| import torch |
| from torchvision import transforms |
| from model import MNISTModel |
|
|
|
|
| class InferenceWrapper: |
| def __init__(self, model_path: str): |
| """ |
| Initialize the inference wrapper with a model path. |
| |
| Args: |
| model_path (str): Path to the model weights file |
| """ |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| self.model_path = model_path |
| self.model = self._load_model() |
| self.transform = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize((0.1307,), (0.3081,)) |
| ]) |
| |
| def _load_model(self): |
| """Load and return the model.""" |
| model = MNISTModel().to(self.device) |
| model.load_state_dict( |
| torch.load(self.model_path, map_location=self.device, weights_only=True) |
| ) |
| model.eval() |
| return model |
|
|
| def predict_tensor(self, input_tensor: torch.Tensor): |
| """ |
| Run inference on a single input tensor. |
| |
| Args: |
| input_tensor (torch.Tensor): Input tensor of shape [1, 28, 28] or [N, 1, 28, 28] |
| |
| Returns: |
| tuple: (prediction, confidence) |
| """ |
| with torch.no_grad(): |
| if input_tensor.dim() == 3: |
| input_tensor = input_tensor.unsqueeze(0) |
| |
| input_tensor = input_tensor.to(self.device) |
| output = self.model(input_tensor) |
| probs = torch.softmax(output, dim=1) |
| prediction = output.argmax(1).item() |
| confidence = probs[0][prediction].item() |
| return prediction, confidence |
|
|
| def predict_batch(self, input_tensors: torch.Tensor): |
| """ |
| Run inference on a batch of input tensors. |
| |
| Args: |
| input_tensors (torch.Tensor): Batch of input tensors of shape [N, 1, 28, 28] |
| |
| Returns: |
| tuple: (predictions, confidences) |
| """ |
| with torch.no_grad(): |
| input_tensors = input_tensors.to(self.device) |
| output = self.model(input_tensors) |
| probs = torch.softmax(output, dim=1) |
| predictions = output.argmax(1) |
| confidences = torch.gather(probs, 1, predictions.unsqueeze(1)).squeeze(1) |
| return predictions.cpu().numpy(), confidences.cpu().numpy() |
|
|
|
|
| def main(): |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--model-path', required=True, help='Path to the model weights') |
| args = parser.parse_args() |
| |
| |
| wrapper = InferenceWrapper(args.model_path) |
| |
| |
| test_input = torch.randn(1, 28, 28) |
| prediction, confidence = wrapper.predict_tensor(test_input) |
| print(f"Single prediction: {prediction}, confidence: {confidence:.4f}") |
| |
| |
| batch_input = torch.randn(4, 1, 28, 28) |
| predictions, confidences = wrapper.predict_batch(batch_input) |
| print(f"Batch predictions: {predictions}") |
| print(f"Batch confidences: {confidences}") |
|
|
| if __name__ == "__main__": |
| main() |