From 7cd510c3c4f008bc9110eb43dea1819942b947cb Mon Sep 17 00:00:00 2001 From: Jon Gallant <2163001+jongio@users.noreply.github.com> Date: Sun, 1 Mar 2026 20:18:47 -0800 Subject: [PATCH 1/6] feat(azdext): add integration helpers for keyvault and config Implements #6945 (P1-5/P1-6). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- cli/azd/pkg/azdext/config_helper.go | 405 ++++++++ cli/azd/pkg/azdext/config_helper_test.go | 967 +++++++++++++++++++ cli/azd/pkg/azdext/keyvault_resolver.go | 298 ++++++ cli/azd/pkg/azdext/keyvault_resolver_test.go | 575 +++++++++++ 4 files changed, 2245 insertions(+) create mode 100644 cli/azd/pkg/azdext/config_helper.go create mode 100644 cli/azd/pkg/azdext/config_helper_test.go create mode 100644 cli/azd/pkg/azdext/keyvault_resolver.go create mode 100644 cli/azd/pkg/azdext/keyvault_resolver_test.go diff --git a/cli/azd/pkg/azdext/config_helper.go b/cli/azd/pkg/azdext/config_helper.go new file mode 100644 index 00000000000..9df1f8c8f5e --- /dev/null +++ b/cli/azd/pkg/azdext/config_helper.go @@ -0,0 +1,405 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "encoding/json" + "errors" + "fmt" +) + +// ConfigHelper provides typed, ergonomic access to azd configuration through +// the gRPC UserConfig and Environment services. It eliminates the boilerplate +// of raw gRPC calls and JSON marshaling that extension authors otherwise need. +// +// Configuration sources (in merge priority, lowest to highest): +// 1. User config (global azd config) — via UserConfigService +// 2. Environment config (per-env) — via EnvironmentService +// +// Usage: +// +// ch := azdext.NewConfigHelper(client) +// port, err := ch.GetUserString(ctx, "extensions.myext.port") +// var cfg MyConfig +// err = ch.GetUserJSON(ctx, "extensions.myext", &cfg) +type ConfigHelper struct { + client *AzdClient +} + +// NewConfigHelper creates a [ConfigHelper] for the given AZD client. +func NewConfigHelper(client *AzdClient) (*ConfigHelper, error) { + if client == nil { + return nil, errors.New("azdext.NewConfigHelper: client must not be nil") + } + + return &ConfigHelper{client: client}, nil +} + +// --- User Config (global) --- + +// GetUserString retrieves a string value from the global user config at the +// given dot-separated path. Returns ("", false, nil) when the path does not +// exist, and ("", false, err) on gRPC errors. +func (ch *ConfigHelper) GetUserString(ctx context.Context, path string) (string, bool, error) { + if err := validatePath(path); err != nil { + return "", false, err + } + + resp, err := ch.client.UserConfig().GetString(ctx, &GetUserConfigStringRequest{Path: path}) + if err != nil { + return "", false, fmt.Errorf("azdext.ConfigHelper.GetUserString: gRPC call failed for path %q: %w", path, err) + } + + return resp.GetValue(), resp.GetFound(), nil +} + +// GetUserJSON retrieves a value from the global user config and unmarshals it +// into out. Returns (false, nil) when the path does not exist. +func (ch *ConfigHelper) GetUserJSON(ctx context.Context, path string, out any) (bool, error) { + if err := validatePath(path); err != nil { + return false, err + } + + if out == nil { + return false, errors.New("azdext.ConfigHelper.GetUserJSON: out must not be nil") + } + + resp, err := ch.client.UserConfig().Get(ctx, &GetUserConfigRequest{Path: path}) + if err != nil { + return false, fmt.Errorf("azdext.ConfigHelper.GetUserJSON: gRPC call failed for path %q: %w", path, err) + } + + if !resp.GetFound() { + return false, nil + } + + data := resp.GetValue() + if len(data) == 0 { + return false, nil + } + + if err := json.Unmarshal(data, out); err != nil { + return true, &ConfigError{ + Path: path, + Reason: ConfigReasonInvalidFormat, + Err: fmt.Errorf("failed to unmarshal config at path %q: %w", path, err), + } + } + + return true, nil +} + +// SetUserJSON marshals value as JSON and writes it to the global user config +// at the given path. +func (ch *ConfigHelper) SetUserJSON(ctx context.Context, path string, value any) error { + if err := validatePath(path); err != nil { + return err + } + + if value == nil { + return errors.New("azdext.ConfigHelper.SetUserJSON: value must not be nil") + } + + data, err := json.Marshal(value) + if err != nil { + return &ConfigError{ + Path: path, + Reason: ConfigReasonInvalidFormat, + Err: fmt.Errorf("failed to marshal value for path %q: %w", path, err), + } + } + + _, err = ch.client.UserConfig().Set(ctx, &SetUserConfigRequest{ + Path: path, + Value: data, + }) + if err != nil { + return fmt.Errorf("azdext.ConfigHelper.SetUserJSON: gRPC call failed for path %q: %w", path, err) + } + + return nil +} + +// UnsetUser removes a value from the global user config. +func (ch *ConfigHelper) UnsetUser(ctx context.Context, path string) error { + if err := validatePath(path); err != nil { + return err + } + + _, err := ch.client.UserConfig().Unset(ctx, &UnsetUserConfigRequest{Path: path}) + if err != nil { + return fmt.Errorf("azdext.ConfigHelper.UnsetUser: gRPC call failed for path %q: %w", path, err) + } + + return nil +} + +// --- Environment Config (per-environment) --- + +// GetEnvString retrieves a string config value from the current environment. +// Returns ("", false, nil) when the path does not exist. +func (ch *ConfigHelper) GetEnvString(ctx context.Context, path string) (string, bool, error) { + if err := validatePath(path); err != nil { + return "", false, err + } + + resp, err := ch.client.Environment().GetConfigString(ctx, &GetConfigStringRequest{Path: path}) + if err != nil { + return "", false, fmt.Errorf("azdext.ConfigHelper.GetEnvString: gRPC call failed for path %q: %w", path, err) + } + + return resp.GetValue(), resp.GetFound(), nil +} + +// GetEnvJSON retrieves a value from the current environment's config and +// unmarshals it into out. Returns (false, nil) when the path does not exist. +func (ch *ConfigHelper) GetEnvJSON(ctx context.Context, path string, out any) (bool, error) { + if err := validatePath(path); err != nil { + return false, err + } + + if out == nil { + return false, errors.New("azdext.ConfigHelper.GetEnvJSON: out must not be nil") + } + + resp, err := ch.client.Environment().GetConfig(ctx, &GetConfigRequest{Path: path}) + if err != nil { + return false, fmt.Errorf("azdext.ConfigHelper.GetEnvJSON: gRPC call failed for path %q: %w", path, err) + } + + if !resp.GetFound() { + return false, nil + } + + data := resp.GetValue() + if len(data) == 0 { + return false, nil + } + + if err := json.Unmarshal(data, out); err != nil { + return true, &ConfigError{ + Path: path, + Reason: ConfigReasonInvalidFormat, + Err: fmt.Errorf("failed to unmarshal env config at path %q: %w", path, err), + } + } + + return true, nil +} + +// SetEnvJSON marshals value as JSON and writes it to the current environment's config. +func (ch *ConfigHelper) SetEnvJSON(ctx context.Context, path string, value any) error { + if err := validatePath(path); err != nil { + return err + } + + if value == nil { + return errors.New("azdext.ConfigHelper.SetEnvJSON: value must not be nil") + } + + data, err := json.Marshal(value) + if err != nil { + return &ConfigError{ + Path: path, + Reason: ConfigReasonInvalidFormat, + Err: fmt.Errorf("failed to marshal value for env config path %q: %w", path, err), + } + } + + _, err = ch.client.Environment().SetConfig(ctx, &SetConfigRequest{ + Path: path, + Value: data, + }) + if err != nil { + return fmt.Errorf("azdext.ConfigHelper.SetEnvJSON: gRPC call failed for path %q: %w", path, err) + } + + return nil +} + +// UnsetEnv removes a value from the current environment's config. +func (ch *ConfigHelper) UnsetEnv(ctx context.Context, path string) error { + if err := validatePath(path); err != nil { + return err + } + + _, err := ch.client.Environment().UnsetConfig(ctx, &UnsetConfigRequest{Path: path}) + if err != nil { + return fmt.Errorf("azdext.ConfigHelper.UnsetEnv: gRPC call failed for path %q: %w", path, err) + } + + return nil +} + +// --- Merge --- + +// MergeJSON performs a shallow merge of override into base, returning a new map. +// Both inputs must be JSON-compatible maps (map[string]any). Keys in override +// take precedence over keys in base. +// +// This is NOT a deep merge — nested maps are replaced entirely by the override +// value. For predictable extension config behavior, keep config structures flat +// or use explicit path-based Set operations for nested values. +func MergeJSON(base, override map[string]any) map[string]any { + merged := make(map[string]any, len(base)+len(override)) + + for k, v := range base { + merged[k] = v + } + + for k, v := range override { + merged[k] = v + } + + return merged +} + +// DeepMergeJSON performs a recursive merge of override into base. +// When both base and override have a map value for the same key, those maps +// are merged recursively. Otherwise the override value replaces the base value. +func DeepMergeJSON(base, override map[string]any) map[string]any { + merged := make(map[string]any, len(base)+len(override)) + + for k, v := range base { + merged[k] = v + } + + for k, v := range override { + baseVal, exists := merged[k] + if !exists { + merged[k] = v + continue + } + + baseMap, baseIsMap := baseVal.(map[string]any) + overMap, overIsMap := v.(map[string]any) + + if baseIsMap && overIsMap { + merged[k] = DeepMergeJSON(baseMap, overMap) + } else { + merged[k] = v + } + } + + return merged +} + +// --- Validation --- + +// ConfigValidator defines a function that validates a config value. +// It returns nil if valid, or an error describing the validation failure. +type ConfigValidator func(value any) error + +// ValidateConfig unmarshals the raw JSON data and runs all supplied validators. +// Returns the first validation error encountered, wrapped in a [*ConfigError]. +func ValidateConfig(path string, data []byte, validators ...ConfigValidator) error { + if len(data) == 0 { + return &ConfigError{ + Path: path, + Reason: ConfigReasonMissing, + Err: fmt.Errorf("config at path %q is empty", path), + } + } + + var value any + if err := json.Unmarshal(data, &value); err != nil { + return &ConfigError{ + Path: path, + Reason: ConfigReasonInvalidFormat, + Err: fmt.Errorf("config at path %q is not valid JSON: %w", path, err), + } + } + + for _, v := range validators { + if err := v(value); err != nil { + return &ConfigError{ + Path: path, + Reason: ConfigReasonValidationFailed, + Err: fmt.Errorf("config validation failed at path %q: %w", path, err), + } + } + } + + return nil +} + +// RequiredKeys returns a [ConfigValidator] that checks for the presence of +// the specified keys in a map value. +func RequiredKeys(keys ...string) ConfigValidator { + return func(value any) error { + m, ok := value.(map[string]any) + if !ok { + return fmt.Errorf("expected object, got %T", value) + } + + for _, key := range keys { + if _, exists := m[key]; !exists { + return fmt.Errorf("required key %q is missing", key) + } + } + + return nil + } +} + +// --- Error types --- + +// ConfigReason classifies the cause of a [ConfigError]. +type ConfigReason int + +const ( + // ConfigReasonMissing indicates the config path does not exist or is empty. + ConfigReasonMissing ConfigReason = iota + + // ConfigReasonInvalidFormat indicates the config value is not valid JSON + // or cannot be unmarshaled into the target type. + ConfigReasonInvalidFormat + + // ConfigReasonValidationFailed indicates a validator rejected the config value. + ConfigReasonValidationFailed +) + +// String returns a human-readable label. +func (r ConfigReason) String() string { + switch r { + case ConfigReasonMissing: + return "missing" + case ConfigReasonInvalidFormat: + return "invalid_format" + case ConfigReasonValidationFailed: + return "validation_failed" + default: + return "unknown" + } +} + +// ConfigError is returned by [ConfigHelper] methods on domain-level failures. +type ConfigError struct { + // Path is the config path that was being accessed. + Path string + + // Reason classifies the failure. + Reason ConfigReason + + // Err is the underlying error. + Err error +} + +func (e *ConfigError) Error() string { + return fmt.Sprintf("azdext.ConfigHelper: %s (path=%s): %v", e.Reason, e.Path, e.Err) +} + +func (e *ConfigError) Unwrap() error { + return e.Err +} + +// validatePath checks that a config path is non-empty. +func validatePath(path string) error { + if path == "" { + return errors.New("azdext.ConfigHelper: config path must not be empty") + } + + return nil +} diff --git a/cli/azd/pkg/azdext/config_helper_test.go b/cli/azd/pkg/azdext/config_helper_test.go new file mode 100644 index 00000000000..c68464f16bc --- /dev/null +++ b/cli/azd/pkg/azdext/config_helper_test.go @@ -0,0 +1,967 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "google.golang.org/grpc" +) + +// --- Stub UserConfigService --- + +type stubUserConfigService struct { + getResp *GetUserConfigResponse + getStringResp *GetUserConfigStringResponse + getSectionErr error + getErr error + getStringErr error + setErr error + unsetErr error +} + +func (s *stubUserConfigService) Get( + _ context.Context, _ *GetUserConfigRequest, _ ...grpc.CallOption, +) (*GetUserConfigResponse, error) { + return s.getResp, s.getErr +} + +func (s *stubUserConfigService) GetString( + _ context.Context, _ *GetUserConfigStringRequest, _ ...grpc.CallOption, +) (*GetUserConfigStringResponse, error) { + return s.getStringResp, s.getStringErr +} + +func (s *stubUserConfigService) GetSection( + _ context.Context, _ *GetUserConfigSectionRequest, _ ...grpc.CallOption, +) (*GetUserConfigSectionResponse, error) { + return nil, s.getSectionErr +} + +func (s *stubUserConfigService) Set( + _ context.Context, _ *SetUserConfigRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return &EmptyResponse{}, s.setErr +} + +func (s *stubUserConfigService) Unset( + _ context.Context, _ *UnsetUserConfigRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return &EmptyResponse{}, s.unsetErr +} + +// --- Stub EnvironmentService --- + +type stubEnvironmentService struct { + getConfigResp *GetConfigResponse + getConfigStringResp *GetConfigStringResponse + getConfigErr error + getConfigStringErr error + setConfigErr error + unsetConfigErr error +} + +func (s *stubEnvironmentService) GetCurrent( + _ context.Context, _ *EmptyRequest, _ ...grpc.CallOption, +) (*EnvironmentResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) List( + _ context.Context, _ *EmptyRequest, _ ...grpc.CallOption, +) (*EnvironmentListResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) Get( + _ context.Context, _ *GetEnvironmentRequest, _ ...grpc.CallOption, +) (*EnvironmentResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) Select( + _ context.Context, _ *SelectEnvironmentRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) GetValues( + _ context.Context, _ *GetEnvironmentRequest, _ ...grpc.CallOption, +) (*KeyValueListResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) GetValue( + _ context.Context, _ *GetEnvRequest, _ ...grpc.CallOption, +) (*KeyValueResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) SetValue( + _ context.Context, _ *SetEnvRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) GetConfig( + _ context.Context, _ *GetConfigRequest, _ ...grpc.CallOption, +) (*GetConfigResponse, error) { + return s.getConfigResp, s.getConfigErr +} + +func (s *stubEnvironmentService) GetConfigString( + _ context.Context, _ *GetConfigStringRequest, _ ...grpc.CallOption, +) (*GetConfigStringResponse, error) { + return s.getConfigStringResp, s.getConfigStringErr +} + +func (s *stubEnvironmentService) GetConfigSection( + _ context.Context, _ *GetConfigSectionRequest, _ ...grpc.CallOption, +) (*GetConfigSectionResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) SetConfig( + _ context.Context, _ *SetConfigRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return &EmptyResponse{}, s.setConfigErr +} + +func (s *stubEnvironmentService) UnsetConfig( + _ context.Context, _ *UnsetConfigRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return &EmptyResponse{}, s.unsetConfigErr +} + +// --- NewConfigHelper --- + +func TestNewConfigHelper_NilClient(t *testing.T) { + t.Parallel() + + _, err := NewConfigHelper(nil) + if err == nil { + t.Fatal("expected error for nil client") + } +} + +func TestNewConfigHelper_Success(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, err := NewConfigHelper(client) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if ch == nil { + t.Fatal("expected non-nil ConfigHelper") + } +} + +// --- GetUserString --- + +func TestGetUserString_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + _, _, err := ch.GetUserString(context.Background(), "") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +func TestGetUserString_Found(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + getStringResp: &GetUserConfigStringResponse{Value: "8080", Found: true}, + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + val, found, err := ch.GetUserString(context.Background(), "extensions.myext.port") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !found { + t.Error("expected found = true") + } + + if val != "8080" { + t.Errorf("value = %q, want %q", val, "8080") + } +} + +func TestGetUserString_NotFound(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + getStringResp: &GetUserConfigStringResponse{Value: "", Found: false}, + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + val, found, err := ch.GetUserString(context.Background(), "nonexistent.path") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if found { + t.Error("expected found = false") + } + + if val != "" { + t.Errorf("value = %q, want empty", val) + } +} + +func TestGetUserString_GRPCError(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + getStringErr: errors.New("grpc unavailable"), + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + _, _, err := ch.GetUserString(context.Background(), "some.path") + if err == nil { + t.Fatal("expected error for gRPC failure") + } +} + +// --- GetUserJSON --- + +func TestGetUserJSON_Found(t *testing.T) { + t.Parallel() + + type myConfig struct { + Port int `json:"port"` + Host string `json:"host"` + } + + data, _ := json.Marshal(myConfig{Port: 3000, Host: "localhost"}) + stub := &stubUserConfigService{ + getResp: &GetUserConfigResponse{Value: data, Found: true}, + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + var cfg myConfig + found, err := ch.GetUserJSON(context.Background(), "extensions.myext", &cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !found { + t.Error("expected found = true") + } + + if cfg.Port != 3000 { + t.Errorf("Port = %d, want 3000", cfg.Port) + } + + if cfg.Host != "localhost" { + t.Errorf("Host = %q, want %q", cfg.Host, "localhost") + } +} + +func TestGetUserJSON_NotFound(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + getResp: &GetUserConfigResponse{Value: nil, Found: false}, + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + var cfg map[string]any + found, err := ch.GetUserJSON(context.Background(), "nonexistent", &cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if found { + t.Error("expected found = false") + } +} + +func TestGetUserJSON_NilOut(t *testing.T) { + t.Parallel() + + client := &AzdClient{userConfigClient: &stubUserConfigService{}} + ch, _ := NewConfigHelper(client) + + _, err := ch.GetUserJSON(context.Background(), "some.path", nil) + if err == nil { + t.Fatal("expected error for nil out parameter") + } +} + +func TestGetUserJSON_InvalidJSON(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + getResp: &GetUserConfigResponse{Value: []byte("not json"), Found: true}, + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + var cfg map[string]any + _, err := ch.GetUserJSON(context.Background(), "bad.json", &cfg) + if err == nil { + t.Fatal("expected error for invalid JSON") + } + + var cfgErr *ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("error type = %T, want *ConfigError", err) + } + + if cfgErr.Reason != ConfigReasonInvalidFormat { + t.Errorf("Reason = %v, want %v", cfgErr.Reason, ConfigReasonInvalidFormat) + } +} + +func TestGetUserJSON_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + var cfg map[string]any + _, err := ch.GetUserJSON(context.Background(), "", &cfg) + if err == nil { + t.Fatal("expected error for empty path") + } +} + +// --- SetUserJSON --- + +func TestSetUserJSON_Success(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{} + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + err := ch.SetUserJSON(context.Background(), "extensions.myext.port", 3000) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestSetUserJSON_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + err := ch.SetUserJSON(context.Background(), "", "value") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +func TestSetUserJSON_NilValue(t *testing.T) { + t.Parallel() + + client := &AzdClient{userConfigClient: &stubUserConfigService{}} + ch, _ := NewConfigHelper(client) + + err := ch.SetUserJSON(context.Background(), "some.path", nil) + if err == nil { + t.Fatal("expected error for nil value") + } +} + +func TestSetUserJSON_GRPCError(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + setErr: errors.New("grpc write error"), + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + err := ch.SetUserJSON(context.Background(), "some.path", "value") + if err == nil { + t.Fatal("expected error for gRPC failure") + } +} + +func TestSetUserJSON_UnmarshalableValue(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{} + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + // Channels cannot be marshaled to JSON + err := ch.SetUserJSON(context.Background(), "some.path", make(chan int)) + if err == nil { + t.Fatal("expected error for unmarshalable value") + } + + var cfgErr *ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("error type = %T, want *ConfigError", err) + } + + if cfgErr.Reason != ConfigReasonInvalidFormat { + t.Errorf("Reason = %v, want %v", cfgErr.Reason, ConfigReasonInvalidFormat) + } +} + +// --- UnsetUser --- + +func TestUnsetUser_Success(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{} + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + err := ch.UnsetUser(context.Background(), "some.path") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestUnsetUser_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + err := ch.UnsetUser(context.Background(), "") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +// --- GetEnvString --- + +func TestGetEnvString_Found(t *testing.T) { + t.Parallel() + + stub := &stubEnvironmentService{ + getConfigStringResp: &GetConfigStringResponse{Value: "prod", Found: true}, + } + + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + val, found, err := ch.GetEnvString(context.Background(), "extensions.myext.mode") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !found { + t.Error("expected found = true") + } + + if val != "prod" { + t.Errorf("value = %q, want %q", val, "prod") + } +} + +func TestGetEnvString_NotFound(t *testing.T) { + t.Parallel() + + stub := &stubEnvironmentService{ + getConfigStringResp: &GetConfigStringResponse{Value: "", Found: false}, + } + + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + _, found, err := ch.GetEnvString(context.Background(), "nonexistent") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if found { + t.Error("expected found = false") + } +} + +func TestGetEnvString_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + _, _, err := ch.GetEnvString(context.Background(), "") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +// --- GetEnvJSON --- + +func TestGetEnvJSON_Found(t *testing.T) { + t.Parallel() + + type envConfig struct { + Debug bool `json:"debug"` + } + + data, _ := json.Marshal(envConfig{Debug: true}) + stub := &stubEnvironmentService{ + getConfigResp: &GetConfigResponse{Value: data, Found: true}, + } + + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + var cfg envConfig + found, err := ch.GetEnvJSON(context.Background(), "extensions.myext", &cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !found { + t.Error("expected found = true") + } + + if !cfg.Debug { + t.Error("expected Debug = true") + } +} + +func TestGetEnvJSON_NotFound(t *testing.T) { + t.Parallel() + + stub := &stubEnvironmentService{ + getConfigResp: &GetConfigResponse{Value: nil, Found: false}, + } + + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + var cfg map[string]any + found, err := ch.GetEnvJSON(context.Background(), "nonexistent", &cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if found { + t.Error("expected found = false") + } +} + +func TestGetEnvJSON_NilOut(t *testing.T) { + t.Parallel() + + client := &AzdClient{environmentClient: &stubEnvironmentService{}} + ch, _ := NewConfigHelper(client) + + _, err := ch.GetEnvJSON(context.Background(), "some.path", nil) + if err == nil { + t.Fatal("expected error for nil out parameter") + } +} + +// --- SetEnvJSON --- + +func TestSetEnvJSON_Success(t *testing.T) { + t.Parallel() + + stub := &stubEnvironmentService{} + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + err := ch.SetEnvJSON(context.Background(), "extensions.myext.mode", "prod") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestSetEnvJSON_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + err := ch.SetEnvJSON(context.Background(), "", "value") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +func TestSetEnvJSON_NilValue(t *testing.T) { + t.Parallel() + + client := &AzdClient{environmentClient: &stubEnvironmentService{}} + ch, _ := NewConfigHelper(client) + + err := ch.SetEnvJSON(context.Background(), "some.path", nil) + if err == nil { + t.Fatal("expected error for nil value") + } +} + +// --- UnsetEnv --- + +func TestUnsetEnv_Success(t *testing.T) { + t.Parallel() + + stub := &stubEnvironmentService{} + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + err := ch.UnsetEnv(context.Background(), "some.path") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestUnsetEnv_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + err := ch.UnsetEnv(context.Background(), "") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +// --- MergeJSON --- + +func TestMergeJSON_Basic(t *testing.T) { + t.Parallel() + + base := map[string]any{"a": 1, "b": 2} + override := map[string]any{"b": 3, "c": 4} + + result := MergeJSON(base, override) + + if result["a"] != 1 { + t.Errorf("a = %v, want 1", result["a"]) + } + + if result["b"] != 3 { + t.Errorf("b = %v, want 3 (override wins)", result["b"]) + } + + if result["c"] != 4 { + t.Errorf("c = %v, want 4", result["c"]) + } +} + +func TestMergeJSON_EmptyBase(t *testing.T) { + t.Parallel() + + result := MergeJSON(nil, map[string]any{"x": "y"}) + + if result["x"] != "y" { + t.Errorf("x = %v, want y", result["x"]) + } +} + +func TestMergeJSON_EmptyOverride(t *testing.T) { + t.Parallel() + + result := MergeJSON(map[string]any{"x": "y"}, nil) + + if result["x"] != "y" { + t.Errorf("x = %v, want y", result["x"]) + } +} + +func TestMergeJSON_BothEmpty(t *testing.T) { + t.Parallel() + + result := MergeJSON(nil, nil) + + if len(result) != 0 { + t.Errorf("len(result) = %d, want 0", len(result)) + } +} + +func TestMergeJSON_DoesNotMutateInputs(t *testing.T) { + t.Parallel() + + base := map[string]any{"a": 1} + override := map[string]any{"b": 2} + + _ = MergeJSON(base, override) + + if _, ok := base["b"]; ok { + t.Error("MergeJSON mutated base map") + } + + if _, ok := override["a"]; ok { + t.Error("MergeJSON mutated override map") + } +} + +// --- DeepMergeJSON --- + +func TestDeepMergeJSON_RecursiveMerge(t *testing.T) { + t.Parallel() + + base := map[string]any{ + "server": map[string]any{ + "host": "localhost", + "port": 3000, + }, + "debug": false, + } + + override := map[string]any{ + "server": map[string]any{ + "port": 8080, + "tls": true, + }, + "version": "1.0", + } + + result := DeepMergeJSON(base, override) + + server, ok := result["server"].(map[string]any) + if !ok { + t.Fatal("server should be a map") + } + + if server["host"] != "localhost" { + t.Errorf("server.host = %v, want localhost", server["host"]) + } + + if server["port"] != 8080 { + t.Errorf("server.port = %v, want 8080 (override wins)", server["port"]) + } + + if server["tls"] != true { + t.Errorf("server.tls = %v, want true", server["tls"]) + } + + if result["debug"] != false { + t.Errorf("debug = %v, want false", result["debug"]) + } + + if result["version"] != "1.0" { + t.Errorf("version = %v, want 1.0", result["version"]) + } +} + +func TestDeepMergeJSON_OverrideReplacesNonMap(t *testing.T) { + t.Parallel() + + base := map[string]any{"x": "string-value"} + override := map[string]any{"x": map[string]any{"nested": true}} + + result := DeepMergeJSON(base, override) + + nested, ok := result["x"].(map[string]any) + if !ok { + t.Fatal("override should replace string with map") + } + + if nested["nested"] != true { + t.Errorf("x.nested = %v, want true", nested["nested"]) + } +} + +func TestDeepMergeJSON_DoesNotMutateInputs(t *testing.T) { + t.Parallel() + + base := map[string]any{"a": map[string]any{"x": 1}} + override := map[string]any{"a": map[string]any{"y": 2}} + + _ = DeepMergeJSON(base, override) + + baseA := base["a"].(map[string]any) + if _, ok := baseA["y"]; ok { + t.Error("DeepMergeJSON mutated base nested map") + } +} + +// --- ValidateConfig --- + +func TestValidateConfig_EmptyData(t *testing.T) { + t.Parallel() + + err := ValidateConfig("test.path", nil) + if err == nil { + t.Fatal("expected error for empty data") + } + + var cfgErr *ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("error type = %T, want *ConfigError", err) + } + + if cfgErr.Reason != ConfigReasonMissing { + t.Errorf("Reason = %v, want %v", cfgErr.Reason, ConfigReasonMissing) + } +} + +func TestValidateConfig_InvalidJSON(t *testing.T) { + t.Parallel() + + err := ValidateConfig("test.path", []byte("not json")) + if err == nil { + t.Fatal("expected error for invalid JSON") + } + + var cfgErr *ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("error type = %T, want *ConfigError", err) + } + + if cfgErr.Reason != ConfigReasonInvalidFormat { + t.Errorf("Reason = %v, want %v", cfgErr.Reason, ConfigReasonInvalidFormat) + } +} + +func TestValidateConfig_ValidatorFails(t *testing.T) { + t.Parallel() + + data, _ := json.Marshal(map[string]any{"a": 1}) + failValidator := func(_ any) error { return errors.New("validation failed") } + + err := ValidateConfig("test.path", data, failValidator) + if err == nil { + t.Fatal("expected error from failing validator") + } + + var cfgErr *ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("error type = %T, want *ConfigError", err) + } + + if cfgErr.Reason != ConfigReasonValidationFailed { + t.Errorf("Reason = %v, want %v", cfgErr.Reason, ConfigReasonValidationFailed) + } +} + +func TestValidateConfig_AllValidatorsPass(t *testing.T) { + t.Parallel() + + data, _ := json.Marshal(map[string]any{"a": 1, "b": 2}) + passValidator := func(_ any) error { return nil } + + err := ValidateConfig("test.path", data, passValidator, passValidator) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestValidateConfig_NoValidators(t *testing.T) { + t.Parallel() + + data, _ := json.Marshal(map[string]any{"a": 1}) + + err := ValidateConfig("test.path", data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +// --- RequiredKeys --- + +func TestRequiredKeys_AllPresent(t *testing.T) { + t.Parallel() + + validator := RequiredKeys("host", "port") + value := map[string]any{"host": "localhost", "port": 3000, "extra": true} + + err := validator(value) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestRequiredKeys_MissingKey(t *testing.T) { + t.Parallel() + + validator := RequiredKeys("host", "port") + value := map[string]any{"host": "localhost"} + + err := validator(value) + if err == nil { + t.Fatal("expected error for missing key") + } +} + +func TestRequiredKeys_NotAMap(t *testing.T) { + t.Parallel() + + validator := RequiredKeys("key") + + err := validator("not a map") + if err == nil { + t.Fatal("expected error for non-map value") + } +} + +// --- ConfigError --- + +func TestConfigError_Error(t *testing.T) { + t.Parallel() + + err := &ConfigError{ + Path: "test.path", + Reason: ConfigReasonMissing, + Err: errors.New("not found"), + } + + got := err.Error() + if got == "" { + t.Fatal("Error() returned empty string") + } +} + +func TestConfigError_Unwrap(t *testing.T) { + t.Parallel() + + inner := errors.New("inner error") + err := &ConfigError{ + Path: "test.path", + Reason: ConfigReasonInvalidFormat, + Err: inner, + } + + if !errors.Is(err, inner) { + t.Error("Unwrap should expose inner error via errors.Is") + } +} + +func TestConfigReason_String(t *testing.T) { + t.Parallel() + + tests := []struct { + reason ConfigReason + want string + }{ + {ConfigReasonMissing, "missing"}, + {ConfigReasonInvalidFormat, "invalid_format"}, + {ConfigReasonValidationFailed, "validation_failed"}, + {ConfigReason(99), "unknown"}, + } + + for _, tt := range tests { + if got := tt.reason.String(); got != tt.want { + t.Errorf("ConfigReason(%d).String() = %q, want %q", tt.reason, got, tt.want) + } + } +} diff --git a/cli/azd/pkg/azdext/keyvault_resolver.go b/cli/azd/pkg/azdext/keyvault_resolver.go new file mode 100644 index 00000000000..d1ea0486429 --- /dev/null +++ b/cli/azd/pkg/azdext/keyvault_resolver.go @@ -0,0 +1,298 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azsecrets" +) + +// KeyVaultResolver resolves Azure Key Vault secret references for extension +// scenarios. It uses the extension's [TokenProvider] for authentication and +// the Azure SDK data-plane client for secret retrieval. +// +// Secret references use the akvs:// URI scheme: +// +// akvs://// +// +// Usage: +// +// tp, _ := azdext.NewTokenProvider(ctx, client, nil) +// resolver, _ := azdext.NewKeyVaultResolver(tp, nil) +// value, err := resolver.Resolve(ctx, "akvs://sub-id/my-vault/my-secret") +type KeyVaultResolver struct { + credential azcore.TokenCredential + clientFactory secretClientFactory + opts KeyVaultResolverOptions +} + +// secretClientFactory abstracts secret client creation for testability. +type secretClientFactory func(vaultURL string, credential azcore.TokenCredential) (secretGetter, error) + +// secretGetter abstracts the Azure SDK secret client's GetSecret method. +type secretGetter interface { + GetSecret(ctx context.Context, name string, version string, options *azsecrets.GetSecretOptions) (azsecrets.GetSecretResponse, error) +} + +// KeyVaultResolverOptions configures a [KeyVaultResolver]. +type KeyVaultResolverOptions struct { + // VaultSuffix overrides the default Key Vault DNS suffix. + // Defaults to "vault.azure.net" (Azure public cloud). + VaultSuffix string + + // ClientFactory overrides the default secret client constructor. + // Useful for testing. When nil, the production [azsecrets.NewClient] is used. + ClientFactory func(vaultURL string, credential azcore.TokenCredential) (secretGetter, error) +} + +// NewKeyVaultResolver creates a [KeyVaultResolver] with the given credential. +// +// credential must not be nil; it is typically a [*TokenProvider] from P1-1. +// If opts is nil, production defaults are used. +func NewKeyVaultResolver(credential azcore.TokenCredential, opts *KeyVaultResolverOptions) (*KeyVaultResolver, error) { + if credential == nil { + return nil, errors.New("azdext.NewKeyVaultResolver: credential must not be nil") + } + + if opts == nil { + opts = &KeyVaultResolverOptions{} + } + + if opts.VaultSuffix == "" { + opts.VaultSuffix = "vault.azure.net" + } + + factory := defaultSecretClientFactory + if opts.ClientFactory != nil { + factory = opts.ClientFactory + } + + return &KeyVaultResolver{ + credential: credential, + clientFactory: factory, + opts: *opts, + }, nil +} + +// defaultSecretClientFactory creates a real Azure SDK secrets client. +func defaultSecretClientFactory(vaultURL string, credential azcore.TokenCredential) (secretGetter, error) { + client, err := azsecrets.NewClient(vaultURL, credential, nil) + if err != nil { + return nil, err + } + + return client, nil +} + +// Resolve fetches the secret value for an akvs:// reference. +// +// The reference must match the format: akvs://// +// +// Returns a [*KeyVaultResolveError] for all domain errors (invalid reference, +// secret not found, authentication failure). No silent fallbacks or hidden retries. +func (r *KeyVaultResolver) Resolve(ctx context.Context, ref string) (string, error) { + if ctx == nil { + return "", errors.New("azdext.KeyVaultResolver.Resolve: context must not be nil") + } + + parsed, err := ParseSecretReference(ref) + if err != nil { + return "", &KeyVaultResolveError{ + Reference: ref, + Reason: ResolveReasonInvalidReference, + Err: err, + } + } + + vaultURL := fmt.Sprintf("https://%s.%s", parsed.VaultName, r.opts.VaultSuffix) + + client, err := r.clientFactory(vaultURL, r.credential) + if err != nil { + return "", &KeyVaultResolveError{ + Reference: ref, + Reason: ResolveReasonClientCreation, + Err: fmt.Errorf("failed to create Key Vault client for %s: %w", vaultURL, err), + } + } + + resp, err := client.GetSecret(ctx, parsed.SecretName, "", nil) + if err != nil { + reason := ResolveReasonAccessDenied + + var respErr *azcore.ResponseError + if errors.As(err, &respErr) { + switch respErr.StatusCode { + case http.StatusNotFound: + reason = ResolveReasonNotFound + case http.StatusForbidden, http.StatusUnauthorized: + reason = ResolveReasonAccessDenied + default: + reason = ResolveReasonServiceError + } + } + + return "", &KeyVaultResolveError{ + Reference: ref, + Reason: reason, + Err: fmt.Errorf("failed to retrieve secret %q from vault %q: %w", parsed.SecretName, parsed.VaultName, err), + } + } + + if resp.Value == nil { + return "", &KeyVaultResolveError{ + Reference: ref, + Reason: ResolveReasonNotFound, + Err: fmt.Errorf("secret %q in vault %q has a nil value", parsed.SecretName, parsed.VaultName), + } + } + + return *resp.Value, nil +} + +// ResolveMap resolves a map of key → akvs:// references, returning a map of +// key → resolved secret values. Processing stops at the first error. +// +// Non-akvs:// values are passed through unchanged, so callers can safely +// resolve a mixed map of plain values and secret references. +func (r *KeyVaultResolver) ResolveMap(ctx context.Context, refs map[string]string) (map[string]string, error) { + if ctx == nil { + return nil, errors.New("azdext.KeyVaultResolver.ResolveMap: context must not be nil") + } + + result := make(map[string]string, len(refs)) + + for key, value := range refs { + if !IsSecretReference(value) { + result[key] = value + continue + } + + resolved, err := r.Resolve(ctx, value) + if err != nil { + return nil, fmt.Errorf("azdext.KeyVaultResolver.ResolveMap: key %q: %w", key, err) + } + + result[key] = resolved + } + + return result, nil +} + +// SecretReference represents a parsed akvs:// URI. +type SecretReference struct { + // SubscriptionID is the Azure subscription containing the Key Vault. + SubscriptionID string + + // VaultName is the Key Vault name (not the full URL). + VaultName string + + // SecretName is the name of the secret within the vault. + SecretName string +} + +const secretScheme = "akvs://" + +// IsSecretReference reports whether s uses the akvs:// scheme. +func IsSecretReference(s string) bool { + return strings.HasPrefix(s, secretScheme) +} + +// ParseSecretReference parses an akvs:// URI into its components. +// +// Expected format: akvs://// +func ParseSecretReference(ref string) (*SecretReference, error) { + if !IsSecretReference(ref) { + return nil, fmt.Errorf("not an akvs:// reference: %s", ref) + } + + body := strings.TrimPrefix(ref, secretScheme) + parts := strings.Split(body, "/") + + if len(parts) != 3 { + return nil, fmt.Errorf( + "invalid akvs:// reference %q: expected format %s//", + ref, secretScheme, + ) + } + + for i, part := range parts { + if strings.TrimSpace(part) == "" { + labels := []string{"subscription-id", "vault-name", "secret-name"} + return nil, fmt.Errorf("invalid akvs:// reference %q: %s must not be empty", ref, labels[i]) + } + } + + return &SecretReference{ + SubscriptionID: parts[0], + VaultName: parts[1], + SecretName: parts[2], + }, nil +} + +// ResolveReason classifies the cause of a [KeyVaultResolveError]. +type ResolveReason int + +const ( + // ResolveReasonInvalidReference indicates the akvs:// URI is malformed. + ResolveReasonInvalidReference ResolveReason = iota + + // ResolveReasonClientCreation indicates failure to create the Key Vault client. + ResolveReasonClientCreation + + // ResolveReasonNotFound indicates the secret does not exist. + ResolveReasonNotFound + + // ResolveReasonAccessDenied indicates an authentication or authorization failure. + ResolveReasonAccessDenied + + // ResolveReasonServiceError indicates an unexpected Key Vault service error. + ResolveReasonServiceError +) + +// String returns a human-readable label for the reason. +func (r ResolveReason) String() string { + switch r { + case ResolveReasonInvalidReference: + return "invalid_reference" + case ResolveReasonClientCreation: + return "client_creation" + case ResolveReasonNotFound: + return "not_found" + case ResolveReasonAccessDenied: + return "access_denied" + case ResolveReasonServiceError: + return "service_error" + default: + return "unknown" + } +} + +// KeyVaultResolveError is returned when [KeyVaultResolver.Resolve] fails. +type KeyVaultResolveError struct { + // Reference is the original akvs:// URI that was being resolved. + Reference string + + // Reason classifies the failure. + Reason ResolveReason + + // Err is the underlying error. + Err error +} + +func (e *KeyVaultResolveError) Error() string { + return fmt.Sprintf( + "azdext.KeyVaultResolver: %s (ref=%s): %v", + e.Reason, e.Reference, e.Err, + ) +} + +func (e *KeyVaultResolveError) Unwrap() error { + return e.Err +} diff --git a/cli/azd/pkg/azdext/keyvault_resolver_test.go b/cli/azd/pkg/azdext/keyvault_resolver_test.go new file mode 100644 index 00000000000..c666e64aa7d --- /dev/null +++ b/cli/azd/pkg/azdext/keyvault_resolver_test.go @@ -0,0 +1,575 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "errors" + "net/http" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azsecrets" +) + +// stubSecretGetter is a test double for the Key Vault data-plane client. +type stubSecretGetter struct { + resp azsecrets.GetSecretResponse + err error +} + +func (s *stubSecretGetter) GetSecret( + _ context.Context, _ string, _ string, _ *azsecrets.GetSecretOptions, +) (azsecrets.GetSecretResponse, error) { + return s.resp, s.err +} + +// stubSecretFactory returns a factory that always returns the given stubSecretGetter. +func stubSecretFactory(g secretGetter, factoryErr error) func(string, azcore.TokenCredential) (secretGetter, error) { + return func(_ string, _ azcore.TokenCredential) (secretGetter, error) { + if factoryErr != nil { + return nil, factoryErr + } + return g, nil + } +} + +// --- NewKeyVaultResolver --- + +func TestNewKeyVaultResolver_NilCredential(t *testing.T) { + t.Parallel() + + _, err := NewKeyVaultResolver(nil, nil) + if err == nil { + t.Fatal("expected error for nil credential") + } +} + +func TestNewKeyVaultResolver_Defaults(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, err := NewKeyVaultResolver(cred, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if resolver.opts.VaultSuffix != "vault.azure.net" { + t.Errorf("VaultSuffix = %q, want %q", resolver.opts.VaultSuffix, "vault.azure.net") + } +} + +func TestNewKeyVaultResolver_CustomSuffix(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, err := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + VaultSuffix: "vault.azure.cn", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if resolver.opts.VaultSuffix != "vault.azure.cn" { + t.Errorf("VaultSuffix = %q, want %q", resolver.opts.VaultSuffix, "vault.azure.cn") + } +} + +// --- IsSecretReference --- + +func TestIsSecretReference(t *testing.T) { + t.Parallel() + + tests := []struct { + input string + want bool + }{ + {"akvs://sub/vault/secret", true}, + {"akvs://", true}, + {"AKVS://sub/vault/secret", false}, // case-sensitive + {"https://vault.azure.net", false}, + {"", false}, + } + + for _, tt := range tests { + if got := IsSecretReference(tt.input); got != tt.want { + t.Errorf("IsSecretReference(%q) = %v, want %v", tt.input, got, tt.want) + } + } +} + +// --- ParseSecretReference --- + +func TestParseSecretReference_Valid(t *testing.T) { + t.Parallel() + + ref, err := ParseSecretReference("akvs://sub-123/my-vault/my-secret") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if ref.SubscriptionID != "sub-123" { + t.Errorf("SubscriptionID = %q, want %q", ref.SubscriptionID, "sub-123") + } + if ref.VaultName != "my-vault" { + t.Errorf("VaultName = %q, want %q", ref.VaultName, "my-vault") + } + if ref.SecretName != "my-secret" { + t.Errorf("SecretName = %q, want %q", ref.SecretName, "my-secret") + } +} + +func TestParseSecretReference_NotAkvsScheme(t *testing.T) { + t.Parallel() + + _, err := ParseSecretReference("https://vault.azure.net/secrets/x") + if err == nil { + t.Fatal("expected error for non-akvs scheme") + } +} + +func TestParseSecretReference_TooFewParts(t *testing.T) { + t.Parallel() + + _, err := ParseSecretReference("akvs://sub/vault") + if err == nil { + t.Fatal("expected error for two-part ref") + } +} + +func TestParseSecretReference_TooManyParts(t *testing.T) { + t.Parallel() + + _, err := ParseSecretReference("akvs://sub/vault/secret/extra") + if err == nil { + t.Fatal("expected error for four-part ref") + } +} + +func TestParseSecretReference_EmptyComponent(t *testing.T) { + t.Parallel() + + cases := []string{ + "akvs:///vault/secret", // empty subscription + "akvs://sub//secret", // empty vault + "akvs://sub/vault/", // empty secret + "akvs:// /vault/secret", // whitespace subscription + "akvs://sub/ /secret", // whitespace vault + "akvs://sub/vault/ ", // whitespace secret + } + + for _, ref := range cases { + _, err := ParseSecretReference(ref) + if err == nil { + t.Errorf("ParseSecretReference(%q) expected error, got nil", ref) + } + } +} + +// --- Resolve --- + +func TestResolve_Success(t *testing.T) { + t.Parallel() + + secretValue := "super-secret-value" + getter := &stubSecretGetter{ + resp: azsecrets.GetSecretResponse{ + Secret: azsecrets.Secret{ + Value: &secretValue, + }, + }, + } + + cred := &stubCredential{} + resolver, err := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + val, err := resolver.Resolve(context.Background(), "akvs://sub-id/my-vault/my-secret") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if val != secretValue { + t.Errorf("Resolve() = %q, want %q", val, secretValue) + } +} + +func TestResolve_NilContext(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(&stubSecretGetter{}, nil), + }) + + //nolint:staticcheck // intentionally testing nil context + _, err := resolver.Resolve(nil, "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for nil context") + } +} + +func TestResolve_InvalidReference(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(&stubSecretGetter{}, nil), + }) + + _, err := resolver.Resolve(context.Background(), "not-akvs://x") + if err == nil { + t.Fatal("expected error for invalid reference") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonInvalidReference { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonInvalidReference) + } +} + +func TestResolve_ClientCreationFailure(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(nil, errors.New("connection refused")), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for client creation failure") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonClientCreation { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonClientCreation) + } +} + +func TestResolve_SecretNotFound(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + err: &azcore.ResponseError{StatusCode: http.StatusNotFound}, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/missing-secret") + if err == nil { + t.Fatal("expected error for missing secret") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonNotFound { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonNotFound) + } +} + +func TestResolve_AccessDenied(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + err: &azcore.ResponseError{StatusCode: http.StatusForbidden}, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for forbidden access") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonAccessDenied { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonAccessDenied) + } +} + +func TestResolve_Unauthorized(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + err: &azcore.ResponseError{StatusCode: http.StatusUnauthorized}, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for unauthorized access") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonAccessDenied { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonAccessDenied) + } +} + +func TestResolve_ServiceError(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + err: &azcore.ResponseError{StatusCode: http.StatusInternalServerError}, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for server error") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonServiceError { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonServiceError) + } +} + +func TestResolve_NilValue(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + resp: azsecrets.GetSecretResponse{ + Secret: azsecrets.Secret{ + Value: nil, + }, + }, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for nil secret value") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonNotFound { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonNotFound) + } +} + +func TestResolve_NonResponseError(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + err: errors.New("network timeout"), + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for network failure") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + // Non-ResponseError defaults to access_denied + if resolveErr.Reason != ResolveReasonAccessDenied { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonAccessDenied) + } +} + +// --- ResolveMap --- + +func TestResolveMap_MixedValues(t *testing.T) { + t.Parallel() + + secretValue := "resolved-secret" + getter := &stubSecretGetter{ + resp: azsecrets.GetSecretResponse{ + Secret: azsecrets.Secret{ + Value: &secretValue, + }, + }, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + input := map[string]string{ + "plain": "hello-world", + "secret": "akvs://sub/vault/secret", + } + + result, err := resolver.ResolveMap(context.Background(), input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result["plain"] != "hello-world" { + t.Errorf("result[plain] = %q, want %q", result["plain"], "hello-world") + } + + if result["secret"] != secretValue { + t.Errorf("result[secret] = %q, want %q", result["secret"], secretValue) + } +} + +func TestResolveMap_Empty(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(&stubSecretGetter{}, nil), + }) + + result, err := resolver.ResolveMap(context.Background(), map[string]string{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result) != 0 { + t.Errorf("len(result) = %d, want 0", len(result)) + } +} + +func TestResolveMap_ErrorStopsProcessing(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + err: &azcore.ResponseError{StatusCode: http.StatusNotFound}, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + input := map[string]string{ + "secret": "akvs://sub/vault/missing", + } + + _, err := resolver.ResolveMap(context.Background(), input) + if err == nil { + t.Fatal("expected error when resolution fails") + } +} + +func TestResolveMap_NilContext(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(&stubSecretGetter{}, nil), + }) + + //nolint:staticcheck // intentionally testing nil context + _, err := resolver.ResolveMap(nil, map[string]string{"k": "v"}) + if err == nil { + t.Fatal("expected error for nil context") + } +} + +// --- Error types --- + +func TestKeyVaultResolveError_Error(t *testing.T) { + t.Parallel() + + err := &KeyVaultResolveError{ + Reference: "akvs://sub/vault/secret", + Reason: ResolveReasonNotFound, + Err: errors.New("secret not found"), + } + + got := err.Error() + if got == "" { + t.Fatal("Error() returned empty string") + } +} + +func TestKeyVaultResolveError_Unwrap(t *testing.T) { + t.Parallel() + + inner := errors.New("inner error") + err := &KeyVaultResolveError{ + Reference: "akvs://sub/vault/secret", + Reason: ResolveReasonServiceError, + Err: inner, + } + + if !errors.Is(err, inner) { + t.Error("Unwrap should expose inner error via errors.Is") + } +} + +func TestResolveReason_String(t *testing.T) { + t.Parallel() + + tests := []struct { + reason ResolveReason + want string + }{ + {ResolveReasonInvalidReference, "invalid_reference"}, + {ResolveReasonClientCreation, "client_creation"}, + {ResolveReasonNotFound, "not_found"}, + {ResolveReasonAccessDenied, "access_denied"}, + {ResolveReasonServiceError, "service_error"}, + {ResolveReason(99), "unknown"}, + } + + for _, tt := range tests { + if got := tt.reason.String(); got != tt.want { + t.Errorf("ResolveReason(%d).String() = %q, want %q", tt.reason, got, tt.want) + } + } +} From 7506603858efa4f82c2737a92fd3867cb3a28c02 Mon Sep 17 00:00:00 2001 From: Jon Gallant <2163001+jongio@users.noreply.github.com> Date: Sun, 1 Mar 2026 20:56:38 -0800 Subject: [PATCH 2/6] fix: resolve PR2 lint issues Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- cli/azd/pkg/azdext/keyvault_resolver.go | 14 ++++++++++++-- cli/azd/pkg/azdext/keyvault_resolver_test.go | 12 ++++++------ 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/cli/azd/pkg/azdext/keyvault_resolver.go b/cli/azd/pkg/azdext/keyvault_resolver.go index d1ea0486429..975618c76f2 100644 --- a/cli/azd/pkg/azdext/keyvault_resolver.go +++ b/cli/azd/pkg/azdext/keyvault_resolver.go @@ -38,7 +38,12 @@ type secretClientFactory func(vaultURL string, credential azcore.TokenCredential // secretGetter abstracts the Azure SDK secret client's GetSecret method. type secretGetter interface { - GetSecret(ctx context.Context, name string, version string, options *azsecrets.GetSecretOptions) (azsecrets.GetSecretResponse, error) + GetSecret( + ctx context.Context, + name string, + version string, + options *azsecrets.GetSecretOptions, + ) (azsecrets.GetSecretResponse, error) } // KeyVaultResolverOptions configures a [KeyVaultResolver]. @@ -141,7 +146,12 @@ func (r *KeyVaultResolver) Resolve(ctx context.Context, ref string) (string, err return "", &KeyVaultResolveError{ Reference: ref, Reason: reason, - Err: fmt.Errorf("failed to retrieve secret %q from vault %q: %w", parsed.SecretName, parsed.VaultName, err), + Err: fmt.Errorf( + "failed to retrieve secret %q from vault %q: %w", + parsed.SecretName, + parsed.VaultName, + err, + ), } } diff --git a/cli/azd/pkg/azdext/keyvault_resolver_test.go b/cli/azd/pkg/azdext/keyvault_resolver_test.go index c666e64aa7d..628f301d5d3 100644 --- a/cli/azd/pkg/azdext/keyvault_resolver_test.go +++ b/cli/azd/pkg/azdext/keyvault_resolver_test.go @@ -151,12 +151,12 @@ func TestParseSecretReference_EmptyComponent(t *testing.T) { t.Parallel() cases := []string{ - "akvs:///vault/secret", // empty subscription - "akvs://sub//secret", // empty vault - "akvs://sub/vault/", // empty secret - "akvs:// /vault/secret", // whitespace subscription - "akvs://sub/ /secret", // whitespace vault - "akvs://sub/vault/ ", // whitespace secret + "akvs:///vault/secret", // empty subscription + "akvs://sub//secret", // empty vault + "akvs://sub/vault/", // empty secret + "akvs:// /vault/secret", // whitespace subscription + "akvs://sub/ /secret", // whitespace vault + "akvs://sub/vault/ ", // whitespace secret } for _, ref := range cases { From d6e39f85a6afedc829770e3495aba97b4646579e Mon Sep 17 00:00:00 2001 From: Jon Gallant <2163001+jongio@users.noreply.github.com> Date: Mon, 2 Mar 2026 08:12:10 -0800 Subject: [PATCH 3/6] security: harden keyvault/config helpers against hack scan findings - config_helper: sanitize config key inputs, add bounds validation - config_helper_test: test coverage for sanitization paths - keyvault_resolver: tighten secret name validation - Propagate core fixes: mcp_security, pagination, resilient_http_client --- cli/azd/pkg/azdext/config_helper.go | 25 ++++- cli/azd/pkg/azdext/config_helper_test.go | 10 ++ cli/azd/pkg/azdext/keyvault_resolver.go | 52 ++++++---- cli/azd/pkg/azdext/mcp_security.go | 109 ++++++++++++++++++-- cli/azd/pkg/azdext/pagination.go | 66 +++++++++++- cli/azd/pkg/azdext/resilient_http_client.go | 5 +- 6 files changed, 233 insertions(+), 34 deletions(-) diff --git a/cli/azd/pkg/azdext/config_helper.go b/cli/azd/pkg/azdext/config_helper.go index 9df1f8c8f5e..cc28ab67a01 100644 --- a/cli/azd/pkg/azdext/config_helper.go +++ b/cli/azd/pkg/azdext/config_helper.go @@ -8,6 +8,7 @@ import ( "encoding/json" "errors" "fmt" + "regexp" ) // ConfigHelper provides typed, ergonomic access to azd configuration through @@ -256,10 +257,23 @@ func MergeJSON(base, override map[string]any) map[string]any { return merged } +// deepMergeMaxDepth is the maximum recursion depth for [DeepMergeJSON]. +// This prevents stack overflow from deeply nested or adversarial JSON +// structures. 32 levels is far deeper than any legitimate config hierarchy. +const deepMergeMaxDepth = 32 + // DeepMergeJSON performs a recursive merge of override into base. // When both base and override have a map value for the same key, those maps // are merged recursively. Otherwise the override value replaces the base value. +// +// Recursion is bounded to [deepMergeMaxDepth] levels to prevent stack overflow +// from deeply nested or adversarial inputs. Beyond the limit, the override +// value replaces the base value (merge degrades to shallow at that level). func DeepMergeJSON(base, override map[string]any) map[string]any { + return deepMergeJSON(base, override, 0) +} + +func deepMergeJSON(base, override map[string]any, depth int) map[string]any { merged := make(map[string]any, len(base)+len(override)) for k, v := range base { @@ -276,8 +290,8 @@ func DeepMergeJSON(base, override map[string]any) map[string]any { baseMap, baseIsMap := baseVal.(map[string]any) overMap, overIsMap := v.(map[string]any) - if baseIsMap && overIsMap { - merged[k] = DeepMergeJSON(baseMap, overMap) + if baseIsMap && overIsMap && depth < deepMergeMaxDepth { + merged[k] = deepMergeJSON(baseMap, overMap, depth+1) } else { merged[k] = v } @@ -395,11 +409,18 @@ func (e *ConfigError) Unwrap() error { return e.Err } +var configPathRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9._-]*$`) + // validatePath checks that a config path is non-empty. func validatePath(path string) error { if path == "" { return errors.New("azdext.ConfigHelper: config path must not be empty") } + if !configPathRe.MatchString(path) { + return errors.New( + "azdext.ConfigHelper: config path must start with alphanumeric and contain only [a-zA-Z0-9._-]", + ) + } return nil } diff --git a/cli/azd/pkg/azdext/config_helper_test.go b/cli/azd/pkg/azdext/config_helper_test.go index c68464f16bc..ba5c80e6321 100644 --- a/cli/azd/pkg/azdext/config_helper_test.go +++ b/cli/azd/pkg/azdext/config_helper_test.go @@ -965,3 +965,13 @@ func TestConfigReason_String(t *testing.T) { } } } + +func TestValidatePath_InvalidCharacters(t *testing.T) { + t.Parallel() + + for _, path := range []string{"extensions/myext", "extensions myext", ".badstart"} { + if err := validatePath(path); err == nil { + t.Errorf("validatePath(%q) expected error, got nil", path) + } + } +} diff --git a/cli/azd/pkg/azdext/keyvault_resolver.go b/cli/azd/pkg/azdext/keyvault_resolver.go index 975618c76f2..0f78416656c 100644 --- a/cli/azd/pkg/azdext/keyvault_resolver.go +++ b/cli/azd/pkg/azdext/keyvault_resolver.go @@ -8,10 +8,12 @@ import ( "errors" "fmt" "net/http" + "regexp" "strings" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azsecrets" + "github.com/azure/azure-dev/cli/azd/pkg/keyvault" ) // KeyVaultResolver resolves Azure Key Vault secret references for extension @@ -207,42 +209,52 @@ type SecretReference struct { SecretName string } -const secretScheme = "akvs://" - // IsSecretReference reports whether s uses the akvs:// scheme. func IsSecretReference(s string) bool { - return strings.HasPrefix(s, secretScheme) + return keyvault.IsAzureKeyVaultSecret(s) } +// vaultNameRe validates Azure Key Vault names per Azure naming rules: +// - 3–24 characters +// - starts with a letter +// - contains only alphanumeric and hyphens +// - does not end with a hyphen +var vaultNameRe = regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9-]{1,22}[a-zA-Z0-9]$`) + // ParseSecretReference parses an akvs:// URI into its components. // // Expected format: akvs://// +// +// The vault name is validated against Azure Key Vault naming rules (3–24 +// characters, starts with letter, alphanumeric and hyphens only, does not +// end with a hyphen). func ParseSecretReference(ref string) (*SecretReference, error) { - if !IsSecretReference(ref) { - return nil, fmt.Errorf("not an akvs:// reference: %s", ref) + parsed, err := keyvault.ParseAzureKeyVaultSecret(ref) + if err != nil { + return nil, err } - body := strings.TrimPrefix(ref, secretScheme) - parts := strings.Split(body, "/") - - if len(parts) != 3 { + if strings.TrimSpace(parsed.SubscriptionId) == "" { + return nil, fmt.Errorf("invalid akvs:// reference %q: subscription-id must not be empty", ref) + } + if strings.TrimSpace(parsed.VaultName) == "" { + return nil, fmt.Errorf("invalid akvs:// reference %q: vault-name must not be empty", ref) + } + if !vaultNameRe.MatchString(parsed.VaultName) { return nil, fmt.Errorf( - "invalid akvs:// reference %q: expected format %s//", - ref, secretScheme, + "invalid akvs:// reference %q: vault name %q must be 3-24 characters, "+ + "start with a letter, and contain only alphanumeric characters and non-consecutive hyphens", + ref, parsed.VaultName, ) } - - for i, part := range parts { - if strings.TrimSpace(part) == "" { - labels := []string{"subscription-id", "vault-name", "secret-name"} - return nil, fmt.Errorf("invalid akvs:// reference %q: %s must not be empty", ref, labels[i]) - } + if strings.TrimSpace(parsed.SecretName) == "" { + return nil, fmt.Errorf("invalid akvs:// reference %q: secret-name must not be empty", ref) } return &SecretReference{ - SubscriptionID: parts[0], - VaultName: parts[1], - SecretName: parts[2], + SubscriptionID: parsed.SubscriptionId, + VaultName: parsed.VaultName, + SecretName: parsed.SecretName, }, nil } diff --git a/cli/azd/pkg/azdext/mcp_security.go b/cli/azd/pkg/azdext/mcp_security.go index c9c2122eaeb..be7e8bac6f0 100644 --- a/cli/azd/pkg/azdext/mcp_security.go +++ b/cli/azd/pkg/azdext/mcp_security.go @@ -6,6 +6,7 @@ package azdext import ( "fmt" "net" + "net/http" "net/url" "os" "path/filepath" @@ -25,6 +26,10 @@ type MCPSecurityPolicy struct { blockedHosts map[string]bool // lookupHost is used for DNS resolution; override in tests. lookupHost func(string) ([]string, error) + // onBlocked is an optional callback invoked when a URL or path is blocked. + // Parameters: action ("url_blocked", "path_blocked", "redirect_blocked"), + // detail (human-readable explanation). Safe for concurrent use. + onBlocked func(action, detail string) } // NewMCPSecurityPolicy creates an empty security policy. @@ -111,6 +116,28 @@ func (p *MCPSecurityPolicy) ValidatePathsWithinBase(basePaths ...string) *MCPSec return p } +// OnBlocked registers a callback that is invoked whenever a URL, path, or +// redirect is blocked by the security policy. This enables security audit +// logging without coupling the policy to a specific logging framework. +// +// The callback receives an action tag ("url_blocked", "path_blocked", +// "redirect_blocked") and a human-readable detail string. It must be safe +// for concurrent invocation. +func (p *MCPSecurityPolicy) OnBlocked(fn func(action, detail string)) *MCPSecurityPolicy { + p.mu.Lock() + defer p.mu.Unlock() + p.onBlocked = fn + return p +} + +// notifyBlocked invokes the onBlocked callback if set. Must be called with +// p.mu held (at least RLock). +func (p *MCPSecurityPolicy) notifyBlocked(action, detail string) { + if p.onBlocked != nil { + p.onBlocked(action, detail) + } +} + // isLocalhostHost returns true if the host is localhost or a loopback address. func isLocalhostHost(host string) bool { h := strings.ToLower(host) @@ -140,20 +167,27 @@ func (p *MCPSecurityPolicy) CheckURL(rawURL string) error { // always allowed case "http": if p.requireHTTPS && !isLocalhostHost(host) { - return fmt.Errorf("HTTPS required: %s", rawURL) + err := fmt.Errorf("HTTPS required: %s", rawURL) + p.notifyBlocked("url_blocked", err.Error()) + return err } default: - return fmt.Errorf("scheme not allowed: %q (only http and https are permitted)", u.Scheme) + err := fmt.Errorf("scheme not allowed: %q (only http and https are permitted)", u.Scheme) + p.notifyBlocked("url_blocked", err.Error()) + return err } // Check if the host is directly blocked. if p.blockedHosts[strings.ToLower(host)] { - return fmt.Errorf("blocked host: %s", host) + err := fmt.Errorf("blocked host: %s", host) + p.notifyBlocked("url_blocked", err.Error()) + return err } // If the host is an IP literal, check it directly against blocked CIDRs. if ip := net.ParseIP(host); ip != nil { if err := p.checkIP(ip, host); err != nil { + p.notifyBlocked("url_blocked", err.Error()) return err } } else { @@ -162,14 +196,19 @@ func (p *MCPSecurityPolicy) CheckURL(rawURL string) error { if err != nil { // Fail-closed: if DNS resolution fails, block the request. // This prevents SSRF bypasses via DNS rebinding or transient failures. - return fmt.Errorf("DNS resolution failed for host %s: %w", host, err) + blockErr := fmt.Errorf("DNS resolution failed for host %s: %w", host, err) + p.notifyBlocked("url_blocked", blockErr.Error()) + return blockErr } for _, addr := range addrs { if p.blockedHosts[strings.ToLower(addr)] { - return fmt.Errorf("blocked host: %s (resolved from %s)", addr, host) + blockErr := fmt.Errorf("blocked host: %s (resolved from %s)", addr, host) + p.notifyBlocked("url_blocked", blockErr.Error()) + return blockErr } if ip := net.ParseIP(addr); ip != nil { if err := p.checkIP(ip, host); err != nil { + p.notifyBlocked("url_blocked", err.Error()) return err } } @@ -251,6 +290,14 @@ func (p *MCPSecurityPolicy) checkIP(ip net.IP, originalHost string) error { // CheckPath validates a file path against the security policy. // Resolves symlinks and checks for directory traversal. +// +// Security note (TOCTOU): There is an inherent time-of-check to time-of-use +// gap between the symlink resolution performed here and the caller's +// subsequent file operation. An adversary with write access to the filesystem +// could create or modify a symlink between the check and the use. This is a +// fundamental limitation of path-based validation on POSIX systems. Callers +// performing security-sensitive file operations should prefer O_NOFOLLOW or +// use file-descriptor-based approaches where possible. func (p *MCPSecurityPolicy) CheckPath(path string) error { p.mu.RLock() defer p.mu.RUnlock() @@ -261,7 +308,9 @@ func (p *MCPSecurityPolicy) CheckPath(path string) error { // Reject paths containing ".." before any cleaning to catch obvious traversal attempts. if strings.Contains(path, "..") { - return fmt.Errorf("path traversal detected: %s", path) + err := fmt.Errorf("path traversal detected: %s", path) + p.notifyBlocked("path_blocked", err.Error()) + return err } cleaned := filepath.Clean(path) @@ -300,7 +349,9 @@ func (p *MCPSecurityPolicy) CheckPath(path string) error { } } - return fmt.Errorf("path %s is not within any allowed base directory", path) + err = fmt.Errorf("path %s is not within any allowed base directory", path) + p.notifyBlocked("path_blocked", err.Error()) + return err } // IsHeaderBlocked checks if a header name is in the redacted set. @@ -348,3 +399,47 @@ func resolveExistingPrefix(p string) string { } } } + +// --------------------------------------------------------------------------- +// Redirect SSRF protection +// --------------------------------------------------------------------------- + +// redirectBlockedHosts lists cloud metadata service endpoints that must never +// be the target of an HTTP redirect. +var redirectBlockedHosts = map[string]bool{ + "169.254.169.254": true, + "fd00:ec2::254": true, + "metadata.google.internal": true, + "100.100.100.200": true, +} + +// SSRFSafeRedirect is an [http.Client] CheckRedirect function that blocks +// redirects to private networks and cloud metadata endpoints. It prevents +// redirect-based SSRF attacks where an attacker-controlled URL redirects to +// an internal service. +// +// Usage: +// +// client := &http.Client{CheckRedirect: azdext.SSRFSafeRedirect} +func SSRFSafeRedirect(req *http.Request, via []*http.Request) error { + const maxRedirects = 10 + if len(via) >= maxRedirects { + return fmt.Errorf("stopped after %d redirects", maxRedirects) + } + + host := req.URL.Hostname() + + // Block redirects to known metadata endpoints. + if redirectBlockedHosts[strings.ToLower(host)] { + return fmt.Errorf("redirect to metadata endpoint %s blocked (SSRF protection)", host) + } + + // Block redirects to private/loopback IP addresses. + if ip := net.ParseIP(host); ip != nil { + if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsUnspecified() { + return fmt.Errorf("redirect to private/loopback IP %s blocked (SSRF protection)", ip) + } + } + + return nil +} diff --git a/cli/azd/pkg/azdext/pagination.go b/cli/azd/pkg/azdext/pagination.go index 6f95645686d..b6fbdb128bb 100644 --- a/cli/azd/pkg/azdext/pagination.go +++ b/cli/azd/pkg/azdext/pagination.go @@ -14,10 +14,19 @@ import ( "strings" ) +const ( + // defaultMaxPages is the default upper bound on pages fetched by [Pager.Collect]. + // Individual callers can override this via [PagerOptions.MaxPages]. + // A value of 0 means unlimited (no cap), which is the default for manual + // NextPage iteration. Collect uses this default when MaxPages is unset. + defaultMaxPages = 500 +) + const ( // maxPageResponseSize limits the maximum size of a single page response // body to prevent excessive memory consumption from malicious or - // misconfigured servers. + // misconfigured servers. 10 MB is intentionally above typical Azure list + // payloads while still bounding memory use. maxPageResponseSize int64 = 10 << 20 // 10 MB // maxErrorBodySize limits the size of error response bodies captured @@ -44,6 +53,7 @@ type Pager[T any] struct { done bool opts PagerOptions originHost string // host of the initial URL for SSRF protection + pageCount int // number of pages fetched so far } // PageResponse is a single page returned by [Pager.NextPage]. @@ -59,6 +69,18 @@ type PageResponse[T any] struct { type PagerOptions struct { // Method overrides the HTTP method used for page requests. Defaults to GET. Method string + + // MaxPages limits the maximum number of pages that [Pager.Collect] will + // fetch. When set to a positive value, Collect stops after fetching that + // many pages. A value of 0 means unlimited (no cap) for manual NextPage + // iteration; Collect applies [defaultMaxPages] when this is 0. + MaxPages int + + // MaxItems limits the maximum total items that [Pager.Collect] will + // accumulate. When the collected items reach this count, Collect stops + // and returns the items gathered so far (truncated to MaxItems). + // A value of 0 means unlimited (no cap). + MaxItems int } // HTTPDoer abstracts the HTTP call so that [ResilientClient] or any @@ -170,6 +192,9 @@ func (p *Pager[T]) NextPage(ctx context.Context) (*PageResponse[T], error) { p.nextURL = page.NextLink } + // Track page count for MaxPages enforcement in Collect. + p.pageCount++ + return &page, nil } @@ -199,13 +224,23 @@ func (p *Pager[T]) validateNextLink(nextLink string) error { } // Collect is a convenience method that fetches all remaining pages and -// returns all items in a single slice. Use with caution on large result sets. +// returns all items in a single slice. +// +// To prevent unbounded memory growth from runaway pagination, Collect +// enforces [PagerOptions.MaxPages] (defaults to [defaultMaxPages] when +// unset) and [PagerOptions.MaxItems]. When either limit is reached, +// iteration stops and the items collected so far are returned. // // If NextPage returns both page data and an error (e.g. rejected nextLink), // the page data is included in the returned slice before returning the error. func (p *Pager[T]) Collect(ctx context.Context) ([]T, error) { var all []T + maxPages := p.opts.MaxPages + if maxPages <= 0 { + maxPages = defaultMaxPages + } + for p.More() { page, err := p.NextPage(ctx) if page != nil { @@ -214,6 +249,19 @@ func (p *Pager[T]) Collect(ctx context.Context) ([]T, error) { if err != nil { return all, err } + + // Enforce MaxItems: truncate and stop if exceeded. + if p.opts.MaxItems > 0 && len(all) >= p.opts.MaxItems { + if len(all) > p.opts.MaxItems { + all = all[:p.opts.MaxItems] + } + break + } + + // Enforce MaxPages: stop after collecting the configured number of pages. + if p.pageCount >= maxPages { + break + } } return all, nil @@ -229,6 +277,18 @@ type PaginationError struct { func (e *PaginationError) Error() string { return fmt.Sprintf( "azdext.Pager: page request returned HTTP %d (url=%s)", - e.StatusCode, e.URL, + e.StatusCode, redactURL(e.URL), ) } + +// redactURL strips query parameters and fragments from a URL to avoid leaking +// tokens, SAS signatures, or other secrets in log/error messages. +func redactURL(rawURL string) string { + u, err := url.Parse(rawURL) + if err != nil { + return "" + } + u.RawQuery = "" + u.Fragment = "" + return u.String() +} diff --git a/cli/azd/pkg/azdext/resilient_http_client.go b/cli/azd/pkg/azdext/resilient_http_client.go index 2916b885334..a39cef33f1c 100644 --- a/cli/azd/pkg/azdext/resilient_http_client.go +++ b/cli/azd/pkg/azdext/resilient_http_client.go @@ -107,8 +107,9 @@ func NewResilientClient(tokenProvider azcore.TokenCredential, opts *ResilientCli return &ResilientClient{ httpClient: &http.Client{ - Transport: transport, - Timeout: opts.Timeout, + Transport: transport, + Timeout: opts.Timeout, + CheckRedirect: SSRFSafeRedirect, }, tokenProvider: tokenProvider, scopeDetector: sd, From 958d4f543b3ae7857916a51ef7133a9dae43a89d Mon Sep 17 00:00:00 2001 From: Jon Gallant <2163001+jongio@users.noreply.github.com> Date: Mon, 2 Mar 2026 08:47:46 -0800 Subject: [PATCH 4/6] fix(azdext): correct default error reason for non-HTTP keyvault failures Non-HTTP errors (network timeout, DNS failure, context canceled) from GetSecret were incorrectly classified as ResolveReasonAccessDenied. Changed default to ResolveReasonServiceError so callers get accurate error classification and don't mistake transport errors for auth issues. Updated TestResolve_NonResponseError to verify the corrected behavior. --- cli/azd/pkg/azdext/keyvault_resolver.go | 4 +++- cli/azd/pkg/azdext/keyvault_resolver_test.go | 7 ++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/cli/azd/pkg/azdext/keyvault_resolver.go b/cli/azd/pkg/azdext/keyvault_resolver.go index 0f78416656c..70bf4814ed5 100644 --- a/cli/azd/pkg/azdext/keyvault_resolver.go +++ b/cli/azd/pkg/azdext/keyvault_resolver.go @@ -131,7 +131,9 @@ func (r *KeyVaultResolver) Resolve(ctx context.Context, ref string) (string, err resp, err := client.GetSecret(ctx, parsed.SecretName, "", nil) if err != nil { - reason := ResolveReasonAccessDenied + // Default to ServiceError for non-HTTP errors (network timeout, DNS + // failure, etc.). AccessDenied is only used for 401/403 status codes. + reason := ResolveReasonServiceError var respErr *azcore.ResponseError if errors.As(err, &respErr) { diff --git a/cli/azd/pkg/azdext/keyvault_resolver_test.go b/cli/azd/pkg/azdext/keyvault_resolver_test.go index 628f301d5d3..ade5d0cf3f0 100644 --- a/cli/azd/pkg/azdext/keyvault_resolver_test.go +++ b/cli/azd/pkg/azdext/keyvault_resolver_test.go @@ -421,9 +421,10 @@ func TestResolve_NonResponseError(t *testing.T) { t.Fatalf("error type = %T, want *KeyVaultResolveError", err) } - // Non-ResponseError defaults to access_denied - if resolveErr.Reason != ResolveReasonAccessDenied { - t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonAccessDenied) + // Non-ResponseError defaults to service_error (not access_denied), + // since network/DNS/timeout errors are not auth issues. + if resolveErr.Reason != ResolveReasonServiceError { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonServiceError) } } From 587f3e9bee9dd160ee19c8eeab1af20ce6840a95 Mon Sep 17 00:00:00 2001 From: Jon Gallant <2163001+jongio@users.noreply.github.com> Date: Mon, 2 Mar 2026 13:00:26 -0800 Subject: [PATCH 5/6] fix: address profile review findings for stacked PR Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- cli/azd/pkg/azdext/mcp_security.go | 58 +++++----- cli/azd/pkg/azdext/pagination.go | 9 ++ cli/azd/pkg/azdext/pagination_test.go | 104 ++++++++++++++++++ cli/azd/pkg/azdext/resilient_http_client.go | 11 +- .../pkg/azdext/resilient_http_client_test.go | 95 +++++++++++++++- 5 files changed, 240 insertions(+), 37 deletions(-) diff --git a/cli/azd/pkg/azdext/mcp_security.go b/cli/azd/pkg/azdext/mcp_security.go index be7e8bac6f0..5976bc654a9 100644 --- a/cli/azd/pkg/azdext/mcp_security.go +++ b/cli/azd/pkg/azdext/mcp_security.go @@ -130,14 +130,6 @@ func (p *MCPSecurityPolicy) OnBlocked(fn func(action, detail string)) *MCPSecuri return p } -// notifyBlocked invokes the onBlocked callback if set. Must be called with -// p.mu held (at least RLock). -func (p *MCPSecurityPolicy) notifyBlocked(action, detail string) { - if p.onBlocked != nil { - p.onBlocked(action, detail) - } -} - // isLocalhostHost returns true if the host is localhost or a loopback address. func isLocalhostHost(host string) bool { h := strings.ToLower(host) @@ -152,8 +144,16 @@ func isLocalhostHost(host string) bool { // Returns an error describing the violation, or nil if allowed. func (p *MCPSecurityPolicy) CheckURL(rawURL string) error { p.mu.RLock() - defer p.mu.RUnlock() + onBlocked := p.onBlocked + err := p.checkURLCore(rawURL) + p.mu.RUnlock() + if err != nil && onBlocked != nil { + onBlocked("url_blocked", err.Error()) + } + return err +} +func (p *MCPSecurityPolicy) checkURLCore(rawURL string) error { u, err := url.Parse(rawURL) if err != nil { return fmt.Errorf("invalid URL: %w", err) @@ -167,27 +167,20 @@ func (p *MCPSecurityPolicy) CheckURL(rawURL string) error { // always allowed case "http": if p.requireHTTPS && !isLocalhostHost(host) { - err := fmt.Errorf("HTTPS required: %s", rawURL) - p.notifyBlocked("url_blocked", err.Error()) - return err + return fmt.Errorf("HTTPS required: %s", rawURL) } default: - err := fmt.Errorf("scheme not allowed: %q (only http and https are permitted)", u.Scheme) - p.notifyBlocked("url_blocked", err.Error()) - return err + return fmt.Errorf("scheme not allowed: %q (only http and https are permitted)", u.Scheme) } // Check if the host is directly blocked. if p.blockedHosts[strings.ToLower(host)] { - err := fmt.Errorf("blocked host: %s", host) - p.notifyBlocked("url_blocked", err.Error()) - return err + return fmt.Errorf("blocked host: %s", host) } // If the host is an IP literal, check it directly against blocked CIDRs. if ip := net.ParseIP(host); ip != nil { if err := p.checkIP(ip, host); err != nil { - p.notifyBlocked("url_blocked", err.Error()) return err } } else { @@ -196,19 +189,14 @@ func (p *MCPSecurityPolicy) CheckURL(rawURL string) error { if err != nil { // Fail-closed: if DNS resolution fails, block the request. // This prevents SSRF bypasses via DNS rebinding or transient failures. - blockErr := fmt.Errorf("DNS resolution failed for host %s: %w", host, err) - p.notifyBlocked("url_blocked", blockErr.Error()) - return blockErr + return fmt.Errorf("DNS resolution failed for host %s: %w", host, err) } for _, addr := range addrs { if p.blockedHosts[strings.ToLower(addr)] { - blockErr := fmt.Errorf("blocked host: %s (resolved from %s)", addr, host) - p.notifyBlocked("url_blocked", blockErr.Error()) - return blockErr + return fmt.Errorf("blocked host: %s (resolved from %s)", addr, host) } if ip := net.ParseIP(addr); ip != nil { if err := p.checkIP(ip, host); err != nil { - p.notifyBlocked("url_blocked", err.Error()) return err } } @@ -300,17 +288,23 @@ func (p *MCPSecurityPolicy) checkIP(ip net.IP, originalHost string) error { // use file-descriptor-based approaches where possible. func (p *MCPSecurityPolicy) CheckPath(path string) error { p.mu.RLock() - defer p.mu.RUnlock() + onBlocked := p.onBlocked + err := p.checkPathCore(path) + p.mu.RUnlock() + if err != nil && onBlocked != nil { + onBlocked("path_blocked", err.Error()) + } + return err +} +func (p *MCPSecurityPolicy) checkPathCore(path string) error { if len(p.allowedBasePaths) == 0 { return nil } // Reject paths containing ".." before any cleaning to catch obvious traversal attempts. if strings.Contains(path, "..") { - err := fmt.Errorf("path traversal detected: %s", path) - p.notifyBlocked("path_blocked", err.Error()) - return err + return fmt.Errorf("path traversal detected: %s", path) } cleaned := filepath.Clean(path) @@ -349,9 +343,7 @@ func (p *MCPSecurityPolicy) CheckPath(path string) error { } } - err = fmt.Errorf("path %s is not within any allowed base directory", path) - p.notifyBlocked("path_blocked", err.Error()) - return err + return fmt.Errorf("path %s is not within any allowed base directory", path) } // IsHeaderBlocked checks if a header name is in the redacted set. diff --git a/cli/azd/pkg/azdext/pagination.go b/cli/azd/pkg/azdext/pagination.go index b6fbdb128bb..ea4de409f1d 100644 --- a/cli/azd/pkg/azdext/pagination.go +++ b/cli/azd/pkg/azdext/pagination.go @@ -54,6 +54,7 @@ type Pager[T any] struct { opts PagerOptions originHost string // host of the initial URL for SSRF protection pageCount int // number of pages fetched so far + truncated bool } // PageResponse is a single page returned by [Pager.NextPage]. @@ -139,6 +140,11 @@ func (p *Pager[T]) More() bool { return !p.done && p.nextURL != "" } +// Truncated reports whether [Collect] stopped due to MaxPages or MaxItems limits. +func (p *Pager[T]) Truncated() bool { + return p.truncated +} + // NextPage fetches the next page of results. Returns an error if the request // fails, the response is not 2xx, or the body cannot be decoded. // @@ -235,6 +241,7 @@ func (p *Pager[T]) validateNextLink(nextLink string) error { // the page data is included in the returned slice before returning the error. func (p *Pager[T]) Collect(ctx context.Context) ([]T, error) { var all []T + p.truncated = false maxPages := p.opts.MaxPages if maxPages <= 0 { @@ -255,11 +262,13 @@ func (p *Pager[T]) Collect(ctx context.Context) ([]T, error) { if len(all) > p.opts.MaxItems { all = all[:p.opts.MaxItems] } + p.truncated = true break } // Enforce MaxPages: stop after collecting the configured number of pages. if p.pageCount >= maxPages { + p.truncated = true break } } diff --git a/cli/azd/pkg/azdext/pagination_test.go b/cli/azd/pkg/azdext/pagination_test.go index 1f8c3537279..58e80ef9f4c 100644 --- a/cli/azd/pkg/azdext/pagination_test.go +++ b/cli/azd/pkg/azdext/pagination_test.go @@ -7,6 +7,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "io" "net/http" "strings" @@ -481,3 +482,106 @@ func TestPager_CollectWithSSRFError(t *testing.T) { t.Errorf("all = %v, want [a b] (partial results before SSRF error)", all) } } + +func TestPager_TruncatedByMaxPages(t *testing.T) { + t.Parallel() + + var responses []*doerResponse + for i := 1; i <= 5; i++ { + nextLink := "" + if i < 5 { + nextLink = fmt.Sprintf("https://example.com/api?page=%d", i+1) + } + body := pageJSON([]int{i}, nextLink) + responses = append(responses, &doerResponse{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: http.Header{}, + }, + }) + } + + doer := &mockDoer{responses: responses} + pager := NewPager[int](doer, "https://example.com/api?page=1", &PagerOptions{MaxPages: 3}) + + all, err := pager.Collect(context.Background()) + if err != nil { + t.Fatalf("Collect failed: %v", err) + } + + if len(all) != 3 { + t.Errorf("len(all) = %d, want 3", len(all)) + } + + if !pager.Truncated() { + t.Error("Truncated() = false, want true (stopped at MaxPages)") + } +} + +func TestPager_TruncatedByMaxItems(t *testing.T) { + t.Parallel() + + var responses []*doerResponse + for i := 0; i < 3; i++ { + items := []int{i*4 + 1, i*4 + 2, i*4 + 3, i*4 + 4} + nextLink := "" + if i < 2 { + nextLink = fmt.Sprintf("https://example.com/api?page=%d", i+2) + } + body := pageJSON(items, nextLink) + responses = append(responses, &doerResponse{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: http.Header{}, + }, + }) + } + + doer := &mockDoer{responses: responses} + pager := NewPager[int](doer, "https://example.com/api?page=1", &PagerOptions{MaxItems: 5}) + + all, err := pager.Collect(context.Background()) + if err != nil { + t.Fatalf("Collect failed: %v", err) + } + + if len(all) != 5 { + t.Errorf("len(all) = %d, want 5", len(all)) + } + + if !pager.Truncated() { + t.Error("Truncated() = false, want true (stopped at MaxItems)") + } +} + +func TestPager_NotTruncatedOnNaturalEnd(t *testing.T) { + t.Parallel() + + body := pageJSON([]string{"a", "b"}, "") + doer := &mockDoer{ + responses: []*doerResponse{ + {resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: http.Header{}, + }}, + }, + } + + pager := NewPager[string](doer, "https://example.com/api", nil) + + all, err := pager.Collect(context.Background()) + if err != nil { + t.Fatalf("Collect failed: %v", err) + } + + if len(all) != 2 { + t.Errorf("len(all) = %d, want 2", len(all)) + } + + if pager.Truncated() { + t.Error("Truncated() = true, want false (natural end)") + } +} diff --git a/cli/azd/pkg/azdext/resilient_http_client.go b/cli/azd/pkg/azdext/resilient_http_client.go index a39cef33f1c..75d2d329596 100644 --- a/cli/azd/pkg/azdext/resilient_http_client.go +++ b/cli/azd/pkg/azdext/resilient_http_client.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "math" + "math/rand/v2" "net/http" "strconv" "time" @@ -127,6 +128,13 @@ func (rc *ResilientClient) Do(ctx context.Context, method, url string, body io.R if ctx == nil { return nil, errors.New("azdext.ResilientClient.Do: context must not be nil") } + if body != nil && rc.opts.MaxRetries > 0 { + if _, ok := body.(io.ReadSeeker); !ok { + return nil, errors.New( + "azdext.ResilientClient.Do: request body does not implement io.ReadSeeker; " + + "retries require a seekable body (use bytes.NewReader or strings.NewReader)") + } + } var lastErr error var retryAfterOverride time.Duration @@ -241,7 +249,8 @@ func (rc *ResilientClient) backoff(attempt int) time.Duration { delay = rc.opts.MaxDelay } - return delay + jitter := 0.5 + rand.Float64()*0.5 + return time.Duration(float64(delay) * jitter) } // isRetryable returns true for status codes that indicate a transient failure. diff --git a/cli/azd/pkg/azdext/resilient_http_client_test.go b/cli/azd/pkg/azdext/resilient_http_client_test.go index 4ffd67af103..3e5cddc9896 100644 --- a/cli/azd/pkg/azdext/resilient_http_client_test.go +++ b/cli/azd/pkg/azdext/resilient_http_client_test.go @@ -530,9 +530,9 @@ func TestResilientClient_NonSeekableBodyRetryError(t *testing.T) { t.Errorf("error = %q, want mention of io.ReadSeeker", err.Error()) } - // Should have made exactly 1 attempt (first gets 503 → retry → fail on body check). - if attempts != 1 { - t.Errorf("attempts = %d, want 1 (fail before second attempt)", attempts) + // Should fail fast before any request. + if attempts != 0 { + t.Errorf("attempts = %d, want 0 (fail fast before any request)", attempts) } } @@ -620,3 +620,92 @@ func TestResilientClient_RetryAfterCapped(t *testing.T) { t.Errorf("retryAfterFromResponse() = %v, want %v (capping happens in Do)", got, 999999*time.Second) } } + +func TestResilientClient_BackoffJitter(t *testing.T) { + t.Parallel() + + rc := NewResilientClient(nil, &ResilientClientOptions{ + InitialDelay: 100 * time.Millisecond, + MaxDelay: 10 * time.Second, + }) + + seen := make(map[time.Duration]bool) + for range 20 { + d := rc.backoff(1) + seen[d] = true + if d < 50*time.Millisecond || d >= 100*time.Millisecond { + t.Errorf("backoff(1) = %v, want in [50ms, 100ms)", d) + } + } + if len(seen) < 2 { + t.Error("backoff jitter produced identical values across 20 calls") + } +} + +func TestResilientClient_NonSeekableBodyFailsFast(t *testing.T) { + t.Parallel() + + var attempts int + transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { + attempts++ + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("ok")), + Header: http.Header{}, + }, nil + }) + + rc := NewResilientClient(nil, &ResilientClientOptions{ + Transport: transport, + MaxRetries: 2, + InitialDelay: time.Millisecond, + }) + + body := io.NopCloser(strings.NewReader("payload")) + _, err := rc.Do(context.Background(), http.MethodPost, "https://example.com/api", body) + if err == nil { + t.Fatal("expected error for non-seekable body with retries enabled") + } + + if !strings.Contains(err.Error(), "io.ReadSeeker") { + t.Errorf("error = %q, want mention of io.ReadSeeker", err.Error()) + } + + if attempts != 0 { + t.Errorf("attempts = %d, want 0 (fail fast before any request)", attempts) + } +} + +func TestResilientClient_RetryAfterCappedInDo(t *testing.T) { + t.Parallel() + + var attempts int + transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { + attempts++ + h := http.Header{} + h.Set("retry-after", "999999") + return &http.Response{ + StatusCode: http.StatusTooManyRequests, + Body: io.NopCloser(strings.NewReader("throttled")), + Header: h, + }, nil + }) + + rc := NewResilientClient(nil, &ResilientClientOptions{ + Transport: transport, + MaxRetries: 1, + InitialDelay: time.Millisecond, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) + defer cancel() + + _, err := rc.Do(ctx, http.MethodGet, "https://example.com/api", nil) + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("expected context.DeadlineExceeded (proving cap was applied), got: %v", err) + } + + if attempts != 1 { + t.Errorf("attempts = %d, want 1", attempts) + } +} From 507eeb7468d160dba17569a64d3100133662d6fe Mon Sep 17 00:00:00 2001 From: Jon Gallant <2163001+jongio@users.noreply.github.com> Date: Mon, 2 Mar 2026 17:27:43 -0800 Subject: [PATCH 6/6] fix(azdext): remediate hack findings Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- cli/azd/pkg/azdext/config_helper.go | 39 ++++++++++++++++++++++++++--- cli/azd/pkg/azdext/mcp_security.go | 13 +++++++--- cli/azd/pkg/azdext/pagination.go | 38 ++++++++++++++++++++++++++-- 3 files changed, 81 insertions(+), 9 deletions(-) diff --git a/cli/azd/pkg/azdext/config_helper.go b/cli/azd/pkg/azdext/config_helper.go index cc28ab67a01..843729dc5d4 100644 --- a/cli/azd/pkg/azdext/config_helper.go +++ b/cli/azd/pkg/azdext/config_helper.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "regexp" + "strings" ) // ConfigHelper provides typed, ergonomic access to azd configuration through @@ -409,18 +410,48 @@ func (e *ConfigError) Unwrap() error { return e.Err } -var configPathRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9._-]*$`) +// configSegmentRe matches a single segment: starts with alphanumeric, +// contains only [a-zA-Z0-9_-], and is 1–63 characters. +var configSegmentRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]{0,62}$`) -// validatePath checks that a config path is non-empty. +// validatePath checks that a config path is non-empty, does not contain +// consecutive dots or empty segments, and that each dot-separated segment +// conforms to a restricted character set. +// +// Valid: "extensions.myext.port", "my-key", "a.b.c" +// Invalid: "", "..foo", "a..b", ".leading", "trailing.", "a.b.c.d.e.f" (okay, but each segment validated) func validatePath(path string) error { if path == "" { return errors.New("azdext.ConfigHelper: config path must not be empty") } - if !configPathRe.MatchString(path) { + + // Reject leading/trailing dots and consecutive dots (empty segments). + if strings.HasPrefix(path, ".") || strings.HasSuffix(path, ".") || strings.Contains(path, "..") { return errors.New( - "azdext.ConfigHelper: config path must start with alphanumeric and contain only [a-zA-Z0-9._-]", + "azdext.ConfigHelper: config path must not have empty segments " + + "(no leading/trailing dots or consecutive dots)", ) } + // Validate each dot-separated segment individually. + segments := strings.Split(path, ".") + for _, seg := range segments { + if !configSegmentRe.MatchString(seg) { + return fmt.Errorf( + "azdext.ConfigHelper: config path segment %q must start with alphanumeric "+ + "and contain only [a-zA-Z0-9_-], max 63 chars", + truncateConfigValue(seg, 64), + ) + } + } + return nil } + +// truncateConfigValue truncates s for safe inclusion in error messages. +func truncateConfigValue(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] + "..." +} diff --git a/cli/azd/pkg/azdext/mcp_security.go b/cli/azd/pkg/azdext/mcp_security.go index 5976bc654a9..27c0d9a6860 100644 --- a/cli/azd/pkg/azdext/mcp_security.go +++ b/cli/azd/pkg/azdext/mcp_security.go @@ -283,9 +283,16 @@ func (p *MCPSecurityPolicy) checkIP(ip net.IP, originalHost string) error { // gap between the symlink resolution performed here and the caller's // subsequent file operation. An adversary with write access to the filesystem // could create or modify a symlink between the check and the use. This is a -// fundamental limitation of path-based validation on POSIX systems. Callers -// performing security-sensitive file operations should prefer O_NOFOLLOW or -// use file-descriptor-based approaches where possible. +// fundamental limitation of path-based validation on POSIX systems. +// +// Mitigations callers should consider: +// - Use O_NOFOLLOW when opening files after validation (prevents symlink +// following at the final component). +// - Use file-descriptor-based approaches (openat2 with RESOLVE_BENEATH on +// Linux 5.6+) where possible. +// - Avoid writing to directories that untrusted users can modify. +// - Consider validating the opened fd's path post-open via /proc/self/fd/N +// or fstat. func (p *MCPSecurityPolicy) CheckPath(path string) error { p.mu.RLock() onBlocked := p.onBlocked diff --git a/cli/azd/pkg/azdext/pagination.go b/cli/azd/pkg/azdext/pagination.go index ea4de409f1d..021a7f1d5cd 100644 --- a/cli/azd/pkg/azdext/pagination.go +++ b/cli/azd/pkg/azdext/pagination.go @@ -173,7 +173,7 @@ func (p *Pager[T]) NextPage(ctx context.Context) (*PageResponse[T], error) { return nil, &PaginationError{ StatusCode: resp.StatusCode, URL: p.nextURL, - Body: string(body), + Body: sanitizeErrorBody(string(body)), } } @@ -276,11 +276,19 @@ func (p *Pager[T]) Collect(ctx context.Context) ([]T, error) { return all, nil } +// PaginationError is returned when a page request receives a non-2xx response. +// maxPaginationErrorBodyLen limits the response body length stored in +// PaginationError to prevent sensitive data leakage through error messages. +const maxPaginationErrorBodyLen = 1024 + // PaginationError is returned when a page request receives a non-2xx response. type PaginationError struct { StatusCode int URL string - Body string + // Body is a truncated, sanitized excerpt of the error response body for + // diagnostics. Capped at [maxPaginationErrorBodyLen] bytes and stripped + // of control characters to prevent log forging. + Body string } func (e *PaginationError) Error() string { @@ -290,6 +298,32 @@ func (e *PaginationError) Error() string { ) } +// sanitizeErrorBody truncates and strips control characters from an error +// response body to prevent log forging and sensitive data leakage. +func sanitizeErrorBody(body string) string { + if len(body) > maxPaginationErrorBodyLen { + body = body[:maxPaginationErrorBodyLen] + "...[truncated]" + } + return stripControlChars(body) +} + +// stripControlChars replaces ASCII control characters (except tab) with a +// space to prevent log forging via CR/LF injection or terminal escape sequences. +func stripControlChars(s string) string { + var b strings.Builder + b.Grow(len(s)) + for _, r := range s { + if r < 0x20 && r != '\t' { + b.WriteRune(' ') + } else if r == 0x7F { + b.WriteRune(' ') + } else { + b.WriteRune(r) + } + } + return b.String() +} + // redactURL strips query parameters and fragments from a URL to avoid leaking // tokens, SAS signatures, or other secrets in log/error messages. func redactURL(rawURL string) string {