From f5909b2289e203d8a41021d42515cce45b91038c Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Wed, 6 May 2026 17:57:06 -0700 Subject: [PATCH] feat: Add RetrieveSkills semantic search method in Vertex AI Skill Registry SDK PiperOrigin-RevId: 911650461 --- .../genai/replays/test_skills_retrieve.py | 26 +++ .../unit/vertexai/genai/test_genai_skills.py | 110 +++++++++++- vertexai/_genai/skills.py | 170 ++++++++++++++++++ vertexai/_genai/types/__init__.py | 20 +++ vertexai/_genai/types/common.py | 97 ++++++++++ 5 files changed, 419 insertions(+), 4 deletions(-) create mode 100644 tests/unit/vertexai/genai/replays/test_skills_retrieve.py diff --git a/tests/unit/vertexai/genai/replays/test_skills_retrieve.py b/tests/unit/vertexai/genai/replays/test_skills_retrieve.py new file mode 100644 index 0000000000..b761fb7e36 --- /dev/null +++ b/tests/unit/vertexai/genai/replays/test_skills_retrieve.py @@ -0,0 +1,26 @@ +"""Tests the skills.retrieve() method against the autopush endpoint.""" + +from tests.unit.vertexai.genai.replays import pytest_helper +from vertexai._genai import types + +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), +) + + +def test_retrieve_skills(client): + # Target the prod endpoint for the Skill Registry API + client._api_client._http_options.base_url = ( + "https://us-central1-aiplatform.googleapis.com" + ) + + response = client.skills.retrieve(query="stubby", config={"top_k": 2}) + + assert isinstance(response, types.RetrieveSkillsResponse) + assert response.retrieved_skills is not None + + for retrieved in response.retrieved_skills: + assert isinstance(retrieved, types.RetrievedSkill) + assert retrieved.skill_name is not None + assert retrieved.description is not None diff --git a/tests/unit/vertexai/genai/test_genai_skills.py b/tests/unit/vertexai/genai/test_genai_skills.py index 45d83a5fff..cc6a54ecaf 100644 --- a/tests/unit/vertexai/genai/test_genai_skills.py +++ b/tests/unit/vertexai/genai/test_genai_skills.py @@ -1,6 +1,7 @@ # //third_party/py/google/cloud/aiplatform/tests/unit/vertexai/genai/test_genai_skills.py import json from unittest import mock +import google.auth.credentials from vertexai import _genai as genai from vertexai._genai import client as vertexai_client from google.genai import types as genai_types @@ -9,7 +10,7 @@ @pytest.fixture def skills_client(): - creds = mock.MagicMock() + creds = mock.create_autospec(google.auth.credentials.Credentials, instance=True) creds.token = "test_token" client = vertexai_client.Client( project="test-project", location="test-location", credentials=creds @@ -17,6 +18,16 @@ def skills_client(): return client.skills +@pytest.fixture +def async_skills_client(): + creds = mock.create_autospec(google.auth.credentials.Credentials, instance=True) + creds.token = "test_token" + client = vertexai_client.Client( + project="test-project", location="test-location", credentials=creds + ) + return client.aio.skills + + class TestGenaiSkills: mock_get_skill_response = { "name": "projects/test-project/locations/test-location/skills/test-skill", @@ -24,8 +35,9 @@ class TestGenaiSkills: } def test_get_skill(self, skills_client): - """Tests the get_skill method.""" - with mock.patch.object(skills_client._api_client, "request") as request_mock: + with mock.patch.object( + skills_client._api_client, "request", autospec=True + ) as request_mock: request_mock.return_value = genai_types.HttpResponse( body=json.dumps(self.mock_get_skill_response) ) @@ -33,7 +45,7 @@ def test_get_skill(self, skills_client): "projects/test-project/locations/test-location/skills/test-skill" ) skill = skills_client.get(name=skill_name) - request_mock.assert_called_with( + request_mock.assert_called_once_with( "get", skill_name, {"_url": {"name": skill_name}}, @@ -42,3 +54,93 @@ def test_get_skill(self, skills_client): assert isinstance(skill, genai.types.Skill) assert skill.name == skill_name assert skill.display_name == "My Test Skill" + + def test_retrieve_skills_response(self, skills_client): + mock_retrieve_response = { + "retrievedSkills": [ + { + "skillName": ( + "projects/test-project/locations/test-location/skills/skill-1" + ), + "description": "Skill 1 Description", + }, + { + "skillName": ( + "projects/test-project/locations/test-location/skills/skill-2" + ), + "description": "Skill 2 Description", + }, + ] + } + + with mock.patch.object( + skills_client._api_client, "request", autospec=True + ) as request_mock: + request_mock.return_value = genai_types.HttpResponse( + body=json.dumps(mock_retrieve_response) + ) + + response = skills_client.retrieve(query="test query", config={"top_k": 5}) + + assert isinstance(response, genai.types.RetrieveSkillsResponse) + assert len(response.retrieved_skills) == 2 + assert response.retrieved_skills[0].skill_name == ( + "projects/test-project/locations/test-location/skills/skill-1" + ) + assert response.retrieved_skills[0].description == "Skill 1 Description" + + def test_retrieve_skills_request_params(self, skills_client): + mock_retrieve_response = {"retrievedSkills": []} + + with mock.patch.object( + skills_client._api_client, "request", autospec=True + ) as request_mock: + request_mock.return_value = genai_types.HttpResponse( + body=json.dumps(mock_retrieve_response) + ) + + skills_client.retrieve(query="test query", config={"top_k": 5}) + + request_mock.assert_called_once_with( + "get", + "skills:retrieve?query=test+query&topK=5", + {"_query": {"query": "test query", "topK": 5}}, + None, + ) + + @pytest.mark.asyncio + async def test_retrieve_skills_async(self, async_skills_client): + mock_retrieve_response = { + "retrievedSkills": [ + { + "skillName": ( + "projects/test-project/locations/test-location/skills/skill-1" + ), + "description": "Skill 1 Description", + } + ] + } + + with mock.patch.object( + async_skills_client._api_client, "async_request", autospec=True + ) as request_mock: + request_mock.return_value = genai_types.HttpResponse( + body=json.dumps(mock_retrieve_response) + ) + + response = await async_skills_client.retrieve( + query="test query", config={"top_k": 1} + ) + + assert isinstance(response, genai.types.RetrieveSkillsResponse) + assert len(response.retrieved_skills) == 1 + assert response.retrieved_skills[0].skill_name == ( + "projects/test-project/locations/test-location/skills/skill-1" + ) + + request_mock.assert_called_once_with( + "get", + "skills:retrieve?query=test+query&topK=1", + {"_query": {"query": "test query", "topK": 1}}, + None, + ) diff --git a/vertexai/_genai/skills.py b/vertexai/_genai/skills.py index 1eb145afe7..8489cfb3b1 100644 --- a/vertexai/_genai/skills.py +++ b/vertexai/_genai/skills.py @@ -44,6 +44,36 @@ def _GetSkillRequestParameters_to_vertex( return to_object +def _RetrieveSkillsConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["top_k"]) is not None: + setv(parent_object, ["_query", "topK"], getv(from_object, ["top_k"])) + + return to_object + + +def _RetrieveSkillsRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["query"]) is not None: + setv(to_object, ["_query", "query"], getv(from_object, ["query"])) + + if getv(from_object, ["config"]) is not None: + setv( + to_object, + ["config"], + _RetrieveSkillsConfig_to_vertex(getv(from_object, ["config"]), to_object), + ) + + return to_object + + class Skills(_api_module.BaseModule): """Class for managing Skills in the Skill Registry.""" @@ -116,6 +146,75 @@ def get( self._api_client._verify_response(return_value) return return_value + def retrieve( + self, *, query: str, config: Optional[types.RetrieveSkillsConfigOrDict] = None + ) -> types.RetrieveSkillsResponse: + """ + Retrieves skills semantically matched to a query. + """ + + parameter_model = types._RetrieveSkillsRequestParameters( + query=query, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _RetrieveSkillsRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "skills:retrieve".format_map(request_url_dict) + else: + path = "skills:retrieve" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.RetrieveSkillsResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + class AsyncSkills(_api_module.BaseModule): """Class for managing Skills in the Skill Registry.""" @@ -190,3 +289,74 @@ async def get( self._api_client._verify_response(return_value) return return_value + + async def retrieve( + self, *, query: str, config: Optional[types.RetrieveSkillsConfigOrDict] = None + ) -> types.RetrieveSkillsResponse: + """ + Retrieves skills semantically matched to a query. + """ + + parameter_model = types._RetrieveSkillsRequestParameters( + query=query, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _RetrieveSkillsRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "skills:retrieve".format_map(request_url_dict) + else: + path = "skills:retrieve" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.RetrieveSkillsResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value diff --git a/vertexai/_genai/types/__init__.py b/vertexai/_genai/types/__init__.py index baeb22b20b..e550ae7285 100644 --- a/vertexai/_genai/types/__init__.py +++ b/vertexai/_genai/types/__init__.py @@ -115,6 +115,7 @@ from .common import _RestoreVersionRequestParameters from .common import _RetrieveAgentEngineMemoriesRequestParameters from .common import _RetrieveMemoryProfilesRequestParameters +from .common import _RetrieveSkillsRequestParameters from .common import _RollbackAgentEngineMemoryRequestParameters from .common import _RunQueryJobAgentEngineConfig from .common import _RunQueryJobAgentEngineConfigDict @@ -1046,6 +1047,9 @@ from .common import RetrieveAgentEngineMemoriesConfig from .common import RetrieveAgentEngineMemoriesConfigDict from .common import RetrieveAgentEngineMemoriesConfigOrDict +from .common import RetrievedSkill +from .common import RetrievedSkillDict +from .common import RetrievedSkillOrDict from .common import RetrieveMemoriesRequestSimilaritySearchParams from .common import RetrieveMemoriesRequestSimilaritySearchParamsDict from .common import RetrieveMemoriesRequestSimilaritySearchParamsOrDict @@ -1064,6 +1068,12 @@ from .common import RetrieveProfilesResponse from .common import RetrieveProfilesResponseDict from .common import RetrieveProfilesResponseOrDict +from .common import RetrieveSkillsConfig +from .common import RetrieveSkillsConfigDict +from .common import RetrieveSkillsConfigOrDict +from .common import RetrieveSkillsResponse +from .common import RetrieveSkillsResponseDict +from .common import RetrieveSkillsResponseOrDict from .common import RollbackAgentEngineMemoryConfig from .common import RollbackAgentEngineMemoryConfigDict from .common import RollbackAgentEngineMemoryConfigOrDict @@ -2492,6 +2502,15 @@ "Skill", "SkillDict", "SkillOrDict", + "RetrieveSkillsConfig", + "RetrieveSkillsConfigDict", + "RetrieveSkillsConfigOrDict", + "RetrievedSkill", + "RetrievedSkillDict", + "RetrievedSkillOrDict", + "RetrieveSkillsResponse", + "RetrieveSkillsResponseDict", + "RetrieveSkillsResponseOrDict", "PromptOptimizerConfig", "PromptOptimizerConfigDict", "PromptOptimizerConfigOrDict", @@ -2730,6 +2749,7 @@ "_GetCustomJobParameters", "_OptimizeRequestParameters", "_GetSkillRequestParameters", + "_RetrieveSkillsRequestParameters", "evals", "agent_engines", "prompts", diff --git a/vertexai/_genai/types/common.py b/vertexai/_genai/types/common.py index 8f5532d721..f624a65cc5 100644 --- a/vertexai/_genai/types/common.py +++ b/vertexai/_genai/types/common.py @@ -17933,6 +17933,103 @@ class SkillDict(TypedDict, total=False): SkillOrDict = Union[Skill, SkillDict] +class RetrieveSkillsConfig(_common.BaseModel): + """Config for retrieving skills.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + top_k: Optional[int] = Field( + default=None, + description="""Optional. The maximum number of skills to return. The service may + return fewer than this value. If unspecified, at most 10 skills will be + returned. The maximum value is 100. + """, + ) + + +class RetrieveSkillsConfigDict(TypedDict, total=False): + """Config for retrieving skills.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + top_k: Optional[int] + """Optional. The maximum number of skills to return. The service may + return fewer than this value. If unspecified, at most 10 skills will be + returned. The maximum value is 100. + """ + + +RetrieveSkillsConfigOrDict = Union[RetrieveSkillsConfig, RetrieveSkillsConfigDict] + + +class _RetrieveSkillsRequestParameters(_common.BaseModel): + """Parameters for retrieving skills.""" + + query: Optional[str] = Field( + default=None, description="""Required. The query to find matching skills.""" + ) + config: Optional[RetrieveSkillsConfig] = Field(default=None, description="""""") + + +class _RetrieveSkillsRequestParametersDict(TypedDict, total=False): + """Parameters for retrieving skills.""" + + query: Optional[str] + """Required. The query to find matching skills.""" + + config: Optional[RetrieveSkillsConfigDict] + """""" + + +_RetrieveSkillsRequestParametersOrDict = Union[ + _RetrieveSkillsRequestParameters, _RetrieveSkillsRequestParametersDict +] + + +class RetrievedSkill(_common.BaseModel): + """A retrieved skill from semantic search.""" + + skill_name: Optional[str] = Field( + default=None, description="""The resource name of the skill.""" + ) + description: Optional[str] = Field( + default=None, description="""The description of the skill.""" + ) + + +class RetrievedSkillDict(TypedDict, total=False): + """A retrieved skill from semantic search.""" + + skill_name: Optional[str] + """The resource name of the skill.""" + + description: Optional[str] + """The description of the skill.""" + + +RetrievedSkillOrDict = Union[RetrievedSkill, RetrievedSkillDict] + + +class RetrieveSkillsResponse(_common.BaseModel): + """Response for retrieving skills.""" + + retrieved_skills: Optional[list[RetrievedSkill]] = Field( + default=None, description="""List of retrieved skills ranked by similarity.""" + ) + + +class RetrieveSkillsResponseDict(TypedDict, total=False): + """Response for retrieving skills.""" + + retrieved_skills: Optional[list[RetrievedSkillDict]] + """List of retrieved skills ranked by similarity.""" + + +RetrieveSkillsResponseOrDict = Union[RetrieveSkillsResponse, RetrieveSkillsResponseDict] + + class PromptOptimizerConfig(_common.BaseModel): """VAPO Prompt Optimizer Config."""