
ML Integration: Data Factory Flow 🤖¶
This notebook demonstrates how to use synth-pdb as a high-speed data factory for Training Protein AI models.
We leverage the BatchedGenerator to produce thousands of structures in milliseconds and feed them directly into PyTorch and JAX with Zero-Copy memory handover.
The Data Factory Workflow¶
Traditional structural bio tools are optimized for single-file PDB processing. synth-pdb is optimized for tensor throughput.

⚠️ How to Run (Important!)¶
This notebook requires a specific environment setup. Follow these steps strictly:
- Run All Cells (
Runtime->Run allorCtrl+F9). - Wait for the Crash: If on Colab, the setup cell will automatically restart the session to load libraries. This is normal.
- Local Users: If you are running locally after editing the library code, Restart your Kernel manually to ensure changes take effect.
- Wait 10 Seconds: Allow the session to reconnect.
- Run All Cells AGAIN: This time, the setup will detect it is ready ('✅ Dependencies Ready') and proceed typically.
# @title Setup & Installation { display-mode: "form" }
import os
import sys
from pathlib import Path
# Ensure the local synth_pdb source code is prioritized if running from the repo
try:
current_path = Path(".").resolve()
repo_root = current_path.parent.parent
if (repo_root / "synth_pdb").exists():
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
print(f"📌 Added local library to path: {repo_root}")
except Exception:
pass
if 'google.colab' in str(get_ipython()):
if not os.path.exists("installed.marker"):
print("Running on Google Colab. Installing dependencies...")
get_ipython().run_line_magic('pip', 'install synth-pdb py3Dmol')
with open("installed.marker", "w") as f:
f.write("done")
print("🔄 Installation complete. KERNEL RESTARTING AUTOMATICALLY...")
print("⚠️ Please wait 10 seconds, then Run All Cells again.")
os.kill(os.getpid(), 9)
else:
print("✅ Dependencies Ready.")
else:
import synth_pdb
print(f"✅ Running locally. Using synth-pdb version: {synth_pdb.__version__} from {synth_pdb.__file__}")
import time
import matplotlib.pyplot as plt
import numpy as np
from synth_pdb.batch_generator import BatchedGenerator
print("Libraries Loaded.")
1. High-Speed Generation¶
We'll generate a batch of 1,000 peptides of length 50. In a traditional serial loop, this would take significant time. In synth-pdb, it's a single matrix operation.
# Construct a clean sequence
residues = ["ALA", "GLY", "SER", "LEU", "VAL", "ILE", "MET"] * 7
sequence = "-".join(residues)
n_batch = 1000
generator = BatchedGenerator(sequence, n_batch=n_batch, full_atom=False)
start = time.time()
batch = generator.generate_batch(drift=5.0)
print(f"Generated {n_batch} structures.")
Benchmark: Serial vs. Batched Generation¶
Why use BatchedGenerator? Below we compare the time to generate 1000 structures one-by-one vs. generating them in a single batch.
from synth_pdb.generator import generate_pdb_content
def run_benchmark(n=100):
start_serial = time.time()
for _ in range(n):
_ = generate_pdb_content(sequence_str=sequence, minimize_energy=False)
serial_dt = time.time() - start_serial
start_batched = time.time()
_ = generator.generate_batch(drift=1.0)
batched_dt = time.time() - start_batched
return serial_dt, batched_dt
n_test = 100
s_time, b_time = run_benchmark(n_test)
s_1k = s_time * (1000/n_test)
b_1k = b_time * (1000/n_batch) if n_batch > 0 else b_time
plt.figure(figsize=(8, 4))
bars = plt.bar(["Traditional Serial", "synth-pdb Batched"], [s_1k, b_time], color=["#ff9999", "#667eea"])
plt.ylabel("Seconds per 1,000 Structures")
plt.title("Real-World Performance Comparison")
plt.grid(axis="y", linestyle="--", alpha=0.7)
for bar in bars:
yval = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2, yval + (s_1k*0.02), f"{yval:.3f}s", ha="center", va="bottom", fontweight="bold")
plt.show()
print(f"Vectorization Speedup: {s_1k / b_time:.1f}x")
print(f"Theoretical throughput: {1000/b_time:.0f} structures/sec")
2. PyTorch Handover (Zero-Copy)¶
PyTorch can "wrap" a NumPy array without copying it. Any change to the NumPy array will be reflected in the Tensor (and vice versa).
try:
import torch
torch_tensor = torch.from_numpy(batch.coords).float()
print("✅ PyTorch Handover successful!")
print(f"Tensor Device: {torch_tensor.device}")
print(f"Contiguous in memory: {torch_tensor.is_contiguous()}")
except ImportError:
print("❌ PyTorch not found. Use 'pip install torch' to see this in action.")
3. JAX / MLX Handover¶
JAX also supports efficient conversion from NumPy.
try:
import jax.numpy as jnp
jax_array = jnp.array(batch.coords)
print("✅ JAX Handover successful!")
print(f"JAX Device: {jax_array.device}")
except ImportError:
print("❌ JAX not found.")
4. Educational Note: Why does this matter?¶
In deep learning for proteins, the Data Loading step is often the bottleneck. If your GPU has to wait for Python loops to calculate coordinates, it sits idle.
By using BatchedGenerator, you can:
- Keep generation on the CPU/AMX units while the GPU trains.
- Avoid expensive serialized PDB parsing.
- Feed thousands of "Hard Decoys" (structures with noise) to help your model learn the energy landscape.
4. Visualizing the Data: Structural Ensembles¶
In ML, we often want to train on "Hard Decoys"—structures that are mostly correct but have physical noise. BatchedGenerator can produce these ensembles instantly.
plt.figure(figsize=(10, 6))
for i in range(10):
plt.plot(batch.coords[i, :, 0], batch.coords[i, :, 1], alpha=0.3, label=f"Model {i}" if i==0 else "")
plt.title("Ensemble Drift: Structural Noise for ML Training")
plt.xlabel("X (Å)")
plt.ylabel("Y (Å)")
plt.legend()
plt.show()
Interactive 3D Inspection¶
Use 3Dmol.js to inspect a sample structure from the batch.
try:
import numpy as np
import py3Dmol
from synth_pdb.batch_generator import BatchedPeptide
c = batch.coords[0].copy()
mask = np.any(c != 0, axis=1)
c_clean = c[mask]
center = (c_clean.min(axis=0) + c_clean.max(axis=0)) / 2
c_centered = c_clean - center
p = BatchedPeptide(
c_centered[np.newaxis, ...],
batch.sequence,
np.array(batch.atom_names)[mask].tolist(),
np.array(batch.residue_indices)[mask].tolist()
)
view = py3Dmol.view(width=800, height=400)
view.setBackgroundColor("#fdfdfd")
view.addModel(p.to_pdb(0), "pdb")
view.setStyle({"stick": {"radius": 0.15}, "cartoon": {"color": "spectrum"}})
view.zoomTo()
view.center()
view.zoom(1.2)
view.show()
print(f"Viewer Ready. Visualizing {len(c_clean)} atoms.")
except ImportError:
print("py3Dmol not installed.")