mirror of
https://git.intern.spaceteamaachen.de/ALPAKA/sta-peak.git
synced 2025-06-12 19:05:58 +00:00
458 lines
8.5 KiB
C++
458 lines
8.5 KiB
C++
#include <sta/math/linalg/matrix.hpp>
|
|
#include <sta/math/linalg/linalg.hpp>
|
|
#include <cstdint>
|
|
#include <cstring>
|
|
#include <iostream>
|
|
|
|
#ifdef STA_CORE
|
|
#include <sta/debug/debug.hpp>
|
|
#include <sta/debug/assert.hpp>
|
|
#else
|
|
void STA_ASSERT_MSG(int cond, const char * msg) {
|
|
if(!cond) {
|
|
printf("%s\n", msg);
|
|
std::exit(1);
|
|
}
|
|
}
|
|
|
|
#define STA_DEBUG_PRINT printf
|
|
#define STA_DEBUG_PRINTF printf
|
|
|
|
void STA_DEBUG_PRINTLN(const char * msg) {
|
|
printf("%s\n", msg);
|
|
}
|
|
#endif
|
|
namespace sta
|
|
{
|
|
namespace math {
|
|
|
|
matrix::matrix() {
|
|
|
|
datafield = nullptr;
|
|
shape = nullptr;
|
|
|
|
}
|
|
|
|
matrix::matrix(const matrix &m) {
|
|
|
|
if (shape != nullptr) {
|
|
|
|
shape[0] -= 1;
|
|
if (shape[0] <= 0) {
|
|
free(datafield);
|
|
free(shape);
|
|
}
|
|
|
|
}
|
|
|
|
datafield = m.datafield;
|
|
shape = m.shape;
|
|
shape[0] += 1;
|
|
|
|
}
|
|
|
|
matrix::matrix(uint8_t rows, uint8_t cols) {
|
|
|
|
uint16_t size = rows * cols;
|
|
datafield = (float *) malloc((sizeof(float) * size));
|
|
shape = (uint8_t *) malloc(sizeof(uint8_t) * 4);
|
|
|
|
shape[0] = 1;
|
|
shape[1] = rows;
|
|
shape[2] = cols;
|
|
shape[3] = rows * cols;
|
|
|
|
}
|
|
|
|
matrix::matrix(uint8_t rows, uint8_t cols, float *vals) {
|
|
|
|
uint16_t size = rows * cols;
|
|
datafield = (float *) malloc((sizeof(float) * size));
|
|
shape = (uint8_t *) malloc(sizeof(uint8_t) * 4);
|
|
|
|
shape[0] = 1;
|
|
shape[1] = rows;
|
|
shape[2] = cols;
|
|
shape[3] = rows * cols;
|
|
|
|
for (uint16_t i = 0; i < size; i++) {
|
|
datafield[i] = vals[i];
|
|
}
|
|
|
|
}
|
|
|
|
matrix::~matrix() {
|
|
|
|
if (shape != nullptr) {
|
|
|
|
shape[0] -= 1;
|
|
if (shape[0] <= 0) {
|
|
free(datafield);
|
|
free(shape);
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
bool matrix::is_valid() {
|
|
if (shape == nullptr) {
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
uint16_t matrix::get_size() {
|
|
STA_ASSERT_MSG(shape != nullptr, "Shape is nullptr");
|
|
return shape[3];
|
|
}
|
|
|
|
uint8_t matrix::get_rows() {
|
|
STA_ASSERT_MSG(shape != nullptr, "Shape is nullptr");
|
|
return shape[1];
|
|
}
|
|
|
|
uint8_t matrix::get_cols() {
|
|
STA_ASSERT_MSG(shape != nullptr, "Shape is nullptr");
|
|
return shape[2];
|
|
}
|
|
|
|
matrix matrix::clone() {
|
|
|
|
matrix m(get_rows(), get_cols());
|
|
|
|
uint16_t size = get_size();
|
|
for (uint16_t i = 0; i < size; i++) {
|
|
m.datafield[i] = datafield[i];
|
|
}
|
|
|
|
return m;
|
|
|
|
}
|
|
|
|
matrix &matrix::operator=(matrix m) {
|
|
|
|
if (shape != nullptr) {
|
|
|
|
shape[0] -= 1;
|
|
if (shape[0] <= 0) {
|
|
free(datafield);
|
|
free(shape);
|
|
}
|
|
|
|
}
|
|
|
|
datafield = m.datafield;
|
|
shape = m.shape;
|
|
shape[0] += 1;
|
|
return *this;
|
|
|
|
}
|
|
|
|
void matrix::reshape(uint8_t r, uint8_t c) {
|
|
|
|
STA_ASSERT_MSG(r * c == get_size(), "New shape does not match old shape");
|
|
shape[1] = r;
|
|
shape[2] = c;
|
|
shape[3] = r * c;
|
|
|
|
}
|
|
|
|
float matrix::det() {
|
|
|
|
uint8_t rows = get_rows();
|
|
uint8_t cols = get_cols();
|
|
|
|
STA_ASSERT_MSG(rows == cols && rows >= 1, "Matrix is not square. Determinant can not be computed." );
|
|
|
|
if (rows == 1) {
|
|
return datafield[0];
|
|
}
|
|
|
|
if (rows == 2) {
|
|
return (operator()(0, 0) * operator()(1, 1)) - (operator()(1, 0) * operator()(0, 1));
|
|
}
|
|
|
|
if (rows == 3) {
|
|
float S = 0;
|
|
S += operator()(0, 0) * ((operator()(1, 1) * operator()(2, 2)) - (operator()(2, 1) * operator()(1, 2)));
|
|
S -= operator()(0, 1) * ((operator()(1, 0) * operator()(2, 2)) - (operator()(2, 0) * operator()(1, 2)));
|
|
S += operator()(0, 2) * ((operator()(1, 0) * operator()(2, 1)) - (operator()(2, 0) * operator()(1, 1)));
|
|
return S;
|
|
}
|
|
|
|
if (rows > 3) {
|
|
|
|
float determinant = 0;
|
|
|
|
for (uint8_t k = 0; k < rows; k++) {
|
|
|
|
matrix submatrix(rows - 1, rows - 1);
|
|
uint8_t i = 0;
|
|
|
|
for (uint8_t x = 0; x < rows; x++) {
|
|
|
|
for (uint8_t y = 1; y < cols; y++) {
|
|
|
|
float val = operator()(x, y);
|
|
|
|
if (x != k) {
|
|
submatrix.set(i, y - 1, val);
|
|
}
|
|
|
|
}
|
|
|
|
if (x != k) {
|
|
i += 1;
|
|
}
|
|
|
|
}
|
|
|
|
float new_determinant = submatrix.det() * operator()(k, 0);
|
|
if (k % 2 == 1) {
|
|
new_determinant *= -1;
|
|
}
|
|
|
|
determinant += new_determinant;
|
|
|
|
|
|
}
|
|
|
|
return determinant;
|
|
|
|
|
|
}
|
|
|
|
return 0;
|
|
|
|
|
|
}
|
|
|
|
matrix matrix::get_block(uint8_t start_r, uint8_t start_c, uint8_t len_r, uint8_t len_c) {
|
|
|
|
//matrix output(len_r, len_c);
|
|
matrix output = matrix::zeros(len_r, len_c);
|
|
|
|
STA_ASSERT_MSG(start_r + len_r <= get_rows() && start_c + len_c <= get_cols(), "get_block failed. Boundary conditions not initialized." );
|
|
|
|
for (uint8_t r = 0; r < len_r; r++) {
|
|
|
|
for (uint8_t c = 0; c < len_c; c++) {
|
|
|
|
float val = operator()(start_r + r, start_c + c);
|
|
output.set(r, c, val);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return output;
|
|
|
|
}
|
|
|
|
void matrix::set_block(uint8_t _r, uint8_t _c, matrix m) {
|
|
|
|
STA_ASSERT_MSG(_r + m.get_rows() <= get_rows() && _c + m.get_cols() <= get_cols(), "set_block failed. Boundary conditions not initialized." );
|
|
|
|
for (uint8_t r = 0; r < m.get_rows(); r++) {
|
|
|
|
for (uint8_t c = 0; c < m.get_cols(); c++) {
|
|
|
|
set(_r + r, _c + c, m(r, c));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
void matrix::set(uint8_t r, uint8_t c, float v) {
|
|
|
|
datafield[get_idx(r, c)] = v;
|
|
|
|
}
|
|
|
|
void matrix::set(uint16_t i, float v) {
|
|
|
|
STA_ASSERT_MSG(i < get_size(), "Index out of Bounds" );
|
|
|
|
datafield[i] = v;
|
|
|
|
}
|
|
|
|
matrix matrix::get_submatrix(uint8_t _r, uint8_t _c) {
|
|
|
|
matrix output = clone();
|
|
|
|
for (uint8_t r = 0; r < get_rows(); r++) {
|
|
|
|
output.set(r, _c, 0);
|
|
|
|
}
|
|
|
|
for (uint8_t c = 0; c < get_cols(); c++) {
|
|
|
|
output.set(_r, c, 0);
|
|
|
|
}
|
|
|
|
return output;
|
|
|
|
}
|
|
|
|
matrix matrix::eye(uint8_t s) {
|
|
|
|
matrix output = matrix::zeros(s, s);
|
|
for (uint8_t i = 0; i < s; i++) {
|
|
output.set(i, i, 1);
|
|
}
|
|
return output;
|
|
|
|
}
|
|
|
|
matrix matrix::zeros(uint8_t r, uint8_t c) {
|
|
|
|
matrix output(r, c);
|
|
for (uint16_t i = 0; i < r * c; i++) {
|
|
output.datafield[i] = 0;
|
|
}
|
|
return output;
|
|
|
|
}
|
|
|
|
matrix matrix::full(uint8_t r, uint8_t c, float v) {
|
|
|
|
matrix output(r, c);
|
|
for (uint16_t i = 0; i < r * c; i++) {
|
|
output.datafield[i] = v;
|
|
}
|
|
return output;
|
|
|
|
}
|
|
|
|
float matrix::operator()(uint8_t r, uint8_t c) {
|
|
float i = datafield[get_idx(r, c)];
|
|
return i;
|
|
|
|
}
|
|
|
|
float matrix::operator[](uint16_t i) {
|
|
if (i > get_size()) {
|
|
return 0;
|
|
}
|
|
return datafield[i];
|
|
|
|
}
|
|
|
|
uint16_t matrix::get_idx(uint8_t r, uint8_t c) {
|
|
STA_ASSERT_MSG(c < get_cols(), "Column index out of bounds get_idx");
|
|
STA_ASSERT_MSG(r < get_rows(), "Row index out of bounds get_idx");
|
|
|
|
return (r * get_cols()) + c;
|
|
|
|
}
|
|
|
|
matrix matrix::T() {
|
|
|
|
matrix output(get_cols(), get_rows());
|
|
|
|
for (uint8_t r = 0; r < get_rows(); r++) {
|
|
for (uint8_t c = 0; c < get_cols(); c++) {
|
|
|
|
output.set(c, r, operator()(r, c));
|
|
|
|
}
|
|
}
|
|
|
|
return output;
|
|
|
|
}
|
|
|
|
matrix matrix::flatten() {
|
|
|
|
matrix output = clone();
|
|
output.reshape(get_size(), 1);
|
|
return output;
|
|
|
|
}
|
|
|
|
float matrix::minor(uint8_t r, uint8_t c) {
|
|
|
|
matrix out(get_rows() - 1, get_cols() - 1);
|
|
|
|
|
|
for (uint8_t row = 0; row < get_rows() - 1; row++) {
|
|
for (uint8_t col = 0; col < get_cols() - 1; col++) {
|
|
|
|
|
|
if (row < r && col < c) {
|
|
out.set(row, col, operator()(row, col));
|
|
} else if (row >= r && col < c) {
|
|
out.set(row, col, operator()(row + 1, col));
|
|
} else if (row < r && col >= c) {
|
|
out.set(row, col, operator()(row, col + 1));
|
|
} else {
|
|
out.set(row, col, operator()(row + 1, col + 1));
|
|
}
|
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
return out.det();
|
|
|
|
|
|
}
|
|
|
|
matrix matrix::operator*(float s) {
|
|
return linalg::dot(*this, s);
|
|
}
|
|
|
|
matrix matrix::operator*(matrix m) {
|
|
return linalg::dot(*this, m);
|
|
}
|
|
|
|
matrix matrix::operator+(matrix m) {
|
|
return linalg::add(*this, m);
|
|
}
|
|
|
|
matrix matrix::operator-(matrix m) {
|
|
return linalg::subtract(*this, m);
|
|
}
|
|
|
|
|
|
void matrix::show_serial() {
|
|
|
|
show_shape();
|
|
|
|
for(uint8_t r = 0; r < get_rows(); r++) {
|
|
|
|
for(uint8_t c = 0; c < get_cols(); c++) {
|
|
|
|
STA_DEBUG_PRINT("| ");
|
|
STA_DEBUG_PRINTF("%f", operator()(r, c));
|
|
if(c == get_cols() - 1) {
|
|
STA_DEBUG_PRINTLN(" |");
|
|
} else {
|
|
STA_DEBUG_PRINT(" ");
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
void matrix::show_shape() {
|
|
|
|
STA_DEBUG_PRINTF("Matrix shape: (%d x %d)\n", get_rows(), get_cols());
|
|
|
|
}
|
|
|
|
}// namespace math
|
|
|
|
}//namespace sta
|