You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
rms_scaling: If True, center and scale are ignored, and the inputs are scaled by gamma and the inverse square root of the square of all inputs. This is an approximate and faster approach that avoids ever computing the mean of the input.
However, in the implementation, it actually does the following:
if self.rms_scaling:
# Calculate outputs with only variance and gamma if rms scaling
# is enabled
# Calculate the variance along self.axis (layer activations).
variance = ops.var(inputs, axis=self.axis, keepdims=True)
inv = ops.rsqrt(variance + self.epsilon)
outputs = (
inputs * inv * ops.cast(_broadcast(self.gamma), inputs.dtype)
)
So the mean is indeed used, as variance is computed here rather than RMS norm.
There was also a discussion during the addition of RMS Normalization (#20911 (comment)) that confirms this behavior.
I think the docs could use an update to clarify this behavior. Right now, it sounds like the mean isn't used when rms_scaling is on, but the code suggests otherwise.
The text was updated successfully, but these errors were encountered:
Alternatively, the implementation could be adjusted to match the docs, though that might make RMSNormalization behave just like LayerNormalization with rms_scaling=True, correct?
The documentation mentions that
However, in the implementation, it actually does the following:
So the mean is indeed used, as variance is computed here rather than RMS norm.
There was also a discussion during the addition of RMS Normalization (#20911 (comment)) that confirms this behavior.
I think the docs could use an update to clarify this behavior. Right now, it sounds like the mean isn't used when
rms_scaling
is on, but the code suggests otherwise.The text was updated successfully, but these errors were encountered: