mirror of
https://git.intern.spaceteamaachen.de/ALPAKA/SPATZ.git
synced 2025-06-10 01:55:59 +00:00
91 lines
3.1 KiB
Python
91 lines
3.1 KiB
Python
import pandas as pd
|
|
import numpy as np
|
|
|
|
from typing import Literal, List
|
|
from numpy.typing import NDArray
|
|
|
|
from spatz.simulations.data_source import DataSource
|
|
|
|
|
|
|
|
class CSVSource(DataSource):
|
|
def __init__(self, path: str, time_col: str, interpolation: Literal['linear']='linear') -> None:
|
|
"""A data source that extracts all its data from a csv file.
|
|
|
|
Args:
|
|
time_col (str): The name of the column that contains time data.
|
|
"""
|
|
super().__init__()
|
|
|
|
self._df = pd.read_csv(path)
|
|
self._time_col = time_col
|
|
self._idx = 0
|
|
self._interpolation = interpolation
|
|
|
|
def get_length(self) -> float:
|
|
return max(self._df[self._time_col])
|
|
|
|
def _on_reset(self):
|
|
pass
|
|
|
|
def _get_closest_idx(self, t: float) -> int:
|
|
"""Gets an index _idx_ for the dataframe _df_ such that the values at the given time _t_ are somewhere between
|
|
_idx_ and _idx+1_.
|
|
|
|
Args:
|
|
t (float): The requested time.
|
|
|
|
Returns:
|
|
int: The computed index.
|
|
"""
|
|
idx = (self._df[self._time_col] - t).abs().idxmin()
|
|
idx = idx if self._df[self._time_col].loc[idx] <= t else idx - 1
|
|
|
|
return idx
|
|
|
|
def _on_step(self, _: float):
|
|
self._idx = self._get_closest_idx(self.get_time())
|
|
|
|
def fetch_value(self, name: str, t: float | None = None, custom_interpolation=None) -> float:
|
|
"""Get a specific value from the dataframe.
|
|
|
|
Args:
|
|
name (str): The name of the value to fetch.
|
|
t (float): Allows specification of a different time instead of the current time. None for current time.
|
|
|
|
Returns:
|
|
float: Returns the requested value.
|
|
"""
|
|
idx = self._idx if t is None else self._get_closest_idx(t)
|
|
|
|
if self._interpolation == 'linear':
|
|
t_min = self._df.at[idx, self._time_col]
|
|
t_max = self._df.at[idx + 1, self._time_col]
|
|
|
|
# Sometimes no time passes in-between two samples.
|
|
if t_max == t_min:
|
|
return self._df.at[name, idx]
|
|
|
|
# Compute the weight for interpolation.
|
|
alpha = (self.get_time() - t_min) / (t_max - t_min)
|
|
|
|
if custom_interpolation is not None:
|
|
a = self._df.at[idx, name]
|
|
b = self._df.at[idx + 1, name]
|
|
|
|
return custom_interpolation(a, b, alpha)
|
|
|
|
# Interpolate linearly between the two data points.
|
|
return (1 - alpha) * self._df.at[idx, name] + alpha * self._df.at[idx + 1, name]
|
|
|
|
def fetch_values(self, names: List[str], t: float | None = None, custom_interpolation=None) -> NDArray:
|
|
"""Get specific values from the dataframe.
|
|
|
|
Args:
|
|
names (List[str]): Names of the values to get.
|
|
t (float): Allows specification of a different time instead of the current time. None for current time.
|
|
|
|
Returns:
|
|
np.array: Returns a numpy array containing the requested values in the same order as in the input list.
|
|
"""
|
|
return np.asarray([self.fetch_value(name, t, custom_interpolation) for name in names]) |