import numpy as np from typing import Any, Tuple from numpy.typing import ArrayLike def inv(val): if np.isscalar(val): if val == 0: return 0 return 1 / val if len(val) == 1: return 1 / val[0] return np.linalg.inv(val) class KalmanFilter: def __init__(self, A: ArrayLike, B: ArrayLike, Q: ArrayLike, H: ArrayLike, R: ArrayLike) -> None: """Simple Kalman Filter implementation. Args: A (ArrayLike): State transition matrix. B (ArrayLike): Controll matrix. Q (ArrayLike): Process noise matrix. H (ArrayLike): Observation matrix. R (ArrayLike): Measurement noise matrix. """ self.__A = A self.__B = B self.__Q = Q self.__H = H self.__R = R def predict(self, dt: float, x: ArrayLike, err: ArrayLike, u: ArrayLike) -> Tuple[ArrayLike, ArrayLike]: """Perform the prediction step of the Kalman Filter. Args: dt (float): The change in time since the last prediction. x (ArrayLike): The current state. err (ArrayLike): The current error. u (ArrayLike): The control input. Returns: Tuple[ArrayLike, ArrayLike]: Returns the prediction state and the corresponding error. """ A = self.__A B = self.__B Q = self.__Q if hasattr(self.__A, '__call__'): A = self.__A(dt) if hasattr(self.__B, '__call__'): B = self.__B(dt) if hasattr(self.__Q, '__call__'): Q = self.__Q(dt) x = A @ x + B @ u err = A @ err @ A.T + Q return x, err def correct(self, dt: float, x: ArrayLike, err: ArrayLike, z: ArrayLike) -> Tuple[ArrayLike, ArrayLike]: """Perform the correction step of the Kalman Filter. Args: dt (float): The change in time since the last correction. x (ArrayLike): The predicted state err (ArrayLike): The error after prediction. z (ArrayLike): The new measurements. Returns: Tuple[ArrayLike, ArrayLike]: Returns a corrected state and updated errors. """ H = self.__H R = self.__R n = len(x) if hasattr(self.__H, '__call__'): H = self.__H(dt) if hasattr(self.__R, '__call__'): R = self.__R(dt) # Compute the Kalman gain. K = err @ H.T @ inv(H @ err @ H.T + R) # Compute the corrected state. x = x + (K @ (z - H @ x).T).T # Compute the error after correction. err = (np.identity(n) - K @ H) @ err return np.squeeze(np.asarray(x)), err