Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 57 additions & 4 deletions brainscore_language/model_helpers/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from numpy.core import defchararray
from torch.utils.hooks import RemovableHandle
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding
from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding, DynamicCache
from transformers.modeling_outputs import CausalLMOutput
from typing import Union, List, Tuple, Dict, Callable

Expand Down Expand Up @@ -55,10 +55,10 @@ def __init__(
self.basemodel = model
elif multi_gpu:
self.basemodel = AutoModelForCausalLM.from_pretrained(
self.model_id, low_cpu_mem_usage=True, device_map='auto')
self.model_id, low_cpu_mem_usage=True, device_map='auto', use_safetensors=True)
else:
self.basemodel = AutoModelForCausalLM.from_pretrained(
self.model_id, low_cpu_mem_usage=True)
self.model_id, low_cpu_mem_usage=True, use_safetensors=True)

if multi_gpu:
# With device_map='auto', model is split across GPUs. Inputs must go to
Expand Down Expand Up @@ -126,6 +126,11 @@ def digest_text(self, text: Union[str, List[str]]) -> Dict[str, DataAssembly]:
output = {'behavior': [], 'neural': []}
number_of_tokens = 0

# KV cache state: only used for behavioral-only tasks (no layer activations needed)
_use_kv = self.behavioral_task is not None and not self.neural_recordings
_past_kv = None
_last_logit = None # last-position logit from previous step, shape [1, 1, vocab]

text_iterator = tqdm(text, desc='digest text') if len(text) > 100 else text # show progress bar if many parts
for part_number, text_part in enumerate(text_iterator):
# prepare string representation of context
Expand All @@ -137,7 +142,55 @@ def digest_text(self, text: Union[str, List[str]]) -> Dict[str, DataAssembly]:

# run and remove hooks
with torch.no_grad():
base_output = self.basemodel(**context_tokens)
if _use_kv and _past_kv is not None:
# Slide the KV cache if adding new tokens would exceed the model's context window.
# Drop the oldest entries so the cache stays at max_position_embeddings - new_len,
# keeping attention map size constant at O(max_len) forever.
_max_len = getattr(self.basemodel.config, 'max_position_embeddings', None)
_past_len = _past_kv.get_seq_length() if hasattr(_past_kv, 'get_seq_length') \
else _past_kv[0][0].shape[2]
_new_len = self.current_tokens['input_ids'].shape[1]
if _max_len is not None and _past_len + _new_len > _max_len:
_keep = _max_len - _new_len
if hasattr(_past_kv, 'get_seq_length'):
# DynamicCache: use public to_legacy_cache/from_legacy_cache API
# to avoid depending on internal attributes (key_cache/value_cache)
# which have changed across transformers versions
_legacy = _past_kv.to_legacy_cache()
_sliced = tuple(
(k[:, :, -_keep:, :], v[:, :, -_keep:, :])
for k, v in _legacy
)
_past_kv = DynamicCache.from_legacy_cache(_sliced)
else:
# legacy tuple-of-tuples format (transformers < 4.36)
_past_kv = tuple(
(k[:, :, -_keep:, :], v[:, :, -_keep:, :])
for k, v in _past_kv
)

if _use_kv and _past_kv is not None:
# Feed only new tokens; the KV cache covers the prefix
new_tok = {k: v.to(self.device) for k, v in self.current_tokens.items()}
# Attention mask must span past + new tokens
past_len = _past_kv.get_seq_length() if hasattr(_past_kv, 'get_seq_length') \
else _past_kv[0][0].shape[2]
new_len = new_tok['input_ids'].shape[1]
new_tok['attention_mask'] = torch.ones(
1, past_len + new_len, dtype=torch.long, device=self.device)
base_output = self.basemodel(**new_tok, past_key_values=_past_kv, use_cache=True)
# Prepend last step's final logit so estimate_reading_times slicing is unchanged:
# it expects logits[-new_len-1:-1], which needs the "bridge" logit from prior context
base_output.logits = torch.cat(
[_last_logit.to(base_output.logits.device), base_output.logits], dim=1)
else:
base_output = self.basemodel(**context_tokens, use_cache=_use_kv)

# Update KV cache state after each step
if _use_kv:
_past_kv = base_output.past_key_values
_last_logit = base_output.logits[:, -1:, :].detach()

for hook in hooks:
hook.remove()

Expand Down
Loading