Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
280 changes: 271 additions & 9 deletions cmd/xsql/command_unit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ import (
"encoding/json"
"os"
"path/filepath"
"runtime"
"syscall"
"testing"
"time"

"github.com/modelcontextprotocol/go-sdk/mcp"

Expand Down Expand Up @@ -134,7 +137,7 @@ func TestRunQuery_MissingDB(t *testing.T) {

var out bytes.Buffer
w := output.New(&out, &bytes.Buffer{})
err := runQuery(nil, []string{"select 1"}, &QueryFlags{}, &w)
err := runQuery([]string{"select 1"}, &QueryFlags{}, &w)
if err == nil {
t.Fatal("expected error for missing db type")
}
Expand All @@ -149,7 +152,7 @@ func TestRunQuery_UnsupportedDriver(t *testing.T) {

var out bytes.Buffer
w := output.New(&out, &bytes.Buffer{})
err := runQuery(nil, []string{"select 1"}, &QueryFlags{}, &w)
err := runQuery([]string{"select 1"}, &QueryFlags{}, &w)
if err == nil {
t.Fatal("expected error for unsupported driver")
}
Expand All @@ -168,7 +171,7 @@ func TestRunQuery_PlaintextPasswordNotAllowed(t *testing.T) {

var out bytes.Buffer
w := output.New(&out, &bytes.Buffer{})
err := runQuery(nil, []string{"select 1"}, &QueryFlags{}, &w)
err := runQuery([]string{"select 1"}, &QueryFlags{}, &w)
if err == nil {
t.Fatal("expected error for plaintext password not allowed")
}
Expand All @@ -183,7 +186,7 @@ func TestRunSchemaDump_UnsupportedDriver(t *testing.T) {

var out bytes.Buffer
w := output.New(&out, &bytes.Buffer{})
err := runSchemaDump(nil, nil, &SchemaFlags{}, &w)
err := runSchemaDump(&SchemaFlags{}, &w)
if err == nil {
t.Fatal("expected error for unsupported driver")
}
Expand All @@ -202,7 +205,7 @@ func TestRunSchemaDump_PlaintextPasswordNotAllowed(t *testing.T) {

var out bytes.Buffer
w := output.New(&out, &bytes.Buffer{})
err := runSchemaDump(nil, nil, &SchemaFlags{}, &w)
err := runSchemaDump(&SchemaFlags{}, &w)
if err == nil {
t.Fatal("expected error for plaintext password not allowed")
}
Expand All @@ -217,7 +220,7 @@ func TestRunQuery_InvalidFormat(t *testing.T) {

var out bytes.Buffer
w := output.New(&out, &bytes.Buffer{})
err := runQuery(nil, []string{"select 1"}, &QueryFlags{}, &w)
err := runQuery([]string{"select 1"}, &QueryFlags{}, &w)
if err == nil {
t.Fatal("expected error for invalid format")
}
Expand All @@ -232,7 +235,7 @@ func TestRunSchemaDump_MissingDB(t *testing.T) {

var out bytes.Buffer
w := output.New(&out, &bytes.Buffer{})
err := runSchemaDump(nil, nil, &SchemaFlags{}, &w)
err := runSchemaDump(&SchemaFlags{}, &w)
if err == nil {
t.Fatal("expected error for missing db type")
}
Expand All @@ -247,7 +250,7 @@ func TestRunSchemaDump_InvalidFormat(t *testing.T) {

var out bytes.Buffer
w := output.New(&out, &bytes.Buffer{})
err := runSchemaDump(nil, nil, &SchemaFlags{}, &w)
err := runSchemaDump(&SchemaFlags{}, &w)
if err == nil {
t.Fatal("expected error for invalid format")
}
Expand Down Expand Up @@ -348,7 +351,7 @@ func TestRunProxy_SSHConnectError(t *testing.T) {
}

func TestResolveSSH_NoConfig(t *testing.T) {
client, err := app.ResolveSSH(nil, config.Profile{}, false, false)
client, err := app.ResolveSSH(context.TODO(), config.Profile{}, false, false)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
Expand Down Expand Up @@ -909,3 +912,262 @@ func TestFirstNonEmpty(t *testing.T) {
func configProfile(dbType string) config.Profile {
return config.Profile{DB: dbType}
}

func TestHandlePortConflict_NonTTY(t *testing.T) {
_, err := handlePortConflict(3306, "127.0.0.1")
if err == nil {
t.Fatal("expected error for non-TTY port conflict")
}
if err.Code != errors.CodePortInUse {
t.Fatalf("expected CodePortInUse, got %s", err.Code)
}
}

func TestModeForWebCommand(t *testing.T) {
if got := modeForWebCommand(true); got != "web" {
t.Fatalf("expected 'web', got %q", got)
}
if got := modeForWebCommand(false); got != "serve" {
t.Fatalf("expected 'serve', got %q", got)
}
}

func TestResolveWebOptions_InvalidAddr(t *testing.T) {
_, xe := resolveWebOptions(&webCommandOptions{
addr: "not-a-valid-addr",
addrSet: true,
}, config.File{})
if xe == nil {
t.Fatal("expected error for invalid addr")
}
if xe.Code != errors.CodeCfgInvalid {
t.Fatalf("expected CodeCfgInvalid, got %s", xe.Code)
}
}

func TestResolveWebOptions_EnvVars(t *testing.T) {
t.Setenv("XSQL_WEB_HTTP_AUTH_TOKEN", "env-token")
resolved, xe := resolveWebOptions(&webCommandOptions{
addr: "0.0.0.0:9999",
addrSet: true,
}, config.File{})
if xe != nil {
t.Fatalf("unexpected error: %v", xe)
}
if resolved.authToken != "env-token" {
t.Fatalf("expected env-token, got %s", resolved.authToken)
}
if !resolved.authRequired {
t.Fatal("expected authRequired=true")
}
}

func TestResolveMCPServerOptions_HttpAddrEnv(t *testing.T) {
t.Setenv("XSQL_MCP_TRANSPORT", "streamable_http")
t.Setenv("XSQL_MCP_HTTP_AUTH_TOKEN", "token")
t.Setenv("XSQL_MCP_HTTP_ADDR", "127.0.0.1:5555")
cfg := config.File{
Profiles: map[string]config.Profile{},
SSHProxies: map[string]config.SSHProxy{},
}
resolved, xe := resolveMCPServerOptions(&mcpServerOptions{}, cfg)
if xe != nil {
t.Fatalf("unexpected error: %v", xe)
}
if resolved.httpAddr != "127.0.0.1:5555" {
t.Fatalf("expected 127.0.0.1:5555, got %s", resolved.httpAddr)
}
}

func TestRunMCPServer_InvalidConfigPath(t *testing.T) {
GlobalConfig.ConfigStr = "/nonexistent/path/config.yaml"
err := runMCPServer(&mcpServerOptions{})
if err == nil {
t.Fatal("expected error for nonexistent config")
}
}

func TestRunMCPServer_StreamableHTTPStarts(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("skipping signal handling test on Windows")
}
configPath := filepath.Join(t.TempDir(), "xsql.yaml")
if err := os.WriteFile(configPath, []byte("profiles: {}\nmcp:\n transport: streamable_http\n http:\n addr: 127.0.0.1:0\n auth_token: test-token\n allow_plaintext_token: true\n"), 0644); err != nil {
t.Fatalf("failed to write config: %v", err)
}

GlobalConfig.ConfigStr = configPath

// This will start the HTTP server and then we need to stop it
// We'll use a goroutine to run it and cancel after a short time
done := make(chan error, 1)
go func() {
done <- runMCPServer(&mcpServerOptions{})
}()

// Give the server time to start
time.Sleep(100 * time.Millisecond)

// The server should be running, send SIGINT to stop it
p, _ := os.FindProcess(os.Getpid())
_ = p.Signal(syscall.SIGINT)

select {
case err := <-done:
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for server to stop")
}
}

func TestNewServeCommand(t *testing.T) {
var out bytes.Buffer
w := output.New(&out, &bytes.Buffer{})
cmd := NewServeCommand(&w)
if cmd.Use != "serve" {
t.Fatalf("expected 'serve', got %s", cmd.Use)
}
}

func TestNewWebCommand(t *testing.T) {
var out bytes.Buffer
w := output.New(&out, &bytes.Buffer{})
cmd := NewWebCommand(&w)
if cmd.Use != "web" {
t.Fatalf("expected 'web', got %s", cmd.Use)
}
}

func TestNewMCPCommand(t *testing.T) {
cmd := NewMCPCommand()
if cmd.Use != "mcp" {
t.Fatalf("expected 'mcp', got %s", cmd.Use)
}
}

func TestNewProfileCommand(t *testing.T) {
var out bytes.Buffer
w := output.New(&out, &bytes.Buffer{})
cmd := NewProfileCommand(&w)
if cmd.Use != "profile" {
t.Fatalf("expected 'profile', got %s", cmd.Use)
}
}

func TestNewSchemaCommand(t *testing.T) {
var out bytes.Buffer
w := output.New(&out, &bytes.Buffer{})
cmd := NewSchemaCommand(&w)
if cmd.Use != "schema" {
t.Fatalf("expected 'schema', got %s", cmd.Use)
}
}

func TestResolveWebOptions_NilOpts(t *testing.T) {
resolved, xe := resolveWebOptions(nil, config.File{})
if xe != nil {
t.Fatalf("unexpected error: %v", xe)
}
if resolved.addr != "127.0.0.1:8788" {
t.Fatalf("expected default addr, got %s", resolved.addr)
}
}

func TestResolveWebOptions_ConfigAddr(t *testing.T) {
resolved, xe := resolveWebOptions(&webCommandOptions{}, config.File{
Web: config.WebConfig{
HTTP: config.WebHTTPConfig{
Addr: "127.0.0.1:9999",
},
},
})
if xe != nil {
t.Fatalf("unexpected error: %v", xe)
}
if resolved.addr != "127.0.0.1:9999" {
t.Fatalf("expected 127.0.0.1:9999, got %s", resolved.addr)
}
}

func TestResolveWebOptions_NonLoopbackRequiresToken(t *testing.T) {
_, xe := resolveWebOptions(&webCommandOptions{
addr: "10.0.0.1:8788",
addrSet: true,
}, config.File{})
if xe == nil {
t.Fatal("expected error for non-loopback without token")
}
if xe.Code != errors.CodeCfgInvalid {
t.Fatalf("expected CodeCfgInvalid, got %s", xe.Code)
}
}

func TestRunProxy_InvalidFormat(t *testing.T) {
prev := GlobalConfig
GlobalConfig = &Config{ProfileStr: "dev", FormatStr: "invalid"}
t.Cleanup(func() { GlobalConfig = prev })

GlobalConfig.Resolved.Profile = config.Profile{DB: "mysql", SSHConfig: &config.SSHProxy{Host: "h", Port: 22, User: "u"}}

var out bytes.Buffer
w := output.New(&out, &bytes.Buffer{})
err := runProxy(nil, &ProxyFlags{}, &w)
if err == nil {
t.Fatal("expected error for invalid format")
}
}

func TestRunProxy_PlaintextNotAllowed(t *testing.T) {
prev := GlobalConfig
GlobalConfig = &Config{ProfileStr: "dev", FormatStr: "json"}
t.Cleanup(func() { GlobalConfig = prev })

GlobalConfig.Resolved.Profile = config.Profile{
DB: "mysql",
Password: "plain",
AllowPlaintext: false,
SSHConfig: &config.SSHProxy{Host: "h", Port: 22, User: "u"},
}

var out bytes.Buffer
w := output.New(&out, &bytes.Buffer{})
err := runProxy(nil, &ProxyFlags{}, &w)
if err == nil {
t.Fatal("expected error for plaintext not allowed")
}
}

func TestResolveProxyPort_AllPaths(t *testing.T) {
// Test with config port and no CLI flag
cmd := NewProxyCommand(nil)
port, fromConfig := resolveProxyPort(cmd, &ProxyFlags{}, 5555)
if port != 5555 || !fromConfig {
t.Errorf("expected port=5555, fromConfig=true, got port=%d, fromConfig=%v", port, fromConfig)
}

// Test with zero config port
port, fromConfig = resolveProxyPort(cmd, &ProxyFlags{}, 0)
if port != 0 || fromConfig {
t.Errorf("expected port=0, fromConfig=false, got port=%d, fromConfig=%v", port, fromConfig)
}
}

func TestNewQueryCommand(t *testing.T) {
var out bytes.Buffer
w := output.New(&out, &bytes.Buffer{})
cmd := NewQueryCommand(&w)
if cmd.Use != "query [SQL]" {
t.Fatalf("expected 'query [SQL]', got %s", cmd.Use)
}
}

func TestNewProxyCommand(t *testing.T) {
var out bytes.Buffer
w := output.New(&out, &bytes.Buffer{})
cmd := NewProxyCommand(&w)
if cmd.Use != "proxy [flags]" {
t.Fatalf("expected 'proxy [flags]', got %s", cmd.Use)
}
}
4 changes: 2 additions & 2 deletions cmd/xsql/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func NewQueryCommand(w *output.Writer) *cobra.Command {
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
flags.QueryTimeoutSet = cmd.Flags().Changed("query-timeout")
return runQuery(cmd, args, flags, w)
return runQuery(args, flags, w)
},
}

Expand All @@ -44,7 +44,7 @@ func NewQueryCommand(w *output.Writer) *cobra.Command {
}

// runQuery executes a SQL query
func runQuery(cmd *cobra.Command, args []string, flags *QueryFlags, w *output.Writer) error {
func runQuery(args []string, flags *QueryFlags, w *output.Writer) error {
sql := args[0]
format, err := parseOutputFormat(GlobalConfig.FormatStr)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions cmd/xsql/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ func NewRootCommand() *cobra.Command {
CLIFormatSet: formatSet,
EnvProfile: os.Getenv("XSQL_PROFILE"),
EnvFormat: os.Getenv("XSQL_FORMAT"),
WorkDir: "",
HomeDir: "",
WorkDir: os.Getenv("XSQL_WORKDIR"),
HomeDir: os.Getenv("XSQL_HOMEDIR"),
})
if xe != nil {
return xe
Expand Down
Loading
Loading