Source code for seqme.models.hyformer

from enum import Enum

import numpy as np
import torch
from packaging.version import Version
from tqdm import tqdm

from .exceptions import OptionalDependencyError

_MAX_SEQUENCE_LENGTH = 512


class HyformerCheckpoint(str, Enum):
    """
    Hyformer checkpoints from Izdebski et al.

    Available checkpoints:
        - molecules_8M: 8M parameters, 8 layers, embedding dim 256, pretrained on GuacaMol dataset [Brown et al.]
        - molecules_50M: 50M parameters, 12 layers, embedding dim 512, pretrained on Uni-Mol dataset [Zhou et al.]
        - peptides_34M: 34M parameters, 8 layers, embedding dim 512, pretrained on combined general-purpose peptide and AMP datasets [Izdebski et al.]
        - peptides_34M_mic: 34M parameters, 8 layers, embedding dim 512, pretrained on combined general-purpose peptide and MIC datasets [Izdebski et al.] and subsequently jointly fine-tuned on peptides with (log2 transformed) MIC values against E. coli bacteria [Szymczak et al.]

    If used for prediction, pre-trained models, i.e., `molecules_8M` and `molecules_50M` and `peptides_34M`, predict the physicochemical properties used for pre-training.
    Jointly fine-tuned model `peptides_34M_mic` predicts the log2 transformed MIC values against E. coli bacteria.

    Reference:
        Izdebski et al. "Synergistic Benefits of Joint Molecule Generation and Property Prediction"
        Brown et al. "GuacaMol: benchmarking models for de novo molecular design"
        Zhou et al. "Uni-mol: A universal 3d molecular representation learning framework"
        Szymczak et al. "Discovering highly potent antimicrobial peptides with deep generative model hydramp"
    """

    # molecules checkpoints
    molecules_8M = "SzczurekLab/hyformer_molecules_8M"
    molecules_50M = "SzczurekLab/hyformer_molecules_50M"

    # peptides checkpoints
    peptides_34M = "SzczurekLab/hyformer_peptides_34M"
    peptides_34M_mic = "SzczurekLab/hyformer_peptides_34M_mic"


[docs] class Hyformer: """ Wrapper for the Hyformer molecule/peptide embedding model. Computes sequence-level embeddings by extracting the [CLS] token embedding. Installation for molecules: ``pip install "seqme[hyformer_molecules]" "hyformer @ git+https://github.com/szczurek-lab/hyformer.git@main"`` Installation for peptides: ``pip install "seqme[hyformer]" "hyformer @ git+https://github.com/szczurek-lab/hyformer.git@v2.0"``. Reference: Izdebski et al., "Synergistic Benefits of Joint Molecule Generation and Property Prediction" (https://arxiv.org/abs/2504.16559) """
[docs] def __init__( self, model_name: HyformerCheckpoint | str, *, device: str | None = None, batch_size: int = 256, cache_dir: str | None = None, verbose: bool = False, ): """ Initialize Hyformer model. Args: model_name: Model checkpoint name or enum. device: Device to run inference on, e.g., ``"cuda"`` or ``"cpu"``. batch_size: Number of sequences to process per batch. cache_dir: Directory to cache the model. verbose: Whether to display a progress bar. """ if isinstance(model_name, HyformerCheckpoint): model_name = model_name.value if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" self.batch_size = batch_size self.device = device self.verbose = verbose try: from hyformer import AutoModel, AutoTokenizer from hyformer import __version__ as hyformer_version from hyformer.utils import create_dataloader except ModuleNotFoundError: raise OptionalDependencyError("hyformer") from None self.tokenizer = AutoTokenizer.from_pretrained(model_name, local_dir=cache_dir) self.model = AutoModel.from_pretrained(model_name, local_dir=cache_dir) # Hyformer version-specific attributes self._version = Version(hyformer_version) self._create_dataloader_fn = create_dataloader self._generative_task_key = "generation" if self._version < Version("2.0.0") else "lm" self._predictive_task_key = "prediction" self._max_sequence_length = ( self.tokenizer.max_molecule_length if self._version < Version("2.0.0") else _MAX_SEQUENCE_LENGTH ) self._logits_generation_key = "logits_generation" if self._version < Version("2.0.0") else "logits" self._logits_prediction_key = "logits_physchem" if self._version < Version("2.0.0") else "logits" self.model.to(device) self.model.eval()
[docs] def __call__(self, sequences: list[str]) -> np.ndarray: return self.embed(sequences)
def generate( self, num_samples: int, temperature: float = 1.0, top_k: int | None = None, seed: int = 0 ) -> list[str]: """Generate sequences de novo. Delegates to the legacy generation path for Hyformer versions prior to 2.0, otherwise uses the newer generation API. Args: num_samples: Number of sequences to produce. temperature: Sampling temperature passed to the decoder. top_k: Optional top-k sampling parameter. seed: Random seed forwarded to the underlying generator. Returns: A list of generated sequences, truncated to ``num_samples`` items. """ if self._version < Version("2.0.0"): return self._generate_legacy(num_samples, temperature, top_k, seed) else: return self._generate(num_samples, temperature, top_k, seed) def _generate_legacy( self, num_samples: int, temperature: float = 1.0, top_k: int | None = None, seed: int = 0 ) -> list[str]: generated = [] for _ in tqdm(range(0, num_samples, self.batch_size), "Generating samples"): samples: list[str] = self.model.generate( self.tokenizer, min(num_samples, self.batch_size), temperature, top_k, self.device ) generated.extend(self.tokenizer.decode(samples)) return generated[:num_samples] def _generate( self, num_samples: int, temperature: float = 1.0, top_k: int | None = None, seed: int = 0 ) -> list[str]: _PREFIX_INPUT_IDS = torch.tensor( [[self.tokenizer.task_token_id(self._generative_task_key), self.tokenizer.bos_token_id]] * self.batch_size, dtype=torch.long, device=self.device, ) _USE_CACHE = False generated_samples = [] with torch.inference_mode(): for _ in tqdm(range(0, num_samples, self.batch_size), "Generating samples"): outputs = self.model.generate( prefix_input_ids=_PREFIX_INPUT_IDS, num_tokens_to_generate=self._max_sequence_length - len(_PREFIX_INPUT_IDS[0]), eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id, temperature=temperature, top_k=top_k, top_p=None, use_cache=_USE_CACHE, seed=seed, ) generated_samples.extend(self.tokenizer.decode(outputs)) return generated_samples[:num_samples] def predict(self, sequences: list[str]) -> np.ndarray: """ Compute predictions for a list of sequences. Each sequence is tokenized and passed through the model. Token predictions are [CLS] token predictions. Args: sequences: List of input sequences. Returns: A NumPy array of shape (n_sequences, num_prediction_tasks) containing the predictions. """ _TASKS = {self._predictive_task_key: 1.0} _dataloader = self._create_dataloader_fn( dataset=sequences, tasks=_TASKS, tokenizer=self.tokenizer, batch_size=min(len(sequences), self.batch_size), shuffle=False, ) predictions = [] with torch.inference_mode(): for batch in tqdm(_dataloader, disable=not self.verbose): batch = batch.to_device(self.device) batch_predictions = self.model.predict( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"] ) batch_predictions = ( batch_predictions[self._logits_prediction_key] if self._version < Version("2.0.0") else batch_predictions ) predictions.append(batch_predictions.cpu().numpy()) return np.concatenate(predictions, axis=0) def embed(self, sequences: list[str]) -> np.ndarray: """ Compute embeddings for a list of sequences. Each sequence is tokenized and passed through the model. Token embeddings are [CLS] token embeddings. Args: sequences: List of input amino acid sequences. Returns: A NumPy array of shape (n_sequences, embedding_dim) containing the embeddings. """ _CLS_TOKEN_IDX = 0 _TASKS = {self._predictive_task_key: 1.0} _dataloader = self._create_dataloader_fn( dataset=sequences, tasks=_TASKS, tokenizer=self.tokenizer, batch_size=min(len(sequences), self.batch_size), shuffle=False, ) embeddings = [] with torch.inference_mode(): for batch in tqdm(_dataloader, disable=not self.verbose): batch = batch.to_device(self.device) output = self.model(**batch, return_loss=False) batch_embeddings = output["embeddings"][:, _CLS_TOKEN_IDX].detach().cpu().numpy() embeddings.append(batch_embeddings) return np.concatenate(embeddings, axis=0) def compute_perplexity(self, sequences: list[str]) -> np.ndarray: """ Compute perplexity for a list of sequences. Args: sequences: List of sequences. Returns: np.ndarray: Perplexity scores, in the same order as the input sequences. """ _TASKS = {self._generative_task_key: 1.0} _dataloader = self._create_dataloader_fn( dataset=sequences, tasks=_TASKS, tokenizer=self.tokenizer, batch_size=min(len(sequences), self.batch_size), shuffle=False, ) perplexities = [] with torch.inference_mode(): for batch in tqdm(_dataloader, disable=not self.verbose): batch = batch.to_device(self.device) output = self.model(**batch, return_loss=False) batch_perplexities = _compute_perplexity( logits=output[self._logits_generation_key], labels=batch["input_labels"], ) perplexities.append(batch_perplexities.cpu().numpy()) return np.concatenate(perplexities)
def _compute_perplexity(logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100) -> torch.Tensor: # shift logits and labels by one logits = logits[:, :-1] targets = labels[:, 1:] mask = targets != ignore_index token_nll = torch.nn.functional.cross_entropy(logits.transpose(1, 2), targets, reduction="none") nll = (token_nll * mask).sum(dim=1) / mask.sum(dim=1) return nll.exp()