from os import environ from pathlib import Path import re import google.generativeai as genai from telebot import TeleBot from telebot.types import Message GOOGLE_GEMINI_KEY = environ.get("GOOGLE_GEMINI_KEY") genai.configure(api_key=GOOGLE_GEMINI_KEY) generation_config = { "temperature": 0.7, "top_p": 1, "top_k": 1, "max_output_tokens": 2048, } safety_settings = [ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, { "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_MEDIUM_AND_ABOVE", }, { "category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE", }, ] # Global history cache gemini_player_dict = {} #### Utils for gemini #### # Note this code copy from https://github.com/yym68686/md2tgmd/blob/main/src/md2tgmd.py # great thanks def find_all_index(str, pattern): index_list = [0] for match in re.finditer(pattern, str, re.MULTILINE): if match.group(1) != None: start = match.start(1) end = match.end(1) index_list += [start, end] index_list.append(len(str)) return index_list def replace_all(text, pattern, function): poslist = [0] strlist = [] originstr = [] poslist = find_all_index(text, pattern) for i in range(1, len(poslist[:-1]), 2): start, end = poslist[i : i + 2] strlist.append(function(text[start:end])) for i in range(0, len(poslist), 2): j, k = poslist[i : i + 2] originstr.append(text[j:k]) if len(strlist) < len(originstr): strlist.append("") else: originstr.append("") new_list = [item for pair in zip(originstr, strlist) for item in pair] return "".join(new_list) def escapeshape(text): return "▎*" + text.split()[1] + "*" def escapeminus(text): return "\\" + text def escapebackquote(text): return r"\`\`" def escapeplus(text): return "\\" + text def escape(text, flag=0): # In all other places characters # _ * [ ] ( ) ~ ` > # + - = | { } . ! # must be escaped with the preceding character '\'. text = re.sub(r"\\\[", "@->@", text) text = re.sub(r"\\\]", "@<-@", text) text = re.sub(r"\\\(", "@-->@", text) text = re.sub(r"\\\)", "@<--@", text) if flag: text = re.sub(r"\\\\", "@@@", text) text = re.sub(r"\\", r"\\\\", text) if flag: text = re.sub(r"\@{3}", r"\\\\", text) text = re.sub(r"_", "\_", text) text = re.sub(r"\*{2}(.*?)\*{2}", "@@@\\1@@@", text) text = re.sub(r"\n{1,2}\*\s", "\n\n• ", text) text = re.sub(r"\*", "\*", text) text = re.sub(r"\@{3}(.*?)\@{3}", "*\\1*", text) text = re.sub(r"\!?\[(.*?)\]\((.*?)\)", "@@@\\1@@@^^^\\2^^^", text) text = re.sub(r"\[", "\[", text) text = re.sub(r"\]", "\]", text) text = re.sub(r"\(", "\(", text) text = re.sub(r"\)", "\)", text) text = re.sub(r"\@\-\>\@", "\[", text) text = re.sub(r"\@\<\-\@", "\]", text) text = re.sub(r"\@\-\-\>\@", "\(", text) text = re.sub(r"\@\<\-\-\@", "\)", text) text = re.sub(r"\@{3}(.*?)\@{3}\^{3}(.*?)\^{3}", "[\\1](\\2)", text) text = re.sub(r"~", "\~", text) text = re.sub(r">", "\>", text) text = replace_all(text, r"(^#+\s.+?$)|```[\D\d\s]+?```", escapeshape) text = re.sub(r"#", "\#", text) text = replace_all( text, r"(\+)|\n[\s]*-\s|```[\D\d\s]+?```|`[\D\d\s]*?`", escapeplus ) text = re.sub(r"\n{1,2}(\s*)-\s", "\n\n\\1• ", text) text = re.sub(r"\n{1,2}(\s*\d{1,2}\.\s)", "\n\n\\1", text) text = replace_all( text, r"(-)|\n[\s]*-\s|```[\D\d\s]+?```|`[\D\d\s]*?`", escapeminus ) text = re.sub(r"```([\D\d\s]+?)```", "@@@\\1@@@", text) text = replace_all(text, r"(``)", escapebackquote) text = re.sub(r"\@{3}([\D\d\s]+?)\@{3}", "```\\1```", text) text = re.sub(r"=", "\=", text) text = re.sub(r"\|", "\|", text) text = re.sub(r"{", "\{", text) text = re.sub(r"}", "\}", text) text = re.sub(r"\.", "\.", text) text = re.sub(r"!", "\!", text) return text def make_new_gemini_convo(): model = genai.GenerativeModel( model_name="gemini-pro", generation_config=generation_config, safety_settings=safety_settings, ) convo = model.start_chat() return convo def gemini_handler(message: Message, bot: TeleBot) -> None: """Gemini : /gemini """ reply_message = bot.reply_to( message, "Generating google gemini answer please wait, note, will only keep the last five messages:", ) m = message.text.strip() player = None # restart will lose all TODO if str(message.from_user.id) not in gemini_player_dict: player = make_new_gemini_convo() gemini_player_dict[str(message.from_user.id)] = player else: player = gemini_player_dict[str(message.from_user.id)] # keep the last 5, every has two ask and answer. if len(player.history) > 10: player.history = player.history[2:] player.send_message(m) gemini_reply_text = player.last.text.strip() try: bot.reply_to( message, "Gemini answer:\n" + escape(gemini_reply_text), parse_mode="MarkdownV2", ) except: print("wrong markdown format") bot.reply_to( message, "Gemini answer:\n\n" + gemini_reply_text, ) finally: bot.delete_message(reply_message.chat.id, reply_message.message_id) def gemini_photo_handler(message: Message, bot: TeleBot) -> None: s = message.caption reply_message = bot.reply_to( message, "Generating google gemini vision answer please wait.", ) prompt = s.strip() # get the high quaility picture. max_size_photo = max(message.photo, key=lambda p: p.file_size) file_path = bot.get_file(max_size_photo.file_id).file_path downloaded_file = bot.download_file(file_path) with open("gemini_temp.jpg", "wb") as temp_file: temp_file.write(downloaded_file) model = genai.GenerativeModel("gemini-pro-vision") image_path = Path("gemini_temp.jpg") image_data = image_path.read_bytes() contents = { "parts": [{"mime_type": "image/jpeg", "data": image_data}, {"text": prompt}] } try: response = model.generate_content(contents=contents) bot.reply_to(message, "Gemini vision answer:\n" + response.text) finally: bot.delete_message(reply_message.chat.id, reply_message.message_id) def register(bot: TeleBot) -> None: bot.register_message_handler(gemini_handler, commands=["gemini"], pass_bot=True) bot.register_message_handler(gemini_handler, regexp="^gemini:", pass_bot=True) bot.register_message_handler( gemini_photo_handler, content_types=["photo"], func=lambda m: m.caption and m.caption.startswith(("gemini:", "/gemini")), pass_bot=True, )