diff --git a/.github/workflows/CI.yaml b/.github/workflows/CI.yaml new file mode 100644 index 0000000..01c6ef0 --- /dev/null +++ b/.github/workflows/CI.yaml @@ -0,0 +1,27 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + workflow_dispatch: + +concurrency: + group: ${{ github.event.number || github.run_id }} + cancel-in-progress: true + +jobs: + testing: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: install python 3.9 + uses: actions/setup-python@v4 + with: + python-version: "3.9" + cache: "pip" # caching pip dependencies + - name: Check formatting (black) + run: | + pip install black + black . --check diff --git a/tg.py b/tg.py index 2f7ca7a..ecf2789 100644 --- a/tg.py +++ b/tg.py @@ -7,6 +7,8 @@ import traceback from tempfile import SpooledTemporaryFile from os import environ from pathlib import Path +from typing import Optional + import numpy as np import PIL @@ -47,6 +49,36 @@ safety_settings = [ }, ] +BOT_STARTS_WORDS_LIST = ["map:", "/map", "github:", "/github", "gemini", "/gemini"] + + +#### Utils #### +def extract_prompt(message: Message, bot_name: str) -> Optional[str]: + """ + This function filters messages for prompts. + + Returns: + str: If it is not a prompt, return None. Otherwise, return the trimmed prefix of the actual prompt. + """ + msg_text: str = message.text.strip() + if msg_text.startswith("@"): + if not msg_text.startswith(f"@{bot_name} "): + return None + s = msg_text[len(bot_name) + 2 :] + else: + prefix = next( + (w for w in BOT_STARTS_WORDS_LIST if msg_text.startswith(w)), None + ) + if not prefix: + return None + s = msg_text[len(prefix) :] + # If the first word is '@bot_name', remove it as it is considered part of the command when in a group chat. + if s.startswith("@"): + if not s.startswith(f"@{bot_name} "): + return None + s = " ".join(s.split(" ")[1:]) + return s + def make_new_gemini_convo(): model = genai.GenerativeModel( @@ -143,13 +175,17 @@ def main(): BotCommand("gemini", "Gemini : /gemini "), ] ) + bot_name = bot.get_me().username print("Bot init done.") @bot.message_handler(commands=["github"]) @bot.message_handler(regexp="^github:") def github_poster_handler(message: Message): reply_message = bot.reply_to(message, "Generating poster please wait:") - m = message.text.strip().split(maxsplit=1)[1].strip() + m = extract_prompt(message, bot_name) + if not m: + bot.reply_to(message, "Please provide info after start words.") + return message_list = m.split(",") name = message_list[0].strip() cmd_list = ["github_poster", "github", "--github_user_name", name, "--me", name] @@ -179,7 +215,10 @@ def main(): reply_message = bot.reply_to( message, "Generating pretty map may take some time please wait:" ) - m = message.text.strip().split(maxsplit=1)[1].strip() + m = extract_prompt(message, bot_name) + if not m: + bot.reply_to(message, "Please provide info after start words.") + return location = m.strip() styles_list = list(STYLES.keys()) style = random.choice(styles_list) @@ -236,7 +275,10 @@ def main(): message, "Generating google gemini answer please wait, note, will only keep the last five messages:", ) - m = message.text.strip().split(maxsplit=1)[1].strip() + m = extract_prompt(message, bot_name) + if not m: + bot.reply_to(message, "Please provide info after start words.") + return player = None # restart will lose all TODO if str(message.from_user.id) not in gemini_player_dict: @@ -290,7 +332,6 @@ def main(): } try: response = model.generate_content(contents=contents) - print(response.text) bot.reply_to(message, "Gemini vision answer:\n" + response.text) except Exception as e: traceback.print_exc()