1
+ import torch
2
+ from byol_pytorch import BYOL
3
+ from torchvision import models
4
+ from torchvision .models import ResNet50_Weights
5
+ from torch .utils .tensorboard import SummaryWriter
6
+
7
+ import datetime
8
+
9
+ def sample_unlabelled_images ():
10
+ return torch .randn (20 , 3 , 256 , 256 )
11
+
12
+ class Trainer ():
13
+ def __init__ (self , params ) -> None :
14
+ self .train_dataloader = params ['train_data' ]
15
+ self .test_dataloader = params ['test_data' ]
16
+ self .BYOL = params ['byol' ]
17
+ self .optimizer = params ['optim' ]
18
+
19
+ def train (self , epoch ):
20
+ self .BYOL .train ()
21
+ epoch_loss = 0
22
+ num_total = 0
23
+ # for i, data in enumerate(self.train_dataloader):
24
+ images = sample_unlabelled_images () # images, label = data
25
+ loss = self .BYOL (images )
26
+ self .optimizer .zero_grad ()
27
+ loss .backward ()
28
+ self .optimizer .step ()
29
+ self .BYOL .update_moving_average ()
30
+ epoch_loss += loss .item ()
31
+ num_total += images .size (0 )
32
+
33
+ loss = epoch_loss / num_total
34
+ return loss
35
+
36
+ def test (self , epoch ):
37
+ self .BYOL .eval ()
38
+ epoch_loss = 0
39
+ num_total = 0
40
+ images = sample_unlabelled_images ()
41
+ loss = self .BYOL (images )
42
+ epoch_loss += loss .item ()
43
+ num_total += images .size (0 )
44
+
45
+ loss = epoch_loss / num_total
46
+ return loss
47
+
48
+ def main ():
49
+ resnet = models .resnet50 (weights = ResNet50_Weights .IMAGENET1K_V1 )
50
+ learner = BYOL (
51
+ resnet ,
52
+ image_size = 256 ,
53
+ hidden_layer = 'avgpool' )
54
+ opt = torch .optim .Adam (learner .parameters (), lr = 3e-4 )
55
+ log_dir = '.\\ logs'
56
+ run_time = datetime .datetime .now ().strftime ('%Y%m%d_%H%M%S' )
57
+ pt_dir = f'.\\ experiments\\ { run_time } _best.pt'
58
+ # writer = SummaryWriter(log_dir)
59
+ epochs = 3
60
+ parameters = {'train_data' :None ,
61
+ 'test_data' :None ,
62
+ 'byol' :learner ,
63
+ 'optim' :opt }
64
+
65
+ trainer = Trainer (params = parameters )
66
+ best_loss = 10000
67
+ for epoch in range (epochs ):
68
+ # train_loss = trainer.train(epoch)
69
+ # writer.add_scalar("Loss/Train", train_loss, epoch)
70
+ # print(f'Train_loss: {train_loss:4f}')
71
+
72
+ test_loss = trainer .test (epoch )
73
+ # writer.add_scalar("Loss/Train", test_loss, epoch)
74
+ print (f'Test_loss: { test_loss :4f} ' )
75
+
76
+ if test_loss < best_loss :
77
+ torch .save (trainer .BYOL .state_dict (), pt_dir )
78
+
79
+ # writer.close()
80
+
81
+
82
+ main ()
0 commit comments