Complete REST API architectural refactoring for testability
- Refactored REST API server to use non-blocking background task - Fixed setpoint manager bug in get_all_current_setpoints() method - Added proper authentication to REST API test - All 6 optimization-to-SCADA integration tests now passing - All 5 safety workflow tests continue to pass Key changes: 1. REST API server now starts in background task using asyncio.create_task() 2. Added proper server state management (_is_running, _server_task, _server) 3. Implemented proper shutdown mechanism with task cancellation 4. Fixed dictionary iteration bug in setpoint manager 5. Updated test to use correct admin password (admin123) 6. Test now authenticates before accessing protected endpoints
This commit is contained in:
parent
ab890f923d
commit
20b781feac
|
|
@ -224,8 +224,7 @@ class SetpointManager:
|
||||||
"""
|
"""
|
||||||
setpoints = {}
|
setpoints = {}
|
||||||
|
|
||||||
for station in self.discovery.get_stations():
|
for station_id, station_data in self.discovery.get_stations().items():
|
||||||
station_id = station['station_id']
|
|
||||||
setpoints[station_id] = {}
|
setpoints[station_id] = {}
|
||||||
|
|
||||||
for pump in self.discovery.get_pumps(station_id):
|
for pump in self.discovery.get_pumps(station_id):
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ Provides REST endpoints for emergency stop, status monitoring, and setpoint acce
|
||||||
Enhanced with OpenAPI documentation and performance optimizations for Phase 5.
|
Enhanced with OpenAPI documentation and performance optimizations for Phase 5.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from typing import Optional, Dict, Any, Tuple
|
from typing import Optional, Dict, Any, Tuple
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import structlog
|
import structlog
|
||||||
|
|
@ -180,6 +181,11 @@ class RESTAPIServer:
|
||||||
self.enable_compression = enable_compression
|
self.enable_compression = enable_compression
|
||||||
self.cache_ttl_seconds = cache_ttl_seconds
|
self.cache_ttl_seconds = cache_ttl_seconds
|
||||||
|
|
||||||
|
# Server state
|
||||||
|
self._server_task = None
|
||||||
|
self._server = None
|
||||||
|
self._is_running = False
|
||||||
|
|
||||||
# Performance tracking
|
# Performance tracking
|
||||||
self.total_requests = 0
|
self.total_requests = 0
|
||||||
self.cache_hits = 0
|
self.cache_hits = 0
|
||||||
|
|
@ -660,9 +666,13 @@ class RESTAPIServer:
|
||||||
)
|
)
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
"""Start the REST API server."""
|
"""Start the REST API server in a non-blocking background task."""
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
|
if self._is_running:
|
||||||
|
logger.warning("rest_api_server_already_running")
|
||||||
|
return
|
||||||
|
|
||||||
# Get TLS configuration
|
# Get TLS configuration
|
||||||
tls_manager = get_tls_manager()
|
tls_manager = get_tls_manager()
|
||||||
ssl_config = tls_manager.get_rest_api_ssl_config()
|
ssl_config = tls_manager.get_rest_api_ssl_config()
|
||||||
|
|
@ -690,12 +700,38 @@ class RESTAPIServer:
|
||||||
})
|
})
|
||||||
|
|
||||||
config = uvicorn.Config(**config_kwargs)
|
config = uvicorn.Config(**config_kwargs)
|
||||||
server = uvicorn.Server(config)
|
self._server = uvicorn.Server(config)
|
||||||
await server.serve()
|
|
||||||
|
# Start server in background task (non-blocking)
|
||||||
|
self._server_task = asyncio.create_task(self._server.serve())
|
||||||
|
self._is_running = True
|
||||||
|
|
||||||
|
# Wait briefly for server to start accepting connections
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
"""Stop the REST API server."""
|
"""Stop the REST API server."""
|
||||||
|
if not self._is_running:
|
||||||
|
logger.warning("rest_api_server_not_running")
|
||||||
|
return
|
||||||
|
|
||||||
logger.info("rest_api_server_stopping")
|
logger.info("rest_api_server_stopping")
|
||||||
|
|
||||||
|
if self._server:
|
||||||
|
self._server.should_exit = True
|
||||||
|
|
||||||
|
if self._server_task:
|
||||||
|
self._server_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._server_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self._server = None
|
||||||
|
self._server_task = None
|
||||||
|
self._is_running = False
|
||||||
|
|
||||||
|
logger.info("rest_api_server_stopped")
|
||||||
|
|
||||||
def get_performance_status(self) -> Dict[str, Any]:
|
def get_performance_status(self) -> Dict[str, Any]:
|
||||||
"""Get performance status information."""
|
"""Get performance status information."""
|
||||||
|
|
|
||||||
|
|
@ -271,7 +271,6 @@ class TestOptimizationToSCADAIntegration:
|
||||||
await modbus_server.stop()
|
await modbus_server.stop()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.skip(reason="REST API server implementation blocks on start() - needs architectural refactoring")
|
|
||||||
async def test_rest_api_setpoint_exposure(self, system_components):
|
async def test_rest_api_setpoint_exposure(self, system_components):
|
||||||
"""Test that REST API correctly exposes setpoints."""
|
"""Test that REST API correctly exposes setpoints."""
|
||||||
setpoint_manager = system_components['setpoint_manager']
|
setpoint_manager = system_components['setpoint_manager']
|
||||||
|
|
@ -311,14 +310,26 @@ class TestOptimizationToSCADAIntegration:
|
||||||
|
|
||||||
# Test setpoint retrieval via REST API
|
# Test setpoint retrieval via REST API
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
|
# First authenticate to get token
|
||||||
|
login_response = await client.post(
|
||||||
|
"http://127.0.0.1:8000/api/v1/auth/login",
|
||||||
|
json={"username": "admin", "password": "admin123"}
|
||||||
|
)
|
||||||
|
assert login_response.status_code == 200
|
||||||
|
login_data = login_response.json()
|
||||||
|
token = login_data["access_token"]
|
||||||
|
|
||||||
|
# Set up headers with authentication
|
||||||
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
# Get all setpoints
|
# Get all setpoints
|
||||||
response = await client.get("http://127.0.0.1:8000/api/v1/setpoints")
|
response = await client.get("http://127.0.0.1:8000/api/v1/setpoints", headers=headers)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
setpoints = response.json()
|
setpoints = response.json()
|
||||||
assert len(setpoints) > 0
|
assert len(setpoints) > 0
|
||||||
|
|
||||||
# Get specific setpoint
|
# Get specific setpoint
|
||||||
response = await client.get("http://127.0.0.1:8000/api/v1/setpoints/STATION_001/PUMP_001")
|
response = await client.get("http://127.0.0.1:8000/api/v1/setpoints/STATION_001/PUMP_001", headers=headers)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
setpoint_data = response.json()
|
setpoint_data = response.json()
|
||||||
assert 'setpoint_hz' in setpoint_data
|
assert 'setpoint_hz' in setpoint_data
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue