Skip to content

Commit 2389362

Browse files
Address Copilot review: handle SQL quote escaping and add tests
- Fix isExitParenBalanced to handle SQL Server quote escaping rules: - Escaped single quotes ('') inside string literals - Escaped bracket identifiers (]]) - Add test cases for escaped quotes in TestIsExitParenBalanced - Add comprehensive TestReadExitContinuation tests: - Continuation lines until balanced - Error handling on readline failure - Multiple continuation lines - Early return for already balanced input
1 parent 724059a commit 2389362

2 files changed

Lines changed: 108 additions & 4 deletions

File tree

pkg/sqlcmd/commands.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,15 +210,23 @@ func (c Commands) SetBatchTerminator(terminator string) error {
210210

211211
// isExitParenBalanced checks if the parentheses in an EXIT command argument are balanced.
212212
// It tracks quotes to avoid counting parens inside string literals.
213+
// It handles SQL Server's quote escaping: ” inside strings and ]] inside bracket identifiers.
213214
func isExitParenBalanced(s string) bool {
214215
depth := 0
215216
var quote rune
216-
for _, c := range s {
217+
runes := []rune(s)
218+
for i := 0; i < len(runes); i++ {
219+
c := runes[i]
217220
switch {
218221
case quote != 0:
219222
// Inside a quoted string
220223
if c == quote {
221-
quote = 0
224+
// Check for escaped quote ('' or ]])
225+
if i+1 < len(runes) && runes[i+1] == quote {
226+
i++ // skip the escaped quote
227+
} else {
228+
quote = 0
229+
}
222230
}
223231
case c == '\'' || c == '"':
224232
quote = c

pkg/sqlcmd/commands_test.go

Lines changed: 98 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ package sqlcmd
55

66
import (
77
"bytes"
8+
"errors"
89
"fmt"
10+
"io"
911
"os"
1012
"strings"
1113
"testing"
@@ -476,8 +478,12 @@ func TestIsExitParenBalanced(t *testing.T) {
476478
{"", true}, // empty string is balanced
477479
{"no parens", true}, // no parens is balanced
478480
{"(", false},
479-
{")", false}, // depth goes -1, not balanced
480-
{"(test))", false}, // depth goes -1 at end
481+
{")", false}, // depth goes -1, not balanced
482+
{"(test))", false}, // depth goes -1 at end
483+
{"(select 'can''t')", true}, // escaped single quote
484+
{"(select [col]]name])", true}, // escaped bracket identifier
485+
{"(select 'it''s a )test')", true}, // escaped quote with paren
486+
{"(select [a]]])", true}, // escaped bracket with paren
481487
}
482488
for _, test := range tests {
483489
t.Run(test.input, func(t *testing.T) {
@@ -486,3 +492,93 @@ func TestIsExitParenBalanced(t *testing.T) {
486492
})
487493
}
488494
}
495+
496+
func TestReadExitContinuation(t *testing.T) {
497+
t.Run("reads continuation lines until balanced", func(t *testing.T) {
498+
s := &Sqlcmd{}
499+
lines := []string{"+ 2)", ""}
500+
lineIndex := 0
501+
promptSet := ""
502+
s.lineIo = &testConsole{
503+
OnReadLine: func() (string, error) {
504+
if lineIndex >= len(lines) {
505+
return "", io.EOF
506+
}
507+
line := lines[lineIndex]
508+
lineIndex++
509+
return line, nil
510+
},
511+
OnPasswordPrompt: func(prompt string) ([]byte, error) {
512+
return nil, nil
513+
},
514+
}
515+
s.lineIo.SetPrompt("")
516+
517+
result, err := readExitContinuation(s, "(select 1")
518+
assert.NoError(t, err)
519+
assert.Equal(t, "(select 1\r\n+ 2)", result)
520+
521+
// Verify prompt was set
522+
tc := s.lineIo.(*testConsole)
523+
promptSet = tc.PromptText
524+
assert.Equal(t, " -> ", promptSet)
525+
})
526+
527+
t.Run("returns error on readline failure", func(t *testing.T) {
528+
s := &Sqlcmd{}
529+
expectedErr := errors.New("readline error")
530+
s.lineIo = &testConsole{
531+
OnReadLine: func() (string, error) {
532+
return "", expectedErr
533+
},
534+
OnPasswordPrompt: func(prompt string) ([]byte, error) {
535+
return nil, nil
536+
},
537+
}
538+
539+
_, err := readExitContinuation(s, "(select 1")
540+
assert.Equal(t, expectedErr, err)
541+
})
542+
543+
t.Run("handles multiple continuation lines", func(t *testing.T) {
544+
s := &Sqlcmd{}
545+
lines := []string{"+ 2", "+ 3", ")"}
546+
lineIndex := 0
547+
s.lineIo = &testConsole{
548+
OnReadLine: func() (string, error) {
549+
if lineIndex >= len(lines) {
550+
return "", io.EOF
551+
}
552+
line := lines[lineIndex]
553+
lineIndex++
554+
return line, nil
555+
},
556+
OnPasswordPrompt: func(prompt string) ([]byte, error) {
557+
return nil, nil
558+
},
559+
}
560+
561+
result, err := readExitContinuation(s, "(select 1")
562+
assert.NoError(t, err)
563+
assert.Equal(t, "(select 1\r\n+ 2\r\n+ 3\r\n)", result)
564+
})
565+
566+
t.Run("returns immediately if already balanced", func(t *testing.T) {
567+
s := &Sqlcmd{}
568+
readLineCalled := false
569+
s.lineIo = &testConsole{
570+
OnReadLine: func() (string, error) {
571+
readLineCalled = true
572+
return "", nil
573+
},
574+
OnPasswordPrompt: func(prompt string) ([]byte, error) {
575+
return nil, nil
576+
},
577+
}
578+
579+
result, err := readExitContinuation(s, "(select 1)")
580+
assert.NoError(t, err)
581+
assert.Equal(t, "(select 1)", result)
582+
assert.False(t, readLineCalled, "Readline should not be called for balanced input")
583+
})
584+
}

0 commit comments

Comments
 (0)