from tqdm import trange

except_ids = [209, 210, 589, 590, 605, 606, 1640, 1646, 1792, 1793, 1794, 1795, 2207, 2208, 2209, 2210, 2486, 2487, 2488, 2489, 3125, 3126, 3153, 4234, 4235, 4482, 4483, 4513, 4514, 5158, 5159, 5635, 6955]
error_ids = [1, 13, 19, 44, 50, 51, 56, 58, 59, 68, 78, 82, 85, 89, 90, 103, 104, 123, 133, 134, 135, 141, 142, 143, 144, 145, 153, 154, 155, 156, 158, 161, 162, 168, 172, 174, 184, 188, 189, 190, 191, 192, 195, 196, 198, 205, 206, 208, 212, 213, 214, 215, 228, 242, 243, 244, 245, 259, 262, 269, 270, 271, 275, 276, 283, 284, 287, 289, 313, 314, 323, 324, 327, 328, 329, 330, 331, 332, 398, 405, 406, 407, 418, 427, 428, 429, 430, 431, 432, 435, 436, 441, 444, 451, 452, 455, 460, 489, 493, 494, 499, 500, 506, 507, 508, 511, 512, 514, 515, 516, 517, 518, 521, 522, 523, 524, 526, 527, 528, 529, 530, 534, 537, 540, 545, 546, 562, 563, 564, 565, 566, 570, 589, 590, 595, 596, 602, 603, 604, 605, 617, 618, 619, 620, 626, 631, 634, 650, 656, 659, 668, 671, 680, 681, 682, 688, 698, 705, 706, 710, 711, 712, 721, 722, 747, 748, 763, 767, 768, 790, 792, 808, 833, 834, 840, 844, 851, 852, 854, 858, 863, 864, 865, 866, 883, 884, 887, 888, 892, 893, 894, 899, 907, 908, 914, 915, 916, 924, 944, 945, 946, 947, 948, 959, 961, 962, 963, 964, 975, 1000, 1019, 1054, 1055, 1056, 1057, 1061, 1066, 1067, 1068, 1069, 1072, 1073, 1074, 1075, 1076, 1077, 1080, 1081, 1086, 1087, 1088, 1089, 1090, 1091, 1092, 1093, 1094, 1095, 1096, 1097, 1100, 1101, 1102, 1103, 1138, 1139, 1140, 1141, 1149, 1158, 1161, 1165, 1182, 1209, 1214, 1215, 1216, 1217, 1218, 1219, 1222, 1226, 1227, 1228, 1229, 1246, 1247, 1249, 1250, 1251, 1252, 1253, 1258, 1266, 1267, 1272, 1282, 1285, 1290, 1292, 1295, 1296, 1297, 1298, 1299, 1300, 1304, 1314, 1332, 1333, 1337, 1344, 1356, 1357, 1358, 1359, 1366, 1367, 1368, 1369, 1394, 1395, 1398, 1399, 1400, 1401, 1408, 1409, 1411, 1412, 1413, 1416, 1417, 1420, 1421, 1422, 1426, 1427, 1429, 1430, 1431, 1432, 1442, 1443, 1448, 1449, 1458, 1470, 1474, 1478, 1479, 1482, 1483, 1509, 1510, 1511, 1512, 1513, 1514, 1515, 1516, 1521, 1527, 1529, 1530, 1547, 1550, 1552, 1559, 1560, 1562, 1563, 1564, 1573, 1574, 1583, 1584, 1601, 1602, 1603, 1604, 1605, 1635, 1659, 1660, 1677, 1678, 1686, 1703, 1706, 1718, 1726, 1727, 1739, 1742, 1746, 1747, 1748, 1749, 1753, 1754, 1756, 1757, 1760, 1761, 1762, 1763, 1800, 1827, 1838, 1848, 1849, 1860, 1861, 1866, 1867, 1884, 1894, 1896, 1897, 1904, 1906, 1909, 1910, 1913, 1915, 1918, 1930, 1931, 1950, 1954, 1981, 1986, 1987, 1988, 1989, 1990, 1995, 1996, 1997, 1998, 2008, 2009, 2010, 2012, 2017, 2018, 2023, 2024, 2033, 2035, 2036, 2038, 2069, 2070, 2076, 2079, 2084, 2109, 2110, 2117, 2118, 2137, 2138, 2140, 2161, 2162, 2164, 2167, 2168, 2169, 2170, 2171, 2172, 2173, 2174, 2176, 2177, 2181, 2182, 2207, 2208, 2209, 2210, 2212, 2227, 2228, 2232, 2239, 2240, 2245, 2246, 2261, 2262, 2277, 2278, 2281, 2289, 2300, 2301, 2319, 2330, 2331, 2332, 2333, 2334, 2340, 2341, 2342, 2343, 2347, 2349, 2352, 2353, 2358, 2360, 2361, 2362, 2363, 2364, 2365, 2366, 2367, 2370, 2371, 2380, 2384, 2385, 2389, 2393, 2400, 2401, 2402, 2406, 2407, 2421, 2436, 2444, 2445, 2452, 2454, 2455, 2462, 2463, 2466, 2467, 2478, 2479, 2482, 2483, 2484, 2485, 2486, 2487, 2488, 2489, 2494, 2495, 2496, 2497, 2499, 2503, 2504, 2505, 2508, 2509, 2510, 2511, 2524, 2543, 2546, 2547, 2556, 2557, 2558, 2559, 2562, 2563, 2564, 2566, 2567, 2568, 2569, 2592, 2593, 2595, 2601, 2618, 2619, 2630, 2631, 2635, 2636, 2637, 2650, 2651, 2656, 2657, 2666, 2695, 2710, 2711, 2712, 2713, 2747, 2749, 2751, 2756, 2757, 2771, 2779, 2780, 2781, 2793, 2800, 2805, 2826, 2828, 2833, 2836, 2860, 2861, 2906, 2908, 2909, 2919, 2924, 2929, 2933, 2935, 2936, 2941, 2947, 2949, 2964, 2975, 2978, 2985, 2988, 2990, 2991, 3002, 3003, 3014, 3015, 3020, 3021, 3034, 3035, 3040, 3044, 3045, 3052, 3069, 3070, 3083, 3104, 3105, 3108, 3109, 3112, 3114, 3117, 3121, 3123, 3131, 3141, 3147, 3154, 3155, 3156, 3157, 3166, 3167, 3175, 3179, 3183, 3187, 3190, 3191, 3193, 3194, 3195, 3196, 3197, 3198, 3199, 3206, 3207, 3211, 3215, 3217, 3228, 3229, 3234, 3235, 3238, 3239, 3242, 3243, 3246, 3247, 3248, 3249, 3250, 3251, 3254, 3255, 3265, 3268, 3269, 3270, 3276, 3277, 3282, 3284, 3285, 3286, 3287, 3288, 3289, 3290, 3291, 3292, 3298, 3299, 3300, 3301, 3306, 3307, 3310, 3327, 3332, 3333, 3340, 3341, 3347, 3350, 3351, 3364, 3365, 3370, 3371, 3385, 3394, 3395, 3396, 3397, 3400, 3402, 3405, 3409, 3410, 3416, 3417, 3419, 3421, 3424, 3425, 3426, 3427, 3428, 3429, 3430, 3431, 3432, 3433, 3434, 3435, 3437, 3442, 3443, 3447, 3448, 3449, 3450, 3451, 3456, 3457, 3458, 3459, 3463, 3464, 3465, 3467, 3473, 3474, 3475, 3477, 3479, 3480, 3481, 3483, 3487, 3489, 3492, 3493, 3495, 3496, 3497, 3500, 3501, 3504, 3505, 3506, 3507, 3522, 3523, 3527, 3528, 3533, 3542, 3546, 3547, 3549, 3552, 3553, 3560, 3561, 3564, 3565, 3567, 3576, 3578, 3579, 3581, 3589, 3593, 3594, 3595, 3596, 3597, 3598, 3608, 3609, 3614, 3620, 3621, 3628, 3629, 3638, 3639, 3643, 3644, 3645, 3646, 3647, 3648, 3649, 3650, 3651, 3654, 3655, 3656, 3657, 3660, 3661, 3662, 3663, 3666, 3667, 3668, 3669, 3670, 3671, 3674, 3675, 3676, 3677, 3678, 3679, 3686, 3687, 3694, 3697, 3702, 3704, 3768, 3778, 3782, 3805, 3807, 3808, 3833, 3834, 3839, 3840, 3842, 3845, 3846, 3853, 3869, 3871, 3872, 3875, 3876, 3878, 3879, 3880, 3887, 3901, 3906, 3921, 3922, 3925, 3926, 3927, 3928, 3935, 3936, 3941, 3947, 3948, 3969, 3970, 4003, 4025, 4061, 4065, 4071, 4073, 4077, 4080, 4082, 4084, 4088, 4106, 4125, 4129, 4130, 4179, 4187, 4188, 4195, 4200, 4211, 4212, 4221, 4223, 4224, 4295, 4296, 4297, 4298, 4301, 4302, 4303, 4309, 4310, 4312, 4313, 4315, 4316, 4317, 4318, 4319, 4325, 4326, 4337, 4338, 4340, 4342, 4344, 4347, 4355, 4356, 4359, 4360, 4365, 4366, 4367, 4368, 4369, 4413, 4414, 4415, 4416, 4417, 4418, 4422, 4423, 4429, 4430, 4431, 4432, 4437, 4438, 4439, 4440, 4441, 4442, 4447, 4448, 4449, 4450, 4453, 4454, 4455, 4456, 4462, 4477, 4478, 4486, 4487, 4492, 4493, 4496, 4497, 4498, 4499, 4516, 4517, 4531, 4545, 4585, 4588, 4595, 4600, 4601, 4612, 4625, 4626, 4648, 4651, 4652, 4653, 4656, 4657, 4661, 4664, 4665, 4667, 4674, 4675, 4676, 4680, 4681, 4689, 4701, 4706, 4707, 4710, 4711, 4719, 4721, 4724, 4725, 4726, 4727, 4728, 4729, 4733, 4736, 4737, 4739, 4744, 4745, 4762, 4763, 4789, 4794, 4812, 4814, 4817, 4820, 4822, 4827, 4865, 4866, 4867, 4874, 4896, 4901, 4903, 4905, 4909, 4910, 4913, 4914, 4916, 4918, 4931, 4932, 4955, 4959, 4960, 4964, 4965, 4980, 4985, 4986, 4994, 4997, 4998, 4999, 5000, 5001, 5002, 5005, 5006, 5007, 5008, 5035, 5043, 5044, 5074, 5103, 5104, 5105, 5106, 5109, 5115, 5116, 5117, 5118, 5132, 5135, 5139, 5140, 5152, 5155, 5156, 5162, 5196, 5219, 5276, 5277, 5278, 5279, 5280, 5283, 5307, 5308, 5313, 5317, 5318, 5319, 5320, 5323, 5324, 5353, 5354, 5359, 5360, 5361, 5362, 5368, 5369, 5370, 5373, 5374, 5375, 5384, 5387, 5388, 5392, 5402, 5403, 5404, 5408, 5410, 5442, 5443, 5444, 5445, 5446, 5447, 5448, 5451, 5452, 5453, 5454, 5459, 5460, 5462, 5463, 5464, 5465, 5466, 5467, 5468, 5477, 5478, 5503, 5504, 5506, 5525, 5526, 5545, 5546, 5552, 5558, 5567, 5582, 5586, 5587, 5596, 5607, 5613, 5619, 5625, 5627, 5629, 5654, 5671, 5674, 5675, 5684, 5685, 5694, 5695, 5702, 5707, 5724, 5725, 5732, 5733, 5744, 5745, 5746, 5747, 5752, 5753, 5762, 5763, 5778, 5779, 5780, 5781, 5796, 5797, 5802, 5803, 5806, 5807, 5826, 5832, 5833, 5835, 5838, 5839, 5840, 5841, 5842, 5843, 5847, 5851, 5852, 5861, 5863, 5864, 5865, 5866, 5870, 5889, 5890, 5909, 5910, 5911, 5912, 5913, 5914, 5920, 5927, 5928, 5931, 5932, 5953, 5955, 5966, 5975, 5976, 5977, 5978, 5982, 5985, 5994, 6005, 6006, 6012, 6036, 6039, 6040, 6069, 6070, 6075, 6076, 6080, 6091, 6099, 6100, 6109, 6110, 6135, 6148, 6153, 6155, 6156, 6159, 6160, 6163, 6167, 6179, 6186, 6187, 6198, 6199, 6207, 6208, 6209, 6210, 6214, 6215, 6222, 6224, 6230, 6231, 6232, 6233, 6236, 6237, 6242, 6254, 6255, 6258, 6259, 6262, 6263, 6266, 6267, 6275, 6284, 6285, 6292, 6293, 6296, 6297, 6298, 6299, 6310, 6311, 6324, 6338, 6343, 6349, 6370, 6386, 6390, 6435, 6436, 6465, 6485, 6486, 6487, 6488, 6501, 6509, 6510, 6519, 6520, 6524, 6525, 6526, 6534, 6549, 6550, 6552, 6553, 6554, 6555, 6556, 6561, 6562, 6563, 6564, 6575, 6576, 6594, 6618, 6619, 6633, 6645, 6679, 6701, 6702, 6706, 6708, 6711, 6713, 6714, 6715, 6716, 6717, 6720, 6723, 6724, 6725, 6727, 6729, 6741, 6742, 6747, 6748, 6757, 6758, 6763, 6764, 6765, 6766, 6767, 6768, 6769, 6770, 6790, 6796, 6799, 6801, 6802, 6803, 6804, 6805, 6806, 6807, 6812, 6814, 6820, 6821, 6822, 6825, 6826, 6836, 6843, 6844, 6845, 6846, 6853, 6861, 6863, 6864, 6877, 6878, 6880, 6884, 6887, 6888, 6922, 6927, 6939, 6945, 6946, 6947, 6948, 6956]

import json
from tools import *

processed_dataset_path="/home/baaiks/cf/pycharm/SQLFrameWorkSpiderNew/generate_datasets/preprocessed_data_train2.json"
pred_path = "/home/baaiks/cf/pycharm/SQLFrameWorkSpiderNew/intermediate_datasets/first_round_train.sql"
db_path="/home/baaiks/cf/pycharm/spider-master/spider/database"

prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
This represents the SQLite SQL query that has been generated in response to the given question, along with the resulting outcome after executing the query.
Please judge its correctness based on the execution result and the explanation for the question.
If it's incorrect, output the correct sqlite SQL query; otherwise, output the original sqlite SQL query.

### Input:
### Sqlite SQL tables, with their properties:
#
{schema}
#
{foreign_keys}
#
### Question: {question}
### SQLite SQL query: {SQL}
### Run results: {result}

### Response:
"""

gold = []
pred = []

data = json.load(open(processed_dataset_path))
for i, d in enumerate(data):
    if i in except_ids:
        continue
    normal_sql = d['norm_sql']
    sql_skeleton = d['sql_skeleton']
    normal_sql_split = normal_sql.split()
    sql_skeleton_split = sql_skeleton.split()
    normal_sql_split_new = []
    for sp in normal_sql_split:
        if sp in sql_skeleton_split:
            normal_sql_split_new.append(sp.upper())
        else:
            normal_sql_split_new.append(sp)
    gold.append(' '.join(normal_sql_split_new))

with open(pred_path, 'r') as f:
    lines = f.readlines()
    for i, line in enumerate(lines):
        if i not in except_ids:
            pred.append(line.strip('\n'))

gold_data = []
idx = 0
for i in trange(len(data)):
    if i in except_ids:
        continue
    p = pred[idx]
    g = gold[idx]
    question = data[i]['question']
    # knowledge = dev[use_id]['evidence']
    foreign_keys = generate_foreign_key(data[i])
    schema = generate_schema(data[i])
    db_id = data[i]['db_id']
    db_dir = f'{db_path}/{db_id}/{db_id}.sqlite'
    result, flag = new_run_sql(db_dir, p)
    if flag is True:
        tmp_dict = {}
        tmp_dict['input'] = prompt.format(schema=schema,
                                          foreign_keys=foreign_keys,
                                          question=question,
                                          SQL=p,
                                          result=result)
        if idx in error_ids:
            tmp_dict['output'] = g
        else:
            tmp_dict['output'] = p
        gold_data.append(tmp_dict)
    idx += 1

with open('train_data.json', 'w') as f:
    json.dump(gold_data, f)