|
20 | 20 | import pickle |
21 | 21 | import shutil |
22 | 22 | import torch |
| 23 | +from airflow.models import Variable |
23 | 24 |
|
24 | 25 | from include.custom_operators.hugging_face import ( |
25 | 26 | TestHuggingFaceImageClassifierOperator, |
@@ -161,15 +162,16 @@ def write_model_results_to_duckdb(db_path, table_name, **context): |
161 | 162 | model_name = context["ti"].xcom_pull(task_ids="test_classifier")["model_name"] |
162 | 163 |
|
163 | 164 | con = duckdb.connect(db_path) |
| 165 | + test_set_num = Variable.get("test_set_num") |
164 | 166 |
|
165 | 167 | con.execute( |
166 | 168 | f"""CREATE TABLE IF NOT EXISTS {table_name} |
167 | | - (model_name TEXT PRIMARY KEY, timestamp DATETIME, test_av_loss FLOAT, test_accuracy FLOAT)""" |
| 169 | + (model_name TEXT PRIMARY KEY, timestamp DATETIME, test_av_loss FLOAT, test_accuracy FLOAT, test_set_num INT)""" |
168 | 170 | ) |
169 | 171 |
|
170 | 172 | con.execute( |
171 | | - f"INSERT OR REPLACE INTO {table_name} (model_name, timestamp, test_av_loss, test_accuracy) VALUES (?, ?, ?, ?) ", |
172 | | - (model_name, timestamp, test_av_loss, test_accuracy), |
| 173 | + f"INSERT OR REPLACE INTO {table_name} (model_name, timestamp, test_av_loss, test_accuracy, test_set_num) VALUES (?, ?, ?, ?, ?) ", |
| 174 | + (model_name, timestamp, test_av_loss, test_accuracy, test_set_num), |
173 | 175 | ) |
174 | 176 |
|
175 | 177 | con.close() |
|
0 commit comments