diff --git a/.gitignore b/.gitignore index f713869e..d8da873d 100644 --- a/.gitignore +++ b/.gitignore @@ -36,6 +36,7 @@ linux-s390x/sqlcmd # Build artifacts in root /sqlcmd /sqlcmd_binary +/modern # certificates used for local testing *.der diff --git a/README.md b/README.md index 576439da..f34209a3 100644 --- a/README.md +++ b/README.md @@ -61,18 +61,51 @@ The Homebrew package manager may be used on Linux and Windows Subsystem for Linu Use `sqlcmd` to create SQL Server and Azure SQL Edge instances using a local container runtime (e.g. [Docker][] or [Podman][]) -### Create SQL Server instance using local container runtime and connect using Azure Data Studio +### Create SQL Server instance using local container runtime -To create a local SQL Server instance with the AdventureWorksLT database restored, query it, and connect to it using Azure Data Studio, run: +To create a local SQL Server instance with the AdventureWorksLT database restored, run: ``` sqlcmd create mssql --accept-eula --using https://aka.ms/AdventureWorksLT.bak sqlcmd query "SELECT DB_NAME()" -sqlcmd open ads ``` Use `sqlcmd --help` to view all the available sub-commands. Use `sqlcmd -?` to view the original ODBC `sqlcmd` flags. +### Connect using Visual Studio Code + +Use `sqlcmd open vscode` to open Visual Studio Code with a connection profile configured for the current context: + +``` +sqlcmd open vscode +``` + +This command will: +1. **Create a connection profile** in VS Code's user settings with the current context name +2. **Copy the password to clipboard** so you can paste it when prompted +3. **Launch VS Code** ready to connect + +To also install the MSSQL extension (if not already installed), add the `--install-extension` flag: + +``` +sqlcmd open vscode --install-extension +``` + +Once VS Code opens, use the MSSQL extension's Object Explorer to connect using the profile. When you connect to the container, VS Code will automatically detect it as a Docker container and provide additional container management features (start/stop/delete) directly from the Object Explorer. + +### Connect using SQL Server Management Studio (Windows) + +On Windows, use `sqlcmd open ssms` to open SQL Server Management Studio pre-configured to connect to the current context: + +``` +sqlcmd open ssms +``` + +This command will: +1. **Copy the password to clipboard** so you can paste it in the login dialog +2. **Launch SSMS** with the server and username pre-filled +3. You'll be prompted for the password - just paste from clipboard (Ctrl+V) + ### The ~/.sqlcmd/sqlconfig file Each time `sqlcmd create` completes, a new context is created (e.g. mssql, mssql2, mssql3 etc.). A context contains the endpoint and user configuration detail. To switch between contexts, run `sqlcmd config use `, to view name of the current context, run `sqlcmd config current-context`, to list all contexts, run `sqlcmd config get-contexts`. diff --git a/cmd/modern/root/open.go b/cmd/modern/root/open.go index d209db81..3e1ca5c9 100644 --- a/cmd/modern/root/open.go +++ b/cmd/modern/root/open.go @@ -17,7 +17,7 @@ type Open struct { func (c *Open) DefineCommand(...cmdparser.CommandOptions) { options := cmdparser.CommandOptions{ Use: "open", - Short: localizer.Sprintf("Open tools (e.g Azure Data Studio) for current context"), + Short: localizer.Sprintf("Open tools (e.g., Azure Data Studio, VS Code, SSMS) for current context"), SubCommands: c.SubCommands(), } @@ -25,11 +25,13 @@ func (c *Open) DefineCommand(...cmdparser.CommandOptions) { } // SubCommands sets up the sub-commands for `sqlcmd open` such as -// `sqlcmd open ads` +// `sqlcmd open ads`, `sqlcmd open vscode`, and `sqlcmd open ssms` func (c *Open) SubCommands() []cmdparser.Command { dependencies := c.Dependencies() return []cmdparser.Command{ cmdparser.New[*open.Ads](dependencies), + cmdparser.New[*open.VSCode](dependencies), + cmdparser.New[*open.Ssms](dependencies), } } diff --git a/cmd/modern/root/open/ads.go b/cmd/modern/root/open/ads.go index 10731ecf..c009d949 100644 --- a/cmd/modern/root/open/ads.go +++ b/cmd/modern/root/open/ads.go @@ -37,6 +37,44 @@ func (c *Ads) DefineCommand(...cmdparser.CommandOptions) { // specific credential store, e.g. on Windows we use the Windows Credential // Manager. func (c *Ads) run() { + output := c.Output() + output.Warn(localizer.Sprintf("Azure Data Studio is being retired. This command will be removed in a future release.")) + + switch runtime.GOOS { + case "windows": + output.Info(localizer.Sprintf(`Alternatives: + + VS Code: + winget install Microsoft.VisualStudioCode + sqlcmd open vscode --install-extension + + SSMS: + winget install Microsoft.SQLServerManagementStudio + sqlcmd open ssms +`)) + case "darwin": + output.Info(localizer.Sprintf(`Alternatives: + + VS Code: + brew install --cask visual-studio-code + sqlcmd open vscode --install-extension + Or download: https://code.visualstudio.com/download +`)) + default: + output.Info(localizer.Sprintf(`Alternatives: + + VS Code: + snap install code --classic + sqlcmd open vscode --install-extension + Or download: https://code.visualstudio.com/download +`)) + } + + tool := tools.NewTool("ads") + if !tool.IsInstalled() { + output.Fatal(localizer.Sprintf("Azure Data Studio is not installed.")) + } + endpoint, user := config.CurrentContext() // If the context has a local container, ensure it is running, otherwise bail out @@ -66,7 +104,6 @@ func (c *Ads) ensureContainerIsRunning(endpoint sqlconfig.Endpoint) { // launchAds launches the Azure Data Studio using the specified server and username. func (c *Ads) launchAds(host string, port int, username string) { - output := c.Output() args := []string{ "-r", fmt.Sprintf( @@ -89,9 +126,7 @@ func (c *Ads) launchAds(host string, port int, username string) { } tool := tools.NewTool("ads") - if !tool.IsInstalled() { - output.Fatal(tool.HowToInstall()) - } + tool.IsInstalled() // precondition for Run; already verified in run() c.displayPreLaunchInfo() diff --git a/cmd/modern/root/open/ads_test.go b/cmd/modern/root/open/ads_test.go index 29f50369..68c2b77c 100644 --- a/cmd/modern/root/open/ads_test.go +++ b/cmd/modern/root/open/ads_test.go @@ -4,17 +4,24 @@ package open import ( + "runtime" + "testing" + "github.com/microsoft/go-sqlcmd/cmd/modern/sqlconfig" "github.com/microsoft/go-sqlcmd/internal/cmdparser" "github.com/microsoft/go-sqlcmd/internal/config" - "runtime" - "testing" + "github.com/microsoft/go-sqlcmd/internal/tools" ) -// TestOpen runs a sanity test of `sqlcmd open` +// TestAds runs a sanity test of `sqlcmd open ads` func TestAds(t *testing.T) { if runtime.GOOS != "windows" { - t.Skip("Ads support only on Windows at this time") + t.Skip("ADS support only on Windows at this time") + } + + tool := tools.NewTool("ads") + if !tool.IsInstalled() { + t.Skip("Azure Data Studio is not installed") } cmdparser.TestSetup(t) diff --git a/cmd/modern/root/open/ads_windows.go b/cmd/modern/root/open/ads_windows.go index 1c475251..64a3dc0d 100644 --- a/cmd/modern/root/open/ads_windows.go +++ b/cmd/modern/root/open/ads_windows.go @@ -26,8 +26,7 @@ type Ads struct { // Ctrl+C here. func (c *Ads) displayPreLaunchInfo() { output := c.Output() - - output.Info(localizer.Sprintf("Press Ctrl+C to exit this process...")) + output.Info(localizer.Sprintf("Launching Azure Data Studio...")) } // persistCredentialForAds stores a SQL password in the Windows Credential Manager @@ -77,7 +76,12 @@ func (c *Ads) adsKey(instance, database, authType, user string) string { // the same target name as the current instance's credential. func (c *Ads) removePreviousCredential() { credentials, err := credman.EnumerateCredentials("", true) - c.CheckErr(err) + if err != nil { + // ERROR_NOT_FOUND (element not found) is expected when no + // credentials exist yet. Any other error is non-fatal here + // since we're only trying to clean up a previous entry. + return + } for _, cred := range credentials { if cred.TargetName == c.credential.TargetName { diff --git a/cmd/modern/root/open/clipboard.go b/cmd/modern/root/open/clipboard.go new file mode 100644 index 00000000..74e9d51c --- /dev/null +++ b/cmd/modern/root/open/clipboard.go @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package open + +import ( + "github.com/microsoft/go-sqlcmd/cmd/modern/sqlconfig" + "github.com/microsoft/go-sqlcmd/internal/config" + "github.com/microsoft/go-sqlcmd/internal/localizer" + "github.com/microsoft/go-sqlcmd/internal/output" + "github.com/microsoft/go-sqlcmd/internal/pal" +) + +// copyPasswordToClipboard copies the SQL password to the system clipboard. +// The password remains on the clipboard until the user or another application +// clears it; callers rely on the user heeding the "clear your clipboard" message. +func copyPasswordToClipboard(user *sqlconfig.User, out *output.Output) bool { + if user == nil || user.AuthenticationType != "basic" { + return false + } + + _, _, password := config.GetCurrentContextInfo() + + if password == "" { + return false + } + + err := pal.CopyToClipboard(password) + if err != nil { + out.Warn(localizer.Sprintf("Could not copy password to clipboard: %s", err.Error())) + return false + } + + out.Info(localizer.Sprintf("Password copied to clipboard - paste it when prompted, then clear your clipboard")) + return true +} diff --git a/cmd/modern/root/open/clipboard_test.go b/cmd/modern/root/open/clipboard_test.go new file mode 100644 index 00000000..67f93fc3 --- /dev/null +++ b/cmd/modern/root/open/clipboard_test.go @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package open + +import ( + "runtime" + "testing" + + "github.com/microsoft/go-sqlcmd/cmd/modern/sqlconfig" + "github.com/microsoft/go-sqlcmd/internal/cmdparser" +) + +func TestCopyPasswordToClipboardWithNoUser(t *testing.T) { + if runtime.GOOS == "linux" { + t.Skip("Skipping on Linux due to ADS tool initialization issue in tools factory") + } + + cmdparser.TestSetup(t) + + result := copyPasswordToClipboard(nil, nil) + if result { + t.Error("Expected false when user is nil") + } +} + +func TestCopyPasswordToClipboardWithNonBasicAuth(t *testing.T) { + if runtime.GOOS == "linux" { + t.Skip("Skipping on Linux due to ADS tool initialization issue in tools factory") + } + + cmdparser.TestSetup(t) + + user := &sqlconfig.User{ + AuthenticationType: "windows", + Name: "test-user", + } + + result := copyPasswordToClipboard(user, nil) + if result { + t.Error("Expected false when auth type is not 'basic'") + } +} diff --git a/cmd/modern/root/open/jsonc_patch.go b/cmd/modern/root/open/jsonc_patch.go new file mode 100644 index 00000000..5aea7dd9 --- /dev/null +++ b/cmd/modern/root/open/jsonc_patch.go @@ -0,0 +1,272 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package open + +import ( + "bytes" + "encoding/json" + "fmt" +) + +// patchJSONCKey replaces the value of a top-level key in a JSONC document, +// preserving all comments and formatting of other keys. If the key does not +// exist, it is appended before the closing brace. If data is nil or empty, +// a new JSON object is created. +func patchJSONCKey(data []byte, key string, value interface{}) ([]byte, error) { + encoded, err := json.MarshalIndent(value, " ", " ") + if err != nil { + return nil, err + } + + if len(bytes.TrimSpace(data)) == 0 { + return []byte(fmt.Sprintf("{\n %q: %s\n}\n", key, encoded)), nil + } + + quotedKey := []byte(`"` + key + `"`) + valStart, valEnd := findTopLevelJSONCValue(data, quotedKey) + if valStart >= 0 { + var buf bytes.Buffer + buf.Write(data[:valStart]) + buf.Write(encoded) + buf.Write(data[valEnd:]) + return buf.Bytes(), nil + } + + return insertJSONCKey(data, quotedKey, encoded) +} + +// findTopLevelJSONCValue locates a top-level key in JSONC data and returns +// the byte range [start, end) of its value. Returns (-1, -1) if not found. +func findTopLevelJSONCValue(data, quotedKey []byte) (int, int) { + i := jsoncSkipWS(data, 0) + if i >= len(data) || data[i] != '{' { + return -1, -1 + } + i++ + + for { + i = jsoncSkipWS(data, i) + if i >= len(data) || data[i] == '}' { + return -1, -1 + } + if data[i] == ',' { + i++ + continue + } + if data[i] != '"' { + return -1, -1 + } + + keyStart := i + i = jsoncSkipString(data, i) + isTarget := bytes.Equal(data[keyStart:i], quotedKey) + + i = jsoncSkipWS(data, i) + if i >= len(data) || data[i] != ':' { + return -1, -1 + } + i++ + + i = jsoncSkipWS(data, i) + valStart := i + i = jsoncSkipValue(data, i) + + if isTarget { + return valStart, i + } + } +} + +// insertJSONCKey inserts a new key-value pair before the top-level closing brace. +func insertJSONCKey(data, quotedKey, encoded []byte) ([]byte, error) { + var closingPos int + i := jsoncSkipWS(data, 0) + if i >= len(data) || data[i] != '{' { + return nil, fmt.Errorf("no top-level object found") + } + depth := 0 + for i < len(data) { + c := data[i] + switch { + case c == '"': + i = jsoncSkipString(data, i) + continue + case i+1 < len(data) && c == '/' && data[i+1] == '/': + i += 2 + for i < len(data) && data[i] != '\n' { + i++ + } + continue + case i+1 < len(data) && c == '/' && data[i+1] == '*': + i += 2 + for i+1 < len(data) { + if data[i] == '*' && data[i+1] == '/' { + i += 2 + break + } + i++ + } + continue + case c == '{' || c == '[': + depth++ + case c == '}' || c == ']': + depth-- + if depth == 0 && c == '}' { + closingPos = i + goto found + } + } + i++ + } + return nil, fmt.Errorf("no closing brace found") + +found: + needsComma := true + for j := closingPos - 1; j >= 0; j-- { + c := data[j] + if c == ' ' || c == '\t' || c == '\n' || c == '\r' { + continue + } + if c == ',' || c == '{' { + needsComma = false + } + break + } + + var buf bytes.Buffer + buf.Write(data[:closingPos]) + if needsComma { + buf.WriteByte(',') + } + buf.WriteString("\n ") + buf.Write(quotedKey) + buf.WriteString(": ") + buf.Write(encoded) + buf.WriteByte('\n') + buf.Write(data[closingPos:]) + return buf.Bytes(), nil +} + +// jsoncSkipWS advances past whitespace and JSONC comments. +func jsoncSkipWS(data []byte, i int) int { + n := len(data) + for i < n { + c := data[i] + if c == ' ' || c == '\t' || c == '\n' || c == '\r' { + i++ + continue + } + if i+1 < n && c == '/' { + if data[i+1] == '/' { + i += 2 + for i < n && data[i] != '\n' { + i++ + } + continue + } + if data[i+1] == '*' { + i += 2 + for i+1 < n { + if data[i] == '*' && data[i+1] == '/' { + i += 2 + break + } + i++ + } + continue + } + } + break + } + return i +} + +// jsoncSkipString advances past a quoted string including escape sequences. +func jsoncSkipString(data []byte, i int) int { + n := len(data) + if i >= n || data[i] != '"' { + return i + } + i++ + for i < n { + if data[i] == '\\' && i+1 < n { + i += 2 + continue + } + if data[i] == '"' { + return i + 1 + } + i++ + } + return i +} + +// jsoncSkipValue advances past a JSONC value (string, number, object, array, bool, null). +func jsoncSkipValue(data []byte, i int) int { + n := len(data) + if i >= n { + return i + } + switch data[i] { + case '"': + return jsoncSkipString(data, i) + case '{', '[': + return jsoncSkipBracket(data, i) + default: + for i < n { + c := data[i] + if c == ',' || c == '}' || c == ']' || c == ' ' || c == '\t' || c == '\n' || c == '\r' { + break + } + if c == '/' && i+1 < n && (data[i+1] == '/' || data[i+1] == '*') { + break + } + i++ + } + return i + } +} + +// jsoncSkipBracket advances past a bracket-delimited structure ({} or []), +// respecting nesting, strings, and comments. +func jsoncSkipBracket(data []byte, i int) int { + n := len(data) + if i >= n { + return i + } + depth := 0 + for i < n { + c := data[i] + switch { + case c == '"': + i = jsoncSkipString(data, i) + continue + case i+1 < n && c == '/' && data[i+1] == '/': + i += 2 + for i < n && data[i] != '\n' { + i++ + } + continue + case i+1 < n && c == '/' && data[i+1] == '*': + i += 2 + for i+1 < n { + if data[i] == '*' && data[i+1] == '/' { + i += 2 + break + } + i++ + } + continue + case c == '{' || c == '[': + depth++ + case c == '}' || c == ']': + depth-- + if depth == 0 { + return i + 1 + } + } + i++ + } + return i +} diff --git a/cmd/modern/root/open/jsonc_patch_test.go b/cmd/modern/root/open/jsonc_patch_test.go new file mode 100644 index 00000000..58f756aa --- /dev/null +++ b/cmd/modern/root/open/jsonc_patch_test.go @@ -0,0 +1,288 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package open + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/tidwall/jsonc" +) + +func TestPatchJSONCKey_PreservesComments(t *testing.T) { + input := []byte(`{ + // Editor settings + "editor.fontSize": 14, + "editor.tabSize": 2, + + /* Database connections */ + "mssql.connections": [ + { + "server": "old-server,1433", + "profileName": "old-profile" + } + ], + + // Terminal settings + "terminal.integrated.fontSize": 12 +}`) + + newConns := []interface{}{ + map[string]interface{}{ + "server": "new-server,1433", + "profileName": "new-profile", + }, + } + + result, err := patchJSONCKey(input, "mssql.connections", newConns) + if err != nil { + t.Fatalf("patchJSONCKey failed: %v", err) + } + + s := string(result) + + if !strings.Contains(s, "// Editor settings") { + t.Error("Line comment before key was destroyed") + } + if !strings.Contains(s, "/* Database connections */") { + t.Error("Block comment before key was destroyed") + } + if !strings.Contains(s, "// Terminal settings") { + t.Error("Line comment after key was destroyed") + } + if !strings.Contains(s, `"editor.fontSize": 14`) { + t.Error("Other key was modified") + } + if !strings.Contains(s, `"terminal.integrated.fontSize": 12`) { + t.Error("Other key was modified") + } + + // Verify the patched file is valid JSONC (strip comments, then parse) + clean := jsonc.ToJSON(result) + var m map[string]interface{} + if err := json.Unmarshal(clean, &m); err != nil { + t.Fatalf("Result is not valid JSONC: %v\n%s", err, result) + } + + conns, ok := m["mssql.connections"].([]interface{}) + if !ok || len(conns) != 1 { + t.Fatalf("Expected 1 connection, got %v", m["mssql.connections"]) + } + conn := conns[0].(map[string]interface{}) + if conn["server"] != "new-server,1433" { + t.Errorf("Expected new-server, got %v", conn["server"]) + } +} + +func TestPatchJSONCKey_InsertsNewKey(t *testing.T) { + input := []byte(`{ + // Editor settings + "editor.fontSize": 14 +}`) + + newConns := []interface{}{ + map[string]interface{}{ + "server": "localhost,1433", + "profileName": "test", + }, + } + + result, err := patchJSONCKey(input, "mssql.connections", newConns) + if err != nil { + t.Fatalf("patchJSONCKey failed: %v", err) + } + + s := string(result) + if !strings.Contains(s, "// Editor settings") { + t.Error("Comment was destroyed during insert") + } + if !strings.Contains(s, `"editor.fontSize": 14`) { + t.Error("Existing key was modified during insert") + } + + clean := jsonc.ToJSON(result) + var m map[string]interface{} + if err := json.Unmarshal(clean, &m); err != nil { + t.Fatalf("Result is not valid JSONC: %v\n%s", err, result) + } + + conns, ok := m["mssql.connections"].([]interface{}) + if !ok || len(conns) != 1 { + t.Fatalf("Expected 1 connection, got %v", m["mssql.connections"]) + } +} + +func TestPatchJSONCKey_EmptyFile(t *testing.T) { + result, err := patchJSONCKey(nil, "mssql.connections", []interface{}{}) + if err != nil { + t.Fatalf("patchJSONCKey failed: %v", err) + } + + var m map[string]interface{} + if err := json.Unmarshal(result, &m); err != nil { + t.Fatalf("Result is not valid JSON: %v\n%s", err, result) + } + conns, ok := m["mssql.connections"].([]interface{}) + if !ok || len(conns) != 0 { + t.Errorf("Expected empty connections array, got %v", m["mssql.connections"]) + } +} + +func TestPatchJSONCKey_TrailingComma(t *testing.T) { + input := []byte(`{ + "editor.fontSize": 14, + "mssql.connections": [], +}`) + + newConns := []interface{}{ + map[string]interface{}{"profileName": "test"}, + } + + result, err := patchJSONCKey(input, "mssql.connections", newConns) + if err != nil { + t.Fatalf("patchJSONCKey failed: %v", err) + } + + clean := jsonc.ToJSON(result) + var m map[string]interface{} + if err := json.Unmarshal(clean, &m); err != nil { + t.Fatalf("Result is not valid JSONC: %v\n%s", err, result) + } + conns, ok := m["mssql.connections"].([]interface{}) + if !ok || len(conns) != 1 { + t.Fatalf("Expected 1 connection, got %v", m["mssql.connections"]) + } +} + +func TestPatchJSONCKey_InsertIntoEmptyObject(t *testing.T) { + input := []byte(`{}`) + + result, err := patchJSONCKey(input, "mssql.connections", []interface{}{}) + if err != nil { + t.Fatalf("patchJSONCKey failed: %v", err) + } + + clean := jsonc.ToJSON(result) + var m map[string]interface{} + if err := json.Unmarshal(clean, &m); err != nil { + t.Fatalf("Result is not valid JSONC: %v\n%s", err, result) + } + if _, ok := m["mssql.connections"]; !ok { + t.Error("Key was not inserted") + } +} + +func TestPatchJSONCKey_InlineCommentAfterValue(t *testing.T) { + input := []byte(`{ + "mssql.connections": [] // old connections +}`) + + newConns := []interface{}{ + map[string]interface{}{"profileName": "new"}, + } + + result, err := patchJSONCKey(input, "mssql.connections", newConns) + if err != nil { + t.Fatalf("patchJSONCKey failed: %v", err) + } + + if !strings.Contains(string(result), "// old connections") { + t.Error("Inline comment after value was destroyed") + } + + clean := jsonc.ToJSON(result) + var m map[string]interface{} + if err := json.Unmarshal(clean, &m); err != nil { + t.Fatalf("Result is not valid JSONC: %v\n%s", err, result) + } +} + +func TestPatchJSONCKey_RealWorldVSCodeSettings(t *testing.T) { + input := []byte(`{ + // General editor preferences + "editor.fontSize": 14, + "editor.wordWrap": "on", + + /* + * Extensions + */ + "extensions.autoUpdate": true, + + // SQL connections managed by sqlcmd + "mssql.connections": [ + { + "server": "localhost,1433", + "profileName": "local-dev", + "encrypt": "Optional", + "trustServerCertificate": true, + }, + ], + + // Python settings + "python.linting.enabled": true, + "python.formatting.provider": "black" +}`) + + newConns := []interface{}{ + map[string]interface{}{ + "server": "localhost,1433", + "profileName": "local-dev", + "encrypt": "Optional", + "trustServerCertificate": true, + }, + map[string]interface{}{ + "server": "prod-server,1433", + "profileName": "production", + "encrypt": "Mandatory", + "trustServerCertificate": false, + "authenticationType": "SqlLogin", + "user": "admin", + }, + } + + result, err := patchJSONCKey(input, "mssql.connections", newConns) + if err != nil { + t.Fatalf("patchJSONCKey failed: %v", err) + } + + s := string(result) + + // All comments preserved + for _, comment := range []string{ + "// General editor preferences", + "* Extensions", + "// SQL connections managed by sqlcmd", + "// Python settings", + } { + if !strings.Contains(s, comment) { + t.Errorf("Comment %q was destroyed", comment) + } + } + + // All non-connection keys preserved verbatim + for _, key := range []string{ + `"editor.fontSize": 14`, + `"editor.wordWrap": "on"`, + `"extensions.autoUpdate": true`, + `"python.linting.enabled": true`, + `"python.formatting.provider": "black"`, + } { + if !strings.Contains(s, key) { + t.Errorf("Key %q was modified or lost", key) + } + } + + // Valid JSONC with correct data + clean := jsonc.ToJSON(result) + var m map[string]interface{} + if err := json.Unmarshal(clean, &m); err != nil { + t.Fatalf("Result is not valid JSONC: %v\n%s", err, result) + } + + conns, ok := m["mssql.connections"].([]interface{}) + if !ok || len(conns) != 2 { + t.Fatalf("Expected 2 connections, got %v", m["mssql.connections"]) + } +} diff --git a/cmd/modern/root/open/ssms.go b/cmd/modern/root/open/ssms.go new file mode 100644 index 00000000..c7c79526 --- /dev/null +++ b/cmd/modern/root/open/ssms.go @@ -0,0 +1,107 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +//go:build windows + +package open + +import ( + "fmt" + "os" + "regexp" + "strconv" + "strings" + + "github.com/microsoft/go-sqlcmd/cmd/modern/sqlconfig" + "github.com/microsoft/go-sqlcmd/internal/cmdparser" + "github.com/microsoft/go-sqlcmd/internal/config" + "github.com/microsoft/go-sqlcmd/internal/container" + "github.com/microsoft/go-sqlcmd/internal/localizer" + "github.com/microsoft/go-sqlcmd/internal/tools" +) + +func (c *Ssms) DefineCommand(...cmdparser.CommandOptions) { + options := cmdparser.CommandOptions{ + Use: "ssms", + Short: localizer.Sprintf("Open SQL Server Management Studio and connect to current context"), + Examples: []cmdparser.ExampleOptions{{ + Description: localizer.Sprintf("Open SSMS and connect using the current context"), + Steps: []string{"sqlcmd open ssms"}}}, + Run: c.run, + } + + c.Cmd.DefineCommand(options) +} + +func (c *Ssms) run() { + endpoint, user := config.CurrentContext() + isLocalConnection := isLocalEndpoint(endpoint) + + if asset := endpoint.AssetDetails; asset != nil && asset.ContainerDetails != nil { + c.ensureContainerIsRunning(asset.Id) + } + + c.launchSsms(endpoint.Address, endpoint.Port, user, isLocalConnection) +} + +func (c *Ssms) ensureContainerIsRunning(containerID string) { + output := c.Output() + controller := container.NewController() + if !controller.ContainerRunning(containerID) { + output.FatalWithHintExamples([][]string{ + {localizer.Sprintf("To start the container"), localizer.Sprintf("sqlcmd start")}, + }, localizer.Sprintf("Container is not running")) + } +} + +func (c *Ssms) launchSsms(host string, port int, user *sqlconfig.User, isLocalConnection bool) { + output := c.Output() + + if user != nil && user.AuthenticationType == "basic" && user.BasicAuth != nil { + copyPasswordToClipboard(user, output) + } + + c.displayPreLaunchInfo() + + serverArg := fmt.Sprintf("%s,%d", host, port) + + args := []string{ + "-S", serverArg, + "-nosplash", + } + + if db := os.Getenv("SQLCMDDATABASE"); db != "" { + args = append(args, "-d", db) + } + + if user != nil && user.AuthenticationType == "basic" && user.BasicAuth != nil { + username := strings.ReplaceAll(user.BasicAuth.Username, `"`, `\"`) + args = append(args, "-U", username) + } + + tool := tools.NewTool("ssms") + if !tool.IsInstalled() { + output.Fatal(tool.HowToInstall()) + } + + // -C (trust server certificate) for self-signed certs on local containers. + // Only supported by SSMS 21+. + if isLocalConnection && ssmsVersion(tool.ExePath()) >= 21 { + args = append(args, "-C") + } + + _, err := tool.Run(args) + c.CheckErr(err) +} + +var ssmsVersionRe = regexp.MustCompile(`Management Studio (\d+)`) + +// ssmsVersion returns 0 if the version cannot be determined from the path. +func ssmsVersion(exePath string) int { + m := ssmsVersionRe.FindStringSubmatch(exePath) + if len(m) < 2 { + return 0 + } + v, _ := strconv.Atoi(m[1]) + return v +} diff --git a/cmd/modern/root/open/ssms_test.go b/cmd/modern/root/open/ssms_test.go new file mode 100644 index 00000000..ef445d02 --- /dev/null +++ b/cmd/modern/root/open/ssms_test.go @@ -0,0 +1,136 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +//go:build windows + +package open + +import ( + "testing" + + "github.com/microsoft/go-sqlcmd/cmd/modern/sqlconfig" + "github.com/microsoft/go-sqlcmd/internal/cmdparser" + "github.com/microsoft/go-sqlcmd/internal/config" + "github.com/microsoft/go-sqlcmd/internal/tools" +) + +// TestSsms runs a sanity test of `sqlcmd open ssms` +func TestSsms(t *testing.T) { + // Skip if SSMS is not installed + tool := tools.NewTool("ssms") + if !tool.IsInstalled() { + t.Skip("SSMS is not installed") + } + + cmdparser.TestSetup(t) + config.AddEndpoint(sqlconfig.Endpoint{ + AssetDetails: nil, + EndpointDetails: sqlconfig.EndpointDetails{ + Address: "localhost", + Port: 1433, + }, + Name: "endpoint", + }) + config.AddContext(sqlconfig.Context{ + ContextDetails: sqlconfig.ContextDetails{ + Endpoint: "endpoint", + User: nil, + }, + Name: "context", + }) + config.SetCurrentContextName("context") + + cmdparser.TestCmd[*Ssms]() +} + +// TestSsmsContextWithUser tests SSMS setup with user credentials +func TestSsmsContextWithUser(t *testing.T) { + cmdparser.TestSetup(t) + + // Set up context with SQL authentication user + config.AddEndpoint(sqlconfig.Endpoint{ + AssetDetails: nil, + EndpointDetails: sqlconfig.EndpointDetails{ + Address: "localhost", + Port: 1433, + }, + Name: "ssms-test-endpoint", + }) + + config.AddUser(sqlconfig.User{ + AuthenticationType: "basic", + BasicAuth: &sqlconfig.BasicAuthDetails{ + Username: "sa", + PasswordEncryption: "", + Password: "TestPassword123", + }, + Name: "ssms-test-user", + }) + + config.AddContext(sqlconfig.Context{ + ContextDetails: sqlconfig.ContextDetails{ + Endpoint: "ssms-test-endpoint", + User: strPtr("ssms-test-user"), + }, + Name: "ssms-test-context", + }) + config.SetCurrentContextName("ssms-test-context") + + // Verify context is set up correctly + endpoint, user := config.CurrentContext() + + if endpoint.Address != "localhost" { + t.Errorf("Expected address 'localhost', got '%s'", endpoint.Address) + } + + if endpoint.Port != 1433 { + t.Errorf("Expected port 1433, got %d", endpoint.Port) + } + + if user == nil { + t.Fatal("Expected user to be set") + } + + if user.AuthenticationType != "basic" { + t.Errorf("Expected auth type 'basic', got '%s'", user.AuthenticationType) + } + + if user.BasicAuth.Username != "sa" { + t.Errorf("Expected username 'sa', got '%s'", user.BasicAuth.Username) + } +} + +// TestSsmsContextWithoutUser tests SSMS setup without user credentials +func TestSsmsContextWithoutUser(t *testing.T) { + cmdparser.TestSetup(t) + + // Set up context without user (e.g., for Windows authentication scenarios) + config.AddEndpoint(sqlconfig.Endpoint{ + AssetDetails: nil, + EndpointDetails: sqlconfig.EndpointDetails{ + Address: "myserver", + Port: 1433, + }, + Name: "ssms-no-user-endpoint", + }) + + config.AddContext(sqlconfig.Context{ + ContextDetails: sqlconfig.ContextDetails{ + Endpoint: "ssms-no-user-endpoint", + User: nil, + }, + Name: "ssms-no-user-context", + }) + config.SetCurrentContextName("ssms-no-user-context") + + // Verify context is set up correctly + endpoint, user := config.CurrentContext() + + if endpoint.Address != "myserver" { + t.Errorf("Expected address 'myserver', got '%s'", endpoint.Address) + } + + if user != nil { + t.Error("Expected user to be nil") + } +} diff --git a/cmd/modern/root/open/ssms_unix.go b/cmd/modern/root/open/ssms_unix.go new file mode 100644 index 00000000..3eb42952 --- /dev/null +++ b/cmd/modern/root/open/ssms_unix.go @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +//go:build !windows + +package open + +import ( + "github.com/microsoft/go-sqlcmd/internal/cmdparser" + "github.com/microsoft/go-sqlcmd/internal/localizer" +) + +// Type Ssms is used to implement the "open ssms" which launches SQL Server +// Management Studio and establishes a connection to the SQL Server for the current +// context +type Ssms struct { + cmdparser.Cmd +} + +// DefineCommand sets up the ssms command for non-Windows platforms +func (c *Ssms) DefineCommand(...cmdparser.CommandOptions) { + options := cmdparser.CommandOptions{ + Use: "ssms", + Short: localizer.Sprintf("Open SQL Server Management Studio and connect to current context"), + Examples: []cmdparser.ExampleOptions{{ + Description: localizer.Sprintf("Open SSMS and connect using the current context"), + Steps: []string{"sqlcmd open ssms"}}}, + Run: c.run, + } + + c.Cmd.DefineCommand(options) +} + +// run fails immediately on non-Windows platforms +func (c *Ssms) run() { + output := c.Output() + output.Fatal(localizer.Sprintf("SSMS is only available on Windows. Use 'sqlcmd open vscode' instead.")) +} diff --git a/cmd/modern/root/open/ssms_windows.go b/cmd/modern/root/open/ssms_windows.go new file mode 100644 index 00000000..f0d7462b --- /dev/null +++ b/cmd/modern/root/open/ssms_windows.go @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package open + +import ( + "github.com/microsoft/go-sqlcmd/internal/cmdparser" + "github.com/microsoft/go-sqlcmd/internal/localizer" +) + +// Type Ssms is used to implement the "open ssms" which launches SQL Server +// Management Studio and establishes a connection to the SQL Server for the current +// context +type Ssms struct { + cmdparser.Cmd +} + +// On Windows, display info before launching +func (c *Ssms) displayPreLaunchInfo() { + output := c.Output() + output.Info(localizer.Sprintf("Launching SQL Server Management Studio...")) +} diff --git a/cmd/modern/root/open/uri_darwin.go b/cmd/modern/root/open/uri_darwin.go new file mode 100644 index 00000000..73df4b32 --- /dev/null +++ b/cmd/modern/root/open/uri_darwin.go @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +//go:build darwin + +package open + +import "os/exec" + +// openURI opens a URI via the macOS protocol handler. +func openURI(uri string) error { + return exec.Command("open", uri).Run() +} diff --git a/cmd/modern/root/open/uri_linux.go b/cmd/modern/root/open/uri_linux.go new file mode 100644 index 00000000..1c603a26 --- /dev/null +++ b/cmd/modern/root/open/uri_linux.go @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +//go:build linux + +package open + +import "os/exec" + +// openURI opens a URI via the Linux protocol handler. +func openURI(uri string) error { + return exec.Command("xdg-open", uri).Run() +} diff --git a/cmd/modern/root/open/uri_windows.go b/cmd/modern/root/open/uri_windows.go new file mode 100644 index 00000000..5119e4f4 --- /dev/null +++ b/cmd/modern/root/open/uri_windows.go @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package open + +import ( + "os/exec" + "syscall" +) + +// openURI opens a URI via the Windows shell protocol handler. +func openURI(uri string) error { + cmd := exec.Command("rundll32", "url.dll,FileProtocolHandler", uri) + cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} + return cmd.Run() +} diff --git a/cmd/modern/root/open/vscode.go b/cmd/modern/root/open/vscode.go new file mode 100644 index 00000000..a5329237 --- /dev/null +++ b/cmd/modern/root/open/vscode.go @@ -0,0 +1,361 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package open + +import ( + "encoding/json" + "fmt" + "net/url" + "os" + "path/filepath" + "runtime" + "strings" + + "github.com/microsoft/go-sqlcmd/cmd/modern/sqlconfig" + "github.com/microsoft/go-sqlcmd/internal/cmdparser" + "github.com/microsoft/go-sqlcmd/internal/config" + "github.com/microsoft/go-sqlcmd/internal/container" + "github.com/microsoft/go-sqlcmd/internal/localizer" + "github.com/microsoft/go-sqlcmd/internal/test" + "github.com/microsoft/go-sqlcmd/internal/tools" + "github.com/tidwall/jsonc" +) + +// testSettingsEnvVar overrides getVSCodeSettingsPath in tests so they +// never touch the real VS Code settings.json. Set via t.Setenv. +const testSettingsEnvVar = "SQLCMD_TEST_VSCODE_SETTINGS_PATH" + +// VSCode implements the `sqlcmd open vscode` command. It opens +// Visual Studio Code and configures a connection profile for the +// current context using the MSSQL extension. +func (c *VSCode) DefineCommand(...cmdparser.CommandOptions) { + options := cmdparser.CommandOptions{ + Use: "vscode", + Short: localizer.Sprintf("Open Visual Studio Code and configure connection for current context"), + Examples: []cmdparser.ExampleOptions{ + { + Description: localizer.Sprintf("Open VS Code and configure connection using the current context"), + Steps: []string{"sqlcmd open vscode"}, + }, + { + Description: localizer.Sprintf("Open VS Code and install the MSSQL extension if needed"), + Steps: []string{"sqlcmd open vscode --install-extension"}, + }, + }, + Run: c.run, + } + + c.Cmd.DefineCommand(options) + + c.AddFlag(cmdparser.FlagOptions{ + Bool: &c.installExtension, + Name: "install-extension", + Usage: localizer.Sprintf("Install the MSSQL extension in VS Code if not already installed"), + }) +} + +func (c *VSCode) run() { + endpoint, user := config.CurrentContext() + isLocalConnection := isLocalEndpoint(endpoint) + + if asset := endpoint.AssetDetails; asset != nil && asset.ContainerDetails != nil { + c.ensureContainerIsRunning(asset.Id) + } + + c.createConnectionProfile(endpoint, user, isLocalConnection) + copyPasswordToClipboard(user, c.Output()) + c.launchVSCode(endpoint, user, isLocalConnection) +} + +func (c *VSCode) ensureContainerIsRunning(containerID string) { + output := c.Output() + controller := container.NewController() + if !controller.ContainerRunning(containerID) { + output.FatalWithHintExamples([][]string{ + {localizer.Sprintf("To start the container"), localizer.Sprintf("sqlcmd start")}, + }, localizer.Sprintf("Container is not running")) + } +} + +func (c *VSCode) launchVSCode(endpoint sqlconfig.Endpoint, user *sqlconfig.User, isLocalConnection bool) { + output := c.Output() + + tool := tools.NewTool("vscode") + if !tool.IsInstalled() { + output.Fatal(tool.HowToInstall()) + } + + // Install the MSSQL extension if explicitly requested + if c.installExtension { + output.Info(localizer.Sprintf("Installing MSSQL extension...")) + _, err := tool.Run([]string{"--install-extension", "ms-mssql.mssql", "--force"}) + if err != nil { + output.Warn(localizer.Sprintf("Could not install MSSQL extension: %s", err.Error())) + } else { + output.Info(localizer.Sprintf("MSSQL extension installed successfully")) + } + } else { + if !c.isMssqlExtensionInstalled() { + output.Warn(localizer.Sprintf("The MSSQL extension (ms-mssql.mssql) is not installed in VS Code")) + output.Info(localizer.Sprintf("To install: sqlcmd open vscode --install-extension")) + } + } + + c.displayPreLaunchInfo() + + if test.IsRunningInTestExecutor() { + return + } + + // Build a vscode:// URI that triggers the mssql extension's protocol + // handler, the same mechanism Fabric uses. The OS protocol handler routes + // it to VS Code without opening a second window. + connectURI := c.buildConnectURI(endpoint, user, isLocalConnection) + if err := openURI(connectURI); err != nil { + output.Warn(localizer.Sprintf("Could not open connection URI: %s", err.Error())) + // Fall back to just opening VS Code + _, err = tool.Run(nil) + c.CheckErr(err) + } +} + +// buildConnectURI creates a vscode://ms-mssql.mssql/connect URI with query +// params matching the connection profile. The mssql extension's protocol +// handler parses these to find a matching profile or open the connect dialog. +func (c *VSCode) buildConnectURI(endpoint sqlconfig.Endpoint, user *sqlconfig.User, isLocalConnection bool) string { + params := url.Values{} + params.Set("server", fmt.Sprintf("%s,%d", endpoint.Address, endpoint.Port)) + params.Set("profileName", config.CurrentContextName()) + + if isLocalConnection { + params.Set("encrypt", "Optional") + params.Set("trustServerCertificate", "true") + } else { + params.Set("encrypt", "Mandatory") + params.Set("trustServerCertificate", "false") + } + + if user != nil && user.AuthenticationType == "basic" && user.BasicAuth != nil { + params.Set("user", user.BasicAuth.Username) + params.Set("authenticationType", "SqlLogin") + } + + return "vscode://ms-mssql.mssql/connect?" + params.Encode() +} + +func (c *VSCode) createConnectionProfile(endpoint sqlconfig.Endpoint, user *sqlconfig.User, isLocalConnection bool) { + output := c.Output() + + settingsPath := c.getVSCodeSettingsPath() + + dir := filepath.Dir(settingsPath) + if err := os.MkdirAll(dir, 0755); err != nil { + output.FatalWithHintExamples([][]string{ + {localizer.Sprintf("Error"), err.Error()}, + }, localizer.Sprintf("Failed to create VS Code settings directory")) + } + + raw, readErr := os.ReadFile(settingsPath) + if readErr != nil && !os.IsNotExist(readErr) { + output.FatalWithHintExamples([][]string{ + {localizer.Sprintf("Error"), readErr.Error()}, + }, localizer.Sprintf("Failed to read VS Code settings")) + } + + settings := c.parseSettings(raw) + profile := c.createProfile(endpoint, user, isLocalConnection) + connections := c.getConnectionsArray(settings) + connections = c.updateOrAddProfile(connections, profile) + + // Patch only mssql.connections, preserving comments + patched, err := patchJSONCKey(raw, "mssql.connections", connections) + if err != nil { + output.FatalWithHintExamples([][]string{ + {localizer.Sprintf("Error"), err.Error()}, + }, localizer.Sprintf("Failed to update VS Code settings")) + } + + c.writeSettingsRaw(settingsPath, patched) + + output.Info(localizer.Sprintf("Connection profile created in VS Code settings")) +} + +func (c *VSCode) parseSettings(data []byte) map[string]interface{} { + settings := make(map[string]interface{}) + if len(data) > 0 { + clean := jsonc.ToJSON(data) + if err := json.Unmarshal(clean, &settings); err != nil { + output := c.Output() + output.FatalWithHintExamples([][]string{ + {localizer.Sprintf("Error"), err.Error()}, + }, localizer.Sprintf("Failed to parse VS Code settings")) + } + } + return settings +} + +func (c *VSCode) writeSettingsRaw(path string, data []byte) { + output := c.Output() + + // Preserve existing file permissions, or use 0600 for new files. + mode := os.FileMode(0600) + if info, err := os.Stat(path); err == nil { + mode = info.Mode() + } + + // Atomic write: temp file + rename, with direct-write fallback. + dir := filepath.Dir(path) + tmp, tmpErr := os.CreateTemp(dir, ".settings-*.tmp") + if tmpErr == nil { + tmpPath := tmp.Name() + _ = tmp.Chmod(mode) + _, writeErr := tmp.Write(data) + closeErr := tmp.Close() + if writeErr != nil || closeErr != nil { + _ = os.Remove(tmpPath) + } else if renameErr := os.Rename(tmpPath, path); renameErr != nil { + _ = os.Remove(tmpPath) + } else { + return // atomic write succeeded + } + } + + // Fallback: direct write + if err := os.WriteFile(path, data, mode); err != nil { + output.FatalWithHintExamples([][]string{ + {localizer.Sprintf("Error"), err.Error()}, + }, localizer.Sprintf("Failed to write VS Code settings")) + } +} + +func (c *VSCode) getConnectionsArray(settings map[string]interface{}) []interface{} { + connections := []interface{}{} + if existing, ok := settings["mssql.connections"]; ok { + if arr, ok := existing.([]interface{}); ok { + connections = arr + } + } + return connections +} + +func (c *VSCode) createProfile(endpoint sqlconfig.Endpoint, user *sqlconfig.User, isLocalConnection bool) map[string]interface{} { + contextName := config.CurrentContextName() + + encrypt := "Mandatory" + trustServerCertificate := false + + // Local connections (containers, localhost) commonly use self-signed certs + if isLocalConnection { + encrypt = "Optional" + trustServerCertificate = true + } + + profile := map[string]interface{}{ + "server": fmt.Sprintf("%s,%d", endpoint.Address, endpoint.Port), + "profileName": contextName, + "encrypt": encrypt, + "trustServerCertificate": trustServerCertificate, + } + + if user != nil && user.AuthenticationType == "basic" && user.BasicAuth != nil { + profile["user"] = user.BasicAuth.Username + profile["authenticationType"] = "SqlLogin" + profile["savePassword"] = true + } + + return profile +} + +func (c *VSCode) updateOrAddProfile(connections []interface{}, newProfile map[string]interface{}) []interface{} { + profileName, ok := newProfile["profileName"].(string) + if !ok { + return append(connections, newProfile) + } + + for i, conn := range connections { + if connMap, ok := conn.(map[string]interface{}); ok { + if name, ok := connMap["profileName"].(string); ok && name == profileName { + connections[i] = newProfile + return connections + } + } + } + + return append(connections, newProfile) +} + +func (c *VSCode) getVSCodeSettingsPath() string { + if override := os.Getenv(testSettingsEnvVar); override != "" { + return override + } + + var stableDir string + var insidersDir string + + homeDir := func() string { + if home, err := os.UserHomeDir(); err == nil { + return home + } + return "." + } + + switch runtime.GOOS { + case "windows": + base := os.Getenv("APPDATA") + if base == "" { + base = filepath.Join(homeDir(), "AppData", "Roaming") + } + stableDir = filepath.Join(base, "Code", "User") + insidersDir = filepath.Join(base, "Code - Insiders", "User") + case "darwin": + base := filepath.Join(homeDir(), "Library", "Application Support") + stableDir = filepath.Join(base, "Code", "User") + insidersDir = filepath.Join(base, "Code - Insiders", "User") + default: // linux and others + base := filepath.Join(homeDir(), ".config") + stableDir = filepath.Join(base, "Code", "User") + insidersDir = filepath.Join(base, "Code - Insiders", "User") + } + + // Prefer VS Code Insiders settings if the directory exists, since the tool + // searches for and launches Insiders first. Fall back to stable Code. + configDir := stableDir + if info, err := os.Stat(insidersDir); err == nil && info.IsDir() { + configDir = insidersDir + } + + return filepath.Join(configDir, "settings.json") +} + +// isMssqlExtensionInstalled checks the VS Code extensions directory on disk +// instead of running Code.exe --list-extensions (which opens a window). +func (c *VSCode) isMssqlExtensionInstalled() bool { + home, err := os.UserHomeDir() + if err != nil { + return true // assume installed if we can't check + } + + for _, dir := range []string{".vscode-insiders", ".vscode"} { + ext := filepath.Join(home, dir, "extensions") + entries, err := os.ReadDir(ext) + if err != nil { + continue + } + for _, e := range entries { + if strings.HasPrefix(strings.ToLower(e.Name()), "ms-mssql.mssql-") { + return true + } + } + } + return false +} + +func isLocalEndpoint(endpoint sqlconfig.Endpoint) bool { + if asset := endpoint.AssetDetails; asset != nil && asset.ContainerDetails != nil { + return true + } + + addr := strings.ToLower(endpoint.Address) + return addr == "localhost" || addr == "127.0.0.1" || addr == "::1" || addr == "host.docker.internal" +} diff --git a/cmd/modern/root/open/vscode_platform.go b/cmd/modern/root/open/vscode_platform.go new file mode 100644 index 00000000..522803b7 --- /dev/null +++ b/cmd/modern/root/open/vscode_platform.go @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package open + +import ( + "github.com/microsoft/go-sqlcmd/internal/cmdparser" + "github.com/microsoft/go-sqlcmd/internal/config" + "github.com/microsoft/go-sqlcmd/internal/localizer" +) + +// Type VSCode is used to implement the "open vscode" which launches Visual +// Studio Code and establishes a connection to the SQL Server for the current +// context +type VSCode struct { + cmdparser.Cmd + installExtension bool +} + +func (c *VSCode) displayPreLaunchInfo() { + output := c.Output() + + output.Info(localizer.Sprintf("Opening VS Code...")) + output.Info(localizer.Sprintf("Use the '%s' connection profile to connect", config.CurrentContextName())) +} diff --git a/cmd/modern/root/open/vscode_test.go b/cmd/modern/root/open/vscode_test.go new file mode 100644 index 00000000..793862a7 --- /dev/null +++ b/cmd/modern/root/open/vscode_test.go @@ -0,0 +1,442 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package open + +import ( + "encoding/json" + "os" + "path/filepath" + "runtime" + "strings" + "testing" + + "github.com/microsoft/go-sqlcmd/cmd/modern/sqlconfig" + "github.com/microsoft/go-sqlcmd/internal/cmdparser" + "github.com/microsoft/go-sqlcmd/internal/config" + "github.com/microsoft/go-sqlcmd/internal/tools" +) + +// TestVSCode runs a sanity test of `sqlcmd open vscode` +func TestVSCode(t *testing.T) { + if runtime.GOOS == "linux" { + t.Skip("Skipping on Linux due to ADS tool initialization issue") + } + + tool := tools.NewTool("vscode") + if !tool.IsInstalled() { + t.Skip("VS Code is not installed") + } + + // Redirect settings writes to a temp directory so the test never + // touches the real VS Code settings.json. + t.Setenv(testSettingsEnvVar, filepath.Join(t.TempDir(), "settings.json")) + + cmdparser.TestSetup(t) + config.AddEndpoint(sqlconfig.Endpoint{ + AssetDetails: nil, + EndpointDetails: sqlconfig.EndpointDetails{ + Address: "localhost", + Port: 1433, + }, + Name: "endpoint", + }) + config.AddContext(sqlconfig.Context{ + ContextDetails: sqlconfig.ContextDetails{ + Endpoint: "endpoint", + User: nil, + }, + Name: "context", + }) + config.SetCurrentContextName("context") + + cmdparser.TestCmd[*VSCode]() +} + +// TestVSCodeCreateProfile tests that createProfile generates correct profile structure +func TestVSCodeCreateProfile(t *testing.T) { + if runtime.GOOS == "linux" { + t.Skip("Skipping on Linux due to ADS tool initialization issue in tools factory") + } + + cmdparser.TestSetup(t) + + // Set up a context with user credentials + config.AddEndpoint(sqlconfig.Endpoint{ + AssetDetails: nil, + EndpointDetails: sqlconfig.EndpointDetails{ + Address: "localhost", + Port: 1433, + }, + Name: "test-endpoint", + }) + + config.AddUser(sqlconfig.User{ + AuthenticationType: "basic", + BasicAuth: &sqlconfig.BasicAuthDetails{ + Username: "sa", + PasswordEncryption: "", + Password: "testpassword", + }, + Name: "test-user", + }) + + config.AddContext(sqlconfig.Context{ + ContextDetails: sqlconfig.ContextDetails{ + Endpoint: "test-endpoint", + User: strPtr("test-user"), + }, + Name: "my-database", + }) + config.SetCurrentContextName("my-database") + + // Create a VSCode command instance and test profile creation + vscode := &VSCode{} + endpoint, user := config.CurrentContext() + + profile := vscode.createProfile(endpoint, user, true) // true for local connection + + // Verify profile structure + if profile["server"] != "localhost,1433" { + t.Errorf("Expected server 'localhost,1433', got '%v'", profile["server"]) + } + + if profile["profileName"] != "my-database" { + t.Errorf("Expected profileName 'my-database', got '%v'", profile["profileName"]) + } + + if profile["authenticationType"] != "SqlLogin" { + t.Errorf("Expected authenticationType 'SqlLogin', got '%v'", profile["authenticationType"]) + } + + if profile["user"] != "sa" { + t.Errorf("Expected user 'sa', got '%v'", profile["user"]) + } + + if profile["encrypt"] != "Optional" { + t.Errorf("Expected encrypt 'Optional', got '%v'", profile["encrypt"]) + } + + if profile["trustServerCertificate"] != true { + t.Errorf("Expected trustServerCertificate true, got '%v'", profile["trustServerCertificate"]) + } + + if profile["savePassword"] != true { + t.Errorf("Expected savePassword true, got '%v'", profile["savePassword"]) + } +} + +// TestVSCodeUpdateOrAddProfile tests profile update and add logic +func TestVSCodeUpdateOrAddProfile(t *testing.T) { + if runtime.GOOS == "linux" { + t.Skip("Skipping on Linux due to ADS tool initialization issue in tools factory") + } + + cmdparser.TestSetup(t) + + vscode := &VSCode{} + + // Test adding a new profile to empty list + connections := []interface{}{} + newProfile := map[string]interface{}{ + "profileName": "test-profile", + "server": "localhost,1433", + } + + result := vscode.updateOrAddProfile(connections, newProfile) + if len(result) != 1 { + t.Errorf("Expected 1 connection, got %d", len(result)) + } + + // Test adding a second profile with different name + secondProfile := map[string]interface{}{ + "profileName": "another-profile", + "server": "server2,1434", + } + + result = vscode.updateOrAddProfile(result, secondProfile) + if len(result) != 2 { + t.Errorf("Expected 2 connections, got %d", len(result)) + } + + // Test updating existing profile (same name) + updatedProfile := map[string]interface{}{ + "profileName": "test-profile", + "server": "localhost,2000", + "user": "newuser", + } + + result = vscode.updateOrAddProfile(result, updatedProfile) + if len(result) != 2 { + t.Errorf("Expected 2 connections after update, got %d", len(result)) + } + + // Verify the profile was updated, not duplicated + found := false + for _, conn := range result { + if connMap, ok := conn.(map[string]interface{}); ok { + if connMap["profileName"] == "test-profile" { + found = true + if connMap["server"] != "localhost,2000" { + t.Errorf("Expected updated server 'localhost,2000', got '%v'", connMap["server"]) + } + if connMap["user"] != "newuser" { + t.Errorf("Expected updated user 'newuser', got '%v'", connMap["user"]) + } + } + } + } + if !found { + t.Error("Updated profile not found in connections") + } +} + +func TestVSCodeReadWriteSettings(t *testing.T) { + // Create a temporary directory for test settings + tempDir := t.TempDir() + settingsPath := filepath.Join(tempDir, "settings.json") + + // Test reading non-existent file (should not exist yet) + _, err := os.ReadFile(settingsPath) + if !os.IsNotExist(err) { + t.Error("Expected file to not exist") + } + + // Write some settings using direct JSON + settings := map[string]interface{}{ + "mssql.connections": []interface{}{ + map[string]interface{}{ + "profileName": "test", + "server": "localhost,1433", + }, + }, + "other.setting": "value", + } + + data, err := json.MarshalIndent(settings, "", " ") + if err != nil { + t.Fatalf("Failed to marshal settings: %v", err) + } + + if err := os.WriteFile(settingsPath, data, 0644); err != nil { + t.Fatalf("Failed to write settings: %v", err) + } + + // Verify file was created + if _, err := os.Stat(settingsPath); os.IsNotExist(err) { + t.Error("Settings file was not created") + } + + // Read settings back + readData, err := os.ReadFile(settingsPath) + if err != nil { + t.Fatalf("Failed to read settings: %v", err) + } + + var readSettings map[string]interface{} + if err := json.Unmarshal(readData, &readSettings); err != nil { + t.Fatalf("Failed to unmarshal settings: %v", err) + } + + if readSettings["other.setting"] != "value" { + t.Errorf("Expected 'other.setting' to be 'value', got '%v'", readSettings["other.setting"]) + } + + connections, ok := readSettings["mssql.connections"].([]interface{}) + if !ok || len(connections) != 1 { + t.Error("Expected 1 mssql connection in read settings") + } +} + +// TestVSCodeGetConnectionsArray tests extracting connections array from settings +func TestVSCodeGetConnectionsArray(t *testing.T) { + if runtime.GOOS == "linux" { + t.Skip("Skipping on Linux due to ADS tool initialization issue in tools factory") + } + + cmdparser.TestSetup(t) + + vscode := &VSCode{} + + // Test with no connections key + settings := map[string]interface{}{} + connections := vscode.getConnectionsArray(settings) + if len(connections) != 0 { + t.Errorf("Expected empty array, got %d items", len(connections)) + } + + // Test with connections array + settings["mssql.connections"] = []interface{}{ + map[string]interface{}{"profileName": "test1"}, + map[string]interface{}{"profileName": "test2"}, + } + connections = vscode.getConnectionsArray(settings) + if len(connections) != 2 { + t.Errorf("Expected 2 connections, got %d", len(connections)) + } + + // Test with wrong type (should return empty array) + settings["mssql.connections"] = "not an array" + connections = vscode.getConnectionsArray(settings) + if len(connections) != 0 { + t.Errorf("Expected empty array for invalid type, got %d items", len(connections)) + } +} + +// TestVSCodeGetSettingsPath tests that settings path is correctly determined +func TestVSCodeGetSettingsPath(t *testing.T) { + if runtime.GOOS == "linux" { + t.Skip("Skipping on Linux due to ADS tool initialization issue in tools factory") + } + + cmdparser.TestSetup(t) + + vscode := &VSCode{} + path := vscode.getVSCodeSettingsPath() + + // Verify path ends with settings.json + if filepath.Base(path) != "settings.json" { + t.Errorf("Expected path to end with 'settings.json', got '%s'", filepath.Base(path)) + } + + // Verify path contains expected directory components + switch runtime.GOOS { + case "windows": + if !strings.Contains(path, "Code") { + t.Errorf("Expected path to contain 'Code' on Windows, got '%s'", path) + } + case "darwin": + if !strings.Contains(path, "Application Support") { + t.Errorf("Expected path to contain 'Application Support' on macOS, got '%s'", path) + } + } +} + +// TestVSCodeProfileWithoutUser tests profile creation when no user is configured +func TestVSCodeProfileWithoutUser(t *testing.T) { + if runtime.GOOS == "linux" { + t.Skip("Skipping on Linux due to ADS tool initialization issue in tools factory") + } + + cmdparser.TestSetup(t) + + config.AddEndpoint(sqlconfig.Endpoint{ + AssetDetails: nil, + EndpointDetails: sqlconfig.EndpointDetails{ + Address: "myserver", + Port: 1433, + }, + Name: "no-user-endpoint", + }) + + config.AddContext(sqlconfig.Context{ + ContextDetails: sqlconfig.ContextDetails{ + Endpoint: "no-user-endpoint", + User: nil, + }, + Name: "no-user-context", + }) + config.SetCurrentContextName("no-user-context") + + vscode := &VSCode{} + endpoint, user := config.CurrentContext() + + profile := vscode.createProfile(endpoint, user, false) // false for non-local connection + + // Verify profile doesn't have user field when no user is configured + if _, hasUser := profile["user"]; hasUser { + t.Error("Expected profile to not have 'user' field when no user configured") + } + + // Verify other fields are still set correctly + if profile["profileName"] != "no-user-context" { + t.Errorf("Expected profileName 'no-user-context', got '%v'", profile["profileName"]) + } + + // Verify secure TLS settings for non-local connections + if profile["encrypt"] != "Mandatory" { + t.Errorf("Expected encrypt 'Mandatory' for non-local connection, got '%v'", profile["encrypt"]) + } + + if profile["trustServerCertificate"] != false { + t.Errorf("Expected trustServerCertificate false for non-local connection, got '%v'", profile["trustServerCertificate"]) + } +} + +func TestVSCodeSettingsPreservesOtherKeys(t *testing.T) { + if runtime.GOOS == "linux" { + t.Skip("Skipping on Linux due to ADS tool initialization issue in tools factory") + } + + cmdparser.TestSetup(t) + + vscode := &VSCode{} + tempDir := t.TempDir() + settingsPath := filepath.Join(tempDir, "settings.json") + + // Write initial settings with various keys + initialSettings := map[string]interface{}{ + "editor.fontSize": 14, + "workbench.theme": "Dark+", + "mssql.connections": []interface{}{}, + } + + data, err := json.MarshalIndent(initialSettings, "", " ") + if err != nil { + t.Fatalf("Failed to marshal initial settings: %v", err) + } + if err := os.WriteFile(settingsPath, data, 0644); err != nil { + t.Fatalf("Failed to write settings: %v", err) + } + + // Read settings back using direct JSON (simulating what readSettings does) + readData, err := os.ReadFile(settingsPath) + if err != nil { + t.Fatalf("Failed to read settings: %v", err) + } + var settings map[string]interface{} + if err := json.Unmarshal(readData, &settings); err != nil { + t.Fatalf("Failed to unmarshal settings: %v", err) + } + + // Get connections and add a new profile + connections := vscode.getConnectionsArray(settings) + newProfile := map[string]interface{}{ + "profileName": "new-profile", + "server": "localhost,1433", + } + connections = vscode.updateOrAddProfile(connections, newProfile) + settings["mssql.connections"] = connections + + // Write back using direct JSON (simulating what writeSettings does) + writeData, err := json.MarshalIndent(settings, "", " ") + if err != nil { + t.Fatalf("Failed to marshal settings: %v", err) + } + if err := os.WriteFile(settingsPath, writeData, 0644); err != nil { + t.Fatalf("Failed to write settings: %v", err) + } + + // Read back and verify other keys are preserved + finalData, err := os.ReadFile(settingsPath) + if err != nil { + t.Fatalf("Failed to read final settings: %v", err) + } + var finalSettings map[string]interface{} + if err := json.Unmarshal(finalData, &finalSettings); err != nil { + t.Fatalf("Failed to unmarshal final settings: %v", err) + } + + if finalSettings["editor.fontSize"].(float64) != 14 { + t.Errorf("Expected editor.fontSize to be preserved as 14, got %v", finalSettings["editor.fontSize"]) + } + + if finalSettings["workbench.theme"] != "Dark+" { + t.Errorf("Expected workbench.theme to be preserved as 'Dark+', got %v", finalSettings["workbench.theme"]) + } +} + +// Helper to create string pointer +func strPtr(s string) *string { + return &s +} diff --git a/go.mod b/go.mod index 5e1e4b1f..8e62abf2 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/spf13/pflag v1.0.10 github.com/spf13/viper v1.21.0 github.com/stretchr/testify v1.11.1 + github.com/tidwall/jsonc v0.3.3 golang.org/x/sys v0.43.0 golang.org/x/term v0.42.0 golang.org/x/text v0.36.0 diff --git a/go.sum b/go.sum index 7208c831..c5395db4 100644 --- a/go.sum +++ b/go.sum @@ -193,6 +193,8 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/tidwall/jsonc v0.3.3 h1:RVQqL3xFfDkKKXIDsrBiVQiEpBtxoKbmMXONb2H/y2w= +github.com/tidwall/jsonc v0.3.3/go.mod h1:dw+3CIxqHi+t8eFSpzzMlcVYxKp08UP5CD8/uSFCyJE= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 h1:OyrsyzuttWTSur2qN/Lm0m2a8yqyIjUVBZcxFPuXq2o= diff --git a/internal/pal/clipboard.go b/internal/pal/clipboard.go new file mode 100644 index 00000000..d3e78e0b --- /dev/null +++ b/internal/pal/clipboard.go @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package pal + +// CopyToClipboard copies the given text to the system clipboard. +// Returns an error if the clipboard operation fails. +func CopyToClipboard(text string) error { + return copyToClipboard(text) +} diff --git a/internal/pal/clipboard_darwin.go b/internal/pal/clipboard_darwin.go new file mode 100644 index 00000000..d6012f22 --- /dev/null +++ b/internal/pal/clipboard_darwin.go @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package pal + +import ( + "os/exec" + "strings" +) + +func copyToClipboard(text string) error { + cmd := exec.Command("pbcopy") + cmd.Stdin = strings.NewReader(text) + return cmd.Run() +} diff --git a/internal/pal/clipboard_linux.go b/internal/pal/clipboard_linux.go new file mode 100644 index 00000000..52025f56 --- /dev/null +++ b/internal/pal/clipboard_linux.go @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package pal + +import ( + "fmt" + "os/exec" + "strings" +) + +func copyToClipboard(text string) error { + var attempts []string + + tryCmd := func(name string, args ...string) bool { + if _, err := exec.LookPath(name); err != nil { + attempts = append(attempts, fmt.Sprintf("%s not found", name)) + return false + } + + cmd := exec.Command(name, args...) + cmd.Stdin = strings.NewReader(text) + if err := cmd.Run(); err != nil { + attempts = append(attempts, fmt.Sprintf("%s failed: %v", name, err)) + return false + } + + return true + } + + if tryCmd("xclip", "-selection", "clipboard") { + return nil + } + + if tryCmd("xsel", "--clipboard", "--input") { + return nil + } + + if tryCmd("wl-copy") { + return nil + } + + return fmt.Errorf("failed to copy to clipboard; tried xclip, xsel, wl-copy: %s", strings.Join(attempts, "; ")) +} diff --git a/internal/pal/clipboard_test.go b/internal/pal/clipboard_test.go new file mode 100644 index 00000000..96b73971 --- /dev/null +++ b/internal/pal/clipboard_test.go @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package pal + +import ( + "testing" +) + +func TestCopyToClipboard(t *testing.T) { + // This test just ensures the function doesn't panic + // Actual clipboard testing would require platform-specific validation + err := CopyToClipboard("test password") + if err != nil { + // Don't fail on Linux headless environments where clipboard tools may not exist + t.Logf("CopyToClipboard returned error (may be expected in headless environment): %v", err) + } +} diff --git a/internal/pal/clipboard_windows.go b/internal/pal/clipboard_windows.go new file mode 100644 index 00000000..79b42d06 --- /dev/null +++ b/internal/pal/clipboard_windows.go @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package pal + +import ( + "os/exec" + "strings" +) + +func copyToClipboard(text string) error { + cmd := exec.Command("clip") + cmd.Stdin = strings.NewReader(text) + return cmd.Run() +} diff --git a/internal/tools/tool/ads.go b/internal/tools/tool/ads.go index f9295e7b..55ed211e 100644 --- a/internal/tools/tool/ads.go +++ b/internal/tools/tool/ads.go @@ -28,8 +28,7 @@ func (t *AzureDataStudio) Init() { func (t *AzureDataStudio) Run(args []string) (int, error) { if !test.IsRunningInTestExecutor() { - return t.tool.Run(args) - } else { - return 0, nil + return t.Launch(args) } + return 0, nil } diff --git a/internal/tools/tool/interface.go b/internal/tools/tool/interface.go index a8910175..146a50bc 100644 --- a/internal/tools/tool/interface.go +++ b/internal/tools/tool/interface.go @@ -6,6 +6,7 @@ package tool type Tool interface { Init() Name() (name string) + ExePath() string Run(args []string) (exitCode int, err error) IsInstalled() bool HowToInstall() string diff --git a/internal/tools/tool/ssms.go b/internal/tools/tool/ssms.go new file mode 100644 index 00000000..0e038fdd --- /dev/null +++ b/internal/tools/tool/ssms.go @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package tool + +import ( + "github.com/microsoft/go-sqlcmd/internal/io/file" + "github.com/microsoft/go-sqlcmd/internal/test" +) + +type SSMS struct { + tool +} + +func (t *SSMS) Init() { + t.SetToolDescription(Description{ + Name: "ssms", + Purpose: "SQL Server Management Studio (SSMS) is an integrated environment for managing SQL Server infrastructure.", + InstallText: t.installText()}) + + for _, location := range t.searchLocations() { + if file.Exists(location) { + t.SetExePathAndName(location) + break + } + } +} + +func (t *SSMS) Run(args []string) (int, error) { + if !test.IsRunningInTestExecutor() { + return t.Launch(args) + } + return 0, nil +} diff --git a/internal/tools/tool/ssms_test.go b/internal/tools/tool/ssms_test.go new file mode 100644 index 00000000..a60343aa --- /dev/null +++ b/internal/tools/tool/ssms_test.go @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package tool + +import "testing" + +func TestSSMS(t *testing.T) { + tool := SSMS{} + tool.Init() + + if tool.Name() != "ssms" { + t.Errorf("Expected name to be 'ssms', got %s", tool.Name()) + } +} diff --git a/internal/tools/tool/ssms_unix.go b/internal/tools/tool/ssms_unix.go new file mode 100644 index 00000000..e8fcb2dc --- /dev/null +++ b/internal/tools/tool/ssms_unix.go @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +//go:build !windows + +package tool + +func (t *SSMS) searchLocations() []string { + return []string{} +} + +func (t *SSMS) installText() string { + return `SQL Server Management Studio (SSMS) is only available on Windows. + +Please use: +- Visual Studio Code with the MSSQL extension: sqlcmd open vscode +- Azure Data Studio: sqlcmd open ads` +} diff --git a/internal/tools/tool/ssms_windows.go b/internal/tools/tool/ssms_windows.go new file mode 100644 index 00000000..5ff58071 --- /dev/null +++ b/internal/tools/tool/ssms_windows.go @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package tool + +import ( + "os" + "path/filepath" + "regexp" + "sort" + "strconv" +) + +const ssmsPrefix = "Microsoft SQL Server Management Studio " + +// SSMS 21+ moved the exe under Release\; older versions use Common7\IDE directly. +var ssmsExeSubPaths = []string{ + `Release\Common7\IDE\Ssms.exe`, + `Common7\IDE\Ssms.exe`, +} + +func (t *SSMS) searchLocations() []string { + programFiles := os.Getenv("ProgramFiles") + programFilesX86 := os.Getenv("ProgramFiles(x86)") + + roots := []string{programFiles, programFilesX86} + dirs := discoverSSMSDirs(roots) + + var paths []string + for _, dir := range dirs { + for _, sub := range ssmsExeSubPaths { + paths = append(paths, filepath.Join(dir, sub)) + } + } + return paths +} + +// discoverSSMSDirs globs for SSMS install directories under the given root +// folders and returns them sorted by version number descending (newest first). +func discoverSSMSDirs(roots []string) []string { + versionRe := regexp.MustCompile(`(\d+)$`) + type entry struct { + path string + version int + } + + var entries []entry + for _, root := range roots { + if root == "" { + continue + } + matches, err := filepath.Glob(filepath.Join(root, ssmsPrefix+"*")) + if err != nil { + continue + } + for _, m := range matches { + base := filepath.Base(m) + if sub := versionRe.FindString(base); sub != "" { + if v, err := strconv.Atoi(sub); err == nil { + entries = append(entries, entry{m, v}) + } + } + } + } + + sort.Slice(entries, func(i, j int) bool { + return entries[i].version > entries[j].version + }) + + dirs := make([]string, len(entries)) + for i, e := range entries { + dirs[i] = e.path + } + return dirs +} + +func (t *SSMS) installText() string { + return `Install using a package manager: + + winget install Microsoft.SQLServerManagementStudio + # or + choco install sql-server-management-studio + +Or download the latest version from: + + https://aka.ms/ssmsfullsetup + +Note: SSMS is only available on Windows.` +} diff --git a/internal/tools/tool/tool.go b/internal/tools/tool/tool.go index ee4d5db4..10169055 100644 --- a/internal/tools/tool/tool.go +++ b/internal/tools/tool/tool.go @@ -22,6 +22,10 @@ func (t *tool) SetExePathAndName(exeName string) { t.exeName = exeName } +func (t *tool) ExePath() string { + return t.exeName +} + func (t *tool) SetToolDescription(description Description) { t.description = description } @@ -32,7 +36,8 @@ func (t *tool) IsInstalled() bool { } t.installed = new(bool) - if file.Exists(t.exeName) { + // Handle case where tool wasn't found during Init (exeName is empty) + if t.exeName != "" && file.Exists(t.exeName) { *t.installed = true } else { *t.installed = false @@ -54,11 +59,27 @@ func (t *tool) HowToInstall() string { func (t *tool) Run(args []string) (int, error) { if t.installed == nil { - panic("Call IsInstalled before Run") + return 1, fmt.Errorf("internal error: Call IsInstalled before Run") } cmd := t.generateCommandLine(args) err := cmd.Run() - return cmd.ProcessState.ExitCode(), err + exitCode := 0 + if cmd.ProcessState != nil { + exitCode = cmd.ProcessState.ExitCode() + } + + return exitCode, err +} + +// Launch starts the process without waiting for it to exit. +func (t *tool) Launch(args []string) (int, error) { + if t.installed == nil { + return 1, fmt.Errorf("internal error: Call IsInstalled before Launch") + } + + cmd := t.generateCommandLine(args) + err := cmd.Start() + return 0, err } diff --git a/internal/tools/tool/tool_linux.go b/internal/tools/tool/tool_linux.go index 4344e37b..a5658959 100644 --- a/internal/tools/tool/tool_linux.go +++ b/internal/tools/tool/tool_linux.go @@ -4,9 +4,17 @@ package tool import ( + "bytes" "os/exec" ) func (t *tool) generateCommandLine(args []string) *exec.Cmd { - panic("Not yet implemented") + var stdout, stderr bytes.Buffer + cmd := &exec.Cmd{ + Path: t.exeName, + Args: append([]string{t.exeName}, args...), + Stdout: &stdout, + Stderr: &stderr, + } + return cmd } diff --git a/internal/tools/tool/tool_test.go b/internal/tools/tool/tool_test.go index 659b8fa1..d869f931 100644 --- a/internal/tools/tool/tool_test.go +++ b/internal/tools/tool/tool_test.go @@ -4,11 +4,12 @@ package tool import ( - "github.com/stretchr/testify/assert" "os" "runtime" "strings" "testing" + + "github.com/stretchr/testify/assert" ) func TestInit(t *testing.T) { @@ -94,12 +95,9 @@ func TestHowToInstall(t *testing.T) { func TestRunWhenNotInstalled(t *testing.T) { tool := &tool{} - assert.Panics(t, func() { - _, err := tool.Run([]string{}) - if err != nil { - return - } - }) + _, err := tool.Run([]string{}) + assert.Error(t, err, "Run should return error when IsInstalled was not called first") + assert.Contains(t, err.Error(), "Call IsInstalled before Run") } func TestRun(t *testing.T) { diff --git a/internal/tools/tool/tool_windows.go b/internal/tools/tool/tool_windows.go index 3e3aeaa5..a5658959 100644 --- a/internal/tools/tool/tool_windows.go +++ b/internal/tools/tool/tool_windows.go @@ -12,7 +12,7 @@ func (t *tool) generateCommandLine(args []string) *exec.Cmd { var stdout, stderr bytes.Buffer cmd := &exec.Cmd{ Path: t.exeName, - Args: args, + Args: append([]string{t.exeName}, args...), Stdout: &stdout, Stderr: &stderr, } diff --git a/internal/tools/tool/vscode.go b/internal/tools/tool/vscode.go new file mode 100644 index 00000000..29ca5a08 --- /dev/null +++ b/internal/tools/tool/vscode.go @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package tool + +import ( + "github.com/microsoft/go-sqlcmd/internal/io/file" + "github.com/microsoft/go-sqlcmd/internal/test" +) + +type VSCode struct { + tool +} + +func (t *VSCode) Init() { + t.SetToolDescription(Description{ + Name: "vscode", + Purpose: "Visual Studio Code is a code editor with support for database management through the MSSQL extension.", + InstallText: t.installText()}) + + for _, location := range t.searchLocations() { + if file.Exists(location) { + t.SetExePathAndName(location) + break + } + } +} + +func (t *VSCode) Run(args []string) (int, error) { + if !test.IsRunningInTestExecutor() { + return t.tool.Run(args) + } + return 0, nil +} diff --git a/internal/tools/tool/vscode_darwin.go b/internal/tools/tool/vscode_darwin.go new file mode 100644 index 00000000..eeba8097 --- /dev/null +++ b/internal/tools/tool/vscode_darwin.go @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package tool + +import ( + "os" + "path/filepath" +) + +func (t *VSCode) searchLocations() []string { + userProfile := os.Getenv("HOME") + + return []string{ + filepath.Join("/", "Applications", "Visual Studio Code - Insiders.app"), + filepath.Join(userProfile, "Downloads", "Visual Studio Code - Insiders.app"), + filepath.Join("/", "Applications", "Visual Studio Code.app"), + filepath.Join(userProfile, "Downloads", "Visual Studio Code.app"), + } +} + +func (t *VSCode) installText() string { + return `Install using Homebrew: + + brew install --cask visual-studio-code + +Or download the latest version from: + + https://code.visualstudio.com/download + +After installation, install the MSSQL extension: + + sqlcmd open vscode --install-extension + +Or install it directly in VS Code via Extensions (Cmd+Shift+X) and search for "SQL Server (mssql)"` +} diff --git a/internal/tools/tool/vscode_linux.go b/internal/tools/tool/vscode_linux.go new file mode 100644 index 00000000..bccbeefa --- /dev/null +++ b/internal/tools/tool/vscode_linux.go @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package tool + +import ( + "os" + "path/filepath" +) + +func (t *VSCode) searchLocations() []string { + userProfile := os.Getenv("HOME") + + return []string{ + filepath.Join("/", "usr", "bin", "code-insiders"), + filepath.Join("/", "usr", "bin", "code"), + filepath.Join(userProfile, ".local", "bin", "code-insiders"), + filepath.Join(userProfile, ".local", "bin", "code"), + filepath.Join("/", "snap", "bin", "code"), + } +} + +func (t *VSCode) installText() string { + return `Install using a package manager: + + # Debian/Ubuntu + sudo apt install code + + # Fedora/RHEL + sudo dnf install code + + # Snap + sudo snap install code --classic + +Or download the latest version from: + + https://code.visualstudio.com/download + +After installation, install the MSSQL extension: + + sqlcmd open vscode --install-extension + +Or install it directly in VS Code via Extensions (Ctrl+Shift+X) and search for "SQL Server (mssql)"` +} diff --git a/internal/tools/tool/vscode_test.go b/internal/tools/tool/vscode_test.go new file mode 100644 index 00000000..2c35beeb --- /dev/null +++ b/internal/tools/tool/vscode_test.go @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package tool + +import "testing" + +func TestVSCode(t *testing.T) { + tool := VSCode{} + tool.Init() + + if tool.Name() != "vscode" { + t.Errorf("Expected name to be 'vscode', got %s", tool.Name()) + } +} diff --git a/internal/tools/tool/vscode_windows.go b/internal/tools/tool/vscode_windows.go new file mode 100644 index 00000000..106a8b8c --- /dev/null +++ b/internal/tools/tool/vscode_windows.go @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package tool + +import ( + "os" + "path/filepath" +) + +// Search in this order +// +// User Insiders Install +// System Insiders Install +// User non-Insiders install +// System non-Insiders install +func (t *VSCode) searchLocations() []string { + userProfile := os.Getenv("USERPROFILE") + programFiles := os.Getenv("ProgramFiles") + + return []string{ + filepath.Join(userProfile, "AppData\\Local\\Programs\\Microsoft VS Code Insiders\\Code - Insiders.exe"), + filepath.Join(programFiles, "Microsoft VS Code Insiders\\Code - Insiders.exe"), + filepath.Join(userProfile, "AppData\\Local\\Programs\\Microsoft VS Code\\Code.exe"), + filepath.Join(programFiles, "Microsoft VS Code\\Code.exe"), + } +} + +func (t *VSCode) installText() string { + return `Install using a package manager: + + winget install Microsoft.VisualStudioCode + # or + choco install vscode + +Or download the latest version from: + + https://code.visualstudio.com/download + +After installation, install the MSSQL extension: + + sqlcmd open vscode --install-extension + +Or install it directly in VS Code via Extensions (Ctrl+Shift+X) and search for "SQL Server (mssql)"` +} diff --git a/internal/tools/tools.go b/internal/tools/tools.go index d60d7fee..cb4431e7 100644 --- a/internal/tools/tools.go +++ b/internal/tools/tools.go @@ -9,4 +9,6 @@ import ( var tools = []tool.Tool{ &tool.AzureDataStudio{}, + &tool.VSCode{}, + &tool.SSMS{}, }