From f9e740300e22bff10539f0b190ded84e5aebc227 Mon Sep 17 00:00:00 2001 From: F4ria Date: Tue, 23 Apr 2024 11:44:27 +0800 Subject: [PATCH] use 'del' to remove the player's history for gemini --- handlers/gemini.py | 90 ++++++++++++++++++++++++---------------------- 1 file changed, 48 insertions(+), 42 deletions(-) diff --git a/handlers/gemini.py b/handlers/gemini.py index 5a6490c..34907fd 100644 --- a/handlers/gemini.py +++ b/handlers/gemini.py @@ -3,6 +3,7 @@ import re import time import google.generativeai as genai +from google.generativeai import ChatSession from google.generativeai.types.generation_types import StopCandidateException from telebot import TeleBot from telebot.types import Message @@ -38,7 +39,7 @@ gemini_pro_player_dict = {} gemini_file_player_dict = {} -def make_new_gemini_convo(is_pro=False): +def make_new_gemini_convo(is_pro=False) -> ChatSession: model_name = "models/gemini-1.0-pro-latest" if is_pro: model_name = "models/gemini-1.5-pro-latest" @@ -52,26 +53,46 @@ def make_new_gemini_convo(is_pro=False): return convo +def remove_gemini_player(player_id: str, is_pro: bool) -> None: + if is_pro: + if player_id in gemini_pro_player_dict: + del gemini_pro_player_dict[player_id] + if player_id in gemini_file_player_dict: + del gemini_file_player_dict[player_id] + else: + if player_id in gemini_player_dict: + del gemini_player_dict[player_id] + + +def get_gemini_player(player_id: str, is_pro: bool) -> ChatSession: + player = None + if is_pro: + if player_id not in gemini_pro_player_dict: + gemini_pro_player_dict[player_id] = make_new_gemini_convo(is_pro) + player = gemini_pro_player_dict[player_id] + else: + if player_id not in gemini_player_dict: + gemini_player_dict[player_id] = make_new_gemini_convo() + player = gemini_player_dict[player_id] + + return player + + def gemini_handler(message: Message, bot: TeleBot) -> None: """Gemini : /gemini """ m = message.text.strip() - player = None - # restart will lose all TODO - if str(message.from_user.id) not in gemini_player_dict: - player = make_new_gemini_convo() - gemini_player_dict[str(message.from_user.id)] = player - else: - player = gemini_player_dict[str(message.from_user.id)] + player_id = str(message.from_user.id) + is_pro = False if m.strip() == "clear": - bot.reply_to( - message, - "just clear you gemini messages history", - ) - player.history.clear() + bot.reply_to(message, "just clear you gemini messages history") + remove_gemini_player(player_id, is_pro) return if m[:4].lower() == "new ": m = m[4:].strip() - player.history.clear() + remove_gemini_player(player_id, is_pro) + + # restart will lose all TODO + player = get_gemini_player(player_id, is_pro) m = enrich_text_with_urls(m) who = "Gemini" @@ -105,28 +126,18 @@ def gemini_handler(message: Message, bot: TeleBot) -> None: def gemini_pro_handler(message: Message, bot: TeleBot) -> None: """Gemini : /gemini_pro """ m = message.text.strip() - player = None - # restart will lose all TODO - if str(message.from_user.id) not in gemini_pro_player_dict: - player = make_new_gemini_convo(is_pro=True) - gemini_pro_player_dict[str(message.from_user.id)] = player - else: - player = gemini_pro_player_dict[str(message.from_user.id)] + player_id = str(message.from_user.id) + is_pro = True if m.strip() == "clear": - bot.reply_to( - message, - "just clear you gemini messages history", - ) - player.history.clear() - # also need to clear the data file - if gemini_file_player_dict.get(str(message.from_user.id)): - del gemini_file_player_dict[str(message.from_user.id)] + bot.reply_to(message, "just clear you gemini messages history") + remove_gemini_player(player_id, is_pro) return if m[:4].lower() == "new ": m = m[4:].strip() - player.history.clear() - if gemini_file_player_dict.get(str(message.from_user.id)): - del gemini_file_player_dict[str(message.from_user.id)] + remove_gemini_player(player_id, is_pro) + + # restart will lose all TODO + player = get_gemini_player(player_id, is_pro) m = enrich_text_with_urls(m) who = "Gemini Pro" @@ -138,7 +149,7 @@ def gemini_pro_handler(message: Message, bot: TeleBot) -> None: player.history = player.history[2:] try: - if path := gemini_file_player_dict.get(str(message.from_user.id)): + if path := gemini_file_player_dict.get(player_id): m = [m, path] r = player.send_message(m, stream=True) s = "" @@ -200,26 +211,21 @@ def gemini_audio_handler(message: Message, bot: TeleBot) -> None: s = message.caption prompt = s.strip() who = "Gemini File Audio" - player = None + player_id = str(message.from_user.id) # restart will lose all TODO - if str(message.from_user.id) not in gemini_pro_player_dict: - player = make_new_gemini_convo(is_pro=True) - gemini_pro_player_dict[str(message.from_user.id)] = player - else: - player = gemini_pro_player_dict[str(message.from_user.id)] + player = get_gemini_player(player_id, is_pro=True) file_path = None - # restart will lose all TODO # for file handler like {user_id: [player, file_path], user_id2: [player, file_path]} reply_id = bot_reply_first(message, who, bot) file_path = bot.get_file(message.audio.file_id).file_path downloaded_file = bot.download_file(file_path) - path = f"{str(message.from_user.id)}_gemini.mp3" + path = f"{player_id}_gemini.mp3" with open(path, "wb") as temp_file: temp_file.write(downloaded_file) gemini_mp3_file = genai.upload_file(path=path) r = player.send_message([prompt, gemini_mp3_file], stream=True) # need set it for the conversation - gemini_file_player_dict[str(message.from_user.id)] = gemini_mp3_file + gemini_file_player_dict[player_id] = gemini_mp3_file try: s = "" start = time.time()