Source code for pyterrier_rag._measures

import string
from collections import Counter
from typing import List, Literal

import regex
import ir_measures
import pyterrier_alpha as pta


# Normalization from SQuAD evaluation script https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/
def normalize_answer(s : str) -> str:
    def remove_articles(text: str) -> str:
        return regex.sub(r'\b(a|an|the)\b', ' ', text)
    def white_space_fix(text: str) -> str:
        return ' '.join(text.split())
    def remove_punc(text: str) -> str:
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)
    return white_space_fix(remove_articles(remove_punc(s.lower())))

def exact_match_score(prediction : str, ground_truth : str) -> float:
    return float(normalize_answer(prediction) == normalize_answer(ground_truth))

def ems(prediction : str, ground_truths : List[str]) -> float:
    return max([exact_match_score(prediction, gt) for gt in ground_truths])

# F1 Evaluation from HotPotQA evaluation script: https://raw.githubusercontent.com/hotpotqa/hotpot/master/hotpot_evaluate_v1.py
def f1_score(prediction : str, ground_truth : List[str]) -> float:
    normalized_prediction = normalize_answer(prediction)
    normalized_ground_truth = normalize_answer(ground_truth)
    ZERO_METRIC = (0, 0, 0)
    if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
        return ZERO_METRIC[0]
    if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
        return ZERO_METRIC[0]
    prediction_tokens = normalized_prediction.split()
    ground_truth_tokens = normalized_ground_truth.split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return ZERO_METRIC[0]
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    # return f1, precision, recall
    return f1

AnswerLen = ir_measures.define_byquery(
    lambda qrels, res: len(res.iloc[0]['qanswer']),
    name='AnswerLen', support_cutoff=False)

AnswerZeroLen = ir_measures.define_byquery(
    lambda qrels, res: 1 if len(res.iloc[0]['qanswer']) == 0 else 0,
    name='AnswerZeroLen', support_cutoff=False)

# we aggregate across multiple gold_answer values using max().
F1 = ir_measures.define_byquery(
    lambda qrels, res: max([f1_score(res.iloc[0]['qanswer'], gold) for gold in qrels['gold_answer']]), support_cutoff=False, name="F1")
# ems function handles the max()
EM = ir_measures.define_byquery(
    lambda qrels, res: ems(res.iloc[0]['qanswer'], qrels['gold_answer']), support_cutoff=False, name="EM")

def _rouge(measure : Literal['precision', 'recall', 'fmeaure'], type : Literal['rouge1', 'rouge2', 'rougeL'] ='rouge1', use_stemmer : bool = True):
    from rouge_score import rouge_scorer
    scorer = rouge_scorer.RougeScorer([type], use_stemmer=use_stemmer)

    name = "ROUGE"
    match type:
        case "rouge1":
            name += '1'
        case "rouge2":
            name += '2'
        case "rougeL":
            name += 'L'
        case _:
            raise ValueError("Unkown type %s" % type)

    match measure:
        case "precision":
            name += "P"
        case "recall":
            name += "R"
        case "fmeasure":
            name += "F"
        case _:
            raise ValueError("Unkown measure %s" % type)

    def _measure(qrels, res):
        assert len(qrels) == 1
        assert len(res) == 1        
        res = scorer.score(
            res.iloc[0]['qanswer'], 
            qrels['gold_answer'].iloc[0][0] # [0] for the first answer in the array of gold_answer
        )
        return getattr(res[type], measure)
    
    return ir_measures.define_byquery(_measure, name=name, support_cutoff=False)

ROUGE1P = _rouge('precision', 'rouge1')
ROUGE1R = _rouge('recall', 'rouge1')
ROUGE1F = _rouge('fmeasure', 'rouge1')

ROUGE2P = _rouge('precision', 'rouge2')
ROUGE2R = _rouge('recall', 'rouge2')
ROUGE2F = _rouge('fmeasure', 'rouge2')

ROUGELP = _rouge('precision', 'rougeL')
ROUGELR = _rouge('recall', 'rougeL')
ROUGELF = _rouge('fmeasure', 'rougeL')

_bertscore_model = None
def _bertscore(qrels, res, rel = 3, submeasure='f1', agg='max'):

    pta.validate.columns(qrels, includes=['query_id', 'relevance', 'text'])
    pta.validate.columns(res, includes=['query_id', 'qanswer'])
    
    assert len(res), "Empty res df provided"
    assert len(qrels), "Empty qrels df provided"
    qrels = qrels[qrels.relevance >= rel]
    assert len(qrels), "No qrels found with minimum label of %d" % rel

    global _bertscore_model
    if _bertscore_model is None:
        from evaluate import load # this is a huggingface package
        _bertscore_model = load("bertscore")
    
    predictions = res['qanswer'].to_list()
    assert len(predictions) == 1, "Unexpected number of predictions"
    references = qrels['text'].to_list()
    # duplicate the prediction for the nbr of ground truths 
    predictions = predictions * len(references)

    results = _bertscore_model.compute(predictions=predictions, references=references, lang="en", model_type="bert-large-uncased", verbose=False)
    precisions, recall, f1 = results['precision'], results['recall'], results['f1']
    r = {
        'precision': {'avg': sum(precisions)/len(precisions), 'max': max(precisions)},
        'recall': {'avg': sum(recall)/len(recall), 'max': max(recall)},
        'f1': {'avg': sum(f1)/len(f1), 'max': max(f1)},
        }
    return r[submeasure][agg]

[docs] def BERTScore(rel=3, submeasure : str = 'f1', agg : str = 'max'): ''' Implements BERTScore, a semantic measure of equivalence. This is defined to take a qrels dataframe with an additional text attribute, and compare with the generated qanswers. NB: This is a function that returns a measure - it needs to be called. Arguments: - rel(int): Minimum label value for relevant qrels. Defaults to 3, which is the highest label in MSMARCO. - submeasure(str): One of 'precision', 'recall' and 'f1'. Defaults to 'f1'. - agg(str): How to combine (aggregate) when there are multiple relevant documents. Valid options are 'max' or 'avg'. Defaults to 'max'. Returns: An IR measures measure object that can be used in pt.Evaluate or pt.Experiment ''' return ir_measures.define_byquery( lambda qrels, res: _bertscore(qrels, res, rel=rel, agg=agg), name='BERTScore', support_cutoff=False)