@@ -5,7 +5,9 @@ package sqlcmd
55
66import (
77 "bytes"
8+ "errors"
89 "fmt"
10+ "io"
911 "os"
1012 "strings"
1113 "testing"
@@ -476,8 +478,12 @@ func TestIsExitParenBalanced(t *testing.T) {
476478 {"" , true }, // empty string is balanced
477479 {"no parens" , true }, // no parens is balanced
478480 {"(" , false },
479- {")" , false }, // depth goes -1, not balanced
480- {"(test))" , false }, // depth goes -1 at end
481+ {")" , false }, // depth goes -1, not balanced
482+ {"(test))" , false }, // depth goes -1 at end
483+ {"(select 'can''t')" , true }, // escaped single quote
484+ {"(select [col]]name])" , true }, // escaped bracket identifier
485+ {"(select 'it''s a )test')" , true }, // escaped quote with paren
486+ {"(select [a]]])" , true }, // escaped bracket with paren
481487 }
482488 for _ , test := range tests {
483489 t .Run (test .input , func (t * testing.T ) {
@@ -486,3 +492,93 @@ func TestIsExitParenBalanced(t *testing.T) {
486492 })
487493 }
488494}
495+
496+ func TestReadExitContinuation (t * testing.T ) {
497+ t .Run ("reads continuation lines until balanced" , func (t * testing.T ) {
498+ s := & Sqlcmd {}
499+ lines := []string {"+ 2)" , "" }
500+ lineIndex := 0
501+ promptSet := ""
502+ s .lineIo = & testConsole {
503+ OnReadLine : func () (string , error ) {
504+ if lineIndex >= len (lines ) {
505+ return "" , io .EOF
506+ }
507+ line := lines [lineIndex ]
508+ lineIndex ++
509+ return line , nil
510+ },
511+ OnPasswordPrompt : func (prompt string ) ([]byte , error ) {
512+ return nil , nil
513+ },
514+ }
515+ s .lineIo .SetPrompt ("" )
516+
517+ result , err := readExitContinuation (s , "(select 1" )
518+ assert .NoError (t , err )
519+ assert .Equal (t , "(select 1\r \n + 2)" , result )
520+
521+ // Verify prompt was set
522+ tc := s .lineIo .(* testConsole )
523+ promptSet = tc .PromptText
524+ assert .Equal (t , " -> " , promptSet )
525+ })
526+
527+ t .Run ("returns error on readline failure" , func (t * testing.T ) {
528+ s := & Sqlcmd {}
529+ expectedErr := errors .New ("readline error" )
530+ s .lineIo = & testConsole {
531+ OnReadLine : func () (string , error ) {
532+ return "" , expectedErr
533+ },
534+ OnPasswordPrompt : func (prompt string ) ([]byte , error ) {
535+ return nil , nil
536+ },
537+ }
538+
539+ _ , err := readExitContinuation (s , "(select 1" )
540+ assert .Equal (t , expectedErr , err )
541+ })
542+
543+ t .Run ("handles multiple continuation lines" , func (t * testing.T ) {
544+ s := & Sqlcmd {}
545+ lines := []string {"+ 2" , "+ 3" , ")" }
546+ lineIndex := 0
547+ s .lineIo = & testConsole {
548+ OnReadLine : func () (string , error ) {
549+ if lineIndex >= len (lines ) {
550+ return "" , io .EOF
551+ }
552+ line := lines [lineIndex ]
553+ lineIndex ++
554+ return line , nil
555+ },
556+ OnPasswordPrompt : func (prompt string ) ([]byte , error ) {
557+ return nil , nil
558+ },
559+ }
560+
561+ result , err := readExitContinuation (s , "(select 1" )
562+ assert .NoError (t , err )
563+ assert .Equal (t , "(select 1\r \n + 2\r \n + 3\r \n )" , result )
564+ })
565+
566+ t .Run ("returns immediately if already balanced" , func (t * testing.T ) {
567+ s := & Sqlcmd {}
568+ readLineCalled := false
569+ s .lineIo = & testConsole {
570+ OnReadLine : func () (string , error ) {
571+ readLineCalled = true
572+ return "" , nil
573+ },
574+ OnPasswordPrompt : func (prompt string ) ([]byte , error ) {
575+ return nil , nil
576+ },
577+ }
578+
579+ result , err := readExitContinuation (s , "(select 1)" )
580+ assert .NoError (t , err )
581+ assert .Equal (t , "(select 1)" , result )
582+ assert .False (t , readLineCalled , "Readline should not be called for balanced input" )
583+ })
584+ }
0 commit comments