Source code for pyterrier_doc2query.filtering

import numpy as np
import pandas as pd
import pyterrier as pt
import pyterrier_alpha as pta


[docs] class QueryScorer(pt.Transformer): """A :class:`~pyterrier.Transformer` that scores queries generated by :class:`~pyterrier_doc2query.Doc2Query` with the provided ``scorer`` transformer.""" def __init__(self, scorer: pt.Transformer): """ Args: scorer: A pyterrier Transformer that takes a DataFrame with columns 'query', 'text', 'qid' and returns a DataFrame with columns 'qid', 'score'. """ self.scorer = scorer def transform(self, inp: pd.DataFrame) -> pd.DataFrame: """Applies the scoring transformation.""" pta.validate.document_frame(inp, extra_columns=['text', 'querygen']) slices = [] scorer_inp = { 'query': [], 'text': [], } for text, querygen, docno in zip(inp['text'], inp['querygen'], inp['docno']): queries = querygen.split('\n') start_idx = len(scorer_inp['query']) slices.append(slice(start_idx, start_idx+len(queries))) scorer_inp['query'].extend(queries) scorer_inp['text'].extend([text] * len(queries)) scorer_inp['docno'].extend([docno] * len(queries)) scorer_inp['qid'] = list(range(len(scorer_inp['query']))) scorer_inp = pd.DataFrame(scorer_inp) if len(scorer_inp) > 0: dout = self.scorer(scorer_inp) return inp.assign(querygen_score=[dout['score'].values[s] for s in slices])
[docs] class QueryFilter(pt.Transformer): """A :class:`~pyterrier.Transformer` that filters out queries based on their scores (from :class:`~pyterrier_doc2query.QueryScorer`) and the threshold ``t``.""" def __init__(self, t: float, append: bool = True): """ Args: t: The threshold to filter queries by. The score must be larger than this value to pass the filter. append: If True, the filtered queries are appended to the text. Otherwise, the queries are filtered out. """ self.t = t self.append = append def transform(self, inp: pd.DataFrame) -> pd.DataFrame: """Applies the filtering transformation.""" pta.validate.document_frame(inp, extra_columns=['querygen', 'querygen_score', 'text']) inp = inp.reset_index(drop=True) querygen = ['\n'.join(np.array(qs.split('\n'))[ss >= self.t].tolist()) for qs, ss in zip(inp['querygen'], inp['querygen_score'])] if self.append: if len(inp) > 0: inp = inp.assign(text=inp['text'] + '\n' + pd.Series(querygen)) inp = inp.drop(['querygen', 'querygen_score'], axis='columns') else: querygen_score = inp['querygen_score'].apply(lambda row: row[row >= self.t]) inp = inp.assign(querygen=querygen, querygen_score=querygen_score) return inp