243 lines
9.2 KiB
Python
243 lines
9.2 KiB
Python
"""
|
|
Unit tests for TLS manager components.
|
|
"""
|
|
|
|
import pytest
|
|
import unittest.mock
|
|
from unittest.mock import Mock, patch, MagicMock
|
|
import tempfile
|
|
import os
|
|
import ssl
|
|
from pathlib import Path
|
|
|
|
from src.core.tls_manager import TLSManager
|
|
from config.settings import settings
|
|
|
|
|
|
class TestTLSManager:
|
|
"""Test cases for TLSManager."""
|
|
|
|
def setup_method(self):
|
|
"""Set up test fixtures."""
|
|
# Create temporary directory for test certificates
|
|
self.temp_dir = tempfile.mkdtemp()
|
|
self.cert_path = os.path.join(self.temp_dir, "test_cert.pem")
|
|
self.key_path = os.path.join(self.temp_dir, "test_key.pem")
|
|
|
|
# Create dummy certificate files for testing
|
|
with open(self.cert_path, 'w') as f:
|
|
f.write("DUMMY CERTIFICATE")
|
|
with open(self.key_path, 'w') as f:
|
|
f.write("DUMMY PRIVATE KEY")
|
|
|
|
# Mock settings
|
|
self.original_tls_enabled = settings.tls_enabled
|
|
self.original_cert_path = settings.tls_cert_path
|
|
self.original_key_path = settings.tls_key_path
|
|
|
|
settings.tls_enabled = True
|
|
settings.tls_cert_path = self.cert_path
|
|
settings.tls_key_path = self.key_path
|
|
|
|
def teardown_method(self):
|
|
"""Clean up test fixtures."""
|
|
# Restore original settings
|
|
settings.tls_enabled = self.original_tls_enabled
|
|
settings.tls_cert_path = self.original_cert_path
|
|
settings.tls_key_path = self.original_key_path
|
|
|
|
# Clean up temporary directory
|
|
import shutil
|
|
shutil.rmtree(self.temp_dir)
|
|
|
|
def test_initialization_with_valid_certificates(self):
|
|
"""Test initialization with valid certificate files."""
|
|
tls_manager = TLSManager()
|
|
|
|
assert tls_manager.tls_enabled is True
|
|
assert tls_manager.cert_path == self.cert_path
|
|
assert tls_manager.key_path == self.key_path
|
|
|
|
def test_initialization_with_missing_certificates(self):
|
|
"""Test initialization with missing certificate files."""
|
|
# Set invalid paths
|
|
settings.tls_cert_path = "/nonexistent/cert.pem"
|
|
settings.tls_key_path = "/nonexistent/key.pem"
|
|
|
|
tls_manager = TLSManager()
|
|
|
|
# Should still initialize but with validation failure
|
|
assert tls_manager.tls_enabled is True
|
|
assert tls_manager.cert_path == "/nonexistent/cert.pem"
|
|
|
|
def test_initialization_with_tls_disabled(self):
|
|
"""Test initialization with TLS disabled."""
|
|
settings.tls_enabled = False
|
|
|
|
tls_manager = TLSManager()
|
|
|
|
assert tls_manager.tls_enabled is False
|
|
|
|
@patch('src.core.tls_manager.TLSManager._get_certificate_expiry')
|
|
def test_validate_certificates_success(self, mock_get_expiry):
|
|
"""Test successful certificate validation."""
|
|
# Mock certificate expiry to be in the future
|
|
from datetime import datetime, timezone
|
|
future_date = datetime(2030, 12, 31, 23, 59, 59, tzinfo=timezone.utc)
|
|
mock_get_expiry.return_value = future_date
|
|
|
|
tls_manager = TLSManager()
|
|
result = tls_manager._validate_certificates()
|
|
|
|
assert result is True
|
|
|
|
@patch('src.core.tls_manager.TLSManager._get_certificate_expiry')
|
|
def test_validate_certificates_expired(self, mock_get_expiry):
|
|
"""Test validation of expired certificate."""
|
|
# Mock certificate expiry to be in the past
|
|
from datetime import datetime, timezone
|
|
past_date = datetime(2020, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
|
|
mock_get_expiry.return_value = past_date
|
|
|
|
tls_manager = TLSManager()
|
|
result = tls_manager._validate_certificates()
|
|
|
|
assert result is False
|
|
|
|
@patch('src.core.tls_manager.TLSManager._get_certificate_expiry')
|
|
def test_validate_certificates_parsing_error(self, mock_get_expiry):
|
|
"""Test certificate validation with parsing error."""
|
|
# Mock certificate parsing to raise an exception
|
|
mock_get_expiry.side_effect = Exception("Parse error")
|
|
|
|
tls_manager = TLSManager()
|
|
result = tls_manager._validate_certificates()
|
|
|
|
assert result is False
|
|
|
|
@patch('src.core.tls_manager.ssl')
|
|
def test_create_ssl_context_success(self, mock_ssl):
|
|
"""Test successful SSL context creation."""
|
|
# Mock SSL context
|
|
mock_context = Mock()
|
|
mock_ssl.create_default_context.return_value = mock_context
|
|
|
|
tls_manager = TLSManager()
|
|
|
|
# Mock validation to succeed
|
|
with patch.object(tls_manager, '_validate_certificates', return_value=True):
|
|
result = tls_manager.create_ssl_context()
|
|
|
|
assert result is mock_context
|
|
# Check that create_default_context was called with CLIENT_AUTH purpose
|
|
mock_ssl.create_default_context.assert_called_once()
|
|
call_args = mock_ssl.create_default_context.call_args
|
|
assert call_args[0][0] == mock_ssl.Purpose.CLIENT_AUTH
|
|
mock_context.load_cert_chain.assert_called_once_with(self.cert_path, self.key_path)
|
|
|
|
def test_create_ssl_context_tls_disabled(self):
|
|
"""Test SSL context creation with TLS disabled."""
|
|
settings.tls_enabled = False
|
|
|
|
tls_manager = TLSManager()
|
|
result = tls_manager.create_ssl_context()
|
|
|
|
assert result is None
|
|
|
|
def test_create_ssl_context_invalid_certificates(self):
|
|
"""Test SSL context creation with invalid certificates."""
|
|
tls_manager = TLSManager()
|
|
|
|
# Mock validation to fail
|
|
with patch.object(tls_manager, '_validate_certificates', return_value=False):
|
|
result = tls_manager.create_ssl_context()
|
|
|
|
assert result is None
|
|
|
|
def test_get_rest_api_ssl_config_success(self):
|
|
"""Test successful REST API SSL configuration."""
|
|
tls_manager = TLSManager()
|
|
|
|
# Mock validation to succeed
|
|
with patch.object(tls_manager, '_validate_certificates', return_value=True):
|
|
result = tls_manager.get_rest_api_ssl_config()
|
|
|
|
assert result == (self.cert_path, self.key_path)
|
|
|
|
def test_get_rest_api_ssl_config_tls_disabled(self):
|
|
"""Test REST API SSL configuration with TLS disabled."""
|
|
settings.tls_enabled = False
|
|
|
|
tls_manager = TLSManager()
|
|
result = tls_manager.get_rest_api_ssl_config()
|
|
|
|
assert result is None
|
|
|
|
def test_get_rest_api_ssl_config_invalid_certificates(self):
|
|
"""Test REST API SSL configuration with invalid certificates."""
|
|
tls_manager = TLSManager()
|
|
|
|
# Mock validation to fail
|
|
with patch.object(tls_manager, '_validate_certificates', return_value=False):
|
|
result = tls_manager.get_rest_api_ssl_config()
|
|
|
|
assert result is None
|
|
|
|
def test_check_certificate_rotation_needed(self):
|
|
"""Test certificate rotation check when needed."""
|
|
tls_manager = TLSManager()
|
|
|
|
# Mock certificate with near expiry
|
|
from datetime import datetime, timezone, timedelta
|
|
near_expiry = datetime.now(timezone.utc) + timedelta(days=5) # Expires in 5 days
|
|
tls_manager.cert_expiry_dates[self.cert_path] = near_expiry
|
|
|
|
result = tls_manager.check_certificate_rotation()
|
|
|
|
assert result is True
|
|
|
|
def test_check_certificate_rotation_not_needed(self):
|
|
"""Test certificate rotation check when not needed."""
|
|
tls_manager = TLSManager()
|
|
|
|
# Mock certificate with distant expiry
|
|
from datetime import datetime, timezone, timedelta
|
|
distant_expiry = datetime.now(timezone.utc) + timedelta(days=365) # Expires in 1 year
|
|
tls_manager.cert_expiry_dates[self.cert_path] = distant_expiry
|
|
|
|
result = tls_manager.check_certificate_rotation()
|
|
|
|
assert result is False
|
|
|
|
def test_check_certificate_rotation_tls_disabled(self):
|
|
"""Test certificate rotation check with TLS disabled."""
|
|
settings.tls_enabled = False
|
|
|
|
tls_manager = TLSManager()
|
|
result = tls_manager.check_certificate_rotation()
|
|
|
|
assert result is False
|
|
|
|
@patch('src.core.tls_manager.os.makedirs')
|
|
def test_generate_self_signed_certificate_success(self, mock_makedirs):
|
|
"""Test successful self-signed certificate generation."""
|
|
tls_manager = TLSManager()
|
|
|
|
# Mock the entire certificate generation to succeed
|
|
with patch('builtins.open', unittest.mock.mock_open()) as mock_file:
|
|
result = tls_manager.generate_self_signed_certificate(self.temp_dir)
|
|
|
|
assert result is True
|
|
assert tls_manager.cert_path.endswith("cert.pem")
|
|
assert tls_manager.key_path.endswith("key.pem")
|
|
|
|
@patch('src.core.tls_manager.os.makedirs')
|
|
def test_generate_self_signed_certificate_failure(self, mock_makedirs):
|
|
"""Test self-signed certificate generation failure."""
|
|
# Mock makedirs to raise an exception
|
|
mock_makedirs.side_effect = Exception("Permission denied")
|
|
|
|
tls_manager = TLSManager()
|
|
result = tls_manager.generate_self_signed_certificate(self.temp_dir)
|
|
|
|
assert result is False |