Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion kvpress/presses/base_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@
)


def is_prefilling(cache_position: torch.Tensor, q_len: int) -> bool:
"""Return whether the current forward pass is the initial prefill."""
prefilling = cache_position[-1] + 1 == q_len
return bool(prefilling.item() if isinstance(prefilling, torch.Tensor) else prefilling)


@dataclass
class BasePress:
"""
Expand Down Expand Up @@ -136,7 +142,7 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic
q_len = hidden_states.shape[1]

# Don't compress after pre-filling
if kwargs["cache_position"][-1] > q_len:
if not is_prefilling(kwargs["cache_position"], q_len):
return output

keys, values = extract_keys_and_values(cache, module.layer_idx)
Expand Down
3 changes: 2 additions & 1 deletion kvpress/presses/cam_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from transformers.models.llama.modeling_llama import repeat_kv, rotate_half

from kvpress.presses.adakv_press import AdaKVPress
from kvpress.presses.base_press import is_prefilling
from kvpress.presses.decoding_press import DecodingPress
from kvpress.presses.scorer_press import ScorerPress
from kvpress.utils import extract_keys_and_values, get_prerope_query_states
Expand Down Expand Up @@ -238,7 +239,7 @@ def forward_hook(
layer_idx = int(module.layer_idx)

# Only operate during decoding
if kwargs["cache_position"][-1] <= q_len:
if is_prefilling(kwargs["cache_position"], q_len):
# Entering prefill for a (potentially new) sequence — drop any per-layer
# state left over from a previous sequence so that subsequent decoding
# steps don't try to `+=` against a stale-shaped running attention sum.
Expand Down
4 changes: 2 additions & 2 deletions kvpress/presses/decoding_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from transformers.cache_utils import QuantizedCache

from kvpress.presses.adakv_press import AdaKVPress
from kvpress.presses.base_press import BasePress
from kvpress.presses.base_press import BasePress, is_prefilling
from kvpress.presses.scorer_press import ScorerPress
from kvpress.utils import extract_keys_and_values

Expand Down Expand Up @@ -126,7 +126,7 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic
layer_idx = module.layer_idx

# Only operate during decoding phase (after prefilling)
if kwargs["cache_position"][-1] <= q_len:
if is_prefilling(kwargs["cache_position"], q_len):
# We're still in prefilling phase, don't do anything
return output
# print(f"Adding hidden states to buffer: {hidden_states.shape}")
Expand Down
4 changes: 2 additions & 2 deletions kvpress/presses/dms_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
import torch.nn as nn

from kvpress.presses.base_press import BasePress
from kvpress.presses.base_press import BasePress, is_prefilling
from kvpress.presses.scorer_press import ScorerPress
from kvpress.utils import extract_keys_and_values

Expand Down Expand Up @@ -71,7 +71,7 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic
cache = kwargs["past_key_values"]
q_len = hidden_states.shape[1]
cache_len = kwargs["cache_position"][-1] + 1
prefilling = cache_len == q_len
prefilling = is_prefilling(kwargs["cache_position"], q_len)

# Extract layer index as int for type safety
layer_idx: int = module.layer_idx # type: ignore[assignment]
Expand Down
4 changes: 2 additions & 2 deletions kvpress/presses/fastkvzip_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from transformers import AutoConfig, Gemma3ForConditionalGeneration, PreTrainedModel
from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm

from kvpress.presses.base_press import SUPPORTED_MODELS, BasePress
from kvpress.presses.base_press import SUPPORTED_MODELS, BasePress, is_prefilling

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -223,7 +223,7 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic
q_len = hidden_states.shape[1]

# Don't compress after pre-filling
if kwargs["cache_position"][-1] > q_len:
if not is_prefilling(kwargs["cache_position"], q_len):
return output

self._score_fast(module, hidden_states)
Expand Down
6 changes: 3 additions & 3 deletions kvpress/presses/prefill_decoding_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch.nn as nn
from transformers import PreTrainedModel

from kvpress.presses.base_press import BasePress
from kvpress.presses.base_press import BasePress, is_prefilling
from kvpress.presses.decoding_press import DecodingPress

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -54,7 +54,7 @@ def compress(
q_len = hidden_states.shape[1]

# Determine if we're in prefilling or decoding phase
if kwargs["cache_position"][-1] <= q_len and self.prefilling_press is not None:
if is_prefilling(kwargs["cache_position"], q_len) and self.prefilling_press is not None:
return self.prefilling_press.compress(module, hidden_states, keys, values, attentions, kwargs)
elif self.decoding_press is not None:
return self.decoding_press.compress(module, hidden_states, keys, values, attentions, kwargs)
Expand All @@ -72,7 +72,7 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic
q_len = hidden_states.shape[1]

# Determine if we're in prefilling or decoding phase
if kwargs["cache_position"][-1] <= q_len and self.prefilling_press is not None:
if is_prefilling(kwargs["cache_position"], q_len) and self.prefilling_press is not None:
return self.prefilling_press.forward_hook(module, input, kwargs, output)
elif self.decoding_press is not None:
return self.decoding_press.forward_hook(module, input, kwargs, output)
Expand Down