Skip to content

Commit a0ab29b

Browse files
committed
update
1 parent 387ccc5 commit a0ab29b

File tree

9 files changed

+37
-49
lines changed

9 files changed

+37
-49
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ astro
77
z_experiments
88
include/duckdb_database
99
include/pretained_models/
10-
db_script.py
10+
db_script.py
11+
include/minio

dags/deploy_best_model.py

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,35 +5,14 @@
55
"""
66

77
from airflow import Dataset as AirflowDataset
8-
from airflow.decorators import dag, task_group, task
9-
from astro import sql as aql
10-
from astro.sql import get_value_list
11-
from astro.files import get_file_list
12-
from astro.sql.table import Table
8+
from airflow.decorators import dag, task
139
from airflow.operators.empty import EmptyOperator
14-
from airflow.operators.bash import BashOperator
15-
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
1610
from airflow.models import Variable
17-
18-
from collections import Counter
19-
import pandas as pd
2011
from pendulum import datetime
21-
import os
2212
import logging
23-
import requests
24-
import numpy as np
25-
from PIL import Image
2613
import duckdb
27-
import json
28-
import pickle
29-
import shutil
30-
import torch
3114
from airflow.sensors.base import PokeReturnValue
3215

33-
from include.custom_operators.hugging_face import (
34-
TestHuggingFaceImageClassifierOperator,
35-
transform_function,
36-
)
3716

3817
task_logger = logging.getLogger("airflow.task")
3918

@@ -79,13 +58,22 @@ def pick_best_model_from_db(db_path):
7958
FROM model_results
8059
WHERE test_set_num = (SELECT MAX(test_set_num) FROM model_results)
8160
ORDER BY test_av_loss ASC
82-
LIMIT 1;"""
83-
).fetchall()[0][0]
61+
LIMIT 1;""" # want higher false negative - be more sensitive recall!!. maybe optimized for f score (long discussion on medium), want to use precision or recall
62+
).fetchall()[0][0] #ROC area
8463
con.close()
8564

8665
return best_model_latest_test_set
8766

88-
ensure_baseline_ran() >> pick_best_model_from_db(db_path=DUCKDB_PATH)
67+
@task
68+
def deploy_model(model):
69+
pass
70+
71+
(
72+
start
73+
>> ensure_baseline_ran()
74+
>> deploy_model(pick_best_model_from_db(db_path=DUCKDB_PATH))
75+
>> end
76+
)
8977

9078

9179
deploy_best_model()

dags/in_new_test_data.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828

2929
@dag(
3030
start_date=datetime(2023, 1, 1),
31-
schedule=None,
31+
schedule="@continuous",
32+
max_active_runs=1,
3233
catchup=False,
3334
)
3435
def in_new_test_data():

dags/in_new_train_data.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828

2929
@dag(
3030
start_date=datetime(2023, 1, 1),
31-
schedule=None,
31+
schedule="@continuous",
32+
max_active_runs=1,
3233
catchup=False,
3334
)
3435
def in_new_train_data():

dags/preprocess_test_data.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,15 @@
55
"""
66

77
from airflow import Dataset
8-
from airflow.decorators import dag, task_group, task
9-
from astro import sql as aql
10-
from astro.sql import get_value_list
8+
from airflow.decorators import dag, task
119
from astro.files import get_file_list
12-
from astro.sql.table import Table
1310
from airflow.operators.empty import EmptyOperator
1411
from airflow.operators.bash import BashOperator
1512
from airflow.models import Variable
1613

17-
import pandas as pd
1814
from pendulum import datetime
19-
import os
2015
import logging
21-
import requests
22-
import numpy as np
23-
from PIL import Image
2416
import duckdb
25-
import json
26-
import pickle
2717

2818
task_logger = logging.getLogger("airflow.task")
2919

dags/preprocess_train_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66

77
from airflow import Dataset
8-
from airflow.decorators import dag, task_group, task
8+
from airflow.decorators import dag, task
99
from astro import sql as aql
1010
from astro.sql import get_value_list
1111
from astro.files import get_file_list

dags/test_fine_tuned_model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import pickle
2121
import shutil
2222
import torch
23+
from airflow.models import Variable
2324

2425
from include.custom_operators.hugging_face import (
2526
TestHuggingFaceImageClassifierOperator,
@@ -161,15 +162,16 @@ def write_model_results_to_duckdb(db_path, table_name, **context):
161162
model_name = context["ti"].xcom_pull(task_ids="test_classifier")["model_name"]
162163

163164
con = duckdb.connect(db_path)
165+
test_set_num = Variable.get("test_set_num")
164166

165167
con.execute(
166168
f"""CREATE TABLE IF NOT EXISTS {table_name}
167-
(model_name TEXT PRIMARY KEY, timestamp DATETIME, test_av_loss FLOAT, test_accuracy FLOAT)"""
169+
(model_name TEXT PRIMARY KEY, timestamp DATETIME, test_av_loss FLOAT, test_accuracy FLOAT, test_set_num INT)"""
168170
)
169171

170172
con.execute(
171-
f"INSERT OR REPLACE INTO {table_name} (model_name, timestamp, test_av_loss, test_accuracy) VALUES (?, ?, ?, ?) ",
172-
(model_name, timestamp, test_av_loss, test_accuracy),
173+
f"INSERT OR REPLACE INTO {table_name} (model_name, timestamp, test_av_loss, test_accuracy, test_set_num) VALUES (?, ?, ?, ?, ?) ",
174+
(model_name, timestamp, test_av_loss, test_accuracy, test_set_num),
173175
)
174176

175177
con.close()

dags/train_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@ def load_training_images(keys):
102102

103103
train_classifier = TrainHuggingFaceImageClassifierOperator(
104104
task_id="train_classifier",
105-
model_name="microsoft/resnet-50",
106-
criterion=torch.nn.CrossEntropyLoss(),
105+
model_name="microsoft/resnet-50", # find newer one?
106+
criterion=torch.nn.CrossEntropyLoss(), # binary entropy loss!
107107
optimizer=torch.optim.Adam,
108108
local_images_filepaths=local_images_filepaths,
109109
labels=get_labels_from_duckdb.map(lambda x: x[0]),

include/custom_operators/hugging_face.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ class TrainHuggingFaceImageClassifierOperator(BaseOperator):
100100
101101
"""
102102

103+
ui_color = "#91ed9d"
104+
103105
template_fields = (
104106
"model_name",
105107
"criterion",
@@ -131,7 +133,7 @@ def __init__(
131133
super().__init__(*args, **kwargs)
132134
self.model_name = model_name
133135
self.criterion = criterion
134-
self.optimizer = optimizer
136+
self.optimizer = optimizer #change optimizer
135137
self.local_images_filepaths = local_images_filepaths
136138
self.labels = labels
137139
self.num_classes = num_classes
@@ -154,16 +156,17 @@ def execute(self, context):
154156
num_workers=0,
155157
)
156158

157-
model = ResNetForImageClassification.from_pretrained(self.model_name)
159+
# figure out how fine tuning happens inside hugging face
160+
model = ResNetForImageClassification.from_pretrained(self.model_name) # does this do something optimized for fine tune or not?
158161
model.classifier[-1] = torch.nn.Linear(
159162
model.classifier[-1].in_features, self.num_classes
160163
)
161164

162165
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
163166
model = model.to(device)
164-
optimizer = self.optimizer(model.parameters(), lr=1e-4)
167+
optimizer = self.optimizer(model.parameters(), lr=1e-4) ### lr is a hyperparameter learning rate to be adjusted. add as a parameter
165168

166-
for epoch in range(self.num_epochs):
169+
for epoch in range(self.num_epochs): #num of epochs
167170
model.train()
168171
running_loss = 0.0
169172
running_corrects = 0
@@ -206,6 +209,8 @@ class TestHuggingFaceImageClassifierOperator(BaseOperator):
206209
207210
"""
208211

212+
ui_color = "#ebab34"
213+
209214
template_fields = (
210215
"model_name",
211216
"criterion",

0 commit comments

Comments
 (0)