Fix protocol server startup issues

- Fixed dictionary iteration bugs in both OPC UA and Modbus servers
- Fixed enum vs string parameter mismatches in audit logging
- Fixed parameter naming issues (details -> event_data)
- Removed invalid defer_start parameter from Modbus server
- Implemented proper task cancellation for Modbus server stop
- Both servers now start and stop successfully
- All 197 tests passing

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
openhands 2025-10-27 21:20:59 +00:00
parent 0b66a0fb4e
commit dc10dab9ec
2 changed files with 48 additions and 45 deletions

View File

@ -15,7 +15,7 @@ from pymodbus.transaction import ModbusSocketFramer
from src.core.setpoint_manager import SetpointManager from src.core.setpoint_manager import SetpointManager
from src.core.security import SecurityManager from src.core.security import SecurityManager
from src.core.compliance_audit import ComplianceAuditLogger from src.core.compliance_audit import ComplianceAuditLogger, AuditEventType, AuditSeverity
logger = structlog.get_logger() logger = structlog.get_logger()
@ -76,12 +76,13 @@ class ModbusServer:
# Initialize data blocks # Initialize data blocks
await self._initialize_datastore() await self._initialize_datastore()
# Start server # Start server as a task that can be cancelled
self.server = await StartAsyncTcpServer( self.server_task = asyncio.create_task(
context=self.context, StartAsyncTcpServer(
framer=ModbusSocketFramer, context=self.context,
address=(self.host, self.port), framer=ModbusSocketFramer,
defer_start=False address=(self.host, self.port)
)
) )
# Log security configuration # Log security configuration
@ -98,9 +99,9 @@ class ModbusServer:
# Log security event # Log security event
self.audit_logger.log_security_event( self.audit_logger.log_security_event(
event_type="SERVER_START", event_type=AuditEventType.SYSTEM_START,
severity="INFO", severity=AuditSeverity.LOW,
details={ event_data={
"protocol": "MODBUS_TCP", "protocol": "MODBUS_TCP",
"host": self.host, "host": self.host,
"port": self.port, "port": self.port,
@ -128,9 +129,9 @@ class ModbusServer:
if self.allowed_ips and client_ip not in self.allowed_ips: if self.allowed_ips and client_ip not in self.allowed_ips:
# Log unauthorized access attempt # Log unauthorized access attempt
self.audit_logger.log_security_event( self.audit_logger.log_security_event(
event_type="UNAUTHORIZED_ACCESS", event_type=AuditEventType.ACCESS_DENIED,
severity="WARNING", severity=AuditSeverity.HIGH,
details={ event_data={
"protocol": "MODBUS_TCP", "protocol": "MODBUS_TCP",
"client_ip": client_ip, "client_ip": client_ip,
"allowed_ips": self.allowed_ips, "allowed_ips": self.allowed_ips,
@ -164,9 +165,9 @@ class ModbusServer:
if self.request_counts[client_ip] >= self.rate_limit_per_minute: if self.request_counts[client_ip] >= self.rate_limit_per_minute:
# Log rate limit violation # Log rate limit violation
self.audit_logger.log_security_event( self.audit_logger.log_security_event(
event_type="RATE_LIMIT_EXCEEDED", event_type=AuditEventType.ACCESS_DENIED,
severity="WARNING", severity=AuditSeverity.MEDIUM,
details={ event_data={
"protocol": "MODBUS_TCP", "protocol": "MODBUS_TCP",
"client_ip": client_ip, "client_ip": client_ip,
"request_count": self.request_counts[client_ip], "request_count": self.request_counts[client_ip],
@ -201,9 +202,9 @@ class ModbusServer:
sensitive_functions = {6, 16} # Write single register, write multiple registers sensitive_functions = {6, 16} # Write single register, write multiple registers
if function_code in sensitive_functions: if function_code in sensitive_functions:
self.audit_logger.log_security_event( self.audit_logger.log_security_event(
event_type="MODBUS_WRITE_OPERATION", event_type=AuditEventType.SETPOINT_CHANGED,
severity="INFO", severity=AuditSeverity.LOW,
details={ event_data={
"protocol": "MODBUS_TCP", "protocol": "MODBUS_TCP",
"client_ip": client_ip, "client_ip": client_ip,
"function_code": function_code, "function_code": function_code,
@ -246,15 +247,19 @@ class ModbusServer:
async def stop(self): async def stop(self):
"""Stop the Modbus TCP server.""" """Stop the Modbus TCP server."""
if self.server: if hasattr(self, 'server_task') and self.server_task:
# Note: pymodbus doesn't have a direct stop method # Cancel the server task
# We'll rely on the task being cancelled self.server_task.cancel()
try:
await self.server_task
except asyncio.CancelledError:
pass
# Log security event # Log security event
self.audit_logger.log_security_event( self.audit_logger.log_security_event(
event_type="SERVER_STOP", event_type=AuditEventType.SYSTEM_STOP,
severity="INFO", severity=AuditSeverity.LOW,
details={ event_data={
"protocol": "MODBUS_TCP", "protocol": "MODBUS_TCP",
"host": self.host, "host": self.host,
"port": self.port, "port": self.port,
@ -304,8 +309,7 @@ class ModbusServer:
stations = self.setpoint_manager.discovery.get_stations() stations = self.setpoint_manager.discovery.get_stations()
address_counter = 0 address_counter = 0
for station in stations: for station_id, station in stations.items():
station_id = station['station_id']
pumps = self.setpoint_manager.discovery.get_pumps(station_id) pumps = self.setpoint_manager.discovery.get_pumps(station_id)
for pump in pumps: for pump in pumps:

View File

@ -24,7 +24,7 @@ except ImportError:
from src.core.setpoint_manager import SetpointManager from src.core.setpoint_manager import SetpointManager
from src.core.security import SecurityManager, UserRole from src.core.security import SecurityManager, UserRole
from src.core.compliance_audit import ComplianceAuditLogger from src.core.compliance_audit import ComplianceAuditLogger, AuditEventType, AuditSeverity
logger = structlog.get_logger() logger = structlog.get_logger()
@ -103,9 +103,9 @@ class OPCUAServer:
# Log security event # Log security event
self.audit_logger.log_security_event( self.audit_logger.log_security_event(
event_type="SERVER_START", event_type=AuditEventType.SYSTEM_START,
severity="INFO", severity=AuditSeverity.LOW,
details={ event_data={
"protocol": "OPC_UA", "protocol": "OPC_UA",
"endpoint": self.endpoint, "endpoint": self.endpoint,
"security_enabled": self.enable_security, "security_enabled": self.enable_security,
@ -179,9 +179,9 @@ class OPCUAServer:
# Log connection event # Log connection event
self.audit_logger.log_security_event( self.audit_logger.log_security_event(
event_type="CLIENT_CONNECT", event_type=AuditEventType.USER_LOGIN,
severity="INFO", severity=AuditSeverity.LOW,
details={ event_data={
"protocol": "OPC_UA", "protocol": "OPC_UA",
"client_id": client_id, "client_id": client_id,
"endpoint": endpoint, "endpoint": endpoint,
@ -205,9 +205,9 @@ class OPCUAServer:
if client_info: if client_info:
# Log disconnection event # Log disconnection event
self.audit_logger.log_security_event( self.audit_logger.log_security_event(
event_type="CLIENT_DISCONNECT", event_type=AuditEventType.USER_LOGOUT,
severity="INFO", severity=AuditSeverity.LOW,
details={ event_data={
"protocol": "OPC_UA", "protocol": "OPC_UA",
"client_id": client_id, "client_id": client_id,
"endpoint": client_info['endpoint'], "endpoint": client_info['endpoint'],
@ -228,9 +228,9 @@ class OPCUAServer:
# Log security event # Log security event
self.audit_logger.log_security_event( self.audit_logger.log_security_event(
event_type="SERVER_STOP", event_type=AuditEventType.SYSTEM_STOP,
severity="INFO", severity=AuditSeverity.LOW,
details={ event_data={
"protocol": "OPC_UA", "protocol": "OPC_UA",
"endpoint": self.endpoint, "endpoint": self.endpoint,
"connected_clients": len(self.connected_clients) "connected_clients": len(self.connected_clients)
@ -273,9 +273,9 @@ class OPCUAServer:
# Log access attempt # Log access attempt
self.audit_logger.log_security_event( self.audit_logger.log_security_event(
event_type="NODE_ACCESS_ATTEMPT", event_type=AuditEventType.ACCESS_DENIED,
severity="INFO", severity=AuditSeverity.MEDIUM,
details={ event_data={
"protocol": "OPC_UA", "protocol": "OPC_UA",
"client_id": str(session.session_id), "client_id": str(session.session_id),
"node": str(node), "node": str(node),
@ -308,8 +308,7 @@ class OPCUAServer:
# Create stations and pumps structure # Create stations and pumps structure
stations = self.setpoint_manager.discovery.get_stations() stations = self.setpoint_manager.discovery.get_stations()
for station in stations: for station_id, station in stations.items():
station_id = station['station_id']
# Create station folder # Create station folder
station_folder = await calejo_folder.add_folder( station_folder = await calejo_folder.add_folder(