diff --git a/cmd/xsql/command_unit_test.go b/cmd/xsql/command_unit_test.go index 4ca30dd..aef2b65 100644 --- a/cmd/xsql/command_unit_test.go +++ b/cmd/xsql/command_unit_test.go @@ -6,7 +6,10 @@ import ( "encoding/json" "os" "path/filepath" + "runtime" + "syscall" "testing" + "time" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -134,7 +137,7 @@ func TestRunQuery_MissingDB(t *testing.T) { var out bytes.Buffer w := output.New(&out, &bytes.Buffer{}) - err := runQuery(nil, []string{"select 1"}, &QueryFlags{}, &w) + err := runQuery([]string{"select 1"}, &QueryFlags{}, &w) if err == nil { t.Fatal("expected error for missing db type") } @@ -149,7 +152,7 @@ func TestRunQuery_UnsupportedDriver(t *testing.T) { var out bytes.Buffer w := output.New(&out, &bytes.Buffer{}) - err := runQuery(nil, []string{"select 1"}, &QueryFlags{}, &w) + err := runQuery([]string{"select 1"}, &QueryFlags{}, &w) if err == nil { t.Fatal("expected error for unsupported driver") } @@ -168,7 +171,7 @@ func TestRunQuery_PlaintextPasswordNotAllowed(t *testing.T) { var out bytes.Buffer w := output.New(&out, &bytes.Buffer{}) - err := runQuery(nil, []string{"select 1"}, &QueryFlags{}, &w) + err := runQuery([]string{"select 1"}, &QueryFlags{}, &w) if err == nil { t.Fatal("expected error for plaintext password not allowed") } @@ -183,7 +186,7 @@ func TestRunSchemaDump_UnsupportedDriver(t *testing.T) { var out bytes.Buffer w := output.New(&out, &bytes.Buffer{}) - err := runSchemaDump(nil, nil, &SchemaFlags{}, &w) + err := runSchemaDump(&SchemaFlags{}, &w) if err == nil { t.Fatal("expected error for unsupported driver") } @@ -202,7 +205,7 @@ func TestRunSchemaDump_PlaintextPasswordNotAllowed(t *testing.T) { var out bytes.Buffer w := output.New(&out, &bytes.Buffer{}) - err := runSchemaDump(nil, nil, &SchemaFlags{}, &w) + err := runSchemaDump(&SchemaFlags{}, &w) if err == nil { t.Fatal("expected error for plaintext password not allowed") } @@ -217,7 +220,7 @@ func TestRunQuery_InvalidFormat(t *testing.T) { var out bytes.Buffer w := output.New(&out, &bytes.Buffer{}) - err := runQuery(nil, []string{"select 1"}, &QueryFlags{}, &w) + err := runQuery([]string{"select 1"}, &QueryFlags{}, &w) if err == nil { t.Fatal("expected error for invalid format") } @@ -232,7 +235,7 @@ func TestRunSchemaDump_MissingDB(t *testing.T) { var out bytes.Buffer w := output.New(&out, &bytes.Buffer{}) - err := runSchemaDump(nil, nil, &SchemaFlags{}, &w) + err := runSchemaDump(&SchemaFlags{}, &w) if err == nil { t.Fatal("expected error for missing db type") } @@ -247,7 +250,7 @@ func TestRunSchemaDump_InvalidFormat(t *testing.T) { var out bytes.Buffer w := output.New(&out, &bytes.Buffer{}) - err := runSchemaDump(nil, nil, &SchemaFlags{}, &w) + err := runSchemaDump(&SchemaFlags{}, &w) if err == nil { t.Fatal("expected error for invalid format") } @@ -348,7 +351,7 @@ func TestRunProxy_SSHConnectError(t *testing.T) { } func TestResolveSSH_NoConfig(t *testing.T) { - client, err := app.ResolveSSH(nil, config.Profile{}, false, false) + client, err := app.ResolveSSH(context.TODO(), config.Profile{}, false, false) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -909,3 +912,262 @@ func TestFirstNonEmpty(t *testing.T) { func configProfile(dbType string) config.Profile { return config.Profile{DB: dbType} } + +func TestHandlePortConflict_NonTTY(t *testing.T) { + _, err := handlePortConflict(3306, "127.0.0.1") + if err == nil { + t.Fatal("expected error for non-TTY port conflict") + } + if err.Code != errors.CodePortInUse { + t.Fatalf("expected CodePortInUse, got %s", err.Code) + } +} + +func TestModeForWebCommand(t *testing.T) { + if got := modeForWebCommand(true); got != "web" { + t.Fatalf("expected 'web', got %q", got) + } + if got := modeForWebCommand(false); got != "serve" { + t.Fatalf("expected 'serve', got %q", got) + } +} + +func TestResolveWebOptions_InvalidAddr(t *testing.T) { + _, xe := resolveWebOptions(&webCommandOptions{ + addr: "not-a-valid-addr", + addrSet: true, + }, config.File{}) + if xe == nil { + t.Fatal("expected error for invalid addr") + } + if xe.Code != errors.CodeCfgInvalid { + t.Fatalf("expected CodeCfgInvalid, got %s", xe.Code) + } +} + +func TestResolveWebOptions_EnvVars(t *testing.T) { + t.Setenv("XSQL_WEB_HTTP_AUTH_TOKEN", "env-token") + resolved, xe := resolveWebOptions(&webCommandOptions{ + addr: "0.0.0.0:9999", + addrSet: true, + }, config.File{}) + if xe != nil { + t.Fatalf("unexpected error: %v", xe) + } + if resolved.authToken != "env-token" { + t.Fatalf("expected env-token, got %s", resolved.authToken) + } + if !resolved.authRequired { + t.Fatal("expected authRequired=true") + } +} + +func TestResolveMCPServerOptions_HttpAddrEnv(t *testing.T) { + t.Setenv("XSQL_MCP_TRANSPORT", "streamable_http") + t.Setenv("XSQL_MCP_HTTP_AUTH_TOKEN", "token") + t.Setenv("XSQL_MCP_HTTP_ADDR", "127.0.0.1:5555") + cfg := config.File{ + Profiles: map[string]config.Profile{}, + SSHProxies: map[string]config.SSHProxy{}, + } + resolved, xe := resolveMCPServerOptions(&mcpServerOptions{}, cfg) + if xe != nil { + t.Fatalf("unexpected error: %v", xe) + } + if resolved.httpAddr != "127.0.0.1:5555" { + t.Fatalf("expected 127.0.0.1:5555, got %s", resolved.httpAddr) + } +} + +func TestRunMCPServer_InvalidConfigPath(t *testing.T) { + GlobalConfig.ConfigStr = "/nonexistent/path/config.yaml" + err := runMCPServer(&mcpServerOptions{}) + if err == nil { + t.Fatal("expected error for nonexistent config") + } +} + +func TestRunMCPServer_StreamableHTTPStarts(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("skipping signal handling test on Windows") + } + configPath := filepath.Join(t.TempDir(), "xsql.yaml") + if err := os.WriteFile(configPath, []byte("profiles: {}\nmcp:\n transport: streamable_http\n http:\n addr: 127.0.0.1:0\n auth_token: test-token\n allow_plaintext_token: true\n"), 0644); err != nil { + t.Fatalf("failed to write config: %v", err) + } + + GlobalConfig.ConfigStr = configPath + + // This will start the HTTP server and then we need to stop it + // We'll use a goroutine to run it and cancel after a short time + done := make(chan error, 1) + go func() { + done <- runMCPServer(&mcpServerOptions{}) + }() + + // Give the server time to start + time.Sleep(100 * time.Millisecond) + + // The server should be running, send SIGINT to stop it + p, _ := os.FindProcess(os.Getpid()) + _ = p.Signal(syscall.SIGINT) + + select { + case err := <-done: + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for server to stop") + } +} + +func TestNewServeCommand(t *testing.T) { + var out bytes.Buffer + w := output.New(&out, &bytes.Buffer{}) + cmd := NewServeCommand(&w) + if cmd.Use != "serve" { + t.Fatalf("expected 'serve', got %s", cmd.Use) + } +} + +func TestNewWebCommand(t *testing.T) { + var out bytes.Buffer + w := output.New(&out, &bytes.Buffer{}) + cmd := NewWebCommand(&w) + if cmd.Use != "web" { + t.Fatalf("expected 'web', got %s", cmd.Use) + } +} + +func TestNewMCPCommand(t *testing.T) { + cmd := NewMCPCommand() + if cmd.Use != "mcp" { + t.Fatalf("expected 'mcp', got %s", cmd.Use) + } +} + +func TestNewProfileCommand(t *testing.T) { + var out bytes.Buffer + w := output.New(&out, &bytes.Buffer{}) + cmd := NewProfileCommand(&w) + if cmd.Use != "profile" { + t.Fatalf("expected 'profile', got %s", cmd.Use) + } +} + +func TestNewSchemaCommand(t *testing.T) { + var out bytes.Buffer + w := output.New(&out, &bytes.Buffer{}) + cmd := NewSchemaCommand(&w) + if cmd.Use != "schema" { + t.Fatalf("expected 'schema', got %s", cmd.Use) + } +} + +func TestResolveWebOptions_NilOpts(t *testing.T) { + resolved, xe := resolveWebOptions(nil, config.File{}) + if xe != nil { + t.Fatalf("unexpected error: %v", xe) + } + if resolved.addr != "127.0.0.1:8788" { + t.Fatalf("expected default addr, got %s", resolved.addr) + } +} + +func TestResolveWebOptions_ConfigAddr(t *testing.T) { + resolved, xe := resolveWebOptions(&webCommandOptions{}, config.File{ + Web: config.WebConfig{ + HTTP: config.WebHTTPConfig{ + Addr: "127.0.0.1:9999", + }, + }, + }) + if xe != nil { + t.Fatalf("unexpected error: %v", xe) + } + if resolved.addr != "127.0.0.1:9999" { + t.Fatalf("expected 127.0.0.1:9999, got %s", resolved.addr) + } +} + +func TestResolveWebOptions_NonLoopbackRequiresToken(t *testing.T) { + _, xe := resolveWebOptions(&webCommandOptions{ + addr: "10.0.0.1:8788", + addrSet: true, + }, config.File{}) + if xe == nil { + t.Fatal("expected error for non-loopback without token") + } + if xe.Code != errors.CodeCfgInvalid { + t.Fatalf("expected CodeCfgInvalid, got %s", xe.Code) + } +} + +func TestRunProxy_InvalidFormat(t *testing.T) { + prev := GlobalConfig + GlobalConfig = &Config{ProfileStr: "dev", FormatStr: "invalid"} + t.Cleanup(func() { GlobalConfig = prev }) + + GlobalConfig.Resolved.Profile = config.Profile{DB: "mysql", SSHConfig: &config.SSHProxy{Host: "h", Port: 22, User: "u"}} + + var out bytes.Buffer + w := output.New(&out, &bytes.Buffer{}) + err := runProxy(nil, &ProxyFlags{}, &w) + if err == nil { + t.Fatal("expected error for invalid format") + } +} + +func TestRunProxy_PlaintextNotAllowed(t *testing.T) { + prev := GlobalConfig + GlobalConfig = &Config{ProfileStr: "dev", FormatStr: "json"} + t.Cleanup(func() { GlobalConfig = prev }) + + GlobalConfig.Resolved.Profile = config.Profile{ + DB: "mysql", + Password: "plain", + AllowPlaintext: false, + SSHConfig: &config.SSHProxy{Host: "h", Port: 22, User: "u"}, + } + + var out bytes.Buffer + w := output.New(&out, &bytes.Buffer{}) + err := runProxy(nil, &ProxyFlags{}, &w) + if err == nil { + t.Fatal("expected error for plaintext not allowed") + } +} + +func TestResolveProxyPort_AllPaths(t *testing.T) { + // Test with config port and no CLI flag + cmd := NewProxyCommand(nil) + port, fromConfig := resolveProxyPort(cmd, &ProxyFlags{}, 5555) + if port != 5555 || !fromConfig { + t.Errorf("expected port=5555, fromConfig=true, got port=%d, fromConfig=%v", port, fromConfig) + } + + // Test with zero config port + port, fromConfig = resolveProxyPort(cmd, &ProxyFlags{}, 0) + if port != 0 || fromConfig { + t.Errorf("expected port=0, fromConfig=false, got port=%d, fromConfig=%v", port, fromConfig) + } +} + +func TestNewQueryCommand(t *testing.T) { + var out bytes.Buffer + w := output.New(&out, &bytes.Buffer{}) + cmd := NewQueryCommand(&w) + if cmd.Use != "query [SQL]" { + t.Fatalf("expected 'query [SQL]', got %s", cmd.Use) + } +} + +func TestNewProxyCommand(t *testing.T) { + var out bytes.Buffer + w := output.New(&out, &bytes.Buffer{}) + cmd := NewProxyCommand(&w) + if cmd.Use != "proxy [flags]" { + t.Fatalf("expected 'proxy [flags]', got %s", cmd.Use) + } +} diff --git a/cmd/xsql/query.go b/cmd/xsql/query.go index 2217c2f..b2685cf 100644 --- a/cmd/xsql/query.go +++ b/cmd/xsql/query.go @@ -31,7 +31,7 @@ func NewQueryCommand(w *output.Writer) *cobra.Command { Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { flags.QueryTimeoutSet = cmd.Flags().Changed("query-timeout") - return runQuery(cmd, args, flags, w) + return runQuery(args, flags, w) }, } @@ -44,7 +44,7 @@ func NewQueryCommand(w *output.Writer) *cobra.Command { } // runQuery executes a SQL query -func runQuery(cmd *cobra.Command, args []string, flags *QueryFlags, w *output.Writer) error { +func runQuery(args []string, flags *QueryFlags, w *output.Writer) error { sql := args[0] format, err := parseOutputFormat(GlobalConfig.FormatStr) if err != nil { diff --git a/cmd/xsql/root.go b/cmd/xsql/root.go index 7f2855d..5ba096a 100644 --- a/cmd/xsql/root.go +++ b/cmd/xsql/root.go @@ -50,8 +50,8 @@ func NewRootCommand() *cobra.Command { CLIFormatSet: formatSet, EnvProfile: os.Getenv("XSQL_PROFILE"), EnvFormat: os.Getenv("XSQL_FORMAT"), - WorkDir: "", - HomeDir: "", + WorkDir: os.Getenv("XSQL_WORKDIR"), + HomeDir: os.Getenv("XSQL_HOMEDIR"), }) if xe != nil { return xe diff --git a/cmd/xsql/schema.go b/cmd/xsql/schema.go index dac0946..c9778f2 100644 --- a/cmd/xsql/schema.go +++ b/cmd/xsql/schema.go @@ -44,7 +44,7 @@ func NewSchemaDumpCommand(w *output.Writer, flags *SchemaFlags) *cobra.Command { Short: "Dump database schema (tables, columns, indexes, foreign keys)", RunE: func(cmd *cobra.Command, args []string) error { flags.SchemaTimeoutSet = cmd.Flags().Changed("schema-timeout") - return runSchemaDump(cmd, args, flags, w) + return runSchemaDump(flags, w) }, } @@ -58,7 +58,7 @@ func NewSchemaDumpCommand(w *output.Writer, flags *SchemaFlags) *cobra.Command { } // runSchemaDump executes the schema dump command -func runSchemaDump(cmd *cobra.Command, args []string, flags *SchemaFlags, w *output.Writer) error { +func runSchemaDump(flags *SchemaFlags, w *output.Writer) error { format, err := parseOutputFormat(GlobalConfig.FormatStr) if err != nil { return err diff --git a/cmd/xsql/web_test.go b/cmd/xsql/web_test.go index 0a358ef..c76bb36 100644 --- a/cmd/xsql/web_test.go +++ b/cmd/xsql/web_test.go @@ -5,6 +5,8 @@ import ( "fmt" "net" "os" + "runtime" + "syscall" "testing" "time" @@ -572,3 +574,91 @@ func TestRunWebCommand_ListenerCreationError(t *testing.T) { t.Error("test timed out - runWebCommand likely blocked waiting for signals") } } + +func TestRunServerWithSignalHandling_StopsOnSignal(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("skipping signal handling test on Windows") + } + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to create listener: %v", err) + } + + handler := webpkg.NewHandler(webpkg.HandlerOptions{}) + server := webpkg.NewServer(listener, handler) + + done := make(chan error, 1) + go func() { + done <- runServerWithSignalHandling(server) + }() + + // Give the server time to start + time.Sleep(50 * time.Millisecond) + + // Send SIGINT to the process to trigger shutdown + p, _ := os.FindProcess(os.Getpid()) + _ = p.Signal(syscall.SIGINT) + + select { + case err := <-done: + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for server to stop") + } +} + +func TestRunWebCommand_StartAndStop(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("skipping signal handling test on Windows") + } + configDir := t.TempDir() + configPath := configDir + "/config.yaml" + configContent := `profiles: + default: + db: mysql + host: localhost + port: 3306 + user: root + password: test +` + if err := os.WriteFile(configPath, []byte(configContent), 0644); err != nil { + t.Fatalf("failed to create temp config: %v", err) + } + + oldConfig := GlobalConfig + defer func() { GlobalConfig = oldConfig }() + GlobalConfig.ConfigStr = configPath + GlobalConfig.ProfileStr = "default" + GlobalConfig.FormatStr = "json" + + opts := &webCommandOptions{ + addr: "127.0.0.1:0", + addrSet: true, + } + + var buf bytes.Buffer + w := output.New(&buf, &bytes.Buffer{}) + + done := make(chan error, 1) + go func() { + done <- runWebCommand(opts, &w) + }() + + // Give the server time to start + time.Sleep(100 * time.Millisecond) + + // Send SIGINT + p, _ := os.FindProcess(os.Getpid()) + _ = p.Signal(syscall.SIGINT) + + select { + case err := <-done: + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for web command to stop") + } +} diff --git a/go.mod b/go.mod index ed2ea3a..b644ce3 100644 --- a/go.mod +++ b/go.mod @@ -3,34 +3,33 @@ module github.com/zx06/xsql go 1.25.0 require ( - github.com/go-sql-driver/mysql v1.8.1 - github.com/google/jsonschema-go v0.4.2 + github.com/go-sql-driver/mysql v1.10.0 + github.com/google/jsonschema-go v0.4.3 github.com/jackc/pgx/v5 v5.9.2 - github.com/modelcontextprotocol/go-sdk v1.4.0 - github.com/spf13/cobra v1.8.1 - github.com/zalando/go-keyring v0.2.6 - golang.org/x/crypto v0.47.0 - golang.org/x/sync v0.19.0 - golang.org/x/term v0.39.0 + github.com/modelcontextprotocol/go-sdk v1.6.0 + github.com/spf13/cobra v1.10.2 + github.com/zalando/go-keyring v0.2.8 + golang.org/x/crypto v0.51.0 + golang.org/x/sync v0.20.0 + golang.org/x/term v0.43.0 gopkg.in/yaml.v3 v3.0.1 ) require ( - al.essio.dev/pkg/shellescape v1.5.1 // indirect - filippo.io/edwards25519 v1.1.1 // indirect - github.com/danieljoos/wincred v1.2.2 // indirect - github.com/godbus/dbus/v5 v5.1.0 // indirect + filippo.io/edwards25519 v1.2.0 // indirect + github.com/danieljoos/wincred v1.2.3 // indirect + github.com/godbus/dbus/v5 v5.2.2 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/kr/text v0.2.0 // indirect github.com/rogpeppe/go-internal v1.14.1 // indirect - github.com/segmentio/asm v1.1.3 // indirect - github.com/segmentio/encoding v0.5.3 // indirect - github.com/spf13/pflag v1.0.5 // indirect + github.com/segmentio/asm v1.2.1 // indirect + github.com/segmentio/encoding v0.5.4 // indirect + github.com/spf13/pflag v1.0.10 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect - golang.org/x/oauth2 v0.34.0 // indirect - golang.org/x/sys v0.40.0 // indirect - golang.org/x/text v0.33.0 // indirect + golang.org/x/oauth2 v0.36.0 // indirect + golang.org/x/sys v0.44.0 // indirect + golang.org/x/text v0.37.0 // indirect ) diff --git a/go.sum b/go.sum index f32e1e8..ad63137 100644 --- a/go.sum +++ b/go.sum @@ -1,26 +1,22 @@ -al.essio.dev/pkg/shellescape v1.5.1 h1:86HrALUujYS/h+GtqoB26SBEdkWfmMI6FubjXlsXyho= -al.essio.dev/pkg/shellescape v1.5.1/go.mod h1:6sIqp7X2P6mThCQ7twERpZTuigpr6KbZWtls1U8I890= -filippo.io/edwards25519 v1.1.1 h1:YpjwWWlNmGIDyXOn8zLzqiD+9TyIlPhGFG96P39uBpw= -filippo.io/edwards25519 v1.1.1/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= -github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo= +filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc= +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/danieljoos/wincred v1.2.2 h1:774zMFJrqaeYCK2W57BgAem/MLi6mtSE47MB6BOJ0i0= -github.com/danieljoos/wincred v1.2.2/go.mod h1:w7w4Utbrz8lqeMbDAK0lkNJUv5sAOkFi7nd/ogr0Uh8= +github.com/danieljoos/wincred v1.2.3 h1:v7dZC2x32Ut3nEfRH+vhoZGvN72+dQ/snVXo/vMFLdQ= +github.com/danieljoos/wincred v1.2.3/go.mod h1:6qqX0WNrS4RzPZ1tnroDzq9kY3fu1KwE7MRLQK4X0bs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= -github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= -github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= -github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= -github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/go-sql-driver/mysql v1.10.0 h1:Q+1LV8DkHJvSYAdR83XzuhDaTykuDx0l6fkXxoWCWfw= +github.com/go-sql-driver/mysql v1.10.0/go.mod h1:M+cqaI7+xxXGG9swrdeUIoPG3Y3KCkF0pZej+SK+nWk= +github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ= +github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c= +github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8= -github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= -github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= -github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= +github.com/google/jsonschema-go v0.4.3 h1:/DBOLZTfDow7pe2GmaJNhltueGTtDKICi8V8p+DQPd0= +github.com/google/jsonschema-go v0.4.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= @@ -35,21 +31,22 @@ github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/modelcontextprotocol/go-sdk v1.4.0 h1:u0kr8lbJc1oBcawK7Df+/ajNMpIDFE41OEPxdeTLOn8= -github.com/modelcontextprotocol/go-sdk v1.4.0/go.mod h1:Nxc2n+n/GdCebUaqCOhTetptS17SXXNu9IfNTaLDi1E= +github.com/modelcontextprotocol/go-sdk v1.6.0 h1:PPLS3kn7WtOEnR+Af4X5H96SG0qSab8R/ZQT/HkhPkY= +github.com/modelcontextprotocol/go-sdk v1.6.0/go.mod h1:kzm3kzFL1/+AziGOE0nUs3gvPoNxMCvkxokMkuFapXQ= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/segmentio/asm v1.1.3 h1:WM03sfUOENvvKexOLp+pCqgb/WDjsi7EK8gIsICtzhc= -github.com/segmentio/asm v1.1.3/go.mod h1:Ld3L4ZXGNcSLRg4JBsZ3//1+f/TjYl0Mzen/DQy1EJg= -github.com/segmentio/encoding v0.5.3 h1:OjMgICtcSFuNvQCdwqMCv9Tg7lEOXGwm1J5RPQccx6w= -github.com/segmentio/encoding v0.5.3/go.mod h1:HS1ZKa3kSN32ZHVZ7ZLPLXWvOVIiZtyJnO1gPH1sKt0= -github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= -github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= -github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= -github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/segmentio/asm v1.2.1 h1:DTNbBqs57ioxAD4PrArqftgypG4/qNpXoJx8TVXxPR0= +github.com/segmentio/asm v1.2.1/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs= +github.com/segmentio/encoding v0.5.4 h1:OW1VRern8Nw6ITAtwSZ7Idrl3MXCFwXHPgqESYfvNt0= +github.com/segmentio/encoding v0.5.4/go.mod h1:HS1ZKa3kSN32ZHVZ7ZLPLXWvOVIiZtyJnO1gPH1sKt0= +github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= +github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= +github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= +github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= @@ -59,22 +56,23 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= -github.com/zalando/go-keyring v0.2.6 h1:r7Yc3+H+Ux0+M72zacZoItR3UDxeWfKTcabvkI8ua9s= -github.com/zalando/go-keyring v0.2.6/go.mod h1:2TCrxYrbUNYfNS/Kgy/LSrkSQzZ5UPVH85RwfczwvcI= -golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= -golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= -golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw= -golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= -golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= -golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= -golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= -golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY= -golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww= -golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= -golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= -golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= -golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= +github.com/zalando/go-keyring v0.2.8 h1:6sD/Ucpl7jNq10rM2pgqTs0sZ9V3qMrqfIIy5YPccHs= +github.com/zalando/go-keyring v0.2.8/go.mod h1:tsMo+VpRq5NGyKfxoBVjCuMrG47yj8cmakZDO5QGii0= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI= +golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8= +golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= +golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ= +golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4= +golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk= +golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc= +golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38= +golang.org/x/tools v0.44.0 h1:UP4ajHPIcuMjT1GqzDWRlalUEoY+uzoZKnhOjbIPD2c= +golang.org/x/tools v0.44.0/go.mod h1:KA0AfVErSdxRZIsOVipbv3rQhVXTnlU6UhKxHd1seDI= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/internal/output/writer.go b/internal/output/writer.go index ae8802f..df15fc1 100644 --- a/internal/output/writer.go +++ b/internal/output/writer.go @@ -276,7 +276,7 @@ func tryAsProfileList(data any) ([]ProfileListItem, bool) { for i := 0; i < v.Len(); i++ { elem := v.Index(i) // Dereference pointer - if elem.Kind() == reflect.Ptr { + if elem.Kind() == reflect.Pointer { elem = elem.Elem() } // Only handle structs @@ -350,7 +350,7 @@ func tryAsQueryResultReflect(data any) (*queryResultLike, bool) { } v := reflect.ValueOf(data) - if v.Kind() == reflect.Ptr { + if v.Kind() == reflect.Pointer { v = v.Elem() } if v.Kind() != reflect.Struct { diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 8c324b4..a37f406 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -486,3 +486,148 @@ func TestProxy_PortInUse_ReturnsCorrectErrorCode(t *testing.T) { t.Errorf("expected port=%d in details, got %v", port, xe.Details["port"]) } } + +func TestProxy_HandleConnection_BidirectionalCopy(t *testing.T) { + // Create a mock echo server as the remote target + echoListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer func() { _ = echoListener.Close() }() + + go func() { + for { + conn, err := echoListener.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer func() { _ = c.Close() }() + buf := make([]byte, 1024) + for { + n, err := c.Read(buf) + if err != nil { + return + } + _, _ = c.Write(buf[:n]) + } + }(conn) + } + }() + + echoPort := echoListener.Addr().(*net.TCPAddr).Port + + // Create a dialer that connects to the echo server + dialer := &directDialer{addr: echoListener.Addr().String()} + defer func() { _ = dialer.Close() }() + + ctx := context.Background() + opts := Options{ + LocalHost: "127.0.0.1", + LocalPort: 0, + RemoteHost: "127.0.0.1", + RemotePort: echoPort, + Dialer: dialer, + } + + proxy, result, xe := Start(ctx, opts) + if xe != nil { + t.Fatalf("failed to start proxy: %v", xe) + } + defer func() { _ = proxy.Stop() }() + + // Connect to the proxy and send/receive data + conn, err := net.DialTimeout("tcp", result.LocalAddress, 2*time.Second) + if err != nil { + t.Fatalf("failed to connect to proxy: %v", err) + } + defer func() { _ = conn.Close() }() + + testData := []byte("hello proxy") + if _, err := conn.Write(testData); err != nil { + t.Fatalf("failed to write: %v", err) + } + + buf := make([]byte, len(testData)) + if _, err := conn.Read(buf); err != nil { + t.Fatalf("failed to read: %v", err) + } + + if string(buf) != string(testData) { + t.Errorf("expected %q, got %q", testData, buf) + } +} + +func TestProxy_HandleConnection_ContextCancelled(t *testing.T) { + // Create a mock server that holds connections open + blockListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer func() { _ = blockListener.Close() }() + + go func() { + for { + conn, err := blockListener.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer func() { _ = c.Close() }() + buf := make([]byte, 1024) + for { + if _, err := c.Read(buf); err != nil { + return + } + } + }(conn) + } + }() + + blockAddr := blockListener.Addr().String() + + dialer := &directDialer{addr: blockAddr} + defer func() { _ = dialer.Close() }() + + ctx, cancel := context.WithCancel(context.Background()) + opts := Options{ + LocalHost: "127.0.0.1", + LocalPort: 0, + RemoteHost: "127.0.0.1", + RemotePort: blockListener.Addr().(*net.TCPAddr).Port, + Dialer: dialer, + } + + proxy, result, xe := Start(ctx, opts) + if xe != nil { + t.Fatalf("failed to start proxy: %v", xe) + } + + // Connect to the proxy to create a handleConnection goroutine + conn, err := net.DialTimeout("tcp", result.LocalAddress, 2*time.Second) + if err != nil { + t.Fatalf("failed to connect: %v", err) + } + _, _ = conn.Write([]byte("test")) + + // Cancel context to trigger the context.Done path in handleConnection + cancel() + + // Stop should complete without hanging + if err := proxy.Stop(); err != nil { + t.Errorf("failed to stop proxy: %v", err) + } + _ = conn.Close() +} + +// directDialer connects directly to a TCP address (no SSH tunnel). +type directDialer struct { + addr string +} + +func (d *directDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + var dialer net.Dialer + return dialer.DialContext(ctx, "tcp", d.addr) +} + +func (d *directDialer) Close() error { return nil }