diff --git a/brainscore_language/model_helpers/huggingface.py b/brainscore_language/model_helpers/huggingface.py index 7f98b24b..6a7fd855 100644 --- a/brainscore_language/model_helpers/huggingface.py +++ b/brainscore_language/model_helpers/huggingface.py @@ -9,7 +9,7 @@ from numpy.core import defchararray from torch.utils.hooks import RemovableHandle from tqdm import tqdm -from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding +from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding, DynamicCache from transformers.modeling_outputs import CausalLMOutput from typing import Union, List, Tuple, Dict, Callable @@ -55,10 +55,10 @@ def __init__( self.basemodel = model elif multi_gpu: self.basemodel = AutoModelForCausalLM.from_pretrained( - self.model_id, low_cpu_mem_usage=True, device_map='auto') + self.model_id, low_cpu_mem_usage=True, device_map='auto', use_safetensors=True) else: self.basemodel = AutoModelForCausalLM.from_pretrained( - self.model_id, low_cpu_mem_usage=True) + self.model_id, low_cpu_mem_usage=True, use_safetensors=True) if multi_gpu: # With device_map='auto', model is split across GPUs. Inputs must go to @@ -126,6 +126,11 @@ def digest_text(self, text: Union[str, List[str]]) -> Dict[str, DataAssembly]: output = {'behavior': [], 'neural': []} number_of_tokens = 0 + # KV cache state: only used for behavioral-only tasks (no layer activations needed) + _use_kv = self.behavioral_task is not None and not self.neural_recordings + _past_kv = None + _last_logit = None # last-position logit from previous step, shape [1, 1, vocab] + text_iterator = tqdm(text, desc='digest text') if len(text) > 100 else text # show progress bar if many parts for part_number, text_part in enumerate(text_iterator): # prepare string representation of context @@ -137,7 +142,55 @@ def digest_text(self, text: Union[str, List[str]]) -> Dict[str, DataAssembly]: # run and remove hooks with torch.no_grad(): - base_output = self.basemodel(**context_tokens) + if _use_kv and _past_kv is not None: + # Slide the KV cache if adding new tokens would exceed the model's context window. + # Drop the oldest entries so the cache stays at max_position_embeddings - new_len, + # keeping attention map size constant at O(max_len) forever. + _max_len = getattr(self.basemodel.config, 'max_position_embeddings', None) + _past_len = _past_kv.get_seq_length() if hasattr(_past_kv, 'get_seq_length') \ + else _past_kv[0][0].shape[2] + _new_len = self.current_tokens['input_ids'].shape[1] + if _max_len is not None and _past_len + _new_len > _max_len: + _keep = _max_len - _new_len + if hasattr(_past_kv, 'get_seq_length'): + # DynamicCache: use public to_legacy_cache/from_legacy_cache API + # to avoid depending on internal attributes (key_cache/value_cache) + # which have changed across transformers versions + _legacy = _past_kv.to_legacy_cache() + _sliced = tuple( + (k[:, :, -_keep:, :], v[:, :, -_keep:, :]) + for k, v in _legacy + ) + _past_kv = DynamicCache.from_legacy_cache(_sliced) + else: + # legacy tuple-of-tuples format (transformers < 4.36) + _past_kv = tuple( + (k[:, :, -_keep:, :], v[:, :, -_keep:, :]) + for k, v in _past_kv + ) + + if _use_kv and _past_kv is not None: + # Feed only new tokens; the KV cache covers the prefix + new_tok = {k: v.to(self.device) for k, v in self.current_tokens.items()} + # Attention mask must span past + new tokens + past_len = _past_kv.get_seq_length() if hasattr(_past_kv, 'get_seq_length') \ + else _past_kv[0][0].shape[2] + new_len = new_tok['input_ids'].shape[1] + new_tok['attention_mask'] = torch.ones( + 1, past_len + new_len, dtype=torch.long, device=self.device) + base_output = self.basemodel(**new_tok, past_key_values=_past_kv, use_cache=True) + # Prepend last step's final logit so estimate_reading_times slicing is unchanged: + # it expects logits[-new_len-1:-1], which needs the "bridge" logit from prior context + base_output.logits = torch.cat( + [_last_logit.to(base_output.logits.device), base_output.logits], dim=1) + else: + base_output = self.basemodel(**context_tokens, use_cache=_use_kv) + + # Update KV cache state after each step + if _use_kv: + _past_kv = base_output.past_key_values + _last_logit = base_output.logits[:, -1:, :].detach() + for hook in hooks: hook.remove()