""" Unit tests for Phase 5 Protocol Server Enhancements. Tests the performance optimizations and security enhancements added to the protocol servers in Phase 5. """ import pytest import asyncio from datetime import datetime, timedelta from unittest.mock import Mock, AsyncMock, patch from src.protocols.opcua_server import OPCUAServer, NodeCache from src.protocols.modbus_server import ModbusServer, ConnectionPool from src.protocols.rest_api import RESTAPIServer, ResponseCache class TestNodeCache: """Test OPC UA Node Cache functionality.""" def test_cache_initialization(self): """Test cache initialization with default parameters.""" cache = NodeCache() assert cache.max_size == 1000 assert cache.ttl_seconds == 300 assert len(cache._cache) == 0 def test_cache_set_and_get(self): """Test setting and getting values from cache.""" cache = NodeCache() mock_node = Mock() # Set value cache.set("test_node", mock_node) # Get value retrieved = cache.get("test_node") assert retrieved == mock_node assert len(cache._cache) == 1 def test_cache_expiration(self): """Test cache expiration functionality.""" cache = NodeCache(ttl_seconds=1) # Very short TTL for testing mock_node = Mock() # Set value cache.set("test_node", mock_node) # Should be available immediately assert cache.get("test_node") == mock_node # After expiration, should return None with patch('src.protocols.opcua_server.datetime') as mock_datetime: mock_datetime.now.return_value = datetime.now() + timedelta(seconds=2) assert cache.get("test_node") is None def test_cache_eviction(self): """Test cache eviction when max size is reached.""" cache = NodeCache(max_size=2) # Fill cache cache.set("node1", Mock()) cache.set("node2", Mock()) assert len(cache._cache) == 2 # Add third node, should evict oldest cache.set("node3", Mock()) assert len(cache._cache) == 2 assert "node1" not in cache._cache # First node should be evicted class TestConnectionPool: """Test Modbus Connection Pool functionality.""" def test_pool_initialization(self): """Test connection pool initialization.""" pool = ConnectionPool(max_connections=50) assert pool.max_connections == 50 assert len(pool.active_connections) == 0 def test_can_accept_connection(self): """Test connection acceptance logic.""" pool = ConnectionPool(max_connections=2) # Should accept first connection assert pool.can_accept_connection("192.168.1.1") is True # Add connection pool.add_connection("192.168.1.1", {"client_ip": "192.168.1.1"}) # Should accept second connection assert pool.can_accept_connection("192.168.1.2") is True # Add second connection pool.add_connection("192.168.1.2", {"client_ip": "192.168.1.2"}) # Should reject third connection assert pool.can_accept_connection("192.168.1.3") is False def test_stale_connection_removal(self): """Test removal of stale connections.""" pool = ConnectionPool(max_connections=2) # Add a stale connection with patch('src.protocols.modbus_server.datetime') as mock_datetime: mock_datetime.now.return_value = datetime.now() - timedelta(minutes=10) pool.add_connection("192.168.1.1", {"client_ip": "192.168.1.1"}) # Should accept new connection (stale one should be removed) assert pool.can_accept_connection("192.168.1.2") is True def test_connection_stats(self): """Test connection pool statistics.""" pool = ConnectionPool(max_connections=10) # Add some connections pool.add_connection("192.168.1.1", {"client_ip": "192.168.1.1"}) pool.add_connection("192.168.1.2", {"client_ip": "192.168.1.2"}) stats = pool.get_stats() assert stats["active_connections"] == 2 assert stats["max_connections"] == 10 assert "192.168.1.1" in stats["connection_details"] class TestResponseCache: """Test REST API Response Cache functionality.""" def test_cache_initialization(self): """Test response cache initialization.""" cache = ResponseCache() assert cache.max_size == 1000 assert cache.ttl_seconds == 60 assert len(cache._cache) == 0 def test_cache_set_and_get(self): """Test setting and getting responses from cache.""" cache = ResponseCache() test_response = {"data": "test"} # Set response cache.set("test_key", test_response) # Get response retrieved = cache.get("test_key") assert retrieved == test_response assert len(cache._cache) == 1 def test_cache_stats(self): """Test cache statistics.""" cache = ResponseCache(max_size=50, ttl_seconds=30) # Add some responses cache.set("key1", {"data": "test1"}) cache.set("key2", {"data": "test2"}) stats = cache.get_stats() assert stats["cache_size"] == 2 assert stats["max_size"] == 50 assert stats["ttl_seconds"] == 30 class TestOPCUAServerEnhancements: """Test OPC UA Server performance enhancements.""" @pytest.fixture def mock_components(self): """Create mock components for OPC UA server.""" setpoint_manager = Mock() security_manager = Mock() audit_logger = Mock() return setpoint_manager, security_manager, audit_logger def test_opcua_server_initialization_with_cache(self, mock_components): """Test OPC UA server initialization with caching enabled.""" setpoint_manager, security_manager, audit_logger = mock_components server = OPCUAServer( setpoint_manager=setpoint_manager, security_manager=security_manager, audit_logger=audit_logger, enable_caching=True, cache_ttl_seconds=300 ) assert server.enable_caching is True assert server.cache_ttl_seconds == 300 assert server.node_cache is not None assert server._setpoint_cache == {} def test_opcua_server_initialization_without_cache(self, mock_components): """Test OPC UA server initialization with caching disabled.""" setpoint_manager, security_manager, audit_logger = mock_components server = OPCUAServer( setpoint_manager=setpoint_manager, security_manager=security_manager, audit_logger=audit_logger, enable_caching=False ) assert server.enable_caching is False assert server.node_cache is None def test_opcua_performance_status(self, mock_components): """Test OPC UA server performance status reporting.""" setpoint_manager, security_manager, audit_logger = mock_components server = OPCUAServer( setpoint_manager=setpoint_manager, security_manager=security_manager, audit_logger=audit_logger, enable_caching=True ) performance_status = server.get_performance_status() assert "caching" in performance_status assert "last_setpoint_update" in performance_status assert "connected_clients" in performance_status assert performance_status["caching"]["enabled"] is True class TestModbusServerEnhancements: """Test Modbus Server performance enhancements.""" @pytest.fixture def mock_components(self): """Create mock components for Modbus server.""" setpoint_manager = Mock() security_manager = Mock() audit_logger = Mock() return setpoint_manager, security_manager, audit_logger def test_modbus_server_initialization_with_pooling(self, mock_components): """Test Modbus server initialization with connection pooling.""" setpoint_manager, security_manager, audit_logger = mock_components server = ModbusServer( setpoint_manager=setpoint_manager, security_manager=security_manager, audit_logger=audit_logger, enable_connection_pooling=True, max_connections=50 ) assert server.enable_connection_pooling is True assert server.max_connections == 50 assert server.connection_pool is not None assert server.total_requests == 0 assert server.failed_requests == 0 def test_modbus_server_initialization_without_pooling(self, mock_components): """Test Modbus server initialization without connection pooling.""" setpoint_manager, security_manager, audit_logger = mock_components server = ModbusServer( setpoint_manager=setpoint_manager, security_manager=security_manager, audit_logger=audit_logger, enable_connection_pooling=False ) assert server.enable_connection_pooling is False assert server.connection_pool is None def test_modbus_performance_status(self, mock_components): """Test Modbus server performance status reporting.""" setpoint_manager, security_manager, audit_logger = mock_components server = ModbusServer( setpoint_manager=setpoint_manager, security_manager=security_manager, audit_logger=audit_logger, enable_connection_pooling=True ) # Simulate some activity server.total_requests = 100 server.failed_requests = 5 performance_status = server.get_performance_status() assert "total_requests" in performance_status assert "failed_requests" in performance_status assert "success_rate" in performance_status assert "connection_pool" in performance_status assert "rate_limiting" in performance_status assert performance_status["total_requests"] == 100 assert performance_status["failed_requests"] == 5 assert performance_status["success_rate"] == 95.0 class TestRESTAPIServerEnhancements: """Test REST API Server performance enhancements.""" @pytest.fixture def mock_components(self): """Create mock components for REST API server.""" setpoint_manager = Mock() emergency_stop_manager = Mock() return setpoint_manager, emergency_stop_manager def test_rest_api_server_initialization_with_cache(self, mock_components): """Test REST API server initialization with caching.""" setpoint_manager, emergency_stop_manager = mock_components server = RESTAPIServer( setpoint_manager=setpoint_manager, emergency_stop_manager=emergency_stop_manager, enable_caching=True, enable_compression=True, cache_ttl_seconds=120 ) assert server.enable_caching is True assert server.enable_compression is True assert server.cache_ttl_seconds == 120 assert server.response_cache is not None assert server.total_requests == 0 assert server.cache_hits == 0 assert server.cache_misses == 0 def test_rest_api_server_initialization_without_cache(self, mock_components): """Test REST API server initialization without caching.""" setpoint_manager, emergency_stop_manager = mock_components server = RESTAPIServer( setpoint_manager=setpoint_manager, emergency_stop_manager=emergency_stop_manager, enable_caching=False, enable_compression=False ) assert server.enable_caching is False assert server.enable_compression is False assert server.response_cache is None def test_rest_api_performance_status(self, mock_components): """Test REST API server performance status reporting.""" setpoint_manager, emergency_stop_manager = mock_components server = RESTAPIServer( setpoint_manager=setpoint_manager, emergency_stop_manager=emergency_stop_manager, enable_caching=True ) # Simulate some activity server.total_requests = 200 server.cache_hits = 150 server.cache_misses = 50 performance_status = server.get_performance_status() assert "total_requests" in performance_status assert "caching" in performance_status assert "compression" in performance_status assert performance_status["total_requests"] == 200 assert performance_status["caching"]["hits"] == 150 assert performance_status["caching"]["misses"] == 50 assert performance_status["caching"]["hit_rate"] == 75.0 class TestProtocolSecurityEnhancements: """Test protocol-specific security enhancements.""" def test_opcua_security_enhancements(self): """Test OPC UA security enhancements.""" # Verify OPC UA server has enhanced security features assert hasattr(OPCUAServer, 'get_security_status') assert hasattr(OPCUAServer, 'get_performance_status') def test_modbus_security_enhancements(self): """Test Modbus security enhancements.""" # Verify Modbus server has enhanced security features assert hasattr(ModbusServer, 'get_security_status') assert hasattr(ModbusServer, 'get_performance_status') assert hasattr(ModbusServer, '_check_connection_limit') def test_rest_api_security_enhancements(self): """Test REST API security enhancements.""" # Verify REST API server has enhanced security features assert hasattr(RESTAPIServer, 'get_performance_status') assert hasattr(RESTAPIServer, '_get_cache_key') assert hasattr(RESTAPIServer, '_get_cached_response') assert hasattr(RESTAPIServer, '_cache_response')