diff --git a/kvpress/presses/base_press.py b/kvpress/presses/base_press.py index 327c5273..5c73ee0f 100644 --- a/kvpress/presses/base_press.py +++ b/kvpress/presses/base_press.py @@ -34,6 +34,12 @@ ) +def is_prefilling(cache_position: torch.Tensor, q_len: int) -> bool: + """Return whether the current forward pass is the initial prefill.""" + prefilling = cache_position[-1] + 1 == q_len + return bool(prefilling.item() if isinstance(prefilling, torch.Tensor) else prefilling) + + @dataclass class BasePress: """ @@ -136,7 +142,7 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic q_len = hidden_states.shape[1] # Don't compress after pre-filling - if kwargs["cache_position"][-1] > q_len: + if not is_prefilling(kwargs["cache_position"], q_len): return output keys, values = extract_keys_and_values(cache, module.layer_idx) diff --git a/kvpress/presses/cam_press.py b/kvpress/presses/cam_press.py index 5a827e74..ee988a12 100644 --- a/kvpress/presses/cam_press.py +++ b/kvpress/presses/cam_press.py @@ -13,6 +13,7 @@ from transformers.models.llama.modeling_llama import repeat_kv, rotate_half from kvpress.presses.adakv_press import AdaKVPress +from kvpress.presses.base_press import is_prefilling from kvpress.presses.decoding_press import DecodingPress from kvpress.presses.scorer_press import ScorerPress from kvpress.utils import extract_keys_and_values, get_prerope_query_states @@ -238,7 +239,7 @@ def forward_hook( layer_idx = int(module.layer_idx) # Only operate during decoding - if kwargs["cache_position"][-1] <= q_len: + if is_prefilling(kwargs["cache_position"], q_len): # Entering prefill for a (potentially new) sequence — drop any per-layer # state left over from a previous sequence so that subsequent decoding # steps don't try to `+=` against a stale-shaped running attention sum. diff --git a/kvpress/presses/decoding_press.py b/kvpress/presses/decoding_press.py index 697b833f..b492ac5b 100644 --- a/kvpress/presses/decoding_press.py +++ b/kvpress/presses/decoding_press.py @@ -12,7 +12,7 @@ from transformers.cache_utils import QuantizedCache from kvpress.presses.adakv_press import AdaKVPress -from kvpress.presses.base_press import BasePress +from kvpress.presses.base_press import BasePress, is_prefilling from kvpress.presses.scorer_press import ScorerPress from kvpress.utils import extract_keys_and_values @@ -126,7 +126,7 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic layer_idx = module.layer_idx # Only operate during decoding phase (after prefilling) - if kwargs["cache_position"][-1] <= q_len: + if is_prefilling(kwargs["cache_position"], q_len): # We're still in prefilling phase, don't do anything return output # print(f"Adding hidden states to buffer: {hidden_states.shape}") diff --git a/kvpress/presses/dms_press.py b/kvpress/presses/dms_press.py index 6f1b3398..29634a30 100644 --- a/kvpress/presses/dms_press.py +++ b/kvpress/presses/dms_press.py @@ -7,7 +7,7 @@ import torch import torch.nn as nn -from kvpress.presses.base_press import BasePress +from kvpress.presses.base_press import BasePress, is_prefilling from kvpress.presses.scorer_press import ScorerPress from kvpress.utils import extract_keys_and_values @@ -71,7 +71,7 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic cache = kwargs["past_key_values"] q_len = hidden_states.shape[1] cache_len = kwargs["cache_position"][-1] + 1 - prefilling = cache_len == q_len + prefilling = is_prefilling(kwargs["cache_position"], q_len) # Extract layer index as int for type safety layer_idx: int = module.layer_idx # type: ignore[assignment] diff --git a/kvpress/presses/fastkvzip_press.py b/kvpress/presses/fastkvzip_press.py index 1c57bbeb..75012cc0 100644 --- a/kvpress/presses/fastkvzip_press.py +++ b/kvpress/presses/fastkvzip_press.py @@ -15,7 +15,7 @@ from transformers import AutoConfig, Gemma3ForConditionalGeneration, PreTrainedModel from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm -from kvpress.presses.base_press import SUPPORTED_MODELS, BasePress +from kvpress.presses.base_press import SUPPORTED_MODELS, BasePress, is_prefilling logger = logging.getLogger(__name__) @@ -223,7 +223,7 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic q_len = hidden_states.shape[1] # Don't compress after pre-filling - if kwargs["cache_position"][-1] > q_len: + if not is_prefilling(kwargs["cache_position"], q_len): return output self._score_fast(module, hidden_states) diff --git a/kvpress/presses/prefill_decoding_press.py b/kvpress/presses/prefill_decoding_press.py index 244131ff..97357657 100644 --- a/kvpress/presses/prefill_decoding_press.py +++ b/kvpress/presses/prefill_decoding_press.py @@ -10,7 +10,7 @@ import torch.nn as nn from transformers import PreTrainedModel -from kvpress.presses.base_press import BasePress +from kvpress.presses.base_press import BasePress, is_prefilling from kvpress.presses.decoding_press import DecodingPress logger = logging.getLogger(__name__) @@ -54,7 +54,7 @@ def compress( q_len = hidden_states.shape[1] # Determine if we're in prefilling or decoding phase - if kwargs["cache_position"][-1] <= q_len and self.prefilling_press is not None: + if is_prefilling(kwargs["cache_position"], q_len) and self.prefilling_press is not None: return self.prefilling_press.compress(module, hidden_states, keys, values, attentions, kwargs) elif self.decoding_press is not None: return self.decoding_press.compress(module, hidden_states, keys, values, attentions, kwargs) @@ -72,7 +72,7 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic q_len = hidden_states.shape[1] # Determine if we're in prefilling or decoding phase - if kwargs["cache_position"][-1] <= q_len and self.prefilling_press is not None: + if is_prefilling(kwargs["cache_position"], q_len) and self.prefilling_press is not None: return self.prefilling_press.forward_hook(module, input, kwargs, output) elif self.decoding_press is not None: return self.decoding_press.forward_hook(module, input, kwargs, output)