Source code for pyterrier_rag.backend._base

import re
from abc import ABC, abstractmethod
from typing import List, Dict, Optional, Union

import pyterrier as pt
from more_itertools import chunked
from dataclasses import dataclass

from pyterrier_rag import backend as _backend


[docs] @dataclass class BackendOutput: text: str = None logprobs: Optional[List[Dict[str, float]]] = None
[docs] class Backend(pt.Transformer, ABC): """ Abstract base class for model-backed Transformers in PyTerrier. Subclasses must implement the raw generation logic (generate) and the high-level generate method. Supports optional logprob extraction. Parameters: max_input_length (int): Maximum token length for each input prompt. max_new_tokens (int): Maximum number of tokens to generate. verbose (bool): Flag to enable detailed logging. device (Union[str, torch.device]): Device for model execution. The following class attributes are available: Attributes: model_id (str): Model name or checkpoint path. supports_logprobs (bool): Indicates support for including the logprobs of generated tokens. supports_message_input (bool): Indicates support for message (chat)-formatted (``List[dict]``) inputs to ``generate``, in addition to ``str`` inputs. """ supports_logprobs = False supports_message_input = False supports_num_responses = False def __init__( self, model_id: str, *, max_input_length: int = 512, max_new_tokens: int = 32, verbose: bool = False, ): super().__init__() self.model_id = model_id self.max_input_length = max_input_length self.max_new_tokens = max_new_tokens self.verbose = verbose # Abstract methods
[docs] @abstractmethod def generate( self, inps: Union[List[str], List[List[dict]]], *, return_logprobs: bool = False, max_new_tokens: Optional[int] = None, ) -> List[BackendOutput]: """ Generate text from input prompts. Parameters: inps (Union[List[str], List[List[dict]]]): Input prompts as strings or dictionaries. When strings, represent the prompts directly. When a list of dictionaries, represents a sequence of messages (if ``backend.supports_message_input==True``). return_logprobs (bool): Whether to return logprobs of generated tokens along with text. (Only available if ``backend.supports_logprobs==True``.) max_new_tokens (Optional[int]): Override for max tokens to generate. Returns: List[BackendOutput]: An output for each ``inp``, each containing the generated text and optionally logprobs. """ raise NotImplementedError("Implement the generate method")
# Transformer implementations
[docs] def text_generator(self, *, input_field: str = 'prompt', output_field: str = 'qanswer', batch_size: int = 4, max_new_tokens: Optional[int] = None, num_responses: int = 1, ) -> pt.Transformer: """ Create a text generator transformer using this backend. Parameters: input_field (str): Name of the field containing input prompts. output_field (str): Name of the field to store generated text. batch_size (int): Number of prompts to process in each batch. max_new_tokens (Optional[int]): Override for max tokens to generate. If None, uses the backend's max_new_tokens. num_responses (int): Number of responses to generate for each prompt. """ return TextGenerator(self, input_field=input_field, output_field=output_field, max_new_tokens=max_new_tokens, num_responses=num_responses)
[docs] def logprobs_generator(self, *, input_field: str = 'prompt', output_field: str = 'qanswer', logprobs_field: str = 'qanswer_logprobs', batch_size: int = 4, max_new_tokens: Optional[int] = None, num_responses: int = 1, ) -> pt.Transformer: """ Create a text generator transformer that also returns the logprobs of each token using this backend. Parameters: input_field (str): Name of the field containing input prompts. output_field (str): Name of the field to store generated text. logprobs_field (str): Name of the field to store logprobs. batch_size (int): Number of prompts to process in each batch. max_new_tokens (Optional[int]): Override for max tokens to generate. If None, uses the backend's max_new_tokens. num_responses (int): Number of responses to generate for each prompt. """ if not self.supports_logprobs: raise ValueError("This model cannot return logprobs") return TextGenerator(self, input_field=input_field, output_field=output_field, logprobs_field=logprobs_field, max_new_tokens=max_new_tokens, num_responses=num_responses)
def transform_iter(self, inp: pt.model.IterDict) -> pt.model.IterDict: return self.text_generator().transform_iter(inp) # factory methods
[docs] @staticmethod def from_dsn(dsn: str) -> 'Backend': """ Create a Backend instance from a DSN (Data Source Name) string. The DSN format is: ``<provider>:<model_id> [key1=value1 key2=value2 ...]``. Examples: ``"openai:gpt-3.5-turbo"``, ``"openai:meta-llama/Llama-4-Scout-17B-16E-Instruct base_path=http://localhost:8080/v1"``, ``"vllm:meta-llama/Llama-4-Scout-17B-16E-Instruct"``, ands ``"huggingface:meta-llama/Llama-4-Scout-17B-16E-Instruct"``. See each backend implementation ``from_params`` method for their supported keys. Parameters: dsn (str): The DSN string to parse. Returns: Backend: An instance of the appropriate Backend subclass based on the provider. Raises: ValueError: If the DSN format is invalid or the provider is unknown. """ pattern = r"^(?P<backend>\w+):(?P<model_id>[\w/-]+)(?:\s+(?P<params>.*))?$" match = re.match(pattern, dsn) if not match: raise ValueError(f"Invalid DSN format: {dsn!r}") backend = match.group("backend") backend_cls = { 'openai': _backend.OpenAIBackend, 'huggingface': _backend.HuggingFaceBackend, 'vllm': _backend.VLLMBackend, }.get(backend) if backend_cls is None: raise ValueError(f'unknown backend {backend}') params_str = match.group("params") params = { 'model_id': match.group("model_id"), } if params_str: for param in params_str.split(): key, value = param.split("=") params[key] = value return backend_cls.from_params(params)
[docs] class TextGenerator(pt.Transformer): """ Transformer that generates text from the specified backend. """ def __init__(self, backend: Backend, *, input_field: str = 'prompt', output_field: str = 'qanswer', logprobs_field: Optional[str] = None, batch_size: int = 4, max_new_tokens: Optional[int] = None, num_responses: int = 1, ): """ Parameters: backend (Backend): The backend to use for text generation. input_field (str): Name of the field containing input prompts. output_field (str): Name of the field to store generated text. logprobs_field (Optional[str]): Name of the field to store generated logprobs. If None, logprobs are not returned. batch_size (int): Number of prompts to process in each batch. max_new_tokens (Optional[int]): Override for max tokens to generate. If None, uses the backend's max_new_tokens. num_responses (int): Number of responses to generate for each prompt. """ if logprobs_field is not None and not backend.supports_logprobs: raise ValueError("Backend does not support logprobs") if num_responses != 1 and not backend.supports_num_responses: raise ValueError("Backend does not support multiple responses per input") self.backend = backend self.input_field = input_field self.output_field = output_field self.logprobs_field = logprobs_field self.batch_size = batch_size self.max_new_tokens = max_new_tokens self.num_responses = num_responses def transform_iter(self, inp: pt.model.IterDict) -> pt.model.IterDict: for chunk in chunked(inp, self.batch_size): chunk = list(chunk) prompts = [i[self.input_field] for i in chunk] out = self.backend.generate( prompts, return_logprobs=self.logprobs_field is not None, num_responses=self.num_responses, max_new_tokens=self.max_new_tokens, ) for i, rec in enumerate(chunk): for j in range(self.num_responses): o = out[i * self.num_responses + j] result = {**rec, self.output_field: o.text} if self.logprobs_field is not None: result[self.logprobs_field] = o.logprobs yield result
__all__ = ["Backend", "BackendOutput", "TextGenerator"]