Files
2025-10-06 20:14:13 +00:00

138 lines
2.8 KiB
C++

#pragma once
#include "./utils/matrix.h"
#include "./core/omp_config.h"
namespace numerics {
template <typename T>
void inplace_eye(utils::Matrix<T>& A, uint64_t N = 0){
bool need_full_zero = true;
if (N != 0){
A.resize(N,N,T{0});
need_full_zero = false;
}else{
N = A.rows();
if (N != A.cols()) {
throw std::runtime_error("inplace_eye: non-square matrix");
}
}
// 1) Zero the whole matrix if we didn't just resize with zeros
if (need_full_zero){
for (uint64_t i = 0; i < N; ++i){
for (uint64_t j = 0; j < N; ++j){
if (i==j){
A(i,j) = T{1};
}else{
A(i,j) = T{0};
}
}
}
}else{
for (uint64_t i = 0; i < N; ++i){
A(i,i) = T{1};
}
}
}
template <typename T>
void inplace_eye_omp(utils::Matrix<T>& A, uint64_t N = 0){
bool need_full_zero = true;
if (N != 0){
A.resize(N,N,T{0});
need_full_zero = false;
}else{
N = A.rows();
if (N != A.cols()) {
throw std::runtime_error("inplace_eye_omp: non-square matrix");
}
}
// 1) Zero the whole matrix if we didn't just resize with zeros
if (need_full_zero){
T* ptr = A.data();
uint64_t NN = N*N;
#pragma omp parallel for schedule(static)
for (uint64_t i = 0; i < NN; ++i){
ptr[i] = T{0};
}
}
// 2) Set the diagonal to 1
#pragma omp parallel for schedule(static)
for (uint64_t i = 0; i < N; ++i){
A(i,i) = T{1};
}
}
template <typename T>
utils::Matrix<T> eye(uint64_t N){
utils::Matrix<T> A;
inplace_eye(A, N);
return A;
}
template <typename T>
utils::Matrix<T> eye_omp(uint64_t N){
utils::Matrix<T> A;
inplace_eye_omp(A, N);
return A;
}
template <typename T>
utils::Matrix<T> eye_omp_auto(uint64_t N){
uint64_t work = N*N;
utils::Matrix<T> A(N,N,T{0});
#ifdef _OPENMP
bool can_parallel = omp_config::omp_parallel_allowed();
uint64_t threads = static_cast<uint64_t>(omp_get_max_threads());
#else
bool can_parallel = false;
uint64_t threads = 1;
#endif
if (can_parallel || work > threads * 4ull) {
inplace_eye_omp(A, 0);
}
else{
// Safe fallback
inplace_eye(A, 0);
}
return A;
}
// Untested:
template <typename T>
void inplace_eye_omp_auto(utils::Matrix<T>& A, uint64_t N = 0){
uint64_t work = N*N;
#ifdef _OPENMP
bool can_parallel = omp_config::omp_parallel_allowed();
uint64_t threads = static_cast<uint64_t>(omp_get_max_threads());
#else
bool can_parallel = false;
uint64_t threads = 1;
#endif
if (can_parallel || work > threads * 4ull) {
inplace_eye_omp(A, 0);
}
else{
// Safe fallback
inplace_eye(A, 0);
}
}
} // namespace utils