import pandas as pd import numpy as np from typing import Literal, List from numpy.typing import NDArray from spatz.simulations.data_source import DataSource class CSVSource(DataSource): def __init__(self, path: str, time_col: str, interpolation: Literal['linear']='linear') -> None: """A data source that extracts all its data from a csv file. Args: time_col (str): The name of the column that contains time data. """ super().__init__() self._df = pd.read_csv(path) self._time_col = time_col self._idx = 0 self._interpolation = interpolation def get_length(self) -> float: return max(self._df[self._time_col]) def _on_reset(self): pass def _get_closest_idx(self, t: float) -> int: """Gets an index _idx_ for the dataframe _df_ such that the values at the given time _t_ are somewhere between _idx_ and _idx+1_. Args: t (float): The requested time. Returns: int: The computed index. """ idx = (self._df[self._time_col] - t).abs().idxmin() idx = idx if self._df[self._time_col].loc[idx] <= t else idx - 1 return idx def _on_step(self, _: float): self._idx = self._get_closest_idx(self.get_time()) def fetch_value(self, name: str, t: float | None = None, custom_interpolation=None) -> float: """Get a specific value from the dataframe. Args: name (str): The name of the value to fetch. t (float): Allows specification of a different time instead of the current time. None for current time. Returns: float: Returns the requested value. """ idx = self._idx if t is None else self._get_closest_idx(t) if self._interpolation == 'linear': t_min = self._df.at[idx, self._time_col] t_max = self._df.at[idx + 1, self._time_col] # Sometimes no time passes in-between two samples. if t_max == t_min: return self._df.at[name, idx] # Compute the weight for interpolation. alpha = (self.get_time() - t_min) / (t_max - t_min) if custom_interpolation is not None: a = self._df.at[idx, name] b = self._df.at[idx + 1, name] return custom_interpolation(a, b, alpha) # Interpolate linearly between the two data points. return (1 - alpha) * self._df.at[idx, name] + alpha * self._df.at[idx + 1, name] def fetch_values(self, names: List[str], t: float | None = None, custom_interpolation=None) -> NDArray: """Get specific values from the dataframe. Args: names (List[str]): Names of the values to get. t (float): Allows specification of a different time instead of the current time. None for current time. Returns: np.array: Returns a numpy array containing the requested values in the same order as in the input list. """ return np.asarray([self.fetch_value(name, t, custom_interpolation) for name in names])