Source code for seqme.metrics.authenticity

from collections.abc import Callable
from typing import Literal

import numpy as np
from sklearn.neighbors import NearestNeighbors

from seqme.core.base import Metric, MetricResult


[docs] class AuthPct(Metric): """ Proportion of authentic generated samples. Authenticity is defined as the fraction of sequences whose nearest training neighbor is closer to some other training sample than to the sequence. References: [1] Alaa et al., "How Faithful is your Synthetic Data? Sample-level Metrics for Evaluating and Auditing Generative Models." (2022). (https://arxiv.org/abs/2102.08921) """
[docs] def __init__( self, train_set: list[str], embedder: Callable[[list[str]], np.ndarray], *, name: str = "Authenticity", ): """Initialize the metric. Args: train_set: List of sequences used to train the generative model. embedder: A function mapping a list of sequences to a 2D NumPy array of embeddings. name: Metric name. """ self.train_set = train_set self.embedder = embedder self._name = name self.train_set_embeddings = self.embedder(self.train_set) if self.train_set_embeddings.shape[0] == 0: raise ValueError("Reference embeddings must contain at least one sample.")
[docs] def __call__(self, sequences: list[str]) -> MetricResult: """ Compute the authenticity score based on the embeddings of the input sequences and the train set. Args: sequences: Sequences to evaluate. Returns: MetricResult: Authenticity score. """ if len(sequences) == 0: raise ValueError("Sequences must contain at least one sample.") embeddings = self.embedder(sequences) auth_score = compute_authenticity( real_data=self.train_set_embeddings, synthetic_data=embeddings, ) return MetricResult(value=auth_score)
@property def name(self) -> str: return self._name @property def objective(self) -> Literal["minimize", "maximize"]: return "maximize"
def compute_authenticity(real_data: np.ndarray, synthetic_data: np.ndarray) -> float: """ Authenticity is defined as the fraction of sequences whose nearest training neighbor is closer to some other training sample than to the sequence. Args: real_data: Embeddings of the real data. synthetic_data: Embeddings of the synthetic data. Returns: Authenticity score in [0, 1]. """ knn_real = NearestNeighbors(n_neighbors=1, n_jobs=-1).fit(real_data) dist_synth_to_real, closest_real_per_synth_idx = knn_real.kneighbors(synthetic_data) dist_real_to_real, _ = knn_real.kneighbors() auth_mask = dist_synth_to_real > dist_real_to_real[closest_real_per_synth_idx.squeeze(axis=-1)] authenticity = np.mean(auth_mask) return authenticity