Skip to content

Commit 4faf8e9

Browse files
nickdrhodesalexmirrington
authored andcommitted
default model and full dataset
1 parent 98f37d0 commit 4faf8e9

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

code/algorithm/main.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,13 @@ def main(config):
5555
if not config.noval:
5656
# Use 10% of data for validation
5757
split = int(0.9*len(train_val_data))
58-
# train_data = Subset(train_val_data, range(0, split))
59-
# val_data = Subset(train_val_data, range(split, len(train_val_data)))
60-
train_data = Subset(train_val_data, range(0, 2))
61-
val_data = Subset(train_val_data, range(2, 4))
58+
train_data = Subset(train_val_data, range(0, split))
59+
val_data = Subset(train_val_data, range(split, len(train_val_data)))
6260
print(f'Train: {len(train_data)}')
6361
print(f'Val: {len(val_data)}')
6462
else:
6563
# Do not create a validation set
66-
#train_data = train_val_data
67-
train_data = Subset(train_val_data, range(0, 2))
68-
64+
train_data = train_val_data
6965
val_data = None
7066
print(f'Train: {len(train_data)}')
7167

@@ -424,13 +420,14 @@ def parse_args(args):
424420
)
425421
model_params_group.add_argument(
426422
'--model-type',
423+
default='rcnn_lstm',
427424
choices=['rcnn_lstm',
428425
'rcnn',
429426
'lstm',
430427
'tfidf'
431428
],
432429
type=str,
433-
required=True,
430+
required=False,
434431
help='Which model type to run.'
435432
)
436433
model_params_group.add_argument(

0 commit comments

Comments
 (0)