Source code for pyterrier_anserini._java

import importlib.metadata
import os
from glob import glob
from pathlib import Path
from typing import Optional, Tuple
from warnings import warn

import pyterrier as pt
from packaging.version import Version

configure = pt.java.register_config('pyterrier.anserini', {
    'version': None,
})
_version = None


class AnseriniJavaInit(pt.java.JavaInitializer):
    def __init__(self):
        self._message = None

    def condition(self) -> bool:
        """Disables loading with anserini >= 0.36 since it introduces incompatible dependencies."""
        try:
            importlib.metadata.version('pyserini')
        except Exception as ex:
            warn(f'error loading anserini java: {ex}')
            return False
        return True

    def pre_init(self, jnius_config): # noqa: ANN001
        global _version
        if configure['version'] is None:
            jar, version = _get_pyserini_jar()
            self._message = f"version={version} (from pyserini package)"
            _version = version
        else:
            # download and use the anserini version specified by the user
            jar = pt.java.mavenresolver.get_package_jar(
                'io.anserini', "anserini", configure['version'], artifact='fatjar')
            self._message = f"version={configure['version']} (local cache)"
            _version = configure['version']

        if jar is None:
            raise RuntimeError('Could not find anserini jar')
        else:
            jnius_config.add_classpath(jar)

    def post_init(self, jnius): # noqa: ANN001
        # Temporarily disable the configure_classpath during pyserini init, otherwise it will try to reconfigure jnius
        import pyserini.setup
        _configure_classpath = pyserini.setup.configure_classpath
        try:
            pyserini.setup.configure_classpath = pt.utils.noop
            import pyserini.search.lucene  # load the package
        finally:
            pyserini.setup.configure_classpath = _configure_classpath

    def message(self):
        return self._message


def _get_pyserini_jar() -> Optional[Tuple[str, str]]:
    # find the anserini jar distributed with pyserini
    # Adapted from pyserini/setup.py and pyserini/pyclass.py
    import pyserini.setup
    jar_root = os.path.join(os.path.split(pyserini.setup.__file__)[0], 'resources/jars/')
    paths = glob(os.path.join(jar_root, 'anserini-*-fatjar.jar'))
    if not paths:
        return None, None
    latest_jar = max(paths, key=os.path.getctime)
    version = Path(latest_jar).name.split('-')[-2]
    return latest_jar, version


[docs] @pt.java.before_init def set_version(version: Optional[str] = None): """Set the version of Anserini to use. If version is ``None`` (default), the version of Anserini distributed with the pyserini package is used. Otherwise, the specified version is downloaded from Maven and used insead. Note that this function must be run before Java is initialized. """ configure['version'] = version
@pt.java.required def check_version(min_version: str) -> bool: return Version(min_version) <= Version(_version) J = pt.java.JavaClasses( ClassicSimilarity = 'org.apache.lucene.search.similarities.ClassicSimilarity', BM25Similarity = 'org.apache.lucene.search.similarities.BM25Similarity', LMDirichletSimilarity = 'org.apache.lucene.search.similarities.LMDirichletSimilarity', IndexReaderUtils = 'io.anserini.index.IndexReaderUtils', QueryParser = 'org.apache.lucene.queryparser.classic.QueryParser', ImpactSimilarity = 'io.anserini.search.similarity.ImpactSimilarity', )