Files
Flux-openbuild/test/test_transpose.cpp
2025-10-06 20:21:40 +00:00

151 lines
4.6 KiB
C++

#include "test_common.h"
//#include "./utils/matrix.h" // matrix.h, vector.h
#include "./numerics/transpose.h" // numerics::transpose / inplace_transpose
using utils::Mi; using utils::Mf; using utils::Md;
/// ---- helpers ----
template <typename T>
static void fill_seq(utils::Matrix<T>& M, T start = T(0), T step = T(1)) {
std::uint64_t k = 0;
for (std::uint64_t i=0; i<M.rows(); ++i)
for (std::uint64_t j=0; j<M.cols(); ++j, ++k)
M(i,j) = start + step * static_cast<T>(k);
}
template <typename T>
static bool mats_equal(const utils::Matrix<T>& X, const utils::Matrix<T>& Y) {
if (X.rows()!=Y.rows() || X.cols()!=Y.cols()) return false;
for (std::uint64_t i=0; i<X.rows(); ++i)
for (std::uint64_t j=0; j<X.cols(); ++j)
if (X(i,j) != Y(i,j)) return false;
return true;
}
// ---- tests ----
// Empty and 1x1 edge cases
TEST_CASE(Transpose_Edges) {
utils::Mi E; // 0x0
auto Et = numerics::transpose(E);
CHECK(Et.rows()==0 && Et.cols()==0, "transpose of empty should be empty");
utils::Mi S(1,1,42);
auto St = numerics::transpose(S);
CHECK(St.rows()==1 && St.cols()==1, "1x1 stays 1x1");
CHECK(St(0,0)==42, "1x1 value preserved");
}
// Rectangular out-of-place
TEST_CASE(Transpose_Rectangular) {
const std::uint64_t r=3, c=5;
utils::Mi A(r,c,0);
fill_seq(A, int64_t(1), int64_t(1));
auto B = numerics::transpose(A);
CHECK(B.rows()==c && B.cols()==r, "shape swapped");
for (std::uint64_t i=0;i<r;++i)
for (std::uint64_t j=0;j<c;++j)
CHECK(B(j,i)==A(i,j), "transpose content mismatch");
}
// Square: in-place equals out-of-place
TEST_CASE(Transpose_Inplace_Equals_OutOfPlace) {
const std::uint64_t n=7;
utils::Mi A(n,n,0);
fill_seq(A, int64_t(10), int64_t(3));
auto B = numerics::transpose(A);
auto C = A; // copy
numerics::inplace_transpose_square(C);
CHECK(mats_equal(B, C), "inplace transpose should match out-of-place");
}
// In-place should throw on non-square
TEST_CASE(Transpose_Inplace_Throws_On_Rect) {
utils::Mi A(2,3,0);
bool threw=false;
try { numerics::inplace_transpose_square(A); } catch(const std::runtime_error&) { threw=true; }
CHECK(threw, "inplace_transpose_square must throw on non-square");
}
// --- OMP variants (compile only with -fopenmp) ---
#ifdef _OPENMP
TEST_CASE(Transpose_OMP_OutOfPlace_Equals_Serial) {
const std::uint64_t r=17, c=31;
utils::Mi A(r,c,0);
fill_seq(A, int64_t(5), int64_t(2));
auto B_serial = numerics::transpose(A);
auto B_omp = numerics::transpose_omp(A);
CHECK(mats_equal(B_serial, B_omp), "transpose_omp != transpose");
}
TEST_CASE(Transpose_OMP_Inplace_Equals_Serial) {
const std::uint64_t n=32;
utils::Mi A(n,n,0);
fill_seq(A, int64_t(0), int64_t(1));
auto B_ref = numerics::transpose(A);
auto C = A;
numerics::inplace_transpose_square_omp(C);
CHECK(mats_equal(B_ref, C), "inplace_transpose_square_omp != transpose");
}
// Auto selectors (if you added transpose_auto / inplace_transpose_square_auto_auto)
TEST_CASE(Transpose_Auto_Equals_Serial) {
// Rectangular: transpose_auto
utils::Mi A(23,11,0);
fill_seq(A, int64_t(1), int64_t(1));
auto B_ref = numerics::transpose(A);
auto B_auto = numerics::transpose_auto(A);
CHECK(mats_equal(B_ref, B_auto), "transpose_auto != transpose");
// Square: inplace_transpose_square_auto
utils::Mi S(29,29,0);
fill_seq(S, int64_t(4), int64_t(7));
auto Tref = numerics::transpose(S);
numerics::inplace_transpose_square_auto(S);
CHECK(mats_equal(Tref, S), "inplace_transpose_square_auto != transpose");
}
// Nested callsite sanity: call OMP versions inside an outer region
TEST_CASE(Transpose_OMP_Nested_Callsite) {
// Out-of-place on rectangular
utils::Mi A(19,37,0);
fill_seq(A, int64_t(2), int64_t(3));
auto Bref = numerics::transpose(A);
int prev_levels = omp_get_max_active_levels();
omp_set_max_active_levels(2);
utils::Mi Bnest;
#pragma omp parallel num_threads(2)
{
#pragma omp single
{
Bnest = numerics::transpose_omp(A);
}
}
CHECK(mats_equal(Bref, Bnest), "nested transpose_omp mismatch");
// In-place on square
utils::Mi S(41,41,0);
fill_seq(S, int64_t(1), int64_t(5));
auto Sref = numerics::transpose(S);
#pragma omp parallel num_threads(2)
{
#pragma omp single
{
numerics::inplace_transpose_square_omp(S);
}
}
omp_set_max_active_levels(prev_levels);
CHECK(mats_equal(Sref, S), "nested inplace_transpose_square_omp mismatch");
}
#endif