@@ -87,28 +87,32 @@ def del_list_idx(l, id_to_del):
8787#all_features = [DocFeature(doc_id=ids, seg_labels=all_labels.loc[all_labels['id'] == ids],
8888# raw_text=all_doc_texts[all_doc_ids.index(ids)], train_or_test='test', tokenizer=TOKENIZER) for ids in tqdm(all_doc_ids)]
8989
90+ try :
91+ print ('Dataset Loading...' )
92+ train_ds = torch .load ('train_ds.pt' )
93+ dev_ds = torch .load ('dev_ds.pt' )
94+ test_ds = torch .load ('test_ds.pt' )
9095
91- train_features = [DocFeature (doc_id = ids , seg_labels = all_labels .loc [all_labels ['id' ] == ids ],
92- raw_text = train_doc_texts [train_doc_ids .index (ids )], train_or_test = 'train' , tokenizer = TOKENIZER ) for ids in tqdm (train_doc_ids )]
93- dev_features = [DocFeature (doc_id = ids , seg_labels = all_labels .loc [all_labels ['id' ] == ids ],
94- raw_text = dev_doc_texts [dev_doc_ids .index (ids )], train_or_test = 'train' , tokenizer = TOKENIZER ) for ids in tqdm (dev_doc_ids )]
95- test_features = [DocFeature (doc_id = ids , raw_text = test_doc_texts [test_doc_ids .index (
96- ids )], train_or_test = 'test' , tokenizer = TOKENIZER ) for ids in test_doc_ids ]
96+ except FileNotFoundError :
97+ print ('Create Dataset' )
98+ train_features = [DocFeature (doc_id = ids , seg_labels = all_labels .loc [all_labels ['id' ] == ids ],
99+ raw_text = train_doc_texts [train_doc_ids .index (ids )], train_or_test = 'train' , tokenizer = TOKENIZER ) for ids in tqdm (train_doc_ids )]
100+ dev_features = [DocFeature (doc_id = ids , seg_labels = all_labels .loc [all_labels ['id' ] == ids ],
101+ raw_text = dev_doc_texts [dev_doc_ids .index (ids )], train_or_test = 'train' , tokenizer = TOKENIZER ) for ids in tqdm (dev_doc_ids )]
102+ test_features = [DocFeature (doc_id = ids , raw_text = test_doc_texts [test_doc_ids .index (
103+ ids )], train_or_test = 'test' , tokenizer = TOKENIZER ) for ids in test_doc_ids ]
97104
98- train_features = [f for f in train_features if f .err is False ]
99- dev_features = [f for f in dev_features if f .err is False ]
100- test_features = [f for f in test_features if f .err is False ]
105+ train_features = [f for f in train_features if f .err is False ]
101106
102- train_ds = create_tensor_ds_sliding_window (train_features )
103- dev_ds = create_tensor_ds_sliding_window (dev_features )
104- test_ds = create_tensor_ds_sliding_window_test (dev_features )
105107
108+ train_ds = create_tensor_ds_sliding_window (train_features )
109+ dev_ds = create_tensor_ds_sliding_window (dev_features )
110+ test_ds = create_tensor_ds_sliding_window_test (dev_features )
106111
107-
108- # Due to design of huggingface's tokenizer, not possible to multithread to speed up the loading
109- # Better run once and load for future development
110- #train_ds = torch.load('train_ds.pt')
111- #dev_ds = torch.load('dev_ds.pt')
112+ print ('Dataset Saving...' )
113+ torch .save (train_ds , 'train_ds.pt' )
114+ torch .save (dev_ds ,'dev_ds.pt' )
115+ torch .save (test_ds , 'test_ds.pt' )
112116
113117
114118train_sp = RandomSampler (train_ds )
@@ -308,4 +312,4 @@ def update(self, val, n=1):
308312 preds = list (itertools .chain .from_iterable (list (itertools .chain .from_iterable (preds ))))
309313 targets = list (itertools .chain .from_iterable (list (itertools .chain .from_iterable (targets ))))
310314 print (classification_report (targets , preds , digits = 4 ))
311- print ()
315+ print ()
0 commit comments