Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 97 additions & 5 deletions core/common/search.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import gc
import inspect
import re
import threading
import time
import urllib
from collections import OrderedDict

from cid.locals import get_cid
from django.conf import settings
from django.db.models import Case, When, IntegerField
from elasticsearch_dsl import FacetedSearch, Q
from pydash import compact, get, has, set_
from sentence_transformers import CrossEncoder
import torch

from core.common.constants import ES_REQUEST_TIMEOUT
from core.common.utils import is_url_encoded_string
Expand Down Expand Up @@ -337,6 +342,8 @@ def __get_response(self, exact_count=True, load_fields=False):


class Reranker:
"""Rerank semantic search hits with model-specific score normalization."""

ENCODERS = [
# Best and Fastest overall lightweight medical reranker
# Size: ~110M
Expand Down Expand Up @@ -368,6 +375,9 @@ class Reranker:
]
SCORE_KEY = 'search_rerank_score'
MISSING_SCORE = -1000000.0
QWEN_RERANKER_PREFIX = 'Qwen/'
CUSTOM_ENCODER_CACHE = OrderedDict()
CUSTOM_ENCODER_CACHE_LOCK = threading.Lock()

def __init__(self, model_name=None):
self.model_name = model_name
Expand All @@ -383,6 +393,11 @@ def rerank( # pylint: disable=too-many-arguments
def default_model(self):
return settings.ENCODER_MODEL_NAME

@classmethod
def _get_default_model_name(cls):
"""Return the default boot-time reranker model name."""
return settings.ENCODER_MODEL_NAME

# private
def _predict_scores(self, hits, txt, name_key, source_attr, should_convert_source_to_dict): # pylint: disable=too-many-arguments
if not hits or not txt:
Expand All @@ -399,12 +414,30 @@ def _predict_scores(self, hits, txt, name_key, source_attr, should_convert_sourc
valid.append((i, d.strip()))
if not valid:
return scores_full
scores = self.encoder.predict([(txt, d) for _, d in valid])
scores = self.encoder.predict([(txt, d) for _, d in valid], **self._get_predict_kwargs())
for (i, _), s in zip(valid, scores):
scores_full[i] = float(s)

return scores_full

def _get_activation_fn(self):
"""Return the score activation required by the configured reranker model."""
model_name = self.model_name or self.default_model
if isinstance(model_name, str) and model_name.startswith(self.QWEN_RERANKER_PREFIX):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me this approach is too hardcoded, maybe its interesting to do somethins like:

RERANKER_SCORE_PROFILES = {
    "Qwen/": "logit_sigmoid",
    "BAAI/bge-reranker": "logit_sigmoid",
    "cross-encoder/ms-marco-": "raw_rank_score",
    "ncbi/MedCPT-Cross-Encoder": "logit_sigmoid",
}

But it also could be too YAGNI, so let's hear @paynejd thoughts

return torch.nn.Sigmoid()
return None

def _get_predict_kwargs(self):
"""Return compatibility kwargs for CrossEncoder.predict across library versions."""
activation_fn = self._get_activation_fn()
if activation_fn is None:
return {}

predict_signature = inspect.signature(self.encoder.predict)
if 'activation_fct' in predict_signature.parameters:
return {'activation_fct': activation_fn}
return {'activation_fn': activation_fn}

def _assign_score(self, hits, scores, score_key, order_results):
score_key = score_key or self.SCORE_KEY
key_to_set = score_key
Expand All @@ -420,10 +453,31 @@ def _assign_score(self, hits, scores, score_key, order_results):
def _order(hits, key_to_order):
return sorted(hits, key=lambda hit: get(hit, key_to_order), reverse=True)

def _get_encoder(self, model_name):
if model_name and model_name != self.default_model:
return self._load_encoder(model_name)
return self._load_default_encoder()
@classmethod
def _get_encoder(cls, model_name):
if model_name and model_name != cls._get_default_model_name():
return cls._get_custom_encoder(model_name)
return cls._load_default_encoder()

@classmethod
def _get_custom_encoder(cls, model_name):
"""Return a bounded cached custom encoder to avoid repeated large-model loads."""
now = time.time()
with cls.CUSTOM_ENCODER_CACHE_LOCK:
cls._evict_expired_custom_encoders(now)
cached_encoder = cls.CUSTOM_ENCODER_CACHE.get(model_name)
if cached_encoder:
cls.CUSTOM_ENCODER_CACHE.move_to_end(model_name)
cached_encoder['expires_at'] = now + cls._get_custom_encoder_cache_ttl()
return cached_encoder['encoder']

cls._evict_custom_encoders_for_capacity()
encoder = cls._load_encoder(model_name)
cls.CUSTOM_ENCODER_CACHE[model_name] = {
'encoder': encoder,
'expires_at': now + cls._get_custom_encoder_cache_ttl(),
}
return encoder

@staticmethod
def _load_encoder(model_name):
Expand All @@ -439,3 +493,41 @@ def _get_source(data, source_attr, should_convert_source_to_dict):
if should_convert_source_to_dict and source:
source = dict(source)
return source

@classmethod
def _get_custom_encoder_cache_size(cls):
"""Return the max number of custom encoders that may stay loaded per process."""
return max(1, int(getattr(settings, 'RERANKER_CUSTOM_ENCODER_CACHE_SIZE', 1)))

@classmethod
def _get_custom_encoder_cache_ttl(cls):
"""Return the idle TTL for custom encoders in seconds."""
return max(1, int(getattr(settings, 'RERANKER_CUSTOM_ENCODER_CACHE_TTL', 60 * 5)))

@classmethod
def _evict_custom_encoders_for_capacity(cls):
"""Evict least-recently-used custom encoders before loading another large model."""
while len(cls.CUSTOM_ENCODER_CACHE) >= cls._get_custom_encoder_cache_size():
_, cached_encoder = cls.CUSTOM_ENCODER_CACHE.popitem(last=False)
cls._release_encoder(cached_encoder['encoder'])

@classmethod
def _evict_expired_custom_encoders(cls, now=None):
"""Remove expired custom encoders so idle large models do not stay resident forever."""
now = now or time.time()
expired_models = [
model_name for model_name, cached_encoder in cls.CUSTOM_ENCODER_CACHE.items()
if cached_encoder['expires_at'] <= now
]
for model_name in expired_models:
cached_encoder = cls.CUSTOM_ENCODER_CACHE.pop(model_name, None)
if cached_encoder:
cls._release_encoder(cached_encoder['encoder'])

@staticmethod
def _release_encoder(encoder):
"""Release model references eagerly to give Python a chance to reclaim memory."""
del encoder
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
99 changes: 98 additions & 1 deletion core/concepts/tests/tests.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from unittest.mock import patch, ANY
from unittest.mock import ANY, Mock, patch

import factory
from django.test import override_settings
from pydash import omit

from core.collections.models import CollectionReference
Expand Down Expand Up @@ -123,6 +124,102 @@ def test_reranker_uses_finite_fallback_score_for_missing_candidate_text(self, ge

self.assertEqual(scores, [Reranker.MISSING_SCORE, Reranker.MISSING_SCORE])

@patch.object(Reranker, '_get_encoder')
def test_reranker_uses_sigmoid_activation_for_qwen_models(self, get_encoder_mock):
encoder_mock = Mock()
encoder_mock.predict.return_value = [0.42]
get_encoder_mock.return_value = encoder_mock
reranker = Reranker(model_name='Qwen/Qwen3-Reranker-0.6B')

scores = reranker._predict_scores( # pylint: disable=protected-access
hits=[{'_source': {'name': 'malaria test'}}],
txt='malaria',
name_key='name',
source_attr='_source',
should_convert_source_to_dict=True,
)

self.assertEqual(scores, [0.42])
_, kwargs = encoder_mock.predict.call_args
self.assertIsNotNone(kwargs['activation_fn'])
self.assertEqual(kwargs['activation_fn'].__class__.__name__, 'Sigmoid')

@patch.object(Reranker, '_get_encoder')
def test_reranker_uses_legacy_activation_fct_kwarg_when_needed(self, get_encoder_mock):
class LegacyEncoder:
def predict(self, _pairs, activation_fct=None):
self.activation_fct = activation_fct
return [0.42]

encoder = LegacyEncoder()
get_encoder_mock.return_value = encoder
reranker = Reranker(model_name='Qwen/Qwen3-Reranker-0.6B')

scores = reranker._predict_scores( # pylint: disable=protected-access
hits=[{'_source': {'name': 'malaria test'}}],
txt='malaria',
name_key='name',
source_attr='_source',
should_convert_source_to_dict=True,
)

self.assertEqual(scores, [0.42])
self.assertIsNotNone(encoder.activation_fct)
self.assertEqual(encoder.activation_fct.__class__.__name__, 'Sigmoid')

@override_settings(RERANKER_CUSTOM_ENCODER_CACHE_SIZE=1, RERANKER_CUSTOM_ENCODER_CACHE_TTL=300)
@patch.object(Reranker, '_load_encoder')
def test_custom_reranker_encoder_is_cached_between_requests(self, load_encoder_mock):
encoder = Mock()
load_encoder_mock.return_value = encoder
Reranker.CUSTOM_ENCODER_CACHE.clear()

first = Reranker(model_name='Qwen/Qwen3-Reranker-0.6B')
second = Reranker(model_name='Qwen/Qwen3-Reranker-0.6B')

self.assertIs(first.encoder, second.encoder)
load_encoder_mock.assert_called_once_with('Qwen/Qwen3-Reranker-0.6B')
Reranker.CUSTOM_ENCODER_CACHE.clear()

@override_settings(RERANKER_CUSTOM_ENCODER_CACHE_SIZE=1, RERANKER_CUSTOM_ENCODER_CACHE_TTL=300)
@patch.object(Reranker, '_release_encoder')
@patch.object(Reranker, '_load_encoder')
def test_custom_reranker_encoder_eviction_releases_previous_model(self, load_encoder_mock, release_encoder_mock):
old_encoder = Mock()
new_encoder = Mock()
load_encoder_mock.side_effect = [old_encoder, new_encoder]
Reranker.CUSTOM_ENCODER_CACHE.clear()

first = Reranker(model_name='Qwen/Qwen3-Reranker-0.6B')
second = Reranker(model_name='BAAI/bge-reranker-v2-m3-custom')

self.assertIs(first.encoder, old_encoder)
self.assertIs(second.encoder, new_encoder)
release_encoder_mock.assert_called_once_with(old_encoder)
self.assertEqual(list(Reranker.CUSTOM_ENCODER_CACHE.keys()), ['BAAI/bge-reranker-v2-m3-custom'])
Reranker.CUSTOM_ENCODER_CACHE.clear()

@override_settings(RERANKER_CUSTOM_ENCODER_CACHE_SIZE=1, RERANKER_CUSTOM_ENCODER_CACHE_TTL=10)
@patch.object(Reranker, '_release_encoder')
@patch.object(Reranker, '_load_encoder')
@patch('core.common.search.time.time')
def test_custom_reranker_encoder_ttl_expiry_reloads_model(
self, time_mock, load_encoder_mock, release_encoder_mock):
first_encoder = Mock()
second_encoder = Mock()
time_mock.side_effect = [100, 111]
load_encoder_mock.side_effect = [first_encoder, second_encoder]
Reranker.CUSTOM_ENCODER_CACHE.clear()

first = Reranker(model_name='Qwen/Qwen3-Reranker-0.6B')
second = Reranker(model_name='Qwen/Qwen3-Reranker-0.6B')

self.assertIs(first.encoder, first_encoder)
self.assertIs(second.encoder, second_encoder)
release_encoder_mock.assert_called_once_with(first_encoder)
self.assertEqual(load_encoder_mock.call_count, 2)
Reranker.CUSTOM_ENCODER_CACHE.clear()

def test_default_name_locales(self):
es_locale = ConceptNameFactory.build(locale='es')
en_locale = ConceptNameFactory.build(locale='en')
Expand Down
2 changes: 2 additions & 0 deletions core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,8 @@ def get_set_from_env(name):
MINIO_SECURE = os.environ.get('MINIO_SECURE') == 'TRUE'

NO_LM = os.environ.get('NO_LM') == 'TRUE'
RERANKER_CUSTOM_ENCODER_CACHE_SIZE = int(os.environ.get('RERANKER_CUSTOM_ENCODER_CACHE_SIZE', 1))
RERANKER_CUSTOM_ENCODER_CACHE_TTL = int(os.environ.get('RERANKER_CUSTOM_ENCODER_CACHE_TTL', 60 * 5))
ENCODER_MODEL_NAME = None
if ENV not in ['ci', 'demo'] and not NO_LM:
LM_MODEL_NAME = 'all-MiniLM-L6-v2'
Expand Down