import torch
from peft import PeftModel
from sentence_transformers import SentenceTransformer
from torch import Tensor, device
import torch.multiprocessing as mp
from typing import List, Dict, Union, Tuple
import numpy as np
import logging
from datasets import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm, trange
from transformers import AutoModelForCausalLM, AutoTokenizer

from data import DatasetForEmbedding, collater

logger = logging.getLogger(__name__)


class EvaluateLlamaModelSearch:
    def __init__(self,
                 base_model_name: str = None,
                 tokenizer_name: str = None,
                 sep: str = " ",
                 max_length: int = 512,
                 batch_size: int = 6,
                 **kwargs):
        self.sep = sep
        self.max_length = max_length
        self.batch_size = batch_size

        self.model = AutoModelForCausalLM.from_pretrained(base_model_name, cache_dir='/share/shared_models')
        self.model.eval()
        self.model.half()

        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, cache_dir='/share/shared_models')
        self.tokenizer_name = tokenizer_name

        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.mode = 'q2p'

    def start_multi_process_pool(self, target_devices: List[str] = None) -> Dict[str, object]:
        logger.info("Start multi-process pool on devices: {}".format(', '.join(map(str, target_devices))))
        ctx = mp.get_context('spawn')
        input_queue = ctx.Queue()
        output_queue = ctx.Queue()
        processes = []

        for process_id, device_name in enumerate(target_devices):
            p = ctx.Process(target=SentenceTransformer._encode_multi_process_worker,
                            args=(process_id, device_name, self, input_queue, output_queue), daemon=True)
            p.start()
            processes.append(p)

        return {'input': input_queue, 'output': output_queue, 'processes': processes}

    def stop_multi_process_pool(self, pool: Dict[str, object]):
        output_queue = pool['output']
        [output_queue.get() for _ in range(len(pool['processes']))]

        for p in pool['processes']:
            p.terminate()

        for p in pool['processes']:
            p.join()
            p.close()

        pool['input'].close()
        pool['output'].close()

    def encode_queries(self, queries: List[str], batch_size: int = 8, **kwargs) -> Union[
        List[Tensor], np.ndarray, Tensor]:
        if self.mode == 'q2p' or self.mode == 'q2q':
            return self.encode_process(queries, batch_size=batch_size, type='queries', **kwargs)
        else:
            return self.encode_process(queries, batch_size=batch_size, type='corpus', **kwargs)

    # def encode(self, corpus: Union[List[Dict[str, str]], Dict[str, List]], batch_size: int = 8, **kwargs) -> \
    def encode_corpus(self, corpus: Union[List[Dict[str, str]], Dict[str, List]], batch_size: int = 8, **kwargs) -> \
    Union[List[Tensor], np.ndarray, Tensor]:
        if type(corpus) is dict:
            sentences = [
                (corpus["title"][i] + self.sep + corpus["text"][i]).strip() if "title" in corpus else corpus["text"][
                    i].strip() for i in range(len(corpus['text']))]
        elif type(corpus[0]) is dict:
            sentences = [(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for
                         doc in corpus]
        else:
            sentences = corpus
        # return self.encode_process(sentences, batch_size=batch_size, type='corpus', **kwargs)
        if self.mode == 'q2p' or self.mode == 'p2p':
            return self.encode_process(sentences, batch_size=batch_size, type='corpus', **kwargs)
        else:
            return self.encode_process(sentences, batch_size=batch_size, type='queries', **kwargs)

    ## Encoding corpus in parallel
    def encode_corpus_parallel(self, corpus: Union[List[Dict[str, str]], Dataset], pool: Dict[str, str],
                               batch_size: int = 8, chunk_id: int = None, **kwargs):
        if type(corpus) is dict:
            sentences = [
                (corpus["title"][i] + self.sep + corpus["text"][i]).strip() if "title" in corpus else corpus["text"][
                    i].strip() for i in range(len(corpus['text']))]
        elif type(corpus[0]) is dict:
            sentences = [(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for
                         doc in corpus]
        else:
            sentences = corpus

        if chunk_id is not None and chunk_id >= len(pool['processes']):
            output_queue = pool['output']
            output_queue.get()

        input_queue = pool['input']
        input_queue.put([chunk_id, batch_size, sentences])

    def encode_process(self, sentences: Union[str, List[str]],
                       type: str = 'corpus',
                       batch_size: int = 32,
                       show_progress_bar: bool = None,
                       output_value: str = 'sentence_embedding',
                       convert_to_numpy: bool = True,
                       convert_to_tensor: bool = False,
                       device: str = None,
                       normalize_embeddings: bool = True) -> Union[List[Tensor], np.ndarray, Tensor]:
        """
        Computes sentence embeddings

        :param sentences: the sentences to embed
        :param batch_size: the batch size used for the computation
        :param show_progress_bar: Output a progress bar when encode sentences
        :param output_value:  Default sentence_embedding, to get sentence embeddings. Can be set to token_embeddings to get wordpiece token embeddings. Set to None, to get all output values
        :param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list of pytorch tensors.
        :param convert_to_tensor: If true, you get one large tensor as return. Overwrites any setting from convert_to_numpy
        :param device: Which torch.device to use for the computation
        :param normalize_embeddings: If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used.

        :return:
           By default, a list of tensors is returned. If convert_to_tensor, a stacked tensor is returned. If convert_to_numpy, a numpy matrix is returned.
        """
        self.model.eval()
        if show_progress_bar is None:
            show_progress_bar = (logger.getEffectiveLevel()==logging.INFO or logger.getEffectiveLevel()==logging.DEBUG)

        if convert_to_tensor:
            convert_to_numpy = False

        if output_value != 'sentence_embedding':
            convert_to_tensor = False
            convert_to_numpy = False

        input_was_string = False
        if isinstance(sentences, str) or not hasattr(sentences, '__len__'): #Cast an individual sentence to a list with length 1
            sentences = [sentences]
            input_was_string = True

        if device is None:
            device = self.device

        self.model = self.model.to(device)

        all_embeddings = []
        length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences])
        sentences_sorted = [sentences[idx] for idx in length_sorted_idx]

        dataset = DatasetForEmbedding(sentences_sorted, self.tokenizer_name, 512, type)
        # dataloader = DataLoader(dataset, shuffle=False, batch_size=self.batch_size, drop_last=False, num_workers=64,
        #                         collate_fn=collater(self.tokenizer, 512))
        collat = collater(self.tokenizer, 512)
        for start_index in trange(0, len(sentences), self.batch_size, desc="Batches", disable=not show_progress_bar):
        # for features in tqdm(dataloader, desc="Batches", disable=not show_progress_bar):
            sentences_batch = dataset[start_index:start_index+self.batch_size]
            # features = self.tokenize(sentences_batch, type)
            features = collat(sentences_batch)
            features = features.to(device)

            with torch.no_grad():
                out_features = self.model(**features, return_dict=True, output_hidden_states=True)

                embeddings = out_features.hidden_states[-1][:, -8:, :]
                embeddings = torch.mean(embeddings, dim=1)
                ### modify
                if normalize_embeddings:
                    embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
                # embeddings = embeddings.float()
                if convert_to_numpy:
                    embeddings = embeddings.cpu()

                all_embeddings.extend(embeddings)

        all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]

        if convert_to_tensor:
            all_embeddings = torch.stack(all_embeddings)
        elif convert_to_numpy:
            all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])

        if input_was_string:
            all_embeddings = all_embeddings[0]

        return all_embeddings

    def _text_length(self, text: Union[List[int], List[List[int]]]):
        """
        Help function to get the length for the input text. Text can be either
        a list of ints (which means a single text as input), or a tuple of list of ints
        (representing several text inputs to the model).
        """

        if isinstance(text, dict):              #{key: value} case
            return len(next(iter(text.values())))
        elif not hasattr(text, '__len__'):      #Object has no len() method
            return 1
        elif len(text) == 0 or isinstance(text[0], int):    #Empty string or list of ints
            return len(text)
        else:
            return sum([len(t) for t in text])      #Sum of length of individual strings


def batch_to_device(batch, target_device: device):
    """
    send a pytorch batch to a device (CPU/GPU)
    """
    for key in batch:
        if isinstance(batch[key], Tensor):
            batch[key] = batch[key].to(target_device)
    return batch