Skip to content

Commit 68a6247

Browse files
authored
fix llama and baichuan typo (#1883)
1 parent f007978 commit 68a6247

File tree

3 files changed

+6
-5
lines changed

3 files changed

+6
-5
lines changed

.github/pylint.conf

+2-1
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,8 @@ disable=raw-checker-failed,
217217
fixme,
218218
use-a-generator,
219219
nested-min-max,
220-
method-hidden
220+
method-hidden,
221+
unsubscriptable-object
221222

222223
# Enable the message, report, category or checker with the given id(s). You can
223224
# either give multiple identifier separated by comma (,) or put this option

mindnlp/transformers/models/baichuan/modeling_baichuan.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1550,11 +1550,11 @@ def forward(
15501550
if attention_mask is not None:
15511551
if len(attention_mask.shape) == 2:
15521552
expanded_mask = attention_mask.to(alibi_mask.dtype)
1553-
expanded_mask = ops.tril(ops.gt(expanded_mask[:, :, None] * expanded_mask[:, None, :], 0)
1554-
) * ops.eq(expanded_mask[:, :, None] - expanded_mask[:, None, :], 0)
1553+
expanded_mask = ops.tril((ops.gt(expanded_mask[:, :, None] * expanded_mask[:, None, :], 0)
1554+
) * ops.eq(expanded_mask[:, :, None] - expanded_mask[:, None, :], 0).int()).bool()
15551555
else:
15561556
expanded_mask = attention_mask
1557-
bsz = inputs_embeds.size(0)
1557+
bsz = inputs_embeds.shape[0]
15581558
src_len, tgt_len = alibi_mask.shape[-2:]
15591559
expanded_mask = expanded_mask.unsqueeze(1).broadcast_to((bsz, 1, src_len, tgt_len)).to(alibi_mask.dtype)
15601560
inverted_mask = 1.0 - expanded_mask

mindnlp/transformers/models/llama/modeling_llama.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def forward(self, x):
300300
)
301301
up_proj = ops.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
302302

303-
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
303+
intermediate_states = ops.split((self.act_fn(gate_proj) * up_proj), slice, dim=2)
304304
down_proj = [
305305
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
306306
]

0 commit comments

Comments
 (0)