Skip to content

Commit c5b8102

Browse files
authored
Merge pull request lukas#52 from charlesfrye/cf-by-hand
Adds interactive nbs for (over-)fitting by hand
2 parents 95cdbbc + ecf6598 commit c5b8102

File tree

7 files changed

+440
-0
lines changed

7 files changed

+440
-0
lines changed

examples/by-hand/data/xs.npy

2.47 KB
Binary file not shown.

examples/by-hand/data/ys.npy

2.47 KB
Binary file not shown.

examples/by-hand/fitting.ipynb

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Fitting a Model by Hand"
8+
]
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": null,
13+
"metadata": {},
14+
"outputs": [],
15+
"source": [
16+
"import matplotlib.pyplot as plt\n",
17+
"import numpy as np\n",
18+
"\n",
19+
"from utils.models import Parameters, LinearModel\n",
20+
"\n",
21+
"plt.rcParams.update({'font.size': 18})"
22+
]
23+
},
24+
{
25+
"cell_type": "code",
26+
"execution_count": null,
27+
"metadata": {},
28+
"outputs": [],
29+
"source": [
30+
"%matplotlib notebook"
31+
]
32+
},
33+
{
34+
"cell_type": "code",
35+
"execution_count": null,
36+
"metadata": {},
37+
"outputs": [],
38+
"source": [
39+
"xs, ys = np.load(\"data/xs.npy\")[:30], np.load(\"data/ys.npy\")[:30]"
40+
]
41+
},
42+
{
43+
"cell_type": "code",
44+
"execution_count": null,
45+
"metadata": {},
46+
"outputs": [],
47+
"source": [
48+
"parameters = Parameters([0.5, 0], [[-1, 1], [-1, 1]], [\"bias\", \"weight\"])\n",
49+
"lm = LinearModel(input_values=np.linspace(0, 1), parameters=parameters)"
50+
]
51+
},
52+
{
53+
"cell_type": "code",
54+
"execution_count": null,
55+
"metadata": {},
56+
"outputs": [],
57+
"source": [
58+
"lm.plot()"
59+
]
60+
},
61+
{
62+
"cell_type": "code",
63+
"execution_count": null,
64+
"metadata": {},
65+
"outputs": [],
66+
"source": [
67+
"lm.set_data(xs, ys)\n",
68+
"\n",
69+
"lm.show_MSE = True\n",
70+
"lm.make_interactive(log=True)"
71+
]
72+
}
73+
],
74+
"metadata": {
75+
"kernelspec": {
76+
"display_name": "Python 3",
77+
"language": "python",
78+
"name": "python3"
79+
},
80+
"language_info": {
81+
"codemirror_mode": {
82+
"name": "ipython",
83+
"version": 3
84+
},
85+
"file_extension": ".py",
86+
"mimetype": "text/x-python",
87+
"name": "python",
88+
"nbconvert_exporter": "python",
89+
"pygments_lexer": "ipython3",
90+
"version": "3.7.3"
91+
}
92+
},
93+
"nbformat": 4,
94+
"nbformat_minor": 4
95+
}

examples/by-hand/overfitting.ipynb

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# (Over-)Fitting a Model by Hand"
8+
]
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": null,
13+
"metadata": {},
14+
"outputs": [],
15+
"source": [
16+
"import matplotlib.pyplot as plt\n",
17+
"import numpy as np\n",
18+
"\n",
19+
"from utils.landmarks import LandmarksModel, setup_plot\n",
20+
"\n",
21+
"plt.rcParams.update({'font.size': 18})"
22+
]
23+
},
24+
{
25+
"cell_type": "code",
26+
"execution_count": null,
27+
"metadata": {},
28+
"outputs": [],
29+
"source": [
30+
"%matplotlib notebook"
31+
]
32+
},
33+
{
34+
"cell_type": "code",
35+
"execution_count": null,
36+
"metadata": {},
37+
"outputs": [],
38+
"source": [
39+
"xs, ys = np.load(\"data/xs.npy\"), np.load(\"data/ys.npy\")"
40+
]
41+
},
42+
{
43+
"cell_type": "code",
44+
"execution_count": null,
45+
"metadata": {},
46+
"outputs": [],
47+
"source": [
48+
"line = setup_plot()\n",
49+
"landmarks_model = LandmarksModel(line, xs, ys, log=True)\n",
50+
"\n",
51+
"plt.show()"
52+
]
53+
}
54+
],
55+
"metadata": {
56+
"kernelspec": {
57+
"display_name": "Python 3",
58+
"language": "python",
59+
"name": "python3"
60+
},
61+
"language_info": {
62+
"codemirror_mode": {
63+
"name": "ipython",
64+
"version": 3
65+
},
66+
"file_extension": ".py",
67+
"mimetype": "text/x-python",
68+
"name": "python",
69+
"nbconvert_exporter": "python",
70+
"pygments_lexer": "ipython3",
71+
"version": "3.7.3"
72+
}
73+
},
74+
"nbformat": 4,
75+
"nbformat_minor": 4
76+
}

examples/by-hand/utils/landmarks.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from matplotlib.lines import Line2D
2+
import matplotlib.pyplot as plt
3+
import numpy as np
4+
import wandb
5+
6+
7+
class LocalLandmarks(object):
8+
9+
def __init__(self, points, num=200):
10+
self.points = np.array(points)
11+
self.grid = np.linspace(0, 1, num=num)
12+
self.pred_fun = np.array([self.grid, self.predict(self.grid)]).T
13+
14+
def predict(self, xs):
15+
idxs_closest_x = []
16+
for x in xs:
17+
idxs_closest_x.append(np.argmin(np.abs(self.points[:, 0] - x)))
18+
return self.points[idxs_closest_x, 1]
19+
20+
21+
class LandmarksModel(object):
22+
"""Interactively build a landmark-based model.
23+
"""
24+
def __init__(self, landmarks_polygon, xs, ys, train_size=15, log=True):
25+
self.landmarks_polygon = landmarks_polygon
26+
self.xp = list(landmarks_polygon.get_xdata())
27+
self.yp = list(landmarks_polygon.get_ydata())
28+
29+
self.canvas = landmarks_polygon.figure.canvas
30+
self.ax_main = landmarks_polygon.axes
31+
32+
self.observed_xs = xs[:train_size]
33+
self.observed_ys = ys[:train_size]
34+
35+
self.make_scatter(self.observed_xs, self.observed_ys, self.ax_main)
36+
self.test_xs = xs[train_size:]
37+
self.test_ys = ys[train_size:]
38+
39+
self.log = log
40+
41+
self.cid = self.canvas.mpl_connect('button_press_event', self)
42+
43+
prediction_line = Line2D([], [],
44+
c=np.divide([255, 204, 51], 256), lw=4)
45+
self.prediction_line_plot = self.ax_main.add_line(prediction_line)
46+
47+
if self.log:
48+
wandb.init()
49+
50+
def __call__(self, event):
51+
# Ignore clicks outside axes
52+
if event.inaxes != self.landmarks_polygon.axes:
53+
return
54+
55+
# Add point
56+
self.xp.append(event.xdata)
57+
self.yp.append(event.ydata)
58+
59+
self.landmarks_polygon.set_data(self.xp, self.yp)
60+
61+
# Rebuild prediction curve and update canvas
62+
self.prediction_line_plot.set_data(*self._rebuild_predictor())
63+
self._update()
64+
65+
def _update(self):
66+
self.canvas.draw()
67+
68+
def _rebuild_predictor(self):
69+
self.local_landmarks = LocalLandmarks(list(zip(self.xp, self.yp)))
70+
71+
train_MSE = self.compute_MSE(self.observed_xs, self.observed_ys)
72+
73+
if self.log:
74+
test_MSE = self.compute_MSE(self.test_xs, self.test_ys)
75+
76+
wandb.log({"train_loss": train_MSE,
77+
"test_loss": test_MSE})
78+
79+
x, y = self.local_landmarks.pred_fun.T
80+
81+
return x, y
82+
83+
def compute_MSE(self, xs, ys):
84+
predictions = self.local_landmarks.predict(xs)
85+
MSE = np.mean(np.square(ys - predictions))
86+
return MSE
87+
88+
def make_scatter(self, xs, ys, ax):
89+
ax.scatter(xs, ys, color='k', alpha=0.5, s=72)
90+
91+
92+
def setup_plot():
93+
fig, ax1 = plt.subplots(figsize=(10, 10))
94+
95+
line = Line2D([], [], ls='none', c='#616666',
96+
marker='x', mew=4, mec='k', ms=10, zorder=3)
97+
ax1.add_line(line)
98+
99+
ax1.set_xlim(0, 1)
100+
ax1.set_ylim(0, 1)
101+
ax1.axis("off")
102+
103+
return line

0 commit comments

Comments
 (0)