Source code for pyterrier_dr.tctcolbert_model
from more_itertools import chunked
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
from pyterrier_dr.util import Variants
from . import BiEncoder
[docs]
class TctColBert(BiEncoder, metaclass=Variants):
"""Dense encoder for TCT-ColBERT (Tightly-Coupled Teachers over ColBERT)
See :class:`~pyterrier_dr.BiEncoder` for usage information.
.. cite.dblp:: journals/corr/abs-2010-11386
.. automethod:: base()
.. automethod:: hn()
.. automethod:: hnp()
"""
def __init__(self, model_name=None, batch_size=32, text_field='text', verbose=False, device=None):
super().__init__(batch_size=batch_size, text_field=text_field, verbose=verbose)
self.model_name = model_name or next(iter(self.VARIANTS.values()))
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = torch.device(device)
self.model = AutoModel.from_pretrained(self.model_name).to(self.device).eval()
def encode_queries(self, texts, batch_size=None):
results = []
with torch.no_grad():
for chunk in chunked(texts, batch_size or self.batch_size):
inps = self.tokenizer([f'[CLS] [Q] {q} ' + ' '.join(['[MASK]'] * 32) for q in chunk], add_special_tokens=False, return_tensors='pt', padding=True, truncation=True, max_length=36)
inps = {k: v.to(self.device) for k, v in inps.items()}
res = self.model(**inps).last_hidden_state
res = res[:, 4:, :].mean(dim=1) # remove the first 4 tokens (representing [CLS] [ Q ]), and average
results.append(res.cpu().numpy())
if not results:
return np.empty(shape=(0, 0))
return np.concatenate(results, axis=0)
def encode_docs(self, texts, batch_size=None):
results = []
with torch.no_grad():
for chunk in chunked(texts, batch_size or self.batch_size):
inps = self.tokenizer([f'[CLS] [D] {d}' for d in chunk], add_special_tokens=False, return_tensors='pt', padding=True, truncation=True, max_length=512)
inps = {k: v.to(self.device) for k, v in inps.items()}
res = self.model(**inps).last_hidden_state
res = res[:, 4:, :] # remove the first 4 tokens (representing [CLS] [ D ])
res = res * inps['attention_mask'][:, 4:].unsqueeze(2) # apply attention mask
lens = inps['attention_mask'][:, 4:].sum(dim=1).unsqueeze(1)
lens[lens == 0] = 1 # avoid edge case of div0 errors
res = res.sum(dim=1) / lens # average based on dim
results.append(res.cpu().numpy())
if not results:
return np.empty(shape=(0, 0))
return np.concatenate(results, axis=0)
def __repr__(self):
return f'TctColBert({repr(self.model_name)})'
VARIANTS = {
'base': 'castorini/tct_colbert-msmarco',
'hn': 'castorini/tct_colbert-v2-hn-msmarco',
'hnp': 'castorini/tct_colbert-v2-hnp-msmarco',
}