""" Unit tests for compliance audit components. """ import pytest from unittest.mock import Mock, patch from datetime import datetime, timezone from src.core.compliance_audit import ( ComplianceAuditLogger, AuditEventType, AuditSeverity ) from config.settings import settings class TestComplianceAuditLogger: """Test cases for ComplianceAuditLogger.""" def setup_method(self): """Set up test fixtures.""" self.mock_db_client = Mock() self.audit_logger = ComplianceAuditLogger(self.mock_db_client) # Mock settings self.original_audit_enabled = settings.audit_log_enabled settings.audit_log_enabled = True def teardown_method(self): """Clean up test fixtures.""" # Restore original settings settings.audit_log_enabled = self.original_audit_enabled def test_initialization(self): """Test initialization of ComplianceAuditLogger.""" assert self.audit_logger.db_client == self.mock_db_client assert self.audit_logger.logger is not None def test_log_compliance_event_success(self): """Test successful logging of compliance event.""" event_type = AuditEventType.USER_LOGIN severity = AuditSeverity.LOW user_id = "test_user" ip_address = "192.168.1.100" with patch.object(self.audit_logger.logger, 'info') as mock_log: self.audit_logger.log_compliance_event( event_type=event_type, severity=severity, user_id=user_id, ip_address=ip_address, action="login", resource="system", result="success" ) # Verify structured logging mock_log.assert_called_once() call_args = mock_log.call_args[1] assert call_args["event_type"] == "user_login" assert call_args["severity"] == "low" assert call_args["user_id"] == "test_user" assert call_args["ip_address"] == "192.168.1.100" # Verify database logging self.mock_db_client.execute.assert_called_once() def test_log_compliance_event_database_disabled(self): """Test logging when database audit is disabled.""" settings.audit_log_enabled = False with patch.object(self.audit_logger.logger, 'info') as mock_log: self.audit_logger.log_compliance_event( event_type=AuditEventType.USER_LOGIN, severity=AuditSeverity.LOW, user_id="test_user" ) # Verify structured logging still occurs mock_log.assert_called_once() # Verify database logging is skipped self.mock_db_client.execute.assert_not_called() def test_log_compliance_event_database_error(self): """Test logging when database operation fails.""" # Mock database error self.mock_db_client.execute.side_effect = Exception("Database error") with patch.object(self.audit_logger.logger, 'error') as mock_error_log: self.audit_logger.log_compliance_event( event_type=AuditEventType.USER_LOGIN, severity=AuditSeverity.LOW, user_id="test_user" ) # Verify error logging mock_error_log.assert_called_once() call_args = mock_error_log.call_args[1] assert "Database error" in call_args["error"] assert call_args["event_type"] == "user_login" def test_log_user_authentication_success(self): """Test logging successful user authentication.""" user_id = "test_user" ip_address = "192.168.1.100" with patch.object(self.audit_logger, 'log_compliance_event') as mock_log: self.audit_logger.log_user_authentication( user_id=user_id, result="success", ip_address=ip_address ) # Verify the correct event type and severity mock_log.assert_called_once() call_args = mock_log.call_args[1] assert call_args["event_type"] == AuditEventType.USER_LOGIN assert call_args["severity"] == AuditSeverity.LOW assert call_args["user_id"] == user_id assert call_args["ip_address"] == ip_address assert call_args["result"] == "success" def test_log_user_authentication_failure(self): """Test logging failed user authentication.""" user_id = "test_user" ip_address = "192.168.1.100" reason = "Invalid credentials" with patch.object(self.audit_logger, 'log_compliance_event') as mock_log: self.audit_logger.log_user_authentication( user_id=user_id, result="failure", ip_address=ip_address, reason=reason ) # Verify the correct event type and severity mock_log.assert_called_once() call_args = mock_log.call_args[1] assert call_args["event_type"] == AuditEventType.INVALID_AUTHENTICATION assert call_args["severity"] == AuditSeverity.HIGH assert call_args["reason"] == reason def test_log_access_control_granted(self): """Test logging granted access control.""" user_id = "test_user" action = "read" resource = "station_data" ip_address = "192.168.1.100" with patch.object(self.audit_logger, 'log_compliance_event') as mock_log: self.audit_logger.log_access_control( user_id=user_id, action=action, resource=resource, result="granted", ip_address=ip_address ) # Verify the correct event type mock_log.assert_called_once() call_args = mock_log.call_args[1] assert call_args["event_type"] == AuditEventType.DATA_READ assert call_args["severity"] == AuditSeverity.MEDIUM def test_log_access_control_denied(self): """Test logging denied access control.""" user_id = "test_user" action = "write" resource = "system_config" ip_address = "192.168.1.100" reason = "Insufficient permissions" with patch.object(self.audit_logger, 'log_compliance_event') as mock_log: self.audit_logger.log_access_control( user_id=user_id, action=action, resource=resource, result="denied", ip_address=ip_address, reason=reason ) # Verify the correct event type and severity mock_log.assert_called_once() call_args = mock_log.call_args[1] assert call_args["event_type"] == AuditEventType.ACCESS_DENIED assert call_args["severity"] == AuditSeverity.HIGH assert call_args["reason"] == reason def test_log_control_operation_setpoint_change(self): """Test logging setpoint change operation.""" user_id = "operator_user" station_id = "station_001" action = "setpoint_change" resource = "pump_speed" ip_address = "192.168.1.100" with patch.object(self.audit_logger, 'log_compliance_event') as mock_log: self.audit_logger.log_control_operation( user_id=user_id, station_id=station_id, pump_id=None, action=action, resource=resource, result="success", ip_address=ip_address ) # Verify the correct event type and severity mock_log.assert_called_once() call_args = mock_log.call_args[1] assert call_args["event_type"] == AuditEventType.SETPOINT_CHANGED assert call_args["severity"] == AuditSeverity.HIGH assert call_args["station_id"] == station_id def test_log_control_operation_emergency_stop(self): """Test logging emergency stop operation.""" user_id = "operator_user" station_id = "station_001" action = "emergency_stop" resource = "system" ip_address = "192.168.1.100" with patch.object(self.audit_logger, 'log_compliance_event') as mock_log: self.audit_logger.log_control_operation( user_id=user_id, station_id=station_id, pump_id="pump_001", action=action, resource=resource, result="success", ip_address=ip_address ) # Verify the correct event type and severity mock_log.assert_called_once() call_args = mock_log.call_args[1] assert call_args["event_type"] == AuditEventType.EMERGENCY_STOP_ACTIVATED assert call_args["severity"] == AuditSeverity.CRITICAL assert call_args["pump_id"] == "pump_001" def test_log_security_event(self): """Test logging security events.""" event_type = AuditEventType.CERTIFICATE_EXPIRED severity = AuditSeverity.HIGH user_id = "system" station_id = "station_001" reason = "TLS certificate expired" with patch.object(self.audit_logger, 'log_compliance_event') as mock_log: self.audit_logger.log_security_event( event_type=event_type, severity=severity, user_id=user_id, station_id=station_id, reason=reason ) # Verify the correct parameters mock_log.assert_called_once() call_args = mock_log.call_args[1] assert call_args["event_type"] == event_type assert call_args["severity"] == severity assert call_args["user_id"] == user_id assert call_args["station_id"] == station_id assert call_args["reason"] == reason def test_get_audit_trail_success(self): """Test successful retrieval of audit trail.""" # Mock database result mock_result = [ {"event_type": "user_login", "user_id": "test_user"}, {"event_type": "data_read", "user_id": "test_user"} ] self.mock_db_client.fetch_all.return_value = mock_result start_time = datetime(2024, 1, 1, tzinfo=timezone.utc) end_time = datetime(2024, 1, 31, tzinfo=timezone.utc) user_id = "test_user" with patch.object(self.audit_logger, 'log_compliance_event') as mock_log: result = self.audit_logger.get_audit_trail( start_time=start_time, end_time=end_time, user_id=user_id ) # Verify database query self.mock_db_client.fetch_all.assert_called_once() # Verify audit trail access logging mock_log.assert_called_once() call_args = mock_log.call_args[1] assert call_args["event_type"] == AuditEventType.AUDIT_LOG_ACCESSED assert call_args["user_id"] == user_id # Verify result assert result == mock_result def test_get_audit_trail_error(self): """Test audit trail retrieval with error.""" # Mock database error self.mock_db_client.fetch_all.side_effect = Exception("Query failed") with patch.object(self.audit_logger.logger, 'error') as mock_error_log: result = self.audit_logger.get_audit_trail() # Verify error logging mock_error_log.assert_called_once() # Verify empty result assert result == [] def test_generate_compliance_report_success(self): """Test successful compliance report generation.""" # Mock database result mock_result = [ {"event_type": "user_login", "severity": "low", "count": 10, "unique_users": 2}, {"event_type": "access_denied", "severity": "high", "count": 2, "unique_users": 1} ] self.mock_db_client.fetch_all.return_value = mock_result start_time = datetime(2024, 1, 1, tzinfo=timezone.utc) end_time = datetime(2024, 1, 31, tzinfo=timezone.utc) compliance_standard = "IEC_62443" with patch.object(self.audit_logger, 'log_compliance_event') as mock_log: report = self.audit_logger.generate_compliance_report( start_time=start_time, end_time=end_time, compliance_standard=compliance_standard ) # Verify database query self.mock_db_client.fetch_all.assert_called_once() # Verify compliance report logging mock_log.assert_called_once() # Verify report structure assert report["compliance_standard"] == compliance_standard assert report["summary"]["total_events"] == 12 assert report["summary"]["unique_users"] == 3 assert report["summary"]["event_types"] == 2 assert report["events_by_type"] == mock_result def test_generate_compliance_report_error(self): """Test compliance report generation with error.""" # Mock database error self.mock_db_client.fetch_all.side_effect = Exception("Report generation failed") start_time = datetime(2024, 1, 1, tzinfo=timezone.utc) end_time = datetime(2024, 1, 31, tzinfo=timezone.utc) compliance_standard = "ISO_27001" with patch.object(self.audit_logger.logger, 'error') as mock_error_log: report = self.audit_logger.generate_compliance_report( start_time=start_time, end_time=end_time, compliance_standard=compliance_standard ) # Verify error logging mock_error_log.assert_called_once() # Verify error result assert "error" in report assert "Report generation failed" in report["error"]