Skip to content

Commit 26ff850

Browse files
committed
do inference with ocean compute job
1 parent ad39db1 commit 26ff850

File tree

4 files changed

+126
-2
lines changed

4 files changed

+126
-2
lines changed

README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,10 @@ Once the trained model has been uploaded to IPFS, define the `IPFS_MODEL_HASH` v
2929

3030
Deploy model to web3 infrastructure
3131
```
32-
python deploy --infra ocean
32+
python deploy.py --infra ocean
33+
```
34+
35+
Make predictions with deployed model
36+
```
37+
python predict.py --infra ocean
3338
```

deploy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
deployed trained model to web3 infra
2+
deploy trained model to web3 infra
33
"""
44

55
from mlweb3.ocean.deployment import deploy

mlweb3/ocean/inference.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
"""
2+
inference logic for ocean model
3+
"""
4+
import os
5+
import time
6+
import datetime
7+
8+
from dotenv import load_dotenv
9+
10+
from brownie.network import accounts
11+
from web3 import Web3
12+
from ocean_lib.web3_internal.utils import connect_to_network
13+
from ocean_lib.example_config import ExampleConfig
14+
from ocean_lib.ocean.ocean import Ocean
15+
from ocean_lib.models.compute_input import ComputeInput
16+
17+
18+
def predict():
19+
# configure client
20+
load_dotenv()
21+
connect_to_network('mumbai')
22+
config = ExampleConfig.get_config('mumbai')
23+
ocean = Ocean(config)
24+
print(config)
25+
26+
# load accounts
27+
accounts.clear()
28+
alice_private_key = os.getenv('PRIVATE_KEY_0')
29+
alice_wallet = accounts.add(alice_private_key)
30+
alice_balance = accounts.at(alice_wallet.address).balance()
31+
print('Alice balance: {}'.format(alice_balance))
32+
assert alice_balance > 0
33+
34+
bob_private_key = os.getenv('PRIVATE_KEY_1')
35+
bob_wallet = accounts.add(bob_private_key)
36+
bob_balance = accounts.at(bob_wallet.address).balance()
37+
print('Bob balance: {}'.format(bob_balance))
38+
assert bob_balance > 0
39+
40+
# resolve assets
41+
data_asset = ocean.assets.resolve('did:op:9a4c1fe5fcec7d4071b8afdbb17dcc2c05b1ecf8cf512f556ef8f38116a594c3')
42+
weights_asset = ocean.assets.resolve('did:op:733ca2bfe560f6a66cc1e5cc2b6b91799db77e5fdbd41d4e1a402cb944bcdf43')
43+
algo_asset = ocean.assets.resolve('did:op:0c3c2c9099c67a49128379e5781e48fccb6125d335464d07edcceae78e7f729c')
44+
45+
data_token = ocean.get_datatoken(data_asset.datatokens[0]['address'])
46+
weights_token = ocean.get_datatoken(weights_asset.datatokens[0]['address'])
47+
algo_token = ocean.get_datatoken(algo_asset.datatokens[0]['address'])
48+
49+
# mint tokens to user
50+
data_token.mint(bob_wallet.address, Web3.toWei(5, 'ether'), {'from': alice_wallet})
51+
weights_token.mint(bob_wallet.address, Web3.toWei(5, 'ether'), {'from': alice_wallet})
52+
algo_token.mint(bob_wallet.address, Web3.toWei(5, 'ether'), {'from': alice_wallet})
53+
54+
# setup environment and inputs
55+
data_service = data_asset.services[0]
56+
weights_service = weights_asset.services[0]
57+
algo_service = algo_asset.services[0]
58+
free_c2d_env = ocean.compute.get_free_c2d_environment(data_service.service_endpoint)
59+
60+
data_compute_input = ComputeInput(data_asset, data_service)
61+
weights_compute_input = ComputeInput(weights_asset, weights_service)
62+
algo_compute_input = ComputeInput(algo_asset, algo_service)
63+
64+
# data_compute_input.transfer_tx_id = '0x8f9eb8f28402acf3d136edc063d2909059da2d418390eff52977f44548895197'
65+
# weights_compute_input.transfer_tx_id = '0x62c9b4ac3d0686d4354847f2f0be8a67b96b8e97a6582cfc675e12626279fc0d'
66+
# algo_compute_input.transfer_tx_id = '0x8b89f8dbcc0b2d9790c25cd403e264f9ccd876513f2926ae56ed1090fba26623'
67+
68+
# pay for dataset, weights, and algo for 1 day
69+
datasets, algorithm = ocean.assets.pay_for_compute_service(
70+
datasets=[data_compute_input, weights_compute_input],
71+
algorithm_data=algo_compute_input,
72+
consume_market_order_fee_address=bob_wallet.address,
73+
wallet=bob_wallet,
74+
compute_environment=free_c2d_env['id'],
75+
valid_until=int((datetime.datetime.utcnow() + datetime.timedelta(days=1)).timestamp()),
76+
consumer_address=free_c2d_env['consumerAddress'],
77+
)
78+
assert datasets, 'payment for dataset unsuccessful'
79+
assert algorithm, 'payment for algorithm unsuccessful'
80+
81+
# start compute job
82+
t0 = datetime.datetime.utcnow()
83+
job_id = ocean.compute.start(
84+
consumer_wallet=bob_wallet,
85+
dataset=data_compute_input,
86+
compute_environment=free_c2d_env['id'],
87+
algorithm=algo_compute_input,
88+
additional_datasets=[weights_compute_input]
89+
)
90+
print('Started compute job: {}'.format(job_id))
91+
92+
# monitor job
93+
while 1:
94+
status = ocean.compute.status(data_asset, data_service, job_id, bob_wallet)
95+
t1 = datetime.datetime.utcnow()
96+
print('status: {}, time elapsed: {} seconds'.format(status['statusText'], t1 - t0))
97+
if status.get('dateFinished'):
98+
print(status)
99+
break
100+
time.sleep(10)
101+
102+
# Retrieve algorithm output and log files
103+
output = ocean.compute.compute_job_result_logs(
104+
data_asset, data_service, job_id, bob_wallet
105+
)
106+
print(output)

predict.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""
2+
make predictions with deployed model on web3 infra
3+
"""
4+
5+
from mlweb3.ocean.inference import predict
6+
7+
8+
def main():
9+
predict()
10+
11+
12+
if __name__ == '__main__':
13+
main()

0 commit comments

Comments
 (0)