Fix protocol server startup issues
- Fixed dictionary iteration bugs in both OPC UA and Modbus servers - Fixed enum vs string parameter mismatches in audit logging - Fixed parameter naming issues (details -> event_data) - Removed invalid defer_start parameter from Modbus server - Implemented proper task cancellation for Modbus server stop - Both servers now start and stop successfully - All 197 tests passing Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
parent
0b66a0fb4e
commit
dc10dab9ec
|
|
@ -15,7 +15,7 @@ from pymodbus.transaction import ModbusSocketFramer
|
||||||
|
|
||||||
from src.core.setpoint_manager import SetpointManager
|
from src.core.setpoint_manager import SetpointManager
|
||||||
from src.core.security import SecurityManager
|
from src.core.security import SecurityManager
|
||||||
from src.core.compliance_audit import ComplianceAuditLogger
|
from src.core.compliance_audit import ComplianceAuditLogger, AuditEventType, AuditSeverity
|
||||||
|
|
||||||
logger = structlog.get_logger()
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
|
@ -76,12 +76,13 @@ class ModbusServer:
|
||||||
# Initialize data blocks
|
# Initialize data blocks
|
||||||
await self._initialize_datastore()
|
await self._initialize_datastore()
|
||||||
|
|
||||||
# Start server
|
# Start server as a task that can be cancelled
|
||||||
self.server = await StartAsyncTcpServer(
|
self.server_task = asyncio.create_task(
|
||||||
|
StartAsyncTcpServer(
|
||||||
context=self.context,
|
context=self.context,
|
||||||
framer=ModbusSocketFramer,
|
framer=ModbusSocketFramer,
|
||||||
address=(self.host, self.port),
|
address=(self.host, self.port)
|
||||||
defer_start=False
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Log security configuration
|
# Log security configuration
|
||||||
|
|
@ -98,9 +99,9 @@ class ModbusServer:
|
||||||
|
|
||||||
# Log security event
|
# Log security event
|
||||||
self.audit_logger.log_security_event(
|
self.audit_logger.log_security_event(
|
||||||
event_type="SERVER_START",
|
event_type=AuditEventType.SYSTEM_START,
|
||||||
severity="INFO",
|
severity=AuditSeverity.LOW,
|
||||||
details={
|
event_data={
|
||||||
"protocol": "MODBUS_TCP",
|
"protocol": "MODBUS_TCP",
|
||||||
"host": self.host,
|
"host": self.host,
|
||||||
"port": self.port,
|
"port": self.port,
|
||||||
|
|
@ -128,9 +129,9 @@ class ModbusServer:
|
||||||
if self.allowed_ips and client_ip not in self.allowed_ips:
|
if self.allowed_ips and client_ip not in self.allowed_ips:
|
||||||
# Log unauthorized access attempt
|
# Log unauthorized access attempt
|
||||||
self.audit_logger.log_security_event(
|
self.audit_logger.log_security_event(
|
||||||
event_type="UNAUTHORIZED_ACCESS",
|
event_type=AuditEventType.ACCESS_DENIED,
|
||||||
severity="WARNING",
|
severity=AuditSeverity.HIGH,
|
||||||
details={
|
event_data={
|
||||||
"protocol": "MODBUS_TCP",
|
"protocol": "MODBUS_TCP",
|
||||||
"client_ip": client_ip,
|
"client_ip": client_ip,
|
||||||
"allowed_ips": self.allowed_ips,
|
"allowed_ips": self.allowed_ips,
|
||||||
|
|
@ -164,9 +165,9 @@ class ModbusServer:
|
||||||
if self.request_counts[client_ip] >= self.rate_limit_per_minute:
|
if self.request_counts[client_ip] >= self.rate_limit_per_minute:
|
||||||
# Log rate limit violation
|
# Log rate limit violation
|
||||||
self.audit_logger.log_security_event(
|
self.audit_logger.log_security_event(
|
||||||
event_type="RATE_LIMIT_EXCEEDED",
|
event_type=AuditEventType.ACCESS_DENIED,
|
||||||
severity="WARNING",
|
severity=AuditSeverity.MEDIUM,
|
||||||
details={
|
event_data={
|
||||||
"protocol": "MODBUS_TCP",
|
"protocol": "MODBUS_TCP",
|
||||||
"client_ip": client_ip,
|
"client_ip": client_ip,
|
||||||
"request_count": self.request_counts[client_ip],
|
"request_count": self.request_counts[client_ip],
|
||||||
|
|
@ -201,9 +202,9 @@ class ModbusServer:
|
||||||
sensitive_functions = {6, 16} # Write single register, write multiple registers
|
sensitive_functions = {6, 16} # Write single register, write multiple registers
|
||||||
if function_code in sensitive_functions:
|
if function_code in sensitive_functions:
|
||||||
self.audit_logger.log_security_event(
|
self.audit_logger.log_security_event(
|
||||||
event_type="MODBUS_WRITE_OPERATION",
|
event_type=AuditEventType.SETPOINT_CHANGED,
|
||||||
severity="INFO",
|
severity=AuditSeverity.LOW,
|
||||||
details={
|
event_data={
|
||||||
"protocol": "MODBUS_TCP",
|
"protocol": "MODBUS_TCP",
|
||||||
"client_ip": client_ip,
|
"client_ip": client_ip,
|
||||||
"function_code": function_code,
|
"function_code": function_code,
|
||||||
|
|
@ -246,15 +247,19 @@ class ModbusServer:
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
"""Stop the Modbus TCP server."""
|
"""Stop the Modbus TCP server."""
|
||||||
if self.server:
|
if hasattr(self, 'server_task') and self.server_task:
|
||||||
# Note: pymodbus doesn't have a direct stop method
|
# Cancel the server task
|
||||||
# We'll rely on the task being cancelled
|
self.server_task.cancel()
|
||||||
|
try:
|
||||||
|
await self.server_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
# Log security event
|
# Log security event
|
||||||
self.audit_logger.log_security_event(
|
self.audit_logger.log_security_event(
|
||||||
event_type="SERVER_STOP",
|
event_type=AuditEventType.SYSTEM_STOP,
|
||||||
severity="INFO",
|
severity=AuditSeverity.LOW,
|
||||||
details={
|
event_data={
|
||||||
"protocol": "MODBUS_TCP",
|
"protocol": "MODBUS_TCP",
|
||||||
"host": self.host,
|
"host": self.host,
|
||||||
"port": self.port,
|
"port": self.port,
|
||||||
|
|
@ -304,8 +309,7 @@ class ModbusServer:
|
||||||
stations = self.setpoint_manager.discovery.get_stations()
|
stations = self.setpoint_manager.discovery.get_stations()
|
||||||
address_counter = 0
|
address_counter = 0
|
||||||
|
|
||||||
for station in stations:
|
for station_id, station in stations.items():
|
||||||
station_id = station['station_id']
|
|
||||||
pumps = self.setpoint_manager.discovery.get_pumps(station_id)
|
pumps = self.setpoint_manager.discovery.get_pumps(station_id)
|
||||||
|
|
||||||
for pump in pumps:
|
for pump in pumps:
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ except ImportError:
|
||||||
|
|
||||||
from src.core.setpoint_manager import SetpointManager
|
from src.core.setpoint_manager import SetpointManager
|
||||||
from src.core.security import SecurityManager, UserRole
|
from src.core.security import SecurityManager, UserRole
|
||||||
from src.core.compliance_audit import ComplianceAuditLogger
|
from src.core.compliance_audit import ComplianceAuditLogger, AuditEventType, AuditSeverity
|
||||||
|
|
||||||
logger = structlog.get_logger()
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
|
@ -103,9 +103,9 @@ class OPCUAServer:
|
||||||
|
|
||||||
# Log security event
|
# Log security event
|
||||||
self.audit_logger.log_security_event(
|
self.audit_logger.log_security_event(
|
||||||
event_type="SERVER_START",
|
event_type=AuditEventType.SYSTEM_START,
|
||||||
severity="INFO",
|
severity=AuditSeverity.LOW,
|
||||||
details={
|
event_data={
|
||||||
"protocol": "OPC_UA",
|
"protocol": "OPC_UA",
|
||||||
"endpoint": self.endpoint,
|
"endpoint": self.endpoint,
|
||||||
"security_enabled": self.enable_security,
|
"security_enabled": self.enable_security,
|
||||||
|
|
@ -179,9 +179,9 @@ class OPCUAServer:
|
||||||
|
|
||||||
# Log connection event
|
# Log connection event
|
||||||
self.audit_logger.log_security_event(
|
self.audit_logger.log_security_event(
|
||||||
event_type="CLIENT_CONNECT",
|
event_type=AuditEventType.USER_LOGIN,
|
||||||
severity="INFO",
|
severity=AuditSeverity.LOW,
|
||||||
details={
|
event_data={
|
||||||
"protocol": "OPC_UA",
|
"protocol": "OPC_UA",
|
||||||
"client_id": client_id,
|
"client_id": client_id,
|
||||||
"endpoint": endpoint,
|
"endpoint": endpoint,
|
||||||
|
|
@ -205,9 +205,9 @@ class OPCUAServer:
|
||||||
if client_info:
|
if client_info:
|
||||||
# Log disconnection event
|
# Log disconnection event
|
||||||
self.audit_logger.log_security_event(
|
self.audit_logger.log_security_event(
|
||||||
event_type="CLIENT_DISCONNECT",
|
event_type=AuditEventType.USER_LOGOUT,
|
||||||
severity="INFO",
|
severity=AuditSeverity.LOW,
|
||||||
details={
|
event_data={
|
||||||
"protocol": "OPC_UA",
|
"protocol": "OPC_UA",
|
||||||
"client_id": client_id,
|
"client_id": client_id,
|
||||||
"endpoint": client_info['endpoint'],
|
"endpoint": client_info['endpoint'],
|
||||||
|
|
@ -228,9 +228,9 @@ class OPCUAServer:
|
||||||
|
|
||||||
# Log security event
|
# Log security event
|
||||||
self.audit_logger.log_security_event(
|
self.audit_logger.log_security_event(
|
||||||
event_type="SERVER_STOP",
|
event_type=AuditEventType.SYSTEM_STOP,
|
||||||
severity="INFO",
|
severity=AuditSeverity.LOW,
|
||||||
details={
|
event_data={
|
||||||
"protocol": "OPC_UA",
|
"protocol": "OPC_UA",
|
||||||
"endpoint": self.endpoint,
|
"endpoint": self.endpoint,
|
||||||
"connected_clients": len(self.connected_clients)
|
"connected_clients": len(self.connected_clients)
|
||||||
|
|
@ -273,9 +273,9 @@ class OPCUAServer:
|
||||||
|
|
||||||
# Log access attempt
|
# Log access attempt
|
||||||
self.audit_logger.log_security_event(
|
self.audit_logger.log_security_event(
|
||||||
event_type="NODE_ACCESS_ATTEMPT",
|
event_type=AuditEventType.ACCESS_DENIED,
|
||||||
severity="INFO",
|
severity=AuditSeverity.MEDIUM,
|
||||||
details={
|
event_data={
|
||||||
"protocol": "OPC_UA",
|
"protocol": "OPC_UA",
|
||||||
"client_id": str(session.session_id),
|
"client_id": str(session.session_id),
|
||||||
"node": str(node),
|
"node": str(node),
|
||||||
|
|
@ -308,8 +308,7 @@ class OPCUAServer:
|
||||||
# Create stations and pumps structure
|
# Create stations and pumps structure
|
||||||
stations = self.setpoint_manager.discovery.get_stations()
|
stations = self.setpoint_manager.discovery.get_stations()
|
||||||
|
|
||||||
for station in stations:
|
for station_id, station in stations.items():
|
||||||
station_id = station['station_id']
|
|
||||||
|
|
||||||
# Create station folder
|
# Create station folder
|
||||||
station_folder = await calejo_folder.add_folder(
|
station_folder = await calejo_folder.add_folder(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue