Benchmarking mRNA’s mRFP-expression

Contents

Benchmarking mRNA’s mRFP-expression#

In this notebook, we apply seqme to mRNA sequences.

import subprocess
from collections.abc import Callable

import numpy as np
import pandas as pd
import torch
from sklearn.linear_model import LinearRegression

import seqme as sm
device = "cuda" if torch.cuda.is_available() else "cpu"

Data#

Let’s download the mRNA sequences.

DATA_PATH = "mRFP_Expression.csv"
DATA_URL = "https://raw.githubusercontent.com/Sanofi-Public/CodonBERT/refs/heads/master/benchmarks/CodonBERT/data/fine-tune/mRFP_Expression.csv"

wget_command = ["wget", "-O", DATA_PATH, DATA_URL]

try:
    subprocess.run(wget_command, check=True)
    print(f"File downloaded successfully as {DATA_PATH}")
except subprocess.CalledProcessError as e:
    print(f"Error occurred: {e}")

Hide code cell output

--2026-02-08 16:08:20--  https://raw.githubusercontent.com/Sanofi-Public/CodonBERT/refs/heads/master/benchmarks/CodonBERT/data/fine-tune/mRFP_Expression.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 2606:50c0:8001::154, 2606:50c0:8003::154, 2606:50c0:8002::154, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8001::154|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1048970 (1.0M) [text/plain]
Saving to: ‘mRFP_Expression.csv’

     0K .......... .......... .......... .......... ..........  4%  876K 1s
    50K .......... .......... .......... .......... ..........  9%  981K 1s
   100K .......... .......... .......... .......... .......... 14% 1.42M 1s
   150K .......... .......... .......... .......... .......... 19%  802K 1s
   200K .......... .......... .......... .......... .......... 24%  804K 1s
   250K .......... .......... .......... .......... .......... 29%  771K 1s
   300K .......... .......... .......... .......... .......... 34%  720K 1s
   350K .......... .......... .......... .......... .......... 39%  504K 1s
   400K .......... .......... .......... .......... .......... 43%  365K 1s
   450K .......... .......... .......... .......... .......... 48%  645K 1s
   500K .......... .......... .......... .......... .......... 53%  510K 1s
   550K .......... .......... .......... .......... .......... 58%  960K 1s
   600K .......... .......... .......... .......... .......... 63%  866K 1s
   650K .......... .......... .......... .......... .......... 68%  915K 0s
   700K .......... .......... .......... .......... .......... 73%  922K 0s
   750K .......... .......... .......... .......... .......... 78% 1.12M 0s
   800K .......... .......... .......... .......... .......... 82%  432K 0s
   850K .......... .......... .......... .......... .......... 87%  764K 0s
   900K .......... .......... .......... .......... .......... 92%  466K 0s
   950K .......... .......... .......... .......... ..
File downloaded successfully as mRFP_Expression.csv
........ 97%  273K 0s
  1000K .......... .......... ....                            100%  303K=1.6s

2026-02-08 16:08:22 (629 KB/s) - ‘mRFP_Expression.csv’ saved [1048970/1048970]
dataset = pd.read_csv(DATA_PATH)

train_data = dataset[dataset["Split"] == "train"]
eval_data = dataset[dataset["Split"] == "val"]
test_data = dataset[dataset["Split"] == "test"]
dataset.head()
Sequence Value Dataset Split
0 AUGGCAUCAUCAGAAGACGUCAUAAAAGAAUUUAUGCGAUUCAAAG... 10.164760 mRFP Expression train
1 AUGGCGUCUUCAGAGGAUGUAAUCAAGGAAUUCAUGCGUUUUAAGG... 10.572869 mRFP Expression train
2 AUGGCAUCAUCGGAAGAUGUAAUAAAGGAAUUUAUGCGUUUCAAAG... 9.766912 mRFP Expression train
3 AUGGCGAGUAGUGAAGACGUUAUCAAAGAAUUUAUGCGUUUUAAGG... 9.926981 mRFP Expression train
4 AUGGCUUCUUCUGAGGACGUAAUAAAGGAGUUCAUGAGGUUCAAGG... 9.857074 mRFP Expression train

Models#

Let’s define the models.

class LinearRegressor:
    def __init__(self, embedder: Callable[[list[str]], np.ndarray]):
        self.embedder = embedder
        self.regressor = LinearRegression()

    def __call__(self, sequences: list[str]) -> np.ndarray:
        return self.predict(sequences)

    def predict(self, sequences: list[str]) -> np.ndarray:
        embeddings = self.embedder(sequences)
        return self.regressor.predict(embeddings)

    def fit(self, sequences: list[str], labels: np.ndarray):
        embeddings = self.embedder(sequences)
        self.regressor.fit(embeddings, labels)
CACHE_PATH = None  # "mrna_precomputed.pkl"

precomputed = sm.read_pickle(CACHE_PATH) if CACHE_PATH else None
cache = sm.Cache(init_cache=precomputed)

Lets setup the embedding model.

cache.add("RNA-fm", sm.models.RNAFM(verbose=True))

Lets setup a model predicting mRFP-expression (and train it).

xs = train_data["Sequence"]
labels = train_data["Value"]

regressor = LinearRegressor(embedder=cache.model("RNA-fm"))
regressor.fit(xs, labels)

cache.add("regressor", regressor)
100%|██████████| 4/4 [02:31<00:00, 37.83s/it]

Let’s look at an UMAP of the mRNA using RNA-FM.

seqs = list(train_data["Sequence"])
values = np.array(train_data["Value"])

xs = cache.model("RNA-fm")(seqs)
xs_umap = sm.utils.umap(xs)
OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.
sm.utils.plot_embeddings(
    xs_umap,
    values=values,
    title="mRFP-expression",
    xlabel="UMAP1",
    ylabel="UMAP2",
    alpha=0.8,
    point_size=4,
    figsize=(5, 4),
)
../_images/bc8a981bb1e506021bf0288d9b7cfd6331ad06c51758bd959e6108d58f9edcbc.png

Benchmark#

Let’s run the benchmark.

mRFP_sequences = list(test_data["Sequence"])
mRFP_shuffled = sm.utils.shuffle_characters(mRFP_sequences)

df = sm.evaluate(
    sequences={
        "mRFP": mRFP_sequences,
        "mRFP (permuted)": mRFP_shuffled,
    },
    metrics=[
        sm.metrics.ID(predictor=cache.model("regressor"), name="mRFP-expression", objective="maximize"),
        sm.metrics.Threshold(
            predictor=cache.model("regressor"),
            name="mRFP-expression (>10.5)",
            min_value=10.5,
        ),
        sm.metrics.FKEA(embedder=cache.model("RNA-fm"), bandwidth=2.0),
    ],
)
100%|██████████| 1/1 [00:31<00:00, 31.83s/it]FP, metric=mRFP-expression]
100%|██████████| 1/1 [00:32<00:00, 32.38s/it] data=mRFP (permuted), metric=mRFP-expression]
100%|██████████| 6/6 [01:11<00:00, 11.98s/it, data=mRFP (permuted), metric=FKEA]                   
sm.show(df)
  mRFP-expression↑ mRFP-expression (>10.5)↑ FKEA↑
mRFP 10.01±0.08 0.33 9.30
mRFP (permuted) -1.43±0.50 0.05 71.00

Let’s save the cache to a file.

CACHE_PATH = "mrna_precomputed.pkl"

sm.to_pickle(cache.get(), CACHE_PATH)