@@ -321,16 +321,17 @@ def test_get_endpoint(self, mock_httpx, registry):
321321 assert server == {"name" : "test-endpoint" }
322322
323323 @pytest .mark .parametrize (
324- "url, expected_auth" ,
324+ "url, expected_auth, use_custom_provider " ,
325325 [
326- ("https://mcp.com" , False ),
327- ("https://mcp.googleapis.com/v1" , True ),
328- ("https://example.com/googleapis/v1" , False ),
326+ ("https://mcp.com" , False , False ),
327+ ("https://mcp.googleapis.com/v1" , True , False ),
328+ ("https://example.com/googleapis/v1" , False , False ),
329+ ("https://mcp.googleapis.com/v1" , True , True ),
329330 ],
330331 )
331332 @patch ("httpx.Client" )
332333 def test_get_mcp_toolset_auth_headers (
333- self , mock_httpx , registry , url , expected_auth
334+ self , mock_httpx , registry , url , expected_auth , use_custom_provider
334335 ):
335336 mock_response = MagicMock ()
336337 mock_response .json .return_value = {
@@ -345,19 +346,34 @@ def test_get_mcp_toolset_auth_headers(
345346 mock_response
346347 )
347348
349+ if use_custom_provider :
350+ custom_header_provider = lambda context : {
351+ "Authorization" : "Bearer custom_token"
352+ }
353+ with patch (
354+ "google.auth.default" , return_value = (MagicMock (), "project-id" )
355+ ):
356+ registry = AgentRegistry (
357+ project_id = "test-project" ,
358+ location = "global" ,
359+ header_provider = custom_header_provider ,
360+ )
361+
348362 registry ._credentials .token = "token"
349363 registry ._credentials .refresh = MagicMock ()
350364
351365 toolset = registry .get_mcp_toolset ("test-mcp" )
352366 assert isinstance (toolset , McpToolset )
353367 assert toolset .tool_name_prefix == "TestPrefix"
354- if expected_auth :
355- assert toolset ._connection_params .headers is not None
356- assert (
357- toolset ._connection_params .headers ["Authorization" ] == "Bearer token"
358- )
368+ assert toolset ._connection_params .headers is None
369+ headers = toolset ._header_provider (MagicMock ())
370+
371+ if use_custom_provider :
372+ assert headers .get ("Authorization" ) == "Bearer custom_token"
373+ elif expected_auth :
374+ assert headers .get ("Authorization" ) == "Bearer token"
359375 else :
360- assert toolset . _connection_params . headers is None
376+ assert "Authorization" not in headers
361377
362378 @patch ("httpx.Client" )
363379 def test_get_mcp_toolset_with_auth (self , mock_httpx , registry ):
@@ -392,6 +408,40 @@ def test_get_mcp_toolset_with_auth(self, mock_httpx, registry):
392408 assert auth_config .auth_scheme == auth_scheme
393409 assert auth_config .raw_auth_credential == auth_credential
394410
411+ @patch ("httpx.Client" )
412+ def test_get_mcp_toolset_with_auth_blocks_gcp_headers (
413+ self , mock_httpx , registry
414+ ):
415+ mock_response = MagicMock ()
416+ mock_response .json .return_value = {
417+ "displayName" : "TestPrefix" ,
418+ "interfaces" : [{
419+ "url" : "https://mcp.googleapis.com/v1" ,
420+ "protocolBinding" : "JSONRPC" ,
421+ }],
422+ }
423+ mock_response .raise_for_status = MagicMock ()
424+ mock_httpx .return_value .__enter__ .return_value .get .return_value = (
425+ mock_response
426+ )
427+
428+ registry ._credentials .token = "token"
429+ registry ._credentials .refresh = MagicMock ()
430+
431+ auth_scheme = OAuth2 (flows = {})
432+ auth_credential = AuthCredential (
433+ auth_type = "oauth2" ,
434+ oauth2 = OAuth2Auth (client_id = "test_id" , client_secret = "test_secret" ),
435+ )
436+
437+ toolset = registry .get_mcp_toolset (
438+ "test-mcp" , auth_scheme = auth_scheme , auth_credential = auth_credential
439+ )
440+ assert isinstance (toolset , McpToolset )
441+
442+ headers = toolset ._header_provider (MagicMock ())
443+ assert "Authorization" not in headers
444+
395445 @patch ("httpx.Client" )
396446 def test_get_remote_a2a_agent (self , mock_httpx , registry ):
397447 mock_response = MagicMock ()
0 commit comments