diff --git a/src/core/setpoint_manager.py b/src/core/setpoint_manager.py index 02b81a5..e7f5c8b 100644 --- a/src/core/setpoint_manager.py +++ b/src/core/setpoint_manager.py @@ -224,8 +224,7 @@ class SetpointManager: """ setpoints = {} - for station in self.discovery.get_stations(): - station_id = station['station_id'] + for station_id, station_data in self.discovery.get_stations().items(): setpoints[station_id] = {} for pump in self.discovery.get_pumps(station_id): diff --git a/src/protocols/rest_api.py b/src/protocols/rest_api.py index 0915b99..48869f8 100644 --- a/src/protocols/rest_api.py +++ b/src/protocols/rest_api.py @@ -5,6 +5,7 @@ Provides REST endpoints for emergency stop, status monitoring, and setpoint acce Enhanced with OpenAPI documentation and performance optimizations for Phase 5. """ +import asyncio from typing import Optional, Dict, Any, Tuple from datetime import datetime, timedelta import structlog @@ -180,6 +181,11 @@ class RESTAPIServer: self.enable_compression = enable_compression self.cache_ttl_seconds = cache_ttl_seconds + # Server state + self._server_task = None + self._server = None + self._is_running = False + # Performance tracking self.total_requests = 0 self.cache_hits = 0 @@ -660,9 +666,13 @@ class RESTAPIServer: ) async def start(self): - """Start the REST API server.""" + """Start the REST API server in a non-blocking background task.""" import uvicorn + if self._is_running: + logger.warning("rest_api_server_already_running") + return + # Get TLS configuration tls_manager = get_tls_manager() ssl_config = tls_manager.get_rest_api_ssl_config() @@ -690,12 +700,38 @@ class RESTAPIServer: }) config = uvicorn.Config(**config_kwargs) - server = uvicorn.Server(config) - await server.serve() + self._server = uvicorn.Server(config) + + # Start server in background task (non-blocking) + self._server_task = asyncio.create_task(self._server.serve()) + self._is_running = True + + # Wait briefly for server to start accepting connections + await asyncio.sleep(0.5) async def stop(self): """Stop the REST API server.""" + if not self._is_running: + logger.warning("rest_api_server_not_running") + return + logger.info("rest_api_server_stopping") + + if self._server: + self._server.should_exit = True + + if self._server_task: + self._server_task.cancel() + try: + await self._server_task + except asyncio.CancelledError: + pass + + self._server = None + self._server_task = None + self._is_running = False + + logger.info("rest_api_server_stopped") def get_performance_status(self) -> Dict[str, Any]: """Get performance status information.""" diff --git a/tests/integration/test_optimization_to_scada.py b/tests/integration/test_optimization_to_scada.py index f8262f8..4e54169 100644 --- a/tests/integration/test_optimization_to_scada.py +++ b/tests/integration/test_optimization_to_scada.py @@ -271,7 +271,6 @@ class TestOptimizationToSCADAIntegration: await modbus_server.stop() @pytest.mark.asyncio - @pytest.mark.skip(reason="REST API server implementation blocks on start() - needs architectural refactoring") async def test_rest_api_setpoint_exposure(self, system_components): """Test that REST API correctly exposes setpoints.""" setpoint_manager = system_components['setpoint_manager'] @@ -311,14 +310,26 @@ class TestOptimizationToSCADAIntegration: # Test setpoint retrieval via REST API async with httpx.AsyncClient() as client: + # First authenticate to get token + login_response = await client.post( + "http://127.0.0.1:8000/api/v1/auth/login", + json={"username": "admin", "password": "admin123"} + ) + assert login_response.status_code == 200 + login_data = login_response.json() + token = login_data["access_token"] + + # Set up headers with authentication + headers = {"Authorization": f"Bearer {token}"} + # Get all setpoints - response = await client.get("http://127.0.0.1:8000/api/v1/setpoints") + response = await client.get("http://127.0.0.1:8000/api/v1/setpoints", headers=headers) assert response.status_code == 200 setpoints = response.json() assert len(setpoints) > 0 # Get specific setpoint - response = await client.get("http://127.0.0.1:8000/api/v1/setpoints/STATION_001/PUMP_001") + response = await client.get("http://127.0.0.1:8000/api/v1/setpoints/STATION_001/PUMP_001", headers=headers) assert response.status_code == 200 setpoint_data = response.json() assert 'setpoint_hz' in setpoint_data