feat: Modularize handlers

Signed-off-by: Frost Ming <me@frostming.com>
This commit is contained in:
Frost Ming 2023-12-15 14:00:18 +08:00
parent dbacdd60f0
commit 7a46a22262
No known key found for this signature in database
GPG Key ID: 5BFA9CB4DDA943BF
5 changed files with 366 additions and 323 deletions

79
handlers/__init__.py Normal file
View File

@ -0,0 +1,79 @@
from __future__ import annotations
import importlib
import re
import traceback
from functools import update_wrapper
from pathlib import Path
from typing import Any, Callable, TypeVar
from telebot import TeleBot
from telebot.types import BotCommand, Message
T = TypeVar("T", bound=Callable)
def extract_prompt(message: str, bot_name: str) -> str:
"""
This function filters messages for prompts.
Returns:
str: If it is not a prompt, return None. Otherwise, return the trimmed prefix of the actual prompt.
"""
# remove '@bot_name' as it is considered part of the command when in a group chat.
message = re.sub(re.escape(f"@{bot_name}"), "", message).strip()
# add a whitespace after the first colon to make sure it is separated from the prompt.
message = re.sub(":", ": ", message, count=1).strip()
try:
left, message = message.split(maxsplit=1)
except ValueError:
return ""
if ":" not in left:
# restore the added space
message = message.replace(": ", ":", 1)
return message.strip()
def wrap_handler(handler: T, bot: TeleBot) -> T:
def wrapper(message: Message, *args: Any, **kwargs: Any) -> None:
try:
m = ""
if message.text is not None:
m = message.text = extract_prompt(message.text, bot.get_me().username)
elif message.caption is not None:
m = message.caption = extract_prompt(
message.caption, bot.get_me().username
)
if not m:
bot.reply_to(message, "Please provide info after start words.")
return
return handler(m, *args, **kwargs)
except Exception:
traceback.print_exc()
bot.reply_to(message, "Something wrong, please check the log")
return update_wrapper(wrapper, handler)
def load_handlers(bot: TeleBot) -> None:
# import all submodules
this_path = Path(__file__).parent
for child in this_path.iterdir():
if child.name.startswith("_"):
continue
module = importlib.import_module(f".{child.stem}", __package__)
if hasattr(module, "register"):
print(f"Loading {child.stem} handlers.")
module.register(bot)
print("Loading handlers done.")
all_commands: list[BotCommand] = []
for handler in bot.message_handlers:
help_text = getattr(handler["function"], "__doc__", "")
handler["function"] = wrap_handler(handler["function"], bot)
for command in handler["filters"].get("commands", []):
all_commands.append(BotCommand(command, help_text))
if all_commands:
bot.set_my_commands(all_commands)
print("Setting commands done.")

104
handlers/gemini.py Normal file
View File

@ -0,0 +1,104 @@
from os import environ
from pathlib import Path
import google.generativeai as genai
from telebot import TeleBot
from telebot.types import Message
GOOGLE_GEMINI_KEY = environ.get("GOOGLE_GEMINI_KEY")
genai.configure(api_key=GOOGLE_GEMINI_KEY)
generation_config = {
"temperature": 0.9,
"top_p": 1,
"top_k": 1,
"max_output_tokens": 2048,
}
safety_settings = [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
},
]
# Global history cache
gemini_player_dict = {}
def make_new_gemini_convo():
model = genai.GenerativeModel(
model_name="gemini-pro",
generation_config=generation_config,
safety_settings=safety_settings,
)
convo = model.start_chat()
return convo
def gemini_handler(message: Message, bot: TeleBot) -> None:
"""Gemini : /gemini <question>"""
reply_message = bot.reply_to(
message,
"Generating google gemini answer please wait, note, will only keep the last five messages:",
)
m = message.text.strip()
player = None
# restart will lose all TODO
if str(message.from_user.id) not in gemini_player_dict:
player = make_new_gemini_convo()
gemini_player_dict[str(message.from_user.id)] = player
else:
player = gemini_player_dict[str(message.from_user.id)]
if len(player.history) > 10:
player.history = player.history[2:]
player.send_message(m)
try:
bot.reply_to(
message,
"Gemini answer:\n" + player.last.text,
parse_mode="MarkdownV2",
)
finally:
bot.delete_message(reply_message.chat.id, reply_message.message_id)
def gemini_photo_handler(message: Message, bot: TeleBot) -> None:
s = message.caption
bot.reply_to(
message,
"Generating google gemini vision answer please wait,",
)
prompt = s.strip()
max_size_photo = max(message.photo, key=lambda p: p.file_size)
file_path = bot.get_file(max_size_photo.file_id).file_path
downloaded_file = bot.download_file(file_path)
with open("gemini_temp.jpg", "wb") as temp_file:
temp_file.write(downloaded_file)
model = genai.GenerativeModel("gemini-pro-vision")
image_path = Path("gemini_temp.jpg")
image_data = image_path.read_bytes()
contents = {
"parts": [{"mime_type": "image/jpeg", "data": image_data}, {"text": prompt}]
}
response = model.generate_content(contents=contents)
print(response.text)
bot.reply_to(message, "Gemini vision answer:\n" + response.text)
def register(bot: TeleBot) -> None:
bot.register_message_handler(gemini_handler, commands=["gemini"], pass_bot=True)
bot.register_message_handler(gemini_handler, regexp="^gemini:", pass_bot=True)
bot.register_message_handler(
gemini_photo_handler,
content_types=["photo"],
func=lambda m: m.caption and m.caption.startswith("gemini:"),
pass_bot=True,
)

39
handlers/github.py Normal file
View File

@ -0,0 +1,39 @@
import subprocess
from telebot import TeleBot
from telebot.types import Message
def github_poster_handler(message: Message, bot: TeleBot):
"""github poster: /github <github_user_name> [<start>-<end>]"""
reply_message = bot.reply_to(message, "Generating poster please wait:")
m = message.text.strip()
message_list = m.split(",")
name = message_list[0].strip()
cmd_list = ["github_poster", "github", "--github_user_name", name, "--me", name]
if len(message_list) > 1:
years = message_list[1]
cmd_list.append("--year")
cmd_list.append(years.strip())
r = subprocess.check_output(cmd_list).decode("utf-8")
try:
if "done" in r:
# TODO windows path
r = subprocess.check_output(
["cairosvg", "OUT_FOLDER/github.svg", "-o", f"github_{name}.png"]
).decode("utf-8")
with open(f"github_{name}.png", "rb") as photo:
bot.send_photo(
message.chat.id, photo, reply_to_message_id=message.message_id
)
finally:
bot.delete_message(reply_message.chat.id, reply_message.message_id)
def register(bot: TeleBot) -> None:
bot.register_message_handler(
github_poster_handler, commands=["github"], pass_bot=True
)
bot.register_message_handler(
github_poster_handler, regexp="^github:", pass_bot=True
)

141
handlers/map.py Normal file
View File

@ -0,0 +1,141 @@
import gc
import shutil
from random import random
from tempfile import SpooledTemporaryFile
import numpy as np
import PIL
from matplotlib import figure
from PIL import Image
from prettymapp.geo import get_aoi
from prettymapp.osm import get_osm_geometries
from prettymapp.plotting import Plot as PrettyPlot
from prettymapp.settings import STYLES
from telebot import TeleBot
from telebot.types import Message
MAX_IN_MEMORY = 10 * 1024 * 1024 # 10MiB
PIL.Image.MAX_IMAGE_PIXELS = 933120000
class Plot(PrettyPlot):
# memory leak fix for Plot. thanks @higuoxing https://github.com/higuoxing
# refer to: https://www.mail-archive.com/matplotlib-users@lists.sourceforge.net/msg11809.html
def __post_init__(self):
(
self.xmin,
self.ymin,
self.xmax,
self.ymax,
) = self.aoi_bounds
# take from aoi geometry bounds, otherwise probelematic if unequal geometry distribution over plot.
self.xmid = (self.xmin + self.xmax) / 2
self.ymid = (self.ymin + self.ymax) / 2
self.xdif = self.xmax - self.xmin
self.ydif = self.ymax - self.ymin
self.bg_buffer_x = (self.bg_buffer / 100) * self.xdif
self.bg_buffer_y = (self.bg_buffer / 100) * self.ydif
# self.fig, self.ax = subplots(
# 1, 1, figsize=(12, 12), constrained_layout=True, dpi=1200
# )
self.fig = figure.Figure(figsize=(12, 12), constrained_layout=True, dpi=1200)
self.ax = self.fig.subplots(1, 1)
self.ax.set_aspect(1 / np.cos(self.ymid * np.pi / 180))
self.ax.axis("off")
self.ax.set_xlim(self.xmin - self.bg_buffer_x, self.xmax + self.bg_buffer_x)
self.ax.set_ylim(self.ymin - self.bg_buffer_y, self.ymax + self.bg_buffer_y)
def sizeof_image(image):
with SpooledTemporaryFile(max_size=MAX_IN_MEMORY) as f:
image.save(f, format="JPEG", quality=95)
return f.tell()
def compress_image(input_image, output_image, target_size):
quality = 95
factor = 1.0
with Image.open(input_image) as img:
while sizeof_image(img) > target_size:
factor -= 0.05
width, height = img.size
img = img.resize(
(int(width * factor), int(height * factor)),
PIL.Image.Resampling.LANCZOS,
)
img.save(output_image, format="JPEG", quality=quality)
output_image.seek(0)
def draw_pretty_map(location, style, output_file):
aoi = get_aoi(address=location, radius=1100, rectangular=True)
df = get_osm_geometries(aoi=aoi)
fig = Plot(df=df, aoi_bounds=aoi.bounds, draw_settings=STYLES[style]).plot_all()
with SpooledTemporaryFile(max_size=MAX_IN_MEMORY) as buffer:
fig.savefig(buffer, format="jpeg")
buffer.seek(0)
compress_image(
buffer,
output_file,
10 * 1024 * 1024, # telegram tog need png less than 10MB
)
def map_handler(message: Message, bot: TeleBot):
"""pretty map: /map <address>"""
reply_message = bot.reply_to(
message, "Generating pretty map may take some time please wait:"
)
m = message.text.strip()
location = m.strip()
styles_list = list(STYLES.keys())
style = random.choice(styles_list)
with SpooledTemporaryFile(max_size=MAX_IN_MEMORY) as out_image:
try:
draw_pretty_map(location, style, out_image)
# tg can only send image less than 10MB
with open("map_out.jpg", "wb") as f: # for debug
shutil.copyfileobj(out_image, f)
out_image.seek(0)
bot.send_photo(
message.chat.id, out_image, reply_to_message_id=message.message_id
)
finally:
bot.delete_message(reply_message.chat.id, reply_message.message_id)
gc.collect()
def map_location_handler(message: Message, bot: TeleBot):
# TODO refactor the function
reply_message = bot.reply_to(
message,
"Generating pretty map using location now, may take some time please wait:",
)
location = "{0}, {1}".format(message.location.latitude, message.location.longitude)
styles_list = list(STYLES.keys())
style = random.choice(styles_list)
try:
with SpooledTemporaryFile(max_size=MAX_IN_MEMORY) as out_image:
draw_pretty_map(location, style, out_image)
# tg can only send image less than 10MB
with open("map_out.jpg", "wb") as f: # for debug
shutil.copyfileobj(out_image, f)
out_image.seek(0)
bot.send_photo(
message.chat.id, out_image, reply_to_message_id=message.message_id
)
finally:
bot.delete_message(reply_message.chat.id, reply_message.message_id)
gc.collect()
def register(bot: TeleBot) -> None:
bot.register_message_handler(map_handler, commands=["map"], pass_bot=True)
bot.register_message_handler(map_handler, regexp="^map:", pass_bot=True)
bot.register_message_handler(
map_location_handler, content_types=["location", "venue"], pass_bot=True
)

326
tg.py
View File

@ -1,159 +1,8 @@
import argparse
import gc
import random
import shutil
import subprocess
import traceback
from tempfile import SpooledTemporaryFile
from os import environ
from pathlib import Path
from typing import Optional
from telebot import TeleBot
import numpy as np
import PIL
from matplotlib import figure
from PIL import Image
from prettymapp.geo import get_aoi
from prettymapp.osm import get_osm_geometries
from prettymapp.plotting import Plot as PrettyPlot
from prettymapp.settings import STYLES
from telebot import TeleBot # type: ignore
from telebot.types import BotCommand, Message # type: ignore
import google.generativeai as genai
PIL.Image.MAX_IMAGE_PIXELS = 933120000
MAX_IN_MEMORY = 10 * 1024 * 1024 # 10MiB
GOOGLE_GEMINI_KEY = environ.get("GOOGLE_GEMINI_KEY")
genai.configure(api_key=GOOGLE_GEMINI_KEY)
generation_config = {
"temperature": 0.9,
"top_p": 1,
"top_k": 1,
"max_output_tokens": 2048,
}
safety_settings = [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
},
]
BOT_STARTS_WORDS_LIST = ["map:", "/map", "github:", "/github", "gemini", "/gemini"]
#### Utils ####
def extract_prompt(message: Message, bot_name: str) -> Optional[str]:
"""
This function filters messages for prompts.
Returns:
str: If it is not a prompt, return None. Otherwise, return the trimmed prefix of the actual prompt.
"""
msg_text: str = message.text.strip()
if msg_text.startswith("@"):
if not msg_text.startswith(f"@{bot_name} "):
return None
s = msg_text[len(bot_name) + 2 :]
else:
prefix = next(
(w for w in BOT_STARTS_WORDS_LIST if msg_text.startswith(w)), None
)
if not prefix:
return None
s = msg_text[len(prefix) :]
# If the first word is '@bot_name', remove it as it is considered part of the command when in a group chat.
if s.startswith("@"):
if not s.startswith(f"@{bot_name} "):
return None
s = " ".join(s.split(" ")[1:])
return s
def make_new_gemini_convo():
model = genai.GenerativeModel(
model_name="gemini-pro",
generation_config=generation_config,
safety_settings=safety_settings,
)
convo = model.start_chat()
return convo
class Plot(PrettyPlot):
# memory leak fix for Plot. thanks @higuoxing https://github.com/higuoxing
# refer to: https://www.mail-archive.com/matplotlib-users@lists.sourceforge.net/msg11809.html
def __post_init__(self):
(
self.xmin,
self.ymin,
self.xmax,
self.ymax,
) = self.aoi_bounds
# take from aoi geometry bounds, otherwise probelematic if unequal geometry distribution over plot.
self.xmid = (self.xmin + self.xmax) / 2
self.ymid = (self.ymin + self.ymax) / 2
self.xdif = self.xmax - self.xmin
self.ydif = self.ymax - self.ymin
self.bg_buffer_x = (self.bg_buffer / 100) * self.xdif
self.bg_buffer_y = (self.bg_buffer / 100) * self.ydif
# self.fig, self.ax = subplots(
# 1, 1, figsize=(12, 12), constrained_layout=True, dpi=1200
# )
self.fig = figure.Figure(figsize=(12, 12), constrained_layout=True, dpi=1200)
self.ax = self.fig.subplots(1, 1)
self.ax.set_aspect(1 / np.cos(self.ymid * np.pi / 180))
self.ax.axis("off")
self.ax.set_xlim(self.xmin - self.bg_buffer_x, self.xmax + self.bg_buffer_x)
self.ax.set_ylim(self.ymin - self.bg_buffer_y, self.ymax + self.bg_buffer_y)
def sizeof_image(image):
with SpooledTemporaryFile(max_size=MAX_IN_MEMORY) as f:
image.save(f, format="JPEG", quality=95)
return f.tell()
def compress_image(input_image, output_image, target_size):
quality = 95
factor = 1.0
with Image.open(input_image) as img:
while sizeof_image(img) > target_size:
factor -= 0.05
width, height = img.size
img = img.resize(
(int(width * factor), int(height * factor)),
PIL.Image.Resampling.LANCZOS,
)
img.save(output_image, format="JPEG", quality=quality)
output_image.seek(0)
def draw_pretty_map(location, style, output_file):
aoi = get_aoi(address=location, radius=1100, rectangular=True)
df = get_osm_geometries(aoi=aoi)
fig = Plot(df=df, aoi_bounds=aoi.bounds, draw_settings=STYLES[style]).plot_all()
with SpooledTemporaryFile(max_size=MAX_IN_MEMORY) as buffer:
fig.savefig(buffer, format="jpeg")
buffer.seek(0)
compress_image(
buffer,
output_file,
10 * 1024 * 1024, # telegram tog need png less than 10MB
)
from handlers import load_handlers
def main():
@ -162,181 +11,12 @@ def main():
parser.add_argument("tg_token", help="tg token")
options = parser.parse_args()
print("Arg parse done.")
gemini_player_dict = {}
# Init bot
bot = TeleBot(options.tg_token)
bot.set_my_commands(
[
BotCommand(
"github", "github poster: /github <github_user_name> [<start>-<end>]"
),
BotCommand("map", "pretty map: /map <address>"),
BotCommand("gemini", "Gemini : /gemini <question>"),
]
)
bot_name = bot.get_me().username
load_handlers(bot)
print("Bot init done.")
@bot.message_handler(commands=["github"])
@bot.message_handler(regexp="^github:")
def github_poster_handler(message: Message):
reply_message = bot.reply_to(message, "Generating poster please wait:")
m = extract_prompt(message, bot_name)
if not m:
bot.reply_to(message, "Please provide info after start words.")
return
message_list = m.split(",")
name = message_list[0].strip()
cmd_list = ["github_poster", "github", "--github_user_name", name, "--me", name]
if len(message_list) > 1:
years = message_list[1]
cmd_list.append("--year")
cmd_list.append(years.strip())
r = subprocess.check_output(cmd_list).decode("utf-8")
if "done" in r:
try:
# TODO windows path
r = subprocess.check_output(
["cairosvg", "OUT_FOLDER/github.svg", "-o", f"github_{name}.png"]
).decode("utf-8")
with open(f"github_{name}.png", "rb") as photo:
bot.send_photo(
message.chat.id, photo, reply_to_message_id=message.message_id
)
except Exception as e:
print(e)
bot.reply_to(message, "Something wrong please check")
bot.delete_message(reply_message.chat.id, reply_message.message_id)
@bot.message_handler(commands=["map"])
@bot.message_handler(regexp="^map:")
def map_handler(message: Message):
reply_message = bot.reply_to(
message, "Generating pretty map may take some time please wait:"
)
m = extract_prompt(message, bot_name)
if not m:
bot.reply_to(message, "Please provide info after start words.")
return
location = m.strip()
styles_list = list(STYLES.keys())
style = random.choice(styles_list)
try:
with SpooledTemporaryFile(max_size=MAX_IN_MEMORY) as out_image:
draw_pretty_map(location, style, out_image)
# tg can only send image less than 10MB
with open("map_out.jpg", "wb") as f: # for debug
shutil.copyfileobj(out_image, f)
out_image.seek(0)
bot.send_photo(
message.chat.id, out_image, reply_to_message_id=message.message_id
)
except Exception:
traceback.print_exc()
bot.reply_to(message, "Something wrong please check")
bot.delete_message(reply_message.chat.id, reply_message.message_id)
gc.collect()
@bot.message_handler(content_types=["location", "venue"])
def map_location_handler(message: Message):
# TODO refactor the function
reply_message = bot.reply_to(
message,
"Generating pretty map using location now, may take some time please wait:",
)
location = "{0}, {1}".format(
message.location.latitude, message.location.longitude
)
styles_list = list(STYLES.keys())
style = random.choice(styles_list)
try:
with SpooledTemporaryFile(max_size=MAX_IN_MEMORY) as out_image:
draw_pretty_map(location, style, out_image)
# tg can only send image less than 10MB
with open("map_out.jpg", "wb") as f: # for debug
shutil.copyfileobj(out_image, f)
out_image.seek(0)
bot.send_photo(
message.chat.id, out_image, reply_to_message_id=message.message_id
)
except Exception:
traceback.print_exc()
bot.reply_to(message, "Something wrong please check")
bot.delete_message(reply_message.chat.id, reply_message.message_id)
gc.collect()
@bot.message_handler(commands=["gemini"])
@bot.message_handler(regexp="^gemini:")
def gemini_handler(message: Message):
reply_message = bot.reply_to(
message,
"Generating google gemini answer please wait, note, will only keep the last five messages:",
)
m = extract_prompt(message, bot_name)
if not m:
bot.reply_to(message, "Please provide info after start words.")
return
player = None
# restart will lose all TODO
if str(message.from_user.id) not in gemini_player_dict:
player = make_new_gemini_convo()
gemini_player_dict[str(message.from_user.id)] = player
else:
player = gemini_player_dict[str(message.from_user.id)]
if len(player.history) > 10:
player.history = player.history[2:]
try:
player.send_message(m)
try:
bot.reply_to(
message,
"Gemini answer:\n" + player.last.text,
parse_mode="MarkdownV2",
)
except:
bot.reply_to(message, "Gemini answer:\n" + player.last.text)
except Exception as e:
traceback.print_exc()
bot.reply_to(message, "Something wrong please check the log")
bot.delete_message(reply_message.chat.id, reply_message.message_id)
@bot.message_handler(content_types=["photo"])
def gemini_photo_handler(message: Message) -> None:
s = message.caption
if not s or not s.startswith("gemini:"):
return
reply_message = bot.reply_to(
message,
"Generating google gemini vision answer please wait,",
)
try:
prompt = s.strip().split(maxsplit=1)[1].strip()
max_size_photo = max(message.photo, key=lambda p: p.file_size)
file_path = bot.get_file(max_size_photo.file_id).file_path
downloaded_file = bot.download_file(file_path)
with open("gemini_temp.jpg", "wb") as temp_file:
temp_file.write(downloaded_file)
except Exception as e:
traceback.print_exc()
bot.reply_to(message, "Something is wrong reading your photo or prompt")
model = genai.GenerativeModel("gemini-pro-vision")
image_path = Path("gemini_temp.jpg")
image_data = image_path.read_bytes()
contents = {
"parts": [{"mime_type": "image/jpeg", "data": image_data}, {"text": prompt}]
}
try:
response = model.generate_content(contents=contents)
bot.reply_to(message, "Gemini vision answer:\n" + response.text)
except Exception as e:
traceback.print_exc()
bot.reply_to(message, "Something wrong please check the log")
# Start bot
print("Starting tg collections bot.")
bot.infinity_polling()