Source code for pyterrier_anserini._reranker

from typing import Dict, Union

import pandas as pd
import pyterrier as pt
import pyterrier_alpha as pta

from pyterrier_anserini import J
from pyterrier_anserini._index import AnseriniIndex
from pyterrier_anserini._similarity import AnseriniSimilarity


[docs] @pt.java.required class AnseriniReRanker(pt.Transformer): """A transformer that scores (i.e., re-ranks) the provided documents from an Anserini index.""" def __init__(self, index: Union[AnseriniIndex, str], similarity: Union[str, AnseriniSimilarity], similarity_args: Dict = None, *, verbose: bool = False ): """Initializes the scorer. Args: index: The index to score from. If a string, an AnseriniIndex object is created for the path. similarity: The similarity function to use for scoring. similarity_args: A dictionary of arguments to use for the similarity function. verbose: Whether to display a progress bar when scoring. """ self.index = index if isinstance(index, AnseriniIndex) else AnseriniIndex(index) self.similarity = AnseriniSimilarity(similarity) self.similarity_args = similarity_args self.verbose = verbose __repr__ = pta.transformer_repr
[docs] def transform(self, inp: pd.DataFrame) -> pd.DataFrame: """Scores (i.e., re-ranks) documents from the index for each query in `inp`. Args: inp: A DataFrame with a 'query' column containing queries and a 'docno' column containing document IDs. Returns: A DataFrame containing the scored documents, with any columns included in `inp`, plus the 'score' and 'rank' of the scored documents. """ with pta.validate.any(inp) as v: v.result_frame(['query_lucene'], mode='query_lucene') v.result_frame(['query_toks'], mode='query_toks') v.result_frame(['query'], mode='query_text') sim = AnseriniSimilarity(self.similarity).to_lucene_sim(self.similarity_args) index_reader = self.index._searcher().object.reader if v.mode == 'query_lucene': raise NotImplementedError('query_lucene not yet supported for AnseriniReRanker') elif v.mode == 'query_toks': raise NotImplementedError('query_toks not yet supported for AnseriniReRanker') elif v.mode == 'query_text': it = zip(inp['docno'], inp['query']) if self.verbose: it = pt.tqdm(it, unit='d', total=len(inp), desc='AnseriniScorer') scores = [ J.IndexReaderUtils.computeQueryDocumentScoreWithSimilarity(index_reader, docno, query, sim) for docno, query in it ] res = inp.assign(score=scores) return pt.model.add_ranks(res)