-
Notifications
You must be signed in to change notification settings - Fork 91
Integrate Sklearn OneHotEncoder #830
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
5e2d533
a407776
6ebdebf
db8b9f0
ec1d143
c4d692e
8659ef0
331be4e
104dddc
1f4059b
24c75b7
6511d28
be1942b
d09c3a2
ab89776
14dc0fd
6505d72
64cc83f
a7f5088
3813863
ed36917
bbd8357
490a04f
f6e6d67
b7c3e56
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,17 +70,16 @@ def fit(self, X, y=None): | |
X = pd.DataFrame(X) | ||
X_t = X | ||
cols_to_encode = self._get_cat_cols(X_t) | ||
self.col_unique_values = {} | ||
|
||
# If there are no categorical columns, nothing needs to happen | ||
if len(cols_to_encode) == 0: | ||
categories = 'auto' | ||
eccabay marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Use the categories parameter | ||
eccabay marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if isinstance(self.parameters['categories'], list): | ||
elif isinstance(self.parameters['categories'], list): | ||
eccabay marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if top_n is not None: | ||
raise ValueError("Cannot use categories and top_n arguments simultaneously") | ||
eccabay marked this conversation as resolved.
Show resolved
Hide resolved
|
||
categories = self.parameters['categories'] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could things break if this comes back as an empty list? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm assuming not, since we start with an empty list in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a good catch, scikit-learn throws an error in that case. I'll add our own catch to provide more useful feedback. |
||
for col in X_t[cols_to_encode]: | ||
value_counts = X_t[col].value_counts(dropna=False).to_frame() | ||
unique_values = value_counts.index.tolist() | ||
self.col_unique_values[col] = np.sort(unique_values) | ||
|
||
# Use the top_n parameter | ||
else: | ||
|
@@ -98,12 +97,8 @@ def fit(self, X, y=None): | |
value_counts = value_counts.sort_values([col], ascending=False, kind='mergesort') | ||
unique_values = value_counts.head(top_n).index.tolist() | ||
unique_values = np.sort(unique_values) | ||
self.col_unique_values[col] = unique_values | ||
categories.append(unique_values) | ||
eccabay marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if len(cols_to_encode) == 0: | ||
categories = 'auto' | ||
|
||
# Create an encoder to pass off the rest of the computation to | ||
encoder = SKOneHotEncoder(categories=categories, | ||
drop=self.drop, | ||
|
@@ -120,9 +115,7 @@ def transform(self, X, y=None): | |
Returns: | ||
Transformed dataframe, where each categorical feature has been encoded into numerical columns using one-hot encoding. | ||
""" | ||
try: | ||
col_values = self.col_unique_values | ||
except AttributeError: | ||
if self.encoder is None: | ||
raise RuntimeError("You must fit one hot encoder before calling transform!") | ||
if not isinstance(X, pd.DataFrame): | ||
X = pd.DataFrame(X) | ||
|
@@ -134,11 +127,11 @@ def transform(self, X, y=None): | |
X_t = pd.DataFrame() | ||
# Add the non-categorical columns, untouched | ||
for col in X.columns: | ||
if col not in col_values: | ||
if col not in cat_cols: | ||
X_t = pd.concat([X_t, X[col]], axis=1) | ||
|
||
# Call sklearn's transform on the categorical columns | ||
if len(col_values) != 0: | ||
if len(cat_cols) != 0: | ||
eccabay marked this conversation as resolved.
Show resolved
Hide resolved
|
||
X_cat = pd.DataFrame(self.encoder.transform(X[cat_cols]).toarray()) | ||
X_cat.columns = self.encoder.get_feature_names(input_features=cat_cols) | ||
X_t = pd.concat([X_t.reindex(X_cat.index), X_cat], axis=1) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have unit test coverage of this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep,
test_all_numerical_dtype
starting at line 284