quark.torch.algorithm.rotation.rotation_utils#

Module Contents#

Classes#

Functions#

class quark.torch.algorithm.rotation.rotation_utils.RMSNorm(hidden_size: int, eps: float = 1e-06)#

Root Mean Square Layer Normalization (RMSNorm).

forward(hidden_states: torch.Tensor) torch.Tensor#

Apply RMSNorm normalization to hidden states.

quark.torch.algorithm.rotation.rotation_utils.rotate_in_channels(weight: torch.nn.Parameter, /, *, rotation: torch.Tensor) None#

Rotate the input channels of a weight matrix.

quark.torch.algorithm.rotation.rotation_utils.rotate_out_channels(weight: torch.nn.Parameter, /, *, rotation: torch.Tensor, bias: Optional[torch.nn.Parameter] = None) None#

Rotate the output channels of a weight matrix.

quark.torch.algorithm.rotation.rotation_utils.get_rotation_matrix(num_channels: int, random: bool = True) torch.Tensor#

Get a random rotation matrix for the given number of channels.