Skip to content

Commit 1539577

Browse files
authored
Merge pull request #349 from rust-ndarray/lax-tridiagonal
Add `LuTridiagonalWork`, merge `Tridiagonal_` into `Lapack`
2 parents ad19250 + 4913818 commit 1539577

File tree

7 files changed

+412
-270
lines changed

7 files changed

+412
-270
lines changed

lax/src/lib.rs

+42-11
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,14 @@ extern crate openblas_src as _src;
8484
#[cfg(any(feature = "netlib-system", feature = "netlib-static"))]
8585
extern crate netlib_src as _src;
8686

87-
pub mod error;
88-
pub mod flags;
89-
pub mod layout;
90-
87+
pub mod alloc;
9188
pub mod cholesky;
9289
pub mod eig;
9390
pub mod eigh;
9491
pub mod eigh_generalized;
92+
pub mod error;
93+
pub mod flags;
94+
pub mod layout;
9595
pub mod least_squares;
9696
pub mod opnorm;
9797
pub mod qr;
@@ -101,16 +101,12 @@ pub mod solveh;
101101
pub mod svd;
102102
pub mod svddc;
103103
pub mod triangular;
104+
pub mod tridiagonal;
104105

105-
mod alloc;
106-
mod tridiagonal;
107-
108-
pub use self::cholesky::*;
109106
pub use self::flags::*;
110107
pub use self::least_squares::LeastSquaresOwned;
111-
pub use self::opnorm::*;
112108
pub use self::svd::{SvdOwned, SvdRef};
113-
pub use self::tridiagonal::*;
109+
pub use self::tridiagonal::{LUFactorizedTridiagonal, Tridiagonal};
114110

115111
use self::{alloc::*, error::*, layout::*};
116112
use cauchy::*;
@@ -120,7 +116,7 @@ pub type Pivot = Vec<i32>;
120116

121117
#[cfg_attr(doc, katexit::katexit)]
122118
/// Trait for primitive types which implements LAPACK subroutines
123-
pub trait Lapack: Tridiagonal_ {
119+
pub trait Lapack: Scalar {
124120
/// Compute right eigenvalue and eigenvectors for a general matrix
125121
fn eig(
126122
calc_v: bool,
@@ -306,6 +302,19 @@ pub trait Lapack: Tridiagonal_ {
306302
a: &[Self],
307303
b: &mut [Self],
308304
) -> Result<()>;
305+
306+
/// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using
307+
/// partial pivoting with row interchanges.
308+
fn lu_tridiagonal(a: Tridiagonal<Self>) -> Result<LUFactorizedTridiagonal<Self>>;
309+
310+
fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real>;
311+
312+
fn solve_tridiagonal(
313+
lu: &LUFactorizedTridiagonal<Self>,
314+
bl: MatrixLayout,
315+
t: Transpose,
316+
b: &mut [Self],
317+
) -> Result<()>;
309318
}
310319

311320
macro_rules! impl_lapack {
@@ -491,6 +500,28 @@ macro_rules! impl_lapack {
491500
use triangular::*;
492501
SolveTriangularImpl::solve_triangular(al, bl, uplo, d, a, b)
493502
}
503+
504+
fn lu_tridiagonal(a: Tridiagonal<Self>) -> Result<LUFactorizedTridiagonal<Self>> {
505+
use tridiagonal::*;
506+
let work = LuTridiagonalWork::<$s>::new(a.l);
507+
work.eval(a)
508+
}
509+
510+
fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real> {
511+
use tridiagonal::*;
512+
let mut work = RcondTridiagonalWork::<$s>::new(lu.a.l);
513+
work.calc(lu)
514+
}
515+
516+
fn solve_tridiagonal(
517+
lu: &LUFactorizedTridiagonal<Self>,
518+
bl: MatrixLayout,
519+
t: Transpose,
520+
b: &mut [Self],
521+
) -> Result<()> {
522+
use tridiagonal::*;
523+
SolveTridiagonalImpl::solve_tridiagonal(lu, bl, t, b)
524+
}
494525
}
495526
};
496527
}

lax/src/tridiagonal.rs

-259
This file was deleted.

0 commit comments

Comments
 (0)