Compare commits
3 Commits
0b66a0fb4e
...
84edcb14ff
| Author | SHA1 | Date |
|---|---|---|
|
|
84edcb14ff | |
|
|
58ba34b230 | |
|
|
dc10dab9ec |
|
|
@ -0,0 +1,117 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Calejo Control Adapter - Test Runner with Better Output Formatting
|
||||
|
||||
This script runs the standard test suite but provides better output formatting
|
||||
and organization of results by test file and system.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
# Colors for output
|
||||
class Colors:
|
||||
RED = '\033[91m'
|
||||
GREEN = '\033[92m'
|
||||
YELLOW = '\033[93m'
|
||||
BLUE = '\033[94m'
|
||||
MAGENTA = '\033[95m'
|
||||
CYAN = '\033[96m'
|
||||
WHITE = '\033[97m'
|
||||
BOLD = '\033[1m'
|
||||
END = '\033[0m'
|
||||
|
||||
def print_color(color, message):
|
||||
print(f"{color}{message}{Colors.END}")
|
||||
|
||||
def print_info(message):
|
||||
print_color(Colors.BLUE, f"[INFO] {message}")
|
||||
|
||||
def print_success(message):
|
||||
print_color(Colors.GREEN, f"[SUCCESS] {message}")
|
||||
|
||||
def print_warning(message):
|
||||
print_color(Colors.YELLOW, f"[WARNING] {message}")
|
||||
|
||||
def print_error(message):
|
||||
print_color(Colors.RED, f"[ERROR] {message}")
|
||||
|
||||
def print_header(message):
|
||||
print_color(Colors.CYAN + Colors.BOLD, f"\n{'='*80}")
|
||||
print_color(Colors.CYAN + Colors.BOLD, f" {message}")
|
||||
print_color(Colors.CYAN + Colors.BOLD, f"{'='*80}\n")
|
||||
|
||||
def main():
|
||||
"""Main function."""
|
||||
print_header("CALEJO CONTROL ADAPTER - TEST SUITE WITH BETTER OUTPUT")
|
||||
|
||||
print_info(f"Test Run Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
# Run tests with better organization
|
||||
test_sections = [
|
||||
("UNIT TESTS", "tests/unit/"),
|
||||
("INTEGRATION TESTS", "tests/integration/"),
|
||||
("SAFETY FRAMEWORK TESTS", "tests/test_safety.py"),
|
||||
("SAFETY FRAMEWORK UNIT TESTS", "tests/unit/test_safety_framework.py")
|
||||
]
|
||||
|
||||
all_passed = True
|
||||
start_time = time.time()
|
||||
|
||||
for section_name, test_path in test_sections:
|
||||
print_header(f"RUNNING {section_name}")
|
||||
|
||||
cmd = [
|
||||
'python', '-m', 'pytest', test_path,
|
||||
'-v',
|
||||
'--tb=short',
|
||||
'--color=yes'
|
||||
]
|
||||
|
||||
print_info(f"Running: {' '.join(cmd)}")
|
||||
section_start = time.time()
|
||||
|
||||
result = subprocess.run(cmd)
|
||||
|
||||
section_duration = time.time() - section_start
|
||||
|
||||
if result.returncode == 0:
|
||||
print_success(f"{section_name} PASSED in {section_duration:.2f}s")
|
||||
else:
|
||||
print_error(f"{section_name} FAILED in {section_duration:.2f}s")
|
||||
all_passed = False
|
||||
|
||||
total_duration = time.time() - start_time
|
||||
|
||||
print_header("TEST SUMMARY")
|
||||
print_info(f"Test Run Completed: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
print_info(f"Total Duration: {total_duration:.2f} seconds")
|
||||
|
||||
if all_passed:
|
||||
print_success("🎉 ALL TEST SECTIONS PASSED! 🎉")
|
||||
print_info("\nTest Results by System:")
|
||||
print(" ✅ Safety Framework - All tests passed")
|
||||
print(" ✅ Protocol Servers - All tests passed")
|
||||
print(" ✅ Database Systems - All tests passed")
|
||||
print(" ✅ Security Systems - All tests passed")
|
||||
print(" ✅ Monitoring Systems - All tests passed")
|
||||
print(" ✅ Integration Tests - All tests passed")
|
||||
else:
|
||||
print_error("❌ SOME TEST SECTIONS FAILED ❌")
|
||||
|
||||
sys.exit(0 if all_passed else 1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
print_warning("\nTest run interrupted by user")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print_error(f"Unexpected error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
|
@ -15,7 +15,7 @@ from pymodbus.transaction import ModbusSocketFramer
|
|||
|
||||
from src.core.setpoint_manager import SetpointManager
|
||||
from src.core.security import SecurityManager
|
||||
from src.core.compliance_audit import ComplianceAuditLogger
|
||||
from src.core.compliance_audit import ComplianceAuditLogger, AuditEventType, AuditSeverity
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
|
@ -76,12 +76,13 @@ class ModbusServer:
|
|||
# Initialize data blocks
|
||||
await self._initialize_datastore()
|
||||
|
||||
# Start server
|
||||
self.server = await StartAsyncTcpServer(
|
||||
# Start server as a task that can be cancelled
|
||||
self.server_task = asyncio.create_task(
|
||||
StartAsyncTcpServer(
|
||||
context=self.context,
|
||||
framer=ModbusSocketFramer,
|
||||
address=(self.host, self.port),
|
||||
defer_start=False
|
||||
address=(self.host, self.port)
|
||||
)
|
||||
)
|
||||
|
||||
# Log security configuration
|
||||
|
|
@ -98,9 +99,9 @@ class ModbusServer:
|
|||
|
||||
# Log security event
|
||||
self.audit_logger.log_security_event(
|
||||
event_type="SERVER_START",
|
||||
severity="INFO",
|
||||
details={
|
||||
event_type=AuditEventType.SYSTEM_START,
|
||||
severity=AuditSeverity.LOW,
|
||||
event_data={
|
||||
"protocol": "MODBUS_TCP",
|
||||
"host": self.host,
|
||||
"port": self.port,
|
||||
|
|
@ -128,9 +129,9 @@ class ModbusServer:
|
|||
if self.allowed_ips and client_ip not in self.allowed_ips:
|
||||
# Log unauthorized access attempt
|
||||
self.audit_logger.log_security_event(
|
||||
event_type="UNAUTHORIZED_ACCESS",
|
||||
severity="WARNING",
|
||||
details={
|
||||
event_type=AuditEventType.ACCESS_DENIED,
|
||||
severity=AuditSeverity.HIGH,
|
||||
event_data={
|
||||
"protocol": "MODBUS_TCP",
|
||||
"client_ip": client_ip,
|
||||
"allowed_ips": self.allowed_ips,
|
||||
|
|
@ -164,9 +165,9 @@ class ModbusServer:
|
|||
if self.request_counts[client_ip] >= self.rate_limit_per_minute:
|
||||
# Log rate limit violation
|
||||
self.audit_logger.log_security_event(
|
||||
event_type="RATE_LIMIT_EXCEEDED",
|
||||
severity="WARNING",
|
||||
details={
|
||||
event_type=AuditEventType.ACCESS_DENIED,
|
||||
severity=AuditSeverity.MEDIUM,
|
||||
event_data={
|
||||
"protocol": "MODBUS_TCP",
|
||||
"client_ip": client_ip,
|
||||
"request_count": self.request_counts[client_ip],
|
||||
|
|
@ -201,9 +202,9 @@ class ModbusServer:
|
|||
sensitive_functions = {6, 16} # Write single register, write multiple registers
|
||||
if function_code in sensitive_functions:
|
||||
self.audit_logger.log_security_event(
|
||||
event_type="MODBUS_WRITE_OPERATION",
|
||||
severity="INFO",
|
||||
details={
|
||||
event_type=AuditEventType.SETPOINT_CHANGED,
|
||||
severity=AuditSeverity.LOW,
|
||||
event_data={
|
||||
"protocol": "MODBUS_TCP",
|
||||
"client_ip": client_ip,
|
||||
"function_code": function_code,
|
||||
|
|
@ -246,15 +247,19 @@ class ModbusServer:
|
|||
|
||||
async def stop(self):
|
||||
"""Stop the Modbus TCP server."""
|
||||
if self.server:
|
||||
# Note: pymodbus doesn't have a direct stop method
|
||||
# We'll rely on the task being cancelled
|
||||
if hasattr(self, 'server_task') and self.server_task:
|
||||
# Cancel the server task
|
||||
self.server_task.cancel()
|
||||
try:
|
||||
await self.server_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Log security event
|
||||
self.audit_logger.log_security_event(
|
||||
event_type="SERVER_STOP",
|
||||
severity="INFO",
|
||||
details={
|
||||
event_type=AuditEventType.SYSTEM_STOP,
|
||||
severity=AuditSeverity.LOW,
|
||||
event_data={
|
||||
"protocol": "MODBUS_TCP",
|
||||
"host": self.host,
|
||||
"port": self.port,
|
||||
|
|
@ -304,8 +309,7 @@ class ModbusServer:
|
|||
stations = self.setpoint_manager.discovery.get_stations()
|
||||
address_counter = 0
|
||||
|
||||
for station in stations:
|
||||
station_id = station['station_id']
|
||||
for station_id, station in stations.items():
|
||||
pumps = self.setpoint_manager.discovery.get_pumps(station_id)
|
||||
|
||||
for pump in pumps:
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ except ImportError:
|
|||
|
||||
from src.core.setpoint_manager import SetpointManager
|
||||
from src.core.security import SecurityManager, UserRole
|
||||
from src.core.compliance_audit import ComplianceAuditLogger
|
||||
from src.core.compliance_audit import ComplianceAuditLogger, AuditEventType, AuditSeverity
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
|
@ -103,9 +103,9 @@ class OPCUAServer:
|
|||
|
||||
# Log security event
|
||||
self.audit_logger.log_security_event(
|
||||
event_type="SERVER_START",
|
||||
severity="INFO",
|
||||
details={
|
||||
event_type=AuditEventType.SYSTEM_START,
|
||||
severity=AuditSeverity.LOW,
|
||||
event_data={
|
||||
"protocol": "OPC_UA",
|
||||
"endpoint": self.endpoint,
|
||||
"security_enabled": self.enable_security,
|
||||
|
|
@ -179,9 +179,9 @@ class OPCUAServer:
|
|||
|
||||
# Log connection event
|
||||
self.audit_logger.log_security_event(
|
||||
event_type="CLIENT_CONNECT",
|
||||
severity="INFO",
|
||||
details={
|
||||
event_type=AuditEventType.USER_LOGIN,
|
||||
severity=AuditSeverity.LOW,
|
||||
event_data={
|
||||
"protocol": "OPC_UA",
|
||||
"client_id": client_id,
|
||||
"endpoint": endpoint,
|
||||
|
|
@ -205,9 +205,9 @@ class OPCUAServer:
|
|||
if client_info:
|
||||
# Log disconnection event
|
||||
self.audit_logger.log_security_event(
|
||||
event_type="CLIENT_DISCONNECT",
|
||||
severity="INFO",
|
||||
details={
|
||||
event_type=AuditEventType.USER_LOGOUT,
|
||||
severity=AuditSeverity.LOW,
|
||||
event_data={
|
||||
"protocol": "OPC_UA",
|
||||
"client_id": client_id,
|
||||
"endpoint": client_info['endpoint'],
|
||||
|
|
@ -228,9 +228,9 @@ class OPCUAServer:
|
|||
|
||||
# Log security event
|
||||
self.audit_logger.log_security_event(
|
||||
event_type="SERVER_STOP",
|
||||
severity="INFO",
|
||||
details={
|
||||
event_type=AuditEventType.SYSTEM_STOP,
|
||||
severity=AuditSeverity.LOW,
|
||||
event_data={
|
||||
"protocol": "OPC_UA",
|
||||
"endpoint": self.endpoint,
|
||||
"connected_clients": len(self.connected_clients)
|
||||
|
|
@ -273,9 +273,9 @@ class OPCUAServer:
|
|||
|
||||
# Log access attempt
|
||||
self.audit_logger.log_security_event(
|
||||
event_type="NODE_ACCESS_ATTEMPT",
|
||||
severity="INFO",
|
||||
details={
|
||||
event_type=AuditEventType.ACCESS_DENIED,
|
||||
severity=AuditSeverity.MEDIUM,
|
||||
event_data={
|
||||
"protocol": "OPC_UA",
|
||||
"client_id": str(session.session_id),
|
||||
"node": str(node),
|
||||
|
|
@ -308,8 +308,7 @@ class OPCUAServer:
|
|||
# Create stations and pumps structure
|
||||
stations = self.setpoint_manager.discovery.get_stations()
|
||||
|
||||
for station in stations:
|
||||
station_id = station['station_id']
|
||||
for station_id, station in stations.items():
|
||||
|
||||
# Create station folder
|
||||
station_folder = await calejo_folder.add_folder(
|
||||
|
|
|
|||
|
|
@ -1,261 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for Phase 1 implementation.
|
||||
|
||||
Tests core infrastructure components:
|
||||
- Database connectivity and queries
|
||||
- Auto-discovery functionality
|
||||
- Configuration management
|
||||
- Safety framework loading
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add src to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from src.database.flexible_client import FlexibleDatabaseClient
|
||||
from src.core.auto_discovery import AutoDiscovery
|
||||
from src.core.safety import SafetyLimitEnforcer
|
||||
from src.core.logging import setup_logging
|
||||
from config.settings import settings
|
||||
|
||||
|
||||
class Phase1Tester:
|
||||
"""Test class for Phase 1 components."""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = setup_logging()
|
||||
self.db_client = FlexibleDatabaseClient(
|
||||
database_url=settings.database_url,
|
||||
pool_size=settings.db_pool_size,
|
||||
max_overflow=settings.db_max_overflow
|
||||
)
|
||||
self.auto_discovery = AutoDiscovery(self.db_client)
|
||||
self.safety_enforcer = SafetyLimitEnforcer(self.db_client)
|
||||
|
||||
self.tests_passed = 0
|
||||
self.tests_failed = 0
|
||||
|
||||
async def run_all_tests(self):
|
||||
"""Run all Phase 1 tests."""
|
||||
self.logger.info("starting_phase1_tests")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("Calejo Control Adapter - Phase 1 Test Suite")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
# Test 1: Database connection
|
||||
await self.test_database_connection()
|
||||
|
||||
# Test 2: Database queries
|
||||
await self.test_database_queries()
|
||||
|
||||
# Test 3: Auto-discovery
|
||||
await self.test_auto_discovery()
|
||||
|
||||
# Test 4: Safety framework
|
||||
await self.test_safety_framework()
|
||||
|
||||
# Test 5: Configuration
|
||||
await self.test_configuration()
|
||||
|
||||
# Print summary
|
||||
self.print_test_summary()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("test_suite_failed", error=str(e))
|
||||
raise
|
||||
finally:
|
||||
await self.db_client.disconnect()
|
||||
|
||||
async def test_database_connection(self):
|
||||
"""Test database connectivity."""
|
||||
print("\n1. Testing Database Connection...")
|
||||
|
||||
try:
|
||||
await self.db_client.connect()
|
||||
|
||||
# Test health check
|
||||
is_healthy = self.db_client.health_check()
|
||||
if is_healthy:
|
||||
print(" ✓ Database connection successful")
|
||||
self.tests_passed += 1
|
||||
else:
|
||||
print(" ✗ Database health check failed")
|
||||
self.tests_failed += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f" ✗ Database connection failed: {e}")
|
||||
self.tests_failed += 1
|
||||
raise
|
||||
|
||||
async def test_database_queries(self):
|
||||
"""Test database queries."""
|
||||
print("\n2. Testing Database Queries...")
|
||||
|
||||
try:
|
||||
# Test pump stations query
|
||||
stations = self.db_client.get_pump_stations()
|
||||
print(f" ✓ Found {len(stations)} pump stations")
|
||||
|
||||
# Test pumps query
|
||||
pumps = self.db_client.get_pumps()
|
||||
print(f" ✓ Found {len(pumps)} pumps")
|
||||
|
||||
# Test safety limits query
|
||||
safety_limits = self.db_client.get_safety_limits()
|
||||
print(f" ✓ Found {len(safety_limits)} safety limits")
|
||||
|
||||
# Test pump plans query
|
||||
pump_plans = self.db_client.get_latest_pump_plans()
|
||||
print(f" ✓ Found {len(pump_plans)} active pump plans")
|
||||
|
||||
self.tests_passed += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f" ✗ Database queries failed: {e}")
|
||||
self.tests_failed += 1
|
||||
raise
|
||||
|
||||
async def test_auto_discovery(self):
|
||||
"""Test auto-discovery functionality."""
|
||||
print("\n3. Testing Auto-Discovery...")
|
||||
|
||||
try:
|
||||
await self.auto_discovery.discover()
|
||||
|
||||
stations = self.auto_discovery.get_stations()
|
||||
pumps = self.auto_discovery.get_pumps()
|
||||
|
||||
print(f" ✓ Discovered {len(stations)} stations")
|
||||
print(f" ✓ Discovered {len(pumps)} pumps")
|
||||
|
||||
# Test individual station/pump retrieval
|
||||
if stations:
|
||||
station_id = list(stations.keys())[0]
|
||||
station = self.auto_discovery.get_station(station_id)
|
||||
if station:
|
||||
print(f" ✓ Station retrieval successful: {station['station_name']}")
|
||||
|
||||
station_pumps = self.auto_discovery.get_pumps(station_id)
|
||||
if station_pumps:
|
||||
pump_id = station_pumps[0]['pump_id']
|
||||
pump = self.auto_discovery.get_pump(station_id, pump_id)
|
||||
if pump:
|
||||
print(f" ✓ Pump retrieval successful: {pump['pump_name']}")
|
||||
|
||||
# Test validation
|
||||
validation = self.auto_discovery.validate_discovery()
|
||||
if validation['valid']:
|
||||
print(" ✓ Auto-discovery validation passed")
|
||||
else:
|
||||
print(f" ⚠ Auto-discovery validation issues: {validation['issues']}")
|
||||
|
||||
self.tests_passed += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f" ✗ Auto-discovery failed: {e}")
|
||||
self.tests_failed += 1
|
||||
raise
|
||||
|
||||
async def test_safety_framework(self):
|
||||
"""Test safety framework loading."""
|
||||
print("\n4. Testing Safety Framework...")
|
||||
|
||||
try:
|
||||
await self.safety_enforcer.load_safety_limits()
|
||||
|
||||
limits_count = len(self.safety_enforcer.safety_limits_cache)
|
||||
print(f" ✓ Loaded {limits_count} safety limits")
|
||||
|
||||
# Test setpoint enforcement
|
||||
if limits_count > 0:
|
||||
# Get first pump with safety limits
|
||||
pumps = self.auto_discovery.get_pumps()
|
||||
if pumps:
|
||||
pump = pumps[0]
|
||||
station_id = pump['station_id']
|
||||
pump_id = pump['pump_id']
|
||||
|
||||
# Test within limits
|
||||
enforced, violations = self.safety_enforcer.enforce_setpoint(
|
||||
station_id, pump_id, 35.0
|
||||
)
|
||||
if enforced == 35.0 and not violations:
|
||||
print(" ✓ Setpoint enforcement within limits successful")
|
||||
|
||||
# Test below minimum
|
||||
enforced, violations = self.safety_enforcer.enforce_setpoint(
|
||||
station_id, pump_id, 10.0
|
||||
)
|
||||
if enforced > 10.0 and violations:
|
||||
print(" ✓ Setpoint enforcement below minimum successful")
|
||||
|
||||
self.tests_passed += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f" ✗ Safety framework failed: {e}")
|
||||
self.tests_failed += 1
|
||||
raise
|
||||
|
||||
async def test_configuration(self):
|
||||
"""Test configuration management."""
|
||||
print("\n5. Testing Configuration Management...")
|
||||
|
||||
try:
|
||||
# Test database URL generation
|
||||
db_url = settings.database_url
|
||||
if db_url:
|
||||
print(" ✓ Database URL generation successful")
|
||||
|
||||
# Test safe settings dict
|
||||
safe_settings = settings.get_safe_dict()
|
||||
if 'db_password' in safe_settings and safe_settings['db_password'] == '***MASKED***':
|
||||
print(" ✓ Sensitive field masking successful")
|
||||
|
||||
# Test configuration validation
|
||||
print(f" ✓ Configuration loaded: {settings.app_name} v{settings.app_version}")
|
||||
print(f" ✓ Environment: {settings.environment}")
|
||||
|
||||
self.tests_passed += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f" ✗ Configuration test failed: {e}")
|
||||
self.tests_failed += 1
|
||||
raise
|
||||
|
||||
def print_test_summary(self):
|
||||
"""Print test summary."""
|
||||
print("\n" + "="*60)
|
||||
print("TEST SUMMARY")
|
||||
print("="*60)
|
||||
print(f"Tests Passed: {self.tests_passed}")
|
||||
print(f"Tests Failed: {self.tests_failed}")
|
||||
|
||||
total_tests = self.tests_passed + self.tests_failed
|
||||
if total_tests > 0:
|
||||
success_rate = (self.tests_passed / total_tests) * 100
|
||||
print(f"Success Rate: {success_rate:.1f}%")
|
||||
|
||||
if self.tests_failed == 0:
|
||||
print("\n🎉 All Phase 1 tests passed!")
|
||||
print("Phase 1 implementation is ready for development.")
|
||||
else:
|
||||
print(f"\n⚠ {self.tests_failed} test(s) failed.")
|
||||
print("Please review the failed tests before proceeding.")
|
||||
|
||||
print("="*60)
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run Phase 1 tests."""
|
||||
tester = Phase1Tester()
|
||||
await tester.run_all_tests()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Loading…
Reference in New Issue