forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcheckpoint.py
More file actions
114 lines (93 loc) · 4 KB
/
checkpoint.py
File metadata and controls
114 lines (93 loc) · 4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
import json
import os
from pathlib import Path
from typing import Any, Dict, Optional
import torch
def get_default_model_resource_dir(model_file_path: str) -> Path:
"""
Get the default path to resouce files (which contain files such as the
checkpoint and param files), either:
1. Uses the path from importlib.resources, only works with buck2
2. Uses default path located in examples/models/llama/params
Expected to be called from with a `model.py` file located in a
`executorch/examples/models/<model_name>` directory.
Args:
model_file_path: The file path to the eager model definition.
For example, `executorch/examples/models/llama/model.py`,
where `executorch/examples/models/llama` contains all
the llama2-related files.
Returns:
The path to the resource directory containing checkpoint, params, etc.
"""
try:
import importlib.resources as _resources
# 1st way: If we can import this path, we are running with buck2 and all resources can be accessed with importlib.resources.
# pyre-ignore
from executorch.examples.models.llama import params # noqa
# Get the model name from the cwd, assuming that this module is called from a path such as
# examples/models/<model_name>/model.py.
model_name = Path(model_file_path).parent.name
model_dir = _resources.files(f"executorch.examples.models.{model_name}")
with _resources.as_file(model_dir) as model_path:
resource_dir = model_path / "params"
assert resource_dir.exists()
except Exception:
# 2nd way:
resource_dir = Path(model_file_path).absolute().parent / "params"
return resource_dir
def get_checkpoint_dtype(checkpoint: Dict[str, Any]) -> Optional[torch.dtype]:
"""
Get the dtype of the checkpoint, returning "None" if the checkpoint is empty.
"""
dtype = None
if len(checkpoint) > 0:
first_key = next(iter(checkpoint))
first = checkpoint[first_key]
dtype = first.dtype
mismatched_dtypes = [
(key, value.dtype)
for key, value in checkpoint.items()
if hasattr(value, "dtype") and value.dtype != dtype
]
if len(mismatched_dtypes) > 0:
print(
f"Mixed dtype model. Dtype of {first_key}: {first.dtype}. Mismatches in the checkpoint: {mismatched_dtypes}"
)
return dtype
def load_checkpoint_from_pytorch_model(input_dir: str) -> Dict:
index_path = os.path.join(input_dir, "pytorch_model.bin.index.json")
if os.path.exists(index_path):
# Sharded checkpoint.
with open(index_path, "r") as f:
index = json.load(f)
weight_map = index["weight_map"]
checkpoint_shards = sorted(set(weight_map.values()))
# Load all the shards into memory
shard_to_weights = {}
for shard in checkpoint_shards:
shard_to_weights[shard] = torch.load(
os.path.join(input_dir, shard),
weights_only=True,
map_location=torch.device("cpu"),
)
# Merge tensors into consolidated state dict.
merged_state_dict = {}
for weight_name, shard in weight_map.items():
tensor = shard_to_weights[shard][weight_name]
merged_state_dict[weight_name] = tensor
return merged_state_dict
# Single checkpoint
model_path = os.path.join(input_dir, "pytorch_model.bin")
if os.path.exists(model_path):
state_dict = torch.load(
model_path, weights_only=True, map_location=torch.device("cpu")
)
return state_dict
raise FileNotFoundError(f"Could not find pytorch_model checkpoint in {input_dir}")