diff --git a/chatbot-core/api/tools/utils.py b/chatbot-core/api/tools/utils.py index 36beee175..bbbb11a7f 100644 --- a/chatbot-core/api/tools/utils.py +++ b/chatbot-core/api/tools/utils.py @@ -2,6 +2,7 @@ Utilities for the tools package. """ +import functools import json import os import re @@ -193,29 +194,45 @@ def extract_chunks_content(chunks: List[Dict], logger) -> str: else retrieval_config["empty_context_message"] ) +def tokenize_plugin_name(name: str) -> str: + """Normalize a plugin name for case/separator-insensitive comparison.""" + return name.replace('-', '').replace(' ', '').lower() + + +@functools.lru_cache(maxsize=1) +def load_plugin_names() -> frozenset: + """ + Load and cache the set of known plugin names (tokenized) from disk. + + The JSON file is static data that never changes at runtime, so it is + read once on first access and kept in memory for O(1) lookups. + + Returns: + frozenset: A set of tokenized plugin names. + """ + list_plugin_names_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", "data", "raw", "plugin_names.json" + ) + with open(list_plugin_names_path, "r", encoding="utf-8") as f: + list_plugin_names = json.load(f) + return frozenset(tokenize_plugin_name(name) for name in list_plugin_names) + + def is_valid_plugin(plugin_name: str) -> bool: """ Checks whether the given plugin name exists in the list of known plugin names. + Uses a cached frozenset for O(1) membership checks instead of + re-reading from disk and performing a linear scan on every call. + Args: plugin_name (str): The name of the plugin to validate. Returns: bool: True if the plugin exists in the list, False otherwise. """ - def tokenize(item: str) -> str: - item = item.replace('-', '') - return item.replace(' ', '').lower() - list_plugin_names_path = os.path.join(os.path.abspath(__file__), - "..", "..", "data", "raw", "plugin_names.json") - with open(list_plugin_names_path, "r", encoding="utf-8") as f: - list_plugin_names = json.load(f) - - for name in list_plugin_names: - if tokenize(plugin_name) == tokenize(name): - return True - - return False + return tokenize_plugin_name(plugin_name) in load_plugin_names() def filter_retrieved_data( semantic_data: List[Dict], @@ -234,14 +251,12 @@ def filter_retrieved_data( Returns: Tuple[List[Dict], List[Dict]]: Filtered semantic and keyword data. """ - def tokenize(item: str) -> str: - item = item.replace('-', '') - return item.replace(' ', '').lower() - semantic_filtered_data = [item for item in semantic_data - if tokenize(item["metadata"]["title"]) == tokenize(plugin_name)] + if tokenize_plugin_name(item["metadata"]["title"]) + == tokenize_plugin_name(plugin_name)] keyword_filtered_data = [item for item in keyword_data - if tokenize(item["metadata"]["title"]) == tokenize(plugin_name)] + if tokenize_plugin_name(item["metadata"]["title"]) + == tokenize_plugin_name(plugin_name)] return semantic_filtered_data, keyword_filtered_data diff --git a/chatbot-core/tests/unit/tools/test_utils.py b/chatbot-core/tests/unit/tools/test_utils.py new file mode 100644 index 000000000..47582ab60 --- /dev/null +++ b/chatbot-core/tests/unit/tools/test_utils.py @@ -0,0 +1,156 @@ +"""Unit tests for plugin name caching and validation in tools/utils.py.""" +import json +import unittest +from unittest.mock import patch, mock_open + +from api.tools.utils import ( + is_valid_plugin, + load_plugin_names, + filter_retrieved_data, +) + +SAMPLE_PLUGINS = json.dumps( + ["git", "blue-ocean", "credentials", "github-branch-source"] +) + + +class TestIsValidPlugin(unittest.TestCase): + """Tests for the is_valid_plugin function.""" + + def setUp(self): + load_plugin_names.cache_clear() + + def tearDown(self): + load_plugin_names.cache_clear() + + @patch( + "builtins.open", + mock_open(read_data=SAMPLE_PLUGINS), + ) + @patch("os.path.dirname", return_value="/fake/dir") + def test_exact_match(self, _mock_dir): + """Exact plugin name should be valid.""" + self.assertTrue(is_valid_plugin("git")) + + @patch( + "builtins.open", + mock_open(read_data=SAMPLE_PLUGINS), + ) + @patch("os.path.dirname", return_value="/fake/dir") + def test_case_insensitive_match(self, _mock_dir): + """Plugin names should match case-insensitively.""" + self.assertTrue(is_valid_plugin("Git")) + self.assertTrue(is_valid_plugin("GIT")) + + @patch( + "builtins.open", + mock_open(read_data=SAMPLE_PLUGINS), + ) + @patch("os.path.dirname", return_value="/fake/dir") + def test_hyphen_insensitive_match(self, _mock_dir): + """Hyphens should be ignored during matching.""" + self.assertTrue(is_valid_plugin("blue ocean")) + self.assertTrue(is_valid_plugin("blueocean")) + self.assertTrue(is_valid_plugin("Blue-Ocean")) + + @patch( + "builtins.open", + mock_open(read_data=SAMPLE_PLUGINS), + ) + @patch("os.path.dirname", return_value="/fake/dir") + def test_invalid_plugin_returns_false(self, _mock_dir): + """Non-existent plugin name should return False.""" + self.assertFalse(is_valid_plugin("nonexistent-plugin")) + self.assertFalse(is_valid_plugin("")) + + +class TestFilterRetrievedData(unittest.TestCase): + """Tests for filter_retrieved_data using the shared tokenizer.""" + + def test_filters_matching_entries(self): + """Only entries whose title matches the plugin name should remain.""" + semantic_data = [ + {"metadata": {"title": "blue-ocean"}, "chunk_text": "a"}, + {"metadata": {"title": "credentials"}, "chunk_text": "b"}, + ] + keyword_data = [ + {"metadata": {"title": "Blue Ocean"}, "chunk_text": "c"}, + {"metadata": {"title": "git"}, "chunk_text": "d"}, + ] + sem, kw = filter_retrieved_data( + semantic_data, keyword_data, "blue-ocean" + ) + self.assertEqual(len(sem), 1) + self.assertEqual(sem[0]["chunk_text"], "a") + self.assertEqual(len(kw), 1) + self.assertEqual(kw[0]["chunk_text"], "c") + + def test_returns_empty_when_no_match(self): + """No results should be returned when nothing matches.""" + data = [{"metadata": {"title": "git"}, "chunk_text": "x"}] + sem, kw = filter_retrieved_data(data, data, "nonexistent") + self.assertEqual(len(sem), 0) + self.assertEqual(len(kw), 0) + + def test_empty_input(self): + """Empty input lists should return empty lists.""" + sem, kw = filter_retrieved_data([], [], "git") + self.assertEqual(sem, []) + self.assertEqual(kw, []) + + +class TestPluginNameCacheIntegration(unittest.TestCase): + """Integration tests verifying lru_cache behaviour through the public API. + + These tests confirm that plugin_names.json is read from disk exactly + once, regardless of how many times public functions that depend on the + cached data are called. + """ + + def setUp(self): + load_plugin_names.cache_clear() + + def tearDown(self): + load_plugin_names.cache_clear() + + @patch("os.path.dirname", return_value="/fake/dir") + def test_multiple_is_valid_plugin_calls_read_file_once(self, _mock_dir): + """Repeated is_valid_plugin() calls should hit the cache, not disk.""" + with patch( + "builtins.open", mock_open(read_data=SAMPLE_PLUGINS) + ) as mocked_file: + is_valid_plugin("git") + is_valid_plugin("blue-ocean") + is_valid_plugin("nonexistent") + mocked_file.assert_called_once() + + @patch("os.path.dirname", return_value="/fake/dir") + def test_cache_shared_across_public_functions(self, _mock_dir): + """is_valid_plugin and load_plugin_names should share the same cache.""" + with patch( + "builtins.open", mock_open(read_data=SAMPLE_PLUGINS) + ) as mocked_file: + # First access via is_valid_plugin populates the cache + is_valid_plugin("git") + # Direct call should reuse the cached result + result = load_plugin_names() + self.assertIsInstance(result, frozenset) + mocked_file.assert_called_once() + + @patch("os.path.dirname", return_value="/fake/dir") + def test_cache_clear_forces_reload(self, _mock_dir): + """After cache_clear(), the next call should re-read the file.""" + with patch( + "builtins.open", mock_open(read_data=SAMPLE_PLUGINS) + ) as mocked_file: + is_valid_plugin("git") + self.assertEqual(mocked_file.call_count, 1) + + load_plugin_names.cache_clear() + + is_valid_plugin("git") + self.assertEqual(mocked_file.call_count, 2) + + +if __name__ == "__main__": + unittest.main()