Source code for pyterrier_dr.bge_m3

import pyterrier as pt
import pandas as pd
import numpy as np
import torch
import pyterrier_alpha as pta
from .biencoder import BiEncoder

[docs] class BGEM3(BiEncoder): def __init__(self, model_name='BAAI/bge-m3', batch_size=32, max_length=8192, text_field='text', verbose=False, device=None, use_fp16=False): super().__init__(batch_size=batch_size, text_field=text_field, verbose=verbose) self.model_name = model_name self.use_fp16 = use_fp16 self.max_length = max_length if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = torch.device(device) try: from FlagEmbedding import BGEM3FlagModel except ImportError: raise ImportError("BGE-M3 requires the FlagEmbedding package. You can install it using 'pip install pyterrier-dr[bgem3]'") self.model = BGEM3FlagModel(self.model_name, use_fp16=self.use_fp16, device=self.device) def __repr__(self): return f'BGEM3({repr(self.model_name)})' def encode_queries(self, texts, batch_size=None): return self.model.encode(list(texts), batch_size=batch_size, max_length=self.max_length, return_dense=True, return_sparse=False, return_colbert_vecs=False)['dense_vecs'] def encode_docs(self, texts, batch_size=None): return self.model.encode(list(texts), batch_size=batch_size, max_length=self.max_length, return_dense=True, return_sparse=False, return_colbert_vecs=False)['dense_vecs'] # Only does dense (single_vec) encoding def query_encoder(self, verbose=None, batch_size=None): return BGEM3QueryEncoder(self, verbose=verbose, batch_size=batch_size) def doc_encoder(self, verbose=None, batch_size=None): return BGEM3DocEncoder(self, verbose=verbose, batch_size=batch_size) # Does all three BGE-M3 encodings: dense, sparse and colbert(multivec) def query_multi_encoder(self, verbose=None, batch_size=None, return_dense=True, return_sparse=True, return_colbert_vecs=True): return BGEM3QueryEncoder(self, verbose=verbose, batch_size=batch_size, return_dense=return_dense, return_sparse=return_sparse, return_colbert_vecs=return_colbert_vecs) def doc_multi_encoder(self, verbose=None, batch_size=None, return_dense=True, return_sparse=True, return_colbert_vecs=True): return BGEM3DocEncoder(self, verbose=verbose, batch_size=batch_size, return_dense=return_dense, return_sparse=return_sparse, return_colbert_vecs=return_colbert_vecs)
class BGEM3QueryEncoder(pt.Transformer): def __init__(self, bge_factory: BGEM3, verbose=None, batch_size=None, max_length=None, return_dense=True, return_sparse=False, return_colbert_vecs=False): self.bge_factory = bge_factory self.verbose = verbose if verbose is not None else bge_factory.verbose self.batch_size = batch_size if batch_size is not None else bge_factory.batch_size self.max_length = max_length if max_length is not None else bge_factory.max_length self.dense = return_dense self.sparse = return_sparse self.multivecs = return_colbert_vecs def encode(self, texts): return self.bge_factory.model.encode(list(texts), batch_size=self.batch_size, max_length=self.max_length, return_dense=self.dense, return_sparse=self.sparse, return_colbert_vecs=self.multivecs) def transform(self, inp: pd.DataFrame) -> pd.DataFrame: pta.validate.columns(inp, includes=['query']) # check if inp is empty if len(inp) == 0: if self.dense: inp = inp.assign(query_vec=[]) if self.sparse: inp = inp.assign(query_toks=[]) if self.multivecs: inp = inp.assign(query_embs=[]) return inp it = inp['query'].values it, inv = np.unique(it, return_inverse=True) if self.verbose: it = pt.tqdm(it, desc='Encoding Queries', unit='query') bgem3_results = self.encode(it) if self.dense: inp = inp.assign(query_vec=[bgem3_results['dense_vecs'][i] for i in inv]) if self.sparse: # for sparse convert ids to the actual tokens query_toks = self.bge_factory.model.convert_id_to_token(bgem3_results['lexical_weights']) inp = inp.assign(query_toks=query_toks) if self.multivecs: inp = inp.assign(query_embs=[bgem3_results['colbert_vecs'][i] for i in inv]) return inp def __repr__(self): return f'{repr(self.bge_factory)}.query_encoder()' class BGEM3DocEncoder(pt.Transformer): def __init__(self, bge_factory: BGEM3, verbose=None, batch_size=None, max_length=None, return_dense=True, return_sparse=False, return_colbert_vecs=False): self.bge_factory = bge_factory self.verbose = verbose if verbose is not None else bge_factory.verbose self.batch_size = batch_size if batch_size is not None else bge_factory.batch_size self.max_length = max_length if max_length is not None else bge_factory.max_length self.dense = return_dense self.sparse = return_sparse self.multivecs = return_colbert_vecs def encode(self, texts): return self.bge_factory.model.encode(list(texts), batch_size=self.batch_size, max_length=self.max_length, return_dense=self.dense, return_sparse=self.sparse, return_colbert_vecs=self.multivecs) def transform(self, inp: pd.DataFrame) -> pd.DataFrame: # check if the input dataframe contains the field(s) specified in the text_field pta.validate.columns(inp, includes=[self.bge_factory.text_field]) # check if inp is empty if len(inp) == 0: if self.dense: inp = inp.assign(doc_vec=[]) if self.sparse: inp = inp.assign(toks=[]) if self.multivecs: inp = inp.assign(doc_embs=[]) return inp it = inp[self.bge_factory.text_field] if self.verbose: it = pt.tqdm(it, desc='Encoding Documents', unit='doc') bgem3_results = self.encode(it) if self.dense: inp = inp.assign(doc_vec=list(bgem3_results['dense_vecs'])) if self.sparse: toks = bgem3_results['lexical_weights'] # for sparse convert ids to the actual tokens toks = self.bge_factory.model.convert_id_to_token(toks) inp = inp.assign(toks=toks) if self.multivecs: inp = inp.assign(doc_embs=list(bgem3_results['colbert_vecs'])) return inp def __repr__(self): return f'{repr(self.bge_factory)}.doc_encoder()'