Source code for seqme.models.gena_lm

from enum import Enum

import numpy as np
import torch
from tqdm import tqdm

from .exceptions import OptionalDependencyError


class GENALMCheckpoint(Enum):
    """
    GENA-LM checkpoints.

    Embedding checkpoints:
        - bert_base_t2t: 110M parameters, 12 layers, embedding dim: 768, max sequence length: 512bp - Trained on T2T+1000G SNPs.
        - bert_base_t2t_lastln_t2t: 110M parameters, 12 layers, embedding dim: 768, max sequence length: 512 bps - Trained on T2T+1000G SNPs.
        - bert_base_t2t_multi: 110M parameters, 12 layers, embedding dim: 768, max sequence length: 512bp - Trained on T2T+1000G SNPs+Multispecies.
        - bert_large_t2t: 336M parameters, 24 layers, embedding dim: 1024, max sequence length: 512bp - Trained on T2T+1000G SNPs.
        - bigbird_base_t2t: 110M parameters, 12 layers, embedding dim: 768, max sequence length: 4096bp - Trained on T2T+1000G SNPs.

        Note: In practice the model has M+1 layers. The last layer is a LayerNorm.

    Downstream classification checkpoints:
        - bert_base_t2t_promoters: 110M parameters, 12 layers, task sequence length: 300bp.
            Classification: determining the absence (0) or presence (1) of a promoter within a given region.

        - bert_large_t2t_promoters: 336M parameters, 24 layers, task sequence length: 300bp.
            Classification: determining the absence (0) or presence (1) of a promoter within a given region.

        - bert_large_t2t_promoters2: 336M parameters, 24 layers, task sequence length: 2000bp.
            Classification: determining the absence (0) or presence (1) of a promoter within a given region.

        - bert_base_t2t_splice_site: 110M parameters, 12 layers, task sequence length: 15000bp. Identifies splicing sites.
            Classification: neither (0), splice acceptor (1) or splice doner (2)

        - bert_large_t2t_splice_site: 336M parameters, 24 layers, task sequence length: 15000bp. Identifies splicing sites.
            Classification: neither (0), splice acceptor (1) or splice doner (2)
    """

    # Embedding
    bert_base_t2t = ("AIRI-Institute/gena-lm-bert-base-t2t", None)
    bert_base_t2t_lastln_t2t = ("AIRI-Institute/gena-lm-bert-base-lastln-t2t", None)
    bert_base_t2t_multi = ("AIRI-Institute/gena-lm-bert-base-t2t-multi", None)
    bert_large_t2t = ("AIRI-Institute/gena-lm-bert-large-t2t", None)
    bigbird_base_t2t = ("AIRI-Institute/gena-lm-bigbird-base-t2t", None)

    # Downstream classification
    bert_base_t2t_promoters = ("AIRI-Institute/gena-lm-bert-base-t2t", "promoters_300_run_1")
    bert_large_t2t_promoters = ("AIRI-Institute/gena-lm-bert-large-t2t", "promoters_300_run_1")
    bert_large_t2t_promoters2 = ("AIRI-Institute/gena-lm-bert-large-t2t", "promoters_2000_run_1")

    bert_base_t2t_splice_site = ("AIRI-Institute/gena-lm-bert-base-t2t", "spliceai_run_1")
    bert_large_t2t_splice_site = ("AIRI-Institute/gena-lm-bert-large-t2t", "spliceai_run_1")


class Task(Enum):
    EMBEDDING = "embedding"
    CLASSIFICATION = "classification"


_TASK = {
    GENALMCheckpoint.bert_base_t2t: Task.EMBEDDING,
    GENALMCheckpoint.bert_base_t2t_lastln_t2t: Task.EMBEDDING,
    GENALMCheckpoint.bert_base_t2t_multi: Task.EMBEDDING,
    GENALMCheckpoint.bert_large_t2t: Task.EMBEDDING,
    GENALMCheckpoint.bigbird_base_t2t: Task.EMBEDDING,
    GENALMCheckpoint.bert_base_t2t_promoters: Task.CLASSIFICATION,
    GENALMCheckpoint.bert_large_t2t_promoters: Task.CLASSIFICATION,
    GENALMCheckpoint.bert_large_t2t_promoters2: Task.CLASSIFICATION,
    GENALMCheckpoint.bert_base_t2t_splice_site: Task.CLASSIFICATION,
    GENALMCheckpoint.bert_large_t2t_splice_site: Task.CLASSIFICATION,
}


[docs] class GENALM: """ GENA-LM is a family of Open-Source Foundational Models for Long DNA Sequences trained on human DNA sequence. Computes sequence-level embeddings by averaging token embeddings. Installation: ``pip install "seqme[genalm]"`` Reference: Fishman et al., "GENA-LM: a family of open-source foundational DNA language models for long sequences" (https://academic.oup.com/nar/article/53/2/gkae1310/7954523) """
[docs] def __init__( self, model_name: GENALMCheckpoint, *, device: str | None = None, batch_size: int = 256, cache_dir: str | None = None, verbose: bool = False, ): """ Initialize model. Args: model_name: Model checkpoint name. device: Device to run inference on, e.g., ``"cuda"`` or ``"cpu"``. batch_size: Number of sequences to process per batch. cache_dir: Directory to cache the model. verbose: Whether to display a progress bar. """ if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" self.batch_size = batch_size self.device = device self.verbose = verbose try: from transformers import AutoModel, AutoTokenizer, BertForSequenceClassification except ModuleNotFoundError: raise OptionalDependencyError("genalm") from None self.task = _TASK[model_name] ckpt_name, branch_name = model_name.value self.tokenizer = AutoTokenizer.from_pretrained(ckpt_name) if self.task == Task.EMBEDDING: self.model = AutoModel.from_pretrained( ckpt_name, trust_remote_code=True, output_hidden_states=True, return_dict_in_generate=True, ) elif self.task == Task.CLASSIFICATION: self.model = BertForSequenceClassification.from_pretrained( ckpt_name, revision=branch_name, trust_remote_code=True, cache_dir=cache_dir ) else: raise ValueError(f"Invalid task: {self.task}.") self.model.to(device) self.model.eval()
[docs] def __call__(self, sequences: list[str]) -> np.ndarray: return self.embed(sequences) if self.task == Task.EMBEDDING else self.classify(sequences)
@torch.inference_mode() def embed(self, sequences: list[str], layer: int = -1) -> np.ndarray: """ Compute embeddings for a list of sequences. Each sequence is tokenized and passed through the model. Token embeddings are averaged to produce a single embedding per sequence. Args: sequences: List of DNA sequences. layer: Embedding layer. Returns: A NumPy array of shape (n_sequences, embedding_dim) containing the embeddings. """ if self.task != Task.EMBEDDING: raise ValueError(f"Expected embedding model got {self.task} model.") embeddings = [] for i in tqdm(range(0, len(sequences), self.batch_size), disable=not self.verbose): batch = sequences[i : i + self.batch_size] tokens = self.tokenizer(batch, return_tensors="pt", padding=True, truncation=False) tokens = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in tokens.items()} hidden_state = self.model(**tokens)["hidden_states"][layer] attention_mask = tokens["attention_mask"].unsqueeze(-1) attention_sum = attention_mask.sum(dim=1) batch_embeddings = torch.sum(hidden_state * attention_mask, dim=1) / attention_sum embeddings.append(batch_embeddings.cpu().numpy()) return np.concatenate(embeddings) @torch.inference_mode() def classify(self, sequences: list[str]) -> np.ndarray: """ Classify a list of sequences. Args: sequences: List of DNA sequences. Returns: A NumPy array of size (n_sequences, 2) for promoter prediction and (n_sequences, 3) for splice-site prediction. """ if self.task != Task.CLASSIFICATION: raise ValueError(f"Expected classification model got {self.task} model.") probs = [] for i in tqdm(range(0, len(sequences), self.batch_size), disable=not self.verbose): batch = sequences[i : i + self.batch_size] tokens = self.tokenizer(batch, return_tensors="pt", padding=True, truncation=False, add_special_tokens=True) tokens = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in tokens.items()} logits = self.model(**tokens)["logits"] batch_prob = torch.softmax(logits, dim=-1) probs.append(batch_prob.cpu().numpy()) return np.concatenate(probs)