quark.torch.algorithm.rotation.rotation_utils
#
Module Contents#
Classes#
Root Mean Square Layer Normalization (RMSNorm). |
Functions#
|
Rotate the input channels of a weight matrix. |
|
Rotate the output channels of a weight matrix. |
|
Get a random rotation matrix for the given number of channels. |
- class quark.torch.algorithm.rotation.rotation_utils.RMSNorm(hidden_size: int, eps: float = 1e-06)#
Root Mean Square Layer Normalization (RMSNorm).
- forward(hidden_states: torch.Tensor) torch.Tensor #
Apply RMSNorm normalization to hidden states.
- quark.torch.algorithm.rotation.rotation_utils.rotate_in_channels(module: torch.nn.Module, rotation: torch.Tensor) None #
Rotate the input channels of a weight matrix.
- quark.torch.algorithm.rotation.rotation_utils.rotate_out_channels(module: torch.nn.Module, rotation: torch.Tensor) None #
Rotate the output channels of a weight matrix.
- quark.torch.algorithm.rotation.rotation_utils.get_rotation_matrix(num_channels: int, random: bool = True) torch.Tensor #
Get a random rotation matrix for the given number of channels.