mirror of
https://github.com/cdryzun/tg_bot_collections.git
synced 2025-11-04 00:28:47 +08:00
feat: support web search via ollama cloud web search for ChatGPT pro handler
Signed-off-by: Frost Ming <me@frostming.com>
This commit is contained in:
@ -63,6 +63,8 @@ Note, if you are using third party service, you need to `export ANTHROPIC_BASE_U
|
|||||||
3. use `gpt: ${message}` to ask
|
3. use `gpt: ${message}` to ask
|
||||||
|
|
||||||
Note, if you are using third party service, you need to `export OPENAI_API_BASE=${the_url}` to change the url.
|
Note, if you are using third party service, you need to `export OPENAI_API_BASE=${the_url}` to change the url.
|
||||||
|
Optional web search support:
|
||||||
|
- export `OLLAMA_WEB_SEARCH_API_KEY=${the_ollama_web_search_api_key}` (and `OLLAMA_WEB_SEARCH_MAX_RESULTS` as needed)
|
||||||
|
|
||||||
## Bot -> llama3
|
## Bot -> llama3
|
||||||
|
|
||||||
|
|||||||
@ -17,6 +17,9 @@ class Settings(BaseSettings):
|
|||||||
google_gemini_api_key: str | None = None
|
google_gemini_api_key: str | None = None
|
||||||
anthropic_api_key: str | None = None
|
anthropic_api_key: str | None = None
|
||||||
telegra_ph_token: str | None = None
|
telegra_ph_token: str | None = None
|
||||||
|
ollama_web_search_api_key: str | None = None
|
||||||
|
ollama_web_search_max_results: int = 5
|
||||||
|
ollama_web_search_timeout: int = 10
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def openai_client(self) -> openai.OpenAI:
|
def openai_client(self) -> openai.OpenAI:
|
||||||
|
|||||||
@ -1,9 +1,12 @@
|
|||||||
|
import json
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import requests
|
||||||
from expiringdict import ExpiringDict
|
from expiringdict import ExpiringDict
|
||||||
from telebot import TeleBot
|
from telebot import TeleBot
|
||||||
from telebot.types import Message
|
from telebot.types import Message
|
||||||
from telegramify_markdown import markdownify
|
|
||||||
|
|
||||||
from config import settings
|
from config import settings
|
||||||
|
|
||||||
@ -15,7 +18,6 @@ from ._utils import (
|
|||||||
logger,
|
logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
CHATGPT_MODEL = settings.openai_model
|
CHATGPT_MODEL = settings.openai_model
|
||||||
CHATGPT_PRO_MODEL = settings.openai_model
|
CHATGPT_PRO_MODEL = settings.openai_model
|
||||||
|
|
||||||
@ -23,11 +25,293 @@ CHATGPT_PRO_MODEL = settings.openai_model
|
|||||||
client = settings.openai_client
|
client = settings.openai_client
|
||||||
|
|
||||||
|
|
||||||
|
# Web search / tool-calling configuration
|
||||||
|
WEB_SEARCH_TOOL_NAME = "web_search"
|
||||||
|
OLLAMA_WEB_SEARCH_URL = "https://ollama.com/api/web_search"
|
||||||
|
WEB_SEARCH_TOOL = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": WEB_SEARCH_TOOL_NAME,
|
||||||
|
"description": (
|
||||||
|
"Use the Ollama Cloud Web Search API to fetch recent information"
|
||||||
|
" from the public internet. Call this when you need up-to-date"
|
||||||
|
" facts, news, or citations."
|
||||||
|
),
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Search keywords or question.",
|
||||||
|
},
|
||||||
|
"max_results": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": (
|
||||||
|
"Maximum number of search results to fetch; defaults"
|
||||||
|
" to the bot configuration if omitted."
|
||||||
|
),
|
||||||
|
"minimum": 1,
|
||||||
|
"maximum": 10,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["query"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
STREAMING_UPDATE_INTERVAL = 1.2
|
||||||
|
MAX_TOOL_ITERATIONS = 3
|
||||||
|
|
||||||
|
|
||||||
# Global history cache
|
# Global history cache
|
||||||
chatgpt_player_dict = ExpiringDict(max_len=1000, max_age_seconds=600)
|
chatgpt_player_dict = ExpiringDict(max_len=1000, max_age_seconds=600)
|
||||||
chatgpt_pro_player_dict = ExpiringDict(max_len=1000, max_age_seconds=600)
|
chatgpt_pro_player_dict = ExpiringDict(max_len=1000, max_age_seconds=600)
|
||||||
|
|
||||||
|
|
||||||
|
def _web_search_available() -> bool:
|
||||||
|
return bool(settings.ollama_web_search_api_key)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_web_search_results(payload: dict[str, Any]) -> str:
|
||||||
|
results = payload.get("results") or payload.get("data") or []
|
||||||
|
if not isinstance(results, list):
|
||||||
|
results = []
|
||||||
|
formatted: list[str] = []
|
||||||
|
for idx, item in enumerate(results, start=1):
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
title = (
|
||||||
|
item.get("title") or item.get("name") or item.get("url") or f"Result {idx}"
|
||||||
|
)
|
||||||
|
url = item.get("url") or item.get("link") or item.get("source") or ""
|
||||||
|
snippet = (
|
||||||
|
item.get("snippet")
|
||||||
|
or item.get("summary")
|
||||||
|
or item.get("content")
|
||||||
|
or item.get("description")
|
||||||
|
or ""
|
||||||
|
).strip()
|
||||||
|
snippet = snippet.replace("\n", " ")
|
||||||
|
if len(snippet) > 400:
|
||||||
|
snippet = snippet[:397].rstrip() + "..."
|
||||||
|
entry = f"[{idx}] {title}"
|
||||||
|
if url:
|
||||||
|
entry = f"{entry}\n{url}"
|
||||||
|
if snippet:
|
||||||
|
entry = f"{entry}\n{snippet}"
|
||||||
|
formatted.append(entry)
|
||||||
|
if formatted:
|
||||||
|
return "\n\n".join(formatted)
|
||||||
|
return json.dumps(payload, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _call_ollama_web_search(query: str, max_results: int | None = None) -> str:
|
||||||
|
if not _web_search_available():
|
||||||
|
return "Web search is not configured."
|
||||||
|
payload: dict[str, Any] = {"query": query.strip()}
|
||||||
|
limit = max_results if isinstance(max_results, int) else None
|
||||||
|
if limit is None or limit <= 0:
|
||||||
|
limit = settings.ollama_web_search_max_results
|
||||||
|
if limit:
|
||||||
|
payload["max_results"] = int(limit)
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {settings.ollama_web_search_api_key}",
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
OLLAMA_WEB_SEARCH_URL,
|
||||||
|
json=payload,
|
||||||
|
headers=headers,
|
||||||
|
timeout=settings.ollama_web_search_timeout,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
except requests.RequestException as exc:
|
||||||
|
logger.exception("Ollama web search failed: %s", exc)
|
||||||
|
return f"Web search error: {exc}"
|
||||||
|
except ValueError:
|
||||||
|
logger.exception("Invalid JSON payload from Ollama web search")
|
||||||
|
return "Web search error: invalid payload."
|
||||||
|
return _format_web_search_results(data)
|
||||||
|
|
||||||
|
|
||||||
|
def _available_tools() -> list[dict[str, Any]]:
|
||||||
|
if not _web_search_available():
|
||||||
|
return []
|
||||||
|
return [WEB_SEARCH_TOOL]
|
||||||
|
|
||||||
|
|
||||||
|
def _accumulate_tool_call_deltas(
|
||||||
|
buffer: dict[int, dict[str, Any]],
|
||||||
|
deltas: list[Any],
|
||||||
|
) -> None:
|
||||||
|
for delta in deltas:
|
||||||
|
idx = getattr(delta, "index", 0) or 0
|
||||||
|
entry = buffer.setdefault(
|
||||||
|
idx,
|
||||||
|
{
|
||||||
|
"id": getattr(delta, "id", None),
|
||||||
|
"type": getattr(delta, "type", "function") or "function",
|
||||||
|
"function": {"name": "", "arguments": ""},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if getattr(delta, "id", None):
|
||||||
|
entry["id"] = delta.id
|
||||||
|
if getattr(delta, "type", None):
|
||||||
|
entry["type"] = delta.type
|
||||||
|
func = getattr(delta, "function", None)
|
||||||
|
if func is not None:
|
||||||
|
if getattr(func, "name", None):
|
||||||
|
entry["function"]["name"] = func.name
|
||||||
|
if getattr(func, "arguments", None):
|
||||||
|
entry["function"]["arguments"] += func.arguments
|
||||||
|
|
||||||
|
|
||||||
|
def _finalize_tool_calls(buffer: dict[int, dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
tool_calls: list[dict[str, Any]] = []
|
||||||
|
for idx in sorted(buffer):
|
||||||
|
entry = buffer[idx]
|
||||||
|
function_name = entry.get("function", {}).get("name")
|
||||||
|
if not function_name:
|
||||||
|
continue
|
||||||
|
arguments = entry.get("function", {}).get("arguments", "{}")
|
||||||
|
tool_calls.append(
|
||||||
|
{
|
||||||
|
"id": entry.get("id") or str(uuid.uuid4()),
|
||||||
|
"type": entry.get("type") or "function",
|
||||||
|
"function": {
|
||||||
|
"name": function_name,
|
||||||
|
"arguments": arguments,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return tool_calls
|
||||||
|
|
||||||
|
|
||||||
|
def _execute_tool(function_name: str, arguments_json: str) -> str:
|
||||||
|
try:
|
||||||
|
arguments = json.loads(arguments_json or "{}")
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
logger.exception("Invalid tool arguments for %s: %s", function_name, exc)
|
||||||
|
return f"Invalid arguments for {function_name}: {exc}"
|
||||||
|
|
||||||
|
if function_name == WEB_SEARCH_TOOL_NAME:
|
||||||
|
query = (arguments.get("query") or "").strip()
|
||||||
|
if not query:
|
||||||
|
return "Web search error: no query provided."
|
||||||
|
max_results = arguments.get("max_results")
|
||||||
|
if isinstance(max_results, str):
|
||||||
|
max_results = int(max_results) if max_results.isdigit() else None
|
||||||
|
elif not isinstance(max_results, int):
|
||||||
|
max_results = None
|
||||||
|
return _call_ollama_web_search(query, max_results)
|
||||||
|
|
||||||
|
return f"Function {function_name} is not implemented."
|
||||||
|
|
||||||
|
|
||||||
|
def _append_tool_messages(
|
||||||
|
conversation: list[dict[str, Any]], tool_calls: list[dict[str, Any]]
|
||||||
|
) -> None:
|
||||||
|
if not tool_calls:
|
||||||
|
return
|
||||||
|
conversation.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": tool_calls,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for call in tool_calls:
|
||||||
|
result = _execute_tool(
|
||||||
|
call["function"]["name"], call["function"].get("arguments", "{}")
|
||||||
|
)
|
||||||
|
conversation.append(
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": call["id"],
|
||||||
|
"content": result,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _stream_chatgpt_pro_response(
|
||||||
|
conversation: list[dict[str, Any]],
|
||||||
|
reply_id: Message,
|
||||||
|
who: str,
|
||||||
|
bot: TeleBot,
|
||||||
|
) -> str:
|
||||||
|
tools = _available_tools()
|
||||||
|
tool_loops_remaining = MAX_TOOL_ITERATIONS if tools else 0
|
||||||
|
final_response = ""
|
||||||
|
while True:
|
||||||
|
request_payload: dict[str, Any] = {
|
||||||
|
"messages": conversation,
|
||||||
|
"model": CHATGPT_PRO_MODEL,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
if tools:
|
||||||
|
request_payload["tools"] = tools
|
||||||
|
|
||||||
|
stream = client.chat.completions.create(**request_payload)
|
||||||
|
buffer = ""
|
||||||
|
pending_tool_call = False
|
||||||
|
tool_buffer: dict[int, dict[str, Any]] = {}
|
||||||
|
last_update = time.time()
|
||||||
|
|
||||||
|
for chunk in stream:
|
||||||
|
if not chunk.choices:
|
||||||
|
continue
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
if delta is None:
|
||||||
|
continue
|
||||||
|
if delta.tool_calls:
|
||||||
|
pending_tool_call = True
|
||||||
|
_accumulate_tool_call_deltas(tool_buffer, delta.tool_calls)
|
||||||
|
continue
|
||||||
|
content_piece = delta.content
|
||||||
|
if isinstance(content_piece, list):
|
||||||
|
content_piece = "".join(
|
||||||
|
getattr(part, "text", "") for part in content_piece
|
||||||
|
)
|
||||||
|
if not content_piece:
|
||||||
|
continue
|
||||||
|
buffer += content_piece
|
||||||
|
now = time.time()
|
||||||
|
if not pending_tool_call and now - last_update > STREAMING_UPDATE_INTERVAL:
|
||||||
|
last_update = now
|
||||||
|
bot_reply_markdown(reply_id, who, buffer, bot, split_text=False)
|
||||||
|
|
||||||
|
if pending_tool_call and tools:
|
||||||
|
if tool_loops_remaining <= 0:
|
||||||
|
logger.warning(
|
||||||
|
"chatgpt_pro_handler reached the maximum number of tool calls"
|
||||||
|
)
|
||||||
|
final_response = buffer or "Unable to finish after calling tools."
|
||||||
|
break
|
||||||
|
tool_calls = _finalize_tool_calls(tool_buffer)
|
||||||
|
if any(
|
||||||
|
call["function"]["name"] == WEB_SEARCH_TOOL_NAME for call in tool_calls
|
||||||
|
):
|
||||||
|
bot_reply_markdown(
|
||||||
|
reply_id,
|
||||||
|
who,
|
||||||
|
"Searching the web for up-to-date information…",
|
||||||
|
bot,
|
||||||
|
split_text=False,
|
||||||
|
disable_web_page_preview=True,
|
||||||
|
)
|
||||||
|
_append_tool_messages(conversation, tool_calls)
|
||||||
|
tool_loops_remaining -= 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
final_response = buffer
|
||||||
|
break
|
||||||
|
|
||||||
|
if not final_response:
|
||||||
|
final_response = "I could not generate a response."
|
||||||
|
bot_reply_markdown(reply_id, who, final_response, bot, split_text=True)
|
||||||
|
return final_response
|
||||||
|
|
||||||
|
|
||||||
def chatgpt_handler(message: Message, bot: TeleBot) -> None:
|
def chatgpt_handler(message: Message, bot: TeleBot) -> None:
|
||||||
"""gpt : /gpt <question>"""
|
"""gpt : /gpt <question>"""
|
||||||
logger.debug(message)
|
logger.debug(message)
|
||||||
@ -125,32 +409,11 @@ def chatgpt_pro_handler(message: Message, bot: TeleBot) -> None:
|
|||||||
player_message = player_message[2:]
|
player_message = player_message[2:]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
r = client.chat.completions.create(
|
reply_text = _stream_chatgpt_pro_response(player_message, reply_id, who, bot)
|
||||||
messages=player_message,
|
|
||||||
model=CHATGPT_PRO_MODEL,
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
s = ""
|
|
||||||
start = time.time()
|
|
||||||
for chunk in r:
|
|
||||||
logger.debug(chunk)
|
|
||||||
if chunk.choices:
|
|
||||||
if chunk.choices[0].delta.content is None:
|
|
||||||
break
|
|
||||||
s += chunk.choices[0].delta.content
|
|
||||||
if time.time() - start > 1.2:
|
|
||||||
start = time.time()
|
|
||||||
bot_reply_markdown(reply_id, who, s, bot, split_text=False)
|
|
||||||
# maybe not complete
|
|
||||||
try:
|
|
||||||
bot_reply_markdown(reply_id, who, s, bot, split_text=True)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
player_message.append(
|
player_message.append(
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": markdownify(s),
|
"content": reply_text,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user