-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathBasicBlock.py
More file actions
112 lines (87 loc) · 4.2 KB
/
BasicBlock.py
File metadata and controls
112 lines (87 loc) · 4.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
import torch.nn as nn
import numpy as np
import math
class ConvBlock(nn.Module):
def __init__(self, in_planes=3, planes=128):
super(ConvBlock, self).__init__()
self.planes = planes
self.conv1 = nn.Conv3d(in_planes, planes // 4, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3),
bias=False)
self.in1 = nn.InstanceNorm3d(planes // 4)
self.maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 0, 0))
self.conv2 = nn.Conv3d(planes // 4, planes, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3),
bias=False)
self.in2 = nn.InstanceNorm3d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv3 = nn.Conv3d(planes, planes, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), bias=False)
self.in3 = nn.InstanceNorm3d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv4 = nn.Conv3d(planes, planes, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False)
self.in4 = nn.InstanceNorm3d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv5 = nn.Conv3d(planes, planes, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), bias=False)
self.in5 = nn.InstanceNorm3d(planes)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
"""
input:
speaker_video_frames x: (batch_size, 3, seq_len, img_size, img_size)
output:
speaker_temporal_tokens y: (batch_size, token_dim, seq_len)
"""
x = self.relu(self.in1(self.conv1(x)))
x = self.maxpool(x)
x = self.relu(self.in2(self.conv2(x)))
x = self.relu(self.in3(self.conv3(x)))
x = x + self.relu(self.in4(self.conv4(x)))
x = self.relu(self.in5(self.conv5(x)))
y = x.mean(dim=-1).mean(dim=-1)
return y
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000, batch_first=True):
super().__init__()
self.batch_first = batch_first
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
# not used in the final model
if self.batch_first:
x = x + self.pe.permute(1, 0, 2)[:, :x.shape[1], :]
else:
x = x + self.pe[:x.shape[0], :]
return self.dropout(x)
def lengths_to_mask(lengths, device):
lengths = torch.tensor(lengths, device=device)
max_len = max(lengths)
mask = torch.arange(max_len, device=device).expand(len(lengths), max_len) < lengths.unsqueeze(1)
return mask
# Temporal Bias, inspired by ALiBi: https://github.com/ofirpress/attention_with_linear_biases
def init_biased_mask(n_head, max_seq_len, period):
def get_slopes(n):
def get_slopes_power_of_2(n):
start = (2**(-2**-(math.log2(n)-3)))
ratio = start
return [start*ratio**i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2**math.floor(math.log2(n))
return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2]
slopes = torch.Tensor(get_slopes(n_head))
bias = torch.arange(start=0, end=max_seq_len, step=period).unsqueeze(1).repeat(1,period).view(-1)//(period)
bias = - torch.flip(bias,dims=[0])
alibi = torch.zeros(max_seq_len, max_seq_len)
for i in range(max_seq_len):
alibi[i, :i+1] = bias[-(i+1):]
alibi = slopes.unsqueeze(1).unsqueeze(1) * alibi.unsqueeze(0)
mask = (torch.triu(torch.ones(max_seq_len, max_seq_len)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
mask = mask.unsqueeze(0) + alibi
return mask