Source code for pyterrier_splade._scorer
import more_itertools
import pandas as pd
import numpy as np
import pyterrier as pt
import pyterrier_alpha as pta
[docs]
class SpladeScorer(pt.Transformer):
"""Scores (re-ranks) documents against queries using a SPLADE model."""
def __init__(self, splade, text_field, batch_size=100, verbose=False):
"""Initializes the SPLADE scorer.
Args:
splade: :class:`pyterrier_splade.Splade` instance
text_field: the text field to score
batch_size: the batch size to use when scoring
verbose: if True, show a progress bar
"""
self.splade = splade
self.text_field = text_field
self.batch_size = batch_size
self.verbose = verbose
[docs]
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""Scores (re-ranks) the documents against the queries in the input DataFrame."""
pta.validate.result_frame(df, ['query', self.text_field])
it = df.groupby('query')
if self.verbose:
it = pt.tqdm(it, unit='query')
res = []
for query, df in it:
query_enc = self.splade.encode([query], 'q', 'torch')
scores = []
for batch in more_itertools.chunked(df[self.text_field], self.batch_size):
doc_enc = self.splade.encode(batch, 'd', 'torch')
scores.append((query_enc @ doc_enc.T).flatten().cpu().numpy())
res.append(df.assign(score=np.concatenate(scores)))
res = pd.concat(res)
from pyterrier.model import add_ranks
res = add_ranks(res)
return res