mirror of
https://git.intern.spaceteamaachen.de/ALPAKA/sta-peak.git
synced 2025-06-10 01:55:59 +00:00
Add matrix and kf code
This commit is contained in:
parent
78ed2e4eab
commit
290c48bfb1
16
README.md
16
README.md
@ -1,3 +1,15 @@
|
||||
# sta-peak
|
||||
# Performant Embedded Algebra Kit (PEAK)
|
||||
|
||||
|
||||
> ⚠️ **Warning:** WORK IN PROGRESS and UNTESTED
|
||||
|
||||
|
||||
## Description
|
||||
|
||||
The Performant Embeddded Algebra Kit (PEAK) is a lightweight and easy-to-use library for performing various algebraic operations.
|
||||
|
||||
## Features
|
||||
|
||||
- Matrix operations: Perform matrix addition, subtraction, multiplication, and inversion.
|
||||
- Kalman Filter implementation
|
||||
|
||||
The Performant Embedded Algebra Kit (PEAK) provides structures, classes and functions for various algebraic operations.
|
38
include/sta/math/algorithms/kalmanFilter.hpp
Normal file
38
include/sta/math/algorithms/kalmanFilter.hpp
Normal file
@ -0,0 +1,38 @@
|
||||
#ifndef KALMAN_FILTER_HPP
|
||||
#define KALMAN_FILTER_HPP
|
||||
#include <sta/math/linalg/matrix.hpp>
|
||||
|
||||
namespace math
|
||||
{
|
||||
|
||||
struct KalmanState
|
||||
{
|
||||
matrix error;
|
||||
matrix x;
|
||||
};
|
||||
|
||||
class KalmanFilter
|
||||
{
|
||||
private:
|
||||
matrix A_;
|
||||
matrix B_;
|
||||
matrix C_;
|
||||
matrix Q_;
|
||||
matrix R_;
|
||||
uint8_t n_;
|
||||
matrix identity_;
|
||||
|
||||
|
||||
|
||||
public:
|
||||
KalmanFilter(matrix, matrix, matrix, matrix, matrix);
|
||||
~KalmanFilter();
|
||||
KalmanState predict(float, KalmanState, matrix);
|
||||
KalmanState correct(float,KalmanState, matrix);
|
||||
};
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
#endif // KALMAN_FILTER_HPP
|
32
include/sta/math/linalg/linalg.hpp
Normal file
32
include/sta/math/linalg/linalg.hpp
Normal file
@ -0,0 +1,32 @@
|
||||
#ifndef INC_LINALG_HPP_
|
||||
#define INC_LINALG_HPP_
|
||||
#include <sta/math/linalg/matrix.hpp>
|
||||
|
||||
namespace math
|
||||
{
|
||||
namespace linalg
|
||||
{
|
||||
|
||||
matrix dot(matrix, matrix);
|
||||
float norm(matrix);
|
||||
matrix normalize(matrix);
|
||||
matrix cross(matrix, matrix);
|
||||
matrix skew_symmetric(matrix);
|
||||
matrix add(matrix, matrix);
|
||||
matrix subtract(matrix, matrix);
|
||||
matrix dot(matrix, float);
|
||||
matrix cof(matrix);
|
||||
matrix adj(matrix);
|
||||
|
||||
matrix inv(matrix);
|
||||
|
||||
matrix inv_adj(matrix);
|
||||
matrix inv_char_poly(matrix);
|
||||
matrix inv_schur_dec(matrix);
|
||||
|
||||
matrix _inv_char_poly_3x3(matrix);
|
||||
matrix _inv_char_poly_2x2(matrix);
|
||||
|
||||
}
|
||||
}
|
||||
#endif /* INC_LINALG_HPP_ */
|
60
include/sta/math/linalg/matrix.hpp
Normal file
60
include/sta/math/linalg/matrix.hpp
Normal file
@ -0,0 +1,60 @@
|
||||
#ifndef INC_MATRIX_HPP_
|
||||
#define INC_MATRIX_HPP_
|
||||
#include <cstdint>
|
||||
|
||||
namespace math
|
||||
{
|
||||
|
||||
struct matrix
|
||||
{
|
||||
|
||||
float * datafield = nullptr;
|
||||
uint8_t * shape = nullptr;
|
||||
|
||||
matrix();
|
||||
matrix(const matrix&);
|
||||
matrix(uint8_t, uint8_t);
|
||||
matrix(uint8_t, uint8_t, float*);
|
||||
~matrix();
|
||||
|
||||
bool is_valid();
|
||||
|
||||
uint16_t get_size();
|
||||
uint8_t get_rows();
|
||||
uint8_t get_cols();
|
||||
|
||||
matrix clone();
|
||||
void show_serial();
|
||||
void show_shape();
|
||||
|
||||
matrix& operator=(matrix);
|
||||
void reshape(uint8_t, uint8_t);
|
||||
|
||||
float det();
|
||||
matrix get_block(uint8_t, uint8_t, uint8_t, uint8_t);
|
||||
void set_block(uint8_t, uint8_t, matrix);
|
||||
void set(uint8_t, uint8_t, float);
|
||||
void set(uint16_t, float);
|
||||
matrix get_submatrix(uint8_t, uint8_t);
|
||||
|
||||
static matrix eye(uint8_t);
|
||||
static matrix zeros(uint8_t, uint8_t);
|
||||
|
||||
float operator()(uint8_t, uint8_t);
|
||||
float operator[](uint16_t);
|
||||
uint16_t get_idx(uint8_t, uint8_t);
|
||||
|
||||
matrix T();
|
||||
matrix flatten();
|
||||
float minor(uint8_t, uint8_t);
|
||||
|
||||
matrix operator*(float);
|
||||
matrix operator*(matrix);
|
||||
matrix operator+(matrix);
|
||||
matrix operator-(matrix);
|
||||
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
#endif /* INC_MATRIX_HPP_ */
|
10
include/sta/math/utils.hpp
Normal file
10
include/sta/math/utils.hpp
Normal file
@ -0,0 +1,10 @@
|
||||
#ifndef INC_UTILS_HPP_
|
||||
#define INC_UTILS_HPP_
|
||||
namespace math
|
||||
{
|
||||
float fast_inv_sqrt(float);
|
||||
|
||||
|
||||
} // namespace stamath
|
||||
|
||||
#endif /* INC_UTILS_HPP_ */
|
11
library.json
Normal file
11
library.json
Normal file
@ -0,0 +1,11 @@
|
||||
{
|
||||
"owner" : "sta",
|
||||
"name": "sta-peak",
|
||||
"version": "0.1.0",
|
||||
"dependencies": [
|
||||
{
|
||||
"url": "git@gitlab.com:sta-git/alpaka/sta-core.git",
|
||||
"ref": "main"
|
||||
}
|
||||
]
|
||||
}
|
46
src/algorithms/kalmanFilter.cpp
Normal file
46
src/algorithms/kalmanFilter.cpp
Normal file
@ -0,0 +1,46 @@
|
||||
#include <sta/math/algorithms/kalmanFilter.hpp>
|
||||
#include <sta/math/linalg/linalg.hpp>
|
||||
#include <sta/debug/debug.hpp>
|
||||
#include <sta/debug/assert.hpp>
|
||||
|
||||
namespace math
|
||||
{
|
||||
|
||||
KalmanFilter::KalmanFilter(matrix A, matrix B, matrix C, matrix Q, matrix R) : A_{A}, B_{B}, C_{C}, Q_{Q}, R_{R}, n_{A.get_cols()}
|
||||
{
|
||||
STA_ASSERT_MSG(A.get_rows() == B.get_rows(), "#rows mismatch: A, B!");
|
||||
STA_ASSERT_MSG(A.get_rows() == C.get_rows(), "#rows mismatch: A, C!");
|
||||
STA_ASSERT_MSG(A.get_rows() == Q.get_rows(), "#rows mismatch: A, Q!");
|
||||
STA_ASSERT_MSG(A.get_cols() == Q.get_rows(), "Q not square!");
|
||||
STA_ASSERT_MSG(C.get_rows() == R.get_rows(), "#rows mismatch: C, R");
|
||||
STA_ASSERT_MSG(R.get_cols() == R.get_rows(), "R not square!");
|
||||
identity_ = matrix::eye(n_);
|
||||
}
|
||||
|
||||
KalmanFilter::~KalmanFilter()
|
||||
{
|
||||
// Destructor implementation
|
||||
}
|
||||
|
||||
KalmanState KalmanFilter::predict(float dt, KalmanState state, matrix u)
|
||||
{
|
||||
// Predict step implementation
|
||||
// Update the state based on the system dynamics
|
||||
state.x = A_ * state.x + B_ * u;
|
||||
// Update the error covariance matrix
|
||||
state.error = A_ * state.error * A_.T() + Q_;
|
||||
return state;
|
||||
}
|
||||
|
||||
KalmanState KalmanFilter::correct(float dt, KalmanState state, matrix z)
|
||||
{
|
||||
// Correct step implementation
|
||||
// Calculate the Kalman gain
|
||||
matrix K = state.error * C_.T() * linalg::inv(C_ * state.error * C_.T() + R_);
|
||||
// Update the state based on the measurement
|
||||
state.x = state.x + K * (z - C_ * state.x); //TODO check transpose
|
||||
// Update the error covariance matrix
|
||||
state.error = (identity_ - K * C_) * state.error;
|
||||
return state;
|
||||
}
|
||||
}
|
352
src/linalg/linalg.cpp
Normal file
352
src/linalg/linalg.cpp
Normal file
@ -0,0 +1,352 @@
|
||||
#include <sta/math/linalg/linalg.hpp>
|
||||
#include <sta/math/utils.hpp>
|
||||
#include <cstdint>
|
||||
#include <cmath>
|
||||
#include <sta/debug/debug.hpp>
|
||||
#include <sta/debug/assert.hpp>
|
||||
|
||||
namespace math {
|
||||
|
||||
namespace linalg {
|
||||
|
||||
|
||||
matrix dot(matrix a, matrix b) {
|
||||
|
||||
STA_ASSERT_MSG(a.get_cols() == b.get_rows(), "Matrix dimension mismatch");
|
||||
|
||||
uint8_t k = a.get_cols();
|
||||
uint8_t m = a.get_rows();
|
||||
uint8_t n = b.get_cols();
|
||||
matrix output(m, n);
|
||||
|
||||
for (uint8_t r = 0; r < m; r++) {
|
||||
for (uint8_t c = 0; c < n; c++) {
|
||||
|
||||
|
||||
float S = 0;
|
||||
|
||||
for (uint8_t h = 0; h < k; h++) {
|
||||
|
||||
S += a(r, h) * b(h, c);
|
||||
|
||||
}
|
||||
|
||||
output.set(r, c, S);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
return output;
|
||||
};
|
||||
float norm(matrix m) {
|
||||
|
||||
if( m.get_rows() == 1 || m.get_cols() == 1 ){
|
||||
|
||||
// apply euclid norm on vector
|
||||
|
||||
uint16_t size = m.get_size();
|
||||
float S = 0;
|
||||
for(uint8_t i = 0; i < size; i++) {
|
||||
S += m[i] * m[i];
|
||||
}
|
||||
|
||||
float s = sqrt(S);
|
||||
return s;
|
||||
|
||||
}
|
||||
|
||||
// todo: implement different matrix norms
|
||||
|
||||
return 0;
|
||||
|
||||
};
|
||||
matrix normalize(matrix m) {
|
||||
|
||||
if( m.get_rows() == 1 || m.get_cols() == 1 ){
|
||||
|
||||
// apply euclid normalization to vector
|
||||
|
||||
uint16_t size = m.get_size();
|
||||
float S = 0;
|
||||
for(uint8_t i = 0; i < size; i++) {
|
||||
S += m[i] * m[i];
|
||||
}
|
||||
|
||||
float s = fast_inv_sqrt(S);
|
||||
return m * s;
|
||||
|
||||
}
|
||||
|
||||
// TODO: implement different matrix normalization techniques
|
||||
|
||||
return m * (1/norm(m));
|
||||
|
||||
};
|
||||
matrix cross(matrix a, matrix b) {
|
||||
|
||||
STA_ASSERT_MSG(a.get_size() == 3 && b.get_size() == 3, "Input Vectors need to be 3 long");
|
||||
|
||||
float d[] = {
|
||||
a[1]*b[2] - a[2]*b[1],
|
||||
a[2]*b[0] - a[0]*b[2],
|
||||
a[0]*b[1] - a[1]*b[0]
|
||||
};
|
||||
|
||||
matrix out(3, 1, d);
|
||||
return out;
|
||||
|
||||
};
|
||||
matrix skew_symmetric(matrix m) {
|
||||
|
||||
STA_ASSERT_MSG( m.get_rows() == 1 && m.get_cols() == 1 , "Input vectors not a vector!");
|
||||
|
||||
STA_ASSERT_MSG( m.get_size() == 3, "Input vector needs to be of size 3!");
|
||||
|
||||
float d[] = {
|
||||
0, -m[2], m[1],
|
||||
m[2], 0, -m[0],
|
||||
-m[1], m[0], 0
|
||||
};
|
||||
|
||||
matrix output(3, 3, d);
|
||||
return output;
|
||||
|
||||
};
|
||||
matrix add(matrix a, matrix b) {
|
||||
|
||||
STA_ASSERT_MSG( a.get_rows() == b.get_rows() && a.get_cols() == b.get_cols(), "Matrix dimensions mismatch!" );
|
||||
|
||||
matrix output = a.clone();
|
||||
uint16_t size = a.get_size();
|
||||
for (uint16_t i = 0; i < size; i++) {
|
||||
output.datafield[i] += b.datafield[i];
|
||||
}
|
||||
|
||||
return output;
|
||||
};
|
||||
matrix subtract(matrix a, matrix b) {
|
||||
|
||||
STA_ASSERT_MSG( a.get_rows() == b.get_rows() && a.get_cols() == b.get_cols(), "Matrix dimensions mismatch!" );
|
||||
matrix output = a.clone();
|
||||
uint16_t size = a.get_size();
|
||||
for (uint16_t i = 0; i < size; i++) {
|
||||
output.datafield[i] -= b.datafield[i];
|
||||
}
|
||||
|
||||
return output;
|
||||
};
|
||||
matrix dot(matrix m, float s) {
|
||||
|
||||
float size = m.get_size();
|
||||
|
||||
matrix output = m.clone();
|
||||
|
||||
for(uint8_t i = 0; i < size; i++) {
|
||||
|
||||
output.datafield[i] *= s;
|
||||
|
||||
}
|
||||
|
||||
return output;
|
||||
|
||||
};
|
||||
matrix cof(matrix m) {
|
||||
|
||||
uint8_t rows = m.get_rows();
|
||||
uint8_t cols = m.get_cols();
|
||||
|
||||
matrix output(rows, cols);
|
||||
|
||||
for (uint8_t r = 0; r < rows; r++) {
|
||||
for (uint8_t c = 0; c < cols; c++) {
|
||||
|
||||
float cof;
|
||||
|
||||
if( (r+c) % 2 == 0 ) {
|
||||
cof = 1;
|
||||
} else {
|
||||
cof = -1;
|
||||
}
|
||||
cof *= m.minor(r, c);
|
||||
|
||||
output.set(r, c, cof);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
return output;
|
||||
};
|
||||
|
||||
matrix adj(matrix m) {
|
||||
|
||||
matrix output = cof(m).T();
|
||||
return output;
|
||||
|
||||
};
|
||||
|
||||
matrix inv(matrix m) {
|
||||
|
||||
STA_ASSERT_MSG( m.get_cols() == m.get_rows(), "Matrix not square. Inverse not valid" );
|
||||
|
||||
uint8_t size = m.get_cols();
|
||||
|
||||
if(size == 1) {
|
||||
matrix output = m.clone();
|
||||
output.set(0, 0, 1/output(0, 0));
|
||||
return output;
|
||||
}
|
||||
|
||||
if(size == 2) {
|
||||
//return inv_adj(m);
|
||||
return _inv_char_poly_2x2(m);
|
||||
}
|
||||
|
||||
if(size == 3) {
|
||||
return inv_adj(m);
|
||||
}
|
||||
|
||||
if(size % 2 == 0) {
|
||||
return inv_schur_dec(m);
|
||||
}
|
||||
|
||||
return inv_adj(m);
|
||||
|
||||
};
|
||||
|
||||
|
||||
matrix inv_adj(matrix m) {
|
||||
|
||||
STA_ASSERT_MSG( m.get_cols() == m.get_rows(), "Matrix not square. Inverse not valid" );
|
||||
|
||||
float d = m.det();
|
||||
|
||||
STA_ASSERT_MSG( d!=0, "Matrix is singular. No inverse could be computed." );
|
||||
d = 1/d;
|
||||
|
||||
matrix a = adj(m);
|
||||
//a.show_serial();
|
||||
|
||||
return a * d;
|
||||
|
||||
};
|
||||
|
||||
matrix inv_char_poly(matrix m) {
|
||||
|
||||
STA_ASSERT_MSG( m.get_cols() == m.get_rows(), "Matrix not square. Inverse not valid" );
|
||||
|
||||
uint8_t size = m.get_cols();
|
||||
|
||||
if( size == 2 ) {
|
||||
|
||||
return _inv_char_poly_2x2(m);
|
||||
|
||||
}
|
||||
|
||||
if( size == 3 ) {
|
||||
|
||||
return _inv_char_poly_3x3(m);
|
||||
|
||||
}
|
||||
|
||||
// revert to different inv function, if matrix size is not correct
|
||||
|
||||
return inv(m);
|
||||
|
||||
};
|
||||
|
||||
matrix inv_schur_dec(matrix m) {
|
||||
|
||||
uint8_t rows = m.get_rows();
|
||||
uint8_t cols = m.get_cols();
|
||||
|
||||
STA_ASSERT_MSG( cols == rows, "Matrix not square. Inverse not valid" );
|
||||
|
||||
if( cols % 2 != 0) {
|
||||
// matrix size not integer, function cant be applied.
|
||||
return inv(m);
|
||||
}
|
||||
|
||||
float det = m.det();
|
||||
if(det == 0) {
|
||||
STA_DEBUG_PRINTLN("Matrix is singular. No inverse could be computed. returned identity");
|
||||
return matrix();
|
||||
}
|
||||
|
||||
uint8_t sub_size = cols/2;
|
||||
|
||||
matrix M_inv(cols, cols);
|
||||
|
||||
matrix A = m.get_block(0, 0, sub_size, sub_size);
|
||||
matrix B = m.get_block(0, sub_size, sub_size, sub_size);
|
||||
matrix C = m.get_block(sub_size, 0, sub_size, sub_size);
|
||||
matrix D = m.get_block(sub_size, sub_size, sub_size, sub_size);
|
||||
|
||||
matrix D_inv = inv(D);
|
||||
matrix M_D = A - (B * (D_inv * C));
|
||||
matrix M_D_inv = inv(M_D);
|
||||
|
||||
if(!D_inv.is_valid() || !M_D_inv.is_valid()) {
|
||||
return matrix();
|
||||
}
|
||||
|
||||
matrix _new_B = ((M_D_inv * (B * D_inv)) * -1 );
|
||||
matrix _new_C = ((D_inv * (C * M_D_inv)) * -1 );
|
||||
matrix _new_D = D_inv + (D_inv * (C * (M_D_inv * (B * D_inv))) );
|
||||
|
||||
M_inv.set_block(0, 0, M_D_inv);
|
||||
M_inv.set_block(0, sub_size, _new_B);
|
||||
M_inv.set_block(sub_size, 0, _new_C);
|
||||
M_inv.set_block(sub_size, sub_size, _new_D);
|
||||
|
||||
return M_inv;
|
||||
|
||||
|
||||
|
||||
};
|
||||
|
||||
matrix _inv_char_poly_3x3(matrix m) {
|
||||
|
||||
float det = m.det();
|
||||
|
||||
if(det == 0) {
|
||||
// matrix is singular. Inverse is invalid
|
||||
STA_DEBUG_PRINTLN("Matrix is singular. No inverse could be computed. returned identity");
|
||||
return matrix();
|
||||
}
|
||||
|
||||
float a0 = -1/det;
|
||||
float a1 = (m(2, 1) * m(1, 2)) + (m(1, 0) * m(0, 1)) + (m(2, 0) * m(0, 2)) - (m(0, 0) * m(1, 1)) - (m(0, 0) * m(2, 2)) - (m(1, 1) * m(2, 2));
|
||||
float a2 = m(0, 0) + m(1, 1) + m(2, 2);
|
||||
|
||||
matrix M_2 = m * m;
|
||||
|
||||
matrix out = ((matrix::eye(3) * a1 ) + (m * a2) - M_2) * a0;
|
||||
return out;
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
||||
matrix _inv_char_poly_2x2(matrix m) {
|
||||
|
||||
float a0 = (m(0, 0) * m(1, 1)) - (m(1, 0) * m(0, 1));
|
||||
float a1 = - m(0, 0) - m(1, 1);
|
||||
|
||||
if(a0 == 0) {
|
||||
STA_DEBUG_PRINTLN("matrix is singular. No inverse could be computed. returned identity");
|
||||
return matrix();
|
||||
}
|
||||
|
||||
float fac = -1/a0;
|
||||
|
||||
matrix I = matrix::eye(2);
|
||||
|
||||
return linalg::dot( linalg::add( linalg::dot(I, a1), m ) , fac );
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
424
src/linalg/matrix.cpp
Normal file
424
src/linalg/matrix.cpp
Normal file
@ -0,0 +1,424 @@
|
||||
#include <sta/math/linalg/matrix.hpp>
|
||||
#include <sta/math/linalg/linalg.hpp>
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
#include <sta/debug/debug.hpp>
|
||||
#include <sta/debug/assert.hpp>
|
||||
|
||||
namespace math {
|
||||
|
||||
matrix::matrix() {
|
||||
|
||||
datafield = nullptr;
|
||||
shape = nullptr;
|
||||
|
||||
}
|
||||
|
||||
matrix::matrix(const matrix &m) {
|
||||
|
||||
if (shape != nullptr) {
|
||||
|
||||
shape[0] -= 1;
|
||||
if (shape[0] <= 0) {
|
||||
free(datafield);
|
||||
free(shape);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
datafield = m.datafield;
|
||||
shape = m.shape;
|
||||
shape[0] += 1;
|
||||
|
||||
}
|
||||
|
||||
matrix::matrix(uint8_t rows, uint8_t cols) {
|
||||
|
||||
uint16_t size = rows * cols;
|
||||
datafield = (float *) malloc((sizeof(float) * size));
|
||||
shape = (uint8_t *) malloc(sizeof(uint8_t) * 4);
|
||||
|
||||
shape[0] = 1;
|
||||
shape[1] = rows;
|
||||
shape[2] = cols;
|
||||
shape[3] = rows * cols;
|
||||
|
||||
}
|
||||
|
||||
matrix::matrix(uint8_t rows, uint8_t cols, float *vals) {
|
||||
|
||||
uint16_t size = rows * cols;
|
||||
datafield = (float *) malloc((sizeof(float) * size));
|
||||
shape = (uint8_t *) malloc(sizeof(uint8_t) * 4);
|
||||
|
||||
shape[0] = 1;
|
||||
shape[1] = rows;
|
||||
shape[2] = cols;
|
||||
shape[3] = rows * cols;
|
||||
|
||||
for (uint16_t i = 0; i < size; i++) {
|
||||
datafield[i] = vals[i];
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
matrix::~matrix() {
|
||||
|
||||
if (shape != nullptr) {
|
||||
|
||||
shape[0] -= 1;
|
||||
if (shape[0] <= 0) {
|
||||
free(datafield);
|
||||
free(shape);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
bool matrix::is_valid() {
|
||||
if (shape == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
uint16_t matrix::get_size() {
|
||||
STA_ASSERT_MSG(shape != nullptr, "Shape is nullptr");
|
||||
return shape[3];
|
||||
}
|
||||
|
||||
uint8_t matrix::get_rows() {
|
||||
STA_ASSERT_MSG(shape != nullptr, "Shape is nullptr");
|
||||
return shape[1];
|
||||
}
|
||||
|
||||
uint8_t matrix::get_cols() {
|
||||
STA_ASSERT_MSG(shape != nullptr, "Shape is nullptr");
|
||||
return shape[2];
|
||||
}
|
||||
|
||||
matrix matrix::clone() {
|
||||
|
||||
matrix m(get_rows(), get_cols());
|
||||
|
||||
uint16_t size = get_size();
|
||||
for (uint16_t i = 0; i < size; i++) {
|
||||
m.datafield[i] = datafield[i];
|
||||
}
|
||||
|
||||
return m;
|
||||
|
||||
}
|
||||
|
||||
matrix &matrix::operator=(matrix m) {
|
||||
|
||||
if (shape != nullptr) {
|
||||
|
||||
shape[0] -= 1;
|
||||
if (shape[0] <= 0) {
|
||||
free(datafield);
|
||||
free(shape);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
datafield = m.datafield;
|
||||
shape = m.shape;
|
||||
shape[0] += 1;
|
||||
return *this;
|
||||
|
||||
}
|
||||
|
||||
void matrix::reshape(uint8_t r, uint8_t c) {
|
||||
|
||||
STA_ASSERT_MSG(r * c == get_size(), "New shape does not match old shape");
|
||||
shape[1] = r;
|
||||
shape[2] = c;
|
||||
shape[3] = r * c;
|
||||
|
||||
}
|
||||
|
||||
float matrix::det() {
|
||||
|
||||
uint8_t rows = get_rows();
|
||||
uint8_t cols = get_cols();
|
||||
|
||||
STA_ASSERT_MSG(rows == cols && rows >= 1, "Matrix is not square. Determinant can not be computed." );
|
||||
|
||||
if (rows == 1) {
|
||||
return datafield[0];
|
||||
}
|
||||
|
||||
if (rows == 2) {
|
||||
return (operator()(0, 0) * operator()(1, 1)) - (operator()(1, 0) * operator()(0, 1));
|
||||
}
|
||||
|
||||
if (rows == 3) {
|
||||
float S = 0;
|
||||
S += operator()(0, 0) * ((operator()(1, 1) * operator()(2, 2)) - (operator()(2, 1) * operator()(1, 2)));
|
||||
S -= operator()(0, 1) * ((operator()(1, 0) * operator()(2, 2)) - (operator()(2, 0) * operator()(1, 2)));
|
||||
S += operator()(0, 2) * ((operator()(1, 0) * operator()(2, 1)) - (operator()(2, 0) * operator()(1, 1)));
|
||||
return S;
|
||||
}
|
||||
|
||||
if (rows > 3) {
|
||||
|
||||
float determinant = 0;
|
||||
|
||||
for (uint8_t k = 0; k < rows; k++) {
|
||||
|
||||
matrix submatrix(rows - 1, rows - 1);
|
||||
uint8_t i = 0;
|
||||
|
||||
for (uint8_t x = 0; x < rows; x++) {
|
||||
|
||||
for (uint8_t y = 1; y < cols; y++) {
|
||||
|
||||
float val = operator()(x, y);
|
||||
|
||||
if (x != k) {
|
||||
submatrix.set(i, y - 1, val);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if (x != k) {
|
||||
i += 1;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
float new_determinant = submatrix.det() * operator()(k, 0);
|
||||
if (k % 2 == 1) {
|
||||
new_determinant *= -1;
|
||||
}
|
||||
|
||||
determinant += new_determinant;
|
||||
|
||||
|
||||
}
|
||||
|
||||
return determinant;
|
||||
|
||||
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
|
||||
}
|
||||
|
||||
matrix matrix::get_block(uint8_t start_r, uint8_t start_c, uint8_t len_r, uint8_t len_c) {
|
||||
|
||||
//matrix output(len_r, len_c);
|
||||
matrix output = matrix::zeros(len_r, len_c);
|
||||
|
||||
STA_ASSERT_MSG(start_r + len_r <= get_rows() && start_c + len_c <= get_cols(), "get_block failed. Boundary conditions not initialized." );
|
||||
|
||||
for (uint8_t r = 0; r < len_r; r++) {
|
||||
|
||||
for (uint8_t c = 0; c < len_c; c++) {
|
||||
|
||||
float val = operator()(start_r + r, start_c + c);
|
||||
output.set(r, c, val);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return output;
|
||||
|
||||
}
|
||||
|
||||
void matrix::set_block(uint8_t _r, uint8_t _c, matrix m) {
|
||||
|
||||
STA_ASSERT_MSG(_r + m.get_rows() <= get_rows() && _c + m.get_cols() <= get_cols(), "set_block failed. Boundary conditions not initialized." );
|
||||
|
||||
for (uint8_t r = 0; r < m.get_rows(); r++) {
|
||||
|
||||
for (uint8_t c = 0; c < m.get_cols(); c++) {
|
||||
|
||||
set(_r + r, _c + c, m(r, c));
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
void matrix::set(uint8_t r, uint8_t c, float v) {
|
||||
|
||||
datafield[get_idx(r, c)] = v;
|
||||
|
||||
}
|
||||
|
||||
void matrix::set(uint16_t i, float v) {
|
||||
|
||||
STA_ASSERT_MSG(i < get_size(), "Index out of Bounds" );
|
||||
|
||||
datafield[i] = v;
|
||||
|
||||
}
|
||||
|
||||
matrix matrix::get_submatrix(uint8_t _r, uint8_t _c) {
|
||||
|
||||
matrix output = clone();
|
||||
|
||||
for (uint8_t r = 0; r < get_rows(); r++) {
|
||||
|
||||
output.set(r, _c, 0);
|
||||
|
||||
}
|
||||
|
||||
for (uint8_t c = 0; c < get_cols(); c++) {
|
||||
|
||||
output.set(_r, c, 0);
|
||||
|
||||
}
|
||||
|
||||
return output;
|
||||
|
||||
}
|
||||
|
||||
matrix matrix::eye(uint8_t s) {
|
||||
|
||||
matrix output = matrix::zeros(s, s);
|
||||
for (uint8_t i = 0; i < s; i++) {
|
||||
output.set(i, i, 1);
|
||||
}
|
||||
return output;
|
||||
|
||||
}
|
||||
|
||||
matrix matrix::zeros(uint8_t r, uint8_t c) {
|
||||
|
||||
matrix output(r, c);
|
||||
for (uint16_t i = 0; i < r * c; i++) {
|
||||
output.datafield[i] = 0;
|
||||
}
|
||||
return output;
|
||||
|
||||
}
|
||||
|
||||
float matrix::operator()(uint8_t r, uint8_t c) {
|
||||
|
||||
return datafield[get_idx(r, c)];
|
||||
|
||||
}
|
||||
|
||||
float matrix::operator[](uint16_t i) {
|
||||
if (i > get_size()) {
|
||||
return 0;
|
||||
}
|
||||
return datafield[i];
|
||||
|
||||
}
|
||||
|
||||
uint16_t matrix::get_idx(uint8_t r, uint8_t c) {
|
||||
STA_ASSERT_MSG(r * c <= get_size(), "Index out of bounds get_idx");
|
||||
|
||||
return (r * get_cols()) + c;
|
||||
|
||||
}
|
||||
|
||||
matrix matrix::T() {
|
||||
|
||||
matrix output(get_cols(), get_rows());
|
||||
|
||||
for (uint8_t r = 0; r < get_rows(); r++) {
|
||||
for (uint8_t c = 0; c < get_cols(); c++) {
|
||||
|
||||
output.set(c, r, operator()(r, c));
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
return output;
|
||||
|
||||
}
|
||||
|
||||
matrix matrix::flatten() {
|
||||
|
||||
matrix output = clone();
|
||||
output.reshape(get_size(), 1);
|
||||
return output;
|
||||
|
||||
}
|
||||
|
||||
float matrix::minor(uint8_t r, uint8_t c) {
|
||||
|
||||
matrix out(get_rows() - 1, get_cols() - 1);
|
||||
|
||||
|
||||
for (uint8_t row = 0; row < get_rows() - 1; row++) {
|
||||
for (uint8_t col = 0; col < get_cols() - 1; col++) {
|
||||
|
||||
|
||||
if (row < r && col < c) {
|
||||
out.set(row, col, operator()(row, col));
|
||||
} else if (row >= r && col < c) {
|
||||
out.set(row, col, operator()(row + 1, col));
|
||||
} else if (row < r && col >= c) {
|
||||
out.set(row, col, operator()(row, col + 1));
|
||||
} else {
|
||||
out.set(row, col, operator()(row + 1, col + 1));
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
return out.det();
|
||||
|
||||
|
||||
}
|
||||
|
||||
matrix matrix::operator*(float s) {
|
||||
return linalg::dot(*this, s);
|
||||
}
|
||||
|
||||
matrix matrix::operator*(matrix m) {
|
||||
return linalg::dot(*this, m);
|
||||
}
|
||||
|
||||
matrix matrix::operator+(matrix m) {
|
||||
return linalg::add(*this, m);
|
||||
}
|
||||
|
||||
matrix matrix::operator-(matrix m) {
|
||||
return linalg::subtract(*this, m);
|
||||
}
|
||||
|
||||
|
||||
void matrix::show_serial() {
|
||||
|
||||
show_shape();
|
||||
|
||||
for(uint8_t r = 0; r < get_rows(); r++) {
|
||||
|
||||
for(uint8_t c = 0; c < get_cols(); c++) {
|
||||
|
||||
STA_DEBUG_PRINT("| ");
|
||||
STA_DEBUG_PRINT(operator()(r, c));
|
||||
if(c == get_cols() - 1) {
|
||||
STA_DEBUG_PRINTLN(" |");
|
||||
} else {
|
||||
STA_DEBUG_PRINT(" ");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void matrix::show_shape() {
|
||||
|
||||
STA_DEBUG_PRINTF("Matrix shape: (%d x %d)", get_rows(), get_cols());
|
||||
|
||||
}
|
||||
|
||||
}
|
26
src/utils.cpp
Normal file
26
src/utils.cpp
Normal file
@ -0,0 +1,26 @@
|
||||
#include <sta/math/utils.hpp>
|
||||
#include <cstdint>
|
||||
|
||||
namespace math
|
||||
{
|
||||
|
||||
|
||||
float fast_inv_sqrt(float v) {
|
||||
|
||||
long i;
|
||||
float x2, y;
|
||||
const float threehalfs = 1.5f;
|
||||
|
||||
y = v;
|
||||
x2 = y * 0.5f;
|
||||
i = * (long*)&y;
|
||||
i = 0x5f3759df - (i >> 1);
|
||||
y = *(float *) &i;
|
||||
y = y * (threehalfs - (x2 * y * y));
|
||||
//y = y * (threehalfs - (x2 * y * y));
|
||||
|
||||
return y;
|
||||
|
||||
}
|
||||
|
||||
} // namespace stamath
|
Loading…
x
Reference in New Issue
Block a user