import argparse
import json
import os
import sys

import tiktoken
from tornado import concurrent
from tqdm import tqdm

from langchain_openai import ChatOpenAI

from prompts import *
from tools import *

import os

os.environ["OPENAI_API_KEY"] = 'sk-VH5XAwQCiu6LT45fA34c2cBfAd26476096922d072b5c56A4'
os.environ["OPENAI_API_BASE"] = 'https://api.xiaoai.plus/v1'


def parse_option():
    parser = argparse.ArgumentParser("")

    parser.add_argument('--dev_path', type=str, default="/home/baaiks/cf/pycharm/DAMO-ConvAI/bird/llm/data/dev/dev.json")
    parser.add_argument('--data_path', type=str, default="/home/baaiks/cf/pycharm/EvolSQL-version2-fix/generate_datasets_bird/preprocessed_data.json")
    parser.add_argument('--input_path', type=str, default="/home/baaiks/cf/pycharm/EvolSQL-version2-fix/intermediate_datasets_bird/third_round.sql")
    parser.add_argument('--db_path', type=str, default="/home/baaiks/cf/pycharm/DAMO-ConvAI/bird/llm/data/dev/dev_databases")

    opt = parser.parse_args()

    return opt


class AnalyzeTool:
    def __init__(self):
        self.encoder = tiktoken.encoding_for_model("text-davinci-003")

        self.prompt_template = find_error_reason
        self.llm = ChatOpenAI(temperature=0, model_name="gpt-4", request_timeout=600, max_retries=3)
        self.llm_long = ChatOpenAI(temperature=0, model_name="gpt-4", request_timeout=600, max_retries=3)

    def run(self, question, schema, foreign_keys, error_SQL, error_result, gold_sql, gold_result):
        prompt = self.prompt_template.format(question=question, schema=schema, foreign_keys=foreign_keys,
                                             error=error_SQL, error_res=error_result,
                                             sql=gold_sql, res=gold_result).strip()
        prompt = '\n'.join([' '.join(e.split()) for e in prompt.split('\n')])
        try:
            if len(self.encoder.encode(prompt)) < 3800:
                reflect = self.llm.predict(prompt)
            else:
                reflect = self.llm_long.predict(prompt)
        except:
            self.llm.temperature = 0.5
            if len(self.encoder.encode(prompt)) < 3800:
                reflect = self.llm.predict(prompt)
            else:
                reflect = self.llm_long.predict(prompt)
            self.llm.temperature = 0
        print(reflect)
        return reflect


def correct_sql(dev, data_all, sqls, db_path, errors):
    analyzer = AnalyzeTool()

    all_res = []
    for i in tqdm(errors):
        tmp = {}
        db_id = dev[i]['db_id']
        db_dir = f'{db_path}/{db_id}/{db_id}.sqlite'
        sql = sqls[i].strip()
        golsql = dev[i]['SQL']
        result, flag = new_run_sql(db_dir, sql)
        goldresult, goldflag = new_run_sql(db_dir, golsql)
        question = dev[i]['question']
        foreign_keys = generate_foreign_key(data_all[i])
        schema = generate_schema(data_all[i])
        analysis_result = analyzer.run(question, schema, foreign_keys, sql, result, golsql, goldresult)
        tmp['question'] = question
        tmp['schema'] = schema
        tmp['foreign_keys'] = foreign_keys
        tmp['pre-SQL'] = sql
        tmp['pre-res'] = str(result)
        tmp['pre-flag'] = str(flag)
        tmp['gold-SQL'] = str(golsql)
        tmp['gold-res'] = str(goldresult)
        tmp['gold-flag'] = str(goldflag)
        tmp['error-type'] = analysis_result
        all_res.append(tmp)

        with open('error_record_gpt4.json', 'w') as f:
            json.dump(all_res, f)



def get_output_name(path, idx):
    paths = path.split('.')
    paths[-2] = paths[-2] + str(idx)
    return '.'.join(paths)

def main(opt):
    dev = json.load(open(opt.dev_path))
    data_all = json.load(open(opt.data_path))
    input_path = opt.input_path
    db_path = opt.db_path

    with open(input_path, 'r') as f:
        last_sqls = f.readlines()
    for i, e in enumerate(last_sqls):
        last_sqls[i] = e.strip()

    errors = [1, 2, 3, 4, 8, 9, 11, 12, 13, 14, 16, 19, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 36, 40, 41, 42, 43, 46, 47, 48, 49, 50, 51, 53, 54, 56, 59, 62, 63, 65, 68, 70, 71, 72, 74, 75, 76, 77, 78, 80, 81, 82, 83, 84, 85, 86, 89, 90, 92, 93, 94, 95, 96, 97, 98, 99, 101, 106, 107, 108, 109, 112, 113, 114, 115, 118, 119, 124, 125, 128, 129, 130, 132, 133, 135, 136, 137, 141, 142, 143, 145, 149, 150, 153, 155, 157, 159, 162, 163, 165, 168, 169, 170, 171, 172, 173, 174, 176, 177, 179, 180, 181, 182, 183, 185, 186, 189, 192, 193, 197, 198, 199, 201, 205, 207, 211, 212, 215, 217, 218, 219, 220, 221, 223, 225, 231, 232, 234, 237, 239, 244, 245, 247, 248, 251, 252, 254, 255, 263, 264, 267, 269, 271, 280, 281, 282, 284, 286, 290, 292, 296, 300, 303, 306, 309, 310, 311, 317, 321, 324, 326, 328, 330, 332, 335, 337, 338, 340, 341, 342, 343, 344, 347, 349, 352, 354, 357, 359, 360, 361, 363, 366, 376, 383, 386, 387, 388, 389, 391, 392, 393, 394, 395, 397, 398, 399, 402, 403, 406, 407, 408, 409, 410, 411, 412, 415, 416, 417, 423, 424, 425, 427, 428, 429, 430, 431, 432, 433, 436, 437, 438, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 453, 454, 456, 458, 463, 465, 467, 469, 472, 473, 476, 477, 478, 479, 482, 484, 487, 491, 494, 498, 500, 506, 507, 511, 512, 513, 514, 516, 519, 520, 521, 523, 525, 530, 533, 536, 556, 562, 571, 582, 583, 584, 586, 587, 590, 592, 593, 595, 596, 598, 600, 602, 603, 604, 605, 606, 608, 610, 614, 616, 617, 620, 628, 630, 631, 632, 635, 636, 637, 639, 640, 642, 643, 646, 649, 652, 655, 656, 662, 663, 665, 667, 670, 672, 679, 682, 683, 685, 686, 687, 689, 693, 694, 698, 701, 706, 708, 709, 710, 711, 712, 716, 720, 743, 756, 766, 767, 775, 791, 800, 802, 803, 805, 810, 811, 812, 819, 828, 832, 846, 851, 852, 855, 857, 860, 861, 864, 866, 869, 871, 872, 876, 879, 880, 883, 888, 889, 892, 894, 895, 896, 897, 898, 903, 904, 908, 911, 913, 921, 922, 928, 929, 930, 935, 936, 937, 941, 942, 944, 947, 948, 951, 952, 953, 954, 955, 956, 957, 958, 959, 961, 962, 963, 966, 967, 970, 972, 973, 974, 977, 978, 979, 981, 982, 983, 984, 985, 986, 987, 989, 990, 993, 994, 996, 998, 999, 1000, 1001, 1004, 1006, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1021, 1023, 1024, 1027, 1029, 1031, 1034, 1036, 1037, 1041, 1050, 1052, 1058, 1064, 1068, 1084, 1085, 1087, 1091, 1092, 1093, 1094, 1099, 1102, 1107, 1108, 1110, 1113, 1115, 1118, 1119, 1120, 1121, 1122, 1126, 1127, 1128, 1131, 1135, 1139, 1144, 1148, 1149, 1152, 1153, 1160, 1161, 1163, 1166, 1168, 1169, 1170, 1171, 1172, 1173, 1174, 1175, 1176, 1177, 1178, 1179, 1181, 1182, 1185, 1186, 1187, 1189, 1190, 1191, 1192, 1195, 1196, 1199, 1200, 1203, 1204, 1205, 1207, 1211, 1213, 1216, 1217, 1219, 1223, 1224, 1225, 1226, 1227, 1231, 1232, 1233, 1234, 1235, 1236, 1237, 1238, 1239, 1241, 1242, 1243, 1244, 1245, 1247, 1248, 1249, 1250, 1252, 1253, 1254, 1255, 1256, 1257, 1258, 1259, 1263, 1264, 1265, 1266, 1268, 1269, 1271, 1272, 1273, 1274, 1275, 1276, 1277, 1279, 1280, 1281, 1284, 1285, 1289, 1290, 1291, 1292, 1294, 1295, 1297, 1298, 1300, 1302, 1305, 1306, 1307, 1308, 1309, 1310, 1318, 1322, 1324, 1335, 1336, 1338, 1343, 1350, 1352, 1359, 1365, 1366, 1370, 1376, 1387, 1388, 1389, 1390, 1391, 1396, 1397, 1399, 1401, 1404, 1405, 1407, 1419, 1425, 1427, 1431, 1433, 1437, 1441, 1444, 1445, 1446, 1448, 1450, 1451, 1453, 1454, 1456, 1467, 1468, 1472, 1473, 1475, 1476, 1477, 1478, 1480, 1481, 1482, 1485, 1491, 1498, 1499, 1500, 1501, 1503, 1512, 1517, 1520, 1524, 1526, 1527, 1529, 1530, 1531, 1533]


    correct_sql(dev, data_all, last_sqls, db_path, errors)


if __name__ == "__main__":
    opt = parse_option()
    main(opt)