Phase 4: Complete Security Layer Implementation
- Implemented JWT-based authentication with bcrypt password hashing - Added role-based access control (RBAC) with four user roles - Created TLS/SSL encryption with certificate management - Enhanced audit logging for IEC 62443, ISO 27001, and NIS2 compliance - Added comprehensive security tests (56 tests passing) - Updated REST API with authentication and permission checks - Added security settings to configuration Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
parent
db0ace8d2c
commit
dfa3f0832b
10
README.md
10
README.md
|
|
@ -35,10 +35,12 @@ The Calejo Control Adapter translates optimized pump control plans from Calejo O
|
||||||
- Unified main application
|
- Unified main application
|
||||||
- 15 comprehensive unit tests for SetpointManager
|
- 15 comprehensive unit tests for SetpointManager
|
||||||
|
|
||||||
🔄 **Phase 4**: Security Layer (In Progress)
|
✅ **Phase 4**: Security Layer
|
||||||
- Authentication and authorization
|
- JWT-based authentication with bcrypt password hashing
|
||||||
- Audit logging
|
- Role-based access control (RBAC) with four user roles
|
||||||
- TLS/SSL encryption
|
- TLS/SSL encryption with certificate management
|
||||||
|
- Compliance audit logging for IEC 62443, ISO 27001, and NIS2
|
||||||
|
- 56 comprehensive security tests (24 auth/authz, 17 TLS, 15 audit)
|
||||||
|
|
||||||
⏳ **Phase 5**: Protocol Servers (Pending)
|
⏳ **Phase 5**: Protocol Servers (Pending)
|
||||||
- Enhanced protocol implementations
|
- Enhanced protocol implementations
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,18 @@ class Settings(BaseSettings):
|
||||||
tls_cert_path: Optional[str] = None
|
tls_cert_path: Optional[str] = None
|
||||||
tls_key_path: Optional[str] = None
|
tls_key_path: Optional[str] = None
|
||||||
|
|
||||||
|
# JWT Authentication
|
||||||
|
jwt_secret_key: str = "your-secret-key-change-in-production"
|
||||||
|
jwt_token_expire_minutes: int = 60
|
||||||
|
jwt_algorithm: str = "HS256"
|
||||||
|
|
||||||
|
# Password policy
|
||||||
|
password_min_length: int = 8
|
||||||
|
password_require_uppercase: bool = True
|
||||||
|
password_require_lowercase: bool = True
|
||||||
|
password_require_numbers: bool = True
|
||||||
|
password_require_special: bool = True
|
||||||
|
|
||||||
# OPC UA
|
# OPC UA
|
||||||
opcua_enabled: bool = True
|
opcua_enabled: bool = True
|
||||||
opcua_host: str = "localhost"
|
opcua_host: str = "localhost"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,463 @@
|
||||||
|
"""
|
||||||
|
Compliance Audit Logger for Calejo Control Adapter.
|
||||||
|
|
||||||
|
Provides enhanced audit logging capabilities compliant with:
|
||||||
|
- IEC 62443 (Industrial Automation and Control Systems Security)
|
||||||
|
- ISO 27001 (Information Security Management)
|
||||||
|
- NIS2 Directive (Network and Information Systems Security)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Dict, Any, Optional, List
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from config.settings import settings
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class AuditEventType(Enum):
|
||||||
|
"""Audit event types for compliance requirements."""
|
||||||
|
|
||||||
|
# Authentication and Authorization
|
||||||
|
USER_LOGIN = "user_login"
|
||||||
|
USER_LOGOUT = "user_logout"
|
||||||
|
USER_CREATED = "user_created"
|
||||||
|
USER_MODIFIED = "user_modified"
|
||||||
|
USER_DELETED = "user_deleted"
|
||||||
|
PASSWORD_CHANGED = "password_changed"
|
||||||
|
ROLE_CHANGED = "role_changed"
|
||||||
|
|
||||||
|
# System Access
|
||||||
|
SYSTEM_START = "system_start"
|
||||||
|
SYSTEM_STOP = "system_stop"
|
||||||
|
SYSTEM_CONFIG_CHANGED = "system_config_changed"
|
||||||
|
|
||||||
|
# Control Operations
|
||||||
|
SETPOINT_CHANGED = "setpoint_changed"
|
||||||
|
EMERGENCY_STOP_ACTIVATED = "emergency_stop_activated"
|
||||||
|
EMERGENCY_STOP_RESET = "emergency_stop_reset"
|
||||||
|
PUMP_CONTROL = "pump_control"
|
||||||
|
VALVE_CONTROL = "valve_control"
|
||||||
|
|
||||||
|
# Security Events
|
||||||
|
ACCESS_DENIED = "access_denied"
|
||||||
|
INVALID_AUTHENTICATION = "invalid_authentication"
|
||||||
|
SESSION_TIMEOUT = "session_timeout"
|
||||||
|
CERTIFICATE_EXPIRED = "certificate_expired"
|
||||||
|
CERTIFICATE_ROTATED = "certificate_rotated"
|
||||||
|
|
||||||
|
# Data Operations
|
||||||
|
DATA_READ = "data_read"
|
||||||
|
DATA_WRITE = "data_write"
|
||||||
|
DATA_EXPORT = "data_export"
|
||||||
|
DATA_DELETED = "data_deleted"
|
||||||
|
|
||||||
|
# Network Operations
|
||||||
|
CONNECTION_ESTABLISHED = "connection_established"
|
||||||
|
CONNECTION_CLOSED = "connection_closed"
|
||||||
|
CONNECTION_REJECTED = "connection_rejected"
|
||||||
|
|
||||||
|
# Compliance Events
|
||||||
|
AUDIT_LOG_ACCESSED = "audit_log_accessed"
|
||||||
|
COMPLIANCE_CHECK = "compliance_check"
|
||||||
|
SECURITY_SCAN = "security_scan"
|
||||||
|
|
||||||
|
|
||||||
|
class AuditSeverity(Enum):
|
||||||
|
"""Audit event severity levels."""
|
||||||
|
|
||||||
|
LOW = "low"
|
||||||
|
MEDIUM = "medium"
|
||||||
|
HIGH = "high"
|
||||||
|
CRITICAL = "critical"
|
||||||
|
|
||||||
|
|
||||||
|
class ComplianceAuditLogger:
|
||||||
|
"""
|
||||||
|
Enhanced audit logger for compliance requirements.
|
||||||
|
|
||||||
|
Provides comprehensive audit trail capabilities compliant with
|
||||||
|
IEC 62443, ISO 27001, and NIS2 Directive requirements.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db_client):
|
||||||
|
self.db_client = db_client
|
||||||
|
self.logger = structlog.get_logger(__name__)
|
||||||
|
|
||||||
|
def log_compliance_event(
|
||||||
|
self,
|
||||||
|
event_type: AuditEventType,
|
||||||
|
severity: AuditSeverity,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
station_id: Optional[str] = None,
|
||||||
|
pump_id: Optional[str] = None,
|
||||||
|
ip_address: Optional[str] = None,
|
||||||
|
protocol: Optional[str] = None,
|
||||||
|
action: Optional[str] = None,
|
||||||
|
resource: Optional[str] = None,
|
||||||
|
result: Optional[str] = None,
|
||||||
|
reason: Optional[str] = None,
|
||||||
|
compliance_standard: Optional[List[str]] = None,
|
||||||
|
event_data: Optional[Dict[str, Any]] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Log a compliance audit event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_type: Type of audit event
|
||||||
|
severity: Severity level
|
||||||
|
user_id: User ID performing the action
|
||||||
|
station_id: Station ID if applicable
|
||||||
|
pump_id: Pump ID if applicable
|
||||||
|
ip_address: Source IP address
|
||||||
|
protocol: Communication protocol used
|
||||||
|
action: Specific action performed
|
||||||
|
resource: Resource accessed or modified
|
||||||
|
result: Result of the action (success/failure)
|
||||||
|
reason: Reason for the action or failure
|
||||||
|
compliance_standard: List of compliance standards this event relates to
|
||||||
|
event_data: Additional event-specific data
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Default compliance standards
|
||||||
|
if compliance_standard is None:
|
||||||
|
compliance_standard = ["IEC_62443", "ISO_27001", "NIS2"]
|
||||||
|
|
||||||
|
# Create comprehensive audit record
|
||||||
|
audit_record = {
|
||||||
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||||
|
"event_type": event_type.value,
|
||||||
|
"severity": severity.value,
|
||||||
|
"user_id": user_id,
|
||||||
|
"station_id": station_id,
|
||||||
|
"pump_id": pump_id,
|
||||||
|
"ip_address": ip_address,
|
||||||
|
"protocol": protocol,
|
||||||
|
"action": action,
|
||||||
|
"resource": resource,
|
||||||
|
"result": result,
|
||||||
|
"reason": reason,
|
||||||
|
"compliance_standard": compliance_standard,
|
||||||
|
"event_data": event_data or {},
|
||||||
|
"app_name": settings.app_name,
|
||||||
|
"app_version": settings.app_version,
|
||||||
|
"environment": settings.environment
|
||||||
|
}
|
||||||
|
|
||||||
|
# Log to structured logs
|
||||||
|
self.logger.info(
|
||||||
|
"compliance_audit_event",
|
||||||
|
**audit_record
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log to database if audit logging is enabled
|
||||||
|
if settings.audit_log_enabled:
|
||||||
|
self._log_to_database(audit_record)
|
||||||
|
|
||||||
|
def _log_to_database(self, audit_record: Dict[str, Any]):
|
||||||
|
"""Log audit record to database."""
|
||||||
|
try:
|
||||||
|
query = """
|
||||||
|
INSERT INTO compliance_audit_log
|
||||||
|
(timestamp, event_type, severity, user_id, station_id, pump_id,
|
||||||
|
ip_address, protocol, action, resource, result, reason,
|
||||||
|
compliance_standard, event_data, app_name, app_version, environment)
|
||||||
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||||
|
"""
|
||||||
|
self.db_client.execute(
|
||||||
|
query,
|
||||||
|
(
|
||||||
|
audit_record["timestamp"],
|
||||||
|
audit_record["event_type"],
|
||||||
|
audit_record["severity"],
|
||||||
|
audit_record["user_id"],
|
||||||
|
audit_record["station_id"],
|
||||||
|
audit_record["pump_id"],
|
||||||
|
audit_record["ip_address"],
|
||||||
|
audit_record["protocol"],
|
||||||
|
audit_record["action"],
|
||||||
|
audit_record["resource"],
|
||||||
|
audit_record["result"],
|
||||||
|
audit_record["reason"],
|
||||||
|
audit_record["compliance_standard"],
|
||||||
|
audit_record["event_data"],
|
||||||
|
audit_record["app_name"],
|
||||||
|
audit_record["app_version"],
|
||||||
|
audit_record["environment"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(
|
||||||
|
"compliance_audit_database_failed",
|
||||||
|
error=str(e),
|
||||||
|
event_type=audit_record["event_type"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def log_user_authentication(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
result: str,
|
||||||
|
ip_address: str,
|
||||||
|
reason: Optional[str] = None
|
||||||
|
):
|
||||||
|
"""Log user authentication events."""
|
||||||
|
event_type = (
|
||||||
|
AuditEventType.USER_LOGIN if result == "success"
|
||||||
|
else AuditEventType.INVALID_AUTHENTICATION
|
||||||
|
)
|
||||||
|
severity = (
|
||||||
|
AuditSeverity.LOW if result == "success"
|
||||||
|
else AuditSeverity.HIGH
|
||||||
|
)
|
||||||
|
|
||||||
|
self.log_compliance_event(
|
||||||
|
event_type=event_type,
|
||||||
|
severity=severity,
|
||||||
|
user_id=user_id,
|
||||||
|
ip_address=ip_address,
|
||||||
|
action="authentication",
|
||||||
|
resource="system",
|
||||||
|
result=result,
|
||||||
|
reason=reason
|
||||||
|
)
|
||||||
|
|
||||||
|
def log_access_control(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
action: str,
|
||||||
|
resource: str,
|
||||||
|
result: str,
|
||||||
|
ip_address: str,
|
||||||
|
reason: Optional[str] = None
|
||||||
|
):
|
||||||
|
"""Log access control events."""
|
||||||
|
event_type = (
|
||||||
|
AuditEventType.ACCESS_DENIED if result == "denied"
|
||||||
|
else AuditEventType.DATA_READ if action == "read"
|
||||||
|
else AuditEventType.DATA_WRITE if action == "write"
|
||||||
|
else AuditEventType.SYSTEM_CONFIG_CHANGED
|
||||||
|
)
|
||||||
|
severity = (
|
||||||
|
AuditSeverity.HIGH if result == "denied"
|
||||||
|
else AuditSeverity.MEDIUM
|
||||||
|
)
|
||||||
|
|
||||||
|
self.log_compliance_event(
|
||||||
|
event_type=event_type,
|
||||||
|
severity=severity,
|
||||||
|
user_id=user_id,
|
||||||
|
ip_address=ip_address,
|
||||||
|
action=action,
|
||||||
|
resource=resource,
|
||||||
|
result=result,
|
||||||
|
reason=reason
|
||||||
|
)
|
||||||
|
|
||||||
|
def log_control_operation(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
station_id: str,
|
||||||
|
pump_id: Optional[str],
|
||||||
|
action: str,
|
||||||
|
resource: str,
|
||||||
|
result: str,
|
||||||
|
ip_address: str,
|
||||||
|
event_data: Optional[Dict[str, Any]] = None
|
||||||
|
):
|
||||||
|
"""Log control system operations."""
|
||||||
|
event_type = (
|
||||||
|
AuditEventType.SETPOINT_CHANGED if action == "setpoint_change"
|
||||||
|
else AuditEventType.EMERGENCY_STOP_ACTIVATED if action == "emergency_stop"
|
||||||
|
else AuditEventType.PUMP_CONTROL if "pump" in resource.lower()
|
||||||
|
else AuditEventType.VALVE_CONTROL if "valve" in resource.lower()
|
||||||
|
else AuditEventType.SYSTEM_CONFIG_CHANGED
|
||||||
|
)
|
||||||
|
severity = (
|
||||||
|
AuditSeverity.CRITICAL if action == "emergency_stop"
|
||||||
|
else AuditSeverity.HIGH if action == "setpoint_change"
|
||||||
|
else AuditSeverity.MEDIUM
|
||||||
|
)
|
||||||
|
|
||||||
|
self.log_compliance_event(
|
||||||
|
event_type=event_type,
|
||||||
|
severity=severity,
|
||||||
|
user_id=user_id,
|
||||||
|
station_id=station_id,
|
||||||
|
pump_id=pump_id,
|
||||||
|
ip_address=ip_address,
|
||||||
|
action=action,
|
||||||
|
resource=resource,
|
||||||
|
result=result,
|
||||||
|
event_data=event_data
|
||||||
|
)
|
||||||
|
|
||||||
|
def log_security_event(
|
||||||
|
self,
|
||||||
|
event_type: AuditEventType,
|
||||||
|
severity: AuditSeverity,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
station_id: Optional[str] = None,
|
||||||
|
ip_address: Optional[str] = None,
|
||||||
|
reason: Optional[str] = None,
|
||||||
|
event_data: Optional[Dict[str, Any]] = None
|
||||||
|
):
|
||||||
|
"""Log security-related events."""
|
||||||
|
self.log_compliance_event(
|
||||||
|
event_type=event_type,
|
||||||
|
severity=severity,
|
||||||
|
user_id=user_id,
|
||||||
|
station_id=station_id,
|
||||||
|
ip_address=ip_address,
|
||||||
|
action="security_event",
|
||||||
|
resource="system",
|
||||||
|
result="detected",
|
||||||
|
reason=reason,
|
||||||
|
event_data=event_data
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_audit_trail(
|
||||||
|
self,
|
||||||
|
start_time: Optional[datetime] = None,
|
||||||
|
end_time: Optional[datetime] = None,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
event_type: Optional[AuditEventType] = None,
|
||||||
|
severity: Optional[AuditSeverity] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Retrieve audit trail for compliance reporting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_time: Start time for audit trail
|
||||||
|
end_time: End time for audit trail
|
||||||
|
user_id: Filter by user ID
|
||||||
|
event_type: Filter by event type
|
||||||
|
severity: Filter by severity
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of audit records
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
query = """
|
||||||
|
SELECT * FROM compliance_audit_log
|
||||||
|
WHERE 1=1
|
||||||
|
"""
|
||||||
|
params = []
|
||||||
|
|
||||||
|
if start_time:
|
||||||
|
query += " AND timestamp >= %s"
|
||||||
|
params.append(start_time.isoformat())
|
||||||
|
|
||||||
|
if end_time:
|
||||||
|
query += " AND timestamp <= %s"
|
||||||
|
params.append(end_time.isoformat())
|
||||||
|
|
||||||
|
if user_id:
|
||||||
|
query += " AND user_id = %s"
|
||||||
|
params.append(user_id)
|
||||||
|
|
||||||
|
if event_type:
|
||||||
|
query += " AND event_type = %s"
|
||||||
|
params.append(event_type.value)
|
||||||
|
|
||||||
|
if severity:
|
||||||
|
query += " AND severity = %s"
|
||||||
|
params.append(severity.value)
|
||||||
|
|
||||||
|
query += " ORDER BY timestamp DESC"
|
||||||
|
|
||||||
|
result = self.db_client.fetch_all(query, params)
|
||||||
|
|
||||||
|
# Log the audit trail access
|
||||||
|
self.log_compliance_event(
|
||||||
|
event_type=AuditEventType.AUDIT_LOG_ACCESSED,
|
||||||
|
severity=AuditSeverity.LOW,
|
||||||
|
user_id=user_id,
|
||||||
|
action="audit_trail_access",
|
||||||
|
resource="audit_log",
|
||||||
|
result="success"
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(
|
||||||
|
"audit_trail_retrieval_failed",
|
||||||
|
error=str(e),
|
||||||
|
user_id=user_id
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
def generate_compliance_report(
|
||||||
|
self,
|
||||||
|
start_time: datetime,
|
||||||
|
end_time: datetime,
|
||||||
|
compliance_standard: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Generate compliance report for specified standard.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_time: Report start time
|
||||||
|
end_time: Report end time
|
||||||
|
compliance_standard: Compliance standard to report on
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Compliance report data
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
query = """
|
||||||
|
SELECT
|
||||||
|
event_type,
|
||||||
|
severity,
|
||||||
|
COUNT(*) as count,
|
||||||
|
COUNT(DISTINCT user_id) as unique_users
|
||||||
|
FROM compliance_audit_log
|
||||||
|
WHERE timestamp BETWEEN %s AND %s
|
||||||
|
AND %s = ANY(compliance_standard)
|
||||||
|
GROUP BY event_type, severity
|
||||||
|
ORDER BY count DESC
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = self.db_client.fetch_all(
|
||||||
|
query,
|
||||||
|
(start_time.isoformat(), end_time.isoformat(), compliance_standard)
|
||||||
|
)
|
||||||
|
|
||||||
|
report = {
|
||||||
|
"compliance_standard": compliance_standard,
|
||||||
|
"report_period": {
|
||||||
|
"start_time": start_time.isoformat(),
|
||||||
|
"end_time": end_time.isoformat()
|
||||||
|
},
|
||||||
|
"summary": {
|
||||||
|
"total_events": sum(row["count"] for row in result),
|
||||||
|
"unique_users": sum(row["unique_users"] for row in result),
|
||||||
|
"event_types": len(set(row["event_type"] for row in result))
|
||||||
|
},
|
||||||
|
"events_by_type": result
|
||||||
|
}
|
||||||
|
|
||||||
|
# Log the compliance report generation
|
||||||
|
self.log_compliance_event(
|
||||||
|
event_type=AuditEventType.COMPLIANCE_CHECK,
|
||||||
|
severity=AuditSeverity.LOW,
|
||||||
|
action="compliance_report_generated",
|
||||||
|
resource="compliance",
|
||||||
|
result="success",
|
||||||
|
event_data={
|
||||||
|
"compliance_standard": compliance_standard,
|
||||||
|
"report_period": report["report_period"]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return report
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(
|
||||||
|
"compliance_report_generation_failed",
|
||||||
|
error=str(e),
|
||||||
|
compliance_standard=compliance_standard
|
||||||
|
)
|
||||||
|
return {"error": str(e)}
|
||||||
|
|
@ -0,0 +1,358 @@
|
||||||
|
"""
|
||||||
|
Security layer for Calejo Control Adapter.
|
||||||
|
|
||||||
|
Provides authentication, authorization, and security utilities for compliance
|
||||||
|
with IEC 62443, ISO 27001, and NIS2 Directive requirements.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
import bcrypt
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Optional, Dict, List, Any
|
||||||
|
from enum import Enum
|
||||||
|
from pydantic import BaseModel
|
||||||
|
import structlog
|
||||||
|
|
||||||
|
from config.settings import settings
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class UserRole(str, Enum):
|
||||||
|
"""User roles for role-based access control."""
|
||||||
|
OPERATOR = "operator"
|
||||||
|
ENGINEER = "engineer"
|
||||||
|
ADMINISTRATOR = "administrator"
|
||||||
|
READ_ONLY = "read_only"
|
||||||
|
|
||||||
|
|
||||||
|
class User(BaseModel):
|
||||||
|
"""User model for authentication and authorization."""
|
||||||
|
user_id: str
|
||||||
|
username: str
|
||||||
|
email: str
|
||||||
|
role: UserRole
|
||||||
|
active: bool = True
|
||||||
|
created_at: datetime
|
||||||
|
last_login: Optional[datetime] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TokenData(BaseModel):
|
||||||
|
"""Data encoded in JWT tokens."""
|
||||||
|
user_id: str
|
||||||
|
username: str
|
||||||
|
role: UserRole
|
||||||
|
exp: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationManager:
|
||||||
|
"""
|
||||||
|
Manages user authentication with JWT tokens and password hashing.
|
||||||
|
|
||||||
|
Implements security controls for IEC 62443, ISO 27001, and NIS2 compliance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, secret_key: str, algorithm: str = "HS256", token_expire_minutes: int = 60):
|
||||||
|
self.secret_key = secret_key
|
||||||
|
self.algorithm = algorithm
|
||||||
|
self.token_expire_minutes = token_expire_minutes
|
||||||
|
|
||||||
|
# In-memory user store (in production, this would be a database)
|
||||||
|
self.users: Dict[str, User] = {}
|
||||||
|
self.password_hashes: Dict[str, str] = {}
|
||||||
|
|
||||||
|
# Initialize with default users for development
|
||||||
|
self._initialize_default_users()
|
||||||
|
|
||||||
|
def _initialize_default_users(self):
|
||||||
|
"""Initialize default users for development and testing."""
|
||||||
|
default_users = [
|
||||||
|
{
|
||||||
|
"user_id": "admin_001",
|
||||||
|
"username": "admin",
|
||||||
|
"email": "admin@calejo.com",
|
||||||
|
"role": UserRole.ADMINISTRATOR,
|
||||||
|
"password": "admin123"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"user_id": "operator_001",
|
||||||
|
"username": "operator",
|
||||||
|
"email": "operator@calejo.com",
|
||||||
|
"role": UserRole.OPERATOR,
|
||||||
|
"password": "operator123"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"user_id": "engineer_001",
|
||||||
|
"username": "engineer",
|
||||||
|
"email": "engineer@calejo.com",
|
||||||
|
"role": UserRole.ENGINEER,
|
||||||
|
"password": "engineer123"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"user_id": "viewer_001",
|
||||||
|
"username": "viewer",
|
||||||
|
"email": "viewer@calejo.com",
|
||||||
|
"role": UserRole.READ_ONLY,
|
||||||
|
"password": "viewer123"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
for user_data in default_users:
|
||||||
|
self.create_user(
|
||||||
|
user_id=user_data["user_id"],
|
||||||
|
username=user_data["username"],
|
||||||
|
email=user_data["email"],
|
||||||
|
role=user_data["role"],
|
||||||
|
password=user_data["password"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def hash_password(self, password: str) -> str:
|
||||||
|
"""Hash a password using bcrypt."""
|
||||||
|
salt = bcrypt.gensalt()
|
||||||
|
hashed = bcrypt.hashpw(password.encode('utf-8'), salt)
|
||||||
|
return hashed.decode('utf-8')
|
||||||
|
|
||||||
|
def verify_password(self, plain_password: str, hashed_password: str) -> bool:
|
||||||
|
"""Verify a password against its hash."""
|
||||||
|
return bcrypt.checkpw(
|
||||||
|
plain_password.encode('utf-8'),
|
||||||
|
hashed_password.encode('utf-8')
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_user(self, user_id: str, username: str, email: str, role: UserRole, password: str) -> User:
|
||||||
|
"""Create a new user with hashed password."""
|
||||||
|
if user_id in self.users:
|
||||||
|
raise ValueError(f"User with ID {user_id} already exists")
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
user_id=user_id,
|
||||||
|
username=username,
|
||||||
|
email=email,
|
||||||
|
role=role,
|
||||||
|
created_at=datetime.now(timezone.utc)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.users[user_id] = user
|
||||||
|
self.password_hashes[user_id] = self.hash_password(password)
|
||||||
|
|
||||||
|
logger.info("user_created", user_id=user_id, username=username, role=role.value)
|
||||||
|
return user
|
||||||
|
|
||||||
|
def authenticate_user(self, username: str, password: str) -> Optional[User]:
|
||||||
|
"""Authenticate a user and return user object if successful."""
|
||||||
|
# Find user by username
|
||||||
|
user = None
|
||||||
|
for u in self.users.values():
|
||||||
|
if u.username == username and u.active:
|
||||||
|
user = u
|
||||||
|
break
|
||||||
|
|
||||||
|
if not user:
|
||||||
|
logger.warning("authentication_failed", username=username, reason="user_not_found")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Verify password
|
||||||
|
if not self.verify_password(password, self.password_hashes[user.user_id]):
|
||||||
|
logger.warning("authentication_failed", username=username, reason="invalid_password")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Update last login
|
||||||
|
user.last_login = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
logger.info("user_authenticated", user_id=user.user_id, username=username, role=user.role.value)
|
||||||
|
return user
|
||||||
|
|
||||||
|
def create_access_token(self, user: User) -> str:
|
||||||
|
"""Create a JWT access token for the user."""
|
||||||
|
expires_delta = timedelta(minutes=self.token_expire_minutes)
|
||||||
|
expire = datetime.now(timezone.utc) + expires_delta
|
||||||
|
|
||||||
|
token_data = TokenData(
|
||||||
|
user_id=user.user_id,
|
||||||
|
username=user.username,
|
||||||
|
role=user.role,
|
||||||
|
exp=expire
|
||||||
|
)
|
||||||
|
|
||||||
|
encoded_jwt = jwt.encode(
|
||||||
|
token_data.dict(),
|
||||||
|
self.secret_key,
|
||||||
|
algorithm=self.algorithm
|
||||||
|
)
|
||||||
|
|
||||||
|
return encoded_jwt
|
||||||
|
|
||||||
|
def verify_token(self, token: str) -> Optional[TokenData]:
|
||||||
|
"""Verify and decode a JWT token."""
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
|
||||||
|
token_data = TokenData(**payload)
|
||||||
|
|
||||||
|
# Check if token is expired
|
||||||
|
if token_data.exp < datetime.now(timezone.utc):
|
||||||
|
logger.warning("token_expired", username=token_data.username)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Verify user still exists and is active
|
||||||
|
user = self.users.get(token_data.user_id)
|
||||||
|
if not user or not user.active:
|
||||||
|
logger.warning("token_invalid_user", username=token_data.username)
|
||||||
|
return None
|
||||||
|
|
||||||
|
return token_data
|
||||||
|
|
||||||
|
except jwt.PyJWTError as e:
|
||||||
|
logger.warning("token_verification_failed", error=str(e))
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class AuthorizationManager:
|
||||||
|
"""
|
||||||
|
Manages role-based access control (RBAC) for authorization.
|
||||||
|
|
||||||
|
Implements IEC 62443 zone security model and ISO 27001 access control requirements.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Define permissions for each role
|
||||||
|
self.role_permissions = {
|
||||||
|
UserRole.READ_ONLY: {
|
||||||
|
"read_pump_status",
|
||||||
|
"read_safety_status",
|
||||||
|
"read_audit_logs"
|
||||||
|
},
|
||||||
|
UserRole.OPERATOR: {
|
||||||
|
"read_pump_status",
|
||||||
|
"read_safety_status",
|
||||||
|
"read_audit_logs",
|
||||||
|
"emergency_stop",
|
||||||
|
"clear_emergency_stop",
|
||||||
|
"view_alerts"
|
||||||
|
},
|
||||||
|
UserRole.ENGINEER: {
|
||||||
|
"read_pump_status",
|
||||||
|
"read_safety_status",
|
||||||
|
"read_audit_logs",
|
||||||
|
"emergency_stop",
|
||||||
|
"clear_emergency_stop",
|
||||||
|
"view_alerts",
|
||||||
|
"configure_safety_limits",
|
||||||
|
"manage_pump_configuration",
|
||||||
|
"view_system_metrics"
|
||||||
|
},
|
||||||
|
UserRole.ADMINISTRATOR: {
|
||||||
|
"read_pump_status",
|
||||||
|
"read_safety_status",
|
||||||
|
"read_audit_logs",
|
||||||
|
"emergency_stop",
|
||||||
|
"clear_emergency_stop",
|
||||||
|
"view_alerts",
|
||||||
|
"configure_safety_limits",
|
||||||
|
"manage_pump_configuration",
|
||||||
|
"view_system_metrics",
|
||||||
|
"manage_users",
|
||||||
|
"configure_system",
|
||||||
|
"access_all_stations"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def has_permission(self, role: UserRole, permission: str) -> bool:
|
||||||
|
"""Check if a role has the specified permission."""
|
||||||
|
permissions = self.role_permissions.get(role, set())
|
||||||
|
return permission in permissions
|
||||||
|
|
||||||
|
def get_allowed_actions(self, role: UserRole) -> List[str]:
|
||||||
|
"""Get all allowed actions for a role."""
|
||||||
|
return list(self.role_permissions.get(role, set()))
|
||||||
|
|
||||||
|
def can_access_station(self, role: UserRole, station_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if user can access a specific station.
|
||||||
|
|
||||||
|
Administrators can access all stations, others may have station-specific
|
||||||
|
permissions (to be implemented with database integration).
|
||||||
|
"""
|
||||||
|
if role == UserRole.ADMINISTRATOR:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# For now, all authenticated users can access all stations
|
||||||
|
# In production, this would check station-specific permissions
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class SecurityManager:
|
||||||
|
"""
|
||||||
|
Main security manager that coordinates authentication and authorization.
|
||||||
|
|
||||||
|
Provides a unified interface for security operations and compliance logging.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, audit_logger=None):
|
||||||
|
self.auth_manager = AuthenticationManager(
|
||||||
|
secret_key=settings.jwt_secret_key,
|
||||||
|
token_expire_minutes=settings.jwt_token_expire_minutes
|
||||||
|
)
|
||||||
|
self.authz_manager = AuthorizationManager()
|
||||||
|
self.audit_logger = audit_logger
|
||||||
|
|
||||||
|
def authenticate(self, username: str, password: str) -> Optional[str]:
|
||||||
|
"""Authenticate user and return JWT token if successful."""
|
||||||
|
user = self.auth_manager.authenticate_user(username, password)
|
||||||
|
if user:
|
||||||
|
token = self.auth_manager.create_access_token(user)
|
||||||
|
|
||||||
|
# Log successful authentication
|
||||||
|
if self.audit_logger:
|
||||||
|
self.audit_logger.log(
|
||||||
|
event_type="USER_AUTHENTICATION",
|
||||||
|
severity="INFO",
|
||||||
|
user_id=user.user_id,
|
||||||
|
event_data={
|
||||||
|
"username": username,
|
||||||
|
"role": user.role.value,
|
||||||
|
"result": "SUCCESS"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
|
# Log failed authentication
|
||||||
|
if self.audit_logger:
|
||||||
|
self.audit_logger.log(
|
||||||
|
event_type="USER_AUTHENTICATION",
|
||||||
|
severity="WARNING",
|
||||||
|
event_data={
|
||||||
|
"username": username,
|
||||||
|
"result": "FAILURE"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def verify_access_token(self, token: str) -> Optional[TokenData]:
|
||||||
|
"""Verify JWT token and return token data if valid."""
|
||||||
|
return self.auth_manager.verify_token(token)
|
||||||
|
|
||||||
|
def check_permission(self, token_data: TokenData, permission: str) -> bool:
|
||||||
|
"""Check if user has the specified permission."""
|
||||||
|
return self.authz_manager.has_permission(token_data.role, permission)
|
||||||
|
|
||||||
|
def can_access_station(self, token_data: TokenData, station_id: str) -> bool:
|
||||||
|
"""Check if user can access the specified station."""
|
||||||
|
return self.authz_manager.can_access_station(token_data.role, station_id)
|
||||||
|
|
||||||
|
def get_user_permissions(self, token_data: TokenData) -> List[str]:
|
||||||
|
"""Get all permissions for the user."""
|
||||||
|
return self.authz_manager.get_allowed_actions(token_data.role)
|
||||||
|
|
||||||
|
|
||||||
|
# Global security manager instance
|
||||||
|
security_manager: Optional[SecurityManager] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_security_manager() -> SecurityManager:
|
||||||
|
"""Get or create the global security manager instance."""
|
||||||
|
global security_manager
|
||||||
|
if security_manager is None:
|
||||||
|
security_manager = SecurityManager()
|
||||||
|
return security_manager
|
||||||
|
|
@ -0,0 +1,304 @@
|
||||||
|
"""
|
||||||
|
TLS/SSL Manager for Calejo Control Adapter.
|
||||||
|
|
||||||
|
Provides certificate management and TLS/SSL configuration for secure communications
|
||||||
|
in compliance with IEC 62443, ISO 27001, and NIS2 Directive requirements.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import ssl
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
from pathlib import Path
|
||||||
|
import structlog
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
from config.settings import settings
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class TLSManager:
|
||||||
|
"""
|
||||||
|
Manages TLS/SSL certificates and secure communications.
|
||||||
|
|
||||||
|
Provides certificate validation, rotation, and secure context creation
|
||||||
|
for all protocol servers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.cert_path = settings.tls_cert_path
|
||||||
|
self.key_path = settings.tls_key_path
|
||||||
|
self.tls_enabled = settings.tls_enabled
|
||||||
|
|
||||||
|
# Certificate rotation tracking
|
||||||
|
self.cert_expiry_dates: dict[str, datetime] = {}
|
||||||
|
|
||||||
|
# Validate certificates on initialization
|
||||||
|
if self.tls_enabled:
|
||||||
|
self._validate_certificates()
|
||||||
|
|
||||||
|
def _validate_certificates(self) -> bool:
|
||||||
|
"""
|
||||||
|
Validate TLS certificates exist and are valid.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if certificates are valid, False otherwise
|
||||||
|
"""
|
||||||
|
if not self.cert_path or not self.key_path:
|
||||||
|
logger.warning(
|
||||||
|
"tls_certificates_missing",
|
||||||
|
cert_path=self.cert_path,
|
||||||
|
key_path=self.key_path
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if certificate files exist
|
||||||
|
cert_file = Path(self.cert_path)
|
||||||
|
key_file = Path(self.key_path)
|
||||||
|
|
||||||
|
if not cert_file.exists():
|
||||||
|
logger.error("tls_certificate_file_missing", path=self.cert_path)
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not key_file.exists():
|
||||||
|
logger.error("tls_key_file_missing", path=self.key_path)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Validate certificate format and expiry
|
||||||
|
try:
|
||||||
|
expiry_date = self._get_certificate_expiry(self.cert_path)
|
||||||
|
self.cert_expiry_dates[self.cert_path] = expiry_date
|
||||||
|
|
||||||
|
# Check if certificate is expired or expiring soon
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
days_until_expiry = (expiry_date - now).days
|
||||||
|
|
||||||
|
if days_until_expiry < 0:
|
||||||
|
logger.error("tls_certificate_expired", expiry_date=expiry_date.isoformat())
|
||||||
|
return False
|
||||||
|
elif days_until_expiry < 30:
|
||||||
|
logger.warning(
|
||||||
|
"tls_certificate_expiring_soon",
|
||||||
|
expiry_date=expiry_date.isoformat(),
|
||||||
|
days_remaining=days_until_expiry
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"tls_certificate_valid",
|
||||||
|
expiry_date=expiry_date.isoformat(),
|
||||||
|
days_remaining=days_until_expiry
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("tls_certificate_validation_failed", error=str(e))
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _get_certificate_expiry(self, cert_path: str) -> datetime:
|
||||||
|
"""
|
||||||
|
Get certificate expiry date.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cert_path: Path to certificate file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
datetime: Certificate expiry date
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If certificate cannot be parsed
|
||||||
|
"""
|
||||||
|
import OpenSSL
|
||||||
|
|
||||||
|
with open(cert_path, 'rb') as f:
|
||||||
|
cert_data = f.read()
|
||||||
|
|
||||||
|
cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_data)
|
||||||
|
expiry_bytes = cert.get_notAfter()
|
||||||
|
|
||||||
|
# Parse ASN.1 time format (YYYYMMDDHHMMSSZ)
|
||||||
|
expiry_str = expiry_bytes.decode('ascii')
|
||||||
|
expiry_date = datetime.strptime(expiry_str, '%Y%m%d%H%M%SZ')
|
||||||
|
|
||||||
|
# Make timezone aware
|
||||||
|
return expiry_date.replace(tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
def create_ssl_context(self) -> Optional[ssl.SSLContext]:
|
||||||
|
"""
|
||||||
|
Create SSL context for secure communications.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[ssl.SSLContext]: SSL context if TLS is enabled and valid,
|
||||||
|
None otherwise
|
||||||
|
"""
|
||||||
|
if not self.tls_enabled:
|
||||||
|
logger.info("tls_disabled")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not self._validate_certificates():
|
||||||
|
logger.error("cannot_create_ssl_context_invalid_certificates")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create SSL context with secure defaults
|
||||||
|
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||||
|
|
||||||
|
# Load certificate and key
|
||||||
|
context.load_cert_chain(self.cert_path, self.key_path)
|
||||||
|
|
||||||
|
# Configure secure cipher suites
|
||||||
|
context.set_ciphers('ECDHE+AESGCM:ECDHE+CHACHA20:DHE+AESGCM:DHE+CHACHA20:!aNULL:!MD5:!DSS')
|
||||||
|
|
||||||
|
# Enable certificate verification
|
||||||
|
context.verify_mode = ssl.CERT_REQUIRED
|
||||||
|
context.check_hostname = True
|
||||||
|
|
||||||
|
# Set minimum TLS version (TLS 1.2 or higher)
|
||||||
|
context.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||||
|
|
||||||
|
logger.info("ssl_context_created_successfully")
|
||||||
|
return context
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("failed_to_create_ssl_context", error=str(e))
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_rest_api_ssl_config(self) -> Optional[Tuple[str, str]]:
|
||||||
|
"""
|
||||||
|
Get SSL configuration for REST API server.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Tuple[str, str]]: Tuple of (certfile, keyfile) paths if TLS is enabled,
|
||||||
|
None otherwise
|
||||||
|
"""
|
||||||
|
if not self.tls_enabled:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not self._validate_certificates():
|
||||||
|
return None
|
||||||
|
|
||||||
|
return (self.cert_path, self.key_path)
|
||||||
|
|
||||||
|
def check_certificate_rotation(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if certificates need rotation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if certificates need rotation, False otherwise
|
||||||
|
"""
|
||||||
|
if not self.tls_enabled:
|
||||||
|
return False
|
||||||
|
|
||||||
|
for cert_path, expiry_date in self.cert_expiry_dates.items():
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
days_until_expiry = (expiry_date - now).days
|
||||||
|
|
||||||
|
# Rotate if certificate expires in less than 7 days
|
||||||
|
if days_until_expiry < 7:
|
||||||
|
logger.warning(
|
||||||
|
"certificate_needs_rotation",
|
||||||
|
cert_path=cert_path,
|
||||||
|
expiry_date=expiry_date.isoformat(),
|
||||||
|
days_remaining=days_until_expiry
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def generate_self_signed_certificate(self, output_dir: str, common_name: str = "calejo-control.local") -> bool:
|
||||||
|
"""
|
||||||
|
Generate self-signed certificate for development/testing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_dir: Directory to save certificate files
|
||||||
|
common_name: Common name for the certificate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if certificate generation succeeded, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from cryptography import x509
|
||||||
|
from cryptography.x509.oid import NameOID
|
||||||
|
from cryptography.hazmat.primitives import hashes, serialization
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||||
|
from cryptography.hazmat.backends import default_backend
|
||||||
|
|
||||||
|
# Create output directory if it doesn't exist
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Generate private key
|
||||||
|
private_key = rsa.generate_private_key(
|
||||||
|
public_exponent=65537,
|
||||||
|
key_size=2048,
|
||||||
|
backend=default_backend()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate certificate
|
||||||
|
subject = issuer = x509.Name([
|
||||||
|
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
|
||||||
|
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Calejo Control"),
|
||||||
|
x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Development"),
|
||||||
|
])
|
||||||
|
|
||||||
|
cert = (
|
||||||
|
x509.CertificateBuilder()
|
||||||
|
.subject_name(subject)
|
||||||
|
.issuer_name(issuer)
|
||||||
|
.public_key(private_key.public_key())
|
||||||
|
.serial_number(x509.random_serial_number())
|
||||||
|
.not_valid_before(datetime.now(timezone.utc))
|
||||||
|
.not_valid_after(datetime.now(timezone.utc) + timedelta(days=365))
|
||||||
|
.add_extension(
|
||||||
|
x509.SubjectAlternativeName([
|
||||||
|
x509.DNSName("localhost"),
|
||||||
|
x509.DNSName("127.0.0.1"),
|
||||||
|
]),
|
||||||
|
critical=False,
|
||||||
|
)
|
||||||
|
.sign(private_key, hashes.SHA256(), default_backend())
|
||||||
|
)
|
||||||
|
|
||||||
|
# Write private key
|
||||||
|
key_path = os.path.join(output_dir, "key.pem")
|
||||||
|
with open(key_path, "wb") as f:
|
||||||
|
f.write(
|
||||||
|
private_key.private_bytes(
|
||||||
|
encoding=serialization.Encoding.PEM,
|
||||||
|
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
||||||
|
encryption_algorithm=serialization.NoEncryption(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Write certificate
|
||||||
|
cert_path = os.path.join(output_dir, "cert.pem")
|
||||||
|
with open(cert_path, "wb") as f:
|
||||||
|
f.write(cert.public_bytes(serialization.Encoding.PEM))
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"self_signed_certificate_generated",
|
||||||
|
cert_path=cert_path,
|
||||||
|
key_path=key_path,
|
||||||
|
common_name=common_name
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update settings
|
||||||
|
self.cert_path = cert_path
|
||||||
|
self.key_path = key_path
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("failed_to_generate_self_signed_certificate", error=str(e))
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# Global TLS manager instance
|
||||||
|
tls_manager: Optional[TLSManager] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_tls_manager() -> TLSManager:
|
||||||
|
"""Get or create the global TLS manager instance."""
|
||||||
|
global tls_manager
|
||||||
|
if tls_manager is None:
|
||||||
|
tls_manager = TLSManager()
|
||||||
|
return tls_manager
|
||||||
|
|
@ -7,12 +7,17 @@ Provides REST endpoints for emergency stop, status monitoring, and setpoint acce
|
||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import structlog
|
import structlog
|
||||||
from fastapi import FastAPI, HTTPException, status, Depends
|
from fastapi import FastAPI, HTTPException, status, Depends, Request
|
||||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from src.core.setpoint_manager import SetpointManager
|
from src.core.setpoint_manager import SetpointManager
|
||||||
from src.core.emergency_stop import EmergencyStopManager
|
from src.core.emergency_stop import EmergencyStopManager
|
||||||
|
from src.core.security import (
|
||||||
|
SecurityManager, TokenData, UserRole, get_security_manager
|
||||||
|
)
|
||||||
|
from src.core.tls_manager import get_tls_manager
|
||||||
|
|
||||||
logger = structlog.get_logger()
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
|
@ -20,6 +25,64 @@ logger = structlog.get_logger()
|
||||||
security = HTTPBearer()
|
security = HTTPBearer()
|
||||||
|
|
||||||
|
|
||||||
|
class LoginRequest(BaseModel):
|
||||||
|
"""Request model for user login."""
|
||||||
|
username: str
|
||||||
|
password: str
|
||||||
|
|
||||||
|
|
||||||
|
class LoginResponse(BaseModel):
|
||||||
|
"""Response model for successful login."""
|
||||||
|
access_token: str
|
||||||
|
token_type: str = "bearer"
|
||||||
|
expires_in: int
|
||||||
|
user_id: str
|
||||||
|
username: str
|
||||||
|
role: str
|
||||||
|
permissions: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class UserInfoResponse(BaseModel):
|
||||||
|
"""Response model for user information."""
|
||||||
|
user_id: str
|
||||||
|
username: str
|
||||||
|
email: str
|
||||||
|
role: str
|
||||||
|
permissions: list[str]
|
||||||
|
last_login: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_user(
|
||||||
|
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||||
|
security_manager: SecurityManager = Depends(get_security_manager)
|
||||||
|
) -> TokenData:
|
||||||
|
"""Dependency to get current user from JWT token."""
|
||||||
|
token = credentials.credentials
|
||||||
|
token_data = security_manager.verify_access_token(token)
|
||||||
|
if not token_data:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid authentication credentials",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
return token_data
|
||||||
|
|
||||||
|
|
||||||
|
def require_permission(permission: str):
|
||||||
|
"""Dependency factory to require specific permission."""
|
||||||
|
def permission_dependency(
|
||||||
|
token_data: TokenData = Depends(get_current_user),
|
||||||
|
security_manager: SecurityManager = Depends(get_security_manager)
|
||||||
|
):
|
||||||
|
if not security_manager.check_permission(token_data, permission):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=f"Insufficient permissions. Required: {permission}"
|
||||||
|
)
|
||||||
|
return token_data
|
||||||
|
return permission_dependency
|
||||||
|
|
||||||
|
|
||||||
class EmergencyStopRequest(BaseModel):
|
class EmergencyStopRequest(BaseModel):
|
||||||
"""Request model for emergency stop."""
|
"""Request model for emergency stop."""
|
||||||
triggered_by: str
|
triggered_by: str
|
||||||
|
|
@ -63,7 +126,18 @@ class RESTAPIServer:
|
||||||
self.app = FastAPI(
|
self.app = FastAPI(
|
||||||
title="Calejo Control API",
|
title="Calejo Control API",
|
||||||
version="2.0",
|
version="2.0",
|
||||||
description="REST API for Calejo Control Adapter with safety framework"
|
description="REST API for Calejo Control Adapter with safety framework",
|
||||||
|
docs_url="/docs",
|
||||||
|
redoc_url="/redoc"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add CORS middleware
|
||||||
|
self.app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"], # In production, restrict to specific origins
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
self._setup_routes()
|
self._setup_routes()
|
||||||
|
|
@ -77,7 +151,8 @@ class RESTAPIServer:
|
||||||
return {
|
return {
|
||||||
"name": "Calejo Control API",
|
"name": "Calejo Control API",
|
||||||
"version": "2.0",
|
"version": "2.0",
|
||||||
"status": "operational"
|
"status": "operational",
|
||||||
|
"authentication_required": True
|
||||||
}
|
}
|
||||||
|
|
||||||
@self.app.get("/health", summary="Health Check", tags=["General"])
|
@self.app.get("/health", summary="Health Check", tags=["General"])
|
||||||
|
|
@ -88,6 +163,69 @@ class RESTAPIServer:
|
||||||
"timestamp": datetime.now().isoformat()
|
"timestamp": datetime.now().isoformat()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Authentication endpoints (no authentication required)
|
||||||
|
@self.app.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
summary="User Login",
|
||||||
|
tags=["Authentication"],
|
||||||
|
response_model=LoginResponse
|
||||||
|
)
|
||||||
|
async def login(
|
||||||
|
request: LoginRequest,
|
||||||
|
security_manager: SecurityManager = Depends(get_security_manager)
|
||||||
|
):
|
||||||
|
"""Authenticate user and return JWT token."""
|
||||||
|
token = security_manager.authenticate(request.username, request.password)
|
||||||
|
if not token:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid username or password"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get user permissions
|
||||||
|
token_data = security_manager.verify_access_token(token)
|
||||||
|
permissions = security_manager.get_user_permissions(token_data)
|
||||||
|
|
||||||
|
return LoginResponse(
|
||||||
|
access_token=token,
|
||||||
|
token_type="bearer",
|
||||||
|
expires_in=60, # 60 minutes
|
||||||
|
user_id=token_data.user_id,
|
||||||
|
username=token_data.username,
|
||||||
|
role=token_data.role.value,
|
||||||
|
permissions=permissions
|
||||||
|
)
|
||||||
|
|
||||||
|
@self.app.get(
|
||||||
|
"/api/v1/auth/me",
|
||||||
|
summary="Get Current User Info",
|
||||||
|
tags=["Authentication"],
|
||||||
|
response_model=UserInfoResponse
|
||||||
|
)
|
||||||
|
async def get_current_user_info(
|
||||||
|
token_data: TokenData = Depends(get_current_user),
|
||||||
|
security_manager: SecurityManager = Depends(get_security_manager)
|
||||||
|
):
|
||||||
|
"""Get information about the currently authenticated user."""
|
||||||
|
# Get user from auth manager
|
||||||
|
user = security_manager.auth_manager.users.get(token_data.user_id)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="User not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
permissions = security_manager.get_user_permissions(token_data)
|
||||||
|
|
||||||
|
return UserInfoResponse(
|
||||||
|
user_id=user.user_id,
|
||||||
|
username=user.username,
|
||||||
|
email=user.email,
|
||||||
|
role=user.role.value,
|
||||||
|
permissions=permissions,
|
||||||
|
last_login=user.last_login.isoformat() if user.last_login else None
|
||||||
|
)
|
||||||
|
|
||||||
@self.app.get(
|
@self.app.get(
|
||||||
"/api/v1/setpoints",
|
"/api/v1/setpoints",
|
||||||
summary="Get All Setpoints",
|
summary="Get All Setpoints",
|
||||||
|
|
@ -95,12 +233,14 @@ class RESTAPIServer:
|
||||||
response_model=Dict[str, Dict[str, Optional[float]]]
|
response_model=Dict[str, Dict[str, Optional[float]]]
|
||||||
)
|
)
|
||||||
async def get_all_setpoints(
|
async def get_all_setpoints(
|
||||||
credentials: HTTPAuthorizationCredentials = Depends(security)
|
token_data: TokenData = Depends(require_permission("read_pump_status"))
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Get current setpoints for all pumps.
|
Get current setpoints for all pumps.
|
||||||
|
|
||||||
Returns dictionary mapping station_id -> pump_id -> setpoint_hz
|
Returns dictionary mapping station_id -> pump_id -> setpoint_hz
|
||||||
|
|
||||||
|
Requires permission: read_pump_status
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
setpoints = self.setpoint_manager.get_all_current_setpoints()
|
setpoints = self.setpoint_manager.get_all_current_setpoints()
|
||||||
|
|
@ -121,10 +261,22 @@ class RESTAPIServer:
|
||||||
async def get_pump_setpoint(
|
async def get_pump_setpoint(
|
||||||
station_id: str,
|
station_id: str,
|
||||||
pump_id: str,
|
pump_id: str,
|
||||||
credentials: HTTPAuthorizationCredentials = Depends(security)
|
token_data: TokenData = Depends(require_permission("read_pump_status")),
|
||||||
|
security_manager: SecurityManager = Depends(get_security_manager)
|
||||||
):
|
):
|
||||||
"""Get current setpoint for a specific pump."""
|
"""
|
||||||
|
Get current setpoint for a specific pump.
|
||||||
|
|
||||||
|
Requires permission: read_pump_status
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Check if user can access this station
|
||||||
|
if not security_manager.can_access_station(token_data, station_id):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=f"Access denied to station {station_id}"
|
||||||
|
)
|
||||||
|
|
||||||
setpoint = self.setpoint_manager.get_current_setpoint(station_id, pump_id)
|
setpoint = self.setpoint_manager.get_current_setpoint(station_id, pump_id)
|
||||||
|
|
||||||
# Get pump info for control type
|
# Get pump info for control type
|
||||||
|
|
@ -147,6 +299,8 @@ class RESTAPIServer:
|
||||||
timestamp=datetime.now().isoformat()
|
timestamp=datetime.now().isoformat()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
"failed_to_get_pump_setpoint",
|
"failed_to_get_pump_setpoint",
|
||||||
|
|
@ -167,7 +321,8 @@ class RESTAPIServer:
|
||||||
)
|
)
|
||||||
async def trigger_emergency_stop(
|
async def trigger_emergency_stop(
|
||||||
request: EmergencyStopRequest,
|
request: EmergencyStopRequest,
|
||||||
credentials: HTTPAuthorizationCredentials = Depends(security)
|
token_data: TokenData = Depends(require_permission("emergency_stop")),
|
||||||
|
security_manager: SecurityManager = Depends(get_security_manager)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Trigger emergency stop ("big red button").
|
Trigger emergency stop ("big red button").
|
||||||
|
|
@ -176,15 +331,27 @@ class RESTAPIServer:
|
||||||
- If station_id and pump_id provided: Stop single pump
|
- If station_id and pump_id provided: Stop single pump
|
||||||
- If station_id only: Stop all pumps at station
|
- If station_id only: Stop all pumps at station
|
||||||
- If neither: Stop ALL pumps system-wide
|
- If neither: Stop ALL pumps system-wide
|
||||||
|
|
||||||
|
Requires permission: emergency_stop
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Check station access if station_id is provided
|
||||||
|
if request.station_id and not security_manager.can_access_station(token_data, request.station_id):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=f"Access denied to station {request.station_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use authenticated user as triggered_by
|
||||||
|
triggered_by = token_data.username
|
||||||
|
|
||||||
if request.station_id and request.pump_id:
|
if request.station_id and request.pump_id:
|
||||||
# Single pump stop
|
# Single pump stop
|
||||||
result = self.emergency_stop_manager.emergency_stop_pump(
|
result = self.emergency_stop_manager.emergency_stop_pump(
|
||||||
station_id=request.station_id,
|
station_id=request.station_id,
|
||||||
pump_id=request.pump_id,
|
pump_id=request.pump_id,
|
||||||
reason=request.reason,
|
reason=request.reason,
|
||||||
user_id=request.triggered_by
|
user_id=triggered_by
|
||||||
)
|
)
|
||||||
scope = f"pump {request.station_id}/{request.pump_id}"
|
scope = f"pump {request.station_id}/{request.pump_id}"
|
||||||
elif request.station_id:
|
elif request.station_id:
|
||||||
|
|
@ -192,14 +359,14 @@ class RESTAPIServer:
|
||||||
result = self.emergency_stop_manager.emergency_stop_station(
|
result = self.emergency_stop_manager.emergency_stop_station(
|
||||||
station_id=request.station_id,
|
station_id=request.station_id,
|
||||||
reason=request.reason,
|
reason=request.reason,
|
||||||
user_id=request.triggered_by
|
user_id=triggered_by
|
||||||
)
|
)
|
||||||
scope = f"station {request.station_id}"
|
scope = f"station {request.station_id}"
|
||||||
else:
|
else:
|
||||||
# System-wide stop
|
# System-wide stop
|
||||||
result = self.emergency_stop_manager.emergency_stop_system(
|
result = self.emergency_stop_manager.emergency_stop_system(
|
||||||
reason=request.reason,
|
reason=request.reason,
|
||||||
user_id=request.triggered_by
|
user_id=triggered_by
|
||||||
)
|
)
|
||||||
scope = "system-wide"
|
scope = "system-wide"
|
||||||
|
|
||||||
|
|
@ -208,7 +375,7 @@ class RESTAPIServer:
|
||||||
"status": "emergency_stop_triggered",
|
"status": "emergency_stop_triggered",
|
||||||
"scope": scope,
|
"scope": scope,
|
||||||
"reason": request.reason,
|
"reason": request.reason,
|
||||||
"triggered_by": request.triggered_by,
|
"triggered_by": triggered_by,
|
||||||
"timestamp": datetime.now().isoformat()
|
"timestamp": datetime.now().isoformat()
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
|
@ -217,6 +384,8 @@ class RESTAPIServer:
|
||||||
detail="Failed to trigger emergency stop"
|
detail="Failed to trigger emergency stop"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("failed_to_trigger_emergency_stop", error=str(e))
|
logger.error("failed_to_trigger_emergency_stop", error=str(e))
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -231,19 +400,26 @@ class RESTAPIServer:
|
||||||
)
|
)
|
||||||
async def clear_emergency_stop(
|
async def clear_emergency_stop(
|
||||||
request: EmergencyStopClearRequest,
|
request: EmergencyStopClearRequest,
|
||||||
credentials: HTTPAuthorizationCredentials = Depends(security)
|
token_data: TokenData = Depends(require_permission("clear_emergency_stop"))
|
||||||
):
|
):
|
||||||
"""Clear all active emergency stops."""
|
"""
|
||||||
|
Clear all active emergency stops.
|
||||||
|
|
||||||
|
Requires permission: clear_emergency_stop
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Use authenticated user as cleared_by
|
||||||
|
cleared_by = token_data.username
|
||||||
|
|
||||||
# Clear system-wide emergency stop
|
# Clear system-wide emergency stop
|
||||||
self.emergency_stop_manager.clear_emergency_stop_system(
|
self.emergency_stop_manager.clear_emergency_stop_system(
|
||||||
reason=request.notes,
|
reason=request.notes,
|
||||||
user_id=request.cleared_by
|
user_id=cleared_by
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "emergency_stop_cleared",
|
"status": "emergency_stop_cleared",
|
||||||
"cleared_by": request.cleared_by,
|
"cleared_by": cleared_by,
|
||||||
"notes": request.notes,
|
"notes": request.notes,
|
||||||
"timestamp": datetime.now().isoformat()
|
"timestamp": datetime.now().isoformat()
|
||||||
}
|
}
|
||||||
|
|
@ -261,9 +437,13 @@ class RESTAPIServer:
|
||||||
tags=["Emergency Stop"]
|
tags=["Emergency Stop"]
|
||||||
)
|
)
|
||||||
async def get_emergency_stop_status(
|
async def get_emergency_stop_status(
|
||||||
credentials: HTTPAuthorizationCredentials = Depends(security)
|
token_data: TokenData = Depends(require_permission("read_safety_status"))
|
||||||
):
|
):
|
||||||
"""Check if any emergency stops are active."""
|
"""
|
||||||
|
Check if any emergency stops are active.
|
||||||
|
|
||||||
|
Requires permission: read_safety_status
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
# Check system-wide emergency stop
|
# Check system-wide emergency stop
|
||||||
system_stop = self.emergency_stop_manager.system_emergency_stop
|
system_stop = self.emergency_stop_manager.system_emergency_stop
|
||||||
|
|
@ -291,18 +471,33 @@ class RESTAPIServer:
|
||||||
"""Start the REST API server."""
|
"""Start the REST API server."""
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
|
# Get TLS configuration
|
||||||
|
tls_manager = get_tls_manager()
|
||||||
|
ssl_config = tls_manager.get_rest_api_ssl_config()
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"rest_api_server_starting",
|
"rest_api_server_starting",
|
||||||
host=self.host,
|
host=self.host,
|
||||||
port=self.port
|
port=self.port,
|
||||||
|
tls_enabled=ssl_config is not None
|
||||||
)
|
)
|
||||||
|
|
||||||
config = uvicorn.Config(
|
config_kwargs = {
|
||||||
self.app,
|
"app": self.app,
|
||||||
host=self.host,
|
"host": self.host,
|
||||||
port=self.port,
|
"port": self.port,
|
||||||
log_level="info"
|
"log_level": "info"
|
||||||
)
|
}
|
||||||
|
|
||||||
|
# Add SSL configuration if available
|
||||||
|
if ssl_config:
|
||||||
|
certfile, keyfile = ssl_config
|
||||||
|
config_kwargs.update({
|
||||||
|
"ssl_certfile": certfile,
|
||||||
|
"ssl_keyfile": keyfile
|
||||||
|
})
|
||||||
|
|
||||||
|
config = uvicorn.Config(**config_kwargs)
|
||||||
server = uvicorn.Server(config)
|
server = uvicorn.Server(config)
|
||||||
await server.serve()
|
await server.serve()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,368 @@
|
||||||
|
"""
|
||||||
|
Unit tests for compliance audit components.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from src.core.compliance_audit import (
|
||||||
|
ComplianceAuditLogger, AuditEventType, AuditSeverity
|
||||||
|
)
|
||||||
|
from config.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class TestComplianceAuditLogger:
|
||||||
|
"""Test cases for ComplianceAuditLogger."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Set up test fixtures."""
|
||||||
|
self.mock_db_client = Mock()
|
||||||
|
self.audit_logger = ComplianceAuditLogger(self.mock_db_client)
|
||||||
|
|
||||||
|
# Mock settings
|
||||||
|
self.original_audit_enabled = settings.audit_log_enabled
|
||||||
|
settings.audit_log_enabled = True
|
||||||
|
|
||||||
|
def teardown_method(self):
|
||||||
|
"""Clean up test fixtures."""
|
||||||
|
# Restore original settings
|
||||||
|
settings.audit_log_enabled = self.original_audit_enabled
|
||||||
|
|
||||||
|
def test_initialization(self):
|
||||||
|
"""Test initialization of ComplianceAuditLogger."""
|
||||||
|
assert self.audit_logger.db_client == self.mock_db_client
|
||||||
|
assert self.audit_logger.logger is not None
|
||||||
|
|
||||||
|
def test_log_compliance_event_success(self):
|
||||||
|
"""Test successful logging of compliance event."""
|
||||||
|
event_type = AuditEventType.USER_LOGIN
|
||||||
|
severity = AuditSeverity.LOW
|
||||||
|
user_id = "test_user"
|
||||||
|
ip_address = "192.168.1.100"
|
||||||
|
|
||||||
|
with patch.object(self.audit_logger.logger, 'info') as mock_log:
|
||||||
|
self.audit_logger.log_compliance_event(
|
||||||
|
event_type=event_type,
|
||||||
|
severity=severity,
|
||||||
|
user_id=user_id,
|
||||||
|
ip_address=ip_address,
|
||||||
|
action="login",
|
||||||
|
resource="system",
|
||||||
|
result="success"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify structured logging
|
||||||
|
mock_log.assert_called_once()
|
||||||
|
call_args = mock_log.call_args[1]
|
||||||
|
assert call_args["event_type"] == "user_login"
|
||||||
|
assert call_args["severity"] == "low"
|
||||||
|
assert call_args["user_id"] == "test_user"
|
||||||
|
assert call_args["ip_address"] == "192.168.1.100"
|
||||||
|
|
||||||
|
# Verify database logging
|
||||||
|
self.mock_db_client.execute.assert_called_once()
|
||||||
|
|
||||||
|
def test_log_compliance_event_database_disabled(self):
|
||||||
|
"""Test logging when database audit is disabled."""
|
||||||
|
settings.audit_log_enabled = False
|
||||||
|
|
||||||
|
with patch.object(self.audit_logger.logger, 'info') as mock_log:
|
||||||
|
self.audit_logger.log_compliance_event(
|
||||||
|
event_type=AuditEventType.USER_LOGIN,
|
||||||
|
severity=AuditSeverity.LOW,
|
||||||
|
user_id="test_user"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify structured logging still occurs
|
||||||
|
mock_log.assert_called_once()
|
||||||
|
|
||||||
|
# Verify database logging is skipped
|
||||||
|
self.mock_db_client.execute.assert_not_called()
|
||||||
|
|
||||||
|
def test_log_compliance_event_database_error(self):
|
||||||
|
"""Test logging when database operation fails."""
|
||||||
|
# Mock database error
|
||||||
|
self.mock_db_client.execute.side_effect = Exception("Database error")
|
||||||
|
|
||||||
|
with patch.object(self.audit_logger.logger, 'error') as mock_error_log:
|
||||||
|
self.audit_logger.log_compliance_event(
|
||||||
|
event_type=AuditEventType.USER_LOGIN,
|
||||||
|
severity=AuditSeverity.LOW,
|
||||||
|
user_id="test_user"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify error logging
|
||||||
|
mock_error_log.assert_called_once()
|
||||||
|
call_args = mock_error_log.call_args[1]
|
||||||
|
assert "Database error" in call_args["error"]
|
||||||
|
assert call_args["event_type"] == "user_login"
|
||||||
|
|
||||||
|
def test_log_user_authentication_success(self):
|
||||||
|
"""Test logging successful user authentication."""
|
||||||
|
user_id = "test_user"
|
||||||
|
ip_address = "192.168.1.100"
|
||||||
|
|
||||||
|
with patch.object(self.audit_logger, 'log_compliance_event') as mock_log:
|
||||||
|
self.audit_logger.log_user_authentication(
|
||||||
|
user_id=user_id,
|
||||||
|
result="success",
|
||||||
|
ip_address=ip_address
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the correct event type and severity
|
||||||
|
mock_log.assert_called_once()
|
||||||
|
call_args = mock_log.call_args[1]
|
||||||
|
assert call_args["event_type"] == AuditEventType.USER_LOGIN
|
||||||
|
assert call_args["severity"] == AuditSeverity.LOW
|
||||||
|
assert call_args["user_id"] == user_id
|
||||||
|
assert call_args["ip_address"] == ip_address
|
||||||
|
assert call_args["result"] == "success"
|
||||||
|
|
||||||
|
def test_log_user_authentication_failure(self):
|
||||||
|
"""Test logging failed user authentication."""
|
||||||
|
user_id = "test_user"
|
||||||
|
ip_address = "192.168.1.100"
|
||||||
|
reason = "Invalid credentials"
|
||||||
|
|
||||||
|
with patch.object(self.audit_logger, 'log_compliance_event') as mock_log:
|
||||||
|
self.audit_logger.log_user_authentication(
|
||||||
|
user_id=user_id,
|
||||||
|
result="failure",
|
||||||
|
ip_address=ip_address,
|
||||||
|
reason=reason
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the correct event type and severity
|
||||||
|
mock_log.assert_called_once()
|
||||||
|
call_args = mock_log.call_args[1]
|
||||||
|
assert call_args["event_type"] == AuditEventType.INVALID_AUTHENTICATION
|
||||||
|
assert call_args["severity"] == AuditSeverity.HIGH
|
||||||
|
assert call_args["reason"] == reason
|
||||||
|
|
||||||
|
def test_log_access_control_granted(self):
|
||||||
|
"""Test logging granted access control."""
|
||||||
|
user_id = "test_user"
|
||||||
|
action = "read"
|
||||||
|
resource = "station_data"
|
||||||
|
ip_address = "192.168.1.100"
|
||||||
|
|
||||||
|
with patch.object(self.audit_logger, 'log_compliance_event') as mock_log:
|
||||||
|
self.audit_logger.log_access_control(
|
||||||
|
user_id=user_id,
|
||||||
|
action=action,
|
||||||
|
resource=resource,
|
||||||
|
result="granted",
|
||||||
|
ip_address=ip_address
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the correct event type
|
||||||
|
mock_log.assert_called_once()
|
||||||
|
call_args = mock_log.call_args[1]
|
||||||
|
assert call_args["event_type"] == AuditEventType.DATA_READ
|
||||||
|
assert call_args["severity"] == AuditSeverity.MEDIUM
|
||||||
|
|
||||||
|
def test_log_access_control_denied(self):
|
||||||
|
"""Test logging denied access control."""
|
||||||
|
user_id = "test_user"
|
||||||
|
action = "write"
|
||||||
|
resource = "system_config"
|
||||||
|
ip_address = "192.168.1.100"
|
||||||
|
reason = "Insufficient permissions"
|
||||||
|
|
||||||
|
with patch.object(self.audit_logger, 'log_compliance_event') as mock_log:
|
||||||
|
self.audit_logger.log_access_control(
|
||||||
|
user_id=user_id,
|
||||||
|
action=action,
|
||||||
|
resource=resource,
|
||||||
|
result="denied",
|
||||||
|
ip_address=ip_address,
|
||||||
|
reason=reason
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the correct event type and severity
|
||||||
|
mock_log.assert_called_once()
|
||||||
|
call_args = mock_log.call_args[1]
|
||||||
|
assert call_args["event_type"] == AuditEventType.ACCESS_DENIED
|
||||||
|
assert call_args["severity"] == AuditSeverity.HIGH
|
||||||
|
assert call_args["reason"] == reason
|
||||||
|
|
||||||
|
def test_log_control_operation_setpoint_change(self):
|
||||||
|
"""Test logging setpoint change operation."""
|
||||||
|
user_id = "operator_user"
|
||||||
|
station_id = "station_001"
|
||||||
|
action = "setpoint_change"
|
||||||
|
resource = "pump_speed"
|
||||||
|
ip_address = "192.168.1.100"
|
||||||
|
|
||||||
|
with patch.object(self.audit_logger, 'log_compliance_event') as mock_log:
|
||||||
|
self.audit_logger.log_control_operation(
|
||||||
|
user_id=user_id,
|
||||||
|
station_id=station_id,
|
||||||
|
pump_id=None,
|
||||||
|
action=action,
|
||||||
|
resource=resource,
|
||||||
|
result="success",
|
||||||
|
ip_address=ip_address
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the correct event type and severity
|
||||||
|
mock_log.assert_called_once()
|
||||||
|
call_args = mock_log.call_args[1]
|
||||||
|
assert call_args["event_type"] == AuditEventType.SETPOINT_CHANGED
|
||||||
|
assert call_args["severity"] == AuditSeverity.HIGH
|
||||||
|
assert call_args["station_id"] == station_id
|
||||||
|
|
||||||
|
def test_log_control_operation_emergency_stop(self):
|
||||||
|
"""Test logging emergency stop operation."""
|
||||||
|
user_id = "operator_user"
|
||||||
|
station_id = "station_001"
|
||||||
|
action = "emergency_stop"
|
||||||
|
resource = "system"
|
||||||
|
ip_address = "192.168.1.100"
|
||||||
|
|
||||||
|
with patch.object(self.audit_logger, 'log_compliance_event') as mock_log:
|
||||||
|
self.audit_logger.log_control_operation(
|
||||||
|
user_id=user_id,
|
||||||
|
station_id=station_id,
|
||||||
|
pump_id="pump_001",
|
||||||
|
action=action,
|
||||||
|
resource=resource,
|
||||||
|
result="success",
|
||||||
|
ip_address=ip_address
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the correct event type and severity
|
||||||
|
mock_log.assert_called_once()
|
||||||
|
call_args = mock_log.call_args[1]
|
||||||
|
assert call_args["event_type"] == AuditEventType.EMERGENCY_STOP_ACTIVATED
|
||||||
|
assert call_args["severity"] == AuditSeverity.CRITICAL
|
||||||
|
assert call_args["pump_id"] == "pump_001"
|
||||||
|
|
||||||
|
def test_log_security_event(self):
|
||||||
|
"""Test logging security events."""
|
||||||
|
event_type = AuditEventType.CERTIFICATE_EXPIRED
|
||||||
|
severity = AuditSeverity.HIGH
|
||||||
|
user_id = "system"
|
||||||
|
station_id = "station_001"
|
||||||
|
reason = "TLS certificate expired"
|
||||||
|
|
||||||
|
with patch.object(self.audit_logger, 'log_compliance_event') as mock_log:
|
||||||
|
self.audit_logger.log_security_event(
|
||||||
|
event_type=event_type,
|
||||||
|
severity=severity,
|
||||||
|
user_id=user_id,
|
||||||
|
station_id=station_id,
|
||||||
|
reason=reason
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the correct parameters
|
||||||
|
mock_log.assert_called_once()
|
||||||
|
call_args = mock_log.call_args[1]
|
||||||
|
assert call_args["event_type"] == event_type
|
||||||
|
assert call_args["severity"] == severity
|
||||||
|
assert call_args["user_id"] == user_id
|
||||||
|
assert call_args["station_id"] == station_id
|
||||||
|
assert call_args["reason"] == reason
|
||||||
|
|
||||||
|
def test_get_audit_trail_success(self):
|
||||||
|
"""Test successful retrieval of audit trail."""
|
||||||
|
# Mock database result
|
||||||
|
mock_result = [
|
||||||
|
{"event_type": "user_login", "user_id": "test_user"},
|
||||||
|
{"event_type": "data_read", "user_id": "test_user"}
|
||||||
|
]
|
||||||
|
self.mock_db_client.fetch_all.return_value = mock_result
|
||||||
|
|
||||||
|
start_time = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||||
|
end_time = datetime(2024, 1, 31, tzinfo=timezone.utc)
|
||||||
|
user_id = "test_user"
|
||||||
|
|
||||||
|
with patch.object(self.audit_logger, 'log_compliance_event') as mock_log:
|
||||||
|
result = self.audit_logger.get_audit_trail(
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify database query
|
||||||
|
self.mock_db_client.fetch_all.assert_called_once()
|
||||||
|
|
||||||
|
# Verify audit trail access logging
|
||||||
|
mock_log.assert_called_once()
|
||||||
|
call_args = mock_log.call_args[1]
|
||||||
|
assert call_args["event_type"] == AuditEventType.AUDIT_LOG_ACCESSED
|
||||||
|
assert call_args["user_id"] == user_id
|
||||||
|
|
||||||
|
# Verify result
|
||||||
|
assert result == mock_result
|
||||||
|
|
||||||
|
def test_get_audit_trail_error(self):
|
||||||
|
"""Test audit trail retrieval with error."""
|
||||||
|
# Mock database error
|
||||||
|
self.mock_db_client.fetch_all.side_effect = Exception("Query failed")
|
||||||
|
|
||||||
|
with patch.object(self.audit_logger.logger, 'error') as mock_error_log:
|
||||||
|
result = self.audit_logger.get_audit_trail()
|
||||||
|
|
||||||
|
# Verify error logging
|
||||||
|
mock_error_log.assert_called_once()
|
||||||
|
|
||||||
|
# Verify empty result
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_generate_compliance_report_success(self):
|
||||||
|
"""Test successful compliance report generation."""
|
||||||
|
# Mock database result
|
||||||
|
mock_result = [
|
||||||
|
{"event_type": "user_login", "severity": "low", "count": 10, "unique_users": 2},
|
||||||
|
{"event_type": "access_denied", "severity": "high", "count": 2, "unique_users": 1}
|
||||||
|
]
|
||||||
|
self.mock_db_client.fetch_all.return_value = mock_result
|
||||||
|
|
||||||
|
start_time = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||||
|
end_time = datetime(2024, 1, 31, tzinfo=timezone.utc)
|
||||||
|
compliance_standard = "IEC_62443"
|
||||||
|
|
||||||
|
with patch.object(self.audit_logger, 'log_compliance_event') as mock_log:
|
||||||
|
report = self.audit_logger.generate_compliance_report(
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
compliance_standard=compliance_standard
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify database query
|
||||||
|
self.mock_db_client.fetch_all.assert_called_once()
|
||||||
|
|
||||||
|
# Verify compliance report logging
|
||||||
|
mock_log.assert_called_once()
|
||||||
|
|
||||||
|
# Verify report structure
|
||||||
|
assert report["compliance_standard"] == compliance_standard
|
||||||
|
assert report["summary"]["total_events"] == 12
|
||||||
|
assert report["summary"]["unique_users"] == 3
|
||||||
|
assert report["summary"]["event_types"] == 2
|
||||||
|
assert report["events_by_type"] == mock_result
|
||||||
|
|
||||||
|
def test_generate_compliance_report_error(self):
|
||||||
|
"""Test compliance report generation with error."""
|
||||||
|
# Mock database error
|
||||||
|
self.mock_db_client.fetch_all.side_effect = Exception("Report generation failed")
|
||||||
|
|
||||||
|
start_time = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||||
|
end_time = datetime(2024, 1, 31, tzinfo=timezone.utc)
|
||||||
|
compliance_standard = "ISO_27001"
|
||||||
|
|
||||||
|
with patch.object(self.audit_logger.logger, 'error') as mock_error_log:
|
||||||
|
report = self.audit_logger.generate_compliance_report(
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
compliance_standard=compliance_standard
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify error logging
|
||||||
|
mock_error_log.assert_called_once()
|
||||||
|
|
||||||
|
# Verify error result
|
||||||
|
assert "error" in report
|
||||||
|
assert "Report generation failed" in report["error"]
|
||||||
|
|
@ -0,0 +1,373 @@
|
||||||
|
"""
|
||||||
|
Unit tests for security layer components.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
import jwt
|
||||||
|
import bcrypt
|
||||||
|
|
||||||
|
from src.core.security import (
|
||||||
|
AuthenticationManager,
|
||||||
|
AuthorizationManager,
|
||||||
|
SecurityManager,
|
||||||
|
UserRole,
|
||||||
|
User,
|
||||||
|
TokenData
|
||||||
|
)
|
||||||
|
from config.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuthenticationManager:
|
||||||
|
"""Test cases for AuthenticationManager."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Set up test fixtures."""
|
||||||
|
self.auth_manager = AuthenticationManager(
|
||||||
|
secret_key="test-secret-key",
|
||||||
|
token_expire_minutes=60
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_hash_password(self):
|
||||||
|
"""Test password hashing."""
|
||||||
|
password = "testpassword123"
|
||||||
|
hashed = self.auth_manager.hash_password(password)
|
||||||
|
|
||||||
|
# Verify the hash is valid
|
||||||
|
assert bcrypt.checkpw(password.encode('utf-8'), hashed.encode('utf-8'))
|
||||||
|
|
||||||
|
# Verify different passwords produce different hashes
|
||||||
|
different_password = "differentpassword"
|
||||||
|
different_hashed = self.auth_manager.hash_password(different_password)
|
||||||
|
assert hashed != different_hashed
|
||||||
|
|
||||||
|
def test_verify_password(self):
|
||||||
|
"""Test password verification."""
|
||||||
|
password = "testpassword123"
|
||||||
|
hashed = self.auth_manager.hash_password(password)
|
||||||
|
|
||||||
|
# Test correct password
|
||||||
|
assert self.auth_manager.verify_password(password, hashed)
|
||||||
|
|
||||||
|
# Test incorrect password
|
||||||
|
assert not self.auth_manager.verify_password("wrongpassword", hashed)
|
||||||
|
|
||||||
|
def test_create_user(self):
|
||||||
|
"""Test user creation."""
|
||||||
|
user = self.auth_manager.create_user(
|
||||||
|
user_id="test_user_001",
|
||||||
|
username="testuser",
|
||||||
|
email="test@example.com",
|
||||||
|
role=UserRole.OPERATOR,
|
||||||
|
password="testpassword123"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert user.user_id == "test_user_001"
|
||||||
|
assert user.username == "testuser"
|
||||||
|
assert user.email == "test@example.com"
|
||||||
|
assert user.role == UserRole.OPERATOR
|
||||||
|
assert user.active is True
|
||||||
|
assert user.created_at is not None
|
||||||
|
|
||||||
|
# Verify password was hashed and stored
|
||||||
|
assert "test_user_001" in self.auth_manager.password_hashes
|
||||||
|
hashed_password = self.auth_manager.password_hashes["test_user_001"]
|
||||||
|
assert bcrypt.checkpw("testpassword123".encode('utf-8'), hashed_password.encode('utf-8'))
|
||||||
|
|
||||||
|
def test_create_user_duplicate_id(self):
|
||||||
|
"""Test creating user with duplicate ID."""
|
||||||
|
self.auth_manager.create_user(
|
||||||
|
user_id="duplicate_user",
|
||||||
|
username="user1",
|
||||||
|
email="user1@example.com",
|
||||||
|
role=UserRole.OPERATOR,
|
||||||
|
password="password1"
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="User with ID duplicate_user already exists"):
|
||||||
|
self.auth_manager.create_user(
|
||||||
|
user_id="duplicate_user",
|
||||||
|
username="user2",
|
||||||
|
email="user2@example.com",
|
||||||
|
role=UserRole.ENGINEER,
|
||||||
|
password="password2"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_authenticate_user_success(self):
|
||||||
|
"""Test successful user authentication."""
|
||||||
|
# Create a test user
|
||||||
|
self.auth_manager.create_user(
|
||||||
|
user_id="auth_test_user",
|
||||||
|
username="authuser",
|
||||||
|
email="auth@example.com",
|
||||||
|
role=UserRole.ENGINEER,
|
||||||
|
password="authpassword123"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test authentication
|
||||||
|
user = self.auth_manager.authenticate_user("authuser", "authpassword123")
|
||||||
|
|
||||||
|
assert user is not None
|
||||||
|
assert user.user_id == "auth_test_user"
|
||||||
|
assert user.username == "authuser"
|
||||||
|
assert user.role == UserRole.ENGINEER
|
||||||
|
assert user.last_login is not None
|
||||||
|
|
||||||
|
def test_authenticate_user_wrong_password(self):
|
||||||
|
"""Test authentication with wrong password."""
|
||||||
|
self.auth_manager.create_user(
|
||||||
|
user_id="wrong_pass_user",
|
||||||
|
username="wrongpass",
|
||||||
|
email="wrong@example.com",
|
||||||
|
role=UserRole.OPERATOR,
|
||||||
|
password="correctpassword"
|
||||||
|
)
|
||||||
|
|
||||||
|
user = self.auth_manager.authenticate_user("wrongpass", "wrongpassword")
|
||||||
|
assert user is None
|
||||||
|
|
||||||
|
def test_authenticate_user_not_found(self):
|
||||||
|
"""Test authentication with non-existent user."""
|
||||||
|
user = self.auth_manager.authenticate_user("nonexistent", "password")
|
||||||
|
assert user is None
|
||||||
|
|
||||||
|
def test_authenticate_user_inactive(self):
|
||||||
|
"""Test authentication with inactive user."""
|
||||||
|
# Create inactive user
|
||||||
|
user = User(
|
||||||
|
user_id="inactive_user",
|
||||||
|
username="inactive",
|
||||||
|
email="inactive@example.com",
|
||||||
|
role=UserRole.OPERATOR,
|
||||||
|
active=False,
|
||||||
|
created_at=datetime.now(timezone.utc)
|
||||||
|
)
|
||||||
|
self.auth_manager.users["inactive_user"] = user
|
||||||
|
self.auth_manager.password_hashes["inactive_user"] = self.auth_manager.hash_password("password")
|
||||||
|
|
||||||
|
result = self.auth_manager.authenticate_user("inactive", "password")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_create_access_token(self):
|
||||||
|
"""Test JWT token creation."""
|
||||||
|
user = User(
|
||||||
|
user_id="token_user",
|
||||||
|
username="tokenuser",
|
||||||
|
email="token@example.com",
|
||||||
|
role=UserRole.ADMINISTRATOR,
|
||||||
|
created_at=datetime.now(timezone.utc)
|
||||||
|
)
|
||||||
|
|
||||||
|
token = self.auth_manager.create_access_token(user)
|
||||||
|
|
||||||
|
# Verify token can be decoded
|
||||||
|
payload = jwt.decode(token, "test-secret-key", algorithms=["HS256"])
|
||||||
|
assert payload["user_id"] == "token_user"
|
||||||
|
assert payload["username"] == "tokenuser"
|
||||||
|
assert payload["role"] == "administrator"
|
||||||
|
assert "exp" in payload
|
||||||
|
|
||||||
|
def test_verify_token_success(self):
|
||||||
|
"""Test successful token verification."""
|
||||||
|
# Create user and token
|
||||||
|
user = User(
|
||||||
|
user_id="verify_user",
|
||||||
|
username="verifyuser",
|
||||||
|
email="verify@example.com",
|
||||||
|
role=UserRole.OPERATOR,
|
||||||
|
created_at=datetime.now(timezone.utc)
|
||||||
|
)
|
||||||
|
self.auth_manager.users["verify_user"] = user
|
||||||
|
|
||||||
|
token = self.auth_manager.create_access_token(user)
|
||||||
|
token_data = self.auth_manager.verify_token(token)
|
||||||
|
|
||||||
|
assert token_data is not None
|
||||||
|
assert token_data.user_id == "verify_user"
|
||||||
|
assert token_data.username == "verifyuser"
|
||||||
|
assert token_data.role == UserRole.OPERATOR
|
||||||
|
|
||||||
|
def test_verify_token_expired(self):
|
||||||
|
"""Test verification of expired token."""
|
||||||
|
# Create expired token
|
||||||
|
expired_time = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||||
|
token_data = TokenData(
|
||||||
|
user_id="expired_user",
|
||||||
|
username="expireduser",
|
||||||
|
role=UserRole.OPERATOR,
|
||||||
|
exp=expired_time
|
||||||
|
)
|
||||||
|
|
||||||
|
token = jwt.encode(
|
||||||
|
token_data.dict(),
|
||||||
|
"test-secret-key",
|
||||||
|
algorithm="HS256"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.auth_manager.verify_token(token)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_verify_token_invalid_signature(self):
|
||||||
|
"""Test verification of token with invalid signature."""
|
||||||
|
# Create token with wrong secret
|
||||||
|
user = User(
|
||||||
|
user_id="invalid_user",
|
||||||
|
username="invaliduser",
|
||||||
|
email="invalid@example.com",
|
||||||
|
role=UserRole.OPERATOR,
|
||||||
|
created_at=datetime.now(timezone.utc)
|
||||||
|
)
|
||||||
|
|
||||||
|
token = jwt.encode(
|
||||||
|
{
|
||||||
|
"user_id": "invalid_user",
|
||||||
|
"username": "invaliduser",
|
||||||
|
"role": "operator",
|
||||||
|
"exp": datetime.now(timezone.utc) + timedelta(minutes=60)
|
||||||
|
},
|
||||||
|
"wrong-secret-key",
|
||||||
|
algorithm="HS256"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.auth_manager.verify_token(token)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuthorizationManager:
|
||||||
|
"""Test cases for AuthorizationManager."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Set up test fixtures."""
|
||||||
|
self.authz_manager = AuthorizationManager()
|
||||||
|
|
||||||
|
def test_has_permission_read_only(self):
|
||||||
|
"""Test permissions for read-only role."""
|
||||||
|
assert self.authz_manager.has_permission(UserRole.READ_ONLY, "read_pump_status")
|
||||||
|
assert self.authz_manager.has_permission(UserRole.READ_ONLY, "read_safety_status")
|
||||||
|
assert self.authz_manager.has_permission(UserRole.READ_ONLY, "read_audit_logs")
|
||||||
|
|
||||||
|
# Should not have write permissions
|
||||||
|
assert not self.authz_manager.has_permission(UserRole.READ_ONLY, "emergency_stop")
|
||||||
|
assert not self.authz_manager.has_permission(UserRole.READ_ONLY, "configure_safety_limits")
|
||||||
|
|
||||||
|
def test_has_permission_operator(self):
|
||||||
|
"""Test permissions for operator role."""
|
||||||
|
assert self.authz_manager.has_permission(UserRole.OPERATOR, "read_pump_status")
|
||||||
|
assert self.authz_manager.has_permission(UserRole.OPERATOR, "emergency_stop")
|
||||||
|
assert self.authz_manager.has_permission(UserRole.OPERATOR, "clear_emergency_stop")
|
||||||
|
|
||||||
|
# Should not have configuration permissions
|
||||||
|
assert not self.authz_manager.has_permission(UserRole.OPERATOR, "configure_safety_limits")
|
||||||
|
assert not self.authz_manager.has_permission(UserRole.OPERATOR, "manage_users")
|
||||||
|
|
||||||
|
def test_has_permission_engineer(self):
|
||||||
|
"""Test permissions for engineer role."""
|
||||||
|
assert self.authz_manager.has_permission(UserRole.ENGINEER, "read_pump_status")
|
||||||
|
assert self.authz_manager.has_permission(UserRole.ENGINEER, "emergency_stop")
|
||||||
|
assert self.authz_manager.has_permission(UserRole.ENGINEER, "configure_safety_limits")
|
||||||
|
assert self.authz_manager.has_permission(UserRole.ENGINEER, "manage_pump_configuration")
|
||||||
|
|
||||||
|
# Should not have administrative permissions
|
||||||
|
assert not self.authz_manager.has_permission(UserRole.ENGINEER, "manage_users")
|
||||||
|
|
||||||
|
def test_has_permission_administrator(self):
|
||||||
|
"""Test permissions for administrator role."""
|
||||||
|
# Administrator should have all permissions
|
||||||
|
all_permissions = [
|
||||||
|
"read_pump_status", "read_safety_status", "read_audit_logs",
|
||||||
|
"emergency_stop", "clear_emergency_stop", "view_alerts",
|
||||||
|
"configure_safety_limits", "manage_pump_configuration",
|
||||||
|
"view_system_metrics", "manage_users", "configure_system",
|
||||||
|
"access_all_stations"
|
||||||
|
]
|
||||||
|
|
||||||
|
for permission in all_permissions:
|
||||||
|
assert self.authz_manager.has_permission(UserRole.ADMINISTRATOR, permission)
|
||||||
|
|
||||||
|
def test_has_permission_unknown_permission(self):
|
||||||
|
"""Test unknown permission."""
|
||||||
|
assert not self.authz_manager.has_permission(UserRole.ADMINISTRATOR, "unknown_permission")
|
||||||
|
|
||||||
|
def test_get_allowed_actions(self):
|
||||||
|
"""Test getting all allowed actions for a role."""
|
||||||
|
operator_permissions = self.authz_manager.get_allowed_actions(UserRole.OPERATOR)
|
||||||
|
|
||||||
|
assert "read_pump_status" in operator_permissions
|
||||||
|
assert "emergency_stop" in operator_permissions
|
||||||
|
assert "clear_emergency_stop" in operator_permissions
|
||||||
|
assert "configure_safety_limits" not in operator_permissions
|
||||||
|
|
||||||
|
def test_can_access_station(self):
|
||||||
|
"""Test station access control."""
|
||||||
|
# Administrators can access all stations
|
||||||
|
assert self.authz_manager.can_access_station(UserRole.ADMINISTRATOR, "STATION_001")
|
||||||
|
|
||||||
|
# Other roles can access all stations (for now)
|
||||||
|
assert self.authz_manager.can_access_station(UserRole.OPERATOR, "STATION_001")
|
||||||
|
assert self.authz_manager.can_access_station(UserRole.ENGINEER, "STATION_001")
|
||||||
|
assert self.authz_manager.can_access_station(UserRole.READ_ONLY, "STATION_001")
|
||||||
|
|
||||||
|
|
||||||
|
class TestSecurityManager:
|
||||||
|
"""Test cases for SecurityManager."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Set up test fixtures."""
|
||||||
|
self.security_manager = SecurityManager()
|
||||||
|
|
||||||
|
def test_authenticate_success(self):
|
||||||
|
"""Test successful authentication."""
|
||||||
|
# Default users are created during initialization
|
||||||
|
token = self.security_manager.authenticate("operator", "operator123")
|
||||||
|
|
||||||
|
assert token is not None
|
||||||
|
|
||||||
|
# Verify token is valid
|
||||||
|
token_data = self.security_manager.verify_access_token(token)
|
||||||
|
assert token_data is not None
|
||||||
|
assert token_data.username == "operator"
|
||||||
|
assert token_data.role == UserRole.OPERATOR
|
||||||
|
|
||||||
|
def test_authenticate_failure(self):
|
||||||
|
"""Test failed authentication."""
|
||||||
|
token = self.security_manager.authenticate("nonexistent", "password")
|
||||||
|
assert token is None
|
||||||
|
|
||||||
|
def test_check_permission(self):
|
||||||
|
"""Test permission checking."""
|
||||||
|
# Create token for operator
|
||||||
|
token = self.security_manager.authenticate("operator", "operator123")
|
||||||
|
token_data = self.security_manager.verify_access_token(token)
|
||||||
|
|
||||||
|
# Operator should have emergency_stop permission
|
||||||
|
assert self.security_manager.check_permission(token_data, "emergency_stop")
|
||||||
|
|
||||||
|
# Operator should not have manage_users permission
|
||||||
|
assert not self.security_manager.check_permission(token_data, "manage_users")
|
||||||
|
|
||||||
|
def test_get_user_permissions(self):
|
||||||
|
"""Test getting user permissions."""
|
||||||
|
# Create token for engineer
|
||||||
|
token = self.security_manager.authenticate("engineer", "engineer123")
|
||||||
|
token_data = self.security_manager.verify_access_token(token)
|
||||||
|
|
||||||
|
permissions = self.security_manager.get_user_permissions(token_data)
|
||||||
|
|
||||||
|
# Engineer should have specific permissions
|
||||||
|
assert "configure_safety_limits" in permissions
|
||||||
|
assert "manage_pump_configuration" in permissions
|
||||||
|
assert "emergency_stop" in permissions
|
||||||
|
|
||||||
|
# Engineer should not have administrative permissions
|
||||||
|
assert "manage_users" not in permissions
|
||||||
|
|
||||||
|
def test_can_access_station(self):
|
||||||
|
"""Test station access control."""
|
||||||
|
# Create token for operator
|
||||||
|
token = self.security_manager.authenticate("operator", "operator123")
|
||||||
|
token_data = self.security_manager.verify_access_token(token)
|
||||||
|
|
||||||
|
# Operator should be able to access any station
|
||||||
|
assert self.security_manager.can_access_station(token_data, "STATION_001")
|
||||||
|
assert self.security_manager.can_access_station(token_data, "STATION_002")
|
||||||
|
|
@ -0,0 +1,243 @@
|
||||||
|
"""
|
||||||
|
Unit tests for TLS manager components.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import unittest.mock
|
||||||
|
from unittest.mock import Mock, patch, MagicMock
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
import ssl
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from src.core.tls_manager import TLSManager
|
||||||
|
from config.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class TestTLSManager:
|
||||||
|
"""Test cases for TLSManager."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Set up test fixtures."""
|
||||||
|
# Create temporary directory for test certificates
|
||||||
|
self.temp_dir = tempfile.mkdtemp()
|
||||||
|
self.cert_path = os.path.join(self.temp_dir, "test_cert.pem")
|
||||||
|
self.key_path = os.path.join(self.temp_dir, "test_key.pem")
|
||||||
|
|
||||||
|
# Create dummy certificate files for testing
|
||||||
|
with open(self.cert_path, 'w') as f:
|
||||||
|
f.write("DUMMY CERTIFICATE")
|
||||||
|
with open(self.key_path, 'w') as f:
|
||||||
|
f.write("DUMMY PRIVATE KEY")
|
||||||
|
|
||||||
|
# Mock settings
|
||||||
|
self.original_tls_enabled = settings.tls_enabled
|
||||||
|
self.original_cert_path = settings.tls_cert_path
|
||||||
|
self.original_key_path = settings.tls_key_path
|
||||||
|
|
||||||
|
settings.tls_enabled = True
|
||||||
|
settings.tls_cert_path = self.cert_path
|
||||||
|
settings.tls_key_path = self.key_path
|
||||||
|
|
||||||
|
def teardown_method(self):
|
||||||
|
"""Clean up test fixtures."""
|
||||||
|
# Restore original settings
|
||||||
|
settings.tls_enabled = self.original_tls_enabled
|
||||||
|
settings.tls_cert_path = self.original_cert_path
|
||||||
|
settings.tls_key_path = self.original_key_path
|
||||||
|
|
||||||
|
# Clean up temporary directory
|
||||||
|
import shutil
|
||||||
|
shutil.rmtree(self.temp_dir)
|
||||||
|
|
||||||
|
def test_initialization_with_valid_certificates(self):
|
||||||
|
"""Test initialization with valid certificate files."""
|
||||||
|
tls_manager = TLSManager()
|
||||||
|
|
||||||
|
assert tls_manager.tls_enabled is True
|
||||||
|
assert tls_manager.cert_path == self.cert_path
|
||||||
|
assert tls_manager.key_path == self.key_path
|
||||||
|
|
||||||
|
def test_initialization_with_missing_certificates(self):
|
||||||
|
"""Test initialization with missing certificate files."""
|
||||||
|
# Set invalid paths
|
||||||
|
settings.tls_cert_path = "/nonexistent/cert.pem"
|
||||||
|
settings.tls_key_path = "/nonexistent/key.pem"
|
||||||
|
|
||||||
|
tls_manager = TLSManager()
|
||||||
|
|
||||||
|
# Should still initialize but with validation failure
|
||||||
|
assert tls_manager.tls_enabled is True
|
||||||
|
assert tls_manager.cert_path == "/nonexistent/cert.pem"
|
||||||
|
|
||||||
|
def test_initialization_with_tls_disabled(self):
|
||||||
|
"""Test initialization with TLS disabled."""
|
||||||
|
settings.tls_enabled = False
|
||||||
|
|
||||||
|
tls_manager = TLSManager()
|
||||||
|
|
||||||
|
assert tls_manager.tls_enabled is False
|
||||||
|
|
||||||
|
@patch('src.core.tls_manager.TLSManager._get_certificate_expiry')
|
||||||
|
def test_validate_certificates_success(self, mock_get_expiry):
|
||||||
|
"""Test successful certificate validation."""
|
||||||
|
# Mock certificate expiry to be in the future
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
future_date = datetime(2030, 12, 31, 23, 59, 59, tzinfo=timezone.utc)
|
||||||
|
mock_get_expiry.return_value = future_date
|
||||||
|
|
||||||
|
tls_manager = TLSManager()
|
||||||
|
result = tls_manager._validate_certificates()
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
@patch('src.core.tls_manager.TLSManager._get_certificate_expiry')
|
||||||
|
def test_validate_certificates_expired(self, mock_get_expiry):
|
||||||
|
"""Test validation of expired certificate."""
|
||||||
|
# Mock certificate expiry to be in the past
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
past_date = datetime(2020, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
|
||||||
|
mock_get_expiry.return_value = past_date
|
||||||
|
|
||||||
|
tls_manager = TLSManager()
|
||||||
|
result = tls_manager._validate_certificates()
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@patch('src.core.tls_manager.TLSManager._get_certificate_expiry')
|
||||||
|
def test_validate_certificates_parsing_error(self, mock_get_expiry):
|
||||||
|
"""Test certificate validation with parsing error."""
|
||||||
|
# Mock certificate parsing to raise an exception
|
||||||
|
mock_get_expiry.side_effect = Exception("Parse error")
|
||||||
|
|
||||||
|
tls_manager = TLSManager()
|
||||||
|
result = tls_manager._validate_certificates()
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@patch('src.core.tls_manager.ssl')
|
||||||
|
def test_create_ssl_context_success(self, mock_ssl):
|
||||||
|
"""Test successful SSL context creation."""
|
||||||
|
# Mock SSL context
|
||||||
|
mock_context = Mock()
|
||||||
|
mock_ssl.create_default_context.return_value = mock_context
|
||||||
|
|
||||||
|
tls_manager = TLSManager()
|
||||||
|
|
||||||
|
# Mock validation to succeed
|
||||||
|
with patch.object(tls_manager, '_validate_certificates', return_value=True):
|
||||||
|
result = tls_manager.create_ssl_context()
|
||||||
|
|
||||||
|
assert result is mock_context
|
||||||
|
# Check that create_default_context was called with CLIENT_AUTH purpose
|
||||||
|
mock_ssl.create_default_context.assert_called_once()
|
||||||
|
call_args = mock_ssl.create_default_context.call_args
|
||||||
|
assert call_args[0][0] == mock_ssl.Purpose.CLIENT_AUTH
|
||||||
|
mock_context.load_cert_chain.assert_called_once_with(self.cert_path, self.key_path)
|
||||||
|
|
||||||
|
def test_create_ssl_context_tls_disabled(self):
|
||||||
|
"""Test SSL context creation with TLS disabled."""
|
||||||
|
settings.tls_enabled = False
|
||||||
|
|
||||||
|
tls_manager = TLSManager()
|
||||||
|
result = tls_manager.create_ssl_context()
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_create_ssl_context_invalid_certificates(self):
|
||||||
|
"""Test SSL context creation with invalid certificates."""
|
||||||
|
tls_manager = TLSManager()
|
||||||
|
|
||||||
|
# Mock validation to fail
|
||||||
|
with patch.object(tls_manager, '_validate_certificates', return_value=False):
|
||||||
|
result = tls_manager.create_ssl_context()
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_get_rest_api_ssl_config_success(self):
|
||||||
|
"""Test successful REST API SSL configuration."""
|
||||||
|
tls_manager = TLSManager()
|
||||||
|
|
||||||
|
# Mock validation to succeed
|
||||||
|
with patch.object(tls_manager, '_validate_certificates', return_value=True):
|
||||||
|
result = tls_manager.get_rest_api_ssl_config()
|
||||||
|
|
||||||
|
assert result == (self.cert_path, self.key_path)
|
||||||
|
|
||||||
|
def test_get_rest_api_ssl_config_tls_disabled(self):
|
||||||
|
"""Test REST API SSL configuration with TLS disabled."""
|
||||||
|
settings.tls_enabled = False
|
||||||
|
|
||||||
|
tls_manager = TLSManager()
|
||||||
|
result = tls_manager.get_rest_api_ssl_config()
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_get_rest_api_ssl_config_invalid_certificates(self):
|
||||||
|
"""Test REST API SSL configuration with invalid certificates."""
|
||||||
|
tls_manager = TLSManager()
|
||||||
|
|
||||||
|
# Mock validation to fail
|
||||||
|
with patch.object(tls_manager, '_validate_certificates', return_value=False):
|
||||||
|
result = tls_manager.get_rest_api_ssl_config()
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_check_certificate_rotation_needed(self):
|
||||||
|
"""Test certificate rotation check when needed."""
|
||||||
|
tls_manager = TLSManager()
|
||||||
|
|
||||||
|
# Mock certificate with near expiry
|
||||||
|
from datetime import datetime, timezone, timedelta
|
||||||
|
near_expiry = datetime.now(timezone.utc) + timedelta(days=5) # Expires in 5 days
|
||||||
|
tls_manager.cert_expiry_dates[self.cert_path] = near_expiry
|
||||||
|
|
||||||
|
result = tls_manager.check_certificate_rotation()
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
def test_check_certificate_rotation_not_needed(self):
|
||||||
|
"""Test certificate rotation check when not needed."""
|
||||||
|
tls_manager = TLSManager()
|
||||||
|
|
||||||
|
# Mock certificate with distant expiry
|
||||||
|
from datetime import datetime, timezone, timedelta
|
||||||
|
distant_expiry = datetime.now(timezone.utc) + timedelta(days=365) # Expires in 1 year
|
||||||
|
tls_manager.cert_expiry_dates[self.cert_path] = distant_expiry
|
||||||
|
|
||||||
|
result = tls_manager.check_certificate_rotation()
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
def test_check_certificate_rotation_tls_disabled(self):
|
||||||
|
"""Test certificate rotation check with TLS disabled."""
|
||||||
|
settings.tls_enabled = False
|
||||||
|
|
||||||
|
tls_manager = TLSManager()
|
||||||
|
result = tls_manager.check_certificate_rotation()
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@patch('src.core.tls_manager.os.makedirs')
|
||||||
|
def test_generate_self_signed_certificate_success(self, mock_makedirs):
|
||||||
|
"""Test successful self-signed certificate generation."""
|
||||||
|
tls_manager = TLSManager()
|
||||||
|
|
||||||
|
# Mock the entire certificate generation to succeed
|
||||||
|
with patch('builtins.open', unittest.mock.mock_open()) as mock_file:
|
||||||
|
result = tls_manager.generate_self_signed_certificate(self.temp_dir)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert tls_manager.cert_path.endswith("cert.pem")
|
||||||
|
assert tls_manager.key_path.endswith("key.pem")
|
||||||
|
|
||||||
|
@patch('src.core.tls_manager.os.makedirs')
|
||||||
|
def test_generate_self_signed_certificate_failure(self, mock_makedirs):
|
||||||
|
"""Test self-signed certificate generation failure."""
|
||||||
|
# Mock makedirs to raise an exception
|
||||||
|
mock_makedirs.side_effect = Exception("Permission denied")
|
||||||
|
|
||||||
|
tls_manager = TLSManager()
|
||||||
|
result = tls_manager.generate_self_signed_certificate(self.temp_dir)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
Loading…
Reference in New Issue