mirror of
https://git.intern.spaceteamaachen.de/ALPAKA/SPATZ.git
synced 2025-06-10 18:15:59 +00:00
16 lines
381 B
Python
16 lines
381 B
Python
from numpy.typing import ArrayLike
|
|
from spatz.metrics import Metric
|
|
|
|
|
|
class MSEMetric(Metric):
|
|
def __init__(self) -> None:
|
|
"""Mean squared error
|
|
"""
|
|
super().__init__()
|
|
|
|
def _update(self, x: ArrayLike, y: ArrayLike):
|
|
if self._score is None:
|
|
self._score = 0.5 * (x - y)**2
|
|
else:
|
|
self._score += 0.5 * (x - y)**2
|