Beyond proteins primary structure

Beyond proteins primary structure#

In this notebook, we show how to evaluate (and visualize) sequences when we also have access to their 3D-structure.

Protein’s 3D-structure may be available from an online database and stored in a PDB-file. However, since the primary purpose of seqme is to evaluate novel proteins, we will assume that the 3D-structure is not available. Thus, we will use ESM-fold to predict the protein’s 3D-structure. ESM-fold is included in seqme.

# !pip install "seqme[esmfold]" BioPython py3Dmol
from collections.abc import Callable
from functools import partial
from typing import Literal

import numpy as np
import py3Dmol
from Bio.Align import PairwiseAligner

import seqme as sm

Hide code cell content

def _kabsch_transform(ref: np.ndarray, mob: np.ndarray) -> np.ndarray:
    """Rotate and translate mob to minimise RMSD against ref.

    Args:
        ref: Reference coordinates, shape ``(N, 3)``.
        mob: Mobile coordinates to superpose onto ref, shape ``(N, 3)``.

    Returns:
        Superposed coordinates, shape ``(N, 3)``.
    """
    ref_center = ref.mean(axis=0)
    mob_center = mob.mean(axis=0)
    r = ref - ref_center
    m = mob - mob_center

    U, _, Vt = np.linalg.svd(m.T @ r)
    d = np.linalg.det(Vt.T @ U.T)
    D = np.diag([1.0, 1.0, d])
    R = Vt.T @ D @ U.T

    return (m @ R.T) + ref_center


def rmsd(a: np.ndarray, b: np.ndarray) -> float:
    return float(np.sqrt(((a - b) ** 2).sum() / a.shape[0]))


def indices(s1: str, s2: str) -> np.ndarray:
    res = []
    i = 0
    for c1, c2 in zip(s1, s2, strict=True):
        if c1 != "-":
            if c2 != "-":
                res.append(i)
            i += 1
    return np.array(res, dtype=np.int32)


def compute_rmsd(coords1: np.ndarray, coords2: np.ndarray, seq1: str, seq2: str) -> float:
    align = PairwiseAligner().align(seq1, seq2)[0]
    a_seq1, a_seq2 = align[0], align[1]
    coords1 = coords1[indices(a_seq1, a_seq2)]
    coords2 = coords2[indices(a_seq2, a_seq1)]

    coords2_aligned = _kabsch_transform(coords1, coords2)
    return rmsd(coords1, coords2_aligned)

Let’s define a metric which uses atomic positions. Here we use RMSD.

class RMSD(sm.Metric):
    """Root mean square deviation of atomic positions."""

    def __init__(self, reference: str, folder: Callable[[list[str]], np.ndarray]):
        self.reference = reference
        self.folder = folder

    def __call__(self, sequences: list[str]) -> sm.MetricResult:
        ref_coords = self.folder([self.reference])[0]
        sequences_coords = self.folder(sequences)

        scores = np.array(
            [
                compute_rmsd(seq_coords, ref_coords, seq, self.reference)
                for seq, seq_coords in zip(sequences, sequences_coords, strict=True)
            ]
        )

        return sm.MetricResult(scores.mean().item())

    @property
    def name(self) -> str:
        return "RMSD"

    @property
    def objective(self) -> Literal["minimize", "maximize"]:
        return "minimize"

Let’s define our protein folding model.

cache = sm.Cache(models={"esm-fold": partial(sm.models.ESMFold().fold, convention="atom37", compute_ptm=True)})

Note: We will only extract the position of the amino-acids “Cα” from the fold prediction to compute the RMSD metric.

esm_fold = cache.model("esm-fold", stack=False)

ptm_fn = lambda sequences: np.array([fold["ptm"] for fold in esm_fold(sequences)])
positions_fn = lambda sequences: [fold["positions"][:, 1, :] for fold in esm_fold(sequences)]  # CA's index = 1
plddt_fn = lambda sequences: np.array([fold["plddt"].mean() for fold in esm_fold(sequences)])
pae_fn = lambda sequences: np.array([fold["pae"].mean() for fold in esm_fold(sequences)])

Let’s also compute the self-consistency perplexity (scPerplexity). To do so, we need a folding model (we again use ESM-fold) and an inverse-folding model (we now also use ESM-IF1).

ESM-IF1 expects the coordinates of the proteins amino-acids atoms: N, CA, C. So we extract those from ESM-Fold’s predictions below first.

# Protein folding
atom_indices = [0, 1, 2]  # atoms: N, CA, C
folder = lambda sequences: [fold["positions"][:, atom_indices, :] for fold in esm_fold(sequences)]

# Inverse protein folding
inv_perplexity = sm.models.ThirdPartyModel(
    entry_point="esmif1.model:compute_perplexity",
    path="../thirdparty/esmif1",
    url="https://github.com/szczurek-lab/seqme-esmif1",
)
cache.add("scPerplexity", lambda sequences: inv_perplexity(folder(sequences), sequences))

inv_recovery = sm.models.ThirdPartyModel(
    entry_point="esmif1.model:compute_sequence_recovery",
    path="../thirdparty/esmif1",
    url="https://github.com/szczurek-lab/seqme-esmif1",
)
cache.add("recovery", lambda sequences: inv_recovery(folder(sequences), sequences))

Notice esm-fold is stored in the cache and we reuse it for ptm, plddt, positions and scPerplexity. Hence, we only call esm-fold once per sequence in total!

Let’s create the metric and sequences. We will evaluate 4 proteins.

sequences = {
    "SAV_STRAV": [
        "MRKIVVAAIAVSLTTVSITASASADPSKDSKAQVSAAEAGITGTWYNQLGSTFIVTAGADGALTGTYESAVGNAESRYVLTGRYDSAPATDGSGTALGWTVAWKNNYRNAHSATTWSGQYVGGAEARINTQWLLTSGTTEANAWKSTLVGHDTFTKVKPSAASIDAAKKAGVNNGNPLDAVQQ"
    ],
    "AVID_CHICK": [
        "MVHATSPLLLLLLLSLALVAPGLSARKCSLTGKWTNDLGSNMTIGAVNSRGEFTGTYITAVTATSNEIKESPLHGTQNTINKRTQPTFGFTVNWKFSESTTVFTGQCFIDRNGKEVLKTMWLLRSSVNDIGDDWKATRVGINIFTRLRTQKE"
    ],
    "GNAT1_HUMAN": [
        "MGAGASAEEKHSRELEKKLKEDAEKDARTVKLLLLGAGESGKSTIVKQMKIIHQDGYSLEECLEFIAIIYGNTLQSILAIVRAMTTLNIQYGDSARQDDARKLMHMADTIEEGTMPKEMSDIIQRLWKDSGIQACFERASEYQLNDSAGYYLSDLERLVTPGYVPTEQDVLRSRVKTTGIIETQFSFKDLNFRMFDVGGQRSERKKWIHCFEGVTCIIFIAALSAYDMVLVEDDEVNRMHESLHLFNSICNHRYFATTSIVLFLNKKDVFFEKIKKAHLSICFPDYDGPNTYEDAGNYIKVQFLELNMRRDVKEIYSHMTCATDTQNVKFVFDAVTDIIIKENLKDCGLF"
    ],
}

sequences["GNAT1_HUMAN (shuffled)"] = sm.utils.shuffle_characters(sequences["GNAT1_HUMAN"])
metrics = [
    sm.metrics.ID(predictor=cache.model("recovery"), name="Sequence recovery", objective="maximize"),
    sm.metrics.ID(predictor=cache.model("scPerplexity"), name="scPerplexity", objective="minimize"),
    RMSD(reference=sequences["SAV_STRAV"][0], folder=positions_fn),
    sm.metrics.ID(predictor=ptm_fn, name="pTM", objective="maximize"),
    sm.metrics.ID(predictor=plddt_fn, name="pLDDT", objective="maximize"),
    sm.metrics.ID(predictor=pae_fn, name="pAE", objective="minimize"),
]

Let’s compute the metrics.

df = sm.evaluate(sequences, metrics)
100%|██████████| 24/24 [12:47<00:00, 31.99s/it, data=GNAT1_HUMAN (shuffled), metric=pAE]              
sm.show(df, color_style="bar")
  Sequence recovery↑ scPerplexity↓ RMSD↓ pTM↑ pLDDT↑ pAE↓
SAV_STRAV 0.45 4.38 0.00 0.74 0.82 11.12
AVID_CHICK 0.37 5.20 16.74 0.71 0.79 10.74
GNAT1_HUMAN 0.57 2.74 21.89 0.95 0.97 2.66
GNAT1_HUMAN (shuffled) 0.10 13.88 21.96 0.16 0.27 26.84

Let’s visualize the 3D-structure of a protein.

def visualize_structure(
    pdb: str,
    min_plddt: float = 0.0,
    max_plddt: float = 1.0,
    rotation: float = 0,
    zoom: float | None = None,
    width: int = 400,
    height: int = 400,
):
    """Visualize 3D-structure and color amino-acids by confidence (plddt)."""
    view = py3Dmol.view(js="https://3dmol.org/build/3Dmol.js", width=width, height=height)
    view.addModel(pdb, "pdb")
    view.setStyle({"cartoon": {"colorscheme": {"prop": "b", "gradient": "roygb", "min": min_plddt, "max": max_plddt}}})
    view.rotate(rotation)
    view.zoom(zoom) if zoom is not None else view.zoomTo()
    view.show()
name = "GNAT1_HUMAN"

folds = {name: esm_fold(seqs) for name, seqs in sequences.items()}
pdb = folds[name][0]["pdb"]

visualize_structure(pdb, rotation=90, min_plddt=0.8)
print(f"{pdb[:600]} \t ... etc.")
PARENT N/A
ATOM      1  N   MET A   1       6.889  17.156  65.667  1.00  0.84           N  
ATOM      2  CA  MET A   1       5.656  16.625  65.093  1.00  0.87           C  
ATOM      3  C   MET A   1       5.435  17.166  63.684  1.00  0.85           C  
ATOM      4  CB  MET A   1       4.458  16.969  65.979  1.00  0.79           C  
ATOM      5  O   MET A   1       5.065  18.329  63.512  1.00  0.74           O  
ATOM      6  CG  MET A   1       4.479  16.279  67.333  1.00  0.72           C  
ATOM      7  SD  MET A   1       3.037  16.730  68.375  1.00  0.77           S  
ATOM      8  CE  MET A 	 ... etc.