File size: 5,216 Bytes
6677176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# regenerate_embeddings.py
"""
Regenerate answer embeddings using the MuRIL model.
This script:
 - downloads model (if MODEL_DIR is a repo id),
 - reads CSV at CSV_PATH,
 - computes mean-pooled, L2-normalized embeddings for 'answer' column,
 - saves embeddings to OUT_EMBED_PATH.

Exit codes:
 - 0 on success
 - non-zero on failure
"""
import os, argparse, math, sys
from pathlib import Path
import torch
import pandas as pd
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModel
from huggingface_hub import snapshot_download

def mean_pooling(last_hidden_state, attention_mask):
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
    sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask

def parse_env():
    # ENV-friendly arg parsing
    cfg = {}
    cfg['model_dir'] = os.getenv("MODEL_DIR", os.getenv("HF_REPO", "Sp2503/Finetuned-multilingualdataset-MuriL-model"))
    cfg['csv_path'] = os.getenv("CSV_PATH", "/app/export_artifacts/muril_multilingual_dataset.csv")
    cfg['out_path'] = os.getenv("OUT_EMBED_PATH", "/app/export_artifacts/answer_embeddings.pt")
    cfg['batch_size'] = int(os.getenv("EMBED_BATCH_SIZE", "64"))
    cfg['device'] = os.getenv("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
    cfg['download_cache'] = os.getenv("HF_CACHE_DIR", "/tmp/hf_cache")
    cfg['upload_back'] = os.getenv("UPLOAD_BACK", "false").lower() in ("1","true","yes")
    cfg['hf_repo'] = os.getenv("HF_REPO", None)  # used for upload_back if set
    return cfg

def main():
    cfg = parse_env()
    print("Regenerate embeddings with config:", cfg)
    model_dir = cfg['model_dir']
    # If model_dir looks like a HF repo id (contains '/'), snapshot_download to local cache
    if "/" in model_dir and not os.path.isdir(model_dir):
        print("Detected HF repo id for model. snapshot_download ->", cfg['download_cache'])
        try:
            model_dir = snapshot_download(repo_id=cfg['model_dir'], repo_type="model", cache_dir=cfg['download_cache'])
            print("Downloaded model to:", model_dir)
        except Exception as e:
            print("Failed to snapshot_download model:", e, file=sys.stderr)
            sys.exit(2)

    csv_path = cfg['csv_path']
    out_path = cfg['out_path']
    batch_size = cfg['batch_size']
    device = cfg['device']
    print(f"Loading CSV: {csv_path}")
    if not os.path.isfile(csv_path):
        print(f"CSV not found at {csv_path}", file=sys.stderr)
        sys.exit(3)
    df = pd.read_csv(csv_path, dtype=str).fillna("")
    if 'answer' not in df.columns:
        print("CSV must contain 'answer' column", file=sys.stderr)
        sys.exit(4)
    answers = df['answer'].astype(str).tolist()
    print(f"Encoding {len(answers)} answers on device {device} (batch_size={batch_size})")

    # Load tokenizer & model
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
        model = AutoModel.from_pretrained(model_dir)
        model.to(device)
        model.eval()
    except Exception as e:
        print("Failed to load model/tokenizer:", e, file=sys.stderr)
        sys.exit(5)

    # compute embeddings
    all_embs = []
    try:
        with torch.inference_mode():
            for i in tqdm(range(0, len(answers), batch_size), desc="Batches"):
                batch = answers[i:i+batch_size]
                enc = tokenizer(batch, padding=True, truncation=True, max_length=256, return_tensors="pt")
                input_ids = enc["input_ids"].to(device)
                attention_mask = enc["attention_mask"].to(device)
                out = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
                pooled = mean_pooling(out.last_hidden_state, attention_mask)      # (B, H)
                pooled = torch.nn.functional.normalize(pooled, p=2, dim=1)        # L2-normalize
                all_embs.append(pooled.cpu())
    except Exception as e:
        print("Error during encoding:", e, file=sys.stderr)
        sys.exit(6)

    all_embs = torch.cat(all_embs, dim=0)
    print("Final embeddings shape:", all_embs.shape)
    Path(out_path).parent.mkdir(parents=True, exist_ok=True)
    torch.save(all_embs, out_path)
    print("Saved embeddings to:", out_path)

    # Optional: upload back to HF repo (requires HF_TOKEN set and HF_REPO)
    if cfg['upload_back'] and cfg['hf_repo']:
        try:
            from huggingface_hub import HfApi
            api = HfApi()
            print(f"Uploading {out_path} back to repo {cfg['hf_repo']} ...")
            api.upload_file(
                path_or_fileobj=out_path,
                path_in_repo=os.path.basename(out_path),
                repo_id=cfg['hf_repo'],
                repo_type="model",
            )
            print("Upload complete.")
        except Exception as e:
            print("Upload back failed:", e, file=sys.stderr)

    # quick sanity check
    norms = (all_embs * all_embs).sum(dim=1)
    print("Sample norms (should be ~1.0):", norms[:5].tolist())
    return 0

if __name__ == "__main__":
    sys.exit(main())