Skip to content

Fix missing and fragile scikit-learn imports in Keras sklearn wrappers #21387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions keras/src/wrappers/fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def _raise_or_return(target_type):
else:
return target_type

target_type = sklearn.utils.multiclass.type_of_target(
y, input_name=input_name
)
from sklearn.utils.multiclass import type_of_target as sk_type_of_target

target_type = sk_type_of_target(y, input_name=input_name)
return _raise_or_return(target_type)


Expand Down
8 changes: 6 additions & 2 deletions keras/src/wrappers/sklearn_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,9 @@ def fit(self, X, y, **kwargs):

def predict(self, X):
"""Predict using the model."""
sklearn.base.check_is_fitted(self)
from sklearn.utils.validation import check_is_fitted

check_is_fitted(self)
X = _validate_data(self, X, reset=False)
raw_output = self.model_.predict(X)
return self._reverse_process_target(raw_output)
Expand Down Expand Up @@ -472,7 +474,9 @@ def transform(self, X):
X_transformed: array-like, shape=(n_samples, n_features)
The transformed data.
"""
sklearn.base.check_is_fitted(self)
from sklearn.utils.validation import check_is_fitted

check_is_fitted(self)
X = _validate_data(self, X, reset=False)
return self.model_.predict(X)

Expand Down
13 changes: 8 additions & 5 deletions keras/src/wrappers/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import numpy as np

try:
import sklearn
from sklearn.base import BaseEstimator
Expand Down Expand Up @@ -25,8 +27,8 @@ def _check_model(model):
# compile model if user gave us an un-compiled model
if not model.compiled or not model.loss or not model.optimizer:
raise RuntimeError(
"Given model needs to be compiled, and have a loss and an "
"optimizer."
"Given model needs to be compiled, and have a loss "
"and an optimizer."
)


Expand Down Expand Up @@ -80,8 +82,9 @@ def inverse_transform(self, y):
is passed, it will be squeezed back to 1D. Otherwise, it
will eb left untouched.
"""
sklearn.base.check_is_fitted(self)
xp, _ = sklearn.utils._array_api.get_namespace(y)
from sklearn.utils.validation import check_is_fitted

check_is_fitted(self)
if self.ndim_ == 1 and y.ndim == 2:
return xp.squeeze(y, axis=1)
return np.squeeze(y, axis=1)
return y