Skip to content

Commit 373da11

Browse files
john-halloranJohn Halloransbillinge
authored
feat: Add random state feature. (#150)
* feat: Add random state feature. * Add class docstring * components->n_components * Updated docstring * Shorten and reformat docstring * docstring typo * Flag self.rng as private * Make logic for n_components and Y0 more rigid * added class attributes to docstring * fix: cleaner import of SNMFOptimizer * fix: correct class instantiation after change in import --------- Co-authored-by: John Halloran <[email protected]> Co-authored-by: Simon Billinge <[email protected]>
1 parent 8613ea0 commit 373da11

File tree

2 files changed

+169
-72
lines changed

2 files changed

+169
-72
lines changed

src/diffpy/snmf/main.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,14 @@
11
import numpy as np
2-
import snmf_class
2+
from snmf_class import SNMFOptimizer
33

44
X0 = np.loadtxt("input/X0.txt", dtype=float)
55
MM = np.loadtxt("input/MM.txt", dtype=float)
66
A0 = np.loadtxt("input/A0.txt", dtype=float)
77
Y0 = np.loadtxt("input/W0.txt", dtype=float)
88
N, M = MM.shape
99

10-
# Convert to DataFrames for display
11-
# df_X = pd.DataFrame(X, columns=[f"Comp_{i+1}" for i in range(X.shape[1])])
12-
# df_Y = pd.DataFrame(Y, columns=[f"Sample_{i+1}" for i in range(Y.shape[1])])
13-
# df_MM = pd.DataFrame(MM, columns=[f"Sample_{i+1}" for i in range(MM.shape[1])])
14-
# df_Y0 = pd.DataFrame(Y0, columns=[f"Sample_{i+1}" for i in range(Y0.shape[1])])
15-
16-
# Print the matrices
17-
"""
18-
print("Feature Matrix (X):\n", df_X, "\n")
19-
print("Coefficient Matrix (Y):\n", df_Y, "\n")
20-
print("Data Matrix (MM):\n", df_MM, "\n")
21-
print("Initial Guess (Y0):\n", df_Y0, "\n")
22-
"""
23-
24-
25-
my_model = snmf_class.SNMFOptimizer(MM=MM, Y0=Y0, X0=X0, A=A0, components=2)
10+
my_model = SNMFOptimizer(MM=MM, Y0=Y0, X0=X0, A0=A0)
2611
print("Done")
27-
# print(f"My final guess for X: {my_model.X}")
28-
# print(f"My final guess for Y: {my_model.Y}")
29-
# print(f"Compare to true X: {X_norm}")
30-
# print(f"Compare to true Y: {Y_norm}")
3112
np.savetxt("my_norm_X.txt", my_model.X, fmt="%.6g", delimiter=" ")
3213
np.savetxt("my_norm_Y.txt", my_model.Y, fmt="%.6g", delimiter=" ")
3314
np.savetxt("my_norm_A.txt", my_model.A, fmt="%.6g", delimiter=" ")

src/diffpy/snmf/snmf_class.py

Lines changed: 167 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -4,51 +4,165 @@
44

55

66
class SNMFOptimizer:
7-
def __init__(self, MM, Y0=None, X0=None, A=None, rho=1e12, eta=610, max_iter=500, tol=5e-7, components=None):
8-
print("Initializing SNMF Optimizer")
7+
"""A implementation of stretched NMF (sNMF), including sparse stretched NMF.
8+
9+
Instantiating the SNMFOptimizer class runs all the analysis immediately.
10+
The results matrices can then be accessed as instance attributes
11+
of the class (X, Y, and A).
12+
13+
For more information on sNMF, please reference:
14+
Gu, R., Rakita, Y., Lan, L. et al. Stretched non-negative matrix factorization.
15+
npj Comput Mater 10, 193 (2024). https://doi.org/10.1038/s41524-024-01377-5
16+
17+
Attributes
18+
----------
19+
MM : ndarray
20+
The original, unmodified data to be decomposed and later, compared against.
21+
Shape is (length_of_signal, number_of_conditions).
22+
Y : ndarray
23+
The best guess (or while running, the current guess) for the stretching
24+
factor matrix.
25+
X : ndarray
26+
The best guess (or while running, the current guess) for the matrix of
27+
component intensities.
28+
A : ndarray
29+
The best guess (or while running, the current guess) for the matrix of
30+
component weights.
31+
rho : float
32+
The stretching factor that influences the decomposition. Zero corresponds to no
33+
stretching present. Relatively insensitive and typically adjusted in powers of 10.
34+
eta : float
35+
The sparsity factor that influences the decomposition. Should be set to zero for
36+
non-sparse data such as PDF. Can be used to improve results for sparse data such
37+
as XRD, but due to instability, should be used only after first selecting the
38+
best value for rho. Suggested adjustment is by powers of 2.
39+
max_iter : int
40+
The maximum number of times to update each of A, X, and Y before stopping
41+
the optimization.
42+
tol : float
43+
The convergence threshold. This is the minimum fractional improvement in the
44+
objective function to allow without terminating the optimization. Note that
45+
a minimum of 20 updates are run before this parameter is checked.
46+
n_components : int
47+
The number of components to extract from MM. Must be provided when and only when
48+
Y0 is not provided.
49+
random_state : int
50+
The seed for the initial guesses at the matrices (A, X, and Y) created by
51+
the decomposition.
52+
num_updates : int
53+
The total number of times that any of (A, X, and Y) have had their values changed.
54+
If not terminated by other means, this value is used to stop when reaching max_iter.
55+
objective_function: float
56+
The value corresponding to the minimization of the difference between the MM and the
57+
products of A, X, and Y. For full details see the sNMF paper. Smaller corresponds to
58+
better agreement and is desirable.
59+
objective_difference : float
60+
The change in the objective function value since the last update. A negative value
61+
means that the result improved.
62+
"""
63+
64+
def __init__(
65+
self,
66+
MM,
67+
Y0=None,
68+
X0=None,
69+
A0=None,
70+
rho=1e12,
71+
eta=610,
72+
max_iter=500,
73+
tol=5e-7,
74+
n_components=None,
75+
random_state=None,
76+
):
77+
"""Initialize an instance of SNMF and run the optimization.
78+
79+
Parameters
80+
----------
81+
MM : ndarray
82+
The data to be decomposed. Shape is (length_of_signal, number_of_conditions).
83+
Y0 : ndarray
84+
The initial guesses for the component weights at each stretching condition.
85+
Shape is (number_of_components, number_of_conditions) Must provide exactly one
86+
of this or n_components.
87+
X0 : ndarray
88+
The initial guesses for the intensities of each component per
89+
row/sample/angle. Shape is (length_of_signal, number_of_components).
90+
A0 : ndarray
91+
The initial guesses for the stretching factor for each component, at each
92+
condition. Shape is (number_of_components, number_of_conditions).
93+
rho : float
94+
The stretching factor that influences the decomposition. Zero corresponds to no
95+
stretching present. Relatively insensitive and typically adjusted in powers of 10.
96+
eta : float
97+
The sparsity factor that influences the decomposition. Should be set to zero for
98+
non-sparse data such as PDF. Can be used to improve results for sparse data such
99+
as XRD, but due to instability, should be used only after first selecting the
100+
best value for rho. Suggested adjustment is by powers of 2.
101+
max_iter : int
102+
The maximum number of times to update each of A, X, and Y before stopping
103+
the optimization.
104+
tol : float
105+
The convergence threshold. This is the minimum fractional improvement in the
106+
objective function to allow without terminating the optimization. Note that
107+
a minimum of 20 updates are run before this parameter is checked.
108+
n_components : int
109+
The number of components to extract from MM. Must be provided when and only when
110+
Y0 is not provided.
111+
random_state : int
112+
The seed for the initial guesses at the matrices (A, X, and Y) created by
113+
the decomposition.
114+
"""
115+
9116
self.MM = MM
10-
self.X0 = X0
11-
self.Y0 = Y0
12-
self.A = A
13117
self.rho = rho
14118
self.eta = eta
15119
# Capture matrix dimensions
16-
self.N, self.M = MM.shape
120+
self._N, self._M = MM.shape
17121
self.num_updates = 0
122+
self._rng = np.random.default_rng(random_state)
18123

124+
# Enforce exclusive specification of n_components or Y0
125+
if (n_components is None) == (Y0 is not None):
126+
raise ValueError("Must provide exactly one of Y0 or n_components, but not both.")
127+
128+
# Initialize Y0 and determine number of components
19129
if Y0 is None:
20-
if components is None:
21-
raise ValueError("Must provide either Y0 or a number of components.")
22-
else:
23-
self.K = components
24-
self.Y0 = np.random.beta(a=2.5, b=1.5, size=(self.K, self.M)) # This is untested
130+
self._K = n_components
131+
self.Y = self._rng.beta(a=2.5, b=1.5, size=(self._K, self._M))
25132
else:
26-
self.K = Y0.shape[0]
133+
self._K = Y0.shape[0]
134+
self.Y = Y0
27135

28-
# Initialize A, X0 if not provided
136+
# Initialize A if not provided
29137
if self.A is None:
30-
self.A = np.ones((self.K, self.M)) + np.random.randn(self.K, self.M) * 1e-3 # Small perturbation
31-
if self.X0 is None:
32-
self.X0 = np.random.rand(self.N, self.K) # Ensures values in [0,1]
138+
self.A = np.ones((self._K, self._M)) + self._rng.normal(0, 1e-3, size=(self._K, self._M))
139+
else:
140+
self.A = A0
141+
142+
# Initialize X0 if not provided
143+
if self.X is None:
144+
self.X = self._rng.random((self._N, self._K))
145+
else:
146+
self.X = X0
33147

34-
# Initialize solution matrices to be iterated on
35-
self.X = np.maximum(0, self.X0)
36-
self.Y = np.maximum(0, self.Y0)
148+
# Enforce non-negativity
149+
self.X = np.maximum(0, self.X)
150+
self.Y = np.maximum(0, self.Y)
37151

38152
# Second-order spline: Tridiagonal (-2 on diagonal, 1 on sub/superdiagonals)
39-
self.P = 0.25 * diags([1, -2, 1], offsets=[0, 1, 2], shape=(self.M - 2, self.M))
153+
self.P = 0.25 * diags([1, -2, 1], offsets=[0, 1, 2], shape=(self._M - 2, self._M))
40154
self.PP = self.P.T @ self.P
41155

42156
# Set up residual matrix, objective function, and history
43157
self.R = self.get_residual_matrix()
44158
self.objective_function = self.get_objective_function()
45159
self.objective_difference = None
46-
self.objective_history = [self.objective_function]
160+
self._objective_history = [self.objective_function]
47161

48162
# Set up tracking variables for updateX()
49-
self.preX = None
50-
self.GraX = np.zeros_like(self.X) # Gradient of X (zeros for now)
51-
self.preGraX = np.zeros_like(self.X) # Previous gradient of X (zeros for now)
163+
self._preX = None
164+
self._GraX = np.zeros_like(self.X) # Gradient of X (zeros for now)
165+
self._preGraX = np.zeros_like(self.X) # Previous gradient of X (zeros for now)
52166

53167
regularization_term = 0.5 * rho * np.linalg.norm(self.P @ self.A.T, "fro") ** 2
54168
sparsity_term = eta * np.sum(np.sqrt(self.X)) # Square root penalty
@@ -83,53 +197,53 @@ def __init__(self, MM, Y0=None, X0=None, A=None, rho=1e12, eta=610, max_iter=500
83197
# loop to normalize X
84198
# effectively just re-running class with non-normalized X, normalized Y/A as inputs, then only update X
85199
# reset difference trackers and initialize
86-
self.preX = None
87-
self.GraX = np.zeros_like(self.X) # Gradient of X (zeros for now)
88-
self.preGraX = np.zeros_like(self.X) # Previous gradient of X (zeros for now)
200+
self._preX = None
201+
self._GraX = np.zeros_like(self.X) # Gradient of X (zeros for now)
202+
self._preGraX = np.zeros_like(self.X) # Previous gradient of X (zeros for now)
89203
self.R = self.get_residual_matrix()
90204
self.objective_function = self.get_objective_function()
91205
self.objective_difference = None
92-
self.objective_history = [self.objective_function]
206+
self._objective_history = [self.objective_function]
93207
for norm_iter in range(100):
94208
self.updateX()
95209
self.R = self.get_residual_matrix()
96210
self.objective_function = self.get_objective_function()
97211
print(f"Objective function after normX: {self.objective_function:.5e}")
98-
self.objective_history.append(self.objective_function)
99-
self.objective_difference = self.objective_history[-2] - self.objective_history[-1]
212+
self._objective_history.append(self.objective_function)
213+
self.objective_difference = self._objective_history[-2] - self._objective_history[-1]
100214
if self.objective_difference < self.objective_function * tol and norm_iter >= 20:
101215
break
102216
# end of normalization (and program)
103217
# note that objective function may not fully recover after normalization, this is okay
104218
print("Finished optimization.")
105219

106220
def optimize_loop(self):
107-
self.preGraX = self.GraX.copy()
221+
self._preGraX = self._GraX.copy()
108222
self.updateX()
109223
self.num_updates += 1
110224
self.R = self.get_residual_matrix()
111225
self.objective_function = self.get_objective_function()
112226
print(f"Objective function after updateX: {self.objective_function:.5e}")
113-
self.objective_history.append(self.objective_function)
227+
self._objective_history.append(self.objective_function)
114228
if self.objective_difference is None:
115-
self.objective_difference = self.objective_history[-1] - self.objective_function
229+
self.objective_difference = self._objective_history[-1] - self.objective_function
116230

117231
# Now we update Y
118232
self.updateY2()
119233
self.num_updates += 1
120234
self.R = self.get_residual_matrix()
121235
self.objective_function = self.get_objective_function()
122236
print(f"Objective function after updateY2: {self.objective_function:.5e}")
123-
self.objective_history.append(self.objective_function)
237+
self._objective_history.append(self.objective_function)
124238

125239
self.updateA2()
126240

127241
self.num_updates += 1
128242
self.R = self.get_residual_matrix()
129243
self.objective_function = self.get_objective_function()
130244
print(f"Objective function after updateA2: {self.objective_function:.5e}")
131-
self.objective_history.append(self.objective_function)
132-
self.objective_difference = self.objective_history[-2] - self.objective_history[-1]
245+
self._objective_history.append(self.objective_function)
246+
self.objective_difference = self._objective_history[-2] - self._objective_history[-1]
133247

134248
def apply_interpolation(self, a, x, return_derivatives=False):
135249
"""
@@ -401,36 +515,38 @@ def updateX(self):
401515
# Compute `AX` using the interpolation function
402516
AX, _, _ = self.apply_interpolation_matrix() # Skip the other two outputs
403517
# Compute RA and RR
404-
intermediate_RA = AX.flatten(order="F").reshape((self.N * self.M, self.K), order="F")
405-
RA = intermediate_RA.sum(axis=1).reshape((self.N, self.M), order="F")
518+
intermediate_RA = AX.flatten(order="F").reshape((self._N * self._M, self._K), order="F")
519+
RA = intermediate_RA.sum(axis=1).reshape((self._N, self._M), order="F")
406520
RR = RA - self.MM
407521
# Compute gradient `GraX`
408-
self.GraX = self.apply_transformation_matrix(R=RR).toarray() # toarray equivalent of full, make non-sparse
522+
self._GraX = self.apply_transformation_matrix(
523+
R=RR
524+
).toarray() # toarray equivalent of full, make non-sparse
409525

410526
# Compute initial step size `L0`
411527
L0 = np.linalg.eigvalsh(self.Y.T @ self.Y).max() * np.max([self.A.max(), 1 / self.A.min()])
412528
# Compute adaptive step size `L`
413-
if self.preX is None:
529+
if self._preX is None:
414530
L = L0
415531
else:
416-
num = np.sum((self.GraX - self.preGraX) * (self.X - self.preX)) # Element-wise multiplication
417-
denom = np.linalg.norm(self.X - self.preX, "fro") ** 2 # Frobenius norm squared
532+
num = np.sum((self._GraX - self._preGraX) * (self.X - self._preX)) # Element-wise multiplication
533+
denom = np.linalg.norm(self.X - self._preX, "fro") ** 2 # Frobenius norm squared
418534
L = num / denom if denom > 0 else L0
419535
if L <= 0:
420536
L = L0
421537

422538
# Store our old X before updating because it is used in step selection
423-
self.preX = self.X.copy()
539+
self._preX = self.X.copy()
424540

425541
while True: # iterate updating X
426-
x_step = self.preX - self.GraX / L
542+
x_step = self._preX - self._GraX / L
427543
# Solve x^3 + p*x + q = 0 for the largest real root
428544
self.X = np.square(cubic_largest_real_root(-x_step, self.eta / (2 * L)))
429545
# Mask values that should be set to zero
430546
mask = self.X**2 * L / 2 - L * self.X * x_step + self.eta * np.sqrt(self.X) < 0
431547
self.X = mask * self.X
432548

433-
objective_improvement = self.objective_history[-1] - self.get_objective_function(
549+
objective_improvement = self._objective_history[-1] - self.get_objective_function(
434550
R=self.get_residual_matrix()
435551
)
436552

@@ -447,9 +563,9 @@ def updateY2(self):
447563
Updates Y using matrix operations, solving a quadratic program via `solve_mkr_box`.
448564
"""
449565

450-
K = self.K
451-
N = self.N
452-
M = self.M
566+
K = self._K
567+
N = self._N
568+
M = self._M
453569

454570
for m in range(M):
455571
T = np.zeros((N, K)) # Initialize T as an (N, K) zero matrix
@@ -476,9 +592,9 @@ def regularize_function(self, A=None):
476592
if A is None:
477593
A = self.A
478594

479-
K = self.K
480-
M = self.M
481-
N = self.N
595+
K = self._K
596+
M = self._M
597+
N = self._N
482598

483599
# Compute interpolated matrices
484600
AX, TX, HX = self.apply_interpolation_matrix(A=A, return_derivatives=True)

0 commit comments

Comments
 (0)