diff --git a/handlers/__init__.py b/handlers/__init__.py new file mode 100644 index 0000000..8aaab27 --- /dev/null +++ b/handlers/__init__.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import importlib +import re +import traceback +from functools import update_wrapper +from pathlib import Path +from typing import Any, Callable, TypeVar + +from telebot import TeleBot +from telebot.types import BotCommand, Message + +T = TypeVar("T", bound=Callable) + + +def extract_prompt(message: str, bot_name: str) -> str: + """ + This function filters messages for prompts. + + Returns: + str: If it is not a prompt, return None. Otherwise, return the trimmed prefix of the actual prompt. + """ + # remove '@bot_name' as it is considered part of the command when in a group chat. + message = re.sub(re.escape(f"@{bot_name}"), "", message).strip() + # add a whitespace after the first colon as we separate the prompt from the command by the first whitespace. + message = re.sub(":", ": ", message, count=1).strip() + try: + left, message = message.split(maxsplit=1) + except ValueError: + return "" + if ":" not in left: + # the replacement happens in the right part, restore it. + message = message.replace(": ", ":", 1) + return message.strip() + + +def wrap_handler(handler: T, bot: TeleBot) -> T: + def wrapper(message: Message, *args: Any, **kwargs: Any) -> None: + try: + m = "" + if message.text is not None: + m = message.text = extract_prompt(message.text, bot.get_me().username) + elif message.caption is not None: + m = message.caption = extract_prompt( + message.caption, bot.get_me().username + ) + if not m: + bot.reply_to(message, "Please provide info after start words.") + return + return handler(message, *args, **kwargs) + except Exception: + traceback.print_exc() + bot.reply_to(message, "Something wrong, please check the log") + + return update_wrapper(wrapper, handler) + + +def load_handlers(bot: TeleBot) -> None: + # import all submodules + this_path = Path(__file__).parent + for child in this_path.iterdir(): + if child.name.startswith("_"): + continue + module = importlib.import_module(f".{child.stem}", __package__) + if hasattr(module, "register"): + print(f"Loading {child.stem} handlers.") + module.register(bot) + print("Loading handlers done.") + + all_commands: list[BotCommand] = [] + for handler in bot.message_handlers: + help_text = getattr(handler["function"], "__doc__", "") + # Add pre-processing and error handling to all callbacks + handler["function"] = wrap_handler(handler["function"], bot) + for command in handler["filters"].get("commands", []): + all_commands.append(BotCommand(command, help_text)) + + if all_commands: + bot.set_my_commands(all_commands) + print("Setting commands done.") diff --git a/handlers/gemini.py b/handlers/gemini.py new file mode 100644 index 0000000..d36ccd4 --- /dev/null +++ b/handlers/gemini.py @@ -0,0 +1,109 @@ +from os import environ +from pathlib import Path + +import google.generativeai as genai +from telebot import TeleBot +from telebot.types import Message + +GOOGLE_GEMINI_KEY = environ.get("GOOGLE_GEMINI_KEY") + +genai.configure(api_key=GOOGLE_GEMINI_KEY) +generation_config = { + "temperature": 0.9, + "top_p": 1, + "top_k": 1, + "max_output_tokens": 2048, +} + +safety_settings = [ + {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, + {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_MEDIUM_AND_ABOVE", + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_MEDIUM_AND_ABOVE", + }, +] + +# Global history cache +gemini_player_dict = {} + + +def make_new_gemini_convo(): + model = genai.GenerativeModel( + model_name="gemini-pro", + generation_config=generation_config, + safety_settings=safety_settings, + ) + convo = model.start_chat() + return convo + + +def gemini_handler(message: Message, bot: TeleBot) -> None: + """Gemini : /gemini """ + reply_message = bot.reply_to( + message, + "Generating google gemini answer please wait, note, will only keep the last five messages:", + ) + m = message.text.strip() + player = None + # restart will lose all TODO + if str(message.from_user.id) not in gemini_player_dict: + player = make_new_gemini_convo() + gemini_player_dict[str(message.from_user.id)] = player + else: + player = gemini_player_dict[str(message.from_user.id)] + if len(player.history) > 10: + player.history = player.history[2:] + player.send_message(m) + try: + bot.reply_to( + message, + "Gemini answer:\n" + player.last.text, + parse_mode="MarkdownV2", + ) + except: + bot.reply_to( + message, + "Gemini answer:\n" + player.last.text, + ) + finally: + bot.delete_message(reply_message.chat.id, reply_message.message_id) + + +def gemini_photo_handler(message: Message, bot: TeleBot) -> None: + s = message.caption + bot.reply_to( + message, + "Generating google gemini vision answer please wait,", + ) + prompt = s.strip() + max_size_photo = max(message.photo, key=lambda p: p.file_size) + file_path = bot.get_file(max_size_photo.file_id).file_path + downloaded_file = bot.download_file(file_path) + with open("gemini_temp.jpg", "wb") as temp_file: + temp_file.write(downloaded_file) + + model = genai.GenerativeModel("gemini-pro-vision") + image_path = Path("gemini_temp.jpg") + image_data = image_path.read_bytes() + contents = { + "parts": [{"mime_type": "image/jpeg", "data": image_data}, {"text": prompt}] + } + response = model.generate_content(contents=contents) + print(response.text) + bot.reply_to(message, "Gemini vision answer:\n" + response.text) + + +def register(bot: TeleBot) -> None: + bot.register_message_handler(gemini_handler, commands=["gemini"], pass_bot=True) + bot.register_message_handler(gemini_handler, regexp="^gemini:", pass_bot=True) + bot.register_message_handler( + gemini_photo_handler, + content_types=["photo"], + func=lambda m: m.caption and m.caption.startswith("gemini:"), + pass_bot=True, + ) diff --git a/handlers/github.py b/handlers/github.py new file mode 100644 index 0000000..8724726 --- /dev/null +++ b/handlers/github.py @@ -0,0 +1,39 @@ +import subprocess + +from telebot import TeleBot +from telebot.types import Message + + +def github_poster_handler(message: Message, bot: TeleBot): + """github poster: /github [-]""" + reply_message = bot.reply_to(message, "Generating poster please wait:") + m = message.text.strip() + message_list = m.split(",") + name = message_list[0].strip() + cmd_list = ["github_poster", "github", "--github_user_name", name, "--me", name] + if len(message_list) > 1: + years = message_list[1] + cmd_list.append("--year") + cmd_list.append(years.strip()) + r = subprocess.check_output(cmd_list).decode("utf-8") + try: + if "done" in r: + # TODO windows path + r = subprocess.check_output( + ["cairosvg", "OUT_FOLDER/github.svg", "-o", f"github_{name}.png"] + ).decode("utf-8") + with open(f"github_{name}.png", "rb") as photo: + bot.send_photo( + message.chat.id, photo, reply_to_message_id=message.message_id + ) + finally: + bot.delete_message(reply_message.chat.id, reply_message.message_id) + + +def register(bot: TeleBot) -> None: + bot.register_message_handler( + github_poster_handler, commands=["github"], pass_bot=True + ) + bot.register_message_handler( + github_poster_handler, regexp="^github:", pass_bot=True + ) diff --git a/handlers/map.py b/handlers/map.py new file mode 100644 index 0000000..247b788 --- /dev/null +++ b/handlers/map.py @@ -0,0 +1,141 @@ +import gc +import shutil +import random +from tempfile import SpooledTemporaryFile + +import numpy as np +import PIL +from matplotlib import figure +from PIL import Image +from prettymapp.geo import get_aoi +from prettymapp.osm import get_osm_geometries +from prettymapp.plotting import Plot as PrettyPlot +from prettymapp.settings import STYLES +from telebot import TeleBot +from telebot.types import Message + +MAX_IN_MEMORY = 10 * 1024 * 1024 # 10MiB +PIL.Image.MAX_IMAGE_PIXELS = 933120000 + + +class Plot(PrettyPlot): + # memory leak fix for Plot. thanks @higuoxing https://github.com/higuoxing + # refer to: https://www.mail-archive.com/matplotlib-users@lists.sourceforge.net/msg11809.html + def __post_init__(self): + ( + self.xmin, + self.ymin, + self.xmax, + self.ymax, + ) = self.aoi_bounds + # take from aoi geometry bounds, otherwise probelematic if unequal geometry distribution over plot. + self.xmid = (self.xmin + self.xmax) / 2 + self.ymid = (self.ymin + self.ymax) / 2 + self.xdif = self.xmax - self.xmin + self.ydif = self.ymax - self.ymin + + self.bg_buffer_x = (self.bg_buffer / 100) * self.xdif + self.bg_buffer_y = (self.bg_buffer / 100) * self.ydif + + # self.fig, self.ax = subplots( + # 1, 1, figsize=(12, 12), constrained_layout=True, dpi=1200 + # ) + self.fig = figure.Figure(figsize=(12, 12), constrained_layout=True, dpi=1200) + self.ax = self.fig.subplots(1, 1) + self.ax.set_aspect(1 / np.cos(self.ymid * np.pi / 180)) + + self.ax.axis("off") + self.ax.set_xlim(self.xmin - self.bg_buffer_x, self.xmax + self.bg_buffer_x) + self.ax.set_ylim(self.ymin - self.bg_buffer_y, self.ymax + self.bg_buffer_y) + + +def sizeof_image(image): + with SpooledTemporaryFile(max_size=MAX_IN_MEMORY) as f: + image.save(f, format="JPEG", quality=95) + return f.tell() + + +def compress_image(input_image, output_image, target_size): + quality = 95 + factor = 1.0 + with Image.open(input_image) as img: + while sizeof_image(img) > target_size: + factor -= 0.05 + width, height = img.size + img = img.resize( + (int(width * factor), int(height * factor)), + PIL.Image.Resampling.LANCZOS, + ) + img.save(output_image, format="JPEG", quality=quality) + output_image.seek(0) + + +def draw_pretty_map(location, style, output_file): + aoi = get_aoi(address=location, radius=1100, rectangular=True) + df = get_osm_geometries(aoi=aoi) + fig = Plot(df=df, aoi_bounds=aoi.bounds, draw_settings=STYLES[style]).plot_all() + with SpooledTemporaryFile(max_size=MAX_IN_MEMORY) as buffer: + fig.savefig(buffer, format="jpeg") + buffer.seek(0) + compress_image( + buffer, + output_file, + 10 * 1024 * 1024, # telegram tog need png less than 10MB + ) + + +def map_handler(message: Message, bot: TeleBot): + """pretty map: /map
""" + reply_message = bot.reply_to( + message, "Generating pretty map may take some time please wait:" + ) + m = message.text.strip() + location = m.strip() + styles_list = list(STYLES.keys()) + style = random.choice(styles_list) + with SpooledTemporaryFile(max_size=MAX_IN_MEMORY) as out_image: + try: + draw_pretty_map(location, style, out_image) + # tg can only send image less than 10MB + with open("map_out.jpg", "wb") as f: # for debug + shutil.copyfileobj(out_image, f) + out_image.seek(0) + bot.send_photo( + message.chat.id, out_image, reply_to_message_id=message.message_id + ) + finally: + bot.delete_message(reply_message.chat.id, reply_message.message_id) + gc.collect() + + +def map_location_handler(message: Message, bot: TeleBot): + # TODO refactor the function + reply_message = bot.reply_to( + message, + "Generating pretty map using location now, may take some time please wait:", + ) + location = "{0}, {1}".format(message.location.latitude, message.location.longitude) + styles_list = list(STYLES.keys()) + style = random.choice(styles_list) + try: + with SpooledTemporaryFile(max_size=MAX_IN_MEMORY) as out_image: + draw_pretty_map(location, style, out_image) + # tg can only send image less than 10MB + with open("map_out.jpg", "wb") as f: # for debug + shutil.copyfileobj(out_image, f) + out_image.seek(0) + bot.send_photo( + message.chat.id, out_image, reply_to_message_id=message.message_id + ) + + finally: + bot.delete_message(reply_message.chat.id, reply_message.message_id) + gc.collect() + + +def register(bot: TeleBot) -> None: + bot.register_message_handler(map_handler, commands=["map"], pass_bot=True) + bot.register_message_handler(map_handler, regexp="^map:", pass_bot=True) + bot.register_message_handler( + map_location_handler, content_types=["location", "venue"], pass_bot=True + ) diff --git a/tg.py b/tg.py index ecf2789..9aed42f 100644 --- a/tg.py +++ b/tg.py @@ -1,159 +1,8 @@ import argparse -import gc -import random -import shutil -import subprocess -import traceback -from tempfile import SpooledTemporaryFile -from os import environ -from pathlib import Path -from typing import Optional +from telebot import TeleBot -import numpy as np -import PIL -from matplotlib import figure -from PIL import Image -from prettymapp.geo import get_aoi -from prettymapp.osm import get_osm_geometries -from prettymapp.plotting import Plot as PrettyPlot -from prettymapp.settings import STYLES -from telebot import TeleBot # type: ignore -from telebot.types import BotCommand, Message # type: ignore -import google.generativeai as genai - -PIL.Image.MAX_IMAGE_PIXELS = 933120000 -MAX_IN_MEMORY = 10 * 1024 * 1024 # 10MiB - -GOOGLE_GEMINI_KEY = environ.get("GOOGLE_GEMINI_KEY") - - -genai.configure(api_key=GOOGLE_GEMINI_KEY) -generation_config = { - "temperature": 0.9, - "top_p": 1, - "top_k": 1, - "max_output_tokens": 2048, -} - -safety_settings = [ - {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, - {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, - { - "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", - "threshold": "BLOCK_MEDIUM_AND_ABOVE", - }, - { - "category": "HARM_CATEGORY_DANGEROUS_CONTENT", - "threshold": "BLOCK_MEDIUM_AND_ABOVE", - }, -] - -BOT_STARTS_WORDS_LIST = ["map:", "/map", "github:", "/github", "gemini", "/gemini"] - - -#### Utils #### -def extract_prompt(message: Message, bot_name: str) -> Optional[str]: - """ - This function filters messages for prompts. - - Returns: - str: If it is not a prompt, return None. Otherwise, return the trimmed prefix of the actual prompt. - """ - msg_text: str = message.text.strip() - if msg_text.startswith("@"): - if not msg_text.startswith(f"@{bot_name} "): - return None - s = msg_text[len(bot_name) + 2 :] - else: - prefix = next( - (w for w in BOT_STARTS_WORDS_LIST if msg_text.startswith(w)), None - ) - if not prefix: - return None - s = msg_text[len(prefix) :] - # If the first word is '@bot_name', remove it as it is considered part of the command when in a group chat. - if s.startswith("@"): - if not s.startswith(f"@{bot_name} "): - return None - s = " ".join(s.split(" ")[1:]) - return s - - -def make_new_gemini_convo(): - model = genai.GenerativeModel( - model_name="gemini-pro", - generation_config=generation_config, - safety_settings=safety_settings, - ) - convo = model.start_chat() - return convo - - -class Plot(PrettyPlot): - # memory leak fix for Plot. thanks @higuoxing https://github.com/higuoxing - # refer to: https://www.mail-archive.com/matplotlib-users@lists.sourceforge.net/msg11809.html - def __post_init__(self): - ( - self.xmin, - self.ymin, - self.xmax, - self.ymax, - ) = self.aoi_bounds - # take from aoi geometry bounds, otherwise probelematic if unequal geometry distribution over plot. - self.xmid = (self.xmin + self.xmax) / 2 - self.ymid = (self.ymin + self.ymax) / 2 - self.xdif = self.xmax - self.xmin - self.ydif = self.ymax - self.ymin - - self.bg_buffer_x = (self.bg_buffer / 100) * self.xdif - self.bg_buffer_y = (self.bg_buffer / 100) * self.ydif - - # self.fig, self.ax = subplots( - # 1, 1, figsize=(12, 12), constrained_layout=True, dpi=1200 - # ) - self.fig = figure.Figure(figsize=(12, 12), constrained_layout=True, dpi=1200) - self.ax = self.fig.subplots(1, 1) - self.ax.set_aspect(1 / np.cos(self.ymid * np.pi / 180)) - - self.ax.axis("off") - self.ax.set_xlim(self.xmin - self.bg_buffer_x, self.xmax + self.bg_buffer_x) - self.ax.set_ylim(self.ymin - self.bg_buffer_y, self.ymax + self.bg_buffer_y) - - -def sizeof_image(image): - with SpooledTemporaryFile(max_size=MAX_IN_MEMORY) as f: - image.save(f, format="JPEG", quality=95) - return f.tell() - - -def compress_image(input_image, output_image, target_size): - quality = 95 - factor = 1.0 - with Image.open(input_image) as img: - while sizeof_image(img) > target_size: - factor -= 0.05 - width, height = img.size - img = img.resize( - (int(width * factor), int(height * factor)), - PIL.Image.Resampling.LANCZOS, - ) - img.save(output_image, format="JPEG", quality=quality) - output_image.seek(0) - - -def draw_pretty_map(location, style, output_file): - aoi = get_aoi(address=location, radius=1100, rectangular=True) - df = get_osm_geometries(aoi=aoi) - fig = Plot(df=df, aoi_bounds=aoi.bounds, draw_settings=STYLES[style]).plot_all() - with SpooledTemporaryFile(max_size=MAX_IN_MEMORY) as buffer: - fig.savefig(buffer, format="jpeg") - buffer.seek(0) - compress_image( - buffer, - output_file, - 10 * 1024 * 1024, # telegram tog need png less than 10MB - ) +from handlers import load_handlers def main(): @@ -162,181 +11,12 @@ def main(): parser.add_argument("tg_token", help="tg token") options = parser.parse_args() print("Arg parse done.") - gemini_player_dict = {} # Init bot bot = TeleBot(options.tg_token) - bot.set_my_commands( - [ - BotCommand( - "github", "github poster: /github [-]" - ), - BotCommand("map", "pretty map: /map
"), - BotCommand("gemini", "Gemini : /gemini "), - ] - ) - bot_name = bot.get_me().username + load_handlers(bot) print("Bot init done.") - @bot.message_handler(commands=["github"]) - @bot.message_handler(regexp="^github:") - def github_poster_handler(message: Message): - reply_message = bot.reply_to(message, "Generating poster please wait:") - m = extract_prompt(message, bot_name) - if not m: - bot.reply_to(message, "Please provide info after start words.") - return - message_list = m.split(",") - name = message_list[0].strip() - cmd_list = ["github_poster", "github", "--github_user_name", name, "--me", name] - if len(message_list) > 1: - years = message_list[1] - cmd_list.append("--year") - cmd_list.append(years.strip()) - r = subprocess.check_output(cmd_list).decode("utf-8") - if "done" in r: - try: - # TODO windows path - r = subprocess.check_output( - ["cairosvg", "OUT_FOLDER/github.svg", "-o", f"github_{name}.png"] - ).decode("utf-8") - with open(f"github_{name}.png", "rb") as photo: - bot.send_photo( - message.chat.id, photo, reply_to_message_id=message.message_id - ) - except Exception as e: - print(e) - bot.reply_to(message, "Something wrong please check") - bot.delete_message(reply_message.chat.id, reply_message.message_id) - - @bot.message_handler(commands=["map"]) - @bot.message_handler(regexp="^map:") - def map_handler(message: Message): - reply_message = bot.reply_to( - message, "Generating pretty map may take some time please wait:" - ) - m = extract_prompt(message, bot_name) - if not m: - bot.reply_to(message, "Please provide info after start words.") - return - location = m.strip() - styles_list = list(STYLES.keys()) - style = random.choice(styles_list) - try: - with SpooledTemporaryFile(max_size=MAX_IN_MEMORY) as out_image: - draw_pretty_map(location, style, out_image) - # tg can only send image less than 10MB - with open("map_out.jpg", "wb") as f: # for debug - shutil.copyfileobj(out_image, f) - out_image.seek(0) - bot.send_photo( - message.chat.id, out_image, reply_to_message_id=message.message_id - ) - - except Exception: - traceback.print_exc() - bot.reply_to(message, "Something wrong please check") - bot.delete_message(reply_message.chat.id, reply_message.message_id) - gc.collect() - - @bot.message_handler(content_types=["location", "venue"]) - def map_location_handler(message: Message): - # TODO refactor the function - reply_message = bot.reply_to( - message, - "Generating pretty map using location now, may take some time please wait:", - ) - location = "{0}, {1}".format( - message.location.latitude, message.location.longitude - ) - styles_list = list(STYLES.keys()) - style = random.choice(styles_list) - try: - with SpooledTemporaryFile(max_size=MAX_IN_MEMORY) as out_image: - draw_pretty_map(location, style, out_image) - # tg can only send image less than 10MB - with open("map_out.jpg", "wb") as f: # for debug - shutil.copyfileobj(out_image, f) - out_image.seek(0) - bot.send_photo( - message.chat.id, out_image, reply_to_message_id=message.message_id - ) - - except Exception: - traceback.print_exc() - bot.reply_to(message, "Something wrong please check") - bot.delete_message(reply_message.chat.id, reply_message.message_id) - gc.collect() - - @bot.message_handler(commands=["gemini"]) - @bot.message_handler(regexp="^gemini:") - def gemini_handler(message: Message): - reply_message = bot.reply_to( - message, - "Generating google gemini answer please wait, note, will only keep the last five messages:", - ) - m = extract_prompt(message, bot_name) - if not m: - bot.reply_to(message, "Please provide info after start words.") - return - player = None - # restart will lose all TODO - if str(message.from_user.id) not in gemini_player_dict: - player = make_new_gemini_convo() - gemini_player_dict[str(message.from_user.id)] = player - else: - player = gemini_player_dict[str(message.from_user.id)] - if len(player.history) > 10: - player.history = player.history[2:] - try: - player.send_message(m) - try: - bot.reply_to( - message, - "Gemini answer:\n" + player.last.text, - parse_mode="MarkdownV2", - ) - except: - bot.reply_to(message, "Gemini answer:\n" + player.last.text) - - except Exception as e: - traceback.print_exc() - bot.reply_to(message, "Something wrong please check the log") - bot.delete_message(reply_message.chat.id, reply_message.message_id) - - @bot.message_handler(content_types=["photo"]) - def gemini_photo_handler(message: Message) -> None: - s = message.caption - if not s or not s.startswith("gemini:"): - return - reply_message = bot.reply_to( - message, - "Generating google gemini vision answer please wait,", - ) - try: - prompt = s.strip().split(maxsplit=1)[1].strip() - - max_size_photo = max(message.photo, key=lambda p: p.file_size) - file_path = bot.get_file(max_size_photo.file_id).file_path - downloaded_file = bot.download_file(file_path) - with open("gemini_temp.jpg", "wb") as temp_file: - temp_file.write(downloaded_file) - except Exception as e: - traceback.print_exc() - bot.reply_to(message, "Something is wrong reading your photo or prompt") - model = genai.GenerativeModel("gemini-pro-vision") - image_path = Path("gemini_temp.jpg") - image_data = image_path.read_bytes() - contents = { - "parts": [{"mime_type": "image/jpeg", "data": image_data}, {"text": prompt}] - } - try: - response = model.generate_content(contents=contents) - bot.reply_to(message, "Gemini vision answer:\n" + response.text) - except Exception as e: - traceback.print_exc() - bot.reply_to(message, "Something wrong please check the log") - # Start bot print("Starting tg collections bot.") bot.infinity_polling()