Source code for seqme.models.rna_fm
from typing import Literal
import numpy as np
import torch
from tqdm import tqdm
from .exceptions import OptionalDependencyError
[docs]
class RNAFM:
"""
A language model trained on RNA sequences, which computes sequence-level embeddings by averaging token embeddings.
Two checkpoints are available:
- mRNA: 239M parameters, 12 layers, embedding dim 1280, trained on 45 million mRNA coding sequences (CDS). Must be codon aligned.
- ncRNA: 99M parameters, 12 layers, embedding dim 640, trained on 23.7 million non-coding RNA (ncRNA) sequences.
Installation: ``pip install "seqme[rnafm]"``
Reference:
Chen et al., "Interpretable RNA Foundation Model from Unannotated Data for Highly Accurate RNA Structure and Function Predictions"
(https://arxiv.org/pdf/2204.00300)
"""
[docs]
def __init__(
self,
*,
model_name: Literal["mRNA", "ncRNA"] = "mRNA",
device: str | None = None,
batch_size: int = 256,
verbose: bool = False,
):
"""
Initialize model.
Args:
model_name: Either a mRNA or ncRNA checkpoint.
device: Device to run inference on, e.g., ``"cuda"`` or ``"cpu"``.
batch_size: Number of sequences to process per batch.
verbose: Whether to display a progress bar.
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.model_name = model_name
self.batch_size = batch_size
self.device = device
self.verbose = verbose
try:
import fm
except ModuleNotFoundError:
raise OptionalDependencyError("rnafm") from None
model, alphabet = fm.pretrained.rna_fm_t12() if self.model_name == "ncRNA" else fm.pretrained.mrna_fm_t12()
batch_converter = alphabet.get_batch_converter()
self.model = model.to(device).eval()
self.batch_converter = batch_converter
[docs]
def __call__(self, sequences: list[str]) -> np.ndarray:
return self.embed(sequences)
@torch.inference_mode()
def embed(self, sequences: list[str], layer: int = 12) -> np.ndarray:
"""
Compute embeddings for the RNA sequences.
Each sequence is tokenized and passed through the model.
Token embeddings are averaged to produce a single embedding per sequence.
Args:
sequences: RNA sequences to embed.
layer: Embedding layer. Last layer is 12.
Returns:
A NumPy array of shape (n_sequences, embedding_dim) containing the embeddings.
"""
if self.model_name == "mRNA":
for sequence in sequences:
if len(sequence) % 3 != 0:
raise ValueError(f"Found non-codon aligned sequence with {len(sequence)} nucleotides.")
embeddings = []
for i in tqdm(range(0, len(sequences), self.batch_size), disable=not self.verbose):
batch = sequences[i : i + self.batch_size]
named_batch = [("", b) for b in batch]
tokens = self.batch_converter(named_batch)[2].to(self.device)
results = self.model(tokens, repr_layers=[layer])
hidden_state = results["representations"][layer]
lengths = [len(s) // 3 if self.model_name == "mRNA" else len(s) for s in batch]
means = [hidden_state[i, 1 : length + 1].mean(dim=-2) for i, length in enumerate(lengths)]
batch_embeddings = torch.stack(means, dim=0)
embeddings.append(batch_embeddings.cpu().numpy())
return np.concatenate(embeddings)