from numpy.typing import ArrayLike from spatz.metrics import Metric class MSEMetric(Metric): def __init__(self) -> None: """Mean squared error """ super().__init__() def _update(self, x: ArrayLike, y: ArrayLike): if self._score is None: self._score = 0.5 * (x - y)**2 else: self._score += 0.5 * (x - y)**2