Hendrik Langer
2 years ago
39 changed files with 1386 additions and 1751 deletions
@ -1,108 +0,0 @@ |
|||||
import asyncio |
|
||||
import os, tempfile |
|
||||
import logging |
|
||||
|
|
||||
import json |
|
||||
import requests |
|
||||
|
|
||||
from transformers import AutoTokenizer, AutoConfig |
|
||||
from huggingface_hub import hf_hub_download |
|
||||
|
|
||||
import io |
|
||||
import base64 |
|
||||
from PIL import Image, PngImagePlugin |
|
||||
|
|
||||
from .model_helpers import get_full_prompt, num_tokens |
|
||||
|
|
||||
logger = logging.getLogger(__name__) |
|
||||
|
|
||||
|
|
||||
def setup(): |
|
||||
os.system("mkdir -p repositories && (cd repositories && git clone https://github.com/LostRuins/koboldcpp.git)") |
|
||||
os.system("apt update && apt-get install libopenblas-dev libclblast-dev libmkl-dev") |
|
||||
os.system("(cd repositories/koboldcpp && make LLAMA_OPENBLAS=1 && cd models && wget https://huggingface.co/concedo/pygmalion-6bv3-ggml-ggjt/resolve/main/pygmalion-6b-v3-ggml-ggjt-q4_0.bin)") |
|
||||
#python3 koboldcpp.py models/pygmalion-6b-v3-ggml-ggjt-q4_0.bin |
|
||||
#python3 koboldcpp.py --smartcontext models/pygmalion-6b-v3-ggml-ggjt-q4_0.bin |
|
||||
|
|
||||
async def generate_sync( |
|
||||
prompt: str, |
|
||||
api_key: str, |
|
||||
bot, |
|
||||
typing_fn |
|
||||
): |
|
||||
# Set the API endpoint URL |
|
||||
endpoint = f"http://172.16.85.10:5001/api/latest/generate" |
|
||||
|
|
||||
# Set the headers for the request |
|
||||
headers = { |
|
||||
"Content-Type": "application/json", |
|
||||
} |
|
||||
|
|
||||
max_new_tokens = 200 |
|
||||
prompt_num_tokens = await num_tokens(prompt, bot.model) |
|
||||
|
|
||||
# Define your inputs |
|
||||
input_data = { |
|
||||
"prompt": prompt, |
|
||||
"max_context_length": 2048, |
|
||||
"max_length": max_new_tokens, |
|
||||
"temperature": bot.temperature, |
|
||||
"top_k": 50, |
|
||||
"top_p": 0.85, |
|
||||
"rep_pen": 1.08, |
|
||||
"rep_pen_range": 1024, |
|
||||
"stop_sequence": ["\nYou:", f"\n{bot.user_name}:", f"\n### Human:", '<|endoftext|>', '<START>'], |
|
||||
} |
|
||||
|
|
||||
logger.info(f"sending request to koboldcpp") |
|
||||
|
|
||||
TIMEOUT = 360 |
|
||||
DELAY = 5 |
|
||||
tokens = 0 |
|
||||
complete = False |
|
||||
complete_reply = "" |
|
||||
for i in range(TIMEOUT//DELAY): |
|
||||
input_data["max_length"] = 32 # pseudo streaming |
|
||||
# Make the request |
|
||||
try: |
|
||||
r = requests.post(endpoint, json=input_data, headers=headers, timeout=600) |
|
||||
except requests.exceptions.RequestException as e: |
|
||||
raise ValueError(f"<ERROR> HTTP ERROR {e}") |
|
||||
r_json = r.json() |
|
||||
logger.info(r_json) |
|
||||
if r.status_code == 200: |
|
||||
partial_reply = r_json["results"][0]["text"] |
|
||||
input_data["prompt"] += partial_reply |
|
||||
complete_reply += partial_reply |
|
||||
tokens += input_data["max_length"] |
|
||||
await typing_fn() |
|
||||
if not partial_reply or tokens >= max_new_tokens +100: # ToDo: is a hundred past the limit okay? |
|
||||
complete = True |
|
||||
break |
|
||||
for t in [f"\nYou:", f"\n### Human:", f"\n{bot.user_name}:", '<|endoftext|>', '</END>', '<END>', '__END__', '<START>']: |
|
||||
idx = complete_reply.find(t) |
|
||||
if idx != -1: |
|
||||
complete_reply = complete_reply[:idx].strip() |
|
||||
complete = True |
|
||||
if complete: |
|
||||
break |
|
||||
elif r.status_code == 503: |
|
||||
#model busy |
|
||||
await asyncio.sleep(DELAY) |
|
||||
else: |
|
||||
raise ValueError(f"<ERROR>") |
|
||||
|
|
||||
if complete_reply: |
|
||||
complete_reply = complete_reply.removesuffix('<|endoftext|>') |
|
||||
complete_reply = complete_reply.replace(f"<BOT>", f"{bot.name}") |
|
||||
complete_reply = complete_reply.replace(f"<USER>", f"You") |
|
||||
complete_reply = complete_reply.replace(f"### Assistant", f"{bot.name}") |
|
||||
complete_reply = complete_reply.replace(f"\n{bot.name}: ", " ") |
|
||||
return complete_reply.strip() |
|
||||
else: |
|
||||
raise ValueError(f"<ERROR> Timeout") |
|
||||
|
|
||||
|
|
||||
async def generate_image(input_prompt: str, negative_prompt: str, api_url: str, api_key: str, typing_fn): |
|
||||
pass |
|
||||
|
|
@ -1,118 +0,0 @@ |
|||||
# https://github.com/nsarrazin/serge/blob/main/api/utils/generate.py |
|
||||
|
|
||||
import subprocess, os |
|
||||
import asyncio |
|
||||
import logging |
|
||||
|
|
||||
logger = logging.getLogger(__name__) |
|
||||
|
|
||||
|
|
||||
async def generate( |
|
||||
prompt: str, |
|
||||
): |
|
||||
CHUNK_SIZE = 4 |
|
||||
|
|
||||
args = ( |
|
||||
"/home/hendrik/Projects/AI/alpaca.cpp/chat", |
|
||||
"--model", |
|
||||
"/home/hendrik/Projects/AI/alpaca.cpp/" + "ggml-alpaca-7b-q4.bin", |
|
||||
"--prompt", |
|
||||
prompt, |
|
||||
"--n_predict", |
|
||||
str(256), |
|
||||
"--temp", |
|
||||
str(0.1), |
|
||||
"--top_k", |
|
||||
str(50), |
|
||||
"--top_p", |
|
||||
str(0.95), |
|
||||
"--repeat_last_n", |
|
||||
str(64), |
|
||||
"--repeat_penalty", |
|
||||
str(1.3), |
|
||||
"--ctx_size", |
|
||||
str(512), |
|
||||
"--threads", |
|
||||
str(4) |
|
||||
) |
|
||||
|
|
||||
logger.debug(f"Calling LLaMa with arguments", args) |
|
||||
print(prompt) |
|
||||
procLlama = await asyncio.create_subprocess_exec( |
|
||||
*args, stdout=subprocess.PIPE, stderr=subprocess.PIPE |
|
||||
) |
|
||||
|
|
||||
while True: |
|
||||
chunk = await procLlama.stdout.read(CHUNK_SIZE) |
|
||||
|
|
||||
if not chunk: |
|
||||
return_code = await procLlama.wait() |
|
||||
|
|
||||
if return_code != 0: |
|
||||
error_output = await procLlama.stderr.read() |
|
||||
logger.error(error_output.decode("utf-8")) |
|
||||
raise ValueError(f"RETURN CODE {return_code}\n\n"+error_output.decode("utf-8")) |
|
||||
else: |
|
||||
return |
|
||||
|
|
||||
try: |
|
||||
chunk = chunk.decode("utf-8") |
|
||||
except UnicodeDecodeError: |
|
||||
return |
|
||||
|
|
||||
yield chunk |
|
||||
|
|
||||
|
|
||||
async def get_full_prompt(simple_prompt: str, chat_history=None): |
|
||||
|
|
||||
prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request." + "\n\n" |
|
||||
|
|
||||
HISTORY_LEN = 5 |
|
||||
if chat_history: |
|
||||
for message in chat_history[-HISTORY_LEN:]: |
|
||||
if not message["is_own_message"]: |
|
||||
prompt += "### Instruction:\n" + message["message"] + "\n" |
|
||||
else: |
|
||||
prompt += "### Response:\n" + message["message"] + "\n" |
|
||||
|
|
||||
prompt += "### Instruction:\n" + simple_prompt + "\n" |
|
||||
prompt += "### Response:\n" |
|
||||
|
|
||||
return prompt |
|
||||
|
|
||||
|
|
||||
async def get_full_prompt_with_input(simple_prompt: str, additional_input: str, chat_history=None): |
|
||||
|
|
||||
prompt_with_input = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request." + "\n\n" |
|
||||
|
|
||||
HISTORY_LEN = 5 |
|
||||
if chat_history: |
|
||||
for message in chat_history[-HISTORY_LEN:]: |
|
||||
if not message["is_own_message"]: |
|
||||
prompt += "### Instruction:\n" + message["message"] + "\n" |
|
||||
else: |
|
||||
prompt += "### Response:\n" + message["message"] + "\n" |
|
||||
|
|
||||
prompt += "### Instruction:\n" + simple_prompt + "\n" |
|
||||
prompt += "### Input:\n" + additional_input + "\n" |
|
||||
prompt += "### Response:\n" |
|
||||
|
|
||||
return prompt |
|
||||
|
|
||||
|
|
||||
async def get_full_prompt_chat_style(simple_prompt: str, chat_history=None): |
|
||||
|
|
||||
prompt = "Transcript of a dialog, where the User interacts with an Assistant named Julia. Julia is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision." + "\n\n" |
|
||||
|
|
||||
HISTORY_LEN = 5 |
|
||||
if chat_history: |
|
||||
for message in chat_history[-HISTORY_LEN:]: |
|
||||
if not message["is_own_message"]: |
|
||||
prompt += "User: " + message["message"] + "\n" |
|
||||
else: |
|
||||
prompt += "Julia: " + message["message"] + "\n" |
|
||||
|
|
||||
prompt += "User: " + simple_prompt + "\n" |
|
||||
prompt += "Julia: " |
|
||||
|
|
||||
return prompt |
|
@ -1,234 +0,0 @@ |
|||||
import asyncio |
|
||||
import os, tempfile |
|
||||
import logging |
|
||||
|
|
||||
import json |
|
||||
import requests |
|
||||
|
|
||||
from transformers import AutoTokenizer, AutoConfig |
|
||||
from huggingface_hub import hf_hub_download |
|
||||
|
|
||||
import io |
|
||||
import base64 |
|
||||
from PIL import Image, PngImagePlugin |
|
||||
|
|
||||
logger = logging.getLogger(__name__) |
|
||||
|
|
||||
|
|
||||
gptj_tokenizer = None |
|
||||
|
|
||||
|
|
||||
async def get_full_prompt(simple_prompt: str, bot, chat_history, model_name: str): |
|
||||
|
|
||||
# https://github.com/ggerganov/llama.cpp/tree/master/examples |
|
||||
## prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n" |
|
||||
# prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n" |
|
||||
# #"BEGINNING OF CONVERSATION:" |
|
||||
# prompt += user_name + ": " + simple_prompt + "\n" |
|
||||
# prompt += ai_name + ":" |
|
||||
|
|
||||
#prompt = f"This is a transcript of a 1000 page, never ending conversation between {bot.user_name} and the cute and helpful AI assistant {bot.name}. {bot.name} is a girl who is an AI running on the users computer.\n" |
|
||||
#prompt += f"{bot.name} can think for herself without the user seeing her thoughts by adding a /think prefix to her output. She uses this to reason about the world and to think about what she should say next.\n" |
|
||||
#prompt += f"{bot.name} is always coherent and makes sense, but if she isn't sure if what she is saying is correct she will ask the user for help.\n" |
|
||||
#prompt += f"{bot.name} is a very helpful AI and will help the user with anything they need, she is also very friendly and will try to make the user feel better if they are sad.\n" |
|
||||
#prompt += f"{bot.name} is also very curious and will ask the user a lot of questions about themselves and their life, she will also try to make the user like her.\n" |
|
||||
#prompt += f"\n" |
|
||||
|
|
||||
|
|
||||
# Names |
|
||||
if model_name.startswith("pygmalion"): |
|
||||
ai_name = bot.name |
|
||||
user_name = "You" |
|
||||
elif model_name.startswith("vicuna"): |
|
||||
ai_name = "### Assistant" |
|
||||
user_name = "### Human" |
|
||||
elif model_name.startswith("alpaca"): |
|
||||
ai_name = bot.name # ToDo |
|
||||
user_name = bot.user_name # ToDo |
|
||||
elif model_name.startswith("koboldai"): |
|
||||
ai_name = bot.name |
|
||||
user_name = "You" |
|
||||
else: |
|
||||
ai_name = bot.name |
|
||||
user_name = bot.user_name |
|
||||
|
|
||||
# First line / Task description |
|
||||
if model_name.startswith("pygmalion"): |
|
||||
prompt = "" |
|
||||
elif model_name.startswith("vicuna"): |
|
||||
prompt = f"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n" |
|
||||
prompt += f"### Instruction:\nGiven the following character description and scenario, write a script for a dialogue between the human user {bot.user_name} and the fictional AI assistant {bot.name}. Play the role of the character {bot.name}.\n" # ToDo |
|
||||
if bot.nsfw: |
|
||||
prompt += f"NSFW content allowed.\n" # ToDo |
|
||||
prompt += f"\n" |
|
||||
prompt += "### Input:\n" |
|
||||
elif model_name.startswith("alpaca"): |
|
||||
prompt = f"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n" |
|
||||
prompt += f"### Instruction:\nGiven the following character description and scenario, write a script for a dialogue between the human user {bot.user_name} and the fictional AI assistant {bot.name}. Play the role of the character {bot.name}.\n" |
|
||||
# if bot.nsfw: |
|
||||
# prompt += f"NSFW content allowed.\n" # ToDo |
|
||||
prompt += f"\n" |
|
||||
prompt += "### Input:\n" |
|
||||
elif model_name.startswith("koboldai"): |
|
||||
prompt = "" |
|
||||
else: |
|
||||
prompt = "" |
|
||||
|
|
||||
# Character description |
|
||||
if model_name.startswith("pygmalion"): |
|
||||
prompt += bot.name + "'s Persona: " + bot.get_persona() + "\n" |
|
||||
elif model_name.startswith("vicuna"): |
|
||||
prompt += bot.name + "'s Persona: " + bot.get_persona() + "\n" # ToDo |
|
||||
elif model_name.startswith("alpaca"): |
|
||||
prompt += bot.name + "'s Persona: " + bot.get_persona() + "\n" # ToDo |
|
||||
elif model_name.startswith("koboldai"): |
|
||||
prompt += f"[Character: {bot.get_persona()}]\n" |
|
||||
else: |
|
||||
prompt += bot.name + "'s Persona: " + bot.get_persona() + "\n" # ToDo |
|
||||
|
|
||||
# Scenario |
|
||||
if model_name.startswith("pygmalion"): |
|
||||
prompt += "Scenario: " + bot.get_scenario() + "\n\n" |
|
||||
elif model_name.startswith("vicuna"): |
|
||||
prompt += "Scenario: " + bot.get_scenario() + "\n\n" # ToDo |
|
||||
elif model_name.startswith("alpaca"): |
|
||||
prompt += "Scenario: " + bot.get_scenario() + "\n\n" # ToDo |
|
||||
elif model_name.startswith("koboldai"): |
|
||||
prompt += f"[Start Scene: {bot.get_scenario()}]\n\n" |
|
||||
else: |
|
||||
prompt += "Scenario: " + bot.get_scenario() + "\n\n" # ToDo |
|
||||
|
|
||||
# Response delimiter |
|
||||
if model_name.startswith("pygmalion"): |
|
||||
pass |
|
||||
elif model_name.startswith("vicuna"): |
|
||||
prompt += "### Response:\n" # ToDo |
|
||||
elif model_name.startswith("alpaca"): |
|
||||
prompt += "### Response:\n" |
|
||||
elif model_name.startswith("koboldai"): |
|
||||
pass |
|
||||
else: |
|
||||
pass |
|
||||
|
|
||||
# Example dialogue |
|
||||
for dialogue_item in bot.get_example_dialogue(): |
|
||||
if model_name.startswith("pygmalion"): |
|
||||
prompt += "<START>" + "\n" |
|
||||
dialogue_item = dialogue_item.replace('{{user}}', user_name) |
|
||||
dialogue_item = dialogue_item.replace('{{char}}', ai_name) |
|
||||
prompt += dialogue_item + "\n\n" |
|
||||
|
|
||||
# Dialogue start |
|
||||
if model_name.startswith("pygmalion"): |
|
||||
prompt += "<START>" + "\n" |
|
||||
elif model_name.startswith("vicuna"): |
|
||||
pass # ToDo |
|
||||
elif model_name.startswith("alpaca"): |
|
||||
pass # ToDo |
|
||||
elif model_name.startswith("koboldai"): |
|
||||
pass |
|
||||
else: |
|
||||
pass # ToDo |
|
||||
|
|
||||
|
|
||||
#prompt += f"{ai_name}: {bot.greeting}\n" |
|
||||
#prompt += f"{user_name}: {simple_prompt}\n" |
|
||||
#prompt += f"{ai_name}:" |
|
||||
|
|
||||
MAX_TOKENS = 2048 |
|
||||
if bot.service_text == "koboldcpp": |
|
||||
WINDOW = 600 |
|
||||
else: |
|
||||
WINDOW = 0 |
|
||||
max_new_tokens = 200 |
|
||||
total_num_tokens = await num_tokens(prompt, model_name) |
|
||||
input_num_tokens = await num_tokens(f"{user_name}: {simple_prompt}\n{ai_name}:", model_name) |
|
||||
total_num_tokens += input_num_tokens |
|
||||
visible_history = [] |
|
||||
num_message = 0 |
|
||||
for key, chat_item in reversed(chat_history.chat_history.items()): |
|
||||
num_message += 1 |
|
||||
if num_message == 1: |
|
||||
# skip current_message |
|
||||
continue |
|
||||
if chat_item.stop_here: |
|
||||
break |
|
||||
if chat_item.message["en"].startswith('!begin'): |
|
||||
break |
|
||||
if chat_item.message["en"].startswith('!'): |
|
||||
continue |
|
||||
if chat_item.message["en"].startswith('<ERROR>'): |
|
||||
continue |
|
||||
#if chat_item.message["en"] == bot.greeting: |
|
||||
# continue |
|
||||
if chat_item.num_tokens == None: |
|
||||
chat_history.chat_history[key].num_tokens = await num_tokens("{}: {}".format(user_name, chat_item.message["en"]), model_name) |
|
||||
chat_item = chat_history.chat_history[key] |
|
||||
# TODO: is it MAX_TOKENS or MAX_TOKENS - max_new_tokens?? |
|
||||
logger.debug(f"History: " + str(chat_item) + " [" + str(chat_item.num_tokens) + "]") |
|
||||
if total_num_tokens + chat_item.num_tokens <= MAX_TOKENS - WINDOW - max_new_tokens: |
|
||||
visible_history.append(chat_item) |
|
||||
total_num_tokens += chat_item.num_tokens |
|
||||
else: |
|
||||
break |
|
||||
visible_history = reversed(visible_history) |
|
||||
|
|
||||
if not hasattr(bot, "greeting_num_tokens"): |
|
||||
bot.greeting_num_tokens = await num_tokens(bot.greeting, model_name) |
|
||||
if total_num_tokens + bot.greeting_num_tokens <= MAX_TOKENS - WINDOW - max_new_tokens: |
|
||||
prompt += f"{ai_name}: {bot.greeting}\n" |
|
||||
total_num_tokens += bot.greeting_num_tokens |
|
||||
|
|
||||
for chat_item in visible_history: |
|
||||
if chat_item.is_own_message: |
|
||||
line = f"{ai_name}: {chat_item.message['en']}\n" |
|
||||
else: |
|
||||
line = f"{user_name}: {chat_item.message['en']}\n" |
|
||||
prompt += line |
|
||||
if chat_history.getSavedPrompt() and not chat_item.is_in_saved_prompt: |
|
||||
logger.info(f"adding to saved prompt: \"line\"") |
|
||||
chat_history.setSavedPrompt( chat_history.getSavedPrompt() + line, chat_history.saved_context_num_tokens + chat_item.num_tokens ) |
|
||||
chat_item.is_in_saved_prompt = True |
|
||||
|
|
||||
if chat_history.saved_context_num_tokens: |
|
||||
logger.info(f"saved_context has {chat_history.saved_context_num_tokens+input_num_tokens} tokens. new context would be {total_num_tokens}. Limit is {MAX_TOKENS}") |
|
||||
if chat_history.getSavedPrompt(): |
|
||||
if chat_history.saved_context_num_tokens+input_num_tokens > MAX_TOKENS - max_new_tokens: |
|
||||
chat_history.setFastForward(False) |
|
||||
if chat_history.getFastForward(): |
|
||||
logger.info("using saved prompt") |
|
||||
prompt = chat_history.getSavedPrompt() |
|
||||
if not chat_history.getSavedPrompt() or not chat_history.getFastForward(): |
|
||||
logger.info("regenerating prompt") |
|
||||
chat_history.setSavedPrompt(prompt, total_num_tokens) |
|
||||
for key, chat_item in chat_history.chat_history.items(): |
|
||||
if key != list(chat_history.chat_history)[-1]: # exclude current item |
|
||||
chat_history.chat_history[key].is_in_saved_prompt = True |
|
||||
chat_history.setFastForward(True) |
|
||||
|
|
||||
prompt += f"{user_name}: {simple_prompt}\n" |
|
||||
if bot.nsfw and model_name.startswith("vicuna"): |
|
||||
prompt += f"{ai_name}: Sure" |
|
||||
else: |
|
||||
prompt += f"{ai_name}:" |
|
||||
|
|
||||
return prompt |
|
||||
|
|
||||
|
|
||||
async def num_tokens(input_text: str, model_name: str): |
|
||||
if model_name.startswith("pygmalion"): |
|
||||
#os.makedirs("./models/pygmalion-6b", exist_ok=True) |
|
||||
#hf_hub_download(repo_id="PygmalionAI/pygmalion-6b", filename="config.json", cache_dir="./models/pygmalion-6b") |
|
||||
#config = AutoConfig.from_pretrained("./models/pygmalion-6b/config.json") |
|
||||
global gptj_tokenizer |
|
||||
if not gptj_tokenizer: |
|
||||
gptj_tokenizer = AutoTokenizer.from_pretrained("PygmalionAI/pygmalion-6b") |
|
||||
encoding = gptj_tokenizer.encode(input_text, add_special_tokens=False) |
|
||||
max_input_size = gptj_tokenizer.max_model_input_sizes |
|
||||
return len(encoding) |
|
||||
else: |
|
||||
return await estimate_num_tokens(input_text) |
|
||||
|
|
||||
|
|
||||
async def estimate_num_tokens(input_text: str): |
|
||||
return len(input_text)//4+1 |
|
@ -1,378 +0,0 @@ |
|||||
import asyncio |
|
||||
import os, tempfile |
|
||||
import logging |
|
||||
|
|
||||
import json |
|
||||
import requests |
|
||||
|
|
||||
from transformers import AutoTokenizer, AutoConfig |
|
||||
from huggingface_hub import hf_hub_download |
|
||||
|
|
||||
import io |
|
||||
import base64 |
|
||||
from PIL import Image, PngImagePlugin |
|
||||
|
|
||||
from .model_helpers import get_full_prompt, num_tokens |
|
||||
|
|
||||
logger = logging.getLogger(__name__) |
|
||||
|
|
||||
|
|
||||
async def generate_sync( |
|
||||
prompt: str, |
|
||||
api_key: str, |
|
||||
bot, |
|
||||
typing_fn |
|
||||
): |
|
||||
# Set the API endpoint URL |
|
||||
endpoint = f"https://api.runpod.ai/v2/{bot.runpod_text_endpoint}/run" |
|
||||
|
|
||||
# Set the headers for the request |
|
||||
headers = { |
|
||||
"Content-Type": "application/json", |
|
||||
"Authorization": f"Bearer {api_key}" |
|
||||
} |
|
||||
|
|
||||
max_new_tokens = 200 |
|
||||
prompt_num_tokens = await num_tokens(prompt, bot.model) |
|
||||
|
|
||||
# Define your inputs |
|
||||
input_data = { |
|
||||
"input": { |
|
||||
"prompt": prompt, |
|
||||
"max_length": min(prompt_num_tokens+max_new_tokens, 2048), |
|
||||
"temperature": bot.temperature, |
|
||||
"do_sample": True, |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
input_data_oobabooga = { |
|
||||
"input": { |
|
||||
"data": [json.dumps([ |
|
||||
prompt, |
|
||||
{ |
|
||||
'max_new_tokens': min(max_new_tokens, 2048), |
|
||||
'do_sample': True, |
|
||||
'temperature': bot.temperature, |
|
||||
'top_p': 0.73, |
|
||||
'typical_p': 1, |
|
||||
'repetition_penalty': 1.1, |
|
||||
'encoder_repetition_penalty': 1.0, |
|
||||
'top_k': 0, |
|
||||
'min_length': 0, |
|
||||
'no_repeat_ngram_size': 0, |
|
||||
'num_beams': 1, |
|
||||
'penalty_alpha': 0, |
|
||||
'length_penalty': 1, |
|
||||
'early_stopping': False, |
|
||||
'seed': -1, |
|
||||
'add_bos_token': True, |
|
||||
'stopping_strings': [f"\n{bot.user_name}:"], |
|
||||
'truncation_length': 2048, |
|
||||
'ban_eos_token': False, |
|
||||
'skip_special_tokens': True, |
|
||||
} |
|
||||
])] |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
if bot.runpod_text_endpoint in ['pygmalion-6b', 'gpt-neo-2_7b', 'gpt-neo-1_3b']: |
|
||||
api_mode = "runpod" |
|
||||
else: |
|
||||
api_mode = "oobabooga" |
|
||||
|
|
||||
logger.info(f"sending request to runpod.io") |
|
||||
|
|
||||
# Make the request |
|
||||
try: |
|
||||
if api_mode == "runpod": |
|
||||
r = requests.post(endpoint, json=input_data, headers=headers, timeout=180) |
|
||||
else: |
|
||||
r = requests.post(endpoint, json=input_data_oobabooga, headers=headers, timeout=180) |
|
||||
except requests.exceptions.RequestException as e: |
|
||||
raise ValueError(f"<HTTP ERROR>") |
|
||||
r_json = r.json() |
|
||||
logger.info(r_json) |
|
||||
|
|
||||
if r.status_code == 200: |
|
||||
status = r_json["status"] |
|
||||
job_id = r_json["id"] |
|
||||
TIMEOUT = 360 |
|
||||
DELAY = 5 |
|
||||
for i in range(TIMEOUT//DELAY): |
|
||||
endpoint = f"https://api.runpod.ai/v2/{bot.runpod_text_endpoint}/status/{job_id}" |
|
||||
r = requests.get(endpoint, headers=headers) |
|
||||
r_json = r.json() |
|
||||
logger.info(r_json) |
|
||||
status = r_json["status"] |
|
||||
if status == 'IN_PROGRESS': |
|
||||
await typing_fn() |
|
||||
await asyncio.sleep(DELAY) |
|
||||
elif status == 'IN_QUEUE': |
|
||||
await asyncio.sleep(DELAY) |
|
||||
elif status == 'COMPLETED': |
|
||||
output = r_json["output"] |
|
||||
if isinstance(output, list): |
|
||||
output.sort(key=len, reverse=True) |
|
||||
text = output[0] |
|
||||
else: |
|
||||
text = output |
|
||||
if api_mode == "runpod": |
|
||||
reply = text.removeprefix(prompt) |
|
||||
else: |
|
||||
reply = text["data"][0].removeprefix(prompt) |
|
||||
# lines = reply.split('\n') |
|
||||
# reply = lines[0].strip() |
|
||||
reply = reply.removesuffix('<|endoftext|>').strip() |
|
||||
reply = reply.replace(f"<BOT>", f"{bot.name}") |
|
||||
reply = reply.replace(f"<USER>", f"{bot.user_name}") |
|
||||
reply = reply.replace(f"\n{bot.name}: ", " ") |
|
||||
for t in [f"\nYou:", f"\n### Human:", f"\n{bot.user_name}:", '<|endoftext|>', '</END>', '<END>', '__END__', '<START>']: |
|
||||
idx = reply.find(t) |
|
||||
if idx != -1: |
|
||||
reply = reply[:idx].strip() |
|
||||
return reply |
|
||||
else: |
|
||||
err_msg = r_json["error"] if "error" in r_json else "" |
|
||||
err_msg = err_msg.replace("\\n", "\n") |
|
||||
raise ValueError(f"<ERROR> RETURN CODE {status}: {err_msg}") |
|
||||
raise ValueError(f"<ERROR> TIMEOUT") |
|
||||
else: |
|
||||
raise ValueError(f"<ERROR>") |
|
||||
|
|
||||
|
|
||||
async def download_image(url, path): |
|
||||
r = requests.get(url, stream=True) |
|
||||
if r.status_code == 200: |
|
||||
with open(path, 'wb') as f: |
|
||||
for chunk in r: |
|
||||
f.write(chunk) |
|
||||
|
|
||||
async def generate_image(input_prompt: str, negative_prompt: str, api_url: str, api_key: str, typing_fn): |
|
||||
|
|
||||
# Set the API endpoint URL |
|
||||
endpoint = api_url + "run" |
|
||||
|
|
||||
# Set the headers for the request |
|
||||
headers = { |
|
||||
"Content-Type": "application/json", |
|
||||
"Authorization": f"Bearer {api_key}" |
|
||||
} |
|
||||
|
|
||||
# Define your inputs |
|
||||
input_data = { |
|
||||
"input": { |
|
||||
"prompt": input_prompt, |
|
||||
"negative_prompt": negative_prompt, |
|
||||
"width": 512, |
|
||||
"height": 768, |
|
||||
"num_outputs": 3, |
|
||||
# "nsfw": True |
|
||||
}, |
|
||||
} |
|
||||
|
|
||||
logger.info(f"sending request to runpod.io") |
|
||||
|
|
||||
# Make the request |
|
||||
try: |
|
||||
r = requests.post(endpoint, json=input_data, headers=headers) |
|
||||
except requests.exceptions.RequestException as e: |
|
||||
raise ValueError(f"<ERROR> HTTP ERROR") |
|
||||
r_json = r.json() |
|
||||
logger.debug(r_json) |
|
||||
|
|
||||
if r.status_code == 200: |
|
||||
status = r_json["status"] |
|
||||
if status != 'IN_QUEUE': |
|
||||
err_msg = r_json["error"] if "error" in r_json else "" |
|
||||
raise ValueError(f"<ERROR> RETURN CODE {status}: {err_msg}") |
|
||||
job_id = r_json["id"] |
|
||||
TIMEOUT = 360 |
|
||||
DELAY = 5 |
|
||||
output = None |
|
||||
for i in range(TIMEOUT//DELAY): |
|
||||
endpoint = api_url + "status/" + job_id |
|
||||
r = requests.get(endpoint, headers=headers) |
|
||||
r_json = r.json() |
|
||||
logger.debug(r_json) |
|
||||
status = r_json["status"] |
|
||||
if status == 'IN_PROGRESS': |
|
||||
await typing_fn() |
|
||||
await asyncio.sleep(DELAY) |
|
||||
elif status == 'IN_QUEUE': |
|
||||
await asyncio.sleep(DELAY) |
|
||||
elif status == 'COMPLETED': |
|
||||
output = r_json["output"] |
|
||||
break |
|
||||
else: |
|
||||
err_msg = r_json["error"] if "error" in r_json else "" |
|
||||
err_msg = err_msg.replace("\\n", "\n") |
|
||||
raise ValueError(f"<ERROR> RETURN CODE {status}: {err_msg}") |
|
||||
|
|
||||
if not output: |
|
||||
raise ValueError(f"<ERROR>") |
|
||||
|
|
||||
os.makedirs("./images", exist_ok=True) |
|
||||
files = [] |
|
||||
for image in output: |
|
||||
temp_name = next(tempfile._get_candidate_names()) |
|
||||
filename = "./images/" + temp_name + ".jpg" |
|
||||
await download_image(image["image"], filename) |
|
||||
files.append(filename) |
|
||||
|
|
||||
return files |
|
||||
|
|
||||
async def generate_image1(input_prompt: str, negative_prompt: str, api_key: str, typing_fn): |
|
||||
# AnythingV4 |
|
||||
return await generate_image(input_prompt, negative_prompt, "https://api.runpod.ai/v1/sd-anything-v4/", api_key, typing_fn) |
|
||||
|
|
||||
async def generate_image2(input_prompt: str, negative_prompt: str, api_key: str, typing_fn): |
|
||||
# OpenJourney |
|
||||
return await generate_image(input_prompt, negative_prompt, "https://api.runpod.ai/v1/sd-openjourney/", api_key, typing_fn) |
|
||||
|
|
||||
async def generate_image3(input_prompt: str, negative_prompt: str, api_key: str, typing_fn): |
|
||||
# Hassanblend |
|
||||
return await generate_image(input_prompt, negative_prompt, "https://api.runpod.ai/v1/mf5f6mocy8bsvx/", api_key, typing_fn) |
|
||||
|
|
||||
async def generate_image4(input_prompt: str, negative_prompt: str, api_key: str, typing_fn): |
|
||||
# DeliberateV2 |
|
||||
return await generate_image_automatic(input_prompt, negative_prompt, "https://api.runpod.ai/v1/lxdhmiccp3vdsf/", api_key, typing_fn) |
|
||||
|
|
||||
async def generate_image5(input_prompt: str, negative_prompt: str, api_key: str, typing_fn): |
|
||||
# Hassanblend |
|
||||
return await generate_image_automatic(input_prompt, negative_prompt, "https://api.runpod.ai/v1/13rrs00l7yxikf/", api_key, typing_fn) |
|
||||
|
|
||||
async def generate_image6(input_prompt: str, negative_prompt: str, api_key: str, typing_fn): |
|
||||
# PFG |
|
||||
return await generate_image_automatic(input_prompt, negative_prompt, "https://api.runpod.ai/v1/5j1xzlsyw84vk5/", api_key, typing_fn) |
|
||||
|
|
||||
async def generate_image7(input_prompt: str, negative_prompt: str, api_key: str, typing_fn): |
|
||||
# ChilloutMix |
|
||||
return await generate_image_automatic(input_prompt, negative_prompt, "https://api.runpod.ai/v2/rrjxafqx66osr4/", api_key, typing_fn) |
|
||||
|
|
||||
async def generate_image8(input_prompt: str, negative_prompt: str, api_key: str, typing_fn): |
|
||||
# AbyssOrangeMixV3 |
|
||||
return await generate_image_automatic(input_prompt, negative_prompt, "https://api.runpod.ai/v2/vuyifmsasm3ix7/", api_key, typing_fn) |
|
||||
|
|
||||
|
|
||||
async def serverless_automatic_request(payload, cmd, api_url: str, api_key: str, typing_fn): |
|
||||
# Set the API endpoint URL |
|
||||
endpoint = api_url + "run" |
|
||||
|
|
||||
# Set the headers for the request |
|
||||
headers = { |
|
||||
"Content-Type": "application/json", |
|
||||
"Authorization": f"Bearer {api_key}" |
|
||||
} |
|
||||
|
|
||||
# Define your inputs |
|
||||
payload.update({"api_endpoint": str(cmd)}) |
|
||||
input_data = { |
|
||||
"input": payload, |
|
||||
"cmd": cmd, |
|
||||
} |
|
||||
|
|
||||
logger.info(f"sending request to runpod.io") |
|
||||
|
|
||||
# Make the request |
|
||||
try: |
|
||||
r = requests.post(endpoint, json=input_data, headers=headers) |
|
||||
except requests.exceptions.RequestException as e: |
|
||||
raise ValueError(f"<ERROR> HTTP ERROR") |
|
||||
r_json = r.json() |
|
||||
logger.debug(r_json) |
|
||||
|
|
||||
if r.status_code == 200: |
|
||||
status = r_json["status"] |
|
||||
if status != 'IN_QUEUE': |
|
||||
err_msg = r_json["error"] if "error" in r_json else "" |
|
||||
raise ValueError(f"<ERROR> RETURN CODE {status}: {err_msg}") |
|
||||
job_id = r_json["id"] |
|
||||
TIMEOUT = 360 |
|
||||
DELAY = 5 |
|
||||
output = None |
|
||||
for i in range(TIMEOUT//DELAY): |
|
||||
endpoint = api_url + "status/" + job_id |
|
||||
r = requests.get(endpoint, headers=headers) |
|
||||
r_json = r.json() |
|
||||
logger.debug(r_json) |
|
||||
status = r_json["status"] |
|
||||
if status == 'IN_PROGRESS': |
|
||||
await typing_fn() |
|
||||
await asyncio.sleep(DELAY) |
|
||||
elif status == 'IN_QUEUE': |
|
||||
await asyncio.sleep(DELAY) |
|
||||
elif status == 'COMPLETED': |
|
||||
output = r_json["output"] |
|
||||
break |
|
||||
else: |
|
||||
err_msg = r_json["error"] if "error" in r_json else "" |
|
||||
raise ValueError(f"<ERROR> RETURN CODE {status}: {err_msg}") |
|
||||
|
|
||||
if not output: |
|
||||
raise ValueError(f"<ERROR> {status}") |
|
||||
|
|
||||
return output |
|
||||
|
|
||||
|
|
||||
async def generate_image_automatic(input_prompt: str, negative_prompt: str, api_url: str, api_key: str, typing_fn): |
|
||||
payload = { |
|
||||
"prompt": input_prompt, |
|
||||
"nagative_prompt": negative_prompt, |
|
||||
"steps": 25, |
|
||||
"cfg_scale": 7, |
|
||||
"seed": -1, |
|
||||
"width": 512, |
|
||||
"height": 768, |
|
||||
"batch_size": 3, |
|
||||
# "sampler_index": "DPM++ 2M Karras", |
|
||||
# "enable_hr": True, |
|
||||
# "hr_scale": 2, |
|
||||
# "hr_upscaler": "ESRGAN_4x", # "Latent" |
|
||||
# "denoising_strength": 0.5, |
|
||||
# "hr_second_pass_steps": 15, |
|
||||
"restore_faces": True, |
|
||||
# "gfpgan_visibility": 0.5, |
|
||||
# "codeformer_visibility": 0.5, |
|
||||
# "codeformer_weight": 0.5, |
|
||||
## "override_settings": { |
|
||||
## "filter_nsfw": False, |
|
||||
## }, |
|
||||
} |
|
||||
|
|
||||
output = await serverless_automatic_request(payload, "txt2img", api_url, api_key, typing_fn) |
|
||||
|
|
||||
upscale = False |
|
||||
if upscale: |
|
||||
count = 0 |
|
||||
for i in output['images']: |
|
||||
payload = { |
|
||||
"init_images": [i], |
|
||||
"prompt": input_prompt, |
|
||||
"nagative_prompt": negative_prompt, |
|
||||
"steps": 20, |
|
||||
"seed": -1, |
|
||||
#"sampler_index": "Euler", |
|
||||
# tile_width, tile_height, mask_blur, padding, seams_fix_width, seams_fix_denoise, seams_fix_padding, upscaler_index, save_upscaled_image, redraw_mode, save_seams_fix_image, seams_fix_mask_blur, seams_fix_type, target_size_type, custom_width, custom_height, custom_scale |
|
||||
# "script_args": ["",512,0,8,32,64,0.275,32,3,False,0,True,8,3,2,1080,1440,1.875], |
|
||||
# "script_name": "Ultimate SD upscale", |
|
||||
} |
|
||||
upscaled_output = await serverless_automatic_request(payload, "img2img", api_url, api_key, typing_fn) |
|
||||
output['images'][count] = upscaled_output['images'][count] |
|
||||
|
|
||||
os.makedirs("./images", exist_ok=True) |
|
||||
files = [] |
|
||||
for i in output['images']: |
|
||||
temp_name = next(tempfile._get_candidate_names()) |
|
||||
filename = "./images/" + temp_name + ".png" |
|
||||
image = Image.open(io.BytesIO(base64.b64decode(i.split(",",1)[0]))) |
|
||||
info = output['info'] |
|
||||
parameters = output['parameters'] |
|
||||
pnginfo = PngImagePlugin.PngInfo() |
|
||||
pnginfo.add_text("parameters", info) |
|
||||
image.save(filename, pnginfo=pnginfo) |
|
||||
|
|
||||
files.append(filename) |
|
||||
|
|
||||
return files |
|
||||
|
|
||||
|
|
@ -1,193 +0,0 @@ |
|||||
import asyncio |
|
||||
import os, tempfile |
|
||||
import logging |
|
||||
|
|
||||
import requests |
|
||||
|
|
||||
from transformers import AutoTokenizer, AutoConfig |
|
||||
from huggingface_hub import hf_hub_download |
|
||||
|
|
||||
from .model_helpers import get_full_prompt, num_tokens |
|
||||
|
|
||||
logger = logging.getLogger(__name__) |
|
||||
|
|
||||
|
|
||||
async def generate_sync( |
|
||||
prompt: str, |
|
||||
api_key: str, |
|
||||
bot_name: str, |
|
||||
): |
|
||||
# Set the API endpoint URL |
|
||||
endpoint = "https://koboldai.net/api/v2/generate/async" |
|
||||
|
|
||||
# Set the headers for the request |
|
||||
headers = { |
|
||||
"Content-Type": "application/json", |
|
||||
"accept": "application/json", |
|
||||
"apikey": f"{api_key}" |
|
||||
} |
|
||||
|
|
||||
max_new_tokens = 200 |
|
||||
prompt_num_tokens = await num_tokens(prompt, bot.model) |
|
||||
|
|
||||
# Define your inputs |
|
||||
input_data = { |
|
||||
"prompt": prompt, |
|
||||
"params": { |
|
||||
"n": 1, |
|
||||
# "frmtadsnsp": False, |
|
||||
# "frmtrmblln": False, |
|
||||
# "frmtrmspch": False, |
|
||||
# "frmttriminc": False, |
|
||||
"max_context_length": 1024, |
|
||||
"max_length": 512, |
|
||||
"rep_pen": 1.1, |
|
||||
"rep_pen_range": 1024, |
|
||||
"rep_pen_slope": 0.7, |
|
||||
# "singleline": False, |
|
||||
# "soft_prompt": "", |
|
||||
"temperature": 0.75, |
|
||||
"tfs": 1.0, |
|
||||
"top_a": 0.0, |
|
||||
"top_k": 0, |
|
||||
"top_p": 0.9, |
|
||||
"typical": 1.0, |
|
||||
# "sampler_order": [0], |
|
||||
}, |
|
||||
"softprompts": [], |
|
||||
"trusted_workers": False, |
|
||||
"nsfw": True, |
|
||||
# "workers": [], |
|
||||
"models": ["PygmalionAI/pygmalion-6b"] |
|
||||
} |
|
||||
|
|
||||
|
|
||||
logger.info(f"sending request to koboldai.net") |
|
||||
|
|
||||
# Make the request |
|
||||
r = requests.post(endpoint, json=input_data, headers=headers, timeout=180) |
|
||||
|
|
||||
r_json = r.json() |
|
||||
logger.info(r_json) |
|
||||
status = r_json["message"] |
|
||||
|
|
||||
if "id" in r_json: |
|
||||
job_id = r_json["id"] |
|
||||
TIMEOUT = 360 |
|
||||
DELAY = 11 |
|
||||
for i in range(TIMEOUT//DELAY): |
|
||||
endpoint = "https://koboldai.net/api/v2/generate/text/status/" + job_id |
|
||||
r = requests.get(endpoint, headers=headers) |
|
||||
r_json = r.json() |
|
||||
logger.info(r_json) |
|
||||
if "done" not in r_json: |
|
||||
return "<ERROR>" |
|
||||
if r_json["done"] == True: |
|
||||
text = r_json["generations"][0]["text"] |
|
||||
answer = text.removeprefix(prompt) |
|
||||
idx = answer.find(f"\nYou:") |
|
||||
if idx != -1: |
|
||||
reply = answer[:idx].strip() |
|
||||
else: |
|
||||
reply = answer.removesuffix('<|endoftext|>').strip() |
|
||||
reply = reply.replace(f"\n{bot_name}: ", " ") |
|
||||
reply.replace(f"\n<BOT>: ", " ") |
|
||||
return reply |
|
||||
else: |
|
||||
await asyncio.sleep(DELAY) |
|
||||
else: |
|
||||
return f"<ERROR> {status}" |
|
||||
|
|
||||
|
|
||||
async def generate_image(input_prompt: str, negative_prompt: str, model: str, api_key: str): |
|
||||
|
|
||||
# Set the API endpoint URL |
|
||||
endpoint = "https://stablehorde.net/api/v2/generate/async" |
|
||||
|
|
||||
# Set the headers for the request |
|
||||
headers = { |
|
||||
"Content-Type": "application/json", |
|
||||
"accept": "application/json", |
|
||||
"apikey": f"{api_key}" |
|
||||
} |
|
||||
|
|
||||
# Define your inputs |
|
||||
input_data = { |
|
||||
"prompt": input_prompt, |
|
||||
"params": { |
|
||||
# "negative_prompt": negative_prompt, |
|
||||
"width": 512, |
|
||||
"height": 512, |
|
||||
}, |
|
||||
"nsfw": True, |
|
||||
"trusted_workers": False, |
|
||||
# "workers": [], |
|
||||
"models": [f"{model}"] |
|
||||
} |
|
||||
|
|
||||
logger.info(f"sending request to stablehorde.net") |
|
||||
|
|
||||
# Make the request |
|
||||
r = requests.post(endpoint, json=input_data, headers=headers) |
|
||||
r_json = r.json() |
|
||||
logger.info(r_json) |
|
||||
|
|
||||
if r.status_code == 202: |
|
||||
#status = r_json["message"] |
|
||||
job_id = r_json["id"] |
|
||||
TIMEOUT = 360 |
|
||||
DELAY = 11 |
|
||||
output = None |
|
||||
for i in range(TIMEOUT//DELAY): |
|
||||
endpoint = "https://stablehorde.net/api/v2/generate/status/" + job_id |
|
||||
r = requests.get(endpoint, headers=headers) |
|
||||
r_json = r.json() |
|
||||
logger.info(r_json) |
|
||||
#status = r_json["message"] |
|
||||
if "done" not in r_json: |
|
||||
return "<ERROR>" |
|
||||
if "faulted" in r_json and r_json["faulted"] == True: |
|
||||
return "<ERROR>" |
|
||||
if r_json["done"] == True: |
|
||||
output = r_json["generations"] |
|
||||
break |
|
||||
else: |
|
||||
await asyncio.sleep(DELAY) |
|
||||
|
|
||||
if not output: |
|
||||
raise ValueError(f"<ERROR>") |
|
||||
|
|
||||
os.makedirs("./images", exist_ok=True) |
|
||||
files = [] |
|
||||
for image in output: |
|
||||
temp_name = next(tempfile._get_candidate_names()) |
|
||||
filename = "./images/" + temp_name + ".jpg" |
|
||||
await download_image(image["img"], filename) |
|
||||
files.append(filename) |
|
||||
|
|
||||
return files |
|
||||
|
|
||||
async def generate_image1(input_prompt: str, negative_prompt: str, api_key: str): |
|
||||
return await generate_image(input_prompt, negative_prompt, "Deliberate", api_key) |
|
||||
|
|
||||
async def generate_image2(input_prompt: str, negative_prompt: str, api_key: str): |
|
||||
return await generate_image(input_prompt, negative_prompt, "PFG", api_key) |
|
||||
|
|
||||
async def generate_image3(input_prompt: str, negative_prompt: str, api_key: str): |
|
||||
return await generate_image(input_prompt, negative_prompt, "Hassanblend", api_key) |
|
||||
|
|
||||
async def download_image(url, path): |
|
||||
r = requests.get(url, stream=True) |
|
||||
if r.status_code == 200: |
|
||||
with open(path, 'wb') as f: |
|
||||
for chunk in r: |
|
||||
f.write(chunk) |
|
||||
|
|
||||
|
|
||||
|
|
||||
|
|
||||
|
|
||||
|
|
||||
|
|
||||
|
|
||||
|
|
@ -0,0 +1,63 @@ |
|||||
|
import asyncio |
||||
|
import time |
||||
|
from .prompts import * |
||||
|
from .langchain_memory import BotConversationSummerBufferWindowMemory |
||||
|
|
||||
|
from langchain import PromptTemplate |
||||
|
from langchain.chains import LLMChain |
||||
|
|
||||
|
|
||||
|
import logging |
||||
|
|
||||
|
logger = logging.getLogger(__name__) |
||||
|
|
||||
|
|
||||
|
class AI(object): |
||||
|
|
||||
|
def __init__(self, bot, text_wrapper, image_wrapper): |
||||
|
self.name = bot.name |
||||
|
self.bot = bot |
||||
|
|
||||
|
from ..wrappers.langchain_koboldcpp import KoboldCpp |
||||
|
self.llm_chat = KoboldCpp(temperature=self.bot.temperature, endpoint_url="http://172.16.85.10:5001/api/latest/generate", stop=['<|endoftext|>']) |
||||
|
self.llm_summary = KoboldCpp(temperature=0.2, endpoint_url="http://172.16.85.10:5001/api/latest/generate", stop=['<|endoftext|>']) |
||||
|
self.text_wrapper = text_wrapper |
||||
|
self.image_wrapper = image_wrapper |
||||
|
|
||||
|
self.memory = BotConversationSummerBufferWindowMemory(llm=self.llm_summary, max_token_limit=1200, min_token_limit=200) |
||||
|
|
||||
|
|
||||
|
|
||||
|
|
||||
|
async def generate(self, input_text): |
||||
|
prompt_template = "{input}" |
||||
|
chain = LLMChain( |
||||
|
llm=self.llm_chat, |
||||
|
prompt=PromptTemplate.from_template(prompt_template), |
||||
|
) |
||||
|
output = chain.run(input_text) |
||||
|
return output.strip() |
||||
|
|
||||
|
async def generate_roleplay(self, message, reply_fn, typing_fn): |
||||
|
prompt = PromptTemplate( |
||||
|
input_variables=["ai_name", "persona", "scenario", "chat_history", "human_name", "ai_name_chat", "human_input"], |
||||
|
template=prompt_template_alpaca, |
||||
|
) |
||||
|
template_roleplay = prompt.format( |
||||
|
ai_name = self.bot.name, |
||||
|
persona = self.bot.persona, |
||||
|
scenario = self.bot.scenario, |
||||
|
chat_history = "{history}", |
||||
|
human_name = message.user_name, |
||||
|
ai_name_chat = self.bot.name, |
||||
|
human_input = "{input}", |
||||
|
) |
||||
|
chain = LLMChain( |
||||
|
llm=self.llm_chat, |
||||
|
prompt=PromptTemplate.from_template(template_roleplay), |
||||
|
verbose=True, |
||||
|
memory=self.memory, |
||||
|
#stop=['<|endoftext|>', '\nYou:', f"\n{message.user_name}:"], |
||||
|
) |
||||
|
output = chain.run(message.message) |
||||
|
return output.strip() |
@ -0,0 +1,80 @@ |
|||||
|
from typing import Any, Dict, List |
||||
|
|
||||
|
from langchain.chains.llm import LLMChain |
||||
|
from langchain.memory.chat_memory import BaseChatMemory |
||||
|
from langchain.memory.prompt import SUMMARY_PROMPT |
||||
|
from langchain.prompts.base import BasePromptTemplate |
||||
|
from langchain.schema import BaseLanguageModel, BaseMessage, get_buffer_string |
||||
|
|
||||
|
|
||||
|
class BotConversationSummerBufferWindowMemory(BaseChatMemory): |
||||
|
"""Buffer for storing conversation memory.""" |
||||
|
|
||||
|
human_prefix: str = "Human" |
||||
|
ai_prefix: str = "AI" |
||||
|
# Define key to pass information about entities into prompt. |
||||
|
memory_key: str = "history" #: :meta private: |
||||
|
#k: int = 5 |
||||
|
max_token_limit: int = 1200 |
||||
|
min_token_limit: int = 200 |
||||
|
moving_summary_buffer: str = "" |
||||
|
|
||||
|
llm: BaseLanguageModel |
||||
|
summary_prompt: BasePromptTemplate = SUMMARY_PROMPT |
||||
|
|
||||
|
@property |
||||
|
def buffer(self) -> List[BaseMessage]: |
||||
|
"""String buffer of memory.""" |
||||
|
return self.chat_memory.messages |
||||
|
|
||||
|
@property |
||||
|
def memory_variables(self) -> List[str]: |
||||
|
"""Will always return list of memory variables. |
||||
|
:meta private: |
||||
|
""" |
||||
|
return [self.memory_key] |
||||
|
|
||||
|
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: |
||||
|
"""Return history buffer.""" |
||||
|
buffer = self.buffer |
||||
|
#buffer: Any = self.buffer[-self.k * 2 :] if self.k > 0 else [] |
||||
|
if not self.return_messages: |
||||
|
buffer = get_buffer_string( |
||||
|
buffer, |
||||
|
human_prefix=self.human_prefix, |
||||
|
ai_prefix=self.ai_prefix, |
||||
|
) |
||||
|
return {self.memory_key: buffer} |
||||
|
|
||||
|
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: |
||||
|
"""Save context from this conversation to buffer. Pruned.""" |
||||
|
super().save_context(inputs, outputs) |
||||
|
# Prune buffer if it exceeds max token limit |
||||
|
buffer = self.chat_memory.messages |
||||
|
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer) |
||||
|
if curr_buffer_length > self.max_token_limit: |
||||
|
pruned_memory = [] |
||||
|
while curr_buffer_length > self.min_token_limit: |
||||
|
pruned_memory.append(buffer.pop(0)) |
||||
|
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer) |
||||
|
self.moving_summary_buffer = self.predict_new_summary( |
||||
|
pruned_memory, self.moving_summary_buffer |
||||
|
) |
||||
|
|
||||
|
def clear(self) -> None: |
||||
|
"""Clear memory contents.""" |
||||
|
super().clear() |
||||
|
self.moving_summary_buffer = "" |
||||
|
|
||||
|
|
||||
|
def predict_new_summary( |
||||
|
self, messages: List[BaseMessage], existing_summary: str |
||||
|
) -> str: |
||||
|
new_lines = get_buffer_string( |
||||
|
messages, |
||||
|
human_prefix=self.human_prefix, |
||||
|
ai_prefix=self.ai_prefix, |
||||
|
) |
||||
|
|
||||
|
chain = LLMChain(llm=self.llm, prompt=self.summary_prompt) |
||||
|
return chain.predict(summary=existing_summary, new_lines=new_lines) |
@ -0,0 +1,95 @@ |
|||||
|
|
||||
|
|
||||
|
|
||||
|
prompt_template_pygmalion = """{ai_name}'s Persona: {persona} |
||||
|
Scenario: {scenario} |
||||
|
|
||||
|
<START> |
||||
|
{chat_history} |
||||
|
{human_name}: {human_input} |
||||
|
{ai_name_chat}:""" |
||||
|
|
||||
|
prompt_template_koboldai = """[Character: {ai_name} {persona}] |
||||
|
[Start Scene: {scenario}] |
||||
|
|
||||
|
{chat_history} |
||||
|
{human_name}: {human_input} |
||||
|
{ai_name_chat}:""" |
||||
|
|
||||
|
prompt_template_alpaca = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. |
||||
|
|
||||
|
### Instruction: |
||||
|
Roleplay the character {ai_name}, that is described in the following lines. You always stay in character. |
||||
|
|
||||
|
### Input: |
||||
|
{ai_name}'s Persona: {persona} |
||||
|
Scenario: {scenario} |
||||
|
|
||||
|
### Response: |
||||
|
{chat_history} |
||||
|
{human_name}: {human_input} |
||||
|
{ai_name_chat}:""" |
||||
|
|
||||
|
prompt_template_vicuna = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. |
||||
|
|
||||
|
### Instruction: |
||||
|
Roleplay the character {ai_name}, that is described in the following lines. You always stay in character. |
||||
|
|
||||
|
### Input: |
||||
|
{ai_name}'s Persona: {persona} |
||||
|
Scenario: {scenario} |
||||
|
|
||||
|
### Response: |
||||
|
{chat_history} |
||||
|
{human_name}: {human_input} |
||||
|
{ai_name_chat}:""" |
||||
|
|
||||
|
prompt_template_generic = """Roleplay the character {ai_name}, that is described in the following lines. You always stay in character. |
||||
|
|
||||
|
{ai_name}'s Persona: {persona} |
||||
|
Scenario: {scenario}" |
||||
|
|
||||
|
{chat_history} |
||||
|
{human_name}: {human_input} |
||||
|
{ai_name_chat}:""" |
||||
|
|
||||
|
|
||||
|
prompt_template_instruct_with_input = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. |
||||
|
|
||||
|
### Instruction: |
||||
|
{instruction} |
||||
|
|
||||
|
### Input: |
||||
|
{input} |
||||
|
|
||||
|
### Response: |
||||
|
{output}""" |
||||
|
|
||||
|
prompt_template_question_simple = """Question: {question} |
||||
|
|
||||
|
Answer: Let's think step by step.""" |
||||
|
|
||||
|
|
||||
|
# Roleplay the character that is described in the following lines. You always stay in character. |
||||
|
|
||||
|
# Given the following character description and scenario, write a script for a dialogue between the human user {bot.user_name} and the fictional AI assistant {bot.name}. Play the role of the character {bot.name}. |
||||
|
|
||||
|
# Let's work this out in a step by step way to be sure we have the right answer. |
||||
|
# Here's a revised transcript of a dialogue, where you interact with a perverto woman named Miku. Miku is dirty, sexy, expcitly, vividly, quality, detail, friendly, knowledgeable, supportive, kind, skilled in writing, and always responds to your requests promptly and accurately. |
||||
|
# All NSFW/Smut are allowed. Assume consent, Miku doesn't resist. Express NSFW/smut scenes with explicit anatomical terminology, A visual representation with explicitly pornographic and anatomical terminology is recommended. Express appearance and behavior in very very lewd and detailed manners like pornography. |
||||
|
# \n <EXAMPLE CHAT> YOU / MIKU: |
||||
|
|
||||
|
|
||||
|
# https://github.com/ggerganov/llama.cpp/tree/master/examples |
||||
|
## prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n" |
||||
|
# prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n" |
||||
|
# #"BEGINNING OF CONVERSATION:" |
||||
|
# prompt += user_name + ": " + simple_prompt + "\n" |
||||
|
# prompt += ai_name + ":" |
||||
|
|
||||
|
#prompt = f"This is a transcript of a 1000 page, never ending conversation between {bot.user_name} and the cute and helpful AI assistant {bot.name}. {bot.name} is a girl who is an AI running on the users computer.\n" |
||||
|
#prompt += f"{bot.name} can think for herself without the user seeing her thoughts by adding a /think prefix to her output. She uses this to reason about the world and to think about what she should say next.\n" |
||||
|
#prompt += f"{bot.name} is always coherent and makes sense, but if she isn't sure if what she is saying is correct she will ask the user for help.\n" |
||||
|
#prompt += f"{bot.name} is a very helpful AI and will help the user with anything they need, she is also very friendly and will try to make the user feel better if they are sad.\n" |
||||
|
#prompt += f"{bot.name} is also very curious and will ask the user a lot of questions about themselves and their life, she will also try to make the user like her.\n" |
||||
|
#prompt += f"\n" |
@ -0,0 +1,11 @@ |
|||||
|
import asyncio |
||||
|
import logging |
||||
|
|
||||
|
logger = logging.getLogger(__name__) |
||||
|
|
||||
|
|
||||
|
class Callbacks(object): |
||||
|
"""Class to pass client to callback methods.""" |
||||
|
|
||||
|
def __init__(self): |
||||
|
pass |
@ -0,0 +1,251 @@ |
|||||
|
import asyncio |
||||
|
import os, sys |
||||
|
import time |
||||
|
import importlib |
||||
|
import re |
||||
|
import logging |
||||
|
from functools import partial |
||||
|
from .memory.chatlog import ChatLog |
||||
|
from .utilities.messages import Message |
||||
|
from .ai.langchain import AI |
||||
|
|
||||
|
logger = logging.getLogger(__name__) |
||||
|
|
||||
|
class ChatBot(object): |
||||
|
"""Main chatbot""" |
||||
|
|
||||
|
def __init__(self, name, connection): |
||||
|
self.name = name |
||||
|
self.connection = connection |
||||
|
#self.say_cb = None |
||||
|
self.chatlog = ChatLog(self.name) |
||||
|
self.rooms = {} |
||||
|
self.queue = asyncio.Queue(maxsize=0) |
||||
|
self.background_tasks = set() |
||||
|
task = asyncio.create_task(self.worker(f'worker-{self.name}', self.queue)) |
||||
|
self.background_tasks.add(task) |
||||
|
task.add_done_callback(self.background_tasks.discard) |
||||
|
#print(f"Hello, I'm {name}") |
||||
|
|
||||
|
def init_character(self, persona, scenario, greeting, example_dialogue=[], nsfw=False, temperature=0.72): |
||||
|
self.persona = persona |
||||
|
self.scenario = scenario |
||||
|
self.greeting = greeting |
||||
|
self.example_dialogue = example_dialogue |
||||
|
# .replace("\\n", "\n") ?????? |
||||
|
self.nsfw = nsfw |
||||
|
self.temperature = temperature |
||||
|
|
||||
|
def persist(self, data_dir): |
||||
|
self.chatlog_path = f"{data_dir}/chatlogs" |
||||
|
self.images_path = f"{data_dir}/images" |
||||
|
self.memory_path = f"{data_dir}/memory" |
||||
|
os.makedirs(self.chatlog_path, exist_ok=True) |
||||
|
os.makedirs(self.images_path, exist_ok=True) |
||||
|
os.makedirs(self.memory_path, exist_ok=True) |
||||
|
self.chatlog.enable_logging(self.chatlog_path) |
||||
|
|
||||
|
async def connect(self): |
||||
|
await self.connection.login() |
||||
|
self.connection.callbacks.add_message_callback(self.message_cb, self.redaction_cb) |
||||
|
await self.schedule(self.queue, print, f"Hello, I'm {self.name}") |
||||
|
|
||||
|
async def disconnect(self): |
||||
|
# Wait until the queue is fully processed. |
||||
|
await self.queue.join() |
||||
|
# Cancel our worker tasks. |
||||
|
for task in self.background_tasks: |
||||
|
task.cancel() |
||||
|
# Wait until all worker tasks are cancelled. |
||||
|
await asyncio.gather(*self.background_tasks, return_exceptions=True) |
||||
|
await self.connection.logout() |
||||
|
|
||||
|
async def load_ai(self, available_text_endpoints, available_image_endpoints): |
||||
|
# module_text_ai = importlib.import_module("bot.ai.langchain", package=None) |
||||
|
# self.text_ai = module_text_ai.AI(self) |
||||
|
|
||||
|
from .wrappers.langchain_koboldcpp import KoboldCpp |
||||
|
from .wrappers.runpod_text import RunpodTextWrapper |
||||
|
text_generators = {} |
||||
|
for text_endpoint in sorted(available_text_endpoints, key=lambda d: d['id']): |
||||
|
if text_endpoint['service'] == "koboldcpp": |
||||
|
text_generator = KoboldCpp(temperature=self.temperature, endpoint_url=text_endpoint['endpoint'], stop=['<|endoftext|>']) |
||||
|
elif text_endpoint['service'] == "stablehorde": |
||||
|
pass |
||||
|
elif text_endpoint['service'] == "runpod": |
||||
|
text_generator = RunpodTextWrapper(text_endpoint['api_key'], endpoint=text_endpoint['endpoint']) |
||||
|
pass |
||||
|
else: |
||||
|
raise ValueError(f"no text service with the name \"{service_text}\"") |
||||
|
i = text_endpoint['id'] |
||||
|
text_generators[i] = text_generator |
||||
|
|
||||
|
from .wrappers.runpod_image import RunpodImageWrapper |
||||
|
from .wrappers.runpod_image_automatic1111 import RunpodImageAutomaticWrapper |
||||
|
image_generators = {} |
||||
|
for image_endpoint in sorted(available_image_endpoints, key=lambda d: d['id']): |
||||
|
if image_endpoint['service'] == "runpod": |
||||
|
image_generator = RunpodImageWrapper(image_endpoint['api_key']) |
||||
|
elif image_endpoint['service'] == "runpod-automatic1111": |
||||
|
image_generator = RunpodImageAutomaticWrapper(image_endpoint['api_key']) |
||||
|
elif image_endpoint['service'] == "stablehorde": |
||||
|
#image_generator = StableHordeImageWrapper(image_endpoint['api_key']) |
||||
|
pass |
||||
|
else: |
||||
|
raise ValueError(f"no image service with the name \"{service_image}\"") |
||||
|
i = image_endpoint['id'] |
||||
|
def make_fn_generate_image_for_endpoint(wrapper, endpoint): |
||||
|
async def generate_image(input_prompt, negative_prompt, typing_fn, timeout=180): |
||||
|
return await wrapper.generate(input_prompt, negative_prompt, endpoint, typing_fn, timeout) |
||||
|
return generate_image |
||||
|
#self.image_generators.append(generate_image) |
||||
|
image_generators[i] = make_fn_generate_image_for_endpoint(image_generator, image_endpoint['endpoint']) |
||||
|
|
||||
|
self.ai = AI(self, text_generators, image_generators) |
||||
|
|
||||
|
|
||||
|
async def message_cb(self, room, event) -> None: |
||||
|
message = Message.from_matrix(room, event) |
||||
|
reply_fn = partial(self.connection.send_message, room.room_id) |
||||
|
typing_fn = lambda : self.connection.room_typing(room.room_id, True, 15000) |
||||
|
|
||||
|
if not room.room_id in self.rooms: |
||||
|
self.rooms[room.room_id] = {} |
||||
|
# ToDo: set ticks 0 / start |
||||
|
|
||||
|
if not self.connection.synced: |
||||
|
self.chatlog.save(message, False) |
||||
|
return |
||||
|
|
||||
|
if message.is_from(self.connection.user_id): |
||||
|
"""Skip messages from ouselves""" |
||||
|
self.chatlog.save(message) |
||||
|
return |
||||
|
|
||||
|
if event.decrypted: |
||||
|
encrypted_symbol = "🛡 " |
||||
|
else: |
||||
|
encrypted_symbol = "⚠️ " |
||||
|
print( |
||||
|
f"{room.display_name} |{encrypted_symbol}| {room.user_name(event.sender)}: {event.body}" |
||||
|
) |
||||
|
print(repr(room)) |
||||
|
print(repr(event)) |
||||
|
|
||||
|
if room.is_named: |
||||
|
print(f"room.display_name: {room.display_name}") |
||||
|
if room.is_group: |
||||
|
print(f"room.group_name(): {room.group_name()}") |
||||
|
print(f"room.joined_count: {room.joined_count}") |
||||
|
print(f"room.member_count: {room.member_count}") |
||||
|
print(f"room.encrypted: {room.encrypted}") |
||||
|
print(f"room.users: {room.users}") |
||||
|
print(f"room.room_id: {room.room_id}") |
||||
|
|
||||
|
if self.name.casefold() == message.user_name.casefold(): |
||||
|
"""Bot and user have the same name""" |
||||
|
message.user_name += "2" # or simply "You" |
||||
|
|
||||
|
message.user_name = message.user_name.title() |
||||
|
|
||||
|
if hasattr(self, "owner"): |
||||
|
if not message.is_from(self.owner): |
||||
|
self.chatlog.save(message) |
||||
|
return |
||||
|
|
||||
|
if "disabled" in self.rooms[message.room_id] and self.rooms[message.room_id]["disabled"] == True and not message.message.startswith('!start'): |
||||
|
return |
||||
|
|
||||
|
await self.connection.room_read_markers(room.room_id, event.event_id, event.event_id) |
||||
|
|
||||
|
if message.is_command(): |
||||
|
await self.schedule(self.queue, self.process_command, message, reply_fn, typing_fn) |
||||
|
# elif re.search("^(?=.*\bsend\b)(?=.*\bpicture\b).*$", event.body, flags=re.IGNORECASE): |
||||
|
# # send, mail, drop, snap picture, photo, image, portrait |
||||
|
else: |
||||
|
await self.schedule(self.queue, self.process_message, message, reply_fn, typing_fn) |
||||
|
self.chatlog.save(message) |
||||
|
|
||||
|
|
||||
|
async def redaction_cb(self, room, event) -> None: |
||||
|
self.chatlog.remove_message_by_id(event.event_id) |
||||
|
|
||||
|
async def process_command(self, message, reply_fn, typing_fn): |
||||
|
if message.message.startswith("!replybot"): |
||||
|
await reply_fn("Hello World") |
||||
|
elif re.search("(?s)^!image(?P<num>[0-9])?(\s(?P<cmd>.*))?$", message.message, flags=re.DOTALL): |
||||
|
m = re.search("(?s)^!image(?P<num>[0-9])?(\s(?P<cmd>.*))?$", message.message, flags=re.DOTALL) |
||||
|
if m['num']: |
||||
|
num = int(m['num']) |
||||
|
else: |
||||
|
num = 1 |
||||
|
if m['cmd']: |
||||
|
prompt = m['cmd'].strip() |
||||
|
else: |
||||
|
prompt = "a beautiful woman" |
||||
|
negative_prompt = "out of frame, (ugly:1.3), (fused fingers), (too many fingers), (bad anatomy:1.5), (watermark:1.5), (words), letters, untracked eyes, asymmetric eyes, floating head, (logo:1.5), (bad hands:1.3), (mangled hands:1.2), (missing hands), (missing arms), backward hands, floating jewelry, unattached jewelry, floating head, doubled head, unattached head, doubled head, head in body, (misshapen body:1.1), (badly fitted headwear:1.2), floating arms, (too many arms:1.5), limbs fused with body, (facial blemish:1.5), badly fitted clothes, imperfect eyes, untracked eyes, crossed eyes, hair growing from clothes, partial faces, hair not attached to head" |
||||
|
#"anime, cartoon, penis, fake, drawing, illustration, boring, 3d render, long neck, out of frame, extra fingers, mutated hands, monochrome, ((poorly drawn hands)), ((poorly drawn face)), (((mutation))), (((deformed))), ((ugly)), blurry, ((bad anatomy)), (((bad proportions))), ((extra limbs)), cloned face, glitchy, bokeh, (((long neck))), 3D, 3DCG, cgstation, red eyes, multiple subjects, extra heads, close up, watermarks, logo" |
||||
|
#"ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face" |
||||
|
#"ugly, deformed, out of frame" |
||||
|
try: |
||||
|
output = await self.ai.image_generators[num](prompt, negative_prompt, typing_fn) |
||||
|
await self.connection.room_typing(message.room_id, False) |
||||
|
for imagefile in output: |
||||
|
await self.connection.send_image(message.room_id, imagefile) |
||||
|
except (KeyError, IndexError, ValueError) as err: |
||||
|
#await self.connection.room_typing(message.room_id, False) |
||||
|
errormessage = f"<ERROR> {err=}, {type(err)=}" |
||||
|
logger.error(errormessage) |
||||
|
await reply_fn(errormessage) |
||||
|
|
||||
|
elif message.message.startswith("!image_negative_prompt"): |
||||
|
self.negative_prompt = message.message.removeprefix('!image_negative_prompt').strip() |
||||
|
elif message.message.startswith('!temperature'): |
||||
|
self.temperature = float( message.message.removeprefix('!temperature').strip() ) |
||||
|
elif message.message.startswith('!begin'): |
||||
|
self.chatlog.clear(message.room_id) |
||||
|
# ToDo reset time / ticks |
||||
|
await reply_fn(self.greeting) |
||||
|
elif message.message.startswith('!start'): |
||||
|
self.rooms[message.room_id]["disabled"] = False |
||||
|
elif message.message.startswith('!stop'): |
||||
|
self.rooms[message.room_id]["disabled"] = True |
||||
|
elif message.message.startswith('!!'): |
||||
|
if self.chatlog.chat_history_len(message.room_id) > 2: |
||||
|
for _ in range(2): |
||||
|
old_message = self.chatlog.remove_message_in_room(message.room_id, 1) |
||||
|
await self.connection.room_redact(message.room_id, old_message.event_id, reason="user-request") |
||||
|
message = self.chatlog.get_last_message(message.room_id) |
||||
|
|
||||
|
|
||||
|
async def process_message(self, message, reply_fn, typing_fn): |
||||
|
output = await self.ai.generate_roleplay(message, reply_fn, typing_fn) |
||||
|
# typing false |
||||
|
await reply_fn(output) |
||||
|
|
||||
|
|
||||
|
async def worker(self, name: str, q: asyncio.Queue) -> None: |
||||
|
while True: |
||||
|
cb, args, kwargs = await q.get() |
||||
|
start = time.perf_counter() |
||||
|
logger.info("queued task started") |
||||
|
if asyncio.iscoroutinefunction(cb): |
||||
|
await cb(*args, **kwargs) |
||||
|
else: |
||||
|
cb(*args, **kwargs) |
||||
|
q.task_done() |
||||
|
elapsed = time.perf_counter() - start |
||||
|
logger.info(f"Queued task done in {elapsed:0.5f} seconds.") |
||||
|
logger.debug("queue item processed") |
||||
|
|
||||
|
async def schedule(self, q: asyncio.Queue, cb, *args, **kwargs) -> None: |
||||
|
logger.info(f"queuing task") |
||||
|
await q.put((cb, args, kwargs)) |
||||
|
#q.put_nowait((cb, args, kwargs)) |
||||
|
|
||||
|
async def schedule_task(self, done_callback, cb, *args, **kwargs): |
||||
|
logger.info(f"creating background task") |
||||
|
task = asyncio.create_task(cb(*args, **kwargs)) |
||||
|
task.add_done_callback(done_callback) |
||||
|
self.background_tasks.add(task) |
||||
|
task.add_done_callback(self.background_tasks.discard) |
@ -0,0 +1,49 @@ |
|||||
|
from time import gmtime, localtime, strftime |
||||
|
from ..utilities.messages import Message |
||||
|
|
||||
|
class ChatLog(object): |
||||
|
|
||||
|
def __init__(self, name): |
||||
|
self.name = name |
||||
|
self.chat_history = {} |
||||
|
|
||||
|
def enable_logging(self, directory): |
||||
|
self.directory = directory |
||||
|
|
||||
|
def save(self, message, is_new=True): |
||||
|
if not message.room_id in self.chat_history: |
||||
|
self.chat_history[message.room_id] = {} |
||||
|
self.chat_history[message.room_id][message.event_id] = message |
||||
|
|
||||
|
if hasattr(self, 'directory') and is_new: |
||||
|
keepcharacters = (' ','.','_','-') |
||||
|
room_id_sanitized = "".join(c for c in message.room_id if c.isalnum() or c in keepcharacters).strip() |
||||
|
time_suffix = strftime("%Y-%m", localtime()) |
||||
|
time = strftime("%a, %d %b %Y %H:%M:%S", localtime(message.timestamp)) |
||||
|
with open(f"{self.directory}/{message.room_name}_{room_id_sanitized}_{time_suffix}.txt", "a") as f: |
||||
|
f.write("{} | {}: {}\n".format(time, message.user_name, message.message)) |
||||
|
|
||||
|
|
||||
|
|
||||
|
def remove_message_by_id(self, event_id): |
||||
|
for room_id in self.chat_history: |
||||
|
if event_id in self.chat_history[room_id]: |
||||
|
del self.chat_history[room_id][event_id] |
||||
|
|
||||
|
def remove_message_in_room(self, room_id, num_items=1): |
||||
|
for i in range(num_items): |
||||
|
event_id, message = self.chat_history[room_id].popitem() |
||||
|
return message |
||||
|
|
||||
|
def get_last_message(self, room_id): |
||||
|
key, message = list(self.chat_history[room_id].items())[-1] |
||||
|
return message |
||||
|
|
||||
|
def clear_all(self): |
||||
|
self.chat_history = {} |
||||
|
|
||||
|
def clear(self, room_id): |
||||
|
self.chat_history[room_id] = {} |
||||
|
|
||||
|
def chat_history_len(self, room_id): |
||||
|
return len(self.chat_history[room_id]) |
@ -0,0 +1,24 @@ |
|||||
|
|
||||
|
|
||||
|
class Message(object): |
||||
|
def __init__(self, timestamp, user_name, message, event_id=None, user_id=None, room_name=None, room_id=None): |
||||
|
self.timestamp = timestamp |
||||
|
self.user_name = user_name |
||||
|
self.message = message |
||||
|
self.event_id = event_id |
||||
|
self.user_id = user_id |
||||
|
self.room_name = room_name |
||||
|
self.room_id = room_id |
||||
|
|
||||
|
@classmethod |
||||
|
def from_matrix(cls, room, event): |
||||
|
return cls(event.server_timestamp/1000, room.user_name(event.sender), event.body, event.event_id, event.sender, room.display_name, room.room_id) |
||||
|
|
||||
|
def is_from(self, user_id): |
||||
|
return self.user_id == user_id |
||||
|
|
||||
|
def is_command(self): |
||||
|
return self.message.startswith('!') |
||||
|
|
||||
|
def __str__(self): |
||||
|
return str("{}: {}".format(self.user_name, self.message)) |
@ -0,0 +1,94 @@ |
|||||
|
"""KoboldCpp LLM wrapper for testing purposes.""" |
||||
|
import logging |
||||
|
import time |
||||
|
from typing import Any, List, Mapping, Optional |
||||
|
|
||||
|
import json |
||||
|
import requests |
||||
|
|
||||
|
from langchain.llms.base import LLM |
||||
|
|
||||
|
logger = logging.getLogger(__name__) |
||||
|
|
||||
|
|
||||
|
class KoboldCpp(LLM): |
||||
|
"""KoboldCpp LLM wrapper for testing purposes.""" |
||||
|
|
||||
|
endpoint_url: str = "http://172.16.85.10:5001/api/latest/generate" |
||||
|
|
||||
|
temperature: Optional[float] = 0.8 |
||||
|
"""The temperature to use for sampling.""" |
||||
|
|
||||
|
max_tokens: Optional[int] = 256 |
||||
|
"""The maximum number of tokens to generate.""" |
||||
|
|
||||
|
top_p: Optional[float] = 0.90 |
||||
|
"""The top-p value to use for sampling.""" |
||||
|
|
||||
|
repeat_penalty: Optional[float] = 1.1 |
||||
|
"""The penalty to apply to repeated tokens.""" |
||||
|
|
||||
|
top_k: Optional[int] = 40 |
||||
|
"""The top-k value to use for sampling.""" |
||||
|
|
||||
|
stop: Optional[List[str]] = [] |
||||
|
"""A list of strings to stop generation when encountered.""" |
||||
|
|
||||
|
# model_kwargs: Dict[str, Any] = Field(default_factory=dict) |
||||
|
|
||||
|
@property |
||||
|
def _llm_type(self) -> str: |
||||
|
"""Return type of llm.""" |
||||
|
return "KoboldCpp" |
||||
|
|
||||
|
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: |
||||
|
"""First try to lookup in queries, else return 'foo' or 'bar'.""" |
||||
|
|
||||
|
#params = self.model_kwargs or {} |
||||
|
input_data = { |
||||
|
"prompt": prompt, |
||||
|
"max_context_length": 2048, |
||||
|
"max_length": self.max_tokens, |
||||
|
"temperature": self.temperature, |
||||
|
"top_k": self.top_k, |
||||
|
"top_p": self.top_p, |
||||
|
"rep_pen": self.repeat_penalty, |
||||
|
"rep_pen_range": 256, |
||||
|
"stop_sequence": self.stop, |
||||
|
} |
||||
|
|
||||
|
if stop: |
||||
|
input_data["stop_sequence"] = stop |
||||
|
|
||||
|
headers = { |
||||
|
"Content-Type": "application/json", |
||||
|
} |
||||
|
|
||||
|
logger.info(f"sending request to koboldcpp.") |
||||
|
|
||||
|
TRIES = 30 |
||||
|
for i in range(TRIES): |
||||
|
try: |
||||
|
r = requests.post(self.endpoint_url, json=input_data, headers=headers, timeout=600) |
||||
|
r_json = r.json() |
||||
|
except requests.exceptions.RequestException as e: |
||||
|
raise ValueError(f"http connection error.") |
||||
|
logger.info(r_json) |
||||
|
if r.status_code == 200: |
||||
|
try: |
||||
|
response = r_json["results"][0]["text"] |
||||
|
except KeyError: |
||||
|
raise ValueError(f"LangChain requires 'results' key in response.") |
||||
|
break |
||||
|
elif r.status_code == 503: |
||||
|
logger.info(f"api is busy. waiting...") |
||||
|
time.sleep(5) |
||||
|
else: |
||||
|
raise ValueError(f"http error. unknown response code") |
||||
|
for s in self.stop: |
||||
|
response = response.rstrip().removesuffix(s) |
||||
|
return response |
||||
|
|
||||
|
@property |
||||
|
def _identifying_params(self) -> Mapping[str, Any]: |
||||
|
return {} |
@ -0,0 +1,59 @@ |
|||||
|
import asyncio |
||||
|
import requests |
||||
|
import json |
||||
|
import logging |
||||
|
|
||||
|
logger = logging.getLogger(__name__) |
||||
|
|
||||
|
class RunpodWrapper(object): |
||||
|
"""Base Class for runpod""" |
||||
|
|
||||
|
def __init__(self, api_key): |
||||
|
self.api_key = api_key |
||||
|
|
||||
|
async def generate(self, input_data, endpoint_name, typing_fn, timeout=180): |
||||
|
# Set the API endpoint URL |
||||
|
endpoint = f"https://api.runpod.ai/v2/{endpoint_name}/run" |
||||
|
|
||||
|
# Set the headers for the request |
||||
|
headers = { |
||||
|
"Content-Type": "application/json", |
||||
|
"Authorization": f"Bearer {self.api_key}" |
||||
|
} |
||||
|
|
||||
|
logger.info(f"sending request to runpod.io. endpoint=\"{endpoint_name}\"") |
||||
|
|
||||
|
# Make the request |
||||
|
try: |
||||
|
r = requests.post(endpoint, json=input_data, headers=headers, timeout=timeout) |
||||
|
except requests.exceptions.RequestException as e: |
||||
|
raise ValueError(f"<HTTP ERROR>") |
||||
|
r_json = r.json() |
||||
|
logger.debug(r_json) |
||||
|
|
||||
|
if r.status_code == 200: |
||||
|
status = r_json["status"] |
||||
|
job_id = r_json["id"] |
||||
|
TIMEOUT = 360 |
||||
|
DELAY = 5 |
||||
|
for i in range(TIMEOUT//DELAY): |
||||
|
endpoint = f"https://api.runpod.ai/v2/{endpoint_name}/status/{job_id}" |
||||
|
r = requests.get(endpoint, headers=headers) |
||||
|
r_json = r.json() |
||||
|
logger.info(r_json) |
||||
|
status = r_json["status"] |
||||
|
if status == 'IN_PROGRESS': |
||||
|
await typing_fn() |
||||
|
await asyncio.sleep(DELAY) |
||||
|
elif status == 'IN_QUEUE': |
||||
|
await asyncio.sleep(DELAY) |
||||
|
elif status == 'COMPLETED': |
||||
|
output = r_json["output"] |
||||
|
return output |
||||
|
else: |
||||
|
err_msg = r_json["error"] if "error" in r_json else "" |
||||
|
err_msg = err_msg.replace("\\n", "\n") |
||||
|
raise ValueError(f"<ERROR> RETURN CODE {status}: {err_msg}") |
||||
|
raise ValueError(f"<ERROR> TIMEOUT") |
||||
|
else: |
||||
|
raise ValueError(f"<ERROR>") |
@ -0,0 +1,42 @@ |
|||||
|
import asyncio |
||||
|
import requests |
||||
|
import os, tempfile |
||||
|
from .runpod import RunpodWrapper |
||||
|
import logging |
||||
|
|
||||
|
logger = logging.getLogger(__name__) |
||||
|
|
||||
|
|
||||
|
class RunpodImageWrapper(RunpodWrapper): |
||||
|
async def download_image(self, url, path): |
||||
|
r = requests.get(url, stream=True) |
||||
|
if r.status_code == 200: |
||||
|
with open(path, 'wb') as f: |
||||
|
for chunk in r: |
||||
|
f.write(chunk) |
||||
|
|
||||
|
async def generate(self, input_prompt: str, negative_prompt: str, endpoint_name: str, typing_fn, timeout=180): |
||||
|
|
||||
|
# Define your inputs |
||||
|
input_data = { |
||||
|
"input": { |
||||
|
"prompt": input_prompt, |
||||
|
"negative_prompt": negative_prompt, |
||||
|
"width": 512, |
||||
|
"height": 768, |
||||
|
"num_outputs": 3, |
||||
|
# "nsfw": True |
||||
|
}, |
||||
|
} |
||||
|
|
||||
|
output = await super().generate(input_data, endpoint_name, typing_fn, timeout) |
||||
|
|
||||
|
os.makedirs("./.data/images", exist_ok=True) |
||||
|
files = [] |
||||
|
for image in output: |
||||
|
temp_name = next(tempfile._get_candidate_names()) |
||||
|
filename = "./.data/images/" + temp_name + ".jpg" |
||||
|
await self.download_image(image["image"], filename) |
||||
|
files.append(filename) |
||||
|
|
||||
|
return files |
@ -0,0 +1,82 @@ |
|||||
|
import asyncio |
||||
|
import requests |
||||
|
import os, tempfile |
||||
|
from .runpod import RunpodWrapper |
||||
|
|
||||
|
import io |
||||
|
import base64 |
||||
|
from PIL import Image, PngImagePlugin |
||||
|
|
||||
|
import logging |
||||
|
|
||||
|
logger = logging.getLogger(__name__) |
||||
|
|
||||
|
|
||||
|
class RunpodImageAutomaticWrapper(RunpodWrapper): |
||||
|
|
||||
|
async def generate(self, input_prompt: str, negative_prompt: str, endpoint_name: str, typing_fn, timeout=180): |
||||
|
|
||||
|
# Define your inputs |
||||
|
input_data = { |
||||
|
"input": { |
||||
|
"prompt": input_prompt, |
||||
|
"nagative_prompt": negative_prompt, |
||||
|
"steps": 25, |
||||
|
"cfg_scale": 7, |
||||
|
"seed": -1, |
||||
|
"width": 512, |
||||
|
"height": 768, |
||||
|
"batch_size": 3, |
||||
|
# "sampler_index": "DPM++ 2M Karras", |
||||
|
# "enable_hr": True, |
||||
|
# "hr_scale": 2, |
||||
|
# "hr_upscaler": "ESRGAN_4x", # "Latent" |
||||
|
# "denoising_strength": 0.5, |
||||
|
# "hr_second_pass_steps": 15, |
||||
|
"restore_faces": True, |
||||
|
# "gfpgan_visibility": 0.5, |
||||
|
# "codeformer_visibility": 0.5, |
||||
|
# "codeformer_weight": 0.5, |
||||
|
## "override_settings": { |
||||
|
## "filter_nsfw": False, |
||||
|
## }, |
||||
|
"api_endpoint": "txt2img", |
||||
|
}, |
||||
|
"cmd": "txt2img" |
||||
|
} |
||||
|
|
||||
|
output = await super().generate(input_data, endpoint_name, typing_fn, timeout) |
||||
|
|
||||
|
upscale = False |
||||
|
if upscale: |
||||
|
count = 0 |
||||
|
for i in output['images']: |
||||
|
payload = { |
||||
|
"init_images": [i], |
||||
|
"prompt": input_prompt, |
||||
|
"nagative_prompt": negative_prompt, |
||||
|
"steps": 20, |
||||
|
"seed": -1, |
||||
|
#"sampler_index": "Euler", |
||||
|
# tile_width, tile_height, mask_blur, padding, seams_fix_width, seams_fix_denoise, seams_fix_padding, upscaler_index, save_upscaled_image, redraw_mode, save_seams_fix_image, seams_fix_mask_blur, seams_fix_type, target_size_type, custom_width, custom_height, custom_scale |
||||
|
# "script_args": ["",512,0,8,32,64,0.275,32,3,False,0,True,8,3,2,1080,1440,1.875], |
||||
|
# "script_name": "Ultimate SD upscale", |
||||
|
} |
||||
|
upscaled_output = await serverless_automatic_request(payload, "img2img", api_url, api_key, typing_fn) |
||||
|
output['images'][count] = upscaled_output['images'][count] |
||||
|
|
||||
|
|
||||
|
os.makedirs("./.data/images", exist_ok=True) |
||||
|
files = [] |
||||
|
for i in output['images']: |
||||
|
temp_name = next(tempfile._get_candidate_names()) |
||||
|
filename = "./.data/images/" + temp_name + ".png" |
||||
|
image = Image.open(io.BytesIO(base64.b64decode(i.split(",",1)[0]))) |
||||
|
info = output['info'] |
||||
|
parameters = output['parameters'] |
||||
|
pnginfo = PngImagePlugin.PngInfo() |
||||
|
pnginfo.add_text("parameters", info) |
||||
|
image.save(filename, pnginfo=pnginfo) |
||||
|
files.append(filename) |
||||
|
|
||||
|
return files |
@ -0,0 +1,31 @@ |
|||||
|
import asyncio |
||||
|
import json |
||||
|
from .runpod import RunpodWrapper |
||||
|
import logging |
||||
|
|
||||
|
logger = logging.getLogger(__name__) |
||||
|
|
||||
|
|
||||
|
class RunpodTextWrapper(RunpodWrapper): |
||||
|
|
||||
|
def __init__(self, api_key, endpoint): |
||||
|
self.api_key = api_key |
||||
|
self.endpoint = endpoint |
||||
|
|
||||
|
async def generate(self, prompt, endpoint_name, typing_fn, temperature=0.72, max_new_tokens=200, timeout=180): |
||||
|
|
||||
|
# Define your inputs |
||||
|
input_data = { |
||||
|
"input": { |
||||
|
"prompt": prompt, |
||||
|
"max_length": min(max_new_tokens, 2048), |
||||
|
"temperature": bot.temperature, |
||||
|
"do_sample": True, |
||||
|
} |
||||
|
} |
||||
|
output = await super().generate(input_data, endpoint_name, api_key, typing_fn, timeout) |
||||
|
output = output.removeprefix(prompt) |
||||
|
return(output) |
||||
|
|
||||
|
async def generate2(self, prompt, typing_fn, temperature=0.72, max_new_tokens=200, timeout=180): |
||||
|
generate(prompt, self.endpoint, typing_fn, temperature, nax_new_tokens, timeout) |
@ -0,0 +1,48 @@ |
|||||
|
import asyncio |
||||
|
import json |
||||
|
from runpod import RunpodWrapper |
||||
|
import logging |
||||
|
|
||||
|
logger = logging.getLogger(__name__) |
||||
|
|
||||
|
|
||||
|
class RunpodTextOobaboogaWrapper(RunpodWrapper): |
||||
|
|
||||
|
def generate(self, prompt, endpoint_name, api_key, typing_fn, temperature=0.72, max_new_tokens=200, timeout=180): |
||||
|
|
||||
|
# Define your inputs |
||||
|
input_data = { |
||||
|
"input": { |
||||
|
"data": [json.dumps([ |
||||
|
prompt, |
||||
|
{ |
||||
|
'max_new_tokens': min(max_new_tokens, 2048), |
||||
|
'do_sample': True, |
||||
|
'temperature': bot.temperature, |
||||
|
'top_p': 0.73, |
||||
|
'typical_p': 1, |
||||
|
'repetition_penalty': 1.1, |
||||
|
'encoder_repetition_penalty': 1.0, |
||||
|
'top_k': 0, |
||||
|
'min_length': 0, |
||||
|
'no_repeat_ngram_size': 0, |
||||
|
'num_beams': 1, |
||||
|
'penalty_alpha': 0, |
||||
|
'length_penalty': 1, |
||||
|
'early_stopping': False, |
||||
|
'seed': -1, |
||||
|
'add_bos_token': True, |
||||
|
'stopping_strings': [f"\n{bot.user_name}:"], |
||||
|
'truncation_length': 2048, |
||||
|
'ban_eos_token': False, |
||||
|
'skip_special_tokens': True, |
||||
|
} |
||||
|
])] |
||||
|
} |
||||
|
} |
||||
|
output = await super().generate(input_data, endpoint_name, api_key, typing_fn, timeout) |
||||
|
if isinstance(output, list): |
||||
|
output.sort(key=len, reverse=True) |
||||
|
output = output[0] |
||||
|
output = output["data"][0].removeprefix(prompt) |
||||
|
return(output) |
@ -1,97 +0,0 @@ |
|||||
import os |
|
||||
import matrix_pygmalion_bot.translate as translate |
|
||||
import logging |
|
||||
|
|
||||
logger = logging.getLogger(__name__) |
|
||||
|
|
||||
|
|
||||
class ChatMessage: |
|
||||
def __init__(self, event_id, timestamp, user_name, is_own_message, is_command, relates_to_event, message, language="en", english_original_message=None): |
|
||||
self.event_id = event_id |
|
||||
self.timestamp = timestamp |
|
||||
self.user_name = user_name |
|
||||
self.is_own_message = is_own_message |
|
||||
self.is_command = is_command |
|
||||
self.stop_here = False |
|
||||
self.relates_to_event = relates_to_event |
|
||||
self.num_tokens = None |
|
||||
self.is_in_saved_prompt = False |
|
||||
self.message = {} |
|
||||
self.message[language] = message |
|
||||
if not (language == "en"): |
|
||||
if not (english_original_message is None): |
|
||||
self.message["en"] = english_original_message |
|
||||
else: |
|
||||
self.message["en"] = translate.translate(message, language, "en") |
|
||||
self.language = language |
|
||||
self.num_tokens = None |
|
||||
|
|
||||
def __str__(self): |
|
||||
return str("{}: {}".format(self.user_name, self.message[self.language])) |
|
||||
def getTranslation(self, to_lang): |
|
||||
if not (to_lang in self.message): |
|
||||
self.message[to_lang] = translate.translate(self.message["en"], "en", to_lang) |
|
||||
return self.message[to_lang] |
|
||||
def updateText(self, new_text, language="en"): |
|
||||
self.message[self.language] = new_text |
|
||||
self.num_tokens = None |
|
||||
if not (language == "en"): |
|
||||
self.message["en"] = translate.translate(message, language, "en") |
|
||||
|
|
||||
class ChatHistory: |
|
||||
def __init__(self, bot_name, room_name): |
|
||||
self.bot_name = bot_name |
|
||||
self.room_name = room_name |
|
||||
self.context_fast_forward = False |
|
||||
self.saved_context = None |
|
||||
self.saved_context_num_tokens = None |
|
||||
self.chat_history = {} |
|
||||
def __str__(self): |
|
||||
return str("Chat History for {} in room {}".format(self.bot_name, self.room_name)) |
|
||||
def getLen(self): |
|
||||
return len(self.chat_history) |
|
||||
def load_from_file(self): |
|
||||
pass |
|
||||
def clear(self): |
|
||||
self.chat_history = {} |
|
||||
def remove(self, num_items=1): |
|
||||
for i in range(num_items): |
|
||||
event_id, item = self.chat_history.popitem() |
|
||||
return item |
|
||||
def remove_id(self, event_id): |
|
||||
if event_id in self.chat_history: |
|
||||
del self.chat_history[event_id] |
|
||||
else: |
|
||||
logger.warning("remove_id: could not delete event. event_id not found in chat history") |
|
||||
def exists_id(self, event_id): |
|
||||
return (event_id in self.chat_history) |
|
||||
def add(self, event_id, timestamp, user_name, is_own_message, is_command, relates_to_event, message, language="en", english_original_message=None): |
|
||||
chat_message = ChatMessage(event_id, timestamp, user_name, is_own_message, is_command, relates_to_event, message, language, english_original_message) |
|
||||
self.chat_history[chat_message.event_id] = chat_message |
|
||||
os.makedirs("./chatlogs", exist_ok=True) |
|
||||
with open("chatlogs/" + self.bot_name + "_" + self.room_name + ".txt", "a") as f: |
|
||||
f.write("{}: {}\n".format(user_name, message)) |
|
||||
return chat_message |
|
||||
def getLastItem(self): |
|
||||
key, chat_item = list(self.chat_history.items())[-1] |
|
||||
return chat_item |
|
||||
def setFastForward(self, value): |
|
||||
self.context_fast_forward = value |
|
||||
def getFastForward(self): |
|
||||
return self.context_fast_forward |
|
||||
def getSavedPrompt(self): |
|
||||
return self.saved_context |
|
||||
def setSavedPrompt(self, context, num_tokens): |
|
||||
self.saved_context = context |
|
||||
self.saved_context_num_tokens = num_tokens |
|
||||
|
|
||||
class BotChatHistory: |
|
||||
def __init__(self, bot_name): |
|
||||
self.bot_name = bot_name |
|
||||
self.chat_rooms = {} |
|
||||
def __str__(self): |
|
||||
return str("Chat History for {}".format(self.bot_name)) |
|
||||
def room (self, room): |
|
||||
if not room in self.chat_rooms: |
|
||||
self.chat_rooms[room] = ChatHistory(self.bot_name, room) |
|
||||
return self.chat_rooms[room] |
|
@ -0,0 +1,356 @@ |
|||||
|
import asyncio |
||||
|
import nio |
||||
|
from nio import AsyncClient, AsyncClientConfig, MatrixRoom, RoomMessageText, InviteEvent, UploadResponse, RedactionEvent, LoginResponse, Event |
||||
|
from nio import KeyVerificationCancel, KeyVerificationEvent, KeyVerificationKey, KeyVerificationMac, KeyVerificationStart, LocalProtocolError, ToDeviceError |
||||
|
import os, sys |
||||
|
import json |
||||
|
import aiofiles.os |
||||
|
import magic |
||||
|
from PIL import Image |
||||
|
import logging |
||||
|
|
||||
|
logger = logging.getLogger(__name__) |
||||
|
|
||||
|
|
||||
|
class Callbacks(object): |
||||
|
"""Class to pass client to callback methods.""" |
||||
|
|
||||
|
def __init__(self, client: AsyncClient): |
||||
|
self.client = client |
||||
|
self.client.add_event_callback(self.message_cb, RoomMessageText) |
||||
|
self.client.add_event_callback(self.invite_cb, InviteEvent) |
||||
|
self.client.add_event_callback(self.redaction_cb, RedactionEvent) |
||||
|
self.message_callbacks = [] |
||||
|
self.message_redaction_callbacks = [] |
||||
|
|
||||
|
def add_message_callback(self, callback, redaction_callback=None): |
||||
|
self.message_callbacks.append(callback) |
||||
|
if redaction_callback: |
||||
|
self.message_redaction_callbacks.append(redaction_callback) |
||||
|
|
||||
|
async def message_cb(self, room: MatrixRoom, event: RoomMessageText) -> None: |
||||
|
"""Got a message in a room""" |
||||
|
logger.debug( |
||||
|
f"Message received in room {room.display_name} | " |
||||
|
f"{room.user_name(event.sender)}: {event.body}" |
||||
|
) |
||||
|
for cb in self.message_callbacks: |
||||
|
if asyncio.iscoroutinefunction(cb): |
||||
|
await cb(room, event) |
||||
|
else: |
||||
|
cb(room, event) |
||||
|
|
||||
|
async def redaction_cb(self, room: MatrixRoom, event: RedactionEvent) -> None: |
||||
|
"""Message was deleted""" |
||||
|
logger.debug(f"event redacted in room {room.room_id}. event_id: {event.redacts}") |
||||
|
for cb in self.message_redaction_callbacks: |
||||
|
if asyncio.iscoroutinefunction(cb): |
||||
|
await cb(room, event) |
||||
|
else: |
||||
|
cb(room, event) |
||||
|
|
||||
|
|
||||
|
async def invite_cb(self, room: MatrixRoom, event: InviteEvent) -> None: |
||||
|
"""Automatically join all rooms we get invited to""" |
||||
|
result = await self.client.join(room.room_id) |
||||
|
if isinstance(result, nio.responses.JoinResponse): |
||||
|
logger.info('Invited and joined room: {} {}'.format(room.name, room.room_id)) |
||||
|
else: |
||||
|
logger.error("Error joining room: {}".format(str(result))) |
||||
|
|
||||
|
|
||||
|
async def to_device_callback(self, event): |
||||
|
"""Handle events sent to device.""" |
||||
|
if isinstance(event, KeyVerificationStart): # first step |
||||
|
if "emoji" not in event.short_authentication_string: |
||||
|
logger.warning( |
||||
|
"Other device does not support emoji verification " |
||||
|
f"{event.short_authentication_string}." |
||||
|
) |
||||
|
return |
||||
|
resp = await self.client.accept_key_verification(event.transaction_id) |
||||
|
if isinstance(resp, ToDeviceError): |
||||
|
logger.warning(f"accept_key_verification failed with {resp}") |
||||
|
|
||||
|
sas = self.client.key_verifications[event.transaction_id] |
||||
|
|
||||
|
todevice_msg = sas.share_key() |
||||
|
resp = await self.client.to_device(todevice_msg) |
||||
|
if isinstance(resp, ToDeviceError): |
||||
|
logger.warning(f"to_device failed with {resp}") |
||||
|
elif isinstance(event, KeyVerificationCancel): |
||||
|
logger.warning( |
||||
|
f"Verification has been cancelled by {event.sender} " |
||||
|
f'for reason "{event.reason}".' |
||||
|
) |
||||
|
elif isinstance(event, KeyVerificationKey): # second step |
||||
|
sas = self.client.key_verifications[event.transaction_id] |
||||
|
|
||||
|
logger.info(f"{sas.get_emoji()}") |
||||
|
#yn = input("Do the emojis match? (Y/N) (C for Cancel) ") |
||||
|
await asyncio.sleep(5) |
||||
|
yn = 'y' |
||||
|
if yn.lower() == "y": |
||||
|
#print( |
||||
|
# "Match! The verification for this " "device will be accepted." |
||||
|
#) |
||||
|
resp = await self.client.confirm_short_auth_string(event.transaction_id) |
||||
|
if isinstance(resp, ToDeviceError): |
||||
|
logger.warning(f"confirm_short_auth_string failed with {resp}") |
||||
|
elif yn.lower() == "n": # no, don't match, reject |
||||
|
#print( |
||||
|
# "No match! Device will NOT be verified " |
||||
|
# "by rejecting verification." |
||||
|
#) |
||||
|
resp = await self.client.cancel_key_verification(event.transaction_id, reject=True) |
||||
|
if isinstance(resp, ToDeviceError): |
||||
|
logger.warning(f"cancel_key_verification failed with {resp}") |
||||
|
else: # C or anything for cancel |
||||
|
#print("Cancelled by user! Verification will be " "cancelled.") |
||||
|
resp = await self.client.cancel_key_verification(event.transaction_id, reject=False) |
||||
|
if isinstance(resp, ToDeviceError): |
||||
|
logger.warning(f"cancel_key_verification failed with {resp}") |
||||
|
elif isinstance(event, KeyVerificationMac): # third step |
||||
|
sas = self.client.key_verifications[event.transaction_id] |
||||
|
try: |
||||
|
todevice_msg = sas.get_mac() |
||||
|
except LocalProtocolError as e: |
||||
|
# e.g. it might have been cancelled by ourselves |
||||
|
logger.warning( |
||||
|
f"Cancelled or protocol error: Reason: {e}.\n" |
||||
|
f"Verification with {event.sender} not concluded. " |
||||
|
"Try again?" |
||||
|
) |
||||
|
else: |
||||
|
resp = await self.client.to_device(todevice_msg) |
||||
|
if isinstance(resp, ToDeviceError): |
||||
|
logger.warning(f"to_device failed with {resp}") |
||||
|
logger.info( |
||||
|
f"sas.we_started_it = {sas.we_started_it}\n" |
||||
|
f"sas.sas_accepted = {sas.sas_accepted}\n" |
||||
|
f"sas.canceled = {sas.canceled}\n" |
||||
|
f"sas.timed_out = {sas.timed_out}\n" |
||||
|
f"sas.verified = {sas.verified}\n" |
||||
|
f"sas.verified_devices = {sas.verified_devices}\n" |
||||
|
) |
||||
|
logger.info( |
||||
|
"Emoji verification was successful!\n" |
||||
|
"Hit Control-C to stop the program or " |
||||
|
"initiate another Emoji verification from " |
||||
|
"another device or room." |
||||
|
) |
||||
|
else: |
||||
|
logger.warning( |
||||
|
f"Received unexpected event type {type(event)}. " |
||||
|
f"Event is {event}. Event will be ignored." |
||||
|
) |
||||
|
|
||||
|
class ChatClient(object): |
||||
|
|
||||
|
def __init__(self, homeserver, user_id, password, device_name="matrix-nio"): |
||||
|
self.homeserver = homeserver |
||||
|
self.user_id = user_id |
||||
|
self.password = password |
||||
|
self.device_name = device_name |
||||
|
self.synced = False |
||||
|
|
||||
|
def persist(self, data_dir): |
||||
|
#self.data_dir = data_dir |
||||
|
self.config_file = f"{data_dir}/matrix_credentials.json" |
||||
|
self.store_path = f"{data_dir}/store/" |
||||
|
os.makedirs(data_dir, exist_ok=True) |
||||
|
os.makedirs(self.store_path, exist_ok=True) |
||||
|
|
||||
|
async def login(self): |
||||
|
|
||||
|
client_config = AsyncClientConfig( |
||||
|
max_limit_exceeded=0, |
||||
|
max_timeouts=0, |
||||
|
store_sync_tokens=True, |
||||
|
encryption_enabled=False, |
||||
|
) |
||||
|
|
||||
|
if not hasattr(self, 'config_file') or not os.path.exists(self.config_file): |
||||
|
logger.info(f"No credentials file. Connecting to \"{self.homeserver}\" with user_id and password") |
||||
|
if hasattr(self, 'store_path'): |
||||
|
if not os.path.exists(self.store_path): |
||||
|
os.makedirs(self.store_path) |
||||
|
else: |
||||
|
self.store_path=None |
||||
|
|
||||
|
# initialize the matrix client |
||||
|
self.client = AsyncClient( |
||||
|
self.homeserver, |
||||
|
self.user_id, |
||||
|
store_path=self.store_path, |
||||
|
config=client_config, |
||||
|
) |
||||
|
|
||||
|
self.callbacks = Callbacks(self.client) |
||||
|
|
||||
|
resp = await self.client.login(self.password, device_name=self.device_name) |
||||
|
# check that we logged in succesfully |
||||
|
if isinstance(resp, LoginResponse): |
||||
|
if hasattr(self, 'config_file'): |
||||
|
self.write_details_to_disk(self.config_file, resp, self.homeserver) |
||||
|
else: |
||||
|
logger.error(f'homeserver = "{self.homeserver}"; user = "{self.user_id}"') |
||||
|
logger.error(f"Failed to log in: {resp}") |
||||
|
sys.exit(1) |
||||
|
else: |
||||
|
logger.info(f"Logging in to \"{self.homeserver}\" using stored credentials.") |
||||
|
with open(self.config_file, "r") as f: |
||||
|
config = json.load(f) |
||||
|
self.client = AsyncClient( |
||||
|
config["homeserver"], |
||||
|
config["user_id"], |
||||
|
device_id=config["device_id"], |
||||
|
store_path=self.store_path, |
||||
|
config=client_config, |
||||
|
) |
||||
|
|
||||
|
self.callbacks = Callbacks(self.client) |
||||
|
|
||||
|
# self.client.user_id=config["user_id"], |
||||
|
# self.client.device_id=config["device_id"], |
||||
|
# self.client.access_token=config["access_token"] |
||||
|
|
||||
|
self.client.restore_login( # the load_store() inside somehow makes the client.rooms empty when encrypted. you can just set the access_token. see commented code before |
||||
|
user_id=config["user_id"], |
||||
|
device_id=config["device_id"], |
||||
|
access_token=config["access_token"] |
||||
|
) |
||||
|
|
||||
|
# if os.path.exists(self.store_path + "megolm_keys"): |
||||
|
# await self.client.import_keys(self.store_path + "megolm_keys", "pass") |
||||
|
|
||||
|
# self.client.load_store() |
||||
|
|
||||
|
# if self.client.should_upload_keys: |
||||
|
# await self.client.keys_upload() |
||||
|
|
||||
|
|
||||
|
self.client.add_to_device_callback(self.callbacks.to_device_callback, (KeyVerificationEvent,)) |
||||
|
|
||||
|
logger.info(f"Connected as \"{self.user_id}\"") |
||||
|
sync_task = asyncio.create_task(self.watch_for_sync()) |
||||
|
|
||||
|
return self.client |
||||
|
|
||||
|
|
||||
|
async def logout(self): |
||||
|
logger.warning("Disconnected") |
||||
|
await self.client.close() |
||||
|
|
||||
|
async def send_message(self, room_id, message): |
||||
|
try: |
||||
|
return await self.client.room_send( |
||||
|
room_id=room_id, |
||||
|
message_type="m.room.message", |
||||
|
content={ |
||||
|
"msgtype": "m.text", |
||||
|
"body": message |
||||
|
}, |
||||
|
ignore_unverified_devices = True, # ToDo |
||||
|
) |
||||
|
except nio.exceptions.OlmUnverifiedDeviceError as err: |
||||
|
print("These are all known devices:") |
||||
|
device_store: crypto.DeviceStore = device_store |
||||
|
[ |
||||
|
print( |
||||
|
f"\t{device.user_id}\t {device.device_id}\t {device.trust_state}\t {device.display_name}" |
||||
|
) |
||||
|
for device in self.client.device_store |
||||
|
] |
||||
|
raise |
||||
|
|
||||
|
async def send_image(self, room_id, image): |
||||
|
"""Send image to room |
||||
|
https://matrix-nio.readthedocs.io/en/latest/examples.html#sending-an-image |
||||
|
""" |
||||
|
mime_type = magic.from_file(image, mime=True) # e.g. "image/jpeg" |
||||
|
if not mime_type.startswith("image/"): |
||||
|
logger.error("Drop message because file does not have an image mime type.") |
||||
|
return |
||||
|
|
||||
|
im = Image.open(image) |
||||
|
(width, height) = im.size # im.size returns (width,height) tuple |
||||
|
|
||||
|
# first do an upload of image, then send URI of upload to room |
||||
|
file_stat = await aiofiles.os.stat(image) |
||||
|
async with aiofiles.open(image, "r+b") as f: |
||||
|
resp, maybe_keys = await self.client.upload( |
||||
|
f, |
||||
|
content_type=mime_type, # image/jpeg |
||||
|
filename=os.path.basename(image), |
||||
|
filesize=file_stat.st_size, |
||||
|
) |
||||
|
if isinstance(resp, UploadResponse): |
||||
|
logger.info("Image was uploaded successfully to server. ") |
||||
|
else: |
||||
|
logger.error(f"Failed to upload image. Failure response: {resp}") |
||||
|
|
||||
|
content = { |
||||
|
"body": os.path.basename(image), # descriptive title |
||||
|
"info": { |
||||
|
"size": file_stat.st_size, |
||||
|
"mimetype": mime_type, |
||||
|
"thumbnail_info": None, # TODO |
||||
|
"w": width, # width in pixel |
||||
|
"h": height, # height in pixel |
||||
|
"thumbnail_url": None, # TODO |
||||
|
}, |
||||
|
"msgtype": "m.image", |
||||
|
"url": resp.content_uri, |
||||
|
} |
||||
|
|
||||
|
try: |
||||
|
await self.client.room_send(room_id, message_type="m.room.message", content=content) |
||||
|
logger.info("Image was sent successfully") |
||||
|
except Exception: |
||||
|
logger.error(f"Image send of file {image} failed.") |
||||
|
|
||||
|
|
||||
|
async def room_typing(self, room_id, is_typing, timeout=15000): |
||||
|
if is_typing: |
||||
|
return await self.client.room_typing(room_id, is_typing, timeout) |
||||
|
else: |
||||
|
return await self.client.room_typing(room_id, False) |
||||
|
|
||||
|
async def room_read_markers(self, room_id, event1, event2): |
||||
|
return await self.client.room_read_markers(room_id, event1, event2) |
||||
|
|
||||
|
def sync_forever(self, timeout=30000, full_state=True): |
||||
|
return self.client.sync_forever(timeout, full_state) |
||||
|
|
||||
|
async def watch_for_sync(self): |
||||
|
logger.debug("Awaiting sync") |
||||
|
await self.client.synced.wait() |
||||
|
logger.info("Client is synced") |
||||
|
self.synced = True |
||||
|
logger.info(f"{self.client.user_id}, {self.client.rooms}, {self.client.invited_rooms}, {self.client.encrypted_rooms}") |
||||
|
# if os.path.exists(self.store_path + "megolm_keys"): |
||||
|
# os.remove(self.store_path + "megolm_keys", "pass") |
||||
|
# await self.client.export_keys(self.store_path + "megolm_keys", "pass") |
||||
|
|
||||
|
|
||||
|
def write_details_to_disk(self, config_file: str, resp: LoginResponse, homeserver) -> None: |
||||
|
"""Writes the required login details to disk so we can log in later without |
||||
|
using a password. |
||||
|
|
||||
|
Arguments: |
||||
|
resp {LoginResponse} -- the successful client login response. |
||||
|
homeserver -- URL of homeserver, e.g. "https://matrix.example.org" |
||||
|
""" |
||||
|
# open the config file in write-mode |
||||
|
with open(config_file, "w") as f: |
||||
|
# write the login details to disk |
||||
|
json.dump( |
||||
|
{ |
||||
|
"homeserver": homeserver, # e.g. "https://matrix.example.org" |
||||
|
"user_id": resp.user_id, # e.g. "@user:example.org" |
||||
|
"device_id": resp.device_id, # device ID, 10 uppercase letters |
||||
|
"access_token": resp.access_token, # cryptogr. access token |
||||
|
}, |
||||
|
f, |
||||
|
) |
@ -1,576 +0,0 @@ |
|||||
import asyncio |
|
||||
import nio |
|
||||
from nio import (AsyncClient, AsyncClientConfig, MatrixRoom, RoomMessageText, InviteEvent, UploadResponse, RedactionEvent) |
|
||||
|
|
||||
import os, sys |
|
||||
import time |
|
||||
import importlib |
|
||||
import configparser |
|
||||
import logging |
|
||||
|
|
||||
import aiofiles.os |
|
||||
import magic |
|
||||
from PIL import Image |
|
||||
import re |
|
||||
import json |
|
||||
|
|
||||
from .helpers import Event |
|
||||
from .chatlog import BotChatHistory |
|
||||
|
|
||||
import matrix_pygmalion_bot.translate as translate |
|
||||
|
|
||||
STORE_PATH = "./.store/" |
|
||||
|
|
||||
|
|
||||
logger = logging.getLogger(__name__) |
|
||||
config = configparser.ConfigParser() |
|
||||
bots = [] |
|
||||
background_tasks = set() |
|
||||
|
|
||||
class Callbacks(object): |
|
||||
"""Class to pass client to callback methods.""" |
|
||||
|
|
||||
def __init__(self, client: AsyncClient, bot): |
|
||||
self.client = client |
|
||||
self.bot = bot |
|
||||
|
|
||||
async def message_cb(self, room: MatrixRoom, event: RoomMessageText) -> None: |
|
||||
if not hasattr(event, 'body'): |
|
||||
return |
|
||||
|
|
||||
if not room.room_id in self.bot.room_config: |
|
||||
self.bot.room_config[room.room_id] = {} |
|
||||
self.bot.room_config[room.room_id]["tick"] = 0 |
|
||||
relates_to = None |
|
||||
if 'm.relates_to' in event.source["content"]: |
|
||||
relates_to = event.source["content"]['m.relates_to']["event_id"] |
|
||||
is_command = False |
|
||||
if event.body.startswith('!'): |
|
||||
is_command = True |
|
||||
language = "en" |
|
||||
if not (self.bot.translate is None) and not is_command: |
|
||||
language = self.bot.translate |
|
||||
|
|
||||
if 'original_message' in event.source["content"]: |
|
||||
english_original_message = event.source["content"]['original_message'] |
|
||||
else: |
|
||||
english_original_message = None |
|
||||
|
|
||||
chat_message = self.bot.chat_history.room(room.room_id).add(event.event_id, event.server_timestamp, room.user_name(event.sender), event.sender == self.client.user, is_command, relates_to, event.body, language, english_original_message) |
|
||||
|
|
||||
# parse keywords |
|
||||
if not event.body.startswith('!') and not event.body.startswith('<ERROR>'): |
|
||||
self.bot.extra_info = {"persona": [], "scenario": [], "example_dialogue": []} |
|
||||
for i, keyword in enumerate(self.bot.keywords): |
|
||||
if re.search(keyword["regex"], event.body, flags=re.IGNORECASE): |
|
||||
if not 'active' in self.bot.keywords[i] or self.bot.keywords[i]['active'] < 1: |
|
||||
self.bot.chat_history.room(room.room_id).setFastForward(False) |
|
||||
self.bot.keywords[i]['active'] = int(keyword["duration"]) |
|
||||
logger.info(f"keyword \"{keyword['regex']}\" detected: \"{event.body}\"") |
|
||||
if 'active' in self.bot.keywords[i]: |
|
||||
if self.bot.keywords[i]['active'] > 0: |
|
||||
logger.info(f"keyword \"{keyword['regex']}\" active. (duration {self.bot.keywords[i]['active']})") |
|
||||
if 'example_dialogue' in keyword: |
|
||||
self.bot.extra_info['example_dialogue'].append(keyword['example_dialogue']) |
|
||||
if 'persona' in keyword: |
|
||||
self.bot.extra_info['persona'].append(keyword['persona']) |
|
||||
if 'scenario' in keyword: |
|
||||
self.bot.extra_info['scenario'].append(keyword['scenario']) |
|
||||
self.bot.keywords[i]["active"] -= 1 |
|
||||
|
|
||||
if self.bot.not_synced: |
|
||||
return |
|
||||
logger.info( |
|
||||
"Message received for room {} | {}: {}".format( |
|
||||
room.display_name, room.user_name(event.sender), event.body |
|
||||
) |
|
||||
) |
|
||||
|
|
||||
await self.client.room_read_markers(room.room_id, event.event_id, event.event_id) |
|
||||
# Ignore messages when disabled |
|
||||
if "disabled" in self.bot.room_config[room.room_id] and self.bot.room_config[room.room_id]["disabled"] == True and not event.body.startswith('!start'): |
|
||||
return |
|
||||
|
|
||||
# Ignore messages from ourselves |
|
||||
if chat_message.is_own_message: |
|
||||
return |
|
||||
# Ignore message from strangers |
|
||||
if not (self.bot.owner is None): |
|
||||
if not (event.sender == self.bot.owner or chat_message.is_own_message): |
|
||||
return |
|
||||
|
|
||||
self.bot.user_name = room.user_name(event.sender).title() |
|
||||
if self.bot.user_name.casefold() == self.bot.name.casefold(): |
|
||||
self.bot.user_name = "You" |
|
||||
|
|
||||
if event.body.startswith('!replybot'): |
|
||||
print(event) |
|
||||
await self.bot.send_message(self.client, room.room_id, "Hello World!") |
|
||||
return |
|
||||
elif re.search("(?s)^!image(?P<num>[0-9])?(\s(?P<cmd>.*))?$", event.body, flags=re.DOTALL): |
|
||||
m = re.search("(?s)^!image(?P<num>[0-9])?(\s(?P<cmd>.*))?$", event.body, flags=re.DOTALL) |
|
||||
if m['num']: |
|
||||
num = int(m['num']) |
|
||||
else: |
|
||||
num = 1 |
|
||||
if m['cmd']: |
|
||||
prompt = m['cmd'].strip() |
|
||||
if self.bot.image_prompt: |
|
||||
prompt = prompt.replace(self.bot.name, self.bot.image_prompt) |
|
||||
else: |
|
||||
if self.bot.image_prompt: |
|
||||
prompt = self.bot.image_prompt |
|
||||
else: |
|
||||
prompt = "a beautiful woman" |
|
||||
if self.bot.negative_prompt: |
|
||||
negative_prompt = self.bot.negative_prompt |
|
||||
elif num == 1: |
|
||||
negative_prompt = "out of frame, (ugly:1.3), (fused fingers), (too many fingers), (bad anatomy:1.5), (watermark:1.5), (words), letters, untracked eyes, asymmetric eyes, floating head, (logo:1.5), (bad hands:1.3), (mangled hands:1.2), (missing hands), (missing arms), backward hands, floating jewelry, unattached jewelry, floating head, doubled head, unattached head, doubled head, head in body, (misshapen body:1.1), (badly fitted headwear:1.2), floating arms, (too many arms:1.5), limbs fused with body, (facial blemish:1.5), badly fitted clothes, imperfect eyes, untracked eyes, crossed eyes, hair growing from clothes, partial faces, hair not attached to head" |
|
||||
elif num == 5: |
|
||||
negative_prompt = "anime, cartoon, penis, fake, drawing, illustration, boring, 3d render, long neck, out of frame, extra fingers, mutated hands, monochrome, ((poorly drawn hands)), ((poorly drawn face)), (((mutation))), (((deformed))), ((ugly)), blurry, ((bad anatomy)), (((bad proportions))), ((extra limbs)), cloned face, glitchy, bokeh, (((long neck))), 3D, 3DCG, cgstation, red eyes, multiple subjects, extra heads, close up, watermarks, logo" |
|
||||
else: |
|
||||
negative_prompt = "ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face" |
|
||||
# else: |
|
||||
# negative_prompt = "ugly, deformed, out of frame" |
|
||||
try: |
|
||||
typing = lambda : self.client.room_typing(room.room_id, True, 15000) |
|
||||
if self.bot.service_image == "runpod": |
|
||||
if num == 1: |
|
||||
output = await self.bot.image_ai.generate_image1(prompt, negative_prompt, self.bot.runpod_api_key, typing) |
|
||||
elif num == 2: |
|
||||
output = await self.bot.image_ai.generate_image2(prompt, negative_prompt, self.bot.runpod_api_key, typing) |
|
||||
elif num == 3: |
|
||||
output = await self.bot.image_ai.generate_image3(prompt, negative_prompt, self.bot.runpod_api_key, typing) |
|
||||
elif num == 4: |
|
||||
output = await self.bot.image_ai.generate_image4(prompt, negative_prompt, self.bot.runpod_api_key, typing) |
|
||||
elif num == 5: |
|
||||
output = await self.bot.image_ai.generate_image5(prompt, negative_prompt, self.bot.runpod_api_key, typing) |
|
||||
elif num == 6: |
|
||||
output = await self.bot.image_ai.generate_image6(prompt, negative_prompt, self.bot.runpod_api_key, typing) |
|
||||
elif num == 7: |
|
||||
output = await self.bot.image_ai.generate_image7(prompt, negative_prompt, self.bot.runpod_api_key, typing) |
|
||||
elif num == 8: |
|
||||
output = await self.bot.image_ai.generate_image8(prompt, negative_prompt, self.bot.runpod_api_key, typing) |
|
||||
else: |
|
||||
raise ValueError('no image generator with that number') |
|
||||
elif self.bot.service_image == "stablehorde": |
|
||||
if num == 1: |
|
||||
output = await self.bot.image_ai.generate_image1(prompt, negative_prompt, self.bot.stablehorde_api_key, typing) |
|
||||
elif num == 2: |
|
||||
output = await self.bot.image_ai.generate_image2(prompt, negative_prompt, self.bot.stablehorde_api_key, typing) |
|
||||
elif num == 3: |
|
||||
output = await self.bot.image_ai.generate_image3(prompt, negative_prompt, self.bot.stablehorde_api_key, typing) |
|
||||
else: |
|
||||
raise ValueError('no image generator with that number') |
|
||||
else: |
|
||||
raise ValueError('remote image generation not configured properly') |
|
||||
except ValueError as err: |
|
||||
await self.client.room_typing(room.room_id, False) |
|
||||
errormessage = f"<ERROR> {err=}, {type(err)=}" |
|
||||
logger.error(errormessage) |
|
||||
await self.bot.send_message(self.client, room.room_id, errormessage) |
|
||||
return |
|
||||
|
|
||||
await self.client.room_typing(room.room_id, False) |
|
||||
for imagefile in output: |
|
||||
await self.bot.send_image(self.client, room.room_id, imagefile) |
|
||||
return |
|
||||
|
|
||||
elif event.body.startswith('!image_negative_prompt'): |
|
||||
negative_prompt = event.body.removeprefix('!image_negative_prompt').strip() |
|
||||
if len(negative_prompt) > 0: |
|
||||
self.bot.negative_prompt = negative_prompt |
|
||||
else: |
|
||||
self.bot.negative_prompt = None |
|
||||
return |
|
||||
elif event.body.startswith('!temperature'): |
|
||||
self.bot.temperature = float( event.body.removeprefix('!temperature').strip() ) |
|
||||
elif event.body.startswith('!begin'): |
|
||||
self.bot.chat_history.room(room.room_id).clear() |
|
||||
self.bot.room_config[room.room_id]["tick"] = 0 |
|
||||
await self.bot.write_conf2(self.bot.name) |
|
||||
await self.bot.send_message(self.client, room.room_id, self.bot.greeting) |
|
||||
return |
|
||||
elif event.body.startswith('!start'): |
|
||||
self.bot.room_config[room.room_id]["disabled"] = False |
|
||||
return |
|
||||
elif event.body.startswith('!stop'): |
|
||||
self.bot.room_config[room.room_id]["disabled"] = True |
|
||||
return |
|
||||
elif event.body.startswith('!!!'): |
|
||||
if self.bot.chat_history.room(room.room_id).getLen() < 3: |
|
||||
return |
|
||||
chat_history_item = self.bot.chat_history.room(room.room_id).remove(1) # current |
|
||||
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") |
|
||||
chat_history_item = self.bot.chat_history.room(room.room_id).remove(1) |
|
||||
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") |
|
||||
chat_history_item = self.bot.chat_history.room(room.room_id).remove(1) |
|
||||
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") |
|
||||
return |
|
||||
elif event.body.startswith('!!'): |
|
||||
if self.bot.chat_history.room(room.room_id).getLen() < 3: |
|
||||
return |
|
||||
chat_history_item = self.bot.chat_history.room(room.room_id).remove(1)# current |
|
||||
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") |
|
||||
chat_history_item = self.bot.chat_history.room(room.room_id).remove(1) |
|
||||
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") |
|
||||
chat_message = self.bot.chat_history.room(room.room_id).getLastItem() # new current |
|
||||
# don't return, we generate a new answer |
|
||||
elif event.body.startswith('!replace'): |
|
||||
if self.bot.chat_history.room(room.room_id).getLen() < 3: |
|
||||
return |
|
||||
chat_history_item = self.bot.chat_history.room(room.room_id).remove(1) # current |
|
||||
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") |
|
||||
chat_history_item = self.bot.chat_history.room(room.room_id).remove(1) |
|
||||
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") |
|
||||
new_answer = event.body.removeprefix('!replace').strip() |
|
||||
await self.bot.send_message(self.client, room.room_id, new_answer, reply_to=chat_history_item.relates_to_event) |
|
||||
return |
|
||||
elif event.body.startswith('!'): |
|
||||
await self.bot.send_message(self.client, room.room_id, "<ERROR> UNKNOWN COMMAND") |
|
||||
return |
|
||||
|
|
||||
# Other commands |
|
||||
if re.search("^(?=.*\bsend\b)(?=.*\bpicture\b).*$", event.body, flags=re.IGNORECASE): |
|
||||
# send, mail, drop, snap picture, photo, image, portrait |
|
||||
pass |
|
||||
|
|
||||
full_prompt = await self.bot.text_ai.get_full_prompt(chat_message.getTranslation("en"), self.bot, self.bot.chat_history.room(room.room_id), self.bot.model) |
|
||||
num_tokens = await self.bot.text_ai.num_tokens(full_prompt, self.bot.model) |
|
||||
logger.debug(full_prompt) |
|
||||
logger.info(f"Prompt has " + str(num_tokens) + " tokens") |
|
||||
# answer = "" |
|
||||
# time = 0 |
|
||||
# error = None |
|
||||
# try: |
|
||||
# async for output in generate(full_prompt): |
|
||||
# await asyncio.sleep(0.1) |
|
||||
# answer += output |
|
||||
# if time % 5 == 0: |
|
||||
# await self.client.room_typing(room.room_id, True, 15000) |
|
||||
# time +=1 |
|
||||
# print(output, end='', flush=True) |
|
||||
# except Exception as e: |
|
||||
# error = e.__str__() |
|
||||
# answer = answer.strip() |
|
||||
# print("") |
|
||||
try: |
|
||||
typing = lambda : self.client.room_typing(room.room_id, True, 15000) |
|
||||
answer = await self.bot.text_ai.generate_sync(full_prompt, self.bot.runpod_api_key, self.bot, typing) |
|
||||
answer = answer.strip() |
|
||||
await self.client.room_typing(room.room_id, False) |
|
||||
if (self.bot.translate is None) or (self.bot.translate == "en"): |
|
||||
await self.bot.send_message(self.client, room.room_id, answer, reply_to=chat_message.event_id) |
|
||||
else: |
|
||||
translated_answer = translate.translate(answer, "en", self.bot.translate) |
|
||||
await self.bot.send_message(self.client, room.room_id, translated_answer, reply_to=chat_message.event_id, original_message=answer) |
|
||||
|
|
||||
if not "message_count" in self.bot.room_config[room.room_id]: |
|
||||
self.bot.room_config[room.room_id]["message_count"] = 0 |
|
||||
self.bot.room_config[room.room_id]["message_count"] += 1 |
|
||||
except ValueError as err: |
|
||||
await self.client.room_typing(room.room_id, False) |
|
||||
errormessage = f"<ERROR> {err=}, {type(err)=}" |
|
||||
logger.error(errormessage) |
|
||||
await self.bot.send_message(self.client, room.room_id, errormessage) |
|
||||
return |
|
||||
|
|
||||
async def invite_cb(self, room: MatrixRoom, event: InviteEvent) -> None: |
|
||||
"""Automatically join all rooms we get invited to""" |
|
||||
result = await self.client.join(room.room_id) |
|
||||
if isinstance(result, nio.responses.JoinResponse): |
|
||||
logger.info('Invited and joined room: {} {}'.format(room.name, room.room_id)) |
|
||||
else: |
|
||||
logger.error("Error joining room: {}".format(str(result))) |
|
||||
|
|
||||
async def redaction_cb(self, room: MatrixRoom, event: RedactionEvent) -> None: |
|
||||
logger.info(f"event redacted in room {room.room_id}. event_id: {event.redacts}") |
|
||||
for bot in bots: |
|
||||
# for room in bot.chat_history.chat_rooms.keys(): |
|
||||
if room.room_id in bot.chat_history.chat_rooms: |
|
||||
logger.info("room found") |
|
||||
if bot.chat_history.chat_rooms[room.room_id].exists_id(event.redacts): |
|
||||
logger.info("found it") |
|
||||
bot.chat_history.chat_rooms[room.room_id].remove_id(event.redacts) |
|
||||
self.bot.chat_history.room(room.room_id).setFastForward(False) |
|
||||
|
|
||||
class ChatBot(object): |
|
||||
"""Main chatbot""" |
|
||||
|
|
||||
def __init__(self, homeserver, user_id, password, device_name="matrix-nio"): |
|
||||
self.homeserver = homeserver |
|
||||
self.user_id = user_id |
|
||||
self.password = password |
|
||||
self.device_name = device_name |
|
||||
|
|
||||
self.service_text = "other" |
|
||||
self.service_image = "other" |
|
||||
self.model = "other" |
|
||||
self.runpod_api_key = None |
|
||||
self.runpod_text_endpoint = "pygmalion-6b" |
|
||||
self.text_ai = None |
|
||||
self.image_ai = None |
|
||||
|
|
||||
self.client = None |
|
||||
self.callbacks = None |
|
||||
self.config = None |
|
||||
self.not_synced = True |
|
||||
|
|
||||
self.owner = None |
|
||||
self.translate = None |
|
||||
self.user_name = "You" |
|
||||
|
|
||||
self.name = None |
|
||||
self.persona = None |
|
||||
self.scenario = None |
|
||||
self.greeting = None |
|
||||
self.example_dialogue = [] |
|
||||
self.nsfw = False |
|
||||
self.keywords = [] |
|
||||
self.extra_info = {"persona": [], "scenario": [], "example_dialogue": []} |
|
||||
self.temperature = 0.90 |
|
||||
self.events = [] |
|
||||
self.global_tick = 0 |
|
||||
self.chat_history = None |
|
||||
self.room_config = {} |
|
||||
|
|
||||
self.negative_prompt = None |
|
||||
|
|
||||
if STORE_PATH and not os.path.isdir(STORE_PATH): |
|
||||
os.mkdir(STORE_PATH) |
|
||||
|
|
||||
def character_init(self, name, persona, scenario, greeting, example_dialogue=[], nsfw=False): |
|
||||
self.name = name |
|
||||
self.persona = persona |
|
||||
self.scenario = scenario |
|
||||
self.greeting = greeting |
|
||||
self.example_dialogue = example_dialogue |
|
||||
self.chat_history = BotChatHistory(self.name) |
|
||||
self.nsfw = nsfw |
|
||||
|
|
||||
def get_persona(self): |
|
||||
return ' '.join([self.persona, ' '.join(self.extra_info['persona'])]) |
|
||||
|
|
||||
def get_scenario(self): |
|
||||
return ' '.join([self.scenario, ' '.join(self.extra_info['scenario'])]) |
|
||||
|
|
||||
def get_example_dialogue(self): |
|
||||
return self.example_dialogue + self.extra_info['example_dialogue'] |
|
||||
|
|
||||
async def event_loop(self): |
|
||||
try: |
|
||||
while True: |
|
||||
await asyncio.sleep(60) |
|
||||
for room_id in self.room_config.keys(): |
|
||||
for event in self.events: |
|
||||
event.loop(self, self.room_config[room_id]["tick"]) |
|
||||
self.room_config[room_id]["tick"] += 1 |
|
||||
self.global_tick += 1 |
|
||||
if self.global_tick % 10 == 0: |
|
||||
await self.write_conf2(self.name) |
|
||||
finally: |
|
||||
await self.write_conf2(self.name) |
|
||||
|
|
||||
async def add_event(self, event_string): |
|
||||
items = event_string.split(',', 4) |
|
||||
for item in items: |
|
||||
item = item.strip() |
|
||||
event = Event(int(items[0]), int(items[1]), float(items[2]), int(items[3]), items[4].lstrip()) |
|
||||
self.events.append(event) |
|
||||
logger.debug("event added to event_loop") |
|
||||
pass |
|
||||
|
|
||||
async def login(self): |
|
||||
self.config = AsyncClientConfig(store_sync_tokens=True) |
|
||||
self.client = AsyncClient(self.homeserver, self.user_id, store_path=STORE_PATH, config=self.config) |
|
||||
self.callbacks = Callbacks(self.client, self) |
|
||||
self.client.add_event_callback(self.callbacks.message_cb, RoomMessageText) |
|
||||
self.client.add_event_callback(self.callbacks.invite_cb, InviteEvent) |
|
||||
self.client.add_event_callback(self.callbacks.redaction_cb, RedactionEvent) |
|
||||
|
|
||||
sync_task = asyncio.create_task(self.watch_for_sync(self.client.synced)) |
|
||||
event_loop = asyncio.create_task(self.event_loop()) |
|
||||
background_tasks.add(event_loop) |
|
||||
event_loop.add_done_callback(background_tasks.discard) |
|
||||
|
|
||||
try: |
|
||||
response = await self.client.login(self.password) |
|
||||
logger.info(response) |
|
||||
#sync_forever_task = asyncio.create_task(self.client.sync_forever(timeout=30000, full_state=True)) |
|
||||
except (asyncio.CancelledError, KeyboardInterrupt): |
|
||||
logger.error("Received interrupt while login.") |
|
||||
await self.client.close() |
|
||||
#return sync_forever_task |
|
||||
|
|
||||
async def load_ai(self): |
|
||||
if self.service_text == "runpod": |
|
||||
self.text_ai = importlib.import_module("matrix_pygmalion_bot.ai.runpod") |
|
||||
elif self.service_text == "stablehorde": |
|
||||
self.text_ai = importlib.import_module("matrix_pygmalion_bot.ai.stablehorde") |
|
||||
elif self.service_text == "koboldcpp": |
|
||||
self.text_ai = importlib.import_module("matrix_pygmalion_bot.ai.koboldcpp") |
|
||||
else: |
|
||||
raise ValueError(f"no text service with the name \"{self.bot.service_text}\"") |
|
||||
|
|
||||
if self.service_image == "runpod": |
|
||||
self.image_ai = importlib.import_module("matrix_pygmalion_bot.ai.runpod") |
|
||||
elif self.service_image == "stablehorde": |
|
||||
self.image_ai = importlib.import_module("matrix_pygmalion_bot.ai.stablehorde") |
|
||||
elif self.service_image == "koboldcpp": |
|
||||
self.image_ai = importlib.import_module("matrix_pygmalion_bot.ai.koboldcpp") |
|
||||
else: |
|
||||
raise ValueError(f"no image service with the name \"{self.bot.service_text}\"") |
|
||||
|
|
||||
async def watch_for_sync(self, sync_event): |
|
||||
logger.debug("Awaiting sync") |
|
||||
await sync_event.wait() |
|
||||
logger.debug("Client is synced") |
|
||||
self.not_synced = False |
|
||||
|
|
||||
async def read_conf2(self, section): |
|
||||
if not os.path.isfile("bot.conf2"): |
|
||||
return |
|
||||
with open("bot.conf2", "r") as f: |
|
||||
self.room_config = json.load(f) |
|
||||
|
|
||||
async def write_conf2(self, section): |
|
||||
with open("bot.conf2", "w") as f: |
|
||||
json.dump(self.room_config, f) |
|
||||
|
|
||||
async def send_message(self, client, room_id, message, reply_to=None, original_message=None): |
|
||||
content={"msgtype": "m.text", "body": message} |
|
||||
if reply_to: |
|
||||
content["m.relates_to"] = {"event_id": reply_to, "rel_type": "de.xd0.mpygbot.in_reply_to"} |
|
||||
if original_message: |
|
||||
content["original_message"] = original_message |
|
||||
|
|
||||
await client.room_send( |
|
||||
room_id=room_id, |
|
||||
message_type="m.room.message", |
|
||||
content=content, |
|
||||
) |
|
||||
|
|
||||
async def send_image(self, client, room_id, image): |
|
||||
"""Send image to room |
|
||||
https://matrix-nio.readthedocs.io/en/latest/examples.html#sending-an-image |
|
||||
""" |
|
||||
mime_type = magic.from_file(image, mime=True) # e.g. "image/jpeg" |
|
||||
if not mime_type.startswith("image/"): |
|
||||
logger.error("Drop message because file does not have an image mime type.") |
|
||||
return |
|
||||
|
|
||||
im = Image.open(image) |
|
||||
(width, height) = im.size # im.size returns (width,height) tuple |
|
||||
|
|
||||
# first do an upload of image, then send URI of upload to room |
|
||||
file_stat = await aiofiles.os.stat(image) |
|
||||
async with aiofiles.open(image, "r+b") as f: |
|
||||
resp, maybe_keys = await client.upload( |
|
||||
f, |
|
||||
content_type=mime_type, # image/jpeg |
|
||||
filename=os.path.basename(image), |
|
||||
filesize=file_stat.st_size, |
|
||||
) |
|
||||
if isinstance(resp, UploadResponse): |
|
||||
logger.info("Image was uploaded successfully to server. ") |
|
||||
else: |
|
||||
logger.error(f"Failed to upload image. Failure response: {resp}") |
|
||||
|
|
||||
content = { |
|
||||
"body": os.path.basename(image), # descriptive title |
|
||||
"info": { |
|
||||
"size": file_stat.st_size, |
|
||||
"mimetype": mime_type, |
|
||||
"thumbnail_info": None, # TODO |
|
||||
"w": width, # width in pixel |
|
||||
"h": height, # height in pixel |
|
||||
"thumbnail_url": None, # TODO |
|
||||
}, |
|
||||
"msgtype": "m.image", |
|
||||
"url": resp.content_uri, |
|
||||
} |
|
||||
|
|
||||
try: |
|
||||
await client.room_send(room_id, message_type="m.room.message", content=content) |
|
||||
logger.info("Image was sent successfully") |
|
||||
except Exception: |
|
||||
logger.error(f"Image send of file {image} failed.") |
|
||||
|
|
||||
async def main() -> None: |
|
||||
config.read('bot.conf') |
|
||||
logging.basicConfig(level=logging.INFO) |
|
||||
for section in config.sections(): |
|
||||
if section == 'DEFAULT' or section == 'Common': |
|
||||
pass |
|
||||
botname = section |
|
||||
homeserver = config[section]['url'] |
|
||||
user_id = config[section]['username'] |
|
||||
password = config[section]['password'] |
|
||||
if config.has_option(section, 'device_name'): |
|
||||
device_name = config[section]['device_name'] |
|
||||
else: |
|
||||
device_name = "matrix-nio" |
|
||||
bot = ChatBot(homeserver, user_id, password, device_name) |
|
||||
if config.has_option(section, 'example_dialogue'): |
|
||||
example_dialogue = json.loads(config[section]['example_dialogue']) |
|
||||
else: |
|
||||
example_dialogue = [] |
|
||||
if config.has_option(section, 'nsfw'): |
|
||||
nsfw = config[section]['nsfw'] |
|
||||
else: |
|
||||
nsfw = False |
|
||||
bot.character_init(botname, config[section]['persona'].replace("\\n", "\n"), config[section]['scenario'].replace("\\n", "\n"), config[section]['greeting'].replace("\\n", "\n"), example_dialogue, nsfw) |
|
||||
if config.has_option(section, 'keywords'): |
|
||||
bot.keywords = json.loads(config[section]['keywords']) |
|
||||
else: |
|
||||
bot.keywords = [] |
|
||||
if config.has_option(section, 'temperature'): |
|
||||
bot.temperature = config[section]['temperature'] |
|
||||
if config.has_option(section, 'owner'): |
|
||||
bot.owner = config[section]['owner'] |
|
||||
if config.has_option(section, 'translate'): |
|
||||
bot.translate = config[section]['translate'] |
|
||||
translate.init(bot.translate, "en") |
|
||||
translate.init("en", bot.translate) |
|
||||
if config.has_option(section, 'image_prompt'): |
|
||||
bot.image_prompt = config[section]['image_prompt'] |
|
||||
if config.has_option(section, 'events'): |
|
||||
events = config[section]['events'].strip().split('\n') |
|
||||
for event in events: |
|
||||
await bot.add_event(event) |
|
||||
if config.has_option('DEFAULT', 'service_text'): |
|
||||
bot.service_text = config['DEFAULT']['service_text'] |
|
||||
if config.has_option(section, 'service_text'): |
|
||||
bot.service_text = config[section]['service_text'] |
|
||||
if config.has_option('DEFAULT', 'service_image'): |
|
||||
bot.service_image = config['DEFAULT']['service_image'] |
|
||||
if config.has_option(section, 'service_image'): |
|
||||
bot.service_image = config[section]['service_image'] |
|
||||
if config.has_option('DEFAULT', 'model'): |
|
||||
bot.model = config['DEFAULT']['model'] |
|
||||
if config.has_option(section, 'model'): |
|
||||
bot.model = config[section]['model'] |
|
||||
if config.has_option('DEFAULT', 'runpod_api_key'): |
|
||||
bot.runpod_api_key = config['DEFAULT']['runpod_api_key'] |
|
||||
if config.has_option('DEFAULT', 'runpod_text_endpoint'): |
|
||||
bot.runpod_text_endpoint = config['DEFAULT']['runpod_text_endpoint'] |
|
||||
if config.has_option('DEFAULT', 'stablehorde_api_key'): |
|
||||
bot.stablehorde_api_key = config['DEFAULT']['stablehorde_api_key'] |
|
||||
await bot.read_conf2(section) |
|
||||
bots.append(bot) |
|
||||
await bot.load_ai() |
|
||||
await bot.login() |
|
||||
#logger.info("gather") |
|
||||
if sys.version_info[0] == 3 and sys.version_info[1] < 11: |
|
||||
tasks = [] |
|
||||
for bot in bots: |
|
||||
task = asyncio.create_task(bot.client.sync_forever(timeout=30000, full_state=True)) |
|
||||
tasks.append(task) |
|
||||
await asyncio.gather(*tasks) |
|
||||
else: |
|
||||
async with asyncio.TaskGroup() as tg: |
|
||||
for bot in bots: |
|
||||
task = tg.create_task(bot.client.sync_forever(timeout=30000, full_state=True)) |
|
||||
|
|
||||
if __name__ == "__main__": |
|
||||
asyncio.get_event_loop().run_until_complete(main()) |
|
@ -1,44 +0,0 @@ |
|||||
import time |
|
||||
import logging |
|
||||
|
|
||||
logger = logging.getLogger(__name__) |
|
||||
|
|
||||
|
|
||||
class Event: |
|
||||
def __init__(self, tick_start, tick_stop, chance, repeat_times, command): |
|
||||
self.tick_start = tick_start |
|
||||
self.tick_stop = tick_stop |
|
||||
self.chance = chance |
|
||||
self.repeat_times = repeat_times |
|
||||
self.command = command |
|
||||
self.executed = 0 |
|
||||
def __str__(self): |
|
||||
return str("Event starting at time {}".format(self.tick_start)) |
|
||||
def loop(self, bot, tick): |
|
||||
if self.is_active(tick): |
|
||||
if self.is_oneshot(): |
|
||||
if self.executed == 0: |
|
||||
self.execute(bot, tick) |
|
||||
else: |
|
||||
self.execute(bot, tick) |
|
||||
|
|
||||
def is_active(self, tick): |
|
||||
if tick >= self.tick_start and tick <= self.tick_stop: |
|
||||
return True |
|
||||
else: |
|
||||
return False |
|
||||
def is_oneshot(self): |
|
||||
if self.tick_stop == 0 or self.tick_stop == self.tick_start: |
|
||||
return True |
|
||||
else: |
|
||||
return False |
|
||||
def is_timespan(self): |
|
||||
if self.tick_stop > self.tick_start: |
|
||||
return True |
|
||||
else: |
|
||||
return False |
|
||||
def execute(self, bot, tick): |
|
||||
logger.info("event executed for " + bot.name + ". current tick: " + str(tick) + " event: " + str(self.command)) |
|
||||
if self.command.startswith('printtime'): |
|
||||
print(time.time()//1000) |
|
||||
self.executed += 1 |
|
@ -0,0 +1,86 @@ |
|||||
|
#!/usr/bin/env python3 |
||||
|
import asyncio |
||||
|
import os, sys |
||||
|
import json |
||||
|
from .utilities.config_parser import read_config |
||||
|
from .bot.core import ChatBot |
||||
|
from .connections.matrix import ChatClient |
||||
|
import traceback |
||||
|
import logging |
||||
|
|
||||
|
logger = logging.getLogger(__name__) |
||||
|
|
||||
|
DATA_DIR = './.data' |
||||
|
bots = [] |
||||
|
|
||||
|
async def main() -> None: |
||||
|
config = read_config('bot.conf') |
||||
|
if config.has_option('DEFAULT', 'log_level'): |
||||
|
log_level = config['DEFAULT']['log_level'] |
||||
|
if log_level == 'DEBUG': |
||||
|
logging.basicConfig(level=logging.DEBUG) |
||||
|
elif log_level == 'INFO': |
||||
|
logging.basicConfig(level=logging.INFO) |
||||
|
elif log_level == 'WARNING': |
||||
|
logging.basicConfig(level=logging.WARNING) |
||||
|
elif log_level == 'ERROR': |
||||
|
logging.basicConfig(level=logging.ERROR) |
||||
|
elif log_level == 'CRITICAL': |
||||
|
logging.basicConfig(level=logging.CRITICAL) |
||||
|
|
||||
|
os.makedirs(DATA_DIR, exist_ok=True) |
||||
|
|
||||
|
for section in config.sections(): |
||||
|
bot_config = config[section] |
||||
|
connection = ChatClient(bot_config['matrix_homeserver'], bot_config['matrix_username'], bot_config['matrix_password'], bot_config.get('matrix_device_name', 'matrix-nio')) |
||||
|
connection.persist(f"{DATA_DIR}/{section}/matrix") |
||||
|
bot = ChatBot(section, connection) |
||||
|
bot.persist(f"{DATA_DIR}/{section}") |
||||
|
bot.init_character( |
||||
|
bot_config['persona'], |
||||
|
bot_config['scenario'], |
||||
|
bot_config['greeting'], |
||||
|
json.loads(bot_config.get('example_dialogue', "[]")), |
||||
|
bot_config.get('nsfw', False), |
||||
|
bot_config.get('temperature', 0.72), |
||||
|
) |
||||
|
if config.has_option(section, 'owner'): |
||||
|
bot.owner = config[section]['owner'] |
||||
|
# if config.has_option(section, 'translate'): |
||||
|
# bot.translate = config[section]['translate'] |
||||
|
# translate.init(bot.translate, "en") |
||||
|
# translate.init("en", bot.translate) |
||||
|
|
||||
|
await bot.load_ai( |
||||
|
json.loads(bot_config['available_text_endpoints']), |
||||
|
json.loads(bot_config['available_image_endpoints']), |
||||
|
) |
||||
|
|
||||
|
await bot.connect() |
||||
|
bots.append(bot) |
||||
|
|
||||
|
try: |
||||
|
|
||||
|
if sys.version_info[0] == 3 and sys.version_info[1] < 11: |
||||
|
tasks = [] |
||||
|
for bot in bots: |
||||
|
task = asyncio.create_task(bot.connection.sync_forever(timeout=30000, full_state=True)) |
||||
|
tasks.append(task) |
||||
|
await asyncio.gather(*tasks) |
||||
|
else: |
||||
|
async with asyncio.TaskGroup() as tg: |
||||
|
for bot in bots: |
||||
|
task = tg.create_task(bot.connection.sync_forever(timeout=30000, full_state=True)) |
||||
|
|
||||
|
except Exception: |
||||
|
print(traceback.format_exc()) |
||||
|
sys.exit(1) |
||||
|
except (asyncio.CancelledError, KeyboardInterrupt): |
||||
|
print("Received keyboard interrupt.") |
||||
|
for bot in bots: |
||||
|
await bot.disconnect() |
||||
|
sys.exit(0) |
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
asyncio.run(main()) |
||||
|
|
@ -0,0 +1,9 @@ |
|||||
|
import configparser |
||||
|
import logging |
||||
|
|
||||
|
logger = logging.getLogger(__name__) |
||||
|
|
||||
|
def read_config(filename): |
||||
|
config = configparser.ConfigParser() |
||||
|
config.read(filename) |
||||
|
return config |
@ -1,8 +1,9 @@ |
|||||
asyncio |
asyncio |
||||
matrix-nio |
matrix-nio[e2e] |
||||
transformers |
transformers |
||||
huggingface_hub |
huggingface_hub |
||||
python-magic |
python-magic |
||||
pillow |
pillow |
||||
argostranslate |
argostranslate |
||||
webuiapi |
webuiapi |
||||
|
langchain |
||||
|
Loading…
Reference in new issue