Skip to content

Commit 424cc32

Browse files
Address Copilot review comments for codepage support
- Use localizer.Errorf for all user-facing error messages - Fix UTF-16 BOM handling using ExpectBOM for input decoding - Add transformWriteCloser to properly close underlying file handles - Use transformWriteCloser in outCommand and errorCommand for both UnicodeOutputFile and CodePage transforms to prevent file handle leaks - Add integration tests for output/error codepage encoding
1 parent faa945c commit 424cc32

3 files changed

Lines changed: 138 additions & 13 deletions

File tree

pkg/sqlcmd/codepage.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
package sqlcmd
55

66
import (
7-
"fmt"
87
"strconv"
98
"strings"
109

10+
"github.com/microsoft/go-sqlcmd/internal/localizer"
1111
"golang.org/x/text/encoding"
1212
"golang.org/x/text/encoding/charmap"
1313
"golang.org/x/text/encoding/japanese"
@@ -43,21 +43,21 @@ func ParseCodePage(arg string) (*CodePageSettings, error) {
4343
// Input codepage
4444
cp, err := strconv.Atoi(strings.TrimPrefix(strings.ToLower(part), "i:"))
4545
if err != nil {
46-
return nil, fmt.Errorf("invalid input codepage: %s", part)
46+
return nil, localizer.Errorf("invalid input codepage: %s", part)
4747
}
4848
settings.InputCodePage = cp
4949
} else if strings.HasPrefix(strings.ToLower(part), "o:") {
5050
// Output codepage
5151
cp, err := strconv.Atoi(strings.TrimPrefix(strings.ToLower(part), "o:"))
5252
if err != nil {
53-
return nil, fmt.Errorf("invalid output codepage: %s", part)
53+
return nil, localizer.Errorf("invalid output codepage: %s", part)
5454
}
5555
settings.OutputCodePage = cp
5656
} else {
5757
// Both input and output
5858
cp, err := strconv.Atoi(part)
5959
if err != nil {
60-
return nil, fmt.Errorf("invalid codepage: %s", part)
60+
return nil, localizer.Errorf("invalid codepage: %s", part)
6161
}
6262
settings.InputCodePage = cp
6363
settings.OutputCodePage = cp
@@ -88,11 +88,11 @@ func GetEncoding(codepage int) (encoding.Encoding, error) {
8888
// UTF-8 - Go's native encoding, return nil to indicate no transformation needed
8989
return nil, nil
9090
case 1200:
91-
// UTF-16LE
92-
return unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM), nil
91+
// UTF-16LE - Use ExpectBOM to strip BOM if present during input
92+
return unicode.UTF16(unicode.LittleEndian, unicode.ExpectBOM), nil
9393
case 1201:
94-
// UTF-16BE
95-
return unicode.UTF16(unicode.BigEndian, unicode.IgnoreBOM), nil
94+
// UTF-16BE - Use ExpectBOM to strip BOM if present during input
95+
return unicode.UTF16(unicode.BigEndian, unicode.ExpectBOM), nil
9696

9797
// OEM/DOS codepages
9898
case 437:
@@ -224,7 +224,7 @@ func GetEncoding(codepage int) (encoding.Encoding, error) {
224224
return traditionalchinese.Big5, nil
225225

226226
default:
227-
return nil, fmt.Errorf("unsupported codepage %d", codepage)
227+
return nil, localizer.Errorf("unsupported codepage %d", codepage)
228228
}
229229
}
230230

pkg/sqlcmd/commands.go

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,36 @@ package sqlcmd
66
import (
77
"flag"
88
"fmt"
9+
"io"
910
"os"
1011
"regexp"
1112
"sort"
1213
"strconv"
1314
"strings"
1415

1516
"github.com/microsoft/go-sqlcmd/internal/color"
17+
"github.com/microsoft/go-sqlcmd/internal/localizer"
1618
"golang.org/x/text/encoding/unicode"
1719
"golang.org/x/text/transform"
1820
)
1921

22+
// transformWriteCloser wraps a transform.Writer and ensures the underlying
23+
// file is closed when Close() is called.
24+
type transformWriteCloser struct {
25+
*transform.Writer
26+
underlying io.Closer
27+
}
28+
29+
// Close flushes the transform writer and closes the underlying file.
30+
func (t *transformWriteCloser) Close() error {
31+
// Close the transform writer (flushes pending data)
32+
if err := t.Writer.Close(); err != nil {
33+
_ = t.underlying.Close()
34+
return err
35+
}
36+
return t.underlying.Close()
37+
}
38+
2039
// Command defines a sqlcmd action which can be intermixed with the SQL batch
2140
// Commands for sqlcmd are defined at https://docs.microsoft.com/sql/tools/sqlcmd-utility#sqlcmd-commands
2241
type Command struct {
@@ -324,7 +343,10 @@ func outCommand(s *Sqlcmd, args []string, line uint) error {
324343
// ODBC sqlcmd doesn't write a BOM but we will.
325344
// Maybe the endian-ness should be configurable.
326345
win16le := unicode.UTF16(unicode.LittleEndian, unicode.UseBOM)
327-
encoder := transform.NewWriter(o, win16le.NewEncoder())
346+
encoder := &transformWriteCloser{
347+
Writer: transform.NewWriter(o, win16le.NewEncoder()),
348+
underlying: o,
349+
}
328350
s.SetOutput(encoder)
329351
} else if s.CodePage != nil && s.CodePage.OutputCodePage != 0 {
330352
// Use specified output codepage
@@ -335,7 +357,10 @@ func outCommand(s *Sqlcmd, args []string, line uint) error {
335357
}
336358
if enc != nil {
337359
// Transform from UTF-8 to specified encoding
338-
encoder := transform.NewWriter(o, enc.NewEncoder())
360+
encoder := &transformWriteCloser{
361+
Writer: transform.NewWriter(o, enc.NewEncoder()),
362+
underlying: o,
363+
}
339364
s.SetOutput(encoder)
340365
} else {
341366
// UTF-8, no transformation needed
@@ -372,15 +397,18 @@ func errorCommand(s *Sqlcmd, args []string, line uint) error {
372397
enc, err := GetEncoding(s.CodePage.OutputCodePage)
373398
if err != nil {
374399
if cerr := o.Close(); cerr != nil {
375-
return fmt.Errorf("%v; additionally, closing error file %q failed: %w", err, args[0], cerr)
400+
return localizer.Errorf("%v; additionally, closing error file %q failed: %v", err, args[0], cerr)
376401
}
377402
return err
378403
}
379404
if enc == nil {
380405
// UTF-8 (or default) encoding: write directly without transform
381406
s.SetError(o)
382407
} else {
383-
encoder := transform.NewWriter(o, enc.NewEncoder())
408+
encoder := &transformWriteCloser{
409+
Writer: transform.NewWriter(o, enc.NewEncoder()),
410+
underlying: o,
411+
}
384412
s.SetError(encoder)
385413
}
386414
} else {

pkg/sqlcmd/commands_test.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,3 +458,100 @@ func TestExitCommandAppendsParameterToCurrentBatch(t *testing.T) {
458458
}
459459

460460
}
461+
462+
func TestOutputCodePageCommand(t *testing.T) {
463+
tests := []struct {
464+
name string
465+
codepage int
466+
expectedBytes []byte
467+
inputText string
468+
skipOnEncError bool
469+
}{
470+
{
471+
name: "UTF-8 output",
472+
codepage: 65001,
473+
inputText: "café",
474+
expectedBytes: []byte("café"),
475+
},
476+
{
477+
name: "Windows-1252 output",
478+
codepage: 1252,
479+
inputText: "café",
480+
expectedBytes: []byte{0x63, 0x61, 0x66, 0xe9}, // "café" in Windows-1252
481+
},
482+
}
483+
484+
for _, tt := range tests {
485+
t.Run(tt.name, func(t *testing.T) {
486+
s, buf := setupSqlCmdWithMemoryOutput(t)
487+
defer buf.Close()
488+
489+
// Set up codepage
490+
s.CodePage = &CodePageSettings{
491+
OutputCodePage: tt.codepage,
492+
}
493+
494+
// Create temp file for output
495+
file, err := os.CreateTemp("", "sqlcmdout")
496+
require.NoError(t, err, "os.CreateTemp")
497+
defer os.Remove(file.Name())
498+
fileName := file.Name()
499+
_ = file.Close()
500+
501+
// Run the OUT command
502+
err = outCommand(s, []string{fileName}, 1)
503+
require.NoError(t, err, "outCommand")
504+
505+
// Write some text
506+
_, err = s.GetOutput().Write([]byte(tt.inputText))
507+
require.NoError(t, err, "Write")
508+
509+
// Close to flush
510+
if closer, ok := s.GetOutput().(interface{ Close() error }); ok {
511+
require.NoError(t, closer.Close(), "Close output")
512+
}
513+
514+
// Read the file and check encoding
515+
content, err := os.ReadFile(fileName)
516+
require.NoError(t, err, "ReadFile")
517+
assert.Equal(t, tt.expectedBytes, content, "Output encoding mismatch")
518+
})
519+
}
520+
}
521+
522+
func TestErrorCodePageCommand(t *testing.T) {
523+
s, buf := setupSqlCmdWithMemoryOutput(t)
524+
defer buf.Close()
525+
526+
// Set up codepage for Windows-1252
527+
s.CodePage = &CodePageSettings{
528+
OutputCodePage: 1252,
529+
}
530+
531+
// Create temp file for error output
532+
file, err := os.CreateTemp("", "sqlcmderr")
533+
require.NoError(t, err, "os.CreateTemp")
534+
defer os.Remove(file.Name())
535+
fileName := file.Name()
536+
_ = file.Close()
537+
538+
// Run the ERROR command
539+
err = errorCommand(s, []string{fileName}, 1)
540+
require.NoError(t, err, "errorCommand")
541+
542+
// Write some text with special characters
543+
_, err = s.err.Write([]byte("Error: café"))
544+
require.NoError(t, err, "Write")
545+
546+
// Close to flush
547+
if closer, ok := s.err.(interface{ Close() error }); ok {
548+
require.NoError(t, closer.Close(), "Close error")
549+
}
550+
551+
// Read the file and check encoding
552+
content, err := os.ReadFile(fileName)
553+
require.NoError(t, err, "ReadFile")
554+
// "Error: café" in Windows-1252
555+
expected := []byte{0x45, 0x72, 0x72, 0x6f, 0x72, 0x3a, 0x20, 0x63, 0x61, 0x66, 0xe9}
556+
assert.Equal(t, expected, content, "Error output encoding mismatch")
557+
}

0 commit comments

Comments
 (0)