Files
Flux-openbuild/include/numerics/transpose.h
2025-10-06 20:14:13 +00:00

156 lines
3.3 KiB
C++

#ifndef _transpose_n_
#define _transpose_n_
#include "./utils/matrix.h"
#include "./core/omp_config.h"
namespace numerics{
template <typename T>
void inplace_transpose_square(utils::Matrix<T>& A){
const uint64_t rows = A.rows();
const uint64_t cols = A.cols();
if (rows != cols){
throw std::runtime_error("inplace_transpose only valid for square matrices");
}
for (uint64_t i = 0; i < rows; ++i){
for (uint64_t j = i + 1; j < cols; ++j){
T tmp = A(j,i);
A(j,i) = A(i,j);
A(i,j) = tmp;
//std::swap(A(j,i), A(i,j));
}
}
}
template <typename T>
void inplace_transpose_square_omp(utils::Matrix<T>& A){
const uint64_t rows = A.rows();
const uint64_t cols = A.cols();
if (rows != cols){
throw std::runtime_error("inplace_transpose only valid for square matrices");
}
#pragma omp parallel for schedule(static)
for (uint64_t i = 0; i < rows; ++i){
for (uint64_t j = i + 1; j < cols; ++j){
T tmp = A(j,i);
A(j,i) = A(i,j);
A(i,j) = tmp;
//std::swap(A(j,i), A(i,j));
}
}
}
template <typename T>
utils::Matrix<T> transpose(const utils::Matrix<T>& A){
const uint64_t rows = A.rows();
const uint64_t cols = A.cols();
utils::Matrix<T> B(cols, rows, T{0});
for (uint64_t i = 0; i < rows; ++i){
for (uint64_t j = 0; j < cols; ++j){
B(j,i) = A(i,j);
}
}
return B;
}
template <typename T>
utils::Matrix<T> transpose_omp(const utils::Matrix<T>& A){
const uint64_t rows = A.rows();
const uint64_t cols = A.cols();
utils::Matrix<T> B(cols, rows, T{0});
#pragma omp parallel for collapse(2) schedule(static)
for (uint64_t i = 0; i < rows; ++i){
for (uint64_t j = 0; j < cols; ++j){
B(j,i) = A(i,j);
}
}
return B;
}
// -------- Auto selectors --------
template <typename T>
void inplace_transpose_square_auto(utils::Matrix<T>& A) {
const uint64_t rows = A.rows(), cols = A.cols();
if (rows != cols) {
throw std::runtime_error("inplace_transpose_auto: only valid for square matrices");
}
const std::uint64_t work = static_cast<std::uint64_t>((rows * (rows - 1)) / 2); // number of swaps
#ifdef _OPENMP
bool can_parallel = omp_config::omp_parallel_allowed();
uint64_t threads = static_cast<std::uint64_t>(omp_get_max_threads());
#else
bool can_parallel = false;
uint64_t threads = 1;
#endif
if (can_parallel && work > threads * 4ull) {
inplace_transpose_square_omp(A);
}else {
inplace_transpose_square(A);
}
}
template <typename T>
utils::Matrix<T> transpose_auto(const utils::Matrix<T>& A) {
const uint64_t rows = A.rows();
const uint64_t cols = A.cols();
uint64_t work = A.rows() * A.cols();
if (rows==cols){
utils::Matrix<T> B = A;
inplace_transpose_square_auto(B);
return B;
}
#ifdef _OPENMP
bool can_parallel = omp_config::omp_parallel_allowed();
uint64_t threads = static_cast<std::uint64_t>(omp_get_max_threads());
#else
bool can_parallel = false;
uint64_t threads = 1;
#endif
if (!can_parallel || work > threads * 4ull) {
return transpose_omp(A);
} else {
return transpose(A);
}
}
} // namespace numerics
#endif // _transpose_n_