4
4
5
5
6
6
class SNMFOptimizer :
7
- def __init__ (self , MM , Y0 = None , X0 = None , A = None , rho = 1e12 , eta = 610 , max_iter = 500 , tol = 5e-7 , components = None ):
8
- print ("Initializing SNMF Optimizer" )
7
+ """A implementation of stretched NMF (sNMF), including sparse stretched NMF.
8
+
9
+ Instantiating the SNMFOptimizer class runs all the analysis immediately.
10
+ The results matrices can then be accessed as instance attributes
11
+ of the class (X, Y, and A).
12
+
13
+ For more information on sNMF, please reference:
14
+ Gu, R., Rakita, Y., Lan, L. et al. Stretched non-negative matrix factorization.
15
+ npj Comput Mater 10, 193 (2024). https://doi.org/10.1038/s41524-024-01377-5
16
+
17
+ Attributes
18
+ ----------
19
+ MM : ndarray
20
+ The original, unmodified data to be decomposed and later, compared against.
21
+ Shape is (length_of_signal, number_of_conditions).
22
+ Y : ndarray
23
+ The best guess (or while running, the current guess) for the stretching
24
+ factor matrix.
25
+ X : ndarray
26
+ The best guess (or while running, the current guess) for the matrix of
27
+ component intensities.
28
+ A : ndarray
29
+ The best guess (or while running, the current guess) for the matrix of
30
+ component weights.
31
+ rho : float
32
+ The stretching factor that influences the decomposition. Zero corresponds to no
33
+ stretching present. Relatively insensitive and typically adjusted in powers of 10.
34
+ eta : float
35
+ The sparsity factor that influences the decomposition. Should be set to zero for
36
+ non-sparse data such as PDF. Can be used to improve results for sparse data such
37
+ as XRD, but due to instability, should be used only after first selecting the
38
+ best value for rho. Suggested adjustment is by powers of 2.
39
+ max_iter : int
40
+ The maximum number of times to update each of A, X, and Y before stopping
41
+ the optimization.
42
+ tol : float
43
+ The convergence threshold. This is the minimum fractional improvement in the
44
+ objective function to allow without terminating the optimization. Note that
45
+ a minimum of 20 updates are run before this parameter is checked.
46
+ n_components : int
47
+ The number of components to extract from MM. Must be provided when and only when
48
+ Y0 is not provided.
49
+ random_state : int
50
+ The seed for the initial guesses at the matrices (A, X, and Y) created by
51
+ the decomposition.
52
+ num_updates : int
53
+ The total number of times that any of (A, X, and Y) have had their values changed.
54
+ If not terminated by other means, this value is used to stop when reaching max_iter.
55
+ objective_function: float
56
+ The value corresponding to the minimization of the difference between the MM and the
57
+ products of A, X, and Y. For full details see the sNMF paper. Smaller corresponds to
58
+ better agreement and is desirable.
59
+ objective_difference : float
60
+ The change in the objective function value since the last update. A negative value
61
+ means that the result improved.
62
+ """
63
+
64
+ def __init__ (
65
+ self ,
66
+ MM ,
67
+ Y0 = None ,
68
+ X0 = None ,
69
+ A0 = None ,
70
+ rho = 1e12 ,
71
+ eta = 610 ,
72
+ max_iter = 500 ,
73
+ tol = 5e-7 ,
74
+ n_components = None ,
75
+ random_state = None ,
76
+ ):
77
+ """Initialize an instance of SNMF and run the optimization.
78
+
79
+ Parameters
80
+ ----------
81
+ MM : ndarray
82
+ The data to be decomposed. Shape is (length_of_signal, number_of_conditions).
83
+ Y0 : ndarray
84
+ The initial guesses for the component weights at each stretching condition.
85
+ Shape is (number_of_components, number_of_conditions) Must provide exactly one
86
+ of this or n_components.
87
+ X0 : ndarray
88
+ The initial guesses for the intensities of each component per
89
+ row/sample/angle. Shape is (length_of_signal, number_of_components).
90
+ A0 : ndarray
91
+ The initial guesses for the stretching factor for each component, at each
92
+ condition. Shape is (number_of_components, number_of_conditions).
93
+ rho : float
94
+ The stretching factor that influences the decomposition. Zero corresponds to no
95
+ stretching present. Relatively insensitive and typically adjusted in powers of 10.
96
+ eta : float
97
+ The sparsity factor that influences the decomposition. Should be set to zero for
98
+ non-sparse data such as PDF. Can be used to improve results for sparse data such
99
+ as XRD, but due to instability, should be used only after first selecting the
100
+ best value for rho. Suggested adjustment is by powers of 2.
101
+ max_iter : int
102
+ The maximum number of times to update each of A, X, and Y before stopping
103
+ the optimization.
104
+ tol : float
105
+ The convergence threshold. This is the minimum fractional improvement in the
106
+ objective function to allow without terminating the optimization. Note that
107
+ a minimum of 20 updates are run before this parameter is checked.
108
+ n_components : int
109
+ The number of components to extract from MM. Must be provided when and only when
110
+ Y0 is not provided.
111
+ random_state : int
112
+ The seed for the initial guesses at the matrices (A, X, and Y) created by
113
+ the decomposition.
114
+ """
115
+
9
116
self .MM = MM
10
- self .X0 = X0
11
- self .Y0 = Y0
12
- self .A = A
13
117
self .rho = rho
14
118
self .eta = eta
15
119
# Capture matrix dimensions
16
- self .N , self .M = MM .shape
120
+ self ._N , self ._M = MM .shape
17
121
self .num_updates = 0
122
+ self ._rng = np .random .default_rng (random_state )
18
123
124
+ # Enforce exclusive specification of n_components or Y0
125
+ if (n_components is None ) == (Y0 is not None ):
126
+ raise ValueError ("Must provide exactly one of Y0 or n_components, but not both." )
127
+
128
+ # Initialize Y0 and determine number of components
19
129
if Y0 is None :
20
- if components is None :
21
- raise ValueError ("Must provide either Y0 or a number of components." )
22
- else :
23
- self .K = components
24
- self .Y0 = np .random .beta (a = 2.5 , b = 1.5 , size = (self .K , self .M )) # This is untested
130
+ self ._K = n_components
131
+ self .Y = self ._rng .beta (a = 2.5 , b = 1.5 , size = (self ._K , self ._M ))
25
132
else :
26
- self .K = Y0 .shape [0 ]
133
+ self ._K = Y0 .shape [0 ]
134
+ self .Y = Y0
27
135
28
- # Initialize A, X0 if not provided
136
+ # Initialize A if not provided
29
137
if self .A is None :
30
- self .A = np .ones ((self .K , self .M )) + np .random .randn (self .K , self .M ) * 1e-3 # Small perturbation
31
- if self .X0 is None :
32
- self .X0 = np .random .rand (self .N , self .K ) # Ensures values in [0,1]
138
+ self .A = np .ones ((self ._K , self ._M )) + self ._rng .normal (0 , 1e-3 , size = (self ._K , self ._M ))
139
+ else :
140
+ self .A = A0
141
+
142
+ # Initialize X0 if not provided
143
+ if self .X is None :
144
+ self .X = self ._rng .random ((self ._N , self ._K ))
145
+ else :
146
+ self .X = X0
33
147
34
- # Initialize solution matrices to be iterated on
35
- self .X = np .maximum (0 , self .X0 )
36
- self .Y = np .maximum (0 , self .Y0 )
148
+ # Enforce non-negativity
149
+ self .X = np .maximum (0 , self .X )
150
+ self .Y = np .maximum (0 , self .Y )
37
151
38
152
# Second-order spline: Tridiagonal (-2 on diagonal, 1 on sub/superdiagonals)
39
- self .P = 0.25 * diags ([1 , - 2 , 1 ], offsets = [0 , 1 , 2 ], shape = (self .M - 2 , self .M ))
153
+ self .P = 0.25 * diags ([1 , - 2 , 1 ], offsets = [0 , 1 , 2 ], shape = (self ._M - 2 , self ._M ))
40
154
self .PP = self .P .T @ self .P
41
155
42
156
# Set up residual matrix, objective function, and history
43
157
self .R = self .get_residual_matrix ()
44
158
self .objective_function = self .get_objective_function ()
45
159
self .objective_difference = None
46
- self .objective_history = [self .objective_function ]
160
+ self ._objective_history = [self .objective_function ]
47
161
48
162
# Set up tracking variables for updateX()
49
- self .preX = None
50
- self .GraX = np .zeros_like (self .X ) # Gradient of X (zeros for now)
51
- self .preGraX = np .zeros_like (self .X ) # Previous gradient of X (zeros for now)
163
+ self ._preX = None
164
+ self ._GraX = np .zeros_like (self .X ) # Gradient of X (zeros for now)
165
+ self ._preGraX = np .zeros_like (self .X ) # Previous gradient of X (zeros for now)
52
166
53
167
regularization_term = 0.5 * rho * np .linalg .norm (self .P @ self .A .T , "fro" ) ** 2
54
168
sparsity_term = eta * np .sum (np .sqrt (self .X )) # Square root penalty
@@ -83,53 +197,53 @@ def __init__(self, MM, Y0=None, X0=None, A=None, rho=1e12, eta=610, max_iter=500
83
197
# loop to normalize X
84
198
# effectively just re-running class with non-normalized X, normalized Y/A as inputs, then only update X
85
199
# reset difference trackers and initialize
86
- self .preX = None
87
- self .GraX = np .zeros_like (self .X ) # Gradient of X (zeros for now)
88
- self .preGraX = np .zeros_like (self .X ) # Previous gradient of X (zeros for now)
200
+ self ._preX = None
201
+ self ._GraX = np .zeros_like (self .X ) # Gradient of X (zeros for now)
202
+ self ._preGraX = np .zeros_like (self .X ) # Previous gradient of X (zeros for now)
89
203
self .R = self .get_residual_matrix ()
90
204
self .objective_function = self .get_objective_function ()
91
205
self .objective_difference = None
92
- self .objective_history = [self .objective_function ]
206
+ self ._objective_history = [self .objective_function ]
93
207
for norm_iter in range (100 ):
94
208
self .updateX ()
95
209
self .R = self .get_residual_matrix ()
96
210
self .objective_function = self .get_objective_function ()
97
211
print (f"Objective function after normX: { self .objective_function :.5e} " )
98
- self .objective_history .append (self .objective_function )
99
- self .objective_difference = self .objective_history [- 2 ] - self .objective_history [- 1 ]
212
+ self ._objective_history .append (self .objective_function )
213
+ self .objective_difference = self ._objective_history [- 2 ] - self ._objective_history [- 1 ]
100
214
if self .objective_difference < self .objective_function * tol and norm_iter >= 20 :
101
215
break
102
216
# end of normalization (and program)
103
217
# note that objective function may not fully recover after normalization, this is okay
104
218
print ("Finished optimization." )
105
219
106
220
def optimize_loop (self ):
107
- self .preGraX = self .GraX .copy ()
221
+ self ._preGraX = self ._GraX .copy ()
108
222
self .updateX ()
109
223
self .num_updates += 1
110
224
self .R = self .get_residual_matrix ()
111
225
self .objective_function = self .get_objective_function ()
112
226
print (f"Objective function after updateX: { self .objective_function :.5e} " )
113
- self .objective_history .append (self .objective_function )
227
+ self ._objective_history .append (self .objective_function )
114
228
if self .objective_difference is None :
115
- self .objective_difference = self .objective_history [- 1 ] - self .objective_function
229
+ self .objective_difference = self ._objective_history [- 1 ] - self .objective_function
116
230
117
231
# Now we update Y
118
232
self .updateY2 ()
119
233
self .num_updates += 1
120
234
self .R = self .get_residual_matrix ()
121
235
self .objective_function = self .get_objective_function ()
122
236
print (f"Objective function after updateY2: { self .objective_function :.5e} " )
123
- self .objective_history .append (self .objective_function )
237
+ self ._objective_history .append (self .objective_function )
124
238
125
239
self .updateA2 ()
126
240
127
241
self .num_updates += 1
128
242
self .R = self .get_residual_matrix ()
129
243
self .objective_function = self .get_objective_function ()
130
244
print (f"Objective function after updateA2: { self .objective_function :.5e} " )
131
- self .objective_history .append (self .objective_function )
132
- self .objective_difference = self .objective_history [- 2 ] - self .objective_history [- 1 ]
245
+ self ._objective_history .append (self .objective_function )
246
+ self .objective_difference = self ._objective_history [- 2 ] - self ._objective_history [- 1 ]
133
247
134
248
def apply_interpolation (self , a , x , return_derivatives = False ):
135
249
"""
@@ -401,36 +515,38 @@ def updateX(self):
401
515
# Compute `AX` using the interpolation function
402
516
AX , _ , _ = self .apply_interpolation_matrix () # Skip the other two outputs
403
517
# Compute RA and RR
404
- intermediate_RA = AX .flatten (order = "F" ).reshape ((self .N * self .M , self .K ), order = "F" )
405
- RA = intermediate_RA .sum (axis = 1 ).reshape ((self .N , self .M ), order = "F" )
518
+ intermediate_RA = AX .flatten (order = "F" ).reshape ((self ._N * self ._M , self ._K ), order = "F" )
519
+ RA = intermediate_RA .sum (axis = 1 ).reshape ((self ._N , self ._M ), order = "F" )
406
520
RR = RA - self .MM
407
521
# Compute gradient `GraX`
408
- self .GraX = self .apply_transformation_matrix (R = RR ).toarray () # toarray equivalent of full, make non-sparse
522
+ self ._GraX = self .apply_transformation_matrix (
523
+ R = RR
524
+ ).toarray () # toarray equivalent of full, make non-sparse
409
525
410
526
# Compute initial step size `L0`
411
527
L0 = np .linalg .eigvalsh (self .Y .T @ self .Y ).max () * np .max ([self .A .max (), 1 / self .A .min ()])
412
528
# Compute adaptive step size `L`
413
- if self .preX is None :
529
+ if self ._preX is None :
414
530
L = L0
415
531
else :
416
- num = np .sum ((self .GraX - self .preGraX ) * (self .X - self .preX )) # Element-wise multiplication
417
- denom = np .linalg .norm (self .X - self .preX , "fro" ) ** 2 # Frobenius norm squared
532
+ num = np .sum ((self ._GraX - self ._preGraX ) * (self .X - self ._preX )) # Element-wise multiplication
533
+ denom = np .linalg .norm (self .X - self ._preX , "fro" ) ** 2 # Frobenius norm squared
418
534
L = num / denom if denom > 0 else L0
419
535
if L <= 0 :
420
536
L = L0
421
537
422
538
# Store our old X before updating because it is used in step selection
423
- self .preX = self .X .copy ()
539
+ self ._preX = self .X .copy ()
424
540
425
541
while True : # iterate updating X
426
- x_step = self .preX - self .GraX / L
542
+ x_step = self ._preX - self ._GraX / L
427
543
# Solve x^3 + p*x + q = 0 for the largest real root
428
544
self .X = np .square (cubic_largest_real_root (- x_step , self .eta / (2 * L )))
429
545
# Mask values that should be set to zero
430
546
mask = self .X ** 2 * L / 2 - L * self .X * x_step + self .eta * np .sqrt (self .X ) < 0
431
547
self .X = mask * self .X
432
548
433
- objective_improvement = self .objective_history [- 1 ] - self .get_objective_function (
549
+ objective_improvement = self ._objective_history [- 1 ] - self .get_objective_function (
434
550
R = self .get_residual_matrix ()
435
551
)
436
552
@@ -447,9 +563,9 @@ def updateY2(self):
447
563
Updates Y using matrix operations, solving a quadratic program via `solve_mkr_box`.
448
564
"""
449
565
450
- K = self .K
451
- N = self .N
452
- M = self .M
566
+ K = self ._K
567
+ N = self ._N
568
+ M = self ._M
453
569
454
570
for m in range (M ):
455
571
T = np .zeros ((N , K )) # Initialize T as an (N, K) zero matrix
@@ -476,9 +592,9 @@ def regularize_function(self, A=None):
476
592
if A is None :
477
593
A = self .A
478
594
479
- K = self .K
480
- M = self .M
481
- N = self .N
595
+ K = self ._K
596
+ M = self ._M
597
+ N = self ._N
482
598
483
599
# Compute interpolated matrices
484
600
AX , TX , HX = self .apply_interpolation_matrix (A = A , return_derivatives = True )
0 commit comments