Source code for seqme.metrics.fold

from typing import Literal

import numpy as np

from seqme.core.base import Metric, MetricResult


[docs] class Fold(Metric): """A wrapper for any metric, which splits the sequences into non-overlapping subsets, computes the metric on each split and aggregates the results. Fold splits the data into k-folds or fixed-size splits, with optional shuffling, and then aggregates the results. """
[docs] def __init__( self, metric: Metric, *, deviation: Literal["std", "se", "var"] = "std", estimate: Literal["biased", "unbiased"] = "unbiased", n_splits: int | None = None, split_size: int | None = None, drop_last: bool = False, strict: bool = True, shuffle: bool = False, seed: int = 0, ): """ Initialize the Fold wrapper. Args: metric: The underlying metric to evaluate per fold. deviation: Type of deviation to compute: - ``'std'``: Standard deviation - ``'se'``: Standard error - ``'var'``: Variance estimate: How to estimate the deviation. n_splits: Number of folds to create (exclusive with ``split_size``). split_size: Fixed size for each fold (exclusive with ``n_splits``). drop_last: Drop final fold if smaller than ``split_size``. strict: Error on any non-null fold deviation. shuffle: Shuffle data before splitting. seed: Seed for deterministic shuffling of sequences when creating folds. """ self.metric = metric self.deviation = deviation self.estimate = estimate self.n_splits = n_splits self.split_size = split_size self.drop_last = drop_last self.strict = strict self.shuffle = shuffle self.seed = seed if (self.n_splits is not None) and (self.split_size is not None): raise ValueError("Only one of n_splits or split_size may be specified.") if (self.n_splits is None) and (self.split_size is None): raise ValueError("One of n_splits or split_size must be specified.") if (self.n_splits is not None) and (self.n_splits < 2): raise ValueError("Expected n_splits >= 2.") if (self.split_size is not None) and (self.split_size <= 0): raise ValueError("Expected split_size > 0.")
[docs] def __call__(self, sequences: list[str]) -> MetricResult: """ Call the wrapped metric on each fold of ``sequences`` and aggregate the results. Args: sequences: Sequences to split into folds. Returns: MetricResult: Aggregated mean value and standard deviation, standard error or variance across folds. """ n = len(sequences) indices = np.arange(n) if self.shuffle: rng = np.random.default_rng(self.seed) rng.shuffle(indices) # Determine folds if self.n_splits is not None: if self.n_splits > n: raise ValueError(f"Cannot split into {self.n_splits} folds with only {n} sequences.") raw_folds = np.array_split(indices, self.n_splits) else: raw_folds = [indices[i : i + self.split_size] for i in range(0, n, self.split_size)] if self.drop_last and raw_folds and len(raw_folds[-1]) < self.split_size: raw_folds = raw_folds[:-1] if self.drop_last and len(raw_folds) == 0: raise ValueError( f"With drop_last=True, cannot form any fold of size {self.split_size} from {n} sequences." ) results = [] for fold_idx in raw_folds: idx_list = fold_idx.tolist() result = self.metric([sequences[i] for i in idx_list]) if self.strict and (result.deviation is not None): raise ValueError("Fold result has non-null deviation in strict mode.") results.append(result) values = np.array([result.value for result in results], float) if len(results) > 1: if self.estimate == "biased": ddof = 0 elif self.estimate == "unbiased": ddof = 1 else: raise ValueError(f"Invalid estimate: {self.estimate}") if self.deviation == "std": deviation = float(values.std(ddof=ddof)) elif self.deviation == "var": deviation = float(values.var(ddof=ddof)) elif self.deviation == "se": deviation = float(values.std(ddof=ddof)) / (len(values) ** 0.5) else: raise ValueError(f"Invalid deviation: {self.deviation}") else: assert len(results) == 1 deviation = results[0].deviation return MetricResult(value=values.mean().item(), deviation=deviation)
@property def name(self) -> str: return self.metric.name @property def objective(self) -> Literal["minimize", "maximize"]: return self.metric.objective