Skip to content

Commit fe92fcb

Browse files
committed
ocean algorithm to do actual inference
1 parent 26ff850 commit fe92fcb

File tree

2 files changed

+79
-13
lines changed

2 files changed

+79
-13
lines changed

mlweb3/ocean/algorithm.py

Lines changed: 74 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,88 @@
66
import argparse
77
import json
88

9+
import torch
10+
from torch import nn
11+
import torch.nn.functional as F
12+
13+
from torch.utils.data import DataLoader
14+
from torchvision import datasets
15+
from torchvision.transforms import ToTensor
16+
917

1018
def main():
1119
args = parse_arguments()
1220
print(args)
13-
1421
print(os.getcwd())
1522

16-
dids = os.getenv('DIDS', None)
17-
print(dids)
23+
if args['local']:
24+
filepath = args['weights']
25+
else:
26+
dids = os.getenv('DIDS', None)
27+
print(dids)
28+
if not dids:
29+
print('no DIDs found in environment. exiting.')
30+
return
31+
dids = json.loads(dids)
32+
if len(dids) == 0:
33+
print('no DID found for model. exiting.')
34+
filepath = f'data/inputs/{dids[0]}/0'
35+
36+
# load model weights
37+
print(f'Loading SimpleCNN from {filepath}...')
38+
model = SimpleCNN()
39+
model.load_state_dict(torch.load(filepath))
40+
model.eval()
41+
42+
# get/load data
43+
# TODO use local data when multiple input assets are supported
44+
os.makedirs('./etc/mnist', exist_ok=True)
45+
data = DataLoader(
46+
datasets.MNIST('./etc/mnist', train=False, download=True, transform=ToTensor()),
47+
batch_size=64
48+
)
49+
50+
# do inference
51+
correct, total = 0, 0
52+
predictions = []
53+
with torch.no_grad():
54+
for X, y in data:
55+
# X, y = X.to(device), y.to(device)
56+
pred = model(X)
57+
correct += (pred.argmax(1) == y).sum().item()
58+
total += len(X)
59+
predictions.extend(pred.argmax(1).numpy().tolist())
60+
61+
print(f'test:\n accuracy: {correct / total:>0.4f}')
62+
63+
# write output
64+
output_file = 'results.txt' if args['local'] else '/data/outputs/result'
65+
with open(output_file, 'w') as f:
66+
f.write(f'accuracy: {correct / total:>0.4f}\n\npredictions:\n')
67+
for p in predictions:
68+
f.write(f'{p}\n')
69+
1870

19-
if not dids:
20-
print('No DIDs found in environment. Aborting.')
21-
return
71+
class SimpleCNN(nn.Module):
72+
def __init__(self):
73+
super().__init__()
74+
self.conv1 = nn.Conv2d(1, 32, (5, 5))
75+
self.conv2 = nn.Conv2d(32, 64, (5, 5))
76+
self.fc1 = nn.Linear(1024, 128)
77+
self.fc2 = nn.Linear(128, 10)
2278

23-
dids = json.loads(dids)
24-
for did in dids:
25-
filename = f'data/inputs/{did}/0'
26-
print(f'Reading asset file {filename}.')
79+
def forward(self, x):
80+
x = self.conv1(x)
81+
x = F.relu(x)
82+
x = F.max_pool2d(x, 2)
83+
x = self.conv2(x)
84+
x = F.relu(x)
85+
x = F.max_pool2d(x, 2)
86+
x = torch.flatten(x, 1)
87+
x = self.fc1(x)
88+
x = F.relu(x)
89+
x = self.fc2(x)
90+
return x
2791

2892

2993
def parse_arguments() -> dict:

mlweb3/ocean/inference.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,16 @@ def predict():
7878
assert datasets, 'payment for dataset unsuccessful'
7979
assert algorithm, 'payment for algorithm unsuccessful'
8080

81+
# NOTE: not currently possible to use multiple input datasets on a single compute job
82+
8183
# start compute job
8284
t0 = datetime.datetime.utcnow()
8385
job_id = ocean.compute.start(
8486
consumer_wallet=bob_wallet,
85-
dataset=data_compute_input,
87+
dataset=datasets[0],
8688
compute_environment=free_c2d_env['id'],
87-
algorithm=algo_compute_input,
88-
additional_datasets=[weights_compute_input]
89+
algorithm=algorithm,
90+
# additional_datasets=[datasets[1]]
8991
)
9092
print('Started compute job: {}'.format(job_id))
9193

0 commit comments

Comments
 (0)