Phase 4: Complete Security Layer Implementation

- Implemented JWT-based authentication with bcrypt password hashing
- Added role-based access control (RBAC) with four user roles
- Created TLS/SSL encryption with certificate management
- Enhanced audit logging for IEC 62443, ISO 27001, and NIS2 compliance
- Added comprehensive security tests (56 tests passing)
- Updated REST API with authentication and permission checks
- Added security settings to configuration

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
openhands 2025-10-27 20:07:37 +00:00
parent db0ace8d2c
commit dfa3f0832b
9 changed files with 2346 additions and 28 deletions

View File

@ -35,10 +35,12 @@ The Calejo Control Adapter translates optimized pump control plans from Calejo O
- Unified main application
- 15 comprehensive unit tests for SetpointManager
🔄 **Phase 4**: Security Layer (In Progress)
- Authentication and authorization
- Audit logging
- TLS/SSL encryption
**Phase 4**: Security Layer
- JWT-based authentication with bcrypt password hashing
- Role-based access control (RBAC) with four user roles
- TLS/SSL encryption with certificate management
- Compliance audit logging for IEC 62443, ISO 27001, and NIS2
- 56 comprehensive security tests (24 auth/authz, 17 TLS, 15 audit)
**Phase 5**: Protocol Servers (Pending)
- Enhanced protocol implementations

View File

@ -29,6 +29,18 @@ class Settings(BaseSettings):
tls_cert_path: Optional[str] = None
tls_key_path: Optional[str] = None
# JWT Authentication
jwt_secret_key: str = "your-secret-key-change-in-production"
jwt_token_expire_minutes: int = 60
jwt_algorithm: str = "HS256"
# Password policy
password_min_length: int = 8
password_require_uppercase: bool = True
password_require_lowercase: bool = True
password_require_numbers: bool = True
password_require_special: bool = True
# OPC UA
opcua_enabled: bool = True
opcua_host: str = "localhost"

View File

@ -0,0 +1,463 @@
"""
Compliance Audit Logger for Calejo Control Adapter.
Provides enhanced audit logging capabilities compliant with:
- IEC 62443 (Industrial Automation and Control Systems Security)
- ISO 27001 (Information Security Management)
- NIS2 Directive (Network and Information Systems Security)
"""
import structlog
from datetime import datetime, timezone
from typing import Dict, Any, Optional, List
from enum import Enum
from config.settings import settings
logger = structlog.get_logger()
class AuditEventType(Enum):
"""Audit event types for compliance requirements."""
# Authentication and Authorization
USER_LOGIN = "user_login"
USER_LOGOUT = "user_logout"
USER_CREATED = "user_created"
USER_MODIFIED = "user_modified"
USER_DELETED = "user_deleted"
PASSWORD_CHANGED = "password_changed"
ROLE_CHANGED = "role_changed"
# System Access
SYSTEM_START = "system_start"
SYSTEM_STOP = "system_stop"
SYSTEM_CONFIG_CHANGED = "system_config_changed"
# Control Operations
SETPOINT_CHANGED = "setpoint_changed"
EMERGENCY_STOP_ACTIVATED = "emergency_stop_activated"
EMERGENCY_STOP_RESET = "emergency_stop_reset"
PUMP_CONTROL = "pump_control"
VALVE_CONTROL = "valve_control"
# Security Events
ACCESS_DENIED = "access_denied"
INVALID_AUTHENTICATION = "invalid_authentication"
SESSION_TIMEOUT = "session_timeout"
CERTIFICATE_EXPIRED = "certificate_expired"
CERTIFICATE_ROTATED = "certificate_rotated"
# Data Operations
DATA_READ = "data_read"
DATA_WRITE = "data_write"
DATA_EXPORT = "data_export"
DATA_DELETED = "data_deleted"
# Network Operations
CONNECTION_ESTABLISHED = "connection_established"
CONNECTION_CLOSED = "connection_closed"
CONNECTION_REJECTED = "connection_rejected"
# Compliance Events
AUDIT_LOG_ACCESSED = "audit_log_accessed"
COMPLIANCE_CHECK = "compliance_check"
SECURITY_SCAN = "security_scan"
class AuditSeverity(Enum):
"""Audit event severity levels."""
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
CRITICAL = "critical"
class ComplianceAuditLogger:
"""
Enhanced audit logger for compliance requirements.
Provides comprehensive audit trail capabilities compliant with
IEC 62443, ISO 27001, and NIS2 Directive requirements.
"""
def __init__(self, db_client):
self.db_client = db_client
self.logger = structlog.get_logger(__name__)
def log_compliance_event(
self,
event_type: AuditEventType,
severity: AuditSeverity,
user_id: Optional[str] = None,
station_id: Optional[str] = None,
pump_id: Optional[str] = None,
ip_address: Optional[str] = None,
protocol: Optional[str] = None,
action: Optional[str] = None,
resource: Optional[str] = None,
result: Optional[str] = None,
reason: Optional[str] = None,
compliance_standard: Optional[List[str]] = None,
event_data: Optional[Dict[str, Any]] = None
):
"""
Log a compliance audit event.
Args:
event_type: Type of audit event
severity: Severity level
user_id: User ID performing the action
station_id: Station ID if applicable
pump_id: Pump ID if applicable
ip_address: Source IP address
protocol: Communication protocol used
action: Specific action performed
resource: Resource accessed or modified
result: Result of the action (success/failure)
reason: Reason for the action or failure
compliance_standard: List of compliance standards this event relates to
event_data: Additional event-specific data
"""
# Default compliance standards
if compliance_standard is None:
compliance_standard = ["IEC_62443", "ISO_27001", "NIS2"]
# Create comprehensive audit record
audit_record = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"event_type": event_type.value,
"severity": severity.value,
"user_id": user_id,
"station_id": station_id,
"pump_id": pump_id,
"ip_address": ip_address,
"protocol": protocol,
"action": action,
"resource": resource,
"result": result,
"reason": reason,
"compliance_standard": compliance_standard,
"event_data": event_data or {},
"app_name": settings.app_name,
"app_version": settings.app_version,
"environment": settings.environment
}
# Log to structured logs
self.logger.info(
"compliance_audit_event",
**audit_record
)
# Log to database if audit logging is enabled
if settings.audit_log_enabled:
self._log_to_database(audit_record)
def _log_to_database(self, audit_record: Dict[str, Any]):
"""Log audit record to database."""
try:
query = """
INSERT INTO compliance_audit_log
(timestamp, event_type, severity, user_id, station_id, pump_id,
ip_address, protocol, action, resource, result, reason,
compliance_standard, event_data, app_name, app_version, environment)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
"""
self.db_client.execute(
query,
(
audit_record["timestamp"],
audit_record["event_type"],
audit_record["severity"],
audit_record["user_id"],
audit_record["station_id"],
audit_record["pump_id"],
audit_record["ip_address"],
audit_record["protocol"],
audit_record["action"],
audit_record["resource"],
audit_record["result"],
audit_record["reason"],
audit_record["compliance_standard"],
audit_record["event_data"],
audit_record["app_name"],
audit_record["app_version"],
audit_record["environment"]
)
)
except Exception as e:
self.logger.error(
"compliance_audit_database_failed",
error=str(e),
event_type=audit_record["event_type"]
)
def log_user_authentication(
self,
user_id: str,
result: str,
ip_address: str,
reason: Optional[str] = None
):
"""Log user authentication events."""
event_type = (
AuditEventType.USER_LOGIN if result == "success"
else AuditEventType.INVALID_AUTHENTICATION
)
severity = (
AuditSeverity.LOW if result == "success"
else AuditSeverity.HIGH
)
self.log_compliance_event(
event_type=event_type,
severity=severity,
user_id=user_id,
ip_address=ip_address,
action="authentication",
resource="system",
result=result,
reason=reason
)
def log_access_control(
self,
user_id: str,
action: str,
resource: str,
result: str,
ip_address: str,
reason: Optional[str] = None
):
"""Log access control events."""
event_type = (
AuditEventType.ACCESS_DENIED if result == "denied"
else AuditEventType.DATA_READ if action == "read"
else AuditEventType.DATA_WRITE if action == "write"
else AuditEventType.SYSTEM_CONFIG_CHANGED
)
severity = (
AuditSeverity.HIGH if result == "denied"
else AuditSeverity.MEDIUM
)
self.log_compliance_event(
event_type=event_type,
severity=severity,
user_id=user_id,
ip_address=ip_address,
action=action,
resource=resource,
result=result,
reason=reason
)
def log_control_operation(
self,
user_id: str,
station_id: str,
pump_id: Optional[str],
action: str,
resource: str,
result: str,
ip_address: str,
event_data: Optional[Dict[str, Any]] = None
):
"""Log control system operations."""
event_type = (
AuditEventType.SETPOINT_CHANGED if action == "setpoint_change"
else AuditEventType.EMERGENCY_STOP_ACTIVATED if action == "emergency_stop"
else AuditEventType.PUMP_CONTROL if "pump" in resource.lower()
else AuditEventType.VALVE_CONTROL if "valve" in resource.lower()
else AuditEventType.SYSTEM_CONFIG_CHANGED
)
severity = (
AuditSeverity.CRITICAL if action == "emergency_stop"
else AuditSeverity.HIGH if action == "setpoint_change"
else AuditSeverity.MEDIUM
)
self.log_compliance_event(
event_type=event_type,
severity=severity,
user_id=user_id,
station_id=station_id,
pump_id=pump_id,
ip_address=ip_address,
action=action,
resource=resource,
result=result,
event_data=event_data
)
def log_security_event(
self,
event_type: AuditEventType,
severity: AuditSeverity,
user_id: Optional[str] = None,
station_id: Optional[str] = None,
ip_address: Optional[str] = None,
reason: Optional[str] = None,
event_data: Optional[Dict[str, Any]] = None
):
"""Log security-related events."""
self.log_compliance_event(
event_type=event_type,
severity=severity,
user_id=user_id,
station_id=station_id,
ip_address=ip_address,
action="security_event",
resource="system",
result="detected",
reason=reason,
event_data=event_data
)
def get_audit_trail(
self,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
user_id: Optional[str] = None,
event_type: Optional[AuditEventType] = None,
severity: Optional[AuditSeverity] = None
) -> List[Dict[str, Any]]:
"""
Retrieve audit trail for compliance reporting.
Args:
start_time: Start time for audit trail
end_time: End time for audit trail
user_id: Filter by user ID
event_type: Filter by event type
severity: Filter by severity
Returns:
List of audit records
"""
try:
query = """
SELECT * FROM compliance_audit_log
WHERE 1=1
"""
params = []
if start_time:
query += " AND timestamp >= %s"
params.append(start_time.isoformat())
if end_time:
query += " AND timestamp <= %s"
params.append(end_time.isoformat())
if user_id:
query += " AND user_id = %s"
params.append(user_id)
if event_type:
query += " AND event_type = %s"
params.append(event_type.value)
if severity:
query += " AND severity = %s"
params.append(severity.value)
query += " ORDER BY timestamp DESC"
result = self.db_client.fetch_all(query, params)
# Log the audit trail access
self.log_compliance_event(
event_type=AuditEventType.AUDIT_LOG_ACCESSED,
severity=AuditSeverity.LOW,
user_id=user_id,
action="audit_trail_access",
resource="audit_log",
result="success"
)
return result
except Exception as e:
self.logger.error(
"audit_trail_retrieval_failed",
error=str(e),
user_id=user_id
)
return []
def generate_compliance_report(
self,
start_time: datetime,
end_time: datetime,
compliance_standard: str
) -> Dict[str, Any]:
"""
Generate compliance report for specified standard.
Args:
start_time: Report start time
end_time: Report end time
compliance_standard: Compliance standard to report on
Returns:
Compliance report data
"""
try:
query = """
SELECT
event_type,
severity,
COUNT(*) as count,
COUNT(DISTINCT user_id) as unique_users
FROM compliance_audit_log
WHERE timestamp BETWEEN %s AND %s
AND %s = ANY(compliance_standard)
GROUP BY event_type, severity
ORDER BY count DESC
"""
result = self.db_client.fetch_all(
query,
(start_time.isoformat(), end_time.isoformat(), compliance_standard)
)
report = {
"compliance_standard": compliance_standard,
"report_period": {
"start_time": start_time.isoformat(),
"end_time": end_time.isoformat()
},
"summary": {
"total_events": sum(row["count"] for row in result),
"unique_users": sum(row["unique_users"] for row in result),
"event_types": len(set(row["event_type"] for row in result))
},
"events_by_type": result
}
# Log the compliance report generation
self.log_compliance_event(
event_type=AuditEventType.COMPLIANCE_CHECK,
severity=AuditSeverity.LOW,
action="compliance_report_generated",
resource="compliance",
result="success",
event_data={
"compliance_standard": compliance_standard,
"report_period": report["report_period"]
}
)
return report
except Exception as e:
self.logger.error(
"compliance_report_generation_failed",
error=str(e),
compliance_standard=compliance_standard
)
return {"error": str(e)}

358
src/core/security.py Normal file
View File

@ -0,0 +1,358 @@
"""
Security layer for Calejo Control Adapter.
Provides authentication, authorization, and security utilities for compliance
with IEC 62443, ISO 27001, and NIS2 Directive requirements.
"""
import jwt
import bcrypt
from datetime import datetime, timedelta, timezone
from typing import Optional, Dict, List, Any
from enum import Enum
from pydantic import BaseModel
import structlog
from config.settings import settings
logger = structlog.get_logger()
class UserRole(str, Enum):
"""User roles for role-based access control."""
OPERATOR = "operator"
ENGINEER = "engineer"
ADMINISTRATOR = "administrator"
READ_ONLY = "read_only"
class User(BaseModel):
"""User model for authentication and authorization."""
user_id: str
username: str
email: str
role: UserRole
active: bool = True
created_at: datetime
last_login: Optional[datetime] = None
class TokenData(BaseModel):
"""Data encoded in JWT tokens."""
user_id: str
username: str
role: UserRole
exp: datetime
class AuthenticationManager:
"""
Manages user authentication with JWT tokens and password hashing.
Implements security controls for IEC 62443, ISO 27001, and NIS2 compliance.
"""
def __init__(self, secret_key: str, algorithm: str = "HS256", token_expire_minutes: int = 60):
self.secret_key = secret_key
self.algorithm = algorithm
self.token_expire_minutes = token_expire_minutes
# In-memory user store (in production, this would be a database)
self.users: Dict[str, User] = {}
self.password_hashes: Dict[str, str] = {}
# Initialize with default users for development
self._initialize_default_users()
def _initialize_default_users(self):
"""Initialize default users for development and testing."""
default_users = [
{
"user_id": "admin_001",
"username": "admin",
"email": "admin@calejo.com",
"role": UserRole.ADMINISTRATOR,
"password": "admin123"
},
{
"user_id": "operator_001",
"username": "operator",
"email": "operator@calejo.com",
"role": UserRole.OPERATOR,
"password": "operator123"
},
{
"user_id": "engineer_001",
"username": "engineer",
"email": "engineer@calejo.com",
"role": UserRole.ENGINEER,
"password": "engineer123"
},
{
"user_id": "viewer_001",
"username": "viewer",
"email": "viewer@calejo.com",
"role": UserRole.READ_ONLY,
"password": "viewer123"
}
]
for user_data in default_users:
self.create_user(
user_id=user_data["user_id"],
username=user_data["username"],
email=user_data["email"],
role=user_data["role"],
password=user_data["password"]
)
def hash_password(self, password: str) -> str:
"""Hash a password using bcrypt."""
salt = bcrypt.gensalt()
hashed = bcrypt.hashpw(password.encode('utf-8'), salt)
return hashed.decode('utf-8')
def verify_password(self, plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash."""
return bcrypt.checkpw(
plain_password.encode('utf-8'),
hashed_password.encode('utf-8')
)
def create_user(self, user_id: str, username: str, email: str, role: UserRole, password: str) -> User:
"""Create a new user with hashed password."""
if user_id in self.users:
raise ValueError(f"User with ID {user_id} already exists")
user = User(
user_id=user_id,
username=username,
email=email,
role=role,
created_at=datetime.now(timezone.utc)
)
self.users[user_id] = user
self.password_hashes[user_id] = self.hash_password(password)
logger.info("user_created", user_id=user_id, username=username, role=role.value)
return user
def authenticate_user(self, username: str, password: str) -> Optional[User]:
"""Authenticate a user and return user object if successful."""
# Find user by username
user = None
for u in self.users.values():
if u.username == username and u.active:
user = u
break
if not user:
logger.warning("authentication_failed", username=username, reason="user_not_found")
return None
# Verify password
if not self.verify_password(password, self.password_hashes[user.user_id]):
logger.warning("authentication_failed", username=username, reason="invalid_password")
return None
# Update last login
user.last_login = datetime.now(timezone.utc)
logger.info("user_authenticated", user_id=user.user_id, username=username, role=user.role.value)
return user
def create_access_token(self, user: User) -> str:
"""Create a JWT access token for the user."""
expires_delta = timedelta(minutes=self.token_expire_minutes)
expire = datetime.now(timezone.utc) + expires_delta
token_data = TokenData(
user_id=user.user_id,
username=user.username,
role=user.role,
exp=expire
)
encoded_jwt = jwt.encode(
token_data.dict(),
self.secret_key,
algorithm=self.algorithm
)
return encoded_jwt
def verify_token(self, token: str) -> Optional[TokenData]:
"""Verify and decode a JWT token."""
try:
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
token_data = TokenData(**payload)
# Check if token is expired
if token_data.exp < datetime.now(timezone.utc):
logger.warning("token_expired", username=token_data.username)
return None
# Verify user still exists and is active
user = self.users.get(token_data.user_id)
if not user or not user.active:
logger.warning("token_invalid_user", username=token_data.username)
return None
return token_data
except jwt.PyJWTError as e:
logger.warning("token_verification_failed", error=str(e))
return None
class AuthorizationManager:
"""
Manages role-based access control (RBAC) for authorization.
Implements IEC 62443 zone security model and ISO 27001 access control requirements.
"""
def __init__(self):
# Define permissions for each role
self.role_permissions = {
UserRole.READ_ONLY: {
"read_pump_status",
"read_safety_status",
"read_audit_logs"
},
UserRole.OPERATOR: {
"read_pump_status",
"read_safety_status",
"read_audit_logs",
"emergency_stop",
"clear_emergency_stop",
"view_alerts"
},
UserRole.ENGINEER: {
"read_pump_status",
"read_safety_status",
"read_audit_logs",
"emergency_stop",
"clear_emergency_stop",
"view_alerts",
"configure_safety_limits",
"manage_pump_configuration",
"view_system_metrics"
},
UserRole.ADMINISTRATOR: {
"read_pump_status",
"read_safety_status",
"read_audit_logs",
"emergency_stop",
"clear_emergency_stop",
"view_alerts",
"configure_safety_limits",
"manage_pump_configuration",
"view_system_metrics",
"manage_users",
"configure_system",
"access_all_stations"
}
}
def has_permission(self, role: UserRole, permission: str) -> bool:
"""Check if a role has the specified permission."""
permissions = self.role_permissions.get(role, set())
return permission in permissions
def get_allowed_actions(self, role: UserRole) -> List[str]:
"""Get all allowed actions for a role."""
return list(self.role_permissions.get(role, set()))
def can_access_station(self, role: UserRole, station_id: str) -> bool:
"""
Check if user can access a specific station.
Administrators can access all stations, others may have station-specific
permissions (to be implemented with database integration).
"""
if role == UserRole.ADMINISTRATOR:
return True
# For now, all authenticated users can access all stations
# In production, this would check station-specific permissions
return True
class SecurityManager:
"""
Main security manager that coordinates authentication and authorization.
Provides a unified interface for security operations and compliance logging.
"""
def __init__(self, audit_logger=None):
self.auth_manager = AuthenticationManager(
secret_key=settings.jwt_secret_key,
token_expire_minutes=settings.jwt_token_expire_minutes
)
self.authz_manager = AuthorizationManager()
self.audit_logger = audit_logger
def authenticate(self, username: str, password: str) -> Optional[str]:
"""Authenticate user and return JWT token if successful."""
user = self.auth_manager.authenticate_user(username, password)
if user:
token = self.auth_manager.create_access_token(user)
# Log successful authentication
if self.audit_logger:
self.audit_logger.log(
event_type="USER_AUTHENTICATION",
severity="INFO",
user_id=user.user_id,
event_data={
"username": username,
"role": user.role.value,
"result": "SUCCESS"
}
)
return token
# Log failed authentication
if self.audit_logger:
self.audit_logger.log(
event_type="USER_AUTHENTICATION",
severity="WARNING",
event_data={
"username": username,
"result": "FAILURE"
}
)
return None
def verify_access_token(self, token: str) -> Optional[TokenData]:
"""Verify JWT token and return token data if valid."""
return self.auth_manager.verify_token(token)
def check_permission(self, token_data: TokenData, permission: str) -> bool:
"""Check if user has the specified permission."""
return self.authz_manager.has_permission(token_data.role, permission)
def can_access_station(self, token_data: TokenData, station_id: str) -> bool:
"""Check if user can access the specified station."""
return self.authz_manager.can_access_station(token_data.role, station_id)
def get_user_permissions(self, token_data: TokenData) -> List[str]:
"""Get all permissions for the user."""
return self.authz_manager.get_allowed_actions(token_data.role)
# Global security manager instance
security_manager: Optional[SecurityManager] = None
def get_security_manager() -> SecurityManager:
"""Get or create the global security manager instance."""
global security_manager
if security_manager is None:
security_manager = SecurityManager()
return security_manager

304
src/core/tls_manager.py Normal file
View File

@ -0,0 +1,304 @@
"""
TLS/SSL Manager for Calejo Control Adapter.
Provides certificate management and TLS/SSL configuration for secure communications
in compliance with IEC 62443, ISO 27001, and NIS2 Directive requirements.
"""
import os
import ssl
from typing import Optional, Tuple
from pathlib import Path
import structlog
from datetime import datetime, timedelta, timezone
from config.settings import settings
logger = structlog.get_logger()
class TLSManager:
"""
Manages TLS/SSL certificates and secure communications.
Provides certificate validation, rotation, and secure context creation
for all protocol servers.
"""
def __init__(self):
self.cert_path = settings.tls_cert_path
self.key_path = settings.tls_key_path
self.tls_enabled = settings.tls_enabled
# Certificate rotation tracking
self.cert_expiry_dates: dict[str, datetime] = {}
# Validate certificates on initialization
if self.tls_enabled:
self._validate_certificates()
def _validate_certificates(self) -> bool:
"""
Validate TLS certificates exist and are valid.
Returns:
bool: True if certificates are valid, False otherwise
"""
if not self.cert_path or not self.key_path:
logger.warning(
"tls_certificates_missing",
cert_path=self.cert_path,
key_path=self.key_path
)
return False
# Check if certificate files exist
cert_file = Path(self.cert_path)
key_file = Path(self.key_path)
if not cert_file.exists():
logger.error("tls_certificate_file_missing", path=self.cert_path)
return False
if not key_file.exists():
logger.error("tls_key_file_missing", path=self.key_path)
return False
# Validate certificate format and expiry
try:
expiry_date = self._get_certificate_expiry(self.cert_path)
self.cert_expiry_dates[self.cert_path] = expiry_date
# Check if certificate is expired or expiring soon
now = datetime.now(timezone.utc)
days_until_expiry = (expiry_date - now).days
if days_until_expiry < 0:
logger.error("tls_certificate_expired", expiry_date=expiry_date.isoformat())
return False
elif days_until_expiry < 30:
logger.warning(
"tls_certificate_expiring_soon",
expiry_date=expiry_date.isoformat(),
days_remaining=days_until_expiry
)
else:
logger.info(
"tls_certificate_valid",
expiry_date=expiry_date.isoformat(),
days_remaining=days_until_expiry
)
return True
except Exception as e:
logger.error("tls_certificate_validation_failed", error=str(e))
return False
def _get_certificate_expiry(self, cert_path: str) -> datetime:
"""
Get certificate expiry date.
Args:
cert_path: Path to certificate file
Returns:
datetime: Certificate expiry date
Raises:
Exception: If certificate cannot be parsed
"""
import OpenSSL
with open(cert_path, 'rb') as f:
cert_data = f.read()
cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_data)
expiry_bytes = cert.get_notAfter()
# Parse ASN.1 time format (YYYYMMDDHHMMSSZ)
expiry_str = expiry_bytes.decode('ascii')
expiry_date = datetime.strptime(expiry_str, '%Y%m%d%H%M%SZ')
# Make timezone aware
return expiry_date.replace(tzinfo=timezone.utc)
def create_ssl_context(self) -> Optional[ssl.SSLContext]:
"""
Create SSL context for secure communications.
Returns:
Optional[ssl.SSLContext]: SSL context if TLS is enabled and valid,
None otherwise
"""
if not self.tls_enabled:
logger.info("tls_disabled")
return None
if not self._validate_certificates():
logger.error("cannot_create_ssl_context_invalid_certificates")
return None
try:
# Create SSL context with secure defaults
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
# Load certificate and key
context.load_cert_chain(self.cert_path, self.key_path)
# Configure secure cipher suites
context.set_ciphers('ECDHE+AESGCM:ECDHE+CHACHA20:DHE+AESGCM:DHE+CHACHA20:!aNULL:!MD5:!DSS')
# Enable certificate verification
context.verify_mode = ssl.CERT_REQUIRED
context.check_hostname = True
# Set minimum TLS version (TLS 1.2 or higher)
context.minimum_version = ssl.TLSVersion.TLSv1_2
logger.info("ssl_context_created_successfully")
return context
except Exception as e:
logger.error("failed_to_create_ssl_context", error=str(e))
return None
def get_rest_api_ssl_config(self) -> Optional[Tuple[str, str]]:
"""
Get SSL configuration for REST API server.
Returns:
Optional[Tuple[str, str]]: Tuple of (certfile, keyfile) paths if TLS is enabled,
None otherwise
"""
if not self.tls_enabled:
return None
if not self._validate_certificates():
return None
return (self.cert_path, self.key_path)
def check_certificate_rotation(self) -> bool:
"""
Check if certificates need rotation.
Returns:
bool: True if certificates need rotation, False otherwise
"""
if not self.tls_enabled:
return False
for cert_path, expiry_date in self.cert_expiry_dates.items():
now = datetime.now(timezone.utc)
days_until_expiry = (expiry_date - now).days
# Rotate if certificate expires in less than 7 days
if days_until_expiry < 7:
logger.warning(
"certificate_needs_rotation",
cert_path=cert_path,
expiry_date=expiry_date.isoformat(),
days_remaining=days_until_expiry
)
return True
return False
def generate_self_signed_certificate(self, output_dir: str, common_name: str = "calejo-control.local") -> bool:
"""
Generate self-signed certificate for development/testing.
Args:
output_dir: Directory to save certificate files
common_name: Common name for the certificate
Returns:
bool: True if certificate generation succeeded, False otherwise
"""
try:
from cryptography import x509
from cryptography.x509.oid import NameOID
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.backends import default_backend
# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
# Generate private key
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
backend=default_backend()
)
# Generate certificate
subject = issuer = x509.Name([
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Calejo Control"),
x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Development"),
])
cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer)
.public_key(private_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.now(timezone.utc))
.not_valid_after(datetime.now(timezone.utc) + timedelta(days=365))
.add_extension(
x509.SubjectAlternativeName([
x509.DNSName("localhost"),
x509.DNSName("127.0.0.1"),
]),
critical=False,
)
.sign(private_key, hashes.SHA256(), default_backend())
)
# Write private key
key_path = os.path.join(output_dir, "key.pem")
with open(key_path, "wb") as f:
f.write(
private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
)
# Write certificate
cert_path = os.path.join(output_dir, "cert.pem")
with open(cert_path, "wb") as f:
f.write(cert.public_bytes(serialization.Encoding.PEM))
logger.info(
"self_signed_certificate_generated",
cert_path=cert_path,
key_path=key_path,
common_name=common_name
)
# Update settings
self.cert_path = cert_path
self.key_path = key_path
return True
except Exception as e:
logger.error("failed_to_generate_self_signed_certificate", error=str(e))
return False
# Global TLS manager instance
tls_manager: Optional[TLSManager] = None
def get_tls_manager() -> TLSManager:
"""Get or create the global TLS manager instance."""
global tls_manager
if tls_manager is None:
tls_manager = TLSManager()
return tls_manager

View File

@ -7,12 +7,17 @@ Provides REST endpoints for emergency stop, status monitoring, and setpoint acce
from typing import Optional, Dict, Any
from datetime import datetime
import structlog
from fastapi import FastAPI, HTTPException, status, Depends
from fastapi import FastAPI, HTTPException, status, Depends, Request
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from src.core.setpoint_manager import SetpointManager
from src.core.emergency_stop import EmergencyStopManager
from src.core.security import (
SecurityManager, TokenData, UserRole, get_security_manager
)
from src.core.tls_manager import get_tls_manager
logger = structlog.get_logger()
@ -20,6 +25,64 @@ logger = structlog.get_logger()
security = HTTPBearer()
class LoginRequest(BaseModel):
"""Request model for user login."""
username: str
password: str
class LoginResponse(BaseModel):
"""Response model for successful login."""
access_token: str
token_type: str = "bearer"
expires_in: int
user_id: str
username: str
role: str
permissions: list[str]
class UserInfoResponse(BaseModel):
"""Response model for user information."""
user_id: str
username: str
email: str
role: str
permissions: list[str]
last_login: Optional[str]
def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
security_manager: SecurityManager = Depends(get_security_manager)
) -> TokenData:
"""Dependency to get current user from JWT token."""
token = credentials.credentials
token_data = security_manager.verify_access_token(token)
if not token_data:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
return token_data
def require_permission(permission: str):
"""Dependency factory to require specific permission."""
def permission_dependency(
token_data: TokenData = Depends(get_current_user),
security_manager: SecurityManager = Depends(get_security_manager)
):
if not security_manager.check_permission(token_data, permission):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Insufficient permissions. Required: {permission}"
)
return token_data
return permission_dependency
class EmergencyStopRequest(BaseModel):
"""Request model for emergency stop."""
triggered_by: str
@ -63,7 +126,18 @@ class RESTAPIServer:
self.app = FastAPI(
title="Calejo Control API",
version="2.0",
description="REST API for Calejo Control Adapter with safety framework"
description="REST API for Calejo Control Adapter with safety framework",
docs_url="/docs",
redoc_url="/redoc"
)
# Add CORS middleware
self.app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, restrict to specific origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
self._setup_routes()
@ -77,7 +151,8 @@ class RESTAPIServer:
return {
"name": "Calejo Control API",
"version": "2.0",
"status": "operational"
"status": "operational",
"authentication_required": True
}
@self.app.get("/health", summary="Health Check", tags=["General"])
@ -88,6 +163,69 @@ class RESTAPIServer:
"timestamp": datetime.now().isoformat()
}
# Authentication endpoints (no authentication required)
@self.app.post(
"/api/v1/auth/login",
summary="User Login",
tags=["Authentication"],
response_model=LoginResponse
)
async def login(
request: LoginRequest,
security_manager: SecurityManager = Depends(get_security_manager)
):
"""Authenticate user and return JWT token."""
token = security_manager.authenticate(request.username, request.password)
if not token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid username or password"
)
# Get user permissions
token_data = security_manager.verify_access_token(token)
permissions = security_manager.get_user_permissions(token_data)
return LoginResponse(
access_token=token,
token_type="bearer",
expires_in=60, # 60 minutes
user_id=token_data.user_id,
username=token_data.username,
role=token_data.role.value,
permissions=permissions
)
@self.app.get(
"/api/v1/auth/me",
summary="Get Current User Info",
tags=["Authentication"],
response_model=UserInfoResponse
)
async def get_current_user_info(
token_data: TokenData = Depends(get_current_user),
security_manager: SecurityManager = Depends(get_security_manager)
):
"""Get information about the currently authenticated user."""
# Get user from auth manager
user = security_manager.auth_manager.users.get(token_data.user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
permissions = security_manager.get_user_permissions(token_data)
return UserInfoResponse(
user_id=user.user_id,
username=user.username,
email=user.email,
role=user.role.value,
permissions=permissions,
last_login=user.last_login.isoformat() if user.last_login else None
)
@self.app.get(
"/api/v1/setpoints",
summary="Get All Setpoints",
@ -95,12 +233,14 @@ class RESTAPIServer:
response_model=Dict[str, Dict[str, Optional[float]]]
)
async def get_all_setpoints(
credentials: HTTPAuthorizationCredentials = Depends(security)
token_data: TokenData = Depends(require_permission("read_pump_status"))
):
"""
Get current setpoints for all pumps.
Returns dictionary mapping station_id -> pump_id -> setpoint_hz
Requires permission: read_pump_status
"""
try:
setpoints = self.setpoint_manager.get_all_current_setpoints()
@ -121,10 +261,22 @@ class RESTAPIServer:
async def get_pump_setpoint(
station_id: str,
pump_id: str,
credentials: HTTPAuthorizationCredentials = Depends(security)
token_data: TokenData = Depends(require_permission("read_pump_status")),
security_manager: SecurityManager = Depends(get_security_manager)
):
"""Get current setpoint for a specific pump."""
"""
Get current setpoint for a specific pump.
Requires permission: read_pump_status
"""
try:
# Check if user can access this station
if not security_manager.can_access_station(token_data, station_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Access denied to station {station_id}"
)
setpoint = self.setpoint_manager.get_current_setpoint(station_id, pump_id)
# Get pump info for control type
@ -147,6 +299,8 @@ class RESTAPIServer:
timestamp=datetime.now().isoformat()
)
except HTTPException:
raise
except Exception as e:
logger.error(
"failed_to_get_pump_setpoint",
@ -167,7 +321,8 @@ class RESTAPIServer:
)
async def trigger_emergency_stop(
request: EmergencyStopRequest,
credentials: HTTPAuthorizationCredentials = Depends(security)
token_data: TokenData = Depends(require_permission("emergency_stop")),
security_manager: SecurityManager = Depends(get_security_manager)
):
"""
Trigger emergency stop ("big red button").
@ -176,15 +331,27 @@ class RESTAPIServer:
- If station_id and pump_id provided: Stop single pump
- If station_id only: Stop all pumps at station
- If neither: Stop ALL pumps system-wide
Requires permission: emergency_stop
"""
try:
# Check station access if station_id is provided
if request.station_id and not security_manager.can_access_station(token_data, request.station_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Access denied to station {request.station_id}"
)
# Use authenticated user as triggered_by
triggered_by = token_data.username
if request.station_id and request.pump_id:
# Single pump stop
result = self.emergency_stop_manager.emergency_stop_pump(
station_id=request.station_id,
pump_id=request.pump_id,
reason=request.reason,
user_id=request.triggered_by
user_id=triggered_by
)
scope = f"pump {request.station_id}/{request.pump_id}"
elif request.station_id:
@ -192,14 +359,14 @@ class RESTAPIServer:
result = self.emergency_stop_manager.emergency_stop_station(
station_id=request.station_id,
reason=request.reason,
user_id=request.triggered_by
user_id=triggered_by
)
scope = f"station {request.station_id}"
else:
# System-wide stop
result = self.emergency_stop_manager.emergency_stop_system(
reason=request.reason,
user_id=request.triggered_by
user_id=triggered_by
)
scope = "system-wide"
@ -208,7 +375,7 @@ class RESTAPIServer:
"status": "emergency_stop_triggered",
"scope": scope,
"reason": request.reason,
"triggered_by": request.triggered_by,
"triggered_by": triggered_by,
"timestamp": datetime.now().isoformat()
}
else:
@ -217,6 +384,8 @@ class RESTAPIServer:
detail="Failed to trigger emergency stop"
)
except HTTPException:
raise
except Exception as e:
logger.error("failed_to_trigger_emergency_stop", error=str(e))
raise HTTPException(
@ -231,19 +400,26 @@ class RESTAPIServer:
)
async def clear_emergency_stop(
request: EmergencyStopClearRequest,
credentials: HTTPAuthorizationCredentials = Depends(security)
token_data: TokenData = Depends(require_permission("clear_emergency_stop"))
):
"""Clear all active emergency stops."""
"""
Clear all active emergency stops.
Requires permission: clear_emergency_stop
"""
try:
# Use authenticated user as cleared_by
cleared_by = token_data.username
# Clear system-wide emergency stop
self.emergency_stop_manager.clear_emergency_stop_system(
reason=request.notes,
user_id=request.cleared_by
user_id=cleared_by
)
return {
"status": "emergency_stop_cleared",
"cleared_by": request.cleared_by,
"cleared_by": cleared_by,
"notes": request.notes,
"timestamp": datetime.now().isoformat()
}
@ -261,9 +437,13 @@ class RESTAPIServer:
tags=["Emergency Stop"]
)
async def get_emergency_stop_status(
credentials: HTTPAuthorizationCredentials = Depends(security)
token_data: TokenData = Depends(require_permission("read_safety_status"))
):
"""Check if any emergency stops are active."""
"""
Check if any emergency stops are active.
Requires permission: read_safety_status
"""
try:
# Check system-wide emergency stop
system_stop = self.emergency_stop_manager.system_emergency_stop
@ -291,18 +471,33 @@ class RESTAPIServer:
"""Start the REST API server."""
import uvicorn
# Get TLS configuration
tls_manager = get_tls_manager()
ssl_config = tls_manager.get_rest_api_ssl_config()
logger.info(
"rest_api_server_starting",
host=self.host,
port=self.port
port=self.port,
tls_enabled=ssl_config is not None
)
config = uvicorn.Config(
self.app,
host=self.host,
port=self.port,
log_level="info"
)
config_kwargs = {
"app": self.app,
"host": self.host,
"port": self.port,
"log_level": "info"
}
# Add SSL configuration if available
if ssl_config:
certfile, keyfile = ssl_config
config_kwargs.update({
"ssl_certfile": certfile,
"ssl_keyfile": keyfile
})
config = uvicorn.Config(**config_kwargs)
server = uvicorn.Server(config)
await server.serve()

View File

@ -0,0 +1,368 @@
"""
Unit tests for compliance audit components.
"""
import pytest
from unittest.mock import Mock, patch
from datetime import datetime, timezone
from src.core.compliance_audit import (
ComplianceAuditLogger, AuditEventType, AuditSeverity
)
from config.settings import settings
class TestComplianceAuditLogger:
"""Test cases for ComplianceAuditLogger."""
def setup_method(self):
"""Set up test fixtures."""
self.mock_db_client = Mock()
self.audit_logger = ComplianceAuditLogger(self.mock_db_client)
# Mock settings
self.original_audit_enabled = settings.audit_log_enabled
settings.audit_log_enabled = True
def teardown_method(self):
"""Clean up test fixtures."""
# Restore original settings
settings.audit_log_enabled = self.original_audit_enabled
def test_initialization(self):
"""Test initialization of ComplianceAuditLogger."""
assert self.audit_logger.db_client == self.mock_db_client
assert self.audit_logger.logger is not None
def test_log_compliance_event_success(self):
"""Test successful logging of compliance event."""
event_type = AuditEventType.USER_LOGIN
severity = AuditSeverity.LOW
user_id = "test_user"
ip_address = "192.168.1.100"
with patch.object(self.audit_logger.logger, 'info') as mock_log:
self.audit_logger.log_compliance_event(
event_type=event_type,
severity=severity,
user_id=user_id,
ip_address=ip_address,
action="login",
resource="system",
result="success"
)
# Verify structured logging
mock_log.assert_called_once()
call_args = mock_log.call_args[1]
assert call_args["event_type"] == "user_login"
assert call_args["severity"] == "low"
assert call_args["user_id"] == "test_user"
assert call_args["ip_address"] == "192.168.1.100"
# Verify database logging
self.mock_db_client.execute.assert_called_once()
def test_log_compliance_event_database_disabled(self):
"""Test logging when database audit is disabled."""
settings.audit_log_enabled = False
with patch.object(self.audit_logger.logger, 'info') as mock_log:
self.audit_logger.log_compliance_event(
event_type=AuditEventType.USER_LOGIN,
severity=AuditSeverity.LOW,
user_id="test_user"
)
# Verify structured logging still occurs
mock_log.assert_called_once()
# Verify database logging is skipped
self.mock_db_client.execute.assert_not_called()
def test_log_compliance_event_database_error(self):
"""Test logging when database operation fails."""
# Mock database error
self.mock_db_client.execute.side_effect = Exception("Database error")
with patch.object(self.audit_logger.logger, 'error') as mock_error_log:
self.audit_logger.log_compliance_event(
event_type=AuditEventType.USER_LOGIN,
severity=AuditSeverity.LOW,
user_id="test_user"
)
# Verify error logging
mock_error_log.assert_called_once()
call_args = mock_error_log.call_args[1]
assert "Database error" in call_args["error"]
assert call_args["event_type"] == "user_login"
def test_log_user_authentication_success(self):
"""Test logging successful user authentication."""
user_id = "test_user"
ip_address = "192.168.1.100"
with patch.object(self.audit_logger, 'log_compliance_event') as mock_log:
self.audit_logger.log_user_authentication(
user_id=user_id,
result="success",
ip_address=ip_address
)
# Verify the correct event type and severity
mock_log.assert_called_once()
call_args = mock_log.call_args[1]
assert call_args["event_type"] == AuditEventType.USER_LOGIN
assert call_args["severity"] == AuditSeverity.LOW
assert call_args["user_id"] == user_id
assert call_args["ip_address"] == ip_address
assert call_args["result"] == "success"
def test_log_user_authentication_failure(self):
"""Test logging failed user authentication."""
user_id = "test_user"
ip_address = "192.168.1.100"
reason = "Invalid credentials"
with patch.object(self.audit_logger, 'log_compliance_event') as mock_log:
self.audit_logger.log_user_authentication(
user_id=user_id,
result="failure",
ip_address=ip_address,
reason=reason
)
# Verify the correct event type and severity
mock_log.assert_called_once()
call_args = mock_log.call_args[1]
assert call_args["event_type"] == AuditEventType.INVALID_AUTHENTICATION
assert call_args["severity"] == AuditSeverity.HIGH
assert call_args["reason"] == reason
def test_log_access_control_granted(self):
"""Test logging granted access control."""
user_id = "test_user"
action = "read"
resource = "station_data"
ip_address = "192.168.1.100"
with patch.object(self.audit_logger, 'log_compliance_event') as mock_log:
self.audit_logger.log_access_control(
user_id=user_id,
action=action,
resource=resource,
result="granted",
ip_address=ip_address
)
# Verify the correct event type
mock_log.assert_called_once()
call_args = mock_log.call_args[1]
assert call_args["event_type"] == AuditEventType.DATA_READ
assert call_args["severity"] == AuditSeverity.MEDIUM
def test_log_access_control_denied(self):
"""Test logging denied access control."""
user_id = "test_user"
action = "write"
resource = "system_config"
ip_address = "192.168.1.100"
reason = "Insufficient permissions"
with patch.object(self.audit_logger, 'log_compliance_event') as mock_log:
self.audit_logger.log_access_control(
user_id=user_id,
action=action,
resource=resource,
result="denied",
ip_address=ip_address,
reason=reason
)
# Verify the correct event type and severity
mock_log.assert_called_once()
call_args = mock_log.call_args[1]
assert call_args["event_type"] == AuditEventType.ACCESS_DENIED
assert call_args["severity"] == AuditSeverity.HIGH
assert call_args["reason"] == reason
def test_log_control_operation_setpoint_change(self):
"""Test logging setpoint change operation."""
user_id = "operator_user"
station_id = "station_001"
action = "setpoint_change"
resource = "pump_speed"
ip_address = "192.168.1.100"
with patch.object(self.audit_logger, 'log_compliance_event') as mock_log:
self.audit_logger.log_control_operation(
user_id=user_id,
station_id=station_id,
pump_id=None,
action=action,
resource=resource,
result="success",
ip_address=ip_address
)
# Verify the correct event type and severity
mock_log.assert_called_once()
call_args = mock_log.call_args[1]
assert call_args["event_type"] == AuditEventType.SETPOINT_CHANGED
assert call_args["severity"] == AuditSeverity.HIGH
assert call_args["station_id"] == station_id
def test_log_control_operation_emergency_stop(self):
"""Test logging emergency stop operation."""
user_id = "operator_user"
station_id = "station_001"
action = "emergency_stop"
resource = "system"
ip_address = "192.168.1.100"
with patch.object(self.audit_logger, 'log_compliance_event') as mock_log:
self.audit_logger.log_control_operation(
user_id=user_id,
station_id=station_id,
pump_id="pump_001",
action=action,
resource=resource,
result="success",
ip_address=ip_address
)
# Verify the correct event type and severity
mock_log.assert_called_once()
call_args = mock_log.call_args[1]
assert call_args["event_type"] == AuditEventType.EMERGENCY_STOP_ACTIVATED
assert call_args["severity"] == AuditSeverity.CRITICAL
assert call_args["pump_id"] == "pump_001"
def test_log_security_event(self):
"""Test logging security events."""
event_type = AuditEventType.CERTIFICATE_EXPIRED
severity = AuditSeverity.HIGH
user_id = "system"
station_id = "station_001"
reason = "TLS certificate expired"
with patch.object(self.audit_logger, 'log_compliance_event') as mock_log:
self.audit_logger.log_security_event(
event_type=event_type,
severity=severity,
user_id=user_id,
station_id=station_id,
reason=reason
)
# Verify the correct parameters
mock_log.assert_called_once()
call_args = mock_log.call_args[1]
assert call_args["event_type"] == event_type
assert call_args["severity"] == severity
assert call_args["user_id"] == user_id
assert call_args["station_id"] == station_id
assert call_args["reason"] == reason
def test_get_audit_trail_success(self):
"""Test successful retrieval of audit trail."""
# Mock database result
mock_result = [
{"event_type": "user_login", "user_id": "test_user"},
{"event_type": "data_read", "user_id": "test_user"}
]
self.mock_db_client.fetch_all.return_value = mock_result
start_time = datetime(2024, 1, 1, tzinfo=timezone.utc)
end_time = datetime(2024, 1, 31, tzinfo=timezone.utc)
user_id = "test_user"
with patch.object(self.audit_logger, 'log_compliance_event') as mock_log:
result = self.audit_logger.get_audit_trail(
start_time=start_time,
end_time=end_time,
user_id=user_id
)
# Verify database query
self.mock_db_client.fetch_all.assert_called_once()
# Verify audit trail access logging
mock_log.assert_called_once()
call_args = mock_log.call_args[1]
assert call_args["event_type"] == AuditEventType.AUDIT_LOG_ACCESSED
assert call_args["user_id"] == user_id
# Verify result
assert result == mock_result
def test_get_audit_trail_error(self):
"""Test audit trail retrieval with error."""
# Mock database error
self.mock_db_client.fetch_all.side_effect = Exception("Query failed")
with patch.object(self.audit_logger.logger, 'error') as mock_error_log:
result = self.audit_logger.get_audit_trail()
# Verify error logging
mock_error_log.assert_called_once()
# Verify empty result
assert result == []
def test_generate_compliance_report_success(self):
"""Test successful compliance report generation."""
# Mock database result
mock_result = [
{"event_type": "user_login", "severity": "low", "count": 10, "unique_users": 2},
{"event_type": "access_denied", "severity": "high", "count": 2, "unique_users": 1}
]
self.mock_db_client.fetch_all.return_value = mock_result
start_time = datetime(2024, 1, 1, tzinfo=timezone.utc)
end_time = datetime(2024, 1, 31, tzinfo=timezone.utc)
compliance_standard = "IEC_62443"
with patch.object(self.audit_logger, 'log_compliance_event') as mock_log:
report = self.audit_logger.generate_compliance_report(
start_time=start_time,
end_time=end_time,
compliance_standard=compliance_standard
)
# Verify database query
self.mock_db_client.fetch_all.assert_called_once()
# Verify compliance report logging
mock_log.assert_called_once()
# Verify report structure
assert report["compliance_standard"] == compliance_standard
assert report["summary"]["total_events"] == 12
assert report["summary"]["unique_users"] == 3
assert report["summary"]["event_types"] == 2
assert report["events_by_type"] == mock_result
def test_generate_compliance_report_error(self):
"""Test compliance report generation with error."""
# Mock database error
self.mock_db_client.fetch_all.side_effect = Exception("Report generation failed")
start_time = datetime(2024, 1, 1, tzinfo=timezone.utc)
end_time = datetime(2024, 1, 31, tzinfo=timezone.utc)
compliance_standard = "ISO_27001"
with patch.object(self.audit_logger.logger, 'error') as mock_error_log:
report = self.audit_logger.generate_compliance_report(
start_time=start_time,
end_time=end_time,
compliance_standard=compliance_standard
)
# Verify error logging
mock_error_log.assert_called_once()
# Verify error result
assert "error" in report
assert "Report generation failed" in report["error"]

373
tests/unit/test_security.py Normal file
View File

@ -0,0 +1,373 @@
"""
Unit tests for security layer components.
"""
import pytest
from unittest.mock import Mock, patch
from datetime import datetime, timedelta, timezone
import jwt
import bcrypt
from src.core.security import (
AuthenticationManager,
AuthorizationManager,
SecurityManager,
UserRole,
User,
TokenData
)
from config.settings import settings
class TestAuthenticationManager:
"""Test cases for AuthenticationManager."""
def setup_method(self):
"""Set up test fixtures."""
self.auth_manager = AuthenticationManager(
secret_key="test-secret-key",
token_expire_minutes=60
)
def test_hash_password(self):
"""Test password hashing."""
password = "testpassword123"
hashed = self.auth_manager.hash_password(password)
# Verify the hash is valid
assert bcrypt.checkpw(password.encode('utf-8'), hashed.encode('utf-8'))
# Verify different passwords produce different hashes
different_password = "differentpassword"
different_hashed = self.auth_manager.hash_password(different_password)
assert hashed != different_hashed
def test_verify_password(self):
"""Test password verification."""
password = "testpassword123"
hashed = self.auth_manager.hash_password(password)
# Test correct password
assert self.auth_manager.verify_password(password, hashed)
# Test incorrect password
assert not self.auth_manager.verify_password("wrongpassword", hashed)
def test_create_user(self):
"""Test user creation."""
user = self.auth_manager.create_user(
user_id="test_user_001",
username="testuser",
email="test@example.com",
role=UserRole.OPERATOR,
password="testpassword123"
)
assert user.user_id == "test_user_001"
assert user.username == "testuser"
assert user.email == "test@example.com"
assert user.role == UserRole.OPERATOR
assert user.active is True
assert user.created_at is not None
# Verify password was hashed and stored
assert "test_user_001" in self.auth_manager.password_hashes
hashed_password = self.auth_manager.password_hashes["test_user_001"]
assert bcrypt.checkpw("testpassword123".encode('utf-8'), hashed_password.encode('utf-8'))
def test_create_user_duplicate_id(self):
"""Test creating user with duplicate ID."""
self.auth_manager.create_user(
user_id="duplicate_user",
username="user1",
email="user1@example.com",
role=UserRole.OPERATOR,
password="password1"
)
with pytest.raises(ValueError, match="User with ID duplicate_user already exists"):
self.auth_manager.create_user(
user_id="duplicate_user",
username="user2",
email="user2@example.com",
role=UserRole.ENGINEER,
password="password2"
)
def test_authenticate_user_success(self):
"""Test successful user authentication."""
# Create a test user
self.auth_manager.create_user(
user_id="auth_test_user",
username="authuser",
email="auth@example.com",
role=UserRole.ENGINEER,
password="authpassword123"
)
# Test authentication
user = self.auth_manager.authenticate_user("authuser", "authpassword123")
assert user is not None
assert user.user_id == "auth_test_user"
assert user.username == "authuser"
assert user.role == UserRole.ENGINEER
assert user.last_login is not None
def test_authenticate_user_wrong_password(self):
"""Test authentication with wrong password."""
self.auth_manager.create_user(
user_id="wrong_pass_user",
username="wrongpass",
email="wrong@example.com",
role=UserRole.OPERATOR,
password="correctpassword"
)
user = self.auth_manager.authenticate_user("wrongpass", "wrongpassword")
assert user is None
def test_authenticate_user_not_found(self):
"""Test authentication with non-existent user."""
user = self.auth_manager.authenticate_user("nonexistent", "password")
assert user is None
def test_authenticate_user_inactive(self):
"""Test authentication with inactive user."""
# Create inactive user
user = User(
user_id="inactive_user",
username="inactive",
email="inactive@example.com",
role=UserRole.OPERATOR,
active=False,
created_at=datetime.now(timezone.utc)
)
self.auth_manager.users["inactive_user"] = user
self.auth_manager.password_hashes["inactive_user"] = self.auth_manager.hash_password("password")
result = self.auth_manager.authenticate_user("inactive", "password")
assert result is None
def test_create_access_token(self):
"""Test JWT token creation."""
user = User(
user_id="token_user",
username="tokenuser",
email="token@example.com",
role=UserRole.ADMINISTRATOR,
created_at=datetime.now(timezone.utc)
)
token = self.auth_manager.create_access_token(user)
# Verify token can be decoded
payload = jwt.decode(token, "test-secret-key", algorithms=["HS256"])
assert payload["user_id"] == "token_user"
assert payload["username"] == "tokenuser"
assert payload["role"] == "administrator"
assert "exp" in payload
def test_verify_token_success(self):
"""Test successful token verification."""
# Create user and token
user = User(
user_id="verify_user",
username="verifyuser",
email="verify@example.com",
role=UserRole.OPERATOR,
created_at=datetime.now(timezone.utc)
)
self.auth_manager.users["verify_user"] = user
token = self.auth_manager.create_access_token(user)
token_data = self.auth_manager.verify_token(token)
assert token_data is not None
assert token_data.user_id == "verify_user"
assert token_data.username == "verifyuser"
assert token_data.role == UserRole.OPERATOR
def test_verify_token_expired(self):
"""Test verification of expired token."""
# Create expired token
expired_time = datetime.now(timezone.utc) - timedelta(hours=1)
token_data = TokenData(
user_id="expired_user",
username="expireduser",
role=UserRole.OPERATOR,
exp=expired_time
)
token = jwt.encode(
token_data.dict(),
"test-secret-key",
algorithm="HS256"
)
result = self.auth_manager.verify_token(token)
assert result is None
def test_verify_token_invalid_signature(self):
"""Test verification of token with invalid signature."""
# Create token with wrong secret
user = User(
user_id="invalid_user",
username="invaliduser",
email="invalid@example.com",
role=UserRole.OPERATOR,
created_at=datetime.now(timezone.utc)
)
token = jwt.encode(
{
"user_id": "invalid_user",
"username": "invaliduser",
"role": "operator",
"exp": datetime.now(timezone.utc) + timedelta(minutes=60)
},
"wrong-secret-key",
algorithm="HS256"
)
result = self.auth_manager.verify_token(token)
assert result is None
class TestAuthorizationManager:
"""Test cases for AuthorizationManager."""
def setup_method(self):
"""Set up test fixtures."""
self.authz_manager = AuthorizationManager()
def test_has_permission_read_only(self):
"""Test permissions for read-only role."""
assert self.authz_manager.has_permission(UserRole.READ_ONLY, "read_pump_status")
assert self.authz_manager.has_permission(UserRole.READ_ONLY, "read_safety_status")
assert self.authz_manager.has_permission(UserRole.READ_ONLY, "read_audit_logs")
# Should not have write permissions
assert not self.authz_manager.has_permission(UserRole.READ_ONLY, "emergency_stop")
assert not self.authz_manager.has_permission(UserRole.READ_ONLY, "configure_safety_limits")
def test_has_permission_operator(self):
"""Test permissions for operator role."""
assert self.authz_manager.has_permission(UserRole.OPERATOR, "read_pump_status")
assert self.authz_manager.has_permission(UserRole.OPERATOR, "emergency_stop")
assert self.authz_manager.has_permission(UserRole.OPERATOR, "clear_emergency_stop")
# Should not have configuration permissions
assert not self.authz_manager.has_permission(UserRole.OPERATOR, "configure_safety_limits")
assert not self.authz_manager.has_permission(UserRole.OPERATOR, "manage_users")
def test_has_permission_engineer(self):
"""Test permissions for engineer role."""
assert self.authz_manager.has_permission(UserRole.ENGINEER, "read_pump_status")
assert self.authz_manager.has_permission(UserRole.ENGINEER, "emergency_stop")
assert self.authz_manager.has_permission(UserRole.ENGINEER, "configure_safety_limits")
assert self.authz_manager.has_permission(UserRole.ENGINEER, "manage_pump_configuration")
# Should not have administrative permissions
assert not self.authz_manager.has_permission(UserRole.ENGINEER, "manage_users")
def test_has_permission_administrator(self):
"""Test permissions for administrator role."""
# Administrator should have all permissions
all_permissions = [
"read_pump_status", "read_safety_status", "read_audit_logs",
"emergency_stop", "clear_emergency_stop", "view_alerts",
"configure_safety_limits", "manage_pump_configuration",
"view_system_metrics", "manage_users", "configure_system",
"access_all_stations"
]
for permission in all_permissions:
assert self.authz_manager.has_permission(UserRole.ADMINISTRATOR, permission)
def test_has_permission_unknown_permission(self):
"""Test unknown permission."""
assert not self.authz_manager.has_permission(UserRole.ADMINISTRATOR, "unknown_permission")
def test_get_allowed_actions(self):
"""Test getting all allowed actions for a role."""
operator_permissions = self.authz_manager.get_allowed_actions(UserRole.OPERATOR)
assert "read_pump_status" in operator_permissions
assert "emergency_stop" in operator_permissions
assert "clear_emergency_stop" in operator_permissions
assert "configure_safety_limits" not in operator_permissions
def test_can_access_station(self):
"""Test station access control."""
# Administrators can access all stations
assert self.authz_manager.can_access_station(UserRole.ADMINISTRATOR, "STATION_001")
# Other roles can access all stations (for now)
assert self.authz_manager.can_access_station(UserRole.OPERATOR, "STATION_001")
assert self.authz_manager.can_access_station(UserRole.ENGINEER, "STATION_001")
assert self.authz_manager.can_access_station(UserRole.READ_ONLY, "STATION_001")
class TestSecurityManager:
"""Test cases for SecurityManager."""
def setup_method(self):
"""Set up test fixtures."""
self.security_manager = SecurityManager()
def test_authenticate_success(self):
"""Test successful authentication."""
# Default users are created during initialization
token = self.security_manager.authenticate("operator", "operator123")
assert token is not None
# Verify token is valid
token_data = self.security_manager.verify_access_token(token)
assert token_data is not None
assert token_data.username == "operator"
assert token_data.role == UserRole.OPERATOR
def test_authenticate_failure(self):
"""Test failed authentication."""
token = self.security_manager.authenticate("nonexistent", "password")
assert token is None
def test_check_permission(self):
"""Test permission checking."""
# Create token for operator
token = self.security_manager.authenticate("operator", "operator123")
token_data = self.security_manager.verify_access_token(token)
# Operator should have emergency_stop permission
assert self.security_manager.check_permission(token_data, "emergency_stop")
# Operator should not have manage_users permission
assert not self.security_manager.check_permission(token_data, "manage_users")
def test_get_user_permissions(self):
"""Test getting user permissions."""
# Create token for engineer
token = self.security_manager.authenticate("engineer", "engineer123")
token_data = self.security_manager.verify_access_token(token)
permissions = self.security_manager.get_user_permissions(token_data)
# Engineer should have specific permissions
assert "configure_safety_limits" in permissions
assert "manage_pump_configuration" in permissions
assert "emergency_stop" in permissions
# Engineer should not have administrative permissions
assert "manage_users" not in permissions
def test_can_access_station(self):
"""Test station access control."""
# Create token for operator
token = self.security_manager.authenticate("operator", "operator123")
token_data = self.security_manager.verify_access_token(token)
# Operator should be able to access any station
assert self.security_manager.can_access_station(token_data, "STATION_001")
assert self.security_manager.can_access_station(token_data, "STATION_002")

View File

@ -0,0 +1,243 @@
"""
Unit tests for TLS manager components.
"""
import pytest
import unittest.mock
from unittest.mock import Mock, patch, MagicMock
import tempfile
import os
import ssl
from pathlib import Path
from src.core.tls_manager import TLSManager
from config.settings import settings
class TestTLSManager:
"""Test cases for TLSManager."""
def setup_method(self):
"""Set up test fixtures."""
# Create temporary directory for test certificates
self.temp_dir = tempfile.mkdtemp()
self.cert_path = os.path.join(self.temp_dir, "test_cert.pem")
self.key_path = os.path.join(self.temp_dir, "test_key.pem")
# Create dummy certificate files for testing
with open(self.cert_path, 'w') as f:
f.write("DUMMY CERTIFICATE")
with open(self.key_path, 'w') as f:
f.write("DUMMY PRIVATE KEY")
# Mock settings
self.original_tls_enabled = settings.tls_enabled
self.original_cert_path = settings.tls_cert_path
self.original_key_path = settings.tls_key_path
settings.tls_enabled = True
settings.tls_cert_path = self.cert_path
settings.tls_key_path = self.key_path
def teardown_method(self):
"""Clean up test fixtures."""
# Restore original settings
settings.tls_enabled = self.original_tls_enabled
settings.tls_cert_path = self.original_cert_path
settings.tls_key_path = self.original_key_path
# Clean up temporary directory
import shutil
shutil.rmtree(self.temp_dir)
def test_initialization_with_valid_certificates(self):
"""Test initialization with valid certificate files."""
tls_manager = TLSManager()
assert tls_manager.tls_enabled is True
assert tls_manager.cert_path == self.cert_path
assert tls_manager.key_path == self.key_path
def test_initialization_with_missing_certificates(self):
"""Test initialization with missing certificate files."""
# Set invalid paths
settings.tls_cert_path = "/nonexistent/cert.pem"
settings.tls_key_path = "/nonexistent/key.pem"
tls_manager = TLSManager()
# Should still initialize but with validation failure
assert tls_manager.tls_enabled is True
assert tls_manager.cert_path == "/nonexistent/cert.pem"
def test_initialization_with_tls_disabled(self):
"""Test initialization with TLS disabled."""
settings.tls_enabled = False
tls_manager = TLSManager()
assert tls_manager.tls_enabled is False
@patch('src.core.tls_manager.TLSManager._get_certificate_expiry')
def test_validate_certificates_success(self, mock_get_expiry):
"""Test successful certificate validation."""
# Mock certificate expiry to be in the future
from datetime import datetime, timezone
future_date = datetime(2030, 12, 31, 23, 59, 59, tzinfo=timezone.utc)
mock_get_expiry.return_value = future_date
tls_manager = TLSManager()
result = tls_manager._validate_certificates()
assert result is True
@patch('src.core.tls_manager.TLSManager._get_certificate_expiry')
def test_validate_certificates_expired(self, mock_get_expiry):
"""Test validation of expired certificate."""
# Mock certificate expiry to be in the past
from datetime import datetime, timezone
past_date = datetime(2020, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
mock_get_expiry.return_value = past_date
tls_manager = TLSManager()
result = tls_manager._validate_certificates()
assert result is False
@patch('src.core.tls_manager.TLSManager._get_certificate_expiry')
def test_validate_certificates_parsing_error(self, mock_get_expiry):
"""Test certificate validation with parsing error."""
# Mock certificate parsing to raise an exception
mock_get_expiry.side_effect = Exception("Parse error")
tls_manager = TLSManager()
result = tls_manager._validate_certificates()
assert result is False
@patch('src.core.tls_manager.ssl')
def test_create_ssl_context_success(self, mock_ssl):
"""Test successful SSL context creation."""
# Mock SSL context
mock_context = Mock()
mock_ssl.create_default_context.return_value = mock_context
tls_manager = TLSManager()
# Mock validation to succeed
with patch.object(tls_manager, '_validate_certificates', return_value=True):
result = tls_manager.create_ssl_context()
assert result is mock_context
# Check that create_default_context was called with CLIENT_AUTH purpose
mock_ssl.create_default_context.assert_called_once()
call_args = mock_ssl.create_default_context.call_args
assert call_args[0][0] == mock_ssl.Purpose.CLIENT_AUTH
mock_context.load_cert_chain.assert_called_once_with(self.cert_path, self.key_path)
def test_create_ssl_context_tls_disabled(self):
"""Test SSL context creation with TLS disabled."""
settings.tls_enabled = False
tls_manager = TLSManager()
result = tls_manager.create_ssl_context()
assert result is None
def test_create_ssl_context_invalid_certificates(self):
"""Test SSL context creation with invalid certificates."""
tls_manager = TLSManager()
# Mock validation to fail
with patch.object(tls_manager, '_validate_certificates', return_value=False):
result = tls_manager.create_ssl_context()
assert result is None
def test_get_rest_api_ssl_config_success(self):
"""Test successful REST API SSL configuration."""
tls_manager = TLSManager()
# Mock validation to succeed
with patch.object(tls_manager, '_validate_certificates', return_value=True):
result = tls_manager.get_rest_api_ssl_config()
assert result == (self.cert_path, self.key_path)
def test_get_rest_api_ssl_config_tls_disabled(self):
"""Test REST API SSL configuration with TLS disabled."""
settings.tls_enabled = False
tls_manager = TLSManager()
result = tls_manager.get_rest_api_ssl_config()
assert result is None
def test_get_rest_api_ssl_config_invalid_certificates(self):
"""Test REST API SSL configuration with invalid certificates."""
tls_manager = TLSManager()
# Mock validation to fail
with patch.object(tls_manager, '_validate_certificates', return_value=False):
result = tls_manager.get_rest_api_ssl_config()
assert result is None
def test_check_certificate_rotation_needed(self):
"""Test certificate rotation check when needed."""
tls_manager = TLSManager()
# Mock certificate with near expiry
from datetime import datetime, timezone, timedelta
near_expiry = datetime.now(timezone.utc) + timedelta(days=5) # Expires in 5 days
tls_manager.cert_expiry_dates[self.cert_path] = near_expiry
result = tls_manager.check_certificate_rotation()
assert result is True
def test_check_certificate_rotation_not_needed(self):
"""Test certificate rotation check when not needed."""
tls_manager = TLSManager()
# Mock certificate with distant expiry
from datetime import datetime, timezone, timedelta
distant_expiry = datetime.now(timezone.utc) + timedelta(days=365) # Expires in 1 year
tls_manager.cert_expiry_dates[self.cert_path] = distant_expiry
result = tls_manager.check_certificate_rotation()
assert result is False
def test_check_certificate_rotation_tls_disabled(self):
"""Test certificate rotation check with TLS disabled."""
settings.tls_enabled = False
tls_manager = TLSManager()
result = tls_manager.check_certificate_rotation()
assert result is False
@patch('src.core.tls_manager.os.makedirs')
def test_generate_self_signed_certificate_success(self, mock_makedirs):
"""Test successful self-signed certificate generation."""
tls_manager = TLSManager()
# Mock the entire certificate generation to succeed
with patch('builtins.open', unittest.mock.mock_open()) as mock_file:
result = tls_manager.generate_self_signed_certificate(self.temp_dir)
assert result is True
assert tls_manager.cert_path.endswith("cert.pem")
assert tls_manager.key_path.endswith("key.pem")
@patch('src.core.tls_manager.os.makedirs')
def test_generate_self_signed_certificate_failure(self, mock_makedirs):
"""Test self-signed certificate generation failure."""
# Mock makedirs to raise an exception
mock_makedirs.side_effect = Exception("Permission denied")
tls_manager = TLSManager()
result = tls_manager.generate_self_signed_certificate(self.temp_dir)
assert result is False