Source code for pyterrier_caching.retriever_cache
from typing import Optional, Union
from pathlib import Path
from tempfile import TemporaryDirectory
from warnings import warn
import hashlib
import lz4.frame
import pandas as pd
import pyterrier as pt
import pickle
import json
import dbm.dumb
import pyterrier_alpha as pta
from pyterrier_caching import meta_file_compat
[docs]
class DbmRetrieverCache(pta.Artifact, pt.Transformer):
"""A :class:`~pyterrier_caching.RetrieverCache` that stores retrieved results in ``dbm.dumb`` database files."""
ARTIFACT_TYPE = 'retriever_cache'
ARTIFACT_FORMAT = 'dbm.dumb'
def __init__(self,
path: Optional[Union[str, Path]] = None,
retriever: Optional[pt.Transformer] = None,
*,
on: Optional[str] = None,
verbose: bool = False):
"""
Args:
path: The path to the cache.
retriever: The retriever that is cached.
on: The column(s) to use as the key for the cache. If None, all columns will be used.
verbose: If True, print progress information.
"""
if path is None:
self._tmpdir = TemporaryDirectory()
path = Path(self._tmpdir.name) / 'cache'
else:
self._tmpdir = None
super().__init__(path)
meta_file_compat(path)
self.on = on
self.retriever = retriever
self.verbose = verbose
self.meta = None
self.file = None
self.file_name = None
if not (Path(self.path)/'pt_meta.json').exists():
with pta.ArtifactBuilder(self):
pass # just create the artifact
def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
if self.on is not None:
if isinstance(self.on, str):
on = [self.on]
else:
on = list(self.on)
pta.validate.columns(inp, includes=on)
else:
on = list(inp.columns)
pta.validate.query_frame(inp, warn=True)
on = tuple(sorted(on))
self._ensure_built(on)
results = []
to_retrieve = []
to_retrieve_hashes = []
# Step 1: Check Cache
for i in range(len(inp)):
row = inp.iloc[i]
key = tuple(row[o] for o in on)
key_hash = hashlib.sha256(pickle.dumps(key)).digest()
if key_hash in self.file:
stored_data = pickle.loads(lz4.frame.decompress(self.file[key_hash]))
results.append(pd.DataFrame(stored_data))
else:
to_retrieve.append(i)
to_retrieve_hashes.append(key_hash)
# Step 2: Retrieve and save missing results
if to_retrieve:
self.file.close()
self.file = None
with dbm.dumb.open(self.file_name, 'w') as file:
self.file_name = None
out_cols = pta.inspect.transformer_outputs(self.retriever, list(inp.columns), strict=False)
if out_cols is not None and all(o in out_cols for o in on):
one_at_a_time = False
retrieve_phases = [to_retrieve]
else:
warn("Running retriever one query at a time because retriever's outputs could not be determined or "
f"the outputs do not contain the cache key: {on}")
one_at_a_time = True
retrieve_phases = [[idx] for idx in to_retrieve]
if self.verbose:
retrieve_phases = pt.tqdm(retrieve_phases, unit='q', desc=f'{self}')
for i, idxs in enumerate(retrieve_phases):
retrieved_results = self.retriever(inp.iloc[idxs])
retrieved_results.reset_index(drop=True, inplace=True)
results.append(retrieved_results)
if one_at_a_time:
hash_groups = [(to_retrieve_hashes[i], retrieved_results)]
else:
keys = retrieved_results[list(on)].itertuples(index=False)
key_hashes = [hashlib.sha256(pickle.dumps(tuple(key))).digest() for key in keys]
hash_groups = retrieved_results.groupby(key_hashes)
for key_hash, group in hash_groups:
if isinstance(key_hash, tuple):
assert len(key_hash) == 1
key_hash = key_hash[0]
stored_data = {c: group[c].values for c in group.columns}
file[key_hash] = lz4.frame.compress(pickle.dumps(stored_data))
if self.verbose:
print(f'{self}: {len(inp)-len(to_retrieve)} hit(s), {len(to_retrieve)} miss(es)')
return pd.concat(results, ignore_index=True)
def _ensure_built(self, on):
on_hash = hashlib.sha256(pickle.dumps(on)).hexdigest()
fname = str(self.path/f'{on_hash}.dbm')
if self.file_name is not None and self.file_name != fname:
self.file.close()
self.file = None
self.file_name = None
if self.file is None:
self.file = dbm.dumb.open(fname, 'c')
self.file_name = fname
if self.meta is None:
with (self.path/'pt_meta.json').open('rt') as fin:
self.meta = json.load(fin)
assert self.meta['type'] == self.ARTIFACT_TYPE
assert self.meta['format'] == self.ARTIFACT_FORMAT
def close(self):
if self.file is not None:
self.file.close()
self.file = None
self.file_name = None
if self._tmpdir is not None:
self._tmpdir.cleanup()
self._tmpdir = None
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
def __repr__(self):
return f'DbmRetrieverCache({repr(str(self.path))}, {self.retriever})'
# Default implementation of RetrieverCache: DbmRetrieverCache
RetrieverCache = DbmRetrieverCache