Source code for pyterrier_doc2query
__version__ = '0.2.0'
import math
import pyterrier as pt
import pyterrier_alpha as pta
import pandas as pd
import torch
from transformers import T5Tokenizer, T5TokenizerFast, T5ForConditionalGeneration
from more_itertools import chunked
from typing import List
import re
from warnings import warn
from .artefact import Artefact
from .filtering import QueryScorer, QueryFilter
from .stores import Doc2QueryStore, QueryScoreStore
[docs]
class Doc2Query(pt.Transformer):
"""A :class:`~pyterrier.Transformer` that generates queries from documents."""
def __init__(self,
checkpoint='macavaney/doc2query-t5-base-msmarco',
num_samples=3,
batch_size=4,
doc_attr="text",
append=False,
out_attr="querygen",
verbose=False,
fast_tokenizer=False,
device=None):
"""
Args:
checkpoint: The checkpoint to use for the model. Defaults to 'macavaney/doc2query-t5-base-msmarco'.
num_samples: The number of queries to generate per document.
batch_size: The batch size to use for inference.
doc_attr: The attribute in the input DataFrame that contains the documents.
append: If True, the generated queries are appended to the documents. Otherwise, the queries are stored in a separate attribute.
out_attr: The attribute in the output DataFrame to store the generated queries.
verbose: If True, displays a progress bar.
fast_tokenizer: If True, uses the fast tokenizer.
device: The device to use for inference. If None, defaults to 'cuda' if available, otherwise 'cpu'.
"""
self.checkpoint = checkpoint
self.num_samples = num_samples
self.doc_attr = doc_attr
self.append = append
self.out_attr = out_attr
if append:
assert out_attr == 'querygen', "append=True cannot be used with out_attr"
self.verbose = verbose
self.batch_size = batch_size
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = torch.device(device)
self.pattern = re.compile("^\\s*http\\S+")
if fast_tokenizer:
self.tokenizer = T5TokenizerFast.from_pretrained(checkpoint)
else:
warn('consider setting fast_tokenizer=True; it speeds up inference considerably')
self.tokenizer = T5Tokenizer.from_pretrained(checkpoint)
self.fast_tokenizer = fast_tokenizer
self.model = T5ForConditionalGeneration.from_pretrained(checkpoint)
self.model.to(self.device)
self.model.eval()
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""Applied the query generation transformation."""
with pta.validate.any(df) as v:
v.document_frame(extra_columns=[self.doc_attr])
v.result_frame(extra_columns=[self.doc_attr])
v.columns(includes=[self.doc_attr])
it = chunked(iter(df[self.doc_attr]), self.batch_size)
if self.verbose and len(df) > 0:
it = pt.tqdm(it, total=math.ceil(len(df)/self.batch_size), unit='d', desc='doc2query')
output = []
for docs in it:
docs = list(docs) # so we can refernece it again when self.append
gens = self._doc2query(docs)
if self.append:
gens = [f'{doc}\n{gen}' for doc, gen in zip(docs, gens)]
output.extend(gens)
if self.append:
df = df.assign(**{self.doc_attr: output}) # replace doc content
else:
df = df.assign(**{self.out_attr: output}) # add new column
return df
def _doc2query(self, docs : List[str]):
docs = [re.sub(self.pattern, "", doc) for doc in docs]
with torch.no_grad():
input_ids = self.tokenizer(docs,
max_length=64,
return_tensors='pt',
padding=True,
truncation=True).input_ids.to(self.device)
outputs = self.model.generate(
input_ids=input_ids,
max_length=64,
do_sample=True,
top_k=10,
num_return_sequences=self.num_samples)
outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
rtr = ['\n'.join(gens) for gens in chunked(outputs, self.num_samples)]
return rtr
__all__ = ['Doc2Query', 'QueryScorer', 'QueryFilter', 'Doc2QueryStore', 'QueryScoreStore', 'Artefact']