import numpy as np from typing import Literal from numpy.typing import NDArray from spatz.simulations.csv_source import CSVSource class RocketPyCSV(CSVSource): def __init__(self, path: str, interpolation: Literal['linear'] = 'linear') -> None: super().__init__(path, ' Time (s)', interpolation) def get_position(self) -> NDArray: return self.fetch_values([' X (m)', ' Y (m)', ' Z (m)']) def get_velocity(self, frame: Literal['global', 'local']) -> NDArray: vel_global = self.fetch_values([' Vx (m/s)', ' Vy (m/s)', ' Vz (m/s)']) if frame == 'global': return vel_global return self.global_to_local() @ vel_global def get_acceleration(self, frame: Literal['global', 'local']) -> NDArray: acc_global = self.fetch_values([' Ax (m/s²)', ' Ay (m/s²)', ' Az (m/s²)']) if frame == 'global': return acc_global return self.global_to_local() @ acc_global def get_attitude(self) -> NDArray: t_min = self._df.at[self._idx, self._time_col] t_max = self._df.at[self._idx + 1, self._time_col] def slerp(a, b, alpha): theta = np.arccos(np.clip(np.dot(a, b), -1, 1)) if np.isclose(theta, 0): return a return (a * np.sin((1-alpha) * theta) + b * np.sin(alpha * theta)) / np.sin(theta) qa = np.array([self._df.at[self._idx, ' e0'], self._df.at[self._idx, ' e1'], self._df.at[self._idx, ' e2'], self._df.at[self._idx, ' e3']]) qb = np.array([self._df.at[self._idx+1, ' e0'], self._df.at[self._idx+1, ' e1'], self._df.at[self._idx+1, ' e2'], self._df.at[self._idx+1, ' e3']]) alpha = (self.get_time() - t_min) / (t_max - t_min) return slerp(qa, qb, alpha) def global_to_local(self) -> NDArray: quat = self.get_attitude() e0, e1, e2, e3 = quat[0], quat[1], quat[2], quat[3] # Taken from: # https://docs.rocketpy.org/en/latest/technical/equations_of_motion.html mat = np.array([ [e0**2 + e1**2 - e2**2 - e3**2, 2*(e1*e2+e0*e3), 2*(e1*e3 - e0*e2)], [2*(e1*e2 - e0*e3), e0**2 - e1**2 + e2**2 - e3**2, 2*(e2*e3 + e0*e1)], [2*(e1*e3 + e0*e2), 2*(e2*e3 - e0*e1), e0**2 - e1**2 - e2**2 + e3**2] ]) return mat def local_to_global(self) -> NDArray: return self.global_to_local().T def get_angular_velocity(self) -> NDArray: return self.fetch_values([' ω1 (rad/s)',' ω2 (rad/s)',' ω3 (rad/s)']) def get_static_pressure(self) -> float: return self.fetch_value(' Pressure (Pa)') def get_temperature(self) -> float: raise NotImplementedError() def get_longitude(self) -> float: return self.fetch_value(' Longitude (°)') def get_latitude(self) -> float: return self.fetch_value(' Latitude (°)') def get_altitude(self) -> float: return self.fetch_value(' Z (m)')