first pull

This commit is contained in:
2026-02-03 01:20:00 +08:00
commit 0753da86a8
79 changed files with 7134 additions and 0 deletions
View File
View File
+123
View File
@@ -0,0 +1,123 @@
# asr_module/api.py
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from pathlib import Path
import tempfile
import time
from datetime import datetime
from loguru import logger
from src.modules.asr_module.asr_core.fast_whisper import create_asr, ASRConfig
# 初始化FastAPI应用
app = FastAPI(
title="Yosuga ASR API",
description="基于faster-whisper Turbo的高性能多语种语音转文本服务",
version="1.0.0"
)
# 全局单例ASR实例(延迟加载)
_asr_instance = None
def get_asr():
"""获取或创建ASR实例(单例)"""
global _asr_instance
if _asr_instance is None:
logger.info("🚀 初始化ASR服务...")
_asr_instance = create_asr(
ASRConfig(
model_name="deepdml/faster-whisper-large-v3-turbo-ct2",
device="auto",
compute_type="int8_float16",
cache_dir=Path("asr_models/faster_whisper_large_v3_ct2"),
beam_size=1, # 贪婪搜索,速度最快
vad_filter=True, # 过滤静音,节省30%时间
)
)
logger.info("✅ ASR服务初始化完成")
return _asr_instance
@app.on_event("startup")
async def startup_event():
"""应用启动时预加载模型"""
get_asr()
@app.on_event("shutdown")
async def shutdown_event():
"""应用关闭时清理资源"""
global _asr_instance
if _asr_instance:
_asr_instance.shutdown()
logger.info("🛑 ASR服务已关闭")
@app.post("/transcribe", response_class=JSONResponse)
async def transcribe_audio(
file: UploadFile = File(..., description="音频文件 (WAV, FLAC, MP3等格式)")
):
"""
语音转文本API
- **file**: 音频文件,支持WAV/FLAC/MP3等格式
- **返回**: JSON格式结果,包含text/language/confidence
"""
start_time = time.time()
# 验证文件类型
if file.content_type and not file.content_type.startswith("audio/"):
raise HTTPException(status_code=400, detail="❌ 请上传音频文件 (MIME类型: audio/*)")
try:
# 创建临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as tmp_file:
content = await file.read()
tmp_file.write(content)
tmp_path = Path(tmp_file.name)
logger.info(f"📥 接收文件: {file.filename} ({len(content)} bytes)")
# 调用ASR识别
asr = get_asr()
text, language, confidence = asr.transcribe_wav(tmp_path)
# 清理临时文件
tmp_path.unlink(missing_ok=True)
processing_time = time.time() - start_time
logger.info(f"✅ 识别完成: {language} | {len(text)}字符 | 置信度:{confidence:.2f} | 耗时:{processing_time:.3f}s")
return {
"success": True,
"data": {
"text": text,
"language": language,
"confidence": confidence,
"processing_time": round(processing_time, 3)
}
}
except Exception as e:
logger.error(f"❌ 识别失败: {e}")
raise HTTPException(status_code=500, detail=f"识别失败: {str(e)}")
@app.get("/health")
async def health_check():
"""健康检查接口"""
asr = get_asr()
health = asr.health_check()
return {
"status": "healthy" if health["status"] == "healthy" else "unhealthy",
"timestamp": datetime.now().isoformat(),
"device": health["device"],
"model_loaded": health["model_loaded"]
}
@app.get("/")
async def root():
"""API根路径"""
return {
"message": "Yosuga ASR API 正在运行",
"docs": "/docs",
"health": "/health",
"transcribe": "/transcribe (POST)"
}
@@ -0,0 +1,16 @@
# fast_whisper/__init__.py
from typing import Optional
from src.modules.asr_module.asr_core.fast_whisper.config import ASRConfig
from src.modules.asr_module.asr_core.fast_whisper.model_manager import ModelManager
from src.modules.asr_module.asr_core.fast_whisper.asr_interface import ASRInterface
__version__ = "1.0.0"
__all__ = ["ASRConfig", "ModelManager", "ASRInterface"]
def create_asr(config: Optional[ASRConfig] = None) -> ASRInterface:
"""
快速创建ASR实例
Args:
config: ASR配置,若为None则使用默认配置
"""
return ASRInterface.get_instance(config)
@@ -0,0 +1,184 @@
# fast_whisper/asr_interface.py
from loguru import logger
from pathlib import Path
from typing import Tuple, Optional
import torchaudio
import torch
import numpy
from .model_manager import ModelManager
from .config import ASRConfig
from .utils import PerformanceProfiler
class ASRInterface:
"""
ASR接口类 - 全局单例
- 提供wav转文本功能
- 注入ModelManager
- 性能统计
"""
_instance: Optional['ASRInterface'] = None
def __new__(cls, *args, **kwargs):
"""单例模式实现"""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, config: Optional[ASRConfig] = None):
# 防止重复初始化
if hasattr(self, '_initialized') and self._initialized:
return
self.config = config or ASRConfig()
self.model_manager = ModelManager(self.config)
self.profiler = PerformanceProfiler(self.config.enable_profiling)
# 音频参数
self.sample_rate = 16000
self._initialized = True
logger.info("🎤 ASR接口初始化完成")
@classmethod
def get_instance(cls, config: Optional[ASRConfig] = None) -> 'ASRInterface':
"""全局访问点"""
if cls._instance is None:
cls._instance = cls(config)
return cls._instance
def transcribe_wav(
self,
wav_path: Path,
language: Optional[str] = None
) -> Tuple[str, str, float]:
"""
WAV音频转文本(核心接口)
Args:
wav_path: WAV文件路径
language: 指定语言代码(如'zh'/'en'),None则自动检测
Returns:
(text, language, confidence)
"""
try:
# 记录开始时间
import time
start_time = time.time()
logger.info(f"🎵 开始识别: {wav_path.name}")
# 执行识别...
audio = self._load_audio(wav_path)
result = self._transcribe(audio, language)
text, lang, confidence = self._parse_result(result)
# 计算耗时
processing_time = time.time() - start_time
logger.info(
f"✅ 识别完成: {lang} | {len(text)}字符 | 置信度:{confidence:.2f} | "
f"耗时:{processing_time:.3f}s | RTF:{processing_time/(len(audio)/self.sample_rate):.3f}"
)
return text, lang, confidence
except Exception as e:
logger.error(f"❌ 识别失败 {wav_path}: {e}")
raise RuntimeError(f"Transcription failed: {e}")
def _load_audio(self, wav_path: Path) -> numpy.ndarray:
"""加载和预处理音频"""
if not wav_path.exists():
raise FileNotFoundError(f"音频文件不存在: {wav_path}")
# 加载音频
waveform, sample_rate = torchaudio.load(wav_path)
# 重采样到16kHz
if sample_rate != self.sample_rate:
resampler = torchaudio.transforms.Resample(sample_rate, self.sample_rate)
waveform = resampler(waveform)
# 转换为单声道
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# 转换为numpy数组
audio = waveform.squeeze().numpy()
return audio
def _transcribe(self, audio: numpy.ndarray, language: Optional[str]) -> Tuple:
"""执行推理"""
model = self.model_manager.model
# 添加模型存在性检查
if model is None:
logger.error("ASR模型未加载,请检查模型配置和路径")
raise RuntimeError("ASR模型未加载,请检查模型配置和路径")
# 记录时间
import time
start_time = time.time()
# 调用模型
segments, info = model.transcribe(
audio,
language=language,
beam_size=self.config.beam_size,
best_of=self.config.best_of,
vad_filter=self.config.vad_filter,
)
# 立即执行生成器
segments_list = list(segments)
# 性能统计
inference_time = time.time() - start_time
audio_duration = len(audio) / self.sample_rate
self.profiler.record(audio_duration, inference_time)
return segments_list, info
def _parse_result(self, result: Tuple) -> Tuple[str, str, float]:
"""解析识别结果"""
segments, info = result
# 合并所有片段
text = " ".join([seg.text.strip() for seg in segments])
# 获取语言信息
language = info.language if info else "unknown"
confidence = info.language_probability if info else 0.0
return text, language, confidence
def transcribe_batch(self, wav_paths: list) -> list:
"""批量识别接口"""
return [
{
"file": str(path),
"text": result[0],
"language": result[1],
"confidence": result[2]
}
for path, result in zip(wav_paths, [
self.transcribe_wav(Path(p)) for p in wav_paths
])
]
def health_check(self) -> dict:
"""健康检查接口"""
return {
"status": "healthy" if self.model_manager.model else "unhealthy",
"device": self.config.device,
"model_loaded": self.model_manager.model is not None,
"device_info": self.model_manager.get_device_info(),
}
def shutdown(self):
"""优雅关闭"""
logger.info("🛑 关闭ASR接口...")
self.model_manager.unload()
@@ -0,0 +1,29 @@
# fast_whisper/config.py
from dataclasses import dataclass
from pathlib import Path
import torch
@dataclass
class ASRConfig:
"""ASR配置类"""
model_name: str = "deepdml/faster-whisper-large-v3-turbo-ct2"
device: str = "auto"
compute_type: str = "int8_float16"
cache_dir: Path = Path.home() / ".cache" / "faster_whisper"
# 速度优化参数
beam_size: int = 1
best_of: int = 1
vad_filter: bool = True
batch_size: int = 16
# 性能统计
enable_profiling: bool = True
def __post_init__(self):
if self.device == "auto":
self.device = "cuda" if torch.cuda.is_available() else "cpu"
if self.device == "cpu":
self.compute_type = "int8"
self.batch_size = 4
@@ -0,0 +1,92 @@
# fast_whisper/model_manager.py
import gc
from loguru import logger
from pathlib import Path
from typing import Optional
from faster_whisper import WhisperModel
import torch
from .config import ASRConfig
class ModelManager:
"""
模型管理类
- 负责模型生命周期管理
- 支持自定义缓存目录
- 自动硬件适配
"""
def __init__(self, config: ASRConfig):
self.config = config
self._model: Optional[WhisperModel] = None
self._device_info = None
# 确保缓存目录存在
self.config.cache_dir.mkdir(parents=True, exist_ok=True)
@property
def model(self) -> Optional[WhisperModel]:
"""懒加载模型"""
if self._model is None:
self._load_model()
return self._model
def _load_model(self):
"""加载模型"""
logger.info(f"🚀 初始化模型: {self.config.model_name}")
logger.info(f"📦 设备: {self.config.device}, 计算类型: {self.config.compute_type}")
try:
self._model = WhisperModel(
self.config.model_name,
device=self.config.device,
compute_type=self.config.compute_type,
download_root=str(self.config.cache_dir),
local_files_only=False,
)
self._device_info = {
"device": self.config.device,
"compute_type": self.config.compute_type,
"model_size": self.config.model_name.split("-")[-2]
}
logger.info("✅ 模型加载成功")
except Exception as e:
logger.error(f"❌ 模型加载失败: {e}")
raise RuntimeError(f"Failed to load ASR model: {e}")
def reload(self, new_config: ASRConfig):
"""热重载模型"""
logger.info("🔄 热重载模型...")
self.unload()
self.config = new_config
self._load_model()
def unload(self):
"""卸载模型释放资源"""
if self._model is not None:
logger.info("🗑️ 卸载模型...")
del self._model
self._model = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
logger.info("✅ 模型已卸载")
def get_device_info(self) -> dict:
"""获取设备信息"""
return self._device_info or {}
def __enter__(self):
"""上下文管理器支持"""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""自动清理资源"""
self.unload()
@@ -0,0 +1,45 @@
# fast_whisper/utils.py
from loguru import logger
from datetime import datetime
from typing import Dict, Any
def check_hardware() -> Dict[str, Any]:
"""硬件检测"""
import torch
info = {
"cuda_available": torch.cuda.is_available(),
"device_name": "CPU",
"device_count": 0,
"compute_type": "int8"
}
if info["cuda_available"]:
info.update({
"device_name": torch.cuda.get_device_name(0),
"device_count": torch.cuda.device_count(),
"compute_type": "int8_float16"
})
return info
class PerformanceProfiler:
"""性能分析器"""
def __init__(self, enable: bool = True):
self.enable = enable
self.stats = []
def record(self, audio_duration: float, inference_time: float):
if not self.enable:
return
rtf = inference_time / audio_duration if audio_duration > 0 else 0
self.stats.append({
"timestamp": datetime.now().isoformat(),
"rtf": rtf,
"audio_duration": audio_duration,
"inference_time": inference_time
})
if len(self.stats) % 10 == 0:
avg_rtf = sum(s["rtf"] for s in self.stats[-10:]) / 10
logger.info(f"📊 最近10次平均RTF: {avg_rtf:.3f}")
+215
View File
@@ -0,0 +1,215 @@
# asr_module/client/asr_client.py
import asyncio
import time
from pathlib import Path
from typing import Union, Optional
import aiofiles
import aiohttp
import requests
from loguru import logger
from .models import ASRResponse, ASRHealthStatus, ServiceInfo
class ASRException(Exception):
"""ASR服务调用异常"""
def __init__(self, message: str, status_code: Optional[int] = None):
self.message = message
self.status_code = status_code
super().__init__(self.message)
class ASRClientConfig:
"""客户端配置"""
def __init__(
self,
base_url: str = "http://localhost:8000",
timeout: float = 30.0,
retry_count: int = 2,
retry_delay: float = 0.5,
):
self.base_url = base_url.rstrip('/')
self.timeout = timeout
self.retry_count = retry_count
self.retry_delay = retry_delay
# 同步客户端
class ASRClientSync:
"""同步ASR客户端"""
def __init__(self, config: Optional[ASRClientConfig] = None):
self.config = config or ASRClientConfig()
self.session = requests.Session()
self.session.timeout = self.config.timeout
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.session.close()
def _request(self, method: str, endpoint: str, **kwargs) -> dict:
"""统一请求处理(带重试)"""
url = f"{self.config.base_url}{endpoint}"
for attempt in range(self.config.retry_count + 1):
try:
response = self.session.request(method, url, **kwargs)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
if attempt < self.config.retry_count:
logger.warning(f"请求失败,重试中 ({attempt + 1}/{self.config.retry_count}): {e}")
time.sleep(self.config.retry_delay)
else:
logger.error(f"请求最终失败: {e}")
raise ASRException(f"API调用失败: {e}", getattr(e.response, 'status_code', None))
def transcribe_file(self, file_path: Union[str, Path]) -> ASRResponse:
"""
转录音频文件
Args:
file_path: 音频文件路径
Returns:
ASRResponse对象
Example:
client = ASRClientSync()
result = client.transcribe_file("/path/to/audio.wav")
print(result.data.text)
"""
file_path = Path(file_path)
if not file_path.exists():
raise FileNotFoundError(f"文件不存在: {file_path}")
logger.info(f"📤 上传文件: {file_path.name}")
with open(file_path, 'rb') as f:
files = {'file': (file_path.name, f, 'audio/wav')}
result = self._request('POST', '/transcribe', files=files)
return ASRResponse(**result)
def transcribe_bytes(self, audio_data: bytes, filename: str = "audio.wav") -> ASRResponse:
"""
转录音频字节流
Args:
audio_data: 原始音频字节
filename: 模拟文件名(用于MIME类型推断)
Returns:
ASRResponse对象
Example:
with open('audio.wav', 'rb') as f:
audio_bytes = f.read()
result = client.transcribe_bytes(audio_bytes)
"""
logger.info(f"📤 上传字节流 ({len(audio_data)} bytes)")
files = {'file': (filename, audio_data, 'audio/wav')}
result = self._request('POST', '/transcribe', files=files)
return ASRResponse(**result)
def health_check(self) -> ASRHealthStatus:
"""健康检查"""
result = self._request('GET', '/health')
return ASRHealthStatus(**result)
def get_service_info(self) -> ServiceInfo:
"""获取服务信息"""
result = self._request('GET', '/')
return ServiceInfo(**result)
# 异步客户端
class ASRClientAsync:
"""异步ASR客户端"""
def __init__(self, config: Optional[ASRClientConfig] = None):
self.config = config or ASRClientConfig()
self._session: Optional[aiohttp.ClientSession] = None
async def __aenter__(self):
self._session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=self.config.timeout)
)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self._session:
await self._session.close()
async def _ensure_session(self):
if self._session is None:
self._session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=self.config.timeout)
)
async def _request(self, method: str, endpoint: str, **kwargs) -> dict:
"""统一异步请求(带重试)"""
await self._ensure_session()
url = f"{self.config.base_url}{endpoint}"
for attempt in range(self.config.retry_count + 1):
try:
async with self._session.request(method, url, **kwargs) as response:
response.raise_for_status()
return await response.json()
except aiohttp.ClientError as e:
if attempt < self.config.retry_count:
logger.warning(f"请求失败,重试中 ({attempt + 1}/{self.config.retry_count}): {e}")
await asyncio.sleep(self.config.retry_delay)
else:
logger.error(f"请求最终失败: {e}")
raise ASRException(f"API调用失败: {e}", getattr(e, 'status', None))
async def transcribe_file(self, file_path: Union[str, Path]) -> ASRResponse:
"""异步转录音频文件"""
file_path = Path(file_path)
if not file_path.exists():
raise FileNotFoundError(f"文件不存在: {file_path}")
logger.info(f"📤 上传文件: {file_path.name}")
async with aiofiles.open(file_path, 'rb') as f:
audio_data = await f.read()
return await self.transcribe_bytes(audio_data, file_path.name)
async def transcribe_bytes(self, audio_data: bytes, filename: str = "audio.wav") -> ASRResponse:
"""异步转录音频字节流"""
logger.info(f"📤 上传字节流 ({len(audio_data)} bytes)")
await self._ensure_session() # 确保session已创建
form = aiohttp.FormData() # 创建表单数据
form.add_field('file', audio_data, filename=filename, content_type='audio/wav') # 添加文件字段
result = await self._request('POST', '/transcribe', data=form) # 发送POST请求
return ASRResponse(**result) # 返回结果
async def health_check(self) -> ASRHealthStatus:
"""异步健康检查"""
result = await self._request('GET', '/health')
return ASRHealthStatus(**result)
async def get_service_info(self) -> ServiceInfo:
"""异步获取服务信息"""
result = await self._request('GET', '/')
return ServiceInfo(**result)
# 工厂函数
def create_asr_client(use_async: bool = False, **config_kwargs) -> Union[ASRClientSync, ASRClientAsync]:
"""
创建客户端工厂函数
Args:
use_async: 是否创建异步客户端
**config_kwargs: ASRClientConfig参数
Returns:
同步或异步客户端实例
"""
config = ASRClientConfig(**config_kwargs)
if use_async:
return ASRClientAsync(config)
return ASRClientSync(config)
+30
View File
@@ -0,0 +1,30 @@
# asr_module/client/models.py
from pydantic import BaseModel
from typing import Optional
class ASRHealthStatus(BaseModel):
"""ASR服务健康状态"""
status: str # 状态
timestamp: str # 时间戳
device: str # 设备
model_loaded: bool # 模型是否加载
class ASRResult(BaseModel):
"""语音识别结果"""
text: str # 识别结果
language: str # 语言
confidence: float # 置信度
processing_time: float # 处理时间
class ASRResponse(BaseModel):
"""统一API响应"""
success: bool
data: Optional[ASRResult] = None
error: Optional[str] = None
class ServiceInfo(BaseModel):
"""服务信息"""
message: str # 消息
docs: str # 文档
health: str # 健康
transcribe: str # 识别
+12
View File
@@ -0,0 +1,12 @@
本模块为asr模块,即语音转文本模块。
本模块提供了两种访问方式:
- 第一种为本地部署的func call方式,即提供函数调用的方式调用相关asr接口。
- 第二种为http call方式,即提供http call方式调用相关asr接口。[FastAPI]
如果你的电脑显卡支持cuda, 并且显存大小大于8G, 那么可以使用第一种方式,
否则可以使用第二种,进行云端部署。
两种调用方式在相同配置下,性能几乎无差别。
个人建议使用第二种
+65
View File
@@ -0,0 +1,65 @@
# start_api.py
import uvicorn
from loguru import logger
import threading
import time
def start_server():
"""启动 ASR API 服务"""
uvicorn.run(
"api:app", # 模块名:app实例
host="0.0.0.0",
port=20260,
workers=1, # 单用户场景,1个worker足够
log_level="info",
reload=False, # 生产环境关闭热重载
access_log=True,
)
def first_test() -> None:
"""首次启动测试"""
time.sleep(5) # 给服务器一些启动时间
# 构造一个测试请求以验证初始化模型加载成功
logger.info("🚀 测试模型是否加载成功...")
import requests
from pathlib import Path
url = "http://localhost:20260/transcribe"
audio_path = Path("../../../Test/test_files/test.wav")
try:
with open(audio_path, "rb") as f:
# 明确指定文件名和 MIME 类型
files = {
"file": (
audio_path.name, # 文件名
f, # 文件对象
"audio/wav" # MIME 类型
)
}
response = requests.post(url, files=files)
logger.info(f"状态码: {response.status_code}")
logger.info(f"响应头: {response.headers.get('content-type')}")
if response.status_code == 200:
result = response.json()
logger.info(f"识别结果: {result['data']['text']}")
logger.info(f"识别语言: {result['data']['language']}")
logger.info(f"置信度: {result['data']['confidence']:.2f}")
logger.info(f"处理时间: {result['data']['processing_time']}s")
else:
logger.error(f"请求失败,错误响应信息: {response.text}")
logger.error("请检查模型是否正确加载或其他问题")
except Exception as e:
logger.error(f"测试过程中发生错误: {e}")
if __name__ == "__main__":
logger.info("🚀 启动 ASR API 服务...")
# 在后台线程启动服务器
server_thread = threading.Thread(target=start_server, daemon=True)
server_thread.start()
# 执行测试
first_test()
# 保持主线程运行
server_thread.join()
@@ -0,0 +1,119 @@
# ui_tars_/ui_tars_client.py
from typing import Optional
from loguru import logger
import asyncio
from src.modules.text_ai_module.text_ai_core.general_text_ai_req import (
UnifiedLLM,
ModelConfig,
ModelProvider,
ChatMessage,
)
from pydantic import BaseModel, Field
from src.modules.device_control_module.device_control_core.ui_tars_.ui_tars_prompts import UI_TARS_SYSTEM_PROMPT
class UITarsClientConfig(BaseModel):
"""UI-TARS 客户端配置"""
deployment_type: str = Field(default="lmstudio", description="部署类型")
base_url: str = Field(default="http://localhost:1234/v1", description="API地址")
model_name: str = Field(default="ui-tars", description="模型名称")
api_key: Optional[str] = Field(default=None, description="API密钥")
temperature: float = Field(default=0.1, ge=0.0, le=2.0)
max_tokens: int = Field(default=8192, ge=2048, le=128000)
timeout: int = Field(default=30, ge=5, le=300)
# UI-TARS-1.5 强制输出格式
system_prompt: str = Field(
default=UI_TARS_SYSTEM_PROMPT # 使用本项目自定义的输出格式约束
)
def to_model_config(self) -> ModelConfig:
"""转换为 UnifiedLLM 配置"""
# 映射部署类型到 ModelProvider
provider_map = {
"lmstudio": ModelProvider.LM_STUDIO,
"vllm": ModelProvider.CUSTOM,
"cloud": ModelProvider.OPENAI,
"ollama": ModelProvider.OLLAMA
}
provider = provider_map.get(self.deployment_type, ModelProvider.CUSTOM)
return ModelConfig(
provider=provider,
model_name=self.model_name,
api_key=self.api_key,
base_url=self.base_url,
temperature=self.temperature,
max_tokens=self.max_tokens,
timeout=self.timeout,
custom_headers={"User-Agent": "UI-TARS-Client/1.0"}
)
class UITarsClient:
"""
UI-TARS 通用客户端 (基于 UnifiedLLM)
图片相关信息请直接传入相应的base64
"""
def __init__(self, config: UITarsClientConfig):
self.config = config
# 复用 UnifiedLLM,自动处理所有部署类型
self.llm = UnifiedLLM(config.to_model_config())
logger.info(f"UI-TARS 客户端初始化: {config.deployment_type} @ {config.base_url}")
logger.info(f" 模型: {config.model_name} | 温度: {config.temperature}")
async def call_async(self, instruction: str, image_base64: str) -> str:
"""异步调用 UI-TARS"""
# 构建消息
messages = self._build_messages(instruction, image_base64)
try:
# 使用 UnifiedLLM 的异步接口
response = await asyncio.to_thread(
self.llm.chat,
messages=messages,
streaming=False
)
return response.content
except Exception as e:
logger.error(f"UI-TARS 调用失败: {e}")
raise
def call_sync(self, instruction: str, image_base64: str) -> str:
"""同步调用 UI-TARS"""
messages = self._build_messages(instruction, image_base64)
try:
response = self.llm.chat(
messages=messages,
streaming=False
)
return response.content
except Exception as e:
logger.error(f"UI-TARS 调用失败: {e}")
raise
def stream_async(self, instruction: str, image_base64: str):
"""流式调用 (异步生成器)"""
messages = self._build_messages(instruction, image_base64)
# UnifiedLLM 自动处理流式
return self.llm.stream_chat(messages=messages)
def _build_messages(self, instruction: str, image_base64: str) -> list:
"""构建 OpenAI 格式消息"""
return [
ChatMessage(
role="system",
content=self.config.system_prompt
),
ChatMessage(
role="user",
content=[
{"type": "text", "text": instruction},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{image_base64}"
}
}
]
)
]
@@ -0,0 +1,25 @@
OFFICIAL_ACTION_SPACE = """## Action Space
click(point='<point>x1 y1</point>') - 单击坐标
left_double(point='<point>x1 y1</point>') - 双击坐标
right_single(point='<point>x1 y1</point>') - 右键单击
drag(start_point='<point>x1 y1</point>', end_point='<point>x2 y2</point>') - 拖拽
hotkey(key='ctrl c') - 快捷键(空格分隔,小写,最多3个键)
type(content='xxx') - 输入文本(用\\' \\\" \\n 转义)
scroll(point='<point>x1 y1</point>', direction='down or up or right or left') - 滚动
wait() - 等待5秒
finished() - 任务完成
## Output Format
Thought: [你的推理过程]
Action: [选择一个动作]
"""
UI_TARS_SYSTEM_PROMPT = f"""You are UI-TARS-1.5, a GUI agent. Given a task and screenshot, output ONLY:
{OFFICIAL_ACTION_SPACE}
## Note
- Write a small plan and summarize the next action in one sentence in Thought.
- NEVER output multiple actions.
- x&y please in box center
"""
@@ -0,0 +1,10 @@
本模块为设备控制模块接口层,此处的设备控制指的是借助AI模型进行一些设备上的自动化
操作,支持`pc`, `android`,其他的未做过测试。
依赖:
`ui_tars`
当前所使用的AI模型为`mradermacher/UI-TARS-1.5-7B-GGUF
(Q6_K or Q4_K_M)`
未来如果有更快质量更高的AI模型本模块会为其添加支持。
@@ -0,0 +1,837 @@
"""
通用大语言模型调用框架
支持本地模型(Ollama, LM Studio, llama.cpp)和云端模型(OpenAI, Anthropic, Google, Azure等)
"""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Union, Any, Iterator
import json
import os
from loguru import logger
from dataclasses import dataclass, asdict, field
from enum import Enum
import httpx
from pydantic import BaseModel, Field
class ModelProvider(Enum):
"""支持的模型提供商枚举"""
OPENAI = "openai"
ANTHROPIC = "anthropic"
GOOGLE = "google"
AZURE = "azure"
OLLAMA = "ollama"
LM_STUDIO = "lm_studio"
LLAMA_CPP = "llama_cpp"
CUSTOM = "custom"
@dataclass
class ModelConfig:
"""模型配置类"""
provider: ModelProvider
model_name: str
api_key: Optional[str] = None
base_url: Optional[str] = None
api_version: Optional[str] = None
temperature: float = 0.7
max_tokens: int = 1024
top_p: float = 1.0
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
timeout: int = 30
streaming: bool = False
custom_headers: Dict[str, str] = field(default_factory=dict)
def to_dict(self) -> Dict:
"""转换为字典"""
return asdict(self)
@dataclass
class ChatMessage:
"""聊天消息类"""
role: str # system, user, assistant
content: str
name: Optional[str] = None
def to_dict(self) -> Dict:
"""转换为字典"""
return {
"role": self.role,
"content": self.content,
**({"name": self.name} if self.name else {})
}
@dataclass
class ModelResponse:
"""模型响应类"""
content: str # 响应内容
model: str # 模型名称
usage: Optional[Dict[str, int]] = None # 使用量
finish_reason: Optional[str] = None # 结束原因
raw_response: Optional[Dict] = None # 原始响应
def normalize_usage(raw_usage: Optional[Dict[str, Any]], provider: ModelProvider) -> Optional[Dict[str, int]]:
"""
将不同平台的 usage 字段统一归一化为 OpenAI 标准格式
Args:
raw_usage: API 原始返回的 usage 数据
provider: 模型提供商枚举
Returns:
归一化后的 usage 字典,格式:
{
"prompt_tokens": int,
"completion_tokens": int,
"total_tokens": int
}
如果无法归一化则返回 None
"""
if not raw_usage:
return None
# 字段映射表:{provider: (input_key, output_key, total_key)}
USAGE_FIELD_MAP = {
ModelProvider.OPENAI: ("prompt_tokens", "completion_tokens", "total_tokens"),
ModelProvider.AZURE: ("prompt_tokens", "completion_tokens", "total_tokens"),
ModelProvider.LM_STUDIO: ("prompt_tokens", "completion_tokens", "total_tokens"),
ModelProvider.LLAMA_CPP: ("prompt_tokens", "completion_tokens", "total_tokens"),
ModelProvider.OLLAMA: ("prompt_eval_count", "eval_count", None), # Ollama 没有 total
ModelProvider.ANTHROPIC: ("input_tokens", "output_tokens", None),
ModelProvider.GOOGLE: ("promptTokenCount", "candidatesTokenCount", "totalTokenCount"),
}
input_key, output_key, total_key = USAGE_FIELD_MAP.get(provider, (None, None, None))
if input_key is None:
logger.warning(f"未知的 provider '{provider}',无法归一化 usage")
return None
try:
# 提取字段值
prompt_tokens = raw_usage.get(input_key, 0)
completion_tokens = raw_usage.get(output_key, 0)
# 处理嵌套字典(如有些 API 的 usage 格式特殊)
if isinstance(prompt_tokens, dict):
prompt_tokens = prompt_tokens.get("value", 0)
if isinstance(completion_tokens, dict):
completion_tokens = completion_tokens.get("value", 0)
# 转换为整数
prompt_tokens = int(prompt_tokens) if prompt_tokens else 0
completion_tokens = int(completion_tokens) if completion_tokens else 0
# 计算 total(如果 API 没提供)
if total_key and total_key in raw_usage:
total_tokens = int(raw_usage[total_key])
else:
total_tokens = prompt_tokens + completion_tokens
normalized = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens
}
logger.debug(f"归一化 usage | {provider} -> OpenAI格式: {normalized}")
return normalized
except Exception as e:
logger.error(f"归一化 usage 失败: {e} | raw_usage: {raw_usage}")
return None
class BaseLLMClient(ABC):
"""大语言模型客户端基类"""
def __init__(self, config: ModelConfig):
self.config = config
self.client = None
self._initialize_client()
@abstractmethod
def _initialize_client(self):
"""初始化客户端"""
pass
@abstractmethod
def chat_completion(
self,
messages: List[Union[ChatMessage, Dict]],
**kwargs
) -> Union[ModelResponse, Iterator[ModelResponse]]:
"""聊天补全"""
pass
@abstractmethod
def completion(
self,
prompt: str,
**kwargs
) -> Union[ModelResponse, Iterator[ModelResponse]]:
"""文本补全"""
pass
def format_messages(self, messages: List[Union[ChatMessage, Dict]]) -> List[Dict]:
"""格式化消息列表"""
formatted = []
for msg in messages:
if isinstance(msg, ChatMessage):
formatted.append(msg.to_dict())
else:
formatted.append(msg)
return formatted
class OpenAIClient(BaseLLMClient):
"""OpenAI客户端"""
def _initialize_client(self):
try:
from openai import OpenAI
api_key = self.config.api_key
if not api_key:
raise ValueError("OpenAI API密钥未设置")
self.client = OpenAI(
api_key=api_key,
base_url=self.config.base_url,
timeout=self.config.timeout
)
logger.info(f"OpenAI客户端初始化成功,base_url: {self.config.base_url}")
except ImportError:
logger.error("请安装openai包: pip install openai")
raise
def chat_completion(self, messages, **kwargs):
formatted_messages = self.format_messages(messages)
# 获取streaming参数,优先使用kwargs中的设置
streaming = kwargs.get("streaming", self.config.streaming)
# 合并配置
params = {
"model": self.config.model_name,
"messages": formatted_messages,
"temperature": kwargs.get("temperature", self.config.temperature),
"max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
"top_p": kwargs.get("top_p", self.config.top_p),
"frequency_penalty": kwargs.get("frequency_penalty", self.config.frequency_penalty),
"presence_penalty": kwargs.get("presence_penalty", self.config.presence_penalty),
"stream": streaming, # 使用正确的streaming设置
}
logger.info(f"🔧 调用参数: streaming={streaming}")
if streaming:
return self._stream_chat_completion(params)
else:
return self._normal_chat_completion(params)
def _normal_chat_completion(self, params):
"""非流式响应处理"""
logger.info("📡 发送非流式请求...")
response = self.client.chat.completions.create(**params)
raw_usage = response.usage
normalized_usage = normalize_usage(
raw_usage.model_dump() if hasattr(raw_usage, 'model_dump') else raw_usage,
ModelProvider.OPENAI
)
return ModelResponse(
content=response.choices[0].message.content,
model=response.model,
usage=normalized_usage,
finish_reason=response.choices[0].finish_reason,
raw_response=response.model_dump() if hasattr(response, 'model_dump') else response.dict()
)
def _stream_chat_completion(self, params):
"""流式响应处理"""
logger.info("📡 发送流式请求...")
response_stream = self.client.chat.completions.create(**params)
full_content = ""
for chunk in response_stream:
if chunk.choices[0].delta.content is not None:
content_chunk = chunk.choices[0].delta.content
full_content += content_chunk
yield ModelResponse(
content=content_chunk,
model=chunk.model,
raw_response=chunk.model_dump() if hasattr(chunk, 'model_dump') else chunk.dict()
)
def completion(self, prompt, **kwargs):
# OpenAI 推荐使用 chat_completion,这里保持兼容
messages = [{"role": "user", "content": prompt}]
return self.chat_completion(messages, **kwargs)
class AnthropicClient(BaseLLMClient):
"""Anthropic Claude客户端"""
def _initialize_client(self):
try:
from anthropic import Anthropic
self.client = Anthropic(
api_key=self.config.api_key or os.getenv("ANTHROPIC_API_KEY"),
timeout=self.config.timeout
)
except ImportError:
logger.error("请安装anthropic包: pip install anthropic")
raise
def chat_completion(self, messages, **kwargs):
formatted_messages = self.format_messages(messages)
# Claude 的消息格式转换
claude_messages = []
system_message = None
for msg in formatted_messages:
if msg["role"] == "system":
system_message = msg["content"]
else:
claude_messages.append({
"role": msg["role"],
"content": msg["content"]
})
# 明确获取streaming参数
params = {
"model": self.config.model_name,
"messages": claude_messages,
"max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
"temperature": kwargs.get("temperature", self.config.temperature),
"top_p": kwargs.get("top_p", self.config.top_p),
"stream": kwargs.get("streaming", self.config.streaming),
}
if system_message:
params["system"] = system_message
if self.config.streaming:
return self._stream_chat_completion(params)
else:
return self._normal_chat_completion(params)
def _normal_chat_completion(self, params):
response = self.client.messages.create(**params)
# Anthropic 返回的usage格式和OpenAI不同,需要进行转换
raw_usage = response.usage
normalized_usage = normalize_usage(
raw_usage.model_dump() if hasattr(raw_usage, 'model_dump') else raw_usage,
ModelProvider.ANTHROPIC
)
return ModelResponse(
content=response.content[0].text,
model=response.model,
usage=normalized_usage,
finish_reason=response.stop_reason,
raw_response=response.model_dump() if hasattr(response, 'model_dump') else response.dict()
)
def _stream_chat_completion(self, params):
with self.client.messages.stream(**params) as stream:
for chunk in stream:
if chunk.type_ == "content_block_delta":
yield ModelResponse(
content=chunk.delta.text,
model=params["model"],
raw_response=chunk.model_dump() if hasattr(chunk, 'model_dump') else chunk.dict()
)
def completion(self, prompt, **kwargs):
messages = [{"role": "user", "content": prompt}]
return self.chat_completion(messages, **kwargs)
class OllamaClient(BaseLLMClient):
"""Ollama本地模型客户端"""
def _initialize_client(self):
import httpx
self.base_url = self.config.base_url or "http://localhost:11434"
self.client = httpx.Client(
base_url=self.base_url,
timeout=self.config.timeout
)
def chat_completion(self, messages, **kwargs):
formatted_messages = self.format_messages(messages)
payload = {
"model": self.config.model_name,
"messages": formatted_messages,
"options": {
"temperature": kwargs.get("temperature", self.config.temperature),
"top_p": kwargs.get("top_p", self.config.top_p),
"num_predict": kwargs.get("max_tokens", self.config.max_tokens),
},
"stream": kwargs.get("streaming", self.config.streaming),
}
if self.config.streaming:
return self._stream_chat_completion(payload)
else:
return self._normal_chat_completion(payload)
def _normal_chat_completion(self, payload):
response = self.client.post("/api/chat", json=payload)
response.raise_for_status()
data = response.json()
normalized_usage = normalize_usage(data, ModelProvider.OLLAMA)
return ModelResponse(
content=data["message"]["content"],
model=data["model"],
usage=normalized_usage,
finish_reason=data.get("done_reason"),
raw_response=data
)
def _stream_chat_completion(self, payload):
with self.client.stream("POST", "/api/chat", json=payload) as response:
for line in response.iter_lines():
if line.strip():
try:
data = json.loads(line)
if "message" in data and "content" in data["message"]:
yield ModelResponse(
content=data["message"]["content"],
model=data.get("model", self.config.model_name),
raw_response=data
)
except json.JSONDecodeError:
continue
def completion(self, prompt, **kwargs):
payload = {
"model": self.config.model_name,
"prompt": prompt,
"options": {
"temperature": kwargs.get("temperature", self.config.temperature),
"top_p": kwargs.get("top_p", self.config.top_p),
"num_predict": kwargs.get("max_tokens", self.config.max_tokens),
},
"stream": kwargs.get("streaming", self.config.streaming),
}
if self.config.streaming:
return self._stream_completion(payload)
else:
return self._normal_completion(payload)
def _normal_completion(self, payload):
response = self.client.post("/api/generate", json=payload)
response.raise_for_status()
data = response.json()
return ModelResponse(
content=data["response"],
model=data["model"],
usage={
"prompt_tokens": data.get("prompt_eval_count", 0),
"completion_tokens": data.get("eval_count", 0),
"total_tokens": data.get("prompt_eval_count", 0) + data.get("eval_count", 0)
},
finish_reason=data.get("done_reason"),
raw_response=data
)
def _stream_completion(self, payload):
with self.client.stream("POST", "/api/generate", json=payload) as response:
for line in response.iter_lines():
if line.strip():
try:
data = json.loads(line)
if "response" in data:
yield ModelResponse(
content=data["response"],
model=data.get("model", self.config.model_name),
raw_response=data
)
except json.JSONDecodeError:
continue
class GenericLLMClient(BaseLLMClient):
"""通用HTTP客户端,支持LM Studio和其他兼容OpenAI API的本地模型"""
def _initialize_client(self):
import httpx
self.base_url = self.config.base_url or "http://localhost:1234/v1"
self.client = httpx.Client(
base_url=self.base_url,
timeout=self.config.timeout,
headers=self.config.custom_headers
)
def chat_completion(self, messages, **kwargs):
formatted_messages = self.format_messages(messages)
# 明确获取 streaming 参数
streaming = kwargs.get("streaming", self.config.streaming)
payload = {
"model": self.config.model_name,
"messages": formatted_messages,
"temperature": kwargs.get("temperature", self.config.temperature),
"max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
"top_p": kwargs.get("top_p", self.config.top_p),
"stream": streaming, # 使用明确的 streaming 变量
}
logger.info(f"GenericLLMClient 参数: streaming={streaming}")
if streaming:
return self._stream_chat_completion(payload)
else:
return self._normal_chat_completion(payload)
def _normal_chat_completion(self, payload):
logger.info(f"GenericLLMClient 发送非流式请求到: {self.base_url}")
response = self.client.post("/chat/completions", json=payload)
response.raise_for_status()
data = response.json()
logger.info(f"GenericLLMClient 收到响应,模型: {data.get('model')}")
return ModelResponse(
content=data["choices"][0]["message"]["content"],
model=data["model"],
usage=data.get("usage"),
finish_reason=data["choices"][0].get("finish_reason"),
raw_response=data
)
def _stream_chat_completion(self, payload):
logger.info(f"GenericLLMClient 发送流式请求到: {self.base_url}")
with self.client.stream("POST", "/chat/completions", json=payload) as response:
for line in response.iter_lines():
if line.startswith("data: "):
chunk = line[6:]
if chunk == "[DONE]":
break
try:
data = json.loads(chunk)
if data["choices"][0]["delta"].get("content"):
yield ModelResponse(
content=data["choices"][0]["delta"]["content"],
model=data["model"],
raw_response=data
)
except json.JSONDecodeError:
continue
def completion(self, prompt, **kwargs):
payload = {
"model": self.config.model_name,
"prompt": prompt,
"temperature": kwargs.get("temperature", self.config.temperature),
"max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
"top_p": kwargs.get("top_p", self.config.top_p),
"stream": kwargs.get("streaming", self.config.streaming),
}
streaming = kwargs.get("streaming", self.config.streaming)
if streaming:
return self._stream_completion(payload)
else:
return self._normal_completion(payload)
def _normal_completion(self, payload):
response = self.client.post("/completions", json=payload)
response.raise_for_status()
data = response.json()
return ModelResponse(
content=data["choices"][0]["text"],
model=data["model"],
usage=data.get("usage"),
finish_reason=data["choices"][0].get("finish_reason"),
raw_response=data
)
def _stream_completion(self, payload):
with self.client.stream("POST", "/completions", json=payload) as response:
for line in response.iter_lines():
if line.startswith("data: "):
chunk = line[6:]
if chunk == "[DONE]":
break
try:
data = json.loads(chunk)
if data["choices"][0].get("text"):
yield ModelResponse(
content=data["choices"][0]["text"],
model=data["model"],
raw_response=data
)
except json.JSONDecodeError:
continue
class UnifiedLLM:
"""
统一的大语言模型调用类
支持多种模型提供商:
- 云端模型:OpenAI, Anthropic, Google, Azure
- 本地模型:Ollama, LM Studio, llama.cpp
使用示例:
```python
# 初始化OpenAI客户端
config = ModelConfig(
provider=ModelProvider.OPENAI,
model_name="gpt-4",
api_key="your-api-key"
)
llm = UnifiedLLM(config)
# 聊天补全
messages = [
{"role": "system", "content": "你是一个有用的助手"},
{"role": "user", "content": "你好!"}
]
response = llm.chat(messages)
print(response.content)
# 流式响应
config.streaming = True
llm.update_config(config)
for chunk in llm.chat(messages):
print(chunk.content, end="", flush=True)
```
"""
def __init__(self, config: ModelConfig):
"""
初始化统一LLM
Args:
config: 模型配置
"""
self.config = config
self.client = self._create_client()
logger.info(f" UnifiedLLM 初始化完成")
logger.info(f" 提供商: {config.provider}")
logger.info(f" 模型: {config.model_name}")
logger.info(f" 流式默认: {config.streaming}")
def _create_client(self) -> BaseLLMClient:
"""根据配置创建客户端"""
provider = self.config.provider
if provider == ModelProvider.OPENAI:
return OpenAIClient(self.config)
elif provider == ModelProvider.ANTHROPIC:
return AnthropicClient(self.config)
elif provider == ModelProvider.OLLAMA:
return OllamaClient(self.config)
elif provider in [ModelProvider.LM_STUDIO, ModelProvider.LLAMA_CPP, ModelProvider.CUSTOM]:
return GenericLLMClient(self.config)
elif provider == ModelProvider.GOOGLE:
# 这里可以扩展Google Gemini支持
return GenericLLMClient(self.config)
elif provider == ModelProvider.AZURE:
# Azure OpenAI需要特殊处理
return GenericLLMClient(self.config)
else:
raise ValueError(f"不支持的模型提供商: {provider}")
def update_config(self, config: ModelConfig):
"""更新配置并重新创建客户端"""
self.config = config
self.client = self._create_client()
logger.info(f"UnifiedLLM 配置已更新")
def chat(self, messages: List[Union[ChatMessage, Dict]], **kwargs) -> Union[ModelResponse, Iterator[ModelResponse]]:
"""
聊天补全
Args:
messages: 消息列表
**kwargs: 其他参数,会覆盖config中的设置
Returns:
ModelResponse 或 ModelResponse 的迭代器(流式模式)
"""
# 明确获取streaming参数
streaming = kwargs.get("streaming", self.config.streaming)
logger.info(f" UnifiedLLM.chat() 调用")
logger.info(f" 消息数: {len(messages)}")
logger.info(f" streaming参数: {streaming}")
# 调用客户端
result = self.client.chat_completion(messages, **kwargs)
# 类型检查(调试用)
if streaming:
if not hasattr(result, '__iter__'):
logger.warning(f"警告: streaming=True 但返回的不是迭代器")
else:
if hasattr(result, '__iter__'):
logger.warning(f"警告: streaming=False 但返回的是迭代器")
elif not isinstance(result, ModelResponse):
logger.warning(f"警告: streaming=False 但返回的不是ModelResponse")
return result
def complete(self, prompt: str, **kwargs) -> Union[ModelResponse, Iterator[ModelResponse]]:
"""
文本补全
Args:
prompt: 提示文本
**kwargs: 其他参数
Returns:
ModelResponse 或 ModelResponse 的迭代器(流式模式)
"""
streaming = kwargs.get("streaming", self.config.streaming)
logger.info(f" UnifiedLLM.complete() 调用")
logger.info(f" prompt长度: {len(prompt)}")
logger.info(f" streaming参数: {streaming}")
return self.client.completion(prompt, **kwargs)
def stream_chat(self, messages: List[Union[ChatMessage, Dict]], **kwargs) -> Iterator[ModelResponse]:
"""
流式聊天补全(便捷方法)
Args:
messages: 消息列表
**kwargs: 其他参数
Returns:
ModelResponse 的迭代器
"""
kwargs["streaming"] = True
logger.info(f"UnifiedLLM.stream_chat() 调用")
result = self.chat(messages, **kwargs)
# 确保返回的是迭代器
if not hasattr(result, '__iter__'):
raise TypeError("stream_chat 应该返回迭代器,但返回了其他类型")
return result
def stream_complete(self, prompt: str, **kwargs) -> Iterator[ModelResponse]:
"""
流式文本补全(便捷方法)
Args:
prompt: 提示文本
**kwargs: 其他参数
Returns:
ModelResponse 的迭代器
"""
kwargs["streaming"] = True
logger.info(f"UnifiedLLM.stream_complete() 调用")
result = self.complete(prompt, **kwargs)
# 确保返回的是迭代器
if not hasattr(result, '__iter__'):
raise TypeError("stream_complete 应该返回迭代器,但返回了其他类型")
return result
# 快捷函数
def create_llm_client(
provider: Union[str, ModelProvider],
model_name: str,
**kwargs
) -> UnifiedLLM:
"""
快捷创建LLM客户端
Args:
provider: 提供商名称或枚举
model_name: 模型名称
**kwargs: 其他配置参数
Returns:
UnifiedLLM 实例
"""
if isinstance(provider, str):
provider = ModelProvider(provider.lower())
config = ModelConfig(
provider=provider,
model_name=model_name,
**kwargs
)
return UnifiedLLM(config)
# 使用示例
def example_usage():
"""使用示例"""
# 示例1: 使用OpenAI
print("示例1: 使用OpenAI")
openai_config = ModelConfig(
provider=ModelProvider.OPENAI,
model_name="gpt-3.5-turbo",
api_key=os.getenv("OPENAI_API_KEY"),
temperature=0.8
)
try:
openai_llm = UnifiedLLM(openai_config)
messages = [
ChatMessage(role="system", content="你是一个有用的助手"),
ChatMessage(role="user", content="请用Python写一个Hello World程序")
]
response = openai_llm.chat(messages)
print(f"响应: {response.content[:100]}...")
except Exception as e:
print(f"OpenAI示例错误: {e}")
# 示例2: 使用Ollama(本地模型)
print("\n示例2: 使用Ollama(本地模型)")
ollama_config = ModelConfig(
provider=ModelProvider.OLLAMA,
model_name="llama2",
base_url="http://localhost:11434",
temperature=0.7,
streaming=True # 流式响应
)
try:
ollama_llm = UnifiedLLM(ollama_config)
messages = [
{"role": "user", "content": "什么是人工智能?"}
]
print("流式响应:")
for chunk in ollama_llm.stream_chat(messages):
print(chunk.content, end="", flush=True)
print()
except Exception as e:
print(f"Ollama示例错误: {e}(请确保Ollama服务正在运行)")
# 示例3: 使用快捷函数
print("\n示例3: 使用快捷函数")
try:
llm = create_llm_client(
provider="openai",
model_name="gpt-3.5-turbo",
api_key=os.getenv("OPENAI_API_KEY"),
temperature=0.5
)
response = llm.complete("天空为什么是蓝色的?")
print(f"响应: {response.content[:100]}...")
except Exception as e:
print(f"快捷函数示例错误: {e}")
@@ -0,0 +1,162 @@
# 大语言模型调用框架架构图
general_text_ai_req.py
```mermaid
graph TB
subgraph "应用层"
A[用户应用] --> B[UnifiedLLM 统一接口]
end
subgraph "适配器层"
B --> C{模型提供商路由}
C --> D[OpenAI 适配器]
C --> E[Anthropic 适配器]
C --> F[Ollama 适配器]
C --> G[通用HTTP适配器]
C --> H[其他适配器]
end
subgraph "服务层"
D --> I[OpenAI API]
E --> J[Anthropic API]
F --> K[Ollama 服务]
G --> L[LM Studio]
G --> M[llama.cpp]
G --> N[其他兼容API]
end
subgraph "配置层"
O[ModelConfig] --> C
O --> D
O --> E
O --> F
O --> G
end
subgraph "数据流"
P[输入: 消息/提示] --> B
I --> Q[输出: ModelResponse]
J --> Q
K --> Q
L --> Q
M --> Q
N --> Q
end
style A fill:#4567f1
style B fill:#4567f1
style O fill:#456748
style D fill:#457911
style E fill:#466bd5
style F fill:#4567f1
style G fill:#4567f1
```
# 数据流图
```mermaid
sequenceDiagram
participant User as 用户/应用
participant UnifiedLLM as UnifiedLLM
participant Adapter as 适配器
participant API as API服务
User->>UnifiedLLM: 调用chat()或complete()
UnifiedLLM->>UnifiedLLM: 根据配置选择适配器
UnifiedLLM->>Adapter: 转发请求
Adapter->>API: 发送HTTP请求/API调用
Note over API: 处理请求并生成响应
alt 流式模式
API-->>Adapter: 流式响应数据
Adapter-->>UnifiedLLM: 流式ModelResponse
UnifiedLLM-->>User: 迭代器返回分块响应
else 非流式模式
API-->>Adapter: 完整响应
Adapter-->>UnifiedLLM: ModelResponse对象
UnifiedLLM-->>User: 完整响应内容
end
```
# 类关系图
```mermaid
classDiagram
class ModelConfig {
+provider: ModelProvider
+model_name: str
+api_key: Optional[str]
+base_url: Optional[str]
+temperature: float
+max_tokens: int
+to_dict() Dict
}
class ChatMessage {
+role: str
+content: str
+name: Optional[str]
+to_dict() Dict
}
class ModelResponse {
+content: str
+model: str
+usage: Optional[Dict]
+finish_reason: Optional[str]
+raw_response: Optional[Dict]
}
class BaseLLMClient {
<<abstract>>
#config: ModelConfig
#client: Any
+__init__(config: ModelConfig)
+_initialize_client()
+chat_completion(messages, **kwargs)*
+completion(prompt, **kwargs)*
+format_messages(messages) List[Dict]
}
class UnifiedLLM {
-config: ModelConfig
-client: BaseLLMClient
+__init__(config: ModelConfig)
+_create_client() BaseLLMClient
+update_config(config: ModelConfig)
+chat(messages, **kwargs) ModelResponse
+complete(prompt, **kwargs) ModelResponse
+stream_chat(messages, **kwargs) Iterator
+stream_complete(prompt, **kwargs) Iterator
}
class OpenAIClient {
+_initialize_client()
+chat_completion(messages, **kwargs)
+completion(prompt, **kwargs)
}
class AnthropicClient {
+_initialize_client()
+chat_completion(messages, **kwargs)
+completion(prompt, **kwargs)
}
class OllamaClient {
+_initialize_client()
+chat_completion(messages, **kwargs)
+completion(prompt, **kwargs)
}
class GenericLLMClient {
+_initialize_client()
+chat_completion(messages, **kwargs)
+completion(prompt, **kwargs)
}
BaseLLMClient <|-- OpenAIClient
BaseLLMClient <|-- AnthropicClient
BaseLLMClient <|-- OllamaClient
BaseLLMClient <|-- GenericLLMClient
UnifiedLLM o-- BaseLLMClient
UnifiedLLM --> ModelConfig
BaseLLMClient --> ModelConfig
BaseLLMClient --> ChatMessage
BaseLLMClient --> ModelResponse
```
View File
+26
View File
@@ -0,0 +1,26 @@
本模块为tts模块,即文本转语音模块。
本模块负责将来自AI的回复转为语音。
说明:
在本module当中,每个子模块的用途分别是:
- tts_core 对不同的tts的实现,提供相对统一的接口
- gpt_sovits
实现了gpt_sovits的tts接口封装
async_audio_player.py
```mermaid
sequenceDiagram
participant TTS as GPT-SoVITS API
participant WS as WebSocket服务
participant Buffer as 音频缓冲区
participant Player as 音频播放器
TTS->>WS: 流式音频块(chunks)
WS->>Buffer: 写入队列(Queue)
Buffer->>Player: 消费PCM数据
Player->>声卡: 实时播放
Note over TTS,Player: 三重缓冲 + 动态采样率检测
```
@@ -0,0 +1,164 @@
# tts_core/async_audio_player.py
import asyncio
import io
from loguru import logger
from typing import Optional
import numpy as np
import sounddevice as sd
import wave
class AsyncAudioPlayer:
"""
异步流式音频播放器
- 自动检测WAV头并解析采样率
- 使用环形缓冲区确保播放流畅
- 支持动态音频格式切换
"""
def __init__(self, buffer_size: int = 10):
"""
Args:
buffer_size: 音频块缓冲数量(越大越稳定,但延迟越高)
"""
self.audio_queue = asyncio.Queue(maxsize=buffer_size)
self.sample_rate = 32000 # 默认采样率
self.channels = 1
self.dtype = np.float32
self.stream: Optional[sd.OutputStream] = None
self.is_playing = False
self._first_chunk_processed = False
logger.info(f"🎵 音频播放器初始化,缓冲区大小: {buffer_size}")
async def add_chunk(self, audio_data: bytes):
"""
添加音频块到播放队列
自动处理第一个chunk(包含WAV头)
"""
try:
# 第一个chunk需要解析WAV头
if not self._first_chunk_processed:
# 写入BytesIO以便wave模块读取
wav_buffer = io.BytesIO(audio_data)
try:
with wave.open(wav_buffer, 'rb') as wav_file:
# 解析WAV头信息
self.sample_rate = wav_file.getframerate()
self.channels = wav_file.getnchannels()
self.sampwidth = wav_file.getsampwidth()
# 读取PCM数据(去掉头部)
pcm_data = wav_file.readframes(wav_file.getnframes())
logger.info(f"📊 解析WAV头: {self.sample_rate}Hz, {self.channels}ch, {self.sampwidth * 8}bit")
# 转换为numpy数组
if self.sampwidth == 2:
audio_array = np.frombuffer(pcm_data, dtype=np.int16).astype(np.float32) / 32768.0
elif self.sampwidth == 4:
audio_array = np.frombuffer(pcm_data, dtype=np.int32).astype(np.float32) / 2147483648.0
else:
raise ValueError(f"不支持的采样宽度: {self.sampwidth}")
# 转单声道(如果多声道)
if self.channels > 1:
audio_array = audio_array.reshape(-1, self.channels).mean(axis=1)
await self.audio_queue.put(audio_array)
self._first_chunk_processed = True
except wave.Error:
# 可能是不完整的WAV头,尝试直接播放
logger.warning("⚠️ WAV头解析失败,尝试直接播放")
await self._play_raw(audio_data)
return
else:
# 后续chunk直接播放(RAW PCM
await self._play_raw(audio_data)
except Exception as e:
logger.error(f"❌ 音频块处理失败: {e}")
async def _play_raw(self, audio_data: bytes):
"""播放RAW PCM数据"""
try:
# 假设是16位PCM(最常见)
audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
# 如果是多声道数据(罕见)
if len(audio_array) % self.channels == 0 and self.channels > 1:
audio_array = audio_array.reshape(-1, self.channels).mean(axis=1)
await self.audio_queue.put(audio_array)
except Exception as e:
logger.error(f"❌ RAW音频处理失败: {e}")
async def play_worker(self):
"""后台播放任务"""
logger.info("🎧 音频播放任务启动")
while self.is_playing or not self.audio_queue.empty():
try:
# 从队列获取音频块(最多等待0.5秒)
audio_chunk = await asyncio.wait_for(self.audio_queue.get(), timeout=0.5)
# 延迟初始化音频流(直到获得第一个数据块)
if self.stream is None:
logger.info(f"🔊 打开音频输出流: {self.sample_rate}Hz")
self.stream = sd.OutputStream(
samplerate=self.sample_rate,
channels=1,
dtype=self.dtype,
blocksize=1024, # 低延迟模式
latency='low'
)
self.stream.start()
# 写入音频流播放
self.stream.write(audio_chunk)
# 标记任务完成
self.audio_queue.task_done()
except asyncio.TimeoutError:
continue
except Exception as e:
logger.error(f"❌ 播放任务异常: {e}")
break
logger.info("🛑 音频播放任务结束")
async def start(self):
"""启动播放系统"""
self.is_playing = True
self._first_chunk_processed = False
self.play_task = asyncio.create_task(self.play_worker())
async def stop(self):
"""停止播放并清理资源"""
self.is_playing = False
# 等待播放任务结束
if hasattr(self, 'play_task'):
await self.play_task
# 关闭音频流
if self.stream is not None:
self.stream.stop()
self.stream.close()
self.stream = None
# 清空队列
while not self.audio_queue.empty():
try:
self.audio_queue.get_nowait()
except:
break
logger.info("✅ 音频播放已停止")
async def __aenter__(self):
await self.start()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.stop()
@@ -0,0 +1,378 @@
# gpt_sovits/gpt_sovits_client.py
import asyncio
import json
from loguru import logger
from enum import Enum
from pathlib import Path
from typing import AsyncGenerator, Optional, Union, Dict, Any
from dataclasses import dataclass
import httpx
from pydantic import BaseModel, Field, validator
class APIError(Exception):
"""API调用异常"""
def __init__(self, status_code: int, message: str):
self.status_code = status_code
self.message = message
super().__init__(f"API Error {status_code}: {message}")
class StreamingMode(Enum):
"""流式模式枚举"""
DISABLED = 0 # 非流式
BEST_QUALITY = 1 # 最佳质量(慢)
MEDIUM_QUALITY = 2 # 中等质量
FASTEST = 3 # 最快响应(较低质量)
class TTSConfig(BaseModel):
"""TTS请求配置模型"""
text: str = Field(..., description="待合成文本")
text_lang: str = Field(..., description="文本语言: zh/en/ja/ko/cantonese")
ref_audio_path: str = Field(..., description="参考音频路径")
prompt_lang: str = Field(..., description="提示文本语言")
# 可选参数
prompt_text: str = Field(default="", description="参考音频提示文本")
aux_ref_audio_paths: list = Field(default_factory=list, description="辅助参考音频")
top_k: int = Field(default=5, ge=1, le=100, description="Top-K采样")
top_p: float = Field(default=1.0, ge=0.1, le=1.0, description="Top-P采样")
temperature: float = Field(default=1.0, ge=0.1, le=1.0, description="采样温度")
text_split_method: str = Field(default="cut5", description="文本分割方法") # 默认按照标点符号切分
batch_size: int = Field(default=8, ge=1, le=200, description="批处理大小")
speed_factor: float = Field(default=1.0, ge=0.6, le=1.65, description="语速倍率")
# 流式相关
streaming_mode: Union[bool, int, StreamingMode] = Field(default=False, description="流式模式")
media_type: str = Field(default="wav", description="输出格式: wav/raw/ogg/aac") # 输出格式
# 高级参数
repetition_penalty: float = Field(default=1.35, ge=1.0, le=2.0) # 惩罚参数
sample_steps: int = Field(default=32, ge=10, le=100) # 采样步数
parallel_infer: bool = Field(default=True) # 并行推理
@validator('text_lang', 'prompt_lang')
def validate_language(cls, v):
"""验证语言代码"""
valid_langs = {'zh', 'en', 'ja', 'ko', 'cantonese'}
if v.lower() not in valid_langs:
raise ValueError(f"Unsupported language: {v}. Must be one of {valid_langs}")
return v.lower()
@validator('media_type')
def validate_media_type(cls, v):
"""验证媒体类型"""
valid_types = {'wav', 'raw', 'ogg', 'aac'}
if v not in valid_types:
raise ValueError(f"Unsupported media_type: {v}")
return v
def build_request(self) -> Dict[str, Any]:
"""构建API请求数据"""
data = self.dict(exclude_none=True)
# 处理流式模式
if isinstance(self.streaming_mode, StreamingMode):
data['streaming_mode'] = self.streaming_mode.value
return data
@dataclass
class AudioResponse:
"""音频响应包装类"""
audio_data: bytes
sample_rate: int = 32000
def save(self, path: Union[str, Path]) -> None:
"""保存音频文件"""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, 'wb') as f:
f.write(self.audio_data)
logger.info(f"Audio saved to {path}, size: {len(self.audio_data)} bytes")
class GPTSoVITSClient:
"""
GPT-SoVITS异步API客户端
完整支持所有TTS功能:
- 文本合成(流式/非流式)
- 模型切换(GPT/SoVITS
- 参考音频设置
- 服务器控制
"""
def __init__(self, host: str = "127.0.0.1", port: int = 9880, debug: bool = False):
"""
初始化客户端
Args:
host: API服务器地址
port: API端口
debug: 是否开启调试模式
"""
self.base_url = f"http://{host}:{port}"
self.client = httpx.AsyncClient(
base_url=self.base_url,
timeout=httpx.Timeout(30.0, connect=5.0)
)
self.debug_mode = debug
logger.info(f"GPT-SoVITS Client initialized: {self.base_url}")
async def __aenter__(self) -> "GPTSoVITSClient":
"""异步上下文管理器入口"""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""异步上下文管理器出口"""
await self.close()
async def close(self):
"""关闭HTTP连接"""
await self.client.aclose()
logger.info("Client connection closed")
def _log_debug(self, message: str, **kwargs):
"""调试日志"""
if self.debug_mode:
logger.debug(f"{message} | {kwargs}")
async def _handle_response(self, response: httpx.Response) -> Dict[str, Any]:
"""统一响应处理"""
if response.status_code == 200:
content_type = response.headers.get('content-type', '')
if 'application/json' in content_type:
return response.json()
return {"status": "success", "content": response.content}
else:
try:
error_data = response.json()
raise APIError(response.status_code, error_data.get('message', 'Unknown error'))
except json.JSONDecodeError:
raise APIError(response.status_code, response.text)
# 核心TTS接口
async def tts(
self,
text: str,
ref_audio_path: str,
text_lang: str = "zh",
prompt_lang: str = "zh",
streaming_mode: StreamingMode = StreamingMode.DISABLED, # 默认禁用流式
media_type: str = "wav",
**kwargs
) -> Union[AudioResponse, AsyncGenerator[AudioResponse, None]]:
"""
文本转语音(支持流式)
Args:
text: 待合成文本
ref_audio_path: 参考音频路径(服务器本地路径或URL)
text_lang: 文本语言
prompt_lang: 提示语言
streaming_mode: 流式模式
media_type: 输出格式
**kwargs: 其他TTS参数
Returns:
非流式: AudioResponse对象
流式: AsyncGenerator[AudioResponse, None]异步生成器
Example:
# 非流式
audio = await client.tts("你好", "ref.wav")
# 流式
async for chunk in client.tts("你好", "ref.wav", streaming_mode=StreamingMode.FASTEST):
process(chunk.audio_data)
"""
config = TTSConfig(
text=text,
ref_audio_path=ref_audio_path,
text_lang=text_lang,
prompt_lang=prompt_lang,
streaming_mode=streaming_mode,
media_type=media_type,
**kwargs
)
self._log_debug("TTS Request", config=config.dict())
if streaming_mode == StreamingMode.DISABLED:
# 非流式模式
response = await self.client.post("/tts", json=config.build_request())
if response.status_code != 200:
raise APIError(response.status_code, await response.text())
return AudioResponse(
audio_data=response.content,
sample_rate=32000 # 默认采样率
)
else:
# 流式模式
config.parallel_infer = False # 强制关闭并行推理,避免与流式冲突
config.batch_size = 1 # 流式下batch_size必须为1
async def stream_generator():
async with self.client.stream(
"POST", "/tts",
json=config.build_request(),
timeout=httpx.Timeout(60.0) # 流式需要更长超时
) as response:
if response.status_code != 200:
raise APIError(response.status_code, await response.aread())
async for chunk in response.aiter_bytes():
if chunk:
yield AudioResponse(audio_data=chunk)
return stream_generator()
# 模型管理接口
async def set_gpt_weights(self, weights_path: str) -> bool:
"""
切换GPT模型权重
Args:
weights_path: 权重文件路径(服务器本地路径)
Returns:
bool: 是否成功
Example:
await client.set_gpt_weights("models/s1bert.ckpt")
"""
if not weights_path:
raise ValueError("weights_path cannot be empty")
params = {"weights_path": weights_path}
response = await self.client.get("/set_gpt_weights", params=params)
result = await self._handle_response(response)
logger.info(f"GPT weights switched to: {weights_path}")
return True
async def set_sovits_weights(self, weights_path: str) -> bool:
"""
切换SoVITS模型权重
Args:
weights_path: 权重文件路径(服务器本地路径)
Returns:
bool: 是否成功
"""
if not weights_path:
raise ValueError("weights_path cannot be empty")
params = {"weights_path": weights_path}
response = await self.client.get("/set_sovits_weights", params=params)
await self._handle_response(response)
logger.info(f"SoVITS weights switched to: {weights_path}")
return True
# 参考音频管理
async def set_refer_audio(
self,
audio_source: Union[str, Path, bytes],
audio_name: Optional[str] = None
) -> bool:
"""
设置参考音频(支持多种输入方式)
Args:
audio_source: 音频文件路径(str/Path)或音频数据(bytes
audio_name: 音频文件名(仅bytes输入时需要)
Returns:
bool: 是否成功
Example:
# 方式1: 服务器本地文件
await client.set_refer_audio("/path/to/audio.wav")
# 方式2: 上传音频数据
with open("audio.wav", "rb") as f:
await client.set_refer_audio(f.read(), "audio.wav")
"""
if isinstance(audio_source, (str, Path)):
# GET方式:服务器本地路径
params = {"refer_audio_path": str(audio_source)}
response = await self.client.get("/set_refer_audio", params=params)
await self._handle_response(response)
logger.info(f"Reference audio set: {audio_source}")
else:
# POST方式:上传音频数据
if not audio_name:
raise ValueError("audio_name is required when uploading bytes")
files = {"audio_file": (audio_name, audio_source, "audio/wav")}
response = await self.client.post("/set_refer_audio", files=files)
await self._handle_response(response)
logger.info(f"Reference audio uploaded: {audio_name}")
return True
# 服务器控制
async def control_command(self, command: str) -> bool:
"""
发送控制命令
Args:
command: 命令类型 - "restart""exit"
Returns:
bool: 是否成功
Warning:
"exit"命令会终止API服务器进程!
"""
if command not in ["restart", "exit"]:
raise ValueError("Command must be 'restart' or 'exit'")
response = await self.client.get("/control", params={"command": command})
await self._handle_response(response)
logger.warning(f"Control command executed: {command}")
return True
# 高级快捷方法
async def get_server_info(self) -> Dict[str, Any]:
"""获取服务器状态信息"""
# 通过调用根路径或自定义health接口
try:
response = await self.client.get("/")
return {"status": "online", "detail": response.text}
except Exception as e:
return {"status": "error", "detail": str(e)}
async def batch_tts(
self,
texts: list[str],
ref_audio_path: str,
**kwargs
) -> list[AudioResponse]:
"""
批量TTS合成
Args:
texts: 文本列表
ref_audio_path: 参考音频
**kwargs: 其他TTS参数
Returns:
list[AudioResponse]: 音频响应列表
"""
tasks = [
self.tts(text, ref_audio_path, **kwargs)
for text in texts
]
return await asyncio.gather(*tasks)
# 异步上下文管理器辅助函数
async def create_client(*args, **kwargs) -> GPTSoVITSClient:
"""快速创建客户端实例"""
return GPTSoVITSClient(*args, **kwargs)
@@ -0,0 +1,27 @@
# dto/dto_base.py
from abc import ABC
from typing import Callable, Coroutine, Any, Dict
from src.modules.websocket_base_module.websocket_core.core_ws_server import WebSocketServer
class MessageDTO(ABC):
"""DTO基类"""
def __init__(
self,
ws_server : WebSocketServer, # WebSocketServer单例
):
# 保存服务器实例(用于发送)
self.ws_server = ws_server
# 便捷属性,DTO层直接调用
@property
def send_binary(self):
return self.ws_server.send_binary
@property
def send_text(self):
return self.ws_server.send_text
@property
def send_json(self):
return self.ws_server.send_json
@@ -0,0 +1,95 @@
from pydantic import Field, BaseModel
import base64
from datetime import datetime, timezone
class AudioDataTransferObject(BaseModel):
"""
音频数据传输对象
该对象被用于服务端与客户端的音频数据交互
同时支持流式与非流式的音频数据
同时收发对等(通过Owner标识)
"""
Owner: str = Field(default="server", description="音频数据的拥有者(server or client)")
isStream: bool = Field(default=False, description="音频数据是否为流式数据")
isStart: bool = Field(default=False, description="音频数据是否开始(流式时有效)")
isEnd: bool = Field(default=False, description="音频数据是否结束(流式时有效)")
sequence: int = Field(default=0, description="音频数据块序列号(流式时有效)")
data: bytes = Field(default=b"", description="音频数据,流式时为分块数据,base64编码")
sampleRate: int = Field(default=32000, description="音频采样率")
channelCount: int = Field(default=1, description="音频通道数")
bitDepth: int = Field(default=16, description="音频采样位数")
duration: float = Field(default=0.0, description="音频时长")
text: str = Field(default="", description="音频对应的文本")
def set_dto_data(self, **kwargs) -> "AudioDataTransferObject":
"""链式更新数据(Pydantic 风格)"""
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
return self
def to_json(self) -> dict:
"""
将DTO对象转换为可序列化字典
返回的json数据的格式:
{
"type": "audio_data",
"timestamp": 1672531200.0,
"data": {
"Owner": "server",
......
}
}
"""
# model_dump() 是 Pydantic v2 的序列化方法
payload = self.model_dump() # 获取所有模型字段
payload["data"] = base64.b64encode(payload["data"]).decode() # base64编码
# 构造嵌套结构
return {
"type": "audio_data",
"timestamp": datetime.now(timezone.utc).timestamp(),
"data": payload # 音频字段嵌套
}
@classmethod
def from_json(cls, json_data: dict) -> "AudioDataTransferObject":
"""
从JSON数据创建DTO对象
传入的json数据格式:
{
"Owner": "server",
......
}
"""
payload = json_data
# 解码 base64 (内层 data 字段在传输时是 base64 字符串) -> bytes
if "data" in payload and isinstance(payload["data"], str):
payload["data"] = base64.b64decode(payload["data"])
# 构造对象 Pydantic 自动忽略 type/timestamp
return cls.model_validate(payload)
# 测试代码
if __name__ == "__main__":
# 模拟音频数据
import os
test_audio = os.urandom(1024) # 随机生成1KB音频数据
# 创建DTO
audio = AudioDataTransferObject(
data=test_audio,
sequence=1,
isStream=True,
sampleRate=44100,
duration=0.5
)
# 序列化 → JSON
json_dict = audio.to_json()
print(f"序列化后 data 长度: {len(json_dict['data'])}") # ~1368 字符
print(f"data 前30字符: {json_dict['data'][:30]}...")
# 反序列化 → DTO
restored = AudioDataTransferObject.from_json(json_dict)
print(f"反序列化后 data 长度: {len(restored.data)}") # 1024 bytes
print(f"数据一致: {restored.data == test_audio}") # True
@@ -0,0 +1,73 @@
from pydantic import Field, BaseModel
from datetime import datetime, timezone
class AutoAgentDataTransferObject(BaseModel):
"""
自动化agent数据传输对象
该对象被用于服务端向客户端发送控制信息
"""
Action: str = Field(default="", description="自动化动作名称")
x1: int = Field(default=-1, description="鼠标起始位置x1")
y1: int = Field(default=-1, description="鼠标起始位置y1")
x2: int = Field(default=-1, description="鼠标结束位置x2")
y2: int = Field(default=-1, description="鼠标结束位置y2")
key: str = Field(default="", description="快捷键")
content: str = Field(default="", description="输入文本内容")
direction: str = Field(default="", description="滚动方向")
def set_dto_data(self, **kwargs) -> "AutoAgentDataTransferObject":
"""链式更新数据(Pydantic 风格)"""
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
return self
def to_json(self) -> dict:
"""
将DTO对象转换为可序列化字典
返回的json数据的格式:
{
"type": "auto_agent",
"timestamp": 1672531200.0,
"data": {
"Action": "自动化agent返回的相应的自动化动作名称",
"x1": "某个操作的x1,不是所有的操作都有,如果相关的操作没有,写成-1即可,y1,x2,y2也同理",
"y1": "某个操作的y1",
"x2": "某个操作的x2",
"y2": "某个操作的y2",
"key": "快捷键,若当前操作没有该字段信息,此处内容为空即可",
"content": "输入文本的内容,若当前操作没有该字段信息,此处内容为空即可",
"direction": "滚动方向,若当前操作没有该字段信息,此处内容为空即可"
}
}
"""
# model_dump() 是 Pydantic v2 的序列化方法
payload = self.model_dump() # 获取所有模型字段
# 构造嵌套结构
return {
"type": "auto_agent",
"timestamp": datetime.now(timezone.utc).timestamp(),
"data": payload # 字段嵌套
}
@classmethod
def from_json(cls, json_data: dict) -> "AutoAgentDataTransferObject":
"""
从JSON数据创建DTO对象
传入的json数据格式:
{
"type": "auto_agent",
"Action": "自动化agent返回的相应的自动化动作名称",
"x1": "某个操作的x1,不是所有的操作都有,如果相关的操作没有,写成-1即可,y1,x2,y2也同理",
"y1": "某个操作的y1",
"x2": "某个操作的x2",
"y2": "某个操作的y2",
"key": "快捷键,若当前操作没有该字段信息,此处内容为空即可",
"content": "输入文本的内容,若当前操作没有该字段信息,此处内容为空即可",
"direction": "滚动方向,若当前操作没有该字段信息,此处内容为空即可"
}
"""
payload = json_data
payload.pop("type", None) # 移除多余字段
# 构造对象 Pydantic 自动忽略 type/timestamp
return cls.model_validate(payload)
@@ -0,0 +1,37 @@
class BaseDataTransferObject:
"""
DTO基类
子类按需重写成员函数即可
"""
def __init__(self):
pass
def to_json(self):
"""
将DTO对象转换为JSON
"""
pass
def from_json(self, json_data):
"""
从JSON数据中创建DTO对象
"""
pass
def to_binary(self):
"""
将DTO对象转换为二进制
"""
pass
def from_binary(self, binary_data):
"""
从二进制数据中创建DTO对象
"""
pass
def to_text(self):
"""
将DTO对象转换为文本
"""
pass
def from_text(self, text_data):
"""
从文本数据中创建DTO对象
"""
pass
@@ -0,0 +1,72 @@
from pydantic import Field, BaseModel
from datetime import datetime, timezone
class ScreenShotDataTransferObject(BaseModel):
"""
服务端向客户端请求实时截图的数据传输对象
服务端与客户端收发对等(通过Owner标识)
客户端收到这个type的包,就会自动对当前设备的画面进行截图
"""
Owner: str = Field(default="server", description="数据的拥有者(server or client)")
isSuccess: bool = Field(default=False, description="是否截图成功")
RealTimeScreenShot: str = Field(default="", description="客户端设备的实时截图数据(base64)")
Width: int = Field(default=1920, description="截图的宽度")
Height: int = Field(default=1080, description="截图的高度")
DescribeInfo: str = Field(default="", description="设备的描述信息(告知模型以做出更加准确的判断)")
LLMResponse: str = Field(default="", description="LLM的响应结果(由服务端发送时携带)")
def set_dto_data(self, **kwargs) -> "ScreenShotDataTransferObject":
"""链式更新数据(Pydantic 风格)"""
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
return self
def to_json(self) -> dict:
"""
将DTO对象转换为可序列化字典
返回的json数据的格式:
{
"type": "screenshot_data",
"timestamp": 1672531200.0,
"data": {
"Owner": "数据的拥有者(server or client)",
"isSuccess": "是否截图成功(true or false)"
"RealTimeScreenShot": "客户端设备的实时截图数据(base64)",
"Width": "截图的宽度",
"Height": "截图的高度",
"DescribeInfo": "设备的描述信息(告知模型以做出更加准确的判断)",
"LLMResponse": "LLM的响应结果(由服务端发送时携带)"
}
}
"""
# model_dump() 是 Pydantic v2 的序列化方法
payload = self.model_dump() # 获取所有模型字段
# 构造嵌套结构
return {
"type": "screenshot_data",
"timestamp": datetime.now(timezone.utc).timestamp(),
"data": payload # 字段嵌套
}
@classmethod
def from_json(cls, json_data: dict) -> "ScreenShotDataTransferObject":
"""
从JSON数据创建DTO对象
传入的json数据格式:
{
"Owner": "数据的拥有者(server or client)",
"isSuccess": "是否截图成功(true or false)",
"RealTimeScreenShot": "客户端设备的实时截图数据(base64)",
"Width": "截图的宽度 非必要字段",
"Height": "截图的高度 非必要字段",
"DescribeInfo": "设备的描述信息(告知模型以做出更加准确的判断) 非必要字段",
"LLMResponse": "LLM的响应结果(由服务端发送时携带) 必要字段"
}
"""
payload = json_data
payload.pop("type", None) # 移除多余字段
payload.pop("timestamp", None) # 移除多余字段
# 构造对象 Pydantic 自动忽略 type/timestamp
return cls.model_validate(payload)
@@ -0,0 +1,95 @@
# dto/second_dtos.py
import asyncio
from typing import Callable, Optional, Dict, Any, List, Coroutine
from src.modules.websocket_base_module.dto.dto_base import MessageDTO
from loguru import logger
"""
二级分发器,因为没有信号与槽机制,因此使用观察者模式替代
"""
def singleton(cls): # 单例
_instance = None
_lock = asyncio.Lock()
async def get_instance(*args, **kwargs):
nonlocal _instance
if _instance is None:
async with _lock:
if _instance is None:
_instance = cls(*args, **kwargs)
return _instance
cls.get_instance = get_instance
return cls
# 类型别名
ReceiveCallback = Callable[[Any], Coroutine[Any, Any, None]]
@singleton
class JsonDTO(MessageDTO):
"""针对json消息的二级分发"""
"""
明确业务json格式:
{
"type" : "xxx",
"timestamp" : "95153...",
"data" : "{根据业务的不同,有不同的内容}"
}
"""
# 因为不同的数据块的json以type字段进行包装,根据type进行正确的数据分发
def __init__(self, ws_server):
super().__init__(ws_server)
self.receivers : Dict[str, List[ReceiveCallback]] = {
'audio_data' : [], # 音频数据
'screenshot_data' : [] # 截图数据
}
# 注册json处理callback function
ws_server.register_receiver('json', self._handle_json)
logger.info("[JsonDTO] JSON分发器已注册")
def register_receiver(self, types : str, callback : ReceiveCallback):
"""注册二次分发业务接收函数,供业务DTO调用"""
if types in self.receivers:
self.receivers[types].append(callback)
logger.debug(f"[JsonDTO] 已注册 {types} 接收器,当前共 {len(self.receivers[types])}")
else:
raise ValueError(f"[JsonDTO] 不支持的分发类型: {types}")
def unregister_receiver(self, types : str, callback : ReceiveCallback):
"""注销二次分发业务接收函数"""
if callback in self.receivers[types]:
self.receivers[types].remove(callback)
logger.debug(f"[JsonDTO] 已注销 {types} 接收器")
async def _handle_json(self, data: dict):
"""JSON消息处理"""
logger.info(f"[JsonDTO] 收到消息")
logger.debug(f'[JsonDTO] 当前消息时间戳: {data["timestamp"]}')
# 根据类型进行自动分发
await self._dispatch(data.get("type"), data["data"])
async def _dispatch(self, types : str, data : dict):
"""二次分发json数据到相应的接收函数当中"""
callbacks = self.receivers[types] # 获取相关types的所有观察者
if not callbacks:
logger.info(f"[JsonDTO] 无 {types} 接收器,消息被忽略")
return
# 并发执行所有回调
tasks = [callback(data) for callback in callbacks]
await asyncio.gather(*tasks, return_exceptions=True)
async def get_json_dto_instance(ws_server) -> JsonDTO:
return JsonDTO(ws_server)
class EchoDTO(MessageDTO):
"""回声DTO:只处理文本消息 测试用"""
def __init__(self, ws_server):
super().__init__(ws_server)
# 注册文本接收函数
ws_server.register_receiver('text', self._handle_text)
logger.info("[EchoDTO] 文本接收器已注册")
async def _handle_text(self, message: str):
"""文本消息处理"""
logger.info(f"[EchoDTO] 收到文本: {message}")
# 业务逻辑
await self.send_text(f"Echo: {message}")
@@ -0,0 +1,281 @@
import asyncio
from src.modules.websocket_base_module.dto.dto_templates.audio_data_dto import AudioDataTransferObject
from src.modules.websocket_base_module.dto.dto_templates.screenshot_data_dto import ScreenShotDataTransferObject
from src.modules.websocket_base_module.dto.second_dtos import JsonDTO
from loguru import logger
from typing import Callable, List, Optional, Coroutine
class AudioDataDTO:
"""音频数据交互DTO 再次分发给所有使用到了音频数据的相关业务(最后一级分发)"""
def __init__(self, json_dto : JsonDTO):
json_dto.register_receiver('audio_data', self._handle_audio_data) # 注册JSON接收函数
logger.info("[AudioDataDTO] 音频接收业务已注册")
self.json_dto = json_dto
self.audio_data = AudioDataTransferObject() # 音频数据对象
# 业务回调列表,延续观察者模式
self._audio_callbacks: List[Callable[[AudioDataTransferObject], Coroutine]] = []
# 最新音频缓存 支持同步查询
self._latest_audio: Optional[AudioDataTransferObject] = None
# 流式缓冲区 用于大段音频流
self._stream_buffer: List[AudioDataTransferObject] = []
logger.info("[AudioDataDTO] 业务接口已初始化")
async def _handle_audio_data(self, data: dict):
"""处理音频数据"""
"""
音频数据json格式:
{
'Owner': 'server',
'isStream': False,
'isStart': False,
'isEnd': False,
'sequence': 0,
'data': '',
'sampleRate': 16000,
'channelCount': 1,
'bitDepth': 16,
'duration': 0.0,
'text': ''
}
"""
logger.debug(f"[AudioDataDTO] 收到音频数据")
# 将dict反序列化到DTO对象
self.audio_data = AudioDataTransferObject.from_json(data)
# 缓存最新数据
self._latest_audio = self.audio_data
# 如果是流式数据,加入缓冲区
if self.audio_data.isStream:
self._stream_buffer.append(self.audio_data)
if self.audio_data.isEnd:
logger.info(f"流式音频接收完成,共 {len(self._stream_buffer)}")
# 通知所有注册的回调
await self._notify_callbacks()
# 业务发送接口
async def send_audio_data(self, data: AudioDataTransferObject) -> None:
"""
发送音频数据
Args:
data: 音频数据DTO
"""
await self.send_audio(
Owner=data.Owner,
is_stream=data.isStream,
is_start=data.isStart,
is_end=data.isEnd,
sequence=data.sequence,
data=data.data,
sampleRate=data.sampleRate,
channelCount=data.channelCount,
bitDepth=data.bitDepth,
duration=data.duration,
text=data.text
)
async def send_audio(
self,
data: bytes,
is_stream: bool = False,
is_start: bool = False,
is_end: bool = False,
sequence: int = 0,
**audio_meta
) -> None:
"""
业务层发送音频的便捷接口
Args:
data: 原始音频字节
is_stream: 是否为流式数据
is_start: 流式数据开始标记
is_end: 流式数据结束标记
sequence: 数据块序号
**audio_meta: 其他音频参数(sampleRate, channelCount等)
"""
# 填充音频数据到DTO
self.audio_data.set_dto_data(
Owner="server" or audio_meta.get('Owner', "server"),
isStream=is_stream,
isStart=is_start,
isEnd=is_end,
sequence=sequence,
data=data,
sampleRate=audio_meta.get('sampleRate', 16000),
channelCount=audio_meta.get('channelCount', 1),
bitDepth=audio_meta.get('bitDepth', 16),
duration=audio_meta.get('duration', 0.0),
text=audio_meta.get('text', "")
)
# 序列化为JSON并发送 自动处理base64和type字段
json_message = self.audio_data.to_json()
await self.json_dto.send_json(json_message)
logger.info(f"音频已发送: sequence={sequence}, 大小={len(data)} bytes")
# 业务接收接口
def register_audio_callback(
self,
callback: Callable[[AudioDataTransferObject], Coroutine]
) -> None:
"""
业务注册接收回调
使用示例:
async def my_audio_handler(audio_dto: AudioDataTransferObject):
print(f"收到音频: {len(audio_dto.data)} bytes")
audio_dto.register_audio_callback(my_audio_handler)
"""
self._audio_callbacks.append(callback)
logger.debug(f"业务音频回调已注册,当前共 {len(self._audio_callbacks)}")
def unregister_audio_callback(self, callback) -> None:
"""注销业务回调"""
if callback in self._audio_callbacks:
self._audio_callbacks.remove(callback)
logger.debug("业务音频回调已注销")
def get_latest_audio(self) -> Optional[AudioDataTransferObject]:
"""
同步获取最新音频数据(轮询模式)
Returns:
最新接收到的音频DTO,如果没有则为 None
"""
return self._latest_audio
def clear_stream_buffer(self) -> None:
"""清空流式缓冲区"""
self._stream_buffer.clear()
logger.debug("流式音频缓冲区已清空")
# 内部通知机制
async def _notify_callbacks(self) -> None:
"""通知所有业务回调"""
if not self._audio_callbacks:
logger.warning("无业务回调,音频数据未处理")
return
tasks = [callback(self.audio_data) for callback in self._audio_callbacks]
await asyncio.gather(*tasks, return_exceptions=True)
logger.debug(f"已通知 {len(self._audio_callbacks)} 个业务回调")
# 异步迭代器 流式时使用
def __aiter__(self):
"""支持 async for 循环接收流式音频"""
return self
async def __anext__(self) -> AudioDataTransferObject:
"""异步迭代器协议"""
pass
class ScreenShotDataDTO:
"""截屏数据交互DTO 分发给所有使用到了截屏数据的相关业务(最后一级分发)"""
def __init__(self, json_dto : JsonDTO):
json_dto.register_receiver('screenshot_data', self._handle_screenshot_data) # 注册JSON接收函数
logger.info("[ScreenShotDataDTO] 截屏接收业务已注册")
self.json_dto = json_dto
self.screenshot_data = ScreenShotDataTransferObject() # 截屏数据对象
# 业务回调列表,延续观察者模式
self._screenshot_callbacks: List[Callable[[ScreenShotDataTransferObject], Coroutine]] = []
# 最新截屏数据缓存 支持同步查询
self._latest_screenshot: Optional[ScreenShotDataTransferObject] = None
logger.info("[ScreenShotDataDTO] 业务接口已初始化")
async def _handle_screenshot_data(self, data: dict):
"""处理截屏数据"""
"""
截屏数据json格式:
{
"Owner": "数据的拥有者(server or client)",
"isSuccess": "是否截图成功(true or false)"
"RealTimeScreenShot": "客户端设备的实时截图数据(base64)",
"Width": "截图的宽度",
"Height": "截图的高度",
"DescribeInfo": "设备的描述信息(告知模型以做出更加准确的判断)"
}
"""
logger.debug(f"[ScreenShotDataDTO] 收到截屏数据")
# 将dict反序列化到DTO对象
self.screenshot_data = ScreenShotDataTransferObject.from_json(data)
# 缓存最新数据
self._latest_screenshot = self.screenshot_data
# 通知所有注册的回调
await self._notify_callbacks()
# 业务发送接口
async def send_screenshot_data(self, data: ScreenShotDataTransferObject) -> None:
"""
发送音频数据
Args:
data: 音频数据DTO
"""
await self.send_screenshot(
Owner=data.Owner,
isSuccess=data.isSuccess,
RealTimeScreenShot=data.RealTimeScreenShot,
Width=data.Width,
Height=data.Height,
DescribeInfo=data.DescribeInfo,
LLMResponse=data.LLMResponse
)
async def send_screenshot(
self,
**screenshot_meta
) -> None:
"""
业务层发送音频的便捷接口
一般来说,作为发送请求方,不需要填充任何数据
Args:
**screenshot_meta: 截屏数据元信息
"""
# 填充音频数据到DTO
self.screenshot_data.set_dto_data(
Owner="server" or screenshot_meta.get('Owner', "server"),
isSuccess=screenshot_meta.get('isSuccess', False),
RealTimeScreenShot=screenshot_meta.get('RealTimeScreenShot', ""),
Width=screenshot_meta.get('Width', 1920),
Height=screenshot_meta.get('Height', 1080),
DescribeInfo=screenshot_meta.get('DescribeInfo', False),
LLMResponse=screenshot_meta.get('LLMResponse', "")
)
# 序列化为JSON并发送 自动处理base64和type字段
json_message = self.screenshot_data.to_json()
await self.json_dto.send_json(json_message)
logger.info(f"截屏包已发送")
# 业务接收接口
def register_screenshot_callback(
self,
callback: Callable[[ScreenShotDataTransferObject], Coroutine]
) -> None:
"""
业务注册接收回调
"""
self._screenshot_callbacks.append(callback)
logger.debug(f"业务截屏回调已注册,当前共 {len(self._screenshot_callbacks)}")
def unregister_screenshot_callback(self, callback) -> None:
"""注销业务回调"""
if callback in self._screenshot_callbacks:
self._screenshot_callbacks.remove(callback)
logger.debug("业务截屏回调已注销")
def get_latest_screenshot(self) -> Optional[ScreenShotDataTransferObject]:
"""
同步获取最新截屏数据(轮询模式)
Returns:
最新接收到的截屏数据DTO,如果没有则为 None
"""
return self._latest_screenshot
# 内部通知机制
async def _notify_callbacks(self) -> None:
"""通知所有业务回调"""
if not self._screenshot_callbacks:
logger.warning("无业务回调,截屏数据未处理")
return
tasks = [callback(self.screenshot_data) for callback in self._screenshot_callbacks]
await asyncio.gather(*tasks, return_exceptions=True)
logger.debug(f"已通知 {len(self._screenshot_callbacks)} 个业务回调")
@@ -0,0 +1,56 @@
### 说明
在本module当中,每个子模块的用途分别是:
- dto
- dto_templates
服务端与客户端交互所使用到的数据传输对象
- dto_base.py / xxx_dtos.py
实际的DTO业务
- websocket_core
- websocket核心,承载了底层核心的网络收发业务
### 模块架构
```mermaid
graph TB
subgraph "Client"
C[WebSocket Client]
end
subgraph "Core Layer"
WS[WebSocketServer<br/>单例]
WS -->|持有| WSP[WebSocketServerProtocol<br/>_websocket]
WS -->|管理| RCV[ receivers: Dict<br/>binary/text/json ]
end
subgraph "DTO Base Layer"
MDTO[MessageDTO<br/>抽象基类]
MDTO -->|注入| MDF[ send_binary<br/>send_text<br/>send_json ]
end
subgraph "Secondary Dispatcher"
JDTO[JsonDTO<br/>单例]
JDTO -->|继承| MDTO
JDTO -->|维护| MAP[ receivers: Dict<br/>audio_data/... ]
JDTO -->|注册到| WS
end
subgraph "Business DTO"
ADTO[AudioDataDTO......<br/>业务实现]
ADTO -->|持有引用| JDTO
ADTO -->|使用| ATO[AudioDataTransferObject<br/>Pydantic模型]
end
C <-->|websocket连接| WSP
WS -->|分发消息| JDTO
JDTO -->|二次分发| ADTO
ADTO -->|发送响应| JDTO
JDTO -->|调用| MDF
MDF -->|经由| WS
WS -->|发送至| C
style WS fill:#64f,stroke:#333,stroke-width:2px
style JDTO fill:#569,stroke:#333,stroke-width:2px
style ADTO fill:#38f,stroke:#333,stroke-width:2px
```
@@ -0,0 +1,172 @@
# websocket_core/core_ws_server.py
import asyncio
import json
from typing import Callable, Optional, Dict, Any, List, Coroutine
from websockets.asyncio.server import serve, ServerConnection
from websockets.exceptions import ConnectionClosed
from loguru import logger
# 类型别名
ReceiveCallback = Callable[[Any], Coroutine[Any, Any, None]]
class WebSocketServer:
"""WebSocket服务端核心模块(单例 + 单客户端)
只管理一个客户端连接,DTO层注册接收函数,服务端只负责分发。
"""
_instance: Optional["WebSocketServer"] = None
_lock = asyncio.Lock()
def __new__(cls):
"""同步单例(__init__ 可以是 async"""
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
# 防止重复初始化
if self._initialized:
return
self._websocket: Optional[ServerConnection] = None
self._receivers: Dict[str, List[ReceiveCallback]] = {
'binary': [],
'text': [],
'json': []
}
self._connected_event = asyncio.Event()
self._initialized = True
logger.info("WebSocketServer 单客户端分发器已初始化")
# DTO层注册接口
def register_receiver(self, msg_type: str, callback: ReceiveCallback) -> None:
"""注册接收函数(供DTO层调用)
Args:
msg_type: 消息类型(binary/text/json
callback: 接收回调,签名为 (data) -> coroutine
"""
if msg_type in self._receivers:
self._receivers[msg_type].append(callback)
logger.debug(f"已注册 {msg_type} 接收器,当前 {msg_type} 类型接收器共 {len(self._receivers[msg_type])}")
else:
raise ValueError(f"不支持的消息类型: {msg_type}")
def unregister_receiver(self, msg_type: str, callback: ReceiveCallback) -> None:
"""注销接收函数"""
if callback in self._receivers[msg_type]:
self._receivers[msg_type].remove(callback)
logger.debug(f"已注销 {msg_type} 接收器")
# 发送接口(供DTO层调用)
async def send_binary(self, data: bytes):
"""发送二进制数据(唯一客户端)"""
if self._websocket:
await self._websocket.send(data)
logger.trace(f"二进制数据已发送 (长度: {len(data)} bytes)")
else:
raise RuntimeError("客户端未连接")
async def send_text(self, data: str):
"""发送文本数据(唯一客户端)"""
if self._websocket:
await self._websocket.send(data)
logger.trace(f"文本数据已发送 (长度: {len(data)} chars)")
else:
raise RuntimeError("客户端未连接")
async def send_json(self, data: Dict[str, Any]):
"""发送JSON数据(唯一客户端)"""
if self._websocket:
try:
logger.debug(f"准备发送JSON数据: {data}")
message = json.dumps(data)
await self._websocket.send(message)
logger.trace(f"JSON数据已发送: {data}")
except Exception as e:
logger.error(f"JSON数据发送失败: {e}")
raise
else:
raise RuntimeError("客户端未连接")
# 等待连接
async def wait_for_client(self):
"""阻塞等待客户端连接"""
await self._connected_event.wait()
logger.info("客户端已就绪")
# 内部消息循环
async def _handle_client(self, websocket: ServerConnection):
"""处理唯一客户端的消息循环"""
self._websocket = websocket
self._connected_event.set()
client_info = f"{websocket.remote_address}" if hasattr(websocket, 'remote_address') else "unknown"
logger.info(f"客户端已连接: {client_info}")
try:
async for message in websocket:
# 根据消息类型分发到所有注册的接收函数
if isinstance(message, bytes):
await self._dispatch('binary', message)
elif isinstance(message, str):
# 优先尝试JSON解析
json_dispatched = False
try:
data = json.loads(message)
await self._dispatch('json', data)
json_dispatched = True
except json.JSONDecodeError:
pass
# 如果不是JSON或没有json接收器,尝试text
if not json_dispatched:
await self._dispatch('text', message)
else:
logger.warning(f"未知消息类型: {type(message)}")
logger.info("客户端连接已正常关闭")
except ConnectionClosed as e:
logger.info(f"客户端连接已关闭: {e}")
except Exception as e:
logger.error(f"处理客户端时发生错误: {e}")
finally:
self._websocket = None
self._connected_event.clear()
async def _dispatch(self, msg_type: str, data: Any):
"""分发消息到所有注册的接收函数"""
callbacks = self._receivers[msg_type]
if not callbacks:
logger.warning(f"{msg_type} 接收器,消息被忽略")
return
# 并发执行所有回调
tasks = [callback(data) for callback in callbacks]
await asyncio.gather(*tasks, return_exceptions=True)
# 启动服务器
async def run(self, host: str = "localhost", port: int = 8765, max_msg_size: int = 50*1024*1025):
"""启动WebSocket服务器(阻塞)"""
logger.info(f"WebSocket服务器启动中... 等待客户端连接 ws://{host}:{port}")
async def handler_wrapper(connection):
logger.info(f"新连接请求: {connection.remote_address}")
await self._handle_client(connection)
try:
async with serve(
handler=handler_wrapper, # 使用 wrapper 适配签名
host=host,
port=port,
max_size=max_msg_size,
):
logger.success(f"WebSocket服务器已启动在 ws://{host}:{port}")
await asyncio.Future() # 永久阻塞,保持服务器运行
except Exception as e:
logger.error(f"WebSocket服务器启动失败: {e}")
raise
async def get_ws_server() -> WebSocketServer:
"""全局单例获取函数(线程安全)"""
server = WebSocketServer() # __new__ 保证单例
return server