| from transformers import AutoModelForCausalLM, AutoProcessor |
| from PIL import Image |
| import requests |
| import torch |
| import io |
|
|
| class EndpointHandler: |
| def __init__(self, model_dir): |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.model = AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True).to(device) |
| self.processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True) |
| self.device = device |
|
|
| def __call__(self, data): |
| try: |
| url = data.get("inputs", {}).get("url") |
| if not url: |
| return {"error": "Missing URL"} |
| |
| headers = { |
| "User-Agent": "Mozilla/5.0", |
| "Accept": "image/*" |
| } |
| response = requests.get(url, headers=headers, verify=False) |
| response.raise_for_status() |
| |
| image_data = io.BytesIO(response.content) |
| image = Image.open(image_data).convert("RGB") |
| |
| inputs = self.processor( |
| text="<MORE_DETAILED_CAPTION>", |
| images=image, |
| return_tensors="pt" |
| ).to(self.device) |
|
|
| with torch.inference_mode(): |
| output = self.model.generate( |
| **inputs, |
| max_new_tokens=512, |
| num_beams=3 |
| ) |
|
|
| text = self.processor.batch_decode(output, skip_special_tokens=True)[0] |
| return {"caption": text} |
| |
| except Exception as e: |
| return {"error": str(e)} |