Source code for pyterrier_rag.prompt._base
from typing import Optional, Union, List, Any
import pandas as pd
import pyterrier as pt
import pyterrier_alpha as pta
from pyterrier_rag.prompt.wrapper import prompt
from fastchat.model import get_conversation_template
from fastchat.conversation import get_conv_template
[docs]
class PromptTransformer(pt.Transformer):
"""
Transformer that constructs and formats prompts for conversational LLMs.
Parameters:
instruction (callable|str): Template or function returning the instruction segment.
model_name_or_path (str, optional): Model identifier for selecting conversation template.
system_message (str, optional): System context message for the conversation.
conversation_template (Any, optional): Preconfigured conversation template.
api_type (str, optional): API format: 'openai','gemini','vertex','reka'.
output_field (str): Field name to store the generated prompt.
input_fields (List[str]): Input record fields required to build the prompt.
expects_logprobs (bool): Indicator for logprob-based backends.
answer_extraction (callable, optional): Function to parse model outputs.
raw_instruction (bool): If True, returns raw instruction without template.
"""
def __init__(
self,
instruction: Union[callable, str] = None,
model_name_or_path: str = None,
system_message: Optional[str] = None,
conversation_template: Optional[Any] = None,
api_type: Optional[str] = None,
output_field: str = "prompt",
input_fields: List[str] = ["query", "qcontext"],
expects_logprobs: bool = False,
answer_extraction: Optional[callable] = None,
raw_instruction: bool = False,
):
self.instruction = instruction
self.model_name_or_path = model_name_or_path
self.system_message = system_message
self.output_field = output_field
self.input_fields = input_fields
self.conversation_template = conversation_template
self.api_type = api_type
self.expects_logprobs = expects_logprobs
self.answer_extraction = answer_extraction or self.answer_extraction
self.raw_instruction = raw_instruction
self.__post_init__()
def __post_init__(self):
if type(self.instruction) is str:
self.instruction = prompt(self.instruction)
if self.model_name_or_path is not None:
self.conversation_template = (
self.conversation_template or get_conversation_template(self.model_name_or_path)
)
if self.conversation_template is None:
self.conversation_template = get_conv_template("zero_shot")
if self.system_message is not None:
# TODO: Set flag for if model supports system message
self.conversation_template.set_system_message(self.system_message)
roles = self.conversation_template.roles
if len(roles) < 2:
self.user_role, self.llm_role = "user", "assistant"
else:
self.user_role = roles[0]
self.llm_role = roles[1]
self.output_attribute = (
{
"openai": "to_openai_api_messages",
"gemini": "to_gemini_api_messages",
"vertex": "to_vertex_api_messages",
"reka": "to_reka_api_messages",
}[self.api_type]
if self.api_type
else "get_prompt"
)
def answer_extraction(self, output):
return output
def set_output_attribute(self, supports_message_input: bool):
# ``output_attribute`` indicates the method to call on the prompt object
# In the future, we may support more message formats, but for now it's either a string or OpenAI-formatted messages
self.output_attribute = 'to_openai_api_messages' if supports_message_input else 'get_prompt'
@property
def prompt(self):
return self.conversation_template.copy()
def to_output(self, prompt) -> Union[str, List[dict]]:
return getattr(prompt, self.output_attribute)()
def create_prompt(self, fields: dict) -> Union[str, List[dict]]:
current_prompt = self.prompt
instruction = self.instruction(**fields)
if self.raw_instruction:
return instruction
current_prompt.append_message(self.user_role, instruction)
return self.to_output(current_prompt)
def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
pta.validate.columns(inp, includes=["qid"] + self.input_fields)
output_frame = []
if inp is None or inp.empty:
return pta.DataFrameBuilder([self.output_field, "qid", *self.input_fields]).to_df()
for qid, group in inp.groupby("qid"):
inp = group.to_dict(orient="records")
query = inp[0].get("query", None)
fields = {k: v for k, v in inp[0].items() if k in self.input_fields}
if any([f not in fields for f in self.input_fields]):
message = f"Expected {self.input_fields} but recieved {fields}"
raise ValueError(message)
prompt = self.create_prompt(fields)
output_frame.append({self.output_field: prompt, "qid": qid, "query_0": query})
return pd.DataFrame(output_frame, columns=[self.output_field, "qid", "query_0"])
__all__ = ["PromptTransformer"]