84
84
"outputs" : [],
85
85
"source" : [
86
86
" def train(model, optimizer, loss_fn, train_loader, val_loader, epochs=20, device=\" cpu\" ):\n " ,
87
- " for epoch in range(epochs):\n " ,
87
+ " for epoch in range(1, epochs+1 ):\n " ,
88
88
" training_loss = 0.0\n " ,
89
89
" valid_loss = 0.0\n " ,
90
90
" model.train()\n " ,
131
131
" return True\n " ,
132
132
" except:\n " ,
133
133
" return False\n " ,
134
+ " \n " ,
134
135
" img_transforms = transforms.Compose([\n " ,
135
136
" transforms.Resize((64,64)), \n " ,
136
137
" transforms.ToTensor(),\n " ,
137
138
" transforms.Normalize(mean=[0.485, 0.456, 0.406],\n " ,
138
- " std=[0.229, 0.224, 0.225] )\n " ,
139
+ " std=[0.229, 0.224, 0.225] )\n " ,
139
140
" ])\n " ,
140
141
" train_data_path = \" ./train/\"\n " ,
141
142
" train_data = torchvision.datasets.ImageFolder(root=train_data_path,transform=img_transforms, is_valid_file=check_image)\n " ,
144
145
" batch_size=64\n " ,
145
146
" train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)\n " ,
146
147
" val_data_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=True)\n " ,
148
+ " \n " ,
147
149
" if torch.cuda.is_available():\n " ,
148
150
" device = torch.device(\" cuda\" ) \n " ,
149
151
" else:\n " ,
175
177
"metadata" : {},
176
178
"outputs" : [],
177
179
"source" : [
178
- " train(transfer_model, optimizer,torch.nn.CrossEntropyLoss(), train_data_loader,val_data_loader, epochs=5, device=device)"
180
+ " train(transfer_model, optimizer,torch.nn.CrossEntropyLoss(), train_data_loader, val_data_loader, epochs=5,\n " ,
181
+ " device=device)"
179
182
]
180
183
},
181
184
{
406
409
},
407
410
"nbformat" : 4 ,
408
411
"nbformat_minor" : 2
409
- }
412
+ }
0 commit comments