From bcdfe329cc1a2df506f9c7a89908641c4240144f Mon Sep 17 00:00:00 2001 From: dario Date: Sat, 30 Dec 2023 12:54:34 +0100 Subject: [PATCH] Updated Kalman Filter base class --- spatz/models/kalman.py | 94 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 84 insertions(+), 10 deletions(-) diff --git a/spatz/models/kalman.py b/spatz/models/kalman.py index f109250..31674bb 100644 --- a/spatz/models/kalman.py +++ b/spatz/models/kalman.py @@ -1,20 +1,94 @@ -from typing import Any +import numpy as np + +from typing import Any, Tuple +from numpy.typing import ArrayLike -def A(dt: float): - return [[dt, 0], [0, 1]] - +def inv(val): + if np.isscalar(val): + return 1 / val + + if len(val) == 1: + return 1 / val[0] + + return np.linalg.inv(val) class KalmanFilter: - def __init__(self, A, B) -> None: + def __init__(self, A: ArrayLike, B: ArrayLike, Q: ArrayLike, H: ArrayLike, R: ArrayLike) -> None: + """Simple Kalman Filter implementation. + + Args: + A (ArrayLike): State transition matrix. + B (ArrayLike): Controll matrix. + Q (ArrayLike): Process noise matrix. + H (ArrayLike): Observation matrix. + R (ArrayLike): Measurement noise matrix. + """ self.__A = A self.__B = B + self.__Q = Q + self.__H = H + self.__R = R - def predict(self, dt: float) -> None: - self.__A(dt) + def predict(self, dt: float, x: ArrayLike, err: ArrayLike, u: ArrayLike) -> Tuple[ArrayLike, ArrayLike]: + """Perform the prediction step of the Kalman Filter. - pass + Args: + dt (float): The change in time since the last prediction. + x (ArrayLike): The current state. + err (ArrayLike): The current error. + u (ArrayLike): The control input. - def correct(self) -> None: - pass \ No newline at end of file + Returns: + Tuple[ArrayLike, ArrayLike]: Returns the prediction state and the corresponding error. + """ + A = self.__A + B = self.__B + Q = self.__Q + + if hasattr(self.__A, '__call__'): + A = self.__A(dt) + + if hasattr(self.__B, '__call__'): + B = self.__B(dt) + + if hasattr(self.__Q, '__call__'): + Q = self.__Q(dt) + + x = A @ x + B @ u + err = A @ err @ A.T + Q + + return x, err + + def correct(self, dt: float, x: ArrayLike, err: ArrayLike, z: ArrayLike) -> Tuple[ArrayLike, ArrayLike]: + """Perform the correction step of the Kalman Filter. + + Args: + dt (float): The change in time since the last correction. + x (ArrayLike): The predicted state + err (ArrayLike): The error after prediction. + z (ArrayLike): The new measurements. + + Returns: + Tuple[ArrayLike, ArrayLike]: Returns a corrected state and updated errors. + """ + H = self.__H + R = self.__R + + if hasattr(self.__H, '__call__'): + H = self.__H(dt) + + if hasattr(self.__R, '__call__'): + R = self.__R(dt) + + # Compute the Kalman gain. + K = err @ H.T @ inv(H @ err @ H.T + R) + + # Compute the corrected state. + x = x + K @ (z - H @ x) + + # Compute the error after correction. + err = (np.identity('TODO') - K @ H) @ err + + return x, err \ No newline at end of file