diff --git a/core/common/search.py b/core/common/search.py index f7cb9d24..49deb066 100644 --- a/core/common/search.py +++ b/core/common/search.py @@ -1,6 +1,10 @@ +import gc +import inspect import re +import threading import time import urllib +from collections import OrderedDict from cid.locals import get_cid from django.conf import settings @@ -8,6 +12,7 @@ from elasticsearch_dsl import FacetedSearch, Q from pydash import compact, get, has, set_ from sentence_transformers import CrossEncoder +import torch from core.common.constants import ES_REQUEST_TIMEOUT from core.common.utils import is_url_encoded_string @@ -337,6 +342,8 @@ def __get_response(self, exact_count=True, load_fields=False): class Reranker: + """Rerank semantic search hits with model-specific score normalization.""" + ENCODERS = [ # Best and Fastest overall lightweight medical reranker # Size: ~110M @@ -368,6 +375,9 @@ class Reranker: ] SCORE_KEY = 'search_rerank_score' MISSING_SCORE = -1000000.0 + QWEN_RERANKER_PREFIX = 'Qwen/' + CUSTOM_ENCODER_CACHE = OrderedDict() + CUSTOM_ENCODER_CACHE_LOCK = threading.Lock() def __init__(self, model_name=None): self.model_name = model_name @@ -383,6 +393,11 @@ def rerank( # pylint: disable=too-many-arguments def default_model(self): return settings.ENCODER_MODEL_NAME + @classmethod + def _get_default_model_name(cls): + """Return the default boot-time reranker model name.""" + return settings.ENCODER_MODEL_NAME + # private def _predict_scores(self, hits, txt, name_key, source_attr, should_convert_source_to_dict): # pylint: disable=too-many-arguments if not hits or not txt: @@ -399,12 +414,30 @@ def _predict_scores(self, hits, txt, name_key, source_attr, should_convert_sourc valid.append((i, d.strip())) if not valid: return scores_full - scores = self.encoder.predict([(txt, d) for _, d in valid]) + scores = self.encoder.predict([(txt, d) for _, d in valid], **self._get_predict_kwargs()) for (i, _), s in zip(valid, scores): scores_full[i] = float(s) return scores_full + def _get_activation_fn(self): + """Return the score activation required by the configured reranker model.""" + model_name = self.model_name or self.default_model + if isinstance(model_name, str) and model_name.startswith(self.QWEN_RERANKER_PREFIX): + return torch.nn.Sigmoid() + return None + + def _get_predict_kwargs(self): + """Return compatibility kwargs for CrossEncoder.predict across library versions.""" + activation_fn = self._get_activation_fn() + if activation_fn is None: + return {} + + predict_signature = inspect.signature(self.encoder.predict) + if 'activation_fct' in predict_signature.parameters: + return {'activation_fct': activation_fn} + return {'activation_fn': activation_fn} + def _assign_score(self, hits, scores, score_key, order_results): score_key = score_key or self.SCORE_KEY key_to_set = score_key @@ -420,10 +453,31 @@ def _assign_score(self, hits, scores, score_key, order_results): def _order(hits, key_to_order): return sorted(hits, key=lambda hit: get(hit, key_to_order), reverse=True) - def _get_encoder(self, model_name): - if model_name and model_name != self.default_model: - return self._load_encoder(model_name) - return self._load_default_encoder() + @classmethod + def _get_encoder(cls, model_name): + if model_name and model_name != cls._get_default_model_name(): + return cls._get_custom_encoder(model_name) + return cls._load_default_encoder() + + @classmethod + def _get_custom_encoder(cls, model_name): + """Return a bounded cached custom encoder to avoid repeated large-model loads.""" + now = time.time() + with cls.CUSTOM_ENCODER_CACHE_LOCK: + cls._evict_expired_custom_encoders(now) + cached_encoder = cls.CUSTOM_ENCODER_CACHE.get(model_name) + if cached_encoder: + cls.CUSTOM_ENCODER_CACHE.move_to_end(model_name) + cached_encoder['expires_at'] = now + cls._get_custom_encoder_cache_ttl() + return cached_encoder['encoder'] + + cls._evict_custom_encoders_for_capacity() + encoder = cls._load_encoder(model_name) + cls.CUSTOM_ENCODER_CACHE[model_name] = { + 'encoder': encoder, + 'expires_at': now + cls._get_custom_encoder_cache_ttl(), + } + return encoder @staticmethod def _load_encoder(model_name): @@ -439,3 +493,41 @@ def _get_source(data, source_attr, should_convert_source_to_dict): if should_convert_source_to_dict and source: source = dict(source) return source + + @classmethod + def _get_custom_encoder_cache_size(cls): + """Return the max number of custom encoders that may stay loaded per process.""" + return max(1, int(getattr(settings, 'RERANKER_CUSTOM_ENCODER_CACHE_SIZE', 1))) + + @classmethod + def _get_custom_encoder_cache_ttl(cls): + """Return the idle TTL for custom encoders in seconds.""" + return max(1, int(getattr(settings, 'RERANKER_CUSTOM_ENCODER_CACHE_TTL', 60 * 5))) + + @classmethod + def _evict_custom_encoders_for_capacity(cls): + """Evict least-recently-used custom encoders before loading another large model.""" + while len(cls.CUSTOM_ENCODER_CACHE) >= cls._get_custom_encoder_cache_size(): + _, cached_encoder = cls.CUSTOM_ENCODER_CACHE.popitem(last=False) + cls._release_encoder(cached_encoder['encoder']) + + @classmethod + def _evict_expired_custom_encoders(cls, now=None): + """Remove expired custom encoders so idle large models do not stay resident forever.""" + now = now or time.time() + expired_models = [ + model_name for model_name, cached_encoder in cls.CUSTOM_ENCODER_CACHE.items() + if cached_encoder['expires_at'] <= now + ] + for model_name in expired_models: + cached_encoder = cls.CUSTOM_ENCODER_CACHE.pop(model_name, None) + if cached_encoder: + cls._release_encoder(cached_encoder['encoder']) + + @staticmethod + def _release_encoder(encoder): + """Release model references eagerly to give Python a chance to reclaim memory.""" + del encoder + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() diff --git a/core/concepts/tests/tests.py b/core/concepts/tests/tests.py index df8a3422..d62b2ad3 100644 --- a/core/concepts/tests/tests.py +++ b/core/concepts/tests/tests.py @@ -1,6 +1,7 @@ -from unittest.mock import patch, ANY +from unittest.mock import ANY, Mock, patch import factory +from django.test import override_settings from pydash import omit from core.collections.models import CollectionReference @@ -123,6 +124,102 @@ def test_reranker_uses_finite_fallback_score_for_missing_candidate_text(self, ge self.assertEqual(scores, [Reranker.MISSING_SCORE, Reranker.MISSING_SCORE]) + @patch.object(Reranker, '_get_encoder') + def test_reranker_uses_sigmoid_activation_for_qwen_models(self, get_encoder_mock): + encoder_mock = Mock() + encoder_mock.predict.return_value = [0.42] + get_encoder_mock.return_value = encoder_mock + reranker = Reranker(model_name='Qwen/Qwen3-Reranker-0.6B') + + scores = reranker._predict_scores( # pylint: disable=protected-access + hits=[{'_source': {'name': 'malaria test'}}], + txt='malaria', + name_key='name', + source_attr='_source', + should_convert_source_to_dict=True, + ) + + self.assertEqual(scores, [0.42]) + _, kwargs = encoder_mock.predict.call_args + self.assertIsNotNone(kwargs['activation_fn']) + self.assertEqual(kwargs['activation_fn'].__class__.__name__, 'Sigmoid') + + @patch.object(Reranker, '_get_encoder') + def test_reranker_uses_legacy_activation_fct_kwarg_when_needed(self, get_encoder_mock): + class LegacyEncoder: + def predict(self, _pairs, activation_fct=None): + self.activation_fct = activation_fct + return [0.42] + + encoder = LegacyEncoder() + get_encoder_mock.return_value = encoder + reranker = Reranker(model_name='Qwen/Qwen3-Reranker-0.6B') + + scores = reranker._predict_scores( # pylint: disable=protected-access + hits=[{'_source': {'name': 'malaria test'}}], + txt='malaria', + name_key='name', + source_attr='_source', + should_convert_source_to_dict=True, + ) + + self.assertEqual(scores, [0.42]) + self.assertIsNotNone(encoder.activation_fct) + self.assertEqual(encoder.activation_fct.__class__.__name__, 'Sigmoid') + + @override_settings(RERANKER_CUSTOM_ENCODER_CACHE_SIZE=1, RERANKER_CUSTOM_ENCODER_CACHE_TTL=300) + @patch.object(Reranker, '_load_encoder') + def test_custom_reranker_encoder_is_cached_between_requests(self, load_encoder_mock): + encoder = Mock() + load_encoder_mock.return_value = encoder + Reranker.CUSTOM_ENCODER_CACHE.clear() + + first = Reranker(model_name='Qwen/Qwen3-Reranker-0.6B') + second = Reranker(model_name='Qwen/Qwen3-Reranker-0.6B') + + self.assertIs(first.encoder, second.encoder) + load_encoder_mock.assert_called_once_with('Qwen/Qwen3-Reranker-0.6B') + Reranker.CUSTOM_ENCODER_CACHE.clear() + + @override_settings(RERANKER_CUSTOM_ENCODER_CACHE_SIZE=1, RERANKER_CUSTOM_ENCODER_CACHE_TTL=300) + @patch.object(Reranker, '_release_encoder') + @patch.object(Reranker, '_load_encoder') + def test_custom_reranker_encoder_eviction_releases_previous_model(self, load_encoder_mock, release_encoder_mock): + old_encoder = Mock() + new_encoder = Mock() + load_encoder_mock.side_effect = [old_encoder, new_encoder] + Reranker.CUSTOM_ENCODER_CACHE.clear() + + first = Reranker(model_name='Qwen/Qwen3-Reranker-0.6B') + second = Reranker(model_name='BAAI/bge-reranker-v2-m3-custom') + + self.assertIs(first.encoder, old_encoder) + self.assertIs(second.encoder, new_encoder) + release_encoder_mock.assert_called_once_with(old_encoder) + self.assertEqual(list(Reranker.CUSTOM_ENCODER_CACHE.keys()), ['BAAI/bge-reranker-v2-m3-custom']) + Reranker.CUSTOM_ENCODER_CACHE.clear() + + @override_settings(RERANKER_CUSTOM_ENCODER_CACHE_SIZE=1, RERANKER_CUSTOM_ENCODER_CACHE_TTL=10) + @patch.object(Reranker, '_release_encoder') + @patch.object(Reranker, '_load_encoder') + @patch('core.common.search.time.time') + def test_custom_reranker_encoder_ttl_expiry_reloads_model( + self, time_mock, load_encoder_mock, release_encoder_mock): + first_encoder = Mock() + second_encoder = Mock() + time_mock.side_effect = [100, 111] + load_encoder_mock.side_effect = [first_encoder, second_encoder] + Reranker.CUSTOM_ENCODER_CACHE.clear() + + first = Reranker(model_name='Qwen/Qwen3-Reranker-0.6B') + second = Reranker(model_name='Qwen/Qwen3-Reranker-0.6B') + + self.assertIs(first.encoder, first_encoder) + self.assertIs(second.encoder, second_encoder) + release_encoder_mock.assert_called_once_with(first_encoder) + self.assertEqual(load_encoder_mock.call_count, 2) + Reranker.CUSTOM_ENCODER_CACHE.clear() + def test_default_name_locales(self): es_locale = ConceptNameFactory.build(locale='es') en_locale = ConceptNameFactory.build(locale='en') diff --git a/core/settings.py b/core/settings.py index 3cd410a4..f104dbb8 100644 --- a/core/settings.py +++ b/core/settings.py @@ -642,6 +642,8 @@ def get_set_from_env(name): MINIO_SECURE = os.environ.get('MINIO_SECURE') == 'TRUE' NO_LM = os.environ.get('NO_LM') == 'TRUE' +RERANKER_CUSTOM_ENCODER_CACHE_SIZE = int(os.environ.get('RERANKER_CUSTOM_ENCODER_CACHE_SIZE', 1)) +RERANKER_CUSTOM_ENCODER_CACHE_TTL = int(os.environ.get('RERANKER_CUSTOM_ENCODER_CACHE_TTL', 60 * 5)) ENCODER_MODEL_NAME = None if ENV not in ['ci', 'demo'] and not NO_LM: LM_MODEL_NAME = 'all-MiniLM-L6-v2'