Source code for do_dpc.environments.rocket_env.rocket_data_collection
"""
Module for Data collection for the Rocket lander.
RocketInputGenerator Class Definition.
"""
import sys
import numpy as np
from tqdm import tqdm # type: ignore
from do_dpc.control_utils.control_structs import InputOutputTrajectory, Bounds
from do_dpc.control_utils.noise_generators import WhiteNoiseGenerator
from do_dpc.control_utils.pid_control_utils import PIDCombo
from do_dpc.control_utils.pid_profiles import ROCKET_PID_COMBO
from do_dpc.control_utils.trajectory_collector import TrajectoryCollector
from do_dpc.environments.rocket_env.rocket_env_facade import RocketEnvFacade
from do_dpc.environments.rocket_env.rocket_utils import is_rocket_outside_reasonable_bounds
from do_dpc.utils.logging_config import get_logger
N_SAMPLES = 800
logger = get_logger(__name__)
[docs]
def collect_trajectory_data_env(
env: RocketEnvFacade, m: int, p: int, n_samples: int = N_SAMPLES
) -> InputOutputTrajectory:
"""
Collects trajectory data from the given environment.
Args:
env: The environment object with `get_output()` and `done` attributes.
m (int): Number of system inputs.
p (int): Number of system outputs.
n_samples (int, optional): Number of samples to collect. Defaults to `N_SAMPLES`.
Returns:
InputOutputTrajectory: The collected trajectory data, or None if data collection fails.
Raises:
AttributeError: If `env` does not have the required attributes or methods.
ValueError: If `n_samples` is not a positive integer.
RuntimeError: If the environment is already done before collecting data.
"""
sys.stdout.flush()
if not hasattr(env, "get_output") or not hasattr(env, "done"):
raise AttributeError("The provided environment must have 'get_output()' and 'done' attributes.")
if not isinstance(n_samples, int) or n_samples <= 0:
raise ValueError("n_samples must be a positive integer.")
if env.done:
raise RuntimeError("The environment is already done. Cannot collect trajectory data.")
traj_col = TrajectoryCollector(m, p, n_samples)
y_next = env.get_output()
obs_init = y_next[:2].copy()
exc_gen = WhiteNoiseGenerator(np.array([0.33, 0, 0]), np.array([0.5, 1, 1]), seed=1)
stabilized_random_inputs = RocketInputGenerator(env.get_input_bounds(), ROCKET_PID_COMBO, exc_gen)
for i in tqdm(range(n_samples), desc="Collecting Training Data", ncols=80):
if env.done:
if env.done:
logger.error("Data collection stopped prematurely: Environment reached 'done' output.")
logger.info("Iteration: %d", i)
env.close()
raise RuntimeError("Could not complete the data collection as the environment is done.")
output = y_next.copy()
output[:2] -= obs_init[:2]
u_next = stabilized_random_inputs.compute_action(output)
traj_col.store_measurements(y_next, u_next)
y_next = env.step(u_next)
logger.info("Data Collection complete")
logger.info("Total Samples: %d", n_samples)
return traj_col.get_trajectory_data()