diff --git a/backend/core/tests/test_services/__init__.py b/backend/core/tests/test_services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/core/tests/test_services/test_ai_client_service.py b/backend/core/tests/test_services/test_ai_client_service.py new file mode 100644 index 0000000..c19c94a --- /dev/null +++ b/backend/core/tests/test_services/test_ai_client_service.py @@ -0,0 +1,272 @@ +import pytest +from unittest.mock import Mock, patch + +from core.services.ai_client_service import AIClientService + + +@pytest.mark.unit +class TestAIClientService: + def test_init(self): + service = AIClientService() + + assert service.provider_factory is not None + + @patch('core.services.ai_client_service.AIProviderFactory') + def test_get_client_and_model_with_ai_provider_id(self, mock_factory_class): + mock_factory = Mock() + mock_factory_class.return_value = mock_factory + + mock_provider = Mock() + mock_factory.create_provider.return_value = mock_provider + + mock_ai_provider = Mock() + mock_ai_provider.provider = 'gemini' + mock_ai_provider.provider_api_key = 'test_key' + mock_ai_provider.metadata = {'model': 'gemini-pro'} + + with patch('core.services.ai_client_service.AIProvider.objects.get', return_value=mock_ai_provider): + service = AIClientService() + provider, model = service.get_client_and_model( + app=Mock(), + ai_provider_id=1, + model='custom-model' + ) + + mock_factory.create_provider.assert_called_once_with( + provider_type='gemini', + api_key='test_key', + config={'model': 'gemini-pro'} + ) + assert provider == mock_provider + assert model == 'custom-model' + + @patch('core.services.ai_client_service.AIProviderFactory') + def test_get_client_and_model_with_invalid_ai_provider_id(self, mock_factory_class): + from core.models import AIProvider + + mock_factory = Mock() + mock_factory_class.return_value = mock_factory + + with patch('core.services.ai_client_service.AIProvider.objects.get', side_effect=AIProvider.DoesNotExist()): + service = AIClientService() + provider, model = service.get_client_and_model( + app=Mock(), + ai_provider_id=999 + ) + + assert provider is None + assert model is None + + @patch('core.services.ai_client_service.AIProviderFactory') + def test_get_client_and_model_without_ai_provider_id_with_config(self, mock_factory_class): + mock_factory = Mock() + mock_factory_class.return_value = mock_factory + + mock_provider = Mock() + mock_provider.get_models.return_value = [{'name': 'gemini-pro'}] + mock_factory.create_provider.return_value = mock_provider + + mock_ai_provider = Mock() + mock_ai_provider.provider = 'gemini' + mock_ai_provider.provider_api_key = 'test_key' + mock_ai_provider.metadata = {} + + mock_config = Mock() + mock_config.ai_provider = mock_ai_provider + mock_config.external_model_id = 'gemini-pro' + + with patch.object(AIClientService, '_get_app_provider_config', return_value=mock_config): + service = AIClientService() + provider, model = service.get_client_and_model( + app=Mock(), + context='response', + capability='text' + ) + + mock_factory.create_provider.assert_called_once() + assert provider == mock_provider + assert model == 'gemini-pro' + + @patch('core.services.ai_client_service.AIProviderFactory') + def test_get_client_and_model_without_ai_provider_id_without_config(self, mock_factory_class): + mock_factory = Mock() + mock_factory_class.return_value = mock_factory + + with patch.object(AIClientService, '_get_app_provider_config', return_value=None): + service = AIClientService() + provider, model = service.get_client_and_model( + app=Mock(), + context='response', + capability='text' + ) + + assert provider is None + assert model is None + + @patch('core.services.ai_client_service.AIProviderFactory') + def test_get_client_and_model_fallback_to_default_model(self, mock_factory_class): + mock_factory = Mock() + mock_factory_class.return_value = mock_factory + + mock_provider = Mock() + mock_provider.get_models.return_value = [] + mock_factory.create_provider.return_value = mock_provider + + mock_ai_provider = Mock() + mock_ai_provider.provider = 'gemini' + mock_ai_provider.provider_api_key = 'test_key' + mock_ai_provider.metadata = {} + + mock_config = Mock() + mock_config.ai_provider = mock_ai_provider + mock_config.external_model_id = None + + with patch.object(AIClientService, '_get_app_provider_config', return_value=mock_config): + service = AIClientService() + provider, model = service.get_client_and_model( + app=Mock(), + context='response', + capability='text' + ) + + assert provider == mock_provider + assert model == 'default' + + @patch('core.services.ai_client_service.AIProviderFactory') + def test_get_client_and_model_get_models_exception(self, mock_factory_class): + mock_factory = Mock() + mock_factory_class.return_value = mock_factory + + mock_provider = Mock() + mock_provider.get_models.side_effect = Exception("API error") + mock_factory.create_provider.return_value = mock_provider + + mock_ai_provider = Mock() + mock_ai_provider.provider = 'gemini' + mock_ai_provider.provider_api_key = 'test_key' + mock_ai_provider.metadata = {} + + mock_config = Mock() + mock_config.ai_provider = mock_ai_provider + mock_config.external_model_id = None + + with patch.object(AIClientService, '_get_app_provider_config', return_value=mock_config): + service = AIClientService() + provider, model = service.get_client_and_model( + app=Mock(), + context='response', + capability='text' + ) + + assert provider == mock_provider + assert model == 'default' + + @patch('core.services.ai_client_service.AIProviderFactory') + def test_get_client_and_model_uses_first_available_model(self, mock_factory_class): + mock_factory = Mock() + mock_factory_class.return_value = mock_factory + + mock_provider = Mock() + mock_provider.get_models.return_value = [{'name': 'gemini-pro'}, {'name': 'gemini-flash'}] + mock_factory.create_provider.return_value = mock_provider + + mock_ai_provider = Mock() + mock_ai_provider.provider = 'gemini' + mock_ai_provider.provider_api_key = 'test_key' + mock_ai_provider.metadata = {} + + mock_config = Mock() + mock_config.ai_provider = mock_ai_provider + mock_config.external_model_id = None + + with patch.object(AIClientService, '_get_app_provider_config', return_value=mock_config): + service = AIClientService() + provider, model = service.get_client_and_model( + app=Mock(), + context='response', + capability='text' + ) + + assert provider == mock_provider + assert model == 'gemini-pro' + + @patch('core.services.ai_client_service.AppAIProvider.objects') + def test_get_app_provider_config_builtin_first(self, mock_objects): + mock_app = Mock() + mock_config = Mock() + mock_config.ai_provider = Mock() + + mock_queryset = Mock() + mock_queryset.select_related.return_value.first.return_value = mock_config + mock_objects.filter.return_value = mock_queryset + + service = AIClientService() + config = service._get_app_provider_config(mock_app, 'response', 'text') + + mock_objects.filter.assert_called_once_with( + application=mock_app, + context='response', + capability='text', + is_active=True, + ai_provider__is_builtin=True + ) + assert config == mock_config + + @patch('core.services.ai_client_service.AppAIProvider.objects') + def test_get_app_provider_config_fallback_to_non_builtin(self, mock_objects): + mock_app = Mock() + mock_config = Mock() + mock_config.ai_provider = Mock() + + mock_queryset_builtin = Mock() + mock_queryset_builtin.select_related.return_value.first.return_value = None + + mock_queryset_fallback = Mock() + mock_queryset_fallback.select_related.return_value.order_by.return_value.first.return_value = mock_config + + mock_objects.filter.side_effect = [mock_queryset_builtin, mock_queryset_fallback] + + service = AIClientService() + config = service._get_app_provider_config(mock_app, 'response', 'text') + + assert mock_objects.filter.call_count == 2 + assert config == mock_config + + @patch('core.services.ai_client_service.AppAIProvider.objects') + def test_get_app_provider_config_no_config_found(self, mock_objects): + mock_app = Mock() + + mock_queryset_builtin = Mock() + mock_queryset_builtin.select_related.return_value.first.return_value = None + + mock_queryset_fallback = Mock() + mock_queryset_fallback.select_related.return_value.order_by.return_value.first.return_value = None + + mock_objects.filter.side_effect = [mock_queryset_builtin, mock_queryset_fallback] + + service = AIClientService() + config = service._get_app_provider_config(mock_app, 'response', 'text') + + assert config is None + + @patch('core.services.ai_client_service.AppAIProvider.objects') + def test_get_app_provider_config_with_different_context_and_capability(self, mock_objects): + mock_app = Mock() + mock_config = Mock() + mock_config.ai_provider = Mock() + + mock_queryset = Mock() + mock_queryset.select_related.return_value.first.return_value = mock_config + mock_objects.filter.return_value = mock_queryset + + service = AIClientService() + config = service._get_app_provider_config(mock_app, 'embedding', 'embedding') + + mock_objects.filter.assert_called_once_with( + application=mock_app, + context='embedding', + capability='embedding', + is_active=True, + ai_provider__is_builtin=True + ) + assert config == mock_config diff --git a/backend/core/tests/test_services/test_ai_provider_factory.py b/backend/core/tests/test_services/test_ai_provider_factory.py new file mode 100644 index 0000000..9d3737a --- /dev/null +++ b/backend/core/tests/test_services/test_ai_provider_factory.py @@ -0,0 +1,136 @@ +import pytest +from unittest.mock import Mock, patch + +from core.services.factories.ai_provider_factory import AIProviderFactory + + +@pytest.mark.unit +class TestAIProviderFactory: + def test_provider_classes_contains_expected_providers(self): + assert 'gemini' in AIProviderFactory.PROVIDER_CLASSES + assert 'custom' in AIProviderFactory.PROVIDER_CLASSES + + def test_create_provider_gemini(self): + mock_provider = Mock() + mock_gemini_class = Mock(return_value=mock_provider) + + with patch.object(AIProviderFactory, 'PROVIDER_CLASSES', {'gemini': mock_gemini_class}): + provider = AIProviderFactory.create_provider('gemini', 'test_api_key', {'model': 'gemini-pro'}) + + mock_gemini_class.assert_called_once_with(api_key='test_api_key', config={'model': 'gemini-pro'}) + assert provider == mock_provider + + def test_create_provider_custom(self): + mock_provider = Mock() + mock_custom_class = Mock(return_value=mock_provider) + + with patch.object(AIProviderFactory, 'PROVIDER_CLASSES', {'custom': mock_custom_class}): + provider = AIProviderFactory.create_provider('custom', 'test_api_key', {'base_url': 'https://api.example.com'}) + + mock_custom_class.assert_called_once_with(api_key='test_api_key', config={'base_url': 'https://api.example.com'}) + assert provider == mock_provider + + def test_create_provider_case_insensitive(self): + mock_provider = Mock() + mock_gemini_class = Mock(return_value=mock_provider) + + with patch.object(AIProviderFactory, 'PROVIDER_CLASSES', {'gemini': mock_gemini_class}): + provider = AIProviderFactory.create_provider('GEMINI', 'test_api_key') + + mock_gemini_class.assert_called_once_with(api_key='test_api_key', config={}) + assert provider == mock_provider + + def test_create_provider_with_none_config(self): + mock_provider = Mock() + mock_gemini_class = Mock(return_value=mock_provider) + + with patch.object(AIProviderFactory, 'PROVIDER_CLASSES', {'gemini': mock_gemini_class}): + provider = AIProviderFactory.create_provider('gemini', 'test_api_key', None) + + mock_gemini_class.assert_called_once_with(api_key='test_api_key', config={}) + assert provider == mock_provider + + def test_create_provider_unsupported_type(self): + with pytest.raises(ValueError) as exc_info: + AIProviderFactory.create_provider('unsupported_provider', 'test_api_key') + + assert "Unsupported provider type: unsupported_provider" in str(exc_info.value) + assert "Supported providers:" in str(exc_info.value) + + def test_create_provider_initialization_error(self): + mock_gemini_class = Mock(side_effect=Exception("Initialization failed")) + + with patch.object(AIProviderFactory, 'PROVIDER_CLASSES', {'gemini': mock_gemini_class}): + with pytest.raises(ValueError) as exc_info: + AIProviderFactory.create_provider('gemini', 'test_api_key') + + assert "Failed to create gemini provider: Initialization failed" in str(exc_info.value) + + def test_validate_provider_gemini(self): + mock_provider = Mock() + mock_provider.validate_connection.return_value = (True, [{'test': 'result'}]) + mock_gemini_class = Mock(return_value=mock_provider) + + with patch.object(AIProviderFactory, 'PROVIDER_CLASSES', {'gemini': mock_gemini_class}): + is_valid, result = AIProviderFactory.validate_provider('gemini', 'test_api_key', {'model': 'gemini-pro'}) + + assert is_valid is True + assert result == [{'test': 'result'}] + mock_provider.validate_connection.assert_called_once() + + def test_validate_provider_custom(self): + mock_provider = Mock() + mock_provider.validate_connection.return_value = (False, []) + mock_custom_class = Mock(return_value=mock_provider) + + with patch.object(AIProviderFactory, 'PROVIDER_CLASSES', {'custom': mock_custom_class}): + is_valid, result = AIProviderFactory.validate_provider('custom', 'test_api_key') + + assert is_valid is False + assert result == [] + mock_provider.validate_connection.assert_called_once() + + def test_validate_provider_case_insensitive(self): + mock_provider = Mock() + mock_provider.validate_connection.return_value = (True, []) + mock_gemini_class = Mock(return_value=mock_provider) + + with patch.object(AIProviderFactory, 'PROVIDER_CLASSES', {'gemini': mock_gemini_class}): + is_valid, result = AIProviderFactory.validate_provider('GEMINI', 'test_api_key') + + assert is_valid is True + assert result == [] + + def test_validate_provider_with_none_config(self): + mock_provider = Mock() + mock_provider.validate_connection.return_value = (True, []) + mock_gemini_class = Mock(return_value=mock_provider) + + with patch.object(AIProviderFactory, 'PROVIDER_CLASSES', {'gemini': mock_gemini_class}): + is_valid, result = AIProviderFactory.validate_provider('gemini', 'test_api_key', None) + + assert is_valid is True + assert result == [] + + def test_validate_provider_unsupported_type(self): + with pytest.raises(ValueError) as exc_info: + AIProviderFactory.validate_provider('unsupported_provider', 'test_api_key') + + assert "Unsupported provider type: unsupported_provider" in str(exc_info.value) + assert "Supported providers:" in str(exc_info.value) + + def test_validate_provider_exception_handling(self): + mock_gemini_class = Mock(side_effect=Exception("Validation failed")) + + with patch.object(AIProviderFactory, 'PROVIDER_CLASSES', {'gemini': mock_gemini_class}): + is_valid, result = AIProviderFactory.validate_provider('gemini', 'test_api_key') + + assert is_valid is False + assert result == [] + + def test_get_supported_providers(self): + providers = AIProviderFactory.get_supported_providers() + + assert isinstance(providers, list) + assert 'gemini' in providers + assert 'custom' in providers diff --git a/backend/core/tests/test_services/test_content_quality_filter.py b/backend/core/tests/test_services/test_content_quality_filter.py new file mode 100644 index 0000000..0d7e704 --- /dev/null +++ b/backend/core/tests/test_services/test_content_quality_filter.py @@ -0,0 +1,331 @@ +import pytest +import emoji +from unittest.mock import Mock, patch + +from core.services.content_quality_filter import ( + EmojiDetector, + BotCommentDetector, + BoilerplateDetector, + ContentTypeHandler, + ContentQualityFilter +) + + +@pytest.mark.unit +class TestEmojiDetector: + def test_is_emoji_only_with_emoji_only(self): + detector = EmojiDetector() + assert detector.is_emoji_only('👍') is True + + def test_is_emoji_only_with_plus_one(self): + detector = EmojiDetector() + assert detector.is_emoji_only('+1') is True + + def test_is_emoji_only_with_text(self): + detector = EmojiDetector() + assert detector.is_emoji_only('Hello world') is False + + def test_is_emoji_only_with_mixed_content(self): + detector = EmojiDetector() + assert detector.is_emoji_only('Hello 👍') is False + + def test_is_emoji_only_with_empty_string(self): + detector = EmojiDetector() + assert detector.is_emoji_only('') is False + + def test_is_emoji_only_with_whitespace(self): + detector = EmojiDetector() + assert detector.is_emoji_only(' ') is False + + def test_is_emoji_only_with_multiple_emojis(self): + detector = EmojiDetector() + assert detector.is_emoji_only('👍🎉❤️') is True + + def test_remove_emojis(self): + detector = EmojiDetector() + result = detector.remove_emojis('Hello 👍 world') + assert result == 'Hello world' + + def test_remove_emojis_removes_plus_one(self): + detector = EmojiDetector() + result = detector.remove_emojis('Thanks +1') + assert result == 'Thanks' + + def test_remove_emojis_with_empty_string(self): + detector = EmojiDetector() + result = detector.remove_emojis('') + assert result == '' + + def test_remove_emojis_with_whitespace(self): + detector = EmojiDetector() + result = detector.remove_emojis(' ') + assert result == ' ' + + def test_remove_emojis_normalizes_whitespace(self): + detector = EmojiDetector() + result = detector.remove_emojis('Hello 👍 world') + assert result == 'Hello world' + + +@pytest.mark.unit +class TestBotCommentDetector: + def test_is_bot_comment_with_bot_author(self): + detector = BotCommentDetector() + assert detector.is_bot_comment(author='dependabot[bot]') is True + + def test_is_bot_comment_with_codecov_author(self): + detector = BotCommentDetector() + assert detector.is_bot_comment(author='codecov[bot]') is True + + def test_is_bot_comment_with_github_actions_author(self): + detector = BotCommentDetector() + assert detector.is_bot_comment(author='github-actions[bot]') is True + + def test_is_bot_comment_with_bot_content(self): + detector = BotCommentDetector() + assert detector.is_bot_comment(content='This PR was merged') is True + + def test_is_bot_comment_with_build_passed_content(self): + detector = BotCommentDetector() + assert detector.is_bot_comment(content='Build passed') is True + + def test_is_bot_comment_with_ci_cd_content(self): + detector = BotCommentDetector() + assert detector.is_bot_comment(content='CI/CD pipeline') is True + + def test_is_bot_comment_with_normal_author(self): + detector = BotCommentDetector() + assert detector.is_bot_comment(author='john-doe') is False + + def test_is_bot_comment_with_normal_content(self): + detector = BotCommentDetector() + assert detector.is_bot_comment(content='This looks good to me') is False + + def test_is_bot_comment_with_none_author_and_content(self): + detector = BotCommentDetector() + assert detector.is_bot_comment() is False + + def test_is_bot_comment_case_insensitive(self): + detector = BotCommentDetector() + assert detector.is_bot_comment(author='DEPENDABOT[bot]') is True + + def test_is_bot_comment_with_merge_conflict(self): + detector = BotCommentDetector() + assert detector.is_bot_comment(content='Merge conflict detected') is True + + +@pytest.mark.unit +class TestBoilerplateDetector: + def test_is_boilerplate_with_copyright(self): + detector = BoilerplateDetector() + assert detector.is_boilerplate_content('Copyright 2024 Company') is True + + def test_is_boilerplate_with_all_rights_reserved(self): + detector = BoilerplateDetector() + assert detector.is_boilerplate_content('All rights reserved') is True + + def test_is_boilerplate_with_licensed_under(self): + detector = BoilerplateDetector() + assert detector.is_boilerplate_content('Licensed under MIT') is True + + def test_is_boilerplate_with_generated_by(self): + detector = BoilerplateDetector() + assert detector.is_boilerplate_content('Generated by tool') is True + + def test_is_boilerplate_with_do_not_edit(self): + detector = BoilerplateDetector() + assert detector.is_boilerplate_content('DO NOT EDIT THIS FILE') is True + + def test_is_boilerplate_with_template_comment(self): + detector = BoilerplateDetector() + assert detector.is_boilerplate_content('# This is a template file') is True + + def test_is_boilerplate_with_normal_content(self): + detector = BoilerplateDetector() + assert detector.is_boilerplate_content('This is normal code') is False + + def test_is_boilerplate_with_empty_string(self): + detector = BoilerplateDetector() + assert detector.is_boilerplate_content('') is False + + def test_is_boilerplate_with_whitespace(self): + detector = BoilerplateDetector() + assert detector.is_boilerplate_content(' ') is False + + def test_is_boilerplate_with_java_import(self): + detector = BoilerplateDetector() + assert detector.is_boilerplate_content('import java.util.List;') is True + + def test_is_boilerplate_with_python_import(self): + detector = BoilerplateDetector() + assert detector.is_boilerplate_content('from typing import List') is True + + def test_is_boilerplate_case_insensitive(self): + detector = BoilerplateDetector() + assert detector.is_boilerplate_content('COPYRIGHT 2024') is True + + +@pytest.mark.unit +class TestContentTypeHandler: + def test_init(self): + boilerplate_detector = BoilerplateDetector() + handler = ContentTypeHandler(boilerplate_detector) + + assert handler.boilerplate_detector == boilerplate_detector + assert 'text' in handler._content_handlers + assert 'file' in handler._content_handlers + assert 'github_issue' in handler._content_handlers + assert 'github_pr' in handler._content_handlers + + def test_should_ingest_text_normal_content(self): + boilerplate_detector = BoilerplateDetector() + handler = ContentTypeHandler(boilerplate_detector) + + assert handler._should_ingest_text('This is normal text') is True + + def test_should_ingest_text_boilerplate_content(self): + boilerplate_detector = BoilerplateDetector() + handler = ContentTypeHandler(boilerplate_detector) + + assert handler._should_ingest_text('Copyright 2024') is False + + def test_should_ingest_file_normal_content(self): + boilerplate_detector = BoilerplateDetector() + handler = ContentTypeHandler(boilerplate_detector) + + assert handler._should_ingest_file('def function():\n pass') is True + + def test_should_ingest_file_boilerplate_content(self): + boilerplate_detector = BoilerplateDetector() + handler = ContentTypeHandler(boilerplate_detector) + + assert handler._should_ingest_file('Copyright 2024') is False + + def test_should_ingest_file_header_comment_block(self): + boilerplate_detector = BoilerplateDetector() + handler = ContentTypeHandler(boilerplate_detector) + + content = '/**\n * @file\n * @brief\n */' + assert handler._should_ingest_file(content) is False + + def test_should_ingest_file_python_docstring(self): + boilerplate_detector = BoilerplateDetector() + handler = ContentTypeHandler(boilerplate_detector) + + content = '"""Module docstring"""' + assert handler._should_ingest_file(content) is False + + def test_should_ingest_file_generated_marker(self): + boilerplate_detector = BoilerplateDetector() + handler = ContentTypeHandler(boilerplate_detector) + + content = '# This file is generated' + assert handler._should_ingest_file(content) is False + + def test_should_ingest_github_issue(self): + boilerplate_detector = BoilerplateDetector() + handler = ContentTypeHandler(boilerplate_detector) + + assert handler._should_ingest_github_issue('Any content') is True + + def test_should_ingest_github_pr(self): + boilerplate_detector = BoilerplateDetector() + handler = ContentTypeHandler(boilerplate_detector) + + assert handler._should_ingest_github_pr('Any content') is True + + def test_should_ingest_content_type_text(self): + boilerplate_detector = BoilerplateDetector() + handler = ContentTypeHandler(boilerplate_detector) + + assert handler.should_ingest_content_type('Normal text', 'text') is True + + def test_should_ingest_content_type_file(self): + boilerplate_detector = BoilerplateDetector() + handler = ContentTypeHandler(boilerplate_detector) + + assert handler.should_ingest_content_type('def foo():', 'file') is True + + def test_should_ingest_content_type_unknown_defaults_to_text(self): + boilerplate_detector = BoilerplateDetector() + handler = ContentTypeHandler(boilerplate_detector) + + assert handler.should_ingest_content_type('Normal text', 'unknown') is True + + +@pytest.mark.unit +class TestContentQualityFilter: + def test_init(self): + filter = ContentQualityFilter() + + assert filter.emoji_detector is not None + assert filter.bot_detector is not None + assert filter.boilerplate_detector is not None + assert filter.content_type_handler is not None + + def test_calculate_quality_score_empty_content(self): + filter = ContentQualityFilter() + assert filter.calculate_quality_score('') == 0.0 + + def test_calculate_quality_score_short_content(self): + filter = ContentQualityFilter() + score = filter.calculate_quality_score('Hi') + assert score == 0.3 + + def test_calculate_quality_score_long_content(self): + filter = ContentQualityFilter() + score = filter.calculate_quality_score('This is a longer message') + assert score >= 0.4 + + def test_calculate_quality_score_with_filler_words(self): + filter = ContentQualityFilter() + score = filter.calculate_quality_score('thanks thank you lgmt ack nice good great') + assert score == 0.4 + + def test_calculate_quality_score_with_meaningful_words(self): + filter = ContentQualityFilter() + score = filter.calculate_quality_score('This function fixes the API error') + assert score >= 0.7 + + def test_calculate_quality_score_max_score(self): + filter = ContentQualityFilter() + score = filter.calculate_quality_score('This function fixes the API error in the code implementation') + assert score == 1.0 + + def test_should_ingest_emoji_only(self): + filter = ContentQualityFilter() + assert filter.should_ingest('👍') is False + + def test_should_ingest_bot_comment(self): + filter = ContentQualityFilter() + assert filter.should_ingest('This PR was merged') is False + + def test_should_ingest_bot_author(self): + filter = ContentQualityFilter() + assert filter.should_ingest('Normal comment', author='dependabot[bot]') is False + + def test_should_ingest_normal_content(self): + filter = ContentQualityFilter() + assert filter.should_ingest('This is a normal comment') is True + + def test_should_ingest_boilerplate_content(self): + filter = ContentQualityFilter() + assert filter.should_ingest('Copyright 2024', content_type='text') is False + + def test_should_ingest_file_header(self): + filter = ContentQualityFilter() + content = '/**\n * @file\n */' + assert filter.should_ingest(content, content_type='file') is False + + def test_should_ingest_github_issue(self): + filter = ContentQualityFilter() + assert filter.should_ingest('Issue description', content_type='github_issue') is True + + def test_should_ingest_github_pr(self): + filter = ContentQualityFilter() + assert filter.should_ingest('PR description', content_type='github_pr') is True + + def test_remove_emojis(self): + filter = ContentQualityFilter() + result = filter.remove_emojis('Hello 👍 world') + assert result == 'Hello world' diff --git a/backend/core/tests/test_services/test_custom_provider.py b/backend/core/tests/test_services/test_custom_provider.py new file mode 100644 index 0000000..34ef019 --- /dev/null +++ b/backend/core/tests/test_services/test_custom_provider.py @@ -0,0 +1,142 @@ +import pytest +from unittest.mock import Mock, patch +from pydantic import BaseModel + +from core.services.providers.ai.custom_provider import CustomProvider + + +class MockResponseSchema(BaseModel): + answer: str + + +@pytest.mark.unit +class TestCustomProvider: + def test_init_raises_not_implemented(self): + with pytest.raises(NotImplementedError) as exc_info: + CustomProvider('test_api_key') + + assert "Not implemented" in str(exc_info.value) + + def test_init_with_config_raises_not_implemented(self): + with pytest.raises(NotImplementedError) as exc_info: + CustomProvider('test_api_key', {'model': 'custom'}) + + assert "Not implemented" in str(exc_info.value) + + def test_generate_text_raises_not_implemented(self): + with patch.object(CustomProvider, '__init__', lambda self, api_key, config=None: None): + provider = CustomProvider('test_api_key') + + with pytest.raises(NotImplementedError) as exc_info: + provider.generate_text('custom-model', 'test content') + + assert "Not implemented" in str(exc_info.value) + + def test_generate_with_conversation_with_user_messages(self): + with patch.object(CustomProvider, '__init__', lambda self, api_key, config=None: None): + provider = CustomProvider('test_api_key') + provider.generate_text = Mock(return_value='{"answer": "response"}') + + messages = [ + {'role': 'system', 'content': 'You are helpful'}, + {'role': 'user', 'content': 'Hello'}, + {'role': 'assistant', 'content': 'Hi there'}, + {'role': 'user', 'content': 'How are you?'} + ] + + result, tool_calls = provider.generate_with_conversation( + 'custom-model', + messages, + None, + MockResponseSchema + ) + + provider.generate_text.assert_called_once_with('custom-model', 'How are you?') + assert result == '{"answer": "response"}' + assert tool_calls == [] + + def test_generate_with_conversation_single_user_message(self): + with patch.object(CustomProvider, '__init__', lambda self, api_key, config=None: None): + provider = CustomProvider('test_api_key') + provider.generate_text = Mock(return_value='{"answer": "response"}') + + messages = [ + {'role': 'user', 'content': 'Hello'} + ] + + result, tool_calls = provider.generate_with_conversation( + 'custom-model', + messages, + None, + MockResponseSchema + ) + + provider.generate_text.assert_called_once_with('custom-model', 'Hello') + assert result == '{"answer": "response"}' + assert tool_calls == [] + + def test_generate_with_conversation_no_user_messages(self): + with patch.object(CustomProvider, '__init__', lambda self, api_key, config=None: None): + provider = CustomProvider('test_api_key') + provider.generate_text = Mock(return_value='{"answer": "response"}') + + messages = [ + {'role': 'system', 'content': 'You are helpful'}, + {'role': 'assistant', 'content': 'Hi there'} + ] + + result, tool_calls = provider.generate_with_conversation( + 'custom-model', + messages, + None, + MockResponseSchema + ) + + provider.generate_text.assert_called_once_with('custom-model', '') + assert result == '{"answer": "response"}' + assert tool_calls == [] + + def test_generate_with_conversation_empty_messages(self): + with patch.object(CustomProvider, '__init__', lambda self, api_key, config=None: None): + provider = CustomProvider('test_api_key') + provider.generate_text = Mock(return_value='{"answer": "response"}') + + messages = [] + + result, tool_calls = provider.generate_with_conversation( + 'custom-model', + messages, + None, + MockResponseSchema + ) + + provider.generate_text.assert_called_once_with('custom-model', '') + assert result == '{"answer": "response"}' + assert tool_calls == [] + + def test_validate_connection_raises_not_implemented(self): + with patch.object(CustomProvider, '__init__', lambda self, api_key, config=None: None): + provider = CustomProvider('test_api_key') + + with pytest.raises(NotImplementedError) as exc_info: + provider.validate_connection() + + assert "Not implemented" in str(exc_info.value) + + def test_get_models_raises_not_implemented(self): + with patch.object(CustomProvider, '__init__', lambda self, api_key, config=None: None): + provider = CustomProvider('test_api_key') + + with pytest.raises(NotImplementedError) as exc_info: + provider.get_models() + + assert "Not implemented" in str(exc_info.value) + + def test_embed_raises_not_implemented(self): + with patch.object(CustomProvider, '__init__', lambda self, api_key, config=None: None): + provider = CustomProvider('test_api_key') + + with pytest.raises(NotImplementedError) as exc_info: + provider.embed('embedding-model', ['test text']) + + assert "Not implemented" in str(exc_info.value) diff --git a/backend/core/tests/test_services/test_duplicate_detector.py b/backend/core/tests/test_services/test_duplicate_detector.py new file mode 100644 index 0000000..35609c7 --- /dev/null +++ b/backend/core/tests/test_services/test_duplicate_detector.py @@ -0,0 +1,533 @@ +import pytest +from unittest.mock import Mock, patch, MagicMock +from django.core.exceptions import ObjectDoesNotExist + +from core.services.duplicate_detector import DuplicateDetector +from core.models.content_hash import ContentHash +from core.models.knowledge_base import KnowledgeBase + + +@pytest.mark.unit +class TestDuplicateDetector: + def test_init(self): + with patch('core.services.duplicate_detector.AIClientService'), \ + patch('core.services.duplicate_detector.ContentQualityFilter'): + detector = DuplicateDetector() + + assert detector._ai_client_service is not None + assert detector._quality_filter is not None + assert detector._replacement_triggered is False + + @patch('core.services.duplicate_detector.ContentHash.generate_content_hash') + def test_get_content_fingerprint(self, mock_generate_hash): + mock_generate_hash.return_value = 'test_hash_123' + + with patch('core.services.duplicate_detector.AIClientService'), \ + patch('core.services.duplicate_detector.ContentQualityFilter'): + detector = DuplicateDetector() + result = detector.get_content_fingerprint('test content') + + mock_generate_hash.assert_called_once_with('test content') + assert result == 'test_hash_123' + + @patch('core.services.duplicate_detector.ContentHash.objects.get') + @patch('core.services.duplicate_detector.ContentHash.generate_content_hash') + def test_get_embedding_cached(self, mock_generate_hash, mock_get): + mock_cached_hash = Mock() + mock_cached_hash.embedding = [0.1, 0.2, 0.3] + mock_generate_hash.return_value = 'hash123' + mock_get.return_value = mock_cached_hash + + with patch('core.services.duplicate_detector.AIClientService'), \ + patch('core.services.duplicate_detector.ContentQualityFilter'): + detector = DuplicateDetector() + mock_app = Mock() + + result = detector._get_embedding('test content', mock_app) + + assert result == [0.1, 0.2, 0.3] + mock_get.assert_called_once() + + @patch('core.services.duplicate_detector.ContentHash.objects.get') + @patch('core.services.duplicate_detector.ContentHash.objects.update_or_create') + @patch('core.services.duplicate_detector.ContentHash.generate_content_hash') + def test_get_embedding_not_cached(self, mock_generate_hash, mock_update_or_create, mock_get): + mock_generate_hash.return_value = 'hash123' + mock_get.side_effect = ContentHash.DoesNotExist() + mock_update_or_create.return_value = (Mock(), True) + + with patch('core.services.duplicate_detector.AIClientService') as mock_ai_service, \ + patch('core.services.duplicate_detector.ContentQualityFilter'): + mock_provider = Mock() + mock_provider.embed.return_value = [[0.1, 0.2, 0.3]] + mock_ai_service.return_value.get_client_and_model.return_value = (mock_provider, 'model') + + detector = DuplicateDetector() + mock_app = Mock() + + result = detector._get_embedding('test content', mock_app) + + assert result == [0.1, 0.2, 0.3] + mock_update_or_create.assert_called_once() + + @patch('core.services.duplicate_detector.ContentHash.objects.get') + @patch('core.services.duplicate_detector.ContentHash.generate_content_hash') + def test_get_embedding_no_provider(self, mock_generate_hash, mock_get): + mock_generate_hash.return_value = 'hash123' + mock_get.side_effect = ContentHash.DoesNotExist() + + with patch('core.services.duplicate_detector.AIClientService') as mock_ai_service, \ + patch('core.services.duplicate_detector.ContentQualityFilter'): + mock_ai_service.return_value.get_client_and_model.return_value = (None, None) + + detector = DuplicateDetector() + mock_app = Mock() + + result = detector._get_embedding('test content', mock_app) + + assert result is None + + @patch('core.services.duplicate_detector.ContentHash') + def test_generate_new_embedding_success(self, mock_content_hash): + mock_content_hash.generate_content_hash.return_value = 'hash123' + + with patch('core.services.duplicate_detector.AIClientService') as mock_ai_service, \ + patch('core.services.duplicate_detector.ContentQualityFilter'): + mock_provider = Mock() + mock_provider.embed.return_value = [[0.1, 0.2, 0.3]] + mock_ai_service.return_value.get_client_and_model.return_value = (mock_provider, 'model') + + detector = DuplicateDetector() + mock_app = Mock() + + result = detector._generate_new_embedding('test content', mock_app) + + assert result == [0.1, 0.2, 0.3] + + @patch('core.services.duplicate_detector.ContentHash') + def test_generate_new_embedding_failure(self, mock_content_hash): + mock_content_hash.generate_content_hash.return_value = 'hash123' + + with patch('core.services.duplicate_detector.AIClientService') as mock_ai_service, \ + patch('core.services.duplicate_detector.ContentQualityFilter'): + mock_ai_service.return_value.get_client_and_model.side_effect = Exception("API error") + + detector = DuplicateDetector() + mock_app = Mock() + + result = detector._generate_new_embedding('test content', mock_app) + + assert result is None + + def test_cosine_similarity_identical_vectors(self): + with patch('core.services.duplicate_detector.AIClientService'), \ + patch('core.services.duplicate_detector.ContentQualityFilter'): + detector = DuplicateDetector() + + vec1 = [1.0, 2.0, 3.0] + vec2 = [1.0, 2.0, 3.0] + + result = detector._cosine_similarity(vec1, vec2) + + assert result == 1.0 + + def test_cosine_similarity_orthogonal_vectors(self): + with patch('core.services.duplicate_detector.AIClientService'), \ + patch('core.services.duplicate_detector.ContentQualityFilter'): + detector = DuplicateDetector() + + vec1 = [1.0, 0.0] + vec2 = [0.0, 1.0] + + result = detector._cosine_similarity(vec1, vec2) + + assert result == 0.0 + + def test_cosine_similarity_empty_vectors(self): + with patch('core.services.duplicate_detector.AIClientService'), \ + patch('core.services.duplicate_detector.ContentQualityFilter'): + detector = DuplicateDetector() + + result = detector._cosine_similarity([], []) + + assert result == 0.0 + + def test_cosine_similarity_mismatched_lengths(self): + with patch('core.services.duplicate_detector.AIClientService'), \ + patch('core.services.duplicate_detector.ContentQualityFilter'): + detector = DuplicateDetector() + + result = detector._cosine_similarity([1.0, 2.0], [1.0]) + + assert result == 0.0 + + + @patch('core.services.duplicate_detector.ContentHash.objects.get') + @patch('core.services.duplicate_detector.ContentHash.generate_content_hash') + def test_find_similar_content_no_embedding(self, mock_generate_hash, mock_get): + mock_generate_hash.return_value = 'hash123' + mock_get.side_effect = ContentHash.DoesNotExist() + + with patch('core.services.duplicate_detector.AIClientService') as mock_ai_service, \ + patch('core.services.duplicate_detector.ContentQualityFilter'): + mock_ai_service.return_value.get_client_and_model.return_value = (None, None) + + detector = DuplicateDetector() + mock_app = Mock() + + result = detector.find_similar_content('test content', mock_app) + + assert result == [] + + @patch('core.services.duplicate_detector.KnowledgeBase') + @patch('core.services.duplicate_detector.ContentHash') + def test_find_similar_content_empty_content(self, mock_content_hash, mock_kb): + detector = DuplicateDetector() + + result = detector.find_similar_content('', Mock()) + + assert result == [] + + + @patch('core.services.duplicate_detector.KnowledgeBase.objects.filter') + @patch('core.services.duplicate_detector.ContentHash.objects.get') + @patch('core.services.duplicate_detector.ContentHash.objects.update_or_create') + @patch('core.services.duplicate_detector.ContentHash.generate_content_hash') + def test_is_semantic_duplicate_no_similar(self, mock_generate_hash, mock_update_or_create, mock_get, mock_kb_filter): + mock_generate_hash.return_value = 'hash123' + mock_get.side_effect = ContentHash.DoesNotExist() + mock_update_or_create.return_value = (Mock(), True) + + mock_kb_filter.return_value.exclude.return_value = [] + + with patch('core.services.duplicate_detector.AIClientService') as mock_ai_service, \ + patch('core.services.duplicate_detector.ContentQualityFilter'): + mock_provider = Mock() + mock_provider.embed.return_value = [[0.1, 0.2, 0.3]] + mock_ai_service.return_value.get_client_and_model.return_value = (mock_provider, 'model') + + detector = DuplicateDetector() + mock_app = Mock() + + result = detector.is_semantic_duplicate('test content', mock_app) + + assert result is False + + @patch('core.services.duplicate_detector.KnowledgeBase') + @patch('core.services.duplicate_detector.ContentHash') + def test_should_replace_content_quality_improvement(self, mock_content_hash, mock_kb): + mock_kb_obj = Mock() + mock_kb_obj.metadata = {'content': 'old low quality content'} + mock_kb.objects.get.return_value = mock_kb_obj + + with patch('core.services.duplicate_detector.AIClientService'), \ + patch('core.services.duplicate_detector.ContentQualityFilter') as mock_quality: + mock_quality.return_value.calculate_quality_score.side_effect = [0.8, 0.5] + + detector = DuplicateDetector() + mock_app = Mock() + + result = detector.should_replace_content('new high quality content', 'kb-uuid', mock_app) + + assert result is True + + @patch('core.services.duplicate_detector.KnowledgeBase') + @patch('core.services.duplicate_detector.ContentHash') + def test_should_replace_content_no_improvement(self, mock_content_hash, mock_kb): + mock_kb_obj = Mock() + mock_kb_obj.metadata = {'content': 'old high quality content'} + mock_kb.objects.get.return_value = mock_kb_obj + + with patch('core.services.duplicate_detector.AIClientService'), \ + patch('core.services.duplicate_detector.ContentQualityFilter') as mock_quality: + mock_quality.return_value.calculate_quality_score.side_effect = [0.5, 0.8] + + detector = DuplicateDetector() + mock_app = Mock() + + result = detector.should_replace_content('new low quality content', 'kb-uuid', mock_app) + + assert result is False + + @patch('core.services.duplicate_detector.KnowledgeBase.objects.get') + def test_should_replace_content_kb_not_found(self, mock_get): + mock_get.side_effect = KnowledgeBase.DoesNotExist() + + with patch('core.services.duplicate_detector.AIClientService'), \ + patch('core.services.duplicate_detector.ContentQualityFilter'): + detector = DuplicateDetector() + mock_app = Mock() + + result = detector.should_replace_content('new content', 'kb-uuid', mock_app) + + assert result is True + + @patch('core.services.duplicate_detector.KnowledgeBase') + @patch('core.services.duplicate_detector.ContentHash') + def test_replace_content_success(self, mock_content_hash, mock_kb): + mock_kb_obj = Mock() + mock_kb_obj.metadata = {} + mock_kb_obj.chunks.all.return_value = [] + mock_kb.objects.get.return_value = mock_kb_obj + + with patch('core.services.duplicate_detector.AIClientService'), \ + patch('core.services.duplicate_detector.ContentQualityFilter'): + detector = DuplicateDetector() + mock_app = Mock() + + result = detector.replace_content('kb-uuid', 'new content', mock_app) + + assert result is True + mock_kb_obj.save.assert_called_once() + + @patch('core.services.duplicate_detector.KnowledgeBase.objects.get') + def test_replace_content_kb_not_found(self, mock_get): + mock_get.side_effect = Exception("DoesNotExist") + + with patch('core.services.duplicate_detector.AIClientService'), \ + patch('core.services.duplicate_detector.ContentQualityFilter'): + detector = DuplicateDetector() + mock_app = Mock() + + result = detector.replace_content('kb-uuid', 'new content', mock_app) + + assert result is False + + def test_cleanup_old_content(self): + mock_kb_obj = Mock() + mock_kb_obj.uuid = 'kb-uuid' + mock_kb_obj.metadata = {'content': 'old content'} + mock_kb_obj.application = Mock() + mock_chunks_qs = Mock() + mock_chunks_qs.delete.return_value = (0, {}) + mock_chunks_qs.__iter__ = Mock(return_value=iter([])) + mock_kb_obj.chunks.all.return_value = mock_chunks_qs + + with patch('core.services.duplicate_detector.AIClientService'), \ + patch('core.services.duplicate_detector.ContentQualityFilter'), \ + patch('core.services.ingestion.delete_vectors_from_qdrant'): + detector = DuplicateDetector() + detector.remove_content_hash = Mock() + + result = detector._cleanup_old_content(mock_kb_obj) + + assert result is True + + def test_cleanup_old_content_error(self): + mock_kb_obj = Mock() + mock_kb_obj.metadata = {} + mock_kb_obj.chunks.all.side_effect = Exception("Error") + + with patch('core.services.duplicate_detector.AIClientService'), \ + patch('core.services.duplicate_detector.ContentQualityFilter'): + detector = DuplicateDetector() + + result = detector._cleanup_old_content(mock_kb_obj) + + assert result is False + + @patch('core.services.duplicate_detector.KnowledgeBase.objects.filter') + @patch('core.services.duplicate_detector.ContentHash.objects.get') + @patch('core.services.duplicate_detector.ContentHash.objects.update_or_create') + @patch('core.services.duplicate_detector.ContentHash.generate_content_hash') + def test_handle_semantic_duplicate_no_similar(self, mock_generate_hash, mock_update_or_create, mock_get, mock_kb_filter): + mock_generate_hash.return_value = 'hash123' + mock_get.side_effect = ContentHash.DoesNotExist() + mock_update_or_create.return_value = (Mock(), True) + + mock_kb_filter.return_value.exclude.return_value = [] + + with patch('core.services.duplicate_detector.AIClientService') as mock_ai_service, \ + patch('core.services.duplicate_detector.ContentQualityFilter'): + mock_provider = Mock() + mock_provider.embed.return_value = [[0.1, 0.2, 0.3]] + mock_ai_service.return_value.get_client_and_model.return_value = (mock_provider, 'model') + + detector = DuplicateDetector() + mock_app = Mock() + + result = detector.handle_semantic_duplicate('new content', mock_app, 'kb-uuid') + + assert result is True + + @patch('core.services.duplicate_detector.KnowledgeBase.objects.filter') + @patch('core.services.duplicate_detector.KnowledgeBase.objects.get') + @patch('core.services.duplicate_detector.ContentHash.objects.get') + @patch('core.services.duplicate_detector.ContentHash.objects.update_or_create') + @patch('core.services.duplicate_detector.ContentHash.generate_content_hash') + def test_handle_semantic_duplicate_replace(self, mock_generate_hash, mock_update_or_create, mock_get, mock_kb_get, mock_kb_filter): + mock_generate_hash.return_value = 'hash123' + mock_get.side_effect = ContentHash.DoesNotExist() + mock_update_or_create.return_value = (Mock(), True) + + mock_kb_obj = Mock() + mock_kb_obj.uuid = 'similar-kb-uuid' + mock_kb_obj.metadata = {'content': 'old content'} + mock_kb_obj.chunks.all.return_value = [] + mock_kb_filter.return_value.exclude.return_value = [mock_kb_obj] + mock_kb_get.return_value = mock_kb_obj + + with patch('core.services.duplicate_detector.AIClientService') as mock_ai_service, \ + patch('core.services.duplicate_detector.ContentQualityFilter') as mock_quality, \ + patch('core.services.ingestion.delete_vectors_from_qdrant'): + mock_provider = Mock() + mock_provider.embed.side_effect = [[[0.1, 0.2, 0.3]], [[0.1, 0.2, 0.3]]] + mock_ai_service.return_value.get_client_and_model.return_value = (mock_provider, 'model') + mock_quality.return_value.calculate_quality_score.side_effect = [0.8, 0.5] + + detector = DuplicateDetector() + detector.remove_content_hash = Mock() + mock_app = Mock() + + result = detector.handle_semantic_duplicate('new content', mock_app, 'kb-uuid') + + assert result is False + assert detector._was_replacement_triggered() is True + + @patch('core.services.duplicate_detector.KnowledgeBase.objects.filter') + @patch('core.services.duplicate_detector.KnowledgeBase.objects.get') + @patch('core.services.duplicate_detector.ContentHash.objects.get') + @patch('core.services.duplicate_detector.ContentHash.objects.update_or_create') + @patch('core.services.duplicate_detector.ContentHash.generate_content_hash') + def test_handle_semantic_duplicate_keep_existing(self, mock_generate_hash, mock_update_or_create, mock_get, mock_kb_get, mock_kb_filter): + mock_generate_hash.return_value = 'hash123' + mock_get.side_effect = ContentHash.DoesNotExist() + mock_update_or_create.return_value = (Mock(), True) + + mock_kb_obj = Mock() + mock_kb_obj.uuid = 'similar-kb-uuid' + mock_kb_obj.metadata = {'content': 'high quality content'} + mock_kb_filter.return_value.exclude.return_value = [mock_kb_obj] + mock_kb_get.return_value = mock_kb_obj + + with patch('core.services.duplicate_detector.AIClientService') as mock_ai_service, \ + patch('core.services.duplicate_detector.ContentQualityFilter') as mock_quality: + mock_provider = Mock() + mock_provider.embed.side_effect = [[[0.1, 0.2, 0.3]], [[0.9, 0.8, 0.7]]] + mock_ai_service.return_value.get_client_and_model.return_value = (mock_provider, 'model') + mock_quality.return_value.calculate_quality_score.side_effect = [0.5, 0.8] + + detector = DuplicateDetector() + mock_app = Mock() + + result = detector.handle_semantic_duplicate('new content', mock_app, 'kb-uuid') + + assert result is False + assert detector._was_replacement_triggered() is False + + @patch('core.services.duplicate_detector.ContentHash') + def test_is_duplicate(self, mock_content_hash): + mock_content_hash.generate_content_hash.return_value = 'hash123' + mock_content_hash.objects.get.return_value = Mock() + + with patch('core.services.duplicate_detector.AIClientService'), \ + patch('core.services.duplicate_detector.ContentQualityFilter'): + detector = DuplicateDetector() + mock_app = Mock() + + result = detector.is_duplicate('test content', mock_app) + + assert result is True + + @patch('core.services.duplicate_detector.ContentHash.objects.get') + @patch('core.services.duplicate_detector.ContentHash.generate_content_hash') + def test_is_duplicate_not_found(self, mock_generate_hash, mock_get): + mock_generate_hash.return_value = 'hash123' + mock_get.side_effect = ContentHash.DoesNotExist() + + with patch('core.services.duplicate_detector.AIClientService'), \ + patch('core.services.duplicate_detector.ContentQualityFilter'): + detector = DuplicateDetector() + mock_app = Mock() + + result = detector.is_duplicate('test content', mock_app) + + assert result is False + + @patch('core.services.duplicate_detector.ContentHash') + def test_is_duplicate_empty_content(self, mock_content_hash): + with patch('core.services.duplicate_detector.AIClientService'), \ + patch('core.services.duplicate_detector.ContentQualityFilter'): + detector = DuplicateDetector() + + result = detector.is_duplicate('', Mock()) + + assert result is False + + @patch('core.services.duplicate_detector.ContentHash') + def test_store_content_hash(self, mock_content_hash): + mock_content_hash.generate_content_hash.return_value = 'hash123' + mock_content_hash.objects.create.return_value = Mock() + + with patch('core.services.duplicate_detector.AIClientService'), \ + patch('core.services.duplicate_detector.ContentQualityFilter'): + detector = DuplicateDetector() + mock_app = Mock() + + result = detector.store_content_hash('test content', mock_app) + + assert result is True + mock_content_hash.objects.create.assert_called_once() + + @patch('core.services.duplicate_detector.ContentHash') + def test_store_content_hash_empty(self, mock_content_hash): + with patch('core.services.duplicate_detector.AIClientService'), \ + patch('core.services.duplicate_detector.ContentQualityFilter'): + detector = DuplicateDetector() + + result = detector.store_content_hash('', Mock()) + + assert result is False + + @patch('core.services.duplicate_detector.ContentHash') + def test_remove_content_hash(self, mock_content_hash): + mock_content_hash.generate_content_hash.return_value = 'hash123' + mock_content_hash.objects.filter.return_value.delete.return_value = (1, {}) + + with patch('core.services.duplicate_detector.AIClientService'), \ + patch('core.services.duplicate_detector.ContentQualityFilter'): + detector = DuplicateDetector() + mock_app = Mock() + + result = detector.remove_content_hash('test content', mock_app) + + assert result is True + + @patch('core.services.duplicate_detector.ContentHash') + def test_remove_content_hash_not_found(self, mock_content_hash): + mock_content_hash.generate_content_hash.return_value = 'hash123' + mock_content_hash.objects.filter.return_value.delete.return_value = (0, {}) + + with patch('core.services.duplicate_detector.AIClientService'), \ + patch('core.services.duplicate_detector.ContentQualityFilter'): + detector = DuplicateDetector() + mock_app = Mock() + + result = detector.remove_content_hash('test content', mock_app) + + assert result is False + + @patch('core.services.duplicate_detector.ContentHash') + def test_get_duplicate_stats(self, mock_content_hash): + mock_content_hash.objects.filter.return_value.count.return_value = 10 + mock_content_hash.objects.filter.return_value.values_list.return_value.distinct.return_value = ['text', 'file'] + + with patch('core.services.duplicate_detector.AIClientService'), \ + patch('core.services.duplicate_detector.ContentQualityFilter'): + detector = DuplicateDetector() + mock_app = Mock() + + result = detector.get_duplicate_stats(mock_app) + + assert result['total_unique_content'] == 10 + assert 'content_type_breakdown' in result + + @patch('core.services.duplicate_detector.ContentHash') + def test_get_duplicate_stats_no_app(self, mock_content_hash): + with patch('core.services.duplicate_detector.AIClientService'), \ + patch('core.services.duplicate_detector.ContentQualityFilter'): + detector = DuplicateDetector() + + result = detector.get_duplicate_stats(None) + + assert result == {} diff --git a/backend/core/tests/test_services/test_encryption.py b/backend/core/tests/test_services/test_encryption.py new file mode 100644 index 0000000..1d6c2a2 --- /dev/null +++ b/backend/core/tests/test_services/test_encryption.py @@ -0,0 +1,237 @@ +import pytest +from unittest.mock import Mock, patch +from datetime import datetime, timedelta +import json +import base64 +from cryptography.fernet import Fernet + +from core.services.encryption import ( + encrypt, + decrypt, + _get_fernet, + generate_verification_token, + verify_verification_token +) + +VALID_FERNET_KEY = Fernet.generate_key().decode() + +@pytest.mark.unit +class TestEncrypt: + @patch('core.services.encryption.SECRET_ENCRYPTION_KEY', VALID_FERNET_KEY) + def test_encrypt_string_success(self): + result = encrypt('test string') + + assert isinstance(result, str) + assert result != 'test string' + + @patch('core.services.encryption.SECRET_ENCRYPTION_KEY', VALID_FERNET_KEY) + def test_encrypt_dict_success(self): + data = {'key1': 'value1', 'key2': 'value2'} + result = encrypt(data) + + assert isinstance(result, str) + result_dict = json.loads(result) + assert 'key1' in result_dict + assert 'key2' in result_dict + + @patch('core.services.encryption.SECRET_ENCRYPTION_KEY', VALID_FERNET_KEY) + def test_encrypt_empty_dict(self): + result = encrypt({}) + + assert result == json.dumps({}) + + @patch('core.services.encryption.key', None) + def test_encrypt_missing_key(self): + with pytest.raises(ValueError, match="Missing SECRET_ENCRYPTION_KEY in settings"): + encrypt('test string') + + @patch('core.services.encryption.SECRET_ENCRYPTION_KEY', VALID_FERNET_KEY) + def test_encrypt_invalid_type(self): + with pytest.raises(TypeError, match="Data must be a dict or str"): + encrypt(123) + + @patch('core.services.encryption.SECRET_ENCRYPTION_KEY', VALID_FERNET_KEY) + def test_encrypt_dict_with_numeric_values(self): + data = {'key1': 123, 'key2': 456.78} + result = encrypt(data) + + assert isinstance(result, str) + result_dict = json.loads(result) + assert 'key1' in result_dict + assert 'key2' in result_dict + + @patch('core.services.encryption.SECRET_ENCRYPTION_KEY', VALID_FERNET_KEY) + def test_encrypt_empty_string(self): + result = encrypt('') + + assert isinstance(result, str) + assert result != '' + + +@pytest.mark.unit +class TestDecrypt: + @patch('core.services.encryption.SECRET_ENCRYPTION_KEY', VALID_FERNET_KEY) + def test_decrypt_string_success(self): + encrypted = encrypt('test string') + result = decrypt(encrypted) + + assert result == 'test string' + + @patch('core.services.encryption.SECRET_ENCRYPTION_KEY', VALID_FERNET_KEY) + def test_decrypt_dict_success(self): + data = {'key1': 'value1', 'key2': 'value2'} + encrypted = encrypt(data) + result = decrypt(encrypted) + + assert result == data + + @patch('core.services.encryption.SECRET_ENCRYPTION_KEY', VALID_FERNET_KEY) + def test_decrypt_empty_dict(self): + encrypted = encrypt({}) + result = decrypt(encrypted) + + assert result == {} + + @patch('core.services.encryption.key', None) + def test_decrypt_missing_key(self): + with pytest.raises(ValueError, match="Missing SECRET_ENCRYPTION_KEY in settings"): + decrypt('encrypted_string') + + @patch('core.services.encryption.SECRET_ENCRYPTION_KEY', VALID_FERNET_KEY) + def test_decrypt_invalid_string(self): + result = decrypt('invalid_encrypted_string') + + assert result == {} + + @patch('core.services.encryption.SECRET_ENCRYPTION_KEY', VALID_FERNET_KEY) + def test_decrypt_round_trip_dict(self): + original = {'username': 'testuser', 'email': 'test@example.com', 'age': '30'} + encrypted = encrypt(original) + decrypted = decrypt(encrypted) + + assert decrypted == original + + @patch('core.services.encryption.SECRET_ENCRYPTION_KEY', VALID_FERNET_KEY) + def test_decrypt_round_trip_string(self): + original = 'This is a test string with special characters: !@#$%^&*()' + encrypted = encrypt(original) + decrypted = decrypt(encrypted) + + assert decrypted == original + + +@pytest.mark.unit +class TestGetFernet: + @patch('core.services.encryption.SECRET_ENCRYPTION_KEY', VALID_FERNET_KEY) + def test_get_fernet_success(self): + fernet = _get_fernet() + + assert fernet is not None + + @patch('core.services.encryption.key', None) + def test_get_fernet_missing_key(self): + with pytest.raises(ValueError, match="Missing SECRET_ENCRYPTION_KEY in settings"): + _get_fernet() + + +@pytest.mark.unit +class TestGenerateVerificationToken: + @patch('core.services.encryption.SECRET_ENCRYPTION_KEY', VALID_FERNET_KEY) + def test_generate_verification_token_success(self): + token = generate_verification_token(123, 'test@example.com') + + assert isinstance(token, str) + assert len(token) > 0 + + @patch('core.services.encryption.SECRET_ENCRYPTION_KEY', VALID_FERNET_KEY) + def test_generate_verification_token_includes_payload(self): + token = generate_verification_token(456, 'user@test.com') + + decoded = base64.urlsafe_b64decode(token.encode()) + assert decoded is not None + + @patch('core.services.encryption.key', None) + def test_generate_verification_token_missing_key(self): + with pytest.raises(ValueError, match="Missing SECRET_ENCRYPTION_KEY in settings"): + generate_verification_token(123, 'test@example.com') + + @patch('core.services.encryption.SECRET_ENCRYPTION_KEY', VALID_FERNET_KEY) + def test_generate_verification_token_different_inputs(self): + token1 = generate_verification_token(1, 'user1@test.com') + token2 = generate_verification_token(2, 'user2@test.com') + + assert token1 != token2 + + +@pytest.mark.unit +class TestVerifyVerificationToken: + @patch('core.services.encryption.SECRET_ENCRYPTION_KEY', VALID_FERNET_KEY) + def test_verify_verification_token_success(self): + token = generate_verification_token(123, 'test@example.com') + payload, error = verify_verification_token(token) + + assert error is None + assert payload is not None + assert payload['user_id'] == 123 + assert payload['email'] == 'test@example.com' + assert 'exp' in payload + + @patch('core.services.encryption.SECRET_ENCRYPTION_KEY', VALID_FERNET_KEY) + def test_verify_verification_token_invalid_token(self): + payload, error = verify_verification_token('invalid_token') + + assert payload is None + assert error is not None + assert 'Invalid token' in error + + @patch('core.services.encryption.key', VALID_FERNET_KEY) + def test_verify_verification_token_expired(self): + with patch('core.services.encryption.datetime') as mock_datetime: + past_time = datetime.now() - timedelta(hours=25) + mock_datetime.now.return_value = past_time + mock_datetime.fromisoformat = datetime.fromisoformat + + token = generate_verification_token(123, 'test@example.com') + + payload, error = verify_verification_token(token) + + assert payload is None + assert error == 'Token has expired' + + @patch('core.services.encryption.SECRET_ENCRYPTION_KEY', VALID_FERNET_KEY) + def test_verify_verification_token_malformed(self): + payload, error = verify_verification_token('not_a_valid_token') + + assert payload is None + assert error is not None + + @patch('core.services.encryption.key', None) + def test_verify_verification_token_missing_key(self): + token = 'some_token' + payload, error = verify_verification_token(token) + + assert payload is None + assert error is not None + + @patch('core.services.encryption.SECRET_ENCRYPTION_KEY', VALID_FERNET_KEY) + def test_verify_verification_token_round_trip(self): + user_id = 999 + email = 'roundtrip@test.com' + + token = generate_verification_token(user_id, email) + payload, error = verify_verification_token(token) + + assert error is None + assert payload['user_id'] == user_id + assert payload['email'] == email + + @patch('core.services.encryption.SECRET_ENCRYPTION_KEY', VALID_FERNET_KEY) + def test_verify_verification_token_within_expiry(self): + token = generate_verification_token(123, 'test@example.com') + payload, error = verify_verification_token(token) + + assert error is None + assert payload is not None + + exp_time = datetime.fromisoformat(payload['exp']) + assert datetime.now() < exp_time diff --git a/backend/core/tests/test_services/test_escalation_service.py b/backend/core/tests/test_services/test_escalation_service.py new file mode 100644 index 0000000..1e97e50 --- /dev/null +++ b/backend/core/tests/test_services/test_escalation_service.py @@ -0,0 +1,379 @@ +import pytest +from unittest.mock import Mock, patch, MagicMock +from datetime import datetime, timezone, timedelta + +from core.services.escalation_service import EscalationService + + +@pytest.mark.unit +class TestEscalationServiceShouldEscalate: + def test_should_escalate_user_requested_escalation(self): + from core.agent_response_schema import ResponseStatus + + service = EscalationService() + chatroom = Mock() + agent_response = Mock() + agent_response.status = ResponseStatus.USER_REQUESTED_ESCALATION + + result = service.should_escalate(chatroom, agent_response, 80) + + assert result is True + + def test_should_escalate_score_above_threshold_not_escalated(self): + service = EscalationService() + chatroom = Mock() + chatroom.is_escalated = False + agent_response = Mock() + agent_response.escalation_score = 85 + agent_response.status = Mock() + + result = service.should_escalate(chatroom, agent_response, 80) + + assert result is True + + def test_should_escalate_score_above_threshold_escalated_not_in_cooldown(self): + service = EscalationService() + chatroom = Mock() + chatroom.is_escalated = True + chatroom.escalated_at = datetime.now(timezone.utc) - timedelta(hours=25) + chatroom.escalation_cooldown_hours = 24 + agent_response = Mock() + agent_response.escalation_score = 85 + agent_response.status = Mock() + + with patch.object(service, '_within_cooldown', return_value=False): + result = service.should_escalate(chatroom, agent_response, 80) + + assert result is True + + def test_should_escalate_score_above_threshold_escalated_in_cooldown(self): + service = EscalationService() + chatroom = Mock() + chatroom.is_escalated = True + chatroom.escalated_at = datetime.now(timezone.utc) - timedelta(hours=1) + chatroom.escalation_cooldown_hours = 24 + agent_response = Mock() + agent_response.escalation_score = 85 + agent_response.status = Mock() + + with patch.object(service, '_within_cooldown', return_value=True): + result = service.should_escalate(chatroom, agent_response, 80) + + assert result is False + + def test_should_escalate_score_below_threshold(self): + service = EscalationService() + chatroom = Mock() + agent_response = Mock() + agent_response.escalation_score = 70 + agent_response.status = Mock() + + result = service.should_escalate(chatroom, agent_response, 80) + + assert result is False + + def test_should_escalate_score_equals_threshold(self): + service = EscalationService() + chatroom = Mock() + chatroom.is_escalated = False + agent_response = Mock() + agent_response.escalation_score = 80 + agent_response.status = Mock() + + result = service.should_escalate(chatroom, agent_response, 80) + + assert result is True + + +@pytest.mark.unit +class TestEscalationServiceEscalate: + @patch('core.services.escalation_service.AppNotificationProfile') + @patch('core.services.escalation_service.send_notification_task') + def test_escalate_success(self, mock_send_task, mock_app_profile): + service = EscalationService() + + chatroom = Mock() + chatroom.name = 'test-chatroom' + chatroom.is_escalated = False + chatroom.escalated_at = None + + application = Mock() + application.name = 'test-app' + + agent_response = Mock() + agent_response.reason_for_escalation = 'Complex issue' + agent_response.status = 'FAILED' + agent_response.answer = 'This is the agent answer' + + user_message = Mock() + user_message.message = 'Help me please' + user_message.platform = 'slack' + + mock_profile = Mock() + mock_profile.name = 'test-profile' + mock_profile.type = 'email' + mock_profile.config = {'email': 'test@example.com'} + + mock_app_profile_instance = Mock() + mock_app_profile_instance.notification_profile = mock_profile + + mock_queryset = Mock() + mock_queryset.select_related.return_value.filter.return_value = [mock_app_profile_instance] + mock_app_profile.objects = mock_queryset + + result = service.escalate(chatroom, application, agent_response, user_message) + + assert chatroom.is_escalated is True + assert chatroom.escalated_at is not None + chatroom.save.assert_called_once_with(update_fields=["is_escalated", "escalated_at"]) + assert result['escalation_reason'] == 'Complex issue' + assert len(result['notified_profiles']) == 1 + assert result['notified_profiles'][0]['name'] == 'test-profile' + mock_send_task.delay.assert_called_once() + + @patch('core.services.escalation_service.AppNotificationProfile') + @patch('core.services.escalation_service.send_notification_task') + def test_escalate_without_user_message(self, mock_send_task, mock_app_profile): + service = EscalationService() + + chatroom = Mock() + chatroom.name = 'test-chatroom' + chatroom.is_escalated = False + chatroom.escalated_at = None + + application = Mock() + application.name = 'test-app' + + agent_response = Mock() + agent_response.reason_for_escalation = None + agent_response.status = 'ESCALATED' + agent_response.answer = 'Agent response' + + mock_queryset = Mock() + mock_queryset.select_related.return_value.filter.return_value = [] + mock_app_profile.objects = mock_queryset + + result = service.escalate(chatroom, application, agent_response, None) + + assert chatroom.is_escalated is True + assert result['escalation_reason'] == 'ESCALATED' + assert result['notified_profiles'] == [] + + @patch('core.services.escalation_service.AppNotificationProfile') + @patch('core.services.escalation_service.send_notification_task') + def test_escalate_agent_answer_truncated(self, mock_send_task, mock_app_profile): + service = EscalationService() + + chatroom = Mock() + chatroom.name = 'test-chatroom' + chatroom.is_escalated = False + + application = Mock() + application.name = 'test-app' + + agent_response = Mock() + agent_response.reason_for_escalation = 'Test reason' + agent_response.status = 'FAILED' + agent_response.answer = 'a' * 400 # Long answer + + user_message = Mock() + user_message.message = 'Test message' + user_message.platform = 'slack' + + mock_profile = Mock() + mock_profile.name = 'test-profile' + mock_profile.type = 'email' + mock_profile.config = {} + + mock_app_profile_instance = Mock() + mock_app_profile_instance.notification_profile = mock_profile + + mock_queryset = Mock() + mock_queryset.select_related.return_value.filter.return_value = [mock_app_profile_instance] + mock_app_profile.objects = mock_queryset + + result = service.escalate(chatroom, application, agent_response, user_message) + + assert result['escalation_reason'] == 'Test reason' + + @patch('core.services.escalation_service.AppNotificationProfile') + @patch('core.services.escalation_service.send_notification_task') + def test_escalate_multiple_notification_profiles(self, mock_send_task, mock_app_profile): + service = EscalationService() + + chatroom = Mock() + chatroom.name = 'test-chatroom' + chatroom.is_escalated = False + + application = Mock() + application.name = 'test-app' + + agent_response = Mock() + agent_response.reason_for_escalation = 'Test' + agent_response.status = 'FAILED' + agent_response.answer = 'Answer' + + user_message = Mock() + user_message.message = 'Message' + user_message.platform = 'slack' + + mock_profile1 = Mock() + mock_profile1.name = 'profile1' + mock_profile1.type = 'email' + mock_profile1.config = {} + + mock_profile2 = Mock() + mock_profile2.name = 'profile2' + mock_profile2.type = 'slack' + mock_profile2.config = {} + + mock_app_profile_instance1 = Mock() + mock_app_profile_instance1.notification_profile = mock_profile1 + + mock_app_profile_instance2 = Mock() + mock_app_profile_instance2.notification_profile = mock_profile2 + + mock_queryset = Mock() + mock_queryset.select_related.return_value.filter.return_value = [ + mock_app_profile_instance1, mock_app_profile_instance2 + ] + mock_app_profile.objects = mock_queryset + + result = service.escalate(chatroom, application, agent_response, user_message) + + assert len(result['notified_profiles']) == 2 + assert mock_send_task.delay.call_count == 2 + + @patch('core.services.escalation_service.AppNotificationProfile') + @patch('core.services.escalation_service.send_notification_task') + def test_escalate_notification_error_handling(self, mock_send_task, mock_app_profile): + service = EscalationService() + + chatroom = Mock() + chatroom.name = 'test-chatroom' + chatroom.is_escalated = False + + application = Mock() + application.name = 'test-app' + + agent_response = Mock() + agent_response.reason_for_escalation = 'Test' + agent_response.status = 'FAILED' + agent_response.answer = 'Answer' + + user_message = Mock() + user_message.message = 'Message' + user_message.platform = 'slack' + + mock_profile = Mock() + mock_profile.name = 'failing-profile' + mock_profile.type = 'email' + mock_profile.config = {} + + mock_app_profile_instance = Mock() + mock_app_profile_instance.notification_profile = mock_profile + + mock_queryset = Mock() + mock_queryset.select_related.return_value.filter.return_value = [mock_app_profile_instance] + mock_app_profile.objects = mock_queryset + + mock_send_task.delay.side_effect = Exception('Notification failed') + + result = service.escalate(chatroom, application, agent_response, user_message) + + assert result['escalation_reason'] == 'Test' + assert result['notified_profiles'] == [] + + @patch('core.services.escalation_service.AppNotificationProfile') + @patch('core.services.escalation_service.send_notification_task') + def test_escalate_agent_without_answer(self, mock_send_task, mock_app_profile): + service = EscalationService() + + chatroom = Mock() + chatroom.name = 'test-chatroom' + chatroom.is_escalated = False + + application = Mock() + application.name = 'test-app' + + agent_response = Mock() + agent_response.reason_for_escalation = 'Test' + agent_response.status = 'FAILED' + del agent_response.answer # Simulate missing answer attribute + + user_message = Mock() + user_message.message = 'Message' + user_message.platform = 'slack' + + mock_queryset = Mock() + mock_queryset.select_related.return_value.filter.return_value = [] + mock_app_profile.objects = mock_queryset + + result = service.escalate(chatroom, application, agent_response, user_message) + + assert result['escalation_reason'] == 'Test' + assert result['notified_profiles'] == [] + + +@pytest.mark.unit +class TestEscalationServiceWithinCooldown: + def test_within_cooldown_no_escalation_time(self): + service = EscalationService() + chatroom = Mock() + chatroom.escalated_at = None + chatroom.escalation_cooldown_hours = 24 + + result = service._within_cooldown(chatroom) + + assert result is False + + def test_within_cooldown_true(self): + service = EscalationService() + chatroom = Mock() + chatroom.escalated_at = datetime.now(timezone.utc) - timedelta(hours=1) + chatroom.escalation_cooldown_hours = 24 + + result = service._within_cooldown(chatroom) + + assert result is True + + def test_within_cooldown_false(self): + service = EscalationService() + chatroom = Mock() + chatroom.escalated_at = datetime.now(timezone.utc) - timedelta(hours=25) + chatroom.escalation_cooldown_hours = 24 + + result = service._within_cooldown(chatroom) + + assert result is False + + def test_within_cooldown_exactly_at_threshold(self): + service = EscalationService() + chatroom = Mock() + chatroom.escalated_at = datetime.now(timezone.utc) - timedelta(hours=24) + chatroom.escalation_cooldown_hours = 24 + + result = service._within_cooldown(chatroom) + + assert result is False + + def test_within_cooldown_naive_datetime(self): + service = EscalationService() + chatroom = Mock() + naive_time = datetime.now() - timedelta(hours=1) + chatroom.escalated_at = naive_time.replace(tzinfo=None) + chatroom.escalation_cooldown_hours = 24 + + result = service._within_cooldown(chatroom) + + assert result is True + + def test_within_cooldown_zero_cooldown(self): + service = EscalationService() + chatroom = Mock() + chatroom.escalated_at = datetime.now(timezone.utc) - timedelta(seconds=1) + chatroom.escalation_cooldown_hours = 0 + + result = service._within_cooldown(chatroom) + + assert result is False diff --git a/backend/core/tests/test_services/test_file_extractors.py b/backend/core/tests/test_services/test_file_extractors.py new file mode 100644 index 0000000..65e9561 --- /dev/null +++ b/backend/core/tests/test_services/test_file_extractors.py @@ -0,0 +1,243 @@ +import pytest +from unittest.mock import Mock, patch, mock_open +import os + +from core.services.file_extractors import ( + extract_text_from_file, + extract_pdf, + extract_docx, + extract_txt +) + + +@pytest.mark.unit +class TestExtractTextFromFile: + @patch('core.services.file_extractors.default_storage') + @patch('core.services.file_extractors.extract_pdf') + def test_extract_text_from_file_pdf(self, mock_extract_pdf, mock_storage): + mock_storage.path.return_value = '/path/to/file.pdf' + mock_extract_pdf.return_value = 'pdf content' + + result = extract_text_from_file('file.pdf') + + mock_extract_pdf.assert_called_once_with('/path/to/file.pdf') + assert result == 'pdf content' + + @patch('core.services.file_extractors.default_storage') + @patch('core.services.file_extractors.extract_docx') + def test_extract_text_from_file_docx(self, mock_extract_docx, mock_storage): + mock_storage.path.return_value = '/path/to/file.docx' + mock_extract_docx.return_value = 'docx content' + + result = extract_text_from_file('file.docx') + + mock_extract_docx.assert_called_once_with('/path/to/file.docx') + assert result == 'docx content' + + @patch('core.services.file_extractors.default_storage') + @patch('core.services.file_extractors.extract_txt') + def test_extract_text_from_file_txt(self, mock_extract_txt, mock_storage): + mock_storage.path.return_value = '/path/to/file.txt' + mock_extract_txt.return_value = 'txt content' + + result = extract_text_from_file('file.txt') + + mock_extract_txt.assert_called_once_with('/path/to/file.txt') + assert result == 'txt content' + + @patch('core.services.file_extractors.default_storage') + @patch('core.services.file_extractors.extract_txt') + def test_extract_text_from_file_md(self, mock_extract_txt, mock_storage): + mock_storage.path.return_value = '/path/to/file.md' + mock_extract_txt.return_value = 'md content' + + result = extract_text_from_file('file.md') + + mock_extract_txt.assert_called_once_with('/path/to/file.md') + assert result == 'md content' + + @patch('core.services.file_extractors.default_storage') + def test_extract_text_from_file_unsupported_type(self, mock_storage): + mock_storage.path.return_value = '/path/to/file.xyz' + + with pytest.raises(ValueError, match="Unsupported file type: .xyz"): + extract_text_from_file('file.xyz') + + @patch('core.services.file_extractors.default_storage') + @patch('core.services.file_extractors.extract_pdf') + def test_extract_text_from_file_uppercase_extension(self, mock_extract_pdf, mock_storage): + mock_storage.path.return_value = '/path/to/file.PDF' + mock_extract_pdf.return_value = 'pdf content' + + result = extract_text_from_file('file.PDF') + + mock_extract_pdf.assert_called_once_with('/path/to/file.PDF') + assert result == 'pdf content' + + +@pytest.mark.unit +class TestExtractPdf: + @patch('core.services.file_extractors.PdfReader') + @patch('core.services.file_extractors._quality_filter') + def test_extract_pdf_success(self, mock_quality_filter, mock_pdf_reader): + mock_page1 = Mock() + mock_page1.extract_text.return_value = 'Page 1 content' + mock_page2 = Mock() + mock_page2.extract_text.return_value = 'Page 2 content' + + mock_pdf_reader.return_value.pages = [mock_page1, mock_page2] + mock_quality_filter.remove_emojis.return_value = 'Page 1 content\nPage 2 content' + + result = extract_pdf('/path/to/file.pdf') + + assert result == 'Page 1 content\nPage 2 content' + mock_quality_filter.remove_emojis.assert_called_once() + + @patch('core.services.file_extractors.PdfReader') + @patch('core.services.file_extractors._quality_filter') + def test_extract_pdf_empty_page(self, mock_quality_filter, mock_pdf_reader): + mock_page = Mock() + mock_page.extract_text.return_value = None + + mock_pdf_reader.return_value.pages = [mock_page] + mock_quality_filter.remove_emojis.return_value = '' + + result = extract_pdf('/path/to/file.pdf') + + assert result == '' + + @patch('core.services.file_extractors.PdfReader') + @patch('core.services.file_extractors._quality_filter') + def test_extract_pdf_single_page(self, mock_quality_filter, mock_pdf_reader): + mock_page = Mock() + mock_page.extract_text.return_value = 'Single page content' + + mock_pdf_reader.return_value.pages = [mock_page] + mock_quality_filter.remove_emojis.return_value = 'Single page content' + + result = extract_pdf('/path/to/file.pdf') + + assert result == 'Single page content' + + @patch('core.services.file_extractors.PdfReader') + @patch('core.services.file_extractors._quality_filter') + def test_extract_pdf_with_emojis(self, mock_quality_filter, mock_pdf_reader): + mock_page = Mock() + mock_page.extract_text.return_value = 'Content with 👍 emoji' + + mock_pdf_reader.return_value.pages = [mock_page] + mock_quality_filter.remove_emojis.return_value = 'Content with emoji' + + result = extract_pdf('/path/to/file.pdf') + + assert result == 'Content with emoji' + mock_quality_filter.remove_emojis.assert_called_once_with('Content with 👍 emoji') + + +@pytest.mark.unit +class TestExtractDocx: + @patch('core.services.file_extractors.Document') + @patch('core.services.file_extractors._quality_filter') + def test_extract_docx_success(self, mock_quality_filter, mock_document): + mock_para1 = Mock() + mock_para1.text = 'Paragraph 1' + mock_para2 = Mock() + mock_para2.text = 'Paragraph 2' + + mock_document.return_value.paragraphs = [mock_para1, mock_para2] + mock_quality_filter.remove_emojis.return_value = 'Paragraph 1\nParagraph 2' + + result = extract_docx('/path/to/file.docx') + + assert result == 'Paragraph 1\nParagraph 2' + mock_document.assert_called_once_with('/path/to/file.docx') + + @patch('core.services.file_extractors.Document') + @patch('core.services.file_extractors._quality_filter') + def test_extract_docx_empty_document(self, mock_quality_filter, mock_document): + mock_document.return_value.paragraphs = [] + mock_quality_filter.remove_emojis.return_value = '' + + result = extract_docx('/path/to/file.docx') + + assert result == '' + + @patch('core.services.file_extractors.Document') + @patch('core.services.file_extractors._quality_filter') + def test_extract_docx_single_paragraph(self, mock_quality_filter, mock_document): + mock_para = Mock() + mock_para.text = 'Single paragraph' + + mock_document.return_value.paragraphs = [mock_para] + mock_quality_filter.remove_emojis.return_value = 'Single paragraph' + + result = extract_docx('/path/to/file.docx') + + assert result == 'Single paragraph' + + @patch('core.services.file_extractors.Document') + @patch('core.services.file_extractors._quality_filter') + def test_extract_docx_with_emojis(self, mock_quality_filter, mock_document): + mock_para = Mock() + mock_para.text = 'Text with 🎉 emoji' + + mock_document.return_value.paragraphs = [mock_para] + mock_quality_filter.remove_emojis.return_value = 'Text with emoji' + + result = extract_docx('/path/to/file.docx') + + assert result == 'Text with emoji' + mock_quality_filter.remove_emojis.assert_called_once_with('Text with 🎉 emoji') + + +@pytest.mark.unit +class TestExtractTxt: + @patch('core.services.file_extractors._quality_filter') + def test_extract_txt_success(self, mock_quality_filter): + with patch('builtins.open', mock_open(read_data='File content')): + mock_quality_filter.remove_emojis.return_value = 'File content' + + result = extract_txt('/path/to/file.txt') + + assert result == 'File content' + mock_quality_filter.remove_emojis.assert_called_once_with('File content') + + @patch('core.services.file_extractors._quality_filter') + def test_extract_txt_empty_file(self, mock_quality_filter): + with patch('builtins.open', mock_open(read_data='')): + mock_quality_filter.remove_emojis.return_value = '' + + result = extract_txt('/path/to/file.txt') + + assert result == '' + + @patch('core.services.file_extractors._quality_filter') + def test_extract_txt_multiline(self, mock_quality_filter): + content = 'Line 1\nLine 2\nLine 3' + with patch('builtins.open', mock_open(read_data=content)): + mock_quality_filter.remove_emojis.return_value = content + + result = extract_txt('/path/to/file.txt') + + assert result == content + + @patch('core.services.file_extractors._quality_filter') + def test_extract_txt_with_emojis(self, mock_quality_filter): + content = 'Text with 😊 emoji' + with patch('builtins.open', mock_open(read_data=content)): + mock_quality_filter.remove_emojis.return_value = 'Text with emoji' + + result = extract_txt('/path/to/file.txt') + + assert result == 'Text with emoji' + mock_quality_filter.remove_emojis.assert_called_once_with('Text with 😊 emoji') + + @patch('core.services.file_extractors._quality_filter') + def test_extract_txt_utf8_encoding(self, mock_quality_filter): + content = 'Unicode content: 你好世界' + with patch('builtins.open', mock_open(read_data=content)): + mock_quality_filter.remove_emojis.return_value = content + + result = extract_txt('/path/to/file.txt') + + assert result == content diff --git a/backend/core/tests/test_services/test_gemini_provider.py b/backend/core/tests/test_services/test_gemini_provider.py new file mode 100644 index 0000000..a4b9d95 --- /dev/null +++ b/backend/core/tests/test_services/test_gemini_provider.py @@ -0,0 +1,380 @@ +import pytest +from unittest.mock import Mock, patch, MagicMock +from pydantic import BaseModel + +from core.services.providers.ai.gemini_provider import GeminiProvider + + +class MockResponseSchema(BaseModel): + answer: str + + +@pytest.mark.unit +class TestGeminiProvider: + def test_init_with_api_key(self): + with patch('core.services.providers.ai.gemini_provider.genai.Client') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + provider = GeminiProvider('test_api_key') + + mock_client_class.assert_called_once_with(api_key='test_api_key') + assert provider.client == mock_client + + def test_init_with_config(self): + with patch('core.services.providers.ai.gemini_provider.genai.Client') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + config = {'model': 'gemini-pro'} + provider = GeminiProvider('test_api_key', config) + + mock_client_class.assert_called_once_with(api_key='test_api_key') + assert provider.client == mock_client + + def test_init_failure(self): + with patch('core.services.providers.ai.gemini_provider.genai.Client') as mock_client_class: + mock_client_class.side_effect = Exception("API key invalid") + + with pytest.raises(ValueError) as exc_info: + GeminiProvider('test_api_key') + + assert "Failed to initialize Gemini client: API key invalid" in str(exc_info.value) + + def test_generate_text(self): + with patch('core.services.providers.ai.gemini_provider.genai.Client') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + mock_response = Mock() + mock_response.text = '{"answer": "test response", "status": "ANSWERED", "escalation": false, "reason_for_escalation": "", "sentiment_score": 50, "escalation_score": 0, "criticality_score": 0}' + mock_client.models.generate_content.return_value = mock_response + + provider = GeminiProvider('test_api_key') + result = provider.generate_text('gemini-pro', 'test content') + + mock_client.models.generate_content.assert_called_once() + assert result.answer == "test response" + + def test_generate_text_api_error(self): + with patch('core.services.providers.ai.gemini_provider.genai.Client') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + mock_client.models.generate_content.side_effect = Exception("API error") + + provider = GeminiProvider('test_api_key') + + with pytest.raises(ValueError) as exc_info: + provider.generate_text('gemini-pro', 'test content') + + assert "Gemini API error: API error" in str(exc_info.value) + + def test_validate_connection_success(self): + with patch('core.services.providers.ai.gemini_provider.genai.Client') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + mock_models = [Mock()] + mock_client.models.list.return_value = mock_models + + provider = GeminiProvider('test_api_key') + is_valid, models = provider.validate_connection() + + assert is_valid is True + assert len(models) > 0 + + def test_validate_connection_failure(self): + with patch('core.services.providers.ai.gemini_provider.genai.Client') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + mock_client.models.list.side_effect = Exception("Connection failed") + + provider = GeminiProvider('test_api_key') + is_valid, models = provider.validate_connection() + + assert is_valid is False + assert models == [] + + def test_get_models_success(self): + with patch('core.services.providers.ai.gemini_provider.genai.Client') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + mock_model = Mock() + mock_model.name = 'gemini-pro' + mock_client.models.list.return_value = [mock_model] + + provider = GeminiProvider('test_api_key') + models = provider.get_models() + + assert len(models) == 1 + assert 'name' in models[0] + + def test_get_models_failure(self): + with patch('core.services.providers.ai.gemini_provider.genai.Client') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + mock_client.models.list.side_effect = Exception("API error") + + provider = GeminiProvider('test_api_key') + + with pytest.raises(ValueError) as exc_info: + provider.get_models() + + assert "Failed to retrieve models from Gemini API: API error" in str(exc_info.value) + + def test_embed(self): + with patch('core.services.providers.ai.gemini_provider.genai.Client') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + mock_embedding = Mock() + mock_embedding.values = [0.1, 0.2, 0.3] + mock_result = Mock() + mock_result.embeddings = [mock_embedding] + mock_client.models.embed_content.return_value = mock_result + + provider = GeminiProvider('test_api_key') + embeddings = provider.embed('embedding-model', ['test text']) + + mock_client.models.embed_content.assert_called_once_with(model='embedding-model', contents=['test text']) + assert len(embeddings) == 1 + assert embeddings[0] == [0.1, 0.2, 0.3] + + def test_embed_error(self): + with patch('core.services.providers.ai.gemini_provider.genai.Client') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + mock_client.models.embed_content.side_effect = Exception("Embedding failed") + + provider = GeminiProvider('test_api_key') + + with pytest.raises(ValueError) as exc_info: + provider.embed('embedding-model', ['test text']) + + assert "Gemini embedding error: Embedding failed" in str(exc_info.value) + + def test_extract_usage_with_metadata(self): + with patch('core.services.providers.ai.gemini_provider.genai.Client') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + provider = GeminiProvider('test_api_key') + + mock_response = Mock() + mock_usage = Mock() + mock_usage.prompt_token_count = 10 + mock_usage.candidates_token_count = 5 + mock_usage.total_token_count = 15 + mock_response.usage_metadata = mock_usage + + usage = provider._extract_usage(mock_response) + + assert usage['prompt_tokens'] == 10 + assert usage['completion_tokens'] == 5 + assert usage['total_tokens'] == 15 + + def test_extract_usage_without_metadata(self): + with patch('core.services.providers.ai.gemini_provider.genai.Client') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + provider = GeminiProvider('test_api_key') + + mock_response = Mock() + delattr(mock_response, 'usage_metadata') + + usage = provider._extract_usage(mock_response) + + assert usage == {} + + def test_generate_with_conversation_simple(self): + with patch('core.services.providers.ai.gemini_provider.genai.Client') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + mock_response = Mock() + mock_response.text = '{"answer": "response"}' + mock_response.candidates = None + mock_client.models.generate_content.return_value = mock_response + + provider = GeminiProvider('test_api_key') + messages = [{'role': 'user', 'content': 'Hello'}] + + result, tool_calls, usage = provider.generate_with_conversation( + 'gemini-pro', + messages, + None, + MockResponseSchema + ) + + assert result.answer == "response" + assert tool_calls == [] + + def test_generate_with_conversation_with_tools(self): + with patch('core.services.providers.ai.gemini_provider.genai.Client') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + mock_response = Mock() + mock_response.text = '{"answer": "response"}' + mock_response.candidates = None + mock_client.models.generate_content.return_value = mock_response + + provider = GeminiProvider('test_api_key') + messages = [{'role': 'user', 'content': 'Hello'}] + tools = [{'function': {'name': 'test_tool', 'parameters': {'type': 'object'}}}] + + result, tool_calls, usage = provider.generate_with_conversation( + 'gemini-pro', + messages, + tools, + MockResponseSchema + ) + + assert result.answer == "response" + assert tool_calls == [] + + def test_generate_with_conversation_with_system_message(self): + with patch('core.services.providers.ai.gemini_provider.genai.Client') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + mock_response = Mock() + mock_response.text = '{"answer": "response"}' + mock_response.candidates = None + mock_client.models.generate_content.return_value = mock_response + + provider = GeminiProvider('test_api_key') + messages = [ + {'role': 'system', 'content': 'You are a helpful assistant'}, + {'role': 'user', 'content': 'Hello'} + ] + + result, tool_calls, usage = provider.generate_with_conversation( + 'gemini-pro', + messages, + None, + MockResponseSchema + ) + + assert result.answer == "response" + + def test_generate_with_conversation_with_tool_response(self): + with patch('core.services.providers.ai.gemini_provider.genai.Client') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + mock_response = Mock() + mock_response.text = '{"answer": "response"}' + mock_response.candidates = None + mock_client.models.generate_content.return_value = mock_response + + provider = GeminiProvider('test_api_key') + messages = [ + {'role': 'tool', 'tool_call_id': '123', 'name': 'test_tool', 'content': 'tool result'} + ] + + result, tool_calls, usage = provider.generate_with_conversation( + 'gemini-pro', + messages, + None, + MockResponseSchema + ) + + assert result.answer == "response" + + def test_generate_with_conversation_with_function_call(self): + with patch('core.services.providers.ai.gemini_provider.genai.Client') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + mock_fc = Mock() + mock_fc.name = 'test_function' + mock_fc.args = {'param': 'value'} + + mock_part = Mock() + mock_part.function_call = mock_fc + mock_part.text = None + + mock_content = Mock() + mock_content.parts = [mock_part] + + mock_candidate = Mock() + mock_candidate.content = mock_content + + mock_response = Mock() + mock_response.candidates = [mock_candidate] + mock_response.text = None + mock_response.usage_metadata = None + mock_client.models.generate_content.return_value = mock_response + + provider = GeminiProvider('test_api_key') + messages = [{'role': 'user', 'content': 'Hello'}] + + result, tool_calls, usage = provider.generate_with_conversation( + 'gemini-pro', + messages, + None, + MockResponseSchema + ) + + assert len(tool_calls) == 1 + assert tool_calls[0]['name'] == 'test_function' + assert tool_calls[0]['args'] == {'param': 'value'} + + def test_generate_with_conversation_api_error(self): + with patch('core.services.providers.ai.gemini_provider.genai.Client') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + mock_client.models.generate_content.side_effect = Exception("API error") + + provider = GeminiProvider('test_api_key') + messages = [{'role': 'user', 'content': 'Hello'}] + + with pytest.raises(ValueError) as exc_info: + provider.generate_with_conversation('gemini-pro', messages, None, MockResponseSchema) + + assert "Gemini API error: API error" in str(exc_info.value) + + def test_generate_with_tools(self): + with patch('core.services.providers.ai.gemini_provider.genai.Client') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + mock_response = Mock() + mock_client.models.generate_content.return_value = mock_response + + provider = GeminiProvider('test_api_key') + contents = [{'role': 'user', 'content': 'Hello'}] + tool_schemas = [{'function': {'name': 'test_tool', 'parameters': {'type': 'object'}}}] + + result = provider.generate_with_tools('gemini-pro', contents, tool_schemas) + + mock_client.models.generate_content.assert_called_once() + assert result == mock_response + + def test_generate_with_tools_with_system_message(self): + with patch('core.services.providers.ai.gemini_provider.genai.Client') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + mock_response = Mock() + mock_client.models.generate_content.return_value = mock_response + + provider = GeminiProvider('test_api_key') + contents = [ + {'role': 'system', 'parts': [{'text': 'You are helpful'}]}, + {'role': 'user', 'content': 'Hello'} + ] + tool_schemas = [{'function': {'name': 'test_tool', 'parameters': {'type': 'object'}}}] + + result = provider.generate_with_tools('gemini-pro', contents, tool_schemas) + + mock_client.models.generate_content.assert_called_once() + assert result == mock_response diff --git a/backend/core/tests/test_services/test_github_client.py b/backend/core/tests/test_services/test_github_client.py new file mode 100644 index 0000000..bc601b4 --- /dev/null +++ b/backend/core/tests/test_services/test_github_client.py @@ -0,0 +1,481 @@ +import pytest +from unittest.mock import Mock, patch, MagicMock +import requests +import time + +from core.services.github_client import GitHubAPIClient + + +@pytest.mark.unit +class TestGitHubAPIClientInit: + def test_init_sets_token_and_headers(self): + client = GitHubAPIClient('test_token') + + assert client.token == 'test_token' + assert client.session is not None + assert client.session.headers['Authorization'] == 'Bearer test_token' + assert client.session.headers['Accept'] == 'application/vnd.github+json' + assert client.session.headers['X-GitHub-Api-Version'] == '2022-11-28' + assert client.session.headers['User-Agent'] == 'Ch8r-GitHub-Ingestion/1.0' + + +@pytest.mark.unit +class TestMakeRequest: + @patch('core.services.github_client.time.sleep') + def test_make_request_success(self, mock_sleep): + client = GitHubAPIClient('test_token') + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {'data': 'test'} + client.session.request = Mock(return_value=mock_response) + + result = client._make_request('GET', '/test') + + assert result == {'data': 'test'} + client.session.request.assert_called_once() + + @patch('core.services.github_client.time.sleep') + def test_make_request_rate_limit_hit(self, mock_sleep): + client = GitHubAPIClient('test_token') + mock_response = Mock() + mock_response.status_code = 403 + mock_response.headers = { + 'X-RateLimit-Remaining': '0', + 'X-RateLimit-Reset': str(int(time.time()) + 2) + } + mock_response.json.return_value = {'message': 'Rate limit exceeded'} + + client.session.request = Mock(side_effect=[ + mock_response, + Mock(status_code=200, json=Mock(return_value={'data': 'success'})) + ]) + + result = client._make_request('GET', '/test') + + assert result == {'data': 'success'} + assert mock_sleep.called + + @patch('core.services.github_client.time.sleep') + def test_make_request_401_error_raises(self, mock_sleep): + client = GitHubAPIClient('test_token') + mock_response = Mock() + mock_response.status_code = 401 + mock_response.content = b'{"message": "Unauthorized"}' + mock_response.json.return_value = {'message': 'Unauthorized'} + client.session.request = Mock(return_value=mock_response) + + with pytest.raises(requests.exceptions.HTTPError, match="GitHub API error: Unauthorized"): + client._make_request('GET', '/test') + + @patch('core.services.github_client.time.sleep') + def test_make_request_404_error_raises(self, mock_sleep): + client = GitHubAPIClient('test_token') + mock_response = Mock() + mock_response.status_code = 404 + mock_response.content = b'{"message": "Not found"}' + mock_response.json.return_value = {'message': 'Not found'} + client.session.request = Mock(return_value=mock_response) + + with pytest.raises(requests.exceptions.HTTPError, match="GitHub API error: Not found"): + client._make_request('GET', '/test') + + @patch('core.services.github_client.time.sleep') + def test_make_request_500_error_retries(self, mock_sleep): + client = GitHubAPIClient('test_token') + mock_response = Mock() + mock_response.status_code = 500 + mock_response.content = b'{"message": "Server error"}' + mock_response.json.return_value = {'message': 'Server error'} + + client.session.request = Mock(side_effect=[ + mock_response, + mock_response, + Mock(status_code=200, json=Mock(return_value={'data': 'success'})) + ]) + + result = client._make_request('GET', '/test') + + assert result == {'data': 'success'} + assert client.session.request.call_count == 3 + + @patch('core.services.github_client.time.sleep') + def test_make_request_timeout_retries(self, mock_sleep): + client = GitHubAPIClient('test_token') + + client.session.request = Mock(side_effect=[ + requests.exceptions.Timeout('Timeout'), + requests.exceptions.Timeout('Timeout'), + Mock(status_code=200, json=Mock(return_value={'data': 'success'})) + ]) + + result = client._make_request('GET', '/test') + + assert result == {'data': 'success'} + assert client.session.request.call_count == 3 + + @patch('core.services.github_client.time.sleep') + def test_make_request_timeout_max_retries_exceeded(self, mock_sleep): + client = GitHubAPIClient('test_token') + + client.session.request = Mock(side_effect=requests.exceptions.Timeout('Timeout')) + + with pytest.raises(requests.exceptions.Timeout): + client._make_request('GET', '/test') + + assert client.session.request.call_count == 3 + + +@pytest.mark.unit +class TestGetRepositoryInfo: + @patch('core.services.github_client.time.sleep') + def test_get_repository_info(self, mock_sleep): + client = GitHubAPIClient('test_token') + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {'name': 'test-repo', 'owner': 'test-owner'} + client.session.request = Mock(return_value=mock_response) + + result = client.get_repository_info('owner', 'repo') + + assert result == {'name': 'test-repo', 'owner': 'test-owner'} + client.session.request.assert_called_once_with('GET', 'https://api.github.com/repos/owner/repo', timeout=30) + + +@pytest.mark.unit +class TestGetIssues: + @patch('core.services.github_client.time.sleep') + def test_get_issues_single_page(self, mock_sleep): + client = GitHubAPIClient('test_token') + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = [ + {'id': 1, 'title': 'Issue 1'}, + {'id': 2, 'title': 'Issue 2'} + ] + client.session.request = Mock(return_value=mock_response) + + result = client.get_issues('owner', 'repo') + + assert len(result) == 2 + assert result[0]['title'] == 'Issue 1' + + @patch('core.services.github_client.time.sleep') + def test_get_issues_filters_pull_requests(self, mock_sleep): + client = GitHubAPIClient('test_token') + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = [ + {'id': 1, 'title': 'Issue 1'}, + {'id': 2, 'title': 'PR 1', 'pull_request': {}}, + {'id': 3, 'title': 'Issue 2'} + ] + client.session.request = Mock(return_value=mock_response) + + result = client.get_issues('owner', 'repo') + + assert len(result) == 2 + assert all('pull_request' not in issue for issue in result) + + @patch('core.services.github_client.time.sleep') + def test_get_issues_with_since_parameter(self, mock_sleep): + client = GitHubAPIClient('test_token') + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = [] + client.session.request = Mock(return_value=mock_response) + + client.get_issues('owner', 'repo', since='2024-01-01') + + call_kwargs = client.session.request.call_args[1] + assert 'since' in call_kwargs['params'] + assert call_kwargs['params']['since'] == '2024-01-01' + + @patch('core.services.github_client.time.sleep') + def test_get_issues_pagination(self, mock_sleep): + client = GitHubAPIClient('test_token') + mock_response1 = Mock() + mock_response1.status_code = 200 + mock_response1.json.return_value = [{'id': i} for i in range(100)] + + mock_response2 = Mock() + mock_response2.status_code = 200 + mock_response2.json.return_value = [{'id': i} for i in range(100, 150)] + + client.session.request = Mock(side_effect=[mock_response1, mock_response2]) + + result = client.get_issues('owner', 'repo', per_page=100) + + assert len(result) == 150 + + +@pytest.mark.unit +class TestGetIssueComments: + @patch('core.services.github_client.time.sleep') + def test_get_issue_comments(self, mock_sleep): + client = GitHubAPIClient('test_token') + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = [ + {'id': 1, 'body': 'Comment 1'}, + {'id': 2, 'body': 'Comment 2'} + ] + client.session.request = Mock(return_value=mock_response) + + result = client.get_issue_comments('owner', 'repo', 1) + + assert len(result) == 2 + assert result[0]['body'] == 'Comment 1' + + @patch('core.services.github_client.time.sleep') + def test_get_issue_comments_pagination(self, mock_sleep): + client = GitHubAPIClient('test_token') + mock_response1 = Mock() + mock_response1.status_code = 200 + mock_response1.json.return_value = [{'id': i} for i in range(100)] + + mock_response2 = Mock() + mock_response2.status_code = 200 + mock_response2.json.return_value = [{'id': i} for i in range(100, 150)] + + client.session.request = Mock(side_effect=[mock_response1, mock_response2]) + + result = client.get_issue_comments('owner', 'repo', 1) + + assert len(result) == 150 + + +@pytest.mark.unit +class TestGetPullRequests: + @patch('core.services.github_client.time.sleep') + def test_get_pull_requests(self, mock_sleep): + client = GitHubAPIClient('test_token') + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = [ + {'id': 1, 'title': 'PR 1'}, + {'id': 2, 'title': 'PR 2'} + ] + client.session.request = Mock(return_value=mock_response) + + result = client.get_pull_requests('owner', 'repo') + + assert len(result) == 2 + assert result[0]['title'] == 'PR 1' + + @patch('core.services.github_client.time.sleep') + def test_get_pull_requests_with_since(self, mock_sleep): + client = GitHubAPIClient('test_token') + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = [] + client.session.request = Mock(return_value=mock_response) + + client.get_pull_requests('owner', 'repo', since='2024-01-01') + + call_kwargs = client.session.request.call_args[1] + assert 'since' in call_kwargs['params'] + assert call_kwargs['params']['since'] == '2024-01-01' + + +@pytest.mark.unit +class TestGetPullRequestComments: + @patch('core.services.github_client.time.sleep') + def test_get_pull_request_comments(self, mock_sleep): + client = GitHubAPIClient('test_token') + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = [ + {'id': 1, 'body': 'Comment 1'}, + {'id': 2, 'body': 'Comment 2'} + ] + client.session.request = Mock(return_value=mock_response) + + result = client.get_pull_request_comments('owner', 'repo', 1) + + assert len(result) == 2 + + @patch('core.services.github_client.time.sleep') + def test_get_pull_request_comments_error_handling(self, mock_sleep): + client = GitHubAPIClient('test_token') + client.session.request = Mock(side_effect=Exception('API error')) + + result = client.get_pull_request_comments('owner', 'repo', 1) + + assert result == [] + + +@pytest.mark.unit +class TestGetPullRequestFiles: + @patch('core.services.github_client.time.sleep') + def test_get_pull_request_files(self, mock_sleep): + client = GitHubAPIClient('test_token') + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = [ + {'filename': 'file1.py'}, + {'filename': 'file2.py'} + ] + client.session.request = Mock(return_value=mock_response) + + result = client.get_pull_request_files('owner', 'repo', 1) + + assert len(result) == 2 + assert result[0]['filename'] == 'file1.py' + + @patch('core.services.github_client.time.sleep') + def test_get_pull_request_files_error_handling(self, mock_sleep): + client = GitHubAPIClient('test_token') + client.session.request = Mock(side_effect=Exception('API error')) + + result = client.get_pull_request_files('owner', 'repo', 1) + + assert result == [] + + +@pytest.mark.unit +class TestGetDiscussions: + @patch('core.services.github_client.time.sleep') + def test_get_discussions(self, mock_sleep): + client = GitHubAPIClient('test_token') + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + 'items': [ + {'id': 1, 'title': 'Discussion 1'}, + {'id': 2, 'title': 'Discussion 2'} + ] + } + client.session.request = Mock(return_value=mock_response) + + result = client.get_discussions('owner', 'repo') + + assert len(result) == 2 + assert result[0]['title'] == 'Discussion 1' + + @patch('core.services.github_client.time.sleep') + def test_get_discussions_with_since(self, mock_sleep): + client = GitHubAPIClient('test_token') + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {'items': []} + client.session.request = Mock(return_value=mock_response) + + client.get_discussions('owner', 'repo', since='2024-01-01') + + call_kwargs = client.session.request.call_args[1] + assert 'q' in call_kwargs['params'] + assert 'created:>2024-01-01' in call_kwargs['params']['q'] + + +@pytest.mark.unit +class TestGetDiscussionComments: + def test_get_discussion_comments_not_implemented(self): + client = GitHubAPIClient('test_token') + + result = client.get_discussion_comments('owner', 'repo', 1) + + assert result == [] + + +@pytest.mark.unit +class TestGetWikiPages: + def test_get_wiki_pages_not_available(self): + client = GitHubAPIClient('test_token') + + result = client.get_wiki_pages('owner', 'repo') + + assert result == [] + + +@pytest.mark.unit +class TestGetRepositoryFile: + @patch('core.services.github_client.time.sleep') + def test_get_repository_file_success(self, mock_sleep): + client = GitHubAPIClient('test_token') + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {'content': 'file content'} + client.session.request = Mock(return_value=mock_response) + + result = client.get_repository_file('owner', 'repo', 'path/to/file.txt') + + assert result == {'content': 'file content'} + + @patch('core.services.github_client.time.sleep') + def test_get_repository_file_with_ref(self, mock_sleep): + client = GitHubAPIClient('test_token') + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {'content': 'file content'} + client.session.request = Mock(return_value=mock_response) + + client.get_repository_file('owner', 'repo', 'path/to/file.txt', ref='main') + + call_kwargs = client.session.request.call_args[1] + assert 'ref' in call_kwargs['params'] + assert call_kwargs['params']['ref'] == 'main' + + @patch('core.services.github_client.time.sleep') + def test_get_repository_file_not_found(self, mock_sleep): + client = GitHubAPIClient('test_token') + + with patch.object(client, '_make_request', side_effect=requests.exceptions.HTTPError('404 Not found')): + result = client.get_repository_file('owner', 'repo', 'path/to/file.txt') + + assert result is None + + +@pytest.mark.unit +class TestGetCodeComments: + @patch('core.services.github_client.time.sleep') + def test_get_code_comments_success(self, mock_sleep): + client = GitHubAPIClient('test_token') + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = [ + {'sha': 'abc123', 'commit': {'message': 'Fix bug'}} + ] + client.session.request = Mock(return_value=mock_response) + + result = client.get_code_comments('owner', 'repo', 'path/to/file.py') + + assert len(result) == 1 + + @patch('core.services.github_client.time.sleep') + def test_get_code_comments_error_handling(self, mock_sleep): + client = GitHubAPIClient('test_token') + client.session.request = Mock(side_effect=Exception('API error')) + + result = client.get_code_comments('owner', 'repo', 'path/to/file.py') + + assert result == [] + + +@pytest.mark.unit +class TestGetRateLimitStatus: + @patch('core.services.github_client.time.sleep') + def test_get_rate_limit_status(self, mock_sleep): + client = GitHubAPIClient('test_token') + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + 'resources': { + 'core': {'remaining': 4999, 'limit': 5000} + } + } + client.session.request = Mock(return_value=mock_response) + + result = client.get_rate_limit_status() + + assert result['resources']['core']['remaining'] == 4999 + + +@pytest.mark.unit +class TestClose: + def test_close_session(self): + client = GitHubAPIClient('test_token') + client.session.close = Mock() + + client.close() + + client.session.close.assert_called_once() diff --git a/backend/core/tests/test_services/test_github_graphql.py b/backend/core/tests/test_services/test_github_graphql.py new file mode 100644 index 0000000..2862089 --- /dev/null +++ b/backend/core/tests/test_services/test_github_graphql.py @@ -0,0 +1,431 @@ +import pytest +from unittest.mock import Mock, patch + +from core.services.providers.version_control.github_graphql import GitHubGraphQLProvider + + +@pytest.mark.unit +class TestGitHubGraphQLProvider: + def test_init(self): + credentials = {'token': 'test_token'} + provider = GitHubGraphQLProvider(credentials) + + assert provider.credentials == credentials + assert provider._graphql_client is None + assert provider._rest_client is None + + def test_provider_name(self): + assert GitHubGraphQLProvider.provider_name == 'github_graphql' + + def test_get_graphql_client_with_token(self): + with patch('core.services.providers.version_control.github_graphql.GitHubGraphQLClient') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + credentials = {'token': 'test_token'} + provider = GitHubGraphQLProvider(credentials) + + client = provider._get_graphql_client() + + mock_client_class.assert_called_once_with('test_token') + assert client == mock_client + assert provider._graphql_client == mock_client + + def test_get_graphql_client_without_token(self): + credentials = {} + provider = GitHubGraphQLProvider(credentials) + + with pytest.raises(ValueError) as exc_info: + provider._get_graphql_client() + + assert "GitHub token not found in credentials" in str(exc_info.value) + + def test_get_graphql_client_caches_client(self): + with patch('core.services.providers.version_control.github_graphql.GitHubGraphQLClient') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + credentials = {'token': 'test_token'} + provider = GitHubGraphQLProvider(credentials) + + client1 = provider._get_graphql_client() + client2 = provider._get_graphql_client() + + mock_client_class.assert_called_once() + assert client1 == client2 + + def test_get_rest_client_with_token(self): + with patch('core.services.providers.version_control.github_graphql.GitHubAPIClient') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + credentials = {'token': 'test_token'} + provider = GitHubGraphQLProvider(credentials) + + client = provider._get_rest_client() + + mock_client_class.assert_called_once_with('test_token') + assert client == mock_client + assert provider._rest_client == mock_client + + def test_get_rest_client_without_token(self): + credentials = {} + provider = GitHubGraphQLProvider(credentials) + + with pytest.raises(ValueError) as exc_info: + provider._get_rest_client() + + assert "GitHub token not found in credentials" in str(exc_info.value) + + def test_validate_credentials_success(self): + with patch('core.services.providers.version_control.github_graphql.GitHubGraphQLClient') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + mock_client.execute_query.return_value = { + 'data': { + 'viewer': { + 'login': 'testuser', + 'name': 'Test User', + 'avatarUrl': 'https://example.com/avatar.png', + 'url': 'https://github.com/testuser' + } + } + } + + credentials = {'token': 'test_token'} + provider = GitHubGraphQLProvider(credentials) + + is_valid, error, metadata = provider.validate_credentials() + + assert is_valid is True + assert error == '' + assert metadata['login'] == 'testuser' + assert metadata['name'] == 'Test User' + assert metadata['avatar_url'] == 'https://example.com/avatar.png' + assert metadata['html_url'] == 'https://github.com/testuser' + + def test_validate_credentials_failure(self): + with patch('core.services.providers.version_control.github_graphql.GitHubGraphQLClient') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + mock_client.execute_query.side_effect = Exception("Invalid token") + + credentials = {'token': 'test_token'} + provider = GitHubGraphQLProvider(credentials) + + is_valid, error, metadata = provider.validate_credentials() + + assert is_valid is False + assert "Invalid token" in error + assert metadata == {} + + def test_get_repository_info(self): + with patch('core.services.providers.version_control.github_graphql.GitHubGraphQLClient') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + mock_client.get_repository_info.return_value = { + 'id': '123', + 'name': 'test-repo', + 'description': 'Test repository', + 'url': 'https://github.com/owner/test-repo', + 'isPrivate': False, + 'defaultBranchRef': {'name': 'main'}, + 'isFork': False, + 'stargazerCount': 10, + 'forkCount': 5, + 'issues': {'totalCount': 3} + } + + credentials = {'token': 'test_token'} + provider = GitHubGraphQLProvider(credentials) + + repo_info = provider.get_repository_info('owner', 'test-repo') + + mock_client.get_repository_info.assert_called_once_with('owner', 'test-repo') + assert repo_info['id'] == '123' + assert repo_info['name'] == 'test-repo' + assert repo_info['owner'] == 'owner' + assert repo_info['full_name'] == 'owner/test-repo' + assert repo_info['description'] == 'Test repository' + assert repo_info['default_branch'] == 'main' + assert repo_info['stars'] == 10 + assert repo_info['forks'] == 5 + assert repo_info['open_issues'] == 3 + + def test_get_issues_all(self): + with patch('core.services.providers.version_control.github_graphql.GitHubGraphQLClient') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + mock_client.get_all_issues_with_comments.return_value = [ + { + 'number': 1, + 'title': 'Test Issue', + 'body': 'Issue body', + 'state': 'OPEN', + 'author': {'login': 'testuser'}, + 'authorAssociation': 'OWNER', + 'assignees': {'nodes': [{'login': 'assignee1'}]}, + 'labels': {'nodes': [{'name': 'bug'}]}, + 'milestone': {'title': 'v1.0'}, + 'locked': False, + 'createdAt': '2024-01-01T00:00:00Z', + 'updatedAt': '2024-01-02T00:00:00Z', + 'closedAt': None, + 'url': 'https://github.com/owner/repo/issues/1' + } + ] + + credentials = {'token': 'test_token'} + provider = GitHubGraphQLProvider(credentials) + + issues = provider.get_issues('owner', 'repo', state='all') + + mock_client.get_all_issues_with_comments.assert_called_once_with( + 'owner', 'repo', states=['OPEN', 'CLOSED'], since=None + ) + assert len(issues) == 1 + assert issues[0]['id'] == '1' + assert issues[0]['title'] == 'Test Issue' + assert issues[0]['state'] == 'open' + assert issues[0]['author'] == 'testuser' + + def test_get_issues_open(self): + with patch('core.services.providers.version_control.github_graphql.GitHubGraphQLClient') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + mock_client.get_all_issues_with_comments.return_value = [] + + credentials = {'token': 'test_token'} + provider = GitHubGraphQLProvider(credentials) + + issues = provider.get_issues('owner', 'repo', state='open') + + mock_client.get_all_issues_with_comments.assert_called_once_with( + 'owner', 'repo', states=['OPEN'], since=None + ) + + def test_get_issues_closed(self): + with patch('core.services.providers.version_control.github_graphql.GitHubGraphQLClient') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + mock_client.get_all_issues_with_comments.return_value = [] + + credentials = {'token': 'test_token'} + provider = GitHubGraphQLProvider(credentials) + + issues = provider.get_issues('owner', 'repo', state='closed') + + mock_client.get_all_issues_with_comments.assert_called_once_with( + 'owner', 'repo', states=['CLOSED'], since=None + ) + + def test_get_issues_with_since(self): + with patch('core.services.providers.version_control.github_graphql.GitHubGraphQLClient') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + mock_client.get_all_issues_with_comments.return_value = [] + + credentials = {'token': 'test_token'} + provider = GitHubGraphQLProvider(credentials) + + issues = provider.get_issues('owner', 'repo', state='all', since='2024-01-01') + + mock_client.get_all_issues_with_comments.assert_called_once_with( + 'owner', 'repo', states=['OPEN', 'CLOSED'], since='2024-01-01' + ) + + def test_get_issue_comments(self): + credentials = {'token': 'test_token'} + provider = GitHubGraphQLProvider(credentials) + + comments = provider.get_issue_comments('owner', 'repo', 1) + + assert comments == [] + + def test_get_pull_requests_all(self): + with patch('core.services.providers.version_control.github_graphql.GitHubGraphQLClient') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + mock_client.get_all_pull_requests_with_comments.return_value = [ + { + 'number': 1, + 'title': 'Test PR', + 'body': 'PR body', + 'state': 'OPEN', + 'author': {'login': 'testuser'}, + 'authorAssociation': 'OWNER', + 'assignees': {'nodes': [{'login': 'assignee1'}]}, + 'labels': {'nodes': [{'name': 'enhancement'}]}, + 'milestone': {'title': 'v1.0'}, + 'headRefName': 'feature', + 'baseRefName': 'main', + 'merged': False, + 'mergedAt': None, + 'mergeCommit': None, + 'additions': 10, + 'deletions': 5, + 'changedFiles': 2, + 'createdAt': '2024-01-01T00:00:00Z', + 'updatedAt': '2024-01-02T00:00:00Z', + 'closedAt': None, + 'url': 'https://github.com/owner/repo/pull/1' + } + ] + + credentials = {'token': 'test_token'} + provider = GitHubGraphQLProvider(credentials) + + prs = provider.get_pull_requests('owner', 'repo', state='all') + + mock_client.get_all_pull_requests_with_comments.assert_called_once_with( + 'owner', 'repo', states=['OPEN', 'CLOSED', 'MERGED'], since=None + ) + assert len(prs) == 1 + assert prs[0]['id'] == '1' + assert prs[0]['title'] == 'Test PR' + assert prs[0]['state'] == 'open' + assert prs[0]['head_branch'] == 'feature' + assert prs[0]['base_branch'] == 'main' + + def test_get_pull_requests_open(self): + with patch('core.services.providers.version_control.github_graphql.GitHubGraphQLClient') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + mock_client.get_all_pull_requests_with_comments.return_value = [] + + credentials = {'token': 'test_token'} + provider = GitHubGraphQLProvider(credentials) + + prs = provider.get_pull_requests('owner', 'repo', state='open') + + mock_client.get_all_pull_requests_with_comments.assert_called_once_with( + 'owner', 'repo', states=['OPEN'], since=None + ) + + def test_get_pull_requests_closed(self): + with patch('core.services.providers.version_control.github_graphql.GitHubGraphQLClient') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + mock_client.get_all_pull_requests_with_comments.return_value = [] + + credentials = {'token': 'test_token'} + provider = GitHubGraphQLProvider(credentials) + + prs = provider.get_pull_requests('owner', 'repo', state='closed') + + mock_client.get_all_pull_requests_with_comments.assert_called_once_with( + 'owner', 'repo', states=['CLOSED', 'MERGED'], since=None + ) + + def test_get_pull_requests_merged(self): + with patch('core.services.providers.version_control.github_graphql.GitHubGraphQLClient') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + mock_client.get_all_pull_requests_with_comments.return_value = [] + + credentials = {'token': 'test_token'} + provider = GitHubGraphQLProvider(credentials) + + prs = provider.get_pull_requests('owner', 'repo', state='merged') + + mock_client.get_all_pull_requests_with_comments.assert_called_once_with( + 'owner', 'repo', states=['MERGED'], since=None + ) + + def test_get_pull_requests_with_since(self): + with patch('core.services.providers.version_control.github_graphql.GitHubGraphQLClient') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + mock_client.get_all_pull_requests_with_comments.return_value = [] + + credentials = {'token': 'test_token'} + provider = GitHubGraphQLProvider(credentials) + + prs = provider.get_pull_requests('owner', 'repo', state='all', since='2024-01-01') + + mock_client.get_all_pull_requests_with_comments.assert_called_once_with( + 'owner', 'repo', states=['OPEN', 'CLOSED', 'MERGED'], since='2024-01-01' + ) + + def test_get_pull_request_comments(self): + credentials = {'token': 'test_token'} + provider = GitHubGraphQLProvider(credentials) + + comments = provider.get_pull_request_comments('owner', 'repo', 1) + + assert comments == [] + + def test_get_pull_request_files(self): + with patch('core.services.providers.version_control.github_graphql.GitHubAPIClient') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + mock_client.get_pull_request_files.return_value = [ + { + 'filename': 'test.py', + 'status': 'modified', + 'additions': 10, + 'deletions': 5, + 'changes': 15, + 'patch': 'diff content', + 'blob_url': 'https://github.com/blob', + 'raw_url': 'https://github.com/raw' + } + ] + + credentials = {'token': 'test_token'} + provider = GitHubGraphQLProvider(credentials) + + files = provider.get_pull_request_files('owner', 'repo', 1) + + mock_client.get_pull_request_files.assert_called_once_with('owner', 'repo', 1) + assert len(files) == 1 + assert files[0]['filename'] == 'test.py' + assert files[0]['status'] == 'modified' + assert files[0]['additions'] == 10 + assert files[0]['deletions'] == 5 + + def test_close(self): + with patch('core.services.providers.version_control.github_graphql.GitHubGraphQLClient') as mock_graphql_class: + with patch('core.services.providers.version_control.github_graphql.GitHubAPIClient') as mock_rest_class: + mock_graphql_client = Mock() + mock_graphql_class.return_value = mock_graphql_client + + mock_rest_client = Mock() + mock_rest_class.return_value = mock_rest_client + + credentials = {'token': 'test_token'} + provider = GitHubGraphQLProvider(credentials) + + # Initialize both clients + provider._get_graphql_client() + provider._get_rest_client() + + provider.close() + + mock_graphql_client.close.assert_called_once() + mock_rest_client.close.assert_called_once() + assert provider._graphql_client is None + assert provider._rest_client is None + + def test_close_with_no_clients(self): + credentials = {'token': 'test_token'} + provider = GitHubGraphQLProvider(credentials) + + # Should not raise any errors + provider.close() + + assert provider._graphql_client is None + assert provider._rest_client is None diff --git a/backend/core/tests/test_services/test_github_graphql_client.py b/backend/core/tests/test_services/test_github_graphql_client.py new file mode 100644 index 0000000..c8614f9 --- /dev/null +++ b/backend/core/tests/test_services/test_github_graphql_client.py @@ -0,0 +1,455 @@ +import pytest +from unittest.mock import Mock, patch, MagicMock +from gql.transport.exceptions import TransportQueryError, TransportServerError +import time + +from core.services.github_graphql_client import GitHubGraphQLClient + + +@pytest.mark.unit +class TestGitHubGraphQLClientInit: + def test_init_sets_token_and_transport(self): + client = GitHubGraphQLClient('test_token') + + assert client.token == 'test_token' + assert client.endpoint == 'https://api.github.com/graphql' + assert client.transport is not None + assert client.client is not None + assert client.transport.headers['Authorization'] == 'Bearer test_token' + assert client.transport.headers['Accept'] == 'application/vnd.github.v4+json' + assert client.transport.headers['User-Agent'] == 'Ch8r-GitHub-GraphQL/1.0' + + +@pytest.mark.unit +class TestExecuteQuery: + @patch('core.services.github_graphql_client.time.sleep') + @patch('core.services.github_graphql_client.gql') + def test_execute_query_success(self, mock_gql, mock_sleep): + client = GitHubGraphQLClient('test_token') + mock_gql.return_value = 'query_doc' + client.client.execute = Mock(return_value={'data': 'test'}) + + result = client._execute_query('query string') + + assert result == {'data': 'test'} + mock_gql.assert_called_once_with('query string') + client.client.execute.assert_called_once() + + @patch('core.services.github_graphql_client.time.sleep') + @patch('core.services.github_graphql_client.gql') + def test_execute_query_with_variables(self, mock_gql, mock_sleep): + client = GitHubGraphQLClient('test_token') + mock_gql.return_value = 'query_doc' + client.client.execute = Mock(return_value={'data': 'test'}) + + variables = {'owner': 'test', 'repo': 'repo'} + result = client._execute_query('query string', variables) + + assert result == {'data': 'test'} + client.client.execute.assert_called_once_with('query_doc', variable_values=variables) + + @patch('core.services.github_graphql_client.time.sleep') + @patch('core.services.github_graphql_client.gql') + def test_execute_query_transport_query_error_raises(self, mock_gql, mock_sleep): + client = GitHubGraphQLClient('test_token') + mock_gql.return_value = 'query_doc' + client.client.execute = Mock(side_effect=TransportQueryError('Syntax error')) + + with pytest.raises(TransportQueryError): + client._execute_query('query string') + + @patch('core.services.github_graphql_client.time.sleep') + @patch('core.services.github_graphql_client.gql') + def test_execute_query_rate_limit_retries(self, mock_gql, mock_sleep): + client = GitHubGraphQLClient('test_token') + mock_gql.return_value = 'query_doc' + + client.client.execute = Mock(side_effect=[ + TransportServerError('API rate limit exceeded'), + {'data': 'success'} + ]) + + result = client._execute_query('query string') + + assert result == {'data': 'success'} + assert mock_sleep.called + + @patch('core.services.github_graphql_client.time.sleep') + @patch('core.services.github_graphql_client.gql') + def test_execute_query_authentication_error_raises(self, mock_gql, mock_sleep): + client = GitHubGraphQLClient('test_token') + mock_gql.return_value = 'query_doc' + client.client.execute = Mock(side_effect=TransportServerError('Bad credentials')) + + with pytest.raises(TransportServerError): + client._execute_query('query string') + + @patch('core.services.github_graphql_client.time.sleep') + @patch('core.services.github_graphql_client.gql') + def test_execute_query_server_error_retries(self, mock_gql, mock_sleep): + client = GitHubGraphQLClient('test_token') + mock_gql.return_value = 'query_doc' + + client.client.execute = Mock(side_effect=[ + TransportServerError('Server error'), + TransportServerError('Server error'), + {'data': 'success'} + ]) + + result = client._execute_query('query string') + + assert result == {'data': 'success'} + assert client.client.execute.call_count == 3 + + @patch('core.services.github_graphql_client.time.sleep') + @patch('core.services.github_graphql_client.gql') + def test_execute_query_generic_exception_retries(self, mock_gql, mock_sleep): + client = GitHubGraphQLClient('test_token') + mock_gql.return_value = 'query_doc' + + client.client.execute = Mock(side_effect=[ + Exception('Connection error'), + Exception('Connection error'), + {'data': 'success'} + ]) + + result = client._execute_query('query string') + + assert result == {'data': 'success'} + assert client.client.execute.call_count == 3 + + @patch('core.services.github_graphql_client.time.sleep') + @patch('core.services.github_graphql_client.gql') + def test_execute_query_max_retries_exceeded(self, mock_gql, mock_sleep): + client = GitHubGraphQLClient('test_token') + mock_gql.return_value = 'query_doc' + client.client.execute = Mock(side_effect=Exception('Connection error')) + + with pytest.raises(Exception): + client._execute_query('query string') + + assert client.client.execute.call_count == 3 + + +@pytest.mark.unit +class TestGetIssuesWithComments: + @patch('core.services.github_graphql_client.time.sleep') + @patch('core.services.github_graphql_client.gql') + def test_get_issues_with_comments_default_states(self, mock_gql, mock_sleep): + client = GitHubGraphQLClient('test_token') + mock_gql.return_value = 'query_doc' + client.client.execute = Mock(return_value={'repository': {'issues': {'edges': []}}}) + + result = client.get_issues_with_comments('owner', 'repo') + + assert 'repository' in result + call_kwargs = client.client.execute.call_args[1] + assert call_kwargs['variable_values']['states'] == ["OPEN", "CLOSED"] + + @patch('core.services.github_graphql_client.time.sleep') + @patch('core.services.github_graphql_client.gql') + def test_get_issues_with_comments_custom_states(self, mock_gql, mock_sleep): + client = GitHubGraphQLClient('test_token') + mock_gql.return_value = 'query_doc' + client.client.execute = Mock(return_value={'repository': {'issues': {'edges': []}}}) + + result = client.get_issues_with_comments('owner', 'repo', states=['OPEN']) + + assert 'repository' in result + call_kwargs = client.client.execute.call_args[1] + assert call_kwargs['variable_values']['states'] == ['OPEN'] + + @patch('core.services.github_graphql_client.time.sleep') + @patch('core.services.github_graphql_client.gql') + def test_get_issues_with_comments_with_cursor(self, mock_gql, mock_sleep): + client = GitHubGraphQLClient('test_token') + mock_gql.return_value = 'query_doc' + client.client.execute = Mock(return_value={'repository': {'issues': {'edges': []}}}) + + result = client.get_issues_with_comments('owner', 'repo', after_cursor='cursor123') + + assert 'repository' in result + call_kwargs = client.client.execute.call_args[1] + assert call_kwargs['variable_values']['after'] == 'cursor123' + + @patch('core.services.github_graphql_client.time.sleep') + @patch('core.services.github_graphql_client.gql') + def test_get_issues_with_comments_custom_first(self, mock_gql, mock_sleep): + client = GitHubGraphQLClient('test_token') + mock_gql.return_value = 'query_doc' + client.client.execute = Mock(return_value={'repository': {'issues': {'edges': []}}}) + + result = client.get_issues_with_comments('owner', 'repo', first=50) + + assert 'repository' in result + call_kwargs = client.client.execute.call_args[1] + assert call_kwargs['variable_values']['first'] == 50 + + +@pytest.mark.unit +class TestGetAllIssuesWithComments: + @patch('core.services.github_graphql_client.time.sleep') + @patch('core.services.github_graphql_client.gql') + def test_get_all_issues_with_comments_single_page(self, mock_gql, mock_sleep): + client = GitHubGraphQLClient('test_token') + mock_gql.return_value = 'query_doc' + + mock_response = { + 'repository': { + 'issues': { + 'edges': [ + {'node': {'id': 1, 'title': 'Issue 1'}}, + {'node': {'id': 2, 'title': 'Issue 2'}} + ], + 'pageInfo': { + 'hasNextPage': False, + 'endCursor': None + } + } + } + } + client.client.execute = Mock(return_value=mock_response) + + result = client.get_all_issues_with_comments('owner', 'repo') + + assert len(result) == 2 + assert result[0]['title'] == 'Issue 1' + + @patch('core.services.github_graphql_client.time.sleep') + @patch('core.services.github_graphql_client.gql') + def test_get_all_issues_with_comments_pagination(self, mock_gql, mock_sleep): + client = GitHubGraphQLClient('test_token') + mock_gql.return_value = 'query_doc' + + mock_response1 = { + 'repository': { + 'issues': { + 'edges': [{'node': {'id': i}} for i in range(100)], + 'pageInfo': { + 'hasNextPage': True, + 'endCursor': 'cursor1' + } + } + } + } + + mock_response2 = { + 'repository': { + 'issues': { + 'edges': [{'node': {'id': i}} for i in range(100, 150)], + 'pageInfo': { + 'hasNextPage': False, + 'endCursor': None + } + } + } + } + + client.client.execute = Mock(side_effect=[mock_response1, mock_response2]) + + result = client.get_all_issues_with_comments('owner', 'repo') + + assert len(result) == 150 + + @patch('core.services.github_graphql_client.time.sleep') + @patch('core.services.github_graphql_client.gql') + def test_get_all_issues_with_comments_no_repository(self, mock_gql, mock_sleep): + client = GitHubGraphQLClient('test_token') + mock_gql.return_value = 'query_doc' + client.client.execute = Mock(return_value={}) + + result = client.get_all_issues_with_comments('owner', 'repo') + + assert result == [] + + @patch('core.services.github_graphql_client.time.sleep') + @patch('core.services.github_graphql_client.gql') + def test_get_all_issues_with_comments_custom_states(self, mock_gql, mock_sleep): + client = GitHubGraphQLClient('test_token') + mock_gql.return_value = 'query_doc' + + mock_response = { + 'repository': { + 'issues': { + 'edges': [], + 'pageInfo': {'hasNextPage': False} + } + } + } + client.client.execute = Mock(return_value=mock_response) + + result = client.get_all_issues_with_comments('owner', 'repo', states=['OPEN']) + + assert result == [] + + +@pytest.mark.unit +class TestGetPullRequestsWithComments: + @patch('core.services.github_graphql_client.time.sleep') + @patch('core.services.github_graphql_client.gql') + def test_get_pull_requests_with_comments_default_states(self, mock_gql, mock_sleep): + client = GitHubGraphQLClient('test_token') + mock_gql.return_value = 'query_doc' + client.client.execute = Mock(return_value={'repository': {'pullRequests': {'edges': []}}}) + + result = client.get_pull_requests_with_comments('owner', 'repo') + + assert 'repository' in result + call_kwargs = client.client.execute.call_args[1] + assert call_kwargs['variable_values']['states'] == ["OPEN", "CLOSED", "MERGED"] + + @patch('core.services.github_graphql_client.time.sleep') + @patch('core.services.github_graphql_client.gql') + def test_get_pull_requests_with_comments_custom_states(self, mock_gql, mock_sleep): + client = GitHubGraphQLClient('test_token') + mock_gql.return_value = 'query_doc' + client.client.execute = Mock(return_value={'repository': {'pullRequests': {'edges': []}}}) + + result = client.get_pull_requests_with_comments('owner', 'repo', states=['OPEN']) + + assert 'repository' in result + call_kwargs = client.client.execute.call_args[1] + assert call_kwargs['variable_values']['states'] == ['OPEN'] + + @patch('core.services.github_graphql_client.time.sleep') + @patch('core.services.github_graphql_client.gql') + def test_get_pull_requests_with_comments_with_cursor(self, mock_gql, mock_sleep): + client = GitHubGraphQLClient('test_token') + mock_gql.return_value = 'query_doc' + client.client.execute = Mock(return_value={'repository': {'pullRequests': {'edges': []}}}) + + result = client.get_pull_requests_with_comments('owner', 'repo', after_cursor='cursor123') + + assert 'repository' in result + call_kwargs = client.client.execute.call_args[1] + assert call_kwargs['variable_values']['after'] == 'cursor123' + + +@pytest.mark.unit +class TestGetAllPullRequestsWithComments: + @patch('core.services.github_graphql_client.time.sleep') + @patch('core.services.github_graphql_client.gql') + def test_get_all_pull_requests_with_comments_single_page(self, mock_gql, mock_sleep): + client = GitHubGraphQLClient('test_token') + mock_gql.return_value = 'query_doc' + + mock_response = { + 'repository': { + 'pullRequests': { + 'edges': [ + {'node': {'id': 1, 'title': 'PR 1'}}, + {'node': {'id': 2, 'title': 'PR 2'}} + ], + 'pageInfo': { + 'hasNextPage': False, + 'endCursor': None + } + } + } + } + client.client.execute = Mock(return_value=mock_response) + + result = client.get_all_pull_requests_with_comments('owner', 'repo') + + assert len(result) == 2 + assert result[0]['title'] == 'PR 1' + + @patch('core.services.github_graphql_client.time.sleep') + @patch('core.services.github_graphql_client.gql') + def test_get_all_pull_requests_with_comments_pagination(self, mock_gql, mock_sleep): + client = GitHubGraphQLClient('test_token') + mock_gql.return_value = 'query_doc' + + mock_response1 = { + 'repository': { + 'pullRequests': { + 'edges': [{'node': {'id': i}} for i in range(100)], + 'pageInfo': { + 'hasNextPage': True, + 'endCursor': 'cursor1' + } + } + } + } + + mock_response2 = { + 'repository': { + 'pullRequests': { + 'edges': [{'node': {'id': i}} for i in range(100, 150)], + 'pageInfo': { + 'hasNextPage': False, + 'endCursor': None + } + } + } + } + + client.client.execute = Mock(side_effect=[mock_response1, mock_response2]) + + result = client.get_all_pull_requests_with_comments('owner', 'repo') + + assert len(result) == 150 + + @patch('core.services.github_graphql_client.time.sleep') + @patch('core.services.github_graphql_client.gql') + def test_get_all_pull_requests_with_comments_no_repository(self, mock_gql, mock_sleep): + client = GitHubGraphQLClient('test_token') + mock_gql.return_value = 'query_doc' + client.client.execute = Mock(return_value={}) + + result = client.get_all_pull_requests_with_comments('owner', 'repo') + + assert result == [] + + +@pytest.mark.unit +class TestGetRepositoryInfo: + @patch('core.services.github_graphql_client.time.sleep') + @patch('core.services.github_graphql_client.gql') + def test_get_repository_info_success(self, mock_gql, mock_sleep): + client = GitHubGraphQLClient('test_token') + mock_gql.return_value = 'query_doc' + + mock_response = { + 'repository': { + 'name': 'test-repo', + 'nameWithOwner': 'owner/test-repo', + 'description': 'Test repository', + 'url': 'https://github.com/owner/test-repo', + 'stargazerCount': 100 + } + } + client.client.execute = Mock(return_value=mock_response) + + result = client.get_repository_info('owner', 'repo') + + assert result['name'] == 'test-repo' + assert result['stargazerCount'] == 100 + + @patch('core.services.github_graphql_client.time.sleep') + @patch('core.services.github_graphql_client.gql') + def test_get_repository_info_no_repository_key(self, mock_gql, mock_sleep): + client = GitHubGraphQLClient('test_token') + mock_gql.return_value = 'query_doc' + client.client.execute = Mock(return_value={}) + + result = client.get_repository_info('owner', 'repo') + + assert result == {} + + +@pytest.mark.unit +class TestClose: + def test_close_session(self): + client = GitHubGraphQLClient('test_token') + client.client.close_session = Mock() + + client.close() + + client.client.close_session.assert_called_once() + + def test_close_no_close_session_method(self): + client = GitHubGraphQLClient('test_token') + client.client = Mock(spec=[]) + + client.close() diff --git a/backend/core/tests/test_services/test_github_graphql_ingestion.py b/backend/core/tests/test_services/test_github_graphql_ingestion.py new file mode 100644 index 0000000..bd25d5f --- /dev/null +++ b/backend/core/tests/test_services/test_github_graphql_ingestion.py @@ -0,0 +1,453 @@ +import pytest +from unittest.mock import Mock, patch, MagicMock +from datetime import datetime +from django.utils import timezone + +from core.services.github_graphql_ingestion import ( + GitHubGraphQLIngestionService, + _extract_numeric_id_from_global_id +) + + +@pytest.mark.unit +class TestExtractNumericIdFromGlobalId: + def test_extract_numeric_id_from_global_id(self): + result = _extract_numeric_id_from_global_id('test_global_id_123') + + assert isinstance(result, int) + assert 0 <= result <= 0xFFFFFFFF + + def test_extract_numeric_id_consistency(self): + id1 = _extract_numeric_id_from_global_id('same_id') + id2 = _extract_numeric_id_from_global_id('same_id') + + assert id1 == id2 + + def test_extract_numeric_id_different_inputs(self): + id1 = _extract_numeric_id_from_global_id('id_1') + id2 = _extract_numeric_id_from_global_id('id_2') + + assert id1 != id2 + + +@pytest.mark.unit +class TestGitHubGraphQLIngestionServiceInit: + def test_init(self): + mock_app_integration = Mock() + service = GitHubGraphQLIngestionService(mock_app_integration) + + assert service.app_integration == mock_app_integration + assert service.graphql_client is None + assert service.rest_client is None + assert service.repository is None + assert service.quality_filter is not None + + +@pytest.mark.unit +class TestGetGraphQLClient: + @patch('core.services.github_graphql_ingestion.GitHubGraphQLClient') + def test_get_graphql_client_creates_new_client(self, mock_client_class): + mock_app_integration = Mock() + mock_integration = Mock() + mock_integration.credentials = '{"token": "test_token"}' + mock_app_integration.integration = mock_integration + + service = GitHubGraphQLIngestionService(mock_app_integration) + + client = service._get_graphql_client() + + assert client is not None + mock_client_class.assert_called_once_with('test_token') + + @patch('core.services.github_graphql_ingestion.GitHubGraphQLClient') + def test_get_graphql_client_reuses_cached_client(self, mock_client_class): + mock_app_integration = Mock() + mock_integration = Mock() + mock_integration.credentials = '{"token": "test_token"}' + mock_app_integration.integration = mock_integration + + service = GitHubGraphQLIngestionService(mock_app_integration) + + client1 = service._get_graphql_client() + client2 = service._get_graphql_client() + + assert client1 is client2 + mock_client_class.assert_called_once() + + def test_get_graphql_client_no_token(self): + mock_app_integration = Mock() + mock_integration = Mock() + mock_integration.credentials = '{}' + mock_app_integration.integration = mock_integration + + service = GitHubGraphQLIngestionService(mock_app_integration) + + with pytest.raises(ValueError, match="GitHub token not found"): + service._get_graphql_client() + + +@pytest.mark.unit +class TestGetRestClient: + @patch('core.services.github_graphql_ingestion.GitHubAPIClient') + def test_get_rest_client_creates_new_client(self, mock_client_class): + mock_app_integration = Mock() + mock_integration = Mock() + mock_integration.credentials = '{"token": "test_token"}' + mock_app_integration.integration = mock_integration + + service = GitHubGraphQLIngestionService(mock_app_integration) + + client = service._get_rest_client() + + assert client is not None + mock_client_class.assert_called_once_with('test_token') + + @patch('core.services.github_graphql_ingestion.GitHubAPIClient') + def test_get_rest_client_reuses_cached_client(self, mock_client_class): + mock_app_integration = Mock() + mock_integration = Mock() + mock_integration.credentials = '{"token": "test_token"}' + mock_app_integration.integration = mock_integration + + service = GitHubGraphQLIngestionService(mock_app_integration) + + client1 = service._get_rest_client() + client2 = service._get_rest_client() + + assert client1 is client2 + mock_client_class.assert_called_once() + + +@pytest.mark.unit +class TestParseDatetime: + def test_parse_datetime_with_z_suffix(self): + service = GitHubGraphQLIngestionService(Mock()) + + result = service._parse_datetime('2024-01-01T12:00:00Z') + + assert result is not None + assert result.tzinfo is not None + + def test_parse_datetime_without_z_suffix(self): + service = GitHubGraphQLIngestionService(Mock()) + + result = service._parse_datetime('2024-01-01T12:00:00') + + assert result is not None + + def test_parse_datetime_none(self): + service = GitHubGraphQLIngestionService(Mock()) + + result = service._parse_datetime(None) + + assert result is None + + def test_parse_datetime_invalid_string(self): + service = GitHubGraphQLIngestionService(Mock()) + + result = service._parse_datetime('invalid_datetime') + + assert result is not None + + +@pytest.mark.unit +class TestCreateKnowledgeBaseContent: + @patch('core.services.github_graphql_ingestion.VCRepository') + def test_create_knowledge_base_content_no_repository(self, mock_repo_class): + service = GitHubGraphQLIngestionService(Mock()) + service.repository = None + + result = service._create_knowledge_base_content() + + assert result == "" + + @patch('core.services.github_graphql_ingestion.VCRepository') + def test_create_knowledge_base_content_with_repository(self, mock_repo_class): + service = GitHubGraphQLIngestionService(Mock()) + + mock_repo = Mock() + mock_repo.full_name = 'owner/repo' + mock_repo.description = 'Test repository' + mock_repo.issues.all.return_value = [] + mock_repo.pull_requests.all.return_value = [] + service.repository = mock_repo + + result = service._create_knowledge_base_content() + + assert 'owner/repo' in result + assert 'Test repository' in result + + @patch('core.services.github_graphql_ingestion.VCRepository') + def test_create_knowledge_base_content_with_issues(self, mock_repo_class): + service = GitHubGraphQLIngestionService(Mock()) + + mock_issue = Mock() + mock_issue.number = 1 + mock_issue.title = 'Test Issue' + mock_issue.state = 'open' + mock_issue.author = 'testuser' + mock_issue.body = 'Issue body' + mock_issue.labels = ['bug', 'enhancement'] + mock_issue.comments.all.return_value = [] + + mock_repo = Mock() + mock_repo.full_name = 'owner/repo' + mock_repo.description = '' + mock_repo.issues.all.return_value = [mock_issue] + mock_repo.pull_requests.all.return_value = [] + service.repository = mock_repo + + result = service._create_knowledge_base_content() + + assert 'Issue #1' in result + assert 'Test Issue' in result + + +@pytest.mark.unit +class TestIngestIssueCommentFromGraphQL: + @patch('core.services.github_graphql_ingestion.VCIssueComment') + @patch('core.services.github_graphql_ingestion._extract_numeric_id_from_global_id') + def test_ingest_issue_comment_success(self, mock_extract_id, mock_comment_class): + service = GitHubGraphQLIngestionService(Mock()) + service.quality_filter.should_ingest = Mock(return_value=True) + service.quality_filter.remove_emojis = Mock(return_value='cleaned body') + + mock_issue = Mock() + comment_data = { + 'id': 'comment_id_123', + 'body': 'Test comment', + 'author': {'login': 'testuser'}, + 'authorAssociation': 'OWNER', + 'createdAt': '2024-01-01T12:00:00Z', + 'updatedAt': '2024-01-01T12:00:00Z', + 'url': 'https://github.com/test' + } + mock_extract_id.return_value = 123 + + service._ingest_issue_comment_from_graphql(mock_issue, comment_data) + + mock_comment_class.objects.update_or_create.assert_called_once() + + @patch('core.services.github_graphql_ingestion.VCIssueComment') + @patch('core.services.github_graphql_ingestion._extract_numeric_id_from_global_id') + def test_ingest_issue_comment_filtered(self, mock_extract_id, mock_comment_class): + service = GitHubGraphQLIngestionService(Mock()) + service.quality_filter.should_ingest = Mock(return_value=False) + + mock_issue = Mock() + comment_data = { + 'id': 'comment_id_123', + 'body': '👍', + 'author': {'login': 'testuser'}, + 'authorAssociation': 'OWNER', + 'createdAt': '2024-01-01T12:00:00Z', + 'updatedAt': '2024-01-01T12:00:00Z', + 'url': 'https://github.com/test' + } + + service._ingest_issue_comment_from_graphql(mock_issue, comment_data) + + mock_comment_class.objects.update_or_create.assert_not_called() + + +@pytest.mark.unit +class TestIngestPRCommentFromGraphQL: + @patch('core.services.github_graphql_ingestion.VCPRComment') + @patch('core.services.github_graphql_ingestion._extract_numeric_id_from_global_id') + def test_ingest_pr_comment_success(self, mock_extract_id, mock_comment_class): + service = GitHubGraphQLIngestionService(Mock()) + service.quality_filter.should_ingest = Mock(return_value=True) + service.quality_filter.remove_emojis = Mock(return_value='cleaned body') + + mock_pr = Mock() + comment_data = { + 'id': 'comment_id_123', + 'body': 'Test comment', + 'author': {'login': 'testuser'}, + 'authorAssociation': 'OWNER', + 'createdAt': '2024-01-01T12:00:00Z', + 'updatedAt': '2024-01-01T12:00:00Z', + 'url': 'https://github.com/test' + } + mock_extract_id.return_value = 123 + + service._ingest_pr_comment_from_graphql(mock_pr, comment_data) + + mock_comment_class.objects.update_or_create.assert_called_once() + + @patch('core.services.github_graphql_ingestion.VCPRComment') + @patch('core.services.github_graphql_ingestion._extract_numeric_id_from_global_id') + def test_ingest_pr_comment_filtered(self, mock_extract_id, mock_comment_class): + service = GitHubGraphQLIngestionService(Mock()) + service.quality_filter.should_ingest = Mock(return_value=False) + + mock_pr = Mock() + comment_data = { + 'id': 'comment_id_123', + 'body': '👍', + 'author': {'login': 'testuser'}, + 'authorAssociation': 'OWNER', + 'createdAt': '2024-01-01T12:00:00Z', + 'updatedAt': '2024-01-01T12:00:00Z', + 'url': 'https://github.com/test' + } + + service._ingest_pr_comment_from_graphql(mock_pr, comment_data) + + mock_comment_class.objects.update_or_create.assert_not_called() + + +@pytest.mark.unit +class TestIngestSingleIssueFromGraphQL: + @patch('core.services.github_graphql_ingestion.transaction.atomic') + @patch('core.services.github_graphql_ingestion.VCIssue') + def test_ingest_single_issue_success(self, mock_issue_class, mock_atomic): + service = GitHubGraphQLIngestionService(Mock()) + service.repository = Mock() + + issue_data = { + 'number': 1, + 'title': 'Test Issue', + 'state': 'OPEN', + 'body': 'Issue body', + 'author': {'login': 'testuser'}, + 'authorAssociation': 'OWNER', + 'assignees': {'nodes': [{'login': 'assignee1'}]}, + 'labels': {'nodes': [{'name': 'bug'}]}, + 'milestone': {'title': 'v1.0'}, + 'locked': False, + 'createdAt': '2024-01-01T12:00:00Z', + 'updatedAt': '2024-01-01T12:00:00Z', + 'closedAt': None, + 'url': 'https://github.com/test' + } + + mock_issue = Mock() + mock_issue_class.objects.update_or_create.return_value = (mock_issue, True) + + service._ingest_single_issue_from_graphql(issue_data) + + mock_issue_class.objects.update_or_create.assert_called_once() + + @patch('core.services.github_graphql_ingestion.transaction.atomic') + @patch('core.services.github_graphql_ingestion.VCIssue') + def test_ingest_single_issue_with_comments(self, mock_issue_class, mock_atomic): + service = GitHubGraphQLIngestionService(Mock()) + service.repository = Mock() + service._ingest_issue_comment_from_graphql = Mock() + + issue_data = { + 'number': 1, + 'title': 'Test Issue', + 'state': 'OPEN', + 'body': 'Issue body', + 'author': {'login': 'testuser'}, + 'authorAssociation': 'OWNER', + 'assignees': {'nodes': []}, + 'labels': {'nodes': []}, + 'milestone': None, + 'locked': False, + 'createdAt': '2024-01-01T12:00:00Z', + 'updatedAt': '2024-01-01T12:00:00Z', + 'closedAt': None, + 'url': 'https://github.com/test', + 'comments': { + 'edges': [ + {'node': {'id': 'comment1', 'body': 'Comment 1', 'author': {'login': 'user1'}, 'authorAssociation': 'NONE', 'createdAt': '2024-01-01T12:00:00Z', 'updatedAt': '2024-01-01T12:00:00Z', 'url': 'https://github.com/test'}} + ] + } + } + + mock_issue = Mock() + mock_issue_class.objects.update_or_create.return_value = (mock_issue, True) + + service._ingest_single_issue_from_graphql(issue_data) + + service._ingest_issue_comment_from_graphql.assert_called_once() + + +@pytest.mark.unit +class TestIngestSinglePRFromGraphQL: + @patch('core.services.github_graphql_ingestion.transaction.atomic') + @patch('core.services.github_graphql_ingestion.VCPullRequest') + def test_ingest_single_pr_success(self, mock_pr_class, mock_atomic): + service = GitHubGraphQLIngestionService(Mock()) + service.repository = Mock() + + pr_data = { + 'number': 1, + 'title': 'Test PR', + 'state': 'OPEN', + 'body': 'PR body', + 'author': {'login': 'testuser'}, + 'authorAssociation': 'OWNER', + 'assignees': {'nodes': [{'login': 'assignee1'}]}, + 'labels': {'nodes': [{'name': 'enhancement'}]}, + 'milestone': {'title': 'v1.0'}, + 'headRefName': 'feature', + 'baseRefName': 'main', + 'merged': False, + 'mergedAt': None, + 'mergeCommit': {'oid': 'abc123'}, + 'additions': 10, + 'deletions': 5, + 'changedFiles': 2, + 'createdAt': '2024-01-01T12:00:00Z', + 'updatedAt': '2024-01-01T12:00:00Z', + 'closedAt': None, + 'url': 'https://github.com/test' + } + + mock_pr = Mock() + mock_pr_class.objects.update_or_create.return_value = (mock_pr, True) + + mock_rest_client = Mock() + mock_rest_client.get_pull_request_files.return_value = [] + + service._ingest_single_pull_request_from_graphql(pr_data, 'owner', 'repo', mock_rest_client) + + mock_pr_class.objects.update_or_create.assert_called_once() + + @patch('core.services.github_graphql_ingestion.transaction.atomic') + @patch('core.services.github_graphql_ingestion.VCPullRequest') + @patch('core.services.github_graphql_ingestion.VCPRFile') + def test_ingest_single_pr_with_files(self, mock_file_class, mock_pr_class, mock_atomic): + service = GitHubGraphQLIngestionService(Mock()) + service.repository = Mock() + + pr_data = { + 'number': 1, + 'title': 'Test PR', + 'state': 'OPEN', + 'body': 'PR body', + 'author': {'login': 'testuser'}, + 'authorAssociation': 'OWNER', + 'assignees': {'nodes': []}, + 'labels': {'nodes': []}, + 'milestone': None, + 'headRefName': 'feature', + 'baseRefName': 'main', + 'merged': False, + 'mergedAt': None, + 'mergeCommit': None, + 'additions': 10, + 'deletions': 5, + 'changedFiles': 2, + 'createdAt': '2024-01-01T12:00:00Z', + 'updatedAt': '2024-01-01T12:00:00Z', + 'closedAt': None, + 'url': 'https://github.com/test', + 'comments': {'edges': []} + } + + mock_pr = Mock() + mock_pr_class.objects.update_or_create.return_value = (mock_pr, True) + + mock_rest_client = Mock() + mock_rest_client.get_pull_request_files.return_value = [ + {'filename': 'file1.py', 'status': 'modified', 'additions': 5, 'deletions': 2, 'changes': 7, 'patch': 'diff', 'blob_url': 'url1', 'raw_url': 'url2', 'contents_url': 'url3'} + ] + + service._ingest_single_pull_request_from_graphql(pr_data, 'owner', 'repo', mock_rest_client) + + mock_file_class.objects.update_or_create.assert_called_once() diff --git a/backend/core/tests/test_services/test_ingestion.py b/backend/core/tests/test_services/test_ingestion.py new file mode 100644 index 0000000..1c85068 --- /dev/null +++ b/backend/core/tests/test_services/test_ingestion.py @@ -0,0 +1,536 @@ +import pytest +from unittest.mock import Mock, patch, MagicMock +import uuid + +from core.services.ingestion import ( + chunk_text, + embed_text, + embed_sparse, + get_chunks, + ingest_kb, + delete_vectors_from_qdrant, + _extract_content, + _clean_content, + _generate_embeddings, + _cleanup_existing_chunks, + _upsert_chunks_to_qdrant, + _finalize_ingestion, + _handle_duplicate_checks +) + + +@pytest.mark.unit +class TestChunkText: + def test_chunk_text_basic(self): + text = "This is a test text that should be chunked into smaller pieces" + chunks = chunk_text(text, chunk_size=20, overlap=5) + + assert len(chunks) > 1 + assert all(len(chunk) <= 20 for chunk in chunks) + + def test_chunk_text_shorter_than_chunk_size(self): + text = "Short text" + chunks = chunk_text(text, chunk_size=100, overlap=10) + + assert len(chunks) == 1 + assert chunks[0] == text + + def test_chunk_text_empty_string(self): + chunks = chunk_text("", chunk_size=100, overlap=10) + + assert chunks == [] + + def test_chunk_text_default_parameters(self): + text = "A" * 400 + chunks = chunk_text(text) + + assert len(chunks) > 1 + assert all(len(chunk) <= 300 for chunk in chunks) + + +@pytest.mark.unit +class TestEmbedText: + @patch('core.services.ingestion._generate_single_embedding') + @patch('core.models.content_hash.ContentHash') + def test_embed_text_with_cache(self, mock_content_hash, mock_generate_embedding): + mock_app = Mock() + mock_app.uuid = uuid.uuid4() + + mock_hash_instance = Mock() + mock_hash_instance.embedding = [0.1, 0.2, 0.3] + mock_content_hash.objects.get.return_value = mock_hash_instance + mock_generate_embedding.return_value = [0.4, 0.5, 0.6] + + chunks = ["chunk1", "chunk2"] + result = embed_text(chunks, mock_app) + + assert len(result) == 2 + assert result[0] == [0.1, 0.2, 0.3] + + @patch('core.services.ingestion._generate_single_embedding') + @patch('core.models.content_hash.ContentHash') + def test_embed_text_without_cache(self, mock_content_hash, mock_generate_embedding): + class MockDoesNotExist(Exception): + pass + mock_app = Mock() + mock_app.uuid = uuid.uuid4() + + mock_content_hash.DoesNotExist = MockDoesNotExist + mock_content_hash.objects.get.side_effect = MockDoesNotExist + mock_generate_embedding.return_value = [0.1, 0.2, 0.3] + + chunks = ["chunk1"] + result = embed_text(chunks, mock_app) + + assert len(result) == 1 + assert result[0] == [0.1, 0.2, 0.3] + mock_content_hash.objects.update_or_create.assert_called_once() + + @patch('core.services.ingestion._generate_single_embedding') + @patch('core.models.content_hash.ContentHash') + def test_embed_text_empty_embedding(self, mock_content_hash, mock_generate_embedding): + class MockDoesNotExist(Exception): + pass + mock_app = Mock() + mock_app.uuid = uuid.uuid4() + + mock_content_hash.DoesNotExist = MockDoesNotExist + mock_content_hash.objects.get.side_effect = MockDoesNotExist + mock_generate_embedding.return_value = [] + + chunks = ["chunk1"] + result = embed_text(chunks, mock_app) + + assert len(result) == 1 + assert result[0] == [] + + +@pytest.mark.unit +class TestEmbedSparse: + @patch('core.services.ingestion._get_sparse_model') + def test_embed_sparse_success(self, mock_get_model): + mock_model = Mock() + mock_embedding = Mock() + mock_embedding.indices.tolist.return_value = [0, 1, 2] + mock_embedding.values.tolist.return_value = [0.1, 0.2, 0.3] + mock_model.embed.return_value = [mock_embedding] + mock_get_model.return_value = mock_model + + chunks = ["chunk1", "chunk2"] + result = embed_sparse(chunks) + + assert len(result) == 1 + assert result[0].indices == [0, 1, 2] + assert result[0].values == [0.1, 0.2, 0.3] + + @patch('core.services.ingestion._get_sparse_model') + def test_embed_sparse_multiple_chunks(self, mock_get_model): + mock_model = Mock() + mock_embedding1 = Mock() + mock_embedding1.indices.tolist.return_value = [0, 1] + mock_embedding1.values.tolist.return_value = [0.1, 0.2] + mock_embedding2 = Mock() + mock_embedding2.indices.tolist.return_value = [2, 3] + mock_embedding2.values.tolist.return_value = [0.3, 0.4] + mock_model.embed.return_value = [mock_embedding1, mock_embedding2] + mock_get_model.return_value = mock_model + + chunks = ["chunk1", "chunk2"] + result = embed_sparse(chunks) + + assert len(result) == 2 + + +@pytest.mark.unit +class TestExtractContent: + def test_extract_content_success(self): + mock_kb = Mock() + mock_kb.uuid = uuid.uuid4() + mock_kb.metadata = {'content': 'Test content'} + + result = _extract_content(mock_kb) + + assert result == 'Test content' + + def test_extract_content_no_content(self): + mock_kb = Mock() + mock_kb.uuid = uuid.uuid4() + mock_kb.metadata = {} + + result = _extract_content(mock_kb) + + assert result is None + + def test_extract_content_empty_content(self): + mock_kb = Mock() + mock_kb.uuid = uuid.uuid4() + mock_kb.metadata = {'content': ''} + + result = _extract_content(mock_kb) + + assert result is None + + +@pytest.mark.unit +class TestCleanContent: + @patch('core.services.ingestion._quality_filter') + def test_clean_content_success(self, mock_quality_filter): + mock_kb = Mock() + mock_kb.uuid = uuid.uuid4() + mock_kb.source_type = 'text' + mock_kb.status = 'pending' + + mock_quality_filter.should_ingest.return_value = True + mock_quality_filter.remove_emojis.return_value = 'Cleaned content' + + result = _clean_content('Original content', mock_kb) + + assert result == 'Cleaned content' + + @patch('core.services.ingestion._quality_filter') + def test_clean_content_filtered(self, mock_quality_filter): + mock_kb = Mock() + mock_kb.uuid = uuid.uuid4() + mock_kb.source_type = 'text' + mock_kb.status = 'pending' + + mock_quality_filter.should_ingest.return_value = False + + result = _clean_content('Original content', mock_kb) + + assert result is None + assert mock_kb.status == 'completed' + + @patch('core.services.ingestion._quality_filter') + def test_clean_content_empty_after_cleaning(self, mock_quality_filter): + mock_kb = Mock() + mock_kb.uuid = uuid.uuid4() + mock_kb.source_type = 'text' + mock_kb.status = 'pending' + + mock_quality_filter.should_ingest.return_value = True + mock_quality_filter.remove_emojis.return_value = '' + + result = _clean_content('Original content', mock_kb) + + assert result is None + assert mock_kb.status == 'completed' + + +@pytest.mark.unit +class TestHandleDuplicateChecks: + @patch('core.services.ingestion._duplicate_detector') + def test_handle_duplicate_checks_url_source(self, mock_detector): + mock_kb = Mock() + mock_kb.uuid = uuid.uuid4() + mock_kb.source_type = 'url' + mock_kb.status = 'pending' + mock_app = Mock() + + result = _handle_duplicate_checks('content', mock_kb, mock_app) + + assert result is True + + @patch('core.services.ingestion._duplicate_detector') + def test_handle_duplicate_checks_is_duplicate(self, mock_detector): + mock_kb = Mock() + mock_kb.uuid = uuid.uuid4() + mock_kb.source_type = 'text' + mock_kb.status = 'pending' + mock_app = Mock() + + mock_detector.is_duplicate.return_value = True + + result = _handle_duplicate_checks('content', mock_kb, mock_app) + + assert result is False + assert mock_kb.status == 'duplicate' + + @patch('core.services.ingestion._duplicate_detector') + def test_handle_duplicate_checks_semantic_duplicate(self, mock_detector): + mock_kb = Mock() + mock_kb.uuid = uuid.uuid4() + mock_kb.source_type = 'text' + mock_kb.status = 'pending' + mock_app = Mock() + + mock_detector.is_duplicate.return_value = False + mock_detector.handle_semantic_duplicate.return_value = False + mock_detector._was_replacement_triggered.return_value = False + + result = _handle_duplicate_checks('content', mock_kb, mock_app) + + assert result is False + assert mock_kb.status == 'duplicate' + + @patch('core.services.ingestion._duplicate_detector') + def test_handle_duplicate_checks_replacement_triggered(self, mock_detector): + mock_kb = Mock() + mock_kb.uuid = uuid.uuid4() + mock_kb.source_type = 'text' + mock_kb.status = 'pending' + mock_app = Mock() + + mock_detector.is_duplicate.return_value = False + mock_detector.handle_semantic_duplicate.return_value = False + mock_detector._was_replacement_triggered.return_value = True + + result = _handle_duplicate_checks('content', mock_kb, mock_app) + + assert result is False + mock_kb.delete.assert_called_once() + + +@pytest.mark.unit +class TestGenerateEmbeddings: + @patch('core.services.ingestion.embed_text') + @patch('core.services.ingestion.embed_sparse') + def test_generate_embeddings_success(self, mock_embed_sparse, mock_embed_text): + mock_app = Mock() + mock_kb = Mock() + mock_kb.uuid = uuid.uuid4() + + mock_embed_text.return_value = [[0.1, 0.2], [0.3, 0.4]] + mock_sparse_vector = Mock() + mock_sparse_vector.indices = [0, 1] + mock_sparse_vector.values = [0.1, 0.2] + mock_embed_sparse.return_value = [mock_sparse_vector, mock_sparse_vector] + + chunks = ["chunk1", "chunk2"] + result = _generate_embeddings(chunks, mock_app, mock_kb) + + assert result is not None + assert len(result) == 2 + + @patch('core.services.ingestion.embed_text') + @patch('core.services.ingestion.embed_sparse') + def test_generate_embeddings_all_empty(self, mock_embed_sparse, mock_embed_text): + mock_app = Mock() + mock_kb = Mock() + mock_kb.uuid = uuid.uuid4() + + mock_embed_text.return_value = [[], []] + + chunks = ["chunk1", "chunk2"] + result = _generate_embeddings(chunks, mock_app, mock_kb) + + assert result is None + + +@pytest.mark.unit +class TestCleanupExistingChunks: + @patch('core.services.ingestion.qdrant') + @patch('core.services.ingestion.IngestedChunk') + def test_cleanup_existing_chunks(self, mock_chunk_class, mock_qdrant): + mock_kb = Mock() + mock_kb.uuid = uuid.uuid4() + + mock_chunk1 = Mock() + mock_chunk1.uuid = uuid.uuid4() + mock_chunk2 = Mock() + mock_chunk2.uuid = uuid.uuid4() + mock_queryset = Mock() + mock_queryset.__iter__ = Mock(return_value=iter([mock_chunk1, mock_chunk2])) + mock_queryset.delete = Mock() + mock_chunk_class.objects.filter.return_value = mock_queryset + + _cleanup_existing_chunks(mock_kb) + + mock_chunk_class.objects.filter.assert_called_once() + mock_queryset.delete.assert_called_once() + + +@pytest.mark.unit +class TestFinalizeIngestion: + @patch('core.services.ingestion._duplicate_detector') + def test_finalize_ingestion(self, mock_detector): + mock_kb = Mock() + mock_kb.uuid = uuid.uuid4() + mock_kb.status = 'processing' + mock_kb.source_type = 'text' + mock_app = Mock() + + _finalize_ingestion(mock_kb, mock_app, 'content', 5, 10) + + assert mock_kb.status == 'completed' + mock_detector.store_content_hash.assert_called_once() + + +@pytest.mark.unit +class TestDeleteVectorsFromQdrant: + @patch('core.services.ingestion.qdrant') + def test_delete_vectors_empty_ids(self, mock_qdrant): + delete_vectors_from_qdrant([]) + + mock_qdrant.delete.assert_not_called() + + @patch('core.services.ingestion.qdrant') + def test_delete_vectors_success(self, mock_qdrant): + ids = ['id1', 'id2'] + + delete_vectors_from_qdrant(ids) + + mock_qdrant.delete.assert_called_once() + + +@pytest.mark.unit +class TestGetChunks: + @patch('core.services.ingestion.embed_text') + @patch('core.services.ingestion._get_sparse_model') + @patch('core.services.ingestion.qdrant') + @patch('core.services.ingestion.IngestedChunk') + def test_get_chunks_success(self, mock_chunk_class, mock_qdrant, mock_get_model, mock_embed_text): + mock_app = Mock() + mock_app.uuid = uuid.uuid4() + + mock_embed_text.return_value = [[0.1, 0.2, 0.3]] + + mock_model = Mock() + mock_sparse = Mock() + mock_sparse.indices = [0, 1] + mock_sparse.values = [0.1, 0.2] + mock_model.embed.return_value = [mock_sparse] + mock_get_model.return_value = mock_model + + mock_point = Mock() + mock_point.id = uuid.uuid4() + mock_point.score = 0.5 + mock_result = Mock() + mock_result.points = [mock_point] + mock_qdrant.query_points.return_value = mock_result + + mock_chunk = Mock() + mock_chunk.content = 'Test content' + mock_chunk.chunk_index = 0 + mock_chunk.knowledge_base_id = uuid.uuid4() + mock_chunk_class.objects.filter.return_value.order_by.return_value = [mock_chunk] + + result = get_chunks('query', mock_app) + + assert len(result) == 1 + assert result[0]['content'] == 'Test content' + + @patch('core.services.ingestion.embed_text') + def test_get_chunks_empty_embedding(self, mock_embed_text): + mock_app = Mock() + mock_app.uuid = uuid.uuid4() + + mock_embed_text.return_value = [[]] + + result = get_chunks('query', mock_app) + + assert result == [] + + @patch('core.services.ingestion.embed_text') + @patch('core.services.ingestion._get_sparse_model') + @patch('core.services.ingestion.qdrant') + @patch('core.services.ingestion.IngestedChunk') + def test_get_chunks_low_score(self, mock_chunk_class, mock_qdrant, mock_get_model, mock_embed_text): + mock_app = Mock() + mock_app.uuid = uuid.uuid4() + + mock_embed_text.return_value = [[0.1, 0.2, 0.3]] + + mock_model = Mock() + mock_sparse = Mock() + mock_sparse.indices = [0, 1] + mock_sparse.values = [0.1, 0.2] + mock_model.embed.return_value = [mock_sparse] + mock_get_model.return_value = mock_model + + mock_point = Mock() + mock_point.id = uuid.uuid4() + mock_point.score = 0.2 + mock_result = Mock() + mock_result.points = [mock_point] + mock_qdrant.query_points.return_value = mock_result + + result = get_chunks('query', mock_app) + + assert result == [] + + +@pytest.mark.unit +class TestIngestKB: + @patch('core.services.ingestion._finalize_ingestion') + @patch('core.services.ingestion._upsert_chunks_to_qdrant') + @patch('core.services.ingestion._cleanup_existing_chunks') + @patch('core.services.ingestion._generate_embeddings') + @patch('core.services.ingestion.chunk_text') + @patch('core.services.ingestion._handle_duplicate_checks') + @patch('core.services.ingestion._clean_content') + @patch('core.services.ingestion._extract_content') + def test_ingest_kb_success(self, mock_extract, mock_clean, mock_duplicate, mock_chunk, + mock_generate, mock_cleanup, mock_upsert, mock_finalize): + mock_kb = Mock() + mock_kb.uuid = uuid.uuid4() + mock_kb.source_type = 'text' + mock_kb.status = 'pending' + mock_kb.metadata = {'content': 'Test content'} + mock_app = Mock() + mock_app.uuid = uuid.uuid4() + + mock_extract.return_value = 'content' + mock_clean.return_value = 'cleaned content' + mock_duplicate.return_value = True + mock_chunk.return_value = ['chunk1', 'chunk2'] + mock_generate.return_value = ([[0.1, 0.2], [0.3, 0.4]], [Mock(), Mock()]) + mock_upsert.return_value = 2 + + ingest_kb(mock_kb, mock_app) + + assert mock_finalize.called + + @patch('core.services.ingestion._extract_content') + def test_ingest_kb_no_content(self, mock_extract): + mock_kb = Mock() + mock_kb.uuid = uuid.uuid4() + mock_kb.metadata = {} + mock_app = Mock() + + mock_extract.return_value = None + + ingest_kb(mock_kb, mock_app) + + mock_extract.assert_called_once() + + @patch('core.services.ingestion._clean_content') + @patch('core.services.ingestion._extract_content') + @patch('core.services.ingestion._handle_duplicate_checks') + def test_ingest_kb_content_filtered(self, mock_duplicate, mock_extract, mock_clean): + mock_kb = Mock() + mock_kb.uuid = uuid.uuid4() + mock_kb.source_type = 'text' + mock_kb.status = 'pending' + mock_kb.metadata = {'content': 'content'} + mock_app = Mock() + + mock_extract.return_value = 'content' + + def clean_side_effect(content, kb): + kb.status = 'completed' + return None + mock_clean.side_effect = clean_side_effect + + ingest_kb(mock_kb, mock_app) + + assert mock_kb.status == 'completed' + + @patch('core.services.ingestion._handle_duplicate_checks') + @patch('core.services.ingestion._clean_content') + @patch('core.services.ingestion._extract_content') + def test_ingest_kb_duplicate(self, mock_extract, mock_clean, mock_duplicate): + mock_kb = Mock() + mock_kb.uuid = uuid.uuid4() + mock_kb.source_type = 'text' + mock_kb.status = 'pending' + mock_kb.metadata = {'content': 'content'} + mock_app = Mock() + + mock_extract.return_value = 'content' + mock_clean.return_value = 'cleaned content' + mock_duplicate.return_value = False + + ingest_kb(mock_kb, mock_app) + + mock_duplicate.assert_called_once() diff --git a/backend/core/tests/test_services/test_kb_utils.py b/backend/core/tests/test_services/test_kb_utils.py new file mode 100644 index 0000000..f9260e8 --- /dev/null +++ b/backend/core/tests/test_services/test_kb_utils.py @@ -0,0 +1,317 @@ +import pytest +from unittest.mock import Mock, patch, MagicMock +import uuid + +from core.services.kb_utils import ( + create_kb_records, + parse_kb_from_request, + format_text_uri +) + + +@pytest.mark.unit +class TestFormatTextUri: + def test_format_text_uri(self): + result = format_text_uri("This is a test text value") + + assert result == "text://This is a test text value" + + def test_format_text_uri_long_text(self): + long_text = "A" * 100 + result = format_text_uri(long_text) + + assert result == f"text://{'A' * 50}" + assert len(result) == len("text://") + 50 + + def test_format_text_uri_empty_string(self): + result = format_text_uri("") + + assert result == "text://" + + +@pytest.mark.unit +class TestCreateKBRecords: + @patch('core.services.kb_utils.KnowledgeBase') + @patch('core.services.kb_utils.default_storage') + def test_create_kb_records_file_type(self, mock_storage, mock_kb_class): + mock_app = Mock() + mock_app.uuid = uuid.uuid4() + + mock_file = Mock() + mock_file.name = "test_file.txt" + + mock_storage.save.return_value = "saved_test_file.txt" + mock_kb_instance = Mock() + mock_kb_class.return_value = mock_kb_instance + mock_kb_class.objects.bulk_create = Mock() + mock_kb_class.objects.filter.return_value.order_by.return_value.__getitem__ = Mock(side_effect=lambda x: [Mock()]) + + items = [ + { + 'type': 'file', + 'file': mock_file + } + ] + + result = create_kb_records(mock_app, items) + + mock_storage.save.assert_called_once_with("test_file.txt", mock_file) + mock_kb_class.objects.bulk_create.assert_called_once() + + @patch('core.services.kb_utils.KnowledgeBase') + def test_create_kb_records_text_type(self, mock_kb_class): + mock_app = Mock() + mock_app.uuid = uuid.uuid4() + + mock_kb_instance = Mock() + mock_kb_class.return_value = mock_kb_instance + mock_kb_class.objects.bulk_create = Mock() + mock_kb_class.objects.filter.return_value.order_by.return_value.__getitem__ = Mock(side_effect=lambda x: [Mock()]) + + items = [ + { + 'type': 'text', + 'value': 'Test text content' + } + ] + + result = create_kb_records(mock_app, items) + + mock_kb_class.objects.bulk_create.assert_called_once() + + @patch('core.services.kb_utils.KnowledgeBase') + def test_create_kb_records_url_type(self, mock_kb_class): + mock_app = Mock() + mock_app.uuid = uuid.uuid4() + + mock_kb_instance = Mock() + mock_kb_class.return_value = mock_kb_instance + mock_kb_class.objects.bulk_create = Mock() + mock_kb_class.objects.filter.return_value.order_by.return_value.__getitem__ = Mock(side_effect=lambda x: [Mock()]) + + items = [ + { + 'type': 'url', + 'value': 'https://example.com' + } + ] + + result = create_kb_records(mock_app, items) + + mock_kb_class.objects.bulk_create.assert_called_once() + + @patch('core.services.kb_utils.KnowledgeBase') + def test_create_kb_records_url_with_crawling_config(self, mock_kb_class): + mock_app = Mock() + mock_app.uuid = uuid.uuid4() + + mock_kb_instance = Mock() + mock_kb_class.return_value = mock_kb_instance + mock_kb_class.objects.bulk_create = Mock() + mock_kb_class.objects.filter.return_value.order_by.return_value.__getitem__ = Mock(side_effect=lambda x: [Mock()]) + + items = [ + { + 'type': 'url', + 'value': 'https://example.com', + 'crawling_config': { + 'enable_crawling': True, + 'max_depth': 3, + 'max_pages': 100 + } + } + ] + + result = create_kb_records(mock_app, items) + + mock_kb_class.objects.bulk_create.assert_called_once() + + @patch('core.services.kb_utils.KnowledgeBase') + @patch('core.services.kb_utils.default_storage') + def test_create_kb_records_multiple_items(self, mock_storage, mock_kb_class): + mock_app = Mock() + mock_app.uuid = uuid.uuid4() + + mock_file = Mock() + mock_file.name = "test_file.txt" + mock_storage.save.return_value = "saved_test_file.txt" + + mock_kb_instance = Mock() + mock_kb_class.return_value = mock_kb_instance + mock_kb_class.objects.bulk_create = Mock() + mock_kb_class.objects.filter.return_value.order_by.return_value.__getitem__ = Mock(side_effect=lambda x: [Mock(), Mock()]) + + items = [ + { + 'type': 'text', + 'value': 'Test text' + }, + { + 'type': 'url', + 'value': 'https://example.com' + } + ] + + result = create_kb_records(mock_app, items) + + assert mock_kb_class.call_count == 2 + + @patch('core.services.kb_utils.KnowledgeBase') + @patch('core.services.kb_utils.default_storage') + def test_create_kb_records_file_without_file(self, mock_storage, mock_kb_class): + mock_app = Mock() + mock_app.uuid = uuid.uuid4() + + mock_kb_instance = Mock() + mock_kb_class.return_value = mock_kb_instance + mock_kb_class.objects.bulk_create = Mock() + mock_kb_class.objects.filter.return_value.order_by.return_value.__getitem__ = Mock(side_effect=lambda x: []) + + items = [ + { + 'type': 'file', + 'file': None + } + ] + + result = create_kb_records(mock_app, items) + + assert mock_kb_class.call_count == 0 + + @patch('core.services.kb_utils.KnowledgeBase') + def test_create_kb_records_empty_items(self, mock_kb_class): + mock_app = Mock() + mock_app.uuid = uuid.uuid4() + + mock_kb_class.objects.bulk_create = Mock() + mock_kb_class.objects.filter.return_value.order_by.return_value.__getitem__ = Mock(side_effect=lambda x: []) + + items = [] + + result = create_kb_records(mock_app, items) + + mock_kb_class.objects.bulk_create.assert_called_once_with([]) + + +@pytest.mark.unit +class TestParseKBFromRequest: + def test_parse_kb_from_request_text_item(self): + mock_request = Mock() + mock_request.data = { + 'items[0].type': 'text', + 'items[0].value': 'Test text content' + } + mock_request.FILES = {} + + result = parse_kb_from_request(mock_request) + + assert len(result) == 1 + assert result[0]['type'] == 'text' + assert result[0]['value'] == 'Test text content' + + def test_parse_kb_from_request_url_item(self): + mock_request = Mock() + mock_request.data = { + 'items[0].type': 'url', + 'items[0].value': 'https://example.com' + } + mock_request.FILES = {} + + result = parse_kb_from_request(mock_request) + + assert len(result) == 1 + assert result[0]['type'] == 'url' + assert result[0]['value'] == 'https://example.com' + assert result[0]['crawling_config'] is None + + def test_parse_kb_from_request_url_with_crawling(self): + mock_request = Mock() + mock_request.data = { + 'items[0].type': 'url', + 'items[0].value': 'https://example.com', + 'items[0].crawling_config.enable_crawling': 'true', + 'items[0].crawling_config.max_depth': '3', + 'items[0].crawling_config.max_pages': '100' + } + mock_request.FILES = {} + + result = parse_kb_from_request(mock_request) + + assert len(result) == 1 + assert result[0]['type'] == 'url' + assert result[0]['crawling_config']['enable_crawling'] is True + assert result[0]['crawling_config']['max_depth'] == 3 + assert result[0]['crawling_config']['max_pages'] == 100 + + def test_parse_kb_from_request_file_item(self): + mock_request = Mock() + mock_request.data = { + 'items[0].type': 'file', + 'items[0].value': '' + } + mock_file = Mock() + mock_file.name = 'test_file.txt' + mock_request.FILES = { + 'items[0].file': mock_file + } + + result = parse_kb_from_request(mock_request) + + assert len(result) == 1 + assert result[0]['type'] == 'file' + assert result[0]['file'] == mock_file + + def test_parse_kb_from_request_multiple_items(self): + mock_request = Mock() + mock_request.data = { + 'items[0].type': 'text', + 'items[0].value': 'Test text', + 'items[1].type': 'url', + 'items[1].value': 'https://example.com' + } + mock_request.FILES = {} + + result = parse_kb_from_request(mock_request) + + assert len(result) == 2 + assert result[0]['type'] == 'text' + assert result[1]['type'] == 'url' + + def test_parse_kb_from_request_empty(self): + mock_request = Mock() + mock_request.data = {} + mock_request.FILES = {} + + result = parse_kb_from_request(mock_request) + + assert result == [] + + def test_parse_kb_from_request_url_crawling_disabled(self): + mock_request = Mock() + mock_request.data = { + 'items[0].type': 'url', + 'items[0].value': 'https://example.com', + 'items[0].crawling_config.enable_crawling': 'false' + } + mock_request.FILES = {} + + result = parse_kb_from_request(mock_request) + + assert len(result) == 1 + assert result[0]['crawling_config'] is None + + def test_parse_kb_from_request_url_crawling_defaults(self): + mock_request = Mock() + mock_request.data = { + 'items[0].type': 'url', + 'items[0].value': 'https://example.com', + 'items[0].crawling_config.enable_crawling': 'true' + } + mock_request.FILES = {} + + result = parse_kb_from_request(mock_request) + + assert len(result) == 1 + assert result[0]['crawling_config']['enable_crawling'] is True + assert result[0]['crawling_config']['max_depth'] == 1 # default + assert result[0]['crawling_config']['max_pages'] == 50 # default diff --git a/backend/core/tests/test_services/test_notifications.py b/backend/core/tests/test_services/test_notifications.py new file mode 100644 index 0000000..e81b5eb --- /dev/null +++ b/backend/core/tests/test_services/test_notifications.py @@ -0,0 +1,208 @@ +import pytest +from unittest.mock import Mock, patch, MagicMock +import uuid + +from core.services.notifications import ( + find_channels, + notify_users +) + + +@pytest.mark.unit +class TestFindChannels: + @patch('core.models.AppNotificationProfile') + def test_find_channels_success(self, mock_app_profile_class): + mock_app = Mock() + mock_app.uuid = uuid.uuid4() + + mock_notification_profile = Mock() + mock_notification_profile.type = 'email' + mock_notification_profile.config = {'email': 'test@example.com'} + + mock_app_profile = Mock() + mock_app_profile.notification_profile = mock_notification_profile + + mock_queryset = Mock() + mock_queryset.select_related.return_value.filter.return_value = [mock_app_profile] + mock_app_profile_class.objects = mock_queryset + + result = find_channels(mock_app) + + assert len(result) == 1 + assert result[0] == mock_notification_profile + + @patch('core.models.AppNotificationProfile') + def test_find_channels_multiple(self, mock_app_profile_class): + mock_app = Mock() + mock_app.uuid = uuid.uuid4() + + mock_notification_profile1 = Mock() + mock_notification_profile1.type = 'email' + mock_notification_profile1.config = {'email': 'test@example.com'} + + mock_notification_profile2 = Mock() + mock_notification_profile2.type = 'slack' + mock_notification_profile2.config = {'webhook': 'https://slack.com/webhook'} + + mock_app_profile1 = Mock() + mock_app_profile1.notification_profile = mock_notification_profile1 + + mock_app_profile2 = Mock() + mock_app_profile2.notification_profile = mock_notification_profile2 + + mock_queryset = Mock() + mock_queryset.select_related.return_value.filter.return_value = [mock_app_profile1, mock_app_profile2] + mock_app_profile_class.objects = mock_queryset + + result = find_channels(mock_app) + + assert len(result) == 2 + assert result[0] == mock_notification_profile1 + assert result[1] == mock_notification_profile2 + + @patch('core.models.AppNotificationProfile') + def test_find_channels_empty(self, mock_app_profile_class): + mock_app = Mock() + mock_app.uuid = uuid.uuid4() + + mock_queryset = Mock() + mock_queryset.select_related.return_value.filter.return_value = [] + mock_app_profile_class.objects = mock_queryset + + result = find_channels(mock_app) + + assert result == [] + + @patch('core.models.AppNotificationProfile') + def test_find_channels_select_related_filter(self, mock_app_profile_class): + mock_app = Mock() + mock_app.uuid = uuid.uuid4() + + mock_notification_profile = Mock() + mock_app_profile = Mock() + mock_app_profile.notification_profile = mock_notification_profile + + mock_queryset = Mock() + mock_queryset.select_related.return_value.filter.return_value = [mock_app_profile] + mock_app_profile_class.objects = mock_queryset + + find_channels(mock_app) + + mock_queryset.select_related.assert_called_once_with("notification_profile") + mock_queryset.select_related.return_value.filter.assert_called_once_with(application=mock_app) + + +@pytest.mark.unit +class TestNotifyUsers: + @patch('core.services.notifications.send_notification_task') + @patch('core.services.notifications.find_channels') + @patch('core.services.notifications.render_template') + def test_notify_users_success(self, mock_render_template, mock_find_channels, mock_send_task): + mock_app = Mock() + mock_app.uuid = uuid.uuid4() + + mock_channel = Mock() + mock_channel.type = 'email' + mock_channel.config = {'email': 'test@example.com'} + mock_find_channels.return_value = [mock_channel] + + mock_render_template.return_value = 'Rendered message' + + mock_send_task.delay = Mock() + + notify_users(mock_app, 'template_str', {'key': 'value'}) + + mock_render_template.assert_called_once_with('template_str', {'key': 'value'}) + mock_find_channels.assert_called_once_with(mock_app) + mock_send_task.delay.assert_called_once() + + @patch('core.services.notifications.send_notification_task') + @patch('core.services.notifications.find_channels') + @patch('core.services.notifications.render_template') + def test_notify_users_multiple_channels(self, mock_render_template, mock_find_channels, mock_send_task): + mock_app = Mock() + mock_app.uuid = uuid.uuid4() + + mock_channel1 = Mock() + mock_channel1.type = 'email' + mock_channel1.config = {'email': 'test@example.com'} + + mock_channel2 = Mock() + mock_channel2.type = 'slack' + mock_channel2.config = {'webhook': 'https://slack.com/webhook'} + + mock_find_channels.return_value = [mock_channel1, mock_channel2] + + mock_render_template.return_value = 'Rendered message' + + mock_send_task.delay = Mock() + + notify_users(mock_app, 'template_str', {'key': 'value'}) + + assert mock_send_task.delay.call_count == 2 + + @patch('core.services.notifications.find_channels') + @patch('core.services.notifications.render_template') + def test_notify_users_no_channels(self, mock_render_template, mock_find_channels, capsys): + mock_app = Mock() + mock_app.uuid = uuid.uuid4() + + mock_find_channels.return_value = [] + + mock_render_template.return_value = 'Rendered message' + + notify_users(mock_app, 'template_str', {'key': 'value'}) + + captured = capsys.readouterr() + assert 'No notification channels found' in captured.out + mock_render_template.assert_called_once_with('template_str', {'key': 'value'}) + + @patch('core.services.notifications.send_notification_task') + @patch('core.services.notifications.find_channels') + @patch('core.services.notifications.render_template') + def test_notify_users_channel_data_format(self, mock_render_template, mock_find_channels, mock_send_task): + mock_app = Mock() + mock_app.uuid = uuid.uuid4() + + mock_channel = Mock() + mock_channel.type = 'email' + mock_channel.config = {'email': 'test@example.com'} + mock_find_channels.return_value = [mock_channel] + + mock_render_template.return_value = 'Rendered message' + + mock_send_task.delay = Mock() + + notify_users(mock_app, 'template_str', {'key': 'value'}) + + expected_channel_data = { + 'type': 'email', + 'config': {'email': 'test@example.com'} + } + mock_send_task.delay.assert_called_once_with(expected_channel_data, 'Rendered message') + + @patch('core.services.notifications.send_notification_task') + @patch('core.services.notifications.find_channels') + @patch('core.services.notifications.render_template') + def test_notify_users_with_context(self, mock_render_template, mock_find_channels, mock_send_task): + mock_app = Mock() + mock_app.uuid = uuid.uuid4() + + mock_channel = Mock() + mock_channel.type = 'slack' + mock_channel.config = {'webhook': 'https://slack.com/webhook'} + mock_find_channels.return_value = [mock_channel] + + mock_render_template.return_value = 'Rendered message with context' + + mock_send_task.delay = Mock() + + context = { + 'user_name': 'John', + 'action': 'created' + } + + notify_users(mock_app, 'user_action_template', context) + + mock_render_template.assert_called_once_with('user_action_template', context) + mock_send_task.delay.assert_called_once() diff --git a/backend/core/tests/test_services/test_private_key_encryption.py b/backend/core/tests/test_services/test_private_key_encryption.py new file mode 100644 index 0000000..4170bf8 --- /dev/null +++ b/backend/core/tests/test_services/test_private_key_encryption.py @@ -0,0 +1,85 @@ +import pytest +from unittest.mock import Mock, patch +import base64 + +from core.services.private_key_encryption import decrypt_with_private_key + + +@pytest.mark.unit +class TestDecryptWithPrivateKey: + @patch('core.services.private_key_encryption.PRIVATE_KEY_PEM') + @patch('core.services.private_key_encryption.serialization.load_pem_private_key') + def test_decrypt_with_private_key_success(self, mock_load_key, mock_private_key): + mock_private_key.return_value = '-----BEGIN PRIVATE KEY-----\ntest_key\n-----END PRIVATE KEY-----' + + mock_key_instance = Mock() + mock_key_instance.decrypt.return_value = b'decrypted_data' + mock_load_key.return_value = mock_key_instance + + encrypted_data = base64.b64encode(b'encrypted_bytes').decode() + result = decrypt_with_private_key(encrypted_data) + + assert result == 'decrypted_data' + mock_key_instance.decrypt.assert_called_once() + + @patch('core.services.private_key_encryption.PRIVATE_KEY_PEM', None) + def test_decrypt_with_private_key_missing_key(self): + encrypted_data = base64.b64encode(b'encrypted_bytes').decode() + + with pytest.raises(ValueError, match="Missing PRIVATE_KEY in environment"): + decrypt_with_private_key(encrypted_data) + + @patch('core.services.private_key_encryption.PRIVATE_KEY_PEM') + @patch('core.services.private_key_encryption.serialization.load_pem_private_key') + def test_decrypt_with_private_key_invalid_base64(self, mock_load_key, mock_private_key): + mock_private_key.return_value = '-----BEGIN PRIVATE KEY-----\ntest_key\n-----END PRIVATE KEY-----' + + mock_key_instance = Mock() + mock_load_key.return_value = mock_key_instance + + with pytest.raises(Exception): + decrypt_with_private_key('invalid_base64!') + + @patch('core.services.private_key_encryption.PRIVATE_KEY_PEM') + @patch('core.services.private_key_encryption.serialization.load_pem_private_key') + def test_decrypt_with_private_key_empty_string(self, mock_load_key, mock_private_key): + mock_private_key.return_value = '-----BEGIN PRIVATE KEY-----\ntest_key\n-----END PRIVATE KEY-----' + + mock_key_instance = Mock() + mock_key_instance.decrypt.return_value = b'' + mock_load_key.return_value = mock_key_instance + + encrypted_data = base64.b64encode(b'encrypted_bytes').decode() + result = decrypt_with_private_key(encrypted_data) + + assert result == '' + + @patch('core.services.private_key_encryption.PRIVATE_KEY_PEM') + @patch('core.services.private_key_encryption.serialization.load_pem_private_key') + def test_decrypt_with_private_key_unicode_content(self, mock_load_key, mock_private_key): + mock_private_key.return_value = '-----BEGIN PRIVATE KEY-----\ntest_key\n-----END PRIVATE KEY-----' + + mock_key_instance = Mock() + mock_key_instance.decrypt.return_value = 'Hello there'.encode('utf-8') + mock_load_key.return_value = mock_key_instance + + encrypted_data = base64.b64encode(b'encrypted_bytes').decode() + result = decrypt_with_private_key(encrypted_data) + + assert result == 'Hello there' + + @patch('core.services.private_key_encryption.PRIVATE_KEY_PEM') + @patch('core.services.private_key_encryption.serialization.load_pem_private_key') + def test_decrypt_with_private_key_long_content(self, mock_load_key, mock_private_key): + mock_private_key.return_value = '-----BEGIN PRIVATE KEY-----\ntest_key\n-----END PRIVATE KEY-----' + + long_content = 'a' * 1000 + mock_key_instance = Mock() + mock_key_instance.decrypt.return_value = long_content.encode('utf-8') + mock_load_key.return_value = mock_key_instance + + encrypted_data = base64.b64encode(b'encrypted_bytes').decode() + result = decrypt_with_private_key(encrypted_data) + + assert result == long_content + assert len(result) == 1000 diff --git a/backend/core/tests/test_services/test_template_loader.py b/backend/core/tests/test_services/test_template_loader.py new file mode 100644 index 0000000..6dd6f5f --- /dev/null +++ b/backend/core/tests/test_services/test_template_loader.py @@ -0,0 +1,183 @@ +import pytest +from unittest.mock import Mock, patch, MagicMock + +from core.services.template_loader import TemplateLoader + + +@pytest.fixture(autouse=True) +def reset_template_loader_cache(): + yield + TemplateLoader._env = None + TemplateLoader._templates_cache = {} + + +@pytest.mark.unit +class TestGetEnvironment: + @patch('core.services.template_loader.Environment') + @patch('core.services.template_loader.settings') + def test_get_environment_first_call(self, mock_settings, mock_env_class): + mock_settings.JINJA_TEMPLATE_DIR = '/templates' + mock_env_instance = Mock() + mock_env_class.return_value = mock_env_instance + + result = TemplateLoader.get_environment() + + assert mock_env_class.called + assert result == mock_env_instance + + @patch('core.services.template_loader.Environment') + @patch('core.services.template_loader.settings') + def test_get_environment_cached(self, mock_settings, mock_env_class): + mock_settings.JINJA_TEMPLATE_DIR = '/templates' + mock_env_instance = Mock() + mock_env_class.return_value = mock_env_instance + + result1 = TemplateLoader.get_environment() + result2 = TemplateLoader.get_environment() + + assert mock_env_class.call_count == 1 + assert result1 == result2 + + @patch('core.services.template_loader.Environment') + def test_get_environment_custom_dir(self, mock_env_class): + mock_env_instance = Mock() + mock_env_class.return_value = mock_env_instance + + result = TemplateLoader.get_environment(template_dir='/custom/templates') + + assert mock_env_class.called + assert result == mock_env_instance + + @patch('core.services.template_loader.Environment') + @patch('core.services.template_loader.settings') + def test_get_environment_default_dir(self, mock_settings, mock_env_class): + mock_settings.JINJA_TEMPLATE_DIR = '/default/templates' + mock_env_instance = Mock() + mock_env_class.return_value = mock_env_instance + + result = TemplateLoader.get_environment() + + assert mock_env_class.called + + +@pytest.mark.unit +class TestGetTemplate: + @patch('core.services.template_loader.TemplateLoader.get_environment') + def test_get_template_first_call(self, mock_get_env): + mock_env = Mock() + mock_template = Mock() + mock_env.get_template.return_value = mock_template + mock_get_env.return_value = mock_env + + result = TemplateLoader.get_template('test_template.html') + + mock_get_env.assert_called_once_with(None) + mock_env.get_template.assert_called_once_with('test_template.html') + assert result == mock_template + + @patch('core.services.template_loader.TemplateLoader.get_environment') + def test_get_template_cached(self, mock_get_env): + mock_env = Mock() + mock_template = Mock() + mock_env.get_template.return_value = mock_template + mock_get_env.return_value = mock_env + + result1 = TemplateLoader.get_template('test_template.html') + result2 = TemplateLoader.get_template('test_template.html') + + assert mock_get_env.call_count == 1 + assert mock_env.get_template.call_count == 1 + assert result1 == result2 + + @patch('core.services.template_loader.TemplateLoader.get_environment') + def test_get_template_custom_dir(self, mock_get_env): + mock_env = Mock() + mock_template = Mock() + mock_env.get_template.return_value = mock_template + mock_get_env.return_value = mock_env + + result = TemplateLoader.get_template('test_template.html', template_dir='/custom') + + mock_get_env.assert_called_once_with('/custom') + mock_env.get_template.assert_called_once_with('test_template.html') + assert result == mock_template + + @patch('core.services.template_loader.TemplateLoader.get_environment') + def test_get_template_different_dirs_separate_cache(self, mock_get_env): + mock_env = Mock() + mock_template1 = Mock() + mock_template2 = Mock() + mock_env.get_template.side_effect = [mock_template1, mock_template2] + mock_get_env.return_value = mock_env + + result1 = TemplateLoader.get_template('test_template.html', template_dir='/dir1') + result2 = TemplateLoader.get_template('test_template.html', template_dir='/dir2') + + assert mock_get_env.call_count == 2 + assert mock_env.get_template.call_count == 2 + + @patch('core.services.template_loader.TemplateLoader.get_environment') + def test_get_template_different_names_separate_cache(self, mock_get_env): + mock_env = Mock() + mock_template1 = Mock() + mock_template2 = Mock() + mock_env.get_template.side_effect = [mock_template1, mock_template2] + mock_get_env.return_value = mock_env + + result1 = TemplateLoader.get_template('template1.html') + result2 = TemplateLoader.get_template('template2.html') + + assert mock_get_env.call_count == 2 + assert mock_env.get_template.call_count == 2 + + +@pytest.mark.unit +class TestRenderTemplate: + @patch('core.services.template_loader.TemplateLoader.get_template') + def test_render_template_success(self, mock_get_template): + mock_template = Mock() + mock_template.render.return_value = 'Rendered content' + mock_get_template.return_value = mock_template + + context = {'name': 'John', 'action': 'created'} + result = TemplateLoader.render_template('test_template.html', context) + + mock_get_template.assert_called_once_with('test_template.html', None) + mock_template.render.assert_called_once_with(name='John', action='created') + assert result == 'Rendered content' + + @patch('core.services.template_loader.TemplateLoader.get_template') + def test_render_template_with_custom_dir(self, mock_get_template): + mock_template = Mock() + mock_template.render.return_value = 'Rendered content' + mock_get_template.return_value = mock_template + + context = {'key': 'value'} + result = TemplateLoader.render_template('test_template.html', context, template_dir='/custom') + + mock_get_template.assert_called_once_with('test_template.html', '/custom') + mock_template.render.assert_called_once_with(key='value') + assert result == 'Rendered content' + + @patch('core.services.template_loader.TemplateLoader.get_template') + def test_render_template_empty_context(self, mock_get_template): + mock_template = Mock() + mock_template.render.return_value = 'Rendered content' + mock_get_template.return_value = mock_template + + result = TemplateLoader.render_template('test_template.html', {}) + + mock_template.render.assert_called_once_with() + assert result == 'Rendered content' + + @patch('core.services.template_loader.TemplateLoader.get_template') + def test_render_template_uses_cache(self, mock_get_template): + mock_template = Mock() + mock_template.render.return_value = 'Rendered content' + mock_get_template.return_value = mock_template + + context = {'key': 'value'} + TemplateLoader.render_template('test_template.html', context) + TemplateLoader.render_template('test_template.html', context) + + assert mock_get_template.call_count == 2 diff --git a/backend/core/tests/test_services/test_tool_call_executor.py b/backend/core/tests/test_services/test_tool_call_executor.py new file mode 100644 index 0000000..402ee94 --- /dev/null +++ b/backend/core/tests/test_services/test_tool_call_executor.py @@ -0,0 +1,305 @@ +import pytest +from unittest.mock import Mock, patch, MagicMock +import uuid + +from core.services.tool_call_executor import ( + _extract_url_from_schema, + _get_tool_url, + ToolCallExecutor, + _monotonic_ns, + _elapsed_ms +) + + +@pytest.mark.unit +class TestExtractUrlFromSchema: + def test_extract_url_from_schema_with_url(self): + schema = 'https://example.com/api/endpoint' + result = _extract_url_from_schema(schema) + + assert result == 'https://example.com/api/endpoint' + + def test_extract_url_from_schema_with_text(self): + schema = 'The URL is https://example.com/api and some text' + result = _extract_url_from_schema(schema) + + assert result == 'https://example.com/api' + + def test_extract_url_from_schema_http(self): + schema = 'http://example.com' + result = _extract_url_from_schema(schema) + + assert result == 'http://example.com' + + def test_extract_url_from_schema_no_url(self): + schema = 'No URL here' + result = _extract_url_from_schema(schema) + + assert result == '' + + def test_extract_url_from_schema_empty_string(self): + result = _extract_url_from_schema('') + + assert result == '' + + def test_extract_url_from_schema_none(self): + result = _extract_url_from_schema(None) + + assert result == '' + + +@pytest.mark.unit +class TestGetToolUrl: + @patch('core.services.tool_call_executor._extract_url_from_schema') + @patch('core.models.ToolConfig') + @patch('core.models.AppIntegration') + @patch('core.integrations.custom_tool_parser._derive_name') + def test_get_tool_url_success(self, mock_derive_name, mock_app_integration_class, mock_tool_config_class, mock_extract_url): + mock_app_uuid = str(uuid.uuid4()) + mock_tool_name = 'test_tool' + + mock_derive_name.return_value = 'test_tool' + mock_extract_url.return_value = 'https://example.com/api' + + mock_tc = Mock() + mock_tc.url_schema = 'https://example.com/api' + mock_tc.title = 'Test Tool' + mock_tool_config_class.objects.filter.return_value.first.return_value = mock_tc + + mock_ai = Mock() + mock_ai.integration = Mock() + mock_app_integration_class.objects.filter.return_value.select_related.return_value = [mock_ai] + + result = _get_tool_url(mock_app_uuid, mock_tool_name) + + assert result == 'https://example.com/api' + mock_app_integration_class.objects.filter.assert_called_once_with( + application__uuid=mock_app_uuid, is_active=True + ) + + @patch('core.services.tool_call_executor._extract_url_from_schema') + @patch('core.models.ToolConfig') + @patch('core.models.AppIntegration') + @patch('core.integrations.custom_tool_parser._derive_name') + def test_get_tool_url_name_mismatch(self, mock_derive_name, mock_app_integration_class, mock_tool_config_class, mock_extract_url): + mock_app_uuid = str(uuid.uuid4()) + mock_tool_name = 'test_tool' + + mock_derive_name.return_value = 'different_tool' + + mock_tc = Mock() + mock_tc.url_schema = 'https://example.com/api' + mock_tc.title = 'Different Tool' + mock_tool_config_class.objects.filter.return_value.first.return_value = mock_tc + + mock_ai = Mock() + mock_ai.integration = Mock() + mock_app_integration_class.objects.filter.return_value.select_related.return_value = [mock_ai] + + result = _get_tool_url(mock_app_uuid, mock_tool_name) + + assert result == '' + + @patch('core.models.ToolConfig') + @patch('core.models.AppIntegration') + def test_get_tool_url_no_tool_config(self, mock_app_integration_class, mock_tool_config_class): + mock_app_uuid = str(uuid.uuid4()) + mock_tool_name = 'test_tool' + + mock_tool_config_class.objects.filter.return_value.first.return_value = None + + mock_ai = Mock() + mock_ai.integration = Mock() + mock_app_integration_class.objects.filter.return_value.select_related.return_value = [mock_ai] + + result = _get_tool_url(mock_app_uuid, mock_tool_name) + + assert result == '' + + @patch('core.models.AppIntegration') + def test_get_tool_url_no_integrations(self, mock_app_integration_class): + mock_app_uuid = str(uuid.uuid4()) + mock_tool_name = 'test_tool' + + mock_app_integration_class.objects.filter.return_value.select_related.return_value = [] + + result = _get_tool_url(mock_app_uuid, mock_tool_name) + + assert result == '' + + @patch('core.models.AppIntegration') + def test_get_tool_url_exception(self, mock_app_integration_class): + mock_app_uuid = str(uuid.uuid4()) + mock_tool_name = 'test_tool' + + mock_app_integration_class.objects.filter.side_effect = Exception('DB Error') + + result = _get_tool_url(mock_app_uuid, mock_tool_name) + + assert result == '' + + +@pytest.mark.unit +class TestToolCallExecutor: + @patch('core.services.tool_call_executor.execute_tool_call') + @patch('core.services.tool_call_executor._get_tool_url') + @patch('core.services.tool_call_executor._monotonic_ns') + @patch('core.services.tool_call_executor._elapsed_ms') + def test_execute_all_success(self, mock_elapsed_ms, mock_monotonic_ns, mock_get_tool_url, mock_execute_tool_call): + mock_app_uuid = str(uuid.uuid4()) + mock_monotonic_ns.return_value = 1000000000 + mock_elapsed_ms.return_value = 100 + mock_get_tool_url.return_value = 'https://example.com/api' + mock_execute_tool_call.return_value = {'result': 'success'} + + tool_calls = [ + { + 'name': 'test_tool', + 'args': {'param': 'value'}, + 'id': 'call_123' + } + ] + + executor = ToolCallExecutor() + records, messages = executor.execute_all(mock_app_uuid, tool_calls) + + assert len(records) == 1 + assert len(messages) == 1 + assert records[0]['name'] == 'test_tool' + assert records[0]['input_parameters'] == {'param': 'value'} + assert records[0]['raw_result'] == {'result': 'success'} + assert records[0]['duration_ms'] == 100 + assert messages[0]['role'] == 'tool' + assert messages[0]['tool_call_id'] == 'call_123' + assert messages[0]['content'] == "{'result': 'success'}" + + @patch('core.services.tool_call_executor.execute_tool_call') + @patch('core.services.tool_call_executor._get_tool_url') + @patch('core.services.tool_call_executor._monotonic_ns') + @patch('core.services.tool_call_executor._elapsed_ms') + def test_execute_all_error(self, mock_elapsed_ms, mock_monotonic_ns, mock_get_tool_url, mock_execute_tool_call): + mock_app_uuid = str(uuid.uuid4()) + mock_monotonic_ns.return_value = 1000000000 + mock_elapsed_ms.return_value = 100 + mock_get_tool_url.return_value = 'https://example.com/api' + mock_execute_tool_call.side_effect = Exception('Tool failed') + + tool_calls = [ + { + 'name': 'test_tool', + 'args': {'param': 'value'}, + 'id': 'call_123' + } + ] + + executor = ToolCallExecutor() + records, messages = executor.execute_all(mock_app_uuid, tool_calls) + + assert len(records) == 1 + assert len(messages) == 1 + assert records[0]['name'] == 'test_tool' + assert 'error' in records[0] + assert records[0]['error']['message'] == 'Tool failed' + assert messages[0]['content'] == 'Error: Tool failed' + + @patch('core.services.tool_call_executor.execute_tool_call') + @patch('core.services.tool_call_executor._get_tool_url') + @patch('core.services.tool_call_executor._monotonic_ns') + @patch('core.services.tool_call_executor._elapsed_ms') + def test_execute_all_multiple_calls(self, mock_elapsed_ms, mock_monotonic_ns, mock_get_tool_url, mock_execute_tool_call): + mock_app_uuid = str(uuid.uuid4()) + mock_monotonic_ns.return_value = 1000000000 + mock_elapsed_ms.return_value = 100 + mock_get_tool_url.return_value = 'https://example.com/api' + mock_execute_tool_call.side_effect = [ + {'result': 'success1'}, + {'result': 'success2'} + ] + + tool_calls = [ + { + 'name': 'tool1', + 'args': {'param': 'value1'}, + 'id': 'call_1' + }, + { + 'name': 'tool2', + 'args': {'param': 'value2'}, + 'id': 'call_2' + } + ] + + executor = ToolCallExecutor() + records, messages = executor.execute_all(mock_app_uuid, tool_calls) + + assert len(records) == 2 + assert len(messages) == 2 + assert records[0]['name'] == 'tool1' + assert records[1]['name'] == 'tool2' + + @patch('core.services.tool_call_executor.execute_tool_call') + @patch('core.services.tool_call_executor._get_tool_url') + @patch('core.services.tool_call_executor._monotonic_ns') + @patch('core.services.tool_call_executor._elapsed_ms') + def test_execute_all_empty_args(self, mock_elapsed_ms, mock_monotonic_ns, mock_get_tool_url, mock_execute_tool_call): + mock_app_uuid = str(uuid.uuid4()) + mock_monotonic_ns.return_value = 1000000000 + mock_elapsed_ms.return_value = 100 + mock_get_tool_url.return_value = 'https://example.com/api' + mock_execute_tool_call.return_value = {'result': 'success'} + + tool_calls = [ + { + 'name': 'test_tool', + 'args': None, + 'id': 'call_123' + } + ] + + executor = ToolCallExecutor() + records, messages = executor.execute_all(mock_app_uuid, tool_calls) + + assert len(records) == 1 + assert records[0]['input_parameters'] == {} + + @patch('core.services.tool_call_executor.execute_tool_call') + @patch('core.services.tool_call_executor._get_tool_url') + @patch('core.services.tool_call_executor._monotonic_ns') + @patch('core.services.tool_call_executor._elapsed_ms') + def test_execute_all_missing_fields(self, mock_elapsed_ms, mock_monotonic_ns, mock_get_tool_url, mock_execute_tool_call): + mock_app_uuid = str(uuid.uuid4()) + mock_monotonic_ns.return_value = 1000000000 + mock_elapsed_ms.return_value = 100 + mock_get_tool_url.return_value = '' + mock_execute_tool_call.return_value = {'result': 'success'} + + tool_calls = [ + { + 'name': 'test_tool', + 'args': {'param': 'value'} + } + ] + + executor = ToolCallExecutor() + records, messages = executor.execute_all(mock_app_uuid, tool_calls) + + assert len(records) == 1 + assert records[0]['name'] == 'test_tool' + assert records[0]['url'] == '' + assert messages[0]['tool_call_id'] == '' + + @patch('core.services.tool_call_executor.execute_tool_call') + @patch('core.services.tool_call_executor._get_tool_url') + @patch('core.services.tool_call_executor._monotonic_ns') + @patch('core.services.tool_call_executor._elapsed_ms') + def test_execute_all_empty_list(self, mock_elapsed_ms, mock_monotonic_ns, mock_get_tool_url, mock_execute_tool_call): + mock_app_uuid = str(uuid.uuid4()) + + tool_calls = [] + + executor = ToolCallExecutor() + records, messages = executor.execute_all(mock_app_uuid, tool_calls) + + assert records == [] + assert messages == [] + mock_execute_tool_call.assert_not_called() diff --git a/backend/core/tests/test_services/test_unread.py b/backend/core/tests/test_services/test_unread.py new file mode 100644 index 0000000..3796023 --- /dev/null +++ b/backend/core/tests/test_services/test_unread.py @@ -0,0 +1,197 @@ +import pytest +from unittest.mock import Mock, patch, MagicMock +import uuid + +from core.services.unread import ( + mark_unread_for_participants, + broadcast_unread_update, + mark_read_for_participant +) + + +@pytest.mark.unit +class TestMarkUnreadForParticipants: + @patch('core.services.unread.ChatroomParticipant') + def test_mark_unread_for_participants_success(self, mock_participant_class): + mock_chatroom = Mock() + mock_chatroom.uuid = uuid.uuid4() + sender_identifier = 'user_123' + + mock_queryset = Mock() + mock_queryset.exclude.return_value = mock_queryset + mock_queryset.filter.return_value = mock_queryset + mock_queryset.update.return_value = 3 + mock_queryset.values_list.return_value = ['user_456', 'user_789', 'user_abc'] + mock_participant_class.objects.filter.return_value = mock_queryset + + result = mark_unread_for_participants(mock_chatroom, sender_identifier) + + mock_participant_class.objects.filter.assert_called_once_with(chatroom=mock_chatroom) + mock_queryset.exclude.assert_called_once_with(user_identifier=sender_identifier) + mock_queryset.update.assert_called_once_with(has_unread=True) + assert result == ['user_456', 'user_789', 'user_abc'] + + @patch('core.services.unread.ChatroomParticipant') + def test_mark_unread_for_participants_internal(self, mock_participant_class): + mock_chatroom = Mock() + mock_chatroom.uuid = uuid.uuid4() + sender_identifier = 'dashboard_admin' + is_internal = True + + mock_queryset = Mock() + mock_queryset.exclude.return_value = mock_queryset + mock_queryset.filter.return_value = mock_queryset + mock_queryset.update.return_value = 2 + mock_queryset.values_list.return_value = ['dashboard_user1', 'dashboard_user2'] + mock_participant_class.objects.filter.return_value = mock_queryset + + result = mark_unread_for_participants(mock_chatroom, sender_identifier, is_internal) + + mock_queryset.exclude.assert_called_once_with(user_identifier=sender_identifier) + mock_queryset.filter.assert_called_once_with(user_identifier__startswith='dashboard_') + assert result == ['dashboard_user1', 'dashboard_user2'] + + @patch('core.services.unread.ChatroomParticipant') + def test_mark_unread_for_participants_no_participants(self, mock_participant_class): + mock_chatroom = Mock() + mock_chatroom.uuid = uuid.uuid4() + sender_identifier = 'user_123' + + mock_queryset = Mock() + mock_queryset.exclude.return_value = mock_queryset + mock_queryset.filter.return_value = mock_queryset + mock_queryset.update.return_value = 0 + mock_queryset.values_list.return_value = [] + mock_participant_class.objects.filter.return_value = mock_queryset + + result = mark_unread_for_participants(mock_chatroom, sender_identifier) + + assert result == [] + + @patch('core.services.unread.ChatroomParticipant') + def test_mark_unread_for_participants_only_sender(self, mock_participant_class): + mock_chatroom = Mock() + mock_chatroom.uuid = uuid.uuid4() + sender_identifier = 'user_123' + + mock_queryset = Mock() + mock_queryset.exclude.return_value = mock_queryset + mock_queryset.filter.return_value = mock_queryset + mock_queryset.update.return_value = 0 + mock_queryset.values_list.return_value = [] + mock_participant_class.objects.filter.return_value = mock_queryset + + result = mark_unread_for_participants(mock_chatroom, sender_identifier) + + mock_queryset.exclude.assert_called_once_with(user_identifier=sender_identifier) + assert result == [] + + +@pytest.mark.unit +class TestBroadcastUnreadUpdate: + @patch('core.services.unread.async_to_sync') + @patch('core.services.unread.get_channel_layer') + def test_broadcast_unread_update_success(self, mock_get_channel_layer, mock_async_to_sync): + mock_channel_layer = Mock() + mock_get_channel_layer.return_value = mock_channel_layer + mock_async_to_sync.return_value = Mock() + + user_identifier = 'user_123' + chatroom_uuid = str(uuid.uuid4()) + has_unread = True + sender_identifier = 'user_456' + + broadcast_unread_update(user_identifier, chatroom_uuid, has_unread, sender_identifier) + + mock_get_channel_layer.assert_called_once() + mock_async_to_sync.assert_called_once() + assert mock_async_to_sync.return_value.called + + @patch('core.services.unread.async_to_sync') + @patch('core.services.unread.get_channel_layer') + def test_broadcast_unread_update_exception(self, mock_get_channel_layer, mock_async_to_sync): + mock_channel_layer = Mock() + mock_get_channel_layer.return_value = mock_channel_layer + mock_async_to_sync.side_effect = Exception('WebSocket error') + + user_identifier = 'user_123' + chatroom_uuid = str(uuid.uuid4()) + has_unread = True + sender_identifier = 'user_456' + + broadcast_unread_update(user_identifier, chatroom_uuid, has_unread, sender_identifier) + + mock_get_channel_layer.assert_called_once() + mock_async_to_sync.assert_called_once() + + @patch('core.services.unread.async_to_sync') + @patch('core.services.unread.get_channel_layer') + @patch('core.services.unread.LIVE_UPDATES_PREFIX', 'live_updates') + def test_broadcast_unread_update_group_name(self, mock_get_channel_layer, mock_async_to_sync): + mock_channel_layer = Mock() + mock_get_channel_layer.return_value = mock_channel_layer + mock_async_to_sync.return_value = Mock() + + user_identifier = 'user_123' + chatroom_uuid = str(uuid.uuid4()) + has_unread = False + sender_identifier = 'user_456' + + broadcast_unread_update(user_identifier, chatroom_uuid, has_unread, sender_identifier) + + mock_async_to_sync.assert_called_once() + + +@pytest.mark.unit +class TestMarkReadForParticipant: + @patch('core.services.unread.ChatroomParticipant') + def test_mark_read_for_participant_success(self, mock_participant_class): + mock_chatroom = Mock() + mock_chatroom.uuid = uuid.uuid4() + user_identifier = 'user_123' + + mock_queryset = Mock() + mock_queryset.update.return_value = 1 + mock_participant_class.objects.filter.return_value = mock_queryset + + mark_read_for_participant(mock_chatroom, user_identifier) + + mock_participant_class.objects.filter.assert_called_once_with( + chatroom=mock_chatroom, + user_identifier=user_identifier + ) + mock_queryset.update.assert_called_once_with(has_unread=False) + + @patch('core.services.unread.ChatroomParticipant') + def test_mark_read_for_participant_not_found(self, mock_participant_class): + mock_chatroom = Mock() + mock_chatroom.uuid = uuid.uuid4() + user_identifier = 'user_123' + + mock_queryset = Mock() + mock_queryset.update.return_value = 0 + mock_participant_class.objects.filter.return_value = mock_queryset + + mark_read_for_participant(mock_chatroom, user_identifier) + + mock_participant_class.objects.filter.assert_called_once_with( + chatroom=mock_chatroom, + user_identifier=user_identifier + ) + mock_queryset.update.assert_called_once_with(has_unread=False) + + @patch('core.services.unread.ChatroomParticipant') + def test_mark_read_for_participant_multiple_updates(self, mock_participant_class): + mock_chatroom = Mock() + mock_chatroom.uuid = uuid.uuid4() + user_identifier = 'user_123' + + mock_queryset = Mock() + mock_queryset.update.return_value = 1 + mock_participant_class.objects.filter.return_value = mock_queryset + + mark_read_for_participant(mock_chatroom, user_identifier) + mark_read_for_participant(mock_chatroom, user_identifier) + + assert mock_participant_class.objects.filter.call_count == 2 + assert mock_queryset.update.call_count == 2 diff --git a/backend/core/tests/test_services/test_url_ingestion.py b/backend/core/tests/test_services/test_url_ingestion.py new file mode 100644 index 0000000..e833a1e --- /dev/null +++ b/backend/core/tests/test_services/test_url_ingestion.py @@ -0,0 +1,409 @@ +import pytest +from unittest.mock import Mock, patch, MagicMock +import uuid +from datetime import datetime + +from core.services.url_ingestion import URLIngestionService + + +@pytest.mark.unit +class TestURLIngestionServiceInit: + @patch('core.services.url_ingestion.URLExtractor') + def test_init(self, mock_extractor_class): + mock_extractor = Mock() + mock_extractor_class.return_value = mock_extractor + + service = URLIngestionService() + + assert service.extractor == mock_extractor + + +@pytest.mark.unit +class TestCreateUrlKb: + @patch('core.models.KnowledgeBase') + def test_create_url_kb_success(self, mock_kb_class): + mock_app = Mock() + mock_app.uuid = uuid.uuid4() + mock_user = Mock() + mock_user.uuid = uuid.uuid4() + url = 'https://example.com' + + mock_kb = Mock() + mock_kb.uuid = uuid.uuid4() + mock_kb.metadata = {} + mock_kb_class.objects.create.return_value = mock_kb + + service = URLIngestionService() + result = service.create_url_kb(mock_app, url, mock_user) + + mock_kb_class.objects.create.assert_called_once() + assert result == mock_kb + + @patch('core.models.KnowledgeBase') + def test_create_url_kb_with_crawling_config(self, mock_kb_class): + mock_app = Mock() + mock_app.uuid = uuid.uuid4() + mock_user = Mock() + mock_user.uuid = uuid.uuid4() + url = 'https://example.com' + crawling_config = {'enable_crawling': True, 'max_depth': 3, 'max_pages': 50} + + mock_kb = Mock() + mock_kb.uuid = uuid.uuid4() + mock_kb.metadata = {} + mock_kb_class.objects.create.return_value = mock_kb + + service = URLIngestionService() + result = service.create_url_kb(mock_app, url, mock_user, crawling_config) + + mock_kb_class.objects.create.assert_called_once() + assert result == mock_kb + + @patch('core.models.KnowledgeBase') + def test_create_url_kb_without_crawling_config(self, mock_kb_class): + mock_app = Mock() + mock_app.uuid = uuid.uuid4() + mock_user = Mock() + mock_user.uuid = uuid.uuid4() + url = 'https://example.com' + + mock_kb = Mock() + mock_kb.uuid = uuid.uuid4() + mock_kb.metadata = {} + mock_kb_class.objects.create.return_value = mock_kb + + service = URLIngestionService() + result = service.create_url_kb(mock_app, url, mock_user, None) + + mock_kb_class.objects.create.assert_called_once() + assert result == mock_kb + + +@pytest.mark.unit +class TestExtractUrlContent: + @patch('core.services.url_ingestion.URLExtractor') + def test_extract_url_content_success(self, mock_extractor_class): + mock_extractor = Mock() + mock_extractor.extract_content.return_value = { + 'url': 'https://example.com', + 'title': 'Example', + 'description': 'Test description', + 'content': 'Test content', + 'links': ['https://example.com/page1'], + 'content_type': 'text/html' + } + mock_extractor_class.return_value = mock_extractor + + mock_kb = Mock() + mock_kb.uuid = uuid.uuid4() + mock_kb.source_type = 'url' + mock_kb.path = 'https://example.com' + mock_kb.metadata = {} + mock_kb.updated_at = datetime.now() + + service = URLIngestionService() + result = service.extract_url_content(mock_kb) + + assert result is True + assert mock_kb.status == 'processing' + mock_kb.save.assert_called() + + @patch('core.services.url_ingestion.URLExtractor') + def test_extract_url_content_not_url_type(self, mock_extractor_class): + mock_extractor = Mock() + mock_extractor_class.return_value = mock_extractor + + mock_kb = Mock() + mock_kb.uuid = uuid.uuid4() + mock_kb.source_type = 'text' + + service = URLIngestionService() + result = service.extract_url_content(mock_kb) + + assert result is False + + @patch('core.services.url_ingestion.URLExtractor') + def test_extract_url_content_no_path(self, mock_extractor_class): + mock_extractor = Mock() + mock_extractor_class.return_value = mock_extractor + + mock_kb = Mock() + mock_kb.uuid = uuid.uuid4() + mock_kb.source_type = 'url' + mock_kb.path = None + + service = URLIngestionService() + result = service.extract_url_content(mock_kb) + + assert result is False + + @patch('core.services.url_ingestion.URLExtractor') + def test_extract_url_content_extraction_failed(self, mock_extractor_class): + mock_extractor = Mock() + mock_extractor.extract_content.return_value = None + mock_extractor_class.return_value = mock_extractor + + mock_kb = Mock() + mock_kb.uuid = uuid.uuid4() + mock_kb.source_type = 'url' + mock_kb.path = 'https://example.com' + mock_kb.metadata = {} + mock_kb.updated_at = datetime.now() + + service = URLIngestionService() + result = service.extract_url_content(mock_kb) + + assert result is False + assert mock_kb.status == 'failed' + + @patch('core.services.url_ingestion.URLExtractor') + def test_extract_url_content_exception(self, mock_extractor_class): + mock_extractor = Mock() + mock_extractor.extract_content.side_effect = Exception('Network error') + mock_extractor_class.return_value = mock_extractor + + mock_kb = Mock() + mock_kb.uuid = uuid.uuid4() + mock_kb.source_type = 'url' + mock_kb.path = 'https://example.com' + mock_kb.metadata = {} + mock_kb.updated_at = datetime.now() + + service = URLIngestionService() + result = service.extract_url_content(mock_kb) + + assert result is False + assert mock_kb.status == 'failed' + + +@pytest.mark.unit +class TestEnableCrawlingForKb: + def test_enable_crawling_for_kb_success(self): + mock_kb = Mock() + mock_kb.uuid = uuid.uuid4() + mock_kb.source_type = 'url' + mock_kb.metadata = {} + mock_kb.updated_at = datetime.now() + + service = URLIngestionService() + service.enable_crawling_for_kb(mock_kb, max_depth=3, max_pages=50) + + assert mock_kb.metadata['crawling_enabled'] is True + assert mock_kb.metadata['crawling_config']['max_depth'] == 3 + assert mock_kb.metadata['crawling_config']['max_pages'] == 50 + mock_kb.save.assert_called_once() + + def test_enable_crawling_for_kb_not_url_type(self): + mock_kb = Mock() + mock_kb.uuid = uuid.uuid4() + mock_kb.source_type = 'text' + + service = URLIngestionService() + + with pytest.raises(ValueError, match="Crawling can only be enabled for URL knowledge base items"): + service.enable_crawling_for_kb(mock_kb) + + def test_enable_crawling_for_kb_default_params(self): + mock_kb = Mock() + mock_kb.uuid = uuid.uuid4() + mock_kb.source_type = 'url' + mock_kb.metadata = {} + mock_kb.updated_at = datetime.now() + + service = URLIngestionService() + service.enable_crawling_for_kb(mock_kb) + + assert mock_kb.metadata['crawling_enabled'] is True + assert mock_kb.metadata['crawling_config']['max_depth'] == 2 + assert mock_kb.metadata['crawling_config']['max_pages'] == 25 + + +@pytest.mark.unit +class TestDisableCrawlingForKb: + def test_disable_crawling_for_kb_success(self): + mock_kb = Mock() + mock_kb.uuid = uuid.uuid4() + mock_kb.source_type = 'url' + mock_kb.metadata = {'crawling_enabled': True, 'crawling_config': {}} + mock_kb.updated_at = datetime.now() + + service = URLIngestionService() + service.disable_crawling_for_kb(mock_kb) + + assert mock_kb.metadata['crawling_enabled'] is False + assert 'disabled_at' in mock_kb.metadata['crawling_config'] + mock_kb.save.assert_called_once() + + def test_disable_crawling_for_kb_not_url_type(self): + mock_kb = Mock() + mock_kb.uuid = uuid.uuid4() + mock_kb.source_type = 'text' + + service = URLIngestionService() + service.disable_crawling_for_kb(mock_kb) + + mock_kb.save.assert_not_called() + + +@pytest.mark.unit +class TestValidateUrlBeforeIngestion: + @patch('core.services.url_ingestion.URLExtractor') + def test_validate_url_invalid_format(self, mock_extractor_class): + mock_extractor = Mock() + mock_extractor.is_valid_url.return_value = False + mock_extractor_class.return_value = mock_extractor + + service = URLIngestionService() + result = service.validate_url_before_ingestion('not-a-url') + + assert result['valid'] is False + assert 'error' in result + + @patch('core.services.url_ingestion.URLExtractor') + def test_validate_url_simple_validation(self, mock_extractor_class): + mock_extractor = Mock() + mock_extractor.is_valid_url.return_value = True + mock_extractor_class.return_value = mock_extractor + + service = URLIngestionService() + result = service.validate_url_before_ingestion('https://example.com', simple_validation=True) + + assert result['valid'] is True + assert result['message'] == 'URL format is valid. Content will be extracted during processing.' + + @patch('core.services.url_ingestion.URLExtractor') + def test_validate_url_full_validation_success(self, mock_extractor_class): + mock_extractor = Mock() + mock_extractor.is_valid_url.return_value = True + mock_extractor.extract_content.return_value = { + 'title': 'Example', + 'description': 'Test', + 'content': 'Content', + 'links': [], + 'content_type': 'text/html' + } + mock_extractor_class.return_value = mock_extractor + + service = URLIngestionService() + result = service.validate_url_before_ingestion('https://example.com') + + assert result['valid'] is True + assert result['title'] == 'Example' + + @patch('core.services.url_ingestion.URLExtractor') + def test_validate_url_extraction_failed(self, mock_extractor_class): + mock_extractor = Mock() + mock_extractor.is_valid_url.return_value = True + mock_extractor.extract_content.return_value = None + mock_extractor_class.return_value = mock_extractor + + service = URLIngestionService() + result = service.validate_url_before_ingestion('https://example.com') + + assert result['valid'] is False + assert 'error' in result + + @patch('core.services.url_ingestion.URLExtractor') + def test_validate_url_timeout(self, mock_extractor_class): + import requests + mock_extractor = Mock() + mock_extractor.is_valid_url.return_value = True + mock_extractor.extract_content.side_effect = requests.exceptions.Timeout() + mock_extractor_class.return_value = mock_extractor + + service = URLIngestionService() + result = service.validate_url_before_ingestion('https://example.com') + + assert result['valid'] is False + assert 'timed out' in result['error'] + + @patch('core.services.url_ingestion.URLExtractor') + def test_validate_url_connection_error(self, mock_extractor_class): + import requests + mock_extractor = Mock() + mock_extractor.is_valid_url.return_value = True + mock_extractor.extract_content.side_effect = requests.exceptions.ConnectionError() + mock_extractor_class.return_value = mock_extractor + + service = URLIngestionService() + result = service.validate_url_before_ingestion('https://example.com') + + assert result['valid'] is False + assert 'connect' in result['error'] + + @patch('core.services.url_ingestion.URLExtractor') + def test_validate_url_403_error(self, mock_extractor_class): + import requests + mock_extractor = Mock() + mock_extractor.is_valid_url.return_value = True + mock_response = Mock() + mock_response.status_code = 403 + mock_response.reason = 'Forbidden' + mock_extractor.extract_content.side_effect = requests.exceptions.HTTPError(response=mock_response) + mock_extractor_class.return_value = mock_extractor + + service = URLIngestionService() + result = service.validate_url_before_ingestion('https://example.com') + + assert result['valid'] is False + assert '403' in result['error'] + + @patch('core.services.url_ingestion.URLExtractor') + def test_validate_url_404_error(self, mock_extractor_class): + import requests + mock_extractor = Mock() + mock_extractor.is_valid_url.return_value = True + mock_response = Mock() + mock_response.status_code = 404 + mock_response.reason = 'Not Found' + mock_extractor.extract_content.side_effect = requests.exceptions.HTTPError(response=mock_response) + mock_extractor_class.return_value = mock_extractor + + service = URLIngestionService() + result = service.validate_url_before_ingestion('https://example.com') + + assert result['valid'] is False + assert '404' in result['error'] + + +@pytest.mark.unit +class TestReprocessUrl: + @patch('core.services.url_ingestion.URLExtractor') + def test_reprocess_url_success(self, mock_extractor_class): + mock_extractor = Mock() + mock_extractor.extract_content.return_value = { + 'url': 'https://example.com', + 'title': 'Example', + 'description': 'Test', + 'content': 'Content', + 'links': [], + 'content_type': 'text/html' + } + mock_extractor_class.return_value = mock_extractor + + mock_kb = Mock() + mock_kb.uuid = uuid.uuid4() + mock_kb.source_type = 'url' + mock_kb.path = 'https://example.com' + mock_kb.metadata = {'extraction_error': 'Previous error'} + mock_kb.updated_at = datetime.now() + + service = URLIngestionService() + result = service.reprocess_url(mock_kb) + + assert result is True + assert 'extraction_error' not in mock_kb.metadata + + @patch('core.services.url_ingestion.URLExtractor') + def test_reprocess_url_not_url_type(self, mock_extractor_class): + mock_extractor = Mock() + mock_extractor_class.return_value = mock_extractor + + mock_kb = Mock() + mock_kb.uuid = uuid.uuid4() + mock_kb.source_type = 'text' + + service = URLIngestionService() + result = service.reprocess_url(mock_kb) + + assert result is False diff --git a/backend/core/tests/test_services/test_vc_ingestion.py b/backend/core/tests/test_services/test_vc_ingestion.py new file mode 100644 index 0000000..55218a1 --- /dev/null +++ b/backend/core/tests/test_services/test_vc_ingestion.py @@ -0,0 +1,333 @@ +import pytest +from unittest.mock import Mock, patch, MagicMock +import uuid +from datetime import datetime + +from core.services.vc_ingestion import VCIngestionService + + +@pytest.mark.unit +class TestVCIngestionServiceInit: + @patch('core.services.vc_ingestion.VCProviderRegistry') + def test_init_with_provider_name(self, mock_registry): + mock_app_integration = Mock() + mock_app_integration.metadata = {} + mock_app_integration.integration.credentials = '{}' + + mock_provider = Mock() + mock_registry.get_provider.return_value = mock_provider + + service = VCIngestionService(mock_app_integration, provider_name='github') + + assert service.app_integration == mock_app_integration + assert service.provider_name == 'github' + + @patch('core.services.vc_ingestion.VCProviderRegistry') + def test_init_without_provider_name(self, mock_registry): + mock_app_integration = Mock() + mock_app_integration.metadata = {'provider': 'gitlab'} + mock_app_integration.integration.credentials = '{}' + + mock_provider = Mock() + mock_registry.get_provider.return_value = mock_provider + + service = VCIngestionService(mock_app_integration) + + assert service.provider_name == 'gitlab' + + +@pytest.mark.unit +class TestDetectProvider: + def test_detect_provider_from_metadata(self): + mock_app_integration = Mock() + mock_app_integration.metadata = {'provider': 'bitbucket'} + + service = VCIngestionService.__new__(VCIngestionService) + service.app_integration = mock_app_integration + + result = service._detect_provider() + + assert result == 'bitbucket' + + def test_detect_provider_default(self): + mock_app_integration = Mock() + mock_app_integration.metadata = {} + + service = VCIngestionService.__new__(VCIngestionService) + service.app_integration = mock_app_integration + + result = service._detect_provider() + + assert result == 'github_graphql' + + +@pytest.mark.unit +class TestGetOrCreateRepository: + @patch('core.services.vc_ingestion.VCRepository') + def test_get_or_create_repository_new(self, mock_vc_repo_class): + mock_app_integration = Mock() + mock_app_integration.metadata = {} + mock_app_integration.integration.credentials = '{}' + + mock_provider = Mock() + mock_provider.get_repository_info.return_value = { + 'id': '123', + 'description': 'Test repo', + 'url': 'https://github.com/test/repo', + 'is_private': False, + 'default_branch': 'main' + } + + mock_repo = Mock() + mock_repo.id = 1 + mock_repo.full_name = 'test/repo' + mock_vc_repo_class.objects.get_or_create.return_value = (mock_repo, True) + + service = VCIngestionService(mock_app_integration) + service.provider = mock_provider + + result = service._get_or_create_repository('test', 'repo') + + assert result == mock_repo + mock_vc_repo_class.objects.get_or_create.assert_called_once() + + @patch('core.services.vc_ingestion.VCRepository') + def test_get_or_create_repository_existing(self, mock_vc_repo_class): + mock_app_integration = Mock() + mock_app_integration.metadata = {} + mock_app_integration.integration.credentials = '{}' + + mock_provider = Mock() + + mock_repo = Mock() + mock_repo.id = 1 + mock_repo.full_name = 'test/repo' + mock_vc_repo_class.objects.get_or_create.return_value = (mock_repo, False) + + service = VCIngestionService(mock_app_integration) + service.provider = mock_provider + + result = service._get_or_create_repository('test', 'repo') + + assert result == mock_repo + mock_provider.get_repository_info.assert_not_called() + + +@pytest.mark.unit +class TestIngestSingleIssue: + @patch('core.services.vc_ingestion.transaction') + @patch('core.services.vc_ingestion.VCIssue') + def test_ingest_single_issue_success(self, mock_issue_class, mock_transaction): + mock_app_integration = Mock() + mock_app_integration.metadata = {} + mock_app_integration.integration.credentials = '{}' + + mock_provider = Mock() + mock_provider._parse_datetime.return_value = datetime.now() + + mock_repository = Mock() + mock_repository.repo_owner = 'test' + mock_repository.name = 'repo' + + mock_issue = Mock() + mock_issue.number = 1 + mock_issue.comments = Mock() + mock_issue.comments.all.return_value = [] + mock_issue_class.objects.update_or_create.return_value = (mock_issue, True) + + mock_quality_filter = Mock() + mock_quality_filter.remove_emojis.return_value = 'cleaned body' + + service = VCIngestionService(mock_app_integration) + service.provider = mock_provider + service.repository = mock_repository + service.quality_filter = mock_quality_filter + + issue_data = { + 'id': '123', + 'number': 1, + 'title': 'Test issue', + 'body': 'Test body', + 'state': 'open', + 'author': 'testuser', + 'created_at': '2024-01-01T00:00:00Z', + 'updated_at': '2024-01-01T00:00:00Z', + 'url': 'https://github.com/test/repo/issues/1' + } + + service._ingest_single_issue(issue_data) + + mock_issue_class.objects.update_or_create.assert_called_once() + + +@pytest.mark.unit +class TestIngestSinglePullRequest: + @patch('core.services.vc_ingestion.transaction') + @patch('core.services.vc_ingestion.VCPullRequest') + def test_ingest_single_pull_request_success(self, mock_pr_class, mock_transaction): + mock_app_integration = Mock() + mock_app_integration.metadata = {} + mock_app_integration.integration.credentials = '{}' + + mock_provider = Mock() + mock_provider._parse_datetime.return_value = datetime.now() + mock_provider.get_pull_request_comments.return_value = [] + mock_provider.get_pull_request_files.return_value = [] + + mock_repository = Mock() + mock_repository.repo_owner = 'test' + mock_repository.name = 'repo' + + mock_pr = Mock() + mock_pr.number = 1 + mock_pr_class.objects.update_or_create.return_value = (mock_pr, True) + + mock_quality_filter = Mock() + mock_quality_filter.remove_emojis.return_value = 'cleaned body' + + service = VCIngestionService(mock_app_integration) + service.provider = mock_provider + service.repository = mock_repository + service.quality_filter = mock_quality_filter + + pr_data = { + 'id': '123', + 'number': 1, + 'title': 'Test PR', + 'body': 'Test body', + 'state': 'open', + 'author': 'testuser', + 'created_at': '2024-01-01T00:00:00Z', + 'updated_at': '2024-01-01T00:00:00Z', + 'url': 'https://github.com/test/repo/pull/1' + } + + service._ingest_single_pull_request(pr_data) + + mock_pr_class.objects.update_or_create.assert_called_once() + + +@pytest.mark.unit +class TestCreateKnowledgeBaseContent: + def test_create_knowledge_base_content_no_repository(self): + mock_app_integration = Mock() + mock_app_integration.metadata = {} + mock_app_integration.integration.credentials = '{}' + + service = VCIngestionService(mock_app_integration) + service.repository = None + + result = service._create_knowledge_base_content() + + assert result == "" + + def test_create_knowledge_base_content_with_repository(self): + mock_app_integration = Mock() + mock_app_integration.metadata = {} + mock_app_integration.integration.credentials = '{}' + + mock_repository = Mock() + mock_repository.full_name = 'test/repo' + mock_repository.description = 'Test description' + mock_repository.issues = Mock() + mock_repository.issues.all.return_value = [] + mock_repository.pull_requests = Mock() + mock_repository.pull_requests.all.return_value = [] + + service = VCIngestionService(mock_app_integration) + service.repository = mock_repository + + result = service._create_knowledge_base_content() + + assert 'Repository: test/repo' in result + assert 'Description: Test description' in result + + def test_create_knowledge_base_content_with_issues(self): + mock_app_integration = Mock() + mock_app_integration.metadata = {} + mock_app_integration.integration.credentials = '{}' + + mock_issue = Mock() + mock_issue.number = 1 + mock_issue.title = 'Test issue' + mock_issue.state = 'open' + mock_issue.author = 'testuser' + mock_issue.body = 'Test body' + mock_issue.labels = ['bug'] + mock_issue.comments = Mock() + mock_issue.comments.all.return_value = [] + + mock_repository = Mock() + mock_repository.full_name = 'test/repo' + mock_repository.description = 'Test description' + mock_repository.issues = Mock() + mock_repository.issues.all.return_value = [mock_issue] + mock_repository.pull_requests = Mock() + mock_repository.pull_requests.all.return_value = [] + + service = VCIngestionService(mock_app_integration) + service.repository = mock_repository + + result = service._create_knowledge_base_content() + + assert '## Issues' in result + assert 'Issue #1: Test issue' in result + + +@pytest.mark.unit +class TestIngestRepository: + @patch('core.services.vc_ingestion.VCIngestionService._ingest_to_knowledge_base') + @patch('core.services.vc_ingestion.VCIngestionService._ingest_pull_requests') + @patch('core.services.vc_ingestion.VCIngestionService._ingest_issues') + @patch('core.services.vc_ingestion.VCIngestionService._get_or_create_repository') + def test_ingest_repository_success(self, mock_get_repo, mock_ingest_issues, mock_ingest_prs, mock_ingest_kb): + mock_app_integration = Mock() + mock_app_integration.metadata = {} + mock_app_integration.integration.credentials = '{}' + + mock_provider = Mock() + mock_provider.close = Mock() + + mock_repository = Mock() + mock_repository.id = 1 + mock_repository.full_name = 'test/repo' + mock_repository.ingestion_status = 'completed' + + mock_get_repo.return_value = mock_repository + + service = VCIngestionService(mock_app_integration) + service.provider = mock_provider + service.provider_name = 'github' + + service.ingest_repository('test', 'repo') + + mock_get_repo.assert_called_once_with('test', 'repo') + mock_ingest_issues.assert_called_once() + mock_ingest_prs.assert_called_once() + mock_ingest_kb.assert_called_once() + assert mock_repository.ingestion_status == 'completed' + + @patch('core.services.vc_ingestion.VCIngestionService._get_or_create_repository') + def test_ingest_repository_exception(self, mock_get_repo): + mock_app_integration = Mock() + mock_app_integration.metadata = {} + mock_app_integration.integration.credentials = '{}' + + mock_provider = Mock() + mock_provider.close = Mock() + + mock_repository = Mock() + mock_repository.ingestion_status = 'running' + mock_get_repo.return_value = mock_repository + mock_get_repo.side_effect = Exception('Test error') + + service = VCIngestionService(mock_app_integration) + service.provider = mock_provider + service.provider_name = 'github' + service.repository = mock_repository + + with pytest.raises(Exception): + service.ingest_repository('test', 'repo') + + assert mock_repository.ingestion_status == 'failed' + mock_provider.close.assert_called_once()