-
Notifications
You must be signed in to change notification settings - Fork 1.3k
ENH make Random*Sampler accept dask array and dataframe #777
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 1 commit
95247e6
ea30287
0766964
d9edb9a
4960724
2152429
e5ce7a6
b537a20
f781be0
fb3d6a4
b7d9f3b
d26da3c
c065808
f2d0ec0
20ba934
0941a5e
7aae9d9
00c0a26
8bfa040
d4aabf8
58acdf2
e54c772
f2a572f
36a0aa3
c7bdc74
f095221
20b44c6
4cd9116
a6e975b
32eda46
6c592ff
456c3eb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -81,6 +81,9 @@ def __init__( | |
self.replacement = replacement | ||
|
||
def _check_X_y(self, X, y): | ||
if is_dask_container(y) and hasattr(y, "to_dask_array"): | ||
y = y.to_dask_array() | ||
y.compute_chunk_sizes() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In Dask-ML we (@stsievert I think? maybe me?) prefer to have the user do this: https://github.com/dask/dask-ml/blob/7e11ce1505a485104e02d49a3620c8264e63e12e/dask_ml/utils.py#L166-L173. If you're just fitting the one estimator then this is probably equivalent. If you're doing something like a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is something that I was unsure of, here. If I recall, the issue was that I could not have called However, if we assume that the checks are too expensive to be done in a distributive setting, we don't need to call the check below and we can directly pass the Series and handle it during the resampling. So, we have fewer safeguards but at least it is more performant which is something you probably want in a distrubted setting |
||
y, binarize_y, self._uniques = check_target_type( | ||
y, | ||
indicate_one_vs_all=True, | ||
|
@@ -95,6 +98,9 @@ def _check_X_y(self, X, y): | |
dtype=None, | ||
force_all_finite=False, | ||
) | ||
elif is_dask_container(X) and hasattr(X, "to_dask_array"): | ||
X = X.to_dask_array() | ||
X.compute_chunk_sizes() | ||
return X, y, binarize_y | ||
|
||
@staticmethod | ||
|
@@ -140,7 +146,7 @@ def _more_tags(self): | |
"2darray", | ||
"string", | ||
"dask-array", | ||
# "dask-dataframe" | ||
"dask-dataframe" | ||
], | ||
"sample_indices": True, | ||
"allow_nan": True, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've struggled with this check in dask-ml. Depending on where it's called, it's potentially very expensive (you might be loading a ton of data just to check if it's multi-label, and then loading it again to to the training).
Whenever possible, it's helpful to provide an option to skip this check by having the user specify it when creating the estimator, or in a keyword to
fit
(dunno if that applies here).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about it. Do you think that having a context manager outside would make sense:
Thought, we might get into trouble with issues related to scikit-learn/scikit-learn#18736
It might just be easier to have an optional class parameter that applies only for dask arrays.