Skip to content

feat: Add import_embeddings method in MatchingEngineIndex resource #5473

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions google/cloud/aiplatform/matching_engine/matching_engine_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
from google.auth import credentials as auth_credentials
from google.protobuf import field_mask_pb2
from google.cloud.aiplatform import base
from google.cloud.aiplatform import compat
from google.cloud.aiplatform.compat.types import (
index_service as gca_index_service,
index_service_v1beta1 as gca_index_service_v1beta1,
matching_engine_deployed_index_ref as gca_matching_engine_deployed_index_ref,
matching_engine_index as gca_matching_engine_index,
encryption_spec as gca_encryption_spec,
Expand Down Expand Up @@ -393,6 +395,58 @@ def update_embeddings(

return self

def import_embeddings(
self,
config: gca_index_service_v1beta1.ImportIndexRequest.ConnectorConfig,
is_complete_overwrite: Optional[bool] = None,
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
import_request_timeout: Optional[float] = None,
) -> "MatchingEngineIndex":
"""Imports embeddings from an external source, e.g., BigQuery.

Args:
config (aiplatform.compat.types.index_service.ConnectorConfig):
Required. The configuration for importing data from an external source.
is_complete_overwrite (bool):
Optional. If true, completely replace existing index data. Must be
true for streaming update indexes.
request_metadata (Sequence[Tuple[str, str]]):
Optional. Strings which should be sent along with the request as metadata.
import_request_timeout (float):
Optional. The timeout for the request in seconds.

Returns:
MatchingEngineIndex - The updated index resource object.
"""
self.wait()

_LOGGER.log_action_start_against_resource(
"Importing embeddings",
"index",
self,
)

api_v1beta1_client = self.api_client.select_version(compat.V1BETA1)
import_lro = api_v1beta1_client.import_index(
request=gca_index_service_v1beta1.ImportIndexRequest(
name=self.resource_name,
config=config,
is_complete_overwrite=is_complete_overwrite,
),
metadata=request_metadata,
timeout=import_request_timeout,
)

_LOGGER.log_action_started_against_resource_with_lro(
"Import", "index", self.__class__, import_lro
)

self._gca_resource = import_lro.result(timeout=None)

_LOGGER.log_action_completed_against_resource("index", "Imported", self)

return self

@property
def deployed_indexes(
self,
Expand Down
56 changes: 56 additions & 0 deletions tests/unit/aiplatform/test_matching_engine_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform.compat.services import (
index_service_client,
index_service_client_v1beta1,
)

from google.cloud.aiplatform.matching_engine import (
Expand All @@ -40,6 +41,7 @@
index as gca_index,
encryption_spec as gca_encryption_spec,
index_service as gca_index_service,
index_service_v1beta1 as gca_index_service_v1beta1,
)
import constants as test_constants

Expand All @@ -66,6 +68,11 @@
_TEST_CONTENTS_DELTA_URI_UPDATE = "gs://contents_update"
_TEST_IS_COMPLETE_OVERWRITE_UPDATE = True

_TEST_BQ_SOURCE_PATH = "bq://my-project.my-dataset.my-table"
_TEST_ID_COLUMN = "id"
_TEST_EMBEDDING_COLUMN = "embedding"


_TEST_INDEX_CONFIG_DIMENSIONS = 100
_TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT = 150
_TEST_LEAF_NODE_EMBEDDING_COUNT = 123
Expand Down Expand Up @@ -208,6 +215,19 @@ def update_index_embeddings_mock():
yield update_index_mock


@pytest.fixture
def import_index_mock():
with patch.object(
index_service_client_v1beta1.IndexServiceClient, "import_index"
) as import_index_mock:
import_index_lro_mock = mock.Mock(operation.Operation)
import_index_lro_mock.result.return_value = gca_index.Index(
name=_TEST_INDEX_NAME,
)
import_index_mock.return_value = import_index_lro_mock
yield import_index_mock


@pytest.fixture
def list_indexes_mock():
with patch.object(
Expand Down Expand Up @@ -337,6 +357,42 @@ def test_update_index_embeddings(self, update_index_embeddings_mock):
# The service only returns the name of the Index
assert updated_index.gca_resource == gca_index.Index(name=_TEST_INDEX_NAME)

@pytest.mark.usefixtures("get_index_mock")
@pytest.mark.parametrize("is_complete_overwrite", [True, False, None])
def test_import_embeddings(self, import_index_mock, is_complete_overwrite):
aiplatform.init(project=_TEST_PROJECT)

my_index = aiplatform.MatchingEngineIndex(index_name=_TEST_INDEX_ID)

config = gca_index_service_v1beta1.ImportIndexRequest.ConnectorConfig(
big_query_source_config=gca_index_service_v1beta1.ImportIndexRequest.ConnectorConfig.BigQuerySourceConfig(
table_path=_TEST_BQ_SOURCE_PATH,
datapoint_field_mapping=gca_index_service_v1beta1.ImportIndexRequest.ConnectorConfig.DatapointFieldMapping(
id_column=_TEST_ID_COLUMN,
embedding_column=_TEST_EMBEDDING_COLUMN,
),
)
)

updated_index = my_index.import_embeddings(
config=config,
is_complete_overwrite=is_complete_overwrite,
import_request_timeout=_TEST_TIMEOUT,
)

expected_request = gca_index_service_v1beta1.ImportIndexRequest(
name=_TEST_INDEX_NAME,
config=config,
is_complete_overwrite=is_complete_overwrite,
)

import_index_mock.assert_called_once_with(
request=expected_request,
metadata=_TEST_REQUEST_METADATA,
timeout=_TEST_TIMEOUT,
)
assert updated_index.gca_resource == gca_index.Index(name=_TEST_INDEX_NAME)

def test_list_indexes(self, list_indexes_mock):
aiplatform.init(project=_TEST_PROJECT)

Expand Down