diff --git a/cmd/non_interactive_test.go b/cmd/non_interactive_test.go index 11709969..2ac1e524 100644 --- a/cmd/non_interactive_test.go +++ b/cmd/non_interactive_test.go @@ -13,6 +13,7 @@ func TestNonInteractiveFlagIsRegistered(t *testing.T) { flag := root.PersistentFlags().Lookup("non-interactive") if flag == nil { t.Fatal("expected --non-interactive flag to be registered on root command") + return } if flag.DefValue != "false" { t.Fatalf("expected default value to be false, got %q", flag.DefValue) diff --git a/cmd/root.go b/cmd/root.go index e1ed3731..5a104f0d 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -80,11 +80,17 @@ func runStart(ctx context.Context, rt runtime.Runtime, cfg *env.Env, tel *teleme // TODO: replace map with a typed payload struct once event schema is finalised tel.Emit(ctx, "cli_cmd", map[string]any{"cmd": "lstk start", "params": []string{}}) - platformClient := api.NewPlatformClient(cfg.APIEndpoint) + opts := container.StartOptions{ + PlatformClient: api.NewPlatformClient(cfg.APIEndpoint), + AuthToken: cfg.AuthToken, + ForceFileKeyring: cfg.ForceFileKeyring, + WebAppURL: cfg.WebAppURL, + LocalStackHost: cfg.LocalStackHost, + } if isInteractiveMode(cfg) { - return ui.Run(ctx, rt, version.Version(), platformClient, cfg.AuthToken, cfg.ForceFileKeyring, cfg.WebAppURL) + return ui.Run(ctx, rt, version.Version(), opts) } - return container.Start(ctx, rt, output.NewPlainSink(os.Stdout), platformClient, cfg.AuthToken, cfg.ForceFileKeyring, cfg.WebAppURL, false) + return container.Start(ctx, rt, output.NewPlainSink(os.Stdout), opts, false) } func isInteractiveMode(cfg *env.Env) bool { diff --git a/go.mod b/go.mod index 8270dd2e..35253aa6 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/stretchr/testify v1.11.1 go.uber.org/mock v0.6.0 golang.org/x/term v0.40.0 + gopkg.in/ini.v1 v1.67.1 gotest.tools/v3 v3.5.2 ) diff --git a/go.sum b/go.sum index f8df7a69..78ce4a3a 100644 --- a/go.sum +++ b/go.sum @@ -47,6 +47,7 @@ github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3 github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/danieljoos/wincred v1.2.3 h1:v7dZC2x32Ut3nEfRH+vhoZGvN72+dQ/snVXo/vMFLdQ= github.com/danieljoos/wincred v1.2.3/go.mod h1:6qqX0WNrS4RzPZ1tnroDzq9kY3fu1KwE7MRLQK4X0bs= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= @@ -151,8 +152,14 @@ github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU= github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= @@ -210,6 +217,9 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/ini.v1 v1.67.1 h1:tVBILHy0R6e4wkYOn3XmiITt/hEVH4TFMYvAX2Ytz6k= +gopkg.in/ini.v1 v1.67.1/go.mod h1:x/cyOwCgZqOkJoDIJ3c1KNHMo10+nLGAhh+kn3Zizss= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q= diff --git a/internal/awsconfig/awsconfig.go b/internal/awsconfig/awsconfig.go new file mode 100644 index 00000000..bc47cbe9 --- /dev/null +++ b/internal/awsconfig/awsconfig.go @@ -0,0 +1,256 @@ +package awsconfig + +import ( + "context" + "errors" + "fmt" + "net" + "net/url" + "os" + "path/filepath" + "strings" + + "gopkg.in/ini.v1" + + "github.com/localstack/lstk/internal/endpoint" + "github.com/localstack/lstk/internal/output" +) + +const ( + profileName = "localstack" + configSectionName = "profile localstack" // ~/.aws/config uses "profile " as section header + credsSectionName = "localstack" // ~/.aws/credentials uses just the profile name + // TODO: make region configurable (e.g. from container env or lstk config) + defaultRegion = "us-east-1" +) + +func credentialsDefaults() map[string]string { + return map[string]string{ + "aws_access_key_id": "test", + "aws_secret_access_key": "test", + } +} + +// isValidLocalStackEndpoint returns true if endpoint_url in ~/.aws/config points to +// the same LocalStack instance as resolvedHost. localhost, 127.0.0.1, and +// localhost.localstack.cloud are treated as interchangeable since all three +// resolve to the local machine. +func isValidLocalStackEndpoint(endpointURL, resolvedHost string) bool { + u, err := url.Parse(endpointURL) + if err != nil { + return false + } + if u.Scheme != "http" && u.Scheme != "https" { + return false + } + if u.Host == resolvedHost { + return true + } + // If the resolved host is one of the two known local hostnames, accept the + // other as equally valid — they both reach the same local service. + resolvedHostname, resolvedPort, err := net.SplitHostPort(resolvedHost) + if err != nil || !isLocalStackLocalHost(resolvedHostname) { + return false + } + return u.Port() == resolvedPort && isLocalStackLocalHost(u.Hostname()) +} + +func isLocalStackLocalHost(host string) bool { + return host == "127.0.0.1" || host == "localhost" || host == endpoint.Hostname +} + +func awsPaths() (configPath, credentialsPath string, err error) { + home, err := os.UserHomeDir() + if err != nil { + return "", "", err + } + return filepath.Join(home, ".aws", "config"), filepath.Join(home, ".aws", "credentials"), nil +} + +// profileStatus holds which AWS profile files need to be written or updated. +type profileStatus struct { + configNeeded bool + credsNeeded bool +} + +func (s profileStatus) anyNeeded() bool { + return s.configNeeded || s.credsNeeded +} + +func (s profileStatus) filesToModify() []string { + var files []string + if s.configNeeded { + files = append(files, "~/.aws/config") + } + if s.credsNeeded { + files = append(files, "~/.aws/credentials") + } + return files +} + +// checkProfileStatus determines which AWS profile files need to be written or updated. +func checkProfileStatus(configPath, credsPath, resolvedHost string) (profileStatus, error) { + configNeeded, err := configNeedsWrite(configPath, resolvedHost) + if err != nil { + return profileStatus{}, err + } + credsNeeded, err := credsNeedWrite(credsPath) + if err != nil { + return profileStatus{}, err + } + return profileStatus{configNeeded: configNeeded, credsNeeded: credsNeeded}, nil +} + +func configNeedsWrite(path, resolvedHost string) (bool, error) { + f, err := ini.Load(path) + if errors.Is(err, os.ErrNotExist) { + return true, nil + } + if err != nil { + return false, err + } + section, err := f.GetSection(configSectionName) + if err != nil { + return true, nil // section doesn't exist + } + endpointKey, err := section.GetKey("endpoint_url") + if err != nil || !isValidLocalStackEndpoint(endpointKey.Value(), resolvedHost) { + return true, nil + } + if !section.HasKey("region") { + return true, nil + } + return false, nil +} + +func credsNeedWrite(path string) (bool, error) { + f, err := ini.Load(path) + if errors.Is(err, os.ErrNotExist) { + return true, nil + } + if err != nil { + return false, err + } + section, err := f.GetSection(credsSectionName) + if err != nil { + return true, nil // section doesn't exist + } + for k, expected := range credentialsDefaults() { + key, err := section.GetKey(k) + if err != nil || key.Value() != expected { + return true, nil + } + } + return false, nil +} + +// profileExists reports whether the localstack profile section is present in both +// ~/.aws/config and ~/.aws/credentials. +func profileExists() (bool, error) { + configPath, credsPath, err := awsPaths() + if err != nil { + return false, err + } + configOK, err := sectionExists(configPath, configSectionName) + if err != nil { + return false, err + } + credsOK, err := sectionExists(credsPath, credsSectionName) + if err != nil { + return false, err + } + return configOK && credsOK, nil +} + +// writeProfile writes the localstack profile to ~/.aws/config and ~/.aws/credentials, +// creating or updating sections as needed. +func writeProfile(host string) error { + configPath, credsPath, err := awsPaths() + if err != nil { + return err + } + configKeys := map[string]string{ + "region": defaultRegion, + "output": "json", + "endpoint_url": "http://" + host, + } + if err := upsertSection(configPath, configSectionName, configKeys); err != nil { + return fmt.Errorf("failed to write %s: %w", configPath, err) + } + if err := upsertSection(credsPath, credsSectionName, credentialsDefaults()); err != nil { + return fmt.Errorf("failed to write %s: %w", credsPath, err) + } + return nil +} + +func writeConfigProfile(configPath, host string) error { + keys := map[string]string{ + "region": defaultRegion, + "output": "json", + "endpoint_url": "http://" + host, + } + return upsertSection(configPath, configSectionName, keys) +} + +func writeCredsProfile(credsPath string) error { + return upsertSection(credsPath, credsSectionName, credentialsDefaults()) +} + +// Setup checks for the localstack AWS profile and prompts to create or update it if needed. +// resolvedHost must be a host:port string (e.g. "localhost.localstack.cloud:4566"). +// In non-interactive mode, emits a note instead of prompting. +func Setup(ctx context.Context, sink output.Sink, interactive bool, resolvedHost string) error { + configPath, credsPath, err := awsPaths() + if err != nil { + output.EmitWarning(sink, fmt.Sprintf("could not determine AWS config paths: %v", err)) + return nil + } + + status, err := checkProfileStatus(configPath, credsPath, resolvedHost) + if err != nil { + output.EmitWarning(sink, fmt.Sprintf("could not check AWS profile: %v", err)) + return nil + } + if !status.anyNeeded() { + return nil + } + + if !interactive { + output.EmitNote(sink, fmt.Sprintf("No complete LocalStack AWS profile found. Run lstk interactively to configure one, or add a [profile %s] section to ~/.aws/config manually.", profileName)) + return nil + } + + files := strings.Join(status.filesToModify(), " and ") + responseCh := make(chan output.InputResponse, 1) + output.EmitUserInputRequest(sink, output.UserInputRequestEvent{ + Prompt: fmt.Sprintf("Set up LocalStack AWS profile in %s?", files), + Options: []output.InputOption{{Key: "y", Label: "Y"}, {Key: "n", Label: "n"}}, + ResponseCh: responseCh, + }) + + select { + case resp := <-responseCh: + if resp.Cancelled || resp.SelectedKey == "n" { + return nil + } + if status.configNeeded { + if err := writeConfigProfile(configPath, resolvedHost); err != nil { + output.EmitWarning(sink, fmt.Sprintf("could not update ~/.aws/config: %v", err)) + return nil + } + } + if status.credsNeeded { + if err := writeCredsProfile(credsPath); err != nil { + output.EmitWarning(sink, fmt.Sprintf("could not update ~/.aws/credentials: %v", err)) + return nil + } + } + output.EmitSuccess(sink, fmt.Sprintf("LocalStack AWS profile written to %s", files)) + output.EmitNote(sink, fmt.Sprintf("Try: aws s3 mb s3://test --profile %s", profileName)) + case <-ctx.Done(): + return ctx.Err() + } + + return nil +} + diff --git a/internal/awsconfig/awsconfig_test.go b/internal/awsconfig/awsconfig_test.go new file mode 100644 index 00000000..52630829 --- /dev/null +++ b/internal/awsconfig/awsconfig_test.go @@ -0,0 +1,379 @@ +package awsconfig + +import ( + "os" + "path/filepath" + "testing" +) + +func writeFile(t *testing.T, path, content string) { + t.Helper() + if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(path, []byte(content), 0600); err != nil { + t.Fatal(err) + } +} + +func TestProfileExists(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T, dir string) + want bool + }{ + { + name: "both present", + setup: func(t *testing.T, dir string) { + writeFile(t, filepath.Join(dir, ".aws", "config"), "[profile localstack]\nregion = us-east-1\n") + writeFile(t, filepath.Join(dir, ".aws", "credentials"), "[localstack]\naws_access_key_id = test\n") + }, + want: true, + }, + { + name: "config missing", + setup: func(t *testing.T, dir string) { + writeFile(t, filepath.Join(dir, ".aws", "credentials"), "[localstack]\naws_access_key_id = test\n") + }, + want: false, + }, + { + name: "credentials missing", + setup: func(t *testing.T, dir string) { + writeFile(t, filepath.Join(dir, ".aws", "config"), "[profile localstack]\nregion = us-east-1\n") + }, + want: false, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + dir := t.TempDir() + t.Setenv("HOME", dir) + tc.setup(t, dir) + ok, err := profileExists() + if err != nil { + t.Fatal(err) + } + if ok != tc.want { + t.Errorf("got %v, want %v", ok, tc.want) + } + }) + } +} + +func TestWriteProfile(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T, dir string) + check func(t *testing.T, dir string) + }{ + { + name: "creates files when absent", + setup: func(t *testing.T, dir string) {}, + check: func(t *testing.T, dir string) { + ok, err := profileExists() + if err != nil { + t.Fatal(err) + } + if !ok { + t.Error("expected profile to exist after writeProfile") + } + }, + }, + { + name: "preserves existing profiles", + setup: func(t *testing.T, dir string) { + writeFile(t, filepath.Join(dir, ".aws", "config"), "[profile default]\nregion = eu-west-1\n") + }, + check: func(t *testing.T, dir string) { + ok, err := sectionExists(filepath.Join(dir, ".aws", "config"), "profile default") + if err != nil { + t.Fatal(err) + } + if !ok { + t.Error("existing profile was lost after writeProfile") + } + }, + }, + { + name: "updates stale localstack section", + setup: func(t *testing.T, dir string) { + writeFile(t, filepath.Join(dir, ".aws", "config"), "[profile localstack]\nregion = eu-west-1\nendpoint_url = http://wrong.host:1234\n") + writeFile(t, filepath.Join(dir, ".aws", "credentials"), "[localstack]\naws_access_key_id = old\naws_secret_access_key = old\n") + }, + check: func(t *testing.T, dir string) { + configNeeded, err := configNeedsWrite(filepath.Join(dir, ".aws", "config"), "localhost.localstack.cloud:4566") + if err != nil { + t.Fatal(err) + } + if configNeeded { + t.Error("config should not need a write after writeProfile") + } + credsNeeded, err := credsNeedWrite(filepath.Join(dir, ".aws", "credentials")) + if err != nil { + t.Fatal(err) + } + if credsNeeded { + t.Error("credentials should not need a write after writeProfile") + } + }, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + dir := t.TempDir() + t.Setenv("HOME", dir) + tc.setup(t, dir) + if err := writeProfile("localhost.localstack.cloud:4566"); err != nil { + t.Fatal(err) + } + tc.check(t, dir) + }) + } +} + +func TestCheckProfileStatus(t *testing.T) { + tests := []struct { + name string + configContent string + credsContent string + resolvedHost string + wantConfig bool + wantCreds bool + }{ + { + name: "both files missing", + resolvedHost: "localhost.localstack.cloud:4566", + wantConfig: true, + wantCreds: true, + }, + { + name: "valid profile needs nothing", + configContent: "[profile localstack]\nregion = us-east-1\noutput = json\nendpoint_url = http://localhost.localstack.cloud:4566\n", + credsContent: "[localstack]\naws_access_key_id = test\naws_secret_access_key = test\n", + resolvedHost: "localhost.localstack.cloud:4566", + wantConfig: false, + wantCreds: false, + }, + { + name: "missing endpoint_url", + configContent: "[profile localstack]\nregion = us-east-1\n", + credsContent: "[localstack]\naws_access_key_id = test\naws_secret_access_key = test\n", + resolvedHost: "localhost.localstack.cloud:4566", + wantConfig: true, + wantCreds: false, + }, + { + name: "invalid endpoint_url", + configContent: "[profile localstack]\nregion = us-east-1\nendpoint_url = http://some-other-host:4566\n", + credsContent: "[localstack]\naws_access_key_id = test\naws_secret_access_key = test\n", + resolvedHost: "localhost.localstack.cloud:4566", + wantConfig: true, + wantCreds: false, + }, + { + name: "wrong credentials", + configContent: "[profile localstack]\nregion = us-east-1\noutput = json\nendpoint_url = http://127.0.0.1:4566\n", + credsContent: "[localstack]\naws_access_key_id = wrong\naws_secret_access_key = wrong\n", + resolvedHost: "127.0.0.1:4566", + wantConfig: false, + wantCreds: true, + }, + { + name: "127.0.0.1 profile valid when DNS now resolves to localhost.localstack.cloud", + configContent: "[profile localstack]\nregion = us-east-1\noutput = json\nendpoint_url = http://127.0.0.1:4566\n", + credsContent: "[localstack]\naws_access_key_id = test\naws_secret_access_key = test\n", + resolvedHost: "localhost.localstack.cloud:4566", + wantConfig: false, + wantCreds: false, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, ".aws", "config") + credsPath := filepath.Join(dir, ".aws", "credentials") + if tc.configContent != "" { + writeFile(t, configPath, tc.configContent) + } + if tc.credsContent != "" { + writeFile(t, credsPath, tc.credsContent) + } + status, err := checkProfileStatus(configPath, credsPath, tc.resolvedHost) + if err != nil { + t.Fatal(err) + } + if status.configNeeded != tc.wantConfig { + t.Errorf("configNeeded: got %v, want %v", status.configNeeded, tc.wantConfig) + } + if status.credsNeeded != tc.wantCreds { + t.Errorf("credsNeeded: got %v, want %v", status.credsNeeded, tc.wantCreds) + } + }) + } +} + +func TestCheckProfileStatusMalformedFile(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, ".aws", "config") + credsPath := filepath.Join(dir, ".aws", "credentials") + + writeFile(t, configPath, "this is not valid \x00\x01\x02 ini content [[[") + writeFile(t, credsPath, "[localstack]\naws_access_key_id = test\naws_secret_access_key = test\n") + + _, err := checkProfileStatus(configPath, credsPath, "127.0.0.1:4566") + if err == nil { + t.Error("expected error for malformed config file, got nil") + } +} + +func TestIsValidLocalStackEndpoint(t *testing.T) { + tests := []struct { + name string + endpointURL string + resolvedHost string + want bool + }{ + { + name: "valid http", + endpointURL: "http://localhost.localstack.cloud:4566", + resolvedHost: "localhost.localstack.cloud:4566", + want: true, + }, + { + name: "valid https", + endpointURL: "https://localhost.localstack.cloud:4566", + resolvedHost: "localhost.localstack.cloud:4566", + want: true, + }, + { + name: "valid fallback ip", + endpointURL: "http://127.0.0.1:4566", + resolvedHost: "127.0.0.1:4566", + want: true, + }, + { + name: "wrong host", + endpointURL: "http://some-other-host:4566", + resolvedHost: "localhost.localstack.cloud:4566", + want: false, + }, + { + name: "wrong port", + endpointURL: "http://localhost.localstack.cloud:9999", + resolvedHost: "localhost.localstack.cloud:4566", + want: false, + }, + { + name: "missing port", + endpointURL: "http://localhost.localstack.cloud", + resolvedHost: "localhost.localstack.cloud:4566", + want: false, + }, + { + name: "trailing slash", + endpointURL: "http://localhost.localstack.cloud:4566/", + resolvedHost: "localhost.localstack.cloud:4566", + want: true, // trailing slash is functionally equivalent; host still matches + }, + { + name: "unsupported scheme", + endpointURL: "ftp://localhost.localstack.cloud:4566", + resolvedHost: "localhost.localstack.cloud:4566", + want: false, + }, + { + name: "unparseable url", + endpointURL: "://bad-url", + resolvedHost: "localhost.localstack.cloud:4566", + want: false, + }, + { + name: "empty string", + endpointURL: "", + resolvedHost: "localhost.localstack.cloud:4566", + want: false, + }, + { + name: "127.0.0.1 accepted when resolved host is localhost.localstack.cloud", + endpointURL: "http://127.0.0.1:4566", + resolvedHost: "localhost.localstack.cloud:4566", + want: true, + }, + { + name: "localhost.localstack.cloud accepted when resolved host is 127.0.0.1", + endpointURL: "http://localhost.localstack.cloud:4566", + resolvedHost: "127.0.0.1:4566", + want: true, + }, + { + name: "127.0.0.1 with wrong port rejected", + endpointURL: "http://127.0.0.1:9999", + resolvedHost: "localhost.localstack.cloud:4566", + want: false, + }, + { + name: "localhost accepted when resolved host is localhost.localstack.cloud", + endpointURL: "http://localhost:4566", + resolvedHost: "localhost.localstack.cloud:4566", + want: true, + }, + { + name: "localhost accepted when resolved host is 127.0.0.1", + endpointURL: "http://localhost:4566", + resolvedHost: "127.0.0.1:4566", + want: true, + }, + { + name: "custom host not interchangeable with 127.0.0.1", + endpointURL: "http://127.0.0.1:4566", + resolvedHost: "myhost.internal:4566", + want: false, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := isValidLocalStackEndpoint(tc.endpointURL, tc.resolvedHost) + if got != tc.want { + t.Errorf("isValidLocalStackEndpoint(%q, %q) = %v, want %v", tc.endpointURL, tc.resolvedHost, got, tc.want) + } + }) + } +} + +func TestFilesToModify(t *testing.T) { + tests := []struct { + name string + status profileStatus + wantFiles []string + }{ + { + name: "both needed", + status: profileStatus{configNeeded: true, credsNeeded: true}, + wantFiles: []string{"~/.aws/config", "~/.aws/credentials"}, + }, + { + name: "config only", + status: profileStatus{configNeeded: true, credsNeeded: false}, + wantFiles: []string{"~/.aws/config"}, + }, + { + name: "credentials only", + status: profileStatus{configNeeded: false, credsNeeded: true}, + wantFiles: []string{"~/.aws/credentials"}, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := tc.status.filesToModify() + if len(got) != len(tc.wantFiles) { + t.Fatalf("got %v, want %v", got, tc.wantFiles) + } + for i, want := range tc.wantFiles { + if got[i] != want { + t.Errorf("files[%d]: got %q, want %q", i, got[i], want) + } + } + }) + } +} diff --git a/internal/awsconfig/ini.go b/internal/awsconfig/ini.go new file mode 100644 index 00000000..b6beaa86 --- /dev/null +++ b/internal/awsconfig/ini.go @@ -0,0 +1,53 @@ +package awsconfig + +import ( + "errors" + "os" + "path/filepath" + "strings" + + "gopkg.in/ini.v1" +) + +func sectionExists(path, sectionName string) (bool, error) { + f, err := ini.Load(path) + if errors.Is(err, os.ErrNotExist) { + return false, nil + } + if err != nil { + return false, err + } + for _, s := range f.Sections() { + if strings.TrimSpace(s.Name()) == sectionName { + return true, nil + } + } + return false, nil +} + +func upsertSection(path, sectionName string, keys map[string]string) error { + if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil { + return err + } + + var f *ini.File + if _, err := os.Stat(path); errors.Is(err, os.ErrNotExist) { + f = ini.Empty() + } else { + var err error + f, err = ini.Load(path) + if err != nil { + return err + } + } + + section := f.Section(sectionName) // gets or creates the section + for k, v := range keys { + section.Key(k).SetValue(v) + } + + if err := f.SaveTo(path); err != nil { + return err + } + return os.Chmod(path, 0600) +} diff --git a/internal/container/start.go b/internal/container/start.go index 4fdd81ae..1de34561 100644 --- a/internal/container/start.go +++ b/internal/container/start.go @@ -6,28 +6,42 @@ import ( "net/http" "os" stdruntime "runtime" + "slices" "time" "github.com/containerd/errdefs" "github.com/localstack/lstk/internal/api" "github.com/localstack/lstk/internal/auth" + "github.com/localstack/lstk/internal/awsconfig" "github.com/localstack/lstk/internal/config" + "github.com/localstack/lstk/internal/endpoint" "github.com/localstack/lstk/internal/output" "github.com/localstack/lstk/internal/ports" "github.com/localstack/lstk/internal/runtime" ) -func Start(ctx context.Context, rt runtime.Runtime, sink output.Sink, platformClient api.PlatformAPI, authToken string, forceFileKeyring bool, webAppURL string, interactive bool) error { +type postStartSetupFunc func(ctx context.Context, sink output.Sink, interactive bool, resolvedHost string) error + +// StartOptions groups the user-provided options for starting an emulator. +type StartOptions struct { + PlatformClient api.PlatformAPI + AuthToken string + ForceFileKeyring bool + WebAppURL string + LocalStackHost string +} + +func Start(ctx context.Context, rt runtime.Runtime, sink output.Sink, opts StartOptions, interactive bool) error { if err := rt.IsHealthy(ctx); err != nil { rt.EmitUnhealthyError(sink, err) return output.NewSilentError(fmt.Errorf("runtime not healthy: %w", err)) } - tokenStorage, err := auth.NewTokenStorage(forceFileKeyring) + tokenStorage, err := auth.NewTokenStorage(opts.ForceFileKeyring) if err != nil { return fmt.Errorf("failed to initialize token storage: %w", err) } - a := auth.New(sink, platformClient, tokenStorage, authToken, webAppURL, interactive) + a := auth.New(sink, opts.PlatformClient, tokenStorage, opts.AuthToken, opts.WebAppURL, interactive) token, err := a.GetToken(ctx) if err != nil { @@ -84,11 +98,44 @@ func Start(ctx context.Context, rt runtime.Runtime, sink output.Sink, platformCl return err } - if err := validateLicenses(ctx, rt, sink, platformClient, containers, token); err != nil { + if err := validateLicenses(ctx, rt, sink, opts.PlatformClient, containers, token); err != nil { return err } - return startContainers(ctx, rt, sink, containers) + if err := startContainers(ctx, rt, sink, containers); err != nil { + return err + } + + // Maps emulator types to their post-start setup functions. + // Add an entry here to run setup for a new emulator type (e.g. Azure, Snowflake). + setups := map[config.EmulatorType]postStartSetupFunc{ + config.EmulatorAWS: awsconfig.Setup, + } + return runPostStartSetups(ctx, sink, cfg.Containers, interactive, opts.LocalStackHost, setups) +} + +func runPostStartSetups(ctx context.Context, sink output.Sink, containers []config.ContainerConfig, interactive bool, localStackHost string, setups map[config.EmulatorType]postStartSetupFunc) error { + // build ordered list of unique types, keeping the first container config for each + firstByType := map[config.EmulatorType]config.ContainerConfig{} + var uniqueEmulatorTypes []config.EmulatorType + for _, c := range containers { + if !slices.Contains(uniqueEmulatorTypes, c.Type) { + uniqueEmulatorTypes = append(uniqueEmulatorTypes, c.Type) + firstByType[c.Type] = c + } + } + for _, t := range uniqueEmulatorTypes { + if setup, ok := setups[t]; ok { + resolvedHost, dnsOK := endpoint.ResolveHost(firstByType[t].Port, localStackHost) + if !dnsOK { + output.EmitNote(sink, `Could not resolve "localhost.localstack.cloud" — your system may have DNS rebind protection enabled. Using 127.0.0.1 as the endpoint.`) + } + if err := setup(ctx, sink, interactive, resolvedHost); err != nil { + return err + } + } + } + return nil } func pullImages(ctx context.Context, rt runtime.Runtime, sink output.Sink, containers []runtime.ContainerConfig) error { diff --git a/internal/container/start_test.go b/internal/container/start_test.go index 97004f71..f422baa2 100644 --- a/internal/container/start_test.go +++ b/internal/container/start_test.go @@ -21,7 +21,7 @@ func TestStart_ReturnsEarlyIfRuntimeUnhealthy(t *testing.T) { sink := output.NewPlainSink(io.Discard) - err := Start(context.Background(), mockRT, sink, nil, "", false, "", false) + err := Start(context.Background(), mockRT, sink, StartOptions{}, false) require.Error(t, err) assert.Contains(t, err.Error(), "runtime not healthy") diff --git a/internal/endpoint/endpoint.go b/internal/endpoint/endpoint.go new file mode 100644 index 00000000..04d62de9 --- /dev/null +++ b/internal/endpoint/endpoint.go @@ -0,0 +1,27 @@ +package endpoint + +import "net" + +const Hostname = "localhost.localstack.cloud" + +// ResolveHost returns the best host:port for reaching LocalStack on the given port. +// If override is non-empty it is returned as-is. Otherwise a DNS check is performed; +// if Hostname does not resolve to 127.0.0.1 (e.g. DNS rebind protection is active), +// it falls back to 127.0.0.1 directly. +func ResolveHost(port, override string) (host string, dnsOK bool) { + if override != "" { + return override, true + } + // Use a "test." subdomain: *.localhost.localstack.cloud has wildcard DNS that resolves + // to 127.0.0.1, so any subdomain works as a probe without hitting the actual service. + addrs, err := net.LookupHost("test." + Hostname) + if err != nil { + return "127.0.0.1:" + port, false + } + for _, addr := range addrs { + if addr == "127.0.0.1" { + return Hostname + ":" + port, true + } + } + return "127.0.0.1:" + port, false +} diff --git a/internal/env/env.go b/internal/env/env.go index c2b53c2f..c3703b55 100644 --- a/internal/env/env.go +++ b/internal/env/env.go @@ -8,12 +8,15 @@ import ( ) type Env struct { - AuthToken string + AuthToken string + LocalStackHost string + DisableEvents bool + APIEndpoint string WebAppURL string ForceFileKeyring bool AnalyticsEndpoint string - DisableEvents bool + NonInteractive bool } @@ -26,15 +29,16 @@ func Init() *Env { viper.SetDefault("api_endpoint", "https://api.localstack.cloud") viper.SetDefault("web_app_url", "https://app.localstack.cloud") viper.SetDefault("analytics_endpoint", "https://analytics.localstack.cloud/v1/events") - // LOCALSTACK_AUTH_TOKEN and LOCALSTACK_DISABLE_EVENTS are not prefixed with LSTK_ - // so they work seamlessly across all LocalStack tools without per-tool configuration + // LOCALSTACK_* variables are not prefixed with LSTK_ so they work seamlessly + // across all LocalStack tools without per-tool configuration return &Env{ AuthToken: os.Getenv("LOCALSTACK_AUTH_TOKEN"), + LocalStackHost: os.Getenv("LOCALSTACK_HOST"), + DisableEvents: os.Getenv("LOCALSTACK_DISABLE_EVENTS") == "1", APIEndpoint: viper.GetString("api_endpoint"), WebAppURL: viper.GetString("web_app_url"), ForceFileKeyring: viper.GetString("keyring") == "file", AnalyticsEndpoint: viper.GetString("analytics_endpoint"), - DisableEvents: os.Getenv("LOCALSTACK_DISABLE_EVENTS") == "1", } } diff --git a/internal/ui/app_test.go b/internal/ui/app_test.go index 594678e3..c3df373a 100644 --- a/internal/ui/app_test.go +++ b/internal/ui/app_test.go @@ -510,10 +510,10 @@ func TestResolveOption(t *testing.T) { keyYUpper := tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'Y'}} tests := []struct { - name string - options []output.InputOption - press tea.KeyMsg - wantOptionKey string // key of the expected returned option; empty means nil + name string + options []output.InputOption + press tea.KeyMsg + wantOptionKey string // key of the expected returned option; empty means nil }{ // "any" option { @@ -614,6 +614,7 @@ func TestResolveOption(t *testing.T) { } else { if got == nil { t.Fatal("expected non-nil option, got nil") + return } if got.Key != tc.wantOptionKey { t.Fatalf("got key %q, want %q", got.Key, tc.wantOptionKey) diff --git a/internal/ui/run.go b/internal/ui/run.go index 9749fd8a..0881f482 100644 --- a/internal/ui/run.go +++ b/internal/ui/run.go @@ -3,13 +3,12 @@ package ui import ( "context" "errors" - "fmt" "os" tea "github.com/charmbracelet/bubbletea" - "github.com/localstack/lstk/internal/api" "github.com/localstack/lstk/internal/config" "github.com/localstack/lstk/internal/container" + "github.com/localstack/lstk/internal/endpoint" "github.com/localstack/lstk/internal/output" "github.com/localstack/lstk/internal/runtime" "golang.org/x/term" @@ -26,28 +25,28 @@ func (s programSender) Send(msg any) { s.p.Send(msg) } -func Run(parentCtx context.Context, rt runtime.Runtime, version string, platformClient api.PlatformAPI, authToken string, forceFileKeyring bool, webAppURL string) error { +func Run(parentCtx context.Context, rt runtime.Runtime, version string, opts container.StartOptions) error { ctx, cancel := context.WithCancel(parentCtx) defer cancel() // FIXME: This assumes a single emulator; revisit for proper multi-emulator support emulatorName := "LocalStack Emulator" - endpoint := "localhost.localstack.cloud" + host := endpoint.Hostname if cfg, err := config.Get(); err == nil && len(cfg.Containers) > 0 { emulatorName = cfg.Containers[0].DisplayName() if cfg.Containers[0].Port != "" { - endpoint = fmt.Sprintf("localhost.localstack.cloud:%s", cfg.Containers[0].Port) + host, _ = endpoint.ResolveHost(cfg.Containers[0].Port, opts.LocalStackHost) } } - app := NewApp(version, emulatorName, endpoint, cancel) + app := NewApp(version, emulatorName, host, cancel) p := tea.NewProgram(app) runErrCh := make(chan error, 1) go func() { var err error defer func() { runErrCh <- err }() - err = container.Start(ctx, rt, output.NewTUISink(programSender{p: p}), platformClient, authToken, forceFileKeyring, webAppURL, true) + err = container.Start(ctx, rt, output.NewTUISink(programSender{p: p}), opts, true) if err != nil { if errors.Is(err, context.Canceled) { return diff --git a/test/integration/awsconfig_test.go b/test/integration/awsconfig_test.go new file mode 100644 index 00000000..e4552478 --- /dev/null +++ b/test/integration/awsconfig_test.go @@ -0,0 +1,157 @@ +package integration_test + +import ( + "bytes" + "io" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "testing" + "time" + + "github.com/creack/pty" + "github.com/localstack/lstk/test/integration/env" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// awsConfigEnv returns a base environment with HOME set to an isolated temp +// directory, so tests never touch the real ~/.aws files. +func awsConfigEnv(t *testing.T) (env.Environ, string) { + t.Helper() + tmpHome := t.TempDir() + e := env.With(env.AuthToken, env.Get(env.AuthToken)).With(env.Home, tmpHome) + return e, tmpHome +} + +func TestStartPromptsToCreateAWSProfileWhenMissing(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("PTY not supported on Windows") + } + requireDocker(t) + _ = env.Require(t, env.AuthToken) + + cleanup() + t.Cleanup(cleanup) + + baseEnv, tmpHome := awsConfigEnv(t) + mockServer := createMockLicenseServer(true) + defer mockServer.Close() + + ctx := testContext(t) + cmd := exec.CommandContext(ctx, binaryPath(), "start") + cmd.Env = baseEnv.With(env.APIEndpoint, mockServer.URL) + + ptmx, err := pty.Start(cmd) + require.NoError(t, err, "failed to start command in PTY") + defer func() { _ = ptmx.Close() }() + + out := &syncBuffer{} + outputCh := make(chan struct{}) + go func() { + _, _ = io.Copy(out, ptmx) + close(outputCh) + }() + + // Wait for the AWS profile prompt. + require.Eventually(t, func() bool { + return bytes.Contains(out.Bytes(), []byte("Set up LocalStack AWS profile")) + }, 2*time.Minute, 200*time.Millisecond, "AWS profile prompt should appear") + + // Press Y to confirm. + _, err = ptmx.Write([]byte("y")) + require.NoError(t, err) + + // Wait for the success message. + require.Eventually(t, func() bool { + return bytes.Contains(out.Bytes(), []byte("LocalStack AWS profile written")) + }, 10*time.Second, 200*time.Millisecond, "success message should appear") + + // Verify files were written to the isolated home dir, not the real one. + configContent, err := os.ReadFile(filepath.Join(tmpHome, ".aws", "config")) + require.NoError(t, err, "~/.aws/config should have been created") + assert.Contains(t, string(configContent), "[profile localstack]") + assert.Contains(t, string(configContent), "endpoint_url") + + credsContent, err := os.ReadFile(filepath.Join(tmpHome, ".aws", "credentials")) + require.NoError(t, err, "~/.aws/credentials should have been created") + normalizedCreds := strings.Join(strings.Fields(string(credsContent)), " ") + assert.Contains(t, normalizedCreds, "[localstack]") + assert.Contains(t, normalizedCreds, "aws_access_key_id = test") + assert.Contains(t, normalizedCreds, "aws_secret_access_key = test") + + _ = cmd.Wait() + <-outputCh +} + +func TestStartSkipsAWSProfilePromptWhenAlreadyConfigured(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("PTY not supported on Windows") + } + requireDocker(t) + _ = env.Require(t, env.AuthToken) + + cleanup() + t.Cleanup(cleanup) + + baseEnv, tmpHome := awsConfigEnv(t) + mockServer := createMockLicenseServer(true) + defer mockServer.Close() + + // Pre-write a valid LocalStack AWS profile in the isolated home. + awsDir := filepath.Join(tmpHome, ".aws") + require.NoError(t, os.MkdirAll(awsDir, 0700)) + require.NoError(t, os.WriteFile(filepath.Join(awsDir, "config"), + []byte("[profile localstack]\nregion = us-east-1\noutput = json\nendpoint_url = http://127.0.0.1:4566\n"), 0600)) + require.NoError(t, os.WriteFile(filepath.Join(awsDir, "credentials"), + []byte("[localstack]\naws_access_key_id = test\naws_secret_access_key = test\n"), 0600)) + + ctx := testContext(t) + cmd := exec.CommandContext(ctx, binaryPath(), "start") + cmd.Env = baseEnv.With(env.APIEndpoint, mockServer.URL) + + ptmx, err := pty.Start(cmd) + require.NoError(t, err, "failed to start command in PTY") + defer func() { _ = ptmx.Close() }() + + out := &syncBuffer{} + outputCh := make(chan struct{}) + go func() { + _, _ = io.Copy(out, ptmx) + close(outputCh) + }() + + // Wait until the container is ready — that's the point at which post-start setup + // runs, so if the prompt were going to appear it would already be in the output. + require.Eventually(t, func() bool { + return bytes.Contains(out.Bytes(), []byte(" ready")) + }, 2*time.Minute, 200*time.Millisecond, "container should become ready") + + _ = cmd.Process.Kill() + _ = cmd.Wait() + <-outputCh + + assert.NotContains(t, out.String(), "Set up LocalStack AWS profile", + "profile prompt should not appear when profile is already correctly configured") +} + +func TestStartNonInteractiveEmitsNoteWhenAWSProfileMissing(t *testing.T) { + requireDocker(t) + _ = env.Require(t, env.AuthToken) + + cleanup() + t.Cleanup(cleanup) + + baseEnv, _ := awsConfigEnv(t) + mockServer := createMockLicenseServer(true) + defer mockServer.Close() + + stdout, _, err := runLstk(t, testContext(t), "", + baseEnv.With(env.APIEndpoint, mockServer.URL), + "start", + ) + require.NoError(t, err) + assert.Contains(t, stdout, "No complete LocalStack AWS profile found") +} diff --git a/test/integration/env/env.go b/test/integration/env/env.go index 2f7cb88a..5146fd7c 100644 --- a/test/integration/env/env.go +++ b/test/integration/env/env.go @@ -15,6 +15,7 @@ const ( CI Key = "CI" AnalyticsEndpoint Key = "LSTK_ANALYTICS_ENDPOINT" DisableEvents Key = "LOCALSTACK_DISABLE_EVENTS" + Home Key = "HOME" ) func Get(key Key) string {