You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
158 lines
5.7 KiB
158 lines
5.7 KiB
import asyncio
|
|
import time
|
|
from .prompts import *
|
|
from .langchain_memory import BotConversationSummerBufferWindowMemory
|
|
|
|
from langchain import PromptTemplate
|
|
from langchain import LLMChain, ConversationChain
|
|
from langchain.memory import ConversationBufferMemory
|
|
|
|
from langchain.chains.base import Chain
|
|
from typing import Dict, List
|
|
|
|
from langchain.document_loaders import TextLoader
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
from langchain.embeddings import SentenceTransformerEmbeddings
|
|
from langchain.vectorstores import Chroma
|
|
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RoleplayChain(Chain):
|
|
llm_chain: LLMChain
|
|
|
|
character_name: str
|
|
persona: str
|
|
scenario: str
|
|
ai_name_chat: str
|
|
human_name_chat: str
|
|
|
|
output_key: str = "output_text" #: :meta private:
|
|
|
|
@property
|
|
def input_keys(self) -> List[str]:
|
|
return ["character_name", "persona", "scenario", "ai_name_chat", "human_name_chat", "llm_chain"]
|
|
|
|
@property
|
|
def output_keys(self) -> List[str]:
|
|
return [self.output_key]
|
|
|
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
|
other_keys = {k: v for k, v in inputs.items() if k not in self.input_keys}
|
|
result = self.llm_chain.predict(**other_keys)
|
|
return {self.output_key: result}
|
|
|
|
|
|
|
|
class AI(object):
|
|
|
|
def __init__(self, bot, text_wrapper, image_wrapper, memory_path: str):
|
|
self.name = bot.name
|
|
self.bot = bot
|
|
self.memory_path = memory_path
|
|
self.rooms = {}
|
|
|
|
from ..wrappers.langchain_koboldcpp import KoboldCpp
|
|
self.llm_chat = KoboldCpp(temperature=self.bot.temperature, endpoint_url="http://172.16.85.10:5001/api/latest/generate", stop=['<|endoftext|>'])
|
|
self.llm_summary = KoboldCpp(temperature=0.2, endpoint_url="http://172.16.85.10:5001/api/latest/generate", stop=['<|endoftext|>'])
|
|
self.text_wrapper = text_wrapper
|
|
self.image_wrapper = image_wrapper
|
|
|
|
#self.memory = BotConversationSummerBufferWindowMemory(llm=self.llm_summary, max_token_limit=1200, min_token_limit=200)
|
|
|
|
def get_memory(self, message):
|
|
if not message.room_id in self.rooms:
|
|
self.rooms[message.room_id] = {}
|
|
memory = ConversationBufferMemory(memory_key="chat_history", human_prefix=message.user_name, ai_prefix=self.bot.name)
|
|
self.rooms[message.room_id]["memory"] = memory
|
|
memory.chat_memory.add_ai_message(self.bot.greeting)
|
|
#memory.save_context({"input": None, "output": self.bot.greeting})
|
|
memory.load_memory_variables({})
|
|
else:
|
|
memory = self.rooms[message.room_id]["memory"]
|
|
print(f"memory: {memory.load_memory_variables({})}")
|
|
print(f"memory has an estimated {self.llm_chat.get_num_tokens(memory.buffer)} number of tokens")
|
|
return memory
|
|
|
|
|
|
async def generate(self, message, reply_fn, typing_fn):
|
|
|
|
embeddings = SentenceTransformerEmbeddings()
|
|
#embeddings = SentenceTransformerEmbeddings(model="all-MiniLM-L6-v2")
|
|
|
|
loader = TextLoader('./germany.txt')
|
|
documents = loader.load()
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(
|
|
# Set a really small chunk size, just to show.
|
|
chunk_size = 600,
|
|
chunk_overlap = 100,
|
|
length_function = len,
|
|
)
|
|
|
|
docs = text_splitter.split_documents(documents)
|
|
|
|
db = Chroma(persist_directory=f'{self.memory_path}/chroma-db', embedding_function=embeddings)
|
|
|
|
print(f"Indexing {len(docs)} documents")
|
|
texts = [doc.page_content for doc in docs]
|
|
metadatas = [doc.metadata for doc in docs]
|
|
#db.add_texts(texts=texts, metadatas=metadatas, ids=None)
|
|
#db.persist()
|
|
|
|
query = "How is climate in Germany?"
|
|
output_docs = db.similarity_search_with_score(query)
|
|
print(query)
|
|
print('###')
|
|
for doc, score in output_docs:
|
|
print("-" * 80)
|
|
print("Score: ", score)
|
|
print(doc.page_content)
|
|
print("-" * 80)
|
|
|
|
prompt_template = "{input}"
|
|
chain = LLMChain(
|
|
llm=self.llm_chat,
|
|
prompt=PromptTemplate.from_template(prompt_template),
|
|
)
|
|
output = await chain.arun(message.message)
|
|
return output.strip()
|
|
|
|
async def generate_roleplay(self, message, reply_fn, typing_fn):
|
|
|
|
memory = self.get_memory(message)
|
|
|
|
prompt = prompt_vicuna.partial(
|
|
ai_name=self.bot.name,
|
|
persona=self.bot.persona,
|
|
scenario=self.bot.scenario,
|
|
human_name=message.user_name,
|
|
#example_dialogue=replace_all(self.bot.example_dialogue, {"{{user}}": message.user_name, "{{char}}": self.bot.name})
|
|
ai_name_chat=self.bot.name,
|
|
)
|
|
|
|
chain = ConversationChain(
|
|
llm=self.llm_chat,
|
|
prompt=prompt,
|
|
verbose=True,
|
|
memory=memory,
|
|
#stop=['<|endoftext|>', '\nYou:', f"\n{message.user_name}:"],
|
|
)
|
|
|
|
# output = llm_chain(inputs={"ai_name": self.bot.name, "persona": self.bot.persona, "scenario": self.bot.scenario, "human_name": message.user_name, "ai_name_chat": self.bot.name, "chat_history": "", "input": message.message})['results'][0]['text']
|
|
|
|
#roleplay_chain = RoleplayChain(llm_chain=chain, character_name=self.bot.name, persona=self.bot.persona, scenario=self.bot.scenario, ai_name_chat=self.bot.name, human_name_chat=message.user_name)
|
|
|
|
stop = ['<|endoftext|>', f"\n{message.user_name}:"]
|
|
print(f"Message is: \"{message.message}\"")
|
|
output = await chain.arun({"input":message.message, "stop": stop})
|
|
|
|
return output.strip()
|
|
|
|
|
|
def replace_all(text, dic):
|
|
for i, j in dic.items():
|
|
text = text.replace(i, j)
|
|
return text
|
|
|