commit 0753da86a886d66b9801b6d1271c1821df6f1f8e Author: Misakiotoha <1841738040@qq.com> Date: Tue Feb 3 01:20:00 2026 +0800 first pull diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ec46e94 --- /dev/null +++ b/.gitignore @@ -0,0 +1,31 @@ +# Python-generated files +__pycache__/ +*.py[oc] +build/ +dist/ +wheels/ +*.egg-info + +# from ide +.vscode +.idea + +# Virtual environments +.venv +venv + +# config +settings.json + +# logs +log/* + +# temp files +temp_files/* + +# ai models +src/modules/asr_module/asr_models/* + +# uv +uv.lock + diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..2c07333 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.11 diff --git a/README.md b/README.md new file mode 100644 index 0000000..c872b40 --- /dev/null +++ b/README.md @@ -0,0 +1,85 @@ +### Yosuga_server + +## 📊 Project Stats + +![GitHub last commit](https://img.shields.io/github/last-commit/Misakityan/Yosuga_server) +![GitHub issues](https://img.shields.io/github/issues/Misakityan/Yosuga_server) +![GitHub stars](https://img.shields.io/github/stars/Misakityan/Yosuga_server?style=social) + +欢迎访问本项目。 + +首先向你介绍一下Yosuga这个项目: + +本项目的作者是Misakiotoha(みさきおとは[見崎音羽])。[call me "Misaki" でいいよ] + +之所以叫Yosuga,这个词来源日语当中的单词"縁"的发音,其意思是"缘分,关系"。 + +本项目分为三个部分: +1. Yosuga:这是项目的前端部分,是Yosuga与用户交互的一层,采用C++20 + Qt6.6.3编写,使用到的核心外部库为Live2D For C++ SDK。 +2. Yosuga_server:这是项目的后端部分,是Yosuga的核心,采用python3.11编写,使用到的外部库较多,负责联系项目的各个部分。 +3. Yosuga_embedded:这是项目的拓展部分,使得Yosuga对嵌入式设备拥有几乎完全的自定义控制能力,采用C语言编写,只使用到了cJSON库,平台无关,增强了Yosuga与外界的交互能力。 + +**_本项目为Yosuga_server._** + +本项目使用uv构建,基于python3.11. +本项目由YosugaServer发展而来,项目架构与代码有了相当大的改变。(YosugaServer并未开源,它仅仅是一次小小的尝试) + + +### 如何快速启动本项目? +1. 确保uv已安装,并添加到环境变量中 +2. 执行`cd Yosuga_server` & `uv sync` +3. 接着,如果你的电脑带有cuda,那么执行 `uv pip install -r requirements-cuda.txt` +4. 如果没有cuda,那么执行 `uv pip install -r requirements-cpu.txt` +5. 最后执行 `uv run python main.py` 即可启动项目 + +首次启动项目后,会在项目根目录下生成settings.json配置文件,你需要配置一些必要的字段信息: +```json +{ + "ai": { + "api_key": "sk-xxxxx", + "base_url": "http://localhost:1234/v1", + "model_name": "qwen/qwen3-4b-2507" + }, + "tts": { + "gpt_model_name": "GPT_weights_v2Pro/Yosuga_Airi-e32.ckpt", + "sovits_model_name": "SoVITS_weights_v2Pro/Yosuga_Airi_e16_s864.pth", + "host": "localhost", + "port": 20261, + "reference_audio": "./using/reference.wav" + }, + "asr": { + "url": "http://localhost:20260/" + }, + "auto_agent": { + "deployment_type": "lmstudio", + "model_name": "ui-tars-1.5-7b@q4_k_m", + "base_url": "http://localhost:1234/v1" + }, + "llm_core": { + "role_character": "你是由Misakiotoha开发的助手稲葉愛理ちゃん,可以和用户一起玩游戏,聊天,做各种事情,性格抽象,没事爱整整活。", + "max_context_tokens": 2048, + "language": "日本语" + } +} +``` +上面这些字段的信息,你需要根据你的实际情况进行配置。实际的配置文件的字段名称会比上面的多出不少。 + + +配置完成后,再次重启服务端就可以使用啦~ + +接着是每个模型的配置相关: +1. asr模型,本项目使用fast-whisper作为asr模型,并且附带了一键启动的部分 +,你需要找到 `Yosuga_server/src/modules/asr_module/start_api.py` 这个文件,然后启动它 +,一般来说,即使是cpu也可以进行asr模型的推理,但是速度相比cuda要逊色很多。 +同时,如果你遇到了启动时长时间加载,那么此时你需要试着挂一下梯子,因为初次启动 +会在Hugging Face上下载模型。 +2. tts模型,本项目使用GPT-SoVITS作为tts模型,建议使用其V2Pro版本。 +3. auto_agent模型,本项目使用的自动化操作识别的模型为字节跳动开源的 +`ui-tars-1.5-7b@q4_k_m` 关于此模型的更多信息可以参考字节跳动的[开源链接](https://github.com/bytedance/UI-TARS) +,建议在LM Studio上进行部署,该模型十分轻量。 +4. ai模型,该模型限制为大语言模型,没有限制,本项目支持市面上的所有大语言模型。 + + +本项目当前并不完善,还有很多需要优化的地方,并且尚未接入Yosuga_embedded。 + +欢迎大家为本项目贡献代码。 \ No newline at end of file diff --git a/Test/GPTSoVITSTest.py b/Test/GPTSoVITSTest.py new file mode 100644 index 0000000..c80fd18 --- /dev/null +++ b/Test/GPTSoVITSTest.py @@ -0,0 +1,173 @@ +import asyncio +from src.modules.tts_module.tts_core.gpt_sovits.gpt_sovits_client import GPTSoVITSClient, StreamingMode +from src.modules.tts_module.tts_core.async_audio_player import AsyncAudioPlayer +import sounddevice as sd + +test_text = "春の午後、公園のベンチに座って本を読んでいると、小さな子供が凧あげをしているのが目に入った。風に乗って凧が高く上がるたびに、彼の顔には真っすぐな笑顔が広がる。母親がそばで見守りながら、時折声をかけている。\ +空は雲一つない青さで、桜の花びらが風に舞っている。遠くで犬の鳴き声が聞こえ、芝生の上では老夫婦がお茶を楽しんでいた。すべてがゆっくりと流れる時間の中で、自分の心も不思議と落ち着いてくる。\ +ふと、子供の凧が木の枝に引っかかってしまった。少し焦る様子だったが、母親が助けてくれて、すぐにまた空に舞い上がった。失敗しても、誰かが助けてくれる。そんな当たり前のことに、今日は特別な温かさを感じた。\ +日が傾き始める頃、私は本を閉じて家路についた。明日もきっと、誰かの笑顔があるだろう。" + +async def test_tts(): + # 创建客户端(推荐上下文管理器) + async with GPTSoVITSClient(debug=True, port= 20261, host="192.168.1.8") as client: + # 基础TTS调用 + try: + audio = await client.tts( + text="あのさ、いやまあ、なんていうか...要するに、そういうことじゃなくて、ほら、前に言ってたやつ、あれなんだけど、とにかく、後ででもいいから、ちょっと相談に乗ってくれない?", + ref_audio_path="uploaded_audio/test_voice.wav", # 服务器上的路径 + text_lang="ja", + prompt_lang="ja", + media_type="wav", + prompt_text="もう!こんなところで何やってるんだよ!" + ) + + # 保存音频 + audio.save("outputs/output.wav") + print(f"✅ TTS成功!音频大小: {len(audio.audio_data)} bytes") + + except Exception as e: + print(f"❌ 错误: {e}") +async def test_model_change(): + async with GPTSoVITSClient(debug=True, port= 20261, host="192.168.1.8") as client: + # 切换模型 + print("🔄 切换GPT模型...") + await client.set_gpt_weights( + "GPT_weights_v2Pro/Yosuga_Airi-e32.ckpt" + ) + + print("🔄 切换SoVITS模型...") + await client.set_sovits_weights( + "SoVITS_weights_v2Pro/Yosuga_Airi_e16_s864.pth" + ) + + +async def stream_tts(): + async with GPTSoVITSClient(debug=True, port= 20261, host="192.168.1.8") as client: + try: + # 使用最快模式流式输出 + chunk_count = 0 + async for chunk in await client.tts( + text="要するに、そういうことじゃなくて、ほら、前に言ってたやつ、あれなんだけど、とにかく、後ででもいいから、ちょっと相談に乗ってくれない?", + ref_audio_path="uploaded_audio/test_voice.wav", + text_lang="ja", + prompt_lang="ja", + prompt_text="もう!こんなところで何やってるんだよ!", + streaming_mode=StreamingMode.FASTEST, # 模式3:快速流式 + media_type="wav" + ): + chunk_count += 1 + print(f"🎵 收到音频块 #{chunk_count}: {len(chunk.audio_data)} bytes") + + # 实时播放处理 + # await play_audio_chunk(chunk.audio_data) + + print(f"✅ 流式TTS完成!共{chunk_count}个音频块") + + except Exception as e: + print(f"❌ 流式错误: {e}") + + +async def stream_tts_and_play( + text: str, + ref_audio_path: str, + text_lang: str = "zh", + prompt_lang: str = "zh", + streaming_mode: StreamingMode = StreamingMode.FASTEST +): + """ + 实时流式TTS + 播放一体化 + + Args: + text: 要合成的文本 + ref_audio_path: 参考音频路径 + text_lang: 文本语言 + prompt_lang: 提示语言 + """ + # 创建音频播放器(缓冲区大小=5,平衡延迟和稳定性) + async with AsyncAudioPlayer(buffer_size=5) as player: + # 创建TTS客户端 + async with GPTSoVITSClient(debug=True, port= 20261, host="192.168.1.8") as client: + try: + print(f"🎤 开始流式合成: {text[:30]}...") + print(f"🎯 流式模式: {streaming_mode.name}") + + # 获取音频流(异步生成器) + audio_stream = await client.tts( + text=text, + ref_audio_path=ref_audio_path, + text_lang=text_lang, + prompt_lang=prompt_lang, + prompt_text="もう!こんなところで何やってるんだよ!", + streaming_mode=streaming_mode, + media_type="wav", + sample_steps=32, + top_k=5, + temperature=1.0 + ) + + # 动态读取并播放 + chunk_idx = 0 + async for audio_chunk in audio_stream: + chunk_idx += 1 + print(f"📥 收到音频块 #{chunk_idx}: {len(audio_chunk.audio_data):6d} bytes") + + # 立即加入播放队列(非阻塞) + await player.add_chunk(audio_chunk.audio_data) + + print(f"✅ 合成完成! 共接收 {chunk_idx} 个音频块") + + # 等待播放完成(所有块播完) + await player.audio_queue.join() + print("🎵 播放完成!") + + except Exception as e: + print(f"❌ 错误: {e}") + raise + + +async def test_japanese(): + """测试日语长文本流式播放""" + print("=" * 50) + print("🗾 日语流式TTS测试") + print("=" * 50) + + await stream_tts_and_play( + text=test_text, + ref_audio_path="uploaded_audio/test_voice.wav", + text_lang="ja", + prompt_lang="ja", + + streaming_mode=StreamingMode.FASTEST # 模式3:最快 + ) + +async def batch_test(): + """批量处理示例""" + async with GPTSoVITSClient() as client: + texts = [ + "你好,世界!", + "这是一个批量测试。", + "异步批量处理非常高效。" + ] + + results = await client.batch_tts( + texts=texts, + ref_audio_path="archive_jingyuan_1.wav", + text_lang="zh" + ) + + for i, audio in enumerate(results): + audio.save(f"output/batch_{i}.wav") + print(f"✅ 批量任务 {i + 1}/{len(results)} 完成") + + + + + +if __name__ == "__main__": + # 检查音频设备 + print("🔍 检查音频设备...") + print(sd.query_devices()) + sd.default.device = (None, "pulse") # 使用PulseAudio + + asyncio.run(test_japanese()) \ No newline at end of file diff --git a/Test/WebsocketTestClient.py b/Test/WebsocketTestClient.py new file mode 100644 index 0000000..3b730fb --- /dev/null +++ b/Test/WebsocketTestClient.py @@ -0,0 +1,24 @@ +import asyncio +import json +from websockets.asyncio.client import connect + +async def test_all_types(): + """测试三种消息类型""" + async with connect("ws://localhost:8765") as ws: + print("=== 测试JSON消息 ===") + await ws.send(json.dumps({ + "type": "chat", + "content": "你好服务器!" + })) + print(f"收到: {await ws.recv()}") + + print("\n=== 测试文本消息 ===") + await ws.send("这是纯文本消息") + print(f"收到: {await ws.recv()}") + + print("\n=== 测试二进制消息 ===") + await ws.send(b"\x00\x01\x02\x03\x04") + print(f"收到: {await ws.recv()}") + +if __name__ == "__main__": + asyncio.run(test_all_types()) \ No newline at end of file diff --git a/Test/WebsocketTestServer.py b/Test/WebsocketTestServer.py new file mode 100644 index 0000000..786deb6 --- /dev/null +++ b/Test/WebsocketTestServer.py @@ -0,0 +1,174 @@ +""" +极简 WebSocket 测试服务器 - 修复版本 +""" +import asyncio +import json +import logging +from datetime import datetime +from typing import Set + +import websockets + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(message)s' +) + +class SimpleWebSocketServer: + def __init__(self, host="localhost", port=8765): + self.host = host + self.port = port + self.clients: Set = set() + + async def handle_connection(self, websocket, path): + """处理客户端连接""" + client_id = f"{websocket.remote_address[0]}:{websocket.remote_address[1]}" + self.clients.add(websocket) + logging.info(f"✅ 客户端连接: {client_id} (当前连接数: {len(self.clients)})") + + try: + # 发送欢迎消息 + welcome = { + "type": "connect", + "data": { + "message": "WebSocket 服务器连接成功", + "client_id": client_id, + "server_time": datetime.now().isoformat(), + "status": "connected" + }, + "timestamp": int(datetime.now().timestamp() * 1000) + } + await websocket.send(json.dumps(welcome)) + + async for message in websocket: + await self.handle_message(websocket, client_id, message) + + except websockets.exceptions.ConnectionClosed: + logging.info(f"❌ 客户端断开: {client_id}") + finally: + self.clients.discard(websocket) + logging.info(f"📊 剩余连接: {len(self.clients)}") + + async def handle_message(self, websocket, client_id, message): + """处理收到的消息""" + logging.info(f"📨 收到消息 from {client_id}: {message}") + + try: + # 尝试解析为 JSON + data = json.loads(message) + msg_type = data.get("type", "unknown") + + # 根据消息类型回复 + if msg_type == "ping": + # 心跳响应 + response = { + "type": "pong", + "data": { + "server_time": datetime.now().isoformat(), + "latency": "0ms" + }, + "timestamp": int(datetime.now().timestamp() * 1000) + } + + elif msg_type == "login": + # 登录响应 + username = data.get("data", {}).get("username", "anonymous") + response = { + "type": "login_success", + "data": { + "user_id": f"user_{abs(hash(username)) % 10000}", + "username": username, + "status": "authenticated" + }, + "timestamp": int(datetime.now().timestamp() * 1000) + } + + elif msg_type == "chat": + # 聊天消息回应 + msg_content = data.get("data", {}).get("message", "") + response = { + "type": "chat_response", + "data": { + "message": f"服务器收到: {msg_content}", + "sender": "server", + "received_at": datetime.now().isoformat() + }, + "timestamp": int(datetime.now().timestamp() * 1000) + } + + else: + # 默认回显 + response = { + "type": "echo", + "data": { + "original": data.get("data", {}), + "original_type": msg_type, + "server_processed_at": datetime.now().isoformat() + }, + "timestamp": int(datetime.now().timestamp() * 1000) + } + + await websocket.send(json.dumps(response)) + + except json.JSONDecodeError: + # 不是 JSON,当作纯文本处理 + response = { + "type": "text_echo", + "data": { + "original": message, + "note": "这是文本消息" + }, + "timestamp": int(datetime.now().timestamp() * 1000) + } + await websocket.send(json.dumps(response)) + + async def start(self): + """启动服务器""" + logging.info(f"🚀 启动 WebSocket 服务器: ws://{self.host}:{self.port}") + + # 创建处理函数包装器(解决参数问题) + async def connection_handler(websocket, path): + await self.handle_connection(websocket, path) + + # 启动服务器 + server = await websockets.serve( + connection_handler, + self.host, + self.port, + ping_interval=None, + ping_timeout=None, + close_timeout=None, + max_size=10 * 1024 * 1024 + ) + + logging.info("📌 服务器已启动,等待连接...") + logging.info("🛑 按 Ctrl+C 停止服务器") + + # 保持服务器运行 + try: + await asyncio.Future() # 永久运行 + finally: + server.close() + await server.wait_closed() + logging.info("👋 服务器已关闭") + +def main(): + """主函数""" + import argparse + + parser = argparse.ArgumentParser(description='极简 WebSocket 测试服务器') + parser.add_argument('--host', default='localhost', help='监听地址') + parser.add_argument('--port', type=int, default=8088, help='监听端口') + + args = parser.parse_args() + + server = SimpleWebSocketServer(args.host, args.port) + + try: + asyncio.run(server.start()) + except KeyboardInterrupt: + logging.info("👋 服务器被用户中断") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/Test/asrRequestTest.py b/Test/asrRequestTest.py new file mode 100644 index 0000000..4e7f295 --- /dev/null +++ b/Test/asrRequestTest.py @@ -0,0 +1,33 @@ +# requestTest.py +import requests +from pathlib import Path + +# 指定正确的 MIME 类型 +url = "http://192.168.1.8:20260/transcribe" +audio_path = Path("test_files/z105300938.wav") + +with open(audio_path, "rb") as f: + # 明确指定文件名和 MIME 类型 + files = { + "file": ( + audio_path.name, # 文件名 + f, # 文件对象 + "audio/wav" # MIME 类型 + ) + } + + response = requests.post(url, files=files) + +# 打印响应详情 +print(f"状态码: {response.status_code}") +print(f"响应头: {response.headers.get('content-type')}") + +# 检查响应是否成功 +if response.status_code == 200: + result = response.json() + print(f"识别结果: {result['data']['text']}") + print(f"语言: {result['data']['language']}") + print(f"置信度: {result['data']['confidence']}") + print(f"处理时间: {result['data']['processing_time']}s") +else: + print(f"错误响应: {response.text}") \ No newline at end of file diff --git a/Test/dtosAndTTSAndASR.py b/Test/dtosAndTTSAndASR.py new file mode 100644 index 0000000..feff988 --- /dev/null +++ b/Test/dtosAndTTSAndASR.py @@ -0,0 +1,14 @@ +# 一个小Test, 展示设计的dtos模块与tts和asr的集成 +from src.modules.websocket_base_module.dto.third_dtos import AudioDataDTO +from src.modules.tts_module.tts_core.async_audio_player import AsyncAudioPlayer +from src.modules.tts_module.tts_core.gpt_sovits.gpt_sovits_client import GPTSoVITSClient, StreamingMode +from src.modules.asr_module.client.asr_client import create_asr_client + + +# with create_asr_client(base_url="http://192.168.1.5:20260") as client: +# # 转录文件 +# result = client.transcribe_file("test_files/test.wav") +# print(f"识别结果: {result.data.text}") +# print(f"置信度: {result.data.confidence:.2f}") +# print(f"耗时: {result.data.processing_time:.3f}s") + diff --git a/Test/dtosTest.py b/Test/dtosTest.py new file mode 100644 index 0000000..d463f9e --- /dev/null +++ b/Test/dtosTest.py @@ -0,0 +1,30 @@ +from src.modules.websocket_base_module.dto.second_dtos import get_json_dto_instance +from src.modules.websocket_base_module.dto.third_dtos import AudioDataDTO +from src.modules.websocket_base_module.websocket_core.core_ws_server import get_ws_server +import asyncio +from loguru import logger +async def main(): + # 获取WebSocket服务器单例 + ws_server = await get_ws_server() + # 获取二级json分发器单例 + json_dto = await get_json_dto_instance(ws_server) + + # 创建DTO实例(自动注册接收函数) + audio_dto = AudioDataDTO(json_dto) + + logger.info("所有DTO接收器已注册,等待客户端连接...") + + # 启动服务器(阻塞) + try: + await ws_server.run("localhost", 8765) + except asyncio.CancelledError: + logger.info("服务器任务已取消,正在优雅退出...") + finally: + logger.info("服务器已停止") + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\n✓ 服务器已手动终止(按 Ctrl+C)") \ No newline at end of file diff --git a/Test/outputs/output.wav b/Test/outputs/output.wav new file mode 100644 index 0000000..f9d7828 Binary files /dev/null and b/Test/outputs/output.wav differ diff --git a/Test/outputs/output1.wav b/Test/outputs/output1.wav new file mode 100644 index 0000000..045e77b Binary files /dev/null and b/Test/outputs/output1.wav differ diff --git a/Test/test_files/Screenshot_test.png b/Test/test_files/Screenshot_test.png new file mode 100644 index 0000000..9da1f4e Binary files /dev/null and b/Test/test_files/Screenshot_test.png differ diff --git a/Test/test_files/okoru.wav b/Test/test_files/okoru.wav new file mode 100644 index 0000000..7b540d9 Binary files /dev/null and b/Test/test_files/okoru.wav differ diff --git a/Test/test_files/sad.wav b/Test/test_files/sad.wav new file mode 100644 index 0000000..7c36d27 Binary files /dev/null and b/Test/test_files/sad.wav differ diff --git a/Test/test_files/test.wav b/Test/test_files/test.wav new file mode 100644 index 0000000..7aaf00c Binary files /dev/null and b/Test/test_files/test.wav differ diff --git a/Test/test_files/z105300938.wav b/Test/test_files/z105300938.wav new file mode 100644 index 0000000..1afc693 Binary files /dev/null and b/Test/test_files/z105300938.wav differ diff --git a/Test/textAITest.py b/Test/textAITest.py new file mode 100644 index 0000000..1f984a1 --- /dev/null +++ b/Test/textAITest.py @@ -0,0 +1,88 @@ +from src.modules.text_ai_module.text_ai_core.general_text_ai_req import UnifiedLLM, ModelConfig, ModelProvider, create_llm_client +from src.config.config import get_settings +from src.config.convert_env import EnvConverter +from src.config.file_config import DirectoryInitializer + +EnvConverter().convert(backup_existing=True) # 若是首次启动则从env模板中生成env文件 +DirectoryInitializer(get_settings()) # 初始化必要的目录(若不存在则创建) + +def test1(): + """ + 测试常规调用 + """ + # 配置模型 + config = ModelConfig( + provider=ModelProvider.OPENAI, + model_name=get_settings().ai_model_name, + base_url=get_settings().ai_api_base_url, + api_key=get_settings().ai_api_key, # 从环境中取出相关的api_key + temperature=0.7, + max_tokens=2048 + ) + + # 创建客户端 + llm = UnifiedLLM(config) + + # 发送消息 + response = llm.chat([ + {"role": "system", "content": "你是一个DeepSeek助手"}, + {"role": "user", "content": "请介绍一下DeepSeek模型的特点"} + ]) + + print(response.content) + +def base_test2(): + """ + 测试流式响应 + """ + # 使用快捷函数 + deepseek_llm = create_llm_client( + provider="openai", # DeepSeek使用OpenAI兼容接口 + model_name=get_settings().ai_model_name, + api_key=get_settings().ai_api_key, + base_url=get_settings().ai_api_base_url + ) + + # 流式聊天 + messages = [ + {"role": "user", "content": "用Python写一个快速排序算法"} + ] + + print("正在生成响应...") + for chunk in deepseek_llm.stream_chat(messages): + print(chunk.content, end="", flush=True) + + +def test_lm_studio(): + """测试本地 LM Studio 模型""" + print("=== 测试本地 LM Studio ===") + + # 使用UnifiedLLM类 + config = ModelConfig( + provider=ModelProvider.LM_STUDIO, + model_name="qwen/qwen3-4b-2507", + base_url="http://192.168.1.8:1234/v1", + api_key="", # LM Studio不需要API密钥,留空 + temperature=0.7, + max_tokens=1024, + streaming=False # 启用流式响应 + ) + + llm = UnifiedLLM(config) + + # 发送消息 + messages = [ + {"role": "system", "content": "你是一个有用的助手"}, + {"role": "user", "content": "用中文介绍一下自己"} + ] + + print("非流式响应:") + response = llm.chat(messages, streaming=False) + print(response.content) + + print("\n流式响应:") + for chunk in llm.stream_chat(messages): + print(chunk.content, end="", flush=True) + +if __name__ == "__main__": + test_lm_studio() \ No newline at end of file diff --git a/Test/ui_tars_test.py b/Test/ui_tars_test.py new file mode 100644 index 0000000..1fa98e1 --- /dev/null +++ b/Test/ui_tars_test.py @@ -0,0 +1,66 @@ +import asyncio +import base64 +from pathlib import Path +from src.modules.device_control_module.device_control_core.ui_tars_.ui_tars_client import UITarsClient, UITarsClientConfig + + +async def test_ui_tars_stream(): + """测试 UI-TARS 流式调用""" + # 创建客户端 + config = UITarsClientConfig( + deployment_type="lmstudio", + base_url="http://192.168.1.8:1234/v1", + model_name="ui-tars-1.5-7b@q4_k_m", + temperature=0.1 + ) + client = UITarsClient(config) + + # 使用工具方法编码 + image_base64 = base64.b64encode(Path("test_files/Screenshot_test.png").read_bytes()).decode() + print(f"✅ 图片编码完成,长度: {len(image_base64)} 字符\n") + + # 流式调用并实时打印 + print("🤖 开始流式调用 UI-TARS...\n") + print("思考过程:\n") + + import time + # 计算耗时 + start_time = time.time() + full_response = "" + chunk_count = 0 + + full_response = await client.call_async("打开AK加速器", image_base64) + # 传入 base64 字符串 + # for chunk in client.stream_async("我的桌面系统是KDE, 帮我打开设置", image_base64): + # chunk_count += 1 + # content = chunk.content + # + # # 实时打印每个 chunk + # print(content, end="", flush=True) + # + # # 累积完整内容 + # full_response += content + + end_time = time.time() + print(f"\n\n耗时: {end_time - start_time:.2f} 秒") + print(f"\n\n{'=' * 50}") + print(f"✅ 流式调用完成!共接收 {chunk_count} 个 chunk") + print(f"完整响应长度: {len(full_response)} 字符") + + print("响应内容:\n") + print(full_response) + +import pyautogui +def auto_click(x : int, y : int): + pyautogui.moveTo(x, y, duration=1.5) + pyautogui.click() + +def auto_drag(x1 : int, y1 : int, x2 : int, y2 : int): + pyautogui.moveTo(x1, y1, duration=1.5) + pyautogui.dragTo(x2, y2, duration=1.5) + +# 运行异步函数 +if __name__ == "__main__": + asyncio.run(test_ui_tars_stream()) + auto_click(173,48) + # auto_drag(56,39, 170,39) \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..0cccc0c --- /dev/null +++ b/main.py @@ -0,0 +1,45 @@ +import asyncio +from contextlib import suppress + +from loguru import logger +from src.config.config import cfg +from datetime import datetime +from src.server_core.core import YosugaServerCore + +def init(): + """ + Yosuga_server 初始化 + """ + + # 初始化日志系统 + logger.add( + f"{cfg.log_dir}/Yosuga_server-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.log", + encoding="utf-8") + logger.info("Yosuga_server 启动") + logger.info(f"日志文件目录见: {cfg.log_dir} 目录") + + +async def main(): + core = await YosugaServerCore.get_instance() + try: + await core.run() + except asyncio.CancelledError: + pass # 正常取消,不打印堆栈 + finally: + # 清理未关闭的 aiohttp sessions + import aiohttp + pending = [t for t in asyncio.all_tasks() + if t is not asyncio.current_task()] + for task in pending: + task.cancel() + with suppress(asyncio.CancelledError): + await asyncio.gather(*pending, return_exceptions=True) + + +if __name__ == "__main__": + init() + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\nYosuga服务器已停止喵~~~") + \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..9f1eb20 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,24 @@ +[project] +name = "yosuga-server" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +requires-python = ">=3.11" +dependencies = [ + "aiofiles>=25.1.0", + "aiohttp>=3.13.3", + "fastapi>=0.128.0", + "faster-whisper>=1.2.1", + "loguru>=0.7.3", + "openai>=2.16.0", + "pyautogui>=0.9.54", + "pydantic>=2.12.5", + "pydantic-settings>=2.12.0", + "python-multipart>=0.0.22", + "requests>=2.32.5", + "sounddevice>=0.5.5", + "soundfile>=0.13.1", + "tiktoken>=0.12.0", + "uvicorn>=0.40.0", + "websockets>=16.0", +] diff --git a/requirements-cpu.txt b/requirements-cpu.txt new file mode 100644 index 0000000..488b1aa --- /dev/null +++ b/requirements-cpu.txt @@ -0,0 +1,84 @@ +--index-url https://download.pytorch.org/whl/cpu +aiofiles==25.1.0 +aiohappyeyeballs==2.6.1 +aiohttp==3.13.3 +aiosignal==1.4.0 +annotated-doc==0.0.4 +annotated-types==0.7.0 +anyio==4.12.1 +attrs==25.4.0 +av==16.1.0 +certifi==2026.1.4 +cffi==2.0.0 +charset-normalizer==3.4.4 +click==8.3.1 +colorama==0.4.6 +coloredlogs==15.0.1 +ctranslate2==4.6.3 +distro==1.9.0 +fastapi==0.128.0 +faster-whisper==1.2.1 +filelock==3.20.3 +flatbuffers==25.12.19 +frozenlist==1.8.0 +fsspec==2026.1.0 +h11==0.16.0 +hf-xet==1.2.0 +httpcore==1.0.9 +httpx==0.28.1 +huggingface-hub==1.3.7 +humanfriendly==10.0 +idna==3.11 +jinja2==3.1.6 +jiter==0.13.0 +loguru==0.7.3 +markupsafe==2.1.5 +mouseinfo==0.1.3 +mpmath==1.3.0 +multidict==6.7.1 +networkx==3.6.1 +numpy==2.4.2 +onnxruntime==1.23.2 +openai==2.16.0 +packaging==26.0 +pillow==12.1.0 +propcache==0.4.1 +protobuf==6.33.5 +pyautogui==0.9.54 +pycparser==3.0 +pydantic==2.12.5 +pydantic-core==2.41.5 +pydantic-settings==2.12.0 +pygetwindow==0.0.9 +pymsgbox==2.0.1 +pyperclip==1.11.0 +pyreadline3==3.5.4 +pyrect==0.2.0 +pyscreeze==1.0.1 +python-dotenv==1.2.1 +python-multipart==0.0.22 +pytweening==1.2.0 +pyyaml==6.0.3 +regex==2026.1.15 +requests==2.32.5 +setuptools==80.10.2 +shellingham==1.5.4 +sniffio==1.3.1 +sounddevice==0.5.5 +soundfile==0.13.1 +starlette==0.50.0 +sympy==1.14.0 +tiktoken==0.12.0 +tokenizers==0.22.2 +torch==2.5.1+cpu +torchaudio==2.5.1+cpu +torchvision==0.20.1+cpu +tqdm==4.67.2 +typer-slim==0.21.1 +typing-extensions==4.15.0 +typing-inspection==0.4.2 +urllib3==2.6.3 +uvicorn==0.40.0 +websockets==16.0 +win32-setctime==1.2.0 +yarl==1.22.0 diff --git a/requirements-cuda.txt b/requirements-cuda.txt new file mode 100644 index 0000000..f9ae722 Binary files /dev/null and b/requirements-cuda.txt differ diff --git a/src/config/config.py b/src/config/config.py new file mode 100644 index 0000000..071e3a7 --- /dev/null +++ b/src/config/config.py @@ -0,0 +1,452 @@ +""" +零初始化 JSON 配置管理 +用法:from src.config import cfg # 直接访问,自动初始化 +""" +import json +import threading +from pathlib import Path +from typing import Any, Dict, Optional, TypeVar +from dataclasses import dataclass, field, asdict + +def _project_root() -> Path: + """自动查找项目根目录""" + markers = ['pyproject.toml', 'settings.json', '.gitignore', 'main.py', '.python-version'] + current = Path(__file__).resolve().parent.parent # src的父目录 + + for path in [current, *current.parents]: + if any((path / m).exists() for m in markers): + return path + if path == path.parent: + break + return current + +# 配置定义 + +@dataclass +class AIConfig: + api_key: Optional[str] = "sk-xxxxx" + base_url: str = "http://localhost:1234/v1" + model_name: str = "qwen/qwen3-4b-2507" + timeout: int = 30 + temperature: float = 0.4 + max_tokens: int = 8192 + + +@dataclass +class TTSConfig: + enabled: bool = True + api_key: Optional[str] = None + gpt_model_name: str = "GPT_weights_v2Pro/Yosuga_Airi-e32.ckpt" + sovits_model_name: str = "SoVITS_weights_v2Pro/Yosuga_Airi_e16_s864.pth" + host: str = "localhost" + port: int = 20261 + reference_audio: str = "./using/reference.wav" + streaming: bool = True + speed: float = 1.0 + +@dataclass +class ASRConfig: + enabled: bool = True + api_key: Optional[str] = None + model_name: str = "fast-whisper" + url: str = "http://localhost:20260/" + +@dataclass +class AutoAgentConfig: + enabled: bool = True + api_key: Optional[str] = None + deployment_type: str = "lmstudio" + model_name: str = "ui-tars-1.5-7b@q4_k_m" + base_url: str = "http://localhost:1234/v1" + temperature: float = 0.1 + max_tokens: int = 16384 + +@dataclass +class LLMConfig: + enabled: bool = True + role_character: str = "你是由Misakiotoha开发的助手稲葉愛理ちゃん,可以和用户一起玩游戏,聊天,做各种事情,性格抽象,没事爱整整活。" + max_context_tokens: int = 2048 + enable_history: bool = True + language: str = "日本语" + +@dataclass +class PathsConfig: + temp: str = "./tmp/" + log: str = "./log/" + using: str = "./using/" + + +class AppConfig: + """ + 应用主配置 + 新增配置分组:1) 上方新建 dataclass 2) 下方 __init__ 添加字段 3) 完成 + """ + + def __init__( + self, + version: str = "1.0.0", + debug: bool = False, + ai: Optional[AIConfig] = None, + tts: Optional[TTSConfig] = None, + asr: Optional[ASRConfig] = None, + auto_agent: Optional[AutoAgentConfig] = None, + llm_core: Optional[LLMConfig] = None, + paths: Optional[PathsConfig] = None, + _config_path: Optional[Path] = None, + **kwargs + ): + # 基础字段 + self.version = version + self.debug = debug + self.ai = ai if ai is not None else AIConfig() + self.tts = tts if tts is not None else TTSConfig() + self.asr = asr if asr is not None else ASRConfig() + self.auto_agent = auto_agent if auto_agent is not None else AutoAgentConfig() + self.llm_core = llm_core if llm_core is not None else LLMConfig() + self.paths = paths if paths is not None else PathsConfig() + + # 内部状态(非 dataclass 字段,不会被序列化) + self._config_path = _config_path + self._lock = threading.RLock() # 普通属性,非 dataclass 字段 + + # 应用其他字段(用于从 JSON 加载时) + for k, v in kwargs.items(): + if hasattr(self, k): + setattr(self, k, v) + + # 路径解析为绝对路径 + if self._config_path: + root = self._config_path.parent + for field_name in ['temp', 'log', 'using']: + rel_path = getattr(self.paths, field_name) + if not Path(rel_path).is_absolute(): + abs_path = (root / Path(rel_path)).resolve() + setattr(self.paths, field_name, str(abs_path) + '/') + + # 便捷属性 + + @property + def temp_dir(self) -> Path: + return Path(self.paths.temp) + + @property + def log_dir(self) -> Path: + return Path(self.paths.log) + + @property + def using_dir(self) -> Path: + return Path(self.paths.using) + + # 核心方法 + + def get(self, key: str, default: Any = None) -> Any: + """ + 点号路径访问:cfg.get("ai.timeout") / cfg.get("tts.enabled") + """ + try: + keys = key.split('.') + value = self + for k in keys: + value = getattr(value, k) if not isinstance(value, dict) else value[k] + return value + except (AttributeError, KeyError): + return default + + def set(self, key: str, value: Any, save: bool = True) -> 'AppConfig': + """ + 点号路径设置,支持链式调用 + cfg.set("ai.timeout", 60).set("debug", True) + """ + with self._lock: + keys = key.split('.') + target = self + for k in keys[:-1]: + target = getattr(target, k) + setattr(target, keys[-1], value) + + if save: + self._save() + return self + + def update(self, updates: Dict[str, Any], save: bool = True) -> 'AppConfig': + """ + 批量更新:cfg.update({"ai": {"timeout": 60}, "debug": True}) + """ + def deep_update(obj: Any, data: dict): + for k, v in data.items(): + if hasattr(obj, k): + current = getattr(obj, k) + if isinstance(v, dict) and hasattr(current, '__dataclass_fields__'): + deep_update(current, v) + else: + setattr(obj, k, v) + + with self._lock: + deep_update(self, updates) + + if save: + self._save() + return self + + def reload(self) -> 'AppConfig': + """热重载配置""" + if self._config_path and self._config_path.exists(): + with self._lock: + data = json.loads(self._config_path.read_text(encoding='utf-8')) + + # 配置项名 -> dataclass 类的映射(和 _load 保持一致) + config_classes = { + 'ai': AIConfig, + 'tts': TTSConfig, + 'asr': ASRConfig, + 'auto_agent': AutoAgentConfig, + 'llm_core': LLMConfig, + 'paths': PathsConfig, + } + + for k, v in data.items(): + if hasattr(self, k) and not k.startswith('_'): + # 如果是配置项且是 dict,转换为 dataclass + if k in config_classes and isinstance(v, dict): + setattr(self, k, config_classes[k](**v)) + else: + setattr(self, k, v) + print(f"配置重载: {self._config_path}") + return self + + def to_dict(self) -> dict: + """导出为字典(手动实现,排除内部属性)""" + result = { + 'version': self.version, + 'debug': self.debug, + 'ai': asdict(self.ai) if hasattr(self.ai, '__dataclass_fields__') else self.ai, + 'tts': asdict(self.tts) if hasattr(self.tts, '__dataclass_fields__') else self.tts, + 'asr': asdict(self.asr) if hasattr(self.asr, '__dataclass_fields__') else self.asr, + 'auto_agent': asdict(self.auto_agent) if hasattr(self.auto_agent, '__dataclass_fields__') else self.auto_agent, + 'llm_core': asdict(self.llm_core) if hasattr(self.llm_core, '__dataclass_fields__') else self.llm_core, + 'paths': asdict(self.paths) if hasattr(self.paths, '__dataclass_fields__') else self.paths, + } + return result + + def _save(self) -> None: + """保存到文件""" + if self._config_path: + with self._lock: + json_str = json.dumps(self.to_dict(), indent=2, ensure_ascii=False) + self._config_path.write_text(json_str, encoding='utf-8') + + @classmethod + def _load(cls, path: Path) -> 'AppConfig': + """从文件加载""" + data = json.loads(path.read_text(encoding='utf-8')) + + # 配置项名 -> dataclass 类的映射 + config_classes = { + 'ai': AIConfig, + 'tts': TTSConfig, + 'asr': ASRConfig, + 'auto_agent': AutoAgentConfig, + 'llm_core': LLMConfig, + 'paths': PathsConfig, + } + + # 自动转换 dict 为对应 dataclass + for key, config_class in config_classes.items(): + if key in data and isinstance(data[key], dict): + data[key] = config_class(**data[key]) + + return cls(_config_path=path, **data) + + @classmethod + def _create_default(cls, path: Path) -> 'AppConfig': + """创建默认配置""" + instance = cls(_config_path=path) + instance._save() + print(f"默认配置已创建: {path}") + return instance + + def __repr__(self) -> str: + """友好的打印格式""" + lines = ["AppConfig("] + for k, v in self.to_dict().items(): + lines.append(f" {k}={v!r},") + lines.append(")") + return "\n".join(lines) + + +# 延迟初始化机制 + +_root: Path = _project_root() +_config_path: Path = _root / "settings.json" +_config_instance: Optional[AppConfig] = None +_init_lock: threading.Lock = threading.Lock() + + +def _ensure_initialized() -> AppConfig: + """ + 确保配置已初始化(线程安全的延迟初始化) + """ + global _config_instance + + if _config_instance is not None: + return _config_instance + + with _init_lock: + if _config_instance is not None: + return _config_instance + + # 自动加载或创建 + if _config_path.exists(): + _config_instance = AppConfig._load(_config_path) + print(f"配置加载: {_config_path}") + else: + _config_instance = AppConfig._create_default(_config_path) + + # 确保目录存在 + for d in [_config_instance.temp_dir, + _config_instance.log_dir, + _config_instance.using_dir]: + d.mkdir(parents=True, exist_ok=True) + + return _config_instance + + +class _LazyConfig: + """ + 配置代理类:拦截所有属性访问,第一次使用时自动初始化 + """ + + def __getattr__(self, name: str) -> Any: + """拦截属性访问,延迟初始化""" + instance = _ensure_initialized() + return getattr(instance, name) + + def __setattr__(self, name: str, value: Any) -> None: + """拦截属性设置""" + # 特殊属性直接设置到代理对象本身 + if name.startswith('_'): + super().__setattr__(name, value) + else: + instance = _ensure_initialized() + setattr(instance, name, value) + + def __repr__(self) -> str: + instance = _ensure_initialized() + return repr(instance) + + def __dir__(self) -> list: + """支持 IDE 自动补全""" + instance = _ensure_initialized() + return dir(instance) + + # 显式代理必要方法 + def get(self, key: str, default: Any = None) -> Any: + return _ensure_initialized().get(key, default) + + def set(self, key: str, value: Any, save: bool = True) -> AppConfig: + return _ensure_initialized().set(key, value, save) + + def update(self, updates: Dict[str, Any], save: bool = True) -> AppConfig: + return _ensure_initialized().update(updates, save) + + def reload(self) -> AppConfig: + return _ensure_initialized().reload() + + def to_dict(self) -> dict: + return _ensure_initialized().to_dict() + + def save(self) -> None: + _ensure_initialized()._save() + + # 属性代理 + @property + def ai(self) -> AIConfig: + return _ensure_initialized().ai + + @property + def tts(self) -> TTSConfig: + return _ensure_initialized().tts + + @property + def asr(self) -> ASRConfig: + return _ensure_initialized().asr + + @property + def auto_agent(self) -> AutoAgentConfig: + return _ensure_initialized().auto_agent + + @property + def llm_core(self) -> LLMConfig: + return _ensure_initialized().llm_core + + @property + def paths(self) -> PathsConfig: + return _ensure_initialized().paths + + @property + def temp_dir(self) -> Path: + return _ensure_initialized().temp_dir + + @property + def log_dir(self) -> Path: + return _ensure_initialized().log_dir + + @property + def using_dir(self) -> Path: + return _ensure_initialized().using_dir + + +# 全局配置对象:导入即用,自动初始化 +cfg: AppConfig = _LazyConfig() # type: ignore + + +# 工具函数 + +def generate_example(path: Path = _root / "settings.example.json") -> None: + """生成示例配置文件""" + example = { + "version": "1.0.0", + "debug": False, + "ai": { + "api_key": "sk-your-api-key", + "base_url": "https://api.deepseek.com", + "model": "deepseek-chat", + "timeout": 30 + }, + "tts": { + "enabled": True, + "model": "GPT_SoVITS", + "url": "http://localhost:12458/", + "reference_audio": "./using/reference.wav", + "streaming": True, + "speed": 1.0 + }, + "paths": { + "temp": "./tmp/", + "log": "./log/", + "data": "./data/" + } + } + path.write_text(json.dumps(example, indent=2, ensure_ascii=False), encoding='utf-8') + print(f"示例配置已生成: {path}") + + +# 测试代码 +if __name__ == "__main__": + # 测试:直接访问,自动初始化 + print("第一次访问 cfg.ai.model:") + print(f" → {cfg.ai.model_name}") + + print(f"\n配置详情:") + print(cfg) + + print(f"\n测试修改:") + cfg.set("ai.timeout", 60) + print(f" ai.timeout = {cfg.ai.timeout}") + + print(f"\n测试批量更新:") + cfg.update({"debug": True, "tts": {"speed": 1.5}}) + print(f" debug = {cfg.debug}, tts.speed = {cfg.tts.speed}") + + print(f"\n测试热重载:") + cfg.reload() \ No newline at end of file diff --git a/src/config/readme.md b/src/config/readme.md new file mode 100644 index 0000000..e69de29 diff --git a/src/modules/__init__.py b/src/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/modules/asr_module/__init__.py b/src/modules/asr_module/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/modules/asr_module/api.py b/src/modules/asr_module/api.py new file mode 100644 index 0000000..818e306 --- /dev/null +++ b/src/modules/asr_module/api.py @@ -0,0 +1,123 @@ +# asr_module/api.py +from fastapi import FastAPI, File, UploadFile, HTTPException +from fastapi.responses import JSONResponse +from pathlib import Path +import tempfile +import time +from datetime import datetime +from loguru import logger +from src.modules.asr_module.asr_core.fast_whisper import create_asr, ASRConfig + +# 初始化FastAPI应用 +app = FastAPI( + title="Yosuga ASR API", + description="基于faster-whisper Turbo的高性能多语种语音转文本服务", + version="1.0.0" +) + +# 全局单例ASR实例(延迟加载) +_asr_instance = None + +def get_asr(): + """获取或创建ASR实例(单例)""" + global _asr_instance + if _asr_instance is None: + logger.info("🚀 初始化ASR服务...") + _asr_instance = create_asr( + ASRConfig( + model_name="deepdml/faster-whisper-large-v3-turbo-ct2", + device="auto", + compute_type="int8_float16", + cache_dir=Path("asr_models/faster_whisper_large_v3_ct2"), + beam_size=1, # 贪婪搜索,速度最快 + vad_filter=True, # 过滤静音,节省30%时间 + ) + ) + logger.info("✅ ASR服务初始化完成") + return _asr_instance + +@app.on_event("startup") +async def startup_event(): + """应用启动时预加载模型""" + get_asr() + +@app.on_event("shutdown") +async def shutdown_event(): + """应用关闭时清理资源""" + global _asr_instance + if _asr_instance: + _asr_instance.shutdown() + logger.info("🛑 ASR服务已关闭") + +@app.post("/transcribe", response_class=JSONResponse) +async def transcribe_audio( + file: UploadFile = File(..., description="音频文件 (WAV, FLAC, MP3等格式)") +): + """ + 语音转文本API + + - **file**: 音频文件,支持WAV/FLAC/MP3等格式 + - **返回**: JSON格式结果,包含text/language/confidence + """ + start_time = time.time() + + # 验证文件类型 + if file.content_type and not file.content_type.startswith("audio/"): + raise HTTPException(status_code=400, detail="❌ 请上传音频文件 (MIME类型: audio/*)") + + try: + # 创建临时文件 + with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as tmp_file: + content = await file.read() + tmp_file.write(content) + tmp_path = Path(tmp_file.name) + + logger.info(f"📥 接收文件: {file.filename} ({len(content)} bytes)") + + # 调用ASR识别 + asr = get_asr() + text, language, confidence = asr.transcribe_wav(tmp_path) + + # 清理临时文件 + tmp_path.unlink(missing_ok=True) + + processing_time = time.time() - start_time + + logger.info(f"✅ 识别完成: {language} | {len(text)}字符 | 置信度:{confidence:.2f} | 耗时:{processing_time:.3f}s") + + return { + "success": True, + "data": { + "text": text, + "language": language, + "confidence": confidence, + "processing_time": round(processing_time, 3) + } + } + + except Exception as e: + logger.error(f"❌ 识别失败: {e}") + raise HTTPException(status_code=500, detail=f"识别失败: {str(e)}") + +@app.get("/health") +async def health_check(): + """健康检查接口""" + asr = get_asr() + health = asr.health_check() + + return { + "status": "healthy" if health["status"] == "healthy" else "unhealthy", + "timestamp": datetime.now().isoformat(), + "device": health["device"], + "model_loaded": health["model_loaded"] + } + +@app.get("/") +async def root(): + """API根路径""" + return { + "message": "Yosuga ASR API 正在运行", + "docs": "/docs", + "health": "/health", + "transcribe": "/transcribe (POST)" + } \ No newline at end of file diff --git a/src/modules/asr_module/asr_core/__init__.py b/src/modules/asr_module/asr_core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/modules/asr_module/asr_core/fast_whisper/__init__.py b/src/modules/asr_module/asr_core/fast_whisper/__init__.py new file mode 100644 index 0000000..cc40188 --- /dev/null +++ b/src/modules/asr_module/asr_core/fast_whisper/__init__.py @@ -0,0 +1,16 @@ +# fast_whisper/__init__.py +from typing import Optional +from src.modules.asr_module.asr_core.fast_whisper.config import ASRConfig +from src.modules.asr_module.asr_core.fast_whisper.model_manager import ModelManager +from src.modules.asr_module.asr_core.fast_whisper.asr_interface import ASRInterface + +__version__ = "1.0.0" +__all__ = ["ASRConfig", "ModelManager", "ASRInterface"] + +def create_asr(config: Optional[ASRConfig] = None) -> ASRInterface: + """ + 快速创建ASR实例 + Args: + config: ASR配置,若为None则使用默认配置 + """ + return ASRInterface.get_instance(config) \ No newline at end of file diff --git a/src/modules/asr_module/asr_core/fast_whisper/asr_interface.py b/src/modules/asr_module/asr_core/fast_whisper/asr_interface.py new file mode 100644 index 0000000..6eeab10 --- /dev/null +++ b/src/modules/asr_module/asr_core/fast_whisper/asr_interface.py @@ -0,0 +1,184 @@ +# fast_whisper/asr_interface.py +from loguru import logger +from pathlib import Path +from typing import Tuple, Optional +import torchaudio +import torch +import numpy + +from .model_manager import ModelManager +from .config import ASRConfig +from .utils import PerformanceProfiler + +class ASRInterface: + """ + ASR接口类 - 全局单例 + - 提供wav转文本功能 + - 注入ModelManager + - 性能统计 + """ + + _instance: Optional['ASRInterface'] = None + + def __new__(cls, *args, **kwargs): + """单例模式实现""" + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self, config: Optional[ASRConfig] = None): + # 防止重复初始化 + if hasattr(self, '_initialized') and self._initialized: + return + + self.config = config or ASRConfig() + self.model_manager = ModelManager(self.config) + self.profiler = PerformanceProfiler(self.config.enable_profiling) + + # 音频参数 + self.sample_rate = 16000 + + self._initialized = True + logger.info("🎤 ASR接口初始化完成") + + @classmethod + def get_instance(cls, config: Optional[ASRConfig] = None) -> 'ASRInterface': + """全局访问点""" + if cls._instance is None: + cls._instance = cls(config) + return cls._instance + + def transcribe_wav( + self, + wav_path: Path, + language: Optional[str] = None + ) -> Tuple[str, str, float]: + """ + WAV音频转文本(核心接口) + + Args: + wav_path: WAV文件路径 + language: 指定语言代码(如'zh'/'en'),None则自动检测 + + Returns: + (text, language, confidence) + """ + try: + # 记录开始时间 + import time + start_time = time.time() + + logger.info(f"🎵 开始识别: {wav_path.name}") + + # 执行识别... + audio = self._load_audio(wav_path) + result = self._transcribe(audio, language) + text, lang, confidence = self._parse_result(result) + + # 计算耗时 + processing_time = time.time() - start_time + logger.info( + f"✅ 识别完成: {lang} | {len(text)}字符 | 置信度:{confidence:.2f} | " + f"耗时:{processing_time:.3f}s | RTF:{processing_time/(len(audio)/self.sample_rate):.3f}" + ) + + return text, lang, confidence + + except Exception as e: + logger.error(f"❌ 识别失败 {wav_path}: {e}") + raise RuntimeError(f"Transcription failed: {e}") + + def _load_audio(self, wav_path: Path) -> numpy.ndarray: + """加载和预处理音频""" + if not wav_path.exists(): + raise FileNotFoundError(f"音频文件不存在: {wav_path}") + + # 加载音频 + waveform, sample_rate = torchaudio.load(wav_path) + + # 重采样到16kHz + if sample_rate != self.sample_rate: + resampler = torchaudio.transforms.Resample(sample_rate, self.sample_rate) + waveform = resampler(waveform) + + # 转换为单声道 + if waveform.shape[0] > 1: + waveform = torch.mean(waveform, dim=0, keepdim=True) + + # 转换为numpy数组 + audio = waveform.squeeze().numpy() + + return audio + + def _transcribe(self, audio: numpy.ndarray, language: Optional[str]) -> Tuple: + """执行推理""" + model = self.model_manager.model + + # 添加模型存在性检查 + if model is None: + logger.error("ASR模型未加载,请检查模型配置和路径") + raise RuntimeError("ASR模型未加载,请检查模型配置和路径") + + # 记录时间 + import time + start_time = time.time() + + # 调用模型 + segments, info = model.transcribe( + audio, + language=language, + beam_size=self.config.beam_size, + best_of=self.config.best_of, + vad_filter=self.config.vad_filter, + ) + + # 立即执行生成器 + segments_list = list(segments) + + # 性能统计 + inference_time = time.time() - start_time + audio_duration = len(audio) / self.sample_rate + self.profiler.record(audio_duration, inference_time) + + return segments_list, info + + def _parse_result(self, result: Tuple) -> Tuple[str, str, float]: + """解析识别结果""" + segments, info = result + + # 合并所有片段 + text = " ".join([seg.text.strip() for seg in segments]) + + # 获取语言信息 + language = info.language if info else "unknown" + confidence = info.language_probability if info else 0.0 + + return text, language, confidence + + def transcribe_batch(self, wav_paths: list) -> list: + """批量识别接口""" + return [ + { + "file": str(path), + "text": result[0], + "language": result[1], + "confidence": result[2] + } + for path, result in zip(wav_paths, [ + self.transcribe_wav(Path(p)) for p in wav_paths + ]) + ] + + def health_check(self) -> dict: + """健康检查接口""" + return { + "status": "healthy" if self.model_manager.model else "unhealthy", + "device": self.config.device, + "model_loaded": self.model_manager.model is not None, + "device_info": self.model_manager.get_device_info(), + } + + def shutdown(self): + """优雅关闭""" + logger.info("🛑 关闭ASR接口...") + self.model_manager.unload() \ No newline at end of file diff --git a/src/modules/asr_module/asr_core/fast_whisper/config.py b/src/modules/asr_module/asr_core/fast_whisper/config.py new file mode 100644 index 0000000..7516926 --- /dev/null +++ b/src/modules/asr_module/asr_core/fast_whisper/config.py @@ -0,0 +1,29 @@ +# fast_whisper/config.py +from dataclasses import dataclass +from pathlib import Path +import torch + +@dataclass +class ASRConfig: + """ASR配置类""" + model_name: str = "deepdml/faster-whisper-large-v3-turbo-ct2" + device: str = "auto" + compute_type: str = "int8_float16" + cache_dir: Path = Path.home() / ".cache" / "faster_whisper" + + # 速度优化参数 + beam_size: int = 1 + best_of: int = 1 + vad_filter: bool = True + batch_size: int = 16 + + # 性能统计 + enable_profiling: bool = True + + def __post_init__(self): + if self.device == "auto": + self.device = "cuda" if torch.cuda.is_available() else "cpu" + + if self.device == "cpu": + self.compute_type = "int8" + self.batch_size = 4 \ No newline at end of file diff --git a/src/modules/asr_module/asr_core/fast_whisper/model_manager.py b/src/modules/asr_module/asr_core/fast_whisper/model_manager.py new file mode 100644 index 0000000..f1bbaab --- /dev/null +++ b/src/modules/asr_module/asr_core/fast_whisper/model_manager.py @@ -0,0 +1,92 @@ +# fast_whisper/model_manager.py +import gc +from loguru import logger +from pathlib import Path +from typing import Optional +from faster_whisper import WhisperModel +import torch + +from .config import ASRConfig + + +class ModelManager: + """ + 模型管理类 + - 负责模型生命周期管理 + - 支持自定义缓存目录 + - 自动硬件适配 + """ + + def __init__(self, config: ASRConfig): + self.config = config + self._model: Optional[WhisperModel] = None + self._device_info = None + + # 确保缓存目录存在 + self.config.cache_dir.mkdir(parents=True, exist_ok=True) + + @property + def model(self) -> Optional[WhisperModel]: + """懒加载模型""" + if self._model is None: + self._load_model() + return self._model + + def _load_model(self): + """加载模型""" + logger.info(f"🚀 初始化模型: {self.config.model_name}") + logger.info(f"📦 设备: {self.config.device}, 计算类型: {self.config.compute_type}") + + try: + self._model = WhisperModel( + self.config.model_name, + device=self.config.device, + compute_type=self.config.compute_type, + download_root=str(self.config.cache_dir), + local_files_only=False, + ) + + self._device_info = { + "device": self.config.device, + "compute_type": self.config.compute_type, + "model_size": self.config.model_name.split("-")[-2] + } + + logger.info("✅ 模型加载成功") + + except Exception as e: + logger.error(f"❌ 模型加载失败: {e}") + raise RuntimeError(f"Failed to load ASR model: {e}") + + def reload(self, new_config: ASRConfig): + """热重载模型""" + logger.info("🔄 热重载模型...") + self.unload() + self.config = new_config + self._load_model() + + def unload(self): + """卸载模型释放资源""" + if self._model is not None: + logger.info("🗑️ 卸载模型...") + del self._model + self._model = None + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + gc.collect() + + logger.info("✅ 模型已卸载") + + def get_device_info(self) -> dict: + """获取设备信息""" + return self._device_info or {} + + def __enter__(self): + """上下文管理器支持""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """自动清理资源""" + self.unload() \ No newline at end of file diff --git a/src/modules/asr_module/asr_core/fast_whisper/utils.py b/src/modules/asr_module/asr_core/fast_whisper/utils.py new file mode 100644 index 0000000..6fe2a58 --- /dev/null +++ b/src/modules/asr_module/asr_core/fast_whisper/utils.py @@ -0,0 +1,45 @@ +# fast_whisper/utils.py +from loguru import logger +from datetime import datetime +from typing import Dict, Any + +def check_hardware() -> Dict[str, Any]: + """硬件检测""" + import torch + info = { + "cuda_available": torch.cuda.is_available(), + "device_name": "CPU", + "device_count": 0, + "compute_type": "int8" + } + + if info["cuda_available"]: + info.update({ + "device_name": torch.cuda.get_device_name(0), + "device_count": torch.cuda.device_count(), + "compute_type": "int8_float16" + }) + + return info + +class PerformanceProfiler: + """性能分析器""" + def __init__(self, enable: bool = True): + self.enable = enable + self.stats = [] + + def record(self, audio_duration: float, inference_time: float): + if not self.enable: + return + + rtf = inference_time / audio_duration if audio_duration > 0 else 0 + self.stats.append({ + "timestamp": datetime.now().isoformat(), + "rtf": rtf, + "audio_duration": audio_duration, + "inference_time": inference_time + }) + + if len(self.stats) % 10 == 0: + avg_rtf = sum(s["rtf"] for s in self.stats[-10:]) / 10 + logger.info(f"📊 最近10次平均RTF: {avg_rtf:.3f}") \ No newline at end of file diff --git a/src/modules/asr_module/client/__init__.py b/src/modules/asr_module/client/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/modules/asr_module/client/asr_client.py b/src/modules/asr_module/client/asr_client.py new file mode 100644 index 0000000..041c83e --- /dev/null +++ b/src/modules/asr_module/client/asr_client.py @@ -0,0 +1,215 @@ +# asr_module/client/asr_client.py +import asyncio +import time +from pathlib import Path +from typing import Union, Optional +import aiofiles +import aiohttp +import requests +from loguru import logger +from .models import ASRResponse, ASRHealthStatus, ServiceInfo + +class ASRException(Exception): + """ASR服务调用异常""" + + def __init__(self, message: str, status_code: Optional[int] = None): + self.message = message + self.status_code = status_code + super().__init__(self.message) + +class ASRClientConfig: + """客户端配置""" + def __init__( + self, + base_url: str = "http://localhost:8000", + timeout: float = 30.0, + retry_count: int = 2, + retry_delay: float = 0.5, + ): + self.base_url = base_url.rstrip('/') + self.timeout = timeout + self.retry_count = retry_count + self.retry_delay = retry_delay + + +# 同步客户端 +class ASRClientSync: + """同步ASR客户端""" + def __init__(self, config: Optional[ASRClientConfig] = None): + self.config = config or ASRClientConfig() + self.session = requests.Session() + self.session.timeout = self.config.timeout + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.session.close() + + def _request(self, method: str, endpoint: str, **kwargs) -> dict: + """统一请求处理(带重试)""" + url = f"{self.config.base_url}{endpoint}" + + for attempt in range(self.config.retry_count + 1): + try: + response = self.session.request(method, url, **kwargs) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + if attempt < self.config.retry_count: + logger.warning(f"请求失败,重试中 ({attempt + 1}/{self.config.retry_count}): {e}") + time.sleep(self.config.retry_delay) + else: + logger.error(f"请求最终失败: {e}") + raise ASRException(f"API调用失败: {e}", getattr(e.response, 'status_code', None)) + + def transcribe_file(self, file_path: Union[str, Path]) -> ASRResponse: + """ + 转录音频文件 + + Args: + file_path: 音频文件路径 + + Returns: + ASRResponse对象 + + Example: + client = ASRClientSync() + result = client.transcribe_file("/path/to/audio.wav") + print(result.data.text) + """ + file_path = Path(file_path) + if not file_path.exists(): + raise FileNotFoundError(f"文件不存在: {file_path}") + + logger.info(f"📤 上传文件: {file_path.name}") + + with open(file_path, 'rb') as f: + files = {'file': (file_path.name, f, 'audio/wav')} + result = self._request('POST', '/transcribe', files=files) + + return ASRResponse(**result) + + def transcribe_bytes(self, audio_data: bytes, filename: str = "audio.wav") -> ASRResponse: + """ + 转录音频字节流 + + Args: + audio_data: 原始音频字节 + filename: 模拟文件名(用于MIME类型推断) + + Returns: + ASRResponse对象 + + Example: + with open('audio.wav', 'rb') as f: + audio_bytes = f.read() + result = client.transcribe_bytes(audio_bytes) + """ + logger.info(f"📤 上传字节流 ({len(audio_data)} bytes)") + + files = {'file': (filename, audio_data, 'audio/wav')} + result = self._request('POST', '/transcribe', files=files) + + return ASRResponse(**result) + + def health_check(self) -> ASRHealthStatus: + """健康检查""" + result = self._request('GET', '/health') + return ASRHealthStatus(**result) + + def get_service_info(self) -> ServiceInfo: + """获取服务信息""" + result = self._request('GET', '/') + return ServiceInfo(**result) + + +# 异步客户端 +class ASRClientAsync: + """异步ASR客户端""" + def __init__(self, config: Optional[ASRClientConfig] = None): + self.config = config or ASRClientConfig() + self._session: Optional[aiohttp.ClientSession] = None + + async def __aenter__(self): + self._session = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=self.config.timeout) + ) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if self._session: + await self._session.close() + + async def _ensure_session(self): + if self._session is None: + self._session = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=self.config.timeout) + ) + + async def _request(self, method: str, endpoint: str, **kwargs) -> dict: + """统一异步请求(带重试)""" + await self._ensure_session() + url = f"{self.config.base_url}{endpoint}" + + for attempt in range(self.config.retry_count + 1): + try: + async with self._session.request(method, url, **kwargs) as response: + response.raise_for_status() + return await response.json() + except aiohttp.ClientError as e: + if attempt < self.config.retry_count: + logger.warning(f"请求失败,重试中 ({attempt + 1}/{self.config.retry_count}): {e}") + await asyncio.sleep(self.config.retry_delay) + else: + logger.error(f"请求最终失败: {e}") + raise ASRException(f"API调用失败: {e}", getattr(e, 'status', None)) + + async def transcribe_file(self, file_path: Union[str, Path]) -> ASRResponse: + """异步转录音频文件""" + file_path = Path(file_path) + if not file_path.exists(): + raise FileNotFoundError(f"文件不存在: {file_path}") + + logger.info(f"📤 上传文件: {file_path.name}") + + async with aiofiles.open(file_path, 'rb') as f: + audio_data = await f.read() + + return await self.transcribe_bytes(audio_data, file_path.name) + + async def transcribe_bytes(self, audio_data: bytes, filename: str = "audio.wav") -> ASRResponse: + """异步转录音频字节流""" + logger.info(f"📤 上传字节流 ({len(audio_data)} bytes)") + await self._ensure_session() # 确保session已创建 + form = aiohttp.FormData() # 创建表单数据 + form.add_field('file', audio_data, filename=filename, content_type='audio/wav') # 添加文件字段 + result = await self._request('POST', '/transcribe', data=form) # 发送POST请求 + return ASRResponse(**result) # 返回结果 + + async def health_check(self) -> ASRHealthStatus: + """异步健康检查""" + result = await self._request('GET', '/health') + return ASRHealthStatus(**result) + + async def get_service_info(self) -> ServiceInfo: + """异步获取服务信息""" + result = await self._request('GET', '/') + return ServiceInfo(**result) + +# 工厂函数 +def create_asr_client(use_async: bool = False, **config_kwargs) -> Union[ASRClientSync, ASRClientAsync]: + """ + 创建客户端工厂函数 + + Args: + use_async: 是否创建异步客户端 + **config_kwargs: ASRClientConfig参数 + + Returns: + 同步或异步客户端实例 + """ + config = ASRClientConfig(**config_kwargs) + if use_async: + return ASRClientAsync(config) + return ASRClientSync(config) \ No newline at end of file diff --git a/src/modules/asr_module/client/models.py b/src/modules/asr_module/client/models.py new file mode 100644 index 0000000..ca280fe --- /dev/null +++ b/src/modules/asr_module/client/models.py @@ -0,0 +1,30 @@ +# asr_module/client/models.py +from pydantic import BaseModel +from typing import Optional + +class ASRHealthStatus(BaseModel): + """ASR服务健康状态""" + status: str # 状态 + timestamp: str # 时间戳 + device: str # 设备 + model_loaded: bool # 模型是否加载 + +class ASRResult(BaseModel): + """语音识别结果""" + text: str # 识别结果 + language: str # 语言 + confidence: float # 置信度 + processing_time: float # 处理时间 + +class ASRResponse(BaseModel): + """统一API响应""" + success: bool + data: Optional[ASRResult] = None + error: Optional[str] = None + +class ServiceInfo(BaseModel): + """服务信息""" + message: str # 消息 + docs: str # 文档 + health: str # 健康 + transcribe: str # 识别 \ No newline at end of file diff --git a/src/modules/asr_module/readme.md b/src/modules/asr_module/readme.md new file mode 100644 index 0000000..3799f32 --- /dev/null +++ b/src/modules/asr_module/readme.md @@ -0,0 +1,12 @@ +本模块为asr模块,即语音转文本模块。 + +本模块提供了两种访问方式: +- 第一种为本地部署的func call方式,即提供函数调用的方式调用相关asr接口。 +- 第二种为http call方式,即提供http call方式调用相关asr接口。[FastAPI] + +如果你的电脑显卡支持cuda, 并且显存大小大于8G, 那么可以使用第一种方式, +否则可以使用第二种,进行云端部署。 + +两种调用方式在相同配置下,性能几乎无差别。 + +个人建议使用第二种 \ No newline at end of file diff --git a/src/modules/asr_module/start_api.py b/src/modules/asr_module/start_api.py new file mode 100644 index 0000000..2618309 --- /dev/null +++ b/src/modules/asr_module/start_api.py @@ -0,0 +1,65 @@ +# start_api.py +import uvicorn +from loguru import logger +import threading +import time + +def start_server(): + """启动 ASR API 服务""" + uvicorn.run( + "api:app", # 模块名:app实例 + host="0.0.0.0", + port=20260, + workers=1, # 单用户场景,1个worker足够 + log_level="info", + reload=False, # 生产环境关闭热重载 + access_log=True, + ) + +def first_test() -> None: + """首次启动测试""" + time.sleep(5) # 给服务器一些启动时间 + # 构造一个测试请求以验证初始化模型加载成功 + logger.info("🚀 测试模型是否加载成功...") + import requests + from pathlib import Path + url = "http://localhost:20260/transcribe" + audio_path = Path("../../../Test/test_files/test.wav") + try: + with open(audio_path, "rb") as f: + # 明确指定文件名和 MIME 类型 + files = { + "file": ( + audio_path.name, # 文件名 + f, # 文件对象 + "audio/wav" # MIME 类型 + ) + } + + response = requests.post(url, files=files) + logger.info(f"状态码: {response.status_code}") + logger.info(f"响应头: {response.headers.get('content-type')}") + if response.status_code == 200: + result = response.json() + logger.info(f"识别结果: {result['data']['text']}") + logger.info(f"识别语言: {result['data']['language']}") + logger.info(f"置信度: {result['data']['confidence']:.2f}") + logger.info(f"处理时间: {result['data']['processing_time']}s") + else: + logger.error(f"请求失败,错误响应信息: {response.text}") + logger.error("请检查模型是否正确加载或其他问题") + except Exception as e: + logger.error(f"测试过程中发生错误: {e}") + +if __name__ == "__main__": + logger.info("🚀 启动 ASR API 服务...") + + # 在后台线程启动服务器 + server_thread = threading.Thread(target=start_server, daemon=True) + server_thread.start() + + # 执行测试 + first_test() + + # 保持主线程运行 + server_thread.join() diff --git a/src/modules/device_control_module/__init__.py b/src/modules/device_control_module/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/modules/device_control_module/device_control_core/__init__.py b/src/modules/device_control_module/device_control_core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/modules/device_control_module/device_control_core/ui_tars_/__init__.py b/src/modules/device_control_module/device_control_core/ui_tars_/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/modules/device_control_module/device_control_core/ui_tars_/ui_tars_client.py b/src/modules/device_control_module/device_control_core/ui_tars_/ui_tars_client.py new file mode 100644 index 0000000..230b6a5 --- /dev/null +++ b/src/modules/device_control_module/device_control_core/ui_tars_/ui_tars_client.py @@ -0,0 +1,119 @@ +# ui_tars_/ui_tars_client.py +from typing import Optional +from loguru import logger +import asyncio +from src.modules.text_ai_module.text_ai_core.general_text_ai_req import ( + UnifiedLLM, + ModelConfig, + ModelProvider, + ChatMessage, +) +from pydantic import BaseModel, Field +from src.modules.device_control_module.device_control_core.ui_tars_.ui_tars_prompts import UI_TARS_SYSTEM_PROMPT + + +class UITarsClientConfig(BaseModel): + """UI-TARS 客户端配置""" + deployment_type: str = Field(default="lmstudio", description="部署类型") + base_url: str = Field(default="http://localhost:1234/v1", description="API地址") + model_name: str = Field(default="ui-tars", description="模型名称") + api_key: Optional[str] = Field(default=None, description="API密钥") + temperature: float = Field(default=0.1, ge=0.0, le=2.0) + max_tokens: int = Field(default=8192, ge=2048, le=128000) + timeout: int = Field(default=30, ge=5, le=300) + + # UI-TARS-1.5 强制输出格式 + system_prompt: str = Field( + default=UI_TARS_SYSTEM_PROMPT # 使用本项目自定义的输出格式约束 + ) + + def to_model_config(self) -> ModelConfig: + """转换为 UnifiedLLM 配置""" + # 映射部署类型到 ModelProvider + provider_map = { + "lmstudio": ModelProvider.LM_STUDIO, + "vllm": ModelProvider.CUSTOM, + "cloud": ModelProvider.OPENAI, + "ollama": ModelProvider.OLLAMA + } + provider = provider_map.get(self.deployment_type, ModelProvider.CUSTOM) + return ModelConfig( + provider=provider, + model_name=self.model_name, + api_key=self.api_key, + base_url=self.base_url, + temperature=self.temperature, + max_tokens=self.max_tokens, + timeout=self.timeout, + custom_headers={"User-Agent": "UI-TARS-Client/1.0"} + ) + +class UITarsClient: + """ + UI-TARS 通用客户端 (基于 UnifiedLLM) + 图片相关信息请直接传入相应的base64 + """ + def __init__(self, config: UITarsClientConfig): + self.config = config + # 复用 UnifiedLLM,自动处理所有部署类型 + self.llm = UnifiedLLM(config.to_model_config()) + + logger.info(f"UI-TARS 客户端初始化: {config.deployment_type} @ {config.base_url}") + logger.info(f" 模型: {config.model_name} | 温度: {config.temperature}") + + async def call_async(self, instruction: str, image_base64: str) -> str: + """异步调用 UI-TARS""" + # 构建消息 + messages = self._build_messages(instruction, image_base64) + try: + # 使用 UnifiedLLM 的异步接口 + response = await asyncio.to_thread( + self.llm.chat, + messages=messages, + streaming=False + ) + return response.content + except Exception as e: + logger.error(f"UI-TARS 调用失败: {e}") + raise + + def call_sync(self, instruction: str, image_base64: str) -> str: + """同步调用 UI-TARS""" + messages = self._build_messages(instruction, image_base64) + try: + response = self.llm.chat( + messages=messages, + streaming=False + ) + + return response.content + except Exception as e: + logger.error(f"UI-TARS 调用失败: {e}") + raise + + def stream_async(self, instruction: str, image_base64: str): + """流式调用 (异步生成器)""" + messages = self._build_messages(instruction, image_base64) + # UnifiedLLM 自动处理流式 + return self.llm.stream_chat(messages=messages) + + def _build_messages(self, instruction: str, image_base64: str) -> list: + """构建 OpenAI 格式消息""" + return [ + ChatMessage( + role="system", + content=self.config.system_prompt + ), + ChatMessage( + role="user", + content=[ + {"type": "text", "text": instruction}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{image_base64}" + } + } + ] + ) + ] diff --git a/src/modules/device_control_module/device_control_core/ui_tars_/ui_tars_prompts.py b/src/modules/device_control_module/device_control_core/ui_tars_/ui_tars_prompts.py new file mode 100644 index 0000000..8c484f3 --- /dev/null +++ b/src/modules/device_control_module/device_control_core/ui_tars_/ui_tars_prompts.py @@ -0,0 +1,25 @@ +OFFICIAL_ACTION_SPACE = """## Action Space +click(point='x1 y1') - 单击坐标 +left_double(point='x1 y1') - 双击坐标 +right_single(point='x1 y1') - 右键单击 +drag(start_point='x1 y1', end_point='x2 y2') - 拖拽 +hotkey(key='ctrl c') - 快捷键(空格分隔,小写,最多3个键) +type(content='xxx') - 输入文本(用\\' \\\" \\n 转义) +scroll(point='x1 y1', direction='down or up or right or left') - 滚动 +wait() - 等待5秒 +finished() - 任务完成 + +## Output Format +Thought: [你的推理过程] +Action: [选择一个动作] +""" + +UI_TARS_SYSTEM_PROMPT = f"""You are UI-TARS-1.5, a GUI agent. Given a task and screenshot, output ONLY: + +{OFFICIAL_ACTION_SPACE} + +## Note +- Write a small plan and summarize the next action in one sentence in Thought. +- NEVER output multiple actions. +- x&y please in box center +""" \ No newline at end of file diff --git a/src/modules/device_control_module/readme.md b/src/modules/device_control_module/readme.md new file mode 100644 index 0000000..47a9652 --- /dev/null +++ b/src/modules/device_control_module/readme.md @@ -0,0 +1,10 @@ +本模块为设备控制模块接口层,此处的设备控制指的是借助AI模型进行一些设备上的自动化 +操作,支持`pc`, `android`,其他的未做过测试。 + +依赖: +`ui_tars` + +当前所使用的AI模型为`mradermacher/UI-TARS-1.5-7B-GGUF +(Q6_K or Q4_K_M)` + +未来如果有更快质量更高的AI模型本模块会为其添加支持。 \ No newline at end of file diff --git a/src/modules/text_ai_module/__init__.py b/src/modules/text_ai_module/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/modules/text_ai_module/text_ai_core/__init__.py b/src/modules/text_ai_module/text_ai_core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/modules/text_ai_module/text_ai_core/general_text_ai_req.py b/src/modules/text_ai_module/text_ai_core/general_text_ai_req.py new file mode 100644 index 0000000..2a0f89a --- /dev/null +++ b/src/modules/text_ai_module/text_ai_core/general_text_ai_req.py @@ -0,0 +1,837 @@ +""" +通用大语言模型调用框架 +支持本地模型(Ollama, LM Studio, llama.cpp)和云端模型(OpenAI, Anthropic, Google, Azure等) +""" + +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Union, Any, Iterator +import json +import os +from loguru import logger +from dataclasses import dataclass, asdict, field +from enum import Enum +import httpx +from pydantic import BaseModel, Field + +class ModelProvider(Enum): + """支持的模型提供商枚举""" + OPENAI = "openai" + ANTHROPIC = "anthropic" + GOOGLE = "google" + AZURE = "azure" + OLLAMA = "ollama" + LM_STUDIO = "lm_studio" + LLAMA_CPP = "llama_cpp" + CUSTOM = "custom" + + +@dataclass +class ModelConfig: + """模型配置类""" + provider: ModelProvider + model_name: str + api_key: Optional[str] = None + base_url: Optional[str] = None + api_version: Optional[str] = None + temperature: float = 0.7 + max_tokens: int = 1024 + top_p: float = 1.0 + frequency_penalty: float = 0.0 + presence_penalty: float = 0.0 + timeout: int = 30 + streaming: bool = False + custom_headers: Dict[str, str] = field(default_factory=dict) + + def to_dict(self) -> Dict: + """转换为字典""" + return asdict(self) + + +@dataclass +class ChatMessage: + """聊天消息类""" + role: str # system, user, assistant + content: str + name: Optional[str] = None + + def to_dict(self) -> Dict: + """转换为字典""" + return { + "role": self.role, + "content": self.content, + **({"name": self.name} if self.name else {}) + } + + +@dataclass +class ModelResponse: + """模型响应类""" + content: str # 响应内容 + model: str # 模型名称 + usage: Optional[Dict[str, int]] = None # 使用量 + finish_reason: Optional[str] = None # 结束原因 + raw_response: Optional[Dict] = None # 原始响应 + + +def normalize_usage(raw_usage: Optional[Dict[str, Any]], provider: ModelProvider) -> Optional[Dict[str, int]]: + """ + 将不同平台的 usage 字段统一归一化为 OpenAI 标准格式 + + Args: + raw_usage: API 原始返回的 usage 数据 + provider: 模型提供商枚举 + + Returns: + 归一化后的 usage 字典,格式: + { + "prompt_tokens": int, + "completion_tokens": int, + "total_tokens": int + } + 如果无法归一化则返回 None + """ + if not raw_usage: + return None + + # 字段映射表:{provider: (input_key, output_key, total_key)} + USAGE_FIELD_MAP = { + ModelProvider.OPENAI: ("prompt_tokens", "completion_tokens", "total_tokens"), + ModelProvider.AZURE: ("prompt_tokens", "completion_tokens", "total_tokens"), + ModelProvider.LM_STUDIO: ("prompt_tokens", "completion_tokens", "total_tokens"), + ModelProvider.LLAMA_CPP: ("prompt_tokens", "completion_tokens", "total_tokens"), + ModelProvider.OLLAMA: ("prompt_eval_count", "eval_count", None), # Ollama 没有 total + ModelProvider.ANTHROPIC: ("input_tokens", "output_tokens", None), + ModelProvider.GOOGLE: ("promptTokenCount", "candidatesTokenCount", "totalTokenCount"), + } + + input_key, output_key, total_key = USAGE_FIELD_MAP.get(provider, (None, None, None)) + + if input_key is None: + logger.warning(f"未知的 provider '{provider}',无法归一化 usage") + return None + + try: + # 提取字段值 + prompt_tokens = raw_usage.get(input_key, 0) + completion_tokens = raw_usage.get(output_key, 0) + + # 处理嵌套字典(如有些 API 的 usage 格式特殊) + if isinstance(prompt_tokens, dict): + prompt_tokens = prompt_tokens.get("value", 0) + if isinstance(completion_tokens, dict): + completion_tokens = completion_tokens.get("value", 0) + + # 转换为整数 + prompt_tokens = int(prompt_tokens) if prompt_tokens else 0 + completion_tokens = int(completion_tokens) if completion_tokens else 0 + + # 计算 total(如果 API 没提供) + if total_key and total_key in raw_usage: + total_tokens = int(raw_usage[total_key]) + else: + total_tokens = prompt_tokens + completion_tokens + + normalized = { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens + } + + logger.debug(f"归一化 usage | {provider} -> OpenAI格式: {normalized}") + return normalized + + except Exception as e: + logger.error(f"归一化 usage 失败: {e} | raw_usage: {raw_usage}") + return None + +class BaseLLMClient(ABC): + """大语言模型客户端基类""" + + def __init__(self, config: ModelConfig): + self.config = config + self.client = None + self._initialize_client() + + @abstractmethod + def _initialize_client(self): + """初始化客户端""" + pass + + @abstractmethod + def chat_completion( + self, + messages: List[Union[ChatMessage, Dict]], + **kwargs + ) -> Union[ModelResponse, Iterator[ModelResponse]]: + """聊天补全""" + pass + + @abstractmethod + def completion( + self, + prompt: str, + **kwargs + ) -> Union[ModelResponse, Iterator[ModelResponse]]: + """文本补全""" + pass + + def format_messages(self, messages: List[Union[ChatMessage, Dict]]) -> List[Dict]: + """格式化消息列表""" + formatted = [] + for msg in messages: + if isinstance(msg, ChatMessage): + formatted.append(msg.to_dict()) + else: + formatted.append(msg) + return formatted + + +class OpenAIClient(BaseLLMClient): + """OpenAI客户端""" + + def _initialize_client(self): + try: + from openai import OpenAI + + api_key = self.config.api_key + if not api_key: + raise ValueError("OpenAI API密钥未设置") + + self.client = OpenAI( + api_key=api_key, + base_url=self.config.base_url, + timeout=self.config.timeout + ) + logger.info(f"OpenAI客户端初始化成功,base_url: {self.config.base_url}") + except ImportError: + logger.error("请安装openai包: pip install openai") + raise + + def chat_completion(self, messages, **kwargs): + formatted_messages = self.format_messages(messages) + + # 获取streaming参数,优先使用kwargs中的设置 + streaming = kwargs.get("streaming", self.config.streaming) + + # 合并配置 + params = { + "model": self.config.model_name, + "messages": formatted_messages, + "temperature": kwargs.get("temperature", self.config.temperature), + "max_tokens": kwargs.get("max_tokens", self.config.max_tokens), + "top_p": kwargs.get("top_p", self.config.top_p), + "frequency_penalty": kwargs.get("frequency_penalty", self.config.frequency_penalty), + "presence_penalty": kwargs.get("presence_penalty", self.config.presence_penalty), + "stream": streaming, # 使用正确的streaming设置 + } + + logger.info(f"🔧 调用参数: streaming={streaming}") + + if streaming: + return self._stream_chat_completion(params) + else: + return self._normal_chat_completion(params) + + def _normal_chat_completion(self, params): + """非流式响应处理""" + logger.info("📡 发送非流式请求...") + response = self.client.chat.completions.create(**params) + raw_usage = response.usage + normalized_usage = normalize_usage( + raw_usage.model_dump() if hasattr(raw_usage, 'model_dump') else raw_usage, + ModelProvider.OPENAI + ) + return ModelResponse( + content=response.choices[0].message.content, + model=response.model, + usage=normalized_usage, + finish_reason=response.choices[0].finish_reason, + raw_response=response.model_dump() if hasattr(response, 'model_dump') else response.dict() + ) + + def _stream_chat_completion(self, params): + """流式响应处理""" + logger.info("📡 发送流式请求...") + response_stream = self.client.chat.completions.create(**params) + + full_content = "" + for chunk in response_stream: + if chunk.choices[0].delta.content is not None: + content_chunk = chunk.choices[0].delta.content + full_content += content_chunk + yield ModelResponse( + content=content_chunk, + model=chunk.model, + raw_response=chunk.model_dump() if hasattr(chunk, 'model_dump') else chunk.dict() + ) + + def completion(self, prompt, **kwargs): + # OpenAI 推荐使用 chat_completion,这里保持兼容 + messages = [{"role": "user", "content": prompt}] + return self.chat_completion(messages, **kwargs) + + +class AnthropicClient(BaseLLMClient): + """Anthropic Claude客户端""" + + def _initialize_client(self): + try: + from anthropic import Anthropic + self.client = Anthropic( + api_key=self.config.api_key or os.getenv("ANTHROPIC_API_KEY"), + timeout=self.config.timeout + ) + except ImportError: + logger.error("请安装anthropic包: pip install anthropic") + raise + + def chat_completion(self, messages, **kwargs): + formatted_messages = self.format_messages(messages) + + # Claude 的消息格式转换 + claude_messages = [] + system_message = None + + for msg in formatted_messages: + if msg["role"] == "system": + system_message = msg["content"] + else: + claude_messages.append({ + "role": msg["role"], + "content": msg["content"] + }) + # 明确获取streaming参数 + + params = { + "model": self.config.model_name, + "messages": claude_messages, + "max_tokens": kwargs.get("max_tokens", self.config.max_tokens), + "temperature": kwargs.get("temperature", self.config.temperature), + "top_p": kwargs.get("top_p", self.config.top_p), + "stream": kwargs.get("streaming", self.config.streaming), + } + + if system_message: + params["system"] = system_message + + if self.config.streaming: + return self._stream_chat_completion(params) + else: + return self._normal_chat_completion(params) + + def _normal_chat_completion(self, params): + response = self.client.messages.create(**params) + # Anthropic 返回的usage格式和OpenAI不同,需要进行转换 + raw_usage = response.usage + normalized_usage = normalize_usage( + raw_usage.model_dump() if hasattr(raw_usage, 'model_dump') else raw_usage, + ModelProvider.ANTHROPIC + ) + return ModelResponse( + content=response.content[0].text, + model=response.model, + usage=normalized_usage, + finish_reason=response.stop_reason, + raw_response=response.model_dump() if hasattr(response, 'model_dump') else response.dict() + ) + + def _stream_chat_completion(self, params): + with self.client.messages.stream(**params) as stream: + for chunk in stream: + if chunk.type_ == "content_block_delta": + yield ModelResponse( + content=chunk.delta.text, + model=params["model"], + raw_response=chunk.model_dump() if hasattr(chunk, 'model_dump') else chunk.dict() + ) + + def completion(self, prompt, **kwargs): + messages = [{"role": "user", "content": prompt}] + return self.chat_completion(messages, **kwargs) + + +class OllamaClient(BaseLLMClient): + """Ollama本地模型客户端""" + + def _initialize_client(self): + import httpx + self.base_url = self.config.base_url or "http://localhost:11434" + self.client = httpx.Client( + base_url=self.base_url, + timeout=self.config.timeout + ) + + def chat_completion(self, messages, **kwargs): + formatted_messages = self.format_messages(messages) + + payload = { + "model": self.config.model_name, + "messages": formatted_messages, + "options": { + "temperature": kwargs.get("temperature", self.config.temperature), + "top_p": kwargs.get("top_p", self.config.top_p), + "num_predict": kwargs.get("max_tokens", self.config.max_tokens), + }, + "stream": kwargs.get("streaming", self.config.streaming), + } + + if self.config.streaming: + return self._stream_chat_completion(payload) + else: + return self._normal_chat_completion(payload) + + def _normal_chat_completion(self, payload): + response = self.client.post("/api/chat", json=payload) + response.raise_for_status() + data = response.json() + normalized_usage = normalize_usage(data, ModelProvider.OLLAMA) + return ModelResponse( + content=data["message"]["content"], + model=data["model"], + usage=normalized_usage, + finish_reason=data.get("done_reason"), + raw_response=data + ) + + def _stream_chat_completion(self, payload): + with self.client.stream("POST", "/api/chat", json=payload) as response: + for line in response.iter_lines(): + if line.strip(): + try: + data = json.loads(line) + if "message" in data and "content" in data["message"]: + yield ModelResponse( + content=data["message"]["content"], + model=data.get("model", self.config.model_name), + raw_response=data + ) + except json.JSONDecodeError: + continue + + def completion(self, prompt, **kwargs): + payload = { + "model": self.config.model_name, + "prompt": prompt, + "options": { + "temperature": kwargs.get("temperature", self.config.temperature), + "top_p": kwargs.get("top_p", self.config.top_p), + "num_predict": kwargs.get("max_tokens", self.config.max_tokens), + }, + "stream": kwargs.get("streaming", self.config.streaming), + } + + if self.config.streaming: + return self._stream_completion(payload) + else: + return self._normal_completion(payload) + + def _normal_completion(self, payload): + response = self.client.post("/api/generate", json=payload) + response.raise_for_status() + data = response.json() + + return ModelResponse( + content=data["response"], + model=data["model"], + usage={ + "prompt_tokens": data.get("prompt_eval_count", 0), + "completion_tokens": data.get("eval_count", 0), + "total_tokens": data.get("prompt_eval_count", 0) + data.get("eval_count", 0) + }, + finish_reason=data.get("done_reason"), + raw_response=data + ) + + def _stream_completion(self, payload): + with self.client.stream("POST", "/api/generate", json=payload) as response: + for line in response.iter_lines(): + if line.strip(): + try: + data = json.loads(line) + if "response" in data: + yield ModelResponse( + content=data["response"], + model=data.get("model", self.config.model_name), + raw_response=data + ) + except json.JSONDecodeError: + continue + + +class GenericLLMClient(BaseLLMClient): + """通用HTTP客户端,支持LM Studio和其他兼容OpenAI API的本地模型""" + + def _initialize_client(self): + import httpx + self.base_url = self.config.base_url or "http://localhost:1234/v1" + self.client = httpx.Client( + base_url=self.base_url, + timeout=self.config.timeout, + headers=self.config.custom_headers + ) + + def chat_completion(self, messages, **kwargs): + formatted_messages = self.format_messages(messages) + + # 明确获取 streaming 参数 + streaming = kwargs.get("streaming", self.config.streaming) + + payload = { + "model": self.config.model_name, + "messages": formatted_messages, + "temperature": kwargs.get("temperature", self.config.temperature), + "max_tokens": kwargs.get("max_tokens", self.config.max_tokens), + "top_p": kwargs.get("top_p", self.config.top_p), + "stream": streaming, # 使用明确的 streaming 变量 + } + + logger.info(f"GenericLLMClient 参数: streaming={streaming}") + + if streaming: + return self._stream_chat_completion(payload) + else: + return self._normal_chat_completion(payload) + + def _normal_chat_completion(self, payload): + logger.info(f"GenericLLMClient 发送非流式请求到: {self.base_url}") + response = self.client.post("/chat/completions", json=payload) + response.raise_for_status() + data = response.json() + + logger.info(f"GenericLLMClient 收到响应,模型: {data.get('model')}") + + return ModelResponse( + content=data["choices"][0]["message"]["content"], + model=data["model"], + usage=data.get("usage"), + finish_reason=data["choices"][0].get("finish_reason"), + raw_response=data + ) + + def _stream_chat_completion(self, payload): + logger.info(f"GenericLLMClient 发送流式请求到: {self.base_url}") + with self.client.stream("POST", "/chat/completions", json=payload) as response: + for line in response.iter_lines(): + if line.startswith("data: "): + chunk = line[6:] + if chunk == "[DONE]": + break + try: + data = json.loads(chunk) + if data["choices"][0]["delta"].get("content"): + yield ModelResponse( + content=data["choices"][0]["delta"]["content"], + model=data["model"], + raw_response=data + ) + except json.JSONDecodeError: + continue + + def completion(self, prompt, **kwargs): + payload = { + "model": self.config.model_name, + "prompt": prompt, + "temperature": kwargs.get("temperature", self.config.temperature), + "max_tokens": kwargs.get("max_tokens", self.config.max_tokens), + "top_p": kwargs.get("top_p", self.config.top_p), + "stream": kwargs.get("streaming", self.config.streaming), + } + + streaming = kwargs.get("streaming", self.config.streaming) + + if streaming: + return self._stream_completion(payload) + else: + return self._normal_completion(payload) + + def _normal_completion(self, payload): + response = self.client.post("/completions", json=payload) + response.raise_for_status() + data = response.json() + + return ModelResponse( + content=data["choices"][0]["text"], + model=data["model"], + usage=data.get("usage"), + finish_reason=data["choices"][0].get("finish_reason"), + raw_response=data + ) + + def _stream_completion(self, payload): + with self.client.stream("POST", "/completions", json=payload) as response: + for line in response.iter_lines(): + if line.startswith("data: "): + chunk = line[6:] + if chunk == "[DONE]": + break + try: + data = json.loads(chunk) + if data["choices"][0].get("text"): + yield ModelResponse( + content=data["choices"][0]["text"], + model=data["model"], + raw_response=data + ) + except json.JSONDecodeError: + continue + +class UnifiedLLM: + """ + 统一的大语言模型调用类 + + 支持多种模型提供商: + - 云端模型:OpenAI, Anthropic, Google, Azure + - 本地模型:Ollama, LM Studio, llama.cpp + + 使用示例: + ```python + # 初始化OpenAI客户端 + config = ModelConfig( + provider=ModelProvider.OPENAI, + model_name="gpt-4", + api_key="your-api-key" + ) + llm = UnifiedLLM(config) + + # 聊天补全 + messages = [ + {"role": "system", "content": "你是一个有用的助手"}, + {"role": "user", "content": "你好!"} + ] + response = llm.chat(messages) + print(response.content) + + # 流式响应 + config.streaming = True + llm.update_config(config) + for chunk in llm.chat(messages): + print(chunk.content, end="", flush=True) + ``` + """ + + def __init__(self, config: ModelConfig): + """ + 初始化统一LLM + + Args: + config: 模型配置 + """ + self.config = config + self.client = self._create_client() + logger.info(f" UnifiedLLM 初始化完成") + logger.info(f" 提供商: {config.provider}") + logger.info(f" 模型: {config.model_name}") + logger.info(f" 流式默认: {config.streaming}") + + def _create_client(self) -> BaseLLMClient: + """根据配置创建客户端""" + provider = self.config.provider + + if provider == ModelProvider.OPENAI: + return OpenAIClient(self.config) + elif provider == ModelProvider.ANTHROPIC: + return AnthropicClient(self.config) + elif provider == ModelProvider.OLLAMA: + return OllamaClient(self.config) + elif provider in [ModelProvider.LM_STUDIO, ModelProvider.LLAMA_CPP, ModelProvider.CUSTOM]: + return GenericLLMClient(self.config) + elif provider == ModelProvider.GOOGLE: + # 这里可以扩展Google Gemini支持 + return GenericLLMClient(self.config) + elif provider == ModelProvider.AZURE: + # Azure OpenAI需要特殊处理 + return GenericLLMClient(self.config) + else: + raise ValueError(f"不支持的模型提供商: {provider}") + + def update_config(self, config: ModelConfig): + """更新配置并重新创建客户端""" + self.config = config + self.client = self._create_client() + logger.info(f"UnifiedLLM 配置已更新") + + def chat(self, messages: List[Union[ChatMessage, Dict]], **kwargs) -> Union[ModelResponse, Iterator[ModelResponse]]: + """ + 聊天补全 + + Args: + messages: 消息列表 + **kwargs: 其他参数,会覆盖config中的设置 + + Returns: + ModelResponse 或 ModelResponse 的迭代器(流式模式) + """ + # 明确获取streaming参数 + streaming = kwargs.get("streaming", self.config.streaming) + logger.info(f" UnifiedLLM.chat() 调用") + logger.info(f" 消息数: {len(messages)}") + logger.info(f" streaming参数: {streaming}") + + # 调用客户端 + result = self.client.chat_completion(messages, **kwargs) + + # 类型检查(调试用) + if streaming: + if not hasattr(result, '__iter__'): + logger.warning(f"警告: streaming=True 但返回的不是迭代器") + else: + if hasattr(result, '__iter__'): + logger.warning(f"警告: streaming=False 但返回的是迭代器") + elif not isinstance(result, ModelResponse): + logger.warning(f"警告: streaming=False 但返回的不是ModelResponse") + + return result + + def complete(self, prompt: str, **kwargs) -> Union[ModelResponse, Iterator[ModelResponse]]: + """ + 文本补全 + + Args: + prompt: 提示文本 + **kwargs: 其他参数 + + Returns: + ModelResponse 或 ModelResponse 的迭代器(流式模式) + """ + streaming = kwargs.get("streaming", self.config.streaming) + logger.info(f" UnifiedLLM.complete() 调用") + logger.info(f" prompt长度: {len(prompt)}") + logger.info(f" streaming参数: {streaming}") + + return self.client.completion(prompt, **kwargs) + + def stream_chat(self, messages: List[Union[ChatMessage, Dict]], **kwargs) -> Iterator[ModelResponse]: + """ + 流式聊天补全(便捷方法) + + Args: + messages: 消息列表 + **kwargs: 其他参数 + + Returns: + ModelResponse 的迭代器 + """ + kwargs["streaming"] = True + logger.info(f"UnifiedLLM.stream_chat() 调用") + + result = self.chat(messages, **kwargs) + + # 确保返回的是迭代器 + if not hasattr(result, '__iter__'): + raise TypeError("stream_chat 应该返回迭代器,但返回了其他类型") + + return result + + def stream_complete(self, prompt: str, **kwargs) -> Iterator[ModelResponse]: + """ + 流式文本补全(便捷方法) + + Args: + prompt: 提示文本 + **kwargs: 其他参数 + + Returns: + ModelResponse 的迭代器 + """ + kwargs["streaming"] = True + logger.info(f"UnifiedLLM.stream_complete() 调用") + + result = self.complete(prompt, **kwargs) + + # 确保返回的是迭代器 + if not hasattr(result, '__iter__'): + raise TypeError("stream_complete 应该返回迭代器,但返回了其他类型") + + return result + + +# 快捷函数 +def create_llm_client( + provider: Union[str, ModelProvider], + model_name: str, + **kwargs +) -> UnifiedLLM: + """ + 快捷创建LLM客户端 + + Args: + provider: 提供商名称或枚举 + model_name: 模型名称 + **kwargs: 其他配置参数 + + Returns: + UnifiedLLM 实例 + """ + if isinstance(provider, str): + provider = ModelProvider(provider.lower()) + + config = ModelConfig( + provider=provider, + model_name=model_name, + **kwargs + ) + + return UnifiedLLM(config) + + +# 使用示例 +def example_usage(): + """使用示例""" + + # 示例1: 使用OpenAI + print("示例1: 使用OpenAI") + openai_config = ModelConfig( + provider=ModelProvider.OPENAI, + model_name="gpt-3.5-turbo", + api_key=os.getenv("OPENAI_API_KEY"), + temperature=0.8 + ) + + try: + openai_llm = UnifiedLLM(openai_config) + messages = [ + ChatMessage(role="system", content="你是一个有用的助手"), + ChatMessage(role="user", content="请用Python写一个Hello World程序") + ] + response = openai_llm.chat(messages) + print(f"响应: {response.content[:100]}...") + except Exception as e: + print(f"OpenAI示例错误: {e}") + + # 示例2: 使用Ollama(本地模型) + print("\n示例2: 使用Ollama(本地模型)") + ollama_config = ModelConfig( + provider=ModelProvider.OLLAMA, + model_name="llama2", + base_url="http://localhost:11434", + temperature=0.7, + streaming=True # 流式响应 + ) + + try: + ollama_llm = UnifiedLLM(ollama_config) + messages = [ + {"role": "user", "content": "什么是人工智能?"} + ] + + print("流式响应:") + for chunk in ollama_llm.stream_chat(messages): + print(chunk.content, end="", flush=True) + print() + except Exception as e: + print(f"Ollama示例错误: {e}(请确保Ollama服务正在运行)") + + # 示例3: 使用快捷函数 + print("\n示例3: 使用快捷函数") + try: + llm = create_llm_client( + provider="openai", + model_name="gpt-3.5-turbo", + api_key=os.getenv("OPENAI_API_KEY"), + temperature=0.5 + ) + + response = llm.complete("天空为什么是蓝色的?") + print(f"响应: {response.content[:100]}...") + except Exception as e: + print(f"快捷函数示例错误: {e}") \ No newline at end of file diff --git a/src/modules/text_ai_module/text_ai_core/readme.md b/src/modules/text_ai_module/text_ai_core/readme.md new file mode 100644 index 0000000..cc3fcd4 --- /dev/null +++ b/src/modules/text_ai_module/text_ai_core/readme.md @@ -0,0 +1,162 @@ +# 大语言模型调用框架架构图 +general_text_ai_req.py +```mermaid +graph TB + subgraph "应用层" + A[用户应用] --> B[UnifiedLLM 统一接口] + end + + subgraph "适配器层" + B --> C{模型提供商路由} + C --> D[OpenAI 适配器] + C --> E[Anthropic 适配器] + C --> F[Ollama 适配器] + C --> G[通用HTTP适配器] + C --> H[其他适配器] + end + + subgraph "服务层" + D --> I[OpenAI API] + E --> J[Anthropic API] + F --> K[Ollama 服务] + G --> L[LM Studio] + G --> M[llama.cpp] + G --> N[其他兼容API] + end + + subgraph "配置层" + O[ModelConfig] --> C + O --> D + O --> E + O --> F + O --> G + end + + subgraph "数据流" + P[输入: 消息/提示] --> B + I --> Q[输出: ModelResponse] + J --> Q + K --> Q + L --> Q + M --> Q + N --> Q + end + + style A fill:#4567f1 + style B fill:#4567f1 + style O fill:#456748 + style D fill:#457911 + style E fill:#466bd5 + style F fill:#4567f1 + style G fill:#4567f1 +``` + +# 数据流图 +```mermaid +sequenceDiagram + participant User as 用户/应用 + participant UnifiedLLM as UnifiedLLM + participant Adapter as 适配器 + participant API as API服务 + + User->>UnifiedLLM: 调用chat()或complete() + UnifiedLLM->>UnifiedLLM: 根据配置选择适配器 + UnifiedLLM->>Adapter: 转发请求 + Adapter->>API: 发送HTTP请求/API调用 + Note over API: 处理请求并生成响应 + + alt 流式模式 + API-->>Adapter: 流式响应数据 + Adapter-->>UnifiedLLM: 流式ModelResponse + UnifiedLLM-->>User: 迭代器返回分块响应 + else 非流式模式 + API-->>Adapter: 完整响应 + Adapter-->>UnifiedLLM: ModelResponse对象 + UnifiedLLM-->>User: 完整响应内容 + end +``` +# 类关系图 +```mermaid +classDiagram + class ModelConfig { + +provider: ModelProvider + +model_name: str + +api_key: Optional[str] + +base_url: Optional[str] + +temperature: float + +max_tokens: int + +to_dict() Dict + } + + class ChatMessage { + +role: str + +content: str + +name: Optional[str] + +to_dict() Dict + } + + class ModelResponse { + +content: str + +model: str + +usage: Optional[Dict] + +finish_reason: Optional[str] + +raw_response: Optional[Dict] + } + + class BaseLLMClient { + <> + #config: ModelConfig + #client: Any + +__init__(config: ModelConfig) + +_initialize_client() + +chat_completion(messages, **kwargs)* + +completion(prompt, **kwargs)* + +format_messages(messages) List[Dict] + } + + class UnifiedLLM { + -config: ModelConfig + -client: BaseLLMClient + +__init__(config: ModelConfig) + +_create_client() BaseLLMClient + +update_config(config: ModelConfig) + +chat(messages, **kwargs) ModelResponse + +complete(prompt, **kwargs) ModelResponse + +stream_chat(messages, **kwargs) Iterator + +stream_complete(prompt, **kwargs) Iterator + } + + class OpenAIClient { + +_initialize_client() + +chat_completion(messages, **kwargs) + +completion(prompt, **kwargs) + } + + class AnthropicClient { + +_initialize_client() + +chat_completion(messages, **kwargs) + +completion(prompt, **kwargs) + } + + class OllamaClient { + +_initialize_client() + +chat_completion(messages, **kwargs) + +completion(prompt, **kwargs) + } + + class GenericLLMClient { + +_initialize_client() + +chat_completion(messages, **kwargs) + +completion(prompt, **kwargs) + } + + BaseLLMClient <|-- OpenAIClient + BaseLLMClient <|-- AnthropicClient + BaseLLMClient <|-- OllamaClient + BaseLLMClient <|-- GenericLLMClient + UnifiedLLM o-- BaseLLMClient + UnifiedLLM --> ModelConfig + BaseLLMClient --> ModelConfig + BaseLLMClient --> ChatMessage + BaseLLMClient --> ModelResponse +``` \ No newline at end of file diff --git a/src/modules/tts_module/__init__.py b/src/modules/tts_module/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/modules/tts_module/readme.md b/src/modules/tts_module/readme.md new file mode 100644 index 0000000..eacada0 --- /dev/null +++ b/src/modules/tts_module/readme.md @@ -0,0 +1,26 @@ +本模块为tts模块,即文本转语音模块。 + +本模块负责将来自AI的回复转为语音。 + +说明: +在本module当中,每个子模块的用途分别是: +- tts_core 对不同的tts的实现,提供相对统一的接口 + - gpt_sovits + 实现了gpt_sovits的tts接口封装 + + +async_audio_player.py +```mermaid +sequenceDiagram + participant TTS as GPT-SoVITS API + participant WS as WebSocket服务 + participant Buffer as 音频缓冲区 + participant Player as 音频播放器 + + TTS->>WS: 流式音频块(chunks) + WS->>Buffer: 写入队列(Queue) + Buffer->>Player: 消费PCM数据 + Player->>声卡: 实时播放 + + Note over TTS,Player: 三重缓冲 + 动态采样率检测 +``` \ No newline at end of file diff --git a/src/modules/tts_module/tts_core/__init__.py b/src/modules/tts_module/tts_core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/modules/tts_module/tts_core/async_audio_player.py b/src/modules/tts_module/tts_core/async_audio_player.py new file mode 100644 index 0000000..fd42b5a --- /dev/null +++ b/src/modules/tts_module/tts_core/async_audio_player.py @@ -0,0 +1,164 @@ +# tts_core/async_audio_player.py +import asyncio +import io +from loguru import logger +from typing import Optional +import numpy as np +import sounddevice as sd +import wave + +class AsyncAudioPlayer: + """ + 异步流式音频播放器 + - 自动检测WAV头并解析采样率 + - 使用环形缓冲区确保播放流畅 + - 支持动态音频格式切换 + """ + + def __init__(self, buffer_size: int = 10): + """ + Args: + buffer_size: 音频块缓冲数量(越大越稳定,但延迟越高) + """ + self.audio_queue = asyncio.Queue(maxsize=buffer_size) + self.sample_rate = 32000 # 默认采样率 + self.channels = 1 + self.dtype = np.float32 + self.stream: Optional[sd.OutputStream] = None + self.is_playing = False + self._first_chunk_processed = False + logger.info(f"🎵 音频播放器初始化,缓冲区大小: {buffer_size}") + + async def add_chunk(self, audio_data: bytes): + """ + 添加音频块到播放队列 + 自动处理第一个chunk(包含WAV头) + """ + try: + # 第一个chunk需要解析WAV头 + if not self._first_chunk_processed: + # 写入BytesIO以便wave模块读取 + wav_buffer = io.BytesIO(audio_data) + try: + with wave.open(wav_buffer, 'rb') as wav_file: + # 解析WAV头信息 + self.sample_rate = wav_file.getframerate() + self.channels = wav_file.getnchannels() + self.sampwidth = wav_file.getsampwidth() + + # 读取PCM数据(去掉头部) + pcm_data = wav_file.readframes(wav_file.getnframes()) + + logger.info(f"📊 解析WAV头: {self.sample_rate}Hz, {self.channels}ch, {self.sampwidth * 8}bit") + + # 转换为numpy数组 + if self.sampwidth == 2: + audio_array = np.frombuffer(pcm_data, dtype=np.int16).astype(np.float32) / 32768.0 + elif self.sampwidth == 4: + audio_array = np.frombuffer(pcm_data, dtype=np.int32).astype(np.float32) / 2147483648.0 + else: + raise ValueError(f"不支持的采样宽度: {self.sampwidth}") + + # 转单声道(如果多声道) + if self.channels > 1: + audio_array = audio_array.reshape(-1, self.channels).mean(axis=1) + + await self.audio_queue.put(audio_array) + self._first_chunk_processed = True + + except wave.Error: + # 可能是不完整的WAV头,尝试直接播放 + logger.warning("⚠️ WAV头解析失败,尝试直接播放") + await self._play_raw(audio_data) + return + else: + # 后续chunk直接播放(RAW PCM) + await self._play_raw(audio_data) + + except Exception as e: + logger.error(f"❌ 音频块处理失败: {e}") + + async def _play_raw(self, audio_data: bytes): + """播放RAW PCM数据""" + try: + # 假设是16位PCM(最常见) + audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0 + + # 如果是多声道数据(罕见) + if len(audio_array) % self.channels == 0 and self.channels > 1: + audio_array = audio_array.reshape(-1, self.channels).mean(axis=1) + + await self.audio_queue.put(audio_array) + except Exception as e: + logger.error(f"❌ RAW音频处理失败: {e}") + + async def play_worker(self): + """后台播放任务""" + logger.info("🎧 音频播放任务启动") + + while self.is_playing or not self.audio_queue.empty(): + try: + # 从队列获取音频块(最多等待0.5秒) + audio_chunk = await asyncio.wait_for(self.audio_queue.get(), timeout=0.5) + + # 延迟初始化音频流(直到获得第一个数据块) + if self.stream is None: + logger.info(f"🔊 打开音频输出流: {self.sample_rate}Hz") + self.stream = sd.OutputStream( + samplerate=self.sample_rate, + channels=1, + dtype=self.dtype, + blocksize=1024, # 低延迟模式 + latency='low' + ) + self.stream.start() + + # 写入音频流播放 + self.stream.write(audio_chunk) + + # 标记任务完成 + self.audio_queue.task_done() + + except asyncio.TimeoutError: + continue + except Exception as e: + logger.error(f"❌ 播放任务异常: {e}") + break + + logger.info("🛑 音频播放任务结束") + + async def start(self): + """启动播放系统""" + self.is_playing = True + self._first_chunk_processed = False + self.play_task = asyncio.create_task(self.play_worker()) + + async def stop(self): + """停止播放并清理资源""" + self.is_playing = False + + # 等待播放任务结束 + if hasattr(self, 'play_task'): + await self.play_task + + # 关闭音频流 + if self.stream is not None: + self.stream.stop() + self.stream.close() + self.stream = None + + # 清空队列 + while not self.audio_queue.empty(): + try: + self.audio_queue.get_nowait() + except: + break + + logger.info("✅ 音频播放已停止") + + async def __aenter__(self): + await self.start() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.stop() \ No newline at end of file diff --git a/src/modules/tts_module/tts_core/gpt_sovits/__init__.py b/src/modules/tts_module/tts_core/gpt_sovits/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/modules/tts_module/tts_core/gpt_sovits/gpt_sovits_client.py b/src/modules/tts_module/tts_core/gpt_sovits/gpt_sovits_client.py new file mode 100644 index 0000000..96e0bf7 --- /dev/null +++ b/src/modules/tts_module/tts_core/gpt_sovits/gpt_sovits_client.py @@ -0,0 +1,378 @@ +# gpt_sovits/gpt_sovits_client.py +import asyncio +import json +from loguru import logger +from enum import Enum +from pathlib import Path +from typing import AsyncGenerator, Optional, Union, Dict, Any +from dataclasses import dataclass + +import httpx +from pydantic import BaseModel, Field, validator + + +class APIError(Exception): + """API调用异常""" + def __init__(self, status_code: int, message: str): + self.status_code = status_code + self.message = message + super().__init__(f"API Error {status_code}: {message}") + + +class StreamingMode(Enum): + """流式模式枚举""" + DISABLED = 0 # 非流式 + BEST_QUALITY = 1 # 最佳质量(慢) + MEDIUM_QUALITY = 2 # 中等质量 + FASTEST = 3 # 最快响应(较低质量) + + +class TTSConfig(BaseModel): + """TTS请求配置模型""" + text: str = Field(..., description="待合成文本") + text_lang: str = Field(..., description="文本语言: zh/en/ja/ko/cantonese") + ref_audio_path: str = Field(..., description="参考音频路径") + prompt_lang: str = Field(..., description="提示文本语言") + + # 可选参数 + prompt_text: str = Field(default="", description="参考音频提示文本") + aux_ref_audio_paths: list = Field(default_factory=list, description="辅助参考音频") + top_k: int = Field(default=5, ge=1, le=100, description="Top-K采样") + top_p: float = Field(default=1.0, ge=0.1, le=1.0, description="Top-P采样") + temperature: float = Field(default=1.0, ge=0.1, le=1.0, description="采样温度") + text_split_method: str = Field(default="cut5", description="文本分割方法") # 默认按照标点符号切分 + batch_size: int = Field(default=8, ge=1, le=200, description="批处理大小") + speed_factor: float = Field(default=1.0, ge=0.6, le=1.65, description="语速倍率") + + # 流式相关 + streaming_mode: Union[bool, int, StreamingMode] = Field(default=False, description="流式模式") + media_type: str = Field(default="wav", description="输出格式: wav/raw/ogg/aac") # 输出格式 + + # 高级参数 + repetition_penalty: float = Field(default=1.35, ge=1.0, le=2.0) # 惩罚参数 + sample_steps: int = Field(default=32, ge=10, le=100) # 采样步数 + parallel_infer: bool = Field(default=True) # 并行推理 + + @validator('text_lang', 'prompt_lang') + def validate_language(cls, v): + """验证语言代码""" + valid_langs = {'zh', 'en', 'ja', 'ko', 'cantonese'} + if v.lower() not in valid_langs: + raise ValueError(f"Unsupported language: {v}. Must be one of {valid_langs}") + return v.lower() + + @validator('media_type') + def validate_media_type(cls, v): + """验证媒体类型""" + valid_types = {'wav', 'raw', 'ogg', 'aac'} + if v not in valid_types: + raise ValueError(f"Unsupported media_type: {v}") + return v + + def build_request(self) -> Dict[str, Any]: + """构建API请求数据""" + data = self.dict(exclude_none=True) + # 处理流式模式 + if isinstance(self.streaming_mode, StreamingMode): + data['streaming_mode'] = self.streaming_mode.value + return data + + +@dataclass +class AudioResponse: + """音频响应包装类""" + audio_data: bytes + sample_rate: int = 32000 + + def save(self, path: Union[str, Path]) -> None: + """保存音频文件""" + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, 'wb') as f: + f.write(self.audio_data) + logger.info(f"Audio saved to {path}, size: {len(self.audio_data)} bytes") + + +class GPTSoVITSClient: + """ + GPT-SoVITS异步API客户端 + + 完整支持所有TTS功能: + - 文本合成(流式/非流式) + - 模型切换(GPT/SoVITS) + - 参考音频设置 + - 服务器控制 + """ + + def __init__(self, host: str = "127.0.0.1", port: int = 9880, debug: bool = False): + """ + 初始化客户端 + + Args: + host: API服务器地址 + port: API端口 + debug: 是否开启调试模式 + """ + self.base_url = f"http://{host}:{port}" + self.client = httpx.AsyncClient( + base_url=self.base_url, + timeout=httpx.Timeout(30.0, connect=5.0) + ) + self.debug_mode = debug + logger.info(f"GPT-SoVITS Client initialized: {self.base_url}") + + async def __aenter__(self) -> "GPTSoVITSClient": + """异步上下文管理器入口""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """异步上下文管理器出口""" + await self.close() + + async def close(self): + """关闭HTTP连接""" + await self.client.aclose() + logger.info("Client connection closed") + + def _log_debug(self, message: str, **kwargs): + """调试日志""" + if self.debug_mode: + logger.debug(f"{message} | {kwargs}") + + async def _handle_response(self, response: httpx.Response) -> Dict[str, Any]: + """统一响应处理""" + if response.status_code == 200: + content_type = response.headers.get('content-type', '') + if 'application/json' in content_type: + return response.json() + return {"status": "success", "content": response.content} + else: + try: + error_data = response.json() + raise APIError(response.status_code, error_data.get('message', 'Unknown error')) + except json.JSONDecodeError: + raise APIError(response.status_code, response.text) + + # 核心TTS接口 + async def tts( + self, + text: str, + ref_audio_path: str, + text_lang: str = "zh", + prompt_lang: str = "zh", + streaming_mode: StreamingMode = StreamingMode.DISABLED, # 默认禁用流式 + media_type: str = "wav", + **kwargs + ) -> Union[AudioResponse, AsyncGenerator[AudioResponse, None]]: + """ + 文本转语音(支持流式) + + Args: + text: 待合成文本 + ref_audio_path: 参考音频路径(服务器本地路径或URL) + text_lang: 文本语言 + prompt_lang: 提示语言 + streaming_mode: 流式模式 + media_type: 输出格式 + **kwargs: 其他TTS参数 + + Returns: + 非流式: AudioResponse对象 + 流式: AsyncGenerator[AudioResponse, None]异步生成器 + + Example: + # 非流式 + audio = await client.tts("你好", "ref.wav") + + # 流式 + async for chunk in client.tts("你好", "ref.wav", streaming_mode=StreamingMode.FASTEST): + process(chunk.audio_data) + """ + config = TTSConfig( + text=text, + ref_audio_path=ref_audio_path, + text_lang=text_lang, + prompt_lang=prompt_lang, + streaming_mode=streaming_mode, + media_type=media_type, + **kwargs + ) + + self._log_debug("TTS Request", config=config.dict()) + + if streaming_mode == StreamingMode.DISABLED: + # 非流式模式 + response = await self.client.post("/tts", json=config.build_request()) + if response.status_code != 200: + raise APIError(response.status_code, await response.text()) + + return AudioResponse( + audio_data=response.content, + sample_rate=32000 # 默认采样率 + ) + else: + # 流式模式 + config.parallel_infer = False # 强制关闭并行推理,避免与流式冲突 + config.batch_size = 1 # 流式下batch_size必须为1 + async def stream_generator(): + async with self.client.stream( + "POST", "/tts", + json=config.build_request(), + timeout=httpx.Timeout(60.0) # 流式需要更长超时 + ) as response: + if response.status_code != 200: + raise APIError(response.status_code, await response.aread()) + + async for chunk in response.aiter_bytes(): + if chunk: + yield AudioResponse(audio_data=chunk) + + return stream_generator() + + # 模型管理接口 + async def set_gpt_weights(self, weights_path: str) -> bool: + """ + 切换GPT模型权重 + + Args: + weights_path: 权重文件路径(服务器本地路径) + + Returns: + bool: 是否成功 + + Example: + await client.set_gpt_weights("models/s1bert.ckpt") + """ + if not weights_path: + raise ValueError("weights_path cannot be empty") + + params = {"weights_path": weights_path} + response = await self.client.get("/set_gpt_weights", params=params) + result = await self._handle_response(response) + + logger.info(f"GPT weights switched to: {weights_path}") + return True + + async def set_sovits_weights(self, weights_path: str) -> bool: + """ + 切换SoVITS模型权重 + + Args: + weights_path: 权重文件路径(服务器本地路径) + + Returns: + bool: 是否成功 + """ + if not weights_path: + raise ValueError("weights_path cannot be empty") + + params = {"weights_path": weights_path} + response = await self.client.get("/set_sovits_weights", params=params) + await self._handle_response(response) + + logger.info(f"SoVITS weights switched to: {weights_path}") + return True + + # 参考音频管理 + async def set_refer_audio( + self, + audio_source: Union[str, Path, bytes], + audio_name: Optional[str] = None + ) -> bool: + """ + 设置参考音频(支持多种输入方式) + + Args: + audio_source: 音频文件路径(str/Path)或音频数据(bytes) + audio_name: 音频文件名(仅bytes输入时需要) + + Returns: + bool: 是否成功 + + Example: + # 方式1: 服务器本地文件 + await client.set_refer_audio("/path/to/audio.wav") + + # 方式2: 上传音频数据 + with open("audio.wav", "rb") as f: + await client.set_refer_audio(f.read(), "audio.wav") + """ + if isinstance(audio_source, (str, Path)): + # GET方式:服务器本地路径 + params = {"refer_audio_path": str(audio_source)} + response = await self.client.get("/set_refer_audio", params=params) + await self._handle_response(response) + logger.info(f"Reference audio set: {audio_source}") + else: + # POST方式:上传音频数据 + if not audio_name: + raise ValueError("audio_name is required when uploading bytes") + + files = {"audio_file": (audio_name, audio_source, "audio/wav")} + response = await self.client.post("/set_refer_audio", files=files) + await self._handle_response(response) + logger.info(f"Reference audio uploaded: {audio_name}") + + return True + + # 服务器控制 + async def control_command(self, command: str) -> bool: + """ + 发送控制命令 + + Args: + command: 命令类型 - "restart" 或 "exit" + + Returns: + bool: 是否成功 + + Warning: + "exit"命令会终止API服务器进程! + """ + if command not in ["restart", "exit"]: + raise ValueError("Command must be 'restart' or 'exit'") + + response = await self.client.get("/control", params={"command": command}) + await self._handle_response(response) + + logger.warning(f"Control command executed: {command}") + return True + + # 高级快捷方法 + async def get_server_info(self) -> Dict[str, Any]: + """获取服务器状态信息""" + # 通过调用根路径或自定义health接口 + try: + response = await self.client.get("/") + return {"status": "online", "detail": response.text} + except Exception as e: + return {"status": "error", "detail": str(e)} + + async def batch_tts( + self, + texts: list[str], + ref_audio_path: str, + **kwargs + ) -> list[AudioResponse]: + """ + 批量TTS合成 + + Args: + texts: 文本列表 + ref_audio_path: 参考音频 + **kwargs: 其他TTS参数 + + Returns: + list[AudioResponse]: 音频响应列表 + """ + tasks = [ + self.tts(text, ref_audio_path, **kwargs) + for text in texts + ] + return await asyncio.gather(*tasks) + + +# 异步上下文管理器辅助函数 +async def create_client(*args, **kwargs) -> GPTSoVITSClient: + """快速创建客户端实例""" + return GPTSoVITSClient(*args, **kwargs) \ No newline at end of file diff --git a/src/modules/websocket_base_module/__init__.py b/src/modules/websocket_base_module/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/modules/websocket_base_module/dto/__init__.py b/src/modules/websocket_base_module/dto/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/modules/websocket_base_module/dto/dto_base.py b/src/modules/websocket_base_module/dto/dto_base.py new file mode 100644 index 0000000..c9c3e63 --- /dev/null +++ b/src/modules/websocket_base_module/dto/dto_base.py @@ -0,0 +1,27 @@ +# dto/dto_base.py +from abc import ABC +from typing import Callable, Coroutine, Any, Dict +from src.modules.websocket_base_module.websocket_core.core_ws_server import WebSocketServer + +class MessageDTO(ABC): + """DTO基类""" + + def __init__( + self, + ws_server : WebSocketServer, # WebSocketServer单例 + ): + # 保存服务器实例(用于发送) + self.ws_server = ws_server + + # 便捷属性,DTO层直接调用 + @property + def send_binary(self): + return self.ws_server.send_binary + + @property + def send_text(self): + return self.ws_server.send_text + + @property + def send_json(self): + return self.ws_server.send_json diff --git a/src/modules/websocket_base_module/dto/dto_templates/__init__.py b/src/modules/websocket_base_module/dto/dto_templates/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/modules/websocket_base_module/dto/dto_templates/audio_data_dto.py b/src/modules/websocket_base_module/dto/dto_templates/audio_data_dto.py new file mode 100644 index 0000000..e4683d7 --- /dev/null +++ b/src/modules/websocket_base_module/dto/dto_templates/audio_data_dto.py @@ -0,0 +1,95 @@ +from pydantic import Field, BaseModel +import base64 +from datetime import datetime, timezone +class AudioDataTransferObject(BaseModel): + """ + 音频数据传输对象 + 该对象被用于服务端与客户端的音频数据交互 + 同时支持流式与非流式的音频数据 + 同时收发对等(通过Owner标识) + """ + Owner: str = Field(default="server", description="音频数据的拥有者(server or client)") + isStream: bool = Field(default=False, description="音频数据是否为流式数据") + isStart: bool = Field(default=False, description="音频数据是否开始(流式时有效)") + isEnd: bool = Field(default=False, description="音频数据是否结束(流式时有效)") + sequence: int = Field(default=0, description="音频数据块序列号(流式时有效)") + data: bytes = Field(default=b"", description="音频数据,流式时为分块数据,base64编码") + sampleRate: int = Field(default=32000, description="音频采样率") + channelCount: int = Field(default=1, description="音频通道数") + bitDepth: int = Field(default=16, description="音频采样位数") + duration: float = Field(default=0.0, description="音频时长") + text: str = Field(default="", description="音频对应的文本") + + def set_dto_data(self, **kwargs) -> "AudioDataTransferObject": + """链式更新数据(Pydantic 风格)""" + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + return self + + def to_json(self) -> dict: + """ + 将DTO对象转换为可序列化字典 + 返回的json数据的格式: + { + "type": "audio_data", + "timestamp": 1672531200.0, + "data": { + "Owner": "server", + ...... + } + } + """ + # model_dump() 是 Pydantic v2 的序列化方法 + payload = self.model_dump() # 获取所有模型字段 + payload["data"] = base64.b64encode(payload["data"]).decode() # base64编码 + # 构造嵌套结构 + return { + "type": "audio_data", + "timestamp": datetime.now(timezone.utc).timestamp(), + "data": payload # 音频字段嵌套 + } + + @classmethod + def from_json(cls, json_data: dict) -> "AudioDataTransferObject": + """ + 从JSON数据创建DTO对象 + 传入的json数据格式: + { + "Owner": "server", + ...... + } + """ + payload = json_data + # 解码 base64 (内层 data 字段在传输时是 base64 字符串) -> bytes + if "data" in payload and isinstance(payload["data"], str): + payload["data"] = base64.b64decode(payload["data"]) + # 构造对象 Pydantic 自动忽略 type/timestamp + return cls.model_validate(payload) + + +# 测试代码 +if __name__ == "__main__": + # 模拟音频数据 + import os + + test_audio = os.urandom(1024) # 随机生成1KB音频数据 + + # 创建DTO + audio = AudioDataTransferObject( + data=test_audio, + sequence=1, + isStream=True, + sampleRate=44100, + duration=0.5 + ) + + # 序列化 → JSON + json_dict = audio.to_json() + print(f"序列化后 data 长度: {len(json_dict['data'])}") # ~1368 字符 + print(f"data 前30字符: {json_dict['data'][:30]}...") + + # 反序列化 → DTO + restored = AudioDataTransferObject.from_json(json_dict) + print(f"反序列化后 data 长度: {len(restored.data)}") # 1024 bytes + print(f"数据一致: {restored.data == test_audio}") # True diff --git a/src/modules/websocket_base_module/dto/dto_templates/auto_agent_data_dto.py b/src/modules/websocket_base_module/dto/dto_templates/auto_agent_data_dto.py new file mode 100644 index 0000000..a992326 --- /dev/null +++ b/src/modules/websocket_base_module/dto/dto_templates/auto_agent_data_dto.py @@ -0,0 +1,73 @@ +from pydantic import Field, BaseModel +from datetime import datetime, timezone +class AutoAgentDataTransferObject(BaseModel): + """ + 自动化agent数据传输对象 + 该对象被用于服务端向客户端发送控制信息 + """ + Action: str = Field(default="", description="自动化动作名称") + x1: int = Field(default=-1, description="鼠标起始位置x1") + y1: int = Field(default=-1, description="鼠标起始位置y1") + x2: int = Field(default=-1, description="鼠标结束位置x2") + y2: int = Field(default=-1, description="鼠标结束位置y2") + key: str = Field(default="", description="快捷键") + content: str = Field(default="", description="输入文本内容") + direction: str = Field(default="", description="滚动方向") + + def set_dto_data(self, **kwargs) -> "AutoAgentDataTransferObject": + """链式更新数据(Pydantic 风格)""" + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + return self + + def to_json(self) -> dict: + """ + 将DTO对象转换为可序列化字典 + 返回的json数据的格式: + { + "type": "auto_agent", + "timestamp": 1672531200.0, + "data": { + "Action": "自动化agent返回的相应的自动化动作名称", + "x1": "某个操作的x1,不是所有的操作都有,如果相关的操作没有,写成-1即可,y1,x2,y2也同理", + "y1": "某个操作的y1", + "x2": "某个操作的x2", + "y2": "某个操作的y2", + "key": "快捷键,若当前操作没有该字段信息,此处内容为空即可", + "content": "输入文本的内容,若当前操作没有该字段信息,此处内容为空即可", + "direction": "滚动方向,若当前操作没有该字段信息,此处内容为空即可" + } + } + """ + # model_dump() 是 Pydantic v2 的序列化方法 + payload = self.model_dump() # 获取所有模型字段 + # 构造嵌套结构 + return { + "type": "auto_agent", + "timestamp": datetime.now(timezone.utc).timestamp(), + "data": payload # 字段嵌套 + } + + @classmethod + def from_json(cls, json_data: dict) -> "AutoAgentDataTransferObject": + """ + 从JSON数据创建DTO对象 + 传入的json数据格式: + { + "type": "auto_agent", + "Action": "自动化agent返回的相应的自动化动作名称", + "x1": "某个操作的x1,不是所有的操作都有,如果相关的操作没有,写成-1即可,y1,x2,y2也同理", + "y1": "某个操作的y1", + "x2": "某个操作的x2", + "y2": "某个操作的y2", + "key": "快捷键,若当前操作没有该字段信息,此处内容为空即可", + "content": "输入文本的内容,若当前操作没有该字段信息,此处内容为空即可", + "direction": "滚动方向,若当前操作没有该字段信息,此处内容为空即可" + } + """ + payload = json_data + payload.pop("type", None) # 移除多余字段 + # 构造对象 Pydantic 自动忽略 type/timestamp + return cls.model_validate(payload) + diff --git a/src/modules/websocket_base_module/dto/dto_templates/data_dto_base.py b/src/modules/websocket_base_module/dto/dto_templates/data_dto_base.py new file mode 100644 index 0000000..df32b46 --- /dev/null +++ b/src/modules/websocket_base_module/dto/dto_templates/data_dto_base.py @@ -0,0 +1,37 @@ +class BaseDataTransferObject: + """ + DTO基类 + 子类按需重写成员函数即可 + """ + def __init__(self): + pass + def to_json(self): + """ + 将DTO对象转换为JSON + """ + pass + def from_json(self, json_data): + """ + 从JSON数据中创建DTO对象 + """ + pass + def to_binary(self): + """ + 将DTO对象转换为二进制 + """ + pass + def from_binary(self, binary_data): + """ + 从二进制数据中创建DTO对象 + """ + pass + def to_text(self): + """ + 将DTO对象转换为文本 + """ + pass + def from_text(self, text_data): + """ + 从文本数据中创建DTO对象 + """ + pass \ No newline at end of file diff --git a/src/modules/websocket_base_module/dto/dto_templates/screenshot_data_dto.py b/src/modules/websocket_base_module/dto/dto_templates/screenshot_data_dto.py new file mode 100644 index 0000000..e49ea33 --- /dev/null +++ b/src/modules/websocket_base_module/dto/dto_templates/screenshot_data_dto.py @@ -0,0 +1,72 @@ +from pydantic import Field, BaseModel +from datetime import datetime, timezone +class ScreenShotDataTransferObject(BaseModel): + """ + 服务端向客户端请求实时截图的数据传输对象 + 服务端与客户端收发对等(通过Owner标识) + 客户端收到这个type的包,就会自动对当前设备的画面进行截图 + """ + Owner: str = Field(default="server", description="数据的拥有者(server or client)") + isSuccess: bool = Field(default=False, description="是否截图成功") + RealTimeScreenShot: str = Field(default="", description="客户端设备的实时截图数据(base64)") + Width: int = Field(default=1920, description="截图的宽度") + Height: int = Field(default=1080, description="截图的高度") + DescribeInfo: str = Field(default="", description="设备的描述信息(告知模型以做出更加准确的判断)") + LLMResponse: str = Field(default="", description="LLM的响应结果(由服务端发送时携带)") + + + def set_dto_data(self, **kwargs) -> "ScreenShotDataTransferObject": + """链式更新数据(Pydantic 风格)""" + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + return self + + def to_json(self) -> dict: + """ + 将DTO对象转换为可序列化字典 + 返回的json数据的格式: + { + "type": "screenshot_data", + "timestamp": 1672531200.0, + "data": { + "Owner": "数据的拥有者(server or client)", + "isSuccess": "是否截图成功(true or false)" + "RealTimeScreenShot": "客户端设备的实时截图数据(base64)", + "Width": "截图的宽度", + "Height": "截图的高度", + "DescribeInfo": "设备的描述信息(告知模型以做出更加准确的判断)", + "LLMResponse": "LLM的响应结果(由服务端发送时携带)" + } + } + """ + # model_dump() 是 Pydantic v2 的序列化方法 + payload = self.model_dump() # 获取所有模型字段 + # 构造嵌套结构 + return { + "type": "screenshot_data", + "timestamp": datetime.now(timezone.utc).timestamp(), + "data": payload # 字段嵌套 + } + + @classmethod + def from_json(cls, json_data: dict) -> "ScreenShotDataTransferObject": + """ + 从JSON数据创建DTO对象 + 传入的json数据格式: + { + "Owner": "数据的拥有者(server or client)", + "isSuccess": "是否截图成功(true or false)", + "RealTimeScreenShot": "客户端设备的实时截图数据(base64)", + "Width": "截图的宽度 非必要字段", + "Height": "截图的高度 非必要字段", + "DescribeInfo": "设备的描述信息(告知模型以做出更加准确的判断) 非必要字段", + "LLMResponse": "LLM的响应结果(由服务端发送时携带) 必要字段" + } + """ + payload = json_data + payload.pop("type", None) # 移除多余字段 + payload.pop("timestamp", None) # 移除多余字段 + # 构造对象 Pydantic 自动忽略 type/timestamp + return cls.model_validate(payload) + diff --git a/src/modules/websocket_base_module/dto/second_dtos.py b/src/modules/websocket_base_module/dto/second_dtos.py new file mode 100644 index 0000000..822cd85 --- /dev/null +++ b/src/modules/websocket_base_module/dto/second_dtos.py @@ -0,0 +1,95 @@ +# dto/second_dtos.py +import asyncio +from typing import Callable, Optional, Dict, Any, List, Coroutine +from src.modules.websocket_base_module.dto.dto_base import MessageDTO +from loguru import logger + +""" +二级分发器,因为没有信号与槽机制,因此使用观察者模式替代 +""" +def singleton(cls): # 单例 + _instance = None + _lock = asyncio.Lock() + async def get_instance(*args, **kwargs): + nonlocal _instance + if _instance is None: + async with _lock: + if _instance is None: + _instance = cls(*args, **kwargs) + return _instance + + cls.get_instance = get_instance + return cls +# 类型别名 +ReceiveCallback = Callable[[Any], Coroutine[Any, Any, None]] +@singleton +class JsonDTO(MessageDTO): + """针对json消息的二级分发""" + """ + 明确业务json格式: + { + "type" : "xxx", + "timestamp" : "95153...", + "data" : "{根据业务的不同,有不同的内容}" + } + """ + # 因为不同的数据块的json以type字段进行包装,根据type进行正确的数据分发 + def __init__(self, ws_server): + super().__init__(ws_server) + self.receivers : Dict[str, List[ReceiveCallback]] = { + 'audio_data' : [], # 音频数据 + 'screenshot_data' : [] # 截图数据 + } + # 注册json处理callback function + ws_server.register_receiver('json', self._handle_json) + logger.info("[JsonDTO] JSON分发器已注册") + + def register_receiver(self, types : str, callback : ReceiveCallback): + """注册二次分发业务接收函数,供业务DTO调用""" + if types in self.receivers: + self.receivers[types].append(callback) + logger.debug(f"[JsonDTO] 已注册 {types} 接收器,当前共 {len(self.receivers[types])} 个") + else: + raise ValueError(f"[JsonDTO] 不支持的分发类型: {types}") + + def unregister_receiver(self, types : str, callback : ReceiveCallback): + """注销二次分发业务接收函数""" + if callback in self.receivers[types]: + self.receivers[types].remove(callback) + logger.debug(f"[JsonDTO] 已注销 {types} 接收器") + + async def _handle_json(self, data: dict): + """JSON消息处理""" + logger.info(f"[JsonDTO] 收到消息") + logger.debug(f'[JsonDTO] 当前消息时间戳: {data["timestamp"]}') + # 根据类型进行自动分发 + await self._dispatch(data.get("type"), data["data"]) + + async def _dispatch(self, types : str, data : dict): + """二次分发json数据到相应的接收函数当中""" + callbacks = self.receivers[types] # 获取相关types的所有观察者 + if not callbacks: + logger.info(f"[JsonDTO] 无 {types} 接收器,消息被忽略") + return + # 并发执行所有回调 + tasks = [callback(data) for callback in callbacks] + await asyncio.gather(*tasks, return_exceptions=True) +async def get_json_dto_instance(ws_server) -> JsonDTO: + return JsonDTO(ws_server) + + +class EchoDTO(MessageDTO): + """回声DTO:只处理文本消息 测试用""" + + def __init__(self, ws_server): + super().__init__(ws_server) + # 注册文本接收函数 + ws_server.register_receiver('text', self._handle_text) + logger.info("[EchoDTO] 文本接收器已注册") + + async def _handle_text(self, message: str): + """文本消息处理""" + logger.info(f"[EchoDTO] 收到文本: {message}") + + # 业务逻辑 + await self.send_text(f"Echo: {message}") diff --git a/src/modules/websocket_base_module/dto/third_dtos.py b/src/modules/websocket_base_module/dto/third_dtos.py new file mode 100644 index 0000000..ae3cb11 --- /dev/null +++ b/src/modules/websocket_base_module/dto/third_dtos.py @@ -0,0 +1,281 @@ +import asyncio +from src.modules.websocket_base_module.dto.dto_templates.audio_data_dto import AudioDataTransferObject +from src.modules.websocket_base_module.dto.dto_templates.screenshot_data_dto import ScreenShotDataTransferObject +from src.modules.websocket_base_module.dto.second_dtos import JsonDTO +from loguru import logger +from typing import Callable, List, Optional, Coroutine + +class AudioDataDTO: + """音频数据交互DTO 再次分发给所有使用到了音频数据的相关业务(最后一级分发)""" + def __init__(self, json_dto : JsonDTO): + json_dto.register_receiver('audio_data', self._handle_audio_data) # 注册JSON接收函数 + logger.info("[AudioDataDTO] 音频接收业务已注册") + self.json_dto = json_dto + self.audio_data = AudioDataTransferObject() # 音频数据对象 + # 业务回调列表,延续观察者模式 + self._audio_callbacks: List[Callable[[AudioDataTransferObject], Coroutine]] = [] + # 最新音频缓存 支持同步查询 + self._latest_audio: Optional[AudioDataTransferObject] = None + # 流式缓冲区 用于大段音频流 + self._stream_buffer: List[AudioDataTransferObject] = [] + logger.info("[AudioDataDTO] 业务接口已初始化") + + async def _handle_audio_data(self, data: dict): + """处理音频数据""" + """ + 音频数据json格式: + { + 'Owner': 'server', + 'isStream': False, + 'isStart': False, + 'isEnd': False, + 'sequence': 0, + 'data': '', + 'sampleRate': 16000, + 'channelCount': 1, + 'bitDepth': 16, + 'duration': 0.0, + 'text': '' + } + """ + logger.debug(f"[AudioDataDTO] 收到音频数据") + # 将dict反序列化到DTO对象 + self.audio_data = AudioDataTransferObject.from_json(data) + # 缓存最新数据 + self._latest_audio = self.audio_data + # 如果是流式数据,加入缓冲区 + if self.audio_data.isStream: + self._stream_buffer.append(self.audio_data) + if self.audio_data.isEnd: + logger.info(f"流式音频接收完成,共 {len(self._stream_buffer)} 块") + # 通知所有注册的回调 + await self._notify_callbacks() + + # 业务发送接口 + async def send_audio_data(self, data: AudioDataTransferObject) -> None: + """ + 发送音频数据 + + Args: + data: 音频数据DTO + """ + await self.send_audio( + Owner=data.Owner, + is_stream=data.isStream, + is_start=data.isStart, + is_end=data.isEnd, + sequence=data.sequence, + data=data.data, + sampleRate=data.sampleRate, + channelCount=data.channelCount, + bitDepth=data.bitDepth, + duration=data.duration, + text=data.text + ) + + async def send_audio( + self, + data: bytes, + is_stream: bool = False, + is_start: bool = False, + is_end: bool = False, + sequence: int = 0, + **audio_meta + ) -> None: + """ + 业务层发送音频的便捷接口 + + Args: + data: 原始音频字节 + is_stream: 是否为流式数据 + is_start: 流式数据开始标记 + is_end: 流式数据结束标记 + sequence: 数据块序号 + **audio_meta: 其他音频参数(sampleRate, channelCount等) + """ + # 填充音频数据到DTO + self.audio_data.set_dto_data( + Owner="server" or audio_meta.get('Owner', "server"), + isStream=is_stream, + isStart=is_start, + isEnd=is_end, + sequence=sequence, + data=data, + sampleRate=audio_meta.get('sampleRate', 16000), + channelCount=audio_meta.get('channelCount', 1), + bitDepth=audio_meta.get('bitDepth', 16), + duration=audio_meta.get('duration', 0.0), + text=audio_meta.get('text', "") + ) + # 序列化为JSON并发送 自动处理base64和type字段 + json_message = self.audio_data.to_json() + await self.json_dto.send_json(json_message) + logger.info(f"音频已发送: sequence={sequence}, 大小={len(data)} bytes") + + # 业务接收接口 + def register_audio_callback( + self, + callback: Callable[[AudioDataTransferObject], Coroutine] + ) -> None: + """ + 业务注册接收回调 + + 使用示例: + async def my_audio_handler(audio_dto: AudioDataTransferObject): + print(f"收到音频: {len(audio_dto.data)} bytes") + + audio_dto.register_audio_callback(my_audio_handler) + """ + self._audio_callbacks.append(callback) + logger.debug(f"业务音频回调已注册,当前共 {len(self._audio_callbacks)} 个") + + def unregister_audio_callback(self, callback) -> None: + """注销业务回调""" + if callback in self._audio_callbacks: + self._audio_callbacks.remove(callback) + logger.debug("业务音频回调已注销") + + def get_latest_audio(self) -> Optional[AudioDataTransferObject]: + """ + 同步获取最新音频数据(轮询模式) + + Returns: + 最新接收到的音频DTO,如果没有则为 None + """ + return self._latest_audio + + def clear_stream_buffer(self) -> None: + """清空流式缓冲区""" + self._stream_buffer.clear() + logger.debug("流式音频缓冲区已清空") + + # 内部通知机制 + async def _notify_callbacks(self) -> None: + """通知所有业务回调""" + if not self._audio_callbacks: + logger.warning("无业务回调,音频数据未处理") + return + tasks = [callback(self.audio_data) for callback in self._audio_callbacks] + await asyncio.gather(*tasks, return_exceptions=True) + logger.debug(f"已通知 {len(self._audio_callbacks)} 个业务回调") + # 异步迭代器 流式时使用 + def __aiter__(self): + """支持 async for 循环接收流式音频""" + return self + async def __anext__(self) -> AudioDataTransferObject: + """异步迭代器协议""" + pass + +class ScreenShotDataDTO: + """截屏数据交互DTO 分发给所有使用到了截屏数据的相关业务(最后一级分发)""" + def __init__(self, json_dto : JsonDTO): + json_dto.register_receiver('screenshot_data', self._handle_screenshot_data) # 注册JSON接收函数 + logger.info("[ScreenShotDataDTO] 截屏接收业务已注册") + self.json_dto = json_dto + self.screenshot_data = ScreenShotDataTransferObject() # 截屏数据对象 + # 业务回调列表,延续观察者模式 + self._screenshot_callbacks: List[Callable[[ScreenShotDataTransferObject], Coroutine]] = [] + # 最新截屏数据缓存 支持同步查询 + self._latest_screenshot: Optional[ScreenShotDataTransferObject] = None + logger.info("[ScreenShotDataDTO] 业务接口已初始化") + + async def _handle_screenshot_data(self, data: dict): + """处理截屏数据""" + """ + 截屏数据json格式: + { + "Owner": "数据的拥有者(server or client)", + "isSuccess": "是否截图成功(true or false)" + "RealTimeScreenShot": "客户端设备的实时截图数据(base64)", + "Width": "截图的宽度", + "Height": "截图的高度", + "DescribeInfo": "设备的描述信息(告知模型以做出更加准确的判断)" + } + """ + logger.debug(f"[ScreenShotDataDTO] 收到截屏数据") + # 将dict反序列化到DTO对象 + self.screenshot_data = ScreenShotDataTransferObject.from_json(data) + # 缓存最新数据 + self._latest_screenshot = self.screenshot_data + # 通知所有注册的回调 + await self._notify_callbacks() + + # 业务发送接口 + async def send_screenshot_data(self, data: ScreenShotDataTransferObject) -> None: + """ + 发送音频数据 + + Args: + data: 音频数据DTO + """ + await self.send_screenshot( + Owner=data.Owner, + isSuccess=data.isSuccess, + RealTimeScreenShot=data.RealTimeScreenShot, + Width=data.Width, + Height=data.Height, + DescribeInfo=data.DescribeInfo, + LLMResponse=data.LLMResponse + ) + + async def send_screenshot( + self, + **screenshot_meta + ) -> None: + """ + 业务层发送音频的便捷接口 + 一般来说,作为发送请求方,不需要填充任何数据 + + Args: + **screenshot_meta: 截屏数据元信息 + """ + # 填充音频数据到DTO + self.screenshot_data.set_dto_data( + Owner="server" or screenshot_meta.get('Owner', "server"), + isSuccess=screenshot_meta.get('isSuccess', False), + RealTimeScreenShot=screenshot_meta.get('RealTimeScreenShot', ""), + Width=screenshot_meta.get('Width', 1920), + Height=screenshot_meta.get('Height', 1080), + DescribeInfo=screenshot_meta.get('DescribeInfo', False), + LLMResponse=screenshot_meta.get('LLMResponse', "") + ) + # 序列化为JSON并发送 自动处理base64和type字段 + json_message = self.screenshot_data.to_json() + await self.json_dto.send_json(json_message) + logger.info(f"截屏包已发送") + + # 业务接收接口 + def register_screenshot_callback( + self, + callback: Callable[[ScreenShotDataTransferObject], Coroutine] + ) -> None: + """ + 业务注册接收回调 + """ + self._screenshot_callbacks.append(callback) + logger.debug(f"业务截屏回调已注册,当前共 {len(self._screenshot_callbacks)} 个") + + def unregister_screenshot_callback(self, callback) -> None: + """注销业务回调""" + if callback in self._screenshot_callbacks: + self._screenshot_callbacks.remove(callback) + logger.debug("业务截屏回调已注销") + + def get_latest_screenshot(self) -> Optional[ScreenShotDataTransferObject]: + """ + 同步获取最新截屏数据(轮询模式) + + Returns: + 最新接收到的截屏数据DTO,如果没有则为 None + """ + return self._latest_screenshot + + # 内部通知机制 + async def _notify_callbacks(self) -> None: + """通知所有业务回调""" + if not self._screenshot_callbacks: + logger.warning("无业务回调,截屏数据未处理") + return + tasks = [callback(self.screenshot_data) for callback in self._screenshot_callbacks] + await asyncio.gather(*tasks, return_exceptions=True) + logger.debug(f"已通知 {len(self._screenshot_callbacks)} 个业务回调") diff --git a/src/modules/websocket_base_module/readme.md b/src/modules/websocket_base_module/readme.md new file mode 100644 index 0000000..6b9724b --- /dev/null +++ b/src/modules/websocket_base_module/readme.md @@ -0,0 +1,56 @@ +### 说明 +在本module当中,每个子模块的用途分别是: +- dto + - dto_templates + 服务端与客户端交互所使用到的数据传输对象 + - dto_base.py / xxx_dtos.py + 实际的DTO业务 +- websocket_core + - websocket核心,承载了底层核心的网络收发业务 + + +### 模块架构 +```mermaid +graph TB + subgraph "Client" + C[WebSocket Client] + end + + subgraph "Core Layer" + WS[WebSocketServer
单例] + WS -->|持有| WSP[WebSocketServerProtocol
_websocket] + WS -->|管理| RCV[ receivers: Dict
binary/text/json ] + end + + subgraph "DTO Base Layer" + MDTO[MessageDTO
抽象基类] + MDTO -->|注入| MDF[ send_binary
send_text
send_json ] + end + + subgraph "Secondary Dispatcher" + JDTO[JsonDTO
单例] + JDTO -->|继承| MDTO + JDTO -->|维护| MAP[ receivers: Dict
audio_data/... ] + JDTO -->|注册到| WS + end + + subgraph "Business DTO" + ADTO[AudioDataDTO......
业务实现] + ADTO -->|持有引用| JDTO + ADTO -->|使用| ATO[AudioDataTransferObject
Pydantic模型] + end + + C <-->|websocket连接| WSP + + WS -->|分发消息| JDTO + JDTO -->|二次分发| ADTO + + ADTO -->|发送响应| JDTO + JDTO -->|调用| MDF + MDF -->|经由| WS + WS -->|发送至| C + + style WS fill:#64f,stroke:#333,stroke-width:2px + style JDTO fill:#569,stroke:#333,stroke-width:2px + style ADTO fill:#38f,stroke:#333,stroke-width:2px +``` \ No newline at end of file diff --git a/src/modules/websocket_base_module/websocket_core/__init__.py b/src/modules/websocket_base_module/websocket_core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/modules/websocket_base_module/websocket_core/core_ws_server.py b/src/modules/websocket_base_module/websocket_core/core_ws_server.py new file mode 100644 index 0000000..093ffd7 --- /dev/null +++ b/src/modules/websocket_base_module/websocket_core/core_ws_server.py @@ -0,0 +1,172 @@ +# websocket_core/core_ws_server.py +import asyncio +import json +from typing import Callable, Optional, Dict, Any, List, Coroutine +from websockets.asyncio.server import serve, ServerConnection +from websockets.exceptions import ConnectionClosed +from loguru import logger + +# 类型别名 +ReceiveCallback = Callable[[Any], Coroutine[Any, Any, None]] +class WebSocketServer: + """WebSocket服务端核心模块(单例 + 单客户端) + 只管理一个客户端连接,DTO层注册接收函数,服务端只负责分发。 + """ + _instance: Optional["WebSocketServer"] = None + _lock = asyncio.Lock() + + def __new__(cls): + """同步单例(__init__ 可以是 async)""" + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + # 防止重复初始化 + if self._initialized: + return + self._websocket: Optional[ServerConnection] = None + self._receivers: Dict[str, List[ReceiveCallback]] = { + 'binary': [], + 'text': [], + 'json': [] + } + self._connected_event = asyncio.Event() + self._initialized = True + logger.info("WebSocketServer 单客户端分发器已初始化") + + # DTO层注册接口 + def register_receiver(self, msg_type: str, callback: ReceiveCallback) -> None: + """注册接收函数(供DTO层调用) + + Args: + msg_type: 消息类型(binary/text/json) + callback: 接收回调,签名为 (data) -> coroutine + """ + if msg_type in self._receivers: + self._receivers[msg_type].append(callback) + logger.debug(f"已注册 {msg_type} 接收器,当前 {msg_type} 类型接收器共 {len(self._receivers[msg_type])} 个") + else: + raise ValueError(f"不支持的消息类型: {msg_type}") + + def unregister_receiver(self, msg_type: str, callback: ReceiveCallback) -> None: + """注销接收函数""" + if callback in self._receivers[msg_type]: + self._receivers[msg_type].remove(callback) + logger.debug(f"已注销 {msg_type} 接收器") + + # 发送接口(供DTO层调用) + async def send_binary(self, data: bytes): + """发送二进制数据(唯一客户端)""" + if self._websocket: + await self._websocket.send(data) + logger.trace(f"二进制数据已发送 (长度: {len(data)} bytes)") + else: + raise RuntimeError("客户端未连接") + + async def send_text(self, data: str): + """发送文本数据(唯一客户端)""" + if self._websocket: + await self._websocket.send(data) + logger.trace(f"文本数据已发送 (长度: {len(data)} chars)") + else: + raise RuntimeError("客户端未连接") + + async def send_json(self, data: Dict[str, Any]): + """发送JSON数据(唯一客户端)""" + if self._websocket: + try: + logger.debug(f"准备发送JSON数据: {data}") + message = json.dumps(data) + await self._websocket.send(message) + logger.trace(f"JSON数据已发送: {data}") + except Exception as e: + logger.error(f"JSON数据发送失败: {e}") + raise + else: + raise RuntimeError("客户端未连接") + + # 等待连接 + async def wait_for_client(self): + """阻塞等待客户端连接""" + await self._connected_event.wait() + logger.info("客户端已就绪") + + # 内部消息循环 + async def _handle_client(self, websocket: ServerConnection): + """处理唯一客户端的消息循环""" + self._websocket = websocket + self._connected_event.set() + client_info = f"{websocket.remote_address}" if hasattr(websocket, 'remote_address') else "unknown" + logger.info(f"客户端已连接: {client_info}") + + try: + async for message in websocket: + # 根据消息类型分发到所有注册的接收函数 + if isinstance(message, bytes): + await self._dispatch('binary', message) + + elif isinstance(message, str): + # 优先尝试JSON解析 + json_dispatched = False + try: + data = json.loads(message) + await self._dispatch('json', data) + json_dispatched = True + except json.JSONDecodeError: + pass + + # 如果不是JSON或没有json接收器,尝试text + if not json_dispatched: + await self._dispatch('text', message) + + else: + logger.warning(f"未知消息类型: {type(message)}") + logger.info("客户端连接已正常关闭") + except ConnectionClosed as e: + logger.info(f"客户端连接已关闭: {e}") + except Exception as e: + logger.error(f"处理客户端时发生错误: {e}") + finally: + self._websocket = None + self._connected_event.clear() + + async def _dispatch(self, msg_type: str, data: Any): + """分发消息到所有注册的接收函数""" + callbacks = self._receivers[msg_type] + if not callbacks: + logger.warning(f"无 {msg_type} 接收器,消息被忽略") + return + + # 并发执行所有回调 + tasks = [callback(data) for callback in callbacks] + await asyncio.gather(*tasks, return_exceptions=True) + + # 启动服务器 + async def run(self, host: str = "localhost", port: int = 8765, max_msg_size: int = 50*1024*1025): + """启动WebSocket服务器(阻塞)""" + logger.info(f"WebSocket服务器启动中... 等待客户端连接 ws://{host}:{port}") + + async def handler_wrapper(connection): + logger.info(f"新连接请求: {connection.remote_address}") + await self._handle_client(connection) + + try: + async with serve( + handler=handler_wrapper, # 使用 wrapper 适配签名 + host=host, + port=port, + max_size=max_msg_size, + ): + logger.success(f"WebSocket服务器已启动在 ws://{host}:{port}") + await asyncio.Future() # 永久阻塞,保持服务器运行 + except Exception as e: + logger.error(f"WebSocket服务器启动失败: {e}") + raise + + +async def get_ws_server() -> WebSocketServer: + """全局单例获取函数(线程安全)""" + server = WebSocketServer() # __new__ 保证单例 + return server \ No newline at end of file diff --git a/src/server_core/core.py b/src/server_core/core.py new file mode 100644 index 0000000..b231b74 --- /dev/null +++ b/src/server_core/core.py @@ -0,0 +1,306 @@ +# server_core/core.py + +""" +统一业务,对外提供启动接口 +业务数据流向: + Yosuga[User Audio Info Struct] ->(WebSocket) Yosuga_server[asr_module] -> Text + Yosuga_server[Come from Yosuga Audio ASR Text] ->(Func call) Yosuga_server[llm_core] -> Ins and Text + + Yosuga_server[Come from llm_core Text]->(WebSocket) Yosuga + Yosuga_server[Come from llm_core Text]->(Func call) Yosuga_server[TTS] -> Audio Data + Yosuga_server[Audio Data] ->(WebSocket) Yosuga + + Yosuga_embedded[Devices Control Info] ->(WebSocket) Yosuga_server[embedded_core]-> Ins + Yosuga_embedded[Devices Control Info] ->(Serial) Yosuga[SerialManager] -> ForWord To Yosuga_server + Yosuga[Come from embedded Info] ->(WebSocket) Yosuga_server[llm_core] -> Ins + + UI_TARS[Mind and x&y Info] ->(Func call) Yosuga_server[llm_core] -> Ins + Yosuga_server[Come from UI_TARS Ins] ->(WebSocket) Yosuga + + Yosuga[Live2D Control Info] ->(WebSocket) Yosuga_server[llm_core] -> Ins + Yosuga_server[Live2D Control Ins] ->(Websocket) Yosuga + + Yosuga_server[agent memory] ->(Func call) Yosuga_server[Memory Uint] +""" +import asyncio +from typing import Optional, List, Dict, Any +from loguru import logger +from src.modules.websocket_base_module.dto.third_dtos import ( + AudioDataDTO, AudioDataTransferObject, + ScreenShotDataDTO, ScreenShotDataTransferObject +) +from src.modules.websocket_base_module.dto.second_dtos import JsonDTO, get_json_dto_instance +from src.modules.websocket_base_module.websocket_core.core_ws_server import WebSocketServer, get_ws_server + +from src.modules.device_control_module.device_control_core.ui_tars_.ui_tars_client import UITarsClient, UITarsClientConfig + +from src.modules.asr_module.client.asr_client import create_asr_client, ASRClientConfig, ASRClientAsync + +from src.modules.tts_module.tts_core.gpt_sovits.gpt_sovits_client import StreamingMode, TTSConfig, GPTSoVITSClient + +from src.server_core.llm_core.llm_core import ( + LLMCoreConfig, ModelConfig, + YosugaLLMCore, ModelProvider, + LLMCoreAnalysisBase, + YosugaAudioResponseData, YosugaUITARSResponseData, + YosugaUITARSRequestData +) + +from src.modules.websocket_base_module.dto.dto_templates.auto_agent_data_dto import AutoAgentDataTransferObject +from src.config.config import cfg + + +class YosugaServerCore: + """ + 异步单例类 + """ + + _instance: Optional["YosugaServerCore"] = None + _lock = asyncio.Lock() + + # 组合必要的工具类 + ws_server: WebSocketServer + json_dto: JsonDTO + audio_dto: AudioDataDTO + screenshot_dto: ScreenShotDataDTO + + asr_client: ASRClientAsync # 异步asr client + tts_client: GPTSoVITSClient # tts client + auto_agent_client: UITarsClient # GUI自动化agent + + llm_core: YosugaLLMCore = None # llm core + + @classmethod + async def get_instance(cls) -> "YosugaServerCore": + """异步单例工厂""" + if cls._instance is None: + async with cls._lock: + if cls._instance is None: + logger.info("Initializing YosugaServerCore...") + # 创建实例 + instance = cls.__new__(cls) + + # 按依赖顺序初始化数据分发器 + instance.ws_server = await get_ws_server() + instance.json_dto = await get_json_dto_instance(instance.ws_server) + instance.audio_dto = AudioDataDTO(instance.json_dto) # 音频分发器 + instance.audio_dto.register_audio_callback(instance._handle_audio_data) # 注册音频处理函数 + instance.screenshot_dto = ScreenShotDataDTO(instance.json_dto) # 截图分发器 + instance.screenshot_dto.register_screenshot_callback(instance._handle_screenshot_data) # 注册截图处理函数 + + instance.asr_client = create_asr_client(use_async=True, base_url=cfg.asr.url) + instance.tts_client = GPTSoVITSClient(host=cfg.tts.host, port=cfg.tts.port, debug=True) + # 切换GPT_SoVITS模型 + await instance.tts_client.set_gpt_weights(cfg.tts.gpt_model_name) + await instance.tts_client.set_sovits_weights(cfg.tts.sovits_model_name) + + instance.auto_agent_client = UITarsClient(UITarsClientConfig( + deployment_type=cfg.auto_agent.deployment_type, + base_url=cfg.auto_agent.base_url, + model_name=cfg.auto_agent.model_name, + temperature=cfg.auto_agent.temperature, + max_tokens=cfg.auto_agent.max_tokens + )) + + instance.llm_core = YosugaLLMCore( + model_config=ModelConfig( # TODO 同上 + provider=ModelProvider.OPENAI, + model_name=cfg.ai.model_name, + base_url=cfg.ai.base_url, + api_key=cfg.ai.api_key, + temperature=cfg.ai.temperature, + max_tokens=cfg.ai.max_tokens + ), + core_config=LLMCoreConfig( # TODO 同上 + max_context_tokens=cfg.llm_core.max_context_tokens, + enable_history=cfg.llm_core.enable_history, + role_setting=cfg.llm_core.role_character, + language=cfg.llm_core.language, # 回复使用语言 + auto_dispatch=True, + dispatch_async=True # 启用异步分发 + ) + ) + instance.register_llm_core_analysis() # 注册解析器 + instance.register_llm_core_action() # 注册分发器 + instance.llm_core.register_overflow_handler(instance._handle_overflow_logger) # 注册上下文溢出处理回调 + + cls._instance = instance + logger.success("YosugaServerCore initialized") + return cls._instance + + def register_llm_core_action(self): + """ + 注册llm_core的分发器 + """ + if self.llm_core is None: + raise Exception("LLMCore is not initialized") + self.llm_core.register_action_handler("audio_text", self._handle_audio_response, is_async=True) + self.llm_core.register_action_handler("auto_agent", self._handle_auto_agent, is_async=True) + self.llm_core.register_action_handler("call_auto_agent", self._handle_call_auto_agent, is_async=True) + self.llm_core.set_fallback_handler(self._handle_fallback) + + def register_llm_core_analysis(self): + """ + 注册llm_core的输出解析器 + """ + if self.llm_core is None: + raise Exception("LLMCore is not initialized") + self.llm_core.register_analysis_model(YosugaAudioResponseData) + self.llm_core.register_analysis_model(YosugaUITARSResponseData) + self.llm_core.register_analysis_model(YosugaUITARSRequestData) + + def _handle_overflow_logger(self, history: List[Any], metadata: Dict[str, Any]): + """上下文溢出记录,仅打印日志""" + print(f" 上下文溢出!") + print(f" 模型: {metadata['model']}") + print(f" 消息数: {metadata['message_count']}") + print(f" Token: {metadata['estimated_tokens']}/{metadata['limit']}") + print(f" 即将遗忘 {len(history) // 2} 条旧消息") + + async def _handle_audio_data(self, audio_data: AudioDataTransferObject): + """ + 音频数据接收call back + Yosuga_server只有接受到每次这个audio数据才会跑一次 + """ + logger.info("Received audio data") + # 在此处客户端发送的音频数据必定不是流式数据(考虑客户端发送数据给服务端往往是在本地的,速度极快) + # 将音频数据发送给asr转换成文本信息,音频数据格式为wav + # TODO: 考虑在此处做一个简单的vad检测,如果客户端发送的音频是静音的,则不把请求发给llm_core + asr_response = await self.asr_client.transcribe_bytes(audio_data.data) + if not asr_response.success: + logger.error(f"ASR failed: {asr_response.error}") + asr_result = asr_response.data # 获取asr结果 + # 将asr结果发送给llm_core进行处理 + llm_result = await self.llm_core.interact( + user_input={ # 构造用户输入信息 + "text": asr_result.text, + "confidence": asr_result.confidence + } + ) # llm_core会自动进行处理并通过执行器异步返回各种相关的数据 + + async def _handle_screenshot_data(self, screenshot_data: ScreenShotDataTransferObject): + """ + 屏幕截图数据接收call back + 将llm_core的回复封装后提交给auto_agent模块,获得自动化agent的返回之后再返回给llm_core + """ + logger.info(f"Received screenshot data {len(screenshot_data.RealTimeScreenShot)}") + if not screenshot_data.isSuccess: # 如果客户端截图失败 + logger.error("Screenshot failed") + return # 直接提前结束回调,不向llm_core发送结果 + # TODO 对于设备描述信息(screenshot_data.DescribeInfo),考虑加入到auto_agent的输入中,增强识别准确率 + # 构造请求 异步调用 + logger.debug(f"screenshot_data.LLMResponse(来自llm_core向auto_agent的输入): {screenshot_data.LLMResponse}") + logger.debug(f"客户端设备信息: {screenshot_data.DescribeInfo}") + auto_agent_response: str = await self.auto_agent_client.call_async(screenshot_data.LLMResponse, + screenshot_data.RealTimeScreenShot) + logger.debug(f"auto_agent_response(auto_agent原生返回结果): {auto_agent_response}") + # 将auto_agent的返回结果发送给llm_core + await self.llm_core.interact( + user_input={ # 构造auto_agent输入信息 + "auto_agent": auto_agent_response + } + ) + + async def _handle_audio_response(self, data: YosugaAudioResponseData): + """ + llm_core异步处理器:语音回复 + 将llm_core的回复封装后提交给tts模块,调用tts模块中的流式返回,并将流式frame返回给Yosuga客户端 + """ + if data.type == "audio_text": + logger.info("Handling audio response") + try: + # 使用最快模式流式输出 + chunk_count = 0 + async for chunk in await self.tts_client.tts( + text=data.response_text, + ref_audio_path="uploaded_audio/test_voice.wav", # TODO 需要替换成config或者后续设计情感系统 + text_lang="ja", + prompt_lang="ja", + prompt_text="もう!こんなところで何やってるんだよ!", # 参考语音的真实文本 + streaming_mode=StreamingMode.FASTEST, # 模式3:快速流式 + media_type="wav" + ): + chunk_count += 1 + print(f"🎵 收到音频块 #{chunk_count}: {len(chunk.audio_data)} bytes") + if chunk_count == 1: # 如果是第一个音频块 + # 构造音频首包发送给客户端 + await self.audio_dto.send_audio_data( + AudioDataTransferObject( + data=chunk.audio_data, + isStream=True, + isStart=True, + sequence=chunk_count, + isEnd=False, + text=data.response_text + ) + ) + else: # 如果不是第一个音频块,则发送中间包给客户端 + await self.audio_dto.send_audio_data( + AudioDataTransferObject( + data=chunk.audio_data, + isStream=True, + isStart=False, + sequence=chunk_count, + isEnd=False, + text=data.response_text + ) + ) + print(f"✅ 流式TTS完成!共{chunk_count}个音频块") + # 构造音频尾包发送给客户端(虚假的音频数据) + await self.audio_dto.send_audio_data( + AudioDataTransferObject( + data=b"0", + isStream=True, + isStart=False, + sequence=chunk_count + 1, + isEnd=True, + text=data.response_text + ) + ) + except Exception as e: + print(f"❌ 流式错误: {e}") + return {"status": "success", "executed": data.response_text} + return None + + async def _handle_auto_agent(self, data: YosugaUITARSResponseData): + """ + llm_core异步处理器:处理自动化操作 + 将llm_core的回复封装后提交给Yosuga客户端,由客户端进行执行相关的GUI自动化操作 + """ + # 构造并发送回复数据 + await self.json_dto.send_json( + AutoAgentDataTransferObject.from_json(data.to_dict()).to_json() + ) + return {"status": "success", "executed": data.Action} + + async def _handle_call_auto_agent(self, data: YosugaUITARSRequestData): + """ + llm_core异步处理器:处理llm_core调用auto_agent需求 + 向客户端请求当前界面的截图,请求成功后由_handle_screenshot_data函数完成剩下的任务 + """ + if data.type == "call_auto_agent": + logger.info("LLM Calling auto agent") + # 向客户端请求当前界面的截图的base64编码 加入llm回复的信息到截图请求DTO当中 方便_handle_screenshot_data构造请求 + await self.screenshot_dto.send_screenshot_data(ScreenShotDataTransferObject(LLMResponse=data.llm_translation)) + return {"status": "success", "executed": data.type} + + def _handle_fallback(self, data: LLMCoreAnalysisBase): + """ + llm_core同步处理器:回退处理器 + """ + logger.debug(f" [Fallback] 未知类型数据: {data.type}, 内容: {data.model_dump_json()}") + + async def run(self): + """启动服务器""" + logger.info("Yosuga Server Websocket Core 启动中...") + await self.ws_server.run(host="0.0.0.0") + + +# 使用方式 +async def main(): + core = await YosugaServerCore.get_instance() + await core.run() + + +if __name__ == '__main__': + asyncio.run(main()) \ No newline at end of file diff --git a/src/server_core/emotion_core/readme.md b/src/server_core/emotion_core/readme.md new file mode 100644 index 0000000..80d15cf --- /dev/null +++ b/src/server_core/emotion_core/readme.md @@ -0,0 +1,2 @@ +### 本模块为Yosuga_server的AI输出语音情感管理模块 + diff --git a/src/server_core/llm_core/llm_core.py b/src/server_core/llm_core/llm_core.py new file mode 100644 index 0000000..9390dab --- /dev/null +++ b/src/server_core/llm_core/llm_core.py @@ -0,0 +1,564 @@ +# llm_core/llm_core.py + +""" +Yosuga Server LLM 核心控制模块 +负责整合Prompt管理、模型调用、输出解析、上下文记忆管理以及生命周期维护。 +作为系统的"大脑",对外提供统一的高级交互接口。 +""" +import asyncio +import json +import time +from typing import List, Dict, Any, Optional, Callable, Union, Type +from loguru import logger +from pydantic import BaseModel, Field +# 引入已有的模块 +from src.modules.text_ai_module.text_ai_core.general_text_ai_req import ( + UnifiedLLM, ModelConfig, ChatMessage, ModelResponse, ModelProvider +) +from src.server_core.llm_core.llm_core_analysis import ( + LLMCoreAnalysisManager, LLMCoreAnalysisBase, YosugaAudioResponseData, YosugaUITARSResponseData, YosugaUITARSRequestData +) +from src.server_core.llm_core.llm_core_dispatcher import LLMCoreActionDispatcher +from src.server_core.llm_core.llm_core_prompt_manager import ( + LLMCorePromptManager, LLMCorePromptBase, + YosugaAudioASRText, YosugaUITARS, YosugaLive2DControl +) +from src.server_core.llm_core.llm_core_prompts import YOSUGA_SYSTEM_PROMPT_SCH +from src.server_core.llm_core.llm_core_token import TokenManager, TokenUsage + +# 类型定义:上下文溢出回调函数签名 +# 参数1: 溢出的历史记录列表 +# 参数2: 相关的元数据 +ContextOverflowCallback = Callable[[List[ChatMessage], Dict[str, Any]], None] + +class LLMCoreConfig(BaseModel): + """LLM Core 运行时配置""" + max_context_tokens: int = Field(default=2000, description="上下文最大Token数(估算值,不包括System prompt),超出触发重置") + enable_history: bool = Field(default=True, description="是否启用历史对话记忆") + language: str = Field(default="zh_CN", description="回复语言设定") + role_setting: str = Field(default="", description="llm角色扮演") + auto_dispatch: bool = Field(default=True, description="是否自动分发到动作处理器") + dispatch_async: bool = Field(default=False, description="分发是否使用异步模式") + memory: str = Field(default="", description="llm记忆") + system_state_table: str = Field(default="", description="Yosuga系统状态表") + + + +class YosugaLLMCore: + """ + Yosuga 服务端 LLM 核心控制器 + """ + + def __init__(self, model_config: ModelConfig, core_config: Optional[LLMCoreConfig] = None): + """ + 初始化 LLM Core + + Args: + model_config: 底层大模型的连接配置 + core_config: 核心业务逻辑配置(上下文限制等) + """ + self.model_config = model_config + self.core_config = core_config or LLMCoreConfig() + # 初始化模型客户端 (UnifiedLLM) + self.llm_client: UnifiedLLM = UnifiedLLM(self.model_config) + # 初始化 TokenManager + self.token_manager = TokenManager(self.model_config.model_name) + # 初始化Prompt管理器 + self.prompt_manager = LLMCorePromptManager() + self._register_default_prompts() + # 上下文记忆存储 + self._history: List[ChatMessage] = [] # 注意:history不包含system prompt,只包含 user/assistant 消息 + # 上下文溢出回调列表 + self._overflow_callbacks: List[ContextOverflowCallback] = [] + logger.info( + f"YosugaLLMCore 初始化完成 | " + f"模型: {model_config.model_name} | " + f"提供商: {model_config.provider}" + ) + logger.info( + f"上下文限制: {self.core_config.max_context_tokens} tokens | " + f"自动分发: {self.core_config.auto_dispatch}" + ) + + def _register_default_prompts(self): + """注册默认的业务Prompt模块""" + self.prompt_manager.register(YosugaAudioASRText()) + self.prompt_manager.register(YosugaUITARS()) + # self.prompt_manager.register(YosugaLive2DControl()) # TODO + logger.info(f"默认Prompt模块注册完成 | 数量: {self.prompt_manager.get_registry_size()}") + + # 系统提示词管理 + def get_system_prompt(self) -> str: + """ + 动态构建当前的 System Prompt + 根据 prompt_manager 中注册的模块实时生成 + """ + return YOSUGA_SYSTEM_PROMPT_SCH.format( + InputInfo=self.prompt_manager.describe_input(), # 不变的内容 + OutputInfo=self.prompt_manager.describe_output(), # 不变的内容 + RoleSetting=self.core_config.role_setting, # 角色扮演,可热重载 + Language=self.core_config.language, # 回复语言,可热重载 + Memory=self.core_config.memory, # 记忆,可热重载,请求前更新 + SystemStateTable=self.core_config.system_state_table# 系统状态表,可热重载,每次请求都会更新 + ) + + def register_prompt_module(self, prompt_module: LLMCorePromptBase): + """运行时注册新的 Prompt 业务模块""" + self.prompt_manager.register(prompt_module) + logger.info(f"动态注册 Prompt 模块: {prompt_module.type()}") + + def register_analysis_model(self, model_class: Type[LLMCoreAnalysisBase]) -> None: + """ + 注册LLM输出解析模型 + + Args: + model_class: 继承自 LLMCoreAnalysisBase 的数据模型类 + """ + LLMCoreAnalysisManager.register(model_class) + logger.info(f"注册解析模型: {model_class.type_()}") + + def register_action_handler( + self, + type_id: str, + handler: Callable, + is_async: bool = False + ) -> None: + """ + 注册动作处理器 + + Args: + type_id: 与解析模型对应的类型标识 + handler: 处理函数(同步或异步) + is_async: 是否为异步处理器 + """ + if is_async: + LLMCoreActionDispatcher.register_async(type_id, handler) + else: + LLMCoreActionDispatcher.register(type_id, handler) + + def set_fallback_handler(self, handler: Callable) -> None: + """设置未注册类型的回退处理器""" + LLMCoreActionDispatcher.set_fallback(handler) + + # 核心交互接口 + async def interact( + self, + user_input: Union[str, Dict[str, Any]], # 输入(纯文本或结构化字典) + past_memories: Optional[str] = "", # 记忆模块检索的相关历史记忆 + system_state_table: Optional[str] = "", # 系统状态表,可热重载,每次请求都会更新 + auto_dispatch: Optional[bool] = True, # 是否自动分发,默认为启用 + dispatch_async: Optional[bool] = True # 是否异步分发,默认为启用 + ) -> Dict[str, List[Any]]: + """ + 核心交互方法:处理输入 -> 组装上下文 -> 调用LLM -> 解析输出 + + Args: + user_input: 用户输入(纯文本或结构化字典) + past_memories: 记忆模块检索的相关历史记忆 + system_state_table: Yosuga系统状态表(方便llm理解当前系统状态) + auto_dispatch: 是否自动分发(覆盖默认配置) + dispatch_async: 是否异步分发(覆盖默认配置) + + Returns: + 分发执行结果字典: + { + "success": [{"type": "...", "output": ..., "index": 0}], + "failed": [{"type": "...", "error": "...", "index": 1}], + "skipped": ["unknown_type"] + } + + Raises: + ValueError: LLM调用或解析失败 + RuntimeError: 分发执行致命错误 + """ + # 输入预处理 + input_content = ( + json.dumps(user_input, ensure_ascii=False) + if isinstance(user_input, dict) + else user_input + ) + logger.info(f"用户输入内容经过json处理后为: {input_content}") + + # 检查并维护上下文 + self._maintain_context_limit() + + # 构建本次请求的消息链 + messages = self._build_request_messages(input_content, past_memories, system_state_table) + + # 调用 LLM + try: + llm_response = self._call_llm(messages) + if llm_response.usage: # 打印请求消耗的Token数量 + self.token_manager.record_api_usage(llm_response.usage) + logger.info(self.token_manager.format_usage_log(source="API")) # 使用 TokenManager 格式化日志 + except Exception as e: + logger.error(f"LLM调用失败: {e}") + raise + + # 统一解析(总是返回列表) + try: + parsed_results = LLMCoreAnalysisManager.parse(llm_response.content) + logger.success(f"解析成功 | 对象数: {len(parsed_results)}") + except Exception as e: + logger.error(f"输出解析失败: {e}") + raise + + # 更新历史记忆 + if self.core_config.enable_history: + self._add_to_history("user", input_content) + self._add_to_history("assistant", llm_response.content) + + # 分发执行 + should_dispatch = auto_dispatch if auto_dispatch is not None else self.core_config.auto_dispatch + if should_dispatch and parsed_results: + is_async = dispatch_async if dispatch_async is not None else self.core_config.dispatch_async + if is_async: + # 直接 await 异步执行,如果调用 asyncio.run() 就会因为重复创建事件循环导致报错 + logger.debug("使用异步模式分发动作") + return await LLMCoreActionDispatcher._execute_async(parsed_results) + else: + # 同步模式保持原样 + return LLMCoreActionDispatcher.execute(parsed_results, run_async=False) + # 不分发则返回原始解析结果(极少用) + return {"success": [{"type": obj.type, "output": obj, "index": i} + for i, obj in enumerate(parsed_results)], + "failed": [], "skipped": []} + + def _build_request_messages( + self, + current_input: str, + memories: str, + system_state_table: str + ) -> List[ChatMessage]: + """ + 构建完整的LLM消息链 + 结构: + 1. System Prompt 构造,包括记忆注入等信息 + 2. 历史上下文 + 3. 当前用户输入 + """ + # 构建memory与其他信息 + self.core_config.memory = memories + # 构建系统状态表 + self.core_config.system_state_table = system_state_table + # 构造 System Prompt + messages = [ChatMessage(role="system", content=self.get_system_prompt())] + + # 历史上下文 + if self.core_config.enable_history: + messages.extend(self._history) # 分开每条消息追加 + # 当前用户输入 + messages.append(ChatMessage(role="user", content=current_input)) + return messages + + def _call_llm(self, messages: List[ChatMessage]) -> ModelResponse: + """执行LLM调用(非流式)""" + logger.debug(f"请求LLM | 消息数: {len(messages)}") + # 预估token使用情况 TODO: (调试用) + estimated = self.token_manager.estimate_chat_tokens( + system_prompt=self.get_system_prompt(), + history=self._history, + current_input=messages[-1].content if messages else "" + ) + logger.debug(self.token_manager.format_usage_log(estimated.to_dict(), source="MANUAL")) + + # 强制非流式(结构化输出需要完整JSON) + response: ModelResponse = self.llm_client.chat( + messages, + streaming=False, + temperature=self.model_config.temperature + ) + # 记录 API 返回的 usage + if response.usage: + self.token_manager.record_api_usage(response.usage) + logger.debug(f"LLM响应长度: {len(response.content)}") + return response + + # 上下文与记忆管理 + def _add_to_history(self, role: str, content: str): + """添加消息到历史""" + self._history.append(ChatMessage(role=role, content=content)) + + def _maintain_context_limit(self) ->None: + """ + 检查上下文是否超出限制 + 如果超出,触发溢出回调,并将当前上下文导出,然后清空最近50%的 + """ + if not self.core_config.enable_history: # 若未启用历史对话记忆 + return + # 使用 TokenManager 获取上下文占用(手动计算) + context_usage = self.token_manager.get_context_usage(self._history) + current_usage = context_usage.total_tokens + + # 使用 TokenManager 格式化日志 + logger.debug( + self.token_manager.format_usage_log(context_usage.to_dict(), source="CONTEXT") + ) + limit = self.core_config.max_context_tokens + # 使用 TokenManager 判断是否接近限制 + if self.token_manager.is_token_limit_approaching(current_usage, limit, threshold=0.85): + logger.warning( + f"上下文接近限制: {current_usage}/{limit} " + f"({current_usage / limit:.1%})" + ) + if current_usage <= limit: + return + # 否则就是溢出 + logger.critical( + f"上下文溢出!| {current_usage}/{limit} tokens " + f"({current_usage / limit:.1%}) | 消息: {len(self._history)}" + ) + # 执行所有注册的溢出处理器 + self._trigger_overflow_callbacks() + + # 智能清理:保留最近50%消息 + keep_messages = max(1, len(self._history) // 2) + self._history = self._history[-keep_messages:] + # 求出新的token占有 + new_usage = self._estimate_token_usage() + logger.success( # 打印清理前后的token占用变化 + f"清理完成 | Token: {current_usage}→{new_usage} | " + f"保留消息: {len(self._history)}" + ) + + def _estimate_token_usage(self) -> int: + """ + 使用 TokenManager 计算当前历史记录的 Token 数 + """ + if not self._history: + return 0 + # 将 ChatMessage 对象转换为字典格式 + history_dicts = [ + {"role": msg.role, "content": msg.content} + for msg in self._history + ] + + return self.token_manager.count_messages_tokens( + history_dicts, + tokens_per_message=3 # OpenAI 格式开销 + ) + + def register_overflow_handler(self, handler: ContextOverflowCallback): + """ + 注册上下文溢出处理器(支持多个目标) + 这个上下文溢出处理器用于将溢出的消息收集并记录,和记忆模块对接 + """ + self._overflow_callbacks.append(handler) + logger.info(f"注册溢出处理器: {handler.__name__}") + + def _trigger_overflow_callbacks(self): + """执行所有注册的溢出处理器""" + if not self._overflow_callbacks: + return + # 使用 TokenManager 获取当前占用 + context_usage = self.token_manager.get_context_usage(self._history) + metadata = { # 构造详细的元数据 + "reason": "token_limit_exceeded", + "message_count": len(self._history), + "estimated_tokens": context_usage.total_tokens, + "limit": self.core_config.max_context_tokens, + "timestamp": time.time(), + "model": self.model_config.model_name + } + # 快照当前历史,防止回调修改 + history_snapshot = list(self._history) + + for handler in self._overflow_callbacks: + try: + handler(history_snapshot, metadata) + logger.debug(f"溢出处理器成功: {handler.__name__}") + except Exception as e: + logger.error(f"执行上下文溢出回调失败: {handler.__name__}:{e}") + + def clear_context(self): + """手动清空上下文""" + self._history.clear() + logger.info("上下文记忆已清空") + + # 运行时热重载 + def reload_model(self, new_model_config: ModelConfig): + """ + 热重载 LLM 模型配置 + 不影响当前的上下文记忆和 System Prompt + """ + logger.info(f"正在热重载模型: {self.model_config.model_name} -> {new_model_config.model_name}") + try: + self.llm_client.update_config(new_model_config) + self.model_config = new_model_config + # 重新初始化 TokenManager + self.token_manager = TokenManager(new_model_config.model_name) + logger.info("模型热重载成功") + except Exception as e: + logger.error(f"模型热重载失败: {e}") + raise + + def get_context_stats(self) -> Dict[str, Any]: + """获取详细上下文统计""" + # 使用 TokenManager 获取当前占用 + context_usage = self.token_manager.get_context_usage(self._history) + tokenizer_info = self.token_manager.get_tokenizer_info() + return { + "message_count": len(self._history), + "estimated_tokens": context_usage.total_tokens, + "limit": self.core_config.max_context_tokens, + "usage_ratio": context_usage.total_tokens / self.core_config.max_context_tokens, + "model": self.model_config.model_name, + "tokenizer": tokenizer_info, + "history_preview": [ + f"{msg.role[:1]}:{msg.content[:30]}..." + for msg in self._history[-3:] + ], + "last_api_usage": self.token_manager._last_api_usage.to_dict() if self.token_manager._last_api_usage else None + } + + def __repr__(self) -> str: + return ( # 返回描述信息 + f"YosugaLLMCore(model={self.model_config.model_name}, " + f"provider={self.model_config.provider.value}, " + f"history_len={len(self._history)})" + ) + + +# 使用示例与测试 +if __name__ == "__main__": + import sys + + # 配置日志输出 + logger.remove() + logger.add(sys.stderr, level="DEBUG") + + print("\n" + "=" * 50) + print("🚀 Yosuga Server LLM Core 启动测试") + print("=" * 50 + "\n") + + # 准备模拟的动作处理器 (Mock Handlers) + + # 异步处理器示例:处理语音回复 + async def handle_audio_response(data: LLMCoreAnalysisBase): + # 这里强制类型转换为具体子类以获取代码提示(实际运行时已经是具体类型) + if data.type == "audio_text": + print(f" [Audio Handler] 正在合成语音: {data.response_text} | 情感: {data.emotion}") + return {"audio_file": "sample.wav", "duration": 3.5} + return None + + # 异步处理器示例:处理自动化操作 + async def handle_auto_agent(data: LLMCoreAnalysisBase): + print(f" [UI Agent] 收到指令: {data.Action}") + await asyncio.sleep(1) # 模拟耗时操作 + print(f" [UI Agent] 执行动作: {data.Action} -> ({data.x1}, {data.y1})") + return {"status": "success", "executed": data.Action} + + async def handle_call_auto_agent(data: LLMCoreAnalysisBase): + print(f" [Call Agent] 收到内容: {data.type}") + print(f" [Call Agent] 收到内容: {data.llm_translation}") + + # 回退处理器 + def handle_fallback(data: LLMCoreAnalysisBase): + print(f" [Fallback] 未知类型数据: {data.type}, 内容: {data.model_dump_json()}") + + # 初始化 LLM Core + + # 配置 LM Studio 连接 + # 注意:LM Studio 通常兼容 OpenAI 格式,所以 provider 选 LM_STUDIO 或 OPENAI 均可 + # 如果是本地服务,API Key 可以随意填写 + model_cfg = ModelConfig( + provider=ModelProvider.LM_STUDIO, + model_name="qwen/qwen3-4b-2507", + base_url="http://192.168.1.3:1234/v1", + api_key="lm-studio", + temperature=0.3, + max_tokens=2048 + ) + + core_cfg = LLMCoreConfig( + max_context_tokens=1024, + enable_history=True, + role_setting="你是由Misakiotoha开发的Yosuga助手,性格抽象,爱说点小骚话。", + auto_dispatch=True, + dispatch_async=True # 启用异步分发测试 + ) + + core = YosugaLLMCore(model_cfg, core_cfg) + + # 注册处理器 + core.register_action_handler("audio_text", handle_audio_response, is_async=True) + core.register_action_handler("auto_agent", handle_auto_agent, is_async=True) + core.register_action_handler("call_auto_agent", handle_call_auto_agent, is_async=True) + core.set_fallback_handler(handle_fallback) + + # 注册解析器 + core.register_analysis_model(YosugaAudioResponseData) + core.register_analysis_model(YosugaUITARSResponseData) + core.register_analysis_model(YosugaUITARSRequestData) + + # 交互测试 Loop + def run_tests(): + # 测试场景 1: 普通对话 (触发 Audio 解析) + print("\n📝 测试 1: 普通对话 (预期触发 audio_text)") + asr_input_1 = { + "text": "你好,Yosuga!请介绍一下你自己,并对我微笑。", + "confidence": 0.99 + } + try: + result = core.interact( + asr_input_1, + dispatch_async=True + ) + print(f"🏁 交互结果: {json.dumps(result, ensure_ascii=False, indent=2)}") + except Exception as e: + logger.error(f"测试1失败: {e}") + + # 测试场景 2: 复杂指令 (预期同时触发 Audio 和 UI 操作) + print("\n📝 测试 2: 混合指令 (预期触发 audio_text + auto_agent)") + # 构造一个复杂的 Prompt 输入,诱导模型输出多条指令 + # 注意:这依赖于模型足够聪明能理解 System Prompt 中的 output schema + complex_input = """ + [{ + "text": "打开系统设置", + "confidence": 0.99 + }] + """ + + try: + result = core.interact(complex_input, dispatch_async=True) + print(f"🏁 交互结果: {json.dumps(result, ensure_ascii=False, indent=2)}") + except Exception as e: + logger.error(f"测试2失败: {e}") + + # 热重载测试 + print("\n🔄 测试 3: 模型热重载") + new_model_cfg = ModelConfig( + provider=ModelProvider.LM_STUDIO, + model_name="qwen/qwen3-vl-8b", + base_url="http://192.168.1.3:1234/v1", + api_key="lm-studio", + temperature=0.5 + ) + + try: + core.reload_model(new_model_cfg) + print("✅ 热重载完成,进行验证对话...") + + # 验证重载后是否还能对话(保留了上下文) + verify_input = """ + { + "text": "系统设置打开了吗", + "confidence": 0.95 + }, + { + "auto_agent": "Thought: 我注意到屏幕左下角有一个齿轮形状的图标,这正是系统的设置入口。在KDE系统中,这个图标通常用来访问系统设置面板。为了帮助用户打开设置界面,我现在需要点击这个位于屏幕左下方的齿轮图标。 + Action: click(start_box='<|box_start|>(42,1045)<|box_end|>')" + } + }""" + result = core.interact(verify_input) + print(f"🏁 重载后回复: {json.dumps(result, ensure_ascii=False, indent=2)}") + + except Exception as e: + logger.error(f"热重载测试失败: {e}") + + # 查看统计信息 + print("\n📊 最终状态统计:") + print(core.get_context_stats()) + + # 运行测试 + run_tests() \ No newline at end of file diff --git a/src/server_core/llm_core/llm_core_analysis.py b/src/server_core/llm_core/llm_core_analysis.py new file mode 100644 index 0000000..f768423 --- /dev/null +++ b/src/server_core/llm_core/llm_core_analysis.py @@ -0,0 +1,313 @@ +# llm_core/llm_core_analysis.py + +""" +LLM输出解析与序列化模块 +将LLM返回的JSON字符串智能解析为强类型Python对象,供其他模块直接调用 +支持多类型混合响应,自动路由到对应数据模型 +""" + +import json +from abc import ABC, abstractmethod +from typing import Any, ClassVar, Dict, Optional, Type, List +from pydantic import BaseModel, Field, ValidationError, field_validator +from loguru import logger +import re + +# 抽象基类 +class LLMCoreAnalysisBase(BaseModel, ABC): + """ + LLM输出数据模型抽象基类 + 各场景通过继承定义具体的数据结构 + """ + + # 解析器type标识 + type: str = Field(..., description="场景类型标识") + + @classmethod + @abstractmethod + def type_(cls) -> str: + """返回该模型对应的场景类型标识""" + pass + + @classmethod + def get_schema(cls) -> Dict[str, Any]: + """ + 返回该场景的数据结构schema + 用于生成system prompt中的{OutputInfo} + """ + return cls.model_json_schema() + +# 管理器 +class LLMCoreAnalysisManager: + """ + LLM输出解析管理器 + 智能路由:根据JSON中的type_字段,自动选择对应模型进行解析 + """ + + # 类变量:存储所有注册的数据模型类 + _model_registry: ClassVar[Dict[str, Type[LLMCoreAnalysisBase]]] = {} + + @classmethod + def register(cls, model_class: Type[LLMCoreAnalysisBase]) -> None: + """ + 注册场景数据模型类 + + Args: + model_class: 继承自LLMCoreAnalysisBase的数据模型类 + + Example: + LLMCoreAnalysisManager.register(YosugaAudioResponseData) + """ + type_id = model_class.type_() + cls._model_registry[type_id] = model_class + logger.info(f"已注册LLM数据输出解析模型: {type_id} -> {model_class.__name__}") + + @classmethod + def parse(cls, json_str: str) -> List[LLMCoreAnalysisBase]: + """ + 统一解析入口:无论单对象还是数组,总是返回对象列表 + + Args: + json_str: LLM原始JSON字符串(支持markdown代码块) + + Returns: + 解析后的模型实例列表,顺序与JSON数组一致 + + Raises: + ValidationError: 格式校验失败 + ValueError: JSON解析失败 + """ + print(f"待解析的内容为(llm本次输出原生内容):{json_str}") # TODO:delete + + cleaned = cls._clean_markdown(json_str) # 清理markdown标记 + try: + data = json.loads(cleaned) + # 统一包装成列表 + if not isinstance(data, list): + data = [data] + except json.JSONDecodeError as e: + logger.error(f"JSON解析失败: {e}\n原始输出: {cleaned[:200]}...") + raise ValueError(f"无效的JSON格式: {e}") + + results: List[LLMCoreAnalysisBase] = [] + for idx, item in enumerate(data): + type_id = item.get("type") + if not type_id: + logger.warning(f"跳过第{idx}个元素(无type字段): {item}") + continue + + if type_id not in cls._model_registry: + logger.warning(f"未注册的类型 '{type_id}',可用类型: {list(cls._model_registry.keys())}") + continue + + model_class = cls._model_registry[type_id] + try: + # 重新序列化为字符串再解析(保持接口兼容) + item_json = json.dumps(item, ensure_ascii=False) + result = model_class.model_validate_json(item_json) + results.append(result) + except ValidationError as e: + logger.error(f"第{idx}个对象校验失败 (type={type_id}): {e}") + continue + except Exception as e: + logger.error(f"第{idx}个对象解析失败 (type={type_id}): {e}") + continue + + logger.success(f"解析完成 | 成功: {len(results)}/{len(data)}") + return results + + @staticmethod + def _clean_markdown(json_str: str) -> str: + # 尝试找到第一个 '[' 和最后一个 ']' + start = json_str.find('[') + end = json_str.rfind(']') + if start != -1 and end != -1: + return json_str[start:end + 1] + return json_str # Fallback + +# 具体场景数据模型 +class YosugaAudioResponseData(LLMCoreAnalysisBase): + """ + 音频ASR场景的LLM输出数据模型 + + 使用示例: + LLMCoreAnalysisManager.register(YosugaAudioResponseData) + data = YosugaAudioResponseData.parse_raw('{"type_": "audio_text", "response_text": "你好"}') + data.response_text # 直接属性访问 + '你好' + data.emotion # 默认值 + 'neutral' + """ + + type: str = Field(default="audio_text", description="固定为audio_text") + response_text: str = Field(..., description="回复文本") + emotion: str = Field(default="neutral", description="情感基调") + action: str = Field(default="none", description="动作指令") + + @classmethod + def type_(cls) -> str: + return "audio_text" + +class YosugaUITARSResponseData(LLMCoreAnalysisBase): + """ + 自动化操作场景的LLM输出数据模型 + """ + + type: str = Field(default="auto_agent", description="固定为auto_agent") + Action: str = Field(..., description="动作名称") + x1: Optional[int] = Field(default=None, description="起点x") + y1: Optional[int] = Field(default=None, description="起点y") + x2: Optional[int] = Field(default=None, description="终点x") + y2: Optional[int] = Field(default=None, description="终点y") + key: Optional[str] = Field(default="", description="快捷键") + content: Optional[str] = Field(default="", description="输入文本") + direction: Optional[str] = Field(default="", description="滚动方向") + + @classmethod + def type_(cls) -> str: + return "auto_agent" + + @field_validator('x1', 'y1', 'x2', 'y2', mode='before') + @classmethod + def convert_optional_int(cls, v: Any) -> Optional[int]: + """ + 将字符串类型的坐标值转换为 int,空字符串转为 None + + Args: + v: 原始值(可能是 str, int, None) + + Returns: + Optional[int]: 转换后的值 + """ + if v is None: + return None + if isinstance(v, str): + # 处理空字符串 + if v.strip() == "": + return None + # 尝试转换为 int + try: + return int(v) + except ValueError: + logger.warning(f"无法将字符串 '{v}' 转换为 int,返回 None") + return None + if isinstance(v, int): # 如果原始值本身就是int类型,直接return + return v + if isinstance(v, float): # 如果解析出了float,强转成int再return + return int(v) + + logger.warning(f"意外的类型 {type(v)},返回 None") + return None + + def to_dict(self) -> dict: + """ + 将模型转换为字典,用于生成JSON字符串 + + Returns: + dict: 模型数据字典 + """ + return { + "type": self.type, + "Action": self.Action, + "x1": self.x1, + "y1": self.y1, + "x2": self.x2, + "y2": self.y2, + "key": self.key, + "content": self.content, + "direction": self.direction + } + +class YosugaUITARSRequestData(LLMCoreAnalysisBase): + """ + 自动化操作场景的LLM对自动化agent的调用的输出解析模型 + """ + + type: str = Field(default="call_auto_agent", description="固定为call_auto_agent") + llm_translation: str = Field(default="", description="针对用户的意图的转译,如果用户的意图足够明确,直接照抄即可") + + @classmethod + def type_(cls) -> str: + return "call_auto_agent" + + + + +class YosugaLive2DResponseData(LLMCoreAnalysisBase): + """ + Live2D控制场景的LLM输出数据模型 TODO + """ + + type: str = Field(default="live2d_control", description="固定为live2d_control") + parameter: str = Field(..., description="参数名") + value: float = Field(..., description="目标值") + duration: int = Field(default=500, description="过渡时间(ms)") + + @classmethod + def type_(cls) -> str: + return "live2d_control" + +class YosugaEmbeddedResponseData(LLMCoreAnalysisBase): + """ + 嵌入式设备场景的LLM输出数据模型 TODO + """ + + type: str = Field(default="embedded_control", description="固定为embedded_control") + device_id: str = Field(..., description="设备ID") + command: str = Field(..., description="控制指令") + params: Optional[Dict[str, Any]] = Field(default=None, description="参数") + + @classmethod + def type_(cls) -> str: + return "embedded_control" + +# 使用示例 +if __name__ == "__main__": + from loguru import logger + + # 注册所有数据模型 + LLMCoreAnalysisManager.register(YosugaAudioResponseData) + LLMCoreAnalysisManager.register(YosugaUITARSResponseData) + LLMCoreAnalysisManager.register(YosugaLive2DResponseData) + LLMCoreAnalysisManager.register(YosugaEmbeddedResponseData) + + logger.info("=== LLM输出解析模块测试 ===") + + # 测试单对象解析 + print("\n【测试1:单对象自动识别】") + llm_output = ''' + ```json + [{ + "type": "audio_text", + "response_text": "收到!我会微笑回应", + "emotion": "cheerful" + }] + ``` + ''' + + response = LLMCoreAnalysisManager.parse(llm_output) + print(f"类型: {response}") + + # 3. 测试多对象解析 + print("\n【测试2:多对象混合响应】") + multi_output = ''' + ```json + [ + { + "type": "auto_agent", + "Action": "click", + "x1": "100", + "y1": "200" + }, + { + "type": "live2d_control", + "parameter": "ParamEyeLOpen", + "value": 0.8, + "duration": 300 + } + ] + ``` + ''' + + results = LLMCoreAnalysisManager.parse(multi_output) + print(f"类型: {results}") \ No newline at end of file diff --git a/src/server_core/llm_core/llm_core_dispatcher.py b/src/server_core/llm_core/llm_core_dispatcher.py new file mode 100644 index 0000000..62b2892 --- /dev/null +++ b/src/server_core/llm_core/llm_core_dispatcher.py @@ -0,0 +1,320 @@ +# llm_core/llm_core_dispatcher.py + +""" +动作分发器模块 +负责将解析后的LLM输出对象路由到对应的业务处理器 +支持同步/异步处理,提供回退机制与执行结果追踪 +""" + +import asyncio +from typing import Any, Callable, ClassVar, Dict, List, Optional, Union +from functools import wraps +from loguru import logger + +from src.server_core.llm_core.llm_core_analysis import LLMCoreAnalysisBase + +# 处理器类型定义 +# 同步处理器:接收解析对象,返回执行结果 +SyncActionHandler = Callable[[LLMCoreAnalysisBase], Any] +# 异步处理器:接收解析对象,返回协程 +AsyncActionHandler = Callable[[LLMCoreAnalysisBase], Any] + + +def handler_error_wrapper(func: Callable) -> Callable: + """ + 处理器错误包装装饰器 + 统一捕获异常并记录日志,避免单个处理失败导致整体崩溃 + """ + + @wraps(func) + def sync_wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + logger.exception(f"处理器 {func.__name__} 执行失败: {e}") + raise + + @wraps(func) + async def async_wrapper(*args, **kwargs): + try: + return await func(*args, **kwargs) + except Exception as e: + logger.exception(f"异步处理器 {func.__name__} 执行失败: {e}") + raise + + return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper + + +class LLMCoreActionDispatcher: + """ + LLM动作分发中心 + + 职责: + 1. 管理类型到处理器的映射关系 + 2. 支持同步与异步处理器注册 + 3. 执行分发并收集处理结果 + 4. 提供未注册类型的回退机制 + 5. 支持批量处理与结果聚合 + + 使用示例: + # 注册同步处理器 + def handle_audio(data: YosugaAudioResponseData) -> dict: + return {"status": "spoken", "text": data.response_text} + + LLMCoreActionDispatcher.register("audio_text", handle_audio) + + # 注册异步处理器 + async def handle_ui(data: YosugaUITARSResponseData): + await perform_click(data.x1, data.y1) + return {"status": "clicked"} + + LLMCoreActionDispatcher.register_async("auto_agent", handle_ui) + + # 执行分发 + results = LLMCoreActionDispatcher.execute(parsed_objects) + """ + + # 类变量存储所有注册的处理器 + _sync_handlers: ClassVar[Dict[str, SyncActionHandler]] = {} + _async_handlers: ClassVar[Dict[str, AsyncActionHandler]] = {} + _fallback_handler: ClassVar[Optional[Union[SyncActionHandler, AsyncActionHandler]]] = None + + @classmethod + def register(cls, type_id: str, handler: SyncActionHandler) -> None: + """ + 注册同步处理器 + + Args: + type_id: 与 LLMCoreAnalysisBase.type_() 返回值匹配的标识符 + handler: 同步函数,接收解析对象并返回执行结果 + + Raises: + ValueError: 处理器不是可调用的函数 + """ + if not callable(handler): + raise ValueError(f"处理器必须是可调用函数,收到: {type(handler)}") + + # 检查是否已注册,防止覆盖 + if type_id in cls._sync_handlers or type_id in cls._async_handlers: + logger.warning(f"类型 '{type_id}' 已被注册,将被覆盖") + + cls._sync_handlers[type_id] = handler_error_wrapper(handler) + logger.success(f"注册同步处理器: {type_id} → {handler.__name__}") + + @classmethod + def register_async(cls, type_id: str, handler: AsyncActionHandler) -> None: + """ + 注册异步处理器 + + Args: + type_id: 类型标识符 + handler: 异步函数,接收解析对象并返回协程 + + Raises: + ValueError: 处理器不是有效的异步函数 + """ + if not asyncio.iscoroutinefunction(handler): + raise ValueError(f"异步处理器必须是协程函数,收到: {type(handler)}") + + if type_id in cls._sync_handlers or type_id in cls._async_handlers: + logger.warning(f"类型 '{type_id}' 已被注册,将被覆盖") + + cls._async_handlers[type_id] = handler_error_wrapper(handler) + logger.success(f"注册异步处理器: {type_id} → {handler.__name__}") + + @classmethod + def set_fallback(cls, handler: Union[SyncActionHandler, AsyncActionHandler]) -> None: + """ + 设置回退处理器(未注册类型的默认处理) + + Args: + handler: 同步或异步函数,处理所有未匹配的类型 + """ + wrapped = handler_error_wrapper(handler) + cls._fallback_handler = wrapped + handler_type = "异步" if asyncio.iscoroutinefunction(handler) else "同步" + logger.info(f"设置{handler_type}回退处理器: {handler.__name__}") + + @classmethod + def get_handler(cls, type_id: str) -> Optional[Union[SyncActionHandler, AsyncActionHandler]]: + """获取指定类型的处理器""" + # 优先返回同步处理器 + if type_id in cls._sync_handlers: + return cls._sync_handlers[type_id] + # 其次返回异步处理器 + if type_id in cls._async_handlers: + return cls._async_handlers[type_id] + # 返回回退处理器 + return cls._fallback_handler + + @classmethod + def execute( + cls, + analysis_results: List[LLMCoreAnalysisBase], + run_async: bool = False + ) -> Dict[str, List[Any]]: + """ + 执行分发处理 + + Args: + analysis_results: 解析器返回的对象列表 + run_async: 是否启用异步执行模式(需要业务代码支持asyncio) + + Returns: + 执行结果字典: + { + "success": [处理成功的结果列表], + "failed": [{"type": "...", "error": "..."}], + "skipped": [跳过的类型列表] + } + """ + if not analysis_results: + logger.warning("无对象需要分发") + return {"success": [], "failed": [], "skipped": []} + + if run_async and (cls._async_handlers or asyncio.iscoroutinefunction(cls._fallback_handler)): + # 异步执行模式(需要事件循环) + return asyncio.run(cls._execute_async(analysis_results)) + + # 默认同步执行 + return cls._execute_sync(analysis_results) + + @classmethod + def _execute_sync(cls, results: List[LLMCoreAnalysisBase]) -> Dict[str, List[Any]]: + """同步批量执行""" + outputs = {"success": [], "failed": [], "skipped": []} + + for idx, result in enumerate(results): + type_id = result.type + logger.debug(f"[{idx}] 分发类型: {type_id}") + + handler = cls.get_handler(type_id) + if not handler: + outputs["skipped"].append(type_id) + logger.warning(f"无处理器,跳过: {type_id}") + continue + + # 执行同步处理器 + if asyncio.iscoroutinefunction(handler): + logger.error(f"异步处理器 '{type_id}' 不能在同步模式下执行") + outputs["failed"].append({"type": type_id, "error": "异步处理器需要run_async=True"}) + continue + + try: + output = handler(result) + outputs["success"].append({ + "type": type_id, + "output": output, + "index": idx + }) + logger.success(f"[{idx}] 处理成功: {type_id}") + except Exception as e: + outputs["failed"].append({ + "type": type_id, + "error": str(e), + "index": idx + }) + logger.error(f"[{idx}] 处理失败 {type_id}: {e}") + + cls._log_summary(outputs, len(results)) + return outputs + + @classmethod + async def _execute_async(cls, results: List[LLMCoreAnalysisBase]) -> Dict[str, List[Any]]: + """异步批量执行""" + outputs = {"success": [], "failed": [], "skipped": []} + tasks = [] + + for idx, result in enumerate(results): + type_id = result.type + logger.debug(f"[{idx}] 异步分发类型: {type_id}") + + handler = cls.get_handler(type_id) + if not handler: + outputs["skipped"].append(type_id) + logger.warning(f"无处理器,跳过: {type_id}") + continue + + # 创建协程任务 + if asyncio.iscoroutinefunction(handler): + task = cls._run_async_handler(handler, result, idx, outputs) + else: + # 同步处理器在异步线程池中执行 + task = cls._run_sync_in_executor(handler, result, idx, outputs) + + tasks.append(task) + + # 并发执行所有任务 + await asyncio.gather(*tasks, return_exceptions=True) + + cls._log_summary(outputs, len(results)) + return outputs + + @classmethod + async def _run_async_handler(cls, handler, result, idx, outputs): + """运行异步处理器""" + try: + output = await handler(result) + outputs["success"].append({ + "type": result.type, + "output": output, + "index": idx + }) + logger.success(f"[{idx}] 异步处理成功: {result.type}") + except Exception as e: + outputs["failed"].append({ + "type": result.type, + "error": str(e), + "index": idx + }) + logger.error(f"[{idx}] 异步处理失败 {result.type}: {e}") + + @classmethod + async def _run_sync_in_executor(cls, handler, result, idx, outputs): + """在线程池中运行同步处理器""" + loop = asyncio.get_event_loop() + try: + output = await loop.run_in_executor(None, handler, result) + outputs["success"].append({ + "type": result.type, + "output": output, + "index": idx + }) + logger.success(f"[{idx}] 同步处理器(线程池)成功: {result.type}") + except Exception as e: + outputs["failed"].append({ + "type": result.type, + "error": str(e), + "index": idx + }) + logger.error(f"[{idx}] 同步处理器(线程池)失败 {result.type}: {e}") + + @classmethod + def _log_summary(cls, outputs: Dict, total: int): + """输出处理摘要""" + success_count = len(outputs["success"]) + failed_count = len(outputs["failed"]) + skipped_count = len(outputs["skipped"]) + + logger.info( + f"分发完成 | 总计: {total} | 成功: {success_count} " + f"失败: {failed_count} 跳过: {skipped_count}" + ) + + @classmethod + def clear(cls) -> None: + """清空所有处理器(测试/热重载用)""" + cls._sync_handlers.clear() + cls._async_handlers.clear() + cls._fallback_handler = None + logger.info("已清空所有动作处理器") + + @classmethod + def list_handlers(cls) -> Dict[str, str]: + """列出当前注册的所有处理器""" + handlers = {} + for type_id, handler in cls._sync_handlers.items(): + handlers[type_id] = f"同步: {handler.__name__}" + for type_id, handler in cls._async_handlers.items(): + handlers[type_id] = f"异步: {handler.__name__}" + return handlers \ No newline at end of file diff --git a/src/server_core/llm_core/llm_core_prompt_manager.py b/src/server_core/llm_core/llm_core_prompt_manager.py new file mode 100644 index 0000000..12fd22b --- /dev/null +++ b/src/server_core/llm_core/llm_core_prompt_manager.py @@ -0,0 +1,174 @@ +# llm_core/llm_core_prompt_manager.py + +""" +llm prompt 结构化信息管理 +用于将各种输入到llm_core的数据流结构化,并附加注释,以方便llm可以准确的理解并可以结构化返回相关内容 +""" +from abc import ABC, abstractmethod +from pydantic import BaseModel, Field, field_validator +from typing import Callable, List, Optional, Coroutine, Any, ClassVar, Dict +from src.server_core.llm_core.llm_core_prompts import YOSUGA_SYSTEM_PROMPT_SCH + +class LLMCorePromptBase(BaseModel, ABC): + """LLM 提示词基类:定义输入输出结构""" + @abstractmethod + def type(self) -> str: + """返回该提示词类型的唯一标识""" + pass + + @abstractmethod + def describe_input(self) -> str: + """生成输入格式的自然语言描述(填充到 {InputInfo})""" + pass + + @abstractmethod + def describe_output(self) -> str: + """生成输出格式的自然语言描述(填充到 {OutputInfo})""" + pass + + def to_json(self) -> str: + return self.model_dump_json() + + @classmethod + def from_json(cls, json_data: str): + return cls.model_validate_json(json_data) + + +class LLMCorePromptManager(LLMCorePromptBase): + """聚合所有子类,生成完整的 prompt 信息""" + # 类变量:存储所有注册的子类 + _registry: ClassVar[Dict[str, LLMCorePromptBase]] = {} + + def get_registry_size(self) -> int: + return len(self._registry) + + def register(self, son: LLMCorePromptBase) -> None: + self._registry[son.type()] = son + + def type(self) -> str: + return "manager" + + def describe_input(self) -> str: + """聚合所有子类的输入描述""" + return "\n".join( + f"[{type_id}]{son.describe_input()}\n" + for type_id, son in self._registry.items() + ) + + def describe_output(self) -> str: + """聚合所有子类的输出描述""" + return "\n".join( + f"[{type_id}]{son.describe_output()}\n" + for type_id, son in self._registry.items() + ) + +class YosugaAudioASRText(LLMCorePromptBase): + """音频ASR文本输入场景""" + def type(self) -> str: + return "用户语音asr信息" + + def describe_input(self) -> str: + return ''' + 当用户通过语音与Yosuga交互时,你会收到如下JSON结构: + { + "text": "用户说的话(字符串)", + "confidence": 0.95 + } + - `text`: 语音转写的原始文本,可能包含口语化表达或识别错误 + - `confidence`: ASR引擎的识别置信度,低于0.8需警惕识别错误 + ''' + + def describe_output(self) -> str: + return ''' + 针对用户音频识别出的文本内容输入,你应该按以下JSON格式回复: + { + "type": "固定为audio_text", + "response_text": "你的回复文本(字符串)", + "emotion": "neutral", + "action": "none" + } + - `response_text`: 给用户的自然语言回复 + - `emotion`: 回复的情感基调,可选值:neutral/cheerful/sad/angry + - `action`: 触发的动作指令,如"wave_hand"、"nod"等,"none"表示无动作 + ''' + +class YosugaEmbedded(LLMCorePromptBase): + """嵌入式设备输入场景""" + def type(self) -> str: + pass + + def describe_input(self) -> str: + pass + + def describe_output(self) -> str: + pass + +class YosugaUITARS(LLMCorePromptBase): + """自动化操作构建场景""" + def type(self) -> str: + return "自动化操作信息" + + def describe_input(self) -> str: + return ''' + 当你尝试调用自动化agent的时候,自动化agent将会返回下面的内容作为输入给你,自动化agent的返回信息很重要,有的任务不是一次就可以完成的,此时你可能需要多次调用自动化agent,直到agent返回finished(任务完成): + { + "auto_agent":" + Thought: 自动化agent的推理过程,可以考虑二次加工后作为回复用户的内容 + Action: 对应的自动化动作[包括click(单击坐标), left_double(双击坐标), right_single(右键单击), drag(拖拽), hotkey(快捷键), type(输入文本), scroll(滚动), wait(等待), finished(任务完成)] + " + } + ''' + + def describe_output(self) -> str: + return ''' + 1. 当用户试图进行一些自动化操作时,你可以通过返回以下json来调用自动化agent,调用的时候也可以一并回复用户一些文本内容: + { + "type": "固定为call_auto_agent", + "llm_translation": "针对用户的意图的转译,如果用户的意图足够明确,直接照抄即可,注意转译的语言必须是中文,可以混合部分英文(应用名称),这个和聊天交流的语言不统一。" + } + 2. 针对自动化agent操作输入,你应该按以下JSON格式回复。如果你发起了自动化agent请求之后,却没有返回下面的JSON内容,那么你的请求并不会作用到用户的设备,所以你调用完自动化agent一定要及时返回下面的JSON内容: + { + "type": "固定为auto_agent", + "Action": "自动化agent返回的相应的自动化动作名称", + "x1": "某个操作的x1,不是所有的操作都有,如果相关的操作没有,写成-1即可,y1,x2,y2也同理", + "y1": "某个操作的y1", + "x2": "某个操作的x2", + "y2": "某个操作的y2", + "key": "快捷键,若当前操作没有该字段信息,此处内容为空即可", + "content": "输入文本的内容,若当前操作没有该字段信息,此处内容为空即可", + "direction": "滚动方向(down or up),若当前操作没有该字段信息,此处内容为空即可,同时需要关心自动化agent回复的x1,和y1坐标" + } + 自动化agent返回的操作信息不一定包括JSON的全部字段,例如某次返回只有key的内容,或者只有content的内容。 + 针对自动化agent操作输入的返回,若没有相关内容可以留空相关字段,请不要省略掉任何字段名称。 + + 注意:自动化agent的状态可见YosugaSystemState表。 + ''' + +class YosugaLive2DControl(LLMCorePromptBase): + """对Yosuga Live2D控制场景""" + def type(self) -> str: + return "Yosuga Live2D控制信息" + + def describe_input(self) -> str: + pass + + def describe_output(self) -> str: + pass + + +if __name__ == "__main__": + # 注册所有prompt处理器 + manager = LLMCorePromptManager() + manager.register(YosugaAudioASRText()) + manager.register(YosugaUITARS()) + + # 生成最终 system prompt + system_prompt = YOSUGA_SYSTEM_PROMPT_SCH.format( + InputInfo=manager.describe_input(), + OutputInfo=manager.describe_output(), + RoleSetting="...", + Language="ja", + Memory = "", + SystemStateTable="" + ) + print(system_prompt) \ No newline at end of file diff --git a/src/server_core/llm_core/llm_core_prompts.py b/src/server_core/llm_core/llm_core_prompts.py new file mode 100644 index 0000000..06a6a64 --- /dev/null +++ b/src/server_core/llm_core/llm_core_prompts.py @@ -0,0 +1,67 @@ +# llm_core/llm_core_prompts.py + +YOSUGA_SYSTEM_PROMPT_SCH = """ +你是Yosuga_server这个项目的核心LLM。 +你将会为用户完成各种任务,进行结构化的输出。 +首先向你介绍一下Yosuga这个项目: +本项目的作者是Misakiotoha(みさきおとは[見崎音羽])。 +之所以叫Yosuga,这个词来源日语当中的单词"縁"的发音,其意思是"缘分,关系"。 +本项目分为三个部分: +1. Yosuga:这是项目的前端部分,是Yosuga与用户交互的一层,采用C++20 + Qt6.6.3编写,使用到的核心外部库为Live2D For C++ SDK。 +2. Yosuga_server:这是项目的后端部分,是Yosuga的核心,采用python3.11编写,使用到的外部库较多,负责联系项目的各个部分。 +3. Yosuga_embedded:这是项目的拓展部分,实现了Yosuga对嵌入式设备拥有几乎完全的自定义控制能力,采用C语言编写,只使用到了cJSON库,平台无关,增强了Yosuga与外界的交互能力。 + +作为Yosuga_server的核心LLM,你的任务包括以下的内容: +1. 接受自定义的结构化的信息(json),理解并给出正确的回复。 +2. 输出自定义的结构化的信息(json),给出正确的回复。 + +对于你会接受到什么信息,以及如何进行正确的返回,都会在下面做出解释。 + +你接受到的信息可能有: +1. 来自Yosuga的对于Live2D模型控制的结构化数据。 +2. 来自Yosuga_embedded的可控制的嵌入式设备信息的结构化摘要数据。 +3. 综合了各种信息的结构化请求数据(例如同时包括用户所说的话语与其他决策信息)。 + +你的工作流程: +1. 接受来自用户的请求,理解用户的意图(例如用户只有聊天的目的,或者是用户想进行一些操作) +2. 如果用户只希望聊天,那就只和用户聊天即可 +3. 如果用户还有其他意图,你需要根据你的能力来完成用户的需求,如果用户的需求不在你的能力范围内,可以婉拒。 + +你的能力范围: +1. 你可以和用户进行聊天,当用户只是想和你聊天,那就只和用户进行聊天 +2. 你可以主动调用自动化agent,当用户有着一些自动化操作的需求时候,你可以按一定调用格式去主动调用自动化agent模型,并理解自动化模型的返回内容,以规定的格式返回给服务端最终作用到用户的设备。 +3. 用户也可能会想和你一起玩一些游戏,这个时候你需要理解用户的意图,并调用自动化agent,同时返回一些聊天内容回复,以此和用户一起玩游戏。 +4. 你可以和各种种类的嵌入式设备进行交互,当用户有着一些现实的需求时候,你可以根据当前系统的状态,可控制的设备信息等信息综合判断,通过规定的方式返回规定的内容操作嵌入式设备,以满足用户的需求。 + +同时你还需要进行角色扮演,完美地扮演符合设定的角色,也要兼顾对于各种问题的处理。 + +以下是你会接收到的结构化的信息(带解释): +{InputInfo} + +以下是你需要根据不同的结构化信息需要输出的结构化信息(带解释): +{OutputInfo} + +你应当扮演的角色设定为: +{RoleSetting} + +你和用户聊天时候的回复语言为: +{Language} + + +注意:你的所有回复必须是一个 JSON 数组,即便只有单项任务。格式为 [ {{ "type": "...", ... }} ]。禁止输出JSON数组外的任何解释性文字,同时也禁止在JSON的回复内容中加入任何表情符号。 +注意:你接收到的结构化信息,有时可能会有多个,例如用户在问你的同时,自动化agent也返回了结果,这样的情况很常见,你需要按要求返回一个或者多个json即可。 +注意:有时候用户的请求可能既涉及到了聊天,也涉及到了自动化agent的调用,那么你在返回的json数组当中加入应当返回的内容即可,可以是聊天内容回复加上自动化agent调用。 +注意:有时用户的请求会比较复杂,需要你连续的调用自动化agent,根据自动化agent的输出,去一步步的完成用户的请求。 + +还有一些是你必须要参考的内容(例如与用户过去的记忆,系统实时状态表等内容): +与用户过去的记忆: +{Memory} + +系统实时状态表: +{SystemStateTable} + +""" + +YOSUGA_SYSTEM_PROMPT_EN = """ + +""" \ No newline at end of file diff --git a/src/server_core/llm_core/llm_core_token.py b/src/server_core/llm_core/llm_core_token.py new file mode 100644 index 0000000..8ec422c --- /dev/null +++ b/src/server_core/llm_core/llm_core_token.py @@ -0,0 +1,367 @@ +# llm_core/llm_core_token.py + +""" +Token 计算与管理模块 +支持双数据源智能切换: +- 优先使用大模型 API 返回的精确 usage 数据 +- 当 API 不返回时,自动回退到 tiktoken 手动估算 +""" +import time +from typing import List, Dict, Any, Optional, Union +from dataclasses import dataclass, field +import tiktoken +from loguru import logger + +@dataclass +class TokenUsage: + """Token 使用量统计""" + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + source: str = field(default="manual", repr=True) # "api" | "manual" + + def to_dict(self) -> Dict[str, Any]: + return { + "prompt_tokens": self.prompt_tokens, + "completion_tokens": self.completion_tokens, + "total_tokens": self.total_tokens, + "source": self.source + } + + +@dataclass +class TokenizerInfo: + """Tokenizer 元数据""" + model_name: str + encoding_name: str + is_fallback: bool + estimated_accuracy: str # "high" | "medium" | "low" + + +class TokenManager: + """ + Token 管理核心类 + 智能数据源切换:API返回 > 手动估算 + """ + def __init__(self, model_name: str): + self.model_name = model_name + self.tokenizer = self._get_tokenizer(model_name) + # 存储最近一次 API 返回的 usage + self._last_api_usage: Optional[TokenUsage] = None + # API 数据有效期(秒),超过此时间则视为过期,回退到手动计算 + self._api_usage_expiry: float = 30.0 + # 记录 API usage 的获取时间 + self._last_api_usage_time: float = 0.0 + + logger.info( + f"TokenManager 初始化完成 | " + f"模型: {model_name} | " + f"编码器: {self.tokenizer.name}" + ) + + def _get_tokenizer(self, model_name: str) -> tiktoken.Encoding: + """获取 tokenizer(与之前实现相同,省略重复代码)""" + model_tokenizer_map = { + "qwen": "gpt-3.5-turbo", + "llama": "gpt-3.5-turbo", + "gemma": "gpt-3.5-turbo", + } + + try: + return tiktoken.encoding_for_model(model_name) + except KeyError: + pass + + for prefix, mapped_model in model_tokenizer_map.items(): + if prefix in model_name.lower(): + logger.info( + f"模型 '{model_name}' 映射到 tokenizer '{mapped_model}'" + ) + return tiktoken.encoding_for_model(mapped_model) + + logger.warning( + f"tiktoken 不支持模型 '{model_name}'," + f"降级使用 'cl100k_base' 编码器" + ) + return tiktoken.get_encoding("cl100k_base") + + def record_api_usage(self, usage: Optional[Union[Dict[str, int], TokenUsage]]) -> None: + """ + 记录大模型 API 返回的 usage 数据 + + Args: + usage: API 返回的 usage 数据(dict 或 TokenUsage 对象) + 如果为 None 或空,则忽略 + """ + if not usage: + logger.debug("API usage 为空,未记录") + return + + # 在底层已经统一好了不同大模型对于usage的处理 + if isinstance(usage, TokenUsage): + api_usage = usage + api_usage.source = "api" + else: + api_usage = TokenUsage( + prompt_tokens=usage.get("prompt_tokens", 0), + completion_tokens=usage.get("completion_tokens", 0), + total_tokens=usage.get("total_tokens", 0), + source="api" + ) + + self._last_api_usage = api_usage + self._last_api_usage_time = time.time() + + logger.debug( + f"记录 API usage | " + f"Prompt: {api_usage.prompt_tokens} | " + f"Completion: {api_usage.completion_tokens} | " + f"Total: {api_usage.total_tokens}" + ) + + def get_current_usage(self, prefer_api: bool = True) -> TokenUsage: + """ + 获取当前 Token 使用情况(智能数据源切换) + + Args: + prefer_api: 是否优先使用 API 数据(默认 True) + + Returns: + TokenUsage 对象,包含数据来源标记 + + Notes: + - 如果 prefer_api=True 且 API 数据在有效期内,直接返回 API 数据 + - 否则回退到手动估算 + """ + current_time = time.time() + + # 检查 API 数据是否有效 + if prefer_api and self._last_api_usage: + time_elapsed = current_time - self._last_api_usage_time + if time_elapsed <= self._api_usage_expiry: + logger.debug( + f"使用 API Token 数据({time_elapsed:.1f}s 内)| " + f"{self._last_api_usage.prompt_tokens} + " + f"{self._last_api_usage.completion_tokens} = " + f"{self._last_api_usage.total_tokens}" + ) + return self._last_api_usage + + # API 数据无效或无数据,回退到手动估算 + logger.debug("API usage 无效/过期,回退到手动估算") + return self._estimate_manual_usage() + + def _estimate_manual_usage(self) -> TokenUsage: + """内部方法:创建空的手动 usage(占位符)""" + return TokenUsage(source="manual") + + def get_context_usage(self, history: List[Any]) -> TokenUsage: + """ + 计算对话上下文的 Token 占用(必须手动计算) + + 注意: + - 上下文占用无法从 API 获得,必须手动估算 + - 此方法不涉及 _last_api_usage + + Args: + history: 历史消息列表 + + Returns: + TokenUsage 对象,source 始终为 "manual" + """ + tokens = self.count_messages_tokens(history) + return TokenUsage( + prompt_tokens=tokens, + completion_tokens=0, + total_tokens=tokens, + source="manual" + ) + + def count_text_tokens(self, text: str) -> int: + """计算单段文本的 token 数量(与之前相同)""" + if not isinstance(text, str) or not text: + return 0 + return len(self.tokenizer.encode(text)) + + def count_messages_tokens( + self, + messages: List[Any], + tokens_per_message: int = 3 + ) -> int: + """计算消息列表的总 token 数量(与之前相同,优化实现)""" + if not messages: + return 0 + + num_tokens = 0 + for msg in messages: + # 统一转换为字典 + if hasattr(msg, "to_dict"): + msg_dict = msg.to_dict() + elif hasattr(msg, "model_dump"): + msg_dict = msg.model_dump() + else: + msg_dict = msg + + # 计算内容和 role 的 token + num_tokens += self.count_text_tokens(msg_dict.get("content", "")) + num_tokens += self.count_text_tokens(msg_dict.get("role", "")) + num_tokens += tokens_per_message + + # 加上回复前缀的开销 + num_tokens += 3 + return num_tokens + + def estimate_chat_tokens( + self, + system_prompt: Optional[str], + history: List[Any], + current_input: str + ) -> TokenUsage: + """ + 估算一次完整对话所需的 token 数量 + + Args: + system_prompt: 系统提示词 + history: 历史消息列表 + current_input: 当前用户输入 + + Returns: + TokenUsage 对象,source 始终为 "manual" + + Notes: + - 这是预估,不是实际 API 返回值 + - 用于调试和前置检查 + """ + messages = [] + + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + + messages.extend(history) + messages.append({"role": "user", "content": current_input}) + + total = self.count_messages_tokens(messages) + + return TokenUsage( + prompt_tokens=total, + completion_tokens=0, + total_tokens=total, + source="manual" + ) + + def format_usage_log( + self, + usage: Optional[Union[Dict[str, int], TokenUsage]] = None, + source: str = "AUTO" + ) -> str: + """ + 格式化 token 使用日志 + + Args: + usage: usage 数据(可选)。如果为 None,自动获取当前 usage + source: 数据来源标记("AUTO" | "API" | "MANUAL" | "CONTEXT") + + Returns: + 格式化的日志字符串 + """ + # 如果未提供 usage,自动获取 + if usage is None: + if source == "CONTEXT": + # 上下文场景必须手动计算 + usage_obj = self.get_context_usage([]) + else: + # 自动选择最佳数据源 + usage_obj = self.get_current_usage(prefer_api=True) + else: + # 使用提供的 usage + if isinstance(usage, TokenUsage): + usage_obj = usage + else: + usage_obj = TokenUsage( + prompt_tokens=usage.get("prompt_tokens", 0), + completion_tokens=usage.get("completion_tokens", 0), + total_tokens=usage.get("total_tokens", 0), + source="api" if source == "API" else "manual" + ) + + # 根据来源选择前缀和图标 + prefix_map = { + "API": "API Token统计", + "MANUAL": "手动估算", + "CONTEXT": "上下文占用", + "AUTO": "Token统计" + } + prefix = prefix_map.get(source, "Token统计") + + # 添加数据来源标记 + source_icon = "⚡" if usage_obj.source == "api" else "🧮" + + # 格式化输出 + if usage_obj.completion_tokens > 0: + return ( + f"{prefix} {source_icon} | " + f"Prompt: {usage_obj.prompt_tokens} | " + f"Completion: {usage_obj.completion_tokens} | " + f"Total: {usage_obj.total_tokens}" + ) + else: + return ( + f"{prefix} {source_icon} | " + f"Total: {usage_obj.total_tokens}" + ) + + def get_tokenizer_info(self) -> TokenizerInfo: + """获取当前 tokenizer 的详细信息(与之前相同)""" + if "cl100k_base" in self.tokenizer.name and "gpt-3.5" not in self.model_name: + accuracy = "low" + elif self.model_name in self.tokenizer.name: + accuracy = "high" + else: + accuracy = "medium" + + return TokenizerInfo( + model_name=self.model_name, + encoding_name=self.tokenizer.name, + is_fallback="cl100k_base" in self.tokenizer.name, + estimated_accuracy=accuracy + ) + + def is_token_limit_approaching( + self, + current_tokens: int, + limit: int, + threshold: float = 0.85 + ) -> bool: + """判断 token 使用量是否接近限制(与之前相同)""" + return current_tokens > limit * threshold + + def calculate_chunk_size( + self, + available_tokens: int, + safety_margin: float = 0.1 + ) -> int: + """计算安全的消息块大小(与之前相同)""" + return int(available_tokens * (1 - safety_margin)) + + def clear_api_usage_cache(self): + """清空 API usage 缓存(用于测试)""" + self._last_api_usage = None + self._last_api_usage_time = 0.0 + logger.debug("API usage 缓存已清空") + + +# 单例工厂 +class TokenManagerFactory: + """TokenManager 工厂类,支持缓存复用""" + _instances: Dict[str, TokenManager] = {} + + @classmethod + def get_manager(cls, model_name: str) -> TokenManager: + if model_name not in cls._instances: + cls._instances[model_name] = TokenManager(model_name) + return cls._instances[model_name] + + @classmethod + def clear_cache(cls): + cls._instances.clear() + logger.info("TokenManager 缓存已清空") \ No newline at end of file diff --git a/src/server_core/llm_core/readme.md b/src/server_core/llm_core/readme.md new file mode 100644 index 0000000..78c265f --- /dev/null +++ b/src/server_core/llm_core/readme.md @@ -0,0 +1,242 @@ +### 服务端llm核心 +服务端以AI为核心去驱动,进行各种调用。 + +本模块为服务端的核心llm模块,这个llm必须足够聪明,能够稳定返回结构化结构。 + + +llm_core负责实现: + + + + +#### 类图 +```mermaid +classDiagram + class YosugaLLMCore { + <<主控制器>> + -ModelConfig model_config + -LLMCoreConfig core_config + -UnifiedLLM llm_client + -TokenManager token_manager + -LLMCorePromptManager prompt_manager + -List[ChatMessage] _history + -Lock _history_lock + -Lock _config_lock + +interact() Dict + +register_prompt_module() void + +register_action_handler() void + +reload_model() void + +get_context_stats() Dict + } + + class LLMCoreConfig { + <<运行时配置>> + -int max_context_tokens + -bool enable_history + -str language + -str role_setting + -bool auto_dispatch + -bool dispatch_async + -str memory + -str system_state_table + } + + class UnifiedLLM { + <<大模型调用层>> + -ModelConfig config + -BaseLLMClient client + +chat() ModelResponse + +complete() ModelResponse + +stream_chat() Iterator + +update_config() void + } + + class TokenManager { + <> + -str model_name + -tiktoken.Encoding tokenizer + -TokenUsage _last_api_usage + +record_api_usage() void + +get_current_usage() TokenUsage + +get_context_usage() TokenUsage + +count_messages_tokens() int + +format_usage_log() str + } + + class LLMCorePromptManager { + <> + -Dict _registry + +register() void + +describe_input() str + +describe_output() str + } + + class LLMCoreAnalysisManager { + <<输出解析>> + <<静态类>> + -Dict _model_registry + +register() void + +parse() List[LLMCoreAnalysisBase] + } + + class LLMCoreActionDispatcher { + <<动作分发>> + <<静态类>> + -Dict _sync_handlers + -Dict _async_handlers + -Callable _fallback_handler + +register() void + +register_async() void + +execute() Dict + } + + class LLMCoreAnalysisBase { + <<抽象基类>> + <<模型数据基类>> + #str type + +type_() str + +get_schema() Dict + } + + class ChatMessage { + <<消息实体>> + -str role + -str content + -str name + +to_dict() Dict + } + + class ModelResponse { + <<响应实体>> + -str content + -str model + -Dict usage + -str finish_reason + -Dict raw_response + } + + YosugaLLMCore --> LLMCoreConfig : 持有配置 + YosugaLLMCore --> UnifiedLLM : 调用大模型 + YosugaLLMCore --> TokenManager : 统计Token + YosugaLLMCore --> LLMCorePromptManager : 管理Prompt + YosugaLLMCore --> LLMCoreAnalysisManager : 解析输出 + YosugaLLMCore --> LLMCoreActionDispatcher : 分发动作 + YosugaLLMCore --> ChatMessage : 管理历史 + UnifiedLLM --> ModelResponse : 返回响应 + LLMCoreAnalysisManager --> LLMCoreAnalysisBase : 解析为 + TokenManager --> ModelResponse : 接收usage +``` + + +#### 时序图 +```mermaid +sequenceDiagram + participant Client as 客户端 + participant Core as YosugaLLMCore + participant Token as TokenManager + participant Prompt as PromptManager + participant LLM as UnifiedLLM + participant Parser as AnalysisManager + participant Dispatcher as ActionDispatcher + + Client->>Core: interact(user_input) + Note over Core: 输入预处理 + + Core->>Token: count_messages_tokens(_history) + Token-->>Core: current_usage + + Core->>Core: _maintain_context_limit() + Note over Core: 检查溢出并清理 + + Core->>Prompt: get_system_prompt() + Note over Prompt: 聚合InputInfo/OutputInfo + + Prompt-->>Core: system_prompt + + Core->>Core: _build_request_messages() + Note over Core: 组装[system, history, user] + + Core->>Token: estimate_chat_tokens() + Token-->>Core: estimated_usage + + Core->>LLM: chat(messages) + Note over LLM: 调用底层API + + LLM-->>Core: ModelResponse(content, usage) + + Core->>Token: record_api_usage(usage) + Note over Token: 优先使用API数据 + + Core->>Parser: parse(content) + Note over Parser: JSON清洗+类型校验 + + Parser-->>Core: List[AnalysisObj] + + Core->>Core: _add_to_history(user+assistant) + Note over Core: 更新对话记忆 + + Core->>Dispatcher: execute(parsed_results) + + par 分发处理 + Dispatcher->>Handler1: 同步处理(audio_text) + Dispatcher->>Handler2: 异步处理(auto_agent) + end + + Dispatcher-->>Core: {"success": [], "failed": []} + + Core-->>Client: 执行结果 +``` + + +#### 配置与模型热重载状态机 +```mermaid +stateDiagram-v2 + [*] --> 初始化: YosugaLLMCore() + + state 初始化 { + [*] --> 加载配置: ModelConfig + 加载配置 --> 创建LLM客户端: UnifiedLLM + 创建LLM客户端 --> 注册默认Prompt: Audio/UITARS + 注册默认Prompt --> 初始化Token管理器: TokenManager + } + + 初始化 --> 待机: 等待输入 + + state 待机 { + [*] --> 构建System Prompt + 构建System Prompt --> 检查上下文限制: _maintain_context_limit() + 检查上下文限制 --> 上下文溢出: current > limit + 检查上下文限制 --> 正常: 否则 + + 上下文溢出 --> 触发回调: _trigger_overflow_callbacks() + 触发回调 --> 清理历史: 保留50% + 清理历史 --> 正常 + + 正常 --> 组装消息链: _build_request_messages() + } + + 待机 --> 调用LLM: chat() + + 调用LLM --> 解析输出: AnalysisManager.parse() + + 解析输出 --> 更新历史: _add_to_history() + + 更新历史 --> 分发动作: Dispatcher.execute() + + 分发动作 --> 返回结果: interact() return + + 返回结果 --> 待机 + + 待机 --> 热重载: reload_model() + + state 热重载 { + [*] --> 更新模型配置: update_config() + 更新模型配置 --> 重建LLM客户端: _create_client() + 重建LLM客户端 --> 重建Token管理器: TokenManager(model_name) + 重建Token管理器 --> [*] + } + + 热重载 --> 待机 + + note right of 热重载 : "保留历史记忆\n不影响上下文" +``` \ No newline at end of file diff --git a/src/server_core/rag_core/readme.md b/src/server_core/rag_core/readme.md new file mode 100644 index 0000000..a78988f --- /dev/null +++ b/src/server_core/rag_core/readme.md @@ -0,0 +1,8 @@ +### 本模块为Yosuga_server的AI记忆模块 + + +##### 提供的功能有: +1. 向量化的记忆存储 +2. 十分便捷的记忆管理与查询接口 +3. 高效的记忆检索 + diff --git a/src/server_core/readme.md b/src/server_core/readme.md new file mode 100644 index 0000000..ed8e761 --- /dev/null +++ b/src/server_core/readme.md @@ -0,0 +1,3 @@ +### 本模块为Yosuga_server的核心业务模块 + +#### diff --git a/src/server_view/readme.md b/src/server_view/readme.md new file mode 100644 index 0000000..793fa04 --- /dev/null +++ b/src/server_view/readme.md @@ -0,0 +1,4 @@ +### 本模块为Yosuga_server的前端部分 + +为了方便管理服务端 +