Skip to content

API Reference

Full auto-generated API documentation for the diff_ensemble package.


Model

Bases: Module

Variational Autoencoder for generating protein structural ensembles.

The model encodes sequence features into a latent Gaussian distribution, draws ensemble_size samples, and decodes each sample into a set of backbone torsion angles. Coordinates are obtained by calling :func:build_backbone_coords.

Parameters:

Name Type Description Default
seq_len

Number of residues in the protein.

required
latent_dim

Dimensionality of the latent space.

required
ensemble_size

Number of conformations to sample per forward pass.

required
hidden_dim

Width of the hidden layers in the encoder and decoder.

required
Source code in diff_ensemble/model.py
class EnsembleVAE(nn.Module):
    """Variational Autoencoder for generating protein structural ensembles.

    The model encodes sequence features into a latent Gaussian distribution,
    draws ``ensemble_size`` samples, and decodes each sample into a set of
    backbone torsion angles.  Coordinates are obtained by calling
    :func:`build_backbone_coords`.

    Args:
        seq_len: Number of residues in the protein.
        latent_dim: Dimensionality of the latent space.
        ensemble_size: Number of conformations to sample per forward pass.
        hidden_dim: Width of the hidden layers in the encoder and decoder.
    """

    seq_len: int
    latent_dim: int
    ensemble_size: int = 100
    hidden_dim: int = 256

    def setup(self) -> None:
        self.encoder = Encoder(latent_dim=self.latent_dim, hidden_dim=self.hidden_dim)
        self.decoder = Decoder(seq_len=self.seq_len, hidden_dim=self.hidden_dim)

    def __call__(self, x: jnp.ndarray, rng: Any) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
        """Forward pass: encode → reparameterise → decode.

        Args:
            x: ``(1, seq_len, features)`` sequence features.  Batch size must
                be 1; EnsembleVAE generates diversity through the ensemble
                dimension, not the batch dimension.
            rng: JAX PRNG key used for latent sampling.

        Returns:
            Tuple of ``(torsions, mean, logvar)`` where

            * ``torsions``: ``(ensemble_size, seq_len, 2)``
            * ``mean``:     ``(latent_dim,)``
            * ``logvar``:   ``(latent_dim,)``
        """
        mean_batch, logvar_batch = self.encoder(x)  # (batch_size, latent_dim)

        # EnsembleVAE processes one sequence at a time; squeeze the batch dim
        # so broadcasting with the ensemble axis is unambiguous.
        mean = mean_batch[0]  # (latent_dim,)
        logvar = logvar_batch[0]  # (latent_dim,)

        std = jnp.exp(0.5 * logvar)
        eps = jax.random.normal(rng, (self.ensemble_size, self.latent_dim))
        z = mean + eps * std  # (ensemble_size, latent_dim) ✓

        torsions = self.decoder(z)  # (ensemble_size, seq_len, 2)
        return torsions, mean, logvar

    def generate_coordinates(self, torsions: jnp.ndarray) -> jnp.ndarray:
        """Convert backbone torsions to N–Cα–C Cartesian coordinates.

        Delegates to the module-level :func:`build_backbone_coords` function,
        which is safe to call inside or outside a JIT-compiled context.

        Args:
            torsions: ``(ensemble_size, seq_len, 2)`` torsion angles.

        Returns:
            ``(ensemble_size, seq_len * 3, 3)`` Cartesian coordinates.
        """
        return build_backbone_coords(torsions)

__call__(x, rng)

Forward pass: encode → reparameterise → decode.

Parameters:

Name Type Description Default
x ndarray

(1, seq_len, features) sequence features. Batch size must be 1; EnsembleVAE generates diversity through the ensemble dimension, not the batch dimension.

required
rng Any

JAX PRNG key used for latent sampling.

required

Returns:

Type Description
ndarray

Tuple of (torsions, mean, logvar) where

ndarray
  • torsions: (ensemble_size, seq_len, 2)
ndarray
  • mean: (latent_dim,)
tuple[ndarray, ndarray, ndarray]
  • logvar: (latent_dim,)
Source code in diff_ensemble/model.py
def __call__(self, x: jnp.ndarray, rng: Any) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """Forward pass: encode → reparameterise → decode.

    Args:
        x: ``(1, seq_len, features)`` sequence features.  Batch size must
            be 1; EnsembleVAE generates diversity through the ensemble
            dimension, not the batch dimension.
        rng: JAX PRNG key used for latent sampling.

    Returns:
        Tuple of ``(torsions, mean, logvar)`` where

        * ``torsions``: ``(ensemble_size, seq_len, 2)``
        * ``mean``:     ``(latent_dim,)``
        * ``logvar``:   ``(latent_dim,)``
    """
    mean_batch, logvar_batch = self.encoder(x)  # (batch_size, latent_dim)

    # EnsembleVAE processes one sequence at a time; squeeze the batch dim
    # so broadcasting with the ensemble axis is unambiguous.
    mean = mean_batch[0]  # (latent_dim,)
    logvar = logvar_batch[0]  # (latent_dim,)

    std = jnp.exp(0.5 * logvar)
    eps = jax.random.normal(rng, (self.ensemble_size, self.latent_dim))
    z = mean + eps * std  # (ensemble_size, latent_dim) ✓

    torsions = self.decoder(z)  # (ensemble_size, seq_len, 2)
    return torsions, mean, logvar

generate_coordinates(torsions)

Convert backbone torsions to N–Cα–C Cartesian coordinates.

Delegates to the module-level :func:build_backbone_coords function, which is safe to call inside or outside a JIT-compiled context.

Parameters:

Name Type Description Default
torsions ndarray

(ensemble_size, seq_len, 2) torsion angles.

required

Returns:

Type Description
ndarray

(ensemble_size, seq_len * 3, 3) Cartesian coordinates.

Source code in diff_ensemble/model.py
def generate_coordinates(self, torsions: jnp.ndarray) -> jnp.ndarray:
    """Convert backbone torsions to N–Cα–C Cartesian coordinates.

    Delegates to the module-level :func:`build_backbone_coords` function,
    which is safe to call inside or outside a JIT-compiled context.

    Args:
        torsions: ``(ensemble_size, seq_len, 2)`` torsion angles.

    Returns:
        ``(ensemble_size, seq_len * 3, 3)`` Cartesian coordinates.
    """
    return build_backbone_coords(torsions)

Bases: Module

Maps sequence features to latent distribution parameters (μ, log σ²).

Source code in diff_ensemble/model.py
class Encoder(nn.Module):
    """Maps sequence features to latent distribution parameters (μ, log σ²)."""

    latent_dim: int
    hidden_dim: int = 256

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
        """Forward pass.

        Args:
            x: ``(batch_size, seq_len, features)`` sequence feature tensor.

        Returns:
            Tuple of ``(mean, logvar)`` each shaped ``(batch_size, latent_dim)``.
        """
        batch_size = x.shape[0]
        x = x.reshape((batch_size, -1))

        x = nn.Dense(self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(self.hidden_dim)(x)
        x = nn.relu(x)

        mean = nn.Dense(self.latent_dim)(x)
        logvar = nn.Dense(self.latent_dim)(x)
        return mean, logvar

__call__(x)

Forward pass.

Parameters:

Name Type Description Default
x ndarray

(batch_size, seq_len, features) sequence feature tensor.

required

Returns:

Type Description
tuple[ndarray, ndarray]

Tuple of (mean, logvar) each shaped (batch_size, latent_dim).

Source code in diff_ensemble/model.py
@nn.compact
def __call__(self, x: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
    """Forward pass.

    Args:
        x: ``(batch_size, seq_len, features)`` sequence feature tensor.

    Returns:
        Tuple of ``(mean, logvar)`` each shaped ``(batch_size, latent_dim)``.
    """
    batch_size = x.shape[0]
    x = x.reshape((batch_size, -1))

    x = nn.Dense(self.hidden_dim)(x)
    x = nn.relu(x)
    x = nn.Dense(self.hidden_dim)(x)
    x = nn.relu(x)

    mean = nn.Dense(self.latent_dim)(x)
    logvar = nn.Dense(self.latent_dim)(x)
    return mean, logvar

Bases: Module

Maps latent samples to protein backbone torsions (φ, ψ).

Source code in diff_ensemble/model.py
class Decoder(nn.Module):
    """Maps latent samples to protein backbone torsions (φ, ψ)."""

    seq_len: int
    hidden_dim: int = 256

    @nn.compact
    def __call__(self, z: jnp.ndarray) -> jnp.ndarray:
        """Forward pass.

        Args:
            z: ``(ensemble_size, latent_dim)`` latent samples.

        Returns:
            ``(ensemble_size, seq_len, 2)`` torsion angles in radians.
        """
        ensemble_size = z.shape[0]

        x = nn.Dense(self.hidden_dim)(z)
        x = nn.relu(x)
        x = nn.Dense(self.hidden_dim)(x)
        x = nn.relu(x)

        # Output φ and ψ for each residue; tanh squashes to (−π, π).
        torsions = nn.Dense(self.seq_len * 2)(x)
        torsions = jnp.tanh(torsions) * jnp.pi

        return torsions.reshape((ensemble_size, self.seq_len, 2))

__call__(z)

Forward pass.

Parameters:

Name Type Description Default
z ndarray

(ensemble_size, latent_dim) latent samples.

required

Returns:

Type Description
ndarray

(ensemble_size, seq_len, 2) torsion angles in radians.

Source code in diff_ensemble/model.py
@nn.compact
def __call__(self, z: jnp.ndarray) -> jnp.ndarray:
    """Forward pass.

    Args:
        z: ``(ensemble_size, latent_dim)`` latent samples.

    Returns:
        ``(ensemble_size, seq_len, 2)`` torsion angles in radians.
    """
    ensemble_size = z.shape[0]

    x = nn.Dense(self.hidden_dim)(z)
    x = nn.relu(x)
    x = nn.Dense(self.hidden_dim)(x)
    x = nn.relu(x)

    # Output φ and ψ for each residue; tanh squashes to (−π, π).
    torsions = nn.Dense(self.seq_len * 2)(x)
    torsions = jnp.tanh(torsions) * jnp.pi

    return torsions.reshape((ensemble_size, self.seq_len, 2))

Convert backbone torsions (φ, ψ) to N–Cα–C Cartesian coordinates.

Parameters:

Name Type Description Default
torsions ndarray

(ensemble_size, seq_len, 2) array where the last axis is [phi, psi] in radians.

required

Returns:

Type Description
ndarray

(ensemble_size, seq_len * 3, 3) Cartesian coordinates for all

ndarray

backbone heavy atoms (N, Cα, C) in each model of the ensemble.

Source code in diff_ensemble/model.py
def build_backbone_coords(torsions: jnp.ndarray) -> jnp.ndarray:
    """Convert backbone torsions (φ, ψ) to N–Cα–C Cartesian coordinates.

    Args:
        torsions: ``(ensemble_size, seq_len, 2)`` array where the last axis is
            ``[phi, psi]`` in radians.

    Returns:
        ``(ensemble_size, seq_len * 3, 3)`` Cartesian coordinates for all
        backbone heavy atoms (N, Cα, C) in each model of the ensemble.
    """
    ensemble_size, n_res, _ = torsions.shape

    # Pre-compute the *static* geometry arrays once, outside vmap.
    # Pattern per residue: (C→N bond, N→Cα bond, Cα→C bond)
    bond_pattern = jnp.array([_B_C_N, _B_N_CA, _B_CA_C] * n_res)  # (3*n_res,)
    angle_pattern = jnp.array([_A_CA_C_N, _A_C_N_CA, _A_N_CA_C] * n_res)  # (3*n_res,)
    omega_vec = jnp.full((n_res,), _OMEGA)  # (n_res,)

    # Drop first triplet — those atoms are provided via init_coords.
    bonds = bond_pattern[3:]  # (3*n_res - 3,)
    angles = angle_pattern[3:]  # (3*n_res - 3,)

    def _build_one(torsion_pair: jnp.ndarray) -> jnp.ndarray:
        """Build one structure from ``(n_res, 2)`` torsions."""
        phi = torsion_pair[:, 0]  # (n_res,)
        psi = torsion_pair[:, 1]  # (n_res,)

        # Interleave: for each residue the dihedral order is (ω, φ, ψ).
        # Stack columns then flatten to (3*n_res,) and drop the first triplet.
        dihedrals_full = jnp.stack([omega_vec, phi, psi], axis=1).reshape(-1)
        dihedrals = dihedrals_full[3:]  # (3*n_res - 3,)

        return cast(jnp.ndarray, chain_nerf(_INIT_COORDS, bonds, angles, dihedrals))

    return cast(jnp.ndarray, jax.vmap(_build_one)(torsions))  # (ensemble_size, 3*n_res, 3)

Ensemble Predictor

High-level wrapper for a trained :class:EnsembleVAE.

Parameters:

Name Type Description Default
model EnsembleVAE

An :class:~diff_ensemble.model.EnsembleVAE instance.

required
params dict[str, Any]

Trained Flax parameter dict (from model.init or a checkpoint).

required

Example::

import jax
from diff_ensemble import EnsembleVAE, EnsemblePredictor

model = EnsembleVAE(seq_len=90, latent_dim=32, ensemble_size=100)
rng = jax.random.PRNGKey(0)
params = model.init(rng, jnp.ones((1, 90, 4)), rng)["params"]

predictor = EnsemblePredictor(model, params)
coords = predictor.predict(jnp.ones((1, 90, 4)), rng)
print(predictor.compute_rg(coords))
Source code in diff_ensemble/ensemble.py
class EnsemblePredictor:
    """High-level wrapper for a trained :class:`EnsembleVAE`.

    Args:
        model: An :class:`~diff_ensemble.model.EnsembleVAE` instance.
        params: Trained Flax parameter dict (from ``model.init`` or a
            checkpoint).

    Example::

        import jax
        from diff_ensemble import EnsembleVAE, EnsemblePredictor

        model = EnsembleVAE(seq_len=90, latent_dim=32, ensemble_size=100)
        rng = jax.random.PRNGKey(0)
        params = model.init(rng, jnp.ones((1, 90, 4)), rng)["params"]

        predictor = EnsemblePredictor(model, params)
        coords = predictor.predict(jnp.ones((1, 90, 4)), rng)
        print(predictor.compute_rg(coords))
    """

    def __init__(self, model: EnsembleVAE, params: dict[str, Any]) -> None:
        self.model = model
        self.params = params

    # ------------------------------------------------------------------
    # Ensemble generation
    # ------------------------------------------------------------------

    def predict(
        self,
        sequence_features: jnp.ndarray,
        rng: Any,
        n_samples: int | None = None,
    ) -> jnp.ndarray:
        """Generate a structural ensemble from sequence features.

        Args:
            sequence_features: ``(1, seq_len, features)`` input tensor.
            rng: JAX PRNG key.
            n_samples: If provided, overrides the model's default
                ``ensemble_size`` by creating a temporary model variant.
                When ``None`` the model's ``ensemble_size`` is used.

        Returns:
            ``(ensemble_size, seq_len * 3, 3)`` Cartesian coordinates for all
            backbone atoms (N, Cα, C) in each generated conformation.
        """
        if n_samples is not None and n_samples != self.model.ensemble_size:
            # Create a lightweight variant with the requested ensemble size.
            tmp_model = EnsembleVAE(
                seq_len=self.model.seq_len,
                latent_dim=self.model.latent_dim,
                ensemble_size=n_samples,
                hidden_dim=self.model.hidden_dim,
            )
            out = tmp_model.apply({"params": self.params}, sequence_features, rng)
        else:
            out = self.model.apply({"params": self.params}, sequence_features, rng)

        torsions, _, _ = out  # type: ignore[misc]  # Flax apply return is untyped
        return build_backbone_coords(torsions)

    # ------------------------------------------------------------------
    # Population-weighted averaging
    # ------------------------------------------------------------------

    def compute_population_average(
        self,
        observable_fn: Callable[..., jnp.ndarray],
        coords: jnp.ndarray,
        weights: jnp.ndarray | None = None,
        *args: Any,
        **kwargs: Any,
    ) -> jnp.ndarray:
        """Compute a population-weighted ensemble average of an observable.

        Args:
            observable_fn: A function with signature
                ``(coords_single: (N, 3), *args, **kwargs) -> (...)`` that
                computes the observable for a single conformation.
            coords: ``(M, N, 3)`` ensemble coordinates.
            weights: ``(M,)`` population weights (must sum to 1).  Defaults to
                uniform weights (``1/M`` each).
            *args: Extra positional arguments forwarded to ``observable_fn``.
            **kwargs: Extra keyword arguments forwarded to ``observable_fn``.

        Returns:
            Observable averaged over the ensemble.
        """
        ensemble_size = coords.shape[0]

        if weights is None:
            weights = jnp.ones(ensemble_size) / ensemble_size

        # vmap over the ensemble dimension.
        per_model_obs = jax.vmap(lambda c: observable_fn(c, *args, **kwargs))(coords)
        # Weighted average: weights shape (M,) broadcast over observable dims.
        return jnp.average(per_model_obs, axis=0, weights=weights)

    # ------------------------------------------------------------------
    # Structural statistics
    # ------------------------------------------------------------------

    def compute_rg(
        self,
        coords: jnp.ndarray,
        weights: jnp.ndarray | None = None,
    ) -> tuple[jnp.ndarray, jnp.ndarray]:
        """Compute the ensemble-averaged radius of gyration.

        Args:
            coords: ``(M, N, 3)`` ensemble coordinates.
            weights: ``(M,)`` population weights.  Defaults to uniform.

        Returns:
            ``(avg_rg, rg_per_model)`` — the population-weighted mean Rg and
            the per-model Rg array (both in Ångströms).
        """
        ensemble_size = coords.shape[0]
        if weights is None:
            weights = jnp.ones(ensemble_size) / ensemble_size

        center = jnp.mean(coords, axis=1, keepdims=True)  # (M, 1, 3)
        sq_dist = jnp.sum((coords - center) ** 2, axis=-1)  # (M, N)
        rg_per_model = jnp.sqrt(jnp.mean(sq_dist, axis=1))  # (M,)
        avg_rg = jnp.average(rg_per_model, weights=weights)

        return avg_rg, rg_per_model

    def compute_end_to_end_distance(
        self,
        coords: jnp.ndarray,
        weights: jnp.ndarray | None = None,
    ) -> tuple[jnp.ndarray, jnp.ndarray]:
        """Compute the ensemble-averaged end-to-end distance.

        Measures the distance between the N-terminal N atom (index 0) and the
        C-terminal C atom (last atom) of each backbone model.

        Args:
            coords: ``(M, N, 3)`` ensemble coordinates.
            weights: ``(M,)`` population weights.  Defaults to uniform.

        Returns:
            ``(avg_ree, ree_per_model)`` — weighted mean R_ee and per-model
            values (both in Ångströms).
        """
        ensemble_size = coords.shape[0]
        if weights is None:
            weights = jnp.ones(ensemble_size) / ensemble_size

        ree_per_model = jnp.linalg.norm(coords[:, -1, :] - coords[:, 0, :], axis=-1)
        avg_ree = jnp.average(ree_per_model, weights=weights)

        return avg_ree, ree_per_model

compute_end_to_end_distance(coords, weights=None)

Compute the ensemble-averaged end-to-end distance.

Measures the distance between the N-terminal N atom (index 0) and the C-terminal C atom (last atom) of each backbone model.

Parameters:

Name Type Description Default
coords ndarray

(M, N, 3) ensemble coordinates.

required
weights ndarray | None

(M,) population weights. Defaults to uniform.

None

Returns:

Type Description
ndarray

(avg_ree, ree_per_model) — weighted mean R_ee and per-model

ndarray

values (both in Ångströms).

Source code in diff_ensemble/ensemble.py
def compute_end_to_end_distance(
    self,
    coords: jnp.ndarray,
    weights: jnp.ndarray | None = None,
) -> tuple[jnp.ndarray, jnp.ndarray]:
    """Compute the ensemble-averaged end-to-end distance.

    Measures the distance between the N-terminal N atom (index 0) and the
    C-terminal C atom (last atom) of each backbone model.

    Args:
        coords: ``(M, N, 3)`` ensemble coordinates.
        weights: ``(M,)`` population weights.  Defaults to uniform.

    Returns:
        ``(avg_ree, ree_per_model)`` — weighted mean R_ee and per-model
        values (both in Ångströms).
    """
    ensemble_size = coords.shape[0]
    if weights is None:
        weights = jnp.ones(ensemble_size) / ensemble_size

    ree_per_model = jnp.linalg.norm(coords[:, -1, :] - coords[:, 0, :], axis=-1)
    avg_ree = jnp.average(ree_per_model, weights=weights)

    return avg_ree, ree_per_model

compute_population_average(observable_fn, coords, weights=None, *args, **kwargs)

Compute a population-weighted ensemble average of an observable.

Parameters:

Name Type Description Default
observable_fn Callable[..., ndarray]

A function with signature (coords_single: (N, 3), *args, **kwargs) -> (...) that computes the observable for a single conformation.

required
coords ndarray

(M, N, 3) ensemble coordinates.

required
weights ndarray | None

(M,) population weights (must sum to 1). Defaults to uniform weights (1/M each).

None
*args Any

Extra positional arguments forwarded to observable_fn.

()
**kwargs Any

Extra keyword arguments forwarded to observable_fn.

{}

Returns:

Type Description
ndarray

Observable averaged over the ensemble.

Source code in diff_ensemble/ensemble.py
def compute_population_average(
    self,
    observable_fn: Callable[..., jnp.ndarray],
    coords: jnp.ndarray,
    weights: jnp.ndarray | None = None,
    *args: Any,
    **kwargs: Any,
) -> jnp.ndarray:
    """Compute a population-weighted ensemble average of an observable.

    Args:
        observable_fn: A function with signature
            ``(coords_single: (N, 3), *args, **kwargs) -> (...)`` that
            computes the observable for a single conformation.
        coords: ``(M, N, 3)`` ensemble coordinates.
        weights: ``(M,)`` population weights (must sum to 1).  Defaults to
            uniform weights (``1/M`` each).
        *args: Extra positional arguments forwarded to ``observable_fn``.
        **kwargs: Extra keyword arguments forwarded to ``observable_fn``.

    Returns:
        Observable averaged over the ensemble.
    """
    ensemble_size = coords.shape[0]

    if weights is None:
        weights = jnp.ones(ensemble_size) / ensemble_size

    # vmap over the ensemble dimension.
    per_model_obs = jax.vmap(lambda c: observable_fn(c, *args, **kwargs))(coords)
    # Weighted average: weights shape (M,) broadcast over observable dims.
    return jnp.average(per_model_obs, axis=0, weights=weights)

compute_rg(coords, weights=None)

Compute the ensemble-averaged radius of gyration.

Parameters:

Name Type Description Default
coords ndarray

(M, N, 3) ensemble coordinates.

required
weights ndarray | None

(M,) population weights. Defaults to uniform.

None

Returns:

Type Description
ndarray

(avg_rg, rg_per_model) — the population-weighted mean Rg and

ndarray

the per-model Rg array (both in Ångströms).

Source code in diff_ensemble/ensemble.py
def compute_rg(
    self,
    coords: jnp.ndarray,
    weights: jnp.ndarray | None = None,
) -> tuple[jnp.ndarray, jnp.ndarray]:
    """Compute the ensemble-averaged radius of gyration.

    Args:
        coords: ``(M, N, 3)`` ensemble coordinates.
        weights: ``(M,)`` population weights.  Defaults to uniform.

    Returns:
        ``(avg_rg, rg_per_model)`` — the population-weighted mean Rg and
        the per-model Rg array (both in Ångströms).
    """
    ensemble_size = coords.shape[0]
    if weights is None:
        weights = jnp.ones(ensemble_size) / ensemble_size

    center = jnp.mean(coords, axis=1, keepdims=True)  # (M, 1, 3)
    sq_dist = jnp.sum((coords - center) ** 2, axis=-1)  # (M, N)
    rg_per_model = jnp.sqrt(jnp.mean(sq_dist, axis=1))  # (M,)
    avg_rg = jnp.average(rg_per_model, weights=weights)

    return avg_rg, rg_per_model

predict(sequence_features, rng, n_samples=None)

Generate a structural ensemble from sequence features.

Parameters:

Name Type Description Default
sequence_features ndarray

(1, seq_len, features) input tensor.

required
rng Any

JAX PRNG key.

required
n_samples int | None

If provided, overrides the model's default ensemble_size by creating a temporary model variant. When None the model's ensemble_size is used.

None

Returns:

Type Description
ndarray

(ensemble_size, seq_len * 3, 3) Cartesian coordinates for all

ndarray

backbone atoms (N, Cα, C) in each generated conformation.

Source code in diff_ensemble/ensemble.py
def predict(
    self,
    sequence_features: jnp.ndarray,
    rng: Any,
    n_samples: int | None = None,
) -> jnp.ndarray:
    """Generate a structural ensemble from sequence features.

    Args:
        sequence_features: ``(1, seq_len, features)`` input tensor.
        rng: JAX PRNG key.
        n_samples: If provided, overrides the model's default
            ``ensemble_size`` by creating a temporary model variant.
            When ``None`` the model's ``ensemble_size`` is used.

    Returns:
        ``(ensemble_size, seq_len * 3, 3)`` Cartesian coordinates for all
        backbone atoms (N, Cα, C) in each generated conformation.
    """
    if n_samples is not None and n_samples != self.model.ensemble_size:
        # Create a lightweight variant with the requested ensemble size.
        tmp_model = EnsembleVAE(
            seq_len=self.model.seq_len,
            latent_dim=self.model.latent_dim,
            ensemble_size=n_samples,
            hidden_dim=self.model.hidden_dim,
        )
        out = tmp_model.apply({"params": self.params}, sequence_features, rng)
    else:
        out = self.model.apply({"params": self.params}, sequence_features, rng)

    torsions, _, _ = out  # type: ignore[misc]  # Flax apply return is untyped
    return build_backbone_coords(torsions)

Observables

Calculate ensemble-averaged SAXS intensity via the Debye formula.

Args:
    coords: ``(M, N, 3)`` coordinates where *M* is ensemble size and *N*
        is atom count.
    q_values: ``(Q,)`` scattering vector magnitudes in Å⁻¹.
    form_factors: ``(N, Q)`` atomic form factors.

from typing import Any, cast

import jax.numpy as jnp ... Returns: (Q,) ensemble-averaged intensity.

Source code in diff_ensemble/observables.py
def get_ensemble_saxs(
    coords: jnp.ndarray,
    q_values: jnp.ndarray,
    form_factors: jnp.ndarray,
) -> jnp.ndarray:
    """Calculate ensemble-averaged SAXS intensity via the Debye formula.

        Args:
            coords: ``(M, N, 3)`` coordinates where *M* is ensemble size and *N*
                is atom count.
            q_values: ``(Q,)`` scattering vector magnitudes in Å⁻¹.
            form_factors: ``(N, Q)`` atomic form factors.
    from typing import Any, cast

    import jax.numpy as jnp
    ...
        Returns:
            ``(Q,)`` ensemble-averaged intensity.
    """
    ensemble = Ensemble(coords)
    # observable_fn signature: (N, 3) → (Q,)
    return cast(jnp.ndarray, ensemble.calculate_average(debye_saxs, q_values, form_factors))

Calculate ensemble-averaged Residual Dipolar Couplings (RDCs).

.. note:: This function requires the diff_biophys.nmr.rdc kernel, which is planned for a future release. Calling it will raise :exc:NotImplementedError until the kernel is available.

Parameters:

Name Type Description Default
coords ndarray

(M, N, 3) ensemble coordinates.

required
bond_vectors ndarray

(K, 2) atom-index pairs defining each NH (or other) bond for which RDCs are measured.

required
alignment_tensor ndarray

(3, 3) Saupe order tensor A for the alignment medium used in the experiment.

required

Returns:

Type Description
ndarray

(K,) ensemble-averaged RDC values in Hz.

Raises:

Type Description
NotImplementedError

Until the diff_biophys.nmr module is released.

Source code in diff_ensemble/observables.py
def get_ensemble_rdc(
    coords: jnp.ndarray,
    bond_vectors: jnp.ndarray,
    alignment_tensor: jnp.ndarray,
) -> jnp.ndarray:
    """Calculate ensemble-averaged Residual Dipolar Couplings (RDCs).

    .. note::
        This function requires the ``diff_biophys.nmr.rdc`` kernel, which is
        planned for a future release.  Calling it will raise
        :exc:`NotImplementedError` until the kernel is available.

    Args:
        coords: ``(M, N, 3)`` ensemble coordinates.
        bond_vectors: ``(K, 2)`` atom-index pairs defining each NH (or other)
            bond for which RDCs are measured.
        alignment_tensor: ``(3, 3)`` Saupe order tensor *A* for the alignment
            medium used in the experiment.

    Returns:
        ``(K,)`` ensemble-averaged RDC values in Hz.

    Raises:
        NotImplementedError: Until the ``diff_biophys.nmr`` module is released.
    """
    raise NotImplementedError(
        "RDC calculation requires the diff_biophys.nmr.rdc kernel, "
        "which is not yet available.  Track progress at "
        "https://github.com/elkins/diff-ensemble/issues."
    )

Calculate ensemble-averaged backbone chemical shifts.

.. note:: This function requires the diff_biophys.nmr.shifts kernel, which is planned for a future release. Calling it will raise :exc:NotImplementedError until the kernel is available.

Parameters:

Name Type Description Default
coords ndarray

(M, N, 3) ensemble coordinates.

required
sequence ndarray

(seq_len,) integer-encoded amino-acid sequence (0 = ALA, …, 19 = VAL).

required

Returns:

Type Description
ndarray

(seq_len, 6) ensemble-averaged chemical shifts for backbone nuclei

ndarray

(Hα, Hβ, C, Cα, Cβ, N) in ppm.

Raises:

Type Description
NotImplementedError

Until the diff_biophys.nmr module is released.

Source code in diff_ensemble/observables.py
def get_ensemble_chemical_shifts(
    coords: jnp.ndarray,
    sequence: jnp.ndarray,
) -> jnp.ndarray:
    """Calculate ensemble-averaged backbone chemical shifts.

    .. note::
        This function requires the ``diff_biophys.nmr.shifts`` kernel, which is
        planned for a future release.  Calling it will raise
        :exc:`NotImplementedError` until the kernel is available.

    Args:
        coords: ``(M, N, 3)`` ensemble coordinates.
        sequence: ``(seq_len,)`` integer-encoded amino-acid sequence
            (0 = ALA, …, 19 = VAL).

    Returns:
        ``(seq_len, 6)`` ensemble-averaged chemical shifts for backbone nuclei
        (Hα, Hβ, C, Cα, Cβ, N) in ppm.

    Raises:
        NotImplementedError: Until the ``diff_biophys.nmr`` module is released.
    """
    raise NotImplementedError(
        "Chemical shift calculation requires the diff_biophys.nmr.shifts kernel, "
        "which is not yet available.  Track progress at "
        "https://github.com/elkins/diff-ensemble/issues."
    )

KL divergence between the encoder posterior and a unit Gaussian prior.

Uses the closed-form expression for Gaussian KL: -0.5 * Σ (1 + log σ² − μ² − σ²).

Parameters:

Name Type Description Default
mean ndarray

(latent_dim,) posterior mean.

required
logvar ndarray

(latent_dim,) posterior log-variance.

required

Returns:

Type Description
ndarray

Scalar KL divergence (non-negative).

Source code in diff_ensemble/observables.py
def kld_loss(mean: jnp.ndarray, logvar: jnp.ndarray) -> jnp.ndarray:
    """KL divergence between the encoder posterior and a unit Gaussian prior.

    Uses the closed-form expression for Gaussian KL:
    ``-0.5 * Σ (1 + log σ² − μ² − σ²)``.

    Args:
        mean: ``(latent_dim,)`` posterior mean.
        logvar: ``(latent_dim,)`` posterior log-variance.

    Returns:
        Scalar KL divergence (non-negative).
    """
    return -0.5 * jnp.sum(1 + logvar - jnp.square(mean) - jnp.exp(logvar))

Normalised MSE between predicted and experimental SAXS profiles.

Both profiles are normalised to their first data point (I(q=0) = 1) before comparison, making the loss invariant to absolute intensity scale.

Parameters:

Name Type Description Default
predicted_saxs ndarray

(Q,) predicted SAXS intensities.

required
experimental_saxs ndarray

(Q,) experimental SAXS intensities.

required

Returns:

Type Description
ndarray

Scalar mean-squared error after normalisation.

Source code in diff_ensemble/observables.py
def biophysical_loss(
    predicted_saxs: jnp.ndarray,
    experimental_saxs: jnp.ndarray,
) -> jnp.ndarray:
    """Normalised MSE between predicted and experimental SAXS profiles.

    Both profiles are normalised to their first data point (I(q=0) = 1) before
    comparison, making the loss invariant to absolute intensity scale.

    Args:
        predicted_saxs: ``(Q,)`` predicted SAXS intensities.
        experimental_saxs: ``(Q,)`` experimental SAXS intensities.

    Returns:
        Scalar mean-squared error after normalisation.
    """
    pred = predicted_saxs / (predicted_saxs[0] + 1e-10)
    exp = experimental_saxs / (experimental_saxs[0] + 1e-10)
    return jnp.mean(jnp.square(pred - exp))

Training

Bases: TrainState

Extended Flax TrainState that carries the PRNG key for sampling.

Source code in diff_ensemble/train.py
class TrainState(train_state.TrainState):
    """Extended Flax TrainState that carries the PRNG key for sampling."""

    key: Any

Initialise model parameters and return an initial :class:TrainState.

Args:
    model: An :class:`~diff_ensemble.model.EnsembleVAE` instance.
    rng: JAX PRNG key.
    learning_rate: Adam learning rate.
    input_shape: Shape of a single input batch, e.g. ``(1, seq_len, 4)``.

from typing import Any, cast ... Returns: Initialised :class:TrainState.

Source code in diff_ensemble/train.py
def create_train_state(
    model: EnsembleVAE,
    rng: Any,
    learning_rate: float,
    input_shape: tuple[int, ...],
) -> TrainState:
    """Initialise model parameters and return an initial :class:`TrainState`.

        Args:
            model: An :class:`~diff_ensemble.model.EnsembleVAE` instance.
            rng: JAX PRNG key.
            learning_rate: Adam learning rate.
            input_shape: Shape of a single input batch, e.g. ``(1, seq_len, 4)``.

    from typing import Any, cast
    ...
        Returns:
            Initialised :class:`TrainState`.
    """
    init_key, step_key = jax.random.split(rng)
    params = model.init(init_key, jnp.ones(input_shape), init_key)["params"]
    tx = optax.adam(learning_rate)
    return cast(
        TrainState,
        TrainState.create(apply_fn=model.apply, params=params, tx=tx, key=step_key),
    )

Perform a single gradient update step.

The PRNG key stored in state is consumed for latent sampling, and the state is returned with a freshly split key ready for the next step.

Parameters:

Name Type Description Default
state TrainState

Current :class:TrainState.

required
batch_x ndarray

(1, seq_len, features) sequence feature tensor.

required
exp_saxs ndarray

(Q,) experimental SAXS intensities.

required
q_values ndarray

(Q,) scattering vector magnitudes.

required
form_factors ndarray

(N_atoms, Q) atomic form factors.

required
beta float

KL-divergence weight in the ELBO loss.

0.1

Returns:

Type Description
tuple[TrainState, ndarray, ndarray, ndarray]

(new_state, total_loss, bio_loss, kl_loss)

Source code in diff_ensemble/train.py
@jax.jit
def train_step(
    state: TrainState,
    batch_x: jnp.ndarray,
    exp_saxs: jnp.ndarray,
    q_values: jnp.ndarray,
    form_factors: jnp.ndarray,
    beta: float = 0.1,
) -> tuple[TrainState, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """Perform a single gradient update step.

    The PRNG key stored in ``state`` is consumed for latent sampling, and the
    state is returned with a freshly split key ready for the next step.

    Args:
        state: Current :class:`TrainState`.
        batch_x: ``(1, seq_len, features)`` sequence feature tensor.
        exp_saxs: ``(Q,)`` experimental SAXS intensities.
        q_values: ``(Q,)`` scattering vector magnitudes.
        form_factors: ``(N_atoms, Q)`` atomic form factors.
        beta: KL-divergence weight in the ELBO loss.

    Returns:
        ``(new_state, total_loss, bio_loss, kl_loss)``
    """
    # Split *before* use so each step consumes a unique key.
    step_key, next_key = jax.random.split(state.key)

    def loss_fn(params: Any) -> tuple[jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray]]:
        torsions, mean, logvar = state.apply_fn({"params": params}, batch_x, step_key)

        # Use the module-level function directly — no need to re-instantiate the
        # model inside the loss closure, keeping gradient flow clean and explicit.
        coords = build_backbone_coords(torsions)

        pred_saxs = get_ensemble_saxs(coords, q_values, form_factors)

        bio_loss = biophysical_loss(pred_saxs, exp_saxs)
        kl_loss = kld_loss(mean, logvar)

        total_loss = bio_loss + beta * kl_loss
        return total_loss, (bio_loss, kl_loss)

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, (bio_loss, kl_loss)), grads = grad_fn(state.params)

    state = state.apply_gradients(grads=grads)
    state = state.replace(key=next_key)

    return state, loss, bio_loss, kl_loss

I/O

Save a structural ensemble to a multi-model PDB file.

Parameters:

Name Type Description Default
coords ndarray

(M, N, 3) where M is ensemble size and N is atom count.

required
file_path str

Path to the output PDB file.

required
res_names list[str] | None

List of residue names (length n_res). Defaults to ALA.

None
Source code in diff_ensemble/io.py
def save_ensemble_to_pdb(
    coords: np.ndarray, file_path: str, res_names: list[str] | None = None
) -> None:
    """
    Save a structural ensemble to a multi-model PDB file.

    Args:
        coords: (M, N, 3) where M is ensemble size and N is atom count.
        file_path: Path to the output PDB file.
        res_names: List of residue names (length n_res). Defaults to ALA.
    """
    ensemble_size, n_atoms, _ = coords.shape
    n_res = n_atoms // 3  # Assuming N-Ca-C backbone

    if res_names is None:
        res_names = ["ALA"] * n_res

    # Create a template structure for one model
    # Atoms: N, CA, C for each residue
    atom_names = ["N", "CA", "C"] * n_res
    res_indices = []
    for i in range(n_res):
        res_indices.extend([i + 1] * 3)

    # Build biotite AtomArrayStack
    stack = struc.AtomArrayStack(ensemble_size, n_atoms)
    stack.coord = coords

    # Fill in metadata for the first model (metadata is shared in the stack)
    stack.chain_id = np.array(["A"] * n_atoms)
    stack.res_id = np.array(res_indices)
    stack.res_name = np.array([res_names[i - 1] for i in res_indices])
    stack.atom_name = np.array(atom_names)
    stack.element = np.array([name[0] for name in atom_names])  # N, C, C

    # Save to file
    pdb_file = pdb.PDBFile()
    pdb_file.set_structure(stack)
    pdb_file.write(file_path)
    print(f"Ensemble saved to {file_path}")