feat: use expire dict

Signed-off-by: yihong0618 <zouzou0208@gmail.com>
This commit is contained in:
yihong0618 2024-06-14 18:55:20 +08:00
parent bc3ecdb9f8
commit 785b75e994
13 changed files with 127 additions and 37 deletions

3
.gitignore vendored
View File

@ -168,5 +168,4 @@ nohup.out
*.mp4 *.mp4
*.pdf *.pdf
.pdm-python .pdm-python
tts.wav *.wav
tts_pro.wav

View File

@ -4,6 +4,7 @@ import time
from openai import OpenAI from openai import OpenAI
from telebot import TeleBot from telebot import TeleBot
from telebot.types import Message from telebot.types import Message
from expiringdict import ExpiringDict
from . import * from . import *
@ -23,8 +24,8 @@ client = OpenAI(api_key=CHATGPT_API_KEY, base_url=CHATGPT_BASE_URL, timeout=20)
# Global history cache # Global history cache
chatgpt_player_dict = {} chatgpt_player_dict = ExpiringDict(max_len=1000, max_age_seconds=300)
chatgpt_pro_player_dict = {} chatgpt_pro_player_dict = ExpiringDict(max_len=1000, max_age_seconds=300)
def chatgpt_handler(message: Message, bot: TeleBot) -> None: def chatgpt_handler(message: Message, bot: TeleBot) -> None:

View File

@ -5,6 +5,7 @@ import time
from anthropic import Anthropic, APITimeoutError from anthropic import Anthropic, APITimeoutError
from telebot import TeleBot from telebot import TeleBot
from telebot.types import Message from telebot.types import Message
from expiringdict import ExpiringDict
from . import * from . import *
@ -28,8 +29,8 @@ else:
# Global history cache # Global history cache
claude_player_dict = {} claude_player_dict = ExpiringDict(max_len=1000, max_age_seconds=300)
claude_pro_player_dict = {} claude_pro_player_dict = ExpiringDict(max_len=1000, max_age_seconds=300)
def claude_handler(message: Message, bot: TeleBot) -> None: def claude_handler(message: Message, bot: TeleBot) -> None:

View File

@ -1,8 +1,8 @@
from os import environ from os import environ
import time
from telebot import TeleBot from telebot import TeleBot
from telebot.types import Message from telebot.types import Message
from expiringdict import ExpiringDict
from . import * from . import *
@ -22,8 +22,10 @@ if DIFY_API_KEY:
client = ChatClient(api_key=DIFY_API_KEY) client = ChatClient(api_key=DIFY_API_KEY)
# Global history cache # Global history cache
dify_player_dict = {} dify_player_dict = ExpiringDict(max_len=1000, max_age_seconds=300)
dify_player_c = {} # History cache is supported by dify cloud conversation_id. dify_player_c = ExpiringDict(
max_len=1000, max_age_seconds=300
) # History cache is supported by dify cloud conversation_id.
def dify_handler(message: Message, bot: TeleBot) -> None: def dify_handler(message: Message, bot: TeleBot) -> None:

View File

@ -7,8 +7,8 @@ from google.generativeai import ChatSession
from google.generativeai.types.generation_types import StopCandidateException from google.generativeai.types.generation_types import StopCandidateException
from telebot import TeleBot from telebot import TeleBot
from telebot.types import Message from telebot.types import Message
from expiringdict import ExpiringDict
from telegramify_markdown import convert
from telegramify_markdown.customize import markdown_symbol from telegramify_markdown.customize import markdown_symbol
from . import * from . import *
@ -34,9 +34,9 @@ safety_settings = [
] ]
# Global history cache # Global history cache
gemini_player_dict = {} gemini_player_dict = ExpiringDict(max_len=1000, max_age_seconds=300)
gemini_pro_player_dict = {} gemini_pro_player_dict = ExpiringDict(max_len=1000, max_age_seconds=300)
gemini_file_player_dict = {} gemini_file_player_dict = ExpiringDict(max_len=100, max_age_seconds=300)
def make_new_gemini_convo(is_pro=False) -> ChatSession: def make_new_gemini_convo(is_pro=False) -> ChatSession:

View File

@ -3,6 +3,7 @@ import time
from telebot import TeleBot from telebot import TeleBot
from telebot.types import Message from telebot.types import Message
from expiringdict import ExpiringDict
from . import * from . import *
@ -21,8 +22,8 @@ if LLAMA_API_KEY:
client = Groq(api_key=LLAMA_API_KEY) client = Groq(api_key=LLAMA_API_KEY)
# Global history cache # Global history cache
llama_player_dict = {} llama_player_dict = ExpiringDict(max_len=1000, max_age_seconds=300)
llama_pro_player_dict = {} llama_pro_player_dict = ExpiringDict(max_len=1000, max_age_seconds=300)
def llama_handler(message: Message, bot: TeleBot) -> None: def llama_handler(message: Message, bot: TeleBot) -> None:

View File

@ -4,6 +4,7 @@ import time
from telebot import TeleBot from telebot import TeleBot
from telebot.types import Message from telebot.types import Message
from expiringdict import ExpiringDict
from . import * from . import *
@ -15,14 +16,14 @@ markdown_symbol.head_level_1 = "📌" # If you want, Customizing the head level
markdown_symbol.link = "🔗" # If you want, Customizing the link symbol markdown_symbol.link = "🔗" # If you want, Customizing the link symbol
QWEN_API_KEY = environ.get("TOGETHER_API_KEY") QWEN_API_KEY = environ.get("TOGETHER_API_KEY")
QWEN_MODEL = "Qwen/Qwen1.5-110B-Chat" QWEN_MODEL = "Qwen/Qwen2-72B-Instruct"
if QWEN_API_KEY: if QWEN_API_KEY:
client = Together(api_key=QWEN_API_KEY) client = Together(api_key=QWEN_API_KEY)
# Global history cache # Global history cache
qwen_player_dict = {} qwen_player_dict = ExpiringDict(max_len=1000, max_age_seconds=300)
qwen_pro_player_dict = {} qwen_pro_player_dict = ExpiringDict(max_len=1000, max_age_seconds=300)
def qwen_handler(message: Message, bot: TeleBot) -> None: def qwen_handler(message: Message, bot: TeleBot) -> None:

View File

@ -1,5 +1,7 @@
import glob
import threading import threading
from os import environ import subprocess
from os import environ, remove
from telebot import TeleBot from telebot import TeleBot
from telebot.types import Message from telebot.types import Message
@ -10,6 +12,21 @@ import wave
import numpy as np import numpy as np
from ChatTTS import Chat from ChatTTS import Chat
def check_ffmpeg():
try:
subprocess.run(
["ffmpeg", "-version"],
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
return True
except (subprocess.CalledProcessError, FileNotFoundError):
return False
HAS_FFMPEG = check_ffmpeg()
USE_CHATTTS = environ.get("USE_CHATTTS") USE_CHATTTS = environ.get("USE_CHATTTS")
if USE_CHATTTS: if USE_CHATTTS:
chat = Chat() chat = Chat()
@ -25,7 +42,7 @@ if USE_CHATTTS:
wf.setframerate(sample_rate) wf.setframerate(sample_rate)
wf.writeframes(data.tobytes()) wf.writeframes(data.tobytes())
def generate_tts_wav(prompt, seed=None): def generate_tts_wav(prompt, output_filename, seed=None):
texts = [ texts = [
prompt, prompt,
] ]
@ -40,10 +57,8 @@ if USE_CHATTTS:
wavs = chat.infer( wavs = chat.infer(
texts, use_decoder=True, params_infer_code=params_infer_code texts, use_decoder=True, params_infer_code=params_infer_code
) )
output_filename = "tts_pro.wav"
else: else:
wavs = chat.infer(texts, use_decoder=True) wavs = chat.infer(texts, use_decoder=True)
output_filename = "tts.wav"
audio_data = np.array( audio_data = np.array(
wavs[0], dtype=np.float32 wavs[0], dtype=np.float32
@ -69,7 +84,7 @@ if USE_CHATTTS:
return return
try: try:
with lock: with lock:
generate_tts_wav(prompt) generate_tts_wav(prompt, "tts.wav")
with open(f"tts.wav", "rb") as audio: with open(f"tts.wav", "rb") as audio:
bot.send_audio( bot.send_audio(
message.chat.id, audio, reply_to_message_id=message.message_id message.chat.id, audio, reply_to_message_id=message.message_id
@ -91,16 +106,72 @@ if USE_CHATTTS:
bot.reply_to(message, "first argument must be a number") bot.reply_to(message, "first argument must be a number")
return return
prompt = prompt[len(str(seed)) + 1 :] prompt = prompt[len(str(seed)) + 1 :]
if len(prompt) > 150: # split the prompt by 100 characters
bot.reply_to(message, "prompt too long must length < 150") prompt_split = [prompt[i : i + 50] for i in range(0, len(prompt), 50)]
return if not HAS_FFMPEG:
if len(prompt) > 150:
bot.reply_to(message, "prompt too long must length < 150")
return
try: try:
with lock: with lock:
generate_tts_wav(prompt, seed) if len(prompt_split) > 1:
with open(f"tts_pro.wav", "rb") as audio: bot.reply_to(
bot.send_audio( message,
message.chat.id, audio, reply_to_message_id=message.message_id "Will split the text and use the same to generate the audio and use ffmpeg to combin them pleas wait more time",
) )
for k, v in enumerate(prompt_split):
generate_tts_wav(v, f"{k}.wav", seed)
with open("input.txt", "a") as f:
f.write(f"file {k}.wav\n")
output_file = "tts_pro.wav"
# Run the FFmpeg command
try:
# make sure remove it
try:
remove("tts_pro.wav")
except:
pass
subprocess.run(
[
"ffmpeg",
"-f",
"concat",
"-safe",
"0",
"-i",
"input.txt",
"-c",
"copy",
"tts_pro.wav",
],
check=True,
)
except Exception as e:
print(f"Error combining audio files, {e}")
bot.reply_to(message, "tts error please check the log")
remove("input.txt")
return
print(f"Combined audio saved as {output_file}")
with open(f"tts_pro.wav", "rb") as audio:
bot.send_audio(
message.chat.id,
audio,
reply_to_message_id=message.message_id,
)
remove("input.txt")
for file in glob.glob("*.wav"):
try:
remove(file)
except OSError as e:
print(e)
else:
generate_tts_wav(prompt, "tts_pro.wav", seed)
with open(f"tts_pro.wav", "rb") as audio:
bot.send_audio(
message.chat.id,
audio,
reply_to_message_id=message.message_id,
)
except Exception as e: except Exception as e:
print(e) print(e)
bot.reply_to(message, "tts error") bot.reply_to(message, "tts error")

View File

@ -9,7 +9,7 @@ from . import *
def md_handler(message: Message, bot: TeleBot): def md_handler(message: Message, bot: TeleBot):
"""pretty md: /md <address>""" """pretty md: /md <address>"""
who = "Markdown" who = ""
reply_id = bot_reply_first(message, who, bot) reply_id = bot_reply_first(message, who, bot)
bot_reply_markdown(reply_id, who, message.text.strip(), bot) bot_reply_markdown(reply_id, who, message.text.strip(), bot)

View File

@ -6,6 +6,7 @@ import requests
from telebot import TeleBot from telebot import TeleBot
from telebot.types import Message from telebot.types import Message
from telegramify_markdown import convert from telegramify_markdown import convert
from expiringdict import ExpiringDict
from . import * from . import *
@ -21,7 +22,8 @@ client = OpenAI(
) )
# Global history cache # Global history cache
yi_player_dict = {} yi_player_dict = ExpiringDict(max_len=1000, max_age_seconds=300)
yi_pro_player_dict = ExpiringDict(max_len=1000, max_age_seconds=300)
def yi_handler(message: Message, bot: TeleBot) -> None: def yi_handler(message: Message, bot: TeleBot) -> None:
@ -95,12 +97,12 @@ def yi_pro_handler(message: Message, bot: TeleBot) -> None:
player_message = [] player_message = []
# restart will lose all TODO # restart will lose all TODO
if str(message.from_user.id) not in yi_player_dict: if str(message.from_user.id) not in yi_pro_player_dict:
yi_player_dict[str(message.from_user.id)] = ( yi_pro_player_dict[str(message.from_user.id)] = (
player_message # for the imuutable list player_message # for the imuutable list
) )
else: else:
player_message = yi_player_dict[str(message.from_user.id)] player_message = yi_pro_player_dict[str(message.from_user.id)]
if m.strip() == "clear": if m.strip() == "clear":
bot.reply_to( bot.reply_to(
message, message,

12
pdm.lock generated
View File

@ -5,7 +5,7 @@
groups = ["default"] groups = ["default"]
strategy = ["cross_platform", "inherit_metadata"] strategy = ["cross_platform", "inherit_metadata"]
lock_version = "4.4.1" lock_version = "4.4.1"
content_hash = "sha256:01d3d4c572074457bb3e879a335e17c0f3597b57aca720da6ab82370872b86c2" content_hash = "sha256:92559499879e6b4ab7bb1fd673848820016f25b06e5b87f38dc95cdb15215dcb"
[[package]] [[package]]
name = "aiohttp" name = "aiohttp"
@ -587,6 +587,16 @@ files = [
{file = "exceptiongroup-1.2.1.tar.gz", hash = "sha256:a4785e48b045528f5bfe627b6ad554ff32def154f42372786903b7abcfe1aa16"}, {file = "exceptiongroup-1.2.1.tar.gz", hash = "sha256:a4785e48b045528f5bfe627b6ad554ff32def154f42372786903b7abcfe1aa16"},
] ]
[[package]]
name = "expiringdict"
version = "1.2.2"
summary = "Dictionary with auto-expiring values for caching purposes"
groups = ["default"]
files = [
{file = "expiringdict-1.2.2-py3-none-any.whl", hash = "sha256:09a5d20bc361163e6432a874edd3179676e935eb81b925eccef48d409a8a45e8"},
{file = "expiringdict-1.2.2.tar.gz", hash = "sha256:300fb92a7e98f15b05cf9a856c1415b3bc4f2e132be07daa326da6414c23ee09"},
]
[[package]] [[package]]
name = "filelock" name = "filelock"
version = "3.14.0" version = "3.14.0"

View File

@ -18,5 +18,6 @@ dependencies = [
"together>=1.1.5", "together>=1.1.5",
"dify-client>=0.1.10", "dify-client>=0.1.10",
"chattts-fork>=0.0.1", "chattts-fork>=0.0.1",
"expiringdict>=1.2.2",
] ]
requires-python = ">=3.10" requires-python = ">=3.10"

View File

@ -33,6 +33,7 @@ emoji==2.11.1
encodec==0.1.1 encodec==0.1.1
eval-type-backport==0.2.0 eval-type-backport==0.2.0
exceptiongroup==1.2.1; python_version < "3.11" exceptiongroup==1.2.1; python_version < "3.11"
expiringdict==1.2.2
filelock==3.14.0 filelock==3.14.0
fiona==1.9.6 fiona==1.9.6
fonttools==4.51.0 fonttools==4.51.0