Skip to content

Commit 696d873

Browse files
committed
Quick patch for limited support of fit args
1 parent 9eef63e commit 696d873

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

vecstack/core.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def transformer(y, func=None):
6565

6666
def model_action(model, X_train, y_train, X_test,
6767
sample_weight=None, action=None,
68-
transform=None):
68+
transform=None, **fit_args):
6969
"""Performs model action.
7070
This wrapper gives us ability to choose action dynamically
7171
(e.g. predict or predict_proba).
@@ -78,9 +78,9 @@ def model_action(model, X_train, y_train, X_test,
7878
# We use following condition, because some models (e.g. Lars) may not have
7979
# 'sample_weight' parameter of fit method
8080
if sample_weight is not None:
81-
return model.fit(X_train, transformer(y_train, func = transform), sample_weight=sample_weight)
81+
return model.fit(X_train, transformer(y_train, func = transform), sample_weight=sample_weight, **fit_args)
8282
else:
83-
return model.fit(X_train, transformer(y_train, func = transform))
83+
return model.fit(X_train, transformer(y_train, func = transform), **fit_args)
8484
elif 'predict' == action:
8585
return transformer(model.predict(X_test), func = transform)
8686
elif 'predict_proba' == action:
@@ -127,7 +127,7 @@ def stacking(models, X_train, y_train, X_test,
127127
transform_target=None, transform_pred=None,
128128
mode='oof_pred_bag', needs_proba=False, save_dir=None,
129129
metric=None, n_folds=4, stratified=False,
130-
shuffle=False, random_state=0, verbose=0):
130+
shuffle=False, random_state=0, verbose=0, **fit_args):
131131
"""Function 'stacking' takes train data, test data and list of 1-st level
132132
models, and returns stacking features, which can be used with 2-nd level model.
133133
@@ -565,7 +565,7 @@ def your_metric(y_true, y_pred):
565565

566566
# Fit 1-st level model
567567
if mode in ['pred_bag', 'oof', 'oof_pred', 'oof_pred_bag']:
568-
_ = model_action(model, X_tr, y_tr, None, sample_weight = sample_weight_tr, action = 'fit', transform = transform_target)
568+
_ = model_action(model, X_tr, y_tr, None, sample_weight = sample_weight_tr, action = 'fit', transform = transform_target, **fit_args)
569569

570570
# Predict out-of-fold part of train set
571571
if mode in ['oof', 'oof_pred', 'oof_pred_bag']:
@@ -625,7 +625,7 @@ def your_metric(y_true, y_pred):
625625
if mode in ['pred', 'oof_pred']:
626626
if verbose > 0:
627627
print(' Fitting on full train set...\n')
628-
_ = model_action(model, X_train, y_train, None, sample_weight = sample_weight, action = 'fit', transform = transform_target)
628+
_ = model_action(model, X_train, y_train, None, sample_weight = sample_weight, action = 'fit', transform = transform_target, **fit_args)
629629
if 'predict_proba' == action:
630630
col_slice_model = slice(model_counter * n_classes, model_counter * n_classes + n_classes)
631631
else:

0 commit comments

Comments
 (0)