Files
Flux-openbuild/test/test_vector.cpp
2025-10-06 20:21:40 +00:00

196 lines
5.8 KiB
C++

#include "test_common.h"
#include "./utils/utils.h"
using utils::Vf; using utils::Vd; using utils::Vi;
// ---------- helpers ----------
template <typename T>
static bool vec_equal_exact(const utils::Vector<T>& a, const utils::Vector<T>& b) {
if (a.size() != b.size()) return false;
for (std::uint64_t i=0;i<a.size();++i) if (a[i]!=b[i]) return false;
return true;
}
// ---------- tests ----------
TEST_CASE(Vector_Construct_Size_Fill) {
utils::Vd v0;
CHECK(v0.size()==0, "default size should be 0");
utils::Vd v1(5, 3.5);
CHECK(v1.size()==5, "size 5");
for (std::uint64_t i=0;i<5;++i) CHECK(v1[i]==3.5, "filled value 3.5");
}
TEST_CASE(Vector_PushBack_Resize) {
utils::Vi v;
v.push_back(1); v.push_back(2);
CHECK(v.size()==2, "push_back size");
CHECK(v[0]==1 && v[1]==2, "push_back contents");
v.resize(5, 7);
CHECK(v.size()==5, "resize size");
CHECK(v[2]==7 && v[3]==7 && v[4]==7, "resize fill value");
}
TEST_CASE(Vector_Data_ReadWrite) {
utils::Vd v(4, 0.0);
double* p = v.data();
for (std::uint64_t i=0;i<4;++i) p[i] = double(i+1);
CHECK(v[0]==1.0 && v[3]==4.0, "write via data()");
const utils::Vd& cv = v;
const double* cp = cv.data();
double s=0.0; for (std::uint64_t i=0;i<cv.size();++i) s += cp[i];
CHECK(std::fabs(s-10.0) < 1e-12, "read via const data()");
}
TEST_CASE(Vector_Equality_and_NearlyEqual) {
utils::Vd a(3, 1.0), b(3, 1.0);
CHECK(a==b, "operator== equal");
b[1] += 5e-7;
CHECK(a==b, "operator== within eps (1e-6)");
b[1] += 2e-6;
CHECK(!(a==b), "operator== beyond eps");
utils::Vd c = a;
c[2] += 1e-10;
CHECK(c.nearly_equal_vec(a, 1e-9), "nearly_equal_vec within tol");
c[2] += 1e-6;
CHECK(!c.nearly_equal_vec(a, 1e-9), "nearly_equal_vec beyond tol");
}
TEST_CASE(Vector_Scalar_Arithmetic) {
utils::Vi v(3, 10); // [10,10,10]
auto vadd = v + 2; // [12,12,12]
CHECK(vadd[0]==12 && vadd[2]==12, "scalar +");
v += 3; // [13,13,13]
CHECK(v[1]==13, "scalar +=");
auto vsub = v - 5; // [8,8,8]
CHECK(vsub[2]==8, "scalar -");
v -= 3; // [10,10,10]
CHECK(v[0]==10, "scalar -=");
auto vmul = v * 2; // [20,20,20]
CHECK(vmul[0]==20 && vmul[1]==20, "scalar *");
v *= 3; // [30,30,30]
CHECK(v[2]==30, "scalar *=");
auto vdiv = v / 3; // [10,10,10]
CHECK(vdiv[0]==10 && vdiv[1]==10, "scalar /");
v /= 2; // [15,15,15]
CHECK(v[0]==15 && v[2]==15, "scalar /=");
}
TEST_CASE(Vector_Vector_Arithmetic) {
utils::Vi a(4, 1), b(4, 2);
auto c = a + b; // [3,3,3,3]
CHECK(c[0]==3 && c[3]==3, "v+v");
a += b; // [3,3,3,3]
CHECK(a[1]==3, "v+=v");
auto d = a - b; // [1,1,1,1]
CHECK(d[2]==1, "v-v");
a -= b; // [1,1,1,1]
CHECK(a[0]==1, "v-=v");
auto e = a * b; // [2,2,2,2]
CHECK(e[1]==2, "v*v (elemwise)");
a *= b; // [2,2,2,2]
CHECK(a[3]==2, "v*=v");
auto f = e / b; // [1,1,1,1]
CHECK(f[0]==1 && f[3]==1, "v/v (elemwise)");
e /= b; // [1,1,1,1]
CHECK(e[2]==1, "v/=v");
}
TEST_CASE(Vector_Friend_Scalar_Left) {
utils::Vd v(3, 2.0); // [2,2,2]
auto s1 = 3.0 + v; // [5,5,5]
CHECK(s1[0]==5.0 && s1[2]==5.0, "left scalar +");
auto s2 = 4.0 * v; // [8,8,8]
CHECK(s2[1]==8.0, "left scalar *");
}
TEST_CASE(Vector_Power_and_Sqrt) {
utils::Vd v(3, 4.0); // [4,4,4]
auto p = v.power(2.0); // [16,16,16]
CHECK(p[0]==16.0 && p[2]==16.0, "power scalar");
v.inplace_sqrt(); // sqrt([4,4,4]) -> [2,2,2]
CHECK(v[0]==2.0 && v[1]==2.0, "inplace_sqrt");
}
TEST_CASE(Vector_Dot_Sum_Norm) {
utils::Vd a(3, 0.0), b(3, 0.0);
a[0]=1.0; a[1]=2.0; a[2]=3.0; // a = [1,2,3]
b[0]=4.0; b[1]=5.0; b[2]=6.0; // b = [4,5,6]
double dot = a.dot(b); // 1*4 + 2*5 + 3*6 = 32
CHECK(std::fabs(dot - 32.0) < 1e-12, "dot");
double s = a.sum(); // 6
CHECK(std::fabs(s - 6.0) < 1e-12, "sum");
double n = a.norm(); // sqrt(14)
CHECK(std::fabs(n - std::sqrt(14.0)) < 1e-12, "norm");
}
TEST_CASE(Vector_Normalize_and_Throws) {
utils::Vd v(3, 0.0);
v[0]=3.0; v[1]=4.0; v[2]=0.0; // norm = 5
auto u = v.normalize(); // returns new vector
CHECK(std::fabs(u.norm() - 1.0) < 1e-12, "normalize() unit length");
v.inplace_normalize();
CHECK(std::fabs(v.norm() - 1.0) < 1e-12, "inplace_normalize unit length");
utils::Vd z(3, 0.0);
bool threw=false;
try { z.inplace_normalize(); } catch(const std::runtime_error&) { threw=true; }
CHECK(threw, "normalize should throw on zero vector");
}
// Size mismatch throws (elementwise ops)
TEST_CASE(Vector_Size_Mismatch_Throws) {
utils::Vi a(3,1), b(4,2);
bool threw=false;
try { (void)a.dot(b); } catch(const std::runtime_error&) { threw=true; }
CHECK(threw, "dot size mismatch should throw");
threw=false;
try { a.inplace_add(b); } catch(const std::runtime_error&) { threw=true; }
CHECK(threw, "add size mismatch should throw");
threw=false;
try { a.inplace_subtract(b); } catch(const std::runtime_error&) { threw=true; }
CHECK(threw, "subtract size mismatch should throw");
threw=false;
try { a.inplace_multiply(b); } catch(const std::runtime_error&) { threw=true; }
CHECK(threw, "multiply size mismatch should throw");
threw=false;
try { a.inplace_divide(b); } catch(const std::runtime_error&) { threw=true; }
CHECK(threw, "divide size mismatch should throw");
threw=false;
try { a.inplace_power(b); } catch(const std::runtime_error&) { threw=true; }
CHECK(threw, "power size mismatch should throw");
}