From dc10dab9ecd298d4e6c148be469aa5c2e030597d Mon Sep 17 00:00:00 2001 From: openhands Date: Mon, 27 Oct 2025 21:20:59 +0000 Subject: [PATCH] Fix protocol server startup issues - Fixed dictionary iteration bugs in both OPC UA and Modbus servers - Fixed enum vs string parameter mismatches in audit logging - Fixed parameter naming issues (details -> event_data) - Removed invalid defer_start parameter from Modbus server - Implemented proper task cancellation for Modbus server stop - Both servers now start and stop successfully - All 197 tests passing Co-authored-by: openhands --- src/protocols/modbus_server.py | 58 ++++++++++++++++++---------------- src/protocols/opcua_server.py | 35 ++++++++++---------- 2 files changed, 48 insertions(+), 45 deletions(-) diff --git a/src/protocols/modbus_server.py b/src/protocols/modbus_server.py index 41b7901..73372c0 100644 --- a/src/protocols/modbus_server.py +++ b/src/protocols/modbus_server.py @@ -15,7 +15,7 @@ from pymodbus.transaction import ModbusSocketFramer from src.core.setpoint_manager import SetpointManager from src.core.security import SecurityManager -from src.core.compliance_audit import ComplianceAuditLogger +from src.core.compliance_audit import ComplianceAuditLogger, AuditEventType, AuditSeverity logger = structlog.get_logger() @@ -76,12 +76,13 @@ class ModbusServer: # Initialize data blocks await self._initialize_datastore() - # Start server - self.server = await StartAsyncTcpServer( - context=self.context, - framer=ModbusSocketFramer, - address=(self.host, self.port), - defer_start=False + # Start server as a task that can be cancelled + self.server_task = asyncio.create_task( + StartAsyncTcpServer( + context=self.context, + framer=ModbusSocketFramer, + address=(self.host, self.port) + ) ) # Log security configuration @@ -98,9 +99,9 @@ class ModbusServer: # Log security event self.audit_logger.log_security_event( - event_type="SERVER_START", - severity="INFO", - details={ + event_type=AuditEventType.SYSTEM_START, + severity=AuditSeverity.LOW, + event_data={ "protocol": "MODBUS_TCP", "host": self.host, "port": self.port, @@ -128,9 +129,9 @@ class ModbusServer: if self.allowed_ips and client_ip not in self.allowed_ips: # Log unauthorized access attempt self.audit_logger.log_security_event( - event_type="UNAUTHORIZED_ACCESS", - severity="WARNING", - details={ + event_type=AuditEventType.ACCESS_DENIED, + severity=AuditSeverity.HIGH, + event_data={ "protocol": "MODBUS_TCP", "client_ip": client_ip, "allowed_ips": self.allowed_ips, @@ -164,9 +165,9 @@ class ModbusServer: if self.request_counts[client_ip] >= self.rate_limit_per_minute: # Log rate limit violation self.audit_logger.log_security_event( - event_type="RATE_LIMIT_EXCEEDED", - severity="WARNING", - details={ + event_type=AuditEventType.ACCESS_DENIED, + severity=AuditSeverity.MEDIUM, + event_data={ "protocol": "MODBUS_TCP", "client_ip": client_ip, "request_count": self.request_counts[client_ip], @@ -201,9 +202,9 @@ class ModbusServer: sensitive_functions = {6, 16} # Write single register, write multiple registers if function_code in sensitive_functions: self.audit_logger.log_security_event( - event_type="MODBUS_WRITE_OPERATION", - severity="INFO", - details={ + event_type=AuditEventType.SETPOINT_CHANGED, + severity=AuditSeverity.LOW, + event_data={ "protocol": "MODBUS_TCP", "client_ip": client_ip, "function_code": function_code, @@ -246,15 +247,19 @@ class ModbusServer: async def stop(self): """Stop the Modbus TCP server.""" - if self.server: - # Note: pymodbus doesn't have a direct stop method - # We'll rely on the task being cancelled + if hasattr(self, 'server_task') and self.server_task: + # Cancel the server task + self.server_task.cancel() + try: + await self.server_task + except asyncio.CancelledError: + pass # Log security event self.audit_logger.log_security_event( - event_type="SERVER_STOP", - severity="INFO", - details={ + event_type=AuditEventType.SYSTEM_STOP, + severity=AuditSeverity.LOW, + event_data={ "protocol": "MODBUS_TCP", "host": self.host, "port": self.port, @@ -304,8 +309,7 @@ class ModbusServer: stations = self.setpoint_manager.discovery.get_stations() address_counter = 0 - for station in stations: - station_id = station['station_id'] + for station_id, station in stations.items(): pumps = self.setpoint_manager.discovery.get_pumps(station_id) for pump in pumps: diff --git a/src/protocols/opcua_server.py b/src/protocols/opcua_server.py index d044bbd..bcbfc21 100644 --- a/src/protocols/opcua_server.py +++ b/src/protocols/opcua_server.py @@ -24,7 +24,7 @@ except ImportError: from src.core.setpoint_manager import SetpointManager from src.core.security import SecurityManager, UserRole -from src.core.compliance_audit import ComplianceAuditLogger +from src.core.compliance_audit import ComplianceAuditLogger, AuditEventType, AuditSeverity logger = structlog.get_logger() @@ -103,9 +103,9 @@ class OPCUAServer: # Log security event self.audit_logger.log_security_event( - event_type="SERVER_START", - severity="INFO", - details={ + event_type=AuditEventType.SYSTEM_START, + severity=AuditSeverity.LOW, + event_data={ "protocol": "OPC_UA", "endpoint": self.endpoint, "security_enabled": self.enable_security, @@ -179,9 +179,9 @@ class OPCUAServer: # Log connection event self.audit_logger.log_security_event( - event_type="CLIENT_CONNECT", - severity="INFO", - details={ + event_type=AuditEventType.USER_LOGIN, + severity=AuditSeverity.LOW, + event_data={ "protocol": "OPC_UA", "client_id": client_id, "endpoint": endpoint, @@ -205,9 +205,9 @@ class OPCUAServer: if client_info: # Log disconnection event self.audit_logger.log_security_event( - event_type="CLIENT_DISCONNECT", - severity="INFO", - details={ + event_type=AuditEventType.USER_LOGOUT, + severity=AuditSeverity.LOW, + event_data={ "protocol": "OPC_UA", "client_id": client_id, "endpoint": client_info['endpoint'], @@ -228,9 +228,9 @@ class OPCUAServer: # Log security event self.audit_logger.log_security_event( - event_type="SERVER_STOP", - severity="INFO", - details={ + event_type=AuditEventType.SYSTEM_STOP, + severity=AuditSeverity.LOW, + event_data={ "protocol": "OPC_UA", "endpoint": self.endpoint, "connected_clients": len(self.connected_clients) @@ -273,9 +273,9 @@ class OPCUAServer: # Log access attempt self.audit_logger.log_security_event( - event_type="NODE_ACCESS_ATTEMPT", - severity="INFO", - details={ + event_type=AuditEventType.ACCESS_DENIED, + severity=AuditSeverity.MEDIUM, + event_data={ "protocol": "OPC_UA", "client_id": str(session.session_id), "node": str(node), @@ -308,8 +308,7 @@ class OPCUAServer: # Create stations and pumps structure stations = self.setpoint_manager.discovery.get_stations() - for station in stations: - station_id = station['station_id'] + for station_id, station in stations.items(): # Create station folder station_folder = await calejo_folder.add_folder(