From dfa3f0832b144f55ddce410e1d6bfe94e47654c0 Mon Sep 17 00:00:00 2001 From: openhands Date: Mon, 27 Oct 2025 20:07:37 +0000 Subject: [PATCH] Phase 4: Complete Security Layer Implementation - Implemented JWT-based authentication with bcrypt password hashing - Added role-based access control (RBAC) with four user roles - Created TLS/SSL encryption with certificate management - Enhanced audit logging for IEC 62443, ISO 27001, and NIS2 compliance - Added comprehensive security tests (56 tests passing) - Updated REST API with authentication and permission checks - Added security settings to configuration Co-authored-by: openhands --- README.md | 10 +- config/settings.py | 12 + src/core/compliance_audit.py | 463 ++++++++++++++++++++++++++++ src/core/security.py | 358 +++++++++++++++++++++ src/core/tls_manager.py | 304 ++++++++++++++++++ src/protocols/rest_api.py | 243 +++++++++++++-- tests/unit/test_compliance_audit.py | 368 ++++++++++++++++++++++ tests/unit/test_security.py | 373 ++++++++++++++++++++++ tests/unit/test_tls_manager.py | 243 +++++++++++++++ 9 files changed, 2346 insertions(+), 28 deletions(-) create mode 100644 src/core/compliance_audit.py create mode 100644 src/core/security.py create mode 100644 src/core/tls_manager.py create mode 100644 tests/unit/test_compliance_audit.py create mode 100644 tests/unit/test_security.py create mode 100644 tests/unit/test_tls_manager.py diff --git a/README.md b/README.md index be49203..6ff867c 100644 --- a/README.md +++ b/README.md @@ -35,10 +35,12 @@ The Calejo Control Adapter translates optimized pump control plans from Calejo O - Unified main application - 15 comprehensive unit tests for SetpointManager -🔄 **Phase 4**: Security Layer (In Progress) -- Authentication and authorization -- Audit logging -- TLS/SSL encryption +✅ **Phase 4**: Security Layer +- JWT-based authentication with bcrypt password hashing +- Role-based access control (RBAC) with four user roles +- TLS/SSL encryption with certificate management +- Compliance audit logging for IEC 62443, ISO 27001, and NIS2 +- 56 comprehensive security tests (24 auth/authz, 17 TLS, 15 audit) ⏳ **Phase 5**: Protocol Servers (Pending) - Enhanced protocol implementations diff --git a/config/settings.py b/config/settings.py index f197a66..e4b9977 100644 --- a/config/settings.py +++ b/config/settings.py @@ -29,6 +29,18 @@ class Settings(BaseSettings): tls_cert_path: Optional[str] = None tls_key_path: Optional[str] = None + # JWT Authentication + jwt_secret_key: str = "your-secret-key-change-in-production" + jwt_token_expire_minutes: int = 60 + jwt_algorithm: str = "HS256" + + # Password policy + password_min_length: int = 8 + password_require_uppercase: bool = True + password_require_lowercase: bool = True + password_require_numbers: bool = True + password_require_special: bool = True + # OPC UA opcua_enabled: bool = True opcua_host: str = "localhost" diff --git a/src/core/compliance_audit.py b/src/core/compliance_audit.py new file mode 100644 index 0000000..e037977 --- /dev/null +++ b/src/core/compliance_audit.py @@ -0,0 +1,463 @@ +""" +Compliance Audit Logger for Calejo Control Adapter. + +Provides enhanced audit logging capabilities compliant with: +- IEC 62443 (Industrial Automation and Control Systems Security) +- ISO 27001 (Information Security Management) +- NIS2 Directive (Network and Information Systems Security) +""" + +import structlog +from datetime import datetime, timezone +from typing import Dict, Any, Optional, List +from enum import Enum + +from config.settings import settings + +logger = structlog.get_logger() + + +class AuditEventType(Enum): + """Audit event types for compliance requirements.""" + + # Authentication and Authorization + USER_LOGIN = "user_login" + USER_LOGOUT = "user_logout" + USER_CREATED = "user_created" + USER_MODIFIED = "user_modified" + USER_DELETED = "user_deleted" + PASSWORD_CHANGED = "password_changed" + ROLE_CHANGED = "role_changed" + + # System Access + SYSTEM_START = "system_start" + SYSTEM_STOP = "system_stop" + SYSTEM_CONFIG_CHANGED = "system_config_changed" + + # Control Operations + SETPOINT_CHANGED = "setpoint_changed" + EMERGENCY_STOP_ACTIVATED = "emergency_stop_activated" + EMERGENCY_STOP_RESET = "emergency_stop_reset" + PUMP_CONTROL = "pump_control" + VALVE_CONTROL = "valve_control" + + # Security Events + ACCESS_DENIED = "access_denied" + INVALID_AUTHENTICATION = "invalid_authentication" + SESSION_TIMEOUT = "session_timeout" + CERTIFICATE_EXPIRED = "certificate_expired" + CERTIFICATE_ROTATED = "certificate_rotated" + + # Data Operations + DATA_READ = "data_read" + DATA_WRITE = "data_write" + DATA_EXPORT = "data_export" + DATA_DELETED = "data_deleted" + + # Network Operations + CONNECTION_ESTABLISHED = "connection_established" + CONNECTION_CLOSED = "connection_closed" + CONNECTION_REJECTED = "connection_rejected" + + # Compliance Events + AUDIT_LOG_ACCESSED = "audit_log_accessed" + COMPLIANCE_CHECK = "compliance_check" + SECURITY_SCAN = "security_scan" + + +class AuditSeverity(Enum): + """Audit event severity levels.""" + + LOW = "low" + MEDIUM = "medium" + HIGH = "high" + CRITICAL = "critical" + + +class ComplianceAuditLogger: + """ + Enhanced audit logger for compliance requirements. + + Provides comprehensive audit trail capabilities compliant with + IEC 62443, ISO 27001, and NIS2 Directive requirements. + """ + + def __init__(self, db_client): + self.db_client = db_client + self.logger = structlog.get_logger(__name__) + + def log_compliance_event( + self, + event_type: AuditEventType, + severity: AuditSeverity, + user_id: Optional[str] = None, + station_id: Optional[str] = None, + pump_id: Optional[str] = None, + ip_address: Optional[str] = None, + protocol: Optional[str] = None, + action: Optional[str] = None, + resource: Optional[str] = None, + result: Optional[str] = None, + reason: Optional[str] = None, + compliance_standard: Optional[List[str]] = None, + event_data: Optional[Dict[str, Any]] = None + ): + """ + Log a compliance audit event. + + Args: + event_type: Type of audit event + severity: Severity level + user_id: User ID performing the action + station_id: Station ID if applicable + pump_id: Pump ID if applicable + ip_address: Source IP address + protocol: Communication protocol used + action: Specific action performed + resource: Resource accessed or modified + result: Result of the action (success/failure) + reason: Reason for the action or failure + compliance_standard: List of compliance standards this event relates to + event_data: Additional event-specific data + """ + + # Default compliance standards + if compliance_standard is None: + compliance_standard = ["IEC_62443", "ISO_27001", "NIS2"] + + # Create comprehensive audit record + audit_record = { + "timestamp": datetime.now(timezone.utc).isoformat(), + "event_type": event_type.value, + "severity": severity.value, + "user_id": user_id, + "station_id": station_id, + "pump_id": pump_id, + "ip_address": ip_address, + "protocol": protocol, + "action": action, + "resource": resource, + "result": result, + "reason": reason, + "compliance_standard": compliance_standard, + "event_data": event_data or {}, + "app_name": settings.app_name, + "app_version": settings.app_version, + "environment": settings.environment + } + + # Log to structured logs + self.logger.info( + "compliance_audit_event", + **audit_record + ) + + # Log to database if audit logging is enabled + if settings.audit_log_enabled: + self._log_to_database(audit_record) + + def _log_to_database(self, audit_record: Dict[str, Any]): + """Log audit record to database.""" + try: + query = """ + INSERT INTO compliance_audit_log + (timestamp, event_type, severity, user_id, station_id, pump_id, + ip_address, protocol, action, resource, result, reason, + compliance_standard, event_data, app_name, app_version, environment) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + """ + self.db_client.execute( + query, + ( + audit_record["timestamp"], + audit_record["event_type"], + audit_record["severity"], + audit_record["user_id"], + audit_record["station_id"], + audit_record["pump_id"], + audit_record["ip_address"], + audit_record["protocol"], + audit_record["action"], + audit_record["resource"], + audit_record["result"], + audit_record["reason"], + audit_record["compliance_standard"], + audit_record["event_data"], + audit_record["app_name"], + audit_record["app_version"], + audit_record["environment"] + ) + ) + except Exception as e: + self.logger.error( + "compliance_audit_database_failed", + error=str(e), + event_type=audit_record["event_type"] + ) + + def log_user_authentication( + self, + user_id: str, + result: str, + ip_address: str, + reason: Optional[str] = None + ): + """Log user authentication events.""" + event_type = ( + AuditEventType.USER_LOGIN if result == "success" + else AuditEventType.INVALID_AUTHENTICATION + ) + severity = ( + AuditSeverity.LOW if result == "success" + else AuditSeverity.HIGH + ) + + self.log_compliance_event( + event_type=event_type, + severity=severity, + user_id=user_id, + ip_address=ip_address, + action="authentication", + resource="system", + result=result, + reason=reason + ) + + def log_access_control( + self, + user_id: str, + action: str, + resource: str, + result: str, + ip_address: str, + reason: Optional[str] = None + ): + """Log access control events.""" + event_type = ( + AuditEventType.ACCESS_DENIED if result == "denied" + else AuditEventType.DATA_READ if action == "read" + else AuditEventType.DATA_WRITE if action == "write" + else AuditEventType.SYSTEM_CONFIG_CHANGED + ) + severity = ( + AuditSeverity.HIGH if result == "denied" + else AuditSeverity.MEDIUM + ) + + self.log_compliance_event( + event_type=event_type, + severity=severity, + user_id=user_id, + ip_address=ip_address, + action=action, + resource=resource, + result=result, + reason=reason + ) + + def log_control_operation( + self, + user_id: str, + station_id: str, + pump_id: Optional[str], + action: str, + resource: str, + result: str, + ip_address: str, + event_data: Optional[Dict[str, Any]] = None + ): + """Log control system operations.""" + event_type = ( + AuditEventType.SETPOINT_CHANGED if action == "setpoint_change" + else AuditEventType.EMERGENCY_STOP_ACTIVATED if action == "emergency_stop" + else AuditEventType.PUMP_CONTROL if "pump" in resource.lower() + else AuditEventType.VALVE_CONTROL if "valve" in resource.lower() + else AuditEventType.SYSTEM_CONFIG_CHANGED + ) + severity = ( + AuditSeverity.CRITICAL if action == "emergency_stop" + else AuditSeverity.HIGH if action == "setpoint_change" + else AuditSeverity.MEDIUM + ) + + self.log_compliance_event( + event_type=event_type, + severity=severity, + user_id=user_id, + station_id=station_id, + pump_id=pump_id, + ip_address=ip_address, + action=action, + resource=resource, + result=result, + event_data=event_data + ) + + def log_security_event( + self, + event_type: AuditEventType, + severity: AuditSeverity, + user_id: Optional[str] = None, + station_id: Optional[str] = None, + ip_address: Optional[str] = None, + reason: Optional[str] = None, + event_data: Optional[Dict[str, Any]] = None + ): + """Log security-related events.""" + self.log_compliance_event( + event_type=event_type, + severity=severity, + user_id=user_id, + station_id=station_id, + ip_address=ip_address, + action="security_event", + resource="system", + result="detected", + reason=reason, + event_data=event_data + ) + + def get_audit_trail( + self, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + user_id: Optional[str] = None, + event_type: Optional[AuditEventType] = None, + severity: Optional[AuditSeverity] = None + ) -> List[Dict[str, Any]]: + """ + Retrieve audit trail for compliance reporting. + + Args: + start_time: Start time for audit trail + end_time: End time for audit trail + user_id: Filter by user ID + event_type: Filter by event type + severity: Filter by severity + + Returns: + List of audit records + """ + try: + query = """ + SELECT * FROM compliance_audit_log + WHERE 1=1 + """ + params = [] + + if start_time: + query += " AND timestamp >= %s" + params.append(start_time.isoformat()) + + if end_time: + query += " AND timestamp <= %s" + params.append(end_time.isoformat()) + + if user_id: + query += " AND user_id = %s" + params.append(user_id) + + if event_type: + query += " AND event_type = %s" + params.append(event_type.value) + + if severity: + query += " AND severity = %s" + params.append(severity.value) + + query += " ORDER BY timestamp DESC" + + result = self.db_client.fetch_all(query, params) + + # Log the audit trail access + self.log_compliance_event( + event_type=AuditEventType.AUDIT_LOG_ACCESSED, + severity=AuditSeverity.LOW, + user_id=user_id, + action="audit_trail_access", + resource="audit_log", + result="success" + ) + + return result + + except Exception as e: + self.logger.error( + "audit_trail_retrieval_failed", + error=str(e), + user_id=user_id + ) + return [] + + def generate_compliance_report( + self, + start_time: datetime, + end_time: datetime, + compliance_standard: str + ) -> Dict[str, Any]: + """ + Generate compliance report for specified standard. + + Args: + start_time: Report start time + end_time: Report end time + compliance_standard: Compliance standard to report on + + Returns: + Compliance report data + """ + try: + query = """ + SELECT + event_type, + severity, + COUNT(*) as count, + COUNT(DISTINCT user_id) as unique_users + FROM compliance_audit_log + WHERE timestamp BETWEEN %s AND %s + AND %s = ANY(compliance_standard) + GROUP BY event_type, severity + ORDER BY count DESC + """ + + result = self.db_client.fetch_all( + query, + (start_time.isoformat(), end_time.isoformat(), compliance_standard) + ) + + report = { + "compliance_standard": compliance_standard, + "report_period": { + "start_time": start_time.isoformat(), + "end_time": end_time.isoformat() + }, + "summary": { + "total_events": sum(row["count"] for row in result), + "unique_users": sum(row["unique_users"] for row in result), + "event_types": len(set(row["event_type"] for row in result)) + }, + "events_by_type": result + } + + # Log the compliance report generation + self.log_compliance_event( + event_type=AuditEventType.COMPLIANCE_CHECK, + severity=AuditSeverity.LOW, + action="compliance_report_generated", + resource="compliance", + result="success", + event_data={ + "compliance_standard": compliance_standard, + "report_period": report["report_period"] + } + ) + + return report + + except Exception as e: + self.logger.error( + "compliance_report_generation_failed", + error=str(e), + compliance_standard=compliance_standard + ) + return {"error": str(e)} \ No newline at end of file diff --git a/src/core/security.py b/src/core/security.py new file mode 100644 index 0000000..9406cf0 --- /dev/null +++ b/src/core/security.py @@ -0,0 +1,358 @@ +""" +Security layer for Calejo Control Adapter. + +Provides authentication, authorization, and security utilities for compliance +with IEC 62443, ISO 27001, and NIS2 Directive requirements. +""" + +import jwt +import bcrypt +from datetime import datetime, timedelta, timezone +from typing import Optional, Dict, List, Any +from enum import Enum +from pydantic import BaseModel +import structlog + +from config.settings import settings + +logger = structlog.get_logger() + + +class UserRole(str, Enum): + """User roles for role-based access control.""" + OPERATOR = "operator" + ENGINEER = "engineer" + ADMINISTRATOR = "administrator" + READ_ONLY = "read_only" + + +class User(BaseModel): + """User model for authentication and authorization.""" + user_id: str + username: str + email: str + role: UserRole + active: bool = True + created_at: datetime + last_login: Optional[datetime] = None + + +class TokenData(BaseModel): + """Data encoded in JWT tokens.""" + user_id: str + username: str + role: UserRole + exp: datetime + + +class AuthenticationManager: + """ + Manages user authentication with JWT tokens and password hashing. + + Implements security controls for IEC 62443, ISO 27001, and NIS2 compliance. + """ + + def __init__(self, secret_key: str, algorithm: str = "HS256", token_expire_minutes: int = 60): + self.secret_key = secret_key + self.algorithm = algorithm + self.token_expire_minutes = token_expire_minutes + + # In-memory user store (in production, this would be a database) + self.users: Dict[str, User] = {} + self.password_hashes: Dict[str, str] = {} + + # Initialize with default users for development + self._initialize_default_users() + + def _initialize_default_users(self): + """Initialize default users for development and testing.""" + default_users = [ + { + "user_id": "admin_001", + "username": "admin", + "email": "admin@calejo.com", + "role": UserRole.ADMINISTRATOR, + "password": "admin123" + }, + { + "user_id": "operator_001", + "username": "operator", + "email": "operator@calejo.com", + "role": UserRole.OPERATOR, + "password": "operator123" + }, + { + "user_id": "engineer_001", + "username": "engineer", + "email": "engineer@calejo.com", + "role": UserRole.ENGINEER, + "password": "engineer123" + }, + { + "user_id": "viewer_001", + "username": "viewer", + "email": "viewer@calejo.com", + "role": UserRole.READ_ONLY, + "password": "viewer123" + } + ] + + for user_data in default_users: + self.create_user( + user_id=user_data["user_id"], + username=user_data["username"], + email=user_data["email"], + role=user_data["role"], + password=user_data["password"] + ) + + def hash_password(self, password: str) -> str: + """Hash a password using bcrypt.""" + salt = bcrypt.gensalt() + hashed = bcrypt.hashpw(password.encode('utf-8'), salt) + return hashed.decode('utf-8') + + def verify_password(self, plain_password: str, hashed_password: str) -> bool: + """Verify a password against its hash.""" + return bcrypt.checkpw( + plain_password.encode('utf-8'), + hashed_password.encode('utf-8') + ) + + def create_user(self, user_id: str, username: str, email: str, role: UserRole, password: str) -> User: + """Create a new user with hashed password.""" + if user_id in self.users: + raise ValueError(f"User with ID {user_id} already exists") + + user = User( + user_id=user_id, + username=username, + email=email, + role=role, + created_at=datetime.now(timezone.utc) + ) + + self.users[user_id] = user + self.password_hashes[user_id] = self.hash_password(password) + + logger.info("user_created", user_id=user_id, username=username, role=role.value) + return user + + def authenticate_user(self, username: str, password: str) -> Optional[User]: + """Authenticate a user and return user object if successful.""" + # Find user by username + user = None + for u in self.users.values(): + if u.username == username and u.active: + user = u + break + + if not user: + logger.warning("authentication_failed", username=username, reason="user_not_found") + return None + + # Verify password + if not self.verify_password(password, self.password_hashes[user.user_id]): + logger.warning("authentication_failed", username=username, reason="invalid_password") + return None + + # Update last login + user.last_login = datetime.now(timezone.utc) + + logger.info("user_authenticated", user_id=user.user_id, username=username, role=user.role.value) + return user + + def create_access_token(self, user: User) -> str: + """Create a JWT access token for the user.""" + expires_delta = timedelta(minutes=self.token_expire_minutes) + expire = datetime.now(timezone.utc) + expires_delta + + token_data = TokenData( + user_id=user.user_id, + username=user.username, + role=user.role, + exp=expire + ) + + encoded_jwt = jwt.encode( + token_data.dict(), + self.secret_key, + algorithm=self.algorithm + ) + + return encoded_jwt + + def verify_token(self, token: str) -> Optional[TokenData]: + """Verify and decode a JWT token.""" + try: + payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm]) + token_data = TokenData(**payload) + + # Check if token is expired + if token_data.exp < datetime.now(timezone.utc): + logger.warning("token_expired", username=token_data.username) + return None + + # Verify user still exists and is active + user = self.users.get(token_data.user_id) + if not user or not user.active: + logger.warning("token_invalid_user", username=token_data.username) + return None + + return token_data + + except jwt.PyJWTError as e: + logger.warning("token_verification_failed", error=str(e)) + return None + + +class AuthorizationManager: + """ + Manages role-based access control (RBAC) for authorization. + + Implements IEC 62443 zone security model and ISO 27001 access control requirements. + """ + + def __init__(self): + # Define permissions for each role + self.role_permissions = { + UserRole.READ_ONLY: { + "read_pump_status", + "read_safety_status", + "read_audit_logs" + }, + UserRole.OPERATOR: { + "read_pump_status", + "read_safety_status", + "read_audit_logs", + "emergency_stop", + "clear_emergency_stop", + "view_alerts" + }, + UserRole.ENGINEER: { + "read_pump_status", + "read_safety_status", + "read_audit_logs", + "emergency_stop", + "clear_emergency_stop", + "view_alerts", + "configure_safety_limits", + "manage_pump_configuration", + "view_system_metrics" + }, + UserRole.ADMINISTRATOR: { + "read_pump_status", + "read_safety_status", + "read_audit_logs", + "emergency_stop", + "clear_emergency_stop", + "view_alerts", + "configure_safety_limits", + "manage_pump_configuration", + "view_system_metrics", + "manage_users", + "configure_system", + "access_all_stations" + } + } + + def has_permission(self, role: UserRole, permission: str) -> bool: + """Check if a role has the specified permission.""" + permissions = self.role_permissions.get(role, set()) + return permission in permissions + + def get_allowed_actions(self, role: UserRole) -> List[str]: + """Get all allowed actions for a role.""" + return list(self.role_permissions.get(role, set())) + + def can_access_station(self, role: UserRole, station_id: str) -> bool: + """ + Check if user can access a specific station. + + Administrators can access all stations, others may have station-specific + permissions (to be implemented with database integration). + """ + if role == UserRole.ADMINISTRATOR: + return True + + # For now, all authenticated users can access all stations + # In production, this would check station-specific permissions + return True + + +class SecurityManager: + """ + Main security manager that coordinates authentication and authorization. + + Provides a unified interface for security operations and compliance logging. + """ + + def __init__(self, audit_logger=None): + self.auth_manager = AuthenticationManager( + secret_key=settings.jwt_secret_key, + token_expire_minutes=settings.jwt_token_expire_minutes + ) + self.authz_manager = AuthorizationManager() + self.audit_logger = audit_logger + + def authenticate(self, username: str, password: str) -> Optional[str]: + """Authenticate user and return JWT token if successful.""" + user = self.auth_manager.authenticate_user(username, password) + if user: + token = self.auth_manager.create_access_token(user) + + # Log successful authentication + if self.audit_logger: + self.audit_logger.log( + event_type="USER_AUTHENTICATION", + severity="INFO", + user_id=user.user_id, + event_data={ + "username": username, + "role": user.role.value, + "result": "SUCCESS" + } + ) + + return token + + # Log failed authentication + if self.audit_logger: + self.audit_logger.log( + event_type="USER_AUTHENTICATION", + severity="WARNING", + event_data={ + "username": username, + "result": "FAILURE" + } + ) + + return None + + def verify_access_token(self, token: str) -> Optional[TokenData]: + """Verify JWT token and return token data if valid.""" + return self.auth_manager.verify_token(token) + + def check_permission(self, token_data: TokenData, permission: str) -> bool: + """Check if user has the specified permission.""" + return self.authz_manager.has_permission(token_data.role, permission) + + def can_access_station(self, token_data: TokenData, station_id: str) -> bool: + """Check if user can access the specified station.""" + return self.authz_manager.can_access_station(token_data.role, station_id) + + def get_user_permissions(self, token_data: TokenData) -> List[str]: + """Get all permissions for the user.""" + return self.authz_manager.get_allowed_actions(token_data.role) + + +# Global security manager instance +security_manager: Optional[SecurityManager] = None + + +def get_security_manager() -> SecurityManager: + """Get or create the global security manager instance.""" + global security_manager + if security_manager is None: + security_manager = SecurityManager() + return security_manager \ No newline at end of file diff --git a/src/core/tls_manager.py b/src/core/tls_manager.py new file mode 100644 index 0000000..1a13a07 --- /dev/null +++ b/src/core/tls_manager.py @@ -0,0 +1,304 @@ +""" +TLS/SSL Manager for Calejo Control Adapter. + +Provides certificate management and TLS/SSL configuration for secure communications +in compliance with IEC 62443, ISO 27001, and NIS2 Directive requirements. +""" + +import os +import ssl +from typing import Optional, Tuple +from pathlib import Path +import structlog +from datetime import datetime, timedelta, timezone + +from config.settings import settings + +logger = structlog.get_logger() + + +class TLSManager: + """ + Manages TLS/SSL certificates and secure communications. + + Provides certificate validation, rotation, and secure context creation + for all protocol servers. + """ + + def __init__(self): + self.cert_path = settings.tls_cert_path + self.key_path = settings.tls_key_path + self.tls_enabled = settings.tls_enabled + + # Certificate rotation tracking + self.cert_expiry_dates: dict[str, datetime] = {} + + # Validate certificates on initialization + if self.tls_enabled: + self._validate_certificates() + + def _validate_certificates(self) -> bool: + """ + Validate TLS certificates exist and are valid. + + Returns: + bool: True if certificates are valid, False otherwise + """ + if not self.cert_path or not self.key_path: + logger.warning( + "tls_certificates_missing", + cert_path=self.cert_path, + key_path=self.key_path + ) + return False + + # Check if certificate files exist + cert_file = Path(self.cert_path) + key_file = Path(self.key_path) + + if not cert_file.exists(): + logger.error("tls_certificate_file_missing", path=self.cert_path) + return False + + if not key_file.exists(): + logger.error("tls_key_file_missing", path=self.key_path) + return False + + # Validate certificate format and expiry + try: + expiry_date = self._get_certificate_expiry(self.cert_path) + self.cert_expiry_dates[self.cert_path] = expiry_date + + # Check if certificate is expired or expiring soon + now = datetime.now(timezone.utc) + days_until_expiry = (expiry_date - now).days + + if days_until_expiry < 0: + logger.error("tls_certificate_expired", expiry_date=expiry_date.isoformat()) + return False + elif days_until_expiry < 30: + logger.warning( + "tls_certificate_expiring_soon", + expiry_date=expiry_date.isoformat(), + days_remaining=days_until_expiry + ) + else: + logger.info( + "tls_certificate_valid", + expiry_date=expiry_date.isoformat(), + days_remaining=days_until_expiry + ) + + return True + + except Exception as e: + logger.error("tls_certificate_validation_failed", error=str(e)) + return False + + def _get_certificate_expiry(self, cert_path: str) -> datetime: + """ + Get certificate expiry date. + + Args: + cert_path: Path to certificate file + + Returns: + datetime: Certificate expiry date + + Raises: + Exception: If certificate cannot be parsed + """ + import OpenSSL + + with open(cert_path, 'rb') as f: + cert_data = f.read() + + cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_data) + expiry_bytes = cert.get_notAfter() + + # Parse ASN.1 time format (YYYYMMDDHHMMSSZ) + expiry_str = expiry_bytes.decode('ascii') + expiry_date = datetime.strptime(expiry_str, '%Y%m%d%H%M%SZ') + + # Make timezone aware + return expiry_date.replace(tzinfo=timezone.utc) + + def create_ssl_context(self) -> Optional[ssl.SSLContext]: + """ + Create SSL context for secure communications. + + Returns: + Optional[ssl.SSLContext]: SSL context if TLS is enabled and valid, + None otherwise + """ + if not self.tls_enabled: + logger.info("tls_disabled") + return None + + if not self._validate_certificates(): + logger.error("cannot_create_ssl_context_invalid_certificates") + return None + + try: + # Create SSL context with secure defaults + context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + + # Load certificate and key + context.load_cert_chain(self.cert_path, self.key_path) + + # Configure secure cipher suites + context.set_ciphers('ECDHE+AESGCM:ECDHE+CHACHA20:DHE+AESGCM:DHE+CHACHA20:!aNULL:!MD5:!DSS') + + # Enable certificate verification + context.verify_mode = ssl.CERT_REQUIRED + context.check_hostname = True + + # Set minimum TLS version (TLS 1.2 or higher) + context.minimum_version = ssl.TLSVersion.TLSv1_2 + + logger.info("ssl_context_created_successfully") + return context + + except Exception as e: + logger.error("failed_to_create_ssl_context", error=str(e)) + return None + + def get_rest_api_ssl_config(self) -> Optional[Tuple[str, str]]: + """ + Get SSL configuration for REST API server. + + Returns: + Optional[Tuple[str, str]]: Tuple of (certfile, keyfile) paths if TLS is enabled, + None otherwise + """ + if not self.tls_enabled: + return None + + if not self._validate_certificates(): + return None + + return (self.cert_path, self.key_path) + + def check_certificate_rotation(self) -> bool: + """ + Check if certificates need rotation. + + Returns: + bool: True if certificates need rotation, False otherwise + """ + if not self.tls_enabled: + return False + + for cert_path, expiry_date in self.cert_expiry_dates.items(): + now = datetime.now(timezone.utc) + days_until_expiry = (expiry_date - now).days + + # Rotate if certificate expires in less than 7 days + if days_until_expiry < 7: + logger.warning( + "certificate_needs_rotation", + cert_path=cert_path, + expiry_date=expiry_date.isoformat(), + days_remaining=days_until_expiry + ) + return True + + return False + + def generate_self_signed_certificate(self, output_dir: str, common_name: str = "calejo-control.local") -> bool: + """ + Generate self-signed certificate for development/testing. + + Args: + output_dir: Directory to save certificate files + common_name: Common name for the certificate + + Returns: + bool: True if certificate generation succeeded, False otherwise + """ + try: + from cryptography import x509 + from cryptography.x509.oid import NameOID + from cryptography.hazmat.primitives import hashes, serialization + from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.hazmat.backends import default_backend + + # Create output directory if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + + # Generate private key + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + backend=default_backend() + ) + + # Generate certificate + subject = issuer = x509.Name([ + x509.NameAttribute(NameOID.COMMON_NAME, common_name), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Calejo Control"), + x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Development"), + ]) + + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=365)) + .add_extension( + x509.SubjectAlternativeName([ + x509.DNSName("localhost"), + x509.DNSName("127.0.0.1"), + ]), + critical=False, + ) + .sign(private_key, hashes.SHA256(), default_backend()) + ) + + # Write private key + key_path = os.path.join(output_dir, "key.pem") + with open(key_path, "wb") as f: + f.write( + private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + ) + + # Write certificate + cert_path = os.path.join(output_dir, "cert.pem") + with open(cert_path, "wb") as f: + f.write(cert.public_bytes(serialization.Encoding.PEM)) + + logger.info( + "self_signed_certificate_generated", + cert_path=cert_path, + key_path=key_path, + common_name=common_name + ) + + # Update settings + self.cert_path = cert_path + self.key_path = key_path + + return True + + except Exception as e: + logger.error("failed_to_generate_self_signed_certificate", error=str(e)) + return False + + +# Global TLS manager instance +tls_manager: Optional[TLSManager] = None + + +def get_tls_manager() -> TLSManager: + """Get or create the global TLS manager instance.""" + global tls_manager + if tls_manager is None: + tls_manager = TLSManager() + return tls_manager \ No newline at end of file diff --git a/src/protocols/rest_api.py b/src/protocols/rest_api.py index 97236e3..054c847 100644 --- a/src/protocols/rest_api.py +++ b/src/protocols/rest_api.py @@ -7,12 +7,17 @@ Provides REST endpoints for emergency stop, status monitoring, and setpoint acce from typing import Optional, Dict, Any from datetime import datetime import structlog -from fastapi import FastAPI, HTTPException, status, Depends +from fastapi import FastAPI, HTTPException, status, Depends, Request from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from src.core.setpoint_manager import SetpointManager from src.core.emergency_stop import EmergencyStopManager +from src.core.security import ( + SecurityManager, TokenData, UserRole, get_security_manager +) +from src.core.tls_manager import get_tls_manager logger = structlog.get_logger() @@ -20,6 +25,64 @@ logger = structlog.get_logger() security = HTTPBearer() +class LoginRequest(BaseModel): + """Request model for user login.""" + username: str + password: str + + +class LoginResponse(BaseModel): + """Response model for successful login.""" + access_token: str + token_type: str = "bearer" + expires_in: int + user_id: str + username: str + role: str + permissions: list[str] + + +class UserInfoResponse(BaseModel): + """Response model for user information.""" + user_id: str + username: str + email: str + role: str + permissions: list[str] + last_login: Optional[str] + + +def get_current_user( + credentials: HTTPAuthorizationCredentials = Depends(security), + security_manager: SecurityManager = Depends(get_security_manager) +) -> TokenData: + """Dependency to get current user from JWT token.""" + token = credentials.credentials + token_data = security_manager.verify_access_token(token) + if not token_data: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + return token_data + + +def require_permission(permission: str): + """Dependency factory to require specific permission.""" + def permission_dependency( + token_data: TokenData = Depends(get_current_user), + security_manager: SecurityManager = Depends(get_security_manager) + ): + if not security_manager.check_permission(token_data, permission): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Insufficient permissions. Required: {permission}" + ) + return token_data + return permission_dependency + + class EmergencyStopRequest(BaseModel): """Request model for emergency stop.""" triggered_by: str @@ -63,7 +126,18 @@ class RESTAPIServer: self.app = FastAPI( title="Calejo Control API", version="2.0", - description="REST API for Calejo Control Adapter with safety framework" + description="REST API for Calejo Control Adapter with safety framework", + docs_url="/docs", + redoc_url="/redoc" + ) + + # Add CORS middleware + self.app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # In production, restrict to specific origins + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], ) self._setup_routes() @@ -77,7 +151,8 @@ class RESTAPIServer: return { "name": "Calejo Control API", "version": "2.0", - "status": "operational" + "status": "operational", + "authentication_required": True } @self.app.get("/health", summary="Health Check", tags=["General"]) @@ -88,6 +163,69 @@ class RESTAPIServer: "timestamp": datetime.now().isoformat() } + # Authentication endpoints (no authentication required) + @self.app.post( + "/api/v1/auth/login", + summary="User Login", + tags=["Authentication"], + response_model=LoginResponse + ) + async def login( + request: LoginRequest, + security_manager: SecurityManager = Depends(get_security_manager) + ): + """Authenticate user and return JWT token.""" + token = security_manager.authenticate(request.username, request.password) + if not token: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid username or password" + ) + + # Get user permissions + token_data = security_manager.verify_access_token(token) + permissions = security_manager.get_user_permissions(token_data) + + return LoginResponse( + access_token=token, + token_type="bearer", + expires_in=60, # 60 minutes + user_id=token_data.user_id, + username=token_data.username, + role=token_data.role.value, + permissions=permissions + ) + + @self.app.get( + "/api/v1/auth/me", + summary="Get Current User Info", + tags=["Authentication"], + response_model=UserInfoResponse + ) + async def get_current_user_info( + token_data: TokenData = Depends(get_current_user), + security_manager: SecurityManager = Depends(get_security_manager) + ): + """Get information about the currently authenticated user.""" + # Get user from auth manager + user = security_manager.auth_manager.users.get(token_data.user_id) + if not user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found" + ) + + permissions = security_manager.get_user_permissions(token_data) + + return UserInfoResponse( + user_id=user.user_id, + username=user.username, + email=user.email, + role=user.role.value, + permissions=permissions, + last_login=user.last_login.isoformat() if user.last_login else None + ) + @self.app.get( "/api/v1/setpoints", summary="Get All Setpoints", @@ -95,12 +233,14 @@ class RESTAPIServer: response_model=Dict[str, Dict[str, Optional[float]]] ) async def get_all_setpoints( - credentials: HTTPAuthorizationCredentials = Depends(security) + token_data: TokenData = Depends(require_permission("read_pump_status")) ): """ Get current setpoints for all pumps. Returns dictionary mapping station_id -> pump_id -> setpoint_hz + + Requires permission: read_pump_status """ try: setpoints = self.setpoint_manager.get_all_current_setpoints() @@ -121,10 +261,22 @@ class RESTAPIServer: async def get_pump_setpoint( station_id: str, pump_id: str, - credentials: HTTPAuthorizationCredentials = Depends(security) + token_data: TokenData = Depends(require_permission("read_pump_status")), + security_manager: SecurityManager = Depends(get_security_manager) ): - """Get current setpoint for a specific pump.""" + """ + Get current setpoint for a specific pump. + + Requires permission: read_pump_status + """ try: + # Check if user can access this station + if not security_manager.can_access_station(token_data, station_id): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Access denied to station {station_id}" + ) + setpoint = self.setpoint_manager.get_current_setpoint(station_id, pump_id) # Get pump info for control type @@ -147,6 +299,8 @@ class RESTAPIServer: timestamp=datetime.now().isoformat() ) + except HTTPException: + raise except Exception as e: logger.error( "failed_to_get_pump_setpoint", @@ -167,7 +321,8 @@ class RESTAPIServer: ) async def trigger_emergency_stop( request: EmergencyStopRequest, - credentials: HTTPAuthorizationCredentials = Depends(security) + token_data: TokenData = Depends(require_permission("emergency_stop")), + security_manager: SecurityManager = Depends(get_security_manager) ): """ Trigger emergency stop ("big red button"). @@ -176,15 +331,27 @@ class RESTAPIServer: - If station_id and pump_id provided: Stop single pump - If station_id only: Stop all pumps at station - If neither: Stop ALL pumps system-wide + + Requires permission: emergency_stop """ try: + # Check station access if station_id is provided + if request.station_id and not security_manager.can_access_station(token_data, request.station_id): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Access denied to station {request.station_id}" + ) + + # Use authenticated user as triggered_by + triggered_by = token_data.username + if request.station_id and request.pump_id: # Single pump stop result = self.emergency_stop_manager.emergency_stop_pump( station_id=request.station_id, pump_id=request.pump_id, reason=request.reason, - user_id=request.triggered_by + user_id=triggered_by ) scope = f"pump {request.station_id}/{request.pump_id}" elif request.station_id: @@ -192,14 +359,14 @@ class RESTAPIServer: result = self.emergency_stop_manager.emergency_stop_station( station_id=request.station_id, reason=request.reason, - user_id=request.triggered_by + user_id=triggered_by ) scope = f"station {request.station_id}" else: # System-wide stop result = self.emergency_stop_manager.emergency_stop_system( reason=request.reason, - user_id=request.triggered_by + user_id=triggered_by ) scope = "system-wide" @@ -208,7 +375,7 @@ class RESTAPIServer: "status": "emergency_stop_triggered", "scope": scope, "reason": request.reason, - "triggered_by": request.triggered_by, + "triggered_by": triggered_by, "timestamp": datetime.now().isoformat() } else: @@ -217,6 +384,8 @@ class RESTAPIServer: detail="Failed to trigger emergency stop" ) + except HTTPException: + raise except Exception as e: logger.error("failed_to_trigger_emergency_stop", error=str(e)) raise HTTPException( @@ -231,19 +400,26 @@ class RESTAPIServer: ) async def clear_emergency_stop( request: EmergencyStopClearRequest, - credentials: HTTPAuthorizationCredentials = Depends(security) + token_data: TokenData = Depends(require_permission("clear_emergency_stop")) ): - """Clear all active emergency stops.""" + """ + Clear all active emergency stops. + + Requires permission: clear_emergency_stop + """ try: + # Use authenticated user as cleared_by + cleared_by = token_data.username + # Clear system-wide emergency stop self.emergency_stop_manager.clear_emergency_stop_system( reason=request.notes, - user_id=request.cleared_by + user_id=cleared_by ) return { "status": "emergency_stop_cleared", - "cleared_by": request.cleared_by, + "cleared_by": cleared_by, "notes": request.notes, "timestamp": datetime.now().isoformat() } @@ -261,9 +437,13 @@ class RESTAPIServer: tags=["Emergency Stop"] ) async def get_emergency_stop_status( - credentials: HTTPAuthorizationCredentials = Depends(security) + token_data: TokenData = Depends(require_permission("read_safety_status")) ): - """Check if any emergency stops are active.""" + """ + Check if any emergency stops are active. + + Requires permission: read_safety_status + """ try: # Check system-wide emergency stop system_stop = self.emergency_stop_manager.system_emergency_stop @@ -291,18 +471,33 @@ class RESTAPIServer: """Start the REST API server.""" import uvicorn + # Get TLS configuration + tls_manager = get_tls_manager() + ssl_config = tls_manager.get_rest_api_ssl_config() + logger.info( "rest_api_server_starting", host=self.host, - port=self.port + port=self.port, + tls_enabled=ssl_config is not None ) - config = uvicorn.Config( - self.app, - host=self.host, - port=self.port, - log_level="info" - ) + config_kwargs = { + "app": self.app, + "host": self.host, + "port": self.port, + "log_level": "info" + } + + # Add SSL configuration if available + if ssl_config: + certfile, keyfile = ssl_config + config_kwargs.update({ + "ssl_certfile": certfile, + "ssl_keyfile": keyfile + }) + + config = uvicorn.Config(**config_kwargs) server = uvicorn.Server(config) await server.serve() diff --git a/tests/unit/test_compliance_audit.py b/tests/unit/test_compliance_audit.py new file mode 100644 index 0000000..85a6758 --- /dev/null +++ b/tests/unit/test_compliance_audit.py @@ -0,0 +1,368 @@ +""" +Unit tests for compliance audit components. +""" + +import pytest +from unittest.mock import Mock, patch +from datetime import datetime, timezone + +from src.core.compliance_audit import ( + ComplianceAuditLogger, AuditEventType, AuditSeverity +) +from config.settings import settings + + +class TestComplianceAuditLogger: + """Test cases for ComplianceAuditLogger.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_db_client = Mock() + self.audit_logger = ComplianceAuditLogger(self.mock_db_client) + + # Mock settings + self.original_audit_enabled = settings.audit_log_enabled + settings.audit_log_enabled = True + + def teardown_method(self): + """Clean up test fixtures.""" + # Restore original settings + settings.audit_log_enabled = self.original_audit_enabled + + def test_initialization(self): + """Test initialization of ComplianceAuditLogger.""" + assert self.audit_logger.db_client == self.mock_db_client + assert self.audit_logger.logger is not None + + def test_log_compliance_event_success(self): + """Test successful logging of compliance event.""" + event_type = AuditEventType.USER_LOGIN + severity = AuditSeverity.LOW + user_id = "test_user" + ip_address = "192.168.1.100" + + with patch.object(self.audit_logger.logger, 'info') as mock_log: + self.audit_logger.log_compliance_event( + event_type=event_type, + severity=severity, + user_id=user_id, + ip_address=ip_address, + action="login", + resource="system", + result="success" + ) + + # Verify structured logging + mock_log.assert_called_once() + call_args = mock_log.call_args[1] + assert call_args["event_type"] == "user_login" + assert call_args["severity"] == "low" + assert call_args["user_id"] == "test_user" + assert call_args["ip_address"] == "192.168.1.100" + + # Verify database logging + self.mock_db_client.execute.assert_called_once() + + def test_log_compliance_event_database_disabled(self): + """Test logging when database audit is disabled.""" + settings.audit_log_enabled = False + + with patch.object(self.audit_logger.logger, 'info') as mock_log: + self.audit_logger.log_compliance_event( + event_type=AuditEventType.USER_LOGIN, + severity=AuditSeverity.LOW, + user_id="test_user" + ) + + # Verify structured logging still occurs + mock_log.assert_called_once() + + # Verify database logging is skipped + self.mock_db_client.execute.assert_not_called() + + def test_log_compliance_event_database_error(self): + """Test logging when database operation fails.""" + # Mock database error + self.mock_db_client.execute.side_effect = Exception("Database error") + + with patch.object(self.audit_logger.logger, 'error') as mock_error_log: + self.audit_logger.log_compliance_event( + event_type=AuditEventType.USER_LOGIN, + severity=AuditSeverity.LOW, + user_id="test_user" + ) + + # Verify error logging + mock_error_log.assert_called_once() + call_args = mock_error_log.call_args[1] + assert "Database error" in call_args["error"] + assert call_args["event_type"] == "user_login" + + def test_log_user_authentication_success(self): + """Test logging successful user authentication.""" + user_id = "test_user" + ip_address = "192.168.1.100" + + with patch.object(self.audit_logger, 'log_compliance_event') as mock_log: + self.audit_logger.log_user_authentication( + user_id=user_id, + result="success", + ip_address=ip_address + ) + + # Verify the correct event type and severity + mock_log.assert_called_once() + call_args = mock_log.call_args[1] + assert call_args["event_type"] == AuditEventType.USER_LOGIN + assert call_args["severity"] == AuditSeverity.LOW + assert call_args["user_id"] == user_id + assert call_args["ip_address"] == ip_address + assert call_args["result"] == "success" + + def test_log_user_authentication_failure(self): + """Test logging failed user authentication.""" + user_id = "test_user" + ip_address = "192.168.1.100" + reason = "Invalid credentials" + + with patch.object(self.audit_logger, 'log_compliance_event') as mock_log: + self.audit_logger.log_user_authentication( + user_id=user_id, + result="failure", + ip_address=ip_address, + reason=reason + ) + + # Verify the correct event type and severity + mock_log.assert_called_once() + call_args = mock_log.call_args[1] + assert call_args["event_type"] == AuditEventType.INVALID_AUTHENTICATION + assert call_args["severity"] == AuditSeverity.HIGH + assert call_args["reason"] == reason + + def test_log_access_control_granted(self): + """Test logging granted access control.""" + user_id = "test_user" + action = "read" + resource = "station_data" + ip_address = "192.168.1.100" + + with patch.object(self.audit_logger, 'log_compliance_event') as mock_log: + self.audit_logger.log_access_control( + user_id=user_id, + action=action, + resource=resource, + result="granted", + ip_address=ip_address + ) + + # Verify the correct event type + mock_log.assert_called_once() + call_args = mock_log.call_args[1] + assert call_args["event_type"] == AuditEventType.DATA_READ + assert call_args["severity"] == AuditSeverity.MEDIUM + + def test_log_access_control_denied(self): + """Test logging denied access control.""" + user_id = "test_user" + action = "write" + resource = "system_config" + ip_address = "192.168.1.100" + reason = "Insufficient permissions" + + with patch.object(self.audit_logger, 'log_compliance_event') as mock_log: + self.audit_logger.log_access_control( + user_id=user_id, + action=action, + resource=resource, + result="denied", + ip_address=ip_address, + reason=reason + ) + + # Verify the correct event type and severity + mock_log.assert_called_once() + call_args = mock_log.call_args[1] + assert call_args["event_type"] == AuditEventType.ACCESS_DENIED + assert call_args["severity"] == AuditSeverity.HIGH + assert call_args["reason"] == reason + + def test_log_control_operation_setpoint_change(self): + """Test logging setpoint change operation.""" + user_id = "operator_user" + station_id = "station_001" + action = "setpoint_change" + resource = "pump_speed" + ip_address = "192.168.1.100" + + with patch.object(self.audit_logger, 'log_compliance_event') as mock_log: + self.audit_logger.log_control_operation( + user_id=user_id, + station_id=station_id, + pump_id=None, + action=action, + resource=resource, + result="success", + ip_address=ip_address + ) + + # Verify the correct event type and severity + mock_log.assert_called_once() + call_args = mock_log.call_args[1] + assert call_args["event_type"] == AuditEventType.SETPOINT_CHANGED + assert call_args["severity"] == AuditSeverity.HIGH + assert call_args["station_id"] == station_id + + def test_log_control_operation_emergency_stop(self): + """Test logging emergency stop operation.""" + user_id = "operator_user" + station_id = "station_001" + action = "emergency_stop" + resource = "system" + ip_address = "192.168.1.100" + + with patch.object(self.audit_logger, 'log_compliance_event') as mock_log: + self.audit_logger.log_control_operation( + user_id=user_id, + station_id=station_id, + pump_id="pump_001", + action=action, + resource=resource, + result="success", + ip_address=ip_address + ) + + # Verify the correct event type and severity + mock_log.assert_called_once() + call_args = mock_log.call_args[1] + assert call_args["event_type"] == AuditEventType.EMERGENCY_STOP_ACTIVATED + assert call_args["severity"] == AuditSeverity.CRITICAL + assert call_args["pump_id"] == "pump_001" + + def test_log_security_event(self): + """Test logging security events.""" + event_type = AuditEventType.CERTIFICATE_EXPIRED + severity = AuditSeverity.HIGH + user_id = "system" + station_id = "station_001" + reason = "TLS certificate expired" + + with patch.object(self.audit_logger, 'log_compliance_event') as mock_log: + self.audit_logger.log_security_event( + event_type=event_type, + severity=severity, + user_id=user_id, + station_id=station_id, + reason=reason + ) + + # Verify the correct parameters + mock_log.assert_called_once() + call_args = mock_log.call_args[1] + assert call_args["event_type"] == event_type + assert call_args["severity"] == severity + assert call_args["user_id"] == user_id + assert call_args["station_id"] == station_id + assert call_args["reason"] == reason + + def test_get_audit_trail_success(self): + """Test successful retrieval of audit trail.""" + # Mock database result + mock_result = [ + {"event_type": "user_login", "user_id": "test_user"}, + {"event_type": "data_read", "user_id": "test_user"} + ] + self.mock_db_client.fetch_all.return_value = mock_result + + start_time = datetime(2024, 1, 1, tzinfo=timezone.utc) + end_time = datetime(2024, 1, 31, tzinfo=timezone.utc) + user_id = "test_user" + + with patch.object(self.audit_logger, 'log_compliance_event') as mock_log: + result = self.audit_logger.get_audit_trail( + start_time=start_time, + end_time=end_time, + user_id=user_id + ) + + # Verify database query + self.mock_db_client.fetch_all.assert_called_once() + + # Verify audit trail access logging + mock_log.assert_called_once() + call_args = mock_log.call_args[1] + assert call_args["event_type"] == AuditEventType.AUDIT_LOG_ACCESSED + assert call_args["user_id"] == user_id + + # Verify result + assert result == mock_result + + def test_get_audit_trail_error(self): + """Test audit trail retrieval with error.""" + # Mock database error + self.mock_db_client.fetch_all.side_effect = Exception("Query failed") + + with patch.object(self.audit_logger.logger, 'error') as mock_error_log: + result = self.audit_logger.get_audit_trail() + + # Verify error logging + mock_error_log.assert_called_once() + + # Verify empty result + assert result == [] + + def test_generate_compliance_report_success(self): + """Test successful compliance report generation.""" + # Mock database result + mock_result = [ + {"event_type": "user_login", "severity": "low", "count": 10, "unique_users": 2}, + {"event_type": "access_denied", "severity": "high", "count": 2, "unique_users": 1} + ] + self.mock_db_client.fetch_all.return_value = mock_result + + start_time = datetime(2024, 1, 1, tzinfo=timezone.utc) + end_time = datetime(2024, 1, 31, tzinfo=timezone.utc) + compliance_standard = "IEC_62443" + + with patch.object(self.audit_logger, 'log_compliance_event') as mock_log: + report = self.audit_logger.generate_compliance_report( + start_time=start_time, + end_time=end_time, + compliance_standard=compliance_standard + ) + + # Verify database query + self.mock_db_client.fetch_all.assert_called_once() + + # Verify compliance report logging + mock_log.assert_called_once() + + # Verify report structure + assert report["compliance_standard"] == compliance_standard + assert report["summary"]["total_events"] == 12 + assert report["summary"]["unique_users"] == 3 + assert report["summary"]["event_types"] == 2 + assert report["events_by_type"] == mock_result + + def test_generate_compliance_report_error(self): + """Test compliance report generation with error.""" + # Mock database error + self.mock_db_client.fetch_all.side_effect = Exception("Report generation failed") + + start_time = datetime(2024, 1, 1, tzinfo=timezone.utc) + end_time = datetime(2024, 1, 31, tzinfo=timezone.utc) + compliance_standard = "ISO_27001" + + with patch.object(self.audit_logger.logger, 'error') as mock_error_log: + report = self.audit_logger.generate_compliance_report( + start_time=start_time, + end_time=end_time, + compliance_standard=compliance_standard + ) + + # Verify error logging + mock_error_log.assert_called_once() + + # Verify error result + assert "error" in report + assert "Report generation failed" in report["error"] \ No newline at end of file diff --git a/tests/unit/test_security.py b/tests/unit/test_security.py new file mode 100644 index 0000000..17cb0f4 --- /dev/null +++ b/tests/unit/test_security.py @@ -0,0 +1,373 @@ +""" +Unit tests for security layer components. +""" + +import pytest +from unittest.mock import Mock, patch +from datetime import datetime, timedelta, timezone +import jwt +import bcrypt + +from src.core.security import ( + AuthenticationManager, + AuthorizationManager, + SecurityManager, + UserRole, + User, + TokenData +) +from config.settings import settings + + +class TestAuthenticationManager: + """Test cases for AuthenticationManager.""" + + def setup_method(self): + """Set up test fixtures.""" + self.auth_manager = AuthenticationManager( + secret_key="test-secret-key", + token_expire_minutes=60 + ) + + def test_hash_password(self): + """Test password hashing.""" + password = "testpassword123" + hashed = self.auth_manager.hash_password(password) + + # Verify the hash is valid + assert bcrypt.checkpw(password.encode('utf-8'), hashed.encode('utf-8')) + + # Verify different passwords produce different hashes + different_password = "differentpassword" + different_hashed = self.auth_manager.hash_password(different_password) + assert hashed != different_hashed + + def test_verify_password(self): + """Test password verification.""" + password = "testpassword123" + hashed = self.auth_manager.hash_password(password) + + # Test correct password + assert self.auth_manager.verify_password(password, hashed) + + # Test incorrect password + assert not self.auth_manager.verify_password("wrongpassword", hashed) + + def test_create_user(self): + """Test user creation.""" + user = self.auth_manager.create_user( + user_id="test_user_001", + username="testuser", + email="test@example.com", + role=UserRole.OPERATOR, + password="testpassword123" + ) + + assert user.user_id == "test_user_001" + assert user.username == "testuser" + assert user.email == "test@example.com" + assert user.role == UserRole.OPERATOR + assert user.active is True + assert user.created_at is not None + + # Verify password was hashed and stored + assert "test_user_001" in self.auth_manager.password_hashes + hashed_password = self.auth_manager.password_hashes["test_user_001"] + assert bcrypt.checkpw("testpassword123".encode('utf-8'), hashed_password.encode('utf-8')) + + def test_create_user_duplicate_id(self): + """Test creating user with duplicate ID.""" + self.auth_manager.create_user( + user_id="duplicate_user", + username="user1", + email="user1@example.com", + role=UserRole.OPERATOR, + password="password1" + ) + + with pytest.raises(ValueError, match="User with ID duplicate_user already exists"): + self.auth_manager.create_user( + user_id="duplicate_user", + username="user2", + email="user2@example.com", + role=UserRole.ENGINEER, + password="password2" + ) + + def test_authenticate_user_success(self): + """Test successful user authentication.""" + # Create a test user + self.auth_manager.create_user( + user_id="auth_test_user", + username="authuser", + email="auth@example.com", + role=UserRole.ENGINEER, + password="authpassword123" + ) + + # Test authentication + user = self.auth_manager.authenticate_user("authuser", "authpassword123") + + assert user is not None + assert user.user_id == "auth_test_user" + assert user.username == "authuser" + assert user.role == UserRole.ENGINEER + assert user.last_login is not None + + def test_authenticate_user_wrong_password(self): + """Test authentication with wrong password.""" + self.auth_manager.create_user( + user_id="wrong_pass_user", + username="wrongpass", + email="wrong@example.com", + role=UserRole.OPERATOR, + password="correctpassword" + ) + + user = self.auth_manager.authenticate_user("wrongpass", "wrongpassword") + assert user is None + + def test_authenticate_user_not_found(self): + """Test authentication with non-existent user.""" + user = self.auth_manager.authenticate_user("nonexistent", "password") + assert user is None + + def test_authenticate_user_inactive(self): + """Test authentication with inactive user.""" + # Create inactive user + user = User( + user_id="inactive_user", + username="inactive", + email="inactive@example.com", + role=UserRole.OPERATOR, + active=False, + created_at=datetime.now(timezone.utc) + ) + self.auth_manager.users["inactive_user"] = user + self.auth_manager.password_hashes["inactive_user"] = self.auth_manager.hash_password("password") + + result = self.auth_manager.authenticate_user("inactive", "password") + assert result is None + + def test_create_access_token(self): + """Test JWT token creation.""" + user = User( + user_id="token_user", + username="tokenuser", + email="token@example.com", + role=UserRole.ADMINISTRATOR, + created_at=datetime.now(timezone.utc) + ) + + token = self.auth_manager.create_access_token(user) + + # Verify token can be decoded + payload = jwt.decode(token, "test-secret-key", algorithms=["HS256"]) + assert payload["user_id"] == "token_user" + assert payload["username"] == "tokenuser" + assert payload["role"] == "administrator" + assert "exp" in payload + + def test_verify_token_success(self): + """Test successful token verification.""" + # Create user and token + user = User( + user_id="verify_user", + username="verifyuser", + email="verify@example.com", + role=UserRole.OPERATOR, + created_at=datetime.now(timezone.utc) + ) + self.auth_manager.users["verify_user"] = user + + token = self.auth_manager.create_access_token(user) + token_data = self.auth_manager.verify_token(token) + + assert token_data is not None + assert token_data.user_id == "verify_user" + assert token_data.username == "verifyuser" + assert token_data.role == UserRole.OPERATOR + + def test_verify_token_expired(self): + """Test verification of expired token.""" + # Create expired token + expired_time = datetime.now(timezone.utc) - timedelta(hours=1) + token_data = TokenData( + user_id="expired_user", + username="expireduser", + role=UserRole.OPERATOR, + exp=expired_time + ) + + token = jwt.encode( + token_data.dict(), + "test-secret-key", + algorithm="HS256" + ) + + result = self.auth_manager.verify_token(token) + assert result is None + + def test_verify_token_invalid_signature(self): + """Test verification of token with invalid signature.""" + # Create token with wrong secret + user = User( + user_id="invalid_user", + username="invaliduser", + email="invalid@example.com", + role=UserRole.OPERATOR, + created_at=datetime.now(timezone.utc) + ) + + token = jwt.encode( + { + "user_id": "invalid_user", + "username": "invaliduser", + "role": "operator", + "exp": datetime.now(timezone.utc) + timedelta(minutes=60) + }, + "wrong-secret-key", + algorithm="HS256" + ) + + result = self.auth_manager.verify_token(token) + assert result is None + + +class TestAuthorizationManager: + """Test cases for AuthorizationManager.""" + + def setup_method(self): + """Set up test fixtures.""" + self.authz_manager = AuthorizationManager() + + def test_has_permission_read_only(self): + """Test permissions for read-only role.""" + assert self.authz_manager.has_permission(UserRole.READ_ONLY, "read_pump_status") + assert self.authz_manager.has_permission(UserRole.READ_ONLY, "read_safety_status") + assert self.authz_manager.has_permission(UserRole.READ_ONLY, "read_audit_logs") + + # Should not have write permissions + assert not self.authz_manager.has_permission(UserRole.READ_ONLY, "emergency_stop") + assert not self.authz_manager.has_permission(UserRole.READ_ONLY, "configure_safety_limits") + + def test_has_permission_operator(self): + """Test permissions for operator role.""" + assert self.authz_manager.has_permission(UserRole.OPERATOR, "read_pump_status") + assert self.authz_manager.has_permission(UserRole.OPERATOR, "emergency_stop") + assert self.authz_manager.has_permission(UserRole.OPERATOR, "clear_emergency_stop") + + # Should not have configuration permissions + assert not self.authz_manager.has_permission(UserRole.OPERATOR, "configure_safety_limits") + assert not self.authz_manager.has_permission(UserRole.OPERATOR, "manage_users") + + def test_has_permission_engineer(self): + """Test permissions for engineer role.""" + assert self.authz_manager.has_permission(UserRole.ENGINEER, "read_pump_status") + assert self.authz_manager.has_permission(UserRole.ENGINEER, "emergency_stop") + assert self.authz_manager.has_permission(UserRole.ENGINEER, "configure_safety_limits") + assert self.authz_manager.has_permission(UserRole.ENGINEER, "manage_pump_configuration") + + # Should not have administrative permissions + assert not self.authz_manager.has_permission(UserRole.ENGINEER, "manage_users") + + def test_has_permission_administrator(self): + """Test permissions for administrator role.""" + # Administrator should have all permissions + all_permissions = [ + "read_pump_status", "read_safety_status", "read_audit_logs", + "emergency_stop", "clear_emergency_stop", "view_alerts", + "configure_safety_limits", "manage_pump_configuration", + "view_system_metrics", "manage_users", "configure_system", + "access_all_stations" + ] + + for permission in all_permissions: + assert self.authz_manager.has_permission(UserRole.ADMINISTRATOR, permission) + + def test_has_permission_unknown_permission(self): + """Test unknown permission.""" + assert not self.authz_manager.has_permission(UserRole.ADMINISTRATOR, "unknown_permission") + + def test_get_allowed_actions(self): + """Test getting all allowed actions for a role.""" + operator_permissions = self.authz_manager.get_allowed_actions(UserRole.OPERATOR) + + assert "read_pump_status" in operator_permissions + assert "emergency_stop" in operator_permissions + assert "clear_emergency_stop" in operator_permissions + assert "configure_safety_limits" not in operator_permissions + + def test_can_access_station(self): + """Test station access control.""" + # Administrators can access all stations + assert self.authz_manager.can_access_station(UserRole.ADMINISTRATOR, "STATION_001") + + # Other roles can access all stations (for now) + assert self.authz_manager.can_access_station(UserRole.OPERATOR, "STATION_001") + assert self.authz_manager.can_access_station(UserRole.ENGINEER, "STATION_001") + assert self.authz_manager.can_access_station(UserRole.READ_ONLY, "STATION_001") + + +class TestSecurityManager: + """Test cases for SecurityManager.""" + + def setup_method(self): + """Set up test fixtures.""" + self.security_manager = SecurityManager() + + def test_authenticate_success(self): + """Test successful authentication.""" + # Default users are created during initialization + token = self.security_manager.authenticate("operator", "operator123") + + assert token is not None + + # Verify token is valid + token_data = self.security_manager.verify_access_token(token) + assert token_data is not None + assert token_data.username == "operator" + assert token_data.role == UserRole.OPERATOR + + def test_authenticate_failure(self): + """Test failed authentication.""" + token = self.security_manager.authenticate("nonexistent", "password") + assert token is None + + def test_check_permission(self): + """Test permission checking.""" + # Create token for operator + token = self.security_manager.authenticate("operator", "operator123") + token_data = self.security_manager.verify_access_token(token) + + # Operator should have emergency_stop permission + assert self.security_manager.check_permission(token_data, "emergency_stop") + + # Operator should not have manage_users permission + assert not self.security_manager.check_permission(token_data, "manage_users") + + def test_get_user_permissions(self): + """Test getting user permissions.""" + # Create token for engineer + token = self.security_manager.authenticate("engineer", "engineer123") + token_data = self.security_manager.verify_access_token(token) + + permissions = self.security_manager.get_user_permissions(token_data) + + # Engineer should have specific permissions + assert "configure_safety_limits" in permissions + assert "manage_pump_configuration" in permissions + assert "emergency_stop" in permissions + + # Engineer should not have administrative permissions + assert "manage_users" not in permissions + + def test_can_access_station(self): + """Test station access control.""" + # Create token for operator + token = self.security_manager.authenticate("operator", "operator123") + token_data = self.security_manager.verify_access_token(token) + + # Operator should be able to access any station + assert self.security_manager.can_access_station(token_data, "STATION_001") + assert self.security_manager.can_access_station(token_data, "STATION_002") \ No newline at end of file diff --git a/tests/unit/test_tls_manager.py b/tests/unit/test_tls_manager.py new file mode 100644 index 0000000..8060c9a --- /dev/null +++ b/tests/unit/test_tls_manager.py @@ -0,0 +1,243 @@ +""" +Unit tests for TLS manager components. +""" + +import pytest +import unittest.mock +from unittest.mock import Mock, patch, MagicMock +import tempfile +import os +import ssl +from pathlib import Path + +from src.core.tls_manager import TLSManager +from config.settings import settings + + +class TestTLSManager: + """Test cases for TLSManager.""" + + def setup_method(self): + """Set up test fixtures.""" + # Create temporary directory for test certificates + self.temp_dir = tempfile.mkdtemp() + self.cert_path = os.path.join(self.temp_dir, "test_cert.pem") + self.key_path = os.path.join(self.temp_dir, "test_key.pem") + + # Create dummy certificate files for testing + with open(self.cert_path, 'w') as f: + f.write("DUMMY CERTIFICATE") + with open(self.key_path, 'w') as f: + f.write("DUMMY PRIVATE KEY") + + # Mock settings + self.original_tls_enabled = settings.tls_enabled + self.original_cert_path = settings.tls_cert_path + self.original_key_path = settings.tls_key_path + + settings.tls_enabled = True + settings.tls_cert_path = self.cert_path + settings.tls_key_path = self.key_path + + def teardown_method(self): + """Clean up test fixtures.""" + # Restore original settings + settings.tls_enabled = self.original_tls_enabled + settings.tls_cert_path = self.original_cert_path + settings.tls_key_path = self.original_key_path + + # Clean up temporary directory + import shutil + shutil.rmtree(self.temp_dir) + + def test_initialization_with_valid_certificates(self): + """Test initialization with valid certificate files.""" + tls_manager = TLSManager() + + assert tls_manager.tls_enabled is True + assert tls_manager.cert_path == self.cert_path + assert tls_manager.key_path == self.key_path + + def test_initialization_with_missing_certificates(self): + """Test initialization with missing certificate files.""" + # Set invalid paths + settings.tls_cert_path = "/nonexistent/cert.pem" + settings.tls_key_path = "/nonexistent/key.pem" + + tls_manager = TLSManager() + + # Should still initialize but with validation failure + assert tls_manager.tls_enabled is True + assert tls_manager.cert_path == "/nonexistent/cert.pem" + + def test_initialization_with_tls_disabled(self): + """Test initialization with TLS disabled.""" + settings.tls_enabled = False + + tls_manager = TLSManager() + + assert tls_manager.tls_enabled is False + + @patch('src.core.tls_manager.TLSManager._get_certificate_expiry') + def test_validate_certificates_success(self, mock_get_expiry): + """Test successful certificate validation.""" + # Mock certificate expiry to be in the future + from datetime import datetime, timezone + future_date = datetime(2030, 12, 31, 23, 59, 59, tzinfo=timezone.utc) + mock_get_expiry.return_value = future_date + + tls_manager = TLSManager() + result = tls_manager._validate_certificates() + + assert result is True + + @patch('src.core.tls_manager.TLSManager._get_certificate_expiry') + def test_validate_certificates_expired(self, mock_get_expiry): + """Test validation of expired certificate.""" + # Mock certificate expiry to be in the past + from datetime import datetime, timezone + past_date = datetime(2020, 1, 1, 0, 0, 0, tzinfo=timezone.utc) + mock_get_expiry.return_value = past_date + + tls_manager = TLSManager() + result = tls_manager._validate_certificates() + + assert result is False + + @patch('src.core.tls_manager.TLSManager._get_certificate_expiry') + def test_validate_certificates_parsing_error(self, mock_get_expiry): + """Test certificate validation with parsing error.""" + # Mock certificate parsing to raise an exception + mock_get_expiry.side_effect = Exception("Parse error") + + tls_manager = TLSManager() + result = tls_manager._validate_certificates() + + assert result is False + + @patch('src.core.tls_manager.ssl') + def test_create_ssl_context_success(self, mock_ssl): + """Test successful SSL context creation.""" + # Mock SSL context + mock_context = Mock() + mock_ssl.create_default_context.return_value = mock_context + + tls_manager = TLSManager() + + # Mock validation to succeed + with patch.object(tls_manager, '_validate_certificates', return_value=True): + result = tls_manager.create_ssl_context() + + assert result is mock_context + # Check that create_default_context was called with CLIENT_AUTH purpose + mock_ssl.create_default_context.assert_called_once() + call_args = mock_ssl.create_default_context.call_args + assert call_args[0][0] == mock_ssl.Purpose.CLIENT_AUTH + mock_context.load_cert_chain.assert_called_once_with(self.cert_path, self.key_path) + + def test_create_ssl_context_tls_disabled(self): + """Test SSL context creation with TLS disabled.""" + settings.tls_enabled = False + + tls_manager = TLSManager() + result = tls_manager.create_ssl_context() + + assert result is None + + def test_create_ssl_context_invalid_certificates(self): + """Test SSL context creation with invalid certificates.""" + tls_manager = TLSManager() + + # Mock validation to fail + with patch.object(tls_manager, '_validate_certificates', return_value=False): + result = tls_manager.create_ssl_context() + + assert result is None + + def test_get_rest_api_ssl_config_success(self): + """Test successful REST API SSL configuration.""" + tls_manager = TLSManager() + + # Mock validation to succeed + with patch.object(tls_manager, '_validate_certificates', return_value=True): + result = tls_manager.get_rest_api_ssl_config() + + assert result == (self.cert_path, self.key_path) + + def test_get_rest_api_ssl_config_tls_disabled(self): + """Test REST API SSL configuration with TLS disabled.""" + settings.tls_enabled = False + + tls_manager = TLSManager() + result = tls_manager.get_rest_api_ssl_config() + + assert result is None + + def test_get_rest_api_ssl_config_invalid_certificates(self): + """Test REST API SSL configuration with invalid certificates.""" + tls_manager = TLSManager() + + # Mock validation to fail + with patch.object(tls_manager, '_validate_certificates', return_value=False): + result = tls_manager.get_rest_api_ssl_config() + + assert result is None + + def test_check_certificate_rotation_needed(self): + """Test certificate rotation check when needed.""" + tls_manager = TLSManager() + + # Mock certificate with near expiry + from datetime import datetime, timezone, timedelta + near_expiry = datetime.now(timezone.utc) + timedelta(days=5) # Expires in 5 days + tls_manager.cert_expiry_dates[self.cert_path] = near_expiry + + result = tls_manager.check_certificate_rotation() + + assert result is True + + def test_check_certificate_rotation_not_needed(self): + """Test certificate rotation check when not needed.""" + tls_manager = TLSManager() + + # Mock certificate with distant expiry + from datetime import datetime, timezone, timedelta + distant_expiry = datetime.now(timezone.utc) + timedelta(days=365) # Expires in 1 year + tls_manager.cert_expiry_dates[self.cert_path] = distant_expiry + + result = tls_manager.check_certificate_rotation() + + assert result is False + + def test_check_certificate_rotation_tls_disabled(self): + """Test certificate rotation check with TLS disabled.""" + settings.tls_enabled = False + + tls_manager = TLSManager() + result = tls_manager.check_certificate_rotation() + + assert result is False + + @patch('src.core.tls_manager.os.makedirs') + def test_generate_self_signed_certificate_success(self, mock_makedirs): + """Test successful self-signed certificate generation.""" + tls_manager = TLSManager() + + # Mock the entire certificate generation to succeed + with patch('builtins.open', unittest.mock.mock_open()) as mock_file: + result = tls_manager.generate_self_signed_certificate(self.temp_dir) + + assert result is True + assert tls_manager.cert_path.endswith("cert.pem") + assert tls_manager.key_path.endswith("key.pem") + + @patch('src.core.tls_manager.os.makedirs') + def test_generate_self_signed_certificate_failure(self, mock_makedirs): + """Test self-signed certificate generation failure.""" + # Mock makedirs to raise an exception + mock_makedirs.side_effect = Exception("Permission denied") + + tls_manager = TLSManager() + result = tls_manager.generate_self_signed_certificate(self.temp_dir) + + assert result is False \ No newline at end of file