commit 0753da86a886d66b9801b6d1271c1821df6f1f8e
Author: Misakiotoha <1841738040@qq.com>
Date: Tue Feb 3 01:20:00 2026 +0800
first pull
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..ec46e94
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,31 @@
+# Python-generated files
+__pycache__/
+*.py[oc]
+build/
+dist/
+wheels/
+*.egg-info
+
+# from ide
+.vscode
+.idea
+
+# Virtual environments
+.venv
+venv
+
+# config
+settings.json
+
+# logs
+log/*
+
+# temp files
+temp_files/*
+
+# ai models
+src/modules/asr_module/asr_models/*
+
+# uv
+uv.lock
+
diff --git a/.python-version b/.python-version
new file mode 100644
index 0000000..2c07333
--- /dev/null
+++ b/.python-version
@@ -0,0 +1 @@
+3.11
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..c872b40
--- /dev/null
+++ b/README.md
@@ -0,0 +1,85 @@
+### Yosuga_server
+
+## 📊 Project Stats
+
+
+
+
+
+欢迎访问本项目。
+
+首先向你介绍一下Yosuga这个项目:
+
+本项目的作者是Misakiotoha(みさきおとは[見崎音羽])。[call me "Misaki" でいいよ]
+
+之所以叫Yosuga,这个词来源日语当中的单词"縁"的发音,其意思是"缘分,关系"。
+
+本项目分为三个部分:
+1. Yosuga:这是项目的前端部分,是Yosuga与用户交互的一层,采用C++20 + Qt6.6.3编写,使用到的核心外部库为Live2D For C++ SDK。
+2. Yosuga_server:这是项目的后端部分,是Yosuga的核心,采用python3.11编写,使用到的外部库较多,负责联系项目的各个部分。
+3. Yosuga_embedded:这是项目的拓展部分,使得Yosuga对嵌入式设备拥有几乎完全的自定义控制能力,采用C语言编写,只使用到了cJSON库,平台无关,增强了Yosuga与外界的交互能力。
+
+**_本项目为Yosuga_server._**
+
+本项目使用uv构建,基于python3.11.
+本项目由YosugaServer发展而来,项目架构与代码有了相当大的改变。(YosugaServer并未开源,它仅仅是一次小小的尝试)
+
+
+### 如何快速启动本项目?
+1. 确保uv已安装,并添加到环境变量中
+2. 执行`cd Yosuga_server` & `uv sync`
+3. 接着,如果你的电脑带有cuda,那么执行 `uv pip install -r requirements-cuda.txt`
+4. 如果没有cuda,那么执行 `uv pip install -r requirements-cpu.txt`
+5. 最后执行 `uv run python main.py` 即可启动项目
+
+首次启动项目后,会在项目根目录下生成settings.json配置文件,你需要配置一些必要的字段信息:
+```json
+{
+ "ai": {
+ "api_key": "sk-xxxxx",
+ "base_url": "http://localhost:1234/v1",
+ "model_name": "qwen/qwen3-4b-2507"
+ },
+ "tts": {
+ "gpt_model_name": "GPT_weights_v2Pro/Yosuga_Airi-e32.ckpt",
+ "sovits_model_name": "SoVITS_weights_v2Pro/Yosuga_Airi_e16_s864.pth",
+ "host": "localhost",
+ "port": 20261,
+ "reference_audio": "./using/reference.wav"
+ },
+ "asr": {
+ "url": "http://localhost:20260/"
+ },
+ "auto_agent": {
+ "deployment_type": "lmstudio",
+ "model_name": "ui-tars-1.5-7b@q4_k_m",
+ "base_url": "http://localhost:1234/v1"
+ },
+ "llm_core": {
+ "role_character": "你是由Misakiotoha开发的助手稲葉愛理ちゃん,可以和用户一起玩游戏,聊天,做各种事情,性格抽象,没事爱整整活。",
+ "max_context_tokens": 2048,
+ "language": "日本语"
+ }
+}
+```
+上面这些字段的信息,你需要根据你的实际情况进行配置。实际的配置文件的字段名称会比上面的多出不少。
+
+
+配置完成后,再次重启服务端就可以使用啦~
+
+接着是每个模型的配置相关:
+1. asr模型,本项目使用fast-whisper作为asr模型,并且附带了一键启动的部分
+,你需要找到 `Yosuga_server/src/modules/asr_module/start_api.py` 这个文件,然后启动它
+,一般来说,即使是cpu也可以进行asr模型的推理,但是速度相比cuda要逊色很多。
+同时,如果你遇到了启动时长时间加载,那么此时你需要试着挂一下梯子,因为初次启动
+会在Hugging Face上下载模型。
+2. tts模型,本项目使用GPT-SoVITS作为tts模型,建议使用其V2Pro版本。
+3. auto_agent模型,本项目使用的自动化操作识别的模型为字节跳动开源的
+`ui-tars-1.5-7b@q4_k_m` 关于此模型的更多信息可以参考字节跳动的[开源链接](https://github.com/bytedance/UI-TARS)
+,建议在LM Studio上进行部署,该模型十分轻量。
+4. ai模型,该模型限制为大语言模型,没有限制,本项目支持市面上的所有大语言模型。
+
+
+本项目当前并不完善,还有很多需要优化的地方,并且尚未接入Yosuga_embedded。
+
+欢迎大家为本项目贡献代码。
\ No newline at end of file
diff --git a/Test/GPTSoVITSTest.py b/Test/GPTSoVITSTest.py
new file mode 100644
index 0000000..c80fd18
--- /dev/null
+++ b/Test/GPTSoVITSTest.py
@@ -0,0 +1,173 @@
+import asyncio
+from src.modules.tts_module.tts_core.gpt_sovits.gpt_sovits_client import GPTSoVITSClient, StreamingMode
+from src.modules.tts_module.tts_core.async_audio_player import AsyncAudioPlayer
+import sounddevice as sd
+
+test_text = "春の午後、公園のベンチに座って本を読んでいると、小さな子供が凧あげをしているのが目に入った。風に乗って凧が高く上がるたびに、彼の顔には真っすぐな笑顔が広がる。母親がそばで見守りながら、時折声をかけている。\
+空は雲一つない青さで、桜の花びらが風に舞っている。遠くで犬の鳴き声が聞こえ、芝生の上では老夫婦がお茶を楽しんでいた。すべてがゆっくりと流れる時間の中で、自分の心も不思議と落ち着いてくる。\
+ふと、子供の凧が木の枝に引っかかってしまった。少し焦る様子だったが、母親が助けてくれて、すぐにまた空に舞い上がった。失敗しても、誰かが助けてくれる。そんな当たり前のことに、今日は特別な温かさを感じた。\
+日が傾き始める頃、私は本を閉じて家路についた。明日もきっと、誰かの笑顔があるだろう。"
+
+async def test_tts():
+ # 创建客户端(推荐上下文管理器)
+ async with GPTSoVITSClient(debug=True, port= 20261, host="192.168.1.8") as client:
+ # 基础TTS调用
+ try:
+ audio = await client.tts(
+ text="あのさ、いやまあ、なんていうか...要するに、そういうことじゃなくて、ほら、前に言ってたやつ、あれなんだけど、とにかく、後ででもいいから、ちょっと相談に乗ってくれない?",
+ ref_audio_path="uploaded_audio/test_voice.wav", # 服务器上的路径
+ text_lang="ja",
+ prompt_lang="ja",
+ media_type="wav",
+ prompt_text="もう!こんなところで何やってるんだよ!"
+ )
+
+ # 保存音频
+ audio.save("outputs/output.wav")
+ print(f"✅ TTS成功!音频大小: {len(audio.audio_data)} bytes")
+
+ except Exception as e:
+ print(f"❌ 错误: {e}")
+async def test_model_change():
+ async with GPTSoVITSClient(debug=True, port= 20261, host="192.168.1.8") as client:
+ # 切换模型
+ print("🔄 切换GPT模型...")
+ await client.set_gpt_weights(
+ "GPT_weights_v2Pro/Yosuga_Airi-e32.ckpt"
+ )
+
+ print("🔄 切换SoVITS模型...")
+ await client.set_sovits_weights(
+ "SoVITS_weights_v2Pro/Yosuga_Airi_e16_s864.pth"
+ )
+
+
+async def stream_tts():
+ async with GPTSoVITSClient(debug=True, port= 20261, host="192.168.1.8") as client:
+ try:
+ # 使用最快模式流式输出
+ chunk_count = 0
+ async for chunk in await client.tts(
+ text="要するに、そういうことじゃなくて、ほら、前に言ってたやつ、あれなんだけど、とにかく、後ででもいいから、ちょっと相談に乗ってくれない?",
+ ref_audio_path="uploaded_audio/test_voice.wav",
+ text_lang="ja",
+ prompt_lang="ja",
+ prompt_text="もう!こんなところで何やってるんだよ!",
+ streaming_mode=StreamingMode.FASTEST, # 模式3:快速流式
+ media_type="wav"
+ ):
+ chunk_count += 1
+ print(f"🎵 收到音频块 #{chunk_count}: {len(chunk.audio_data)} bytes")
+
+ # 实时播放处理
+ # await play_audio_chunk(chunk.audio_data)
+
+ print(f"✅ 流式TTS完成!共{chunk_count}个音频块")
+
+ except Exception as e:
+ print(f"❌ 流式错误: {e}")
+
+
+async def stream_tts_and_play(
+ text: str,
+ ref_audio_path: str,
+ text_lang: str = "zh",
+ prompt_lang: str = "zh",
+ streaming_mode: StreamingMode = StreamingMode.FASTEST
+):
+ """
+ 实时流式TTS + 播放一体化
+
+ Args:
+ text: 要合成的文本
+ ref_audio_path: 参考音频路径
+ text_lang: 文本语言
+ prompt_lang: 提示语言
+ """
+ # 创建音频播放器(缓冲区大小=5,平衡延迟和稳定性)
+ async with AsyncAudioPlayer(buffer_size=5) as player:
+ # 创建TTS客户端
+ async with GPTSoVITSClient(debug=True, port= 20261, host="192.168.1.8") as client:
+ try:
+ print(f"🎤 开始流式合成: {text[:30]}...")
+ print(f"🎯 流式模式: {streaming_mode.name}")
+
+ # 获取音频流(异步生成器)
+ audio_stream = await client.tts(
+ text=text,
+ ref_audio_path=ref_audio_path,
+ text_lang=text_lang,
+ prompt_lang=prompt_lang,
+ prompt_text="もう!こんなところで何やってるんだよ!",
+ streaming_mode=streaming_mode,
+ media_type="wav",
+ sample_steps=32,
+ top_k=5,
+ temperature=1.0
+ )
+
+ # 动态读取并播放
+ chunk_idx = 0
+ async for audio_chunk in audio_stream:
+ chunk_idx += 1
+ print(f"📥 收到音频块 #{chunk_idx}: {len(audio_chunk.audio_data):6d} bytes")
+
+ # 立即加入播放队列(非阻塞)
+ await player.add_chunk(audio_chunk.audio_data)
+
+ print(f"✅ 合成完成! 共接收 {chunk_idx} 个音频块")
+
+ # 等待播放完成(所有块播完)
+ await player.audio_queue.join()
+ print("🎵 播放完成!")
+
+ except Exception as e:
+ print(f"❌ 错误: {e}")
+ raise
+
+
+async def test_japanese():
+ """测试日语长文本流式播放"""
+ print("=" * 50)
+ print("🗾 日语流式TTS测试")
+ print("=" * 50)
+
+ await stream_tts_and_play(
+ text=test_text,
+ ref_audio_path="uploaded_audio/test_voice.wav",
+ text_lang="ja",
+ prompt_lang="ja",
+
+ streaming_mode=StreamingMode.FASTEST # 模式3:最快
+ )
+
+async def batch_test():
+ """批量处理示例"""
+ async with GPTSoVITSClient() as client:
+ texts = [
+ "你好,世界!",
+ "这是一个批量测试。",
+ "异步批量处理非常高效。"
+ ]
+
+ results = await client.batch_tts(
+ texts=texts,
+ ref_audio_path="archive_jingyuan_1.wav",
+ text_lang="zh"
+ )
+
+ for i, audio in enumerate(results):
+ audio.save(f"output/batch_{i}.wav")
+ print(f"✅ 批量任务 {i + 1}/{len(results)} 完成")
+
+
+
+
+
+if __name__ == "__main__":
+ # 检查音频设备
+ print("🔍 检查音频设备...")
+ print(sd.query_devices())
+ sd.default.device = (None, "pulse") # 使用PulseAudio
+
+ asyncio.run(test_japanese())
\ No newline at end of file
diff --git a/Test/WebsocketTestClient.py b/Test/WebsocketTestClient.py
new file mode 100644
index 0000000..3b730fb
--- /dev/null
+++ b/Test/WebsocketTestClient.py
@@ -0,0 +1,24 @@
+import asyncio
+import json
+from websockets.asyncio.client import connect
+
+async def test_all_types():
+ """测试三种消息类型"""
+ async with connect("ws://localhost:8765") as ws:
+ print("=== 测试JSON消息 ===")
+ await ws.send(json.dumps({
+ "type": "chat",
+ "content": "你好服务器!"
+ }))
+ print(f"收到: {await ws.recv()}")
+
+ print("\n=== 测试文本消息 ===")
+ await ws.send("这是纯文本消息")
+ print(f"收到: {await ws.recv()}")
+
+ print("\n=== 测试二进制消息 ===")
+ await ws.send(b"\x00\x01\x02\x03\x04")
+ print(f"收到: {await ws.recv()}")
+
+if __name__ == "__main__":
+ asyncio.run(test_all_types())
\ No newline at end of file
diff --git a/Test/WebsocketTestServer.py b/Test/WebsocketTestServer.py
new file mode 100644
index 0000000..786deb6
--- /dev/null
+++ b/Test/WebsocketTestServer.py
@@ -0,0 +1,174 @@
+"""
+极简 WebSocket 测试服务器 - 修复版本
+"""
+import asyncio
+import json
+import logging
+from datetime import datetime
+from typing import Set
+
+import websockets
+
+# 配置日志
+logging.basicConfig(
+ level=logging.INFO,
+ format='%(asctime)s - %(message)s'
+)
+
+class SimpleWebSocketServer:
+ def __init__(self, host="localhost", port=8765):
+ self.host = host
+ self.port = port
+ self.clients: Set = set()
+
+ async def handle_connection(self, websocket, path):
+ """处理客户端连接"""
+ client_id = f"{websocket.remote_address[0]}:{websocket.remote_address[1]}"
+ self.clients.add(websocket)
+ logging.info(f"✅ 客户端连接: {client_id} (当前连接数: {len(self.clients)})")
+
+ try:
+ # 发送欢迎消息
+ welcome = {
+ "type": "connect",
+ "data": {
+ "message": "WebSocket 服务器连接成功",
+ "client_id": client_id,
+ "server_time": datetime.now().isoformat(),
+ "status": "connected"
+ },
+ "timestamp": int(datetime.now().timestamp() * 1000)
+ }
+ await websocket.send(json.dumps(welcome))
+
+ async for message in websocket:
+ await self.handle_message(websocket, client_id, message)
+
+ except websockets.exceptions.ConnectionClosed:
+ logging.info(f"❌ 客户端断开: {client_id}")
+ finally:
+ self.clients.discard(websocket)
+ logging.info(f"📊 剩余连接: {len(self.clients)}")
+
+ async def handle_message(self, websocket, client_id, message):
+ """处理收到的消息"""
+ logging.info(f"📨 收到消息 from {client_id}: {message}")
+
+ try:
+ # 尝试解析为 JSON
+ data = json.loads(message)
+ msg_type = data.get("type", "unknown")
+
+ # 根据消息类型回复
+ if msg_type == "ping":
+ # 心跳响应
+ response = {
+ "type": "pong",
+ "data": {
+ "server_time": datetime.now().isoformat(),
+ "latency": "0ms"
+ },
+ "timestamp": int(datetime.now().timestamp() * 1000)
+ }
+
+ elif msg_type == "login":
+ # 登录响应
+ username = data.get("data", {}).get("username", "anonymous")
+ response = {
+ "type": "login_success",
+ "data": {
+ "user_id": f"user_{abs(hash(username)) % 10000}",
+ "username": username,
+ "status": "authenticated"
+ },
+ "timestamp": int(datetime.now().timestamp() * 1000)
+ }
+
+ elif msg_type == "chat":
+ # 聊天消息回应
+ msg_content = data.get("data", {}).get("message", "")
+ response = {
+ "type": "chat_response",
+ "data": {
+ "message": f"服务器收到: {msg_content}",
+ "sender": "server",
+ "received_at": datetime.now().isoformat()
+ },
+ "timestamp": int(datetime.now().timestamp() * 1000)
+ }
+
+ else:
+ # 默认回显
+ response = {
+ "type": "echo",
+ "data": {
+ "original": data.get("data", {}),
+ "original_type": msg_type,
+ "server_processed_at": datetime.now().isoformat()
+ },
+ "timestamp": int(datetime.now().timestamp() * 1000)
+ }
+
+ await websocket.send(json.dumps(response))
+
+ except json.JSONDecodeError:
+ # 不是 JSON,当作纯文本处理
+ response = {
+ "type": "text_echo",
+ "data": {
+ "original": message,
+ "note": "这是文本消息"
+ },
+ "timestamp": int(datetime.now().timestamp() * 1000)
+ }
+ await websocket.send(json.dumps(response))
+
+ async def start(self):
+ """启动服务器"""
+ logging.info(f"🚀 启动 WebSocket 服务器: ws://{self.host}:{self.port}")
+
+ # 创建处理函数包装器(解决参数问题)
+ async def connection_handler(websocket, path):
+ await self.handle_connection(websocket, path)
+
+ # 启动服务器
+ server = await websockets.serve(
+ connection_handler,
+ self.host,
+ self.port,
+ ping_interval=None,
+ ping_timeout=None,
+ close_timeout=None,
+ max_size=10 * 1024 * 1024
+ )
+
+ logging.info("📌 服务器已启动,等待连接...")
+ logging.info("🛑 按 Ctrl+C 停止服务器")
+
+ # 保持服务器运行
+ try:
+ await asyncio.Future() # 永久运行
+ finally:
+ server.close()
+ await server.wait_closed()
+ logging.info("👋 服务器已关闭")
+
+def main():
+ """主函数"""
+ import argparse
+
+ parser = argparse.ArgumentParser(description='极简 WebSocket 测试服务器')
+ parser.add_argument('--host', default='localhost', help='监听地址')
+ parser.add_argument('--port', type=int, default=8088, help='监听端口')
+
+ args = parser.parse_args()
+
+ server = SimpleWebSocketServer(args.host, args.port)
+
+ try:
+ asyncio.run(server.start())
+ except KeyboardInterrupt:
+ logging.info("👋 服务器被用户中断")
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/Test/asrRequestTest.py b/Test/asrRequestTest.py
new file mode 100644
index 0000000..4e7f295
--- /dev/null
+++ b/Test/asrRequestTest.py
@@ -0,0 +1,33 @@
+# requestTest.py
+import requests
+from pathlib import Path
+
+# 指定正确的 MIME 类型
+url = "http://192.168.1.8:20260/transcribe"
+audio_path = Path("test_files/z105300938.wav")
+
+with open(audio_path, "rb") as f:
+ # 明确指定文件名和 MIME 类型
+ files = {
+ "file": (
+ audio_path.name, # 文件名
+ f, # 文件对象
+ "audio/wav" # MIME 类型
+ )
+ }
+
+ response = requests.post(url, files=files)
+
+# 打印响应详情
+print(f"状态码: {response.status_code}")
+print(f"响应头: {response.headers.get('content-type')}")
+
+# 检查响应是否成功
+if response.status_code == 200:
+ result = response.json()
+ print(f"识别结果: {result['data']['text']}")
+ print(f"语言: {result['data']['language']}")
+ print(f"置信度: {result['data']['confidence']}")
+ print(f"处理时间: {result['data']['processing_time']}s")
+else:
+ print(f"错误响应: {response.text}")
\ No newline at end of file
diff --git a/Test/dtosAndTTSAndASR.py b/Test/dtosAndTTSAndASR.py
new file mode 100644
index 0000000..feff988
--- /dev/null
+++ b/Test/dtosAndTTSAndASR.py
@@ -0,0 +1,14 @@
+# 一个小Test, 展示设计的dtos模块与tts和asr的集成
+from src.modules.websocket_base_module.dto.third_dtos import AudioDataDTO
+from src.modules.tts_module.tts_core.async_audio_player import AsyncAudioPlayer
+from src.modules.tts_module.tts_core.gpt_sovits.gpt_sovits_client import GPTSoVITSClient, StreamingMode
+from src.modules.asr_module.client.asr_client import create_asr_client
+
+
+# with create_asr_client(base_url="http://192.168.1.5:20260") as client:
+# # 转录文件
+# result = client.transcribe_file("test_files/test.wav")
+# print(f"识别结果: {result.data.text}")
+# print(f"置信度: {result.data.confidence:.2f}")
+# print(f"耗时: {result.data.processing_time:.3f}s")
+
diff --git a/Test/dtosTest.py b/Test/dtosTest.py
new file mode 100644
index 0000000..d463f9e
--- /dev/null
+++ b/Test/dtosTest.py
@@ -0,0 +1,30 @@
+from src.modules.websocket_base_module.dto.second_dtos import get_json_dto_instance
+from src.modules.websocket_base_module.dto.third_dtos import AudioDataDTO
+from src.modules.websocket_base_module.websocket_core.core_ws_server import get_ws_server
+import asyncio
+from loguru import logger
+async def main():
+ # 获取WebSocket服务器单例
+ ws_server = await get_ws_server()
+ # 获取二级json分发器单例
+ json_dto = await get_json_dto_instance(ws_server)
+
+ # 创建DTO实例(自动注册接收函数)
+ audio_dto = AudioDataDTO(json_dto)
+
+ logger.info("所有DTO接收器已注册,等待客户端连接...")
+
+ # 启动服务器(阻塞)
+ try:
+ await ws_server.run("localhost", 8765)
+ except asyncio.CancelledError:
+ logger.info("服务器任务已取消,正在优雅退出...")
+ finally:
+ logger.info("服务器已停止")
+
+
+if __name__ == "__main__":
+ try:
+ asyncio.run(main())
+ except KeyboardInterrupt:
+ print("\n✓ 服务器已手动终止(按 Ctrl+C)")
\ No newline at end of file
diff --git a/Test/outputs/output.wav b/Test/outputs/output.wav
new file mode 100644
index 0000000..f9d7828
Binary files /dev/null and b/Test/outputs/output.wav differ
diff --git a/Test/outputs/output1.wav b/Test/outputs/output1.wav
new file mode 100644
index 0000000..045e77b
Binary files /dev/null and b/Test/outputs/output1.wav differ
diff --git a/Test/test_files/Screenshot_test.png b/Test/test_files/Screenshot_test.png
new file mode 100644
index 0000000..9da1f4e
Binary files /dev/null and b/Test/test_files/Screenshot_test.png differ
diff --git a/Test/test_files/okoru.wav b/Test/test_files/okoru.wav
new file mode 100644
index 0000000..7b540d9
Binary files /dev/null and b/Test/test_files/okoru.wav differ
diff --git a/Test/test_files/sad.wav b/Test/test_files/sad.wav
new file mode 100644
index 0000000..7c36d27
Binary files /dev/null and b/Test/test_files/sad.wav differ
diff --git a/Test/test_files/test.wav b/Test/test_files/test.wav
new file mode 100644
index 0000000..7aaf00c
Binary files /dev/null and b/Test/test_files/test.wav differ
diff --git a/Test/test_files/z105300938.wav b/Test/test_files/z105300938.wav
new file mode 100644
index 0000000..1afc693
Binary files /dev/null and b/Test/test_files/z105300938.wav differ
diff --git a/Test/textAITest.py b/Test/textAITest.py
new file mode 100644
index 0000000..1f984a1
--- /dev/null
+++ b/Test/textAITest.py
@@ -0,0 +1,88 @@
+from src.modules.text_ai_module.text_ai_core.general_text_ai_req import UnifiedLLM, ModelConfig, ModelProvider, create_llm_client
+from src.config.config import get_settings
+from src.config.convert_env import EnvConverter
+from src.config.file_config import DirectoryInitializer
+
+EnvConverter().convert(backup_existing=True) # 若是首次启动则从env模板中生成env文件
+DirectoryInitializer(get_settings()) # 初始化必要的目录(若不存在则创建)
+
+def test1():
+ """
+ 测试常规调用
+ """
+ # 配置模型
+ config = ModelConfig(
+ provider=ModelProvider.OPENAI,
+ model_name=get_settings().ai_model_name,
+ base_url=get_settings().ai_api_base_url,
+ api_key=get_settings().ai_api_key, # 从环境中取出相关的api_key
+ temperature=0.7,
+ max_tokens=2048
+ )
+
+ # 创建客户端
+ llm = UnifiedLLM(config)
+
+ # 发送消息
+ response = llm.chat([
+ {"role": "system", "content": "你是一个DeepSeek助手"},
+ {"role": "user", "content": "请介绍一下DeepSeek模型的特点"}
+ ])
+
+ print(response.content)
+
+def base_test2():
+ """
+ 测试流式响应
+ """
+ # 使用快捷函数
+ deepseek_llm = create_llm_client(
+ provider="openai", # DeepSeek使用OpenAI兼容接口
+ model_name=get_settings().ai_model_name,
+ api_key=get_settings().ai_api_key,
+ base_url=get_settings().ai_api_base_url
+ )
+
+ # 流式聊天
+ messages = [
+ {"role": "user", "content": "用Python写一个快速排序算法"}
+ ]
+
+ print("正在生成响应...")
+ for chunk in deepseek_llm.stream_chat(messages):
+ print(chunk.content, end="", flush=True)
+
+
+def test_lm_studio():
+ """测试本地 LM Studio 模型"""
+ print("=== 测试本地 LM Studio ===")
+
+ # 使用UnifiedLLM类
+ config = ModelConfig(
+ provider=ModelProvider.LM_STUDIO,
+ model_name="qwen/qwen3-4b-2507",
+ base_url="http://192.168.1.8:1234/v1",
+ api_key="", # LM Studio不需要API密钥,留空
+ temperature=0.7,
+ max_tokens=1024,
+ streaming=False # 启用流式响应
+ )
+
+ llm = UnifiedLLM(config)
+
+ # 发送消息
+ messages = [
+ {"role": "system", "content": "你是一个有用的助手"},
+ {"role": "user", "content": "用中文介绍一下自己"}
+ ]
+
+ print("非流式响应:")
+ response = llm.chat(messages, streaming=False)
+ print(response.content)
+
+ print("\n流式响应:")
+ for chunk in llm.stream_chat(messages):
+ print(chunk.content, end="", flush=True)
+
+if __name__ == "__main__":
+ test_lm_studio()
\ No newline at end of file
diff --git a/Test/ui_tars_test.py b/Test/ui_tars_test.py
new file mode 100644
index 0000000..1fa98e1
--- /dev/null
+++ b/Test/ui_tars_test.py
@@ -0,0 +1,66 @@
+import asyncio
+import base64
+from pathlib import Path
+from src.modules.device_control_module.device_control_core.ui_tars_.ui_tars_client import UITarsClient, UITarsClientConfig
+
+
+async def test_ui_tars_stream():
+ """测试 UI-TARS 流式调用"""
+ # 创建客户端
+ config = UITarsClientConfig(
+ deployment_type="lmstudio",
+ base_url="http://192.168.1.8:1234/v1",
+ model_name="ui-tars-1.5-7b@q4_k_m",
+ temperature=0.1
+ )
+ client = UITarsClient(config)
+
+ # 使用工具方法编码
+ image_base64 = base64.b64encode(Path("test_files/Screenshot_test.png").read_bytes()).decode()
+ print(f"✅ 图片编码完成,长度: {len(image_base64)} 字符\n")
+
+ # 流式调用并实时打印
+ print("🤖 开始流式调用 UI-TARS...\n")
+ print("思考过程:\n")
+
+ import time
+ # 计算耗时
+ start_time = time.time()
+ full_response = ""
+ chunk_count = 0
+
+ full_response = await client.call_async("打开AK加速器", image_base64)
+ # 传入 base64 字符串
+ # for chunk in client.stream_async("我的桌面系统是KDE, 帮我打开设置", image_base64):
+ # chunk_count += 1
+ # content = chunk.content
+ #
+ # # 实时打印每个 chunk
+ # print(content, end="", flush=True)
+ #
+ # # 累积完整内容
+ # full_response += content
+
+ end_time = time.time()
+ print(f"\n\n耗时: {end_time - start_time:.2f} 秒")
+ print(f"\n\n{'=' * 50}")
+ print(f"✅ 流式调用完成!共接收 {chunk_count} 个 chunk")
+ print(f"完整响应长度: {len(full_response)} 字符")
+
+ print("响应内容:\n")
+ print(full_response)
+
+import pyautogui
+def auto_click(x : int, y : int):
+ pyautogui.moveTo(x, y, duration=1.5)
+ pyautogui.click()
+
+def auto_drag(x1 : int, y1 : int, x2 : int, y2 : int):
+ pyautogui.moveTo(x1, y1, duration=1.5)
+ pyautogui.dragTo(x2, y2, duration=1.5)
+
+# 运行异步函数
+if __name__ == "__main__":
+ asyncio.run(test_ui_tars_stream())
+ auto_click(173,48)
+ # auto_drag(56,39, 170,39)
\ No newline at end of file
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..0cccc0c
--- /dev/null
+++ b/main.py
@@ -0,0 +1,45 @@
+import asyncio
+from contextlib import suppress
+
+from loguru import logger
+from src.config.config import cfg
+from datetime import datetime
+from src.server_core.core import YosugaServerCore
+
+def init():
+ """
+ Yosuga_server 初始化
+ """
+
+ # 初始化日志系统
+ logger.add(
+ f"{cfg.log_dir}/Yosuga_server-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.log",
+ encoding="utf-8")
+ logger.info("Yosuga_server 启动")
+ logger.info(f"日志文件目录见: {cfg.log_dir} 目录")
+
+
+async def main():
+ core = await YosugaServerCore.get_instance()
+ try:
+ await core.run()
+ except asyncio.CancelledError:
+ pass # 正常取消,不打印堆栈
+ finally:
+ # 清理未关闭的 aiohttp sessions
+ import aiohttp
+ pending = [t for t in asyncio.all_tasks()
+ if t is not asyncio.current_task()]
+ for task in pending:
+ task.cancel()
+ with suppress(asyncio.CancelledError):
+ await asyncio.gather(*pending, return_exceptions=True)
+
+
+if __name__ == "__main__":
+ init()
+ try:
+ asyncio.run(main())
+ except KeyboardInterrupt:
+ print("\nYosuga服务器已停止喵~~~")
+
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..9f1eb20
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,24 @@
+[project]
+name = "yosuga-server"
+version = "0.1.0"
+description = "Add your description here"
+readme = "README.md"
+requires-python = ">=3.11"
+dependencies = [
+ "aiofiles>=25.1.0",
+ "aiohttp>=3.13.3",
+ "fastapi>=0.128.0",
+ "faster-whisper>=1.2.1",
+ "loguru>=0.7.3",
+ "openai>=2.16.0",
+ "pyautogui>=0.9.54",
+ "pydantic>=2.12.5",
+ "pydantic-settings>=2.12.0",
+ "python-multipart>=0.0.22",
+ "requests>=2.32.5",
+ "sounddevice>=0.5.5",
+ "soundfile>=0.13.1",
+ "tiktoken>=0.12.0",
+ "uvicorn>=0.40.0",
+ "websockets>=16.0",
+]
diff --git a/requirements-cpu.txt b/requirements-cpu.txt
new file mode 100644
index 0000000..488b1aa
--- /dev/null
+++ b/requirements-cpu.txt
@@ -0,0 +1,84 @@
+--index-url https://download.pytorch.org/whl/cpu
+aiofiles==25.1.0
+aiohappyeyeballs==2.6.1
+aiohttp==3.13.3
+aiosignal==1.4.0
+annotated-doc==0.0.4
+annotated-types==0.7.0
+anyio==4.12.1
+attrs==25.4.0
+av==16.1.0
+certifi==2026.1.4
+cffi==2.0.0
+charset-normalizer==3.4.4
+click==8.3.1
+colorama==0.4.6
+coloredlogs==15.0.1
+ctranslate2==4.6.3
+distro==1.9.0
+fastapi==0.128.0
+faster-whisper==1.2.1
+filelock==3.20.3
+flatbuffers==25.12.19
+frozenlist==1.8.0
+fsspec==2026.1.0
+h11==0.16.0
+hf-xet==1.2.0
+httpcore==1.0.9
+httpx==0.28.1
+huggingface-hub==1.3.7
+humanfriendly==10.0
+idna==3.11
+jinja2==3.1.6
+jiter==0.13.0
+loguru==0.7.3
+markupsafe==2.1.5
+mouseinfo==0.1.3
+mpmath==1.3.0
+multidict==6.7.1
+networkx==3.6.1
+numpy==2.4.2
+onnxruntime==1.23.2
+openai==2.16.0
+packaging==26.0
+pillow==12.1.0
+propcache==0.4.1
+protobuf==6.33.5
+pyautogui==0.9.54
+pycparser==3.0
+pydantic==2.12.5
+pydantic-core==2.41.5
+pydantic-settings==2.12.0
+pygetwindow==0.0.9
+pymsgbox==2.0.1
+pyperclip==1.11.0
+pyreadline3==3.5.4
+pyrect==0.2.0
+pyscreeze==1.0.1
+python-dotenv==1.2.1
+python-multipart==0.0.22
+pytweening==1.2.0
+pyyaml==6.0.3
+regex==2026.1.15
+requests==2.32.5
+setuptools==80.10.2
+shellingham==1.5.4
+sniffio==1.3.1
+sounddevice==0.5.5
+soundfile==0.13.1
+starlette==0.50.0
+sympy==1.14.0
+tiktoken==0.12.0
+tokenizers==0.22.2
+torch==2.5.1+cpu
+torchaudio==2.5.1+cpu
+torchvision==0.20.1+cpu
+tqdm==4.67.2
+typer-slim==0.21.1
+typing-extensions==4.15.0
+typing-inspection==0.4.2
+urllib3==2.6.3
+uvicorn==0.40.0
+websockets==16.0
+win32-setctime==1.2.0
+yarl==1.22.0
diff --git a/requirements-cuda.txt b/requirements-cuda.txt
new file mode 100644
index 0000000..f9ae722
Binary files /dev/null and b/requirements-cuda.txt differ
diff --git a/src/config/config.py b/src/config/config.py
new file mode 100644
index 0000000..071e3a7
--- /dev/null
+++ b/src/config/config.py
@@ -0,0 +1,452 @@
+"""
+零初始化 JSON 配置管理
+用法:from src.config import cfg # 直接访问,自动初始化
+"""
+import json
+import threading
+from pathlib import Path
+from typing import Any, Dict, Optional, TypeVar
+from dataclasses import dataclass, field, asdict
+
+def _project_root() -> Path:
+ """自动查找项目根目录"""
+ markers = ['pyproject.toml', 'settings.json', '.gitignore', 'main.py', '.python-version']
+ current = Path(__file__).resolve().parent.parent # src的父目录
+
+ for path in [current, *current.parents]:
+ if any((path / m).exists() for m in markers):
+ return path
+ if path == path.parent:
+ break
+ return current
+
+# 配置定义
+
+@dataclass
+class AIConfig:
+ api_key: Optional[str] = "sk-xxxxx"
+ base_url: str = "http://localhost:1234/v1"
+ model_name: str = "qwen/qwen3-4b-2507"
+ timeout: int = 30
+ temperature: float = 0.4
+ max_tokens: int = 8192
+
+
+@dataclass
+class TTSConfig:
+ enabled: bool = True
+ api_key: Optional[str] = None
+ gpt_model_name: str = "GPT_weights_v2Pro/Yosuga_Airi-e32.ckpt"
+ sovits_model_name: str = "SoVITS_weights_v2Pro/Yosuga_Airi_e16_s864.pth"
+ host: str = "localhost"
+ port: int = 20261
+ reference_audio: str = "./using/reference.wav"
+ streaming: bool = True
+ speed: float = 1.0
+
+@dataclass
+class ASRConfig:
+ enabled: bool = True
+ api_key: Optional[str] = None
+ model_name: str = "fast-whisper"
+ url: str = "http://localhost:20260/"
+
+@dataclass
+class AutoAgentConfig:
+ enabled: bool = True
+ api_key: Optional[str] = None
+ deployment_type: str = "lmstudio"
+ model_name: str = "ui-tars-1.5-7b@q4_k_m"
+ base_url: str = "http://localhost:1234/v1"
+ temperature: float = 0.1
+ max_tokens: int = 16384
+
+@dataclass
+class LLMConfig:
+ enabled: bool = True
+ role_character: str = "你是由Misakiotoha开发的助手稲葉愛理ちゃん,可以和用户一起玩游戏,聊天,做各种事情,性格抽象,没事爱整整活。"
+ max_context_tokens: int = 2048
+ enable_history: bool = True
+ language: str = "日本语"
+
+@dataclass
+class PathsConfig:
+ temp: str = "./tmp/"
+ log: str = "./log/"
+ using: str = "./using/"
+
+
+class AppConfig:
+ """
+ 应用主配置
+ 新增配置分组:1) 上方新建 dataclass 2) 下方 __init__ 添加字段 3) 完成
+ """
+
+ def __init__(
+ self,
+ version: str = "1.0.0",
+ debug: bool = False,
+ ai: Optional[AIConfig] = None,
+ tts: Optional[TTSConfig] = None,
+ asr: Optional[ASRConfig] = None,
+ auto_agent: Optional[AutoAgentConfig] = None,
+ llm_core: Optional[LLMConfig] = None,
+ paths: Optional[PathsConfig] = None,
+ _config_path: Optional[Path] = None,
+ **kwargs
+ ):
+ # 基础字段
+ self.version = version
+ self.debug = debug
+ self.ai = ai if ai is not None else AIConfig()
+ self.tts = tts if tts is not None else TTSConfig()
+ self.asr = asr if asr is not None else ASRConfig()
+ self.auto_agent = auto_agent if auto_agent is not None else AutoAgentConfig()
+ self.llm_core = llm_core if llm_core is not None else LLMConfig()
+ self.paths = paths if paths is not None else PathsConfig()
+
+ # 内部状态(非 dataclass 字段,不会被序列化)
+ self._config_path = _config_path
+ self._lock = threading.RLock() # 普通属性,非 dataclass 字段
+
+ # 应用其他字段(用于从 JSON 加载时)
+ for k, v in kwargs.items():
+ if hasattr(self, k):
+ setattr(self, k, v)
+
+ # 路径解析为绝对路径
+ if self._config_path:
+ root = self._config_path.parent
+ for field_name in ['temp', 'log', 'using']:
+ rel_path = getattr(self.paths, field_name)
+ if not Path(rel_path).is_absolute():
+ abs_path = (root / Path(rel_path)).resolve()
+ setattr(self.paths, field_name, str(abs_path) + '/')
+
+ # 便捷属性
+
+ @property
+ def temp_dir(self) -> Path:
+ return Path(self.paths.temp)
+
+ @property
+ def log_dir(self) -> Path:
+ return Path(self.paths.log)
+
+ @property
+ def using_dir(self) -> Path:
+ return Path(self.paths.using)
+
+ # 核心方法
+
+ def get(self, key: str, default: Any = None) -> Any:
+ """
+ 点号路径访问:cfg.get("ai.timeout") / cfg.get("tts.enabled")
+ """
+ try:
+ keys = key.split('.')
+ value = self
+ for k in keys:
+ value = getattr(value, k) if not isinstance(value, dict) else value[k]
+ return value
+ except (AttributeError, KeyError):
+ return default
+
+ def set(self, key: str, value: Any, save: bool = True) -> 'AppConfig':
+ """
+ 点号路径设置,支持链式调用
+ cfg.set("ai.timeout", 60).set("debug", True)
+ """
+ with self._lock:
+ keys = key.split('.')
+ target = self
+ for k in keys[:-1]:
+ target = getattr(target, k)
+ setattr(target, keys[-1], value)
+
+ if save:
+ self._save()
+ return self
+
+ def update(self, updates: Dict[str, Any], save: bool = True) -> 'AppConfig':
+ """
+ 批量更新:cfg.update({"ai": {"timeout": 60}, "debug": True})
+ """
+ def deep_update(obj: Any, data: dict):
+ for k, v in data.items():
+ if hasattr(obj, k):
+ current = getattr(obj, k)
+ if isinstance(v, dict) and hasattr(current, '__dataclass_fields__'):
+ deep_update(current, v)
+ else:
+ setattr(obj, k, v)
+
+ with self._lock:
+ deep_update(self, updates)
+
+ if save:
+ self._save()
+ return self
+
+ def reload(self) -> 'AppConfig':
+ """热重载配置"""
+ if self._config_path and self._config_path.exists():
+ with self._lock:
+ data = json.loads(self._config_path.read_text(encoding='utf-8'))
+
+ # 配置项名 -> dataclass 类的映射(和 _load 保持一致)
+ config_classes = {
+ 'ai': AIConfig,
+ 'tts': TTSConfig,
+ 'asr': ASRConfig,
+ 'auto_agent': AutoAgentConfig,
+ 'llm_core': LLMConfig,
+ 'paths': PathsConfig,
+ }
+
+ for k, v in data.items():
+ if hasattr(self, k) and not k.startswith('_'):
+ # 如果是配置项且是 dict,转换为 dataclass
+ if k in config_classes and isinstance(v, dict):
+ setattr(self, k, config_classes[k](**v))
+ else:
+ setattr(self, k, v)
+ print(f"配置重载: {self._config_path}")
+ return self
+
+ def to_dict(self) -> dict:
+ """导出为字典(手动实现,排除内部属性)"""
+ result = {
+ 'version': self.version,
+ 'debug': self.debug,
+ 'ai': asdict(self.ai) if hasattr(self.ai, '__dataclass_fields__') else self.ai,
+ 'tts': asdict(self.tts) if hasattr(self.tts, '__dataclass_fields__') else self.tts,
+ 'asr': asdict(self.asr) if hasattr(self.asr, '__dataclass_fields__') else self.asr,
+ 'auto_agent': asdict(self.auto_agent) if hasattr(self.auto_agent, '__dataclass_fields__') else self.auto_agent,
+ 'llm_core': asdict(self.llm_core) if hasattr(self.llm_core, '__dataclass_fields__') else self.llm_core,
+ 'paths': asdict(self.paths) if hasattr(self.paths, '__dataclass_fields__') else self.paths,
+ }
+ return result
+
+ def _save(self) -> None:
+ """保存到文件"""
+ if self._config_path:
+ with self._lock:
+ json_str = json.dumps(self.to_dict(), indent=2, ensure_ascii=False)
+ self._config_path.write_text(json_str, encoding='utf-8')
+
+ @classmethod
+ def _load(cls, path: Path) -> 'AppConfig':
+ """从文件加载"""
+ data = json.loads(path.read_text(encoding='utf-8'))
+
+ # 配置项名 -> dataclass 类的映射
+ config_classes = {
+ 'ai': AIConfig,
+ 'tts': TTSConfig,
+ 'asr': ASRConfig,
+ 'auto_agent': AutoAgentConfig,
+ 'llm_core': LLMConfig,
+ 'paths': PathsConfig,
+ }
+
+ # 自动转换 dict 为对应 dataclass
+ for key, config_class in config_classes.items():
+ if key in data and isinstance(data[key], dict):
+ data[key] = config_class(**data[key])
+
+ return cls(_config_path=path, **data)
+
+ @classmethod
+ def _create_default(cls, path: Path) -> 'AppConfig':
+ """创建默认配置"""
+ instance = cls(_config_path=path)
+ instance._save()
+ print(f"默认配置已创建: {path}")
+ return instance
+
+ def __repr__(self) -> str:
+ """友好的打印格式"""
+ lines = ["AppConfig("]
+ for k, v in self.to_dict().items():
+ lines.append(f" {k}={v!r},")
+ lines.append(")")
+ return "\n".join(lines)
+
+
+# 延迟初始化机制
+
+_root: Path = _project_root()
+_config_path: Path = _root / "settings.json"
+_config_instance: Optional[AppConfig] = None
+_init_lock: threading.Lock = threading.Lock()
+
+
+def _ensure_initialized() -> AppConfig:
+ """
+ 确保配置已初始化(线程安全的延迟初始化)
+ """
+ global _config_instance
+
+ if _config_instance is not None:
+ return _config_instance
+
+ with _init_lock:
+ if _config_instance is not None:
+ return _config_instance
+
+ # 自动加载或创建
+ if _config_path.exists():
+ _config_instance = AppConfig._load(_config_path)
+ print(f"配置加载: {_config_path}")
+ else:
+ _config_instance = AppConfig._create_default(_config_path)
+
+ # 确保目录存在
+ for d in [_config_instance.temp_dir,
+ _config_instance.log_dir,
+ _config_instance.using_dir]:
+ d.mkdir(parents=True, exist_ok=True)
+
+ return _config_instance
+
+
+class _LazyConfig:
+ """
+ 配置代理类:拦截所有属性访问,第一次使用时自动初始化
+ """
+
+ def __getattr__(self, name: str) -> Any:
+ """拦截属性访问,延迟初始化"""
+ instance = _ensure_initialized()
+ return getattr(instance, name)
+
+ def __setattr__(self, name: str, value: Any) -> None:
+ """拦截属性设置"""
+ # 特殊属性直接设置到代理对象本身
+ if name.startswith('_'):
+ super().__setattr__(name, value)
+ else:
+ instance = _ensure_initialized()
+ setattr(instance, name, value)
+
+ def __repr__(self) -> str:
+ instance = _ensure_initialized()
+ return repr(instance)
+
+ def __dir__(self) -> list:
+ """支持 IDE 自动补全"""
+ instance = _ensure_initialized()
+ return dir(instance)
+
+ # 显式代理必要方法
+ def get(self, key: str, default: Any = None) -> Any:
+ return _ensure_initialized().get(key, default)
+
+ def set(self, key: str, value: Any, save: bool = True) -> AppConfig:
+ return _ensure_initialized().set(key, value, save)
+
+ def update(self, updates: Dict[str, Any], save: bool = True) -> AppConfig:
+ return _ensure_initialized().update(updates, save)
+
+ def reload(self) -> AppConfig:
+ return _ensure_initialized().reload()
+
+ def to_dict(self) -> dict:
+ return _ensure_initialized().to_dict()
+
+ def save(self) -> None:
+ _ensure_initialized()._save()
+
+ # 属性代理
+ @property
+ def ai(self) -> AIConfig:
+ return _ensure_initialized().ai
+
+ @property
+ def tts(self) -> TTSConfig:
+ return _ensure_initialized().tts
+
+ @property
+ def asr(self) -> ASRConfig:
+ return _ensure_initialized().asr
+
+ @property
+ def auto_agent(self) -> AutoAgentConfig:
+ return _ensure_initialized().auto_agent
+
+ @property
+ def llm_core(self) -> LLMConfig:
+ return _ensure_initialized().llm_core
+
+ @property
+ def paths(self) -> PathsConfig:
+ return _ensure_initialized().paths
+
+ @property
+ def temp_dir(self) -> Path:
+ return _ensure_initialized().temp_dir
+
+ @property
+ def log_dir(self) -> Path:
+ return _ensure_initialized().log_dir
+
+ @property
+ def using_dir(self) -> Path:
+ return _ensure_initialized().using_dir
+
+
+# 全局配置对象:导入即用,自动初始化
+cfg: AppConfig = _LazyConfig() # type: ignore
+
+
+# 工具函数
+
+def generate_example(path: Path = _root / "settings.example.json") -> None:
+ """生成示例配置文件"""
+ example = {
+ "version": "1.0.0",
+ "debug": False,
+ "ai": {
+ "api_key": "sk-your-api-key",
+ "base_url": "https://api.deepseek.com",
+ "model": "deepseek-chat",
+ "timeout": 30
+ },
+ "tts": {
+ "enabled": True,
+ "model": "GPT_SoVITS",
+ "url": "http://localhost:12458/",
+ "reference_audio": "./using/reference.wav",
+ "streaming": True,
+ "speed": 1.0
+ },
+ "paths": {
+ "temp": "./tmp/",
+ "log": "./log/",
+ "data": "./data/"
+ }
+ }
+ path.write_text(json.dumps(example, indent=2, ensure_ascii=False), encoding='utf-8')
+ print(f"示例配置已生成: {path}")
+
+
+# 测试代码
+if __name__ == "__main__":
+ # 测试:直接访问,自动初始化
+ print("第一次访问 cfg.ai.model:")
+ print(f" → {cfg.ai.model_name}")
+
+ print(f"\n配置详情:")
+ print(cfg)
+
+ print(f"\n测试修改:")
+ cfg.set("ai.timeout", 60)
+ print(f" ai.timeout = {cfg.ai.timeout}")
+
+ print(f"\n测试批量更新:")
+ cfg.update({"debug": True, "tts": {"speed": 1.5}})
+ print(f" debug = {cfg.debug}, tts.speed = {cfg.tts.speed}")
+
+ print(f"\n测试热重载:")
+ cfg.reload()
\ No newline at end of file
diff --git a/src/config/readme.md b/src/config/readme.md
new file mode 100644
index 0000000..e69de29
diff --git a/src/modules/__init__.py b/src/modules/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/modules/asr_module/__init__.py b/src/modules/asr_module/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/modules/asr_module/api.py b/src/modules/asr_module/api.py
new file mode 100644
index 0000000..818e306
--- /dev/null
+++ b/src/modules/asr_module/api.py
@@ -0,0 +1,123 @@
+# asr_module/api.py
+from fastapi import FastAPI, File, UploadFile, HTTPException
+from fastapi.responses import JSONResponse
+from pathlib import Path
+import tempfile
+import time
+from datetime import datetime
+from loguru import logger
+from src.modules.asr_module.asr_core.fast_whisper import create_asr, ASRConfig
+
+# 初始化FastAPI应用
+app = FastAPI(
+ title="Yosuga ASR API",
+ description="基于faster-whisper Turbo的高性能多语种语音转文本服务",
+ version="1.0.0"
+)
+
+# 全局单例ASR实例(延迟加载)
+_asr_instance = None
+
+def get_asr():
+ """获取或创建ASR实例(单例)"""
+ global _asr_instance
+ if _asr_instance is None:
+ logger.info("🚀 初始化ASR服务...")
+ _asr_instance = create_asr(
+ ASRConfig(
+ model_name="deepdml/faster-whisper-large-v3-turbo-ct2",
+ device="auto",
+ compute_type="int8_float16",
+ cache_dir=Path("asr_models/faster_whisper_large_v3_ct2"),
+ beam_size=1, # 贪婪搜索,速度最快
+ vad_filter=True, # 过滤静音,节省30%时间
+ )
+ )
+ logger.info("✅ ASR服务初始化完成")
+ return _asr_instance
+
+@app.on_event("startup")
+async def startup_event():
+ """应用启动时预加载模型"""
+ get_asr()
+
+@app.on_event("shutdown")
+async def shutdown_event():
+ """应用关闭时清理资源"""
+ global _asr_instance
+ if _asr_instance:
+ _asr_instance.shutdown()
+ logger.info("🛑 ASR服务已关闭")
+
+@app.post("/transcribe", response_class=JSONResponse)
+async def transcribe_audio(
+ file: UploadFile = File(..., description="音频文件 (WAV, FLAC, MP3等格式)")
+):
+ """
+ 语音转文本API
+
+ - **file**: 音频文件,支持WAV/FLAC/MP3等格式
+ - **返回**: JSON格式结果,包含text/language/confidence
+ """
+ start_time = time.time()
+
+ # 验证文件类型
+ if file.content_type and not file.content_type.startswith("audio/"):
+ raise HTTPException(status_code=400, detail="❌ 请上传音频文件 (MIME类型: audio/*)")
+
+ try:
+ # 创建临时文件
+ with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as tmp_file:
+ content = await file.read()
+ tmp_file.write(content)
+ tmp_path = Path(tmp_file.name)
+
+ logger.info(f"📥 接收文件: {file.filename} ({len(content)} bytes)")
+
+ # 调用ASR识别
+ asr = get_asr()
+ text, language, confidence = asr.transcribe_wav(tmp_path)
+
+ # 清理临时文件
+ tmp_path.unlink(missing_ok=True)
+
+ processing_time = time.time() - start_time
+
+ logger.info(f"✅ 识别完成: {language} | {len(text)}字符 | 置信度:{confidence:.2f} | 耗时:{processing_time:.3f}s")
+
+ return {
+ "success": True,
+ "data": {
+ "text": text,
+ "language": language,
+ "confidence": confidence,
+ "processing_time": round(processing_time, 3)
+ }
+ }
+
+ except Exception as e:
+ logger.error(f"❌ 识别失败: {e}")
+ raise HTTPException(status_code=500, detail=f"识别失败: {str(e)}")
+
+@app.get("/health")
+async def health_check():
+ """健康检查接口"""
+ asr = get_asr()
+ health = asr.health_check()
+
+ return {
+ "status": "healthy" if health["status"] == "healthy" else "unhealthy",
+ "timestamp": datetime.now().isoformat(),
+ "device": health["device"],
+ "model_loaded": health["model_loaded"]
+ }
+
+@app.get("/")
+async def root():
+ """API根路径"""
+ return {
+ "message": "Yosuga ASR API 正在运行",
+ "docs": "/docs",
+ "health": "/health",
+ "transcribe": "/transcribe (POST)"
+ }
\ No newline at end of file
diff --git a/src/modules/asr_module/asr_core/__init__.py b/src/modules/asr_module/asr_core/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/modules/asr_module/asr_core/fast_whisper/__init__.py b/src/modules/asr_module/asr_core/fast_whisper/__init__.py
new file mode 100644
index 0000000..cc40188
--- /dev/null
+++ b/src/modules/asr_module/asr_core/fast_whisper/__init__.py
@@ -0,0 +1,16 @@
+# fast_whisper/__init__.py
+from typing import Optional
+from src.modules.asr_module.asr_core.fast_whisper.config import ASRConfig
+from src.modules.asr_module.asr_core.fast_whisper.model_manager import ModelManager
+from src.modules.asr_module.asr_core.fast_whisper.asr_interface import ASRInterface
+
+__version__ = "1.0.0"
+__all__ = ["ASRConfig", "ModelManager", "ASRInterface"]
+
+def create_asr(config: Optional[ASRConfig] = None) -> ASRInterface:
+ """
+ 快速创建ASR实例
+ Args:
+ config: ASR配置,若为None则使用默认配置
+ """
+ return ASRInterface.get_instance(config)
\ No newline at end of file
diff --git a/src/modules/asr_module/asr_core/fast_whisper/asr_interface.py b/src/modules/asr_module/asr_core/fast_whisper/asr_interface.py
new file mode 100644
index 0000000..6eeab10
--- /dev/null
+++ b/src/modules/asr_module/asr_core/fast_whisper/asr_interface.py
@@ -0,0 +1,184 @@
+# fast_whisper/asr_interface.py
+from loguru import logger
+from pathlib import Path
+from typing import Tuple, Optional
+import torchaudio
+import torch
+import numpy
+
+from .model_manager import ModelManager
+from .config import ASRConfig
+from .utils import PerformanceProfiler
+
+class ASRInterface:
+ """
+ ASR接口类 - 全局单例
+ - 提供wav转文本功能
+ - 注入ModelManager
+ - 性能统计
+ """
+
+ _instance: Optional['ASRInterface'] = None
+
+ def __new__(cls, *args, **kwargs):
+ """单例模式实现"""
+ if cls._instance is None:
+ cls._instance = super().__new__(cls)
+ return cls._instance
+
+ def __init__(self, config: Optional[ASRConfig] = None):
+ # 防止重复初始化
+ if hasattr(self, '_initialized') and self._initialized:
+ return
+
+ self.config = config or ASRConfig()
+ self.model_manager = ModelManager(self.config)
+ self.profiler = PerformanceProfiler(self.config.enable_profiling)
+
+ # 音频参数
+ self.sample_rate = 16000
+
+ self._initialized = True
+ logger.info("🎤 ASR接口初始化完成")
+
+ @classmethod
+ def get_instance(cls, config: Optional[ASRConfig] = None) -> 'ASRInterface':
+ """全局访问点"""
+ if cls._instance is None:
+ cls._instance = cls(config)
+ return cls._instance
+
+ def transcribe_wav(
+ self,
+ wav_path: Path,
+ language: Optional[str] = None
+ ) -> Tuple[str, str, float]:
+ """
+ WAV音频转文本(核心接口)
+
+ Args:
+ wav_path: WAV文件路径
+ language: 指定语言代码(如'zh'/'en'),None则自动检测
+
+ Returns:
+ (text, language, confidence)
+ """
+ try:
+ # 记录开始时间
+ import time
+ start_time = time.time()
+
+ logger.info(f"🎵 开始识别: {wav_path.name}")
+
+ # 执行识别...
+ audio = self._load_audio(wav_path)
+ result = self._transcribe(audio, language)
+ text, lang, confidence = self._parse_result(result)
+
+ # 计算耗时
+ processing_time = time.time() - start_time
+ logger.info(
+ f"✅ 识别完成: {lang} | {len(text)}字符 | 置信度:{confidence:.2f} | "
+ f"耗时:{processing_time:.3f}s | RTF:{processing_time/(len(audio)/self.sample_rate):.3f}"
+ )
+
+ return text, lang, confidence
+
+ except Exception as e:
+ logger.error(f"❌ 识别失败 {wav_path}: {e}")
+ raise RuntimeError(f"Transcription failed: {e}")
+
+ def _load_audio(self, wav_path: Path) -> numpy.ndarray:
+ """加载和预处理音频"""
+ if not wav_path.exists():
+ raise FileNotFoundError(f"音频文件不存在: {wav_path}")
+
+ # 加载音频
+ waveform, sample_rate = torchaudio.load(wav_path)
+
+ # 重采样到16kHz
+ if sample_rate != self.sample_rate:
+ resampler = torchaudio.transforms.Resample(sample_rate, self.sample_rate)
+ waveform = resampler(waveform)
+
+ # 转换为单声道
+ if waveform.shape[0] > 1:
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
+
+ # 转换为numpy数组
+ audio = waveform.squeeze().numpy()
+
+ return audio
+
+ def _transcribe(self, audio: numpy.ndarray, language: Optional[str]) -> Tuple:
+ """执行推理"""
+ model = self.model_manager.model
+
+ # 添加模型存在性检查
+ if model is None:
+ logger.error("ASR模型未加载,请检查模型配置和路径")
+ raise RuntimeError("ASR模型未加载,请检查模型配置和路径")
+
+ # 记录时间
+ import time
+ start_time = time.time()
+
+ # 调用模型
+ segments, info = model.transcribe(
+ audio,
+ language=language,
+ beam_size=self.config.beam_size,
+ best_of=self.config.best_of,
+ vad_filter=self.config.vad_filter,
+ )
+
+ # 立即执行生成器
+ segments_list = list(segments)
+
+ # 性能统计
+ inference_time = time.time() - start_time
+ audio_duration = len(audio) / self.sample_rate
+ self.profiler.record(audio_duration, inference_time)
+
+ return segments_list, info
+
+ def _parse_result(self, result: Tuple) -> Tuple[str, str, float]:
+ """解析识别结果"""
+ segments, info = result
+
+ # 合并所有片段
+ text = " ".join([seg.text.strip() for seg in segments])
+
+ # 获取语言信息
+ language = info.language if info else "unknown"
+ confidence = info.language_probability if info else 0.0
+
+ return text, language, confidence
+
+ def transcribe_batch(self, wav_paths: list) -> list:
+ """批量识别接口"""
+ return [
+ {
+ "file": str(path),
+ "text": result[0],
+ "language": result[1],
+ "confidence": result[2]
+ }
+ for path, result in zip(wav_paths, [
+ self.transcribe_wav(Path(p)) for p in wav_paths
+ ])
+ ]
+
+ def health_check(self) -> dict:
+ """健康检查接口"""
+ return {
+ "status": "healthy" if self.model_manager.model else "unhealthy",
+ "device": self.config.device,
+ "model_loaded": self.model_manager.model is not None,
+ "device_info": self.model_manager.get_device_info(),
+ }
+
+ def shutdown(self):
+ """优雅关闭"""
+ logger.info("🛑 关闭ASR接口...")
+ self.model_manager.unload()
\ No newline at end of file
diff --git a/src/modules/asr_module/asr_core/fast_whisper/config.py b/src/modules/asr_module/asr_core/fast_whisper/config.py
new file mode 100644
index 0000000..7516926
--- /dev/null
+++ b/src/modules/asr_module/asr_core/fast_whisper/config.py
@@ -0,0 +1,29 @@
+# fast_whisper/config.py
+from dataclasses import dataclass
+from pathlib import Path
+import torch
+
+@dataclass
+class ASRConfig:
+ """ASR配置类"""
+ model_name: str = "deepdml/faster-whisper-large-v3-turbo-ct2"
+ device: str = "auto"
+ compute_type: str = "int8_float16"
+ cache_dir: Path = Path.home() / ".cache" / "faster_whisper"
+
+ # 速度优化参数
+ beam_size: int = 1
+ best_of: int = 1
+ vad_filter: bool = True
+ batch_size: int = 16
+
+ # 性能统计
+ enable_profiling: bool = True
+
+ def __post_init__(self):
+ if self.device == "auto":
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ if self.device == "cpu":
+ self.compute_type = "int8"
+ self.batch_size = 4
\ No newline at end of file
diff --git a/src/modules/asr_module/asr_core/fast_whisper/model_manager.py b/src/modules/asr_module/asr_core/fast_whisper/model_manager.py
new file mode 100644
index 0000000..f1bbaab
--- /dev/null
+++ b/src/modules/asr_module/asr_core/fast_whisper/model_manager.py
@@ -0,0 +1,92 @@
+# fast_whisper/model_manager.py
+import gc
+from loguru import logger
+from pathlib import Path
+from typing import Optional
+from faster_whisper import WhisperModel
+import torch
+
+from .config import ASRConfig
+
+
+class ModelManager:
+ """
+ 模型管理类
+ - 负责模型生命周期管理
+ - 支持自定义缓存目录
+ - 自动硬件适配
+ """
+
+ def __init__(self, config: ASRConfig):
+ self.config = config
+ self._model: Optional[WhisperModel] = None
+ self._device_info = None
+
+ # 确保缓存目录存在
+ self.config.cache_dir.mkdir(parents=True, exist_ok=True)
+
+ @property
+ def model(self) -> Optional[WhisperModel]:
+ """懒加载模型"""
+ if self._model is None:
+ self._load_model()
+ return self._model
+
+ def _load_model(self):
+ """加载模型"""
+ logger.info(f"🚀 初始化模型: {self.config.model_name}")
+ logger.info(f"📦 设备: {self.config.device}, 计算类型: {self.config.compute_type}")
+
+ try:
+ self._model = WhisperModel(
+ self.config.model_name,
+ device=self.config.device,
+ compute_type=self.config.compute_type,
+ download_root=str(self.config.cache_dir),
+ local_files_only=False,
+ )
+
+ self._device_info = {
+ "device": self.config.device,
+ "compute_type": self.config.compute_type,
+ "model_size": self.config.model_name.split("-")[-2]
+ }
+
+ logger.info("✅ 模型加载成功")
+
+ except Exception as e:
+ logger.error(f"❌ 模型加载失败: {e}")
+ raise RuntimeError(f"Failed to load ASR model: {e}")
+
+ def reload(self, new_config: ASRConfig):
+ """热重载模型"""
+ logger.info("🔄 热重载模型...")
+ self.unload()
+ self.config = new_config
+ self._load_model()
+
+ def unload(self):
+ """卸载模型释放资源"""
+ if self._model is not None:
+ logger.info("🗑️ 卸载模型...")
+ del self._model
+ self._model = None
+
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ gc.collect()
+
+ logger.info("✅ 模型已卸载")
+
+ def get_device_info(self) -> dict:
+ """获取设备信息"""
+ return self._device_info or {}
+
+ def __enter__(self):
+ """上下文管理器支持"""
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ """自动清理资源"""
+ self.unload()
\ No newline at end of file
diff --git a/src/modules/asr_module/asr_core/fast_whisper/utils.py b/src/modules/asr_module/asr_core/fast_whisper/utils.py
new file mode 100644
index 0000000..6fe2a58
--- /dev/null
+++ b/src/modules/asr_module/asr_core/fast_whisper/utils.py
@@ -0,0 +1,45 @@
+# fast_whisper/utils.py
+from loguru import logger
+from datetime import datetime
+from typing import Dict, Any
+
+def check_hardware() -> Dict[str, Any]:
+ """硬件检测"""
+ import torch
+ info = {
+ "cuda_available": torch.cuda.is_available(),
+ "device_name": "CPU",
+ "device_count": 0,
+ "compute_type": "int8"
+ }
+
+ if info["cuda_available"]:
+ info.update({
+ "device_name": torch.cuda.get_device_name(0),
+ "device_count": torch.cuda.device_count(),
+ "compute_type": "int8_float16"
+ })
+
+ return info
+
+class PerformanceProfiler:
+ """性能分析器"""
+ def __init__(self, enable: bool = True):
+ self.enable = enable
+ self.stats = []
+
+ def record(self, audio_duration: float, inference_time: float):
+ if not self.enable:
+ return
+
+ rtf = inference_time / audio_duration if audio_duration > 0 else 0
+ self.stats.append({
+ "timestamp": datetime.now().isoformat(),
+ "rtf": rtf,
+ "audio_duration": audio_duration,
+ "inference_time": inference_time
+ })
+
+ if len(self.stats) % 10 == 0:
+ avg_rtf = sum(s["rtf"] for s in self.stats[-10:]) / 10
+ logger.info(f"📊 最近10次平均RTF: {avg_rtf:.3f}")
\ No newline at end of file
diff --git a/src/modules/asr_module/client/__init__.py b/src/modules/asr_module/client/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/modules/asr_module/client/asr_client.py b/src/modules/asr_module/client/asr_client.py
new file mode 100644
index 0000000..041c83e
--- /dev/null
+++ b/src/modules/asr_module/client/asr_client.py
@@ -0,0 +1,215 @@
+# asr_module/client/asr_client.py
+import asyncio
+import time
+from pathlib import Path
+from typing import Union, Optional
+import aiofiles
+import aiohttp
+import requests
+from loguru import logger
+from .models import ASRResponse, ASRHealthStatus, ServiceInfo
+
+class ASRException(Exception):
+ """ASR服务调用异常"""
+
+ def __init__(self, message: str, status_code: Optional[int] = None):
+ self.message = message
+ self.status_code = status_code
+ super().__init__(self.message)
+
+class ASRClientConfig:
+ """客户端配置"""
+ def __init__(
+ self,
+ base_url: str = "http://localhost:8000",
+ timeout: float = 30.0,
+ retry_count: int = 2,
+ retry_delay: float = 0.5,
+ ):
+ self.base_url = base_url.rstrip('/')
+ self.timeout = timeout
+ self.retry_count = retry_count
+ self.retry_delay = retry_delay
+
+
+# 同步客户端
+class ASRClientSync:
+ """同步ASR客户端"""
+ def __init__(self, config: Optional[ASRClientConfig] = None):
+ self.config = config or ASRClientConfig()
+ self.session = requests.Session()
+ self.session.timeout = self.config.timeout
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.session.close()
+
+ def _request(self, method: str, endpoint: str, **kwargs) -> dict:
+ """统一请求处理(带重试)"""
+ url = f"{self.config.base_url}{endpoint}"
+
+ for attempt in range(self.config.retry_count + 1):
+ try:
+ response = self.session.request(method, url, **kwargs)
+ response.raise_for_status()
+ return response.json()
+ except requests.exceptions.RequestException as e:
+ if attempt < self.config.retry_count:
+ logger.warning(f"请求失败,重试中 ({attempt + 1}/{self.config.retry_count}): {e}")
+ time.sleep(self.config.retry_delay)
+ else:
+ logger.error(f"请求最终失败: {e}")
+ raise ASRException(f"API调用失败: {e}", getattr(e.response, 'status_code', None))
+
+ def transcribe_file(self, file_path: Union[str, Path]) -> ASRResponse:
+ """
+ 转录音频文件
+
+ Args:
+ file_path: 音频文件路径
+
+ Returns:
+ ASRResponse对象
+
+ Example:
+ client = ASRClientSync()
+ result = client.transcribe_file("/path/to/audio.wav")
+ print(result.data.text)
+ """
+ file_path = Path(file_path)
+ if not file_path.exists():
+ raise FileNotFoundError(f"文件不存在: {file_path}")
+
+ logger.info(f"📤 上传文件: {file_path.name}")
+
+ with open(file_path, 'rb') as f:
+ files = {'file': (file_path.name, f, 'audio/wav')}
+ result = self._request('POST', '/transcribe', files=files)
+
+ return ASRResponse(**result)
+
+ def transcribe_bytes(self, audio_data: bytes, filename: str = "audio.wav") -> ASRResponse:
+ """
+ 转录音频字节流
+
+ Args:
+ audio_data: 原始音频字节
+ filename: 模拟文件名(用于MIME类型推断)
+
+ Returns:
+ ASRResponse对象
+
+ Example:
+ with open('audio.wav', 'rb') as f:
+ audio_bytes = f.read()
+ result = client.transcribe_bytes(audio_bytes)
+ """
+ logger.info(f"📤 上传字节流 ({len(audio_data)} bytes)")
+
+ files = {'file': (filename, audio_data, 'audio/wav')}
+ result = self._request('POST', '/transcribe', files=files)
+
+ return ASRResponse(**result)
+
+ def health_check(self) -> ASRHealthStatus:
+ """健康检查"""
+ result = self._request('GET', '/health')
+ return ASRHealthStatus(**result)
+
+ def get_service_info(self) -> ServiceInfo:
+ """获取服务信息"""
+ result = self._request('GET', '/')
+ return ServiceInfo(**result)
+
+
+# 异步客户端
+class ASRClientAsync:
+ """异步ASR客户端"""
+ def __init__(self, config: Optional[ASRClientConfig] = None):
+ self.config = config or ASRClientConfig()
+ self._session: Optional[aiohttp.ClientSession] = None
+
+ async def __aenter__(self):
+ self._session = aiohttp.ClientSession(
+ timeout=aiohttp.ClientTimeout(total=self.config.timeout)
+ )
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ if self._session:
+ await self._session.close()
+
+ async def _ensure_session(self):
+ if self._session is None:
+ self._session = aiohttp.ClientSession(
+ timeout=aiohttp.ClientTimeout(total=self.config.timeout)
+ )
+
+ async def _request(self, method: str, endpoint: str, **kwargs) -> dict:
+ """统一异步请求(带重试)"""
+ await self._ensure_session()
+ url = f"{self.config.base_url}{endpoint}"
+
+ for attempt in range(self.config.retry_count + 1):
+ try:
+ async with self._session.request(method, url, **kwargs) as response:
+ response.raise_for_status()
+ return await response.json()
+ except aiohttp.ClientError as e:
+ if attempt < self.config.retry_count:
+ logger.warning(f"请求失败,重试中 ({attempt + 1}/{self.config.retry_count}): {e}")
+ await asyncio.sleep(self.config.retry_delay)
+ else:
+ logger.error(f"请求最终失败: {e}")
+ raise ASRException(f"API调用失败: {e}", getattr(e, 'status', None))
+
+ async def transcribe_file(self, file_path: Union[str, Path]) -> ASRResponse:
+ """异步转录音频文件"""
+ file_path = Path(file_path)
+ if not file_path.exists():
+ raise FileNotFoundError(f"文件不存在: {file_path}")
+
+ logger.info(f"📤 上传文件: {file_path.name}")
+
+ async with aiofiles.open(file_path, 'rb') as f:
+ audio_data = await f.read()
+
+ return await self.transcribe_bytes(audio_data, file_path.name)
+
+ async def transcribe_bytes(self, audio_data: bytes, filename: str = "audio.wav") -> ASRResponse:
+ """异步转录音频字节流"""
+ logger.info(f"📤 上传字节流 ({len(audio_data)} bytes)")
+ await self._ensure_session() # 确保session已创建
+ form = aiohttp.FormData() # 创建表单数据
+ form.add_field('file', audio_data, filename=filename, content_type='audio/wav') # 添加文件字段
+ result = await self._request('POST', '/transcribe', data=form) # 发送POST请求
+ return ASRResponse(**result) # 返回结果
+
+ async def health_check(self) -> ASRHealthStatus:
+ """异步健康检查"""
+ result = await self._request('GET', '/health')
+ return ASRHealthStatus(**result)
+
+ async def get_service_info(self) -> ServiceInfo:
+ """异步获取服务信息"""
+ result = await self._request('GET', '/')
+ return ServiceInfo(**result)
+
+# 工厂函数
+def create_asr_client(use_async: bool = False, **config_kwargs) -> Union[ASRClientSync, ASRClientAsync]:
+ """
+ 创建客户端工厂函数
+
+ Args:
+ use_async: 是否创建异步客户端
+ **config_kwargs: ASRClientConfig参数
+
+ Returns:
+ 同步或异步客户端实例
+ """
+ config = ASRClientConfig(**config_kwargs)
+ if use_async:
+ return ASRClientAsync(config)
+ return ASRClientSync(config)
\ No newline at end of file
diff --git a/src/modules/asr_module/client/models.py b/src/modules/asr_module/client/models.py
new file mode 100644
index 0000000..ca280fe
--- /dev/null
+++ b/src/modules/asr_module/client/models.py
@@ -0,0 +1,30 @@
+# asr_module/client/models.py
+from pydantic import BaseModel
+from typing import Optional
+
+class ASRHealthStatus(BaseModel):
+ """ASR服务健康状态"""
+ status: str # 状态
+ timestamp: str # 时间戳
+ device: str # 设备
+ model_loaded: bool # 模型是否加载
+
+class ASRResult(BaseModel):
+ """语音识别结果"""
+ text: str # 识别结果
+ language: str # 语言
+ confidence: float # 置信度
+ processing_time: float # 处理时间
+
+class ASRResponse(BaseModel):
+ """统一API响应"""
+ success: bool
+ data: Optional[ASRResult] = None
+ error: Optional[str] = None
+
+class ServiceInfo(BaseModel):
+ """服务信息"""
+ message: str # 消息
+ docs: str # 文档
+ health: str # 健康
+ transcribe: str # 识别
\ No newline at end of file
diff --git a/src/modules/asr_module/readme.md b/src/modules/asr_module/readme.md
new file mode 100644
index 0000000..3799f32
--- /dev/null
+++ b/src/modules/asr_module/readme.md
@@ -0,0 +1,12 @@
+本模块为asr模块,即语音转文本模块。
+
+本模块提供了两种访问方式:
+- 第一种为本地部署的func call方式,即提供函数调用的方式调用相关asr接口。
+- 第二种为http call方式,即提供http call方式调用相关asr接口。[FastAPI]
+
+如果你的电脑显卡支持cuda, 并且显存大小大于8G, 那么可以使用第一种方式,
+否则可以使用第二种,进行云端部署。
+
+两种调用方式在相同配置下,性能几乎无差别。
+
+个人建议使用第二种
\ No newline at end of file
diff --git a/src/modules/asr_module/start_api.py b/src/modules/asr_module/start_api.py
new file mode 100644
index 0000000..2618309
--- /dev/null
+++ b/src/modules/asr_module/start_api.py
@@ -0,0 +1,65 @@
+# start_api.py
+import uvicorn
+from loguru import logger
+import threading
+import time
+
+def start_server():
+ """启动 ASR API 服务"""
+ uvicorn.run(
+ "api:app", # 模块名:app实例
+ host="0.0.0.0",
+ port=20260,
+ workers=1, # 单用户场景,1个worker足够
+ log_level="info",
+ reload=False, # 生产环境关闭热重载
+ access_log=True,
+ )
+
+def first_test() -> None:
+ """首次启动测试"""
+ time.sleep(5) # 给服务器一些启动时间
+ # 构造一个测试请求以验证初始化模型加载成功
+ logger.info("🚀 测试模型是否加载成功...")
+ import requests
+ from pathlib import Path
+ url = "http://localhost:20260/transcribe"
+ audio_path = Path("../../../Test/test_files/test.wav")
+ try:
+ with open(audio_path, "rb") as f:
+ # 明确指定文件名和 MIME 类型
+ files = {
+ "file": (
+ audio_path.name, # 文件名
+ f, # 文件对象
+ "audio/wav" # MIME 类型
+ )
+ }
+
+ response = requests.post(url, files=files)
+ logger.info(f"状态码: {response.status_code}")
+ logger.info(f"响应头: {response.headers.get('content-type')}")
+ if response.status_code == 200:
+ result = response.json()
+ logger.info(f"识别结果: {result['data']['text']}")
+ logger.info(f"识别语言: {result['data']['language']}")
+ logger.info(f"置信度: {result['data']['confidence']:.2f}")
+ logger.info(f"处理时间: {result['data']['processing_time']}s")
+ else:
+ logger.error(f"请求失败,错误响应信息: {response.text}")
+ logger.error("请检查模型是否正确加载或其他问题")
+ except Exception as e:
+ logger.error(f"测试过程中发生错误: {e}")
+
+if __name__ == "__main__":
+ logger.info("🚀 启动 ASR API 服务...")
+
+ # 在后台线程启动服务器
+ server_thread = threading.Thread(target=start_server, daemon=True)
+ server_thread.start()
+
+ # 执行测试
+ first_test()
+
+ # 保持主线程运行
+ server_thread.join()
diff --git a/src/modules/device_control_module/__init__.py b/src/modules/device_control_module/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/modules/device_control_module/device_control_core/__init__.py b/src/modules/device_control_module/device_control_core/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/modules/device_control_module/device_control_core/ui_tars_/__init__.py b/src/modules/device_control_module/device_control_core/ui_tars_/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/modules/device_control_module/device_control_core/ui_tars_/ui_tars_client.py b/src/modules/device_control_module/device_control_core/ui_tars_/ui_tars_client.py
new file mode 100644
index 0000000..230b6a5
--- /dev/null
+++ b/src/modules/device_control_module/device_control_core/ui_tars_/ui_tars_client.py
@@ -0,0 +1,119 @@
+# ui_tars_/ui_tars_client.py
+from typing import Optional
+from loguru import logger
+import asyncio
+from src.modules.text_ai_module.text_ai_core.general_text_ai_req import (
+ UnifiedLLM,
+ ModelConfig,
+ ModelProvider,
+ ChatMessage,
+)
+from pydantic import BaseModel, Field
+from src.modules.device_control_module.device_control_core.ui_tars_.ui_tars_prompts import UI_TARS_SYSTEM_PROMPT
+
+
+class UITarsClientConfig(BaseModel):
+ """UI-TARS 客户端配置"""
+ deployment_type: str = Field(default="lmstudio", description="部署类型")
+ base_url: str = Field(default="http://localhost:1234/v1", description="API地址")
+ model_name: str = Field(default="ui-tars", description="模型名称")
+ api_key: Optional[str] = Field(default=None, description="API密钥")
+ temperature: float = Field(default=0.1, ge=0.0, le=2.0)
+ max_tokens: int = Field(default=8192, ge=2048, le=128000)
+ timeout: int = Field(default=30, ge=5, le=300)
+
+ # UI-TARS-1.5 强制输出格式
+ system_prompt: str = Field(
+ default=UI_TARS_SYSTEM_PROMPT # 使用本项目自定义的输出格式约束
+ )
+
+ def to_model_config(self) -> ModelConfig:
+ """转换为 UnifiedLLM 配置"""
+ # 映射部署类型到 ModelProvider
+ provider_map = {
+ "lmstudio": ModelProvider.LM_STUDIO,
+ "vllm": ModelProvider.CUSTOM,
+ "cloud": ModelProvider.OPENAI,
+ "ollama": ModelProvider.OLLAMA
+ }
+ provider = provider_map.get(self.deployment_type, ModelProvider.CUSTOM)
+ return ModelConfig(
+ provider=provider,
+ model_name=self.model_name,
+ api_key=self.api_key,
+ base_url=self.base_url,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ timeout=self.timeout,
+ custom_headers={"User-Agent": "UI-TARS-Client/1.0"}
+ )
+
+class UITarsClient:
+ """
+ UI-TARS 通用客户端 (基于 UnifiedLLM)
+ 图片相关信息请直接传入相应的base64
+ """
+ def __init__(self, config: UITarsClientConfig):
+ self.config = config
+ # 复用 UnifiedLLM,自动处理所有部署类型
+ self.llm = UnifiedLLM(config.to_model_config())
+
+ logger.info(f"UI-TARS 客户端初始化: {config.deployment_type} @ {config.base_url}")
+ logger.info(f" 模型: {config.model_name} | 温度: {config.temperature}")
+
+ async def call_async(self, instruction: str, image_base64: str) -> str:
+ """异步调用 UI-TARS"""
+ # 构建消息
+ messages = self._build_messages(instruction, image_base64)
+ try:
+ # 使用 UnifiedLLM 的异步接口
+ response = await asyncio.to_thread(
+ self.llm.chat,
+ messages=messages,
+ streaming=False
+ )
+ return response.content
+ except Exception as e:
+ logger.error(f"UI-TARS 调用失败: {e}")
+ raise
+
+ def call_sync(self, instruction: str, image_base64: str) -> str:
+ """同步调用 UI-TARS"""
+ messages = self._build_messages(instruction, image_base64)
+ try:
+ response = self.llm.chat(
+ messages=messages,
+ streaming=False
+ )
+
+ return response.content
+ except Exception as e:
+ logger.error(f"UI-TARS 调用失败: {e}")
+ raise
+
+ def stream_async(self, instruction: str, image_base64: str):
+ """流式调用 (异步生成器)"""
+ messages = self._build_messages(instruction, image_base64)
+ # UnifiedLLM 自动处理流式
+ return self.llm.stream_chat(messages=messages)
+
+ def _build_messages(self, instruction: str, image_base64: str) -> list:
+ """构建 OpenAI 格式消息"""
+ return [
+ ChatMessage(
+ role="system",
+ content=self.config.system_prompt
+ ),
+ ChatMessage(
+ role="user",
+ content=[
+ {"type": "text", "text": instruction},
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": f"data:image/png;base64,{image_base64}"
+ }
+ }
+ ]
+ )
+ ]
diff --git a/src/modules/device_control_module/device_control_core/ui_tars_/ui_tars_prompts.py b/src/modules/device_control_module/device_control_core/ui_tars_/ui_tars_prompts.py
new file mode 100644
index 0000000..8c484f3
--- /dev/null
+++ b/src/modules/device_control_module/device_control_core/ui_tars_/ui_tars_prompts.py
@@ -0,0 +1,25 @@
+OFFICIAL_ACTION_SPACE = """## Action Space
+click(point='x1 y1') - 单击坐标
+left_double(point='x1 y1') - 双击坐标
+right_single(point='x1 y1') - 右键单击
+drag(start_point='x1 y1', end_point='x2 y2') - 拖拽
+hotkey(key='ctrl c') - 快捷键(空格分隔,小写,最多3个键)
+type(content='xxx') - 输入文本(用\\' \\\" \\n 转义)
+scroll(point='x1 y1', direction='down or up or right or left') - 滚动
+wait() - 等待5秒
+finished() - 任务完成
+
+## Output Format
+Thought: [你的推理过程]
+Action: [选择一个动作]
+"""
+
+UI_TARS_SYSTEM_PROMPT = f"""You are UI-TARS-1.5, a GUI agent. Given a task and screenshot, output ONLY:
+
+{OFFICIAL_ACTION_SPACE}
+
+## Note
+- Write a small plan and summarize the next action in one sentence in Thought.
+- NEVER output multiple actions.
+- x&y please in box center
+"""
\ No newline at end of file
diff --git a/src/modules/device_control_module/readme.md b/src/modules/device_control_module/readme.md
new file mode 100644
index 0000000..47a9652
--- /dev/null
+++ b/src/modules/device_control_module/readme.md
@@ -0,0 +1,10 @@
+本模块为设备控制模块接口层,此处的设备控制指的是借助AI模型进行一些设备上的自动化
+操作,支持`pc`, `android`,其他的未做过测试。
+
+依赖:
+`ui_tars`
+
+当前所使用的AI模型为`mradermacher/UI-TARS-1.5-7B-GGUF
+(Q6_K or Q4_K_M)`
+
+未来如果有更快质量更高的AI模型本模块会为其添加支持。
\ No newline at end of file
diff --git a/src/modules/text_ai_module/__init__.py b/src/modules/text_ai_module/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/modules/text_ai_module/text_ai_core/__init__.py b/src/modules/text_ai_module/text_ai_core/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/modules/text_ai_module/text_ai_core/general_text_ai_req.py b/src/modules/text_ai_module/text_ai_core/general_text_ai_req.py
new file mode 100644
index 0000000..2a0f89a
--- /dev/null
+++ b/src/modules/text_ai_module/text_ai_core/general_text_ai_req.py
@@ -0,0 +1,837 @@
+"""
+通用大语言模型调用框架
+支持本地模型(Ollama, LM Studio, llama.cpp)和云端模型(OpenAI, Anthropic, Google, Azure等)
+"""
+
+from abc import ABC, abstractmethod
+from typing import Dict, List, Optional, Union, Any, Iterator
+import json
+import os
+from loguru import logger
+from dataclasses import dataclass, asdict, field
+from enum import Enum
+import httpx
+from pydantic import BaseModel, Field
+
+class ModelProvider(Enum):
+ """支持的模型提供商枚举"""
+ OPENAI = "openai"
+ ANTHROPIC = "anthropic"
+ GOOGLE = "google"
+ AZURE = "azure"
+ OLLAMA = "ollama"
+ LM_STUDIO = "lm_studio"
+ LLAMA_CPP = "llama_cpp"
+ CUSTOM = "custom"
+
+
+@dataclass
+class ModelConfig:
+ """模型配置类"""
+ provider: ModelProvider
+ model_name: str
+ api_key: Optional[str] = None
+ base_url: Optional[str] = None
+ api_version: Optional[str] = None
+ temperature: float = 0.7
+ max_tokens: int = 1024
+ top_p: float = 1.0
+ frequency_penalty: float = 0.0
+ presence_penalty: float = 0.0
+ timeout: int = 30
+ streaming: bool = False
+ custom_headers: Dict[str, str] = field(default_factory=dict)
+
+ def to_dict(self) -> Dict:
+ """转换为字典"""
+ return asdict(self)
+
+
+@dataclass
+class ChatMessage:
+ """聊天消息类"""
+ role: str # system, user, assistant
+ content: str
+ name: Optional[str] = None
+
+ def to_dict(self) -> Dict:
+ """转换为字典"""
+ return {
+ "role": self.role,
+ "content": self.content,
+ **({"name": self.name} if self.name else {})
+ }
+
+
+@dataclass
+class ModelResponse:
+ """模型响应类"""
+ content: str # 响应内容
+ model: str # 模型名称
+ usage: Optional[Dict[str, int]] = None # 使用量
+ finish_reason: Optional[str] = None # 结束原因
+ raw_response: Optional[Dict] = None # 原始响应
+
+
+def normalize_usage(raw_usage: Optional[Dict[str, Any]], provider: ModelProvider) -> Optional[Dict[str, int]]:
+ """
+ 将不同平台的 usage 字段统一归一化为 OpenAI 标准格式
+
+ Args:
+ raw_usage: API 原始返回的 usage 数据
+ provider: 模型提供商枚举
+
+ Returns:
+ 归一化后的 usage 字典,格式:
+ {
+ "prompt_tokens": int,
+ "completion_tokens": int,
+ "total_tokens": int
+ }
+ 如果无法归一化则返回 None
+ """
+ if not raw_usage:
+ return None
+
+ # 字段映射表:{provider: (input_key, output_key, total_key)}
+ USAGE_FIELD_MAP = {
+ ModelProvider.OPENAI: ("prompt_tokens", "completion_tokens", "total_tokens"),
+ ModelProvider.AZURE: ("prompt_tokens", "completion_tokens", "total_tokens"),
+ ModelProvider.LM_STUDIO: ("prompt_tokens", "completion_tokens", "total_tokens"),
+ ModelProvider.LLAMA_CPP: ("prompt_tokens", "completion_tokens", "total_tokens"),
+ ModelProvider.OLLAMA: ("prompt_eval_count", "eval_count", None), # Ollama 没有 total
+ ModelProvider.ANTHROPIC: ("input_tokens", "output_tokens", None),
+ ModelProvider.GOOGLE: ("promptTokenCount", "candidatesTokenCount", "totalTokenCount"),
+ }
+
+ input_key, output_key, total_key = USAGE_FIELD_MAP.get(provider, (None, None, None))
+
+ if input_key is None:
+ logger.warning(f"未知的 provider '{provider}',无法归一化 usage")
+ return None
+
+ try:
+ # 提取字段值
+ prompt_tokens = raw_usage.get(input_key, 0)
+ completion_tokens = raw_usage.get(output_key, 0)
+
+ # 处理嵌套字典(如有些 API 的 usage 格式特殊)
+ if isinstance(prompt_tokens, dict):
+ prompt_tokens = prompt_tokens.get("value", 0)
+ if isinstance(completion_tokens, dict):
+ completion_tokens = completion_tokens.get("value", 0)
+
+ # 转换为整数
+ prompt_tokens = int(prompt_tokens) if prompt_tokens else 0
+ completion_tokens = int(completion_tokens) if completion_tokens else 0
+
+ # 计算 total(如果 API 没提供)
+ if total_key and total_key in raw_usage:
+ total_tokens = int(raw_usage[total_key])
+ else:
+ total_tokens = prompt_tokens + completion_tokens
+
+ normalized = {
+ "prompt_tokens": prompt_tokens,
+ "completion_tokens": completion_tokens,
+ "total_tokens": total_tokens
+ }
+
+ logger.debug(f"归一化 usage | {provider} -> OpenAI格式: {normalized}")
+ return normalized
+
+ except Exception as e:
+ logger.error(f"归一化 usage 失败: {e} | raw_usage: {raw_usage}")
+ return None
+
+class BaseLLMClient(ABC):
+ """大语言模型客户端基类"""
+
+ def __init__(self, config: ModelConfig):
+ self.config = config
+ self.client = None
+ self._initialize_client()
+
+ @abstractmethod
+ def _initialize_client(self):
+ """初始化客户端"""
+ pass
+
+ @abstractmethod
+ def chat_completion(
+ self,
+ messages: List[Union[ChatMessage, Dict]],
+ **kwargs
+ ) -> Union[ModelResponse, Iterator[ModelResponse]]:
+ """聊天补全"""
+ pass
+
+ @abstractmethod
+ def completion(
+ self,
+ prompt: str,
+ **kwargs
+ ) -> Union[ModelResponse, Iterator[ModelResponse]]:
+ """文本补全"""
+ pass
+
+ def format_messages(self, messages: List[Union[ChatMessage, Dict]]) -> List[Dict]:
+ """格式化消息列表"""
+ formatted = []
+ for msg in messages:
+ if isinstance(msg, ChatMessage):
+ formatted.append(msg.to_dict())
+ else:
+ formatted.append(msg)
+ return formatted
+
+
+class OpenAIClient(BaseLLMClient):
+ """OpenAI客户端"""
+
+ def _initialize_client(self):
+ try:
+ from openai import OpenAI
+
+ api_key = self.config.api_key
+ if not api_key:
+ raise ValueError("OpenAI API密钥未设置")
+
+ self.client = OpenAI(
+ api_key=api_key,
+ base_url=self.config.base_url,
+ timeout=self.config.timeout
+ )
+ logger.info(f"OpenAI客户端初始化成功,base_url: {self.config.base_url}")
+ except ImportError:
+ logger.error("请安装openai包: pip install openai")
+ raise
+
+ def chat_completion(self, messages, **kwargs):
+ formatted_messages = self.format_messages(messages)
+
+ # 获取streaming参数,优先使用kwargs中的设置
+ streaming = kwargs.get("streaming", self.config.streaming)
+
+ # 合并配置
+ params = {
+ "model": self.config.model_name,
+ "messages": formatted_messages,
+ "temperature": kwargs.get("temperature", self.config.temperature),
+ "max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
+ "top_p": kwargs.get("top_p", self.config.top_p),
+ "frequency_penalty": kwargs.get("frequency_penalty", self.config.frequency_penalty),
+ "presence_penalty": kwargs.get("presence_penalty", self.config.presence_penalty),
+ "stream": streaming, # 使用正确的streaming设置
+ }
+
+ logger.info(f"🔧 调用参数: streaming={streaming}")
+
+ if streaming:
+ return self._stream_chat_completion(params)
+ else:
+ return self._normal_chat_completion(params)
+
+ def _normal_chat_completion(self, params):
+ """非流式响应处理"""
+ logger.info("📡 发送非流式请求...")
+ response = self.client.chat.completions.create(**params)
+ raw_usage = response.usage
+ normalized_usage = normalize_usage(
+ raw_usage.model_dump() if hasattr(raw_usage, 'model_dump') else raw_usage,
+ ModelProvider.OPENAI
+ )
+ return ModelResponse(
+ content=response.choices[0].message.content,
+ model=response.model,
+ usage=normalized_usage,
+ finish_reason=response.choices[0].finish_reason,
+ raw_response=response.model_dump() if hasattr(response, 'model_dump') else response.dict()
+ )
+
+ def _stream_chat_completion(self, params):
+ """流式响应处理"""
+ logger.info("📡 发送流式请求...")
+ response_stream = self.client.chat.completions.create(**params)
+
+ full_content = ""
+ for chunk in response_stream:
+ if chunk.choices[0].delta.content is not None:
+ content_chunk = chunk.choices[0].delta.content
+ full_content += content_chunk
+ yield ModelResponse(
+ content=content_chunk,
+ model=chunk.model,
+ raw_response=chunk.model_dump() if hasattr(chunk, 'model_dump') else chunk.dict()
+ )
+
+ def completion(self, prompt, **kwargs):
+ # OpenAI 推荐使用 chat_completion,这里保持兼容
+ messages = [{"role": "user", "content": prompt}]
+ return self.chat_completion(messages, **kwargs)
+
+
+class AnthropicClient(BaseLLMClient):
+ """Anthropic Claude客户端"""
+
+ def _initialize_client(self):
+ try:
+ from anthropic import Anthropic
+ self.client = Anthropic(
+ api_key=self.config.api_key or os.getenv("ANTHROPIC_API_KEY"),
+ timeout=self.config.timeout
+ )
+ except ImportError:
+ logger.error("请安装anthropic包: pip install anthropic")
+ raise
+
+ def chat_completion(self, messages, **kwargs):
+ formatted_messages = self.format_messages(messages)
+
+ # Claude 的消息格式转换
+ claude_messages = []
+ system_message = None
+
+ for msg in formatted_messages:
+ if msg["role"] == "system":
+ system_message = msg["content"]
+ else:
+ claude_messages.append({
+ "role": msg["role"],
+ "content": msg["content"]
+ })
+ # 明确获取streaming参数
+
+ params = {
+ "model": self.config.model_name,
+ "messages": claude_messages,
+ "max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
+ "temperature": kwargs.get("temperature", self.config.temperature),
+ "top_p": kwargs.get("top_p", self.config.top_p),
+ "stream": kwargs.get("streaming", self.config.streaming),
+ }
+
+ if system_message:
+ params["system"] = system_message
+
+ if self.config.streaming:
+ return self._stream_chat_completion(params)
+ else:
+ return self._normal_chat_completion(params)
+
+ def _normal_chat_completion(self, params):
+ response = self.client.messages.create(**params)
+ # Anthropic 返回的usage格式和OpenAI不同,需要进行转换
+ raw_usage = response.usage
+ normalized_usage = normalize_usage(
+ raw_usage.model_dump() if hasattr(raw_usage, 'model_dump') else raw_usage,
+ ModelProvider.ANTHROPIC
+ )
+ return ModelResponse(
+ content=response.content[0].text,
+ model=response.model,
+ usage=normalized_usage,
+ finish_reason=response.stop_reason,
+ raw_response=response.model_dump() if hasattr(response, 'model_dump') else response.dict()
+ )
+
+ def _stream_chat_completion(self, params):
+ with self.client.messages.stream(**params) as stream:
+ for chunk in stream:
+ if chunk.type_ == "content_block_delta":
+ yield ModelResponse(
+ content=chunk.delta.text,
+ model=params["model"],
+ raw_response=chunk.model_dump() if hasattr(chunk, 'model_dump') else chunk.dict()
+ )
+
+ def completion(self, prompt, **kwargs):
+ messages = [{"role": "user", "content": prompt}]
+ return self.chat_completion(messages, **kwargs)
+
+
+class OllamaClient(BaseLLMClient):
+ """Ollama本地模型客户端"""
+
+ def _initialize_client(self):
+ import httpx
+ self.base_url = self.config.base_url or "http://localhost:11434"
+ self.client = httpx.Client(
+ base_url=self.base_url,
+ timeout=self.config.timeout
+ )
+
+ def chat_completion(self, messages, **kwargs):
+ formatted_messages = self.format_messages(messages)
+
+ payload = {
+ "model": self.config.model_name,
+ "messages": formatted_messages,
+ "options": {
+ "temperature": kwargs.get("temperature", self.config.temperature),
+ "top_p": kwargs.get("top_p", self.config.top_p),
+ "num_predict": kwargs.get("max_tokens", self.config.max_tokens),
+ },
+ "stream": kwargs.get("streaming", self.config.streaming),
+ }
+
+ if self.config.streaming:
+ return self._stream_chat_completion(payload)
+ else:
+ return self._normal_chat_completion(payload)
+
+ def _normal_chat_completion(self, payload):
+ response = self.client.post("/api/chat", json=payload)
+ response.raise_for_status()
+ data = response.json()
+ normalized_usage = normalize_usage(data, ModelProvider.OLLAMA)
+ return ModelResponse(
+ content=data["message"]["content"],
+ model=data["model"],
+ usage=normalized_usage,
+ finish_reason=data.get("done_reason"),
+ raw_response=data
+ )
+
+ def _stream_chat_completion(self, payload):
+ with self.client.stream("POST", "/api/chat", json=payload) as response:
+ for line in response.iter_lines():
+ if line.strip():
+ try:
+ data = json.loads(line)
+ if "message" in data and "content" in data["message"]:
+ yield ModelResponse(
+ content=data["message"]["content"],
+ model=data.get("model", self.config.model_name),
+ raw_response=data
+ )
+ except json.JSONDecodeError:
+ continue
+
+ def completion(self, prompt, **kwargs):
+ payload = {
+ "model": self.config.model_name,
+ "prompt": prompt,
+ "options": {
+ "temperature": kwargs.get("temperature", self.config.temperature),
+ "top_p": kwargs.get("top_p", self.config.top_p),
+ "num_predict": kwargs.get("max_tokens", self.config.max_tokens),
+ },
+ "stream": kwargs.get("streaming", self.config.streaming),
+ }
+
+ if self.config.streaming:
+ return self._stream_completion(payload)
+ else:
+ return self._normal_completion(payload)
+
+ def _normal_completion(self, payload):
+ response = self.client.post("/api/generate", json=payload)
+ response.raise_for_status()
+ data = response.json()
+
+ return ModelResponse(
+ content=data["response"],
+ model=data["model"],
+ usage={
+ "prompt_tokens": data.get("prompt_eval_count", 0),
+ "completion_tokens": data.get("eval_count", 0),
+ "total_tokens": data.get("prompt_eval_count", 0) + data.get("eval_count", 0)
+ },
+ finish_reason=data.get("done_reason"),
+ raw_response=data
+ )
+
+ def _stream_completion(self, payload):
+ with self.client.stream("POST", "/api/generate", json=payload) as response:
+ for line in response.iter_lines():
+ if line.strip():
+ try:
+ data = json.loads(line)
+ if "response" in data:
+ yield ModelResponse(
+ content=data["response"],
+ model=data.get("model", self.config.model_name),
+ raw_response=data
+ )
+ except json.JSONDecodeError:
+ continue
+
+
+class GenericLLMClient(BaseLLMClient):
+ """通用HTTP客户端,支持LM Studio和其他兼容OpenAI API的本地模型"""
+
+ def _initialize_client(self):
+ import httpx
+ self.base_url = self.config.base_url or "http://localhost:1234/v1"
+ self.client = httpx.Client(
+ base_url=self.base_url,
+ timeout=self.config.timeout,
+ headers=self.config.custom_headers
+ )
+
+ def chat_completion(self, messages, **kwargs):
+ formatted_messages = self.format_messages(messages)
+
+ # 明确获取 streaming 参数
+ streaming = kwargs.get("streaming", self.config.streaming)
+
+ payload = {
+ "model": self.config.model_name,
+ "messages": formatted_messages,
+ "temperature": kwargs.get("temperature", self.config.temperature),
+ "max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
+ "top_p": kwargs.get("top_p", self.config.top_p),
+ "stream": streaming, # 使用明确的 streaming 变量
+ }
+
+ logger.info(f"GenericLLMClient 参数: streaming={streaming}")
+
+ if streaming:
+ return self._stream_chat_completion(payload)
+ else:
+ return self._normal_chat_completion(payload)
+
+ def _normal_chat_completion(self, payload):
+ logger.info(f"GenericLLMClient 发送非流式请求到: {self.base_url}")
+ response = self.client.post("/chat/completions", json=payload)
+ response.raise_for_status()
+ data = response.json()
+
+ logger.info(f"GenericLLMClient 收到响应,模型: {data.get('model')}")
+
+ return ModelResponse(
+ content=data["choices"][0]["message"]["content"],
+ model=data["model"],
+ usage=data.get("usage"),
+ finish_reason=data["choices"][0].get("finish_reason"),
+ raw_response=data
+ )
+
+ def _stream_chat_completion(self, payload):
+ logger.info(f"GenericLLMClient 发送流式请求到: {self.base_url}")
+ with self.client.stream("POST", "/chat/completions", json=payload) as response:
+ for line in response.iter_lines():
+ if line.startswith("data: "):
+ chunk = line[6:]
+ if chunk == "[DONE]":
+ break
+ try:
+ data = json.loads(chunk)
+ if data["choices"][0]["delta"].get("content"):
+ yield ModelResponse(
+ content=data["choices"][0]["delta"]["content"],
+ model=data["model"],
+ raw_response=data
+ )
+ except json.JSONDecodeError:
+ continue
+
+ def completion(self, prompt, **kwargs):
+ payload = {
+ "model": self.config.model_name,
+ "prompt": prompt,
+ "temperature": kwargs.get("temperature", self.config.temperature),
+ "max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
+ "top_p": kwargs.get("top_p", self.config.top_p),
+ "stream": kwargs.get("streaming", self.config.streaming),
+ }
+
+ streaming = kwargs.get("streaming", self.config.streaming)
+
+ if streaming:
+ return self._stream_completion(payload)
+ else:
+ return self._normal_completion(payload)
+
+ def _normal_completion(self, payload):
+ response = self.client.post("/completions", json=payload)
+ response.raise_for_status()
+ data = response.json()
+
+ return ModelResponse(
+ content=data["choices"][0]["text"],
+ model=data["model"],
+ usage=data.get("usage"),
+ finish_reason=data["choices"][0].get("finish_reason"),
+ raw_response=data
+ )
+
+ def _stream_completion(self, payload):
+ with self.client.stream("POST", "/completions", json=payload) as response:
+ for line in response.iter_lines():
+ if line.startswith("data: "):
+ chunk = line[6:]
+ if chunk == "[DONE]":
+ break
+ try:
+ data = json.loads(chunk)
+ if data["choices"][0].get("text"):
+ yield ModelResponse(
+ content=data["choices"][0]["text"],
+ model=data["model"],
+ raw_response=data
+ )
+ except json.JSONDecodeError:
+ continue
+
+class UnifiedLLM:
+ """
+ 统一的大语言模型调用类
+
+ 支持多种模型提供商:
+ - 云端模型:OpenAI, Anthropic, Google, Azure
+ - 本地模型:Ollama, LM Studio, llama.cpp
+
+ 使用示例:
+ ```python
+ # 初始化OpenAI客户端
+ config = ModelConfig(
+ provider=ModelProvider.OPENAI,
+ model_name="gpt-4",
+ api_key="your-api-key"
+ )
+ llm = UnifiedLLM(config)
+
+ # 聊天补全
+ messages = [
+ {"role": "system", "content": "你是一个有用的助手"},
+ {"role": "user", "content": "你好!"}
+ ]
+ response = llm.chat(messages)
+ print(response.content)
+
+ # 流式响应
+ config.streaming = True
+ llm.update_config(config)
+ for chunk in llm.chat(messages):
+ print(chunk.content, end="", flush=True)
+ ```
+ """
+
+ def __init__(self, config: ModelConfig):
+ """
+ 初始化统一LLM
+
+ Args:
+ config: 模型配置
+ """
+ self.config = config
+ self.client = self._create_client()
+ logger.info(f" UnifiedLLM 初始化完成")
+ logger.info(f" 提供商: {config.provider}")
+ logger.info(f" 模型: {config.model_name}")
+ logger.info(f" 流式默认: {config.streaming}")
+
+ def _create_client(self) -> BaseLLMClient:
+ """根据配置创建客户端"""
+ provider = self.config.provider
+
+ if provider == ModelProvider.OPENAI:
+ return OpenAIClient(self.config)
+ elif provider == ModelProvider.ANTHROPIC:
+ return AnthropicClient(self.config)
+ elif provider == ModelProvider.OLLAMA:
+ return OllamaClient(self.config)
+ elif provider in [ModelProvider.LM_STUDIO, ModelProvider.LLAMA_CPP, ModelProvider.CUSTOM]:
+ return GenericLLMClient(self.config)
+ elif provider == ModelProvider.GOOGLE:
+ # 这里可以扩展Google Gemini支持
+ return GenericLLMClient(self.config)
+ elif provider == ModelProvider.AZURE:
+ # Azure OpenAI需要特殊处理
+ return GenericLLMClient(self.config)
+ else:
+ raise ValueError(f"不支持的模型提供商: {provider}")
+
+ def update_config(self, config: ModelConfig):
+ """更新配置并重新创建客户端"""
+ self.config = config
+ self.client = self._create_client()
+ logger.info(f"UnifiedLLM 配置已更新")
+
+ def chat(self, messages: List[Union[ChatMessage, Dict]], **kwargs) -> Union[ModelResponse, Iterator[ModelResponse]]:
+ """
+ 聊天补全
+
+ Args:
+ messages: 消息列表
+ **kwargs: 其他参数,会覆盖config中的设置
+
+ Returns:
+ ModelResponse 或 ModelResponse 的迭代器(流式模式)
+ """
+ # 明确获取streaming参数
+ streaming = kwargs.get("streaming", self.config.streaming)
+ logger.info(f" UnifiedLLM.chat() 调用")
+ logger.info(f" 消息数: {len(messages)}")
+ logger.info(f" streaming参数: {streaming}")
+
+ # 调用客户端
+ result = self.client.chat_completion(messages, **kwargs)
+
+ # 类型检查(调试用)
+ if streaming:
+ if not hasattr(result, '__iter__'):
+ logger.warning(f"警告: streaming=True 但返回的不是迭代器")
+ else:
+ if hasattr(result, '__iter__'):
+ logger.warning(f"警告: streaming=False 但返回的是迭代器")
+ elif not isinstance(result, ModelResponse):
+ logger.warning(f"警告: streaming=False 但返回的不是ModelResponse")
+
+ return result
+
+ def complete(self, prompt: str, **kwargs) -> Union[ModelResponse, Iterator[ModelResponse]]:
+ """
+ 文本补全
+
+ Args:
+ prompt: 提示文本
+ **kwargs: 其他参数
+
+ Returns:
+ ModelResponse 或 ModelResponse 的迭代器(流式模式)
+ """
+ streaming = kwargs.get("streaming", self.config.streaming)
+ logger.info(f" UnifiedLLM.complete() 调用")
+ logger.info(f" prompt长度: {len(prompt)}")
+ logger.info(f" streaming参数: {streaming}")
+
+ return self.client.completion(prompt, **kwargs)
+
+ def stream_chat(self, messages: List[Union[ChatMessage, Dict]], **kwargs) -> Iterator[ModelResponse]:
+ """
+ 流式聊天补全(便捷方法)
+
+ Args:
+ messages: 消息列表
+ **kwargs: 其他参数
+
+ Returns:
+ ModelResponse 的迭代器
+ """
+ kwargs["streaming"] = True
+ logger.info(f"UnifiedLLM.stream_chat() 调用")
+
+ result = self.chat(messages, **kwargs)
+
+ # 确保返回的是迭代器
+ if not hasattr(result, '__iter__'):
+ raise TypeError("stream_chat 应该返回迭代器,但返回了其他类型")
+
+ return result
+
+ def stream_complete(self, prompt: str, **kwargs) -> Iterator[ModelResponse]:
+ """
+ 流式文本补全(便捷方法)
+
+ Args:
+ prompt: 提示文本
+ **kwargs: 其他参数
+
+ Returns:
+ ModelResponse 的迭代器
+ """
+ kwargs["streaming"] = True
+ logger.info(f"UnifiedLLM.stream_complete() 调用")
+
+ result = self.complete(prompt, **kwargs)
+
+ # 确保返回的是迭代器
+ if not hasattr(result, '__iter__'):
+ raise TypeError("stream_complete 应该返回迭代器,但返回了其他类型")
+
+ return result
+
+
+# 快捷函数
+def create_llm_client(
+ provider: Union[str, ModelProvider],
+ model_name: str,
+ **kwargs
+) -> UnifiedLLM:
+ """
+ 快捷创建LLM客户端
+
+ Args:
+ provider: 提供商名称或枚举
+ model_name: 模型名称
+ **kwargs: 其他配置参数
+
+ Returns:
+ UnifiedLLM 实例
+ """
+ if isinstance(provider, str):
+ provider = ModelProvider(provider.lower())
+
+ config = ModelConfig(
+ provider=provider,
+ model_name=model_name,
+ **kwargs
+ )
+
+ return UnifiedLLM(config)
+
+
+# 使用示例
+def example_usage():
+ """使用示例"""
+
+ # 示例1: 使用OpenAI
+ print("示例1: 使用OpenAI")
+ openai_config = ModelConfig(
+ provider=ModelProvider.OPENAI,
+ model_name="gpt-3.5-turbo",
+ api_key=os.getenv("OPENAI_API_KEY"),
+ temperature=0.8
+ )
+
+ try:
+ openai_llm = UnifiedLLM(openai_config)
+ messages = [
+ ChatMessage(role="system", content="你是一个有用的助手"),
+ ChatMessage(role="user", content="请用Python写一个Hello World程序")
+ ]
+ response = openai_llm.chat(messages)
+ print(f"响应: {response.content[:100]}...")
+ except Exception as e:
+ print(f"OpenAI示例错误: {e}")
+
+ # 示例2: 使用Ollama(本地模型)
+ print("\n示例2: 使用Ollama(本地模型)")
+ ollama_config = ModelConfig(
+ provider=ModelProvider.OLLAMA,
+ model_name="llama2",
+ base_url="http://localhost:11434",
+ temperature=0.7,
+ streaming=True # 流式响应
+ )
+
+ try:
+ ollama_llm = UnifiedLLM(ollama_config)
+ messages = [
+ {"role": "user", "content": "什么是人工智能?"}
+ ]
+
+ print("流式响应:")
+ for chunk in ollama_llm.stream_chat(messages):
+ print(chunk.content, end="", flush=True)
+ print()
+ except Exception as e:
+ print(f"Ollama示例错误: {e}(请确保Ollama服务正在运行)")
+
+ # 示例3: 使用快捷函数
+ print("\n示例3: 使用快捷函数")
+ try:
+ llm = create_llm_client(
+ provider="openai",
+ model_name="gpt-3.5-turbo",
+ api_key=os.getenv("OPENAI_API_KEY"),
+ temperature=0.5
+ )
+
+ response = llm.complete("天空为什么是蓝色的?")
+ print(f"响应: {response.content[:100]}...")
+ except Exception as e:
+ print(f"快捷函数示例错误: {e}")
\ No newline at end of file
diff --git a/src/modules/text_ai_module/text_ai_core/readme.md b/src/modules/text_ai_module/text_ai_core/readme.md
new file mode 100644
index 0000000..cc3fcd4
--- /dev/null
+++ b/src/modules/text_ai_module/text_ai_core/readme.md
@@ -0,0 +1,162 @@
+# 大语言模型调用框架架构图
+general_text_ai_req.py
+```mermaid
+graph TB
+ subgraph "应用层"
+ A[用户应用] --> B[UnifiedLLM 统一接口]
+ end
+
+ subgraph "适配器层"
+ B --> C{模型提供商路由}
+ C --> D[OpenAI 适配器]
+ C --> E[Anthropic 适配器]
+ C --> F[Ollama 适配器]
+ C --> G[通用HTTP适配器]
+ C --> H[其他适配器]
+ end
+
+ subgraph "服务层"
+ D --> I[OpenAI API]
+ E --> J[Anthropic API]
+ F --> K[Ollama 服务]
+ G --> L[LM Studio]
+ G --> M[llama.cpp]
+ G --> N[其他兼容API]
+ end
+
+ subgraph "配置层"
+ O[ModelConfig] --> C
+ O --> D
+ O --> E
+ O --> F
+ O --> G
+ end
+
+ subgraph "数据流"
+ P[输入: 消息/提示] --> B
+ I --> Q[输出: ModelResponse]
+ J --> Q
+ K --> Q
+ L --> Q
+ M --> Q
+ N --> Q
+ end
+
+ style A fill:#4567f1
+ style B fill:#4567f1
+ style O fill:#456748
+ style D fill:#457911
+ style E fill:#466bd5
+ style F fill:#4567f1
+ style G fill:#4567f1
+```
+
+# 数据流图
+```mermaid
+sequenceDiagram
+ participant User as 用户/应用
+ participant UnifiedLLM as UnifiedLLM
+ participant Adapter as 适配器
+ participant API as API服务
+
+ User->>UnifiedLLM: 调用chat()或complete()
+ UnifiedLLM->>UnifiedLLM: 根据配置选择适配器
+ UnifiedLLM->>Adapter: 转发请求
+ Adapter->>API: 发送HTTP请求/API调用
+ Note over API: 处理请求并生成响应
+
+ alt 流式模式
+ API-->>Adapter: 流式响应数据
+ Adapter-->>UnifiedLLM: 流式ModelResponse
+ UnifiedLLM-->>User: 迭代器返回分块响应
+ else 非流式模式
+ API-->>Adapter: 完整响应
+ Adapter-->>UnifiedLLM: ModelResponse对象
+ UnifiedLLM-->>User: 完整响应内容
+ end
+```
+# 类关系图
+```mermaid
+classDiagram
+ class ModelConfig {
+ +provider: ModelProvider
+ +model_name: str
+ +api_key: Optional[str]
+ +base_url: Optional[str]
+ +temperature: float
+ +max_tokens: int
+ +to_dict() Dict
+ }
+
+ class ChatMessage {
+ +role: str
+ +content: str
+ +name: Optional[str]
+ +to_dict() Dict
+ }
+
+ class ModelResponse {
+ +content: str
+ +model: str
+ +usage: Optional[Dict]
+ +finish_reason: Optional[str]
+ +raw_response: Optional[Dict]
+ }
+
+ class BaseLLMClient {
+ <>
+ #config: ModelConfig
+ #client: Any
+ +__init__(config: ModelConfig)
+ +_initialize_client()
+ +chat_completion(messages, **kwargs)*
+ +completion(prompt, **kwargs)*
+ +format_messages(messages) List[Dict]
+ }
+
+ class UnifiedLLM {
+ -config: ModelConfig
+ -client: BaseLLMClient
+ +__init__(config: ModelConfig)
+ +_create_client() BaseLLMClient
+ +update_config(config: ModelConfig)
+ +chat(messages, **kwargs) ModelResponse
+ +complete(prompt, **kwargs) ModelResponse
+ +stream_chat(messages, **kwargs) Iterator
+ +stream_complete(prompt, **kwargs) Iterator
+ }
+
+ class OpenAIClient {
+ +_initialize_client()
+ +chat_completion(messages, **kwargs)
+ +completion(prompt, **kwargs)
+ }
+
+ class AnthropicClient {
+ +_initialize_client()
+ +chat_completion(messages, **kwargs)
+ +completion(prompt, **kwargs)
+ }
+
+ class OllamaClient {
+ +_initialize_client()
+ +chat_completion(messages, **kwargs)
+ +completion(prompt, **kwargs)
+ }
+
+ class GenericLLMClient {
+ +_initialize_client()
+ +chat_completion(messages, **kwargs)
+ +completion(prompt, **kwargs)
+ }
+
+ BaseLLMClient <|-- OpenAIClient
+ BaseLLMClient <|-- AnthropicClient
+ BaseLLMClient <|-- OllamaClient
+ BaseLLMClient <|-- GenericLLMClient
+ UnifiedLLM o-- BaseLLMClient
+ UnifiedLLM --> ModelConfig
+ BaseLLMClient --> ModelConfig
+ BaseLLMClient --> ChatMessage
+ BaseLLMClient --> ModelResponse
+```
\ No newline at end of file
diff --git a/src/modules/tts_module/__init__.py b/src/modules/tts_module/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/modules/tts_module/readme.md b/src/modules/tts_module/readme.md
new file mode 100644
index 0000000..eacada0
--- /dev/null
+++ b/src/modules/tts_module/readme.md
@@ -0,0 +1,26 @@
+本模块为tts模块,即文本转语音模块。
+
+本模块负责将来自AI的回复转为语音。
+
+说明:
+在本module当中,每个子模块的用途分别是:
+- tts_core 对不同的tts的实现,提供相对统一的接口
+ - gpt_sovits
+ 实现了gpt_sovits的tts接口封装
+
+
+async_audio_player.py
+```mermaid
+sequenceDiagram
+ participant TTS as GPT-SoVITS API
+ participant WS as WebSocket服务
+ participant Buffer as 音频缓冲区
+ participant Player as 音频播放器
+
+ TTS->>WS: 流式音频块(chunks)
+ WS->>Buffer: 写入队列(Queue)
+ Buffer->>Player: 消费PCM数据
+ Player->>声卡: 实时播放
+
+ Note over TTS,Player: 三重缓冲 + 动态采样率检测
+```
\ No newline at end of file
diff --git a/src/modules/tts_module/tts_core/__init__.py b/src/modules/tts_module/tts_core/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/modules/tts_module/tts_core/async_audio_player.py b/src/modules/tts_module/tts_core/async_audio_player.py
new file mode 100644
index 0000000..fd42b5a
--- /dev/null
+++ b/src/modules/tts_module/tts_core/async_audio_player.py
@@ -0,0 +1,164 @@
+# tts_core/async_audio_player.py
+import asyncio
+import io
+from loguru import logger
+from typing import Optional
+import numpy as np
+import sounddevice as sd
+import wave
+
+class AsyncAudioPlayer:
+ """
+ 异步流式音频播放器
+ - 自动检测WAV头并解析采样率
+ - 使用环形缓冲区确保播放流畅
+ - 支持动态音频格式切换
+ """
+
+ def __init__(self, buffer_size: int = 10):
+ """
+ Args:
+ buffer_size: 音频块缓冲数量(越大越稳定,但延迟越高)
+ """
+ self.audio_queue = asyncio.Queue(maxsize=buffer_size)
+ self.sample_rate = 32000 # 默认采样率
+ self.channels = 1
+ self.dtype = np.float32
+ self.stream: Optional[sd.OutputStream] = None
+ self.is_playing = False
+ self._first_chunk_processed = False
+ logger.info(f"🎵 音频播放器初始化,缓冲区大小: {buffer_size}")
+
+ async def add_chunk(self, audio_data: bytes):
+ """
+ 添加音频块到播放队列
+ 自动处理第一个chunk(包含WAV头)
+ """
+ try:
+ # 第一个chunk需要解析WAV头
+ if not self._first_chunk_processed:
+ # 写入BytesIO以便wave模块读取
+ wav_buffer = io.BytesIO(audio_data)
+ try:
+ with wave.open(wav_buffer, 'rb') as wav_file:
+ # 解析WAV头信息
+ self.sample_rate = wav_file.getframerate()
+ self.channels = wav_file.getnchannels()
+ self.sampwidth = wav_file.getsampwidth()
+
+ # 读取PCM数据(去掉头部)
+ pcm_data = wav_file.readframes(wav_file.getnframes())
+
+ logger.info(f"📊 解析WAV头: {self.sample_rate}Hz, {self.channels}ch, {self.sampwidth * 8}bit")
+
+ # 转换为numpy数组
+ if self.sampwidth == 2:
+ audio_array = np.frombuffer(pcm_data, dtype=np.int16).astype(np.float32) / 32768.0
+ elif self.sampwidth == 4:
+ audio_array = np.frombuffer(pcm_data, dtype=np.int32).astype(np.float32) / 2147483648.0
+ else:
+ raise ValueError(f"不支持的采样宽度: {self.sampwidth}")
+
+ # 转单声道(如果多声道)
+ if self.channels > 1:
+ audio_array = audio_array.reshape(-1, self.channels).mean(axis=1)
+
+ await self.audio_queue.put(audio_array)
+ self._first_chunk_processed = True
+
+ except wave.Error:
+ # 可能是不完整的WAV头,尝试直接播放
+ logger.warning("⚠️ WAV头解析失败,尝试直接播放")
+ await self._play_raw(audio_data)
+ return
+ else:
+ # 后续chunk直接播放(RAW PCM)
+ await self._play_raw(audio_data)
+
+ except Exception as e:
+ logger.error(f"❌ 音频块处理失败: {e}")
+
+ async def _play_raw(self, audio_data: bytes):
+ """播放RAW PCM数据"""
+ try:
+ # 假设是16位PCM(最常见)
+ audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
+
+ # 如果是多声道数据(罕见)
+ if len(audio_array) % self.channels == 0 and self.channels > 1:
+ audio_array = audio_array.reshape(-1, self.channels).mean(axis=1)
+
+ await self.audio_queue.put(audio_array)
+ except Exception as e:
+ logger.error(f"❌ RAW音频处理失败: {e}")
+
+ async def play_worker(self):
+ """后台播放任务"""
+ logger.info("🎧 音频播放任务启动")
+
+ while self.is_playing or not self.audio_queue.empty():
+ try:
+ # 从队列获取音频块(最多等待0.5秒)
+ audio_chunk = await asyncio.wait_for(self.audio_queue.get(), timeout=0.5)
+
+ # 延迟初始化音频流(直到获得第一个数据块)
+ if self.stream is None:
+ logger.info(f"🔊 打开音频输出流: {self.sample_rate}Hz")
+ self.stream = sd.OutputStream(
+ samplerate=self.sample_rate,
+ channels=1,
+ dtype=self.dtype,
+ blocksize=1024, # 低延迟模式
+ latency='low'
+ )
+ self.stream.start()
+
+ # 写入音频流播放
+ self.stream.write(audio_chunk)
+
+ # 标记任务完成
+ self.audio_queue.task_done()
+
+ except asyncio.TimeoutError:
+ continue
+ except Exception as e:
+ logger.error(f"❌ 播放任务异常: {e}")
+ break
+
+ logger.info("🛑 音频播放任务结束")
+
+ async def start(self):
+ """启动播放系统"""
+ self.is_playing = True
+ self._first_chunk_processed = False
+ self.play_task = asyncio.create_task(self.play_worker())
+
+ async def stop(self):
+ """停止播放并清理资源"""
+ self.is_playing = False
+
+ # 等待播放任务结束
+ if hasattr(self, 'play_task'):
+ await self.play_task
+
+ # 关闭音频流
+ if self.stream is not None:
+ self.stream.stop()
+ self.stream.close()
+ self.stream = None
+
+ # 清空队列
+ while not self.audio_queue.empty():
+ try:
+ self.audio_queue.get_nowait()
+ except:
+ break
+
+ logger.info("✅ 音频播放已停止")
+
+ async def __aenter__(self):
+ await self.start()
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ await self.stop()
\ No newline at end of file
diff --git a/src/modules/tts_module/tts_core/gpt_sovits/__init__.py b/src/modules/tts_module/tts_core/gpt_sovits/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/modules/tts_module/tts_core/gpt_sovits/gpt_sovits_client.py b/src/modules/tts_module/tts_core/gpt_sovits/gpt_sovits_client.py
new file mode 100644
index 0000000..96e0bf7
--- /dev/null
+++ b/src/modules/tts_module/tts_core/gpt_sovits/gpt_sovits_client.py
@@ -0,0 +1,378 @@
+# gpt_sovits/gpt_sovits_client.py
+import asyncio
+import json
+from loguru import logger
+from enum import Enum
+from pathlib import Path
+from typing import AsyncGenerator, Optional, Union, Dict, Any
+from dataclasses import dataclass
+
+import httpx
+from pydantic import BaseModel, Field, validator
+
+
+class APIError(Exception):
+ """API调用异常"""
+ def __init__(self, status_code: int, message: str):
+ self.status_code = status_code
+ self.message = message
+ super().__init__(f"API Error {status_code}: {message}")
+
+
+class StreamingMode(Enum):
+ """流式模式枚举"""
+ DISABLED = 0 # 非流式
+ BEST_QUALITY = 1 # 最佳质量(慢)
+ MEDIUM_QUALITY = 2 # 中等质量
+ FASTEST = 3 # 最快响应(较低质量)
+
+
+class TTSConfig(BaseModel):
+ """TTS请求配置模型"""
+ text: str = Field(..., description="待合成文本")
+ text_lang: str = Field(..., description="文本语言: zh/en/ja/ko/cantonese")
+ ref_audio_path: str = Field(..., description="参考音频路径")
+ prompt_lang: str = Field(..., description="提示文本语言")
+
+ # 可选参数
+ prompt_text: str = Field(default="", description="参考音频提示文本")
+ aux_ref_audio_paths: list = Field(default_factory=list, description="辅助参考音频")
+ top_k: int = Field(default=5, ge=1, le=100, description="Top-K采样")
+ top_p: float = Field(default=1.0, ge=0.1, le=1.0, description="Top-P采样")
+ temperature: float = Field(default=1.0, ge=0.1, le=1.0, description="采样温度")
+ text_split_method: str = Field(default="cut5", description="文本分割方法") # 默认按照标点符号切分
+ batch_size: int = Field(default=8, ge=1, le=200, description="批处理大小")
+ speed_factor: float = Field(default=1.0, ge=0.6, le=1.65, description="语速倍率")
+
+ # 流式相关
+ streaming_mode: Union[bool, int, StreamingMode] = Field(default=False, description="流式模式")
+ media_type: str = Field(default="wav", description="输出格式: wav/raw/ogg/aac") # 输出格式
+
+ # 高级参数
+ repetition_penalty: float = Field(default=1.35, ge=1.0, le=2.0) # 惩罚参数
+ sample_steps: int = Field(default=32, ge=10, le=100) # 采样步数
+ parallel_infer: bool = Field(default=True) # 并行推理
+
+ @validator('text_lang', 'prompt_lang')
+ def validate_language(cls, v):
+ """验证语言代码"""
+ valid_langs = {'zh', 'en', 'ja', 'ko', 'cantonese'}
+ if v.lower() not in valid_langs:
+ raise ValueError(f"Unsupported language: {v}. Must be one of {valid_langs}")
+ return v.lower()
+
+ @validator('media_type')
+ def validate_media_type(cls, v):
+ """验证媒体类型"""
+ valid_types = {'wav', 'raw', 'ogg', 'aac'}
+ if v not in valid_types:
+ raise ValueError(f"Unsupported media_type: {v}")
+ return v
+
+ def build_request(self) -> Dict[str, Any]:
+ """构建API请求数据"""
+ data = self.dict(exclude_none=True)
+ # 处理流式模式
+ if isinstance(self.streaming_mode, StreamingMode):
+ data['streaming_mode'] = self.streaming_mode.value
+ return data
+
+
+@dataclass
+class AudioResponse:
+ """音频响应包装类"""
+ audio_data: bytes
+ sample_rate: int = 32000
+
+ def save(self, path: Union[str, Path]) -> None:
+ """保存音频文件"""
+ path = Path(path)
+ path.parent.mkdir(parents=True, exist_ok=True)
+ with open(path, 'wb') as f:
+ f.write(self.audio_data)
+ logger.info(f"Audio saved to {path}, size: {len(self.audio_data)} bytes")
+
+
+class GPTSoVITSClient:
+ """
+ GPT-SoVITS异步API客户端
+
+ 完整支持所有TTS功能:
+ - 文本合成(流式/非流式)
+ - 模型切换(GPT/SoVITS)
+ - 参考音频设置
+ - 服务器控制
+ """
+
+ def __init__(self, host: str = "127.0.0.1", port: int = 9880, debug: bool = False):
+ """
+ 初始化客户端
+
+ Args:
+ host: API服务器地址
+ port: API端口
+ debug: 是否开启调试模式
+ """
+ self.base_url = f"http://{host}:{port}"
+ self.client = httpx.AsyncClient(
+ base_url=self.base_url,
+ timeout=httpx.Timeout(30.0, connect=5.0)
+ )
+ self.debug_mode = debug
+ logger.info(f"GPT-SoVITS Client initialized: {self.base_url}")
+
+ async def __aenter__(self) -> "GPTSoVITSClient":
+ """异步上下文管理器入口"""
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ """异步上下文管理器出口"""
+ await self.close()
+
+ async def close(self):
+ """关闭HTTP连接"""
+ await self.client.aclose()
+ logger.info("Client connection closed")
+
+ def _log_debug(self, message: str, **kwargs):
+ """调试日志"""
+ if self.debug_mode:
+ logger.debug(f"{message} | {kwargs}")
+
+ async def _handle_response(self, response: httpx.Response) -> Dict[str, Any]:
+ """统一响应处理"""
+ if response.status_code == 200:
+ content_type = response.headers.get('content-type', '')
+ if 'application/json' in content_type:
+ return response.json()
+ return {"status": "success", "content": response.content}
+ else:
+ try:
+ error_data = response.json()
+ raise APIError(response.status_code, error_data.get('message', 'Unknown error'))
+ except json.JSONDecodeError:
+ raise APIError(response.status_code, response.text)
+
+ # 核心TTS接口
+ async def tts(
+ self,
+ text: str,
+ ref_audio_path: str,
+ text_lang: str = "zh",
+ prompt_lang: str = "zh",
+ streaming_mode: StreamingMode = StreamingMode.DISABLED, # 默认禁用流式
+ media_type: str = "wav",
+ **kwargs
+ ) -> Union[AudioResponse, AsyncGenerator[AudioResponse, None]]:
+ """
+ 文本转语音(支持流式)
+
+ Args:
+ text: 待合成文本
+ ref_audio_path: 参考音频路径(服务器本地路径或URL)
+ text_lang: 文本语言
+ prompt_lang: 提示语言
+ streaming_mode: 流式模式
+ media_type: 输出格式
+ **kwargs: 其他TTS参数
+
+ Returns:
+ 非流式: AudioResponse对象
+ 流式: AsyncGenerator[AudioResponse, None]异步生成器
+
+ Example:
+ # 非流式
+ audio = await client.tts("你好", "ref.wav")
+
+ # 流式
+ async for chunk in client.tts("你好", "ref.wav", streaming_mode=StreamingMode.FASTEST):
+ process(chunk.audio_data)
+ """
+ config = TTSConfig(
+ text=text,
+ ref_audio_path=ref_audio_path,
+ text_lang=text_lang,
+ prompt_lang=prompt_lang,
+ streaming_mode=streaming_mode,
+ media_type=media_type,
+ **kwargs
+ )
+
+ self._log_debug("TTS Request", config=config.dict())
+
+ if streaming_mode == StreamingMode.DISABLED:
+ # 非流式模式
+ response = await self.client.post("/tts", json=config.build_request())
+ if response.status_code != 200:
+ raise APIError(response.status_code, await response.text())
+
+ return AudioResponse(
+ audio_data=response.content,
+ sample_rate=32000 # 默认采样率
+ )
+ else:
+ # 流式模式
+ config.parallel_infer = False # 强制关闭并行推理,避免与流式冲突
+ config.batch_size = 1 # 流式下batch_size必须为1
+ async def stream_generator():
+ async with self.client.stream(
+ "POST", "/tts",
+ json=config.build_request(),
+ timeout=httpx.Timeout(60.0) # 流式需要更长超时
+ ) as response:
+ if response.status_code != 200:
+ raise APIError(response.status_code, await response.aread())
+
+ async for chunk in response.aiter_bytes():
+ if chunk:
+ yield AudioResponse(audio_data=chunk)
+
+ return stream_generator()
+
+ # 模型管理接口
+ async def set_gpt_weights(self, weights_path: str) -> bool:
+ """
+ 切换GPT模型权重
+
+ Args:
+ weights_path: 权重文件路径(服务器本地路径)
+
+ Returns:
+ bool: 是否成功
+
+ Example:
+ await client.set_gpt_weights("models/s1bert.ckpt")
+ """
+ if not weights_path:
+ raise ValueError("weights_path cannot be empty")
+
+ params = {"weights_path": weights_path}
+ response = await self.client.get("/set_gpt_weights", params=params)
+ result = await self._handle_response(response)
+
+ logger.info(f"GPT weights switched to: {weights_path}")
+ return True
+
+ async def set_sovits_weights(self, weights_path: str) -> bool:
+ """
+ 切换SoVITS模型权重
+
+ Args:
+ weights_path: 权重文件路径(服务器本地路径)
+
+ Returns:
+ bool: 是否成功
+ """
+ if not weights_path:
+ raise ValueError("weights_path cannot be empty")
+
+ params = {"weights_path": weights_path}
+ response = await self.client.get("/set_sovits_weights", params=params)
+ await self._handle_response(response)
+
+ logger.info(f"SoVITS weights switched to: {weights_path}")
+ return True
+
+ # 参考音频管理
+ async def set_refer_audio(
+ self,
+ audio_source: Union[str, Path, bytes],
+ audio_name: Optional[str] = None
+ ) -> bool:
+ """
+ 设置参考音频(支持多种输入方式)
+
+ Args:
+ audio_source: 音频文件路径(str/Path)或音频数据(bytes)
+ audio_name: 音频文件名(仅bytes输入时需要)
+
+ Returns:
+ bool: 是否成功
+
+ Example:
+ # 方式1: 服务器本地文件
+ await client.set_refer_audio("/path/to/audio.wav")
+
+ # 方式2: 上传音频数据
+ with open("audio.wav", "rb") as f:
+ await client.set_refer_audio(f.read(), "audio.wav")
+ """
+ if isinstance(audio_source, (str, Path)):
+ # GET方式:服务器本地路径
+ params = {"refer_audio_path": str(audio_source)}
+ response = await self.client.get("/set_refer_audio", params=params)
+ await self._handle_response(response)
+ logger.info(f"Reference audio set: {audio_source}")
+ else:
+ # POST方式:上传音频数据
+ if not audio_name:
+ raise ValueError("audio_name is required when uploading bytes")
+
+ files = {"audio_file": (audio_name, audio_source, "audio/wav")}
+ response = await self.client.post("/set_refer_audio", files=files)
+ await self._handle_response(response)
+ logger.info(f"Reference audio uploaded: {audio_name}")
+
+ return True
+
+ # 服务器控制
+ async def control_command(self, command: str) -> bool:
+ """
+ 发送控制命令
+
+ Args:
+ command: 命令类型 - "restart" 或 "exit"
+
+ Returns:
+ bool: 是否成功
+
+ Warning:
+ "exit"命令会终止API服务器进程!
+ """
+ if command not in ["restart", "exit"]:
+ raise ValueError("Command must be 'restart' or 'exit'")
+
+ response = await self.client.get("/control", params={"command": command})
+ await self._handle_response(response)
+
+ logger.warning(f"Control command executed: {command}")
+ return True
+
+ # 高级快捷方法
+ async def get_server_info(self) -> Dict[str, Any]:
+ """获取服务器状态信息"""
+ # 通过调用根路径或自定义health接口
+ try:
+ response = await self.client.get("/")
+ return {"status": "online", "detail": response.text}
+ except Exception as e:
+ return {"status": "error", "detail": str(e)}
+
+ async def batch_tts(
+ self,
+ texts: list[str],
+ ref_audio_path: str,
+ **kwargs
+ ) -> list[AudioResponse]:
+ """
+ 批量TTS合成
+
+ Args:
+ texts: 文本列表
+ ref_audio_path: 参考音频
+ **kwargs: 其他TTS参数
+
+ Returns:
+ list[AudioResponse]: 音频响应列表
+ """
+ tasks = [
+ self.tts(text, ref_audio_path, **kwargs)
+ for text in texts
+ ]
+ return await asyncio.gather(*tasks)
+
+
+# 异步上下文管理器辅助函数
+async def create_client(*args, **kwargs) -> GPTSoVITSClient:
+ """快速创建客户端实例"""
+ return GPTSoVITSClient(*args, **kwargs)
\ No newline at end of file
diff --git a/src/modules/websocket_base_module/__init__.py b/src/modules/websocket_base_module/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/modules/websocket_base_module/dto/__init__.py b/src/modules/websocket_base_module/dto/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/modules/websocket_base_module/dto/dto_base.py b/src/modules/websocket_base_module/dto/dto_base.py
new file mode 100644
index 0000000..c9c3e63
--- /dev/null
+++ b/src/modules/websocket_base_module/dto/dto_base.py
@@ -0,0 +1,27 @@
+# dto/dto_base.py
+from abc import ABC
+from typing import Callable, Coroutine, Any, Dict
+from src.modules.websocket_base_module.websocket_core.core_ws_server import WebSocketServer
+
+class MessageDTO(ABC):
+ """DTO基类"""
+
+ def __init__(
+ self,
+ ws_server : WebSocketServer, # WebSocketServer单例
+ ):
+ # 保存服务器实例(用于发送)
+ self.ws_server = ws_server
+
+ # 便捷属性,DTO层直接调用
+ @property
+ def send_binary(self):
+ return self.ws_server.send_binary
+
+ @property
+ def send_text(self):
+ return self.ws_server.send_text
+
+ @property
+ def send_json(self):
+ return self.ws_server.send_json
diff --git a/src/modules/websocket_base_module/dto/dto_templates/__init__.py b/src/modules/websocket_base_module/dto/dto_templates/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/modules/websocket_base_module/dto/dto_templates/audio_data_dto.py b/src/modules/websocket_base_module/dto/dto_templates/audio_data_dto.py
new file mode 100644
index 0000000..e4683d7
--- /dev/null
+++ b/src/modules/websocket_base_module/dto/dto_templates/audio_data_dto.py
@@ -0,0 +1,95 @@
+from pydantic import Field, BaseModel
+import base64
+from datetime import datetime, timezone
+class AudioDataTransferObject(BaseModel):
+ """
+ 音频数据传输对象
+ 该对象被用于服务端与客户端的音频数据交互
+ 同时支持流式与非流式的音频数据
+ 同时收发对等(通过Owner标识)
+ """
+ Owner: str = Field(default="server", description="音频数据的拥有者(server or client)")
+ isStream: bool = Field(default=False, description="音频数据是否为流式数据")
+ isStart: bool = Field(default=False, description="音频数据是否开始(流式时有效)")
+ isEnd: bool = Field(default=False, description="音频数据是否结束(流式时有效)")
+ sequence: int = Field(default=0, description="音频数据块序列号(流式时有效)")
+ data: bytes = Field(default=b"", description="音频数据,流式时为分块数据,base64编码")
+ sampleRate: int = Field(default=32000, description="音频采样率")
+ channelCount: int = Field(default=1, description="音频通道数")
+ bitDepth: int = Field(default=16, description="音频采样位数")
+ duration: float = Field(default=0.0, description="音频时长")
+ text: str = Field(default="", description="音频对应的文本")
+
+ def set_dto_data(self, **kwargs) -> "AudioDataTransferObject":
+ """链式更新数据(Pydantic 风格)"""
+ for key, value in kwargs.items():
+ if hasattr(self, key):
+ setattr(self, key, value)
+ return self
+
+ def to_json(self) -> dict:
+ """
+ 将DTO对象转换为可序列化字典
+ 返回的json数据的格式:
+ {
+ "type": "audio_data",
+ "timestamp": 1672531200.0,
+ "data": {
+ "Owner": "server",
+ ......
+ }
+ }
+ """
+ # model_dump() 是 Pydantic v2 的序列化方法
+ payload = self.model_dump() # 获取所有模型字段
+ payload["data"] = base64.b64encode(payload["data"]).decode() # base64编码
+ # 构造嵌套结构
+ return {
+ "type": "audio_data",
+ "timestamp": datetime.now(timezone.utc).timestamp(),
+ "data": payload # 音频字段嵌套
+ }
+
+ @classmethod
+ def from_json(cls, json_data: dict) -> "AudioDataTransferObject":
+ """
+ 从JSON数据创建DTO对象
+ 传入的json数据格式:
+ {
+ "Owner": "server",
+ ......
+ }
+ """
+ payload = json_data
+ # 解码 base64 (内层 data 字段在传输时是 base64 字符串) -> bytes
+ if "data" in payload and isinstance(payload["data"], str):
+ payload["data"] = base64.b64decode(payload["data"])
+ # 构造对象 Pydantic 自动忽略 type/timestamp
+ return cls.model_validate(payload)
+
+
+# 测试代码
+if __name__ == "__main__":
+ # 模拟音频数据
+ import os
+
+ test_audio = os.urandom(1024) # 随机生成1KB音频数据
+
+ # 创建DTO
+ audio = AudioDataTransferObject(
+ data=test_audio,
+ sequence=1,
+ isStream=True,
+ sampleRate=44100,
+ duration=0.5
+ )
+
+ # 序列化 → JSON
+ json_dict = audio.to_json()
+ print(f"序列化后 data 长度: {len(json_dict['data'])}") # ~1368 字符
+ print(f"data 前30字符: {json_dict['data'][:30]}...")
+
+ # 反序列化 → DTO
+ restored = AudioDataTransferObject.from_json(json_dict)
+ print(f"反序列化后 data 长度: {len(restored.data)}") # 1024 bytes
+ print(f"数据一致: {restored.data == test_audio}") # True
diff --git a/src/modules/websocket_base_module/dto/dto_templates/auto_agent_data_dto.py b/src/modules/websocket_base_module/dto/dto_templates/auto_agent_data_dto.py
new file mode 100644
index 0000000..a992326
--- /dev/null
+++ b/src/modules/websocket_base_module/dto/dto_templates/auto_agent_data_dto.py
@@ -0,0 +1,73 @@
+from pydantic import Field, BaseModel
+from datetime import datetime, timezone
+class AutoAgentDataTransferObject(BaseModel):
+ """
+ 自动化agent数据传输对象
+ 该对象被用于服务端向客户端发送控制信息
+ """
+ Action: str = Field(default="", description="自动化动作名称")
+ x1: int = Field(default=-1, description="鼠标起始位置x1")
+ y1: int = Field(default=-1, description="鼠标起始位置y1")
+ x2: int = Field(default=-1, description="鼠标结束位置x2")
+ y2: int = Field(default=-1, description="鼠标结束位置y2")
+ key: str = Field(default="", description="快捷键")
+ content: str = Field(default="", description="输入文本内容")
+ direction: str = Field(default="", description="滚动方向")
+
+ def set_dto_data(self, **kwargs) -> "AutoAgentDataTransferObject":
+ """链式更新数据(Pydantic 风格)"""
+ for key, value in kwargs.items():
+ if hasattr(self, key):
+ setattr(self, key, value)
+ return self
+
+ def to_json(self) -> dict:
+ """
+ 将DTO对象转换为可序列化字典
+ 返回的json数据的格式:
+ {
+ "type": "auto_agent",
+ "timestamp": 1672531200.0,
+ "data": {
+ "Action": "自动化agent返回的相应的自动化动作名称",
+ "x1": "某个操作的x1,不是所有的操作都有,如果相关的操作没有,写成-1即可,y1,x2,y2也同理",
+ "y1": "某个操作的y1",
+ "x2": "某个操作的x2",
+ "y2": "某个操作的y2",
+ "key": "快捷键,若当前操作没有该字段信息,此处内容为空即可",
+ "content": "输入文本的内容,若当前操作没有该字段信息,此处内容为空即可",
+ "direction": "滚动方向,若当前操作没有该字段信息,此处内容为空即可"
+ }
+ }
+ """
+ # model_dump() 是 Pydantic v2 的序列化方法
+ payload = self.model_dump() # 获取所有模型字段
+ # 构造嵌套结构
+ return {
+ "type": "auto_agent",
+ "timestamp": datetime.now(timezone.utc).timestamp(),
+ "data": payload # 字段嵌套
+ }
+
+ @classmethod
+ def from_json(cls, json_data: dict) -> "AutoAgentDataTransferObject":
+ """
+ 从JSON数据创建DTO对象
+ 传入的json数据格式:
+ {
+ "type": "auto_agent",
+ "Action": "自动化agent返回的相应的自动化动作名称",
+ "x1": "某个操作的x1,不是所有的操作都有,如果相关的操作没有,写成-1即可,y1,x2,y2也同理",
+ "y1": "某个操作的y1",
+ "x2": "某个操作的x2",
+ "y2": "某个操作的y2",
+ "key": "快捷键,若当前操作没有该字段信息,此处内容为空即可",
+ "content": "输入文本的内容,若当前操作没有该字段信息,此处内容为空即可",
+ "direction": "滚动方向,若当前操作没有该字段信息,此处内容为空即可"
+ }
+ """
+ payload = json_data
+ payload.pop("type", None) # 移除多余字段
+ # 构造对象 Pydantic 自动忽略 type/timestamp
+ return cls.model_validate(payload)
+
diff --git a/src/modules/websocket_base_module/dto/dto_templates/data_dto_base.py b/src/modules/websocket_base_module/dto/dto_templates/data_dto_base.py
new file mode 100644
index 0000000..df32b46
--- /dev/null
+++ b/src/modules/websocket_base_module/dto/dto_templates/data_dto_base.py
@@ -0,0 +1,37 @@
+class BaseDataTransferObject:
+ """
+ DTO基类
+ 子类按需重写成员函数即可
+ """
+ def __init__(self):
+ pass
+ def to_json(self):
+ """
+ 将DTO对象转换为JSON
+ """
+ pass
+ def from_json(self, json_data):
+ """
+ 从JSON数据中创建DTO对象
+ """
+ pass
+ def to_binary(self):
+ """
+ 将DTO对象转换为二进制
+ """
+ pass
+ def from_binary(self, binary_data):
+ """
+ 从二进制数据中创建DTO对象
+ """
+ pass
+ def to_text(self):
+ """
+ 将DTO对象转换为文本
+ """
+ pass
+ def from_text(self, text_data):
+ """
+ 从文本数据中创建DTO对象
+ """
+ pass
\ No newline at end of file
diff --git a/src/modules/websocket_base_module/dto/dto_templates/screenshot_data_dto.py b/src/modules/websocket_base_module/dto/dto_templates/screenshot_data_dto.py
new file mode 100644
index 0000000..e49ea33
--- /dev/null
+++ b/src/modules/websocket_base_module/dto/dto_templates/screenshot_data_dto.py
@@ -0,0 +1,72 @@
+from pydantic import Field, BaseModel
+from datetime import datetime, timezone
+class ScreenShotDataTransferObject(BaseModel):
+ """
+ 服务端向客户端请求实时截图的数据传输对象
+ 服务端与客户端收发对等(通过Owner标识)
+ 客户端收到这个type的包,就会自动对当前设备的画面进行截图
+ """
+ Owner: str = Field(default="server", description="数据的拥有者(server or client)")
+ isSuccess: bool = Field(default=False, description="是否截图成功")
+ RealTimeScreenShot: str = Field(default="", description="客户端设备的实时截图数据(base64)")
+ Width: int = Field(default=1920, description="截图的宽度")
+ Height: int = Field(default=1080, description="截图的高度")
+ DescribeInfo: str = Field(default="", description="设备的描述信息(告知模型以做出更加准确的判断)")
+ LLMResponse: str = Field(default="", description="LLM的响应结果(由服务端发送时携带)")
+
+
+ def set_dto_data(self, **kwargs) -> "ScreenShotDataTransferObject":
+ """链式更新数据(Pydantic 风格)"""
+ for key, value in kwargs.items():
+ if hasattr(self, key):
+ setattr(self, key, value)
+ return self
+
+ def to_json(self) -> dict:
+ """
+ 将DTO对象转换为可序列化字典
+ 返回的json数据的格式:
+ {
+ "type": "screenshot_data",
+ "timestamp": 1672531200.0,
+ "data": {
+ "Owner": "数据的拥有者(server or client)",
+ "isSuccess": "是否截图成功(true or false)"
+ "RealTimeScreenShot": "客户端设备的实时截图数据(base64)",
+ "Width": "截图的宽度",
+ "Height": "截图的高度",
+ "DescribeInfo": "设备的描述信息(告知模型以做出更加准确的判断)",
+ "LLMResponse": "LLM的响应结果(由服务端发送时携带)"
+ }
+ }
+ """
+ # model_dump() 是 Pydantic v2 的序列化方法
+ payload = self.model_dump() # 获取所有模型字段
+ # 构造嵌套结构
+ return {
+ "type": "screenshot_data",
+ "timestamp": datetime.now(timezone.utc).timestamp(),
+ "data": payload # 字段嵌套
+ }
+
+ @classmethod
+ def from_json(cls, json_data: dict) -> "ScreenShotDataTransferObject":
+ """
+ 从JSON数据创建DTO对象
+ 传入的json数据格式:
+ {
+ "Owner": "数据的拥有者(server or client)",
+ "isSuccess": "是否截图成功(true or false)",
+ "RealTimeScreenShot": "客户端设备的实时截图数据(base64)",
+ "Width": "截图的宽度 非必要字段",
+ "Height": "截图的高度 非必要字段",
+ "DescribeInfo": "设备的描述信息(告知模型以做出更加准确的判断) 非必要字段",
+ "LLMResponse": "LLM的响应结果(由服务端发送时携带) 必要字段"
+ }
+ """
+ payload = json_data
+ payload.pop("type", None) # 移除多余字段
+ payload.pop("timestamp", None) # 移除多余字段
+ # 构造对象 Pydantic 自动忽略 type/timestamp
+ return cls.model_validate(payload)
+
diff --git a/src/modules/websocket_base_module/dto/second_dtos.py b/src/modules/websocket_base_module/dto/second_dtos.py
new file mode 100644
index 0000000..822cd85
--- /dev/null
+++ b/src/modules/websocket_base_module/dto/second_dtos.py
@@ -0,0 +1,95 @@
+# dto/second_dtos.py
+import asyncio
+from typing import Callable, Optional, Dict, Any, List, Coroutine
+from src.modules.websocket_base_module.dto.dto_base import MessageDTO
+from loguru import logger
+
+"""
+二级分发器,因为没有信号与槽机制,因此使用观察者模式替代
+"""
+def singleton(cls): # 单例
+ _instance = None
+ _lock = asyncio.Lock()
+ async def get_instance(*args, **kwargs):
+ nonlocal _instance
+ if _instance is None:
+ async with _lock:
+ if _instance is None:
+ _instance = cls(*args, **kwargs)
+ return _instance
+
+ cls.get_instance = get_instance
+ return cls
+# 类型别名
+ReceiveCallback = Callable[[Any], Coroutine[Any, Any, None]]
+@singleton
+class JsonDTO(MessageDTO):
+ """针对json消息的二级分发"""
+ """
+ 明确业务json格式:
+ {
+ "type" : "xxx",
+ "timestamp" : "95153...",
+ "data" : "{根据业务的不同,有不同的内容}"
+ }
+ """
+ # 因为不同的数据块的json以type字段进行包装,根据type进行正确的数据分发
+ def __init__(self, ws_server):
+ super().__init__(ws_server)
+ self.receivers : Dict[str, List[ReceiveCallback]] = {
+ 'audio_data' : [], # 音频数据
+ 'screenshot_data' : [] # 截图数据
+ }
+ # 注册json处理callback function
+ ws_server.register_receiver('json', self._handle_json)
+ logger.info("[JsonDTO] JSON分发器已注册")
+
+ def register_receiver(self, types : str, callback : ReceiveCallback):
+ """注册二次分发业务接收函数,供业务DTO调用"""
+ if types in self.receivers:
+ self.receivers[types].append(callback)
+ logger.debug(f"[JsonDTO] 已注册 {types} 接收器,当前共 {len(self.receivers[types])} 个")
+ else:
+ raise ValueError(f"[JsonDTO] 不支持的分发类型: {types}")
+
+ def unregister_receiver(self, types : str, callback : ReceiveCallback):
+ """注销二次分发业务接收函数"""
+ if callback in self.receivers[types]:
+ self.receivers[types].remove(callback)
+ logger.debug(f"[JsonDTO] 已注销 {types} 接收器")
+
+ async def _handle_json(self, data: dict):
+ """JSON消息处理"""
+ logger.info(f"[JsonDTO] 收到消息")
+ logger.debug(f'[JsonDTO] 当前消息时间戳: {data["timestamp"]}')
+ # 根据类型进行自动分发
+ await self._dispatch(data.get("type"), data["data"])
+
+ async def _dispatch(self, types : str, data : dict):
+ """二次分发json数据到相应的接收函数当中"""
+ callbacks = self.receivers[types] # 获取相关types的所有观察者
+ if not callbacks:
+ logger.info(f"[JsonDTO] 无 {types} 接收器,消息被忽略")
+ return
+ # 并发执行所有回调
+ tasks = [callback(data) for callback in callbacks]
+ await asyncio.gather(*tasks, return_exceptions=True)
+async def get_json_dto_instance(ws_server) -> JsonDTO:
+ return JsonDTO(ws_server)
+
+
+class EchoDTO(MessageDTO):
+ """回声DTO:只处理文本消息 测试用"""
+
+ def __init__(self, ws_server):
+ super().__init__(ws_server)
+ # 注册文本接收函数
+ ws_server.register_receiver('text', self._handle_text)
+ logger.info("[EchoDTO] 文本接收器已注册")
+
+ async def _handle_text(self, message: str):
+ """文本消息处理"""
+ logger.info(f"[EchoDTO] 收到文本: {message}")
+
+ # 业务逻辑
+ await self.send_text(f"Echo: {message}")
diff --git a/src/modules/websocket_base_module/dto/third_dtos.py b/src/modules/websocket_base_module/dto/third_dtos.py
new file mode 100644
index 0000000..ae3cb11
--- /dev/null
+++ b/src/modules/websocket_base_module/dto/third_dtos.py
@@ -0,0 +1,281 @@
+import asyncio
+from src.modules.websocket_base_module.dto.dto_templates.audio_data_dto import AudioDataTransferObject
+from src.modules.websocket_base_module.dto.dto_templates.screenshot_data_dto import ScreenShotDataTransferObject
+from src.modules.websocket_base_module.dto.second_dtos import JsonDTO
+from loguru import logger
+from typing import Callable, List, Optional, Coroutine
+
+class AudioDataDTO:
+ """音频数据交互DTO 再次分发给所有使用到了音频数据的相关业务(最后一级分发)"""
+ def __init__(self, json_dto : JsonDTO):
+ json_dto.register_receiver('audio_data', self._handle_audio_data) # 注册JSON接收函数
+ logger.info("[AudioDataDTO] 音频接收业务已注册")
+ self.json_dto = json_dto
+ self.audio_data = AudioDataTransferObject() # 音频数据对象
+ # 业务回调列表,延续观察者模式
+ self._audio_callbacks: List[Callable[[AudioDataTransferObject], Coroutine]] = []
+ # 最新音频缓存 支持同步查询
+ self._latest_audio: Optional[AudioDataTransferObject] = None
+ # 流式缓冲区 用于大段音频流
+ self._stream_buffer: List[AudioDataTransferObject] = []
+ logger.info("[AudioDataDTO] 业务接口已初始化")
+
+ async def _handle_audio_data(self, data: dict):
+ """处理音频数据"""
+ """
+ 音频数据json格式:
+ {
+ 'Owner': 'server',
+ 'isStream': False,
+ 'isStart': False,
+ 'isEnd': False,
+ 'sequence': 0,
+ 'data': '',
+ 'sampleRate': 16000,
+ 'channelCount': 1,
+ 'bitDepth': 16,
+ 'duration': 0.0,
+ 'text': ''
+ }
+ """
+ logger.debug(f"[AudioDataDTO] 收到音频数据")
+ # 将dict反序列化到DTO对象
+ self.audio_data = AudioDataTransferObject.from_json(data)
+ # 缓存最新数据
+ self._latest_audio = self.audio_data
+ # 如果是流式数据,加入缓冲区
+ if self.audio_data.isStream:
+ self._stream_buffer.append(self.audio_data)
+ if self.audio_data.isEnd:
+ logger.info(f"流式音频接收完成,共 {len(self._stream_buffer)} 块")
+ # 通知所有注册的回调
+ await self._notify_callbacks()
+
+ # 业务发送接口
+ async def send_audio_data(self, data: AudioDataTransferObject) -> None:
+ """
+ 发送音频数据
+
+ Args:
+ data: 音频数据DTO
+ """
+ await self.send_audio(
+ Owner=data.Owner,
+ is_stream=data.isStream,
+ is_start=data.isStart,
+ is_end=data.isEnd,
+ sequence=data.sequence,
+ data=data.data,
+ sampleRate=data.sampleRate,
+ channelCount=data.channelCount,
+ bitDepth=data.bitDepth,
+ duration=data.duration,
+ text=data.text
+ )
+
+ async def send_audio(
+ self,
+ data: bytes,
+ is_stream: bool = False,
+ is_start: bool = False,
+ is_end: bool = False,
+ sequence: int = 0,
+ **audio_meta
+ ) -> None:
+ """
+ 业务层发送音频的便捷接口
+
+ Args:
+ data: 原始音频字节
+ is_stream: 是否为流式数据
+ is_start: 流式数据开始标记
+ is_end: 流式数据结束标记
+ sequence: 数据块序号
+ **audio_meta: 其他音频参数(sampleRate, channelCount等)
+ """
+ # 填充音频数据到DTO
+ self.audio_data.set_dto_data(
+ Owner="server" or audio_meta.get('Owner', "server"),
+ isStream=is_stream,
+ isStart=is_start,
+ isEnd=is_end,
+ sequence=sequence,
+ data=data,
+ sampleRate=audio_meta.get('sampleRate', 16000),
+ channelCount=audio_meta.get('channelCount', 1),
+ bitDepth=audio_meta.get('bitDepth', 16),
+ duration=audio_meta.get('duration', 0.0),
+ text=audio_meta.get('text', "")
+ )
+ # 序列化为JSON并发送 自动处理base64和type字段
+ json_message = self.audio_data.to_json()
+ await self.json_dto.send_json(json_message)
+ logger.info(f"音频已发送: sequence={sequence}, 大小={len(data)} bytes")
+
+ # 业务接收接口
+ def register_audio_callback(
+ self,
+ callback: Callable[[AudioDataTransferObject], Coroutine]
+ ) -> None:
+ """
+ 业务注册接收回调
+
+ 使用示例:
+ async def my_audio_handler(audio_dto: AudioDataTransferObject):
+ print(f"收到音频: {len(audio_dto.data)} bytes")
+
+ audio_dto.register_audio_callback(my_audio_handler)
+ """
+ self._audio_callbacks.append(callback)
+ logger.debug(f"业务音频回调已注册,当前共 {len(self._audio_callbacks)} 个")
+
+ def unregister_audio_callback(self, callback) -> None:
+ """注销业务回调"""
+ if callback in self._audio_callbacks:
+ self._audio_callbacks.remove(callback)
+ logger.debug("业务音频回调已注销")
+
+ def get_latest_audio(self) -> Optional[AudioDataTransferObject]:
+ """
+ 同步获取最新音频数据(轮询模式)
+
+ Returns:
+ 最新接收到的音频DTO,如果没有则为 None
+ """
+ return self._latest_audio
+
+ def clear_stream_buffer(self) -> None:
+ """清空流式缓冲区"""
+ self._stream_buffer.clear()
+ logger.debug("流式音频缓冲区已清空")
+
+ # 内部通知机制
+ async def _notify_callbacks(self) -> None:
+ """通知所有业务回调"""
+ if not self._audio_callbacks:
+ logger.warning("无业务回调,音频数据未处理")
+ return
+ tasks = [callback(self.audio_data) for callback in self._audio_callbacks]
+ await asyncio.gather(*tasks, return_exceptions=True)
+ logger.debug(f"已通知 {len(self._audio_callbacks)} 个业务回调")
+ # 异步迭代器 流式时使用
+ def __aiter__(self):
+ """支持 async for 循环接收流式音频"""
+ return self
+ async def __anext__(self) -> AudioDataTransferObject:
+ """异步迭代器协议"""
+ pass
+
+class ScreenShotDataDTO:
+ """截屏数据交互DTO 分发给所有使用到了截屏数据的相关业务(最后一级分发)"""
+ def __init__(self, json_dto : JsonDTO):
+ json_dto.register_receiver('screenshot_data', self._handle_screenshot_data) # 注册JSON接收函数
+ logger.info("[ScreenShotDataDTO] 截屏接收业务已注册")
+ self.json_dto = json_dto
+ self.screenshot_data = ScreenShotDataTransferObject() # 截屏数据对象
+ # 业务回调列表,延续观察者模式
+ self._screenshot_callbacks: List[Callable[[ScreenShotDataTransferObject], Coroutine]] = []
+ # 最新截屏数据缓存 支持同步查询
+ self._latest_screenshot: Optional[ScreenShotDataTransferObject] = None
+ logger.info("[ScreenShotDataDTO] 业务接口已初始化")
+
+ async def _handle_screenshot_data(self, data: dict):
+ """处理截屏数据"""
+ """
+ 截屏数据json格式:
+ {
+ "Owner": "数据的拥有者(server or client)",
+ "isSuccess": "是否截图成功(true or false)"
+ "RealTimeScreenShot": "客户端设备的实时截图数据(base64)",
+ "Width": "截图的宽度",
+ "Height": "截图的高度",
+ "DescribeInfo": "设备的描述信息(告知模型以做出更加准确的判断)"
+ }
+ """
+ logger.debug(f"[ScreenShotDataDTO] 收到截屏数据")
+ # 将dict反序列化到DTO对象
+ self.screenshot_data = ScreenShotDataTransferObject.from_json(data)
+ # 缓存最新数据
+ self._latest_screenshot = self.screenshot_data
+ # 通知所有注册的回调
+ await self._notify_callbacks()
+
+ # 业务发送接口
+ async def send_screenshot_data(self, data: ScreenShotDataTransferObject) -> None:
+ """
+ 发送音频数据
+
+ Args:
+ data: 音频数据DTO
+ """
+ await self.send_screenshot(
+ Owner=data.Owner,
+ isSuccess=data.isSuccess,
+ RealTimeScreenShot=data.RealTimeScreenShot,
+ Width=data.Width,
+ Height=data.Height,
+ DescribeInfo=data.DescribeInfo,
+ LLMResponse=data.LLMResponse
+ )
+
+ async def send_screenshot(
+ self,
+ **screenshot_meta
+ ) -> None:
+ """
+ 业务层发送音频的便捷接口
+ 一般来说,作为发送请求方,不需要填充任何数据
+
+ Args:
+ **screenshot_meta: 截屏数据元信息
+ """
+ # 填充音频数据到DTO
+ self.screenshot_data.set_dto_data(
+ Owner="server" or screenshot_meta.get('Owner', "server"),
+ isSuccess=screenshot_meta.get('isSuccess', False),
+ RealTimeScreenShot=screenshot_meta.get('RealTimeScreenShot', ""),
+ Width=screenshot_meta.get('Width', 1920),
+ Height=screenshot_meta.get('Height', 1080),
+ DescribeInfo=screenshot_meta.get('DescribeInfo', False),
+ LLMResponse=screenshot_meta.get('LLMResponse', "")
+ )
+ # 序列化为JSON并发送 自动处理base64和type字段
+ json_message = self.screenshot_data.to_json()
+ await self.json_dto.send_json(json_message)
+ logger.info(f"截屏包已发送")
+
+ # 业务接收接口
+ def register_screenshot_callback(
+ self,
+ callback: Callable[[ScreenShotDataTransferObject], Coroutine]
+ ) -> None:
+ """
+ 业务注册接收回调
+ """
+ self._screenshot_callbacks.append(callback)
+ logger.debug(f"业务截屏回调已注册,当前共 {len(self._screenshot_callbacks)} 个")
+
+ def unregister_screenshot_callback(self, callback) -> None:
+ """注销业务回调"""
+ if callback in self._screenshot_callbacks:
+ self._screenshot_callbacks.remove(callback)
+ logger.debug("业务截屏回调已注销")
+
+ def get_latest_screenshot(self) -> Optional[ScreenShotDataTransferObject]:
+ """
+ 同步获取最新截屏数据(轮询模式)
+
+ Returns:
+ 最新接收到的截屏数据DTO,如果没有则为 None
+ """
+ return self._latest_screenshot
+
+ # 内部通知机制
+ async def _notify_callbacks(self) -> None:
+ """通知所有业务回调"""
+ if not self._screenshot_callbacks:
+ logger.warning("无业务回调,截屏数据未处理")
+ return
+ tasks = [callback(self.screenshot_data) for callback in self._screenshot_callbacks]
+ await asyncio.gather(*tasks, return_exceptions=True)
+ logger.debug(f"已通知 {len(self._screenshot_callbacks)} 个业务回调")
diff --git a/src/modules/websocket_base_module/readme.md b/src/modules/websocket_base_module/readme.md
new file mode 100644
index 0000000..6b9724b
--- /dev/null
+++ b/src/modules/websocket_base_module/readme.md
@@ -0,0 +1,56 @@
+### 说明
+在本module当中,每个子模块的用途分别是:
+- dto
+ - dto_templates
+ 服务端与客户端交互所使用到的数据传输对象
+ - dto_base.py / xxx_dtos.py
+ 实际的DTO业务
+- websocket_core
+ - websocket核心,承载了底层核心的网络收发业务
+
+
+### 模块架构
+```mermaid
+graph TB
+ subgraph "Client"
+ C[WebSocket Client]
+ end
+
+ subgraph "Core Layer"
+ WS[WebSocketServer
单例]
+ WS -->|持有| WSP[WebSocketServerProtocol
_websocket]
+ WS -->|管理| RCV[ receivers: Dict
binary/text/json ]
+ end
+
+ subgraph "DTO Base Layer"
+ MDTO[MessageDTO
抽象基类]
+ MDTO -->|注入| MDF[ send_binary
send_text
send_json ]
+ end
+
+ subgraph "Secondary Dispatcher"
+ JDTO[JsonDTO
单例]
+ JDTO -->|继承| MDTO
+ JDTO -->|维护| MAP[ receivers: Dict
audio_data/... ]
+ JDTO -->|注册到| WS
+ end
+
+ subgraph "Business DTO"
+ ADTO[AudioDataDTO......
业务实现]
+ ADTO -->|持有引用| JDTO
+ ADTO -->|使用| ATO[AudioDataTransferObject
Pydantic模型]
+ end
+
+ C <-->|websocket连接| WSP
+
+ WS -->|分发消息| JDTO
+ JDTO -->|二次分发| ADTO
+
+ ADTO -->|发送响应| JDTO
+ JDTO -->|调用| MDF
+ MDF -->|经由| WS
+ WS -->|发送至| C
+
+ style WS fill:#64f,stroke:#333,stroke-width:2px
+ style JDTO fill:#569,stroke:#333,stroke-width:2px
+ style ADTO fill:#38f,stroke:#333,stroke-width:2px
+```
\ No newline at end of file
diff --git a/src/modules/websocket_base_module/websocket_core/__init__.py b/src/modules/websocket_base_module/websocket_core/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/modules/websocket_base_module/websocket_core/core_ws_server.py b/src/modules/websocket_base_module/websocket_core/core_ws_server.py
new file mode 100644
index 0000000..093ffd7
--- /dev/null
+++ b/src/modules/websocket_base_module/websocket_core/core_ws_server.py
@@ -0,0 +1,172 @@
+# websocket_core/core_ws_server.py
+import asyncio
+import json
+from typing import Callable, Optional, Dict, Any, List, Coroutine
+from websockets.asyncio.server import serve, ServerConnection
+from websockets.exceptions import ConnectionClosed
+from loguru import logger
+
+# 类型别名
+ReceiveCallback = Callable[[Any], Coroutine[Any, Any, None]]
+class WebSocketServer:
+ """WebSocket服务端核心模块(单例 + 单客户端)
+ 只管理一个客户端连接,DTO层注册接收函数,服务端只负责分发。
+ """
+ _instance: Optional["WebSocketServer"] = None
+ _lock = asyncio.Lock()
+
+ def __new__(cls):
+ """同步单例(__init__ 可以是 async)"""
+ if cls._instance is None:
+ cls._instance = super().__new__(cls)
+ cls._instance._initialized = False
+ return cls._instance
+
+ def __init__(self):
+ # 防止重复初始化
+ if self._initialized:
+ return
+ self._websocket: Optional[ServerConnection] = None
+ self._receivers: Dict[str, List[ReceiveCallback]] = {
+ 'binary': [],
+ 'text': [],
+ 'json': []
+ }
+ self._connected_event = asyncio.Event()
+ self._initialized = True
+ logger.info("WebSocketServer 单客户端分发器已初始化")
+
+ # DTO层注册接口
+ def register_receiver(self, msg_type: str, callback: ReceiveCallback) -> None:
+ """注册接收函数(供DTO层调用)
+
+ Args:
+ msg_type: 消息类型(binary/text/json)
+ callback: 接收回调,签名为 (data) -> coroutine
+ """
+ if msg_type in self._receivers:
+ self._receivers[msg_type].append(callback)
+ logger.debug(f"已注册 {msg_type} 接收器,当前 {msg_type} 类型接收器共 {len(self._receivers[msg_type])} 个")
+ else:
+ raise ValueError(f"不支持的消息类型: {msg_type}")
+
+ def unregister_receiver(self, msg_type: str, callback: ReceiveCallback) -> None:
+ """注销接收函数"""
+ if callback in self._receivers[msg_type]:
+ self._receivers[msg_type].remove(callback)
+ logger.debug(f"已注销 {msg_type} 接收器")
+
+ # 发送接口(供DTO层调用)
+ async def send_binary(self, data: bytes):
+ """发送二进制数据(唯一客户端)"""
+ if self._websocket:
+ await self._websocket.send(data)
+ logger.trace(f"二进制数据已发送 (长度: {len(data)} bytes)")
+ else:
+ raise RuntimeError("客户端未连接")
+
+ async def send_text(self, data: str):
+ """发送文本数据(唯一客户端)"""
+ if self._websocket:
+ await self._websocket.send(data)
+ logger.trace(f"文本数据已发送 (长度: {len(data)} chars)")
+ else:
+ raise RuntimeError("客户端未连接")
+
+ async def send_json(self, data: Dict[str, Any]):
+ """发送JSON数据(唯一客户端)"""
+ if self._websocket:
+ try:
+ logger.debug(f"准备发送JSON数据: {data}")
+ message = json.dumps(data)
+ await self._websocket.send(message)
+ logger.trace(f"JSON数据已发送: {data}")
+ except Exception as e:
+ logger.error(f"JSON数据发送失败: {e}")
+ raise
+ else:
+ raise RuntimeError("客户端未连接")
+
+ # 等待连接
+ async def wait_for_client(self):
+ """阻塞等待客户端连接"""
+ await self._connected_event.wait()
+ logger.info("客户端已就绪")
+
+ # 内部消息循环
+ async def _handle_client(self, websocket: ServerConnection):
+ """处理唯一客户端的消息循环"""
+ self._websocket = websocket
+ self._connected_event.set()
+ client_info = f"{websocket.remote_address}" if hasattr(websocket, 'remote_address') else "unknown"
+ logger.info(f"客户端已连接: {client_info}")
+
+ try:
+ async for message in websocket:
+ # 根据消息类型分发到所有注册的接收函数
+ if isinstance(message, bytes):
+ await self._dispatch('binary', message)
+
+ elif isinstance(message, str):
+ # 优先尝试JSON解析
+ json_dispatched = False
+ try:
+ data = json.loads(message)
+ await self._dispatch('json', data)
+ json_dispatched = True
+ except json.JSONDecodeError:
+ pass
+
+ # 如果不是JSON或没有json接收器,尝试text
+ if not json_dispatched:
+ await self._dispatch('text', message)
+
+ else:
+ logger.warning(f"未知消息类型: {type(message)}")
+ logger.info("客户端连接已正常关闭")
+ except ConnectionClosed as e:
+ logger.info(f"客户端连接已关闭: {e}")
+ except Exception as e:
+ logger.error(f"处理客户端时发生错误: {e}")
+ finally:
+ self._websocket = None
+ self._connected_event.clear()
+
+ async def _dispatch(self, msg_type: str, data: Any):
+ """分发消息到所有注册的接收函数"""
+ callbacks = self._receivers[msg_type]
+ if not callbacks:
+ logger.warning(f"无 {msg_type} 接收器,消息被忽略")
+ return
+
+ # 并发执行所有回调
+ tasks = [callback(data) for callback in callbacks]
+ await asyncio.gather(*tasks, return_exceptions=True)
+
+ # 启动服务器
+ async def run(self, host: str = "localhost", port: int = 8765, max_msg_size: int = 50*1024*1025):
+ """启动WebSocket服务器(阻塞)"""
+ logger.info(f"WebSocket服务器启动中... 等待客户端连接 ws://{host}:{port}")
+
+ async def handler_wrapper(connection):
+ logger.info(f"新连接请求: {connection.remote_address}")
+ await self._handle_client(connection)
+
+ try:
+ async with serve(
+ handler=handler_wrapper, # 使用 wrapper 适配签名
+ host=host,
+ port=port,
+ max_size=max_msg_size,
+ ):
+ logger.success(f"WebSocket服务器已启动在 ws://{host}:{port}")
+ await asyncio.Future() # 永久阻塞,保持服务器运行
+ except Exception as e:
+ logger.error(f"WebSocket服务器启动失败: {e}")
+ raise
+
+
+async def get_ws_server() -> WebSocketServer:
+ """全局单例获取函数(线程安全)"""
+ server = WebSocketServer() # __new__ 保证单例
+ return server
\ No newline at end of file
diff --git a/src/server_core/core.py b/src/server_core/core.py
new file mode 100644
index 0000000..b231b74
--- /dev/null
+++ b/src/server_core/core.py
@@ -0,0 +1,306 @@
+# server_core/core.py
+
+"""
+统一业务,对外提供启动接口
+业务数据流向:
+ Yosuga[User Audio Info Struct] ->(WebSocket) Yosuga_server[asr_module] -> Text
+ Yosuga_server[Come from Yosuga Audio ASR Text] ->(Func call) Yosuga_server[llm_core] -> Ins and Text
+
+ Yosuga_server[Come from llm_core Text]->(WebSocket) Yosuga
+ Yosuga_server[Come from llm_core Text]->(Func call) Yosuga_server[TTS] -> Audio Data
+ Yosuga_server[Audio Data] ->(WebSocket) Yosuga
+
+ Yosuga_embedded[Devices Control Info] ->(WebSocket) Yosuga_server[embedded_core]-> Ins
+ Yosuga_embedded[Devices Control Info] ->(Serial) Yosuga[SerialManager] -> ForWord To Yosuga_server
+ Yosuga[Come from embedded Info] ->(WebSocket) Yosuga_server[llm_core] -> Ins
+
+ UI_TARS[Mind and x&y Info] ->(Func call) Yosuga_server[llm_core] -> Ins
+ Yosuga_server[Come from UI_TARS Ins] ->(WebSocket) Yosuga
+
+ Yosuga[Live2D Control Info] ->(WebSocket) Yosuga_server[llm_core] -> Ins
+ Yosuga_server[Live2D Control Ins] ->(Websocket) Yosuga
+
+ Yosuga_server[agent memory] ->(Func call) Yosuga_server[Memory Uint]
+"""
+import asyncio
+from typing import Optional, List, Dict, Any
+from loguru import logger
+from src.modules.websocket_base_module.dto.third_dtos import (
+ AudioDataDTO, AudioDataTransferObject,
+ ScreenShotDataDTO, ScreenShotDataTransferObject
+)
+from src.modules.websocket_base_module.dto.second_dtos import JsonDTO, get_json_dto_instance
+from src.modules.websocket_base_module.websocket_core.core_ws_server import WebSocketServer, get_ws_server
+
+from src.modules.device_control_module.device_control_core.ui_tars_.ui_tars_client import UITarsClient, UITarsClientConfig
+
+from src.modules.asr_module.client.asr_client import create_asr_client, ASRClientConfig, ASRClientAsync
+
+from src.modules.tts_module.tts_core.gpt_sovits.gpt_sovits_client import StreamingMode, TTSConfig, GPTSoVITSClient
+
+from src.server_core.llm_core.llm_core import (
+ LLMCoreConfig, ModelConfig,
+ YosugaLLMCore, ModelProvider,
+ LLMCoreAnalysisBase,
+ YosugaAudioResponseData, YosugaUITARSResponseData,
+ YosugaUITARSRequestData
+)
+
+from src.modules.websocket_base_module.dto.dto_templates.auto_agent_data_dto import AutoAgentDataTransferObject
+from src.config.config import cfg
+
+
+class YosugaServerCore:
+ """
+ 异步单例类
+ """
+
+ _instance: Optional["YosugaServerCore"] = None
+ _lock = asyncio.Lock()
+
+ # 组合必要的工具类
+ ws_server: WebSocketServer
+ json_dto: JsonDTO
+ audio_dto: AudioDataDTO
+ screenshot_dto: ScreenShotDataDTO
+
+ asr_client: ASRClientAsync # 异步asr client
+ tts_client: GPTSoVITSClient # tts client
+ auto_agent_client: UITarsClient # GUI自动化agent
+
+ llm_core: YosugaLLMCore = None # llm core
+
+ @classmethod
+ async def get_instance(cls) -> "YosugaServerCore":
+ """异步单例工厂"""
+ if cls._instance is None:
+ async with cls._lock:
+ if cls._instance is None:
+ logger.info("Initializing YosugaServerCore...")
+ # 创建实例
+ instance = cls.__new__(cls)
+
+ # 按依赖顺序初始化数据分发器
+ instance.ws_server = await get_ws_server()
+ instance.json_dto = await get_json_dto_instance(instance.ws_server)
+ instance.audio_dto = AudioDataDTO(instance.json_dto) # 音频分发器
+ instance.audio_dto.register_audio_callback(instance._handle_audio_data) # 注册音频处理函数
+ instance.screenshot_dto = ScreenShotDataDTO(instance.json_dto) # 截图分发器
+ instance.screenshot_dto.register_screenshot_callback(instance._handle_screenshot_data) # 注册截图处理函数
+
+ instance.asr_client = create_asr_client(use_async=True, base_url=cfg.asr.url)
+ instance.tts_client = GPTSoVITSClient(host=cfg.tts.host, port=cfg.tts.port, debug=True)
+ # 切换GPT_SoVITS模型
+ await instance.tts_client.set_gpt_weights(cfg.tts.gpt_model_name)
+ await instance.tts_client.set_sovits_weights(cfg.tts.sovits_model_name)
+
+ instance.auto_agent_client = UITarsClient(UITarsClientConfig(
+ deployment_type=cfg.auto_agent.deployment_type,
+ base_url=cfg.auto_agent.base_url,
+ model_name=cfg.auto_agent.model_name,
+ temperature=cfg.auto_agent.temperature,
+ max_tokens=cfg.auto_agent.max_tokens
+ ))
+
+ instance.llm_core = YosugaLLMCore(
+ model_config=ModelConfig( # TODO 同上
+ provider=ModelProvider.OPENAI,
+ model_name=cfg.ai.model_name,
+ base_url=cfg.ai.base_url,
+ api_key=cfg.ai.api_key,
+ temperature=cfg.ai.temperature,
+ max_tokens=cfg.ai.max_tokens
+ ),
+ core_config=LLMCoreConfig( # TODO 同上
+ max_context_tokens=cfg.llm_core.max_context_tokens,
+ enable_history=cfg.llm_core.enable_history,
+ role_setting=cfg.llm_core.role_character,
+ language=cfg.llm_core.language, # 回复使用语言
+ auto_dispatch=True,
+ dispatch_async=True # 启用异步分发
+ )
+ )
+ instance.register_llm_core_analysis() # 注册解析器
+ instance.register_llm_core_action() # 注册分发器
+ instance.llm_core.register_overflow_handler(instance._handle_overflow_logger) # 注册上下文溢出处理回调
+
+ cls._instance = instance
+ logger.success("YosugaServerCore initialized")
+ return cls._instance
+
+ def register_llm_core_action(self):
+ """
+ 注册llm_core的分发器
+ """
+ if self.llm_core is None:
+ raise Exception("LLMCore is not initialized")
+ self.llm_core.register_action_handler("audio_text", self._handle_audio_response, is_async=True)
+ self.llm_core.register_action_handler("auto_agent", self._handle_auto_agent, is_async=True)
+ self.llm_core.register_action_handler("call_auto_agent", self._handle_call_auto_agent, is_async=True)
+ self.llm_core.set_fallback_handler(self._handle_fallback)
+
+ def register_llm_core_analysis(self):
+ """
+ 注册llm_core的输出解析器
+ """
+ if self.llm_core is None:
+ raise Exception("LLMCore is not initialized")
+ self.llm_core.register_analysis_model(YosugaAudioResponseData)
+ self.llm_core.register_analysis_model(YosugaUITARSResponseData)
+ self.llm_core.register_analysis_model(YosugaUITARSRequestData)
+
+ def _handle_overflow_logger(self, history: List[Any], metadata: Dict[str, Any]):
+ """上下文溢出记录,仅打印日志"""
+ print(f" 上下文溢出!")
+ print(f" 模型: {metadata['model']}")
+ print(f" 消息数: {metadata['message_count']}")
+ print(f" Token: {metadata['estimated_tokens']}/{metadata['limit']}")
+ print(f" 即将遗忘 {len(history) // 2} 条旧消息")
+
+ async def _handle_audio_data(self, audio_data: AudioDataTransferObject):
+ """
+ 音频数据接收call back
+ Yosuga_server只有接受到每次这个audio数据才会跑一次
+ """
+ logger.info("Received audio data")
+ # 在此处客户端发送的音频数据必定不是流式数据(考虑客户端发送数据给服务端往往是在本地的,速度极快)
+ # 将音频数据发送给asr转换成文本信息,音频数据格式为wav
+ # TODO: 考虑在此处做一个简单的vad检测,如果客户端发送的音频是静音的,则不把请求发给llm_core
+ asr_response = await self.asr_client.transcribe_bytes(audio_data.data)
+ if not asr_response.success:
+ logger.error(f"ASR failed: {asr_response.error}")
+ asr_result = asr_response.data # 获取asr结果
+ # 将asr结果发送给llm_core进行处理
+ llm_result = await self.llm_core.interact(
+ user_input={ # 构造用户输入信息
+ "text": asr_result.text,
+ "confidence": asr_result.confidence
+ }
+ ) # llm_core会自动进行处理并通过执行器异步返回各种相关的数据
+
+ async def _handle_screenshot_data(self, screenshot_data: ScreenShotDataTransferObject):
+ """
+ 屏幕截图数据接收call back
+ 将llm_core的回复封装后提交给auto_agent模块,获得自动化agent的返回之后再返回给llm_core
+ """
+ logger.info(f"Received screenshot data {len(screenshot_data.RealTimeScreenShot)}")
+ if not screenshot_data.isSuccess: # 如果客户端截图失败
+ logger.error("Screenshot failed")
+ return # 直接提前结束回调,不向llm_core发送结果
+ # TODO 对于设备描述信息(screenshot_data.DescribeInfo),考虑加入到auto_agent的输入中,增强识别准确率
+ # 构造请求 异步调用
+ logger.debug(f"screenshot_data.LLMResponse(来自llm_core向auto_agent的输入): {screenshot_data.LLMResponse}")
+ logger.debug(f"客户端设备信息: {screenshot_data.DescribeInfo}")
+ auto_agent_response: str = await self.auto_agent_client.call_async(screenshot_data.LLMResponse,
+ screenshot_data.RealTimeScreenShot)
+ logger.debug(f"auto_agent_response(auto_agent原生返回结果): {auto_agent_response}")
+ # 将auto_agent的返回结果发送给llm_core
+ await self.llm_core.interact(
+ user_input={ # 构造auto_agent输入信息
+ "auto_agent": auto_agent_response
+ }
+ )
+
+ async def _handle_audio_response(self, data: YosugaAudioResponseData):
+ """
+ llm_core异步处理器:语音回复
+ 将llm_core的回复封装后提交给tts模块,调用tts模块中的流式返回,并将流式frame返回给Yosuga客户端
+ """
+ if data.type == "audio_text":
+ logger.info("Handling audio response")
+ try:
+ # 使用最快模式流式输出
+ chunk_count = 0
+ async for chunk in await self.tts_client.tts(
+ text=data.response_text,
+ ref_audio_path="uploaded_audio/test_voice.wav", # TODO 需要替换成config或者后续设计情感系统
+ text_lang="ja",
+ prompt_lang="ja",
+ prompt_text="もう!こんなところで何やってるんだよ!", # 参考语音的真实文本
+ streaming_mode=StreamingMode.FASTEST, # 模式3:快速流式
+ media_type="wav"
+ ):
+ chunk_count += 1
+ print(f"🎵 收到音频块 #{chunk_count}: {len(chunk.audio_data)} bytes")
+ if chunk_count == 1: # 如果是第一个音频块
+ # 构造音频首包发送给客户端
+ await self.audio_dto.send_audio_data(
+ AudioDataTransferObject(
+ data=chunk.audio_data,
+ isStream=True,
+ isStart=True,
+ sequence=chunk_count,
+ isEnd=False,
+ text=data.response_text
+ )
+ )
+ else: # 如果不是第一个音频块,则发送中间包给客户端
+ await self.audio_dto.send_audio_data(
+ AudioDataTransferObject(
+ data=chunk.audio_data,
+ isStream=True,
+ isStart=False,
+ sequence=chunk_count,
+ isEnd=False,
+ text=data.response_text
+ )
+ )
+ print(f"✅ 流式TTS完成!共{chunk_count}个音频块")
+ # 构造音频尾包发送给客户端(虚假的音频数据)
+ await self.audio_dto.send_audio_data(
+ AudioDataTransferObject(
+ data=b"0",
+ isStream=True,
+ isStart=False,
+ sequence=chunk_count + 1,
+ isEnd=True,
+ text=data.response_text
+ )
+ )
+ except Exception as e:
+ print(f"❌ 流式错误: {e}")
+ return {"status": "success", "executed": data.response_text}
+ return None
+
+ async def _handle_auto_agent(self, data: YosugaUITARSResponseData):
+ """
+ llm_core异步处理器:处理自动化操作
+ 将llm_core的回复封装后提交给Yosuga客户端,由客户端进行执行相关的GUI自动化操作
+ """
+ # 构造并发送回复数据
+ await self.json_dto.send_json(
+ AutoAgentDataTransferObject.from_json(data.to_dict()).to_json()
+ )
+ return {"status": "success", "executed": data.Action}
+
+ async def _handle_call_auto_agent(self, data: YosugaUITARSRequestData):
+ """
+ llm_core异步处理器:处理llm_core调用auto_agent需求
+ 向客户端请求当前界面的截图,请求成功后由_handle_screenshot_data函数完成剩下的任务
+ """
+ if data.type == "call_auto_agent":
+ logger.info("LLM Calling auto agent")
+ # 向客户端请求当前界面的截图的base64编码 加入llm回复的信息到截图请求DTO当中 方便_handle_screenshot_data构造请求
+ await self.screenshot_dto.send_screenshot_data(ScreenShotDataTransferObject(LLMResponse=data.llm_translation))
+ return {"status": "success", "executed": data.type}
+
+ def _handle_fallback(self, data: LLMCoreAnalysisBase):
+ """
+ llm_core同步处理器:回退处理器
+ """
+ logger.debug(f" [Fallback] 未知类型数据: {data.type}, 内容: {data.model_dump_json()}")
+
+ async def run(self):
+ """启动服务器"""
+ logger.info("Yosuga Server Websocket Core 启动中...")
+ await self.ws_server.run(host="0.0.0.0")
+
+
+# 使用方式
+async def main():
+ core = await YosugaServerCore.get_instance()
+ await core.run()
+
+
+if __name__ == '__main__':
+ asyncio.run(main())
\ No newline at end of file
diff --git a/src/server_core/emotion_core/readme.md b/src/server_core/emotion_core/readme.md
new file mode 100644
index 0000000..80d15cf
--- /dev/null
+++ b/src/server_core/emotion_core/readme.md
@@ -0,0 +1,2 @@
+### 本模块为Yosuga_server的AI输出语音情感管理模块
+
diff --git a/src/server_core/llm_core/llm_core.py b/src/server_core/llm_core/llm_core.py
new file mode 100644
index 0000000..9390dab
--- /dev/null
+++ b/src/server_core/llm_core/llm_core.py
@@ -0,0 +1,564 @@
+# llm_core/llm_core.py
+
+"""
+Yosuga Server LLM 核心控制模块
+负责整合Prompt管理、模型调用、输出解析、上下文记忆管理以及生命周期维护。
+作为系统的"大脑",对外提供统一的高级交互接口。
+"""
+import asyncio
+import json
+import time
+from typing import List, Dict, Any, Optional, Callable, Union, Type
+from loguru import logger
+from pydantic import BaseModel, Field
+# 引入已有的模块
+from src.modules.text_ai_module.text_ai_core.general_text_ai_req import (
+ UnifiedLLM, ModelConfig, ChatMessage, ModelResponse, ModelProvider
+)
+from src.server_core.llm_core.llm_core_analysis import (
+ LLMCoreAnalysisManager, LLMCoreAnalysisBase, YosugaAudioResponseData, YosugaUITARSResponseData, YosugaUITARSRequestData
+)
+from src.server_core.llm_core.llm_core_dispatcher import LLMCoreActionDispatcher
+from src.server_core.llm_core.llm_core_prompt_manager import (
+ LLMCorePromptManager, LLMCorePromptBase,
+ YosugaAudioASRText, YosugaUITARS, YosugaLive2DControl
+)
+from src.server_core.llm_core.llm_core_prompts import YOSUGA_SYSTEM_PROMPT_SCH
+from src.server_core.llm_core.llm_core_token import TokenManager, TokenUsage
+
+# 类型定义:上下文溢出回调函数签名
+# 参数1: 溢出的历史记录列表
+# 参数2: 相关的元数据
+ContextOverflowCallback = Callable[[List[ChatMessage], Dict[str, Any]], None]
+
+class LLMCoreConfig(BaseModel):
+ """LLM Core 运行时配置"""
+ max_context_tokens: int = Field(default=2000, description="上下文最大Token数(估算值,不包括System prompt),超出触发重置")
+ enable_history: bool = Field(default=True, description="是否启用历史对话记忆")
+ language: str = Field(default="zh_CN", description="回复语言设定")
+ role_setting: str = Field(default="", description="llm角色扮演")
+ auto_dispatch: bool = Field(default=True, description="是否自动分发到动作处理器")
+ dispatch_async: bool = Field(default=False, description="分发是否使用异步模式")
+ memory: str = Field(default="", description="llm记忆")
+ system_state_table: str = Field(default="", description="Yosuga系统状态表")
+
+
+
+class YosugaLLMCore:
+ """
+ Yosuga 服务端 LLM 核心控制器
+ """
+
+ def __init__(self, model_config: ModelConfig, core_config: Optional[LLMCoreConfig] = None):
+ """
+ 初始化 LLM Core
+
+ Args:
+ model_config: 底层大模型的连接配置
+ core_config: 核心业务逻辑配置(上下文限制等)
+ """
+ self.model_config = model_config
+ self.core_config = core_config or LLMCoreConfig()
+ # 初始化模型客户端 (UnifiedLLM)
+ self.llm_client: UnifiedLLM = UnifiedLLM(self.model_config)
+ # 初始化 TokenManager
+ self.token_manager = TokenManager(self.model_config.model_name)
+ # 初始化Prompt管理器
+ self.prompt_manager = LLMCorePromptManager()
+ self._register_default_prompts()
+ # 上下文记忆存储
+ self._history: List[ChatMessage] = [] # 注意:history不包含system prompt,只包含 user/assistant 消息
+ # 上下文溢出回调列表
+ self._overflow_callbacks: List[ContextOverflowCallback] = []
+ logger.info(
+ f"YosugaLLMCore 初始化完成 | "
+ f"模型: {model_config.model_name} | "
+ f"提供商: {model_config.provider}"
+ )
+ logger.info(
+ f"上下文限制: {self.core_config.max_context_tokens} tokens | "
+ f"自动分发: {self.core_config.auto_dispatch}"
+ )
+
+ def _register_default_prompts(self):
+ """注册默认的业务Prompt模块"""
+ self.prompt_manager.register(YosugaAudioASRText())
+ self.prompt_manager.register(YosugaUITARS())
+ # self.prompt_manager.register(YosugaLive2DControl()) # TODO
+ logger.info(f"默认Prompt模块注册完成 | 数量: {self.prompt_manager.get_registry_size()}")
+
+ # 系统提示词管理
+ def get_system_prompt(self) -> str:
+ """
+ 动态构建当前的 System Prompt
+ 根据 prompt_manager 中注册的模块实时生成
+ """
+ return YOSUGA_SYSTEM_PROMPT_SCH.format(
+ InputInfo=self.prompt_manager.describe_input(), # 不变的内容
+ OutputInfo=self.prompt_manager.describe_output(), # 不变的内容
+ RoleSetting=self.core_config.role_setting, # 角色扮演,可热重载
+ Language=self.core_config.language, # 回复语言,可热重载
+ Memory=self.core_config.memory, # 记忆,可热重载,请求前更新
+ SystemStateTable=self.core_config.system_state_table# 系统状态表,可热重载,每次请求都会更新
+ )
+
+ def register_prompt_module(self, prompt_module: LLMCorePromptBase):
+ """运行时注册新的 Prompt 业务模块"""
+ self.prompt_manager.register(prompt_module)
+ logger.info(f"动态注册 Prompt 模块: {prompt_module.type()}")
+
+ def register_analysis_model(self, model_class: Type[LLMCoreAnalysisBase]) -> None:
+ """
+ 注册LLM输出解析模型
+
+ Args:
+ model_class: 继承自 LLMCoreAnalysisBase 的数据模型类
+ """
+ LLMCoreAnalysisManager.register(model_class)
+ logger.info(f"注册解析模型: {model_class.type_()}")
+
+ def register_action_handler(
+ self,
+ type_id: str,
+ handler: Callable,
+ is_async: bool = False
+ ) -> None:
+ """
+ 注册动作处理器
+
+ Args:
+ type_id: 与解析模型对应的类型标识
+ handler: 处理函数(同步或异步)
+ is_async: 是否为异步处理器
+ """
+ if is_async:
+ LLMCoreActionDispatcher.register_async(type_id, handler)
+ else:
+ LLMCoreActionDispatcher.register(type_id, handler)
+
+ def set_fallback_handler(self, handler: Callable) -> None:
+ """设置未注册类型的回退处理器"""
+ LLMCoreActionDispatcher.set_fallback(handler)
+
+ # 核心交互接口
+ async def interact(
+ self,
+ user_input: Union[str, Dict[str, Any]], # 输入(纯文本或结构化字典)
+ past_memories: Optional[str] = "", # 记忆模块检索的相关历史记忆
+ system_state_table: Optional[str] = "", # 系统状态表,可热重载,每次请求都会更新
+ auto_dispatch: Optional[bool] = True, # 是否自动分发,默认为启用
+ dispatch_async: Optional[bool] = True # 是否异步分发,默认为启用
+ ) -> Dict[str, List[Any]]:
+ """
+ 核心交互方法:处理输入 -> 组装上下文 -> 调用LLM -> 解析输出
+
+ Args:
+ user_input: 用户输入(纯文本或结构化字典)
+ past_memories: 记忆模块检索的相关历史记忆
+ system_state_table: Yosuga系统状态表(方便llm理解当前系统状态)
+ auto_dispatch: 是否自动分发(覆盖默认配置)
+ dispatch_async: 是否异步分发(覆盖默认配置)
+
+ Returns:
+ 分发执行结果字典:
+ {
+ "success": [{"type": "...", "output": ..., "index": 0}],
+ "failed": [{"type": "...", "error": "...", "index": 1}],
+ "skipped": ["unknown_type"]
+ }
+
+ Raises:
+ ValueError: LLM调用或解析失败
+ RuntimeError: 分发执行致命错误
+ """
+ # 输入预处理
+ input_content = (
+ json.dumps(user_input, ensure_ascii=False)
+ if isinstance(user_input, dict)
+ else user_input
+ )
+ logger.info(f"用户输入内容经过json处理后为: {input_content}")
+
+ # 检查并维护上下文
+ self._maintain_context_limit()
+
+ # 构建本次请求的消息链
+ messages = self._build_request_messages(input_content, past_memories, system_state_table)
+
+ # 调用 LLM
+ try:
+ llm_response = self._call_llm(messages)
+ if llm_response.usage: # 打印请求消耗的Token数量
+ self.token_manager.record_api_usage(llm_response.usage)
+ logger.info(self.token_manager.format_usage_log(source="API")) # 使用 TokenManager 格式化日志
+ except Exception as e:
+ logger.error(f"LLM调用失败: {e}")
+ raise
+
+ # 统一解析(总是返回列表)
+ try:
+ parsed_results = LLMCoreAnalysisManager.parse(llm_response.content)
+ logger.success(f"解析成功 | 对象数: {len(parsed_results)}")
+ except Exception as e:
+ logger.error(f"输出解析失败: {e}")
+ raise
+
+ # 更新历史记忆
+ if self.core_config.enable_history:
+ self._add_to_history("user", input_content)
+ self._add_to_history("assistant", llm_response.content)
+
+ # 分发执行
+ should_dispatch = auto_dispatch if auto_dispatch is not None else self.core_config.auto_dispatch
+ if should_dispatch and parsed_results:
+ is_async = dispatch_async if dispatch_async is not None else self.core_config.dispatch_async
+ if is_async:
+ # 直接 await 异步执行,如果调用 asyncio.run() 就会因为重复创建事件循环导致报错
+ logger.debug("使用异步模式分发动作")
+ return await LLMCoreActionDispatcher._execute_async(parsed_results)
+ else:
+ # 同步模式保持原样
+ return LLMCoreActionDispatcher.execute(parsed_results, run_async=False)
+ # 不分发则返回原始解析结果(极少用)
+ return {"success": [{"type": obj.type, "output": obj, "index": i}
+ for i, obj in enumerate(parsed_results)],
+ "failed": [], "skipped": []}
+
+ def _build_request_messages(
+ self,
+ current_input: str,
+ memories: str,
+ system_state_table: str
+ ) -> List[ChatMessage]:
+ """
+ 构建完整的LLM消息链
+ 结构:
+ 1. System Prompt 构造,包括记忆注入等信息
+ 2. 历史上下文
+ 3. 当前用户输入
+ """
+ # 构建memory与其他信息
+ self.core_config.memory = memories
+ # 构建系统状态表
+ self.core_config.system_state_table = system_state_table
+ # 构造 System Prompt
+ messages = [ChatMessage(role="system", content=self.get_system_prompt())]
+
+ # 历史上下文
+ if self.core_config.enable_history:
+ messages.extend(self._history) # 分开每条消息追加
+ # 当前用户输入
+ messages.append(ChatMessage(role="user", content=current_input))
+ return messages
+
+ def _call_llm(self, messages: List[ChatMessage]) -> ModelResponse:
+ """执行LLM调用(非流式)"""
+ logger.debug(f"请求LLM | 消息数: {len(messages)}")
+ # 预估token使用情况 TODO: (调试用)
+ estimated = self.token_manager.estimate_chat_tokens(
+ system_prompt=self.get_system_prompt(),
+ history=self._history,
+ current_input=messages[-1].content if messages else ""
+ )
+ logger.debug(self.token_manager.format_usage_log(estimated.to_dict(), source="MANUAL"))
+
+ # 强制非流式(结构化输出需要完整JSON)
+ response: ModelResponse = self.llm_client.chat(
+ messages,
+ streaming=False,
+ temperature=self.model_config.temperature
+ )
+ # 记录 API 返回的 usage
+ if response.usage:
+ self.token_manager.record_api_usage(response.usage)
+ logger.debug(f"LLM响应长度: {len(response.content)}")
+ return response
+
+ # 上下文与记忆管理
+ def _add_to_history(self, role: str, content: str):
+ """添加消息到历史"""
+ self._history.append(ChatMessage(role=role, content=content))
+
+ def _maintain_context_limit(self) ->None:
+ """
+ 检查上下文是否超出限制
+ 如果超出,触发溢出回调,并将当前上下文导出,然后清空最近50%的
+ """
+ if not self.core_config.enable_history: # 若未启用历史对话记忆
+ return
+ # 使用 TokenManager 获取上下文占用(手动计算)
+ context_usage = self.token_manager.get_context_usage(self._history)
+ current_usage = context_usage.total_tokens
+
+ # 使用 TokenManager 格式化日志
+ logger.debug(
+ self.token_manager.format_usage_log(context_usage.to_dict(), source="CONTEXT")
+ )
+ limit = self.core_config.max_context_tokens
+ # 使用 TokenManager 判断是否接近限制
+ if self.token_manager.is_token_limit_approaching(current_usage, limit, threshold=0.85):
+ logger.warning(
+ f"上下文接近限制: {current_usage}/{limit} "
+ f"({current_usage / limit:.1%})"
+ )
+ if current_usage <= limit:
+ return
+ # 否则就是溢出
+ logger.critical(
+ f"上下文溢出!| {current_usage}/{limit} tokens "
+ f"({current_usage / limit:.1%}) | 消息: {len(self._history)}"
+ )
+ # 执行所有注册的溢出处理器
+ self._trigger_overflow_callbacks()
+
+ # 智能清理:保留最近50%消息
+ keep_messages = max(1, len(self._history) // 2)
+ self._history = self._history[-keep_messages:]
+ # 求出新的token占有
+ new_usage = self._estimate_token_usage()
+ logger.success( # 打印清理前后的token占用变化
+ f"清理完成 | Token: {current_usage}→{new_usage} | "
+ f"保留消息: {len(self._history)}"
+ )
+
+ def _estimate_token_usage(self) -> int:
+ """
+ 使用 TokenManager 计算当前历史记录的 Token 数
+ """
+ if not self._history:
+ return 0
+ # 将 ChatMessage 对象转换为字典格式
+ history_dicts = [
+ {"role": msg.role, "content": msg.content}
+ for msg in self._history
+ ]
+
+ return self.token_manager.count_messages_tokens(
+ history_dicts,
+ tokens_per_message=3 # OpenAI 格式开销
+ )
+
+ def register_overflow_handler(self, handler: ContextOverflowCallback):
+ """
+ 注册上下文溢出处理器(支持多个目标)
+ 这个上下文溢出处理器用于将溢出的消息收集并记录,和记忆模块对接
+ """
+ self._overflow_callbacks.append(handler)
+ logger.info(f"注册溢出处理器: {handler.__name__}")
+
+ def _trigger_overflow_callbacks(self):
+ """执行所有注册的溢出处理器"""
+ if not self._overflow_callbacks:
+ return
+ # 使用 TokenManager 获取当前占用
+ context_usage = self.token_manager.get_context_usage(self._history)
+ metadata = { # 构造详细的元数据
+ "reason": "token_limit_exceeded",
+ "message_count": len(self._history),
+ "estimated_tokens": context_usage.total_tokens,
+ "limit": self.core_config.max_context_tokens,
+ "timestamp": time.time(),
+ "model": self.model_config.model_name
+ }
+ # 快照当前历史,防止回调修改
+ history_snapshot = list(self._history)
+
+ for handler in self._overflow_callbacks:
+ try:
+ handler(history_snapshot, metadata)
+ logger.debug(f"溢出处理器成功: {handler.__name__}")
+ except Exception as e:
+ logger.error(f"执行上下文溢出回调失败: {handler.__name__}:{e}")
+
+ def clear_context(self):
+ """手动清空上下文"""
+ self._history.clear()
+ logger.info("上下文记忆已清空")
+
+ # 运行时热重载
+ def reload_model(self, new_model_config: ModelConfig):
+ """
+ 热重载 LLM 模型配置
+ 不影响当前的上下文记忆和 System Prompt
+ """
+ logger.info(f"正在热重载模型: {self.model_config.model_name} -> {new_model_config.model_name}")
+ try:
+ self.llm_client.update_config(new_model_config)
+ self.model_config = new_model_config
+ # 重新初始化 TokenManager
+ self.token_manager = TokenManager(new_model_config.model_name)
+ logger.info("模型热重载成功")
+ except Exception as e:
+ logger.error(f"模型热重载失败: {e}")
+ raise
+
+ def get_context_stats(self) -> Dict[str, Any]:
+ """获取详细上下文统计"""
+ # 使用 TokenManager 获取当前占用
+ context_usage = self.token_manager.get_context_usage(self._history)
+ tokenizer_info = self.token_manager.get_tokenizer_info()
+ return {
+ "message_count": len(self._history),
+ "estimated_tokens": context_usage.total_tokens,
+ "limit": self.core_config.max_context_tokens,
+ "usage_ratio": context_usage.total_tokens / self.core_config.max_context_tokens,
+ "model": self.model_config.model_name,
+ "tokenizer": tokenizer_info,
+ "history_preview": [
+ f"{msg.role[:1]}:{msg.content[:30]}..."
+ for msg in self._history[-3:]
+ ],
+ "last_api_usage": self.token_manager._last_api_usage.to_dict() if self.token_manager._last_api_usage else None
+ }
+
+ def __repr__(self) -> str:
+ return ( # 返回描述信息
+ f"YosugaLLMCore(model={self.model_config.model_name}, "
+ f"provider={self.model_config.provider.value}, "
+ f"history_len={len(self._history)})"
+ )
+
+
+# 使用示例与测试
+if __name__ == "__main__":
+ import sys
+
+ # 配置日志输出
+ logger.remove()
+ logger.add(sys.stderr, level="DEBUG")
+
+ print("\n" + "=" * 50)
+ print("🚀 Yosuga Server LLM Core 启动测试")
+ print("=" * 50 + "\n")
+
+ # 准备模拟的动作处理器 (Mock Handlers)
+
+ # 异步处理器示例:处理语音回复
+ async def handle_audio_response(data: LLMCoreAnalysisBase):
+ # 这里强制类型转换为具体子类以获取代码提示(实际运行时已经是具体类型)
+ if data.type == "audio_text":
+ print(f" [Audio Handler] 正在合成语音: {data.response_text} | 情感: {data.emotion}")
+ return {"audio_file": "sample.wav", "duration": 3.5}
+ return None
+
+ # 异步处理器示例:处理自动化操作
+ async def handle_auto_agent(data: LLMCoreAnalysisBase):
+ print(f" [UI Agent] 收到指令: {data.Action}")
+ await asyncio.sleep(1) # 模拟耗时操作
+ print(f" [UI Agent] 执行动作: {data.Action} -> ({data.x1}, {data.y1})")
+ return {"status": "success", "executed": data.Action}
+
+ async def handle_call_auto_agent(data: LLMCoreAnalysisBase):
+ print(f" [Call Agent] 收到内容: {data.type}")
+ print(f" [Call Agent] 收到内容: {data.llm_translation}")
+
+ # 回退处理器
+ def handle_fallback(data: LLMCoreAnalysisBase):
+ print(f" [Fallback] 未知类型数据: {data.type}, 内容: {data.model_dump_json()}")
+
+ # 初始化 LLM Core
+
+ # 配置 LM Studio 连接
+ # 注意:LM Studio 通常兼容 OpenAI 格式,所以 provider 选 LM_STUDIO 或 OPENAI 均可
+ # 如果是本地服务,API Key 可以随意填写
+ model_cfg = ModelConfig(
+ provider=ModelProvider.LM_STUDIO,
+ model_name="qwen/qwen3-4b-2507",
+ base_url="http://192.168.1.3:1234/v1",
+ api_key="lm-studio",
+ temperature=0.3,
+ max_tokens=2048
+ )
+
+ core_cfg = LLMCoreConfig(
+ max_context_tokens=1024,
+ enable_history=True,
+ role_setting="你是由Misakiotoha开发的Yosuga助手,性格抽象,爱说点小骚话。",
+ auto_dispatch=True,
+ dispatch_async=True # 启用异步分发测试
+ )
+
+ core = YosugaLLMCore(model_cfg, core_cfg)
+
+ # 注册处理器
+ core.register_action_handler("audio_text", handle_audio_response, is_async=True)
+ core.register_action_handler("auto_agent", handle_auto_agent, is_async=True)
+ core.register_action_handler("call_auto_agent", handle_call_auto_agent, is_async=True)
+ core.set_fallback_handler(handle_fallback)
+
+ # 注册解析器
+ core.register_analysis_model(YosugaAudioResponseData)
+ core.register_analysis_model(YosugaUITARSResponseData)
+ core.register_analysis_model(YosugaUITARSRequestData)
+
+ # 交互测试 Loop
+ def run_tests():
+ # 测试场景 1: 普通对话 (触发 Audio 解析)
+ print("\n📝 测试 1: 普通对话 (预期触发 audio_text)")
+ asr_input_1 = {
+ "text": "你好,Yosuga!请介绍一下你自己,并对我微笑。",
+ "confidence": 0.99
+ }
+ try:
+ result = core.interact(
+ asr_input_1,
+ dispatch_async=True
+ )
+ print(f"🏁 交互结果: {json.dumps(result, ensure_ascii=False, indent=2)}")
+ except Exception as e:
+ logger.error(f"测试1失败: {e}")
+
+ # 测试场景 2: 复杂指令 (预期同时触发 Audio 和 UI 操作)
+ print("\n📝 测试 2: 混合指令 (预期触发 audio_text + auto_agent)")
+ # 构造一个复杂的 Prompt 输入,诱导模型输出多条指令
+ # 注意:这依赖于模型足够聪明能理解 System Prompt 中的 output schema
+ complex_input = """
+ [{
+ "text": "打开系统设置",
+ "confidence": 0.99
+ }]
+ """
+
+ try:
+ result = core.interact(complex_input, dispatch_async=True)
+ print(f"🏁 交互结果: {json.dumps(result, ensure_ascii=False, indent=2)}")
+ except Exception as e:
+ logger.error(f"测试2失败: {e}")
+
+ # 热重载测试
+ print("\n🔄 测试 3: 模型热重载")
+ new_model_cfg = ModelConfig(
+ provider=ModelProvider.LM_STUDIO,
+ model_name="qwen/qwen3-vl-8b",
+ base_url="http://192.168.1.3:1234/v1",
+ api_key="lm-studio",
+ temperature=0.5
+ )
+
+ try:
+ core.reload_model(new_model_cfg)
+ print("✅ 热重载完成,进行验证对话...")
+
+ # 验证重载后是否还能对话(保留了上下文)
+ verify_input = """
+ {
+ "text": "系统设置打开了吗",
+ "confidence": 0.95
+ },
+ {
+ "auto_agent": "Thought: 我注意到屏幕左下角有一个齿轮形状的图标,这正是系统的设置入口。在KDE系统中,这个图标通常用来访问系统设置面板。为了帮助用户打开设置界面,我现在需要点击这个位于屏幕左下方的齿轮图标。
+ Action: click(start_box='<|box_start|>(42,1045)<|box_end|>')"
+ }
+ }"""
+ result = core.interact(verify_input)
+ print(f"🏁 重载后回复: {json.dumps(result, ensure_ascii=False, indent=2)}")
+
+ except Exception as e:
+ logger.error(f"热重载测试失败: {e}")
+
+ # 查看统计信息
+ print("\n📊 最终状态统计:")
+ print(core.get_context_stats())
+
+ # 运行测试
+ run_tests()
\ No newline at end of file
diff --git a/src/server_core/llm_core/llm_core_analysis.py b/src/server_core/llm_core/llm_core_analysis.py
new file mode 100644
index 0000000..f768423
--- /dev/null
+++ b/src/server_core/llm_core/llm_core_analysis.py
@@ -0,0 +1,313 @@
+# llm_core/llm_core_analysis.py
+
+"""
+LLM输出解析与序列化模块
+将LLM返回的JSON字符串智能解析为强类型Python对象,供其他模块直接调用
+支持多类型混合响应,自动路由到对应数据模型
+"""
+
+import json
+from abc import ABC, abstractmethod
+from typing import Any, ClassVar, Dict, Optional, Type, List
+from pydantic import BaseModel, Field, ValidationError, field_validator
+from loguru import logger
+import re
+
+# 抽象基类
+class LLMCoreAnalysisBase(BaseModel, ABC):
+ """
+ LLM输出数据模型抽象基类
+ 各场景通过继承定义具体的数据结构
+ """
+
+ # 解析器type标识
+ type: str = Field(..., description="场景类型标识")
+
+ @classmethod
+ @abstractmethod
+ def type_(cls) -> str:
+ """返回该模型对应的场景类型标识"""
+ pass
+
+ @classmethod
+ def get_schema(cls) -> Dict[str, Any]:
+ """
+ 返回该场景的数据结构schema
+ 用于生成system prompt中的{OutputInfo}
+ """
+ return cls.model_json_schema()
+
+# 管理器
+class LLMCoreAnalysisManager:
+ """
+ LLM输出解析管理器
+ 智能路由:根据JSON中的type_字段,自动选择对应模型进行解析
+ """
+
+ # 类变量:存储所有注册的数据模型类
+ _model_registry: ClassVar[Dict[str, Type[LLMCoreAnalysisBase]]] = {}
+
+ @classmethod
+ def register(cls, model_class: Type[LLMCoreAnalysisBase]) -> None:
+ """
+ 注册场景数据模型类
+
+ Args:
+ model_class: 继承自LLMCoreAnalysisBase的数据模型类
+
+ Example:
+ LLMCoreAnalysisManager.register(YosugaAudioResponseData)
+ """
+ type_id = model_class.type_()
+ cls._model_registry[type_id] = model_class
+ logger.info(f"已注册LLM数据输出解析模型: {type_id} -> {model_class.__name__}")
+
+ @classmethod
+ def parse(cls, json_str: str) -> List[LLMCoreAnalysisBase]:
+ """
+ 统一解析入口:无论单对象还是数组,总是返回对象列表
+
+ Args:
+ json_str: LLM原始JSON字符串(支持markdown代码块)
+
+ Returns:
+ 解析后的模型实例列表,顺序与JSON数组一致
+
+ Raises:
+ ValidationError: 格式校验失败
+ ValueError: JSON解析失败
+ """
+ print(f"待解析的内容为(llm本次输出原生内容):{json_str}") # TODO:delete
+
+ cleaned = cls._clean_markdown(json_str) # 清理markdown标记
+ try:
+ data = json.loads(cleaned)
+ # 统一包装成列表
+ if not isinstance(data, list):
+ data = [data]
+ except json.JSONDecodeError as e:
+ logger.error(f"JSON解析失败: {e}\n原始输出: {cleaned[:200]}...")
+ raise ValueError(f"无效的JSON格式: {e}")
+
+ results: List[LLMCoreAnalysisBase] = []
+ for idx, item in enumerate(data):
+ type_id = item.get("type")
+ if not type_id:
+ logger.warning(f"跳过第{idx}个元素(无type字段): {item}")
+ continue
+
+ if type_id not in cls._model_registry:
+ logger.warning(f"未注册的类型 '{type_id}',可用类型: {list(cls._model_registry.keys())}")
+ continue
+
+ model_class = cls._model_registry[type_id]
+ try:
+ # 重新序列化为字符串再解析(保持接口兼容)
+ item_json = json.dumps(item, ensure_ascii=False)
+ result = model_class.model_validate_json(item_json)
+ results.append(result)
+ except ValidationError as e:
+ logger.error(f"第{idx}个对象校验失败 (type={type_id}): {e}")
+ continue
+ except Exception as e:
+ logger.error(f"第{idx}个对象解析失败 (type={type_id}): {e}")
+ continue
+
+ logger.success(f"解析完成 | 成功: {len(results)}/{len(data)}")
+ return results
+
+ @staticmethod
+ def _clean_markdown(json_str: str) -> str:
+ # 尝试找到第一个 '[' 和最后一个 ']'
+ start = json_str.find('[')
+ end = json_str.rfind(']')
+ if start != -1 and end != -1:
+ return json_str[start:end + 1]
+ return json_str # Fallback
+
+# 具体场景数据模型
+class YosugaAudioResponseData(LLMCoreAnalysisBase):
+ """
+ 音频ASR场景的LLM输出数据模型
+
+ 使用示例:
+ LLMCoreAnalysisManager.register(YosugaAudioResponseData)
+ data = YosugaAudioResponseData.parse_raw('{"type_": "audio_text", "response_text": "你好"}')
+ data.response_text # 直接属性访问
+ '你好'
+ data.emotion # 默认值
+ 'neutral'
+ """
+
+ type: str = Field(default="audio_text", description="固定为audio_text")
+ response_text: str = Field(..., description="回复文本")
+ emotion: str = Field(default="neutral", description="情感基调")
+ action: str = Field(default="none", description="动作指令")
+
+ @classmethod
+ def type_(cls) -> str:
+ return "audio_text"
+
+class YosugaUITARSResponseData(LLMCoreAnalysisBase):
+ """
+ 自动化操作场景的LLM输出数据模型
+ """
+
+ type: str = Field(default="auto_agent", description="固定为auto_agent")
+ Action: str = Field(..., description="动作名称")
+ x1: Optional[int] = Field(default=None, description="起点x")
+ y1: Optional[int] = Field(default=None, description="起点y")
+ x2: Optional[int] = Field(default=None, description="终点x")
+ y2: Optional[int] = Field(default=None, description="终点y")
+ key: Optional[str] = Field(default="", description="快捷键")
+ content: Optional[str] = Field(default="", description="输入文本")
+ direction: Optional[str] = Field(default="", description="滚动方向")
+
+ @classmethod
+ def type_(cls) -> str:
+ return "auto_agent"
+
+ @field_validator('x1', 'y1', 'x2', 'y2', mode='before')
+ @classmethod
+ def convert_optional_int(cls, v: Any) -> Optional[int]:
+ """
+ 将字符串类型的坐标值转换为 int,空字符串转为 None
+
+ Args:
+ v: 原始值(可能是 str, int, None)
+
+ Returns:
+ Optional[int]: 转换后的值
+ """
+ if v is None:
+ return None
+ if isinstance(v, str):
+ # 处理空字符串
+ if v.strip() == "":
+ return None
+ # 尝试转换为 int
+ try:
+ return int(v)
+ except ValueError:
+ logger.warning(f"无法将字符串 '{v}' 转换为 int,返回 None")
+ return None
+ if isinstance(v, int): # 如果原始值本身就是int类型,直接return
+ return v
+ if isinstance(v, float): # 如果解析出了float,强转成int再return
+ return int(v)
+
+ logger.warning(f"意外的类型 {type(v)},返回 None")
+ return None
+
+ def to_dict(self) -> dict:
+ """
+ 将模型转换为字典,用于生成JSON字符串
+
+ Returns:
+ dict: 模型数据字典
+ """
+ return {
+ "type": self.type,
+ "Action": self.Action,
+ "x1": self.x1,
+ "y1": self.y1,
+ "x2": self.x2,
+ "y2": self.y2,
+ "key": self.key,
+ "content": self.content,
+ "direction": self.direction
+ }
+
+class YosugaUITARSRequestData(LLMCoreAnalysisBase):
+ """
+ 自动化操作场景的LLM对自动化agent的调用的输出解析模型
+ """
+
+ type: str = Field(default="call_auto_agent", description="固定为call_auto_agent")
+ llm_translation: str = Field(default="", description="针对用户的意图的转译,如果用户的意图足够明确,直接照抄即可")
+
+ @classmethod
+ def type_(cls) -> str:
+ return "call_auto_agent"
+
+
+
+
+class YosugaLive2DResponseData(LLMCoreAnalysisBase):
+ """
+ Live2D控制场景的LLM输出数据模型 TODO
+ """
+
+ type: str = Field(default="live2d_control", description="固定为live2d_control")
+ parameter: str = Field(..., description="参数名")
+ value: float = Field(..., description="目标值")
+ duration: int = Field(default=500, description="过渡时间(ms)")
+
+ @classmethod
+ def type_(cls) -> str:
+ return "live2d_control"
+
+class YosugaEmbeddedResponseData(LLMCoreAnalysisBase):
+ """
+ 嵌入式设备场景的LLM输出数据模型 TODO
+ """
+
+ type: str = Field(default="embedded_control", description="固定为embedded_control")
+ device_id: str = Field(..., description="设备ID")
+ command: str = Field(..., description="控制指令")
+ params: Optional[Dict[str, Any]] = Field(default=None, description="参数")
+
+ @classmethod
+ def type_(cls) -> str:
+ return "embedded_control"
+
+# 使用示例
+if __name__ == "__main__":
+ from loguru import logger
+
+ # 注册所有数据模型
+ LLMCoreAnalysisManager.register(YosugaAudioResponseData)
+ LLMCoreAnalysisManager.register(YosugaUITARSResponseData)
+ LLMCoreAnalysisManager.register(YosugaLive2DResponseData)
+ LLMCoreAnalysisManager.register(YosugaEmbeddedResponseData)
+
+ logger.info("=== LLM输出解析模块测试 ===")
+
+ # 测试单对象解析
+ print("\n【测试1:单对象自动识别】")
+ llm_output = '''
+ ```json
+ [{
+ "type": "audio_text",
+ "response_text": "收到!我会微笑回应",
+ "emotion": "cheerful"
+ }]
+ ```
+ '''
+
+ response = LLMCoreAnalysisManager.parse(llm_output)
+ print(f"类型: {response}")
+
+ # 3. 测试多对象解析
+ print("\n【测试2:多对象混合响应】")
+ multi_output = '''
+ ```json
+ [
+ {
+ "type": "auto_agent",
+ "Action": "click",
+ "x1": "100",
+ "y1": "200"
+ },
+ {
+ "type": "live2d_control",
+ "parameter": "ParamEyeLOpen",
+ "value": 0.8,
+ "duration": 300
+ }
+ ]
+ ```
+ '''
+
+ results = LLMCoreAnalysisManager.parse(multi_output)
+ print(f"类型: {results}")
\ No newline at end of file
diff --git a/src/server_core/llm_core/llm_core_dispatcher.py b/src/server_core/llm_core/llm_core_dispatcher.py
new file mode 100644
index 0000000..62b2892
--- /dev/null
+++ b/src/server_core/llm_core/llm_core_dispatcher.py
@@ -0,0 +1,320 @@
+# llm_core/llm_core_dispatcher.py
+
+"""
+动作分发器模块
+负责将解析后的LLM输出对象路由到对应的业务处理器
+支持同步/异步处理,提供回退机制与执行结果追踪
+"""
+
+import asyncio
+from typing import Any, Callable, ClassVar, Dict, List, Optional, Union
+from functools import wraps
+from loguru import logger
+
+from src.server_core.llm_core.llm_core_analysis import LLMCoreAnalysisBase
+
+# 处理器类型定义
+# 同步处理器:接收解析对象,返回执行结果
+SyncActionHandler = Callable[[LLMCoreAnalysisBase], Any]
+# 异步处理器:接收解析对象,返回协程
+AsyncActionHandler = Callable[[LLMCoreAnalysisBase], Any]
+
+
+def handler_error_wrapper(func: Callable) -> Callable:
+ """
+ 处理器错误包装装饰器
+ 统一捕获异常并记录日志,避免单个处理失败导致整体崩溃
+ """
+
+ @wraps(func)
+ def sync_wrapper(*args, **kwargs):
+ try:
+ return func(*args, **kwargs)
+ except Exception as e:
+ logger.exception(f"处理器 {func.__name__} 执行失败: {e}")
+ raise
+
+ @wraps(func)
+ async def async_wrapper(*args, **kwargs):
+ try:
+ return await func(*args, **kwargs)
+ except Exception as e:
+ logger.exception(f"异步处理器 {func.__name__} 执行失败: {e}")
+ raise
+
+ return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
+
+
+class LLMCoreActionDispatcher:
+ """
+ LLM动作分发中心
+
+ 职责:
+ 1. 管理类型到处理器的映射关系
+ 2. 支持同步与异步处理器注册
+ 3. 执行分发并收集处理结果
+ 4. 提供未注册类型的回退机制
+ 5. 支持批量处理与结果聚合
+
+ 使用示例:
+ # 注册同步处理器
+ def handle_audio(data: YosugaAudioResponseData) -> dict:
+ return {"status": "spoken", "text": data.response_text}
+
+ LLMCoreActionDispatcher.register("audio_text", handle_audio)
+
+ # 注册异步处理器
+ async def handle_ui(data: YosugaUITARSResponseData):
+ await perform_click(data.x1, data.y1)
+ return {"status": "clicked"}
+
+ LLMCoreActionDispatcher.register_async("auto_agent", handle_ui)
+
+ # 执行分发
+ results = LLMCoreActionDispatcher.execute(parsed_objects)
+ """
+
+ # 类变量存储所有注册的处理器
+ _sync_handlers: ClassVar[Dict[str, SyncActionHandler]] = {}
+ _async_handlers: ClassVar[Dict[str, AsyncActionHandler]] = {}
+ _fallback_handler: ClassVar[Optional[Union[SyncActionHandler, AsyncActionHandler]]] = None
+
+ @classmethod
+ def register(cls, type_id: str, handler: SyncActionHandler) -> None:
+ """
+ 注册同步处理器
+
+ Args:
+ type_id: 与 LLMCoreAnalysisBase.type_() 返回值匹配的标识符
+ handler: 同步函数,接收解析对象并返回执行结果
+
+ Raises:
+ ValueError: 处理器不是可调用的函数
+ """
+ if not callable(handler):
+ raise ValueError(f"处理器必须是可调用函数,收到: {type(handler)}")
+
+ # 检查是否已注册,防止覆盖
+ if type_id in cls._sync_handlers or type_id in cls._async_handlers:
+ logger.warning(f"类型 '{type_id}' 已被注册,将被覆盖")
+
+ cls._sync_handlers[type_id] = handler_error_wrapper(handler)
+ logger.success(f"注册同步处理器: {type_id} → {handler.__name__}")
+
+ @classmethod
+ def register_async(cls, type_id: str, handler: AsyncActionHandler) -> None:
+ """
+ 注册异步处理器
+
+ Args:
+ type_id: 类型标识符
+ handler: 异步函数,接收解析对象并返回协程
+
+ Raises:
+ ValueError: 处理器不是有效的异步函数
+ """
+ if not asyncio.iscoroutinefunction(handler):
+ raise ValueError(f"异步处理器必须是协程函数,收到: {type(handler)}")
+
+ if type_id in cls._sync_handlers or type_id in cls._async_handlers:
+ logger.warning(f"类型 '{type_id}' 已被注册,将被覆盖")
+
+ cls._async_handlers[type_id] = handler_error_wrapper(handler)
+ logger.success(f"注册异步处理器: {type_id} → {handler.__name__}")
+
+ @classmethod
+ def set_fallback(cls, handler: Union[SyncActionHandler, AsyncActionHandler]) -> None:
+ """
+ 设置回退处理器(未注册类型的默认处理)
+
+ Args:
+ handler: 同步或异步函数,处理所有未匹配的类型
+ """
+ wrapped = handler_error_wrapper(handler)
+ cls._fallback_handler = wrapped
+ handler_type = "异步" if asyncio.iscoroutinefunction(handler) else "同步"
+ logger.info(f"设置{handler_type}回退处理器: {handler.__name__}")
+
+ @classmethod
+ def get_handler(cls, type_id: str) -> Optional[Union[SyncActionHandler, AsyncActionHandler]]:
+ """获取指定类型的处理器"""
+ # 优先返回同步处理器
+ if type_id in cls._sync_handlers:
+ return cls._sync_handlers[type_id]
+ # 其次返回异步处理器
+ if type_id in cls._async_handlers:
+ return cls._async_handlers[type_id]
+ # 返回回退处理器
+ return cls._fallback_handler
+
+ @classmethod
+ def execute(
+ cls,
+ analysis_results: List[LLMCoreAnalysisBase],
+ run_async: bool = False
+ ) -> Dict[str, List[Any]]:
+ """
+ 执行分发处理
+
+ Args:
+ analysis_results: 解析器返回的对象列表
+ run_async: 是否启用异步执行模式(需要业务代码支持asyncio)
+
+ Returns:
+ 执行结果字典:
+ {
+ "success": [处理成功的结果列表],
+ "failed": [{"type": "...", "error": "..."}],
+ "skipped": [跳过的类型列表]
+ }
+ """
+ if not analysis_results:
+ logger.warning("无对象需要分发")
+ return {"success": [], "failed": [], "skipped": []}
+
+ if run_async and (cls._async_handlers or asyncio.iscoroutinefunction(cls._fallback_handler)):
+ # 异步执行模式(需要事件循环)
+ return asyncio.run(cls._execute_async(analysis_results))
+
+ # 默认同步执行
+ return cls._execute_sync(analysis_results)
+
+ @classmethod
+ def _execute_sync(cls, results: List[LLMCoreAnalysisBase]) -> Dict[str, List[Any]]:
+ """同步批量执行"""
+ outputs = {"success": [], "failed": [], "skipped": []}
+
+ for idx, result in enumerate(results):
+ type_id = result.type
+ logger.debug(f"[{idx}] 分发类型: {type_id}")
+
+ handler = cls.get_handler(type_id)
+ if not handler:
+ outputs["skipped"].append(type_id)
+ logger.warning(f"无处理器,跳过: {type_id}")
+ continue
+
+ # 执行同步处理器
+ if asyncio.iscoroutinefunction(handler):
+ logger.error(f"异步处理器 '{type_id}' 不能在同步模式下执行")
+ outputs["failed"].append({"type": type_id, "error": "异步处理器需要run_async=True"})
+ continue
+
+ try:
+ output = handler(result)
+ outputs["success"].append({
+ "type": type_id,
+ "output": output,
+ "index": idx
+ })
+ logger.success(f"[{idx}] 处理成功: {type_id}")
+ except Exception as e:
+ outputs["failed"].append({
+ "type": type_id,
+ "error": str(e),
+ "index": idx
+ })
+ logger.error(f"[{idx}] 处理失败 {type_id}: {e}")
+
+ cls._log_summary(outputs, len(results))
+ return outputs
+
+ @classmethod
+ async def _execute_async(cls, results: List[LLMCoreAnalysisBase]) -> Dict[str, List[Any]]:
+ """异步批量执行"""
+ outputs = {"success": [], "failed": [], "skipped": []}
+ tasks = []
+
+ for idx, result in enumerate(results):
+ type_id = result.type
+ logger.debug(f"[{idx}] 异步分发类型: {type_id}")
+
+ handler = cls.get_handler(type_id)
+ if not handler:
+ outputs["skipped"].append(type_id)
+ logger.warning(f"无处理器,跳过: {type_id}")
+ continue
+
+ # 创建协程任务
+ if asyncio.iscoroutinefunction(handler):
+ task = cls._run_async_handler(handler, result, idx, outputs)
+ else:
+ # 同步处理器在异步线程池中执行
+ task = cls._run_sync_in_executor(handler, result, idx, outputs)
+
+ tasks.append(task)
+
+ # 并发执行所有任务
+ await asyncio.gather(*tasks, return_exceptions=True)
+
+ cls._log_summary(outputs, len(results))
+ return outputs
+
+ @classmethod
+ async def _run_async_handler(cls, handler, result, idx, outputs):
+ """运行异步处理器"""
+ try:
+ output = await handler(result)
+ outputs["success"].append({
+ "type": result.type,
+ "output": output,
+ "index": idx
+ })
+ logger.success(f"[{idx}] 异步处理成功: {result.type}")
+ except Exception as e:
+ outputs["failed"].append({
+ "type": result.type,
+ "error": str(e),
+ "index": idx
+ })
+ logger.error(f"[{idx}] 异步处理失败 {result.type}: {e}")
+
+ @classmethod
+ async def _run_sync_in_executor(cls, handler, result, idx, outputs):
+ """在线程池中运行同步处理器"""
+ loop = asyncio.get_event_loop()
+ try:
+ output = await loop.run_in_executor(None, handler, result)
+ outputs["success"].append({
+ "type": result.type,
+ "output": output,
+ "index": idx
+ })
+ logger.success(f"[{idx}] 同步处理器(线程池)成功: {result.type}")
+ except Exception as e:
+ outputs["failed"].append({
+ "type": result.type,
+ "error": str(e),
+ "index": idx
+ })
+ logger.error(f"[{idx}] 同步处理器(线程池)失败 {result.type}: {e}")
+
+ @classmethod
+ def _log_summary(cls, outputs: Dict, total: int):
+ """输出处理摘要"""
+ success_count = len(outputs["success"])
+ failed_count = len(outputs["failed"])
+ skipped_count = len(outputs["skipped"])
+
+ logger.info(
+ f"分发完成 | 总计: {total} | 成功: {success_count} "
+ f"失败: {failed_count} 跳过: {skipped_count}"
+ )
+
+ @classmethod
+ def clear(cls) -> None:
+ """清空所有处理器(测试/热重载用)"""
+ cls._sync_handlers.clear()
+ cls._async_handlers.clear()
+ cls._fallback_handler = None
+ logger.info("已清空所有动作处理器")
+
+ @classmethod
+ def list_handlers(cls) -> Dict[str, str]:
+ """列出当前注册的所有处理器"""
+ handlers = {}
+ for type_id, handler in cls._sync_handlers.items():
+ handlers[type_id] = f"同步: {handler.__name__}"
+ for type_id, handler in cls._async_handlers.items():
+ handlers[type_id] = f"异步: {handler.__name__}"
+ return handlers
\ No newline at end of file
diff --git a/src/server_core/llm_core/llm_core_prompt_manager.py b/src/server_core/llm_core/llm_core_prompt_manager.py
new file mode 100644
index 0000000..12fd22b
--- /dev/null
+++ b/src/server_core/llm_core/llm_core_prompt_manager.py
@@ -0,0 +1,174 @@
+# llm_core/llm_core_prompt_manager.py
+
+"""
+llm prompt 结构化信息管理
+用于将各种输入到llm_core的数据流结构化,并附加注释,以方便llm可以准确的理解并可以结构化返回相关内容
+"""
+from abc import ABC, abstractmethod
+from pydantic import BaseModel, Field, field_validator
+from typing import Callable, List, Optional, Coroutine, Any, ClassVar, Dict
+from src.server_core.llm_core.llm_core_prompts import YOSUGA_SYSTEM_PROMPT_SCH
+
+class LLMCorePromptBase(BaseModel, ABC):
+ """LLM 提示词基类:定义输入输出结构"""
+ @abstractmethod
+ def type(self) -> str:
+ """返回该提示词类型的唯一标识"""
+ pass
+
+ @abstractmethod
+ def describe_input(self) -> str:
+ """生成输入格式的自然语言描述(填充到 {InputInfo})"""
+ pass
+
+ @abstractmethod
+ def describe_output(self) -> str:
+ """生成输出格式的自然语言描述(填充到 {OutputInfo})"""
+ pass
+
+ def to_json(self) -> str:
+ return self.model_dump_json()
+
+ @classmethod
+ def from_json(cls, json_data: str):
+ return cls.model_validate_json(json_data)
+
+
+class LLMCorePromptManager(LLMCorePromptBase):
+ """聚合所有子类,生成完整的 prompt 信息"""
+ # 类变量:存储所有注册的子类
+ _registry: ClassVar[Dict[str, LLMCorePromptBase]] = {}
+
+ def get_registry_size(self) -> int:
+ return len(self._registry)
+
+ def register(self, son: LLMCorePromptBase) -> None:
+ self._registry[son.type()] = son
+
+ def type(self) -> str:
+ return "manager"
+
+ def describe_input(self) -> str:
+ """聚合所有子类的输入描述"""
+ return "\n".join(
+ f"[{type_id}]{son.describe_input()}\n"
+ for type_id, son in self._registry.items()
+ )
+
+ def describe_output(self) -> str:
+ """聚合所有子类的输出描述"""
+ return "\n".join(
+ f"[{type_id}]{son.describe_output()}\n"
+ for type_id, son in self._registry.items()
+ )
+
+class YosugaAudioASRText(LLMCorePromptBase):
+ """音频ASR文本输入场景"""
+ def type(self) -> str:
+ return "用户语音asr信息"
+
+ def describe_input(self) -> str:
+ return '''
+ 当用户通过语音与Yosuga交互时,你会收到如下JSON结构:
+ {
+ "text": "用户说的话(字符串)",
+ "confidence": 0.95
+ }
+ - `text`: 语音转写的原始文本,可能包含口语化表达或识别错误
+ - `confidence`: ASR引擎的识别置信度,低于0.8需警惕识别错误
+ '''
+
+ def describe_output(self) -> str:
+ return '''
+ 针对用户音频识别出的文本内容输入,你应该按以下JSON格式回复:
+ {
+ "type": "固定为audio_text",
+ "response_text": "你的回复文本(字符串)",
+ "emotion": "neutral",
+ "action": "none"
+ }
+ - `response_text`: 给用户的自然语言回复
+ - `emotion`: 回复的情感基调,可选值:neutral/cheerful/sad/angry
+ - `action`: 触发的动作指令,如"wave_hand"、"nod"等,"none"表示无动作
+ '''
+
+class YosugaEmbedded(LLMCorePromptBase):
+ """嵌入式设备输入场景"""
+ def type(self) -> str:
+ pass
+
+ def describe_input(self) -> str:
+ pass
+
+ def describe_output(self) -> str:
+ pass
+
+class YosugaUITARS(LLMCorePromptBase):
+ """自动化操作构建场景"""
+ def type(self) -> str:
+ return "自动化操作信息"
+
+ def describe_input(self) -> str:
+ return '''
+ 当你尝试调用自动化agent的时候,自动化agent将会返回下面的内容作为输入给你,自动化agent的返回信息很重要,有的任务不是一次就可以完成的,此时你可能需要多次调用自动化agent,直到agent返回finished(任务完成):
+ {
+ "auto_agent":"
+ Thought: 自动化agent的推理过程,可以考虑二次加工后作为回复用户的内容
+ Action: 对应的自动化动作[包括click(单击坐标), left_double(双击坐标), right_single(右键单击), drag(拖拽), hotkey(快捷键), type(输入文本), scroll(滚动), wait(等待), finished(任务完成)]
+ "
+ }
+ '''
+
+ def describe_output(self) -> str:
+ return '''
+ 1. 当用户试图进行一些自动化操作时,你可以通过返回以下json来调用自动化agent,调用的时候也可以一并回复用户一些文本内容:
+ {
+ "type": "固定为call_auto_agent",
+ "llm_translation": "针对用户的意图的转译,如果用户的意图足够明确,直接照抄即可,注意转译的语言必须是中文,可以混合部分英文(应用名称),这个和聊天交流的语言不统一。"
+ }
+ 2. 针对自动化agent操作输入,你应该按以下JSON格式回复。如果你发起了自动化agent请求之后,却没有返回下面的JSON内容,那么你的请求并不会作用到用户的设备,所以你调用完自动化agent一定要及时返回下面的JSON内容:
+ {
+ "type": "固定为auto_agent",
+ "Action": "自动化agent返回的相应的自动化动作名称",
+ "x1": "某个操作的x1,不是所有的操作都有,如果相关的操作没有,写成-1即可,y1,x2,y2也同理",
+ "y1": "某个操作的y1",
+ "x2": "某个操作的x2",
+ "y2": "某个操作的y2",
+ "key": "快捷键,若当前操作没有该字段信息,此处内容为空即可",
+ "content": "输入文本的内容,若当前操作没有该字段信息,此处内容为空即可",
+ "direction": "滚动方向(down or up),若当前操作没有该字段信息,此处内容为空即可,同时需要关心自动化agent回复的x1,和y1坐标"
+ }
+ 自动化agent返回的操作信息不一定包括JSON的全部字段,例如某次返回只有key的内容,或者只有content的内容。
+ 针对自动化agent操作输入的返回,若没有相关内容可以留空相关字段,请不要省略掉任何字段名称。
+
+ 注意:自动化agent的状态可见YosugaSystemState表。
+ '''
+
+class YosugaLive2DControl(LLMCorePromptBase):
+ """对Yosuga Live2D控制场景"""
+ def type(self) -> str:
+ return "Yosuga Live2D控制信息"
+
+ def describe_input(self) -> str:
+ pass
+
+ def describe_output(self) -> str:
+ pass
+
+
+if __name__ == "__main__":
+ # 注册所有prompt处理器
+ manager = LLMCorePromptManager()
+ manager.register(YosugaAudioASRText())
+ manager.register(YosugaUITARS())
+
+ # 生成最终 system prompt
+ system_prompt = YOSUGA_SYSTEM_PROMPT_SCH.format(
+ InputInfo=manager.describe_input(),
+ OutputInfo=manager.describe_output(),
+ RoleSetting="...",
+ Language="ja",
+ Memory = "",
+ SystemStateTable=""
+ )
+ print(system_prompt)
\ No newline at end of file
diff --git a/src/server_core/llm_core/llm_core_prompts.py b/src/server_core/llm_core/llm_core_prompts.py
new file mode 100644
index 0000000..06a6a64
--- /dev/null
+++ b/src/server_core/llm_core/llm_core_prompts.py
@@ -0,0 +1,67 @@
+# llm_core/llm_core_prompts.py
+
+YOSUGA_SYSTEM_PROMPT_SCH = """
+你是Yosuga_server这个项目的核心LLM。
+你将会为用户完成各种任务,进行结构化的输出。
+首先向你介绍一下Yosuga这个项目:
+本项目的作者是Misakiotoha(みさきおとは[見崎音羽])。
+之所以叫Yosuga,这个词来源日语当中的单词"縁"的发音,其意思是"缘分,关系"。
+本项目分为三个部分:
+1. Yosuga:这是项目的前端部分,是Yosuga与用户交互的一层,采用C++20 + Qt6.6.3编写,使用到的核心外部库为Live2D For C++ SDK。
+2. Yosuga_server:这是项目的后端部分,是Yosuga的核心,采用python3.11编写,使用到的外部库较多,负责联系项目的各个部分。
+3. Yosuga_embedded:这是项目的拓展部分,实现了Yosuga对嵌入式设备拥有几乎完全的自定义控制能力,采用C语言编写,只使用到了cJSON库,平台无关,增强了Yosuga与外界的交互能力。
+
+作为Yosuga_server的核心LLM,你的任务包括以下的内容:
+1. 接受自定义的结构化的信息(json),理解并给出正确的回复。
+2. 输出自定义的结构化的信息(json),给出正确的回复。
+
+对于你会接受到什么信息,以及如何进行正确的返回,都会在下面做出解释。
+
+你接受到的信息可能有:
+1. 来自Yosuga的对于Live2D模型控制的结构化数据。
+2. 来自Yosuga_embedded的可控制的嵌入式设备信息的结构化摘要数据。
+3. 综合了各种信息的结构化请求数据(例如同时包括用户所说的话语与其他决策信息)。
+
+你的工作流程:
+1. 接受来自用户的请求,理解用户的意图(例如用户只有聊天的目的,或者是用户想进行一些操作)
+2. 如果用户只希望聊天,那就只和用户聊天即可
+3. 如果用户还有其他意图,你需要根据你的能力来完成用户的需求,如果用户的需求不在你的能力范围内,可以婉拒。
+
+你的能力范围:
+1. 你可以和用户进行聊天,当用户只是想和你聊天,那就只和用户进行聊天
+2. 你可以主动调用自动化agent,当用户有着一些自动化操作的需求时候,你可以按一定调用格式去主动调用自动化agent模型,并理解自动化模型的返回内容,以规定的格式返回给服务端最终作用到用户的设备。
+3. 用户也可能会想和你一起玩一些游戏,这个时候你需要理解用户的意图,并调用自动化agent,同时返回一些聊天内容回复,以此和用户一起玩游戏。
+4. 你可以和各种种类的嵌入式设备进行交互,当用户有着一些现实的需求时候,你可以根据当前系统的状态,可控制的设备信息等信息综合判断,通过规定的方式返回规定的内容操作嵌入式设备,以满足用户的需求。
+
+同时你还需要进行角色扮演,完美地扮演符合设定的角色,也要兼顾对于各种问题的处理。
+
+以下是你会接收到的结构化的信息(带解释):
+{InputInfo}
+
+以下是你需要根据不同的结构化信息需要输出的结构化信息(带解释):
+{OutputInfo}
+
+你应当扮演的角色设定为:
+{RoleSetting}
+
+你和用户聊天时候的回复语言为:
+{Language}
+
+
+注意:你的所有回复必须是一个 JSON 数组,即便只有单项任务。格式为 [ {{ "type": "...", ... }} ]。禁止输出JSON数组外的任何解释性文字,同时也禁止在JSON的回复内容中加入任何表情符号。
+注意:你接收到的结构化信息,有时可能会有多个,例如用户在问你的同时,自动化agent也返回了结果,这样的情况很常见,你需要按要求返回一个或者多个json即可。
+注意:有时候用户的请求可能既涉及到了聊天,也涉及到了自动化agent的调用,那么你在返回的json数组当中加入应当返回的内容即可,可以是聊天内容回复加上自动化agent调用。
+注意:有时用户的请求会比较复杂,需要你连续的调用自动化agent,根据自动化agent的输出,去一步步的完成用户的请求。
+
+还有一些是你必须要参考的内容(例如与用户过去的记忆,系统实时状态表等内容):
+与用户过去的记忆:
+{Memory}
+
+系统实时状态表:
+{SystemStateTable}
+
+"""
+
+YOSUGA_SYSTEM_PROMPT_EN = """
+
+"""
\ No newline at end of file
diff --git a/src/server_core/llm_core/llm_core_token.py b/src/server_core/llm_core/llm_core_token.py
new file mode 100644
index 0000000..8ec422c
--- /dev/null
+++ b/src/server_core/llm_core/llm_core_token.py
@@ -0,0 +1,367 @@
+# llm_core/llm_core_token.py
+
+"""
+Token 计算与管理模块
+支持双数据源智能切换:
+- 优先使用大模型 API 返回的精确 usage 数据
+- 当 API 不返回时,自动回退到 tiktoken 手动估算
+"""
+import time
+from typing import List, Dict, Any, Optional, Union
+from dataclasses import dataclass, field
+import tiktoken
+from loguru import logger
+
+@dataclass
+class TokenUsage:
+ """Token 使用量统计"""
+ prompt_tokens: int = 0
+ completion_tokens: int = 0
+ total_tokens: int = 0
+ source: str = field(default="manual", repr=True) # "api" | "manual"
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ "prompt_tokens": self.prompt_tokens,
+ "completion_tokens": self.completion_tokens,
+ "total_tokens": self.total_tokens,
+ "source": self.source
+ }
+
+
+@dataclass
+class TokenizerInfo:
+ """Tokenizer 元数据"""
+ model_name: str
+ encoding_name: str
+ is_fallback: bool
+ estimated_accuracy: str # "high" | "medium" | "low"
+
+
+class TokenManager:
+ """
+ Token 管理核心类
+ 智能数据源切换:API返回 > 手动估算
+ """
+ def __init__(self, model_name: str):
+ self.model_name = model_name
+ self.tokenizer = self._get_tokenizer(model_name)
+ # 存储最近一次 API 返回的 usage
+ self._last_api_usage: Optional[TokenUsage] = None
+ # API 数据有效期(秒),超过此时间则视为过期,回退到手动计算
+ self._api_usage_expiry: float = 30.0
+ # 记录 API usage 的获取时间
+ self._last_api_usage_time: float = 0.0
+
+ logger.info(
+ f"TokenManager 初始化完成 | "
+ f"模型: {model_name} | "
+ f"编码器: {self.tokenizer.name}"
+ )
+
+ def _get_tokenizer(self, model_name: str) -> tiktoken.Encoding:
+ """获取 tokenizer(与之前实现相同,省略重复代码)"""
+ model_tokenizer_map = {
+ "qwen": "gpt-3.5-turbo",
+ "llama": "gpt-3.5-turbo",
+ "gemma": "gpt-3.5-turbo",
+ }
+
+ try:
+ return tiktoken.encoding_for_model(model_name)
+ except KeyError:
+ pass
+
+ for prefix, mapped_model in model_tokenizer_map.items():
+ if prefix in model_name.lower():
+ logger.info(
+ f"模型 '{model_name}' 映射到 tokenizer '{mapped_model}'"
+ )
+ return tiktoken.encoding_for_model(mapped_model)
+
+ logger.warning(
+ f"tiktoken 不支持模型 '{model_name}',"
+ f"降级使用 'cl100k_base' 编码器"
+ )
+ return tiktoken.get_encoding("cl100k_base")
+
+ def record_api_usage(self, usage: Optional[Union[Dict[str, int], TokenUsage]]) -> None:
+ """
+ 记录大模型 API 返回的 usage 数据
+
+ Args:
+ usage: API 返回的 usage 数据(dict 或 TokenUsage 对象)
+ 如果为 None 或空,则忽略
+ """
+ if not usage:
+ logger.debug("API usage 为空,未记录")
+ return
+
+ # 在底层已经统一好了不同大模型对于usage的处理
+ if isinstance(usage, TokenUsage):
+ api_usage = usage
+ api_usage.source = "api"
+ else:
+ api_usage = TokenUsage(
+ prompt_tokens=usage.get("prompt_tokens", 0),
+ completion_tokens=usage.get("completion_tokens", 0),
+ total_tokens=usage.get("total_tokens", 0),
+ source="api"
+ )
+
+ self._last_api_usage = api_usage
+ self._last_api_usage_time = time.time()
+
+ logger.debug(
+ f"记录 API usage | "
+ f"Prompt: {api_usage.prompt_tokens} | "
+ f"Completion: {api_usage.completion_tokens} | "
+ f"Total: {api_usage.total_tokens}"
+ )
+
+ def get_current_usage(self, prefer_api: bool = True) -> TokenUsage:
+ """
+ 获取当前 Token 使用情况(智能数据源切换)
+
+ Args:
+ prefer_api: 是否优先使用 API 数据(默认 True)
+
+ Returns:
+ TokenUsage 对象,包含数据来源标记
+
+ Notes:
+ - 如果 prefer_api=True 且 API 数据在有效期内,直接返回 API 数据
+ - 否则回退到手动估算
+ """
+ current_time = time.time()
+
+ # 检查 API 数据是否有效
+ if prefer_api and self._last_api_usage:
+ time_elapsed = current_time - self._last_api_usage_time
+ if time_elapsed <= self._api_usage_expiry:
+ logger.debug(
+ f"使用 API Token 数据({time_elapsed:.1f}s 内)| "
+ f"{self._last_api_usage.prompt_tokens} + "
+ f"{self._last_api_usage.completion_tokens} = "
+ f"{self._last_api_usage.total_tokens}"
+ )
+ return self._last_api_usage
+
+ # API 数据无效或无数据,回退到手动估算
+ logger.debug("API usage 无效/过期,回退到手动估算")
+ return self._estimate_manual_usage()
+
+ def _estimate_manual_usage(self) -> TokenUsage:
+ """内部方法:创建空的手动 usage(占位符)"""
+ return TokenUsage(source="manual")
+
+ def get_context_usage(self, history: List[Any]) -> TokenUsage:
+ """
+ 计算对话上下文的 Token 占用(必须手动计算)
+
+ 注意:
+ - 上下文占用无法从 API 获得,必须手动估算
+ - 此方法不涉及 _last_api_usage
+
+ Args:
+ history: 历史消息列表
+
+ Returns:
+ TokenUsage 对象,source 始终为 "manual"
+ """
+ tokens = self.count_messages_tokens(history)
+ return TokenUsage(
+ prompt_tokens=tokens,
+ completion_tokens=0,
+ total_tokens=tokens,
+ source="manual"
+ )
+
+ def count_text_tokens(self, text: str) -> int:
+ """计算单段文本的 token 数量(与之前相同)"""
+ if not isinstance(text, str) or not text:
+ return 0
+ return len(self.tokenizer.encode(text))
+
+ def count_messages_tokens(
+ self,
+ messages: List[Any],
+ tokens_per_message: int = 3
+ ) -> int:
+ """计算消息列表的总 token 数量(与之前相同,优化实现)"""
+ if not messages:
+ return 0
+
+ num_tokens = 0
+ for msg in messages:
+ # 统一转换为字典
+ if hasattr(msg, "to_dict"):
+ msg_dict = msg.to_dict()
+ elif hasattr(msg, "model_dump"):
+ msg_dict = msg.model_dump()
+ else:
+ msg_dict = msg
+
+ # 计算内容和 role 的 token
+ num_tokens += self.count_text_tokens(msg_dict.get("content", ""))
+ num_tokens += self.count_text_tokens(msg_dict.get("role", ""))
+ num_tokens += tokens_per_message
+
+ # 加上回复前缀的开销
+ num_tokens += 3
+ return num_tokens
+
+ def estimate_chat_tokens(
+ self,
+ system_prompt: Optional[str],
+ history: List[Any],
+ current_input: str
+ ) -> TokenUsage:
+ """
+ 估算一次完整对话所需的 token 数量
+
+ Args:
+ system_prompt: 系统提示词
+ history: 历史消息列表
+ current_input: 当前用户输入
+
+ Returns:
+ TokenUsage 对象,source 始终为 "manual"
+
+ Notes:
+ - 这是预估,不是实际 API 返回值
+ - 用于调试和前置检查
+ """
+ messages = []
+
+ if system_prompt:
+ messages.append({"role": "system", "content": system_prompt})
+
+ messages.extend(history)
+ messages.append({"role": "user", "content": current_input})
+
+ total = self.count_messages_tokens(messages)
+
+ return TokenUsage(
+ prompt_tokens=total,
+ completion_tokens=0,
+ total_tokens=total,
+ source="manual"
+ )
+
+ def format_usage_log(
+ self,
+ usage: Optional[Union[Dict[str, int], TokenUsage]] = None,
+ source: str = "AUTO"
+ ) -> str:
+ """
+ 格式化 token 使用日志
+
+ Args:
+ usage: usage 数据(可选)。如果为 None,自动获取当前 usage
+ source: 数据来源标记("AUTO" | "API" | "MANUAL" | "CONTEXT")
+
+ Returns:
+ 格式化的日志字符串
+ """
+ # 如果未提供 usage,自动获取
+ if usage is None:
+ if source == "CONTEXT":
+ # 上下文场景必须手动计算
+ usage_obj = self.get_context_usage([])
+ else:
+ # 自动选择最佳数据源
+ usage_obj = self.get_current_usage(prefer_api=True)
+ else:
+ # 使用提供的 usage
+ if isinstance(usage, TokenUsage):
+ usage_obj = usage
+ else:
+ usage_obj = TokenUsage(
+ prompt_tokens=usage.get("prompt_tokens", 0),
+ completion_tokens=usage.get("completion_tokens", 0),
+ total_tokens=usage.get("total_tokens", 0),
+ source="api" if source == "API" else "manual"
+ )
+
+ # 根据来源选择前缀和图标
+ prefix_map = {
+ "API": "API Token统计",
+ "MANUAL": "手动估算",
+ "CONTEXT": "上下文占用",
+ "AUTO": "Token统计"
+ }
+ prefix = prefix_map.get(source, "Token统计")
+
+ # 添加数据来源标记
+ source_icon = "⚡" if usage_obj.source == "api" else "🧮"
+
+ # 格式化输出
+ if usage_obj.completion_tokens > 0:
+ return (
+ f"{prefix} {source_icon} | "
+ f"Prompt: {usage_obj.prompt_tokens} | "
+ f"Completion: {usage_obj.completion_tokens} | "
+ f"Total: {usage_obj.total_tokens}"
+ )
+ else:
+ return (
+ f"{prefix} {source_icon} | "
+ f"Total: {usage_obj.total_tokens}"
+ )
+
+ def get_tokenizer_info(self) -> TokenizerInfo:
+ """获取当前 tokenizer 的详细信息(与之前相同)"""
+ if "cl100k_base" in self.tokenizer.name and "gpt-3.5" not in self.model_name:
+ accuracy = "low"
+ elif self.model_name in self.tokenizer.name:
+ accuracy = "high"
+ else:
+ accuracy = "medium"
+
+ return TokenizerInfo(
+ model_name=self.model_name,
+ encoding_name=self.tokenizer.name,
+ is_fallback="cl100k_base" in self.tokenizer.name,
+ estimated_accuracy=accuracy
+ )
+
+ def is_token_limit_approaching(
+ self,
+ current_tokens: int,
+ limit: int,
+ threshold: float = 0.85
+ ) -> bool:
+ """判断 token 使用量是否接近限制(与之前相同)"""
+ return current_tokens > limit * threshold
+
+ def calculate_chunk_size(
+ self,
+ available_tokens: int,
+ safety_margin: float = 0.1
+ ) -> int:
+ """计算安全的消息块大小(与之前相同)"""
+ return int(available_tokens * (1 - safety_margin))
+
+ def clear_api_usage_cache(self):
+ """清空 API usage 缓存(用于测试)"""
+ self._last_api_usage = None
+ self._last_api_usage_time = 0.0
+ logger.debug("API usage 缓存已清空")
+
+
+# 单例工厂
+class TokenManagerFactory:
+ """TokenManager 工厂类,支持缓存复用"""
+ _instances: Dict[str, TokenManager] = {}
+
+ @classmethod
+ def get_manager(cls, model_name: str) -> TokenManager:
+ if model_name not in cls._instances:
+ cls._instances[model_name] = TokenManager(model_name)
+ return cls._instances[model_name]
+
+ @classmethod
+ def clear_cache(cls):
+ cls._instances.clear()
+ logger.info("TokenManager 缓存已清空")
\ No newline at end of file
diff --git a/src/server_core/llm_core/readme.md b/src/server_core/llm_core/readme.md
new file mode 100644
index 0000000..78c265f
--- /dev/null
+++ b/src/server_core/llm_core/readme.md
@@ -0,0 +1,242 @@
+### 服务端llm核心
+服务端以AI为核心去驱动,进行各种调用。
+
+本模块为服务端的核心llm模块,这个llm必须足够聪明,能够稳定返回结构化结构。
+
+
+llm_core负责实现:
+
+
+
+
+#### 类图
+```mermaid
+classDiagram
+ class YosugaLLMCore {
+ <<主控制器>>
+ -ModelConfig model_config
+ -LLMCoreConfig core_config
+ -UnifiedLLM llm_client
+ -TokenManager token_manager
+ -LLMCorePromptManager prompt_manager
+ -List[ChatMessage] _history
+ -Lock _history_lock
+ -Lock _config_lock
+ +interact() Dict
+ +register_prompt_module() void
+ +register_action_handler() void
+ +reload_model() void
+ +get_context_stats() Dict
+ }
+
+ class LLMCoreConfig {
+ <<运行时配置>>
+ -int max_context_tokens
+ -bool enable_history
+ -str language
+ -str role_setting
+ -bool auto_dispatch
+ -bool dispatch_async
+ -str memory
+ -str system_state_table
+ }
+
+ class UnifiedLLM {
+ <<大模型调用层>>
+ -ModelConfig config
+ -BaseLLMClient client
+ +chat() ModelResponse
+ +complete() ModelResponse
+ +stream_chat() Iterator
+ +update_config() void
+ }
+
+ class TokenManager {
+ <>
+ -str model_name
+ -tiktoken.Encoding tokenizer
+ -TokenUsage _last_api_usage
+ +record_api_usage() void
+ +get_current_usage() TokenUsage
+ +get_context_usage() TokenUsage
+ +count_messages_tokens() int
+ +format_usage_log() str
+ }
+
+ class LLMCorePromptManager {
+ <>
+ -Dict _registry
+ +register() void
+ +describe_input() str
+ +describe_output() str
+ }
+
+ class LLMCoreAnalysisManager {
+ <<输出解析>>
+ <<静态类>>
+ -Dict _model_registry
+ +register() void
+ +parse() List[LLMCoreAnalysisBase]
+ }
+
+ class LLMCoreActionDispatcher {
+ <<动作分发>>
+ <<静态类>>
+ -Dict _sync_handlers
+ -Dict _async_handlers
+ -Callable _fallback_handler
+ +register() void
+ +register_async() void
+ +execute() Dict
+ }
+
+ class LLMCoreAnalysisBase {
+ <<抽象基类>>
+ <<模型数据基类>>
+ #str type
+ +type_() str
+ +get_schema() Dict
+ }
+
+ class ChatMessage {
+ <<消息实体>>
+ -str role
+ -str content
+ -str name
+ +to_dict() Dict
+ }
+
+ class ModelResponse {
+ <<响应实体>>
+ -str content
+ -str model
+ -Dict usage
+ -str finish_reason
+ -Dict raw_response
+ }
+
+ YosugaLLMCore --> LLMCoreConfig : 持有配置
+ YosugaLLMCore --> UnifiedLLM : 调用大模型
+ YosugaLLMCore --> TokenManager : 统计Token
+ YosugaLLMCore --> LLMCorePromptManager : 管理Prompt
+ YosugaLLMCore --> LLMCoreAnalysisManager : 解析输出
+ YosugaLLMCore --> LLMCoreActionDispatcher : 分发动作
+ YosugaLLMCore --> ChatMessage : 管理历史
+ UnifiedLLM --> ModelResponse : 返回响应
+ LLMCoreAnalysisManager --> LLMCoreAnalysisBase : 解析为
+ TokenManager --> ModelResponse : 接收usage
+```
+
+
+#### 时序图
+```mermaid
+sequenceDiagram
+ participant Client as 客户端
+ participant Core as YosugaLLMCore
+ participant Token as TokenManager
+ participant Prompt as PromptManager
+ participant LLM as UnifiedLLM
+ participant Parser as AnalysisManager
+ participant Dispatcher as ActionDispatcher
+
+ Client->>Core: interact(user_input)
+ Note over Core: 输入预处理
+
+ Core->>Token: count_messages_tokens(_history)
+ Token-->>Core: current_usage
+
+ Core->>Core: _maintain_context_limit()
+ Note over Core: 检查溢出并清理
+
+ Core->>Prompt: get_system_prompt()
+ Note over Prompt: 聚合InputInfo/OutputInfo
+
+ Prompt-->>Core: system_prompt
+
+ Core->>Core: _build_request_messages()
+ Note over Core: 组装[system, history, user]
+
+ Core->>Token: estimate_chat_tokens()
+ Token-->>Core: estimated_usage
+
+ Core->>LLM: chat(messages)
+ Note over LLM: 调用底层API
+
+ LLM-->>Core: ModelResponse(content, usage)
+
+ Core->>Token: record_api_usage(usage)
+ Note over Token: 优先使用API数据
+
+ Core->>Parser: parse(content)
+ Note over Parser: JSON清洗+类型校验
+
+ Parser-->>Core: List[AnalysisObj]
+
+ Core->>Core: _add_to_history(user+assistant)
+ Note over Core: 更新对话记忆
+
+ Core->>Dispatcher: execute(parsed_results)
+
+ par 分发处理
+ Dispatcher->>Handler1: 同步处理(audio_text)
+ Dispatcher->>Handler2: 异步处理(auto_agent)
+ end
+
+ Dispatcher-->>Core: {"success": [], "failed": []}
+
+ Core-->>Client: 执行结果
+```
+
+
+#### 配置与模型热重载状态机
+```mermaid
+stateDiagram-v2
+ [*] --> 初始化: YosugaLLMCore()
+
+ state 初始化 {
+ [*] --> 加载配置: ModelConfig
+ 加载配置 --> 创建LLM客户端: UnifiedLLM
+ 创建LLM客户端 --> 注册默认Prompt: Audio/UITARS
+ 注册默认Prompt --> 初始化Token管理器: TokenManager
+ }
+
+ 初始化 --> 待机: 等待输入
+
+ state 待机 {
+ [*] --> 构建System Prompt
+ 构建System Prompt --> 检查上下文限制: _maintain_context_limit()
+ 检查上下文限制 --> 上下文溢出: current > limit
+ 检查上下文限制 --> 正常: 否则
+
+ 上下文溢出 --> 触发回调: _trigger_overflow_callbacks()
+ 触发回调 --> 清理历史: 保留50%
+ 清理历史 --> 正常
+
+ 正常 --> 组装消息链: _build_request_messages()
+ }
+
+ 待机 --> 调用LLM: chat()
+
+ 调用LLM --> 解析输出: AnalysisManager.parse()
+
+ 解析输出 --> 更新历史: _add_to_history()
+
+ 更新历史 --> 分发动作: Dispatcher.execute()
+
+ 分发动作 --> 返回结果: interact() return
+
+ 返回结果 --> 待机
+
+ 待机 --> 热重载: reload_model()
+
+ state 热重载 {
+ [*] --> 更新模型配置: update_config()
+ 更新模型配置 --> 重建LLM客户端: _create_client()
+ 重建LLM客户端 --> 重建Token管理器: TokenManager(model_name)
+ 重建Token管理器 --> [*]
+ }
+
+ 热重载 --> 待机
+
+ note right of 热重载 : "保留历史记忆\n不影响上下文"
+```
\ No newline at end of file
diff --git a/src/server_core/rag_core/readme.md b/src/server_core/rag_core/readme.md
new file mode 100644
index 0000000..a78988f
--- /dev/null
+++ b/src/server_core/rag_core/readme.md
@@ -0,0 +1,8 @@
+### 本模块为Yosuga_server的AI记忆模块
+
+
+##### 提供的功能有:
+1. 向量化的记忆存储
+2. 十分便捷的记忆管理与查询接口
+3. 高效的记忆检索
+
diff --git a/src/server_core/readme.md b/src/server_core/readme.md
new file mode 100644
index 0000000..ed8e761
--- /dev/null
+++ b/src/server_core/readme.md
@@ -0,0 +1,3 @@
+### 本模块为Yosuga_server的核心业务模块
+
+####
diff --git a/src/server_view/readme.md b/src/server_view/readme.md
new file mode 100644
index 0000000..793fa04
--- /dev/null
+++ b/src/server_view/readme.md
@@ -0,0 +1,4 @@
+### 本模块为Yosuga_server的前端部分
+
+为了方便管理服务端
+