| """ |
| CutoutAI API Server |
| |
| FastAPI server providing: |
| - REST API endpoints for background removal |
| - Webhook endpoint for n8n/Make integration |
| - Health check for monitoring |
| - Startup model preloading |
| """ |
|
|
| import io |
| import base64 |
| import time |
| import logging |
| import httpx |
| from typing import Optional, Literal, Union |
| from pathlib import Path |
| from contextlib import asynccontextmanager |
|
|
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request |
| from fastapi.responses import Response, JSONResponse |
| from pydantic import BaseModel, Field |
|
|
| from cutoutai import CutoutAI, MODEL_VARIANTS, logger as cutout_logger |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| ) |
| logger = logging.getLogger("CutoutAI-API") |
|
|
| |
| _models: dict[str, CutoutAI] = {} |
|
|
| def get_model(variant: str = "matting") -> CutoutAI: |
| """Get or create a model instance for the specified variant.""" |
| global _models |
| if variant not in _models: |
| _models[variant] = CutoutAI(model_variant=variant) |
| _models[variant].load_model() |
| return _models[variant] |
|
|
|
|
| |
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| |
| print("Preloading matting model...") |
| get_model("matting") |
| print("Model preloaded and ready!") |
| yield |
| |
| _models.clear() |
|
|
|
|
| |
| app = FastAPI( |
| title="CutoutAI - Background Remover", |
| description="Flawless background removal for t-shirt mockups and design workflows", |
| version="1.1.0", |
| lifespan=lifespan |
| ) |
|
|
|
|
| |
| class ProcessOptions(BaseModel): |
| model: Literal["general", "matting", "portrait", "lite", "hr", "dynamic"] = "matting" |
| capture_all_elements: bool = True |
| edge_refinement: bool = True |
| edge_radius: int = 2 |
| threshold: Optional[float] = None |
| soft_threshold: bool = False |
| remove_artifacts: bool = True |
| min_artifact_size: int = 40 |
| adaptive_threshold: bool = True |
| return_mask: bool = False |
| output_format: Literal["png", "base64"] = "png" |
|
|
|
|
| class WebhookRequest(BaseModel): |
| image_base64: Optional[str] = None |
| image_url: Optional[str] = None |
| options: Optional[ProcessOptions] = None |
|
|
|
|
| class HealthResponse(BaseModel): |
| status: str |
| version: str |
| model_loaded: bool |
| models_loaded: list[str] |
| device: str |
|
|
|
|
| |
| @app.get("/health", response_model=HealthResponse) |
| async def health_check(): |
| """Health check endpoint for monitoring.""" |
| global _models |
| loaded_models = list(_models.keys()) |
| device = _models["matting"].device if "matting" in _models else "not loaded" |
| return HealthResponse( |
| status="healthy", |
| version="1.1.0", |
| model_loaded=len(_models) > 0, |
| models_loaded=loaded_models, |
| device=device |
| ) |
|
|
|
|
| @app.get("/") |
| async def root(): |
| """Root endpoint with API info.""" |
| return { |
| "name": "CutoutAI - Background Remover", |
| "version": "1.1.0", |
| "docs": "/docs", |
| "health": "/health" |
| } |
|
|
|
|
| @app.post("/api/v1/remove") |
| async def remove_bg( |
| image: UploadFile = File(...), |
| model: str = Form("matting"), |
| edge_refinement: bool = Form(True), |
| capture_all_elements: bool = Form(True), |
| threshold: Optional[float] = Form(None), |
| soft_threshold: bool = Form(False), |
| remove_artifacts: bool = Form(True), |
| adaptive_threshold: bool = Form(True), |
| return_mask: bool = Form(False), |
| output_format: str = Form("png") |
| ): |
| """ |
| Remove background from uploaded image. |
| |
| - **image**: Image file to process |
| - **model**: Model variant (matting recommended for designs) |
| - **edge_refinement**: Smooth edges for cleaner cutouts |
| - **capture_all_elements**: Lower threshold to capture bubbles/small elements |
| - **threshold**: Override mask threshold (0.0-1.0) |
| - **soft_threshold**: Use soft thresholding |
| - **remove_artifacts**: Remove small isolated islands from mask |
| - **adaptive_threshold**: Calculate threshold based on image confidence |
| - **return_mask**: Return a JSON object with both result and mask |
| - **output_format**: "png" for file download, "base64" for JSON response |
| """ |
| start_time = time.time() |
|
|
| try: |
| |
| if model not in MODEL_VARIANTS: |
| raise HTTPException(status_code=400, detail=f"Invalid model: {model}. Available variants: {list(MODEL_VARIANTS.keys())}") |
|
|
| |
| contents = await image.read() |
|
|
| |
| if len(contents) > 10 * 1024 * 1024: |
| raise HTTPException(status_code=413, detail="Image too large (max 10MB)") |
|
|
| |
| processor = get_model(model) |
| result = processor.process( |
| contents, |
| edge_refinement=edge_refinement, |
| capture_all_elements=capture_all_elements, |
| threshold=threshold, |
| soft_threshold=soft_threshold, |
| remove_artifacts=remove_artifacts, |
| adaptive_threshold=adaptive_threshold, |
| return_mask=return_mask, |
| output_format="bytes" if output_format == "png" and not return_mask else "base64" |
| ) |
|
|
| processing_time = time.time() - start_time |
|
|
| if return_mask: |
| |
| return JSONResponse({ |
| "success": True, |
| "result_base64": result["result"], |
| "mask_base64": result["mask"], |
| "threshold_used": round(result["threshold_used"], 4), |
| "processing_time_seconds": round(processing_time, 2) |
| }) |
|
|
| if output_format == "png": |
| return Response( |
| content=result, |
| media_type="image/png", |
| headers={ |
| "Content-Disposition": f'attachment; filename="{image.filename}_cutout.png"', |
| "X-Processing-Time": f"{processing_time:.2f}s" |
| } |
| ) |
| else: |
| return JSONResponse({ |
| "success": True, |
| "image_base64": result, |
| "processing_time_seconds": round(processing_time, 2) |
| }) |
|
|
| except HTTPException: |
| raise |
| except ValueError as e: |
| raise HTTPException(status_code=400, detail=str(e)) |
| except Exception as e: |
| logger.exception("Error processing request") |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") |
|
|
|
|
| @app.post("/api/v1/batch") |
| async def batch_remove( |
| images: list[UploadFile] = File(...), |
| model: str = Form("matting"), |
| capture_all_elements: bool = Form(True) |
| ): |
| """Process multiple images in batch.""" |
| start_time = time.time() |
| results = [] |
| processor = get_model(model) |
| |
| for img in images: |
| contents = await img.read() |
| result = processor.process( |
| contents, |
| capture_all_elements=capture_all_elements, |
| output_format="base64" |
| ) |
| results.append({ |
| "filename": img.filename, |
| "image_base64": result |
| }) |
| |
| total_time = time.time() - start_time |
| |
| return JSONResponse({ |
| "success": True, |
| "count": len(results), |
| "results": results, |
| "total_processing_time_seconds": round(total_time, 2) |
| }) |
|
|
|
|
| @app.post("/webhook") |
| async def webhook_handler( |
| request: Request, |
| image: Optional[UploadFile] = File(None), |
| image_base64: Optional[str] = Form(None), |
| image_url: Optional[str] = Form(None), |
| model: str = Form("matting"), |
| edge_refinement: bool = Form(True), |
| capture_all_elements: bool = Form(True), |
| edge_radius: int = Form(2), |
| threshold: Optional[float] = Form(None), |
| soft_threshold: bool = Form(False), |
| return_mask: bool = Form(False), |
| callback_url: Optional[str] = Form(None) |
| ): |
| """ |
| Webhook endpoint for n8n/Make integration. |
| |
| Accepts image via: |
| - File upload (image) |
| - Base64 encoded string (image_base64) |
| - URL to fetch (image_url) |
| |
| Returns base64 encoded result for easy workflow integration. |
| """ |
| start_time = time.time() |
| logger.info(f"Webhook request received from {request.client.host}") |
|
|
| try: |
| |
| if request.headers.get("content-type") == "application/json": |
| try: |
| body = await request.json() |
| image_base64 = body.get("image_base64", image_base64) |
| image_url = body.get("image_url", image_url) |
| model = body.get("model", model) |
| edge_refinement = body.get("edge_refinement", edge_refinement) |
| capture_all_elements = body.get("capture_all_elements", capture_all_elements) |
| edge_radius = body.get("edge_radius", edge_radius) |
| threshold = body.get("threshold", threshold) |
| soft_threshold = body.get("soft_threshold", soft_threshold) |
| return_mask = body.get("return_mask", return_mask) |
| callback_url = body.get("callback_url", callback_url) |
| except Exception as e: |
| logger.warning(f"Failed to parse JSON body: {e}") |
|
|
| |
| if model not in MODEL_VARIANTS: |
| logger.error(f"Invalid model requested: {model}") |
| return JSONResponse( |
| {"success": False, "error": f"Invalid model: {model}. Available: {list(MODEL_VARIANTS.keys())}"}, |
| status_code=400 |
| ) |
|
|
| processor = get_model(model) |
|
|
| |
| img_data = None |
| if image: |
| img_data = await image.read() |
| logger.info(f"Using uploaded file: {image.filename}") |
| elif image_base64: |
| try: |
| |
| if "," in image_base64: |
| image_base64 = image_base64.split(",")[1] |
| |
| image_base64 = "".join(image_base64.split()) |
| img_data = base64.b64decode(image_base64) |
| logger.info("Using base64 image data") |
| except Exception as e: |
| return JSONResponse({"success": False, "error": f"Invalid base64 data: {e}"}, status_code=400) |
| elif image_url: |
| logger.info(f"Fetching image from URL: {image_url}") |
| async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: |
| try: |
| response = await client.get(image_url) |
| response.raise_for_status() |
| img_data = response.content |
| except httpx.HTTPStatusError as e: |
| return JSONResponse({"success": False, "error": f"Failed to fetch image: {e.response.status_code}"}, status_code=400) |
| except Exception as e: |
| return JSONResponse({"success": False, "error": f"Network error: {e}"}, status_code=500) |
| else: |
| return JSONResponse( |
| {"success": False, "error": "No image provided. Use 'image', 'image_base64', or 'image_url'"}, |
| status_code=400 |
| ) |
|
|
| |
| if not img_data: |
| return JSONResponse({"success": False, "error": "Empty image data"}, status_code=400) |
|
|
| |
| result = processor.process( |
| img_data, |
| edge_refinement=edge_refinement, |
| capture_all_elements=capture_all_elements, |
| edge_radius=edge_radius, |
| threshold=threshold, |
| soft_threshold=soft_threshold, |
| return_mask=return_mask, |
| output_format="base64" |
| ) |
|
|
| processing_time = time.time() - start_time |
|
|
| if isinstance(result, dict): |
| response_data = { |
| "success": True, |
| "image_base64": result["result"], |
| "mask_base64": result["mask"], |
| "model_used": model, |
| "threshold_used": round(result.get("threshold_used", 0), 4), |
| "processing_time_seconds": round(processing_time, 2) |
| } |
| else: |
| response_data = { |
| "success": True, |
| "image_base64": result, |
| "model_used": model, |
| "processing_time_seconds": round(processing_time, 2) |
| } |
|
|
| |
| if callback_url: |
| logger.info(f"Sending callback to: {callback_url}") |
| async with httpx.AsyncClient(timeout=10.0) as client: |
| try: |
| await client.post(callback_url, json=response_data) |
| except Exception as e: |
| logger.error(f"Callback failed: {e}") |
| response_data["callback_error"] = str(e) |
|
|
| return JSONResponse(response_data) |
|
|
| except Exception as e: |
| logger.exception("Unexpected error in webhook handler") |
| return JSONResponse( |
| {"success": False, "error": str(e)}, |
| status_code=500 |
| ) |
|
|
|
|
| |
| if __name__ == "__main__": |
| import uvicorn |
| import argparse |
| import os |
|
|
| parser = argparse.ArgumentParser(description="CutoutAI API Server") |
| parser.add_argument("--host", default="0.0.0.0", help="Host address") |
| parser.add_argument("--port", type=int, default=int(os.environ.get("PORT", 8000)), help="Port number") |
| args = parser.parse_args() |
|
|
| uvicorn.run(app, host=args.host, port=args.port) |
|
|