From bc3ecdb9f8fac605660ca106d9ce331b36d989ab Mon Sep 17 00:00:00 2001 From: yihong0618 Date: Thu, 30 May 2024 19:25:53 +0800 Subject: [PATCH] fix: add seed Signed-off-by: yihong0618 --- .gitignore | 3 +- handlers/tts.py | 82 +++++++++++++++++++++++++++++++++++++++---------- 2 files changed, 68 insertions(+), 17 deletions(-) diff --git a/.gitignore b/.gitignore index 06d059f..79348ba 100644 --- a/.gitignore +++ b/.gitignore @@ -168,4 +168,5 @@ nohup.out *.mp4 *.pdf .pdm-python -tts.wav \ No newline at end of file +tts.wav +tts_pro.wav \ No newline at end of file diff --git a/handlers/tts.py b/handlers/tts.py index 08514ba..7baf081 100644 --- a/handlers/tts.py +++ b/handlers/tts.py @@ -16,30 +16,49 @@ if USE_CHATTTS: chat.load_models() lock = threading.Lock() # Initialize a lock - def generate_tts_wav(prompt): - texts = [ - prompt, - ] - wavs = chat.infer(texts, use_decoder=True) - output_filename = "tts.wav" - audio_data = np.array( - wavs[0], dtype=np.float32 - ) # Ensure the data type is correct + def save_data_to_wav(filename, data): sample_rate = 24000 - # Normalize the audio data to 16-bit PCM range - audio_data = (audio_data * 32767).astype(np.int16) - # Open a .wav file to write into - with wave.open(output_filename, "w") as wf: + with wave.open(filename, "w") as wf: wf.setnchannels(1) # Mono channel wf.setsampwidth(2) # 2 bytes per sample wf.setframerate(sample_rate) - wf.writeframes(audio_data.tobytes()) + wf.writeframes(data.tobytes()) - print(f"Audio has been saved to {output_filename}") + def generate_tts_wav(prompt, seed=None): + texts = [ + prompt, + ] + if seed: + r = chat.sample_random_speaker(seed) + params_infer_code = { + "spk_emb": r, # add sampled speaker + "temperature": 0.3, # using custom temperature + "top_P": 0.7, # top P decode + "top_K": 20, # top K decode + } + wavs = chat.infer( + texts, use_decoder=True, params_infer_code=params_infer_code + ) + output_filename = "tts_pro.wav" + else: + wavs = chat.infer(texts, use_decoder=True) + output_filename = "tts.wav" + + audio_data = np.array( + wavs[0], dtype=np.float32 + ) # Ensure the data type is correct + # Normalize the audio data to 16-bit PCM range + audio_data = (audio_data * 32767).astype(np.int16) + save_data_to_wav(output_filename, audio_data) + + if seed: + print(f"Audio has been saved to {output_filename} with seed {seed}") + else: + print(f"Audio has been saved to {output_filename}") def tts_handler(message: Message, bot: TeleBot): - """pretty tts: /tts
""" + """pretty tts: /tts """ bot.reply_to( message, f"Generating ChatTTS may take some time please wait some time." ) @@ -59,6 +78,37 @@ if USE_CHATTTS: print(e) bot.reply_to(message, "tts error") + def tts_pro_handler(message: Message, bot: TeleBot): + """pretty tts_pro: /tts_pro ,""" + m = message.text.strip() + prompt = m.strip() + seed = prompt.split(",")[0] + bot.reply_to( + message, + f"Generating ChatTTS with seed: {seed} may take some time please wait some time.", + ) + if not seed.isdigit(): + bot.reply_to(message, "first argument must be a number") + return + prompt = prompt[len(str(seed)) + 1 :] + if len(prompt) > 150: + bot.reply_to(message, "prompt too long must length < 150") + return + try: + with lock: + generate_tts_wav(prompt, seed) + with open(f"tts_pro.wav", "rb") as audio: + bot.send_audio( + message.chat.id, audio, reply_to_message_id=message.message_id + ) + except Exception as e: + print(e) + bot.reply_to(message, "tts error") + def register(bot: TeleBot) -> None: bot.register_message_handler(tts_handler, commands=["tts"], pass_bot=True) bot.register_message_handler(tts_handler, regexp="^tts:", pass_bot=True) + bot.register_message_handler( + tts_pro_handler, commands=["tts_pro"], pass_bot=True + ) + bot.register_message_handler(tts_pro_handler, regexp="^tts_pro:", pass_bot=True)