Source code for pyterrier_splade._model

from typing import Union, List, Literal, Dict
import torch
import numpy as np
import pyterrier as pt
import pyterrier_splade

[docs] class Splade: """A SPLADE model, which provides transformers for sparse encoding documents and queries, and scoring documents.""" def __init__( self, model: Union[torch.nn.Module, str] = "naver/splade-cocondenser-ensembledistil", tokenizer=None, agg='max', max_length=256, device=None ): """Initializes the SPLADE model. Args: model: the SPLADE model to use, either a PyTorch model or a string to load from HuggingFace tokenizer: the tokenizer to use, if not included in the model agg: the aggregation function to use for the SPLADE model max_length: the maximum length of the input sequences device: the device to use, e.g. 'cuda' or 'cpu' """ self.max_length = max_length self.model = model self.tokenizer = tokenizer if device is None: self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') else: self.device = torch.device(device) if isinstance(model, str): from transformers import AutoModelForMaskedLM if self.tokenizer is None: from transformers import AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained(model) self.model = AutoModelForMaskedLM.from_pretrained(model) self.agg = agg self.model.output_dim = self.model.config.vocab_size self.model.eval() self.model = self.model.to(self.device) else: if self.tokenizer is None: raise ValueError("you must specify tokenizer if passing a model") self.reverse_voc = {v: k for k, v in self.tokenizer.vocab.items()}
[docs] def doc_encoder(self, text_field='text', batch_size=100, sparse=True, verbose=False, scale=100) -> pt.Transformer: """Returns a transformer that encodes a text field into a document representation. Args: text_field: the text field to encode batch_size: the batch size to use when encoding sparse: if True, the output will be a dict of term frequencies, otherwise a dense vector verbose: if True, show a progress bar scale: the scale to apply to the term frequencies """ out_field = 'toks' if sparse else 'doc_vec' return pyterrier_splade.SpladeEncoder(self, text_field, out_field, 'd', sparse, batch_size, verbose, scale)
indexing = doc_encoder # backward compatible name
[docs] def query_encoder(self, batch_size=100, sparse=True, verbose=False, scale=100) -> pt.Transformer: """Returns a transformer that encodes a query field into a query representation. Args: batch_size: the batch size to use when encoding sparse: if True, the output will be a dict of term frequencies, otherwise a dense vector verbose: if True, show a progress bar scale: the scale to apply to the term frequencies """ out_field = 'query_toks' if sparse else 'query_vec' res = pyterrier_splade.SpladeEncoder(self, 'query', out_field, 'q', sparse, batch_size, verbose, scale) return res
query = query_encoder # backward compatible name
[docs] def scorer(self, text_field='text', batch_size=100, verbose=False) -> pt.Transformer: """Returns a transformer that scores documents against queries. Args: text_field: the text field to score batch_size: the batch size to use when scoring verbose: if True, show a progress bar """ return pyterrier_splade.SpladeScorer(self, text_field, batch_size, verbose)
[docs] def encode( self, texts: List[str], rep: Literal['d', 'q'] = 'd', format: Literal['dict', 'np', 'torch'] ='dict', scale: float = 1., ) -> Union[List[Dict[str, float]], List[np.ndarray], torch.Tensor]: """Encodes a batch of texts into their SPLADE representations. Args: texts: the list of texts to encode rep: 'q' for query, 'd' for document format: 'dict' for a dict of term frequencies, 'np' for a list of numpy arrays, 'torch' for a torch tensor scale: the scale to apply to the term frequencies """ rtr = [] with torch.no_grad(): inputs = self.tokenizer( texts, add_special_tokens=True, return_tensors="pt", padding="longest", truncation="longest_first", # truncates to max model length, max_length=self.max_length, ).to(self.device) reps = self.model(**inputs).logits if self.agg == "sum": reps = torch.sum(torch.log(1 + torch.relu(reps)) * inputs["attention_mask"].unsqueeze(-1), dim=1) else: reps, _ = torch.max(torch.log(1 + torch.relu(reps)) * inputs["attention_mask"].unsqueeze(-1), dim=1) reps = reps * scale if format == 'dict': reps = reps.cpu() for i in range(reps.shape[0]): # get the number of non-zero dimensions in the rep: col = torch.nonzero(reps[i]).squeeze(1).tolist() # now let's create the bow representation as a dictionary weights = reps[i, col].cpu().tolist() # if document cast to int to make the weights ready for terrier indexing if rep == "d": weights = list(map(int, weights)) sorted_weights = sorted(zip(col, weights), key=lambda x: (-x[1], x[0])) # create the dict removing the weights less than 1, i.e. 0, that are not helpful d = {self.reverse_voc[k]: v for k, v in sorted_weights if v > 0} rtr.append(d) elif format == 'np': reps = reps.cpu().numpy() for i in range(reps.shape[0]): rtr.append(reps[i]) elif format == 'torch': rtr = reps return rtr