kernels.lora
kernels.lora
Module for definition of Low-Rank Adaptation (LoRA) Triton kernels.
See “LoRA: Low-Rank Adaptation of Large Language Models” (https://arxiv.org/abs/2106.09685).
Also supports DoRA (Weight-Decomposed Low-Rank Adaptation): See “DoRA: Weight-Decomposed Low-Rank Adaptation” (https://arxiv.org/abs/2402.09353).
Credit to unsloth (https://unsloth.ai/) for inspiration for this implementation.
Classes
| Name | Description |
|---|---|
| LoRA_Embedding | Fused LoRA embedding: F.embedding(x, W) + s * F.embedding(x, A^T) @ B^T. |
| LoRA_MLP | Optimized LoRA MLP implementation. |
| LoRA_O | Optimized LoRA implementation for output projection. |
| LoRA_QKV | Optimized LoRA QKV implementation with quantization support. |
LoRA_Embedding
kernels.lora.LoRA_Embedding()Fused LoRA embedding: F.embedding(x, W) + s * F.embedding(x, A^T) @ B^T.
Supports dropout and DoRA.
LoRA_MLP
kernels.lora.LoRA_MLP()Optimized LoRA MLP implementation.
Supports bias, dropout, and DoRA. Dropout is applied to the input for gate/up projections. The down projection uses hidden states (post-activation) as input, so dropout is not applied there.
LoRA_O
kernels.lora.LoRA_O()Optimized LoRA implementation for output projection.
Supports bias, dropout, and DoRA.
LoRA_QKV
kernels.lora.LoRA_QKV()Optimized LoRA QKV implementation with quantization support.
Supports bias, dropout, and DoRA (Weight-Decomposed Low-Rank Adaptation). Dropout is applied outside this Function so autograd handles its backward.
Functions
| Name | Description |
|---|---|
| apply_lora_embedding | Applies LoRA to embedding layer. |
| apply_lora_mlp_geglu | Applies LoRA to MLP layer with GEGLU activation. |
| apply_lora_mlp_swiglu | Applies LoRA to MLP layer with SwiGLU activation. |
| apply_lora_o | Applies LoRA to output projection layer. |
| apply_lora_qkv | Applies LoRA to compute Query, Key, Value projections. |
| get_embedding_lora_parameters | Extract LoRA parameters from a PEFT Embedding module. |
| get_lora_parameters | Gets LoRA parameters from a projection module. |
| matmul_lora | Efficient fused matmul + LoRA computation. |
apply_lora_embedding
kernels.lora.apply_lora_embedding(self, x)Applies LoRA to embedding layer.
apply_lora_mlp_geglu
kernels.lora.apply_lora_mlp_geglu(self, X, inplace=True)Applies LoRA to MLP layer with GEGLU activation.
Supports bias, dropout, and DoRA.
apply_lora_mlp_swiglu
kernels.lora.apply_lora_mlp_swiglu(self, X, inplace=True)Applies LoRA to MLP layer with SwiGLU activation.
Supports bias, dropout, and DoRA.
apply_lora_o
kernels.lora.apply_lora_o(self, X)Applies LoRA to output projection layer.
Supports bias, dropout, and DoRA.
apply_lora_qkv
kernels.lora.apply_lora_qkv(self, X, inplace=True)Applies LoRA to compute Query, Key, Value projections.
Supports bias, dropout, and DoRA. Dropout is applied outside the autograd Function so PyTorch handles its backward automatically. A single shared dropout mask is used across Q, K, V projections for memory efficiency.
get_embedding_lora_parameters
kernels.lora.get_embedding_lora_parameters(embed)Extract LoRA parameters from a PEFT Embedding module.
get_lora_parameters
kernels.lora.get_lora_parameters(proj)Gets LoRA parameters from a projection module.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| proj | nn.Module | The projection module to extract parameters from. | required |
Returns
| Name | Type | Description |
|---|---|---|
| torch.Tensor | A tuple containing: | |
| torch.Tensor | None | - W: base weight tensor | |
| QuantState | torch.Tensor | None | - b: base layer bias (or None) | |
| torch.Tensor | None | - quant_state: quantization state (or None) | |
| torch.Tensor | None | - A: LoRA A weight (or None) | |
| float | None | - B: LoRA B weight (or None) | |
| torch.Tensor | None | - s: LoRA scaling factor (or None) | |
| nn.Module | None | - lora_bias: LoRA B bias (or None) | |
| torch.Tensor | None | - dropout: dropout module (or None) | |
| tuple[torch.Tensor, torch.Tensor | None, QuantState | torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, float | None, torch.Tensor | None, nn.Module | None, torch.Tensor | None] | - magnitude: DoRA magnitude vector (or None) |
matmul_lora
kernels.lora.matmul_lora(
X,
W,
b,
W_quant,
A,
B,
s,
out=None,
X_drop=None,
lora_bias=None,
)Efficient fused matmul + LoRA computation.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| X | torch.Tensor | Input tensor [*, in_features] | required |
| W | torch.Tensor | Base weight matrix [out_features, in_features] | required |
| W_quant | QuantState | torch.Tensor | None | Quantization state for W | required |
| A | torch.Tensor | None | LoRA A matrix [rank, in_features] | required |
| B | torch.Tensor | None | LoRA B matrix [out_features, rank] | required |
| s | float | None | LoRA scaling factor | required |
| out | torch.Tensor | None | Optional output tensor for inplace operations | None |
| X_drop | torch.Tensor | None | Optional dropout-applied input for LoRA path (if None, uses X) | None |
| lora_bias | torch.Tensor | None | Optional LoRA B layer bias [out_features] | None |
Returns
| Name | Type | Description |
|---|---|---|
| torch.Tensor | Result of X @ W + s * X_drop @ A @ B + b + s * lora_bias |