mirror of
https://github.com/cdryzun/tg_bot_collections.git
synced 2025-08-05 13:16:42 +08:00
feat: add summary and search commands (#54)
* feat: add summary and search commands Signed-off-by: Frost Ming <me@frostming.com> * fix formats Signed-off-by: Frost Ming <me@frostming.com> * fix: clean up Signed-off-by: Frost Ming <me@frostming.com>
This commit is contained in:
139
handlers/summary/__init__.py
Normal file
139
handlers/summary/__init__.py
Normal file
@ -0,0 +1,139 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
|
||||
import telegramify_markdown
|
||||
from telebot import TeleBot
|
||||
from telebot.types import Message
|
||||
|
||||
from config import settings
|
||||
from handlers._utils import non_llm_handler
|
||||
|
||||
from .messages import ChatMessage, MessageStore
|
||||
from .utils import PROMPT, filter_message, parse_date
|
||||
|
||||
logger = logging.getLogger("bot")
|
||||
store = MessageStore("data/messages.db")
|
||||
|
||||
|
||||
@non_llm_handler
|
||||
def handle_message(message: Message):
|
||||
logger.debug(
|
||||
"Received message: %s, chat_id=%d, from=%s",
|
||||
message.text,
|
||||
message.chat.id,
|
||||
message.from_user.id,
|
||||
)
|
||||
# 这里可以添加处理消息的逻辑
|
||||
store.add_message(
|
||||
ChatMessage(
|
||||
chat_id=message.chat.id,
|
||||
message_id=message.id,
|
||||
content=message.text or "",
|
||||
user_id=message.from_user.id,
|
||||
user_name=message.from_user.full_name,
|
||||
timestamp=datetime.fromtimestamp(message.date, tz=timezone.utc),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@non_llm_handler
|
||||
def summary_command(message: Message, bot: TeleBot):
|
||||
"""生成消息摘要。示例:/summary today; /summary 2d"""
|
||||
text_parts = message.text.split(maxsplit=1)
|
||||
if len(text_parts) < 2:
|
||||
date = "today"
|
||||
else:
|
||||
date = text_parts[1].strip()
|
||||
since, now = parse_date(date, settings.timezone)
|
||||
messages = store.get_messages_since(message.chat.id, since)
|
||||
messages_text = "\n".join(
|
||||
f"{msg.timestamp.isoformat()} - @{msg.user_name}: {msg.content}"
|
||||
for msg in messages
|
||||
)
|
||||
if not messages_text:
|
||||
bot.reply_to(message, "没有找到指定时间范围内的历史消息。")
|
||||
return
|
||||
new_message = bot.reply_to(message, "正在生成摘要,请稍候...")
|
||||
response = settings.openai_client.chat.completions.create(
|
||||
model=settings.openai_model,
|
||||
messages=[
|
||||
{"role": "user", "content": PROMPT.format(messages=messages_text)},
|
||||
],
|
||||
)
|
||||
reply_text = f"""*👇 前情提要 👇 \\({since.strftime("%Y/%m/%d %H:%M")} \\- {now.strftime("%Y/%m/%d %H:%M")}\\)*
|
||||
|
||||
{telegramify_markdown.convert(response.choices[0].message.content)}
|
||||
"""
|
||||
logger.debug("Generated summary:\n%s", reply_text)
|
||||
bot.edit_message_text(
|
||||
chat_id=new_message.chat.id,
|
||||
message_id=new_message.message_id,
|
||||
text=reply_text,
|
||||
parse_mode="MarkdownV2",
|
||||
)
|
||||
|
||||
|
||||
@non_llm_handler
|
||||
def stats_command(message: Message, bot: TeleBot):
|
||||
"""获取群组消息统计信息"""
|
||||
stats = store.get_stats(message.chat.id)
|
||||
if not stats:
|
||||
bot.reply_to(message, "没有找到任何统计信息。")
|
||||
return
|
||||
stats_text = "\n".join(
|
||||
f"{entry.date}: {entry.message_count} messages" for entry in stats
|
||||
)
|
||||
bot.reply_to(
|
||||
message,
|
||||
f"📊 群组消息统计信息:\n```\n{stats_text}\n```",
|
||||
parse_mode="MarkdownV2",
|
||||
)
|
||||
|
||||
|
||||
@non_llm_handler
|
||||
def search_command(message: Message, bot: TeleBot):
|
||||
"""搜索群组消息(示例:/search 关键词 [N])"""
|
||||
text_parts = message.text.split(maxsplit=2)
|
||||
if len(text_parts) < 2:
|
||||
bot.reply_to(message, "请提供要搜索的关键词。")
|
||||
return
|
||||
keyword = text_parts[1].strip()
|
||||
if len(text_parts) > 2 and text_parts[2].isdigit():
|
||||
limit = int(text_parts[2])
|
||||
else:
|
||||
limit = 10
|
||||
messages = store.search_messages(message.chat.id, keyword, limit=limit)
|
||||
if not messages:
|
||||
bot.reply_to(message, "没有找到匹配的消息。")
|
||||
return
|
||||
chat_id = str(message.chat.id)
|
||||
if chat_id.startswith("-100"):
|
||||
chat_id = chat_id[4:]
|
||||
items = []
|
||||
for msg in messages:
|
||||
link = f"https://t.me/c/{chat_id}/{msg.message_id}"
|
||||
items.append(f"{link}\n```\n{msg.content}\n```")
|
||||
message_text = telegramify_markdown.convert("\n".join(items))
|
||||
bot.reply_to(
|
||||
message,
|
||||
f"🔍 *搜索结果(只显示前 {limit} 个):*\n{message_text}",
|
||||
parse_mode="MarkdownV2",
|
||||
)
|
||||
|
||||
|
||||
load_priority = 5
|
||||
if settings.openai_api_key:
|
||||
|
||||
def register(bot: TeleBot):
|
||||
"""注册命令处理器"""
|
||||
bot.register_message_handler(
|
||||
summary_command, commands=["summary"], pass_bot=True
|
||||
)
|
||||
bot.register_message_handler(stats_command, commands=["stats"], pass_bot=True)
|
||||
bot.register_message_handler(search_command, commands=["search"], pass_bot=True)
|
||||
bot.register_message_handler(
|
||||
handle_message, func=partial(filter_message, bot=bot)
|
||||
)
|
49
handlers/summary/__main__.py
Normal file
49
handlers/summary/__main__.py
Normal file
@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
from .messages import ChatMessage, MessageStore
|
||||
|
||||
|
||||
async def fetch_messages(chat_id: int) -> None:
|
||||
from telethon import TelegramClient
|
||||
from telethon.tl.types import Message
|
||||
|
||||
store = MessageStore("data/messages.db")
|
||||
|
||||
api_id = int(os.getenv("TELEGRAM_API_ID"))
|
||||
api_hash = os.getenv("TELEGRAM_API_HASH")
|
||||
async with TelegramClient("test", api_id, api_hash) as client:
|
||||
assert isinstance(client, TelegramClient)
|
||||
with store.connect() as conn:
|
||||
async for message in client.iter_messages(chat_id, reverse=True):
|
||||
if not isinstance(message, Message) or not message.message:
|
||||
continue
|
||||
if not message.from_id:
|
||||
continue
|
||||
print(message.pretty_format(message))
|
||||
user = await client.get_entity(message.from_id)
|
||||
fullname = user.first_name
|
||||
if user.last_name:
|
||||
fullname += f" {user.last_name}"
|
||||
store.add_message(
|
||||
ChatMessage(
|
||||
chat_id=chat_id,
|
||||
message_id=message.id,
|
||||
content=message.message,
|
||||
user_id=message.from_id.user_id,
|
||||
user_name=fullname,
|
||||
timestamp=message.date,
|
||||
),
|
||||
conn=conn,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 2:
|
||||
print("Usage: python -m handlers.summary <chat_id>")
|
||||
sys.exit(1)
|
||||
chat_id = int(sys.argv[1])
|
||||
asyncio.run(fetch_messages(chat_id)) # 替换为实际的群组ID
|
164
handlers/summary/messages.py
Normal file
164
handlers/summary/messages.py
Normal file
@ -0,0 +1,164 @@
|
||||
import os
|
||||
import sqlite3
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ChatMessage:
|
||||
chat_id: int
|
||||
message_id: int
|
||||
content: str
|
||||
user_id: int
|
||||
user_name: str
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StatsEntry:
|
||||
date: str
|
||||
message_count: int
|
||||
|
||||
|
||||
class MessageStore:
|
||||
def __init__(self, db_file: str):
|
||||
parent_folder = os.path.dirname(db_file)
|
||||
if not os.path.exists(parent_folder):
|
||||
os.makedirs(parent_folder)
|
||||
self._db_file = db_file
|
||||
self._init_db()
|
||||
|
||||
def connect(self) -> sqlite3.Connection:
|
||||
"""Create a new database connection."""
|
||||
return sqlite3.connect(self._db_file)
|
||||
|
||||
def _init_db(self):
|
||||
with self.connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
chat_id INTEGER,
|
||||
message_id INTEGER,
|
||||
content TEXT,
|
||||
user_id INTEGER,
|
||||
user_name TEXT,
|
||||
timestamp TEXT,
|
||||
PRIMARY KEY (chat_id, message_id)
|
||||
);
|
||||
"""
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_chat_timestamp ON messages (chat_id, timestamp);
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def add_message(
|
||||
self, message: ChatMessage, conn: sqlite3.Connection | None = None
|
||||
) -> None:
|
||||
need_close = False
|
||||
if conn is None:
|
||||
conn = self.connect()
|
||||
need_close = True
|
||||
try:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO messages (chat_id, message_id, content, user_id, user_name, timestamp)
|
||||
VALUES (?, ?, ?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
message.chat_id,
|
||||
message.message_id,
|
||||
message.content,
|
||||
message.user_id,
|
||||
message.user_name,
|
||||
message.timestamp.isoformat(),
|
||||
),
|
||||
)
|
||||
self._clean_old_messages(message.chat_id, conn)
|
||||
conn.commit()
|
||||
finally:
|
||||
if need_close:
|
||||
conn.close()
|
||||
|
||||
def get_messages_since(self, chat_id: int, since: datetime) -> list[ChatMessage]:
|
||||
with self.connect() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT chat_id, message_id, content, user_id, user_name, timestamp
|
||||
FROM messages
|
||||
WHERE chat_id = ? AND timestamp >= ?
|
||||
ORDER BY timestamp ASC;
|
||||
""",
|
||||
(chat_id, since.isoformat()),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
return [
|
||||
ChatMessage(
|
||||
chat_id=row[0],
|
||||
message_id=row[1],
|
||||
content=row[2],
|
||||
user_id=row[3],
|
||||
user_name=row[4],
|
||||
timestamp=datetime.fromisoformat(row[5]),
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def get_stats(self, chat_id: int) -> list[StatsEntry]:
|
||||
with self.connect() as conn:
|
||||
self._clean_old_messages(chat_id, conn)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT DATE(timestamp), COUNT(*)
|
||||
FROM messages
|
||||
WHERE chat_id = ?
|
||||
GROUP BY DATE(timestamp)
|
||||
ORDER BY DATE(timestamp) ASC;
|
||||
""",
|
||||
(chat_id,),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
return [StatsEntry(date=row[0], message_count=row[1]) for row in rows]
|
||||
|
||||
def search_messages(
|
||||
self, chat_id: int, keyword: str, limit: int = 10
|
||||
) -> list[ChatMessage]:
|
||||
# TODO: Fuzzy search with full-text search or similar
|
||||
with self.connect() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT chat_id, message_id, content, user_id, user_name, timestamp
|
||||
FROM messages
|
||||
WHERE chat_id = ? AND content LIKE ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?;
|
||||
""",
|
||||
(chat_id, f"%{keyword}%", limit),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
return [
|
||||
ChatMessage(
|
||||
chat_id=row[0],
|
||||
message_id=row[1],
|
||||
content=row[2],
|
||||
user_id=row[3],
|
||||
user_name=row[4],
|
||||
timestamp=datetime.fromisoformat(row[5]),
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def _clean_old_messages(
|
||||
self, chat_id: int, conn: sqlite3.Connection, days: int = 7
|
||||
) -> None:
|
||||
cursor = conn.cursor()
|
||||
threshold_date = datetime.now(tz=timezone.utc) - timedelta(days=days)
|
||||
cursor.execute(
|
||||
"DELETE FROM messages WHERE chat_id = ? AND timestamp < ?;",
|
||||
(chat_id, threshold_date.isoformat()),
|
||||
)
|
48
handlers/summary/utils.py
Normal file
48
handlers/summary/utils.py
Normal file
@ -0,0 +1,48 @@
|
||||
import re
|
||||
import zoneinfo
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from telebot import TeleBot
|
||||
from telebot.types import Message
|
||||
|
||||
PROMPT = """\
|
||||
请将下面的聊天记录进行总结,包含讨论了哪些话题,有哪些亮点发言和主要观点。
|
||||
引用用户名请加粗。直接返回内容即可,不要包含引导词和标题。
|
||||
--- Messages Start ---
|
||||
{messages}
|
||||
--- Messages End ---
|
||||
"""
|
||||
|
||||
|
||||
def filter_message(message: Message, bot: TeleBot) -> bool:
|
||||
"""过滤消息,排除非文本消息和命令消息"""
|
||||
if not message.text:
|
||||
return False
|
||||
if not message.from_user:
|
||||
return False
|
||||
if message.from_user.id == bot.get_me().id:
|
||||
return False
|
||||
if message.text.startswith("/"):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
date_regex = re.compile(r"^(\d+)([dhm])$")
|
||||
|
||||
|
||||
def parse_date(date_str: str, locale: str) -> tuple[datetime, datetime]:
|
||||
date_str = date_str.strip().lower()
|
||||
now = datetime.now(tz=zoneinfo.ZoneInfo(locale))
|
||||
if date_str == "today":
|
||||
return now.replace(hour=0, minute=0, second=0, microsecond=0), now
|
||||
elif m := date_regex.match(date_str):
|
||||
number = int(m.group(1))
|
||||
unit = m.group(2)
|
||||
match unit:
|
||||
case "d":
|
||||
return now - timedelta(days=number), now
|
||||
case "h":
|
||||
return now - timedelta(hours=number), now
|
||||
case "m":
|
||||
return now - timedelta(minutes=number), now
|
||||
raise ValueError(f"Unsupported date format: {date_str}")
|
Reference in New Issue
Block a user