API
This page documents the reusable modules in the repository. The project is primarily script-driven, but several modules define interfaces that are useful as importable components.
API Overview
The most reusable modules are:
lnn: model definition and theLagrangianNNentrypointdata: analytical system helpers and HDF5 dataset utilities- training and simulation helpers in
train_utils,simulate, andlosses
The API documentation below focuses on importable functions and classes rather than __main__ script blocks.
Core Model
LagrangianNN
Bases: Module
Neural network model for learning a structured Lagrangian representation of a physical system.
The model consists of three core modules: 1. kinetic_net - an MLP that maps trigonometric features of the generalized coordinates (q) to a Cholesky-factorised mass matrix representation. This branch does not take system parameters as direct inputs; instead, it is conditioned on normalized system parameters through FiLM modulation.
-
potential_net - an MLP that takes the same trigonometric features together with the normalized system parameters (p) and outputs a scalar normalized potential energy.
-
film_net - a Feature-wise Linear Modulation (FiLM) network that generates per-layer scaling (gamma) and shifting (beta) parameters for the hidden layers of kinetic_net. These parameters are conditioned on the normalized system parameters (e.g. masses, lengths) and thus allow the normalized kinetic energy branch to adapt across different system configurations.
The __call__ method of this class uses automatic differentiation (jax.grad, jax.jacobian)
on the learned normalized Lagrangian to derive the system's equations of motion,
returning the generalized accelerations (q_tt).
Source code in src/lnn/model.py
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 | |
__call__(q, q_t, p)
Derives and returns the generalized accelerations (q_tt) from the learned normalized Lagrangian.
This method applies Euler-Lagrange equations via automatic differentiation to compute the accelerations: M(q) * q_tt = f(q, q_t) where M = d^2L / (dq_t dq_t), i.e. the Hessian of the normalized kinetic energy term with respect to generalized velocities and f = dL/dq - d/dt(dL/dq_t) = dL/dq - (dL/dq_t dq)q_t - (dL/dq_t dq_t)q_tt (effectively) Rearranging and solving for q_tt.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
q
|
Array
|
Generalized coordinates, shape (pos_dim,). |
required |
q_t
|
Array
|
Generalized velocities, shape (vel_dim,). |
required |
p
|
Array
|
System parameters, shape (param_dim,). |
required |
Returns:
| Type | Description |
|---|---|
Array
|
jax.Array: Generalized accelerations (q_tt), shape (vel_dim,). |
Source code in src/lnn/model.py
172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 | |
apply_film(h, film_params, net)
Runs a network while applying FiLM modulation to each hidden layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
h
|
Input features for the network. |
required | |
film_params
|
Per-layer FiLM parameters of shape (n_hidden, 2), storing [gamma, beta] for each hidden layer. |
required | |
net
|
Network whose hidden activations are modulated. |
required |
Source code in src/lnn/model.py
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 | |
LagrangianNN
Bases: Module
Neural network model for learning a structured Lagrangian representation of a physical system.
The model consists of three core modules: 1. kinetic_net - an MLP that maps trigonometric features of the generalized coordinates (q) to a Cholesky-factorised mass matrix representation. This branch does not take system parameters as direct inputs; instead, it is conditioned on normalized system parameters through FiLM modulation.
-
potential_net - an MLP that takes the same trigonometric features together with the normalized system parameters (p) and outputs a scalar normalized potential energy.
-
film_net - a Feature-wise Linear Modulation (FiLM) network that generates per-layer scaling (gamma) and shifting (beta) parameters for the hidden layers of kinetic_net. These parameters are conditioned on the normalized system parameters (e.g. masses, lengths) and thus allow the normalized kinetic energy branch to adapt across different system configurations.
The __call__ method of this class uses automatic differentiation (jax.grad, jax.jacobian)
on the learned normalized Lagrangian to derive the system's equations of motion,
returning the generalized accelerations (q_tt).
Source code in src/lnn/model.py
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 | |
__call__(q, q_t, p)
Derives and returns the generalized accelerations (q_tt) from the learned normalized Lagrangian.
This method applies Euler-Lagrange equations via automatic differentiation to compute the accelerations: M(q) * q_tt = f(q, q_t) where M = d^2L / (dq_t dq_t), i.e. the Hessian of the normalized kinetic energy term with respect to generalized velocities and f = dL/dq - d/dt(dL/dq_t) = dL/dq - (dL/dq_t dq)q_t - (dL/dq_t dq_t)q_tt (effectively) Rearranging and solving for q_tt.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
q
|
Array
|
Generalized coordinates, shape (pos_dim,). |
required |
q_t
|
Array
|
Generalized velocities, shape (vel_dim,). |
required |
p
|
Array
|
System parameters, shape (param_dim,). |
required |
Returns:
| Type | Description |
|---|---|
Array
|
jax.Array: Generalized accelerations (q_tt), shape (vel_dim,). |
Source code in src/lnn/model.py
172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 | |
apply_film(h, film_params, net)
Runs a network while applying FiLM modulation to each hidden layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
h
|
Input features for the network. |
required | |
film_params
|
Per-layer FiLM parameters of shape (n_hidden, 2), storing [gamma, beta] for each hidden layer. |
required | |
net
|
Network whose hidden activations are modulated. |
required |
Source code in src/lnn/model.py
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 | |
Analytical System And Dataset Utilities
DoublePendulum
Bases: Module
Double pendulum code assuming the following state variables:
x = [t1, t2, w1, w2] t1: rad, angle of pendulum 1 from downward vertical t2: rad, angle of pendulum 2 from downward vertical w1: rad/s, angular velocity of pendulum 1 w2: rad/s, angular velocity of pendulum 2
analytical state transition returns the state vector derivative d/dt x: [w1, w2, g1, g2]
Source code in src/data/doublependulum.py
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | |
analytical_state_transition(full_state, t)
1 - a1 * a2 in the denominator goes to zero when t1 - t2 = ±π/2 (cos → 0 kills it) — actually it's cos²(t1-t2) that drives the singularity.
Source code in src/data/doublependulum.py
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | |
to_cartesian(q)
Convert angles to Cartesian coordinates.
Source code in src/data/doublependulum.py
113 114 115 116 117 118 119 120 | |
load_list_of_arrays_from_h5(system='doublependulum', filename='trajectories.h5')
Loads a list of 2D NumPy arrays from a specified HDF5 file.
The function constructs the file path relative to the project's root data directory. It expects the arrays to be stored within an HDF5 group named 'trajectories', with individual datasets named systematically (e.g., 'trajectory_000').
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
system
|
str
|
The name of the subdirectory within the project's 'data/' folder where the HDF5 file is located. Defaults to 'doublependulum'. |
'doublependulum'
|
filename
|
str
|
The name of the HDF5 file to load. Defaults to 'trajectories.h5'. |
'trajectories.h5'
|
Returns:
| Type | Description |
|---|---|
List[array]
|
List[np.array]: A list of 2D NumPy arrays loaded from the HDF5 file. Returns an empty list if the specified file does not exist. |
Source code in src/data/utils.py
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 | |
save_list_of_arrays_to_h5(list_of_arrays, system='doublependulum', filename='list_of_trajectories.h5')
Saves a list of 2D NumPy or JAX arrays into a single compressed HDF5 file.
Each array in the list is stored as a separate dataset within an HDF5 group named 'trajectories'. The individual datasets are named sequentially (e.g., 'trajectory_000', 'trajectory_001', etc.). The function ensures that the target directory structure exists before writing the file.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
list_of_arrays
|
List[Union[ndarray, ndarray]]
|
A list of 2D arrays (either NumPy arrays or JAX arrays) to be saved. |
required |
system
|
str
|
The name of the subdirectory within the project's 'data/' folder where the HDF5 file should be stored. Defaults to 'doublependulum'. |
'doublependulum'
|
filename
|
str
|
The name of the HDF5 file to create. Defaults to 'trajectories.h5'. |
'list_of_trajectories.h5'
|
Side Effects
- Creates the target directory (e.g., 'project_root/data/doublependulum/') if it does not already exist.
- Creates an HDF5 file at the specified path.
- Prints messages indicating directory creation and successful file saving.
Source code in src/data/utils.py
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | |
DoublePendulum
Bases: Module
Double pendulum code assuming the following state variables:
x = [t1, t2, w1, w2] t1: rad, angle of pendulum 1 from downward vertical t2: rad, angle of pendulum 2 from downward vertical w1: rad/s, angular velocity of pendulum 1 w2: rad/s, angular velocity of pendulum 2
analytical state transition returns the state vector derivative d/dt x: [w1, w2, g1, g2]
Source code in src/data/doublependulum.py
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | |
analytical_state_transition(full_state, t)
1 - a1 * a2 in the denominator goes to zero when t1 - t2 = ±π/2 (cos → 0 kills it) — actually it's cos²(t1-t2) that drives the singularity.
Source code in src/data/doublependulum.py
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | |
to_cartesian(q)
Convert angles to Cartesian coordinates.
Source code in src/data/doublependulum.py
113 114 115 116 117 118 119 120 | |
get_project_data_path(sub_path='')
Constructs an absolute path to a location within the project's /data directory. Assumes this script is run from within the project structure (e.g., src/data/).
Source code in src/data/utils.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | |
load_list_of_arrays_from_h5(system='doublependulum', filename='trajectories.h5')
Loads a list of 2D NumPy arrays from a specified HDF5 file.
The function constructs the file path relative to the project's root data directory. It expects the arrays to be stored within an HDF5 group named 'trajectories', with individual datasets named systematically (e.g., 'trajectory_000').
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
system
|
str
|
The name of the subdirectory within the project's 'data/' folder where the HDF5 file is located. Defaults to 'doublependulum'. |
'doublependulum'
|
filename
|
str
|
The name of the HDF5 file to load. Defaults to 'trajectories.h5'. |
'trajectories.h5'
|
Returns:
| Type | Description |
|---|---|
List[array]
|
List[np.array]: A list of 2D NumPy arrays loaded from the HDF5 file. Returns an empty list if the specified file does not exist. |
Source code in src/data/utils.py
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 | |
save_list_of_arrays_to_h5(list_of_arrays, system='doublependulum', filename='list_of_trajectories.h5')
Saves a list of 2D NumPy or JAX arrays into a single compressed HDF5 file.
Each array in the list is stored as a separate dataset within an HDF5 group named 'trajectories'. The individual datasets are named sequentially (e.g., 'trajectory_000', 'trajectory_001', etc.). The function ensures that the target directory structure exists before writing the file.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
list_of_arrays
|
List[Union[ndarray, ndarray]]
|
A list of 2D arrays (either NumPy arrays or JAX arrays) to be saved. |
required |
system
|
str
|
The name of the subdirectory within the project's 'data/' folder where the HDF5 file should be stored. Defaults to 'doublependulum'. |
'doublependulum'
|
filename
|
str
|
The name of the HDF5 file to create. Defaults to 'trajectories.h5'. |
'list_of_trajectories.h5'
|
Side Effects
- Creates the target directory (e.g., 'project_root/data/doublependulum/') if it does not already exist.
- Creates an HDF5 file at the specified path.
- Prints messages indicating directory creation and successful file saving.
Source code in src/data/utils.py
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | |
Training And Simulation Helpers
build_input_output(datasets, params, dt)
Preprocesses raw trajectory data into model inputs (X) and targets (dXdt).
The input X is augmented with system parameters. The target dXdt is computed via numerical differentiation of velocities.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
datasets
|
List[Array]
|
List of raw trajectory data, each of shape (time_steps, 5) where columns are [time, q1, q2, w1, w2]. |
required |
params
|
List[Array]
|
List of system parameters, each of shape (4,) [m1, m2, l1, l2]. |
required |
dt
|
float
|
Time step size, used for numerical differentiation. |
required |
Returns:
| Type | Description |
|---|---|
Tuple[Array, Array]
|
Tuple[jax.Array, jax.Array]: A tuple containing: - X (jax.Array): Concatenated input states, shape (num_trajectories, time_steps, features). Features order: [q1, q2, w1, w2, m1, m2, l1, l2]. - dXdt (jax.Array): Computed accelerations, shape (num_trajectories, time_steps, 2). |
Source code in src/train_utils.py
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 | |
build_temporal_batch(x, y, batch_size, temporal_chunk_len, step_key)
Builds a batch of data by sampling random temporal chunks from random trajectories.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input data, shape (num_trajectories, time_steps, features). |
required |
y
|
Array
|
Target data, shape (num_trajectories, time_steps, output_dim). |
required |
batch_size
|
int
|
Number of trajectories to sample for this batch. |
required |
temporal_chunk_len
|
int
|
Length of the time chunk to extract from each sampled trajectory. |
required |
step_key
|
Array
|
JAX PRNGKey for random sampling. |
required |
Returns:
| Type | Description |
|---|---|
List[Array, Array]
|
Tuple[jax.Array, jax.Array]: A batch of inputs (x_batch) and targets (y_batch), concatenated along the first axis. Shapes: (batch_size * temporal_chunk_len, features) and (batch_size * temporal_chunk_len, output_dim). |
Source code in src/train_utils.py
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 | |
compute_V_stats(datasets, params, idx_train)
Computes mean and standard deviation of potential energy for the training set.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
datasets
|
List[Array]
|
List of all trajectory datasets. |
required |
params
|
List[Array]
|
List of all system parameters corresponding to datasets. |
required |
idx_train
|
ndarray
|
Indices of trajectories belonging to the training set. |
required |
Returns:
| Type | Description |
|---|---|
Tuple[Array, Array]
|
Tuple[jnp.Array, jnp.Array]: Mean and standard deviation of potential energy. |
Source code in src/train_utils.py
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | |
load_model(model, fname='model')
Loads an Equinox model's leaves (trainable parameters) from a file.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
An uninitialized Equinox model with the correct architecture to load the parameters into. |
required |
fname
|
Path
|
The base path/filename of the model to load. Assumes a '.eqx' extension. |
'model'
|
Returns:
| Type | Description |
|---|---|
Module
|
eqx.Module: The model with loaded parameters. |
Source code in src/train_utils.py
20 21 22 23 24 25 26 27 28 29 30 31 32 33 | |
normalize_data(Xtrain, Xval, Xtest, dXdt_train, dXdt_val, dXdt_test, len_params, normalize=True)
Normalizes the input (X) and target (dXdt) datasets based on training set statistics.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
Xtrain
|
Array
|
Input training data. |
required |
Xval
|
Array
|
Input validation data. |
required |
Xtest
|
Array
|
Input test data. |
required |
dXdt_train
|
Array
|
Target training data (accelerations). |
required |
dXdt_val
|
Array
|
Target validation data (accelerations). |
required |
dXdt_test
|
Array
|
Target test data (accelerations). |
required |
len_params
|
int
|
The number of unique parameter sets (trajectories) in the dataset. Used to handle cases where parameters might be constant across all train samples (std=0). |
required |
normalize
|
bool
|
If True, perform normalization. Otherwise, return data as is. Defaults to True. |
True
|
Returns:
| Type | Description |
|---|---|
List[Array, Array, Array, Array, Dict]
|
Tuple[jax.Array, ...]: Normalized (or original) X_train, X_val, X_test, dXdt_train, dXdt_val, dXdt_test, and a dictionary of norm_stats. The tuple order is: Xtrain_norm, Xval_norm, Xtest_norm, dXdt_train_norm, dXdt_val_norm, dXdt_test_norm, norm_stats. |
Source code in src/train_utils.py
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | |
run_diagnostics(model, Xtrain, Xtrain_norm, dXdt_train, dXdt_train_norm, norm_stats, params)
One-step acceleration check and Lagrangian structure diagnostics.
Source code in src/train_utils.py
260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 | |
save_model(model, fname)
Saves an Equinox model's leaves (trainable parameters) to a file.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
The Equinox model to save. |
required |
fname
|
Path
|
The base path/filename for the model. A '.eqx' extension will be added. |
required |
Source code in src/train_utils.py
10 11 12 13 14 15 16 17 18 | |
train_test_split(X, n_train=0.7, n_val=0.1, seed=42)
Splits the dataset into training, validation, and test sets.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X
|
Array
|
The full dataset (e.g., input trajectories). |
required |
n_train
|
float
|
Proportion of data to use for the training set. Defaults to 0.7. |
0.7
|
n_val
|
float
|
Proportion of data to use for the validation set. Defaults to 0.1. |
0.1
|
seed
|
int
|
Random seed for reproducibility. Defaults to 42. |
42
|
Returns:
| Type | Description |
|---|---|
Tuple[ndarray, ndarray, ndarray]
|
tuple[np.ndarray, np.ndarray, np.ndarray]: Indices for train, validation, and test sets. |
Source code in src/train_utils.py
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 | |
save_rollout_data(save_dir, filename_prefix, times, gt_states, sim_states, params_phys, case_label='')
Saves ground truth and simulated trajectory data to a compressed .npz file.
Source code in src/simulate.py
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 | |
energy_conservation_loss(model, x, split_size=2)
Calculates a loss term that penalizes drift in the model's normalized Hamiltonian within trajectory chunks.
Because the model is trained in normalized coordinates, this quantity should be interpreted as a structured normalized energy induced by the learned Lagrangian, rather than as the exact physical Hamiltonian in original units. The loss encourages temporal consistency of that learned quantity along each trajectory chunk.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
The neural network model. |
required |
x
|
Array
|
The input batch of state vectors containing generalized coordinates, generalized velocities, and normalized system parameters. |
required |
split_size
|
int
|
The dimensionality of generalized coordinates. Defaults to 2 for 2D systems. |
2
|
Returns:
| Type | Description |
|---|---|
Array
|
jax.Array: Variance of the model's normalized Hamiltonian across the current flattened batch chunk. In the present training setup this corresponds to a single trajectory chunk; for multi-trajectory batches this would need to be made trajectory-local explicitly. |
Source code in src/losses.py
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 | |