diff --git a/google/cloud/aiplatform_v1beta1/types/endpoint.py b/google/cloud/aiplatform_v1beta1/types/endpoint.py index e1fdc1d571..277f115e47 100644 --- a/google/cloud/aiplatform_v1beta1/types/endpoint.py +++ b/google/cloud/aiplatform_v1beta1/types/endpoint.py @@ -668,13 +668,39 @@ class PublisherModelConfig(proto.Message): logging_config (google.cloud.aiplatform_v1beta1.types.PredictRequestResponseLoggingConfig): The prediction request/response logging config. + data_sharing_enabled_provider (google.cloud.aiplatform_v1beta1.types.PublisherModelConfig.ModelProvider): + Optional. The model provider (publisher) for + which the customer has enabled data sharing. For + publisher models that are configured to require + data sharing, a prediction request is only + allowed when the model's publisher matches this + provider. Otherwise, the request is rejected. """ + class ModelProvider(proto.Enum): + r"""A model provider (publisher) that prediction data may be + shared with. + + Values: + MODEL_PROVIDER_UNSPECIFIED (0): + Unspecified model provider. + ANTHROPIC (1): + Anthropic. + """ + + MODEL_PROVIDER_UNSPECIFIED = 0 + ANTHROPIC = 1 + logging_config: "PredictRequestResponseLoggingConfig" = proto.Field( proto.MESSAGE, number=3, message="PredictRequestResponseLoggingConfig", ) + data_sharing_enabled_provider: ModelProvider = proto.Field( + proto.ENUM, + number=4, + enum=ModelProvider, + ) class ClientConnectionConfig(proto.Message): diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json index 08e4ff30d8..53753f28da 100644 --- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json +++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json @@ -8,7 +8,7 @@ ], "language": "PYTHON", "name": "google-cloud-aiplatform", - "version": "1.158.0" + "version": "0.0.0" }, "snippets": [ { diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json index 2eec3a1799..bb8afa1c8f 100644 --- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json +++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json @@ -8,7 +8,7 @@ ], "language": "PYTHON", "name": "google-cloud-aiplatform", - "version": "1.158.0" + "version": "0.0.0" }, "snippets": [ { diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py index cdc156eacc..f34c67f287 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py @@ -4553,6 +4553,7 @@ def test_fetch_publisher_model_config(request_type, transport: str = 'grpc'): '__call__') as call: # Designate an appropriate return value for the call. call.return_value = endpoint.PublisherModelConfig( + data_sharing_enabled_provider=endpoint.PublisherModelConfig.ModelProvider.ANTHROPIC, ) response = client.fetch_publisher_model_config(request) @@ -4564,6 +4565,7 @@ def test_fetch_publisher_model_config(request_type, transport: str = 'grpc'): # Establish that the response is the type that we expect. assert isinstance(response, endpoint.PublisherModelConfig) + assert response.data_sharing_enabled_provider == endpoint.PublisherModelConfig.ModelProvider.ANTHROPIC def test_fetch_publisher_model_config_non_empty_request_with_auto_populated_field(): @@ -4681,6 +4683,7 @@ async def test_fetch_publisher_model_config_async(request_type, transport: str = '__call__') as call: # Designate an appropriate return value for the call. call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(endpoint.PublisherModelConfig( + data_sharing_enabled_provider=endpoint.PublisherModelConfig.ModelProvider.ANTHROPIC, )) response = await client.fetch_publisher_model_config(request) @@ -4692,6 +4695,7 @@ async def test_fetch_publisher_model_config_async(request_type, transport: str = # Establish that the response is the type that we expect. assert isinstance(response, endpoint.PublisherModelConfig) + assert response.data_sharing_enabled_provider == endpoint.PublisherModelConfig.ModelProvider.ANTHROPIC def test_fetch_publisher_model_config_field_headers(): client = EndpointServiceClient( @@ -7413,6 +7417,7 @@ async def test_fetch_publisher_model_config_empty_call_grpc_asyncio(): '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint.PublisherModelConfig( + data_sharing_enabled_provider=endpoint.PublisherModelConfig.ModelProvider.ANTHROPIC, )) await client.fetch_publisher_model_config(request=None) @@ -8667,6 +8672,7 @@ def test_fetch_publisher_model_config_rest_call_success(request_type): with mock.patch.object(type(client.transport._session), 'request') as req: # Designate an appropriate value for the returned response. return_value = endpoint.PublisherModelConfig( + data_sharing_enabled_provider=endpoint.PublisherModelConfig.ModelProvider.ANTHROPIC, ) # Wrap the value into a proper Response obj @@ -8683,6 +8689,7 @@ def test_fetch_publisher_model_config_rest_call_success(request_type): # Establish that the response is the type that we expect. assert isinstance(response, endpoint.PublisherModelConfig) + assert response.data_sharing_enabled_provider == endpoint.PublisherModelConfig.ModelProvider.ANTHROPIC @pytest.mark.parametrize("null_interceptor", [True, False]) @@ -10836,6 +10843,7 @@ async def test_fetch_publisher_model_config_rest_asyncio_call_success(request_ty with mock.patch.object(type(client.transport._session), 'request') as req: # Designate an appropriate value for the returned response. return_value = endpoint.PublisherModelConfig( + data_sharing_enabled_provider=endpoint.PublisherModelConfig.ModelProvider.ANTHROPIC, ) # Wrap the value into a proper Response obj @@ -10852,6 +10860,7 @@ async def test_fetch_publisher_model_config_rest_asyncio_call_success(request_ty # Establish that the response is the type that we expect. assert isinstance(response, endpoint.PublisherModelConfig) + assert response.data_sharing_enabled_provider == endpoint.PublisherModelConfig.ModelProvider.ANTHROPIC @pytest.mark.asyncio