feat: add summary and search commands (#54)

* feat: add summary and search commands

Signed-off-by: Frost Ming <me@frostming.com>

* fix formats

Signed-off-by: Frost Ming <me@frostming.com>

* fix: clean up

Signed-off-by: Frost Ming <me@frostming.com>
This commit is contained in:
Frost Ming
2025-07-08 11:41:57 +08:00
committed by GitHub
parent e91862a548
commit 0b60ae2fbe
31 changed files with 1279 additions and 2588 deletions

View File

@ -0,0 +1,139 @@
from __future__ import annotations
import logging
from datetime import datetime, timezone
from functools import partial
import telegramify_markdown
from telebot import TeleBot
from telebot.types import Message
from config import settings
from handlers._utils import non_llm_handler
from .messages import ChatMessage, MessageStore
from .utils import PROMPT, filter_message, parse_date
logger = logging.getLogger("bot")
store = MessageStore("data/messages.db")
@non_llm_handler
def handle_message(message: Message):
logger.debug(
"Received message: %s, chat_id=%d, from=%s",
message.text,
message.chat.id,
message.from_user.id,
)
# 这里可以添加处理消息的逻辑
store.add_message(
ChatMessage(
chat_id=message.chat.id,
message_id=message.id,
content=message.text or "",
user_id=message.from_user.id,
user_name=message.from_user.full_name,
timestamp=datetime.fromtimestamp(message.date, tz=timezone.utc),
)
)
@non_llm_handler
def summary_command(message: Message, bot: TeleBot):
"""生成消息摘要。示例:/summary today; /summary 2d"""
text_parts = message.text.split(maxsplit=1)
if len(text_parts) < 2:
date = "today"
else:
date = text_parts[1].strip()
since, now = parse_date(date, settings.timezone)
messages = store.get_messages_since(message.chat.id, since)
messages_text = "\n".join(
f"{msg.timestamp.isoformat()} - @{msg.user_name}: {msg.content}"
for msg in messages
)
if not messages_text:
bot.reply_to(message, "没有找到指定时间范围内的历史消息。")
return
new_message = bot.reply_to(message, "正在生成摘要,请稍候...")
response = settings.openai_client.chat.completions.create(
model=settings.openai_model,
messages=[
{"role": "user", "content": PROMPT.format(messages=messages_text)},
],
)
reply_text = f"""*👇 前情提要 👇 \\({since.strftime("%Y/%m/%d %H:%M")} \\- {now.strftime("%Y/%m/%d %H:%M")}\\)*
{telegramify_markdown.convert(response.choices[0].message.content)}
"""
logger.debug("Generated summary:\n%s", reply_text)
bot.edit_message_text(
chat_id=new_message.chat.id,
message_id=new_message.message_id,
text=reply_text,
parse_mode="MarkdownV2",
)
@non_llm_handler
def stats_command(message: Message, bot: TeleBot):
"""获取群组消息统计信息"""
stats = store.get_stats(message.chat.id)
if not stats:
bot.reply_to(message, "没有找到任何统计信息。")
return
stats_text = "\n".join(
f"{entry.date}: {entry.message_count} messages" for entry in stats
)
bot.reply_to(
message,
f"📊 群组消息统计信息:\n```\n{stats_text}\n```",
parse_mode="MarkdownV2",
)
@non_llm_handler
def search_command(message: Message, bot: TeleBot):
"""搜索群组消息(示例:/search 关键词 [N]"""
text_parts = message.text.split(maxsplit=2)
if len(text_parts) < 2:
bot.reply_to(message, "请提供要搜索的关键词。")
return
keyword = text_parts[1].strip()
if len(text_parts) > 2 and text_parts[2].isdigit():
limit = int(text_parts[2])
else:
limit = 10
messages = store.search_messages(message.chat.id, keyword, limit=limit)
if not messages:
bot.reply_to(message, "没有找到匹配的消息。")
return
chat_id = str(message.chat.id)
if chat_id.startswith("-100"):
chat_id = chat_id[4:]
items = []
for msg in messages:
link = f"https://t.me/c/{chat_id}/{msg.message_id}"
items.append(f"{link}\n```\n{msg.content}\n```")
message_text = telegramify_markdown.convert("\n".join(items))
bot.reply_to(
message,
f"🔍 *搜索结果(只显示前 {limit} 个):*\n{message_text}",
parse_mode="MarkdownV2",
)
load_priority = 5
if settings.openai_api_key:
def register(bot: TeleBot):
"""注册命令处理器"""
bot.register_message_handler(
summary_command, commands=["summary"], pass_bot=True
)
bot.register_message_handler(stats_command, commands=["stats"], pass_bot=True)
bot.register_message_handler(search_command, commands=["search"], pass_bot=True)
bot.register_message_handler(
handle_message, func=partial(filter_message, bot=bot)
)

View File

@ -0,0 +1,49 @@
from __future__ import annotations
import asyncio
import os
import sys
from .messages import ChatMessage, MessageStore
async def fetch_messages(chat_id: int) -> None:
from telethon import TelegramClient
from telethon.tl.types import Message
store = MessageStore("data/messages.db")
api_id = int(os.getenv("TELEGRAM_API_ID"))
api_hash = os.getenv("TELEGRAM_API_HASH")
async with TelegramClient("test", api_id, api_hash) as client:
assert isinstance(client, TelegramClient)
with store.connect() as conn:
async for message in client.iter_messages(chat_id, reverse=True):
if not isinstance(message, Message) or not message.message:
continue
if not message.from_id:
continue
print(message.pretty_format(message))
user = await client.get_entity(message.from_id)
fullname = user.first_name
if user.last_name:
fullname += f" {user.last_name}"
store.add_message(
ChatMessage(
chat_id=chat_id,
message_id=message.id,
content=message.message,
user_id=message.from_id.user_id,
user_name=fullname,
timestamp=message.date,
),
conn=conn,
)
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: python -m handlers.summary <chat_id>")
sys.exit(1)
chat_id = int(sys.argv[1])
asyncio.run(fetch_messages(chat_id)) # 替换为实际的群组ID

View File

@ -0,0 +1,164 @@
import os
import sqlite3
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
@dataclass(frozen=True)
class ChatMessage:
chat_id: int
message_id: int
content: str
user_id: int
user_name: str
timestamp: datetime
@dataclass(frozen=True)
class StatsEntry:
date: str
message_count: int
class MessageStore:
def __init__(self, db_file: str):
parent_folder = os.path.dirname(db_file)
if not os.path.exists(parent_folder):
os.makedirs(parent_folder)
self._db_file = db_file
self._init_db()
def connect(self) -> sqlite3.Connection:
"""Create a new database connection."""
return sqlite3.connect(self._db_file)
def _init_db(self):
with self.connect() as conn:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS messages (
chat_id INTEGER,
message_id INTEGER,
content TEXT,
user_id INTEGER,
user_name TEXT,
timestamp TEXT,
PRIMARY KEY (chat_id, message_id)
);
"""
)
conn.execute(
"""
CREATE INDEX IF NOT EXISTS idx_chat_timestamp ON messages (chat_id, timestamp);
"""
)
conn.commit()
def add_message(
self, message: ChatMessage, conn: sqlite3.Connection | None = None
) -> None:
need_close = False
if conn is None:
conn = self.connect()
need_close = True
try:
conn.execute(
"""
INSERT OR REPLACE INTO messages (chat_id, message_id, content, user_id, user_name, timestamp)
VALUES (?, ?, ?, ?, ?, ?);
""",
(
message.chat_id,
message.message_id,
message.content,
message.user_id,
message.user_name,
message.timestamp.isoformat(),
),
)
self._clean_old_messages(message.chat_id, conn)
conn.commit()
finally:
if need_close:
conn.close()
def get_messages_since(self, chat_id: int, since: datetime) -> list[ChatMessage]:
with self.connect() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT chat_id, message_id, content, user_id, user_name, timestamp
FROM messages
WHERE chat_id = ? AND timestamp >= ?
ORDER BY timestamp ASC;
""",
(chat_id, since.isoformat()),
)
rows = cursor.fetchall()
return [
ChatMessage(
chat_id=row[0],
message_id=row[1],
content=row[2],
user_id=row[3],
user_name=row[4],
timestamp=datetime.fromisoformat(row[5]),
)
for row in rows
]
def get_stats(self, chat_id: int) -> list[StatsEntry]:
with self.connect() as conn:
self._clean_old_messages(chat_id, conn)
cursor = conn.cursor()
cursor.execute(
"""
SELECT DATE(timestamp), COUNT(*)
FROM messages
WHERE chat_id = ?
GROUP BY DATE(timestamp)
ORDER BY DATE(timestamp) ASC;
""",
(chat_id,),
)
rows = cursor.fetchall()
return [StatsEntry(date=row[0], message_count=row[1]) for row in rows]
def search_messages(
self, chat_id: int, keyword: str, limit: int = 10
) -> list[ChatMessage]:
# TODO: Fuzzy search with full-text search or similar
with self.connect() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT chat_id, message_id, content, user_id, user_name, timestamp
FROM messages
WHERE chat_id = ? AND content LIKE ?
ORDER BY timestamp DESC
LIMIT ?;
""",
(chat_id, f"%{keyword}%", limit),
)
rows = cursor.fetchall()
return [
ChatMessage(
chat_id=row[0],
message_id=row[1],
content=row[2],
user_id=row[3],
user_name=row[4],
timestamp=datetime.fromisoformat(row[5]),
)
for row in rows
]
def _clean_old_messages(
self, chat_id: int, conn: sqlite3.Connection, days: int = 7
) -> None:
cursor = conn.cursor()
threshold_date = datetime.now(tz=timezone.utc) - timedelta(days=days)
cursor.execute(
"DELETE FROM messages WHERE chat_id = ? AND timestamp < ?;",
(chat_id, threshold_date.isoformat()),
)

48
handlers/summary/utils.py Normal file
View File

@ -0,0 +1,48 @@
import re
import zoneinfo
from datetime import datetime, timedelta
from telebot import TeleBot
from telebot.types import Message
PROMPT = """\
请将下面的聊天记录进行总结,包含讨论了哪些话题,有哪些亮点发言和主要观点。
引用用户名请加粗。直接返回内容即可,不要包含引导词和标题。
--- Messages Start ---
{messages}
--- Messages End ---
"""
def filter_message(message: Message, bot: TeleBot) -> bool:
"""过滤消息,排除非文本消息和命令消息"""
if not message.text:
return False
if not message.from_user:
return False
if message.from_user.id == bot.get_me().id:
return False
if message.text.startswith("/"):
return False
return True
date_regex = re.compile(r"^(\d+)([dhm])$")
def parse_date(date_str: str, locale: str) -> tuple[datetime, datetime]:
date_str = date_str.strip().lower()
now = datetime.now(tz=zoneinfo.ZoneInfo(locale))
if date_str == "today":
return now.replace(hour=0, minute=0, second=0, microsecond=0), now
elif m := date_regex.match(date_str):
number = int(m.group(1))
unit = m.group(2)
match unit:
case "d":
return now - timedelta(days=number), now
case "h":
return now - timedelta(hours=number), now
case "m":
return now - timedelta(minutes=number), now
raise ValueError(f"Unsupported date format: {date_str}")