Source code for pyterrier.java._core

import os
import sys
from pyterrier.java import required_raise, required, before_init, started, mavenresolver, JavaClasses, JavaInitializer, register_config
from typing import Optional
import pyterrier as pt

_min_colab_jdk = "openjdk-11-jdk-headless"
_stdout_ref = None
_stderr_ref = None


# ----------------------------------------------------------
# Java Initialization
# ----------------------------------------------------------



class ColabJavaInit(JavaInitializer):
    def priority(self) -> int:
        return -101 # run this initializer before CoreJavaInit
    
    def pre_init(self, jnius_config):
        import sys
        # detect colab
        if 'google.colab' not in sys.modules:
            return
        import shutil
        # detect java on the PATH
        if shutil.which("java") is not None:
            return
        print(f"This Colab is missing Java - installing {_min_colab_jdk}, please wait")
        import subprocess
        import os

        cmd = [
            "apt-get", 
            "install", 
            "-y", 
            _min_colab_jdk,
            "--option=Dpkg::Progress-Fancy=1",
            "--option=APT::Color=1"
        ]

        process = subprocess.Popen(
            cmd,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            text=True,
            bufsize=1,
            env={**os.environ, "TERM": "xterm-color"}
        )

        for line in process.stdout:
            sys.stdout.write(line)
            sys.stdout.flush()

        process.wait()
        # ✅ Check exit status
        if process.returncode == 0:
            print(f"\n✅ apt-get install of {_min_colab_jdk} completed successfully.")
        else:
            print(f"\n❌ apt-get install of {_min_colab_jdk} failed with exit code {process.returncode}.")

class CoreJavaInit(JavaInitializer):
    def priority(self) -> int:
        return -100 # run this initializer before anything else

    def pre_init(self, jnius_config):
        if configure['java_home']:
            os.environ['JAVA_HOME'] = configure['java_home']

        if pt.utils.is_windows():
            if "JAVA_HOME" in os.environ:
                java_home =  os.environ["JAVA_HOME"]
                fix = f'{java_home}\\jre\\bin\\server\\;{java_home}\\jre\\bin\\client\\;{java_home}\\bin\\server\\'
                os.environ["PATH"] = os.environ["PATH"] + ";" + fix

        if pt.java.configure['mem'] is not None:
            jnius_config.add_options('-Xmx' + str(pt.java.configure['mem']) + 'm')

        for opt in pt.java.configure['options']:
            jnius_config.add_options(opt)

        for jar in pt.java.configure['jars']:
            jnius_config.add_classpath(jar)

        # set the property that makes a process name visible in jps
        process_name : str =  pt.utils._get_notebook()
        if process_name is None:
            process_name = "python[pyterrier]:" + (sys.argv[0] if sys.argv[0] else '<interactive>')
        else:
            process_name = "jupyter[pyterrier]:" + process_name
        jnius_config.add_options("-Dsun.java.command=%s" % process_name)

    @required_raise
    def post_init(self, jnius):
        pt.java.set_log_level(pt.java.configure['log_level'])

        if pt.java.configure['redirect_io']:
            pt.java.redirect_stdouterr()

        java_version = pt.java.J.System.getProperty("java.version")
        if java_version.startswith("1.") or java_version.startswith("9."):
            raise RuntimeError(f"Pyterrier requires Java 11 or newer, we only found Java version {java_version};"
                + " install a more recent Java, or change os.environ['JAVA_HOME'] to point to the proper Java installation")

        jnius.protocol_map['java.util.Map$Entry'] = {
            '__getitem__' : _mapentry_getitem,
            '__iter__' : lambda self: iter([self.getKey(), self.getValue()]),
            '__len__' : lambda self: 2
        }


# Map$Entry can be decoded like a tuple
def _mapentry_getitem(self, i):
    if i == 0:
        return self.getKey()
    if i == 1:
        return self.getValue()
    raise IndexError()

class Java24Init(JavaInitializer):
    """Responsible for hacking around JDK safety checks from JDK 24 onwards"""
    def priority(self) -> int:
        return -99 # run this before TerrierJavaInit
    
    def pre_init(self, jnius_config):
        # detect JDK 24 onwards - this is a best attempt - at best, the user will see a warning.
        # the plan is to use https://github.com/kivy/pyjnius/pull/780 to ask the JVM for its version
        # before starting the JVM. 
        # We /could/ safely add this option from at least JDK 21 onwards. JDK 11 doesnt recognise it.
        if "JAVA_HOME" in os.environ:
            java_home = os.environ["JAVA_HOME"]
            if any([f"{ver}." in java_home for ver in [24,25,26,28,29]]):
                jnius_config.add_options("--enable-native-access=ALL-UNNAMED")

    @required_raise
    def post_init(self, jnius):
        from packaging.version import Version, parse
        import re
        java_version = pt.java.J.System.getProperty("java.version")
        # RTD has an annoying -internal in their Java version
        java_version = re.sub(r'[-_].*$', '', java_version)
        if parse(java_version) >= Version("24"):
            # Hadoop will fallback to pureJava for sparc architecture - lets pretend we are for just a minute.
            arch = pt.java.J.System.getProperty("os.arch")
            pt.java.J.System.setProperty("os.arch", "sparc")
            # force initialisation of FastByteComparisons using the sparc arch
            pt.java.autoclass("org.apache.hadoop.io.FastByteComparisons$LexicographicalComparerHolder")
            pt.java.J.System.setProperty("os.arch", arch)

def _is_binary(f):
    import io
    return isinstance(f, (io.RawIOBase, io.BufferedIOBase))


@required
def redirect_stdouterr():
    from jnius import autoclass, PythonJavaClass, java_method

    # TODO: encodings may be a probem here
    class MyOut(PythonJavaClass):
        __javainterfaces__ = ['org.terrier.python.OutputStreamable']

        def __init__(self, pystream):
            super(MyOut, self).__init__()
            self.pystream = pystream
            self.binary = _is_binary(pystream)

        @java_method('()V')
        def close(self):
            self.pystream.close()

        @java_method('()V')
        def flush(self):
            self.pystream.flush()

        @java_method('([B)V', name='write')
        def writeByteArray(self, byteArray):
            # TODO probably this could be faster.
            for c in byteArray:
                self.writeChar(c)

        @java_method('([BII)V', name='write')
        def writeByteArrayIntInt(self, byteArray, offset, length):
            # TODO probably this could be faster.
            for i in range(offset, offset + length):
                self.writeChar(byteArray[i])

        @java_method('(I)V', name='write')
        def writeChar(self, chara):
            if self.binary:
                return self.pystream.write(bytes([chara]))
            return self.pystream.write(chr(chara))

    # we need to hold lifetime references to _stdout_ref/_stderr_ref, to ensure
    # they arent GCd. This prevents a crash when Java callsback to  GCd py obj

    global _stdout_ref
    global _stderr_ref
    import sys
    _stdout_ref = MyOut(sys.stdout)
    _stderr_ref = MyOut(sys.stderr)
    jls = autoclass("java.lang.System")
    jls.setOut(
        autoclass('java.io.PrintStream')(
            autoclass('org.terrier.python.ProxyableOutputStream')(_stdout_ref),
            signature="(Ljava/io/OutputStream;)V"))
    jls.setErr(
        autoclass('java.io.PrintStream')(
            autoclass('org.terrier.python.ProxyableOutputStream')(_stderr_ref),
            signature="(Ljava/io/OutputStream;)V"))


def bytebuffer_to_array(buffer):
    assert buffer is not None
    def unsign(signed):
        return signed + 256 if signed < 0 else signed
    return bytearray([ unsign(buffer.get(offset)) for offset in range(buffer.capacity()) ])


# ----------------------------------------------------------
# Configuration
# ----------------------------------------------------------

configure = register_config('pyterrier.java', {
    'jars': [],
    'options': [],
    'mem': None,
    'log_level': 'WARN',
    'redirect_io': True,
    'java_home': None,
})


[docs] @before_init def add_jar(jar_path): configure.append('jars', jar_path)
[docs] @before_init def add_package(org_name : str, package_name : str, version : Optional[str] = None, file_type : str = 'jar'): if version is None or version == 'snapshot': version = mavenresolver.latest_version_num(org_name, package_name) file_name = mavenresolver.get_package_jar(org_name, package_name, version, artifact=file_type) add_jar(file_name)
[docs] @before_init def set_memory_limit(mem: Optional[float]): configure['mem'] = mem
[docs] @before_init def add_option(option: str): configure.append('options', option)
[docs] @before_init def set_redirect_io(redirect_io: bool): configure['redirect_io'] = redirect_io
[docs] @before_init def set_java_home(java_home: str): """ Sets the directory to search when loading Java. Note that you can achieve the same outcome by setting the `JAVA_HOME` environment variable. """ configure['java_home'] = java_home
[docs] def set_log_level(level): """ Set the logging level. The following string values are allowed, corresponding to Java logging levels: - `'ERROR'`: only show error messages - `'WARN'`: only show warnings and error messages (default) - `'INFO'`: show information, warnings and error messages - `'DEBUG'`: show debugging, information, warnings and error messages Unlike other java settings, this can be changed either before or after init() has been called. """ if not started(): configure['log_level'] = level else: J.PTUtils.setLogLevel(level, None) # noqa: PT100 handled by started() check above
# ---------------------------------------------------------- # Common classes (accessible via pt.java.J.[ClassName]) # ---------------------------------------------------------- J = JavaClasses( ArrayList = 'java.util.ArrayList', Properties = 'java.util.Properties', PTUtils = 'org.terrier.python.PTUtils', System = 'java.lang.System', StringReader = 'java.io.StringReader', HashMap = 'java.util.HashMap', Arrays = 'java.util.Arrays', Array = 'java.lang.reflect.Array', String = 'java.lang.String', List = 'java.util.List', )