-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodels.py
More file actions
73 lines (49 loc) · 2.05 KB
/
models.py
File metadata and controls
73 lines (49 loc) · 2.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
from torch import nn
class CNNClassifier(nn.Module):
'''
Makes a simple 1D CNN text classifier based off Zhang et. al 2015 (BibTeX below)
Expects a gensim KeyedVectors for embedding
@article{zhang2015sensitivity,
title={A sensitivity analysis of (and practitioners' guide to) convolutional neural networks for sentence classification},
author={Zhang, Ye and Wallace, Byron},
journal={arXiv preprint arXiv:1510.03820},
year={2015}
}
'''
def __init__(self, pretrained_embed, num_classes=2, num_feature_maps=300, kernel_sizes=(2, 3, 4, 5), p_dropout=0.5, device='cpu'):
super().__init__()
embed_dim = pretrained_embed.shape[1]
self.embedder = nn.Embedding.from_pretrained(pretrained_embed.to(device))
self.convs = nn.ModuleList([nn.Conv1d(
embed_dim, num_feature_maps, kernel, device=device) for kernel in kernel_sizes])
self.dropout = nn.Dropout(p=p_dropout)
self.classifier = nn.Linear(
num_feature_maps * len(kernel_sizes), num_classes, device=device)
def conv_pool(self, X, conv):
# (B, F, L') where F is num_feature_maps, L' is length of convolved features
X = conv(X)
# (B, F, 1) -- global maxpool
X = nn.functional.max_pool1d(X, X.shape[2])
# (B, F) -- remove extra dim
out = X.squeeze(2)
return out
def forward(self, X):
'''
INPUTS
X: shape (B, L) where
B is batch size
L is length of sentences (should be 300)
'''
# (B, E, L) where E is embed dim i.e. 300
embeds = self.embedder(X).permute(0, 2, 1)
# (B, number of kernels * F)
features = torch.cat([self.conv_pool(embeds, conv)
for conv in self.convs], dim=1)
# non linearity
features = nn.functional.relu(features)
# dropout
features = self.dropout(features)
# fully connected layer
logit = self.classifier(features)
return logit