Test / backend /services /embeddings.py
UKielz's picture
Upload 14 files
0bbe8e9 verified
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
from typing import Dict, Any, List, Optional, Tuple
from backend.config import EMBEDDING_MODEL_NAME, MAX_LENGTH, DEVICE
tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL_NAME)
model = AutoModel.from_pretrained(EMBEDDING_MODEL_NAME).to(DEVICE).eval()
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0]
mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)
def get_embeddings(texts):
inputs = tokenizer(texts, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = model(**inputs)
emb = mean_pooling(outputs, inputs["attention_mask"])
return torch.nn.functional.normalize(emb, p=2, dim=1).cpu().numpy()
def combine_embeddings(content_emb: np.ndarray) -> np.ndarray:
combined = content_emb
norm = np.linalg.norm(combined)
if norm > 1e-8:
return combined / norm
return combined
def process_chunk_data(payload: Dict[str, Any]) -> Tuple[str, List[str]]:
markdown_content = ""
images = []
if payload.get("source_file"):
markdown_content += f"\n\n**File gốc:** {payload['source_file']}\n\n"
if payload.get("markdown_data"):
markdown_content += payload["markdown_data"]
if payload.get("images"):
if isinstance(payload["images"], list):
for img_data in payload["images"]:
if isinstance(img_data, str):
images.append(img_data)
elif isinstance(img_data, dict) and img_data.get("data"):
images.append(img_data["data"])
return markdown_content, images