-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathsemb.py
More file actions
146 lines (113 loc) · 5.1 KB
/
semb.py
File metadata and controls
146 lines (113 loc) · 5.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""
Compute SEmb for classification/ regression/ question-answering tasks.
"""
import argparse
import logging
import os
import numpy as np
import torch
from itrain import DATASET_MANAGER_CLASSES, DatasetArguments, DatasetManager, RunArguments
from itrain.datasets.tagging import TaggingDatasetManager
from itrain.runner import set_seed
from torch.utils.data import DataLoader
from tqdm import tqdm, trange
from transformers import AutoModel, AutoTokenizer, PreTrainedModel
logger = logging.getLogger(__name__)
def compute_semb(dataset_manager: DatasetManager, model: PreTrainedModel, run_args: RunArguments):
train_dataloader = DataLoader(
dataset_manager.train_split,
batch_size=run_args.batch_size,
sampler=dataset_manager.train_sampler(),
collate_fn=dataset_manager.collate_fn,
)
logger.info("***** Compute SEmb *****")
logger.info("Num batches = %d", len(train_dataloader))
logger.info("Batch size = %d", run_args.batch_size)
model.eval()
model.zero_grad()
train_iterator = trange(int(run_args.num_train_epochs), desc="Epoch", disable=False)
total_num_examples = 0
global_feature_dict = {}
for _ in train_iterator:
num_examples = 0
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=False)
for step, inputs in enumerate(epoch_iterator):
# don't input labels as we don't have a head
if "labels" in inputs:
del inputs["labels"]
elif "start_positions" in inputs:
del inputs["start_positions"]
del inputs["end_positions"]
# HACK
if model.config.model_type == "distilbert" and "token_type_ids" in inputs:
del inputs["token_type_ids"]
for k, v in inputs.items():
if isinstance(v, torch.Tensor):
inputs[k] = v.to(run_args.device).view(-1, v.size(-1))
with torch.no_grad():
input_mask = inputs["attention_mask"]
outputs = model(**inputs)
sequence_output = outputs[0] # batch_size x max_seq_length x hidden_size
input_mask = input_mask.view(-1, input_mask.size(-1))
active_sequence_output = torch.einsum("ijk,ij->ijk", [sequence_output, input_mask])
avg_sequence_output = active_sequence_output.sum(1) / input_mask.sum(dim=1).view(input_mask.size(0), 1)
if len(global_feature_dict) == 0:
global_feature_dict["avg_sequence_output"] = avg_sequence_output.sum(dim=0).detach().cpu().numpy()
else:
global_feature_dict["avg_sequence_output"] += avg_sequence_output.sum(dim=0).detach().cpu().numpy()
num_examples += input_mask.size(0)
total_num_examples += num_examples
# Normalize
for key in global_feature_dict:
global_feature_dict[key] = global_feature_dict[key] / total_num_examples
return global_feature_dict
def run_semb(args, data_args=None, run_args=None, seed=42):
run_args = run_args or RunArguments(
batch_size=32,
num_train_epochs=1,
)
# Setup logging
logging.basicConfig(level=logging.INFO)
logger.warning("Device: %s", run_args.device)
# Dataset
data_args = data_args or DatasetArguments(dataset_name=args["dataset"], task_name=args["dataset_task"])
dataset_manager: DatasetManager = DATASET_MANAGER_CLASSES[data_args.dataset_name](data_args)
# Create output directory if needed
task_output_dir = os.path.join(args["output_dir"], dataset_manager.name)
if not os.path.exists(task_output_dir):
os.makedirs(task_output_dir)
# skip computation if already exists
elif len(os.listdir(task_output_dir)) > 0 and not args["overwrite"]:
logger.info("Output already exists, skipping {0}".format(dataset_manager.name))
return
# Set seed
set_seed(seed)
# Tokenizer
if isinstance(dataset_manager, TaggingDatasetManager):
add_prefix_space = True
else:
add_prefix_space = False
tokenizer = AutoTokenizer.from_pretrained(
args["model_name"],
use_fast=args.get("fast_tokenizer", False),
add_prefix_space=add_prefix_space,
)
dataset_manager.tokenizer = tokenizer
# Load
model = AutoModel.from_pretrained(args["model_name"])
dataset_manager.load_and_preprocess()
model.to(run_args.device)
feature_dict = compute_semb(dataset_manager, model, run_args)
for key in feature_dict:
np.save(os.path.join(task_output_dir, "{}.npy".format(key)), feature_dict[key])
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", default=None, type=str, required=True)
parser.add_argument("--dataset_task", default=None, type=str)
parser.add_argument("--model_name", default="sentence-transformers/roberta-base-nli-stsb-mean-tokens", type=str)
parser.add_argument("--output_dir", default="output/semb", type=str)
parser.add_argument("--overwrite", action="store_true", default=False)
args = parser.parse_args()
run_semb(vars(args))
if __name__ == "__main__":
main()