
🧠 Multi-Modal AI Research: The Structural-Magnetic Bridge ⚛️¶
Objective: Learn how to generate synchronized datasets of 3D Structural Tensors and Experimental observables (NMR Chemical Shifts) for Multi-Modal AI training.
🌟 The Vision: "AlphaFold-NMR"¶
In modern structural biology, 3D coordinates are only half the story. Real experimental verification often comes from Nuclear Magnetic Resonance (NMR). NMR chemical shifts are incredibly sensitive to the local electronic environment—meaning every atom's magnetic frequency is a "fingerprint" of the local geometry.
In this lab, we build an end-to-end pipeline that treats the protein as both a Geometric Object and a Magnetic Observable. This data is used to train models that can:
- Back-Calculate: Predict NMR shifts from structure.
- De-Novo Solve: Predict structure directly from chemical shifts.
# @title Setup & Installation { display-mode: "form" }
import os
import sys
from pathlib import Path
try:
current_path = Path(".").resolve()
repo_root = current_path.parent.parent
if (repo_root / "synth_pdb").exists():
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
print(f"📌 Added local library to path: {repo_root}")
except Exception:
pass
if 'google.colab' in str(get_ipython()):
if not os.path.exists("installed.marker"):
print("Running on Google Colab. Installing dependencies...")
get_ipython().run_line_magic('pip', 'install synth-pdb torch numpy matplotlib py3Dmol biotite')
with open("installed.marker", "w") as f:
f.write("done")
print("🔄 Installation complete. KERNEL RESTARTING AUTOMATICALLY...")
os.kill(os.getpid(), 9)
else:
print("✅ Dependencies Ready.")
else:
import synth_pdb
print(f"✅ Running locally. Using synth-pdb version: {synth_pdb.__version__}")
import matplotlib.pyplot as plt
import py3Dmol
import torch
from torch.utils.data import DataLoader, Dataset
from synth_pdb.batch_generator import BatchedGenerator
from synth_pdb.chemical_shifts import predict_chemical_shifts
print("Magnetic Resonance Engine: ONLINE ⚡")
1. Synchronized Generation: The Coords-Shift Tensor¶
We will generate a batch of structures with significant structural drift, then compute the resulting chemical shifts for every atom. This creates a paired dataset: (X, Y) = (Coordinates, NMR Shifts).
# FIX: Use explicit hyphenation to avoid 'ASPTRP' merging errors
sequence = "-".join(["TRP-PHE-TYR-HIS-LYS-GLU-ASP"] * 3) # 21 residues, rich in Aromatics
n_samples = 100
print(f"🚀 Generating {n_samples} synchronized multi-modal structural samples...")
generator = BatchedGenerator(sequence, n_batch=n_samples, full_atom=True)
batch = generator.generate_batch(drift=2.0)
print("✅ Structural Tensors Generated.")
print("⚡ Predicting Chemical Shifts (SPARTA-Lite + Ring Currents)...")
all_shifts = []
for i in range(5): # We'll analyze the first 5 in detail for the demo
# Convert batch member to biotite structure for the NMR engine
pdb_str = batch.to_pdb(i)
from io import StringIO
import biotite.structure.io.pdb as pdb_io
struct = pdb_io.PDBFile.read(StringIO(pdb_str)).get_structure(model=1)
shifts = predict_chemical_shifts(struct)
all_shifts.append(shifts)
print(f"✅ Paired Data Ready. Sample 0 Chain A Res 1 chemical shifts: {all_shifts[0]['A'][1]}")
2. Fold Recognition: The CSI Plot¶
The Chemical Shift Index (CSI) is the deviation of an atom's frequency from its "Random Coil" baseline.
- Alpha Helices: Move C-alpha shifts Upfield (+ ppm).
- Beta Sheets: Move C-alpha shifts Downfield (- ppm).
Let's visualize this footprint for our first sample.
# Extract CA secondary shifts for Sample 0
sample_idx = 0
res_ids = sorted(all_shifts[sample_idx]['A'].keys())
ca_deltas = [all_shifts[sample_idx]['A'][r].get('CA', 0) - 52.5 for r in res_ids] # Relative to generic Ala baseline for visual
plt.figure(figsize=(12, 4))
plt.bar(res_ids, ca_deltas, color='#9b59b6', alpha=0.7, label="Delta-CA (Secondary Shift)")
plt.axhline(0.7, color='red', linestyle='--', alpha=0.3, label="Helix Threshold")
plt.axhline(-0.7, color='blue', linestyle='--', alpha=0.3, label="Sheet Threshold")
plt.title("The Magnetic Footprint of Protein Folding")
plt.xlabel("Residue Number")
plt.ylabel("CSI Deviation (ppm)")
plt.legend()
plt.grid(alpha=0.2)
plt.show()
print("Educational Insight: Note how consistent positive deviations signal a stable secondary structure.")
3. Visualizing Ring Current Effects (Tertiary Proximity)¶
Aromatic rings (Phe, Tyr, Trp) act like tiny electromagnets. Atoms that get too close to the "face" of the ring are shielded and shift toward lower frequencies. This is how NMR "sees" tertiary packing.
view = py3Dmol.view(width=800, height=400)
view.setBackgroundColor("#fdfdfd")
pdb_str = batch.to_pdb(0)
# 1. Highlight the Aromatic Rings
view.addModel(pdb_str, 'pdb')
view.setStyle({'model': 0}, {'cartoon': {'color': '#667eea', 'opacity': 0.6}})
view.setStyle({'resn': ['PHE', 'TYR', 'TRP']}, {'stick': {'radius': 0.25, 'color': '#ffcc00'}})
# 2. Show the "Magnetic Cloud"
# We'll put a surface around aromatics to visualize the 'Influence Zone'
view.addSurface(py3Dmol.MS, {'opacity': 0.2, 'color': '#ffcc00'}, {'resn': ['PHE', 'TYR', 'TRP']})
view.zoomTo()
view.center()
view.show()
print("Yellow regions indicate Aromatic hubs that distort the local magnetic field of nearby nuclei.")
4. Multi-Modal PyTorch Pipeline¶
Finally, we combine both signals into a single high-performance DataLoader. Every sample is a tuple of (Geometry, NMR).
class MultiModalProteinDataset(Dataset):
def __init__(self, coords, shifts_list):
self.coords = torch.from_numpy(coords).float()
# Tensorize the 'CA' and 'HA' shifts as features
n_samples = len(shifts_list)
n_res = coords.shape[1] // 4 # Approximate for backbone clusters
self.nmr_features = torch.zeros((n_samples, n_res, 2)) # [CA_shift, HA_shift]
for i in range(n_samples):
# Only use chain A for the demo
if 'A' not in shifts_list[i]: continue
s = shifts_list[i]['A']
sorted_keys = sorted(s.keys())
for r_idx, r_id in enumerate(sorted_keys):
if r_idx < n_res:
self.nmr_features[i, r_idx, 0] = s[r_id].get('CA', 0.0)
self.nmr_features[i, r_idx, 1] = s[r_id].get('HA', 0.0)
def __len__(self):
return len(self.coords)
def __getitem__(self, idx):
return self.coords[idx], self.nmr_features[idx]
# Create the synchronized dataset
ds = MultiModalProteinDataset(batch.coords[:5], all_shifts)
loader = DataLoader(ds, batch_size=2, shuffle=True)
batch_coords, batch_nmr = next(iter(loader))
print("✅ Multi-Modal Batch Data Ready.")
print(f"Geometry Shape: {batch_coords.shape}")
print(f"NMR Tensor Shape: {batch_nmr.shape} (Input for Transformer Encoded Shifts)")
🏆 Next Steps¶
- Predicting Reality: Try generating structures with
--conformation betaand see how the CSI Plot flips! 📉 - Transformer Training: Feed the
batch_nmrtensor into a 1D Transformer to see if it can recover secondary structure labels.
You are now generating the same type of data used to train the next generation of experimental AI solvers. The lab is yours. 🧬🤖