Skip to content

dataset Module

The dataset module provides tools for orchestrating the large-scale generation of synthetic protein datasets for AI model training.

Overview

Generating diverse, balanced datasets is critical for training robust deep learning models like AlphaFold or RosettaFold. The dataset module automates the production of thousands of (Structure, Sequence, Constraint) triplets.

Main Classes

DatasetGenerator

Orchestrates the generation of large-scale synthetic protein datasets for AI model training.

EDUCATIONAL NOTE - The Balanced Dataset Problem:

When training an AI model (like AlphaFold or a Forcefield predictor), the quality and BALANCE of the data are often more important than the quantity.

  1. The Alpha-Helix Trap: If you only generate structures using the 'alpha' preset, your AI will learn that all biology looks like a helix. This leads to "Halls of Mirrors" where the model fails on Beta sheets or intrinsically disordered regions (IDRs).
  2. Mixed Conformations: This generator encourages specifying a mix of 'alpha', 'beta', and 'random' conformations. A dataset that "covers" the Ramachandran plot uniformly ensures the AI learns both the rules and the exceptions of protein geometry.
  3. Structural Diversity: By varying 'length' and 'conformation', we minimize "Selection Bias", making the resulting AI model more robust and generalizable.

Data Factory Overview:

AI models for protein folding (like AlphaFold, RoseTTAFold) require massive datasets of (Structure, Sequence) pairs to learn the patterns of protein physics. Real PDB data is limited (~200k structures). Synthetic data allows us to: 1. Augment training data with unlimited diversity. 2. Balance the dataset (e.g., more examples of rare secondary structures). 3. Create "uncurated" datasets to test model robustness.

This generator produces: - PDB files (coordinates) - Contact Maps (distance constraints) - Metadata Manifest (CSV)

Source code in synth_pdb/dataset.py
class DatasetGenerator:
    """Orchestrates the generation of large-scale synthetic protein datasets for AI model training.

    EDUCATIONAL NOTE - The Balanced Dataset Problem:
    -----------------------------------------------
    When training an AI model (like AlphaFold or a Forcefield predictor), the
    quality and BALANCE of the data are often more important than the quantity.

    1. The Alpha-Helix Trap: If you only generate structures using the 'alpha'
       preset, your AI will learn that *all* biology looks like a helix. This
       leads to "Halls of Mirrors" where the model fails on Beta sheets or
       intrinsically disordered regions (IDRs).
    2. Mixed Conformations: This generator encourages specifying a mix of
       'alpha', 'beta', and 'random' conformations. A dataset that "covers"
       the Ramachandran plot uniformly ensures the AI learns both the rules
       and the exceptions of protein geometry.
    3. Structural Diversity: By varying 'length' and 'conformation', we minimize
       "Selection Bias", making the resulting AI model more robust and generalizable.

    Data Factory Overview:
    ----------------------
    AI models for protein folding (like AlphaFold, RoseTTAFold) require massive datasets
    of (Structure, Sequence) pairs to learn the patterns of protein physics.
    Real PDB data is limited (~200k structures). Synthetic data allows us to:
    1. Augment training data with unlimited diversity.
    2. Balance the dataset (e.g., more examples of rare secondary structures).
    3. Create "uncurated" datasets to test model robustness.

    This generator produces:
    - PDB files (coordinates)
    - Contact Maps (distance constraints)
    - Metadata Manifest (CSV)
    """

    def __init__(
        self,
        output_dir: str,
        num_samples: int = 100,
        min_length: int = 10,
        max_length: int = 50,
        train_ratio: float = 0.8,
        seed: Optional[int] = None,
        max_workers: Optional[int] = None,
        dataset_format: str = "pdb",
    ):
        self.output_dir = Path(output_dir).absolute()
        self.num_samples = num_samples
        self.min_length = min_length
        self.max_length = max_length
        self.train_ratio = train_ratio
        self.max_workers = max_workers

        if seed is not None:
            random.seed(seed)
            np.random.seed(seed)

        self.dataset_format = dataset_format.lower() if dataset_format else "pdb"

    def prepare_directories(self) -> None:
        """Create the directory structure for the dataset."""
        train_dir = self.output_dir / "train"
        test_dir = self.output_dir / "test"

        train_dir.mkdir(parents=True, exist_ok=True)
        test_dir.mkdir(parents=True, exist_ok=True)

        # Initialize manifest if it doesn't exist
        manifest_path = self.output_dir / "dataset_manifest.csv"
        if not manifest_path.exists():
            with open(manifest_path, "w", newline="") as f:
                writer = csv.writer(f)
                if self.dataset_format == "npz":
                    writer.writerow(["id", "length", "conformation", "split", "npz_path"])
                else:
                    writer.writerow(
                        ["id", "length", "conformation", "split", "pdb_path", "cmap_path"]
                    )

    def generate(self) -> None:
        """Run the generation loop using multiprocessing."""
        import multiprocessing

        # Determine CPUs
        if self.max_workers is None:
            self.max_workers = max(1, multiprocessing.cpu_count() - 1)

        logger.info(
            f"Starting bulk generation of {self.num_samples} samples using {self.max_workers} cores..."
        )
        self.prepare_directories()

        manifest_path = self.output_dir / "dataset_manifest.csv"

        # Prepare Tasks
        tasks = []
        for i in range(self.num_samples):
            sample_id = f"synth_{i:06d}"

            # 1. Randomize Parameters (in main process for determinism with seed)
            length = random.randint(self.min_length, self.max_length)

            # weighted choice for conformation complexity
            conf_type = random.choices(
                ["alpha", "beta", "random", "ppii", "extended"], weights=[0.3, 0.3, 0.3, 0.05, 0.05]
            )[0]

            is_train = random.random() < self.train_ratio
            split = "train" if is_train else "test"

            # Pass format-specific args
            if self.dataset_format == "npz":
                tasks.append((sample_id, length, conf_type, split, str(self.output_dir), "npz"))
            else:
                tasks.append((sample_id, length, conf_type, split, str(self.output_dir), "pdb"))

        # Execute
        completed_count = 0
        # Determine appropriate task function
        task_func = (
            _generate_single_sample_npz_task
            if self.dataset_format == "npz"
            else _generate_single_sample_task
        )

        with open(manifest_path, "a", newline="") as f:
            writer = csv.writer(f)

            with concurrent.futures.ProcessPoolExecutor(max_workers=self.max_workers) as executor:
                # Submit all tasks
                future_to_id = {executor.submit(task_func, task): task[0] for task in tasks}

                for future in concurrent.futures.as_completed(future_to_id):
                    sample_id = future_to_id[future]
                    try:
                        result = future.result()
                        if result["success"]:
                            if self.dataset_format == "npz":
                                writer.writerow(
                                    [
                                        result["sample_id"],
                                        result["length"],
                                        result["conformation"],
                                        result["split"],
                                        result["npz_path"],
                                    ]
                                )
                            else:
                                writer.writerow(
                                    [
                                        result["sample_id"],
                                        result["length"],
                                        result["conformation"],
                                        result["split"],
                                        result["pdb_path"],
                                        result["cmap_path"],
                                    ]
                                )
                            completed_count += 1
                        else:
                            logger.error(f"Failed to generate {sample_id}: {result.get('error')}")

                        # Logging progress
                        if completed_count % 100 == 0:
                            logger.info(
                                f"Progress: {completed_count}/{self.num_samples} ({completed_count / self.num_samples * 100:.1f}%)"
                            )

                    except Exception as exc:
                        logger.error(f"Generate task generated an exception: {exc}")

        logger.info(
            f"Bulk generation complete. Generated {completed_count}/{self.num_samples} samples."
        )

Functions

__init__(output_dir, num_samples=100, min_length=10, max_length=50, train_ratio=0.8, seed=None, max_workers=None, dataset_format='pdb')

Source code in synth_pdb/dataset.py
def __init__(
    self,
    output_dir: str,
    num_samples: int = 100,
    min_length: int = 10,
    max_length: int = 50,
    train_ratio: float = 0.8,
    seed: Optional[int] = None,
    max_workers: Optional[int] = None,
    dataset_format: str = "pdb",
):
    self.output_dir = Path(output_dir).absolute()
    self.num_samples = num_samples
    self.min_length = min_length
    self.max_length = max_length
    self.train_ratio = train_ratio
    self.max_workers = max_workers

    if seed is not None:
        random.seed(seed)
        np.random.seed(seed)

    self.dataset_format = dataset_format.lower() if dataset_format else "pdb"

generate()

Run the generation loop using multiprocessing.

Source code in synth_pdb/dataset.py
def generate(self) -> None:
    """Run the generation loop using multiprocessing."""
    import multiprocessing

    # Determine CPUs
    if self.max_workers is None:
        self.max_workers = max(1, multiprocessing.cpu_count() - 1)

    logger.info(
        f"Starting bulk generation of {self.num_samples} samples using {self.max_workers} cores..."
    )
    self.prepare_directories()

    manifest_path = self.output_dir / "dataset_manifest.csv"

    # Prepare Tasks
    tasks = []
    for i in range(self.num_samples):
        sample_id = f"synth_{i:06d}"

        # 1. Randomize Parameters (in main process for determinism with seed)
        length = random.randint(self.min_length, self.max_length)

        # weighted choice for conformation complexity
        conf_type = random.choices(
            ["alpha", "beta", "random", "ppii", "extended"], weights=[0.3, 0.3, 0.3, 0.05, 0.05]
        )[0]

        is_train = random.random() < self.train_ratio
        split = "train" if is_train else "test"

        # Pass format-specific args
        if self.dataset_format == "npz":
            tasks.append((sample_id, length, conf_type, split, str(self.output_dir), "npz"))
        else:
            tasks.append((sample_id, length, conf_type, split, str(self.output_dir), "pdb"))

    # Execute
    completed_count = 0
    # Determine appropriate task function
    task_func = (
        _generate_single_sample_npz_task
        if self.dataset_format == "npz"
        else _generate_single_sample_task
    )

    with open(manifest_path, "a", newline="") as f:
        writer = csv.writer(f)

        with concurrent.futures.ProcessPoolExecutor(max_workers=self.max_workers) as executor:
            # Submit all tasks
            future_to_id = {executor.submit(task_func, task): task[0] for task in tasks}

            for future in concurrent.futures.as_completed(future_to_id):
                sample_id = future_to_id[future]
                try:
                    result = future.result()
                    if result["success"]:
                        if self.dataset_format == "npz":
                            writer.writerow(
                                [
                                    result["sample_id"],
                                    result["length"],
                                    result["conformation"],
                                    result["split"],
                                    result["npz_path"],
                                ]
                            )
                        else:
                            writer.writerow(
                                [
                                    result["sample_id"],
                                    result["length"],
                                    result["conformation"],
                                    result["split"],
                                    result["pdb_path"],
                                    result["cmap_path"],
                                ]
                            )
                        completed_count += 1
                    else:
                        logger.error(f"Failed to generate {sample_id}: {result.get('error')}")

                    # Logging progress
                    if completed_count % 100 == 0:
                        logger.info(
                            f"Progress: {completed_count}/{self.num_samples} ({completed_count / self.num_samples * 100:.1f}%)"
                        )

                except Exception as exc:
                    logger.error(f"Generate task generated an exception: {exc}")

    logger.info(
        f"Bulk generation complete. Generated {completed_count}/{self.num_samples} samples."
    )

prepare_directories()

Create the directory structure for the dataset.

Source code in synth_pdb/dataset.py
def prepare_directories(self) -> None:
    """Create the directory structure for the dataset."""
    train_dir = self.output_dir / "train"
    test_dir = self.output_dir / "test"

    train_dir.mkdir(parents=True, exist_ok=True)
    test_dir.mkdir(parents=True, exist_ok=True)

    # Initialize manifest if it doesn't exist
    manifest_path = self.output_dir / "dataset_manifest.csv"
    if not manifest_path.exists():
        with open(manifest_path, "w", newline="") as f:
            writer = csv.writer(f)
            if self.dataset_format == "npz":
                writer.writerow(["id", "length", "conformation", "split", "npz_path"])
            else:
                writer.writerow(
                    ["id", "length", "conformation", "split", "pdb_path", "cmap_path"]
                )

Main Functions

_generate_single_sample_task(args)

Helper function to generate a single sample. Arguments are passed as a tuple to be compatible with map/submit if needed, but we'll use unpacking for clarity.

_generate_single_sample_npz_task(args)

Generate a single sample in NPZ format (AI-Ready). Does NOT write intermediate PDB files.

Usage Examples

Bulk Dataset Generation

Generate a balanced dataset of 1,000 structures with varied secondary structures and lengths.

from synth_pdb.dataset import DatasetGenerator

generator = DatasetGenerator(
    output_dir="./synthetic_dataset",
    num_samples=1000,
    min_length=30,
    max_length=150,
    train_ratio=0.8,
    max_workers=8
)

generator.generate()

The resulting directory will contain: - train/: PDB and CASP (contact map) files for training. - test/: PDB and CASP files for testing. - dataset_manifest.csv: A manifest mapping IDs to file paths and metadata.

AI-Ready NPZ Export

For deep learning frameworks, it is often more efficient to store data in compressed NumPy format.

generator = DatasetGenerator(
    output_dir="./ai_dataset",
    dataset_format="npz"
)
generator.generate()

Educational Notes

The Balanced Dataset Problem

When training AI models, the quality and balance of the data are often more important than the quantity. 1. The Alpha-Helix Trap: If a dataset only contains helices, the AI will fail to generalize to beta-sheets or disordered regions. 2. Mixed Conformations: This module encourages a mix of 'alpha', 'beta', and 'random' conformations to ensure the model learns the full breadth of protein geometry. 3. Structural Diversity: Varying lengths and sequences minimizes "Selection Bias," leading to more robust models.

Why Distance Matrices instead of Binary Contact Maps?

Binary contact maps (0/1) indicate whether atoms are within a threshold (usually 8.0 Å). While common, they discard detailed geometric information. Modern models (like AlphaFold) use Distograms (weighted distance bins) or raw distances to learn a continuous representation of the energy landscape. The dataset module can export exact ground-truth distances to support these advanced training objectives.

See Also