Files
Flux/include/utils/matrix.h
T
2025-09-21 20:57:02 +02:00

165 lines
4.8 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#ifndef _matrix_n_
#define _matrix_n_
#include "./utils/vector.h"
#ifdef _OPENMP
#include <omp.h>
#endif
#include <iomanip>
namespace utils{
//#######################################
//# MATRIX TYPE #
//# Backed by utils::Vector<T> #
//#######################################
template <typename T>
class Matrix{
public:
Matrix() : rows_(0), cols_(0), data_() {} // Default constructor
// Constructor to initialize matrix with rows × cols and a fill value
Matrix(uint64_t rows, uint64_t cols, const T& value = T())
: rows_(rows), cols_(cols), data_(rows * cols, value) {}
//# MATRIX: basic properties #
uint64_t rows() const noexcept {return rows_;}
uint64_t cols() const noexcept {return cols_;}
//# MATRIX: element access (fast; unchecked) #
T& operator()(uint64_t i, uint64_t j) { return data_[i * cols_ + j]; }
const T& operator()(uint64_t i, uint64_t j) const { return data_[i * cols_ + j]; }
//# MATRIX: data access #
T* data() noexcept { return data_.data(); }
const T* data() const noexcept { return data_.data(); }
void resize(uint64_t rows, uint64_t cols, const T& value = T(0)){
rows_ = rows;
cols_ = cols;
data_.resize(rows_*cols_, value);
}
//# MATRIX: row helpers (copy out) #
// Read whole row as an owning Vector<T>
// utils::Vf v = M.get_row(2);
Vector<T> get_row(const uint64_t row) const {
if (row >= rows_) {
throw std::out_of_range("Matrix::get_row -> row index");
}
utils::Vector<T> result(cols_, T{});
for (uint64_t i = 0; i < cols_; ++i){
result[i] = data_[row * cols_ + i];
}
return result;
}
//# MATRIX: row helpers (copy in) #
// Assign a whole Vector<T> to a row
// M.set_row(2) = v;
void set_row(const uint64_t row, const Vector<T>& vector){
if (row >= rows_) {
throw std::out_of_range("Matrix::set_row -> row index");
}
if (vector.size() != cols_){
throw std::runtime_error("Matrix::set_row -> size mismatch");
}
for (uint64_t i = 0; i < cols_; ++i){
data_[row * cols_ + i] = vector[i];
}
}
//# MATRIX: col helpers (copy out) #
// Read whole col as an owning Vector<T>
// utils::Vf v = M.get_col(2);
Vector<T> get_col(const uint64_t col) const {
if (col >= cols_) {
throw std::out_of_range("Matrix::get_col -> col index");
}
utils::Vector<T> result(rows_, T{});
for (uint64_t i = 0; i < rows_; ++i){
result[i] = data_[i * cols_ + col];
}
return result;
}
//# MATRIX: col helpers (copy in) #
// Assign a whole Vector<T> to a col
// M.set_col(2) = v;
void set_col(const uint64_t col, const Vector<T>& vector){
if (col >= cols_) {
throw std::out_of_range("Matrix::set_col -> col index");
}
if (vector.size() != rows_){
throw std::runtime_error("Matrix::set_col -> size mismatch");
}
for (uint64_t i = 0; i < rows_; ++i){
data_[i * cols_ + col] = vector[i];
}
}
void swap_rows(uint64_t a, uint64_t b){
if (a >= rows_ || b >= rows_) {
throw std::out_of_range("Matrix::swap_rows -> row index");
}
if (a == b){
return;
}
for (uint64_t i = 0; i < cols_; ++i){
T tmp = data_[a * cols_ + i];
data_[a * cols_ + i] = data_[b * cols_ + i];
data_[b * cols_ + i] = tmp;
}
}
void swap_cols(uint64_t a, uint64_t b){
if (a >= cols_ || b >= cols_) {
throw std::out_of_range("Matrix::swap_cols -> col index");
}
if (a == b){
return;
}
for (uint64_t i = 0; i < rows_; ++i){
T tmp = data_[i * cols_ + a];
data_[i * cols_ + a] = data_[i * cols_ + b];
data_[i * cols_ + b] = tmp;
}
}
inline friend std::ostream& operator<<(std::ostream& out, const Matrix& M) {
out << "[";
for (uint64_t i = 0; i < M.rows_; ++i) {
out << "[";
for (uint64_t j = 0; j < M.cols_; ++j) {
out << std::setw(4) << std::setprecision(3) << std::fixed << M(i, j);
if (j + 1 < M.cols_) out << ", ";
}
out << "]";
if (i + 1 < M.rows_) out << ",\n ";
}
out << "]";
return out;
}
void print() const {
std::cout << *this << std::endl;
}
private:
uint64_t rows_, cols_;
std::vector<T> data_;
};
typedef Matrix<int64_t> Mi;
typedef Matrix<float> Mf;
typedef Matrix<double> Md;
}
#endif // _matrix_n_