""" Unit tests for DatabaseClient class. """ import pytest import pytest_asyncio from unittest.mock import Mock, patch, MagicMock from typing import Dict, Any from src.database.client import DatabaseClient class TestDatabaseClient: """Test cases for DatabaseClient.""" @pytest_asyncio.fixture async def db_client(self): """Create a test database client.""" client = DatabaseClient( database_url="postgresql://test:test@localhost:5432/test", min_connections=1, max_connections=2 ) # Mock the connection pool client.connection_pool = Mock() client.connection_pool.getconn.return_value = Mock() client.connection_pool.putconn = Mock() return client @pytest.mark.asyncio async def test_connect_success(self, db_client): """Test successful database connection.""" with patch('psycopg2.pool.ThreadedConnectionPool') as mock_pool: mock_pool_instance = Mock() mock_pool.return_value = mock_pool_instance mock_conn = MagicMock() mock_cursor = MagicMock() mock_cursor.fetchone.return_value = {'version': 'PostgreSQL 14.0'} mock_conn.cursor.return_value.__enter__.return_value = mock_cursor mock_conn.cursor.return_value.__exit__.return_value = None mock_pool_instance.getconn.return_value = mock_conn await db_client.connect() assert db_client.connection_pool is not None mock_pool.assert_called_once() @pytest.mark.asyncio async def test_connect_failure(self, db_client): """Test database connection failure.""" with patch('psycopg2.pool.ThreadedConnectionPool') as mock_pool: mock_pool.side_effect = Exception("Connection failed") with pytest.raises(Exception, match="Connection failed"): await db_client.connect() def test_execute_query_success(self, db_client): """Test successful query execution.""" # Mock the cursor mock_cursor = MagicMock() mock_cursor.fetchall.return_value = [{'id': 1, 'name': 'test'}] # Mock the cursor context manager db_client._get_cursor = MagicMock() db_client._get_cursor.return_value.__enter__.return_value = mock_cursor db_client._get_cursor.return_value.__exit__.return_value = None result = db_client.execute_query("SELECT * FROM test", (1,)) assert result == [{'id': 1, 'name': 'test'}] mock_cursor.execute.assert_called_once_with("SELECT * FROM test", (1,)) def test_execute_query_failure(self, db_client): """Test query execution failure.""" # Mock the cursor to raise an exception mock_cursor = MagicMock() mock_cursor.execute.side_effect = Exception("Query failed") # Mock the cursor context manager db_client._get_cursor = MagicMock() db_client._get_cursor.return_value.__enter__.return_value = mock_cursor db_client._get_cursor.return_value.__exit__.return_value = None with pytest.raises(Exception, match="Query failed"): db_client.execute_query("SELECT * FROM test") def test_execute_success(self, db_client): """Test successful execute operation.""" # Mock the cursor mock_cursor = MagicMock() # Mock the cursor context manager db_client._get_cursor = MagicMock() db_client._get_cursor.return_value.__enter__.return_value = mock_cursor db_client._get_cursor.return_value.__exit__.return_value = None db_client.execute("INSERT INTO test VALUES (%s)", (1,)) mock_cursor.execute.assert_called_once_with("INSERT INTO test VALUES (%s)", (1,)) def test_health_check_success(self, db_client): """Test successful health check.""" # Mock the cursor mock_cursor = MagicMock() mock_cursor.fetchone.return_value = {'health_check': 1} # Mock the cursor context manager db_client._get_cursor = MagicMock() db_client._get_cursor.return_value.__enter__.return_value = mock_cursor db_client._get_cursor.return_value.__exit__.return_value = None result = db_client.health_check() assert result is True mock_cursor.execute.assert_called_once_with("SELECT 1 as health_check;") def test_health_check_failure(self, db_client): """Test health check failure.""" # Mock the cursor to raise an exception mock_cursor = MagicMock() mock_cursor.execute.side_effect = Exception("Health check failed") # Mock the cursor context manager db_client._get_cursor = MagicMock() db_client._get_cursor.return_value.__enter__.return_value = mock_cursor db_client._get_cursor.return_value.__exit__.return_value = None result = db_client.health_check() assert result is False def test_get_connection_stats(self, db_client): """Test connection statistics retrieval.""" stats = db_client.get_connection_stats() expected = { "min_connections": 1, "max_connections": 2, "pool_status": "active" } assert stats == expected def test_get_connection_stats_no_pool(self): """Test connection statistics when pool is not initialized.""" client = DatabaseClient("test_url") stats = client.get_connection_stats() assert stats == {"status": "pool_not_initialized"} @pytest.mark.asyncio async def test_disconnect(self, db_client): """Test database disconnection.""" db_client.connection_pool.closeall = Mock() await db_client.disconnect() db_client.connection_pool.closeall.assert_called_once() @pytest.mark.asyncio async def test_disconnect_no_pool(self, db_client): """Test disconnection when no pool exists.""" db_client.connection_pool = None # Should not raise an exception await db_client.disconnect()