Merge pull request #61 from yihong0618/feat/web-search

feat: support web search via ollama cloud web search for ChatGPT pro handler
This commit is contained in:
yihong
2025-10-27 19:15:47 +08:00
committed by GitHub
3 changed files with 301 additions and 25 deletions

View File

@ -63,6 +63,8 @@ Note, if you are using third party service, you need to `export ANTHROPIC_BASE_U
3. use `gpt: ${message}` to ask 3. use `gpt: ${message}` to ask
Note, if you are using third party service, you need to `export OPENAI_API_BASE=${the_url}` to change the url. Note, if you are using third party service, you need to `export OPENAI_API_BASE=${the_url}` to change the url.
Optional web search support:
- export `OLLAMA_WEB_SEARCH_API_KEY=${the_ollama_web_search_api_key}` (and `OLLAMA_WEB_SEARCH_MAX_RESULTS` as needed)
## Bot -> llama3 ## Bot -> llama3

View File

@ -17,6 +17,9 @@ class Settings(BaseSettings):
google_gemini_api_key: str | None = None google_gemini_api_key: str | None = None
anthropic_api_key: str | None = None anthropic_api_key: str | None = None
telegra_ph_token: str | None = None telegra_ph_token: str | None = None
ollama_web_search_api_key: str | None = None
ollama_web_search_max_results: int = 5
ollama_web_search_timeout: int = 10
@cached_property @cached_property
def openai_client(self) -> openai.OpenAI: def openai_client(self) -> openai.OpenAI:

View File

@ -1,9 +1,12 @@
import json
import time import time
import uuid
from typing import Any
import requests
from expiringdict import ExpiringDict from expiringdict import ExpiringDict
from telebot import TeleBot from telebot import TeleBot
from telebot.types import Message from telebot.types import Message
from telegramify_markdown import markdownify
from config import settings from config import settings
@ -15,7 +18,6 @@ from ._utils import (
logger, logger,
) )
CHATGPT_MODEL = settings.openai_model CHATGPT_MODEL = settings.openai_model
CHATGPT_PRO_MODEL = settings.openai_model CHATGPT_PRO_MODEL = settings.openai_model
@ -23,11 +25,301 @@ CHATGPT_PRO_MODEL = settings.openai_model
client = settings.openai_client client = settings.openai_client
# Web search / tool-calling configuration
WEB_SEARCH_TOOL_NAME = "web_search"
WEB_SEARCH_SYSTEM_PROMPT = {
"role": "system",
"content": "You are a helpful assistant that uses the Ollama Cloud Web Search API to fetch recent information "
"from the public internet when needed. Always cite your sources using the format [number](URL) in your responses.\n\n",
}
OLLAMA_WEB_SEARCH_URL = "https://ollama.com/api/web_search"
WEB_SEARCH_TOOL = {
"type": "function",
"function": {
"name": WEB_SEARCH_TOOL_NAME,
"description": (
"Use the Ollama Cloud Web Search API to fetch recent information"
" from the public internet. Call this when you need up-to-date"
" facts, news, or citations."
),
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search keywords or question.",
},
"max_results": {
"type": "integer",
"description": (
"Maximum number of search results to fetch; defaults"
" to the bot configuration if omitted."
),
"minimum": 1,
"maximum": 10,
},
},
"required": ["query"],
},
},
}
STREAMING_UPDATE_INTERVAL = 1.2
MAX_TOOL_ITERATIONS = 3
# Global history cache # Global history cache
chatgpt_player_dict = ExpiringDict(max_len=1000, max_age_seconds=600) chatgpt_player_dict = ExpiringDict(max_len=1000, max_age_seconds=600)
chatgpt_pro_player_dict = ExpiringDict(max_len=1000, max_age_seconds=600) chatgpt_pro_player_dict = ExpiringDict(max_len=1000, max_age_seconds=600)
def _web_search_available() -> bool:
return bool(settings.ollama_web_search_api_key)
def _format_web_search_results(payload: dict[str, Any]) -> str:
results = payload.get("results") or payload.get("data") or []
if not isinstance(results, list):
results = []
formatted: list[str] = []
for idx, item in enumerate(results, start=1):
if not isinstance(item, dict):
continue
title = (
item.get("title") or item.get("name") or item.get("url") or f"Result {idx}"
)
url = item.get("url") or item.get("link") or item.get("source") or ""
snippet = (
item.get("snippet")
or item.get("summary")
or item.get("content")
or item.get("description")
or ""
).strip()
snippet = snippet.replace("\n", " ")
if len(snippet) > 400:
snippet = snippet[:397].rstrip() + "..."
entry = f"[{idx}] {title}"
if url:
entry = f"{entry}\nURL: {url}"
if snippet:
entry = f"{entry}\n{snippet}"
formatted.append(entry)
if formatted:
return "\n\n".join(formatted)
return json.dumps(payload, ensure_ascii=False)
def _call_ollama_web_search(query: str, max_results: int | None = None) -> str:
if not _web_search_available():
return "Web search is not configured."
payload: dict[str, Any] = {"query": query.strip()}
limit = max_results if isinstance(max_results, int) else None
if limit is None or limit <= 0:
limit = settings.ollama_web_search_max_results
if limit:
payload["max_results"] = int(limit)
headers = {
"Authorization": f"Bearer {settings.ollama_web_search_api_key}",
}
try:
response = requests.post(
OLLAMA_WEB_SEARCH_URL,
json=payload,
headers=headers,
timeout=settings.ollama_web_search_timeout,
)
response.raise_for_status()
data = response.json()
except requests.RequestException as exc:
logger.exception("Ollama web search failed: %s", exc)
return f"Web search error: {exc}"
except ValueError:
logger.exception("Invalid JSON payload from Ollama web search")
return "Web search error: invalid payload."
return _format_web_search_results(data)
def _available_tools() -> list[dict[str, Any]]:
if not _web_search_available():
return []
return [WEB_SEARCH_TOOL]
def _accumulate_tool_call_deltas(
buffer: dict[int, dict[str, Any]],
deltas: list[Any],
) -> None:
for delta in deltas:
idx = getattr(delta, "index", 0) or 0
entry = buffer.setdefault(
idx,
{
"id": getattr(delta, "id", None),
"type": getattr(delta, "type", "function") or "function",
"function": {"name": "", "arguments": ""},
},
)
if getattr(delta, "id", None):
entry["id"] = delta.id
if getattr(delta, "type", None):
entry["type"] = delta.type
func = getattr(delta, "function", None)
if func is not None:
if getattr(func, "name", None):
entry["function"]["name"] = func.name
if getattr(func, "arguments", None):
entry["function"]["arguments"] += func.arguments
def _finalize_tool_calls(buffer: dict[int, dict[str, Any]]) -> list[dict[str, Any]]:
tool_calls: list[dict[str, Any]] = []
for idx in sorted(buffer):
entry = buffer[idx]
function_name = entry.get("function", {}).get("name")
if not function_name:
continue
arguments = entry.get("function", {}).get("arguments", "{}")
tool_calls.append(
{
"id": entry.get("id") or str(uuid.uuid4()),
"type": entry.get("type") or "function",
"function": {
"name": function_name,
"arguments": arguments,
},
}
)
return tool_calls
def _execute_tool(function_name: str, arguments_json: str) -> str:
try:
arguments = json.loads(arguments_json or "{}")
except json.JSONDecodeError as exc:
logger.exception("Invalid tool arguments for %s: %s", function_name, exc)
return f"Invalid arguments for {function_name}: {exc}"
if function_name == WEB_SEARCH_TOOL_NAME:
query = (arguments.get("query") or "").strip()
if not query:
return "Web search error: no query provided."
max_results = arguments.get("max_results")
if isinstance(max_results, str):
max_results = int(max_results) if max_results.isdigit() else None
elif not isinstance(max_results, int):
max_results = None
return _call_ollama_web_search(query, max_results)
return f"Function {function_name} is not implemented."
def _append_tool_messages(
conversation: list[dict[str, Any]], tool_calls: list[dict[str, Any]]
) -> None:
if not tool_calls:
return
conversation.append(
{
"role": "assistant",
"content": None,
"tool_calls": tool_calls,
}
)
for call in tool_calls:
result = _execute_tool(
call["function"]["name"], call["function"].get("arguments", "{}")
)
conversation.append(
{
"role": "tool",
"tool_call_id": call["id"],
"content": result,
}
)
def _stream_chatgpt_pro_response(
conversation: list[dict[str, Any]],
reply_id: Message,
who: str,
bot: TeleBot,
) -> str:
tools = _available_tools()
tool_loops_remaining = MAX_TOOL_ITERATIONS if tools else 0
final_response = ""
if tools:
conversation.insert(0, WEB_SEARCH_SYSTEM_PROMPT)
while True:
request_payload: dict[str, Any] = {
"messages": conversation,
"model": CHATGPT_PRO_MODEL,
"stream": True,
}
if tools:
request_payload.update(tools=tools, tool_choice="auto")
stream = client.chat.completions.create(**request_payload)
buffer = ""
pending_tool_call = False
tool_buffer: dict[int, dict[str, Any]] = {}
last_update = time.time()
for chunk in stream:
if not chunk.choices:
continue
delta = chunk.choices[0].delta
if delta is None:
continue
if delta.tool_calls:
pending_tool_call = True
_accumulate_tool_call_deltas(tool_buffer, delta.tool_calls)
continue
content_piece = delta.content
if isinstance(content_piece, list):
content_piece = "".join(
getattr(part, "text", "") for part in content_piece
)
if not content_piece:
continue
buffer += content_piece
now = time.time()
if not pending_tool_call and now - last_update > STREAMING_UPDATE_INTERVAL:
last_update = now
bot_reply_markdown(reply_id, who, buffer, bot, split_text=False)
if pending_tool_call and tools:
if tool_loops_remaining <= 0:
logger.warning(
"chatgpt_pro_handler reached the maximum number of tool calls"
)
final_response = buffer or "Unable to finish after calling tools."
break
tool_calls = _finalize_tool_calls(tool_buffer)
if any(
call["function"]["name"] == WEB_SEARCH_TOOL_NAME for call in tool_calls
):
bot_reply_markdown(
reply_id,
who,
"Searching the web for up-to-date information…",
bot,
split_text=False,
disable_web_page_preview=True,
)
_append_tool_messages(conversation, tool_calls)
tool_loops_remaining -= 1
continue
final_response = buffer
break
if not final_response:
final_response = "I could not generate a response."
bot_reply_markdown(reply_id, who, final_response, bot, split_text=True)
return final_response
def chatgpt_handler(message: Message, bot: TeleBot) -> None: def chatgpt_handler(message: Message, bot: TeleBot) -> None:
"""gpt : /gpt <question>""" """gpt : /gpt <question>"""
logger.debug(message) logger.debug(message)
@ -125,32 +417,11 @@ def chatgpt_pro_handler(message: Message, bot: TeleBot) -> None:
player_message = player_message[2:] player_message = player_message[2:]
try: try:
r = client.chat.completions.create( reply_text = _stream_chatgpt_pro_response(player_message[:], reply_id, who, bot)
messages=player_message,
model=CHATGPT_PRO_MODEL,
stream=True,
)
s = ""
start = time.time()
for chunk in r:
logger.debug(chunk)
if chunk.choices:
if chunk.choices[0].delta.content is None:
break
s += chunk.choices[0].delta.content
if time.time() - start > 1.2:
start = time.time()
bot_reply_markdown(reply_id, who, s, bot, split_text=False)
# maybe not complete
try:
bot_reply_markdown(reply_id, who, s, bot, split_text=True)
except Exception:
pass
player_message.append( player_message.append(
{ {
"role": "assistant", "role": "assistant",
"content": markdownify(s), "content": reply_text,
} }
) )