mirror of
https://github.com/cdryzun/tg_bot_collections.git
synced 2025-11-03 16:16:45 +08:00
Merge pull request #61 from yihong0618/feat/web-search
feat: support web search via ollama cloud web search for ChatGPT pro handler
This commit is contained in:
@ -63,6 +63,8 @@ Note, if you are using third party service, you need to `export ANTHROPIC_BASE_U
|
||||
3. use `gpt: ${message}` to ask
|
||||
|
||||
Note, if you are using third party service, you need to `export OPENAI_API_BASE=${the_url}` to change the url.
|
||||
Optional web search support:
|
||||
- export `OLLAMA_WEB_SEARCH_API_KEY=${the_ollama_web_search_api_key}` (and `OLLAMA_WEB_SEARCH_MAX_RESULTS` as needed)
|
||||
|
||||
## Bot -> llama3
|
||||
|
||||
|
||||
@ -17,6 +17,9 @@ class Settings(BaseSettings):
|
||||
google_gemini_api_key: str | None = None
|
||||
anthropic_api_key: str | None = None
|
||||
telegra_ph_token: str | None = None
|
||||
ollama_web_search_api_key: str | None = None
|
||||
ollama_web_search_max_results: int = 5
|
||||
ollama_web_search_timeout: int = 10
|
||||
|
||||
@cached_property
|
||||
def openai_client(self) -> openai.OpenAI:
|
||||
|
||||
@ -1,9 +1,12 @@
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from expiringdict import ExpiringDict
|
||||
from telebot import TeleBot
|
||||
from telebot.types import Message
|
||||
from telegramify_markdown import markdownify
|
||||
|
||||
from config import settings
|
||||
|
||||
@ -15,7 +18,6 @@ from ._utils import (
|
||||
logger,
|
||||
)
|
||||
|
||||
|
||||
CHATGPT_MODEL = settings.openai_model
|
||||
CHATGPT_PRO_MODEL = settings.openai_model
|
||||
|
||||
@ -23,11 +25,301 @@ CHATGPT_PRO_MODEL = settings.openai_model
|
||||
client = settings.openai_client
|
||||
|
||||
|
||||
# Web search / tool-calling configuration
|
||||
WEB_SEARCH_TOOL_NAME = "web_search"
|
||||
WEB_SEARCH_SYSTEM_PROMPT = {
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that uses the Ollama Cloud Web Search API to fetch recent information "
|
||||
"from the public internet when needed. Always cite your sources using the format [number](URL) in your responses.\n\n",
|
||||
}
|
||||
OLLAMA_WEB_SEARCH_URL = "https://ollama.com/api/web_search"
|
||||
WEB_SEARCH_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": WEB_SEARCH_TOOL_NAME,
|
||||
"description": (
|
||||
"Use the Ollama Cloud Web Search API to fetch recent information"
|
||||
" from the public internet. Call this when you need up-to-date"
|
||||
" facts, news, or citations."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search keywords or question.",
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Maximum number of search results to fetch; defaults"
|
||||
" to the bot configuration if omitted."
|
||||
),
|
||||
"minimum": 1,
|
||||
"maximum": 10,
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
}
|
||||
STREAMING_UPDATE_INTERVAL = 1.2
|
||||
MAX_TOOL_ITERATIONS = 3
|
||||
|
||||
|
||||
# Global history cache
|
||||
chatgpt_player_dict = ExpiringDict(max_len=1000, max_age_seconds=600)
|
||||
chatgpt_pro_player_dict = ExpiringDict(max_len=1000, max_age_seconds=600)
|
||||
|
||||
|
||||
def _web_search_available() -> bool:
|
||||
return bool(settings.ollama_web_search_api_key)
|
||||
|
||||
|
||||
def _format_web_search_results(payload: dict[str, Any]) -> str:
|
||||
results = payload.get("results") or payload.get("data") or []
|
||||
if not isinstance(results, list):
|
||||
results = []
|
||||
formatted: list[str] = []
|
||||
for idx, item in enumerate(results, start=1):
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
title = (
|
||||
item.get("title") or item.get("name") or item.get("url") or f"Result {idx}"
|
||||
)
|
||||
url = item.get("url") or item.get("link") or item.get("source") or ""
|
||||
snippet = (
|
||||
item.get("snippet")
|
||||
or item.get("summary")
|
||||
or item.get("content")
|
||||
or item.get("description")
|
||||
or ""
|
||||
).strip()
|
||||
snippet = snippet.replace("\n", " ")
|
||||
if len(snippet) > 400:
|
||||
snippet = snippet[:397].rstrip() + "..."
|
||||
entry = f"[{idx}] {title}"
|
||||
if url:
|
||||
entry = f"{entry}\nURL: {url}"
|
||||
if snippet:
|
||||
entry = f"{entry}\n{snippet}"
|
||||
formatted.append(entry)
|
||||
if formatted:
|
||||
return "\n\n".join(formatted)
|
||||
return json.dumps(payload, ensure_ascii=False)
|
||||
|
||||
|
||||
def _call_ollama_web_search(query: str, max_results: int | None = None) -> str:
|
||||
if not _web_search_available():
|
||||
return "Web search is not configured."
|
||||
payload: dict[str, Any] = {"query": query.strip()}
|
||||
limit = max_results if isinstance(max_results, int) else None
|
||||
if limit is None or limit <= 0:
|
||||
limit = settings.ollama_web_search_max_results
|
||||
if limit:
|
||||
payload["max_results"] = int(limit)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {settings.ollama_web_search_api_key}",
|
||||
}
|
||||
try:
|
||||
response = requests.post(
|
||||
OLLAMA_WEB_SEARCH_URL,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=settings.ollama_web_search_timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
except requests.RequestException as exc:
|
||||
logger.exception("Ollama web search failed: %s", exc)
|
||||
return f"Web search error: {exc}"
|
||||
except ValueError:
|
||||
logger.exception("Invalid JSON payload from Ollama web search")
|
||||
return "Web search error: invalid payload."
|
||||
return _format_web_search_results(data)
|
||||
|
||||
|
||||
def _available_tools() -> list[dict[str, Any]]:
|
||||
if not _web_search_available():
|
||||
return []
|
||||
return [WEB_SEARCH_TOOL]
|
||||
|
||||
|
||||
def _accumulate_tool_call_deltas(
|
||||
buffer: dict[int, dict[str, Any]],
|
||||
deltas: list[Any],
|
||||
) -> None:
|
||||
for delta in deltas:
|
||||
idx = getattr(delta, "index", 0) or 0
|
||||
entry = buffer.setdefault(
|
||||
idx,
|
||||
{
|
||||
"id": getattr(delta, "id", None),
|
||||
"type": getattr(delta, "type", "function") or "function",
|
||||
"function": {"name": "", "arguments": ""},
|
||||
},
|
||||
)
|
||||
if getattr(delta, "id", None):
|
||||
entry["id"] = delta.id
|
||||
if getattr(delta, "type", None):
|
||||
entry["type"] = delta.type
|
||||
func = getattr(delta, "function", None)
|
||||
if func is not None:
|
||||
if getattr(func, "name", None):
|
||||
entry["function"]["name"] = func.name
|
||||
if getattr(func, "arguments", None):
|
||||
entry["function"]["arguments"] += func.arguments
|
||||
|
||||
|
||||
def _finalize_tool_calls(buffer: dict[int, dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
|
||||
for idx in sorted(buffer):
|
||||
entry = buffer[idx]
|
||||
function_name = entry.get("function", {}).get("name")
|
||||
if not function_name:
|
||||
continue
|
||||
arguments = entry.get("function", {}).get("arguments", "{}")
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": entry.get("id") or str(uuid.uuid4()),
|
||||
"type": entry.get("type") or "function",
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": arguments,
|
||||
},
|
||||
}
|
||||
)
|
||||
return tool_calls
|
||||
|
||||
|
||||
def _execute_tool(function_name: str, arguments_json: str) -> str:
|
||||
try:
|
||||
arguments = json.loads(arguments_json or "{}")
|
||||
except json.JSONDecodeError as exc:
|
||||
logger.exception("Invalid tool arguments for %s: %s", function_name, exc)
|
||||
return f"Invalid arguments for {function_name}: {exc}"
|
||||
|
||||
if function_name == WEB_SEARCH_TOOL_NAME:
|
||||
query = (arguments.get("query") or "").strip()
|
||||
if not query:
|
||||
return "Web search error: no query provided."
|
||||
max_results = arguments.get("max_results")
|
||||
if isinstance(max_results, str):
|
||||
max_results = int(max_results) if max_results.isdigit() else None
|
||||
elif not isinstance(max_results, int):
|
||||
max_results = None
|
||||
return _call_ollama_web_search(query, max_results)
|
||||
|
||||
return f"Function {function_name} is not implemented."
|
||||
|
||||
|
||||
def _append_tool_messages(
|
||||
conversation: list[dict[str, Any]], tool_calls: list[dict[str, Any]]
|
||||
) -> None:
|
||||
if not tool_calls:
|
||||
return
|
||||
conversation.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
)
|
||||
for call in tool_calls:
|
||||
result = _execute_tool(
|
||||
call["function"]["name"], call["function"].get("arguments", "{}")
|
||||
)
|
||||
conversation.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": call["id"],
|
||||
"content": result,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _stream_chatgpt_pro_response(
|
||||
conversation: list[dict[str, Any]],
|
||||
reply_id: Message,
|
||||
who: str,
|
||||
bot: TeleBot,
|
||||
) -> str:
|
||||
tools = _available_tools()
|
||||
tool_loops_remaining = MAX_TOOL_ITERATIONS if tools else 0
|
||||
final_response = ""
|
||||
if tools:
|
||||
conversation.insert(0, WEB_SEARCH_SYSTEM_PROMPT)
|
||||
while True:
|
||||
request_payload: dict[str, Any] = {
|
||||
"messages": conversation,
|
||||
"model": CHATGPT_PRO_MODEL,
|
||||
"stream": True,
|
||||
}
|
||||
if tools:
|
||||
request_payload.update(tools=tools, tool_choice="auto")
|
||||
|
||||
stream = client.chat.completions.create(**request_payload)
|
||||
buffer = ""
|
||||
pending_tool_call = False
|
||||
tool_buffer: dict[int, dict[str, Any]] = {}
|
||||
last_update = time.time()
|
||||
|
||||
for chunk in stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
delta = chunk.choices[0].delta
|
||||
if delta is None:
|
||||
continue
|
||||
if delta.tool_calls:
|
||||
pending_tool_call = True
|
||||
_accumulate_tool_call_deltas(tool_buffer, delta.tool_calls)
|
||||
continue
|
||||
content_piece = delta.content
|
||||
if isinstance(content_piece, list):
|
||||
content_piece = "".join(
|
||||
getattr(part, "text", "") for part in content_piece
|
||||
)
|
||||
if not content_piece:
|
||||
continue
|
||||
buffer += content_piece
|
||||
now = time.time()
|
||||
if not pending_tool_call and now - last_update > STREAMING_UPDATE_INTERVAL:
|
||||
last_update = now
|
||||
bot_reply_markdown(reply_id, who, buffer, bot, split_text=False)
|
||||
|
||||
if pending_tool_call and tools:
|
||||
if tool_loops_remaining <= 0:
|
||||
logger.warning(
|
||||
"chatgpt_pro_handler reached the maximum number of tool calls"
|
||||
)
|
||||
final_response = buffer or "Unable to finish after calling tools."
|
||||
break
|
||||
tool_calls = _finalize_tool_calls(tool_buffer)
|
||||
if any(
|
||||
call["function"]["name"] == WEB_SEARCH_TOOL_NAME for call in tool_calls
|
||||
):
|
||||
bot_reply_markdown(
|
||||
reply_id,
|
||||
who,
|
||||
"Searching the web for up-to-date information…",
|
||||
bot,
|
||||
split_text=False,
|
||||
disable_web_page_preview=True,
|
||||
)
|
||||
_append_tool_messages(conversation, tool_calls)
|
||||
tool_loops_remaining -= 1
|
||||
continue
|
||||
|
||||
final_response = buffer
|
||||
break
|
||||
|
||||
if not final_response:
|
||||
final_response = "I could not generate a response."
|
||||
bot_reply_markdown(reply_id, who, final_response, bot, split_text=True)
|
||||
return final_response
|
||||
|
||||
|
||||
def chatgpt_handler(message: Message, bot: TeleBot) -> None:
|
||||
"""gpt : /gpt <question>"""
|
||||
logger.debug(message)
|
||||
@ -125,32 +417,11 @@ def chatgpt_pro_handler(message: Message, bot: TeleBot) -> None:
|
||||
player_message = player_message[2:]
|
||||
|
||||
try:
|
||||
r = client.chat.completions.create(
|
||||
messages=player_message,
|
||||
model=CHATGPT_PRO_MODEL,
|
||||
stream=True,
|
||||
)
|
||||
s = ""
|
||||
start = time.time()
|
||||
for chunk in r:
|
||||
logger.debug(chunk)
|
||||
if chunk.choices:
|
||||
if chunk.choices[0].delta.content is None:
|
||||
break
|
||||
s += chunk.choices[0].delta.content
|
||||
if time.time() - start > 1.2:
|
||||
start = time.time()
|
||||
bot_reply_markdown(reply_id, who, s, bot, split_text=False)
|
||||
# maybe not complete
|
||||
try:
|
||||
bot_reply_markdown(reply_id, who, s, bot, split_text=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
reply_text = _stream_chatgpt_pro_response(player_message[:], reply_id, who, bot)
|
||||
player_message.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": markdownify(s),
|
||||
"content": reply_text,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user