mirror of
https://github.com/cdryzun/tg_bot_collections.git
synced 2025-04-29 00:27:09 +08:00
234 lines
7.5 KiB
Python
234 lines
7.5 KiB
Python
from os import environ
|
|
from pathlib import Path
|
|
import re
|
|
|
|
import google.generativeai as genai
|
|
from google.generativeai.types.generation_types import StopCandidateException
|
|
from telebot import TeleBot
|
|
from telebot.types import Message
|
|
|
|
GOOGLE_GEMINI_KEY = environ.get("GOOGLE_GEMINI_KEY")
|
|
|
|
genai.configure(api_key=GOOGLE_GEMINI_KEY)
|
|
generation_config = {
|
|
"temperature": 0.7,
|
|
"top_p": 1,
|
|
"top_k": 1,
|
|
"max_output_tokens": 2048,
|
|
}
|
|
|
|
safety_settings = [
|
|
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
|
|
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
|
|
{
|
|
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
|
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
|
|
},
|
|
{
|
|
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
|
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
|
|
},
|
|
]
|
|
|
|
# Global history cache
|
|
gemini_player_dict = {}
|
|
|
|
|
|
#### Utils for gemini ####
|
|
# Note this code copy from https://github.com/yym68686/md2tgmd/blob/main/src/md2tgmd.py
|
|
# great thanks
|
|
def find_all_index(str, pattern):
|
|
index_list = [0]
|
|
for match in re.finditer(pattern, str, re.MULTILINE):
|
|
if match.group(1) != None:
|
|
start = match.start(1)
|
|
end = match.end(1)
|
|
index_list += [start, end]
|
|
index_list.append(len(str))
|
|
return index_list
|
|
|
|
|
|
def replace_all(text, pattern, function):
|
|
poslist = [0]
|
|
strlist = []
|
|
originstr = []
|
|
poslist = find_all_index(text, pattern)
|
|
for i in range(1, len(poslist[:-1]), 2):
|
|
start, end = poslist[i : i + 2]
|
|
strlist.append(function(text[start:end]))
|
|
for i in range(0, len(poslist), 2):
|
|
j, k = poslist[i : i + 2]
|
|
originstr.append(text[j:k])
|
|
if len(strlist) < len(originstr):
|
|
strlist.append("")
|
|
else:
|
|
originstr.append("")
|
|
new_list = [item for pair in zip(originstr, strlist) for item in pair]
|
|
return "".join(new_list)
|
|
|
|
|
|
def escapeshape(text):
|
|
return "▎*" + text.split()[1] + "*"
|
|
|
|
|
|
def escapeminus(text):
|
|
return "\\" + text
|
|
|
|
|
|
def escapebackquote(text):
|
|
return r"\`\`"
|
|
|
|
|
|
def escapeplus(text):
|
|
return "\\" + text
|
|
|
|
|
|
def escape(text, flag=0):
|
|
# In all other places characters
|
|
# _ * [ ] ( ) ~ ` > # + - = | { } . !
|
|
# must be escaped with the preceding character '\'.
|
|
text = re.sub(r"\\\[", "@->@", text)
|
|
text = re.sub(r"\\\]", "@<-@", text)
|
|
text = re.sub(r"\\\(", "@-->@", text)
|
|
text = re.sub(r"\\\)", "@<--@", text)
|
|
if flag:
|
|
text = re.sub(r"\\\\", "@@@", text)
|
|
text = re.sub(r"\\", r"\\\\", text)
|
|
if flag:
|
|
text = re.sub(r"\@{3}", r"\\\\", text)
|
|
text = re.sub(r"_", "\_", text)
|
|
text = re.sub(r"\*{2}(.*?)\*{2}", "@@@\\1@@@", text)
|
|
text = re.sub(r"\n{1,2}\*\s", "\n\n• ", text)
|
|
text = re.sub(r"\*", "\*", text)
|
|
text = re.sub(r"\@{3}(.*?)\@{3}", "*\\1*", text)
|
|
text = re.sub(r"\!?\[(.*?)\]\((.*?)\)", "@@@\\1@@@^^^\\2^^^", text)
|
|
text = re.sub(r"\[", "\[", text)
|
|
text = re.sub(r"\]", "\]", text)
|
|
text = re.sub(r"\(", "\(", text)
|
|
text = re.sub(r"\)", "\)", text)
|
|
text = re.sub(r"\@\-\>\@", "\[", text)
|
|
text = re.sub(r"\@\<\-\@", "\]", text)
|
|
text = re.sub(r"\@\-\-\>\@", "\(", text)
|
|
text = re.sub(r"\@\<\-\-\@", "\)", text)
|
|
text = re.sub(r"\@{3}(.*?)\@{3}\^{3}(.*?)\^{3}", "[\\1](\\2)", text)
|
|
text = re.sub(r"~", "\~", text)
|
|
text = re.sub(r">", "\>", text)
|
|
text = replace_all(text, r"(^#+\s.+?$)|```[\D\d\s]+?```", escapeshape)
|
|
text = re.sub(r"#", "\#", text)
|
|
text = replace_all(
|
|
text, r"(\+)|\n[\s]*-\s|```[\D\d\s]+?```|`[\D\d\s]*?`", escapeplus
|
|
)
|
|
text = re.sub(r"\n{1,2}(\s*)-\s", "\n\n\\1• ", text)
|
|
text = re.sub(r"\n{1,2}(\s*\d{1,2}\.\s)", "\n\n\\1", text)
|
|
text = replace_all(
|
|
text, r"(-)|\n[\s]*-\s|```[\D\d\s]+?```|`[\D\d\s]*?`", escapeminus
|
|
)
|
|
text = re.sub(r"```([\D\d\s]+?)```", "@@@\\1@@@", text)
|
|
text = replace_all(text, r"(``)", escapebackquote)
|
|
text = re.sub(r"\@{3}([\D\d\s]+?)\@{3}", "```\\1```", text)
|
|
text = re.sub(r"=", "\=", text)
|
|
text = re.sub(r"\|", "\|", text)
|
|
text = re.sub(r"{", "\{", text)
|
|
text = re.sub(r"}", "\}", text)
|
|
text = re.sub(r"\.", "\.", text)
|
|
text = re.sub(r"!", "\!", text)
|
|
return text
|
|
|
|
|
|
def make_new_gemini_convo():
|
|
model = genai.GenerativeModel(
|
|
model_name="gemini-pro",
|
|
generation_config=generation_config,
|
|
safety_settings=safety_settings,
|
|
)
|
|
convo = model.start_chat()
|
|
return convo
|
|
|
|
|
|
def gemini_handler(message: Message, bot: TeleBot) -> None:
|
|
"""Gemini : /gemini <question>"""
|
|
reply_message = bot.reply_to(
|
|
message,
|
|
"Generating google gemini answer please wait, note, will only keep the last five messages:",
|
|
)
|
|
m = message.text.strip()
|
|
player = None
|
|
# restart will lose all TODO
|
|
if str(message.from_user.id) not in gemini_player_dict:
|
|
player = make_new_gemini_convo()
|
|
gemini_player_dict[str(message.from_user.id)] = player
|
|
else:
|
|
player = gemini_player_dict[str(message.from_user.id)]
|
|
# keep the last 5, every has two ask and answer.
|
|
if len(player.history) > 10:
|
|
player.history = player.history[2:]
|
|
|
|
try:
|
|
player.send_message(m)
|
|
gemini_reply_text = player.last.text.strip()
|
|
except StopCandidateException as e:
|
|
match = re.search(r'content\s*{\s*parts\s*{\s*text:\s*"([^"]+)"', str(e))
|
|
if match:
|
|
gemini_reply_text = match.group(1)
|
|
gemini_reply_text = re.sub(r"\\n", "\n", gemini_reply_text)
|
|
else:
|
|
print("No meaningful text was extracted from the exception.")
|
|
bot.reply_to(
|
|
message,
|
|
"Google gemini encountered an error while generating an answer. Please check the log.",
|
|
)
|
|
return
|
|
|
|
try:
|
|
bot.reply_to(
|
|
message,
|
|
"Gemini answer:\n" + escape(gemini_reply_text),
|
|
parse_mode="MarkdownV2",
|
|
)
|
|
except:
|
|
print("wrong markdown format")
|
|
bot.reply_to(
|
|
message,
|
|
"Gemini answer:\n\n" + gemini_reply_text,
|
|
)
|
|
finally:
|
|
bot.delete_message(reply_message.chat.id, reply_message.message_id)
|
|
|
|
|
|
def gemini_photo_handler(message: Message, bot: TeleBot) -> None:
|
|
s = message.caption
|
|
reply_message = bot.reply_to(
|
|
message,
|
|
"Generating google gemini vision answer please wait.",
|
|
)
|
|
prompt = s.strip()
|
|
# get the high quaility picture.
|
|
max_size_photo = max(message.photo, key=lambda p: p.file_size)
|
|
file_path = bot.get_file(max_size_photo.file_id).file_path
|
|
downloaded_file = bot.download_file(file_path)
|
|
with open("gemini_temp.jpg", "wb") as temp_file:
|
|
temp_file.write(downloaded_file)
|
|
|
|
model = genai.GenerativeModel("gemini-pro-vision")
|
|
with open("gemini_temp.jpg", "rb") as image_file:
|
|
image_data = image_file.read()
|
|
contents = {
|
|
"parts": [{"mime_type": "image/jpeg", "data": image_data}, {"text": prompt}]
|
|
}
|
|
try:
|
|
response = model.generate_content(contents=contents)
|
|
bot.reply_to(message, "Gemini vision answer:\n" + response.text)
|
|
finally:
|
|
bot.delete_message(reply_message.chat.id, reply_message.message_id)
|
|
|
|
|
|
def register(bot: TeleBot) -> None:
|
|
bot.register_message_handler(gemini_handler, commands=["gemini"], pass_bot=True)
|
|
bot.register_message_handler(gemini_handler, regexp="^gemini:", pass_bot=True)
|
|
bot.register_message_handler(
|
|
gemini_photo_handler,
|
|
content_types=["photo"],
|
|
func=lambda m: m.caption and m.caption.startswith(("gemini:", "/gemini")),
|
|
pass_bot=True,
|
|
)
|