ocr.py
Overview
The ocr.py file implements a modular Optical Character Recognition (OCR) pipeline designed for document and text image analysis. It provides core functionality for detecting text regions in images, recognizing the text content within those regions, and managing end-to-end OCR workflows. Leveraging ONNX runtime for efficient model inference on both CPU and GPU, this module supports multi-device parallelism and various image preprocessing strategies tailored for different OCR model architectures.
Key capabilities include:
Text detection using a deep learning model to localize text boxes.
Text recognition using a separate model with multiple preprocessing options.
Utilities for image cropping, rotation, and normalization to prepare text regions.
Batch recognition for performance optimization.
Multi-GPU support and caching for model loading.
Integration with HuggingFace model hub for model retrieval.
A high-level
OCRclass that orchestrates detection and recognition steps.
The file is part of a larger system likely focused on document understanding, possibly within the InfiniFlow project context, and depends on related modules like operators and postprocess.
Detailed Documentation
Functions
transform(data, ops=None)
Applies a sequence of operator functions to input data.
Parameters:
data(any): Initial input data to transform.ops(list of callables, optional): List of transformation functions/operators to apply sequentially.
Returns:
Transformed data after applying all operators, or
Noneif any operator returnsNone.
Example:
data = {'image': img} ops = create_operators([{'NormalizeImage': None}, {'ToCHWImage': None}]) transformed_data = transform(data, ops)
create_operators(op_param_list, global_config=None)
Creates a list of operator instances based on configuration dictionaries.
Parameters:
op_param_list(list of dict): Each dict has a single key for operator name and its parameters as the value.global_config(dict, optional): Additional parameters applied to all operators.
Returns:
list: List of operator instances.
Raises:
AssertionErrorif input format is incorrect.
Example:
op_params = [{'NormalizeImage': {'mean': [0.5], 'std': [0.5]}}, {'ToCHWImage': None}] ops = create_operators(op_params)
load_model(model_dir, nm, device_id: int | None = None)
Loads an ONNX model with optional GPU support and caches it for reuse.
Parameters:
model_dir(str): Directory path containing the model files.nm(str): Model name prefix (without extension).device_id(int or None): GPU device ID if using CUDA;Nonefor CPU.
Returns:
Tuple
(onnxruntime.InferenceSession, onnxruntime.RunOptions): The loaded ONNX session and run options.
Raises:
ValueErrorif the model file does not exist.
Implementation Details:
Uses ONNX Runtime with CPU or CUDA execution providers.
Limits GPU memory usage to 512 MB.
Enables memory arena shrinkage to reduce GPU memory fragmentation.
Caches loaded models keyed by model path and device ID.
Classes
TextRecognizer
Performs text recognition from cropped text images.
Constructor:
TextRecognizer(model_dir: str, device_id: int | None = None)Loads the recognition model and postprocessing configuration.
Initializes input tensor shape and batch size.
Methods:
resize_norm_img(img, max_wh_ratio)
Resizes and normalizes an image for recognition, padding to a fixed width based on max aspect ratio.resize_norm_img_vl(img, image_shape)
Resizes an image to a specific shape and normalizes it (used for particular models).resize_norm_img_srn(img, image_shape)
Specialized resizing and grayscale conversion used for SRN models with variable width.srn_other_inputs(image_shape, num_heads, max_text_length)
Constructs positional encodings and attention bias matrices for SRN recognition model inputs.process_image_srn(img, image_shape, num_heads, max_text_length)
Prepares the full set of inputs required for SRN model inference.resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25)
Resizes and normalizes images for SAR model with flexible width constraints.resize_norm_img_spin(img)
Grayscale resize and normalization suited for SPIN model.resize_norm_img_svtr(img, image_shape)
Resizes and normalizes image for SVTR model.resize_norm_img_abinet(img, image_shape)
Resizes and normalizes images for ABINet model including channel-wise mean/std normalization.norm_img_can(img, image_shape)
Normalizes images for CAN model; pads grayscale images if smaller than expected size.close()
Releases model resources and forces garbage collection.call(img_list)
Recognizes text for a list of cropped images in batches. Returns list of(text, score)tuples and elapsed time.
Usage Example:
recognizer = TextRecognizer(model_dir="/models/ocr", device_id=0) cropped_images = [cv2.imread(path) for path in crop_paths] results, elapsed = recognizer(cropped_images) for text, score in results: if score > 0.5: print(text)
TextDetector
Detects text boxes in an image.
Constructor:
TextDetector(model_dir: str, device_id: int | None = None)Initializes preprocessing pipeline and postprocessing for detection.
Loads detection ONNX model.
Methods:
order_points_clockwise(pts)
Orders four points of a quadrilateral clockwise starting from top-left.clip_det_res(points, img_height, img_width)
Clamps detected box points to image boundaries.filter_tag_det_res(dt_boxes, image_shape)
Filters out boxes too small to be likely text and ensures proper ordering/clipping.filter_tag_det_res_only_clip(dt_boxes, image_shape)
Clips boxes to image bounds without filtering.close()
Releases model resources and triggers garbage collection.call(img)
Runs text detection on an input image and returns filtered detected boxes along with inference time.
Usage Example:
detector = TextDetector(model_dir="/models/ocr", device_id=0) image = cv2.imread("doc.jpg") boxes, elapsed = detector(image) for box in boxes: print(box)
OCR
High-level class integrating detection and recognition into an end-to-end OCR pipeline.
Constructor:
OCR(model_dir: str = None)Attempts to load models from local directory.
Falls back to downloading from HuggingFace Hub if needed.
Supports multi-device initialization based on
PARALLEL_DEVICESsetting.Initializes default thresholds and internal state.
Methods:
get_rotate_crop_image(img, points)
Crops a text region from the image by perspective transform to a normalized rectangle, and attempts rotation corrections if the aspect ratio suggests vertical text.sorted_boxes(dt_boxes)
Sorts detected boxes top-to-bottom and left-to-right for logical reading order.detect(img, device_id: int | None = None)
Runs text detection and returns detected boxes and timing info.recognize(ori_im, box, device_id: int | None = None)
Recognizes text within a single bounding box from the original image, applying cropping and rotation.recognize_batch(img_list, device_id: int | None = None)
Recognizes text for a batch of cropped images.call(img, device_id=0, cls=True)
Runs the full OCR process (detection + recognition) on the input image, returning recognized texts with boxes and timing statistics.
Usage Example:
ocr_engine = OCR(model_dir="/models/ocr") image = cv2.imread("invoice.jpg") results = ocr_engine(image) for box, (text, score) in results: print(f"Text: {text}, Box: {box}")
Important Implementation Details and Algorithms
Model Loading and Caching: Models are loaded via ONNX Runtime with options to run on CPU or CUDA GPU. Loaded sessions are cached globally keyed by model path and device ID to avoid repeated loads.
Image Preprocessing: Multiple specialized image resizing and normalization functions are implemented to accommodate different OCR model architectures (e.g., SRN, SAR, SPIN, SVTR, ABINet). This modular approach allows flexible adaptation to various model requirements.
Text Detection Postprocessing: Detected polygons are ordered clockwise and clipped within image bounds to ensure valid crops. Small boxes (less than 3 pixels width/height) are filtered out.
Rotation Handling in Cropping: For vertical or rotated text regions, the
get_rotate_crop_imagemethod tries the initial orientation and 90-degree rotations clockwise and counterclockwise to find the best recognized text score, choosing the best orientation.Batch Recognition: Recognition is done in batches with sorting by aspect ratio to optimize GPU throughput and reduce padding overhead.
Multi-Device Parallelism: The
OCRclass supports multiple devices (GPUs) by maintaining separate detectors and recognizers per device, enabling parallel processing in multi-GPU environments.Error Handling and Retry: ONNX inference calls are wrapped with retry loops (up to 3 retries with delays) to handle transient runtime failures gracefully.
Interaction with Other System Components
Dependencies:
operatorsmodule: Provides operator classes used in preprocessing pipelines.postprocessmodule: Provides postprocessing functions for detection and recognition outputs.api.utils.file_utils: Used to resolve base project directories.rag.settings.PARALLEL_DEVICES: Configuration for multi-device support.huggingface_hub.snapshot_download: For downloading models if not locally available.
Integration Points:
The OCR pipeline is designed to be a reusable component that can be called by other system parts needing text extraction from images.
Can serve as a backend to higher-level document understanding or data extraction services.
The modular operator and postprocess architecture allow easy adaptation or extension with new preprocessing/postprocessing steps.
Visual Diagram
Below is a class diagram illustrating the main classes, their core methods, and relationships in ocr.py.
classDiagram
class OCR {
-text_detector: list~TextDetector~
-text_recognizer: list~TextRecognizer~
-drop_score: float
+__init__(model_dir=None)
+detect(img, device_id)
+recognize(ori_im, box, device_id)
+recognize_batch(img_list, device_id)
+__call__(img, device_id, cls)
+get_rotate_crop_image(img, points)
+sorted_boxes(dt_boxes)
}
class TextDetector {
-preprocess_op: list
-postprocess_op
-predictor
-run_options
-input_tensor
+__init__(model_dir, device_id)
+__call__(img)
+order_points_clockwise(pts)
+clip_det_res(points, img_h, img_w)
+filter_tag_det_res(dt_boxes, image_shape)
+filter_tag_det_res_only_clip(dt_boxes, image_shape)
+close()
}
class TextRecognizer {
-rec_image_shape: list
-rec_batch_num: int
-postprocess_op
-predictor
-run_options
-input_tensor
+__init__(model_dir, device_id)
+__call__(img_list)
+resize_norm_img(img, max_wh_ratio)
+resize_norm_img_vl(img, image_shape)
+resize_norm_img_srn(img, image_shape)
+srn_other_inputs(image_shape, num_heads, max_text_length)
+process_image_srn(img, image_shape, num_heads, max_text_length)
+resize_norm_img_sar(img, image_shape, width_downsample_ratio)
+resize_norm_img_spin(img)
+resize_norm_img_svtr(img, image_shape)
+resize_norm_img_abinet(img, image_shape)
+norm_img_can(img, image_shape)
+close()
}
OCR "1" *-- "*" TextDetector : has
OCR "1" *-- "*" TextRecognizer : has
Summary
ocr.py implements a robust, extensible OCR pipeline supporting state-of-the-art detection and recognition models via ONNX Runtime. It offers flexible, configurable preprocessing, multi-device inference, and modular design to integrate seamlessly into larger document processing systems. The OCR class serves as the primary interface, coordinating text detection and recognition with performance optimizations such as batching and smart cropping with rotation correction.