kernels.lora

kernels.lora

Module for definition of Low-Rank Adaptation (LoRA) Triton kernels.

See “LoRA: Low-Rank Adaptation of Large Language Models” (https://arxiv.org/abs/2106.09685).

Also supports DoRA (Weight-Decomposed Low-Rank Adaptation): See “DoRA: Weight-Decomposed Low-Rank Adaptation” (https://arxiv.org/abs/2402.09353).

Credit to unsloth (https://unsloth.ai/) for inspiration for this implementation.

Classes

Name Description
LoRA_Embedding Fused LoRA embedding: F.embedding(x, W) + s * F.embedding(x, A^T) @ B^T.
LoRA_MLP Optimized LoRA MLP implementation.
LoRA_O Optimized LoRA implementation for output projection.
LoRA_QKV Optimized LoRA QKV implementation with quantization support.

LoRA_Embedding

kernels.lora.LoRA_Embedding()

Fused LoRA embedding: F.embedding(x, W) + s * F.embedding(x, A^T) @ B^T.

Supports dropout and DoRA.

LoRA_MLP

kernels.lora.LoRA_MLP()

Optimized LoRA MLP implementation.

Supports bias, dropout, and DoRA. Dropout is applied to the input for gate/up projections. The down projection uses hidden states (post-activation) as input, so dropout is not applied there.

LoRA_O

kernels.lora.LoRA_O()

Optimized LoRA implementation for output projection.

Supports bias, dropout, and DoRA.

LoRA_QKV

kernels.lora.LoRA_QKV()

Optimized LoRA QKV implementation with quantization support.

Supports bias, dropout, and DoRA (Weight-Decomposed Low-Rank Adaptation). Dropout is applied outside this Function so autograd handles its backward.

Functions

Name Description
apply_lora_embedding Applies LoRA to embedding layer.
apply_lora_mlp_geglu Applies LoRA to MLP layer with GEGLU activation.
apply_lora_mlp_swiglu Applies LoRA to MLP layer with SwiGLU activation.
apply_lora_o Applies LoRA to output projection layer.
apply_lora_qkv Applies LoRA to compute Query, Key, Value projections.
get_embedding_lora_parameters Extract LoRA parameters from a PEFT Embedding module.
get_lora_parameters Gets LoRA parameters from a projection module.
matmul_lora Efficient fused matmul + LoRA computation.

apply_lora_embedding

kernels.lora.apply_lora_embedding(self, x)

Applies LoRA to embedding layer.

apply_lora_mlp_geglu

kernels.lora.apply_lora_mlp_geglu(self, X, inplace=True)

Applies LoRA to MLP layer with GEGLU activation.

Supports bias, dropout, and DoRA.

apply_lora_mlp_swiglu

kernels.lora.apply_lora_mlp_swiglu(self, X, inplace=True)

Applies LoRA to MLP layer with SwiGLU activation.

Supports bias, dropout, and DoRA.

apply_lora_o

kernels.lora.apply_lora_o(self, X)

Applies LoRA to output projection layer.

Supports bias, dropout, and DoRA.

apply_lora_qkv

kernels.lora.apply_lora_qkv(self, X, inplace=True)

Applies LoRA to compute Query, Key, Value projections.

Supports bias, dropout, and DoRA. Dropout is applied outside the autograd Function so PyTorch handles its backward automatically. A single shared dropout mask is used across Q, K, V projections for memory efficiency.

get_embedding_lora_parameters

kernels.lora.get_embedding_lora_parameters(embed)

Extract LoRA parameters from a PEFT Embedding module.

get_lora_parameters

kernels.lora.get_lora_parameters(proj)

Gets LoRA parameters from a projection module.

Parameters

Name Type Description Default
proj nn.Module The projection module to extract parameters from. required

Returns

Name Type Description
torch.Tensor A tuple containing:
torch.Tensor | None - W: base weight tensor
QuantState | torch.Tensor | None - b: base layer bias (or None)
torch.Tensor | None - quant_state: quantization state (or None)
torch.Tensor | None - A: LoRA A weight (or None)
float | None - B: LoRA B weight (or None)
torch.Tensor | None - s: LoRA scaling factor (or None)
nn.Module | None - lora_bias: LoRA B bias (or None)
torch.Tensor | None - dropout: dropout module (or None)
tuple[torch.Tensor, torch.Tensor | None, QuantState | torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, float | None, torch.Tensor | None, nn.Module | None, torch.Tensor | None] - magnitude: DoRA magnitude vector (or None)

matmul_lora

kernels.lora.matmul_lora(
    X,
    W,
    b,
    W_quant,
    A,
    B,
    s,
    out=None,
    X_drop=None,
    lora_bias=None,
)

Efficient fused matmul + LoRA computation.

Parameters

Name Type Description Default
X torch.Tensor Input tensor [*, in_features] required
W torch.Tensor Base weight matrix [out_features, in_features] required
W_quant QuantState | torch.Tensor | None Quantization state for W required
A torch.Tensor | None LoRA A matrix [rank, in_features] required
B torch.Tensor | None LoRA B matrix [out_features, rank] required
s float | None LoRA scaling factor required
out torch.Tensor | None Optional output tensor for inplace operations None
X_drop torch.Tensor | None Optional dropout-applied input for LoRA path (if None, uses X) None
lora_bias torch.Tensor | None Optional LoRA B layer bias [out_features] None

Returns

Name Type Description
torch.Tensor Result of X @ W + s * X_drop @ A @ B + b + s * lora_bias