Skip to content

Commit ad0ba22

Browse files
zxyezxye
authored andcommitted
add char-cnn
1 parent 17d6284 commit ad0ba22

File tree

4 files changed

+200
-75
lines changed

4 files changed

+200
-75
lines changed

eval.py

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
help='score file location'
2424
)
2525
optparser.add_option(
26-
"-f", "--crf", default="1",
26+
"-f", "--crf", default="0",
2727
type='int', help="Use CRF (0 to disable)"
2828
)
2929
optparser.add_option(
@@ -42,6 +42,10 @@
4242
'--map_path', default='models/mapping.pkl',
4343
help='model path'
4444
)
45+
optparser.add_option(
46+
'--char_mode', choices=['CNN', 'LSTM'], default='CNN',
47+
help='char_CNN or char_LSTM'
48+
)
4549

4650
opts = optparser.parse_args()[0]
4751

@@ -102,7 +106,7 @@
102106
# l = 1
103107
# return maxl
104108

105-
def eval(model, datas):
109+
def eval(model, datas, maxl=1):
106110
prediction = []
107111
confusion_matrix = torch.zeros((len(tag_to_id) - 2, len(tag_to_id) - 2))
108112
for data in datas:
@@ -113,23 +117,33 @@ def eval(model, datas):
113117
words = data['str_words']
114118
chars2 = data['chars']
115119
caps = data['caps']
116-
chars2_sorted = sorted(chars2, key=lambda p: len(p), reverse=True)
117-
d = {}
118-
for i, ci in enumerate(chars2):
119-
for j, cj in enumerate(chars2_sorted):
120-
if ci == cj:
121-
d[j] = i
122-
continue
123-
chars2_length = [len(c) for c in chars2_sorted]
124-
char_maxl = max(chars2_length)
125-
chars2_mask = np.zeros((len(chars2_sorted), char_maxl), dtype='int')
126-
for i, c in enumerate(chars2_sorted):
127-
chars2_mask[i, :chars2_length[i]] = c
128-
129-
chars2_mask = Variable(torch.LongTensor(chars2_mask))
120+
121+
if parameters['char_mode'] == 'LSTM':
122+
chars2_sorted = sorted(chars2, key=lambda p: len(p), reverse=True)
123+
d = {}
124+
for i, ci in enumerate(chars2):
125+
for j, cj in enumerate(chars2_sorted):
126+
if ci == cj:
127+
d[j] = i
128+
continue
129+
chars2_length = [len(c) for c in chars2_sorted]
130+
char_maxl = max(chars2_length)
131+
chars2_mask = np.zeros((len(chars2_sorted), char_maxl), dtype='int')
132+
for i, c in enumerate(chars2_sorted):
133+
chars2_mask[i, :chars2_length[i]] = c
134+
chars2_mask = Variable(torch.LongTensor(chars2_mask))
135+
136+
if parameters['char_mode'] == 'CNN':
137+
d = {}
138+
chars2_length = [len(c) for c in chars2]
139+
char_maxl = max(chars2_length)
140+
chars2_mask = np.zeros((len(chars2_length), char_maxl), dtype='int')
141+
for i, c in enumerate(chars2):
142+
chars2_mask[i, :chars2_length[i]] = c
143+
chars2_mask = Variable(torch.LongTensor(chars2_mask))
144+
130145
dwords = Variable(torch.LongTensor(data['words']))
131146
dcaps = Variable(torch.LongTensor(caps))
132-
133147
if use_gpu:
134148
val, out = model(dwords.cuda(), chars2_mask.cuda(), dcaps.cuda(),chars2_length, d)
135149
else:
@@ -164,11 +178,10 @@ def eval(model, datas):
164178
))
165179

166180
# for l in range(1, 6):
167-
# for i in range(10):
168-
# eval(model, test_data, l)
169-
# print()
170-
# print()
171-
172-
eval(model, test_data)
181+
# print('maxl=', l)
182+
# eval(model, test_data, l)
183+
# # print()
184+
# # for i in range(10):
185+
# # eval(model, test_data, 100)
173186

174187
print(time.time() - t)

loader.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,15 @@ def word_mapping(sentences, lower):
7171
"""
7272
words = [[x[0].lower() if lower else x[0] for x in s] for s in sentences]
7373
dico = create_dico(words)
74+
7475
dico['<PAD>'] = 10000001
7576
dico['<UNK>'] = 10000000
77+
dico = {k:v for k,v in dico.items() if v>=3}
7678
word_to_id, id_to_word = create_mapping(dico)
77-
print "Found %i unique words (%i in total)" % (
79+
80+
print("Found %i unique words (%i in total)" % (
7881
len(dico), sum(len(x) for x in words)
79-
)
82+
))
8083
return dico, word_to_id, id_to_word
8184

8285

@@ -142,7 +145,7 @@ def f(x): return x.lower() if lower else x
142145
}
143146

144147

145-
def prepare_dataset(sentences, word_to_id, char_to_id, tag_to_id, lower=False):
148+
def prepare_dataset(sentences, word_to_id, char_to_id, tag_to_id, lower=True):
146149
"""
147150
Prepare the dataset. Return a list of lists of dictionaries containing:
148151
- word indexes

model.py

Lines changed: 92 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import torch.autograd as autograd
3-
import torch.nn as nn
43
from torch.autograd import Variable
4+
from utils import *
55

66
START_TAG = '<START>'
77
STOP_TAG = '<STOP>'
@@ -34,7 +34,7 @@ class BiLSTM_CRF(nn.Module):
3434

3535
def __init__(self, vocab_size, tag_to_ix, embedding_dim, hidden_dim, char_lstm_dim=25,
3636
char_to_ix=None, pre_word_embeds=None, char_embedding_dim=25, use_gpu=False,
37-
n_cap=None, cap_embedding_dim=None, use_crf=True):
37+
n_cap=None, cap_embedding_dim=None, use_crf=True, char_mode='CNN'):
3838
super(BiLSTM_CRF, self).__init__()
3939
self.use_gpu = use_gpu
4040
self.embedding_dim = embedding_dim
@@ -45,13 +45,21 @@ def __init__(self, vocab_size, tag_to_ix, embedding_dim, hidden_dim, char_lstm_d
4545
self.cap_embedding_dim = cap_embedding_dim
4646
self.use_crf = use_crf
4747
self.tagset_size = len(tag_to_ix)
48+
self.out_channels = char_lstm_dim
49+
self.char_mode = char_mode
4850
if self.n_cap and self.cap_embedding_dim:
4951
self.cap_embeds = nn.Embedding(self.n_cap, self.cap_embedding_dim)
52+
init_embedding(self.cap_embeds.weight)
5053

5154
if char_embedding_dim is not None:
5255
self.char_lstm_dim = char_lstm_dim
5356
self.char_embeds = nn.Embedding(len(char_to_ix), char_embedding_dim)
54-
self.char_lstm = nn.LSTM(char_embedding_dim, char_lstm_dim, num_layers=1, bidirectional=True)
57+
init_embedding(self.char_embeds.weight)
58+
if self.char_mode == 'LSTM':
59+
self.char_lstm = nn.LSTM(char_embedding_dim, char_lstm_dim, num_layers=1, bidirectional=True)
60+
init_lstm(self.char_lstm)
61+
if self.char_mode == 'CNN':
62+
self.char_cnn3 = nn.Conv2d(in_channels=1, out_channels=self.out_channels, kernel_size=(3, char_embedding_dim), padding=(2,0))
5563

5664
self.word_embeds = nn.Embedding(vocab_size, embedding_dim)
5765
if pre_word_embeds is not None:
@@ -62,39 +70,65 @@ def __init__(self, vocab_size, tag_to_ix, embedding_dim, hidden_dim, char_lstm_d
6270

6371
self.dropout = nn.Dropout(0.5)
6472
if self.n_cap and self.cap_embedding_dim:
65-
self.lstm = nn.LSTM(embedding_dim+char_lstm_dim*2+cap_embedding_dim, hidden_dim,
66-
num_layers=1, bidirectional=True)
73+
if self.char_mode == 'LSTM':
74+
self.lstm = nn.LSTM(embedding_dim+char_lstm_dim*2+cap_embedding_dim, hidden_dim, bidirectional=True)
75+
if self.char_mode == 'CNN':
76+
self.lstm = nn.LSTM(embedding_dim+self.out_channels+cap_embedding_dim, hidden_dim, bidirectional=True)
6777
else:
68-
self.lstm = nn.LSTM(embedding_dim+char_lstm_dim*2, hidden_dim,
69-
num_layers=1, bidirectional=True)
78+
if self.char_mode == 'LSTM':
79+
self.lstm = nn.LSTM(embedding_dim+char_lstm_dim*2, hidden_dim, bidirectional=True)
80+
if self.char_mode == 'CNN':
81+
self.lstm = nn.LSTM(embedding_dim+self.out_channels, hidden_dim, bidirectional=True)
82+
init_lstm(self.lstm)
83+
self.hw_trans = nn.Linear(self.out_channels, self.out_channels)
84+
self.hw_gate = nn.Linear(self.out_channels, self.out_channels)
7085
self.h2_h1 = nn.Linear(hidden_dim*2, hidden_dim)
7186
self.tanh = nn.Tanh()
7287
self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size)
88+
init_linear(self.h2_h1)
89+
init_linear(self.hidden2tag)
90+
init_linear(self.hw_gate)
91+
init_linear(self.hw_trans)
7392

74-
# trans is also a score tensor, not a probability
7593
if self.use_crf:
7694
self.transitions = nn.Parameter(
77-
torch.randn(self.tagset_size, self.tagset_size))
95+
torch.zeros(self.tagset_size, self.tagset_size))
7896
self.transitions.data[tag_to_ix[START_TAG], :] = -10000
7997
self.transitions.data[:, tag_to_ix[STOP_TAG]] = -10000
8098

81-
def init_char_hidden(self, batchsize):
82-
83-
if self.use_gpu:
84-
return (autograd.Variable(torch.randn(2, batchsize, self.char_lstm_dim)).cuda(),
85-
autograd.Variable(torch.randn(2, batchsize, self.char_lstm_dim)).cuda())
86-
else:
87-
return (autograd.Variable(torch.randn(2, batchsize, self.char_lstm_dim)),
88-
autograd.Variable(torch.randn(2, batchsize, self.char_lstm_dim)))
89-
90-
91-
def init_hidden(self):
92-
if self.use_gpu:
93-
return (autograd.Variable(torch.randn(2, 1, self.hidden_dim)).cuda(),
94-
autograd.Variable(torch.randn(2, 1, self.hidden_dim)).cuda())
99+
def init_lstm_hidden(self, dim, bidirection=True, batchsize=1):
100+
l = 1 + bidirection
101+
if self.training:
102+
if self.use_gpu:
103+
return (Variable(torch.randn(l, batchsize, dim)).cuda(),
104+
Variable(torch.randn(l, batchsize, dim)).cuda())
105+
else:
106+
return (Variable(torch.randn(l, batchsize, dim)),
107+
Variable(torch.randn(l, batchsize, dim)))
95108
else:
96-
return (autograd.Variable(torch.randn(2, 1, self.hidden_dim)),
97-
autograd.Variable(torch.randn(2, 1, self.hidden_dim)))
109+
if self.use_gpu:
110+
return (Variable(torch.zeros(l, batchsize, dim)).cuda(),
111+
Variable(torch.zeros(l, batchsize, dim)).cuda())
112+
else:
113+
return (Variable(torch.zeros(l, batchsize, dim)),
114+
Variable(torch.zeros(l, batchsize, dim)))
115+
116+
117+
# def init_hidden(self):
118+
# if self.training:
119+
# if self.use_gpu:
120+
# return (autograd.Variable(torch.randn(2, 1, self.hidden_dim)).cuda(),
121+
# autograd.Variable(torch.randn(2, 1, self.hidden_dim)).cuda())
122+
# else:
123+
# return (autograd.Variable(torch.randn(2, 1, self.hidden_dim)),
124+
# autograd.Variable(torch.randn(2, 1, self.hidden_dim)))
125+
# else:
126+
# if self.use_gpu:
127+
# return (autograd.Variable(torch.zeros(2, 1, self.hidden_dim)).cuda(),
128+
# autograd.Variable(torch.zeros(2, 1, self.hidden_dim)).cuda())
129+
# else:
130+
# return (autograd.Variable(torch.zeros(2, 1, self.hidden_dim)),
131+
# autograd.Variable(torch.zeros(2, 1, self.hidden_dim)))
98132

99133
def _score_sentence(self, feats, tags):
100134
# tags is ground_truth, a list of ints, length is len(sentence)
@@ -113,29 +147,41 @@ def _score_sentence(self, feats, tags):
113147
return score
114148

115149
def _get_lstm_features(self, sentence, chars2, caps, chars2_length, d):
116-
# sentence: a list of ints
117-
# initialize lstm hidden state, h and c
118-
self.hidden = self.init_hidden()
119-
self.char_lstm_hidden = self.init_char_hidden(chars2.size(0))
120-
121-
chars_embeds = self.char_embeds(chars2).transpose(0, 1)
122-
packed = torch.nn.utils.rnn.pack_padded_sequence(chars_embeds, chars2_length)
123-
lstm_out, self.char_lstm_hidden = self.char_lstm(packed, self.char_lstm_hidden)
124-
125-
# outputs: maxlength * len(sentence) * hiddensize
126-
outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(lstm_out)
127-
outputs = outputs.transpose(0, 1)
128-
chars_embeds_temp = Variable(torch.FloatTensor(torch.zeros((outputs.size(0), outputs.size(2)))))
129-
if self.use_gpu:
130-
chars_embeds_temp = chars_embeds_temp.cuda()
131-
for i, index in enumerate(output_lengths):
132-
chars_embeds_temp[i] = outputs[i, index-1]
133-
chars_embeds = chars_embeds_temp.clone()
134-
for i in range(chars_embeds.size(0)):
135-
chars_embeds[d[i]] = chars_embeds_temp[i]
136150

137-
embeds = self.word_embeds(sentence)
151+
# # sentence: a list of ints
152+
# # initialize lstm hidden state, h and c
153+
# # self.hidden = self.init_hidden()
154+
self.hidden = self.init_lstm_hidden(dim=self.hidden_dim, bidirection=True, batchsize=1)
138155

156+
if self.char_mode == 'LSTM':
157+
158+
self.char_lstm_hidden = self.init_lstm_hidden(dim=self.char_lstm_dim, bidirection=True, batchsize=chars2.size(0))
159+
chars_embeds = self.char_embeds(chars2).transpose(0, 1)
160+
packed = torch.nn.utils.rnn.pack_padded_sequence(chars_embeds, chars2_length)
161+
lstm_out, self.char_lstm_hidden = self.char_lstm(packed, self.char_lstm_hidden)
162+
outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(lstm_out)
163+
outputs = outputs.transpose(0, 1)
164+
chars_embeds_temp = Variable(torch.FloatTensor(torch.zeros((outputs.size(0), outputs.size(2)))))
165+
if self.use_gpu:
166+
chars_embeds_temp = chars_embeds_temp.cuda()
167+
for i, index in enumerate(output_lengths):
168+
chars_embeds_temp[i] = torch.cat((outputs[i, index-1, :self.char_lstm_dim], outputs[i, 0, self.char_lstm_dim:]))
169+
chars_embeds = chars_embeds_temp.clone()
170+
for i in range(chars_embeds.size(0)):
171+
chars_embeds[d[i]] = chars_embeds_temp[i]
172+
173+
if self.char_mode == 'CNN':
174+
chars_embeds = self.char_embeds(chars2).unsqueeze(1)
175+
chars_cnn_out3 = self.char_cnn3(chars_embeds)
176+
chars_embeds = nn.functional.max_pool2d(chars_cnn_out3,
177+
kernel_size=(chars_cnn_out3.size(2), 1)).view(chars_cnn_out3.size(0), self.out_channels)
178+
179+
# t = self.hw_gate(chars_embeds)
180+
# g = nn.functional.sigmoid(t)
181+
# h = nn.functional.relu(self.hw_trans(chars_embeds))
182+
# chars_embeds = g * h + (1 - g) * chars_embeds
183+
184+
embeds = self.word_embeds(sentence)
139185
if self.n_cap and self.cap_embedding_dim:
140186
cap_embedding = self.cap_embeds(caps)
141187

@@ -146,7 +192,6 @@ def _get_lstm_features(self, sentence, chars2, caps, chars2_length, d):
146192

147193
embeds = embeds.unsqueeze(1)
148194
embeds = self.dropout(embeds)
149-
150195
lstm_out, self.hidden = self.lstm(embeds, self.hidden)
151196
lstm_out = lstm_out.view(len(sentence), self.hidden_dim*2)
152197
lstm_out = self.h2_h1(lstm_out)
@@ -215,7 +260,6 @@ def neg_log_likelihood(self, sentence, tags, chars2, caps, chars2_length, d):
215260

216261
if self.use_crf:
217262
forward_score = self._forward_alg(feats)
218-
# calculate the score of the ground_truth, in CRF
219263
gold_score = self._score_sentence(feats, tags)
220264
return forward_score - gold_score
221265
else:

0 commit comments

Comments
 (0)