from .model import coerce_queries_dataframe
from .batchretrieve import BatchRetrieveBase
from warnings import warn
import pandas as pd
from . import tqdm
anserini_monkey=False
def _init_anserini():
global anserini_monkey
if anserini_monkey:
return
# jnius monkypatching
import jnius_config
anserini_found = False
for j in jnius_config.get_classpath():
if "anserini" in j:
anserini_found = True
break
assert anserini_found, 'Anserini jar not found: you should start PyTerrier, e.g. with '\
+ 'pt.init(boot_packages=["io.anserini:anserini:0.22.0:fatjar"])'
import pyserini.setup
pyserini.setup.configure_classpath = lambda x: x
jnius_config.set_classpath = lambda x: x
anserini_monkey = True
return
#this is the Anserini early rank cutoff rule
from matchpy import Wildcard, ReplacementRule, Pattern
from .transformer import RankCutoffTransformer, rewrite_rules
x = Wildcard.dot('x')
_brAnserini = Wildcard.symbol('_brAnserini', AnseriniBatchRetrieve)
def set_k(_brAnserini, x):
_brAnserini.k = int(x.value)
return _brAnserini
rewrite_rules.append(ReplacementRule(
Pattern(RankCutoffTransformer(_brAnserini, x) ),
set_k
))
[docs]class AnseriniBatchRetrieve(BatchRetrieveBase):
"""
Allows retrieval from an Anserini index. To use this class, PyTerrier should have been started using `pt.init(boot_packages=["io.anserini:anserini:0.22.0:fatjar"])`.
"""
def __init__(self, index_location, k=1000, wmodel="BM25", **kwargs):
"""
Construct an AnseriniBatchRetrieve retrieve from pyserini.search.lucene.LuceneSearcher.
Args:
index_location(str): The location of the Anserini index.
wmodel(str): Weighting models supported by Anserini. There are three options:
* `"BM25"` - the BM25 weighting model
* `"QLD"` - Dirichlet language modelling
* `"TFIDF"` - Lucene's `ClassicSimilarity <https://lucene.apache.org/core/8_1_0/core/org/apache/lucene/search/similarities/ClassicSimilarity.html>`_.
k(int): number of results to return. Default is 1000.
"""
super().__init__(kwargs)
self.index_location = index_location
self.k = k
_init_anserini()
from pyserini.search.lucene import LuceneSearcher
self.searcher = LuceneSearcher(index_location)
self.wmodel = wmodel
self._setsimilarty(wmodel)
def __reduce__(self):
return (
self.__class__,
(self.index_location, self.k, self.wmodel),
self.__getstate__()
)
def __getstate__(self):
return {
'k' : self.k,
'wmodel' : self.wmodel,
'index_location' : self.index_location,
}
def __setstate__(self, d):
self.k = d["k"]
self.wmodel = d["wmodel"]
self.index_location = d["index_location"]
def _setsimilarty(self, wmodel):
if wmodel == "BM25":
self.searcher.set_bm25(k1=0.9, b=0.4)
elif wmodel == "QLD":
self.searcher.object.set_qld(1000.0)
elif wmodel == "TFIDF":
from jnius import autoclass
self.searcher.object.similarty = autoclass("org.apache.lucene.search.similarities.ClassicSimilarity")()
else:
raise ValueError("wmodel %s not support in AnseriniBatchRetrieve" % wmodel)
def _getsimilarty(self, wmodel):
from jnius import autoclass
if wmodel == "BM25":
return autoclass("org.apache.lucene.search.similarities.BM25Similarity")(0.9, 0.4)#(self.searcher.object.bm25_k1, self.searcher.object.bm25_b)
elif wmodel == "QLD":
return autoclass("org.apache.lucene.search.similarities.LMDirichletSimilarity")(1000.0)# (self.searcher.object.ql_mu)
elif wmodel == "TFIDF":
return autoclass("org.apache.lucene.search.similarities.ClassicSimilarity")()
else:
raise ValueError("wmodel %s not support in AnseriniBatchRetrieve" % wmodel)
def __str__(self):
return "AnseriniBatchRetrieve()"
def __repr__(self):
return "AnseriniBatchRetrieve("+self.wmodel + ","+self.k+")"