| | model_name_or_path="alfaxadeyembe/gemma2-2b-swahili-preview" |
| | from transformers import AutoModelForSequenceClassification, AutoTokenizer |
| | import torch |
| |
|
| | class EndpointHandler: |
| | def __init__(self, model_name_or_path): |
| | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) |
| | self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path) |
| | self.model.eval() |
| |
|
| | def __call__(self, data): |
| | inputs = data.get("inputs", "") |
| | tokens = self.tokenizer(inputs, return_tensors='pt') |
| | with torch.no_grad(): |
| | outputs = self.model(**tokens) |
| | return outputs |
| |
|