Spaces:
Runtime error
Runtime error
| # utils.py | |
| import os | |
| import numpy as np | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| import pickle | |
| DATASET_NAME = "JDhruv14/Bhagavad-Gita_Dataset" | |
| EMBED_MODEL = os.getenv("EMBED_MODEL", "all-MiniLM-L6-v2") | |
| VECSTORE_PATH = "vecstore.pkl" | |
| def load_gita(): | |
| """Load HF dataset and return pandas-like records list""" | |
| ds = load_dataset(DATASET_NAME, split="train") | |
| # convert to list of dicts | |
| rows = [] | |
| for r in ds: | |
| rows.append({ | |
| "chapter": int(r["chapter"]), | |
| "verse": int(r["verse"]), | |
| "sanskrit": r.get("sanskrit", ""), | |
| "hindi": r.get("hindi", ""), | |
| "english": r.get("english", ""), | |
| "transliteration": r.get("transliteration", "") | |
| }) | |
| return rows | |
| class VectorStore: | |
| def __init__(self, model_name=EMBED_MODEL, dim=None): | |
| self.model = SentenceTransformer(model_name) | |
| # determine dim from model | |
| test = self.model.encode(["test"]) | |
| self.dim = dim or test.shape[-1] | |
| self.index = faiss.IndexFlatL2(self.dim) | |
| self.meta = [] | |
| def build(self, docs, text_field="english"): | |
| texts = [d[text_field] or d["hindi"] or d["sanskrit"] for d in docs] | |
| embeddings = self.model.encode(texts, show_progress_bar=True, convert_to_numpy=True) | |
| self.index.add(embeddings.astype("float32")) | |
| self.meta = docs | |
| def save(self, path=VECSTORE_PATH): | |
| with open(path, "wb") as f: | |
| pickle.dump({"meta": self.meta, "index": self.index}, f) | |
| def load(self, path=VECSTORE_PATH): | |
| with open(path, "rb") as f: | |
| obj = pickle.load(f) | |
| self.meta = obj["meta"] | |
| self.index = obj["index"] | |
| def search(self, query, top_k=6): | |
| q_emb = self.model.encode([query]).astype("float32") | |
| D, I = self.index.search(q_emb, top_k) | |
| results = [] | |
| for idx in I[0]: | |
| if idx < len(self.meta): | |
| results.append(self.meta[idx]) | |
| return results |