Rocket#

Environment Definition#

[1]:
import logging
import warnings

import numpy as np

from do_dpc.environments.rocket_env.rocket_env_facade import RocketEnvironmentArguments, RocketEnvFacade

# Removes the INFO logging
logging.getLogger().setLevel(logging.WARNING)
# Suppress all UserWarnings
warnings.filterwarnings("ignore", category=UserWarning)

data_col_env = RocketEnvFacade(RocketEnvironmentArguments(), video_name_prefix="data_collection", record_video=True)
[2]:
from pyvirtualdisplay import Display

# create display for the videos
display = Display(visible=False, size=(1400, 900), backend="xvfb")
display.start()
[2]:
<pyvirtualdisplay.display.Display at 0x12d41ce60>

Training Data Collection#

[3]:
from do_dpc.environments.rocket_env.rocket_data_collection import collect_trajectory_data_env
from do_dpc.utils.plots import plot_training_data

n_samples = 800

m, p = data_col_env.get_dims()
training_data = collect_trajectory_data_env(data_col_env, m, p, n_samples=n_samples)
data_col_env.close()

plot_training_data(training_data)
Collecting Training Data: 100%|███████████████| 800/800 [00:08<00:00, 93.98it/s]
../_images/example_gallery_rocket_5_1.png
[4]:
from do_dpc.utils.plots import show_video

show_video("data_collection")

Online DPC Application#

[5]:
from do_dpc.control_utils.trajectory_collector import TrajectoryCollector
from do_dpc.environments.rocket_env.rocket_utils import calculate_normalized_thrust_to_hover
from do_dpc.environments.rocket_env.rocket_controller_profiles import CTRL_PARAMS_ROCKET
from do_dpc.dpc.tpc import TPC

env = RocketEnvFacade(
    RocketEnvironmentArguments(initial_position=(0.8, 0.9, 0)), video_name_prefix="rocket_example", record_video=True
)
m, p = env.get_dims()

n_sim = 2000

traj_col = TrajectoryCollector(m, p, n_sim)

controller_instance = TPC(CTRL_PARAMS_ROCKET, training_data)
controller_instance.add_input_constraints(env.get_input_bounds())
controller_instance.build_optimization_problem()

TIMESTEPS = 2000
u_stable = np.array([calculate_normalized_thrust_to_hover(), 0, 0])

u_next, y_current, y_prev = u_stable, env.get_output(), env.get_output()
z_p_current = np.concatenate((y_prev, u_next))

for _ in range(CTRL_PARAMS_ROCKET.tau_p):
    controller_instance.update_past_measurements(z_p_current)
y_prev = env.step(u_next)
try:
    for i in range(TIMESTEPS):
        if env.landed_successfully:
            print(f"Landed successfully after {i} steps")
            break

        if env.done:
            print(f"Crashed after {i} steps")
            break

        y_r, u_r = env.get_y_u_reference()
        controller_instance.update_tracking_reference(y_r, u_r)

        traj_col.store_measurements(y_prev, u_next)
        z_p_current = np.concatenate((y_prev, u_next))
        controller_instance.update_past_measurements(z_p_current)

        controller_instance.solve()
        u_next = controller_instance.get_next_control_action()
        y_prev, y_current = y_current, env.step(u_next)

finally:
    assert env.landed_successfully, "Rocket did not land successfully!"
    env.close()

sim_data = traj_col.get_trajectory_data()
2025-03-31 10:56:23,380 - WARNING - do_ddpc.ddpc.tpc - Lambda_uu is not positive semidefinite. Regularization cost is skipped to avoid non-convex cost.
Landed successfully after 1287 steps
2025-03-31 10:56:48,449 - WARNING - do_ddpc.control_utils.trajectory_collector - NaN values detected in trajectory data. Cleaning data by removing NaNs.

Analyse and Data Visualization#

[6]:
plot_training_data(sim_data)
../_images/example_gallery_rocket_10_0.png
[7]:
show_video("rocket_example")