#pragma once #include "./utils/matrix.h" #include "./core/omp_config.h" namespace numerics { template void inplace_eye(utils::Matrix& A, uint64_t N = 0){ bool need_full_zero = true; if (N != 0){ A.resize(N,N,T{0}); need_full_zero = false; }else{ N = A.rows(); if (N != A.cols()) { throw std::runtime_error("inplace_eye: non-square matrix"); } } // 1) Zero the whole matrix if we didn't just resize with zeros if (need_full_zero){ for (uint64_t i = 0; i < N; ++i){ for (uint64_t j = 0; j < N; ++j){ if (i==j){ A(i,j) = T{1}; }else{ A(i,j) = T{0}; } } } }else{ for (uint64_t i = 0; i < N; ++i){ A(i,i) = T{1}; } } } template void inplace_eye_omp(utils::Matrix& A, uint64_t N = 0){ bool need_full_zero = true; if (N != 0){ A.resize(N,N,T{0}); need_full_zero = false; }else{ N = A.rows(); if (N != A.cols()) { throw std::runtime_error("inplace_eye_omp: non-square matrix"); } } // 1) Zero the whole matrix if we didn't just resize with zeros if (need_full_zero){ T* ptr = A.data(); uint64_t NN = N*N; #pragma omp parallel for schedule(static) for (uint64_t i = 0; i < NN; ++i){ ptr[i] = T{0}; } } // 2) Set the diagonal to 1 #pragma omp parallel for schedule(static) for (uint64_t i = 0; i < N; ++i){ A(i,i) = T{1}; } } template utils::Matrix eye(uint64_t N){ utils::Matrix A; inplace_eye(A, N); return A; } template utils::Matrix eye_omp(uint64_t N){ utils::Matrix A; inplace_eye_omp(A, N); return A; } template utils::Matrix eye_omp_auto(uint64_t N){ uint64_t work = N*N; utils::Matrix A(N,N,T{0}); #ifdef _OPENMP bool can_parallel = omp_config::omp_parallel_allowed(); uint64_t threads = static_cast(omp_get_max_threads()); #else bool can_parallel = false; uint64_t threads = 1; #endif if (can_parallel || work > threads * 4ull) { inplace_eye_omp(A, 0); } else{ // Safe fallback inplace_eye(A, 0); } return A; } // Untested: template void inplace_eye_omp_auto(utils::Matrix& A, uint64_t N = 0){ uint64_t work = N*N; #ifdef _OPENMP bool can_parallel = omp_config::omp_parallel_allowed(); uint64_t threads = static_cast(omp_get_max_threads()); #else bool can_parallel = false; uint64_t threads = 1; #endif if (can_parallel || work > threads * 4ull) { inplace_eye_omp(A, 0); } else{ // Safe fallback inplace_eye(A, 0); } } } // namespace utils