Source code for pyterrier_anserini._retriever
from typing import Any, Dict, List, Optional, Union
import numpy as np
import pandas as pd
import pyterrier as pt
import pyterrier_alpha as pta
from pyterrier_anserini import J
from pyterrier_anserini._index import AnseriniIndex
from pyterrier_anserini._similarity import AnseriniSimilarity
def _noop_query_parser(query: str) -> str:
return query
def _toks_query_parser_factory(parser): # noqa: ANN001
def wrapped(toks: Dict[str, float]) -> Any:
res = []
for tok, weight in toks.items():
res.append(f'{parser.escape(tok)}^{weight:f}')
query = ' '.join(res)
return parser.parse(query)
return wrapped
[docs]
@pt.java.required
class AnseriniRetriever(pt.Transformer):
"""Retrieves from an Anserini index."""
def __init__(self,
index: Union[AnseriniIndex, str],
similarity: Union[AnseriniSimilarity, str] = "BM25",
similarity_args: Dict[str, any] = None,
*,
num_results: int = 1000,
include_fields: Optional[List[str]] = None,
verbose: bool = False,
):
"""Construct an AnseriniRetriever retrieve from pyserini.search.lucene.LuceneSearcher.
Args:
index: The Anserini index.
similarity: The similarity function to use.
similarity_args: model-specific arguments, like bm25.k1.
num_results: number of results to return. Default is 1000.
include_fields: a list of extra stored fields to include for each result. `None` indicates no extra fields.
verbose: show a progress bar during retrieval?
"""
if not isinstance(index, AnseriniIndex):
index = AnseriniIndex(index)
self.index = index
self.similarity = similarity
self.similarity_args = similarity_args
self.num_results = num_results
self.include_fields = include_fields
self.verbose = verbose
__repr__ = pta.transformer_repr
[docs]
def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
"""Performs retrieval.
Args:
inp: A pandas.Dataframe
Returns:
pandas.DataFrame with columns=['qid', 'query', 'docno', 'rank', 'score']
"""
with pta.validate.any(inp) as v:
v.query_frame(extra_columns=['query_lucene'], mode='query_lucene')
v.query_frame(extra_columns=['query_toks'], mode='query_toks')
v.query_frame(extra_columns=['query'], mode='query_text')
sim = AnseriniSimilarity(self.similarity).to_lucene_sim(self.similarity_args)
searcher = self.index._searcher()
searcher.object.searcher.setSimilarity(sim)
if v.mode == 'query_lucene':
parser = J.QueryParser("contents", searcher.object.analyzer)
q_transform = parser.parse
it = enumerate(inp['query_lucene'])
elif v.mode == 'query_toks':
parser = J.QueryParser("contents", searcher.object.analyzer)
q_transform = _toks_query_parser_factory(parser)
it = enumerate(inp['query_toks'])
elif v.mode == 'query_text':
q_transform = _noop_query_parser
it = enumerate(inp['query'])
if self.verbose:
it = pt.tqdm(it, desc=str(self), total=len(inp), unit='d')
result_cols = ['_index', 'docno', 'score', 'rank']
if self.include_fields:
result_cols += self.include_fields
result = pta.DataFrameBuilder(result_cols)
for i, query in it:
hits = searcher.search(q_transform(query), k=self.num_results)
records = {
'_index': i,
'docno': [h.docid for h in hits],
'score': [h.score for h in hits],
'rank': np.arange(len(hits)),
}
if self.include_fields:
records.update({
f: [h.lucene_document.get(f) for h in hits]
for f in self.include_fields
})
result.extend(records)
return result.to_df(merge_on_index=inp)