CalejoControl/src/protocols/modbus_server.py

476 lines
19 KiB
Python

"""
Modbus TCP Server for Calejo Control Adapter.
Provides Modbus TCP interface for SCADA systems to access setpoints and status.
"""
import asyncio
from typing import Dict, Optional, Tuple, Any
from datetime import datetime
import structlog
from pymodbus.server import StartAsyncTcpServer
from pymodbus.datastore import ModbusSequentialDataBlock
from pymodbus.datastore import ModbusSlaveContext, ModbusServerContext
from pymodbus.transaction import ModbusSocketFramer
from src.core.setpoint_manager import SetpointManager
from src.core.security import SecurityManager
from src.core.compliance_audit import ComplianceAuditLogger, AuditEventType, AuditSeverity
logger = structlog.get_logger()
class ModbusServer:
"""Modbus TCP Server for Calejo Control Adapter."""
def __init__(
self,
setpoint_manager: SetpointManager,
security_manager: SecurityManager,
audit_logger: ComplianceAuditLogger,
host: str = "0.0.0.0",
port: int = 502,
unit_id: int = 1,
enable_security: bool = True,
allowed_ips: Optional[list] = None,
rate_limit_per_minute: int = 60
):
self.setpoint_manager = setpoint_manager
self.security_manager = security_manager
self.audit_logger = audit_logger
self.host = host
self.port = port
self.unit_id = unit_id
self.enable_security = enable_security
self.allowed_ips = allowed_ips or []
self.rate_limit_per_minute = rate_limit_per_minute
self.server = None
self.context = None
# Security tracking
self.connected_clients: Dict[str, Dict] = {} # client_ip -> client_info
self.request_counts: Dict[str, int] = {} # client_ip -> request_count
self.last_request_time: Dict[str, datetime] = {} # client_ip -> last_request_time
# Memory mapping
self.holding_registers = None
self.input_registers = None
self.coils = None
# Register mapping configuration
self.REGISTER_CONFIG = {
'SETPOINT_BASE': 0, # Holding register 0-99: Setpoints (Hz * 10)
'STATUS_BASE': 100, # Input register 100-199: Status codes
'SAFETY_BASE': 200, # Input register 200-299: Safety status
'EMERGENCY_STOP_COIL': 0, # Coil 0: Emergency stop status
'FAILSAFE_COIL': 1, # Coil 1: Failsafe mode status
'SECURITY_STATUS_BASE': 300, # Input register 300-399: Security status
}
# Pump address mapping
self.pump_addresses = {} # (station_id, pump_id) -> register_offset
async def start(self):
"""Start the Modbus TCP server."""
try:
# Initialize data blocks
await self._initialize_datastore()
# Start server as a task that can be cancelled
self.server_task = asyncio.create_task(
StartAsyncTcpServer(
context=self.context,
framer=ModbusSocketFramer,
address=(self.host, self.port)
)
)
# Log security configuration
security_mode = "secure" if self.enable_security else "insecure"
logger.info(
"modbus_server_started",
host=self.host,
port=self.port,
unit_id=self.unit_id,
security_mode=security_mode,
allowed_ips_count=len(self.allowed_ips),
rate_limit=self.rate_limit_per_minute
)
# Log security event
self.audit_logger.log_security_event(
event_type=AuditEventType.SYSTEM_START,
severity=AuditSeverity.LOW,
event_data={
"protocol": "MODBUS_TCP",
"host": self.host,
"port": self.port,
"security_enabled": self.enable_security,
"allowed_ips_count": len(self.allowed_ips),
"rate_limit_per_minute": self.rate_limit_per_minute
}
)
# Start background task to update registers
asyncio.create_task(self._update_registers_loop())
# Start background task for security monitoring
asyncio.create_task(self._security_monitoring_loop())
except Exception as e:
logger.error("failed_to_start_modbus_server", error=str(e))
raise
def _check_ip_access(self, client_ip: str) -> bool:
"""Check if client IP is allowed to connect."""
if not self.enable_security:
return True
if self.allowed_ips and client_ip not in self.allowed_ips:
# Log unauthorized access attempt
self.audit_logger.log_security_event(
event_type=AuditEventType.ACCESS_DENIED,
severity=AuditSeverity.HIGH,
event_data={
"protocol": "MODBUS_TCP",
"client_ip": client_ip,
"allowed_ips": self.allowed_ips,
"reason": "IP not in allowed list"
}
)
return False
return True
def _check_rate_limit(self, client_ip: str) -> bool:
"""Check if client is within rate limits."""
if not self.enable_security:
return True
current_time = datetime.now()
# Reset counter if more than a minute has passed
if client_ip in self.last_request_time:
time_diff = (current_time - self.last_request_time[client_ip]).total_seconds()
if time_diff > 60:
self.request_counts[client_ip] = 0
# Initialize counters if needed
if client_ip not in self.request_counts:
self.request_counts[client_ip] = 0
if client_ip not in self.last_request_time:
self.last_request_time[client_ip] = current_time
# Check rate limit
if self.request_counts[client_ip] >= self.rate_limit_per_minute:
# Log rate limit violation
self.audit_logger.log_security_event(
event_type=AuditEventType.ACCESS_DENIED,
severity=AuditSeverity.MEDIUM,
event_data={
"protocol": "MODBUS_TCP",
"client_ip": client_ip,
"request_count": self.request_counts[client_ip],
"rate_limit": self.rate_limit_per_minute
}
)
return False
# Update counters
self.request_counts[client_ip] += 1
self.last_request_time[client_ip] = current_time
return True
def _log_client_request(self, client_ip: str, function_code: int, register_address: int):
"""Log client request for security monitoring."""
# Track connected clients
if client_ip not in self.connected_clients:
self.connected_clients[client_ip] = {
'first_seen': datetime.now(),
'last_seen': datetime.now(),
'request_count': 0,
'function_codes': set()
}
client_info = self.connected_clients[client_ip]
client_info['last_seen'] = datetime.now()
client_info['request_count'] += 1
client_info['function_codes'].add(function_code)
# Log detailed request for sensitive operations
sensitive_functions = {6, 16} # Write single register, write multiple registers
if function_code in sensitive_functions:
self.audit_logger.log_security_event(
event_type=AuditEventType.SETPOINT_CHANGED,
severity=AuditSeverity.LOW,
event_data={
"protocol": "MODBUS_TCP",
"client_ip": client_ip,
"function_code": function_code,
"register_address": register_address,
"timestamp": datetime.now().isoformat()
}
)
async def _security_monitoring_loop(self):
"""Background task for security monitoring."""
while True:
try:
await self._cleanup_old_clients()
await asyncio.sleep(60) # Check every minute
except Exception as e:
logger.error("security_monitoring_error", error=str(e))
await asyncio.sleep(10)
async def _cleanup_old_clients(self):
"""Remove clients that haven't been seen for a while."""
current_time = datetime.now()
timeout_minutes = 30 # Remove clients after 30 minutes of inactivity
clients_to_remove = []
for client_ip, client_info in self.connected_clients.items():
time_diff = (current_time - client_info['last_seen']).total_seconds() / 60
if time_diff > timeout_minutes:
clients_to_remove.append(client_ip)
for client_ip in clients_to_remove:
self.connected_clients.pop(client_ip, None)
self.request_counts.pop(client_ip, None)
self.last_request_time.pop(client_ip, None)
logger.info(
"modbus_client_removed",
client_ip=client_ip,
reason="inactivity"
)
async def stop(self):
"""Stop the Modbus TCP server."""
if hasattr(self, 'server_task') and self.server_task:
# Cancel the server task
self.server_task.cancel()
try:
await self.server_task
except asyncio.CancelledError:
pass
# Log security event
self.audit_logger.log_security_event(
event_type=AuditEventType.SYSTEM_STOP,
severity=AuditSeverity.LOW,
event_data={
"protocol": "MODBUS_TCP",
"host": self.host,
"port": self.port,
"connected_clients": len(self.connected_clients)
}
)
logger.info("modbus_server_stopping", connected_clients=len(self.connected_clients))
async def _initialize_datastore(self):
"""Initialize the Modbus data store."""
# Initialize data blocks
# Holding registers (read/write): Setpoints
self.holding_registers = ModbusSequentialDataBlock(
self.REGISTER_CONFIG['SETPOINT_BASE'],
[0] * 100 # 100 registers for setpoints
)
# Input registers (read-only): Status, safety, and security
self.input_registers = ModbusSequentialDataBlock(
self.REGISTER_CONFIG['STATUS_BASE'],
[0] * 300 # 300 registers for status, safety, and security
)
# Coils (read-only): Binary status
self.coils = ModbusSequentialDataBlock(
self.REGISTER_CONFIG['EMERGENCY_STOP_COIL'],
[False] * 10 # 10 coils for binary status
)
# Create slave context
store = ModbusSlaveContext(
hr=self.holding_registers, # Holding registers
ir=self.input_registers, # Input registers
co=self.coils, # Coils
zero_mode=True
)
# Create server context
self.context = ModbusServerContext(slaves=store, single=True)
# Initialize pump address mapping
await self._initialize_pump_mapping()
async def _initialize_pump_mapping(self):
"""Initialize mapping between pumps and Modbus addresses."""
stations = self.setpoint_manager.discovery.get_stations()
address_counter = 0
for station_id, station in stations.items():
pumps = self.setpoint_manager.discovery.get_pumps(station_id)
for pump in pumps:
pump_id = pump['pump_id']
# Assign register addresses
self.pump_addresses[(station_id, pump_id)] = {
'setpoint_register': address_counter,
'status_register': address_counter + self.REGISTER_CONFIG['STATUS_BASE'],
'safety_register': address_counter + self.REGISTER_CONFIG['SAFETY_BASE']
}
address_counter += 1
# Don't exceed available registers
if address_counter >= 100:
logger.warning("modbus_register_limit_reached")
break
async def _update_registers_loop(self):
"""Background task to update Modbus registers periodically."""
while True:
try:
await self._update_registers()
await asyncio.sleep(5) # Update every 5 seconds
except Exception as e:
logger.error("failed_to_update_registers", error=str(e))
await asyncio.sleep(10) # Wait longer on error
async def _update_registers(self):
"""Update all Modbus register values."""
# Update pump setpoints and status
for (station_id, pump_id), addresses in self.pump_addresses.items():
try:
# Get current setpoint
setpoint = self.setpoint_manager.get_current_setpoint(station_id, pump_id)
if setpoint is not None:
# Convert setpoint to integer (Hz * 10 for precision)
setpoint_int = int(setpoint * 10)
# Update holding register (setpoint)
self.holding_registers.setValues(
addresses['setpoint_register'],
[setpoint_int]
)
# Determine status code
status_code = 0 # Normal operation
safety_code = 0 # Normal safety
if self.setpoint_manager.emergency_stop_manager.is_emergency_stop_active(station_id, pump_id):
status_code = 2 # Emergency stop
safety_code = 1
elif self.setpoint_manager.watchdog.is_failsafe_active(station_id, pump_id):
status_code = 1 # Failsafe mode
safety_code = 2
# Update input registers (status and safety)
self.input_registers.setValues(
addresses['status_register'],
[status_code]
)
self.input_registers.setValues(
addresses['safety_register'],
[safety_code]
)
except Exception as e:
logger.error(
"failed_to_update_pump_registers",
station_id=station_id,
pump_id=pump_id,
error=str(e)
)
# Update global status coils
try:
# Check if any emergency stops are active
any_emergency_stop = (
self.setpoint_manager.emergency_stop_manager.system_emergency_stop or
len(self.setpoint_manager.emergency_stop_manager.emergency_stop_stations) > 0 or
len(self.setpoint_manager.emergency_stop_manager.emergency_stop_pumps) > 0
)
# Check if any failsafe modes are active
any_failsafe = any(
self.setpoint_manager.watchdog.is_failsafe_active(station_id, pump_id)
for (station_id, pump_id) in self.pump_addresses.keys()
)
# Update coils
self.coils.setValues(
self.REGISTER_CONFIG['EMERGENCY_STOP_COIL'],
[any_emergency_stop]
)
self.coils.setValues(
self.REGISTER_CONFIG['FAILSAFE_COIL'],
[any_failsafe]
)
# Update security status registers
await self._update_security_registers()
except Exception as e:
logger.error("failed_to_update_status_coils", error=str(e))
def get_pump_setpoint_address(self, station_id: str, pump_id: str) -> Optional[int]:
"""Get Modbus register address for a pump's setpoint."""
addresses = self.pump_addresses.get((station_id, pump_id))
return addresses['setpoint_register'] if addresses else None
def get_pump_status_address(self, station_id: str, pump_id: str) -> Optional[int]:
"""Get Modbus register address for a pump's status."""
addresses = self.pump_addresses.get((station_id, pump_id))
return addresses['status_register'] if addresses else None
async def _update_security_registers(self):
"""Update Modbus registers with security status information."""
try:
# Security status codes
security_status = {
'security_enabled': 1 if self.enable_security else 0,
'connected_clients': len(self.connected_clients),
'rate_limit': self.rate_limit_per_minute,
'allowed_ips_count': len(self.allowed_ips),
'total_requests': sum(self.request_counts.values())
}
# Update security status registers
self.input_registers.setValues(
self.REGISTER_CONFIG['SECURITY_STATUS_BASE'],
[
security_status['security_enabled'],
security_status['connected_clients'],
security_status['rate_limit'],
security_status['allowed_ips_count'],
security_status['total_requests']
]
)
except Exception as e:
logger.error("failed_to_update_security_registers", error=str(e))
def get_security_status(self) -> Dict[str, Any]:
"""Get current security status of Modbus server."""
return {
"security_enabled": self.enable_security,
"connected_clients": len(self.connected_clients),
"allowed_ips": self.allowed_ips,
"rate_limit_per_minute": self.rate_limit_per_minute,
"client_details": [
{
"client_ip": client_ip,
"first_seen": info['first_seen'].isoformat(),
"last_seen": info['last_seen'].isoformat(),
"request_count": info['request_count'],
"function_codes": list(info['function_codes'])
}
for client_ip, info in self.connected_clients.items()
]
}