@@ -84,9 +84,7 @@ def __init__(
84
84
axis : Union [Iterable [int ], int ] = - 1 ,
85
85
weight_dtype : DType = jnp .float32 ,
86
86
dtype : DType = jnp .float32 ,
87
- kernel_init : NdInitializer = nd_dense_init (
88
- 1.0 , "fan_in" , "truncated_normal"
89
- ),
87
+ kernel_init : NdInitializer = nd_dense_init (1.0 , "fan_in" , "truncated_normal" ),
90
88
kernel_axes : Tuple [Optional [str ], ...] = (),
91
89
quant : Optional [Quant ] = None ,
92
90
use_bias : bool = False ,
@@ -127,9 +125,7 @@ def __init__(
127
125
# Parameter initialization
128
126
kernel_shape = self .in_features + self .out_features
129
127
kernel_in_axis = np .arange (len (self .axis ))
130
- kernel_out_axis = np .arange (
131
- len (self .axis ), len (self .axis ) + len (self .out_features )
132
- )
128
+ kernel_out_axis = np .arange (len (self .axis ), len (self .axis ) + len (self .out_features ))
133
129
134
130
self .kernel = nnx .Param (
135
131
self .kernel_init (
@@ -217,9 +213,7 @@ def dense_general(
217
213
axis : Union [Iterable [int ], int ] = - 1 ,
218
214
weight_dtype : DType = jnp .float32 ,
219
215
dtype : DType = jnp .float32 ,
220
- kernel_init : NdInitializer = nd_dense_init (
221
- 1.0 , "fan_in" , "truncated_normal"
222
- ),
216
+ kernel_init : NdInitializer = nd_dense_init (1.0 , "fan_in" , "truncated_normal" ),
223
217
kernel_axes : Tuple [Optional [str ], ...] = (),
224
218
quant : Optional [Quant ] = None ,
225
219
use_bias : bool = False ,
@@ -246,15 +240,11 @@ def dense_general(
246
240
name: name passed to the ToLinen Module
247
241
"""
248
242
if not (inputs_shape is not None ) ^ (in_features is not None ):
249
- raise ValueError (
250
- "Exactly one of inputs_shape or in_features must be specified."
251
- )
243
+ raise ValueError ("Exactly one of inputs_shape or in_features must be specified." )
252
244
253
245
if inputs_shape is not None :
254
246
axis = _canonicalize_tuple (axis )
255
- in_features = tuple (
256
- inputs_shape [ax ] for ax in _normalize_axes (axis , len (inputs_shape ))
257
- )
247
+ in_features = tuple (inputs_shape [ax ] for ax in _normalize_axes (axis , len (inputs_shape )))
258
248
else :
259
249
assert in_features is not None
260
250
module = nnx .bridge .to_linen (
@@ -400,4 +390,3 @@ def __call__(self, inputs, decode: bool = False, deterministic: bool = False):
400
390
401
391
output = checkpoint_name (output , "mlpwo" )
402
392
return output
403
-
0 commit comments