Skip to content

Commit 1ed0714

Browse files
committed
2 parents 031d314 + 19ee3c4 commit 1ed0714

File tree

1 file changed

+22
-18
lines changed

1 file changed

+22
-18
lines changed

main.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -87,28 +87,32 @@ def del_list_idx(l, id_to_del):
8787
#all_features = [DocFeature(doc_id=ids, seg_labels=all_labels.loc[all_labels['id'] == ids],
8888
# raw_text=all_doc_texts[all_doc_ids.index(ids)], train_or_test='test', tokenizer=TOKENIZER) for ids in tqdm(all_doc_ids)]
8989

90+
try:
91+
print('Dataset Loading...')
92+
train_ds = torch.load('train_ds.pt')
93+
dev_ds = torch.load('dev_ds.pt')
94+
test_ds = torch.load('test_ds.pt')
9095

91-
train_features = [DocFeature(doc_id=ids, seg_labels=all_labels.loc[all_labels['id'] == ids],
92-
raw_text=train_doc_texts[train_doc_ids.index(ids)], train_or_test='train', tokenizer=TOKENIZER) for ids in tqdm(train_doc_ids)]
93-
dev_features = [DocFeature(doc_id=ids, seg_labels=all_labels.loc[all_labels['id'] == ids],
94-
raw_text=dev_doc_texts[dev_doc_ids.index(ids)], train_or_test='train', tokenizer=TOKENIZER) for ids in tqdm(dev_doc_ids)]
95-
test_features = [DocFeature(doc_id=ids, raw_text=test_doc_texts[test_doc_ids.index(
96-
ids)], train_or_test='test', tokenizer=TOKENIZER) for ids in test_doc_ids]
96+
except FileNotFoundError:
97+
print('Create Dataset')
98+
train_features = [DocFeature(doc_id=ids, seg_labels=all_labels.loc[all_labels['id'] == ids],
99+
raw_text=train_doc_texts[train_doc_ids.index(ids)], train_or_test='train', tokenizer=TOKENIZER) for ids in tqdm(train_doc_ids)]
100+
dev_features = [DocFeature(doc_id=ids, seg_labels=all_labels.loc[all_labels['id'] == ids],
101+
raw_text=dev_doc_texts[dev_doc_ids.index(ids)], train_or_test='train', tokenizer=TOKENIZER) for ids in tqdm(dev_doc_ids)]
102+
test_features = [DocFeature(doc_id=ids, raw_text=test_doc_texts[test_doc_ids.index(
103+
ids)], train_or_test='test', tokenizer=TOKENIZER) for ids in test_doc_ids]
97104

98-
train_features = [f for f in train_features if f.err is False]
99-
dev_features = [f for f in dev_features if f.err is False]
100-
test_features = [f for f in test_features if f.err is False]
105+
train_features = [f for f in train_features if f.err is False]
101106

102-
train_ds = create_tensor_ds_sliding_window(train_features)
103-
dev_ds = create_tensor_ds_sliding_window(dev_features)
104-
test_ds = create_tensor_ds_sliding_window_test(dev_features)
105107

108+
train_ds = create_tensor_ds_sliding_window(train_features)
109+
dev_ds = create_tensor_ds_sliding_window(dev_features)
110+
test_ds = create_tensor_ds_sliding_window_test(dev_features)
106111

107-
108-
# Due to design of huggingface's tokenizer, not possible to multithread to speed up the loading
109-
# Better run once and load for future development
110-
#train_ds = torch.load('train_ds.pt')
111-
#dev_ds = torch.load('dev_ds.pt')
112+
print('Dataset Saving...')
113+
torch.save(train_ds, 'train_ds.pt')
114+
torch.save(dev_ds,'dev_ds.pt')
115+
torch.save(test_ds, 'test_ds.pt')
112116

113117

114118
train_sp = RandomSampler(train_ds)
@@ -308,4 +312,4 @@ def update(self, val, n=1):
308312
preds=list(itertools.chain.from_iterable(list(itertools.chain.from_iterable(preds))))
309313
targets=list(itertools.chain.from_iterable(list(itertools.chain.from_iterable(targets))))
310314
print(classification_report(targets, preds, digits=4))
311-
print()
315+
print()

0 commit comments

Comments
 (0)