Source code for seqme.metrics.kl_divergence
from collections.abc import Callable
from typing import Literal
import numpy as np
from sklearn.neighbors import KernelDensity
from seqme.core.base import Metric, MetricResult
[docs]
class KLDivergence(Metric):
r"""
KL-divergence between samples and reference for a single property.
This metric measures how much the empirical distribution of a property
:math:`f(x)` in the generated samples deviates from the corresponding
reference distribution.
The KL-divergence is defined as:
.. math::
\mathrm{KL}\big(p_{f(\mathrm{ref})} \,\|\, p_{f(\mathrm{gen})}\big)
= \int p_{f(\mathrm{ref})}(y)
\log \frac{p_{f(\mathrm{ref})}(y)}{p_{f(\mathrm{gen})}(y)} \, dy,
where :math:`p_{f(\mathrm{ref})}` denotes the reference distribution and
:math:`p_{f(\mathrm{gen})}` denotes the generated distribution.
The KL-divergence is approximated using Monte-Carlo sampling.
"""
[docs]
def __init__(
self,
reference: list[str],
predictor: Callable[[list[str]], np.ndarray],
*,
n_draws: int = 10_000,
kde_bandwidth: float | Literal["scott", "silverman"] = "silverman",
seed: int = 0,
name: str = "KL-divergence",
):
"""
Initialize the metric.
Args:
reference: Reference sequences assumed to represent the target distribution.
predictor: Predictor function which returns a 1D NumPy array. One value per sequence.
n_draws: Number of Monte Carlo samples to draw from reference distribution.
kde_bandwidth: Bandwidth parameter for the Gaussian KDE.
seed: Seed for KL-divergence Monte-Carlo sampling.
name: Metric name.
"""
self.reference = reference
self.predictor = predictor
self.n_draws = n_draws
self.kde_bandwidth = kde_bandwidth
self.seed = seed
self._name = name
self.reference_predictor = self.predictor(self.reference)
if self.n_draws <= 0:
raise ValueError("Expected n_draws > 0.")
[docs]
def __call__(self, sequences: list[str]) -> MetricResult:
"""
Compute the KL-divergence between reference and sequence predictor.
Args:
sequences: Sequences to evaluate.
Returns:
MetricResult: KL-divergence and standard error.
"""
seqs_predictor = self.predictor(sequences)
kl_div, standard_error = continuous_kl_mc(
self.reference_predictor,
seqs_predictor,
kde_bandwidth=self.kde_bandwidth,
n_draws=self.n_draws,
seed=self.seed,
)
return MetricResult(value=kl_div, deviation=standard_error)
@property
def name(self) -> str:
return self._name
@property
def objective(self) -> Literal["minimize", "maximize"]:
return "minimize"
def continuous_kl_mc(
x_reference: np.ndarray,
x_samples: np.ndarray,
kde_bandwidth: float | Literal["scott", "silverman"] = "silverman",
n_draws: int = 10_000,
seed: int = 0,
) -> tuple[float, float]:
x_reference = x_reference.reshape(-1, 1)
x_samples = x_samples.reshape(-1, 1)
kde_p = KernelDensity(kernel="gaussian", bandwidth=kde_bandwidth).fit(x_reference)
kde_q = KernelDensity(kernel="gaussian", bandwidth=kde_p.bandwidth_).fit(x_samples)
x_p = kde_p.sample(n_draws, random_state=seed)
log_p = kde_p.score_samples(x_p)
log_q = kde_q.score_samples(x_p)
log_diff = log_p - log_q
kl_estimate = float(log_diff.mean())
se = float(log_diff.std(ddof=1) / np.sqrt(n_draws))
return kl_estimate, se