import asyncio
import os, time
from .prompts import *
from .langchain_memory import CustomMemory, ChangeNamesMemory # BotConversationSummaryBufferWindowMemory, TestMemory
from ..utilities.messages import Message

from langchain import PromptTemplate
from langchain import LLMChain, ConversationChain
from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory, CombinedMemory, ConversationSummaryMemory

from langchain.chains.base import Chain
from langchain.chains.summarize import load_summarize_chain
from typing import Any, Dict, List, Optional, Union

from langchain.document_loaders import TextLoader
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings  # was SentenceTransformerEmbeddings
from langchain.vectorstores import Chroma

from langchain.agents import Tool, AgentExecutor, LLMSingleActionAgent, AgentOutputParser, ZeroShotAgent
from langchain.schema import AgentAction, AgentFinish
from langchain.schema import AIMessage, HumanMessage, SystemMessage, ChatMessage
from langchain.utilities import OpenWeatherMapAPIWrapper, SearxSearchWrapper, PythonREPL
from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper

import humanize
from datetime import datetime, timedelta

from termcolor import colored
import logging

logger = logging.getLogger(__name__)


class RoleplayChain(Chain):
    llm_chain: LLMChain

    character_name: str
    persona: str
    scenario: str
    ai_name_chat: str
    human_name_chat: str

    output_key: str = "output_text"  #: :meta private:

    @property
    def input_keys(self) -> List[str]:
        return ["character_name", "persona", "scenario", "ai_name_chat", "human_name_chat", "llm_chain"]

    @property
    def output_keys(self) -> List[str]:
        return [self.output_key]

    def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
        other_keys = {k: v for k, v in inputs.items() if k not in self.input_keys}
        result = self.llm_chain.predict(**other_keys)
        return {self.output_key: result}


class CustomOutputParser(AgentOutputParser):
    
    def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:
        # Check if agent should finish
        if "Final Answer:" in llm_output:
            return AgentFinish(
                # Return values is generally always a dictionary with a single `output` key
                # It is not recommended to try anything else at the moment :)
                return_values={"output": llm_output.split("Final Answer:")[-1].strip()},
                log=llm_output,
            )
        # Parse out the action and action input
        regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
        match = re.search(regex, llm_output, re.DOTALL)
        if not match:
            regex = r"Action\s*\d*\s*:(.*?)[\s]*[\"\'](.*)[\"\']"
            match = re.search(regex, llm_output, re.DOTALL)
            if not match:
                raise ValueError(f"Could not parse LLM output: `{llm_output}`")
        action = match.group(1).strip()
        action_input = match.group(2)
        # Return the action and action input
        return AgentAction(tool=action, tool_input=action_input.strip(" ").strip('"'), log=llm_output)


class AI(object):

    def __init__(self, bot, text_wrapper, image_wrapper, memory_path: str):
        self.name = bot.name
        self.bot = bot
        self.memory_path = memory_path
        self.rooms = {}
        self.max_context = 4096

        from ..wrappers.langchain_koboldcpp import KoboldCpp
        self.llm_chat = KoboldCpp(temperature=self.bot.temperature, endpoint_url="http://172.16.33.10:5001/api/latest/generate", max_context=self.max_context, stop=['<|endoftext|>'], verbose=True)
        self.llm_summary = KoboldCpp(temperature=0.7, repeat_penalty=1.15, top_k = 20, top_p= 0.9, endpoint_url="http://172.16.33.10:5001/api/latest/generate", max_context=self.max_context, stop=['<|endoftext|>'], max_tokens=512, verbose=True)
        self.llm_chat_model = "pygmalion-7b"
        self.llm_summary_model = "vicuna-13b"
        self.text_wrapper = text_wrapper
        self.image_wrapper = image_wrapper
        self.embeddings = HuggingFaceEmbeddings()
        #self.embeddings = HuggingFaceEmbeddings(model="all-MiniLM-L6-v2")
        #self.embeddings = HuggingFaceEmbeddings(
        #    model_name="sentence-transformers/all-mpnet-base-v2",
        #    model_kwargs={'device': 'cpu'},
        #    encode_kwargs={'normalize_embeddings': False}
        #)
        self.db = Chroma(persist_directory=os.path.join(self.memory_path, f'chroma-db'), embedding_function=self.embeddings)

        #self.memory = BotConversationSummerBufferWindowMemory(llm=self.llm_summary, max_token_limit=1200, min_token_limit=200)

    def get_memory(self, room_id, human_prefix=None):
        if not room_id in self.rooms:
            self.rooms[room_id] = {}
            if "moving_summary" in self.bot.rooms[room_id]:
                moving_summary = self.bot.rooms[room_id]['moving_summary']
            else:
                moving_summary = "No previous events."
            if "last_message_ids_summarized" in self.bot.rooms[room_id]:
                last_message_ids_summarized = self.bot.rooms[room_id]['last_message_ids_summarized']
            else:
                last_message_ids_summarized = []
            if not human_prefix:
                human_prefix = "Human"
            memory = CustomMemory(memory_key="chat_history", input_key="input", human_prefix=human_prefix, ai_prefix=self.bot.name, llm=self.llm_summary, summary_prompt=prompt_progressive_summary, moving_summary_buffer=moving_summary, max_len=int(self.max_context-800), min_len=int(0.1*self.max_context), last_message_ids_summarized=last_message_ids_summarized)
            self.rooms[room_id]["memory"] = memory
            #memory.chat_memory.add_ai_message(self.bot.greeting)
        else:
            memory = self.rooms[room_id]["memory"]
            if human_prefix:
                memory.human_prefix = human_prefix
        return memory

    async def add_chat_message(self, message):
        room_id = message.additional_kwargs['room_id']
        conversation_memory = self.get_memory(room_id)
        if 'event_id' in message.additional_kwargs and message.additional_kwargs['event_id'] in conversation_memory.last_message_ids_summarized:
            #don't add already summarized messages
            return
        conversation_memory.chat_memory.messages.append(message)
        conversation_memory.chat_memory_day.messages.append(message)

    async def clear(self, room_id):
        conversation_memory = self.get_memory(room_id)
        conversation_memory.clear()

    async def ingest_textfile(self, filename, category):
        loader = TextLoader(filename)
        documents = loader.load()
        documents[0].metadata['indexed'] = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        documents[0].metadata['category'] = category

        text_splitter = RecursiveCharacterTextSplitter(
            # Set a really small chunk size, just to show.
            chunk_size = 1024,
            chunk_overlap  = 80,
            length_function = len,
            #length_function = self.llm_chat.get_num_tokens,  # The Embeddings are generated with SsentenceTransformers, not this model
        )

        docs = text_splitter.split_documents(documents)

        for i in range(len(docs)):
            docs[i].metadata['part'] = f"{i}/{len(docs)}"

        print(f"Indexing {len(docs)} documents")
        texts = [doc.page_content for doc in docs]
        metadatas = [doc.metadata for doc in docs]
        self.db.add_texts(texts=texts, metadatas=metadatas, ids=None)
        self.db.persist()

    async def search_vectordb(self, query, category):
        #query = "How is climate in Germany?"
        #retreiver = db.as_retreiver()
        #docs = retreiver.get_relevant_documents(query)
        if category:
            #https://github.com/chroma-core/chroma/blob/main/examples/where_filtering.ipynb
            output_docs = self.db.similarity_search_with_score(query, filter={"category": category})
        else:
            output_docs = self.db.similarity_search_with_score(query)
        print(query)
        print('###')
        for doc, score in output_docs:
            print("-" * 80)
            print("Score: ", score)
            #print(doc.page_content)
            print(doc)
            print("-" * 80)


    async def generate(self, message, reply_fn, typing_fn):

        prompt_template = "{input}"
        chain = LLMChain(
            llm=self.llm_chat,
            prompt=PromptTemplate.from_template(prompt_template),
        )
        output = await chain.arun(message.content)
        return output.strip()


    async def generate_roleplay(self, message, reply_fn, typing_fn):

        chat_ai_name = self.bot.name
        chat_human_name = message.additional_kwargs['user_name']
        room_id = message.additional_kwargs['room_id']

        if self.llm_chat_model.startswith('vicuna'):
            prompt_chat = prompt_vicuna
            chat_ai_name = "### Assistant"
            chat_human_name = "### Human"
        elif self.llm_chat_model.startswith('pygmalion'):
            prompt_chat = prompt_pygmalion
            chat_human_name = "You"
        elif self.llm_chat_model.startswith('koboldai'):
            prompt_chat = prompt_koboldai
        else:
            prompt_chat = prompt_alpaca

        conversation_memory = self.get_memory(room_id, chat_human_name)
        readonlymemory = ReadOnlySharedMemory(memory=conversation_memory)
        custom_memory = ChangeNamesMemory(memory=conversation_memory, replace_ai_chat_names={self.bot.name: chat_ai_name}, replace_human_chat_names={message.additional_kwargs['user_name']: chat_human_name})
        #summary_memory = ConversationSummaryMemory(llm=self.llm_summary, memory_key="summary", input_key="input")
        #combined_memory = CombinedMemory(memories=[conversation_memory, summary_memory])

        #await self.bot.schedule(self.bot.queue, make_progressive_summary, self.rooms[room_id]["summary"], conversation_memory.buffer) #.add_done_callback(

        #t = datetime.fromtimestamp(message.additional_kwargs['timestamp'])
        #when = humanize.naturaltime(t)
        #print(when)

        # ToDo: either use prompt.format() to fill out the pygmalion prompt and use
        # the resulting template text to feed it into the instruct prompt's instruction
        # or do this with the prompt.partial()

        for i in range(1):
            prompt = prompt_chat.partial(
                ai_name=self.bot.name,
                persona=self.bot.persona,
                scenario=self.bot.scenario,
                human_name=chat_human_name,
                ai_name_chat=chat_ai_name,
            )
            if "summary" in prompt.input_variables:
                prompt = prompt.partial(summary=conversation_memory.moving_summary_buffer)
            if "example_dialogue" in prompt.input_variables:
                prompt = prompt.partial(
                    example_dialogue=self.bot.example_dialogue.replace("{{user}}", chat_human_name)
                )
            
            tmp_prompt_text = prompt.format(chat_history=conversation_memory.buffer, input=message.content)
            prompt_len = self.llm_chat.get_num_tokens(tmp_prompt_text)

            if prompt_len+200 > self.max_context:
                logger.warning(f"Prompt too large. Estimated {prompt_len} tokens. Summarizing...")
                await reply_fn(f"<WARNING> Prompt too large. Estimated {prompt_len} tokens")
                if i == 0:
                    await conversation_memory.prune_memory(conversation_memory.min_len)
                elif i == 1:
                    conversation_memory.moving_summary_buffer = await self.summarize(conversation_memory.moving_summary_buffer)
            else:
                break

        #roleplay_chain = RoleplayChain(llm_chain=chain, character_name=self.bot.name, persona=self.bot.persona, scenario=self.bot.scenario, ai_name_chat=chat_ai_name, human_name_chat=chat_human_name)

        chain = ConversationChain(
            llm=self.llm_chat,
            prompt=prompt,
            verbose=True,
            memory=custom_memory,
            #stop=['<|endoftext|>', '\nYou:', f"\n{chat_human_name}:"],
        )

#        output = llm_chain(inputs={"ai_name": self.bot.name, "persona": self.bot.persona, "scenario": self.bot.scenario, "human_name": chat_human_name, "ai_name_chat": self.bot.name, "chat_history": "", "input": message.content})['results'][0]['text']

        stop = ['<|endoftext|>', f"\n{chat_human_name}:"]
        if chat_human_name != message.additional_kwargs['user_name']:
            stop.append(f"\n{message.additional_kwargs['user_name']}:")
        #print(f"Message is: \"{message.content}\"")
        await asyncio.sleep(0)
        output = await chain.arun({"input":message.content, "stop": stop})
        output = output.replace("<BOT>", self.bot.name).replace("<USER>", chat_human_name)
        output = output.replace("### Assistant", self.bot.name)
        output = output.replace(f"\n{self.bot.name}: ", " ")
        output = output.strip()


        if "*activates the neural uplink*" in output.casefold():
            pass # call agent

        own_message_resp = await reply_fn(output)

        output_message = AIMessage(
            content=output,
            additional_kwargs={
                "timestamp": datetime.now().timestamp(),
                "user_name": self.bot.name,
                "event_id": own_message_resp.event_id,
                "user_id": self.bot.connection.user_id,
                "room_name": message.additional_kwargs['room_name'],
                "room_id": own_message_resp.room_id,
            }
        )

        await conversation_memory.asave_context(message, output_message)
        summary_len = self.llm_chat.get_num_tokens(conversation_memory.moving_summary_buffer)
        if summary_len > 400:
            logger.warning("Summary is getting too long. Refining...")
            conversation_memory.moving_summary_buffer = await self.summarize(conversation_memory.moving_summary_buffer)
            new_summary_len = self.llm_chat.get_num_tokens(conversation_memory.moving_summary_buffer)
            logger.info(f"Refined summary from {summary_len} tokens to {new_summary_len} tokens ({new_summary_len-summary_len} tokens)")
        self.bot.rooms[room_id]['moving_summary'] = conversation_memory.moving_summary_buffer
        self.bot.rooms[room_id]['last_message_ids_summarized'] = conversation_memory.last_message_ids_summarized

        return output


    async def summarize(self, text, map_prompt=prompt_summary2, combine_prompt=prompt_summary2):
        #metadata = {"source": "internet", "date": "Friday"}
        #doc = Document(page_content=text, metadata=metadata)
        docs = [ Document(page_content=text) ]

        map_chain = LLMChain(llm=self.llm_summary, prompt=map_prompt, verbose=True)
        reduce_chain = LLMChain(llm=self.llm_summary, prompt=combine_prompt, verbose=True)
        text_splitter = RecursiveCharacterTextSplitter(
            #separators = ["\n\n", "\n", " ", ""],
            chunk_size = 1600,
            chunk_overlap  = 80,
            length_function = self.llm_chat.get_num_tokens,
        )

        for i in range(2):
            docs = text_splitter.split_documents(docs)
            if len(docs) > 1:
                #results = await map_chain.aapply([{"text": d.page_content} for d in docs])
                #docs = [Document(page_content=r['output'], metadata=docs[i].metadata) for i, r in enumerate(results)
                for i, d in enumerate(docs):
                    await asyncio.sleep(0) # yield for matrix-nio
                    docs[i].page_content = await map_chain.arun(docs[i].page_content)
                combined = "\n".join([d.page_content for d in docs])
                docs = [ Document(page_content=combined) ]
            else:
                break

        await asyncio.sleep(0) # yield for matrix-nio
        return await reduce_chain.arun(text=docs[0].page_content)
        #ToDo: max_tokens and stop

    async def diary(self, room_id):
        diary_chain = LLMChain(llm=self.llm_summary, prompt=prompt_outline, verbose=True)
        conversation_memory = self.get_memory(room_id)

        text_splitter = RecursiveCharacterTextSplitter(
            separators = ["\n", " ", ""],
            chunk_size = 1600,
            chunk_overlap  = 40,
            length_function = self.llm_summary.get_num_tokens,
        )
        docs = text_splitter.create_documents([conversation_memory.buffer_day])
        string_diary = []
        for i, doc in enumerate(docs):
            logger.info("Writing diary... page {i} of {len(docs)}.")
            await asyncio.sleep(0) # yield for matrix-nio
            diary_chunk = await diary_chain.apredict(text=docs[i].page_content, ai_name=self.bot.name)
            string_diary.append(diary_chunk)
        diary_entry = "\n".join(string_diary)

        if self.llm_summary.get_num_tokens(diary_entry) > 1400:
            logger.info("Summarizing diary entry.")
            await asyncio.sleep(0)
            diary_entry = await self.summarize(diary_entry)
        if self.llm_summary.get_num_tokens(diary_entry) > 1600:
            logger.warning("Diary entry too long. Discarding.")
            diary_entry = conversation_memory.moving_summary_buffer

        return diary_entry


    async def agent(self):

        os.environ["OPENWEATHERMAP_API_KEY"] = "82452fdb0d1e0e805ac096db87914342"
        # Tools
        search = DuckDuckGoSearchAPIWrapper()
        weather = OpenWeatherMapAPIWrapper()
        search2 = SearxSearchWrapper(searx_host="https://search.mdosch.de")
        python_repl = PythonREPL()

        tools = [
            Tool(
                name = "Search",
                func=search.run,
                description="useful for when you need to answer questions about current events"
            ),
            Tool(
                name = "Searx Search",
                func=search.run,
                description="useful for when you need to answer questions about current events"
            ),
            Tool(
                name = "Weather",
                func=weather.run,
                description="Useful for fetching current weather information for a specified location. Input should be a location string (e.g. 'London,GB')."
            ),
            Tool(
                name = "Summary",
                func=summry_chain.run,
                description="useful for when you summarize a conversation. The input to this tool should be a string, representing who will read this summary."
            )
        ]

        prompt = ZeroShotAgent.create_prompt(
            tools=tools, 
            prefix=prefix, 
            suffix=suffix, 
            input_variables=["input", "chat_history", "agent_scratchpad"]
        )

        output_parser = CustomOutputParser()

        # LLM chain consisting of the LLM and a prompt
        llm_chain = LLMChain(llm=llm, prompt=prompt_agent)

        agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True)
        #agent = initialize_agent(tools, llm, agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION, verbose=True, return_intermediate_steps=True, memory=memory)

        #tool_names = [tool.name for tool in tools]
        #agent = LLMSingleActionAgent(
        #    llm_chain=llm_chain, 
        #    output_parser=output_parser,
        #    stop=["\nObservation:"], 
        #    allowed_tools=tool_names,
        #    verbose=True,
        #)

        agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, memory=memory)


        await agent_executor.arun(input="How many people live in canada as of 2023?")


    async def sleep(self):
        logger.info(f"{self.bot.name} sleeping now... running background tasks...")
        # Write Date into chat history
        for room_id in self.rooms.keys():
            #fake_message = Message(datetime.now().timestamp(), self.bot.name, "", event_id=None, user_id=None, room_name=None, room_id=room_id)
            conversation_memory = self.get_memory(room_id)
            message = SystemMessage(
                content=f"~~~~ {datetime.now().strftime('%A, %B %d, %Y')} ~~~~",
                additional_kwargs={
                    "timestamp": datetime.now().timestamp(),
                    "user_name": None,
                    "event_id": None,
                    "user_id": None,
                    "room_name": None,
                    "room_id": room_id,
                }
            )
            if conversation_memory.chat_memory.messages[-1].content.startswith('~~~~ '):
                conversation_memory.chat_memory.messages.pop()
            conversation_memory.chat_memory.messages.append(message)
            #conversation_memory.chat_memory.add_system_message(message)

            # [ 21:30 | Tuesday 9th | Pentagram City Alleys | 18°C | Overcast | 92% ]

        # Summarize the last day and save a diary entry
        yesterday = ( datetime.now() - timedelta(days=1) ).strftime('%Y-%m-%d')
        for room_id in self.rooms.keys():
            if len(conversation_memory.chat_memory_day.messages) > 0:
                if not "diary" in self.bot.rooms[room_id]:
                    self.bot.rooms[room_id]['diary'] = {}
                self.bot.rooms[room_id]["diary"][yesterday] = await self.diary(room_id)
        # Calculate new goals for the character
        # Update stats
        # Let background tasks run
        conversation_memory.chat_memory_day.clear()
        await conversation_memory.prune_memory(conversation_memory.min_len)
        await self.bot.write_conf2(self.bot.rooms)
        logger.info(f"{self.bot.name} done sleeping and ready for the next day...")


    async def prime_llm(self, text):
        self.llm_chat(text, max_tokens=1)


def replace_all(text, dic):
    #example_dialogue=replace_all(self.bot.example_dialogue, {"{{user}}": chat_human_name, "{{char}}": chat_ai_name})
    for i, j in dic.items():
        text = text.replace(i, j)
    return text