Skip to content

Commit b5e5c33

Browse files
committed
Change catfish_server.py (add eval(),change load_model() to base scenario, delete if __name__ ) & add/change bash scripts for excecution for basic catfish service
1 parent 88287c7 commit b5e5c33

File tree

7 files changed

+76
-24
lines changed

7 files changed

+76
-24
lines changed

chapter8/catfish/catfish_server.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,28 @@
1-
import os
21
import requests
32
import torch
43
from flask import Flask, jsonify, request
54
from io import BytesIO
65
from PIL import Image
76
from torchvision import transforms
8-
from catfish_model import CatfishModel, CatfishClasses
9-
from urllib.request import urlopen
10-
from shutil import copyfileobj
11-
from tempfile import NamedTemporaryFile
127

8+
from catfish_model import CatfishModel, CatfishClasses
139

1410
def load_model():
15-
m = CatfishModel
16-
if "CATFISH_MODEL_LOCATION" in os.environ:
17-
parameter_url = os.environ["CATFISH_MODEL_LOCATION"]
18-
print(f"downloading {parameter_url}")
19-
with urlopen(parameter_url) as fsrc, NamedTemporaryFile() as fdst:
20-
copyfileobj(fsrc, fdst)
21-
m.load_state_dict(torch.load(fdst))
22-
m.load_state_dict(torch.load(location))
23-
return m
11+
m = CatfishModel
12+
m.eval()
13+
return m
2414

2515
model = load_model()
2616

2717
img_transforms = transforms.Compose([
2818
transforms.Resize((224,224)),
2919
transforms.ToTensor(),
3020
transforms.Normalize(mean=[0.485, 0.456, 0.406],
31-
std=[0.229, 0.224, 0.225] )
21+
std=[0.229, 0.224, 0.225])
3222
])
3323

3424
def create_app():
3525
app = Flask(__name__)
36-
3726

3827
@app.route("/")
3928
def status():
@@ -53,6 +42,4 @@ def predict():
5342
predicted_class = CatfishClasses[torch.argmax(prediction)]
5443
return jsonify({"image": img_url, "prediction": predicted_class})
5544

56-
return app
57-
if __name__ == '__main__':
58-
app.run(host=os.environ["CATFISH_HOST"], port=os.environ["CATFISH_PORT"])
45+
return app

chapter8/catfish/get-prediction.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#!/bin/bash
2+
#run-model-service.sh
3+
4+
curl http://127.0.0.1:8080/predict\?image_url\=https://upload.wikimedia.org/wikipedia/commons/thumb/3/36/A_domestic_shorthair_tortie-tabby_cat.jpg/412px-A_domestic_shorthair_tortie-tabby_cat.jpg

chapter8/catfish/run-flask-server.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#!/usr/bin/env bash
2+
#run-flask-server.sh
3+
4+
FLASK_APP=catfish_server.py FLASK_RUN_PORT=8080 flask run

chapter8/catfish/run-model-service.sh

Lines changed: 0 additions & 5 deletions
This file was deleted.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#!/usr/bin/env bash
2+
#run-waitress-server.sh
3+
4+
waitress-serve --call 'catfish_server:create_app'
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import os
2+
import requests
3+
import torch
4+
from flask import Flask, jsonify, request
5+
from io import BytesIO
6+
from PIL import Image
7+
from torchvision import transforms
8+
from catfish_model import CatfishModel, CatfishClasses
9+
from urllib.request import urlopen
10+
from shutil import copyfileobj
11+
from tempfile import NamedTemporaryFile
12+
13+
14+
def load_model():
15+
m = CatfishModel
16+
if "CATFISH_MODEL_LOCATION" in os.environ:
17+
parameter_url = os.environ["CATFISH_MODEL_LOCATION"]
18+
print(f"downloading {parameter_url}")
19+
with urlopen(parameter_url) as fsrc, NamedTemporaryFile() as fdst:
20+
copyfileobj(fsrc, fdst)
21+
m.load_state_dict(torch.load(fdst))
22+
m.load_state_dict(torch.load(location))
23+
return m
24+
25+
model = load_model()
26+
27+
img_transforms = transforms.Compose([
28+
transforms.Resize((224,224)),
29+
transforms.ToTensor(),
30+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
31+
std=[0.229, 0.224, 0.225] )
32+
])
33+
34+
def create_app():
35+
app = Flask(__name__)
36+
37+
38+
@app.route("/")
39+
def status():
40+
return jsonify({"status": "ok"})
41+
42+
@app.route("/predict", methods=['GET', 'POST'])
43+
def predict():
44+
if request.method == 'POST':
45+
img_url = request.form.image_url
46+
else:
47+
img_url = request.args.get('image_url', '')
48+
49+
response = requests.get(img_url)
50+
img = Image.open(BytesIO(response.content))
51+
img_tensor = img_transforms(img).unsqueeze(0)
52+
prediction = model(img_tensor)
53+
predicted_class = CatfishClasses[torch.argmax(prediction)]
54+
return jsonify({"image": img_url, "prediction": predicted_class})
55+
56+
return app
57+
if __name__ == '__main__':
58+
app.run(host=os.environ["CATFISH_HOST"], port=os.environ["CATFISH_PORT"])

0 commit comments

Comments
 (0)