Skip to content

Commit 77030c1

Browse files
committed
transfer learning example
1 parent bd8fda7 commit 77030c1

File tree

4 files changed

+1441
-0
lines changed

4 files changed

+1441
-0
lines changed
7.48 KB
Loading
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from subprocess import call
2+
import os
3+
from urllib.request import urlretrieve
4+
from keras.preprocessing.image import load_img, img_to_array
5+
import numpy as np
6+
7+
dataset_url = "http://aisdatasets.informatik.uni-freiburg.de/" \
8+
"freiburg_groceries_dataset/freiburg_groceries_dataset.tar.gz"
9+
10+
def download():
11+
print("Downloading dataset.")
12+
urlretrieve(dataset_url, "freiburg_groceries_dataset.tar.gz")
13+
print("Extracting dataset.")
14+
call(["tar", "-xf", "freiburg_groceries_dataset.tar.gz", "-C", "."])
15+
os.remove("freiburg_groceries_dataset.tar.gz")
16+
print("Done.")
17+
18+
19+
def load_data():
20+
if (not os.path.exists("images")):
21+
download()
22+
23+
x_train = []
24+
y_train = []
25+
x_test = []
26+
y_test = []
27+
class_names = []
28+
category_num = 0
29+
30+
for category in sorted(os.listdir("images")):
31+
class_names.append(category)
32+
count = 0
33+
for img in sorted(os.listdir(os.path.join("images", category))):
34+
if (not img.endswith(".png")):
35+
continue
36+
37+
x = load_img(os.path.join("images", category, img),target_size=(224, 224))
38+
if count < 10:
39+
x_test.append(img_to_array(x))
40+
y_test.append(category_num)
41+
else:
42+
x_train.append(img_to_array(x))
43+
y_train.append(category_num)
44+
count += 1
45+
category_num += 1
46+
return (np.array(x_train), np.array(y_train)), (np.array(x_test), np.array(y_test)), class_names
47+

examples/keras-transfer-learning/transfer-demo.ipynb

Lines changed: 1389 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[default]
2+
entity = bloomberg-class
3+
project = groceries
4+
base_url = https://api.wandb.ai
5+

0 commit comments

Comments
 (0)