import torch import numpy as np from transformers import AutoTokenizer, AutoModel from typing import Dict, Any, List, Optional, Tuple from backend.config import EMBEDDING_MODEL_NAME, MAX_LENGTH, DEVICE tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL_NAME) model = AutoModel.from_pretrained(EMBEDDING_MODEL_NAME).to(DEVICE).eval() def mean_pooling(model_output, attention_mask): token_embeddings = model_output[0] mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() return torch.sum(token_embeddings * mask, 1) / torch.clamp(mask.sum(1), min=1e-9) def get_embeddings(texts): inputs = tokenizer(texts, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt").to(DEVICE) with torch.no_grad(): outputs = model(**inputs) emb = mean_pooling(outputs, inputs["attention_mask"]) return torch.nn.functional.normalize(emb, p=2, dim=1).cpu().numpy() def combine_embeddings(content_emb: np.ndarray) -> np.ndarray: combined = content_emb norm = np.linalg.norm(combined) if norm > 1e-8: return combined / norm return combined def process_chunk_data(payload: Dict[str, Any]) -> Tuple[str, List[str]]: markdown_content = "" images = [] if payload.get("source_file"): markdown_content += f"\n\n**File gốc:** {payload['source_file']}\n\n" if payload.get("markdown_data"): markdown_content += payload["markdown_data"] if payload.get("images"): if isinstance(payload["images"], list): for img_data in payload["images"]: if isinstance(img_data, str): images.append(img_data) elif isinstance(img_data, dict) and img_data.get("data"): images.append(img_data["data"]) return markdown_content, images