Add: conviniece functions, Fix: Asserts and debug prints

This commit is contained in:
Milo Priegnitz 2024-06-02 20:31:43 +02:00
parent 7c8388ebb9
commit af40569b9d
3 changed files with 20 additions and 6 deletions

View File

@ -1,7 +1,6 @@
#ifndef INC_MATRIX_HPP_ #ifndef INC_MATRIX_HPP_
#define INC_MATRIX_HPP_ #define INC_MATRIX_HPP_
#include <cstdint> #include <cstdint>
namespace math namespace math
{ {
@ -39,6 +38,7 @@ struct matrix
static matrix eye(uint8_t); static matrix eye(uint8_t);
static matrix zeros(uint8_t, uint8_t); static matrix zeros(uint8_t, uint8_t);
static matrix full(uint8_t, uint8_t, float);
float operator()(uint8_t, uint8_t); float operator()(uint8_t, uint8_t);
float operator[](uint16_t); float operator[](uint16_t);
@ -57,4 +57,5 @@ struct matrix
} }
#endif /* INC_MATRIX_HPP_ */ #endif /* INC_MATRIX_HPP_ */

View File

@ -9,11 +9,11 @@ namespace math
KalmanFilter::KalmanFilter(matrix A, matrix B, matrix C, matrix Q, matrix R) : A_{A}, B_{B}, C_{C}, Q_{Q}, R_{R}, n_{A.get_cols()} KalmanFilter::KalmanFilter(matrix A, matrix B, matrix C, matrix Q, matrix R) : A_{A}, B_{B}, C_{C}, Q_{Q}, R_{R}, n_{A.get_cols()}
{ {
STA_ASSERT_MSG(A.get_rows() == B.get_rows(), "#rows mismatch: A, B!"); STA_ASSERT_MSG(A.get_rows() == B.get_rows(), "#rows mismatch: A, B!");
STA_ASSERT_MSG(A.get_rows() == C.get_rows(), "#rows mismatch: A, C!"); STA_ASSERT_MSG(A.get_cols() == C.get_cols(), "#cols mismatch: A, C!");
STA_ASSERT_MSG(A.get_rows() == Q.get_rows(), "#rows mismatch: A, Q!"); STA_ASSERT_MSG(A.get_rows() == Q.get_rows(), "#rows mismatch: A, Q!");
STA_ASSERT_MSG(A.get_cols() == Q.get_rows(), "Q not square!"); STA_ASSERT_MSG(A.get_cols() == Q.get_rows(), "Q not square!");
STA_ASSERT_MSG(C.get_rows() == R.get_rows(), "#rows mismatch: C, R"); STA_ASSERT_MSG(C.get_rows() == R.get_rows(), "#rows mismatch: C, R");
STA_ASSERT_MSG(R.get_cols() == R.get_rows(), "R not square!"); STA_ASSERT_MSG(R.get_rows() == R.get_cols(), "#R not square");
identity_ = matrix::eye(n_); identity_ = matrix::eye(n_);
} }
@ -37,8 +37,10 @@ KalmanState KalmanFilter::correct(KalmanState state, matrix z)
// Correct step implementation // Correct step implementation
// Calculate the Kalman gain // Calculate the Kalman gain
matrix K = state.error * C_.T() * linalg::inv(C_ * state.error * C_.T() + R_); matrix K = state.error * C_.T() * linalg::inv(C_ * state.error * C_.T() + R_);
K.show_serial();
// Update the state based on the measurement // Update the state based on the measurement
state.x = state.x + K * (z - C_ * state.x); //TODO check transpose state.x = state.x + K * (z - C_ * state.x); //TODO check transpose
state.x.show_serial();
// Update the error covariance matrix // Update the error covariance matrix
state.error = (identity_ - K * C_) * state.error; state.error = (identity_ - K * C_) * state.error;
return state; return state;

View File

@ -1,10 +1,10 @@
#include <sta/math/linalg/matrix.hpp> #include <sta/math/linalg/matrix.hpp>
#include <sta/math/linalg/linalg.hpp> #include <sta/math/linalg/linalg.hpp>
#include <cstdint> #include <cstdint>
#include <cstring>
#include <iostream> #include <iostream>
#include <sta/debug/debug.hpp> #include <sta/debug/debug.hpp>
#include <sta/debug/assert.hpp> #include <sta/debug/assert.hpp>
namespace math { namespace math {
matrix::matrix() { matrix::matrix() {
@ -302,9 +302,19 @@ matrix matrix::zeros(uint8_t r, uint8_t c) {
} }
float matrix::operator()(uint8_t r, uint8_t c) { matrix matrix::full(uint8_t r, uint8_t c, float v) {
return datafield[get_idx(r, c)]; matrix output(r, c);
for (uint16_t i = 0; i < r * c; i++) {
output.datafield[i] = v;
}
return output;
}
float matrix::operator()(uint8_t r, uint8_t c) {
float i = datafield[get_idx(r, c)];
return i;
} }
@ -415,6 +425,7 @@ void matrix::show_serial() {
} }
void matrix::show_shape() { void matrix::show_shape() {
STA_DEBUG_PRINTF("Matrix shape: (%d x %d)", get_rows(), get_cols()); STA_DEBUG_PRINTF("Matrix shape: (%d x %d)", get_rows(), get_cols());