Source code for suiteeval.suite.base

from __future__ import annotations

from abc import ABCMeta, ABC
import builtins
from collections.abc import Sequence as runtime_Sequence, Iterator
import inspect
from functools import cache
from typing import Callable, Generator, Optional, Any, Tuple, Union, Sequence
from logging import getLogger

import numpy as np
import ir_datasets as irds
from ir_measures import nDCG, Measure, parse_measure, parse_trec_measure
import pandas as pd
import pyterrier as pt
from pyterrier import Transformer

from suiteeval.context import DatasetContext
from suiteeval.utility import geometric_mean

logging = getLogger(__name__)


class SuiteMeta(ABCMeta):
    """
    Metaclass for :class:`Suite`.

    Responsibilities:
    - Maintain a registry of suite classes by name.
    - Enforce a singleton instance per suite class (i.e., one instance per subclass).
    - Provide a :meth:`register` helper to dynamically create and register suites.
    """

    _classes: dict[str, type] = {}
    _instances: dict[str, "Suite"] = {}

    def __call__(cls, *args, **kwargs):
        # singleton: only one instance per class
        if cls.__name__ not in SuiteMeta._instances:
            SuiteMeta._instances[cls.__name__] = super().__call__(*args, **kwargs)
        return SuiteMeta._instances[cls.__name__]

    @classmethod
    def register(
        mcs,
        suite_name: str,
        datasets: list[str],
        names: Optional[list[str]] = None,
        metadata: Optional[Union[list[dict[str, Any]], dict[str, Any]]] = None,
        query_field: Optional[str] = None,
    ) -> "Suite":
        """
        Create (or retrieve) a Suite singleton that wraps the given datasets.

        Args:
            suite_name: Name to assign to the dynamically created suite subclass.
            datasets: IRDS dataset identifiers (e.g., ``"msmarco-passage/trec-dl-2019"``).
            names: Optional display names corresponding one-to-one with ``datasets``.
                Defaults to ``datasets`` when omitted.
            metadata: Optional metadata. Accepted forms:
                * ``None`` → per-dataset empty dicts
                * ``list[dict]`` → each entry applies to the corresponding dataset in ``names``/``datasets``
                * ``dict[str, dict]`` → explicit mapping from dataset name/ID to metadata dict
                * ``dict[str, Any]`` where values are not dicts → treated as flat metadata applied to all
            query_field: Optional topic field name to use when fetching topics (e.g., ``"title"``).

        Returns:
            Suite: The singleton instance of the dynamically created suite class.

        Raises:
            ValueError: If ``metadata`` has an unsupported shape or length.
        """
        # if already registered, return existing instance
        if suite_name in mcs._classes:
            return mcs._classes[suite_name]()

        # build the dataset name → dataset_id mapping
        ds_names = names or datasets
        dataset_map = dict(zip(ds_names, datasets))

        # normalise metadata:
        #  • None            → empty per-dataset dicts
        #  • list[dict]      → metadata[i] applies to ds_names[i]
        #  • dict[str,dict]  → per-dataset mapping (keys are names or IDs)
        #  • dict[k,v] where v is NOT a dict → flat metadata for all
        if metadata is None:
            metadata_map = {name: {} for name in ds_names}
        elif isinstance(metadata, list):
            if len(metadata) != len(ds_names):
                raise ValueError("`metadata` list must match number of datasets")
            metadata_map = dict(zip(ds_names, metadata))
        elif isinstance(metadata, dict):
            if all(not isinstance(v, dict) for v in metadata.values()):
                metadata_map = {name: metadata for name in ds_names}
            else:
                metadata_map = metadata
        else:
            raise ValueError(f"Unsupported metadata type: {type(metadata)}")

        # dynamically create subclass with mappings
        attrs = {
            "_datasets": dataset_map,  # display-name -> dataset_id
            "_dataset_ids": dataset_map,  # alias used by other methods
            "_metadata": metadata_map,
            "_query_field": query_field,
        }
        new_cls = mcs(suite_name, (Suite,), attrs)

        # store class and return its singleton instance
        mcs._classes[suite_name] = new_cls
        return new_cls()


[docs] class Suite(ABC, metaclass=SuiteMeta): """ Abstract base class for a set of related evaluations across one or more datasets. Subclasses (or classes created via :meth:`SuiteMeta.register`) must populate: Attributes: _datasets: Either a ``dict[str, str]`` mapping display name → IRDS dataset ID, or a ``list[str]`` of IRDS dataset IDs. _dataset_ids: Normalized mapping of display name → IRDS dataset ID (filled in by registration helpers). _metadata: Optional per-dataset or global metadata. _measures: A list of :class:`ir_measures.Measure` or a mapping from dataset name to such a list. When not provided, defaults are derived from metadata or IRDS documentation; ultimately falling back to ``[nDCG@10]``. _query_field: Optional topic field name to use when fetching topics. Notes: Instances are singletons per subclass (enforced by :class:`SuiteMeta`). """ _datasets: Union[list[str], dict[str, str]] = {} _dataset_ids: dict[str, str] = {} _metadata: dict[str, Any] = {} _measures: Union[list[Measure], dict[str, list[Measure]]] = None __default_measures: list[Measure] = [nDCG @ 10] _query_field: Optional[str] = None # --------------------------- # Construction and validation # --------------------------- def __init__(self): self.coerce_measures(self._metadata) if "description" in self._metadata: self.__doc__ = self._metadata["description"] self.__post_init__() def __post_init__(self): assert self._datasets, ( "Suite must have at least one dataset defined in _datasets" ) if not isinstance(self._datasets, (dict, list)): raise AssertionError( "Suite _datasets must be a dict[name->id] or a list[dataset_id]" ) if isinstance(self._datasets, dict): if not all( isinstance(k, str) and isinstance(v, str) for k, v in self._datasets.items() ): raise AssertionError( "Suite _datasets must map string names to string dataset IDs" ) else: if not all(isinstance(ds, str) for ds in self._datasets): raise AssertionError( "Suite _datasets list must contain dataset IDs (str)" ) assert self._measures is not None, ( "Suite must have measures defined in _measures" ) # --------------------------- # Corpus grouping # --------------------------- def _iter_corpus_groups(self): """ Yield groups of datasets that share the same underlying corpus, determined by ir_datasets.docs_parent_id(dataset_id). Yields: (corpus_id: str, corpus_ds: pt.datasets.Dataset, members: list[tuple[str, str]]) # [(display_name, dataset_id), ...] """ # normalise to a list of (name, ds_id) if isinstance(self._datasets, dict): items = list(self._datasets.items()) else: items = [(ds_id, ds_id) for ds_id in self._datasets] # group by docs-parent (corpus) id groups: dict[str, dict] = {} for name, ds_id in items: try: corpus_id = irds.docs_parent_id(ds_id) or ds_id except Exception: corpus_id = ds_id if corpus_id not in groups: groups[corpus_id] = { "corpus_ds": pt.get_dataset(f"irds:{corpus_id}"), "members": [], } groups[corpus_id]["members"].append((name, ds_id)) # deterministic iteration order (insertion order is fine here) for corpus_id, g in groups.items(): yield corpus_id, g["corpus_ds"], g["members"] # --------------------------- # Measures # ---------------------------
[docs] @staticmethod def parse_measures(measures: list[Union[str, Measure]]) -> list[Measure]: """ Convert a list of measure strings or :class:`ir_measures.Measure` objects into a flat ``list[Measure]``. Args: measures: A sequence containing measure strings (e.g., ``"nDCG@10"``) and/or :class:`ir_measures.Measure` instances. Returns: list[Measure]: Parsed measure objects. Raises: ValueError: If a string entry cannot be parsed by either :func:`ir_measures.parse_measure` or :func:`ir_measures.parse_trec_measure`, or if an entry has an invalid type. """ out: list[Measure] = [] def _ensure_list(x: Union[Measure, Sequence[Measure]]) -> list[Measure]: if isinstance(x, Measure): return [x] return list(x) for m in measures: if isinstance(m, Measure): out.append(m) continue if isinstance(m, str): candidates: list[Measure] = [] for parser in (parse_measure, parse_trec_measure): try: parsed = parser(m) candidates.extend(_ensure_list(parsed)) except ValueError: continue if not candidates: raise ValueError(f"Unrecognised measure string: {m!r}") out.extend(candidates) continue raise ValueError(f"Invalid measure type: {type(m)}") return out
[docs] def coerce_measures(self, metadata: dict[str, Any]) -> None: """ Populate ``self._measures`` by aggregating available sources in priority order: 1. Global ``metadata['official_measures']`` if present. 2. Per-dataset ``metadata[name]['official_measures']`` if present. 3. IRDS documentation ``official_measures`` for each dataset (when available). If no measures are discovered, default to ``[nDCG@10]``. Args: metadata: The suite metadata dictionary as configured at construction time. Returns: None """ measures_accum: list[Measure] = [] seen: set[str] = set() def _add_many(items: Optional[list[Union[str, Measure]]]) -> None: if not items: return for m in self.parse_measures(items): sig = str(m) if sig not in seen: measures_accum.append(m) seen.add(sig) # (1) global metadata if isinstance(metadata, dict): _add_many(metadata.get("official_measures")) # (2) per-dataset metadata if isinstance(metadata, dict): # iterate over declared dataset names (works for dict; if list, keys are ids) names_iter = ( self._datasets if isinstance(self._datasets, dict) else self._datasets ) for name in names_iter: md = metadata.get(name, {}) if isinstance(md, dict): _add_many(md.get("official_measures")) # (3) ir_datasets documentation for name, ds_id in ( self._dataset_ids.items() if isinstance(self._dataset_ids, dict) else [] ): try: ds = irds.load(ds_id) docs = getattr(ds, "documentation", lambda: None)() if isinstance(docs, dict): _add_many(docs.get("official_measures")) except Exception as e: logging.warning( f"Failed to load measures from documentation for '{name}' ({ds_id}): {e}" ) if not measures_accum: logging.warning("No measures discovered; defaulting to [nDCG@10].") measures_accum = [nDCG @ 10] self._measures = measures_accum
@staticmethod def _normalize_generators( pipeline_generators: Union[Callable[[DatasetContext], Any], runtime_Sequence], what: str, ) -> list[Callable[[DatasetContext], Any]]: """ Normalize a callable or a sequence of callables to a list of callables. Args: pipeline_generators: Either a single callable taking ``DatasetContext`` and yielding pipelines, or a sequence of such callables. what: Human-readable label used in error messages. Returns: list[Callable[[DatasetContext], Any]]: The normalized list. Raises: TypeError: If the input is neither callable nor a sequence of callables. """ if not isinstance(pipeline_generators, runtime_Sequence) or isinstance( pipeline_generators, (str, bytes) ): if not builtins.callable(pipeline_generators): raise TypeError( f"{what} must be a callable or a sequence of callables." ) return [pipeline_generators] # type: ignore[list-item] if not all(builtins.callable(f) for f in pipeline_generators): # type: ignore[arg-type] raise TypeError(f"All elements of {what} must be callable.") return list(pipeline_generators) # type: ignore[return-value]
[docs] def coerce_pipelines_sequential( self, context: DatasetContext, pipeline_generators: Union[Callable[[DatasetContext], Any], runtime_Sequence], ): """ Yield pipelines lazily, one at a time, without materializing the full set. Use this when you want to minimize memory/VRAM footprint and you do not require joint analysis across all systems at once (e.g., significance testing). Args: context: The shared :class:`DatasetContext` for the current corpus group. pipeline_generators: Callable or sequence of callables that produce either: * a single :class:`pyterrier.Transformer`, * a sequence of transformers, * a tuple ``(pipelines, name_or_names)`` where names may be a single label applied to all pipelines or a sequence aligned with ``pipelines``. Yields: tuple[Transformer, Optional[str]]: The pipeline and an optional display name. Raises: ValueError: If a generator yields an invalid structure. """ gens = self._normalize_generators(pipeline_generators, "pipeline_generators") def _yield_item(item): if isinstance(item, tuple) and len(item) == 2: p, nm = item else: p, nm = item, None if isinstance(p, Transformer): yield p, (nm if isinstance(nm, str) else None) elif isinstance(p, runtime_Sequence) and all( isinstance(pi, Transformer) for pi in p ): if isinstance(nm, str): for pi in p: yield pi, nm elif isinstance(nm, runtime_Sequence): nm_list = list(nm) if len(nm_list) != len(p): raise ValueError( "Length of names does not match number of pipelines." ) for pi, nmi in zip(p, nm_list): yield pi, (nmi if isinstance(nmi, str) else None) else: for pi in p: yield pi, None else: raise ValueError(f"Generator yielded an invalid item: {type(p)}") for gen in gens: out = gen(context) if inspect.isgenerator(out) or isinstance(out, Iterator): for item in out: yield from _yield_item(item) else: if isinstance(out, tuple): _pipelines, *_rest = out _names = None if not _rest else _rest[0] yield from _yield_item((_pipelines, _names)) else: yield from _yield_item(out)
[docs] def coerce_pipelines_grouped( self, context: DatasetContext, pipeline_generators: Union[Callable[[DatasetContext], Any], runtime_Sequence], ) -> Tuple[list[Transformer], Optional[list[str]]]: """ Materialize all pipelines (and optional names) into lists. Use this when downstream evaluation requires access to the full set of systems simultaneously (e.g., significance tests). Args: context: The shared :class:`DatasetContext` for the current corpus group. pipeline_generators: Callable or sequence of callables following the same conventions as in :meth:`coerce_pipelines_sequential`. Returns: tuple[list[Transformer], Optional[list[str]]]: A list of pipelines and, if provided, a list of corresponding names. If no names were supplied, returns ``None`` for the second element. Raises: ValueError: If the generators produce no pipelines or an invalid structure. """ gens = self._normalize_generators(pipeline_generators, "pipeline_generators") pipelines: list[Transformer] = [] names: list[Optional[str]] = [] def _emit_item_to_lists(item): if isinstance(item, tuple) and len(item) == 2: p, nm = item else: p, nm = item, None if isinstance(p, Transformer): pipelines.append(p) names.append(nm if isinstance(nm, str) else None) elif isinstance(p, runtime_Sequence) and all( isinstance(pi, Transformer) for pi in p ): if isinstance(nm, str): pipelines.extend(p) names.extend([nm] * len(p)) elif isinstance(nm, runtime_Sequence): nm_list = list(nm) if len(nm_list) != len(p): raise ValueError( "Length of names does not match number of pipelines." ) pipelines.extend(p) names.extend([n if isinstance(n, str) else None for n in nm_list]) else: pipelines.extend(p) names.extend([None] * len(p)) else: raise ValueError(f"Generator yielded an invalid item: {type(p)}") for gen in gens: out = gen(context) if inspect.isgenerator(out) or isinstance(out, Iterator): for item in out: _emit_item_to_lists(item) else: if isinstance(out, tuple): _pipelines, *_rest = out _names = None if not _rest else _rest[0] _emit_item_to_lists((_pipelines, _names)) else: _emit_item_to_lists(out) if not pipelines: raise ValueError( "No pipelines generated. Ensure your generators produce valid Transformers." ) final_names = ( None if not any(names) else [ nm if nm is not None else f"pipeline_{i}" for i, nm in enumerate(names) ] ) return pipelines, final_names
[docs] def compute_overall_mean( self, results: pd.DataFrame, eval_metrics: Sequence[Any] = None, ) -> pd.DataFrame: """ Append overall (geometric mean) rows across datasets for each system name. This first aggregates per-dataset means over repeated runs, then computes the geometric mean across datasets for each metric and appends rows with ``dataset == "Overall"``. Args: results: DataFrame with at least ``["dataset", "name"]`` and metric columns. eval_metrics: Optional sequence of metrics to consider; defaults to ``self.__default_measures`` when not provided. Returns: pandas.DataFrame: The input results with additional ``Overall`` rows appended. """ measure_cols = [ str(m) for m in (eval_metrics or self.__default_measures) if str(m) in results.columns ] if measure_cols: per_ds = ( results.groupby(["dataset", "name"], dropna=False)[measure_cols] .mean() .reset_index() ) gmean_rows = [] for name, group in per_ds.groupby("name", dropna=False): row = {"dataset": "Overall", "name": name} for col in measure_cols: vals = pd.to_numeric(group[col], errors="coerce").dropna().values if np.any(vals <= 0): vals = vals + 1e-12 row[col] = geometric_mean(vals) gmean_rows.append(row) gmean_df = pd.DataFrame(gmean_rows) results = pd.concat([results, gmean_df], ignore_index=True) return results
[docs] @cache def get_measures(self, dataset: str) -> list[Measure]: """ Resolve the measures applicable to a given dataset name. Args: dataset: Dataset display name as used in this suite. Returns: list[Measure]: The list configured for this dataset (or the suite-wide list if a single list is maintained). Falls back to defaults when the dataset is unknown. """ if isinstance(self._measures, list): return self._measures if dataset not in self._measures: return self.__default_measures return self._measures[dataset]
@property def datasets(self) -> Generator[Tuple[str, pt.datasets.Dataset], None, None]: """ Iterate over declared datasets yielding display name and PyTerrier dataset. Yields: tuple[str, pyterrier.datasets.Dataset]: Pairs of (name, ``pt.get_dataset("irds:<id>")``). Raises: ValueError: If ``_datasets`` has an invalid type. """ if isinstance(self._datasets, list): for ds_id in self._datasets: yield ds_id, pt.get_dataset(f"irds:{ds_id}") elif isinstance(self._datasets, dict): for name, ds_id in self._datasets.items(): yield name, pt.get_dataset(f"irds:{ds_id}") else: raise ValueError( "Suite _datasets must be a list or dict mapping names to dataset IDs." ) @staticmethod def _topics_qrels(ds: pt.datasets.Dataset, query_field: Optional[str]): """ Fetch topics and qrels for a dataset. Args: ds: A :class:`pyterrier.datasets.Dataset` instance. query_field: Optional topic field name (e.g., ``"title"``). Returns: tuple[pandas.DataFrame, pandas.DataFrame]: ``(topics, qrels)``. """ topics = ds.get_topics(query_field, tokenise_query=False) qrels = ds.get_qrels() return topics, qrels @staticmethod def _free_cuda(): """ Best-effort memory cleanup helper. Calls ``gc.collect()`` and, if ``torch.cuda.is_available()``, empties the CUDA cache. Silently ignores any exceptions (CUDA and torch are optional). """ import gc gc.collect() try: import torch # noqa: WPS433 — optional dependency if torch.cuda.is_available(): torch.cuda.empty_cache() except Exception: pass
[docs] def __call__( self, ranking_generators: Union[ Callable[[DatasetContext], Any], Sequence[Callable[[DatasetContext], Any]] ], eval_metrics: Sequence[Any] = None, subset: Optional[str] = None, **experiment_kwargs: dict[str, Any], ) -> pd.DataFrame: """ Run the experiment(s) for each dataset in the suite and return a results table. If a ``baseline`` is provided in ``experiment_kwargs``, all pipelines are materialized together (grouped mode) to enable tests that require joint access (e.g., significance). Otherwise, pipelines are streamed one-by-one to reduce memory usage (sequential mode). Args: ranking_generators: Callable or sequence of callables producing pipelines per :class:`DatasetContext` (same conventions as in :meth:`coerce_pipelines_sequential`). eval_metrics: Optional explicit metrics to evaluate; defaults to the suite’s configuration for each dataset. subset: Optional dataset display name to restrict evaluation to a single member. **experiment_kwargs: Additional keyword arguments forwarded to :func:`pyterrier.Experiment`. If ``save_dir`` is provided, it is suffixed per dataset. Returns: pandas.DataFrame: The concatenated experiment results. When ``perquery`` is not set, an additional ``Overall`` row is appended per system with geometric-mean aggregation across datasets. Notes: This method reuses a single index per corpus group and cleans up GPU memory between pipeline evaluations. """ results: list[pd.DataFrame] = [] baseline = experiment_kwargs.get("baseline", None) coerce_grouped = baseline is not None if coerce_grouped: logging.warning( "Significance tests require pipelines to be grouped; this uses more memory." ) for corpus_id, corpus_ds, members in self._iter_corpus_groups(): # If a subset was requested, skip this corpus unless it contains the subset if subset and all(name != subset for name, _ in members): continue # Single shared context per corpus (indexing happens once here) context = DatasetContext(corpus_ds) if coerce_grouped: # Materialise all pipelines ONCE for the corpus pipelines, names = self.coerce_pipelines_grouped( context, ranking_generators ) # Evaluate the same systems across each dataset that shares this corpus for ds_name, ds_id in members: if subset and ds_name != subset: continue ds_member = pt.get_dataset(f"irds:{ds_id}") topics, qrels = self._topics_qrels(ds_member, self._query_field) save_dir = experiment_kwargs.pop("save_dir", None) if save_dir is not None: formatted_ds_name = ds_name.replace("/", "-").lower() ds_save_dir = f"{save_dir}/{formatted_ds_name}" experiment_kwargs["save_dir"] = ds_save_dir df = pt.Experiment( pipelines, eval_metrics=eval_metrics or self.get_measures(ds_name), topics=topics, qrels=qrels, names=names, **experiment_kwargs, ) df["dataset"] = ds_name results.append(df) # Release materialised pipelines after all member datasets are processed try: del pipelines, names finally: self._free_cuda() else: # Stream pipelines one at a time, but reuse each pipeline across ALL member datasets for pipeline, name in self.coerce_pipelines_sequential( context, ranking_generators ): for ds_name, ds_id in members: if subset and ds_name != subset: continue ds_member = pt.get_dataset(f"irds:{ds_id}") topics, qrels = self._topics_qrels(ds_member, self._query_field) df = pt.Experiment( [pipeline], eval_metrics=eval_metrics or self.get_measures(ds_name), topics=topics, qrels=qrels, names=None if name is None else [name], **experiment_kwargs, ) df["dataset"] = ds_name results.append(df) # Dispose of this pipeline (after all member datasets) try: del pipeline finally: self._free_cuda() # Release per-corpus context del context results_df = ( pd.concat(results, ignore_index=True) if results else pd.DataFrame() ) # Aggregate geometric mean only across actual Measure columns perquery = experiment_kwargs.get("perquery", False) if not perquery and not results_df.empty: results_df = self.compute_overall_mean(results_df) return results_df