|
@ -52,6 +52,7 @@ class AI(object): |
|
|
self.name = bot.name |
|
|
self.name = bot.name |
|
|
self.bot = bot |
|
|
self.bot = bot |
|
|
self.memory_path = memory_path |
|
|
self.memory_path = memory_path |
|
|
|
|
|
self.rooms = {} |
|
|
|
|
|
|
|
|
from ..wrappers.langchain_koboldcpp import KoboldCpp |
|
|
from ..wrappers.langchain_koboldcpp import KoboldCpp |
|
|
self.llm_chat = KoboldCpp(temperature=self.bot.temperature, endpoint_url="http://172.16.85.10:5001/api/latest/generate", stop=['<|endoftext|>']) |
|
|
self.llm_chat = KoboldCpp(temperature=self.bot.temperature, endpoint_url="http://172.16.85.10:5001/api/latest/generate", stop=['<|endoftext|>']) |
|
@ -60,8 +61,20 @@ class AI(object): |
|
|
self.image_wrapper = image_wrapper |
|
|
self.image_wrapper = image_wrapper |
|
|
|
|
|
|
|
|
#self.memory = BotConversationSummerBufferWindowMemory(llm=self.llm_summary, max_token_limit=1200, min_token_limit=200) |
|
|
#self.memory = BotConversationSummerBufferWindowMemory(llm=self.llm_summary, max_token_limit=1200, min_token_limit=200) |
|
|
self.memory = ConversationBufferMemory(memory_key="chat_history", human_prefix="You", ai_prefix=self.bot.name) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_memory(self, message): |
|
|
|
|
|
if not message.room_id in self.rooms: |
|
|
|
|
|
self.rooms[message.room_id] = {} |
|
|
|
|
|
memory = ConversationBufferMemory(memory_key="chat_history", human_prefix=message.user_name, ai_prefix=self.bot.name) |
|
|
|
|
|
self.rooms[message.room_id]["memory"] = memory |
|
|
|
|
|
memory.chat_memory.add_ai_message(self.bot.greeting) |
|
|
|
|
|
#memory.save_context({"input": None, "output": self.bot.greeting}) |
|
|
|
|
|
memory.load_memory_variables({}) |
|
|
|
|
|
else: |
|
|
|
|
|
memory = self.rooms[message.room_id]["memory"] |
|
|
|
|
|
print(f"memory: {memory.load_memory_variables({})}") |
|
|
|
|
|
print(f"memory has an estimated {estimate_num_tokens(memory.buffer)} number of tokens") |
|
|
|
|
|
return memory |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def generate(self, message, reply_fn, typing_fn): |
|
|
async def generate(self, message, reply_fn, typing_fn): |
|
@ -109,13 +122,14 @@ class AI(object): |
|
|
|
|
|
|
|
|
async def generate_roleplay(self, message, reply_fn, typing_fn): |
|
|
async def generate_roleplay(self, message, reply_fn, typing_fn): |
|
|
|
|
|
|
|
|
self.memory.human_prefix = message.user_name |
|
|
memory = self.get_memory(message) |
|
|
|
|
|
|
|
|
prompt = prompt_vicuna.partial( |
|
|
prompt = prompt_vicuna.partial( |
|
|
ai_name=self.bot.name, |
|
|
ai_name=self.bot.name, |
|
|
persona=self.bot.persona, |
|
|
persona=self.bot.persona, |
|
|
scenario=self.bot.scenario, |
|
|
scenario=self.bot.scenario, |
|
|
human_name=message.user_name, |
|
|
human_name=message.user_name, |
|
|
|
|
|
#example_dialogue=replace_all(self.bot.example_dialogue, {"{{user}}": message.user_name, "{{char}}": self.bot.name}) |
|
|
ai_name_chat=self.bot.name, |
|
|
ai_name_chat=self.bot.name, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
@ -123,7 +137,7 @@ class AI(object): |
|
|
llm=self.llm_chat, |
|
|
llm=self.llm_chat, |
|
|
prompt=prompt, |
|
|
prompt=prompt, |
|
|
verbose=True, |
|
|
verbose=True, |
|
|
memory=self.memory, |
|
|
memory=memory, |
|
|
#stop=['<|endoftext|>', '\nYou:', f"\n{message.user_name}:"], |
|
|
#stop=['<|endoftext|>', '\nYou:', f"\n{message.user_name}:"], |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
@ -131,9 +145,17 @@ class AI(object): |
|
|
|
|
|
|
|
|
#roleplay_chain = RoleplayChain(llm_chain=chain, character_name=self.bot.name, persona=self.bot.persona, scenario=self.bot.scenario, ai_name_chat=self.bot.name, human_name_chat=message.user_name) |
|
|
#roleplay_chain = RoleplayChain(llm_chain=chain, character_name=self.bot.name, persona=self.bot.persona, scenario=self.bot.scenario, ai_name_chat=self.bot.name, human_name_chat=message.user_name) |
|
|
|
|
|
|
|
|
output = chain.run({"input":message.message, "stop": ['<|endoftext|>', f"\n{message.user_name}:"]}) |
|
|
stop = ['<|endoftext|>', f"\n{message.user_name}:"] |
|
|
|
|
|
print(f"Message is: \"{message.message}\"") |
|
|
|
|
|
output = chain.run({"input":message.message, "stop": stop}) |
|
|
|
|
|
|
|
|
return output.strip() |
|
|
return output.strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def estimate_num_tokens(input_text: str): |
|
|
def estimate_num_tokens(input_text: str): |
|
|
return len(input_text)//4+1 |
|
|
return len(input_text)//4+1 |
|
|
|
|
|
|
|
|
|
|
|
def replace_all(text, dic): |
|
|
|
|
|
for i, j in dic.items(): |
|
|
|
|
|
text = text.replace(i, j) |
|
|
|
|
|
return text |
|
|