Sync public subset from Flux (private)
This commit is contained in:
180
include/decomp/lu.h
Normal file
180
include/decomp/lu.h
Normal file
@@ -0,0 +1,180 @@
|
||||
#pragma once
|
||||
|
||||
#include "./utils/vector.h"
|
||||
#include "./utils/matrix.h"
|
||||
|
||||
#include "./numerics/initializers/eye.h"
|
||||
|
||||
namespace decomp{
|
||||
|
||||
// Stores PA = LU with partial pivoting (row permutations).
|
||||
template <typename T>
|
||||
struct LUdcmp{
|
||||
uint64_t rows; // Stores number of rows.
|
||||
utils::Matrix<T> lu; // Stores the decomposition.
|
||||
std::vector<uint64_t> indx; // Stores the permutation.
|
||||
T d; // Used by det.
|
||||
|
||||
// Default Constructor
|
||||
LUdcmp() = default;
|
||||
|
||||
// Constructor
|
||||
LUdcmp(const utils::Matrix<T>& A){
|
||||
decomp(A);
|
||||
}
|
||||
|
||||
void decomp(const utils::Matrix<T>&A){
|
||||
|
||||
rows = A.rows();
|
||||
|
||||
if (rows != A.cols()){
|
||||
throw std::runtime_error("LUdcmp: decomp non-square");
|
||||
}
|
||||
|
||||
uint64_t imax{0};
|
||||
lu = A;
|
||||
indx.resize(rows);
|
||||
std::vector<T> vv(rows); // vv stores the implicit scaling of each row.
|
||||
T big{T{0}}, tmp{T{0}};// TINY{T{1.0e-40}};
|
||||
|
||||
d = T{1}; // No row interchanges yet.
|
||||
|
||||
// Loop over rows to get the implicit scaling information.
|
||||
for (uint64_t i = 0; i < rows; ++i){
|
||||
big = T{0};
|
||||
for (uint64_t j = 0; j < rows; ++j){
|
||||
tmp = std::abs(lu(i,j));
|
||||
if (tmp > big){
|
||||
big = tmp;
|
||||
}
|
||||
}
|
||||
if (big == T{0}){
|
||||
throw std::runtime_error("LUdcmp: Singular matrix");
|
||||
}
|
||||
// No nonzero largest element.
|
||||
vv[i] = T{1}/big; // Save the scaling.
|
||||
}
|
||||
// This is the outermost kij loop. Initialize for the search for largest pivot element.
|
||||
for (uint64_t k = 0; k < rows; ++k){
|
||||
big = T{0};
|
||||
imax = k;
|
||||
for (uint64_t i = k; i < rows; ++i){
|
||||
tmp = vv[i] * static_cast<T>(std::fabs(static_cast<double>(lu(i,k))));
|
||||
if (tmp > big){ // Is the figure of merit for the pivot better than the best so far?
|
||||
big = tmp;
|
||||
imax = i;
|
||||
}
|
||||
}
|
||||
if (k != imax){ // Do we need to interchange rows?
|
||||
lu.swap_rows(imax, k); // Yes, do so...
|
||||
d = -d; // ...and change the parity of d.
|
||||
vv[imax] = vv[k]; // Also interchange the scale factor.
|
||||
}
|
||||
indx[k] = imax;
|
||||
if (lu(k,k) == T{0.0}){ // if the pivot element is zero, the matrix is singular (at least to the precision of thealgorithm).
|
||||
throw std::runtime_error("LUdcmp: Singular matrix");
|
||||
//lu(k,k) = TINY; // For some applications on singular matrices, it is desirable to substitute TINY for zero.
|
||||
}
|
||||
for (uint64_t i = k+1; i < rows; ++i){
|
||||
tmp = lu(i,k) /= lu(k,k); // Divide by the pivot element.
|
||||
for (uint64_t j = k+1; j < rows; ++j){ // Innermost loop: reduce remaining submatrix.
|
||||
lu(i,j) -= tmp*lu(k,j);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // end void decomp(const utils::Matrix<T>&A)
|
||||
|
||||
// Solves the set of n linear equations A*x=b using the stored LU decomposition of A.
|
||||
void inplace_solve(const utils::Vector<T>& b, utils::Vector<T>& x){
|
||||
T sum{T{0}};
|
||||
|
||||
uint64_t ii{0}, ip{0};
|
||||
|
||||
if (b.size() != rows || x.size() != rows){
|
||||
throw std::runtime_error("LUdcmp: inplace_solve bad sizes");
|
||||
}
|
||||
x = b;
|
||||
|
||||
for (uint64_t i = 0; i < rows; ++i){ // When ii is set to a positive value, it will become the index of the first nonvanishing element of b.
|
||||
ip = indx[i];
|
||||
sum = x[ip];
|
||||
x[ip] = x[i];
|
||||
if (ii >= 0){
|
||||
for (uint64_t j = ii; j < i; ++j){
|
||||
sum -= lu(i,j)*x[j];
|
||||
}
|
||||
}else if (sum != T{0}) { // A nonzero element was encountered, so from now on we will have to do the sums in the loop above.
|
||||
ii = i+1;
|
||||
}
|
||||
x[i] = sum;
|
||||
}
|
||||
for (int64_t i = static_cast<int64_t>(rows)-1; i >= 0; --i){ // Now we do the backsubstitution.
|
||||
sum=x[i];
|
||||
for (uint64_t j = static_cast<uint64_t>(i)+1; j < rows; ++j){
|
||||
sum -= lu(static_cast<uint64_t>(i),j)*x[j];
|
||||
}
|
||||
x[static_cast<uint64_t>(i)] = sum/lu(static_cast<uint64_t>(i),static_cast<uint64_t>(i)); // Store a component of the solution vector x.
|
||||
}
|
||||
} // end inplace_solve(utils::Vector<T>& b, utils::Vector<T>& x)
|
||||
|
||||
// SSolves m sets of n linear equations A*X=B using the stored LU decomposition of A.
|
||||
void inplace_solve(const utils::Matrix<T>& B, utils::Matrix<T>& X) {
|
||||
|
||||
uint64_t m{B.cols()};
|
||||
|
||||
if (B.rows() != rows || X.rows() != rows || B.cols() != X.cols()){
|
||||
throw std::runtime_error("LUdcmp: inplace_solve bad sizes");
|
||||
}
|
||||
|
||||
utils::Vector<T> xx(rows);
|
||||
|
||||
for (uint64_t j = 0; j < m; ++j){ // Copy and solve each column in turn.
|
||||
|
||||
xx = B.get_col(j);
|
||||
inplace_solve(xx,xx);
|
||||
X.set_col(j, xx);
|
||||
}
|
||||
|
||||
} // end inplace_solve(utils::Matrix<T>& B, utils::Matrix<T>& X)
|
||||
|
||||
// Solves the set of n linear equations A*x=b using the stored LU decomposition of A.
|
||||
utils::Vector<T> solve(const utils::Vector<T>& b) {
|
||||
utils::Vector<T> x(rows,T{0});
|
||||
inplace_solve(b, x);
|
||||
return x;
|
||||
}
|
||||
|
||||
// Solves the set of n linear equations A*X=B using the stored LU decomposition of A.
|
||||
utils::Matrix<T> solve(const utils::Matrix<T>& B) {
|
||||
utils::Matrix<T> X(rows,B.cols(),T{0});
|
||||
inplace_solve(B, X);
|
||||
return X;
|
||||
}
|
||||
|
||||
void inplace_inverse(utils::Matrix<T>& Ainv){
|
||||
numerics::inplace_eye<T>(Ainv);
|
||||
inplace_solve(Ainv, Ainv);
|
||||
}
|
||||
|
||||
utils::Matrix<T> inverse(){
|
||||
utils::Matrix<T> Ainv;
|
||||
inplace_inverse(Ainv);
|
||||
return Ainv;
|
||||
}
|
||||
|
||||
T det(){
|
||||
T dd = d;
|
||||
for (uint64_t i = 0; i < rows; ++i){
|
||||
dd *= lu(i,i);
|
||||
}
|
||||
return dd;
|
||||
}
|
||||
|
||||
|
||||
}; // struct LU
|
||||
|
||||
typedef LUdcmp<float> LUdcmpf;
|
||||
typedef LUdcmp<double> LUdcmpd;
|
||||
|
||||
|
||||
} // namespace decomp
|
||||
Reference in New Issue
Block a user