Skip to content

Commit 85ea86d

Browse files
committed
🎨 Standarize code ./lib directory
1 parent 2a38e7c commit 85ea86d

32 files changed

+1839
-1959
lines changed

lib/nmatrix.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,4 @@
2525
# This file is a stub that only loads the main NMatrix file.
2626
#
2727

28-
require 'nmatrix/nmatrix.rb'
28+
require "nmatrix/nmatrix.rb"

lib/nmatrix/atlas.rb

Lines changed: 59 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,17 @@
2727
# nice ruby interfaces for ATLAS functions.
2828
#++
2929

30-
require 'nmatrix/nmatrix.rb'
31-
#need to have nmatrix required first or else bad things will happen
32-
require_relative 'lapack_ext_common'
30+
require "nmatrix/nmatrix.rb"
31+
# need to have nmatrix required first or else bad things will happen
32+
require_relative "lapack_ext_common"
3333

3434
NMatrix.register_lapack_extension("nmatrix-atlas")
3535

3636
require "nmatrix_atlas.so"
3737

3838
class NMatrix
39-
40-
#Add functions from the ATLAS C extension to the main LAPACK and BLAS modules.
41-
#This will overwrite the original functions where applicable.
39+
# Add functions from the ATLAS C extension to the main LAPACK and BLAS modules.
40+
# This will overwrite the original functions where applicable.
4241
module LAPACK
4342
class << self
4443
NMatrix::ATLAS::LAPACK.singleton_methods.each do |m|
@@ -68,8 +67,8 @@ def posv(uplo, a, b)
6867
unless a.stype == :dense && b.stype == :dense
6968

7069
raise(DataTypeError, "only works for non-integer, non-object dtypes") \
71-
if a.integer_dtype? || a.object_dtype? || \
72-
b.integer_dtype? || b.object_dtype?
70+
if a.integer_dtype? || a.object_dtype? || \
71+
b.integer_dtype? || b.object_dtype?
7372

7473
x = b.clone
7574
clone = a.clone
@@ -83,78 +82,81 @@ def posv(uplo, a, b)
8382
x.transpose
8483
end
8584

86-
def geev(matrix, which=:both)
85+
def geev(matrix, which = :both)
8786
raise(StorageTypeError, "LAPACK functions only work on dense matrices") \
8887
unless matrix.dense?
8988

9089
raise(ShapeError, "eigenvalues can only be computed for square matrices") \
9190
unless matrix.dim == 2 && matrix.shape[0] == matrix.shape[1]
9291

93-
jobvl = (which == :both || which == :left) ? :t : false
94-
jobvr = (which == :both || which == :right) ? :t : false
92+
jobvl = which == :both || which == :left ? :t : false
93+
jobvr = which == :both || which == :right ? :t : false
9594

9695
n = matrix.shape[0]
9796

9897
# Outputs
9998
eigenvalues = NMatrix.new([n, 1], dtype: matrix.dtype)
100-
# For real dtypes this holds only the real part of the eigenvalues.
99+
# For real dtypes this holds only the real part of the eigenvalues.
101100
imag_eigenvalues = matrix.complex_dtype? ? nil : NMatrix.new([n, 1], \
102-
dtype: matrix.dtype) # For complex dtypes, this is unused.
101+
dtype: matrix.dtype) # For complex dtypes, this is unused.
103102
left_output = jobvl ? matrix.clone_structure : nil
104103
right_output = jobvr ? matrix.clone_structure : nil
105104

106105
# lapack_geev is a pure LAPACK routine so it expects column-major matrices,
107106
# so we need to transpose the input as well as the output.
108107
temporary_matrix = matrix.transpose
109-
NMatrix::LAPACK::lapack_geev(jobvl, # compute left eigenvectors of A?
110-
jobvr, # compute right eigenvectors of A? (left eigenvectors of A**T)
111-
n, # order of the matrix
112-
temporary_matrix,# input matrix (used as work)
113-
n, # leading dimension of matrix
114-
eigenvalues,# real part of computed eigenvalues
115-
imag_eigenvalues,# imag part of computed eigenvalues
116-
left_output, # left eigenvectors, if applicable
117-
n, # leading dimension of left_output
118-
right_output, # right eigenvectors, if applicable
119-
n, # leading dimension of right_output
120-
2*n)
108+
NMatrix::LAPACK.lapack_geev(jobvl, # compute left eigenvectors of A?
109+
jobvr, # compute right eigenvectors of A? (left eigenvectors of A**T)
110+
n, # order of the matrix
111+
temporary_matrix, # input matrix (used as work)
112+
n, # leading dimension of matrix
113+
eigenvalues, # real part of computed eigenvalues
114+
imag_eigenvalues, # imag part of computed eigenvalues
115+
left_output, # left eigenvectors, if applicable
116+
n, # leading dimension of left_output
117+
right_output, # right eigenvectors, if applicable
118+
n, # leading dimension of right_output
119+
2 * n)
121120
left_output = left_output.transpose if jobvl
122121
right_output = right_output.transpose if jobvr
123122

124-
125123
# For real dtypes, transform left_output and right_output into correct forms.
126124
# If the j'th and the (j+1)'th eigenvalues form a complex conjugate
127125
# pair, then the j'th and (j+1)'th columns of the matrix are
128126
# the real and imag parts of the eigenvector corresponding
129127
# to the j'th eigenvalue.
130-
if !matrix.complex_dtype?
128+
unless matrix.complex_dtype?
131129
complex_indices = []
132130
n.times do |i|
133131
complex_indices << i if imag_eigenvalues[i] != 0.0
134132
end
135133

136-
if !complex_indices.empty?
134+
unless complex_indices.empty?
137135
# For real dtypes, put the real and imaginary parts together
138-
eigenvalues = eigenvalues + imag_eigenvalues * \
139-
Complex(0.0,1.0)
140-
left_output = left_output.cast(dtype: \
141-
NMatrix.upcast(:complex64, matrix.dtype)) if left_output
142-
right_output = right_output.cast(dtype: NMatrix.upcast(:complex64, \
143-
matrix.dtype)) if right_output
136+
eigenvalues += imag_eigenvalues * \
137+
Complex(0.0, 1.0)
138+
if left_output
139+
left_output = left_output.cast(dtype: \
140+
NMatrix.upcast(:complex64, matrix.dtype))
141+
end
142+
if right_output
143+
right_output = right_output.cast(dtype: NMatrix.upcast(:complex64, \
144+
matrix.dtype))
145+
end
144146
end
145147

146148
complex_indices.each_slice(2) do |i, _|
147149
if right_output
148-
right_output[0...n,i] = right_output[0...n,i] + \
149-
right_output[0...n,i+1] * Complex(0.0,1.0)
150-
right_output[0...n,i+1] = \
151-
right_output[0...n,i].complex_conjugate
150+
right_output[0...n, i] = right_output[0...n, i] + \
151+
right_output[0...n, i + 1] * Complex(0.0, 1.0)
152+
right_output[0...n, i + 1] = \
153+
right_output[0...n, i].complex_conjugate
152154
end
153155

154156
if left_output
155-
left_output[0...n,i] = left_output[0...n,i] + \
156-
left_output[0...n,i+1] * Complex(0.0,1.0)
157-
left_output[0...n,i+1] = left_output[0...n,i].complex_conjugate
157+
left_output[0...n, i] = left_output[0...n, i] + \
158+
left_output[0...n, i + 1] * Complex(0.0, 1.0)
159+
left_output[0...n, i + 1] = left_output[0...n, i].complex_conjugate
158160
end
159161
end
160162
end
@@ -168,7 +170,7 @@ def geev(matrix, which=:both)
168170
end
169171
end
170172

171-
def gesvd(matrix, workspace_size=1)
173+
def gesvd(matrix, workspace_size = 1)
172174
result = alloc_svd_result(matrix)
173175

174176
m = matrix.shape[0]
@@ -177,16 +179,16 @@ def gesvd(matrix, workspace_size=1)
177179
# This is a pure LAPACK function so it expects column-major functions.
178180
# So we need to transpose the input as well as the output.
179181
matrix = matrix.transpose
180-
NMatrix::LAPACK::lapack_gesvd(:a, :a, m, n, matrix, \
181-
m, result[1], result[0], m, result[2], n, workspace_size)
182+
NMatrix::LAPACK.lapack_gesvd(:a, :a, m, n, matrix, \
183+
m, result[1], result[0], m, result[2], n, workspace_size)
182184
result[0] = result[0].transpose
183185
result[2] = result[2].transpose
184186
result
185187
end
186188

187-
def gesdd(matrix, workspace_size=nil)
189+
def gesdd(matrix, workspace_size = nil)
188190
min_workspace_size = matrix.shape.min * \
189-
(6 + 4 * matrix.shape.min) + matrix.shape.max
191+
(6 + 4 * matrix.shape.min) + matrix.shape.max
190192
workspace_size = min_workspace_size if \
191193
workspace_size.nil? || workspace_size < min_workspace_size
192194

@@ -198,8 +200,8 @@ def gesdd(matrix, workspace_size=nil)
198200
# This is a pure LAPACK function so it expects column-major functions.
199201
# So we need to transpose the input as well as the output.
200202
matrix = matrix.transpose
201-
NMatrix::LAPACK::lapack_gesdd(:a, m, n, matrix, m, result[1], \
202-
result[0], m, result[2], n, workspace_size)
203+
NMatrix::LAPACK.lapack_gesdd(:a, m, n, matrix, m, result[1], \
204+
result[0], m, result[2], n, workspace_size)
203205
result[0] = result[0].transpose
204206
result[2] = result[2].transpose
205207
result
@@ -209,36 +211,36 @@ def gesdd(matrix, workspace_size=nil)
209211

210212
def invert!
211213
raise(StorageTypeError, "invert only works on dense matrices currently") \
212-
unless self.dense?
214+
unless dense?
213215

214216
raise(ShapeError, "Cannot invert non-square matrix") \
215217
unless shape[0] == shape[1]
216218

217219
raise(DataTypeError, "Cannot invert an integer matrix in-place") \
218-
if self.integer_dtype?
220+
if integer_dtype?
219221

220222
# Even though we are using the ATLAS plugin, we still might be missing
221223
# CLAPACK (and thus clapack_getri) if we are on OS X.
222224
if NMatrix.has_clapack?
223225
# Get the pivot array; factor the matrix
224226
# We can't used getrf! here since it doesn't have the clapack behavior,
225227
# so it doesn't play nicely with clapack_getri
226-
n = self.shape[0]
227-
pivot = NMatrix::LAPACK::clapack_getrf(:row, n, n, self, n)
228+
n = shape[0]
229+
pivot = NMatrix::LAPACK.clapack_getrf(:row, n, n, self, n)
228230
# Now calculate the inverse using the pivot array
229-
NMatrix::LAPACK::clapack_getri(:row, n, self, n, pivot)
231+
NMatrix::LAPACK.clapack_getri(:row, n, self, n, pivot)
230232
self
231233
else
232-
__inverse__(self,true)
234+
__inverse__(self, true)
233235
end
234236
end
235237

236238
def potrf!(which)
237239
raise(StorageTypeError, "ATLAS functions only work on dense matrices") \
238-
unless self.dense?
240+
unless dense?
239241
raise(ShapeError, "Cholesky decomposition only valid for square matrices") \
240-
unless self.dim == 2 && self.shape[0] == self.shape[1]
242+
unless dim == 2 && shape[0] == shape[1]
241243

242-
NMatrix::LAPACK::clapack_potrf(:row, which, self.shape[0], self, self.shape[1])
244+
NMatrix::LAPACK.clapack_potrf(:row, which, shape[0], self, shape[1])
243245
end
244246
end

0 commit comments

Comments
 (0)