diff --git a/brainscore_language/model_helpers/huggingface.py b/brainscore_language/model_helpers/huggingface.py index ed7bac7e..54f0655f 100644 --- a/brainscore_language/model_helpers/huggingface.py +++ b/brainscore_language/model_helpers/huggingface.py @@ -46,16 +46,34 @@ def __init__( self.model_id = model_id self.use_localizer = use_localizer self.region_layer_mapping = region_layer_mapping - self.basemodel = (model if model is not None else AutoModelForCausalLM.from_pretrained(self.model_id)) - if torch.backends.mps.is_available(): + + multi_gpu = (model is None + and torch.cuda.is_available() + and torch.cuda.device_count() > 1) + + if model is not None: + self.basemodel = model + elif multi_gpu: + self.basemodel = AutoModelForCausalLM.from_pretrained( + self.model_id, low_cpu_mem_usage=True, device_map='auto') + else: + self.basemodel = AutoModelForCausalLM.from_pretrained( + self.model_id, low_cpu_mem_usage=True) + + if multi_gpu: + self.device = 'cuda' + self._logger.info(f"Using device_map='auto' across {torch.cuda.device_count()} GPUs") + print(f"Using device_map='auto' across {torch.cuda.device_count()} GPUs") + elif torch.backends.mps.is_available(): self.device = 'mps' + self.basemodel.to(self.device) elif torch.cuda.is_available(): self.device = 'cuda' + self.basemodel.to(self.device) else: self.device = 'cpu' self._logger.info(f"Using device: {self.device}") print(f"Using device: {self.device}") - self.basemodel.to(self.device) self.tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(self.model_id, truncation_side='left') self.current_tokens = None # keep track of current tokens