feat: support web search via ollama cloud web search for ChatGPT pro handler

Signed-off-by: Frost Ming <me@frostming.com>
This commit is contained in:
Frost Ming
2025-10-27 11:39:04 +08:00
parent 5bdfe6fbad
commit 791fe0625f
3 changed files with 293 additions and 25 deletions

View File

@ -1,9 +1,12 @@
import json
import time
import uuid
from typing import Any
import requests
from expiringdict import ExpiringDict
from telebot import TeleBot
from telebot.types import Message
from telegramify_markdown import markdownify
from config import settings
@ -15,7 +18,6 @@ from ._utils import (
logger,
)
CHATGPT_MODEL = settings.openai_model
CHATGPT_PRO_MODEL = settings.openai_model
@ -23,11 +25,293 @@ CHATGPT_PRO_MODEL = settings.openai_model
client = settings.openai_client
# Web search / tool-calling configuration
WEB_SEARCH_TOOL_NAME = "web_search"
OLLAMA_WEB_SEARCH_URL = "https://ollama.com/api/web_search"
WEB_SEARCH_TOOL = {
"type": "function",
"function": {
"name": WEB_SEARCH_TOOL_NAME,
"description": (
"Use the Ollama Cloud Web Search API to fetch recent information"
" from the public internet. Call this when you need up-to-date"
" facts, news, or citations."
),
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search keywords or question.",
},
"max_results": {
"type": "integer",
"description": (
"Maximum number of search results to fetch; defaults"
" to the bot configuration if omitted."
),
"minimum": 1,
"maximum": 10,
},
},
"required": ["query"],
},
},
}
STREAMING_UPDATE_INTERVAL = 1.2
MAX_TOOL_ITERATIONS = 3
# Global history cache
chatgpt_player_dict = ExpiringDict(max_len=1000, max_age_seconds=600)
chatgpt_pro_player_dict = ExpiringDict(max_len=1000, max_age_seconds=600)
def _web_search_available() -> bool:
return bool(settings.ollama_web_search_api_key)
def _format_web_search_results(payload: dict[str, Any]) -> str:
results = payload.get("results") or payload.get("data") or []
if not isinstance(results, list):
results = []
formatted: list[str] = []
for idx, item in enumerate(results, start=1):
if not isinstance(item, dict):
continue
title = (
item.get("title") or item.get("name") or item.get("url") or f"Result {idx}"
)
url = item.get("url") or item.get("link") or item.get("source") or ""
snippet = (
item.get("snippet")
or item.get("summary")
or item.get("content")
or item.get("description")
or ""
).strip()
snippet = snippet.replace("\n", " ")
if len(snippet) > 400:
snippet = snippet[:397].rstrip() + "..."
entry = f"[{idx}] {title}"
if url:
entry = f"{entry}\n{url}"
if snippet:
entry = f"{entry}\n{snippet}"
formatted.append(entry)
if formatted:
return "\n\n".join(formatted)
return json.dumps(payload, ensure_ascii=False)
def _call_ollama_web_search(query: str, max_results: int | None = None) -> str:
if not _web_search_available():
return "Web search is not configured."
payload: dict[str, Any] = {"query": query.strip()}
limit = max_results if isinstance(max_results, int) else None
if limit is None or limit <= 0:
limit = settings.ollama_web_search_max_results
if limit:
payload["max_results"] = int(limit)
headers = {
"Authorization": f"Bearer {settings.ollama_web_search_api_key}",
}
try:
response = requests.post(
OLLAMA_WEB_SEARCH_URL,
json=payload,
headers=headers,
timeout=settings.ollama_web_search_timeout,
)
response.raise_for_status()
data = response.json()
except requests.RequestException as exc:
logger.exception("Ollama web search failed: %s", exc)
return f"Web search error: {exc}"
except ValueError:
logger.exception("Invalid JSON payload from Ollama web search")
return "Web search error: invalid payload."
return _format_web_search_results(data)
def _available_tools() -> list[dict[str, Any]]:
if not _web_search_available():
return []
return [WEB_SEARCH_TOOL]
def _accumulate_tool_call_deltas(
buffer: dict[int, dict[str, Any]],
deltas: list[Any],
) -> None:
for delta in deltas:
idx = getattr(delta, "index", 0) or 0
entry = buffer.setdefault(
idx,
{
"id": getattr(delta, "id", None),
"type": getattr(delta, "type", "function") or "function",
"function": {"name": "", "arguments": ""},
},
)
if getattr(delta, "id", None):
entry["id"] = delta.id
if getattr(delta, "type", None):
entry["type"] = delta.type
func = getattr(delta, "function", None)
if func is not None:
if getattr(func, "name", None):
entry["function"]["name"] = func.name
if getattr(func, "arguments", None):
entry["function"]["arguments"] += func.arguments
def _finalize_tool_calls(buffer: dict[int, dict[str, Any]]) -> list[dict[str, Any]]:
tool_calls: list[dict[str, Any]] = []
for idx in sorted(buffer):
entry = buffer[idx]
function_name = entry.get("function", {}).get("name")
if not function_name:
continue
arguments = entry.get("function", {}).get("arguments", "{}")
tool_calls.append(
{
"id": entry.get("id") or str(uuid.uuid4()),
"type": entry.get("type") or "function",
"function": {
"name": function_name,
"arguments": arguments,
},
}
)
return tool_calls
def _execute_tool(function_name: str, arguments_json: str) -> str:
try:
arguments = json.loads(arguments_json or "{}")
except json.JSONDecodeError as exc:
logger.exception("Invalid tool arguments for %s: %s", function_name, exc)
return f"Invalid arguments for {function_name}: {exc}"
if function_name == WEB_SEARCH_TOOL_NAME:
query = (arguments.get("query") or "").strip()
if not query:
return "Web search error: no query provided."
max_results = arguments.get("max_results")
if isinstance(max_results, str):
max_results = int(max_results) if max_results.isdigit() else None
elif not isinstance(max_results, int):
max_results = None
return _call_ollama_web_search(query, max_results)
return f"Function {function_name} is not implemented."
def _append_tool_messages(
conversation: list[dict[str, Any]], tool_calls: list[dict[str, Any]]
) -> None:
if not tool_calls:
return
conversation.append(
{
"role": "assistant",
"content": None,
"tool_calls": tool_calls,
}
)
for call in tool_calls:
result = _execute_tool(
call["function"]["name"], call["function"].get("arguments", "{}")
)
conversation.append(
{
"role": "tool",
"tool_call_id": call["id"],
"content": result,
}
)
def _stream_chatgpt_pro_response(
conversation: list[dict[str, Any]],
reply_id: Message,
who: str,
bot: TeleBot,
) -> str:
tools = _available_tools()
tool_loops_remaining = MAX_TOOL_ITERATIONS if tools else 0
final_response = ""
while True:
request_payload: dict[str, Any] = {
"messages": conversation,
"model": CHATGPT_PRO_MODEL,
"stream": True,
}
if tools:
request_payload["tools"] = tools
stream = client.chat.completions.create(**request_payload)
buffer = ""
pending_tool_call = False
tool_buffer: dict[int, dict[str, Any]] = {}
last_update = time.time()
for chunk in stream:
if not chunk.choices:
continue
delta = chunk.choices[0].delta
if delta is None:
continue
if delta.tool_calls:
pending_tool_call = True
_accumulate_tool_call_deltas(tool_buffer, delta.tool_calls)
continue
content_piece = delta.content
if isinstance(content_piece, list):
content_piece = "".join(
getattr(part, "text", "") for part in content_piece
)
if not content_piece:
continue
buffer += content_piece
now = time.time()
if not pending_tool_call and now - last_update > STREAMING_UPDATE_INTERVAL:
last_update = now
bot_reply_markdown(reply_id, who, buffer, bot, split_text=False)
if pending_tool_call and tools:
if tool_loops_remaining <= 0:
logger.warning(
"chatgpt_pro_handler reached the maximum number of tool calls"
)
final_response = buffer or "Unable to finish after calling tools."
break
tool_calls = _finalize_tool_calls(tool_buffer)
if any(
call["function"]["name"] == WEB_SEARCH_TOOL_NAME for call in tool_calls
):
bot_reply_markdown(
reply_id,
who,
"Searching the web for up-to-date information…",
bot,
split_text=False,
disable_web_page_preview=True,
)
_append_tool_messages(conversation, tool_calls)
tool_loops_remaining -= 1
continue
final_response = buffer
break
if not final_response:
final_response = "I could not generate a response."
bot_reply_markdown(reply_id, who, final_response, bot, split_text=True)
return final_response
def chatgpt_handler(message: Message, bot: TeleBot) -> None:
"""gpt : /gpt <question>"""
logger.debug(message)
@ -125,32 +409,11 @@ def chatgpt_pro_handler(message: Message, bot: TeleBot) -> None:
player_message = player_message[2:]
try:
r = client.chat.completions.create(
messages=player_message,
model=CHATGPT_PRO_MODEL,
stream=True,
)
s = ""
start = time.time()
for chunk in r:
logger.debug(chunk)
if chunk.choices:
if chunk.choices[0].delta.content is None:
break
s += chunk.choices[0].delta.content
if time.time() - start > 1.2:
start = time.time()
bot_reply_markdown(reply_id, who, s, bot, split_text=False)
# maybe not complete
try:
bot_reply_markdown(reply_id, who, s, bot, split_text=True)
except Exception:
pass
reply_text = _stream_chatgpt_pro_response(player_message, reply_id, who, bot)
player_message.append(
{
"role": "assistant",
"content": markdownify(s),
"content": reply_text,
}
)