Skip to content

Commit e5295e7

Browse files
committed
byol_train_v1
1 parent f12087b commit e5295e7

File tree

1 file changed

+82
-0
lines changed

1 file changed

+82
-0
lines changed

byol_pytorch/train.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import torch
2+
from byol_pytorch import BYOL
3+
from torchvision import models
4+
from torchvision.models import ResNet50_Weights
5+
from torch.utils.tensorboard import SummaryWriter
6+
7+
import datetime
8+
9+
def sample_unlabelled_images():
10+
return torch.randn(20, 3, 256, 256)
11+
12+
class Trainer():
13+
def __init__(self, params) -> None:
14+
self.train_dataloader = params['train_data']
15+
self.test_dataloader = params['test_data']
16+
self.BYOL = params['byol']
17+
self.optimizer = params['optim']
18+
19+
def train(self, epoch):
20+
self.BYOL.train()
21+
epoch_loss = 0
22+
num_total = 0
23+
# for i, data in enumerate(self.train_dataloader):
24+
images = sample_unlabelled_images() # images, label = data
25+
loss = self.BYOL(images)
26+
self.optimizer.zero_grad()
27+
loss.backward()
28+
self.optimizer.step()
29+
self.BYOL.update_moving_average()
30+
epoch_loss += loss.item()
31+
num_total += images.size(0)
32+
33+
loss = epoch_loss / num_total
34+
return loss
35+
36+
def test(self, epoch):
37+
self.BYOL.eval()
38+
epoch_loss = 0
39+
num_total = 0
40+
images = sample_unlabelled_images()
41+
loss = self.BYOL(images)
42+
epoch_loss += loss.item()
43+
num_total += images.size(0)
44+
45+
loss = epoch_loss / num_total
46+
return loss
47+
48+
def main():
49+
resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
50+
learner = BYOL(
51+
resnet,
52+
image_size = 256,
53+
hidden_layer= 'avgpool')
54+
opt = torch.optim.Adam(learner.parameters(), lr=3e-4)
55+
log_dir = '.\\logs'
56+
run_time = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
57+
pt_dir = f'.\\experiments\\{run_time}_best.pt'
58+
# writer = SummaryWriter(log_dir)
59+
epochs = 3
60+
parameters = {'train_data':None,
61+
'test_data' :None,
62+
'byol' :learner,
63+
'optim' :opt}
64+
65+
trainer = Trainer(params=parameters)
66+
best_loss = 10000
67+
for epoch in range(epochs):
68+
# train_loss = trainer.train(epoch)
69+
# writer.add_scalar("Loss/Train", train_loss, epoch)
70+
# print(f'Train_loss: {train_loss:4f}')
71+
72+
test_loss = trainer.test(epoch)
73+
# writer.add_scalar("Loss/Train", test_loss, epoch)
74+
print(f'Test_loss: {test_loss:4f}')
75+
76+
if test_loss < best_loss:
77+
torch.save(trainer.BYOL.state_dict(), pt_dir)
78+
79+
# writer.close()
80+
81+
82+
main()

0 commit comments

Comments
 (0)