first pull
This commit is contained in:
+31
@@ -0,0 +1,31 @@
|
||||
# Python-generated files
|
||||
__pycache__/
|
||||
*.py[oc]
|
||||
build/
|
||||
dist/
|
||||
wheels/
|
||||
*.egg-info
|
||||
|
||||
# from ide
|
||||
.vscode
|
||||
.idea
|
||||
|
||||
# Virtual environments
|
||||
.venv
|
||||
venv
|
||||
|
||||
# config
|
||||
settings.json
|
||||
|
||||
# logs
|
||||
log/*
|
||||
|
||||
# temp files
|
||||
temp_files/*
|
||||
|
||||
# ai models
|
||||
src/modules/asr_module/asr_models/*
|
||||
|
||||
# uv
|
||||
uv.lock
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
3.11
|
||||
@@ -0,0 +1,85 @@
|
||||
### Yosuga_server
|
||||
|
||||
## 📊 Project Stats
|
||||
|
||||

|
||||

|
||||

|
||||
|
||||
欢迎访问本项目。
|
||||
|
||||
首先向你介绍一下Yosuga这个项目:
|
||||
|
||||
本项目的作者是Misakiotoha(みさきおとは[見崎音羽])。[call me "Misaki" でいいよ]
|
||||
|
||||
之所以叫Yosuga,这个词来源日语当中的单词"縁"的发音,其意思是"缘分,关系"。
|
||||
|
||||
本项目分为三个部分:
|
||||
1. Yosuga:这是项目的前端部分,是Yosuga与用户交互的一层,采用C++20 + Qt6.6.3编写,使用到的核心外部库为Live2D For C++ SDK。
|
||||
2. Yosuga_server:这是项目的后端部分,是Yosuga的核心,采用python3.11编写,使用到的外部库较多,负责联系项目的各个部分。
|
||||
3. Yosuga_embedded:这是项目的拓展部分,使得Yosuga对嵌入式设备拥有几乎完全的自定义控制能力,采用C语言编写,只使用到了cJSON库,平台无关,增强了Yosuga与外界的交互能力。
|
||||
|
||||
**_本项目为Yosuga_server._**
|
||||
|
||||
本项目使用uv构建,基于python3.11.
|
||||
本项目由YosugaServer发展而来,项目架构与代码有了相当大的改变。(YosugaServer并未开源,它仅仅是一次小小的尝试)
|
||||
|
||||
|
||||
### 如何快速启动本项目?
|
||||
1. 确保uv已安装,并添加到环境变量中
|
||||
2. 执行`cd Yosuga_server` & `uv sync`
|
||||
3. 接着,如果你的电脑带有cuda,那么执行 `uv pip install -r requirements-cuda.txt`
|
||||
4. 如果没有cuda,那么执行 `uv pip install -r requirements-cpu.txt`
|
||||
5. 最后执行 `uv run python main.py` 即可启动项目
|
||||
|
||||
首次启动项目后,会在项目根目录下生成settings.json配置文件,你需要配置一些必要的字段信息:
|
||||
```json
|
||||
{
|
||||
"ai": {
|
||||
"api_key": "sk-xxxxx",
|
||||
"base_url": "http://localhost:1234/v1",
|
||||
"model_name": "qwen/qwen3-4b-2507"
|
||||
},
|
||||
"tts": {
|
||||
"gpt_model_name": "GPT_weights_v2Pro/Yosuga_Airi-e32.ckpt",
|
||||
"sovits_model_name": "SoVITS_weights_v2Pro/Yosuga_Airi_e16_s864.pth",
|
||||
"host": "localhost",
|
||||
"port": 20261,
|
||||
"reference_audio": "./using/reference.wav"
|
||||
},
|
||||
"asr": {
|
||||
"url": "http://localhost:20260/"
|
||||
},
|
||||
"auto_agent": {
|
||||
"deployment_type": "lmstudio",
|
||||
"model_name": "ui-tars-1.5-7b@q4_k_m",
|
||||
"base_url": "http://localhost:1234/v1"
|
||||
},
|
||||
"llm_core": {
|
||||
"role_character": "你是由Misakiotoha开发的助手稲葉愛理ちゃん,可以和用户一起玩游戏,聊天,做各种事情,性格抽象,没事爱整整活。",
|
||||
"max_context_tokens": 2048,
|
||||
"language": "日本语"
|
||||
}
|
||||
}
|
||||
```
|
||||
上面这些字段的信息,你需要根据你的实际情况进行配置。实际的配置文件的字段名称会比上面的多出不少。
|
||||
|
||||
|
||||
配置完成后,再次重启服务端就可以使用啦~
|
||||
|
||||
接着是每个模型的配置相关:
|
||||
1. asr模型,本项目使用fast-whisper作为asr模型,并且附带了一键启动的部分
|
||||
,你需要找到 `Yosuga_server/src/modules/asr_module/start_api.py` 这个文件,然后启动它
|
||||
,一般来说,即使是cpu也可以进行asr模型的推理,但是速度相比cuda要逊色很多。
|
||||
同时,如果你遇到了启动时长时间加载,那么此时你需要试着挂一下梯子,因为初次启动
|
||||
会在Hugging Face上下载模型。
|
||||
2. tts模型,本项目使用GPT-SoVITS作为tts模型,建议使用其V2Pro版本。
|
||||
3. auto_agent模型,本项目使用的自动化操作识别的模型为字节跳动开源的
|
||||
`ui-tars-1.5-7b@q4_k_m` 关于此模型的更多信息可以参考字节跳动的[开源链接](https://github.com/bytedance/UI-TARS)
|
||||
,建议在LM Studio上进行部署,该模型十分轻量。
|
||||
4. ai模型,该模型限制为大语言模型,没有限制,本项目支持市面上的所有大语言模型。
|
||||
|
||||
|
||||
本项目当前并不完善,还有很多需要优化的地方,并且尚未接入Yosuga_embedded。
|
||||
|
||||
欢迎大家为本项目贡献代码。
|
||||
@@ -0,0 +1,173 @@
|
||||
import asyncio
|
||||
from src.modules.tts_module.tts_core.gpt_sovits.gpt_sovits_client import GPTSoVITSClient, StreamingMode
|
||||
from src.modules.tts_module.tts_core.async_audio_player import AsyncAudioPlayer
|
||||
import sounddevice as sd
|
||||
|
||||
test_text = "春の午後、公園のベンチに座って本を読んでいると、小さな子供が凧あげをしているのが目に入った。風に乗って凧が高く上がるたびに、彼の顔には真っすぐな笑顔が広がる。母親がそばで見守りながら、時折声をかけている。\
|
||||
空は雲一つない青さで、桜の花びらが風に舞っている。遠くで犬の鳴き声が聞こえ、芝生の上では老夫婦がお茶を楽しんでいた。すべてがゆっくりと流れる時間の中で、自分の心も不思議と落ち着いてくる。\
|
||||
ふと、子供の凧が木の枝に引っかかってしまった。少し焦る様子だったが、母親が助けてくれて、すぐにまた空に舞い上がった。失敗しても、誰かが助けてくれる。そんな当たり前のことに、今日は特別な温かさを感じた。\
|
||||
日が傾き始める頃、私は本を閉じて家路についた。明日もきっと、誰かの笑顔があるだろう。"
|
||||
|
||||
async def test_tts():
|
||||
# 创建客户端(推荐上下文管理器)
|
||||
async with GPTSoVITSClient(debug=True, port= 20261, host="192.168.1.8") as client:
|
||||
# 基础TTS调用
|
||||
try:
|
||||
audio = await client.tts(
|
||||
text="あのさ、いやまあ、なんていうか...要するに、そういうことじゃなくて、ほら、前に言ってたやつ、あれなんだけど、とにかく、後ででもいいから、ちょっと相談に乗ってくれない?",
|
||||
ref_audio_path="uploaded_audio/test_voice.wav", # 服务器上的路径
|
||||
text_lang="ja",
|
||||
prompt_lang="ja",
|
||||
media_type="wav",
|
||||
prompt_text="もう!こんなところで何やってるんだよ!"
|
||||
)
|
||||
|
||||
# 保存音频
|
||||
audio.save("outputs/output.wav")
|
||||
print(f"✅ TTS成功!音频大小: {len(audio.audio_data)} bytes")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 错误: {e}")
|
||||
async def test_model_change():
|
||||
async with GPTSoVITSClient(debug=True, port= 20261, host="192.168.1.8") as client:
|
||||
# 切换模型
|
||||
print("🔄 切换GPT模型...")
|
||||
await client.set_gpt_weights(
|
||||
"GPT_weights_v2Pro/Yosuga_Airi-e32.ckpt"
|
||||
)
|
||||
|
||||
print("🔄 切换SoVITS模型...")
|
||||
await client.set_sovits_weights(
|
||||
"SoVITS_weights_v2Pro/Yosuga_Airi_e16_s864.pth"
|
||||
)
|
||||
|
||||
|
||||
async def stream_tts():
|
||||
async with GPTSoVITSClient(debug=True, port= 20261, host="192.168.1.8") as client:
|
||||
try:
|
||||
# 使用最快模式流式输出
|
||||
chunk_count = 0
|
||||
async for chunk in await client.tts(
|
||||
text="要するに、そういうことじゃなくて、ほら、前に言ってたやつ、あれなんだけど、とにかく、後ででもいいから、ちょっと相談に乗ってくれない?",
|
||||
ref_audio_path="uploaded_audio/test_voice.wav",
|
||||
text_lang="ja",
|
||||
prompt_lang="ja",
|
||||
prompt_text="もう!こんなところで何やってるんだよ!",
|
||||
streaming_mode=StreamingMode.FASTEST, # 模式3:快速流式
|
||||
media_type="wav"
|
||||
):
|
||||
chunk_count += 1
|
||||
print(f"🎵 收到音频块 #{chunk_count}: {len(chunk.audio_data)} bytes")
|
||||
|
||||
# 实时播放处理
|
||||
# await play_audio_chunk(chunk.audio_data)
|
||||
|
||||
print(f"✅ 流式TTS完成!共{chunk_count}个音频块")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 流式错误: {e}")
|
||||
|
||||
|
||||
async def stream_tts_and_play(
|
||||
text: str,
|
||||
ref_audio_path: str,
|
||||
text_lang: str = "zh",
|
||||
prompt_lang: str = "zh",
|
||||
streaming_mode: StreamingMode = StreamingMode.FASTEST
|
||||
):
|
||||
"""
|
||||
实时流式TTS + 播放一体化
|
||||
|
||||
Args:
|
||||
text: 要合成的文本
|
||||
ref_audio_path: 参考音频路径
|
||||
text_lang: 文本语言
|
||||
prompt_lang: 提示语言
|
||||
"""
|
||||
# 创建音频播放器(缓冲区大小=5,平衡延迟和稳定性)
|
||||
async with AsyncAudioPlayer(buffer_size=5) as player:
|
||||
# 创建TTS客户端
|
||||
async with GPTSoVITSClient(debug=True, port= 20261, host="192.168.1.8") as client:
|
||||
try:
|
||||
print(f"🎤 开始流式合成: {text[:30]}...")
|
||||
print(f"🎯 流式模式: {streaming_mode.name}")
|
||||
|
||||
# 获取音频流(异步生成器)
|
||||
audio_stream = await client.tts(
|
||||
text=text,
|
||||
ref_audio_path=ref_audio_path,
|
||||
text_lang=text_lang,
|
||||
prompt_lang=prompt_lang,
|
||||
prompt_text="もう!こんなところで何やってるんだよ!",
|
||||
streaming_mode=streaming_mode,
|
||||
media_type="wav",
|
||||
sample_steps=32,
|
||||
top_k=5,
|
||||
temperature=1.0
|
||||
)
|
||||
|
||||
# 动态读取并播放
|
||||
chunk_idx = 0
|
||||
async for audio_chunk in audio_stream:
|
||||
chunk_idx += 1
|
||||
print(f"📥 收到音频块 #{chunk_idx}: {len(audio_chunk.audio_data):6d} bytes")
|
||||
|
||||
# 立即加入播放队列(非阻塞)
|
||||
await player.add_chunk(audio_chunk.audio_data)
|
||||
|
||||
print(f"✅ 合成完成! 共接收 {chunk_idx} 个音频块")
|
||||
|
||||
# 等待播放完成(所有块播完)
|
||||
await player.audio_queue.join()
|
||||
print("🎵 播放完成!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 错误: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def test_japanese():
|
||||
"""测试日语长文本流式播放"""
|
||||
print("=" * 50)
|
||||
print("🗾 日语流式TTS测试")
|
||||
print("=" * 50)
|
||||
|
||||
await stream_tts_and_play(
|
||||
text=test_text,
|
||||
ref_audio_path="uploaded_audio/test_voice.wav",
|
||||
text_lang="ja",
|
||||
prompt_lang="ja",
|
||||
|
||||
streaming_mode=StreamingMode.FASTEST # 模式3:最快
|
||||
)
|
||||
|
||||
async def batch_test():
|
||||
"""批量处理示例"""
|
||||
async with GPTSoVITSClient() as client:
|
||||
texts = [
|
||||
"你好,世界!",
|
||||
"这是一个批量测试。",
|
||||
"异步批量处理非常高效。"
|
||||
]
|
||||
|
||||
results = await client.batch_tts(
|
||||
texts=texts,
|
||||
ref_audio_path="archive_jingyuan_1.wav",
|
||||
text_lang="zh"
|
||||
)
|
||||
|
||||
for i, audio in enumerate(results):
|
||||
audio.save(f"output/batch_{i}.wav")
|
||||
print(f"✅ 批量任务 {i + 1}/{len(results)} 完成")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 检查音频设备
|
||||
print("🔍 检查音频设备...")
|
||||
print(sd.query_devices())
|
||||
sd.default.device = (None, "pulse") # 使用PulseAudio
|
||||
|
||||
asyncio.run(test_japanese())
|
||||
@@ -0,0 +1,24 @@
|
||||
import asyncio
|
||||
import json
|
||||
from websockets.asyncio.client import connect
|
||||
|
||||
async def test_all_types():
|
||||
"""测试三种消息类型"""
|
||||
async with connect("ws://localhost:8765") as ws:
|
||||
print("=== 测试JSON消息 ===")
|
||||
await ws.send(json.dumps({
|
||||
"type": "chat",
|
||||
"content": "你好服务器!"
|
||||
}))
|
||||
print(f"收到: {await ws.recv()}")
|
||||
|
||||
print("\n=== 测试文本消息 ===")
|
||||
await ws.send("这是纯文本消息")
|
||||
print(f"收到: {await ws.recv()}")
|
||||
|
||||
print("\n=== 测试二进制消息 ===")
|
||||
await ws.send(b"\x00\x01\x02\x03\x04")
|
||||
print(f"收到: {await ws.recv()}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_all_types())
|
||||
@@ -0,0 +1,174 @@
|
||||
"""
|
||||
极简 WebSocket 测试服务器 - 修复版本
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Set
|
||||
|
||||
import websockets
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(message)s'
|
||||
)
|
||||
|
||||
class SimpleWebSocketServer:
|
||||
def __init__(self, host="localhost", port=8765):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.clients: Set = set()
|
||||
|
||||
async def handle_connection(self, websocket, path):
|
||||
"""处理客户端连接"""
|
||||
client_id = f"{websocket.remote_address[0]}:{websocket.remote_address[1]}"
|
||||
self.clients.add(websocket)
|
||||
logging.info(f"✅ 客户端连接: {client_id} (当前连接数: {len(self.clients)})")
|
||||
|
||||
try:
|
||||
# 发送欢迎消息
|
||||
welcome = {
|
||||
"type": "connect",
|
||||
"data": {
|
||||
"message": "WebSocket 服务器连接成功",
|
||||
"client_id": client_id,
|
||||
"server_time": datetime.now().isoformat(),
|
||||
"status": "connected"
|
||||
},
|
||||
"timestamp": int(datetime.now().timestamp() * 1000)
|
||||
}
|
||||
await websocket.send(json.dumps(welcome))
|
||||
|
||||
async for message in websocket:
|
||||
await self.handle_message(websocket, client_id, message)
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logging.info(f"❌ 客户端断开: {client_id}")
|
||||
finally:
|
||||
self.clients.discard(websocket)
|
||||
logging.info(f"📊 剩余连接: {len(self.clients)}")
|
||||
|
||||
async def handle_message(self, websocket, client_id, message):
|
||||
"""处理收到的消息"""
|
||||
logging.info(f"📨 收到消息 from {client_id}: {message}")
|
||||
|
||||
try:
|
||||
# 尝试解析为 JSON
|
||||
data = json.loads(message)
|
||||
msg_type = data.get("type", "unknown")
|
||||
|
||||
# 根据消息类型回复
|
||||
if msg_type == "ping":
|
||||
# 心跳响应
|
||||
response = {
|
||||
"type": "pong",
|
||||
"data": {
|
||||
"server_time": datetime.now().isoformat(),
|
||||
"latency": "0ms"
|
||||
},
|
||||
"timestamp": int(datetime.now().timestamp() * 1000)
|
||||
}
|
||||
|
||||
elif msg_type == "login":
|
||||
# 登录响应
|
||||
username = data.get("data", {}).get("username", "anonymous")
|
||||
response = {
|
||||
"type": "login_success",
|
||||
"data": {
|
||||
"user_id": f"user_{abs(hash(username)) % 10000}",
|
||||
"username": username,
|
||||
"status": "authenticated"
|
||||
},
|
||||
"timestamp": int(datetime.now().timestamp() * 1000)
|
||||
}
|
||||
|
||||
elif msg_type == "chat":
|
||||
# 聊天消息回应
|
||||
msg_content = data.get("data", {}).get("message", "")
|
||||
response = {
|
||||
"type": "chat_response",
|
||||
"data": {
|
||||
"message": f"服务器收到: {msg_content}",
|
||||
"sender": "server",
|
||||
"received_at": datetime.now().isoformat()
|
||||
},
|
||||
"timestamp": int(datetime.now().timestamp() * 1000)
|
||||
}
|
||||
|
||||
else:
|
||||
# 默认回显
|
||||
response = {
|
||||
"type": "echo",
|
||||
"data": {
|
||||
"original": data.get("data", {}),
|
||||
"original_type": msg_type,
|
||||
"server_processed_at": datetime.now().isoformat()
|
||||
},
|
||||
"timestamp": int(datetime.now().timestamp() * 1000)
|
||||
}
|
||||
|
||||
await websocket.send(json.dumps(response))
|
||||
|
||||
except json.JSONDecodeError:
|
||||
# 不是 JSON,当作纯文本处理
|
||||
response = {
|
||||
"type": "text_echo",
|
||||
"data": {
|
||||
"original": message,
|
||||
"note": "这是文本消息"
|
||||
},
|
||||
"timestamp": int(datetime.now().timestamp() * 1000)
|
||||
}
|
||||
await websocket.send(json.dumps(response))
|
||||
|
||||
async def start(self):
|
||||
"""启动服务器"""
|
||||
logging.info(f"🚀 启动 WebSocket 服务器: ws://{self.host}:{self.port}")
|
||||
|
||||
# 创建处理函数包装器(解决参数问题)
|
||||
async def connection_handler(websocket, path):
|
||||
await self.handle_connection(websocket, path)
|
||||
|
||||
# 启动服务器
|
||||
server = await websockets.serve(
|
||||
connection_handler,
|
||||
self.host,
|
||||
self.port,
|
||||
ping_interval=None,
|
||||
ping_timeout=None,
|
||||
close_timeout=None,
|
||||
max_size=10 * 1024 * 1024
|
||||
)
|
||||
|
||||
logging.info("📌 服务器已启动,等待连接...")
|
||||
logging.info("🛑 按 Ctrl+C 停止服务器")
|
||||
|
||||
# 保持服务器运行
|
||||
try:
|
||||
await asyncio.Future() # 永久运行
|
||||
finally:
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
logging.info("👋 服务器已关闭")
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='极简 WebSocket 测试服务器')
|
||||
parser.add_argument('--host', default='localhost', help='监听地址')
|
||||
parser.add_argument('--port', type=int, default=8088, help='监听端口')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
server = SimpleWebSocketServer(args.host, args.port)
|
||||
|
||||
try:
|
||||
asyncio.run(server.start())
|
||||
except KeyboardInterrupt:
|
||||
logging.info("👋 服务器被用户中断")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,33 @@
|
||||
# requestTest.py
|
||||
import requests
|
||||
from pathlib import Path
|
||||
|
||||
# 指定正确的 MIME 类型
|
||||
url = "http://192.168.1.8:20260/transcribe"
|
||||
audio_path = Path("test_files/z105300938.wav")
|
||||
|
||||
with open(audio_path, "rb") as f:
|
||||
# 明确指定文件名和 MIME 类型
|
||||
files = {
|
||||
"file": (
|
||||
audio_path.name, # 文件名
|
||||
f, # 文件对象
|
||||
"audio/wav" # MIME 类型
|
||||
)
|
||||
}
|
||||
|
||||
response = requests.post(url, files=files)
|
||||
|
||||
# 打印响应详情
|
||||
print(f"状态码: {response.status_code}")
|
||||
print(f"响应头: {response.headers.get('content-type')}")
|
||||
|
||||
# 检查响应是否成功
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
print(f"识别结果: {result['data']['text']}")
|
||||
print(f"语言: {result['data']['language']}")
|
||||
print(f"置信度: {result['data']['confidence']}")
|
||||
print(f"处理时间: {result['data']['processing_time']}s")
|
||||
else:
|
||||
print(f"错误响应: {response.text}")
|
||||
@@ -0,0 +1,14 @@
|
||||
# 一个小Test, 展示设计的dtos模块与tts和asr的集成
|
||||
from src.modules.websocket_base_module.dto.third_dtos import AudioDataDTO
|
||||
from src.modules.tts_module.tts_core.async_audio_player import AsyncAudioPlayer
|
||||
from src.modules.tts_module.tts_core.gpt_sovits.gpt_sovits_client import GPTSoVITSClient, StreamingMode
|
||||
from src.modules.asr_module.client.asr_client import create_asr_client
|
||||
|
||||
|
||||
# with create_asr_client(base_url="http://192.168.1.5:20260") as client:
|
||||
# # 转录文件
|
||||
# result = client.transcribe_file("test_files/test.wav")
|
||||
# print(f"识别结果: {result.data.text}")
|
||||
# print(f"置信度: {result.data.confidence:.2f}")
|
||||
# print(f"耗时: {result.data.processing_time:.3f}s")
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
from src.modules.websocket_base_module.dto.second_dtos import get_json_dto_instance
|
||||
from src.modules.websocket_base_module.dto.third_dtos import AudioDataDTO
|
||||
from src.modules.websocket_base_module.websocket_core.core_ws_server import get_ws_server
|
||||
import asyncio
|
||||
from loguru import logger
|
||||
async def main():
|
||||
# 获取WebSocket服务器单例
|
||||
ws_server = await get_ws_server()
|
||||
# 获取二级json分发器单例
|
||||
json_dto = await get_json_dto_instance(ws_server)
|
||||
|
||||
# 创建DTO实例(自动注册接收函数)
|
||||
audio_dto = AudioDataDTO(json_dto)
|
||||
|
||||
logger.info("所有DTO接收器已注册,等待客户端连接...")
|
||||
|
||||
# 启动服务器(阻塞)
|
||||
try:
|
||||
await ws_server.run("localhost", 8765)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("服务器任务已取消,正在优雅退出...")
|
||||
finally:
|
||||
logger.info("服务器已停止")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
print("\n✓ 服务器已手动终止(按 Ctrl+C)")
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
After Width: | Height: | Size: 1.6 MiB |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,88 @@
|
||||
from src.modules.text_ai_module.text_ai_core.general_text_ai_req import UnifiedLLM, ModelConfig, ModelProvider, create_llm_client
|
||||
from src.config.config import get_settings
|
||||
from src.config.convert_env import EnvConverter
|
||||
from src.config.file_config import DirectoryInitializer
|
||||
|
||||
EnvConverter().convert(backup_existing=True) # 若是首次启动则从env模板中生成env文件
|
||||
DirectoryInitializer(get_settings()) # 初始化必要的目录(若不存在则创建)
|
||||
|
||||
def test1():
|
||||
"""
|
||||
测试常规调用
|
||||
"""
|
||||
# 配置模型
|
||||
config = ModelConfig(
|
||||
provider=ModelProvider.OPENAI,
|
||||
model_name=get_settings().ai_model_name,
|
||||
base_url=get_settings().ai_api_base_url,
|
||||
api_key=get_settings().ai_api_key, # 从环境中取出相关的api_key
|
||||
temperature=0.7,
|
||||
max_tokens=2048
|
||||
)
|
||||
|
||||
# 创建客户端
|
||||
llm = UnifiedLLM(config)
|
||||
|
||||
# 发送消息
|
||||
response = llm.chat([
|
||||
{"role": "system", "content": "你是一个DeepSeek助手"},
|
||||
{"role": "user", "content": "请介绍一下DeepSeek模型的特点"}
|
||||
])
|
||||
|
||||
print(response.content)
|
||||
|
||||
def base_test2():
|
||||
"""
|
||||
测试流式响应
|
||||
"""
|
||||
# 使用快捷函数
|
||||
deepseek_llm = create_llm_client(
|
||||
provider="openai", # DeepSeek使用OpenAI兼容接口
|
||||
model_name=get_settings().ai_model_name,
|
||||
api_key=get_settings().ai_api_key,
|
||||
base_url=get_settings().ai_api_base_url
|
||||
)
|
||||
|
||||
# 流式聊天
|
||||
messages = [
|
||||
{"role": "user", "content": "用Python写一个快速排序算法"}
|
||||
]
|
||||
|
||||
print("正在生成响应...")
|
||||
for chunk in deepseek_llm.stream_chat(messages):
|
||||
print(chunk.content, end="", flush=True)
|
||||
|
||||
|
||||
def test_lm_studio():
|
||||
"""测试本地 LM Studio 模型"""
|
||||
print("=== 测试本地 LM Studio ===")
|
||||
|
||||
# 使用UnifiedLLM类
|
||||
config = ModelConfig(
|
||||
provider=ModelProvider.LM_STUDIO,
|
||||
model_name="qwen/qwen3-4b-2507",
|
||||
base_url="http://192.168.1.8:1234/v1",
|
||||
api_key="", # LM Studio不需要API密钥,留空
|
||||
temperature=0.7,
|
||||
max_tokens=1024,
|
||||
streaming=False # 启用流式响应
|
||||
)
|
||||
|
||||
llm = UnifiedLLM(config)
|
||||
|
||||
# 发送消息
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个有用的助手"},
|
||||
{"role": "user", "content": "用中文介绍一下自己"}
|
||||
]
|
||||
|
||||
print("非流式响应:")
|
||||
response = llm.chat(messages, streaming=False)
|
||||
print(response.content)
|
||||
|
||||
print("\n流式响应:")
|
||||
for chunk in llm.stream_chat(messages):
|
||||
print(chunk.content, end="", flush=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_lm_studio()
|
||||
@@ -0,0 +1,66 @@
|
||||
import asyncio
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from src.modules.device_control_module.device_control_core.ui_tars_.ui_tars_client import UITarsClient, UITarsClientConfig
|
||||
|
||||
|
||||
async def test_ui_tars_stream():
|
||||
"""测试 UI-TARS 流式调用"""
|
||||
# 创建客户端
|
||||
config = UITarsClientConfig(
|
||||
deployment_type="lmstudio",
|
||||
base_url="http://192.168.1.8:1234/v1",
|
||||
model_name="ui-tars-1.5-7b@q4_k_m",
|
||||
temperature=0.1
|
||||
)
|
||||
client = UITarsClient(config)
|
||||
|
||||
# 使用工具方法编码
|
||||
image_base64 = base64.b64encode(Path("test_files/Screenshot_test.png").read_bytes()).decode()
|
||||
print(f"✅ 图片编码完成,长度: {len(image_base64)} 字符\n")
|
||||
|
||||
# 流式调用并实时打印
|
||||
print("🤖 开始流式调用 UI-TARS...\n")
|
||||
print("思考过程:\n")
|
||||
|
||||
import time
|
||||
# 计算耗时
|
||||
start_time = time.time()
|
||||
full_response = ""
|
||||
chunk_count = 0
|
||||
|
||||
full_response = await client.call_async("打开AK加速器", image_base64)
|
||||
# 传入 base64 字符串
|
||||
# for chunk in client.stream_async("我的桌面系统是KDE, 帮我打开设置", image_base64):
|
||||
# chunk_count += 1
|
||||
# content = chunk.content
|
||||
#
|
||||
# # 实时打印每个 chunk
|
||||
# print(content, end="", flush=True)
|
||||
#
|
||||
# # 累积完整内容
|
||||
# full_response += content
|
||||
|
||||
end_time = time.time()
|
||||
print(f"\n\n耗时: {end_time - start_time:.2f} 秒")
|
||||
print(f"\n\n{'=' * 50}")
|
||||
print(f"✅ 流式调用完成!共接收 {chunk_count} 个 chunk")
|
||||
print(f"完整响应长度: {len(full_response)} 字符")
|
||||
|
||||
print("响应内容:\n")
|
||||
print(full_response)
|
||||
|
||||
import pyautogui
|
||||
def auto_click(x : int, y : int):
|
||||
pyautogui.moveTo(x, y, duration=1.5)
|
||||
pyautogui.click()
|
||||
|
||||
def auto_drag(x1 : int, y1 : int, x2 : int, y2 : int):
|
||||
pyautogui.moveTo(x1, y1, duration=1.5)
|
||||
pyautogui.dragTo(x2, y2, duration=1.5)
|
||||
|
||||
# 运行异步函数
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_ui_tars_stream())
|
||||
auto_click(173,48)
|
||||
# auto_drag(56,39, 170,39)
|
||||
@@ -0,0 +1,45 @@
|
||||
import asyncio
|
||||
from contextlib import suppress
|
||||
|
||||
from loguru import logger
|
||||
from src.config.config import cfg
|
||||
from datetime import datetime
|
||||
from src.server_core.core import YosugaServerCore
|
||||
|
||||
def init():
|
||||
"""
|
||||
Yosuga_server 初始化
|
||||
"""
|
||||
|
||||
# 初始化日志系统
|
||||
logger.add(
|
||||
f"{cfg.log_dir}/Yosuga_server-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.log",
|
||||
encoding="utf-8")
|
||||
logger.info("Yosuga_server 启动")
|
||||
logger.info(f"日志文件目录见: {cfg.log_dir} 目录")
|
||||
|
||||
|
||||
async def main():
|
||||
core = await YosugaServerCore.get_instance()
|
||||
try:
|
||||
await core.run()
|
||||
except asyncio.CancelledError:
|
||||
pass # 正常取消,不打印堆栈
|
||||
finally:
|
||||
# 清理未关闭的 aiohttp sessions
|
||||
import aiohttp
|
||||
pending = [t for t in asyncio.all_tasks()
|
||||
if t is not asyncio.current_task()]
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
with suppress(asyncio.CancelledError):
|
||||
await asyncio.gather(*pending, return_exceptions=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
init()
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
print("\nYosuga服务器已停止喵~~~")
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
[project]
|
||||
name = "yosuga-server"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"aiofiles>=25.1.0",
|
||||
"aiohttp>=3.13.3",
|
||||
"fastapi>=0.128.0",
|
||||
"faster-whisper>=1.2.1",
|
||||
"loguru>=0.7.3",
|
||||
"openai>=2.16.0",
|
||||
"pyautogui>=0.9.54",
|
||||
"pydantic>=2.12.5",
|
||||
"pydantic-settings>=2.12.0",
|
||||
"python-multipart>=0.0.22",
|
||||
"requests>=2.32.5",
|
||||
"sounddevice>=0.5.5",
|
||||
"soundfile>=0.13.1",
|
||||
"tiktoken>=0.12.0",
|
||||
"uvicorn>=0.40.0",
|
||||
"websockets>=16.0",
|
||||
]
|
||||
@@ -0,0 +1,84 @@
|
||||
--index-url https://download.pytorch.org/whl/cpu
|
||||
aiofiles==25.1.0
|
||||
aiohappyeyeballs==2.6.1
|
||||
aiohttp==3.13.3
|
||||
aiosignal==1.4.0
|
||||
annotated-doc==0.0.4
|
||||
annotated-types==0.7.0
|
||||
anyio==4.12.1
|
||||
attrs==25.4.0
|
||||
av==16.1.0
|
||||
certifi==2026.1.4
|
||||
cffi==2.0.0
|
||||
charset-normalizer==3.4.4
|
||||
click==8.3.1
|
||||
colorama==0.4.6
|
||||
coloredlogs==15.0.1
|
||||
ctranslate2==4.6.3
|
||||
distro==1.9.0
|
||||
fastapi==0.128.0
|
||||
faster-whisper==1.2.1
|
||||
filelock==3.20.3
|
||||
flatbuffers==25.12.19
|
||||
frozenlist==1.8.0
|
||||
fsspec==2026.1.0
|
||||
h11==0.16.0
|
||||
hf-xet==1.2.0
|
||||
httpcore==1.0.9
|
||||
httpx==0.28.1
|
||||
huggingface-hub==1.3.7
|
||||
humanfriendly==10.0
|
||||
idna==3.11
|
||||
jinja2==3.1.6
|
||||
jiter==0.13.0
|
||||
loguru==0.7.3
|
||||
markupsafe==2.1.5
|
||||
mouseinfo==0.1.3
|
||||
mpmath==1.3.0
|
||||
multidict==6.7.1
|
||||
networkx==3.6.1
|
||||
numpy==2.4.2
|
||||
onnxruntime==1.23.2
|
||||
openai==2.16.0
|
||||
packaging==26.0
|
||||
pillow==12.1.0
|
||||
propcache==0.4.1
|
||||
protobuf==6.33.5
|
||||
pyautogui==0.9.54
|
||||
pycparser==3.0
|
||||
pydantic==2.12.5
|
||||
pydantic-core==2.41.5
|
||||
pydantic-settings==2.12.0
|
||||
pygetwindow==0.0.9
|
||||
pymsgbox==2.0.1
|
||||
pyperclip==1.11.0
|
||||
pyreadline3==3.5.4
|
||||
pyrect==0.2.0
|
||||
pyscreeze==1.0.1
|
||||
python-dotenv==1.2.1
|
||||
python-multipart==0.0.22
|
||||
pytweening==1.2.0
|
||||
pyyaml==6.0.3
|
||||
regex==2026.1.15
|
||||
requests==2.32.5
|
||||
setuptools==80.10.2
|
||||
shellingham==1.5.4
|
||||
sniffio==1.3.1
|
||||
sounddevice==0.5.5
|
||||
soundfile==0.13.1
|
||||
starlette==0.50.0
|
||||
sympy==1.14.0
|
||||
tiktoken==0.12.0
|
||||
tokenizers==0.22.2
|
||||
torch==2.5.1+cpu
|
||||
torchaudio==2.5.1+cpu
|
||||
torchvision==0.20.1+cpu
|
||||
tqdm==4.67.2
|
||||
typer-slim==0.21.1
|
||||
typing-extensions==4.15.0
|
||||
typing-inspection==0.4.2
|
||||
urllib3==2.6.3
|
||||
uvicorn==0.40.0
|
||||
websockets==16.0
|
||||
win32-setctime==1.2.0
|
||||
yarl==1.22.0
|
||||
Binary file not shown.
@@ -0,0 +1,452 @@
|
||||
"""
|
||||
零初始化 JSON 配置管理
|
||||
用法:from src.config import cfg # 直接访问,自动初始化
|
||||
"""
|
||||
import json
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, TypeVar
|
||||
from dataclasses import dataclass, field, asdict
|
||||
|
||||
def _project_root() -> Path:
|
||||
"""自动查找项目根目录"""
|
||||
markers = ['pyproject.toml', 'settings.json', '.gitignore', 'main.py', '.python-version']
|
||||
current = Path(__file__).resolve().parent.parent # src的父目录
|
||||
|
||||
for path in [current, *current.parents]:
|
||||
if any((path / m).exists() for m in markers):
|
||||
return path
|
||||
if path == path.parent:
|
||||
break
|
||||
return current
|
||||
|
||||
# 配置定义
|
||||
|
||||
@dataclass
|
||||
class AIConfig:
|
||||
api_key: Optional[str] = "sk-xxxxx"
|
||||
base_url: str = "http://localhost:1234/v1"
|
||||
model_name: str = "qwen/qwen3-4b-2507"
|
||||
timeout: int = 30
|
||||
temperature: float = 0.4
|
||||
max_tokens: int = 8192
|
||||
|
||||
|
||||
@dataclass
|
||||
class TTSConfig:
|
||||
enabled: bool = True
|
||||
api_key: Optional[str] = None
|
||||
gpt_model_name: str = "GPT_weights_v2Pro/Yosuga_Airi-e32.ckpt"
|
||||
sovits_model_name: str = "SoVITS_weights_v2Pro/Yosuga_Airi_e16_s864.pth"
|
||||
host: str = "localhost"
|
||||
port: int = 20261
|
||||
reference_audio: str = "./using/reference.wav"
|
||||
streaming: bool = True
|
||||
speed: float = 1.0
|
||||
|
||||
@dataclass
|
||||
class ASRConfig:
|
||||
enabled: bool = True
|
||||
api_key: Optional[str] = None
|
||||
model_name: str = "fast-whisper"
|
||||
url: str = "http://localhost:20260/"
|
||||
|
||||
@dataclass
|
||||
class AutoAgentConfig:
|
||||
enabled: bool = True
|
||||
api_key: Optional[str] = None
|
||||
deployment_type: str = "lmstudio"
|
||||
model_name: str = "ui-tars-1.5-7b@q4_k_m"
|
||||
base_url: str = "http://localhost:1234/v1"
|
||||
temperature: float = 0.1
|
||||
max_tokens: int = 16384
|
||||
|
||||
@dataclass
|
||||
class LLMConfig:
|
||||
enabled: bool = True
|
||||
role_character: str = "你是由Misakiotoha开发的助手稲葉愛理ちゃん,可以和用户一起玩游戏,聊天,做各种事情,性格抽象,没事爱整整活。"
|
||||
max_context_tokens: int = 2048
|
||||
enable_history: bool = True
|
||||
language: str = "日本语"
|
||||
|
||||
@dataclass
|
||||
class PathsConfig:
|
||||
temp: str = "./tmp/"
|
||||
log: str = "./log/"
|
||||
using: str = "./using/"
|
||||
|
||||
|
||||
class AppConfig:
|
||||
"""
|
||||
应用主配置
|
||||
新增配置分组:1) 上方新建 dataclass 2) 下方 __init__ 添加字段 3) 完成
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
version: str = "1.0.0",
|
||||
debug: bool = False,
|
||||
ai: Optional[AIConfig] = None,
|
||||
tts: Optional[TTSConfig] = None,
|
||||
asr: Optional[ASRConfig] = None,
|
||||
auto_agent: Optional[AutoAgentConfig] = None,
|
||||
llm_core: Optional[LLMConfig] = None,
|
||||
paths: Optional[PathsConfig] = None,
|
||||
_config_path: Optional[Path] = None,
|
||||
**kwargs
|
||||
):
|
||||
# 基础字段
|
||||
self.version = version
|
||||
self.debug = debug
|
||||
self.ai = ai if ai is not None else AIConfig()
|
||||
self.tts = tts if tts is not None else TTSConfig()
|
||||
self.asr = asr if asr is not None else ASRConfig()
|
||||
self.auto_agent = auto_agent if auto_agent is not None else AutoAgentConfig()
|
||||
self.llm_core = llm_core if llm_core is not None else LLMConfig()
|
||||
self.paths = paths if paths is not None else PathsConfig()
|
||||
|
||||
# 内部状态(非 dataclass 字段,不会被序列化)
|
||||
self._config_path = _config_path
|
||||
self._lock = threading.RLock() # 普通属性,非 dataclass 字段
|
||||
|
||||
# 应用其他字段(用于从 JSON 加载时)
|
||||
for k, v in kwargs.items():
|
||||
if hasattr(self, k):
|
||||
setattr(self, k, v)
|
||||
|
||||
# 路径解析为绝对路径
|
||||
if self._config_path:
|
||||
root = self._config_path.parent
|
||||
for field_name in ['temp', 'log', 'using']:
|
||||
rel_path = getattr(self.paths, field_name)
|
||||
if not Path(rel_path).is_absolute():
|
||||
abs_path = (root / Path(rel_path)).resolve()
|
||||
setattr(self.paths, field_name, str(abs_path) + '/')
|
||||
|
||||
# 便捷属性
|
||||
|
||||
@property
|
||||
def temp_dir(self) -> Path:
|
||||
return Path(self.paths.temp)
|
||||
|
||||
@property
|
||||
def log_dir(self) -> Path:
|
||||
return Path(self.paths.log)
|
||||
|
||||
@property
|
||||
def using_dir(self) -> Path:
|
||||
return Path(self.paths.using)
|
||||
|
||||
# 核心方法
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
点号路径访问:cfg.get("ai.timeout") / cfg.get("tts.enabled")
|
||||
"""
|
||||
try:
|
||||
keys = key.split('.')
|
||||
value = self
|
||||
for k in keys:
|
||||
value = getattr(value, k) if not isinstance(value, dict) else value[k]
|
||||
return value
|
||||
except (AttributeError, KeyError):
|
||||
return default
|
||||
|
||||
def set(self, key: str, value: Any, save: bool = True) -> 'AppConfig':
|
||||
"""
|
||||
点号路径设置,支持链式调用
|
||||
cfg.set("ai.timeout", 60).set("debug", True)
|
||||
"""
|
||||
with self._lock:
|
||||
keys = key.split('.')
|
||||
target = self
|
||||
for k in keys[:-1]:
|
||||
target = getattr(target, k)
|
||||
setattr(target, keys[-1], value)
|
||||
|
||||
if save:
|
||||
self._save()
|
||||
return self
|
||||
|
||||
def update(self, updates: Dict[str, Any], save: bool = True) -> 'AppConfig':
|
||||
"""
|
||||
批量更新:cfg.update({"ai": {"timeout": 60}, "debug": True})
|
||||
"""
|
||||
def deep_update(obj: Any, data: dict):
|
||||
for k, v in data.items():
|
||||
if hasattr(obj, k):
|
||||
current = getattr(obj, k)
|
||||
if isinstance(v, dict) and hasattr(current, '__dataclass_fields__'):
|
||||
deep_update(current, v)
|
||||
else:
|
||||
setattr(obj, k, v)
|
||||
|
||||
with self._lock:
|
||||
deep_update(self, updates)
|
||||
|
||||
if save:
|
||||
self._save()
|
||||
return self
|
||||
|
||||
def reload(self) -> 'AppConfig':
|
||||
"""热重载配置"""
|
||||
if self._config_path and self._config_path.exists():
|
||||
with self._lock:
|
||||
data = json.loads(self._config_path.read_text(encoding='utf-8'))
|
||||
|
||||
# 配置项名 -> dataclass 类的映射(和 _load 保持一致)
|
||||
config_classes = {
|
||||
'ai': AIConfig,
|
||||
'tts': TTSConfig,
|
||||
'asr': ASRConfig,
|
||||
'auto_agent': AutoAgentConfig,
|
||||
'llm_core': LLMConfig,
|
||||
'paths': PathsConfig,
|
||||
}
|
||||
|
||||
for k, v in data.items():
|
||||
if hasattr(self, k) and not k.startswith('_'):
|
||||
# 如果是配置项且是 dict,转换为 dataclass
|
||||
if k in config_classes and isinstance(v, dict):
|
||||
setattr(self, k, config_classes[k](**v))
|
||||
else:
|
||||
setattr(self, k, v)
|
||||
print(f"配置重载: {self._config_path}")
|
||||
return self
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""导出为字典(手动实现,排除内部属性)"""
|
||||
result = {
|
||||
'version': self.version,
|
||||
'debug': self.debug,
|
||||
'ai': asdict(self.ai) if hasattr(self.ai, '__dataclass_fields__') else self.ai,
|
||||
'tts': asdict(self.tts) if hasattr(self.tts, '__dataclass_fields__') else self.tts,
|
||||
'asr': asdict(self.asr) if hasattr(self.asr, '__dataclass_fields__') else self.asr,
|
||||
'auto_agent': asdict(self.auto_agent) if hasattr(self.auto_agent, '__dataclass_fields__') else self.auto_agent,
|
||||
'llm_core': asdict(self.llm_core) if hasattr(self.llm_core, '__dataclass_fields__') else self.llm_core,
|
||||
'paths': asdict(self.paths) if hasattr(self.paths, '__dataclass_fields__') else self.paths,
|
||||
}
|
||||
return result
|
||||
|
||||
def _save(self) -> None:
|
||||
"""保存到文件"""
|
||||
if self._config_path:
|
||||
with self._lock:
|
||||
json_str = json.dumps(self.to_dict(), indent=2, ensure_ascii=False)
|
||||
self._config_path.write_text(json_str, encoding='utf-8')
|
||||
|
||||
@classmethod
|
||||
def _load(cls, path: Path) -> 'AppConfig':
|
||||
"""从文件加载"""
|
||||
data = json.loads(path.read_text(encoding='utf-8'))
|
||||
|
||||
# 配置项名 -> dataclass 类的映射
|
||||
config_classes = {
|
||||
'ai': AIConfig,
|
||||
'tts': TTSConfig,
|
||||
'asr': ASRConfig,
|
||||
'auto_agent': AutoAgentConfig,
|
||||
'llm_core': LLMConfig,
|
||||
'paths': PathsConfig,
|
||||
}
|
||||
|
||||
# 自动转换 dict 为对应 dataclass
|
||||
for key, config_class in config_classes.items():
|
||||
if key in data and isinstance(data[key], dict):
|
||||
data[key] = config_class(**data[key])
|
||||
|
||||
return cls(_config_path=path, **data)
|
||||
|
||||
@classmethod
|
||||
def _create_default(cls, path: Path) -> 'AppConfig':
|
||||
"""创建默认配置"""
|
||||
instance = cls(_config_path=path)
|
||||
instance._save()
|
||||
print(f"默认配置已创建: {path}")
|
||||
return instance
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""友好的打印格式"""
|
||||
lines = ["AppConfig("]
|
||||
for k, v in self.to_dict().items():
|
||||
lines.append(f" {k}={v!r},")
|
||||
lines.append(")")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# 延迟初始化机制
|
||||
|
||||
_root: Path = _project_root()
|
||||
_config_path: Path = _root / "settings.json"
|
||||
_config_instance: Optional[AppConfig] = None
|
||||
_init_lock: threading.Lock = threading.Lock()
|
||||
|
||||
|
||||
def _ensure_initialized() -> AppConfig:
|
||||
"""
|
||||
确保配置已初始化(线程安全的延迟初始化)
|
||||
"""
|
||||
global _config_instance
|
||||
|
||||
if _config_instance is not None:
|
||||
return _config_instance
|
||||
|
||||
with _init_lock:
|
||||
if _config_instance is not None:
|
||||
return _config_instance
|
||||
|
||||
# 自动加载或创建
|
||||
if _config_path.exists():
|
||||
_config_instance = AppConfig._load(_config_path)
|
||||
print(f"配置加载: {_config_path}")
|
||||
else:
|
||||
_config_instance = AppConfig._create_default(_config_path)
|
||||
|
||||
# 确保目录存在
|
||||
for d in [_config_instance.temp_dir,
|
||||
_config_instance.log_dir,
|
||||
_config_instance.using_dir]:
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return _config_instance
|
||||
|
||||
|
||||
class _LazyConfig:
|
||||
"""
|
||||
配置代理类:拦截所有属性访问,第一次使用时自动初始化
|
||||
"""
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""拦截属性访问,延迟初始化"""
|
||||
instance = _ensure_initialized()
|
||||
return getattr(instance, name)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
"""拦截属性设置"""
|
||||
# 特殊属性直接设置到代理对象本身
|
||||
if name.startswith('_'):
|
||||
super().__setattr__(name, value)
|
||||
else:
|
||||
instance = _ensure_initialized()
|
||||
setattr(instance, name, value)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
instance = _ensure_initialized()
|
||||
return repr(instance)
|
||||
|
||||
def __dir__(self) -> list:
|
||||
"""支持 IDE 自动补全"""
|
||||
instance = _ensure_initialized()
|
||||
return dir(instance)
|
||||
|
||||
# 显式代理必要方法
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
return _ensure_initialized().get(key, default)
|
||||
|
||||
def set(self, key: str, value: Any, save: bool = True) -> AppConfig:
|
||||
return _ensure_initialized().set(key, value, save)
|
||||
|
||||
def update(self, updates: Dict[str, Any], save: bool = True) -> AppConfig:
|
||||
return _ensure_initialized().update(updates, save)
|
||||
|
||||
def reload(self) -> AppConfig:
|
||||
return _ensure_initialized().reload()
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return _ensure_initialized().to_dict()
|
||||
|
||||
def save(self) -> None:
|
||||
_ensure_initialized()._save()
|
||||
|
||||
# 属性代理
|
||||
@property
|
||||
def ai(self) -> AIConfig:
|
||||
return _ensure_initialized().ai
|
||||
|
||||
@property
|
||||
def tts(self) -> TTSConfig:
|
||||
return _ensure_initialized().tts
|
||||
|
||||
@property
|
||||
def asr(self) -> ASRConfig:
|
||||
return _ensure_initialized().asr
|
||||
|
||||
@property
|
||||
def auto_agent(self) -> AutoAgentConfig:
|
||||
return _ensure_initialized().auto_agent
|
||||
|
||||
@property
|
||||
def llm_core(self) -> LLMConfig:
|
||||
return _ensure_initialized().llm_core
|
||||
|
||||
@property
|
||||
def paths(self) -> PathsConfig:
|
||||
return _ensure_initialized().paths
|
||||
|
||||
@property
|
||||
def temp_dir(self) -> Path:
|
||||
return _ensure_initialized().temp_dir
|
||||
|
||||
@property
|
||||
def log_dir(self) -> Path:
|
||||
return _ensure_initialized().log_dir
|
||||
|
||||
@property
|
||||
def using_dir(self) -> Path:
|
||||
return _ensure_initialized().using_dir
|
||||
|
||||
|
||||
# 全局配置对象:导入即用,自动初始化
|
||||
cfg: AppConfig = _LazyConfig() # type: ignore
|
||||
|
||||
|
||||
# 工具函数
|
||||
|
||||
def generate_example(path: Path = _root / "settings.example.json") -> None:
|
||||
"""生成示例配置文件"""
|
||||
example = {
|
||||
"version": "1.0.0",
|
||||
"debug": False,
|
||||
"ai": {
|
||||
"api_key": "sk-your-api-key",
|
||||
"base_url": "https://api.deepseek.com",
|
||||
"model": "deepseek-chat",
|
||||
"timeout": 30
|
||||
},
|
||||
"tts": {
|
||||
"enabled": True,
|
||||
"model": "GPT_SoVITS",
|
||||
"url": "http://localhost:12458/",
|
||||
"reference_audio": "./using/reference.wav",
|
||||
"streaming": True,
|
||||
"speed": 1.0
|
||||
},
|
||||
"paths": {
|
||||
"temp": "./tmp/",
|
||||
"log": "./log/",
|
||||
"data": "./data/"
|
||||
}
|
||||
}
|
||||
path.write_text(json.dumps(example, indent=2, ensure_ascii=False), encoding='utf-8')
|
||||
print(f"示例配置已生成: {path}")
|
||||
|
||||
|
||||
# 测试代码
|
||||
if __name__ == "__main__":
|
||||
# 测试:直接访问,自动初始化
|
||||
print("第一次访问 cfg.ai.model:")
|
||||
print(f" → {cfg.ai.model_name}")
|
||||
|
||||
print(f"\n配置详情:")
|
||||
print(cfg)
|
||||
|
||||
print(f"\n测试修改:")
|
||||
cfg.set("ai.timeout", 60)
|
||||
print(f" ai.timeout = {cfg.ai.timeout}")
|
||||
|
||||
print(f"\n测试批量更新:")
|
||||
cfg.update({"debug": True, "tts": {"speed": 1.5}})
|
||||
print(f" debug = {cfg.debug}, tts.speed = {cfg.tts.speed}")
|
||||
|
||||
print(f"\n测试热重载:")
|
||||
cfg.reload()
|
||||
@@ -0,0 +1,123 @@
|
||||
# asr_module/api.py
|
||||
from fastapi import FastAPI, File, UploadFile, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import time
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
from src.modules.asr_module.asr_core.fast_whisper import create_asr, ASRConfig
|
||||
|
||||
# 初始化FastAPI应用
|
||||
app = FastAPI(
|
||||
title="Yosuga ASR API",
|
||||
description="基于faster-whisper Turbo的高性能多语种语音转文本服务",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# 全局单例ASR实例(延迟加载)
|
||||
_asr_instance = None
|
||||
|
||||
def get_asr():
|
||||
"""获取或创建ASR实例(单例)"""
|
||||
global _asr_instance
|
||||
if _asr_instance is None:
|
||||
logger.info("🚀 初始化ASR服务...")
|
||||
_asr_instance = create_asr(
|
||||
ASRConfig(
|
||||
model_name="deepdml/faster-whisper-large-v3-turbo-ct2",
|
||||
device="auto",
|
||||
compute_type="int8_float16",
|
||||
cache_dir=Path("asr_models/faster_whisper_large_v3_ct2"),
|
||||
beam_size=1, # 贪婪搜索,速度最快
|
||||
vad_filter=True, # 过滤静音,节省30%时间
|
||||
)
|
||||
)
|
||||
logger.info("✅ ASR服务初始化完成")
|
||||
return _asr_instance
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""应用启动时预加载模型"""
|
||||
get_asr()
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""应用关闭时清理资源"""
|
||||
global _asr_instance
|
||||
if _asr_instance:
|
||||
_asr_instance.shutdown()
|
||||
logger.info("🛑 ASR服务已关闭")
|
||||
|
||||
@app.post("/transcribe", response_class=JSONResponse)
|
||||
async def transcribe_audio(
|
||||
file: UploadFile = File(..., description="音频文件 (WAV, FLAC, MP3等格式)")
|
||||
):
|
||||
"""
|
||||
语音转文本API
|
||||
|
||||
- **file**: 音频文件,支持WAV/FLAC/MP3等格式
|
||||
- **返回**: JSON格式结果,包含text/language/confidence
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# 验证文件类型
|
||||
if file.content_type and not file.content_type.startswith("audio/"):
|
||||
raise HTTPException(status_code=400, detail="❌ 请上传音频文件 (MIME类型: audio/*)")
|
||||
|
||||
try:
|
||||
# 创建临时文件
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as tmp_file:
|
||||
content = await file.read()
|
||||
tmp_file.write(content)
|
||||
tmp_path = Path(tmp_file.name)
|
||||
|
||||
logger.info(f"📥 接收文件: {file.filename} ({len(content)} bytes)")
|
||||
|
||||
# 调用ASR识别
|
||||
asr = get_asr()
|
||||
text, language, confidence = asr.transcribe_wav(tmp_path)
|
||||
|
||||
# 清理临时文件
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
logger.info(f"✅ 识别完成: {language} | {len(text)}字符 | 置信度:{confidence:.2f} | 耗时:{processing_time:.3f}s")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"text": text,
|
||||
"language": language,
|
||||
"confidence": confidence,
|
||||
"processing_time": round(processing_time, 3)
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 识别失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"识别失败: {str(e)}")
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""健康检查接口"""
|
||||
asr = get_asr()
|
||||
health = asr.health_check()
|
||||
|
||||
return {
|
||||
"status": "healthy" if health["status"] == "healthy" else "unhealthy",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"device": health["device"],
|
||||
"model_loaded": health["model_loaded"]
|
||||
}
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""API根路径"""
|
||||
return {
|
||||
"message": "Yosuga ASR API 正在运行",
|
||||
"docs": "/docs",
|
||||
"health": "/health",
|
||||
"transcribe": "/transcribe (POST)"
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
# fast_whisper/__init__.py
|
||||
from typing import Optional
|
||||
from src.modules.asr_module.asr_core.fast_whisper.config import ASRConfig
|
||||
from src.modules.asr_module.asr_core.fast_whisper.model_manager import ModelManager
|
||||
from src.modules.asr_module.asr_core.fast_whisper.asr_interface import ASRInterface
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__all__ = ["ASRConfig", "ModelManager", "ASRInterface"]
|
||||
|
||||
def create_asr(config: Optional[ASRConfig] = None) -> ASRInterface:
|
||||
"""
|
||||
快速创建ASR实例
|
||||
Args:
|
||||
config: ASR配置,若为None则使用默认配置
|
||||
"""
|
||||
return ASRInterface.get_instance(config)
|
||||
@@ -0,0 +1,184 @@
|
||||
# fast_whisper/asr_interface.py
|
||||
from loguru import logger
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Optional
|
||||
import torchaudio
|
||||
import torch
|
||||
import numpy
|
||||
|
||||
from .model_manager import ModelManager
|
||||
from .config import ASRConfig
|
||||
from .utils import PerformanceProfiler
|
||||
|
||||
class ASRInterface:
|
||||
"""
|
||||
ASR接口类 - 全局单例
|
||||
- 提供wav转文本功能
|
||||
- 注入ModelManager
|
||||
- 性能统计
|
||||
"""
|
||||
|
||||
_instance: Optional['ASRInterface'] = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""单例模式实现"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, config: Optional[ASRConfig] = None):
|
||||
# 防止重复初始化
|
||||
if hasattr(self, '_initialized') and self._initialized:
|
||||
return
|
||||
|
||||
self.config = config or ASRConfig()
|
||||
self.model_manager = ModelManager(self.config)
|
||||
self.profiler = PerformanceProfiler(self.config.enable_profiling)
|
||||
|
||||
# 音频参数
|
||||
self.sample_rate = 16000
|
||||
|
||||
self._initialized = True
|
||||
logger.info("🎤 ASR接口初始化完成")
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, config: Optional[ASRConfig] = None) -> 'ASRInterface':
|
||||
"""全局访问点"""
|
||||
if cls._instance is None:
|
||||
cls._instance = cls(config)
|
||||
return cls._instance
|
||||
|
||||
def transcribe_wav(
|
||||
self,
|
||||
wav_path: Path,
|
||||
language: Optional[str] = None
|
||||
) -> Tuple[str, str, float]:
|
||||
"""
|
||||
WAV音频转文本(核心接口)
|
||||
|
||||
Args:
|
||||
wav_path: WAV文件路径
|
||||
language: 指定语言代码(如'zh'/'en'),None则自动检测
|
||||
|
||||
Returns:
|
||||
(text, language, confidence)
|
||||
"""
|
||||
try:
|
||||
# 记录开始时间
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
logger.info(f"🎵 开始识别: {wav_path.name}")
|
||||
|
||||
# 执行识别...
|
||||
audio = self._load_audio(wav_path)
|
||||
result = self._transcribe(audio, language)
|
||||
text, lang, confidence = self._parse_result(result)
|
||||
|
||||
# 计算耗时
|
||||
processing_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"✅ 识别完成: {lang} | {len(text)}字符 | 置信度:{confidence:.2f} | "
|
||||
f"耗时:{processing_time:.3f}s | RTF:{processing_time/(len(audio)/self.sample_rate):.3f}"
|
||||
)
|
||||
|
||||
return text, lang, confidence
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 识别失败 {wav_path}: {e}")
|
||||
raise RuntimeError(f"Transcription failed: {e}")
|
||||
|
||||
def _load_audio(self, wav_path: Path) -> numpy.ndarray:
|
||||
"""加载和预处理音频"""
|
||||
if not wav_path.exists():
|
||||
raise FileNotFoundError(f"音频文件不存在: {wav_path}")
|
||||
|
||||
# 加载音频
|
||||
waveform, sample_rate = torchaudio.load(wav_path)
|
||||
|
||||
# 重采样到16kHz
|
||||
if sample_rate != self.sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(sample_rate, self.sample_rate)
|
||||
waveform = resampler(waveform)
|
||||
|
||||
# 转换为单声道
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||
|
||||
# 转换为numpy数组
|
||||
audio = waveform.squeeze().numpy()
|
||||
|
||||
return audio
|
||||
|
||||
def _transcribe(self, audio: numpy.ndarray, language: Optional[str]) -> Tuple:
|
||||
"""执行推理"""
|
||||
model = self.model_manager.model
|
||||
|
||||
# 添加模型存在性检查
|
||||
if model is None:
|
||||
logger.error("ASR模型未加载,请检查模型配置和路径")
|
||||
raise RuntimeError("ASR模型未加载,请检查模型配置和路径")
|
||||
|
||||
# 记录时间
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# 调用模型
|
||||
segments, info = model.transcribe(
|
||||
audio,
|
||||
language=language,
|
||||
beam_size=self.config.beam_size,
|
||||
best_of=self.config.best_of,
|
||||
vad_filter=self.config.vad_filter,
|
||||
)
|
||||
|
||||
# 立即执行生成器
|
||||
segments_list = list(segments)
|
||||
|
||||
# 性能统计
|
||||
inference_time = time.time() - start_time
|
||||
audio_duration = len(audio) / self.sample_rate
|
||||
self.profiler.record(audio_duration, inference_time)
|
||||
|
||||
return segments_list, info
|
||||
|
||||
def _parse_result(self, result: Tuple) -> Tuple[str, str, float]:
|
||||
"""解析识别结果"""
|
||||
segments, info = result
|
||||
|
||||
# 合并所有片段
|
||||
text = " ".join([seg.text.strip() for seg in segments])
|
||||
|
||||
# 获取语言信息
|
||||
language = info.language if info else "unknown"
|
||||
confidence = info.language_probability if info else 0.0
|
||||
|
||||
return text, language, confidence
|
||||
|
||||
def transcribe_batch(self, wav_paths: list) -> list:
|
||||
"""批量识别接口"""
|
||||
return [
|
||||
{
|
||||
"file": str(path),
|
||||
"text": result[0],
|
||||
"language": result[1],
|
||||
"confidence": result[2]
|
||||
}
|
||||
for path, result in zip(wav_paths, [
|
||||
self.transcribe_wav(Path(p)) for p in wav_paths
|
||||
])
|
||||
]
|
||||
|
||||
def health_check(self) -> dict:
|
||||
"""健康检查接口"""
|
||||
return {
|
||||
"status": "healthy" if self.model_manager.model else "unhealthy",
|
||||
"device": self.config.device,
|
||||
"model_loaded": self.model_manager.model is not None,
|
||||
"device_info": self.model_manager.get_device_info(),
|
||||
}
|
||||
|
||||
def shutdown(self):
|
||||
"""优雅关闭"""
|
||||
logger.info("🛑 关闭ASR接口...")
|
||||
self.model_manager.unload()
|
||||
@@ -0,0 +1,29 @@
|
||||
# fast_whisper/config.py
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
import torch
|
||||
|
||||
@dataclass
|
||||
class ASRConfig:
|
||||
"""ASR配置类"""
|
||||
model_name: str = "deepdml/faster-whisper-large-v3-turbo-ct2"
|
||||
device: str = "auto"
|
||||
compute_type: str = "int8_float16"
|
||||
cache_dir: Path = Path.home() / ".cache" / "faster_whisper"
|
||||
|
||||
# 速度优化参数
|
||||
beam_size: int = 1
|
||||
best_of: int = 1
|
||||
vad_filter: bool = True
|
||||
batch_size: int = 16
|
||||
|
||||
# 性能统计
|
||||
enable_profiling: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.device == "auto":
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
if self.device == "cpu":
|
||||
self.compute_type = "int8"
|
||||
self.batch_size = 4
|
||||
@@ -0,0 +1,92 @@
|
||||
# fast_whisper/model_manager.py
|
||||
import gc
|
||||
from loguru import logger
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from faster_whisper import WhisperModel
|
||||
import torch
|
||||
|
||||
from .config import ASRConfig
|
||||
|
||||
|
||||
class ModelManager:
|
||||
"""
|
||||
模型管理类
|
||||
- 负责模型生命周期管理
|
||||
- 支持自定义缓存目录
|
||||
- 自动硬件适配
|
||||
"""
|
||||
|
||||
def __init__(self, config: ASRConfig):
|
||||
self.config = config
|
||||
self._model: Optional[WhisperModel] = None
|
||||
self._device_info = None
|
||||
|
||||
# 确保缓存目录存在
|
||||
self.config.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@property
|
||||
def model(self) -> Optional[WhisperModel]:
|
||||
"""懒加载模型"""
|
||||
if self._model is None:
|
||||
self._load_model()
|
||||
return self._model
|
||||
|
||||
def _load_model(self):
|
||||
"""加载模型"""
|
||||
logger.info(f"🚀 初始化模型: {self.config.model_name}")
|
||||
logger.info(f"📦 设备: {self.config.device}, 计算类型: {self.config.compute_type}")
|
||||
|
||||
try:
|
||||
self._model = WhisperModel(
|
||||
self.config.model_name,
|
||||
device=self.config.device,
|
||||
compute_type=self.config.compute_type,
|
||||
download_root=str(self.config.cache_dir),
|
||||
local_files_only=False,
|
||||
)
|
||||
|
||||
self._device_info = {
|
||||
"device": self.config.device,
|
||||
"compute_type": self.config.compute_type,
|
||||
"model_size": self.config.model_name.split("-")[-2]
|
||||
}
|
||||
|
||||
logger.info("✅ 模型加载成功")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 模型加载失败: {e}")
|
||||
raise RuntimeError(f"Failed to load ASR model: {e}")
|
||||
|
||||
def reload(self, new_config: ASRConfig):
|
||||
"""热重载模型"""
|
||||
logger.info("🔄 热重载模型...")
|
||||
self.unload()
|
||||
self.config = new_config
|
||||
self._load_model()
|
||||
|
||||
def unload(self):
|
||||
"""卸载模型释放资源"""
|
||||
if self._model is not None:
|
||||
logger.info("🗑️ 卸载模型...")
|
||||
del self._model
|
||||
self._model = None
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
gc.collect()
|
||||
|
||||
logger.info("✅ 模型已卸载")
|
||||
|
||||
def get_device_info(self) -> dict:
|
||||
"""获取设备信息"""
|
||||
return self._device_info or {}
|
||||
|
||||
def __enter__(self):
|
||||
"""上下文管理器支持"""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""自动清理资源"""
|
||||
self.unload()
|
||||
@@ -0,0 +1,45 @@
|
||||
# fast_whisper/utils.py
|
||||
from loguru import logger
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any
|
||||
|
||||
def check_hardware() -> Dict[str, Any]:
|
||||
"""硬件检测"""
|
||||
import torch
|
||||
info = {
|
||||
"cuda_available": torch.cuda.is_available(),
|
||||
"device_name": "CPU",
|
||||
"device_count": 0,
|
||||
"compute_type": "int8"
|
||||
}
|
||||
|
||||
if info["cuda_available"]:
|
||||
info.update({
|
||||
"device_name": torch.cuda.get_device_name(0),
|
||||
"device_count": torch.cuda.device_count(),
|
||||
"compute_type": "int8_float16"
|
||||
})
|
||||
|
||||
return info
|
||||
|
||||
class PerformanceProfiler:
|
||||
"""性能分析器"""
|
||||
def __init__(self, enable: bool = True):
|
||||
self.enable = enable
|
||||
self.stats = []
|
||||
|
||||
def record(self, audio_duration: float, inference_time: float):
|
||||
if not self.enable:
|
||||
return
|
||||
|
||||
rtf = inference_time / audio_duration if audio_duration > 0 else 0
|
||||
self.stats.append({
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"rtf": rtf,
|
||||
"audio_duration": audio_duration,
|
||||
"inference_time": inference_time
|
||||
})
|
||||
|
||||
if len(self.stats) % 10 == 0:
|
||||
avg_rtf = sum(s["rtf"] for s in self.stats[-10:]) / 10
|
||||
logger.info(f"📊 最近10次平均RTF: {avg_rtf:.3f}")
|
||||
@@ -0,0 +1,215 @@
|
||||
# asr_module/client/asr_client.py
|
||||
import asyncio
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Union, Optional
|
||||
import aiofiles
|
||||
import aiohttp
|
||||
import requests
|
||||
from loguru import logger
|
||||
from .models import ASRResponse, ASRHealthStatus, ServiceInfo
|
||||
|
||||
class ASRException(Exception):
|
||||
"""ASR服务调用异常"""
|
||||
|
||||
def __init__(self, message: str, status_code: Optional[int] = None):
|
||||
self.message = message
|
||||
self.status_code = status_code
|
||||
super().__init__(self.message)
|
||||
|
||||
class ASRClientConfig:
|
||||
"""客户端配置"""
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str = "http://localhost:8000",
|
||||
timeout: float = 30.0,
|
||||
retry_count: int = 2,
|
||||
retry_delay: float = 0.5,
|
||||
):
|
||||
self.base_url = base_url.rstrip('/')
|
||||
self.timeout = timeout
|
||||
self.retry_count = retry_count
|
||||
self.retry_delay = retry_delay
|
||||
|
||||
|
||||
# 同步客户端
|
||||
class ASRClientSync:
|
||||
"""同步ASR客户端"""
|
||||
def __init__(self, config: Optional[ASRClientConfig] = None):
|
||||
self.config = config or ASRClientConfig()
|
||||
self.session = requests.Session()
|
||||
self.session.timeout = self.config.timeout
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.session.close()
|
||||
|
||||
def _request(self, method: str, endpoint: str, **kwargs) -> dict:
|
||||
"""统一请求处理(带重试)"""
|
||||
url = f"{self.config.base_url}{endpoint}"
|
||||
|
||||
for attempt in range(self.config.retry_count + 1):
|
||||
try:
|
||||
response = self.session.request(method, url, **kwargs)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
if attempt < self.config.retry_count:
|
||||
logger.warning(f"请求失败,重试中 ({attempt + 1}/{self.config.retry_count}): {e}")
|
||||
time.sleep(self.config.retry_delay)
|
||||
else:
|
||||
logger.error(f"请求最终失败: {e}")
|
||||
raise ASRException(f"API调用失败: {e}", getattr(e.response, 'status_code', None))
|
||||
|
||||
def transcribe_file(self, file_path: Union[str, Path]) -> ASRResponse:
|
||||
"""
|
||||
转录音频文件
|
||||
|
||||
Args:
|
||||
file_path: 音频文件路径
|
||||
|
||||
Returns:
|
||||
ASRResponse对象
|
||||
|
||||
Example:
|
||||
client = ASRClientSync()
|
||||
result = client.transcribe_file("/path/to/audio.wav")
|
||||
print(result.data.text)
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
|
||||
logger.info(f"📤 上传文件: {file_path.name}")
|
||||
|
||||
with open(file_path, 'rb') as f:
|
||||
files = {'file': (file_path.name, f, 'audio/wav')}
|
||||
result = self._request('POST', '/transcribe', files=files)
|
||||
|
||||
return ASRResponse(**result)
|
||||
|
||||
def transcribe_bytes(self, audio_data: bytes, filename: str = "audio.wav") -> ASRResponse:
|
||||
"""
|
||||
转录音频字节流
|
||||
|
||||
Args:
|
||||
audio_data: 原始音频字节
|
||||
filename: 模拟文件名(用于MIME类型推断)
|
||||
|
||||
Returns:
|
||||
ASRResponse对象
|
||||
|
||||
Example:
|
||||
with open('audio.wav', 'rb') as f:
|
||||
audio_bytes = f.read()
|
||||
result = client.transcribe_bytes(audio_bytes)
|
||||
"""
|
||||
logger.info(f"📤 上传字节流 ({len(audio_data)} bytes)")
|
||||
|
||||
files = {'file': (filename, audio_data, 'audio/wav')}
|
||||
result = self._request('POST', '/transcribe', files=files)
|
||||
|
||||
return ASRResponse(**result)
|
||||
|
||||
def health_check(self) -> ASRHealthStatus:
|
||||
"""健康检查"""
|
||||
result = self._request('GET', '/health')
|
||||
return ASRHealthStatus(**result)
|
||||
|
||||
def get_service_info(self) -> ServiceInfo:
|
||||
"""获取服务信息"""
|
||||
result = self._request('GET', '/')
|
||||
return ServiceInfo(**result)
|
||||
|
||||
|
||||
# 异步客户端
|
||||
class ASRClientAsync:
|
||||
"""异步ASR客户端"""
|
||||
def __init__(self, config: Optional[ASRClientConfig] = None):
|
||||
self.config = config or ASRClientConfig()
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self._session = aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=self.config.timeout)
|
||||
)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
|
||||
async def _ensure_session(self):
|
||||
if self._session is None:
|
||||
self._session = aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=self.config.timeout)
|
||||
)
|
||||
|
||||
async def _request(self, method: str, endpoint: str, **kwargs) -> dict:
|
||||
"""统一异步请求(带重试)"""
|
||||
await self._ensure_session()
|
||||
url = f"{self.config.base_url}{endpoint}"
|
||||
|
||||
for attempt in range(self.config.retry_count + 1):
|
||||
try:
|
||||
async with self._session.request(method, url, **kwargs) as response:
|
||||
response.raise_for_status()
|
||||
return await response.json()
|
||||
except aiohttp.ClientError as e:
|
||||
if attempt < self.config.retry_count:
|
||||
logger.warning(f"请求失败,重试中 ({attempt + 1}/{self.config.retry_count}): {e}")
|
||||
await asyncio.sleep(self.config.retry_delay)
|
||||
else:
|
||||
logger.error(f"请求最终失败: {e}")
|
||||
raise ASRException(f"API调用失败: {e}", getattr(e, 'status', None))
|
||||
|
||||
async def transcribe_file(self, file_path: Union[str, Path]) -> ASRResponse:
|
||||
"""异步转录音频文件"""
|
||||
file_path = Path(file_path)
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
|
||||
logger.info(f"📤 上传文件: {file_path.name}")
|
||||
|
||||
async with aiofiles.open(file_path, 'rb') as f:
|
||||
audio_data = await f.read()
|
||||
|
||||
return await self.transcribe_bytes(audio_data, file_path.name)
|
||||
|
||||
async def transcribe_bytes(self, audio_data: bytes, filename: str = "audio.wav") -> ASRResponse:
|
||||
"""异步转录音频字节流"""
|
||||
logger.info(f"📤 上传字节流 ({len(audio_data)} bytes)")
|
||||
await self._ensure_session() # 确保session已创建
|
||||
form = aiohttp.FormData() # 创建表单数据
|
||||
form.add_field('file', audio_data, filename=filename, content_type='audio/wav') # 添加文件字段
|
||||
result = await self._request('POST', '/transcribe', data=form) # 发送POST请求
|
||||
return ASRResponse(**result) # 返回结果
|
||||
|
||||
async def health_check(self) -> ASRHealthStatus:
|
||||
"""异步健康检查"""
|
||||
result = await self._request('GET', '/health')
|
||||
return ASRHealthStatus(**result)
|
||||
|
||||
async def get_service_info(self) -> ServiceInfo:
|
||||
"""异步获取服务信息"""
|
||||
result = await self._request('GET', '/')
|
||||
return ServiceInfo(**result)
|
||||
|
||||
# 工厂函数
|
||||
def create_asr_client(use_async: bool = False, **config_kwargs) -> Union[ASRClientSync, ASRClientAsync]:
|
||||
"""
|
||||
创建客户端工厂函数
|
||||
|
||||
Args:
|
||||
use_async: 是否创建异步客户端
|
||||
**config_kwargs: ASRClientConfig参数
|
||||
|
||||
Returns:
|
||||
同步或异步客户端实例
|
||||
"""
|
||||
config = ASRClientConfig(**config_kwargs)
|
||||
if use_async:
|
||||
return ASRClientAsync(config)
|
||||
return ASRClientSync(config)
|
||||
@@ -0,0 +1,30 @@
|
||||
# asr_module/client/models.py
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
class ASRHealthStatus(BaseModel):
|
||||
"""ASR服务健康状态"""
|
||||
status: str # 状态
|
||||
timestamp: str # 时间戳
|
||||
device: str # 设备
|
||||
model_loaded: bool # 模型是否加载
|
||||
|
||||
class ASRResult(BaseModel):
|
||||
"""语音识别结果"""
|
||||
text: str # 识别结果
|
||||
language: str # 语言
|
||||
confidence: float # 置信度
|
||||
processing_time: float # 处理时间
|
||||
|
||||
class ASRResponse(BaseModel):
|
||||
"""统一API响应"""
|
||||
success: bool
|
||||
data: Optional[ASRResult] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
class ServiceInfo(BaseModel):
|
||||
"""服务信息"""
|
||||
message: str # 消息
|
||||
docs: str # 文档
|
||||
health: str # 健康
|
||||
transcribe: str # 识别
|
||||
@@ -0,0 +1,12 @@
|
||||
本模块为asr模块,即语音转文本模块。
|
||||
|
||||
本模块提供了两种访问方式:
|
||||
- 第一种为本地部署的func call方式,即提供函数调用的方式调用相关asr接口。
|
||||
- 第二种为http call方式,即提供http call方式调用相关asr接口。[FastAPI]
|
||||
|
||||
如果你的电脑显卡支持cuda, 并且显存大小大于8G, 那么可以使用第一种方式,
|
||||
否则可以使用第二种,进行云端部署。
|
||||
|
||||
两种调用方式在相同配置下,性能几乎无差别。
|
||||
|
||||
个人建议使用第二种
|
||||
@@ -0,0 +1,65 @@
|
||||
# start_api.py
|
||||
import uvicorn
|
||||
from loguru import logger
|
||||
import threading
|
||||
import time
|
||||
|
||||
def start_server():
|
||||
"""启动 ASR API 服务"""
|
||||
uvicorn.run(
|
||||
"api:app", # 模块名:app实例
|
||||
host="0.0.0.0",
|
||||
port=20260,
|
||||
workers=1, # 单用户场景,1个worker足够
|
||||
log_level="info",
|
||||
reload=False, # 生产环境关闭热重载
|
||||
access_log=True,
|
||||
)
|
||||
|
||||
def first_test() -> None:
|
||||
"""首次启动测试"""
|
||||
time.sleep(5) # 给服务器一些启动时间
|
||||
# 构造一个测试请求以验证初始化模型加载成功
|
||||
logger.info("🚀 测试模型是否加载成功...")
|
||||
import requests
|
||||
from pathlib import Path
|
||||
url = "http://localhost:20260/transcribe"
|
||||
audio_path = Path("../../../Test/test_files/test.wav")
|
||||
try:
|
||||
with open(audio_path, "rb") as f:
|
||||
# 明确指定文件名和 MIME 类型
|
||||
files = {
|
||||
"file": (
|
||||
audio_path.name, # 文件名
|
||||
f, # 文件对象
|
||||
"audio/wav" # MIME 类型
|
||||
)
|
||||
}
|
||||
|
||||
response = requests.post(url, files=files)
|
||||
logger.info(f"状态码: {response.status_code}")
|
||||
logger.info(f"响应头: {response.headers.get('content-type')}")
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
logger.info(f"识别结果: {result['data']['text']}")
|
||||
logger.info(f"识别语言: {result['data']['language']}")
|
||||
logger.info(f"置信度: {result['data']['confidence']:.2f}")
|
||||
logger.info(f"处理时间: {result['data']['processing_time']}s")
|
||||
else:
|
||||
logger.error(f"请求失败,错误响应信息: {response.text}")
|
||||
logger.error("请检查模型是否正确加载或其他问题")
|
||||
except Exception as e:
|
||||
logger.error(f"测试过程中发生错误: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("🚀 启动 ASR API 服务...")
|
||||
|
||||
# 在后台线程启动服务器
|
||||
server_thread = threading.Thread(target=start_server, daemon=True)
|
||||
server_thread.start()
|
||||
|
||||
# 执行测试
|
||||
first_test()
|
||||
|
||||
# 保持主线程运行
|
||||
server_thread.join()
|
||||
@@ -0,0 +1,119 @@
|
||||
# ui_tars_/ui_tars_client.py
|
||||
from typing import Optional
|
||||
from loguru import logger
|
||||
import asyncio
|
||||
from src.modules.text_ai_module.text_ai_core.general_text_ai_req import (
|
||||
UnifiedLLM,
|
||||
ModelConfig,
|
||||
ModelProvider,
|
||||
ChatMessage,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
from src.modules.device_control_module.device_control_core.ui_tars_.ui_tars_prompts import UI_TARS_SYSTEM_PROMPT
|
||||
|
||||
|
||||
class UITarsClientConfig(BaseModel):
|
||||
"""UI-TARS 客户端配置"""
|
||||
deployment_type: str = Field(default="lmstudio", description="部署类型")
|
||||
base_url: str = Field(default="http://localhost:1234/v1", description="API地址")
|
||||
model_name: str = Field(default="ui-tars", description="模型名称")
|
||||
api_key: Optional[str] = Field(default=None, description="API密钥")
|
||||
temperature: float = Field(default=0.1, ge=0.0, le=2.0)
|
||||
max_tokens: int = Field(default=8192, ge=2048, le=128000)
|
||||
timeout: int = Field(default=30, ge=5, le=300)
|
||||
|
||||
# UI-TARS-1.5 强制输出格式
|
||||
system_prompt: str = Field(
|
||||
default=UI_TARS_SYSTEM_PROMPT # 使用本项目自定义的输出格式约束
|
||||
)
|
||||
|
||||
def to_model_config(self) -> ModelConfig:
|
||||
"""转换为 UnifiedLLM 配置"""
|
||||
# 映射部署类型到 ModelProvider
|
||||
provider_map = {
|
||||
"lmstudio": ModelProvider.LM_STUDIO,
|
||||
"vllm": ModelProvider.CUSTOM,
|
||||
"cloud": ModelProvider.OPENAI,
|
||||
"ollama": ModelProvider.OLLAMA
|
||||
}
|
||||
provider = provider_map.get(self.deployment_type, ModelProvider.CUSTOM)
|
||||
return ModelConfig(
|
||||
provider=provider,
|
||||
model_name=self.model_name,
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
timeout=self.timeout,
|
||||
custom_headers={"User-Agent": "UI-TARS-Client/1.0"}
|
||||
)
|
||||
|
||||
class UITarsClient:
|
||||
"""
|
||||
UI-TARS 通用客户端 (基于 UnifiedLLM)
|
||||
图片相关信息请直接传入相应的base64
|
||||
"""
|
||||
def __init__(self, config: UITarsClientConfig):
|
||||
self.config = config
|
||||
# 复用 UnifiedLLM,自动处理所有部署类型
|
||||
self.llm = UnifiedLLM(config.to_model_config())
|
||||
|
||||
logger.info(f"UI-TARS 客户端初始化: {config.deployment_type} @ {config.base_url}")
|
||||
logger.info(f" 模型: {config.model_name} | 温度: {config.temperature}")
|
||||
|
||||
async def call_async(self, instruction: str, image_base64: str) -> str:
|
||||
"""异步调用 UI-TARS"""
|
||||
# 构建消息
|
||||
messages = self._build_messages(instruction, image_base64)
|
||||
try:
|
||||
# 使用 UnifiedLLM 的异步接口
|
||||
response = await asyncio.to_thread(
|
||||
self.llm.chat,
|
||||
messages=messages,
|
||||
streaming=False
|
||||
)
|
||||
return response.content
|
||||
except Exception as e:
|
||||
logger.error(f"UI-TARS 调用失败: {e}")
|
||||
raise
|
||||
|
||||
def call_sync(self, instruction: str, image_base64: str) -> str:
|
||||
"""同步调用 UI-TARS"""
|
||||
messages = self._build_messages(instruction, image_base64)
|
||||
try:
|
||||
response = self.llm.chat(
|
||||
messages=messages,
|
||||
streaming=False
|
||||
)
|
||||
|
||||
return response.content
|
||||
except Exception as e:
|
||||
logger.error(f"UI-TARS 调用失败: {e}")
|
||||
raise
|
||||
|
||||
def stream_async(self, instruction: str, image_base64: str):
|
||||
"""流式调用 (异步生成器)"""
|
||||
messages = self._build_messages(instruction, image_base64)
|
||||
# UnifiedLLM 自动处理流式
|
||||
return self.llm.stream_chat(messages=messages)
|
||||
|
||||
def _build_messages(self, instruction: str, image_base64: str) -> list:
|
||||
"""构建 OpenAI 格式消息"""
|
||||
return [
|
||||
ChatMessage(
|
||||
role="system",
|
||||
content=self.config.system_prompt
|
||||
),
|
||||
ChatMessage(
|
||||
role="user",
|
||||
content=[
|
||||
{"type": "text", "text": instruction},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{image_base64}"
|
||||
}
|
||||
}
|
||||
]
|
||||
)
|
||||
]
|
||||
@@ -0,0 +1,25 @@
|
||||
OFFICIAL_ACTION_SPACE = """## Action Space
|
||||
click(point='<point>x1 y1</point>') - 单击坐标
|
||||
left_double(point='<point>x1 y1</point>') - 双击坐标
|
||||
right_single(point='<point>x1 y1</point>') - 右键单击
|
||||
drag(start_point='<point>x1 y1</point>', end_point='<point>x2 y2</point>') - 拖拽
|
||||
hotkey(key='ctrl c') - 快捷键(空格分隔,小写,最多3个键)
|
||||
type(content='xxx') - 输入文本(用\\' \\\" \\n 转义)
|
||||
scroll(point='<point>x1 y1</point>', direction='down or up or right or left') - 滚动
|
||||
wait() - 等待5秒
|
||||
finished() - 任务完成
|
||||
|
||||
## Output Format
|
||||
Thought: [你的推理过程]
|
||||
Action: [选择一个动作]
|
||||
"""
|
||||
|
||||
UI_TARS_SYSTEM_PROMPT = f"""You are UI-TARS-1.5, a GUI agent. Given a task and screenshot, output ONLY:
|
||||
|
||||
{OFFICIAL_ACTION_SPACE}
|
||||
|
||||
## Note
|
||||
- Write a small plan and summarize the next action in one sentence in Thought.
|
||||
- NEVER output multiple actions.
|
||||
- x&y please in box center
|
||||
"""
|
||||
@@ -0,0 +1,10 @@
|
||||
本模块为设备控制模块接口层,此处的设备控制指的是借助AI模型进行一些设备上的自动化
|
||||
操作,支持`pc`, `android`,其他的未做过测试。
|
||||
|
||||
依赖:
|
||||
`ui_tars`
|
||||
|
||||
当前所使用的AI模型为`mradermacher/UI-TARS-1.5-7B-GGUF
|
||||
(Q6_K or Q4_K_M)`
|
||||
|
||||
未来如果有更快质量更高的AI模型本模块会为其添加支持。
|
||||
@@ -0,0 +1,837 @@
|
||||
"""
|
||||
通用大语言模型调用框架
|
||||
支持本地模型(Ollama, LM Studio, llama.cpp)和云端模型(OpenAI, Anthropic, Google, Azure等)
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Union, Any, Iterator
|
||||
import json
|
||||
import os
|
||||
from loguru import logger
|
||||
from dataclasses import dataclass, asdict, field
|
||||
from enum import Enum
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class ModelProvider(Enum):
|
||||
"""支持的模型提供商枚举"""
|
||||
OPENAI = "openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
GOOGLE = "google"
|
||||
AZURE = "azure"
|
||||
OLLAMA = "ollama"
|
||||
LM_STUDIO = "lm_studio"
|
||||
LLAMA_CPP = "llama_cpp"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
"""模型配置类"""
|
||||
provider: ModelProvider
|
||||
model_name: str
|
||||
api_key: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
api_version: Optional[str] = None
|
||||
temperature: float = 0.7
|
||||
max_tokens: int = 1024
|
||||
top_p: float = 1.0
|
||||
frequency_penalty: float = 0.0
|
||||
presence_penalty: float = 0.0
|
||||
timeout: int = 30
|
||||
streaming: bool = False
|
||||
custom_headers: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""转换为字典"""
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
"""聊天消息类"""
|
||||
role: str # system, user, assistant
|
||||
content: str
|
||||
name: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"role": self.role,
|
||||
"content": self.content,
|
||||
**({"name": self.name} if self.name else {})
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelResponse:
|
||||
"""模型响应类"""
|
||||
content: str # 响应内容
|
||||
model: str # 模型名称
|
||||
usage: Optional[Dict[str, int]] = None # 使用量
|
||||
finish_reason: Optional[str] = None # 结束原因
|
||||
raw_response: Optional[Dict] = None # 原始响应
|
||||
|
||||
|
||||
def normalize_usage(raw_usage: Optional[Dict[str, Any]], provider: ModelProvider) -> Optional[Dict[str, int]]:
|
||||
"""
|
||||
将不同平台的 usage 字段统一归一化为 OpenAI 标准格式
|
||||
|
||||
Args:
|
||||
raw_usage: API 原始返回的 usage 数据
|
||||
provider: 模型提供商枚举
|
||||
|
||||
Returns:
|
||||
归一化后的 usage 字典,格式:
|
||||
{
|
||||
"prompt_tokens": int,
|
||||
"completion_tokens": int,
|
||||
"total_tokens": int
|
||||
}
|
||||
如果无法归一化则返回 None
|
||||
"""
|
||||
if not raw_usage:
|
||||
return None
|
||||
|
||||
# 字段映射表:{provider: (input_key, output_key, total_key)}
|
||||
USAGE_FIELD_MAP = {
|
||||
ModelProvider.OPENAI: ("prompt_tokens", "completion_tokens", "total_tokens"),
|
||||
ModelProvider.AZURE: ("prompt_tokens", "completion_tokens", "total_tokens"),
|
||||
ModelProvider.LM_STUDIO: ("prompt_tokens", "completion_tokens", "total_tokens"),
|
||||
ModelProvider.LLAMA_CPP: ("prompt_tokens", "completion_tokens", "total_tokens"),
|
||||
ModelProvider.OLLAMA: ("prompt_eval_count", "eval_count", None), # Ollama 没有 total
|
||||
ModelProvider.ANTHROPIC: ("input_tokens", "output_tokens", None),
|
||||
ModelProvider.GOOGLE: ("promptTokenCount", "candidatesTokenCount", "totalTokenCount"),
|
||||
}
|
||||
|
||||
input_key, output_key, total_key = USAGE_FIELD_MAP.get(provider, (None, None, None))
|
||||
|
||||
if input_key is None:
|
||||
logger.warning(f"未知的 provider '{provider}',无法归一化 usage")
|
||||
return None
|
||||
|
||||
try:
|
||||
# 提取字段值
|
||||
prompt_tokens = raw_usage.get(input_key, 0)
|
||||
completion_tokens = raw_usage.get(output_key, 0)
|
||||
|
||||
# 处理嵌套字典(如有些 API 的 usage 格式特殊)
|
||||
if isinstance(prompt_tokens, dict):
|
||||
prompt_tokens = prompt_tokens.get("value", 0)
|
||||
if isinstance(completion_tokens, dict):
|
||||
completion_tokens = completion_tokens.get("value", 0)
|
||||
|
||||
# 转换为整数
|
||||
prompt_tokens = int(prompt_tokens) if prompt_tokens else 0
|
||||
completion_tokens = int(completion_tokens) if completion_tokens else 0
|
||||
|
||||
# 计算 total(如果 API 没提供)
|
||||
if total_key and total_key in raw_usage:
|
||||
total_tokens = int(raw_usage[total_key])
|
||||
else:
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
normalized = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
|
||||
logger.debug(f"归一化 usage | {provider} -> OpenAI格式: {normalized}")
|
||||
return normalized
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"归一化 usage 失败: {e} | raw_usage: {raw_usage}")
|
||||
return None
|
||||
|
||||
class BaseLLMClient(ABC):
|
||||
"""大语言模型客户端基类"""
|
||||
|
||||
def __init__(self, config: ModelConfig):
|
||||
self.config = config
|
||||
self.client = None
|
||||
self._initialize_client()
|
||||
|
||||
@abstractmethod
|
||||
def _initialize_client(self):
|
||||
"""初始化客户端"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def chat_completion(
|
||||
self,
|
||||
messages: List[Union[ChatMessage, Dict]],
|
||||
**kwargs
|
||||
) -> Union[ModelResponse, Iterator[ModelResponse]]:
|
||||
"""聊天补全"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def completion(
|
||||
self,
|
||||
prompt: str,
|
||||
**kwargs
|
||||
) -> Union[ModelResponse, Iterator[ModelResponse]]:
|
||||
"""文本补全"""
|
||||
pass
|
||||
|
||||
def format_messages(self, messages: List[Union[ChatMessage, Dict]]) -> List[Dict]:
|
||||
"""格式化消息列表"""
|
||||
formatted = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, ChatMessage):
|
||||
formatted.append(msg.to_dict())
|
||||
else:
|
||||
formatted.append(msg)
|
||||
return formatted
|
||||
|
||||
|
||||
class OpenAIClient(BaseLLMClient):
|
||||
"""OpenAI客户端"""
|
||||
|
||||
def _initialize_client(self):
|
||||
try:
|
||||
from openai import OpenAI
|
||||
|
||||
api_key = self.config.api_key
|
||||
if not api_key:
|
||||
raise ValueError("OpenAI API密钥未设置")
|
||||
|
||||
self.client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=self.config.base_url,
|
||||
timeout=self.config.timeout
|
||||
)
|
||||
logger.info(f"OpenAI客户端初始化成功,base_url: {self.config.base_url}")
|
||||
except ImportError:
|
||||
logger.error("请安装openai包: pip install openai")
|
||||
raise
|
||||
|
||||
def chat_completion(self, messages, **kwargs):
|
||||
formatted_messages = self.format_messages(messages)
|
||||
|
||||
# 获取streaming参数,优先使用kwargs中的设置
|
||||
streaming = kwargs.get("streaming", self.config.streaming)
|
||||
|
||||
# 合并配置
|
||||
params = {
|
||||
"model": self.config.model_name,
|
||||
"messages": formatted_messages,
|
||||
"temperature": kwargs.get("temperature", self.config.temperature),
|
||||
"max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
|
||||
"top_p": kwargs.get("top_p", self.config.top_p),
|
||||
"frequency_penalty": kwargs.get("frequency_penalty", self.config.frequency_penalty),
|
||||
"presence_penalty": kwargs.get("presence_penalty", self.config.presence_penalty),
|
||||
"stream": streaming, # 使用正确的streaming设置
|
||||
}
|
||||
|
||||
logger.info(f"🔧 调用参数: streaming={streaming}")
|
||||
|
||||
if streaming:
|
||||
return self._stream_chat_completion(params)
|
||||
else:
|
||||
return self._normal_chat_completion(params)
|
||||
|
||||
def _normal_chat_completion(self, params):
|
||||
"""非流式响应处理"""
|
||||
logger.info("📡 发送非流式请求...")
|
||||
response = self.client.chat.completions.create(**params)
|
||||
raw_usage = response.usage
|
||||
normalized_usage = normalize_usage(
|
||||
raw_usage.model_dump() if hasattr(raw_usage, 'model_dump') else raw_usage,
|
||||
ModelProvider.OPENAI
|
||||
)
|
||||
return ModelResponse(
|
||||
content=response.choices[0].message.content,
|
||||
model=response.model,
|
||||
usage=normalized_usage,
|
||||
finish_reason=response.choices[0].finish_reason,
|
||||
raw_response=response.model_dump() if hasattr(response, 'model_dump') else response.dict()
|
||||
)
|
||||
|
||||
def _stream_chat_completion(self, params):
|
||||
"""流式响应处理"""
|
||||
logger.info("📡 发送流式请求...")
|
||||
response_stream = self.client.chat.completions.create(**params)
|
||||
|
||||
full_content = ""
|
||||
for chunk in response_stream:
|
||||
if chunk.choices[0].delta.content is not None:
|
||||
content_chunk = chunk.choices[0].delta.content
|
||||
full_content += content_chunk
|
||||
yield ModelResponse(
|
||||
content=content_chunk,
|
||||
model=chunk.model,
|
||||
raw_response=chunk.model_dump() if hasattr(chunk, 'model_dump') else chunk.dict()
|
||||
)
|
||||
|
||||
def completion(self, prompt, **kwargs):
|
||||
# OpenAI 推荐使用 chat_completion,这里保持兼容
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
return self.chat_completion(messages, **kwargs)
|
||||
|
||||
|
||||
class AnthropicClient(BaseLLMClient):
|
||||
"""Anthropic Claude客户端"""
|
||||
|
||||
def _initialize_client(self):
|
||||
try:
|
||||
from anthropic import Anthropic
|
||||
self.client = Anthropic(
|
||||
api_key=self.config.api_key or os.getenv("ANTHROPIC_API_KEY"),
|
||||
timeout=self.config.timeout
|
||||
)
|
||||
except ImportError:
|
||||
logger.error("请安装anthropic包: pip install anthropic")
|
||||
raise
|
||||
|
||||
def chat_completion(self, messages, **kwargs):
|
||||
formatted_messages = self.format_messages(messages)
|
||||
|
||||
# Claude 的消息格式转换
|
||||
claude_messages = []
|
||||
system_message = None
|
||||
|
||||
for msg in formatted_messages:
|
||||
if msg["role"] == "system":
|
||||
system_message = msg["content"]
|
||||
else:
|
||||
claude_messages.append({
|
||||
"role": msg["role"],
|
||||
"content": msg["content"]
|
||||
})
|
||||
# 明确获取streaming参数
|
||||
|
||||
params = {
|
||||
"model": self.config.model_name,
|
||||
"messages": claude_messages,
|
||||
"max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
|
||||
"temperature": kwargs.get("temperature", self.config.temperature),
|
||||
"top_p": kwargs.get("top_p", self.config.top_p),
|
||||
"stream": kwargs.get("streaming", self.config.streaming),
|
||||
}
|
||||
|
||||
if system_message:
|
||||
params["system"] = system_message
|
||||
|
||||
if self.config.streaming:
|
||||
return self._stream_chat_completion(params)
|
||||
else:
|
||||
return self._normal_chat_completion(params)
|
||||
|
||||
def _normal_chat_completion(self, params):
|
||||
response = self.client.messages.create(**params)
|
||||
# Anthropic 返回的usage格式和OpenAI不同,需要进行转换
|
||||
raw_usage = response.usage
|
||||
normalized_usage = normalize_usage(
|
||||
raw_usage.model_dump() if hasattr(raw_usage, 'model_dump') else raw_usage,
|
||||
ModelProvider.ANTHROPIC
|
||||
)
|
||||
return ModelResponse(
|
||||
content=response.content[0].text,
|
||||
model=response.model,
|
||||
usage=normalized_usage,
|
||||
finish_reason=response.stop_reason,
|
||||
raw_response=response.model_dump() if hasattr(response, 'model_dump') else response.dict()
|
||||
)
|
||||
|
||||
def _stream_chat_completion(self, params):
|
||||
with self.client.messages.stream(**params) as stream:
|
||||
for chunk in stream:
|
||||
if chunk.type_ == "content_block_delta":
|
||||
yield ModelResponse(
|
||||
content=chunk.delta.text,
|
||||
model=params["model"],
|
||||
raw_response=chunk.model_dump() if hasattr(chunk, 'model_dump') else chunk.dict()
|
||||
)
|
||||
|
||||
def completion(self, prompt, **kwargs):
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
return self.chat_completion(messages, **kwargs)
|
||||
|
||||
|
||||
class OllamaClient(BaseLLMClient):
|
||||
"""Ollama本地模型客户端"""
|
||||
|
||||
def _initialize_client(self):
|
||||
import httpx
|
||||
self.base_url = self.config.base_url or "http://localhost:11434"
|
||||
self.client = httpx.Client(
|
||||
base_url=self.base_url,
|
||||
timeout=self.config.timeout
|
||||
)
|
||||
|
||||
def chat_completion(self, messages, **kwargs):
|
||||
formatted_messages = self.format_messages(messages)
|
||||
|
||||
payload = {
|
||||
"model": self.config.model_name,
|
||||
"messages": formatted_messages,
|
||||
"options": {
|
||||
"temperature": kwargs.get("temperature", self.config.temperature),
|
||||
"top_p": kwargs.get("top_p", self.config.top_p),
|
||||
"num_predict": kwargs.get("max_tokens", self.config.max_tokens),
|
||||
},
|
||||
"stream": kwargs.get("streaming", self.config.streaming),
|
||||
}
|
||||
|
||||
if self.config.streaming:
|
||||
return self._stream_chat_completion(payload)
|
||||
else:
|
||||
return self._normal_chat_completion(payload)
|
||||
|
||||
def _normal_chat_completion(self, payload):
|
||||
response = self.client.post("/api/chat", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
normalized_usage = normalize_usage(data, ModelProvider.OLLAMA)
|
||||
return ModelResponse(
|
||||
content=data["message"]["content"],
|
||||
model=data["model"],
|
||||
usage=normalized_usage,
|
||||
finish_reason=data.get("done_reason"),
|
||||
raw_response=data
|
||||
)
|
||||
|
||||
def _stream_chat_completion(self, payload):
|
||||
with self.client.stream("POST", "/api/chat", json=payload) as response:
|
||||
for line in response.iter_lines():
|
||||
if line.strip():
|
||||
try:
|
||||
data = json.loads(line)
|
||||
if "message" in data and "content" in data["message"]:
|
||||
yield ModelResponse(
|
||||
content=data["message"]["content"],
|
||||
model=data.get("model", self.config.model_name),
|
||||
raw_response=data
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
def completion(self, prompt, **kwargs):
|
||||
payload = {
|
||||
"model": self.config.model_name,
|
||||
"prompt": prompt,
|
||||
"options": {
|
||||
"temperature": kwargs.get("temperature", self.config.temperature),
|
||||
"top_p": kwargs.get("top_p", self.config.top_p),
|
||||
"num_predict": kwargs.get("max_tokens", self.config.max_tokens),
|
||||
},
|
||||
"stream": kwargs.get("streaming", self.config.streaming),
|
||||
}
|
||||
|
||||
if self.config.streaming:
|
||||
return self._stream_completion(payload)
|
||||
else:
|
||||
return self._normal_completion(payload)
|
||||
|
||||
def _normal_completion(self, payload):
|
||||
response = self.client.post("/api/generate", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
return ModelResponse(
|
||||
content=data["response"],
|
||||
model=data["model"],
|
||||
usage={
|
||||
"prompt_tokens": data.get("prompt_eval_count", 0),
|
||||
"completion_tokens": data.get("eval_count", 0),
|
||||
"total_tokens": data.get("prompt_eval_count", 0) + data.get("eval_count", 0)
|
||||
},
|
||||
finish_reason=data.get("done_reason"),
|
||||
raw_response=data
|
||||
)
|
||||
|
||||
def _stream_completion(self, payload):
|
||||
with self.client.stream("POST", "/api/generate", json=payload) as response:
|
||||
for line in response.iter_lines():
|
||||
if line.strip():
|
||||
try:
|
||||
data = json.loads(line)
|
||||
if "response" in data:
|
||||
yield ModelResponse(
|
||||
content=data["response"],
|
||||
model=data.get("model", self.config.model_name),
|
||||
raw_response=data
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
|
||||
class GenericLLMClient(BaseLLMClient):
|
||||
"""通用HTTP客户端,支持LM Studio和其他兼容OpenAI API的本地模型"""
|
||||
|
||||
def _initialize_client(self):
|
||||
import httpx
|
||||
self.base_url = self.config.base_url or "http://localhost:1234/v1"
|
||||
self.client = httpx.Client(
|
||||
base_url=self.base_url,
|
||||
timeout=self.config.timeout,
|
||||
headers=self.config.custom_headers
|
||||
)
|
||||
|
||||
def chat_completion(self, messages, **kwargs):
|
||||
formatted_messages = self.format_messages(messages)
|
||||
|
||||
# 明确获取 streaming 参数
|
||||
streaming = kwargs.get("streaming", self.config.streaming)
|
||||
|
||||
payload = {
|
||||
"model": self.config.model_name,
|
||||
"messages": formatted_messages,
|
||||
"temperature": kwargs.get("temperature", self.config.temperature),
|
||||
"max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
|
||||
"top_p": kwargs.get("top_p", self.config.top_p),
|
||||
"stream": streaming, # 使用明确的 streaming 变量
|
||||
}
|
||||
|
||||
logger.info(f"GenericLLMClient 参数: streaming={streaming}")
|
||||
|
||||
if streaming:
|
||||
return self._stream_chat_completion(payload)
|
||||
else:
|
||||
return self._normal_chat_completion(payload)
|
||||
|
||||
def _normal_chat_completion(self, payload):
|
||||
logger.info(f"GenericLLMClient 发送非流式请求到: {self.base_url}")
|
||||
response = self.client.post("/chat/completions", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
logger.info(f"GenericLLMClient 收到响应,模型: {data.get('model')}")
|
||||
|
||||
return ModelResponse(
|
||||
content=data["choices"][0]["message"]["content"],
|
||||
model=data["model"],
|
||||
usage=data.get("usage"),
|
||||
finish_reason=data["choices"][0].get("finish_reason"),
|
||||
raw_response=data
|
||||
)
|
||||
|
||||
def _stream_chat_completion(self, payload):
|
||||
logger.info(f"GenericLLMClient 发送流式请求到: {self.base_url}")
|
||||
with self.client.stream("POST", "/chat/completions", json=payload) as response:
|
||||
for line in response.iter_lines():
|
||||
if line.startswith("data: "):
|
||||
chunk = line[6:]
|
||||
if chunk == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(chunk)
|
||||
if data["choices"][0]["delta"].get("content"):
|
||||
yield ModelResponse(
|
||||
content=data["choices"][0]["delta"]["content"],
|
||||
model=data["model"],
|
||||
raw_response=data
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
def completion(self, prompt, **kwargs):
|
||||
payload = {
|
||||
"model": self.config.model_name,
|
||||
"prompt": prompt,
|
||||
"temperature": kwargs.get("temperature", self.config.temperature),
|
||||
"max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
|
||||
"top_p": kwargs.get("top_p", self.config.top_p),
|
||||
"stream": kwargs.get("streaming", self.config.streaming),
|
||||
}
|
||||
|
||||
streaming = kwargs.get("streaming", self.config.streaming)
|
||||
|
||||
if streaming:
|
||||
return self._stream_completion(payload)
|
||||
else:
|
||||
return self._normal_completion(payload)
|
||||
|
||||
def _normal_completion(self, payload):
|
||||
response = self.client.post("/completions", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
return ModelResponse(
|
||||
content=data["choices"][0]["text"],
|
||||
model=data["model"],
|
||||
usage=data.get("usage"),
|
||||
finish_reason=data["choices"][0].get("finish_reason"),
|
||||
raw_response=data
|
||||
)
|
||||
|
||||
def _stream_completion(self, payload):
|
||||
with self.client.stream("POST", "/completions", json=payload) as response:
|
||||
for line in response.iter_lines():
|
||||
if line.startswith("data: "):
|
||||
chunk = line[6:]
|
||||
if chunk == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(chunk)
|
||||
if data["choices"][0].get("text"):
|
||||
yield ModelResponse(
|
||||
content=data["choices"][0]["text"],
|
||||
model=data["model"],
|
||||
raw_response=data
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
class UnifiedLLM:
|
||||
"""
|
||||
统一的大语言模型调用类
|
||||
|
||||
支持多种模型提供商:
|
||||
- 云端模型:OpenAI, Anthropic, Google, Azure
|
||||
- 本地模型:Ollama, LM Studio, llama.cpp
|
||||
|
||||
使用示例:
|
||||
```python
|
||||
# 初始化OpenAI客户端
|
||||
config = ModelConfig(
|
||||
provider=ModelProvider.OPENAI,
|
||||
model_name="gpt-4",
|
||||
api_key="your-api-key"
|
||||
)
|
||||
llm = UnifiedLLM(config)
|
||||
|
||||
# 聊天补全
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个有用的助手"},
|
||||
{"role": "user", "content": "你好!"}
|
||||
]
|
||||
response = llm.chat(messages)
|
||||
print(response.content)
|
||||
|
||||
# 流式响应
|
||||
config.streaming = True
|
||||
llm.update_config(config)
|
||||
for chunk in llm.chat(messages):
|
||||
print(chunk.content, end="", flush=True)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config: ModelConfig):
|
||||
"""
|
||||
初始化统一LLM
|
||||
|
||||
Args:
|
||||
config: 模型配置
|
||||
"""
|
||||
self.config = config
|
||||
self.client = self._create_client()
|
||||
logger.info(f" UnifiedLLM 初始化完成")
|
||||
logger.info(f" 提供商: {config.provider}")
|
||||
logger.info(f" 模型: {config.model_name}")
|
||||
logger.info(f" 流式默认: {config.streaming}")
|
||||
|
||||
def _create_client(self) -> BaseLLMClient:
|
||||
"""根据配置创建客户端"""
|
||||
provider = self.config.provider
|
||||
|
||||
if provider == ModelProvider.OPENAI:
|
||||
return OpenAIClient(self.config)
|
||||
elif provider == ModelProvider.ANTHROPIC:
|
||||
return AnthropicClient(self.config)
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
return OllamaClient(self.config)
|
||||
elif provider in [ModelProvider.LM_STUDIO, ModelProvider.LLAMA_CPP, ModelProvider.CUSTOM]:
|
||||
return GenericLLMClient(self.config)
|
||||
elif provider == ModelProvider.GOOGLE:
|
||||
# 这里可以扩展Google Gemini支持
|
||||
return GenericLLMClient(self.config)
|
||||
elif provider == ModelProvider.AZURE:
|
||||
# Azure OpenAI需要特殊处理
|
||||
return GenericLLMClient(self.config)
|
||||
else:
|
||||
raise ValueError(f"不支持的模型提供商: {provider}")
|
||||
|
||||
def update_config(self, config: ModelConfig):
|
||||
"""更新配置并重新创建客户端"""
|
||||
self.config = config
|
||||
self.client = self._create_client()
|
||||
logger.info(f"UnifiedLLM 配置已更新")
|
||||
|
||||
def chat(self, messages: List[Union[ChatMessage, Dict]], **kwargs) -> Union[ModelResponse, Iterator[ModelResponse]]:
|
||||
"""
|
||||
聊天补全
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
**kwargs: 其他参数,会覆盖config中的设置
|
||||
|
||||
Returns:
|
||||
ModelResponse 或 ModelResponse 的迭代器(流式模式)
|
||||
"""
|
||||
# 明确获取streaming参数
|
||||
streaming = kwargs.get("streaming", self.config.streaming)
|
||||
logger.info(f" UnifiedLLM.chat() 调用")
|
||||
logger.info(f" 消息数: {len(messages)}")
|
||||
logger.info(f" streaming参数: {streaming}")
|
||||
|
||||
# 调用客户端
|
||||
result = self.client.chat_completion(messages, **kwargs)
|
||||
|
||||
# 类型检查(调试用)
|
||||
if streaming:
|
||||
if not hasattr(result, '__iter__'):
|
||||
logger.warning(f"警告: streaming=True 但返回的不是迭代器")
|
||||
else:
|
||||
if hasattr(result, '__iter__'):
|
||||
logger.warning(f"警告: streaming=False 但返回的是迭代器")
|
||||
elif not isinstance(result, ModelResponse):
|
||||
logger.warning(f"警告: streaming=False 但返回的不是ModelResponse")
|
||||
|
||||
return result
|
||||
|
||||
def complete(self, prompt: str, **kwargs) -> Union[ModelResponse, Iterator[ModelResponse]]:
|
||||
"""
|
||||
文本补全
|
||||
|
||||
Args:
|
||||
prompt: 提示文本
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
ModelResponse 或 ModelResponse 的迭代器(流式模式)
|
||||
"""
|
||||
streaming = kwargs.get("streaming", self.config.streaming)
|
||||
logger.info(f" UnifiedLLM.complete() 调用")
|
||||
logger.info(f" prompt长度: {len(prompt)}")
|
||||
logger.info(f" streaming参数: {streaming}")
|
||||
|
||||
return self.client.completion(prompt, **kwargs)
|
||||
|
||||
def stream_chat(self, messages: List[Union[ChatMessage, Dict]], **kwargs) -> Iterator[ModelResponse]:
|
||||
"""
|
||||
流式聊天补全(便捷方法)
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
ModelResponse 的迭代器
|
||||
"""
|
||||
kwargs["streaming"] = True
|
||||
logger.info(f"UnifiedLLM.stream_chat() 调用")
|
||||
|
||||
result = self.chat(messages, **kwargs)
|
||||
|
||||
# 确保返回的是迭代器
|
||||
if not hasattr(result, '__iter__'):
|
||||
raise TypeError("stream_chat 应该返回迭代器,但返回了其他类型")
|
||||
|
||||
return result
|
||||
|
||||
def stream_complete(self, prompt: str, **kwargs) -> Iterator[ModelResponse]:
|
||||
"""
|
||||
流式文本补全(便捷方法)
|
||||
|
||||
Args:
|
||||
prompt: 提示文本
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
ModelResponse 的迭代器
|
||||
"""
|
||||
kwargs["streaming"] = True
|
||||
logger.info(f"UnifiedLLM.stream_complete() 调用")
|
||||
|
||||
result = self.complete(prompt, **kwargs)
|
||||
|
||||
# 确保返回的是迭代器
|
||||
if not hasattr(result, '__iter__'):
|
||||
raise TypeError("stream_complete 应该返回迭代器,但返回了其他类型")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# 快捷函数
|
||||
def create_llm_client(
|
||||
provider: Union[str, ModelProvider],
|
||||
model_name: str,
|
||||
**kwargs
|
||||
) -> UnifiedLLM:
|
||||
"""
|
||||
快捷创建LLM客户端
|
||||
|
||||
Args:
|
||||
provider: 提供商名称或枚举
|
||||
model_name: 模型名称
|
||||
**kwargs: 其他配置参数
|
||||
|
||||
Returns:
|
||||
UnifiedLLM 实例
|
||||
"""
|
||||
if isinstance(provider, str):
|
||||
provider = ModelProvider(provider.lower())
|
||||
|
||||
config = ModelConfig(
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return UnifiedLLM(config)
|
||||
|
||||
|
||||
# 使用示例
|
||||
def example_usage():
|
||||
"""使用示例"""
|
||||
|
||||
# 示例1: 使用OpenAI
|
||||
print("示例1: 使用OpenAI")
|
||||
openai_config = ModelConfig(
|
||||
provider=ModelProvider.OPENAI,
|
||||
model_name="gpt-3.5-turbo",
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
temperature=0.8
|
||||
)
|
||||
|
||||
try:
|
||||
openai_llm = UnifiedLLM(openai_config)
|
||||
messages = [
|
||||
ChatMessage(role="system", content="你是一个有用的助手"),
|
||||
ChatMessage(role="user", content="请用Python写一个Hello World程序")
|
||||
]
|
||||
response = openai_llm.chat(messages)
|
||||
print(f"响应: {response.content[:100]}...")
|
||||
except Exception as e:
|
||||
print(f"OpenAI示例错误: {e}")
|
||||
|
||||
# 示例2: 使用Ollama(本地模型)
|
||||
print("\n示例2: 使用Ollama(本地模型)")
|
||||
ollama_config = ModelConfig(
|
||||
provider=ModelProvider.OLLAMA,
|
||||
model_name="llama2",
|
||||
base_url="http://localhost:11434",
|
||||
temperature=0.7,
|
||||
streaming=True # 流式响应
|
||||
)
|
||||
|
||||
try:
|
||||
ollama_llm = UnifiedLLM(ollama_config)
|
||||
messages = [
|
||||
{"role": "user", "content": "什么是人工智能?"}
|
||||
]
|
||||
|
||||
print("流式响应:")
|
||||
for chunk in ollama_llm.stream_chat(messages):
|
||||
print(chunk.content, end="", flush=True)
|
||||
print()
|
||||
except Exception as e:
|
||||
print(f"Ollama示例错误: {e}(请确保Ollama服务正在运行)")
|
||||
|
||||
# 示例3: 使用快捷函数
|
||||
print("\n示例3: 使用快捷函数")
|
||||
try:
|
||||
llm = create_llm_client(
|
||||
provider="openai",
|
||||
model_name="gpt-3.5-turbo",
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
temperature=0.5
|
||||
)
|
||||
|
||||
response = llm.complete("天空为什么是蓝色的?")
|
||||
print(f"响应: {response.content[:100]}...")
|
||||
except Exception as e:
|
||||
print(f"快捷函数示例错误: {e}")
|
||||
@@ -0,0 +1,162 @@
|
||||
# 大语言模型调用框架架构图
|
||||
general_text_ai_req.py
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph "应用层"
|
||||
A[用户应用] --> B[UnifiedLLM 统一接口]
|
||||
end
|
||||
|
||||
subgraph "适配器层"
|
||||
B --> C{模型提供商路由}
|
||||
C --> D[OpenAI 适配器]
|
||||
C --> E[Anthropic 适配器]
|
||||
C --> F[Ollama 适配器]
|
||||
C --> G[通用HTTP适配器]
|
||||
C --> H[其他适配器]
|
||||
end
|
||||
|
||||
subgraph "服务层"
|
||||
D --> I[OpenAI API]
|
||||
E --> J[Anthropic API]
|
||||
F --> K[Ollama 服务]
|
||||
G --> L[LM Studio]
|
||||
G --> M[llama.cpp]
|
||||
G --> N[其他兼容API]
|
||||
end
|
||||
|
||||
subgraph "配置层"
|
||||
O[ModelConfig] --> C
|
||||
O --> D
|
||||
O --> E
|
||||
O --> F
|
||||
O --> G
|
||||
end
|
||||
|
||||
subgraph "数据流"
|
||||
P[输入: 消息/提示] --> B
|
||||
I --> Q[输出: ModelResponse]
|
||||
J --> Q
|
||||
K --> Q
|
||||
L --> Q
|
||||
M --> Q
|
||||
N --> Q
|
||||
end
|
||||
|
||||
style A fill:#4567f1
|
||||
style B fill:#4567f1
|
||||
style O fill:#456748
|
||||
style D fill:#457911
|
||||
style E fill:#466bd5
|
||||
style F fill:#4567f1
|
||||
style G fill:#4567f1
|
||||
```
|
||||
|
||||
# 数据流图
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant User as 用户/应用
|
||||
participant UnifiedLLM as UnifiedLLM
|
||||
participant Adapter as 适配器
|
||||
participant API as API服务
|
||||
|
||||
User->>UnifiedLLM: 调用chat()或complete()
|
||||
UnifiedLLM->>UnifiedLLM: 根据配置选择适配器
|
||||
UnifiedLLM->>Adapter: 转发请求
|
||||
Adapter->>API: 发送HTTP请求/API调用
|
||||
Note over API: 处理请求并生成响应
|
||||
|
||||
alt 流式模式
|
||||
API-->>Adapter: 流式响应数据
|
||||
Adapter-->>UnifiedLLM: 流式ModelResponse
|
||||
UnifiedLLM-->>User: 迭代器返回分块响应
|
||||
else 非流式模式
|
||||
API-->>Adapter: 完整响应
|
||||
Adapter-->>UnifiedLLM: ModelResponse对象
|
||||
UnifiedLLM-->>User: 完整响应内容
|
||||
end
|
||||
```
|
||||
# 类关系图
|
||||
```mermaid
|
||||
classDiagram
|
||||
class ModelConfig {
|
||||
+provider: ModelProvider
|
||||
+model_name: str
|
||||
+api_key: Optional[str]
|
||||
+base_url: Optional[str]
|
||||
+temperature: float
|
||||
+max_tokens: int
|
||||
+to_dict() Dict
|
||||
}
|
||||
|
||||
class ChatMessage {
|
||||
+role: str
|
||||
+content: str
|
||||
+name: Optional[str]
|
||||
+to_dict() Dict
|
||||
}
|
||||
|
||||
class ModelResponse {
|
||||
+content: str
|
||||
+model: str
|
||||
+usage: Optional[Dict]
|
||||
+finish_reason: Optional[str]
|
||||
+raw_response: Optional[Dict]
|
||||
}
|
||||
|
||||
class BaseLLMClient {
|
||||
<<abstract>>
|
||||
#config: ModelConfig
|
||||
#client: Any
|
||||
+__init__(config: ModelConfig)
|
||||
+_initialize_client()
|
||||
+chat_completion(messages, **kwargs)*
|
||||
+completion(prompt, **kwargs)*
|
||||
+format_messages(messages) List[Dict]
|
||||
}
|
||||
|
||||
class UnifiedLLM {
|
||||
-config: ModelConfig
|
||||
-client: BaseLLMClient
|
||||
+__init__(config: ModelConfig)
|
||||
+_create_client() BaseLLMClient
|
||||
+update_config(config: ModelConfig)
|
||||
+chat(messages, **kwargs) ModelResponse
|
||||
+complete(prompt, **kwargs) ModelResponse
|
||||
+stream_chat(messages, **kwargs) Iterator
|
||||
+stream_complete(prompt, **kwargs) Iterator
|
||||
}
|
||||
|
||||
class OpenAIClient {
|
||||
+_initialize_client()
|
||||
+chat_completion(messages, **kwargs)
|
||||
+completion(prompt, **kwargs)
|
||||
}
|
||||
|
||||
class AnthropicClient {
|
||||
+_initialize_client()
|
||||
+chat_completion(messages, **kwargs)
|
||||
+completion(prompt, **kwargs)
|
||||
}
|
||||
|
||||
class OllamaClient {
|
||||
+_initialize_client()
|
||||
+chat_completion(messages, **kwargs)
|
||||
+completion(prompt, **kwargs)
|
||||
}
|
||||
|
||||
class GenericLLMClient {
|
||||
+_initialize_client()
|
||||
+chat_completion(messages, **kwargs)
|
||||
+completion(prompt, **kwargs)
|
||||
}
|
||||
|
||||
BaseLLMClient <|-- OpenAIClient
|
||||
BaseLLMClient <|-- AnthropicClient
|
||||
BaseLLMClient <|-- OllamaClient
|
||||
BaseLLMClient <|-- GenericLLMClient
|
||||
UnifiedLLM o-- BaseLLMClient
|
||||
UnifiedLLM --> ModelConfig
|
||||
BaseLLMClient --> ModelConfig
|
||||
BaseLLMClient --> ChatMessage
|
||||
BaseLLMClient --> ModelResponse
|
||||
```
|
||||
@@ -0,0 +1,26 @@
|
||||
本模块为tts模块,即文本转语音模块。
|
||||
|
||||
本模块负责将来自AI的回复转为语音。
|
||||
|
||||
说明:
|
||||
在本module当中,每个子模块的用途分别是:
|
||||
- tts_core 对不同的tts的实现,提供相对统一的接口
|
||||
- gpt_sovits
|
||||
实现了gpt_sovits的tts接口封装
|
||||
|
||||
|
||||
async_audio_player.py
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant TTS as GPT-SoVITS API
|
||||
participant WS as WebSocket服务
|
||||
participant Buffer as 音频缓冲区
|
||||
participant Player as 音频播放器
|
||||
|
||||
TTS->>WS: 流式音频块(chunks)
|
||||
WS->>Buffer: 写入队列(Queue)
|
||||
Buffer->>Player: 消费PCM数据
|
||||
Player->>声卡: 实时播放
|
||||
|
||||
Note over TTS,Player: 三重缓冲 + 动态采样率检测
|
||||
```
|
||||
@@ -0,0 +1,164 @@
|
||||
# tts_core/async_audio_player.py
|
||||
import asyncio
|
||||
import io
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
import numpy as np
|
||||
import sounddevice as sd
|
||||
import wave
|
||||
|
||||
class AsyncAudioPlayer:
|
||||
"""
|
||||
异步流式音频播放器
|
||||
- 自动检测WAV头并解析采样率
|
||||
- 使用环形缓冲区确保播放流畅
|
||||
- 支持动态音频格式切换
|
||||
"""
|
||||
|
||||
def __init__(self, buffer_size: int = 10):
|
||||
"""
|
||||
Args:
|
||||
buffer_size: 音频块缓冲数量(越大越稳定,但延迟越高)
|
||||
"""
|
||||
self.audio_queue = asyncio.Queue(maxsize=buffer_size)
|
||||
self.sample_rate = 32000 # 默认采样率
|
||||
self.channels = 1
|
||||
self.dtype = np.float32
|
||||
self.stream: Optional[sd.OutputStream] = None
|
||||
self.is_playing = False
|
||||
self._first_chunk_processed = False
|
||||
logger.info(f"🎵 音频播放器初始化,缓冲区大小: {buffer_size}")
|
||||
|
||||
async def add_chunk(self, audio_data: bytes):
|
||||
"""
|
||||
添加音频块到播放队列
|
||||
自动处理第一个chunk(包含WAV头)
|
||||
"""
|
||||
try:
|
||||
# 第一个chunk需要解析WAV头
|
||||
if not self._first_chunk_processed:
|
||||
# 写入BytesIO以便wave模块读取
|
||||
wav_buffer = io.BytesIO(audio_data)
|
||||
try:
|
||||
with wave.open(wav_buffer, 'rb') as wav_file:
|
||||
# 解析WAV头信息
|
||||
self.sample_rate = wav_file.getframerate()
|
||||
self.channels = wav_file.getnchannels()
|
||||
self.sampwidth = wav_file.getsampwidth()
|
||||
|
||||
# 读取PCM数据(去掉头部)
|
||||
pcm_data = wav_file.readframes(wav_file.getnframes())
|
||||
|
||||
logger.info(f"📊 解析WAV头: {self.sample_rate}Hz, {self.channels}ch, {self.sampwidth * 8}bit")
|
||||
|
||||
# 转换为numpy数组
|
||||
if self.sampwidth == 2:
|
||||
audio_array = np.frombuffer(pcm_data, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
elif self.sampwidth == 4:
|
||||
audio_array = np.frombuffer(pcm_data, dtype=np.int32).astype(np.float32) / 2147483648.0
|
||||
else:
|
||||
raise ValueError(f"不支持的采样宽度: {self.sampwidth}")
|
||||
|
||||
# 转单声道(如果多声道)
|
||||
if self.channels > 1:
|
||||
audio_array = audio_array.reshape(-1, self.channels).mean(axis=1)
|
||||
|
||||
await self.audio_queue.put(audio_array)
|
||||
self._first_chunk_processed = True
|
||||
|
||||
except wave.Error:
|
||||
# 可能是不完整的WAV头,尝试直接播放
|
||||
logger.warning("⚠️ WAV头解析失败,尝试直接播放")
|
||||
await self._play_raw(audio_data)
|
||||
return
|
||||
else:
|
||||
# 后续chunk直接播放(RAW PCM)
|
||||
await self._play_raw(audio_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 音频块处理失败: {e}")
|
||||
|
||||
async def _play_raw(self, audio_data: bytes):
|
||||
"""播放RAW PCM数据"""
|
||||
try:
|
||||
# 假设是16位PCM(最常见)
|
||||
audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
|
||||
# 如果是多声道数据(罕见)
|
||||
if len(audio_array) % self.channels == 0 and self.channels > 1:
|
||||
audio_array = audio_array.reshape(-1, self.channels).mean(axis=1)
|
||||
|
||||
await self.audio_queue.put(audio_array)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ RAW音频处理失败: {e}")
|
||||
|
||||
async def play_worker(self):
|
||||
"""后台播放任务"""
|
||||
logger.info("🎧 音频播放任务启动")
|
||||
|
||||
while self.is_playing or not self.audio_queue.empty():
|
||||
try:
|
||||
# 从队列获取音频块(最多等待0.5秒)
|
||||
audio_chunk = await asyncio.wait_for(self.audio_queue.get(), timeout=0.5)
|
||||
|
||||
# 延迟初始化音频流(直到获得第一个数据块)
|
||||
if self.stream is None:
|
||||
logger.info(f"🔊 打开音频输出流: {self.sample_rate}Hz")
|
||||
self.stream = sd.OutputStream(
|
||||
samplerate=self.sample_rate,
|
||||
channels=1,
|
||||
dtype=self.dtype,
|
||||
blocksize=1024, # 低延迟模式
|
||||
latency='low'
|
||||
)
|
||||
self.stream.start()
|
||||
|
||||
# 写入音频流播放
|
||||
self.stream.write(audio_chunk)
|
||||
|
||||
# 标记任务完成
|
||||
self.audio_queue.task_done()
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 播放任务异常: {e}")
|
||||
break
|
||||
|
||||
logger.info("🛑 音频播放任务结束")
|
||||
|
||||
async def start(self):
|
||||
"""启动播放系统"""
|
||||
self.is_playing = True
|
||||
self._first_chunk_processed = False
|
||||
self.play_task = asyncio.create_task(self.play_worker())
|
||||
|
||||
async def stop(self):
|
||||
"""停止播放并清理资源"""
|
||||
self.is_playing = False
|
||||
|
||||
# 等待播放任务结束
|
||||
if hasattr(self, 'play_task'):
|
||||
await self.play_task
|
||||
|
||||
# 关闭音频流
|
||||
if self.stream is not None:
|
||||
self.stream.stop()
|
||||
self.stream.close()
|
||||
self.stream = None
|
||||
|
||||
# 清空队列
|
||||
while not self.audio_queue.empty():
|
||||
try:
|
||||
self.audio_queue.get_nowait()
|
||||
except:
|
||||
break
|
||||
|
||||
logger.info("✅ 音频播放已停止")
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.start()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.stop()
|
||||
@@ -0,0 +1,378 @@
|
||||
# gpt_sovits/gpt_sovits_client.py
|
||||
import asyncio
|
||||
import json
|
||||
from loguru import logger
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator, Optional, Union, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
|
||||
class APIError(Exception):
|
||||
"""API调用异常"""
|
||||
def __init__(self, status_code: int, message: str):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
super().__init__(f"API Error {status_code}: {message}")
|
||||
|
||||
|
||||
class StreamingMode(Enum):
|
||||
"""流式模式枚举"""
|
||||
DISABLED = 0 # 非流式
|
||||
BEST_QUALITY = 1 # 最佳质量(慢)
|
||||
MEDIUM_QUALITY = 2 # 中等质量
|
||||
FASTEST = 3 # 最快响应(较低质量)
|
||||
|
||||
|
||||
class TTSConfig(BaseModel):
|
||||
"""TTS请求配置模型"""
|
||||
text: str = Field(..., description="待合成文本")
|
||||
text_lang: str = Field(..., description="文本语言: zh/en/ja/ko/cantonese")
|
||||
ref_audio_path: str = Field(..., description="参考音频路径")
|
||||
prompt_lang: str = Field(..., description="提示文本语言")
|
||||
|
||||
# 可选参数
|
||||
prompt_text: str = Field(default="", description="参考音频提示文本")
|
||||
aux_ref_audio_paths: list = Field(default_factory=list, description="辅助参考音频")
|
||||
top_k: int = Field(default=5, ge=1, le=100, description="Top-K采样")
|
||||
top_p: float = Field(default=1.0, ge=0.1, le=1.0, description="Top-P采样")
|
||||
temperature: float = Field(default=1.0, ge=0.1, le=1.0, description="采样温度")
|
||||
text_split_method: str = Field(default="cut5", description="文本分割方法") # 默认按照标点符号切分
|
||||
batch_size: int = Field(default=8, ge=1, le=200, description="批处理大小")
|
||||
speed_factor: float = Field(default=1.0, ge=0.6, le=1.65, description="语速倍率")
|
||||
|
||||
# 流式相关
|
||||
streaming_mode: Union[bool, int, StreamingMode] = Field(default=False, description="流式模式")
|
||||
media_type: str = Field(default="wav", description="输出格式: wav/raw/ogg/aac") # 输出格式
|
||||
|
||||
# 高级参数
|
||||
repetition_penalty: float = Field(default=1.35, ge=1.0, le=2.0) # 惩罚参数
|
||||
sample_steps: int = Field(default=32, ge=10, le=100) # 采样步数
|
||||
parallel_infer: bool = Field(default=True) # 并行推理
|
||||
|
||||
@validator('text_lang', 'prompt_lang')
|
||||
def validate_language(cls, v):
|
||||
"""验证语言代码"""
|
||||
valid_langs = {'zh', 'en', 'ja', 'ko', 'cantonese'}
|
||||
if v.lower() not in valid_langs:
|
||||
raise ValueError(f"Unsupported language: {v}. Must be one of {valid_langs}")
|
||||
return v.lower()
|
||||
|
||||
@validator('media_type')
|
||||
def validate_media_type(cls, v):
|
||||
"""验证媒体类型"""
|
||||
valid_types = {'wav', 'raw', 'ogg', 'aac'}
|
||||
if v not in valid_types:
|
||||
raise ValueError(f"Unsupported media_type: {v}")
|
||||
return v
|
||||
|
||||
def build_request(self) -> Dict[str, Any]:
|
||||
"""构建API请求数据"""
|
||||
data = self.dict(exclude_none=True)
|
||||
# 处理流式模式
|
||||
if isinstance(self.streaming_mode, StreamingMode):
|
||||
data['streaming_mode'] = self.streaming_mode.value
|
||||
return data
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioResponse:
|
||||
"""音频响应包装类"""
|
||||
audio_data: bytes
|
||||
sample_rate: int = 32000
|
||||
|
||||
def save(self, path: Union[str, Path]) -> None:
|
||||
"""保存音频文件"""
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, 'wb') as f:
|
||||
f.write(self.audio_data)
|
||||
logger.info(f"Audio saved to {path}, size: {len(self.audio_data)} bytes")
|
||||
|
||||
|
||||
class GPTSoVITSClient:
|
||||
"""
|
||||
GPT-SoVITS异步API客户端
|
||||
|
||||
完整支持所有TTS功能:
|
||||
- 文本合成(流式/非流式)
|
||||
- 模型切换(GPT/SoVITS)
|
||||
- 参考音频设置
|
||||
- 服务器控制
|
||||
"""
|
||||
|
||||
def __init__(self, host: str = "127.0.0.1", port: int = 9880, debug: bool = False):
|
||||
"""
|
||||
初始化客户端
|
||||
|
||||
Args:
|
||||
host: API服务器地址
|
||||
port: API端口
|
||||
debug: 是否开启调试模式
|
||||
"""
|
||||
self.base_url = f"http://{host}:{port}"
|
||||
self.client = httpx.AsyncClient(
|
||||
base_url=self.base_url,
|
||||
timeout=httpx.Timeout(30.0, connect=5.0)
|
||||
)
|
||||
self.debug_mode = debug
|
||||
logger.info(f"GPT-SoVITS Client initialized: {self.base_url}")
|
||||
|
||||
async def __aenter__(self) -> "GPTSoVITSClient":
|
||||
"""异步上下文管理器入口"""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""异步上下文管理器出口"""
|
||||
await self.close()
|
||||
|
||||
async def close(self):
|
||||
"""关闭HTTP连接"""
|
||||
await self.client.aclose()
|
||||
logger.info("Client connection closed")
|
||||
|
||||
def _log_debug(self, message: str, **kwargs):
|
||||
"""调试日志"""
|
||||
if self.debug_mode:
|
||||
logger.debug(f"{message} | {kwargs}")
|
||||
|
||||
async def _handle_response(self, response: httpx.Response) -> Dict[str, Any]:
|
||||
"""统一响应处理"""
|
||||
if response.status_code == 200:
|
||||
content_type = response.headers.get('content-type', '')
|
||||
if 'application/json' in content_type:
|
||||
return response.json()
|
||||
return {"status": "success", "content": response.content}
|
||||
else:
|
||||
try:
|
||||
error_data = response.json()
|
||||
raise APIError(response.status_code, error_data.get('message', 'Unknown error'))
|
||||
except json.JSONDecodeError:
|
||||
raise APIError(response.status_code, response.text)
|
||||
|
||||
# 核心TTS接口
|
||||
async def tts(
|
||||
self,
|
||||
text: str,
|
||||
ref_audio_path: str,
|
||||
text_lang: str = "zh",
|
||||
prompt_lang: str = "zh",
|
||||
streaming_mode: StreamingMode = StreamingMode.DISABLED, # 默认禁用流式
|
||||
media_type: str = "wav",
|
||||
**kwargs
|
||||
) -> Union[AudioResponse, AsyncGenerator[AudioResponse, None]]:
|
||||
"""
|
||||
文本转语音(支持流式)
|
||||
|
||||
Args:
|
||||
text: 待合成文本
|
||||
ref_audio_path: 参考音频路径(服务器本地路径或URL)
|
||||
text_lang: 文本语言
|
||||
prompt_lang: 提示语言
|
||||
streaming_mode: 流式模式
|
||||
media_type: 输出格式
|
||||
**kwargs: 其他TTS参数
|
||||
|
||||
Returns:
|
||||
非流式: AudioResponse对象
|
||||
流式: AsyncGenerator[AudioResponse, None]异步生成器
|
||||
|
||||
Example:
|
||||
# 非流式
|
||||
audio = await client.tts("你好", "ref.wav")
|
||||
|
||||
# 流式
|
||||
async for chunk in client.tts("你好", "ref.wav", streaming_mode=StreamingMode.FASTEST):
|
||||
process(chunk.audio_data)
|
||||
"""
|
||||
config = TTSConfig(
|
||||
text=text,
|
||||
ref_audio_path=ref_audio_path,
|
||||
text_lang=text_lang,
|
||||
prompt_lang=prompt_lang,
|
||||
streaming_mode=streaming_mode,
|
||||
media_type=media_type,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self._log_debug("TTS Request", config=config.dict())
|
||||
|
||||
if streaming_mode == StreamingMode.DISABLED:
|
||||
# 非流式模式
|
||||
response = await self.client.post("/tts", json=config.build_request())
|
||||
if response.status_code != 200:
|
||||
raise APIError(response.status_code, await response.text())
|
||||
|
||||
return AudioResponse(
|
||||
audio_data=response.content,
|
||||
sample_rate=32000 # 默认采样率
|
||||
)
|
||||
else:
|
||||
# 流式模式
|
||||
config.parallel_infer = False # 强制关闭并行推理,避免与流式冲突
|
||||
config.batch_size = 1 # 流式下batch_size必须为1
|
||||
async def stream_generator():
|
||||
async with self.client.stream(
|
||||
"POST", "/tts",
|
||||
json=config.build_request(),
|
||||
timeout=httpx.Timeout(60.0) # 流式需要更长超时
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
raise APIError(response.status_code, await response.aread())
|
||||
|
||||
async for chunk in response.aiter_bytes():
|
||||
if chunk:
|
||||
yield AudioResponse(audio_data=chunk)
|
||||
|
||||
return stream_generator()
|
||||
|
||||
# 模型管理接口
|
||||
async def set_gpt_weights(self, weights_path: str) -> bool:
|
||||
"""
|
||||
切换GPT模型权重
|
||||
|
||||
Args:
|
||||
weights_path: 权重文件路径(服务器本地路径)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
|
||||
Example:
|
||||
await client.set_gpt_weights("models/s1bert.ckpt")
|
||||
"""
|
||||
if not weights_path:
|
||||
raise ValueError("weights_path cannot be empty")
|
||||
|
||||
params = {"weights_path": weights_path}
|
||||
response = await self.client.get("/set_gpt_weights", params=params)
|
||||
result = await self._handle_response(response)
|
||||
|
||||
logger.info(f"GPT weights switched to: {weights_path}")
|
||||
return True
|
||||
|
||||
async def set_sovits_weights(self, weights_path: str) -> bool:
|
||||
"""
|
||||
切换SoVITS模型权重
|
||||
|
||||
Args:
|
||||
weights_path: 权重文件路径(服务器本地路径)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
if not weights_path:
|
||||
raise ValueError("weights_path cannot be empty")
|
||||
|
||||
params = {"weights_path": weights_path}
|
||||
response = await self.client.get("/set_sovits_weights", params=params)
|
||||
await self._handle_response(response)
|
||||
|
||||
logger.info(f"SoVITS weights switched to: {weights_path}")
|
||||
return True
|
||||
|
||||
# 参考音频管理
|
||||
async def set_refer_audio(
|
||||
self,
|
||||
audio_source: Union[str, Path, bytes],
|
||||
audio_name: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
设置参考音频(支持多种输入方式)
|
||||
|
||||
Args:
|
||||
audio_source: 音频文件路径(str/Path)或音频数据(bytes)
|
||||
audio_name: 音频文件名(仅bytes输入时需要)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
|
||||
Example:
|
||||
# 方式1: 服务器本地文件
|
||||
await client.set_refer_audio("/path/to/audio.wav")
|
||||
|
||||
# 方式2: 上传音频数据
|
||||
with open("audio.wav", "rb") as f:
|
||||
await client.set_refer_audio(f.read(), "audio.wav")
|
||||
"""
|
||||
if isinstance(audio_source, (str, Path)):
|
||||
# GET方式:服务器本地路径
|
||||
params = {"refer_audio_path": str(audio_source)}
|
||||
response = await self.client.get("/set_refer_audio", params=params)
|
||||
await self._handle_response(response)
|
||||
logger.info(f"Reference audio set: {audio_source}")
|
||||
else:
|
||||
# POST方式:上传音频数据
|
||||
if not audio_name:
|
||||
raise ValueError("audio_name is required when uploading bytes")
|
||||
|
||||
files = {"audio_file": (audio_name, audio_source, "audio/wav")}
|
||||
response = await self.client.post("/set_refer_audio", files=files)
|
||||
await self._handle_response(response)
|
||||
logger.info(f"Reference audio uploaded: {audio_name}")
|
||||
|
||||
return True
|
||||
|
||||
# 服务器控制
|
||||
async def control_command(self, command: str) -> bool:
|
||||
"""
|
||||
发送控制命令
|
||||
|
||||
Args:
|
||||
command: 命令类型 - "restart" 或 "exit"
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
|
||||
Warning:
|
||||
"exit"命令会终止API服务器进程!
|
||||
"""
|
||||
if command not in ["restart", "exit"]:
|
||||
raise ValueError("Command must be 'restart' or 'exit'")
|
||||
|
||||
response = await self.client.get("/control", params={"command": command})
|
||||
await self._handle_response(response)
|
||||
|
||||
logger.warning(f"Control command executed: {command}")
|
||||
return True
|
||||
|
||||
# 高级快捷方法
|
||||
async def get_server_info(self) -> Dict[str, Any]:
|
||||
"""获取服务器状态信息"""
|
||||
# 通过调用根路径或自定义health接口
|
||||
try:
|
||||
response = await self.client.get("/")
|
||||
return {"status": "online", "detail": response.text}
|
||||
except Exception as e:
|
||||
return {"status": "error", "detail": str(e)}
|
||||
|
||||
async def batch_tts(
|
||||
self,
|
||||
texts: list[str],
|
||||
ref_audio_path: str,
|
||||
**kwargs
|
||||
) -> list[AudioResponse]:
|
||||
"""
|
||||
批量TTS合成
|
||||
|
||||
Args:
|
||||
texts: 文本列表
|
||||
ref_audio_path: 参考音频
|
||||
**kwargs: 其他TTS参数
|
||||
|
||||
Returns:
|
||||
list[AudioResponse]: 音频响应列表
|
||||
"""
|
||||
tasks = [
|
||||
self.tts(text, ref_audio_path, **kwargs)
|
||||
for text in texts
|
||||
]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
# 异步上下文管理器辅助函数
|
||||
async def create_client(*args, **kwargs) -> GPTSoVITSClient:
|
||||
"""快速创建客户端实例"""
|
||||
return GPTSoVITSClient(*args, **kwargs)
|
||||
@@ -0,0 +1,27 @@
|
||||
# dto/dto_base.py
|
||||
from abc import ABC
|
||||
from typing import Callable, Coroutine, Any, Dict
|
||||
from src.modules.websocket_base_module.websocket_core.core_ws_server import WebSocketServer
|
||||
|
||||
class MessageDTO(ABC):
|
||||
"""DTO基类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ws_server : WebSocketServer, # WebSocketServer单例
|
||||
):
|
||||
# 保存服务器实例(用于发送)
|
||||
self.ws_server = ws_server
|
||||
|
||||
# 便捷属性,DTO层直接调用
|
||||
@property
|
||||
def send_binary(self):
|
||||
return self.ws_server.send_binary
|
||||
|
||||
@property
|
||||
def send_text(self):
|
||||
return self.ws_server.send_text
|
||||
|
||||
@property
|
||||
def send_json(self):
|
||||
return self.ws_server.send_json
|
||||
@@ -0,0 +1,95 @@
|
||||
from pydantic import Field, BaseModel
|
||||
import base64
|
||||
from datetime import datetime, timezone
|
||||
class AudioDataTransferObject(BaseModel):
|
||||
"""
|
||||
音频数据传输对象
|
||||
该对象被用于服务端与客户端的音频数据交互
|
||||
同时支持流式与非流式的音频数据
|
||||
同时收发对等(通过Owner标识)
|
||||
"""
|
||||
Owner: str = Field(default="server", description="音频数据的拥有者(server or client)")
|
||||
isStream: bool = Field(default=False, description="音频数据是否为流式数据")
|
||||
isStart: bool = Field(default=False, description="音频数据是否开始(流式时有效)")
|
||||
isEnd: bool = Field(default=False, description="音频数据是否结束(流式时有效)")
|
||||
sequence: int = Field(default=0, description="音频数据块序列号(流式时有效)")
|
||||
data: bytes = Field(default=b"", description="音频数据,流式时为分块数据,base64编码")
|
||||
sampleRate: int = Field(default=32000, description="音频采样率")
|
||||
channelCount: int = Field(default=1, description="音频通道数")
|
||||
bitDepth: int = Field(default=16, description="音频采样位数")
|
||||
duration: float = Field(default=0.0, description="音频时长")
|
||||
text: str = Field(default="", description="音频对应的文本")
|
||||
|
||||
def set_dto_data(self, **kwargs) -> "AudioDataTransferObject":
|
||||
"""链式更新数据(Pydantic 风格)"""
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
return self
|
||||
|
||||
def to_json(self) -> dict:
|
||||
"""
|
||||
将DTO对象转换为可序列化字典
|
||||
返回的json数据的格式:
|
||||
{
|
||||
"type": "audio_data",
|
||||
"timestamp": 1672531200.0,
|
||||
"data": {
|
||||
"Owner": "server",
|
||||
......
|
||||
}
|
||||
}
|
||||
"""
|
||||
# model_dump() 是 Pydantic v2 的序列化方法
|
||||
payload = self.model_dump() # 获取所有模型字段
|
||||
payload["data"] = base64.b64encode(payload["data"]).decode() # base64编码
|
||||
# 构造嵌套结构
|
||||
return {
|
||||
"type": "audio_data",
|
||||
"timestamp": datetime.now(timezone.utc).timestamp(),
|
||||
"data": payload # 音频字段嵌套
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_data: dict) -> "AudioDataTransferObject":
|
||||
"""
|
||||
从JSON数据创建DTO对象
|
||||
传入的json数据格式:
|
||||
{
|
||||
"Owner": "server",
|
||||
......
|
||||
}
|
||||
"""
|
||||
payload = json_data
|
||||
# 解码 base64 (内层 data 字段在传输时是 base64 字符串) -> bytes
|
||||
if "data" in payload and isinstance(payload["data"], str):
|
||||
payload["data"] = base64.b64decode(payload["data"])
|
||||
# 构造对象 Pydantic 自动忽略 type/timestamp
|
||||
return cls.model_validate(payload)
|
||||
|
||||
|
||||
# 测试代码
|
||||
if __name__ == "__main__":
|
||||
# 模拟音频数据
|
||||
import os
|
||||
|
||||
test_audio = os.urandom(1024) # 随机生成1KB音频数据
|
||||
|
||||
# 创建DTO
|
||||
audio = AudioDataTransferObject(
|
||||
data=test_audio,
|
||||
sequence=1,
|
||||
isStream=True,
|
||||
sampleRate=44100,
|
||||
duration=0.5
|
||||
)
|
||||
|
||||
# 序列化 → JSON
|
||||
json_dict = audio.to_json()
|
||||
print(f"序列化后 data 长度: {len(json_dict['data'])}") # ~1368 字符
|
||||
print(f"data 前30字符: {json_dict['data'][:30]}...")
|
||||
|
||||
# 反序列化 → DTO
|
||||
restored = AudioDataTransferObject.from_json(json_dict)
|
||||
print(f"反序列化后 data 长度: {len(restored.data)}") # 1024 bytes
|
||||
print(f"数据一致: {restored.data == test_audio}") # True
|
||||
@@ -0,0 +1,73 @@
|
||||
from pydantic import Field, BaseModel
|
||||
from datetime import datetime, timezone
|
||||
class AutoAgentDataTransferObject(BaseModel):
|
||||
"""
|
||||
自动化agent数据传输对象
|
||||
该对象被用于服务端向客户端发送控制信息
|
||||
"""
|
||||
Action: str = Field(default="", description="自动化动作名称")
|
||||
x1: int = Field(default=-1, description="鼠标起始位置x1")
|
||||
y1: int = Field(default=-1, description="鼠标起始位置y1")
|
||||
x2: int = Field(default=-1, description="鼠标结束位置x2")
|
||||
y2: int = Field(default=-1, description="鼠标结束位置y2")
|
||||
key: str = Field(default="", description="快捷键")
|
||||
content: str = Field(default="", description="输入文本内容")
|
||||
direction: str = Field(default="", description="滚动方向")
|
||||
|
||||
def set_dto_data(self, **kwargs) -> "AutoAgentDataTransferObject":
|
||||
"""链式更新数据(Pydantic 风格)"""
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
return self
|
||||
|
||||
def to_json(self) -> dict:
|
||||
"""
|
||||
将DTO对象转换为可序列化字典
|
||||
返回的json数据的格式:
|
||||
{
|
||||
"type": "auto_agent",
|
||||
"timestamp": 1672531200.0,
|
||||
"data": {
|
||||
"Action": "自动化agent返回的相应的自动化动作名称",
|
||||
"x1": "某个操作的x1,不是所有的操作都有,如果相关的操作没有,写成-1即可,y1,x2,y2也同理",
|
||||
"y1": "某个操作的y1",
|
||||
"x2": "某个操作的x2",
|
||||
"y2": "某个操作的y2",
|
||||
"key": "快捷键,若当前操作没有该字段信息,此处内容为空即可",
|
||||
"content": "输入文本的内容,若当前操作没有该字段信息,此处内容为空即可",
|
||||
"direction": "滚动方向,若当前操作没有该字段信息,此处内容为空即可"
|
||||
}
|
||||
}
|
||||
"""
|
||||
# model_dump() 是 Pydantic v2 的序列化方法
|
||||
payload = self.model_dump() # 获取所有模型字段
|
||||
# 构造嵌套结构
|
||||
return {
|
||||
"type": "auto_agent",
|
||||
"timestamp": datetime.now(timezone.utc).timestamp(),
|
||||
"data": payload # 字段嵌套
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_data: dict) -> "AutoAgentDataTransferObject":
|
||||
"""
|
||||
从JSON数据创建DTO对象
|
||||
传入的json数据格式:
|
||||
{
|
||||
"type": "auto_agent",
|
||||
"Action": "自动化agent返回的相应的自动化动作名称",
|
||||
"x1": "某个操作的x1,不是所有的操作都有,如果相关的操作没有,写成-1即可,y1,x2,y2也同理",
|
||||
"y1": "某个操作的y1",
|
||||
"x2": "某个操作的x2",
|
||||
"y2": "某个操作的y2",
|
||||
"key": "快捷键,若当前操作没有该字段信息,此处内容为空即可",
|
||||
"content": "输入文本的内容,若当前操作没有该字段信息,此处内容为空即可",
|
||||
"direction": "滚动方向,若当前操作没有该字段信息,此处内容为空即可"
|
||||
}
|
||||
"""
|
||||
payload = json_data
|
||||
payload.pop("type", None) # 移除多余字段
|
||||
# 构造对象 Pydantic 自动忽略 type/timestamp
|
||||
return cls.model_validate(payload)
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
class BaseDataTransferObject:
|
||||
"""
|
||||
DTO基类
|
||||
子类按需重写成员函数即可
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
def to_json(self):
|
||||
"""
|
||||
将DTO对象转换为JSON
|
||||
"""
|
||||
pass
|
||||
def from_json(self, json_data):
|
||||
"""
|
||||
从JSON数据中创建DTO对象
|
||||
"""
|
||||
pass
|
||||
def to_binary(self):
|
||||
"""
|
||||
将DTO对象转换为二进制
|
||||
"""
|
||||
pass
|
||||
def from_binary(self, binary_data):
|
||||
"""
|
||||
从二进制数据中创建DTO对象
|
||||
"""
|
||||
pass
|
||||
def to_text(self):
|
||||
"""
|
||||
将DTO对象转换为文本
|
||||
"""
|
||||
pass
|
||||
def from_text(self, text_data):
|
||||
"""
|
||||
从文本数据中创建DTO对象
|
||||
"""
|
||||
pass
|
||||
@@ -0,0 +1,72 @@
|
||||
from pydantic import Field, BaseModel
|
||||
from datetime import datetime, timezone
|
||||
class ScreenShotDataTransferObject(BaseModel):
|
||||
"""
|
||||
服务端向客户端请求实时截图的数据传输对象
|
||||
服务端与客户端收发对等(通过Owner标识)
|
||||
客户端收到这个type的包,就会自动对当前设备的画面进行截图
|
||||
"""
|
||||
Owner: str = Field(default="server", description="数据的拥有者(server or client)")
|
||||
isSuccess: bool = Field(default=False, description="是否截图成功")
|
||||
RealTimeScreenShot: str = Field(default="", description="客户端设备的实时截图数据(base64)")
|
||||
Width: int = Field(default=1920, description="截图的宽度")
|
||||
Height: int = Field(default=1080, description="截图的高度")
|
||||
DescribeInfo: str = Field(default="", description="设备的描述信息(告知模型以做出更加准确的判断)")
|
||||
LLMResponse: str = Field(default="", description="LLM的响应结果(由服务端发送时携带)")
|
||||
|
||||
|
||||
def set_dto_data(self, **kwargs) -> "ScreenShotDataTransferObject":
|
||||
"""链式更新数据(Pydantic 风格)"""
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
return self
|
||||
|
||||
def to_json(self) -> dict:
|
||||
"""
|
||||
将DTO对象转换为可序列化字典
|
||||
返回的json数据的格式:
|
||||
{
|
||||
"type": "screenshot_data",
|
||||
"timestamp": 1672531200.0,
|
||||
"data": {
|
||||
"Owner": "数据的拥有者(server or client)",
|
||||
"isSuccess": "是否截图成功(true or false)"
|
||||
"RealTimeScreenShot": "客户端设备的实时截图数据(base64)",
|
||||
"Width": "截图的宽度",
|
||||
"Height": "截图的高度",
|
||||
"DescribeInfo": "设备的描述信息(告知模型以做出更加准确的判断)",
|
||||
"LLMResponse": "LLM的响应结果(由服务端发送时携带)"
|
||||
}
|
||||
}
|
||||
"""
|
||||
# model_dump() 是 Pydantic v2 的序列化方法
|
||||
payload = self.model_dump() # 获取所有模型字段
|
||||
# 构造嵌套结构
|
||||
return {
|
||||
"type": "screenshot_data",
|
||||
"timestamp": datetime.now(timezone.utc).timestamp(),
|
||||
"data": payload # 字段嵌套
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_data: dict) -> "ScreenShotDataTransferObject":
|
||||
"""
|
||||
从JSON数据创建DTO对象
|
||||
传入的json数据格式:
|
||||
{
|
||||
"Owner": "数据的拥有者(server or client)",
|
||||
"isSuccess": "是否截图成功(true or false)",
|
||||
"RealTimeScreenShot": "客户端设备的实时截图数据(base64)",
|
||||
"Width": "截图的宽度 非必要字段",
|
||||
"Height": "截图的高度 非必要字段",
|
||||
"DescribeInfo": "设备的描述信息(告知模型以做出更加准确的判断) 非必要字段",
|
||||
"LLMResponse": "LLM的响应结果(由服务端发送时携带) 必要字段"
|
||||
}
|
||||
"""
|
||||
payload = json_data
|
||||
payload.pop("type", None) # 移除多余字段
|
||||
payload.pop("timestamp", None) # 移除多余字段
|
||||
# 构造对象 Pydantic 自动忽略 type/timestamp
|
||||
return cls.model_validate(payload)
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
# dto/second_dtos.py
|
||||
import asyncio
|
||||
from typing import Callable, Optional, Dict, Any, List, Coroutine
|
||||
from src.modules.websocket_base_module.dto.dto_base import MessageDTO
|
||||
from loguru import logger
|
||||
|
||||
"""
|
||||
二级分发器,因为没有信号与槽机制,因此使用观察者模式替代
|
||||
"""
|
||||
def singleton(cls): # 单例
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
async def get_instance(*args, **kwargs):
|
||||
nonlocal _instance
|
||||
if _instance is None:
|
||||
async with _lock:
|
||||
if _instance is None:
|
||||
_instance = cls(*args, **kwargs)
|
||||
return _instance
|
||||
|
||||
cls.get_instance = get_instance
|
||||
return cls
|
||||
# 类型别名
|
||||
ReceiveCallback = Callable[[Any], Coroutine[Any, Any, None]]
|
||||
@singleton
|
||||
class JsonDTO(MessageDTO):
|
||||
"""针对json消息的二级分发"""
|
||||
"""
|
||||
明确业务json格式:
|
||||
{
|
||||
"type" : "xxx",
|
||||
"timestamp" : "95153...",
|
||||
"data" : "{根据业务的不同,有不同的内容}"
|
||||
}
|
||||
"""
|
||||
# 因为不同的数据块的json以type字段进行包装,根据type进行正确的数据分发
|
||||
def __init__(self, ws_server):
|
||||
super().__init__(ws_server)
|
||||
self.receivers : Dict[str, List[ReceiveCallback]] = {
|
||||
'audio_data' : [], # 音频数据
|
||||
'screenshot_data' : [] # 截图数据
|
||||
}
|
||||
# 注册json处理callback function
|
||||
ws_server.register_receiver('json', self._handle_json)
|
||||
logger.info("[JsonDTO] JSON分发器已注册")
|
||||
|
||||
def register_receiver(self, types : str, callback : ReceiveCallback):
|
||||
"""注册二次分发业务接收函数,供业务DTO调用"""
|
||||
if types in self.receivers:
|
||||
self.receivers[types].append(callback)
|
||||
logger.debug(f"[JsonDTO] 已注册 {types} 接收器,当前共 {len(self.receivers[types])} 个")
|
||||
else:
|
||||
raise ValueError(f"[JsonDTO] 不支持的分发类型: {types}")
|
||||
|
||||
def unregister_receiver(self, types : str, callback : ReceiveCallback):
|
||||
"""注销二次分发业务接收函数"""
|
||||
if callback in self.receivers[types]:
|
||||
self.receivers[types].remove(callback)
|
||||
logger.debug(f"[JsonDTO] 已注销 {types} 接收器")
|
||||
|
||||
async def _handle_json(self, data: dict):
|
||||
"""JSON消息处理"""
|
||||
logger.info(f"[JsonDTO] 收到消息")
|
||||
logger.debug(f'[JsonDTO] 当前消息时间戳: {data["timestamp"]}')
|
||||
# 根据类型进行自动分发
|
||||
await self._dispatch(data.get("type"), data["data"])
|
||||
|
||||
async def _dispatch(self, types : str, data : dict):
|
||||
"""二次分发json数据到相应的接收函数当中"""
|
||||
callbacks = self.receivers[types] # 获取相关types的所有观察者
|
||||
if not callbacks:
|
||||
logger.info(f"[JsonDTO] 无 {types} 接收器,消息被忽略")
|
||||
return
|
||||
# 并发执行所有回调
|
||||
tasks = [callback(data) for callback in callbacks]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
async def get_json_dto_instance(ws_server) -> JsonDTO:
|
||||
return JsonDTO(ws_server)
|
||||
|
||||
|
||||
class EchoDTO(MessageDTO):
|
||||
"""回声DTO:只处理文本消息 测试用"""
|
||||
|
||||
def __init__(self, ws_server):
|
||||
super().__init__(ws_server)
|
||||
# 注册文本接收函数
|
||||
ws_server.register_receiver('text', self._handle_text)
|
||||
logger.info("[EchoDTO] 文本接收器已注册")
|
||||
|
||||
async def _handle_text(self, message: str):
|
||||
"""文本消息处理"""
|
||||
logger.info(f"[EchoDTO] 收到文本: {message}")
|
||||
|
||||
# 业务逻辑
|
||||
await self.send_text(f"Echo: {message}")
|
||||
@@ -0,0 +1,281 @@
|
||||
import asyncio
|
||||
from src.modules.websocket_base_module.dto.dto_templates.audio_data_dto import AudioDataTransferObject
|
||||
from src.modules.websocket_base_module.dto.dto_templates.screenshot_data_dto import ScreenShotDataTransferObject
|
||||
from src.modules.websocket_base_module.dto.second_dtos import JsonDTO
|
||||
from loguru import logger
|
||||
from typing import Callable, List, Optional, Coroutine
|
||||
|
||||
class AudioDataDTO:
|
||||
"""音频数据交互DTO 再次分发给所有使用到了音频数据的相关业务(最后一级分发)"""
|
||||
def __init__(self, json_dto : JsonDTO):
|
||||
json_dto.register_receiver('audio_data', self._handle_audio_data) # 注册JSON接收函数
|
||||
logger.info("[AudioDataDTO] 音频接收业务已注册")
|
||||
self.json_dto = json_dto
|
||||
self.audio_data = AudioDataTransferObject() # 音频数据对象
|
||||
# 业务回调列表,延续观察者模式
|
||||
self._audio_callbacks: List[Callable[[AudioDataTransferObject], Coroutine]] = []
|
||||
# 最新音频缓存 支持同步查询
|
||||
self._latest_audio: Optional[AudioDataTransferObject] = None
|
||||
# 流式缓冲区 用于大段音频流
|
||||
self._stream_buffer: List[AudioDataTransferObject] = []
|
||||
logger.info("[AudioDataDTO] 业务接口已初始化")
|
||||
|
||||
async def _handle_audio_data(self, data: dict):
|
||||
"""处理音频数据"""
|
||||
"""
|
||||
音频数据json格式:
|
||||
{
|
||||
'Owner': 'server',
|
||||
'isStream': False,
|
||||
'isStart': False,
|
||||
'isEnd': False,
|
||||
'sequence': 0,
|
||||
'data': '',
|
||||
'sampleRate': 16000,
|
||||
'channelCount': 1,
|
||||
'bitDepth': 16,
|
||||
'duration': 0.0,
|
||||
'text': ''
|
||||
}
|
||||
"""
|
||||
logger.debug(f"[AudioDataDTO] 收到音频数据")
|
||||
# 将dict反序列化到DTO对象
|
||||
self.audio_data = AudioDataTransferObject.from_json(data)
|
||||
# 缓存最新数据
|
||||
self._latest_audio = self.audio_data
|
||||
# 如果是流式数据,加入缓冲区
|
||||
if self.audio_data.isStream:
|
||||
self._stream_buffer.append(self.audio_data)
|
||||
if self.audio_data.isEnd:
|
||||
logger.info(f"流式音频接收完成,共 {len(self._stream_buffer)} 块")
|
||||
# 通知所有注册的回调
|
||||
await self._notify_callbacks()
|
||||
|
||||
# 业务发送接口
|
||||
async def send_audio_data(self, data: AudioDataTransferObject) -> None:
|
||||
"""
|
||||
发送音频数据
|
||||
|
||||
Args:
|
||||
data: 音频数据DTO
|
||||
"""
|
||||
await self.send_audio(
|
||||
Owner=data.Owner,
|
||||
is_stream=data.isStream,
|
||||
is_start=data.isStart,
|
||||
is_end=data.isEnd,
|
||||
sequence=data.sequence,
|
||||
data=data.data,
|
||||
sampleRate=data.sampleRate,
|
||||
channelCount=data.channelCount,
|
||||
bitDepth=data.bitDepth,
|
||||
duration=data.duration,
|
||||
text=data.text
|
||||
)
|
||||
|
||||
async def send_audio(
|
||||
self,
|
||||
data: bytes,
|
||||
is_stream: bool = False,
|
||||
is_start: bool = False,
|
||||
is_end: bool = False,
|
||||
sequence: int = 0,
|
||||
**audio_meta
|
||||
) -> None:
|
||||
"""
|
||||
业务层发送音频的便捷接口
|
||||
|
||||
Args:
|
||||
data: 原始音频字节
|
||||
is_stream: 是否为流式数据
|
||||
is_start: 流式数据开始标记
|
||||
is_end: 流式数据结束标记
|
||||
sequence: 数据块序号
|
||||
**audio_meta: 其他音频参数(sampleRate, channelCount等)
|
||||
"""
|
||||
# 填充音频数据到DTO
|
||||
self.audio_data.set_dto_data(
|
||||
Owner="server" or audio_meta.get('Owner', "server"),
|
||||
isStream=is_stream,
|
||||
isStart=is_start,
|
||||
isEnd=is_end,
|
||||
sequence=sequence,
|
||||
data=data,
|
||||
sampleRate=audio_meta.get('sampleRate', 16000),
|
||||
channelCount=audio_meta.get('channelCount', 1),
|
||||
bitDepth=audio_meta.get('bitDepth', 16),
|
||||
duration=audio_meta.get('duration', 0.0),
|
||||
text=audio_meta.get('text', "")
|
||||
)
|
||||
# 序列化为JSON并发送 自动处理base64和type字段
|
||||
json_message = self.audio_data.to_json()
|
||||
await self.json_dto.send_json(json_message)
|
||||
logger.info(f"音频已发送: sequence={sequence}, 大小={len(data)} bytes")
|
||||
|
||||
# 业务接收接口
|
||||
def register_audio_callback(
|
||||
self,
|
||||
callback: Callable[[AudioDataTransferObject], Coroutine]
|
||||
) -> None:
|
||||
"""
|
||||
业务注册接收回调
|
||||
|
||||
使用示例:
|
||||
async def my_audio_handler(audio_dto: AudioDataTransferObject):
|
||||
print(f"收到音频: {len(audio_dto.data)} bytes")
|
||||
|
||||
audio_dto.register_audio_callback(my_audio_handler)
|
||||
"""
|
||||
self._audio_callbacks.append(callback)
|
||||
logger.debug(f"业务音频回调已注册,当前共 {len(self._audio_callbacks)} 个")
|
||||
|
||||
def unregister_audio_callback(self, callback) -> None:
|
||||
"""注销业务回调"""
|
||||
if callback in self._audio_callbacks:
|
||||
self._audio_callbacks.remove(callback)
|
||||
logger.debug("业务音频回调已注销")
|
||||
|
||||
def get_latest_audio(self) -> Optional[AudioDataTransferObject]:
|
||||
"""
|
||||
同步获取最新音频数据(轮询模式)
|
||||
|
||||
Returns:
|
||||
最新接收到的音频DTO,如果没有则为 None
|
||||
"""
|
||||
return self._latest_audio
|
||||
|
||||
def clear_stream_buffer(self) -> None:
|
||||
"""清空流式缓冲区"""
|
||||
self._stream_buffer.clear()
|
||||
logger.debug("流式音频缓冲区已清空")
|
||||
|
||||
# 内部通知机制
|
||||
async def _notify_callbacks(self) -> None:
|
||||
"""通知所有业务回调"""
|
||||
if not self._audio_callbacks:
|
||||
logger.warning("无业务回调,音频数据未处理")
|
||||
return
|
||||
tasks = [callback(self.audio_data) for callback in self._audio_callbacks]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
logger.debug(f"已通知 {len(self._audio_callbacks)} 个业务回调")
|
||||
# 异步迭代器 流式时使用
|
||||
def __aiter__(self):
|
||||
"""支持 async for 循环接收流式音频"""
|
||||
return self
|
||||
async def __anext__(self) -> AudioDataTransferObject:
|
||||
"""异步迭代器协议"""
|
||||
pass
|
||||
|
||||
class ScreenShotDataDTO:
|
||||
"""截屏数据交互DTO 分发给所有使用到了截屏数据的相关业务(最后一级分发)"""
|
||||
def __init__(self, json_dto : JsonDTO):
|
||||
json_dto.register_receiver('screenshot_data', self._handle_screenshot_data) # 注册JSON接收函数
|
||||
logger.info("[ScreenShotDataDTO] 截屏接收业务已注册")
|
||||
self.json_dto = json_dto
|
||||
self.screenshot_data = ScreenShotDataTransferObject() # 截屏数据对象
|
||||
# 业务回调列表,延续观察者模式
|
||||
self._screenshot_callbacks: List[Callable[[ScreenShotDataTransferObject], Coroutine]] = []
|
||||
# 最新截屏数据缓存 支持同步查询
|
||||
self._latest_screenshot: Optional[ScreenShotDataTransferObject] = None
|
||||
logger.info("[ScreenShotDataDTO] 业务接口已初始化")
|
||||
|
||||
async def _handle_screenshot_data(self, data: dict):
|
||||
"""处理截屏数据"""
|
||||
"""
|
||||
截屏数据json格式:
|
||||
{
|
||||
"Owner": "数据的拥有者(server or client)",
|
||||
"isSuccess": "是否截图成功(true or false)"
|
||||
"RealTimeScreenShot": "客户端设备的实时截图数据(base64)",
|
||||
"Width": "截图的宽度",
|
||||
"Height": "截图的高度",
|
||||
"DescribeInfo": "设备的描述信息(告知模型以做出更加准确的判断)"
|
||||
}
|
||||
"""
|
||||
logger.debug(f"[ScreenShotDataDTO] 收到截屏数据")
|
||||
# 将dict反序列化到DTO对象
|
||||
self.screenshot_data = ScreenShotDataTransferObject.from_json(data)
|
||||
# 缓存最新数据
|
||||
self._latest_screenshot = self.screenshot_data
|
||||
# 通知所有注册的回调
|
||||
await self._notify_callbacks()
|
||||
|
||||
# 业务发送接口
|
||||
async def send_screenshot_data(self, data: ScreenShotDataTransferObject) -> None:
|
||||
"""
|
||||
发送音频数据
|
||||
|
||||
Args:
|
||||
data: 音频数据DTO
|
||||
"""
|
||||
await self.send_screenshot(
|
||||
Owner=data.Owner,
|
||||
isSuccess=data.isSuccess,
|
||||
RealTimeScreenShot=data.RealTimeScreenShot,
|
||||
Width=data.Width,
|
||||
Height=data.Height,
|
||||
DescribeInfo=data.DescribeInfo,
|
||||
LLMResponse=data.LLMResponse
|
||||
)
|
||||
|
||||
async def send_screenshot(
|
||||
self,
|
||||
**screenshot_meta
|
||||
) -> None:
|
||||
"""
|
||||
业务层发送音频的便捷接口
|
||||
一般来说,作为发送请求方,不需要填充任何数据
|
||||
|
||||
Args:
|
||||
**screenshot_meta: 截屏数据元信息
|
||||
"""
|
||||
# 填充音频数据到DTO
|
||||
self.screenshot_data.set_dto_data(
|
||||
Owner="server" or screenshot_meta.get('Owner', "server"),
|
||||
isSuccess=screenshot_meta.get('isSuccess', False),
|
||||
RealTimeScreenShot=screenshot_meta.get('RealTimeScreenShot', ""),
|
||||
Width=screenshot_meta.get('Width', 1920),
|
||||
Height=screenshot_meta.get('Height', 1080),
|
||||
DescribeInfo=screenshot_meta.get('DescribeInfo', False),
|
||||
LLMResponse=screenshot_meta.get('LLMResponse', "")
|
||||
)
|
||||
# 序列化为JSON并发送 自动处理base64和type字段
|
||||
json_message = self.screenshot_data.to_json()
|
||||
await self.json_dto.send_json(json_message)
|
||||
logger.info(f"截屏包已发送")
|
||||
|
||||
# 业务接收接口
|
||||
def register_screenshot_callback(
|
||||
self,
|
||||
callback: Callable[[ScreenShotDataTransferObject], Coroutine]
|
||||
) -> None:
|
||||
"""
|
||||
业务注册接收回调
|
||||
"""
|
||||
self._screenshot_callbacks.append(callback)
|
||||
logger.debug(f"业务截屏回调已注册,当前共 {len(self._screenshot_callbacks)} 个")
|
||||
|
||||
def unregister_screenshot_callback(self, callback) -> None:
|
||||
"""注销业务回调"""
|
||||
if callback in self._screenshot_callbacks:
|
||||
self._screenshot_callbacks.remove(callback)
|
||||
logger.debug("业务截屏回调已注销")
|
||||
|
||||
def get_latest_screenshot(self) -> Optional[ScreenShotDataTransferObject]:
|
||||
"""
|
||||
同步获取最新截屏数据(轮询模式)
|
||||
|
||||
Returns:
|
||||
最新接收到的截屏数据DTO,如果没有则为 None
|
||||
"""
|
||||
return self._latest_screenshot
|
||||
|
||||
# 内部通知机制
|
||||
async def _notify_callbacks(self) -> None:
|
||||
"""通知所有业务回调"""
|
||||
if not self._screenshot_callbacks:
|
||||
logger.warning("无业务回调,截屏数据未处理")
|
||||
return
|
||||
tasks = [callback(self.screenshot_data) for callback in self._screenshot_callbacks]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
logger.debug(f"已通知 {len(self._screenshot_callbacks)} 个业务回调")
|
||||
@@ -0,0 +1,56 @@
|
||||
### 说明
|
||||
在本module当中,每个子模块的用途分别是:
|
||||
- dto
|
||||
- dto_templates
|
||||
服务端与客户端交互所使用到的数据传输对象
|
||||
- dto_base.py / xxx_dtos.py
|
||||
实际的DTO业务
|
||||
- websocket_core
|
||||
- websocket核心,承载了底层核心的网络收发业务
|
||||
|
||||
|
||||
### 模块架构
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph "Client"
|
||||
C[WebSocket Client]
|
||||
end
|
||||
|
||||
subgraph "Core Layer"
|
||||
WS[WebSocketServer<br/>单例]
|
||||
WS -->|持有| WSP[WebSocketServerProtocol<br/>_websocket]
|
||||
WS -->|管理| RCV[ receivers: Dict<br/>binary/text/json ]
|
||||
end
|
||||
|
||||
subgraph "DTO Base Layer"
|
||||
MDTO[MessageDTO<br/>抽象基类]
|
||||
MDTO -->|注入| MDF[ send_binary<br/>send_text<br/>send_json ]
|
||||
end
|
||||
|
||||
subgraph "Secondary Dispatcher"
|
||||
JDTO[JsonDTO<br/>单例]
|
||||
JDTO -->|继承| MDTO
|
||||
JDTO -->|维护| MAP[ receivers: Dict<br/>audio_data/... ]
|
||||
JDTO -->|注册到| WS
|
||||
end
|
||||
|
||||
subgraph "Business DTO"
|
||||
ADTO[AudioDataDTO......<br/>业务实现]
|
||||
ADTO -->|持有引用| JDTO
|
||||
ADTO -->|使用| ATO[AudioDataTransferObject<br/>Pydantic模型]
|
||||
end
|
||||
|
||||
C <-->|websocket连接| WSP
|
||||
|
||||
WS -->|分发消息| JDTO
|
||||
JDTO -->|二次分发| ADTO
|
||||
|
||||
ADTO -->|发送响应| JDTO
|
||||
JDTO -->|调用| MDF
|
||||
MDF -->|经由| WS
|
||||
WS -->|发送至| C
|
||||
|
||||
style WS fill:#64f,stroke:#333,stroke-width:2px
|
||||
style JDTO fill:#569,stroke:#333,stroke-width:2px
|
||||
style ADTO fill:#38f,stroke:#333,stroke-width:2px
|
||||
```
|
||||
@@ -0,0 +1,172 @@
|
||||
# websocket_core/core_ws_server.py
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Callable, Optional, Dict, Any, List, Coroutine
|
||||
from websockets.asyncio.server import serve, ServerConnection
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
from loguru import logger
|
||||
|
||||
# 类型别名
|
||||
ReceiveCallback = Callable[[Any], Coroutine[Any, Any, None]]
|
||||
class WebSocketServer:
|
||||
"""WebSocket服务端核心模块(单例 + 单客户端)
|
||||
只管理一个客户端连接,DTO层注册接收函数,服务端只负责分发。
|
||||
"""
|
||||
_instance: Optional["WebSocketServer"] = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
"""同步单例(__init__ 可以是 async)"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
# 防止重复初始化
|
||||
if self._initialized:
|
||||
return
|
||||
self._websocket: Optional[ServerConnection] = None
|
||||
self._receivers: Dict[str, List[ReceiveCallback]] = {
|
||||
'binary': [],
|
||||
'text': [],
|
||||
'json': []
|
||||
}
|
||||
self._connected_event = asyncio.Event()
|
||||
self._initialized = True
|
||||
logger.info("WebSocketServer 单客户端分发器已初始化")
|
||||
|
||||
# DTO层注册接口
|
||||
def register_receiver(self, msg_type: str, callback: ReceiveCallback) -> None:
|
||||
"""注册接收函数(供DTO层调用)
|
||||
|
||||
Args:
|
||||
msg_type: 消息类型(binary/text/json)
|
||||
callback: 接收回调,签名为 (data) -> coroutine
|
||||
"""
|
||||
if msg_type in self._receivers:
|
||||
self._receivers[msg_type].append(callback)
|
||||
logger.debug(f"已注册 {msg_type} 接收器,当前 {msg_type} 类型接收器共 {len(self._receivers[msg_type])} 个")
|
||||
else:
|
||||
raise ValueError(f"不支持的消息类型: {msg_type}")
|
||||
|
||||
def unregister_receiver(self, msg_type: str, callback: ReceiveCallback) -> None:
|
||||
"""注销接收函数"""
|
||||
if callback in self._receivers[msg_type]:
|
||||
self._receivers[msg_type].remove(callback)
|
||||
logger.debug(f"已注销 {msg_type} 接收器")
|
||||
|
||||
# 发送接口(供DTO层调用)
|
||||
async def send_binary(self, data: bytes):
|
||||
"""发送二进制数据(唯一客户端)"""
|
||||
if self._websocket:
|
||||
await self._websocket.send(data)
|
||||
logger.trace(f"二进制数据已发送 (长度: {len(data)} bytes)")
|
||||
else:
|
||||
raise RuntimeError("客户端未连接")
|
||||
|
||||
async def send_text(self, data: str):
|
||||
"""发送文本数据(唯一客户端)"""
|
||||
if self._websocket:
|
||||
await self._websocket.send(data)
|
||||
logger.trace(f"文本数据已发送 (长度: {len(data)} chars)")
|
||||
else:
|
||||
raise RuntimeError("客户端未连接")
|
||||
|
||||
async def send_json(self, data: Dict[str, Any]):
|
||||
"""发送JSON数据(唯一客户端)"""
|
||||
if self._websocket:
|
||||
try:
|
||||
logger.debug(f"准备发送JSON数据: {data}")
|
||||
message = json.dumps(data)
|
||||
await self._websocket.send(message)
|
||||
logger.trace(f"JSON数据已发送: {data}")
|
||||
except Exception as e:
|
||||
logger.error(f"JSON数据发送失败: {e}")
|
||||
raise
|
||||
else:
|
||||
raise RuntimeError("客户端未连接")
|
||||
|
||||
# 等待连接
|
||||
async def wait_for_client(self):
|
||||
"""阻塞等待客户端连接"""
|
||||
await self._connected_event.wait()
|
||||
logger.info("客户端已就绪")
|
||||
|
||||
# 内部消息循环
|
||||
async def _handle_client(self, websocket: ServerConnection):
|
||||
"""处理唯一客户端的消息循环"""
|
||||
self._websocket = websocket
|
||||
self._connected_event.set()
|
||||
client_info = f"{websocket.remote_address}" if hasattr(websocket, 'remote_address') else "unknown"
|
||||
logger.info(f"客户端已连接: {client_info}")
|
||||
|
||||
try:
|
||||
async for message in websocket:
|
||||
# 根据消息类型分发到所有注册的接收函数
|
||||
if isinstance(message, bytes):
|
||||
await self._dispatch('binary', message)
|
||||
|
||||
elif isinstance(message, str):
|
||||
# 优先尝试JSON解析
|
||||
json_dispatched = False
|
||||
try:
|
||||
data = json.loads(message)
|
||||
await self._dispatch('json', data)
|
||||
json_dispatched = True
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 如果不是JSON或没有json接收器,尝试text
|
||||
if not json_dispatched:
|
||||
await self._dispatch('text', message)
|
||||
|
||||
else:
|
||||
logger.warning(f"未知消息类型: {type(message)}")
|
||||
logger.info("客户端连接已正常关闭")
|
||||
except ConnectionClosed as e:
|
||||
logger.info(f"客户端连接已关闭: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"处理客户端时发生错误: {e}")
|
||||
finally:
|
||||
self._websocket = None
|
||||
self._connected_event.clear()
|
||||
|
||||
async def _dispatch(self, msg_type: str, data: Any):
|
||||
"""分发消息到所有注册的接收函数"""
|
||||
callbacks = self._receivers[msg_type]
|
||||
if not callbacks:
|
||||
logger.warning(f"无 {msg_type} 接收器,消息被忽略")
|
||||
return
|
||||
|
||||
# 并发执行所有回调
|
||||
tasks = [callback(data) for callback in callbacks]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 启动服务器
|
||||
async def run(self, host: str = "localhost", port: int = 8765, max_msg_size: int = 50*1024*1025):
|
||||
"""启动WebSocket服务器(阻塞)"""
|
||||
logger.info(f"WebSocket服务器启动中... 等待客户端连接 ws://{host}:{port}")
|
||||
|
||||
async def handler_wrapper(connection):
|
||||
logger.info(f"新连接请求: {connection.remote_address}")
|
||||
await self._handle_client(connection)
|
||||
|
||||
try:
|
||||
async with serve(
|
||||
handler=handler_wrapper, # 使用 wrapper 适配签名
|
||||
host=host,
|
||||
port=port,
|
||||
max_size=max_msg_size,
|
||||
):
|
||||
logger.success(f"WebSocket服务器已启动在 ws://{host}:{port}")
|
||||
await asyncio.Future() # 永久阻塞,保持服务器运行
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket服务器启动失败: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def get_ws_server() -> WebSocketServer:
|
||||
"""全局单例获取函数(线程安全)"""
|
||||
server = WebSocketServer() # __new__ 保证单例
|
||||
return server
|
||||
@@ -0,0 +1,306 @@
|
||||
# server_core/core.py
|
||||
|
||||
"""
|
||||
统一业务,对外提供启动接口
|
||||
业务数据流向:
|
||||
Yosuga[User Audio Info Struct] ->(WebSocket) Yosuga_server[asr_module] -> Text
|
||||
Yosuga_server[Come from Yosuga Audio ASR Text] ->(Func call) Yosuga_server[llm_core] -> Ins and Text
|
||||
|
||||
Yosuga_server[Come from llm_core Text]->(WebSocket) Yosuga
|
||||
Yosuga_server[Come from llm_core Text]->(Func call) Yosuga_server[TTS] -> Audio Data
|
||||
Yosuga_server[Audio Data] ->(WebSocket) Yosuga
|
||||
|
||||
Yosuga_embedded[Devices Control Info] ->(WebSocket) Yosuga_server[embedded_core]-> Ins
|
||||
Yosuga_embedded[Devices Control Info] ->(Serial) Yosuga[SerialManager] -> ForWord To Yosuga_server
|
||||
Yosuga[Come from embedded Info] ->(WebSocket) Yosuga_server[llm_core] -> Ins
|
||||
|
||||
UI_TARS[Mind and x&y Info] ->(Func call) Yosuga_server[llm_core] -> Ins
|
||||
Yosuga_server[Come from UI_TARS Ins] ->(WebSocket) Yosuga
|
||||
|
||||
Yosuga[Live2D Control Info] ->(WebSocket) Yosuga_server[llm_core] -> Ins
|
||||
Yosuga_server[Live2D Control Ins] ->(Websocket) Yosuga
|
||||
|
||||
Yosuga_server[agent memory] ->(Func call) Yosuga_server[Memory Uint]
|
||||
"""
|
||||
import asyncio
|
||||
from typing import Optional, List, Dict, Any
|
||||
from loguru import logger
|
||||
from src.modules.websocket_base_module.dto.third_dtos import (
|
||||
AudioDataDTO, AudioDataTransferObject,
|
||||
ScreenShotDataDTO, ScreenShotDataTransferObject
|
||||
)
|
||||
from src.modules.websocket_base_module.dto.second_dtos import JsonDTO, get_json_dto_instance
|
||||
from src.modules.websocket_base_module.websocket_core.core_ws_server import WebSocketServer, get_ws_server
|
||||
|
||||
from src.modules.device_control_module.device_control_core.ui_tars_.ui_tars_client import UITarsClient, UITarsClientConfig
|
||||
|
||||
from src.modules.asr_module.client.asr_client import create_asr_client, ASRClientConfig, ASRClientAsync
|
||||
|
||||
from src.modules.tts_module.tts_core.gpt_sovits.gpt_sovits_client import StreamingMode, TTSConfig, GPTSoVITSClient
|
||||
|
||||
from src.server_core.llm_core.llm_core import (
|
||||
LLMCoreConfig, ModelConfig,
|
||||
YosugaLLMCore, ModelProvider,
|
||||
LLMCoreAnalysisBase,
|
||||
YosugaAudioResponseData, YosugaUITARSResponseData,
|
||||
YosugaUITARSRequestData
|
||||
)
|
||||
|
||||
from src.modules.websocket_base_module.dto.dto_templates.auto_agent_data_dto import AutoAgentDataTransferObject
|
||||
from src.config.config import cfg
|
||||
|
||||
|
||||
class YosugaServerCore:
|
||||
"""
|
||||
异步单例类
|
||||
"""
|
||||
|
||||
_instance: Optional["YosugaServerCore"] = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
# 组合必要的工具类
|
||||
ws_server: WebSocketServer
|
||||
json_dto: JsonDTO
|
||||
audio_dto: AudioDataDTO
|
||||
screenshot_dto: ScreenShotDataDTO
|
||||
|
||||
asr_client: ASRClientAsync # 异步asr client
|
||||
tts_client: GPTSoVITSClient # tts client
|
||||
auto_agent_client: UITarsClient # GUI自动化agent
|
||||
|
||||
llm_core: YosugaLLMCore = None # llm core
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls) -> "YosugaServerCore":
|
||||
"""异步单例工厂"""
|
||||
if cls._instance is None:
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
logger.info("Initializing YosugaServerCore...")
|
||||
# 创建实例
|
||||
instance = cls.__new__(cls)
|
||||
|
||||
# 按依赖顺序初始化数据分发器
|
||||
instance.ws_server = await get_ws_server()
|
||||
instance.json_dto = await get_json_dto_instance(instance.ws_server)
|
||||
instance.audio_dto = AudioDataDTO(instance.json_dto) # 音频分发器
|
||||
instance.audio_dto.register_audio_callback(instance._handle_audio_data) # 注册音频处理函数
|
||||
instance.screenshot_dto = ScreenShotDataDTO(instance.json_dto) # 截图分发器
|
||||
instance.screenshot_dto.register_screenshot_callback(instance._handle_screenshot_data) # 注册截图处理函数
|
||||
|
||||
instance.asr_client = create_asr_client(use_async=True, base_url=cfg.asr.url)
|
||||
instance.tts_client = GPTSoVITSClient(host=cfg.tts.host, port=cfg.tts.port, debug=True)
|
||||
# 切换GPT_SoVITS模型
|
||||
await instance.tts_client.set_gpt_weights(cfg.tts.gpt_model_name)
|
||||
await instance.tts_client.set_sovits_weights(cfg.tts.sovits_model_name)
|
||||
|
||||
instance.auto_agent_client = UITarsClient(UITarsClientConfig(
|
||||
deployment_type=cfg.auto_agent.deployment_type,
|
||||
base_url=cfg.auto_agent.base_url,
|
||||
model_name=cfg.auto_agent.model_name,
|
||||
temperature=cfg.auto_agent.temperature,
|
||||
max_tokens=cfg.auto_agent.max_tokens
|
||||
))
|
||||
|
||||
instance.llm_core = YosugaLLMCore(
|
||||
model_config=ModelConfig( # TODO 同上
|
||||
provider=ModelProvider.OPENAI,
|
||||
model_name=cfg.ai.model_name,
|
||||
base_url=cfg.ai.base_url,
|
||||
api_key=cfg.ai.api_key,
|
||||
temperature=cfg.ai.temperature,
|
||||
max_tokens=cfg.ai.max_tokens
|
||||
),
|
||||
core_config=LLMCoreConfig( # TODO 同上
|
||||
max_context_tokens=cfg.llm_core.max_context_tokens,
|
||||
enable_history=cfg.llm_core.enable_history,
|
||||
role_setting=cfg.llm_core.role_character,
|
||||
language=cfg.llm_core.language, # 回复使用语言
|
||||
auto_dispatch=True,
|
||||
dispatch_async=True # 启用异步分发
|
||||
)
|
||||
)
|
||||
instance.register_llm_core_analysis() # 注册解析器
|
||||
instance.register_llm_core_action() # 注册分发器
|
||||
instance.llm_core.register_overflow_handler(instance._handle_overflow_logger) # 注册上下文溢出处理回调
|
||||
|
||||
cls._instance = instance
|
||||
logger.success("YosugaServerCore initialized")
|
||||
return cls._instance
|
||||
|
||||
def register_llm_core_action(self):
|
||||
"""
|
||||
注册llm_core的分发器
|
||||
"""
|
||||
if self.llm_core is None:
|
||||
raise Exception("LLMCore is not initialized")
|
||||
self.llm_core.register_action_handler("audio_text", self._handle_audio_response, is_async=True)
|
||||
self.llm_core.register_action_handler("auto_agent", self._handle_auto_agent, is_async=True)
|
||||
self.llm_core.register_action_handler("call_auto_agent", self._handle_call_auto_agent, is_async=True)
|
||||
self.llm_core.set_fallback_handler(self._handle_fallback)
|
||||
|
||||
def register_llm_core_analysis(self):
|
||||
"""
|
||||
注册llm_core的输出解析器
|
||||
"""
|
||||
if self.llm_core is None:
|
||||
raise Exception("LLMCore is not initialized")
|
||||
self.llm_core.register_analysis_model(YosugaAudioResponseData)
|
||||
self.llm_core.register_analysis_model(YosugaUITARSResponseData)
|
||||
self.llm_core.register_analysis_model(YosugaUITARSRequestData)
|
||||
|
||||
def _handle_overflow_logger(self, history: List[Any], metadata: Dict[str, Any]):
|
||||
"""上下文溢出记录,仅打印日志"""
|
||||
print(f" 上下文溢出!")
|
||||
print(f" 模型: {metadata['model']}")
|
||||
print(f" 消息数: {metadata['message_count']}")
|
||||
print(f" Token: {metadata['estimated_tokens']}/{metadata['limit']}")
|
||||
print(f" 即将遗忘 {len(history) // 2} 条旧消息")
|
||||
|
||||
async def _handle_audio_data(self, audio_data: AudioDataTransferObject):
|
||||
"""
|
||||
音频数据接收call back
|
||||
Yosuga_server只有接受到每次这个audio数据才会跑一次
|
||||
"""
|
||||
logger.info("Received audio data")
|
||||
# 在此处客户端发送的音频数据必定不是流式数据(考虑客户端发送数据给服务端往往是在本地的,速度极快)
|
||||
# 将音频数据发送给asr转换成文本信息,音频数据格式为wav
|
||||
# TODO: 考虑在此处做一个简单的vad检测,如果客户端发送的音频是静音的,则不把请求发给llm_core
|
||||
asr_response = await self.asr_client.transcribe_bytes(audio_data.data)
|
||||
if not asr_response.success:
|
||||
logger.error(f"ASR failed: {asr_response.error}")
|
||||
asr_result = asr_response.data # 获取asr结果
|
||||
# 将asr结果发送给llm_core进行处理
|
||||
llm_result = await self.llm_core.interact(
|
||||
user_input={ # 构造用户输入信息
|
||||
"text": asr_result.text,
|
||||
"confidence": asr_result.confidence
|
||||
}
|
||||
) # llm_core会自动进行处理并通过执行器异步返回各种相关的数据
|
||||
|
||||
async def _handle_screenshot_data(self, screenshot_data: ScreenShotDataTransferObject):
|
||||
"""
|
||||
屏幕截图数据接收call back
|
||||
将llm_core的回复封装后提交给auto_agent模块,获得自动化agent的返回之后再返回给llm_core
|
||||
"""
|
||||
logger.info(f"Received screenshot data {len(screenshot_data.RealTimeScreenShot)}")
|
||||
if not screenshot_data.isSuccess: # 如果客户端截图失败
|
||||
logger.error("Screenshot failed")
|
||||
return # 直接提前结束回调,不向llm_core发送结果
|
||||
# TODO 对于设备描述信息(screenshot_data.DescribeInfo),考虑加入到auto_agent的输入中,增强识别准确率
|
||||
# 构造请求 异步调用
|
||||
logger.debug(f"screenshot_data.LLMResponse(来自llm_core向auto_agent的输入): {screenshot_data.LLMResponse}")
|
||||
logger.debug(f"客户端设备信息: {screenshot_data.DescribeInfo}")
|
||||
auto_agent_response: str = await self.auto_agent_client.call_async(screenshot_data.LLMResponse,
|
||||
screenshot_data.RealTimeScreenShot)
|
||||
logger.debug(f"auto_agent_response(auto_agent原生返回结果): {auto_agent_response}")
|
||||
# 将auto_agent的返回结果发送给llm_core
|
||||
await self.llm_core.interact(
|
||||
user_input={ # 构造auto_agent输入信息
|
||||
"auto_agent": auto_agent_response
|
||||
}
|
||||
)
|
||||
|
||||
async def _handle_audio_response(self, data: YosugaAudioResponseData):
|
||||
"""
|
||||
llm_core异步处理器:语音回复
|
||||
将llm_core的回复封装后提交给tts模块,调用tts模块中的流式返回,并将流式frame返回给Yosuga客户端
|
||||
"""
|
||||
if data.type == "audio_text":
|
||||
logger.info("Handling audio response")
|
||||
try:
|
||||
# 使用最快模式流式输出
|
||||
chunk_count = 0
|
||||
async for chunk in await self.tts_client.tts(
|
||||
text=data.response_text,
|
||||
ref_audio_path="uploaded_audio/test_voice.wav", # TODO 需要替换成config或者后续设计情感系统
|
||||
text_lang="ja",
|
||||
prompt_lang="ja",
|
||||
prompt_text="もう!こんなところで何やってるんだよ!", # 参考语音的真实文本
|
||||
streaming_mode=StreamingMode.FASTEST, # 模式3:快速流式
|
||||
media_type="wav"
|
||||
):
|
||||
chunk_count += 1
|
||||
print(f"🎵 收到音频块 #{chunk_count}: {len(chunk.audio_data)} bytes")
|
||||
if chunk_count == 1: # 如果是第一个音频块
|
||||
# 构造音频首包发送给客户端
|
||||
await self.audio_dto.send_audio_data(
|
||||
AudioDataTransferObject(
|
||||
data=chunk.audio_data,
|
||||
isStream=True,
|
||||
isStart=True,
|
||||
sequence=chunk_count,
|
||||
isEnd=False,
|
||||
text=data.response_text
|
||||
)
|
||||
)
|
||||
else: # 如果不是第一个音频块,则发送中间包给客户端
|
||||
await self.audio_dto.send_audio_data(
|
||||
AudioDataTransferObject(
|
||||
data=chunk.audio_data,
|
||||
isStream=True,
|
||||
isStart=False,
|
||||
sequence=chunk_count,
|
||||
isEnd=False,
|
||||
text=data.response_text
|
||||
)
|
||||
)
|
||||
print(f"✅ 流式TTS完成!共{chunk_count}个音频块")
|
||||
# 构造音频尾包发送给客户端(虚假的音频数据)
|
||||
await self.audio_dto.send_audio_data(
|
||||
AudioDataTransferObject(
|
||||
data=b"0",
|
||||
isStream=True,
|
||||
isStart=False,
|
||||
sequence=chunk_count + 1,
|
||||
isEnd=True,
|
||||
text=data.response_text
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"❌ 流式错误: {e}")
|
||||
return {"status": "success", "executed": data.response_text}
|
||||
return None
|
||||
|
||||
async def _handle_auto_agent(self, data: YosugaUITARSResponseData):
|
||||
"""
|
||||
llm_core异步处理器:处理自动化操作
|
||||
将llm_core的回复封装后提交给Yosuga客户端,由客户端进行执行相关的GUI自动化操作
|
||||
"""
|
||||
# 构造并发送回复数据
|
||||
await self.json_dto.send_json(
|
||||
AutoAgentDataTransferObject.from_json(data.to_dict()).to_json()
|
||||
)
|
||||
return {"status": "success", "executed": data.Action}
|
||||
|
||||
async def _handle_call_auto_agent(self, data: YosugaUITARSRequestData):
|
||||
"""
|
||||
llm_core异步处理器:处理llm_core调用auto_agent需求
|
||||
向客户端请求当前界面的截图,请求成功后由_handle_screenshot_data函数完成剩下的任务
|
||||
"""
|
||||
if data.type == "call_auto_agent":
|
||||
logger.info("LLM Calling auto agent")
|
||||
# 向客户端请求当前界面的截图的base64编码 加入llm回复的信息到截图请求DTO当中 方便_handle_screenshot_data构造请求
|
||||
await self.screenshot_dto.send_screenshot_data(ScreenShotDataTransferObject(LLMResponse=data.llm_translation))
|
||||
return {"status": "success", "executed": data.type}
|
||||
|
||||
def _handle_fallback(self, data: LLMCoreAnalysisBase):
|
||||
"""
|
||||
llm_core同步处理器:回退处理器
|
||||
"""
|
||||
logger.debug(f" [Fallback] 未知类型数据: {data.type}, 内容: {data.model_dump_json()}")
|
||||
|
||||
async def run(self):
|
||||
"""启动服务器"""
|
||||
logger.info("Yosuga Server Websocket Core 启动中...")
|
||||
await self.ws_server.run(host="0.0.0.0")
|
||||
|
||||
|
||||
# 使用方式
|
||||
async def main():
|
||||
core = await YosugaServerCore.get_instance()
|
||||
await core.run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,2 @@
|
||||
### 本模块为Yosuga_server的AI输出语音情感管理模块
|
||||
|
||||
@@ -0,0 +1,564 @@
|
||||
# llm_core/llm_core.py
|
||||
|
||||
"""
|
||||
Yosuga Server LLM 核心控制模块
|
||||
负责整合Prompt管理、模型调用、输出解析、上下文记忆管理以及生命周期维护。
|
||||
作为系统的"大脑",对外提供统一的高级交互接口。
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional, Callable, Union, Type
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
# 引入已有的模块
|
||||
from src.modules.text_ai_module.text_ai_core.general_text_ai_req import (
|
||||
UnifiedLLM, ModelConfig, ChatMessage, ModelResponse, ModelProvider
|
||||
)
|
||||
from src.server_core.llm_core.llm_core_analysis import (
|
||||
LLMCoreAnalysisManager, LLMCoreAnalysisBase, YosugaAudioResponseData, YosugaUITARSResponseData, YosugaUITARSRequestData
|
||||
)
|
||||
from src.server_core.llm_core.llm_core_dispatcher import LLMCoreActionDispatcher
|
||||
from src.server_core.llm_core.llm_core_prompt_manager import (
|
||||
LLMCorePromptManager, LLMCorePromptBase,
|
||||
YosugaAudioASRText, YosugaUITARS, YosugaLive2DControl
|
||||
)
|
||||
from src.server_core.llm_core.llm_core_prompts import YOSUGA_SYSTEM_PROMPT_SCH
|
||||
from src.server_core.llm_core.llm_core_token import TokenManager, TokenUsage
|
||||
|
||||
# 类型定义:上下文溢出回调函数签名
|
||||
# 参数1: 溢出的历史记录列表
|
||||
# 参数2: 相关的元数据
|
||||
ContextOverflowCallback = Callable[[List[ChatMessage], Dict[str, Any]], None]
|
||||
|
||||
class LLMCoreConfig(BaseModel):
|
||||
"""LLM Core 运行时配置"""
|
||||
max_context_tokens: int = Field(default=2000, description="上下文最大Token数(估算值,不包括System prompt),超出触发重置")
|
||||
enable_history: bool = Field(default=True, description="是否启用历史对话记忆")
|
||||
language: str = Field(default="zh_CN", description="回复语言设定")
|
||||
role_setting: str = Field(default="", description="llm角色扮演")
|
||||
auto_dispatch: bool = Field(default=True, description="是否自动分发到动作处理器")
|
||||
dispatch_async: bool = Field(default=False, description="分发是否使用异步模式")
|
||||
memory: str = Field(default="", description="llm记忆")
|
||||
system_state_table: str = Field(default="", description="Yosuga系统状态表")
|
||||
|
||||
|
||||
|
||||
class YosugaLLMCore:
|
||||
"""
|
||||
Yosuga 服务端 LLM 核心控制器
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: ModelConfig, core_config: Optional[LLMCoreConfig] = None):
|
||||
"""
|
||||
初始化 LLM Core
|
||||
|
||||
Args:
|
||||
model_config: 底层大模型的连接配置
|
||||
core_config: 核心业务逻辑配置(上下文限制等)
|
||||
"""
|
||||
self.model_config = model_config
|
||||
self.core_config = core_config or LLMCoreConfig()
|
||||
# 初始化模型客户端 (UnifiedLLM)
|
||||
self.llm_client: UnifiedLLM = UnifiedLLM(self.model_config)
|
||||
# 初始化 TokenManager
|
||||
self.token_manager = TokenManager(self.model_config.model_name)
|
||||
# 初始化Prompt管理器
|
||||
self.prompt_manager = LLMCorePromptManager()
|
||||
self._register_default_prompts()
|
||||
# 上下文记忆存储
|
||||
self._history: List[ChatMessage] = [] # 注意:history不包含system prompt,只包含 user/assistant 消息
|
||||
# 上下文溢出回调列表
|
||||
self._overflow_callbacks: List[ContextOverflowCallback] = []
|
||||
logger.info(
|
||||
f"YosugaLLMCore 初始化完成 | "
|
||||
f"模型: {model_config.model_name} | "
|
||||
f"提供商: {model_config.provider}"
|
||||
)
|
||||
logger.info(
|
||||
f"上下文限制: {self.core_config.max_context_tokens} tokens | "
|
||||
f"自动分发: {self.core_config.auto_dispatch}"
|
||||
)
|
||||
|
||||
def _register_default_prompts(self):
|
||||
"""注册默认的业务Prompt模块"""
|
||||
self.prompt_manager.register(YosugaAudioASRText())
|
||||
self.prompt_manager.register(YosugaUITARS())
|
||||
# self.prompt_manager.register(YosugaLive2DControl()) # TODO
|
||||
logger.info(f"默认Prompt模块注册完成 | 数量: {self.prompt_manager.get_registry_size()}")
|
||||
|
||||
# 系统提示词管理
|
||||
def get_system_prompt(self) -> str:
|
||||
"""
|
||||
动态构建当前的 System Prompt
|
||||
根据 prompt_manager 中注册的模块实时生成
|
||||
"""
|
||||
return YOSUGA_SYSTEM_PROMPT_SCH.format(
|
||||
InputInfo=self.prompt_manager.describe_input(), # 不变的内容
|
||||
OutputInfo=self.prompt_manager.describe_output(), # 不变的内容
|
||||
RoleSetting=self.core_config.role_setting, # 角色扮演,可热重载
|
||||
Language=self.core_config.language, # 回复语言,可热重载
|
||||
Memory=self.core_config.memory, # 记忆,可热重载,请求前更新
|
||||
SystemStateTable=self.core_config.system_state_table# 系统状态表,可热重载,每次请求都会更新
|
||||
)
|
||||
|
||||
def register_prompt_module(self, prompt_module: LLMCorePromptBase):
|
||||
"""运行时注册新的 Prompt 业务模块"""
|
||||
self.prompt_manager.register(prompt_module)
|
||||
logger.info(f"动态注册 Prompt 模块: {prompt_module.type()}")
|
||||
|
||||
def register_analysis_model(self, model_class: Type[LLMCoreAnalysisBase]) -> None:
|
||||
"""
|
||||
注册LLM输出解析模型
|
||||
|
||||
Args:
|
||||
model_class: 继承自 LLMCoreAnalysisBase 的数据模型类
|
||||
"""
|
||||
LLMCoreAnalysisManager.register(model_class)
|
||||
logger.info(f"注册解析模型: {model_class.type_()}")
|
||||
|
||||
def register_action_handler(
|
||||
self,
|
||||
type_id: str,
|
||||
handler: Callable,
|
||||
is_async: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
注册动作处理器
|
||||
|
||||
Args:
|
||||
type_id: 与解析模型对应的类型标识
|
||||
handler: 处理函数(同步或异步)
|
||||
is_async: 是否为异步处理器
|
||||
"""
|
||||
if is_async:
|
||||
LLMCoreActionDispatcher.register_async(type_id, handler)
|
||||
else:
|
||||
LLMCoreActionDispatcher.register(type_id, handler)
|
||||
|
||||
def set_fallback_handler(self, handler: Callable) -> None:
|
||||
"""设置未注册类型的回退处理器"""
|
||||
LLMCoreActionDispatcher.set_fallback(handler)
|
||||
|
||||
# 核心交互接口
|
||||
async def interact(
|
||||
self,
|
||||
user_input: Union[str, Dict[str, Any]], # 输入(纯文本或结构化字典)
|
||||
past_memories: Optional[str] = "", # 记忆模块检索的相关历史记忆
|
||||
system_state_table: Optional[str] = "", # 系统状态表,可热重载,每次请求都会更新
|
||||
auto_dispatch: Optional[bool] = True, # 是否自动分发,默认为启用
|
||||
dispatch_async: Optional[bool] = True # 是否异步分发,默认为启用
|
||||
) -> Dict[str, List[Any]]:
|
||||
"""
|
||||
核心交互方法:处理输入 -> 组装上下文 -> 调用LLM -> 解析输出
|
||||
|
||||
Args:
|
||||
user_input: 用户输入(纯文本或结构化字典)
|
||||
past_memories: 记忆模块检索的相关历史记忆
|
||||
system_state_table: Yosuga系统状态表(方便llm理解当前系统状态)
|
||||
auto_dispatch: 是否自动分发(覆盖默认配置)
|
||||
dispatch_async: 是否异步分发(覆盖默认配置)
|
||||
|
||||
Returns:
|
||||
分发执行结果字典:
|
||||
{
|
||||
"success": [{"type": "...", "output": ..., "index": 0}],
|
||||
"failed": [{"type": "...", "error": "...", "index": 1}],
|
||||
"skipped": ["unknown_type"]
|
||||
}
|
||||
|
||||
Raises:
|
||||
ValueError: LLM调用或解析失败
|
||||
RuntimeError: 分发执行致命错误
|
||||
"""
|
||||
# 输入预处理
|
||||
input_content = (
|
||||
json.dumps(user_input, ensure_ascii=False)
|
||||
if isinstance(user_input, dict)
|
||||
else user_input
|
||||
)
|
||||
logger.info(f"用户输入内容经过json处理后为: {input_content}")
|
||||
|
||||
# 检查并维护上下文
|
||||
self._maintain_context_limit()
|
||||
|
||||
# 构建本次请求的消息链
|
||||
messages = self._build_request_messages(input_content, past_memories, system_state_table)
|
||||
|
||||
# 调用 LLM
|
||||
try:
|
||||
llm_response = self._call_llm(messages)
|
||||
if llm_response.usage: # 打印请求消耗的Token数量
|
||||
self.token_manager.record_api_usage(llm_response.usage)
|
||||
logger.info(self.token_manager.format_usage_log(source="API")) # 使用 TokenManager 格式化日志
|
||||
except Exception as e:
|
||||
logger.error(f"LLM调用失败: {e}")
|
||||
raise
|
||||
|
||||
# 统一解析(总是返回列表)
|
||||
try:
|
||||
parsed_results = LLMCoreAnalysisManager.parse(llm_response.content)
|
||||
logger.success(f"解析成功 | 对象数: {len(parsed_results)}")
|
||||
except Exception as e:
|
||||
logger.error(f"输出解析失败: {e}")
|
||||
raise
|
||||
|
||||
# 更新历史记忆
|
||||
if self.core_config.enable_history:
|
||||
self._add_to_history("user", input_content)
|
||||
self._add_to_history("assistant", llm_response.content)
|
||||
|
||||
# 分发执行
|
||||
should_dispatch = auto_dispatch if auto_dispatch is not None else self.core_config.auto_dispatch
|
||||
if should_dispatch and parsed_results:
|
||||
is_async = dispatch_async if dispatch_async is not None else self.core_config.dispatch_async
|
||||
if is_async:
|
||||
# 直接 await 异步执行,如果调用 asyncio.run() 就会因为重复创建事件循环导致报错
|
||||
logger.debug("使用异步模式分发动作")
|
||||
return await LLMCoreActionDispatcher._execute_async(parsed_results)
|
||||
else:
|
||||
# 同步模式保持原样
|
||||
return LLMCoreActionDispatcher.execute(parsed_results, run_async=False)
|
||||
# 不分发则返回原始解析结果(极少用)
|
||||
return {"success": [{"type": obj.type, "output": obj, "index": i}
|
||||
for i, obj in enumerate(parsed_results)],
|
||||
"failed": [], "skipped": []}
|
||||
|
||||
def _build_request_messages(
|
||||
self,
|
||||
current_input: str,
|
||||
memories: str,
|
||||
system_state_table: str
|
||||
) -> List[ChatMessage]:
|
||||
"""
|
||||
构建完整的LLM消息链
|
||||
结构:
|
||||
1. System Prompt 构造,包括记忆注入等信息
|
||||
2. 历史上下文
|
||||
3. 当前用户输入
|
||||
"""
|
||||
# 构建memory与其他信息
|
||||
self.core_config.memory = memories
|
||||
# 构建系统状态表
|
||||
self.core_config.system_state_table = system_state_table
|
||||
# 构造 System Prompt
|
||||
messages = [ChatMessage(role="system", content=self.get_system_prompt())]
|
||||
|
||||
# 历史上下文
|
||||
if self.core_config.enable_history:
|
||||
messages.extend(self._history) # 分开每条消息追加
|
||||
# 当前用户输入
|
||||
messages.append(ChatMessage(role="user", content=current_input))
|
||||
return messages
|
||||
|
||||
def _call_llm(self, messages: List[ChatMessage]) -> ModelResponse:
|
||||
"""执行LLM调用(非流式)"""
|
||||
logger.debug(f"请求LLM | 消息数: {len(messages)}")
|
||||
# 预估token使用情况 TODO: (调试用)
|
||||
estimated = self.token_manager.estimate_chat_tokens(
|
||||
system_prompt=self.get_system_prompt(),
|
||||
history=self._history,
|
||||
current_input=messages[-1].content if messages else ""
|
||||
)
|
||||
logger.debug(self.token_manager.format_usage_log(estimated.to_dict(), source="MANUAL"))
|
||||
|
||||
# 强制非流式(结构化输出需要完整JSON)
|
||||
response: ModelResponse = self.llm_client.chat(
|
||||
messages,
|
||||
streaming=False,
|
||||
temperature=self.model_config.temperature
|
||||
)
|
||||
# 记录 API 返回的 usage
|
||||
if response.usage:
|
||||
self.token_manager.record_api_usage(response.usage)
|
||||
logger.debug(f"LLM响应长度: {len(response.content)}")
|
||||
return response
|
||||
|
||||
# 上下文与记忆管理
|
||||
def _add_to_history(self, role: str, content: str):
|
||||
"""添加消息到历史"""
|
||||
self._history.append(ChatMessage(role=role, content=content))
|
||||
|
||||
def _maintain_context_limit(self) ->None:
|
||||
"""
|
||||
检查上下文是否超出限制
|
||||
如果超出,触发溢出回调,并将当前上下文导出,然后清空最近50%的
|
||||
"""
|
||||
if not self.core_config.enable_history: # 若未启用历史对话记忆
|
||||
return
|
||||
# 使用 TokenManager 获取上下文占用(手动计算)
|
||||
context_usage = self.token_manager.get_context_usage(self._history)
|
||||
current_usage = context_usage.total_tokens
|
||||
|
||||
# 使用 TokenManager 格式化日志
|
||||
logger.debug(
|
||||
self.token_manager.format_usage_log(context_usage.to_dict(), source="CONTEXT")
|
||||
)
|
||||
limit = self.core_config.max_context_tokens
|
||||
# 使用 TokenManager 判断是否接近限制
|
||||
if self.token_manager.is_token_limit_approaching(current_usage, limit, threshold=0.85):
|
||||
logger.warning(
|
||||
f"上下文接近限制: {current_usage}/{limit} "
|
||||
f"({current_usage / limit:.1%})"
|
||||
)
|
||||
if current_usage <= limit:
|
||||
return
|
||||
# 否则就是溢出
|
||||
logger.critical(
|
||||
f"上下文溢出!| {current_usage}/{limit} tokens "
|
||||
f"({current_usage / limit:.1%}) | 消息: {len(self._history)}"
|
||||
)
|
||||
# 执行所有注册的溢出处理器
|
||||
self._trigger_overflow_callbacks()
|
||||
|
||||
# 智能清理:保留最近50%消息
|
||||
keep_messages = max(1, len(self._history) // 2)
|
||||
self._history = self._history[-keep_messages:]
|
||||
# 求出新的token占有
|
||||
new_usage = self._estimate_token_usage()
|
||||
logger.success( # 打印清理前后的token占用变化
|
||||
f"清理完成 | Token: {current_usage}→{new_usage} | "
|
||||
f"保留消息: {len(self._history)}"
|
||||
)
|
||||
|
||||
def _estimate_token_usage(self) -> int:
|
||||
"""
|
||||
使用 TokenManager 计算当前历史记录的 Token 数
|
||||
"""
|
||||
if not self._history:
|
||||
return 0
|
||||
# 将 ChatMessage 对象转换为字典格式
|
||||
history_dicts = [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in self._history
|
||||
]
|
||||
|
||||
return self.token_manager.count_messages_tokens(
|
||||
history_dicts,
|
||||
tokens_per_message=3 # OpenAI 格式开销
|
||||
)
|
||||
|
||||
def register_overflow_handler(self, handler: ContextOverflowCallback):
|
||||
"""
|
||||
注册上下文溢出处理器(支持多个目标)
|
||||
这个上下文溢出处理器用于将溢出的消息收集并记录,和记忆模块对接
|
||||
"""
|
||||
self._overflow_callbacks.append(handler)
|
||||
logger.info(f"注册溢出处理器: {handler.__name__}")
|
||||
|
||||
def _trigger_overflow_callbacks(self):
|
||||
"""执行所有注册的溢出处理器"""
|
||||
if not self._overflow_callbacks:
|
||||
return
|
||||
# 使用 TokenManager 获取当前占用
|
||||
context_usage = self.token_manager.get_context_usage(self._history)
|
||||
metadata = { # 构造详细的元数据
|
||||
"reason": "token_limit_exceeded",
|
||||
"message_count": len(self._history),
|
||||
"estimated_tokens": context_usage.total_tokens,
|
||||
"limit": self.core_config.max_context_tokens,
|
||||
"timestamp": time.time(),
|
||||
"model": self.model_config.model_name
|
||||
}
|
||||
# 快照当前历史,防止回调修改
|
||||
history_snapshot = list(self._history)
|
||||
|
||||
for handler in self._overflow_callbacks:
|
||||
try:
|
||||
handler(history_snapshot, metadata)
|
||||
logger.debug(f"溢出处理器成功: {handler.__name__}")
|
||||
except Exception as e:
|
||||
logger.error(f"执行上下文溢出回调失败: {handler.__name__}:{e}")
|
||||
|
||||
def clear_context(self):
|
||||
"""手动清空上下文"""
|
||||
self._history.clear()
|
||||
logger.info("上下文记忆已清空")
|
||||
|
||||
# 运行时热重载
|
||||
def reload_model(self, new_model_config: ModelConfig):
|
||||
"""
|
||||
热重载 LLM 模型配置
|
||||
不影响当前的上下文记忆和 System Prompt
|
||||
"""
|
||||
logger.info(f"正在热重载模型: {self.model_config.model_name} -> {new_model_config.model_name}")
|
||||
try:
|
||||
self.llm_client.update_config(new_model_config)
|
||||
self.model_config = new_model_config
|
||||
# 重新初始化 TokenManager
|
||||
self.token_manager = TokenManager(new_model_config.model_name)
|
||||
logger.info("模型热重载成功")
|
||||
except Exception as e:
|
||||
logger.error(f"模型热重载失败: {e}")
|
||||
raise
|
||||
|
||||
def get_context_stats(self) -> Dict[str, Any]:
|
||||
"""获取详细上下文统计"""
|
||||
# 使用 TokenManager 获取当前占用
|
||||
context_usage = self.token_manager.get_context_usage(self._history)
|
||||
tokenizer_info = self.token_manager.get_tokenizer_info()
|
||||
return {
|
||||
"message_count": len(self._history),
|
||||
"estimated_tokens": context_usage.total_tokens,
|
||||
"limit": self.core_config.max_context_tokens,
|
||||
"usage_ratio": context_usage.total_tokens / self.core_config.max_context_tokens,
|
||||
"model": self.model_config.model_name,
|
||||
"tokenizer": tokenizer_info,
|
||||
"history_preview": [
|
||||
f"{msg.role[:1]}:{msg.content[:30]}..."
|
||||
for msg in self._history[-3:]
|
||||
],
|
||||
"last_api_usage": self.token_manager._last_api_usage.to_dict() if self.token_manager._last_api_usage else None
|
||||
}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return ( # 返回描述信息
|
||||
f"YosugaLLMCore(model={self.model_config.model_name}, "
|
||||
f"provider={self.model_config.provider.value}, "
|
||||
f"history_len={len(self._history)})"
|
||||
)
|
||||
|
||||
|
||||
# 使用示例与测试
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
# 配置日志输出
|
||||
logger.remove()
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("🚀 Yosuga Server LLM Core 启动测试")
|
||||
print("=" * 50 + "\n")
|
||||
|
||||
# 准备模拟的动作处理器 (Mock Handlers)
|
||||
|
||||
# 异步处理器示例:处理语音回复
|
||||
async def handle_audio_response(data: LLMCoreAnalysisBase):
|
||||
# 这里强制类型转换为具体子类以获取代码提示(实际运行时已经是具体类型)
|
||||
if data.type == "audio_text":
|
||||
print(f" [Audio Handler] 正在合成语音: {data.response_text} | 情感: {data.emotion}")
|
||||
return {"audio_file": "sample.wav", "duration": 3.5}
|
||||
return None
|
||||
|
||||
# 异步处理器示例:处理自动化操作
|
||||
async def handle_auto_agent(data: LLMCoreAnalysisBase):
|
||||
print(f" [UI Agent] 收到指令: {data.Action}")
|
||||
await asyncio.sleep(1) # 模拟耗时操作
|
||||
print(f" [UI Agent] 执行动作: {data.Action} -> ({data.x1}, {data.y1})")
|
||||
return {"status": "success", "executed": data.Action}
|
||||
|
||||
async def handle_call_auto_agent(data: LLMCoreAnalysisBase):
|
||||
print(f" [Call Agent] 收到内容: {data.type}")
|
||||
print(f" [Call Agent] 收到内容: {data.llm_translation}")
|
||||
|
||||
# 回退处理器
|
||||
def handle_fallback(data: LLMCoreAnalysisBase):
|
||||
print(f" [Fallback] 未知类型数据: {data.type}, 内容: {data.model_dump_json()}")
|
||||
|
||||
# 初始化 LLM Core
|
||||
|
||||
# 配置 LM Studio 连接
|
||||
# 注意:LM Studio 通常兼容 OpenAI 格式,所以 provider 选 LM_STUDIO 或 OPENAI 均可
|
||||
# 如果是本地服务,API Key 可以随意填写
|
||||
model_cfg = ModelConfig(
|
||||
provider=ModelProvider.LM_STUDIO,
|
||||
model_name="qwen/qwen3-4b-2507",
|
||||
base_url="http://192.168.1.3:1234/v1",
|
||||
api_key="lm-studio",
|
||||
temperature=0.3,
|
||||
max_tokens=2048
|
||||
)
|
||||
|
||||
core_cfg = LLMCoreConfig(
|
||||
max_context_tokens=1024,
|
||||
enable_history=True,
|
||||
role_setting="你是由Misakiotoha开发的Yosuga助手,性格抽象,爱说点小骚话。",
|
||||
auto_dispatch=True,
|
||||
dispatch_async=True # 启用异步分发测试
|
||||
)
|
||||
|
||||
core = YosugaLLMCore(model_cfg, core_cfg)
|
||||
|
||||
# 注册处理器
|
||||
core.register_action_handler("audio_text", handle_audio_response, is_async=True)
|
||||
core.register_action_handler("auto_agent", handle_auto_agent, is_async=True)
|
||||
core.register_action_handler("call_auto_agent", handle_call_auto_agent, is_async=True)
|
||||
core.set_fallback_handler(handle_fallback)
|
||||
|
||||
# 注册解析器
|
||||
core.register_analysis_model(YosugaAudioResponseData)
|
||||
core.register_analysis_model(YosugaUITARSResponseData)
|
||||
core.register_analysis_model(YosugaUITARSRequestData)
|
||||
|
||||
# 交互测试 Loop
|
||||
def run_tests():
|
||||
# 测试场景 1: 普通对话 (触发 Audio 解析)
|
||||
print("\n📝 测试 1: 普通对话 (预期触发 audio_text)")
|
||||
asr_input_1 = {
|
||||
"text": "你好,Yosuga!请介绍一下你自己,并对我微笑。",
|
||||
"confidence": 0.99
|
||||
}
|
||||
try:
|
||||
result = core.interact(
|
||||
asr_input_1,
|
||||
dispatch_async=True
|
||||
)
|
||||
print(f"🏁 交互结果: {json.dumps(result, ensure_ascii=False, indent=2)}")
|
||||
except Exception as e:
|
||||
logger.error(f"测试1失败: {e}")
|
||||
|
||||
# 测试场景 2: 复杂指令 (预期同时触发 Audio 和 UI 操作)
|
||||
print("\n📝 测试 2: 混合指令 (预期触发 audio_text + auto_agent)")
|
||||
# 构造一个复杂的 Prompt 输入,诱导模型输出多条指令
|
||||
# 注意:这依赖于模型足够聪明能理解 System Prompt 中的 output schema
|
||||
complex_input = """
|
||||
[{
|
||||
"text": "打开系统设置",
|
||||
"confidence": 0.99
|
||||
}]
|
||||
"""
|
||||
|
||||
try:
|
||||
result = core.interact(complex_input, dispatch_async=True)
|
||||
print(f"🏁 交互结果: {json.dumps(result, ensure_ascii=False, indent=2)}")
|
||||
except Exception as e:
|
||||
logger.error(f"测试2失败: {e}")
|
||||
|
||||
# 热重载测试
|
||||
print("\n🔄 测试 3: 模型热重载")
|
||||
new_model_cfg = ModelConfig(
|
||||
provider=ModelProvider.LM_STUDIO,
|
||||
model_name="qwen/qwen3-vl-8b",
|
||||
base_url="http://192.168.1.3:1234/v1",
|
||||
api_key="lm-studio",
|
||||
temperature=0.5
|
||||
)
|
||||
|
||||
try:
|
||||
core.reload_model(new_model_cfg)
|
||||
print("✅ 热重载完成,进行验证对话...")
|
||||
|
||||
# 验证重载后是否还能对话(保留了上下文)
|
||||
verify_input = """
|
||||
{
|
||||
"text": "系统设置打开了吗",
|
||||
"confidence": 0.95
|
||||
},
|
||||
{
|
||||
"auto_agent": "Thought: 我注意到屏幕左下角有一个齿轮形状的图标,这正是系统的设置入口。在KDE系统中,这个图标通常用来访问系统设置面板。为了帮助用户打开设置界面,我现在需要点击这个位于屏幕左下方的齿轮图标。
|
||||
Action: click(start_box='<|box_start|>(42,1045)<|box_end|>')"
|
||||
}
|
||||
}"""
|
||||
result = core.interact(verify_input)
|
||||
print(f"🏁 重载后回复: {json.dumps(result, ensure_ascii=False, indent=2)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"热重载测试失败: {e}")
|
||||
|
||||
# 查看统计信息
|
||||
print("\n📊 最终状态统计:")
|
||||
print(core.get_context_stats())
|
||||
|
||||
# 运行测试
|
||||
run_tests()
|
||||
@@ -0,0 +1,313 @@
|
||||
# llm_core/llm_core_analysis.py
|
||||
|
||||
"""
|
||||
LLM输出解析与序列化模块
|
||||
将LLM返回的JSON字符串智能解析为强类型Python对象,供其他模块直接调用
|
||||
支持多类型混合响应,自动路由到对应数据模型
|
||||
"""
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, ClassVar, Dict, Optional, Type, List
|
||||
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||
from loguru import logger
|
||||
import re
|
||||
|
||||
# 抽象基类
|
||||
class LLMCoreAnalysisBase(BaseModel, ABC):
|
||||
"""
|
||||
LLM输出数据模型抽象基类
|
||||
各场景通过继承定义具体的数据结构
|
||||
"""
|
||||
|
||||
# 解析器type标识
|
||||
type: str = Field(..., description="场景类型标识")
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def type_(cls) -> str:
|
||||
"""返回该模型对应的场景类型标识"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_schema(cls) -> Dict[str, Any]:
|
||||
"""
|
||||
返回该场景的数据结构schema
|
||||
用于生成system prompt中的{OutputInfo}
|
||||
"""
|
||||
return cls.model_json_schema()
|
||||
|
||||
# 管理器
|
||||
class LLMCoreAnalysisManager:
|
||||
"""
|
||||
LLM输出解析管理器
|
||||
智能路由:根据JSON中的type_字段,自动选择对应模型进行解析
|
||||
"""
|
||||
|
||||
# 类变量:存储所有注册的数据模型类
|
||||
_model_registry: ClassVar[Dict[str, Type[LLMCoreAnalysisBase]]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, model_class: Type[LLMCoreAnalysisBase]) -> None:
|
||||
"""
|
||||
注册场景数据模型类
|
||||
|
||||
Args:
|
||||
model_class: 继承自LLMCoreAnalysisBase的数据模型类
|
||||
|
||||
Example:
|
||||
LLMCoreAnalysisManager.register(YosugaAudioResponseData)
|
||||
"""
|
||||
type_id = model_class.type_()
|
||||
cls._model_registry[type_id] = model_class
|
||||
logger.info(f"已注册LLM数据输出解析模型: {type_id} -> {model_class.__name__}")
|
||||
|
||||
@classmethod
|
||||
def parse(cls, json_str: str) -> List[LLMCoreAnalysisBase]:
|
||||
"""
|
||||
统一解析入口:无论单对象还是数组,总是返回对象列表
|
||||
|
||||
Args:
|
||||
json_str: LLM原始JSON字符串(支持markdown代码块)
|
||||
|
||||
Returns:
|
||||
解析后的模型实例列表,顺序与JSON数组一致
|
||||
|
||||
Raises:
|
||||
ValidationError: 格式校验失败
|
||||
ValueError: JSON解析失败
|
||||
"""
|
||||
print(f"待解析的内容为(llm本次输出原生内容):{json_str}") # TODO:delete
|
||||
|
||||
cleaned = cls._clean_markdown(json_str) # 清理markdown标记
|
||||
try:
|
||||
data = json.loads(cleaned)
|
||||
# 统一包装成列表
|
||||
if not isinstance(data, list):
|
||||
data = [data]
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON解析失败: {e}\n原始输出: {cleaned[:200]}...")
|
||||
raise ValueError(f"无效的JSON格式: {e}")
|
||||
|
||||
results: List[LLMCoreAnalysisBase] = []
|
||||
for idx, item in enumerate(data):
|
||||
type_id = item.get("type")
|
||||
if not type_id:
|
||||
logger.warning(f"跳过第{idx}个元素(无type字段): {item}")
|
||||
continue
|
||||
|
||||
if type_id not in cls._model_registry:
|
||||
logger.warning(f"未注册的类型 '{type_id}',可用类型: {list(cls._model_registry.keys())}")
|
||||
continue
|
||||
|
||||
model_class = cls._model_registry[type_id]
|
||||
try:
|
||||
# 重新序列化为字符串再解析(保持接口兼容)
|
||||
item_json = json.dumps(item, ensure_ascii=False)
|
||||
result = model_class.model_validate_json(item_json)
|
||||
results.append(result)
|
||||
except ValidationError as e:
|
||||
logger.error(f"第{idx}个对象校验失败 (type={type_id}): {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"第{idx}个对象解析失败 (type={type_id}): {e}")
|
||||
continue
|
||||
|
||||
logger.success(f"解析完成 | 成功: {len(results)}/{len(data)}")
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _clean_markdown(json_str: str) -> str:
|
||||
# 尝试找到第一个 '[' 和最后一个 ']'
|
||||
start = json_str.find('[')
|
||||
end = json_str.rfind(']')
|
||||
if start != -1 and end != -1:
|
||||
return json_str[start:end + 1]
|
||||
return json_str # Fallback
|
||||
|
||||
# 具体场景数据模型
|
||||
class YosugaAudioResponseData(LLMCoreAnalysisBase):
|
||||
"""
|
||||
音频ASR场景的LLM输出数据模型
|
||||
|
||||
使用示例:
|
||||
LLMCoreAnalysisManager.register(YosugaAudioResponseData)
|
||||
data = YosugaAudioResponseData.parse_raw('{"type_": "audio_text", "response_text": "你好"}')
|
||||
data.response_text # 直接属性访问
|
||||
'你好'
|
||||
data.emotion # 默认值
|
||||
'neutral'
|
||||
"""
|
||||
|
||||
type: str = Field(default="audio_text", description="固定为audio_text")
|
||||
response_text: str = Field(..., description="回复文本")
|
||||
emotion: str = Field(default="neutral", description="情感基调")
|
||||
action: str = Field(default="none", description="动作指令")
|
||||
|
||||
@classmethod
|
||||
def type_(cls) -> str:
|
||||
return "audio_text"
|
||||
|
||||
class YosugaUITARSResponseData(LLMCoreAnalysisBase):
|
||||
"""
|
||||
自动化操作场景的LLM输出数据模型
|
||||
"""
|
||||
|
||||
type: str = Field(default="auto_agent", description="固定为auto_agent")
|
||||
Action: str = Field(..., description="动作名称")
|
||||
x1: Optional[int] = Field(default=None, description="起点x")
|
||||
y1: Optional[int] = Field(default=None, description="起点y")
|
||||
x2: Optional[int] = Field(default=None, description="终点x")
|
||||
y2: Optional[int] = Field(default=None, description="终点y")
|
||||
key: Optional[str] = Field(default="", description="快捷键")
|
||||
content: Optional[str] = Field(default="", description="输入文本")
|
||||
direction: Optional[str] = Field(default="", description="滚动方向")
|
||||
|
||||
@classmethod
|
||||
def type_(cls) -> str:
|
||||
return "auto_agent"
|
||||
|
||||
@field_validator('x1', 'y1', 'x2', 'y2', mode='before')
|
||||
@classmethod
|
||||
def convert_optional_int(cls, v: Any) -> Optional[int]:
|
||||
"""
|
||||
将字符串类型的坐标值转换为 int,空字符串转为 None
|
||||
|
||||
Args:
|
||||
v: 原始值(可能是 str, int, None)
|
||||
|
||||
Returns:
|
||||
Optional[int]: 转换后的值
|
||||
"""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, str):
|
||||
# 处理空字符串
|
||||
if v.strip() == "":
|
||||
return None
|
||||
# 尝试转换为 int
|
||||
try:
|
||||
return int(v)
|
||||
except ValueError:
|
||||
logger.warning(f"无法将字符串 '{v}' 转换为 int,返回 None")
|
||||
return None
|
||||
if isinstance(v, int): # 如果原始值本身就是int类型,直接return
|
||||
return v
|
||||
if isinstance(v, float): # 如果解析出了float,强转成int再return
|
||||
return int(v)
|
||||
|
||||
logger.warning(f"意外的类型 {type(v)},返回 None")
|
||||
return None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""
|
||||
将模型转换为字典,用于生成JSON字符串
|
||||
|
||||
Returns:
|
||||
dict: 模型数据字典
|
||||
"""
|
||||
return {
|
||||
"type": self.type,
|
||||
"Action": self.Action,
|
||||
"x1": self.x1,
|
||||
"y1": self.y1,
|
||||
"x2": self.x2,
|
||||
"y2": self.y2,
|
||||
"key": self.key,
|
||||
"content": self.content,
|
||||
"direction": self.direction
|
||||
}
|
||||
|
||||
class YosugaUITARSRequestData(LLMCoreAnalysisBase):
|
||||
"""
|
||||
自动化操作场景的LLM对自动化agent的调用的输出解析模型
|
||||
"""
|
||||
|
||||
type: str = Field(default="call_auto_agent", description="固定为call_auto_agent")
|
||||
llm_translation: str = Field(default="", description="针对用户的意图的转译,如果用户的意图足够明确,直接照抄即可")
|
||||
|
||||
@classmethod
|
||||
def type_(cls) -> str:
|
||||
return "call_auto_agent"
|
||||
|
||||
|
||||
|
||||
|
||||
class YosugaLive2DResponseData(LLMCoreAnalysisBase):
|
||||
"""
|
||||
Live2D控制场景的LLM输出数据模型 TODO
|
||||
"""
|
||||
|
||||
type: str = Field(default="live2d_control", description="固定为live2d_control")
|
||||
parameter: str = Field(..., description="参数名")
|
||||
value: float = Field(..., description="目标值")
|
||||
duration: int = Field(default=500, description="过渡时间(ms)")
|
||||
|
||||
@classmethod
|
||||
def type_(cls) -> str:
|
||||
return "live2d_control"
|
||||
|
||||
class YosugaEmbeddedResponseData(LLMCoreAnalysisBase):
|
||||
"""
|
||||
嵌入式设备场景的LLM输出数据模型 TODO
|
||||
"""
|
||||
|
||||
type: str = Field(default="embedded_control", description="固定为embedded_control")
|
||||
device_id: str = Field(..., description="设备ID")
|
||||
command: str = Field(..., description="控制指令")
|
||||
params: Optional[Dict[str, Any]] = Field(default=None, description="参数")
|
||||
|
||||
@classmethod
|
||||
def type_(cls) -> str:
|
||||
return "embedded_control"
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
from loguru import logger
|
||||
|
||||
# 注册所有数据模型
|
||||
LLMCoreAnalysisManager.register(YosugaAudioResponseData)
|
||||
LLMCoreAnalysisManager.register(YosugaUITARSResponseData)
|
||||
LLMCoreAnalysisManager.register(YosugaLive2DResponseData)
|
||||
LLMCoreAnalysisManager.register(YosugaEmbeddedResponseData)
|
||||
|
||||
logger.info("=== LLM输出解析模块测试 ===")
|
||||
|
||||
# 测试单对象解析
|
||||
print("\n【测试1:单对象自动识别】")
|
||||
llm_output = '''
|
||||
```json
|
||||
[{
|
||||
"type": "audio_text",
|
||||
"response_text": "收到!我会微笑回应",
|
||||
"emotion": "cheerful"
|
||||
}]
|
||||
```
|
||||
'''
|
||||
|
||||
response = LLMCoreAnalysisManager.parse(llm_output)
|
||||
print(f"类型: {response}")
|
||||
|
||||
# 3. 测试多对象解析
|
||||
print("\n【测试2:多对象混合响应】")
|
||||
multi_output = '''
|
||||
```json
|
||||
[
|
||||
{
|
||||
"type": "auto_agent",
|
||||
"Action": "click",
|
||||
"x1": "100",
|
||||
"y1": "200"
|
||||
},
|
||||
{
|
||||
"type": "live2d_control",
|
||||
"parameter": "ParamEyeLOpen",
|
||||
"value": 0.8,
|
||||
"duration": 300
|
||||
}
|
||||
]
|
||||
```
|
||||
'''
|
||||
|
||||
results = LLMCoreAnalysisManager.parse(multi_output)
|
||||
print(f"类型: {results}")
|
||||
@@ -0,0 +1,320 @@
|
||||
# llm_core/llm_core_dispatcher.py
|
||||
|
||||
"""
|
||||
动作分发器模块
|
||||
负责将解析后的LLM输出对象路由到对应的业务处理器
|
||||
支持同步/异步处理,提供回退机制与执行结果追踪
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Callable, ClassVar, Dict, List, Optional, Union
|
||||
from functools import wraps
|
||||
from loguru import logger
|
||||
|
||||
from src.server_core.llm_core.llm_core_analysis import LLMCoreAnalysisBase
|
||||
|
||||
# 处理器类型定义
|
||||
# 同步处理器:接收解析对象,返回执行结果
|
||||
SyncActionHandler = Callable[[LLMCoreAnalysisBase], Any]
|
||||
# 异步处理器:接收解析对象,返回协程
|
||||
AsyncActionHandler = Callable[[LLMCoreAnalysisBase], Any]
|
||||
|
||||
|
||||
def handler_error_wrapper(func: Callable) -> Callable:
|
||||
"""
|
||||
处理器错误包装装饰器
|
||||
统一捕获异常并记录日志,避免单个处理失败导致整体崩溃
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.exception(f"处理器 {func.__name__} 执行失败: {e}")
|
||||
raise
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.exception(f"异步处理器 {func.__name__} 执行失败: {e}")
|
||||
raise
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
|
||||
|
||||
class LLMCoreActionDispatcher:
|
||||
"""
|
||||
LLM动作分发中心
|
||||
|
||||
职责:
|
||||
1. 管理类型到处理器的映射关系
|
||||
2. 支持同步与异步处理器注册
|
||||
3. 执行分发并收集处理结果
|
||||
4. 提供未注册类型的回退机制
|
||||
5. 支持批量处理与结果聚合
|
||||
|
||||
使用示例:
|
||||
# 注册同步处理器
|
||||
def handle_audio(data: YosugaAudioResponseData) -> dict:
|
||||
return {"status": "spoken", "text": data.response_text}
|
||||
|
||||
LLMCoreActionDispatcher.register("audio_text", handle_audio)
|
||||
|
||||
# 注册异步处理器
|
||||
async def handle_ui(data: YosugaUITARSResponseData):
|
||||
await perform_click(data.x1, data.y1)
|
||||
return {"status": "clicked"}
|
||||
|
||||
LLMCoreActionDispatcher.register_async("auto_agent", handle_ui)
|
||||
|
||||
# 执行分发
|
||||
results = LLMCoreActionDispatcher.execute(parsed_objects)
|
||||
"""
|
||||
|
||||
# 类变量存储所有注册的处理器
|
||||
_sync_handlers: ClassVar[Dict[str, SyncActionHandler]] = {}
|
||||
_async_handlers: ClassVar[Dict[str, AsyncActionHandler]] = {}
|
||||
_fallback_handler: ClassVar[Optional[Union[SyncActionHandler, AsyncActionHandler]]] = None
|
||||
|
||||
@classmethod
|
||||
def register(cls, type_id: str, handler: SyncActionHandler) -> None:
|
||||
"""
|
||||
注册同步处理器
|
||||
|
||||
Args:
|
||||
type_id: 与 LLMCoreAnalysisBase.type_() 返回值匹配的标识符
|
||||
handler: 同步函数,接收解析对象并返回执行结果
|
||||
|
||||
Raises:
|
||||
ValueError: 处理器不是可调用的函数
|
||||
"""
|
||||
if not callable(handler):
|
||||
raise ValueError(f"处理器必须是可调用函数,收到: {type(handler)}")
|
||||
|
||||
# 检查是否已注册,防止覆盖
|
||||
if type_id in cls._sync_handlers or type_id in cls._async_handlers:
|
||||
logger.warning(f"类型 '{type_id}' 已被注册,将被覆盖")
|
||||
|
||||
cls._sync_handlers[type_id] = handler_error_wrapper(handler)
|
||||
logger.success(f"注册同步处理器: {type_id} → {handler.__name__}")
|
||||
|
||||
@classmethod
|
||||
def register_async(cls, type_id: str, handler: AsyncActionHandler) -> None:
|
||||
"""
|
||||
注册异步处理器
|
||||
|
||||
Args:
|
||||
type_id: 类型标识符
|
||||
handler: 异步函数,接收解析对象并返回协程
|
||||
|
||||
Raises:
|
||||
ValueError: 处理器不是有效的异步函数
|
||||
"""
|
||||
if not asyncio.iscoroutinefunction(handler):
|
||||
raise ValueError(f"异步处理器必须是协程函数,收到: {type(handler)}")
|
||||
|
||||
if type_id in cls._sync_handlers or type_id in cls._async_handlers:
|
||||
logger.warning(f"类型 '{type_id}' 已被注册,将被覆盖")
|
||||
|
||||
cls._async_handlers[type_id] = handler_error_wrapper(handler)
|
||||
logger.success(f"注册异步处理器: {type_id} → {handler.__name__}")
|
||||
|
||||
@classmethod
|
||||
def set_fallback(cls, handler: Union[SyncActionHandler, AsyncActionHandler]) -> None:
|
||||
"""
|
||||
设置回退处理器(未注册类型的默认处理)
|
||||
|
||||
Args:
|
||||
handler: 同步或异步函数,处理所有未匹配的类型
|
||||
"""
|
||||
wrapped = handler_error_wrapper(handler)
|
||||
cls._fallback_handler = wrapped
|
||||
handler_type = "异步" if asyncio.iscoroutinefunction(handler) else "同步"
|
||||
logger.info(f"设置{handler_type}回退处理器: {handler.__name__}")
|
||||
|
||||
@classmethod
|
||||
def get_handler(cls, type_id: str) -> Optional[Union[SyncActionHandler, AsyncActionHandler]]:
|
||||
"""获取指定类型的处理器"""
|
||||
# 优先返回同步处理器
|
||||
if type_id in cls._sync_handlers:
|
||||
return cls._sync_handlers[type_id]
|
||||
# 其次返回异步处理器
|
||||
if type_id in cls._async_handlers:
|
||||
return cls._async_handlers[type_id]
|
||||
# 返回回退处理器
|
||||
return cls._fallback_handler
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
analysis_results: List[LLMCoreAnalysisBase],
|
||||
run_async: bool = False
|
||||
) -> Dict[str, List[Any]]:
|
||||
"""
|
||||
执行分发处理
|
||||
|
||||
Args:
|
||||
analysis_results: 解析器返回的对象列表
|
||||
run_async: 是否启用异步执行模式(需要业务代码支持asyncio)
|
||||
|
||||
Returns:
|
||||
执行结果字典:
|
||||
{
|
||||
"success": [处理成功的结果列表],
|
||||
"failed": [{"type": "...", "error": "..."}],
|
||||
"skipped": [跳过的类型列表]
|
||||
}
|
||||
"""
|
||||
if not analysis_results:
|
||||
logger.warning("无对象需要分发")
|
||||
return {"success": [], "failed": [], "skipped": []}
|
||||
|
||||
if run_async and (cls._async_handlers or asyncio.iscoroutinefunction(cls._fallback_handler)):
|
||||
# 异步执行模式(需要事件循环)
|
||||
return asyncio.run(cls._execute_async(analysis_results))
|
||||
|
||||
# 默认同步执行
|
||||
return cls._execute_sync(analysis_results)
|
||||
|
||||
@classmethod
|
||||
def _execute_sync(cls, results: List[LLMCoreAnalysisBase]) -> Dict[str, List[Any]]:
|
||||
"""同步批量执行"""
|
||||
outputs = {"success": [], "failed": [], "skipped": []}
|
||||
|
||||
for idx, result in enumerate(results):
|
||||
type_id = result.type
|
||||
logger.debug(f"[{idx}] 分发类型: {type_id}")
|
||||
|
||||
handler = cls.get_handler(type_id)
|
||||
if not handler:
|
||||
outputs["skipped"].append(type_id)
|
||||
logger.warning(f"无处理器,跳过: {type_id}")
|
||||
continue
|
||||
|
||||
# 执行同步处理器
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
logger.error(f"异步处理器 '{type_id}' 不能在同步模式下执行")
|
||||
outputs["failed"].append({"type": type_id, "error": "异步处理器需要run_async=True"})
|
||||
continue
|
||||
|
||||
try:
|
||||
output = handler(result)
|
||||
outputs["success"].append({
|
||||
"type": type_id,
|
||||
"output": output,
|
||||
"index": idx
|
||||
})
|
||||
logger.success(f"[{idx}] 处理成功: {type_id}")
|
||||
except Exception as e:
|
||||
outputs["failed"].append({
|
||||
"type": type_id,
|
||||
"error": str(e),
|
||||
"index": idx
|
||||
})
|
||||
logger.error(f"[{idx}] 处理失败 {type_id}: {e}")
|
||||
|
||||
cls._log_summary(outputs, len(results))
|
||||
return outputs
|
||||
|
||||
@classmethod
|
||||
async def _execute_async(cls, results: List[LLMCoreAnalysisBase]) -> Dict[str, List[Any]]:
|
||||
"""异步批量执行"""
|
||||
outputs = {"success": [], "failed": [], "skipped": []}
|
||||
tasks = []
|
||||
|
||||
for idx, result in enumerate(results):
|
||||
type_id = result.type
|
||||
logger.debug(f"[{idx}] 异步分发类型: {type_id}")
|
||||
|
||||
handler = cls.get_handler(type_id)
|
||||
if not handler:
|
||||
outputs["skipped"].append(type_id)
|
||||
logger.warning(f"无处理器,跳过: {type_id}")
|
||||
continue
|
||||
|
||||
# 创建协程任务
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
task = cls._run_async_handler(handler, result, idx, outputs)
|
||||
else:
|
||||
# 同步处理器在异步线程池中执行
|
||||
task = cls._run_sync_in_executor(handler, result, idx, outputs)
|
||||
|
||||
tasks.append(task)
|
||||
|
||||
# 并发执行所有任务
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
cls._log_summary(outputs, len(results))
|
||||
return outputs
|
||||
|
||||
@classmethod
|
||||
async def _run_async_handler(cls, handler, result, idx, outputs):
|
||||
"""运行异步处理器"""
|
||||
try:
|
||||
output = await handler(result)
|
||||
outputs["success"].append({
|
||||
"type": result.type,
|
||||
"output": output,
|
||||
"index": idx
|
||||
})
|
||||
logger.success(f"[{idx}] 异步处理成功: {result.type}")
|
||||
except Exception as e:
|
||||
outputs["failed"].append({
|
||||
"type": result.type,
|
||||
"error": str(e),
|
||||
"index": idx
|
||||
})
|
||||
logger.error(f"[{idx}] 异步处理失败 {result.type}: {e}")
|
||||
|
||||
@classmethod
|
||||
async def _run_sync_in_executor(cls, handler, result, idx, outputs):
|
||||
"""在线程池中运行同步处理器"""
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
output = await loop.run_in_executor(None, handler, result)
|
||||
outputs["success"].append({
|
||||
"type": result.type,
|
||||
"output": output,
|
||||
"index": idx
|
||||
})
|
||||
logger.success(f"[{idx}] 同步处理器(线程池)成功: {result.type}")
|
||||
except Exception as e:
|
||||
outputs["failed"].append({
|
||||
"type": result.type,
|
||||
"error": str(e),
|
||||
"index": idx
|
||||
})
|
||||
logger.error(f"[{idx}] 同步处理器(线程池)失败 {result.type}: {e}")
|
||||
|
||||
@classmethod
|
||||
def _log_summary(cls, outputs: Dict, total: int):
|
||||
"""输出处理摘要"""
|
||||
success_count = len(outputs["success"])
|
||||
failed_count = len(outputs["failed"])
|
||||
skipped_count = len(outputs["skipped"])
|
||||
|
||||
logger.info(
|
||||
f"分发完成 | 总计: {total} | 成功: {success_count} "
|
||||
f"失败: {failed_count} 跳过: {skipped_count}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def clear(cls) -> None:
|
||||
"""清空所有处理器(测试/热重载用)"""
|
||||
cls._sync_handlers.clear()
|
||||
cls._async_handlers.clear()
|
||||
cls._fallback_handler = None
|
||||
logger.info("已清空所有动作处理器")
|
||||
|
||||
@classmethod
|
||||
def list_handlers(cls) -> Dict[str, str]:
|
||||
"""列出当前注册的所有处理器"""
|
||||
handlers = {}
|
||||
for type_id, handler in cls._sync_handlers.items():
|
||||
handlers[type_id] = f"同步: {handler.__name__}"
|
||||
for type_id, handler in cls._async_handlers.items():
|
||||
handlers[type_id] = f"异步: {handler.__name__}"
|
||||
return handlers
|
||||
@@ -0,0 +1,174 @@
|
||||
# llm_core/llm_core_prompt_manager.py
|
||||
|
||||
"""
|
||||
llm prompt 结构化信息管理
|
||||
用于将各种输入到llm_core的数据流结构化,并附加注释,以方便llm可以准确的理解并可以结构化返回相关内容
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing import Callable, List, Optional, Coroutine, Any, ClassVar, Dict
|
||||
from src.server_core.llm_core.llm_core_prompts import YOSUGA_SYSTEM_PROMPT_SCH
|
||||
|
||||
class LLMCorePromptBase(BaseModel, ABC):
|
||||
"""LLM 提示词基类:定义输入输出结构"""
|
||||
@abstractmethod
|
||||
def type(self) -> str:
|
||||
"""返回该提示词类型的唯一标识"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def describe_input(self) -> str:
|
||||
"""生成输入格式的自然语言描述(填充到 {InputInfo})"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def describe_output(self) -> str:
|
||||
"""生成输出格式的自然语言描述(填充到 {OutputInfo})"""
|
||||
pass
|
||||
|
||||
def to_json(self) -> str:
|
||||
return self.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_data: str):
|
||||
return cls.model_validate_json(json_data)
|
||||
|
||||
|
||||
class LLMCorePromptManager(LLMCorePromptBase):
|
||||
"""聚合所有子类,生成完整的 prompt 信息"""
|
||||
# 类变量:存储所有注册的子类
|
||||
_registry: ClassVar[Dict[str, LLMCorePromptBase]] = {}
|
||||
|
||||
def get_registry_size(self) -> int:
|
||||
return len(self._registry)
|
||||
|
||||
def register(self, son: LLMCorePromptBase) -> None:
|
||||
self._registry[son.type()] = son
|
||||
|
||||
def type(self) -> str:
|
||||
return "manager"
|
||||
|
||||
def describe_input(self) -> str:
|
||||
"""聚合所有子类的输入描述"""
|
||||
return "\n".join(
|
||||
f"[{type_id}]{son.describe_input()}\n"
|
||||
for type_id, son in self._registry.items()
|
||||
)
|
||||
|
||||
def describe_output(self) -> str:
|
||||
"""聚合所有子类的输出描述"""
|
||||
return "\n".join(
|
||||
f"[{type_id}]{son.describe_output()}\n"
|
||||
for type_id, son in self._registry.items()
|
||||
)
|
||||
|
||||
class YosugaAudioASRText(LLMCorePromptBase):
|
||||
"""音频ASR文本输入场景"""
|
||||
def type(self) -> str:
|
||||
return "用户语音asr信息"
|
||||
|
||||
def describe_input(self) -> str:
|
||||
return '''
|
||||
当用户通过语音与Yosuga交互时,你会收到如下JSON结构:
|
||||
{
|
||||
"text": "用户说的话(字符串)",
|
||||
"confidence": 0.95
|
||||
}
|
||||
- `text`: 语音转写的原始文本,可能包含口语化表达或识别错误
|
||||
- `confidence`: ASR引擎的识别置信度,低于0.8需警惕识别错误
|
||||
'''
|
||||
|
||||
def describe_output(self) -> str:
|
||||
return '''
|
||||
针对用户音频识别出的文本内容输入,你应该按以下JSON格式回复:
|
||||
{
|
||||
"type": "固定为audio_text",
|
||||
"response_text": "你的回复文本(字符串)",
|
||||
"emotion": "neutral",
|
||||
"action": "none"
|
||||
}
|
||||
- `response_text`: 给用户的自然语言回复
|
||||
- `emotion`: 回复的情感基调,可选值:neutral/cheerful/sad/angry
|
||||
- `action`: 触发的动作指令,如"wave_hand"、"nod"等,"none"表示无动作
|
||||
'''
|
||||
|
||||
class YosugaEmbedded(LLMCorePromptBase):
|
||||
"""嵌入式设备输入场景"""
|
||||
def type(self) -> str:
|
||||
pass
|
||||
|
||||
def describe_input(self) -> str:
|
||||
pass
|
||||
|
||||
def describe_output(self) -> str:
|
||||
pass
|
||||
|
||||
class YosugaUITARS(LLMCorePromptBase):
|
||||
"""自动化操作构建场景"""
|
||||
def type(self) -> str:
|
||||
return "自动化操作信息"
|
||||
|
||||
def describe_input(self) -> str:
|
||||
return '''
|
||||
当你尝试调用自动化agent的时候,自动化agent将会返回下面的内容作为输入给你,自动化agent的返回信息很重要,有的任务不是一次就可以完成的,此时你可能需要多次调用自动化agent,直到agent返回finished(任务完成):
|
||||
{
|
||||
"auto_agent":"
|
||||
Thought: 自动化agent的推理过程,可以考虑二次加工后作为回复用户的内容
|
||||
Action: 对应的自动化动作[包括click(单击坐标), left_double(双击坐标), right_single(右键单击), drag(拖拽), hotkey(快捷键), type(输入文本), scroll(滚动), wait(等待), finished(任务完成)]
|
||||
"
|
||||
}
|
||||
'''
|
||||
|
||||
def describe_output(self) -> str:
|
||||
return '''
|
||||
1. 当用户试图进行一些自动化操作时,你可以通过返回以下json来调用自动化agent,调用的时候也可以一并回复用户一些文本内容:
|
||||
{
|
||||
"type": "固定为call_auto_agent",
|
||||
"llm_translation": "针对用户的意图的转译,如果用户的意图足够明确,直接照抄即可,注意转译的语言必须是中文,可以混合部分英文(应用名称),这个和聊天交流的语言不统一。"
|
||||
}
|
||||
2. 针对自动化agent操作输入,你应该按以下JSON格式回复。如果你发起了自动化agent请求之后,却没有返回下面的JSON内容,那么你的请求并不会作用到用户的设备,所以你调用完自动化agent一定要及时返回下面的JSON内容:
|
||||
{
|
||||
"type": "固定为auto_agent",
|
||||
"Action": "自动化agent返回的相应的自动化动作名称",
|
||||
"x1": "某个操作的x1,不是所有的操作都有,如果相关的操作没有,写成-1即可,y1,x2,y2也同理",
|
||||
"y1": "某个操作的y1",
|
||||
"x2": "某个操作的x2",
|
||||
"y2": "某个操作的y2",
|
||||
"key": "快捷键,若当前操作没有该字段信息,此处内容为空即可",
|
||||
"content": "输入文本的内容,若当前操作没有该字段信息,此处内容为空即可",
|
||||
"direction": "滚动方向(down or up),若当前操作没有该字段信息,此处内容为空即可,同时需要关心自动化agent回复的x1,和y1坐标"
|
||||
}
|
||||
自动化agent返回的操作信息不一定包括JSON的全部字段,例如某次返回只有key的内容,或者只有content的内容。
|
||||
针对自动化agent操作输入的返回,若没有相关内容可以留空相关字段,请不要省略掉任何字段名称。
|
||||
|
||||
注意:自动化agent的状态可见YosugaSystemState表。
|
||||
'''
|
||||
|
||||
class YosugaLive2DControl(LLMCorePromptBase):
|
||||
"""对Yosuga Live2D控制场景"""
|
||||
def type(self) -> str:
|
||||
return "Yosuga Live2D控制信息"
|
||||
|
||||
def describe_input(self) -> str:
|
||||
pass
|
||||
|
||||
def describe_output(self) -> str:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 注册所有prompt处理器
|
||||
manager = LLMCorePromptManager()
|
||||
manager.register(YosugaAudioASRText())
|
||||
manager.register(YosugaUITARS())
|
||||
|
||||
# 生成最终 system prompt
|
||||
system_prompt = YOSUGA_SYSTEM_PROMPT_SCH.format(
|
||||
InputInfo=manager.describe_input(),
|
||||
OutputInfo=manager.describe_output(),
|
||||
RoleSetting="...",
|
||||
Language="ja",
|
||||
Memory = "",
|
||||
SystemStateTable=""
|
||||
)
|
||||
print(system_prompt)
|
||||
@@ -0,0 +1,67 @@
|
||||
# llm_core/llm_core_prompts.py
|
||||
|
||||
YOSUGA_SYSTEM_PROMPT_SCH = """
|
||||
你是Yosuga_server这个项目的核心LLM。
|
||||
你将会为用户完成各种任务,进行结构化的输出。
|
||||
首先向你介绍一下Yosuga这个项目:
|
||||
本项目的作者是Misakiotoha(みさきおとは[見崎音羽])。
|
||||
之所以叫Yosuga,这个词来源日语当中的单词"縁"的发音,其意思是"缘分,关系"。
|
||||
本项目分为三个部分:
|
||||
1. Yosuga:这是项目的前端部分,是Yosuga与用户交互的一层,采用C++20 + Qt6.6.3编写,使用到的核心外部库为Live2D For C++ SDK。
|
||||
2. Yosuga_server:这是项目的后端部分,是Yosuga的核心,采用python3.11编写,使用到的外部库较多,负责联系项目的各个部分。
|
||||
3. Yosuga_embedded:这是项目的拓展部分,实现了Yosuga对嵌入式设备拥有几乎完全的自定义控制能力,采用C语言编写,只使用到了cJSON库,平台无关,增强了Yosuga与外界的交互能力。
|
||||
|
||||
作为Yosuga_server的核心LLM,你的任务包括以下的内容:
|
||||
1. 接受自定义的结构化的信息(json),理解并给出正确的回复。
|
||||
2. 输出自定义的结构化的信息(json),给出正确的回复。
|
||||
|
||||
对于你会接受到什么信息,以及如何进行正确的返回,都会在下面做出解释。
|
||||
|
||||
你接受到的信息可能有:
|
||||
1. 来自Yosuga的对于Live2D模型控制的结构化数据。
|
||||
2. 来自Yosuga_embedded的可控制的嵌入式设备信息的结构化摘要数据。
|
||||
3. 综合了各种信息的结构化请求数据(例如同时包括用户所说的话语与其他决策信息)。
|
||||
|
||||
你的工作流程:
|
||||
1. 接受来自用户的请求,理解用户的意图(例如用户只有聊天的目的,或者是用户想进行一些操作)
|
||||
2. 如果用户只希望聊天,那就只和用户聊天即可
|
||||
3. 如果用户还有其他意图,你需要根据你的能力来完成用户的需求,如果用户的需求不在你的能力范围内,可以婉拒。
|
||||
|
||||
你的能力范围:
|
||||
1. 你可以和用户进行聊天,当用户只是想和你聊天,那就只和用户进行聊天
|
||||
2. 你可以主动调用自动化agent,当用户有着一些自动化操作的需求时候,你可以按一定调用格式去主动调用自动化agent模型,并理解自动化模型的返回内容,以规定的格式返回给服务端最终作用到用户的设备。
|
||||
3. 用户也可能会想和你一起玩一些游戏,这个时候你需要理解用户的意图,并调用自动化agent,同时返回一些聊天内容回复,以此和用户一起玩游戏。
|
||||
4. 你可以和各种种类的嵌入式设备进行交互,当用户有着一些现实的需求时候,你可以根据当前系统的状态,可控制的设备信息等信息综合判断,通过规定的方式返回规定的内容操作嵌入式设备,以满足用户的需求。
|
||||
|
||||
同时你还需要进行角色扮演,完美地扮演符合设定的角色,也要兼顾对于各种问题的处理。
|
||||
|
||||
以下是你会接收到的结构化的信息(带解释):
|
||||
{InputInfo}
|
||||
|
||||
以下是你需要根据不同的结构化信息需要输出的结构化信息(带解释):
|
||||
{OutputInfo}
|
||||
|
||||
你应当扮演的角色设定为:
|
||||
{RoleSetting}
|
||||
|
||||
你和用户聊天时候的回复语言为:
|
||||
{Language}
|
||||
|
||||
|
||||
注意:你的所有回复必须是一个 JSON 数组,即便只有单项任务。格式为 [ {{ "type": "...", ... }} ]。禁止输出JSON数组外的任何解释性文字,同时也禁止在JSON的回复内容中加入任何表情符号。
|
||||
注意:你接收到的结构化信息,有时可能会有多个,例如用户在问你的同时,自动化agent也返回了结果,这样的情况很常见,你需要按要求返回一个或者多个json即可。
|
||||
注意:有时候用户的请求可能既涉及到了聊天,也涉及到了自动化agent的调用,那么你在返回的json数组当中加入应当返回的内容即可,可以是聊天内容回复加上自动化agent调用。
|
||||
注意:有时用户的请求会比较复杂,需要你连续的调用自动化agent,根据自动化agent的输出,去一步步的完成用户的请求。
|
||||
|
||||
还有一些是你必须要参考的内容(例如与用户过去的记忆,系统实时状态表等内容):
|
||||
与用户过去的记忆:
|
||||
{Memory}
|
||||
|
||||
系统实时状态表:
|
||||
{SystemStateTable}
|
||||
|
||||
"""
|
||||
|
||||
YOSUGA_SYSTEM_PROMPT_EN = """
|
||||
|
||||
"""
|
||||
@@ -0,0 +1,367 @@
|
||||
# llm_core/llm_core_token.py
|
||||
|
||||
"""
|
||||
Token 计算与管理模块
|
||||
支持双数据源智能切换:
|
||||
- 优先使用大模型 API 返回的精确 usage 数据
|
||||
- 当 API 不返回时,自动回退到 tiktoken 手动估算
|
||||
"""
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from dataclasses import dataclass, field
|
||||
import tiktoken
|
||||
from loguru import logger
|
||||
|
||||
@dataclass
|
||||
class TokenUsage:
|
||||
"""Token 使用量统计"""
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
source: str = field(default="manual", repr=True) # "api" | "manual"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"prompt_tokens": self.prompt_tokens,
|
||||
"completion_tokens": self.completion_tokens,
|
||||
"total_tokens": self.total_tokens,
|
||||
"source": self.source
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenizerInfo:
|
||||
"""Tokenizer 元数据"""
|
||||
model_name: str
|
||||
encoding_name: str
|
||||
is_fallback: bool
|
||||
estimated_accuracy: str # "high" | "medium" | "low"
|
||||
|
||||
|
||||
class TokenManager:
|
||||
"""
|
||||
Token 管理核心类
|
||||
智能数据源切换:API返回 > 手动估算
|
||||
"""
|
||||
def __init__(self, model_name: str):
|
||||
self.model_name = model_name
|
||||
self.tokenizer = self._get_tokenizer(model_name)
|
||||
# 存储最近一次 API 返回的 usage
|
||||
self._last_api_usage: Optional[TokenUsage] = None
|
||||
# API 数据有效期(秒),超过此时间则视为过期,回退到手动计算
|
||||
self._api_usage_expiry: float = 30.0
|
||||
# 记录 API usage 的获取时间
|
||||
self._last_api_usage_time: float = 0.0
|
||||
|
||||
logger.info(
|
||||
f"TokenManager 初始化完成 | "
|
||||
f"模型: {model_name} | "
|
||||
f"编码器: {self.tokenizer.name}"
|
||||
)
|
||||
|
||||
def _get_tokenizer(self, model_name: str) -> tiktoken.Encoding:
|
||||
"""获取 tokenizer(与之前实现相同,省略重复代码)"""
|
||||
model_tokenizer_map = {
|
||||
"qwen": "gpt-3.5-turbo",
|
||||
"llama": "gpt-3.5-turbo",
|
||||
"gemma": "gpt-3.5-turbo",
|
||||
}
|
||||
|
||||
try:
|
||||
return tiktoken.encoding_for_model(model_name)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
for prefix, mapped_model in model_tokenizer_map.items():
|
||||
if prefix in model_name.lower():
|
||||
logger.info(
|
||||
f"模型 '{model_name}' 映射到 tokenizer '{mapped_model}'"
|
||||
)
|
||||
return tiktoken.encoding_for_model(mapped_model)
|
||||
|
||||
logger.warning(
|
||||
f"tiktoken 不支持模型 '{model_name}',"
|
||||
f"降级使用 'cl100k_base' 编码器"
|
||||
)
|
||||
return tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
def record_api_usage(self, usage: Optional[Union[Dict[str, int], TokenUsage]]) -> None:
|
||||
"""
|
||||
记录大模型 API 返回的 usage 数据
|
||||
|
||||
Args:
|
||||
usage: API 返回的 usage 数据(dict 或 TokenUsage 对象)
|
||||
如果为 None 或空,则忽略
|
||||
"""
|
||||
if not usage:
|
||||
logger.debug("API usage 为空,未记录")
|
||||
return
|
||||
|
||||
# 在底层已经统一好了不同大模型对于usage的处理
|
||||
if isinstance(usage, TokenUsage):
|
||||
api_usage = usage
|
||||
api_usage.source = "api"
|
||||
else:
|
||||
api_usage = TokenUsage(
|
||||
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||
completion_tokens=usage.get("completion_tokens", 0),
|
||||
total_tokens=usage.get("total_tokens", 0),
|
||||
source="api"
|
||||
)
|
||||
|
||||
self._last_api_usage = api_usage
|
||||
self._last_api_usage_time = time.time()
|
||||
|
||||
logger.debug(
|
||||
f"记录 API usage | "
|
||||
f"Prompt: {api_usage.prompt_tokens} | "
|
||||
f"Completion: {api_usage.completion_tokens} | "
|
||||
f"Total: {api_usage.total_tokens}"
|
||||
)
|
||||
|
||||
def get_current_usage(self, prefer_api: bool = True) -> TokenUsage:
|
||||
"""
|
||||
获取当前 Token 使用情况(智能数据源切换)
|
||||
|
||||
Args:
|
||||
prefer_api: 是否优先使用 API 数据(默认 True)
|
||||
|
||||
Returns:
|
||||
TokenUsage 对象,包含数据来源标记
|
||||
|
||||
Notes:
|
||||
- 如果 prefer_api=True 且 API 数据在有效期内,直接返回 API 数据
|
||||
- 否则回退到手动估算
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# 检查 API 数据是否有效
|
||||
if prefer_api and self._last_api_usage:
|
||||
time_elapsed = current_time - self._last_api_usage_time
|
||||
if time_elapsed <= self._api_usage_expiry:
|
||||
logger.debug(
|
||||
f"使用 API Token 数据({time_elapsed:.1f}s 内)| "
|
||||
f"{self._last_api_usage.prompt_tokens} + "
|
||||
f"{self._last_api_usage.completion_tokens} = "
|
||||
f"{self._last_api_usage.total_tokens}"
|
||||
)
|
||||
return self._last_api_usage
|
||||
|
||||
# API 数据无效或无数据,回退到手动估算
|
||||
logger.debug("API usage 无效/过期,回退到手动估算")
|
||||
return self._estimate_manual_usage()
|
||||
|
||||
def _estimate_manual_usage(self) -> TokenUsage:
|
||||
"""内部方法:创建空的手动 usage(占位符)"""
|
||||
return TokenUsage(source="manual")
|
||||
|
||||
def get_context_usage(self, history: List[Any]) -> TokenUsage:
|
||||
"""
|
||||
计算对话上下文的 Token 占用(必须手动计算)
|
||||
|
||||
注意:
|
||||
- 上下文占用无法从 API 获得,必须手动估算
|
||||
- 此方法不涉及 _last_api_usage
|
||||
|
||||
Args:
|
||||
history: 历史消息列表
|
||||
|
||||
Returns:
|
||||
TokenUsage 对象,source 始终为 "manual"
|
||||
"""
|
||||
tokens = self.count_messages_tokens(history)
|
||||
return TokenUsage(
|
||||
prompt_tokens=tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=tokens,
|
||||
source="manual"
|
||||
)
|
||||
|
||||
def count_text_tokens(self, text: str) -> int:
|
||||
"""计算单段文本的 token 数量(与之前相同)"""
|
||||
if not isinstance(text, str) or not text:
|
||||
return 0
|
||||
return len(self.tokenizer.encode(text))
|
||||
|
||||
def count_messages_tokens(
|
||||
self,
|
||||
messages: List[Any],
|
||||
tokens_per_message: int = 3
|
||||
) -> int:
|
||||
"""计算消息列表的总 token 数量(与之前相同,优化实现)"""
|
||||
if not messages:
|
||||
return 0
|
||||
|
||||
num_tokens = 0
|
||||
for msg in messages:
|
||||
# 统一转换为字典
|
||||
if hasattr(msg, "to_dict"):
|
||||
msg_dict = msg.to_dict()
|
||||
elif hasattr(msg, "model_dump"):
|
||||
msg_dict = msg.model_dump()
|
||||
else:
|
||||
msg_dict = msg
|
||||
|
||||
# 计算内容和 role 的 token
|
||||
num_tokens += self.count_text_tokens(msg_dict.get("content", ""))
|
||||
num_tokens += self.count_text_tokens(msg_dict.get("role", ""))
|
||||
num_tokens += tokens_per_message
|
||||
|
||||
# 加上回复前缀的开销
|
||||
num_tokens += 3
|
||||
return num_tokens
|
||||
|
||||
def estimate_chat_tokens(
|
||||
self,
|
||||
system_prompt: Optional[str],
|
||||
history: List[Any],
|
||||
current_input: str
|
||||
) -> TokenUsage:
|
||||
"""
|
||||
估算一次完整对话所需的 token 数量
|
||||
|
||||
Args:
|
||||
system_prompt: 系统提示词
|
||||
history: 历史消息列表
|
||||
current_input: 当前用户输入
|
||||
|
||||
Returns:
|
||||
TokenUsage 对象,source 始终为 "manual"
|
||||
|
||||
Notes:
|
||||
- 这是预估,不是实际 API 返回值
|
||||
- 用于调试和前置检查
|
||||
"""
|
||||
messages = []
|
||||
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
messages.extend(history)
|
||||
messages.append({"role": "user", "content": current_input})
|
||||
|
||||
total = self.count_messages_tokens(messages)
|
||||
|
||||
return TokenUsage(
|
||||
prompt_tokens=total,
|
||||
completion_tokens=0,
|
||||
total_tokens=total,
|
||||
source="manual"
|
||||
)
|
||||
|
||||
def format_usage_log(
|
||||
self,
|
||||
usage: Optional[Union[Dict[str, int], TokenUsage]] = None,
|
||||
source: str = "AUTO"
|
||||
) -> str:
|
||||
"""
|
||||
格式化 token 使用日志
|
||||
|
||||
Args:
|
||||
usage: usage 数据(可选)。如果为 None,自动获取当前 usage
|
||||
source: 数据来源标记("AUTO" | "API" | "MANUAL" | "CONTEXT")
|
||||
|
||||
Returns:
|
||||
格式化的日志字符串
|
||||
"""
|
||||
# 如果未提供 usage,自动获取
|
||||
if usage is None:
|
||||
if source == "CONTEXT":
|
||||
# 上下文场景必须手动计算
|
||||
usage_obj = self.get_context_usage([])
|
||||
else:
|
||||
# 自动选择最佳数据源
|
||||
usage_obj = self.get_current_usage(prefer_api=True)
|
||||
else:
|
||||
# 使用提供的 usage
|
||||
if isinstance(usage, TokenUsage):
|
||||
usage_obj = usage
|
||||
else:
|
||||
usage_obj = TokenUsage(
|
||||
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||
completion_tokens=usage.get("completion_tokens", 0),
|
||||
total_tokens=usage.get("total_tokens", 0),
|
||||
source="api" if source == "API" else "manual"
|
||||
)
|
||||
|
||||
# 根据来源选择前缀和图标
|
||||
prefix_map = {
|
||||
"API": "API Token统计",
|
||||
"MANUAL": "手动估算",
|
||||
"CONTEXT": "上下文占用",
|
||||
"AUTO": "Token统计"
|
||||
}
|
||||
prefix = prefix_map.get(source, "Token统计")
|
||||
|
||||
# 添加数据来源标记
|
||||
source_icon = "⚡" if usage_obj.source == "api" else "🧮"
|
||||
|
||||
# 格式化输出
|
||||
if usage_obj.completion_tokens > 0:
|
||||
return (
|
||||
f"{prefix} {source_icon} | "
|
||||
f"Prompt: {usage_obj.prompt_tokens} | "
|
||||
f"Completion: {usage_obj.completion_tokens} | "
|
||||
f"Total: {usage_obj.total_tokens}"
|
||||
)
|
||||
else:
|
||||
return (
|
||||
f"{prefix} {source_icon} | "
|
||||
f"Total: {usage_obj.total_tokens}"
|
||||
)
|
||||
|
||||
def get_tokenizer_info(self) -> TokenizerInfo:
|
||||
"""获取当前 tokenizer 的详细信息(与之前相同)"""
|
||||
if "cl100k_base" in self.tokenizer.name and "gpt-3.5" not in self.model_name:
|
||||
accuracy = "low"
|
||||
elif self.model_name in self.tokenizer.name:
|
||||
accuracy = "high"
|
||||
else:
|
||||
accuracy = "medium"
|
||||
|
||||
return TokenizerInfo(
|
||||
model_name=self.model_name,
|
||||
encoding_name=self.tokenizer.name,
|
||||
is_fallback="cl100k_base" in self.tokenizer.name,
|
||||
estimated_accuracy=accuracy
|
||||
)
|
||||
|
||||
def is_token_limit_approaching(
|
||||
self,
|
||||
current_tokens: int,
|
||||
limit: int,
|
||||
threshold: float = 0.85
|
||||
) -> bool:
|
||||
"""判断 token 使用量是否接近限制(与之前相同)"""
|
||||
return current_tokens > limit * threshold
|
||||
|
||||
def calculate_chunk_size(
|
||||
self,
|
||||
available_tokens: int,
|
||||
safety_margin: float = 0.1
|
||||
) -> int:
|
||||
"""计算安全的消息块大小(与之前相同)"""
|
||||
return int(available_tokens * (1 - safety_margin))
|
||||
|
||||
def clear_api_usage_cache(self):
|
||||
"""清空 API usage 缓存(用于测试)"""
|
||||
self._last_api_usage = None
|
||||
self._last_api_usage_time = 0.0
|
||||
logger.debug("API usage 缓存已清空")
|
||||
|
||||
|
||||
# 单例工厂
|
||||
class TokenManagerFactory:
|
||||
"""TokenManager 工厂类,支持缓存复用"""
|
||||
_instances: Dict[str, TokenManager] = {}
|
||||
|
||||
@classmethod
|
||||
def get_manager(cls, model_name: str) -> TokenManager:
|
||||
if model_name not in cls._instances:
|
||||
cls._instances[model_name] = TokenManager(model_name)
|
||||
return cls._instances[model_name]
|
||||
|
||||
@classmethod
|
||||
def clear_cache(cls):
|
||||
cls._instances.clear()
|
||||
logger.info("TokenManager 缓存已清空")
|
||||
@@ -0,0 +1,242 @@
|
||||
### 服务端llm核心
|
||||
服务端以AI为核心去驱动,进行各种调用。
|
||||
|
||||
本模块为服务端的核心llm模块,这个llm必须足够聪明,能够稳定返回结构化结构。
|
||||
|
||||
|
||||
llm_core负责实现:
|
||||
|
||||
|
||||
|
||||
|
||||
#### 类图
|
||||
```mermaid
|
||||
classDiagram
|
||||
class YosugaLLMCore {
|
||||
<<主控制器>>
|
||||
-ModelConfig model_config
|
||||
-LLMCoreConfig core_config
|
||||
-UnifiedLLM llm_client
|
||||
-TokenManager token_manager
|
||||
-LLMCorePromptManager prompt_manager
|
||||
-List[ChatMessage] _history
|
||||
-Lock _history_lock
|
||||
-Lock _config_lock
|
||||
+interact() Dict
|
||||
+register_prompt_module() void
|
||||
+register_action_handler() void
|
||||
+reload_model() void
|
||||
+get_context_stats() Dict
|
||||
}
|
||||
|
||||
class LLMCoreConfig {
|
||||
<<运行时配置>>
|
||||
-int max_context_tokens
|
||||
-bool enable_history
|
||||
-str language
|
||||
-str role_setting
|
||||
-bool auto_dispatch
|
||||
-bool dispatch_async
|
||||
-str memory
|
||||
-str system_state_table
|
||||
}
|
||||
|
||||
class UnifiedLLM {
|
||||
<<大模型调用层>>
|
||||
-ModelConfig config
|
||||
-BaseLLMClient client
|
||||
+chat() ModelResponse
|
||||
+complete() ModelResponse
|
||||
+stream_chat() Iterator
|
||||
+update_config() void
|
||||
}
|
||||
|
||||
class TokenManager {
|
||||
<<Token管理>>
|
||||
-str model_name
|
||||
-tiktoken.Encoding tokenizer
|
||||
-TokenUsage _last_api_usage
|
||||
+record_api_usage() void
|
||||
+get_current_usage() TokenUsage
|
||||
+get_context_usage() TokenUsage
|
||||
+count_messages_tokens() int
|
||||
+format_usage_log() str
|
||||
}
|
||||
|
||||
class LLMCorePromptManager {
|
||||
<<Prompt管理>>
|
||||
-Dict _registry
|
||||
+register() void
|
||||
+describe_input() str
|
||||
+describe_output() str
|
||||
}
|
||||
|
||||
class LLMCoreAnalysisManager {
|
||||
<<输出解析>>
|
||||
<<静态类>>
|
||||
-Dict _model_registry
|
||||
+register() void
|
||||
+parse() List[LLMCoreAnalysisBase]
|
||||
}
|
||||
|
||||
class LLMCoreActionDispatcher {
|
||||
<<动作分发>>
|
||||
<<静态类>>
|
||||
-Dict _sync_handlers
|
||||
-Dict _async_handlers
|
||||
-Callable _fallback_handler
|
||||
+register() void
|
||||
+register_async() void
|
||||
+execute() Dict
|
||||
}
|
||||
|
||||
class LLMCoreAnalysisBase {
|
||||
<<抽象基类>>
|
||||
<<模型数据基类>>
|
||||
#str type
|
||||
+type_() str
|
||||
+get_schema() Dict
|
||||
}
|
||||
|
||||
class ChatMessage {
|
||||
<<消息实体>>
|
||||
-str role
|
||||
-str content
|
||||
-str name
|
||||
+to_dict() Dict
|
||||
}
|
||||
|
||||
class ModelResponse {
|
||||
<<响应实体>>
|
||||
-str content
|
||||
-str model
|
||||
-Dict usage
|
||||
-str finish_reason
|
||||
-Dict raw_response
|
||||
}
|
||||
|
||||
YosugaLLMCore --> LLMCoreConfig : 持有配置
|
||||
YosugaLLMCore --> UnifiedLLM : 调用大模型
|
||||
YosugaLLMCore --> TokenManager : 统计Token
|
||||
YosugaLLMCore --> LLMCorePromptManager : 管理Prompt
|
||||
YosugaLLMCore --> LLMCoreAnalysisManager : 解析输出
|
||||
YosugaLLMCore --> LLMCoreActionDispatcher : 分发动作
|
||||
YosugaLLMCore --> ChatMessage : 管理历史
|
||||
UnifiedLLM --> ModelResponse : 返回响应
|
||||
LLMCoreAnalysisManager --> LLMCoreAnalysisBase : 解析为
|
||||
TokenManager --> ModelResponse : 接收usage
|
||||
```
|
||||
|
||||
|
||||
#### 时序图
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant Client as 客户端
|
||||
participant Core as YosugaLLMCore
|
||||
participant Token as TokenManager
|
||||
participant Prompt as PromptManager
|
||||
participant LLM as UnifiedLLM
|
||||
participant Parser as AnalysisManager
|
||||
participant Dispatcher as ActionDispatcher
|
||||
|
||||
Client->>Core: interact(user_input)
|
||||
Note over Core: 输入预处理
|
||||
|
||||
Core->>Token: count_messages_tokens(_history)
|
||||
Token-->>Core: current_usage
|
||||
|
||||
Core->>Core: _maintain_context_limit()
|
||||
Note over Core: 检查溢出并清理
|
||||
|
||||
Core->>Prompt: get_system_prompt()
|
||||
Note over Prompt: 聚合InputInfo/OutputInfo
|
||||
|
||||
Prompt-->>Core: system_prompt
|
||||
|
||||
Core->>Core: _build_request_messages()
|
||||
Note over Core: 组装[system, history, user]
|
||||
|
||||
Core->>Token: estimate_chat_tokens()
|
||||
Token-->>Core: estimated_usage
|
||||
|
||||
Core->>LLM: chat(messages)
|
||||
Note over LLM: 调用底层API
|
||||
|
||||
LLM-->>Core: ModelResponse(content, usage)
|
||||
|
||||
Core->>Token: record_api_usage(usage)
|
||||
Note over Token: 优先使用API数据
|
||||
|
||||
Core->>Parser: parse(content)
|
||||
Note over Parser: JSON清洗+类型校验
|
||||
|
||||
Parser-->>Core: List[AnalysisObj]
|
||||
|
||||
Core->>Core: _add_to_history(user+assistant)
|
||||
Note over Core: 更新对话记忆
|
||||
|
||||
Core->>Dispatcher: execute(parsed_results)
|
||||
|
||||
par 分发处理
|
||||
Dispatcher->>Handler1: 同步处理(audio_text)
|
||||
Dispatcher->>Handler2: 异步处理(auto_agent)
|
||||
end
|
||||
|
||||
Dispatcher-->>Core: {"success": [], "failed": []}
|
||||
|
||||
Core-->>Client: 执行结果
|
||||
```
|
||||
|
||||
|
||||
#### 配置与模型热重载状态机
|
||||
```mermaid
|
||||
stateDiagram-v2
|
||||
[*] --> 初始化: YosugaLLMCore()
|
||||
|
||||
state 初始化 {
|
||||
[*] --> 加载配置: ModelConfig
|
||||
加载配置 --> 创建LLM客户端: UnifiedLLM
|
||||
创建LLM客户端 --> 注册默认Prompt: Audio/UITARS
|
||||
注册默认Prompt --> 初始化Token管理器: TokenManager
|
||||
}
|
||||
|
||||
初始化 --> 待机: 等待输入
|
||||
|
||||
state 待机 {
|
||||
[*] --> 构建System Prompt
|
||||
构建System Prompt --> 检查上下文限制: _maintain_context_limit()
|
||||
检查上下文限制 --> 上下文溢出: current > limit
|
||||
检查上下文限制 --> 正常: 否则
|
||||
|
||||
上下文溢出 --> 触发回调: _trigger_overflow_callbacks()
|
||||
触发回调 --> 清理历史: 保留50%
|
||||
清理历史 --> 正常
|
||||
|
||||
正常 --> 组装消息链: _build_request_messages()
|
||||
}
|
||||
|
||||
待机 --> 调用LLM: chat()
|
||||
|
||||
调用LLM --> 解析输出: AnalysisManager.parse()
|
||||
|
||||
解析输出 --> 更新历史: _add_to_history()
|
||||
|
||||
更新历史 --> 分发动作: Dispatcher.execute()
|
||||
|
||||
分发动作 --> 返回结果: interact() return
|
||||
|
||||
返回结果 --> 待机
|
||||
|
||||
待机 --> 热重载: reload_model()
|
||||
|
||||
state 热重载 {
|
||||
[*] --> 更新模型配置: update_config()
|
||||
更新模型配置 --> 重建LLM客户端: _create_client()
|
||||
重建LLM客户端 --> 重建Token管理器: TokenManager(model_name)
|
||||
重建Token管理器 --> [*]
|
||||
}
|
||||
|
||||
热重载 --> 待机
|
||||
|
||||
note right of 热重载 : "保留历史记忆\n不影响上下文"
|
||||
```
|
||||
@@ -0,0 +1,8 @@
|
||||
### 本模块为Yosuga_server的AI记忆模块
|
||||
|
||||
|
||||
##### 提供的功能有:
|
||||
1. 向量化的记忆存储
|
||||
2. 十分便捷的记忆管理与查询接口
|
||||
3. 高效的记忆检索
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
### 本模块为Yosuga_server的核心业务模块
|
||||
|
||||
####
|
||||
@@ -0,0 +1,4 @@
|
||||
### 本模块为Yosuga_server的前端部分
|
||||
|
||||
为了方便管理服务端
|
||||
|
||||
Reference in New Issue
Block a user