230 lines
9.8 KiB
Python
230 lines
9.8 KiB
Python
"""
|
|
Unit tests for SafetyLimitEnforcer class.
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import Mock, patch
|
|
from typing import Dict, Any
|
|
|
|
from src.core.safety import SafetyLimitEnforcer, SafetyLimits
|
|
|
|
|
|
class TestSafetyLimitEnforcer:
|
|
"""Test cases for SafetyLimitEnforcer."""
|
|
|
|
@pytest.fixture
|
|
def safety_enforcer(self):
|
|
"""Create a test safety enforcer."""
|
|
mock_db_client = Mock()
|
|
enforcer = SafetyLimitEnforcer(mock_db_client)
|
|
return enforcer
|
|
|
|
@pytest.fixture
|
|
def mock_safety_limits(self):
|
|
"""Create mock safety limits."""
|
|
return SafetyLimits(
|
|
hard_min_speed_hz=20.0,
|
|
hard_max_speed_hz=50.0,
|
|
hard_min_level_m=1.0,
|
|
hard_max_level_m=4.0,
|
|
hard_max_power_kw=60.0,
|
|
max_speed_change_hz_per_min=5.0
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_load_safety_limits_success(self, safety_enforcer):
|
|
"""Test successful loading of safety limits."""
|
|
# Mock database response
|
|
mock_limits = [
|
|
{
|
|
'station_id': 'STATION_001',
|
|
'pump_id': 'PUMP_001',
|
|
'hard_min_speed_hz': 20.0,
|
|
'hard_max_speed_hz': 50.0,
|
|
'hard_min_level_m': 1.0,
|
|
'hard_max_level_m': 4.0,
|
|
'hard_max_power_kw': 60.0,
|
|
'max_speed_change_hz_per_min': 5.0
|
|
}
|
|
]
|
|
safety_enforcer.db_client.get_safety_limits.return_value = mock_limits
|
|
|
|
await safety_enforcer.load_safety_limits()
|
|
|
|
# Verify limits were loaded
|
|
assert len(safety_enforcer.safety_limits_cache) == 1
|
|
key = ('STATION_001', 'PUMP_001')
|
|
assert key in safety_enforcer.safety_limits_cache
|
|
|
|
limits = safety_enforcer.safety_limits_cache[key]
|
|
assert limits.hard_min_speed_hz == 20.0
|
|
assert limits.hard_max_speed_hz == 50.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_load_safety_limits_failure(self, safety_enforcer):
|
|
"""Test loading safety limits failure."""
|
|
safety_enforcer.db_client.get_safety_limits.side_effect = Exception("Database error")
|
|
|
|
with pytest.raises(Exception, match="Database error"):
|
|
await safety_enforcer.load_safety_limits()
|
|
|
|
def test_enforce_setpoint_within_limits(self, safety_enforcer, mock_safety_limits):
|
|
"""Test setpoint within limits is not modified."""
|
|
safety_enforcer.safety_limits_cache[('STATION_001', 'PUMP_001')] = mock_safety_limits
|
|
|
|
enforced, violations = safety_enforcer.enforce_setpoint('STATION_001', 'PUMP_001', 35.0)
|
|
|
|
assert enforced == 35.0
|
|
assert violations == []
|
|
|
|
def test_enforce_setpoint_below_min(self, safety_enforcer, mock_safety_limits):
|
|
"""Test setpoint below minimum is clamped."""
|
|
safety_enforcer.safety_limits_cache[('STATION_001', 'PUMP_001')] = mock_safety_limits
|
|
|
|
enforced, violations = safety_enforcer.enforce_setpoint('STATION_001', 'PUMP_001', 15.0)
|
|
|
|
assert enforced == 20.0 # Clamped to hard_min_speed_hz
|
|
assert len(violations) == 1
|
|
assert "BELOW_MIN_SPEED" in violations[0]
|
|
|
|
def test_enforce_setpoint_above_max(self, safety_enforcer, mock_safety_limits):
|
|
"""Test setpoint above maximum is clamped."""
|
|
safety_enforcer.safety_limits_cache[('STATION_001', 'PUMP_001')] = mock_safety_limits
|
|
|
|
enforced, violations = safety_enforcer.enforce_setpoint('STATION_001', 'PUMP_001', 55.0)
|
|
|
|
assert enforced == 50.0 # Clamped to hard_max_speed_hz
|
|
assert len(violations) == 1
|
|
assert "ABOVE_MAX_SPEED" in violations[0]
|
|
|
|
def test_enforce_setpoint_no_limits(self, safety_enforcer):
|
|
"""Test setpoint without safety limits defined."""
|
|
enforced, violations = safety_enforcer.enforce_setpoint('STATION_001', 'PUMP_001', 35.0)
|
|
|
|
assert enforced == 0.0 # Default to 0 when no limits
|
|
assert violations == ["NO_SAFETY_LIMITS_DEFINED"]
|
|
|
|
def test_enforce_setpoint_with_rate_limit(self, safety_enforcer, mock_safety_limits):
|
|
"""Test setpoint with rate of change limit."""
|
|
safety_enforcer.safety_limits_cache[('STATION_001', 'PUMP_001')] = mock_safety_limits
|
|
|
|
# Set previous setpoint
|
|
safety_enforcer.previous_setpoints[('STATION_001', 'PUMP_001')] = 30.0
|
|
|
|
# Test large increase that exceeds rate limit but is within max speed
|
|
# 30 + 26 = 56, which exceeds 25 Hz limit but is within 50 Hz max
|
|
enforced, violations = safety_enforcer.enforce_setpoint('STATION_001', 'PUMP_001', 56.0)
|
|
|
|
# Should be limited to 30 + 25 = 55.0 (max 5 Hz/min * 5 min = 25 Hz increase)
|
|
# But since 55.0 is within the hard_max_speed_hz of 50.0, it gets clamped to 50.0
|
|
assert enforced == 50.0
|
|
assert len(violations) == 1 # Only max speed violation (rate limit not triggered due to clamping)
|
|
violation_types = [v.split(':')[0] for v in violations]
|
|
assert "ABOVE_MAX_SPEED" in violation_types
|
|
|
|
def test_enforce_setpoint_multiple_violations(self, safety_enforcer, mock_safety_limits):
|
|
"""Test setpoint with multiple violations."""
|
|
safety_enforcer.safety_limits_cache[('STATION_001', 'PUMP_001')] = mock_safety_limits
|
|
|
|
# Set previous setpoint
|
|
safety_enforcer.previous_setpoints[('STATION_001', 'PUMP_001')] = 30.0
|
|
|
|
# Test setpoint that violates max speed
|
|
enforced, violations = safety_enforcer.enforce_setpoint('STATION_001', 'PUMP_001', 60.0)
|
|
|
|
# Should be limited to 50.0 (hard_max_speed_hz)
|
|
assert enforced == 50.0
|
|
assert len(violations) == 1
|
|
violation_types = [v.split(':')[0] for v in violations]
|
|
assert "ABOVE_MAX_SPEED" in violation_types
|
|
|
|
def test_get_safety_limits_exists(self, safety_enforcer, mock_safety_limits):
|
|
"""Test getting existing safety limits."""
|
|
safety_enforcer.safety_limits_cache[('STATION_001', 'PUMP_001')] = mock_safety_limits
|
|
|
|
limits = safety_enforcer.get_safety_limits('STATION_001', 'PUMP_001')
|
|
|
|
assert limits is not None
|
|
assert limits.hard_min_speed_hz == 20.0
|
|
|
|
def test_get_safety_limits_not_exists(self, safety_enforcer):
|
|
"""Test getting non-existent safety limits."""
|
|
limits = safety_enforcer.get_safety_limits('STATION_001', 'PUMP_001')
|
|
|
|
assert limits is None
|
|
|
|
def test_has_safety_limits_exists(self, safety_enforcer, mock_safety_limits):
|
|
"""Test checking for existing safety limits."""
|
|
safety_enforcer.safety_limits_cache[('STATION_001', 'PUMP_001')] = mock_safety_limits
|
|
|
|
assert safety_enforcer.has_safety_limits('STATION_001', 'PUMP_001') is True
|
|
|
|
def test_has_safety_limits_not_exists(self, safety_enforcer):
|
|
"""Test checking for non-existent safety limits."""
|
|
assert safety_enforcer.has_safety_limits('STATION_001', 'PUMP_001') is False
|
|
|
|
def test_clear_safety_limits(self, safety_enforcer, mock_safety_limits):
|
|
"""Test clearing safety limits."""
|
|
safety_enforcer.safety_limits_cache[('STATION_001', 'PUMP_001')] = mock_safety_limits
|
|
safety_enforcer.previous_setpoints[('STATION_001', 'PUMP_001')] = 35.0
|
|
|
|
safety_enforcer.clear_safety_limits('STATION_001', 'PUMP_001')
|
|
|
|
assert ('STATION_001', 'PUMP_001') not in safety_enforcer.safety_limits_cache
|
|
assert ('STATION_001', 'PUMP_001') not in safety_enforcer.previous_setpoints
|
|
|
|
def test_clear_all_safety_limits(self, safety_enforcer, mock_safety_limits):
|
|
"""Test clearing all safety limits."""
|
|
safety_enforcer.safety_limits_cache[('STATION_001', 'PUMP_001')] = mock_safety_limits
|
|
safety_enforcer.safety_limits_cache[('STATION_002', 'PUMP_001')] = mock_safety_limits
|
|
safety_enforcer.previous_setpoints[('STATION_001', 'PUMP_001')] = 35.0
|
|
|
|
safety_enforcer.clear_all_safety_limits()
|
|
|
|
assert len(safety_enforcer.safety_limits_cache) == 0
|
|
assert len(safety_enforcer.previous_setpoints) == 0
|
|
|
|
def test_get_loaded_limits_count(self, safety_enforcer, mock_safety_limits):
|
|
"""Test getting count of loaded safety limits."""
|
|
assert safety_enforcer.get_loaded_limits_count() == 0
|
|
|
|
safety_enforcer.safety_limits_cache[('STATION_001', 'PUMP_001')] = mock_safety_limits
|
|
safety_enforcer.safety_limits_cache[('STATION_002', 'PUMP_001')] = mock_safety_limits
|
|
|
|
assert safety_enforcer.get_loaded_limits_count() == 2
|
|
|
|
def test_safety_limits_dataclass(self):
|
|
"""Test SafetyLimits dataclass."""
|
|
limits = SafetyLimits(
|
|
hard_min_speed_hz=20.0,
|
|
hard_max_speed_hz=50.0,
|
|
hard_min_level_m=1.0,
|
|
hard_max_level_m=4.0,
|
|
hard_max_power_kw=60.0,
|
|
max_speed_change_hz_per_min=5.0
|
|
)
|
|
|
|
assert limits.hard_min_speed_hz == 20.0
|
|
assert limits.hard_max_speed_hz == 50.0
|
|
assert limits.hard_min_level_m == 1.0
|
|
assert limits.hard_max_level_m == 4.0
|
|
assert limits.hard_max_power_kw == 60.0
|
|
assert limits.max_speed_change_hz_per_min == 5.0
|
|
|
|
def test_safety_limits_with_optional_fields(self):
|
|
"""Test SafetyLimits with optional fields."""
|
|
limits = SafetyLimits(
|
|
hard_min_speed_hz=20.0,
|
|
hard_max_speed_hz=50.0,
|
|
hard_min_level_m=None,
|
|
hard_max_level_m=None,
|
|
hard_max_power_kw=None,
|
|
max_speed_change_hz_per_min=5.0
|
|
)
|
|
|
|
assert limits.hard_min_speed_hz == 20.0
|
|
assert limits.hard_max_speed_hz == 50.0
|
|
assert limits.hard_min_level_m is None
|
|
assert limits.hard_max_level_m is None
|
|
assert limits.hard_max_power_kw is None
|
|
assert limits.max_speed_change_hz_per_min == 5.0 |