SPATZ/spatz/logger.py

73 lines
2.0 KiB
Python

import numpy as np
import pandas as pd
from abc import abstractmethod
from typing import Any, Tuple, List
from numpy.typing import ArrayLike
from spatz.simulations.advanceable import Advanceable
class Logger(Advanceable):
def __init__(self) -> None:
super().__init__()
@abstractmethod
def write(self, attrib: str | List[str], value: Any | List[Any] | List[ArrayLike], domain: str = 'all'):
pass
class EmptyLogger(Logger):
def __init__(self) -> None:
super().__init__()
self.__idx = -1
def _on_step(self, _: float):
pass
def _on_reset(self):
pass
def write(self, attrib: str | List[str], value: Any | List[Any] | List[ArrayLike], domain: str = 'all'):
pass
def get_dataframe(self) -> pd.DataFrame:
pass
class CSVLogger(Logger):
def __init__(self) -> None:
super().__init__()
self.__idx = -1
def _on_step(self, _: float):
self.__df = pd.concat([pd.DataFrame(), self.__df], ignore_index=True, copy=False)
self.__idx += 1
self.__df.at[self.__idx, 'time'] = self.get_time()
def _on_reset(self):
self.__df = pd.DataFrame.from_dict({'time': [self.get_time()]}).astype(np.float64)
def write(self, attrib: str | List[str], value: Any | List[Any] | List[ArrayLike], domain: str = 'all'):
"""Writes a value to the logger.
Args:
attrib (str): The name of the value to log.
value (Any): The value to log.
domain (str, optional): The domain the value belongs to. Defaults to 'any'.
"""
if not isinstance(attrib, str):
for attr, val in zip(attrib, value):
self.write(attr, val, domain=domain)
else:
name = f'{domain}/{attrib}'
if name not in self.__df.columns:
self.__df[name] = pd.Series([pd.NA] * len(self.__df))
self.__df.at[self.__idx, name] = value
def get_dataframe(self) -> pd.DataFrame:
return self.__df