Skip to content

Commit 3077847

Browse files
author
David Shiflet (from Dev Box)
committed
fix argument handling
1 parent 15ea586 commit 3077847

4 files changed

Lines changed: 32 additions & 8 deletions

File tree

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,8 @@ linux-s390x/sqlcmd
3636
# Build artifacts in root
3737
/sqlcmd
3838
/sqlcmd_binary
39+
40+
# certificates used for local testing
41+
*.der
42+
*.pem
43+
*.pfx

cmd/sqlcmd/sqlcmd.go

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,15 @@ const (
128128
removeControlCharacters = "remove-control-characters"
129129
)
130130

131+
func encryptConnectionAllowsTLS(value string) bool {
132+
switch strings.ToLower(value) {
133+
case "s", "strict", "m", "mandatory", "true", "t", "yes", "1":
134+
return true
135+
default:
136+
return false
137+
}
138+
}
139+
131140
// Validate arguments for settings not describe
132141
func (a *SQLCmdArguments) Validate(c *cobra.Command) (err error) {
133142
if a.ListServers != "" {
@@ -145,6 +154,8 @@ func (a *SQLCmdArguments) Validate(c *cobra.Command) (err error) {
145154
err = mutuallyExclusiveError("-E", `-U/-P`)
146155
case a.UseAad && len(a.AuthenticationMethod) > 0:
147156
err = mutuallyExclusiveError("-G", "--authentication-method")
157+
case len(a.HostNameInCertificate) > 0 && len(a.ServerCertificate) > 0:
158+
err = mutuallyExclusiveError("-F", "-J")
148159
case a.PacketSize != 0 && (a.PacketSize < 512 || a.PacketSize > 32767):
149160
err = localizer.Errorf(`'-a %#v': Packet size has to be a number between 512 and 32767.`, a.PacketSize)
150161
// Ignore 0 even though it's technically an invalid input
@@ -158,8 +169,8 @@ func (a *SQLCmdArguments) Validate(c *cobra.Command) (err error) {
158169
err = rangeParameterError("-y", fmt.Sprint(*a.VariableTypeWidth), 0, 8000, true)
159170
case a.QueryTimeout < 0 || a.QueryTimeout > 65534:
160171
err = rangeParameterError("-t", fmt.Sprint(a.QueryTimeout), 0, 65534, true)
161-
case a.ServerCertificate != "" && a.EncryptConnection != "s" && a.EncryptConnection != "strict":
162-
err = localizer.Errorf("The -J parameter can only be used with strict encryption mode (-N s or -N strict).")
172+
case a.ServerCertificate != "" && !encryptConnectionAllowsTLS(a.EncryptConnection):
173+
err = localizer.Errorf("The -J parameter requires encryption to be enabled (-N true, -N mandatory, or -N strict).")
163174
}
164175
}
165176
if err != nil {
@@ -432,7 +443,8 @@ func setFlags(rootCmd *cobra.Command, args *SQLCmdArguments) {
432443
rootCmd.Flags().StringVarP(&args.ApplicationIntent, applicationIntent, "K", "default", localizer.Sprintf("Declares the application workload type when connecting to a server. The only currently supported value is ReadOnly. If %s is not specified, the sqlcmd utility will not support connectivity to a secondary replica in an Always On availability group", localizer.ApplicationIntentFlagShort))
433444
rootCmd.Flags().StringVarP(&args.EncryptConnection, encryptConnection, "N", "default", localizer.Sprintf("This switch is used by the client to request an encrypted connection"))
434445
rootCmd.Flags().StringVarP(&args.HostNameInCertificate, "host-name-in-certificate", "F", "", localizer.Sprintf("Specifies the host name in the server certificate."))
435-
rootCmd.Flags().StringVarP(&args.ServerCertificate, "server-certificate", "J", "", localizer.Sprintf("Specifies the path to a server certificate file (PEM, DER, or CER) to match against the server's TLS certificate. Used with strict encryption mode (-N s or -N strict) for certificate pinning instead of standard certificate validation."))
446+
rootCmd.Flags().StringVarP(&args.ServerCertificate, "server-certificate", "J", "", localizer.Sprintf("Specifies the path to a server certificate file (PEM, DER, or CER) to match against the server's TLS certificate. Use when encryption is enabled (-N true, -N mandatory, or -N strict) for certificate pinning instead of standard certificate validation."))
447+
rootCmd.MarkFlagsMutuallyExclusive("host-name-in-certificate", "server-certificate")
436448
// Can't use NoOptDefVal until this fix: https://github.com/spf13/cobra/issues/866
437449
//rootCmd.Flags().Lookup(encryptConnection).NoOptDefVal = "true"
438450
rootCmd.Flags().BoolVarP(&args.Vertical, "vertical", "", false, localizer.Sprintf("Prints the output in vertical format. This option sets the sqlcmd scripting variable %s to '%s'. The default is false", sqlcmd.SQLCMDFORMAT, "vert"))

cmd/sqlcmd/sqlcmd_test.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,12 @@ func TestValidCommandLineToArgsConversion(t *testing.T) {
117117
{[]string{"-N", "strict", "-J", "/path/to/cert.der"}, func(args SQLCmdArguments) bool {
118118
return args.EncryptConnection == "strict" && args.ServerCertificate == "/path/to/cert.der"
119119
}},
120+
{[]string{"-N", "m", "-J", "/path/to/cert.cer"}, func(args SQLCmdArguments) bool {
121+
return args.EncryptConnection == "m" && args.ServerCertificate == "/path/to/cert.cer"
122+
}},
123+
{[]string{"-N", "true", "-J", "/path/to/cert2.pem"}, func(args SQLCmdArguments) bool {
124+
return args.EncryptConnection == "true" && args.ServerCertificate == "/path/to/cert2.pem"
125+
}},
120126
}
121127

122128
for _, test := range commands {
@@ -160,17 +166,18 @@ func TestInvalidCommandLine(t *testing.T) {
160166
{[]string{"-E", "-U", "someuser"}, "The -E and the -U/-P options are mutually exclusive."},
161167
{[]string{"-L", "-q", `"select 1"`}, "The -L parameter can not be used in combination with other parameters."},
162168
{[]string{"-i", "foo.sql", "-q", `"select 1"`}, "The i and the -Q/-q options are mutually exclusive."},
163-
{[]string{"-r", "5"}, `'-r 5': Unexpected argument. Argument value has to be one of [0 1].`},
169+
{[]string{"-r", "5"}, "'-r 5': Unexpected argument. Argument value has to be one of [0 1]."},
164170
{[]string{"-w", "x"}, "'-w x': value must be greater than 8 and less than 65536."},
165171
{[]string{"-y", "111111"}, "'-y 111111': value must be greater than or equal to 0 and less than or equal to 8000."},
166172
{[]string{"-Y", "-2"}, "'-Y -2': value must be greater than or equal to 0 and less than or equal to 8000."},
167173
{[]string{"-P"}, "'-P': Missing argument. Enter '-?' for help."},
168174
{[]string{"-;"}, "';': Unknown Option. Enter '-?' for help."},
169175
{[]string{"-t", "-2"}, "'-t -2': value must be greater than or equal to 0 and less than or equal to 65534."},
170176
{[]string{"-N", "invalid"}, "'-N invalid': Unexpected argument. Argument value has to be one of [m[andatory] yes 1 t[rue] disable o[ptional] no 0 f[alse] s[trict]]."},
171-
{[]string{"-J", "/path/to/cert.pem"}, "The -J parameter can only be used with strict encryption mode (-N s or -N strict)."},
172-
{[]string{"-N", "m", "-J", "/path/to/cert.pem"}, "The -J parameter can only be used with strict encryption mode (-N s or -N strict)."},
173-
{[]string{"-N", "optional", "-J", "/path/to/cert.pem"}, "The -J parameter can only be used with strict encryption mode (-N s or -N strict)."},
177+
{[]string{"-J", "/path/to/cert.pem"}, "The -J parameter requires encryption to be enabled (-N true, -N mandatory, or -N strict)."},
178+
{[]string{"-N", "optional", "-J", "/path/to/cert.pem"}, "The -J parameter requires encryption to be enabled (-N true, -N mandatory, or -N strict)."},
179+
{[]string{"-N", "disable", "-J", "/path/to/cert.pem"}, "The -J parameter requires encryption to be enabled (-N true, -N mandatory, or -N strict)."},
180+
{[]string{"-N", "strict", "-F", "myserver.domain.com", "-J", "/path/to/cert.pem"}, "The -F and the -J options are mutually exclusive."},
174181
}
175182

176183
for _, test := range commands {

pkg/sqlcmd/connect.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ func (connect ConnectSettings) ConnectionString() (connectionString string, err
153153
query.Add(msdsn.HostNameInCertificate, connect.HostNameInCertificate)
154154
}
155155
if connect.ServerCertificate != "" {
156-
query.Add(msdsn.Certificate, connect.ServerCertificate)
156+
query.Add(msdsn.ServerCertificate, connect.ServerCertificate)
157157
}
158158
if connect.LogLevel > 0 {
159159
query.Add(msdsn.LogParam, fmt.Sprint(connect.LogLevel))

0 commit comments

Comments
 (0)