first pull
This commit is contained in:
@@ -0,0 +1,123 @@
|
||||
# asr_module/api.py
|
||||
from fastapi import FastAPI, File, UploadFile, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import time
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
from src.modules.asr_module.asr_core.fast_whisper import create_asr, ASRConfig
|
||||
|
||||
# 初始化FastAPI应用
|
||||
app = FastAPI(
|
||||
title="Yosuga ASR API",
|
||||
description="基于faster-whisper Turbo的高性能多语种语音转文本服务",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# 全局单例ASR实例(延迟加载)
|
||||
_asr_instance = None
|
||||
|
||||
def get_asr():
|
||||
"""获取或创建ASR实例(单例)"""
|
||||
global _asr_instance
|
||||
if _asr_instance is None:
|
||||
logger.info("🚀 初始化ASR服务...")
|
||||
_asr_instance = create_asr(
|
||||
ASRConfig(
|
||||
model_name="deepdml/faster-whisper-large-v3-turbo-ct2",
|
||||
device="auto",
|
||||
compute_type="int8_float16",
|
||||
cache_dir=Path("asr_models/faster_whisper_large_v3_ct2"),
|
||||
beam_size=1, # 贪婪搜索,速度最快
|
||||
vad_filter=True, # 过滤静音,节省30%时间
|
||||
)
|
||||
)
|
||||
logger.info("✅ ASR服务初始化完成")
|
||||
return _asr_instance
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""应用启动时预加载模型"""
|
||||
get_asr()
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""应用关闭时清理资源"""
|
||||
global _asr_instance
|
||||
if _asr_instance:
|
||||
_asr_instance.shutdown()
|
||||
logger.info("🛑 ASR服务已关闭")
|
||||
|
||||
@app.post("/transcribe", response_class=JSONResponse)
|
||||
async def transcribe_audio(
|
||||
file: UploadFile = File(..., description="音频文件 (WAV, FLAC, MP3等格式)")
|
||||
):
|
||||
"""
|
||||
语音转文本API
|
||||
|
||||
- **file**: 音频文件,支持WAV/FLAC/MP3等格式
|
||||
- **返回**: JSON格式结果,包含text/language/confidence
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# 验证文件类型
|
||||
if file.content_type and not file.content_type.startswith("audio/"):
|
||||
raise HTTPException(status_code=400, detail="❌ 请上传音频文件 (MIME类型: audio/*)")
|
||||
|
||||
try:
|
||||
# 创建临时文件
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as tmp_file:
|
||||
content = await file.read()
|
||||
tmp_file.write(content)
|
||||
tmp_path = Path(tmp_file.name)
|
||||
|
||||
logger.info(f"📥 接收文件: {file.filename} ({len(content)} bytes)")
|
||||
|
||||
# 调用ASR识别
|
||||
asr = get_asr()
|
||||
text, language, confidence = asr.transcribe_wav(tmp_path)
|
||||
|
||||
# 清理临时文件
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
logger.info(f"✅ 识别完成: {language} | {len(text)}字符 | 置信度:{confidence:.2f} | 耗时:{processing_time:.3f}s")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"text": text,
|
||||
"language": language,
|
||||
"confidence": confidence,
|
||||
"processing_time": round(processing_time, 3)
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 识别失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"识别失败: {str(e)}")
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""健康检查接口"""
|
||||
asr = get_asr()
|
||||
health = asr.health_check()
|
||||
|
||||
return {
|
||||
"status": "healthy" if health["status"] == "healthy" else "unhealthy",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"device": health["device"],
|
||||
"model_loaded": health["model_loaded"]
|
||||
}
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""API根路径"""
|
||||
return {
|
||||
"message": "Yosuga ASR API 正在运行",
|
||||
"docs": "/docs",
|
||||
"health": "/health",
|
||||
"transcribe": "/transcribe (POST)"
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
# fast_whisper/__init__.py
|
||||
from typing import Optional
|
||||
from src.modules.asr_module.asr_core.fast_whisper.config import ASRConfig
|
||||
from src.modules.asr_module.asr_core.fast_whisper.model_manager import ModelManager
|
||||
from src.modules.asr_module.asr_core.fast_whisper.asr_interface import ASRInterface
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__all__ = ["ASRConfig", "ModelManager", "ASRInterface"]
|
||||
|
||||
def create_asr(config: Optional[ASRConfig] = None) -> ASRInterface:
|
||||
"""
|
||||
快速创建ASR实例
|
||||
Args:
|
||||
config: ASR配置,若为None则使用默认配置
|
||||
"""
|
||||
return ASRInterface.get_instance(config)
|
||||
@@ -0,0 +1,184 @@
|
||||
# fast_whisper/asr_interface.py
|
||||
from loguru import logger
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Optional
|
||||
import torchaudio
|
||||
import torch
|
||||
import numpy
|
||||
|
||||
from .model_manager import ModelManager
|
||||
from .config import ASRConfig
|
||||
from .utils import PerformanceProfiler
|
||||
|
||||
class ASRInterface:
|
||||
"""
|
||||
ASR接口类 - 全局单例
|
||||
- 提供wav转文本功能
|
||||
- 注入ModelManager
|
||||
- 性能统计
|
||||
"""
|
||||
|
||||
_instance: Optional['ASRInterface'] = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""单例模式实现"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, config: Optional[ASRConfig] = None):
|
||||
# 防止重复初始化
|
||||
if hasattr(self, '_initialized') and self._initialized:
|
||||
return
|
||||
|
||||
self.config = config or ASRConfig()
|
||||
self.model_manager = ModelManager(self.config)
|
||||
self.profiler = PerformanceProfiler(self.config.enable_profiling)
|
||||
|
||||
# 音频参数
|
||||
self.sample_rate = 16000
|
||||
|
||||
self._initialized = True
|
||||
logger.info("🎤 ASR接口初始化完成")
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, config: Optional[ASRConfig] = None) -> 'ASRInterface':
|
||||
"""全局访问点"""
|
||||
if cls._instance is None:
|
||||
cls._instance = cls(config)
|
||||
return cls._instance
|
||||
|
||||
def transcribe_wav(
|
||||
self,
|
||||
wav_path: Path,
|
||||
language: Optional[str] = None
|
||||
) -> Tuple[str, str, float]:
|
||||
"""
|
||||
WAV音频转文本(核心接口)
|
||||
|
||||
Args:
|
||||
wav_path: WAV文件路径
|
||||
language: 指定语言代码(如'zh'/'en'),None则自动检测
|
||||
|
||||
Returns:
|
||||
(text, language, confidence)
|
||||
"""
|
||||
try:
|
||||
# 记录开始时间
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
logger.info(f"🎵 开始识别: {wav_path.name}")
|
||||
|
||||
# 执行识别...
|
||||
audio = self._load_audio(wav_path)
|
||||
result = self._transcribe(audio, language)
|
||||
text, lang, confidence = self._parse_result(result)
|
||||
|
||||
# 计算耗时
|
||||
processing_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"✅ 识别完成: {lang} | {len(text)}字符 | 置信度:{confidence:.2f} | "
|
||||
f"耗时:{processing_time:.3f}s | RTF:{processing_time/(len(audio)/self.sample_rate):.3f}"
|
||||
)
|
||||
|
||||
return text, lang, confidence
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 识别失败 {wav_path}: {e}")
|
||||
raise RuntimeError(f"Transcription failed: {e}")
|
||||
|
||||
def _load_audio(self, wav_path: Path) -> numpy.ndarray:
|
||||
"""加载和预处理音频"""
|
||||
if not wav_path.exists():
|
||||
raise FileNotFoundError(f"音频文件不存在: {wav_path}")
|
||||
|
||||
# 加载音频
|
||||
waveform, sample_rate = torchaudio.load(wav_path)
|
||||
|
||||
# 重采样到16kHz
|
||||
if sample_rate != self.sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(sample_rate, self.sample_rate)
|
||||
waveform = resampler(waveform)
|
||||
|
||||
# 转换为单声道
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||
|
||||
# 转换为numpy数组
|
||||
audio = waveform.squeeze().numpy()
|
||||
|
||||
return audio
|
||||
|
||||
def _transcribe(self, audio: numpy.ndarray, language: Optional[str]) -> Tuple:
|
||||
"""执行推理"""
|
||||
model = self.model_manager.model
|
||||
|
||||
# 添加模型存在性检查
|
||||
if model is None:
|
||||
logger.error("ASR模型未加载,请检查模型配置和路径")
|
||||
raise RuntimeError("ASR模型未加载,请检查模型配置和路径")
|
||||
|
||||
# 记录时间
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# 调用模型
|
||||
segments, info = model.transcribe(
|
||||
audio,
|
||||
language=language,
|
||||
beam_size=self.config.beam_size,
|
||||
best_of=self.config.best_of,
|
||||
vad_filter=self.config.vad_filter,
|
||||
)
|
||||
|
||||
# 立即执行生成器
|
||||
segments_list = list(segments)
|
||||
|
||||
# 性能统计
|
||||
inference_time = time.time() - start_time
|
||||
audio_duration = len(audio) / self.sample_rate
|
||||
self.profiler.record(audio_duration, inference_time)
|
||||
|
||||
return segments_list, info
|
||||
|
||||
def _parse_result(self, result: Tuple) -> Tuple[str, str, float]:
|
||||
"""解析识别结果"""
|
||||
segments, info = result
|
||||
|
||||
# 合并所有片段
|
||||
text = " ".join([seg.text.strip() for seg in segments])
|
||||
|
||||
# 获取语言信息
|
||||
language = info.language if info else "unknown"
|
||||
confidence = info.language_probability if info else 0.0
|
||||
|
||||
return text, language, confidence
|
||||
|
||||
def transcribe_batch(self, wav_paths: list) -> list:
|
||||
"""批量识别接口"""
|
||||
return [
|
||||
{
|
||||
"file": str(path),
|
||||
"text": result[0],
|
||||
"language": result[1],
|
||||
"confidence": result[2]
|
||||
}
|
||||
for path, result in zip(wav_paths, [
|
||||
self.transcribe_wav(Path(p)) for p in wav_paths
|
||||
])
|
||||
]
|
||||
|
||||
def health_check(self) -> dict:
|
||||
"""健康检查接口"""
|
||||
return {
|
||||
"status": "healthy" if self.model_manager.model else "unhealthy",
|
||||
"device": self.config.device,
|
||||
"model_loaded": self.model_manager.model is not None,
|
||||
"device_info": self.model_manager.get_device_info(),
|
||||
}
|
||||
|
||||
def shutdown(self):
|
||||
"""优雅关闭"""
|
||||
logger.info("🛑 关闭ASR接口...")
|
||||
self.model_manager.unload()
|
||||
@@ -0,0 +1,29 @@
|
||||
# fast_whisper/config.py
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
import torch
|
||||
|
||||
@dataclass
|
||||
class ASRConfig:
|
||||
"""ASR配置类"""
|
||||
model_name: str = "deepdml/faster-whisper-large-v3-turbo-ct2"
|
||||
device: str = "auto"
|
||||
compute_type: str = "int8_float16"
|
||||
cache_dir: Path = Path.home() / ".cache" / "faster_whisper"
|
||||
|
||||
# 速度优化参数
|
||||
beam_size: int = 1
|
||||
best_of: int = 1
|
||||
vad_filter: bool = True
|
||||
batch_size: int = 16
|
||||
|
||||
# 性能统计
|
||||
enable_profiling: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.device == "auto":
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
if self.device == "cpu":
|
||||
self.compute_type = "int8"
|
||||
self.batch_size = 4
|
||||
@@ -0,0 +1,92 @@
|
||||
# fast_whisper/model_manager.py
|
||||
import gc
|
||||
from loguru import logger
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from faster_whisper import WhisperModel
|
||||
import torch
|
||||
|
||||
from .config import ASRConfig
|
||||
|
||||
|
||||
class ModelManager:
|
||||
"""
|
||||
模型管理类
|
||||
- 负责模型生命周期管理
|
||||
- 支持自定义缓存目录
|
||||
- 自动硬件适配
|
||||
"""
|
||||
|
||||
def __init__(self, config: ASRConfig):
|
||||
self.config = config
|
||||
self._model: Optional[WhisperModel] = None
|
||||
self._device_info = None
|
||||
|
||||
# 确保缓存目录存在
|
||||
self.config.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@property
|
||||
def model(self) -> Optional[WhisperModel]:
|
||||
"""懒加载模型"""
|
||||
if self._model is None:
|
||||
self._load_model()
|
||||
return self._model
|
||||
|
||||
def _load_model(self):
|
||||
"""加载模型"""
|
||||
logger.info(f"🚀 初始化模型: {self.config.model_name}")
|
||||
logger.info(f"📦 设备: {self.config.device}, 计算类型: {self.config.compute_type}")
|
||||
|
||||
try:
|
||||
self._model = WhisperModel(
|
||||
self.config.model_name,
|
||||
device=self.config.device,
|
||||
compute_type=self.config.compute_type,
|
||||
download_root=str(self.config.cache_dir),
|
||||
local_files_only=False,
|
||||
)
|
||||
|
||||
self._device_info = {
|
||||
"device": self.config.device,
|
||||
"compute_type": self.config.compute_type,
|
||||
"model_size": self.config.model_name.split("-")[-2]
|
||||
}
|
||||
|
||||
logger.info("✅ 模型加载成功")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 模型加载失败: {e}")
|
||||
raise RuntimeError(f"Failed to load ASR model: {e}")
|
||||
|
||||
def reload(self, new_config: ASRConfig):
|
||||
"""热重载模型"""
|
||||
logger.info("🔄 热重载模型...")
|
||||
self.unload()
|
||||
self.config = new_config
|
||||
self._load_model()
|
||||
|
||||
def unload(self):
|
||||
"""卸载模型释放资源"""
|
||||
if self._model is not None:
|
||||
logger.info("🗑️ 卸载模型...")
|
||||
del self._model
|
||||
self._model = None
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
gc.collect()
|
||||
|
||||
logger.info("✅ 模型已卸载")
|
||||
|
||||
def get_device_info(self) -> dict:
|
||||
"""获取设备信息"""
|
||||
return self._device_info or {}
|
||||
|
||||
def __enter__(self):
|
||||
"""上下文管理器支持"""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""自动清理资源"""
|
||||
self.unload()
|
||||
@@ -0,0 +1,45 @@
|
||||
# fast_whisper/utils.py
|
||||
from loguru import logger
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any
|
||||
|
||||
def check_hardware() -> Dict[str, Any]:
|
||||
"""硬件检测"""
|
||||
import torch
|
||||
info = {
|
||||
"cuda_available": torch.cuda.is_available(),
|
||||
"device_name": "CPU",
|
||||
"device_count": 0,
|
||||
"compute_type": "int8"
|
||||
}
|
||||
|
||||
if info["cuda_available"]:
|
||||
info.update({
|
||||
"device_name": torch.cuda.get_device_name(0),
|
||||
"device_count": torch.cuda.device_count(),
|
||||
"compute_type": "int8_float16"
|
||||
})
|
||||
|
||||
return info
|
||||
|
||||
class PerformanceProfiler:
|
||||
"""性能分析器"""
|
||||
def __init__(self, enable: bool = True):
|
||||
self.enable = enable
|
||||
self.stats = []
|
||||
|
||||
def record(self, audio_duration: float, inference_time: float):
|
||||
if not self.enable:
|
||||
return
|
||||
|
||||
rtf = inference_time / audio_duration if audio_duration > 0 else 0
|
||||
self.stats.append({
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"rtf": rtf,
|
||||
"audio_duration": audio_duration,
|
||||
"inference_time": inference_time
|
||||
})
|
||||
|
||||
if len(self.stats) % 10 == 0:
|
||||
avg_rtf = sum(s["rtf"] for s in self.stats[-10:]) / 10
|
||||
logger.info(f"📊 最近10次平均RTF: {avg_rtf:.3f}")
|
||||
@@ -0,0 +1,215 @@
|
||||
# asr_module/client/asr_client.py
|
||||
import asyncio
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Union, Optional
|
||||
import aiofiles
|
||||
import aiohttp
|
||||
import requests
|
||||
from loguru import logger
|
||||
from .models import ASRResponse, ASRHealthStatus, ServiceInfo
|
||||
|
||||
class ASRException(Exception):
|
||||
"""ASR服务调用异常"""
|
||||
|
||||
def __init__(self, message: str, status_code: Optional[int] = None):
|
||||
self.message = message
|
||||
self.status_code = status_code
|
||||
super().__init__(self.message)
|
||||
|
||||
class ASRClientConfig:
|
||||
"""客户端配置"""
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str = "http://localhost:8000",
|
||||
timeout: float = 30.0,
|
||||
retry_count: int = 2,
|
||||
retry_delay: float = 0.5,
|
||||
):
|
||||
self.base_url = base_url.rstrip('/')
|
||||
self.timeout = timeout
|
||||
self.retry_count = retry_count
|
||||
self.retry_delay = retry_delay
|
||||
|
||||
|
||||
# 同步客户端
|
||||
class ASRClientSync:
|
||||
"""同步ASR客户端"""
|
||||
def __init__(self, config: Optional[ASRClientConfig] = None):
|
||||
self.config = config or ASRClientConfig()
|
||||
self.session = requests.Session()
|
||||
self.session.timeout = self.config.timeout
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.session.close()
|
||||
|
||||
def _request(self, method: str, endpoint: str, **kwargs) -> dict:
|
||||
"""统一请求处理(带重试)"""
|
||||
url = f"{self.config.base_url}{endpoint}"
|
||||
|
||||
for attempt in range(self.config.retry_count + 1):
|
||||
try:
|
||||
response = self.session.request(method, url, **kwargs)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
if attempt < self.config.retry_count:
|
||||
logger.warning(f"请求失败,重试中 ({attempt + 1}/{self.config.retry_count}): {e}")
|
||||
time.sleep(self.config.retry_delay)
|
||||
else:
|
||||
logger.error(f"请求最终失败: {e}")
|
||||
raise ASRException(f"API调用失败: {e}", getattr(e.response, 'status_code', None))
|
||||
|
||||
def transcribe_file(self, file_path: Union[str, Path]) -> ASRResponse:
|
||||
"""
|
||||
转录音频文件
|
||||
|
||||
Args:
|
||||
file_path: 音频文件路径
|
||||
|
||||
Returns:
|
||||
ASRResponse对象
|
||||
|
||||
Example:
|
||||
client = ASRClientSync()
|
||||
result = client.transcribe_file("/path/to/audio.wav")
|
||||
print(result.data.text)
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
|
||||
logger.info(f"📤 上传文件: {file_path.name}")
|
||||
|
||||
with open(file_path, 'rb') as f:
|
||||
files = {'file': (file_path.name, f, 'audio/wav')}
|
||||
result = self._request('POST', '/transcribe', files=files)
|
||||
|
||||
return ASRResponse(**result)
|
||||
|
||||
def transcribe_bytes(self, audio_data: bytes, filename: str = "audio.wav") -> ASRResponse:
|
||||
"""
|
||||
转录音频字节流
|
||||
|
||||
Args:
|
||||
audio_data: 原始音频字节
|
||||
filename: 模拟文件名(用于MIME类型推断)
|
||||
|
||||
Returns:
|
||||
ASRResponse对象
|
||||
|
||||
Example:
|
||||
with open('audio.wav', 'rb') as f:
|
||||
audio_bytes = f.read()
|
||||
result = client.transcribe_bytes(audio_bytes)
|
||||
"""
|
||||
logger.info(f"📤 上传字节流 ({len(audio_data)} bytes)")
|
||||
|
||||
files = {'file': (filename, audio_data, 'audio/wav')}
|
||||
result = self._request('POST', '/transcribe', files=files)
|
||||
|
||||
return ASRResponse(**result)
|
||||
|
||||
def health_check(self) -> ASRHealthStatus:
|
||||
"""健康检查"""
|
||||
result = self._request('GET', '/health')
|
||||
return ASRHealthStatus(**result)
|
||||
|
||||
def get_service_info(self) -> ServiceInfo:
|
||||
"""获取服务信息"""
|
||||
result = self._request('GET', '/')
|
||||
return ServiceInfo(**result)
|
||||
|
||||
|
||||
# 异步客户端
|
||||
class ASRClientAsync:
|
||||
"""异步ASR客户端"""
|
||||
def __init__(self, config: Optional[ASRClientConfig] = None):
|
||||
self.config = config or ASRClientConfig()
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self._session = aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=self.config.timeout)
|
||||
)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
|
||||
async def _ensure_session(self):
|
||||
if self._session is None:
|
||||
self._session = aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=self.config.timeout)
|
||||
)
|
||||
|
||||
async def _request(self, method: str, endpoint: str, **kwargs) -> dict:
|
||||
"""统一异步请求(带重试)"""
|
||||
await self._ensure_session()
|
||||
url = f"{self.config.base_url}{endpoint}"
|
||||
|
||||
for attempt in range(self.config.retry_count + 1):
|
||||
try:
|
||||
async with self._session.request(method, url, **kwargs) as response:
|
||||
response.raise_for_status()
|
||||
return await response.json()
|
||||
except aiohttp.ClientError as e:
|
||||
if attempt < self.config.retry_count:
|
||||
logger.warning(f"请求失败,重试中 ({attempt + 1}/{self.config.retry_count}): {e}")
|
||||
await asyncio.sleep(self.config.retry_delay)
|
||||
else:
|
||||
logger.error(f"请求最终失败: {e}")
|
||||
raise ASRException(f"API调用失败: {e}", getattr(e, 'status', None))
|
||||
|
||||
async def transcribe_file(self, file_path: Union[str, Path]) -> ASRResponse:
|
||||
"""异步转录音频文件"""
|
||||
file_path = Path(file_path)
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
|
||||
logger.info(f"📤 上传文件: {file_path.name}")
|
||||
|
||||
async with aiofiles.open(file_path, 'rb') as f:
|
||||
audio_data = await f.read()
|
||||
|
||||
return await self.transcribe_bytes(audio_data, file_path.name)
|
||||
|
||||
async def transcribe_bytes(self, audio_data: bytes, filename: str = "audio.wav") -> ASRResponse:
|
||||
"""异步转录音频字节流"""
|
||||
logger.info(f"📤 上传字节流 ({len(audio_data)} bytes)")
|
||||
await self._ensure_session() # 确保session已创建
|
||||
form = aiohttp.FormData() # 创建表单数据
|
||||
form.add_field('file', audio_data, filename=filename, content_type='audio/wav') # 添加文件字段
|
||||
result = await self._request('POST', '/transcribe', data=form) # 发送POST请求
|
||||
return ASRResponse(**result) # 返回结果
|
||||
|
||||
async def health_check(self) -> ASRHealthStatus:
|
||||
"""异步健康检查"""
|
||||
result = await self._request('GET', '/health')
|
||||
return ASRHealthStatus(**result)
|
||||
|
||||
async def get_service_info(self) -> ServiceInfo:
|
||||
"""异步获取服务信息"""
|
||||
result = await self._request('GET', '/')
|
||||
return ServiceInfo(**result)
|
||||
|
||||
# 工厂函数
|
||||
def create_asr_client(use_async: bool = False, **config_kwargs) -> Union[ASRClientSync, ASRClientAsync]:
|
||||
"""
|
||||
创建客户端工厂函数
|
||||
|
||||
Args:
|
||||
use_async: 是否创建异步客户端
|
||||
**config_kwargs: ASRClientConfig参数
|
||||
|
||||
Returns:
|
||||
同步或异步客户端实例
|
||||
"""
|
||||
config = ASRClientConfig(**config_kwargs)
|
||||
if use_async:
|
||||
return ASRClientAsync(config)
|
||||
return ASRClientSync(config)
|
||||
@@ -0,0 +1,30 @@
|
||||
# asr_module/client/models.py
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
class ASRHealthStatus(BaseModel):
|
||||
"""ASR服务健康状态"""
|
||||
status: str # 状态
|
||||
timestamp: str # 时间戳
|
||||
device: str # 设备
|
||||
model_loaded: bool # 模型是否加载
|
||||
|
||||
class ASRResult(BaseModel):
|
||||
"""语音识别结果"""
|
||||
text: str # 识别结果
|
||||
language: str # 语言
|
||||
confidence: float # 置信度
|
||||
processing_time: float # 处理时间
|
||||
|
||||
class ASRResponse(BaseModel):
|
||||
"""统一API响应"""
|
||||
success: bool
|
||||
data: Optional[ASRResult] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
class ServiceInfo(BaseModel):
|
||||
"""服务信息"""
|
||||
message: str # 消息
|
||||
docs: str # 文档
|
||||
health: str # 健康
|
||||
transcribe: str # 识别
|
||||
@@ -0,0 +1,12 @@
|
||||
本模块为asr模块,即语音转文本模块。
|
||||
|
||||
本模块提供了两种访问方式:
|
||||
- 第一种为本地部署的func call方式,即提供函数调用的方式调用相关asr接口。
|
||||
- 第二种为http call方式,即提供http call方式调用相关asr接口。[FastAPI]
|
||||
|
||||
如果你的电脑显卡支持cuda, 并且显存大小大于8G, 那么可以使用第一种方式,
|
||||
否则可以使用第二种,进行云端部署。
|
||||
|
||||
两种调用方式在相同配置下,性能几乎无差别。
|
||||
|
||||
个人建议使用第二种
|
||||
@@ -0,0 +1,65 @@
|
||||
# start_api.py
|
||||
import uvicorn
|
||||
from loguru import logger
|
||||
import threading
|
||||
import time
|
||||
|
||||
def start_server():
|
||||
"""启动 ASR API 服务"""
|
||||
uvicorn.run(
|
||||
"api:app", # 模块名:app实例
|
||||
host="0.0.0.0",
|
||||
port=20260,
|
||||
workers=1, # 单用户场景,1个worker足够
|
||||
log_level="info",
|
||||
reload=False, # 生产环境关闭热重载
|
||||
access_log=True,
|
||||
)
|
||||
|
||||
def first_test() -> None:
|
||||
"""首次启动测试"""
|
||||
time.sleep(5) # 给服务器一些启动时间
|
||||
# 构造一个测试请求以验证初始化模型加载成功
|
||||
logger.info("🚀 测试模型是否加载成功...")
|
||||
import requests
|
||||
from pathlib import Path
|
||||
url = "http://localhost:20260/transcribe"
|
||||
audio_path = Path("../../../Test/test_files/test.wav")
|
||||
try:
|
||||
with open(audio_path, "rb") as f:
|
||||
# 明确指定文件名和 MIME 类型
|
||||
files = {
|
||||
"file": (
|
||||
audio_path.name, # 文件名
|
||||
f, # 文件对象
|
||||
"audio/wav" # MIME 类型
|
||||
)
|
||||
}
|
||||
|
||||
response = requests.post(url, files=files)
|
||||
logger.info(f"状态码: {response.status_code}")
|
||||
logger.info(f"响应头: {response.headers.get('content-type')}")
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
logger.info(f"识别结果: {result['data']['text']}")
|
||||
logger.info(f"识别语言: {result['data']['language']}")
|
||||
logger.info(f"置信度: {result['data']['confidence']:.2f}")
|
||||
logger.info(f"处理时间: {result['data']['processing_time']}s")
|
||||
else:
|
||||
logger.error(f"请求失败,错误响应信息: {response.text}")
|
||||
logger.error("请检查模型是否正确加载或其他问题")
|
||||
except Exception as e:
|
||||
logger.error(f"测试过程中发生错误: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("🚀 启动 ASR API 服务...")
|
||||
|
||||
# 在后台线程启动服务器
|
||||
server_thread = threading.Thread(target=start_server, daemon=True)
|
||||
server_thread.start()
|
||||
|
||||
# 执行测试
|
||||
first_test()
|
||||
|
||||
# 保持主线程运行
|
||||
server_thread.join()
|
||||
@@ -0,0 +1,119 @@
|
||||
# ui_tars_/ui_tars_client.py
|
||||
from typing import Optional
|
||||
from loguru import logger
|
||||
import asyncio
|
||||
from src.modules.text_ai_module.text_ai_core.general_text_ai_req import (
|
||||
UnifiedLLM,
|
||||
ModelConfig,
|
||||
ModelProvider,
|
||||
ChatMessage,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
from src.modules.device_control_module.device_control_core.ui_tars_.ui_tars_prompts import UI_TARS_SYSTEM_PROMPT
|
||||
|
||||
|
||||
class UITarsClientConfig(BaseModel):
|
||||
"""UI-TARS 客户端配置"""
|
||||
deployment_type: str = Field(default="lmstudio", description="部署类型")
|
||||
base_url: str = Field(default="http://localhost:1234/v1", description="API地址")
|
||||
model_name: str = Field(default="ui-tars", description="模型名称")
|
||||
api_key: Optional[str] = Field(default=None, description="API密钥")
|
||||
temperature: float = Field(default=0.1, ge=0.0, le=2.0)
|
||||
max_tokens: int = Field(default=8192, ge=2048, le=128000)
|
||||
timeout: int = Field(default=30, ge=5, le=300)
|
||||
|
||||
# UI-TARS-1.5 强制输出格式
|
||||
system_prompt: str = Field(
|
||||
default=UI_TARS_SYSTEM_PROMPT # 使用本项目自定义的输出格式约束
|
||||
)
|
||||
|
||||
def to_model_config(self) -> ModelConfig:
|
||||
"""转换为 UnifiedLLM 配置"""
|
||||
# 映射部署类型到 ModelProvider
|
||||
provider_map = {
|
||||
"lmstudio": ModelProvider.LM_STUDIO,
|
||||
"vllm": ModelProvider.CUSTOM,
|
||||
"cloud": ModelProvider.OPENAI,
|
||||
"ollama": ModelProvider.OLLAMA
|
||||
}
|
||||
provider = provider_map.get(self.deployment_type, ModelProvider.CUSTOM)
|
||||
return ModelConfig(
|
||||
provider=provider,
|
||||
model_name=self.model_name,
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
timeout=self.timeout,
|
||||
custom_headers={"User-Agent": "UI-TARS-Client/1.0"}
|
||||
)
|
||||
|
||||
class UITarsClient:
|
||||
"""
|
||||
UI-TARS 通用客户端 (基于 UnifiedLLM)
|
||||
图片相关信息请直接传入相应的base64
|
||||
"""
|
||||
def __init__(self, config: UITarsClientConfig):
|
||||
self.config = config
|
||||
# 复用 UnifiedLLM,自动处理所有部署类型
|
||||
self.llm = UnifiedLLM(config.to_model_config())
|
||||
|
||||
logger.info(f"UI-TARS 客户端初始化: {config.deployment_type} @ {config.base_url}")
|
||||
logger.info(f" 模型: {config.model_name} | 温度: {config.temperature}")
|
||||
|
||||
async def call_async(self, instruction: str, image_base64: str) -> str:
|
||||
"""异步调用 UI-TARS"""
|
||||
# 构建消息
|
||||
messages = self._build_messages(instruction, image_base64)
|
||||
try:
|
||||
# 使用 UnifiedLLM 的异步接口
|
||||
response = await asyncio.to_thread(
|
||||
self.llm.chat,
|
||||
messages=messages,
|
||||
streaming=False
|
||||
)
|
||||
return response.content
|
||||
except Exception as e:
|
||||
logger.error(f"UI-TARS 调用失败: {e}")
|
||||
raise
|
||||
|
||||
def call_sync(self, instruction: str, image_base64: str) -> str:
|
||||
"""同步调用 UI-TARS"""
|
||||
messages = self._build_messages(instruction, image_base64)
|
||||
try:
|
||||
response = self.llm.chat(
|
||||
messages=messages,
|
||||
streaming=False
|
||||
)
|
||||
|
||||
return response.content
|
||||
except Exception as e:
|
||||
logger.error(f"UI-TARS 调用失败: {e}")
|
||||
raise
|
||||
|
||||
def stream_async(self, instruction: str, image_base64: str):
|
||||
"""流式调用 (异步生成器)"""
|
||||
messages = self._build_messages(instruction, image_base64)
|
||||
# UnifiedLLM 自动处理流式
|
||||
return self.llm.stream_chat(messages=messages)
|
||||
|
||||
def _build_messages(self, instruction: str, image_base64: str) -> list:
|
||||
"""构建 OpenAI 格式消息"""
|
||||
return [
|
||||
ChatMessage(
|
||||
role="system",
|
||||
content=self.config.system_prompt
|
||||
),
|
||||
ChatMessage(
|
||||
role="user",
|
||||
content=[
|
||||
{"type": "text", "text": instruction},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{image_base64}"
|
||||
}
|
||||
}
|
||||
]
|
||||
)
|
||||
]
|
||||
@@ -0,0 +1,25 @@
|
||||
OFFICIAL_ACTION_SPACE = """## Action Space
|
||||
click(point='<point>x1 y1</point>') - 单击坐标
|
||||
left_double(point='<point>x1 y1</point>') - 双击坐标
|
||||
right_single(point='<point>x1 y1</point>') - 右键单击
|
||||
drag(start_point='<point>x1 y1</point>', end_point='<point>x2 y2</point>') - 拖拽
|
||||
hotkey(key='ctrl c') - 快捷键(空格分隔,小写,最多3个键)
|
||||
type(content='xxx') - 输入文本(用\\' \\\" \\n 转义)
|
||||
scroll(point='<point>x1 y1</point>', direction='down or up or right or left') - 滚动
|
||||
wait() - 等待5秒
|
||||
finished() - 任务完成
|
||||
|
||||
## Output Format
|
||||
Thought: [你的推理过程]
|
||||
Action: [选择一个动作]
|
||||
"""
|
||||
|
||||
UI_TARS_SYSTEM_PROMPT = f"""You are UI-TARS-1.5, a GUI agent. Given a task and screenshot, output ONLY:
|
||||
|
||||
{OFFICIAL_ACTION_SPACE}
|
||||
|
||||
## Note
|
||||
- Write a small plan and summarize the next action in one sentence in Thought.
|
||||
- NEVER output multiple actions.
|
||||
- x&y please in box center
|
||||
"""
|
||||
@@ -0,0 +1,10 @@
|
||||
本模块为设备控制模块接口层,此处的设备控制指的是借助AI模型进行一些设备上的自动化
|
||||
操作,支持`pc`, `android`,其他的未做过测试。
|
||||
|
||||
依赖:
|
||||
`ui_tars`
|
||||
|
||||
当前所使用的AI模型为`mradermacher/UI-TARS-1.5-7B-GGUF
|
||||
(Q6_K or Q4_K_M)`
|
||||
|
||||
未来如果有更快质量更高的AI模型本模块会为其添加支持。
|
||||
@@ -0,0 +1,837 @@
|
||||
"""
|
||||
通用大语言模型调用框架
|
||||
支持本地模型(Ollama, LM Studio, llama.cpp)和云端模型(OpenAI, Anthropic, Google, Azure等)
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Union, Any, Iterator
|
||||
import json
|
||||
import os
|
||||
from loguru import logger
|
||||
from dataclasses import dataclass, asdict, field
|
||||
from enum import Enum
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class ModelProvider(Enum):
|
||||
"""支持的模型提供商枚举"""
|
||||
OPENAI = "openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
GOOGLE = "google"
|
||||
AZURE = "azure"
|
||||
OLLAMA = "ollama"
|
||||
LM_STUDIO = "lm_studio"
|
||||
LLAMA_CPP = "llama_cpp"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
"""模型配置类"""
|
||||
provider: ModelProvider
|
||||
model_name: str
|
||||
api_key: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
api_version: Optional[str] = None
|
||||
temperature: float = 0.7
|
||||
max_tokens: int = 1024
|
||||
top_p: float = 1.0
|
||||
frequency_penalty: float = 0.0
|
||||
presence_penalty: float = 0.0
|
||||
timeout: int = 30
|
||||
streaming: bool = False
|
||||
custom_headers: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""转换为字典"""
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
"""聊天消息类"""
|
||||
role: str # system, user, assistant
|
||||
content: str
|
||||
name: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"role": self.role,
|
||||
"content": self.content,
|
||||
**({"name": self.name} if self.name else {})
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelResponse:
|
||||
"""模型响应类"""
|
||||
content: str # 响应内容
|
||||
model: str # 模型名称
|
||||
usage: Optional[Dict[str, int]] = None # 使用量
|
||||
finish_reason: Optional[str] = None # 结束原因
|
||||
raw_response: Optional[Dict] = None # 原始响应
|
||||
|
||||
|
||||
def normalize_usage(raw_usage: Optional[Dict[str, Any]], provider: ModelProvider) -> Optional[Dict[str, int]]:
|
||||
"""
|
||||
将不同平台的 usage 字段统一归一化为 OpenAI 标准格式
|
||||
|
||||
Args:
|
||||
raw_usage: API 原始返回的 usage 数据
|
||||
provider: 模型提供商枚举
|
||||
|
||||
Returns:
|
||||
归一化后的 usage 字典,格式:
|
||||
{
|
||||
"prompt_tokens": int,
|
||||
"completion_tokens": int,
|
||||
"total_tokens": int
|
||||
}
|
||||
如果无法归一化则返回 None
|
||||
"""
|
||||
if not raw_usage:
|
||||
return None
|
||||
|
||||
# 字段映射表:{provider: (input_key, output_key, total_key)}
|
||||
USAGE_FIELD_MAP = {
|
||||
ModelProvider.OPENAI: ("prompt_tokens", "completion_tokens", "total_tokens"),
|
||||
ModelProvider.AZURE: ("prompt_tokens", "completion_tokens", "total_tokens"),
|
||||
ModelProvider.LM_STUDIO: ("prompt_tokens", "completion_tokens", "total_tokens"),
|
||||
ModelProvider.LLAMA_CPP: ("prompt_tokens", "completion_tokens", "total_tokens"),
|
||||
ModelProvider.OLLAMA: ("prompt_eval_count", "eval_count", None), # Ollama 没有 total
|
||||
ModelProvider.ANTHROPIC: ("input_tokens", "output_tokens", None),
|
||||
ModelProvider.GOOGLE: ("promptTokenCount", "candidatesTokenCount", "totalTokenCount"),
|
||||
}
|
||||
|
||||
input_key, output_key, total_key = USAGE_FIELD_MAP.get(provider, (None, None, None))
|
||||
|
||||
if input_key is None:
|
||||
logger.warning(f"未知的 provider '{provider}',无法归一化 usage")
|
||||
return None
|
||||
|
||||
try:
|
||||
# 提取字段值
|
||||
prompt_tokens = raw_usage.get(input_key, 0)
|
||||
completion_tokens = raw_usage.get(output_key, 0)
|
||||
|
||||
# 处理嵌套字典(如有些 API 的 usage 格式特殊)
|
||||
if isinstance(prompt_tokens, dict):
|
||||
prompt_tokens = prompt_tokens.get("value", 0)
|
||||
if isinstance(completion_tokens, dict):
|
||||
completion_tokens = completion_tokens.get("value", 0)
|
||||
|
||||
# 转换为整数
|
||||
prompt_tokens = int(prompt_tokens) if prompt_tokens else 0
|
||||
completion_tokens = int(completion_tokens) if completion_tokens else 0
|
||||
|
||||
# 计算 total(如果 API 没提供)
|
||||
if total_key and total_key in raw_usage:
|
||||
total_tokens = int(raw_usage[total_key])
|
||||
else:
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
normalized = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
|
||||
logger.debug(f"归一化 usage | {provider} -> OpenAI格式: {normalized}")
|
||||
return normalized
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"归一化 usage 失败: {e} | raw_usage: {raw_usage}")
|
||||
return None
|
||||
|
||||
class BaseLLMClient(ABC):
|
||||
"""大语言模型客户端基类"""
|
||||
|
||||
def __init__(self, config: ModelConfig):
|
||||
self.config = config
|
||||
self.client = None
|
||||
self._initialize_client()
|
||||
|
||||
@abstractmethod
|
||||
def _initialize_client(self):
|
||||
"""初始化客户端"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def chat_completion(
|
||||
self,
|
||||
messages: List[Union[ChatMessage, Dict]],
|
||||
**kwargs
|
||||
) -> Union[ModelResponse, Iterator[ModelResponse]]:
|
||||
"""聊天补全"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def completion(
|
||||
self,
|
||||
prompt: str,
|
||||
**kwargs
|
||||
) -> Union[ModelResponse, Iterator[ModelResponse]]:
|
||||
"""文本补全"""
|
||||
pass
|
||||
|
||||
def format_messages(self, messages: List[Union[ChatMessage, Dict]]) -> List[Dict]:
|
||||
"""格式化消息列表"""
|
||||
formatted = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, ChatMessage):
|
||||
formatted.append(msg.to_dict())
|
||||
else:
|
||||
formatted.append(msg)
|
||||
return formatted
|
||||
|
||||
|
||||
class OpenAIClient(BaseLLMClient):
|
||||
"""OpenAI客户端"""
|
||||
|
||||
def _initialize_client(self):
|
||||
try:
|
||||
from openai import OpenAI
|
||||
|
||||
api_key = self.config.api_key
|
||||
if not api_key:
|
||||
raise ValueError("OpenAI API密钥未设置")
|
||||
|
||||
self.client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=self.config.base_url,
|
||||
timeout=self.config.timeout
|
||||
)
|
||||
logger.info(f"OpenAI客户端初始化成功,base_url: {self.config.base_url}")
|
||||
except ImportError:
|
||||
logger.error("请安装openai包: pip install openai")
|
||||
raise
|
||||
|
||||
def chat_completion(self, messages, **kwargs):
|
||||
formatted_messages = self.format_messages(messages)
|
||||
|
||||
# 获取streaming参数,优先使用kwargs中的设置
|
||||
streaming = kwargs.get("streaming", self.config.streaming)
|
||||
|
||||
# 合并配置
|
||||
params = {
|
||||
"model": self.config.model_name,
|
||||
"messages": formatted_messages,
|
||||
"temperature": kwargs.get("temperature", self.config.temperature),
|
||||
"max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
|
||||
"top_p": kwargs.get("top_p", self.config.top_p),
|
||||
"frequency_penalty": kwargs.get("frequency_penalty", self.config.frequency_penalty),
|
||||
"presence_penalty": kwargs.get("presence_penalty", self.config.presence_penalty),
|
||||
"stream": streaming, # 使用正确的streaming设置
|
||||
}
|
||||
|
||||
logger.info(f"🔧 调用参数: streaming={streaming}")
|
||||
|
||||
if streaming:
|
||||
return self._stream_chat_completion(params)
|
||||
else:
|
||||
return self._normal_chat_completion(params)
|
||||
|
||||
def _normal_chat_completion(self, params):
|
||||
"""非流式响应处理"""
|
||||
logger.info("📡 发送非流式请求...")
|
||||
response = self.client.chat.completions.create(**params)
|
||||
raw_usage = response.usage
|
||||
normalized_usage = normalize_usage(
|
||||
raw_usage.model_dump() if hasattr(raw_usage, 'model_dump') else raw_usage,
|
||||
ModelProvider.OPENAI
|
||||
)
|
||||
return ModelResponse(
|
||||
content=response.choices[0].message.content,
|
||||
model=response.model,
|
||||
usage=normalized_usage,
|
||||
finish_reason=response.choices[0].finish_reason,
|
||||
raw_response=response.model_dump() if hasattr(response, 'model_dump') else response.dict()
|
||||
)
|
||||
|
||||
def _stream_chat_completion(self, params):
|
||||
"""流式响应处理"""
|
||||
logger.info("📡 发送流式请求...")
|
||||
response_stream = self.client.chat.completions.create(**params)
|
||||
|
||||
full_content = ""
|
||||
for chunk in response_stream:
|
||||
if chunk.choices[0].delta.content is not None:
|
||||
content_chunk = chunk.choices[0].delta.content
|
||||
full_content += content_chunk
|
||||
yield ModelResponse(
|
||||
content=content_chunk,
|
||||
model=chunk.model,
|
||||
raw_response=chunk.model_dump() if hasattr(chunk, 'model_dump') else chunk.dict()
|
||||
)
|
||||
|
||||
def completion(self, prompt, **kwargs):
|
||||
# OpenAI 推荐使用 chat_completion,这里保持兼容
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
return self.chat_completion(messages, **kwargs)
|
||||
|
||||
|
||||
class AnthropicClient(BaseLLMClient):
|
||||
"""Anthropic Claude客户端"""
|
||||
|
||||
def _initialize_client(self):
|
||||
try:
|
||||
from anthropic import Anthropic
|
||||
self.client = Anthropic(
|
||||
api_key=self.config.api_key or os.getenv("ANTHROPIC_API_KEY"),
|
||||
timeout=self.config.timeout
|
||||
)
|
||||
except ImportError:
|
||||
logger.error("请安装anthropic包: pip install anthropic")
|
||||
raise
|
||||
|
||||
def chat_completion(self, messages, **kwargs):
|
||||
formatted_messages = self.format_messages(messages)
|
||||
|
||||
# Claude 的消息格式转换
|
||||
claude_messages = []
|
||||
system_message = None
|
||||
|
||||
for msg in formatted_messages:
|
||||
if msg["role"] == "system":
|
||||
system_message = msg["content"]
|
||||
else:
|
||||
claude_messages.append({
|
||||
"role": msg["role"],
|
||||
"content": msg["content"]
|
||||
})
|
||||
# 明确获取streaming参数
|
||||
|
||||
params = {
|
||||
"model": self.config.model_name,
|
||||
"messages": claude_messages,
|
||||
"max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
|
||||
"temperature": kwargs.get("temperature", self.config.temperature),
|
||||
"top_p": kwargs.get("top_p", self.config.top_p),
|
||||
"stream": kwargs.get("streaming", self.config.streaming),
|
||||
}
|
||||
|
||||
if system_message:
|
||||
params["system"] = system_message
|
||||
|
||||
if self.config.streaming:
|
||||
return self._stream_chat_completion(params)
|
||||
else:
|
||||
return self._normal_chat_completion(params)
|
||||
|
||||
def _normal_chat_completion(self, params):
|
||||
response = self.client.messages.create(**params)
|
||||
# Anthropic 返回的usage格式和OpenAI不同,需要进行转换
|
||||
raw_usage = response.usage
|
||||
normalized_usage = normalize_usage(
|
||||
raw_usage.model_dump() if hasattr(raw_usage, 'model_dump') else raw_usage,
|
||||
ModelProvider.ANTHROPIC
|
||||
)
|
||||
return ModelResponse(
|
||||
content=response.content[0].text,
|
||||
model=response.model,
|
||||
usage=normalized_usage,
|
||||
finish_reason=response.stop_reason,
|
||||
raw_response=response.model_dump() if hasattr(response, 'model_dump') else response.dict()
|
||||
)
|
||||
|
||||
def _stream_chat_completion(self, params):
|
||||
with self.client.messages.stream(**params) as stream:
|
||||
for chunk in stream:
|
||||
if chunk.type_ == "content_block_delta":
|
||||
yield ModelResponse(
|
||||
content=chunk.delta.text,
|
||||
model=params["model"],
|
||||
raw_response=chunk.model_dump() if hasattr(chunk, 'model_dump') else chunk.dict()
|
||||
)
|
||||
|
||||
def completion(self, prompt, **kwargs):
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
return self.chat_completion(messages, **kwargs)
|
||||
|
||||
|
||||
class OllamaClient(BaseLLMClient):
|
||||
"""Ollama本地模型客户端"""
|
||||
|
||||
def _initialize_client(self):
|
||||
import httpx
|
||||
self.base_url = self.config.base_url or "http://localhost:11434"
|
||||
self.client = httpx.Client(
|
||||
base_url=self.base_url,
|
||||
timeout=self.config.timeout
|
||||
)
|
||||
|
||||
def chat_completion(self, messages, **kwargs):
|
||||
formatted_messages = self.format_messages(messages)
|
||||
|
||||
payload = {
|
||||
"model": self.config.model_name,
|
||||
"messages": formatted_messages,
|
||||
"options": {
|
||||
"temperature": kwargs.get("temperature", self.config.temperature),
|
||||
"top_p": kwargs.get("top_p", self.config.top_p),
|
||||
"num_predict": kwargs.get("max_tokens", self.config.max_tokens),
|
||||
},
|
||||
"stream": kwargs.get("streaming", self.config.streaming),
|
||||
}
|
||||
|
||||
if self.config.streaming:
|
||||
return self._stream_chat_completion(payload)
|
||||
else:
|
||||
return self._normal_chat_completion(payload)
|
||||
|
||||
def _normal_chat_completion(self, payload):
|
||||
response = self.client.post("/api/chat", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
normalized_usage = normalize_usage(data, ModelProvider.OLLAMA)
|
||||
return ModelResponse(
|
||||
content=data["message"]["content"],
|
||||
model=data["model"],
|
||||
usage=normalized_usage,
|
||||
finish_reason=data.get("done_reason"),
|
||||
raw_response=data
|
||||
)
|
||||
|
||||
def _stream_chat_completion(self, payload):
|
||||
with self.client.stream("POST", "/api/chat", json=payload) as response:
|
||||
for line in response.iter_lines():
|
||||
if line.strip():
|
||||
try:
|
||||
data = json.loads(line)
|
||||
if "message" in data and "content" in data["message"]:
|
||||
yield ModelResponse(
|
||||
content=data["message"]["content"],
|
||||
model=data.get("model", self.config.model_name),
|
||||
raw_response=data
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
def completion(self, prompt, **kwargs):
|
||||
payload = {
|
||||
"model": self.config.model_name,
|
||||
"prompt": prompt,
|
||||
"options": {
|
||||
"temperature": kwargs.get("temperature", self.config.temperature),
|
||||
"top_p": kwargs.get("top_p", self.config.top_p),
|
||||
"num_predict": kwargs.get("max_tokens", self.config.max_tokens),
|
||||
},
|
||||
"stream": kwargs.get("streaming", self.config.streaming),
|
||||
}
|
||||
|
||||
if self.config.streaming:
|
||||
return self._stream_completion(payload)
|
||||
else:
|
||||
return self._normal_completion(payload)
|
||||
|
||||
def _normal_completion(self, payload):
|
||||
response = self.client.post("/api/generate", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
return ModelResponse(
|
||||
content=data["response"],
|
||||
model=data["model"],
|
||||
usage={
|
||||
"prompt_tokens": data.get("prompt_eval_count", 0),
|
||||
"completion_tokens": data.get("eval_count", 0),
|
||||
"total_tokens": data.get("prompt_eval_count", 0) + data.get("eval_count", 0)
|
||||
},
|
||||
finish_reason=data.get("done_reason"),
|
||||
raw_response=data
|
||||
)
|
||||
|
||||
def _stream_completion(self, payload):
|
||||
with self.client.stream("POST", "/api/generate", json=payload) as response:
|
||||
for line in response.iter_lines():
|
||||
if line.strip():
|
||||
try:
|
||||
data = json.loads(line)
|
||||
if "response" in data:
|
||||
yield ModelResponse(
|
||||
content=data["response"],
|
||||
model=data.get("model", self.config.model_name),
|
||||
raw_response=data
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
|
||||
class GenericLLMClient(BaseLLMClient):
|
||||
"""通用HTTP客户端,支持LM Studio和其他兼容OpenAI API的本地模型"""
|
||||
|
||||
def _initialize_client(self):
|
||||
import httpx
|
||||
self.base_url = self.config.base_url or "http://localhost:1234/v1"
|
||||
self.client = httpx.Client(
|
||||
base_url=self.base_url,
|
||||
timeout=self.config.timeout,
|
||||
headers=self.config.custom_headers
|
||||
)
|
||||
|
||||
def chat_completion(self, messages, **kwargs):
|
||||
formatted_messages = self.format_messages(messages)
|
||||
|
||||
# 明确获取 streaming 参数
|
||||
streaming = kwargs.get("streaming", self.config.streaming)
|
||||
|
||||
payload = {
|
||||
"model": self.config.model_name,
|
||||
"messages": formatted_messages,
|
||||
"temperature": kwargs.get("temperature", self.config.temperature),
|
||||
"max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
|
||||
"top_p": kwargs.get("top_p", self.config.top_p),
|
||||
"stream": streaming, # 使用明确的 streaming 变量
|
||||
}
|
||||
|
||||
logger.info(f"GenericLLMClient 参数: streaming={streaming}")
|
||||
|
||||
if streaming:
|
||||
return self._stream_chat_completion(payload)
|
||||
else:
|
||||
return self._normal_chat_completion(payload)
|
||||
|
||||
def _normal_chat_completion(self, payload):
|
||||
logger.info(f"GenericLLMClient 发送非流式请求到: {self.base_url}")
|
||||
response = self.client.post("/chat/completions", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
logger.info(f"GenericLLMClient 收到响应,模型: {data.get('model')}")
|
||||
|
||||
return ModelResponse(
|
||||
content=data["choices"][0]["message"]["content"],
|
||||
model=data["model"],
|
||||
usage=data.get("usage"),
|
||||
finish_reason=data["choices"][0].get("finish_reason"),
|
||||
raw_response=data
|
||||
)
|
||||
|
||||
def _stream_chat_completion(self, payload):
|
||||
logger.info(f"GenericLLMClient 发送流式请求到: {self.base_url}")
|
||||
with self.client.stream("POST", "/chat/completions", json=payload) as response:
|
||||
for line in response.iter_lines():
|
||||
if line.startswith("data: "):
|
||||
chunk = line[6:]
|
||||
if chunk == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(chunk)
|
||||
if data["choices"][0]["delta"].get("content"):
|
||||
yield ModelResponse(
|
||||
content=data["choices"][0]["delta"]["content"],
|
||||
model=data["model"],
|
||||
raw_response=data
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
def completion(self, prompt, **kwargs):
|
||||
payload = {
|
||||
"model": self.config.model_name,
|
||||
"prompt": prompt,
|
||||
"temperature": kwargs.get("temperature", self.config.temperature),
|
||||
"max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
|
||||
"top_p": kwargs.get("top_p", self.config.top_p),
|
||||
"stream": kwargs.get("streaming", self.config.streaming),
|
||||
}
|
||||
|
||||
streaming = kwargs.get("streaming", self.config.streaming)
|
||||
|
||||
if streaming:
|
||||
return self._stream_completion(payload)
|
||||
else:
|
||||
return self._normal_completion(payload)
|
||||
|
||||
def _normal_completion(self, payload):
|
||||
response = self.client.post("/completions", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
return ModelResponse(
|
||||
content=data["choices"][0]["text"],
|
||||
model=data["model"],
|
||||
usage=data.get("usage"),
|
||||
finish_reason=data["choices"][0].get("finish_reason"),
|
||||
raw_response=data
|
||||
)
|
||||
|
||||
def _stream_completion(self, payload):
|
||||
with self.client.stream("POST", "/completions", json=payload) as response:
|
||||
for line in response.iter_lines():
|
||||
if line.startswith("data: "):
|
||||
chunk = line[6:]
|
||||
if chunk == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(chunk)
|
||||
if data["choices"][0].get("text"):
|
||||
yield ModelResponse(
|
||||
content=data["choices"][0]["text"],
|
||||
model=data["model"],
|
||||
raw_response=data
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
class UnifiedLLM:
|
||||
"""
|
||||
统一的大语言模型调用类
|
||||
|
||||
支持多种模型提供商:
|
||||
- 云端模型:OpenAI, Anthropic, Google, Azure
|
||||
- 本地模型:Ollama, LM Studio, llama.cpp
|
||||
|
||||
使用示例:
|
||||
```python
|
||||
# 初始化OpenAI客户端
|
||||
config = ModelConfig(
|
||||
provider=ModelProvider.OPENAI,
|
||||
model_name="gpt-4",
|
||||
api_key="your-api-key"
|
||||
)
|
||||
llm = UnifiedLLM(config)
|
||||
|
||||
# 聊天补全
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个有用的助手"},
|
||||
{"role": "user", "content": "你好!"}
|
||||
]
|
||||
response = llm.chat(messages)
|
||||
print(response.content)
|
||||
|
||||
# 流式响应
|
||||
config.streaming = True
|
||||
llm.update_config(config)
|
||||
for chunk in llm.chat(messages):
|
||||
print(chunk.content, end="", flush=True)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config: ModelConfig):
|
||||
"""
|
||||
初始化统一LLM
|
||||
|
||||
Args:
|
||||
config: 模型配置
|
||||
"""
|
||||
self.config = config
|
||||
self.client = self._create_client()
|
||||
logger.info(f" UnifiedLLM 初始化完成")
|
||||
logger.info(f" 提供商: {config.provider}")
|
||||
logger.info(f" 模型: {config.model_name}")
|
||||
logger.info(f" 流式默认: {config.streaming}")
|
||||
|
||||
def _create_client(self) -> BaseLLMClient:
|
||||
"""根据配置创建客户端"""
|
||||
provider = self.config.provider
|
||||
|
||||
if provider == ModelProvider.OPENAI:
|
||||
return OpenAIClient(self.config)
|
||||
elif provider == ModelProvider.ANTHROPIC:
|
||||
return AnthropicClient(self.config)
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
return OllamaClient(self.config)
|
||||
elif provider in [ModelProvider.LM_STUDIO, ModelProvider.LLAMA_CPP, ModelProvider.CUSTOM]:
|
||||
return GenericLLMClient(self.config)
|
||||
elif provider == ModelProvider.GOOGLE:
|
||||
# 这里可以扩展Google Gemini支持
|
||||
return GenericLLMClient(self.config)
|
||||
elif provider == ModelProvider.AZURE:
|
||||
# Azure OpenAI需要特殊处理
|
||||
return GenericLLMClient(self.config)
|
||||
else:
|
||||
raise ValueError(f"不支持的模型提供商: {provider}")
|
||||
|
||||
def update_config(self, config: ModelConfig):
|
||||
"""更新配置并重新创建客户端"""
|
||||
self.config = config
|
||||
self.client = self._create_client()
|
||||
logger.info(f"UnifiedLLM 配置已更新")
|
||||
|
||||
def chat(self, messages: List[Union[ChatMessage, Dict]], **kwargs) -> Union[ModelResponse, Iterator[ModelResponse]]:
|
||||
"""
|
||||
聊天补全
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
**kwargs: 其他参数,会覆盖config中的设置
|
||||
|
||||
Returns:
|
||||
ModelResponse 或 ModelResponse 的迭代器(流式模式)
|
||||
"""
|
||||
# 明确获取streaming参数
|
||||
streaming = kwargs.get("streaming", self.config.streaming)
|
||||
logger.info(f" UnifiedLLM.chat() 调用")
|
||||
logger.info(f" 消息数: {len(messages)}")
|
||||
logger.info(f" streaming参数: {streaming}")
|
||||
|
||||
# 调用客户端
|
||||
result = self.client.chat_completion(messages, **kwargs)
|
||||
|
||||
# 类型检查(调试用)
|
||||
if streaming:
|
||||
if not hasattr(result, '__iter__'):
|
||||
logger.warning(f"警告: streaming=True 但返回的不是迭代器")
|
||||
else:
|
||||
if hasattr(result, '__iter__'):
|
||||
logger.warning(f"警告: streaming=False 但返回的是迭代器")
|
||||
elif not isinstance(result, ModelResponse):
|
||||
logger.warning(f"警告: streaming=False 但返回的不是ModelResponse")
|
||||
|
||||
return result
|
||||
|
||||
def complete(self, prompt: str, **kwargs) -> Union[ModelResponse, Iterator[ModelResponse]]:
|
||||
"""
|
||||
文本补全
|
||||
|
||||
Args:
|
||||
prompt: 提示文本
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
ModelResponse 或 ModelResponse 的迭代器(流式模式)
|
||||
"""
|
||||
streaming = kwargs.get("streaming", self.config.streaming)
|
||||
logger.info(f" UnifiedLLM.complete() 调用")
|
||||
logger.info(f" prompt长度: {len(prompt)}")
|
||||
logger.info(f" streaming参数: {streaming}")
|
||||
|
||||
return self.client.completion(prompt, **kwargs)
|
||||
|
||||
def stream_chat(self, messages: List[Union[ChatMessage, Dict]], **kwargs) -> Iterator[ModelResponse]:
|
||||
"""
|
||||
流式聊天补全(便捷方法)
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
ModelResponse 的迭代器
|
||||
"""
|
||||
kwargs["streaming"] = True
|
||||
logger.info(f"UnifiedLLM.stream_chat() 调用")
|
||||
|
||||
result = self.chat(messages, **kwargs)
|
||||
|
||||
# 确保返回的是迭代器
|
||||
if not hasattr(result, '__iter__'):
|
||||
raise TypeError("stream_chat 应该返回迭代器,但返回了其他类型")
|
||||
|
||||
return result
|
||||
|
||||
def stream_complete(self, prompt: str, **kwargs) -> Iterator[ModelResponse]:
|
||||
"""
|
||||
流式文本补全(便捷方法)
|
||||
|
||||
Args:
|
||||
prompt: 提示文本
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
ModelResponse 的迭代器
|
||||
"""
|
||||
kwargs["streaming"] = True
|
||||
logger.info(f"UnifiedLLM.stream_complete() 调用")
|
||||
|
||||
result = self.complete(prompt, **kwargs)
|
||||
|
||||
# 确保返回的是迭代器
|
||||
if not hasattr(result, '__iter__'):
|
||||
raise TypeError("stream_complete 应该返回迭代器,但返回了其他类型")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# 快捷函数
|
||||
def create_llm_client(
|
||||
provider: Union[str, ModelProvider],
|
||||
model_name: str,
|
||||
**kwargs
|
||||
) -> UnifiedLLM:
|
||||
"""
|
||||
快捷创建LLM客户端
|
||||
|
||||
Args:
|
||||
provider: 提供商名称或枚举
|
||||
model_name: 模型名称
|
||||
**kwargs: 其他配置参数
|
||||
|
||||
Returns:
|
||||
UnifiedLLM 实例
|
||||
"""
|
||||
if isinstance(provider, str):
|
||||
provider = ModelProvider(provider.lower())
|
||||
|
||||
config = ModelConfig(
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return UnifiedLLM(config)
|
||||
|
||||
|
||||
# 使用示例
|
||||
def example_usage():
|
||||
"""使用示例"""
|
||||
|
||||
# 示例1: 使用OpenAI
|
||||
print("示例1: 使用OpenAI")
|
||||
openai_config = ModelConfig(
|
||||
provider=ModelProvider.OPENAI,
|
||||
model_name="gpt-3.5-turbo",
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
temperature=0.8
|
||||
)
|
||||
|
||||
try:
|
||||
openai_llm = UnifiedLLM(openai_config)
|
||||
messages = [
|
||||
ChatMessage(role="system", content="你是一个有用的助手"),
|
||||
ChatMessage(role="user", content="请用Python写一个Hello World程序")
|
||||
]
|
||||
response = openai_llm.chat(messages)
|
||||
print(f"响应: {response.content[:100]}...")
|
||||
except Exception as e:
|
||||
print(f"OpenAI示例错误: {e}")
|
||||
|
||||
# 示例2: 使用Ollama(本地模型)
|
||||
print("\n示例2: 使用Ollama(本地模型)")
|
||||
ollama_config = ModelConfig(
|
||||
provider=ModelProvider.OLLAMA,
|
||||
model_name="llama2",
|
||||
base_url="http://localhost:11434",
|
||||
temperature=0.7,
|
||||
streaming=True # 流式响应
|
||||
)
|
||||
|
||||
try:
|
||||
ollama_llm = UnifiedLLM(ollama_config)
|
||||
messages = [
|
||||
{"role": "user", "content": "什么是人工智能?"}
|
||||
]
|
||||
|
||||
print("流式响应:")
|
||||
for chunk in ollama_llm.stream_chat(messages):
|
||||
print(chunk.content, end="", flush=True)
|
||||
print()
|
||||
except Exception as e:
|
||||
print(f"Ollama示例错误: {e}(请确保Ollama服务正在运行)")
|
||||
|
||||
# 示例3: 使用快捷函数
|
||||
print("\n示例3: 使用快捷函数")
|
||||
try:
|
||||
llm = create_llm_client(
|
||||
provider="openai",
|
||||
model_name="gpt-3.5-turbo",
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
temperature=0.5
|
||||
)
|
||||
|
||||
response = llm.complete("天空为什么是蓝色的?")
|
||||
print(f"响应: {response.content[:100]}...")
|
||||
except Exception as e:
|
||||
print(f"快捷函数示例错误: {e}")
|
||||
@@ -0,0 +1,162 @@
|
||||
# 大语言模型调用框架架构图
|
||||
general_text_ai_req.py
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph "应用层"
|
||||
A[用户应用] --> B[UnifiedLLM 统一接口]
|
||||
end
|
||||
|
||||
subgraph "适配器层"
|
||||
B --> C{模型提供商路由}
|
||||
C --> D[OpenAI 适配器]
|
||||
C --> E[Anthropic 适配器]
|
||||
C --> F[Ollama 适配器]
|
||||
C --> G[通用HTTP适配器]
|
||||
C --> H[其他适配器]
|
||||
end
|
||||
|
||||
subgraph "服务层"
|
||||
D --> I[OpenAI API]
|
||||
E --> J[Anthropic API]
|
||||
F --> K[Ollama 服务]
|
||||
G --> L[LM Studio]
|
||||
G --> M[llama.cpp]
|
||||
G --> N[其他兼容API]
|
||||
end
|
||||
|
||||
subgraph "配置层"
|
||||
O[ModelConfig] --> C
|
||||
O --> D
|
||||
O --> E
|
||||
O --> F
|
||||
O --> G
|
||||
end
|
||||
|
||||
subgraph "数据流"
|
||||
P[输入: 消息/提示] --> B
|
||||
I --> Q[输出: ModelResponse]
|
||||
J --> Q
|
||||
K --> Q
|
||||
L --> Q
|
||||
M --> Q
|
||||
N --> Q
|
||||
end
|
||||
|
||||
style A fill:#4567f1
|
||||
style B fill:#4567f1
|
||||
style O fill:#456748
|
||||
style D fill:#457911
|
||||
style E fill:#466bd5
|
||||
style F fill:#4567f1
|
||||
style G fill:#4567f1
|
||||
```
|
||||
|
||||
# 数据流图
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant User as 用户/应用
|
||||
participant UnifiedLLM as UnifiedLLM
|
||||
participant Adapter as 适配器
|
||||
participant API as API服务
|
||||
|
||||
User->>UnifiedLLM: 调用chat()或complete()
|
||||
UnifiedLLM->>UnifiedLLM: 根据配置选择适配器
|
||||
UnifiedLLM->>Adapter: 转发请求
|
||||
Adapter->>API: 发送HTTP请求/API调用
|
||||
Note over API: 处理请求并生成响应
|
||||
|
||||
alt 流式模式
|
||||
API-->>Adapter: 流式响应数据
|
||||
Adapter-->>UnifiedLLM: 流式ModelResponse
|
||||
UnifiedLLM-->>User: 迭代器返回分块响应
|
||||
else 非流式模式
|
||||
API-->>Adapter: 完整响应
|
||||
Adapter-->>UnifiedLLM: ModelResponse对象
|
||||
UnifiedLLM-->>User: 完整响应内容
|
||||
end
|
||||
```
|
||||
# 类关系图
|
||||
```mermaid
|
||||
classDiagram
|
||||
class ModelConfig {
|
||||
+provider: ModelProvider
|
||||
+model_name: str
|
||||
+api_key: Optional[str]
|
||||
+base_url: Optional[str]
|
||||
+temperature: float
|
||||
+max_tokens: int
|
||||
+to_dict() Dict
|
||||
}
|
||||
|
||||
class ChatMessage {
|
||||
+role: str
|
||||
+content: str
|
||||
+name: Optional[str]
|
||||
+to_dict() Dict
|
||||
}
|
||||
|
||||
class ModelResponse {
|
||||
+content: str
|
||||
+model: str
|
||||
+usage: Optional[Dict]
|
||||
+finish_reason: Optional[str]
|
||||
+raw_response: Optional[Dict]
|
||||
}
|
||||
|
||||
class BaseLLMClient {
|
||||
<<abstract>>
|
||||
#config: ModelConfig
|
||||
#client: Any
|
||||
+__init__(config: ModelConfig)
|
||||
+_initialize_client()
|
||||
+chat_completion(messages, **kwargs)*
|
||||
+completion(prompt, **kwargs)*
|
||||
+format_messages(messages) List[Dict]
|
||||
}
|
||||
|
||||
class UnifiedLLM {
|
||||
-config: ModelConfig
|
||||
-client: BaseLLMClient
|
||||
+__init__(config: ModelConfig)
|
||||
+_create_client() BaseLLMClient
|
||||
+update_config(config: ModelConfig)
|
||||
+chat(messages, **kwargs) ModelResponse
|
||||
+complete(prompt, **kwargs) ModelResponse
|
||||
+stream_chat(messages, **kwargs) Iterator
|
||||
+stream_complete(prompt, **kwargs) Iterator
|
||||
}
|
||||
|
||||
class OpenAIClient {
|
||||
+_initialize_client()
|
||||
+chat_completion(messages, **kwargs)
|
||||
+completion(prompt, **kwargs)
|
||||
}
|
||||
|
||||
class AnthropicClient {
|
||||
+_initialize_client()
|
||||
+chat_completion(messages, **kwargs)
|
||||
+completion(prompt, **kwargs)
|
||||
}
|
||||
|
||||
class OllamaClient {
|
||||
+_initialize_client()
|
||||
+chat_completion(messages, **kwargs)
|
||||
+completion(prompt, **kwargs)
|
||||
}
|
||||
|
||||
class GenericLLMClient {
|
||||
+_initialize_client()
|
||||
+chat_completion(messages, **kwargs)
|
||||
+completion(prompt, **kwargs)
|
||||
}
|
||||
|
||||
BaseLLMClient <|-- OpenAIClient
|
||||
BaseLLMClient <|-- AnthropicClient
|
||||
BaseLLMClient <|-- OllamaClient
|
||||
BaseLLMClient <|-- GenericLLMClient
|
||||
UnifiedLLM o-- BaseLLMClient
|
||||
UnifiedLLM --> ModelConfig
|
||||
BaseLLMClient --> ModelConfig
|
||||
BaseLLMClient --> ChatMessage
|
||||
BaseLLMClient --> ModelResponse
|
||||
```
|
||||
@@ -0,0 +1,26 @@
|
||||
本模块为tts模块,即文本转语音模块。
|
||||
|
||||
本模块负责将来自AI的回复转为语音。
|
||||
|
||||
说明:
|
||||
在本module当中,每个子模块的用途分别是:
|
||||
- tts_core 对不同的tts的实现,提供相对统一的接口
|
||||
- gpt_sovits
|
||||
实现了gpt_sovits的tts接口封装
|
||||
|
||||
|
||||
async_audio_player.py
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant TTS as GPT-SoVITS API
|
||||
participant WS as WebSocket服务
|
||||
participant Buffer as 音频缓冲区
|
||||
participant Player as 音频播放器
|
||||
|
||||
TTS->>WS: 流式音频块(chunks)
|
||||
WS->>Buffer: 写入队列(Queue)
|
||||
Buffer->>Player: 消费PCM数据
|
||||
Player->>声卡: 实时播放
|
||||
|
||||
Note over TTS,Player: 三重缓冲 + 动态采样率检测
|
||||
```
|
||||
@@ -0,0 +1,164 @@
|
||||
# tts_core/async_audio_player.py
|
||||
import asyncio
|
||||
import io
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
import numpy as np
|
||||
import sounddevice as sd
|
||||
import wave
|
||||
|
||||
class AsyncAudioPlayer:
|
||||
"""
|
||||
异步流式音频播放器
|
||||
- 自动检测WAV头并解析采样率
|
||||
- 使用环形缓冲区确保播放流畅
|
||||
- 支持动态音频格式切换
|
||||
"""
|
||||
|
||||
def __init__(self, buffer_size: int = 10):
|
||||
"""
|
||||
Args:
|
||||
buffer_size: 音频块缓冲数量(越大越稳定,但延迟越高)
|
||||
"""
|
||||
self.audio_queue = asyncio.Queue(maxsize=buffer_size)
|
||||
self.sample_rate = 32000 # 默认采样率
|
||||
self.channels = 1
|
||||
self.dtype = np.float32
|
||||
self.stream: Optional[sd.OutputStream] = None
|
||||
self.is_playing = False
|
||||
self._first_chunk_processed = False
|
||||
logger.info(f"🎵 音频播放器初始化,缓冲区大小: {buffer_size}")
|
||||
|
||||
async def add_chunk(self, audio_data: bytes):
|
||||
"""
|
||||
添加音频块到播放队列
|
||||
自动处理第一个chunk(包含WAV头)
|
||||
"""
|
||||
try:
|
||||
# 第一个chunk需要解析WAV头
|
||||
if not self._first_chunk_processed:
|
||||
# 写入BytesIO以便wave模块读取
|
||||
wav_buffer = io.BytesIO(audio_data)
|
||||
try:
|
||||
with wave.open(wav_buffer, 'rb') as wav_file:
|
||||
# 解析WAV头信息
|
||||
self.sample_rate = wav_file.getframerate()
|
||||
self.channels = wav_file.getnchannels()
|
||||
self.sampwidth = wav_file.getsampwidth()
|
||||
|
||||
# 读取PCM数据(去掉头部)
|
||||
pcm_data = wav_file.readframes(wav_file.getnframes())
|
||||
|
||||
logger.info(f"📊 解析WAV头: {self.sample_rate}Hz, {self.channels}ch, {self.sampwidth * 8}bit")
|
||||
|
||||
# 转换为numpy数组
|
||||
if self.sampwidth == 2:
|
||||
audio_array = np.frombuffer(pcm_data, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
elif self.sampwidth == 4:
|
||||
audio_array = np.frombuffer(pcm_data, dtype=np.int32).astype(np.float32) / 2147483648.0
|
||||
else:
|
||||
raise ValueError(f"不支持的采样宽度: {self.sampwidth}")
|
||||
|
||||
# 转单声道(如果多声道)
|
||||
if self.channels > 1:
|
||||
audio_array = audio_array.reshape(-1, self.channels).mean(axis=1)
|
||||
|
||||
await self.audio_queue.put(audio_array)
|
||||
self._first_chunk_processed = True
|
||||
|
||||
except wave.Error:
|
||||
# 可能是不完整的WAV头,尝试直接播放
|
||||
logger.warning("⚠️ WAV头解析失败,尝试直接播放")
|
||||
await self._play_raw(audio_data)
|
||||
return
|
||||
else:
|
||||
# 后续chunk直接播放(RAW PCM)
|
||||
await self._play_raw(audio_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 音频块处理失败: {e}")
|
||||
|
||||
async def _play_raw(self, audio_data: bytes):
|
||||
"""播放RAW PCM数据"""
|
||||
try:
|
||||
# 假设是16位PCM(最常见)
|
||||
audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
|
||||
# 如果是多声道数据(罕见)
|
||||
if len(audio_array) % self.channels == 0 and self.channels > 1:
|
||||
audio_array = audio_array.reshape(-1, self.channels).mean(axis=1)
|
||||
|
||||
await self.audio_queue.put(audio_array)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ RAW音频处理失败: {e}")
|
||||
|
||||
async def play_worker(self):
|
||||
"""后台播放任务"""
|
||||
logger.info("🎧 音频播放任务启动")
|
||||
|
||||
while self.is_playing or not self.audio_queue.empty():
|
||||
try:
|
||||
# 从队列获取音频块(最多等待0.5秒)
|
||||
audio_chunk = await asyncio.wait_for(self.audio_queue.get(), timeout=0.5)
|
||||
|
||||
# 延迟初始化音频流(直到获得第一个数据块)
|
||||
if self.stream is None:
|
||||
logger.info(f"🔊 打开音频输出流: {self.sample_rate}Hz")
|
||||
self.stream = sd.OutputStream(
|
||||
samplerate=self.sample_rate,
|
||||
channels=1,
|
||||
dtype=self.dtype,
|
||||
blocksize=1024, # 低延迟模式
|
||||
latency='low'
|
||||
)
|
||||
self.stream.start()
|
||||
|
||||
# 写入音频流播放
|
||||
self.stream.write(audio_chunk)
|
||||
|
||||
# 标记任务完成
|
||||
self.audio_queue.task_done()
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 播放任务异常: {e}")
|
||||
break
|
||||
|
||||
logger.info("🛑 音频播放任务结束")
|
||||
|
||||
async def start(self):
|
||||
"""启动播放系统"""
|
||||
self.is_playing = True
|
||||
self._first_chunk_processed = False
|
||||
self.play_task = asyncio.create_task(self.play_worker())
|
||||
|
||||
async def stop(self):
|
||||
"""停止播放并清理资源"""
|
||||
self.is_playing = False
|
||||
|
||||
# 等待播放任务结束
|
||||
if hasattr(self, 'play_task'):
|
||||
await self.play_task
|
||||
|
||||
# 关闭音频流
|
||||
if self.stream is not None:
|
||||
self.stream.stop()
|
||||
self.stream.close()
|
||||
self.stream = None
|
||||
|
||||
# 清空队列
|
||||
while not self.audio_queue.empty():
|
||||
try:
|
||||
self.audio_queue.get_nowait()
|
||||
except:
|
||||
break
|
||||
|
||||
logger.info("✅ 音频播放已停止")
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.start()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.stop()
|
||||
@@ -0,0 +1,378 @@
|
||||
# gpt_sovits/gpt_sovits_client.py
|
||||
import asyncio
|
||||
import json
|
||||
from loguru import logger
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator, Optional, Union, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
|
||||
class APIError(Exception):
|
||||
"""API调用异常"""
|
||||
def __init__(self, status_code: int, message: str):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
super().__init__(f"API Error {status_code}: {message}")
|
||||
|
||||
|
||||
class StreamingMode(Enum):
|
||||
"""流式模式枚举"""
|
||||
DISABLED = 0 # 非流式
|
||||
BEST_QUALITY = 1 # 最佳质量(慢)
|
||||
MEDIUM_QUALITY = 2 # 中等质量
|
||||
FASTEST = 3 # 最快响应(较低质量)
|
||||
|
||||
|
||||
class TTSConfig(BaseModel):
|
||||
"""TTS请求配置模型"""
|
||||
text: str = Field(..., description="待合成文本")
|
||||
text_lang: str = Field(..., description="文本语言: zh/en/ja/ko/cantonese")
|
||||
ref_audio_path: str = Field(..., description="参考音频路径")
|
||||
prompt_lang: str = Field(..., description="提示文本语言")
|
||||
|
||||
# 可选参数
|
||||
prompt_text: str = Field(default="", description="参考音频提示文本")
|
||||
aux_ref_audio_paths: list = Field(default_factory=list, description="辅助参考音频")
|
||||
top_k: int = Field(default=5, ge=1, le=100, description="Top-K采样")
|
||||
top_p: float = Field(default=1.0, ge=0.1, le=1.0, description="Top-P采样")
|
||||
temperature: float = Field(default=1.0, ge=0.1, le=1.0, description="采样温度")
|
||||
text_split_method: str = Field(default="cut5", description="文本分割方法") # 默认按照标点符号切分
|
||||
batch_size: int = Field(default=8, ge=1, le=200, description="批处理大小")
|
||||
speed_factor: float = Field(default=1.0, ge=0.6, le=1.65, description="语速倍率")
|
||||
|
||||
# 流式相关
|
||||
streaming_mode: Union[bool, int, StreamingMode] = Field(default=False, description="流式模式")
|
||||
media_type: str = Field(default="wav", description="输出格式: wav/raw/ogg/aac") # 输出格式
|
||||
|
||||
# 高级参数
|
||||
repetition_penalty: float = Field(default=1.35, ge=1.0, le=2.0) # 惩罚参数
|
||||
sample_steps: int = Field(default=32, ge=10, le=100) # 采样步数
|
||||
parallel_infer: bool = Field(default=True) # 并行推理
|
||||
|
||||
@validator('text_lang', 'prompt_lang')
|
||||
def validate_language(cls, v):
|
||||
"""验证语言代码"""
|
||||
valid_langs = {'zh', 'en', 'ja', 'ko', 'cantonese'}
|
||||
if v.lower() not in valid_langs:
|
||||
raise ValueError(f"Unsupported language: {v}. Must be one of {valid_langs}")
|
||||
return v.lower()
|
||||
|
||||
@validator('media_type')
|
||||
def validate_media_type(cls, v):
|
||||
"""验证媒体类型"""
|
||||
valid_types = {'wav', 'raw', 'ogg', 'aac'}
|
||||
if v not in valid_types:
|
||||
raise ValueError(f"Unsupported media_type: {v}")
|
||||
return v
|
||||
|
||||
def build_request(self) -> Dict[str, Any]:
|
||||
"""构建API请求数据"""
|
||||
data = self.dict(exclude_none=True)
|
||||
# 处理流式模式
|
||||
if isinstance(self.streaming_mode, StreamingMode):
|
||||
data['streaming_mode'] = self.streaming_mode.value
|
||||
return data
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioResponse:
|
||||
"""音频响应包装类"""
|
||||
audio_data: bytes
|
||||
sample_rate: int = 32000
|
||||
|
||||
def save(self, path: Union[str, Path]) -> None:
|
||||
"""保存音频文件"""
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, 'wb') as f:
|
||||
f.write(self.audio_data)
|
||||
logger.info(f"Audio saved to {path}, size: {len(self.audio_data)} bytes")
|
||||
|
||||
|
||||
class GPTSoVITSClient:
|
||||
"""
|
||||
GPT-SoVITS异步API客户端
|
||||
|
||||
完整支持所有TTS功能:
|
||||
- 文本合成(流式/非流式)
|
||||
- 模型切换(GPT/SoVITS)
|
||||
- 参考音频设置
|
||||
- 服务器控制
|
||||
"""
|
||||
|
||||
def __init__(self, host: str = "127.0.0.1", port: int = 9880, debug: bool = False):
|
||||
"""
|
||||
初始化客户端
|
||||
|
||||
Args:
|
||||
host: API服务器地址
|
||||
port: API端口
|
||||
debug: 是否开启调试模式
|
||||
"""
|
||||
self.base_url = f"http://{host}:{port}"
|
||||
self.client = httpx.AsyncClient(
|
||||
base_url=self.base_url,
|
||||
timeout=httpx.Timeout(30.0, connect=5.0)
|
||||
)
|
||||
self.debug_mode = debug
|
||||
logger.info(f"GPT-SoVITS Client initialized: {self.base_url}")
|
||||
|
||||
async def __aenter__(self) -> "GPTSoVITSClient":
|
||||
"""异步上下文管理器入口"""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""异步上下文管理器出口"""
|
||||
await self.close()
|
||||
|
||||
async def close(self):
|
||||
"""关闭HTTP连接"""
|
||||
await self.client.aclose()
|
||||
logger.info("Client connection closed")
|
||||
|
||||
def _log_debug(self, message: str, **kwargs):
|
||||
"""调试日志"""
|
||||
if self.debug_mode:
|
||||
logger.debug(f"{message} | {kwargs}")
|
||||
|
||||
async def _handle_response(self, response: httpx.Response) -> Dict[str, Any]:
|
||||
"""统一响应处理"""
|
||||
if response.status_code == 200:
|
||||
content_type = response.headers.get('content-type', '')
|
||||
if 'application/json' in content_type:
|
||||
return response.json()
|
||||
return {"status": "success", "content": response.content}
|
||||
else:
|
||||
try:
|
||||
error_data = response.json()
|
||||
raise APIError(response.status_code, error_data.get('message', 'Unknown error'))
|
||||
except json.JSONDecodeError:
|
||||
raise APIError(response.status_code, response.text)
|
||||
|
||||
# 核心TTS接口
|
||||
async def tts(
|
||||
self,
|
||||
text: str,
|
||||
ref_audio_path: str,
|
||||
text_lang: str = "zh",
|
||||
prompt_lang: str = "zh",
|
||||
streaming_mode: StreamingMode = StreamingMode.DISABLED, # 默认禁用流式
|
||||
media_type: str = "wav",
|
||||
**kwargs
|
||||
) -> Union[AudioResponse, AsyncGenerator[AudioResponse, None]]:
|
||||
"""
|
||||
文本转语音(支持流式)
|
||||
|
||||
Args:
|
||||
text: 待合成文本
|
||||
ref_audio_path: 参考音频路径(服务器本地路径或URL)
|
||||
text_lang: 文本语言
|
||||
prompt_lang: 提示语言
|
||||
streaming_mode: 流式模式
|
||||
media_type: 输出格式
|
||||
**kwargs: 其他TTS参数
|
||||
|
||||
Returns:
|
||||
非流式: AudioResponse对象
|
||||
流式: AsyncGenerator[AudioResponse, None]异步生成器
|
||||
|
||||
Example:
|
||||
# 非流式
|
||||
audio = await client.tts("你好", "ref.wav")
|
||||
|
||||
# 流式
|
||||
async for chunk in client.tts("你好", "ref.wav", streaming_mode=StreamingMode.FASTEST):
|
||||
process(chunk.audio_data)
|
||||
"""
|
||||
config = TTSConfig(
|
||||
text=text,
|
||||
ref_audio_path=ref_audio_path,
|
||||
text_lang=text_lang,
|
||||
prompt_lang=prompt_lang,
|
||||
streaming_mode=streaming_mode,
|
||||
media_type=media_type,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self._log_debug("TTS Request", config=config.dict())
|
||||
|
||||
if streaming_mode == StreamingMode.DISABLED:
|
||||
# 非流式模式
|
||||
response = await self.client.post("/tts", json=config.build_request())
|
||||
if response.status_code != 200:
|
||||
raise APIError(response.status_code, await response.text())
|
||||
|
||||
return AudioResponse(
|
||||
audio_data=response.content,
|
||||
sample_rate=32000 # 默认采样率
|
||||
)
|
||||
else:
|
||||
# 流式模式
|
||||
config.parallel_infer = False # 强制关闭并行推理,避免与流式冲突
|
||||
config.batch_size = 1 # 流式下batch_size必须为1
|
||||
async def stream_generator():
|
||||
async with self.client.stream(
|
||||
"POST", "/tts",
|
||||
json=config.build_request(),
|
||||
timeout=httpx.Timeout(60.0) # 流式需要更长超时
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
raise APIError(response.status_code, await response.aread())
|
||||
|
||||
async for chunk in response.aiter_bytes():
|
||||
if chunk:
|
||||
yield AudioResponse(audio_data=chunk)
|
||||
|
||||
return stream_generator()
|
||||
|
||||
# 模型管理接口
|
||||
async def set_gpt_weights(self, weights_path: str) -> bool:
|
||||
"""
|
||||
切换GPT模型权重
|
||||
|
||||
Args:
|
||||
weights_path: 权重文件路径(服务器本地路径)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
|
||||
Example:
|
||||
await client.set_gpt_weights("models/s1bert.ckpt")
|
||||
"""
|
||||
if not weights_path:
|
||||
raise ValueError("weights_path cannot be empty")
|
||||
|
||||
params = {"weights_path": weights_path}
|
||||
response = await self.client.get("/set_gpt_weights", params=params)
|
||||
result = await self._handle_response(response)
|
||||
|
||||
logger.info(f"GPT weights switched to: {weights_path}")
|
||||
return True
|
||||
|
||||
async def set_sovits_weights(self, weights_path: str) -> bool:
|
||||
"""
|
||||
切换SoVITS模型权重
|
||||
|
||||
Args:
|
||||
weights_path: 权重文件路径(服务器本地路径)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
if not weights_path:
|
||||
raise ValueError("weights_path cannot be empty")
|
||||
|
||||
params = {"weights_path": weights_path}
|
||||
response = await self.client.get("/set_sovits_weights", params=params)
|
||||
await self._handle_response(response)
|
||||
|
||||
logger.info(f"SoVITS weights switched to: {weights_path}")
|
||||
return True
|
||||
|
||||
# 参考音频管理
|
||||
async def set_refer_audio(
|
||||
self,
|
||||
audio_source: Union[str, Path, bytes],
|
||||
audio_name: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
设置参考音频(支持多种输入方式)
|
||||
|
||||
Args:
|
||||
audio_source: 音频文件路径(str/Path)或音频数据(bytes)
|
||||
audio_name: 音频文件名(仅bytes输入时需要)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
|
||||
Example:
|
||||
# 方式1: 服务器本地文件
|
||||
await client.set_refer_audio("/path/to/audio.wav")
|
||||
|
||||
# 方式2: 上传音频数据
|
||||
with open("audio.wav", "rb") as f:
|
||||
await client.set_refer_audio(f.read(), "audio.wav")
|
||||
"""
|
||||
if isinstance(audio_source, (str, Path)):
|
||||
# GET方式:服务器本地路径
|
||||
params = {"refer_audio_path": str(audio_source)}
|
||||
response = await self.client.get("/set_refer_audio", params=params)
|
||||
await self._handle_response(response)
|
||||
logger.info(f"Reference audio set: {audio_source}")
|
||||
else:
|
||||
# POST方式:上传音频数据
|
||||
if not audio_name:
|
||||
raise ValueError("audio_name is required when uploading bytes")
|
||||
|
||||
files = {"audio_file": (audio_name, audio_source, "audio/wav")}
|
||||
response = await self.client.post("/set_refer_audio", files=files)
|
||||
await self._handle_response(response)
|
||||
logger.info(f"Reference audio uploaded: {audio_name}")
|
||||
|
||||
return True
|
||||
|
||||
# 服务器控制
|
||||
async def control_command(self, command: str) -> bool:
|
||||
"""
|
||||
发送控制命令
|
||||
|
||||
Args:
|
||||
command: 命令类型 - "restart" 或 "exit"
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
|
||||
Warning:
|
||||
"exit"命令会终止API服务器进程!
|
||||
"""
|
||||
if command not in ["restart", "exit"]:
|
||||
raise ValueError("Command must be 'restart' or 'exit'")
|
||||
|
||||
response = await self.client.get("/control", params={"command": command})
|
||||
await self._handle_response(response)
|
||||
|
||||
logger.warning(f"Control command executed: {command}")
|
||||
return True
|
||||
|
||||
# 高级快捷方法
|
||||
async def get_server_info(self) -> Dict[str, Any]:
|
||||
"""获取服务器状态信息"""
|
||||
# 通过调用根路径或自定义health接口
|
||||
try:
|
||||
response = await self.client.get("/")
|
||||
return {"status": "online", "detail": response.text}
|
||||
except Exception as e:
|
||||
return {"status": "error", "detail": str(e)}
|
||||
|
||||
async def batch_tts(
|
||||
self,
|
||||
texts: list[str],
|
||||
ref_audio_path: str,
|
||||
**kwargs
|
||||
) -> list[AudioResponse]:
|
||||
"""
|
||||
批量TTS合成
|
||||
|
||||
Args:
|
||||
texts: 文本列表
|
||||
ref_audio_path: 参考音频
|
||||
**kwargs: 其他TTS参数
|
||||
|
||||
Returns:
|
||||
list[AudioResponse]: 音频响应列表
|
||||
"""
|
||||
tasks = [
|
||||
self.tts(text, ref_audio_path, **kwargs)
|
||||
for text in texts
|
||||
]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
# 异步上下文管理器辅助函数
|
||||
async def create_client(*args, **kwargs) -> GPTSoVITSClient:
|
||||
"""快速创建客户端实例"""
|
||||
return GPTSoVITSClient(*args, **kwargs)
|
||||
@@ -0,0 +1,27 @@
|
||||
# dto/dto_base.py
|
||||
from abc import ABC
|
||||
from typing import Callable, Coroutine, Any, Dict
|
||||
from src.modules.websocket_base_module.websocket_core.core_ws_server import WebSocketServer
|
||||
|
||||
class MessageDTO(ABC):
|
||||
"""DTO基类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ws_server : WebSocketServer, # WebSocketServer单例
|
||||
):
|
||||
# 保存服务器实例(用于发送)
|
||||
self.ws_server = ws_server
|
||||
|
||||
# 便捷属性,DTO层直接调用
|
||||
@property
|
||||
def send_binary(self):
|
||||
return self.ws_server.send_binary
|
||||
|
||||
@property
|
||||
def send_text(self):
|
||||
return self.ws_server.send_text
|
||||
|
||||
@property
|
||||
def send_json(self):
|
||||
return self.ws_server.send_json
|
||||
@@ -0,0 +1,95 @@
|
||||
from pydantic import Field, BaseModel
|
||||
import base64
|
||||
from datetime import datetime, timezone
|
||||
class AudioDataTransferObject(BaseModel):
|
||||
"""
|
||||
音频数据传输对象
|
||||
该对象被用于服务端与客户端的音频数据交互
|
||||
同时支持流式与非流式的音频数据
|
||||
同时收发对等(通过Owner标识)
|
||||
"""
|
||||
Owner: str = Field(default="server", description="音频数据的拥有者(server or client)")
|
||||
isStream: bool = Field(default=False, description="音频数据是否为流式数据")
|
||||
isStart: bool = Field(default=False, description="音频数据是否开始(流式时有效)")
|
||||
isEnd: bool = Field(default=False, description="音频数据是否结束(流式时有效)")
|
||||
sequence: int = Field(default=0, description="音频数据块序列号(流式时有效)")
|
||||
data: bytes = Field(default=b"", description="音频数据,流式时为分块数据,base64编码")
|
||||
sampleRate: int = Field(default=32000, description="音频采样率")
|
||||
channelCount: int = Field(default=1, description="音频通道数")
|
||||
bitDepth: int = Field(default=16, description="音频采样位数")
|
||||
duration: float = Field(default=0.0, description="音频时长")
|
||||
text: str = Field(default="", description="音频对应的文本")
|
||||
|
||||
def set_dto_data(self, **kwargs) -> "AudioDataTransferObject":
|
||||
"""链式更新数据(Pydantic 风格)"""
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
return self
|
||||
|
||||
def to_json(self) -> dict:
|
||||
"""
|
||||
将DTO对象转换为可序列化字典
|
||||
返回的json数据的格式:
|
||||
{
|
||||
"type": "audio_data",
|
||||
"timestamp": 1672531200.0,
|
||||
"data": {
|
||||
"Owner": "server",
|
||||
......
|
||||
}
|
||||
}
|
||||
"""
|
||||
# model_dump() 是 Pydantic v2 的序列化方法
|
||||
payload = self.model_dump() # 获取所有模型字段
|
||||
payload["data"] = base64.b64encode(payload["data"]).decode() # base64编码
|
||||
# 构造嵌套结构
|
||||
return {
|
||||
"type": "audio_data",
|
||||
"timestamp": datetime.now(timezone.utc).timestamp(),
|
||||
"data": payload # 音频字段嵌套
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_data: dict) -> "AudioDataTransferObject":
|
||||
"""
|
||||
从JSON数据创建DTO对象
|
||||
传入的json数据格式:
|
||||
{
|
||||
"Owner": "server",
|
||||
......
|
||||
}
|
||||
"""
|
||||
payload = json_data
|
||||
# 解码 base64 (内层 data 字段在传输时是 base64 字符串) -> bytes
|
||||
if "data" in payload and isinstance(payload["data"], str):
|
||||
payload["data"] = base64.b64decode(payload["data"])
|
||||
# 构造对象 Pydantic 自动忽略 type/timestamp
|
||||
return cls.model_validate(payload)
|
||||
|
||||
|
||||
# 测试代码
|
||||
if __name__ == "__main__":
|
||||
# 模拟音频数据
|
||||
import os
|
||||
|
||||
test_audio = os.urandom(1024) # 随机生成1KB音频数据
|
||||
|
||||
# 创建DTO
|
||||
audio = AudioDataTransferObject(
|
||||
data=test_audio,
|
||||
sequence=1,
|
||||
isStream=True,
|
||||
sampleRate=44100,
|
||||
duration=0.5
|
||||
)
|
||||
|
||||
# 序列化 → JSON
|
||||
json_dict = audio.to_json()
|
||||
print(f"序列化后 data 长度: {len(json_dict['data'])}") # ~1368 字符
|
||||
print(f"data 前30字符: {json_dict['data'][:30]}...")
|
||||
|
||||
# 反序列化 → DTO
|
||||
restored = AudioDataTransferObject.from_json(json_dict)
|
||||
print(f"反序列化后 data 长度: {len(restored.data)}") # 1024 bytes
|
||||
print(f"数据一致: {restored.data == test_audio}") # True
|
||||
@@ -0,0 +1,73 @@
|
||||
from pydantic import Field, BaseModel
|
||||
from datetime import datetime, timezone
|
||||
class AutoAgentDataTransferObject(BaseModel):
|
||||
"""
|
||||
自动化agent数据传输对象
|
||||
该对象被用于服务端向客户端发送控制信息
|
||||
"""
|
||||
Action: str = Field(default="", description="自动化动作名称")
|
||||
x1: int = Field(default=-1, description="鼠标起始位置x1")
|
||||
y1: int = Field(default=-1, description="鼠标起始位置y1")
|
||||
x2: int = Field(default=-1, description="鼠标结束位置x2")
|
||||
y2: int = Field(default=-1, description="鼠标结束位置y2")
|
||||
key: str = Field(default="", description="快捷键")
|
||||
content: str = Field(default="", description="输入文本内容")
|
||||
direction: str = Field(default="", description="滚动方向")
|
||||
|
||||
def set_dto_data(self, **kwargs) -> "AutoAgentDataTransferObject":
|
||||
"""链式更新数据(Pydantic 风格)"""
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
return self
|
||||
|
||||
def to_json(self) -> dict:
|
||||
"""
|
||||
将DTO对象转换为可序列化字典
|
||||
返回的json数据的格式:
|
||||
{
|
||||
"type": "auto_agent",
|
||||
"timestamp": 1672531200.0,
|
||||
"data": {
|
||||
"Action": "自动化agent返回的相应的自动化动作名称",
|
||||
"x1": "某个操作的x1,不是所有的操作都有,如果相关的操作没有,写成-1即可,y1,x2,y2也同理",
|
||||
"y1": "某个操作的y1",
|
||||
"x2": "某个操作的x2",
|
||||
"y2": "某个操作的y2",
|
||||
"key": "快捷键,若当前操作没有该字段信息,此处内容为空即可",
|
||||
"content": "输入文本的内容,若当前操作没有该字段信息,此处内容为空即可",
|
||||
"direction": "滚动方向,若当前操作没有该字段信息,此处内容为空即可"
|
||||
}
|
||||
}
|
||||
"""
|
||||
# model_dump() 是 Pydantic v2 的序列化方法
|
||||
payload = self.model_dump() # 获取所有模型字段
|
||||
# 构造嵌套结构
|
||||
return {
|
||||
"type": "auto_agent",
|
||||
"timestamp": datetime.now(timezone.utc).timestamp(),
|
||||
"data": payload # 字段嵌套
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_data: dict) -> "AutoAgentDataTransferObject":
|
||||
"""
|
||||
从JSON数据创建DTO对象
|
||||
传入的json数据格式:
|
||||
{
|
||||
"type": "auto_agent",
|
||||
"Action": "自动化agent返回的相应的自动化动作名称",
|
||||
"x1": "某个操作的x1,不是所有的操作都有,如果相关的操作没有,写成-1即可,y1,x2,y2也同理",
|
||||
"y1": "某个操作的y1",
|
||||
"x2": "某个操作的x2",
|
||||
"y2": "某个操作的y2",
|
||||
"key": "快捷键,若当前操作没有该字段信息,此处内容为空即可",
|
||||
"content": "输入文本的内容,若当前操作没有该字段信息,此处内容为空即可",
|
||||
"direction": "滚动方向,若当前操作没有该字段信息,此处内容为空即可"
|
||||
}
|
||||
"""
|
||||
payload = json_data
|
||||
payload.pop("type", None) # 移除多余字段
|
||||
# 构造对象 Pydantic 自动忽略 type/timestamp
|
||||
return cls.model_validate(payload)
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
class BaseDataTransferObject:
|
||||
"""
|
||||
DTO基类
|
||||
子类按需重写成员函数即可
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
def to_json(self):
|
||||
"""
|
||||
将DTO对象转换为JSON
|
||||
"""
|
||||
pass
|
||||
def from_json(self, json_data):
|
||||
"""
|
||||
从JSON数据中创建DTO对象
|
||||
"""
|
||||
pass
|
||||
def to_binary(self):
|
||||
"""
|
||||
将DTO对象转换为二进制
|
||||
"""
|
||||
pass
|
||||
def from_binary(self, binary_data):
|
||||
"""
|
||||
从二进制数据中创建DTO对象
|
||||
"""
|
||||
pass
|
||||
def to_text(self):
|
||||
"""
|
||||
将DTO对象转换为文本
|
||||
"""
|
||||
pass
|
||||
def from_text(self, text_data):
|
||||
"""
|
||||
从文本数据中创建DTO对象
|
||||
"""
|
||||
pass
|
||||
@@ -0,0 +1,72 @@
|
||||
from pydantic import Field, BaseModel
|
||||
from datetime import datetime, timezone
|
||||
class ScreenShotDataTransferObject(BaseModel):
|
||||
"""
|
||||
服务端向客户端请求实时截图的数据传输对象
|
||||
服务端与客户端收发对等(通过Owner标识)
|
||||
客户端收到这个type的包,就会自动对当前设备的画面进行截图
|
||||
"""
|
||||
Owner: str = Field(default="server", description="数据的拥有者(server or client)")
|
||||
isSuccess: bool = Field(default=False, description="是否截图成功")
|
||||
RealTimeScreenShot: str = Field(default="", description="客户端设备的实时截图数据(base64)")
|
||||
Width: int = Field(default=1920, description="截图的宽度")
|
||||
Height: int = Field(default=1080, description="截图的高度")
|
||||
DescribeInfo: str = Field(default="", description="设备的描述信息(告知模型以做出更加准确的判断)")
|
||||
LLMResponse: str = Field(default="", description="LLM的响应结果(由服务端发送时携带)")
|
||||
|
||||
|
||||
def set_dto_data(self, **kwargs) -> "ScreenShotDataTransferObject":
|
||||
"""链式更新数据(Pydantic 风格)"""
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
return self
|
||||
|
||||
def to_json(self) -> dict:
|
||||
"""
|
||||
将DTO对象转换为可序列化字典
|
||||
返回的json数据的格式:
|
||||
{
|
||||
"type": "screenshot_data",
|
||||
"timestamp": 1672531200.0,
|
||||
"data": {
|
||||
"Owner": "数据的拥有者(server or client)",
|
||||
"isSuccess": "是否截图成功(true or false)"
|
||||
"RealTimeScreenShot": "客户端设备的实时截图数据(base64)",
|
||||
"Width": "截图的宽度",
|
||||
"Height": "截图的高度",
|
||||
"DescribeInfo": "设备的描述信息(告知模型以做出更加准确的判断)",
|
||||
"LLMResponse": "LLM的响应结果(由服务端发送时携带)"
|
||||
}
|
||||
}
|
||||
"""
|
||||
# model_dump() 是 Pydantic v2 的序列化方法
|
||||
payload = self.model_dump() # 获取所有模型字段
|
||||
# 构造嵌套结构
|
||||
return {
|
||||
"type": "screenshot_data",
|
||||
"timestamp": datetime.now(timezone.utc).timestamp(),
|
||||
"data": payload # 字段嵌套
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_data: dict) -> "ScreenShotDataTransferObject":
|
||||
"""
|
||||
从JSON数据创建DTO对象
|
||||
传入的json数据格式:
|
||||
{
|
||||
"Owner": "数据的拥有者(server or client)",
|
||||
"isSuccess": "是否截图成功(true or false)",
|
||||
"RealTimeScreenShot": "客户端设备的实时截图数据(base64)",
|
||||
"Width": "截图的宽度 非必要字段",
|
||||
"Height": "截图的高度 非必要字段",
|
||||
"DescribeInfo": "设备的描述信息(告知模型以做出更加准确的判断) 非必要字段",
|
||||
"LLMResponse": "LLM的响应结果(由服务端发送时携带) 必要字段"
|
||||
}
|
||||
"""
|
||||
payload = json_data
|
||||
payload.pop("type", None) # 移除多余字段
|
||||
payload.pop("timestamp", None) # 移除多余字段
|
||||
# 构造对象 Pydantic 自动忽略 type/timestamp
|
||||
return cls.model_validate(payload)
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
# dto/second_dtos.py
|
||||
import asyncio
|
||||
from typing import Callable, Optional, Dict, Any, List, Coroutine
|
||||
from src.modules.websocket_base_module.dto.dto_base import MessageDTO
|
||||
from loguru import logger
|
||||
|
||||
"""
|
||||
二级分发器,因为没有信号与槽机制,因此使用观察者模式替代
|
||||
"""
|
||||
def singleton(cls): # 单例
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
async def get_instance(*args, **kwargs):
|
||||
nonlocal _instance
|
||||
if _instance is None:
|
||||
async with _lock:
|
||||
if _instance is None:
|
||||
_instance = cls(*args, **kwargs)
|
||||
return _instance
|
||||
|
||||
cls.get_instance = get_instance
|
||||
return cls
|
||||
# 类型别名
|
||||
ReceiveCallback = Callable[[Any], Coroutine[Any, Any, None]]
|
||||
@singleton
|
||||
class JsonDTO(MessageDTO):
|
||||
"""针对json消息的二级分发"""
|
||||
"""
|
||||
明确业务json格式:
|
||||
{
|
||||
"type" : "xxx",
|
||||
"timestamp" : "95153...",
|
||||
"data" : "{根据业务的不同,有不同的内容}"
|
||||
}
|
||||
"""
|
||||
# 因为不同的数据块的json以type字段进行包装,根据type进行正确的数据分发
|
||||
def __init__(self, ws_server):
|
||||
super().__init__(ws_server)
|
||||
self.receivers : Dict[str, List[ReceiveCallback]] = {
|
||||
'audio_data' : [], # 音频数据
|
||||
'screenshot_data' : [] # 截图数据
|
||||
}
|
||||
# 注册json处理callback function
|
||||
ws_server.register_receiver('json', self._handle_json)
|
||||
logger.info("[JsonDTO] JSON分发器已注册")
|
||||
|
||||
def register_receiver(self, types : str, callback : ReceiveCallback):
|
||||
"""注册二次分发业务接收函数,供业务DTO调用"""
|
||||
if types in self.receivers:
|
||||
self.receivers[types].append(callback)
|
||||
logger.debug(f"[JsonDTO] 已注册 {types} 接收器,当前共 {len(self.receivers[types])} 个")
|
||||
else:
|
||||
raise ValueError(f"[JsonDTO] 不支持的分发类型: {types}")
|
||||
|
||||
def unregister_receiver(self, types : str, callback : ReceiveCallback):
|
||||
"""注销二次分发业务接收函数"""
|
||||
if callback in self.receivers[types]:
|
||||
self.receivers[types].remove(callback)
|
||||
logger.debug(f"[JsonDTO] 已注销 {types} 接收器")
|
||||
|
||||
async def _handle_json(self, data: dict):
|
||||
"""JSON消息处理"""
|
||||
logger.info(f"[JsonDTO] 收到消息")
|
||||
logger.debug(f'[JsonDTO] 当前消息时间戳: {data["timestamp"]}')
|
||||
# 根据类型进行自动分发
|
||||
await self._dispatch(data.get("type"), data["data"])
|
||||
|
||||
async def _dispatch(self, types : str, data : dict):
|
||||
"""二次分发json数据到相应的接收函数当中"""
|
||||
callbacks = self.receivers[types] # 获取相关types的所有观察者
|
||||
if not callbacks:
|
||||
logger.info(f"[JsonDTO] 无 {types} 接收器,消息被忽略")
|
||||
return
|
||||
# 并发执行所有回调
|
||||
tasks = [callback(data) for callback in callbacks]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
async def get_json_dto_instance(ws_server) -> JsonDTO:
|
||||
return JsonDTO(ws_server)
|
||||
|
||||
|
||||
class EchoDTO(MessageDTO):
|
||||
"""回声DTO:只处理文本消息 测试用"""
|
||||
|
||||
def __init__(self, ws_server):
|
||||
super().__init__(ws_server)
|
||||
# 注册文本接收函数
|
||||
ws_server.register_receiver('text', self._handle_text)
|
||||
logger.info("[EchoDTO] 文本接收器已注册")
|
||||
|
||||
async def _handle_text(self, message: str):
|
||||
"""文本消息处理"""
|
||||
logger.info(f"[EchoDTO] 收到文本: {message}")
|
||||
|
||||
# 业务逻辑
|
||||
await self.send_text(f"Echo: {message}")
|
||||
@@ -0,0 +1,281 @@
|
||||
import asyncio
|
||||
from src.modules.websocket_base_module.dto.dto_templates.audio_data_dto import AudioDataTransferObject
|
||||
from src.modules.websocket_base_module.dto.dto_templates.screenshot_data_dto import ScreenShotDataTransferObject
|
||||
from src.modules.websocket_base_module.dto.second_dtos import JsonDTO
|
||||
from loguru import logger
|
||||
from typing import Callable, List, Optional, Coroutine
|
||||
|
||||
class AudioDataDTO:
|
||||
"""音频数据交互DTO 再次分发给所有使用到了音频数据的相关业务(最后一级分发)"""
|
||||
def __init__(self, json_dto : JsonDTO):
|
||||
json_dto.register_receiver('audio_data', self._handle_audio_data) # 注册JSON接收函数
|
||||
logger.info("[AudioDataDTO] 音频接收业务已注册")
|
||||
self.json_dto = json_dto
|
||||
self.audio_data = AudioDataTransferObject() # 音频数据对象
|
||||
# 业务回调列表,延续观察者模式
|
||||
self._audio_callbacks: List[Callable[[AudioDataTransferObject], Coroutine]] = []
|
||||
# 最新音频缓存 支持同步查询
|
||||
self._latest_audio: Optional[AudioDataTransferObject] = None
|
||||
# 流式缓冲区 用于大段音频流
|
||||
self._stream_buffer: List[AudioDataTransferObject] = []
|
||||
logger.info("[AudioDataDTO] 业务接口已初始化")
|
||||
|
||||
async def _handle_audio_data(self, data: dict):
|
||||
"""处理音频数据"""
|
||||
"""
|
||||
音频数据json格式:
|
||||
{
|
||||
'Owner': 'server',
|
||||
'isStream': False,
|
||||
'isStart': False,
|
||||
'isEnd': False,
|
||||
'sequence': 0,
|
||||
'data': '',
|
||||
'sampleRate': 16000,
|
||||
'channelCount': 1,
|
||||
'bitDepth': 16,
|
||||
'duration': 0.0,
|
||||
'text': ''
|
||||
}
|
||||
"""
|
||||
logger.debug(f"[AudioDataDTO] 收到音频数据")
|
||||
# 将dict反序列化到DTO对象
|
||||
self.audio_data = AudioDataTransferObject.from_json(data)
|
||||
# 缓存最新数据
|
||||
self._latest_audio = self.audio_data
|
||||
# 如果是流式数据,加入缓冲区
|
||||
if self.audio_data.isStream:
|
||||
self._stream_buffer.append(self.audio_data)
|
||||
if self.audio_data.isEnd:
|
||||
logger.info(f"流式音频接收完成,共 {len(self._stream_buffer)} 块")
|
||||
# 通知所有注册的回调
|
||||
await self._notify_callbacks()
|
||||
|
||||
# 业务发送接口
|
||||
async def send_audio_data(self, data: AudioDataTransferObject) -> None:
|
||||
"""
|
||||
发送音频数据
|
||||
|
||||
Args:
|
||||
data: 音频数据DTO
|
||||
"""
|
||||
await self.send_audio(
|
||||
Owner=data.Owner,
|
||||
is_stream=data.isStream,
|
||||
is_start=data.isStart,
|
||||
is_end=data.isEnd,
|
||||
sequence=data.sequence,
|
||||
data=data.data,
|
||||
sampleRate=data.sampleRate,
|
||||
channelCount=data.channelCount,
|
||||
bitDepth=data.bitDepth,
|
||||
duration=data.duration,
|
||||
text=data.text
|
||||
)
|
||||
|
||||
async def send_audio(
|
||||
self,
|
||||
data: bytes,
|
||||
is_stream: bool = False,
|
||||
is_start: bool = False,
|
||||
is_end: bool = False,
|
||||
sequence: int = 0,
|
||||
**audio_meta
|
||||
) -> None:
|
||||
"""
|
||||
业务层发送音频的便捷接口
|
||||
|
||||
Args:
|
||||
data: 原始音频字节
|
||||
is_stream: 是否为流式数据
|
||||
is_start: 流式数据开始标记
|
||||
is_end: 流式数据结束标记
|
||||
sequence: 数据块序号
|
||||
**audio_meta: 其他音频参数(sampleRate, channelCount等)
|
||||
"""
|
||||
# 填充音频数据到DTO
|
||||
self.audio_data.set_dto_data(
|
||||
Owner="server" or audio_meta.get('Owner', "server"),
|
||||
isStream=is_stream,
|
||||
isStart=is_start,
|
||||
isEnd=is_end,
|
||||
sequence=sequence,
|
||||
data=data,
|
||||
sampleRate=audio_meta.get('sampleRate', 16000),
|
||||
channelCount=audio_meta.get('channelCount', 1),
|
||||
bitDepth=audio_meta.get('bitDepth', 16),
|
||||
duration=audio_meta.get('duration', 0.0),
|
||||
text=audio_meta.get('text', "")
|
||||
)
|
||||
# 序列化为JSON并发送 自动处理base64和type字段
|
||||
json_message = self.audio_data.to_json()
|
||||
await self.json_dto.send_json(json_message)
|
||||
logger.info(f"音频已发送: sequence={sequence}, 大小={len(data)} bytes")
|
||||
|
||||
# 业务接收接口
|
||||
def register_audio_callback(
|
||||
self,
|
||||
callback: Callable[[AudioDataTransferObject], Coroutine]
|
||||
) -> None:
|
||||
"""
|
||||
业务注册接收回调
|
||||
|
||||
使用示例:
|
||||
async def my_audio_handler(audio_dto: AudioDataTransferObject):
|
||||
print(f"收到音频: {len(audio_dto.data)} bytes")
|
||||
|
||||
audio_dto.register_audio_callback(my_audio_handler)
|
||||
"""
|
||||
self._audio_callbacks.append(callback)
|
||||
logger.debug(f"业务音频回调已注册,当前共 {len(self._audio_callbacks)} 个")
|
||||
|
||||
def unregister_audio_callback(self, callback) -> None:
|
||||
"""注销业务回调"""
|
||||
if callback in self._audio_callbacks:
|
||||
self._audio_callbacks.remove(callback)
|
||||
logger.debug("业务音频回调已注销")
|
||||
|
||||
def get_latest_audio(self) -> Optional[AudioDataTransferObject]:
|
||||
"""
|
||||
同步获取最新音频数据(轮询模式)
|
||||
|
||||
Returns:
|
||||
最新接收到的音频DTO,如果没有则为 None
|
||||
"""
|
||||
return self._latest_audio
|
||||
|
||||
def clear_stream_buffer(self) -> None:
|
||||
"""清空流式缓冲区"""
|
||||
self._stream_buffer.clear()
|
||||
logger.debug("流式音频缓冲区已清空")
|
||||
|
||||
# 内部通知机制
|
||||
async def _notify_callbacks(self) -> None:
|
||||
"""通知所有业务回调"""
|
||||
if not self._audio_callbacks:
|
||||
logger.warning("无业务回调,音频数据未处理")
|
||||
return
|
||||
tasks = [callback(self.audio_data) for callback in self._audio_callbacks]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
logger.debug(f"已通知 {len(self._audio_callbacks)} 个业务回调")
|
||||
# 异步迭代器 流式时使用
|
||||
def __aiter__(self):
|
||||
"""支持 async for 循环接收流式音频"""
|
||||
return self
|
||||
async def __anext__(self) -> AudioDataTransferObject:
|
||||
"""异步迭代器协议"""
|
||||
pass
|
||||
|
||||
class ScreenShotDataDTO:
|
||||
"""截屏数据交互DTO 分发给所有使用到了截屏数据的相关业务(最后一级分发)"""
|
||||
def __init__(self, json_dto : JsonDTO):
|
||||
json_dto.register_receiver('screenshot_data', self._handle_screenshot_data) # 注册JSON接收函数
|
||||
logger.info("[ScreenShotDataDTO] 截屏接收业务已注册")
|
||||
self.json_dto = json_dto
|
||||
self.screenshot_data = ScreenShotDataTransferObject() # 截屏数据对象
|
||||
# 业务回调列表,延续观察者模式
|
||||
self._screenshot_callbacks: List[Callable[[ScreenShotDataTransferObject], Coroutine]] = []
|
||||
# 最新截屏数据缓存 支持同步查询
|
||||
self._latest_screenshot: Optional[ScreenShotDataTransferObject] = None
|
||||
logger.info("[ScreenShotDataDTO] 业务接口已初始化")
|
||||
|
||||
async def _handle_screenshot_data(self, data: dict):
|
||||
"""处理截屏数据"""
|
||||
"""
|
||||
截屏数据json格式:
|
||||
{
|
||||
"Owner": "数据的拥有者(server or client)",
|
||||
"isSuccess": "是否截图成功(true or false)"
|
||||
"RealTimeScreenShot": "客户端设备的实时截图数据(base64)",
|
||||
"Width": "截图的宽度",
|
||||
"Height": "截图的高度",
|
||||
"DescribeInfo": "设备的描述信息(告知模型以做出更加准确的判断)"
|
||||
}
|
||||
"""
|
||||
logger.debug(f"[ScreenShotDataDTO] 收到截屏数据")
|
||||
# 将dict反序列化到DTO对象
|
||||
self.screenshot_data = ScreenShotDataTransferObject.from_json(data)
|
||||
# 缓存最新数据
|
||||
self._latest_screenshot = self.screenshot_data
|
||||
# 通知所有注册的回调
|
||||
await self._notify_callbacks()
|
||||
|
||||
# 业务发送接口
|
||||
async def send_screenshot_data(self, data: ScreenShotDataTransferObject) -> None:
|
||||
"""
|
||||
发送音频数据
|
||||
|
||||
Args:
|
||||
data: 音频数据DTO
|
||||
"""
|
||||
await self.send_screenshot(
|
||||
Owner=data.Owner,
|
||||
isSuccess=data.isSuccess,
|
||||
RealTimeScreenShot=data.RealTimeScreenShot,
|
||||
Width=data.Width,
|
||||
Height=data.Height,
|
||||
DescribeInfo=data.DescribeInfo,
|
||||
LLMResponse=data.LLMResponse
|
||||
)
|
||||
|
||||
async def send_screenshot(
|
||||
self,
|
||||
**screenshot_meta
|
||||
) -> None:
|
||||
"""
|
||||
业务层发送音频的便捷接口
|
||||
一般来说,作为发送请求方,不需要填充任何数据
|
||||
|
||||
Args:
|
||||
**screenshot_meta: 截屏数据元信息
|
||||
"""
|
||||
# 填充音频数据到DTO
|
||||
self.screenshot_data.set_dto_data(
|
||||
Owner="server" or screenshot_meta.get('Owner', "server"),
|
||||
isSuccess=screenshot_meta.get('isSuccess', False),
|
||||
RealTimeScreenShot=screenshot_meta.get('RealTimeScreenShot', ""),
|
||||
Width=screenshot_meta.get('Width', 1920),
|
||||
Height=screenshot_meta.get('Height', 1080),
|
||||
DescribeInfo=screenshot_meta.get('DescribeInfo', False),
|
||||
LLMResponse=screenshot_meta.get('LLMResponse', "")
|
||||
)
|
||||
# 序列化为JSON并发送 自动处理base64和type字段
|
||||
json_message = self.screenshot_data.to_json()
|
||||
await self.json_dto.send_json(json_message)
|
||||
logger.info(f"截屏包已发送")
|
||||
|
||||
# 业务接收接口
|
||||
def register_screenshot_callback(
|
||||
self,
|
||||
callback: Callable[[ScreenShotDataTransferObject], Coroutine]
|
||||
) -> None:
|
||||
"""
|
||||
业务注册接收回调
|
||||
"""
|
||||
self._screenshot_callbacks.append(callback)
|
||||
logger.debug(f"业务截屏回调已注册,当前共 {len(self._screenshot_callbacks)} 个")
|
||||
|
||||
def unregister_screenshot_callback(self, callback) -> None:
|
||||
"""注销业务回调"""
|
||||
if callback in self._screenshot_callbacks:
|
||||
self._screenshot_callbacks.remove(callback)
|
||||
logger.debug("业务截屏回调已注销")
|
||||
|
||||
def get_latest_screenshot(self) -> Optional[ScreenShotDataTransferObject]:
|
||||
"""
|
||||
同步获取最新截屏数据(轮询模式)
|
||||
|
||||
Returns:
|
||||
最新接收到的截屏数据DTO,如果没有则为 None
|
||||
"""
|
||||
return self._latest_screenshot
|
||||
|
||||
# 内部通知机制
|
||||
async def _notify_callbacks(self) -> None:
|
||||
"""通知所有业务回调"""
|
||||
if not self._screenshot_callbacks:
|
||||
logger.warning("无业务回调,截屏数据未处理")
|
||||
return
|
||||
tasks = [callback(self.screenshot_data) for callback in self._screenshot_callbacks]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
logger.debug(f"已通知 {len(self._screenshot_callbacks)} 个业务回调")
|
||||
@@ -0,0 +1,56 @@
|
||||
### 说明
|
||||
在本module当中,每个子模块的用途分别是:
|
||||
- dto
|
||||
- dto_templates
|
||||
服务端与客户端交互所使用到的数据传输对象
|
||||
- dto_base.py / xxx_dtos.py
|
||||
实际的DTO业务
|
||||
- websocket_core
|
||||
- websocket核心,承载了底层核心的网络收发业务
|
||||
|
||||
|
||||
### 模块架构
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph "Client"
|
||||
C[WebSocket Client]
|
||||
end
|
||||
|
||||
subgraph "Core Layer"
|
||||
WS[WebSocketServer<br/>单例]
|
||||
WS -->|持有| WSP[WebSocketServerProtocol<br/>_websocket]
|
||||
WS -->|管理| RCV[ receivers: Dict<br/>binary/text/json ]
|
||||
end
|
||||
|
||||
subgraph "DTO Base Layer"
|
||||
MDTO[MessageDTO<br/>抽象基类]
|
||||
MDTO -->|注入| MDF[ send_binary<br/>send_text<br/>send_json ]
|
||||
end
|
||||
|
||||
subgraph "Secondary Dispatcher"
|
||||
JDTO[JsonDTO<br/>单例]
|
||||
JDTO -->|继承| MDTO
|
||||
JDTO -->|维护| MAP[ receivers: Dict<br/>audio_data/... ]
|
||||
JDTO -->|注册到| WS
|
||||
end
|
||||
|
||||
subgraph "Business DTO"
|
||||
ADTO[AudioDataDTO......<br/>业务实现]
|
||||
ADTO -->|持有引用| JDTO
|
||||
ADTO -->|使用| ATO[AudioDataTransferObject<br/>Pydantic模型]
|
||||
end
|
||||
|
||||
C <-->|websocket连接| WSP
|
||||
|
||||
WS -->|分发消息| JDTO
|
||||
JDTO -->|二次分发| ADTO
|
||||
|
||||
ADTO -->|发送响应| JDTO
|
||||
JDTO -->|调用| MDF
|
||||
MDF -->|经由| WS
|
||||
WS -->|发送至| C
|
||||
|
||||
style WS fill:#64f,stroke:#333,stroke-width:2px
|
||||
style JDTO fill:#569,stroke:#333,stroke-width:2px
|
||||
style ADTO fill:#38f,stroke:#333,stroke-width:2px
|
||||
```
|
||||
@@ -0,0 +1,172 @@
|
||||
# websocket_core/core_ws_server.py
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Callable, Optional, Dict, Any, List, Coroutine
|
||||
from websockets.asyncio.server import serve, ServerConnection
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
from loguru import logger
|
||||
|
||||
# 类型别名
|
||||
ReceiveCallback = Callable[[Any], Coroutine[Any, Any, None]]
|
||||
class WebSocketServer:
|
||||
"""WebSocket服务端核心模块(单例 + 单客户端)
|
||||
只管理一个客户端连接,DTO层注册接收函数,服务端只负责分发。
|
||||
"""
|
||||
_instance: Optional["WebSocketServer"] = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
"""同步单例(__init__ 可以是 async)"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
# 防止重复初始化
|
||||
if self._initialized:
|
||||
return
|
||||
self._websocket: Optional[ServerConnection] = None
|
||||
self._receivers: Dict[str, List[ReceiveCallback]] = {
|
||||
'binary': [],
|
||||
'text': [],
|
||||
'json': []
|
||||
}
|
||||
self._connected_event = asyncio.Event()
|
||||
self._initialized = True
|
||||
logger.info("WebSocketServer 单客户端分发器已初始化")
|
||||
|
||||
# DTO层注册接口
|
||||
def register_receiver(self, msg_type: str, callback: ReceiveCallback) -> None:
|
||||
"""注册接收函数(供DTO层调用)
|
||||
|
||||
Args:
|
||||
msg_type: 消息类型(binary/text/json)
|
||||
callback: 接收回调,签名为 (data) -> coroutine
|
||||
"""
|
||||
if msg_type in self._receivers:
|
||||
self._receivers[msg_type].append(callback)
|
||||
logger.debug(f"已注册 {msg_type} 接收器,当前 {msg_type} 类型接收器共 {len(self._receivers[msg_type])} 个")
|
||||
else:
|
||||
raise ValueError(f"不支持的消息类型: {msg_type}")
|
||||
|
||||
def unregister_receiver(self, msg_type: str, callback: ReceiveCallback) -> None:
|
||||
"""注销接收函数"""
|
||||
if callback in self._receivers[msg_type]:
|
||||
self._receivers[msg_type].remove(callback)
|
||||
logger.debug(f"已注销 {msg_type} 接收器")
|
||||
|
||||
# 发送接口(供DTO层调用)
|
||||
async def send_binary(self, data: bytes):
|
||||
"""发送二进制数据(唯一客户端)"""
|
||||
if self._websocket:
|
||||
await self._websocket.send(data)
|
||||
logger.trace(f"二进制数据已发送 (长度: {len(data)} bytes)")
|
||||
else:
|
||||
raise RuntimeError("客户端未连接")
|
||||
|
||||
async def send_text(self, data: str):
|
||||
"""发送文本数据(唯一客户端)"""
|
||||
if self._websocket:
|
||||
await self._websocket.send(data)
|
||||
logger.trace(f"文本数据已发送 (长度: {len(data)} chars)")
|
||||
else:
|
||||
raise RuntimeError("客户端未连接")
|
||||
|
||||
async def send_json(self, data: Dict[str, Any]):
|
||||
"""发送JSON数据(唯一客户端)"""
|
||||
if self._websocket:
|
||||
try:
|
||||
logger.debug(f"准备发送JSON数据: {data}")
|
||||
message = json.dumps(data)
|
||||
await self._websocket.send(message)
|
||||
logger.trace(f"JSON数据已发送: {data}")
|
||||
except Exception as e:
|
||||
logger.error(f"JSON数据发送失败: {e}")
|
||||
raise
|
||||
else:
|
||||
raise RuntimeError("客户端未连接")
|
||||
|
||||
# 等待连接
|
||||
async def wait_for_client(self):
|
||||
"""阻塞等待客户端连接"""
|
||||
await self._connected_event.wait()
|
||||
logger.info("客户端已就绪")
|
||||
|
||||
# 内部消息循环
|
||||
async def _handle_client(self, websocket: ServerConnection):
|
||||
"""处理唯一客户端的消息循环"""
|
||||
self._websocket = websocket
|
||||
self._connected_event.set()
|
||||
client_info = f"{websocket.remote_address}" if hasattr(websocket, 'remote_address') else "unknown"
|
||||
logger.info(f"客户端已连接: {client_info}")
|
||||
|
||||
try:
|
||||
async for message in websocket:
|
||||
# 根据消息类型分发到所有注册的接收函数
|
||||
if isinstance(message, bytes):
|
||||
await self._dispatch('binary', message)
|
||||
|
||||
elif isinstance(message, str):
|
||||
# 优先尝试JSON解析
|
||||
json_dispatched = False
|
||||
try:
|
||||
data = json.loads(message)
|
||||
await self._dispatch('json', data)
|
||||
json_dispatched = True
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 如果不是JSON或没有json接收器,尝试text
|
||||
if not json_dispatched:
|
||||
await self._dispatch('text', message)
|
||||
|
||||
else:
|
||||
logger.warning(f"未知消息类型: {type(message)}")
|
||||
logger.info("客户端连接已正常关闭")
|
||||
except ConnectionClosed as e:
|
||||
logger.info(f"客户端连接已关闭: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"处理客户端时发生错误: {e}")
|
||||
finally:
|
||||
self._websocket = None
|
||||
self._connected_event.clear()
|
||||
|
||||
async def _dispatch(self, msg_type: str, data: Any):
|
||||
"""分发消息到所有注册的接收函数"""
|
||||
callbacks = self._receivers[msg_type]
|
||||
if not callbacks:
|
||||
logger.warning(f"无 {msg_type} 接收器,消息被忽略")
|
||||
return
|
||||
|
||||
# 并发执行所有回调
|
||||
tasks = [callback(data) for callback in callbacks]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 启动服务器
|
||||
async def run(self, host: str = "localhost", port: int = 8765, max_msg_size: int = 50*1024*1025):
|
||||
"""启动WebSocket服务器(阻塞)"""
|
||||
logger.info(f"WebSocket服务器启动中... 等待客户端连接 ws://{host}:{port}")
|
||||
|
||||
async def handler_wrapper(connection):
|
||||
logger.info(f"新连接请求: {connection.remote_address}")
|
||||
await self._handle_client(connection)
|
||||
|
||||
try:
|
||||
async with serve(
|
||||
handler=handler_wrapper, # 使用 wrapper 适配签名
|
||||
host=host,
|
||||
port=port,
|
||||
max_size=max_msg_size,
|
||||
):
|
||||
logger.success(f"WebSocket服务器已启动在 ws://{host}:{port}")
|
||||
await asyncio.Future() # 永久阻塞,保持服务器运行
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket服务器启动失败: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def get_ws_server() -> WebSocketServer:
|
||||
"""全局单例获取函数(线程安全)"""
|
||||
server = WebSocketServer() # __new__ 保证单例
|
||||
return server
|
||||
Reference in New Issue
Block a user