Skip to content

Commit a3ee0fe

Browse files
Merge pull request falloutdurham#47 from MarcusFra/ch08_catfish_new
Ch8 Change catfish_server.py & add/change bash scripts
2 parents 681ba2f + c9866de commit a3ee0fe

17 files changed

+188
-52
lines changed

chapter8/catfish/catfish_server.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,28 @@
1-
import os
21
import requests
32
import torch
43
from flask import Flask, jsonify, request
54
from io import BytesIO
65
from PIL import Image
76
from torchvision import transforms
8-
from catfish_model import CatfishModel, CatfishClasses
9-
from urllib.request import urlopen
10-
from shutil import copyfileobj
11-
from tempfile import NamedTemporaryFile
127

8+
from catfish_model import CatfishModel, CatfishClasses
139

1410
def load_model():
15-
m = CatfishModel
16-
if "CATFISH_MODEL_LOCATION" in os.environ:
17-
parameter_url = os.environ["CATFISH_MODEL_LOCATION"]
18-
print(f"downloading {parameter_url}")
19-
with urlopen(parameter_url) as fsrc, NamedTemporaryFile() as fdst:
20-
copyfileobj(fsrc, fdst)
21-
m.load_state_dict(torch.load(fdst))
22-
m.load_state_dict(torch.load(location))
23-
return m
11+
m = CatfishModel
12+
m.eval()
13+
return m
2414

2515
model = load_model()
2616

2717
img_transforms = transforms.Compose([
2818
transforms.Resize((224,224)),
2919
transforms.ToTensor(),
3020
transforms.Normalize(mean=[0.485, 0.456, 0.406],
31-
std=[0.229, 0.224, 0.225] )
21+
std=[0.229, 0.224, 0.225])
3222
])
3323

3424
def create_app():
3525
app = Flask(__name__)
36-
3726

3827
@app.route("/")
3928
def status():
@@ -53,6 +42,4 @@ def predict():
5342
predicted_class = CatfishClasses[torch.argmax(prediction)]
5443
return jsonify({"image": img_url, "prediction": predicted_class})
5544

56-
return app
57-
if __name__ == '__main__':
58-
app.run(host=os.environ["CATFISH_HOST"], port=os.environ["CATFISH_PORT"])
45+
return app

chapter8/catfish/get-prediction.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#!/bin/bash
2+
#run-model-service.sh
3+
4+
curl http://127.0.0.1:8080/predict\?image_url\=https://upload.wikimedia.org/wikipedia/commons/thumb/3/36/A_domestic_shorthair_tortie-tabby_cat.jpg/412px-A_domestic_shorthair_tortie-tabby_cat.jpg

chapter8/catfish/run-flask-server.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#!/usr/bin/env bash
2+
#run-flask-server.sh
3+
4+
FLASK_APP=catfish_server.py FLASK_RUN_PORT=8080 flask run
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#!/usr/bin/env bash
2+
#run-waitress-server.sh
3+
4+
waitress-serve --call 'catfish_server:create_app'
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import torch.nn as nn
2+
from torchvision import models
3+
4+
CatfishClasses = ["cat","fish"]
5+
6+
CatfishModel = models.resnet50()
7+
CatfishModel.fc = nn.Sequential(nn.Linear(CatfishModel.fc.in_features,500),
8+
nn.ReLU(),
9+
nn.Dropout(), nn.Linear(500,2))
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import os
2+
import requests
3+
import torch
4+
from flask import Flask, jsonify, request
5+
from io import BytesIO
6+
from PIL import Image
7+
from shutil import copyfileobj
8+
from tempfile import NamedTemporaryFile
9+
from torchvision import transforms
10+
from urllib.request import urlopen
11+
12+
from catfish_model import CatfishModel, CatfishClasses
13+
14+
15+
def load_model():
16+
m = CatfishModel
17+
if "CATFISH_MODEL_LOCATION" in os.environ:
18+
parameter_url = os.environ["CATFISH_MODEL_LOCATION"]
19+
print(f"downloading {parameter_url}")
20+
with urlopen(parameter_url) as fsrc, NamedTemporaryFile() as fdst:
21+
copyfileobj(fsrc, fdst)
22+
m.load_state_dict(torch.load(fdst, map_location="cpu"))
23+
m.eval()
24+
return m
25+
26+
27+
model = load_model()
28+
29+
img_transforms = transforms.Compose([
30+
transforms.Resize((224,224)),
31+
transforms.ToTensor(),
32+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
33+
std=[0.229, 0.224, 0.225])
34+
])
35+
36+
def create_app():
37+
app = Flask(__name__)
38+
39+
@app.route("/")
40+
def status():
41+
return jsonify({"status": "ok"})
42+
43+
@app.route("/predict", methods=['GET', 'POST'])
44+
def predict():
45+
if request.method == 'POST':
46+
img_url = request.form.image_url
47+
else:
48+
img_url = request.args.get('image_url', '')
49+
50+
response = requests.get(img_url)
51+
img = Image.open(BytesIO(response.content))
52+
img_tensor = img_transforms(img).unsqueeze(0)
53+
prediction = model(img_tensor)
54+
predicted_class = CatfishClasses[torch.argmax(prediction)]
55+
return jsonify({"image": img_url, "prediction": predicted_class})
56+
57+
return app
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#!/bin/bash
2+
#get-prediction.sh
3+
4+
curl http://127.0.0.1:5000/predict?image_url=https://upload.wikimedia.org/wikipedia/commons/thumb/3/36/A_domestic_shorthair_tortie-tabby_cat.jpg/412px-A_domestic_shorthair_tortie-tabby_cat.jpg
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#!/bin/bash
2+
#run-docker.sh
3+
4+
docker build -t catfish-service .
5+
docker run -d -p 5000:5000 --env CATFISH_MODEL_LOCATION=[URL] catfish-service:latest

chapter8/catfish/run-model-service.sh renamed to chapter8/catfish_docker_cloud/run-model-service.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
#run-model-service.sh
33

44
cd /app
5-
waitress-serve --port ${CATFISH_PORT} --call 'catfish_server:create_app'
5+
waitress-serve --port ${CATFISH_PORT} --call 'catfish_server:create_app'
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
FROM continuumio/miniconda3:latest
2+
3+
ARG model_parameter_file=catfishweights.pth
4+
ARG port=5000
5+
6+
ENV CATFISH_PORT=$port
7+
ENV CATFISH_MODEL_LOCATION=/app/$model_parameter_file
8+
9+
RUN conda install -y flask \
10+
&& conda install -c pytorch torchvision \
11+
&& conda install waitress
12+
RUN mkdir -p /app
13+
14+
COPY ./catfish_model.py /app
15+
COPY ./catfish_server.py /app
16+
COPY ./$model_parameter_file /app/
17+
COPY ./run-model-service.sh /
18+
19+
EXPOSE $port
20+
21+
ENTRYPOINT ["/run-model-service.sh"]
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import torch.nn as nn
2+
from torchvision import models
3+
4+
CatfishClasses = ["cat","fish"]
5+
6+
CatfishModel = models.resnet50()
7+
CatfishModel.fc = nn.Sequential(nn.Linear(CatfishModel.fc.in_features,500),
8+
nn.ReLU(),
9+
nn.Dropout(), nn.Linear(500,2))
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import os
2+
import requests
3+
import torch
4+
from flask import Flask, jsonify, request
5+
from io import BytesIO
6+
from PIL import Image
7+
from torchvision import transforms
8+
9+
from catfish_model import CatfishModel, CatfishClasses
10+
11+
12+
def load_model():
13+
location = os.environ["CATFISH_MODEL_LOCATION"]
14+
m = CatfishModel
15+
m.load_state_dict(torch.load(location, map_location="cpu"))
16+
m.eval()
17+
return m
18+
19+
20+
model = load_model()
21+
22+
img_transforms = transforms.Compose([
23+
transforms.Resize((224,224)),
24+
transforms.ToTensor(),
25+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
26+
std=[0.229, 0.224, 0.225] )
27+
])
28+
29+
def create_app():
30+
app = Flask(__name__)
31+
32+
@app.route("/")
33+
def status():
34+
return jsonify({"status": "ok"})
35+
36+
@app.route("/predict", methods=['GET', 'POST'])
37+
def predict():
38+
if request.method == 'POST':
39+
img_url = request.form.image_url
40+
else:
41+
img_url = request.args.get('image_url', '')
42+
43+
response = requests.get(img_url)
44+
img = Image.open(BytesIO(response.content))
45+
img_tensor = img_transforms(img).unsqueeze(0)
46+
prediction = model(img_tensor)
47+
predicted_class = CatfishClasses[torch.argmax(prediction)]
48+
return jsonify({"image": img_url, "prediction": predicted_class})
49+
50+
return app
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#!/bin/bash
2+
#get-prediction.sh
3+
4+
curl http://127.0.0.1:5000/predict?image_url=https://upload.wikimedia.org/wikipedia/commons/thumb/3/36/A_domestic_shorthair_tortie-tabby_cat.jpg/412px-A_domestic_shorthair_tortie-tabby_cat.jpg
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#!/bin/bash
2+
#run-docker.sh
3+
4+
docker build -t catfish-service .
5+
docker run -d -p 5000:5000 catfish-service:latest
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#!/bin/bash
2+
#run-model-service.sh
3+
4+
cd /app
5+
waitress-serve --port ${CATFISH_PORT} --call 'catfish_server:create_app'

chapter8/server.py

Lines changed: 0 additions & 32 deletions
This file was deleted.

0 commit comments

Comments
 (0)