diff --git a/matrix_pygmalion_bot/ai/runpod_pygmalion.py b/matrix_pygmalion_bot/ai/runpod_pygmalion.py index 331fe6d..a872943 100644 --- a/matrix_pygmalion_bot/ai/runpod_pygmalion.py +++ b/matrix_pygmalion_bot/ai/runpod_pygmalion.py @@ -18,7 +18,7 @@ logger = logging.getLogger(__name__) async def generate_sync( prompt: str, api_key: str, - bot_name: str, + bot, ): # Set the API endpoint URL endpoint = "https://api.runpod.ai/v2/pygmalion-6b/runsync" @@ -37,7 +37,7 @@ async def generate_sync( "input": { "prompt": prompt, "max_length": min(prompt_num_tokens+max_new_tokens, 2048), - "temperature": 0.80, + "temperature": bot.temperature, "do_sample": True, } } @@ -61,7 +61,7 @@ async def generate_sync( reply = answer[:idx].strip() else: reply = answer.removesuffix('<|endoftext|>').strip() - reply = reply.replace("\n{bot_name}: ", " ") + reply = reply.replace("\n{bot.name}: ", " ") reply = reply.replace("\n: ", " ") return reply elif status == 'IN_PROGRESS' or status == 'IN_QUEUE': @@ -88,7 +88,7 @@ async def generate_sync( reply = answer[:idx].strip() else: reply = answer.removesuffix('<|endoftext|>').strip() - reply = reply.replace("\n{bot_name}: ", " ") + reply = reply.replace("\n{bot.name}: ", " ") reply = reply.replace("\n: ", " ") return reply else: diff --git a/matrix_pygmalion_bot/core.py b/matrix_pygmalion_bot/core.py index cafe8c9..17cbd9d 100644 --- a/matrix_pygmalion_bot/core.py +++ b/matrix_pygmalion_bot/core.py @@ -222,7 +222,7 @@ class Callbacks(object): # answer = answer.strip() # print("") await self.client.room_typing(room.room_id, True, 15000) - answer = await ai.generate_sync(full_prompt, self.bot.runpod_api_key, self.bot.name) + answer = await ai.generate_sync(full_prompt, self.bot.runpod_api_key, self.bot) answer = answer.strip() await self.client.room_typing(room.room_id, False) if not (self.bot.translate is None): @@ -266,6 +266,7 @@ class ChatBot(object): self.persona = None self.scenario = None self.greeting = None + self.temperature = 0.80 self.events = [] self.global_tick = 0 self.chat_history = None @@ -418,6 +419,8 @@ async def main() -> None: device_name = "matrix-nio" bot = ChatBot(homeserver, user_id, password, device_name) bot.character_init(botname, config[section]['persona'], config[section]['scenario'], config[section]['greeting']) + if config.has_option(section, 'temperature'): + bot.temperature = config[section]['temperature'] if config.has_option(section, 'owner'): bot.owner = config[section]['owner'] if config.has_option(section, 'translate'):