Instructions to use MRiabov/WireSegHR with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use MRiabov/WireSegHR with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("image-segmentation", model="MRiabov/WireSegHR")# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("MRiabov/WireSegHR", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import argparse | |
| import os | |
| import pprint | |
| import time | |
| from typing import List, Tuple, Optional, Dict, Any | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.amp import autocast | |
| from tqdm import tqdm | |
| from safetensors.torch import load_file as safe_load_file | |
| import yaml | |
| from src.wireseghr.model import WireSegHR | |
| from pathlib import Path | |
| from src.wireseghr.metrics import compute_metrics | |
| def _pad_for_minmax(kernel: int) -> Tuple[int, int, int, int]: | |
| # Replicate the padding logic from train.validate for even/odd kernels | |
| if (kernel % 2) == 0: | |
| return (kernel // 2 - 1, kernel // 2, kernel // 2 - 1, kernel // 2) | |
| else: | |
| return (kernel // 2, kernel // 2, kernel // 2, kernel // 2) | |
| def _coarse_forward( | |
| model: WireSegHR, | |
| img_rgb: np.ndarray, | |
| coarse_size: int, | |
| minmax_enable: bool, | |
| minmax_kernel: int, | |
| device: torch.device, | |
| amp_flag: bool, | |
| amp_dtype, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| # Convert to tensor on device | |
| t_img = ( | |
| torch.from_numpy(np.transpose(img_rgb, (2, 0, 1))) | |
| .unsqueeze(0) | |
| .to(device) | |
| .float() | |
| ) # 1x3xHxW | |
| H = img_rgb.shape[0] | |
| W = img_rgb.shape[1] | |
| rgb_c = F.interpolate( | |
| t_img, size=(coarse_size, coarse_size), mode="bilinear", align_corners=False | |
| )[0] | |
| y_t = 0.299 * t_img[:, 0:1] + 0.587 * t_img[:, 1:2] + 0.114 * t_img[:, 2:3] | |
| if minmax_enable: | |
| pad = _pad_for_minmax(minmax_kernel) | |
| y_p = F.pad(y_t, pad, mode="replicate") | |
| y_max_full = F.max_pool2d(y_p, kernel_size=minmax_kernel, stride=1) | |
| y_min_full = -F.max_pool2d(-y_p, kernel_size=minmax_kernel, stride=1) | |
| else: | |
| y_min_full = y_t | |
| y_max_full = y_t | |
| y_min_c = F.interpolate( | |
| y_min_full, | |
| size=(coarse_size, coarse_size), | |
| mode="bilinear", | |
| align_corners=False, | |
| )[0] | |
| y_max_c = F.interpolate( | |
| y_max_full, | |
| size=(coarse_size, coarse_size), | |
| mode="bilinear", | |
| align_corners=False, | |
| )[0] | |
| zeros_c = torch.zeros(1, coarse_size, coarse_size, device=device) | |
| x_t = torch.cat([rgb_c, y_min_c, y_max_c, zeros_c], dim=0).unsqueeze(0) | |
| with autocast(device_type=device.type, dtype=amp_dtype, enabled=amp_flag): | |
| logits_c, cond_map = model.forward_coarse(x_t) | |
| prob = torch.softmax(logits_c, dim=1)[:, 1:2] | |
| prob_up = ( | |
| F.interpolate(prob, size=(H, W), mode="bilinear", align_corners=False)[0, 0] | |
| .detach() | |
| .cpu() | |
| .float() | |
| ) # HxW torch.Tensor on CPU | |
| return prob_up, cond_map, t_img, y_min_full, y_max_full | |
| def _tiled_fine_forward( | |
| model: WireSegHR, | |
| t_img: torch.Tensor, # 1x3xHxW on device | |
| cond_map: torch.Tensor, # 1x1xhxw | |
| y_min_full: torch.Tensor, # 1x1xHxW | |
| y_max_full: torch.Tensor, # 1x1xHxW | |
| patch_size: int, | |
| overlap: int, | |
| fine_batch: int, | |
| device: torch.device, | |
| amp_flag: bool, | |
| amp_dtype, | |
| ) -> torch.Tensor: | |
| H = int(t_img.shape[2]) | |
| W = int(t_img.shape[3]) | |
| P = patch_size | |
| stride = P - overlap | |
| assert stride > 0 | |
| assert H >= P and W >= P | |
| prob_sum_t = torch.zeros((H, W), device=device, dtype=torch.float32) | |
| weight_t = torch.zeros((H, W), device=device, dtype=torch.float32) | |
| hc4, wc4 = cond_map.shape[2], cond_map.shape[3] | |
| ys = list(range(0, H - P + 1, stride)) | |
| if ys[-1] != (H - P): | |
| ys.append(H - P) | |
| xs = list(range(0, W - P + 1, stride)) | |
| if xs[-1] != (W - P): | |
| xs.append(W - P) | |
| coords: List[Tuple[int, int]] = [] | |
| for y0 in ys: | |
| for x0 in xs: | |
| coords.append((y0, x0)) | |
| for i0 in range(0, len(coords), fine_batch): | |
| batch_coords = coords[i0 : i0 + fine_batch] | |
| xs_list: List[torch.Tensor] = [] | |
| for y0, x0 in batch_coords: | |
| y1, x1 = y0 + P, x0 + P | |
| # Map to cond grid | |
| y0c = (y0 * hc4) // H | |
| y1c = ((y1 * hc4) + H - 1) // H | |
| x0c = (x0 * wc4) // W | |
| x1c = ((x1 * wc4) + W - 1) // W | |
| cond_sub = cond_map[:, :, y0c:y1c, x0c:x1c].float() | |
| cond_patch = F.interpolate( | |
| cond_sub, size=(P, P), mode="bilinear", align_corners=False | |
| ).squeeze(1) # 1xPxP | |
| rgb_t = t_img[0, :, y0:y1, x0:x1] # 3xPxP | |
| ymin_t = y_min_full[0, 0, y0:y1, x0:x1].float().unsqueeze(0) # 1xPxP | |
| ymax_t = y_max_full[0, 0, y0:y1, x0:x1].float().unsqueeze(0) # 1xPxP | |
| x_f = torch.cat([rgb_t, ymin_t, ymax_t, cond_patch], dim=0).unsqueeze(0) | |
| xs_list.append(x_f) | |
| x_f_batch = torch.cat(xs_list, dim=0) # Bx6xPxP | |
| with autocast(device_type=device.type, dtype=amp_dtype, enabled=amp_flag): | |
| logits_f = model.forward_fine(x_f_batch) | |
| prob_f = torch.softmax(logits_f, dim=1)[:, 1:2] | |
| prob_f_up = F.interpolate( | |
| prob_f, size=(P, P), mode="bilinear", align_corners=False | |
| )[:, 0, :, :] # BxPxP | |
| for bi, (y0, x0) in enumerate(batch_coords): | |
| y1, x1 = y0 + P, x0 + P | |
| prob_sum_t[y0:y1, x0:x1] += prob_f_up[bi] | |
| weight_t[y0:y1, x0:x1] += 1.0 | |
| prob_full = (prob_sum_t / weight_t).detach().cpu().float() | |
| return prob_full # HxW torch.Tensor on CPU | |
| def _build_model_from_cfg(cfg: dict, device: torch.device) -> WireSegHR: | |
| pretrained_flag = bool(cfg.get("pretrained", False)) | |
| model = WireSegHR( | |
| backbone=cfg["backbone"], in_channels=6, pretrained=pretrained_flag | |
| ) | |
| model = model.to(device) | |
| return model | |
| def infer_image( | |
| model: WireSegHR, | |
| img_path: str, | |
| cfg: dict, | |
| device: torch.device, | |
| amp_flag: bool, | |
| amp_dtype, | |
| out_dir: Optional[str] = None, | |
| save_prob: bool = False, | |
| prob_thresh: Optional[float] = None, | |
| ) -> Tuple[np.ndarray, np.ndarray]: | |
| assert Path(img_path).is_file(), f"Image not found: {img_path}" | |
| bgr = cv2.imread(img_path, cv2.IMREAD_COLOR) | |
| assert bgr is not None, f"Failed to read {img_path}" | |
| rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
| coarse_size = int(cfg["coarse"]["test_size"]) | |
| patch_size = int(cfg["inference"]["fine_patch_size"]) # 1024 for inference | |
| overlap = int(cfg["fine"]["overlap"]) | |
| minmax_enable = bool(cfg["minmax"]["enable"]) | |
| minmax_kernel = int(cfg["minmax"]["kernel"]) | |
| if prob_thresh is None: | |
| prob_thresh = float(cfg["inference"]["prob_threshold"]) | |
| prob_c, cond_map, t_img, y_min_full, y_max_full = _coarse_forward( | |
| model, | |
| rgb, | |
| coarse_size, | |
| minmax_enable, | |
| minmax_kernel, | |
| device, | |
| amp_flag, | |
| amp_dtype, | |
| ) | |
| prob_f = _tiled_fine_forward( | |
| model, | |
| t_img, | |
| cond_map, | |
| y_min_full, | |
| y_max_full, | |
| patch_size, | |
| overlap, | |
| int(cfg.get("eval", {}).get("fine_batch", 16)), | |
| device, | |
| amp_flag, | |
| amp_dtype, | |
| ) | |
| # Threshold with torch on CPU; convert to numpy only for saving/returning | |
| pred_t = (prob_f > prob_thresh).to(torch.uint8) * 255 # HxW uint8 torch | |
| pred = pred_t.detach().cpu().numpy() | |
| if out_dir is not None: | |
| os.makedirs(out_dir, exist_ok=True) | |
| stem = Path(img_path).stem | |
| out_mask = Path(out_dir) / f"{stem}_pred.png" | |
| cv2.imwrite(str(out_mask), pred) | |
| if save_prob: | |
| out_prob = Path(out_dir) / f"{stem}_prob.npy" | |
| np.save(out_prob, prob_f.detach().cpu().float().numpy()) | |
| # Return numpy arrays for external consumers, computed via torch | |
| return pred, prob_f.detach().cpu().numpy() | |
| def main(): | |
| parser = argparse.ArgumentParser(description="WireSegHR inference") | |
| parser.add_argument( | |
| "--config", type=str, default="configs/default.yaml", help="Path to YAML config" | |
| ) | |
| parser.add_argument("--image", type=str, required=False, help="Path to input image") | |
| parser.add_argument( | |
| "--images_dir", | |
| type=str, | |
| required=False, | |
| help="Directory with .jpg/.jpeg images", | |
| ) | |
| parser.add_argument( | |
| "--out", type=str, default="outputs/infer", help="Directory to save predictions" | |
| ) | |
| parser.add_argument( | |
| "--ckpt", | |
| type=str, | |
| default="", | |
| help="Optional checkpoint (.pt with {'model': state_dict} or .safetensors with pure state_dict)", | |
| ) | |
| parser.add_argument( | |
| "--save_prob", action="store_true", help="Also save probability .npy" | |
| ) | |
| # Metrics options | |
| parser.add_argument( | |
| "--metrics", | |
| action="store_true", | |
| help="Compute IoU, F1, Precision, Recall if ground-truth masks are provided", | |
| ) | |
| parser.add_argument( | |
| "--mask", | |
| type=str, | |
| default="", | |
| help="Path to ground-truth mask (.png) for --image when --metrics is set", | |
| ) | |
| parser.add_argument( | |
| "--masks_dir", | |
| type=str, | |
| default="", | |
| help="Directory with ground-truth masks (.png) for --images_dir when --metrics is set", | |
| ) | |
| # Benchmarking options | |
| parser.add_argument( | |
| "--benchmark", | |
| action="store_true", | |
| help="Run benchmarking on a directory (defaults to cfg.data.test_images)", | |
| ) | |
| parser.add_argument( | |
| "--bench_images_dir", | |
| type=str, | |
| default="", | |
| help="Images dir for benchmark (overrides cfg.data.test_images if set)", | |
| ) | |
| parser.add_argument( | |
| "--bench_masks_dir", | |
| type=str, | |
| default="", | |
| help="Masks dir for benchmark (overrides cfg.data.test_masks if set; used with --metrics)", | |
| ) | |
| parser.add_argument( | |
| "--bench_limit", | |
| type=int, | |
| default=0, | |
| help="Limit number of images for benchmark (0 means all)", | |
| ) | |
| parser.add_argument( | |
| "--bench_warmup", | |
| type=int, | |
| default=2, | |
| help="Number of warmup images (excluded from stats)", | |
| ) | |
| parser.add_argument( | |
| "--bench_size_filter", | |
| type=str, | |
| default="", | |
| help="Only benchmark images matching HxW, e.g. 3000x4000", | |
| ) | |
| parser.add_argument( | |
| "--bench_report_json", | |
| type=str, | |
| default="", | |
| help="Optional path to save JSON report of timings", | |
| ) | |
| args = parser.parse_args() | |
| cfg_path = args.config | |
| if not Path(cfg_path).is_absolute(): | |
| cfg_path = str(Path.cwd() / cfg_path) | |
| with open(cfg_path, "r") as f: | |
| cfg = yaml.safe_load(f) | |
| print("[WireSegHR][infer] Loaded config from:", cfg_path) | |
| pprint.pprint(cfg) | |
| # If benchmarking, do not require --image/--images_dir | |
| if not args.benchmark: | |
| assert (args.image is not None) ^ (args.images_dir is not None), ( | |
| "Provide exactly one of --image or --images_dir" | |
| ) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| precision = str(cfg["optim"].get("precision", "fp32")).lower() | |
| assert precision in ("fp32", "fp16", "bf16") | |
| amp_enabled = (device.type == "cuda") and (precision in ("fp16", "bf16")) | |
| amp_dtype = ( | |
| torch.float16 | |
| if precision == "fp16" | |
| else (torch.bfloat16 if precision == "bf16" else None) | |
| ) | |
| model = _build_model_from_cfg(cfg, device) | |
| ckpt_path = args.ckpt if args.ckpt else cfg.get("resume", "") | |
| if ckpt_path: | |
| assert Path(ckpt_path).is_file(), f"Checkpoint not found: {ckpt_path}" | |
| print(f"[WireSegHR][infer] Loading checkpoint: {ckpt_path}") | |
| suffix = Path(ckpt_path).suffix.lower() | |
| if suffix == ".safetensors": | |
| # Safetensors exports contain a pure state_dict | |
| state_dict = safe_load_file(ckpt_path) | |
| model.load_state_dict(state_dict) | |
| else: | |
| print( | |
| "[WireSegHR][infer][WARN] Loading a PyTorch checkpoint. Prefer .safetensors for inference-only weights." | |
| ) | |
| # PyTorch .pt/.pth checkpoints expected to have {'model': state_dict} | |
| state = torch.load(ckpt_path, map_location=device) | |
| assert "model" in state, ( | |
| "Expected a dict with key 'model' for PyTorch checkpoint. " | |
| "Use scripts/strip_checkpoint.py or provide a .safetensors file." | |
| ) | |
| model.load_state_dict(state["model"]) | |
| model.eval() | |
| # Benchmark mode | |
| if args.benchmark: | |
| # Resolve image and mask directories | |
| bench_dir = args.bench_images_dir or cfg["data"]["test_images"] | |
| assert Path(bench_dir).is_dir(), f"Not a directory: {bench_dir}" | |
| if args.metrics: | |
| bench_masks_dir = args.bench_masks_dir or cfg["data"]["test_masks"] | |
| assert Path(bench_masks_dir).is_dir(), f"Not a directory: {bench_masks_dir}" | |
| # Optional size filter | |
| size_filter: Optional[Tuple[int, int]] = None | |
| if args.bench_size_filter: | |
| try: | |
| h_str, w_str = args.bench_size_filter.lower().split("x") | |
| size_filter = (int(h_str), int(w_str)) | |
| except Exception: | |
| raise AssertionError( | |
| f"Invalid --bench_size_filter format: {args.bench_size_filter} (use HxW)" | |
| ) | |
| # Gather images | |
| img_files = sorted( | |
| [ | |
| str(Path(bench_dir) / p) | |
| for p in os.listdir(bench_dir) | |
| if p.lower().endswith((".jpg", ".jpeg")) | |
| ] | |
| ) | |
| assert len(img_files) > 0, f"No .jpg/.jpeg in {bench_dir}" | |
| # Filter by size if requested | |
| if size_filter is not None: | |
| filt_files: List[str] = [] | |
| for p in img_files: | |
| bgr = cv2.imread(p, cv2.IMREAD_COLOR) | |
| assert bgr is not None, f"Failed to read {p}" | |
| if bgr.shape[0] == size_filter[0] and bgr.shape[1] == size_filter[1]: | |
| filt_files.append(p) | |
| img_files = filt_files | |
| assert len(img_files) > 0, ( | |
| f"No images matching {size_filter[0]}x{size_filter[1]} in {bench_dir}" | |
| ) | |
| # Limit | |
| if args.bench_limit > 0: | |
| img_files = img_files[: args.bench_limit] | |
| print(f"[WireSegHR][bench] Images: {len(img_files)} from {bench_dir}") | |
| print(f"[WireSegHR][bench] Warmup: {args.bench_warmup}") | |
| def _sync(): | |
| if device.type == "cuda": | |
| torch.cuda.synchronize() | |
| timings: List[Dict[str, Any]] = [] | |
| # Metric accumulators (for timed runs only) | |
| if args.metrics: | |
| fine_sum: Dict[str, float] = {"iou": 0.0, "f1": 0.0, "precision": 0.0, "recall": 0.0} | |
| coarse_sum: Dict[str, float] = {"iou": 0.0, "f1": 0.0, "precision": 0.0, "recall": 0.0} | |
| n_metrics = 0 | |
| # Warmup | |
| for i in tqdm(range(min(args.bench_warmup, len(img_files))), desc="[bench] Warmup"): | |
| _ = infer_image( | |
| model, | |
| img_files[i], | |
| cfg, | |
| device, | |
| amp_enabled, | |
| amp_dtype, | |
| out_dir=None, | |
| save_prob=False, | |
| ) | |
| # Timed runs | |
| for p in tqdm(img_files[args.bench_warmup :], desc="[bench] Timed"): | |
| # Replicate internals to time coarse vs fine separately | |
| bgr = cv2.imread(p, cv2.IMREAD_COLOR) | |
| assert bgr is not None, f"Failed to read {p}" | |
| rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
| coarse_size = int(cfg["coarse"]["test_size"]) | |
| minmax_enable = bool(cfg["minmax"]["enable"]) | |
| minmax_kernel = int(cfg["minmax"]["kernel"]) | |
| _sync(); t0 = time.perf_counter() | |
| prob_c, cond_map, t_img, y_min_full, y_max_full = _coarse_forward( | |
| model, | |
| rgb, | |
| coarse_size, | |
| minmax_enable, | |
| minmax_kernel, | |
| device, | |
| amp_enabled, | |
| amp_dtype, | |
| ) | |
| _sync(); t1 = time.perf_counter() | |
| patch_size = int(cfg["inference"]["fine_patch_size"]) # 1024 | |
| overlap = int(cfg["fine"]["overlap"]) | |
| prob_f = _tiled_fine_forward( | |
| model, | |
| t_img, | |
| cond_map, | |
| y_min_full, | |
| y_max_full, | |
| patch_size, | |
| overlap, | |
| int(cfg.get("eval", {}).get("fine_batch", 16)), | |
| device, | |
| amp_enabled, | |
| amp_dtype, | |
| ) | |
| _sync(); t2 = time.perf_counter() | |
| # Optional metrics computation | |
| if args.metrics: | |
| stem = Path(p).stem | |
| gt_path = Path(bench_masks_dir) / f"{stem}.png" | |
| assert gt_path.is_file(), f"Missing mask for {stem}: {gt_path}" | |
| gt = cv2.imread(str(gt_path), cv2.IMREAD_GRAYSCALE) | |
| assert gt is not None, f"Failed to read mask: {gt_path}" | |
| gt_bin = (gt > 0).astype(np.uint8) | |
| prob_thresh = float(cfg["inference"]["prob_threshold"]) | |
| pred_coarse = (prob_c > prob_thresh).to(torch.uint8).cpu().numpy() | |
| pred_fine = (prob_f > prob_thresh).to(torch.uint8).cpu().numpy() | |
| m_c = compute_metrics(pred_coarse, gt_bin) | |
| m_f = compute_metrics(pred_fine, gt_bin) | |
| for k in coarse_sum: | |
| coarse_sum[k] += m_c[k] | |
| for k in fine_sum: | |
| fine_sum[k] += m_f[k] | |
| n_metrics += 1 | |
| timings.append( | |
| { | |
| "path": p, | |
| "H": int(t_img.shape[2]), | |
| "W": int(t_img.shape[3]), | |
| "t_coarse_ms": (t1 - t0) * 1000.0, | |
| "t_fine_ms": (t2 - t1) * 1000.0, | |
| "t_total_ms": (t2 - t0) * 1000.0, | |
| } | |
| ) | |
| if len(timings) == 0: | |
| print("[WireSegHR][bench] Nothing to benchmark after warmup.") | |
| return | |
| # Aggregate | |
| def _agg(key: str) -> Tuple[float, float, float]: | |
| vals = sorted([t[key] for t in timings]) | |
| n = len(vals) | |
| p50 = vals[n // 2] | |
| p95 = vals[min(n - 1, int(0.95 * (n - 1)))] | |
| avg = sum(vals) / n | |
| return avg, p50, p95 | |
| avg_c, p50_c, p95_c = _agg("t_coarse_ms") | |
| avg_f, p50_f, p95_f = _agg("t_fine_ms") | |
| avg_t, p50_t, p95_t = _agg("t_total_ms") | |
| print("[WireSegHR][bench] Results (ms):") | |
| print(f" Coarse avg={avg_c:.2f} p50={p50_c:.2f} p95={p95_c:.2f}") | |
| print(f" Fine avg={avg_f:.2f} p50={p50_f:.2f} p95={p95_f:.2f}") | |
| print(f" Total avg={avg_t:.2f} p50={p50_t:.2f} p95={p95_t:.2f}") | |
| print(f" Target < 1000 ms per 3000x4000 image: {'YES' if p50_t < 1000.0 else 'NO'}") | |
| if args.bench_report_json: | |
| import json | |
| report = { | |
| "summary": { | |
| "avg_ms": avg_t, | |
| "p50_ms": p50_t, | |
| "p95_ms": p95_t, | |
| "avg_coarse_ms": avg_c, | |
| "avg_fine_ms": avg_f, | |
| "images": len(timings), | |
| }, | |
| "per_image": timings, | |
| } | |
| report_path = args.bench_report_json | |
| Path(report_path).parent.mkdir(parents=True, exist_ok=True) | |
| with open(report_path, "w") as f: | |
| json.dump(report, f, indent=2) | |
| # Print aggregated metrics if requested | |
| if args.metrics: | |
| if n_metrics > 0: | |
| for k in fine_sum: | |
| fine_sum[k] /= n_metrics | |
| for k in coarse_sum: | |
| coarse_sum[k] /= n_metrics | |
| print( | |
| f"[WireSegHR][bench][Fine] IoU={fine_sum['iou']:.4f} F1={fine_sum['f1']:.4f} P={fine_sum['precision']:.4f} R={fine_sum['recall']:.4f}" | |
| ) | |
| print( | |
| f"[WireSegHR][bench][Coarse] IoU={coarse_sum['iou']:.4f} F1={coarse_sum['f1']:.4f} P={coarse_sum['precision']:.4f} R={coarse_sum['recall']:.4f}" | |
| ) | |
| return | |
| # Single image mode | |
| if args.image is not None: | |
| pred, _ = infer_image( | |
| model, | |
| args.image, | |
| cfg, | |
| device, | |
| amp_enabled, | |
| amp_dtype, | |
| out_dir=args.out, | |
| save_prob=args.save_prob, | |
| ) | |
| if args.metrics: | |
| assert args.mask, "--mask is required with --image when --metrics is set" | |
| assert Path(args.mask).is_file(), f"Mask not found: {args.mask}" | |
| gt = cv2.imread(args.mask, cv2.IMREAD_GRAYSCALE) | |
| assert gt is not None, f"Failed to read mask: {args.mask}" | |
| gt_bin = (gt > 0).astype(np.uint8) | |
| pred_bin = (pred > 0).astype(np.uint8) | |
| m = compute_metrics(pred_bin, gt_bin) | |
| print( | |
| f"[Infer] IoU={m['iou']:.4f} F1={m['f1']:.4f} P={m['precision']:.4f} R={m['recall']:.4f}" | |
| ) | |
| print("[WireSegHR][infer] Done.") | |
| return | |
| # Directory mode | |
| img_dir = args.images_dir | |
| assert Path(img_dir).is_dir(), f"Not a directory: {img_dir}" | |
| img_files = sorted( | |
| [p for p in os.listdir(img_dir) if p.lower().endswith((".jpg", ".jpeg"))] | |
| ) | |
| assert len(img_files) > 0, f"No .jpg/.jpeg in {img_dir}" | |
| os.makedirs(args.out, exist_ok=True) | |
| if args.metrics: | |
| assert args.masks_dir, ( | |
| "--masks_dir is required with --images_dir when --metrics is set" | |
| ) | |
| assert Path(args.masks_dir).is_dir(), f"Not a directory: {args.masks_dir}" | |
| metrics_sum = {"iou": 0.0, "f1": 0.0, "precision": 0.0, "recall": 0.0} | |
| n_eval = 0 | |
| for name in tqdm(img_files, desc="[infer] Dir"): | |
| path = str(Path(img_dir) / name) | |
| pred, _ = infer_image( | |
| model, | |
| path, | |
| cfg, | |
| device, | |
| amp_enabled, | |
| amp_dtype, | |
| out_dir=args.out, | |
| save_prob=args.save_prob, | |
| ) | |
| if args.metrics: | |
| stem = Path(name).stem | |
| mask_path = Path(args.masks_dir) / f"{stem}.png" | |
| assert mask_path.is_file(), f"Missing mask for {stem}: {mask_path}" | |
| gt = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE) | |
| assert gt is not None, f"Failed to read mask: {mask_path}" | |
| gt_bin = (gt > 0).astype(np.uint8) | |
| pred_bin = (pred > 0).astype(np.uint8) | |
| m = compute_metrics(pred_bin, gt_bin) | |
| for k in metrics_sum: | |
| metrics_sum[k] += m[k] | |
| n_eval += 1 | |
| if args.metrics and n_eval > 0: | |
| for k in metrics_sum: | |
| metrics_sum[k] /= n_eval | |
| print( | |
| f"[Infer][Avg over {n_eval}] IoU={metrics_sum['iou']:.4f} F1={metrics_sum['f1']:.4f} P={metrics_sum['precision']:.4f} R={metrics_sum['recall']:.4f}" | |
| ) | |
| print("[WireSegHR][infer] Done.") | |
| if __name__ == "__main__": | |
| main() | |