#ifndef _matmul_n_ #define _matmul_n_ #include "./utils/matrix.h" #include "./core/omp_config.h" namespace numerics{ // ---------------- Serial baseline ---------------- template utils::Matrix matmul(const utils::Matrix& A, const utils::Matrix& B){ if(A.cols() != B.rows()){ throw std::runtime_error("matmul: dimension mismatch"); } const uint64_t m = A.rows(); const uint64_t n = A.cols(); // also B.rows() const uint64_t p = B.cols(); T tmp; utils::Matrix C(m, p, T{0}); for (uint64_t i = 0; i < m; ++i){ for (uint64_t j = 0; j < n; ++j){ tmp = A(i,j); for (uint64_t k = 0; k < p; ++k){ C(i,k) += tmp * B(j,k); } } } return C; } // ---------------- Rows-only OpenMP ---------------- template utils::Matrix matmul_rows_omp(const utils::Matrix& A, const utils::Matrix& B) { if (A.cols() != B.rows()) throw std::runtime_error("matmul_rows_omp: dim mismatch"); const uint64_t m=A.rows(), n=A.cols(), p=B.cols(); utils::Matrix C(m, p, T{0}); #pragma omp parallel for schedule(static) for (uint64_t i=0;i utils::Matrix matmul_collapse_omp(const utils::Matrix& A, const utils::Matrix& B) { if (A.cols() != B.rows()) throw std::runtime_error("matmul_collapse_omp: dim mismatch"); const uint64_t m=A.rows(), n=A.cols(), p=B.cols(); utils::Matrix C(m, p, T{0}); #pragma omp parallel for collapse(2) schedule(static) for (uint64_t i=0;i utils::Matrix matmul_auto(const utils::Matrix& A, const utils::Matrix& B) { const uint64_t m=A.rows(), p=B.cols(); const uint64_t work = m * p; #ifdef _OPENMP bool can_parallel = omp_config::omp_parallel_allowed(); uint64_t threads = static_cast(omp_get_max_threads()); #else bool can_parallel = false; uint64_t threads = 1; #endif // Tiny problems: serial is cheapest. if (!can_parallel || work < threads*4ull) { return matmul(A,B); } // Plenty of (i,j) work → collapse(2) is a great default. else if (work >= 8ull * threads) { return matmul_collapse_omp(A,B); } // Many rows and very few columns → rows-only cheaper overhead. else if (m >= static_cast(threads) && p <= 4) { return matmul_rows_omp(A,B); } else{ // Safe fallback return matmul(A,B); } } } // namespace numerics #endif // _matmul_n_