From d5914f2d1ce9c03fb46cdae0cc70a99d90cd2f11 Mon Sep 17 00:00:00 2001 From: Milo Priegnitz Date: Mon, 15 Jul 2024 17:25:15 +0200 Subject: [PATCH] Optimize correct --- include/sta/math/algorithms/kalmanFilter.hpp | 2 +- src/algorithms/kalmanFilter.cpp | 18 ++++++++++-------- src/linalg/linalg.cpp | 5 ++--- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/include/sta/math/algorithms/kalmanFilter.hpp b/include/sta/math/algorithms/kalmanFilter.hpp index 1d76460..cbe65b6 100644 --- a/include/sta/math/algorithms/kalmanFilter.hpp +++ b/include/sta/math/algorithms/kalmanFilter.hpp @@ -78,7 +78,7 @@ public: * @param R The measurement noise covariance matrix. * @return The corrected state of the Kalman filter. */ - static KalmanState correct(KalmanState state, matrix z, matrix H, matrix R); + KalmanState correct(KalmanState state, matrix z, matrix H, matrix R); }; } // namespace math diff --git a/src/algorithms/kalmanFilter.cpp b/src/algorithms/kalmanFilter.cpp index 3242822..1ea9224 100644 --- a/src/algorithms/kalmanFilter.cpp +++ b/src/algorithms/kalmanFilter.cpp @@ -2,6 +2,7 @@ #include #include #include +#include namespace sta { @@ -30,9 +31,9 @@ KalmanState KalmanFilter::predict(KalmanState state, matrix u) { // Predict step implementation // Update the state based on the system dynamics - state.x = F_ * state.x + B_ * u; + state.x = (F_ * state.x) + (B_ * u); // Update the error covariance matrix - state.error = F_ * state.error * F_.T() + Q_; + state.error = F_ *( state.error * F_.T()) + Q_; return state; } @@ -40,11 +41,13 @@ KalmanState KalmanFilter::correct(KalmanState state, matrix z) { // Correct step implementation // Calculate the Kalman gain - matrix K = state.error * H_.T() * linalg::inv(H_ * state.error * H_.T() + R_); - K.show_serial(); + matrix K; + { + + K = state.error * H_.T() * linalg::inv(H_ * state.error * H_.T() + R_); + } // Update the state based on the measurement state.x = state.x + K * (z - H_ * state.x); //TODO check transpose - state.x.show_serial(); // Update the error covariance matrix state.error = (identity_ - K * H_) * state.error; return state; @@ -55,12 +58,11 @@ KalmanState KalmanFilter::correct(KalmanState state, matrix z, matrix H, matrix // Correct step implementation // Calculate the Kalman gain matrix K = state.error * H.T() * linalg::inv(H * state.error * H.T() + R); - K.show_serial(); // Update the state based on the measurement state.x = state.x + K * (z - H * state.x); //TODO check transpose - state.x.show_serial(); // Update the error covariance matrix - state.error = (matrix::eye(state.x.get_rows()) - K * H) * state.error; + //state.error = (matrix::eye(state.x.get_rows()) - K * H) * state.error; + state.error = (identity_ - K * H) * state.error; return state; } diff --git a/src/linalg/linalg.cpp b/src/linalg/linalg.cpp index 4ee7eaa..3689546 100644 --- a/src/linalg/linalg.cpp +++ b/src/linalg/linalg.cpp @@ -4,6 +4,7 @@ #include #include #include +#include namespace sta{ @@ -13,7 +14,6 @@ namespace linalg { matrix dot(matrix a, matrix b) { - STA_ASSERT_MSG(a.get_cols() == b.get_rows(), "Matrix dimension mismatch"); uint8_t k = a.get_cols(); @@ -119,7 +119,6 @@ matrix skew_symmetric(matrix m) { }; matrix add(matrix a, matrix b) { - STA_ASSERT_MSG( a.get_rows() == b.get_rows() && a.get_cols() == b.get_cols(), "Matrix dimensions mismatch!" ); matrix output = a.clone(); @@ -142,7 +141,7 @@ matrix subtract(matrix a, matrix b) { return output; }; matrix dot(matrix m, float s) { - + float size = m.get_size(); matrix output = m.clone();