Source code for do_dpc.utils.serialization
"""
serialization.py
This module provides utility functions for saving and loading `dataclass` objects using NumPy `.npz` format.
It supports NumPy arrays, floats, and integers.
Functions:
- save_dataclass_npz: Saves a dataclass object as a compressed `.npz` file.
- load_dataclass_npz: Loads a dataclass object from a `.npz` file.
"""
import dataclasses
from typing import Type, TypeVar
import numpy as np
T = TypeVar("T")
[docs]
def save_dataclass_npz(obj: T, filename: str):
"""
Saves a dataclass object as a `.npz` file, supporting NumPy arrays, floats, and integers.
Args:
obj (T): The dataclass object to save.
filename (str): The file path where the object should be stored.
Raises:
ValueError: If any field in the dataclass is not a `numpy.ndarray`, `float`, or `int`.
Note:
- This function only supports numerical attributes (`np.ndarray`, `float`, `int`).
- Non-supported types (e.g., strings, lists, dictionaries) will raise an error.
- Data is stored in a compressed `.npz` format to optimize storage.
"""
data = {}
for field in dataclasses.fields(obj): # type: ignore
value = getattr(obj, field.name)
if isinstance(value, (np.ndarray, float, int)):
data[field.name] = value
else:
raise ValueError(
f"Unsupported data type {type(value)} for field '{field.name}'. "
"Only NumPy arrays, floats, and ints are allowed."
)
np.savez_compressed(filename, **data)
[docs]
def load_dataclass_npz(cls: Type[T], filename: str) -> T:
"""
Loads a dataclass object from a `.npz` file, reconstructing NumPy arrays, floats, and integers.
Args:
cls (Type[T]): The class type of the dataclass to reconstruct.
filename (str): The `.npz` file path from which the object should be loaded.
Returns:
T: The reconstructed dataclass object.
Raises:
FileNotFoundError: If the specified file does not exist.
ValueError: If the loaded data contains unsupported types or is incompatible with the dataclass.
Note:
- Converts 0D NumPy arrays (e.g., `np.array(0.1)`) back to Python scalars (`float` or `int`).
- Only fields that exist in the dataclass definition are reconstructed.
- If a field is missing in the `.npz` file, it will raise a `TypeError` (from `dataclass` instantiation).
"""
try:
data = np.load(filename, allow_pickle=True)
except FileNotFoundError as e:
raise FileNotFoundError(f"File '{filename}' not found.") from e
field_data = {}
for field in dataclasses.fields(cls): # type: ignore
if field.name in data:
value = data[field.name]
if isinstance(value, np.ndarray) and value.shape == (): # Convert 0D NumPy arrays to scalars
value = value.item()
field_data[field.name] = value
return cls(**field_data)