Complete REST API architectural refactoring for testability

- Refactored REST API server to use non-blocking background task
- Fixed setpoint manager bug in get_all_current_setpoints() method
- Added proper authentication to REST API test
- All 6 optimization-to-SCADA integration tests now passing
- All 5 safety workflow tests continue to pass

Key changes:
1. REST API server now starts in background task using asyncio.create_task()
2. Added proper server state management (_is_running, _server_task, _server)
3. Implemented proper shutdown mechanism with task cancellation
4. Fixed dictionary iteration bug in setpoint manager
5. Updated test to use correct admin password (admin123)
6. Test now authenticates before accessing protected endpoints
This commit is contained in:
openhands 2025-10-28 18:06:18 +00:00
parent ab890f923d
commit 20b781feac
3 changed files with 54 additions and 8 deletions

View File

@ -224,8 +224,7 @@ class SetpointManager:
""" """
setpoints = {} setpoints = {}
for station in self.discovery.get_stations(): for station_id, station_data in self.discovery.get_stations().items():
station_id = station['station_id']
setpoints[station_id] = {} setpoints[station_id] = {}
for pump in self.discovery.get_pumps(station_id): for pump in self.discovery.get_pumps(station_id):

View File

@ -5,6 +5,7 @@ Provides REST endpoints for emergency stop, status monitoring, and setpoint acce
Enhanced with OpenAPI documentation and performance optimizations for Phase 5. Enhanced with OpenAPI documentation and performance optimizations for Phase 5.
""" """
import asyncio
from typing import Optional, Dict, Any, Tuple from typing import Optional, Dict, Any, Tuple
from datetime import datetime, timedelta from datetime import datetime, timedelta
import structlog import structlog
@ -180,6 +181,11 @@ class RESTAPIServer:
self.enable_compression = enable_compression self.enable_compression = enable_compression
self.cache_ttl_seconds = cache_ttl_seconds self.cache_ttl_seconds = cache_ttl_seconds
# Server state
self._server_task = None
self._server = None
self._is_running = False
# Performance tracking # Performance tracking
self.total_requests = 0 self.total_requests = 0
self.cache_hits = 0 self.cache_hits = 0
@ -660,9 +666,13 @@ class RESTAPIServer:
) )
async def start(self): async def start(self):
"""Start the REST API server.""" """Start the REST API server in a non-blocking background task."""
import uvicorn import uvicorn
if self._is_running:
logger.warning("rest_api_server_already_running")
return
# Get TLS configuration # Get TLS configuration
tls_manager = get_tls_manager() tls_manager = get_tls_manager()
ssl_config = tls_manager.get_rest_api_ssl_config() ssl_config = tls_manager.get_rest_api_ssl_config()
@ -690,13 +700,39 @@ class RESTAPIServer:
}) })
config = uvicorn.Config(**config_kwargs) config = uvicorn.Config(**config_kwargs)
server = uvicorn.Server(config) self._server = uvicorn.Server(config)
await server.serve()
# Start server in background task (non-blocking)
self._server_task = asyncio.create_task(self._server.serve())
self._is_running = True
# Wait briefly for server to start accepting connections
await asyncio.sleep(0.5)
async def stop(self): async def stop(self):
"""Stop the REST API server.""" """Stop the REST API server."""
if not self._is_running:
logger.warning("rest_api_server_not_running")
return
logger.info("rest_api_server_stopping") logger.info("rest_api_server_stopping")
if self._server:
self._server.should_exit = True
if self._server_task:
self._server_task.cancel()
try:
await self._server_task
except asyncio.CancelledError:
pass
self._server = None
self._server_task = None
self._is_running = False
logger.info("rest_api_server_stopped")
def get_performance_status(self) -> Dict[str, Any]: def get_performance_status(self) -> Dict[str, Any]:
"""Get performance status information.""" """Get performance status information."""
cache_stats = self.response_cache.get_stats() if self.response_cache else { cache_stats = self.response_cache.get_stats() if self.response_cache else {

View File

@ -271,7 +271,6 @@ class TestOptimizationToSCADAIntegration:
await modbus_server.stop() await modbus_server.stop()
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.skip(reason="REST API server implementation blocks on start() - needs architectural refactoring")
async def test_rest_api_setpoint_exposure(self, system_components): async def test_rest_api_setpoint_exposure(self, system_components):
"""Test that REST API correctly exposes setpoints.""" """Test that REST API correctly exposes setpoints."""
setpoint_manager = system_components['setpoint_manager'] setpoint_manager = system_components['setpoint_manager']
@ -311,14 +310,26 @@ class TestOptimizationToSCADAIntegration:
# Test setpoint retrieval via REST API # Test setpoint retrieval via REST API
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
# First authenticate to get token
login_response = await client.post(
"http://127.0.0.1:8000/api/v1/auth/login",
json={"username": "admin", "password": "admin123"}
)
assert login_response.status_code == 200
login_data = login_response.json()
token = login_data["access_token"]
# Set up headers with authentication
headers = {"Authorization": f"Bearer {token}"}
# Get all setpoints # Get all setpoints
response = await client.get("http://127.0.0.1:8000/api/v1/setpoints") response = await client.get("http://127.0.0.1:8000/api/v1/setpoints", headers=headers)
assert response.status_code == 200 assert response.status_code == 200
setpoints = response.json() setpoints = response.json()
assert len(setpoints) > 0 assert len(setpoints) > 0
# Get specific setpoint # Get specific setpoint
response = await client.get("http://127.0.0.1:8000/api/v1/setpoints/STATION_001/PUMP_001") response = await client.get("http://127.0.0.1:8000/api/v1/setpoints/STATION_001/PUMP_001", headers=headers)
assert response.status_code == 200 assert response.status_code == 200
setpoint_data = response.json() setpoint_data = response.json()
assert 'setpoint_hz' in setpoint_data assert 'setpoint_hz' in setpoint_data