39c54a4452
2. 增加了对嵌入式设备完全自定义控制的功能
269 lines
9.8 KiB
Python
269 lines
9.8 KiB
Python
|
|
import socket
|
|
import struct
|
|
import threading
|
|
import json
|
|
import os
|
|
from typing import Optional
|
|
|
|
from openai import OpenAI # 确保 pip install openai
|
|
|
|
from src.server_core.yosuga_embedded_server.device_manager import DeviceManager, DeviceInfo
|
|
from src.server_core.yosuga_embedded_server.function_registry import FunctionRegistry
|
|
from src.server_core.yosuga_embedded_server.ai_prompt import AIPromptBuilder
|
|
from src.server_core.yosuga_embedded_server.json_rpc import JSONRPCHandler, RPCError
|
|
|
|
|
|
class DeviceConnection:
|
|
"""管理一个 TCP 设备连接"""
|
|
def __init__(self, sock, addr, server):
|
|
self.sock = sock
|
|
self.addr = addr
|
|
self.server = server
|
|
self.device_id: Optional[str] = None
|
|
|
|
def send_msg(self, data: str):
|
|
"""发送带长度前缀的消息"""
|
|
encoded = data.encode('utf-8')
|
|
self.sock.sendall(struct.pack('!I', len(encoded)))
|
|
self.sock.sendall(encoded)
|
|
|
|
def recv_msg(self) -> Optional[str]:
|
|
try:
|
|
# 读满 4 字节长度前缀
|
|
raw_len = b''
|
|
while len(raw_len) < 4:
|
|
chunk = self.sock.recv(4 - len(raw_len))
|
|
if not chunk: # 对方关闭连接
|
|
return None
|
|
raw_len += chunk
|
|
msg_len = struct.unpack('!I', raw_len)[0]
|
|
|
|
# 读消息体
|
|
data = b''
|
|
while len(data) < msg_len:
|
|
chunk = self.sock.recv(msg_len - len(data))
|
|
if not chunk:
|
|
return None
|
|
data += chunk
|
|
return data.decode('utf-8')
|
|
except Exception as e:
|
|
print(f"recv error: {e}")
|
|
return None
|
|
|
|
def handle(self):
|
|
try:
|
|
# 只接收第一条能力广告
|
|
caps_str = self.recv_msg()
|
|
if not caps_str:
|
|
return
|
|
print(f"[{self.addr}] Capabilities received")
|
|
device = self.server.register_device(caps_str)
|
|
self.device_id = device.device_id
|
|
self.server.set_connection(device.device_id, self)
|
|
print(f"[{self.addr}] Registered as {device.name} (id={device.device_id})")
|
|
|
|
# 此处直接返回,线程结束。后续通信由 call_device 接管。
|
|
except Exception as e:
|
|
print(f"Registration error: {e}")
|
|
finally:
|
|
# 注意:不要关闭 socket!它还要用于后续 call_device
|
|
pass
|
|
|
|
|
|
class YosugaServer:
|
|
"""整合设备管理、函数注册、AI 交互的服务端"""
|
|
def __init__(self, deepseek_api_key: str):
|
|
self.device_manager = DeviceManager()
|
|
self.function_registry = FunctionRegistry()
|
|
self.ai_prompt = AIPromptBuilder()
|
|
self._lock = threading.Lock()
|
|
self._call_id_counter = 0
|
|
|
|
# 关联设备变更
|
|
self.device_manager.on_device_change = self._on_device_change
|
|
self.function_registry.on_change = self._on_functions_change
|
|
|
|
# DeepSeek 客户端
|
|
self.ai_client = OpenAI(
|
|
api_key=deepseek_api_key,
|
|
base_url="https://api.deepseek.com"
|
|
)
|
|
|
|
# 设备连接表:device_id -> DeviceConnection
|
|
self._connections: dict[str, DeviceConnection] = {}
|
|
|
|
def _on_device_change(self, event: str, device: DeviceInfo):
|
|
print(f"Device event: {event} {device.device_id}")
|
|
if event == "removed":
|
|
self.function_registry.remove_device_functions(device.device_id)
|
|
conn = self._connections.pop(device.device_id, None)
|
|
if conn and conn.sock:
|
|
try:
|
|
conn.sock.close()
|
|
except Exception:
|
|
pass
|
|
elif event in ("added", "updated"):
|
|
if device.state.value == "registered":
|
|
self.function_registry.add_device_functions(
|
|
device.device_id,
|
|
device.name,
|
|
device.functions or [],
|
|
)
|
|
|
|
def _on_functions_change(self):
|
|
print(f"Functions updated, total: {self.function_registry.function_count()}")
|
|
|
|
def register_device(self, caps_json: str) -> DeviceInfo:
|
|
data = json.loads(caps_json)
|
|
return self.device_manager.register_from_json(data)
|
|
|
|
def remove_device(self, device_id: str):
|
|
self.device_manager.remove_device(device_id)
|
|
|
|
def set_connection(self, device_id: str, conn: DeviceConnection):
|
|
with self._lock:
|
|
self._connections[device_id] = conn
|
|
|
|
def call_device(self, device_id: str, method: str, params: dict, call_id: int) -> dict:
|
|
conn = self._connections.get(device_id)
|
|
if not conn:
|
|
return {"error": {"code": RPCError.DEVICE_NOT_FOUND,
|
|
"message": f"Device {device_id} not connected"}}
|
|
|
|
req = JSONRPCHandler.build_call(method, params, call_id)
|
|
try:
|
|
conn.send_msg(req)
|
|
resp_str = conn.recv_msg_with_timeout(timeout=5.0)
|
|
if resp_str is None:
|
|
# 超时或连接断开
|
|
self.remove_device(device_id) # 自动清理
|
|
return {"error": {"code": RPCError.TIMEOUT, "message": "Device timeout or disconnected"}}
|
|
|
|
resp = JSONRPCHandler.parse_response(resp_str)
|
|
if resp and resp.is_success():
|
|
return {"result": resp.result}
|
|
elif resp and resp.error:
|
|
return {"error": resp.error.to_dict()}
|
|
else:
|
|
return {"error": {"code": RPCError.PARSE_ERROR, "message": "Invalid response"}}
|
|
except Exception as e:
|
|
self.remove_device(device_id)
|
|
return {"error": {"code": RPCError.DEVICE_ERROR, "message": str(e)}}
|
|
|
|
def process_device_message(self, device_id: str, message: str) -> str:
|
|
"""处理从设备主动发来的消息(如 RPC 响应)"""
|
|
# 在这个架构中,设备不会主动发消息,所有交互都由服务端发起
|
|
# 这里仅作为占位,返回空
|
|
return ""
|
|
|
|
def handle_user_request(self, user_input: str) -> str:
|
|
"""处理用户自然语言请求:调用 AI,执行函数,返回最终答案"""
|
|
# 1. 构建系统提示(包含当前所有函数)
|
|
func_list = self.function_registry.to_function_list()
|
|
system_prompt = self.ai_prompt.build_system_prompt(func_list)
|
|
|
|
# 2. 调用 DeepSeek
|
|
print("Calling DeepSeek...")
|
|
try:
|
|
response = self.ai_client.chat.completions.create(
|
|
model="deepseek-chat",
|
|
messages=[
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": user_input}
|
|
],
|
|
temperature=0.0, # 确保输出稳定
|
|
)
|
|
ai_text = response.choices[0].message.content.strip()
|
|
print(f"AI response:\n{ai_text}")
|
|
except Exception as e:
|
|
return f"AI error: {e}"
|
|
|
|
# 3. 解析 AI 返回的调用
|
|
calls = self.ai_prompt.parse_ai_response(ai_text)
|
|
if not calls:
|
|
return "Failed to parse AI response as JSON-RPC calls"
|
|
|
|
# 4. 逐个执行调用
|
|
results = []
|
|
for idx, call in enumerate(calls):
|
|
method = call.get("method")
|
|
params = call.get("params", {})
|
|
call_id = idx + 1 # 简单自增 ID
|
|
|
|
# 找到该函数所属设备
|
|
func_info = self.function_registry.get_function(method)
|
|
if not func_info:
|
|
results.append(f"❌ Unknown function: {method}")
|
|
continue
|
|
|
|
device_id = func_info.device_id
|
|
print(f"Routing '{method}' to device {device_id}")
|
|
|
|
# 发送给设备
|
|
dev_resp = self.call_device(device_id, method, params, call_id)
|
|
if "error" in dev_resp:
|
|
results.append(f"❌ {method}: {dev_resp['error']['message']}")
|
|
else:
|
|
results.append(f"✅ {method}: {dev_resp.get('result', 'ok')}")
|
|
|
|
# 5. 汇总结果返回
|
|
summary = "\n".join(results)
|
|
print(f"Task result:\n{summary}")
|
|
return summary
|
|
|
|
|
|
# ========== 修改 DeviceConnection 的 recv_msg 支持超时 ==========
|
|
|
|
def recv_msg_with_timeout(self, timeout: float = 5.0) -> Optional[str]:
|
|
"""接收带长度前缀的消息,带超时"""
|
|
self.sock.settimeout(timeout)
|
|
try:
|
|
return self.recv_msg()
|
|
except socket.timeout:
|
|
return None
|
|
finally:
|
|
self.sock.settimeout(None)
|
|
|
|
DeviceConnection.recv_msg_with_timeout = recv_msg_with_timeout # monkey-patch
|
|
|
|
|
|
# ========== 主函数 ==========
|
|
def main():
|
|
|
|
|
|
server = YosugaServer("")
|
|
|
|
# 启动 TCP 监听,接收设备连接
|
|
def accept_connections():
|
|
listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
listener.bind(('0.0.0.0', 9555))
|
|
listener.listen(5)
|
|
print("Server listening on port 9555...")
|
|
while True:
|
|
sock, addr = listener.accept()
|
|
print(f"New connection from {addr}")
|
|
conn = DeviceConnection(sock, addr, server)
|
|
# 在连接注册后,关联 device_id 需要等到能力广告,暂存一下
|
|
# 我们在 handle 中收到能力后再设置 server.set_connection
|
|
threading.Thread(target=conn.handle, daemon=True).start()
|
|
|
|
threading.Thread(target=accept_connections, daemon=True).start()
|
|
|
|
# 简单交互循环
|
|
print("\nEnter your requests (type 'quit' to exit):")
|
|
while True:
|
|
try:
|
|
user_input = input("> ")
|
|
except EOFError:
|
|
break
|
|
if user_input.lower() in ('quit', 'exit'):
|
|
break
|
|
if not user_input.strip():
|
|
continue
|
|
result = server.handle_user_request(user_input)
|
|
print("---\n" + result + "\n---")
|
|
|
|
if __name__ == "__main__":
|
|
main() |