diff --git a/README.md b/README.md index 09169ea..fb893ca 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,8 @@ Note, if you are using third party service, you need to `export ANTHROPIC_BASE_U 3. use `gpt: ${message}` to ask Note, if you are using third party service, you need to `export OPENAI_API_BASE=${the_url}` to change the url. +Optional web search support: +- export `OLLAMA_WEB_SEARCH_API_KEY=${the_ollama_web_search_api_key}` (and `OLLAMA_WEB_SEARCH_MAX_RESULTS` as needed) ## Bot -> llama3 diff --git a/config.py b/config.py index d67ee8d..862d11b 100644 --- a/config.py +++ b/config.py @@ -17,6 +17,9 @@ class Settings(BaseSettings): google_gemini_api_key: str | None = None anthropic_api_key: str | None = None telegra_ph_token: str | None = None + ollama_web_search_api_key: str | None = None + ollama_web_search_max_results: int = 5 + ollama_web_search_timeout: int = 10 @cached_property def openai_client(self) -> openai.OpenAI: diff --git a/handlers/chatgpt.py b/handlers/chatgpt.py index 542816b..f499a3a 100644 --- a/handlers/chatgpt.py +++ b/handlers/chatgpt.py @@ -1,9 +1,12 @@ +import json import time +import uuid +from typing import Any +import requests from expiringdict import ExpiringDict from telebot import TeleBot from telebot.types import Message -from telegramify_markdown import markdownify from config import settings @@ -15,7 +18,6 @@ from ._utils import ( logger, ) - CHATGPT_MODEL = settings.openai_model CHATGPT_PRO_MODEL = settings.openai_model @@ -23,11 +25,301 @@ CHATGPT_PRO_MODEL = settings.openai_model client = settings.openai_client +# Web search / tool-calling configuration +WEB_SEARCH_TOOL_NAME = "web_search" +WEB_SEARCH_SYSTEM_PROMPT = { + "role": "system", + "content": "You are a helpful assistant that uses the Ollama Cloud Web Search API to fetch recent information " + "from the public internet when needed. Always cite your sources using the format [number](URL) in your responses.\n\n", +} +OLLAMA_WEB_SEARCH_URL = "https://ollama.com/api/web_search" +WEB_SEARCH_TOOL = { + "type": "function", + "function": { + "name": WEB_SEARCH_TOOL_NAME, + "description": ( + "Use the Ollama Cloud Web Search API to fetch recent information" + " from the public internet. Call this when you need up-to-date" + " facts, news, or citations." + ), + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search keywords or question.", + }, + "max_results": { + "type": "integer", + "description": ( + "Maximum number of search results to fetch; defaults" + " to the bot configuration if omitted." + ), + "minimum": 1, + "maximum": 10, + }, + }, + "required": ["query"], + }, + }, +} +STREAMING_UPDATE_INTERVAL = 1.2 +MAX_TOOL_ITERATIONS = 3 + + # Global history cache chatgpt_player_dict = ExpiringDict(max_len=1000, max_age_seconds=600) chatgpt_pro_player_dict = ExpiringDict(max_len=1000, max_age_seconds=600) +def _web_search_available() -> bool: + return bool(settings.ollama_web_search_api_key) + + +def _format_web_search_results(payload: dict[str, Any]) -> str: + results = payload.get("results") or payload.get("data") or [] + if not isinstance(results, list): + results = [] + formatted: list[str] = [] + for idx, item in enumerate(results, start=1): + if not isinstance(item, dict): + continue + title = ( + item.get("title") or item.get("name") or item.get("url") or f"Result {idx}" + ) + url = item.get("url") or item.get("link") or item.get("source") or "" + snippet = ( + item.get("snippet") + or item.get("summary") + or item.get("content") + or item.get("description") + or "" + ).strip() + snippet = snippet.replace("\n", " ") + if len(snippet) > 400: + snippet = snippet[:397].rstrip() + "..." + entry = f"[{idx}] {title}" + if url: + entry = f"{entry}\nURL: {url}" + if snippet: + entry = f"{entry}\n{snippet}" + formatted.append(entry) + if formatted: + return "\n\n".join(formatted) + return json.dumps(payload, ensure_ascii=False) + + +def _call_ollama_web_search(query: str, max_results: int | None = None) -> str: + if not _web_search_available(): + return "Web search is not configured." + payload: dict[str, Any] = {"query": query.strip()} + limit = max_results if isinstance(max_results, int) else None + if limit is None or limit <= 0: + limit = settings.ollama_web_search_max_results + if limit: + payload["max_results"] = int(limit) + headers = { + "Authorization": f"Bearer {settings.ollama_web_search_api_key}", + } + try: + response = requests.post( + OLLAMA_WEB_SEARCH_URL, + json=payload, + headers=headers, + timeout=settings.ollama_web_search_timeout, + ) + response.raise_for_status() + data = response.json() + except requests.RequestException as exc: + logger.exception("Ollama web search failed: %s", exc) + return f"Web search error: {exc}" + except ValueError: + logger.exception("Invalid JSON payload from Ollama web search") + return "Web search error: invalid payload." + return _format_web_search_results(data) + + +def _available_tools() -> list[dict[str, Any]]: + if not _web_search_available(): + return [] + return [WEB_SEARCH_TOOL] + + +def _accumulate_tool_call_deltas( + buffer: dict[int, dict[str, Any]], + deltas: list[Any], +) -> None: + for delta in deltas: + idx = getattr(delta, "index", 0) or 0 + entry = buffer.setdefault( + idx, + { + "id": getattr(delta, "id", None), + "type": getattr(delta, "type", "function") or "function", + "function": {"name": "", "arguments": ""}, + }, + ) + if getattr(delta, "id", None): + entry["id"] = delta.id + if getattr(delta, "type", None): + entry["type"] = delta.type + func = getattr(delta, "function", None) + if func is not None: + if getattr(func, "name", None): + entry["function"]["name"] = func.name + if getattr(func, "arguments", None): + entry["function"]["arguments"] += func.arguments + + +def _finalize_tool_calls(buffer: dict[int, dict[str, Any]]) -> list[dict[str, Any]]: + tool_calls: list[dict[str, Any]] = [] + + for idx in sorted(buffer): + entry = buffer[idx] + function_name = entry.get("function", {}).get("name") + if not function_name: + continue + arguments = entry.get("function", {}).get("arguments", "{}") + tool_calls.append( + { + "id": entry.get("id") or str(uuid.uuid4()), + "type": entry.get("type") or "function", + "function": { + "name": function_name, + "arguments": arguments, + }, + } + ) + return tool_calls + + +def _execute_tool(function_name: str, arguments_json: str) -> str: + try: + arguments = json.loads(arguments_json or "{}") + except json.JSONDecodeError as exc: + logger.exception("Invalid tool arguments for %s: %s", function_name, exc) + return f"Invalid arguments for {function_name}: {exc}" + + if function_name == WEB_SEARCH_TOOL_NAME: + query = (arguments.get("query") or "").strip() + if not query: + return "Web search error: no query provided." + max_results = arguments.get("max_results") + if isinstance(max_results, str): + max_results = int(max_results) if max_results.isdigit() else None + elif not isinstance(max_results, int): + max_results = None + return _call_ollama_web_search(query, max_results) + + return f"Function {function_name} is not implemented." + + +def _append_tool_messages( + conversation: list[dict[str, Any]], tool_calls: list[dict[str, Any]] +) -> None: + if not tool_calls: + return + conversation.append( + { + "role": "assistant", + "content": None, + "tool_calls": tool_calls, + } + ) + for call in tool_calls: + result = _execute_tool( + call["function"]["name"], call["function"].get("arguments", "{}") + ) + conversation.append( + { + "role": "tool", + "tool_call_id": call["id"], + "content": result, + } + ) + + +def _stream_chatgpt_pro_response( + conversation: list[dict[str, Any]], + reply_id: Message, + who: str, + bot: TeleBot, +) -> str: + tools = _available_tools() + tool_loops_remaining = MAX_TOOL_ITERATIONS if tools else 0 + final_response = "" + if tools: + conversation.insert(0, WEB_SEARCH_SYSTEM_PROMPT) + while True: + request_payload: dict[str, Any] = { + "messages": conversation, + "model": CHATGPT_PRO_MODEL, + "stream": True, + } + if tools: + request_payload.update(tools=tools, tool_choice="auto") + + stream = client.chat.completions.create(**request_payload) + buffer = "" + pending_tool_call = False + tool_buffer: dict[int, dict[str, Any]] = {} + last_update = time.time() + + for chunk in stream: + if not chunk.choices: + continue + delta = chunk.choices[0].delta + if delta is None: + continue + if delta.tool_calls: + pending_tool_call = True + _accumulate_tool_call_deltas(tool_buffer, delta.tool_calls) + continue + content_piece = delta.content + if isinstance(content_piece, list): + content_piece = "".join( + getattr(part, "text", "") for part in content_piece + ) + if not content_piece: + continue + buffer += content_piece + now = time.time() + if not pending_tool_call and now - last_update > STREAMING_UPDATE_INTERVAL: + last_update = now + bot_reply_markdown(reply_id, who, buffer, bot, split_text=False) + + if pending_tool_call and tools: + if tool_loops_remaining <= 0: + logger.warning( + "chatgpt_pro_handler reached the maximum number of tool calls" + ) + final_response = buffer or "Unable to finish after calling tools." + break + tool_calls = _finalize_tool_calls(tool_buffer) + if any( + call["function"]["name"] == WEB_SEARCH_TOOL_NAME for call in tool_calls + ): + bot_reply_markdown( + reply_id, + who, + "Searching the web for up-to-date information…", + bot, + split_text=False, + disable_web_page_preview=True, + ) + _append_tool_messages(conversation, tool_calls) + tool_loops_remaining -= 1 + continue + + final_response = buffer + break + + if not final_response: + final_response = "I could not generate a response." + bot_reply_markdown(reply_id, who, final_response, bot, split_text=True) + return final_response + + def chatgpt_handler(message: Message, bot: TeleBot) -> None: """gpt : /gpt """ logger.debug(message) @@ -125,32 +417,11 @@ def chatgpt_pro_handler(message: Message, bot: TeleBot) -> None: player_message = player_message[2:] try: - r = client.chat.completions.create( - messages=player_message, - model=CHATGPT_PRO_MODEL, - stream=True, - ) - s = "" - start = time.time() - for chunk in r: - logger.debug(chunk) - if chunk.choices: - if chunk.choices[0].delta.content is None: - break - s += chunk.choices[0].delta.content - if time.time() - start > 1.2: - start = time.time() - bot_reply_markdown(reply_id, who, s, bot, split_text=False) - # maybe not complete - try: - bot_reply_markdown(reply_id, who, s, bot, split_text=True) - except Exception: - pass - + reply_text = _stream_chatgpt_pro_response(player_message[:], reply_id, who, bot) player_message.append( { "role": "assistant", - "content": markdownify(s), + "content": reply_text, } )