Extend correction step

This commit is contained in:
Milo Priegnitz
2024-06-12 15:03:38 +02:00
parent 5cf23f802e
commit c5ea7dc596
6 changed files with 185 additions and 13 deletions

View File

@@ -10,17 +10,20 @@ namespace sta
namespace math
{
DynamicKalmanFilter::DynamicKalmanFilter(matrix A, matrix T, matrix B, matrix H, matrix Q, matrix R) : A_{A},T_{T}, B_{B}, H_{H}, Q_{Q}, R_{R}, n_{A.get_cols()}
DynamicKalmanFilter::DynamicKalmanFilter(matrix A, matrix TA, matrix B, matrix H, matrix Q, matrix TQ, matrix R) : A_{A},TA_{TA}, B_{B}, H_{H}, Q_{Q}, TQ_{TQ}, R_{R}, n_{A.get_cols()}
{
STA_ASSERT_MSG(A.get_rows() == B.get_rows(), "#rows mismatch: A, B!");
STA_ASSERT_MSG(A.get_cols() == H.get_cols(), "#cols mismatch: A, H!");
STA_ASSERT_MSG(A.get_rows() == T.get_rows(), "#rows mismatch: A, T!");
STA_ASSERT_MSG(A.get_cols() == T.get_cols(), "#cols mismatch: A, T!");
STA_ASSERT_MSG(A.get_rows() == TA.get_rows(), "#rows mismatch: A, TA!");
STA_ASSERT_MSG(A.get_cols() == TA.get_cols(), "#cols mismatch: A, TA!");
STA_ASSERT_MSG(A.get_rows() == Q.get_rows(), "#rows mismatch: A, Q!");
STA_ASSERT_MSG(Q.get_cols() == Q.get_rows(), "Q not square!");
STA_ASSERT_MSG(A.get_cols() == A.get_rows(), "A not square!");
STA_ASSERT_MSG(H.get_rows() == R.get_rows(), "#rows mismatch: H, R");
STA_ASSERT_MSG(R.get_rows() == R.get_cols(), "#R not square");
STA_ASSERT_MSG(Q.get_rows() == TQ.get_rows(), "#rows mismatch: Q, TQ!");
STA_ASSERT_MSG(Q.get_cols() == TQ.get_cols(), "#cols mismatch: Q, TQ!");
identity_ = matrix::eye(n_);
}
@@ -32,14 +35,19 @@ DynamicKalmanFilter::~DynamicKalmanFilter()
KalmanState DynamicKalmanFilter::predict(float dt, KalmanState state, matrix u)
{
//Build the state transition matrix
matrix F = matrix::zeros(9, 9);
for(int i =0; i < n_*n_;i++){
F.set(i, A_[i] * std::pow(dt, T_[i]));
matrix F = matrix::zeros(A_.get_rows(), A_.get_rows());
for(int i =0; i < A_.get_rows()* A_.get_rows();i++){
F.set(i, A_[i] * std::pow(dt, TA_[i]));
}
matrix Q = matrix::zeros(Q_.get_rows(), Q_.get_cols());
for(int i =0; i < Q_.get_rows()* Q_.get_cols();i++){
Q.set(i, Q_[i] * std::pow(dt, TQ_[i]));
}
// Update the state based on the system dynamics
state.x = F * state.x + F * u;
// Update the error covariance matrix
state.error = F * state.error * F.T() + Q_;
state.error = F * state.error * F.T() + Q;
return state;
}

View File

@@ -50,6 +50,20 @@ KalmanState KalmanFilter::correct(KalmanState state, matrix z)
return state;
}
KalmanState KalmanFilter::correct(KalmanState state, matrix z, matrix H, matrix R)
{
// Correct step implementation
// Calculate the Kalman gain
matrix K = state.error * H.T() * linalg::inv(H * state.error * H.T() + R);
K.show_serial();
// Update the state based on the measurement
state.x = state.x + K * (z - H * state.x); //TODO check transpose
state.x.show_serial();
// Update the error covariance matrix
state.error = (identity_ - K * H) * state.error;
return state;
}
} // namespace math
} // namespace sta