Source code for pyterrier_rag.readers._fid_models

import inspect
from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
import pyterrier as pt
import pyterrier_alpha as pta
from transformers import (
    AutoTokenizer,
    BartForConditionalGeneration,
    BartModel,
    GenerationConfig,
    T5ForConditionalGeneration,
)
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput, Seq2SeqModelOutput
from transformers.models.bart.configuration_bart import BartConfig
from transformers.models.t5.configuration_t5 import T5Config


@dataclass
class FiDEncoderOuput(BaseModelOutput):
    last_hidden_state: torch.FloatTensor = None
    attention_mask: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None


class T5FiDReader(T5ForConditionalGeneration):

    def __init__(self, config: T5Config, **kwargs):
        super().__init__(config)

    def get_encoder_output(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = False,
        output_hidden_states: Optional[bool] = False,
        **kwargs
    ) -> dict:

        need_flatten = True if len(input_ids.shape) > 2 else False
        if need_flatten:
            batch_size, num_passages, seq_length = input_ids.shape
            input_ids = input_ids.reshape(-1, seq_length) # batch_size x num_passages, seq_length
            attention_mask = attention_mask.reshape(-1, seq_length)

        encoder_outputs = self.encoder(
            input_ids = input_ids,
            attention_mask = attention_mask,
            output_attentions = output_attentions,
            output_hidden_states = output_hidden_states,
            return_dict = True,
        )

        hidden_states = encoder_outputs[0] # batch_size x num_passages, seq_length, hidden_size
        hidden_size = hidden_states.shape[-1]
        if need_flatten:
            hidden_states = hidden_states.reshape(batch_size, num_passages*seq_length, hidden_size)
            attention_mask = attention_mask.reshape(batch_size, num_passages*seq_length)

        outputs = FiDEncoderOuput(
            last_hidden_state=hidden_states,
            attention_mask=attention_mask,
            hidden_states=encoder_outputs.hidden_states, 
            attentions=encoder_outputs.attentions
        )
        return outputs

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        decoder_input_ids: Optional[torch.Tensor] = None,
        decoder_attention_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Dict[str, torch.Tensor]] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = False,
        output_hidden_states: Optional[bool] = False,
        return_dict: Optional[bool] = False,
        **kwargs
    ):

        if encoder_outputs is None:
            encoder_outputs = self.get_encoder_output(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                **kwargs
            )

        encoder_hidden_states = encoder_outputs.last_hidden_state
        encoder_attention_mask = encoder_outputs.attention_mask

        # for decoding
        if labels is not None:
            decoder_input_ids = self._shift_right(labels)

        decoder_output = self.decoder(
            input_ids = decoder_input_ids,
            attention_mask = decoder_attention_mask,
            encoder_hidden_states = encoder_hidden_states,
            encoder_attention_mask = encoder_attention_mask,
            output_attentions = output_attentions,
            output_hidden_states = output_hidden_states,
            return_dict = False,
        )

        sequence_output = decoder_output[0]
        if self.config.tie_word_embeddings:
            sequence_output = sequence_output * (self.model_dim**-0.5)
        if sequence_output.dtype != self.lm_head.weight.dtype:
            sequence_output = sequence_output.to(self.lm_head.weight.dtype)
        lm_logits = self.lm_head(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss(ignore_index=-100)
            labels = labels.to(lm_logits.device)
            loss = loss_fct(lm_logits.reshape(-1, lm_logits.shape[-1]), labels.reshape(-1))

        if not return_dict:
            output = (lm_logits, encoder_hidden_states)
            output = ((loss, ) + output) if loss is not None else output
            return output

        return Seq2SeqLMOutput(loss=loss, logits=lm_logits)

    def generate(self, **kwargs):

        input_ids = kwargs.pop("input_ids")
        attention_mask = kwargs.pop("attention_mask")
        encoder_outputs = self.get_encoder_output(
            input_ids=input_ids, 
            attention_mask=attention_mask,
        )
        kwargs["encoder_outputs"] = encoder_outputs
        return super().generate(**kwargs)

    def _prepare_encoder_decoder_kwargs_for_generation(self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name, *args, **kwargs):

        # 1. get encoder
        # encoder = self.get_encoder()

        # 2. Prepare encoder args and encoder kwargs from model kwargs.
        irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
        encoder_kwargs = {
            argument: value
            for argument, value in model_kwargs.items()
            if not any(argument.startswith(p) for p in irrelevant_prefix)
        }
        # encoder_signature = set(inspect.signature(encoder.forward).parameters)
        encoder_signature = set(inspect.signature(self.get_encoder_output).parameters)

        encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
        if not encoder_accepts_wildcard:
            encoder_kwargs = {
                argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
            }

        # 3. make sure that encoder returns `ModelOutput`
        model_input_name = model_input_name if model_input_name is not None else self.main_input_name
        encoder_kwargs["return_dict"] = True
        encoder_kwargs[model_input_name] = inputs_tensor
        # model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)
        model_kwargs["encoder_outputs"] = self.get_encoder_output(**encoder_kwargs)

        # dict_keys(['attention_mask', 'ent_indices', 'ent_mask', 'output_attentions', 'output_hidden_states', 'use_cache', 'encoder_outputs'])
        return model_kwargs

    def prepare_inputs_for_generation(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        decoder_attention_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Dict[str, torch.Tensor]] = None,
        **kwargs,
    ):

        return {
            "encoder_outputs": encoder_outputs,
            "decoder_input_ids": input_ids,
            "decoder_attention_mask": decoder_attention_mask,
        }


def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
    """Shift input ids one token to the right.
    """
    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
    shifted_input_ids[:, 0] = decoder_start_token_id

    if pad_token_id is None:
        raise ValueError("self.model.config.pad_token_id has to be defined.")
    # replace possible -100 values in labels by `pad_token_id`
    shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

    return shifted_input_ids


class FiDBartTModel(BartModel):

    def get_encoder_output(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        return_dict: Optional[bool] = None,
    ):
        need_flatten = True if len(input_ids.shape) > 2 else False
        if need_flatten:
            batch_size, num_passages, seq_length = input_ids.shape
            input_ids = input_ids.reshape(-1, seq_length)
            attention_mask = attention_mask.reshape(-1, seq_length)

        encoder_outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        if return_dict:
            hidden_states = encoder_outputs.last_hidden_state
            all_hidden_states = encoder_outputs.hidden_states
            all_attentions = encoder_outputs.attentions
        else:
            hidden_states = encoder_outputs[0]
            all_hidden_states = encoder_outputs[1] if len(encoder_outputs) > 1 else None
            all_attentions = encoder_outputs[2] if len(encoder_outputs) > 2 else None

        hidden_size = hidden_states.shape[-1]
        if need_flatten:
            hidden_states = hidden_states.reshape(batch_size, num_passages*seq_length, hidden_size)
            attention_mask = attention_mask.reshape(batch_size, num_passages*seq_length)

        if not return_dict:
            return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)

        return FiDEncoderOuput(
            last_hidden_state=hidden_states, 
            attention_mask=attention_mask,
            hidden_states=all_hidden_states, 
            attentions=all_attentions
        )

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        decoder_head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[List[torch.FloatTensor]] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        # copy from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/bart/modeling_bart.py#L838
        if decoder_input_ids is None and decoder_inputs_embeds is None:
            if input_ids is None:
                raise ValueError(
                    "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
                    "passed, `input_ids` cannot be `None`. Please pass either "
                    "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
                )

            decoder_input_ids = shift_tokens_right(
                input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
            )

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if encoder_outputs is None:
            encoder_outputs = self.get_encoder_output(
                input_ids=input_ids,
                attention_mask=attention_mask,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
        # elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
        #     encoder_outputs = BaseModelOutput(
        #         last_hidden_state=encoder_outputs[0],
        #         hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
        #         attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
        #     )
        encoder_hidden_states = encoder_outputs.last_hidden_state
        encoder_attention_mask = encoder_outputs.attention_mask

        # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            encoder_hidden_states=encoder_hidden_states, # encoder_outputs[0],
            encoder_attention_mask=encoder_attention_mask, # attention_mask.reshape(attention_mask.shape[0], -1),
            head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            past_key_values=past_key_values,
            inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        if not return_dict:
            return decoder_outputs + encoder_outputs

        return Seq2SeqModelOutput(
            last_hidden_state=decoder_outputs.last_hidden_state,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )


class BARTFiDReader(BartForConditionalGeneration):

    def __init__(self, config: BartConfig, **kwargs):

        super().__init__(config)
        self.model = FiDBartTModel(config)
        self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
        self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def get_encoder_output(self, *args, **kwargs):
        return self.model.get_encoder_output(*args, **kwargs)
    
    def generate(self, **kwargs):

        input_ids = kwargs.pop("input_ids")
        attention_mask = kwargs.pop("attention_mask")
        encoder_outputs = self.get_encoder_output(
            input_ids=input_ids, 
            attention_mask=attention_mask,
            return_dict=True,
        )
        kwargs["encoder_outputs"] = encoder_outputs
        return super().generate(**kwargs)

    def _prepare_encoder_decoder_kwargs_for_generation(self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name, *args, **kwargs):

        # 1. get encoder
        # encoder = self.get_encoder()
        encoder = self.model.get_encoder_output

        # 2. Prepare encoder args and encoder kwargs from model kwargs.
        irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
        encoder_kwargs = {
            argument: value
            for argument, value in model_kwargs.items()
            if not any(argument.startswith(p) for p in irrelevant_prefix)
        }
        # encoder_signature = set(inspect.signature(encoder.forward).parameters)
        encoder_signature = set(inspect.signature(encoder).parameters)

        encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
        if not encoder_accepts_wildcard:
            encoder_kwargs = {
                argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
            }

        # 3. make sure that encoder returns `ModelOutput`
        model_input_name = model_input_name if model_input_name is not None else self.main_input_name
        encoder_kwargs["return_dict"] = True
        encoder_kwargs[model_input_name] = inputs_tensor
        model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs)
        # model_kwargs["encoder_outputs"] = self.get_encoder_output(**encoder_kwargs)

        # dict_keys(['attention_mask', 'ent_indices', 'ent_mask', 'output_attentions', 'output_hidden_states', 'use_cache', 'encoder_outputs'])
        return model_kwargs


class FiD(pt.Transformer):
    def __init__(
        self,
        model: Union[T5FiDReader, BARTFiDReader],
        tokenizer: AutoTokenizer,
        batch_size: int = 4,
        text_field: str = 'text',
        text_max_length: int = 256,
        num_context: Union[int, str] = "auto",
        max_new_tokens: int = 32,
        generation_config: GenerationConfig = None,
        verbose: bool = False,
        device: Union[str, torch.device] = None,
        **kwargs
    ):
        self.model = model.to(device)
        self.model.eval()
        self.batch_size = batch_size
        self.tokenizer = tokenizer
        self.text_field = text_field
        self.text_max_length = text_max_length
        self.num_context = num_context
        self.max_new_tokens = max_new_tokens
        self.generation_config = generation_config
        self.device = device
        self.query_prefix = "question:"
        self.title_prefix = "title:"
        self.context_prefix = "context:"
        self.verbose = verbose
    
    def get_context_by_query(self, inp: Iterable[dict]) -> Iterable[Union[str, Tuple[str]]]:
        """Return at most self.num_context retrieved context.
        """
        if self.num_context and inp:
            num = len(inp) if self.num_context == "auto" else self.num_context
            if "score" in inp[0]:
                inp = sorted(inp, key=lambda x: x["score"], reverse=True)
            if "title" in inp[0]:
                context = [(item["title"], item[self.text_field]) for item in inp]
            else:
                context = [item[self.text_field] for item in inp]
            context = context[:num]
        else:
            context = None
        return context

    def format_input_texts(self, question: str, context: Iterable[Union[str, Tuple[str]]]) -> List[str]:

        if not context:
            return [question]

        input_texts = []
        for item in context:
            # append title and context prefix
            if isinstance(item, tuple):
                title, text = item
                doc_text = self.title_prefix + " " + title + " " + self.context_prefix + " " + text
            else:
                text = item
                doc_text = self.context_prefix + " " + text
            # prepend question
            input_text = self.query_prefix + " " + question + " " + doc_text.strip()
            input_texts.append(input_text.strip())

        return input_texts

    def tokenizer_encode(self, texts: List[str]) -> Dict[str, torch.Tensor]:

        tokenizer_outputs = self.tokenizer.batch_encode_plus(
            texts,
            max_length = self.text_max_length,
            padding = "max_length",
            truncation = True,
            return_tensors = 'pt'
        )
        input_ids = tokenizer_outputs["input_ids"][None, :, :] # for only one query
        attention_mask = tokenizer_outputs["attention_mask"][None, :, :]

        return {"input_ids": input_ids, "attention_mask": attention_mask}

    @pta.transform.by_query(add_ranks=False)
    def transform_iter(self, inp: Iterable[dict]) -> Iterable[dict]:
        return self.transform_by_query(inp)

    def transform_by_query(self, inp: Iterable[dict]) -> Iterable[dict]:
        inp = list(inp)
        qid = inp[0]["qid"]
        query = inp[0]["query"]
        for row in inp:
            assert row["query"] == query, "All rows must have the same query for `transform_by_query`"

        context = self.get_context_by_query(inp)
        input_texts = self.format_input_texts(query, context)
        inputs = self.tokenizer_encode(input_texts)
        inputs = {k: v.to(self.device) if torch.is_tensor(v) else v for k, v in inputs.items()}
        generated_token_ids = self.model.generate(**inputs, generation_config=self.generation_config)
        qanswer = self.tokenizer.batch_decode(generated_token_ids, skip_special_tokens=True)[0]

        return [ {"qid": qid, "query": query, "qanswer": qanswer} ]


[docs] class T5FiD(FiD): """ T5 FiD Reader for PyTerrier-RAG .. cite.dblp:: conf/eacl/IzacardG21 """ def __init__(self, model_name_or_path: str, tokenizer_name_or_path: str = None, batch_size: int = 4, text_field: str = 'text', text_max_length: int = 256, num_context: Union[int, str] = "auto", max_new_tokens: int = 32, generation_config: GenerationConfig = None, verbose: bool = False, device: Union[str, torch.device] = None, **kwargs): model = T5FiDReader.from_pretrained(model_name_or_path) tokenizer_name_or_path = tokenizer_name_or_path or model_name_or_path tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) super().__init__(model, tokenizer, batch_size, text_field, text_max_length, num_context, max_new_tokens, generation_config, verbose, model.device, **kwargs)
[docs] class BARTFiD(FiD): """ BART FiD Reader for PyTerrier-RAG .. cite.dblp:: conf/eacl/IzacardG21 """ def __init__(self, model_name_or_path: str, tokenizer_name_or_path: str = None, batch_size: int = 4, text_field: str = 'text', text_max_length: int = 256, num_context: Union[int, str] = "auto", max_new_tokens: int = 32, generation_config: GenerationConfig = None, verbose: bool = False, device: Union[str, torch.device] = None, **kwargs): model = BARTFiDReader.from_pretrained(model_name_or_path) tokenizer_name_or_path = tokenizer_name_or_path or model_name_or_path tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) super().__init__(model, tokenizer, batch_size, text_field, text_max_length, num_context, max_new_tokens, generation_config, verbose, model.device, **kwargs)