Source code for pyterrier.validate

"""Validation utilities for checking the inputs of transformers."""
import warnings
from types import TracebackType
from typing import List, Optional, Type, Union, Set
import pandas as pd
import pyterrier as pt


class _TransformerMode:
    def __init__(self, missing_columns: List[str], extra_columns: List[str], mode_name: Optional[str] = None):
        self.missing_columns = missing_columns
        self.extra_columns = extra_columns
        self.mode_name = mode_name

    def __str__(self):
        return f'{self.mode_name} (missing: {self.missing_columns}, extra: {self.extra_columns})'

    def __repr__(self):
        return f'TransformerMode(missing_columns={self.missing_columns!r}, ' \
               f'extra_columns={self.extra_columns!r}, ' \
               f'mode_name={self.mode_name!r})'


class InputValidationError(KeyError):
    """Exception raised when input validation of a single transformer fails.
    
    This is raised when the input DataFrame does not match the expected columns for a transformer.
    its typically raised by pt.validate.* methods.
    """
    def __init__(self, message: str, modes: List[_TransformerMode]):
        """Create an InputValidationError."""
        assert len(modes) > 0
        super().__init__(message)
        self.modes = modes

    def __str__(self):
        return self.args[0] + ' ' + str(self.modes)

    def __repr__(self):
        return f'InputValidationError({self.args[0]!r}, {self.modes!r})'


class InputValidationWarning(Warning):
    """Warning raised when input validation fails in warn mode.
    
    This is raised when the input DataFrame does not match the expected columns for a transformer,
    but the validation is in warn mode. It is typically produced by pt.validate.* methods."""
    pass


[docs] def columns( inp: Union[pd.DataFrame, List[str]], *, includes: Optional[List[str]] = None, excludes: Optional[List[str]] = None, warn: bool = False ) -> None: """Check that the input frame has the expected columns. Args: inp: Input DataFrame or columns to validate includes: List of required columns excludes: List of forbidden columns warn: If True, raise warnings instead of exceptions for validation errors Raises: InputValidationError: If warn=False and validation fails InputValidationWarning: If warn=True and validation fails .. versionchanged:: 0.15.0 Accept ``List[str]`` inp columns """ with any(inp, warn=warn) as v: v.columns(includes=includes, excludes=excludes)
[docs] def query_frame(inp: Union[pd.DataFrame, List[str]], extra_columns: Optional[List[str]] = None, warn: bool = False) -> None: """Check that the input frame is a valid query frame. Args: inp: Input DataFrame or columns to validate extra_columns: Additional required columns warn: If True, raise warnings instead of exceptions for validation errors Raises: InputValidationError: If warn=False and validation fails InputValidationWarning: If warn=True and validation fails .. versionchanged:: 0.15.0 Accept ``List[str]`` inp columns """ with any(inp, warn=warn) as v: v.query_frame(extra_columns)
[docs] def result_frame(inp: Union[pd.DataFrame, List[str]], extra_columns: Optional[List[str]] = None, warn: bool = False) -> None: """Check that the input frame is a valid result frame. Args: inp: Input DataFrame or columns to validate extra_columns: Additional required columns warn: If True, raise warnings instead of exceptions for validation errors Raises: InputValidationError: If warn=False and validation fails InputValidationWarning: If warn=True and validation fails .. versionchanged:: 0.15.0 Accept ``List[str]`` inp columns """ with any(inp, warn=warn) as v: v.result_frame(extra_columns)
[docs] def document_frame(inp: Union[pd.DataFrame, List[str]], extra_columns: Optional[List[str]] = None, warn: bool = False) -> None: """Check that the input frame is a valid document frame. Args: inp: Input DataFrame or columns to validate extra_columns: Additional required columns warn: If True, raise warnings instead of exceptions for validation errors Raises: InputValidationError: If warn=False and validation fails InputValidationWarning: If warn=True and validation fails .. versionchanged:: 0.15.0 Accept ``List[str]`` inp columns """ with any(inp, warn=warn) as v: v.document_frame(extra_columns)
[docs] def columns_iter( inp: 'pt.utils.PeekableIter', *, includes: Optional[List[str]] = None, excludes: Optional[List[str]] = None, warn: bool = False ) -> None: """Check that the input frame has the expected columns. Args: inp: Input DataFrame to validate includes: List of required columns excludes: List of forbidden columns warn: If True, raise warnings instead of exceptions for validation errors Raises: InputValidationError: If warn=False and validation fails InputValidationWarning: If warn=True and validation fails """ with any_iter(inp, warn=warn) as v: v.columns(includes=includes, excludes=excludes)
[docs] def query_iter(inp: 'pt.utils.PeekableIter', extra_columns: Optional[List[str]] = None, warn: bool = False) -> None: """Check that the input iterator is a valid query iterator. Args: inp: Input iterator to validate extra_columns: Additional required columns warn: If True, raise warnings instead of exceptions for validation errors Raises: InputValidationError: If warn=False and validation fails InputValidationWarning: If warn=True and validation fails """ with any_iter(inp, warn=warn) as v: v.query_iter(extra_columns)
[docs] def result_iter(inp: 'pt.utils.PeekableIter', extra_columns: Optional[List[str]] = None, warn: bool = False) -> None: """Check that the input iterator is a valid result iterator. Args: inp: Input iterator to validate extra_columns: Additional required columns warn: If True, raise warnings instead of exceptions for validation errors Raises: InputValidationError: If warn=False and validation fails InputValidationWarning: If warn=True and validation fails """ with any_iter(inp, warn=warn) as v: v.result_iter(extra_columns)
[docs] def document_iter(inp: 'pt.utils.PeekableIter', extra_columns: Optional[List[str]] = None, warn: bool = False) -> None: """Check that the input iterator is a valid document iterator. Args: inp: Input iterator to validate extra_columns: Additional required columns warn: If True, raise warnings instead of exceptions for validation errors Raises: InputValidationError: If warn=False and validation fails InputValidationWarning: If warn=True and validation fails """ with any_iter(inp, warn=warn) as v: v.document_iter(extra_columns)
[docs] def any(inp: Union[pd.DataFrame, List[str]], warn: bool = False) -> '_ValidationContextManager': """Create a validation context manager for a DataFrame or list of columns to test multiple possible modes. Args: inp: Input DataFrame or list of columns to validate warn: If True, raise warnings instead of exceptions for validation errors """ return _ValidationContextManager(inp, warn=warn)
[docs] def any_iter(inp: 'pt.utils.PeekableIter', warn: bool = False) -> '_IterValidationContextManager': """Create a validation context manager for an iterator to test multiple possible modes. Args: inp: Input iterator to validate warn: If True, raise warnings instead of exceptions for validation errors """ if not isinstance(inp, pt.utils.PeekableIter): raise AttributeError('inp is not peekable. Run the following before calling this function.\n' 'inp = pt.utils.peekable(inp) # !! IMPORTANT: you must re-assign the input to peekable ' '(not just pass it in), otherwise you will skip the first record !!') return _IterValidationContextManager(inp, warn=warn)
class _ValidationContextManager: """Context manager for validating the input to transformers.""" def __init__(self, inp: Union[pd.DataFrame, List[str]], warn: bool = False): """Create a ValidationContextManager for the given DataFrame.""" if isinstance(inp, pd.DataFrame): self.inp_columns = list(inp.columns) else: self.inp_columns = inp self.mode: Optional[str] = None self.attempts = 0 self.errors: List[_TransformerMode] = [] self.warn = warn def __enter__(self): return self def __exit__( self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType] ) -> Optional[bool]: if exc_type is not None: return False # the captured exception takes priority if self.attempts > 0 and self.attempts == len(self.errors): message = "DataFrame(columns=%s) does not match required columns for this transformer." % str(self.inp_columns) if self.warn: warnings.warn(f"{message} {self.errors}", InputValidationWarning) else: raise InputValidationError(message, self.errors) return None def columns(self, *, includes: Optional[List[str]] = None, excludes: Optional[List[str]] = None, mode: Optional[str] = None) -> bool: """Check that the input frame has the ``includes`` columns and doesn't have the ``excludes`` columns.""" includes = includes if includes is not None else [] excludes = excludes if excludes is not None else [] missing_columns = set(includes) - set(self.inp_columns) extra_columns = set(excludes) & set(self.inp_columns) self.attempts += 1 if missing_columns or extra_columns: self.errors.append(_TransformerMode( missing_columns=[c for c in includes if c in missing_columns], extra_columns=[c for c in excludes if c in extra_columns], mode_name=mode, )) return False if self.mode is None and mode is not None: self.mode = mode return True def query_frame(self, extra_columns: Optional[List[str]] = None, mode: Optional[str] = None) -> bool: """Check that the input frame is a valid query frame, with optional extra columns.""" extra_columns = list(extra_columns) if extra_columns is not None else [] return self.columns(includes=['qid'] + extra_columns, excludes=['docno'], mode=mode) def result_frame(self, extra_columns: Optional[List[str]] = None, mode: Optional[str] = None) -> bool: """Check that the input frame is a valid result frame, with optional extra columns.""" extra_columns = list(extra_columns) if extra_columns is not None else [] return self.columns(includes=['qid', 'docno'] + extra_columns, mode=mode) def document_frame(self, extra_columns: Optional[List[str]] = None, mode: Optional[str] = None) -> bool: """Check that the input frame is a valid document frame, with optional extra columns.""" extra_columns = list(extra_columns) if extra_columns is not None else [] return self.columns(includes=['docno'] + extra_columns, excludes=['qid'], mode=mode) _EMPTY_ITER = object() class _IterValidationContextManager: def __init__(self, inp: 'pt.utils.PeekableIter', warn: bool = False): self.sample_cols: Union[Set[str], object] try: self.sample_cols = set(inp.peek().keys()) except StopIteration: self.sample_cols = _EMPTY_ITER self.mode: Optional[str] = None self.attempts = 0 self.errors: List[_TransformerMode] = [] self.warn = warn def __enter__(self): return self def __exit__( self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType] ) -> Optional[bool]: if exc_type is not None: return False # the captured exception takes priority if self.attempts > 0 and self.attempts == len(self.errors): message = "Input does not match required columns for this transformer." if self.warn: warnings.warn(f"{message} {self.errors}", InputValidationWarning) else: raise InputValidationError(message, self.errors) return None def columns(self, *, includes: Optional[List[str]] = None, excludes: Optional[List[str]] = None, mode: Optional[str] = None) -> bool: self.attempts += 1 includes = includes if includes is not None else [] excludes = excludes if excludes is not None else [] if not isinstance(self.sample_cols, set): assert self.sample_cols == _EMPTY_ITER # AKA, but need to check type to make mypy happy self.errors.append(_TransformerMode( missing_columns=list(includes), extra_columns=[], mode_name=mode, )) return False missing_columns = set(includes) - self.sample_cols extra_columns = set(excludes) & self.sample_cols if missing_columns or extra_columns: self.errors.append(_TransformerMode( missing_columns=[c for c in includes if c in missing_columns], extra_columns=[c for c in excludes if c in extra_columns], mode_name=mode, )) return False if self.mode is None and mode is not None: self.mode = mode return True def query_iter(self, extra_columns: Optional[List[str]] = None, mode: Optional[str] = None) -> bool: """Check that the input iterator is a valid query iterator, with optional extra columns.""" extra_columns = list(extra_columns) if extra_columns is not None else [] return self.columns(includes=['qid'] + extra_columns, excludes=['docno'], mode=mode) def result_iter(self, extra_columns: Optional[List[str]] = None, mode: Optional[str] = None) -> bool: """Check that the input iterator is a valid result iterator, with optional extra columns.""" extra_columns = list(extra_columns) if extra_columns is not None else [] return self.columns(includes=['qid', 'docno'] + extra_columns, mode=mode) def document_iter(self, extra_columns: Optional[List[str]] = None, mode: Optional[str] = None) -> bool: """Check that the input iterator is a valid document iterator, with optional extra columns.""" extra_columns = list(extra_columns) if extra_columns is not None else [] return self.columns(includes=['docno'] + extra_columns, excludes=['qid'], mode=mode) def empty(self, *, mode: Optional[str] = 'empty'): self.attempts += 1 if self.sample_cols != _EMPTY_ITER: self.errors.append(_TransformerMode( missing_columns=[], extra_columns=[], mode_name=mode, )) return False if self.mode is None and mode is not None: self.mode = mode return True