Files
Flux-openbuild/include/utils/matrix.h
2025-10-09 08:44:15 +00:00

248 lines
7.2 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#ifndef _matrix_n_
#define _matrix_n_
#include "./utils/vector.h"
#include "./utils/random.h"
#ifdef _OPENMP
#include <omp.h>
#endif
#include <initializer_list>
#include <iomanip>
namespace utils{
//#######################################
//# MATRIX TYPE #
//# Backed by utils::Vector<T> #
//#######################################
template <typename T>
class Matrix{
public:
Matrix() : rows_(0), cols_(0), data_() {} // Default constructor
// Constructor to initialize matrix with rows × cols and a fill value
Matrix(uint64_t rows, uint64_t cols, const T& value = T())
: rows_(rows), cols_(cols), data_(rows * cols, value) {}
// Construct from list-of-lists:
// utils::Mf A{{1,2,3}, {4,5,6}};
Matrix(std::initializer_list<std::initializer_list<T>> init) {
rows_ = static_cast<uint64_t>(init.size());
if (rows_ > 0) {
cols_ = static_cast<uint64_t>(init.begin()->size());
} else {
cols_ = 0;
}
// Validate all rows have the same length
for (const auto& row : init) {
if (row.size() != cols_) {
throw std::runtime_error("Matrix(list of lists): ragged rows");
}
}
data_.resize(rows_ * cols_);
uint64_t r = 0;
for (uint64_t i = 0; i < init.size(); ++i, ++r){
const std::initializer_list<T>& row = *(init.begin() + i);
uint64_t c = 0;
for (uint64_t j = 0; j < row.size(); ++j, ++c){
const T& val = *(row.begin() + j);
data_[r * cols_ + c] = val;
}
}
}
// Assign from list-of-lists after default construction:
// utils::Md M; M = {{1,2},{3,4}};
Matrix& operator=(std::initializer_list<std::initializer_list<T>> init) {
// Set sizes
rows_ = static_cast<uint64_t>(init.size());
if (rows_ > 0) {
cols_ = static_cast<uint64_t>((init.begin())->size());
} else {
cols_ = 0;
}
// Validate: all rows must have same length
for (uint64_t i = 0; i < rows_; ++i) {
const std::initializer_list<T>& row = *(init.begin() + i);
if (row.size() != cols_) {
throw std::runtime_error("Matrix(list of lists): ragged rows");
}
}
// Allocate storage
data_.resize(rows_ * cols_);
// Copy data row by row
for (uint64_t i = 0; i < rows_; ++i) {
const std::initializer_list<T>& row = *(init.begin() + i);
for (uint64_t j = 0; j < cols_; ++j) {
const T& val = *(row.begin() + j);
data_[i * cols_ + j] = val;
}
}
return *this;
}
//# MATRIX: basic properties #
uint64_t rows() const noexcept {return rows_;}
uint64_t cols() const noexcept {return cols_;}
//# MATRIX: element access (fast; unchecked) #
T& operator()(uint64_t i, uint64_t j) { return data_[i * cols_ + j]; }
const T& operator()(uint64_t i, uint64_t j) const { return data_[i * cols_ + j]; }
//# MATRIX: data access #
T* data() noexcept { return data_.data(); }
const T* data() const noexcept { return data_.data(); }
void resize(uint64_t rows, uint64_t cols, const T& value = T(0)){
rows_ = rows;
cols_ = cols;
data_.resize(rows_*cols_, value);
}
void random(uint64_t rows, uint64_t cols, const T& lower, const T& higher){
rows_ = rows;
cols_ = cols;
data_.resize(rows_*cols_, 0);
// Copy data row by row
for (uint64_t i = 0; i < rows_; ++i) {
for (uint64_t j = 0; j < cols_; ++j) {
data_[i * cols_ + j] = utils::random(lower, higher);
}
}
}
//# MATRIX: row helpers (copy out) #
// Read whole row as an owning Vector<T>
// utils::Vf v = M.get_row(2);
Vector<T> get_row(const uint64_t row) const {
if (row >= rows_) {
throw std::out_of_range("Matrix::get_row -> row index");
}
utils::Vector<T> result(cols_, T{});
for (uint64_t i = 0; i < cols_; ++i){
result[i] = data_[row * cols_ + i];
}
return result;
}
//# MATRIX: row helpers (copy in) #
// Assign a whole Vector<T> to a row
// M.set_row(2) = v;
void set_row(const uint64_t row, const Vector<T>& vector){
if (row >= rows_) {
throw std::out_of_range("Matrix::set_row -> row index");
}
if (vector.size() != cols_){
throw std::runtime_error("Matrix::set_row -> size mismatch");
}
for (uint64_t i = 0; i < cols_; ++i){
data_[row * cols_ + i] = vector[i];
}
}
//# MATRIX: col helpers (copy out) #
// Read whole col as an owning Vector<T>
// utils::Vf v = M.get_col(2);
Vector<T> get_col(const uint64_t col) const {
if (col >= cols_) {
throw std::out_of_range("Matrix::get_col -> col index");
}
utils::Vector<T> result(rows_, T{});
for (uint64_t i = 0; i < rows_; ++i){
result[i] = data_[i * cols_ + col];
}
return result;
}
//# MATRIX: col helpers (copy in) #
// Assign a whole Vector<T> to a col
// M.set_col(2) = v;
void set_col(const uint64_t col, const Vector<T>& vector){
if (col >= cols_) {
throw std::out_of_range("Matrix::set_col -> col index");
}
if (vector.size() != rows_){
throw std::runtime_error("Matrix::set_col -> size mismatch");
}
for (uint64_t i = 0; i < rows_; ++i){
data_[i * cols_ + col] = vector[i];
}
}
void swap_rows(uint64_t a, uint64_t b){
if (a >= rows_ || b >= rows_) {
throw std::out_of_range("Matrix::swap_rows -> row index");
}
if (a == b){
return;
}
for (uint64_t i = 0; i < cols_; ++i){
T tmp = data_[a * cols_ + i];
data_[a * cols_ + i] = data_[b * cols_ + i];
data_[b * cols_ + i] = tmp;
}
}
void swap_cols(uint64_t a, uint64_t b){
if (a >= cols_ || b >= cols_) {
throw std::out_of_range("Matrix::swap_cols -> col index");
}
if (a == b){
return;
}
for (uint64_t i = 0; i < rows_; ++i){
T tmp = data_[i * cols_ + a];
data_[i * cols_ + a] = data_[i * cols_ + b];
data_[i * cols_ + b] = tmp;
}
}
inline friend std::ostream& operator<<(std::ostream& out, const Matrix& M) {
out << "[";
for (uint64_t i = 0; i < M.rows_; ++i) {
out << "[";
for (uint64_t j = 0; j < M.cols_; ++j) {
out << std::setw(4) << std::setprecision(3) << std::fixed << M(i, j);
if (j + 1 < M.cols_) out << ", ";
}
out << "]";
if (i + 1 < M.rows_) out << ",\n ";
}
out << "]";
return out;
}
void print() const {
std::cout << *this << std::endl;
}
private:
uint64_t rows_, cols_;
std::vector<T> data_;
};
typedef Matrix<int64_t> Mi;
typedef Matrix<float> Mf;
typedef Matrix<double> Md;
}
#endif // _matrix_n_