diff --git a/tests/unit/vertexai/genai/test_sandbox.py b/tests/unit/vertexai/genai/test_sandbox.py index cc8cd7968b..45949ac60d 100644 --- a/tests/unit/vertexai/genai/test_sandbox.py +++ b/tests/unit/vertexai/genai/test_sandbox.py @@ -73,8 +73,9 @@ def teardown_method(self): @mock.patch.object(client.Client, "_get_api_client") def test_send_command(self, mock_get_api_client): mock_sandbox = mock.Mock() - mock_sandbox.connection_info.load_balancer_ip = "127.0.0.1" - mock_sandbox.connection_info.load_balancer_hostname = None + mock_sandbox.connection_info.load_balancer_ip = None + mock_sandbox.connection_info.load_balancer_hostname = "test-us-central1.autopush-sandbox.vertexai.goog" + mock_sandbox.connection_info.routing_token = "test_routing_token" mock_http_client = mock_get_api_client.return_value mock_http_client.request.return_value = genai_types.HttpResponse( body=b"{}", headers={} @@ -91,7 +92,44 @@ def test_send_command(self, mock_get_api_client): assert call_args is not None _, kwargs = call_args http_options = kwargs["http_options"] - assert http_options.base_url == "http://127.0.0.1/test/path" + assert http_options.base_url == ( + "https://test-us-central1.autopush-sandbox.vertexai.goog/test/path" + ) assert http_options.headers["Authorization"] == "Bearer test_token" mock_http_client.request.assert_called_with("GET", "test/path", {}) + + @mock.patch( + "google.cloud.aiplatform.vertexai._genai.sandboxes.Sandboxes.generate_access_token" + ) + @mock.patch.object(client.Client, "_get_api_client") + def test_generate_browser_ws_headers( + self, mock_get_api_client, mock_generate_access_token + ): + mock_generate_access_token.return_value = "test_token" + + mock_sandbox = mock.Mock() + mock_sandbox.connection_info.load_balancer_ip = None + mock_sandbox.connection_info.load_balancer_hostname = ( + "test-us-central1.autopush-sandbox.vertexai.goog" + ) + mock_sandbox.connection_info.routing_token = "test_routing_token" + mock_http_client = mock_get_api_client.return_value + mock_http_client.request.return_value = genai_types.HttpResponse( + body=b'{"endpoint": "test/endpoint"}', headers={} + ) + ws_url, headers = ( + self.client.agent_engines.sandboxes.generate_browser_ws_headers( + sandbox_environment=mock_sandbox, + service_account_email=_TEST_SERVICE_ACCOUNT_EMAIL, + timeout=3600, + ) + ) + assert ( + ws_url + == "wss://test-us-central1.autopush-sandbox.vertexai.goog/test/endpoint" + ) + assert ( + headers["Sec-WebSocket-Protocol"] + == "v1.stream, test_token, test_routing_token, 9222" + ) diff --git a/vertexai/_genai/sandboxes.py b/vertexai/_genai/sandboxes.py index fbd7cc0a15..6b1ba20e45 100644 --- a/vertexai/_genai/sandboxes.py +++ b/vertexai/_genai/sandboxes.py @@ -56,6 +56,23 @@ def _CreateAgentEngineSandboxConfig_to_vertex( if getv(from_object, ["ttl"]) is not None: setv(parent_object, ["ttl"], getv(from_object, ["ttl"])) + if getv(from_object, ["sandbox_environment_template"]) is not None: + setv( + parent_object, + ["sandboxEnvironmentTemplate"], + getv(from_object, ["sandbox_environment_template"]), + ) + + if getv(from_object, ["sandbox_environment_snapshot"]) is not None: + setv( + parent_object, + ["sandboxEnvironmentSnapshot"], + getv(from_object, ["sandbox_environment_snapshot"]), + ) + + if getv(from_object, ["owner"]) is not None: + setv(parent_object, ["owner"], getv(from_object, ["owner"])) + return to_object @@ -837,7 +854,7 @@ def delete( def generate_access_token( self, service_account_email: str, - sandbox_id: str, + sandbox_hostname: str, port: str = "8080", timeout: int = 3600, ) -> str: @@ -846,8 +863,8 @@ def generate_access_token( Args: service_account_email (str): Required. The email of the service account to use for signing. - sandbox_id (str): - Required. The resource name of the sandbox to generate a token for. + sandbox_hostname (str): + Required. The hostname of the sandbox to generate a token for. port (str): Optional. The port to use for the token. Defaults to "8080". timeout (int): @@ -858,13 +875,14 @@ def generate_access_token( """ client = iam_credentials_v1.IAMCredentialsClient() name = f"projects/-/serviceAccounts/{service_account_email}" - custom_claims = {"port": port, "sandbox_id": sandbox_id} + custom_claims = {"hostname": sandbox_hostname, "port": port} payload = { "iat": int(time.time()), "exp": int(time.time()) + timeout, "iss": service_account_email, + "sub": service_account_email, "nonce": secrets.randbelow(1000000000) + 1, - "aud": "vmaas-proxy-api", # default audience for sandbox proxy + "aud": "https://aiplatform.googleapis.com/", # default audience for sandbox proxy **custom_claims, } request = iam_credentials_v1.SignJwtRequest( @@ -880,6 +898,7 @@ def send_command( http_method: str, access_token: str, sandbox_environment: types.SandboxEnvironment, + port: str = "8080", path: Optional[str] = None, query_params: Optional[dict[str, object]] = None, headers: Optional[dict[str, str]] = None, @@ -894,6 +913,8 @@ def send_command( Required. The access token to use for authorization. sandbox_environment (types.SandboxEnvironment): Required. The sandbox environment to send the command to. + port (str): + Optional. The port to use for the token. Defaults to "8080". This should be one of the ports specified during template creation. path (str): Optional. The path to send the command to. query_params (dict[str, object]): @@ -918,10 +939,16 @@ def send_command( else: raise ValueError("Load balancer hostname or ip is not available.") + routing_token = connection_info.routing_token + if not routing_token: + raise ValueError("Routing token is not available.") + path = path or "" if query_params: path = f"{path}?{urlencode(query_params)}" headers["Authorization"] = f"Bearer {access_token}" + headers["X-Sandbox-Routing-Token"] = routing_token + headers["X-Sandbox-Port"] = port endpoint = endpoint + path if path.startswith("/") else endpoint + "/" + path http_options = genai_types.HttpOptions(headers=headers, base_url=endpoint) http_client = genai.Client(vertexai=True, http_options=http_options) @@ -937,6 +964,7 @@ def generate_browser_ws_headers( self, sandbox_environment: types.SandboxEnvironment, service_account_email: str, + port: str = "8080", timeout: int = 3600, ) -> tuple[str, dict[str, str]]: """Generates the websocket upgrade headers for the browser. @@ -946,47 +974,56 @@ def generate_browser_ws_headers( Required. The sandbox environment to generate websocket headers for. service_account_email (str): Required. The email of the service account to use for signing. + port (str): + Optional. The port to use for the CDP websocket endpoint url fetching. + Defaults to "8080". This should be one of the ports specified during template creation. timeout (int): Optional. The timeout in seconds for the token. Defaults to 3600. - Returns: tuple[str, dict[str, str]]: A tuple containing the websocket URL and the headers for websocket upgrade. """ - sandbox_id = sandbox_environment.name - # port 8080 is the default port for http endpoint. + if not sandbox_environment.connection_info: + raise ValueError("Connection info is not available.") + + connection_info = sandbox_environment.connection_info + if connection_info.load_balancer_hostname: + ws_base_url = "wss://" + connection_info.load_balancer_hostname + elif connection_info.load_balancer_ip: + ws_base_url = "ws://" + connection_info.load_balancer_ip + else: + raise ValueError("Load balancer hostname or ip is not available.") + http_access_token = self.generate_access_token( - service_account_email, sandbox_id, "8080", timeout + service_account_email, connection_info.load_balancer_hostname, port, timeout ) response = self.send_command( http_method="GET", access_token=http_access_token, sandbox_environment=sandbox_environment, + port=port, path="/cdp_ws_endpoint", ) if not response: raise ValueError("Failed to get the websocket endpoint.") body_dict = json.loads(response.body) ws_path = body_dict["endpoint"] - - ws_url = "wss://test-us-central1.autopush-sandbox.vertexai.goog" - if sandbox_environment and sandbox_environment.connection_info: - connection_info = sandbox_environment.connection_info - if connection_info.load_balancer_hostname: - ws_url = "wss://" + connection_info.load_balancer_hostname - elif connection_info.load_balancer_ip: - ws_url = "ws://" + connection_info.load_balancer_ip - else: - raise ValueError("Load balancer hostname or ip is not available.") - ws_url = ws_url + "/" + ws_path + ws_url = ws_base_url + "/" + ws_path # port 9222 is the default port for the browser websocket endpoint. ws_access_token = self.generate_access_token( - service_account_email, sandbox_id, "9222", timeout + service_account_email, + connection_info.load_balancer_hostname, + "9222", + timeout, ) + routing_token = connection_info.routing_token + headers = {} - headers["Sec-WebSocket-Protocol"] = f"binary, {ws_access_token}" + headers["Sec-WebSocket-Protocol"] = ( + f"v1.stream, {ws_access_token}, {routing_token}, 9222" + ) return ws_url, headers diff --git a/vertexai/_genai/types/common.py b/vertexai/_genai/types/common.py index b2b26b8573..cf7b3f94cf 100644 --- a/vertexai/_genai/types/common.py +++ b/vertexai/_genai/types/common.py @@ -11733,6 +11733,20 @@ class CreateAgentEngineSandboxConfig(_common.BaseModel): default=None, description="""The TTL for this resource. The expiration time is computed: now + TTL.""", ) + sandbox_environment_template: Optional[str] = Field( + default=None, + description="""The name of the sandbox environment template to create the sandbox from. The sandbox environment template should be in the format: + projects/{project}/locations/{location}/agentEngines/{agent_engine}/sandboxEnvironmentTemplates/{sandbox_environment_template}""", + ) + sandbox_environment_snapshot: Optional[str] = Field( + default=None, + description="""The name of the sandbox environment snapshot to restore the sandbox from. The sandbox environment snapshot should be in the format: + projects/{project}/locations/{location}/agentEngines/{agent_engine}/sandboxEnvironmentSnapshots/{sandbox_environment_snapshot}""", + ) + owner: Optional[str] = Field( + default=None, + description="""Owner information for this sandbox environment. A sandbox can only be restored from a snapshot belonging to the same owner.""", + ) class CreateAgentEngineSandboxConfigDict(TypedDict, total=False): @@ -11753,6 +11767,17 @@ class CreateAgentEngineSandboxConfigDict(TypedDict, total=False): ttl: Optional[str] """The TTL for this resource. The expiration time is computed: now + TTL.""" + sandbox_environment_template: Optional[str] + """The name of the sandbox environment template to create the sandbox from. The sandbox environment template should be in the format: + projects/{project}/locations/{location}/agentEngines/{agent_engine}/sandboxEnvironmentTemplates/{sandbox_environment_template}""" + + sandbox_environment_snapshot: Optional[str] + """The name of the sandbox environment snapshot to restore the sandbox from. The sandbox environment snapshot should be in the format: + projects/{project}/locations/{location}/agentEngines/{agent_engine}/sandboxEnvironmentSnapshots/{sandbox_environment_snapshot}""" + + owner: Optional[str] + """Owner information for this sandbox environment. A sandbox can only be restored from a snapshot belonging to the same owner.""" + CreateAgentEngineSandboxConfigOrDict = Union[ CreateAgentEngineSandboxConfig, CreateAgentEngineSandboxConfigDict