quark.torch.algorithm.rotation.rotation_utils
#
Module Contents#
Classes#
Functions#
- class quark.torch.algorithm.rotation.rotation_utils.RMSNorm(hidden_size: int, eps: float = 1e-06)#
Root Mean Square Layer Normalization (RMSNorm).
- forward(hidden_states: torch.Tensor) torch.Tensor #
Apply RMSNorm normalization to hidden states.
- quark.torch.algorithm.rotation.rotation_utils.rotate_in_channels(weight: torch.nn.Parameter, /, *, rotation: torch.Tensor) None #
Rotate the input channels of a weight matrix.
- quark.torch.algorithm.rotation.rotation_utils.rotate_out_channels(weight: torch.nn.Parameter, /, *, rotation: torch.Tensor, bias: Optional[torch.nn.Parameter] = None) None #
Rotate the output channels of a weight matrix.
- quark.torch.algorithm.rotation.rotation_utils.get_rotation_matrix(num_channels: int, random: bool = True) torch.Tensor #
Get a random rotation matrix for the given number of channels.