rerank_model.py
Overview
The rerank_model.py file provides a comprehensive framework for performing reranking of search results or text documents based on their relevance to a query. It defines an abstract base class and multiple concrete implementations of reranking models that interact with either local models or remote APIs. These models compute similarity scores between a query and a list of candidate texts, returning a relevance ranking that can be used to reorder search results or improve information retrieval systems.
Key features:
Abstract base class
Basedefines the interface for all rerankers.Concrete implementations include local model wrappers (e.g.,
DefaultRerank,YoudaoRerank) and API clients for external services (e.g.,JinaRerank,NvidiaRerank,CoHereRerank).Support for batch processing with dynamic batch size adjustment to optimize GPU memory usage.
Utilities for token counting and text truncation to manage input size limits.
Integration with various third-party model providers, including HuggingFace, Jina, Nvidia, BaiduYiyan, and others.
Exception handling and logging to aid debugging.
This file is a core component of a retrieval-augmented generation (RAG) or search system, allowing flexible use of different reranking backends under a common interface.
Classes and Methods
class Base(ABC)
Abstract base class for all reranking models.
Methods:
init(self, key, model_name, **kwargs)Abstract constructor. Does not store parameters; subclasses handle initialization.
similarity(self, query: str, texts: list) -> np.ndarrayAbstract method to compute similarity scores between a query and a list of texts. Must be implemented by subclasses.
total_token_count(self, resp) -> intHelper to extract total token count from a response object or dictionary.
Parameters:
resp: Response object or dictionary from a reranking API or model.
Returns:
Total token count as an integer (0 if unavailable).
class DefaultRerank(Base)
Local reranker using the FlagReranker model (from FlagEmbedding).
Attributes:
_FACTORY_NAME = "BAAI"_model: Class-level singleton instance of the loaded model._model_lock: Threading lock to synchronize model loading._dynamic_batch_size: Current batch size for processing._min_batch_size: Minimum batch size allowed.
Constructor:
def __init__(self, key, model_name, **kwargs):Loads the model from local cache or downloads from HuggingFace if necessary. Uses FP16 if CUDA is available.
Methods:
torch_empty_cache(self)Clears PyTorch CUDA cache to free GPU memory safely.
_process_batch(self, pairs, max_batch_size=None) -> np.ndarrayProcesses pairs of (query, text) in batches, dynamically adjusting batch size on CUDA OOM errors.
Parameters:
pairs: List of (query, text) tuples.max_batch_size: Optional maximum batch size.
Returns:
NumPy array of similarity scores.
_compute_batch_scores(self, batch_pairs, max_length=None) -> listComputes similarity scores for a batch of pairs using the model's
compute_scoremethod.similarity(self, query: str, texts: list) -> tuple[np.ndarray, int]Computes similarity scores between a single query and multiple texts.
Parameters:
query: The query string.texts: List of candidate text strings.
Returns:
Tuple of:
NumPy array of similarity scores.
Total token count used in the input texts.
class JinaRerank(Base)
Reranker that calls the Jina API for similarity scoring.
Attributes:
_FACTORY_NAME = "Jina"base_url: API endpoint URL, defaults to"https://api.jina.ai/v1/rerank".headers: HTTP headers including authorization.model_name: Name of the rerank model.
Constructor:
def __init__(self, key, model_name="jina-reranker-v2-base-multilingual", base_url="https://api.jina.ai/v1/rerank")Method:
similarity(self, query: str, texts: list) -> tuple[np.ndarray, int]Sends POST request to Jina API with the query and texts, receiving relevance scores.
class YoudaoRerank(DefaultRerank)
Local reranker using the RerankerModel from BCEmbedding.
Similar to
DefaultRerankbut uses a different underlying model and truncation length.
class XInferenceRerank(Base)
Reranker calling an inference API endpoint.
Handles URL normalization for endpoint path
/v1/rerank.Sends POST request with query and documents.
Parses results similarly to
JinaRerank.
class LocalAIRerank(Base)
Reranker calling a LocalAI API.
Normalizes URL.
Uses fixed truncation length (500 tokens).
Normalizes scores to range [0,1].
class NvidiaRerank(Base)
Reranker calling NVIDIA's AI API.
Supports two model names with different API paths.
Sends POST request with query and passages.
Retrieves logits as scores.
class LmStudioRerank(Base)
Stub class for LM-Studio reranker.
Not implemented.
class OpenAI_APIRerank(Base)
Reranker compatible with OpenAI API style.
Normalizes input texts to length 500.
Normalizes output scores.
class CoHereRerank(Base)
Reranker using Cohere's API client.
Uses
cohere.Client.Calls
rerankmethod of the client.
class TogetherAIRerank(Base)
Stub class for TogetherAI reranker.
Not implemented.
class SILICONFLOWRerank(Base)
Reranker calling SiliconFlow API.
Sends POST request with detailed parameters.
Returns combined input and output token counts.
class BaiduYiyanRerank(Base)
Reranker using Baidu Yiyan API.
Initializes client with access keys.
Calls client
.do()method.
class VoyageRerank(Base)
Reranker using the Voyage AI client.
Calls
client.rerank()method.
class QWenRerank(Base)
Reranker using Tongyi-Qianwen API via dashscope.
Calls
dashscope.TextReRank.call().Handles HTTP status codes and errors.
class HuggingfaceRerank(DefaultRerank)
Local reranker that sends requests to a local HuggingFace rerank server.
Static method
post()sends batched POST requests.Constructor sets base URL.
similarity()calls the static method.
class GPUStackRerank(Base)
Reranker calling GPUStack API.
Raises error if base URL missing.
Sends POST request to
/v1/rerank.Parses results.
class NovitaRerank(JinaRerank)
Subclass of JinaRerank with a different default URL.
class GiteeRerank(JinaRerank)
Subclass of JinaRerank with a different default URL.
class Ai302Rerank(Base)
Reranker calling 302.AI API.
Constructor normalizes URL.
Inherits base initialization.
Important Implementation Details
Batch Processing in
DefaultRerank: Uses adaptive batch sizing to handle GPU memory constraints. On CUDA out-of-memory errors, reduces batch size and retries up to 5 times.Token Counting: Token counts are computed using
num_tokens_from_stringfor input length tracking and API usage reporting.Text Truncation: Input texts are truncated to model-specific maximum lengths to prevent exceeding token limits.
API Interaction: Most remote rerankers send POST requests with JSON payloads including model name, query, documents, and requested top N results.
Score Normalization: Some rerankers normalize scores to the [0,1] range to standardize relevance scores.
Thread Safety: Static model instances for local rerankers are protected by threading locks to avoid multiple loads.
Error Handling: Exceptions during API calls or model computation are logged with
log_exceptionfor diagnostics.
Interactions with Other System Components
Settings & Utilities:
Uses
settings.LIGHTENflag to conditionally load heavy models.Uses utility functions
get_home_cache_dir,num_tokens_from_string,truncatefor filesystem and text handling.Uses
log_exceptionfor robust error logging.
External Dependencies:
Relies on third-party APIs and clients (
requests,httpx,huggingface_hub,cohere,dashscope, etc.) for reranking services.Uses local model classes (
FlagReranker,RerankerModel) from external packages.
API Keys & Auth:
Each reranker uses API keys or tokens for authorization headers, supporting secure access.
Model Caching & Downloading:
Uses HuggingFace snapshot downloads and local cache directories to manage model files.
Normalization & Preprocessing:
Applies text truncation and token counting consistently before sending data to models/APIs.
Usage Examples
Using DefaultRerank (Local Model)
reranker = DefaultRerank(key="dummy", model_name="BAAI/bge-reranker-v2-m3")
query = "What is AI?"
texts = ["AI stands for Artificial Intelligence.", "Machine learning is a subset of AI."]
scores, token_count = reranker.similarity(query, texts)
print(scores)
Using JinaRerank (API)
reranker = JinaRerank(key="your_api_key")
query = "Climate change effects"
texts = ["Rising sea levels are a concern.", "Global warming impacts agriculture."]
scores, token_count = reranker.similarity(query, texts)
print(scores)
Mermaid Class Diagram
classDiagram
class Base {
<<abstract>>
+__init__(key, model_name, **kwargs)
+similarity(query: str, texts: list)
+total_token_count(resp)
}
class DefaultRerank {
-_FACTORY_NAME: str = "BAAI"
-_model
-_model_lock
-_dynamic_batch_size: int
-_min_batch_size: int
+__init__(key, model_name, **kwargs)
+torch_empty_cache()
-_process_batch(pairs, max_batch_size=None)
-_compute_batch_scores(batch_pairs, max_length=None)
+similarity(query: str, texts: list)
}
class JinaRerank {
-_FACTORY_NAME: str = "Jina"
-base_url: str
-headers: dict
-model_name: str
+__init__(key, model_name, base_url)
+similarity(query: str, texts: list)
}
class YoudaoRerank {
-_FACTORY_NAME: str = "Youdao"
-_model
-_model_lock
+__init__(key, model_name, **kwargs)
+similarity(query: str, texts: list)
}
class XInferenceRerank {
-_FACTORY_NAME: str = "Xinference"
-base_url: str
-headers: dict
-model_name: str
+__init__(key, model_name, base_url)
+similarity(query: str, texts: list)
}
class LocalAIRerank {
-_FACTORY_NAME: str = "LocalAI"
-base_url: str
-headers: dict
-model_name: str
+__init__(key, model_name, base_url)
+similarity(query: str, texts: list)
}
class NvidiaRerank {
-_FACTORY_NAME: str = "NVIDIA"
-base_url: str
-headers: dict
-model_name: str
+__init__(key, model_name, base_url)
+similarity(query: str, texts: list)
}
class HuggingfaceRerank {
-_FACTORY_NAME: str = "HuggingFace"
-base_url: str
-model_name: str
+__init__(key, model_name, base_url)
+similarity(query: str, texts: list)
+post(query: str, texts: list, url)
}
%% Inheritance
DefaultRerank --|> Base
YoudaoRerank --|> DefaultRerank
JinaRerank --|> Base
XInferenceRerank --|> Base
LocalAIRerank --|> Base
NvidiaRerank --|> Base
HuggingfaceRerank --|> DefaultRerank
Summary
The rerank_model.py file defines a modular and extensible reranking system with multiple backend implementations. It abstracts interaction with local ML models and remote APIs behind a consistent interface for generating relevance scores. This flexibility enables the larger system to adapt to different deployment environments and external service providers while maintaining consistent functionality.
The adaptive batching and careful token management support efficient and scalable reranking workflows, critical for real-time search and retrieval applications.