From 785b75e9940a75e14a381ef06872368deafd737f Mon Sep 17 00:00:00 2001 From: yihong0618 Date: Fri, 14 Jun 2024 18:55:20 +0800 Subject: [PATCH] feat: use expire dict Signed-off-by: yihong0618 --- .gitignore | 3 +- handlers/chatgpt.py | 5 ++- handlers/claude.py | 5 ++- handlers/dify.py | 8 ++-- handlers/gemini.py | 8 ++-- handlers/llama.py | 5 ++- handlers/qwen.py | 7 ++-- handlers/tts.py | 97 +++++++++++++++++++++++++++++++++++++++------ handlers/useful.py | 2 +- handlers/yi.py | 10 +++-- pdm.lock | 12 +++++- pyproject.toml | 1 + requirements.txt | 1 + 13 files changed, 127 insertions(+), 37 deletions(-) diff --git a/.gitignore b/.gitignore index 79348ba..77d4e40 100644 --- a/.gitignore +++ b/.gitignore @@ -168,5 +168,4 @@ nohup.out *.mp4 *.pdf .pdm-python -tts.wav -tts_pro.wav \ No newline at end of file +*.wav \ No newline at end of file diff --git a/handlers/chatgpt.py b/handlers/chatgpt.py index 0c19ee2..4c17f76 100644 --- a/handlers/chatgpt.py +++ b/handlers/chatgpt.py @@ -4,6 +4,7 @@ import time from openai import OpenAI from telebot import TeleBot from telebot.types import Message +from expiringdict import ExpiringDict from . import * @@ -23,8 +24,8 @@ client = OpenAI(api_key=CHATGPT_API_KEY, base_url=CHATGPT_BASE_URL, timeout=20) # Global history cache -chatgpt_player_dict = {} -chatgpt_pro_player_dict = {} +chatgpt_player_dict = ExpiringDict(max_len=1000, max_age_seconds=300) +chatgpt_pro_player_dict = ExpiringDict(max_len=1000, max_age_seconds=300) def chatgpt_handler(message: Message, bot: TeleBot) -> None: diff --git a/handlers/claude.py b/handlers/claude.py index ebe8582..30d7997 100644 --- a/handlers/claude.py +++ b/handlers/claude.py @@ -5,6 +5,7 @@ import time from anthropic import Anthropic, APITimeoutError from telebot import TeleBot from telebot.types import Message +from expiringdict import ExpiringDict from . import * @@ -28,8 +29,8 @@ else: # Global history cache -claude_player_dict = {} -claude_pro_player_dict = {} +claude_player_dict = ExpiringDict(max_len=1000, max_age_seconds=300) +claude_pro_player_dict = ExpiringDict(max_len=1000, max_age_seconds=300) def claude_handler(message: Message, bot: TeleBot) -> None: diff --git a/handlers/dify.py b/handlers/dify.py index cf3f3a5..a1cd013 100644 --- a/handlers/dify.py +++ b/handlers/dify.py @@ -1,8 +1,8 @@ from os import environ -import time from telebot import TeleBot from telebot.types import Message +from expiringdict import ExpiringDict from . import * @@ -22,8 +22,10 @@ if DIFY_API_KEY: client = ChatClient(api_key=DIFY_API_KEY) # Global history cache -dify_player_dict = {} -dify_player_c = {} # History cache is supported by dify cloud conversation_id. +dify_player_dict = ExpiringDict(max_len=1000, max_age_seconds=300) +dify_player_c = ExpiringDict( + max_len=1000, max_age_seconds=300 +) # History cache is supported by dify cloud conversation_id. def dify_handler(message: Message, bot: TeleBot) -> None: diff --git a/handlers/gemini.py b/handlers/gemini.py index ef62e03..4140579 100644 --- a/handlers/gemini.py +++ b/handlers/gemini.py @@ -7,8 +7,8 @@ from google.generativeai import ChatSession from google.generativeai.types.generation_types import StopCandidateException from telebot import TeleBot from telebot.types import Message +from expiringdict import ExpiringDict -from telegramify_markdown import convert from telegramify_markdown.customize import markdown_symbol from . import * @@ -34,9 +34,9 @@ safety_settings = [ ] # Global history cache -gemini_player_dict = {} -gemini_pro_player_dict = {} -gemini_file_player_dict = {} +gemini_player_dict = ExpiringDict(max_len=1000, max_age_seconds=300) +gemini_pro_player_dict = ExpiringDict(max_len=1000, max_age_seconds=300) +gemini_file_player_dict = ExpiringDict(max_len=100, max_age_seconds=300) def make_new_gemini_convo(is_pro=False) -> ChatSession: diff --git a/handlers/llama.py b/handlers/llama.py index 8e64fa0..5dd880d 100644 --- a/handlers/llama.py +++ b/handlers/llama.py @@ -3,6 +3,7 @@ import time from telebot import TeleBot from telebot.types import Message +from expiringdict import ExpiringDict from . import * @@ -21,8 +22,8 @@ if LLAMA_API_KEY: client = Groq(api_key=LLAMA_API_KEY) # Global history cache -llama_player_dict = {} -llama_pro_player_dict = {} +llama_player_dict = ExpiringDict(max_len=1000, max_age_seconds=300) +llama_pro_player_dict = ExpiringDict(max_len=1000, max_age_seconds=300) def llama_handler(message: Message, bot: TeleBot) -> None: diff --git a/handlers/qwen.py b/handlers/qwen.py index 40c0bd8..2cbc94a 100644 --- a/handlers/qwen.py +++ b/handlers/qwen.py @@ -4,6 +4,7 @@ import time from telebot import TeleBot from telebot.types import Message +from expiringdict import ExpiringDict from . import * @@ -15,14 +16,14 @@ markdown_symbol.head_level_1 = "📌" # If you want, Customizing the head level markdown_symbol.link = "🔗" # If you want, Customizing the link symbol QWEN_API_KEY = environ.get("TOGETHER_API_KEY") -QWEN_MODEL = "Qwen/Qwen1.5-110B-Chat" +QWEN_MODEL = "Qwen/Qwen2-72B-Instruct" if QWEN_API_KEY: client = Together(api_key=QWEN_API_KEY) # Global history cache -qwen_player_dict = {} -qwen_pro_player_dict = {} +qwen_player_dict = ExpiringDict(max_len=1000, max_age_seconds=300) +qwen_pro_player_dict = ExpiringDict(max_len=1000, max_age_seconds=300) def qwen_handler(message: Message, bot: TeleBot) -> None: diff --git a/handlers/tts.py b/handlers/tts.py index 7baf081..99ff32a 100644 --- a/handlers/tts.py +++ b/handlers/tts.py @@ -1,5 +1,7 @@ +import glob import threading -from os import environ +import subprocess +from os import environ, remove from telebot import TeleBot from telebot.types import Message @@ -10,6 +12,21 @@ import wave import numpy as np from ChatTTS import Chat + +def check_ffmpeg(): + try: + subprocess.run( + ["ffmpeg", "-version"], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + return True + except (subprocess.CalledProcessError, FileNotFoundError): + return False + + +HAS_FFMPEG = check_ffmpeg() USE_CHATTTS = environ.get("USE_CHATTTS") if USE_CHATTTS: chat = Chat() @@ -25,7 +42,7 @@ if USE_CHATTTS: wf.setframerate(sample_rate) wf.writeframes(data.tobytes()) - def generate_tts_wav(prompt, seed=None): + def generate_tts_wav(prompt, output_filename, seed=None): texts = [ prompt, ] @@ -40,10 +57,8 @@ if USE_CHATTTS: wavs = chat.infer( texts, use_decoder=True, params_infer_code=params_infer_code ) - output_filename = "tts_pro.wav" else: wavs = chat.infer(texts, use_decoder=True) - output_filename = "tts.wav" audio_data = np.array( wavs[0], dtype=np.float32 @@ -69,7 +84,7 @@ if USE_CHATTTS: return try: with lock: - generate_tts_wav(prompt) + generate_tts_wav(prompt, "tts.wav") with open(f"tts.wav", "rb") as audio: bot.send_audio( message.chat.id, audio, reply_to_message_id=message.message_id @@ -91,16 +106,72 @@ if USE_CHATTTS: bot.reply_to(message, "first argument must be a number") return prompt = prompt[len(str(seed)) + 1 :] - if len(prompt) > 150: - bot.reply_to(message, "prompt too long must length < 150") - return + # split the prompt by 100 characters + prompt_split = [prompt[i : i + 50] for i in range(0, len(prompt), 50)] + if not HAS_FFMPEG: + if len(prompt) > 150: + bot.reply_to(message, "prompt too long must length < 150") + return try: with lock: - generate_tts_wav(prompt, seed) - with open(f"tts_pro.wav", "rb") as audio: - bot.send_audio( - message.chat.id, audio, reply_to_message_id=message.message_id - ) + if len(prompt_split) > 1: + bot.reply_to( + message, + "Will split the text and use the same to generate the audio and use ffmpeg to combin them pleas wait more time", + ) + for k, v in enumerate(prompt_split): + generate_tts_wav(v, f"{k}.wav", seed) + with open("input.txt", "a") as f: + f.write(f"file {k}.wav\n") + output_file = "tts_pro.wav" + # Run the FFmpeg command + try: + # make sure remove it + try: + remove("tts_pro.wav") + except: + pass + subprocess.run( + [ + "ffmpeg", + "-f", + "concat", + "-safe", + "0", + "-i", + "input.txt", + "-c", + "copy", + "tts_pro.wav", + ], + check=True, + ) + except Exception as e: + print(f"Error combining audio files, {e}") + bot.reply_to(message, "tts error please check the log") + remove("input.txt") + return + print(f"Combined audio saved as {output_file}") + with open(f"tts_pro.wav", "rb") as audio: + bot.send_audio( + message.chat.id, + audio, + reply_to_message_id=message.message_id, + ) + remove("input.txt") + for file in glob.glob("*.wav"): + try: + remove(file) + except OSError as e: + print(e) + else: + generate_tts_wav(prompt, "tts_pro.wav", seed) + with open(f"tts_pro.wav", "rb") as audio: + bot.send_audio( + message.chat.id, + audio, + reply_to_message_id=message.message_id, + ) except Exception as e: print(e) bot.reply_to(message, "tts error") diff --git a/handlers/useful.py b/handlers/useful.py index 25e50fd..b4e7665 100644 --- a/handlers/useful.py +++ b/handlers/useful.py @@ -9,7 +9,7 @@ from . import * def md_handler(message: Message, bot: TeleBot): """pretty md: /md
""" - who = "Markdown" + who = "" reply_id = bot_reply_first(message, who, bot) bot_reply_markdown(reply_id, who, message.text.strip(), bot) diff --git a/handlers/yi.py b/handlers/yi.py index ebdc00a..f5a844b 100644 --- a/handlers/yi.py +++ b/handlers/yi.py @@ -6,6 +6,7 @@ import requests from telebot import TeleBot from telebot.types import Message from telegramify_markdown import convert +from expiringdict import ExpiringDict from . import * @@ -21,7 +22,8 @@ client = OpenAI( ) # Global history cache -yi_player_dict = {} +yi_player_dict = ExpiringDict(max_len=1000, max_age_seconds=300) +yi_pro_player_dict = ExpiringDict(max_len=1000, max_age_seconds=300) def yi_handler(message: Message, bot: TeleBot) -> None: @@ -95,12 +97,12 @@ def yi_pro_handler(message: Message, bot: TeleBot) -> None: player_message = [] # restart will lose all TODO - if str(message.from_user.id) not in yi_player_dict: - yi_player_dict[str(message.from_user.id)] = ( + if str(message.from_user.id) not in yi_pro_player_dict: + yi_pro_player_dict[str(message.from_user.id)] = ( player_message # for the imuutable list ) else: - player_message = yi_player_dict[str(message.from_user.id)] + player_message = yi_pro_player_dict[str(message.from_user.id)] if m.strip() == "clear": bot.reply_to( message, diff --git a/pdm.lock b/pdm.lock index cf3e221..0a332e3 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:01d3d4c572074457bb3e879a335e17c0f3597b57aca720da6ab82370872b86c2" +content_hash = "sha256:92559499879e6b4ab7bb1fd673848820016f25b06e5b87f38dc95cdb15215dcb" [[package]] name = "aiohttp" @@ -587,6 +587,16 @@ files = [ {file = "exceptiongroup-1.2.1.tar.gz", hash = "sha256:a4785e48b045528f5bfe627b6ad554ff32def154f42372786903b7abcfe1aa16"}, ] +[[package]] +name = "expiringdict" +version = "1.2.2" +summary = "Dictionary with auto-expiring values for caching purposes" +groups = ["default"] +files = [ + {file = "expiringdict-1.2.2-py3-none-any.whl", hash = "sha256:09a5d20bc361163e6432a874edd3179676e935eb81b925eccef48d409a8a45e8"}, + {file = "expiringdict-1.2.2.tar.gz", hash = "sha256:300fb92a7e98f15b05cf9a856c1415b3bc4f2e132be07daa326da6414c23ee09"}, +] + [[package]] name = "filelock" version = "3.14.0" diff --git a/pyproject.toml b/pyproject.toml index fb7a592..153d7bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,5 +18,6 @@ dependencies = [ "together>=1.1.5", "dify-client>=0.1.10", "chattts-fork>=0.0.1", + "expiringdict>=1.2.2", ] requires-python = ">=3.10" diff --git a/requirements.txt b/requirements.txt index b6e7388..cc1c506 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,6 +33,7 @@ emoji==2.11.1 encodec==0.1.1 eval-type-backport==0.2.0 exceptiongroup==1.2.1; python_version < "3.11" +expiringdict==1.2.2 filelock==3.14.0 fiona==1.9.6 fonttools==4.51.0