SPATZ/spatz/simulation.py

140 lines
4.3 KiB
Python

from typing import List
from numpy.random import normal
from tqdm import tqdm
from spatz.simulations.advanceable import Advanceable
from spatz.simulations.data_source import DataSource
from spatz.dataset import Dataset
from spatz.logger import Logger
from spatz.sensors import Sensor
from spatz.dataset import Dataset, Phase
from spatz.logger import Logger, EmptyLogger
from spatz.sensors import Sensor
from spatz.observer import Observer
class UniformTimeSteps:
def __init__(self, dt: float, mu: float = 0, sigma: float = 0, delay_only=True) -> None:
"""_summary_
Args:
dt (float): _description_
mu (float, optional): _description_. Defaults to 0.
sigma (float, optional): _description_. Defaults to 0.
delay_only (bool, optional): _description_. Defaults to True.
"""
self.__dt = dt
self.__mu = mu
self.__sigma = sigma
self.__delay_only = delay_only
def __call__(self, t):
noise = normal(self.__mu, self.__sigma)
if self.__delay_only:
noise = abs(noise)
return self.__dt + noise
class Simulation(Advanceable):
def __init__(self, time_steps=UniformTimeSteps(0.01)):
super().__init__()
self.__data_source = None
self.__logger = None
self.__sensors: List[Sensor] = []
self.__time_steps = time_steps
def run(self, verbose=False, until: Phase = None):
idx = 0
# Clear all logs and reset the dataset to the first time step.
self.__data_source.reset()
self.__logger.reset()
if verbose:
pbar = tqdm(total=self.__data_source.get_length())
while True:
t = self.get_time()
dt = self.__time_steps(t)
t_ = t + dt
idx += 1
if t_ > self.__data_source.get_length():
break
if until is not None and self.__data_source.get_phase() == until:
break
self.advance(dt)
if verbose:
pbar.update(dt)
yield idx, t_, t_ - t
if verbose:
pbar.close()
def _on_step(self, dt: float):
self.__data_source.advance(dt)
self.__logger.advance(dt)
def load(self, source: DataSource):
self.__data_source = source
self.__logger = EmptyLogger()
for sensor in self.__sensors:
sensor.set_dataset(self.__data_source)
sensor.set_logger(self.__logger)
return self
def get_data_source(self) -> DataSource:
return self.__data_source
def get_logger(self) -> Logger:
return self.__logger
def add_sensor(self, sensor, *args, **kwargs) -> Sensor:
"""Register a new sensor for this simulation. A registered sensor can be called like a function and returns
the current measurements. The class' constructor arguments have to be given aswell.
Args:
sensor (_type_): A subclass of the abstract Sensor class.
Returns:
Sensor: Returns an object of the provided sensor subclass.
"""
assert issubclass(sensor, Sensor), "Expected a subclass of Sensor."
self.__sensors.append(sensor(self.__data_source, self.__logger, *args, **kwargs))
return self.__sensors[-1]
def add_observer(self, observer_or_attributes: List[str] | Observer) -> Observer:
"""Register a new observer for this simulation.
Args:
observer_or_attributes (List[str] | Observer): A list of strings describing the attributes to observe
or a custom observer class.
Returns:
Observer: An observer object which can be called like a function to obtain the desired data.
"""
assert isinstance(observer_or_attributes, list) or issubclass(observer_or_attributes, Observer)
if isinstance(observer_or_attributes, list):
attributes = observer_or_attributes
assert len(attributes) != 0, "Observed attributes list must be nonempty."
self.__sensors.append(Observer(self.__data_source, self.__logger, attributes))
else:
observer = observer_or_attributes
self.__sensors.append(observer(self.__data_source, self.__logger))
return self.__sensors[-1]