Files
Flux-openbuild/include/numerics/matvec.h
2025-10-06 20:14:13 +00:00

156 lines
4.0 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#ifndef _matvec_n_
#define _matvec_n_
#include "./utils/matrix.h"
#include "./core/omp_config.h"
namespace numerics{
// =================================================
// y = A * x (MatrixVector product)
// =================================================
template <typename T>
utils::Vector<T> matvec(const utils::Matrix<T>& A, const utils::Vector<T>& x) {
if (A.cols() != x.size()) {
throw std::runtime_error("matvec: dimension mismatch");
}
const uint64_t m = A.rows();
const uint64_t n = A.cols();
utils::Vector<T> y(m, T{0});
for (uint64_t i = 0; i < m; ++i) {
for (uint64_t j = 0; j < n; ++j) {
y[i] += A(i, j) * x[j];
}
}
return y;
}
// -------------- Collapse(2) OpenMP ----------------
template <typename T>
utils::Vector<T> matvec_omp(const utils::Matrix<T>& A, const utils::Vector<T>& x) {
if (A.cols() != x.size()) {
throw std::runtime_error("matvec: dimension mismatch");
}
const uint64_t m = A.rows();
const uint64_t n = A.cols();
utils::Vector<T> y(m, T{0}); // <-- y has length m (rows)
const T* xptr = x.data();
const T* Aptr = A.data(); // row-major: A(i,j) == Aptr[i*n + j]
// Each row i is an independent dot product: y[i] = dot(A[i,*], x)
#pragma omp parallel for schedule(static)
for (uint64_t i = 0; i < m; ++i) {
const T* row = Aptr + i * n; // contiguous row i
T acc = T{0};
#pragma omp simd reduction(+:acc)
for (uint64_t j = 0; j < n; ++j) {
acc += row[j] * xptr[j];
}
y[i] = acc;
}
return y;
}
// -------------- Auto OpenMP ----------------
template <typename T>
utils::Vector<T> matvec_auto(const utils::Matrix<T>& A,
const utils::Vector<T>& x) {
uint64_t work = A.rows() * A.cols();
bool can_parallel = omp_config::omp_parallel_allowed();
#ifdef _OPENMP
int threads = omp_get_max_threads();
#else
int threads = 1;
#endif
if (can_parallel || work > static_cast<uint64_t>(threads) * 4ull) {
return matvec_omp(A,x);
}
else{
// Safe fallback
return matvec(A,x);
}
}
// =================================================
// y = x * A (VectorMatrix product)
// =================================================
template <typename T>
utils::Vector<T> vecmat(const utils::Vector<T>& x, const utils::Matrix<T>& A) {
if (x.size() != A.rows()) {
throw std::runtime_error("vecmat: dimension mismatch");
}
const uint64_t m = A.rows();
const uint64_t n = A.cols();
utils::Vector<T> y(n, T{0});
for (uint64_t j = 0; j < n; ++j) {
for (uint64_t i = 0; i < m; ++i) {
y[j] += x[i] * A(i, j);
}
}
return y;
}
// -------------- Collapse(2) OpenMP ----------------
template <typename T>
utils::Vector<T> vecmat_omp(const utils::Vector<T>& x, const utils::Matrix<T>& A) {
if (x.size() != A.rows()) {
throw std::runtime_error("vecmat: dimension mismatch");
}
const uint64_t m = A.rows();
const uint64_t n = A.cols();
utils::Vector<T> y(n, T{0});
#pragma omp parallel for schedule(static)
for (uint64_t j = 0; j < n; ++j) {
T acc = T{0};
for (uint64_t i = 0; i < m; ++i) {
acc += x[i] * A(i, j);
}
y[j] = acc;
}
return y;
}
// -------------- Auto OpenMP ----------------
template <typename T>
utils::Vector<T> vecmat_auto(const utils::Vector<T>& x,
const utils::Matrix<T>& A) {
uint64_t work = A.rows() * A.cols();
bool can_parallel = omp_config::omp_parallel_allowed();
#ifdef _OPENMP
int threads = omp_get_max_threads();
#else
int threads = 1;
#endif
if (can_parallel || work > static_cast<uint64_t>(threads) * 4ull) {
return vecmat_omp(x,A);
}
else{
// Safe fallback
return vecmat(x,A);
}
}
} // namespace numerics
#endif // _matvec_n_