#include #include #include #include #include #ifdef STA_CORE #include #include #else void STA_ASSERT_MSG(int cond, const char * msg) { if(!cond) { printf("%s\n", msg); std::exit(1); } } #define STA_DEBUG_PRINT printf #define STA_DEBUG_PRINTF printf void STA_DEBUG_PRINTLN(const char * msg) { printf("%s\n", msg); } #endif namespace sta { namespace math { matrix::matrix() { datafield = nullptr; shape = nullptr; } matrix::matrix(const matrix &m) { if (shape != nullptr) { shape[0] -= 1; if (shape[0] <= 0) { free(datafield); free(shape); } } datafield = m.datafield; shape = m.shape; shape[0] += 1; } matrix::matrix(uint8_t rows, uint8_t cols) { uint16_t size = rows * cols; datafield = (float *) malloc((sizeof(float) * size)); shape = (uint8_t *) malloc(sizeof(uint8_t) * 4); shape[0] = 1; shape[1] = rows; shape[2] = cols; shape[3] = rows * cols; } matrix::matrix(uint8_t rows, uint8_t cols, float *vals) { uint16_t size = rows * cols; datafield = (float *) malloc((sizeof(float) * size)); shape = (uint8_t *) malloc(sizeof(uint8_t) * 4); shape[0] = 1; shape[1] = rows; shape[2] = cols; shape[3] = rows * cols; for (uint16_t i = 0; i < size; i++) { datafield[i] = vals[i]; } } matrix::~matrix() { if (shape != nullptr) { shape[0] -= 1; if (shape[0] <= 0) { free(datafield); free(shape); } } } bool matrix::is_valid() { if (shape == nullptr) { return false; } return true; } uint16_t matrix::get_size() { STA_ASSERT_MSG(shape != nullptr, "Shape is nullptr"); return shape[3]; } uint8_t matrix::get_rows() { STA_ASSERT_MSG(shape != nullptr, "Shape is nullptr"); return shape[1]; } uint8_t matrix::get_cols() { STA_ASSERT_MSG(shape != nullptr, "Shape is nullptr"); return shape[2]; } matrix matrix::clone() { matrix m(get_rows(), get_cols()); uint16_t size = get_size(); for (uint16_t i = 0; i < size; i++) { m.datafield[i] = datafield[i]; } return m; } matrix &matrix::operator=(matrix m) { if (shape != nullptr) { shape[0] -= 1; if (shape[0] <= 0) { free(datafield); free(shape); } } datafield = m.datafield; shape = m.shape; shape[0] += 1; return *this; } void matrix::reshape(uint8_t r, uint8_t c) { STA_ASSERT_MSG(r * c == get_size(), "New shape does not match old shape"); shape[1] = r; shape[2] = c; shape[3] = r * c; } float matrix::det() { uint8_t rows = get_rows(); uint8_t cols = get_cols(); STA_ASSERT_MSG(rows == cols && rows >= 1, "Matrix is not square. Determinant can not be computed." ); if (rows == 1) { return datafield[0]; } if (rows == 2) { return (operator()(0, 0) * operator()(1, 1)) - (operator()(1, 0) * operator()(0, 1)); } if (rows == 3) { float S = 0; S += operator()(0, 0) * ((operator()(1, 1) * operator()(2, 2)) - (operator()(2, 1) * operator()(1, 2))); S -= operator()(0, 1) * ((operator()(1, 0) * operator()(2, 2)) - (operator()(2, 0) * operator()(1, 2))); S += operator()(0, 2) * ((operator()(1, 0) * operator()(2, 1)) - (operator()(2, 0) * operator()(1, 1))); return S; } if (rows > 3) { float determinant = 0; for (uint8_t k = 0; k < rows; k++) { matrix submatrix(rows - 1, rows - 1); uint8_t i = 0; for (uint8_t x = 0; x < rows; x++) { for (uint8_t y = 1; y < cols; y++) { float val = operator()(x, y); if (x != k) { submatrix.set(i, y - 1, val); } } if (x != k) { i += 1; } } float new_determinant = submatrix.det() * operator()(k, 0); if (k % 2 == 1) { new_determinant *= -1; } determinant += new_determinant; } return determinant; } return 0; } matrix matrix::get_block(uint8_t start_r, uint8_t start_c, uint8_t len_r, uint8_t len_c) { //matrix output(len_r, len_c); matrix output = matrix::zeros(len_r, len_c); STA_ASSERT_MSG(start_r + len_r <= get_rows() && start_c + len_c <= get_cols(), "get_block failed. Boundary conditions not initialized." ); for (uint8_t r = 0; r < len_r; r++) { for (uint8_t c = 0; c < len_c; c++) { float val = operator()(start_r + r, start_c + c); output.set(r, c, val); } } return output; } void matrix::set_block(uint8_t _r, uint8_t _c, matrix m) { STA_ASSERT_MSG(_r + m.get_rows() <= get_rows() && _c + m.get_cols() <= get_cols(), "set_block failed. Boundary conditions not initialized." ); for (uint8_t r = 0; r < m.get_rows(); r++) { for (uint8_t c = 0; c < m.get_cols(); c++) { set(_r + r, _c + c, m(r, c)); } } } void matrix::set(uint8_t r, uint8_t c, float v) { datafield[get_idx(r, c)] = v; } void matrix::set(uint16_t i, float v) { STA_ASSERT_MSG(i < get_size(), "Index out of Bounds" ); datafield[i] = v; } matrix matrix::get_submatrix(uint8_t _r, uint8_t _c) { matrix output = clone(); for (uint8_t r = 0; r < get_rows(); r++) { output.set(r, _c, 0); } for (uint8_t c = 0; c < get_cols(); c++) { output.set(_r, c, 0); } return output; } matrix matrix::eye(uint8_t s) { matrix output = matrix::zeros(s, s); for (uint8_t i = 0; i < s; i++) { output.set(i, i, 1); } return output; } matrix matrix::zeros(uint8_t r, uint8_t c) { matrix output(r, c); for (uint16_t i = 0; i < r * c; i++) { output.datafield[i] = 0; } return output; } matrix matrix::full(uint8_t r, uint8_t c, float v) { matrix output(r, c); for (uint16_t i = 0; i < r * c; i++) { output.datafield[i] = v; } return output; } float matrix::operator()(uint8_t r, uint8_t c) { float i = datafield[get_idx(r, c)]; return i; } float matrix::operator[](uint16_t i) { if (i > get_size()) { return 0; } return datafield[i]; } uint16_t matrix::get_idx(uint8_t r, uint8_t c) { STA_ASSERT_MSG(c < get_cols(), "Column index out of bounds get_idx"); STA_ASSERT_MSG(r < get_rows(), "Row index out of bounds get_idx"); return (r * get_cols()) + c; } matrix matrix::T() { matrix output(get_cols(), get_rows()); for (uint8_t r = 0; r < get_rows(); r++) { for (uint8_t c = 0; c < get_cols(); c++) { output.set(c, r, operator()(r, c)); } } return output; } matrix matrix::flatten() { matrix output = clone(); output.reshape(get_size(), 1); return output; } float matrix::minor(uint8_t r, uint8_t c) { matrix out(get_rows() - 1, get_cols() - 1); for (uint8_t row = 0; row < get_rows() - 1; row++) { for (uint8_t col = 0; col < get_cols() - 1; col++) { if (row < r && col < c) { out.set(row, col, operator()(row, col)); } else if (row >= r && col < c) { out.set(row, col, operator()(row + 1, col)); } else if (row < r && col >= c) { out.set(row, col, operator()(row, col + 1)); } else { out.set(row, col, operator()(row + 1, col + 1)); } } } return out.det(); } matrix matrix::operator*(float s) { return linalg::dot(*this, s); } matrix matrix::operator*(matrix m) { return linalg::dot(*this, m); } matrix matrix::operator+(matrix m) { return linalg::add(*this, m); } matrix matrix::operator-(matrix m) { return linalg::subtract(*this, m); } void matrix::show_serial() { show_shape(); for(uint8_t r = 0; r < get_rows(); r++) { for(uint8_t c = 0; c < get_cols(); c++) { STA_DEBUG_PRINT("| "); STA_DEBUG_PRINTF("%f", operator()(r, c)); if(c == get_cols() - 1) { STA_DEBUG_PRINTLN(" |"); } else { STA_DEBUG_PRINT(" "); } } } } void matrix::show_shape() { STA_DEBUG_PRINTF("Matrix shape: (%d x %d)\n", get_rows(), get_cols()); } }// namespace math }//namespace sta