Files
Yosuga_server/src/modules/websocket_base_module/dto/third_dtos.py
T
2026-03-01 14:55:45 +08:00

282 lines
11 KiB
Python

import asyncio
from src.modules.websocket_base_module.dto.dto_templates.audio_data_dto import AudioDataTransferObject
from src.modules.websocket_base_module.dto.dto_templates.screenshot_data_dto import ScreenShotDataTransferObject
from src.modules.websocket_base_module.dto.second_dtos import JsonDTO
from loguru import logger
from typing import Callable, List, Optional, Coroutine
class AudioDataDTO:
"""音频数据交互DTO 再次分发给所有使用到了音频数据的相关业务(最后一级分发)"""
def __init__(self, json_dto : JsonDTO):
json_dto.register_receiver('audio_data', self._handle_audio_data) # 注册JSON接收函数
logger.info("[AudioDataDTO] 音频接收业务已注册")
self.json_dto = json_dto
self.audio_data = AudioDataTransferObject() # 音频数据对象
# 业务回调列表,延续观察者模式
self._audio_callbacks: List[Callable[[AudioDataTransferObject], Coroutine]] = []
# 最新音频缓存 支持同步查询
self._latest_audio: Optional[AudioDataTransferObject] = None
# 流式缓冲区 用于大段音频流
self._stream_buffer: List[AudioDataTransferObject] = []
logger.info("[AudioDataDTO] 业务接口已初始化")
async def _handle_audio_data(self, data: dict):
"""处理音频数据"""
"""
音频数据json格式:
{
'Owner': 'server',
'isStream': False,
'isStart': False,
'isEnd': False,
'sequence': 0,
'data': '',
'sampleRate': 16000,
'channelCount': 1,
'bitDepth': 16,
'duration': 0.0,
'text': ''
}
"""
logger.debug(f"[AudioDataDTO] 收到音频数据")
# 将dict反序列化到DTO对象
self.audio_data = AudioDataTransferObject.from_json(data)
# 缓存最新数据
self._latest_audio = self.audio_data
# 如果是流式数据,加入缓冲区
if self.audio_data.isStream:
self._stream_buffer.append(self.audio_data)
if self.audio_data.isEnd:
logger.info(f"流式音频接收完成,共 {len(self._stream_buffer)}")
# 通知所有注册的回调
await self._notify_callbacks()
# 业务发送接口
async def send_audio_data(self, data: AudioDataTransferObject) -> None:
"""
发送音频数据
Args:
data: 音频数据DTO
"""
await self.send_audio(
Owner=data.Owner,
is_stream=data.isStream,
is_start=data.isStart,
is_end=data.isEnd,
sequence=data.sequence,
data=data.data,
sampleRate=data.sampleRate,
channelCount=data.channelCount,
bitDepth=data.bitDepth,
duration=data.duration,
text=data.text
)
async def send_audio(
self,
data: bytes,
is_stream: bool = False,
is_start: bool = False,
is_end: bool = False,
sequence: int = 0,
**audio_meta
) -> None:
"""
业务层发送音频的便捷接口
Args:
data: 原始音频字节
is_stream: 是否为流式数据
is_start: 流式数据开始标记
is_end: 流式数据结束标记
sequence: 数据块序号
**audio_meta: 其他音频参数(sampleRate, channelCount等)
"""
# 填充音频数据到DTO
self.audio_data.set_dto_data(
Owner="server" or audio_meta.get('Owner', "server"),
isStream=is_stream,
isStart=is_start,
isEnd=is_end,
sequence=sequence,
data=data,
sampleRate=audio_meta.get('sampleRate', 16000),
channelCount=audio_meta.get('channelCount', 1),
bitDepth=audio_meta.get('bitDepth', 16),
duration=audio_meta.get('duration', 0.0),
text=audio_meta.get('text', "")
)
# 序列化为JSON并发送 自动处理base64和type字段
json_message = self.audio_data.to_json()
await self.json_dto.send_json(json_message)
# logger.info(f"音频已发送: sequence={sequence}, 大小={len(data)} bytes")
# 业务接收接口
def register_audio_callback(
self,
callback: Callable[[AudioDataTransferObject], Coroutine]
) -> None:
"""
业务注册接收回调
使用示例:
async def my_audio_handler(audio_dto: AudioDataTransferObject):
print(f"收到音频: {len(audio_dto.data)} bytes")
audio_dto.register_audio_callback(my_audio_handler)
"""
self._audio_callbacks.append(callback)
logger.debug(f"业务音频回调已注册,当前共 {len(self._audio_callbacks)}")
def unregister_audio_callback(self, callback) -> None:
"""注销业务回调"""
if callback in self._audio_callbacks:
self._audio_callbacks.remove(callback)
logger.debug("业务音频回调已注销")
def get_latest_audio(self) -> Optional[AudioDataTransferObject]:
"""
同步获取最新音频数据(轮询模式)
Returns:
最新接收到的音频DTO,如果没有则为 None
"""
return self._latest_audio
def clear_stream_buffer(self) -> None:
"""清空流式缓冲区"""
self._stream_buffer.clear()
logger.debug("流式音频缓冲区已清空")
# 内部通知机制
async def _notify_callbacks(self) -> None:
"""通知所有业务回调"""
if not self._audio_callbacks:
logger.warning("无业务回调,音频数据未处理")
return
tasks = [callback(self.audio_data) for callback in self._audio_callbacks]
await asyncio.gather(*tasks, return_exceptions=True)
logger.debug(f"已通知 {len(self._audio_callbacks)} 个业务回调")
# 异步迭代器 流式时使用
def __aiter__(self):
"""支持 async for 循环接收流式音频"""
return self
async def __anext__(self) -> AudioDataTransferObject:
"""异步迭代器协议"""
pass
class ScreenShotDataDTO:
"""截屏数据交互DTO 分发给所有使用到了截屏数据的相关业务(最后一级分发)"""
def __init__(self, json_dto : JsonDTO):
json_dto.register_receiver('screenshot_data', self._handle_screenshot_data) # 注册JSON接收函数
logger.info("[ScreenShotDataDTO] 截屏接收业务已注册")
self.json_dto = json_dto
self.screenshot_data = ScreenShotDataTransferObject() # 截屏数据对象
# 业务回调列表,延续观察者模式
self._screenshot_callbacks: List[Callable[[ScreenShotDataTransferObject], Coroutine]] = []
# 最新截屏数据缓存 支持同步查询
self._latest_screenshot: Optional[ScreenShotDataTransferObject] = None
logger.info("[ScreenShotDataDTO] 业务接口已初始化")
async def _handle_screenshot_data(self, data: dict):
"""处理截屏数据"""
"""
截屏数据json格式:
{
"Owner": "数据的拥有者(server or client)",
"isSuccess": "是否截图成功(true or false)"
"RealTimeScreenShot": "客户端设备的实时截图数据(base64)",
"Width": "截图的宽度",
"Height": "截图的高度",
"DescribeInfo": "设备的描述信息(告知模型以做出更加准确的判断)"
}
"""
logger.debug(f"[ScreenShotDataDTO] 收到截屏数据")
# 将dict反序列化到DTO对象
self.screenshot_data = ScreenShotDataTransferObject.from_json(data)
# 缓存最新数据
self._latest_screenshot = self.screenshot_data
# 通知所有注册的回调
await self._notify_callbacks()
# 业务发送接口
async def send_screenshot_data(self, data: ScreenShotDataTransferObject) -> None:
"""
发送音频数据
Args:
data: 音频数据DTO
"""
await self.send_screenshot(
Owner=data.Owner,
isSuccess=data.isSuccess,
RealTimeScreenShot=data.RealTimeScreenShot,
Width=data.Width,
Height=data.Height,
DescribeInfo=data.DescribeInfo,
LLMResponse=data.LLMResponse
)
async def send_screenshot(
self,
**screenshot_meta
) -> None:
"""
业务层发送音频的便捷接口
一般来说,作为发送请求方,不需要填充任何数据
Args:
**screenshot_meta: 截屏数据元信息
"""
# 填充音频数据到DTO
self.screenshot_data.set_dto_data(
Owner="server" or screenshot_meta.get('Owner', "server"),
isSuccess=screenshot_meta.get('isSuccess', False),
RealTimeScreenShot=screenshot_meta.get('RealTimeScreenShot', ""),
Width=screenshot_meta.get('Width', 1920),
Height=screenshot_meta.get('Height', 1080),
DescribeInfo=screenshot_meta.get('DescribeInfo', False),
LLMResponse=screenshot_meta.get('LLMResponse', "")
)
# 序列化为JSON并发送 自动处理base64和type字段
json_message = self.screenshot_data.to_json()
await self.json_dto.send_json(json_message)
logger.info(f"截屏包已发送")
# 业务接收接口
def register_screenshot_callback(
self,
callback: Callable[[ScreenShotDataTransferObject], Coroutine]
) -> None:
"""
业务注册接收回调
"""
self._screenshot_callbacks.append(callback)
logger.debug(f"业务截屏回调已注册,当前共 {len(self._screenshot_callbacks)}")
def unregister_screenshot_callback(self, callback) -> None:
"""注销业务回调"""
if callback in self._screenshot_callbacks:
self._screenshot_callbacks.remove(callback)
logger.debug("业务截屏回调已注销")
def get_latest_screenshot(self) -> Optional[ScreenShotDataTransferObject]:
"""
同步获取最新截屏数据(轮询模式)
Returns:
最新接收到的截屏数据DTO,如果没有则为 None
"""
return self._latest_screenshot
# 内部通知机制
async def _notify_callbacks(self) -> None:
"""通知所有业务回调"""
if not self._screenshot_callbacks:
logger.warning("无业务回调,截屏数据未处理")
return
tasks = [callback(self.screenshot_data) for callback in self._screenshot_callbacks]
await asyncio.gather(*tasks, return_exceptions=True)
logger.debug(f"已通知 {len(self._screenshot_callbacks)} 个业务回调")