Skip to content

API

This page documents the reusable modules in the repository. The project is primarily script-driven, but several modules define interfaces that are useful as importable components.

API Overview

The most reusable modules are:

  • lnn: model definition and the LagrangianNN entrypoint
  • data: analytical system helpers and HDF5 dataset utilities
  • training and simulation helpers in train_utils, simulate, and losses

The API documentation below focuses on importable functions and classes rather than __main__ script blocks.

Core Model

LagrangianNN

Bases: Module

Neural network model for learning a structured Lagrangian representation of a physical system.

The model consists of three core modules: 1. kinetic_net - an MLP that maps trigonometric features of the generalized coordinates (q) to a Cholesky-factorised mass matrix representation. This branch does not take system parameters as direct inputs; instead, it is conditioned on normalized system parameters through FiLM modulation.

  1. potential_net - an MLP that takes the same trigonometric features together with the normalized system parameters (p) and outputs a scalar normalized potential energy.

  2. film_net - a Feature-wise Linear Modulation (FiLM) network that generates per-layer scaling (gamma) and shifting (beta) parameters for the hidden layers of kinetic_net. These parameters are conditioned on the normalized system parameters (e.g. masses, lengths) and thus allow the normalized kinetic energy branch to adapt across different system configurations.

The __call__ method of this class uses automatic differentiation (jax.grad, jax.jacobian) on the learned normalized Lagrangian to derive the system's equations of motion, returning the generalized accelerations (q_tt).

Source code in src/lnn/model.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
class LagrangianNN(eqx.Module):
    """
    Neural network model for learning a structured Lagrangian representation of a
    physical system.

    The model consists of three core modules:
    1. kinetic_net - an MLP that maps trigonometric features of the generalized
    coordinates (q) to a Cholesky-factorised mass matrix representation. This
    branch does not take system parameters as direct inputs; instead, it is
    conditioned on normalized system parameters through FiLM modulation.

    2. potential_net - an MLP that takes the same trigonometric features together
    with the normalized system parameters (p) and outputs a scalar normalized
    potential energy.

    3. film_net - a Feature-wise Linear Modulation (FiLM) network that generates
    per-layer scaling (gamma) and shifting (beta) parameters for the hidden layers
    of kinetic_net. These parameters are conditioned on the normalized system
    parameters (e.g. masses, lengths) and thus allow the normalized kinetic energy
    branch to adapt across different system configurations.

    The `__call__` method of this class uses automatic differentiation (jax.grad, jax.jacobian)
    on the learned normalized Lagrangian to derive the system's equations of motion,
    returning the generalized accelerations (q_tt).
    """
    kinetic_net:    eqx.nn.MLP
    potential_net:  eqx.nn.MLP
    film_net:       eqx.nn.MLP
    n_hidden:       int
    hidden_dim:     int

    def __init__(self,
                 pos_dim: int,
                 vel_dim: int,
                 hidden_dim: int, 
                 n_hidden: int, 
                 param_dim: int,
                 key: jnp.array, 
                 **kwds):
        super().__init__(**kwds)
        """Initializes the Lagrangian Neural Network.

        Args:
            pos_dim (int): Dimensionality of the generalized coordinates (q).
            vel_dim (int): Dimensionality of the generalized velocities (q_dot).
            param_dim (int): Dimensionality of the system parameters (p),
                             e.g., masses and lengths.
            hidden_dim (int): Width of the hidden layers in the `lagrangian_net`.
            n_hidden (int): Number of hidden layers in the `lagrangian_net`.
            key (jnp.array): JAX PRNGKey for initializing model weights.
            apply_trig_fn (bool, optional): If True, converts generalized coordinates `q`
                                            into `[sin(q_i), cos(q_i)]` pairs for input
                                            to the `lagrangian_net`. Defaults to True.
            **kwds: Additional keyword arguments for Equinox.
        """
        self.hidden_dim = hidden_dim
        self.n_hidden = n_hidden
        trig_dim = pos_dim * 2

        kinetic_key, potential_key, film_key = jax.random.split(key, 3)

        # kinetic net
        output_dim_kin = vel_dim * (vel_dim+1)//2

        self.kinetic_net = eqx.nn.MLP(
            in_size=trig_dim, 
            out_size=output_dim_kin,
            width_size=hidden_dim,
            depth=n_hidden,
            activation=lambda x: x,    # identity for now; apply softplus later.
            key=kinetic_key,
        )

        # potential net
        self.potential_net = eqx.nn.MLP(
            in_size=trig_dim + param_dim,
            out_size=1,
            width_size=hidden_dim,
            depth=n_hidden,
            activation=jax.nn.tanh,
            key=potential_key
        )

        # Feature-wIse Linear Modulation (FiLM)
        # ====================================
        # Outputs two FiLM parameters per hidden layer to condition the kinetic branch
        # on the normalized pendulum parameters.
        self.film_net = eqx.nn.MLP(
            in_size=param_dim, 
            out_size = 2 * n_hidden,
            width_size=32,
            depth=2,
            activation=jax.nn.softplus,
            key=film_key
        )

        # Initialize FiLM net to identity such that at first, the kinetic energy is not modulated at all
        # identity_bias  = jnp.tile(jnp.array([1.0, 0.]), n_hidden)
        # model = eqx.tree_at(lambda m: m.film_net.layers[-1].bias, self, identity_bias)


    def apply_film(self, h, film_params, net):
        """
        Runs a network while applying FiLM modulation to each hidden layer.

        Args:
            h: Input features for the network.
            film_params: Per-layer FiLM parameters of shape (n_hidden, 2), storing
                [gamma, beta] for each hidden layer.
            net: Network whose hidden activations are modulated.
        """
        for i in range(self.n_hidden):
            # Compute layer transformation
            h = net.layers[i](h)
            h = jax.nn.softplus(h)

            # FiLM scaling
            gamma = film_params[i, 0]
            beta = film_params[i, 1]

            h = gamma * h + beta
        return h

    def compute_cholesky_entries(self, q: jnp.Array, film_params: jnp.Array) -> jnp.Array:
        h = self.apply_film(q, film_params, self.kinetic_net)
        return self.kinetic_net.layers[self.n_hidden](h)

    def compute_potential(self, q: jnp.Array, p: jnp.Array) -> jnp.Array:
        h_pot = jnp.concatenate([q, p])
        return jnp.squeeze(self.potential_net(h_pot))

    def compute_lagrangian(self, q: jax.Array, q_t: jax.Array, p: jax.Array) -> jax.Array:    

        # Transform angles to trigonometric features to avoid discontinuities at
        # angle wraparound. Angles are kept in their original scale, while
        # velocities and parameters are normalized elsewhere for optimization
        # stability.
        trig_q = []
        for qi in q:
            trig_q.extend([jnp.sin(qi), jnp.cos(qi)])
        trig_q = jnp.array(trig_q)

        # Compute FiLM parameters from the normalized system parameters.
        # Reshape to (n_hidden, 2) where each row is [gamma_i, beta_i].
        film_params = self.film_net(p).reshape(self.n_hidden, 2)

        # Compute the normalized kinetic energy.
        chol_entries = self.compute_cholesky_entries(trig_q, film_params)

        # Build the positive-definite matrix used in the normalized kinetic energy term.
        L = jnp.array([
            [jax.nn.softplus(chol_entries[0]),                               0.0],
            [chol_entries[1],                   jax.nn.softplus(chol_entries[2])]
        ])

        M = L.T @ L + jnp.eye(2) * 1e-6
        T = 0.5 * q_t @ M @ q_t

        # Compute the normalized potential energy.
        V = self.compute_potential(trig_q, p)

        # Return the learned structured Lagrangian in normalized coordinates.
        return T - V

    def __call__(self, q: jax.Array, q_t: jax.Array, p: jax.Array) -> jax.Array:
        """Derives and returns the generalized accelerations (q_tt) from the
        learned normalized Lagrangian.

        This method applies Euler-Lagrange equations via automatic differentiation
        to compute the accelerations:
        M(q) * q_tt = f(q, q_t)
        where M = d^2L / (dq_t dq_t), i.e. the Hessian of the normalized kinetic
        energy term with respect to generalized velocities
        and f = dL/dq - d/dt(dL/dq_t) = dL/dq - (dL/dq_t dq)q_t - (dL/dq_t dq_t)q_tt (effectively)
        Rearranging and solving for q_tt.

        Args:
            q (jax.Array): Generalized coordinates, shape (pos_dim,).
            q_t (jax.Array): Generalized velocities, shape (vel_dim,).
            p (jax.Array): System parameters, shape (param_dim,).

        Returns:
            jax.Array: Generalized accelerations (q_tt), shape (vel_dim,).
        """
        lagrangian_fn = lambda _q, _qt: self.compute_lagrangian(_q, _qt, p)

        # 1. Compute dL/dq
        l_q = jax.grad(lagrangian_fn, argnums=0)(q, q_t)

        # 2. dL/dqt and its derivatives
        l_qt_fn = jax.grad(lagrangian_fn, argnums=1)        # get function dL/dqt
        l_qt_q = jax.jacobian(l_qt_fn, argnums=0)(q, q_t)   # l_qt_q = d^2L / (dqt dq), shape (vel_dim,vel_dim)
        l_qt_qt = jax.jacobian(l_qt_fn, argnums=1)(q, q_t)  # l_qt_qt = d^2L / (dqt dqt)  <-- The Mass Matrix, shape (vel_dim,vel_dim)

        # 3. Solve (L_qt_qt) * q_tt = L_q - (L_qt_q) * q_t
        l_qt_qt = l_qt_qt + jnp.eye(2) * 1e-6
        rhs     = l_q - l_qt_q @ q_t
        q_tt    = jnp.linalg.solve(l_qt_qt, rhs)    # acceleration
        return q_tt

__call__(q, q_t, p)

Derives and returns the generalized accelerations (q_tt) from the learned normalized Lagrangian.

This method applies Euler-Lagrange equations via automatic differentiation to compute the accelerations: M(q) * q_tt = f(q, q_t) where M = d^2L / (dq_t dq_t), i.e. the Hessian of the normalized kinetic energy term with respect to generalized velocities and f = dL/dq - d/dt(dL/dq_t) = dL/dq - (dL/dq_t dq)q_t - (dL/dq_t dq_t)q_tt (effectively) Rearranging and solving for q_tt.

Parameters:

Name Type Description Default
q Array

Generalized coordinates, shape (pos_dim,).

required
q_t Array

Generalized velocities, shape (vel_dim,).

required
p Array

System parameters, shape (param_dim,).

required

Returns:

Type Description
Array

jax.Array: Generalized accelerations (q_tt), shape (vel_dim,).

Source code in src/lnn/model.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
def __call__(self, q: jax.Array, q_t: jax.Array, p: jax.Array) -> jax.Array:
    """Derives and returns the generalized accelerations (q_tt) from the
    learned normalized Lagrangian.

    This method applies Euler-Lagrange equations via automatic differentiation
    to compute the accelerations:
    M(q) * q_tt = f(q, q_t)
    where M = d^2L / (dq_t dq_t), i.e. the Hessian of the normalized kinetic
    energy term with respect to generalized velocities
    and f = dL/dq - d/dt(dL/dq_t) = dL/dq - (dL/dq_t dq)q_t - (dL/dq_t dq_t)q_tt (effectively)
    Rearranging and solving for q_tt.

    Args:
        q (jax.Array): Generalized coordinates, shape (pos_dim,).
        q_t (jax.Array): Generalized velocities, shape (vel_dim,).
        p (jax.Array): System parameters, shape (param_dim,).

    Returns:
        jax.Array: Generalized accelerations (q_tt), shape (vel_dim,).
    """
    lagrangian_fn = lambda _q, _qt: self.compute_lagrangian(_q, _qt, p)

    # 1. Compute dL/dq
    l_q = jax.grad(lagrangian_fn, argnums=0)(q, q_t)

    # 2. dL/dqt and its derivatives
    l_qt_fn = jax.grad(lagrangian_fn, argnums=1)        # get function dL/dqt
    l_qt_q = jax.jacobian(l_qt_fn, argnums=0)(q, q_t)   # l_qt_q = d^2L / (dqt dq), shape (vel_dim,vel_dim)
    l_qt_qt = jax.jacobian(l_qt_fn, argnums=1)(q, q_t)  # l_qt_qt = d^2L / (dqt dqt)  <-- The Mass Matrix, shape (vel_dim,vel_dim)

    # 3. Solve (L_qt_qt) * q_tt = L_q - (L_qt_q) * q_t
    l_qt_qt = l_qt_qt + jnp.eye(2) * 1e-6
    rhs     = l_q - l_qt_q @ q_t
    q_tt    = jnp.linalg.solve(l_qt_qt, rhs)    # acceleration
    return q_tt

apply_film(h, film_params, net)

Runs a network while applying FiLM modulation to each hidden layer.

Parameters:

Name Type Description Default
h

Input features for the network.

required
film_params

Per-layer FiLM parameters of shape (n_hidden, 2), storing [gamma, beta] for each hidden layer.

required
net

Network whose hidden activations are modulated.

required
Source code in src/lnn/model.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
def apply_film(self, h, film_params, net):
    """
    Runs a network while applying FiLM modulation to each hidden layer.

    Args:
        h: Input features for the network.
        film_params: Per-layer FiLM parameters of shape (n_hidden, 2), storing
            [gamma, beta] for each hidden layer.
        net: Network whose hidden activations are modulated.
    """
    for i in range(self.n_hidden):
        # Compute layer transformation
        h = net.layers[i](h)
        h = jax.nn.softplus(h)

        # FiLM scaling
        gamma = film_params[i, 0]
        beta = film_params[i, 1]

        h = gamma * h + beta
    return h

LagrangianNN

Bases: Module

Neural network model for learning a structured Lagrangian representation of a physical system.

The model consists of three core modules: 1. kinetic_net - an MLP that maps trigonometric features of the generalized coordinates (q) to a Cholesky-factorised mass matrix representation. This branch does not take system parameters as direct inputs; instead, it is conditioned on normalized system parameters through FiLM modulation.

  1. potential_net - an MLP that takes the same trigonometric features together with the normalized system parameters (p) and outputs a scalar normalized potential energy.

  2. film_net - a Feature-wise Linear Modulation (FiLM) network that generates per-layer scaling (gamma) and shifting (beta) parameters for the hidden layers of kinetic_net. These parameters are conditioned on the normalized system parameters (e.g. masses, lengths) and thus allow the normalized kinetic energy branch to adapt across different system configurations.

The __call__ method of this class uses automatic differentiation (jax.grad, jax.jacobian) on the learned normalized Lagrangian to derive the system's equations of motion, returning the generalized accelerations (q_tt).

Source code in src/lnn/model.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
class LagrangianNN(eqx.Module):
    """
    Neural network model for learning a structured Lagrangian representation of a
    physical system.

    The model consists of three core modules:
    1. kinetic_net - an MLP that maps trigonometric features of the generalized
    coordinates (q) to a Cholesky-factorised mass matrix representation. This
    branch does not take system parameters as direct inputs; instead, it is
    conditioned on normalized system parameters through FiLM modulation.

    2. potential_net - an MLP that takes the same trigonometric features together
    with the normalized system parameters (p) and outputs a scalar normalized
    potential energy.

    3. film_net - a Feature-wise Linear Modulation (FiLM) network that generates
    per-layer scaling (gamma) and shifting (beta) parameters for the hidden layers
    of kinetic_net. These parameters are conditioned on the normalized system
    parameters (e.g. masses, lengths) and thus allow the normalized kinetic energy
    branch to adapt across different system configurations.

    The `__call__` method of this class uses automatic differentiation (jax.grad, jax.jacobian)
    on the learned normalized Lagrangian to derive the system's equations of motion,
    returning the generalized accelerations (q_tt).
    """
    kinetic_net:    eqx.nn.MLP
    potential_net:  eqx.nn.MLP
    film_net:       eqx.nn.MLP
    n_hidden:       int
    hidden_dim:     int

    def __init__(self,
                 pos_dim: int,
                 vel_dim: int,
                 hidden_dim: int, 
                 n_hidden: int, 
                 param_dim: int,
                 key: jnp.array, 
                 **kwds):
        super().__init__(**kwds)
        """Initializes the Lagrangian Neural Network.

        Args:
            pos_dim (int): Dimensionality of the generalized coordinates (q).
            vel_dim (int): Dimensionality of the generalized velocities (q_dot).
            param_dim (int): Dimensionality of the system parameters (p),
                             e.g., masses and lengths.
            hidden_dim (int): Width of the hidden layers in the `lagrangian_net`.
            n_hidden (int): Number of hidden layers in the `lagrangian_net`.
            key (jnp.array): JAX PRNGKey for initializing model weights.
            apply_trig_fn (bool, optional): If True, converts generalized coordinates `q`
                                            into `[sin(q_i), cos(q_i)]` pairs for input
                                            to the `lagrangian_net`. Defaults to True.
            **kwds: Additional keyword arguments for Equinox.
        """
        self.hidden_dim = hidden_dim
        self.n_hidden = n_hidden
        trig_dim = pos_dim * 2

        kinetic_key, potential_key, film_key = jax.random.split(key, 3)

        # kinetic net
        output_dim_kin = vel_dim * (vel_dim+1)//2

        self.kinetic_net = eqx.nn.MLP(
            in_size=trig_dim, 
            out_size=output_dim_kin,
            width_size=hidden_dim,
            depth=n_hidden,
            activation=lambda x: x,    # identity for now; apply softplus later.
            key=kinetic_key,
        )

        # potential net
        self.potential_net = eqx.nn.MLP(
            in_size=trig_dim + param_dim,
            out_size=1,
            width_size=hidden_dim,
            depth=n_hidden,
            activation=jax.nn.tanh,
            key=potential_key
        )

        # Feature-wIse Linear Modulation (FiLM)
        # ====================================
        # Outputs two FiLM parameters per hidden layer to condition the kinetic branch
        # on the normalized pendulum parameters.
        self.film_net = eqx.nn.MLP(
            in_size=param_dim, 
            out_size = 2 * n_hidden,
            width_size=32,
            depth=2,
            activation=jax.nn.softplus,
            key=film_key
        )

        # Initialize FiLM net to identity such that at first, the kinetic energy is not modulated at all
        # identity_bias  = jnp.tile(jnp.array([1.0, 0.]), n_hidden)
        # model = eqx.tree_at(lambda m: m.film_net.layers[-1].bias, self, identity_bias)


    def apply_film(self, h, film_params, net):
        """
        Runs a network while applying FiLM modulation to each hidden layer.

        Args:
            h: Input features for the network.
            film_params: Per-layer FiLM parameters of shape (n_hidden, 2), storing
                [gamma, beta] for each hidden layer.
            net: Network whose hidden activations are modulated.
        """
        for i in range(self.n_hidden):
            # Compute layer transformation
            h = net.layers[i](h)
            h = jax.nn.softplus(h)

            # FiLM scaling
            gamma = film_params[i, 0]
            beta = film_params[i, 1]

            h = gamma * h + beta
        return h

    def compute_cholesky_entries(self, q: jnp.Array, film_params: jnp.Array) -> jnp.Array:
        h = self.apply_film(q, film_params, self.kinetic_net)
        return self.kinetic_net.layers[self.n_hidden](h)

    def compute_potential(self, q: jnp.Array, p: jnp.Array) -> jnp.Array:
        h_pot = jnp.concatenate([q, p])
        return jnp.squeeze(self.potential_net(h_pot))

    def compute_lagrangian(self, q: jax.Array, q_t: jax.Array, p: jax.Array) -> jax.Array:    

        # Transform angles to trigonometric features to avoid discontinuities at
        # angle wraparound. Angles are kept in their original scale, while
        # velocities and parameters are normalized elsewhere for optimization
        # stability.
        trig_q = []
        for qi in q:
            trig_q.extend([jnp.sin(qi), jnp.cos(qi)])
        trig_q = jnp.array(trig_q)

        # Compute FiLM parameters from the normalized system parameters.
        # Reshape to (n_hidden, 2) where each row is [gamma_i, beta_i].
        film_params = self.film_net(p).reshape(self.n_hidden, 2)

        # Compute the normalized kinetic energy.
        chol_entries = self.compute_cholesky_entries(trig_q, film_params)

        # Build the positive-definite matrix used in the normalized kinetic energy term.
        L = jnp.array([
            [jax.nn.softplus(chol_entries[0]),                               0.0],
            [chol_entries[1],                   jax.nn.softplus(chol_entries[2])]
        ])

        M = L.T @ L + jnp.eye(2) * 1e-6
        T = 0.5 * q_t @ M @ q_t

        # Compute the normalized potential energy.
        V = self.compute_potential(trig_q, p)

        # Return the learned structured Lagrangian in normalized coordinates.
        return T - V

    def __call__(self, q: jax.Array, q_t: jax.Array, p: jax.Array) -> jax.Array:
        """Derives and returns the generalized accelerations (q_tt) from the
        learned normalized Lagrangian.

        This method applies Euler-Lagrange equations via automatic differentiation
        to compute the accelerations:
        M(q) * q_tt = f(q, q_t)
        where M = d^2L / (dq_t dq_t), i.e. the Hessian of the normalized kinetic
        energy term with respect to generalized velocities
        and f = dL/dq - d/dt(dL/dq_t) = dL/dq - (dL/dq_t dq)q_t - (dL/dq_t dq_t)q_tt (effectively)
        Rearranging and solving for q_tt.

        Args:
            q (jax.Array): Generalized coordinates, shape (pos_dim,).
            q_t (jax.Array): Generalized velocities, shape (vel_dim,).
            p (jax.Array): System parameters, shape (param_dim,).

        Returns:
            jax.Array: Generalized accelerations (q_tt), shape (vel_dim,).
        """
        lagrangian_fn = lambda _q, _qt: self.compute_lagrangian(_q, _qt, p)

        # 1. Compute dL/dq
        l_q = jax.grad(lagrangian_fn, argnums=0)(q, q_t)

        # 2. dL/dqt and its derivatives
        l_qt_fn = jax.grad(lagrangian_fn, argnums=1)        # get function dL/dqt
        l_qt_q = jax.jacobian(l_qt_fn, argnums=0)(q, q_t)   # l_qt_q = d^2L / (dqt dq), shape (vel_dim,vel_dim)
        l_qt_qt = jax.jacobian(l_qt_fn, argnums=1)(q, q_t)  # l_qt_qt = d^2L / (dqt dqt)  <-- The Mass Matrix, shape (vel_dim,vel_dim)

        # 3. Solve (L_qt_qt) * q_tt = L_q - (L_qt_q) * q_t
        l_qt_qt = l_qt_qt + jnp.eye(2) * 1e-6
        rhs     = l_q - l_qt_q @ q_t
        q_tt    = jnp.linalg.solve(l_qt_qt, rhs)    # acceleration
        return q_tt

__call__(q, q_t, p)

Derives and returns the generalized accelerations (q_tt) from the learned normalized Lagrangian.

This method applies Euler-Lagrange equations via automatic differentiation to compute the accelerations: M(q) * q_tt = f(q, q_t) where M = d^2L / (dq_t dq_t), i.e. the Hessian of the normalized kinetic energy term with respect to generalized velocities and f = dL/dq - d/dt(dL/dq_t) = dL/dq - (dL/dq_t dq)q_t - (dL/dq_t dq_t)q_tt (effectively) Rearranging and solving for q_tt.

Parameters:

Name Type Description Default
q Array

Generalized coordinates, shape (pos_dim,).

required
q_t Array

Generalized velocities, shape (vel_dim,).

required
p Array

System parameters, shape (param_dim,).

required

Returns:

Type Description
Array

jax.Array: Generalized accelerations (q_tt), shape (vel_dim,).

Source code in src/lnn/model.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
def __call__(self, q: jax.Array, q_t: jax.Array, p: jax.Array) -> jax.Array:
    """Derives and returns the generalized accelerations (q_tt) from the
    learned normalized Lagrangian.

    This method applies Euler-Lagrange equations via automatic differentiation
    to compute the accelerations:
    M(q) * q_tt = f(q, q_t)
    where M = d^2L / (dq_t dq_t), i.e. the Hessian of the normalized kinetic
    energy term with respect to generalized velocities
    and f = dL/dq - d/dt(dL/dq_t) = dL/dq - (dL/dq_t dq)q_t - (dL/dq_t dq_t)q_tt (effectively)
    Rearranging and solving for q_tt.

    Args:
        q (jax.Array): Generalized coordinates, shape (pos_dim,).
        q_t (jax.Array): Generalized velocities, shape (vel_dim,).
        p (jax.Array): System parameters, shape (param_dim,).

    Returns:
        jax.Array: Generalized accelerations (q_tt), shape (vel_dim,).
    """
    lagrangian_fn = lambda _q, _qt: self.compute_lagrangian(_q, _qt, p)

    # 1. Compute dL/dq
    l_q = jax.grad(lagrangian_fn, argnums=0)(q, q_t)

    # 2. dL/dqt and its derivatives
    l_qt_fn = jax.grad(lagrangian_fn, argnums=1)        # get function dL/dqt
    l_qt_q = jax.jacobian(l_qt_fn, argnums=0)(q, q_t)   # l_qt_q = d^2L / (dqt dq), shape (vel_dim,vel_dim)
    l_qt_qt = jax.jacobian(l_qt_fn, argnums=1)(q, q_t)  # l_qt_qt = d^2L / (dqt dqt)  <-- The Mass Matrix, shape (vel_dim,vel_dim)

    # 3. Solve (L_qt_qt) * q_tt = L_q - (L_qt_q) * q_t
    l_qt_qt = l_qt_qt + jnp.eye(2) * 1e-6
    rhs     = l_q - l_qt_q @ q_t
    q_tt    = jnp.linalg.solve(l_qt_qt, rhs)    # acceleration
    return q_tt

apply_film(h, film_params, net)

Runs a network while applying FiLM modulation to each hidden layer.

Parameters:

Name Type Description Default
h

Input features for the network.

required
film_params

Per-layer FiLM parameters of shape (n_hidden, 2), storing [gamma, beta] for each hidden layer.

required
net

Network whose hidden activations are modulated.

required
Source code in src/lnn/model.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
def apply_film(self, h, film_params, net):
    """
    Runs a network while applying FiLM modulation to each hidden layer.

    Args:
        h: Input features for the network.
        film_params: Per-layer FiLM parameters of shape (n_hidden, 2), storing
            [gamma, beta] for each hidden layer.
        net: Network whose hidden activations are modulated.
    """
    for i in range(self.n_hidden):
        # Compute layer transformation
        h = net.layers[i](h)
        h = jax.nn.softplus(h)

        # FiLM scaling
        gamma = film_params[i, 0]
        beta = film_params[i, 1]

        h = gamma * h + beta
    return h

Analytical System And Dataset Utilities

DoublePendulum

Bases: Module

Double pendulum code assuming the following state variables:

x = [t1, t2, w1, w2] t1: rad, angle of pendulum 1 from downward vertical t2: rad, angle of pendulum 2 from downward vertical w1: rad/s, angular velocity of pendulum 1 w2: rad/s, angular velocity of pendulum 2

analytical state transition returns the state vector derivative d/dt x: [w1, w2, g1, g2]

Source code in src/data/doublependulum.py
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
class DoublePendulum(eqx.Module):
    """
    Double pendulum code assuming the following state variables:

    x = [t1, t2, w1, w2]
    t1: rad, angle of pendulum 1 from downward vertical
    t2: rad, angle of pendulum 2 from downward vertical
    w1: rad/s, angular velocity of pendulum 1
    w2: rad/s, angular velocity of pendulum 2

    analytical state transition returns the state vector derivative d/dt x: [w1, w2, g1, g2]
    """

    m1: float = 1.0
    m2: float = 1.0
    l1: float = 1.0
    l2: float = 1.0
    g: float = GRAVITY

    @jit
    def kinetic_energy(self, q, q_dot):
        (t1, t2), (w1, w2) = q, q_dot
        T1 = 0.5 * self.m1 * (self.l1 * w1)**2
        T2 = 0.5 * self.m2 * ((self.l1 * w1)**2 + (self.l2 * w2)**2 + 2 * self.l1 * self.l2 * w1 * w2 * jnp.cos(t1 - t2))
        T  = T1 + T2
        return T

    @jit
    def potential_energy(self, q):
        if len(q) == 2:
            (t1, t2) = q
        else:
            t1, t2 = q[:, 0], q[:, 1]
        y1 = - self.l1 * jnp.cos(t1)
        y2 = y1 - self.l2 * jnp.cos(t2)
        V = self.m1 * self.g * y1 + self.m2 * self.g * y2
        return V

    def lagrangian_fn(self, q, q_dot):
        T = self.kinetic_energy(q, q_dot)
        V = self.potential_energy(q)
        return T - V

    def hamiltonian_fn(self, q, q_dot):
        T = self.kinetic_energy(q, q_dot)
        V = self.potential_energy(q)
        return T + V

    def to_cartesian(self, q: jax.Array):
        """Convert angles to Cartesian coordinates."""
        q1, q2 = q
        x1 = self.l1 * jnp.sin(q1)
        y1 = -self.l1 * jnp.cos(q1)
        x2 = x1 + self.l2 * jnp.sin(q2)
        y2 = y1 - self.l2 * jnp.cos(q2)
        return x1, y1, x2, y2

    @staticmethod
    def is_low_energy(q, q_dot, m1, m2, l1, l2, g=9.81):
        t1, t2 = q
        w1, w2 = q_dot

        # PE at unstable equilibrium (both up)
        V_max = (m1 + m2) * g * l1 + m2 * g * l2

        # Total energy at initial condition
        T = 0.5 * m1 * (l1*w1)**2 + 0.5 * m2 * ((l1*w1)**2 + (l2*w2)**2 + 2*l1*l2*w1*w2*jnp.cos(t1-t2))
        V = -(m1 + m2) * g * l1 * jnp.cos(t1) - m2 * g * l2 * jnp.cos(t2)
        H = T + V

        return H < V_max

    @jit
    def analytical_state_transition(self, full_state, t):
        """
        1 - a1 * a2 in the denominator goes to zero when t1 - t2 = ±π/2 (cos → 0 kills it) — actually it's cos²(t1-t2) that drives the singularity. 
        """
        t1, t2, w1, w2 = full_state

        a1 = (self.l2 / self.l1) * (self.m2 / (self.m1 + self.m2)) * jnp.cos(t1 - t2)
        a2 = (self.l1 / self.l2) * jnp.cos(t1 - t2)

        f1 = -(self.l2 / self.l1) * (self.m2 / (self.m1 + self.m2)) * (w2**2) * jnp.sin(t1 - t2) - (self.g / self.l1) * jnp.sin(t1)
        f2 = (self.l1 / self.l2) * (w1**2) * jnp.sin(t1 - t2) - (self.g / self.l2) * jnp.sin(t2)

        g1 = (f1 - a1 * f2) / (1 - a1 * a2)
        g2 = (f2 - a2 * f1) / (1 - a1 * a2)
        return jnp.stack([w1, w2, g1, g2])

analytical_state_transition(full_state, t)

1 - a1 * a2 in the denominator goes to zero when t1 - t2 = ±π/2 (cos → 0 kills it) — actually it's cos²(t1-t2) that drives the singularity.

Source code in src/data/doublependulum.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
@jit
def analytical_state_transition(self, full_state, t):
    """
    1 - a1 * a2 in the denominator goes to zero when t1 - t2 = ±π/2 (cos → 0 kills it) — actually it's cos²(t1-t2) that drives the singularity. 
    """
    t1, t2, w1, w2 = full_state

    a1 = (self.l2 / self.l1) * (self.m2 / (self.m1 + self.m2)) * jnp.cos(t1 - t2)
    a2 = (self.l1 / self.l2) * jnp.cos(t1 - t2)

    f1 = -(self.l2 / self.l1) * (self.m2 / (self.m1 + self.m2)) * (w2**2) * jnp.sin(t1 - t2) - (self.g / self.l1) * jnp.sin(t1)
    f2 = (self.l1 / self.l2) * (w1**2) * jnp.sin(t1 - t2) - (self.g / self.l2) * jnp.sin(t2)

    g1 = (f1 - a1 * f2) / (1 - a1 * a2)
    g2 = (f2 - a2 * f1) / (1 - a1 * a2)
    return jnp.stack([w1, w2, g1, g2])

to_cartesian(q)

Convert angles to Cartesian coordinates.

Source code in src/data/doublependulum.py
113
114
115
116
117
118
119
120
def to_cartesian(self, q: jax.Array):
    """Convert angles to Cartesian coordinates."""
    q1, q2 = q
    x1 = self.l1 * jnp.sin(q1)
    y1 = -self.l1 * jnp.cos(q1)
    x2 = x1 + self.l2 * jnp.sin(q2)
    y2 = y1 - self.l2 * jnp.cos(q2)
    return x1, y1, x2, y2

load_list_of_arrays_from_h5(system='doublependulum', filename='trajectories.h5')

Loads a list of 2D NumPy arrays from a specified HDF5 file.

The function constructs the file path relative to the project's root data directory. It expects the arrays to be stored within an HDF5 group named 'trajectories', with individual datasets named systematically (e.g., 'trajectory_000').

Parameters:

Name Type Description Default
system str

The name of the subdirectory within the project's 'data/' folder where the HDF5 file is located. Defaults to 'doublependulum'.

'doublependulum'
filename str

The name of the HDF5 file to load. Defaults to 'trajectories.h5'.

'trajectories.h5'

Returns:

Type Description
List[array]

List[np.array]: A list of 2D NumPy arrays loaded from the HDF5 file. Returns an empty list if the specified file does not exist.

Source code in src/data/utils.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def load_list_of_arrays_from_h5(system: str = 'doublependulum', filename: str = 'trajectories.h5') -> List[np.array]:
    """Loads a list of 2D NumPy arrays from a specified HDF5 file.

    The function constructs the file path relative to the project's root
    data directory. It expects the arrays to be stored within an HDF5 group
    named 'trajectories', with individual datasets named systematically (e.g., 'trajectory_000').

    Args:
        system (str, optional): The name of the subdirectory within the project's
            'data/' folder where the HDF5 file is located. Defaults to 'doublependulum'.
        filename (str, optional): The name of the HDF5 file to load. Defaults to 'trajectories.h5'.

    Returns:
        List[np.array]: A list of 2D NumPy arrays loaded from the HDF5 file.
            Returns an empty list if the specified file does not exist.
    """
    full_filepath = os.path.join(get_project_data_path(system), filename)

    loaded_trajs = []
    loaded_params = []
    if not os.path.exists(full_filepath):
        print(f"Error: File not found at {full_filepath}")
        return []

    with h5py.File(full_filepath, 'r') as f:
        group = f['trajectories']
        for key in sorted(group.keys()):
            if key.startswith('trajectory'):
                loaded_trajs.append(group[key][()])
            elif key.startswith('param'):
                loaded_params.append(group[key][()])

    print(f"Loaded {len(loaded_trajs)} trajectories from {full_filepath}")
    return loaded_trajs, loaded_params

save_list_of_arrays_to_h5(list_of_arrays, system='doublependulum', filename='list_of_trajectories.h5')

Saves a list of 2D NumPy or JAX arrays into a single compressed HDF5 file.

Each array in the list is stored as a separate dataset within an HDF5 group named 'trajectories'. The individual datasets are named sequentially (e.g., 'trajectory_000', 'trajectory_001', etc.). The function ensures that the target directory structure exists before writing the file.

Parameters:

Name Type Description Default
list_of_arrays List[Union[ndarray, ndarray]]

A list of 2D arrays (either NumPy arrays or JAX arrays) to be saved.

required
system str

The name of the subdirectory within the project's 'data/' folder where the HDF5 file should be stored. Defaults to 'doublependulum'.

'doublependulum'
filename str

The name of the HDF5 file to create. Defaults to 'trajectories.h5'.

'list_of_trajectories.h5'
Side Effects
  • Creates the target directory (e.g., 'project_root/data/doublependulum/') if it does not already exist.
  • Creates an HDF5 file at the specified path.
  • Prints messages indicating directory creation and successful file saving.
Source code in src/data/utils.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def save_list_of_arrays_to_h5(list_of_arrays: List[np.array], system: str = 'doublependulum', filename: str = 'list_of_trajectories.h5') -> None:
    """Saves a list of 2D NumPy or JAX arrays into a single compressed HDF5 file.

    Each array in the list is stored as a separate dataset within an HDF5 group
    named 'trajectories'. The individual datasets are named sequentially
    (e.g., 'trajectory_000', 'trajectory_001', etc.). The function ensures that
    the target directory structure exists before writing the file.

    Args:
        list_of_arrays (List[Union[np.ndarray, jnp.ndarray]]): A list of 2D arrays
            (either NumPy arrays or JAX arrays) to be saved.
        system (str, optional): The name of the subdirectory within the project's
            'data/' folder where the HDF5 file should be stored. Defaults to 'doublependulum'.
        filename (str, optional): The name of the HDF5 file to create.
            Defaults to 'trajectories.h5'.

    Side Effects:
        - Creates the target directory (e.g., 'project_root/data/doublependulum/') if it does not already exist.
        - Creates an HDF5 file at the specified path.
        - Prints messages indicating directory creation and successful file saving.
    """
    full_path = os.path.join(get_project_data_path(system), filename)

    output_dir = os.path.dirname(full_path)

    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)
        print(f"Created directory: {output_dir}")

    with h5py.File(full_path, 'w') as f:
        # Create a group to better organize your arrays, if desired
        # If you have different types of lists, you could create multiple groups
        group = f.create_group('trajectories') 

        for i, arr in enumerate(list_of_arrays):
            # Convert JAX array to NumPy array for saving if it's not already
            np_arr = np.asarray(arr)
            # Create a dataset for each array within the group
            # Use gzip compression, which is generally a good balance.
            # You can also use 'lzf' for faster (de)compression, sometimes at lower ratio.
            group.create_dataset(f'trajectory_{i:03d}', data=np_arr, compression="gzip", compression_opts=9)
    print(f"List of arrays saved to {full_path}")

DoublePendulum

Bases: Module

Double pendulum code assuming the following state variables:

x = [t1, t2, w1, w2] t1: rad, angle of pendulum 1 from downward vertical t2: rad, angle of pendulum 2 from downward vertical w1: rad/s, angular velocity of pendulum 1 w2: rad/s, angular velocity of pendulum 2

analytical state transition returns the state vector derivative d/dt x: [w1, w2, g1, g2]

Source code in src/data/doublependulum.py
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
class DoublePendulum(eqx.Module):
    """
    Double pendulum code assuming the following state variables:

    x = [t1, t2, w1, w2]
    t1: rad, angle of pendulum 1 from downward vertical
    t2: rad, angle of pendulum 2 from downward vertical
    w1: rad/s, angular velocity of pendulum 1
    w2: rad/s, angular velocity of pendulum 2

    analytical state transition returns the state vector derivative d/dt x: [w1, w2, g1, g2]
    """

    m1: float = 1.0
    m2: float = 1.0
    l1: float = 1.0
    l2: float = 1.0
    g: float = GRAVITY

    @jit
    def kinetic_energy(self, q, q_dot):
        (t1, t2), (w1, w2) = q, q_dot
        T1 = 0.5 * self.m1 * (self.l1 * w1)**2
        T2 = 0.5 * self.m2 * ((self.l1 * w1)**2 + (self.l2 * w2)**2 + 2 * self.l1 * self.l2 * w1 * w2 * jnp.cos(t1 - t2))
        T  = T1 + T2
        return T

    @jit
    def potential_energy(self, q):
        if len(q) == 2:
            (t1, t2) = q
        else:
            t1, t2 = q[:, 0], q[:, 1]
        y1 = - self.l1 * jnp.cos(t1)
        y2 = y1 - self.l2 * jnp.cos(t2)
        V = self.m1 * self.g * y1 + self.m2 * self.g * y2
        return V

    def lagrangian_fn(self, q, q_dot):
        T = self.kinetic_energy(q, q_dot)
        V = self.potential_energy(q)
        return T - V

    def hamiltonian_fn(self, q, q_dot):
        T = self.kinetic_energy(q, q_dot)
        V = self.potential_energy(q)
        return T + V

    def to_cartesian(self, q: jax.Array):
        """Convert angles to Cartesian coordinates."""
        q1, q2 = q
        x1 = self.l1 * jnp.sin(q1)
        y1 = -self.l1 * jnp.cos(q1)
        x2 = x1 + self.l2 * jnp.sin(q2)
        y2 = y1 - self.l2 * jnp.cos(q2)
        return x1, y1, x2, y2

    @staticmethod
    def is_low_energy(q, q_dot, m1, m2, l1, l2, g=9.81):
        t1, t2 = q
        w1, w2 = q_dot

        # PE at unstable equilibrium (both up)
        V_max = (m1 + m2) * g * l1 + m2 * g * l2

        # Total energy at initial condition
        T = 0.5 * m1 * (l1*w1)**2 + 0.5 * m2 * ((l1*w1)**2 + (l2*w2)**2 + 2*l1*l2*w1*w2*jnp.cos(t1-t2))
        V = -(m1 + m2) * g * l1 * jnp.cos(t1) - m2 * g * l2 * jnp.cos(t2)
        H = T + V

        return H < V_max

    @jit
    def analytical_state_transition(self, full_state, t):
        """
        1 - a1 * a2 in the denominator goes to zero when t1 - t2 = ±π/2 (cos → 0 kills it) — actually it's cos²(t1-t2) that drives the singularity. 
        """
        t1, t2, w1, w2 = full_state

        a1 = (self.l2 / self.l1) * (self.m2 / (self.m1 + self.m2)) * jnp.cos(t1 - t2)
        a2 = (self.l1 / self.l2) * jnp.cos(t1 - t2)

        f1 = -(self.l2 / self.l1) * (self.m2 / (self.m1 + self.m2)) * (w2**2) * jnp.sin(t1 - t2) - (self.g / self.l1) * jnp.sin(t1)
        f2 = (self.l1 / self.l2) * (w1**2) * jnp.sin(t1 - t2) - (self.g / self.l2) * jnp.sin(t2)

        g1 = (f1 - a1 * f2) / (1 - a1 * a2)
        g2 = (f2 - a2 * f1) / (1 - a1 * a2)
        return jnp.stack([w1, w2, g1, g2])

analytical_state_transition(full_state, t)

1 - a1 * a2 in the denominator goes to zero when t1 - t2 = ±π/2 (cos → 0 kills it) — actually it's cos²(t1-t2) that drives the singularity.

Source code in src/data/doublependulum.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
@jit
def analytical_state_transition(self, full_state, t):
    """
    1 - a1 * a2 in the denominator goes to zero when t1 - t2 = ±π/2 (cos → 0 kills it) — actually it's cos²(t1-t2) that drives the singularity. 
    """
    t1, t2, w1, w2 = full_state

    a1 = (self.l2 / self.l1) * (self.m2 / (self.m1 + self.m2)) * jnp.cos(t1 - t2)
    a2 = (self.l1 / self.l2) * jnp.cos(t1 - t2)

    f1 = -(self.l2 / self.l1) * (self.m2 / (self.m1 + self.m2)) * (w2**2) * jnp.sin(t1 - t2) - (self.g / self.l1) * jnp.sin(t1)
    f2 = (self.l1 / self.l2) * (w1**2) * jnp.sin(t1 - t2) - (self.g / self.l2) * jnp.sin(t2)

    g1 = (f1 - a1 * f2) / (1 - a1 * a2)
    g2 = (f2 - a2 * f1) / (1 - a1 * a2)
    return jnp.stack([w1, w2, g1, g2])

to_cartesian(q)

Convert angles to Cartesian coordinates.

Source code in src/data/doublependulum.py
113
114
115
116
117
118
119
120
def to_cartesian(self, q: jax.Array):
    """Convert angles to Cartesian coordinates."""
    q1, q2 = q
    x1 = self.l1 * jnp.sin(q1)
    y1 = -self.l1 * jnp.cos(q1)
    x2 = x1 + self.l2 * jnp.sin(q2)
    y2 = y1 - self.l2 * jnp.cos(q2)
    return x1, y1, x2, y2

get_project_data_path(sub_path='')

Constructs an absolute path to a location within the project's /data directory. Assumes this script is run from within the project structure (e.g., src/data/).

Source code in src/data/utils.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def get_project_data_path(sub_path=''):
    """
    Constructs an absolute path to a location within the project's /data directory.
    Assumes this script is run from within the project structure (e.g., src/data/).
    """
    # Get the directory of the current script (e.g., 'path/to/project_root/src/data')
    script_dir = os.path.dirname(os.path.abspath(__file__))

    # Navigate up to the project root (e.g., 'path/to/project_root')
    # Assuming script is in src/data/, need to go up two levels
    project_root = os.path.join(script_dir, '..', '..')

    # Construct the path to the desired data subdirectory
    full_data_path = os.path.join(project_root, 'data', sub_path)

    return full_data_path

load_list_of_arrays_from_h5(system='doublependulum', filename='trajectories.h5')

Loads a list of 2D NumPy arrays from a specified HDF5 file.

The function constructs the file path relative to the project's root data directory. It expects the arrays to be stored within an HDF5 group named 'trajectories', with individual datasets named systematically (e.g., 'trajectory_000').

Parameters:

Name Type Description Default
system str

The name of the subdirectory within the project's 'data/' folder where the HDF5 file is located. Defaults to 'doublependulum'.

'doublependulum'
filename str

The name of the HDF5 file to load. Defaults to 'trajectories.h5'.

'trajectories.h5'

Returns:

Type Description
List[array]

List[np.array]: A list of 2D NumPy arrays loaded from the HDF5 file. Returns an empty list if the specified file does not exist.

Source code in src/data/utils.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def load_list_of_arrays_from_h5(system: str = 'doublependulum', filename: str = 'trajectories.h5') -> List[np.array]:
    """Loads a list of 2D NumPy arrays from a specified HDF5 file.

    The function constructs the file path relative to the project's root
    data directory. It expects the arrays to be stored within an HDF5 group
    named 'trajectories', with individual datasets named systematically (e.g., 'trajectory_000').

    Args:
        system (str, optional): The name of the subdirectory within the project's
            'data/' folder where the HDF5 file is located. Defaults to 'doublependulum'.
        filename (str, optional): The name of the HDF5 file to load. Defaults to 'trajectories.h5'.

    Returns:
        List[np.array]: A list of 2D NumPy arrays loaded from the HDF5 file.
            Returns an empty list if the specified file does not exist.
    """
    full_filepath = os.path.join(get_project_data_path(system), filename)

    loaded_trajs = []
    loaded_params = []
    if not os.path.exists(full_filepath):
        print(f"Error: File not found at {full_filepath}")
        return []

    with h5py.File(full_filepath, 'r') as f:
        group = f['trajectories']
        for key in sorted(group.keys()):
            if key.startswith('trajectory'):
                loaded_trajs.append(group[key][()])
            elif key.startswith('param'):
                loaded_params.append(group[key][()])

    print(f"Loaded {len(loaded_trajs)} trajectories from {full_filepath}")
    return loaded_trajs, loaded_params

save_list_of_arrays_to_h5(list_of_arrays, system='doublependulum', filename='list_of_trajectories.h5')

Saves a list of 2D NumPy or JAX arrays into a single compressed HDF5 file.

Each array in the list is stored as a separate dataset within an HDF5 group named 'trajectories'. The individual datasets are named sequentially (e.g., 'trajectory_000', 'trajectory_001', etc.). The function ensures that the target directory structure exists before writing the file.

Parameters:

Name Type Description Default
list_of_arrays List[Union[ndarray, ndarray]]

A list of 2D arrays (either NumPy arrays or JAX arrays) to be saved.

required
system str

The name of the subdirectory within the project's 'data/' folder where the HDF5 file should be stored. Defaults to 'doublependulum'.

'doublependulum'
filename str

The name of the HDF5 file to create. Defaults to 'trajectories.h5'.

'list_of_trajectories.h5'
Side Effects
  • Creates the target directory (e.g., 'project_root/data/doublependulum/') if it does not already exist.
  • Creates an HDF5 file at the specified path.
  • Prints messages indicating directory creation and successful file saving.
Source code in src/data/utils.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def save_list_of_arrays_to_h5(list_of_arrays: List[np.array], system: str = 'doublependulum', filename: str = 'list_of_trajectories.h5') -> None:
    """Saves a list of 2D NumPy or JAX arrays into a single compressed HDF5 file.

    Each array in the list is stored as a separate dataset within an HDF5 group
    named 'trajectories'. The individual datasets are named sequentially
    (e.g., 'trajectory_000', 'trajectory_001', etc.). The function ensures that
    the target directory structure exists before writing the file.

    Args:
        list_of_arrays (List[Union[np.ndarray, jnp.ndarray]]): A list of 2D arrays
            (either NumPy arrays or JAX arrays) to be saved.
        system (str, optional): The name of the subdirectory within the project's
            'data/' folder where the HDF5 file should be stored. Defaults to 'doublependulum'.
        filename (str, optional): The name of the HDF5 file to create.
            Defaults to 'trajectories.h5'.

    Side Effects:
        - Creates the target directory (e.g., 'project_root/data/doublependulum/') if it does not already exist.
        - Creates an HDF5 file at the specified path.
        - Prints messages indicating directory creation and successful file saving.
    """
    full_path = os.path.join(get_project_data_path(system), filename)

    output_dir = os.path.dirname(full_path)

    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)
        print(f"Created directory: {output_dir}")

    with h5py.File(full_path, 'w') as f:
        # Create a group to better organize your arrays, if desired
        # If you have different types of lists, you could create multiple groups
        group = f.create_group('trajectories') 

        for i, arr in enumerate(list_of_arrays):
            # Convert JAX array to NumPy array for saving if it's not already
            np_arr = np.asarray(arr)
            # Create a dataset for each array within the group
            # Use gzip compression, which is generally a good balance.
            # You can also use 'lzf' for faster (de)compression, sometimes at lower ratio.
            group.create_dataset(f'trajectory_{i:03d}', data=np_arr, compression="gzip", compression_opts=9)
    print(f"List of arrays saved to {full_path}")

Training And Simulation Helpers

build_input_output(datasets, params, dt)

Preprocesses raw trajectory data into model inputs (X) and targets (dXdt).

The input X is augmented with system parameters. The target dXdt is computed via numerical differentiation of velocities.

Parameters:

Name Type Description Default
datasets List[Array]

List of raw trajectory data, each of shape (time_steps, 5) where columns are [time, q1, q2, w1, w2].

required
params List[Array]

List of system parameters, each of shape (4,) [m1, m2, l1, l2].

required
dt float

Time step size, used for numerical differentiation.

required

Returns:

Type Description
Tuple[Array, Array]

Tuple[jax.Array, jax.Array]: A tuple containing: - X (jax.Array): Concatenated input states, shape (num_trajectories, time_steps, features). Features order: [q1, q2, w1, w2, m1, m2, l1, l2]. - dXdt (jax.Array): Computed accelerations, shape (num_trajectories, time_steps, 2).

Source code in src/train_utils.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
def build_input_output(datasets: List[jax.Array], params: List[jax.Array], dt: float) -> Tuple[jax.Array, jax.Array]:
    """Preprocesses raw trajectory data into model inputs (X) and targets (dXdt).

    The input X is augmented with system parameters.
    The target dXdt is computed via numerical differentiation of velocities.

    Args:
        datasets (List[jnp.Array]): List of raw trajectory data, each of shape (time_steps, 5)
                                    where columns are [time, q1, q2, w1, w2].
        params (List[jnp.Array]): List of system parameters, each of shape (4,)
                                  [m1, m2, l1, l2].
        dt (float): Time step size, used for numerical differentiation.

    Returns:
        Tuple[jax.Array, jax.Array]: A tuple containing:
            - X (jax.Array): Concatenated input states, shape (num_trajectories, time_steps, features).
                             Features order: [q1, q2, w1, w2, m1, m2, l1, l2].
            - dXdt (jax.Array): Computed accelerations, shape (num_trajectories, time_steps, 2).
    """
    X_np, dXdt_np = [], []
    for i, traj in enumerate(datasets):
        # Keep preprocessing in NumPy here: np.gradient is robust for the offline
        # derivative estimate used to build supervision targets.
        x_np = np.asarray(traj[:, 1:])
        xdot_np = np.gradient(x_np[:, 2:], dt, axis=0, edge_order=2)

        # Tile physical parameters to match the trajectory length before converting
        # the assembled arrays to JAX.
        p_np = np.asarray(params[i])
        p_tiled_np = np.tile(p_np, (x_np.shape[0], 1))

        # Augmented state vector: [q1, q2, w1, w2, m1, m2, l1, l2]
        x_aug_np = np.concatenate([x_np, p_tiled_np], axis=1)

        X_np.append(x_aug_np)
        dXdt_np.append(xdot_np)

    # Convert once at the end so downstream training code can stay in JAX.
    X = jnp.asarray(np.stack(X_np))
    dXdt = jnp.asarray(np.stack(dXdt_np))
    return X, dXdt

build_temporal_batch(x, y, batch_size, temporal_chunk_len, step_key)

Builds a batch of data by sampling random temporal chunks from random trajectories.

Parameters:

Name Type Description Default
x Array

Input data, shape (num_trajectories, time_steps, features).

required
y Array

Target data, shape (num_trajectories, time_steps, output_dim).

required
batch_size int

Number of trajectories to sample for this batch.

required
temporal_chunk_len int

Length of the time chunk to extract from each sampled trajectory.

required
step_key Array

JAX PRNGKey for random sampling.

required

Returns:

Type Description
List[Array, Array]

Tuple[jax.Array, jax.Array]: A batch of inputs (x_batch) and targets (y_batch), concatenated along the first axis. Shapes: (batch_size * temporal_chunk_len, features) and (batch_size * temporal_chunk_len, output_dim).

Source code in src/train_utils.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def build_temporal_batch(x: jnp.Array, 
                         y: jnp.Array, 
                         batch_size: int, 
                         temporal_chunk_len: int, 
                         step_key) -> List[jnp.Array, jnp.Array]:
    """Builds a batch of data by sampling random temporal chunks from random trajectories.

    Args:
        x (jax.Array): Input data, shape (num_trajectories, time_steps, features).
        y (jax.Array): Target data, shape (num_trajectories, time_steps, output_dim).
        batch_size (int): Number of trajectories to sample for this batch.
        temporal_chunk_len (int): Length of the time chunk to extract from each sampled trajectory.
        step_key (jax.Array): JAX PRNGKey for random sampling.

    Returns:
        Tuple[jax.Array, jax.Array]: A batch of inputs (x_batch) and targets (y_batch),
                                     concatenated along the first axis.
                                     Shapes: (batch_size * temporal_chunk_len, features)
                                     and (batch_size * temporal_chunk_len, output_dim).
    """
    # Split key for different random operations
    sample_idxs = jax.random.randint(step_key, (batch_size,), 0, len(x))

    # Sample 'batch_size' starting time indices for chunks
    time_idxs   = jax.random.randint(step_key, (batch_size,), 0, x.shape[1]-temporal_chunk_len)

    # Extract chunks. Using a list comprehension and then concatenate is JAX-compatible.
    x_chunks = [x[sample_idxs[i]][time_idxs[i] : time_idxs[i] + temporal_chunk_len] for i in range(len(sample_idxs))]
    y_chunks = [y[sample_idxs[i]][time_idxs[i] : time_idxs[i] + temporal_chunk_len] for i in range(len(sample_idxs))]

    x_batch = jnp.concatenate(x_chunks, axis=0)  # [K*T, 8]
    y_batch = jnp.concatenate(y_chunks, axis=0)  # [K*T, 2]    
    return x_batch, y_batch

compute_V_stats(datasets, params, idx_train)

Computes mean and standard deviation of potential energy for the training set.

Parameters:

Name Type Description Default
datasets List[Array]

List of all trajectory datasets.

required
params List[Array]

List of all system parameters corresponding to datasets.

required
idx_train ndarray

Indices of trajectories belonging to the training set.

required

Returns:

Type Description
Tuple[Array, Array]

Tuple[jnp.Array, jnp.Array]: Mean and standard deviation of potential energy.

Source code in src/train_utils.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def compute_V_stats(datasets: List[jax.Array], params: List[jax.Array], idx_train: np.ndarray) -> Tuple[jnp.Array, jnp.Array]:
    """Computes mean and standard deviation of potential energy for the training set.

    Args:
        datasets (List[jnp.Array]): List of all trajectory datasets.
        params (List[jnp.Array]): List of all system parameters corresponding to datasets.
        idx_train (np.ndarray): Indices of trajectories belonging to the training set.

    Returns:
        Tuple[jnp.Array, jnp.Array]: Mean and standard deviation of potential energy.
    """
    V_all = []
    for i in idx_train:
        traj = datasets[i]
        p = params[i]
        #  Instantiate DoublePendulum for each parameter set to compute potential energy
        dp = DoublePendulum(m1=p[0], m2=p[1], l1=p[2], l2=p[3])
        # traj[:, 1:3] contains q1, q2
        V_all.append(dp.potential_energy(traj[:, 1:3]))

    V_all = jnp.stack(V_all).reshape((-1,))
    return jnp.mean(V_all), jnp.std(V_all)

load_model(model, fname='model')

Loads an Equinox model's leaves (trainable parameters) from a file.

Parameters:

Name Type Description Default
model Module

An uninitialized Equinox model with the correct architecture to load the parameters into.

required
fname Path

The base path/filename of the model to load. Assumes a '.eqx' extension.

'model'

Returns:

Type Description
Module

eqx.Module: The model with loaded parameters.

Source code in src/train_utils.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def load_model(model: eqx.Module, fname: str = "model") -> eqx.Module:
    """Loads an Equinox model's leaves (trainable parameters) from a file.

    Args:
        model (eqx.Module): An uninitialized Equinox model with the correct architecture
                            to load the parameters into.
        fname (Path): The base path/filename of the model to load. Assumes a '.eqx' extension.

    Returns:
        eqx.Module: The model with loaded parameters.
    """
    model = eqx.tree_deserialise_leaves(str(fname)+".eqx", model)
    print(f"Model loaded from: {fname}")
    return model

normalize_data(Xtrain, Xval, Xtest, dXdt_train, dXdt_val, dXdt_test, len_params, normalize=True)

Normalizes the input (X) and target (dXdt) datasets based on training set statistics.

Parameters:

Name Type Description Default
Xtrain Array

Input training data.

required
Xval Array

Input validation data.

required
Xtest Array

Input test data.

required
dXdt_train Array

Target training data (accelerations).

required
dXdt_val Array

Target validation data (accelerations).

required
dXdt_test Array

Target test data (accelerations).

required
len_params int

The number of unique parameter sets (trajectories) in the dataset. Used to handle cases where parameters might be constant across all train samples (std=0).

required
normalize bool

If True, perform normalization. Otherwise, return data as is. Defaults to True.

True

Returns:

Type Description
List[Array, Array, Array, Array, Dict]

Tuple[jax.Array, ...]: Normalized (or original) X_train, X_val, X_test, dXdt_train, dXdt_val, dXdt_test, and a dictionary of norm_stats. The tuple order is: Xtrain_norm, Xval_norm, Xtest_norm, dXdt_train_norm, dXdt_val_norm, dXdt_test_norm, norm_stats.

Source code in src/train_utils.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def normalize_data(Xtrain: jnp.Array,
                   Xval: jnp.Array,
                   Xtest: jnp.Array, 
                   dXdt_train: jnp.Array, 
                   dXdt_val: jnp.Array,
                   dXdt_test: jnp.Array, 
                   len_params: int, 
                   normalize: bool = True) -> List[jnp.Array, jnp.Array, jnp.Array, jnp.Array, Dict]:
    """Normalizes the input (X) and target (dXdt) datasets based on training set statistics.

    Args:
        Xtrain (jax.Array): Input training data.
        Xval (jax.Array): Input validation data.
        Xtest (jax.Array): Input test data.
        dXdt_train (jax.Array): Target training data (accelerations).
        dXdt_val (jax.Array): Target validation data (accelerations).
        dXdt_test (jax.Array): Target test data (accelerations).
        len_params (int): The number of unique parameter sets (trajectories)
                                          in the dataset. Used to handle cases where
                                          parameters might be constant across all train samples (std=0).
        normalize (bool, optional): If True, perform normalization. Otherwise, return data as is.
                                    Defaults to True.

    Returns:
        Tuple[jax.Array, ...]: Normalized (or original) X_train, X_val, X_test,
                               dXdt_train, dXdt_val, dXdt_test, and a dictionary of norm_stats.
                               The tuple order is: Xtrain_norm, Xval_norm, Xtest_norm,
                               dXdt_train_norm, dXdt_val_norm, dXdt_test_norm, norm_stats.
    """
    # Compute norm stats
    # ===================
    X_mean    = jnp.mean(Xtrain, axis=(0,1))    # average over multiple trajectories and full length of trajectory
    X_std     = jnp.std(Xtrain, axis=(0,1))     # average over multiple trajectories and full length of trajectory
    dXdt_mean = jnp.mean(dXdt_train, axis=(0,1))
    dXdt_std  = jnp.std(dXdt_train, axis=(0,1))

    # Zero out mean/std for angles so they pass through unchanged
    # --------------------------------------------------------
    X_mean = X_mean.at[0].set(0.0).at[1].set(0.0)
    X_std  = X_std.at[0].set(1.0).at[1].set(1.0)

    # Parameter normalization 
    # ========================
    # if we're testing a single trajectory, do not normalize parameters as std=0.
    # this is fine as the absolute value of parameters if faily small (around 1.)
    if len_params == 1:
        for i in [4, 5, 6, 7]:
            X_mean = X_mean.at[i].set(0.0)
            X_std = X_std.at[i].set(1.0)

    # Normalization stats
    # =================
    norm_stats = {
        'X_mean': X_mean, 'X_std': X_std,
        'dXdt_mean': dXdt_mean, 'dXdt_std': dXdt_std
    }

    if normalize:
        # Input normalization
        Xtrain_norm = (Xtrain - X_mean) / X_std
        Xval_norm = (Xval - X_mean) / X_std
        Xtest_norm = (Xtest - X_mean) / X_std

        # Targets: accelerations — zero-mean normalize
        dXdt_train_norm = (dXdt_train - dXdt_mean) / dXdt_std
        dXdt_val_norm = (dXdt_val - dXdt_mean) / dXdt_std
        dXdt_test_norm = (dXdt_test - dXdt_mean) / dXdt_std

        return Xtrain_norm, Xval_norm, Xtest_norm, dXdt_train_norm, dXdt_val_norm, dXdt_test_norm, norm_stats
    else:
        return Xtrain, Xval, Xtest, dXdt_train, dXdt_val, dXdt_test, norm_stats

run_diagnostics(model, Xtrain, Xtrain_norm, dXdt_train, dXdt_train_norm, norm_stats, params)

One-step acceleration check and Lagrangian structure diagnostics.

Source code in src/train_utils.py
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
def run_diagnostics(model: eqx.Module, 
                    Xtrain: jnp.Array, 
                    Xtrain_norm: jnp.Array, 
                    dXdt_train: jnp.Array, 
                    dXdt_train_norm: jnp.Array, 
                    norm_stats: Dict, 
                    params: List[jnp.Array]):
    """One-step acceleration check and Lagrangian structure diagnostics."""

    state0_physical = Xtrain[0, 0, :4]
    state0_norm     = (state0_physical - norm_stats['X_mean'][:4]) / norm_stats['X_std'][:4]
    p_norm          = Xtrain_norm[0, 0, 4:]

    # model prediction
    q_tt_norm = model(state0_norm[:2], state0_norm[2:], p_norm)
    q_tt_phys = q_tt_norm * norm_stats['dXdt_std'] + norm_stats['dXdt_mean']
    print("q_tt_norm:", q_tt_norm)
    print("q_tt_phys:", q_tt_phys)

    # ground truth
    p      = Xtrain[0, 0, 4:]
    dp     = DoublePendulum(m1=p[0], m2=p[1], l1=p[2], l2=p[3])
    deriv  = dp.analytical_state_transition(Xtrain[0, 0, :4], 0.0)
    gt_norm = (deriv[2:] - norm_stats['dXdt_mean']) / norm_stats['dXdt_std']
    print("gt q_tt physical:", deriv[2:])
    print("gt q_tt normalized:", gt_norm)
    print("stored dXdt[0,0]:", dXdt_train[0, 0])

    # mass matrix
    lagrangian_fn = lambda _q, _qt: model.compute_lagrangian(_q, _qt, p_norm)
    l_qt_fn  = jax.grad(lagrangian_fn, argnums=1)
    l_qt_qt  = jax.jacobian(l_qt_fn, argnums=1)(state0_norm[:2], state0_norm[2:])
    l_q      = jax.grad(lagrangian_fn, argnums=0)(state0_norm[:2], state0_norm[2:])
    l_qt_q   = jax.jacobian(l_qt_fn, argnums=0)(state0_norm[:2], state0_norm[2:])
    rhs      = l_q - l_qt_q @ state0_norm[2:]

    print("l_qt_qt:\n", l_qt_qt)
    print("cond(l_qt_qt):", jnp.linalg.cond(l_qt_qt))
    print("eigenvalues:", jnp.linalg.eigvalsh(l_qt_qt))
    print("l_q:", l_q)
    print("rhs:", rhs)

    # physical M for reference
    M11_phys = (p[0] + p[1]) * p[2]**2
    M22_phys = p[1] * p[3]**2
    print(f"physical M11: {M11_phys:.3f}, M22: {M22_phys:.3f}")
    print(f"model    M11: {l_qt_qt[0,0]:.3f}, M22: {l_qt_qt[1,1]:.3f}")

save_model(model, fname)

Saves an Equinox model's leaves (trainable parameters) to a file.

Parameters:

Name Type Description Default
model Module

The Equinox model to save.

required
fname Path

The base path/filename for the model. A '.eqx' extension will be added.

required
Source code in src/train_utils.py
10
11
12
13
14
15
16
17
18
def save_model(model: eqx.Module, fname: Path):
    """Saves an Equinox model's leaves (trainable parameters) to a file.

    Args:
        model (eqx.Module): The Equinox model to save.
        fname (Path): The base path/filename for the model. A '.eqx' extension will be added.
    """
    eqx.tree_serialise_leaves(str(fname)+'.eqx', model)
    print(f'Model saved: {fname}')

train_test_split(X, n_train=0.7, n_val=0.1, seed=42)

Splits the dataset into training, validation, and test sets.

Parameters:

Name Type Description Default
X Array

The full dataset (e.g., input trajectories).

required
n_train float

Proportion of data to use for the training set. Defaults to 0.7.

0.7
n_val float

Proportion of data to use for the validation set. Defaults to 0.1.

0.1
seed int

Random seed for reproducibility. Defaults to 42.

42

Returns:

Type Description
Tuple[ndarray, ndarray, ndarray]

tuple[np.ndarray, np.ndarray, np.ndarray]: Indices for train, validation, and test sets.

Source code in src/train_utils.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
def train_test_split(X: jax.Array, n_train: float = 0.7, n_val: float = 0.1, seed: int = 42) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Splits the dataset into training, validation, and test sets.

    Args:
        X (jax.Array): The full dataset (e.g., input trajectories).
        n_train (float, optional): Proportion of data to use for the training set. Defaults to 0.7.
        n_val (float, optional): Proportion of data to use for the validation set. Defaults to 0.1.
        seed (int, optional): Random seed for reproducibility. Defaults to 42.

    Returns:
        tuple[np.ndarray, np.ndarray, np.ndarray]: Indices for train, validation, and test sets.
    """

    total_samples = X.shape[0]
    rng_ = np.random.default_rng(seed)

    # Calculate sizes
    size_train = int(round(total_samples * n_train))
    size_val   = int(round(total_samples * n_val))
    size_test  = total_samples - size_train - size_val

    # Ensure sizes are non-negative
    size_train = max(0, size_train)
    size_val   = max(0, size_val)
    size_test  = max(0, size_test)
    if size_train + size_val + size_test == 0:
        raise ValueError("Calculated sizes lead to zero total samples. Check n_train, n_val, and X.shape[0].")

    # Get all indices
    all_indices = np.arange(total_samples)

    # Shuffle all indices for random split
    rng_.shuffle(all_indices)

    # Split into train, val, test based on shuffled indices
    idx_train = all_indices[:size_train]
    idx_val   = all_indices[size_train : size_train + size_val]
    idx_test  = all_indices[size_train + size_val : size_train + size_val + size_test]

    return idx_train, idx_val, idx_test

save_rollout_data(save_dir, filename_prefix, times, gt_states, sim_states, params_phys, case_label='')

Saves ground truth and simulated trajectory data to a compressed .npz file.

Source code in src/simulate.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def save_rollout_data(
    save_dir: Path,
    filename_prefix: str,
    times: jax.Array,
    gt_states: jax.Array,
    sim_states: jax.Array,
    params_phys: jax.Array,
    case_label: str = "" # e.g., "train_traj_0", "test_traj_0", "ood_case_0"
):
    """Saves ground truth and simulated trajectory data to a compressed .npz file."""
    full_filename = save_dir / f"{filename_prefix}_{case_label}.npz"

    # Convert JAX arrays to NumPy for saving, if they aren't already
    np.savez_compressed(
        full_filename,
        times=np.asarray(times),
        ground_truth=np.asarray(gt_states),
        simulated=np.asarray(sim_states),
        physical_parameters=np.asarray(params_phys)
    )
    print(f"✅ Rollout data saved to: {full_filename}")

energy_conservation_loss(model, x, split_size=2)

Calculates a loss term that penalizes drift in the model's normalized Hamiltonian within trajectory chunks.

Because the model is trained in normalized coordinates, this quantity should be interpreted as a structured normalized energy induced by the learned Lagrangian, rather than as the exact physical Hamiltonian in original units. The loss encourages temporal consistency of that learned quantity along each trajectory chunk.

Parameters:

Name Type Description Default
model Module

The neural network model.

required
x Array

The input batch of state vectors containing generalized coordinates, generalized velocities, and normalized system parameters.

required
split_size int

The dimensionality of generalized coordinates. Defaults to 2 for 2D systems.

2

Returns:

Type Description
Array

jax.Array: Variance of the model's normalized Hamiltonian across the current flattened batch chunk. In the present training setup this corresponds to a single trajectory chunk; for multi-trajectory batches this would need to be made trajectory-local explicitly.

Source code in src/losses.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def energy_conservation_loss(model: eqx.Module, x: jax.Array, split_size: int = 2) -> jax.Array:
    """
    Calculates a loss term that penalizes drift in the model's normalized
    Hamiltonian within trajectory chunks.

    Because the model is trained in normalized coordinates, this quantity should
    be interpreted as a structured normalized energy induced by the learned
    Lagrangian, rather than as the exact physical Hamiltonian in original units.
    The loss encourages temporal consistency of that learned quantity along each
    trajectory chunk.

    Args:
        model (eqx.Module): The neural network model.
        x (jax.Array): The input batch of state vectors containing generalized
                       coordinates, generalized velocities, and normalized system
                       parameters.
        split_size (int, optional): The dimensionality of generalized coordinates.
                                    Defaults to 2 for 2D systems.

    Returns:
        jax.Array: Variance of the model's normalized Hamiltonian across the
                   current flattened batch chunk. In the present training setup
                   this corresponds to a single trajectory chunk; for
                   multi-trajectory batches this would need to be made
                   trajectory-local explicitly.
    """
    batch_q, batch_qt, batch_params = jnp.split(x, [split_size, split_size*2], axis=-1)

    def single_H(q: jax.Array, qt: jax.Array, p: jax.Array):
        """Computes the model's normalized Hamiltonian for a single timestep."""
        trig_q = jnp.array([jnp.sin(q[0]), jnp.cos(q[0]),
                            jnp.sin(q[1]), jnp.cos(q[1])])
        film_params = model.film_net(p).reshape(model.n_hidden, 2)
        chol = model.compute_cholesky_entries(trig_q, film_params)
        L = jnp.array([[jax.nn.softplus(chol[0]), 0.0],
                        [chol[1], jax.nn.softplus(chol[2])]])
        M = L.T @ L + jnp.eye(2) * 1e-6
        T = 0.5 * qt @ M @ qt
        V = model.compute_potential(trig_q, p)
        return T + V

    H = jax.vmap(single_H)(batch_q, batch_qt, batch_params)
    # NOTE:
    # This variance is computed over the full flattened batch. In the current
    # training setup, each batch is a single temporal chunk from one trajectory,
    # so this matches the intended notion of trajectory-local energy
    # consistency. If batching is later extended to multiple trajectories or
    # parameter settings at once, this should be updated to compute the variance
    # per trajectory chunk and then average across chunks.
    return jnp.var(H)  # should stay approximately constant along a trajectory chunk