138 lines
2.8 KiB
C++
138 lines
2.8 KiB
C++
#pragma once
|
|
#include "./utils/matrix.h"
|
|
#include "./core/omp_config.h"
|
|
|
|
|
|
namespace numerics {
|
|
|
|
template <typename T>
|
|
void inplace_eye(utils::Matrix<T>& A, uint64_t N = 0){
|
|
|
|
bool need_full_zero = true;
|
|
|
|
if (N != 0){
|
|
A.resize(N,N,T{0});
|
|
need_full_zero = false;
|
|
}else{
|
|
N = A.rows();
|
|
if (N != A.cols()) {
|
|
throw std::runtime_error("inplace_eye: non-square matrix");
|
|
}
|
|
}
|
|
// 1) Zero the whole matrix if we didn't just resize with zeros
|
|
if (need_full_zero){
|
|
for (uint64_t i = 0; i < N; ++i){
|
|
for (uint64_t j = 0; j < N; ++j){
|
|
if (i==j){
|
|
A(i,j) = T{1};
|
|
}else{
|
|
A(i,j) = T{0};
|
|
}
|
|
}
|
|
}
|
|
}else{
|
|
for (uint64_t i = 0; i < N; ++i){
|
|
A(i,i) = T{1};
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
|
|
template <typename T>
|
|
void inplace_eye_omp(utils::Matrix<T>& A, uint64_t N = 0){
|
|
|
|
bool need_full_zero = true;
|
|
|
|
if (N != 0){
|
|
A.resize(N,N,T{0});
|
|
need_full_zero = false;
|
|
}else{
|
|
N = A.rows();
|
|
if (N != A.cols()) {
|
|
throw std::runtime_error("inplace_eye_omp: non-square matrix");
|
|
}
|
|
}
|
|
|
|
// 1) Zero the whole matrix if we didn't just resize with zeros
|
|
if (need_full_zero){
|
|
T* ptr = A.data();
|
|
uint64_t NN = N*N;
|
|
#pragma omp parallel for schedule(static)
|
|
for (uint64_t i = 0; i < NN; ++i){
|
|
ptr[i] = T{0};
|
|
}
|
|
}
|
|
// 2) Set the diagonal to 1
|
|
#pragma omp parallel for schedule(static)
|
|
for (uint64_t i = 0; i < N; ++i){
|
|
A(i,i) = T{1};
|
|
}
|
|
|
|
}
|
|
|
|
template <typename T>
|
|
utils::Matrix<T> eye(uint64_t N){
|
|
utils::Matrix<T> A;
|
|
inplace_eye(A, N);
|
|
return A;
|
|
|
|
}
|
|
|
|
template <typename T>
|
|
utils::Matrix<T> eye_omp(uint64_t N){
|
|
utils::Matrix<T> A;
|
|
inplace_eye_omp(A, N);
|
|
return A;
|
|
}
|
|
|
|
template <typename T>
|
|
utils::Matrix<T> eye_omp_auto(uint64_t N){
|
|
|
|
uint64_t work = N*N;
|
|
utils::Matrix<T> A(N,N,T{0});
|
|
|
|
#ifdef _OPENMP
|
|
bool can_parallel = omp_config::omp_parallel_allowed();
|
|
uint64_t threads = static_cast<uint64_t>(omp_get_max_threads());
|
|
#else
|
|
bool can_parallel = false;
|
|
uint64_t threads = 1;
|
|
#endif
|
|
|
|
if (can_parallel || work > threads * 4ull) {
|
|
inplace_eye_omp(A, 0);
|
|
}
|
|
else{
|
|
// Safe fallback
|
|
inplace_eye(A, 0);
|
|
}
|
|
|
|
return A;
|
|
}
|
|
// Untested:
|
|
template <typename T>
|
|
void inplace_eye_omp_auto(utils::Matrix<T>& A, uint64_t N = 0){
|
|
|
|
uint64_t work = N*N;
|
|
|
|
#ifdef _OPENMP
|
|
bool can_parallel = omp_config::omp_parallel_allowed();
|
|
uint64_t threads = static_cast<uint64_t>(omp_get_max_threads());
|
|
#else
|
|
bool can_parallel = false;
|
|
uint64_t threads = 1;
|
|
#endif
|
|
|
|
if (can_parallel || work > threads * 4ull) {
|
|
inplace_eye_omp(A, 0);
|
|
}
|
|
else{
|
|
// Safe fallback
|
|
inplace_eye(A, 0);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
} // namespace utils
|