first pull
This commit is contained in:
@@ -0,0 +1,281 @@
|
||||
import asyncio
|
||||
from src.modules.websocket_base_module.dto.dto_templates.audio_data_dto import AudioDataTransferObject
|
||||
from src.modules.websocket_base_module.dto.dto_templates.screenshot_data_dto import ScreenShotDataTransferObject
|
||||
from src.modules.websocket_base_module.dto.second_dtos import JsonDTO
|
||||
from loguru import logger
|
||||
from typing import Callable, List, Optional, Coroutine
|
||||
|
||||
class AudioDataDTO:
|
||||
"""音频数据交互DTO 再次分发给所有使用到了音频数据的相关业务(最后一级分发)"""
|
||||
def __init__(self, json_dto : JsonDTO):
|
||||
json_dto.register_receiver('audio_data', self._handle_audio_data) # 注册JSON接收函数
|
||||
logger.info("[AudioDataDTO] 音频接收业务已注册")
|
||||
self.json_dto = json_dto
|
||||
self.audio_data = AudioDataTransferObject() # 音频数据对象
|
||||
# 业务回调列表,延续观察者模式
|
||||
self._audio_callbacks: List[Callable[[AudioDataTransferObject], Coroutine]] = []
|
||||
# 最新音频缓存 支持同步查询
|
||||
self._latest_audio: Optional[AudioDataTransferObject] = None
|
||||
# 流式缓冲区 用于大段音频流
|
||||
self._stream_buffer: List[AudioDataTransferObject] = []
|
||||
logger.info("[AudioDataDTO] 业务接口已初始化")
|
||||
|
||||
async def _handle_audio_data(self, data: dict):
|
||||
"""处理音频数据"""
|
||||
"""
|
||||
音频数据json格式:
|
||||
{
|
||||
'Owner': 'server',
|
||||
'isStream': False,
|
||||
'isStart': False,
|
||||
'isEnd': False,
|
||||
'sequence': 0,
|
||||
'data': '',
|
||||
'sampleRate': 16000,
|
||||
'channelCount': 1,
|
||||
'bitDepth': 16,
|
||||
'duration': 0.0,
|
||||
'text': ''
|
||||
}
|
||||
"""
|
||||
logger.debug(f"[AudioDataDTO] 收到音频数据")
|
||||
# 将dict反序列化到DTO对象
|
||||
self.audio_data = AudioDataTransferObject.from_json(data)
|
||||
# 缓存最新数据
|
||||
self._latest_audio = self.audio_data
|
||||
# 如果是流式数据,加入缓冲区
|
||||
if self.audio_data.isStream:
|
||||
self._stream_buffer.append(self.audio_data)
|
||||
if self.audio_data.isEnd:
|
||||
logger.info(f"流式音频接收完成,共 {len(self._stream_buffer)} 块")
|
||||
# 通知所有注册的回调
|
||||
await self._notify_callbacks()
|
||||
|
||||
# 业务发送接口
|
||||
async def send_audio_data(self, data: AudioDataTransferObject) -> None:
|
||||
"""
|
||||
发送音频数据
|
||||
|
||||
Args:
|
||||
data: 音频数据DTO
|
||||
"""
|
||||
await self.send_audio(
|
||||
Owner=data.Owner,
|
||||
is_stream=data.isStream,
|
||||
is_start=data.isStart,
|
||||
is_end=data.isEnd,
|
||||
sequence=data.sequence,
|
||||
data=data.data,
|
||||
sampleRate=data.sampleRate,
|
||||
channelCount=data.channelCount,
|
||||
bitDepth=data.bitDepth,
|
||||
duration=data.duration,
|
||||
text=data.text
|
||||
)
|
||||
|
||||
async def send_audio(
|
||||
self,
|
||||
data: bytes,
|
||||
is_stream: bool = False,
|
||||
is_start: bool = False,
|
||||
is_end: bool = False,
|
||||
sequence: int = 0,
|
||||
**audio_meta
|
||||
) -> None:
|
||||
"""
|
||||
业务层发送音频的便捷接口
|
||||
|
||||
Args:
|
||||
data: 原始音频字节
|
||||
is_stream: 是否为流式数据
|
||||
is_start: 流式数据开始标记
|
||||
is_end: 流式数据结束标记
|
||||
sequence: 数据块序号
|
||||
**audio_meta: 其他音频参数(sampleRate, channelCount等)
|
||||
"""
|
||||
# 填充音频数据到DTO
|
||||
self.audio_data.set_dto_data(
|
||||
Owner="server" or audio_meta.get('Owner', "server"),
|
||||
isStream=is_stream,
|
||||
isStart=is_start,
|
||||
isEnd=is_end,
|
||||
sequence=sequence,
|
||||
data=data,
|
||||
sampleRate=audio_meta.get('sampleRate', 16000),
|
||||
channelCount=audio_meta.get('channelCount', 1),
|
||||
bitDepth=audio_meta.get('bitDepth', 16),
|
||||
duration=audio_meta.get('duration', 0.0),
|
||||
text=audio_meta.get('text', "")
|
||||
)
|
||||
# 序列化为JSON并发送 自动处理base64和type字段
|
||||
json_message = self.audio_data.to_json()
|
||||
await self.json_dto.send_json(json_message)
|
||||
logger.info(f"音频已发送: sequence={sequence}, 大小={len(data)} bytes")
|
||||
|
||||
# 业务接收接口
|
||||
def register_audio_callback(
|
||||
self,
|
||||
callback: Callable[[AudioDataTransferObject], Coroutine]
|
||||
) -> None:
|
||||
"""
|
||||
业务注册接收回调
|
||||
|
||||
使用示例:
|
||||
async def my_audio_handler(audio_dto: AudioDataTransferObject):
|
||||
print(f"收到音频: {len(audio_dto.data)} bytes")
|
||||
|
||||
audio_dto.register_audio_callback(my_audio_handler)
|
||||
"""
|
||||
self._audio_callbacks.append(callback)
|
||||
logger.debug(f"业务音频回调已注册,当前共 {len(self._audio_callbacks)} 个")
|
||||
|
||||
def unregister_audio_callback(self, callback) -> None:
|
||||
"""注销业务回调"""
|
||||
if callback in self._audio_callbacks:
|
||||
self._audio_callbacks.remove(callback)
|
||||
logger.debug("业务音频回调已注销")
|
||||
|
||||
def get_latest_audio(self) -> Optional[AudioDataTransferObject]:
|
||||
"""
|
||||
同步获取最新音频数据(轮询模式)
|
||||
|
||||
Returns:
|
||||
最新接收到的音频DTO,如果没有则为 None
|
||||
"""
|
||||
return self._latest_audio
|
||||
|
||||
def clear_stream_buffer(self) -> None:
|
||||
"""清空流式缓冲区"""
|
||||
self._stream_buffer.clear()
|
||||
logger.debug("流式音频缓冲区已清空")
|
||||
|
||||
# 内部通知机制
|
||||
async def _notify_callbacks(self) -> None:
|
||||
"""通知所有业务回调"""
|
||||
if not self._audio_callbacks:
|
||||
logger.warning("无业务回调,音频数据未处理")
|
||||
return
|
||||
tasks = [callback(self.audio_data) for callback in self._audio_callbacks]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
logger.debug(f"已通知 {len(self._audio_callbacks)} 个业务回调")
|
||||
# 异步迭代器 流式时使用
|
||||
def __aiter__(self):
|
||||
"""支持 async for 循环接收流式音频"""
|
||||
return self
|
||||
async def __anext__(self) -> AudioDataTransferObject:
|
||||
"""异步迭代器协议"""
|
||||
pass
|
||||
|
||||
class ScreenShotDataDTO:
|
||||
"""截屏数据交互DTO 分发给所有使用到了截屏数据的相关业务(最后一级分发)"""
|
||||
def __init__(self, json_dto : JsonDTO):
|
||||
json_dto.register_receiver('screenshot_data', self._handle_screenshot_data) # 注册JSON接收函数
|
||||
logger.info("[ScreenShotDataDTO] 截屏接收业务已注册")
|
||||
self.json_dto = json_dto
|
||||
self.screenshot_data = ScreenShotDataTransferObject() # 截屏数据对象
|
||||
# 业务回调列表,延续观察者模式
|
||||
self._screenshot_callbacks: List[Callable[[ScreenShotDataTransferObject], Coroutine]] = []
|
||||
# 最新截屏数据缓存 支持同步查询
|
||||
self._latest_screenshot: Optional[ScreenShotDataTransferObject] = None
|
||||
logger.info("[ScreenShotDataDTO] 业务接口已初始化")
|
||||
|
||||
async def _handle_screenshot_data(self, data: dict):
|
||||
"""处理截屏数据"""
|
||||
"""
|
||||
截屏数据json格式:
|
||||
{
|
||||
"Owner": "数据的拥有者(server or client)",
|
||||
"isSuccess": "是否截图成功(true or false)"
|
||||
"RealTimeScreenShot": "客户端设备的实时截图数据(base64)",
|
||||
"Width": "截图的宽度",
|
||||
"Height": "截图的高度",
|
||||
"DescribeInfo": "设备的描述信息(告知模型以做出更加准确的判断)"
|
||||
}
|
||||
"""
|
||||
logger.debug(f"[ScreenShotDataDTO] 收到截屏数据")
|
||||
# 将dict反序列化到DTO对象
|
||||
self.screenshot_data = ScreenShotDataTransferObject.from_json(data)
|
||||
# 缓存最新数据
|
||||
self._latest_screenshot = self.screenshot_data
|
||||
# 通知所有注册的回调
|
||||
await self._notify_callbacks()
|
||||
|
||||
# 业务发送接口
|
||||
async def send_screenshot_data(self, data: ScreenShotDataTransferObject) -> None:
|
||||
"""
|
||||
发送音频数据
|
||||
|
||||
Args:
|
||||
data: 音频数据DTO
|
||||
"""
|
||||
await self.send_screenshot(
|
||||
Owner=data.Owner,
|
||||
isSuccess=data.isSuccess,
|
||||
RealTimeScreenShot=data.RealTimeScreenShot,
|
||||
Width=data.Width,
|
||||
Height=data.Height,
|
||||
DescribeInfo=data.DescribeInfo,
|
||||
LLMResponse=data.LLMResponse
|
||||
)
|
||||
|
||||
async def send_screenshot(
|
||||
self,
|
||||
**screenshot_meta
|
||||
) -> None:
|
||||
"""
|
||||
业务层发送音频的便捷接口
|
||||
一般来说,作为发送请求方,不需要填充任何数据
|
||||
|
||||
Args:
|
||||
**screenshot_meta: 截屏数据元信息
|
||||
"""
|
||||
# 填充音频数据到DTO
|
||||
self.screenshot_data.set_dto_data(
|
||||
Owner="server" or screenshot_meta.get('Owner', "server"),
|
||||
isSuccess=screenshot_meta.get('isSuccess', False),
|
||||
RealTimeScreenShot=screenshot_meta.get('RealTimeScreenShot', ""),
|
||||
Width=screenshot_meta.get('Width', 1920),
|
||||
Height=screenshot_meta.get('Height', 1080),
|
||||
DescribeInfo=screenshot_meta.get('DescribeInfo', False),
|
||||
LLMResponse=screenshot_meta.get('LLMResponse', "")
|
||||
)
|
||||
# 序列化为JSON并发送 自动处理base64和type字段
|
||||
json_message = self.screenshot_data.to_json()
|
||||
await self.json_dto.send_json(json_message)
|
||||
logger.info(f"截屏包已发送")
|
||||
|
||||
# 业务接收接口
|
||||
def register_screenshot_callback(
|
||||
self,
|
||||
callback: Callable[[ScreenShotDataTransferObject], Coroutine]
|
||||
) -> None:
|
||||
"""
|
||||
业务注册接收回调
|
||||
"""
|
||||
self._screenshot_callbacks.append(callback)
|
||||
logger.debug(f"业务截屏回调已注册,当前共 {len(self._screenshot_callbacks)} 个")
|
||||
|
||||
def unregister_screenshot_callback(self, callback) -> None:
|
||||
"""注销业务回调"""
|
||||
if callback in self._screenshot_callbacks:
|
||||
self._screenshot_callbacks.remove(callback)
|
||||
logger.debug("业务截屏回调已注销")
|
||||
|
||||
def get_latest_screenshot(self) -> Optional[ScreenShotDataTransferObject]:
|
||||
"""
|
||||
同步获取最新截屏数据(轮询模式)
|
||||
|
||||
Returns:
|
||||
最新接收到的截屏数据DTO,如果没有则为 None
|
||||
"""
|
||||
return self._latest_screenshot
|
||||
|
||||
# 内部通知机制
|
||||
async def _notify_callbacks(self) -> None:
|
||||
"""通知所有业务回调"""
|
||||
if not self._screenshot_callbacks:
|
||||
logger.warning("无业务回调,截屏数据未处理")
|
||||
return
|
||||
tasks = [callback(self.screenshot_data) for callback in self._screenshot_callbacks]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
logger.debug(f"已通知 {len(self._screenshot_callbacks)} 个业务回调")
|
||||
Reference in New Issue
Block a user