import sys

import math
import os.path
import random
from dataclasses import dataclass
from typing import List, Tuple

import datasets
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import DataCollatorWithPadding, DataCollatorForSeq2Seq, AutoTokenizer
from transformers import PreTrainedTokenizer, BatchEncoding


class DatasetForEmbedding(Dataset):
    def __init__(
            self,
            dataset,
            tokenizer_path: str,
            max_len: int = 256,
            type: str = 'corpus'
    ):
        self.dataset = dataset
        self.total_len = len(dataset)
        self.max_length = max_len

        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
        self.prefix = '"'
        self.prefix_ids = self.tokenizer(self.prefix, return_tensors=None, add_special_tokens=True)['input_ids']
        if type == 'corpus':
            self.suffix = '", summarize the above passage within eight words: <s1><s2><s3><s4><s5><s6><s7><s8>'
            self.suffix_ids = self.tokenizer(self.suffix, return_tensors=None, add_special_tokens=False)[
                'input_ids']
            self.max_len = self.max_length - len(self.prefix_ids) - len(self.suffix_ids)
        else:
            self.suffix = '", predict the following passage within eight words: <s9><s10><s11><s12><s13><s14><s15><s16>'
            self.suffix_ids = self.tokenizer(self.suffix, return_tensors=None, add_special_tokens=False)[
                'input_ids']
            self.max_len = self.max_length - len(self.prefix_ids) - len(self.suffix_ids)

    def __len__(self):
        return self.total_len

    def __getitem__(self, item):
        passage = self.dataset[item]
        passage_inputs = self.tokenizer(passage,
                                        return_tensors=None,
                                        max_length=self.max_len,
                                        truncation=True,
                                        add_special_tokens=False)
        for i in range(len(passage_inputs['input_ids'])):
            passage_inputs['input_ids'][i] = self.prefix_ids + passage_inputs['input_ids'][i] + self.suffix_ids
            passage_inputs['attention_mask'][i] = [1] * len(passage_inputs['input_ids'][i])
        return passage_inputs

class collater():
    def __init__(self, tokenizer, max_len):
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __call__(self, data):
        return self.tokenizer.pad(
            data,
            padding=True,
            max_length=self.max_len,
            pad_to_multiple_of=8,
            return_tensors='pt',
        )