Source code for pyterrier_rag.prompt._context_aggregation

from typing import List, Optional, Any, Callable

import pandas as pd
import pyterrier as pt
import pyterrier_alpha as pta

from pyterrier_rag._util import concat


def score_sort(inp: List[dict]):
    if "score" in inp[0]:
        return sorted(inp, key=lambda x: x["score"], reverse=True)
    else:
        return inp


[docs] class Concatenator(pt.Transformer): """ Transformer that concatenates specified fields from document records into a context string. At query time, orders, loads text (if needed), and aggregates records into a single context. Parameters: in_fields (List[str]): Fields to extract from each record. Defaults to ["text"]. out_field (str): Name of the output context field. Defaults to "qcontext". text_loader (Callable, optional): Function to load document text by doc ID. intermediate_format (Callable, optional): Formatter for individual records. tokenizer (Any, optional): Tokenizer used for length-based truncation. max_length (int): Max total token length of the context. Defaults to -1 (no limit). max_elements (int): Max number of records to include. Defaults to -1 (no limit). max_per_context (int): Max tokens per record. truncation_rate (int): Token drop rate during truncation. Defaults to 50. aggregate_func (Callable, optional): Custom aggregation function. ordering_func (Callable): Record ordering function before aggregation. Defaults to score_sort, which sorts by "score" descending. Raises: ValueError: If 'text' is in in_fields but no text_loader is set. """ def __init__( self, in_fields: Optional[List[str]] = ["text"], out_field: Optional[str] = "qcontext", text_loader: Optional[Callable] = None, intermediate_format: Optional[Callable] = None, tokenizer: Optional[Any] = None, max_length: Optional[int] = -1, max_elements: Optional[int] = -1, max_per_context: Optional[int] = -1, truncation_rate: Optional[int] = 50, aggregate_func: Optional[Callable] = None, ordering_func: Optional[Callable] = score_sort, ): super().__init__() self.in_fields = in_fields self.out_field = out_field self.aggregate_func = aggregate_func self.text_loader = text_loader self.intermediate_format = intermediate_format self.tokenizer = tokenizer self.max_length = max_length self.max_elements = max_elements self.max_per_context = max_per_context self.truncation_rate = truncation_rate self.ordering_func = ordering_func def transform(self, inp: pd.DataFrame) -> pd.DataFrame: pta.validate.columns(inp, includes=["qid"] + self.in_fields) output_frame = pta.DataFrameBuilder([self.out_field, "qid", "query"]) if inp is None or inp.empty: return output_frame.to_df() for qid, group in inp.groupby("qid"): inp = group.to_dict(orient="records") qid = inp[0].get("qid", None) query = inp[0].get("query", None) if self.ordering_func is not None: inp = self.ordering_func(inp) relevant = [{k: v for k, v in i.items() if k in self.in_fields} for i in inp] if "text" in self.in_fields and "text" not in inp[0].keys(): if self.text_loader is None: raise ValueError("Cannot retrieve text without a text loader") else: for d, t in zip(relevant, inp): d["text"] = self.text_loader(t["docno"]) if self.aggregate_func is not None: context = self.aggregate_func(relevant) else: context = concat( relevant, intermediate_format=self.intermediate_format, tokenizer=self.tokenizer, max_length=self.max_length, max_elements=self.max_elements, max_per_context=self.max_per_context, truncation_rate=self.truncation_rate, ) output_frame.extend({self.out_field: context, "qid": qid, "query": query}) return output_frame.to_df()
__all__ = ["Concatenator"]