Sync public subset from Flux (private)
This commit is contained in:
124
include/numerics/matmul.h
Normal file
124
include/numerics/matmul.h
Normal file
@@ -0,0 +1,124 @@
|
||||
#ifndef _matmul_n_
|
||||
#define _matmul_n_
|
||||
|
||||
|
||||
#include "./utils/matrix.h"
|
||||
#include "./core/omp_config.h"
|
||||
|
||||
|
||||
namespace numerics{
|
||||
|
||||
// ---------------- Serial baseline ----------------
|
||||
template <typename T>
|
||||
utils::Matrix<T> matmul(const utils::Matrix<T>& A, const utils::Matrix<T>& B){
|
||||
|
||||
if(A.cols() != B.rows()){
|
||||
throw std::runtime_error("matmul: dimension mismatch");
|
||||
}
|
||||
|
||||
const uint64_t m = A.rows();
|
||||
const uint64_t n = A.cols(); // also B.rows()
|
||||
const uint64_t p = B.cols();
|
||||
T tmp;
|
||||
|
||||
utils::Matrix<T> C(m, p, T{0});
|
||||
|
||||
for (uint64_t i = 0; i < m; ++i){
|
||||
for (uint64_t j = 0; j < n; ++j){
|
||||
tmp = A(i,j);
|
||||
for (uint64_t k = 0; k < p; ++k){
|
||||
C(i,k) += tmp * B(j,k);
|
||||
}
|
||||
}
|
||||
}
|
||||
return C;
|
||||
}
|
||||
|
||||
// ---------------- Rows-only OpenMP ----------------
|
||||
template <typename T>
|
||||
utils::Matrix<T> matmul_rows_omp(const utils::Matrix<T>& A,
|
||||
const utils::Matrix<T>& B) {
|
||||
if (A.cols() != B.rows()) throw std::runtime_error("matmul_rows_omp: dim mismatch");
|
||||
const uint64_t m=A.rows(), n=A.cols(), p=B.cols();
|
||||
|
||||
utils::Matrix<T> C(m, p, T{0});
|
||||
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (uint64_t i=0;i<m;++i) {
|
||||
for (uint64_t j=0;j<p;++j) {
|
||||
T acc=T{0};
|
||||
for (uint64_t k=0;k<n;++k) {
|
||||
acc += A(i,k)*B(k,j);
|
||||
}
|
||||
C(i,j)=acc;
|
||||
}
|
||||
}
|
||||
return C;
|
||||
}
|
||||
|
||||
// -------------- Collapse(2) OpenMP ----------------
|
||||
template <typename T>
|
||||
utils::Matrix<T> matmul_collapse_omp(const utils::Matrix<T>& A,
|
||||
const utils::Matrix<T>& B) {
|
||||
if (A.cols() != B.rows()) throw std::runtime_error("matmul_collapse_omp: dim mismatch");
|
||||
const uint64_t m=A.rows(), n=A.cols(), p=B.cols();
|
||||
utils::Matrix<T> C(m, p, T{0});
|
||||
|
||||
#pragma omp parallel for collapse(2) schedule(static)
|
||||
for (uint64_t i=0;i<m;++i) {
|
||||
for (uint64_t j=0;j<p;++j) {
|
||||
T acc=T{0};
|
||||
for (uint64_t k=0;k<n;++k){
|
||||
acc += A(i,k)*B(k,j);
|
||||
}
|
||||
C(i,j)=acc;
|
||||
}
|
||||
}
|
||||
return C;
|
||||
}
|
||||
|
||||
|
||||
// -------------------- Auto selector ---------------------
|
||||
template <typename T>
|
||||
utils::Matrix<T> matmul_auto(const utils::Matrix<T>& A,
|
||||
const utils::Matrix<T>& B) {
|
||||
const uint64_t m=A.rows(), p=B.cols();
|
||||
const uint64_t work = m * p;
|
||||
|
||||
|
||||
|
||||
#ifdef _OPENMP
|
||||
bool can_parallel = omp_config::omp_parallel_allowed();
|
||||
uint64_t threads = static_cast<uint64_t>(omp_get_max_threads());
|
||||
#else
|
||||
bool can_parallel = false;
|
||||
uint64_t threads = 1;
|
||||
#endif
|
||||
|
||||
|
||||
// Tiny problems: serial is cheapest.
|
||||
if (!can_parallel || work < threads*4ull) {
|
||||
|
||||
return matmul(A,B);
|
||||
}
|
||||
// Plenty of (i,j) work → collapse(2) is a great default.
|
||||
else if (work >= 8ull * threads) {
|
||||
return matmul_collapse_omp(A,B);
|
||||
}
|
||||
// Many rows and very few columns → rows-only cheaper overhead.
|
||||
else if (m >= static_cast<uint64_t>(threads) && p <= 4) {
|
||||
return matmul_rows_omp(A,B);
|
||||
}
|
||||
else{
|
||||
// Safe fallback
|
||||
return matmul(A,B);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
} // namespace numerics
|
||||
|
||||
#endif // _matmul_n_
|
||||
Reference in New Issue
Block a user