API Reference
Full auto-generated API documentation for the diff_ensemble package.
Model
Bases: Module
Variational Autoencoder for generating protein structural ensembles.
The model encodes sequence features into a latent Gaussian distribution,
draws ensemble_size samples, and decodes each sample into a set of
backbone torsion angles. Coordinates are obtained by calling
:func:build_backbone_coords.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
seq_len
|
Number of residues in the protein. |
required | |
latent_dim
|
Dimensionality of the latent space. |
required | |
ensemble_size
|
Number of conformations to sample per forward pass. |
required | |
hidden_dim
|
Width of the hidden layers in the encoder and decoder. |
required |
Source code in diff_ensemble/model.py
__call__(x, rng)
Forward pass: encode → reparameterise → decode.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
ndarray
|
|
required |
rng
|
Any
|
JAX PRNG key used for latent sampling. |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Tuple of |
ndarray
|
|
ndarray
|
|
tuple[ndarray, ndarray, ndarray]
|
|
Source code in diff_ensemble/model.py
generate_coordinates(torsions)
Convert backbone torsions to N–Cα–C Cartesian coordinates.
Delegates to the module-level :func:build_backbone_coords function,
which is safe to call inside or outside a JIT-compiled context.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
torsions
|
ndarray
|
|
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
|
Source code in diff_ensemble/model.py
Bases: Module
Maps sequence features to latent distribution parameters (μ, log σ²).
Source code in diff_ensemble/model.py
__call__(x)
Forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
ndarray
|
|
required |
Returns:
| Type | Description |
|---|---|
tuple[ndarray, ndarray]
|
Tuple of |
Source code in diff_ensemble/model.py
Bases: Module
Maps latent samples to protein backbone torsions (φ, ψ).
Source code in diff_ensemble/model.py
__call__(z)
Forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
z
|
ndarray
|
|
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
|
Source code in diff_ensemble/model.py
Convert backbone torsions (φ, ψ) to N–Cα–C Cartesian coordinates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
torsions
|
ndarray
|
|
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
|
ndarray
|
backbone heavy atoms (N, Cα, C) in each model of the ensemble. |
Source code in diff_ensemble/model.py
Ensemble Predictor
High-level wrapper for a trained :class:EnsembleVAE.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
EnsembleVAE
|
An :class: |
required |
params
|
dict[str, Any]
|
Trained Flax parameter dict (from |
required |
Example::
import jax
from diff_ensemble import EnsembleVAE, EnsemblePredictor
model = EnsembleVAE(seq_len=90, latent_dim=32, ensemble_size=100)
rng = jax.random.PRNGKey(0)
params = model.init(rng, jnp.ones((1, 90, 4)), rng)["params"]
predictor = EnsemblePredictor(model, params)
coords = predictor.predict(jnp.ones((1, 90, 4)), rng)
print(predictor.compute_rg(coords))
Source code in diff_ensemble/ensemble.py
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | |
compute_end_to_end_distance(coords, weights=None)
Compute the ensemble-averaged end-to-end distance.
Measures the distance between the N-terminal N atom (index 0) and the C-terminal C atom (last atom) of each backbone model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
coords
|
ndarray
|
|
required |
weights
|
ndarray | None
|
|
None
|
Returns:
| Type | Description |
|---|---|
ndarray
|
|
ndarray
|
values (both in Ångströms). |
Source code in diff_ensemble/ensemble.py
compute_population_average(observable_fn, coords, weights=None, *args, **kwargs)
Compute a population-weighted ensemble average of an observable.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
observable_fn
|
Callable[..., ndarray]
|
A function with signature
|
required |
coords
|
ndarray
|
|
required |
weights
|
ndarray | None
|
|
None
|
*args
|
Any
|
Extra positional arguments forwarded to |
()
|
**kwargs
|
Any
|
Extra keyword arguments forwarded to |
{}
|
Returns:
| Type | Description |
|---|---|
ndarray
|
Observable averaged over the ensemble. |
Source code in diff_ensemble/ensemble.py
compute_rg(coords, weights=None)
Compute the ensemble-averaged radius of gyration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
coords
|
ndarray
|
|
required |
weights
|
ndarray | None
|
|
None
|
Returns:
| Type | Description |
|---|---|
ndarray
|
|
ndarray
|
the per-model Rg array (both in Ångströms). |
Source code in diff_ensemble/ensemble.py
predict(sequence_features, rng, n_samples=None)
Generate a structural ensemble from sequence features.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sequence_features
|
ndarray
|
|
required |
rng
|
Any
|
JAX PRNG key. |
required |
n_samples
|
int | None
|
If provided, overrides the model's default
|
None
|
Returns:
| Type | Description |
|---|---|
ndarray
|
|
ndarray
|
backbone atoms (N, Cα, C) in each generated conformation. |
Source code in diff_ensemble/ensemble.py
Observables
Calculate ensemble-averaged SAXS intensity via the Debye formula.
Args:
coords: ``(M, N, 3)`` coordinates where *M* is ensemble size and *N*
is atom count.
q_values: ``(Q,)`` scattering vector magnitudes in Å⁻¹.
form_factors: ``(N, Q)`` atomic form factors.
from typing import Any, cast
import jax.numpy as jnp
...
Returns:
(Q,) ensemble-averaged intensity.
Source code in diff_ensemble/observables.py
Calculate ensemble-averaged Residual Dipolar Couplings (RDCs).
.. note::
This function requires the diff_biophys.nmr.rdc kernel, which is
planned for a future release. Calling it will raise
:exc:NotImplementedError until the kernel is available.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
coords
|
ndarray
|
|
required |
bond_vectors
|
ndarray
|
|
required |
alignment_tensor
|
ndarray
|
|
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
|
Raises:
| Type | Description |
|---|---|
NotImplementedError
|
Until the |
Source code in diff_ensemble/observables.py
Calculate ensemble-averaged backbone chemical shifts.
.. note::
This function requires the diff_biophys.nmr.shifts kernel, which is
planned for a future release. Calling it will raise
:exc:NotImplementedError until the kernel is available.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
coords
|
ndarray
|
|
required |
sequence
|
ndarray
|
|
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
|
ndarray
|
(Hα, Hβ, C, Cα, Cβ, N) in ppm. |
Raises:
| Type | Description |
|---|---|
NotImplementedError
|
Until the |
Source code in diff_ensemble/observables.py
KL divergence between the encoder posterior and a unit Gaussian prior.
Uses the closed-form expression for Gaussian KL:
-0.5 * Σ (1 + log σ² − μ² − σ²).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mean
|
ndarray
|
|
required |
logvar
|
ndarray
|
|
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Scalar KL divergence (non-negative). |
Source code in diff_ensemble/observables.py
Normalised MSE between predicted and experimental SAXS profiles.
Both profiles are normalised to their first data point (I(q=0) = 1) before comparison, making the loss invariant to absolute intensity scale.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
predicted_saxs
|
ndarray
|
|
required |
experimental_saxs
|
ndarray
|
|
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Scalar mean-squared error after normalisation. |
Source code in diff_ensemble/observables.py
Training
Initialise model parameters and return an initial :class:TrainState.
Args:
model: An :class:`~diff_ensemble.model.EnsembleVAE` instance.
rng: JAX PRNG key.
learning_rate: Adam learning rate.
input_shape: Shape of a single input batch, e.g. ``(1, seq_len, 4)``.
from typing import Any, cast
...
Returns:
Initialised :class:TrainState.
Source code in diff_ensemble/train.py
Perform a single gradient update step.
The PRNG key stored in state is consumed for latent sampling, and the
state is returned with a freshly split key ready for the next step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
TrainState
|
Current :class: |
required |
batch_x
|
ndarray
|
|
required |
exp_saxs
|
ndarray
|
|
required |
q_values
|
ndarray
|
|
required |
form_factors
|
ndarray
|
|
required |
beta
|
float
|
KL-divergence weight in the ELBO loss. |
0.1
|
Returns:
| Type | Description |
|---|---|
tuple[TrainState, ndarray, ndarray, ndarray]
|
|
Source code in diff_ensemble/train.py
I/O
Save a structural ensemble to a multi-model PDB file.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
coords
|
ndarray
|
(M, N, 3) where M is ensemble size and N is atom count. |
required |
file_path
|
str
|
Path to the output PDB file. |
required |
res_names
|
list[str] | None
|
List of residue names (length n_res). Defaults to ALA. |
None
|