SPATZ/spatz/simulations/rocketpy.py

84 lines
2.9 KiB
Python

import numpy as np
from typing import Literal
from numpy.typing import NDArray
from spatz.simulations.csv_source import CSVSource
class RocketPyCSV(CSVSource):
def __init__(self, path: str, interpolation: Literal['linear'] = 'linear') -> None:
super().__init__(path, ' Time (s)', interpolation)
def get_position(self) -> NDArray:
return self.fetch_values([' X (m)', ' Y (m)', ' Z (m)'])
def get_velocity(self, frame: Literal['global', 'local']) -> NDArray:
vel_global = self.fetch_values([' Vx (m/s)', ' Vy (m/s)', ' Vz (m/s)'])
if frame == 'global':
return vel_global
return self.global_to_local() @ vel_global
def get_acceleration(self, frame: Literal['global', 'local']) -> NDArray:
acc_global = self.fetch_values([' Ax (m/s²)', ' Ay (m/s²)', ' Az (m/s²)'])
if frame == 'global':
return acc_global
return self.global_to_local() @ acc_global
def get_attitude(self) -> NDArray:
t_min = self._df.at[self._idx, self._time_col]
t_max = self._df.at[self._idx + 1, self._time_col]
def slerp(a, b, alpha):
theta = np.arccos(np.clip(np.dot(a, b), -1, 1))
if np.isclose(theta, 0):
return a
return (a * np.sin((1-alpha) * theta) + b * np.sin(alpha * theta)) / np.sin(theta)
qa = np.array([self._df.at[self._idx, ' e0'], self._df.at[self._idx, ' e1'], self._df.at[self._idx, ' e2'], self._df.at[self._idx, ' e3']])
qb = np.array([self._df.at[self._idx+1, ' e0'], self._df.at[self._idx+1, ' e1'], self._df.at[self._idx+1, ' e2'], self._df.at[self._idx+1, ' e3']])
alpha = (self.get_time() - t_min) / (t_max - t_min)
return slerp(qa, qb, alpha)
def global_to_local(self) -> NDArray:
quat = self.get_attitude()
e0, e1, e2, e3 = quat[0], quat[1], quat[2], quat[3]
# Taken from:
# https://docs.rocketpy.org/en/latest/technical/equations_of_motion.html
mat = np.array([
[e0**2 + e1**2 - e2**2 - e3**2, 2*(e1*e2+e0*e3), 2*(e1*e3 - e0*e2)],
[2*(e1*e2 - e0*e3), e0**2 - e1**2 + e2**2 - e3**2, 2*(e2*e3 + e0*e1)],
[2*(e1*e3 + e0*e2), 2*(e2*e3 - e0*e1), e0**2 - e1**2 - e2**2 + e3**2]
])
return mat
def local_to_global(self) -> NDArray:
return self.global_to_local().T
def get_angular_velocity(self) -> NDArray:
return self.fetch_values([' ω1 (rad/s)',' ω2 (rad/s)',' ω3 (rad/s)'])
def get_static_pressure(self) -> float:
return self.fetch_value(' Pressure (Pa)')
def get_longitude(self) -> float:
return self.fetch_value(' Longitude (°)')
def get_latitude(self) -> float:
return self.fetch_value(' Latitude (°)')
def get_altitude(self) -> float:
return self.fetch_value(' Z (m)')