Source code for pyterrier.schematic

from copy import copy
import html
import uuid
from importlib import resources
from typing import Any, Dict, List, Optional, Protocol, Union, cast, runtime_checkable

import numpy as np
import pyterrier as pt
import re
from pyterrier._ops import Compose


def radix_tree_schematic(tree, input_columns=None):
    def terminal_output_node(parent_schem):
        return {
            "type": "node",
            "children": [],
            "self": {
                "input_columns": parent_schem.get('output_columns'),
                "output_columns": parent_schem.get('output_columns'),
                "is_last": True,
                "type": "output"   # output is added as a type even though it is not a type of element, to make life easier
            },
        }

    def node_to_schematic(edge_label, node, _input_columns=None):
        # Efficiently determine transformer for schematic
        if isinstance(edge_label, (tuple, list)):
            transformer = edge_label[0] if len(edge_label) == 1 else Compose(*list(edge_label))
        else:
            transformer = edge_label

        self_schem = pt.schematic.transformer_schematic(transformer, input_columns=_input_columns) if transformer is not None else {}

        children = [node_to_schematic(child_label, child, _input_columns=self_schem.get('output_columns')) for child_label, child in node.children.items()]
        is_terminal = node.value is not None
        self_schem['is_terminal'] = is_terminal
        if self_schem['type'] == 'pipeline':

            transformers = self_schem.get('transformers', [])
            n = len(transformers)-1
            has_children = bool(node.children)
            
            for idx, t in enumerate(transformers):
                t['node_id'] = f"{node.node_id}:{idx}"
                t['is_last'] = (idx == n) and not has_children

        else:
            self_schem['node_id'] = node.node_id
            # node.value -> evaluation index, node.children -> whether it's a leaf node or not
            self_schem['is_last'] = is_terminal and not bool(node.children)

        if is_terminal and children:
            children = [terminal_output_node(self_schem), *children]

        node_dict = {
            "type": "node",
            "children": children,
            "self": self_schem,
        }
        # Mark as a branch only when there is more than one child.
        if len(children) > 1:
            node_dict["mode"] = "branch"
        return node_dict

    nodes = [node_to_schematic(edge_label, child, _input_columns=input_columns) for edge_label, child in tree.root.children.items()]
    mode = "branch" if len(nodes) > 1 else "linear"
    return {
        "type": "tree",
        "input_columns": input_columns,
        "nodes": nodes,
        "mode": mode
    }

def _apply_default_schematic(schematic: Dict[str, Any], transformer: pt.Transformer, *, input_columns: Optional[List[str]] = None):
    schematic.setdefault('type', 'indexer' if pt.inspect.transformer_type(transformer) == pt.inspect.TransformerType.indexer else 'transformer')
    assert schematic['type'] in ('transformer', 'indexer')
    if 'label' not in schematic:
        label = transformer.__class__.__name__
        if label.endswith('Transformer'):
            schematic['label'] = label[:-len('Transformer')]
        else:
            schematic['label'] = label

    if 'class_name' not in schematic:
        name = f'{transformer.__class__.__module__}.{transformer.__class__.__name__}'
        if name.startswith('pyterrier.'):
            name = 'pt.' + name[len('pyterrier.'):]
        schematic['name'] = name

    if 'input_columns' not in schematic or 'output_columns' not in schematic:
        if 'input_columns' not in schematic:
            schematic['input_columns'] = input_columns
        if 'output_columns' not in schematic:
            if input_columns is None:
                schematic['output_columns'] = None
            else:
                try:
                    schematic['output_columns'] = pt.inspect.transformer_outputs(transformer, input_columns)
                except pt.validate.InputValidationError as e:
                    schematic['output_columns'] = None
                    schematic['input_validation_error'] = e
                except pt.inspect.InspectError:
                    schematic['output_columns'] = None

    default_settings_applied = False
    if 'settings' not in schematic:
        default_settings_applied = True
        try:
            schematic['settings'] = {attr.name: attr.value for attr in pt.inspect.transformer_attributes(transformer)}
        except pt.inspect.InspectError:
            schematic['settings'] = {}

    if 'help_url' not in schematic:
        schematic['help_url'] = pt.documentation.url_for_class(transformer)

    if 'inner_pipelines' not in schematic:
        try:
            subtransformers = pt.inspect.subtransformers(transformer)
        except pt.inspect.InspectError:
            subtransformers = {}
        if subtransformers:
            subtransformer_inputs = schematic['input_columns'] or _INFER
            if schematic.get('inner_pipelines_mode', 'unlinked') == 'unlinked':
                subtransformer_inputs = _INFER
            pipelines = []
            pipeline_labels = []
            for key, value in subtransformers.items():
                if default_settings_applied and key in schematic['settings']:
                    del schematic['settings'][key]
                if isinstance(value, list):
                    for i, v in enumerate(value):
                        pipelines.append(transformer_schematic(v, input_columns=subtransformer_inputs))
                        pipeline_labels.append(f'{key}[{i}]')
                else:
                    pipelines.append(transformer_schematic(value, input_columns=subtransformer_inputs))
                    pipeline_labels.append(key)
            schematic['inner_pipelines'] = pipelines
            schematic['inner_pipelines_labels'] = pipeline_labels

    if schematic.get('inner_pipelines') and 'inner_pipelines_mode' not in schematic:
        schematic['inner_pipelines_mode'] = 'unlinked'


_INFER = object()
def transformer_schematic(
    transformer: pt.Transformer,
    *,
    input_columns: Optional[Union[List[str],object]] = _INFER,
    default: bool = False,
) -> dict:
    """Builds a structured schematic of the transformer."""
    is_indexer = pt.inspect.transformer_type(transformer) == pt.inspect.TransformerType.indexer
    if input_columns is _INFER:
        if is_indexer:
            all_input_column_configs = pt.inspect.indexer_inputs(cast(pt.Indexer, transformer), strict=False) # noqa: PT100
        else:
            all_input_column_configs = pt.inspect.transformer_inputs(transformer, strict=False)
        if all_input_column_configs is not None and len(all_input_column_configs) > 0:
            input_columns = all_input_column_configs[0] # pick the first one
        else:
            input_columns = None
    # input_columns can no longer be _INFER
    input_columns = cast(Optional[List[str]], input_columns) # noqa: PT100 (this is typing.cast, not jinus.cast)
    if not default and isinstance(transformer, HasSchematic):
        if callable(transformer.schematic):
            schematic = transformer.schematic(input_columns=input_columns)
        else:
            schematic = transformer.schematic
        schematic = copy(schematic) # we don't want to accidently modify the original schematic
        if 'type' in schematic:
            return schematic
    else:
        schematic = {}
    _apply_default_schematic(schematic, transformer, input_columns=input_columns)
    return schematic

# Tools for converting the schematic diagrams to html
_css = None
_js = None
def _get_schematic_css_js(container_id):
    global _css, _js
    if _css is None or _js is None:
        _css = (resources.files('pyterrier') / 'data/schematic.css').read_text()
    if _js is None:
        _js = (resources.files('pyterrier') / 'data/schematic.js').read_text()
    css = _css.replace('#ID', f'#{container_id}')
    js = _js.replace('#ID', f'#{container_id}')
    return css, js


[docs] def draw(transformer: Union[pt.Transformer, dict], *, outer_class: Optional[str] = None, input_columns: Optional[List[str]] = None) -> str: """Draws a transformer as an HTML schematic. If the transformer is already a ``SchematicDict``, it will be drawn directly. Otherwise, it will first convert the transformer to a structured schematic using :func:`transformer_schematic`, and draw that. Args: transformer: The transformer to draw, or a dict in ``SchematicDict`` format. input_columns: If you want to specify the input columns for the transformer (pipeline). outer_class: An optional CSS class to apply to the outer container of the schematic. Returns: An HTML string representing the schematic of the transformer. """ if isinstance(transformer, dict): assert input_columns is None, "Cannot set input_columns and provide a SchematicDict input." schematic = transformer else: schematic = transformer_schematic(transformer, input_columns=input_columns or _INFER) return draw_html_schematic(schematic, outer_class=outer_class)
def draw_html_schematic(schematic: dict, *, outer_class: Optional[str] = None) -> str: """Draws a structured schematic as an HTML representation.""" uid = str(uuid.uuid4()) css, js = _get_schematic_css_js(f'id-{uid}') if schematic.get('type') == 'tree': # Use the custom tree/radix renderer for tree schematics inner_html = f'<div class="pts-tree-scroll">{draw_radix_html_schematic(schematic, outer_class="outer")}</div>' else: inner_html = _draw_html_schematic(schematic) return f''' <div id="id-{uid}" class="{outer_class or ''}" style="display: none;"> <style>{css}</style> <div class="pts-infobox"> <div class="pts-infobox-title"></div> <div class="pts-infobox-body"></div> <div class="pts-infobox-hint"><svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="currentColor" style="width: 12px; height: 12px; vertical-align: -2px;"><path stroke="none" d="M0 0h24v24H0z" fill="none"/><path d="M4 11a1 1 0 0 1 .117 1.993l-.117 .007h-1a1 1 0 0 1 -.117 -1.993l.117 -.007h1z" /><path d="M12 2a1 1 0 0 1 .993 .883l.007 .117v1a1 1 0 0 1 -1.993 .117l-.007 -.117v-1a1 1 0 0 1 1 -1z" /><path d="M21 11a1 1 0 0 1 .117 1.993l-.117 .007h-1a1 1 0 0 1 -.117 -1.993l.117 -.007h1z" /><path d="M4.893 4.893a1 1 0 0 1 1.32 -.083l.094 .083l.7 .7a1 1 0 0 1 -1.32 1.497l-.094 -.083l-.7 -.7a1 1 0 0 1 0 -1.414z" /><path d="M17.693 4.893a1 1 0 0 1 1.497 1.32l-.083 .094l-.7 .7a1 1 0 0 1 -1.497 -1.32l.083 -.094l.7 -.7z" /><path d="M14 18a1 1 0 0 1 1 1a3 3 0 0 1 -6 0a1 1 0 0 1 .883 -.993l.117 -.007h4z" /><path d="M12 6a6 6 0 0 1 3.6 10.8a1 1 0 0 1 -.471 .192l-.129 .008h-6a1 1 0 0 1 -.6 -.2a6 6 0 0 1 3.6 -10.8z" /></svg> Click to explore!</div> </div> {inner_html} <script>{js}</script> </div> <div id="id-{uid}-pts-rendering-issue"> Rendering issue. Try running the cell again. </div> ''' def render_transformer_infobox(record: Dict[str, Any]) : uid = str(uuid.uuid4()) infobox = '' infobox_attr = '' error_cls = '' if record.get('settings') or record.get('name') or record.get('input_validation_error'): help_url = record.get('help_url') or '' name = record.get('name') or '' attrs = '' error_info = '' if record.get('input_validation_error'): modes = record['input_validation_error'].modes error_cls = 'pts-input-validation-error' if len(modes) == 1: # Normal case: there's just one mode error_info = '<div class="pts-infobox-error">' if len(modes[0].missing_columns) > 0: error_info += f'Missing input columns: {", ".join(["<b>" + html.escape(c) + "</b>" for c in modes[0].missing_columns])}. ' if len(modes[0].extra_columns) > 0: error_info += f'Unexpected input columns: {", ".join(["<b>" + html.escape(c) + "</b>" for c in modes[0].extra_columns])}. ' error_info += '</div>' else: error_info = '<div class="pts-infobox-error"><div>None of the supported input modes matched:</div><ul>' for i, error_mode in enumerate(modes): error_info += f'<li>Mode {html.escape(error_mode.mode_name or str(i+1))}: ' if len(error_mode.missing_columns) > 0: error_info += f'Missing input columns: {", ".join(["<b>" + html.escape(c) + "</b>" for c in error_mode.missing_columns])}. ' if len(error_mode.extra_columns) > 0: error_info += f'Unexpected input columns: {", ".join(["<b>" + html.escape(c) + "</b>" for c in error_mode.extra_columns])}. ' error_info += '</li>' error_info += '</ul></div>' if record['settings']: attr_rows = [] for key, value in record['settings'].items(): attr_rows.append(f'<tr><th>{html.escape(key)}</th><td>{html.escape(str(value))}</td></tr>') attrs = f'<table class="pts-df-columns">{"".join(attr_rows)}</table>' infobox = f''' <div class="pts-infobox-item" id="id-{uid}" data-title="Transformer"> <div style="font-family: monospace; padding: 3px 6px;"> {'<a href="' + html.escape(help_url) + '" target="_blank" onclick="window.event.stopPropagation();">' if help_url else ''} {html.escape(name)} {'</a>' if help_url else ''} </div> {error_info} {attrs} </div> ''' infobox_attr = f'data-pts-infobox="id-{uid}"' return infobox, infobox_attr, error_cls def draw_radix_html_schematic(radix_schematic, outer_class='outer') -> str: def render_node(record, is_last): node_id = record.get('node_id')#this is diff dom_id = f"pts-node-{node_id}" if node_id is not None else '' #this is diff infobox, infobox_attr, error_cls = render_transformer_infobox(record) html_block = f''' <div class="pts-transformer pts-pending {error_cls}" id="{dom_id}" data-node-id="{node_id}" {infobox_attr}> {infobox} <div class="pts-transformer-title">{html.escape(record["label"])}</div> </div> ''' output_columns = record.get("output_columns") input_columns = record.get("input_columns") if output_columns is not None: if is_last: html_block += f'<div class="pts-hline pts-arr pts-arr-output">{_draw_df_html(output_columns, input_columns)}</div>' html_block += '<div class="pts-io-label">Evaluate</div>' else: if outer_class == 'inner-pipeline': html_block += f'<div class="pts-hline pts-arr-inner">{_draw_df_html(output_columns, input_columns)}</div>' else: html_block += f'<div class="pts-hline pts-arr pts-arr-inner">{_draw_df_html(output_columns, input_columns)}</div>' return html_block def render_branch_node(): result = '''<div class="pts-parallel-scaffold pts-inner"> <div class="pts-hline"></div> <div class="pts-inner-schematic pts-inner-linked"> ''' return result result = '' mode = radix_schematic.get('mode','') if radix_schematic['type'] == 'tree': result = '<div class="pts-pipeline">' clz = 'pts-arr' if mode == 'linear' else '' if outer_class == 'outer': result += '<div class="pts-io-label">Input</div>' result += f'<div class="pts-hline {clz} pts-arr-input">{_draw_df_html(radix_schematic["input_columns"])}</div>' if mode == 'branch': # Branching: render vertical lines and each branch as a parallel scaffold result += render_branch_node() for i, node in enumerate(radix_schematic['nodes']): # Handle pipeline nodes in branch mode new = {} record = node['self'] if node.get('mode','') == 'branch': result+= draw_radix_html_schematic(record, outer_class='inner') new['nodes'] = node['children'] new['type'] = 'tree' new['mode'] = 'branch' result += draw_radix_html_schematic(new, outer_class='inner') result += '</div>' continue if record['type'] == 'pipeline': pipe_tree = { 'type': 'pipeline', 'evaluation_index': record.get('evaluation_index', []), 'nodes': [{'self': t, 'type' : 'node'} for t in record.get('transformers', [])], 'mode': 'linear' } result += draw_radix_html_schematic(pipe_tree, outer_class='inner-pipeline') else: result += draw_radix_html_schematic(record, outer_class='inner') result += '</div>' result += '</div></div>' else: # Linear or single node: render as before (no short hline after vline) for i, node in enumerate(radix_schematic['nodes']): new = {} record = node['self'] if node.get('mode','') == 'branch': if record['type'] == 'pipeline': pipe_tree = { 'type': 'tree', 'input_columns': record.get('input_columns', []), 'nodes': [{'self': t, 'type' : 'node'} for t in record.get('transformers', [])] } # outer_class can be anything except 'outer' result += draw_radix_html_schematic(pipe_tree, outer_class='inner') else: # Transformer node with branches - mark as last node before branching result += render_node(record, is_last=False) new['nodes'] = node['children'] new['type'] = 'tree' new['mode'] = 'branch' result += draw_radix_html_schematic(new, outer_class='inner-pipeline') result += '</div>' continue if record['type'] == 'pipeline': # Render the pipeline transformers (works with or without children) pipe_tree = { 'type': 'tree', 'input_columns': record.get('input_columns', []), 'nodes': [{'self': t, 'type' : 'node'} for t in record.get('transformers', [])] } result += draw_radix_html_schematic(pipe_tree, outer_class='inner') # If this pipeline has children, render them as branches if node.get('children', []) != []: new = {} new['nodes'] = node['children'] new['type'] = 'tree' new['mode'] = 'linear' result += draw_radix_html_schematic(new, outer_class='inner') result += '</div>' elif node.get('children', []) != []: result+= render_node(record, is_last = record['is_last']) new['nodes'] = node['children'] new['type'] = 'tree' new['mode'] = 'linear' result += draw_radix_html_schematic(new, outer_class='inner-pipeline') result += '</div>' else: result+= render_node(record, is_last = record['is_last']) result += '</div>' return result elif radix_schematic['type'] == 'pipeline': result += '<div class="pts-parallel-item"><div class="pts-vline"></div>' result += '<div class="pts-pipeline">' if 'transformers' in radix_schematic: pipe_tree = { 'type': 'tree', 'input_columns': radix_schematic.get('input_columns', []), 'nodes': [{'self': t, 'type' : 'node'} for t in radix_schematic.get('transformers', [])] } # outer_class can be anything except 'outer' result += '<div class="pts-hline pts-arr pts-arr-inner" style="width: 16px;"></div>' result += draw_radix_html_schematic(pipe_tree, outer_class='inner') else: for transformer in radix_schematic['nodes']: record = transformer['self'] result += draw_radix_html_schematic(record, outer_class='inner-pipeline') result += '</div>' return result elif radix_schematic['type'] == "output": result = '' if outer_class == 'inner': result += '<div class="pts-parallel-item"><div class="pts-vline"></div>' result += '<div class="pts-pipeline">' result += '<div class="pts-hline" style="width: 10px;"></div>' else: result += '<div class="pts-pipeline">' result += f'<div class="pts-hline pts-arr pts-arr-output">{_draw_df_html(radix_schematic.get("output_columns"), radix_schematic.get("input_columns"))}</div>' result += '<div class="pts-io-label">Evaluate</div>' result += '</div>' return result elif radix_schematic['type'] in ('node', 'transformer', 'indexer') : transformer_result = '' if outer_class == 'inner': transformer_result += '<div class="pts-parallel-item"><div class="pts-vline"></div>' transformer_result += '<div class="pts-pipeline">' transformer_result += '<div class="pts-hline pts-arr pts-arr-inner" style="width: 16px;"></div>' transformer_result += render_node(radix_schematic, is_last = radix_schematic['is_last']) # is_last = True transformer_result += '</div>' return transformer_result elif outer_class == 'inner-pipeline': transformer_result += '<div class="pts-hline pts-arr pts-arr-inner" style="width: 16px;"></div>' transformer_result += render_node(radix_schematic, radix_schematic['is_last']) return transformer_result else: return render_node(radix_schematic, radix_schematic['is_last']) else: raise ValueError(f"Unknown schematic type {radix_schematic['type']}") def _draw_html_schematic(schematic: dict, *, mode: str = 'outer') -> str: if schematic['type'] == 'transformer': return _draw_html_schematic({ 'type': 'pipeline', 'input_columns': schematic.get('input_columns'), 'output_columns': schematic.get('output_columns'), 'transformers': [schematic], }, mode=mode) if schematic['type'] == 'indexer': return _draw_html_schematic({ 'type': 'pipeline', 'input_columns': schematic.get('input_columns'), 'transformers': [schematic], }, mode=mode) if schematic['type'] == 'pipeline': result = '<div class="pts-pipeline">' if mode == 'outer': # Only omit the arrow for 'combine' and 'branch' modes ipm = schematic['transformers'][0].get('inner_pipelines_mode') clz = 'pts-arr' if ipm != 'combine' or ipm != 'branch' else '' result += '<div class="pts-io-label">Input</div>' result += f'<div class="pts-hline {clz} pts-arr-input">{_draw_df_html(schematic["input_columns"])}</div>' elif mode == 'inner_linked': result += '<div class="pts-hline pts-arr pts-arr-inner" style="width: 16px;"></div>' elif mode == 'inner_labeled': result += f'<div class="pts-hline pts-arr pts-arr-input">{_draw_df_html(schematic["input_columns"])}</div>' columns = schematic["input_columns"] for i, record in enumerate(schematic['transformers']): assert record['type'] == 'transformer' or record['type'] == 'indexer', record assert record['input_columns'] == columns #calll the render function here infobox, infobox_attr, error_cls = render_transformer_infobox(record) if 'inner_pipelines' in record: if record['inner_pipelines_mode'] == 'linked': pipelines = '' for pipeline in record['inner_pipelines']: pipelines += '<div class="pts-parallel-item"><div class="pts-vline"></div>' + _draw_html_schematic(pipeline, mode='inner_linked') + '<div class="pts-vline"></div></div>' result += f''' <div class="pts-transformer pts-inner pts-parallel-scaffold {error_cls}" {infobox_attr}> {infobox} <div class="pts-hline"></div> <div class="pts-transformer-title">{html.escape(record["label"])}</div> <div class="pts-inner-schematic pts-inner-linked">{pipelines}</div> <div class="pts-hline pts-arr"></div> <!-- TODO this is unusual - an arrow AFTER something --> </div> ''' elif record['inner_pipelines_mode'] == 'combine': pipelines = '' for pipeline in record['inner_pipelines']: pipelines += '<div class="pts-parallel-item"><div class="pts-vline"></div>' + _draw_html_schematic(pipeline, mode='inner_linked') + '<div class="pts-vline"></div></div>' result += f''' <div class="pts-combine-box"> <div class="pts-parallel-scaffold pts-inner"> <div class="pts-hline"></div> <div class="pts-inner-schematic pts-inner-linked">{pipelines}</div> <div class="pts-hline pts-arr"></div> </div> <!-- this is for the RRFusion part of the pipeline --> <div class="pts-transformer {error_cls}" {infobox_attr}> {infobox} <div class="pts-transformer-title">{html.escape(record["label"])}</div> </div> </div> ''' elif record['inner_pipelines_mode'] == 'unlinked': pipelines = '' if len(record['inner_pipelines_labels']) == len(record['inner_pipelines']): for label, pipeline in zip(record['inner_pipelines_labels'], record['inner_pipelines']): pipelines += f''' <div class="pts-transformer-title">{html.escape(label)}</div> <div class="pts-inner-schematic pts-inner-labeled">{_draw_html_schematic(pipeline, mode='inner_labeled')}</div> ''' else: for pipeline in record['inner_pipelines']: pipelines += _draw_html_schematic(pipeline, mode='inner_labeled') result += f''' <div class="pts-transformer pts-inner {error_cls}" {infobox_attr}> {infobox} <div class="pts-transformer-title">{html.escape(record["label"])}</div> <div class="pts-inner-schematic pts-inner-labeled">{pipelines}</div> </div> ''' elif record['type'] == 'indexer': result += f''' <div class="pts-transformer {error_cls}" {infobox_attr}> {infobox} <div class="pts-transformer-title">{html.escape(record["label"])}</div> </div> ''' else: assert record['type'] == 'transformer' result += f''' <div class="pts-transformer {error_cls}" {infobox_attr}> {infobox} <div class="pts-transformer-title">{html.escape(record["label"])}</div> </div> ''' if i != len(schematic['transformers']) - 1: result += f'<div class="pts-hline pts-arr pts-arr-inner">{_draw_df_html(record["output_columns"], record["input_columns"])}</div>' columns = record['output_columns'] if mode == 'outer' or mode == 'outer-branch': if schematic["transformers"][-1]["type"] == 'indexer': result += '<div class="pts-hline pts-arr pts-arr-output"><svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="pts-artifact-icon"><path stroke="none" d="M0 0h24v24H0z" fill="none"/><path d="M12 6m-8 0a8 3 0 1 0 16 0a8 3 0 1 0 -16 0" /><path d="M4 6v6a8 3 0 0 0 16 0v-6" /><path d="M4 12v6a8 3 0 0 0 16 0v-6" /></svg></div>' result += '<div class="pts-io-label">Artifact</div>' else: result += f'<div class="pts-hline pts-arr pts-arr-output">{_draw_df_html(schematic["output_columns"], schematic["transformers"][-1]["input_columns"])}</div>' result += '<div class="pts-io-label">Output</div>' elif mode == 'inner_linked': result += f'<div class="pts-hline" style="flex-grow: 1;">{_draw_df_html(schematic["output_columns"], schematic["transformers"][-1]["input_columns"])}</div>' elif mode == 'inner_labeled': result += f'<div class="pts-hline pts-arr pts-arr-output">{_draw_df_html(schematic["output_columns"], schematic["transformers"][-1]["input_columns"])}</div>' result += '</div>' return result raise ValueError(f"Unknown schematic type {schematic['type']}") def _draw_df_html(columns, prev_columns = None) -> str: """Draws a DataFrame as an HTML table.""" if columns is None: columns = [] df_class = ' pts-df-alert' frame_info = { 'label': '?', 'title': 'Unknown Frame', } else: df_class = '' frame_info = pt.model.frame_info(columns) or {'label': 'DF', 'title': 'DataFrame'} df_label = frame_info['label'] df_label_long = frame_info['title'] # change underscore subscript into HTML subscript df_label = re.sub(r'_(\w+)', r'<sub>\1</sub>', df_label) uid = str(uuid.uuid4()) if columns: column_rows = [] for col in columns: col_info = pt.model.column_info(col) or {} col_desc = '' type_name = '' if 'type' in col_info: type_name = str(col_info['type']) if col_info['type'] == np.array: type_name = 'np.array' elif hasattr(col_info['type'], '__name__'): type_name = col_info['type'].__name__ type_name = f'<span style="font-family: monospace;">{html.escape(type_name)}</span>' if 'phrase' in col_info: col_desc += f'<i>({html.escape(col_info["phrase"])})</i> ' if 'short_desc' in col_info: col_desc += f'{html.escape(col_info["short_desc"])} ' is_added = prev_columns and col not in prev_columns column_rows.append(f''' <tr class="{"pts-add" if is_added else ""}"> <th>{html.escape(col)}</th> <td>{type_name}</td> <td>{col_desc}</td> </tr> ''') col_table = f''' <div id="id-{uid}" class="pts-infobox-item" data-title="{df_label_long}"> <table class="pts-df-columns"> {''.join(column_rows)} </table> </div>''' else: col_table = f''' <div id="id-{uid}" class="pts-infobox-item" data-title="{df_label_long}"> <div class="pts-infobox-error">Unknown/incompatible columns</div> </div>''' return f'<div class="pts-df {df_class}" data-pts-infobox="id-{uid}">{df_label}{col_table}</div>'
[docs] @runtime_checkable class HasSchematic(Protocol): """Protocol for transformers override details about their schematic representation. This is an optional extension interface to :class:`pyterrier.Transformer` that allows transformers to provide customizations to their schematics. """
[docs] def schematic(self, *, input_columns: Optional[List[str]]) -> Dict[str, Any]: """Returns a structured schematic representation of the transformer. The schematic should be a dictionary that follows the structure defined in :ref:`pt.schematic <pyterrier.schematic>`. For ease of use, the method can optionally return only some of the fields of the schematic; any missing fields will be filled in with default values. It can also be implemented as an instance or class member when the values do not need to be computed on-the-fly (e.g., overriding the schematic label). When ``schematic`` is not ``callable``, it uses its dict value directly as the schematic. Args: input_columns: The input columns of the transformer, used to determine schematic fields such as the output columns. Returns: A dictionary representing the schematic of the transformer, which will be used to draw the schematic diagram. """