Skip to content

Commit f4a5e24

Browse files
committed
make style
1 parent 7892744 commit f4a5e24

File tree

3 files changed

+7
-21
lines changed

3 files changed

+7
-21
lines changed

MaxText/common_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,4 @@ class DecoderBlockType(enum.Enum):
8181
SIMPLE = "simple"
8282
SIMPLE_MLP = "simple_mlp"
8383
LLAMA4 = "llama4"
84-
QWEN3 = "qwen3"
84+
QWEN3 = "qwen3"

MaxText/layers/linears.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,7 @@ def __init__(
8484
axis: Union[Iterable[int], int] = -1,
8585
weight_dtype: DType = jnp.float32,
8686
dtype: DType = jnp.float32,
87-
kernel_init: NdInitializer = nd_dense_init(
88-
1.0, "fan_in", "truncated_normal"
89-
),
87+
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal"),
9088
kernel_axes: Tuple[Optional[str], ...] = (),
9189
quant: Optional[Quant] = None,
9290
use_bias: bool = False,
@@ -127,9 +125,7 @@ def __init__(
127125
# Parameter initialization
128126
kernel_shape = self.in_features + self.out_features
129127
kernel_in_axis = np.arange(len(self.axis))
130-
kernel_out_axis = np.arange(
131-
len(self.axis), len(self.axis) + len(self.out_features)
132-
)
128+
kernel_out_axis = np.arange(len(self.axis), len(self.axis) + len(self.out_features))
133129

134130
self.kernel = nnx.Param(
135131
self.kernel_init(
@@ -217,9 +213,7 @@ def dense_general(
217213
axis: Union[Iterable[int], int] = -1,
218214
weight_dtype: DType = jnp.float32,
219215
dtype: DType = jnp.float32,
220-
kernel_init: NdInitializer = nd_dense_init(
221-
1.0, "fan_in", "truncated_normal"
222-
),
216+
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal"),
223217
kernel_axes: Tuple[Optional[str], ...] = (),
224218
quant: Optional[Quant] = None,
225219
use_bias: bool = False,
@@ -246,15 +240,11 @@ def dense_general(
246240
name: name passed to the ToLinen Module
247241
"""
248242
if not (inputs_shape is not None) ^ (in_features is not None):
249-
raise ValueError(
250-
"Exactly one of inputs_shape or in_features must be specified."
251-
)
243+
raise ValueError("Exactly one of inputs_shape or in_features must be specified.")
252244

253245
if inputs_shape is not None:
254246
axis = _canonicalize_tuple(axis)
255-
in_features = tuple(
256-
inputs_shape[ax] for ax in _normalize_axes(axis, len(inputs_shape))
257-
)
247+
in_features = tuple(inputs_shape[ax] for ax in _normalize_axes(axis, len(inputs_shape)))
258248
else:
259249
assert in_features is not None
260250
module = nnx.bridge.to_linen(
@@ -400,4 +390,3 @@ def __call__(self, inputs, decode: bool = False, deterministic: bool = False):
400390

401391
output = checkpoint_name(output, "mlpwo")
402392
return output
403-

MaxText/layers/models.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -671,9 +671,7 @@ def __call__(
671671
inputs_shape=y.shape,
672672
features=cfg.vocab_size,
673673
weight_dtype=cfg.weight_dtype,
674-
dtype=jnp.float32
675-
if cfg.logits_dot_in_fp32
676-
else cfg.dtype, # for logit training stability
674+
dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability
677675
kernel_axes=("embed", "vocab"),
678676
name="logits_dense",
679677
matmul_precision=self.config.matmul_precision,
@@ -809,4 +807,3 @@ def __call__(
809807
image_embeddings=image_embeddings,
810808
)
811809
return logits
812-

0 commit comments

Comments
 (0)