Skip to content

Commit 9053040

Browse files
committed
fix for 0.2
1 parent 9012fae commit 9053040

File tree

1 file changed

+1
-6
lines changed

1 file changed

+1
-6
lines changed

imagenet/main.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,6 @@
1616
import torchvision.models as models
1717

1818

19-
model_names = sorted(name for name in models.__dict__
20-
if name.islower() and not name.startswith("__")
21-
and callable(models.__dict__[name]))
22-
23-
2419
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
2520
parser.add_argument('data', metavar='DIR',
2621
help='path to dataset')
@@ -308,7 +303,7 @@ def accuracy(output, target, topk=(1,)):
308303

309304
res = []
310305
for k in topk:
311-
correct_k = correct[:k].view(-1).float().sum(0)
306+
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
312307
res.append(correct_k.mul_(100.0 / batch_size))
313308
return res
314309

0 commit comments

Comments
 (0)