import argparse

from new_llama_modl_search import EvaluateLlamaModelSearch
from new_llama_modl_none import EvaluateLlamaModelSearchNone
from mteb import MTEB
import torch

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--device_id', default=0, type=int)


    return parser.parse_args()


if __name__ == "__main__":
    args = get_args()
    device_id = args.device_id
    torch.cuda.set_device(torch.cuda.device(device_id))

    task_names = [
        ['CQADupstackGamingRetrieval', 'CQADupstackGisRetrieval', 'CQADupstackMathematicaRetrieval',
         'CQADupstackPhysicsRetrieval', 'CQADupstackAndroidRetrieval', 'CQADupstackEnglishRetrieval',
         'CQADupstackWebmastersRetrieval', 'CQADupstackWordpressRetrieval', 'CQADupstackProgrammersRetrieval',
         'CQADupstackStatsRetrieval', 'CQADupstackTexRetrieval', 'CQADupstackUnixRetrieval'],
        ['CQADupstackGamingRetrieval', 'CQADupstackGisRetrieval', 'CQADupstackMathematicaRetrieval',
         'CQADupstackPhysicsRetrieval', 'CQADupstackAndroidRetrieval', 'CQADupstackEnglishRetrieval',
         'CQADupstackWebmastersRetrieval', 'CQADupstackWordpressRetrieval', 'CQADupstackProgrammersRetrieval',
         'CQADupstackStatsRetrieval', 'CQADupstackTexRetrieval', 'CQADupstackUnixRetrieval'],
        ['CQADupstackGamingRetrieval', 'CQADupstackGisRetrieval', 'CQADupstackMathematicaRetrieval',
         'CQADupstackPhysicsRetrieval', 'CQADupstackAndroidRetrieval', 'CQADupstackEnglishRetrieval',
         'CQADupstackWebmastersRetrieval', 'CQADupstackWordpressRetrieval', 'CQADupstackProgrammersRetrieval',
         'CQADupstackStatsRetrieval', 'CQADupstackTexRetrieval', 'CQADupstackUnixRetrieval'],
        ['CQADupstackGamingRetrieval', 'CQADupstackGisRetrieval', 'CQADupstackMathematicaRetrieval',
         'CQADupstackPhysicsRetrieval', 'CQADupstackAndroidRetrieval', 'CQADupstackEnglishRetrieval',
         'CQADupstackWebmastersRetrieval', 'CQADupstackWordpressRetrieval', 'CQADupstackProgrammersRetrieval',
         'CQADupstackStatsRetrieval', 'CQADupstackTexRetrieval', 'CQADupstackUnixRetrieval'],
        [],
        [],
        [],
        []
    ]

    # with torch.cuda.amp.autocast():
    task_names = task_names[device_id]
    for task in task_names:
        if device_id == 0:
            model = EvaluateLlamaModelSearch("BAAI/LLARA-beir",
                                             "BAAI/LLARA-beir",
                                             batch_size=12)
            model.mode = 'q2q'
        elif device_id == 1:
            model = EvaluateLlamaModelSearch("BAAI/LLARA-beir",
                                             "BAAI/LLARA-beir",
                                             batch_size=12)
            model.mode = 'p2p'
        elif device_id == 2:
            model = EvaluateLlamaModelSearch("BAAI/LLARA-beir",
                                             "BAAI/LLARA-beir",
                                             batch_size=12)
            model.mode = 'q2p'
        else:
            model = EvaluateLlamaModelSearchNone("BAAI/LLARA-beir",
                                                 "BAAI/LLARA-beir",
                                                 batch_size=12)
            model.mode = 'q2p'

        evaluation = MTEB(tasks=[task], task_langs=['en'], eval_splits = ["test" if task not in ['MSMARCO'] else 'dev'])
        evaluation.run(model, output_folder=f"en_results/{device_id}")