Skip to content

Commit 177a8de

Browse files
committed
Value update for mask
Signed-off-by: Amit Raj <[email protected]>
1 parent c04ba8e commit 177a8de

File tree

3 files changed

+3
-6
lines changed

3 files changed

+3
-6
lines changed

QEfficient/transformers/models/gpt2/modeling_gpt2.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,9 @@ def eager_attention_forward(module, query, key, value, attention_mask, head_mask
3131
# if only "normal" attention layer implements causal mask
3232
query_length, key_length = query.size(-2), key.size(-2)
3333
causal_mask = module.bias[:, :, key_length - query_length : key_length, :key_length]
34-
mask_value = MIN_MASKED_ATTENTION_VALUE
3534
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
3635
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
37-
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
36+
mask_value = torch.full([], MIN_MASKED_ATTENTION_VALUE, dtype=attn_weights.dtype, device=attn_weights.device)
3837
attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
3938

4039
if attention_mask is not None:

QEfficient/transformers/models/mllama/modeling_mllama.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,7 @@ def forward(
179179
if attention_mask is not None: # no matter the length, we just slice it
180180
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
181181
attn_weights = attn_weights + causal_mask
182-
# attn_weights = torch.where(
183-
# attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
184-
# )
182+
185183

186184
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
187185
attn_output = torch.matmul(attn_weights, value_states)

QEfficient/utils/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
ONNX_EXPORT_CTX_LEN = 1024
2727

2828
# Minimum value for causal mask
29-
MIN_MASKED_ATTENTION_VALUE = -1e4
29+
MIN_MASKED_ATTENTION_VALUE = float("-inf")
3030

3131

3232
# Store the qeff_models inside the ~/.cache directory or over-ride with an env variable.

0 commit comments

Comments
 (0)