mirror of
				https://github.com/cdryzun/tg_bot_collections.git
				synced 2025-11-04 08:46:44 +08:00 
			
		
		
		
	feat: support web search via ollama cloud web search for ChatGPT pro handler
Signed-off-by: Frost Ming <me@frostming.com>
This commit is contained in:
		@ -1,9 +1,12 @@
 | 
			
		||||
import json
 | 
			
		||||
import time
 | 
			
		||||
import uuid
 | 
			
		||||
from typing import Any
 | 
			
		||||
 | 
			
		||||
import requests
 | 
			
		||||
from expiringdict import ExpiringDict
 | 
			
		||||
from telebot import TeleBot
 | 
			
		||||
from telebot.types import Message
 | 
			
		||||
from telegramify_markdown import markdownify
 | 
			
		||||
 | 
			
		||||
from config import settings
 | 
			
		||||
 | 
			
		||||
@ -15,7 +18,6 @@ from ._utils import (
 | 
			
		||||
    logger,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
CHATGPT_MODEL = settings.openai_model
 | 
			
		||||
CHATGPT_PRO_MODEL = settings.openai_model
 | 
			
		||||
 | 
			
		||||
@ -23,11 +25,293 @@ CHATGPT_PRO_MODEL = settings.openai_model
 | 
			
		||||
client = settings.openai_client
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Web search / tool-calling configuration
 | 
			
		||||
WEB_SEARCH_TOOL_NAME = "web_search"
 | 
			
		||||
OLLAMA_WEB_SEARCH_URL = "https://ollama.com/api/web_search"
 | 
			
		||||
WEB_SEARCH_TOOL = {
 | 
			
		||||
    "type": "function",
 | 
			
		||||
    "function": {
 | 
			
		||||
        "name": WEB_SEARCH_TOOL_NAME,
 | 
			
		||||
        "description": (
 | 
			
		||||
            "Use the Ollama Cloud Web Search API to fetch recent information"
 | 
			
		||||
            " from the public internet. Call this when you need up-to-date"
 | 
			
		||||
            " facts, news, or citations."
 | 
			
		||||
        ),
 | 
			
		||||
        "parameters": {
 | 
			
		||||
            "type": "object",
 | 
			
		||||
            "properties": {
 | 
			
		||||
                "query": {
 | 
			
		||||
                    "type": "string",
 | 
			
		||||
                    "description": "Search keywords or question.",
 | 
			
		||||
                },
 | 
			
		||||
                "max_results": {
 | 
			
		||||
                    "type": "integer",
 | 
			
		||||
                    "description": (
 | 
			
		||||
                        "Maximum number of search results to fetch; defaults"
 | 
			
		||||
                        " to the bot configuration if omitted."
 | 
			
		||||
                    ),
 | 
			
		||||
                    "minimum": 1,
 | 
			
		||||
                    "maximum": 10,
 | 
			
		||||
                },
 | 
			
		||||
            },
 | 
			
		||||
            "required": ["query"],
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
}
 | 
			
		||||
STREAMING_UPDATE_INTERVAL = 1.2
 | 
			
		||||
MAX_TOOL_ITERATIONS = 3
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Global history cache
 | 
			
		||||
chatgpt_player_dict = ExpiringDict(max_len=1000, max_age_seconds=600)
 | 
			
		||||
chatgpt_pro_player_dict = ExpiringDict(max_len=1000, max_age_seconds=600)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _web_search_available() -> bool:
 | 
			
		||||
    return bool(settings.ollama_web_search_api_key)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _format_web_search_results(payload: dict[str, Any]) -> str:
 | 
			
		||||
    results = payload.get("results") or payload.get("data") or []
 | 
			
		||||
    if not isinstance(results, list):
 | 
			
		||||
        results = []
 | 
			
		||||
    formatted: list[str] = []
 | 
			
		||||
    for idx, item in enumerate(results, start=1):
 | 
			
		||||
        if not isinstance(item, dict):
 | 
			
		||||
            continue
 | 
			
		||||
        title = (
 | 
			
		||||
            item.get("title") or item.get("name") or item.get("url") or f"Result {idx}"
 | 
			
		||||
        )
 | 
			
		||||
        url = item.get("url") or item.get("link") or item.get("source") or ""
 | 
			
		||||
        snippet = (
 | 
			
		||||
            item.get("snippet")
 | 
			
		||||
            or item.get("summary")
 | 
			
		||||
            or item.get("content")
 | 
			
		||||
            or item.get("description")
 | 
			
		||||
            or ""
 | 
			
		||||
        ).strip()
 | 
			
		||||
        snippet = snippet.replace("\n", " ")
 | 
			
		||||
        if len(snippet) > 400:
 | 
			
		||||
            snippet = snippet[:397].rstrip() + "..."
 | 
			
		||||
        entry = f"[{idx}] {title}"
 | 
			
		||||
        if url:
 | 
			
		||||
            entry = f"{entry}\n{url}"
 | 
			
		||||
        if snippet:
 | 
			
		||||
            entry = f"{entry}\n{snippet}"
 | 
			
		||||
        formatted.append(entry)
 | 
			
		||||
    if formatted:
 | 
			
		||||
        return "\n\n".join(formatted)
 | 
			
		||||
    return json.dumps(payload, ensure_ascii=False)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _call_ollama_web_search(query: str, max_results: int | None = None) -> str:
 | 
			
		||||
    if not _web_search_available():
 | 
			
		||||
        return "Web search is not configured."
 | 
			
		||||
    payload: dict[str, Any] = {"query": query.strip()}
 | 
			
		||||
    limit = max_results if isinstance(max_results, int) else None
 | 
			
		||||
    if limit is None or limit <= 0:
 | 
			
		||||
        limit = settings.ollama_web_search_max_results
 | 
			
		||||
    if limit:
 | 
			
		||||
        payload["max_results"] = int(limit)
 | 
			
		||||
    headers = {
 | 
			
		||||
        "Authorization": f"Bearer {settings.ollama_web_search_api_key}",
 | 
			
		||||
    }
 | 
			
		||||
    try:
 | 
			
		||||
        response = requests.post(
 | 
			
		||||
            OLLAMA_WEB_SEARCH_URL,
 | 
			
		||||
            json=payload,
 | 
			
		||||
            headers=headers,
 | 
			
		||||
            timeout=settings.ollama_web_search_timeout,
 | 
			
		||||
        )
 | 
			
		||||
        response.raise_for_status()
 | 
			
		||||
        data = response.json()
 | 
			
		||||
    except requests.RequestException as exc:
 | 
			
		||||
        logger.exception("Ollama web search failed: %s", exc)
 | 
			
		||||
        return f"Web search error: {exc}"
 | 
			
		||||
    except ValueError:
 | 
			
		||||
        logger.exception("Invalid JSON payload from Ollama web search")
 | 
			
		||||
        return "Web search error: invalid payload."
 | 
			
		||||
    return _format_web_search_results(data)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _available_tools() -> list[dict[str, Any]]:
 | 
			
		||||
    if not _web_search_available():
 | 
			
		||||
        return []
 | 
			
		||||
    return [WEB_SEARCH_TOOL]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _accumulate_tool_call_deltas(
 | 
			
		||||
    buffer: dict[int, dict[str, Any]],
 | 
			
		||||
    deltas: list[Any],
 | 
			
		||||
) -> None:
 | 
			
		||||
    for delta in deltas:
 | 
			
		||||
        idx = getattr(delta, "index", 0) or 0
 | 
			
		||||
        entry = buffer.setdefault(
 | 
			
		||||
            idx,
 | 
			
		||||
            {
 | 
			
		||||
                "id": getattr(delta, "id", None),
 | 
			
		||||
                "type": getattr(delta, "type", "function") or "function",
 | 
			
		||||
                "function": {"name": "", "arguments": ""},
 | 
			
		||||
            },
 | 
			
		||||
        )
 | 
			
		||||
        if getattr(delta, "id", None):
 | 
			
		||||
            entry["id"] = delta.id
 | 
			
		||||
        if getattr(delta, "type", None):
 | 
			
		||||
            entry["type"] = delta.type
 | 
			
		||||
        func = getattr(delta, "function", None)
 | 
			
		||||
        if func is not None:
 | 
			
		||||
            if getattr(func, "name", None):
 | 
			
		||||
                entry["function"]["name"] = func.name
 | 
			
		||||
            if getattr(func, "arguments", None):
 | 
			
		||||
                entry["function"]["arguments"] += func.arguments
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _finalize_tool_calls(buffer: dict[int, dict[str, Any]]) -> list[dict[str, Any]]:
 | 
			
		||||
    tool_calls: list[dict[str, Any]] = []
 | 
			
		||||
    for idx in sorted(buffer):
 | 
			
		||||
        entry = buffer[idx]
 | 
			
		||||
        function_name = entry.get("function", {}).get("name")
 | 
			
		||||
        if not function_name:
 | 
			
		||||
            continue
 | 
			
		||||
        arguments = entry.get("function", {}).get("arguments", "{}")
 | 
			
		||||
        tool_calls.append(
 | 
			
		||||
            {
 | 
			
		||||
                "id": entry.get("id") or str(uuid.uuid4()),
 | 
			
		||||
                "type": entry.get("type") or "function",
 | 
			
		||||
                "function": {
 | 
			
		||||
                    "name": function_name,
 | 
			
		||||
                    "arguments": arguments,
 | 
			
		||||
                },
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
    return tool_calls
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _execute_tool(function_name: str, arguments_json: str) -> str:
 | 
			
		||||
    try:
 | 
			
		||||
        arguments = json.loads(arguments_json or "{}")
 | 
			
		||||
    except json.JSONDecodeError as exc:
 | 
			
		||||
        logger.exception("Invalid tool arguments for %s: %s", function_name, exc)
 | 
			
		||||
        return f"Invalid arguments for {function_name}: {exc}"
 | 
			
		||||
 | 
			
		||||
    if function_name == WEB_SEARCH_TOOL_NAME:
 | 
			
		||||
        query = (arguments.get("query") or "").strip()
 | 
			
		||||
        if not query:
 | 
			
		||||
            return "Web search error: no query provided."
 | 
			
		||||
        max_results = arguments.get("max_results")
 | 
			
		||||
        if isinstance(max_results, str):
 | 
			
		||||
            max_results = int(max_results) if max_results.isdigit() else None
 | 
			
		||||
        elif not isinstance(max_results, int):
 | 
			
		||||
            max_results = None
 | 
			
		||||
        return _call_ollama_web_search(query, max_results)
 | 
			
		||||
 | 
			
		||||
    return f"Function {function_name} is not implemented."
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _append_tool_messages(
 | 
			
		||||
    conversation: list[dict[str, Any]], tool_calls: list[dict[str, Any]]
 | 
			
		||||
) -> None:
 | 
			
		||||
    if not tool_calls:
 | 
			
		||||
        return
 | 
			
		||||
    conversation.append(
 | 
			
		||||
        {
 | 
			
		||||
            "role": "assistant",
 | 
			
		||||
            "content": None,
 | 
			
		||||
            "tool_calls": tool_calls,
 | 
			
		||||
        }
 | 
			
		||||
    )
 | 
			
		||||
    for call in tool_calls:
 | 
			
		||||
        result = _execute_tool(
 | 
			
		||||
            call["function"]["name"], call["function"].get("arguments", "{}")
 | 
			
		||||
        )
 | 
			
		||||
        conversation.append(
 | 
			
		||||
            {
 | 
			
		||||
                "role": "tool",
 | 
			
		||||
                "tool_call_id": call["id"],
 | 
			
		||||
                "content": result,
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _stream_chatgpt_pro_response(
 | 
			
		||||
    conversation: list[dict[str, Any]],
 | 
			
		||||
    reply_id: Message,
 | 
			
		||||
    who: str,
 | 
			
		||||
    bot: TeleBot,
 | 
			
		||||
) -> str:
 | 
			
		||||
    tools = _available_tools()
 | 
			
		||||
    tool_loops_remaining = MAX_TOOL_ITERATIONS if tools else 0
 | 
			
		||||
    final_response = ""
 | 
			
		||||
    while True:
 | 
			
		||||
        request_payload: dict[str, Any] = {
 | 
			
		||||
            "messages": conversation,
 | 
			
		||||
            "model": CHATGPT_PRO_MODEL,
 | 
			
		||||
            "stream": True,
 | 
			
		||||
        }
 | 
			
		||||
        if tools:
 | 
			
		||||
            request_payload["tools"] = tools
 | 
			
		||||
 | 
			
		||||
        stream = client.chat.completions.create(**request_payload)
 | 
			
		||||
        buffer = ""
 | 
			
		||||
        pending_tool_call = False
 | 
			
		||||
        tool_buffer: dict[int, dict[str, Any]] = {}
 | 
			
		||||
        last_update = time.time()
 | 
			
		||||
 | 
			
		||||
        for chunk in stream:
 | 
			
		||||
            if not chunk.choices:
 | 
			
		||||
                continue
 | 
			
		||||
            delta = chunk.choices[0].delta
 | 
			
		||||
            if delta is None:
 | 
			
		||||
                continue
 | 
			
		||||
            if delta.tool_calls:
 | 
			
		||||
                pending_tool_call = True
 | 
			
		||||
                _accumulate_tool_call_deltas(tool_buffer, delta.tool_calls)
 | 
			
		||||
                continue
 | 
			
		||||
            content_piece = delta.content
 | 
			
		||||
            if isinstance(content_piece, list):
 | 
			
		||||
                content_piece = "".join(
 | 
			
		||||
                    getattr(part, "text", "") for part in content_piece
 | 
			
		||||
                )
 | 
			
		||||
            if not content_piece:
 | 
			
		||||
                continue
 | 
			
		||||
            buffer += content_piece
 | 
			
		||||
            now = time.time()
 | 
			
		||||
            if not pending_tool_call and now - last_update > STREAMING_UPDATE_INTERVAL:
 | 
			
		||||
                last_update = now
 | 
			
		||||
                bot_reply_markdown(reply_id, who, buffer, bot, split_text=False)
 | 
			
		||||
 | 
			
		||||
        if pending_tool_call and tools:
 | 
			
		||||
            if tool_loops_remaining <= 0:
 | 
			
		||||
                logger.warning(
 | 
			
		||||
                    "chatgpt_pro_handler reached the maximum number of tool calls"
 | 
			
		||||
                )
 | 
			
		||||
                final_response = buffer or "Unable to finish after calling tools."
 | 
			
		||||
                break
 | 
			
		||||
            tool_calls = _finalize_tool_calls(tool_buffer)
 | 
			
		||||
            if any(
 | 
			
		||||
                call["function"]["name"] == WEB_SEARCH_TOOL_NAME for call in tool_calls
 | 
			
		||||
            ):
 | 
			
		||||
                bot_reply_markdown(
 | 
			
		||||
                    reply_id,
 | 
			
		||||
                    who,
 | 
			
		||||
                    "Searching the web for up-to-date information…",
 | 
			
		||||
                    bot,
 | 
			
		||||
                    split_text=False,
 | 
			
		||||
                    disable_web_page_preview=True,
 | 
			
		||||
                )
 | 
			
		||||
            _append_tool_messages(conversation, tool_calls)
 | 
			
		||||
            tool_loops_remaining -= 1
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        final_response = buffer
 | 
			
		||||
        break
 | 
			
		||||
 | 
			
		||||
    if not final_response:
 | 
			
		||||
        final_response = "I could not generate a response."
 | 
			
		||||
    bot_reply_markdown(reply_id, who, final_response, bot, split_text=True)
 | 
			
		||||
    return final_response
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def chatgpt_handler(message: Message, bot: TeleBot) -> None:
 | 
			
		||||
    """gpt : /gpt <question>"""
 | 
			
		||||
    logger.debug(message)
 | 
			
		||||
@ -125,32 +409,11 @@ def chatgpt_pro_handler(message: Message, bot: TeleBot) -> None:
 | 
			
		||||
        player_message = player_message[2:]
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        r = client.chat.completions.create(
 | 
			
		||||
            messages=player_message,
 | 
			
		||||
            model=CHATGPT_PRO_MODEL,
 | 
			
		||||
            stream=True,
 | 
			
		||||
        )
 | 
			
		||||
        s = ""
 | 
			
		||||
        start = time.time()
 | 
			
		||||
        for chunk in r:
 | 
			
		||||
            logger.debug(chunk)
 | 
			
		||||
            if chunk.choices:
 | 
			
		||||
                if chunk.choices[0].delta.content is None:
 | 
			
		||||
                    break
 | 
			
		||||
                s += chunk.choices[0].delta.content
 | 
			
		||||
                if time.time() - start > 1.2:
 | 
			
		||||
                    start = time.time()
 | 
			
		||||
                    bot_reply_markdown(reply_id, who, s, bot, split_text=False)
 | 
			
		||||
        # maybe not complete
 | 
			
		||||
        try:
 | 
			
		||||
            bot_reply_markdown(reply_id, who, s, bot, split_text=True)
 | 
			
		||||
        except Exception:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
        reply_text = _stream_chatgpt_pro_response(player_message, reply_id, who, bot)
 | 
			
		||||
        player_message.append(
 | 
			
		||||
            {
 | 
			
		||||
                "role": "assistant",
 | 
			
		||||
                "content": markdownify(s),
 | 
			
		||||
                "content": reply_text,
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user