mirror of
https://git.intern.spaceteamaachen.de/ALPAKA/SPATZ.git
synced 2025-06-10 01:55:59 +00:00
43 lines
1.4 KiB
Python
43 lines
1.4 KiB
Python
import numpy as np
|
|
import pandas as pd
|
|
|
|
from typing import Any, Tuple, List
|
|
from numpy.typing import ArrayLike
|
|
|
|
from spatz.simulations.advanceable import Advanceable
|
|
|
|
|
|
class Logger(Advanceable):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.__idx = -1
|
|
|
|
def _on_step(self, _: float):
|
|
self.__df = pd.concat([pd.DataFrame(), self.__df], ignore_index=True, copy=False)
|
|
self.__idx += 1
|
|
self.__df.at[self.__idx, 'time'] = self.get_time()
|
|
|
|
def _on_reset(self):
|
|
self.__df = pd.DataFrame.from_dict({'time': [self.get_time()]}).astype(np.float64)
|
|
|
|
def write(self, attrib: str | List[str], value: Any | List[Any] | List[ArrayLike], domain: str = 'all'):
|
|
"""Writes a value to the logger.
|
|
|
|
Args:
|
|
attrib (str): The name of the value to log.
|
|
value (Any): The value to log.
|
|
domain (str, optional): The domain the value belongs to. Defaults to 'any'.
|
|
"""
|
|
if not isinstance(attrib, str):
|
|
for attr, val in zip(attrib, value):
|
|
self.write(attr, val, domain=domain)
|
|
else:
|
|
name = f'{domain}/{attrib}'
|
|
|
|
if name not in self.__df.columns:
|
|
self.__df[name] = pd.Series([pd.NA] * len(self.__df))
|
|
|
|
self.__df.at[self.__idx, name] = value
|
|
|
|
def get_dataframe(self) -> pd.DataFrame:
|
|
return self.__df |