Optimize correct

This commit is contained in:
Milo Priegnitz 2024-07-15 17:25:15 +02:00 committed by dario
parent ee10ce1189
commit d5914f2d1c
3 changed files with 13 additions and 12 deletions

View File

@ -78,7 +78,7 @@ public:
* @param R The measurement noise covariance matrix. * @param R The measurement noise covariance matrix.
* @return The corrected state of the Kalman filter. * @return The corrected state of the Kalman filter.
*/ */
static KalmanState correct(KalmanState state, matrix z, matrix H, matrix R); KalmanState correct(KalmanState state, matrix z, matrix H, matrix R);
}; };
} // namespace math } // namespace math

View File

@ -2,6 +2,7 @@
#include <sta/math/linalg/linalg.hpp> #include <sta/math/linalg/linalg.hpp>
#include <sta/debug/debug.hpp> #include <sta/debug/debug.hpp>
#include <sta/debug/assert.hpp> #include <sta/debug/assert.hpp>
#include <sta/debug/profile.hpp>
namespace sta namespace sta
{ {
@ -30,9 +31,9 @@ KalmanState KalmanFilter::predict(KalmanState state, matrix u)
{ {
// Predict step implementation // Predict step implementation
// Update the state based on the system dynamics // Update the state based on the system dynamics
state.x = F_ * state.x + B_ * u; state.x = (F_ * state.x) + (B_ * u);
// Update the error covariance matrix // Update the error covariance matrix
state.error = F_ * state.error * F_.T() + Q_; state.error = F_ *( state.error * F_.T()) + Q_;
return state; return state;
} }
@ -40,11 +41,13 @@ KalmanState KalmanFilter::correct(KalmanState state, matrix z)
{ {
// Correct step implementation // Correct step implementation
// Calculate the Kalman gain // Calculate the Kalman gain
matrix K = state.error * H_.T() * linalg::inv(H_ * state.error * H_.T() + R_); matrix K;
K.show_serial(); {
K = state.error * H_.T() * linalg::inv(H_ * state.error * H_.T() + R_);
}
// Update the state based on the measurement // Update the state based on the measurement
state.x = state.x + K * (z - H_ * state.x); //TODO check transpose state.x = state.x + K * (z - H_ * state.x); //TODO check transpose
state.x.show_serial();
// Update the error covariance matrix // Update the error covariance matrix
state.error = (identity_ - K * H_) * state.error; state.error = (identity_ - K * H_) * state.error;
return state; return state;
@ -55,12 +58,11 @@ KalmanState KalmanFilter::correct(KalmanState state, matrix z, matrix H, matrix
// Correct step implementation // Correct step implementation
// Calculate the Kalman gain // Calculate the Kalman gain
matrix K = state.error * H.T() * linalg::inv(H * state.error * H.T() + R); matrix K = state.error * H.T() * linalg::inv(H * state.error * H.T() + R);
K.show_serial();
// Update the state based on the measurement // Update the state based on the measurement
state.x = state.x + K * (z - H * state.x); //TODO check transpose state.x = state.x + K * (z - H * state.x); //TODO check transpose
state.x.show_serial();
// Update the error covariance matrix // Update the error covariance matrix
state.error = (matrix::eye(state.x.get_rows()) - K * H) * state.error; //state.error = (matrix::eye(state.x.get_rows()) - K * H) * state.error;
state.error = (identity_ - K * H) * state.error;
return state; return state;
} }

View File

@ -4,6 +4,7 @@
#include <cmath> #include <cmath>
#include <sta/debug/debug.hpp> #include <sta/debug/debug.hpp>
#include <sta/debug/assert.hpp> #include <sta/debug/assert.hpp>
#include <sta/debug/profile.hpp>
namespace sta{ namespace sta{
@ -13,7 +14,6 @@ namespace linalg {
matrix dot(matrix a, matrix b) { matrix dot(matrix a, matrix b) {
STA_ASSERT_MSG(a.get_cols() == b.get_rows(), "Matrix dimension mismatch"); STA_ASSERT_MSG(a.get_cols() == b.get_rows(), "Matrix dimension mismatch");
uint8_t k = a.get_cols(); uint8_t k = a.get_cols();
@ -119,7 +119,6 @@ matrix skew_symmetric(matrix m) {
}; };
matrix add(matrix a, matrix b) { matrix add(matrix a, matrix b) {
STA_ASSERT_MSG( a.get_rows() == b.get_rows() && a.get_cols() == b.get_cols(), "Matrix dimensions mismatch!" ); STA_ASSERT_MSG( a.get_rows() == b.get_rows() && a.get_cols() == b.get_cols(), "Matrix dimensions mismatch!" );
matrix output = a.clone(); matrix output = a.clone();
@ -142,7 +141,7 @@ matrix subtract(matrix a, matrix b) {
return output; return output;
}; };
matrix dot(matrix m, float s) { matrix dot(matrix m, float s) {
float size = m.get_size(); float size = m.get_size();
matrix output = m.clone(); matrix output = m.clone();