Skip to content

Commit b3cf9f2

Browse files
committed
Working on loss
1 parent 14b71e5 commit b3cf9f2

File tree

4 files changed

+9
-7
lines changed

4 files changed

+9
-7
lines changed

scripts/train.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
from torch.utils.data import DataLoader
33

4+
from src.YOLOLoss import YOLOLoss
45
from src.YOLOv3 import YOLOv3
56
from src.YOLOv3Dataset import YOLOv3Dataset
67

@@ -11,4 +12,6 @@
1112
# model.load_state_dict(torch.load("weights.pt"))
1213

1314
train_dataset = YOLOv3Dataset()
14-
dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
15+
dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
16+
17+
loss = YOLOLoss()

src/YOLOLoss.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ def __init__(self):
1212
self.mse_loss = nn.MSELoss() # For Bounding Box Prediction
1313
self.bce_loss = nn.BCELoss() # For Class Prediction
1414

15-
def forward(self, x, y):
16-
# self.metrics, total_loss = self.calculate_metrics(pred_boxes=pred_boxes, pred_cls=pred_cls, targets=y, x=x_0,
17-
# y=y_0, w=w, h=h, pred_conf=pred_conf)
15+
def forward(self, predictions, ground_truth):
16+
self.metrics, total_loss = self.calculate_metrics(pred_boxes=pred_boxes, pred_cls=pred_cls, targets=y, x=x_0,
17+
y=y_0, w=w, h=h, pred_conf=pred_conf)
1818

1919
return
2020

src/YOLOv3.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import torch.nn as nn
22

3-
from src.YOLODummy import YOLODummy
43
from src.YOLOLayer import YOLOLayer
54
from src.YOLOModule import YOLOModule
65

src/YOLOv3Dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
import torch.nn as nn
2+
from torch.utils.data import Dataset
33
from torchvision import transforms
44
import torch.nn.functional as F
55

@@ -10,7 +10,7 @@
1010
import numpy as np
1111

1212

13-
class YOLOv3Dataset(nn.Module):
13+
class YOLOv3Dataset(Dataset):
1414

1515
def __init__(self, images_list_path="../data/images_list.txt", image_size=416, should_augment=True, transform=None):
1616
super(YOLOv3Dataset, self).__init__()

0 commit comments

Comments
 (0)