258 lines
9.2 KiB
Python
258 lines
9.2 KiB
Python
"""
|
|
Protocol Discovery Service - Persistent version with database storage
|
|
"""
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import List, Dict, Any, Optional
|
|
from enum import Enum
|
|
from dataclasses import dataclass, asdict
|
|
|
|
from sqlalchemy import text
|
|
from config.settings import settings
|
|
from src.database.flexible_client import FlexibleDatabaseClient
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DiscoveryStatus(Enum):
|
|
"""Discovery operation status"""
|
|
PENDING = "pending"
|
|
RUNNING = "running"
|
|
COMPLETED = "completed"
|
|
FAILED = "failed"
|
|
|
|
|
|
class ProtocolType(Enum):
|
|
MODBUS_TCP = "modbus_tcp"
|
|
MODBUS_RTU = "modbus_rtu"
|
|
OPC_UA = "opc_ua"
|
|
REST_API = "rest_api"
|
|
|
|
|
|
@dataclass
|
|
class DiscoveredEndpoint:
|
|
protocol_type: ProtocolType
|
|
address: str
|
|
port: Optional[int] = None
|
|
device_id: Optional[str] = None
|
|
device_name: Optional[str] = None
|
|
capabilities: Optional[List[str]] = None
|
|
response_time: Optional[float] = None
|
|
discovered_at: Optional[datetime] = None
|
|
|
|
def __post_init__(self):
|
|
if self.capabilities is None:
|
|
self.capabilities = []
|
|
|
|
|
|
@dataclass
|
|
class DiscoveryResult:
|
|
scan_id: str
|
|
status: DiscoveryStatus
|
|
discovered_endpoints: List[DiscoveredEndpoint]
|
|
scan_started_at: datetime
|
|
scan_completed_at: Optional[datetime] = None
|
|
error_message: Optional[str] = None
|
|
|
|
|
|
class PersistentProtocolDiscoveryService:
|
|
"""
|
|
Protocol discovery service with database persistence
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._current_scan_id: Optional[str] = None
|
|
self._db_client = FlexibleDatabaseClient(settings.database_url)
|
|
|
|
async def initialize(self):
|
|
"""Initialize database connection"""
|
|
try:
|
|
await self._db_client.connect()
|
|
logger.info("Discovery service database initialized")
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize discovery service database: {e}")
|
|
|
|
def get_discovery_status(self) -> Dict[str, Any]:
|
|
"""Get current discovery service status"""
|
|
try:
|
|
# Get recent scans from database
|
|
query = text("""
|
|
SELECT scan_id, status, scan_started_at, scan_completed_at
|
|
FROM discovery_results
|
|
ORDER BY scan_started_at DESC
|
|
LIMIT 5
|
|
""")
|
|
|
|
with self._db_client.engine.connect() as conn:
|
|
result = conn.execute(query)
|
|
recent_scans = [
|
|
{
|
|
'scan_id': row[0],
|
|
'status': row[1],
|
|
'scan_started_at': row[2].isoformat() if row[2] else None,
|
|
'scan_completed_at': row[3].isoformat() if row[3] else None
|
|
}
|
|
for row in result
|
|
]
|
|
|
|
# Get total discovered endpoints (count unique endpoints across all scans)
|
|
query = text("""
|
|
SELECT COUNT(DISTINCT endpoint->>'device_id')
|
|
FROM discovery_results dr,
|
|
jsonb_array_elements(dr.discovered_endpoints) AS endpoint
|
|
WHERE dr.status = 'completed'
|
|
""")
|
|
|
|
with self._db_client.engine.connect() as conn:
|
|
result = conn.execute(query)
|
|
total_endpoints = result.scalar() or 0
|
|
|
|
return {
|
|
"current_scan_id": self._current_scan_id,
|
|
"is_scanning": self._current_scan_id is not None,
|
|
"recent_scans": recent_scans,
|
|
"total_discovered_endpoints": total_endpoints
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error getting discovery status: {e}")
|
|
return {
|
|
"current_scan_id": None,
|
|
"is_scanning": False,
|
|
"recent_scans": [],
|
|
"total_discovered_endpoints": 0
|
|
}
|
|
|
|
def get_scan_result(self, scan_id: str) -> Optional[Dict[str, Any]]:
|
|
"""Get result for a specific scan from database"""
|
|
try:
|
|
query = text("""
|
|
SELECT scan_id, status, discovered_endpoints,
|
|
scan_started_at, scan_completed_at, error_message
|
|
FROM discovery_results
|
|
WHERE scan_id = :scan_id
|
|
""")
|
|
|
|
with self._db_client.engine.connect() as conn:
|
|
result = conn.execute(query, {"scan_id": scan_id})
|
|
row = result.fetchone()
|
|
|
|
if row:
|
|
return {
|
|
"scan_id": row[0],
|
|
"status": row[1],
|
|
"discovered_endpoints": row[2] if row[2] else [],
|
|
"scan_started_at": row[3].isoformat() if row[3] else None,
|
|
"scan_completed_at": row[4].isoformat() if row[4] else None,
|
|
"error_message": row[5]
|
|
}
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error getting scan result {scan_id}: {e}")
|
|
return None
|
|
|
|
async def discover_all_protocols(self, scan_id: str) -> None:
|
|
"""
|
|
Discover all available protocols (simulated for now)
|
|
"""
|
|
try:
|
|
# Store scan as started
|
|
await self._store_scan_result(
|
|
scan_id=scan_id,
|
|
status=DiscoveryStatus.RUNNING,
|
|
discovered_endpoints=[],
|
|
scan_started_at=datetime.now(),
|
|
scan_completed_at=None,
|
|
error_message=None
|
|
)
|
|
|
|
# Simulate discovery process
|
|
await asyncio.sleep(2)
|
|
|
|
# Create mock discovered endpoints
|
|
discovered_endpoints = [
|
|
{
|
|
"protocol_type": "modbus_tcp",
|
|
"address": "192.168.1.100",
|
|
"port": 502,
|
|
"device_id": "pump_controller_001",
|
|
"device_name": "Main Pump Controller",
|
|
"capabilities": ["read_coils", "read_holding_registers"],
|
|
"response_time": 0.15,
|
|
"discovered_at": datetime.now().isoformat()
|
|
},
|
|
{
|
|
"protocol_type": "opc_ua",
|
|
"address": "192.168.1.101",
|
|
"port": 4840,
|
|
"device_id": "scada_server_001",
|
|
"device_name": "SCADA Server",
|
|
"capabilities": ["browse", "read", "write"],
|
|
"response_time": 0.25,
|
|
"discovered_at": datetime.now().isoformat()
|
|
}
|
|
]
|
|
|
|
# Store completed scan
|
|
await self._store_scan_result(
|
|
scan_id=scan_id,
|
|
status=DiscoveryStatus.COMPLETED,
|
|
discovered_endpoints=discovered_endpoints,
|
|
scan_started_at=datetime.now(),
|
|
scan_completed_at=datetime.now(),
|
|
error_message=None
|
|
)
|
|
|
|
logger.info(f"Discovery scan {scan_id} completed with {len(discovered_endpoints)} endpoints")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Discovery scan {scan_id} failed: {e}")
|
|
await self._store_scan_result(
|
|
scan_id=scan_id,
|
|
status=DiscoveryStatus.FAILED,
|
|
discovered_endpoints=[],
|
|
scan_started_at=datetime.now(),
|
|
scan_completed_at=datetime.now(),
|
|
error_message=str(e)
|
|
)
|
|
|
|
async def _store_scan_result(
|
|
self,
|
|
scan_id: str,
|
|
status: DiscoveryStatus,
|
|
discovered_endpoints: List[Dict[str, Any]],
|
|
scan_started_at: datetime,
|
|
scan_completed_at: Optional[datetime] = None,
|
|
error_message: Optional[str] = None
|
|
) -> None:
|
|
"""Store scan result in database"""
|
|
try:
|
|
query = text("""
|
|
INSERT INTO discovery_results
|
|
(scan_id, status, discovered_endpoints, scan_started_at, scan_completed_at, error_message)
|
|
VALUES (:scan_id, :status, :discovered_endpoints, :scan_started_at, :scan_completed_at, :error_message)
|
|
ON CONFLICT (scan_id) DO UPDATE SET
|
|
status = EXCLUDED.status,
|
|
discovered_endpoints = EXCLUDED.discovered_endpoints,
|
|
scan_completed_at = EXCLUDED.scan_completed_at,
|
|
error_message = EXCLUDED.error_message
|
|
""")
|
|
|
|
with self._db_client.engine.connect() as conn:
|
|
conn.execute(query, {
|
|
"scan_id": scan_id,
|
|
"status": status.value,
|
|
"discovered_endpoints": json.dumps(discovered_endpoints),
|
|
"scan_started_at": scan_started_at,
|
|
"scan_completed_at": scan_completed_at,
|
|
"error_message": error_message
|
|
})
|
|
conn.commit()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to store scan result {scan_id}: {e}")
|
|
|
|
|
|
# Global instance
|
|
persistent_discovery_service = PersistentProtocolDiscoveryService() |