diff --git a/CHANGELOG b/CHANGELOG index 6dface96a..fe3bc6e09 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -13,13 +13,16 @@ The rules for CHANGELOG file: Unreleased ---------- +- Fix typos and make minor edits to ``PCovC`` and ``KPCovR`` docs (#254) +- Fix assertTrue typos in ``KPCovR`` testing suite (#254) +- Add ``KernelPCovC`` class and examples (#254) 0.3.0 (2025/06/12) ------------------ - Add ``_BasePCov`` class (#248) - Add ``PCovC`` class that inherits shared functionality from ``_BasePCov`` (#248) - Add ``PCovC`` testing suite and examples (#248) -- Modify ``PCovR`` to inherit shared functionality from ``_BasePCov_`` (#248) +- Modify ``PCovR`` to inherit shared functionality from ``_BasePCov`` (#248) - Update to sklearn >= 1.7.0 and scipy >= 1.15.0 (#239, #257) - Fixed moved function import from scipy and bump scipy dependency to 1.15.0 (#236) - Fix rendering issues for `SparseKDE` and `QuickShift` (#236) diff --git a/docs/src/references/decomposition.rst b/docs/src/references/decomposition.rst index 0ee5caf9c..483d5f7c4 100644 --- a/docs/src/references/decomposition.rst +++ b/docs/src/references/decomposition.rst @@ -54,3 +54,19 @@ Kernel PCovR .. automethod:: predict .. automethod:: inverse_transform .. automethod:: score + +.. _KPCovC-api: + +Kernel PCovC +------------ + +.. autoclass:: skmatter.decomposition.KernelPCovC + :show-inheritance: + :special-members: + + .. automethod:: fit + .. automethod:: transform + .. automethod:: predict + .. automethod:: inverse_transform + .. automethod:: decision_function + .. automethod:: score diff --git a/examples/pcovc/KPCovC_Comparison.py b/examples/pcovc/KPCovC_Comparison.py new file mode 100644 index 000000000..0dd1277e5 --- /dev/null +++ b/examples/pcovc/KPCovC_Comparison.py @@ -0,0 +1,265 @@ +#!/usr/bin/env python +# coding: utf-8 + +""" +Comparing KPCovC with KPCA +====================================== +""" +# %% +# + +import numpy as np + +import matplotlib.pyplot as plt +import matplotlib as mpl +from matplotlib.colors import ListedColormap + +from sklearn import datasets +from sklearn.preprocessing import StandardScaler +from sklearn.svm import LinearSVC +from sklearn.decomposition import PCA, KernelPCA +from sklearn.inspection import DecisionBoundaryDisplay +from sklearn.model_selection import train_test_split +from sklearn.linear_model import ( + LogisticRegressionCV, + RidgeClassifierCV, + SGDClassifier, +) + +from skmatter.decomposition import PCovC, KernelPCovC + +plt.rcParams["scatter.edgecolors"] = "k" +cm_bright = ListedColormap(["#d7191c", "#fdae61", "#a6d96a", "#3a7cdf"]) + +random_state = 0 +n_components = 2 + +# %% +# +# For this, we will combine two ``sklearn`` datasets from +# :func:`sklearn.datasets.make_moons`. + +X1, y1 = datasets.make_moons(n_samples=750, noise=0.10, random_state=random_state) +X2, y2 = datasets.make_moons(n_samples=750, noise=0.10, random_state=random_state) + +X2, y2 = X2 + 2, y2 + 2 +R = np.array( + [ + [np.cos(np.pi / 2), -np.sin(np.pi / 2)], + [np.sin(np.pi / 2), np.cos(np.pi / 2)], + ] +) +# rotate second pair of moons +X2 = X2 @ R.T + +X = np.vstack([X1, X2]) +y = np.concatenate([y1, y2]) + +# %% +# +# Original Data +# ------------- + +fig, ax = plt.subplots(figsize=(5.5, 5)) +ax.scatter(X[:, 0], X[:, 1], c=y, cmap=cm_bright) +ax.set_title("Original Data") + + +# %% +# +# Scale Data + +X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.25, stratify=y, random_state=random_state +) + +scaler = StandardScaler() +X_train_scaled = scaler.fit_transform(X_train) +X_test_scaled = scaler.transform(X_test) + +# %% +# +# PCA and PCovC +# ------------- +# +# Both PCA and PCovC fail to produce linearly separable latent space +# maps. We will need a kernel method to effectively separate the moon classes. + +mixing = 0.10 +alpha_d = 0.5 +alpha_p = 0.4 + +models = { + PCA(n_components=n_components): "PCA", + PCovC( + n_components=n_components, + random_state=random_state, + mixing=mixing, + classifier=LinearSVC(), + ): "PCovC", +} + +fig, axs = plt.subplots(1, 2, figsize=(10, 4)) + +for ax, model in zip(axs, models): + t_train = model.fit_transform(X_train_scaled, y_train) + t_test = model.transform(X_test_scaled) + + ax.scatter(t_test[:, 0], t_test[:, 1], alpha=alpha_d, cmap=cm_bright, c=y_test) + ax.scatter(t_train[:, 0], t_train[:, 1], cmap=cm_bright, c=y_train) + + ax.set_title(models[model]) + plt.tight_layout() + +# %% +# +# Kernel PCA and Kernel PCovC +# --------------------------- +# +# A comparison of the latent spaces produced by KPCA and KPCovC is shown. +# A logistic regression classifier is trained on the KPCA latent space (this is also +# the default classifier used in KPCovC), and we see the comparison of the respective +# decision boundaries and test data accuracy scores. + +fig, axs = plt.subplots(1, 2, figsize=(13, 6)) + +center = True +resolution = 1000 + +kernel_params = {"kernel": "rbf", "gamma": 2} + +models = { + KernelPCA(n_components=n_components, **kernel_params): { + "title": "Kernel PCA", + "eps": 0.1, + }, + KernelPCovC( + n_components=n_components, + random_state=random_state, + mixing=mixing, + center=center, + **kernel_params, + ): {"title": "Kernel PCovC", "eps": 2}, +} + +for ax, model in zip(axs, models): + t_train = model.fit_transform(X_train_scaled, y_train) + t_test = model.transform(X_test_scaled) + + if isinstance(model, KernelPCA): + t_classifier = LinearSVC(random_state=random_state).fit(t_train, y_train) + score = t_classifier.score(t_test, y_test) + else: + t_classifier = model.classifier_ + score = model.score(X_test_scaled, y_test) + + DecisionBoundaryDisplay.from_estimator( + estimator=t_classifier, + X=t_test, + ax=ax, + response_method="predict", + cmap=cm_bright, + alpha=alpha_d, + eps=models[model]["eps"], + grid_resolution=resolution, + ) + ax.scatter(t_test[:, 0], t_test[:, 1], alpha=alpha_p, cmap=cm_bright, c=y_test) + ax.scatter(t_train[:, 0], t_train[:, 1], cmap=cm_bright, c=y_train) + ax.set_title(models[model]["title"]) + + ax.text( + 0.82, + 0.03, + f"Score: {round(score, 3)}", + fontsize=mpl.rcParams["axes.titlesize"], + transform=ax.transAxes, + ) + ax.set_xticks([]) + ax.set_yticks([]) + +fig.subplots_adjust(wspace=0.04) +plt.tight_layout() + + +# %% +# +# Effect of KPCovC Classifier on KPCovC Maps and Decision Boundaries +# ------------------------------------------------------------------------------ +# +# Based on the evidence :math:`\mathbf{Z}` generated by the underlying classifier fit +# on a computed kernel :math:`\mathbf{K}` and :math:`\mathbf{Y}`, Kernel PCovC will +# produce varying latent space maps. Hence, the decision boundaries produced by the +# linear classifier fit between :math:`\mathbf{T}` and :math:`\mathbf{Y}` to make +# predictions will also vary. + +names = ["Logistic Regression", "Ridge Classifier", "Linear SVC", "SGD Classifier"] + +models = { + LogisticRegressionCV(random_state=random_state): { + "kernel_params": {"kernel": "rbf", "gamma": 12}, + "title": "Logistic Regression", + }, + RidgeClassifierCV(): { + "kernel_params": {"kernel": "rbf", "gamma": 1}, + "title": "Ridge Classifier", + "eps": 0.40, + }, + LinearSVC(random_state=random_state): { + "kernel_params": {"kernel": "rbf", "gamma": 15}, + "title": "Support Vector Classification", + }, + SGDClassifier(random_state=random_state): { + "kernel_params": {"kernel": "rbf", "gamma": 15}, + "title": "SGD Classifier", + "eps": 10, + }, +} + +fig, axs = plt.subplots(1, len(models), figsize=(4 * len(models), 4)) + +for ax, name, model in zip(axs.flat, names, models): + kpcovc = KernelPCovC( + n_components=n_components, + random_state=random_state, + mixing=mixing, + classifier=model, + center=center, + **models[model]["kernel_params"], + ) + t_kpcovc_train = kpcovc.fit_transform(X_train_scaled, y_train) + t_kpcovc_test = kpcovc.transform(X_test_scaled) + kpcovc_score = kpcovc.score(X_test_scaled, y_test) + + DecisionBoundaryDisplay.from_estimator( + estimator=kpcovc.classifier_, + X=t_kpcovc_test, + ax=ax, + response_method="predict", + cmap=cm_bright, + alpha=alpha_d, + eps=models[model].get("eps", 1), + grid_resolution=resolution, + ) + + ax.scatter( + t_kpcovc_test[:, 0], + t_kpcovc_test[:, 1], + cmap=cm_bright, + alpha=alpha_p, + c=y_test, + ) + ax.scatter(t_kpcovc_train[:, 0], t_kpcovc_train[:, 1], cmap=cm_bright, c=y_train) + ax.text( + 0.70, + 0.03, + f"Score: {round(kpcovc_score, 3)}", + fontsize=mpl.rcParams["axes.titlesize"], + transform=ax.transAxes, + ) + + ax.set_title(name) + ax.set_xticks([]) + ax.set_yticks([]) + fig.subplots_adjust(wspace=0.04) + + plt.tight_layout() diff --git a/examples/pcovc/KPCovC_Hyperparameters.py b/examples/pcovc/KPCovC_Hyperparameters.py new file mode 100644 index 000000000..ce3948e25 --- /dev/null +++ b/examples/pcovc/KPCovC_Hyperparameters.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python +# coding: utf-8 + +""" +KPCovC Hyperparameter Tuning +================================== +""" +# %% +# + +import matplotlib as mpl +import matplotlib.pyplot as plt + +from sklearn import datasets +from sklearn.preprocessing import StandardScaler +from sklearn.decomposition import KernelPCA +from skmatter.decomposition import KernelPCovC + +plt.rcParams["image.cmap"] = "tab20" +plt.rcParams["scatter.edgecolors"] = "k" + +random_state = 0 +n_components = 2 + +# %% +# +# For this, we will use the :func:`sklearn.datasets.make_circles` dataset from +# ``sklearn``. + +X, y = datasets.make_circles( + noise=0.1, factor=0.7, random_state=random_state, n_samples=1500 +) + +# %% +# Original Data +# ------------- + +fig, ax = plt.subplots(figsize=(5.5, 5)) +ax.scatter(X[:, 0], X[:, 1], c=y) +ax.set_title("Original Data") + +# %% +# +# Scale Data + +scaler = StandardScaler() +X_scaled = scaler.fit_transform(X) + +# %% +# +# Effect of Kernel on KPCA and KPCovC Projections +# ----------------------------------------------------------- +# +# Here, we see how Kernel PCovC with kernels such as a radial basis function +# can outperform Kernel PCA by producing cleanly separable projections from +# noisy circular data. + +kernels = ["linear", "rbf", "sigmoid", "poly"] +kernel_params = { + "rbf": {"gamma": 0.5}, + "sigmoid": {"gamma": 1.0}, + "poly": {"degree": 6}, +} + +fig, axs = plt.subplots(2, len(kernels), figsize=(len(kernels) * 4, 8)) + +center = True +mixing = 0.10 + +for i, kernel in enumerate(kernels): + kpca = KernelPCA( + random_state=random_state, + n_components=n_components, + kernel=kernel, + **kernel_params.get(kernel, {}), + ) + t_kpca = kpca.fit_transform(X_scaled) + + kpcovc = KernelPCovC( + n_components=n_components, + mixing=mixing, + kernel=kernel, + random_state=random_state, + **kernel_params.get(kernel, {}), + center=center, + ) + t_kpcovc = kpcovc.fit_transform(X_scaled, y) + + axs[0][i].scatter(t_kpca[:, 0], t_kpca[:, 1], c=y) + axs[1][i].scatter(t_kpcovc[:, 0], t_kpcovc[:, 1], c=y) + + axs[0][i].set_title(kernel) + + axs[0][i].set_xticks([]) + axs[1][i].set_xticks([]) + axs[0][i].set_yticks([]) + axs[1][i].set_yticks([]) + +axs[0][0].set_ylabel("Kernel PCA", fontsize=mpl.rcParams["axes.titlesize"]) +axs[1][0].set_ylabel("Kernel PCovC", fontsize=mpl.rcParams["axes.titlesize"]) + +fig.subplots_adjust(wspace=0, hspace=0) +plt.tight_layout() + +# %% +# +# Decision Boundary Formation with Gamma Tuning +# --------------------------------------------- +# +# Depending on the data, tuning gamma values for the KPCovC kernel can greatly +# improve latent space projections, enabling clearer decision boundaries. + +gamma_vals = [0.001, 0.0016, 0.00165, 0.00167, 0.00169, 0.00175] + +fig, axs = plt.subplots(1, len(gamma_vals), figsize=(len(gamma_vals) * 3.5, 3.5)) + +for ax, gamma in zip(axs, gamma_vals): + kpcovc = KernelPCovC( + n_components=n_components, + random_state=random_state, + mixing=mixing, + center=center, + kernel="rbf", + gamma=gamma, + ) + t_kpcovc = kpcovc.fit_transform(X_scaled, y) + + ax.scatter(t_kpcovc[:, 0], t_kpcovc[:, 1], c=y) + ax.set_title(f"gamma: {gamma}") + + ax.set_xticks([]) + ax.set_yticks([]) + +fig.subplots_adjust(wspace=0) +plt.tight_layout() diff --git a/examples/pcovc/PCovC_Comparison.py b/examples/pcovc/PCovC_Comparison.py index d3e769142..eaa59ca1b 100644 --- a/examples/pcovc/PCovC_Comparison.py +++ b/examples/pcovc/PCovC_Comparison.py @@ -107,4 +107,6 @@ axs[1].set_title("LDA") axs[2].scatter(T_pcovc[:, 0], T_pcovc[:, 1], c=y) axs[2].set_title("PCovC") + +plt.tight_layout() plt.show() diff --git a/examples/pcovc/PCovC_Hyperparameters.py b/examples/pcovc/PCovC_Hyperparameters.py index 49846c687..22989a95d 100644 --- a/examples/pcovc/PCovC_Hyperparameters.py +++ b/examples/pcovc/PCovC_Hyperparameters.py @@ -71,9 +71,7 @@ fig, axs = plt.subplots(1, n_mixing, figsize=(4 * n_mixing, 4), sharey="row") -for id in range(0, n_mixing): - mixing = mixing_params[id] - +for ax, mixing in zip(axs, mixing_params): pcovc = PCovC( mixing=mixing, n_components=n_components, @@ -84,21 +82,22 @@ pcovc.fit(X_scaled, y) T = pcovc.transform(X_scaled) - axs[id].set_xticks([]) - axs[id].set_yticks([]) + ax.set_xticks([]) + ax.set_yticks([]) - axs[id].set_title(r"$\alpha=$" + str(mixing)) - axs[id].set_xlabel("PCov$_1$") - axs[id].scatter(T[:, 0], T[:, 1], c=y) + ax.set_title(r"$\alpha=$" + str(mixing)) + ax.set_xlabel("PCov$_1$") + ax.scatter(T[:, 0], T[:, 1], c=y) axs[0].set_ylabel("PCov$_2$") fig.subplots_adjust(wspace=0) +plt.tight_layout() # %% # -# Effect of PCovC Classifier on PCovC Map and Decision Boundaries -# --------------------------------------------------------------- +# Effect of PCovC Classifier on PCovC Maps and Decision Boundaries +# ---------------------------------------------------------------- # # Here, we see how a PCovC model (:math:`\alpha` = 0.5) fitted with # different classifiers produces varying PCovC maps. In addition, @@ -115,9 +114,7 @@ Perceptron(random_state=random_state): "Single-Layer Perceptron", } -for id in range(0, len(models)): - model = list(models)[id] - +for ax, model in zip(axs, models): pcovc = PCovC( mixing=mixing, n_components=n_components, @@ -128,22 +125,21 @@ pcovc.fit(X_scaled, y) T = pcovc.transform(X_scaled) - graph = axs[id] - graph.set_title(models[model]) + ax.set_title(models[model]) DecisionBoundaryDisplay.from_estimator( estimator=pcovc.classifier_, X=T, - ax=graph, + ax=ax, response_method="predict", grid_resolution=1000, ) - scatter = graph.scatter(T[:, 0], T[:, 1], c=y) + scatter = ax.scatter(T[:, 0], T[:, 1], c=y) - graph.set_xlabel("PCov$_1$") - graph.set_xticks([]) - graph.set_yticks([]) + ax.set_xlabel("PCov$_1$") + ax.set_xticks([]) + ax.set_yticks([]) axs[0].set_ylabel("PCov$_2$") axs[0].legend( @@ -155,4 +151,5 @@ ) fig.subplots_adjust(wspace=0.04) +plt.tight_layout() plt.show() diff --git a/examples/pcovc/README.rst b/examples/pcovc/README.rst index 4018f7ffa..3a69e6e5d 100644 --- a/examples/pcovc/README.rst +++ b/examples/pcovc/README.rst @@ -1,2 +1,2 @@ -PCovC -===== +PCovC and KernelPCovC +===================== diff --git a/src/skmatter/decomposition/__init__.py b/src/skmatter/decomposition/__init__.py index 4fbb6d92c..2fd488c93 100644 --- a/src/skmatter/decomposition/__init__.py +++ b/src/skmatter/decomposition/__init__.py @@ -16,7 +16,8 @@ [Helfrecht2020]_ introduced the non-linear version of PCovR, Kernel Principal Covariates Regression (KPCovR), where the mixing parameter α now interpolates between kernel ridge regression (:math:`\alpha = 0`) and -kernel principal components analysis (KPCA, :math:`\alpha = 1`). +kernel principal components analysis (KPCA, :math:`\alpha = 1`). A non-linear version +of PCovC, Kernel Principal Covariates Classification (KPCovC), is also provided. The module includes: @@ -30,18 +31,24 @@ * :ref:`KPCovR-api` the Kernel Principal Covariates Regression. A kernel-based variation on the original PCovR method, proposed in [Helfrecht2020]_. +* :ref:`KPCovC-api` the Kernel Principal Covariates Classification. + A kernel-based modification on the original PCovC method. """ from ._pcov import _BasePCov +from ._kpcov import _BaseKPCov from ._pcovr import PCovR from ._pcovc import PCovC from ._kernel_pcovr import KernelPCovR +from ._kernel_pcovc import KernelPCovC __all__ = [ "_BasePCov", + "_BaseKPCov", "PCovR", "PCovC", "KernelPCovR", + "KernelPCovC", ] diff --git a/src/skmatter/decomposition/_kernel_pcovc.py b/src/skmatter/decomposition/_kernel_pcovc.py new file mode 100644 index 000000000..63cc5f9fb --- /dev/null +++ b/src/skmatter/decomposition/_kernel_pcovc.py @@ -0,0 +1,448 @@ +import numpy as np + +from sklearn import clone +from sklearn.svm import LinearSVC +from sklearn.discriminant_analysis import LinearDiscriminantAnalysis +from sklearn.linear_model import ( + Perceptron, + RidgeClassifier, + RidgeClassifierCV, + LogisticRegression, + LogisticRegressionCV, + SGDClassifier, +) +from sklearn.utils import check_array +from sklearn.utils.validation import check_is_fitted, validate_data +from sklearn.linear_model._base import LinearClassifierMixin +from sklearn.utils.multiclass import check_classification_targets, type_of_target + +from skmatter.preprocessing import KernelNormalizer +from skmatter.utils import check_cl_fit +from skmatter.decomposition import _BaseKPCov + + +class KernelPCovC(LinearClassifierMixin, _BaseKPCov): + r"""Kernel Principal Covariates Classification (KPCovC). + + KPCovC is a modification on the PrincipalCovariates Classification + proposed in [Jorgensen2025]_. It determines a latent-space projection + :math:`\mathbf{T}` which minimizes a combined loss in supervised and unsupervised + tasks in the reproducing kernel Hilbert space (RKHS). + + This projection is determined by the eigendecomposition of a modified gram matrix + :math:`\mathbf{\tilde{K}}` + + .. math:: + \mathbf{\tilde{K}} = \alpha \mathbf{K} + + (1 - \alpha) \mathbf{Z}\mathbf{Z}^T + + where :math:`\alpha` is a mixing parameter, + :math:`\mathbf{K}` is the input kernel of shape :math:`(n_{samples}, n_{samples})` + and :math:`\mathbf{Z}` is a matrix of class confidence scores of shape + :math:`(n_{samples}, n_{classes})` + + Parameters + ---------- + mixing : float, default=0.5 + mixing parameter, as described in PCovC as :math:`{\alpha}` + + n_components : int, float or str, default=None + Number of components to keep. + if n_components is not set all components are kept:: + + n_components == n_samples + + svd_solver : {'auto', 'full', 'arpack', 'randomized'}, default='auto' + If auto : + The solver is selected by a default policy based on `X.shape` and + `n_components`: if the input data is larger than 500x500 and the + number of components to extract is lower than 80% of the smallest + dimension of the data, then the more efficient 'randomized' + method is enabled. Otherwise the exact full SVD is computed and + optionally truncated afterwards. + If full : + run exact full SVD calling the standard LAPACK solver via + `scipy.linalg.svd` and select the components by postprocessing + If arpack : + run SVD truncated to n_components calling ARPACK solver via + `scipy.sparse.linalg.svds`. It requires strictly + 0 < n_components < min(X.shape) + If randomized : + run randomized SVD by the method of Halko et al. + + classifier: `estimator object` or `precomputed`, default=None + classifier for computing :math:`{\mathbf{Z}}`. The classifier should be + one of the following: + + - ``sklearn.linear_model.LogisticRegression()`` + - ``sklearn.linear_model.LogisticRegressionCV()`` + - ``sklearn.svm.LinearSVC()`` + - ``sklearn.discriminant_analysis.LinearDiscriminantAnalysis()`` + - ``sklearn.linear_model.RidgeClassifier()`` + - ``sklearn.linear_model.RidgeClassifierCV()`` + - ``sklearn.linear_model.Perceptron()`` + + If a pre-fitted classifier is provided, it is used to compute :math:`{\mathbf{Z}}`. + If None, ``sklearn.linear_model.LogisticRegression()`` + is used as the classifier. + + kernel : {"linear", "poly", "rbf", "sigmoid", "precomputed"} or callable, default="linear" + Kernel. + + gamma : {'scale', 'auto'} or float, default=None + Kernel coefficient for rbf, poly and sigmoid kernels. Ignored by other + kernels. + + degree : int, default=3 + Degree for poly kernels. Ignored by other kernels. + + coef0 : float, default=1 + Independent term in poly and sigmoid kernels. + Ignored by other kernels. + + kernel_params : mapping of str to any, default=None + Parameters (keyword arguments) and values for kernel passed as + callable object. Ignored by other kernels. + + center : bool, default=False + Whether to center any computed kernels + + fit_inverse_transform : bool, default=False + Learn the inverse transform for non-precomputed kernels. + (i.e. learn to find the pre-image of a point) + + tol : float, default=1e-12 + Tolerance for singular values computed by svd_solver == 'arpack' + and for matrix inversions. + Must be of range [0.0, infinity). + + n_jobs : int, default=None + The number of parallel jobs to run. + :obj:`None` means 1 unless in a :obj:`joblib.parallel_backend` context. + ``-1`` means using all processors. + + iterated_power : int or 'auto', default='auto' + Number of iterations for the power method computed by + svd_solver == 'randomized'. + Must be of range [0, infinity). + + random_state : int, :class:`numpy.random.RandomState` instance or None, default=None + Used when the 'arpack' or 'randomized' solvers are used. Pass an int + for reproducible results across multiple function calls. + + Attributes + ---------- + classifier : estimator object + The linear classifier passed for fitting. If pre-fitted, it is assummed + to be fit on a precomputed kernel :math:`\mathbf{K}` and :math:`\mathbf{Y}`. + + z_classifier_ : estimator object + The linear classifier fit between the computed kernel :math:`\mathbf{K}` + and :math:`\mathbf{Y}`. + + classifier_ : estimator object + The linear classifier fit between :math:`\mathbf{T}` and :math:`\mathbf{Y}`. + + pt__: numpy.darray of size :math:`({n_{components}, n_{components}})` + pseudo-inverse of the latent-space projection, which + can be used to contruct projectors from latent-space + + pkt_: numpy.ndarray of size :math:`({n_{samples}, n_{components}})` + the projector, or weights, from the input kernel :math:`\mathbf{K}` + to the latent-space projection :math:`\mathbf{T}` + + pkz_: numpy.ndarray of size :math:`({n_{samples}, })` or :math:`({n_{samples}, n_{classes}})` + the projector, or weights, from the input kernel :math:`\mathbf{K}` + to the class confidence scores :math:`\mathbf{Z}` + + ptz_: numpy.ndarray of size :math:`({n_{components}, })` or :math:`({n_{components}, n_{classes}})` + the projector, or weights, from the latent-space projection + :math:`\mathbf{T}` to the class confidence scores :math:`\mathbf{Z}` + + ptx_: numpy.ndarray of size :math:`({n_{components}, n_{features}})` + the projector, or weights, from the latent-space projection + :math:`\mathbf{T}` to the feature matrix :math:`\mathbf{X}` + + X_fit_: numpy.ndarray of shape (n_samples, n_features) + The data used to fit the model. This attribute is used to build kernels + from new data. + + Examples + -------- + >>> import numpy as np + >>> from skmatter.decomposition import KernelPCovC + >>> from sklearn.preprocessing import StandardScaler + >>> X = np.array([[-2, 3, -1, 0], [2, 0, -3, 1], [3, 0, -1, 3], [2, -2, 1, 0]]) + >>> X = StandardScaler().fit_transform(X) + >>> Y = np.array([[2], [0], [1], [2]]) + >>> kpcovc = KernelPCovC( + ... mixing=0.1, + ... n_components=2, + ... kernel="rbf", + ... gamma=1, + ... ) + >>> kpcovc.fit(X, Y) + KernelPCovC(gamma=1, kernel='rbf', mixing=0.1, n_components=2) + >>> kpcovc.transform(X) + array([[-4.45970689e-01, 8.95327566e-06], + [ 4.52745933e-01, 5.54810948e-01], + [ 4.52881359e-01, -5.54708315e-01], + [-4.45921092e-01, -7.32157649e-05]]) + >>> kpcovc.predict(X) + array([2, 0, 1, 2]) + >>> kpcovc.score(X, Y) + 1.0 + """ # NoQa: E501 + + def __init__( + self, + mixing=0.5, + n_components=None, + svd_solver="auto", + classifier=None, + kernel="linear", + gamma=None, + degree=3, + coef0=1, + kernel_params=None, + center=False, + fit_inverse_transform=False, + tol=1e-12, + n_jobs=None, + iterated_power="auto", + random_state=None, + ): + super().__init__( + mixing=mixing, + n_components=n_components, + svd_solver=svd_solver, + tol=tol, + iterated_power=iterated_power, + random_state=random_state, + center=center, + kernel=kernel, + gamma=gamma, + degree=degree, + coef0=coef0, + kernel_params=kernel_params, + n_jobs=n_jobs, + fit_inverse_transform=fit_inverse_transform, + ) + self.classifier = classifier + + def fit(self, X, Y, W=None): + r"""Fit the model with X and Y. + + A computed kernel K is derived from X, and W is taken from the + coefficients of a linear classifier fit between K and Y to compute + Z: + + .. math:: + \mathbf{Z} = \mathbf{K} \mathbf{W} + + We then call either `_fit_feature_space` or `_fit_sample_space`, + using Z as our approximation of Y. Finally, we refit a classifier on + T and Y to obtain :math:`\mathbf{P}_{TZ}`. + + Parameters + ---------- + X : numpy.ndarray, shape (n_samples, n_features) + Training data, where n_samples is the number of samples and + n_features is the number of features. + + It is suggested that :math:`\mathbf{X}` be centered by its column- + means and scaled. If features are related, the matrix should be + scaled to have unit variance, otherwise :math:`\mathbf{X}` should + be scaled so that each feature has a variance of 1 / n_features. + + Y : numpy.ndarray, shape (n_samples,) + Training data, where n_samples is the number of samples. + + W : numpy.ndarray, shape (n_features, n_classes) + Classification weights, optional when classifier = `precomputed`. If + not passed, it is assumed that the weights will be taken from a + linear classifier fit between K and Y. + + Returns + ------- + self: object + Returns the instance itself. + """ + X, Y = validate_data(self, X, Y, y_numeric=False) + check_classification_targets(Y) + self.classes_ = np.unique(Y) + + super().fit(X) + + K = self._get_kernel(X) + + if self.center: + self.centerer_ = KernelNormalizer() + K = self.centerer_.fit_transform(K) + + compatible_classifiers = ( + LogisticRegression, + LogisticRegressionCV, + LinearSVC, + LinearDiscriminantAnalysis, + RidgeClassifier, + RidgeClassifierCV, + SGDClassifier, + Perceptron, + ) + + if self.classifier not in ["precomputed", None] and not isinstance( + self.classifier, compatible_classifiers + ): + raise ValueError( + "Classifier must be an instance of `" + f"{'`, `'.join(c.__name__ for c in compatible_classifiers)}`" + ", or `precomputed`" + ) + + if self.classifier != "precomputed": + if self.classifier is None: + classifier = LogisticRegression() + else: + classifier = self.classifier + + # for convergence warnings + if hasattr(classifier, "max_iter") and ( + classifier.max_iter is None or classifier.max_iter < 500 + ): + classifier.max_iter = 500 + + # Check if classifier is fitted; if not, fit with precomputed K + self.z_classifier_ = check_cl_fit(classifier, K, Y) + W = self.z_classifier_.coef_.T.reshape(K.shape[1], -1) + + else: + # If precomputed, use default classifier to predict Y from T + classifier = LogisticRegression(max_iter=500) + if W is None: + W = LogisticRegression().fit(K, Y).coef_.T + W = W.reshape(K.shape[1], -1) + + Z = K @ W + + self._fit(K, Z, W) + + self.ptk_ = self.pt__ @ K + + if self.fit_inverse_transform: + self.ptx_ = self.pt__ @ X + + self.classifier_ = clone(classifier).fit(K @ self.pkt_, Y) + + self.ptz_ = self.classifier_.coef_.T + self.pkz_ = self.pkt_ @ self.ptz_ + + if len(Y.shape) == 1 and type_of_target(Y) == "binary": + self.pkz_ = self.pkz_.reshape( + K.shape[1], + ) + self.ptz_ = self.ptz_.reshape( + self.n_components_, + ) + + self.components_ = self.pkt_.T # for sklearn compatibility + return self + + def predict(self, X=None, T=None): + """Predicts the property labels using classification on T.""" + check_is_fitted(self, ["pkz_", "ptz_"]) + + if X is None and T is None: + raise ValueError("Either X or T must be supplied.") + + if X is not None: + X = validate_data(self, X, reset=False) + K = self._get_kernel(X, self.X_fit_) + if self.center: + K = self.centerer_.transform(K) + + return self.classifier_.predict(K @ self.pkt_) + else: + return self.classifier_.predict(T) + + def transform(self, X): + """Apply dimensionality reduction to X. + + ``X`` is projected on the first principal components as determined by the + modified Kernel PCovR distances. + + Parameters + ---------- + X : numpy.ndarray, shape (n_samples, n_features) + New data, where n_samples is the number of samples + and n_features is the number of features. + """ + return super().transform(X) + + def inverse_transform(self, T): + r"""Transform input data back to its original space. + + .. math:: + \mathbf{\hat{X}} = \mathbf{T} \mathbf{P}_{TX} + = \mathbf{K} \mathbf{P}_{KT} \mathbf{P}_{TX} + + Similar to KPCA, the original features are not always recoverable, + as the projection is computed from the kernel features, not the original + features, and the mapping between the original and kernel features + is not one-to-one. + + Parameters + ---------- + T : numpy.ndarray, shape (n_samples, n_components) + Projected data, where n_samples is the number of samples and n_components is + the number of components. + + Returns + ------- + X_original : numpy.ndarray, shape (n_samples, n_features) + """ + return super().inverse_transform(T) + + def decision_function(self, X=None, T=None): + r"""Predicts confidence scores from X or T. + + .. math:: + \mathbf{Z} = \mathbf{T} \mathbf{P}_{TZ} + = \mathbf{K} \mathbf{P}_{KT} \mathbf{P}_{TZ} + = \mathbf{K} \mathbf{P}_{KZ} + + Parameters + ---------- + X : ndarray, shape(n_samples, n_features) + Original data for which we want to get confidence scores, + where n_samples is the number of samples and n_features is the + number of features. + + T : ndarray, shape (n_samples, n_components) + Projected data for which we want to get confidence scores, + where n_samples is the number of samples and n_components is the + number of components. + + Returns + ------- + Z : numpy.ndarray, shape (n_samples,) or (n_samples, n_classes) + Confidence scores. For binary classification, has shape `(n_samples,)`, + for multiclass classification, has shape `(n_samples, n_classes)` + """ + check_is_fitted(self, attributes=["pkz_", "ptz_"]) + + if X is None and T is None: + raise ValueError("Either X or T must be supplied.") + + if X is not None: + X = validate_data(self, X, reset=False) + K = self._get_kernel(X, self.X_fit_) + if self.center: + K = self.centerer_.transform(K) + + # Or self.classifier_.decision_function(K @ self.pxt_) + return K @ self.pkz_ + self.classifier_.intercept_ + + else: + T = check_array(T) + return T @ self.ptz_ + self.classifier_.intercept_ diff --git a/src/skmatter/decomposition/_kernel_pcovr.py b/src/skmatter/decomposition/_kernel_pcovr.py index 54ee8855c..ed276739a 100644 --- a/src/skmatter/decomposition/_kernel_pcovr.py +++ b/src/skmatter/decomposition/_kernel_pcovr.py @@ -1,28 +1,20 @@ -import numbers - import numpy as np -from scipy import linalg -from scipy.sparse.linalg import svds -from sklearn.decomposition._base import _BasePCA -from sklearn.decomposition._pca import _infer_dimension + from sklearn.exceptions import NotFittedError from sklearn.kernel_ridge import KernelRidge -from sklearn.linear_model._base import LinearModel -from sklearn.metrics.pairwise import pairwise_kernels -from sklearn.utils import check_random_state -from sklearn.utils._arpack import _init_arpack_v0 -from sklearn.utils.extmath import randomized_svd, stable_cumsum, svd_flip from sklearn.utils.validation import _check_n_features, check_is_fitted, validate_data -from ..preprocessing import KernelNormalizer -from ..utils import check_krr_fit, pcovr_kernel +from skmatter.utils import check_krr_fit +from skmatter.decomposition import _BaseKPCov +from skmatter.preprocessing import KernelNormalizer -class KernelPCovR(_BasePCA, LinearModel): - r"""Kernel Principal Covariates Regression, as described in [Helfrecht2020]_ - determines a latent-space projection :math:`\mathbf{T}` which minimizes a combined - loss in supervised and unsupervised tasks in the reproducing kernel Hilbert space - (RKHS). +class KernelPCovR(_BaseKPCov): + r"""Kernel Principal Covariates Regression (KPCovR). + + As described in [Helfrecht2020]_, KPCovR determines a latent-space projection + :math:`\mathbf{T}` which minimizes a combined loss in supervised and unsupervised + tasks in the reproducing kernel Hilbert space (RKHS). This projection is determined by the eigendecomposition of a modified gram matrix :math:`\mathbf{\tilde{K}}` @@ -76,8 +68,8 @@ class KernelPCovR(_BasePCA, LinearModel): If `precomputed`, we assume that the `y` passed to the `fit` function is the regressed form of the targets :math:`{\mathbf{\hat{Y}}}`. - kernel : "linear" | "poly" | "rbf" | "sigmoid" | "cosine" | "precomputed" - Kernel. Default="linear". + kernel : {'linear', 'poly', 'rbf', 'sigmoid', 'cosine', 'precomputed'} or callable, default='linear' + Kernel. gamma : float, default=None Kernel coefficient for rbf, poly and sigmoid kernels. Ignored by other @@ -95,7 +87,7 @@ class KernelPCovR(_BasePCA, LinearModel): callable object. Ignored by other kernels. center : bool, default=False - Whether to center any computed kernels + Whether to center any computed kernels fit_inverse_transform : bool, default=False Learn the inverse transform for non-precomputed kernels. @@ -123,24 +115,24 @@ class KernelPCovR(_BasePCA, LinearModel): Attributes ---------- pt__: numpy.darray of size :math:`({n_{components}, n_{components}})` - pseudo-inverse of the latent-space projection, which - can be used to contruct projectors from latent-space + pseudo-inverse of the latent-space projection, which + can be used to contruct projectors from latent-space pkt_: numpy.ndarray of size :math:`({n_{samples}, n_{components}})` - the projector, or weights, from the input kernel :math:`\mathbf{K}` - to the latent-space projection :math:`\mathbf{T}` + the projector, or weights, from the input kernel :math:`\mathbf{K}` + to the latent-space projection :math:`\mathbf{T}` pky_: numpy.ndarray of size :math:`({n_{samples}, n_{properties}})` - the projector, or weights, from the input kernel :math:`\mathbf{K}` - to the properties :math:`\mathbf{Y}` + the projector, or weights, from the input kernel :math:`\mathbf{K}` + to the properties :math:`\mathbf{Y}` pty_: numpy.ndarray of size :math:`({n_{components}, n_{properties}})` - the projector, or weights, from the latent-space projection - :math:`\mathbf{T}` to the properties :math:`\mathbf{Y}` + the projector, or weights, from the latent-space projection + :math:`\mathbf{T}` to the properties :math:`\mathbf{Y}` ptx_: numpy.ndarray of size :math:`({n_{components}, n_{features}})` - the projector, or weights, from the latent-space projection - :math:`\mathbf{T}` to the feature matrix :math:`\mathbf{X}` + the projector, or weights, from the latent-space projection + :math:`\mathbf{T}` to the feature matrix :math:`\mathbf{X}` X_fit_: numpy.ndarray of shape (n_samples, n_features) The data used to fit the model. This attribute is used to build kernels @@ -198,59 +190,23 @@ def __init__( iterated_power="auto", random_state=None, ): - self.mixing = mixing - self.n_components = n_components - - self.svd_solver = svd_solver - self.tol = tol - self.iterated_power = iterated_power - self.random_state = random_state - self.center = center - - self.kernel = kernel - self.gamma = gamma - self.degree = degree - self.coef0 = coef0 - self.kernel_params = kernel_params - - self.n_jobs = n_jobs - - self.fit_inverse_transform = fit_inverse_transform - - self.regressor = regressor - - def _get_kernel(self, X, Y=None): - if callable(self.kernel): - params = self.kernel_params or {} - else: - params = {"gamma": self.gamma, "degree": self.degree, "coef0": self.coef0} - return pairwise_kernels( - X, Y, metric=self.kernel, filter_params=True, n_jobs=self.n_jobs, **params + super().__init__( + mixing=mixing, + n_components=n_components, + svd_solver=svd_solver, + tol=tol, + iterated_power=iterated_power, + random_state=random_state, + center=center, + kernel=kernel, + gamma=gamma, + degree=degree, + coef0=coef0, + kernel_params=kernel_params, + n_jobs=n_jobs, + fit_inverse_transform=fit_inverse_transform, ) - - def _fit(self, K, Yhat, W): - """Fit the model with the computed kernel and approximated properties.""" - K_tilde = pcovr_kernel(mixing=self.mixing, X=K, Y=Yhat, kernel="precomputed") - - if self._fit_svd_solver == "full": - _, S, Vt = self._decompose_full(K_tilde) - elif self._fit_svd_solver in ["arpack", "randomized"]: - _, S, Vt = self._decompose_truncated(K_tilde) - else: - raise ValueError( - "Unrecognized svd_solver='{0}'".format(self._fit_svd_solver) - ) - - U = Vt.T - - P = (self.mixing * np.eye(K.shape[0])) + (1.0 - self.mixing) * (W @ Yhat.T) - - S_inv = np.array([1.0 / s if s > self.tol else 0.0 for s in S]) - - self.pkt_ = P @ U @ np.sqrt(np.diagflat(S_inv)) - - T = K @ self.pkt_ - self.pt__ = np.linalg.lstsq(T, np.eye(T.shape[0]), rcond=self.tol)[0] + self.regressor = regressor def fit(self, X, Y, W=None): r"""Fit the model with X and Y. @@ -276,7 +232,7 @@ def fit(self, X, Y, W=None): scaled so that each feature has a variance of 1 / n_features. W : numpy.ndarray, shape (n_samples, n_properties) - Regression weights, optional when regressor=`precomputed`. If not + Regression weights, optional when regressor = `precomputed`. If not passed, it is assumed that `W = np.linalg.lstsq(K, Y, self.tol)[0]` Returns @@ -284,21 +240,9 @@ def fit(self, X, Y, W=None): self: object Returns the instance itself. """ - if self.regressor not in ["precomputed", None] and not isinstance( - self.regressor, KernelRidge - ): - raise ValueError("Regressor must be an instance of `KernelRidge`") - X, Y = validate_data(self, X, Y, y_numeric=True, multi_output=True) - self.X_fit_ = X.copy() - if self.n_components is None: - if self.svd_solver != "arpack": - self.n_components_ = X.shape[0] - else: - self.n_components_ = X.shape[0] - 1 - else: - self.n_components_ = self.n_components + super().fit(X) K = self._get_kernel(X) @@ -306,7 +250,10 @@ def fit(self, X, Y, W=None): self.centerer_ = KernelNormalizer() K = self.centerer_.fit_transform(K) - self.n_samples_in_, self.n_features_in_ = X.shape + if self.regressor not in ["precomputed", None] and not isinstance( + self.regressor, KernelRidge + ): + raise ValueError("Regressor must be an instance of `KernelRidge`") if self.regressor != "precomputed": if self.regressor is None: @@ -349,12 +296,12 @@ def fit(self, X, Y, W=None): # Check if regressor is fitted; if not, fit with precomputed K # to avoid needing to compute the kernel a second time self.regressor_ = check_krr_fit(regressor, K, X, Y) - W = self.regressor_.dual_coef_.reshape(self.n_samples_in_, -1) # Use this instead of `self.regressor_.predict(K)` # so that we can handle the case of the pre-fitted regressor Yhat = K @ W + # When we have an unfitted regressor, # we fit it with a precomputed K # so we must subsequently "reset" it so that @@ -372,24 +319,7 @@ def fit(self, X, Y, W=None): if W is None: W = np.linalg.lstsq(K, Yhat, self.tol)[0] - # Handle svd_solver - self._fit_svd_solver = self.svd_solver - if self._fit_svd_solver == "auto": - # Small problem or self.n_components_ == 'mle', just call full PCA - if ( - max(self.n_samples_in_, self.n_features_in_) <= 500 - or self.n_components_ == "mle" - ): - self._fit_svd_solver = "full" - elif self.n_components_ >= 1 and self.n_components_ < 0.8 * max( - self.n_samples_in_, self.n_features_in_ - ): - self._fit_svd_solver = "randomized" - # This is also the case of self.n_components_ in (0,1) - else: - self._fit_svd_solver = "full" - - self._fit(K, Yhat, W) + super()._fit(K, Yhat, W) self.ptk_ = self.pt__ @ K self.pty_ = self.pt__ @ Y @@ -425,15 +355,7 @@ def transform(self, X): New data, where n_samples is the number of samples and n_features is the number of features. """ - check_is_fitted(self, ["pkt_", "X_fit_"]) - - X = validate_data(self, X, reset=False) - K = self._get_kernel(X, self.X_fit_) - - if self.center: - K = self.centerer_.transform(K) - - return K @ self.pkt_ + return super().transform(X) def inverse_transform(self, T): r"""Transform input data back to its original space. @@ -457,11 +379,13 @@ def inverse_transform(self, T): ------- X_original : numpy.ndarray, shape (n_samples, n_features) """ - return T @ self.ptx_ + return super().inverse_transform(T) def score(self, X, y): r"""Computes the (negative) loss values for KernelPCovR on the given predictor - and response variables. The loss in :math:`\mathbf{K}`, as explained in + and response variables. + + The loss in :math:`\mathbf{K}`, as explained in [Helfrecht2020]_ does not correspond to a traditional Gram loss :math:`\mathbf{K} - \mathbf{TT}^T`. Indicating the kernel between set A and B as :math:`\mathbf{K}_{AB}`, the projection of set A as :math:`\mathbf{T}_A`, and @@ -519,118 +443,3 @@ def score(self, X, y): Lkpca = np.trace(K_VV - 2 * K_VN @ w + w.T @ K_VV @ w) / np.trace(K_VV) return -sum([Lkpca, Lkrr]) - - def _decompose_truncated(self, mat): - if not 1 <= self.n_components_ <= self.n_samples_in_: - raise ValueError( - "n_components=%r must be between 1 and " - "n_samples=%r with " - "svd_solver='%s'" - % ( - self.n_components_, - self.n_samples_in_, - self.svd_solver, - ) - ) - elif not isinstance(self.n_components_, numbers.Integral): - raise ValueError( - "n_components=%r must be of type int " - "when greater than or equal to 1, was of type=%r" - % (self.n_components_, type(self.n_components_)) - ) - elif self.svd_solver == "arpack" and self.n_components_ == self.n_samples_in_: - raise ValueError( - "n_components=%r must be strictly less than " - "n_samples=%r with " - "svd_solver='%s'" - % ( - self.n_components_, - self.n_samples_in_, - self.svd_solver, - ) - ) - - random_state = check_random_state(self.random_state) - - if self._fit_svd_solver == "arpack": - v0 = _init_arpack_v0(min(mat.shape), random_state) - U, S, Vt = svds(mat, k=self.n_components_, tol=self.tol, v0=v0) - # svds doesn't abide by scipy.linalg.svd/randomized_svd - # conventions, so reverse its outputs. - S = S[::-1] - # flip eigenvectors' sign to enforce deterministic output - U, Vt = svd_flip(U[:, ::-1], Vt[::-1]) - - # We have already eliminated all other solvers, so this must be "randomized" - else: - # sign flipping is done inside - U, S, Vt = randomized_svd( - mat, - n_components=self.n_components_, - n_iter=self.iterated_power, - flip_sign=True, - random_state=random_state, - ) - - U[:, S < self.tol] = 0.0 - Vt[S < self.tol] = 0.0 - S[S < self.tol] = 0.0 - - return U, S, Vt - - def _decompose_full(self, mat): - if self.n_components_ != "mle": - if not (0 <= self.n_components_ <= self.n_samples_in_): - raise ValueError( - "n_components=%r must be between 1 and " - "n_samples=%r with " - "svd_solver='%s'" - % ( - self.n_components_, - self.n_samples_in_, - self.svd_solver, - ) - ) - elif self.n_components_ >= 1: - if not isinstance(self.n_components_, numbers.Integral): - raise ValueError( - "n_components=%r must be of type int " - "when greater than or equal to 1, " - "was of type=%r" - % (self.n_components_, type(self.n_components_)) - ) - - U, S, Vt = linalg.svd(mat, full_matrices=False) - U[:, S < self.tol] = 0.0 - Vt[S < self.tol] = 0.0 - S[S < self.tol] = 0.0 - - # flip eigenvectors' sign to enforce deterministic output - U, Vt = svd_flip(U, Vt) - - # Get variance explained by singular values - explained_variance_ = (S**2) / (self.n_samples_in_ - 1) - total_var = explained_variance_.sum() - explained_variance_ratio_ = explained_variance_ / total_var - - # Postprocess the number of components required - if self.n_components_ == "mle": - self.n_components_ = _infer_dimension( - explained_variance_, self.n_samples_in_ - ) - elif 0 < self.n_components_ < 1.0: - # number of components for which the cumulated explained - # variance percentage is superior to the desired threshold - # side='right' ensures that number of features selected - # their variance is always greater than self.n_components_ float - # passed. More discussion in issue: #15669 - ratio_cumsum = stable_cumsum(explained_variance_ratio_) - self.n_components_ = ( - np.searchsorted(ratio_cumsum, self.n_components_, side="right") + 1 - ) - - return ( - U[:, : self.n_components_], - S[: self.n_components_], - Vt[: self.n_components_], - ) diff --git a/src/skmatter/decomposition/_kpcov.py b/src/skmatter/decomposition/_kpcov.py new file mode 100644 index 000000000..62a5614da --- /dev/null +++ b/src/skmatter/decomposition/_kpcov.py @@ -0,0 +1,268 @@ +from abc import ABCMeta, abstractmethod + +import numbers +import numpy as np + +from scipy import linalg +from scipy.sparse.linalg import svds + +from sklearn.exceptions import NotFittedError +from sklearn.decomposition._base import _BasePCA +from sklearn.linear_model._base import LinearModel +from sklearn.decomposition._pca import _infer_dimension +from sklearn.utils import check_random_state +from sklearn.utils._arpack import _init_arpack_v0 +from sklearn.utils.extmath import randomized_svd, stable_cumsum, svd_flip +from sklearn.utils.validation import check_is_fitted +from sklearn.utils.validation import validate_data +from sklearn.metrics.pairwise import pairwise_kernels + +from skmatter.utils import pcovr_kernel + + +class _BaseKPCov(_BasePCA, LinearModel, metaclass=ABCMeta): + """Base class for KernelPCovR and KernelPCovC methods. + + Warning: This class should not be used directly. + Use derived classes instead. + """ + + @abstractmethod + def __init__( + self, + mixing=0.5, + n_components=None, + svd_solver="auto", + kernel="linear", + gamma=None, + degree=3, + coef0=1, + kernel_params=None, + center=False, + fit_inverse_transform=False, + tol=1e-12, + n_jobs=None, + iterated_power="auto", + random_state=None, + ): + self.mixing = mixing + self.n_components = n_components + self.svd_solver = svd_solver + self.kernel = kernel + self.gamma = gamma + self.degree = degree + self.coef0 = coef0 + self.kernel_params = kernel_params + self.center = center + self.fit_inverse_transform = fit_inverse_transform + self.tol = tol + self.n_jobs = n_jobs + self.iterated_power = iterated_power + self.random_state = random_state + + def _get_kernel(self, X, Y=None): + if callable(self.kernel): + params = self.kernel_params or {} + else: + params = { + "gamma": self.gamma, + "degree": self.degree, + "coef0": self.coef0, + } + + return pairwise_kernels( + X, Y, metric=self.kernel, filter_params=True, n_jobs=self.n_jobs, **params + ) + + def fit(self, X): + """Contains the common functionality for the KPCovR and KPCovC fit methods, + but leaves the rest of the functionality to the subclass. + """ + self.X_fit_ = X.copy() + + if self.n_components is None: + if self.svd_solver != "arpack": + self.n_components_ = X.shape[0] + else: + self.n_components_ = X.shape[0] - 1 + else: + self.n_components_ = self.n_components + + self.n_samples_in_, self.n_features_in_ = X.shape + + # Handle svd_solver + self.fit_svd_solver_ = self.svd_solver + if self.fit_svd_solver_ == "auto": + # Small problem or self.n_components_ == 'mle', just call full PCA + if ( + max(self.n_samples_in_, self.n_features_in_) <= 500 + or self.n_components_ == "mle" + ): + self.fit_svd_solver_ = "full" + elif self.n_components_ >= 1 and self.n_components_ < 0.8 * max( + self.n_samples_in_, self.n_features_in_ + ): + self.fit_svd_solver_ = "randomized" + # This is also the case of self.n_components_ in (0,1) + else: + self.fit_svd_solver_ = "full" + + def _fit(self, K, Yhat, W): + """Fit the model with the computed kernel and approximated properties.""" + K_tilde = pcovr_kernel(mixing=self.mixing, X=K, Y=Yhat, kernel="precomputed") + + if self.fit_svd_solver_ == "full": + _, S, Vt = self._decompose_full(K_tilde) + elif self.fit_svd_solver_ in ["arpack", "randomized"]: + _, S, Vt = self._decompose_truncated(K_tilde) + else: + raise ValueError( + "Unrecognized svd_solver='{0}'".format(self.fit_svd_solver_) + ) + + U = Vt.T + + P = (self.mixing * np.eye(K.shape[0])) + (1.0 - self.mixing) * (W @ Yhat.T) + + S_inv = np.array([1.0 / s if s > self.tol else 0.0 for s in S]) + + self.pkt_ = P @ U @ np.sqrt(np.diagflat(S_inv)) + + T = K @ self.pkt_ + self.pt__ = np.linalg.lstsq(T, np.eye(T.shape[0]), rcond=self.tol)[0] + + def transform(self, X=None): + check_is_fitted(self, ["pkt_", "X_fit_"]) + + X = validate_data(self, X, reset=False) + K = self._get_kernel(X, self.X_fit_) + + if self.center: + K = self.centerer_.transform(K) + + return K @ self.pkt_ + + def inverse_transform(self, T): + if not self.fit_inverse_transform: + raise NotFittedError( + "The fit_inverse_transform parameter was not" + " set to True when instantiating and hence " + "the inverse transform is not available." + ) + + return T @ self.ptx_ + + def _decompose_truncated(self, mat): + if not 1 <= self.n_components_ <= self.n_samples_in_: + raise ValueError( + "n_components=%r must be between 1 and " + "n_samples=%r with " + "svd_solver='%s'" + % ( + self.n_components_, + self.n_samples_in_, + self.svd_solver, + ) + ) + elif not isinstance(self.n_components_, numbers.Integral): + raise ValueError( + "n_components=%r must be of type int " + "when greater than or equal to 1, was of type=%r" + % (self.n_components_, type(self.n_components_)) + ) + elif self.svd_solver == "arpack" and self.n_components_ == self.n_samples_in_: + raise ValueError( + "n_components=%r must be strictly less than " + "n_samples=%r with " + "svd_solver='%s'" + % ( + self.n_components_, + self.n_samples_in_, + self.svd_solver, + ) + ) + + random_state = check_random_state(self.random_state) + + if self.fit_svd_solver_ == "arpack": + v0 = _init_arpack_v0(min(mat.shape), random_state) + U, S, Vt = svds(mat, k=self.n_components_, tol=self.tol, v0=v0) + # svds doesn't abide by scipy.linalg.svd/randomized_svd + # conventions, so reverse its outputs. + S = S[::-1] + # flip eigenvectors' sign to enforce deterministic output + U, Vt = svd_flip(U[:, ::-1], Vt[::-1]) + + # We have already eliminated all other solvers, so this must be "randomized" + else: + # sign flipping is done inside + U, S, Vt = randomized_svd( + mat, + n_components=self.n_components_, + n_iter=self.iterated_power, + flip_sign=True, + random_state=random_state, + ) + + U[:, S < self.tol] = 0.0 + Vt[S < self.tol] = 0.0 + S[S < self.tol] = 0.0 + + return U, S, Vt + + def _decompose_full(self, mat): + if self.n_components_ != "mle": + if not (0 <= self.n_components_ <= self.n_samples_in_): + raise ValueError( + "n_components=%r must be between 1 and " + "n_samples=%r with " + "svd_solver='%s'" + % ( + self.n_components_, + self.n_samples_in_, + self.svd_solver, + ) + ) + elif self.n_components_ >= 1: + if not isinstance(self.n_components_, numbers.Integral): + raise ValueError( + "n_components=%r must be of type int " + "when greater than or equal to 1, " + "was of type=%r" + % (self.n_components_, type(self.n_components_)) + ) + + U, S, Vt = linalg.svd(mat, full_matrices=False) + U[:, S < self.tol] = 0.0 + Vt[S < self.tol] = 0.0 + S[S < self.tol] = 0.0 + + # flip eigenvectors' sign to enforce deterministic output + U, Vt = svd_flip(U, Vt) + + # Get variance explained by singular values + explained_variance_ = (S**2) / (self.n_samples_in_ - 1) + total_var = explained_variance_.sum() + explained_variance_ratio_ = explained_variance_ / total_var + + # Postprocess the number of components required + if self.n_components_ == "mle": + self.n_components_ = _infer_dimension( + explained_variance_, self.n_samples_in_ + ) + elif 0 < self.n_components_ < 1.0: + # number of components for which the cumulated explained + # variance percentage is superior to the desired threshold + # side='right' ensures that number of features selected + # their variance is always greater than self.n_components_ float + # passed. More discussion in issue: #15669 + ratio_cumsum = stable_cumsum(explained_variance_ratio_) + self.n_components_ = ( + np.searchsorted(ratio_cumsum, self.n_components_, side="right") + 1 + ) + + return ( + U[:, : self.n_components_], + S[: self.n_components_], + Vt[: self.n_components_], + ) diff --git a/src/skmatter/decomposition/_pcov.py b/src/skmatter/decomposition/_pcov.py index 878b22ff9..04dc93b4e 100644 --- a/src/skmatter/decomposition/_pcov.py +++ b/src/skmatter/decomposition/_pcov.py @@ -1,11 +1,14 @@ +from abc import ABCMeta, abstractmethod import numbers import warnings import numpy as np from numpy.linalg import LinAlgError + from scipy import linalg from scipy.linalg import sqrtm as MatrixSqrt from scipy.sparse.linalg import svds + from sklearn.decomposition._base import _BasePCA from sklearn.decomposition._pca import _infer_dimension from sklearn.linear_model._base import LinearModel @@ -17,7 +20,14 @@ from skmatter.utils import pcovr_covariance, pcovr_kernel -class _BasePCov(_BasePCA, LinearModel): +class _BasePCov(_BasePCA, LinearModel, metaclass=ABCMeta): + """Base class for PCovR and PCovC methods. + + Warning: This class should not be used directly. + Use derived classes instead. + """ + + @abstractmethod def __init__( self, mixing=0.5, @@ -146,6 +156,7 @@ def _fit_sample_space(self, X, Y, Yhat, W, compute_pty_=True): ) P = (self.mixing * X.T) + (1.0 - self.mixing) * W @ Yhat.T + S_sqrt_inv = np.diagflat([1.0 / np.sqrt(s) if s > self.tol else 0.0 for s in S]) T = Vt.T @ S_sqrt_inv diff --git a/src/skmatter/decomposition/_pcovc.py b/src/skmatter/decomposition/_pcovc.py index fe4434b4f..ec8ce3202 100644 --- a/src/skmatter/decomposition/_pcovc.py +++ b/src/skmatter/decomposition/_pcovc.py @@ -14,15 +14,16 @@ from sklearn.utils import check_array from sklearn.utils.multiclass import check_classification_targets, type_of_target from sklearn.utils.validation import check_is_fitted, validate_data - from skmatter.decomposition import _BasePCov from skmatter.utils import check_cl_fit class PCovC(LinearClassifierMixin, _BasePCov): - r"""Principal Covariates Classification, as described in [Jorgensen2025]_, - determines a latent-space projection :math:`\mathbf{T}` - which minimizes a combined loss in supervised and unsupervised tasks. + r"""Principal Covariates Classification (PCovC). + + As described in [Jorgensen2025]_, PCovC determines a latent-space projection + :math:`\mathbf{T}` which minimizes a combined loss in supervised and + unsupervised tasks. This projection is determined by the eigendecomposition of a modified gram matrix :math:`\mathbf{\tilde{K}}` @@ -44,9 +45,9 @@ class PCovC(LinearClassifierMixin, _BasePCov): \mathbf{Z}\mathbf{Z}^T \mathbf{X} \left(\mathbf{X}^T \mathbf{X}\right)^{-\frac{1}{2}}\right) - For all PCovC methods, it is strongly suggested that :math:`\mathbf{X}` and - :math:`\mathbf{Y}` are centered and scaled to unit variance, otherwise the - results will change drastically near :math:`\alpha \to 0` and :math:`\alpha \to 1`. + For all PCovC methods, it is strongly suggested that :math:`\mathbf{X}` is centered + and scaled to unit variance, otherwise the results will change drastically near + :math:`\alpha \to 0` and :math:`\alpha \to 1`. This can be done with the companion preprocessing classes, where >>> from skmatter.preprocessing import StandardFlexibleScaler as SFS @@ -100,12 +101,19 @@ class PCovC(LinearClassifierMixin, _BasePCov): default=`sample` when :math:`{n_{samples} < n_{features}}` and `feature` when :math:`{n_{features} < n_{samples}}` - classifier: `estimator object` or `precomputed`, default=None - classifier for computing :math:`{\mathbf{Z}}`. The classifier should be one of - `sklearn.linear_model.LogisticRegression`, `sklearn.linear_model.LogisticRegressionCV`, - `sklearn.svm.LinearSVC`, `sklearn.discriminant_analysis.LinearDiscriminantAnalysis`, - `sklearn.linear_model.RidgeClassifier`, `sklearn.linear_model.RidgeClassifierCV`, - `sklearn.linear_model.SGDClassifier`, or `Perceptron`. If a pre-fitted classifier + classifier: `estimator object` or `precomputed`, default=None + classifier for computing :math:`{\mathbf{Z}}`. The classifier should be + one of the following: + + - ``sklearn.linear_model.LogisticRegression()`` + - ``sklearn.linear_model.LogisticRegressionCV()`` + - ``sklearn.svm.LinearSVC()`` + - ``sklearn.discriminant_analysis.LinearDiscriminantAnalysis()`` + - ``sklearn.linear_model.RidgeClassifier()`` + - ``sklearn.linear_model.RidgeClassifierCV()`` + - ``sklearn.linear_model.Perceptron()`` + + If a pre-fitted classifier is provided, it is used to compute :math:`{\mathbf{Z}}`. Note that any pre-fitting of the classifier will be lost if `PCovC` is within a composite estimator that enforces cloning, e.g., @@ -219,9 +227,10 @@ def __init__( self.classifier = classifier def fit(self, X, Y, W=None): - r"""Fit the model with X and Y. Note that W is taken from the - coefficients of a linear classifier fit between X and Y to compute - Z: + r"""Fit the model with X and Y. + + Note that W is taken from the coefficients of a linear classifier fit + between X and Y to compute Z: .. math:: \mathbf{Z} = \mathbf{X} \mathbf{W} @@ -244,8 +253,8 @@ def fit(self, X, Y, W=None): Y : numpy.ndarray, shape (n_samples,) Training data, where n_samples is the number of samples. - W : numpy.ndarray, shape (n_features, n_properties) - Classification weights, optional when classifier= `precomputed`. If + W : numpy.ndarray, shape (n_features, n_classes) + Classification weights, optional when classifier = `precomputed`. If not passed, it is assumed that the weights will be taken from a linear classifier fit between :math:`\mathbf{X}` and :math:`\mathbf{Y}` """ @@ -394,6 +403,7 @@ def decision_function(self, X=None, T=None): Original data for which we want to get confidence scores, where n_samples is the number of samples and n_features is the number of features. + T : ndarray, shape (n_samples, n_components) Projected data for which we want to get confidence scores, where n_samples is the number of samples and n_components is the diff --git a/src/skmatter/decomposition/_pcovr.py b/src/skmatter/decomposition/_pcovr.py index 417a82c12..9a038c6ea 100644 --- a/src/skmatter/decomposition/_pcovr.py +++ b/src/skmatter/decomposition/_pcovr.py @@ -9,9 +9,11 @@ class PCovR(RegressorMixin, MultiOutputMixin, _BasePCov): - r"""Principal Covariates Regression, as described in [deJong1992]_, - determines a latent-space projection :math:`\mathbf{T}` which - minimizes a combined loss in supervised and unsupervised tasks. + r"""Principal Covariates Regression (PCovR). + + As described in [deJong1992]_, PCovR determines a latent-space projection + :math:`\mathbf{T}` which minimizes a combined loss in supervised and + unsupervised tasks. This projection is determined by the eigendecomposition of a modified gram matrix :math:`\mathbf{\tilde{K}}` @@ -216,7 +218,7 @@ def fit(self, X, Y, W=None): Training data, where n_samples is the number of samples and n_properties is the number of properties - It is suggested that :math:`\mathbf{X}` be centered by its column- means and + It is suggested that :math:`\mathbf{Y}` be centered by its column- means and scaled. If features are related, the matrix should be scaled to have unit variance, otherwise :math:`\mathbf{Y}` should be scaled so that each feature has a variance of 1 / n_features. diff --git a/src/skmatter/utils/__init__.py b/src/skmatter/utils/__init__.py index 70e32616c..4f0c32415 100644 --- a/src/skmatter/utils/__init__.py +++ b/src/skmatter/utils/__init__.py @@ -9,8 +9,6 @@ Y_sample_orthogonalizer, ) -from ._pcovc_utils import check_cl_fit - from ._pcovr_utils import ( check_krr_fit, check_lr_fit, @@ -18,6 +16,8 @@ pcovr_kernel, ) +from ._pcovc_utils import check_cl_fit + from ._progress_bar import ( get_progress_bar, no_progress_bar, @@ -35,6 +35,7 @@ "pcovr_kernel", "check_krr_fit", "check_lr_fit", + "check_cl_fit", "X_orthogonalizer", "Y_sample_orthogonalizer", "Y_feature_orthogonalizer", diff --git a/src/skmatter/utils/_pcovr_utils.py b/src/skmatter/utils/_pcovr_utils.py index 29463b633..0ddcba147 100644 --- a/src/skmatter/utils/_pcovr_utils.py +++ b/src/skmatter/utils/_pcovr_utils.py @@ -162,7 +162,7 @@ def pcovr_covariance( \mathbf{\hat{Y}}\mathbf{\hat{Y}}^T \mathbf{X} \left(\mathbf{X}^T \mathbf{X}\right)^{-\frac{1}{2}}\right) - where :math:`\mathbf{\hat{Y}}`` are the properties obtained by linear regression. + where :math:`\mathbf{\hat{Y}}` are the properties obtained by linear regression. Parameters ---------- diff --git a/tests/test_kernel_pcovc.py b/tests/test_kernel_pcovc.py new file mode 100644 index 000000000..10ef589af --- /dev/null +++ b/tests/test_kernel_pcovc.py @@ -0,0 +1,496 @@ +import unittest + +import numpy as np +from sklearn import exceptions +from sklearn.calibration import LinearSVC +from sklearn.datasets import load_breast_cancer as get_dataset +from sklearn.naive_bayes import GaussianNB +from sklearn.utils.validation import check_X_y +from sklearn.preprocessing import StandardScaler +from sklearn.linear_model import LogisticRegression, RidgeClassifier +from sklearn.metrics.pairwise import pairwise_kernels + +from skmatter.decomposition import KernelPCovC + + +class KernelPCovCBaseTest(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.random_state = np.random.RandomState(0) + + self.error_tol = 1e-6 + + self.X, self.Y = get_dataset(return_X_y=True) + + # for the sake of expedience, only use a subset of the dataset + idx = self.random_state.choice(len(self.X), 100) + self.X = self.X[idx] + self.Y = self.Y[idx] + + scaler = StandardScaler() + self.X = scaler.fit_transform(self.X) + + self.model = ( + lambda mixing=0.5, + classifier=LogisticRegression(), + n_components=4, + **kwargs: KernelPCovC( + mixing=mixing, + classifier=classifier, + n_components=n_components, + svd_solver=kwargs.pop("svd_solver", "full"), + **kwargs, + ) + ) + + def setUp(self): + pass + + +class KernelPCovCErrorTest(KernelPCovCBaseTest): + def test_cl_with_x_errors(self): + """ + Check that KernelPCovC returns a non-null property prediction + and that the prediction error increases with `mixing` + """ + prev_error = -1.0 + + for mixing in np.linspace(0, 1, 6): + kpcovc = KernelPCovC(mixing=mixing, n_components=4, tol=1e-12) + kpcovc.fit(self.X, self.Y) + + error = ( + np.linalg.norm(self.Y - kpcovc.predict(self.X)) ** 2.0 + / np.linalg.norm(self.Y) ** 2.0 + ) + + with self.subTest(error=error): + self.assertFalse(np.isnan(error)) + with self.subTest(error=error, alpha=round(mixing, 4)): + self.assertGreaterEqual(error, prev_error - self.error_tol) + + prev_error = error + + def test_cl_with_t_errors(self): + """Check that KernelPCovC returns a non-null property prediction from + the latent space projection and that the prediction error increases with + `mixing`. + """ + prev_error = -1.0 + + for mixing in np.linspace(0, 1, 6): + kpcovc = self.model(mixing=mixing, n_components=2, tol=1e-12) + kpcovc.fit(self.X, self.Y) + + T = kpcovc.transform(self.X) + + error = ( + np.linalg.norm(self.Y - kpcovc.predict(T=T)) ** 2.0 + / np.linalg.norm(self.Y) ** 2.0 + ) + + with self.subTest(error=error): + self.assertFalse(np.isnan(error)) + with self.subTest(error=error, alpha=round(mixing, 4)): + self.assertGreaterEqual(error, prev_error - self.error_tol) + + prev_error = error + + def test_reconstruction_errors(self): + """Check that KernelPCovC returns a non-null reconstructed X and that the + reconstruction error decreases with `mixing`. + """ + prev_error = 1.0 + + for mixing in np.linspace(0, 1, 11): + kpcovc = self.model( + mixing=mixing, n_components=2, tol=1e-12, fit_inverse_transform=True + ) + kpcovc.fit(self.X, self.Y) + + Xr = kpcovc.inverse_transform(kpcovc.transform(self.X)) + error = np.linalg.norm(self.X - Xr) ** 2.0 / np.linalg.norm(self.X) ** 2.0 + + with self.subTest(error=error): + self.assertFalse(np.isnan(error)) + with self.subTest(error=error, alpha=round(mixing, 4)): + self.assertLessEqual(error, prev_error + self.error_tol) + + prev_error = error + + +class KernelPCovCInfrastructureTest(KernelPCovCBaseTest): + def test_nonfitted_failure(self): + """ + Check that KernelPCovC will raise a `NonFittedError` if + `transform` is called before the model is fitted + """ + kpcovc = KernelPCovC(mixing=0.5, n_components=4, tol=1e-12) + with self.assertRaises(exceptions.NotFittedError): + _ = kpcovc.transform(self.X) + + def test_no_arg_predict(self): + """ + Check that KernelPCovC will raise a `ValueError` if + `predict` is called without arguments + """ + kpcovc = KernelPCovC(mixing=0.5, n_components=4, tol=1e-12) + kpcovc.fit(self.X, self.Y) + with self.assertRaises(ValueError): + _ = kpcovc.predict() + + def test_T_shape(self): + """ + Check that KernelPCovC returns a latent space projection + consistent with the shape of the input matrix + """ + n_components = 5 + kpcovc = KernelPCovC(mixing=0.5, n_components=n_components, tol=1e-12) + kpcovc.fit(self.X, self.Y) + T = kpcovc.transform(self.X) + self.assertTrue(check_X_y(self.X, T, multi_output=True) == (self.X, T)) + self.assertTrue(T.shape[-1] == n_components) + + def test_Z_shape(self): + """Check that KPCovC returns an evidence matrix consistent with the number + of samples and the number of classes. + """ + n_components = 5 + kpcovc = self.model(n_components=n_components, tol=1e-12) + kpcovc.fit(self.X, self.Y) + + # Shape (n_samples, ) for binary classifcation + Z = kpcovc.decision_function(self.X) + + self.assertTrue(Z.ndim == 1) + self.assertTrue(Z.shape[0] == self.X.shape[0]) + + # Modify Y so that it now contains three classes + Y_multiclass = self.Y.copy() + Y_multiclass[0] = 2 + kpcovc.fit(self.X, Y_multiclass) + n_classes = len(np.unique(Y_multiclass)) + + # Shape (n_samples, n_classes) for multiclass classification + Z = kpcovc.decision_function(self.X) + + self.assertTrue(Z.ndim == 2) + self.assertTrue((Z.shape[0], Z.shape[1]) == (self.X.shape[0], n_classes)) + + def test_decision_function(self): + """Check that KPCovC's decision_function works when only T is + provided and throws an error when appropriate. + """ + kpcovc = self.model(center=True) + kpcovc.fit(self.X, self.Y) + + with self.assertRaises(ValueError) as cm: + _ = kpcovc.decision_function() + self.assertEqual( + str(cm.exception), + "Either X or T must be supplied.", + ) + + _ = kpcovc.decision_function(self.X) + T = kpcovc.transform(self.X) + _ = kpcovc.decision_function(T=T) + + def test_no_centerer(self): + """Tests that when center=False, no centerer exists.""" + kpcovc = self.model(center=False) + kpcovc.fit(self.X, self.Y) + + with self.assertRaises(AttributeError): + kpcovc.centerer_ + + def test_centerer(self): + """Tests that all functionalities that rely on the centerer work properly.""" + kpcovc = self.model(center=True) + kpcovc.fit(self.X, self.Y) + + self.assertTrue(hasattr(kpcovc, "centerer_")) + _ = kpcovc.predict(self.X) + _ = kpcovc.transform(self.X) + _ = kpcovc.score(self.X, self.Y) + + def test_prefit_classifier(self): + # in KPCovC, our classifiers don't compute the kernel for us, hence we only + # allow prefit classifiers on K and y + kernel_params = {"kernel": "rbf", "gamma": 0.1, "degree": 3, "coef0": 0} + K = pairwise_kernels(self.X, metric="rbf", filter_params=True, **kernel_params) + + classifier = LinearSVC() + classifier.fit(K, self.Y) + + kpcovc = KernelPCovC(mixing=0.5, classifier=classifier, **kernel_params) + kpcovc.fit(self.X, self.Y) + + Z_classifier = classifier.decision_function(K).reshape(K.shape[0], -1) + W_classifier = classifier.coef_.T.reshape(K.shape[1], -1) + + Z_kpcovc = kpcovc.z_classifier_.decision_function(K).reshape(K.shape[0], -1) + W_kpcovc = kpcovc.z_classifier_.coef_.T.reshape(K.shape[1], -1) + + self.assertTrue(np.allclose(Z_classifier, Z_kpcovc)) + self.assertTrue(np.allclose(W_classifier, W_kpcovc)) + + def test_classifier_modifications(self): + classifier = RidgeClassifier() + kpcovc = self.model(mixing=0.5, classifier=classifier, kernel="rbf", gamma=0.1) + + # KPCovC classifier matches the original + self.assertTrue(classifier.get_params() == kpcovc.classifier.get_params()) + + # KPCovC classifier updates its parameters + # to match the original classifier + classifier.set_params(random_state=3) + self.assertTrue(classifier.get_params() == kpcovc.classifier.get_params()) + + # Fitting classifier outside KPCovC fits the KPCovC classifier + classifier.fit(self.X, self.Y) + self.assertTrue(hasattr(kpcovc.classifier, "coef_")) + + def test_incompatible_classifier(self): + classifier = GaussianNB() + classifier.fit(self.X, self.Y) + kpcovc = self.model(mixing=0.5, classifier=classifier) + + with self.assertRaises(ValueError) as cm: + kpcovc.fit(self.X, self.Y) + self.assertEqual( + str(cm.exception), + "Classifier must be an instance of " + "`LogisticRegression`, `LogisticRegressionCV`, `LinearSVC`, " + "`LinearDiscriminantAnalysis`, `RidgeClassifier`, " + "`RidgeClassifierCV`, `SGDClassifier`, `Perceptron`, " + "or `precomputed`", + ) + + def test_none_classifier(self): + kpcovc = KernelPCovC(mixing=0.5, classifier=None) + kpcovc.fit(self.X, self.Y) + self.assertTrue(kpcovc.classifier is None) + self.assertTrue(kpcovc.classifier_ is not None) + + def test_incompatible_coef_shape(self): + kernel_params = {"kernel": "rbf", "gamma": 0.1, "degree": 3, "coef0": 0} + + K = pairwise_kernels(self.X, metric="rbf", filter_params=True, **kernel_params) + + # Modify Y to be multiclass + Y_multiclass = self.Y.copy() + Y_multiclass[0] = 2 + + classifier1 = LinearSVC() + classifier1.fit(K, Y_multiclass) + kpcovc1 = self.model(mixing=0.5, classifier=classifier1, **kernel_params) + + # Binary classification shape mismatch + with self.assertRaises(ValueError) as cm: + kpcovc1.fit(self.X, self.Y) + self.assertEqual( + str(cm.exception), + "For binary classification, expected classifier coefficients " + "to have shape (1, %d) but got shape %r" + % (K.shape[1], classifier1.coef_.shape), + ) + + classifier2 = LinearSVC() + classifier2.fit(K, self.Y) + kpcovc2 = self.model(mixing=0.5, classifier=classifier2) + + # Multiclass classification shape mismatch + with self.assertRaises(ValueError) as cm: + kpcovc2.fit(self.X, Y_multiclass) + self.assertEqual( + str(cm.exception), + "For multiclass classification, expected classifier coefficients " + "to have shape (%d, %d) but got shape %r" + % (len(np.unique(Y_multiclass)), K.shape[1], classifier2.coef_.shape), + ) + + def test_precomputed_classification(self): + kernel_params = {"kernel": "rbf", "gamma": 0.1, "degree": 3, "coef0": 0} + K = pairwise_kernels(self.X, metric="rbf", filter_params=True, **kernel_params) + + classifier = LogisticRegression() + classifier.fit(K, self.Y) + + W = classifier.coef_.T.reshape(K.shape[1], -1) + kpcovc1 = self.model(mixing=0.5, classifier="precomputed", **kernel_params) + kpcovc1.fit(self.X, self.Y, W) + t1 = kpcovc1.transform(self.X) + + kpcovc2 = self.model(mixing=0.5, classifier=classifier, **kernel_params) + kpcovc2.fit(self.X, self.Y) + t2 = kpcovc2.transform(self.X) + + self.assertTrue(np.linalg.norm(t1 - t2) < self.error_tol) + + # Now check for match when W is not passed: + kpcovc3 = self.model(mixing=0.5, classifier="precomputed", **kernel_params) + kpcovc3.fit(self.X, self.Y) + t3 = kpcovc3.transform(self.X) + + self.assertTrue(np.linalg.norm(t3 - t2) < self.error_tol) + self.assertTrue(np.linalg.norm(t3 - t1) < self.error_tol) + + +class KernelTests(KernelPCovCBaseTest): + def test_kernel_types(self): + """Check that KernelPCovC can handle all kernels passable to sklearn + kernel classes, including callable kernels + """ + + def _linear_kernel(X, Y): + return X @ Y.T + + kernel_params = { + "poly": {"degree": 2}, + "rbf": {"gamma": 3.0}, + "sigmoid": {"gamma": 3.0, "coef0": 0.5}, + } + + for kernel in ["linear", "poly", "rbf", "sigmoid", "cosine", _linear_kernel]: + with self.subTest(kernel=kernel): + kpcovc = KernelPCovC( + mixing=0.5, + n_components=2, + classifier=LogisticRegression(), + kernel=kernel, + **kernel_params.get(kernel, {}), + ) + kpcovc.fit(self.X, self.Y) + + +class KernelPCovCTestSVDSolvers(KernelPCovCBaseTest): + def test_svd_solvers(self): + """ + Check that KPCovC works with all svd_solver modes and assigns + the right n_components + """ + for solver in ["arpack", "full", "randomized", "auto"]: + with self.subTest(solver=solver): + kpcovc = self.model(tol=1e-12, n_components=None, svd_solver=solver) + kpcovc.fit(self.X, self.Y) + + if solver == "arpack": + self.assertTrue(kpcovc.n_components_ == self.X.shape[0] - 1) + else: + self.assertTrue(kpcovc.n_components_ == self.X.shape[0]) + + n_component_solvers = { + "mle": "full", + int(0.75 * max(self.X.shape)): "randomized", + 0.1: "full", + } + for n_components, solver in n_component_solvers.items(): + with self.subTest(solver=solver, n_components=n_components): + kpcovc = self.model( + tol=1e-12, n_components=n_components, svd_solver="auto" + ) + if solver == "randomized": + n_copies = (501 // max(self.X.shape)) + 1 + X = np.hstack(np.repeat(self.X.copy(), n_copies)).reshape( + self.X.shape[0] * n_copies, -1 + ) + Y = np.hstack(np.repeat(self.Y.copy(), n_copies)).reshape( + self.X.shape[0] * n_copies, -1 + ) + kpcovc.fit(X, Y) + else: + kpcovc.fit(self.X, self.Y) + + self.assertTrue(kpcovc.fit_svd_solver_ == solver) + + def test_bad_solver(self): + """ + Check that KPCovC will not work with a solver that isn't in + ['arpack', 'full', 'randomized', 'auto'] + """ + with self.assertRaises(ValueError) as cm: + kpcovc = self.model(svd_solver="bad") + kpcovc.fit(self.X, self.Y) + + self.assertEqual(str(cm.exception), "Unrecognized svd_solver='bad'") + + def test_good_n_components(self): + """Check that KPCovC will work with any allowed values of n_components.""" + # this one should pass + kpcovc = self.model(n_components=0.5, svd_solver="full") + kpcovc.fit(self.X, self.Y) + + for svd_solver in ["auto", "full"]: + # this one should pass + kpcovc = self.model(n_components=2, svd_solver=svd_solver) + kpcovc.fit(self.X, self.Y) + + # this one should pass + kpcovc = self.model(n_components="mle", svd_solver=svd_solver) + kpcovc.fit(self.X, self.Y) + + def test_bad_n_components(self): + """Check that KPCovC will not work with any prohibited values of n_components""" + with self.subTest(type="negative_ncomponents"): + with self.assertRaises(ValueError) as cm: + kpcovc = self.model(n_components=-1, svd_solver="auto") + kpcovc.fit(self.X, self.Y) + + self.assertEqual( + str(cm.exception), + "n_components=%r must be between 1 and " + "n_samples=%r with " + "svd_solver='%s'" + % ( + kpcovc.n_components, + self.X.shape[0], + kpcovc.svd_solver, + ), + ) + with self.subTest(type="0_ncomponents"): + with self.assertRaises(ValueError) as cm: + kpcovc = self.model(n_components=0, svd_solver="randomized") + kpcovc.fit(self.X, self.Y) + + self.assertEqual( + str(cm.exception), + "n_components=%r must be between 1 and " + "n_samples=%r with " + "svd_solver='%s'" + % ( + kpcovc.n_components, + self.X.shape[0], + kpcovc.svd_solver, + ), + ) + with self.subTest(type="arpack_X_ncomponents"): + with self.assertRaises(ValueError) as cm: + kpcovc = self.model(n_components=self.X.shape[0], svd_solver="arpack") + kpcovc.fit(self.X, self.Y) + self.assertEqual( + str(cm.exception), + "n_components=%r must be strictly less than " + "n_samples=%r with " + "svd_solver='%s'" + % ( + kpcovc.n_components, + self.X.shape[0], + kpcovc.svd_solver, + ), + ) + + for svd_solver in ["auto", "full"]: + with self.subTest(type="pi_ncomponents"): + with self.assertRaises(ValueError) as cm: + kpcovc = self.model(n_components=np.pi, svd_solver=svd_solver) + kpcovc.fit(self.X, self.Y) + self.assertEqual( + str(cm.exception), + "n_components=%r must be of type int " + "when greater than or equal to 1, was of type=%r" + % (kpcovc.n_components, type(kpcovc.n_components)), + ) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/test_kernel_pcovr.py b/tests/test_kernel_pcovr.py index aeaf30dff..aebdb404a 100644 --- a/tests/test_kernel_pcovr.py +++ b/tests/test_kernel_pcovr.py @@ -7,7 +7,7 @@ from sklearn.linear_model import Ridge, RidgeCV from sklearn.utils.validation import check_X_y -from skmatter.decomposition import KernelPCovR, PCovR +from skmatter.decomposition import PCovR, KernelPCovR from skmatter.preprocessing import StandardFlexibleScaler as SFS @@ -59,7 +59,6 @@ def test_lr_with_x_errors(self): for mixing in np.linspace(0, 1, 6): kpcovr = KernelPCovR(mixing=mixing, n_components=2, tol=1e-12) kpcovr.fit(self.X, self.Y) - error = ( np.linalg.norm(self.Y - kpcovr.predict(self.X)) ** 2.0 / np.linalg.norm(self.Y) ** 2.0 @@ -163,7 +162,7 @@ def test_T_shape(self): kpcovr = KernelPCovR(mixing=0.5, n_components=n_components, tol=1e-12) kpcovr.fit(self.X, self.Y) T = kpcovr.transform(self.X) - self.assertTrue(check_X_y(self.X, T, multi_output=True)) + self.assertTrue(check_X_y(self.X, T, multi_output=True) == (self.X, T)) self.assertTrue(T.shape[-1] == n_components) def test_no_centerer(self): @@ -219,12 +218,12 @@ def test_regressor_modifications(self): # kernel parameters now inconsistent with self.assertRaises(ValueError) as cm: kpcovr.fit(self.X, self.Y) - self.assertTrue( + self.assertEqual( str(cm.exception), "Kernel parameter mismatch: the regressor has kernel parameters " - "{kernel: linear, gamma: 0.2, degree: 3, coef0: 1, kernel_params: None}" + "{kernel: 'rbf', gamma: 0.2, degree: 3, coef0: 1, kernel_params: None}" " and KernelPCovR was initialized with kernel parameters " - "{kernel: linear, gamma: 0.1, degree: 3, coef0: 1, kernel_params: None}", + "{kernel: 'rbf', gamma: 0.1, degree: 3, coef0: 1, kernel_params: None}", ) def test_incompatible_regressor(self): @@ -234,7 +233,7 @@ def test_incompatible_regressor(self): with self.assertRaises(ValueError) as cm: kpcovr.fit(self.X, self.Y) - self.assertTrue( + self.assertEqual( str(cm.exception), "Regressor must be an instance of `KernelRidge`", ) @@ -250,29 +249,33 @@ def test_incompatible_coef_shape(self): # Don't need to test X shape, since this should # be caught by sklearn's _validate_data regressor = KernelRidge(alpha=1e-8, kernel="linear") - regressor.fit(self.X, self.Y[:, 0][:, np.newaxis]) + regressor.fit(self.X, self.Y[:, 0]) kpcovr = self.model(mixing=0.5, regressor=regressor) # Dimension mismatch with self.assertRaises(ValueError) as cm: - kpcovr.fit(self.X, np.zeros(self.Y.shape + (2,))) - self.assertTrue( + kpcovr.fit(self.X, self.Y) + self.assertEqual( str(cm.exception), "The regressor coefficients have a dimension incompatible " "with the supplied target space. " "The coefficients have dimension %d and the targets " - "have dimension %d" % (regressor.dual_coef_.ndim, self.Y[:, 0].ndim), + "have dimension %d" % (regressor.dual_coef_.ndim, self.Y.ndim), ) + Y_double = np.column_stack((self.Y, self.Y)) + Y_triple = np.column_stack((Y_double, self.Y)) + regressor.fit(self.X, Y_double) + # Shape mismatch (number of targets) with self.assertRaises(ValueError) as cm: - kpcovr.fit(self.X, self.Y) - self.assertTrue( + kpcovr.fit(self.X, Y_triple) + self.assertEqual( str(cm.exception), "The regressor coefficients have a shape incompatible " "with the supplied target space. " "The coefficients have shape %r and the targets " - "have shape %r" % (regressor.dual_coef_.shape, self.Y.shape), + "have shape %r" % (regressor.dual_coef_.shape, Y_triple.shape), ) def test_precomputed_regression(self): @@ -419,7 +422,7 @@ def test_svd_solvers(self): else: kpcovr.fit(self.X, self.Y) - self.assertTrue(kpcovr._fit_svd_solver == solver) + self.assertTrue(kpcovr.fit_svd_solver_ == solver) def test_bad_solver(self): """ @@ -454,10 +457,10 @@ def test_bad_n_components(self): kpcovr = self.model(n_components=-1, svd_solver="auto") kpcovr.fit(self.X, self.Y) - self.assertTrue( + self.assertEqual( str(cm.exception), - "self.n_components=%r must be between 0 and " - "min(n_samples, n_features)=%r with " + "n_components=%r must be between 1 and " + "n_samples=%r with " "svd_solver='%s'" % ( kpcovr.n_components, @@ -470,10 +473,10 @@ def test_bad_n_components(self): kpcovr = self.model(n_components=0, svd_solver="randomized") kpcovr.fit(self.X, self.Y) - self.assertTrue( + self.assertEqual( str(cm.exception), - "self.n_components=%r must be between 1 and " - "min(n_samples, n_features)=%r with " + "n_components=%r must be between 1 and " + "n_samples=%r with " "svd_solver='%s'" % ( kpcovr.n_components, @@ -485,10 +488,10 @@ def test_bad_n_components(self): with self.assertRaises(ValueError) as cm: kpcovr = self.model(n_components=self.X.shape[0], svd_solver="arpack") kpcovr.fit(self.X, self.Y) - self.assertTrue( + self.assertEqual( str(cm.exception), - "self.n_components=%r must be strictly less than " - "min(n_samples, n_features)=%r with " + "n_components=%r must be strictly less than " + "n_samples=%r with " "svd_solver='%s'" % ( kpcovr.n_components, @@ -502,9 +505,9 @@ def test_bad_n_components(self): with self.assertRaises(ValueError) as cm: kpcovr = self.model(n_components=np.pi, svd_solver=svd_solver) kpcovr.fit(self.X, self.Y) - self.assertTrue( + self.assertEqual( str(cm.exception), - "self.n_components=%r must be of type int " + "n_components=%r must be of type int " "when greater than or equal to 1, was of type=%r" % (kpcovr.n_components, type(kpcovr.n_components)), ) diff --git a/tests/test_pcovc.py b/tests/test_pcovc.py index dfd546035..5746c610f 100644 --- a/tests/test_pcovc.py +++ b/tests/test_pcovc.py @@ -3,6 +3,7 @@ import numpy as np from sklearn import exceptions +from sklearn.calibration import LinearSVC from sklearn.datasets import load_breast_cancer as get_dataset from sklearn.decomposition import PCA from sklearn.linear_model import LogisticRegression @@ -112,7 +113,7 @@ def test_cl_with_x_errors(self): prev_error = error def test_cl_with_t_errors(self): - """Check that PCovc returns a non-null property prediction from the latent space + """Check that PCovC returns a non-null property prediction from the latent space projection and that the prediction error increases with `mixing`. """ prev_error = -1.0 @@ -286,7 +287,7 @@ def test_bad_n_components(self): """Check that PCovC will not work with any prohibited values of n_components.""" with self.assertRaises(ValueError) as cm: pcovc = self.model( - n_components="mle", classifier=LogisticRegression(), svd_solver="full" + n_components="mle", classifier=LinearSVC(), svd_solver="full" ) pcovc.fit(self.X[:20], self.Y[:20]) self.assertEqual( @@ -401,6 +402,14 @@ def test_T_shape(self): self.assertTrue(check_X_y(self.X, T, multi_output=True)) self.assertTrue(T.shape[-1] == n_components) + def test_Y_Shape(self): + pcovc = self.model() + Y = np.vstack(self.Y) + pcovc.fit(self.X, Y) + + self.assertEqual(pcovc.pxz_.shape[0], self.X.shape[1]) + self.assertEqual(pcovc.ptz_.shape[0], pcovc.n_components_) + def test_Z_shape(self): """Check that PCovC returns an evidence matrix consistent with the number of samples and the number of classes. @@ -427,22 +436,30 @@ def test_Z_shape(self): self.assertTrue(Z.ndim == 2) self.assertTrue((Z.shape[0], Z.shape[1]) == (self.X.shape[0], n_classes)) + def test_decision_function(self): + """Check that PCovC's decision_function works when only T is + provided and throws an error when appropriate. + """ + pcovc = self.model() + pcovc.fit(self.X, self.Y) + with self.assertRaises(ValueError) as cm: + _ = pcovc.decision_function() + self.assertEqual( + str(cm.exception), + "Either X or T must be supplied.", + ) + + T = pcovc.transform(self.X) + _ = pcovc.decision_function(T=T) + def test_default_ncomponents(self): pcovc = PCovC(mixing=0.5) pcovc.fit(self.X, self.Y) self.assertEqual(pcovc.n_components_, min(self.X.shape)) - def test_Y_Shape(self): - pcovc = self.model() - Y = np.vstack(self.Y) - pcovc.fit(self.X, Y) - - self.assertEqual(pcovc.pxz_.shape[0], self.X.shape[1]) - self.assertEqual(pcovc.ptz_.shape[0], pcovc.n_components_) - def test_prefit_classifier(self): - classifier = LogisticRegression() + classifier = LinearSVC() classifier.fit(self.X, self.Y) pcovc = self.model(mixing=0.5, classifier=classifier) pcovc.fit(self.X, self.Y) @@ -482,7 +499,7 @@ def test_precomputed_classification(self): self.assertTrue(np.linalg.norm(t3 - t1) < self.error_tol) def test_classifier_modifications(self): - classifier = LogisticRegression() + classifier = LinearSVC() pcovc = self.model(mixing=0.5, classifier=classifier) # PCovC classifier matches the original @@ -504,7 +521,6 @@ def test_classifier_modifications(self): self.assertTrue(classifier.get_params() != pcovc.classifier_.get_params()) def test_incompatible_classifier(self): - self.maxDiff = None classifier = GaussianNB() classifier.fit(self.X, self.Y) pcovc = self.model(mixing=0.5, classifier=classifier) diff --git a/tests/test_pcovr.py b/tests/test_pcovr.py index 27f99e8e7..597dcc2ba 100644 --- a/tests/test_pcovr.py +++ b/tests/test_pcovr.py @@ -392,7 +392,7 @@ def test_T_shape(self): pcovr = self.model(n_components=n_components, tol=1e-12) pcovr.fit(self.X, self.Y) T = pcovr.transform(self.X) - self.assertTrue(check_X_y(self.X, T, multi_output=True)) + self.assertTrue(check_X_y(self.X, T, multi_output=True) == (self.X, T)) self.assertTrue(T.shape[-1] == n_components) def test_default_ncomponents(self):