Source code for pyterrier_doc2query.filtering
import numpy as np
import pandas as pd
import pyterrier as pt
import pyterrier_alpha as pta
[docs]
class QueryScorer(pt.Transformer):
"""A :class:`~pyterrier.Transformer` that scores queries generated by :class:`~pyterrier_doc2query.Doc2Query` with the provided ``scorer`` transformer."""
def __init__(self, scorer: pt.Transformer):
"""
Args:
scorer: A pyterrier Transformer that takes a DataFrame with columns 'query', 'text', 'qid' and returns a DataFrame with columns 'qid', 'score'.
"""
self.scorer = scorer
def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
"""Applies the scoring transformation."""
pta.validate.document_frame(inp, extra_columns=['text', 'querygen'])
slices = []
scorer_inp = {
'query': [],
'text': [],
}
for text, querygen, docno in zip(inp['text'], inp['querygen'], inp['docno']):
queries = querygen.split('\n')
start_idx = len(scorer_inp['query'])
slices.append(slice(start_idx, start_idx+len(queries)))
scorer_inp['query'].extend(queries)
scorer_inp['text'].extend([text] * len(queries))
scorer_inp['docno'].extend([docno] * len(queries))
scorer_inp['qid'] = list(range(len(scorer_inp['query'])))
scorer_inp = pd.DataFrame(scorer_inp)
if len(scorer_inp) > 0:
dout = self.scorer(scorer_inp)
return inp.assign(querygen_score=[dout['score'].values[s] for s in slices])
[docs]
class QueryFilter(pt.Transformer):
"""A :class:`~pyterrier.Transformer` that filters out queries based on their scores (from :class:`~pyterrier_doc2query.QueryScorer`) and the threshold ``t``."""
def __init__(self, t: float, append: bool = True):
"""
Args:
t: The threshold to filter queries by. The score must be larger than this value to pass the filter.
append: If True, the filtered queries are appended to the text. Otherwise, the queries are filtered out.
"""
self.t = t
self.append = append
def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
"""Applies the filtering transformation."""
pta.validate.document_frame(inp, extra_columns=['querygen', 'querygen_score', 'text'])
inp = inp.reset_index(drop=True)
querygen = ['\n'.join(np.array(qs.split('\n'))[ss >= self.t].tolist()) for qs, ss in zip(inp['querygen'], inp['querygen_score'])]
if self.append:
if len(inp) > 0:
inp = inp.assign(text=inp['text'] + '\n' + pd.Series(querygen))
inp = inp.drop(['querygen', 'querygen_score'], axis='columns')
else:
querygen_score = inp['querygen_score'].apply(lambda row: row[row >= self.t])
inp = inp.assign(querygen=querygen, querygen_score=querygen_score)
return inp