Skip to content

feat: GenAI SDK client - Add Vertex AI Prompt Optimizer to the Gen AI SDK (experimental) #5456

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 25 additions & 63 deletions tests/unit/vertexai/genai/test_prompt_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,14 @@
# limitations under the License.
#
# pylint: disable=protected-access,bad-continuation
import copy

import importlib
from unittest import mock

from google.cloud import aiplatform
import vertexai
from google.cloud.aiplatform import initializer as aiplatform_initializer
from google.cloud.aiplatform.compat.services import job_service_client
from google.cloud.aiplatform.compat.types import (
custom_job as gca_custom_job_compat,
)
from google.cloud.aiplatform.compat.types import io as gca_io_compat
from google.cloud.aiplatform.compat.types import (
job_state as gca_job_state_compat,
)

# from google.cloud.aiplatform.utils import gcs_utils
# from google.genai import client
from vertexai._genai import prompt_optimizer
from vertexai._genai import types
from google.genai import client
import pytest


Expand All @@ -42,64 +32,36 @@
_TEST_DISPLAY_NAME = f"{_TEST_PARENT}/customJobs/12345"
_TEST_BASE_OUTPUT_DIR = "gs://test_bucket/test_base_output_dir"

_TEST_CUSTOM_JOB_PROTO = gca_custom_job_compat.CustomJob(
display_name=_TEST_DISPLAY_NAME,
job_spec={
"base_output_directory": gca_io_compat.GcsDestination(
output_uri_prefix=_TEST_BASE_OUTPUT_DIR
),
},
labels={"trained_by_vertex_ai": "true"},
)


@pytest.fixture
def mock_create_custom_job():
with mock.patch.object(
job_service_client.JobServiceClient, "create_custom_job"
) as create_custom_job_mock:
custom_job_proto = copy.deepcopy(_TEST_CUSTOM_JOB_PROTO)
custom_job_proto.name = _TEST_DISPLAY_NAME
custom_job_proto.state = gca_job_state_compat.JobState.JOB_STATE_PENDING
create_custom_job_mock.return_value = custom_job_proto
yield create_custom_job_mock


class TestPromptOptimizer:
"""Unit tests for the Prompt Optimizer client."""

def setup_method(self):
importlib.reload(aiplatform_initializer)
importlib.reload(aiplatform)
importlib.reload(vertexai)
vertexai.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)

# @pytest.mark.usefixtures("google_auth_mock")
# def test_prompt_optimizer_client(self):
# test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
# assert test_client is not None
# assert test_client._api_client.vertexai
# assert test_client._api_client.project == _TEST_PROJECT
# assert test_client._api_client.location == _TEST_LOCATION
@pytest.mark.usefixtures("google_auth_mock")
def test_prompt_optimizer_client(self):
test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
assert test_client.prompt_optimizer is not None

@mock.patch.object(client.Client, "_get_api_client")
@mock.patch.object(prompt_optimizer.PromptOptimizer, "_create_custom_job_resource")
def test_prompt_optimizer_optimize(self, mock_custom_job, mock_client):
"""Test that prompt_optimizer.optimize method creates a custom job."""
test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
test_client.prompt_optimizer.optimize(
method="vapo",
config=types.PromptOptimizerVAPOConfig(
config_path="gs://ssusie-vapo-sdk-test/config.json",
wait_for_completion=False,
service_account="test-service-account",
),
)
mock_client.assert_called_once()
mock_custom_job.assert_called_once()

# @mock.patch.object(client.Client, "_get_api_client")
# @mock.patch.object(
# gcs_utils.resource_manager_utils, "get_project_number", return_value=12345
# )
# def test_prompt_optimizer_optimize(
# self, mock_get_project_number, mock_client, mock_create_custom_job
# ):
# """Test that prompt_optimizer.optimize method creates a custom job."""
# test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
# test_client.prompt_optimizer.optimize(
# method="vapo",
# config={
# "config_path": "gs://ssusie-vapo-sdk-test/config.json",
# "wait_for_completion": False,
# },
# )
# mock_create_custom_job.assert_called_once()
# mock_get_project_number.assert_called_once()
# TODO(b/415060797): add more tests for prompt_optimizer.optimize
12 changes: 12 additions & 0 deletions vertexai/_genai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,18 @@ def evals(self):
) from e
return self._evals.Evals(self._api_client)

@property
@_common.experimental_warning(
"The Vertex SDK GenAI prompt optimizer module is experimental, and may change in future "
"versions."
)
def prompt_optimizer(self):
if self._prompt_optimizer is None:
self._prompt_optimizer = importlib.import_module(
".prompt_optimizer", __package__
)
return self._prompt_optimizer.PromptOptimizer(self._api_client)

@property
@_common.experimental_warning(
"The Vertex SDK GenAI async client is experimental, "
Expand Down
Loading