Skip to content

Commit cb59d00

Browse files
committed
Edit
1 parent e8621bf commit cb59d00

File tree

4 files changed

+79
-19
lines changed

4 files changed

+79
-19
lines changed

3. Model/classification/datasets.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import os
2-
from torchvision.datasets import CIFAR10, CIFAR100, MNIST
2+
from torchvision.datasets import CIFAR10, CIFAR100, MNIST, ImageFolder
33

44
def get_dataset(name, train_transform=None, test_transform=None):
55
data_paths = {
66
'mnist': r'C:\Users\gjust\Documents\Github\data',
77
'cifar10': r'C:\Users\gjust\Documents\Github\data',
88
'cifar100': r'C:\Users\gjust\Documents\Github\data',
9-
'coco': r'C:\Users\gjust\Documents\Github\data\COCO'
9+
'coco': r'C:\Users\gjust\Documents\Github\data\COCO',
10+
'catdog': r'C:\Users\gjust\Documents\Github\data\dogs-vs-cats'
1011
}
1112

1213
data_path = data_paths[name]
@@ -30,8 +31,14 @@ def get_dataset(name, train_transform=None, test_transform=None):
3031
train_set = CIFAR100(root=data_path, transform=train_transform, train=True, download=True)
3132
test_set = CIFAR100(root=data_path, transform=test_transform, train=False, download=True)
3233

34+
elif name == 'catdog':
35+
train_path = os.path.join(data_path, 'train')
36+
# test_path = os.path.join(data_path, 'test')
37+
train_set = ImageFolder(root=train_path, transform=train_transform)
38+
test_set = ImageFolder(root=train_path, transform=test_transform)
39+
3340
return train_set, test_set
3441

3542
if __name__ == '__main__':
36-
data_path = get_dataset('cifar10')
43+
data_path = get_dataset('catdog')
3744
print(data_path)

3. Model/classification/metrics.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import torchmetrics
2+
3+
if __name__ == '__main__':
4+
import torch
5+
# import our library
6+
import torchmetrics
7+
8+
# initialize metric
9+
metric = torchmetrics.Accuracy(num_classes=5).cuda()
10+
recall = torchmetrics.Recall(num_classes=5, multiclass=True, average='samples').cuda()
11+
precision = torchmetrics.Precision(num_classes=5, multiclass=True).cuda()
12+
13+
n_batches = 100
14+
for i in range(n_batches):
15+
# simulate a classification problem
16+
preds = torch.randn(10, 5).cuda()
17+
target = torch.randint(5, (10,)).cuda()
18+
19+
# metric on current batch
20+
metric(preds, target)
21+
recall(preds, target)
22+
precision(preds, target)
23+
result(preds, target)
24+
#print(f"Accuracy on batch {i}: {acc:0.4f}")
25+
# metric on all batches using custom accumulation
26+
acc = metric.compute()
27+
recall = metric.compute()
28+
precision = metric.compute()
29+
30+
print(f"Accuracy on all data: {acc:0.4}")
31+
print(f"Recall: {recall:0.4}")
32+
print(f"Precision: {precision:0.4}")
33+
# print(f'result : {result.compute()}')
34+
acc = metric.compute()
35+
# Reseting internal state such that metric ready for new data
36+
metric.reset()

3. Model/classification/models.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,36 @@
22
from ResNet.model import ResNet, BasicBlock, Bottleneck
33
from GoogLeNet.model import GoogLeNet
44

5+
import torchvision
56

6-
def get_model(name='resnet18', num_classes=10):
7-
models = {
8-
9-
'vgg11': VGG('vgg11', num_classes=num_classes),
10-
'vgg13': VGG('vgg13', num_classes=num_classes),
11-
'vgg16': VGG('vgg16', num_classes=num_classes),
12-
'vgg19': VGG('vgg19', num_classes=num_classes),
13-
14-
'resnet18': ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes),
15-
'resnet34': ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes),
16-
'resnet50': ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes),
17-
'resnet101': ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes),
18-
'resnet152': ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes),
7+
8+
def get_model(name='resnet18', num_classes=10, pretrained=False):
9+
10+
if pretrained:
11+
models = {
12+
'vgg11': torchvision.models.vgg11_bn(pretrained=True),
13+
'vgg13': torchvision.models.vgg13_bn(pretrained=True),
14+
'vgg16': torchvision.models.vgg16_bn(pretrained=True),
15+
'vgg19': torchvision.models.vgg19_bn(pretrained=True)
16+
}
1917

20-
'googlenet': GoogLeNet(num_classes=num_classes)
21-
}
18+
return models[name]
19+
20+
else:
21+
models = {
22+
23+
'vgg11': VGG('vgg11', num_classes=num_classes),
24+
'vgg13': VGG('vgg13', num_classes=num_classes),
25+
'vgg16': VGG('vgg16', num_classes=num_classes),
26+
'vgg19': VGG('vgg19', num_classes=num_classes),
27+
28+
'resnet18': ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes),
29+
'resnet34': ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes),
30+
'resnet50': ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes),
31+
'resnet101': ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes),
32+
'resnet152': ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes),
33+
34+
'googlenet': GoogLeNet(num_classes=num_classes)
35+
}
2236

23-
return models[name]
37+
return models[name]

3. Model/classification/remove.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import torch
2+
3+
print(torch.randint(5, (10,)))

0 commit comments

Comments
 (0)