| import os |
| import torch |
| import torch.nn as nn |
| from torchvision import models, transforms |
| from PIL import Image |
| import json |
| import sys |
|
|
| |
| print("Handler module loaded") |
| print(f"Python version: {sys.version}") |
| print(f"PyTorch version: {torch.__version__}") |
| print(f"Directory contents: {os.listdir('.')}") |
| if os.path.exists('/repository'): |
| print(f"Repository directory contents: {os.listdir('/repository')}") |
|
|
| |
| class ViTForImageClassification: |
| @staticmethod |
| def from_pretrained(model_dir): |
| |
| print(f"ERROR: ViTForImageClassification.from_pretrained was called with {model_dir}") |
| raise ValueError("ViTForImageClassification is not the correct model for this application") |
|
|
| class EndpointHandler: |
| def __init__(self, model_dir): |
| """ |
| Initialize the model for AI image detection |
| """ |
| print(f"Initializing EndpointHandler with model_dir: {model_dir}") |
| |
| |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {self.device}") |
| |
| |
| self.transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
| ]) |
| |
| |
| self.classes = ["Real Image", "AI-Generated Image"] |
| |
| |
| try: |
| self.model = self._load_model(model_dir) |
| print("Model loaded successfully") |
| except Exception as e: |
| print(f"Error loading model: {e}") |
| |
| print("Creating a dummy model as fallback") |
| self.model = models.efficientnet_v2_s(pretrained=True) |
| self.model.classifier[-1] = nn.Linear( |
| self.model.classifier[-1].in_features, 2 |
| ) |
| self.model.eval() |
| |
| def _load_model(self, model_dir): |
| print(f"Loading model from directory: {model_dir}") |
| print(f"Directory contents: {os.listdir(model_dir)}") |
| |
| |
| model = models.efficientnet_v2_s(weights=None) |
| |
| |
| model.classifier = nn.Sequential( |
| nn.Linear(model.classifier[1].in_features, 1024), |
| nn.ReLU(), |
| nn.Dropout(p=0.3), |
| nn.Linear(1024, 512), |
| nn.ReLU(), |
| nn.Dropout(p=0.3), |
| nn.Linear(512, 2) |
| ) |
| |
| |
| model_found = False |
| possible_paths = [ |
| os.path.join(model_dir, "best_model_improved.pth"), |
| os.path.join(model_dir, "pytorch_model.bin"), |
| "best_model_improved.pth", |
| "/repository/best_model_improved.pth" |
| ] |
| |
| for model_path in possible_paths: |
| print(f"Trying model path: {model_path}") |
| if os.path.exists(model_path): |
| print(f"Found model at: {model_path}") |
| model.load_state_dict(torch.load(model_path, map_location=self.device)) |
| model_found = True |
| break |
| |
| if not model_found: |
| |
| if os.path.exists('best_model_improved.pth') and not os.path.exists(os.path.join(model_dir, 'best_model_improved.pth')): |
| import shutil |
| print(f"Copying model file to {model_dir}") |
| shutil.copy('best_model_improved.pth', os.path.join(model_dir, 'best_model_improved.pth')) |
| model.load_state_dict(torch.load(os.path.join(model_dir, 'best_model_improved.pth'), map_location=self.device)) |
| model_found = True |
| |
| if not model_found: |
| raise FileNotFoundError(f"Model file not found in any of these locations: {possible_paths}") |
| |
| model.to(self.device) |
| model.eval() |
| return model |
| |
| def __call__(self, data): |
| """ |
| Run prediction on the input data |
| """ |
| try: |
| print(f"Received prediction request with data type: {type(data)}") |
| |
| |
| if isinstance(data, dict) and "inputs" in data: |
| |
| input_data = data["inputs"] |
| print(f"Extracted input data from API format, type: {type(input_data)}") |
| else: |
| |
| input_data = data |
| |
| |
| if isinstance(input_data, str): |
| print("Processing base64 string image") |
| import base64 |
| from io import BytesIO |
| |
| |
| if ',' in input_data: |
| input_data = input_data.split(",", 1)[1] |
| image_bytes = base64.b64decode(input_data) |
| image = Image.open(BytesIO(image_bytes)).convert("RGB") |
| elif hasattr(input_data, "read"): |
| print("Processing file-like object image") |
| image = Image.open(input_data).convert("RGB") |
| elif isinstance(input_data, Image.Image): |
| print("Processing PIL Image") |
| image = input_data |
| else: |
| print(f"Unsupported input type: {type(input_data)}") |
| return {"error": f"Unsupported input type: {type(input_data)}"} |
| |
| |
| image_tensor = self.transform(image).unsqueeze(0).to(self.device) |
| |
| |
| with torch.no_grad(): |
| outputs = self.model(image_tensor) |
| probabilities = torch.nn.functional.softmax(outputs, dim=1)[0] |
| prediction = torch.argmax(probabilities).item() |
| |
| |
| real_prob = probabilities[0].item() * 100 |
| ai_prob = probabilities[1].item() * 100 |
| |
| |
| |
| return [ |
| { |
| "label": "Real Image", |
| "score": float(real_prob) |
| }, |
| { |
| "label": "AI-Generated Image", |
| "score": float(ai_prob) |
| } |
| ] |
| |
| except Exception as e: |
| import traceback |
| print(f"Error during prediction: {e}") |
| traceback.print_exc() |
| return {"error": str(e), "traceback": traceback.format_exc()} |
|
|