import os
import json
from packaging.version import Version
from typing import Optional, Union, List, Dict, Any
import pyterrier as pt
TERRIER_PKG = "org.terrier"
_SAVED_FNS = []
_properties = None
configure = pt.java.register_config('pt.terrier.java', {
'terrier_version': os.environ.get("TERRIER_VERSION") or None,
'helper_version': os.environ.get("TERRIER_HELPER_VERSION") or None,
'boot_packages': [],
'force_download': True,
})
[docs]
@pt.java.before_init
def set_version(version: Optional[str] = None):
configure['terrier_version'] = version
[docs]
@pt.java.before_init
def set_helper_version(version: Optional[str] = None):
configure['helper_version'] = version
class TerrierJavaInit(pt.java.JavaInitializer):
def priority(self) -> int:
return -10 # between pt.java.core (-100) and default (0) to load earlier than extensions
def pre_init(self, jnius_config):
# If version is not specified, find newest and download it
if configure['terrier_version'] is None:
terrier_version = pt.java.mavenresolver.latest_version_num(TERRIER_PKG, "terrier-assemblies")
else:
terrier_version = str(configure['terrier_version']) # just in case its a float
configure['terrier_version'] = terrier_version # save this specific version
# obtain the fat jar from Maven
# "snapshot" means use Jitpack.io to get a build of the current
# 5.x branch from Github - see https://jitpack.io/#terrier-org/terrier-core/5.x-SNAPSHOT
if terrier_version == "snapshot":
trJar = pt.java.mavenresolver.get_package_jar("com.github.terrier-org.terrier-core", "terrier-assemblies", "5.x-SNAPSHOT", artifact="jar-with-dependencies", force_download=configure['force_download'])
else:
trJar = pt.java.mavenresolver.get_package_jar(TERRIER_PKG, "terrier-assemblies", terrier_version, artifact="jar-with-dependencies")
jnius_config.add_classpath(trJar)
# now the helper classes
if configure['helper_version'] is None or configure['helper_version'] == 'snapshot':
helper_version = pt.java.mavenresolver.latest_version_num(TERRIER_PKG, "terrier-python-helper")
configure['helper_version'] = helper_version # save this specific version
else:
helper_version = str(configure['helper_version']) # just in case its a float
helper_jar = pt.java.mavenresolver.get_package_jar(TERRIER_PKG, 'terrier-python-helper', helper_version)
jnius_config.add_classpath(helper_jar)
# This is for parallel -- it means that when re-configured in a parallel process, force_download will be False
# and mavenresolver will use the version that was just downloaded above (not try to do it again).
configure['force_download'] = False
@pt.java.required_raise
def post_init(self, jnius):
global _properties
jnius.protocol_map["org.terrier.structures.postings.IterablePosting"] = {
'__iter__': lambda self: self,
'__next__': lambda self: _iterableposting_next(self),
'__str__': lambda self: self.toString()
}
jnius.protocol_map["org.terrier.structures.CollectionStatistics"] = {
'__str__': lambda self: self.toString()
}
jnius.protocol_map["org.terrier.structures.LexiconEntry"] = {
'__str__': lambda self: self.toString()
}
jnius.protocol_map["org.terrier.structures.Lexicon"] = {
'__getitem__': _lexicon_getitem,
'__contains__': lambda self, term: self.getLexiconEntry(term) is not None,
'__len__': lambda self: self.numberOfEntries()
}
jnius.protocol_map["org.terrier.querying.IndexRef"] = {
'__eq__' : lambda self, other: self.equals(other),
'__reduce__' : _index_ref_reduce,
'__getstate__' : lambda self : None,
'text_loader': pt.terrier.terrier_text_loader,
}
jnius.protocol_map["org.terrier.matching.models.WeightingModel"] = {
'__reduce__' : _wmodel_reduce,
'__getstate__' : lambda self : None,
}
jnius.protocol_map["org.terrier.python.CallableWeightingModel"] = {
'__reduce__' : _callable_wmodel_reduce,
'__getstate__' : lambda self : None,
}
jnius.protocol_map["org.terrier.structures.Index"] = {
# this means that len(index) returns the number of documents in the index
'__len__': lambda self: self.getCollectionStatistics().getNumberOfDocuments(),
# document-wise composition of indices: adding more documents to an index, by merging two indices with
# different numbers of documents. This implemented by the overloading the `+` Python operator
'__add__': _index_add,
# get_corpus_iter returns a yield generator that return {"docno": "d1", "toks" : {'a' : 1}}
'get_corpus_iter' : _index_corpusiter,
'text_loader': pt.terrier.terrier_text_loader,
}
self._post_init_index(jnius)
pt.IndexRef = J.IndexRef
_properties = pt.java.J.Properties()
pt.ApplicationSetup = J.ApplicationSetup
J.ApplicationSetup.bootstrapInitialisation(_properties)
@pt.java.required_raise
def message(self):
version_string = J.Version.VERSION
if "BUILD_DATE" in dir(J.Version):
version_string += f" (build: {J.Version.BUILD_USER} {J.Version.BUILD_DATE})"
return f"version={version_string}, helper_version={configure['helper_version']}"
def _post_init_index(self, jnius):
@pt.java.required
class DocListIterator(jnius.PythonJavaClass):
__javainterfaces__ = [
'java/util/Iterator',
]
def __init__(self, pyiterator):
self.pyiterator = pyiterator
self.hasnext = True
self.lastdoc = None
self.tr57 = not pt.terrier.check_version("5.8")
@staticmethod
def pyDictToMap(a_dict): #returns Map<String,String>
rtr = pt.java.J.HashMap()
for k,v in a_dict.items():
rtr.put(k, v)
return rtr
def pyDictToMapEntry(self,doc_dict : Dict[str,Any]): #returns Map.Entry<Map<String,String>, DocumentPostingList>>
dpl = pt.terrier.J.DocumentPostingList()
# this works around a bug in the counting of doc lengths in Tr 5.7
if self.tr57:
for t, tf in doc_dict["toks"].items():
for i in range(int(tf)):
dpl.insert(t)
else: # this code for 5.8 onwards
for t, tf in doc_dict["toks"].items():
dpl.insert(int(tf), t)
# we cant make the toks column into the metaindex as it isnt a string. remove it.
del doc_dict["toks"]
return pt.terrier.J.MapEntry(DocListIterator.pyDictToMap(doc_dict), dpl)
@jnius.java_method('()Z')
def hasNext(self):
return self.hasnext
@jnius.java_method('()Ljava/lang/Object;')
def next(self):
try:
doc_dict = next(self.pyiterator)
except StopIteration:
self.hasnext = False
# terrier will ignore a null return from an iterator
return None
# keep this around to prevent being GCd before Java can read it
self.lastdoc = self.pyDictToMapEntry(doc_dict)
return self.lastdoc
@pt.java.required
class TQDMCollection(jnius.PythonJavaClass):
__javainterfaces__ = ['org/terrier/indexing/Collection']
def __init__(self, collection):
super(TQDMCollection, self).__init__()
assert isinstance(collection, pt.terrier.J.MultiDocumentFileCollection)
self.collection = collection
size = self.collection.FilesToProcess.size()
self.pbar = pt.tqdm(total=size, unit="files")
self.last = -1
@jnius.java_method('()Z')
def nextDocument(self):
rtr = self.collection.nextDocument()
filenum = self.collection.FileNumber
if filenum > self.last:
self.pbar.update(filenum - self.last)
self.last = filenum
return rtr
@jnius.java_method('()V')
def reset(self):
self.pbar.reset()
self.collection.reset()
@jnius.java_method('()V')
def close(self):
self.pbar.close()
self.collection.close()
@jnius.java_method('()Z')
def endOfCollection(self):
return self.collection.endOfCollection()
@jnius.java_method('()Lorg/terrier/indexing/Document;')
def getDocument(self):
global lastdoc
lastdoc = self.collection.getDocument()
return lastdoc
class PythonListIterator(jnius.PythonJavaClass):
__javainterfaces__ = ['java/util/Iterator']
def __init__(self, text, meta, convertFn, len=None, index=0):
super(PythonListIterator, self).__init__()
self.text = text
self.meta = meta
self.index = index
self.convertFn = convertFn
if len is None:
self.len = len(self.text)
else:
self.len = len
@jnius.java_method('()V')
def remove():
# 1
pass
@jnius.java_method('(Ljava/util/function/Consumer;)V')
def forEachRemaining(action):
# 1
pass
@jnius.java_method('()Z')
def hasNext(self):
return self.index < self.len
@jnius.java_method('()Ljava/lang/Object;')
def next(self):
text = self.text[self.index]
meta = self.meta.__next__()
self.index += 1
global lastdoc
if self.convertFn is not None:
lastdoc = self.convertFn(text, meta)
else:
lastdoc = [text, meta]
return lastdoc
@pt.java.required
class FlatJSONDocumentIterator(jnius.PythonJavaClass):
__javainterfaces__ = ['java/util/Iterator']
def __init__(self, it):
super(FlatJSONDocumentIterator, self).__init__()
self._it = it
# easiest way to support hasNext is just to start consuming right away, I think
self._next = next(self._it, StopIteration)
@jnius.java_method('()V')
def remove():
# 1
pass
@jnius.java_method('(Ljava/util/function/Consumer;)V')
def forEachRemaining(action):
# 1
pass
@jnius.java_method('()Z')
def hasNext(self):
return self._next is not StopIteration
@jnius.java_method('()Ljava/lang/Object;')
def next(self):
result = self._next
self._next = next(self._it, StopIteration)
if result is not StopIteration:
global lastdoc
lastdoc = pt.terrier.J.FlatJSONDocument(json.dumps(result))
return lastdoc
return None
class TQDMSizeCollection(jnius.PythonJavaClass):
__javainterfaces__ = ['org/terrier/indexing/Collection']
def __init__(self, collection, total):
super(TQDMSizeCollection, self).__init__()
self.collection = collection
self.pbar = pt.tqdm(total=total, unit="documents")
@jnius.java_method('()Z')
def nextDocument(self):
rtr = self.collection.nextDocument()
self.pbar.update()
return rtr
@jnius.java_method('()V')
def reset(self):
self.pbar.reset()
self.collection.reset()
@jnius.java_method('()V')
def close(self):
self.pbar.close()
self.collection.close()
@jnius.java_method('()Z')
def endOfCollection(self):
return self.collection.endOfCollection()
@jnius.java_method('()Lorg/terrier/indexing/Document;')
def getDocument(self):
global lastdoc
lastdoc = self.collection.getDocument()
return lastdoc
pt.terrier.index.DocListIterator = DocListIterator
pt.terrier.index.PythonListIterator = PythonListIterator
pt.terrier.index.FlatJSONDocumentIterator = FlatJSONDocumentIterator
pt.terrier.index.TQDMCollection = TQDMCollection
pt.terrier.index.TQDMSizeCollection = TQDMSizeCollection
def _new_indexref(s):
return pt.IndexRef.of(s)
@pt.java.required
def _new_wmodel(b):
return J.Serialization.deserialize(b, J.ApplicationSetup.getClass("org.terrier.matching.models.WeightingModel"))
def _new_callable_wmodel(byterep):
import dill as pickle
from dill import extend
#see https://github.com/SeldonIO/alibi/issues/447#issuecomment-881552005
extend(use_dill=False)
fn = pickle.loads(byterep)
#we need to prevent these functions from being GCd.
global _SAVED_FNS
_SAVED_FNS.append(fn)
callback, wmodel = pt.terrier.retriever._function2wmodel(fn)
_SAVED_FNS.append(callback)
#print("Stored lambda fn %s and callback in SAVED_FNS, now %d stored" % (str(fn), len(SAVED_FNS)))
return wmodel
def _iterableposting_next(self):
''' dunder method for iterating IterablePosting '''
nextid = self.next()
# 2147483647 is IP.EOL. fix this once static fields can be read from instances.
if 2147483647 == nextid:
raise StopIteration()
return self
def _lexicon_getitem(self, term):
''' dunder method for accessing Lexicon '''
rtr = self.getLexiconEntry(term)
if rtr is None:
raise KeyError()
return rtr
def _index_ref_reduce(self):
return (
_new_indexref,
(str(self.toString()),),
None
)
# handles the pickling of WeightingModel classes, which are themselves usually Serializable in Java
@pt.java.required
def _wmodel_reduce(self):
serialized = bytes(J.Serialization.serialize(self))
return (
_new_wmodel,
(serialized, ),
None
)
def _callable_wmodel_reduce(self):
# get bytebuffer representation of lambda
# convert bytebyffer to python bytearray
serlzd = self.scoringClass.serializeFn()
bytesrep = pt.java.bytebuffer_to_array(serlzd)
del(serlzd)
return (
_new_callable_wmodel,
(bytesrep, ),
None
)
@pt.java.required
def _index_add(self, other):
fields_1 = self.getCollectionStatistics().getNumberOfFields()
fields_2 = self.getCollectionStatistics().getNumberOfFields()
if fields_1 != fields_2:
raise ValueError("Cannot document-wise merge indices with different numbers of fields (%d vs %d)" % (fields_1, fields_2))
blocks_1 = self.getCollectionStatistics().hasPositions()
blocks_2 = other.getCollectionStatistics().hasPositions()
if blocks_1 != blocks_2:
raise ValueError("Cannot document-wise merge indices with and without positions (%r vs %r)" % (blocks_1, blocks_2))
return J.MultiIndex([self, other], blocks_1, fields_1 > 0)
def _index_corpusiter(self, return_toks=True):
def _index_corpusiter_meta(self):
meta_inputstream = self.getIndexStructureInputStream("meta")
keys = self.getMetaIndex().getKeys()
keys_offset = { k: offset for offset, k in enumerate(keys) }
while meta_inputstream.hasNext():
item = meta_inputstream.next()
yield {k : item[keys_offset[k]] for k in keys_offset}
def _index_corpusiter_direct_pretok(self):
meta_inputstream = self.getIndexStructureInputStream("meta")
keys = self.getMetaIndex().getKeys()
keys_offset = { k: offset for offset, k in enumerate(keys) }
keys_offset = {'docno' : keys_offset['docno'] }
direct_inputstream = self.getIndexStructureInputStream("direct")
lex = self.getLexicon()
ip = None
while (ip := direct_inputstream.getNextPostings()) is not None: # this is the next() method
# yield empty toks dicts for empty documents
for skipped in range(0, direct_inputstream.getEntriesSkipped()):
meta = meta_inputstream.next()
rtr = {k : meta[keys_offset[k]] for k in keys_offset}
rtr['toks'] = {}
yield rtr
toks = {}
while ip.next() != ip.EOL:
t, _ = lex[ip.getId()]
toks[t] = ip.getFrequency()
meta = meta_inputstream.next()
rtr = {'toks' : toks}
rtr.update({k : meta[keys_offset[k]] for k in keys_offset})
yield rtr
# yield for trailing empty documents
for skipped in range(0, direct_inputstream.getEntriesSkipped()):
meta = meta_inputstream.next()
rtr = {k : meta[keys_offset[k]] for k in keys_offset}
rtr['toks'] = {}
yield rtr
if return_toks:
if not self.hasIndexStructureInputStream("direct"):
raise ValueError("No direct index input stream available, cannot use return_toks=True")
return _index_corpusiter_direct_pretok(self)
return _index_corpusiter_meta(self)
[docs]
@pt.java.required
def extend_classpath(packages: Union[str, List[str]]):
"""
Allows to add packages to Terrier's classpath after the JVM has started.
"""
if isinstance(packages, str):
packages = [packages]
assert check_version(5.3), "Terrier 5.3 required for this functionality"
package_list = pt.java.J.ArrayList()
for package in packages:
package_list.add(package)
mvnr = J.ApplicationSetup.getPlugin("MavenResolver")
assert mvnr is not None
mvnr = pt.java.cast("org.terrier.utility.MavenResolver", mvnr)
mvnr.addDependencies(package_list)
[docs]
@pt.java.required
def set_property(k, v):
"""
Allows to set a property in Terrier's global properties configuration. Example::
pt.set_property("termpipelines", "")
While Terrier has a variety of properties -- as discussed in its
`indexing <https://github.com/terrier-org/terrier-core/blob/5.x/doc/configure_indexing.md>`_
and `retrieval <https://github.com/terrier-org/terrier-core/blob/5.x/doc/configure_retrieval.md>`_
configuration guides -- in PyTerrier, we aim to expose Terrier configuration through appropriate
methods or arguments. So this method should be seen as a safety-valve - a way to override the
Terrier configuration not explicitly supported by PyTerrier.
"""
_properties[str(k)] = str(v)
J.ApplicationSetup.bootstrapInitialisation(_properties)
[docs]
@pt.java.required
def set_properties(kwargs):
"""
Allows to set many properties in Terrier's global properties configuration
"""
for key, value in kwargs.items():
_properties[str(key)] = str(value)
J.ApplicationSetup.bootstrapInitialisation(_properties)
@pt.java.required
def run(cmd, args=[]):
"""
Allows to run a Terrier executable class, i.e. one that can be access from the `bin/terrier` commandline programme.
"""
J.CLITool.main([cmd] + args)
@pt.java.required
def version():
"""
Returns the version string from the underlying Terrier platform.
"""
return J.Version.VERSION
def check_version(min):
"""
Returns True iff the underlying Terrier version is no older than the requested version.
"""
current_ver = version()
assert current_ver is not None, "Could not obtain Terrier version"
current_ver = Version(current_ver.replace("-SNAPSHOT", ""))
min = Version(str(min))
return current_ver >= min
def check_helper_version(min):
"""
Returns True iff the underlying Terrier helper version is no older than the requested version.
"""
current_ver = configure['helper_version']
assert current_ver is not None, "Could not obtain Terrier helper version"
current_ver = Version(current_ver.replace("-SNAPSHOT", ""))
min = Version(str(min))
return current_ver >= min
# Terrier-specific classes
J = pt.java.JavaClasses(
ApplicationSetup = 'org.terrier.utility.ApplicationSetup',
IndexRef = 'org.terrier.querying.IndexRef',
Version = 'org.terrier.Version',
Tokenizer = 'org.terrier.indexing.tokenisation.Tokeniser',
Serialization = 'org.terrier.python.Serialization',
IndexOnDisk = 'org.terrier.structures.IndexOnDisk',
IndexFactory = 'org.terrier.structures.IndexFactory',
MultiIndex = 'org.terrier.realtime.multi.MultiIndex',
CLITool = 'org.terrier.applications.CLITool',
ApplyTermPipeline = 'org.terrier.querying.ApplyTermPipeline',
ManagerFactory = 'org.terrier.querying.ManagerFactory',
Request = 'org.terrier.querying.Request',
# Indexing
TaggedDocument = 'org.terrier.indexing.TaggedDocument',
FlatJSONDocument = 'org.terrier.indexing.FlatJSONDocument',
Tokeniser = 'org.terrier.indexing.tokenisation.Tokeniser',
TRECCollection = 'org.terrier.indexing.TRECCollection',
SimpleFileCollection = 'org.terrier.indexing.SimpleFileCollection',
BasicIndexer = 'org.terrier.structures.indexing.classical.BasicIndexer',
BlockIndexer = 'org.terrier.structures.indexing.classical.BlockIndexer',
BasicSinglePassIndexer = 'org.terrier.structures.indexing.singlepass.BasicSinglePassIndexer',
BlockSinglePassIndexer = 'org.terrier.structures.indexing.singlepass.BlockSinglePassIndexer',
BasicMemoryIndexer = lambda: 'org.terrier.realtime.memory.MemoryIndexer' if check_version("5.7") else 'org.terrier.python.MemoryIndexer',
Collection = 'org.terrier.indexing.Collection',
StructureMerger = 'org.terrier.structures.merging.StructureMerger',
BlockStructureMerger = 'org.terrier.structures.merging.BlockStructureMerger',
DocumentPostingList = 'org.terrier.structures.indexing.DocumentPostingList',
MapEntry = 'org.terrier.structures.collections.MapEntry',
MultiDocumentFileCollection = 'org.terrier.indexing.MultiDocumentFileCollection',
CollectionFactory = 'org.terrier.indexing.CollectionFactory',
TagSet = 'org.terrier.utility.TagSet',
CollectionFromDocumentIterator = 'org.terrier.python.CollectionFromDocumentIterator',
PTUtils = 'org.terrier.python.PTUtils',
Index = 'org.terrier.structures.Index',
JsonlDocumentIterator = 'org.terrier.python.JsonlDocumentIterator',
JsonlPretokenisedIterator = 'org.terrier.python.JsonlPretokenisedIterator',
ParallelIndexer = 'org.terrier.python.ParallelIndexer',
# PRF
TerrierQLParser = 'org.terrier.querying.TerrierQLParser',
TerrierQLToMatchingQueryTerms = 'org.terrier.querying.TerrierQLToMatchingQueryTerms',
QueryResultSet = 'org.terrier.matching.QueryResultSet',
DependenceModelPreProcess = 'org.terrier.querying.DependenceModelPreProcess',
RM3 = 'org.terrier.querying.RM3',
)