SPATZ/spatz/models/kalman.py

99 lines
2.7 KiB
Python

import numpy as np
from typing import Any, Tuple
from numpy.typing import ArrayLike
def inv(val):
if np.isscalar(val):
if val == 0:
return 0
return 1 / val
if len(val) == 1:
return 1 / val[0]
return np.linalg.inv(val)
class KalmanFilter:
def __init__(self, A: ArrayLike, B: ArrayLike, Q: ArrayLike, H: ArrayLike, R: ArrayLike) -> None:
"""Simple Kalman Filter implementation.
Args:
A (ArrayLike): State transition matrix.
B (ArrayLike): Controll matrix.
Q (ArrayLike): Process noise matrix.
H (ArrayLike): Observation matrix.
R (ArrayLike): Measurement noise matrix.
"""
self.__A = A
self.__B = B
self.__Q = Q
self.__H = H
self.__R = R
def predict(self, dt: float, x: ArrayLike, err: ArrayLike, u: ArrayLike) -> Tuple[ArrayLike, ArrayLike]:
"""Perform the prediction step of the Kalman Filter.
Args:
dt (float): The change in time since the last prediction.
x (ArrayLike): The current state.
err (ArrayLike): The current error.
u (ArrayLike): The control input.
Returns:
Tuple[ArrayLike, ArrayLike]: Returns the prediction state and the corresponding error.
"""
A = self.__A
B = self.__B
Q = self.__Q
if hasattr(self.__A, '__call__'):
A = self.__A(dt)
if hasattr(self.__B, '__call__'):
B = self.__B(dt)
if hasattr(self.__Q, '__call__'):
Q = self.__Q(dt)
x = A @ x + B @ u
err = A @ err @ A.T + Q
return x, err
def correct(self, dt: float, x: ArrayLike, err: ArrayLike, z: ArrayLike) -> Tuple[ArrayLike, ArrayLike]:
"""Perform the correction step of the Kalman Filter.
Args:
dt (float): The change in time since the last correction.
x (ArrayLike): The predicted state
err (ArrayLike): The error after prediction.
z (ArrayLike): The new measurements.
Returns:
Tuple[ArrayLike, ArrayLike]: Returns a corrected state and updated errors.
"""
H = self.__H
R = self.__R
n = len(x)
if hasattr(self.__H, '__call__'):
H = self.__H(dt)
if hasattr(self.__R, '__call__'):
R = self.__R(dt)
# Compute the Kalman gain.
K = err @ H.T @ inv(H @ err @ H.T + R)
# Compute the corrected state.
x = x + (K @ (z - H @ x).T).T
# Compute the error after correction.
err = (np.identity(n) - K @ H) @ err
return np.squeeze(np.asarray(x)), err