From af40569b9d13e1e959a6d6041d05256ff27fc744 Mon Sep 17 00:00:00 2001 From: Milo Priegnitz Date: Sun, 2 Jun 2024 20:31:43 +0200 Subject: [PATCH] Add: conviniece functions, Fix: Asserts and debug prints --- include/sta/math/linalg/matrix.hpp | 3 ++- src/algorithms/kalmanFilter.cpp | 6 ++++-- src/linalg/matrix.cpp | 17 ++++++++++++++--- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/include/sta/math/linalg/matrix.hpp b/include/sta/math/linalg/matrix.hpp index 71a2f94..cf4b3af 100644 --- a/include/sta/math/linalg/matrix.hpp +++ b/include/sta/math/linalg/matrix.hpp @@ -1,7 +1,6 @@ #ifndef INC_MATRIX_HPP_ #define INC_MATRIX_HPP_ #include - namespace math { @@ -39,6 +38,7 @@ struct matrix static matrix eye(uint8_t); static matrix zeros(uint8_t, uint8_t); + static matrix full(uint8_t, uint8_t, float); float operator()(uint8_t, uint8_t); float operator[](uint16_t); @@ -57,4 +57,5 @@ struct matrix } + #endif /* INC_MATRIX_HPP_ */ diff --git a/src/algorithms/kalmanFilter.cpp b/src/algorithms/kalmanFilter.cpp index c84af68..8bd8b62 100644 --- a/src/algorithms/kalmanFilter.cpp +++ b/src/algorithms/kalmanFilter.cpp @@ -9,11 +9,11 @@ namespace math KalmanFilter::KalmanFilter(matrix A, matrix B, matrix C, matrix Q, matrix R) : A_{A}, B_{B}, C_{C}, Q_{Q}, R_{R}, n_{A.get_cols()} { STA_ASSERT_MSG(A.get_rows() == B.get_rows(), "#rows mismatch: A, B!"); - STA_ASSERT_MSG(A.get_rows() == C.get_rows(), "#rows mismatch: A, C!"); + STA_ASSERT_MSG(A.get_cols() == C.get_cols(), "#cols mismatch: A, C!"); STA_ASSERT_MSG(A.get_rows() == Q.get_rows(), "#rows mismatch: A, Q!"); STA_ASSERT_MSG(A.get_cols() == Q.get_rows(), "Q not square!"); STA_ASSERT_MSG(C.get_rows() == R.get_rows(), "#rows mismatch: C, R"); - STA_ASSERT_MSG(R.get_cols() == R.get_rows(), "R not square!"); + STA_ASSERT_MSG(R.get_rows() == R.get_cols(), "#R not square"); identity_ = matrix::eye(n_); } @@ -37,8 +37,10 @@ KalmanState KalmanFilter::correct(KalmanState state, matrix z) // Correct step implementation // Calculate the Kalman gain matrix K = state.error * C_.T() * linalg::inv(C_ * state.error * C_.T() + R_); + K.show_serial(); // Update the state based on the measurement state.x = state.x + K * (z - C_ * state.x); //TODO check transpose + state.x.show_serial(); // Update the error covariance matrix state.error = (identity_ - K * C_) * state.error; return state; diff --git a/src/linalg/matrix.cpp b/src/linalg/matrix.cpp index 78e9aca..811d901 100644 --- a/src/linalg/matrix.cpp +++ b/src/linalg/matrix.cpp @@ -1,10 +1,10 @@ #include #include #include +#include #include #include #include - namespace math { matrix::matrix() { @@ -302,9 +302,19 @@ matrix matrix::zeros(uint8_t r, uint8_t c) { } -float matrix::operator()(uint8_t r, uint8_t c) { +matrix matrix::full(uint8_t r, uint8_t c, float v) { - return datafield[get_idx(r, c)]; + matrix output(r, c); + for (uint16_t i = 0; i < r * c; i++) { + output.datafield[i] = v; + } + return output; + +} + +float matrix::operator()(uint8_t r, uint8_t c) { + float i = datafield[get_idx(r, c)]; + return i; } @@ -415,6 +425,7 @@ void matrix::show_serial() { } + void matrix::show_shape() { STA_DEBUG_PRINTF("Matrix shape: (%d x %d)", get_rows(), get_cols());