Skip to content

📐 Geometry API

chain_nerf(init_coords, bond_lengths, bond_angles, dihedrals)

Build a chain of atoms using the NeRF algorithm.

Parameters:

Name Type Description Default
init_coords ndarray

(3, 3) initial coordinates for the first 3 atoms

required
bond_lengths ndarray

(N,) bond lengths for atoms 4 to N+3

required
bond_angles ndarray

(N,) bond angles (in radians) for atoms 4 to N+3

required
dihedrals ndarray

(N,) dihedral angles (in radians) for atoms 4 to N+3

required

Returns:

Type Description
ndarray

jnp.ndarray: (N+3, 3) coordinates for the entire chain

Source code in diff_biophys/geometry/nerf.py
@jit
def chain_nerf(init_coords: jnp.ndarray, bond_lengths: jnp.ndarray, 
               bond_angles: jnp.ndarray, dihedrals: jnp.ndarray) -> jnp.ndarray:
    """
    Build a chain of atoms using the NeRF algorithm.

    Args:
        init_coords: (3, 3) initial coordinates for the first 3 atoms
        bond_lengths: (N,) bond lengths for atoms 4 to N+3
        bond_angles: (N,) bond angles (in radians) for atoms 4 to N+3
        dihedrals: (N,) dihedral angles (in radians) for atoms 4 to N+3

    Returns:
        jnp.ndarray: (N+3, 3) coordinates for the entire chain
    """
    def body_fun(carry, i):
        p1, p2, p3 = carry
        p4 = position_atom_3d(p1, p2, p3, bond_lengths[i], bond_angles[i], dihedrals[i])
        return (p2, p3, p4), p4

    indices = jnp.arange(len(bond_lengths))
    init_carry = (init_coords[0], init_coords[1], init_coords[2])
    _, final_coords = lax.scan(body_fun, init_carry, indices)

    return jnp.concatenate([init_coords, final_coords], axis=0)

position_atom_3d(p1, p2, p3, bond_length, bond_angle_rad, dihedral_angle_rad)

Differentiable NeRF implementation in JAX for a single atom.

Source code in diff_biophys/geometry/nerf.py
@jit
def position_atom_3d(p1: jnp.ndarray, p2: jnp.ndarray, p3: jnp.ndarray, 
                     bond_length: jnp.ndarray, bond_angle_rad: jnp.ndarray, dihedral_angle_rad: jnp.ndarray) -> jnp.ndarray:
    """
    Differentiable NeRF implementation in JAX for a single atom.
    """
    v1 = p1 - p2
    v2 = p3 - p2

    u2 = v2 / (jnp.linalg.norm(v2) + 1e-10)

    n = jnp.cross(v1, u2)
    n /= (jnp.linalg.norm(n) + 1e-10)

    m = jnp.cross(n, u2)

    p4 = p3 + bond_length * (
        -jnp.cos(bond_angle_rad) * u2 
        - jnp.sin(bond_angle_rad) * jnp.cos(dihedral_angle_rad) * m 
        - jnp.sin(bond_angle_rad) * jnp.sin(dihedral_angle_rad) * n
    )
    return p4

kabsch_alignment(P, Q)

Optimal superposition of P onto Q using Kabsch algorithm in JAX.

Parameters:

Name Type Description Default
P ndarray

(N, 3) mobile coordinates

required
Q ndarray

(N, 3) reference coordinates

required

Returns:

Type Description
tuple[ndarray, ndarray]

tuple[jnp.ndarray, jnp.ndarray]: (3x3 rotation matrix, 3-element translation vector)

Source code in diff_biophys/geometry/superposition.py
@jit
def kabsch_alignment(P: jnp.ndarray, Q: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
    """
    Optimal superposition of P onto Q using Kabsch algorithm in JAX.

    Args:
        P: (N, 3) mobile coordinates
        Q: (N, 3) reference coordinates

    Returns:
        tuple[jnp.ndarray, jnp.ndarray]: (3x3 rotation matrix, 3-element translation vector)
    """
    p_center = jnp.mean(P, axis=0)
    q_center = jnp.mean(Q, axis=0)

    P_c = P - p_center
    Q_c = Q - q_center

    H = jnp.dot(P_c.T, Q_c)

    U, S, Vt = jnp.linalg.svd(H)

    d = jnp.linalg.det(jnp.dot(Vt.T, U.T))
    step = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, jnp.where(d > 0, 1.0, -1.0)]])

    R = jnp.dot(Vt.T, jnp.dot(step, U.T))
    t = q_center - jnp.dot(R, p_center)

    return R, t