Skip to content

Getting Started

Installation

pip install diff-ensemble

For GPU/TPU acceleration install the appropriate JAX backend first:

# CUDA 12 (recommended for NVIDIA GPUs)
pip install "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install diff-ensemble

Quick Start

import jax
import jax.numpy as jnp
from diff_ensemble import EnsembleVAE, save_ensemble_to_pdb

# 1. Build the model (90 residues, 32 latent dims, 100-member ensemble)
model = EnsembleVAE(seq_len=90, latent_dim=32, ensemble_size=100)
rng   = jax.random.PRNGKey(0)

# 2. Initialise parameters (one-time cost)
params = model.init(rng, jnp.ones((1, 90, 4)), rng)["params"]

# 3. Generate a structural ensemble from sequence features
sequence_features = jnp.ones((1, 90, 4))   # replace with real PSSM/one-hot
torsions, mean, logvar = model.apply({"params": params}, sequence_features, rng)

# 4. Convert torsions to 3D coordinates (N–Cα–C backbone)
coords = model.generate_coordinates(torsions)  # (100, 270, 3)

# 5. Export to a multi-model PDB for visualisation in PyMOL / VMD
save_ensemble_to_pdb(coords, "my_ensemble.pdb")

Using the High-Level Predictor

:class:~diff_ensemble.EnsemblePredictor wraps trained parameters and provides convenience methods for Rg, end-to-end distance, and population-weighted observable averaging:

from diff_ensemble import EnsemblePredictor

predictor = EnsemblePredictor(model, params)

# Generate coordinates
coords = predictor.predict(sequence_features, rng)

# Structural statistics
avg_rg, rg_per_model = predictor.compute_rg(coords)
print(f"Ensemble Rg = {avg_rg:.1f} Å")

avg_ree, _ = predictor.compute_end_to_end_distance(coords)
print(f"End-to-end distance = {avg_ree:.1f} Å")

Training from Experimental Data

from diff_ensemble.train import create_train_state, train_step
import jax.numpy as jnp
import numpy as np

# Load your experimental SAXS profile
data      = np.loadtxt("my_protein_saxs.dat", comments="#")
q_values  = jnp.array(data[:, 0])
exp_saxs  = jnp.array(data[:, 1])

# Dummy form factors (replace with real atomic scattering factors)
n_atoms      = 3 * 90   # N, Cα, C for 90 residues
form_factors = jnp.ones((n_atoms, len(q_values)))

# Set up training
state = create_train_state(model, rng, learning_rate=1e-4, input_shape=(1, 90, 4))

for step in range(1000):
    state, loss, bio, kl = train_step(
        state, sequence_features, exp_saxs, q_values, form_factors, beta=0.1
    )
    if step % 100 == 0:
        print(f"Step {step:4d} | total={loss:.4f} | SAXS={bio:.4f} | KL={kl:.4f}")