Source code for pyterrier_caching.retriever_cache
from typing import List, Optional, Union
from pathlib import Path
from tempfile import TemporaryDirectory
from warnings import warn
import hashlib
import lz4.frame
import pandas as pd
import pyterrier as pt
import pickle
import json
import dbm.dumb
import pyterrier_alpha as pta
from pyterrier_caching import meta_file_compat
[docs]
class DbmRetrieverCache(pt.Artifact, pt.Transformer):
"""A :class:`~pyterrier_caching.RetrieverCache` that stores retrieved results in ``dbm.dumb`` database files."""
ARTIFACT_TYPE = 'retriever_cache'
ARTIFACT_FORMAT = 'dbm.dumb'
ARTIFACT_SCHEMATIC_SHOW_AS_TRANSFORMER = True
def __init__(self,
path: Optional[Union[str, Path]] = None,
retriever: Optional[pt.Transformer] = None,
*,
on: Optional[Union[str, List[str]]] = None,
verbose: bool = False):
"""
Args:
path: The path to the cache.
retriever: The retriever that is cached.
on: The column(s) to use as the key for the cache. If None, all columns will be used.
verbose: If True, print progress information.
"""
if path is None:
self._tmpdir = TemporaryDirectory()
path = Path(self._tmpdir.name) / 'cache'
else:
self._tmpdir = None
super().__init__(path)
meta_file_compat(path)
self.on = on
self._validate_on(on)
self.retriever = retriever
self.verbose = verbose
self.meta = None
self.file = None
self.file_name = None
if not (Path(self.path)/'pt_meta.json').exists():
with pta.ArtifactBuilder(self):
pass # just create the artifact
def _validate_on(self, on):
if on is None:
return
if isinstance(on, str):
on = [self.on]
if 'docno' in on:
raise ValueError("The 'docno' column is reserved and cannot be used as a cache key.")
def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
if self.on is not None:
if isinstance(self.on, str):
on = [self.on]
else:
on = list(self.on)
else:
on = list(inp.columns)
# we dont use _validate_on here because it happens in the input validation
child_reqs = pt.inspect.transformer_inputs(self.retriever, strict=False)
if child_reqs is not None:
valid_reqs = [ c for c in child_reqs if "docno" not in c and "docid" not in c ]
with pt.validate.any(inp.columns) as v:
for c in valid_reqs:
v.columns(includes=list(set(c) | set(on)), excludes=['docno', 'docid'])
else:
pt.validate.query_frame(inp, extra_cols=on)
on = tuple(sorted(on))
self._ensure_built(on)
results = []
to_retrieve = []
to_retrieve_hashes = []
# Step 1: Check Cache
for i in range(len(inp)):
row = inp.iloc[i]
key = tuple(row[o] for o in on)
key_hash = hashlib.sha256(pickle.dumps(key)).digest()
if key_hash in self.file:
stored_data = pickle.loads(lz4.frame.decompress(self.file[key_hash]))
results.append(pd.DataFrame(stored_data))
else:
to_retrieve.append(i)
to_retrieve_hashes.append(key_hash)
# calculate the output columns. this is needed for detecting one_at_a_time, and
# also for returning an empty dataframe for inspection
out_cols = pt.inspect.transformer_outputs(self.retriever, list(inp.columns), strict=False)
if out_cols is not None and all(o in out_cols for o in on):
one_at_a_time = False
retrieve_phases = [to_retrieve]
else:
one_at_a_time = True
# Step 2: Retrieve and save missing results
if to_retrieve:
self.file.close()
self.file = None
with dbm.dumb.open(self.file_name, 'w') as file:
self.file_name = None
if one_at_a_time:
retrieve_phases = [[idx] for idx in to_retrieve]
warn("Running retriever one query at a time because retriever's outputs could not be determined or "
f"the outputs do not contain the cache key: {on}")
if self.verbose:
retrieve_phases = pt.tqdm(retrieve_phases, unit='q', desc=f'{self}')
for i, idxs in enumerate(retrieve_phases):
retrieved_results = self.retriever(inp.iloc[idxs])
retrieved_results.reset_index(drop=True, inplace=True)
results.append(retrieved_results)
if one_at_a_time:
hash_groups = [(to_retrieve_hashes[i], retrieved_results)]
else:
keys = retrieved_results[list(on)].itertuples(index=False)
key_hashes = [hashlib.sha256(pickle.dumps(tuple(key))).digest() for key in keys]
hash_groups = retrieved_results.groupby(key_hashes)
for key_hash, group in hash_groups:
if isinstance(key_hash, tuple):
assert len(key_hash) == 1
key_hash = key_hash[0]
stored_data = {c: group[c].values for c in group.columns}
file[key_hash] = lz4.frame.compress(pickle.dumps(stored_data))
if self.verbose and len(inp):
print(f'{self}: {len(inp)-len(to_retrieve)} hit(s), {len(to_retrieve)} miss(es)')
if not results:
return pd.DataFrame([], columns=out_cols)
return pd.concat(results, ignore_index=True)
def _ensure_built(self, on):
on_hash = hashlib.sha256(pickle.dumps(on)).hexdigest()
fname = str(self.path/f'{on_hash}.dbm')
if self.file_name is not None and self.file_name != fname:
self.file.close()
self.file = None
self.file_name = None
if self.file is None:
self.file = dbm.dumb.open(fname, 'c')
self.file_name = fname
if self.meta is None:
with (self.path/'pt_meta.json').open('rt') as fin:
self.meta = json.load(fin)
assert self.meta['type'] == self.ARTIFACT_TYPE
assert self.meta['format'] == self.ARTIFACT_FORMAT
def close(self):
if self.file is not None:
self.file.close()
self.file = None
self.file_name = None
if self._tmpdir is not None:
self._tmpdir.cleanup()
self._tmpdir = None
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
def __repr__(self):
return f'DbmRetrieverCache({repr(str(self.path))}, {self.retriever})'
def _repr_html_(self):
return pt.schematic.draw(self, outer_class='repr_html')
# Default implementation of RetrieverCache: DbmRetrieverCache
RetrieverCache = DbmRetrieverCache