| import os |
| import time |
| import torch |
| import spaces |
| import warnings |
| import tempfile |
| import sys |
| from io import StringIO |
| from contextlib import contextmanager |
| from threading import Thread |
| from PIL import Image |
| from transformers import ( |
| AutoProcessor, |
| AutoModelForCausalLM, |
| AutoModel, |
| AutoTokenizer, |
| Qwen2_5_VLForConditionalGeneration, |
| TextIteratorStreamer |
| ) |
| from huggingface_hub import snapshot_download |
| from qwen_vl_utils import process_vision_info |
|
|
|
|
|
|
|
|
| |
| warnings.filterwarnings('ignore', message='Some weights.*were not initialized') |
|
|
|
|
|
|
|
|
| |
| try: |
| from transformers import Qwen3VLForConditionalGeneration |
| except ImportError: |
| Qwen3VLForConditionalGeneration = None |
|
|
|
|
|
|
|
|
| MAX_MAX_NEW_TOKENS = 4096 |
| DEFAULT_MAX_NEW_TOKENS = 2048 |
| MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) |
| CACHE_DIR = os.getenv("HF_CACHE_DIR", "./models") |
|
|
|
|
|
|
|
|
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
| print(f"Initial Device: {device}") |
| print(f"CUDA Available: {torch.cuda.is_available()}") |
|
|
|
|
|
|
|
|
| |
| try: |
| MODEL_ID_V = "datalab-to/chandra" |
| processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True) |
| if Qwen3VLForConditionalGeneration: |
| model_v = Qwen3VLForConditionalGeneration.from_pretrained( |
| MODEL_ID_V, |
| trust_remote_code=True, |
| torch_dtype=torch.float16, |
| device_map="auto" |
| ).eval() |
| print("✓ Chandra-OCR loaded") |
| else: |
| model_v = None |
| print("✗ Chandra-OCR: Qwen3VL not available") |
| except Exception as e: |
| model_v = None |
| processor_v = None |
| print(f"✗ Chandra-OCR: Failed to load - {str(e)}") |
|
|
|
|
|
|
|
|
| |
| try: |
| MODEL_ID_X = "nanonets/Nanonets-OCR2-3B" |
| processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True) |
| model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
| MODEL_ID_X, |
| trust_remote_code=True, |
| torch_dtype=torch.float16, |
| device_map="auto" |
| ).eval() |
| print("✓ Nanonets-OCR2-3B loaded") |
| except Exception as e: |
| model_x = None |
| processor_x = None |
| print(f"✗ Nanonets-OCR2-3B: Failed to load - {str(e)}") |
|
|
| |
| try: |
| MODEL_ID_M = "allenai/olmOCR-2-7B-1025" |
| processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True) |
| model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
| MODEL_ID_M, |
| trust_remote_code=True, |
| torch_dtype=torch.float16, |
| device_map="auto" |
| ).eval() |
| print("✓ olmOCR-2-7B-1025 loaded") |
| except Exception as e: |
| model_m = None |
| processor_m = None |
| print(f"✗ olmOCR-2-7B-1025: Failed to load - {str(e)}") |
|
|
|
|
|
|
|
|
| @spaces.GPU |
| def generate_image(model_name: str, text: str, image: Image.Image, |
| max_new_tokens: int, temperature: float, top_p: float, |
| top_k: int, repetition_penalty: float): |
| """ |
| Generates responses using the selected model for image input. |
| Yields raw text and Markdown-formatted text. |
| This function is decorated with @spaces.GPU to ensure it runs on GPU |
| when available in Hugging Face Spaces. |
| Args: |
| model_name: Name of the OCR model to use |
| text: Prompt text for the model |
| image: PIL Image object to process |
| max_new_tokens: Maximum number of tokens to generate |
| temperature: Sampling temperature |
| top_p: Nucleus sampling parameter |
| top_k: Top-k sampling parameter |
| repetition_penalty: Penalty for repeating tokens |
| Yields: |
| tuple: (raw_text, markdown_text) |
| """ |
| |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| |
| if model_name == "olmOCR-2-7B-1025": |
| if model_m is None: |
| yield "olmOCR-2-7B-1025 is not available.", "olmOCR-2-7B-1025 is not available." |
| return |
| processor = processor_m |
| model = model_m |
| elif model_name == "Nanonets-OCR2-3B": |
| if model_x is None: |
| yield "Nanonets-OCR2-3B is not available.", "Nanonets-OCR2-3B is not available." |
| return |
| processor = processor_x |
| model = model_x |
| elif model_name == "Chandra-OCR": |
| if model_v is None: |
| yield "Chandra-OCR is not available.", "Chandra-OCR is not available." |
| return |
| processor = processor_v |
| model = model_v |
| else: |
| yield "Invalid model selected.", "Invalid model selected." |
| return |
|
|
|
|
|
|
|
|
| if image is None: |
| yield "Please upload an image.", "Please upload an image." |
| return |
|
|
|
|
| try: |
| |
| messages = [{ |
| "role": "user", |
| "content": [ |
| {"type": "image"}, |
| {"type": "text", "text": text}, |
| ] |
| }] |
|
|
|
|
| |
| try: |
| prompt_full = processor.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True |
| ) |
| except Exception as template_error: |
| |
| print(f"Chat template error: {template_error}. Using fallback prompt.") |
| prompt_full = f"{text}" |
|
|
|
|
|
|
|
|
| |
| inputs = processor( |
| text=[prompt_full], |
| images=[image], |
| return_tensors="pt", |
| padding=True |
| ).to(device) |
|
|
|
|
|
|
|
|
| |
| streamer = TextIteratorStreamer( |
| processor.tokenizer if hasattr(processor, 'tokenizer') else processor, |
| skip_prompt=True, |
| skip_special_tokens=True |
| ) |
|
|
|
|
| generation_kwargs = { |
| **inputs, |
| "streamer": streamer, |
| "max_new_tokens": max_new_tokens, |
| "do_sample": True, |
| "temperature": temperature, |
| "top_p": top_p, |
| "top_k": top_k, |
| "repetition_penalty": repetition_penalty, |
| } |
|
|
|
|
| |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) |
| thread.start() |
|
|
|
|
| |
| buffer = "" |
| for new_text in streamer: |
| buffer += new_text |
| buffer = buffer.replace("<|im_end|>", "") |
| time.sleep(0.01) |
| yield buffer, buffer |
|
|
|
|
| |
| thread.join() |
|
|
|
|
| except Exception as e: |
| error_msg = f"Error during generation: {str(e)}" |
| print(f"Full error: {e}") |
| import traceback |
| traceback.print_exc() |
| yield error_msg, error_msg |
|
|
|
|
|
|
|
|
| |
| if __name__ == "__main__": |
| import gradio as gr |
|
|
|
|
| |
| available_models = [] |
| if model_m is not None: |
| available_models.append("olmOCR-2-7B-1025") |
| print(" Added: olmOCR-2-7B-1025") |
| if model_x is not None: |
| available_models.append("Nanonets-OCR2-3B") |
| print(" Added: Nanonets-OCR2-3B") |
| if model_v is not None: |
| available_models.append("Chandra-OCR") |
| print(" Added: Chandra-OCR") |
| if not available_models: |
| print("ERROR: No models were loaded successfully!") |
| exit(1) |
|
|
|
|
| print(f"\n✓ Available models for dropdown: {', '.join(available_models)}") |
|
|
|
|
| with gr.Blocks(title="Multi-Model OCR") as demo: |
| gr.Markdown("# 🔍 Multi-Model OCR Application") |
| gr.Markdown("Upload an image and select a model to extract text. Models run on GPU via Hugging Face Spaces.") |
|
|
|
|
| with gr.Row(): |
| with gr.Column(): |
| model_selector = gr.Dropdown( |
| choices=available_models, |
| value=available_models[0] if available_models else None, |
| label="Select OCR Model" |
| ) |
| image_input = gr.Image(type="pil", label="Upload Image") |
| text_input = gr.Textbox( |
| value="Extract all text from this image.", |
| label="Prompt", |
| lines=2 |
| ) |
|
|
|
|
| with gr.Accordion("Advanced Settings", open=False): |
| max_tokens = gr.Slider( |
| minimum=1, |
| maximum=MAX_MAX_NEW_TOKENS, |
| value=DEFAULT_MAX_NEW_TOKENS, |
| step=1, |
| label="Max New Tokens" |
| ) |
| temperature = gr.Slider( |
| minimum=0.1, |
| maximum=2.0, |
| value=0.7, |
| step=0.1, |
| label="Temperature" |
| ) |
| top_p = gr.Slider( |
| minimum=0.0, |
| maximum=1.0, |
| value=0.9, |
| step=0.05, |
| label="Top P" |
| ) |
| top_k = gr.Slider( |
| minimum=1, |
| maximum=100, |
| value=50, |
| step=1, |
| label="Top K" |
| ) |
| repetition_penalty = gr.Slider( |
| minimum=1.0, |
| maximum=2.0, |
| value=1.1, |
| step=0.1, |
| label="Repetition Penalty" |
| ) |
|
|
|
|
| submit_btn = gr.Button("Extract Text", variant="primary") |
|
|
|
|
| with gr.Column(): |
| output_text = gr.Textbox(label="Extracted Text", lines=20) |
| output_markdown = gr.Markdown(label="Formatted Output") |
|
|
|
|
| gr.Markdown(""" |
| ### Available Models: |
| - **olmOCR-2-7B-1025**: Allen AI's OCR model |
| - **Nanonets-OCR2-3B**: Nanonets OCR model |
| - **Chandra-OCR**: Datalab OCR model |
| """) |
|
|
|
|
| submit_btn.click( |
| fn=generate_image, |
| inputs=[ |
| model_selector, |
| text_input, |
| image_input, |
| max_tokens, |
| temperature, |
| top_p, |
| top_k, |
| repetition_penalty |
| ], |
| outputs=[output_text, output_markdown] |
| ) |
|
|
|
|
| |
| demo.launch(share=True) |
|
|