zcahjl3 commited on
Commit
0a7c540
·
verified ·
1 Parent(s): 66166af

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +225 -0
model.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from typing import List, Tuple
5
+ from torch.nn import TransformerEncoder, TransformerEncoderLayer
6
+ from transformers.models.bert import BertModel
7
+ from fastNLP.modules.torch import MLP,ConditionalRandomField,allowed_transitions
8
+ from torch.nn import CrossEntropyLoss
9
+
10
+
11
+ class ConvFeatureExtractionModel(nn.Module):
12
+
13
+ def __init__(
14
+ self,
15
+ conv_layers: List[Tuple[int, int, int]],
16
+ conv_dropout: float = 0.0,
17
+ conv_bias: bool = False,
18
+ ):
19
+ super().__init__()
20
+
21
+ def block(n_in, n_out, k, stride=1, conv_bias=False):
22
+ padding = k // 2
23
+ return nn.Sequential(
24
+ nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=k, stride=stride, padding=padding, bias=conv_bias),
25
+ nn.Dropout(conv_dropout),
26
+ # nn.BatchNorm1d(n_out),
27
+ nn.ReLU(),
28
+ # nn.MaxPool1d(kernel_size=2, stride=2)
29
+ )
30
+
31
+ in_d = 1
32
+ self.conv_layers = nn.ModuleList()
33
+ for _, cl in enumerate(conv_layers):
34
+ assert len(cl) == 3, "invalid conv definition: " + str(cl)
35
+ (dim, k, stride) = cl
36
+
37
+ self.conv_layers.append(
38
+ block(in_d, dim, k, stride=stride, conv_bias=conv_bias))
39
+ in_d = dim
40
+
41
+ def forward(self, x):
42
+ # x = x.unsqueeze(1)
43
+ for conv in self.conv_layers:
44
+ x = conv(x)
45
+ return x
46
+
47
+
48
+ class ModelWiseCNNClassifier(nn.Module):
49
+
50
+ def __init__(self, id2labels, dropout_rate=0.1):
51
+ super(ModelWiseCNNClassifier, self).__init__()
52
+ feature_enc_layers = [(64, 5, 1)] + [(128, 3, 1)] * 3 + [(64, 3, 1)]
53
+ self.conv = ConvFeatureExtractionModel(
54
+ conv_layers=feature_enc_layers,
55
+ conv_dropout=0.0,
56
+ conv_bias=False,
57
+ )
58
+
59
+ embedding_size = 4 *64
60
+ self.norm = nn.LayerNorm(embedding_size)
61
+
62
+ self.label_num = len(id2labels)
63
+ self.dropout = nn.Dropout(dropout_rate)
64
+ self.classifier = nn.Sequential(nn.Linear(embedding_size, self.label_num))
65
+ self.crf = ConditionalRandomField(num_tags=self.label_num, allowed_transitions=allowed_transitions(id2labels))
66
+ self.crf.trans_m.data *= 0
67
+
68
+ def conv_feat_extract(self, x):
69
+ out = self.conv(x)
70
+ out = out.transpose(1, 2)
71
+ return out
72
+
73
+ def forward(self, x, labels):
74
+ x = x.transpose(1, 2)
75
+ out1 = self.conv_feat_extract(x[:, 0:1, :])
76
+ out2 = self.conv_feat_extract(x[:, 1:2, :])
77
+ out3 = self.conv_feat_extract(x[:, 2:3, :])
78
+ out4 = self.conv_feat_extract(x[:, 3:4, :])
79
+ outputs = torch.cat((out1, out2, out3, out4), dim=2)
80
+
81
+ outputs = self.norm(outputs)
82
+ dropout_outputs = self.dropout(outputs)
83
+ logits = self.classifier(dropout_outputs)
84
+
85
+ if self.training:
86
+ loss_fct = CrossEntropyLoss(ignore_index=-1)
87
+ loss = loss_fct(logits.view(-1, self.label_num), labels.view(-1))
88
+ output = {'loss': loss, 'logits': logits}
89
+ else:
90
+ mask = labels.gt(-1)
91
+ paths, scores = self.crf.viterbi_decode(logits=logits, mask=mask)
92
+ paths[mask==0] = -1
93
+ output = {'preds': paths, 'logits': logits}
94
+ pass
95
+
96
+ return output
97
+
98
+
99
+ class ModelWiseTransformerClassifier(nn.Module):
100
+
101
+ def __init__(self, id2labels, seq_len, intermediate_size = 512, num_layers=2, dropout_rate=0.1):
102
+ super(ModelWiseTransformerClassifier, self).__init__()
103
+ # feature_enc_layers = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]
104
+ feature_enc_layers = [(64, 5, 1)] + [(128, 3, 1)] * 3 + [(64, 3, 1)]
105
+ self.conv = ConvFeatureExtractionModel(
106
+ conv_layers=feature_enc_layers,
107
+ conv_dropout=0.0,
108
+ conv_bias=False,
109
+ )
110
+
111
+ self.seq_len = seq_len # MAX Seq_len
112
+ embedding_size = 4 *64
113
+ self.encoder_layer = TransformerEncoderLayer(
114
+ d_model=embedding_size,
115
+ nhead=16,
116
+ dim_feedforward=intermediate_size,
117
+ dropout=dropout_rate,
118
+ batch_first=True)
119
+ self.encoder = TransformerEncoder(encoder_layer=self.encoder_layer,
120
+ num_layers=num_layers)
121
+
122
+ self.position_encoding = torch.zeros((seq_len, embedding_size))
123
+ for pos in range(seq_len):
124
+ for i in range(0, embedding_size, 2):
125
+ self.position_encoding[pos, i] = torch.sin(
126
+ torch.tensor(pos / (10000**((2 * i) / embedding_size))))
127
+ self.position_encoding[pos, i + 1] = torch.cos(
128
+ torch.tensor(pos / (10000**((2 *
129
+ (i + 1)) / embedding_size))))
130
+
131
+ self.norm = nn.LayerNorm(embedding_size)
132
+
133
+ self.label_num = len(id2labels)
134
+ self.dropout = nn.Dropout(dropout_rate)
135
+ self.classifier = nn.Sequential(nn.Linear(embedding_size, self.label_num))
136
+ self.crf = ConditionalRandomField(num_tags=self.label_num, allowed_transitions=allowed_transitions(id2labels))
137
+ self.crf.trans_m.data *= 0
138
+
139
+ def conv_feat_extract(self, x):
140
+ out = self.conv(x)
141
+ out = out.transpose(1, 2)
142
+ return out
143
+
144
+ def forward(self, x, labels):
145
+ mask = labels.gt(-1)
146
+ padding_mask = ~mask
147
+
148
+ x = x.transpose(1, 2)
149
+ out1 = self.conv_feat_extract(x[:, 0:1, :])
150
+ out2 = self.conv_feat_extract(x[:, 1:2, :])
151
+ out3 = self.conv_feat_extract(x[:, 2:3, :])
152
+ out4 = self.conv_feat_extract(x[:, 3:4, :])
153
+ out = torch.cat((out1, out2, out3, out4), dim=2)
154
+
155
+ outputs = out + self.position_encoding.to(out.device)
156
+ outputs = self.norm(outputs)
157
+ outputs = self.encoder(outputs, src_key_padding_mask=padding_mask)
158
+ dropout_outputs = self.dropout(outputs)
159
+ logits = self.classifier(dropout_outputs)
160
+
161
+ if self.training:
162
+ loss_fct = CrossEntropyLoss(ignore_index=-1)
163
+ loss = loss_fct(logits.view(-1, self.label_num), labels.view(-1))
164
+ output = {'loss': loss, 'logits': logits}
165
+ else:
166
+ paths, scores = self.crf.viterbi_decode(logits=logits, mask=mask)
167
+ paths[mask==0] = -1
168
+ output = {'preds': paths, 'logits': logits}
169
+ pass
170
+
171
+ return output
172
+
173
+
174
+ class TransformerOnlyClassifier(nn.Module):
175
+
176
+ def __init__(self, id2labels, seq_len, embedding_size=4, num_heads=2, intermediate_size=64, num_layers=2, dropout_rate=0.1):
177
+ super(TransformerOnlyClassifier, self).__init__()
178
+
179
+ self.encoder_layer = TransformerEncoderLayer(
180
+ d_model=embedding_size,
181
+ nhead=num_heads,
182
+ dim_feedforward=intermediate_size,
183
+ dropout=dropout_rate,
184
+ batch_first=True)
185
+ self.encoder = TransformerEncoder(encoder_layer=self.encoder_layer,
186
+ num_layers=num_layers)
187
+
188
+ self.position_encoding = torch.zeros((seq_len, embedding_size))
189
+ for pos in range(seq_len):
190
+ for i in range(0, embedding_size, 2):
191
+ self.position_encoding[pos, i] = torch.sin(
192
+ torch.tensor(pos / (10000**((2 * i) / embedding_size))))
193
+ self.position_encoding[pos, i + 1] = torch.cos(
194
+ torch.tensor(pos / (10000**((2 *
195
+ (i + 1)) / embedding_size))))
196
+
197
+ self.norm = nn.LayerNorm(embedding_size)
198
+
199
+ self.label_num = len(id2labels)
200
+ self.dropout = nn.Dropout(dropout_rate)
201
+ self.classifier = nn.Sequential(nn.Linear(embedding_size, self.label_num))
202
+ self.crf = ConditionalRandomField(num_tags=self.label_num, allowed_transitions=allowed_transitions(id2labels))
203
+ self.crf.trans_m.data *= 0
204
+
205
+ def forward(self, inputs, labels):
206
+ mask = labels.gt(-1)
207
+ padding_mask = ~mask
208
+
209
+ outputs = inputs + self.position_encoding.to(inputs.device)
210
+ outputs = self.norm(outputs)
211
+ outputs = self.encoder(outputs, src_key_padding_mask=padding_mask)
212
+ dropout_outputs = self.dropout(outputs)
213
+ logits = self.classifier(dropout_outputs)
214
+
215
+ if self.training:
216
+ loss_fct = CrossEntropyLoss(ignore_index=-1)
217
+ loss = loss_fct(logits.view(-1, self.label_num), labels.view(-1))
218
+ output = {'loss': loss, 'logits': logits}
219
+ else:
220
+ paths, scores = self.crf.viterbi_decode(logits=logits, mask=mask)
221
+ paths[mask==0] = -1
222
+ output = {'preds': paths, 'logits': logits}
223
+ pass
224
+
225
+ return output