Sync public subset from Flux (private)
This commit is contained in:
138
include/numerics/initializers/eye.h
Normal file
138
include/numerics/initializers/eye.h
Normal file
@@ -0,0 +1,138 @@
|
||||
#pragma once
|
||||
#include "./utils/matrix.h"
|
||||
#include "./core/omp_config.h"
|
||||
|
||||
|
||||
namespace numerics {
|
||||
|
||||
template <typename T>
|
||||
void inplace_eye(utils::Matrix<T>& A, uint64_t N = 0){
|
||||
|
||||
bool need_full_zero = true;
|
||||
|
||||
if (N != 0){
|
||||
A.resize(N,N,T{0});
|
||||
need_full_zero = false;
|
||||
}else{
|
||||
N = A.rows();
|
||||
if (N != A.cols()) {
|
||||
throw std::runtime_error("inplace_eye: non-square matrix");
|
||||
}
|
||||
}
|
||||
// 1) Zero the whole matrix if we didn't just resize with zeros
|
||||
if (need_full_zero){
|
||||
for (uint64_t i = 0; i < N; ++i){
|
||||
for (uint64_t j = 0; j < N; ++j){
|
||||
if (i==j){
|
||||
A(i,j) = T{1};
|
||||
}else{
|
||||
A(i,j) = T{0};
|
||||
}
|
||||
}
|
||||
}
|
||||
}else{
|
||||
for (uint64_t i = 0; i < N; ++i){
|
||||
A(i,i) = T{1};
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
void inplace_eye_omp(utils::Matrix<T>& A, uint64_t N = 0){
|
||||
|
||||
bool need_full_zero = true;
|
||||
|
||||
if (N != 0){
|
||||
A.resize(N,N,T{0});
|
||||
need_full_zero = false;
|
||||
}else{
|
||||
N = A.rows();
|
||||
if (N != A.cols()) {
|
||||
throw std::runtime_error("inplace_eye_omp: non-square matrix");
|
||||
}
|
||||
}
|
||||
|
||||
// 1) Zero the whole matrix if we didn't just resize with zeros
|
||||
if (need_full_zero){
|
||||
T* ptr = A.data();
|
||||
uint64_t NN = N*N;
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (uint64_t i = 0; i < NN; ++i){
|
||||
ptr[i] = T{0};
|
||||
}
|
||||
}
|
||||
// 2) Set the diagonal to 1
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (uint64_t i = 0; i < N; ++i){
|
||||
A(i,i) = T{1};
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
utils::Matrix<T> eye(uint64_t N){
|
||||
utils::Matrix<T> A;
|
||||
inplace_eye(A, N);
|
||||
return A;
|
||||
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
utils::Matrix<T> eye_omp(uint64_t N){
|
||||
utils::Matrix<T> A;
|
||||
inplace_eye_omp(A, N);
|
||||
return A;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
utils::Matrix<T> eye_omp_auto(uint64_t N){
|
||||
|
||||
uint64_t work = N*N;
|
||||
utils::Matrix<T> A(N,N,T{0});
|
||||
|
||||
#ifdef _OPENMP
|
||||
bool can_parallel = omp_config::omp_parallel_allowed();
|
||||
uint64_t threads = static_cast<uint64_t>(omp_get_max_threads());
|
||||
#else
|
||||
bool can_parallel = false;
|
||||
uint64_t threads = 1;
|
||||
#endif
|
||||
|
||||
if (can_parallel || work > threads * 4ull) {
|
||||
inplace_eye_omp(A, 0);
|
||||
}
|
||||
else{
|
||||
// Safe fallback
|
||||
inplace_eye(A, 0);
|
||||
}
|
||||
|
||||
return A;
|
||||
}
|
||||
// Untested:
|
||||
template <typename T>
|
||||
void inplace_eye_omp_auto(utils::Matrix<T>& A, uint64_t N = 0){
|
||||
|
||||
uint64_t work = N*N;
|
||||
|
||||
#ifdef _OPENMP
|
||||
bool can_parallel = omp_config::omp_parallel_allowed();
|
||||
uint64_t threads = static_cast<uint64_t>(omp_get_max_threads());
|
||||
#else
|
||||
bool can_parallel = false;
|
||||
uint64_t threads = 1;
|
||||
#endif
|
||||
|
||||
if (can_parallel || work > threads * 4ull) {
|
||||
inplace_eye_omp(A, 0);
|
||||
}
|
||||
else{
|
||||
// Safe fallback
|
||||
inplace_eye(A, 0);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
} // namespace utils
|
||||
Reference in New Issue
Block a user