first pull

This commit is contained in:
2026-02-03 01:20:00 +08:00
commit 0753da86a8
79 changed files with 7134 additions and 0 deletions
+31
View File
@@ -0,0 +1,31 @@
# Python-generated files
__pycache__/
*.py[oc]
build/
dist/
wheels/
*.egg-info
# from ide
.vscode
.idea
# Virtual environments
.venv
venv
# config
settings.json
# logs
log/*
# temp files
temp_files/*
# ai models
src/modules/asr_module/asr_models/*
# uv
uv.lock
+1
View File
@@ -0,0 +1 @@
3.11
+85
View File
@@ -0,0 +1,85 @@
### Yosuga_server
## 📊 Project Stats
![GitHub last commit](https://img.shields.io/github/last-commit/Misakityan/Yosuga_server)
![GitHub issues](https://img.shields.io/github/issues/Misakityan/Yosuga_server)
![GitHub stars](https://img.shields.io/github/stars/Misakityan/Yosuga_server?style=social)
欢迎访问本项目。
首先向你介绍一下Yosuga这个项目:
本项目的作者是Misakiotoha(みさきおとは[見崎音羽])。[call me "Misaki" でいいよ]
之所以叫Yosuga,这个词来源日语当中的单词"縁"的发音,其意思是"缘分,关系"。
本项目分为三个部分:
1. Yosuga:这是项目的前端部分,是Yosuga与用户交互的一层,采用C++20 + Qt6.6.3编写,使用到的核心外部库为Live2D For C++ SDK。
2. Yosuga_server:这是项目的后端部分,是Yosuga的核心,采用python3.11编写,使用到的外部库较多,负责联系项目的各个部分。
3. Yosuga_embedded:这是项目的拓展部分,使得Yosuga对嵌入式设备拥有几乎完全的自定义控制能力,采用C语言编写,只使用到了cJSON库,平台无关,增强了Yosuga与外界的交互能力。
**_本项目为Yosuga_server._**
本项目使用uv构建,基于python3.11.
本项目由YosugaServer发展而来,项目架构与代码有了相当大的改变。(YosugaServer并未开源,它仅仅是一次小小的尝试)
### 如何快速启动本项目?
1. 确保uv已安装,并添加到环境变量中
2. 执行`cd Yosuga_server` & `uv sync`
3. 接着,如果你的电脑带有cuda,那么执行 `uv pip install -r requirements-cuda.txt`
4. 如果没有cuda,那么执行 `uv pip install -r requirements-cpu.txt`
5. 最后执行 `uv run python main.py` 即可启动项目
首次启动项目后,会在项目根目录下生成settings.json配置文件,你需要配置一些必要的字段信息:
```json
{
"ai": {
"api_key": "sk-xxxxx",
"base_url": "http://localhost:1234/v1",
"model_name": "qwen/qwen3-4b-2507"
},
"tts": {
"gpt_model_name": "GPT_weights_v2Pro/Yosuga_Airi-e32.ckpt",
"sovits_model_name": "SoVITS_weights_v2Pro/Yosuga_Airi_e16_s864.pth",
"host": "localhost",
"port": 20261,
"reference_audio": "./using/reference.wav"
},
"asr": {
"url": "http://localhost:20260/"
},
"auto_agent": {
"deployment_type": "lmstudio",
"model_name": "ui-tars-1.5-7b@q4_k_m",
"base_url": "http://localhost:1234/v1"
},
"llm_core": {
"role_character": "你是由Misakiotoha开发的助手稲葉愛理ちゃん,可以和用户一起玩游戏,聊天,做各种事情,性格抽象,没事爱整整活。",
"max_context_tokens": 2048,
"language": "日本语"
}
}
```
上面这些字段的信息,你需要根据你的实际情况进行配置。实际的配置文件的字段名称会比上面的多出不少。
配置完成后,再次重启服务端就可以使用啦~
接着是每个模型的配置相关:
1. asr模型,本项目使用fast-whisper作为asr模型,并且附带了一键启动的部分
,你需要找到 `Yosuga_server/src/modules/asr_module/start_api.py` 这个文件,然后启动它
,一般来说,即使是cpu也可以进行asr模型的推理,但是速度相比cuda要逊色很多。
同时,如果你遇到了启动时长时间加载,那么此时你需要试着挂一下梯子,因为初次启动
会在Hugging Face上下载模型。
2. tts模型,本项目使用GPT-SoVITS作为tts模型,建议使用其V2Pro版本。
3. auto_agent模型,本项目使用的自动化操作识别的模型为字节跳动开源的
`ui-tars-1.5-7b@q4_k_m` 关于此模型的更多信息可以参考字节跳动的[开源链接](https://github.com/bytedance/UI-TARS)
,建议在LM Studio上进行部署,该模型十分轻量。
4. ai模型,该模型限制为大语言模型,没有限制,本项目支持市面上的所有大语言模型。
本项目当前并不完善,还有很多需要优化的地方,并且尚未接入Yosuga_embedded。
欢迎大家为本项目贡献代码。
+173
View File
@@ -0,0 +1,173 @@
import asyncio
from src.modules.tts_module.tts_core.gpt_sovits.gpt_sovits_client import GPTSoVITSClient, StreamingMode
from src.modules.tts_module.tts_core.async_audio_player import AsyncAudioPlayer
import sounddevice as sd
test_text = "春の午後、公園のベンチに座って本を読んでいると、小さな子供が凧あげをしているのが目に入った。風に乗って凧が高く上がるたびに、彼の顔には真っすぐな笑顔が広がる。母親がそばで見守りながら、時折声をかけている。\
空は雲一つない青さで、桜の花びらが風に舞っている。遠くで犬の鳴き声が聞こえ、芝生の上では老夫婦がお茶を楽しんでいた。すべてがゆっくりと流れる時間の中で、自分の心も不思議と落ち着いてくる。\
ふと、子供の凧が木の枝に引っかかってしまった。少し焦る様子だったが、母親が助けてくれて、すぐにまた空に舞い上がった。失敗しても、誰かが助けてくれる。そんな当たり前のことに、今日は特別な温かさを感じた。\
日が傾き始める頃、私は本を閉じて家路についた。明日もきっと、誰かの笑顔があるだろう。"
async def test_tts():
# 创建客户端(推荐上下文管理器)
async with GPTSoVITSClient(debug=True, port= 20261, host="192.168.1.8") as client:
# 基础TTS调用
try:
audio = await client.tts(
text="あのさ、いやまあ、なんていうか...要するに、そういうことじゃなくて、ほら、前に言ってたやつ、あれなんだけど、とにかく、後ででもいいから、ちょっと相談に乗ってくれない?",
ref_audio_path="uploaded_audio/test_voice.wav", # 服务器上的路径
text_lang="ja",
prompt_lang="ja",
media_type="wav",
prompt_text="もう!こんなところで何やってるんだよ!"
)
# 保存音频
audio.save("outputs/output.wav")
print(f"✅ TTS成功!音频大小: {len(audio.audio_data)} bytes")
except Exception as e:
print(f"❌ 错误: {e}")
async def test_model_change():
async with GPTSoVITSClient(debug=True, port= 20261, host="192.168.1.8") as client:
# 切换模型
print("🔄 切换GPT模型...")
await client.set_gpt_weights(
"GPT_weights_v2Pro/Yosuga_Airi-e32.ckpt"
)
print("🔄 切换SoVITS模型...")
await client.set_sovits_weights(
"SoVITS_weights_v2Pro/Yosuga_Airi_e16_s864.pth"
)
async def stream_tts():
async with GPTSoVITSClient(debug=True, port= 20261, host="192.168.1.8") as client:
try:
# 使用最快模式流式输出
chunk_count = 0
async for chunk in await client.tts(
text="要するに、そういうことじゃなくて、ほら、前に言ってたやつ、あれなんだけど、とにかく、後ででもいいから、ちょっと相談に乗ってくれない?",
ref_audio_path="uploaded_audio/test_voice.wav",
text_lang="ja",
prompt_lang="ja",
prompt_text="もう!こんなところで何やってるんだよ!",
streaming_mode=StreamingMode.FASTEST, # 模式3:快速流式
media_type="wav"
):
chunk_count += 1
print(f"🎵 收到音频块 #{chunk_count}: {len(chunk.audio_data)} bytes")
# 实时播放处理
# await play_audio_chunk(chunk.audio_data)
print(f"✅ 流式TTS完成!共{chunk_count}个音频块")
except Exception as e:
print(f"❌ 流式错误: {e}")
async def stream_tts_and_play(
text: str,
ref_audio_path: str,
text_lang: str = "zh",
prompt_lang: str = "zh",
streaming_mode: StreamingMode = StreamingMode.FASTEST
):
"""
实时流式TTS + 播放一体化
Args:
text: 要合成的文本
ref_audio_path: 参考音频路径
text_lang: 文本语言
prompt_lang: 提示语言
"""
# 创建音频播放器(缓冲区大小=5,平衡延迟和稳定性)
async with AsyncAudioPlayer(buffer_size=5) as player:
# 创建TTS客户端
async with GPTSoVITSClient(debug=True, port= 20261, host="192.168.1.8") as client:
try:
print(f"🎤 开始流式合成: {text[:30]}...")
print(f"🎯 流式模式: {streaming_mode.name}")
# 获取音频流(异步生成器)
audio_stream = await client.tts(
text=text,
ref_audio_path=ref_audio_path,
text_lang=text_lang,
prompt_lang=prompt_lang,
prompt_text="もう!こんなところで何やってるんだよ!",
streaming_mode=streaming_mode,
media_type="wav",
sample_steps=32,
top_k=5,
temperature=1.0
)
# 动态读取并播放
chunk_idx = 0
async for audio_chunk in audio_stream:
chunk_idx += 1
print(f"📥 收到音频块 #{chunk_idx}: {len(audio_chunk.audio_data):6d} bytes")
# 立即加入播放队列(非阻塞)
await player.add_chunk(audio_chunk.audio_data)
print(f"✅ 合成完成! 共接收 {chunk_idx} 个音频块")
# 等待播放完成(所有块播完)
await player.audio_queue.join()
print("🎵 播放完成!")
except Exception as e:
print(f"❌ 错误: {e}")
raise
async def test_japanese():
"""测试日语长文本流式播放"""
print("=" * 50)
print("🗾 日语流式TTS测试")
print("=" * 50)
await stream_tts_and_play(
text=test_text,
ref_audio_path="uploaded_audio/test_voice.wav",
text_lang="ja",
prompt_lang="ja",
streaming_mode=StreamingMode.FASTEST # 模式3:最快
)
async def batch_test():
"""批量处理示例"""
async with GPTSoVITSClient() as client:
texts = [
"你好,世界!",
"这是一个批量测试。",
"异步批量处理非常高效。"
]
results = await client.batch_tts(
texts=texts,
ref_audio_path="archive_jingyuan_1.wav",
text_lang="zh"
)
for i, audio in enumerate(results):
audio.save(f"output/batch_{i}.wav")
print(f"✅ 批量任务 {i + 1}/{len(results)} 完成")
if __name__ == "__main__":
# 检查音频设备
print("🔍 检查音频设备...")
print(sd.query_devices())
sd.default.device = (None, "pulse") # 使用PulseAudio
asyncio.run(test_japanese())
+24
View File
@@ -0,0 +1,24 @@
import asyncio
import json
from websockets.asyncio.client import connect
async def test_all_types():
"""测试三种消息类型"""
async with connect("ws://localhost:8765") as ws:
print("=== 测试JSON消息 ===")
await ws.send(json.dumps({
"type": "chat",
"content": "你好服务器!"
}))
print(f"收到: {await ws.recv()}")
print("\n=== 测试文本消息 ===")
await ws.send("这是纯文本消息")
print(f"收到: {await ws.recv()}")
print("\n=== 测试二进制消息 ===")
await ws.send(b"\x00\x01\x02\x03\x04")
print(f"收到: {await ws.recv()}")
if __name__ == "__main__":
asyncio.run(test_all_types())
+174
View File
@@ -0,0 +1,174 @@
"""
极简 WebSocket 测试服务器 - 修复版本
"""
import asyncio
import json
import logging
from datetime import datetime
from typing import Set
import websockets
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(message)s'
)
class SimpleWebSocketServer:
def __init__(self, host="localhost", port=8765):
self.host = host
self.port = port
self.clients: Set = set()
async def handle_connection(self, websocket, path):
"""处理客户端连接"""
client_id = f"{websocket.remote_address[0]}:{websocket.remote_address[1]}"
self.clients.add(websocket)
logging.info(f"✅ 客户端连接: {client_id} (当前连接数: {len(self.clients)})")
try:
# 发送欢迎消息
welcome = {
"type": "connect",
"data": {
"message": "WebSocket 服务器连接成功",
"client_id": client_id,
"server_time": datetime.now().isoformat(),
"status": "connected"
},
"timestamp": int(datetime.now().timestamp() * 1000)
}
await websocket.send(json.dumps(welcome))
async for message in websocket:
await self.handle_message(websocket, client_id, message)
except websockets.exceptions.ConnectionClosed:
logging.info(f"❌ 客户端断开: {client_id}")
finally:
self.clients.discard(websocket)
logging.info(f"📊 剩余连接: {len(self.clients)}")
async def handle_message(self, websocket, client_id, message):
"""处理收到的消息"""
logging.info(f"📨 收到消息 from {client_id}: {message}")
try:
# 尝试解析为 JSON
data = json.loads(message)
msg_type = data.get("type", "unknown")
# 根据消息类型回复
if msg_type == "ping":
# 心跳响应
response = {
"type": "pong",
"data": {
"server_time": datetime.now().isoformat(),
"latency": "0ms"
},
"timestamp": int(datetime.now().timestamp() * 1000)
}
elif msg_type == "login":
# 登录响应
username = data.get("data", {}).get("username", "anonymous")
response = {
"type": "login_success",
"data": {
"user_id": f"user_{abs(hash(username)) % 10000}",
"username": username,
"status": "authenticated"
},
"timestamp": int(datetime.now().timestamp() * 1000)
}
elif msg_type == "chat":
# 聊天消息回应
msg_content = data.get("data", {}).get("message", "")
response = {
"type": "chat_response",
"data": {
"message": f"服务器收到: {msg_content}",
"sender": "server",
"received_at": datetime.now().isoformat()
},
"timestamp": int(datetime.now().timestamp() * 1000)
}
else:
# 默认回显
response = {
"type": "echo",
"data": {
"original": data.get("data", {}),
"original_type": msg_type,
"server_processed_at": datetime.now().isoformat()
},
"timestamp": int(datetime.now().timestamp() * 1000)
}
await websocket.send(json.dumps(response))
except json.JSONDecodeError:
# 不是 JSON,当作纯文本处理
response = {
"type": "text_echo",
"data": {
"original": message,
"note": "这是文本消息"
},
"timestamp": int(datetime.now().timestamp() * 1000)
}
await websocket.send(json.dumps(response))
async def start(self):
"""启动服务器"""
logging.info(f"🚀 启动 WebSocket 服务器: ws://{self.host}:{self.port}")
# 创建处理函数包装器(解决参数问题)
async def connection_handler(websocket, path):
await self.handle_connection(websocket, path)
# 启动服务器
server = await websockets.serve(
connection_handler,
self.host,
self.port,
ping_interval=None,
ping_timeout=None,
close_timeout=None,
max_size=10 * 1024 * 1024
)
logging.info("📌 服务器已启动,等待连接...")
logging.info("🛑 按 Ctrl+C 停止服务器")
# 保持服务器运行
try:
await asyncio.Future() # 永久运行
finally:
server.close()
await server.wait_closed()
logging.info("👋 服务器已关闭")
def main():
"""主函数"""
import argparse
parser = argparse.ArgumentParser(description='极简 WebSocket 测试服务器')
parser.add_argument('--host', default='localhost', help='监听地址')
parser.add_argument('--port', type=int, default=8088, help='监听端口')
args = parser.parse_args()
server = SimpleWebSocketServer(args.host, args.port)
try:
asyncio.run(server.start())
except KeyboardInterrupt:
logging.info("👋 服务器被用户中断")
if __name__ == "__main__":
main()
+33
View File
@@ -0,0 +1,33 @@
# requestTest.py
import requests
from pathlib import Path
# 指定正确的 MIME 类型
url = "http://192.168.1.8:20260/transcribe"
audio_path = Path("test_files/z105300938.wav")
with open(audio_path, "rb") as f:
# 明确指定文件名和 MIME 类型
files = {
"file": (
audio_path.name, # 文件名
f, # 文件对象
"audio/wav" # MIME 类型
)
}
response = requests.post(url, files=files)
# 打印响应详情
print(f"状态码: {response.status_code}")
print(f"响应头: {response.headers.get('content-type')}")
# 检查响应是否成功
if response.status_code == 200:
result = response.json()
print(f"识别结果: {result['data']['text']}")
print(f"语言: {result['data']['language']}")
print(f"置信度: {result['data']['confidence']}")
print(f"处理时间: {result['data']['processing_time']}s")
else:
print(f"错误响应: {response.text}")
+14
View File
@@ -0,0 +1,14 @@
# 一个小Test, 展示设计的dtos模块与tts和asr的集成
from src.modules.websocket_base_module.dto.third_dtos import AudioDataDTO
from src.modules.tts_module.tts_core.async_audio_player import AsyncAudioPlayer
from src.modules.tts_module.tts_core.gpt_sovits.gpt_sovits_client import GPTSoVITSClient, StreamingMode
from src.modules.asr_module.client.asr_client import create_asr_client
# with create_asr_client(base_url="http://192.168.1.5:20260") as client:
# # 转录文件
# result = client.transcribe_file("test_files/test.wav")
# print(f"识别结果: {result.data.text}")
# print(f"置信度: {result.data.confidence:.2f}")
# print(f"耗时: {result.data.processing_time:.3f}s")
+30
View File
@@ -0,0 +1,30 @@
from src.modules.websocket_base_module.dto.second_dtos import get_json_dto_instance
from src.modules.websocket_base_module.dto.third_dtos import AudioDataDTO
from src.modules.websocket_base_module.websocket_core.core_ws_server import get_ws_server
import asyncio
from loguru import logger
async def main():
# 获取WebSocket服务器单例
ws_server = await get_ws_server()
# 获取二级json分发器单例
json_dto = await get_json_dto_instance(ws_server)
# 创建DTO实例(自动注册接收函数)
audio_dto = AudioDataDTO(json_dto)
logger.info("所有DTO接收器已注册,等待客户端连接...")
# 启动服务器(阻塞)
try:
await ws_server.run("localhost", 8765)
except asyncio.CancelledError:
logger.info("服务器任务已取消,正在优雅退出...")
finally:
logger.info("服务器已停止")
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
print("\n✓ 服务器已手动终止(按 Ctrl+C)")
Binary file not shown.
Binary file not shown.
Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 MiB

Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
+88
View File
@@ -0,0 +1,88 @@
from src.modules.text_ai_module.text_ai_core.general_text_ai_req import UnifiedLLM, ModelConfig, ModelProvider, create_llm_client
from src.config.config import get_settings
from src.config.convert_env import EnvConverter
from src.config.file_config import DirectoryInitializer
EnvConverter().convert(backup_existing=True) # 若是首次启动则从env模板中生成env文件
DirectoryInitializer(get_settings()) # 初始化必要的目录(若不存在则创建)
def test1():
"""
测试常规调用
"""
# 配置模型
config = ModelConfig(
provider=ModelProvider.OPENAI,
model_name=get_settings().ai_model_name,
base_url=get_settings().ai_api_base_url,
api_key=get_settings().ai_api_key, # 从环境中取出相关的api_key
temperature=0.7,
max_tokens=2048
)
# 创建客户端
llm = UnifiedLLM(config)
# 发送消息
response = llm.chat([
{"role": "system", "content": "你是一个DeepSeek助手"},
{"role": "user", "content": "请介绍一下DeepSeek模型的特点"}
])
print(response.content)
def base_test2():
"""
测试流式响应
"""
# 使用快捷函数
deepseek_llm = create_llm_client(
provider="openai", # DeepSeek使用OpenAI兼容接口
model_name=get_settings().ai_model_name,
api_key=get_settings().ai_api_key,
base_url=get_settings().ai_api_base_url
)
# 流式聊天
messages = [
{"role": "user", "content": "用Python写一个快速排序算法"}
]
print("正在生成响应...")
for chunk in deepseek_llm.stream_chat(messages):
print(chunk.content, end="", flush=True)
def test_lm_studio():
"""测试本地 LM Studio 模型"""
print("=== 测试本地 LM Studio ===")
# 使用UnifiedLLM类
config = ModelConfig(
provider=ModelProvider.LM_STUDIO,
model_name="qwen/qwen3-4b-2507",
base_url="http://192.168.1.8:1234/v1",
api_key="", # LM Studio不需要API密钥,留空
temperature=0.7,
max_tokens=1024,
streaming=False # 启用流式响应
)
llm = UnifiedLLM(config)
# 发送消息
messages = [
{"role": "system", "content": "你是一个有用的助手"},
{"role": "user", "content": "用中文介绍一下自己"}
]
print("非流式响应:")
response = llm.chat(messages, streaming=False)
print(response.content)
print("\n流式响应:")
for chunk in llm.stream_chat(messages):
print(chunk.content, end="", flush=True)
if __name__ == "__main__":
test_lm_studio()
+66
View File
@@ -0,0 +1,66 @@
import asyncio
import base64
from pathlib import Path
from src.modules.device_control_module.device_control_core.ui_tars_.ui_tars_client import UITarsClient, UITarsClientConfig
async def test_ui_tars_stream():
"""测试 UI-TARS 流式调用"""
# 创建客户端
config = UITarsClientConfig(
deployment_type="lmstudio",
base_url="http://192.168.1.8:1234/v1",
model_name="ui-tars-1.5-7b@q4_k_m",
temperature=0.1
)
client = UITarsClient(config)
# 使用工具方法编码
image_base64 = base64.b64encode(Path("test_files/Screenshot_test.png").read_bytes()).decode()
print(f"✅ 图片编码完成,长度: {len(image_base64)} 字符\n")
# 流式调用并实时打印
print("🤖 开始流式调用 UI-TARS...\n")
print("思考过程:\n")
import time
# 计算耗时
start_time = time.time()
full_response = ""
chunk_count = 0
full_response = await client.call_async("打开AK加速器", image_base64)
# 传入 base64 字符串
# for chunk in client.stream_async("我的桌面系统是KDE, 帮我打开设置", image_base64):
# chunk_count += 1
# content = chunk.content
#
# # 实时打印每个 chunk
# print(content, end="", flush=True)
#
# # 累积完整内容
# full_response += content
end_time = time.time()
print(f"\n\n耗时: {end_time - start_time:.2f}")
print(f"\n\n{'=' * 50}")
print(f"✅ 流式调用完成!共接收 {chunk_count} 个 chunk")
print(f"完整响应长度: {len(full_response)} 字符")
print("响应内容:\n")
print(full_response)
import pyautogui
def auto_click(x : int, y : int):
pyautogui.moveTo(x, y, duration=1.5)
pyautogui.click()
def auto_drag(x1 : int, y1 : int, x2 : int, y2 : int):
pyautogui.moveTo(x1, y1, duration=1.5)
pyautogui.dragTo(x2, y2, duration=1.5)
# 运行异步函数
if __name__ == "__main__":
asyncio.run(test_ui_tars_stream())
auto_click(173,48)
# auto_drag(56,39, 170,39)
+45
View File
@@ -0,0 +1,45 @@
import asyncio
from contextlib import suppress
from loguru import logger
from src.config.config import cfg
from datetime import datetime
from src.server_core.core import YosugaServerCore
def init():
"""
Yosuga_server 初始化
"""
# 初始化日志系统
logger.add(
f"{cfg.log_dir}/Yosuga_server-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.log",
encoding="utf-8")
logger.info("Yosuga_server 启动")
logger.info(f"日志文件目录见: {cfg.log_dir} 目录")
async def main():
core = await YosugaServerCore.get_instance()
try:
await core.run()
except asyncio.CancelledError:
pass # 正常取消,不打印堆栈
finally:
# 清理未关闭的 aiohttp sessions
import aiohttp
pending = [t for t in asyncio.all_tasks()
if t is not asyncio.current_task()]
for task in pending:
task.cancel()
with suppress(asyncio.CancelledError):
await asyncio.gather(*pending, return_exceptions=True)
if __name__ == "__main__":
init()
try:
asyncio.run(main())
except KeyboardInterrupt:
print("\nYosuga服务器已停止喵~~~")
+24
View File
@@ -0,0 +1,24 @@
[project]
name = "yosuga-server"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
"aiofiles>=25.1.0",
"aiohttp>=3.13.3",
"fastapi>=0.128.0",
"faster-whisper>=1.2.1",
"loguru>=0.7.3",
"openai>=2.16.0",
"pyautogui>=0.9.54",
"pydantic>=2.12.5",
"pydantic-settings>=2.12.0",
"python-multipart>=0.0.22",
"requests>=2.32.5",
"sounddevice>=0.5.5",
"soundfile>=0.13.1",
"tiktoken>=0.12.0",
"uvicorn>=0.40.0",
"websockets>=16.0",
]
+84
View File
@@ -0,0 +1,84 @@
--index-url https://download.pytorch.org/whl/cpu
aiofiles==25.1.0
aiohappyeyeballs==2.6.1
aiohttp==3.13.3
aiosignal==1.4.0
annotated-doc==0.0.4
annotated-types==0.7.0
anyio==4.12.1
attrs==25.4.0
av==16.1.0
certifi==2026.1.4
cffi==2.0.0
charset-normalizer==3.4.4
click==8.3.1
colorama==0.4.6
coloredlogs==15.0.1
ctranslate2==4.6.3
distro==1.9.0
fastapi==0.128.0
faster-whisper==1.2.1
filelock==3.20.3
flatbuffers==25.12.19
frozenlist==1.8.0
fsspec==2026.1.0
h11==0.16.0
hf-xet==1.2.0
httpcore==1.0.9
httpx==0.28.1
huggingface-hub==1.3.7
humanfriendly==10.0
idna==3.11
jinja2==3.1.6
jiter==0.13.0
loguru==0.7.3
markupsafe==2.1.5
mouseinfo==0.1.3
mpmath==1.3.0
multidict==6.7.1
networkx==3.6.1
numpy==2.4.2
onnxruntime==1.23.2
openai==2.16.0
packaging==26.0
pillow==12.1.0
propcache==0.4.1
protobuf==6.33.5
pyautogui==0.9.54
pycparser==3.0
pydantic==2.12.5
pydantic-core==2.41.5
pydantic-settings==2.12.0
pygetwindow==0.0.9
pymsgbox==2.0.1
pyperclip==1.11.0
pyreadline3==3.5.4
pyrect==0.2.0
pyscreeze==1.0.1
python-dotenv==1.2.1
python-multipart==0.0.22
pytweening==1.2.0
pyyaml==6.0.3
regex==2026.1.15
requests==2.32.5
setuptools==80.10.2
shellingham==1.5.4
sniffio==1.3.1
sounddevice==0.5.5
soundfile==0.13.1
starlette==0.50.0
sympy==1.14.0
tiktoken==0.12.0
tokenizers==0.22.2
torch==2.5.1+cpu
torchaudio==2.5.1+cpu
torchvision==0.20.1+cpu
tqdm==4.67.2
typer-slim==0.21.1
typing-extensions==4.15.0
typing-inspection==0.4.2
urllib3==2.6.3
uvicorn==0.40.0
websockets==16.0
win32-setctime==1.2.0
yarl==1.22.0
Binary file not shown.
+452
View File
@@ -0,0 +1,452 @@
"""
零初始化 JSON 配置管理
用法:from src.config import cfg # 直接访问,自动初始化
"""
import json
import threading
from pathlib import Path
from typing import Any, Dict, Optional, TypeVar
from dataclasses import dataclass, field, asdict
def _project_root() -> Path:
"""自动查找项目根目录"""
markers = ['pyproject.toml', 'settings.json', '.gitignore', 'main.py', '.python-version']
current = Path(__file__).resolve().parent.parent # src的父目录
for path in [current, *current.parents]:
if any((path / m).exists() for m in markers):
return path
if path == path.parent:
break
return current
# 配置定义
@dataclass
class AIConfig:
api_key: Optional[str] = "sk-xxxxx"
base_url: str = "http://localhost:1234/v1"
model_name: str = "qwen/qwen3-4b-2507"
timeout: int = 30
temperature: float = 0.4
max_tokens: int = 8192
@dataclass
class TTSConfig:
enabled: bool = True
api_key: Optional[str] = None
gpt_model_name: str = "GPT_weights_v2Pro/Yosuga_Airi-e32.ckpt"
sovits_model_name: str = "SoVITS_weights_v2Pro/Yosuga_Airi_e16_s864.pth"
host: str = "localhost"
port: int = 20261
reference_audio: str = "./using/reference.wav"
streaming: bool = True
speed: float = 1.0
@dataclass
class ASRConfig:
enabled: bool = True
api_key: Optional[str] = None
model_name: str = "fast-whisper"
url: str = "http://localhost:20260/"
@dataclass
class AutoAgentConfig:
enabled: bool = True
api_key: Optional[str] = None
deployment_type: str = "lmstudio"
model_name: str = "ui-tars-1.5-7b@q4_k_m"
base_url: str = "http://localhost:1234/v1"
temperature: float = 0.1
max_tokens: int = 16384
@dataclass
class LLMConfig:
enabled: bool = True
role_character: str = "你是由Misakiotoha开发的助手稲葉愛理ちゃん,可以和用户一起玩游戏,聊天,做各种事情,性格抽象,没事爱整整活。"
max_context_tokens: int = 2048
enable_history: bool = True
language: str = "日本语"
@dataclass
class PathsConfig:
temp: str = "./tmp/"
log: str = "./log/"
using: str = "./using/"
class AppConfig:
"""
应用主配置
新增配置分组:1) 上方新建 dataclass 2) 下方 __init__ 添加字段 3) 完成
"""
def __init__(
self,
version: str = "1.0.0",
debug: bool = False,
ai: Optional[AIConfig] = None,
tts: Optional[TTSConfig] = None,
asr: Optional[ASRConfig] = None,
auto_agent: Optional[AutoAgentConfig] = None,
llm_core: Optional[LLMConfig] = None,
paths: Optional[PathsConfig] = None,
_config_path: Optional[Path] = None,
**kwargs
):
# 基础字段
self.version = version
self.debug = debug
self.ai = ai if ai is not None else AIConfig()
self.tts = tts if tts is not None else TTSConfig()
self.asr = asr if asr is not None else ASRConfig()
self.auto_agent = auto_agent if auto_agent is not None else AutoAgentConfig()
self.llm_core = llm_core if llm_core is not None else LLMConfig()
self.paths = paths if paths is not None else PathsConfig()
# 内部状态(非 dataclass 字段,不会被序列化)
self._config_path = _config_path
self._lock = threading.RLock() # 普通属性,非 dataclass 字段
# 应用其他字段(用于从 JSON 加载时)
for k, v in kwargs.items():
if hasattr(self, k):
setattr(self, k, v)
# 路径解析为绝对路径
if self._config_path:
root = self._config_path.parent
for field_name in ['temp', 'log', 'using']:
rel_path = getattr(self.paths, field_name)
if not Path(rel_path).is_absolute():
abs_path = (root / Path(rel_path)).resolve()
setattr(self.paths, field_name, str(abs_path) + '/')
# 便捷属性
@property
def temp_dir(self) -> Path:
return Path(self.paths.temp)
@property
def log_dir(self) -> Path:
return Path(self.paths.log)
@property
def using_dir(self) -> Path:
return Path(self.paths.using)
# 核心方法
def get(self, key: str, default: Any = None) -> Any:
"""
点号路径访问:cfg.get("ai.timeout") / cfg.get("tts.enabled")
"""
try:
keys = key.split('.')
value = self
for k in keys:
value = getattr(value, k) if not isinstance(value, dict) else value[k]
return value
except (AttributeError, KeyError):
return default
def set(self, key: str, value: Any, save: bool = True) -> 'AppConfig':
"""
点号路径设置,支持链式调用
cfg.set("ai.timeout", 60).set("debug", True)
"""
with self._lock:
keys = key.split('.')
target = self
for k in keys[:-1]:
target = getattr(target, k)
setattr(target, keys[-1], value)
if save:
self._save()
return self
def update(self, updates: Dict[str, Any], save: bool = True) -> 'AppConfig':
"""
批量更新:cfg.update({"ai": {"timeout": 60}, "debug": True})
"""
def deep_update(obj: Any, data: dict):
for k, v in data.items():
if hasattr(obj, k):
current = getattr(obj, k)
if isinstance(v, dict) and hasattr(current, '__dataclass_fields__'):
deep_update(current, v)
else:
setattr(obj, k, v)
with self._lock:
deep_update(self, updates)
if save:
self._save()
return self
def reload(self) -> 'AppConfig':
"""热重载配置"""
if self._config_path and self._config_path.exists():
with self._lock:
data = json.loads(self._config_path.read_text(encoding='utf-8'))
# 配置项名 -> dataclass 类的映射(和 _load 保持一致)
config_classes = {
'ai': AIConfig,
'tts': TTSConfig,
'asr': ASRConfig,
'auto_agent': AutoAgentConfig,
'llm_core': LLMConfig,
'paths': PathsConfig,
}
for k, v in data.items():
if hasattr(self, k) and not k.startswith('_'):
# 如果是配置项且是 dict,转换为 dataclass
if k in config_classes and isinstance(v, dict):
setattr(self, k, config_classes[k](**v))
else:
setattr(self, k, v)
print(f"配置重载: {self._config_path}")
return self
def to_dict(self) -> dict:
"""导出为字典(手动实现,排除内部属性)"""
result = {
'version': self.version,
'debug': self.debug,
'ai': asdict(self.ai) if hasattr(self.ai, '__dataclass_fields__') else self.ai,
'tts': asdict(self.tts) if hasattr(self.tts, '__dataclass_fields__') else self.tts,
'asr': asdict(self.asr) if hasattr(self.asr, '__dataclass_fields__') else self.asr,
'auto_agent': asdict(self.auto_agent) if hasattr(self.auto_agent, '__dataclass_fields__') else self.auto_agent,
'llm_core': asdict(self.llm_core) if hasattr(self.llm_core, '__dataclass_fields__') else self.llm_core,
'paths': asdict(self.paths) if hasattr(self.paths, '__dataclass_fields__') else self.paths,
}
return result
def _save(self) -> None:
"""保存到文件"""
if self._config_path:
with self._lock:
json_str = json.dumps(self.to_dict(), indent=2, ensure_ascii=False)
self._config_path.write_text(json_str, encoding='utf-8')
@classmethod
def _load(cls, path: Path) -> 'AppConfig':
"""从文件加载"""
data = json.loads(path.read_text(encoding='utf-8'))
# 配置项名 -> dataclass 类的映射
config_classes = {
'ai': AIConfig,
'tts': TTSConfig,
'asr': ASRConfig,
'auto_agent': AutoAgentConfig,
'llm_core': LLMConfig,
'paths': PathsConfig,
}
# 自动转换 dict 为对应 dataclass
for key, config_class in config_classes.items():
if key in data and isinstance(data[key], dict):
data[key] = config_class(**data[key])
return cls(_config_path=path, **data)
@classmethod
def _create_default(cls, path: Path) -> 'AppConfig':
"""创建默认配置"""
instance = cls(_config_path=path)
instance._save()
print(f"默认配置已创建: {path}")
return instance
def __repr__(self) -> str:
"""友好的打印格式"""
lines = ["AppConfig("]
for k, v in self.to_dict().items():
lines.append(f" {k}={v!r},")
lines.append(")")
return "\n".join(lines)
# 延迟初始化机制
_root: Path = _project_root()
_config_path: Path = _root / "settings.json"
_config_instance: Optional[AppConfig] = None
_init_lock: threading.Lock = threading.Lock()
def _ensure_initialized() -> AppConfig:
"""
确保配置已初始化(线程安全的延迟初始化)
"""
global _config_instance
if _config_instance is not None:
return _config_instance
with _init_lock:
if _config_instance is not None:
return _config_instance
# 自动加载或创建
if _config_path.exists():
_config_instance = AppConfig._load(_config_path)
print(f"配置加载: {_config_path}")
else:
_config_instance = AppConfig._create_default(_config_path)
# 确保目录存在
for d in [_config_instance.temp_dir,
_config_instance.log_dir,
_config_instance.using_dir]:
d.mkdir(parents=True, exist_ok=True)
return _config_instance
class _LazyConfig:
"""
配置代理类:拦截所有属性访问,第一次使用时自动初始化
"""
def __getattr__(self, name: str) -> Any:
"""拦截属性访问,延迟初始化"""
instance = _ensure_initialized()
return getattr(instance, name)
def __setattr__(self, name: str, value: Any) -> None:
"""拦截属性设置"""
# 特殊属性直接设置到代理对象本身
if name.startswith('_'):
super().__setattr__(name, value)
else:
instance = _ensure_initialized()
setattr(instance, name, value)
def __repr__(self) -> str:
instance = _ensure_initialized()
return repr(instance)
def __dir__(self) -> list:
"""支持 IDE 自动补全"""
instance = _ensure_initialized()
return dir(instance)
# 显式代理必要方法
def get(self, key: str, default: Any = None) -> Any:
return _ensure_initialized().get(key, default)
def set(self, key: str, value: Any, save: bool = True) -> AppConfig:
return _ensure_initialized().set(key, value, save)
def update(self, updates: Dict[str, Any], save: bool = True) -> AppConfig:
return _ensure_initialized().update(updates, save)
def reload(self) -> AppConfig:
return _ensure_initialized().reload()
def to_dict(self) -> dict:
return _ensure_initialized().to_dict()
def save(self) -> None:
_ensure_initialized()._save()
# 属性代理
@property
def ai(self) -> AIConfig:
return _ensure_initialized().ai
@property
def tts(self) -> TTSConfig:
return _ensure_initialized().tts
@property
def asr(self) -> ASRConfig:
return _ensure_initialized().asr
@property
def auto_agent(self) -> AutoAgentConfig:
return _ensure_initialized().auto_agent
@property
def llm_core(self) -> LLMConfig:
return _ensure_initialized().llm_core
@property
def paths(self) -> PathsConfig:
return _ensure_initialized().paths
@property
def temp_dir(self) -> Path:
return _ensure_initialized().temp_dir
@property
def log_dir(self) -> Path:
return _ensure_initialized().log_dir
@property
def using_dir(self) -> Path:
return _ensure_initialized().using_dir
# 全局配置对象:导入即用,自动初始化
cfg: AppConfig = _LazyConfig() # type: ignore
# 工具函数
def generate_example(path: Path = _root / "settings.example.json") -> None:
"""生成示例配置文件"""
example = {
"version": "1.0.0",
"debug": False,
"ai": {
"api_key": "sk-your-api-key",
"base_url": "https://api.deepseek.com",
"model": "deepseek-chat",
"timeout": 30
},
"tts": {
"enabled": True,
"model": "GPT_SoVITS",
"url": "http://localhost:12458/",
"reference_audio": "./using/reference.wav",
"streaming": True,
"speed": 1.0
},
"paths": {
"temp": "./tmp/",
"log": "./log/",
"data": "./data/"
}
}
path.write_text(json.dumps(example, indent=2, ensure_ascii=False), encoding='utf-8')
print(f"示例配置已生成: {path}")
# 测试代码
if __name__ == "__main__":
# 测试:直接访问,自动初始化
print("第一次访问 cfg.ai.model:")
print(f"{cfg.ai.model_name}")
print(f"\n配置详情:")
print(cfg)
print(f"\n测试修改:")
cfg.set("ai.timeout", 60)
print(f" ai.timeout = {cfg.ai.timeout}")
print(f"\n测试批量更新:")
cfg.update({"debug": True, "tts": {"speed": 1.5}})
print(f" debug = {cfg.debug}, tts.speed = {cfg.tts.speed}")
print(f"\n测试热重载:")
cfg.reload()
View File
View File
View File
+123
View File
@@ -0,0 +1,123 @@
# asr_module/api.py
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from pathlib import Path
import tempfile
import time
from datetime import datetime
from loguru import logger
from src.modules.asr_module.asr_core.fast_whisper import create_asr, ASRConfig
# 初始化FastAPI应用
app = FastAPI(
title="Yosuga ASR API",
description="基于faster-whisper Turbo的高性能多语种语音转文本服务",
version="1.0.0"
)
# 全局单例ASR实例(延迟加载)
_asr_instance = None
def get_asr():
"""获取或创建ASR实例(单例)"""
global _asr_instance
if _asr_instance is None:
logger.info("🚀 初始化ASR服务...")
_asr_instance = create_asr(
ASRConfig(
model_name="deepdml/faster-whisper-large-v3-turbo-ct2",
device="auto",
compute_type="int8_float16",
cache_dir=Path("asr_models/faster_whisper_large_v3_ct2"),
beam_size=1, # 贪婪搜索,速度最快
vad_filter=True, # 过滤静音,节省30%时间
)
)
logger.info("✅ ASR服务初始化完成")
return _asr_instance
@app.on_event("startup")
async def startup_event():
"""应用启动时预加载模型"""
get_asr()
@app.on_event("shutdown")
async def shutdown_event():
"""应用关闭时清理资源"""
global _asr_instance
if _asr_instance:
_asr_instance.shutdown()
logger.info("🛑 ASR服务已关闭")
@app.post("/transcribe", response_class=JSONResponse)
async def transcribe_audio(
file: UploadFile = File(..., description="音频文件 (WAV, FLAC, MP3等格式)")
):
"""
语音转文本API
- **file**: 音频文件,支持WAV/FLAC/MP3等格式
- **返回**: JSON格式结果,包含text/language/confidence
"""
start_time = time.time()
# 验证文件类型
if file.content_type and not file.content_type.startswith("audio/"):
raise HTTPException(status_code=400, detail="❌ 请上传音频文件 (MIME类型: audio/*)")
try:
# 创建临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as tmp_file:
content = await file.read()
tmp_file.write(content)
tmp_path = Path(tmp_file.name)
logger.info(f"📥 接收文件: {file.filename} ({len(content)} bytes)")
# 调用ASR识别
asr = get_asr()
text, language, confidence = asr.transcribe_wav(tmp_path)
# 清理临时文件
tmp_path.unlink(missing_ok=True)
processing_time = time.time() - start_time
logger.info(f"✅ 识别完成: {language} | {len(text)}字符 | 置信度:{confidence:.2f} | 耗时:{processing_time:.3f}s")
return {
"success": True,
"data": {
"text": text,
"language": language,
"confidence": confidence,
"processing_time": round(processing_time, 3)
}
}
except Exception as e:
logger.error(f"❌ 识别失败: {e}")
raise HTTPException(status_code=500, detail=f"识别失败: {str(e)}")
@app.get("/health")
async def health_check():
"""健康检查接口"""
asr = get_asr()
health = asr.health_check()
return {
"status": "healthy" if health["status"] == "healthy" else "unhealthy",
"timestamp": datetime.now().isoformat(),
"device": health["device"],
"model_loaded": health["model_loaded"]
}
@app.get("/")
async def root():
"""API根路径"""
return {
"message": "Yosuga ASR API 正在运行",
"docs": "/docs",
"health": "/health",
"transcribe": "/transcribe (POST)"
}
@@ -0,0 +1,16 @@
# fast_whisper/__init__.py
from typing import Optional
from src.modules.asr_module.asr_core.fast_whisper.config import ASRConfig
from src.modules.asr_module.asr_core.fast_whisper.model_manager import ModelManager
from src.modules.asr_module.asr_core.fast_whisper.asr_interface import ASRInterface
__version__ = "1.0.0"
__all__ = ["ASRConfig", "ModelManager", "ASRInterface"]
def create_asr(config: Optional[ASRConfig] = None) -> ASRInterface:
"""
快速创建ASR实例
Args:
config: ASR配置,若为None则使用默认配置
"""
return ASRInterface.get_instance(config)
@@ -0,0 +1,184 @@
# fast_whisper/asr_interface.py
from loguru import logger
from pathlib import Path
from typing import Tuple, Optional
import torchaudio
import torch
import numpy
from .model_manager import ModelManager
from .config import ASRConfig
from .utils import PerformanceProfiler
class ASRInterface:
"""
ASR接口类 - 全局单例
- 提供wav转文本功能
- 注入ModelManager
- 性能统计
"""
_instance: Optional['ASRInterface'] = None
def __new__(cls, *args, **kwargs):
"""单例模式实现"""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, config: Optional[ASRConfig] = None):
# 防止重复初始化
if hasattr(self, '_initialized') and self._initialized:
return
self.config = config or ASRConfig()
self.model_manager = ModelManager(self.config)
self.profiler = PerformanceProfiler(self.config.enable_profiling)
# 音频参数
self.sample_rate = 16000
self._initialized = True
logger.info("🎤 ASR接口初始化完成")
@classmethod
def get_instance(cls, config: Optional[ASRConfig] = None) -> 'ASRInterface':
"""全局访问点"""
if cls._instance is None:
cls._instance = cls(config)
return cls._instance
def transcribe_wav(
self,
wav_path: Path,
language: Optional[str] = None
) -> Tuple[str, str, float]:
"""
WAV音频转文本(核心接口)
Args:
wav_path: WAV文件路径
language: 指定语言代码(如'zh'/'en'),None则自动检测
Returns:
(text, language, confidence)
"""
try:
# 记录开始时间
import time
start_time = time.time()
logger.info(f"🎵 开始识别: {wav_path.name}")
# 执行识别...
audio = self._load_audio(wav_path)
result = self._transcribe(audio, language)
text, lang, confidence = self._parse_result(result)
# 计算耗时
processing_time = time.time() - start_time
logger.info(
f"✅ 识别完成: {lang} | {len(text)}字符 | 置信度:{confidence:.2f} | "
f"耗时:{processing_time:.3f}s | RTF:{processing_time/(len(audio)/self.sample_rate):.3f}"
)
return text, lang, confidence
except Exception as e:
logger.error(f"❌ 识别失败 {wav_path}: {e}")
raise RuntimeError(f"Transcription failed: {e}")
def _load_audio(self, wav_path: Path) -> numpy.ndarray:
"""加载和预处理音频"""
if not wav_path.exists():
raise FileNotFoundError(f"音频文件不存在: {wav_path}")
# 加载音频
waveform, sample_rate = torchaudio.load(wav_path)
# 重采样到16kHz
if sample_rate != self.sample_rate:
resampler = torchaudio.transforms.Resample(sample_rate, self.sample_rate)
waveform = resampler(waveform)
# 转换为单声道
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# 转换为numpy数组
audio = waveform.squeeze().numpy()
return audio
def _transcribe(self, audio: numpy.ndarray, language: Optional[str]) -> Tuple:
"""执行推理"""
model = self.model_manager.model
# 添加模型存在性检查
if model is None:
logger.error("ASR模型未加载,请检查模型配置和路径")
raise RuntimeError("ASR模型未加载,请检查模型配置和路径")
# 记录时间
import time
start_time = time.time()
# 调用模型
segments, info = model.transcribe(
audio,
language=language,
beam_size=self.config.beam_size,
best_of=self.config.best_of,
vad_filter=self.config.vad_filter,
)
# 立即执行生成器
segments_list = list(segments)
# 性能统计
inference_time = time.time() - start_time
audio_duration = len(audio) / self.sample_rate
self.profiler.record(audio_duration, inference_time)
return segments_list, info
def _parse_result(self, result: Tuple) -> Tuple[str, str, float]:
"""解析识别结果"""
segments, info = result
# 合并所有片段
text = " ".join([seg.text.strip() for seg in segments])
# 获取语言信息
language = info.language if info else "unknown"
confidence = info.language_probability if info else 0.0
return text, language, confidence
def transcribe_batch(self, wav_paths: list) -> list:
"""批量识别接口"""
return [
{
"file": str(path),
"text": result[0],
"language": result[1],
"confidence": result[2]
}
for path, result in zip(wav_paths, [
self.transcribe_wav(Path(p)) for p in wav_paths
])
]
def health_check(self) -> dict:
"""健康检查接口"""
return {
"status": "healthy" if self.model_manager.model else "unhealthy",
"device": self.config.device,
"model_loaded": self.model_manager.model is not None,
"device_info": self.model_manager.get_device_info(),
}
def shutdown(self):
"""优雅关闭"""
logger.info("🛑 关闭ASR接口...")
self.model_manager.unload()
@@ -0,0 +1,29 @@
# fast_whisper/config.py
from dataclasses import dataclass
from pathlib import Path
import torch
@dataclass
class ASRConfig:
"""ASR配置类"""
model_name: str = "deepdml/faster-whisper-large-v3-turbo-ct2"
device: str = "auto"
compute_type: str = "int8_float16"
cache_dir: Path = Path.home() / ".cache" / "faster_whisper"
# 速度优化参数
beam_size: int = 1
best_of: int = 1
vad_filter: bool = True
batch_size: int = 16
# 性能统计
enable_profiling: bool = True
def __post_init__(self):
if self.device == "auto":
self.device = "cuda" if torch.cuda.is_available() else "cpu"
if self.device == "cpu":
self.compute_type = "int8"
self.batch_size = 4
@@ -0,0 +1,92 @@
# fast_whisper/model_manager.py
import gc
from loguru import logger
from pathlib import Path
from typing import Optional
from faster_whisper import WhisperModel
import torch
from .config import ASRConfig
class ModelManager:
"""
模型管理类
- 负责模型生命周期管理
- 支持自定义缓存目录
- 自动硬件适配
"""
def __init__(self, config: ASRConfig):
self.config = config
self._model: Optional[WhisperModel] = None
self._device_info = None
# 确保缓存目录存在
self.config.cache_dir.mkdir(parents=True, exist_ok=True)
@property
def model(self) -> Optional[WhisperModel]:
"""懒加载模型"""
if self._model is None:
self._load_model()
return self._model
def _load_model(self):
"""加载模型"""
logger.info(f"🚀 初始化模型: {self.config.model_name}")
logger.info(f"📦 设备: {self.config.device}, 计算类型: {self.config.compute_type}")
try:
self._model = WhisperModel(
self.config.model_name,
device=self.config.device,
compute_type=self.config.compute_type,
download_root=str(self.config.cache_dir),
local_files_only=False,
)
self._device_info = {
"device": self.config.device,
"compute_type": self.config.compute_type,
"model_size": self.config.model_name.split("-")[-2]
}
logger.info("✅ 模型加载成功")
except Exception as e:
logger.error(f"❌ 模型加载失败: {e}")
raise RuntimeError(f"Failed to load ASR model: {e}")
def reload(self, new_config: ASRConfig):
"""热重载模型"""
logger.info("🔄 热重载模型...")
self.unload()
self.config = new_config
self._load_model()
def unload(self):
"""卸载模型释放资源"""
if self._model is not None:
logger.info("🗑️ 卸载模型...")
del self._model
self._model = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
logger.info("✅ 模型已卸载")
def get_device_info(self) -> dict:
"""获取设备信息"""
return self._device_info or {}
def __enter__(self):
"""上下文管理器支持"""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""自动清理资源"""
self.unload()
@@ -0,0 +1,45 @@
# fast_whisper/utils.py
from loguru import logger
from datetime import datetime
from typing import Dict, Any
def check_hardware() -> Dict[str, Any]:
"""硬件检测"""
import torch
info = {
"cuda_available": torch.cuda.is_available(),
"device_name": "CPU",
"device_count": 0,
"compute_type": "int8"
}
if info["cuda_available"]:
info.update({
"device_name": torch.cuda.get_device_name(0),
"device_count": torch.cuda.device_count(),
"compute_type": "int8_float16"
})
return info
class PerformanceProfiler:
"""性能分析器"""
def __init__(self, enable: bool = True):
self.enable = enable
self.stats = []
def record(self, audio_duration: float, inference_time: float):
if not self.enable:
return
rtf = inference_time / audio_duration if audio_duration > 0 else 0
self.stats.append({
"timestamp": datetime.now().isoformat(),
"rtf": rtf,
"audio_duration": audio_duration,
"inference_time": inference_time
})
if len(self.stats) % 10 == 0:
avg_rtf = sum(s["rtf"] for s in self.stats[-10:]) / 10
logger.info(f"📊 最近10次平均RTF: {avg_rtf:.3f}")
+215
View File
@@ -0,0 +1,215 @@
# asr_module/client/asr_client.py
import asyncio
import time
from pathlib import Path
from typing import Union, Optional
import aiofiles
import aiohttp
import requests
from loguru import logger
from .models import ASRResponse, ASRHealthStatus, ServiceInfo
class ASRException(Exception):
"""ASR服务调用异常"""
def __init__(self, message: str, status_code: Optional[int] = None):
self.message = message
self.status_code = status_code
super().__init__(self.message)
class ASRClientConfig:
"""客户端配置"""
def __init__(
self,
base_url: str = "http://localhost:8000",
timeout: float = 30.0,
retry_count: int = 2,
retry_delay: float = 0.5,
):
self.base_url = base_url.rstrip('/')
self.timeout = timeout
self.retry_count = retry_count
self.retry_delay = retry_delay
# 同步客户端
class ASRClientSync:
"""同步ASR客户端"""
def __init__(self, config: Optional[ASRClientConfig] = None):
self.config = config or ASRClientConfig()
self.session = requests.Session()
self.session.timeout = self.config.timeout
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.session.close()
def _request(self, method: str, endpoint: str, **kwargs) -> dict:
"""统一请求处理(带重试)"""
url = f"{self.config.base_url}{endpoint}"
for attempt in range(self.config.retry_count + 1):
try:
response = self.session.request(method, url, **kwargs)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
if attempt < self.config.retry_count:
logger.warning(f"请求失败,重试中 ({attempt + 1}/{self.config.retry_count}): {e}")
time.sleep(self.config.retry_delay)
else:
logger.error(f"请求最终失败: {e}")
raise ASRException(f"API调用失败: {e}", getattr(e.response, 'status_code', None))
def transcribe_file(self, file_path: Union[str, Path]) -> ASRResponse:
"""
转录音频文件
Args:
file_path: 音频文件路径
Returns:
ASRResponse对象
Example:
client = ASRClientSync()
result = client.transcribe_file("/path/to/audio.wav")
print(result.data.text)
"""
file_path = Path(file_path)
if not file_path.exists():
raise FileNotFoundError(f"文件不存在: {file_path}")
logger.info(f"📤 上传文件: {file_path.name}")
with open(file_path, 'rb') as f:
files = {'file': (file_path.name, f, 'audio/wav')}
result = self._request('POST', '/transcribe', files=files)
return ASRResponse(**result)
def transcribe_bytes(self, audio_data: bytes, filename: str = "audio.wav") -> ASRResponse:
"""
转录音频字节流
Args:
audio_data: 原始音频字节
filename: 模拟文件名(用于MIME类型推断)
Returns:
ASRResponse对象
Example:
with open('audio.wav', 'rb') as f:
audio_bytes = f.read()
result = client.transcribe_bytes(audio_bytes)
"""
logger.info(f"📤 上传字节流 ({len(audio_data)} bytes)")
files = {'file': (filename, audio_data, 'audio/wav')}
result = self._request('POST', '/transcribe', files=files)
return ASRResponse(**result)
def health_check(self) -> ASRHealthStatus:
"""健康检查"""
result = self._request('GET', '/health')
return ASRHealthStatus(**result)
def get_service_info(self) -> ServiceInfo:
"""获取服务信息"""
result = self._request('GET', '/')
return ServiceInfo(**result)
# 异步客户端
class ASRClientAsync:
"""异步ASR客户端"""
def __init__(self, config: Optional[ASRClientConfig] = None):
self.config = config or ASRClientConfig()
self._session: Optional[aiohttp.ClientSession] = None
async def __aenter__(self):
self._session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=self.config.timeout)
)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self._session:
await self._session.close()
async def _ensure_session(self):
if self._session is None:
self._session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=self.config.timeout)
)
async def _request(self, method: str, endpoint: str, **kwargs) -> dict:
"""统一异步请求(带重试)"""
await self._ensure_session()
url = f"{self.config.base_url}{endpoint}"
for attempt in range(self.config.retry_count + 1):
try:
async with self._session.request(method, url, **kwargs) as response:
response.raise_for_status()
return await response.json()
except aiohttp.ClientError as e:
if attempt < self.config.retry_count:
logger.warning(f"请求失败,重试中 ({attempt + 1}/{self.config.retry_count}): {e}")
await asyncio.sleep(self.config.retry_delay)
else:
logger.error(f"请求最终失败: {e}")
raise ASRException(f"API调用失败: {e}", getattr(e, 'status', None))
async def transcribe_file(self, file_path: Union[str, Path]) -> ASRResponse:
"""异步转录音频文件"""
file_path = Path(file_path)
if not file_path.exists():
raise FileNotFoundError(f"文件不存在: {file_path}")
logger.info(f"📤 上传文件: {file_path.name}")
async with aiofiles.open(file_path, 'rb') as f:
audio_data = await f.read()
return await self.transcribe_bytes(audio_data, file_path.name)
async def transcribe_bytes(self, audio_data: bytes, filename: str = "audio.wav") -> ASRResponse:
"""异步转录音频字节流"""
logger.info(f"📤 上传字节流 ({len(audio_data)} bytes)")
await self._ensure_session() # 确保session已创建
form = aiohttp.FormData() # 创建表单数据
form.add_field('file', audio_data, filename=filename, content_type='audio/wav') # 添加文件字段
result = await self._request('POST', '/transcribe', data=form) # 发送POST请求
return ASRResponse(**result) # 返回结果
async def health_check(self) -> ASRHealthStatus:
"""异步健康检查"""
result = await self._request('GET', '/health')
return ASRHealthStatus(**result)
async def get_service_info(self) -> ServiceInfo:
"""异步获取服务信息"""
result = await self._request('GET', '/')
return ServiceInfo(**result)
# 工厂函数
def create_asr_client(use_async: bool = False, **config_kwargs) -> Union[ASRClientSync, ASRClientAsync]:
"""
创建客户端工厂函数
Args:
use_async: 是否创建异步客户端
**config_kwargs: ASRClientConfig参数
Returns:
同步或异步客户端实例
"""
config = ASRClientConfig(**config_kwargs)
if use_async:
return ASRClientAsync(config)
return ASRClientSync(config)
+30
View File
@@ -0,0 +1,30 @@
# asr_module/client/models.py
from pydantic import BaseModel
from typing import Optional
class ASRHealthStatus(BaseModel):
"""ASR服务健康状态"""
status: str # 状态
timestamp: str # 时间戳
device: str # 设备
model_loaded: bool # 模型是否加载
class ASRResult(BaseModel):
"""语音识别结果"""
text: str # 识别结果
language: str # 语言
confidence: float # 置信度
processing_time: float # 处理时间
class ASRResponse(BaseModel):
"""统一API响应"""
success: bool
data: Optional[ASRResult] = None
error: Optional[str] = None
class ServiceInfo(BaseModel):
"""服务信息"""
message: str # 消息
docs: str # 文档
health: str # 健康
transcribe: str # 识别
+12
View File
@@ -0,0 +1,12 @@
本模块为asr模块,即语音转文本模块。
本模块提供了两种访问方式:
- 第一种为本地部署的func call方式,即提供函数调用的方式调用相关asr接口。
- 第二种为http call方式,即提供http call方式调用相关asr接口。[FastAPI]
如果你的电脑显卡支持cuda, 并且显存大小大于8G, 那么可以使用第一种方式,
否则可以使用第二种,进行云端部署。
两种调用方式在相同配置下,性能几乎无差别。
个人建议使用第二种
+65
View File
@@ -0,0 +1,65 @@
# start_api.py
import uvicorn
from loguru import logger
import threading
import time
def start_server():
"""启动 ASR API 服务"""
uvicorn.run(
"api:app", # 模块名:app实例
host="0.0.0.0",
port=20260,
workers=1, # 单用户场景,1个worker足够
log_level="info",
reload=False, # 生产环境关闭热重载
access_log=True,
)
def first_test() -> None:
"""首次启动测试"""
time.sleep(5) # 给服务器一些启动时间
# 构造一个测试请求以验证初始化模型加载成功
logger.info("🚀 测试模型是否加载成功...")
import requests
from pathlib import Path
url = "http://localhost:20260/transcribe"
audio_path = Path("../../../Test/test_files/test.wav")
try:
with open(audio_path, "rb") as f:
# 明确指定文件名和 MIME 类型
files = {
"file": (
audio_path.name, # 文件名
f, # 文件对象
"audio/wav" # MIME 类型
)
}
response = requests.post(url, files=files)
logger.info(f"状态码: {response.status_code}")
logger.info(f"响应头: {response.headers.get('content-type')}")
if response.status_code == 200:
result = response.json()
logger.info(f"识别结果: {result['data']['text']}")
logger.info(f"识别语言: {result['data']['language']}")
logger.info(f"置信度: {result['data']['confidence']:.2f}")
logger.info(f"处理时间: {result['data']['processing_time']}s")
else:
logger.error(f"请求失败,错误响应信息: {response.text}")
logger.error("请检查模型是否正确加载或其他问题")
except Exception as e:
logger.error(f"测试过程中发生错误: {e}")
if __name__ == "__main__":
logger.info("🚀 启动 ASR API 服务...")
# 在后台线程启动服务器
server_thread = threading.Thread(target=start_server, daemon=True)
server_thread.start()
# 执行测试
first_test()
# 保持主线程运行
server_thread.join()
@@ -0,0 +1,119 @@
# ui_tars_/ui_tars_client.py
from typing import Optional
from loguru import logger
import asyncio
from src.modules.text_ai_module.text_ai_core.general_text_ai_req import (
UnifiedLLM,
ModelConfig,
ModelProvider,
ChatMessage,
)
from pydantic import BaseModel, Field
from src.modules.device_control_module.device_control_core.ui_tars_.ui_tars_prompts import UI_TARS_SYSTEM_PROMPT
class UITarsClientConfig(BaseModel):
"""UI-TARS 客户端配置"""
deployment_type: str = Field(default="lmstudio", description="部署类型")
base_url: str = Field(default="http://localhost:1234/v1", description="API地址")
model_name: str = Field(default="ui-tars", description="模型名称")
api_key: Optional[str] = Field(default=None, description="API密钥")
temperature: float = Field(default=0.1, ge=0.0, le=2.0)
max_tokens: int = Field(default=8192, ge=2048, le=128000)
timeout: int = Field(default=30, ge=5, le=300)
# UI-TARS-1.5 强制输出格式
system_prompt: str = Field(
default=UI_TARS_SYSTEM_PROMPT # 使用本项目自定义的输出格式约束
)
def to_model_config(self) -> ModelConfig:
"""转换为 UnifiedLLM 配置"""
# 映射部署类型到 ModelProvider
provider_map = {
"lmstudio": ModelProvider.LM_STUDIO,
"vllm": ModelProvider.CUSTOM,
"cloud": ModelProvider.OPENAI,
"ollama": ModelProvider.OLLAMA
}
provider = provider_map.get(self.deployment_type, ModelProvider.CUSTOM)
return ModelConfig(
provider=provider,
model_name=self.model_name,
api_key=self.api_key,
base_url=self.base_url,
temperature=self.temperature,
max_tokens=self.max_tokens,
timeout=self.timeout,
custom_headers={"User-Agent": "UI-TARS-Client/1.0"}
)
class UITarsClient:
"""
UI-TARS 通用客户端 (基于 UnifiedLLM)
图片相关信息请直接传入相应的base64
"""
def __init__(self, config: UITarsClientConfig):
self.config = config
# 复用 UnifiedLLM,自动处理所有部署类型
self.llm = UnifiedLLM(config.to_model_config())
logger.info(f"UI-TARS 客户端初始化: {config.deployment_type} @ {config.base_url}")
logger.info(f" 模型: {config.model_name} | 温度: {config.temperature}")
async def call_async(self, instruction: str, image_base64: str) -> str:
"""异步调用 UI-TARS"""
# 构建消息
messages = self._build_messages(instruction, image_base64)
try:
# 使用 UnifiedLLM 的异步接口
response = await asyncio.to_thread(
self.llm.chat,
messages=messages,
streaming=False
)
return response.content
except Exception as e:
logger.error(f"UI-TARS 调用失败: {e}")
raise
def call_sync(self, instruction: str, image_base64: str) -> str:
"""同步调用 UI-TARS"""
messages = self._build_messages(instruction, image_base64)
try:
response = self.llm.chat(
messages=messages,
streaming=False
)
return response.content
except Exception as e:
logger.error(f"UI-TARS 调用失败: {e}")
raise
def stream_async(self, instruction: str, image_base64: str):
"""流式调用 (异步生成器)"""
messages = self._build_messages(instruction, image_base64)
# UnifiedLLM 自动处理流式
return self.llm.stream_chat(messages=messages)
def _build_messages(self, instruction: str, image_base64: str) -> list:
"""构建 OpenAI 格式消息"""
return [
ChatMessage(
role="system",
content=self.config.system_prompt
),
ChatMessage(
role="user",
content=[
{"type": "text", "text": instruction},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{image_base64}"
}
}
]
)
]
@@ -0,0 +1,25 @@
OFFICIAL_ACTION_SPACE = """## Action Space
click(point='<point>x1 y1</point>') - 单击坐标
left_double(point='<point>x1 y1</point>') - 双击坐标
right_single(point='<point>x1 y1</point>') - 右键单击
drag(start_point='<point>x1 y1</point>', end_point='<point>x2 y2</point>') - 拖拽
hotkey(key='ctrl c') - 快捷键(空格分隔,小写,最多3个键)
type(content='xxx') - 输入文本(用\\' \\\" \\n 转义)
scroll(point='<point>x1 y1</point>', direction='down or up or right or left') - 滚动
wait() - 等待5秒
finished() - 任务完成
## Output Format
Thought: [你的推理过程]
Action: [选择一个动作]
"""
UI_TARS_SYSTEM_PROMPT = f"""You are UI-TARS-1.5, a GUI agent. Given a task and screenshot, output ONLY:
{OFFICIAL_ACTION_SPACE}
## Note
- Write a small plan and summarize the next action in one sentence in Thought.
- NEVER output multiple actions.
- x&y please in box center
"""
@@ -0,0 +1,10 @@
本模块为设备控制模块接口层,此处的设备控制指的是借助AI模型进行一些设备上的自动化
操作,支持`pc`, `android`,其他的未做过测试。
依赖:
`ui_tars`
当前所使用的AI模型为`mradermacher/UI-TARS-1.5-7B-GGUF
(Q6_K or Q4_K_M)`
未来如果有更快质量更高的AI模型本模块会为其添加支持。
@@ -0,0 +1,837 @@
"""
通用大语言模型调用框架
支持本地模型(Ollama, LM Studio, llama.cpp)和云端模型(OpenAI, Anthropic, Google, Azure等)
"""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Union, Any, Iterator
import json
import os
from loguru import logger
from dataclasses import dataclass, asdict, field
from enum import Enum
import httpx
from pydantic import BaseModel, Field
class ModelProvider(Enum):
"""支持的模型提供商枚举"""
OPENAI = "openai"
ANTHROPIC = "anthropic"
GOOGLE = "google"
AZURE = "azure"
OLLAMA = "ollama"
LM_STUDIO = "lm_studio"
LLAMA_CPP = "llama_cpp"
CUSTOM = "custom"
@dataclass
class ModelConfig:
"""模型配置类"""
provider: ModelProvider
model_name: str
api_key: Optional[str] = None
base_url: Optional[str] = None
api_version: Optional[str] = None
temperature: float = 0.7
max_tokens: int = 1024
top_p: float = 1.0
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
timeout: int = 30
streaming: bool = False
custom_headers: Dict[str, str] = field(default_factory=dict)
def to_dict(self) -> Dict:
"""转换为字典"""
return asdict(self)
@dataclass
class ChatMessage:
"""聊天消息类"""
role: str # system, user, assistant
content: str
name: Optional[str] = None
def to_dict(self) -> Dict:
"""转换为字典"""
return {
"role": self.role,
"content": self.content,
**({"name": self.name} if self.name else {})
}
@dataclass
class ModelResponse:
"""模型响应类"""
content: str # 响应内容
model: str # 模型名称
usage: Optional[Dict[str, int]] = None # 使用量
finish_reason: Optional[str] = None # 结束原因
raw_response: Optional[Dict] = None # 原始响应
def normalize_usage(raw_usage: Optional[Dict[str, Any]], provider: ModelProvider) -> Optional[Dict[str, int]]:
"""
将不同平台的 usage 字段统一归一化为 OpenAI 标准格式
Args:
raw_usage: API 原始返回的 usage 数据
provider: 模型提供商枚举
Returns:
归一化后的 usage 字典,格式:
{
"prompt_tokens": int,
"completion_tokens": int,
"total_tokens": int
}
如果无法归一化则返回 None
"""
if not raw_usage:
return None
# 字段映射表:{provider: (input_key, output_key, total_key)}
USAGE_FIELD_MAP = {
ModelProvider.OPENAI: ("prompt_tokens", "completion_tokens", "total_tokens"),
ModelProvider.AZURE: ("prompt_tokens", "completion_tokens", "total_tokens"),
ModelProvider.LM_STUDIO: ("prompt_tokens", "completion_tokens", "total_tokens"),
ModelProvider.LLAMA_CPP: ("prompt_tokens", "completion_tokens", "total_tokens"),
ModelProvider.OLLAMA: ("prompt_eval_count", "eval_count", None), # Ollama 没有 total
ModelProvider.ANTHROPIC: ("input_tokens", "output_tokens", None),
ModelProvider.GOOGLE: ("promptTokenCount", "candidatesTokenCount", "totalTokenCount"),
}
input_key, output_key, total_key = USAGE_FIELD_MAP.get(provider, (None, None, None))
if input_key is None:
logger.warning(f"未知的 provider '{provider}',无法归一化 usage")
return None
try:
# 提取字段值
prompt_tokens = raw_usage.get(input_key, 0)
completion_tokens = raw_usage.get(output_key, 0)
# 处理嵌套字典(如有些 API 的 usage 格式特殊)
if isinstance(prompt_tokens, dict):
prompt_tokens = prompt_tokens.get("value", 0)
if isinstance(completion_tokens, dict):
completion_tokens = completion_tokens.get("value", 0)
# 转换为整数
prompt_tokens = int(prompt_tokens) if prompt_tokens else 0
completion_tokens = int(completion_tokens) if completion_tokens else 0
# 计算 total(如果 API 没提供)
if total_key and total_key in raw_usage:
total_tokens = int(raw_usage[total_key])
else:
total_tokens = prompt_tokens + completion_tokens
normalized = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens
}
logger.debug(f"归一化 usage | {provider} -> OpenAI格式: {normalized}")
return normalized
except Exception as e:
logger.error(f"归一化 usage 失败: {e} | raw_usage: {raw_usage}")
return None
class BaseLLMClient(ABC):
"""大语言模型客户端基类"""
def __init__(self, config: ModelConfig):
self.config = config
self.client = None
self._initialize_client()
@abstractmethod
def _initialize_client(self):
"""初始化客户端"""
pass
@abstractmethod
def chat_completion(
self,
messages: List[Union[ChatMessage, Dict]],
**kwargs
) -> Union[ModelResponse, Iterator[ModelResponse]]:
"""聊天补全"""
pass
@abstractmethod
def completion(
self,
prompt: str,
**kwargs
) -> Union[ModelResponse, Iterator[ModelResponse]]:
"""文本补全"""
pass
def format_messages(self, messages: List[Union[ChatMessage, Dict]]) -> List[Dict]:
"""格式化消息列表"""
formatted = []
for msg in messages:
if isinstance(msg, ChatMessage):
formatted.append(msg.to_dict())
else:
formatted.append(msg)
return formatted
class OpenAIClient(BaseLLMClient):
"""OpenAI客户端"""
def _initialize_client(self):
try:
from openai import OpenAI
api_key = self.config.api_key
if not api_key:
raise ValueError("OpenAI API密钥未设置")
self.client = OpenAI(
api_key=api_key,
base_url=self.config.base_url,
timeout=self.config.timeout
)
logger.info(f"OpenAI客户端初始化成功,base_url: {self.config.base_url}")
except ImportError:
logger.error("请安装openai包: pip install openai")
raise
def chat_completion(self, messages, **kwargs):
formatted_messages = self.format_messages(messages)
# 获取streaming参数,优先使用kwargs中的设置
streaming = kwargs.get("streaming", self.config.streaming)
# 合并配置
params = {
"model": self.config.model_name,
"messages": formatted_messages,
"temperature": kwargs.get("temperature", self.config.temperature),
"max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
"top_p": kwargs.get("top_p", self.config.top_p),
"frequency_penalty": kwargs.get("frequency_penalty", self.config.frequency_penalty),
"presence_penalty": kwargs.get("presence_penalty", self.config.presence_penalty),
"stream": streaming, # 使用正确的streaming设置
}
logger.info(f"🔧 调用参数: streaming={streaming}")
if streaming:
return self._stream_chat_completion(params)
else:
return self._normal_chat_completion(params)
def _normal_chat_completion(self, params):
"""非流式响应处理"""
logger.info("📡 发送非流式请求...")
response = self.client.chat.completions.create(**params)
raw_usage = response.usage
normalized_usage = normalize_usage(
raw_usage.model_dump() if hasattr(raw_usage, 'model_dump') else raw_usage,
ModelProvider.OPENAI
)
return ModelResponse(
content=response.choices[0].message.content,
model=response.model,
usage=normalized_usage,
finish_reason=response.choices[0].finish_reason,
raw_response=response.model_dump() if hasattr(response, 'model_dump') else response.dict()
)
def _stream_chat_completion(self, params):
"""流式响应处理"""
logger.info("📡 发送流式请求...")
response_stream = self.client.chat.completions.create(**params)
full_content = ""
for chunk in response_stream:
if chunk.choices[0].delta.content is not None:
content_chunk = chunk.choices[0].delta.content
full_content += content_chunk
yield ModelResponse(
content=content_chunk,
model=chunk.model,
raw_response=chunk.model_dump() if hasattr(chunk, 'model_dump') else chunk.dict()
)
def completion(self, prompt, **kwargs):
# OpenAI 推荐使用 chat_completion,这里保持兼容
messages = [{"role": "user", "content": prompt}]
return self.chat_completion(messages, **kwargs)
class AnthropicClient(BaseLLMClient):
"""Anthropic Claude客户端"""
def _initialize_client(self):
try:
from anthropic import Anthropic
self.client = Anthropic(
api_key=self.config.api_key or os.getenv("ANTHROPIC_API_KEY"),
timeout=self.config.timeout
)
except ImportError:
logger.error("请安装anthropic包: pip install anthropic")
raise
def chat_completion(self, messages, **kwargs):
formatted_messages = self.format_messages(messages)
# Claude 的消息格式转换
claude_messages = []
system_message = None
for msg in formatted_messages:
if msg["role"] == "system":
system_message = msg["content"]
else:
claude_messages.append({
"role": msg["role"],
"content": msg["content"]
})
# 明确获取streaming参数
params = {
"model": self.config.model_name,
"messages": claude_messages,
"max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
"temperature": kwargs.get("temperature", self.config.temperature),
"top_p": kwargs.get("top_p", self.config.top_p),
"stream": kwargs.get("streaming", self.config.streaming),
}
if system_message:
params["system"] = system_message
if self.config.streaming:
return self._stream_chat_completion(params)
else:
return self._normal_chat_completion(params)
def _normal_chat_completion(self, params):
response = self.client.messages.create(**params)
# Anthropic 返回的usage格式和OpenAI不同,需要进行转换
raw_usage = response.usage
normalized_usage = normalize_usage(
raw_usage.model_dump() if hasattr(raw_usage, 'model_dump') else raw_usage,
ModelProvider.ANTHROPIC
)
return ModelResponse(
content=response.content[0].text,
model=response.model,
usage=normalized_usage,
finish_reason=response.stop_reason,
raw_response=response.model_dump() if hasattr(response, 'model_dump') else response.dict()
)
def _stream_chat_completion(self, params):
with self.client.messages.stream(**params) as stream:
for chunk in stream:
if chunk.type_ == "content_block_delta":
yield ModelResponse(
content=chunk.delta.text,
model=params["model"],
raw_response=chunk.model_dump() if hasattr(chunk, 'model_dump') else chunk.dict()
)
def completion(self, prompt, **kwargs):
messages = [{"role": "user", "content": prompt}]
return self.chat_completion(messages, **kwargs)
class OllamaClient(BaseLLMClient):
"""Ollama本地模型客户端"""
def _initialize_client(self):
import httpx
self.base_url = self.config.base_url or "http://localhost:11434"
self.client = httpx.Client(
base_url=self.base_url,
timeout=self.config.timeout
)
def chat_completion(self, messages, **kwargs):
formatted_messages = self.format_messages(messages)
payload = {
"model": self.config.model_name,
"messages": formatted_messages,
"options": {
"temperature": kwargs.get("temperature", self.config.temperature),
"top_p": kwargs.get("top_p", self.config.top_p),
"num_predict": kwargs.get("max_tokens", self.config.max_tokens),
},
"stream": kwargs.get("streaming", self.config.streaming),
}
if self.config.streaming:
return self._stream_chat_completion(payload)
else:
return self._normal_chat_completion(payload)
def _normal_chat_completion(self, payload):
response = self.client.post("/api/chat", json=payload)
response.raise_for_status()
data = response.json()
normalized_usage = normalize_usage(data, ModelProvider.OLLAMA)
return ModelResponse(
content=data["message"]["content"],
model=data["model"],
usage=normalized_usage,
finish_reason=data.get("done_reason"),
raw_response=data
)
def _stream_chat_completion(self, payload):
with self.client.stream("POST", "/api/chat", json=payload) as response:
for line in response.iter_lines():
if line.strip():
try:
data = json.loads(line)
if "message" in data and "content" in data["message"]:
yield ModelResponse(
content=data["message"]["content"],
model=data.get("model", self.config.model_name),
raw_response=data
)
except json.JSONDecodeError:
continue
def completion(self, prompt, **kwargs):
payload = {
"model": self.config.model_name,
"prompt": prompt,
"options": {
"temperature": kwargs.get("temperature", self.config.temperature),
"top_p": kwargs.get("top_p", self.config.top_p),
"num_predict": kwargs.get("max_tokens", self.config.max_tokens),
},
"stream": kwargs.get("streaming", self.config.streaming),
}
if self.config.streaming:
return self._stream_completion(payload)
else:
return self._normal_completion(payload)
def _normal_completion(self, payload):
response = self.client.post("/api/generate", json=payload)
response.raise_for_status()
data = response.json()
return ModelResponse(
content=data["response"],
model=data["model"],
usage={
"prompt_tokens": data.get("prompt_eval_count", 0),
"completion_tokens": data.get("eval_count", 0),
"total_tokens": data.get("prompt_eval_count", 0) + data.get("eval_count", 0)
},
finish_reason=data.get("done_reason"),
raw_response=data
)
def _stream_completion(self, payload):
with self.client.stream("POST", "/api/generate", json=payload) as response:
for line in response.iter_lines():
if line.strip():
try:
data = json.loads(line)
if "response" in data:
yield ModelResponse(
content=data["response"],
model=data.get("model", self.config.model_name),
raw_response=data
)
except json.JSONDecodeError:
continue
class GenericLLMClient(BaseLLMClient):
"""通用HTTP客户端,支持LM Studio和其他兼容OpenAI API的本地模型"""
def _initialize_client(self):
import httpx
self.base_url = self.config.base_url or "http://localhost:1234/v1"
self.client = httpx.Client(
base_url=self.base_url,
timeout=self.config.timeout,
headers=self.config.custom_headers
)
def chat_completion(self, messages, **kwargs):
formatted_messages = self.format_messages(messages)
# 明确获取 streaming 参数
streaming = kwargs.get("streaming", self.config.streaming)
payload = {
"model": self.config.model_name,
"messages": formatted_messages,
"temperature": kwargs.get("temperature", self.config.temperature),
"max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
"top_p": kwargs.get("top_p", self.config.top_p),
"stream": streaming, # 使用明确的 streaming 变量
}
logger.info(f"GenericLLMClient 参数: streaming={streaming}")
if streaming:
return self._stream_chat_completion(payload)
else:
return self._normal_chat_completion(payload)
def _normal_chat_completion(self, payload):
logger.info(f"GenericLLMClient 发送非流式请求到: {self.base_url}")
response = self.client.post("/chat/completions", json=payload)
response.raise_for_status()
data = response.json()
logger.info(f"GenericLLMClient 收到响应,模型: {data.get('model')}")
return ModelResponse(
content=data["choices"][0]["message"]["content"],
model=data["model"],
usage=data.get("usage"),
finish_reason=data["choices"][0].get("finish_reason"),
raw_response=data
)
def _stream_chat_completion(self, payload):
logger.info(f"GenericLLMClient 发送流式请求到: {self.base_url}")
with self.client.stream("POST", "/chat/completions", json=payload) as response:
for line in response.iter_lines():
if line.startswith("data: "):
chunk = line[6:]
if chunk == "[DONE]":
break
try:
data = json.loads(chunk)
if data["choices"][0]["delta"].get("content"):
yield ModelResponse(
content=data["choices"][0]["delta"]["content"],
model=data["model"],
raw_response=data
)
except json.JSONDecodeError:
continue
def completion(self, prompt, **kwargs):
payload = {
"model": self.config.model_name,
"prompt": prompt,
"temperature": kwargs.get("temperature", self.config.temperature),
"max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
"top_p": kwargs.get("top_p", self.config.top_p),
"stream": kwargs.get("streaming", self.config.streaming),
}
streaming = kwargs.get("streaming", self.config.streaming)
if streaming:
return self._stream_completion(payload)
else:
return self._normal_completion(payload)
def _normal_completion(self, payload):
response = self.client.post("/completions", json=payload)
response.raise_for_status()
data = response.json()
return ModelResponse(
content=data["choices"][0]["text"],
model=data["model"],
usage=data.get("usage"),
finish_reason=data["choices"][0].get("finish_reason"),
raw_response=data
)
def _stream_completion(self, payload):
with self.client.stream("POST", "/completions", json=payload) as response:
for line in response.iter_lines():
if line.startswith("data: "):
chunk = line[6:]
if chunk == "[DONE]":
break
try:
data = json.loads(chunk)
if data["choices"][0].get("text"):
yield ModelResponse(
content=data["choices"][0]["text"],
model=data["model"],
raw_response=data
)
except json.JSONDecodeError:
continue
class UnifiedLLM:
"""
统一的大语言模型调用类
支持多种模型提供商:
- 云端模型:OpenAI, Anthropic, Google, Azure
- 本地模型:Ollama, LM Studio, llama.cpp
使用示例:
```python
# 初始化OpenAI客户端
config = ModelConfig(
provider=ModelProvider.OPENAI,
model_name="gpt-4",
api_key="your-api-key"
)
llm = UnifiedLLM(config)
# 聊天补全
messages = [
{"role": "system", "content": "你是一个有用的助手"},
{"role": "user", "content": "你好!"}
]
response = llm.chat(messages)
print(response.content)
# 流式响应
config.streaming = True
llm.update_config(config)
for chunk in llm.chat(messages):
print(chunk.content, end="", flush=True)
```
"""
def __init__(self, config: ModelConfig):
"""
初始化统一LLM
Args:
config: 模型配置
"""
self.config = config
self.client = self._create_client()
logger.info(f" UnifiedLLM 初始化完成")
logger.info(f" 提供商: {config.provider}")
logger.info(f" 模型: {config.model_name}")
logger.info(f" 流式默认: {config.streaming}")
def _create_client(self) -> BaseLLMClient:
"""根据配置创建客户端"""
provider = self.config.provider
if provider == ModelProvider.OPENAI:
return OpenAIClient(self.config)
elif provider == ModelProvider.ANTHROPIC:
return AnthropicClient(self.config)
elif provider == ModelProvider.OLLAMA:
return OllamaClient(self.config)
elif provider in [ModelProvider.LM_STUDIO, ModelProvider.LLAMA_CPP, ModelProvider.CUSTOM]:
return GenericLLMClient(self.config)
elif provider == ModelProvider.GOOGLE:
# 这里可以扩展Google Gemini支持
return GenericLLMClient(self.config)
elif provider == ModelProvider.AZURE:
# Azure OpenAI需要特殊处理
return GenericLLMClient(self.config)
else:
raise ValueError(f"不支持的模型提供商: {provider}")
def update_config(self, config: ModelConfig):
"""更新配置并重新创建客户端"""
self.config = config
self.client = self._create_client()
logger.info(f"UnifiedLLM 配置已更新")
def chat(self, messages: List[Union[ChatMessage, Dict]], **kwargs) -> Union[ModelResponse, Iterator[ModelResponse]]:
"""
聊天补全
Args:
messages: 消息列表
**kwargs: 其他参数,会覆盖config中的设置
Returns:
ModelResponse 或 ModelResponse 的迭代器(流式模式)
"""
# 明确获取streaming参数
streaming = kwargs.get("streaming", self.config.streaming)
logger.info(f" UnifiedLLM.chat() 调用")
logger.info(f" 消息数: {len(messages)}")
logger.info(f" streaming参数: {streaming}")
# 调用客户端
result = self.client.chat_completion(messages, **kwargs)
# 类型检查(调试用)
if streaming:
if not hasattr(result, '__iter__'):
logger.warning(f"警告: streaming=True 但返回的不是迭代器")
else:
if hasattr(result, '__iter__'):
logger.warning(f"警告: streaming=False 但返回的是迭代器")
elif not isinstance(result, ModelResponse):
logger.warning(f"警告: streaming=False 但返回的不是ModelResponse")
return result
def complete(self, prompt: str, **kwargs) -> Union[ModelResponse, Iterator[ModelResponse]]:
"""
文本补全
Args:
prompt: 提示文本
**kwargs: 其他参数
Returns:
ModelResponse 或 ModelResponse 的迭代器(流式模式)
"""
streaming = kwargs.get("streaming", self.config.streaming)
logger.info(f" UnifiedLLM.complete() 调用")
logger.info(f" prompt长度: {len(prompt)}")
logger.info(f" streaming参数: {streaming}")
return self.client.completion(prompt, **kwargs)
def stream_chat(self, messages: List[Union[ChatMessage, Dict]], **kwargs) -> Iterator[ModelResponse]:
"""
流式聊天补全(便捷方法)
Args:
messages: 消息列表
**kwargs: 其他参数
Returns:
ModelResponse 的迭代器
"""
kwargs["streaming"] = True
logger.info(f"UnifiedLLM.stream_chat() 调用")
result = self.chat(messages, **kwargs)
# 确保返回的是迭代器
if not hasattr(result, '__iter__'):
raise TypeError("stream_chat 应该返回迭代器,但返回了其他类型")
return result
def stream_complete(self, prompt: str, **kwargs) -> Iterator[ModelResponse]:
"""
流式文本补全(便捷方法)
Args:
prompt: 提示文本
**kwargs: 其他参数
Returns:
ModelResponse 的迭代器
"""
kwargs["streaming"] = True
logger.info(f"UnifiedLLM.stream_complete() 调用")
result = self.complete(prompt, **kwargs)
# 确保返回的是迭代器
if not hasattr(result, '__iter__'):
raise TypeError("stream_complete 应该返回迭代器,但返回了其他类型")
return result
# 快捷函数
def create_llm_client(
provider: Union[str, ModelProvider],
model_name: str,
**kwargs
) -> UnifiedLLM:
"""
快捷创建LLM客户端
Args:
provider: 提供商名称或枚举
model_name: 模型名称
**kwargs: 其他配置参数
Returns:
UnifiedLLM 实例
"""
if isinstance(provider, str):
provider = ModelProvider(provider.lower())
config = ModelConfig(
provider=provider,
model_name=model_name,
**kwargs
)
return UnifiedLLM(config)
# 使用示例
def example_usage():
"""使用示例"""
# 示例1: 使用OpenAI
print("示例1: 使用OpenAI")
openai_config = ModelConfig(
provider=ModelProvider.OPENAI,
model_name="gpt-3.5-turbo",
api_key=os.getenv("OPENAI_API_KEY"),
temperature=0.8
)
try:
openai_llm = UnifiedLLM(openai_config)
messages = [
ChatMessage(role="system", content="你是一个有用的助手"),
ChatMessage(role="user", content="请用Python写一个Hello World程序")
]
response = openai_llm.chat(messages)
print(f"响应: {response.content[:100]}...")
except Exception as e:
print(f"OpenAI示例错误: {e}")
# 示例2: 使用Ollama(本地模型)
print("\n示例2: 使用Ollama(本地模型)")
ollama_config = ModelConfig(
provider=ModelProvider.OLLAMA,
model_name="llama2",
base_url="http://localhost:11434",
temperature=0.7,
streaming=True # 流式响应
)
try:
ollama_llm = UnifiedLLM(ollama_config)
messages = [
{"role": "user", "content": "什么是人工智能?"}
]
print("流式响应:")
for chunk in ollama_llm.stream_chat(messages):
print(chunk.content, end="", flush=True)
print()
except Exception as e:
print(f"Ollama示例错误: {e}(请确保Ollama服务正在运行)")
# 示例3: 使用快捷函数
print("\n示例3: 使用快捷函数")
try:
llm = create_llm_client(
provider="openai",
model_name="gpt-3.5-turbo",
api_key=os.getenv("OPENAI_API_KEY"),
temperature=0.5
)
response = llm.complete("天空为什么是蓝色的?")
print(f"响应: {response.content[:100]}...")
except Exception as e:
print(f"快捷函数示例错误: {e}")
@@ -0,0 +1,162 @@
# 大语言模型调用框架架构图
general_text_ai_req.py
```mermaid
graph TB
subgraph "应用层"
A[用户应用] --> B[UnifiedLLM 统一接口]
end
subgraph "适配器层"
B --> C{模型提供商路由}
C --> D[OpenAI 适配器]
C --> E[Anthropic 适配器]
C --> F[Ollama 适配器]
C --> G[通用HTTP适配器]
C --> H[其他适配器]
end
subgraph "服务层"
D --> I[OpenAI API]
E --> J[Anthropic API]
F --> K[Ollama 服务]
G --> L[LM Studio]
G --> M[llama.cpp]
G --> N[其他兼容API]
end
subgraph "配置层"
O[ModelConfig] --> C
O --> D
O --> E
O --> F
O --> G
end
subgraph "数据流"
P[输入: 消息/提示] --> B
I --> Q[输出: ModelResponse]
J --> Q
K --> Q
L --> Q
M --> Q
N --> Q
end
style A fill:#4567f1
style B fill:#4567f1
style O fill:#456748
style D fill:#457911
style E fill:#466bd5
style F fill:#4567f1
style G fill:#4567f1
```
# 数据流图
```mermaid
sequenceDiagram
participant User as 用户/应用
participant UnifiedLLM as UnifiedLLM
participant Adapter as 适配器
participant API as API服务
User->>UnifiedLLM: 调用chat()或complete()
UnifiedLLM->>UnifiedLLM: 根据配置选择适配器
UnifiedLLM->>Adapter: 转发请求
Adapter->>API: 发送HTTP请求/API调用
Note over API: 处理请求并生成响应
alt 流式模式
API-->>Adapter: 流式响应数据
Adapter-->>UnifiedLLM: 流式ModelResponse
UnifiedLLM-->>User: 迭代器返回分块响应
else 非流式模式
API-->>Adapter: 完整响应
Adapter-->>UnifiedLLM: ModelResponse对象
UnifiedLLM-->>User: 完整响应内容
end
```
# 类关系图
```mermaid
classDiagram
class ModelConfig {
+provider: ModelProvider
+model_name: str
+api_key: Optional[str]
+base_url: Optional[str]
+temperature: float
+max_tokens: int
+to_dict() Dict
}
class ChatMessage {
+role: str
+content: str
+name: Optional[str]
+to_dict() Dict
}
class ModelResponse {
+content: str
+model: str
+usage: Optional[Dict]
+finish_reason: Optional[str]
+raw_response: Optional[Dict]
}
class BaseLLMClient {
<<abstract>>
#config: ModelConfig
#client: Any
+__init__(config: ModelConfig)
+_initialize_client()
+chat_completion(messages, **kwargs)*
+completion(prompt, **kwargs)*
+format_messages(messages) List[Dict]
}
class UnifiedLLM {
-config: ModelConfig
-client: BaseLLMClient
+__init__(config: ModelConfig)
+_create_client() BaseLLMClient
+update_config(config: ModelConfig)
+chat(messages, **kwargs) ModelResponse
+complete(prompt, **kwargs) ModelResponse
+stream_chat(messages, **kwargs) Iterator
+stream_complete(prompt, **kwargs) Iterator
}
class OpenAIClient {
+_initialize_client()
+chat_completion(messages, **kwargs)
+completion(prompt, **kwargs)
}
class AnthropicClient {
+_initialize_client()
+chat_completion(messages, **kwargs)
+completion(prompt, **kwargs)
}
class OllamaClient {
+_initialize_client()
+chat_completion(messages, **kwargs)
+completion(prompt, **kwargs)
}
class GenericLLMClient {
+_initialize_client()
+chat_completion(messages, **kwargs)
+completion(prompt, **kwargs)
}
BaseLLMClient <|-- OpenAIClient
BaseLLMClient <|-- AnthropicClient
BaseLLMClient <|-- OllamaClient
BaseLLMClient <|-- GenericLLMClient
UnifiedLLM o-- BaseLLMClient
UnifiedLLM --> ModelConfig
BaseLLMClient --> ModelConfig
BaseLLMClient --> ChatMessage
BaseLLMClient --> ModelResponse
```
View File
+26
View File
@@ -0,0 +1,26 @@
本模块为tts模块,即文本转语音模块。
本模块负责将来自AI的回复转为语音。
说明:
在本module当中,每个子模块的用途分别是:
- tts_core 对不同的tts的实现,提供相对统一的接口
- gpt_sovits
实现了gpt_sovits的tts接口封装
async_audio_player.py
```mermaid
sequenceDiagram
participant TTS as GPT-SoVITS API
participant WS as WebSocket服务
participant Buffer as 音频缓冲区
participant Player as 音频播放器
TTS->>WS: 流式音频块(chunks)
WS->>Buffer: 写入队列(Queue)
Buffer->>Player: 消费PCM数据
Player->>声卡: 实时播放
Note over TTS,Player: 三重缓冲 + 动态采样率检测
```
@@ -0,0 +1,164 @@
# tts_core/async_audio_player.py
import asyncio
import io
from loguru import logger
from typing import Optional
import numpy as np
import sounddevice as sd
import wave
class AsyncAudioPlayer:
"""
异步流式音频播放器
- 自动检测WAV头并解析采样率
- 使用环形缓冲区确保播放流畅
- 支持动态音频格式切换
"""
def __init__(self, buffer_size: int = 10):
"""
Args:
buffer_size: 音频块缓冲数量(越大越稳定,但延迟越高)
"""
self.audio_queue = asyncio.Queue(maxsize=buffer_size)
self.sample_rate = 32000 # 默认采样率
self.channels = 1
self.dtype = np.float32
self.stream: Optional[sd.OutputStream] = None
self.is_playing = False
self._first_chunk_processed = False
logger.info(f"🎵 音频播放器初始化,缓冲区大小: {buffer_size}")
async def add_chunk(self, audio_data: bytes):
"""
添加音频块到播放队列
自动处理第一个chunk(包含WAV头)
"""
try:
# 第一个chunk需要解析WAV头
if not self._first_chunk_processed:
# 写入BytesIO以便wave模块读取
wav_buffer = io.BytesIO(audio_data)
try:
with wave.open(wav_buffer, 'rb') as wav_file:
# 解析WAV头信息
self.sample_rate = wav_file.getframerate()
self.channels = wav_file.getnchannels()
self.sampwidth = wav_file.getsampwidth()
# 读取PCM数据(去掉头部)
pcm_data = wav_file.readframes(wav_file.getnframes())
logger.info(f"📊 解析WAV头: {self.sample_rate}Hz, {self.channels}ch, {self.sampwidth * 8}bit")
# 转换为numpy数组
if self.sampwidth == 2:
audio_array = np.frombuffer(pcm_data, dtype=np.int16).astype(np.float32) / 32768.0
elif self.sampwidth == 4:
audio_array = np.frombuffer(pcm_data, dtype=np.int32).astype(np.float32) / 2147483648.0
else:
raise ValueError(f"不支持的采样宽度: {self.sampwidth}")
# 转单声道(如果多声道)
if self.channels > 1:
audio_array = audio_array.reshape(-1, self.channels).mean(axis=1)
await self.audio_queue.put(audio_array)
self._first_chunk_processed = True
except wave.Error:
# 可能是不完整的WAV头,尝试直接播放
logger.warning("⚠️ WAV头解析失败,尝试直接播放")
await self._play_raw(audio_data)
return
else:
# 后续chunk直接播放(RAW PCM
await self._play_raw(audio_data)
except Exception as e:
logger.error(f"❌ 音频块处理失败: {e}")
async def _play_raw(self, audio_data: bytes):
"""播放RAW PCM数据"""
try:
# 假设是16位PCM(最常见)
audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
# 如果是多声道数据(罕见)
if len(audio_array) % self.channels == 0 and self.channels > 1:
audio_array = audio_array.reshape(-1, self.channels).mean(axis=1)
await self.audio_queue.put(audio_array)
except Exception as e:
logger.error(f"❌ RAW音频处理失败: {e}")
async def play_worker(self):
"""后台播放任务"""
logger.info("🎧 音频播放任务启动")
while self.is_playing or not self.audio_queue.empty():
try:
# 从队列获取音频块(最多等待0.5秒)
audio_chunk = await asyncio.wait_for(self.audio_queue.get(), timeout=0.5)
# 延迟初始化音频流(直到获得第一个数据块)
if self.stream is None:
logger.info(f"🔊 打开音频输出流: {self.sample_rate}Hz")
self.stream = sd.OutputStream(
samplerate=self.sample_rate,
channels=1,
dtype=self.dtype,
blocksize=1024, # 低延迟模式
latency='low'
)
self.stream.start()
# 写入音频流播放
self.stream.write(audio_chunk)
# 标记任务完成
self.audio_queue.task_done()
except asyncio.TimeoutError:
continue
except Exception as e:
logger.error(f"❌ 播放任务异常: {e}")
break
logger.info("🛑 音频播放任务结束")
async def start(self):
"""启动播放系统"""
self.is_playing = True
self._first_chunk_processed = False
self.play_task = asyncio.create_task(self.play_worker())
async def stop(self):
"""停止播放并清理资源"""
self.is_playing = False
# 等待播放任务结束
if hasattr(self, 'play_task'):
await self.play_task
# 关闭音频流
if self.stream is not None:
self.stream.stop()
self.stream.close()
self.stream = None
# 清空队列
while not self.audio_queue.empty():
try:
self.audio_queue.get_nowait()
except:
break
logger.info("✅ 音频播放已停止")
async def __aenter__(self):
await self.start()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.stop()
@@ -0,0 +1,378 @@
# gpt_sovits/gpt_sovits_client.py
import asyncio
import json
from loguru import logger
from enum import Enum
from pathlib import Path
from typing import AsyncGenerator, Optional, Union, Dict, Any
from dataclasses import dataclass
import httpx
from pydantic import BaseModel, Field, validator
class APIError(Exception):
"""API调用异常"""
def __init__(self, status_code: int, message: str):
self.status_code = status_code
self.message = message
super().__init__(f"API Error {status_code}: {message}")
class StreamingMode(Enum):
"""流式模式枚举"""
DISABLED = 0 # 非流式
BEST_QUALITY = 1 # 最佳质量(慢)
MEDIUM_QUALITY = 2 # 中等质量
FASTEST = 3 # 最快响应(较低质量)
class TTSConfig(BaseModel):
"""TTS请求配置模型"""
text: str = Field(..., description="待合成文本")
text_lang: str = Field(..., description="文本语言: zh/en/ja/ko/cantonese")
ref_audio_path: str = Field(..., description="参考音频路径")
prompt_lang: str = Field(..., description="提示文本语言")
# 可选参数
prompt_text: str = Field(default="", description="参考音频提示文本")
aux_ref_audio_paths: list = Field(default_factory=list, description="辅助参考音频")
top_k: int = Field(default=5, ge=1, le=100, description="Top-K采样")
top_p: float = Field(default=1.0, ge=0.1, le=1.0, description="Top-P采样")
temperature: float = Field(default=1.0, ge=0.1, le=1.0, description="采样温度")
text_split_method: str = Field(default="cut5", description="文本分割方法") # 默认按照标点符号切分
batch_size: int = Field(default=8, ge=1, le=200, description="批处理大小")
speed_factor: float = Field(default=1.0, ge=0.6, le=1.65, description="语速倍率")
# 流式相关
streaming_mode: Union[bool, int, StreamingMode] = Field(default=False, description="流式模式")
media_type: str = Field(default="wav", description="输出格式: wav/raw/ogg/aac") # 输出格式
# 高级参数
repetition_penalty: float = Field(default=1.35, ge=1.0, le=2.0) # 惩罚参数
sample_steps: int = Field(default=32, ge=10, le=100) # 采样步数
parallel_infer: bool = Field(default=True) # 并行推理
@validator('text_lang', 'prompt_lang')
def validate_language(cls, v):
"""验证语言代码"""
valid_langs = {'zh', 'en', 'ja', 'ko', 'cantonese'}
if v.lower() not in valid_langs:
raise ValueError(f"Unsupported language: {v}. Must be one of {valid_langs}")
return v.lower()
@validator('media_type')
def validate_media_type(cls, v):
"""验证媒体类型"""
valid_types = {'wav', 'raw', 'ogg', 'aac'}
if v not in valid_types:
raise ValueError(f"Unsupported media_type: {v}")
return v
def build_request(self) -> Dict[str, Any]:
"""构建API请求数据"""
data = self.dict(exclude_none=True)
# 处理流式模式
if isinstance(self.streaming_mode, StreamingMode):
data['streaming_mode'] = self.streaming_mode.value
return data
@dataclass
class AudioResponse:
"""音频响应包装类"""
audio_data: bytes
sample_rate: int = 32000
def save(self, path: Union[str, Path]) -> None:
"""保存音频文件"""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, 'wb') as f:
f.write(self.audio_data)
logger.info(f"Audio saved to {path}, size: {len(self.audio_data)} bytes")
class GPTSoVITSClient:
"""
GPT-SoVITS异步API客户端
完整支持所有TTS功能:
- 文本合成(流式/非流式)
- 模型切换(GPT/SoVITS
- 参考音频设置
- 服务器控制
"""
def __init__(self, host: str = "127.0.0.1", port: int = 9880, debug: bool = False):
"""
初始化客户端
Args:
host: API服务器地址
port: API端口
debug: 是否开启调试模式
"""
self.base_url = f"http://{host}:{port}"
self.client = httpx.AsyncClient(
base_url=self.base_url,
timeout=httpx.Timeout(30.0, connect=5.0)
)
self.debug_mode = debug
logger.info(f"GPT-SoVITS Client initialized: {self.base_url}")
async def __aenter__(self) -> "GPTSoVITSClient":
"""异步上下文管理器入口"""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""异步上下文管理器出口"""
await self.close()
async def close(self):
"""关闭HTTP连接"""
await self.client.aclose()
logger.info("Client connection closed")
def _log_debug(self, message: str, **kwargs):
"""调试日志"""
if self.debug_mode:
logger.debug(f"{message} | {kwargs}")
async def _handle_response(self, response: httpx.Response) -> Dict[str, Any]:
"""统一响应处理"""
if response.status_code == 200:
content_type = response.headers.get('content-type', '')
if 'application/json' in content_type:
return response.json()
return {"status": "success", "content": response.content}
else:
try:
error_data = response.json()
raise APIError(response.status_code, error_data.get('message', 'Unknown error'))
except json.JSONDecodeError:
raise APIError(response.status_code, response.text)
# 核心TTS接口
async def tts(
self,
text: str,
ref_audio_path: str,
text_lang: str = "zh",
prompt_lang: str = "zh",
streaming_mode: StreamingMode = StreamingMode.DISABLED, # 默认禁用流式
media_type: str = "wav",
**kwargs
) -> Union[AudioResponse, AsyncGenerator[AudioResponse, None]]:
"""
文本转语音(支持流式)
Args:
text: 待合成文本
ref_audio_path: 参考音频路径(服务器本地路径或URL)
text_lang: 文本语言
prompt_lang: 提示语言
streaming_mode: 流式模式
media_type: 输出格式
**kwargs: 其他TTS参数
Returns:
非流式: AudioResponse对象
流式: AsyncGenerator[AudioResponse, None]异步生成器
Example:
# 非流式
audio = await client.tts("你好", "ref.wav")
# 流式
async for chunk in client.tts("你好", "ref.wav", streaming_mode=StreamingMode.FASTEST):
process(chunk.audio_data)
"""
config = TTSConfig(
text=text,
ref_audio_path=ref_audio_path,
text_lang=text_lang,
prompt_lang=prompt_lang,
streaming_mode=streaming_mode,
media_type=media_type,
**kwargs
)
self._log_debug("TTS Request", config=config.dict())
if streaming_mode == StreamingMode.DISABLED:
# 非流式模式
response = await self.client.post("/tts", json=config.build_request())
if response.status_code != 200:
raise APIError(response.status_code, await response.text())
return AudioResponse(
audio_data=response.content,
sample_rate=32000 # 默认采样率
)
else:
# 流式模式
config.parallel_infer = False # 强制关闭并行推理,避免与流式冲突
config.batch_size = 1 # 流式下batch_size必须为1
async def stream_generator():
async with self.client.stream(
"POST", "/tts",
json=config.build_request(),
timeout=httpx.Timeout(60.0) # 流式需要更长超时
) as response:
if response.status_code != 200:
raise APIError(response.status_code, await response.aread())
async for chunk in response.aiter_bytes():
if chunk:
yield AudioResponse(audio_data=chunk)
return stream_generator()
# 模型管理接口
async def set_gpt_weights(self, weights_path: str) -> bool:
"""
切换GPT模型权重
Args:
weights_path: 权重文件路径(服务器本地路径)
Returns:
bool: 是否成功
Example:
await client.set_gpt_weights("models/s1bert.ckpt")
"""
if not weights_path:
raise ValueError("weights_path cannot be empty")
params = {"weights_path": weights_path}
response = await self.client.get("/set_gpt_weights", params=params)
result = await self._handle_response(response)
logger.info(f"GPT weights switched to: {weights_path}")
return True
async def set_sovits_weights(self, weights_path: str) -> bool:
"""
切换SoVITS模型权重
Args:
weights_path: 权重文件路径(服务器本地路径)
Returns:
bool: 是否成功
"""
if not weights_path:
raise ValueError("weights_path cannot be empty")
params = {"weights_path": weights_path}
response = await self.client.get("/set_sovits_weights", params=params)
await self._handle_response(response)
logger.info(f"SoVITS weights switched to: {weights_path}")
return True
# 参考音频管理
async def set_refer_audio(
self,
audio_source: Union[str, Path, bytes],
audio_name: Optional[str] = None
) -> bool:
"""
设置参考音频(支持多种输入方式)
Args:
audio_source: 音频文件路径(str/Path)或音频数据(bytes
audio_name: 音频文件名(仅bytes输入时需要)
Returns:
bool: 是否成功
Example:
# 方式1: 服务器本地文件
await client.set_refer_audio("/path/to/audio.wav")
# 方式2: 上传音频数据
with open("audio.wav", "rb") as f:
await client.set_refer_audio(f.read(), "audio.wav")
"""
if isinstance(audio_source, (str, Path)):
# GET方式:服务器本地路径
params = {"refer_audio_path": str(audio_source)}
response = await self.client.get("/set_refer_audio", params=params)
await self._handle_response(response)
logger.info(f"Reference audio set: {audio_source}")
else:
# POST方式:上传音频数据
if not audio_name:
raise ValueError("audio_name is required when uploading bytes")
files = {"audio_file": (audio_name, audio_source, "audio/wav")}
response = await self.client.post("/set_refer_audio", files=files)
await self._handle_response(response)
logger.info(f"Reference audio uploaded: {audio_name}")
return True
# 服务器控制
async def control_command(self, command: str) -> bool:
"""
发送控制命令
Args:
command: 命令类型 - "restart""exit"
Returns:
bool: 是否成功
Warning:
"exit"命令会终止API服务器进程!
"""
if command not in ["restart", "exit"]:
raise ValueError("Command must be 'restart' or 'exit'")
response = await self.client.get("/control", params={"command": command})
await self._handle_response(response)
logger.warning(f"Control command executed: {command}")
return True
# 高级快捷方法
async def get_server_info(self) -> Dict[str, Any]:
"""获取服务器状态信息"""
# 通过调用根路径或自定义health接口
try:
response = await self.client.get("/")
return {"status": "online", "detail": response.text}
except Exception as e:
return {"status": "error", "detail": str(e)}
async def batch_tts(
self,
texts: list[str],
ref_audio_path: str,
**kwargs
) -> list[AudioResponse]:
"""
批量TTS合成
Args:
texts: 文本列表
ref_audio_path: 参考音频
**kwargs: 其他TTS参数
Returns:
list[AudioResponse]: 音频响应列表
"""
tasks = [
self.tts(text, ref_audio_path, **kwargs)
for text in texts
]
return await asyncio.gather(*tasks)
# 异步上下文管理器辅助函数
async def create_client(*args, **kwargs) -> GPTSoVITSClient:
"""快速创建客户端实例"""
return GPTSoVITSClient(*args, **kwargs)
@@ -0,0 +1,27 @@
# dto/dto_base.py
from abc import ABC
from typing import Callable, Coroutine, Any, Dict
from src.modules.websocket_base_module.websocket_core.core_ws_server import WebSocketServer
class MessageDTO(ABC):
"""DTO基类"""
def __init__(
self,
ws_server : WebSocketServer, # WebSocketServer单例
):
# 保存服务器实例(用于发送)
self.ws_server = ws_server
# 便捷属性,DTO层直接调用
@property
def send_binary(self):
return self.ws_server.send_binary
@property
def send_text(self):
return self.ws_server.send_text
@property
def send_json(self):
return self.ws_server.send_json
@@ -0,0 +1,95 @@
from pydantic import Field, BaseModel
import base64
from datetime import datetime, timezone
class AudioDataTransferObject(BaseModel):
"""
音频数据传输对象
该对象被用于服务端与客户端的音频数据交互
同时支持流式与非流式的音频数据
同时收发对等(通过Owner标识)
"""
Owner: str = Field(default="server", description="音频数据的拥有者(server or client)")
isStream: bool = Field(default=False, description="音频数据是否为流式数据")
isStart: bool = Field(default=False, description="音频数据是否开始(流式时有效)")
isEnd: bool = Field(default=False, description="音频数据是否结束(流式时有效)")
sequence: int = Field(default=0, description="音频数据块序列号(流式时有效)")
data: bytes = Field(default=b"", description="音频数据,流式时为分块数据,base64编码")
sampleRate: int = Field(default=32000, description="音频采样率")
channelCount: int = Field(default=1, description="音频通道数")
bitDepth: int = Field(default=16, description="音频采样位数")
duration: float = Field(default=0.0, description="音频时长")
text: str = Field(default="", description="音频对应的文本")
def set_dto_data(self, **kwargs) -> "AudioDataTransferObject":
"""链式更新数据(Pydantic 风格)"""
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
return self
def to_json(self) -> dict:
"""
将DTO对象转换为可序列化字典
返回的json数据的格式:
{
"type": "audio_data",
"timestamp": 1672531200.0,
"data": {
"Owner": "server",
......
}
}
"""
# model_dump() 是 Pydantic v2 的序列化方法
payload = self.model_dump() # 获取所有模型字段
payload["data"] = base64.b64encode(payload["data"]).decode() # base64编码
# 构造嵌套结构
return {
"type": "audio_data",
"timestamp": datetime.now(timezone.utc).timestamp(),
"data": payload # 音频字段嵌套
}
@classmethod
def from_json(cls, json_data: dict) -> "AudioDataTransferObject":
"""
从JSON数据创建DTO对象
传入的json数据格式:
{
"Owner": "server",
......
}
"""
payload = json_data
# 解码 base64 (内层 data 字段在传输时是 base64 字符串) -> bytes
if "data" in payload and isinstance(payload["data"], str):
payload["data"] = base64.b64decode(payload["data"])
# 构造对象 Pydantic 自动忽略 type/timestamp
return cls.model_validate(payload)
# 测试代码
if __name__ == "__main__":
# 模拟音频数据
import os
test_audio = os.urandom(1024) # 随机生成1KB音频数据
# 创建DTO
audio = AudioDataTransferObject(
data=test_audio,
sequence=1,
isStream=True,
sampleRate=44100,
duration=0.5
)
# 序列化 → JSON
json_dict = audio.to_json()
print(f"序列化后 data 长度: {len(json_dict['data'])}") # ~1368 字符
print(f"data 前30字符: {json_dict['data'][:30]}...")
# 反序列化 → DTO
restored = AudioDataTransferObject.from_json(json_dict)
print(f"反序列化后 data 长度: {len(restored.data)}") # 1024 bytes
print(f"数据一致: {restored.data == test_audio}") # True
@@ -0,0 +1,73 @@
from pydantic import Field, BaseModel
from datetime import datetime, timezone
class AutoAgentDataTransferObject(BaseModel):
"""
自动化agent数据传输对象
该对象被用于服务端向客户端发送控制信息
"""
Action: str = Field(default="", description="自动化动作名称")
x1: int = Field(default=-1, description="鼠标起始位置x1")
y1: int = Field(default=-1, description="鼠标起始位置y1")
x2: int = Field(default=-1, description="鼠标结束位置x2")
y2: int = Field(default=-1, description="鼠标结束位置y2")
key: str = Field(default="", description="快捷键")
content: str = Field(default="", description="输入文本内容")
direction: str = Field(default="", description="滚动方向")
def set_dto_data(self, **kwargs) -> "AutoAgentDataTransferObject":
"""链式更新数据(Pydantic 风格)"""
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
return self
def to_json(self) -> dict:
"""
将DTO对象转换为可序列化字典
返回的json数据的格式:
{
"type": "auto_agent",
"timestamp": 1672531200.0,
"data": {
"Action": "自动化agent返回的相应的自动化动作名称",
"x1": "某个操作的x1,不是所有的操作都有,如果相关的操作没有,写成-1即可,y1,x2,y2也同理",
"y1": "某个操作的y1",
"x2": "某个操作的x2",
"y2": "某个操作的y2",
"key": "快捷键,若当前操作没有该字段信息,此处内容为空即可",
"content": "输入文本的内容,若当前操作没有该字段信息,此处内容为空即可",
"direction": "滚动方向,若当前操作没有该字段信息,此处内容为空即可"
}
}
"""
# model_dump() 是 Pydantic v2 的序列化方法
payload = self.model_dump() # 获取所有模型字段
# 构造嵌套结构
return {
"type": "auto_agent",
"timestamp": datetime.now(timezone.utc).timestamp(),
"data": payload # 字段嵌套
}
@classmethod
def from_json(cls, json_data: dict) -> "AutoAgentDataTransferObject":
"""
从JSON数据创建DTO对象
传入的json数据格式:
{
"type": "auto_agent",
"Action": "自动化agent返回的相应的自动化动作名称",
"x1": "某个操作的x1,不是所有的操作都有,如果相关的操作没有,写成-1即可,y1,x2,y2也同理",
"y1": "某个操作的y1",
"x2": "某个操作的x2",
"y2": "某个操作的y2",
"key": "快捷键,若当前操作没有该字段信息,此处内容为空即可",
"content": "输入文本的内容,若当前操作没有该字段信息,此处内容为空即可",
"direction": "滚动方向,若当前操作没有该字段信息,此处内容为空即可"
}
"""
payload = json_data
payload.pop("type", None) # 移除多余字段
# 构造对象 Pydantic 自动忽略 type/timestamp
return cls.model_validate(payload)
@@ -0,0 +1,37 @@
class BaseDataTransferObject:
"""
DTO基类
子类按需重写成员函数即可
"""
def __init__(self):
pass
def to_json(self):
"""
将DTO对象转换为JSON
"""
pass
def from_json(self, json_data):
"""
从JSON数据中创建DTO对象
"""
pass
def to_binary(self):
"""
将DTO对象转换为二进制
"""
pass
def from_binary(self, binary_data):
"""
从二进制数据中创建DTO对象
"""
pass
def to_text(self):
"""
将DTO对象转换为文本
"""
pass
def from_text(self, text_data):
"""
从文本数据中创建DTO对象
"""
pass
@@ -0,0 +1,72 @@
from pydantic import Field, BaseModel
from datetime import datetime, timezone
class ScreenShotDataTransferObject(BaseModel):
"""
服务端向客户端请求实时截图的数据传输对象
服务端与客户端收发对等(通过Owner标识)
客户端收到这个type的包,就会自动对当前设备的画面进行截图
"""
Owner: str = Field(default="server", description="数据的拥有者(server or client)")
isSuccess: bool = Field(default=False, description="是否截图成功")
RealTimeScreenShot: str = Field(default="", description="客户端设备的实时截图数据(base64)")
Width: int = Field(default=1920, description="截图的宽度")
Height: int = Field(default=1080, description="截图的高度")
DescribeInfo: str = Field(default="", description="设备的描述信息(告知模型以做出更加准确的判断)")
LLMResponse: str = Field(default="", description="LLM的响应结果(由服务端发送时携带)")
def set_dto_data(self, **kwargs) -> "ScreenShotDataTransferObject":
"""链式更新数据(Pydantic 风格)"""
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
return self
def to_json(self) -> dict:
"""
将DTO对象转换为可序列化字典
返回的json数据的格式:
{
"type": "screenshot_data",
"timestamp": 1672531200.0,
"data": {
"Owner": "数据的拥有者(server or client)",
"isSuccess": "是否截图成功(true or false)"
"RealTimeScreenShot": "客户端设备的实时截图数据(base64)",
"Width": "截图的宽度",
"Height": "截图的高度",
"DescribeInfo": "设备的描述信息(告知模型以做出更加准确的判断)",
"LLMResponse": "LLM的响应结果(由服务端发送时携带)"
}
}
"""
# model_dump() 是 Pydantic v2 的序列化方法
payload = self.model_dump() # 获取所有模型字段
# 构造嵌套结构
return {
"type": "screenshot_data",
"timestamp": datetime.now(timezone.utc).timestamp(),
"data": payload # 字段嵌套
}
@classmethod
def from_json(cls, json_data: dict) -> "ScreenShotDataTransferObject":
"""
从JSON数据创建DTO对象
传入的json数据格式:
{
"Owner": "数据的拥有者(server or client)",
"isSuccess": "是否截图成功(true or false)",
"RealTimeScreenShot": "客户端设备的实时截图数据(base64)",
"Width": "截图的宽度 非必要字段",
"Height": "截图的高度 非必要字段",
"DescribeInfo": "设备的描述信息(告知模型以做出更加准确的判断) 非必要字段",
"LLMResponse": "LLM的响应结果(由服务端发送时携带) 必要字段"
}
"""
payload = json_data
payload.pop("type", None) # 移除多余字段
payload.pop("timestamp", None) # 移除多余字段
# 构造对象 Pydantic 自动忽略 type/timestamp
return cls.model_validate(payload)
@@ -0,0 +1,95 @@
# dto/second_dtos.py
import asyncio
from typing import Callable, Optional, Dict, Any, List, Coroutine
from src.modules.websocket_base_module.dto.dto_base import MessageDTO
from loguru import logger
"""
二级分发器,因为没有信号与槽机制,因此使用观察者模式替代
"""
def singleton(cls): # 单例
_instance = None
_lock = asyncio.Lock()
async def get_instance(*args, **kwargs):
nonlocal _instance
if _instance is None:
async with _lock:
if _instance is None:
_instance = cls(*args, **kwargs)
return _instance
cls.get_instance = get_instance
return cls
# 类型别名
ReceiveCallback = Callable[[Any], Coroutine[Any, Any, None]]
@singleton
class JsonDTO(MessageDTO):
"""针对json消息的二级分发"""
"""
明确业务json格式:
{
"type" : "xxx",
"timestamp" : "95153...",
"data" : "{根据业务的不同,有不同的内容}"
}
"""
# 因为不同的数据块的json以type字段进行包装,根据type进行正确的数据分发
def __init__(self, ws_server):
super().__init__(ws_server)
self.receivers : Dict[str, List[ReceiveCallback]] = {
'audio_data' : [], # 音频数据
'screenshot_data' : [] # 截图数据
}
# 注册json处理callback function
ws_server.register_receiver('json', self._handle_json)
logger.info("[JsonDTO] JSON分发器已注册")
def register_receiver(self, types : str, callback : ReceiveCallback):
"""注册二次分发业务接收函数,供业务DTO调用"""
if types in self.receivers:
self.receivers[types].append(callback)
logger.debug(f"[JsonDTO] 已注册 {types} 接收器,当前共 {len(self.receivers[types])}")
else:
raise ValueError(f"[JsonDTO] 不支持的分发类型: {types}")
def unregister_receiver(self, types : str, callback : ReceiveCallback):
"""注销二次分发业务接收函数"""
if callback in self.receivers[types]:
self.receivers[types].remove(callback)
logger.debug(f"[JsonDTO] 已注销 {types} 接收器")
async def _handle_json(self, data: dict):
"""JSON消息处理"""
logger.info(f"[JsonDTO] 收到消息")
logger.debug(f'[JsonDTO] 当前消息时间戳: {data["timestamp"]}')
# 根据类型进行自动分发
await self._dispatch(data.get("type"), data["data"])
async def _dispatch(self, types : str, data : dict):
"""二次分发json数据到相应的接收函数当中"""
callbacks = self.receivers[types] # 获取相关types的所有观察者
if not callbacks:
logger.info(f"[JsonDTO] 无 {types} 接收器,消息被忽略")
return
# 并发执行所有回调
tasks = [callback(data) for callback in callbacks]
await asyncio.gather(*tasks, return_exceptions=True)
async def get_json_dto_instance(ws_server) -> JsonDTO:
return JsonDTO(ws_server)
class EchoDTO(MessageDTO):
"""回声DTO:只处理文本消息 测试用"""
def __init__(self, ws_server):
super().__init__(ws_server)
# 注册文本接收函数
ws_server.register_receiver('text', self._handle_text)
logger.info("[EchoDTO] 文本接收器已注册")
async def _handle_text(self, message: str):
"""文本消息处理"""
logger.info(f"[EchoDTO] 收到文本: {message}")
# 业务逻辑
await self.send_text(f"Echo: {message}")
@@ -0,0 +1,281 @@
import asyncio
from src.modules.websocket_base_module.dto.dto_templates.audio_data_dto import AudioDataTransferObject
from src.modules.websocket_base_module.dto.dto_templates.screenshot_data_dto import ScreenShotDataTransferObject
from src.modules.websocket_base_module.dto.second_dtos import JsonDTO
from loguru import logger
from typing import Callable, List, Optional, Coroutine
class AudioDataDTO:
"""音频数据交互DTO 再次分发给所有使用到了音频数据的相关业务(最后一级分发)"""
def __init__(self, json_dto : JsonDTO):
json_dto.register_receiver('audio_data', self._handle_audio_data) # 注册JSON接收函数
logger.info("[AudioDataDTO] 音频接收业务已注册")
self.json_dto = json_dto
self.audio_data = AudioDataTransferObject() # 音频数据对象
# 业务回调列表,延续观察者模式
self._audio_callbacks: List[Callable[[AudioDataTransferObject], Coroutine]] = []
# 最新音频缓存 支持同步查询
self._latest_audio: Optional[AudioDataTransferObject] = None
# 流式缓冲区 用于大段音频流
self._stream_buffer: List[AudioDataTransferObject] = []
logger.info("[AudioDataDTO] 业务接口已初始化")
async def _handle_audio_data(self, data: dict):
"""处理音频数据"""
"""
音频数据json格式:
{
'Owner': 'server',
'isStream': False,
'isStart': False,
'isEnd': False,
'sequence': 0,
'data': '',
'sampleRate': 16000,
'channelCount': 1,
'bitDepth': 16,
'duration': 0.0,
'text': ''
}
"""
logger.debug(f"[AudioDataDTO] 收到音频数据")
# 将dict反序列化到DTO对象
self.audio_data = AudioDataTransferObject.from_json(data)
# 缓存最新数据
self._latest_audio = self.audio_data
# 如果是流式数据,加入缓冲区
if self.audio_data.isStream:
self._stream_buffer.append(self.audio_data)
if self.audio_data.isEnd:
logger.info(f"流式音频接收完成,共 {len(self._stream_buffer)}")
# 通知所有注册的回调
await self._notify_callbacks()
# 业务发送接口
async def send_audio_data(self, data: AudioDataTransferObject) -> None:
"""
发送音频数据
Args:
data: 音频数据DTO
"""
await self.send_audio(
Owner=data.Owner,
is_stream=data.isStream,
is_start=data.isStart,
is_end=data.isEnd,
sequence=data.sequence,
data=data.data,
sampleRate=data.sampleRate,
channelCount=data.channelCount,
bitDepth=data.bitDepth,
duration=data.duration,
text=data.text
)
async def send_audio(
self,
data: bytes,
is_stream: bool = False,
is_start: bool = False,
is_end: bool = False,
sequence: int = 0,
**audio_meta
) -> None:
"""
业务层发送音频的便捷接口
Args:
data: 原始音频字节
is_stream: 是否为流式数据
is_start: 流式数据开始标记
is_end: 流式数据结束标记
sequence: 数据块序号
**audio_meta: 其他音频参数(sampleRate, channelCount等)
"""
# 填充音频数据到DTO
self.audio_data.set_dto_data(
Owner="server" or audio_meta.get('Owner', "server"),
isStream=is_stream,
isStart=is_start,
isEnd=is_end,
sequence=sequence,
data=data,
sampleRate=audio_meta.get('sampleRate', 16000),
channelCount=audio_meta.get('channelCount', 1),
bitDepth=audio_meta.get('bitDepth', 16),
duration=audio_meta.get('duration', 0.0),
text=audio_meta.get('text', "")
)
# 序列化为JSON并发送 自动处理base64和type字段
json_message = self.audio_data.to_json()
await self.json_dto.send_json(json_message)
logger.info(f"音频已发送: sequence={sequence}, 大小={len(data)} bytes")
# 业务接收接口
def register_audio_callback(
self,
callback: Callable[[AudioDataTransferObject], Coroutine]
) -> None:
"""
业务注册接收回调
使用示例:
async def my_audio_handler(audio_dto: AudioDataTransferObject):
print(f"收到音频: {len(audio_dto.data)} bytes")
audio_dto.register_audio_callback(my_audio_handler)
"""
self._audio_callbacks.append(callback)
logger.debug(f"业务音频回调已注册,当前共 {len(self._audio_callbacks)}")
def unregister_audio_callback(self, callback) -> None:
"""注销业务回调"""
if callback in self._audio_callbacks:
self._audio_callbacks.remove(callback)
logger.debug("业务音频回调已注销")
def get_latest_audio(self) -> Optional[AudioDataTransferObject]:
"""
同步获取最新音频数据(轮询模式)
Returns:
最新接收到的音频DTO,如果没有则为 None
"""
return self._latest_audio
def clear_stream_buffer(self) -> None:
"""清空流式缓冲区"""
self._stream_buffer.clear()
logger.debug("流式音频缓冲区已清空")
# 内部通知机制
async def _notify_callbacks(self) -> None:
"""通知所有业务回调"""
if not self._audio_callbacks:
logger.warning("无业务回调,音频数据未处理")
return
tasks = [callback(self.audio_data) for callback in self._audio_callbacks]
await asyncio.gather(*tasks, return_exceptions=True)
logger.debug(f"已通知 {len(self._audio_callbacks)} 个业务回调")
# 异步迭代器 流式时使用
def __aiter__(self):
"""支持 async for 循环接收流式音频"""
return self
async def __anext__(self) -> AudioDataTransferObject:
"""异步迭代器协议"""
pass
class ScreenShotDataDTO:
"""截屏数据交互DTO 分发给所有使用到了截屏数据的相关业务(最后一级分发)"""
def __init__(self, json_dto : JsonDTO):
json_dto.register_receiver('screenshot_data', self._handle_screenshot_data) # 注册JSON接收函数
logger.info("[ScreenShotDataDTO] 截屏接收业务已注册")
self.json_dto = json_dto
self.screenshot_data = ScreenShotDataTransferObject() # 截屏数据对象
# 业务回调列表,延续观察者模式
self._screenshot_callbacks: List[Callable[[ScreenShotDataTransferObject], Coroutine]] = []
# 最新截屏数据缓存 支持同步查询
self._latest_screenshot: Optional[ScreenShotDataTransferObject] = None
logger.info("[ScreenShotDataDTO] 业务接口已初始化")
async def _handle_screenshot_data(self, data: dict):
"""处理截屏数据"""
"""
截屏数据json格式:
{
"Owner": "数据的拥有者(server or client)",
"isSuccess": "是否截图成功(true or false)"
"RealTimeScreenShot": "客户端设备的实时截图数据(base64)",
"Width": "截图的宽度",
"Height": "截图的高度",
"DescribeInfo": "设备的描述信息(告知模型以做出更加准确的判断)"
}
"""
logger.debug(f"[ScreenShotDataDTO] 收到截屏数据")
# 将dict反序列化到DTO对象
self.screenshot_data = ScreenShotDataTransferObject.from_json(data)
# 缓存最新数据
self._latest_screenshot = self.screenshot_data
# 通知所有注册的回调
await self._notify_callbacks()
# 业务发送接口
async def send_screenshot_data(self, data: ScreenShotDataTransferObject) -> None:
"""
发送音频数据
Args:
data: 音频数据DTO
"""
await self.send_screenshot(
Owner=data.Owner,
isSuccess=data.isSuccess,
RealTimeScreenShot=data.RealTimeScreenShot,
Width=data.Width,
Height=data.Height,
DescribeInfo=data.DescribeInfo,
LLMResponse=data.LLMResponse
)
async def send_screenshot(
self,
**screenshot_meta
) -> None:
"""
业务层发送音频的便捷接口
一般来说,作为发送请求方,不需要填充任何数据
Args:
**screenshot_meta: 截屏数据元信息
"""
# 填充音频数据到DTO
self.screenshot_data.set_dto_data(
Owner="server" or screenshot_meta.get('Owner', "server"),
isSuccess=screenshot_meta.get('isSuccess', False),
RealTimeScreenShot=screenshot_meta.get('RealTimeScreenShot', ""),
Width=screenshot_meta.get('Width', 1920),
Height=screenshot_meta.get('Height', 1080),
DescribeInfo=screenshot_meta.get('DescribeInfo', False),
LLMResponse=screenshot_meta.get('LLMResponse', "")
)
# 序列化为JSON并发送 自动处理base64和type字段
json_message = self.screenshot_data.to_json()
await self.json_dto.send_json(json_message)
logger.info(f"截屏包已发送")
# 业务接收接口
def register_screenshot_callback(
self,
callback: Callable[[ScreenShotDataTransferObject], Coroutine]
) -> None:
"""
业务注册接收回调
"""
self._screenshot_callbacks.append(callback)
logger.debug(f"业务截屏回调已注册,当前共 {len(self._screenshot_callbacks)}")
def unregister_screenshot_callback(self, callback) -> None:
"""注销业务回调"""
if callback in self._screenshot_callbacks:
self._screenshot_callbacks.remove(callback)
logger.debug("业务截屏回调已注销")
def get_latest_screenshot(self) -> Optional[ScreenShotDataTransferObject]:
"""
同步获取最新截屏数据(轮询模式)
Returns:
最新接收到的截屏数据DTO,如果没有则为 None
"""
return self._latest_screenshot
# 内部通知机制
async def _notify_callbacks(self) -> None:
"""通知所有业务回调"""
if not self._screenshot_callbacks:
logger.warning("无业务回调,截屏数据未处理")
return
tasks = [callback(self.screenshot_data) for callback in self._screenshot_callbacks]
await asyncio.gather(*tasks, return_exceptions=True)
logger.debug(f"已通知 {len(self._screenshot_callbacks)} 个业务回调")
@@ -0,0 +1,56 @@
### 说明
在本module当中,每个子模块的用途分别是:
- dto
- dto_templates
服务端与客户端交互所使用到的数据传输对象
- dto_base.py / xxx_dtos.py
实际的DTO业务
- websocket_core
- websocket核心,承载了底层核心的网络收发业务
### 模块架构
```mermaid
graph TB
subgraph "Client"
C[WebSocket Client]
end
subgraph "Core Layer"
WS[WebSocketServer<br/>单例]
WS -->|持有| WSP[WebSocketServerProtocol<br/>_websocket]
WS -->|管理| RCV[ receivers: Dict<br/>binary/text/json ]
end
subgraph "DTO Base Layer"
MDTO[MessageDTO<br/>抽象基类]
MDTO -->|注入| MDF[ send_binary<br/>send_text<br/>send_json ]
end
subgraph "Secondary Dispatcher"
JDTO[JsonDTO<br/>单例]
JDTO -->|继承| MDTO
JDTO -->|维护| MAP[ receivers: Dict<br/>audio_data/... ]
JDTO -->|注册到| WS
end
subgraph "Business DTO"
ADTO[AudioDataDTO......<br/>业务实现]
ADTO -->|持有引用| JDTO
ADTO -->|使用| ATO[AudioDataTransferObject<br/>Pydantic模型]
end
C <-->|websocket连接| WSP
WS -->|分发消息| JDTO
JDTO -->|二次分发| ADTO
ADTO -->|发送响应| JDTO
JDTO -->|调用| MDF
MDF -->|经由| WS
WS -->|发送至| C
style WS fill:#64f,stroke:#333,stroke-width:2px
style JDTO fill:#569,stroke:#333,stroke-width:2px
style ADTO fill:#38f,stroke:#333,stroke-width:2px
```
@@ -0,0 +1,172 @@
# websocket_core/core_ws_server.py
import asyncio
import json
from typing import Callable, Optional, Dict, Any, List, Coroutine
from websockets.asyncio.server import serve, ServerConnection
from websockets.exceptions import ConnectionClosed
from loguru import logger
# 类型别名
ReceiveCallback = Callable[[Any], Coroutine[Any, Any, None]]
class WebSocketServer:
"""WebSocket服务端核心模块(单例 + 单客户端)
只管理一个客户端连接,DTO层注册接收函数,服务端只负责分发。
"""
_instance: Optional["WebSocketServer"] = None
_lock = asyncio.Lock()
def __new__(cls):
"""同步单例(__init__ 可以是 async"""
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
# 防止重复初始化
if self._initialized:
return
self._websocket: Optional[ServerConnection] = None
self._receivers: Dict[str, List[ReceiveCallback]] = {
'binary': [],
'text': [],
'json': []
}
self._connected_event = asyncio.Event()
self._initialized = True
logger.info("WebSocketServer 单客户端分发器已初始化")
# DTO层注册接口
def register_receiver(self, msg_type: str, callback: ReceiveCallback) -> None:
"""注册接收函数(供DTO层调用)
Args:
msg_type: 消息类型(binary/text/json
callback: 接收回调,签名为 (data) -> coroutine
"""
if msg_type in self._receivers:
self._receivers[msg_type].append(callback)
logger.debug(f"已注册 {msg_type} 接收器,当前 {msg_type} 类型接收器共 {len(self._receivers[msg_type])}")
else:
raise ValueError(f"不支持的消息类型: {msg_type}")
def unregister_receiver(self, msg_type: str, callback: ReceiveCallback) -> None:
"""注销接收函数"""
if callback in self._receivers[msg_type]:
self._receivers[msg_type].remove(callback)
logger.debug(f"已注销 {msg_type} 接收器")
# 发送接口(供DTO层调用)
async def send_binary(self, data: bytes):
"""发送二进制数据(唯一客户端)"""
if self._websocket:
await self._websocket.send(data)
logger.trace(f"二进制数据已发送 (长度: {len(data)} bytes)")
else:
raise RuntimeError("客户端未连接")
async def send_text(self, data: str):
"""发送文本数据(唯一客户端)"""
if self._websocket:
await self._websocket.send(data)
logger.trace(f"文本数据已发送 (长度: {len(data)} chars)")
else:
raise RuntimeError("客户端未连接")
async def send_json(self, data: Dict[str, Any]):
"""发送JSON数据(唯一客户端)"""
if self._websocket:
try:
logger.debug(f"准备发送JSON数据: {data}")
message = json.dumps(data)
await self._websocket.send(message)
logger.trace(f"JSON数据已发送: {data}")
except Exception as e:
logger.error(f"JSON数据发送失败: {e}")
raise
else:
raise RuntimeError("客户端未连接")
# 等待连接
async def wait_for_client(self):
"""阻塞等待客户端连接"""
await self._connected_event.wait()
logger.info("客户端已就绪")
# 内部消息循环
async def _handle_client(self, websocket: ServerConnection):
"""处理唯一客户端的消息循环"""
self._websocket = websocket
self._connected_event.set()
client_info = f"{websocket.remote_address}" if hasattr(websocket, 'remote_address') else "unknown"
logger.info(f"客户端已连接: {client_info}")
try:
async for message in websocket:
# 根据消息类型分发到所有注册的接收函数
if isinstance(message, bytes):
await self._dispatch('binary', message)
elif isinstance(message, str):
# 优先尝试JSON解析
json_dispatched = False
try:
data = json.loads(message)
await self._dispatch('json', data)
json_dispatched = True
except json.JSONDecodeError:
pass
# 如果不是JSON或没有json接收器,尝试text
if not json_dispatched:
await self._dispatch('text', message)
else:
logger.warning(f"未知消息类型: {type(message)}")
logger.info("客户端连接已正常关闭")
except ConnectionClosed as e:
logger.info(f"客户端连接已关闭: {e}")
except Exception as e:
logger.error(f"处理客户端时发生错误: {e}")
finally:
self._websocket = None
self._connected_event.clear()
async def _dispatch(self, msg_type: str, data: Any):
"""分发消息到所有注册的接收函数"""
callbacks = self._receivers[msg_type]
if not callbacks:
logger.warning(f"{msg_type} 接收器,消息被忽略")
return
# 并发执行所有回调
tasks = [callback(data) for callback in callbacks]
await asyncio.gather(*tasks, return_exceptions=True)
# 启动服务器
async def run(self, host: str = "localhost", port: int = 8765, max_msg_size: int = 50*1024*1025):
"""启动WebSocket服务器(阻塞)"""
logger.info(f"WebSocket服务器启动中... 等待客户端连接 ws://{host}:{port}")
async def handler_wrapper(connection):
logger.info(f"新连接请求: {connection.remote_address}")
await self._handle_client(connection)
try:
async with serve(
handler=handler_wrapper, # 使用 wrapper 适配签名
host=host,
port=port,
max_size=max_msg_size,
):
logger.success(f"WebSocket服务器已启动在 ws://{host}:{port}")
await asyncio.Future() # 永久阻塞,保持服务器运行
except Exception as e:
logger.error(f"WebSocket服务器启动失败: {e}")
raise
async def get_ws_server() -> WebSocketServer:
"""全局单例获取函数(线程安全)"""
server = WebSocketServer() # __new__ 保证单例
return server
+306
View File
@@ -0,0 +1,306 @@
# server_core/core.py
"""
统一业务,对外提供启动接口
业务数据流向:
Yosuga[User Audio Info Struct] ->(WebSocket) Yosuga_server[asr_module] -> Text
Yosuga_server[Come from Yosuga Audio ASR Text] ->(Func call) Yosuga_server[llm_core] -> Ins and Text
Yosuga_server[Come from llm_core Text]->(WebSocket) Yosuga
Yosuga_server[Come from llm_core Text]->(Func call) Yosuga_server[TTS] -> Audio Data
Yosuga_server[Audio Data] ->(WebSocket) Yosuga
Yosuga_embedded[Devices Control Info] ->(WebSocket) Yosuga_server[embedded_core]-> Ins
Yosuga_embedded[Devices Control Info] ->(Serial) Yosuga[SerialManager] -> ForWord To Yosuga_server
Yosuga[Come from embedded Info] ->(WebSocket) Yosuga_server[llm_core] -> Ins
UI_TARS[Mind and x&y Info] ->(Func call) Yosuga_server[llm_core] -> Ins
Yosuga_server[Come from UI_TARS Ins] ->(WebSocket) Yosuga
Yosuga[Live2D Control Info] ->(WebSocket) Yosuga_server[llm_core] -> Ins
Yosuga_server[Live2D Control Ins] ->(Websocket) Yosuga
Yosuga_server[agent memory] ->(Func call) Yosuga_server[Memory Uint]
"""
import asyncio
from typing import Optional, List, Dict, Any
from loguru import logger
from src.modules.websocket_base_module.dto.third_dtos import (
AudioDataDTO, AudioDataTransferObject,
ScreenShotDataDTO, ScreenShotDataTransferObject
)
from src.modules.websocket_base_module.dto.second_dtos import JsonDTO, get_json_dto_instance
from src.modules.websocket_base_module.websocket_core.core_ws_server import WebSocketServer, get_ws_server
from src.modules.device_control_module.device_control_core.ui_tars_.ui_tars_client import UITarsClient, UITarsClientConfig
from src.modules.asr_module.client.asr_client import create_asr_client, ASRClientConfig, ASRClientAsync
from src.modules.tts_module.tts_core.gpt_sovits.gpt_sovits_client import StreamingMode, TTSConfig, GPTSoVITSClient
from src.server_core.llm_core.llm_core import (
LLMCoreConfig, ModelConfig,
YosugaLLMCore, ModelProvider,
LLMCoreAnalysisBase,
YosugaAudioResponseData, YosugaUITARSResponseData,
YosugaUITARSRequestData
)
from src.modules.websocket_base_module.dto.dto_templates.auto_agent_data_dto import AutoAgentDataTransferObject
from src.config.config import cfg
class YosugaServerCore:
"""
异步单例类
"""
_instance: Optional["YosugaServerCore"] = None
_lock = asyncio.Lock()
# 组合必要的工具类
ws_server: WebSocketServer
json_dto: JsonDTO
audio_dto: AudioDataDTO
screenshot_dto: ScreenShotDataDTO
asr_client: ASRClientAsync # 异步asr client
tts_client: GPTSoVITSClient # tts client
auto_agent_client: UITarsClient # GUI自动化agent
llm_core: YosugaLLMCore = None # llm core
@classmethod
async def get_instance(cls) -> "YosugaServerCore":
"""异步单例工厂"""
if cls._instance is None:
async with cls._lock:
if cls._instance is None:
logger.info("Initializing YosugaServerCore...")
# 创建实例
instance = cls.__new__(cls)
# 按依赖顺序初始化数据分发器
instance.ws_server = await get_ws_server()
instance.json_dto = await get_json_dto_instance(instance.ws_server)
instance.audio_dto = AudioDataDTO(instance.json_dto) # 音频分发器
instance.audio_dto.register_audio_callback(instance._handle_audio_data) # 注册音频处理函数
instance.screenshot_dto = ScreenShotDataDTO(instance.json_dto) # 截图分发器
instance.screenshot_dto.register_screenshot_callback(instance._handle_screenshot_data) # 注册截图处理函数
instance.asr_client = create_asr_client(use_async=True, base_url=cfg.asr.url)
instance.tts_client = GPTSoVITSClient(host=cfg.tts.host, port=cfg.tts.port, debug=True)
# 切换GPT_SoVITS模型
await instance.tts_client.set_gpt_weights(cfg.tts.gpt_model_name)
await instance.tts_client.set_sovits_weights(cfg.tts.sovits_model_name)
instance.auto_agent_client = UITarsClient(UITarsClientConfig(
deployment_type=cfg.auto_agent.deployment_type,
base_url=cfg.auto_agent.base_url,
model_name=cfg.auto_agent.model_name,
temperature=cfg.auto_agent.temperature,
max_tokens=cfg.auto_agent.max_tokens
))
instance.llm_core = YosugaLLMCore(
model_config=ModelConfig( # TODO 同上
provider=ModelProvider.OPENAI,
model_name=cfg.ai.model_name,
base_url=cfg.ai.base_url,
api_key=cfg.ai.api_key,
temperature=cfg.ai.temperature,
max_tokens=cfg.ai.max_tokens
),
core_config=LLMCoreConfig( # TODO 同上
max_context_tokens=cfg.llm_core.max_context_tokens,
enable_history=cfg.llm_core.enable_history,
role_setting=cfg.llm_core.role_character,
language=cfg.llm_core.language, # 回复使用语言
auto_dispatch=True,
dispatch_async=True # 启用异步分发
)
)
instance.register_llm_core_analysis() # 注册解析器
instance.register_llm_core_action() # 注册分发器
instance.llm_core.register_overflow_handler(instance._handle_overflow_logger) # 注册上下文溢出处理回调
cls._instance = instance
logger.success("YosugaServerCore initialized")
return cls._instance
def register_llm_core_action(self):
"""
注册llm_core的分发器
"""
if self.llm_core is None:
raise Exception("LLMCore is not initialized")
self.llm_core.register_action_handler("audio_text", self._handle_audio_response, is_async=True)
self.llm_core.register_action_handler("auto_agent", self._handle_auto_agent, is_async=True)
self.llm_core.register_action_handler("call_auto_agent", self._handle_call_auto_agent, is_async=True)
self.llm_core.set_fallback_handler(self._handle_fallback)
def register_llm_core_analysis(self):
"""
注册llm_core的输出解析器
"""
if self.llm_core is None:
raise Exception("LLMCore is not initialized")
self.llm_core.register_analysis_model(YosugaAudioResponseData)
self.llm_core.register_analysis_model(YosugaUITARSResponseData)
self.llm_core.register_analysis_model(YosugaUITARSRequestData)
def _handle_overflow_logger(self, history: List[Any], metadata: Dict[str, Any]):
"""上下文溢出记录,仅打印日志"""
print(f" 上下文溢出!")
print(f" 模型: {metadata['model']}")
print(f" 消息数: {metadata['message_count']}")
print(f" Token: {metadata['estimated_tokens']}/{metadata['limit']}")
print(f" 即将遗忘 {len(history) // 2} 条旧消息")
async def _handle_audio_data(self, audio_data: AudioDataTransferObject):
"""
音频数据接收call back
Yosuga_server只有接受到每次这个audio数据才会跑一次
"""
logger.info("Received audio data")
# 在此处客户端发送的音频数据必定不是流式数据(考虑客户端发送数据给服务端往往是在本地的,速度极快)
# 将音频数据发送给asr转换成文本信息,音频数据格式为wav
# TODO: 考虑在此处做一个简单的vad检测,如果客户端发送的音频是静音的,则不把请求发给llm_core
asr_response = await self.asr_client.transcribe_bytes(audio_data.data)
if not asr_response.success:
logger.error(f"ASR failed: {asr_response.error}")
asr_result = asr_response.data # 获取asr结果
# 将asr结果发送给llm_core进行处理
llm_result = await self.llm_core.interact(
user_input={ # 构造用户输入信息
"text": asr_result.text,
"confidence": asr_result.confidence
}
) # llm_core会自动进行处理并通过执行器异步返回各种相关的数据
async def _handle_screenshot_data(self, screenshot_data: ScreenShotDataTransferObject):
"""
屏幕截图数据接收call back
将llm_core的回复封装后提交给auto_agent模块,获得自动化agent的返回之后再返回给llm_core
"""
logger.info(f"Received screenshot data {len(screenshot_data.RealTimeScreenShot)}")
if not screenshot_data.isSuccess: # 如果客户端截图失败
logger.error("Screenshot failed")
return # 直接提前结束回调,不向llm_core发送结果
# TODO 对于设备描述信息(screenshot_data.DescribeInfo),考虑加入到auto_agent的输入中,增强识别准确率
# 构造请求 异步调用
logger.debug(f"screenshot_data.LLMResponse(来自llm_core向auto_agent的输入): {screenshot_data.LLMResponse}")
logger.debug(f"客户端设备信息: {screenshot_data.DescribeInfo}")
auto_agent_response: str = await self.auto_agent_client.call_async(screenshot_data.LLMResponse,
screenshot_data.RealTimeScreenShot)
logger.debug(f"auto_agent_response(auto_agent原生返回结果): {auto_agent_response}")
# 将auto_agent的返回结果发送给llm_core
await self.llm_core.interact(
user_input={ # 构造auto_agent输入信息
"auto_agent": auto_agent_response
}
)
async def _handle_audio_response(self, data: YosugaAudioResponseData):
"""
llm_core异步处理器:语音回复
将llm_core的回复封装后提交给tts模块,调用tts模块中的流式返回,并将流式frame返回给Yosuga客户端
"""
if data.type == "audio_text":
logger.info("Handling audio response")
try:
# 使用最快模式流式输出
chunk_count = 0
async for chunk in await self.tts_client.tts(
text=data.response_text,
ref_audio_path="uploaded_audio/test_voice.wav", # TODO 需要替换成config或者后续设计情感系统
text_lang="ja",
prompt_lang="ja",
prompt_text="もう!こんなところで何やってるんだよ!", # 参考语音的真实文本
streaming_mode=StreamingMode.FASTEST, # 模式3:快速流式
media_type="wav"
):
chunk_count += 1
print(f"🎵 收到音频块 #{chunk_count}: {len(chunk.audio_data)} bytes")
if chunk_count == 1: # 如果是第一个音频块
# 构造音频首包发送给客户端
await self.audio_dto.send_audio_data(
AudioDataTransferObject(
data=chunk.audio_data,
isStream=True,
isStart=True,
sequence=chunk_count,
isEnd=False,
text=data.response_text
)
)
else: # 如果不是第一个音频块,则发送中间包给客户端
await self.audio_dto.send_audio_data(
AudioDataTransferObject(
data=chunk.audio_data,
isStream=True,
isStart=False,
sequence=chunk_count,
isEnd=False,
text=data.response_text
)
)
print(f"✅ 流式TTS完成!共{chunk_count}个音频块")
# 构造音频尾包发送给客户端(虚假的音频数据)
await self.audio_dto.send_audio_data(
AudioDataTransferObject(
data=b"0",
isStream=True,
isStart=False,
sequence=chunk_count + 1,
isEnd=True,
text=data.response_text
)
)
except Exception as e:
print(f"❌ 流式错误: {e}")
return {"status": "success", "executed": data.response_text}
return None
async def _handle_auto_agent(self, data: YosugaUITARSResponseData):
"""
llm_core异步处理器:处理自动化操作
将llm_core的回复封装后提交给Yosuga客户端,由客户端进行执行相关的GUI自动化操作
"""
# 构造并发送回复数据
await self.json_dto.send_json(
AutoAgentDataTransferObject.from_json(data.to_dict()).to_json()
)
return {"status": "success", "executed": data.Action}
async def _handle_call_auto_agent(self, data: YosugaUITARSRequestData):
"""
llm_core异步处理器:处理llm_core调用auto_agent需求
向客户端请求当前界面的截图,请求成功后由_handle_screenshot_data函数完成剩下的任务
"""
if data.type == "call_auto_agent":
logger.info("LLM Calling auto agent")
# 向客户端请求当前界面的截图的base64编码 加入llm回复的信息到截图请求DTO当中 方便_handle_screenshot_data构造请求
await self.screenshot_dto.send_screenshot_data(ScreenShotDataTransferObject(LLMResponse=data.llm_translation))
return {"status": "success", "executed": data.type}
def _handle_fallback(self, data: LLMCoreAnalysisBase):
"""
llm_core同步处理器:回退处理器
"""
logger.debug(f" [Fallback] 未知类型数据: {data.type}, 内容: {data.model_dump_json()}")
async def run(self):
"""启动服务器"""
logger.info("Yosuga Server Websocket Core 启动中...")
await self.ws_server.run(host="0.0.0.0")
# 使用方式
async def main():
core = await YosugaServerCore.get_instance()
await core.run()
if __name__ == '__main__':
asyncio.run(main())
+2
View File
@@ -0,0 +1,2 @@
### 本模块为Yosuga_server的AI输出语音情感管理模块
+564
View File
@@ -0,0 +1,564 @@
# llm_core/llm_core.py
"""
Yosuga Server LLM 核心控制模块
负责整合Prompt管理、模型调用、输出解析、上下文记忆管理以及生命周期维护。
作为系统的"大脑",对外提供统一的高级交互接口。
"""
import asyncio
import json
import time
from typing import List, Dict, Any, Optional, Callable, Union, Type
from loguru import logger
from pydantic import BaseModel, Field
# 引入已有的模块
from src.modules.text_ai_module.text_ai_core.general_text_ai_req import (
UnifiedLLM, ModelConfig, ChatMessage, ModelResponse, ModelProvider
)
from src.server_core.llm_core.llm_core_analysis import (
LLMCoreAnalysisManager, LLMCoreAnalysisBase, YosugaAudioResponseData, YosugaUITARSResponseData, YosugaUITARSRequestData
)
from src.server_core.llm_core.llm_core_dispatcher import LLMCoreActionDispatcher
from src.server_core.llm_core.llm_core_prompt_manager import (
LLMCorePromptManager, LLMCorePromptBase,
YosugaAudioASRText, YosugaUITARS, YosugaLive2DControl
)
from src.server_core.llm_core.llm_core_prompts import YOSUGA_SYSTEM_PROMPT_SCH
from src.server_core.llm_core.llm_core_token import TokenManager, TokenUsage
# 类型定义:上下文溢出回调函数签名
# 参数1: 溢出的历史记录列表
# 参数2: 相关的元数据
ContextOverflowCallback = Callable[[List[ChatMessage], Dict[str, Any]], None]
class LLMCoreConfig(BaseModel):
"""LLM Core 运行时配置"""
max_context_tokens: int = Field(default=2000, description="上下文最大Token数(估算值,不包括System prompt),超出触发重置")
enable_history: bool = Field(default=True, description="是否启用历史对话记忆")
language: str = Field(default="zh_CN", description="回复语言设定")
role_setting: str = Field(default="", description="llm角色扮演")
auto_dispatch: bool = Field(default=True, description="是否自动分发到动作处理器")
dispatch_async: bool = Field(default=False, description="分发是否使用异步模式")
memory: str = Field(default="", description="llm记忆")
system_state_table: str = Field(default="", description="Yosuga系统状态表")
class YosugaLLMCore:
"""
Yosuga 服务端 LLM 核心控制器
"""
def __init__(self, model_config: ModelConfig, core_config: Optional[LLMCoreConfig] = None):
"""
初始化 LLM Core
Args:
model_config: 底层大模型的连接配置
core_config: 核心业务逻辑配置(上下文限制等)
"""
self.model_config = model_config
self.core_config = core_config or LLMCoreConfig()
# 初始化模型客户端 (UnifiedLLM)
self.llm_client: UnifiedLLM = UnifiedLLM(self.model_config)
# 初始化 TokenManager
self.token_manager = TokenManager(self.model_config.model_name)
# 初始化Prompt管理器
self.prompt_manager = LLMCorePromptManager()
self._register_default_prompts()
# 上下文记忆存储
self._history: List[ChatMessage] = [] # 注意:history不包含system prompt,只包含 user/assistant 消息
# 上下文溢出回调列表
self._overflow_callbacks: List[ContextOverflowCallback] = []
logger.info(
f"YosugaLLMCore 初始化完成 | "
f"模型: {model_config.model_name} | "
f"提供商: {model_config.provider}"
)
logger.info(
f"上下文限制: {self.core_config.max_context_tokens} tokens | "
f"自动分发: {self.core_config.auto_dispatch}"
)
def _register_default_prompts(self):
"""注册默认的业务Prompt模块"""
self.prompt_manager.register(YosugaAudioASRText())
self.prompt_manager.register(YosugaUITARS())
# self.prompt_manager.register(YosugaLive2DControl()) # TODO
logger.info(f"默认Prompt模块注册完成 | 数量: {self.prompt_manager.get_registry_size()}")
# 系统提示词管理
def get_system_prompt(self) -> str:
"""
动态构建当前的 System Prompt
根据 prompt_manager 中注册的模块实时生成
"""
return YOSUGA_SYSTEM_PROMPT_SCH.format(
InputInfo=self.prompt_manager.describe_input(), # 不变的内容
OutputInfo=self.prompt_manager.describe_output(), # 不变的内容
RoleSetting=self.core_config.role_setting, # 角色扮演,可热重载
Language=self.core_config.language, # 回复语言,可热重载
Memory=self.core_config.memory, # 记忆,可热重载,请求前更新
SystemStateTable=self.core_config.system_state_table# 系统状态表,可热重载,每次请求都会更新
)
def register_prompt_module(self, prompt_module: LLMCorePromptBase):
"""运行时注册新的 Prompt 业务模块"""
self.prompt_manager.register(prompt_module)
logger.info(f"动态注册 Prompt 模块: {prompt_module.type()}")
def register_analysis_model(self, model_class: Type[LLMCoreAnalysisBase]) -> None:
"""
注册LLM输出解析模型
Args:
model_class: 继承自 LLMCoreAnalysisBase 的数据模型类
"""
LLMCoreAnalysisManager.register(model_class)
logger.info(f"注册解析模型: {model_class.type_()}")
def register_action_handler(
self,
type_id: str,
handler: Callable,
is_async: bool = False
) -> None:
"""
注册动作处理器
Args:
type_id: 与解析模型对应的类型标识
handler: 处理函数(同步或异步)
is_async: 是否为异步处理器
"""
if is_async:
LLMCoreActionDispatcher.register_async(type_id, handler)
else:
LLMCoreActionDispatcher.register(type_id, handler)
def set_fallback_handler(self, handler: Callable) -> None:
"""设置未注册类型的回退处理器"""
LLMCoreActionDispatcher.set_fallback(handler)
# 核心交互接口
async def interact(
self,
user_input: Union[str, Dict[str, Any]], # 输入(纯文本或结构化字典)
past_memories: Optional[str] = "", # 记忆模块检索的相关历史记忆
system_state_table: Optional[str] = "", # 系统状态表,可热重载,每次请求都会更新
auto_dispatch: Optional[bool] = True, # 是否自动分发,默认为启用
dispatch_async: Optional[bool] = True # 是否异步分发,默认为启用
) -> Dict[str, List[Any]]:
"""
核心交互方法:处理输入 -> 组装上下文 -> 调用LLM -> 解析输出
Args:
user_input: 用户输入(纯文本或结构化字典)
past_memories: 记忆模块检索的相关历史记忆
system_state_table: Yosuga系统状态表(方便llm理解当前系统状态)
auto_dispatch: 是否自动分发(覆盖默认配置)
dispatch_async: 是否异步分发(覆盖默认配置)
Returns:
分发执行结果字典:
{
"success": [{"type": "...", "output": ..., "index": 0}],
"failed": [{"type": "...", "error": "...", "index": 1}],
"skipped": ["unknown_type"]
}
Raises:
ValueError: LLM调用或解析失败
RuntimeError: 分发执行致命错误
"""
# 输入预处理
input_content = (
json.dumps(user_input, ensure_ascii=False)
if isinstance(user_input, dict)
else user_input
)
logger.info(f"用户输入内容经过json处理后为: {input_content}")
# 检查并维护上下文
self._maintain_context_limit()
# 构建本次请求的消息链
messages = self._build_request_messages(input_content, past_memories, system_state_table)
# 调用 LLM
try:
llm_response = self._call_llm(messages)
if llm_response.usage: # 打印请求消耗的Token数量
self.token_manager.record_api_usage(llm_response.usage)
logger.info(self.token_manager.format_usage_log(source="API")) # 使用 TokenManager 格式化日志
except Exception as e:
logger.error(f"LLM调用失败: {e}")
raise
# 统一解析(总是返回列表)
try:
parsed_results = LLMCoreAnalysisManager.parse(llm_response.content)
logger.success(f"解析成功 | 对象数: {len(parsed_results)}")
except Exception as e:
logger.error(f"输出解析失败: {e}")
raise
# 更新历史记忆
if self.core_config.enable_history:
self._add_to_history("user", input_content)
self._add_to_history("assistant", llm_response.content)
# 分发执行
should_dispatch = auto_dispatch if auto_dispatch is not None else self.core_config.auto_dispatch
if should_dispatch and parsed_results:
is_async = dispatch_async if dispatch_async is not None else self.core_config.dispatch_async
if is_async:
# 直接 await 异步执行,如果调用 asyncio.run() 就会因为重复创建事件循环导致报错
logger.debug("使用异步模式分发动作")
return await LLMCoreActionDispatcher._execute_async(parsed_results)
else:
# 同步模式保持原样
return LLMCoreActionDispatcher.execute(parsed_results, run_async=False)
# 不分发则返回原始解析结果(极少用)
return {"success": [{"type": obj.type, "output": obj, "index": i}
for i, obj in enumerate(parsed_results)],
"failed": [], "skipped": []}
def _build_request_messages(
self,
current_input: str,
memories: str,
system_state_table: str
) -> List[ChatMessage]:
"""
构建完整的LLM消息链
结构:
1. System Prompt 构造,包括记忆注入等信息
2. 历史上下文
3. 当前用户输入
"""
# 构建memory与其他信息
self.core_config.memory = memories
# 构建系统状态表
self.core_config.system_state_table = system_state_table
# 构造 System Prompt
messages = [ChatMessage(role="system", content=self.get_system_prompt())]
# 历史上下文
if self.core_config.enable_history:
messages.extend(self._history) # 分开每条消息追加
# 当前用户输入
messages.append(ChatMessage(role="user", content=current_input))
return messages
def _call_llm(self, messages: List[ChatMessage]) -> ModelResponse:
"""执行LLM调用(非流式)"""
logger.debug(f"请求LLM | 消息数: {len(messages)}")
# 预估token使用情况 TODO: (调试用)
estimated = self.token_manager.estimate_chat_tokens(
system_prompt=self.get_system_prompt(),
history=self._history,
current_input=messages[-1].content if messages else ""
)
logger.debug(self.token_manager.format_usage_log(estimated.to_dict(), source="MANUAL"))
# 强制非流式(结构化输出需要完整JSON)
response: ModelResponse = self.llm_client.chat(
messages,
streaming=False,
temperature=self.model_config.temperature
)
# 记录 API 返回的 usage
if response.usage:
self.token_manager.record_api_usage(response.usage)
logger.debug(f"LLM响应长度: {len(response.content)}")
return response
# 上下文与记忆管理
def _add_to_history(self, role: str, content: str):
"""添加消息到历史"""
self._history.append(ChatMessage(role=role, content=content))
def _maintain_context_limit(self) ->None:
"""
检查上下文是否超出限制
如果超出,触发溢出回调,并将当前上下文导出,然后清空最近50%
"""
if not self.core_config.enable_history: # 若未启用历史对话记忆
return
# 使用 TokenManager 获取上下文占用(手动计算)
context_usage = self.token_manager.get_context_usage(self._history)
current_usage = context_usage.total_tokens
# 使用 TokenManager 格式化日志
logger.debug(
self.token_manager.format_usage_log(context_usage.to_dict(), source="CONTEXT")
)
limit = self.core_config.max_context_tokens
# 使用 TokenManager 判断是否接近限制
if self.token_manager.is_token_limit_approaching(current_usage, limit, threshold=0.85):
logger.warning(
f"上下文接近限制: {current_usage}/{limit} "
f"({current_usage / limit:.1%})"
)
if current_usage <= limit:
return
# 否则就是溢出
logger.critical(
f"上下文溢出!| {current_usage}/{limit} tokens "
f"({current_usage / limit:.1%}) | 消息: {len(self._history)}"
)
# 执行所有注册的溢出处理器
self._trigger_overflow_callbacks()
# 智能清理:保留最近50%消息
keep_messages = max(1, len(self._history) // 2)
self._history = self._history[-keep_messages:]
# 求出新的token占有
new_usage = self._estimate_token_usage()
logger.success( # 打印清理前后的token占用变化
f"清理完成 | Token: {current_usage}{new_usage} | "
f"保留消息: {len(self._history)}"
)
def _estimate_token_usage(self) -> int:
"""
使用 TokenManager 计算当前历史记录的 Token 数
"""
if not self._history:
return 0
# 将 ChatMessage 对象转换为字典格式
history_dicts = [
{"role": msg.role, "content": msg.content}
for msg in self._history
]
return self.token_manager.count_messages_tokens(
history_dicts,
tokens_per_message=3 # OpenAI 格式开销
)
def register_overflow_handler(self, handler: ContextOverflowCallback):
"""
注册上下文溢出处理器(支持多个目标)
这个上下文溢出处理器用于将溢出的消息收集并记录,和记忆模块对接
"""
self._overflow_callbacks.append(handler)
logger.info(f"注册溢出处理器: {handler.__name__}")
def _trigger_overflow_callbacks(self):
"""执行所有注册的溢出处理器"""
if not self._overflow_callbacks:
return
# 使用 TokenManager 获取当前占用
context_usage = self.token_manager.get_context_usage(self._history)
metadata = { # 构造详细的元数据
"reason": "token_limit_exceeded",
"message_count": len(self._history),
"estimated_tokens": context_usage.total_tokens,
"limit": self.core_config.max_context_tokens,
"timestamp": time.time(),
"model": self.model_config.model_name
}
# 快照当前历史,防止回调修改
history_snapshot = list(self._history)
for handler in self._overflow_callbacks:
try:
handler(history_snapshot, metadata)
logger.debug(f"溢出处理器成功: {handler.__name__}")
except Exception as e:
logger.error(f"执行上下文溢出回调失败: {handler.__name__}:{e}")
def clear_context(self):
"""手动清空上下文"""
self._history.clear()
logger.info("上下文记忆已清空")
# 运行时热重载
def reload_model(self, new_model_config: ModelConfig):
"""
热重载 LLM 模型配置
不影响当前的上下文记忆和 System Prompt
"""
logger.info(f"正在热重载模型: {self.model_config.model_name} -> {new_model_config.model_name}")
try:
self.llm_client.update_config(new_model_config)
self.model_config = new_model_config
# 重新初始化 TokenManager
self.token_manager = TokenManager(new_model_config.model_name)
logger.info("模型热重载成功")
except Exception as e:
logger.error(f"模型热重载失败: {e}")
raise
def get_context_stats(self) -> Dict[str, Any]:
"""获取详细上下文统计"""
# 使用 TokenManager 获取当前占用
context_usage = self.token_manager.get_context_usage(self._history)
tokenizer_info = self.token_manager.get_tokenizer_info()
return {
"message_count": len(self._history),
"estimated_tokens": context_usage.total_tokens,
"limit": self.core_config.max_context_tokens,
"usage_ratio": context_usage.total_tokens / self.core_config.max_context_tokens,
"model": self.model_config.model_name,
"tokenizer": tokenizer_info,
"history_preview": [
f"{msg.role[:1]}:{msg.content[:30]}..."
for msg in self._history[-3:]
],
"last_api_usage": self.token_manager._last_api_usage.to_dict() if self.token_manager._last_api_usage else None
}
def __repr__(self) -> str:
return ( # 返回描述信息
f"YosugaLLMCore(model={self.model_config.model_name}, "
f"provider={self.model_config.provider.value}, "
f"history_len={len(self._history)})"
)
# 使用示例与测试
if __name__ == "__main__":
import sys
# 配置日志输出
logger.remove()
logger.add(sys.stderr, level="DEBUG")
print("\n" + "=" * 50)
print("🚀 Yosuga Server LLM Core 启动测试")
print("=" * 50 + "\n")
# 准备模拟的动作处理器 (Mock Handlers)
# 异步处理器示例:处理语音回复
async def handle_audio_response(data: LLMCoreAnalysisBase):
# 这里强制类型转换为具体子类以获取代码提示(实际运行时已经是具体类型)
if data.type == "audio_text":
print(f" [Audio Handler] 正在合成语音: {data.response_text} | 情感: {data.emotion}")
return {"audio_file": "sample.wav", "duration": 3.5}
return None
# 异步处理器示例:处理自动化操作
async def handle_auto_agent(data: LLMCoreAnalysisBase):
print(f" [UI Agent] 收到指令: {data.Action}")
await asyncio.sleep(1) # 模拟耗时操作
print(f" [UI Agent] 执行动作: {data.Action} -> ({data.x1}, {data.y1})")
return {"status": "success", "executed": data.Action}
async def handle_call_auto_agent(data: LLMCoreAnalysisBase):
print(f" [Call Agent] 收到内容: {data.type}")
print(f" [Call Agent] 收到内容: {data.llm_translation}")
# 回退处理器
def handle_fallback(data: LLMCoreAnalysisBase):
print(f" [Fallback] 未知类型数据: {data.type}, 内容: {data.model_dump_json()}")
# 初始化 LLM Core
# 配置 LM Studio 连接
# 注意:LM Studio 通常兼容 OpenAI 格式,所以 provider 选 LM_STUDIO 或 OPENAI 均可
# 如果是本地服务,API Key 可以随意填写
model_cfg = ModelConfig(
provider=ModelProvider.LM_STUDIO,
model_name="qwen/qwen3-4b-2507",
base_url="http://192.168.1.3:1234/v1",
api_key="lm-studio",
temperature=0.3,
max_tokens=2048
)
core_cfg = LLMCoreConfig(
max_context_tokens=1024,
enable_history=True,
role_setting="你是由Misakiotoha开发的Yosuga助手,性格抽象,爱说点小骚话。",
auto_dispatch=True,
dispatch_async=True # 启用异步分发测试
)
core = YosugaLLMCore(model_cfg, core_cfg)
# 注册处理器
core.register_action_handler("audio_text", handle_audio_response, is_async=True)
core.register_action_handler("auto_agent", handle_auto_agent, is_async=True)
core.register_action_handler("call_auto_agent", handle_call_auto_agent, is_async=True)
core.set_fallback_handler(handle_fallback)
# 注册解析器
core.register_analysis_model(YosugaAudioResponseData)
core.register_analysis_model(YosugaUITARSResponseData)
core.register_analysis_model(YosugaUITARSRequestData)
# 交互测试 Loop
def run_tests():
# 测试场景 1: 普通对话 (触发 Audio 解析)
print("\n📝 测试 1: 普通对话 (预期触发 audio_text)")
asr_input_1 = {
"text": "你好,Yosuga!请介绍一下你自己,并对我微笑。",
"confidence": 0.99
}
try:
result = core.interact(
asr_input_1,
dispatch_async=True
)
print(f"🏁 交互结果: {json.dumps(result, ensure_ascii=False, indent=2)}")
except Exception as e:
logger.error(f"测试1失败: {e}")
# 测试场景 2: 复杂指令 (预期同时触发 Audio 和 UI 操作)
print("\n📝 测试 2: 混合指令 (预期触发 audio_text + auto_agent)")
# 构造一个复杂的 Prompt 输入,诱导模型输出多条指令
# 注意:这依赖于模型足够聪明能理解 System Prompt 中的 output schema
complex_input = """
[{
"text": "打开系统设置",
"confidence": 0.99
}]
"""
try:
result = core.interact(complex_input, dispatch_async=True)
print(f"🏁 交互结果: {json.dumps(result, ensure_ascii=False, indent=2)}")
except Exception as e:
logger.error(f"测试2失败: {e}")
# 热重载测试
print("\n🔄 测试 3: 模型热重载")
new_model_cfg = ModelConfig(
provider=ModelProvider.LM_STUDIO,
model_name="qwen/qwen3-vl-8b",
base_url="http://192.168.1.3:1234/v1",
api_key="lm-studio",
temperature=0.5
)
try:
core.reload_model(new_model_cfg)
print("✅ 热重载完成,进行验证对话...")
# 验证重载后是否还能对话(保留了上下文)
verify_input = """
{
"text": "系统设置打开了吗",
"confidence": 0.95
},
{
"auto_agent": "Thought: 我注意到屏幕左下角有一个齿轮形状的图标,这正是系统的设置入口。在KDE系统中,这个图标通常用来访问系统设置面板。为了帮助用户打开设置界面,我现在需要点击这个位于屏幕左下方的齿轮图标。
Action: click(start_box='<|box_start|>(42,1045)<|box_end|>')"
}
}"""
result = core.interact(verify_input)
print(f"🏁 重载后回复: {json.dumps(result, ensure_ascii=False, indent=2)}")
except Exception as e:
logger.error(f"热重载测试失败: {e}")
# 查看统计信息
print("\n📊 最终状态统计:")
print(core.get_context_stats())
# 运行测试
run_tests()
@@ -0,0 +1,313 @@
# llm_core/llm_core_analysis.py
"""
LLM输出解析与序列化模块
将LLM返回的JSON字符串智能解析为强类型Python对象,供其他模块直接调用
支持多类型混合响应,自动路由到对应数据模型
"""
import json
from abc import ABC, abstractmethod
from typing import Any, ClassVar, Dict, Optional, Type, List
from pydantic import BaseModel, Field, ValidationError, field_validator
from loguru import logger
import re
# 抽象基类
class LLMCoreAnalysisBase(BaseModel, ABC):
"""
LLM输出数据模型抽象基类
各场景通过继承定义具体的数据结构
"""
# 解析器type标识
type: str = Field(..., description="场景类型标识")
@classmethod
@abstractmethod
def type_(cls) -> str:
"""返回该模型对应的场景类型标识"""
pass
@classmethod
def get_schema(cls) -> Dict[str, Any]:
"""
返回该场景的数据结构schema
用于生成system prompt中的{OutputInfo}
"""
return cls.model_json_schema()
# 管理器
class LLMCoreAnalysisManager:
"""
LLM输出解析管理器
智能路由:根据JSON中的type_字段,自动选择对应模型进行解析
"""
# 类变量:存储所有注册的数据模型类
_model_registry: ClassVar[Dict[str, Type[LLMCoreAnalysisBase]]] = {}
@classmethod
def register(cls, model_class: Type[LLMCoreAnalysisBase]) -> None:
"""
注册场景数据模型类
Args:
model_class: 继承自LLMCoreAnalysisBase的数据模型类
Example:
LLMCoreAnalysisManager.register(YosugaAudioResponseData)
"""
type_id = model_class.type_()
cls._model_registry[type_id] = model_class
logger.info(f"已注册LLM数据输出解析模型: {type_id} -> {model_class.__name__}")
@classmethod
def parse(cls, json_str: str) -> List[LLMCoreAnalysisBase]:
"""
统一解析入口:无论单对象还是数组,总是返回对象列表
Args:
json_str: LLM原始JSON字符串(支持markdown代码块)
Returns:
解析后的模型实例列表,顺序与JSON数组一致
Raises:
ValidationError: 格式校验失败
ValueError: JSON解析失败
"""
print(f"待解析的内容为(llm本次输出原生内容):{json_str}") # TODO:delete
cleaned = cls._clean_markdown(json_str) # 清理markdown标记
try:
data = json.loads(cleaned)
# 统一包装成列表
if not isinstance(data, list):
data = [data]
except json.JSONDecodeError as e:
logger.error(f"JSON解析失败: {e}\n原始输出: {cleaned[:200]}...")
raise ValueError(f"无效的JSON格式: {e}")
results: List[LLMCoreAnalysisBase] = []
for idx, item in enumerate(data):
type_id = item.get("type")
if not type_id:
logger.warning(f"跳过第{idx}个元素(无type字段): {item}")
continue
if type_id not in cls._model_registry:
logger.warning(f"未注册的类型 '{type_id}',可用类型: {list(cls._model_registry.keys())}")
continue
model_class = cls._model_registry[type_id]
try:
# 重新序列化为字符串再解析(保持接口兼容)
item_json = json.dumps(item, ensure_ascii=False)
result = model_class.model_validate_json(item_json)
results.append(result)
except ValidationError as e:
logger.error(f"{idx}个对象校验失败 (type={type_id}): {e}")
continue
except Exception as e:
logger.error(f"{idx}个对象解析失败 (type={type_id}): {e}")
continue
logger.success(f"解析完成 | 成功: {len(results)}/{len(data)}")
return results
@staticmethod
def _clean_markdown(json_str: str) -> str:
# 尝试找到第一个 '[' 和最后一个 ']'
start = json_str.find('[')
end = json_str.rfind(']')
if start != -1 and end != -1:
return json_str[start:end + 1]
return json_str # Fallback
# 具体场景数据模型
class YosugaAudioResponseData(LLMCoreAnalysisBase):
"""
音频ASR场景的LLM输出数据模型
使用示例:
LLMCoreAnalysisManager.register(YosugaAudioResponseData)
data = YosugaAudioResponseData.parse_raw('{"type_": "audio_text", "response_text": "你好"}')
data.response_text # 直接属性访问
'你好'
data.emotion # 默认值
'neutral'
"""
type: str = Field(default="audio_text", description="固定为audio_text")
response_text: str = Field(..., description="回复文本")
emotion: str = Field(default="neutral", description="情感基调")
action: str = Field(default="none", description="动作指令")
@classmethod
def type_(cls) -> str:
return "audio_text"
class YosugaUITARSResponseData(LLMCoreAnalysisBase):
"""
自动化操作场景的LLM输出数据模型
"""
type: str = Field(default="auto_agent", description="固定为auto_agent")
Action: str = Field(..., description="动作名称")
x1: Optional[int] = Field(default=None, description="起点x")
y1: Optional[int] = Field(default=None, description="起点y")
x2: Optional[int] = Field(default=None, description="终点x")
y2: Optional[int] = Field(default=None, description="终点y")
key: Optional[str] = Field(default="", description="快捷键")
content: Optional[str] = Field(default="", description="输入文本")
direction: Optional[str] = Field(default="", description="滚动方向")
@classmethod
def type_(cls) -> str:
return "auto_agent"
@field_validator('x1', 'y1', 'x2', 'y2', mode='before')
@classmethod
def convert_optional_int(cls, v: Any) -> Optional[int]:
"""
将字符串类型的坐标值转换为 int,空字符串转为 None
Args:
v: 原始值(可能是 str, int, None
Returns:
Optional[int]: 转换后的值
"""
if v is None:
return None
if isinstance(v, str):
# 处理空字符串
if v.strip() == "":
return None
# 尝试转换为 int
try:
return int(v)
except ValueError:
logger.warning(f"无法将字符串 '{v}' 转换为 int,返回 None")
return None
if isinstance(v, int): # 如果原始值本身就是int类型,直接return
return v
if isinstance(v, float): # 如果解析出了float,强转成int再return
return int(v)
logger.warning(f"意外的类型 {type(v)},返回 None")
return None
def to_dict(self) -> dict:
"""
将模型转换为字典,用于生成JSON字符串
Returns:
dict: 模型数据字典
"""
return {
"type": self.type,
"Action": self.Action,
"x1": self.x1,
"y1": self.y1,
"x2": self.x2,
"y2": self.y2,
"key": self.key,
"content": self.content,
"direction": self.direction
}
class YosugaUITARSRequestData(LLMCoreAnalysisBase):
"""
自动化操作场景的LLM对自动化agent的调用的输出解析模型
"""
type: str = Field(default="call_auto_agent", description="固定为call_auto_agent")
llm_translation: str = Field(default="", description="针对用户的意图的转译,如果用户的意图足够明确,直接照抄即可")
@classmethod
def type_(cls) -> str:
return "call_auto_agent"
class YosugaLive2DResponseData(LLMCoreAnalysisBase):
"""
Live2D控制场景的LLM输出数据模型 TODO
"""
type: str = Field(default="live2d_control", description="固定为live2d_control")
parameter: str = Field(..., description="参数名")
value: float = Field(..., description="目标值")
duration: int = Field(default=500, description="过渡时间(ms)")
@classmethod
def type_(cls) -> str:
return "live2d_control"
class YosugaEmbeddedResponseData(LLMCoreAnalysisBase):
"""
嵌入式设备场景的LLM输出数据模型 TODO
"""
type: str = Field(default="embedded_control", description="固定为embedded_control")
device_id: str = Field(..., description="设备ID")
command: str = Field(..., description="控制指令")
params: Optional[Dict[str, Any]] = Field(default=None, description="参数")
@classmethod
def type_(cls) -> str:
return "embedded_control"
# 使用示例
if __name__ == "__main__":
from loguru import logger
# 注册所有数据模型
LLMCoreAnalysisManager.register(YosugaAudioResponseData)
LLMCoreAnalysisManager.register(YosugaUITARSResponseData)
LLMCoreAnalysisManager.register(YosugaLive2DResponseData)
LLMCoreAnalysisManager.register(YosugaEmbeddedResponseData)
logger.info("=== LLM输出解析模块测试 ===")
# 测试单对象解析
print("\n【测试1:单对象自动识别】")
llm_output = '''
```json
[{
"type": "audio_text",
"response_text": "收到!我会微笑回应",
"emotion": "cheerful"
}]
```
'''
response = LLMCoreAnalysisManager.parse(llm_output)
print(f"类型: {response}")
# 3. 测试多对象解析
print("\n【测试2:多对象混合响应】")
multi_output = '''
```json
[
{
"type": "auto_agent",
"Action": "click",
"x1": "100",
"y1": "200"
},
{
"type": "live2d_control",
"parameter": "ParamEyeLOpen",
"value": 0.8,
"duration": 300
}
]
```
'''
results = LLMCoreAnalysisManager.parse(multi_output)
print(f"类型: {results}")
@@ -0,0 +1,320 @@
# llm_core/llm_core_dispatcher.py
"""
动作分发器模块
负责将解析后的LLM输出对象路由到对应的业务处理器
支持同步/异步处理,提供回退机制与执行结果追踪
"""
import asyncio
from typing import Any, Callable, ClassVar, Dict, List, Optional, Union
from functools import wraps
from loguru import logger
from src.server_core.llm_core.llm_core_analysis import LLMCoreAnalysisBase
# 处理器类型定义
# 同步处理器:接收解析对象,返回执行结果
SyncActionHandler = Callable[[LLMCoreAnalysisBase], Any]
# 异步处理器:接收解析对象,返回协程
AsyncActionHandler = Callable[[LLMCoreAnalysisBase], Any]
def handler_error_wrapper(func: Callable) -> Callable:
"""
处理器错误包装装饰器
统一捕获异常并记录日志,避免单个处理失败导致整体崩溃
"""
@wraps(func)
def sync_wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
logger.exception(f"处理器 {func.__name__} 执行失败: {e}")
raise
@wraps(func)
async def async_wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except Exception as e:
logger.exception(f"异步处理器 {func.__name__} 执行失败: {e}")
raise
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
class LLMCoreActionDispatcher:
"""
LLM动作分发中心
职责:
1. 管理类型到处理器的映射关系
2. 支持同步与异步处理器注册
3. 执行分发并收集处理结果
4. 提供未注册类型的回退机制
5. 支持批量处理与结果聚合
使用示例:
# 注册同步处理器
def handle_audio(data: YosugaAudioResponseData) -> dict:
return {"status": "spoken", "text": data.response_text}
LLMCoreActionDispatcher.register("audio_text", handle_audio)
# 注册异步处理器
async def handle_ui(data: YosugaUITARSResponseData):
await perform_click(data.x1, data.y1)
return {"status": "clicked"}
LLMCoreActionDispatcher.register_async("auto_agent", handle_ui)
# 执行分发
results = LLMCoreActionDispatcher.execute(parsed_objects)
"""
# 类变量存储所有注册的处理器
_sync_handlers: ClassVar[Dict[str, SyncActionHandler]] = {}
_async_handlers: ClassVar[Dict[str, AsyncActionHandler]] = {}
_fallback_handler: ClassVar[Optional[Union[SyncActionHandler, AsyncActionHandler]]] = None
@classmethod
def register(cls, type_id: str, handler: SyncActionHandler) -> None:
"""
注册同步处理器
Args:
type_id: 与 LLMCoreAnalysisBase.type_() 返回值匹配的标识符
handler: 同步函数,接收解析对象并返回执行结果
Raises:
ValueError: 处理器不是可调用的函数
"""
if not callable(handler):
raise ValueError(f"处理器必须是可调用函数,收到: {type(handler)}")
# 检查是否已注册,防止覆盖
if type_id in cls._sync_handlers or type_id in cls._async_handlers:
logger.warning(f"类型 '{type_id}' 已被注册,将被覆盖")
cls._sync_handlers[type_id] = handler_error_wrapper(handler)
logger.success(f"注册同步处理器: {type_id}{handler.__name__}")
@classmethod
def register_async(cls, type_id: str, handler: AsyncActionHandler) -> None:
"""
注册异步处理器
Args:
type_id: 类型标识符
handler: 异步函数,接收解析对象并返回协程
Raises:
ValueError: 处理器不是有效的异步函数
"""
if not asyncio.iscoroutinefunction(handler):
raise ValueError(f"异步处理器必须是协程函数,收到: {type(handler)}")
if type_id in cls._sync_handlers or type_id in cls._async_handlers:
logger.warning(f"类型 '{type_id}' 已被注册,将被覆盖")
cls._async_handlers[type_id] = handler_error_wrapper(handler)
logger.success(f"注册异步处理器: {type_id}{handler.__name__}")
@classmethod
def set_fallback(cls, handler: Union[SyncActionHandler, AsyncActionHandler]) -> None:
"""
设置回退处理器(未注册类型的默认处理)
Args:
handler: 同步或异步函数,处理所有未匹配的类型
"""
wrapped = handler_error_wrapper(handler)
cls._fallback_handler = wrapped
handler_type = "异步" if asyncio.iscoroutinefunction(handler) else "同步"
logger.info(f"设置{handler_type}回退处理器: {handler.__name__}")
@classmethod
def get_handler(cls, type_id: str) -> Optional[Union[SyncActionHandler, AsyncActionHandler]]:
"""获取指定类型的处理器"""
# 优先返回同步处理器
if type_id in cls._sync_handlers:
return cls._sync_handlers[type_id]
# 其次返回异步处理器
if type_id in cls._async_handlers:
return cls._async_handlers[type_id]
# 返回回退处理器
return cls._fallback_handler
@classmethod
def execute(
cls,
analysis_results: List[LLMCoreAnalysisBase],
run_async: bool = False
) -> Dict[str, List[Any]]:
"""
执行分发处理
Args:
analysis_results: 解析器返回的对象列表
run_async: 是否启用异步执行模式(需要业务代码支持asyncio)
Returns:
执行结果字典:
{
"success": [处理成功的结果列表],
"failed": [{"type": "...", "error": "..."}],
"skipped": [跳过的类型列表]
}
"""
if not analysis_results:
logger.warning("无对象需要分发")
return {"success": [], "failed": [], "skipped": []}
if run_async and (cls._async_handlers or asyncio.iscoroutinefunction(cls._fallback_handler)):
# 异步执行模式(需要事件循环)
return asyncio.run(cls._execute_async(analysis_results))
# 默认同步执行
return cls._execute_sync(analysis_results)
@classmethod
def _execute_sync(cls, results: List[LLMCoreAnalysisBase]) -> Dict[str, List[Any]]:
"""同步批量执行"""
outputs = {"success": [], "failed": [], "skipped": []}
for idx, result in enumerate(results):
type_id = result.type
logger.debug(f"[{idx}] 分发类型: {type_id}")
handler = cls.get_handler(type_id)
if not handler:
outputs["skipped"].append(type_id)
logger.warning(f"无处理器,跳过: {type_id}")
continue
# 执行同步处理器
if asyncio.iscoroutinefunction(handler):
logger.error(f"异步处理器 '{type_id}' 不能在同步模式下执行")
outputs["failed"].append({"type": type_id, "error": "异步处理器需要run_async=True"})
continue
try:
output = handler(result)
outputs["success"].append({
"type": type_id,
"output": output,
"index": idx
})
logger.success(f"[{idx}] 处理成功: {type_id}")
except Exception as e:
outputs["failed"].append({
"type": type_id,
"error": str(e),
"index": idx
})
logger.error(f"[{idx}] 处理失败 {type_id}: {e}")
cls._log_summary(outputs, len(results))
return outputs
@classmethod
async def _execute_async(cls, results: List[LLMCoreAnalysisBase]) -> Dict[str, List[Any]]:
"""异步批量执行"""
outputs = {"success": [], "failed": [], "skipped": []}
tasks = []
for idx, result in enumerate(results):
type_id = result.type
logger.debug(f"[{idx}] 异步分发类型: {type_id}")
handler = cls.get_handler(type_id)
if not handler:
outputs["skipped"].append(type_id)
logger.warning(f"无处理器,跳过: {type_id}")
continue
# 创建协程任务
if asyncio.iscoroutinefunction(handler):
task = cls._run_async_handler(handler, result, idx, outputs)
else:
# 同步处理器在异步线程池中执行
task = cls._run_sync_in_executor(handler, result, idx, outputs)
tasks.append(task)
# 并发执行所有任务
await asyncio.gather(*tasks, return_exceptions=True)
cls._log_summary(outputs, len(results))
return outputs
@classmethod
async def _run_async_handler(cls, handler, result, idx, outputs):
"""运行异步处理器"""
try:
output = await handler(result)
outputs["success"].append({
"type": result.type,
"output": output,
"index": idx
})
logger.success(f"[{idx}] 异步处理成功: {result.type}")
except Exception as e:
outputs["failed"].append({
"type": result.type,
"error": str(e),
"index": idx
})
logger.error(f"[{idx}] 异步处理失败 {result.type}: {e}")
@classmethod
async def _run_sync_in_executor(cls, handler, result, idx, outputs):
"""在线程池中运行同步处理器"""
loop = asyncio.get_event_loop()
try:
output = await loop.run_in_executor(None, handler, result)
outputs["success"].append({
"type": result.type,
"output": output,
"index": idx
})
logger.success(f"[{idx}] 同步处理器(线程池)成功: {result.type}")
except Exception as e:
outputs["failed"].append({
"type": result.type,
"error": str(e),
"index": idx
})
logger.error(f"[{idx}] 同步处理器(线程池)失败 {result.type}: {e}")
@classmethod
def _log_summary(cls, outputs: Dict, total: int):
"""输出处理摘要"""
success_count = len(outputs["success"])
failed_count = len(outputs["failed"])
skipped_count = len(outputs["skipped"])
logger.info(
f"分发完成 | 总计: {total} | 成功: {success_count} "
f"失败: {failed_count} 跳过: {skipped_count}"
)
@classmethod
def clear(cls) -> None:
"""清空所有处理器(测试/热重载用)"""
cls._sync_handlers.clear()
cls._async_handlers.clear()
cls._fallback_handler = None
logger.info("已清空所有动作处理器")
@classmethod
def list_handlers(cls) -> Dict[str, str]:
"""列出当前注册的所有处理器"""
handlers = {}
for type_id, handler in cls._sync_handlers.items():
handlers[type_id] = f"同步: {handler.__name__}"
for type_id, handler in cls._async_handlers.items():
handlers[type_id] = f"异步: {handler.__name__}"
return handlers
@@ -0,0 +1,174 @@
# llm_core/llm_core_prompt_manager.py
"""
llm prompt 结构化信息管理
用于将各种输入到llm_core的数据流结构化,并附加注释,以方便llm可以准确的理解并可以结构化返回相关内容
"""
from abc import ABC, abstractmethod
from pydantic import BaseModel, Field, field_validator
from typing import Callable, List, Optional, Coroutine, Any, ClassVar, Dict
from src.server_core.llm_core.llm_core_prompts import YOSUGA_SYSTEM_PROMPT_SCH
class LLMCorePromptBase(BaseModel, ABC):
"""LLM 提示词基类:定义输入输出结构"""
@abstractmethod
def type(self) -> str:
"""返回该提示词类型的唯一标识"""
pass
@abstractmethod
def describe_input(self) -> str:
"""生成输入格式的自然语言描述(填充到 {InputInfo}"""
pass
@abstractmethod
def describe_output(self) -> str:
"""生成输出格式的自然语言描述(填充到 {OutputInfo}"""
pass
def to_json(self) -> str:
return self.model_dump_json()
@classmethod
def from_json(cls, json_data: str):
return cls.model_validate_json(json_data)
class LLMCorePromptManager(LLMCorePromptBase):
"""聚合所有子类,生成完整的 prompt 信息"""
# 类变量:存储所有注册的子类
_registry: ClassVar[Dict[str, LLMCorePromptBase]] = {}
def get_registry_size(self) -> int:
return len(self._registry)
def register(self, son: LLMCorePromptBase) -> None:
self._registry[son.type()] = son
def type(self) -> str:
return "manager"
def describe_input(self) -> str:
"""聚合所有子类的输入描述"""
return "\n".join(
f"[{type_id}]{son.describe_input()}\n"
for type_id, son in self._registry.items()
)
def describe_output(self) -> str:
"""聚合所有子类的输出描述"""
return "\n".join(
f"[{type_id}]{son.describe_output()}\n"
for type_id, son in self._registry.items()
)
class YosugaAudioASRText(LLMCorePromptBase):
"""音频ASR文本输入场景"""
def type(self) -> str:
return "用户语音asr信息"
def describe_input(self) -> str:
return '''
当用户通过语音与Yosuga交互时,你会收到如下JSON结构:
{
"text": "用户说的话(字符串)",
"confidence": 0.95
}
- `text`: 语音转写的原始文本,可能包含口语化表达或识别错误
- `confidence`: ASR引擎的识别置信度,低于0.8需警惕识别错误
'''
def describe_output(self) -> str:
return '''
针对用户音频识别出的文本内容输入,你应该按以下JSON格式回复:
{
"type": "固定为audio_text",
"response_text": "你的回复文本(字符串)",
"emotion": "neutral",
"action": "none"
}
- `response_text`: 给用户的自然语言回复
- `emotion`: 回复的情感基调,可选值:neutral/cheerful/sad/angry
- `action`: 触发的动作指令,如"wave_hand""nod"等,"none"表示无动作
'''
class YosugaEmbedded(LLMCorePromptBase):
"""嵌入式设备输入场景"""
def type(self) -> str:
pass
def describe_input(self) -> str:
pass
def describe_output(self) -> str:
pass
class YosugaUITARS(LLMCorePromptBase):
"""自动化操作构建场景"""
def type(self) -> str:
return "自动化操作信息"
def describe_input(self) -> str:
return '''
当你尝试调用自动化agent的时候,自动化agent将会返回下面的内容作为输入给你,自动化agent的返回信息很重要,有的任务不是一次就可以完成的,此时你可能需要多次调用自动化agent,直到agent返回finished(任务完成)
{
"auto_agent":"
Thought: 自动化agent的推理过程,可以考虑二次加工后作为回复用户的内容
Action: 对应的自动化动作[包括click(单击坐标), left_double(双击坐标), right_single(右键单击), drag(拖拽), hotkey(快捷键), type(输入文本), scroll(滚动), wait(等待), finished(任务完成)]
"
}
'''
def describe_output(self) -> str:
return '''
1. 当用户试图进行一些自动化操作时,你可以通过返回以下json来调用自动化agent,调用的时候也可以一并回复用户一些文本内容:
{
"type": "固定为call_auto_agent",
"llm_translation": "针对用户的意图的转译,如果用户的意图足够明确,直接照抄即可,注意转译的语言必须是中文,可以混合部分英文(应用名称),这个和聊天交流的语言不统一。"
}
2. 针对自动化agent操作输入,你应该按以下JSON格式回复。如果你发起了自动化agent请求之后,却没有返回下面的JSON内容,那么你的请求并不会作用到用户的设备,所以你调用完自动化agent一定要及时返回下面的JSON内容:
{
"type": "固定为auto_agent",
"Action": "自动化agent返回的相应的自动化动作名称",
"x1": "某个操作的x1,不是所有的操作都有,如果相关的操作没有,写成-1即可,y1,x2,y2也同理",
"y1": "某个操作的y1",
"x2": "某个操作的x2",
"y2": "某个操作的y2",
"key": "快捷键,若当前操作没有该字段信息,此处内容为空即可",
"content": "输入文本的内容,若当前操作没有该字段信息,此处内容为空即可",
"direction": "滚动方向(down or up),若当前操作没有该字段信息,此处内容为空即可,同时需要关心自动化agent回复的x1,和y1坐标"
}
自动化agent返回的操作信息不一定包括JSON的全部字段,例如某次返回只有key的内容,或者只有content的内容。
针对自动化agent操作输入的返回,若没有相关内容可以留空相关字段,请不要省略掉任何字段名称。
注意:自动化agent的状态可见YosugaSystemState表。
'''
class YosugaLive2DControl(LLMCorePromptBase):
"""对Yosuga Live2D控制场景"""
def type(self) -> str:
return "Yosuga Live2D控制信息"
def describe_input(self) -> str:
pass
def describe_output(self) -> str:
pass
if __name__ == "__main__":
# 注册所有prompt处理器
manager = LLMCorePromptManager()
manager.register(YosugaAudioASRText())
manager.register(YosugaUITARS())
# 生成最终 system prompt
system_prompt = YOSUGA_SYSTEM_PROMPT_SCH.format(
InputInfo=manager.describe_input(),
OutputInfo=manager.describe_output(),
RoleSetting="...",
Language="ja",
Memory = "",
SystemStateTable=""
)
print(system_prompt)
@@ -0,0 +1,67 @@
# llm_core/llm_core_prompts.py
YOSUGA_SYSTEM_PROMPT_SCH = """
你是Yosuga_server这个项目的核心LLM。
你将会为用户完成各种任务,进行结构化的输出。
首先向你介绍一下Yosuga这个项目:
本项目的作者是Misakiotoha(みさきおとは[見崎音羽])。
之所以叫Yosuga,这个词来源日语当中的单词""的发音,其意思是"缘分,关系"
本项目分为三个部分:
1. Yosuga:这是项目的前端部分,是Yosuga与用户交互的一层,采用C++20 + Qt6.6.3编写,使用到的核心外部库为Live2D For C++ SDK。
2. Yosuga_server:这是项目的后端部分,是Yosuga的核心,采用python3.11编写,使用到的外部库较多,负责联系项目的各个部分。
3. Yosuga_embedded:这是项目的拓展部分,实现了Yosuga对嵌入式设备拥有几乎完全的自定义控制能力,采用C语言编写,只使用到了cJSON库,平台无关,增强了Yosuga与外界的交互能力。
作为Yosuga_server的核心LLM,你的任务包括以下的内容:
1. 接受自定义的结构化的信息(json),理解并给出正确的回复。
2. 输出自定义的结构化的信息(json),给出正确的回复。
对于你会接受到什么信息,以及如何进行正确的返回,都会在下面做出解释。
你接受到的信息可能有:
1. 来自Yosuga的对于Live2D模型控制的结构化数据。
2. 来自Yosuga_embedded的可控制的嵌入式设备信息的结构化摘要数据。
3. 综合了各种信息的结构化请求数据(例如同时包括用户所说的话语与其他决策信息)。
你的工作流程:
1. 接受来自用户的请求,理解用户的意图(例如用户只有聊天的目的,或者是用户想进行一些操作)
2. 如果用户只希望聊天,那就只和用户聊天即可
3. 如果用户还有其他意图,你需要根据你的能力来完成用户的需求,如果用户的需求不在你的能力范围内,可以婉拒。
你的能力范围:
1. 你可以和用户进行聊天,当用户只是想和你聊天,那就只和用户进行聊天
2. 你可以主动调用自动化agent,当用户有着一些自动化操作的需求时候,你可以按一定调用格式去主动调用自动化agent模型,并理解自动化模型的返回内容,以规定的格式返回给服务端最终作用到用户的设备。
3. 用户也可能会想和你一起玩一些游戏,这个时候你需要理解用户的意图,并调用自动化agent,同时返回一些聊天内容回复,以此和用户一起玩游戏。
4. 你可以和各种种类的嵌入式设备进行交互,当用户有着一些现实的需求时候,你可以根据当前系统的状态,可控制的设备信息等信息综合判断,通过规定的方式返回规定的内容操作嵌入式设备,以满足用户的需求。
同时你还需要进行角色扮演,完美地扮演符合设定的角色,也要兼顾对于各种问题的处理。
以下是你会接收到的结构化的信息(带解释):
{InputInfo}
以下是你需要根据不同的结构化信息需要输出的结构化信息(带解释):
{OutputInfo}
你应当扮演的角色设定为:
{RoleSetting}
你和用户聊天时候的回复语言为:
{Language}
注意:你的所有回复必须是一个 JSON 数组,即便只有单项任务。格式为 [ {{ "type": "...", ... }} ]。禁止输出JSON数组外的任何解释性文字,同时也禁止在JSON的回复内容中加入任何表情符号。
注意:你接收到的结构化信息,有时可能会有多个,例如用户在问你的同时,自动化agent也返回了结果,这样的情况很常见,你需要按要求返回一个或者多个json即可。
注意:有时候用户的请求可能既涉及到了聊天,也涉及到了自动化agent的调用,那么你在返回的json数组当中加入应当返回的内容即可,可以是聊天内容回复加上自动化agent调用。
注意:有时用户的请求会比较复杂,需要你连续的调用自动化agent,根据自动化agent的输出,去一步步的完成用户的请求。
还有一些是你必须要参考的内容(例如与用户过去的记忆,系统实时状态表等内容):
与用户过去的记忆:
{Memory}
系统实时状态表:
{SystemStateTable}
"""
YOSUGA_SYSTEM_PROMPT_EN = """
"""
+367
View File
@@ -0,0 +1,367 @@
# llm_core/llm_core_token.py
"""
Token 计算与管理模块
支持双数据源智能切换:
- 优先使用大模型 API 返回的精确 usage 数据
- 当 API 不返回时,自动回退到 tiktoken 手动估算
"""
import time
from typing import List, Dict, Any, Optional, Union
from dataclasses import dataclass, field
import tiktoken
from loguru import logger
@dataclass
class TokenUsage:
"""Token 使用量统计"""
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
source: str = field(default="manual", repr=True) # "api" | "manual"
def to_dict(self) -> Dict[str, Any]:
return {
"prompt_tokens": self.prompt_tokens,
"completion_tokens": self.completion_tokens,
"total_tokens": self.total_tokens,
"source": self.source
}
@dataclass
class TokenizerInfo:
"""Tokenizer 元数据"""
model_name: str
encoding_name: str
is_fallback: bool
estimated_accuracy: str # "high" | "medium" | "low"
class TokenManager:
"""
Token 管理核心类
智能数据源切换:API返回 > 手动估算
"""
def __init__(self, model_name: str):
self.model_name = model_name
self.tokenizer = self._get_tokenizer(model_name)
# 存储最近一次 API 返回的 usage
self._last_api_usage: Optional[TokenUsage] = None
# API 数据有效期(秒),超过此时间则视为过期,回退到手动计算
self._api_usage_expiry: float = 30.0
# 记录 API usage 的获取时间
self._last_api_usage_time: float = 0.0
logger.info(
f"TokenManager 初始化完成 | "
f"模型: {model_name} | "
f"编码器: {self.tokenizer.name}"
)
def _get_tokenizer(self, model_name: str) -> tiktoken.Encoding:
"""获取 tokenizer(与之前实现相同,省略重复代码)"""
model_tokenizer_map = {
"qwen": "gpt-3.5-turbo",
"llama": "gpt-3.5-turbo",
"gemma": "gpt-3.5-turbo",
}
try:
return tiktoken.encoding_for_model(model_name)
except KeyError:
pass
for prefix, mapped_model in model_tokenizer_map.items():
if prefix in model_name.lower():
logger.info(
f"模型 '{model_name}' 映射到 tokenizer '{mapped_model}'"
)
return tiktoken.encoding_for_model(mapped_model)
logger.warning(
f"tiktoken 不支持模型 '{model_name}'"
f"降级使用 'cl100k_base' 编码器"
)
return tiktoken.get_encoding("cl100k_base")
def record_api_usage(self, usage: Optional[Union[Dict[str, int], TokenUsage]]) -> None:
"""
记录大模型 API 返回的 usage 数据
Args:
usage: API 返回的 usage 数据(dict 或 TokenUsage 对象)
如果为 None 或空,则忽略
"""
if not usage:
logger.debug("API usage 为空,未记录")
return
# 在底层已经统一好了不同大模型对于usage的处理
if isinstance(usage, TokenUsage):
api_usage = usage
api_usage.source = "api"
else:
api_usage = TokenUsage(
prompt_tokens=usage.get("prompt_tokens", 0),
completion_tokens=usage.get("completion_tokens", 0),
total_tokens=usage.get("total_tokens", 0),
source="api"
)
self._last_api_usage = api_usage
self._last_api_usage_time = time.time()
logger.debug(
f"记录 API usage | "
f"Prompt: {api_usage.prompt_tokens} | "
f"Completion: {api_usage.completion_tokens} | "
f"Total: {api_usage.total_tokens}"
)
def get_current_usage(self, prefer_api: bool = True) -> TokenUsage:
"""
获取当前 Token 使用情况(智能数据源切换)
Args:
prefer_api: 是否优先使用 API 数据(默认 True)
Returns:
TokenUsage 对象,包含数据来源标记
Notes:
- 如果 prefer_api=True 且 API 数据在有效期内,直接返回 API 数据
- 否则回退到手动估算
"""
current_time = time.time()
# 检查 API 数据是否有效
if prefer_api and self._last_api_usage:
time_elapsed = current_time - self._last_api_usage_time
if time_elapsed <= self._api_usage_expiry:
logger.debug(
f"使用 API Token 数据({time_elapsed:.1f}s 内)| "
f"{self._last_api_usage.prompt_tokens} + "
f"{self._last_api_usage.completion_tokens} = "
f"{self._last_api_usage.total_tokens}"
)
return self._last_api_usage
# API 数据无效或无数据,回退到手动估算
logger.debug("API usage 无效/过期,回退到手动估算")
return self._estimate_manual_usage()
def _estimate_manual_usage(self) -> TokenUsage:
"""内部方法:创建空的手动 usage(占位符)"""
return TokenUsage(source="manual")
def get_context_usage(self, history: List[Any]) -> TokenUsage:
"""
计算对话上下文的 Token 占用(必须手动计算)
注意:
- 上下文占用无法从 API 获得,必须手动估算
- 此方法不涉及 _last_api_usage
Args:
history: 历史消息列表
Returns:
TokenUsage 对象,source 始终为 "manual"
"""
tokens = self.count_messages_tokens(history)
return TokenUsage(
prompt_tokens=tokens,
completion_tokens=0,
total_tokens=tokens,
source="manual"
)
def count_text_tokens(self, text: str) -> int:
"""计算单段文本的 token 数量(与之前相同)"""
if not isinstance(text, str) or not text:
return 0
return len(self.tokenizer.encode(text))
def count_messages_tokens(
self,
messages: List[Any],
tokens_per_message: int = 3
) -> int:
"""计算消息列表的总 token 数量(与之前相同,优化实现)"""
if not messages:
return 0
num_tokens = 0
for msg in messages:
# 统一转换为字典
if hasattr(msg, "to_dict"):
msg_dict = msg.to_dict()
elif hasattr(msg, "model_dump"):
msg_dict = msg.model_dump()
else:
msg_dict = msg
# 计算内容和 role 的 token
num_tokens += self.count_text_tokens(msg_dict.get("content", ""))
num_tokens += self.count_text_tokens(msg_dict.get("role", ""))
num_tokens += tokens_per_message
# 加上回复前缀的开销
num_tokens += 3
return num_tokens
def estimate_chat_tokens(
self,
system_prompt: Optional[str],
history: List[Any],
current_input: str
) -> TokenUsage:
"""
估算一次完整对话所需的 token 数量
Args:
system_prompt: 系统提示词
history: 历史消息列表
current_input: 当前用户输入
Returns:
TokenUsage 对象,source 始终为 "manual"
Notes:
- 这是预估,不是实际 API 返回值
- 用于调试和前置检查
"""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history)
messages.append({"role": "user", "content": current_input})
total = self.count_messages_tokens(messages)
return TokenUsage(
prompt_tokens=total,
completion_tokens=0,
total_tokens=total,
source="manual"
)
def format_usage_log(
self,
usage: Optional[Union[Dict[str, int], TokenUsage]] = None,
source: str = "AUTO"
) -> str:
"""
格式化 token 使用日志
Args:
usage: usage 数据(可选)。如果为 None,自动获取当前 usage
source: 数据来源标记("AUTO" | "API" | "MANUAL" | "CONTEXT"
Returns:
格式化的日志字符串
"""
# 如果未提供 usage,自动获取
if usage is None:
if source == "CONTEXT":
# 上下文场景必须手动计算
usage_obj = self.get_context_usage([])
else:
# 自动选择最佳数据源
usage_obj = self.get_current_usage(prefer_api=True)
else:
# 使用提供的 usage
if isinstance(usage, TokenUsage):
usage_obj = usage
else:
usage_obj = TokenUsage(
prompt_tokens=usage.get("prompt_tokens", 0),
completion_tokens=usage.get("completion_tokens", 0),
total_tokens=usage.get("total_tokens", 0),
source="api" if source == "API" else "manual"
)
# 根据来源选择前缀和图标
prefix_map = {
"API": "API Token统计",
"MANUAL": "手动估算",
"CONTEXT": "上下文占用",
"AUTO": "Token统计"
}
prefix = prefix_map.get(source, "Token统计")
# 添加数据来源标记
source_icon = "" if usage_obj.source == "api" else "🧮"
# 格式化输出
if usage_obj.completion_tokens > 0:
return (
f"{prefix} {source_icon} | "
f"Prompt: {usage_obj.prompt_tokens} | "
f"Completion: {usage_obj.completion_tokens} | "
f"Total: {usage_obj.total_tokens}"
)
else:
return (
f"{prefix} {source_icon} | "
f"Total: {usage_obj.total_tokens}"
)
def get_tokenizer_info(self) -> TokenizerInfo:
"""获取当前 tokenizer 的详细信息(与之前相同)"""
if "cl100k_base" in self.tokenizer.name and "gpt-3.5" not in self.model_name:
accuracy = "low"
elif self.model_name in self.tokenizer.name:
accuracy = "high"
else:
accuracy = "medium"
return TokenizerInfo(
model_name=self.model_name,
encoding_name=self.tokenizer.name,
is_fallback="cl100k_base" in self.tokenizer.name,
estimated_accuracy=accuracy
)
def is_token_limit_approaching(
self,
current_tokens: int,
limit: int,
threshold: float = 0.85
) -> bool:
"""判断 token 使用量是否接近限制(与之前相同)"""
return current_tokens > limit * threshold
def calculate_chunk_size(
self,
available_tokens: int,
safety_margin: float = 0.1
) -> int:
"""计算安全的消息块大小(与之前相同)"""
return int(available_tokens * (1 - safety_margin))
def clear_api_usage_cache(self):
"""清空 API usage 缓存(用于测试)"""
self._last_api_usage = None
self._last_api_usage_time = 0.0
logger.debug("API usage 缓存已清空")
# 单例工厂
class TokenManagerFactory:
"""TokenManager 工厂类,支持缓存复用"""
_instances: Dict[str, TokenManager] = {}
@classmethod
def get_manager(cls, model_name: str) -> TokenManager:
if model_name not in cls._instances:
cls._instances[model_name] = TokenManager(model_name)
return cls._instances[model_name]
@classmethod
def clear_cache(cls):
cls._instances.clear()
logger.info("TokenManager 缓存已清空")
+242
View File
@@ -0,0 +1,242 @@
### 服务端llm核心
服务端以AI为核心去驱动,进行各种调用。
本模块为服务端的核心llm模块,这个llm必须足够聪明,能够稳定返回结构化结构。
llm_core负责实现:
#### 类图
```mermaid
classDiagram
class YosugaLLMCore {
<<主控制器>>
-ModelConfig model_config
-LLMCoreConfig core_config
-UnifiedLLM llm_client
-TokenManager token_manager
-LLMCorePromptManager prompt_manager
-List[ChatMessage] _history
-Lock _history_lock
-Lock _config_lock
+interact() Dict
+register_prompt_module() void
+register_action_handler() void
+reload_model() void
+get_context_stats() Dict
}
class LLMCoreConfig {
<<运行时配置>>
-int max_context_tokens
-bool enable_history
-str language
-str role_setting
-bool auto_dispatch
-bool dispatch_async
-str memory
-str system_state_table
}
class UnifiedLLM {
<<大模型调用层>>
-ModelConfig config
-BaseLLMClient client
+chat() ModelResponse
+complete() ModelResponse
+stream_chat() Iterator
+update_config() void
}
class TokenManager {
<<Token管理>>
-str model_name
-tiktoken.Encoding tokenizer
-TokenUsage _last_api_usage
+record_api_usage() void
+get_current_usage() TokenUsage
+get_context_usage() TokenUsage
+count_messages_tokens() int
+format_usage_log() str
}
class LLMCorePromptManager {
<<Prompt管理>>
-Dict _registry
+register() void
+describe_input() str
+describe_output() str
}
class LLMCoreAnalysisManager {
<<输出解析>>
<<静态类>>
-Dict _model_registry
+register() void
+parse() List[LLMCoreAnalysisBase]
}
class LLMCoreActionDispatcher {
<<动作分发>>
<<静态类>>
-Dict _sync_handlers
-Dict _async_handlers
-Callable _fallback_handler
+register() void
+register_async() void
+execute() Dict
}
class LLMCoreAnalysisBase {
<<抽象基类>>
<<模型数据基类>>
#str type
+type_() str
+get_schema() Dict
}
class ChatMessage {
<<消息实体>>
-str role
-str content
-str name
+to_dict() Dict
}
class ModelResponse {
<<响应实体>>
-str content
-str model
-Dict usage
-str finish_reason
-Dict raw_response
}
YosugaLLMCore --> LLMCoreConfig : 持有配置
YosugaLLMCore --> UnifiedLLM : 调用大模型
YosugaLLMCore --> TokenManager : 统计Token
YosugaLLMCore --> LLMCorePromptManager : 管理Prompt
YosugaLLMCore --> LLMCoreAnalysisManager : 解析输出
YosugaLLMCore --> LLMCoreActionDispatcher : 分发动作
YosugaLLMCore --> ChatMessage : 管理历史
UnifiedLLM --> ModelResponse : 返回响应
LLMCoreAnalysisManager --> LLMCoreAnalysisBase : 解析为
TokenManager --> ModelResponse : 接收usage
```
#### 时序图
```mermaid
sequenceDiagram
participant Client as 客户端
participant Core as YosugaLLMCore
participant Token as TokenManager
participant Prompt as PromptManager
participant LLM as UnifiedLLM
participant Parser as AnalysisManager
participant Dispatcher as ActionDispatcher
Client->>Core: interact(user_input)
Note over Core: 输入预处理
Core->>Token: count_messages_tokens(_history)
Token-->>Core: current_usage
Core->>Core: _maintain_context_limit()
Note over Core: 检查溢出并清理
Core->>Prompt: get_system_prompt()
Note over Prompt: 聚合InputInfo/OutputInfo
Prompt-->>Core: system_prompt
Core->>Core: _build_request_messages()
Note over Core: 组装[system, history, user]
Core->>Token: estimate_chat_tokens()
Token-->>Core: estimated_usage
Core->>LLM: chat(messages)
Note over LLM: 调用底层API
LLM-->>Core: ModelResponse(content, usage)
Core->>Token: record_api_usage(usage)
Note over Token: 优先使用API数据
Core->>Parser: parse(content)
Note over Parser: JSON清洗+类型校验
Parser-->>Core: List[AnalysisObj]
Core->>Core: _add_to_history(user+assistant)
Note over Core: 更新对话记忆
Core->>Dispatcher: execute(parsed_results)
par 分发处理
Dispatcher->>Handler1: 同步处理(audio_text)
Dispatcher->>Handler2: 异步处理(auto_agent)
end
Dispatcher-->>Core: {"success": [], "failed": []}
Core-->>Client: 执行结果
```
#### 配置与模型热重载状态机
```mermaid
stateDiagram-v2
[*] --> 初始化: YosugaLLMCore()
state 初始化 {
[*] --> 加载配置: ModelConfig
加载配置 --> 创建LLM客户端: UnifiedLLM
创建LLM客户端 --> 注册默认Prompt: Audio/UITARS
注册默认Prompt --> 初始化Token管理器: TokenManager
}
初始化 --> 待机: 等待输入
state 待机 {
[*] --> 构建System Prompt
构建System Prompt --> 检查上下文限制: _maintain_context_limit()
检查上下文限制 --> 上下文溢出: current > limit
检查上下文限制 --> 正常: 否则
上下文溢出 --> 触发回调: _trigger_overflow_callbacks()
触发回调 --> 清理历史: 保留50%
清理历史 --> 正常
正常 --> 组装消息链: _build_request_messages()
}
待机 --> 调用LLM: chat()
调用LLM --> 解析输出: AnalysisManager.parse()
解析输出 --> 更新历史: _add_to_history()
更新历史 --> 分发动作: Dispatcher.execute()
分发动作 --> 返回结果: interact() return
返回结果 --> 待机
待机 --> 热重载: reload_model()
state 热重载 {
[*] --> 更新模型配置: update_config()
更新模型配置 --> 重建LLM客户端: _create_client()
重建LLM客户端 --> 重建Token管理器: TokenManager(model_name)
重建Token管理器 --> [*]
}
热重载 --> 待机
note right of 热重载 : "保留历史记忆\n不影响上下文"
```
+8
View File
@@ -0,0 +1,8 @@
### 本模块为Yosuga_server的AI记忆模块
##### 提供的功能有:
1. 向量化的记忆存储
2. 十分便捷的记忆管理与查询接口
3. 高效的记忆检索
+3
View File
@@ -0,0 +1,3 @@
### 本模块为Yosuga_server的核心业务模块
####
+4
View File
@@ -0,0 +1,4 @@
### 本模块为Yosuga_server的前端部分
为了方便管理服务端