|
|
|
|
| import fitz |
| import numpy as np |
| import cv2 |
| import torch |
| import torch.serialization |
| import os |
| import time |
| from typing import Optional, Tuple, List, Dict, Any |
| from ultralytics import YOLO |
| import logging |
| import gradio as gr |
| import shutil |
| import tempfile |
| import io |
|
|
| |
| |
| |
|
|
| |
| _original_torch_load = torch.load |
| def patched_torch_load(*args, **kwargs): |
| kwargs["weights_only"] = False |
| return _original_torch_load(*args, **kwargs) |
| torch.load = patched_torch_load |
|
|
| logging.basicConfig(level=logging.WARNING) |
|
|
| |
| |
| |
|
|
| WEIGHTS_PATH = 'best.pt' |
| SCALE_FACTOR = 2.0 |
|
|
| |
| CONF_THRESHOLD = 0.2 |
| TARGET_CLASSES = ['figure', 'equation'] |
| IOU_MERGE_THRESHOLD = 0.4 |
| IOA_SUPPRESSION_THRESHOLD = 0.7 |
|
|
| |
| GLOBAL_FIGURE_COUNT = 0 |
| GLOBAL_EQUATION_COUNT = 0 |
|
|
| |
| |
| |
|
|
| def calculate_iou(box1, box2): |
| x1_a, y1_a, x2_a, y2_a = box1 |
| x1_b, y1_b, x2_b, y2_b = box2 |
| x_left = max(x1_a, x1_b) |
| y_top = max(y1_a, y1_b) |
| x_right = min(x2_a, x2_b) |
| y_bottom = min(y2_a, y2_b) |
| intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top) |
| box_a_area = (x2_a - x1_a) * (y2_a - y1_a) |
| box_b_area = (x2_b - x1_b) * (y2_b - y1_b) |
| union_area = float(box_a_area + box_b_area - intersection_area) |
| return intersection_area / union_area if union_area > 0 else 0 |
|
|
|
|
| def filter_nested_boxes(detections, ioa_threshold=0.80): |
| if not detections: return [] |
| for d in detections: |
| x1, y1, x2, y2 = d['coords'] |
| d['area'] = (x2 - x1) * (y2 - y1) |
| detections.sort(key=lambda x: x['area'], reverse=True) |
| keep_indices = [] |
| is_suppressed = [False] * len(detections) |
| for i in range(len(detections)): |
| if is_suppressed[i]: continue |
| keep_indices.append(i) |
| box_a = detections[i]['coords'] |
| for j in range(i + 1, len(detections)): |
| if is_suppressed[j]: continue |
| box_b = detections[j]['coords'] |
| x_left = max(box_a[0], box_b[0]) |
| y_top = max(box_a[1], box_b[1]) |
| x_right = min(box_a[2], box_b[2]) |
| y_bottom = min(box_a[3], box_b[3]) |
| intersection = max(0, x_right - x_left) * max(0, y_bottom - y_top) |
| area_b = detections[j]['area'] |
| if area_b > 0 and intersection / area_b > ioa_threshold: |
| is_suppressed[j] = True |
| return [detections[i] for i in keep_indices] |
|
|
|
|
| def merge_overlapping_boxes(detections, iou_threshold): |
| if not detections: return [] |
| detections.sort(key=lambda d: d['conf'], reverse=True) |
| merged_detections = [] |
| is_merged = [False] * len(detections) |
| for i in range(len(detections)): |
| if is_merged[i]: continue |
| current_box = detections[i]['coords'] |
| current_class = detections[i]['class'] |
| merged_x1, merged_y1, merged_x2, merged_y2 = current_box |
| for j in range(i + 1, len(detections)): |
| if is_merged[j] or detections[j]['class'] != current_class: continue |
| other_box = detections[j]['coords'] |
| iou = calculate_iou(current_box, other_box) |
| if iou > iou_threshold: |
| merged_x1 = min(merged_x1, other_box[0]) |
| merged_y1 = min(merged_y1, other_box[1]) |
| merged_x2 = max(merged_x2, other_box[2]) |
| merged_y2 = max(merged_y2, other_box[3]) |
| is_merged[j] = True |
| merged_detections.append({ |
| 'coords': (merged_x1, merged_y1, merged_x2, merged_y2), |
| 'y1': merged_y1, 'class': current_class, 'conf': detections[i]['conf'] |
| }) |
| return merged_detections |
|
|
| |
| |
| |
|
|
| def pixmap_to_numpy(pix: fitz.Pixmap) -> np.ndarray: |
| """Converts a PyMuPDF Pixmap to a NumPy array for OpenCV/YOLO.""" |
| img = np.frombuffer(pix.samples, dtype=np.uint8).reshape( |
| (pix.h, pix.w, pix.n) |
| ) |
| if pix.n == 4: |
| img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB) |
| elif pix.n == 1: |
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) |
| return img |
|
|
|
|
| def run_yolo_detection_and_count( |
| image: np.ndarray, model: YOLO, page_num: int |
| ) -> Tuple[int, int]: |
| """ |
| Runs YOLO inference, applies NMS/filtering, and updates global counters. |
| Returns page counts only. |
| """ |
| global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT |
| |
| yolo_detections = [] |
| page_equations = 0 |
| page_figures = 0 |
| |
| try: |
| results = model.predict(image, conf=CONF_THRESHOLD, verbose=False) |
| |
| if results and results[0].boxes: |
| for box in results[0].boxes.data.tolist(): |
| x1, y1, x2, y2, conf, cls_id = box |
| cls_name = model.names[int(cls_id)] |
| |
| if cls_name in TARGET_CLASSES: |
| yolo_detections.append({ |
| 'coords': (x1, y1, x2, y2), |
| 'class': cls_name, |
| 'conf': conf |
| }) |
| except Exception as e: |
| logging.error(f"YOLO inference failed on page {page_num}: {e}") |
| return 0, 0 |
|
|
| |
| merged_detections = merge_overlapping_boxes(yolo_detections, IOU_MERGE_THRESHOLD) |
| final_detections = filter_nested_boxes(merged_detections, IOA_SUPPRESSION_THRESHOLD) |
|
|
| |
| for det in final_detections: |
| if det['class'] == 'figure': |
| GLOBAL_FIGURE_COUNT += 1 |
| page_figures += 1 |
| elif det['class'] == 'equation': |
| GLOBAL_EQUATION_COUNT += 1 |
| page_equations += 1 |
| |
| logging.warning(f" -> Page {page_num}: EQs={page_equations}, Figs={page_figures}") |
| return page_equations, page_figures |
|
|
|
|
| |
| |
| |
|
|
| |
| def run_single_pdf_preprocessing(pdf_path: str) -> Tuple[int, int, int, str, float, Dict[str, int], List[str]]: |
| """ |
| Runs the pipeline, returns counts, report, total time, page counts dict (str keys), and empty list. |
| """ |
| global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT |
| start_time = time.time() |
| log_messages = [] |
| |
| |
| equation_counts_per_page: Dict[int, int] = {} |
|
|
| |
| GLOBAL_FIGURE_COUNT = 0 |
| GLOBAL_EQUATION_COUNT = 0 |
|
|
| |
| t0 = time.time() |
| if not os.path.exists(pdf_path): |
| report = f"❌ FATAL ERROR: Input PDF not found at {pdf_path}." |
| return 0, 0, 0, report, time.time() - start_time, {}, [] |
| |
| try: |
| model = YOLO(WEIGHTS_PATH) |
| logging.warning(f"✅ Loaded YOLO model from: {WEIGHTS_PATH}") |
| except Exception as e: |
| report = f"❌ ERROR loading YOLO model: {e}\n(Ensure 'best.pt' is available and valid.)" |
| return 0, 0, 0, report, time.time() - start_time, {}, [] |
| t1 = time.time() |
| log_messages.append(f"Model Loading Time: {t1-t0:.4f}s") |
| |
| |
| t2 = time.time() |
| try: |
| doc = fitz.open(pdf_path) |
| total_pages = doc.page_count |
| logging.warning(f"✅ Opened PDF with {doc.page_count} pages") |
| except Exception as e: |
| report = f"❌ ERROR loading PDF file: {e}" |
| return 0, 0, 0, report, time.time() - start_time, {}, [] |
| t3 = time.time() |
| log_messages.append(f"PDF Initialization Time: {t3-t2:.4f}s") |
|
|
| mat = fitz.Matrix(SCALE_FACTOR, SCALE_FACTOR) |
| |
| |
| t4 = time.time() |
| for page_num_0_based in range(doc.page_count): |
| page_start_time = time.time() |
| fitz_page = doc.load_page(page_num_0_based) |
| page_num = page_num_0_based + 1 |
|
|
| |
| try: |
| pix_start = time.time() |
| pix = fitz_page.get_pixmap(matrix=mat) |
| original_img = pixmap_to_numpy(pix) |
| pix_time = time.time() - pix_start |
| except Exception as e: |
| logging.error(f"Error converting page {page_num} to image: {e}. Skipping.") |
| continue |
| |
| |
| detect_start = time.time() |
| page_equations, _ = run_yolo_detection_and_count(original_img, model, page_num) |
| detect_time = time.time() - detect_start |
| |
| |
| equation_counts_per_page[page_num] = page_equations |
| |
| page_total_time = time.time() - page_start_time |
| log_messages.append(f"Page {page_num} Time: Total={page_total_time:.4f}s (Render={pix_time:.4f}s, Detect={detect_time:.4f}s)") |
| |
| doc.close() |
| t5 = time.time() |
| detection_loop_time = t5 - t4 |
| log_messages.append(f"Total Detection Loop Time ({total_pages} pages): {detection_loop_time:.4f}s") |
|
|
| |
| equation_counts_per_page_str_keys: Dict[str, int] = { |
| str(k): v for k, v in equation_counts_per_page.items() |
| } |
|
|
| |
| total_execution_time = t5 - start_time |
| |
| report = ( |
| f"✅ **YOLO Counting Complete!**\n\n" |
| f"**1) Total Pages Detected in PDF:** **{total_pages}**\n" |
| f"**2) Total Equations Detected:** **{GLOBAL_EQUATION_COUNT}**\n" |
| f"**3) Total Figures Detected:** **{GLOBAL_FIGURE_COUNT}**\n" |
| f"---\n" |
| f"**4) Total Execution Time:** **{total_execution_time:.4f}s**\n" |
| f"### Detailed Step Timing\n" |
| f"```\n" |
| + "\n".join(log_messages) + |
| f"\n```" |
| ) |
|
|
| |
| return total_pages, GLOBAL_EQUATION_COUNT, GLOBAL_FIGURE_COUNT, report, total_execution_time, equation_counts_per_page_str_keys, [] |
|
|
|
|
| |
| |
| |
|
|
| def gradio_process_pdf(pdf_file) -> Tuple[str, str, str, str, Dict[str, int], List[str]]: |
| """ |
| Gradio wrapper function to handle file upload and return results. |
| """ |
| if pdf_file is None: |
| |
| return "N/A", "N/A", "N/A", "Please upload a PDF file.", {}, [] |
| |
| pdf_path = pdf_file.name |
|
|
| try: |
| |
| num_pages, num_equations, num_figures, report, total_time, equation_counts_per_page, _ = run_single_pdf_preprocessing( |
| pdf_path |
| ) |
| |
| |
| return str(num_pages), str(num_equations), str(num_figures), report, equation_counts_per_page, [] |
| |
| except Exception as e: |
| error_msg = f"An unexpected error occurred: {e}" |
| logging.error(error_msg, exc_info=True) |
| |
| return "Error", "Error", "Error", error_msg, {}, [] |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| |
| if not os.path.exists(WEIGHTS_PATH): |
| logging.error(f"❌ FATAL ERROR: YOLO weight file '{WEIGHTS_PATH}' not found. Cannot run live inference.") |
| |
| input_file = gr.File(label="Upload PDF Document", type="filepath", file_types=[".pdf"]) |
| |
| |
| output_pages = gr.Textbox(label="Total Pages in PDF", interactive=False) |
| output_equations = gr.Textbox(label="Total Equations Detected", interactive=False) |
| output_figures = gr.Textbox(label="Total Figures Detected", interactive=False) |
| output_report = gr.Markdown(label="Processing Summary and Timing") |
| |
| |
| output_page_counts = gr.JSON(label="Equation Count Per Page (Dictionary)") |
| |
| |
| output_gallery = gr.Gallery( |
| label="Detected Equations (Disabled for Speed)", |
| columns=5, |
| height="auto", |
| object_fit="contain", |
| allow_preview=False |
| ) |
| |
| interface = gr.Interface( |
| fn=gradio_process_pdf, |
| inputs=input_file, |
| |
| outputs=[ |
| output_pages, |
| output_equations, |
| output_figures, |
| output_report, |
| output_page_counts, |
| output_gallery |
| ], |
| title="📊 YOLO Counting with Per-Page Data & Timing", |
| description=( |
| "Upload a PDF to run YOLO detection. The results include total counts, a breakdown of " |
| "equation counts per page (in JSON format), and detailed timing." |
| ), |
| ) |
|
|
| print("\nStarting Gradio application...") |
| interface.launch(inbrowser=True) |
|
|
|
|