Skip to content

Commit 9e4e3d8

Browse files
committed
Switching from KPCovC w/SVC back to KPCovC w/linear classifiers
1 parent b701e23 commit 9e4e3d8

File tree

12 files changed

+875
-593
lines changed

12 files changed

+875
-593
lines changed

CHANGELOG

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ Unreleased
1919
- Add ``_BasePCov`` class (#248)
2020
- Add ``PCovC`` class that inherits shared functionality from ``_BasePCov`` (#248)
2121
- Add ``PCovC`` testing suite and examples (#248)
22-
- Modify ``PCovR`` to inherit shared functionality from ``_BasePCov_`` (#248)
23-
- Update to sklearn >= 1.7.0 and scipy >= 1.15.0 (#239, #257)
22+
- Modify ``PCovR`` to inherit shared functionality from ``_BasePCov`` (#248)
23+
- Update to sklearn >= 1.6.0 and scipy >= 1.15.0 (#239)
2424
- Fixed moved function import from scipy and bump scipy dependency to 1.15.0 (#236)
2525
- Fix rendering issues for `SparseKDE` and `QuickShift` (#236)
2626
- Updating ``FPS`` to allow a numpy array of ints as an initialize parameter (#145)

examples/pcovc/KPCovC_Moons.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@
177177
" else model.fit_transform(X_train_scaled, y_train)\n",
178178
" )\n",
179179
" t_test = model.transform(X_test_scaled)\n",
180-
" axes[i].scatter(t_test[:, 0], t_test[:, 1], alpha=0.4, cmap=cm_bright, c=y_test)\n",
180+
" axes[i].scatter(t_test[:, 0], t_test[:, 1], alpha=0.6, cmap=cm_bright, c=y_test)\n",
181181
"\n",
182182
" axes[i].scatter(t_train[:, 0], t_train[:, 1], cmap=cm_bright, c=y_train)\n",
183183
" axes[i].set_title(models[model])"

examples/pcovc/KPCovC_Moons_linear.ipynb

Lines changed: 324 additions & 0 deletions
Large diffs are not rendered by default.

examples/pcovc/KPCovC_Rings.ipynb

Lines changed: 0 additions & 275 deletions
This file was deleted.

examples/pcovc/KPCovC_Rings_linear.ipynb

Lines changed: 343 additions & 0 deletions
Large diffs are not rendered by default.

src/skmatter/decomposition/_kernel_pcovc.py

Lines changed: 82 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,36 @@
11
import numpy as np
2-
from sklearn.calibration import LinearSVC, check_classification_targets
32

4-
5-
from sklearn.svm import SVC
6-
from sklearn.utils import (
7-
check_array,
8-
)
9-
from sklearn.utils.multiclass import check_classification_targets, type_of_target
3+
from sklearn import clone
104
from sklearn.svm import LinearSVC
11-
5+
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
6+
from sklearn.multioutput import MultiOutputClassifier
7+
from sklearn.linear_model import (
8+
Perceptron,
9+
RidgeClassifier,
10+
RidgeClassifierCV,
11+
LogisticRegression,
12+
LogisticRegressionCV,
13+
SGDClassifier,
14+
)
1215
from sklearn.utils import check_array
1316
from sklearn.utils.validation import check_is_fitted, validate_data
1417
from sklearn.linear_model._base import LinearClassifierMixin
1518
from sklearn.utils.multiclass import check_classification_targets, type_of_target
1619

17-
from skmatter.utils._pcovc_utils import check_svc_fit
18-
from skmatter.decomposition import _BaseKPCov
19-
from sklearn.metrics.pairwise import pairwise_kernels
2020
from skmatter.preprocessing import KernelNormalizer
21+
from skmatter.utils import check_cl_fit
22+
from skmatter.decomposition import _BaseKPCov
23+
from sklearn.preprocessing import StandardScaler
2124

2225
import scipy.sparse as sp
2326

24-
2527
class KernelPCovC(LinearClassifierMixin, _BaseKPCov):
26-
r"""Kernel Principal Covariates Classification is a modification on the Principal
27-
Covariates Classification proposed in [Jorgensen2025]_. It determines a latent-space
28-
projection :math:`\mathbf{T}` which minimizes a combined loss in supervised and unsupervised
29-
tasks in the reproducing kernel Hilbert space (RKHS).
28+
r"""Kernel Principal Covariates Classification, as described in [Jorgensen2025]_,
29+
determines a latent-space projection :math:`\mathbf{T}` which minimizes a combined
30+
loss in supervised and unsupervised tasks in the reproducing kernel Hilbert space
31+
(RKHS).
3032
31-
This projection is determined by the eigendecomposition of a modified covariance matrix
33+
This projection is determined by the eigendecomposition of a modified gram matrix
3234
:math:`\mathbf{\tilde{K}}`
3335
3436
.. math::
@@ -69,25 +71,26 @@ class KernelPCovC(LinearClassifierMixin, _BaseKPCov):
6971
If randomized :
7072
run randomized SVD by the method of Halko et al.
7173
72-
classifier : {instance of `sklearn.svm.SVC`, None}, default=None
74+
classifier: {`LogisticRegression`, `LogisticRegressionCV`, `LinearSVC`, `LinearDiscriminantAnalysis`,
75+
`RidgeClassifier`, `RidgeClassifierCV`, `SGDClassifier`, `Perceptron`, `precomputed`}, default=None
7376
The classifier to use for computing
7477
the evidence :math:`{\mathbf{Z}}`.
7578
A pre-fitted classifier may be provided.
76-
If the classifier is not `None`, its kernel parameters
77-
(`kernel`, `gamma`, `degree`, and `coef0`)
78-
must be identical to those passed directly to `KernelPCovC`.
7979
80-
kernel : {'linear', 'poly', 'rbf', 'sigmoid', 'precomputed'} or callable, default='rbf'
80+
If None, ``sklearn.linear_model.LogisticRegression()``
81+
is used as the classifier.
82+
83+
kernel : {"linear", "poly", "rbf", "sigmoid", "cosine", "precomputed"}, default="linear
8184
Kernel.
8285
83-
gamma : {'scale', 'auto'} or float, default='scale'
86+
gamma : {'scale', 'auto'} or float, default=None
8487
Kernel coefficient for rbf, poly and sigmoid kernels. Ignored by other
8588
kernels.
8689
8790
degree : int, default=3
8891
Degree for poly kernels. Ignored by other kernels.
8992
90-
coef0 : float, default=0.0
93+
coef0 : float, default=1
9194
Independent term in poly and sigmoid kernels.
9295
Ignored by other kernels.
9396
@@ -172,14 +175,14 @@ class KernelPCovC(LinearClassifierMixin, _BaseKPCov):
172175
... gamma=1,
173176
... )
174177
>>> kpcovc.fit(X, Y)
175-
KernelPCovC(classifier=SVC(gamma=1), gamma=1, mixing=0.25, n_components=2)
178+
KernelPCovC(gamma=1, kernel='rbf', mixing=0.1, n_components=2)
176179
>>> kpcovc.transform(X)
177-
array([[ 1.91713954, 2.52318389]
178-
[ 2.95581573, 0.78491499]
179-
[ 3.00977646, -1.1252421 ]
180-
[ 2.45390525, -1.5365844 ]])
180+
array([[-4.45970689e-01 8.95327566e-06]
181+
[ 4.52745933e-01 5.54810948e-01]
182+
[ 4.52881359e-01 -5.54708315e-01]
183+
[-4.45921092e-01 -7.32157649e-05]])
181184
>>> kpcovc.predict(X)
182-
array([2, 1, 3, 0])
185+
array([2 0 1 2])
183186
>>> kpcovc.score(X, Y)
184187
1.0
185188
""" # NoQa: E501
@@ -190,10 +193,10 @@ def __init__(
190193
n_components=None,
191194
svd_solver="auto",
192195
classifier=None,
193-
kernel="rbf",
194-
gamma="scale",
196+
kernel="linear",
197+
gamma=None,
195198
degree=3,
196-
coef0=0.0,
199+
coef0=1,
197200
kernel_params=None,
198201
center=False,
199202
fit_inverse_transform=False,
@@ -218,10 +221,9 @@ def __init__(
218221
n_jobs=n_jobs,
219222
fit_inverse_transform=fit_inverse_transform,
220223
)
221-
222224
self.classifier = classifier
223225

224-
def fit(self, X, Y):
226+
def fit(self, X, Y, W=None):
225227
r"""Fit the model with X and Y.
226228
227229
Parameters
@@ -238,6 +240,11 @@ def fit(self, X, Y):
238240
Y : numpy.ndarray, shape (n_samples,)
239241
Training data, where n_samples is the number of samples.
240242
243+
W : numpy.ndarray, shape (n_features, n_properties)
244+
Classification weights, optional when classifier=`precomputed`. If
245+
not passed, it is assumed that the weights will be taken from a
246+
linear classifier fit between K and Y.
247+
241248
Returns
242249
-------
243250
self: object
@@ -261,76 +268,56 @@ def fit(self, X, Y):
261268
super().fit(X)
262269

263270
K = super()._get_kernel(X)
264-
265271
if self.center:
266272
self.centerer_ = KernelNormalizer()
267273
K = self.centerer_.fit_transform(K)
274+
compatible_classifiers = (
275+
LogisticRegression,
276+
LogisticRegressionCV,
277+
LinearSVC,
278+
LinearDiscriminantAnalysis,
279+
RidgeClassifier,
280+
RidgeClassifierCV,
281+
SGDClassifier,
282+
Perceptron,
283+
)
268284

269-
if self.classifier and not isinstance(
270-
self.classifier,
271-
SVC,
285+
if self.classifier not in ["precomputed", None] and not isinstance(
286+
self.classifier, compatible_classifiers
272287
):
273-
raise ValueError("Classifier must be an instance of `SVC`")
274-
275-
if self.classifier is None:
276-
classifier = SVC(
277-
kernel=self.kernel,
278-
gamma=self.gamma,
279-
degree=self.degree,
280-
coef0=self.coef0,
288+
raise ValueError(
289+
"Classifier must be an instance of `"
290+
f"{'`, `'.join(c.__name__ for c in compatible_classifiers)}`"
291+
", or `precomputed`"
281292
)
293+
294+
if self.classifier != "precomputed":
295+
if self.classifier is None:
296+
classifier = LogisticRegression()
297+
else:
298+
classifier = self.classifier
299+
300+
# Check if classifier is fitted; if not, fit with precomputed K
301+
self.z_classifier_ = check_cl_fit(classifier, K, Y)
302+
W = self.z_classifier_.coef_.T.reshape(K.shape[1], -1)
303+
282304
else:
283-
classifier = self.classifier
284-
kernel_attrs = ["kernel", "gamma", "degree", "coef0"]
285-
if not all(
286-
[
287-
getattr(self, attr) == getattr(classifier, attr)
288-
for attr in kernel_attrs
289-
]
290-
):
291-
raise ValueError(
292-
"Kernel parameter mismatch: the classifier has kernel "
293-
"parameters {%s} and KernelPCovC was initialized with kernel "
294-
"parameters {%s}"
295-
% (
296-
", ".join(
297-
[
298-
"%s: %r" % (attr, getattr(classifier, attr))
299-
for attr in kernel_attrs
300-
]
301-
),
302-
", ".join(
303-
[
304-
"%s: %r" % (attr, getattr(self, attr))
305-
for attr in kernel_attrs
306-
]
307-
),
308-
)
309-
)
310-
if classifier.decision_function_shape != "ovr":
311-
raise ValueError(
312-
f"Classifier must have parameter `decision_function_shape` set to 'ovr' "
313-
f"but was initialized with '{classifier.decision_function_shape}'"
314-
)
315-
316-
# Check if classifier is fitted; if not, fit with precomputed K
317-
# to avoid needing to compute the kernel a second time
318-
self.z_classifier_ = check_svc_fit(classifier, K, X, Y)
319-
320-
# if we have fit the classifier on a precomputed K, we obtain Z
321-
# from K, otherwise obtain it from X
322-
if self.z_classifier_.kernel == "precomputed":
323-
Z = self.z_classifier_.decision_function(K)
324-
else:
325-
Z = self.z_classifier_.decision_function(X)
305+
# If precomputed, use default classifier to predict Y from T
306+
classifier = LogisticRegression()
307+
if W is None:
308+
W = LogisticRegression().fit(K, Y).coef_.T
309+
W = W.reshape(K.shape[1], -1)
310+
311+
Z = K @ W
326312

327-
print(f"KPCovC Z: {Z[:5]}")
328-
super()._fit_covariance(K, Z) # gives us T, Pkt, self.pt__
313+
self._fit_gram(K, Z, W)
329314

315+
self.ptk_ = self.pt__ @ K
316+
# ("KPCovc"+str(self.ptk_[:10][1]))
330317
if self.fit_inverse_transform:
331318
self.ptx_ = self.pt__ @ X
332319

333-
self.classifier_ = LinearSVC().fit(K @ self.pkt_, Y)
320+
self.classifier_ = clone(classifier).fit(K @ self.pkt_, Y)
334321

335322
self.ptz_ = self.classifier_.coef_.T
336323
self.pkz_ = self.pkt_ @ self.ptz_
@@ -367,7 +354,7 @@ def transform(self, X):
367354
"""Apply dimensionality reduction to X.
368355
369356
``X`` is projected on the first principal components as determined by the
370-
modified Kernel PCovC distances.
357+
modified Kernel PCovR distances.
371358
372359
Parameters
373360
----------
@@ -402,31 +389,7 @@ def inverse_transform(self, T):
402389
return super().inverse_transform(T)
403390

404391
def decision_function(self, X=None, T=None):
405-
r"""Predicts confidence scores from X or T.
406-
407-
.. math::
408-
\mathbf{Z} = \mathbf{T} \mathbf{P}_{TZ}
409-
= \mathbf{K} \mathbf{P}_{KT} \mathbf{P}_{TZ}
410-
= \mathbf{K} \mathbf{P}_{KZ}
411-
412-
Parameters
413-
----------
414-
X : ndarray, shape(n_samples, n_features)
415-
Original data for which we want to get confidence scores,
416-
where n_samples is the number of samples and n_features is the
417-
number of features.
418-
419-
T : ndarray, shape (n_samples, n_components)
420-
Projected data for which we want to get confidence scores,
421-
where n_samples is the number of samples and n_components is the
422-
number of components.
423-
424-
Returns
425-
-------
426-
Z : numpy.ndarray, shape (n_samples,) or (n_samples, n_classes)
427-
Confidence scores. For binary classification, has shape `(n_samples,)`,
428-
for multiclass classification, has shape `(n_samples, n_classes)`
429-
"""
392+
"""Predicts confidence scores from X or T."""
430393
check_is_fitted(self, attributes=["pkz_", "ptz_"])
431394

432395
if X is None and T is None:

0 commit comments

Comments
 (0)