Source code for do_dpc.control_utils.trajectory_collector

"""
Handles control input generation and trajectory data collection for data-driven control.
"""

from typing import Optional

import numpy as np
from tqdm import tqdm  # type: ignore

from do_dpc.control_utils.control_structs import InputOutputTrajectory
from do_dpc.control_utils.lti_systems import LTISimulator
from do_dpc.control_utils.noise_generators import WhiteNoiseGenerator
from do_dpc.control_utils.pid_control_utils import PIDCombo
from do_dpc.utils.logging_config import get_logger

logger = get_logger(__name__)

N_SAMPLES = 1000
EXC_DEV = 1


[docs] class TrajectoryCollector: """ Collects system output (`y`) and control input (`u`) trajectories. This class handles the logging and structured collection of trajectory data for data-driven control applications. """ def __init__(self, m: int, p: int, traj_length: int): """ Initializes the TrajectoryCollector. Args: m (int): Number of control inputs (dimension of `u`). p (int): Number of system outputs (dimension of `y`). traj_length (int): Length of the trajectory (number of time steps). """ self.m = m self.p = p self.traj_length = traj_length # Initialize trajectory arrays with NaNs to track missing values. self.y = np.full((p, traj_length), np.nan) self.u = np.full((m, traj_length), np.nan) logger.info("TrajectoryCollector initialized with m=%s, p=%s, traj_length=%s", m, p, traj_length) def store_measurements(self, y_next: np.ndarray, u_next: np.ndarray): """ Stores the next system output (`y_next`) and the corresponding control input (`u_next`) in the trajectory dataset. This function assigns the provided system output `y_next` and control input `u_next` to the first available (empty) column in the trajectory storage. If the trajectory is already full, a warning is logged. Args: y_next (np.ndarray): The new system output observation (shape `(p,)`). u_next (np.ndarray): The control input applied at this step (shape `(m,)`). Raises: ValueError: If `y_next` does not match the expected shape `(p,)`. ValueError: If `u_next` does not match the expected shape `(m,)`. """ if y_next.shape != (self.p,): raise ValueError(f"y_next must have shape ({self.p},) but got {y_next.shape}") if u_next.shape != (self.m,): raise ValueError(f"u_next must have shape ({self.m},) but got {u_next.shape}") u_idx = np.argmax(np.isnan(self.u[0])) if np.isnan(self.u[0, u_idx]): self.u[:, u_idx] = u_next logger.debug("Stored control input u at column index %s", u_idx) else: logger.warning("Trajectory collection for u is already complete.") y_idx = np.argmax(np.isnan(self.y[0])) if np.isnan(self.y[0, y_idx]): self.y[:, y_idx] = y_next logger.debug("Stored system output y at column index %s", y_idx) else: logger.warning("Trajectory collection for y is already complete.") def get_trajectory_data(self) -> InputOutputTrajectory: """ Retrieves the collected trajectory data. Ensures that data collection is complete before returning the trajectory. If NaN values are present, logs a warning, removes them, and proceeds with returning the cleaned data. Returns: InputOutputTrajectory: The collected trajectory data with `y`, `u`, `m`, and `p`, with NaN values removed. """ if np.isnan(self.u).any() or np.isnan(self.y).any(): logger.warning("NaN values detected in trajectory data. Cleaning data by removing NaNs.") self.y = np.nan_to_num(self.y) self.u = np.nan_to_num(self.u) logger.info("Trajectory collection complete. Returning cleaned trajectory data.") return InputOutputTrajectory(y=self.y, u=self.u)
# pylint: disable=R0913,R0917
[docs] def collect_trajectory_data( sys: LTISimulator, m: int, p: int, pid_combo: Optional[PIDCombo] = None, n_samples: int = N_SAMPLES, exc_dev: float = EXC_DEV, ) -> InputOutputTrajectory: """ Collects trajectory data from an LTI system with optional PID control. Args: sys (LTISimulator): The system to collect data from. m (int): Number of control inputs. p (int): Number of system outputs. pid_combo (PIDCombo, optional): PID controller with function to get state to error. n_samples (int, optional): Number of samples for the data collection. exc_dev (float, optional): Excitation Deviation for the input generation. Returns: InputOutputTrajectory: Collected trajectory data containing system outputs and control inputs. """ seed = 7473 u_mean = np.zeros((m,)) traj_col = TrajectoryCollector(m, p, n_samples) ctrl_gen = WhiteNoiseGenerator(mean=u_mean, std=exc_dev * np.ones_like(u_mean), seed=seed) y_next = sys.get_output() for _ in tqdm(range(n_samples), desc="Collecting Training Data", ncols=80, position=0): u_next = ctrl_gen.generate() if pid_combo: err, der_err = pid_combo.converter_function(y_next, None) u_next -= pid_combo.MIMO_PID.compute_with_derivative(err, der_err) traj_col.store_measurements(y_next, u_next) y_next = sys.step(u_next) return traj_col.get_trajectory_data()