#pragma once #include "./core/omp_config.h" #include "./utils/matrix.h" #include "./numerics/abs.h" namespace numerics{ // -------------- Serial ---------------- template bool matequal(const utils::Matrix& A, const utils::Matrix& B, double tol = 1e-9) { if (A.rows() != B.rows() || A.cols() != B.cols()) { return false; } bool decimal = std::is_floating_point::value; const uint64_t rows=A.rows(), cols=A.cols(); for (uint64_t i = 0; i < rows; ++i) for (uint64_t j = 0; j < cols; ++j) if (decimal) { if (numerics::abs(A(i,j) - B(i,j)) > static_cast(tol)){ return false; } } else { if (A(i,j) != B(i,j)){ return false; } } return true; } // -------------- Parallel ---------------- template bool matequal_omp(const utils::Matrix& A, const utils::Matrix& B, double tol = 1e-9) { if (A.rows() != B.rows() || A.cols() != B.cols()) { return false; } bool decimal = std::is_floating_point::value; bool eq = true; const uint64_t rows=A.rows(), cols=A.cols(); #pragma omp parallel for collapse(2) schedule(static) reduction(&&:eq) for (uint64_t i = 0; i < rows; ++i) for (uint64_t j = 0; j < cols; ++j) if (decimal) { eq = eq && (numerics::abs(A(i,j) - B(i,j)) <= static_cast(tol)); } else { eq = eq && (A(i,j) == B(i,j)); } return eq; } // -------------- Auto OpenMP ---------------- template bool matequal_auto(const utils::Matrix& A, const utils::Matrix& B, double tol = 1e-9) { if (A.rows() != B.rows() || A.cols() != B.cols()) { return false; } uint64_t work = A.rows() * A.cols(); #ifdef _OPENMP bool can_parallel = omp_config::omp_parallel_allowed(); uint64_t threads = static_cast(omp_get_max_threads()); #else bool can_parallel = false; uint64_t threads = 1; #endif if (can_parallel || work > threads * 4ull) { return matequal_omp(A,B,tol); } else{ // Safe fallback return matequal(A,B,tol); } } } // namespace numerics