Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions python/tvm/relax/frontend/nn/llm/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,23 @@ def get_query_positions(self, total_length: tir.PrimExpr) -> Tensor:
# pylint: enable=protected-access


def _prepare_yarn_rope_scaling(
rope_scaling: Optional[Dict[str, Any]],
rope_theta: Optional[float],
) -> Optional[Dict[str, Any]]:
"""Ensure Yarn-specific scaling configs include the theta metadata."""
if rope_scaling is None:
return None
if rope_scaling.get("rope_type") != "yarn":
return rope_scaling

rope_scaling_updated = dict(rope_scaling)
if "inv_theta_log_scale" not in rope_scaling_updated and rope_theta is not None:
theta_value = float(rope_theta)
rope_scaling_updated["inv_theta_log_scale"] = 1.0 / (2 * math.log(theta_value))
return rope_scaling_updated


class FlashInferPagedKVCache(PagedKVCache): # pylint: disable=too-few-public-methods
"""Paged KV cache using FlashInfer (CUDA) kernels."""

Expand Down Expand Up @@ -372,6 +389,7 @@ def __init__( # pylint: disable=too-many-locals
Whether to enable disaggregation in the KV cache.
"""
assert rope_mode != RopeMode.INLINE, "FlashInfer RoPE does not support inline mode."
rope_scaling = _prepare_yarn_rope_scaling(rope_scaling, rope_theta)

attn_kind_single = attn_kind[0] if isinstance(attn_kind, List) else attn_kind
if attn_kind_single == "mha_sliding":
Expand Down Expand Up @@ -561,6 +579,7 @@ def __init__( # pylint: disable=too-many-locals
target : Target
The target to build the model to.
"""
rope_scaling = _prepare_yarn_rope_scaling(rope_scaling, rope_theta)
attn_kind_single = attn_kind[0] if isinstance(attn_kind, List) else attn_kind
if attn_kind_single == "mha_sliding":
attn_kind_single = "mha"
Expand Down
30 changes: 21 additions & 9 deletions python/tvm/relax/frontend/nn/llm/position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import math
from functools import partial
from typing import Any, Callable, Dict, Optional, Tuple
from typing import Any, Callable, Dict, Optional, Tuple, Union

from tvm import tir
from tvm.relax.frontend.nn import Tensor, op
Expand Down Expand Up @@ -180,38 +180,43 @@ def rope_freq_longrope( # pylint: disable=too-many-arguments
def yarn_find_correction_dim(
num_rotations: int,
d: tir.Var,
theta: float,
max_position_embeddings: int,
inv_theta_log_scale: Optional[Union[float, tir.PrimExpr]] = None,
):
"""Inverse dim formula to find dim based on number of rotations"""
return (d * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
2 * math.log(theta)
return (
d * math.log(max_position_embeddings / (num_rotations * 2 * math.pi)) * inv_theta_log_scale
)


def yarn_find_correction_range(
low_rot: int,
high_rot: int,
d: tir.Var,
theta: float,
max_position_embeddings: int,
inv_theta_log_scale: Optional[Union[float, tir.PrimExpr]] = None,
):
"""Find the correction range based on the number of rotations"""
low = yarn_find_correction_dim(low_rot, d, theta, max_position_embeddings)
high = yarn_find_correction_dim(high_rot, d, theta, max_position_embeddings)
low = yarn_find_correction_dim(
low_rot, d, max_position_embeddings, inv_theta_log_scale=inv_theta_log_scale
)
high = yarn_find_correction_dim(
high_rot, d, max_position_embeddings, inv_theta_log_scale=inv_theta_log_scale
)
return tir.max(low, 0), tir.min(high, d - 1)


def rope_freq_yarn(
s: tir.Var,
d: tir.Var,
d_range: int,
theta: float,
theta: Union[float, tir.PrimExpr],
dtype: str,
original_max_position_embeddings: int,
scaling_factor: float,
beta_fast: int,
beta_slow: int,
inv_theta_log_scale: Optional[Union[float, tir.PrimExpr]] = None,
): # pylint: disable=too-many-arguments, too-many-locals
"""Compute the inverse frequency of RoPE for yarn RoPE scaling."""

Expand All @@ -221,7 +226,11 @@ def rope_freq_yarn(
freq_inter = tir.const(1, "float32") / (scaling_factor * freq_power)

low, high = yarn_find_correction_range(
beta_fast, beta_slow, d_range, theta, original_max_position_embeddings
beta_fast,
beta_slow,
d_range,
original_max_position_embeddings,
inv_theta_log_scale=inv_theta_log_scale,
)
high = tir.if_then_else(low == high, high + 0.001, high)
inv_freq_mask = tir.const(1, "float32") - tir.max(
Expand Down Expand Up @@ -266,12 +275,15 @@ def switch_rope_freq_func(rope_scaling: Dict[str, Any]) -> Callable:
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
)
if rope_scaling["rope_type"] == "yarn":
inv_theta_log_scale = rope_scaling.get("inv_theta_log_scale")
assert inv_theta_log_scale is not None, "inv_theta_log_scale must be precomputed for YaRN"
return partial(
rope_freq_yarn,
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
scaling_factor=rope_scaling["factor"],
beta_fast=rope_scaling["beta_fast"],
beta_slow=rope_scaling["beta_slow"],
inv_theta_log_scale=inv_theta_log_scale,
)
raise ValueError(f'Unsupported RoPE scaling type: {rope_scaling["rope_type"]}')

Expand Down
Loading