Skip to content

Commit 7ff248d

Browse files
committed
job script and image for bacalhau
1 parent d9733a1 commit 7ff248d

File tree

9 files changed

+120
-2
lines changed

9 files changed

+120
-2
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ Copy `.env.template` to `.env` and define variables
2626

2727
Update brownie configuration as needed for RPC access
2828

29-
For Golem support, also install the [yagna](https://docs.golem.network/docs/creators/python/examples/tools/yagna-installation-for-requestors) CLI and service
29+
For Golem support, also install the [yagna](https://docs.golem.network/docs/quickstarts/python-quickstart) CLI and service
3030

3131

3232
## Run
@@ -40,7 +40,7 @@ Once the trained model has been uploaded to IPFS, define the `IPFS_MODEL_HASH` v
4040

4141
Activate appropriate virtual environment (e.g. for ocean)
4242
```
43-
python activate mlweb3-ocean
43+
conda activate mlweb3-ocean
4444
```
4545

4646
Deploy model to web3 infrastructure

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,6 @@ dependencies:
2525
- zlib
2626
- zstd
2727
- pip:
28+
- numpy<2
2829
- python-dotenv
2930
prefix: /home/devin/miniconda3/envs/mlweb3
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
name: mlweb3-bacalhau
2+
channels:
3+
- pytorch
4+
- conda-forge
5+
- defaults
6+
dependencies:
7+
- bzip2
8+
- ca-certificates
9+
- pip
10+
- python=3.10
11+
- pytorch=1.12.1
12+
- pytorch-mutex
13+
- readline
14+
- requests
15+
- six
16+
- sqlite
17+
- tk
18+
- torchvision=0.13.1
19+
- typing_extensions
20+
- tzdata
21+
- urllib3
22+
- wheel
23+
- xz
24+
- zlib
25+
- zstd
26+
- pip:
27+
- python-dotenv
28+
# ---- bacalhau ----
29+
- bacalhau-sdk
30+
prefix: /home/devin/miniconda3/envs/mlweb3-bacalhau

etc/requirements/environment-golem.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,6 @@ dependencies:
2929
# ---- golem ----
3030
- flask~=3.0.1
3131
- gvmkit-build
32+
- numpy<2
3233
- yapapi~=0.12.0
3334
prefix: /home/devin/miniconda3/envs/mlweb3-golem

mlweb3/bacalhau/Dockerfile

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# dockerfile for mnist classification job on bacalhau
2+
FROM nvidia/cuda:11.3.1-cudnn8-runtime
3+
4+
# patch: https://github.com/NVIDIA/nvidia-container-toolkit/issues/258
5+
RUN rm /etc/apt/sources.list.d/cuda.list
6+
RUN rm /etc/apt/sources.list.d/nvidia-ml.list
7+
8+
# install python
9+
RUN apt-get update && \
10+
apt-get install -y python3 python3-pip && \
11+
rm -rf /var/lib/apt/lists/*
12+
RUN python3 --version
13+
14+
# install packages
15+
RUN pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cu113
16+
17+
# app
18+
WORKDIR /mlweb3
19+
COPY mlweb3/bacalhau/job.py job.py
20+
COPY mlweb3/model.py mlweb3/model.py
21+
COPY etc/models/cnn_mnist.pth etc/models/cnn_mnist.pth
22+
23+
CMD python3 job.py

mlweb3/bacalhau/__init__.py

Whitespace-only changes.

mlweb3/bacalhau/deployment.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""
2+
deployment logic for bacalhau
3+
"""
4+
5+
6+
def deploy():
7+
print('no bacalhau deployment needed.')

mlweb3/bacalhau/inference.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""
2+
inference logic for bacalhau
3+
"""
4+
5+
6+
def predict():
7+
raise NotImplemented('bacalhau predict not implemented.')

mlweb3/bacalhau/job.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
"""
2+
inference logic for bacalhau job
3+
"""
4+
5+
# lib
6+
import os
7+
import torch
8+
from torch.utils.data import DataLoader
9+
from torchvision import datasets
10+
from torchvision.transforms import ToTensor, Compose
11+
12+
# src
13+
from mlweb3.model import SimpleCNN
14+
15+
16+
def main():
17+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
18+
print(f'using device {device}')
19+
20+
# load model
21+
print(f'loading SimpleCNN...')
22+
model = SimpleCNN()
23+
model.load_state_dict(torch.load('./etc/models/cnn_mnist.pth', map_location=torch.device(device)))
24+
model.eval()
25+
model = model.to(device)
26+
27+
# get data
28+
os.makedirs('./etc/mnist', exist_ok=True)
29+
data = DataLoader(
30+
datasets.MNIST('./etc/mnist', train=False, download=True, transform=ToTensor()),
31+
batch_size=64
32+
)
33+
34+
# do inference
35+
correct, total = 0, 0
36+
predictions = []
37+
with torch.no_grad():
38+
for X, y in data:
39+
X, y = X.to(device), y.to(device)
40+
pred = model(X)
41+
correct += (pred.argmax(1) == y).sum().item()
42+
total += len(X)
43+
predictions.extend(pred.argmax(1).cpu().numpy().tolist())
44+
45+
print(f'test:\n correct: {correct}\n total: {total}\n accuracy: {correct / total:>0.4f}')
46+
47+
48+
if __name__ == '__main__':
49+
main()

0 commit comments

Comments
 (0)