Added a few metrics for measuring performance

This commit is contained in:
dario 2023-12-14 23:29:58 +01:00
parent 33e5eba3a6
commit 4a93fd4ee6
5 changed files with 49 additions and 7 deletions

20
spatz/metrics/max_dev.py Normal file
View File

@ -0,0 +1,20 @@
from numpy.typing import ArrayLike
from spatz.metrics import Metric
class MaxAbsDeviation(Metric):
def __init__(self) -> None:
"""A metric tracking the maximum absolute deviation from the true value."""
super().__init__()
def _update(self, x: ArrayLike, y: ArrayLike):
self._score = max(self._score, abs(x - y))
class MaxRelDeviation(Metric):
def __init__(self) -> None:
"""A metric tracking the maximum deviation from the true value in percent."""
super().__init__()
def _update(self, x: ArrayLike, y: ArrayLike):
self._score = max(self._score, (x - y) / y)

View File

@ -1,14 +1,19 @@
from abc import abstractmethod
from numpy.typing import ArrayLike
from typing import Any from typing import Any
class Metric: class Metric:
def __init__(self) -> None: def __init__(self) -> None:
self.__sum = 0 self._score = None
def get_score(): def get_score(self):
pass return self._score
def __call__(self, *args) -> Any: @abstractmethod
self.__sum += abs(x - y) def _update(self, x: ArrayLike, y: ArrayLike):
raise NotImplementedError()
def __call__(self, x: ArrayLike, y: ArrayLike) -> Any:
self._update(x, y)

15
spatz/metrics/mse.py Normal file
View File

@ -0,0 +1,15 @@
from numpy.typing import ArrayLike
from spatz.metrics import Metric
class MSEMetric(Metric):
def __init__(self) -> None:
"""Mean squared error
"""
super().__init__()
def _update(self, x: ArrayLike, y: ArrayLike):
if self._score is None:
self._score = 0.5 * (x - y)**2
else:
self._score += 0.5 * (x - y)**2

View File

@ -14,4 +14,6 @@ class Erinome_I(GPS):
def _sensor_specific_effects(self, x: ArrayLike) -> ArrayLike: def _sensor_specific_effects(self, x: ArrayLike) -> ArrayLike:
# TODO: What's the GPS module's behavior? # TODO: What's the GPS module's behavior?
# TODO: Only return measurements every second
return x return x

View File

@ -16,7 +16,7 @@ class GPS(Sensor):
"""GPS Module which provides the following information: """GPS Module which provides the following information:
- Longitude (in °) - Longitude (in °)
- Latitiude (in °) - Latitiude (in °)
- Altitude (in °) - Altitude (in m)
""" """
super().__init__(dataset, logger, transforms) super().__init__(dataset, logger, transforms)