Skip to content

Commit b81eeed

Browse files
Merge pull request falloutdurham#43 from MarcusFra/ch04_range
Change range() to count epochs from 1
2 parents 933b665 + b5648ad commit b81eeed

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

chapter4/Chapter 4.ipynb

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@
8484
"outputs": [],
8585
"source": [
8686
"def train(model, optimizer, loss_fn, train_loader, val_loader, epochs=20, device=\"cpu\"):\n",
87-
" for epoch in range(epochs):\n",
87+
" for epoch in range(1, epochs+1):\n",
8888
" training_loss = 0.0\n",
8989
" valid_loss = 0.0\n",
9090
" model.train()\n",
@@ -131,11 +131,12 @@
131131
" return True\n",
132132
" except:\n",
133133
" return False\n",
134+
"\n",
134135
"img_transforms = transforms.Compose([\n",
135136
" transforms.Resize((64,64)), \n",
136137
" transforms.ToTensor(),\n",
137138
" transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
138-
" std=[0.229, 0.224, 0.225] )\n",
139+
" std=[0.229, 0.224, 0.225] )\n",
139140
" ])\n",
140141
"train_data_path = \"./train/\"\n",
141142
"train_data = torchvision.datasets.ImageFolder(root=train_data_path,transform=img_transforms, is_valid_file=check_image)\n",
@@ -144,6 +145,7 @@
144145
"batch_size=64\n",
145146
"train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)\n",
146147
"val_data_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=True)\n",
148+
"\n",
147149
"if torch.cuda.is_available():\n",
148150
" device = torch.device(\"cuda\") \n",
149151
"else:\n",
@@ -175,7 +177,8 @@
175177
"metadata": {},
176178
"outputs": [],
177179
"source": [
178-
"train(transfer_model, optimizer,torch.nn.CrossEntropyLoss(), train_data_loader,val_data_loader, epochs=5, device=device)"
180+
"train(transfer_model, optimizer,torch.nn.CrossEntropyLoss(), train_data_loader, val_data_loader, epochs=5,\n",
181+
" device=device)"
179182
]
180183
},
181184
{
@@ -406,4 +409,4 @@
406409
},
407410
"nbformat": 4,
408411
"nbformat_minor": 2
409-
}
412+
}

0 commit comments

Comments
 (0)