Skip to content

score — GNN Quality Scoring API

The synth_pdb.score module provides a single-import, zero-configuration interface for scoring protein structures using the bundled Graph Attention Network (GNN) quality classifier. It is the recommended entry point for all quality scoring tasks.

Installation

pip install synth-pdb[gnn]    # installs torch + torch_geometric
The synth_pdb.score module can be imported without PyTorch installed — the dependency is only checked when a scoring function is actually called.


Quick Start

from synth_pdb.score import score_structure, score_batch

# Score a PDB file by path
result = score_structure("my_helix.pdb")
print(f"Global quality: {result.global_score:.3f}  ({result.label})")
# Global quality: 0.999  (High Quality)

# Inspect per-residue pLDDT confidence
for i, (score, label) in enumerate(zip(result.per_residue, result.residue_labels)):
    print(f"  Residue {i+1:3d}: {score:.3f}  [{label}]")
# Residue   1: 0.958  [Very High]
# Residue   2: 0.960  [Very High]
# ...

# Score a batch efficiently (model loaded once)
results = score_batch(["helix.pdb", "strand.pdb", "decoy.pdb"])
best = max(results, key=lambda r: r.global_score)
print(f"Best structure: global_score={best.global_score:.3f}")

score_structure()

def score_structure(
    source: str | os.PathLike,
    *,
    model_path: str | None = None,
) -> QualityScore

Score a single protein structure and return a rich QualityScore object.

Parameters

Parameter Type Description
source str or path-like A file path ending in .pdb, or a raw PDB-format string. File detection is based on whether the string starts with a PDB record keyword (ATOM, REMARK, HEADER, MODEL).
model_path str, optional Path to a custom .pt checkpoint. Defaults to the bundled gnn_quality_v2.pt (with per-residue head). Falls back to gnn_quality_v1.pt if v2 is unavailable.

Returns

A QualityScore dataclass.

Raises

Exception Condition
FileNotFoundError source looks like a file path but the file does not exist.
ImportError torch or torch_geometric are not installed.
ValueError The PDB contains fewer than 2 residues with Cα atoms.

Examples

from synth_pdb.score import score_structure

# From a file path
result = score_structure("/data/structures/ubiquitin.pdb")

# From an inline PDB string
pdb_string = open("ubiquitin.pdb").read()
result = score_structure(pdb_string)

# Using a custom checkpoint
result = score_structure("ubiquitin.pdb", model_path="my_retrained_gnn.pt")

score_batch()

def score_batch(
    sources: list[str | os.PathLike],
    *,
    model_path: str | None = None,
) -> list[QualityScore]

Score a list of structures efficiently. The GNN model is loaded once and reused for all structures — significantly faster than calling score_structure() in a loop for large collections.

If any individual structure fails (e.g. unparseable PDB, too few residues), a sentinel QualityScore with global_score=NaN and label="Error" is inserted at the corresponding index, so the output list always has the same length as the input list.

Parameters

Parameter Type Description
sources list[str \| PathLike] Mixed list of file paths and/or PDB strings.
model_path str, optional Custom checkpoint path.

Returns

list[QualityScore] — one result per input, in the same order.

Example

import glob
from synth_pdb.score import score_batch

pdb_files = sorted(glob.glob("alphafold_predictions/*.pdb"))
results = score_batch(pdb_files)

# Rank by global quality score
ranked = sorted(zip(pdb_files, results), key=lambda x: x[1].global_score, reverse=True)
for path, r in ranked[:5]:
    print(f"{path:50s}  {r.global_score:.4f}  {r.label}")

QualityScore

@dataclass
class QualityScore:
    global_score:    float
    label:           str
    per_residue:     list[float]
    residue_labels:  list[str]
    features:        dict[str, float]
    n_residues:      int

Returned by score_structure(), score_batch(), and GNNQualityClassifier.score().

Fields

Field Type Description
global_score float ∈ [0,1] P(Good) — probability the structure is biophysically plausible. Values > 0.5 are classified as High Quality.
label str "High Quality" or "Low Quality".
per_residue list[float] Per-residue pLDDT-like confidence ∈ [0, 1]. Length equals n_residues. Analogous to AlphaFold's per-residue pLDDT.
residue_labels list[str] Human-readable confidence band for each residue.
features dict[str, float] Mean per-feature summary of the GNN input graph (useful for debugging). Keys: sin_phi, cos_phi, sin_psi, cos_psi, b_factor_norm, seq_position, is_n_terminus, is_c_terminus.
n_residues int Number of residues with Cα atoms in the PDB.

pLDDT Confidence Bands

Label Score range Interpretation (AlphaFold equivalent)
"Very High" ≥ 0.90 Backbone and side-chain likely accurate
"High" 0.70–0.90 Backbone likely accurate
"Uncertain" 0.50–0.70 Use with caution
"Low" < 0.50 Likely disordered or incorrect geometry

Example Usage

result = score_structure("my_protein.pdb")

# Find low-confidence regions
low_conf = [
    i + 1  # 1-indexed residue number
    for i, label in enumerate(result.residue_labels)
    if label in ("Uncertain", "Low")
]
print(f"Low-confidence residues: {low_conf}")

# Check mean pLDDT
import numpy as np
mean_plddt = np.mean(result.per_residue)
print(f"Mean pLDDT: {mean_plddt:.3f}")

# Export to pandas for downstream analysis
import pandas as pd
df = pd.DataFrame({
    "residue": range(1, result.n_residues + 1),
    "plddt": result.per_residue,
    "band": result.residue_labels,
})
df.to_csv("plddt_per_residue.csv", index=False)

GNNQualityClassifier

The lower-level class underlying score_structure(). Import it when you need direct control over checkpoint loading or want to access the predict() method for backward compatibility.

from synth_pdb.quality import GNNQualityClassifier

clf = GNNQualityClassifier()                     # auto-loads bundled weights
clf = GNNQualityClassifier(model_path="v2.pt")   # explicit checkpoint

Methods

score(pdb_content: str) → QualityScore

The primary method. Equivalent to score_structure(pdb_content) but requires a PDB string (not a file path).

predict(pdb_content: str) → (bool, float, dict)

Legacy method for backward compatibility with the ProteinQualityClassifier (RF) API. Returns (is_good, probability, features_dict).

save(path: str) → None

Save model weights and architecture metadata to a .pt checkpoint.

load(path: str) → None

Load a checkpoint. The architecture (node features, hidden dim, etc.) is read from the checkpoint itself — no configuration file needed.


Retraining the Model

To retrain gnn_quality_v2.pt from scratch (e.g. after modifying the architecture or adding training data):

python scripts/train_gnn_quality_filter.py \
    --n-samples 200 \
    --epochs 50 \
    --output synth_pdb/quality/models/gnn_quality_v2.pt

The training script generates 200 synthetic structures across four classes (Good / Random / Distorted / Clashing) and trains with a joint objective:

\[\mathcal{L} = \mathcal{L}_{\text{NLL}} + \lambda \cdot \mathcal{L}_{\text{MSE}}\]

where \(\mathcal{L}_{\text{NLL}}\) is the global binary classification loss, \(\mathcal{L}_{\text{MSE}}\) is the per-residue Ramachandran Z-score regression loss, and \(\lambda = 0.3\) by default.


Full API Reference

score

synth_pdb.score. ~~~~~~~~~~~~~~~~ Top-level, user-facing API for protein structure quality scoring.

This module provides a single-import interface to the GNN quality scorer, hiding all internal complexity behind two clean functions::

from synth_pdb.score import score_structure, score_batch

# Score a single structure from a file path or PDB string
result = score_structure("my_protein.pdb")
print(f"Global quality: {result.global_score:.3f}  ({result.label})")
print(f"Per-residue pLDDT: {result.per_residue}")

# Score a batch (processes all structures in one call)
results = score_batch(["prot1.pdb", "prot2.pdb", "prot3.pdb"])
top = sorted(results, key=lambda r: r.global_score, reverse=True)

Design goals

  1. Zero-friction: works out of the box with the bundled pre-trained weights; no training required.
  2. Single import: from synth_pdb.score import score_structure is all a user ever needs.
  3. Lazy loading: PyTorch and torch_geometric are imported only when a scoring function is called - users without GPU/torch can still import the rest of synth_pdb.
  4. Backward compatibility: the underlying GNNQualityClassifier is still available for users who need lower-level access.

Requirements

pip install synth-pdb[gnn]

or equivalently::

pip install torch torch_geometric

Classes

QualityScore dataclass

Rich quality assessment result for a single protein structure.

This object serves as the Data Transfer Object (DTO) between the internal GNN inference engine and the end-user. It encapsulates not just a binary "Good/Bad" label, but a high-resolution map of the structure's physical confidence.

-- pLDDT: The Standard of Confidence ------------------------------------- The per_residue scores are modeled after the predicted Local Distance Difference Test (pLDDT), the primary confidence metric used by AlphaFold. Values in [0, 1] represent the model's certainty that a residue is in its physically correct local environment.

Attributes

global_score : float The "Whole-Protein" probability P(Good) in [0, 1]. This is the output of the GNN's global pooling layer followed by a log-softmax. * 0.9 - 1.0: Extremely confident, well-folded model. * 0.5 - 0.9: Likely valid but may have minor local strains. * < 0.5 : "Low Quality" - likely contains unphysical geometry. label : str A human-readable categorical label ("High Quality" or "Low Quality") derived from the 0.5 global_score threshold. per_residue : list[float] The "Confidence Heatmap". Each float represents the pLDDT of an individual residue. Length equals the number of residues (Calpha atoms). This is generated by the auxiliary regression head in v2 models. residue_labels : list[str] The AlphaFold-standardized categorical bands: * "Very High" (>= 0.90) : Crystallographic-quality geometry. * "High" (>= 0.70) : Generally reliable backbone. * "Uncertain" (>= 0.50) : Low-confidence loop or linker. * "Low" (< 0.50) : Unphysical/Clashing region. features : dict[str, float] A dictionary of the mean input node features (sin_phi, cos_phi, etc.). This is provided for Explainable AI (XAI) - it helps researchers understand if a low score was triggered by bad dihedrals or high B-factors. n_residues : int The total number of nodes (amino acids) in the interaction graph.

Examples

clf = GNNQualityClassifier() result = clf.score(pdb_string) print(f"Protein quality: {result.label} ({result.global_score:.1%})")

Identify local errors

clashes = [i for i, s in enumerate(result.per_residue) if s < 0.5] print(f"Detected {len(clashes)} problematic residues.")

Source code in synth_pdb/quality/gnn/gnn_classifier.py
@dataclass
class QualityScore:
    """Rich quality assessment result for a single protein structure.

    This object serves as the **Data Transfer Object (DTO)** between the
    internal GNN inference engine and the end-user. It encapsulates not just
    a binary "Good/Bad" label, but a high-resolution map of the structure's
    physical confidence.

    -- pLDDT: The Standard of Confidence -------------------------------------
    The `per_residue` scores are modeled after the **predicted Local Distance
    Difference Test (pLDDT)**, the primary confidence metric used by AlphaFold.
    Values in [0, 1] represent the model's certainty that a residue is in its
    physically correct local environment.

    Attributes
    ----------
    global_score : float
        The "Whole-Protein" probability P(Good) in [0, 1]. This is the output
         of the GNN's global pooling layer followed by a log-softmax.
         * 0.9 - 1.0: Extremely confident, well-folded model.
         * 0.5 - 0.9: Likely valid but may have minor local strains.
         * < 0.5    : "Low Quality" - likely contains unphysical geometry.
    label : str
        A human-readable categorical label ("High Quality" or "Low Quality")
        derived from the 0.5 global_score threshold.
    per_residue : list[float]
        The "Confidence Heatmap". Each float represents the pLDDT of an
        individual residue. Length equals the number of residues (Calpha atoms).
        This is generated by the auxiliary regression head in v2 models.
    residue_labels : list[str]
        The AlphaFold-standardized categorical bands:
        * "Very High" (>= 0.90) : Crystallographic-quality geometry.
        * "High"      (>= 0.70) : Generally reliable backbone.
        * "Uncertain" (>= 0.50) : Low-confidence loop or linker.
        * "Low"       (< 0.50) : Unphysical/Clashing region.
    features : dict[str, float]
        A dictionary of the mean input node features (sin_phi, cos_phi, etc.).
        This is provided for **Explainable AI (XAI)** - it helps researchers
        understand if a low score was triggered by bad dihedrals or high B-factors.
    n_residues : int
        The total number of nodes (amino acids) in the interaction graph.

    Examples
    --------
    >>> clf = GNNQualityClassifier()
    >>> result = clf.score(pdb_string)
    >>> print(f"Protein quality: {result.label} ({result.global_score:.1%})")
    >>> # Identify local errors
    >>> clashes = [i for i, s in enumerate(result.per_residue) if s < 0.5]
    >>> print(f"Detected {len(clashes)} problematic residues.")
    """

    global_score: float
    label: str
    per_residue: list[float] = field(default_factory=list)
    residue_labels: list[str] = field(default_factory=list)
    features: dict[str, float] = field(default_factory=dict)
    n_residues: int = 0

Functions

score_structure(source, *, model_path=None)

Score a single protein structure and return a rich quality assessment.

Parameters

source : str or path-like Either: * A file path ending in .pdb - the file is read automatically. * A raw PDB-format string (must start with ATOM or REMARK). model_path : str, optional Path to a custom .pt checkpoint. Defaults to the bundled pre-trained weights (gnn_quality_v2.pt).

Returns

QualityScore Dataclass with: global_score (float in [0, 1]), label ("High Quality" / "Low Quality"), per_residue (list of per-residue pLDDT scores in [0, 1]), residue_labels (list of "Very High"/"High"/"Uncertain"/"Low"), n_residues (int), features (dict of mean input features, for debugging).

Raises

FileNotFoundError If source looks like a file path but does not exist. ImportError If torch or torch_geometric are not installed. ValueError If the PDB content has fewer than 2 residues with Calpha atoms.

Examples

from synth_pdb.score import score_structure result = score_structure("outputs/my_helix.pdb") print(result.global_score, result.label) 0.987 High Quality low = [i for i, lbl in enumerate(result.residue_labels) if lbl == "Low"] print(f"Low-confidence residues: {low}") Low-confidence residues: []

Source code in synth_pdb/score.py
def score_structure(
    source: Union[str, "os.PathLike[str]"],
    *,
    model_path: str | None = None,
) -> "QualityScore":  # type: ignore[name-defined]
    """Score a single protein structure and return a rich quality assessment.

    Parameters
    ----------
    source : str or path-like
        Either:
        * A file path ending in ``.pdb`` - the file is read automatically.
        * A raw PDB-format string (must start with ``ATOM`` or ``REMARK``).
    model_path : str, optional
        Path to a custom ``.pt`` checkpoint.  Defaults to the bundled
        pre-trained weights (``gnn_quality_v2.pt``).

    Returns
    -------
    QualityScore
        Dataclass with:
        ``global_score`` (float in [0, 1]),
        ``label`` ("High Quality" / "Low Quality"),
        ``per_residue`` (list of per-residue pLDDT scores in [0, 1]),
        ``residue_labels`` (list of "Very High"/"High"/"Uncertain"/"Low"),
        ``n_residues`` (int),
        ``features`` (dict of mean input features, for debugging).

    Raises
    ------
    FileNotFoundError
        If *source* looks like a file path but does not exist.
    ImportError
        If ``torch`` or ``torch_geometric`` are not installed.
    ValueError
        If the PDB content has fewer than 2 residues with Calpha atoms.

    Examples
    --------
    >>> from synth_pdb.score import score_structure
    >>> result = score_structure("outputs/my_helix.pdb")
    >>> print(result.global_score, result.label)
    0.987 High Quality
    >>> low = [i for i, lbl in enumerate(result.residue_labels) if lbl == "Low"]
    >>> print(f"Low-confidence residues: {low}")
    Low-confidence residues: []

    """
    # -- Resolve source -> PDB string ------------------------------------
    source_str = os.fspath(source) if not isinstance(source, str) else source

    if source_str.strip().startswith(("ATOM", "REMARK", "HEADER", "MODEL")):
        # Treat as inline PDB content
        pdb_content = source_str
    else:
        # Treat as file path
        path = os.path.expanduser(source_str)
        if not os.path.exists(path):
            raise FileNotFoundError(f"PDB file not found: {path}")
        with open(path) as fh:
            pdb_content = fh.read()

    # -- Load classifier (singleton) ------------------------------------
    if model_path:
        from synth_pdb.quality.gnn.gnn_classifier import GNNQualityClassifier

        clf = GNNQualityClassifier(model_path=model_path)
    else:
        clf = _get_classifier()

    return clf.score(pdb_content)

score_batch(sources, *, model_path=None)

Score a list of protein structures efficiently using a shared model.

The GNN weights are loaded once and reused for all structures, making this significantly faster than calling score_structure() in a loop when the list is large (avoids repeated checkpoint loading).

Parameters

sources : list of str or path-like Each element is either a file path ending in .pdb or a raw PDB string. Mixed lists are accepted. model_path : str, optional Path to a custom .pt checkpoint.

Returns

list[QualityScore] One result per input, in the same order as sources.

Examples

from synth_pdb.score import score_batch paths = ["helix.pdb", "strand.pdb", "random.pdb"] results = score_batch(paths) top = max(results, key=lambda r: r.global_score) print(f"Best: {top.global_score:.3f} ({top.label})")

Source code in synth_pdb/score.py
def score_batch(
    sources: list[Union[str, "os.PathLike[str]"]],
    *,
    model_path: str | None = None,
) -> "list[QualityScore]":  # type: ignore[name-defined]
    """Score a list of protein structures efficiently using a shared model.

    The GNN weights are loaded once and reused for all structures, making
    this significantly faster than calling ``score_structure()`` in a loop
    when the list is large (avoids repeated checkpoint loading).

    Parameters
    ----------
    sources : list of str or path-like
        Each element is either a file path ending in ``.pdb`` or a raw PDB
        string.  Mixed lists are accepted.
    model_path : str, optional
        Path to a custom ``.pt`` checkpoint.

    Returns
    -------
    list[QualityScore]
        One result per input, in the same order as *sources*.

    Examples
    --------
    >>> from synth_pdb.score import score_batch
    >>> paths = ["helix.pdb", "strand.pdb", "random.pdb"]
    >>> results = score_batch(paths)
    >>> top = max(results, key=lambda r: r.global_score)
    >>> print(f"Best: {top.global_score:.3f}  ({top.label})")

    """
    if model_path:
        from synth_pdb.quality.gnn.gnn_classifier import GNNQualityClassifier

        clf = GNNQualityClassifier(model_path=model_path)
    else:
        clf = _get_classifier()

    results = []
    for i, source in enumerate(sources):
        try:
            source_str = os.fspath(source) if not isinstance(source, str) else source

            if source_str.strip().startswith(("ATOM", "REMARK", "HEADER", "MODEL")):
                pdb_content = source_str
            else:
                path = os.path.expanduser(source_str)
                if not os.path.exists(path):
                    raise FileNotFoundError(f"PDB file not found: {path}")
                with open(path) as fh:
                    pdb_content = fh.read()

            results.append(clf.score(pdb_content))
        except Exception as exc:
            logger.warning("score_batch: failed on item %d (%s): %s", i, source, exc)
            # Return a sentinel score so the list length always matches the input
            from synth_pdb.quality.gnn.gnn_classifier import QualityScore

            results.append(
                QualityScore(
                    global_score=float("nan"),
                    label="Error",
                    per_residue=[],
                    residue_labels=[],
                    features={},
                    n_residues=0,
                )
            )

    return results

gnn_classifier

synth_pdb.quality.gnn.gnn_classifier. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GNN-based protein structure quality classifier with global and per-residue outputs.


DESIGN CONTRACT - Same API as ProteinQualityClassifier (RF)

Both classifiers expose

predict(pdb_str) -> (is_good: bool, probability: float, features: dict)

This lets downstream code swap between the RF and GNN model without changes::

from synth_pdb.quality.classifier    import ProteinQualityClassifier   # RF
from synth_pdb.quality.gnn.gnn_classifier import GNNQualityClassifier  # GNN

clf = GNNQualityClassifier()       # or ProteinQualityClassifier()
is_good, prob, feats = clf.predict(pdb_string)

Additionally, GNNQualityClassifier exposes a richer API::

result = clf.score(pdb_string)
# result.global_score       -> float in [0, 1]  (P(Good))
# result.per_residue        -> list[float]      (pLDDT per residue)
# result.residue_labels     -> list[str]        ("Very High"/"High"/"Uncertain"/"Low")
# result.label              -> str              ("High Quality" / "Low Quality")

CHECKPOINT FORMAT (.pt)

GNN weights are saved with torch.save() as a dict::

{
  "state_dict"   : OrderedDict of parameter tensors,
  "node_features": int,    <- architecture metadata
  "edge_features": int,
  "hidden_dim"   : int,
  "num_classes"  : int,
}

We store architecture metadata alongside weights so the model can be re-instantiated without any external configuration file. This is the standard pattern for "self-describing" PyTorch checkpoints.


HIGH-THROUGHPUT AUDITING - Vectorized Ensemble Scoring

For large-scale structural genomics or generative AI tasks, scoring structures individually is prohibitively slow. Each structure requires a Python function call, PDB serialization, and a GPU kernel launch.

The GNNQualityClassifier solves this via the score_batch method. By accepting a BatchedPeptide object, the classifier can process thousands of structures in a single massive GPU operation. This "vectorized auditing" allows synth-pdb to act as a real-time quality filter for high-diversity secondary structure ensembles.

Classes

GNNQualityClassifier

GNN-based protein structure quality classifier.

Predicts whether a PDB structure is "High Quality" (biophysically plausible, good Ramachandran geometry, no steric clashes) or "Low Quality".

This classifier implements the Graph Neural Network (GNN) philosophy: instead of relying on manual features like 'clash counts', it learns to recognise local and global patterns of structural discordance directly from the residue interaction graph.

-- When is a GNN better than a Random Forest? ------------------------- The RF classifier uses hand-crafted, per-structure summary statistics (e.g. "fraction of residues in favoured Ramachandran regions").

The GNN works directly on the full residue interaction graph, so it can: * Learn WHICH specific contacts are problematic, not just aggregate counts * Capture spatial patterns (e.g. a single clashing i/i+4 contact pair) * Generalise to protein classes or contact patterns not seen in training (because the pattern recogniser is learned, not hand-engineered)

-- Performance Trade-offs --------------------------------------------- * RF: Extremely fast (~0.1 ms/sample), zero dependencies, best for simple screening. * GNN: Richer signal (~0.3 ms/sample), requires torch, best for deep quality assessment.


Source code in synth_pdb/quality/gnn/gnn_classifier.py
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
class GNNQualityClassifier:
    """GNN-based protein structure quality classifier.

    Predicts whether a PDB structure is "High Quality" (biophysically plausible,
    good Ramachandran geometry, no steric clashes) or "Low Quality".

    This classifier implements the Graph Neural Network (GNN) philosophy:
    instead of relying on manual features like 'clash counts', it learns to
    recognise local and global patterns of structural discordance directly
    from the residue interaction graph.

    -- When is a GNN better than a Random Forest? -------------------------
    The RF classifier uses hand-crafted, per-structure summary statistics
    (e.g. "fraction of residues in favoured Ramachandran regions").

    The GNN works directly on the full residue interaction graph, so it can:
      * Learn WHICH specific contacts are problematic, not just aggregate counts
      * Capture spatial patterns (e.g. a single clashing i/i+4 contact pair)
      * Generalise to protein classes or contact patterns not seen in training
        (because the pattern recogniser is learned, not hand-engineered)

    -- Performance Trade-offs ---------------------------------------------
    * RF:  Extremely fast (~0.1 ms/sample), zero dependencies, best for simple screening.
    * GNN: Richer signal (~0.3 ms/sample), requires torch, best for deep quality assessment.
    -----------------------------------------------------------------------
    """

    def __init__(self, model_path: str | None = None):
        """Initialise the GNN quality classifier.

        The classifier follows a **'Lazy Loading'** pattern where the heavy GNN
        architecture and weights are only loaded when explicitly requested or
        during first inference. This prevents synth-pdb from crashing on
        systems where `torch` is not installed, as long as the user doesn't
        try to use the GNN features.

        -- Model Versioning ---------------------------------------------------
        * **v2 Models**: Include an auxiliary regression head for pLDDT.
        * **v1 Models**: Legacy global-only classification.
        -----------------------------------------------------------------------

        Args:
            model_path: Path to a .pt checkpoint written by GNNQualityClassifier.save().
                        If None, looks for the default bundled checkpoint (v2 first, v1 fallback).
                        If no checkpoint is found, initialises a random-weight model
                        (useful for testing graph construction without training).

        """
        # Internal model storage (Lazy loaded via self.load() or _init_fresh_model())
        self.model: Any | None = None

        # Keep track of where we loaded the weights from for provenance and auditing.
        self._model_path: str | None = None

        # Track if the model supports per-residue confidence (v2) or just global (v1).
        # v2 models have an auxiliary regression head that predicts local confidence.
        self._has_residue_head: bool = False

        # Track if weights were successfully loaded from a checkpoint.
        # This is used by unit tests to determine if accuracy assertions
        # are valid or if the model is in a random state.
        self._is_pretrained: bool = False

        if model_path:
            # User provided an explicit path - load it or die.
            # This allows researchers to use their own trained checkpoints (e.g. robust_final.pt).
            self.load(model_path)
        else:
            # Search for bundled pre-trained weights in the installation folder.
            # v2 (per-residue head) is the modern standard for synth-pdb.
            v2 = os.path.normpath(_DEFAULT_CHECKPOINT_V2)
            v1 = os.path.normpath(_DEFAULT_CHECKPOINT_V1)

            if os.path.exists(v2):
                # Standard case: load the best available model
                self.load(v2)
            elif os.path.exists(v1):
                # Fallback to legacy v1 model if v2 is missing (e.g. partial install)
                logger.info(
                    "v2 checkpoint not found, loading v1 (no per-residue pLDDT). "
                    "Run scripts/train_gnn_quality_filter.py to produce v2."
                )
                self.load(v1)
            else:
                # No weights found in models/ - this usually happens during
                # fresh development or minimal installs without weight assets.
                logger.info(
                    "No pre-trained GNN checkpoint found. "
                    "Classifier initialised with random weights. "
                    "Run scripts/train_gnn_quality_filter.py to train."
                )
                self._init_fresh_model()

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    def predict(self, pdb_content: str) -> tuple[bool, float, dict[str, float]]:
        """Predict the quality of a PDB structure (Legacy Protocol).

        This method satisfies the `ProteinQualityClassifier` protocol,
        allowing the GNN to act as a **drop-in replacement** for the
        Random Forest (RF) classifier used in the core synth-pdb `Validator`.

        -- The API Bridge -----------------------------------------------
        By maintaining this exact signature, we enable "Polymorphic Quality":
        a generation script can accept ANY classifier instance and use it
        without knowing if it's a fast RF model or a deep GNN.

        -- Internal Pipeline --------------------------------------------
        1. **PDB Parsing**: Extract 3D Calpha coordinates from ATOM records.
        2. **Graph Construction**: Build the spatial interaction graph (graph.py).
        3. **Message Passing**: Propagate geometric data across the graph edges.
        4. **Global Pooling**: Average node embeddings into a whole-protein vector.
        5. **Softmax Output**: Produce final probability P(Good).
        -----------------------------------------------------------------

        Args:
            pdb_content: Raw PDB string containing the structural model.

        Returns:
            is_good (bool): True if the structure passes the 0.5 quality cliff.
            probability (float): The linear confidence score [0, 1].
            features (dict): Summary statistics of the input geometric features.

        Raises:
            ImportError: If the 'gnn' optional dependencies are missing.
            ValueError: If the PDB contains fewer than 2 valid residues.

        """
        # score() handles the heavy lifting; predict() wraps it for the legacy API.
        # This ensures that both the categorical and rich assessment APIs
        # use the exact same weights and graph-building logic.
        result = self.score(pdb_content)
        # We unpack the rich result object into the (bool, float, dict) format
        # expected by the standard synth-pdb Validator interface.
        return (result.label == "High Quality"), result.global_score, result.features

    def score(self, pdb_content: str) -> QualityScore:
        """Score a PDB structure, returning a rich :class:`QualityScore` object.

        This is the **modern, rich API** for the quality system. While
        `predict()` returns a simple bool, `score()` returns the full
        pLDDT confidence map, allowing for fine-grained structural auditing.

        -- The Power of the v2 Model ------------------------------------------
        When a v2 checkpoint is loaded, this method activates the auxiliary
        **Per-Residue Head**. This head doesn't just judge the whole protein;
        it performs "structural surgery" to identify exactly which loop
        is strained or which residue is clashing.

        Args:
            pdb_content: PDB-format string representing the protein.

        Returns:
            A :class:`QualityScore` instance with global and residue-level metrics.

        Raises:
            ImportError: If PyTorch or PyG are not installed.
            ValueError: If the PDB has insufficient atoms to form a graph.

        """
        # PyTorch and PyG are heavy dependencies (~500MB). We only import
        # them inside this method to allow synth-pdb to maintain its
        # "Lean Core" philosophy. Users who only want PDB generation
        # don't need to install the deep learning stack.
        try:
            import torch
        except ImportError as exc:
            # Provide an actionable error message for the optional [gnn] extra.
            raise ImportError(
                "torch is required for GNNQualityClassifier. "
                "Install with: pip install synth-pdb[gnn]"
            ) from exc

        # Lazy runtime imports for PyTorch Geometric (PyG) utilities.
        # Batch is the container used to pack graphs into high-performance tensors.
        from torch_geometric.data import Batch
        from .graph import build_protein_graph

        # Step 1: Spatial Graph Reconstruction
        # ---------------------------------------------------------------------
        # Convert the raw PDB string into a Topological Interaction Graph.
        # Nodes = Residues, Edges = Contacts < 10 A.
        graph = build_protein_graph(pdb_content)

        # Step 2: Batch Packaging
        # ---------------------------------------------------------------------
        # Even for one structure, we wrap the graph in a Batch object.
        # This is the standard input format for Graph Attention (GAT) layers,
        # as it includes the critical 'batch' mapping vector [Nodes, 1].
        batch = Batch.from_data_list([graph])

        # Step 3: Global Model Audit
        # ---------------------------------------------------------------------
        # Verify the weights were successfully loaded from the checkpoint.
        assert self.model is not None, "Model not loaded"

        # Evaluation Mode: Disables stochastic layers like Dropout and
        # BatchNormalization. This is essential for reproducible results.
        self.model.eval()

        # Inference Optimization: Disable the autograd engine (grad tracking).
        # This increases speed by 2x and saves memory by not building the
        # back-propagation graph.
        with torch.no_grad():
            if self._has_residue_head:
                # DUAL-OUTPUT MODE (v2 Architecture)
                # We retrieve both the pooled global log-probs and the
                # un-pooled node confidence map (pLDDT).
                log_probs, per_res_tensor = type(self.model).forward_with_node_embeddings(
                    self.model,
                    batch.x,
                    batch.edge_index,
                    batch.edge_attr,
                    batch.batch,
                )
                # Move tensor back to standard Python floats for the user.
                # Squeeze removes the redundant unit dimension: [Nodes, 1] -> [Nodes].
                per_residue = per_res_tensor.squeeze(-1).tolist()
            else:
                # LEGACY MODE (v1 Architecture)
                # Only the global whole-protein classification is available.
                # per_residue is returned as an empty list.
                log_probs = self.model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
                per_residue = []

        # Step 4: Final Probability Calculation
        # ---------------------------------------------------------------------
        # The raw output is in "Log-Softmax" space [-inf, 0].
        # We apply the exponent to get linear probabilities [0, 1].
        # Index 1 corresponds to the 'Good' (Physically Valid) class.
        prob_good = float(log_probs.exp()[0, 1].item())

        # Categorize the prediction using the 0.5 decision boundary.
        label = "High Quality" if prob_good > 0.5 else "Low Quality"

        # Map the numeric pLDDT confidence floats to the AlphaFold color bands.
        # This provides immediate structural insight (e.g. blue = good, orange = bad).
        residue_labels = [_plddt_label(s) for s in per_residue]

        # Step 5: Feature Introspection (Explainability)
        # ---------------------------------------------------------------------
        # We extract the mean of each input node feature (e.g., mean phi angle).
        # This dictionary is returned so researchers can verify IF the GNN is
        # paying attention to the correct physical signals.
        node_feats = graph.x.numpy()
        feat_dict = {
            name: float(np.mean(node_feats[:, i])) for i, name in enumerate(_FEATURE_NAMES)
        }

        # Step 6: Package Result
        # ---------------------------------------------------------------------
        return QualityScore(
            global_score=prob_good,
            label=label,
            per_residue=per_residue,
            residue_labels=residue_labels,
            features=feat_dict,
            n_residues=graph.num_nodes,
        )

    def score_batch(self, batch: Any) -> list[QualityScore]:
        r"""Score an entire batch of structures in a single vectorized pass.

        -----------------------------------------------------------------------------
        VECTORIZED ENSEMBLE AUDITING
        -----------------------------------------------------------------------------
        While the standard ``score()`` method processes a single PDB string,
        this method is optimized for the **high-throughput generation pipeline**.
        It operates directly on the coordinate tensors of a ``BatchedPeptide``,
        eliminating the CPU-bound bottlenecks of string serialization.

        -- GPU Parallelism ----------------------------------------------------------
        Instead of iterating structure-by-structure, we leverage PyTorch
        Geometric's **Graph Batching** mechanism. All $B$ structures are
        packed into a single disjoint Interaction Graph:
          1. $N_{total} = \sum N_i$ nodes are combined into a single matrix.
          2. Edge indices are shifted by node offsets to maintain connectivity.
          3. A single forward pass on the GPU evaluates the entire ensemble.

        This architecture is critical for screening large-scale "Bio-Active"
        libraries where thousands of candidates must be validated in seconds.

        Args:
            batch: A :class:`synth_pdb.batch_generator.BatchedPeptide` object
                containing $B$ protein structures.

        Returns:
            A list of $B$ :class:`QualityScore` objects, containing global and
            per-residue confidence for every member of the ensemble.
        """
        try:
            import torch
            from torch_geometric.data import Batch
        except ImportError:
            # Silent fallback - scoring will be skipped in downstream pipelines
            # if the optional 'gnn' dependencies are missing.
            return []

        from .graph import build_protein_graphs_from_batch

        # 1. Vectorized Graph Construction
        # ---------------------------------------------------------------------
        # We bypass PDB parsing and build graphs directly from the
        # (Batch, Atoms, 3) coordinate tensor using optimized NumPy kernels.
        graphs = build_protein_graphs_from_batch(batch)
        if not graphs:
            # No structures to score (empty batch)
            return []

        # 2. PyG Packing
        # ---------------------------------------------------------------------
        # Convert the list of individual graphs into a single large
        # 'Batch' object for parallel GPU processing.
        pyg_batch = Batch.from_data_list(graphs)

        # 3. Model & Device Synchronization
        # ---------------------------------------------------------------------
        # Ensure the model is loaded and in evaluation mode (disables dropout).
        assert self.model is not None, "Model not loaded"
        self.model.eval()

        # Identify the active device (CPU/CUDA/MPS) and move the entire packed
        # batch tensor to that device in one transfer.
        device = next(self.model.parameters()).device
        pyg_batch = pyg_batch.to(device)

        # 4. Forward Pass (Inference)
        # ---------------------------------------------------------------------
        # Gradient tracking is disabled to save VRAM and increase speed.
        with torch.no_grad():
            if self._has_residue_head:
                # v2 models: Retrieve both global log-probs and local pLDDT.
                # per_res_tensor shape: [TotalNodes, 1] - contains confidence
                # for every atom across all B structures.
                log_probs, per_res_tensor = type(self.model).forward_with_node_embeddings(
                    self.model,
                    pyg_batch.x,
                    pyg_batch.edge_index,
                    pyg_batch.edge_attr,
                    pyg_batch.batch,
                )
                per_residue_all = per_res_tensor.squeeze(-1).cpu().numpy()
            else:
                # v1 models: Legacy global-only scoring head.
                log_probs = self.model(
                    pyg_batch.x, pyg_batch.edge_index, pyg_batch.edge_attr, pyg_batch.batch
                )
                per_residue_all = None

        # 5. Result Post-Processing & Slicing
        # ---------------------------------------------------------------------
        # Convert log-probabilities back to linear probability space [0, 1].
        # probs_good[i] represents P(High Quality) for structure i.
        probs_good = log_probs.exp()[:, 1].cpu().numpy()

        results = []
        # ptr (pointer) array defines the residue boundaries for each graph
        # in the packed batch tensor. node_ptr[i] to node_ptr[i+1] is the slice.
        node_ptr = pyg_batch.ptr.cpu().numpy()

        for i in range(len(graphs)):
            prob_good = float(probs_good[i])
            label = "High Quality" if prob_good > 0.5 else "Low Quality"

            # Slice the per-residue pLDDT scores using the node pointers
            plddt = []
            res_lbls = []
            if per_residue_all is not None:
                # Extract the confidence values for this specific structure
                plddt = per_residue_all[node_ptr[i] : node_ptr[i + 1]].tolist()
                # Map to human-readable AlphaFold confidence bands
                res_lbls = [_plddt_label(s) for s in plddt]

            # Extract aggregate feature statistics (e.g. mean dihedrals)
            # for downstream transparency and debugging.
            node_feats = graphs[i].x.numpy()
            feat_dict = {
                name: float(np.mean(node_feats[:, j])) for j, name in enumerate(_FEATURE_NAMES)
            }

            # Wrap in structured QualityScore container
            results.append(
                QualityScore(
                    global_score=prob_good,
                    label=label,
                    per_residue=plddt,
                    residue_labels=res_lbls,
                    features=feat_dict,
                    n_residues=graphs[i].num_nodes,
                )
            )

        return results

    def save(self, path: str) -> None:
        """Save model weights and architecture config to a ``.pt`` checkpoint.

        The checkpoint is 'self-describing' - it includes the layer dimensions
        required to re-instantiate the `ProteinGNN` class without an external
        config or JSON file.  This follows the **'Single File Per Asset'**
        philosophy which simplifies deployment, model versioning, and
        cloud-based distribution of pre-trained structural auditors.

        -- Persistence Mechanism ----------------------------------------------
        We use standard `torch.save()`, which uses Python's `pickle` module
        internally. The payload includes:
          1. The `state_dict`: An OrderedDict mapping layer names to parameter
             tensors (weights and biases).
          2. Architectural Hyperparameters: `hidden_dim`, `node_features`, etc.
        -----------------------------------------------------------------------

        Args:
            path: Destination file path (should end in .pt).

        """
        # Ensure we have torch before trying to save.
        # This is a safety check for systems where torch is an optional extra.
        try:
            import torch
        except ImportError as exc:
            raise ImportError("torch is required to save a GNN checkpoint.") from exc

        # Ensure destination directory exists before writing to prevent IOErrors.
        # We use abspath to handle relative paths provided by the user in the CLI.
        # This prevents crashes when users specify nested output directories.
        target_dir = os.path.dirname(os.path.abspath(path))
        os.makedirs(target_dir, exist_ok=True)

        # Verification check: we cannot save a non-existent or uninitialized model.
        # The internal self.model must be an instance of ProteinGNN.
        assert self.model is not None, "Model not loaded; nothing to save"

        # Compile the state dictionary (weights) and architecture metadata.
        # This ensures the model can be reconstructed on any machine
        # without needing to guess the hidden_dim or feature counts used.
        # We store the core hyper-parameters (node_features, hidden_dim, etc.)
        # directly in the checkpoint to ensure perfect reproducibility.
        payload = {
            "state_dict": self.model.state_dict(),
            "node_features": self.model.node_features,
            "edge_features": self.model.edge_features,
            "hidden_dim": self.model.hidden_dim,
            "num_classes": self.model.num_classes,
        }

        # Standard PyTorch serialization (uses pickle/zip under the hood).
        # We use standard torch.save() to remain compatible with standard tools
        # like Netron for model visualization.
        torch.save(payload, path)
        self._model_path = path
        logger.info("GNN checkpoint saved to %s", path)

    def load(self, path: str) -> None:
        """Load model weights and re-configure architecture from a checkpoint.

        This method acts as a **Dynamic Factory**: it reads the metadata
        inside the checkpoint to determine the model's width and depth,
        builds the graph attention layers, and then injects the learned weights.
        This allows one classifier instance to seamlessly switch between
        different model variants (e.g. v1 vs v2) or generations.

        -- Portability (CPU/GPU) ----------------------------------------------
        Models trained on a GPU (CUDA) often contain tensors tied to that
        hardware. We use `map_location='cpu'` during the load phase to
        ensure the model can be deployed on standard workstations without
        specialized hardware, then move it to the active device later.
        -----------------------------------------------------------------------

        Args:
            path: Path to a .pt checkpoint written by GNNQualityClassifier.save().

        """
        # Runtime dependency check for torch.
        try:
            import torch
        except ImportError as exc:
            raise ImportError("torch is required to load a GNN checkpoint.") from exc

        # Lazy import of the GNN model class to minimize startup time.
        # ProteinGNN is the core Message Passing Neural Network architecture.
        from .model import ProteinGNN

        try:
            # Load into CPU memory by default for maximum compatibility across systems.
            # weights_only=False is used because our checkpoint is a custom dict
            # rather than a raw tensor stream.
            checkpoint = torch.load(path, map_location="cpu", weights_only=False)

            import typing

            # Instantiate model using dimensions stored inside the checkpoint file.
            # This dynamic reconstruction is key for 'self-describing' assets.
            # node_features and hidden_dim MUST match what was used during training.
            # If they mismatch, layer shapes will be incompatible with the state_dict.
            self.model = typing.cast(
                Any,
                ProteinGNN(
                    node_features=checkpoint["node_features"],
                    edge_features=checkpoint["edge_features"],
                    hidden_dim=checkpoint["hidden_dim"],
                    num_classes=checkpoint["num_classes"],
                ),
            )

            # Inject parameter tensors (weights/biases) into the instantiated layers.
            # This uses the standard state_dict mechanism for weight restoration.
            # This step performs the heavy transfer of floating-point weights.
            self.model.load_state_dict(checkpoint["state_dict"])

            # Switch to evaluation mode (essential: fixes dropout and batchnorm behavior).
            # Without this, the model might produce stochastic or incorrect results
            # due to active Dropout layers.
            self.model.eval()
            self._model_path = path
            self._is_pretrained = True

            # Detect if this checkpoint has the per-residue pLDDT head (v2).
            # We look for the existence of the specific linear layer parameters
            # that only exist in the expanded v2 architecture (residue_lin1).
            # This allows the classifier to gracefully handle both v1 and v2 files.
            self._has_residue_head = "residue_lin1.weight" in checkpoint["state_dict"]

            logger.info(
                "GNN classifier loaded from %s (per-residue head: %s)",
                path,
                self._has_residue_head,
            )
        except Exception as exc:
            # Critical error: model weights are essential for GNN inference.
            # We log the full stack trace to help researchers debug corrupted
            # assets or version mismatches in the architecture/state_dict.
            logger.error("Failed to load GNN checkpoint from %s: %s", path, exc, exc_info=True)
            raise

    # ------------------------------------------------------------------
    # Internal helpers
    # ------------------------------------------------------------------

    def _init_fresh_model(self) -> None:
        """Initialise a randomly-weighted model (v2 architecture).

        Used primarily for integration tests or when bootstrapping a new
        training run from scratch.  Uses standard synth-pdb dimensions.
        """
        import typing

        from .model import ProteinGNN

        # Default architecture for fresh training:
        # - 8 node features (dihedrals, sequence info, b-factors)
        # - 2 edge features (euclidean distance, distance bins)
        self.model = typing.cast(
            Any, ProteinGNN(node_features=8, edge_features=2, hidden_dim=64, num_classes=2)
        )

        # Start in eval mode for safety
        self.model.eval()

        # New models always include the residue head (standard for v2+)
        self._has_residue_head = True
        self._is_pretrained = False

    @property
    def is_pretrained(self) -> bool:
        """Return True if the model weights were loaded from a checkpoint."""
        return self._is_pretrained
Attributes
is_pretrained property

Return True if the model weights were loaded from a checkpoint.

Functions
predict(pdb_content)

Predict the quality of a PDB structure (Legacy Protocol).

This method satisfies the ProteinQualityClassifier protocol, allowing the GNN to act as a drop-in replacement for the Random Forest (RF) classifier used in the core synth-pdb Validator.

-- The API Bridge ----------------------------------------------- By maintaining this exact signature, we enable "Polymorphic Quality": a generation script can accept ANY classifier instance and use it without knowing if it's a fast RF model or a deep GNN.

-- Internal Pipeline -------------------------------------------- 1. PDB Parsing: Extract 3D Calpha coordinates from ATOM records. 2. Graph Construction: Build the spatial interaction graph (graph.py). 3. Message Passing: Propagate geometric data across the graph edges. 4. Global Pooling: Average node embeddings into a whole-protein vector. 5. Softmax Output: Produce final probability P(Good).


Parameters:

Name Type Description Default
pdb_content str

Raw PDB string containing the structural model.

required

Returns:

Name Type Description
is_good bool

True if the structure passes the 0.5 quality cliff.

probability float

The linear confidence score [0, 1].

features dict

Summary statistics of the input geometric features.

Raises:

Type Description
ImportError

If the 'gnn' optional dependencies are missing.

ValueError

If the PDB contains fewer than 2 valid residues.

Source code in synth_pdb/quality/gnn/gnn_classifier.py
def predict(self, pdb_content: str) -> tuple[bool, float, dict[str, float]]:
    """Predict the quality of a PDB structure (Legacy Protocol).

    This method satisfies the `ProteinQualityClassifier` protocol,
    allowing the GNN to act as a **drop-in replacement** for the
    Random Forest (RF) classifier used in the core synth-pdb `Validator`.

    -- The API Bridge -----------------------------------------------
    By maintaining this exact signature, we enable "Polymorphic Quality":
    a generation script can accept ANY classifier instance and use it
    without knowing if it's a fast RF model or a deep GNN.

    -- Internal Pipeline --------------------------------------------
    1. **PDB Parsing**: Extract 3D Calpha coordinates from ATOM records.
    2. **Graph Construction**: Build the spatial interaction graph (graph.py).
    3. **Message Passing**: Propagate geometric data across the graph edges.
    4. **Global Pooling**: Average node embeddings into a whole-protein vector.
    5. **Softmax Output**: Produce final probability P(Good).
    -----------------------------------------------------------------

    Args:
        pdb_content: Raw PDB string containing the structural model.

    Returns:
        is_good (bool): True if the structure passes the 0.5 quality cliff.
        probability (float): The linear confidence score [0, 1].
        features (dict): Summary statistics of the input geometric features.

    Raises:
        ImportError: If the 'gnn' optional dependencies are missing.
        ValueError: If the PDB contains fewer than 2 valid residues.

    """
    # score() handles the heavy lifting; predict() wraps it for the legacy API.
    # This ensures that both the categorical and rich assessment APIs
    # use the exact same weights and graph-building logic.
    result = self.score(pdb_content)
    # We unpack the rich result object into the (bool, float, dict) format
    # expected by the standard synth-pdb Validator interface.
    return (result.label == "High Quality"), result.global_score, result.features
score(pdb_content)

Score a PDB structure, returning a rich :class:QualityScore object.

This is the modern, rich API for the quality system. While predict() returns a simple bool, score() returns the full pLDDT confidence map, allowing for fine-grained structural auditing.

-- The Power of the v2 Model ------------------------------------------ When a v2 checkpoint is loaded, this method activates the auxiliary Per-Residue Head. This head doesn't just judge the whole protein; it performs "structural surgery" to identify exactly which loop is strained or which residue is clashing.

Parameters:

Name Type Description Default
pdb_content str

PDB-format string representing the protein.

required

Returns:

Name Type Description
A QualityScore

class:QualityScore instance with global and residue-level metrics.

Raises:

Type Description
ImportError

If PyTorch or PyG are not installed.

ValueError

If the PDB has insufficient atoms to form a graph.

Source code in synth_pdb/quality/gnn/gnn_classifier.py
def score(self, pdb_content: str) -> QualityScore:
    """Score a PDB structure, returning a rich :class:`QualityScore` object.

    This is the **modern, rich API** for the quality system. While
    `predict()` returns a simple bool, `score()` returns the full
    pLDDT confidence map, allowing for fine-grained structural auditing.

    -- The Power of the v2 Model ------------------------------------------
    When a v2 checkpoint is loaded, this method activates the auxiliary
    **Per-Residue Head**. This head doesn't just judge the whole protein;
    it performs "structural surgery" to identify exactly which loop
    is strained or which residue is clashing.

    Args:
        pdb_content: PDB-format string representing the protein.

    Returns:
        A :class:`QualityScore` instance with global and residue-level metrics.

    Raises:
        ImportError: If PyTorch or PyG are not installed.
        ValueError: If the PDB has insufficient atoms to form a graph.

    """
    # PyTorch and PyG are heavy dependencies (~500MB). We only import
    # them inside this method to allow synth-pdb to maintain its
    # "Lean Core" philosophy. Users who only want PDB generation
    # don't need to install the deep learning stack.
    try:
        import torch
    except ImportError as exc:
        # Provide an actionable error message for the optional [gnn] extra.
        raise ImportError(
            "torch is required for GNNQualityClassifier. "
            "Install with: pip install synth-pdb[gnn]"
        ) from exc

    # Lazy runtime imports for PyTorch Geometric (PyG) utilities.
    # Batch is the container used to pack graphs into high-performance tensors.
    from torch_geometric.data import Batch
    from .graph import build_protein_graph

    # Step 1: Spatial Graph Reconstruction
    # ---------------------------------------------------------------------
    # Convert the raw PDB string into a Topological Interaction Graph.
    # Nodes = Residues, Edges = Contacts < 10 A.
    graph = build_protein_graph(pdb_content)

    # Step 2: Batch Packaging
    # ---------------------------------------------------------------------
    # Even for one structure, we wrap the graph in a Batch object.
    # This is the standard input format for Graph Attention (GAT) layers,
    # as it includes the critical 'batch' mapping vector [Nodes, 1].
    batch = Batch.from_data_list([graph])

    # Step 3: Global Model Audit
    # ---------------------------------------------------------------------
    # Verify the weights were successfully loaded from the checkpoint.
    assert self.model is not None, "Model not loaded"

    # Evaluation Mode: Disables stochastic layers like Dropout and
    # BatchNormalization. This is essential for reproducible results.
    self.model.eval()

    # Inference Optimization: Disable the autograd engine (grad tracking).
    # This increases speed by 2x and saves memory by not building the
    # back-propagation graph.
    with torch.no_grad():
        if self._has_residue_head:
            # DUAL-OUTPUT MODE (v2 Architecture)
            # We retrieve both the pooled global log-probs and the
            # un-pooled node confidence map (pLDDT).
            log_probs, per_res_tensor = type(self.model).forward_with_node_embeddings(
                self.model,
                batch.x,
                batch.edge_index,
                batch.edge_attr,
                batch.batch,
            )
            # Move tensor back to standard Python floats for the user.
            # Squeeze removes the redundant unit dimension: [Nodes, 1] -> [Nodes].
            per_residue = per_res_tensor.squeeze(-1).tolist()
        else:
            # LEGACY MODE (v1 Architecture)
            # Only the global whole-protein classification is available.
            # per_residue is returned as an empty list.
            log_probs = self.model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
            per_residue = []

    # Step 4: Final Probability Calculation
    # ---------------------------------------------------------------------
    # The raw output is in "Log-Softmax" space [-inf, 0].
    # We apply the exponent to get linear probabilities [0, 1].
    # Index 1 corresponds to the 'Good' (Physically Valid) class.
    prob_good = float(log_probs.exp()[0, 1].item())

    # Categorize the prediction using the 0.5 decision boundary.
    label = "High Quality" if prob_good > 0.5 else "Low Quality"

    # Map the numeric pLDDT confidence floats to the AlphaFold color bands.
    # This provides immediate structural insight (e.g. blue = good, orange = bad).
    residue_labels = [_plddt_label(s) for s in per_residue]

    # Step 5: Feature Introspection (Explainability)
    # ---------------------------------------------------------------------
    # We extract the mean of each input node feature (e.g., mean phi angle).
    # This dictionary is returned so researchers can verify IF the GNN is
    # paying attention to the correct physical signals.
    node_feats = graph.x.numpy()
    feat_dict = {
        name: float(np.mean(node_feats[:, i])) for i, name in enumerate(_FEATURE_NAMES)
    }

    # Step 6: Package Result
    # ---------------------------------------------------------------------
    return QualityScore(
        global_score=prob_good,
        label=label,
        per_residue=per_residue,
        residue_labels=residue_labels,
        features=feat_dict,
        n_residues=graph.num_nodes,
    )
score_batch(batch)

Score an entire batch of structures in a single vectorized pass.


VECTORIZED ENSEMBLE AUDITING

While the standard score() method processes a single PDB string, this method is optimized for the high-throughput generation pipeline. It operates directly on the coordinate tensors of a BatchedPeptide, eliminating the CPU-bound bottlenecks of string serialization.

-- GPU Parallelism ---------------------------------------------------------- Instead of iterating structure-by-structure, we leverage PyTorch Geometric's Graph Batching mechanism. All \(B\) structures are packed into a single disjoint Interaction Graph: 1. \(N_{total} = \sum N_i\) nodes are combined into a single matrix. 2. Edge indices are shifted by node offsets to maintain connectivity. 3. A single forward pass on the GPU evaluates the entire ensemble.

This architecture is critical for screening large-scale "Bio-Active" libraries where thousands of candidates must be validated in seconds.

Parameters:

Name Type Description Default
batch Any

A :class:synth_pdb.batch_generator.BatchedPeptide object containing \(B\) protein structures.

required

Returns:

Type Description
list[QualityScore]

A list of \(B\) :class:QualityScore objects, containing global and

list[QualityScore]

per-residue confidence for every member of the ensemble.

Source code in synth_pdb/quality/gnn/gnn_classifier.py
def score_batch(self, batch: Any) -> list[QualityScore]:
    r"""Score an entire batch of structures in a single vectorized pass.

    -----------------------------------------------------------------------------
    VECTORIZED ENSEMBLE AUDITING
    -----------------------------------------------------------------------------
    While the standard ``score()`` method processes a single PDB string,
    this method is optimized for the **high-throughput generation pipeline**.
    It operates directly on the coordinate tensors of a ``BatchedPeptide``,
    eliminating the CPU-bound bottlenecks of string serialization.

    -- GPU Parallelism ----------------------------------------------------------
    Instead of iterating structure-by-structure, we leverage PyTorch
    Geometric's **Graph Batching** mechanism. All $B$ structures are
    packed into a single disjoint Interaction Graph:
      1. $N_{total} = \sum N_i$ nodes are combined into a single matrix.
      2. Edge indices are shifted by node offsets to maintain connectivity.
      3. A single forward pass on the GPU evaluates the entire ensemble.

    This architecture is critical for screening large-scale "Bio-Active"
    libraries where thousands of candidates must be validated in seconds.

    Args:
        batch: A :class:`synth_pdb.batch_generator.BatchedPeptide` object
            containing $B$ protein structures.

    Returns:
        A list of $B$ :class:`QualityScore` objects, containing global and
        per-residue confidence for every member of the ensemble.
    """
    try:
        import torch
        from torch_geometric.data import Batch
    except ImportError:
        # Silent fallback - scoring will be skipped in downstream pipelines
        # if the optional 'gnn' dependencies are missing.
        return []

    from .graph import build_protein_graphs_from_batch

    # 1. Vectorized Graph Construction
    # ---------------------------------------------------------------------
    # We bypass PDB parsing and build graphs directly from the
    # (Batch, Atoms, 3) coordinate tensor using optimized NumPy kernels.
    graphs = build_protein_graphs_from_batch(batch)
    if not graphs:
        # No structures to score (empty batch)
        return []

    # 2. PyG Packing
    # ---------------------------------------------------------------------
    # Convert the list of individual graphs into a single large
    # 'Batch' object for parallel GPU processing.
    pyg_batch = Batch.from_data_list(graphs)

    # 3. Model & Device Synchronization
    # ---------------------------------------------------------------------
    # Ensure the model is loaded and in evaluation mode (disables dropout).
    assert self.model is not None, "Model not loaded"
    self.model.eval()

    # Identify the active device (CPU/CUDA/MPS) and move the entire packed
    # batch tensor to that device in one transfer.
    device = next(self.model.parameters()).device
    pyg_batch = pyg_batch.to(device)

    # 4. Forward Pass (Inference)
    # ---------------------------------------------------------------------
    # Gradient tracking is disabled to save VRAM and increase speed.
    with torch.no_grad():
        if self._has_residue_head:
            # v2 models: Retrieve both global log-probs and local pLDDT.
            # per_res_tensor shape: [TotalNodes, 1] - contains confidence
            # for every atom across all B structures.
            log_probs, per_res_tensor = type(self.model).forward_with_node_embeddings(
                self.model,
                pyg_batch.x,
                pyg_batch.edge_index,
                pyg_batch.edge_attr,
                pyg_batch.batch,
            )
            per_residue_all = per_res_tensor.squeeze(-1).cpu().numpy()
        else:
            # v1 models: Legacy global-only scoring head.
            log_probs = self.model(
                pyg_batch.x, pyg_batch.edge_index, pyg_batch.edge_attr, pyg_batch.batch
            )
            per_residue_all = None

    # 5. Result Post-Processing & Slicing
    # ---------------------------------------------------------------------
    # Convert log-probabilities back to linear probability space [0, 1].
    # probs_good[i] represents P(High Quality) for structure i.
    probs_good = log_probs.exp()[:, 1].cpu().numpy()

    results = []
    # ptr (pointer) array defines the residue boundaries for each graph
    # in the packed batch tensor. node_ptr[i] to node_ptr[i+1] is the slice.
    node_ptr = pyg_batch.ptr.cpu().numpy()

    for i in range(len(graphs)):
        prob_good = float(probs_good[i])
        label = "High Quality" if prob_good > 0.5 else "Low Quality"

        # Slice the per-residue pLDDT scores using the node pointers
        plddt = []
        res_lbls = []
        if per_residue_all is not None:
            # Extract the confidence values for this specific structure
            plddt = per_residue_all[node_ptr[i] : node_ptr[i + 1]].tolist()
            # Map to human-readable AlphaFold confidence bands
            res_lbls = [_plddt_label(s) for s in plddt]

        # Extract aggregate feature statistics (e.g. mean dihedrals)
        # for downstream transparency and debugging.
        node_feats = graphs[i].x.numpy()
        feat_dict = {
            name: float(np.mean(node_feats[:, j])) for j, name in enumerate(_FEATURE_NAMES)
        }

        # Wrap in structured QualityScore container
        results.append(
            QualityScore(
                global_score=prob_good,
                label=label,
                per_residue=plddt,
                residue_labels=res_lbls,
                features=feat_dict,
                n_residues=graphs[i].num_nodes,
            )
        )

    return results
save(path)

Save model weights and architecture config to a .pt checkpoint.

The checkpoint is 'self-describing' - it includes the layer dimensions required to re-instantiate the ProteinGNN class without an external config or JSON file. This follows the 'Single File Per Asset' philosophy which simplifies deployment, model versioning, and cloud-based distribution of pre-trained structural auditors.

-- Persistence Mechanism ---------------------------------------------- We use standard torch.save(), which uses Python's pickle module internally. The payload includes: 1. The state_dict: An OrderedDict mapping layer names to parameter tensors (weights and biases). 2. Architectural Hyperparameters: hidden_dim, node_features, etc.


Parameters:

Name Type Description Default
path str

Destination file path (should end in .pt).

required
Source code in synth_pdb/quality/gnn/gnn_classifier.py
def save(self, path: str) -> None:
    """Save model weights and architecture config to a ``.pt`` checkpoint.

    The checkpoint is 'self-describing' - it includes the layer dimensions
    required to re-instantiate the `ProteinGNN` class without an external
    config or JSON file.  This follows the **'Single File Per Asset'**
    philosophy which simplifies deployment, model versioning, and
    cloud-based distribution of pre-trained structural auditors.

    -- Persistence Mechanism ----------------------------------------------
    We use standard `torch.save()`, which uses Python's `pickle` module
    internally. The payload includes:
      1. The `state_dict`: An OrderedDict mapping layer names to parameter
         tensors (weights and biases).
      2. Architectural Hyperparameters: `hidden_dim`, `node_features`, etc.
    -----------------------------------------------------------------------

    Args:
        path: Destination file path (should end in .pt).

    """
    # Ensure we have torch before trying to save.
    # This is a safety check for systems where torch is an optional extra.
    try:
        import torch
    except ImportError as exc:
        raise ImportError("torch is required to save a GNN checkpoint.") from exc

    # Ensure destination directory exists before writing to prevent IOErrors.
    # We use abspath to handle relative paths provided by the user in the CLI.
    # This prevents crashes when users specify nested output directories.
    target_dir = os.path.dirname(os.path.abspath(path))
    os.makedirs(target_dir, exist_ok=True)

    # Verification check: we cannot save a non-existent or uninitialized model.
    # The internal self.model must be an instance of ProteinGNN.
    assert self.model is not None, "Model not loaded; nothing to save"

    # Compile the state dictionary (weights) and architecture metadata.
    # This ensures the model can be reconstructed on any machine
    # without needing to guess the hidden_dim or feature counts used.
    # We store the core hyper-parameters (node_features, hidden_dim, etc.)
    # directly in the checkpoint to ensure perfect reproducibility.
    payload = {
        "state_dict": self.model.state_dict(),
        "node_features": self.model.node_features,
        "edge_features": self.model.edge_features,
        "hidden_dim": self.model.hidden_dim,
        "num_classes": self.model.num_classes,
    }

    # Standard PyTorch serialization (uses pickle/zip under the hood).
    # We use standard torch.save() to remain compatible with standard tools
    # like Netron for model visualization.
    torch.save(payload, path)
    self._model_path = path
    logger.info("GNN checkpoint saved to %s", path)
load(path)

Load model weights and re-configure architecture from a checkpoint.

This method acts as a Dynamic Factory: it reads the metadata inside the checkpoint to determine the model's width and depth, builds the graph attention layers, and then injects the learned weights. This allows one classifier instance to seamlessly switch between different model variants (e.g. v1 vs v2) or generations.

-- Portability (CPU/GPU) ---------------------------------------------- Models trained on a GPU (CUDA) often contain tensors tied to that hardware. We use map_location='cpu' during the load phase to ensure the model can be deployed on standard workstations without specialized hardware, then move it to the active device later.


Parameters:

Name Type Description Default
path str

Path to a .pt checkpoint written by GNNQualityClassifier.save().

required
Source code in synth_pdb/quality/gnn/gnn_classifier.py
def load(self, path: str) -> None:
    """Load model weights and re-configure architecture from a checkpoint.

    This method acts as a **Dynamic Factory**: it reads the metadata
    inside the checkpoint to determine the model's width and depth,
    builds the graph attention layers, and then injects the learned weights.
    This allows one classifier instance to seamlessly switch between
    different model variants (e.g. v1 vs v2) or generations.

    -- Portability (CPU/GPU) ----------------------------------------------
    Models trained on a GPU (CUDA) often contain tensors tied to that
    hardware. We use `map_location='cpu'` during the load phase to
    ensure the model can be deployed on standard workstations without
    specialized hardware, then move it to the active device later.
    -----------------------------------------------------------------------

    Args:
        path: Path to a .pt checkpoint written by GNNQualityClassifier.save().

    """
    # Runtime dependency check for torch.
    try:
        import torch
    except ImportError as exc:
        raise ImportError("torch is required to load a GNN checkpoint.") from exc

    # Lazy import of the GNN model class to minimize startup time.
    # ProteinGNN is the core Message Passing Neural Network architecture.
    from .model import ProteinGNN

    try:
        # Load into CPU memory by default for maximum compatibility across systems.
        # weights_only=False is used because our checkpoint is a custom dict
        # rather than a raw tensor stream.
        checkpoint = torch.load(path, map_location="cpu", weights_only=False)

        import typing

        # Instantiate model using dimensions stored inside the checkpoint file.
        # This dynamic reconstruction is key for 'self-describing' assets.
        # node_features and hidden_dim MUST match what was used during training.
        # If they mismatch, layer shapes will be incompatible with the state_dict.
        self.model = typing.cast(
            Any,
            ProteinGNN(
                node_features=checkpoint["node_features"],
                edge_features=checkpoint["edge_features"],
                hidden_dim=checkpoint["hidden_dim"],
                num_classes=checkpoint["num_classes"],
            ),
        )

        # Inject parameter tensors (weights/biases) into the instantiated layers.
        # This uses the standard state_dict mechanism for weight restoration.
        # This step performs the heavy transfer of floating-point weights.
        self.model.load_state_dict(checkpoint["state_dict"])

        # Switch to evaluation mode (essential: fixes dropout and batchnorm behavior).
        # Without this, the model might produce stochastic or incorrect results
        # due to active Dropout layers.
        self.model.eval()
        self._model_path = path
        self._is_pretrained = True

        # Detect if this checkpoint has the per-residue pLDDT head (v2).
        # We look for the existence of the specific linear layer parameters
        # that only exist in the expanded v2 architecture (residue_lin1).
        # This allows the classifier to gracefully handle both v1 and v2 files.
        self._has_residue_head = "residue_lin1.weight" in checkpoint["state_dict"]

        logger.info(
            "GNN classifier loaded from %s (per-residue head: %s)",
            path,
            self._has_residue_head,
        )
    except Exception as exc:
        # Critical error: model weights are essential for GNN inference.
        # We log the full stack trace to help researchers debug corrupted
        # assets or version mismatches in the architecture/state_dict.
        logger.error("Failed to load GNN checkpoint from %s: %s", path, exc, exc_info=True)
        raise

QualityScore dataclass

Rich quality assessment result for a single protein structure.

This object serves as the Data Transfer Object (DTO) between the internal GNN inference engine and the end-user. It encapsulates not just a binary "Good/Bad" label, but a high-resolution map of the structure's physical confidence.

-- pLDDT: The Standard of Confidence ------------------------------------- The per_residue scores are modeled after the predicted Local Distance Difference Test (pLDDT), the primary confidence metric used by AlphaFold. Values in [0, 1] represent the model's certainty that a residue is in its physically correct local environment.

Attributes

global_score : float The "Whole-Protein" probability P(Good) in [0, 1]. This is the output of the GNN's global pooling layer followed by a log-softmax. * 0.9 - 1.0: Extremely confident, well-folded model. * 0.5 - 0.9: Likely valid but may have minor local strains. * < 0.5 : "Low Quality" - likely contains unphysical geometry. label : str A human-readable categorical label ("High Quality" or "Low Quality") derived from the 0.5 global_score threshold. per_residue : list[float] The "Confidence Heatmap". Each float represents the pLDDT of an individual residue. Length equals the number of residues (Calpha atoms). This is generated by the auxiliary regression head in v2 models. residue_labels : list[str] The AlphaFold-standardized categorical bands: * "Very High" (>= 0.90) : Crystallographic-quality geometry. * "High" (>= 0.70) : Generally reliable backbone. * "Uncertain" (>= 0.50) : Low-confidence loop or linker. * "Low" (< 0.50) : Unphysical/Clashing region. features : dict[str, float] A dictionary of the mean input node features (sin_phi, cos_phi, etc.). This is provided for Explainable AI (XAI) - it helps researchers understand if a low score was triggered by bad dihedrals or high B-factors. n_residues : int The total number of nodes (amino acids) in the interaction graph.

Examples

clf = GNNQualityClassifier() result = clf.score(pdb_string) print(f"Protein quality: {result.label} ({result.global_score:.1%})")

Identify local errors

clashes = [i for i, s in enumerate(result.per_residue) if s < 0.5] print(f"Detected {len(clashes)} problematic residues.")

Source code in synth_pdb/quality/gnn/gnn_classifier.py
@dataclass
class QualityScore:
    """Rich quality assessment result for a single protein structure.

    This object serves as the **Data Transfer Object (DTO)** between the
    internal GNN inference engine and the end-user. It encapsulates not just
    a binary "Good/Bad" label, but a high-resolution map of the structure's
    physical confidence.

    -- pLDDT: The Standard of Confidence -------------------------------------
    The `per_residue` scores are modeled after the **predicted Local Distance
    Difference Test (pLDDT)**, the primary confidence metric used by AlphaFold.
    Values in [0, 1] represent the model's certainty that a residue is in its
    physically correct local environment.

    Attributes
    ----------
    global_score : float
        The "Whole-Protein" probability P(Good) in [0, 1]. This is the output
         of the GNN's global pooling layer followed by a log-softmax.
         * 0.9 - 1.0: Extremely confident, well-folded model.
         * 0.5 - 0.9: Likely valid but may have minor local strains.
         * < 0.5    : "Low Quality" - likely contains unphysical geometry.
    label : str
        A human-readable categorical label ("High Quality" or "Low Quality")
        derived from the 0.5 global_score threshold.
    per_residue : list[float]
        The "Confidence Heatmap". Each float represents the pLDDT of an
        individual residue. Length equals the number of residues (Calpha atoms).
        This is generated by the auxiliary regression head in v2 models.
    residue_labels : list[str]
        The AlphaFold-standardized categorical bands:
        * "Very High" (>= 0.90) : Crystallographic-quality geometry.
        * "High"      (>= 0.70) : Generally reliable backbone.
        * "Uncertain" (>= 0.50) : Low-confidence loop or linker.
        * "Low"       (< 0.50) : Unphysical/Clashing region.
    features : dict[str, float]
        A dictionary of the mean input node features (sin_phi, cos_phi, etc.).
        This is provided for **Explainable AI (XAI)** - it helps researchers
        understand if a low score was triggered by bad dihedrals or high B-factors.
    n_residues : int
        The total number of nodes (amino acids) in the interaction graph.

    Examples
    --------
    >>> clf = GNNQualityClassifier()
    >>> result = clf.score(pdb_string)
    >>> print(f"Protein quality: {result.label} ({result.global_score:.1%})")
    >>> # Identify local errors
    >>> clashes = [i for i, s in enumerate(result.per_residue) if s < 0.5]
    >>> print(f"Detected {len(clashes)} problematic residues.")
    """

    global_score: float
    label: str
    per_residue: list[float] = field(default_factory=list)
    residue_labels: list[str] = field(default_factory=list)
    features: dict[str, float] = field(default_factory=dict)
    n_residues: int = 0

See Also