Skip to content

Commit a5f8ce8

Browse files
committed
🚑 forgot to add handling to update for skmultiflow input format
1 parent ab84173 commit a5f8ce8

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

src/predictive_model/classification/classification.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,15 @@ def _update(job: Job, data: DataFrame) -> dict:
153153
if not x.empty:
154154
y = x['label']
155155

156-
models[cluster].partial_fit(x.drop('label', 1), y.values.ravel())
156+
try:
157+
models[cluster].partial_fit(x.drop('label', 1), y.values.ravel())
158+
except (NotImplementedError, KeyError):
159+
try:
160+
models[cluster].partial_fit(x.drop('label', 1).T, y.values.ravel())
161+
except KeyError:
162+
models[cluster].partial_fit(x.drop('label', 1).values, y.values.ravel())
163+
except Exception as exception:
164+
raise exception
157165

158166
return {ModelType.CLUSTERER.value: clusterer, ModelType.CLASSIFIER.value: models}
159167

0 commit comments

Comments
 (0)