Skip to content

Commit 3a206e4

Browse files
Add validation for unsupported config file extensions
Co-authored-by: dlevy-msft-sql <[email protected]>
1 parent 2fd7e3f commit 3a206e4

3 files changed

Lines changed: 129 additions & 2 deletions

File tree

internal/config/config.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ func SetFileName(name string) {
2727
filename = name
2828

2929
file.CreateEmptyIfNotExists(filename)
30-
configureViper(filename)
30+
err := configureViper(filename)
31+
checkErr(err)
3132
}
3233

3334
func SetFileNameForTest(t *testing.T) {

internal/config/viper.go

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@ package config
55

66
import (
77
"bytes"
8+
"github.com/microsoft/go-sqlcmd/internal/localizer"
89
"github.com/microsoft/go-sqlcmd/internal/pal"
910
"github.com/spf13/viper"
1011
"gopkg.in/yaml.v2"
12+
"path/filepath"
13+
"strings"
1114
)
1215

1316
// Load loads the configuration from the file specified by the SetFileName() function.
@@ -56,16 +59,45 @@ func GetConfigFileUsed() string {
5659
return viper.ConfigFileUsed()
5760
}
5861

62+
// validateConfigFileExtension checks if the config file has a supported extension.
63+
// It allows .yaml, .yml, and no extension (for default sqlconfig file).
64+
// Returns an error if the extension is not supported.
65+
func validateConfigFileExtension(configFile string) error {
66+
ext := strings.ToLower(filepath.Ext(configFile))
67+
68+
// Allow no extension (for default sqlconfig file)
69+
if ext == "" {
70+
return nil
71+
}
72+
73+
// Allow .yaml and .yml extensions
74+
if ext == ".yaml" || ext == ".yml" {
75+
return nil
76+
}
77+
78+
// Return error for unsupported extensions
79+
return localizer.Errorf(
80+
"Configuration files must use YAML format with .yaml or .yml extension.\n"+
81+
"The file '%s' has an unsupported extension '%s'.",
82+
configFile, ext)
83+
}
84+
5985
// configureViper initializes the Viper library with the given configuration file.
6086
// This function sets the configuration file type to "yaml" and sets the environment variable prefix to "SQLCMD".
6187
// It also sets the configuration file to use to the one provided as an argument to the function.
6288
// This function is intended to be called at the start of the application to configure Viper before any other code uses it.
63-
func configureViper(configFile string) {
89+
func configureViper(configFile string) error {
6490
if configFile == "" {
6591
panic("Must provide configFile")
6692
}
6793

94+
// Validate file extension
95+
if err := validateConfigFileExtension(configFile); err != nil {
96+
return err
97+
}
98+
6899
viper.SetConfigType("yaml")
69100
viper.SetEnvPrefix("SQLCMD")
70101
viper.SetConfigFile(configFile)
102+
return nil
71103
}

internal/config/viper_test.go

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,100 @@ func Test_configureViper(t *testing.T) {
1414
})
1515
}
1616

17+
func Test_validateConfigFileExtension(t *testing.T) {
18+
tests := []struct {
19+
name string
20+
filename string
21+
wantErr bool
22+
}{
23+
{
24+
name: "valid yaml extension",
25+
filename: "config.yaml",
26+
wantErr: false,
27+
},
28+
{
29+
name: "valid yml extension",
30+
filename: "config.yml",
31+
wantErr: false,
32+
},
33+
{
34+
name: "no extension (default sqlconfig)",
35+
filename: "sqlconfig",
36+
wantErr: false,
37+
},
38+
{
39+
name: "no extension with path",
40+
filename: "/home/user/.sqlcmd/sqlconfig",
41+
wantErr: false,
42+
},
43+
{
44+
name: "invalid txt extension",
45+
filename: "config.txt",
46+
wantErr: true,
47+
},
48+
{
49+
name: "invalid json extension",
50+
filename: "config.json",
51+
wantErr: true,
52+
},
53+
{
54+
name: "invalid xml extension",
55+
filename: "config.xml",
56+
wantErr: true,
57+
},
58+
{
59+
name: "uppercase YAML extension",
60+
filename: "config.YAML",
61+
wantErr: false,
62+
},
63+
{
64+
name: "uppercase YML extension",
65+
filename: "config.YML",
66+
wantErr: false,
67+
},
68+
{
69+
name: "mixed case yaml extension",
70+
filename: "config.Yaml",
71+
wantErr: false,
72+
},
73+
}
74+
75+
for _, tt := range tests {
76+
t.Run(tt.name, func(t *testing.T) {
77+
err := validateConfigFileExtension(tt.filename)
78+
if tt.wantErr {
79+
assert.Error(t, err, "Expected error for filename: %s", tt.filename)
80+
assert.Contains(t, err.Error(), "Configuration files must use YAML format")
81+
} else {
82+
assert.NoError(t, err, "Expected no error for filename: %s", tt.filename)
83+
}
84+
})
85+
}
86+
}
87+
88+
func Test_configureViper_withInvalidExtension(t *testing.T) {
89+
err := configureViper("myconfig.txt")
90+
assert.Error(t, err)
91+
assert.Contains(t, err.Error(), "Configuration files must use YAML format")
92+
assert.Contains(t, err.Error(), ".txt")
93+
}
94+
95+
func Test_configureViper_withValidExtensions(t *testing.T) {
96+
testCases := []string{
97+
"config.yaml",
98+
"config.yml",
99+
"sqlconfig",
100+
"/path/to/config.yaml",
101+
}
102+
103+
for _, filename := range testCases {
104+
t.Run(filename, func(t *testing.T) {
105+
err := configureViper(filename)
106+
assert.NoError(t, err, "Expected no error for filename: %s", filename)
107+
})
108+
}
109+
}
110+
17111
func Test_Load(t *testing.T) {
18112
SetFileNameForTest(t)
19113
Clean()

0 commit comments

Comments
 (0)