Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 80 additions & 16 deletions apps/cli-go/pkg/migration/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ type MigrationFile struct {
var (
migrateFilePattern = regexp.MustCompile(`^([0-9]+)_(.*)\.sql$`)
typeNamePattern = regexp.MustCompile(`type "([^"]+)" does not exist`)
createIndexPattern = regexp.MustCompile(`^CREATE\s+(UNIQUE\s+)?INDEX\s+CONCURRENTLY(\s|\z)`)
reindexPattern = regexp.MustCompile(`^REINDEX(\s|\().*\sCONCURRENTLY(\s|\z)`)
vacuumPattern = regexp.MustCompile(`^VACUUM(\s|\(|\z)`)
alterSystemPattern = regexp.MustCompile(`^ALTER\s+SYSTEM(\s|\z)`)
clusterPattern = regexp.MustCompile(`^CLUSTER(\s|\z)`)
)

func NewMigrationFromFile(path string, fsys fs.FS) (*MigrationFile, error) {
Expand Down Expand Up @@ -72,23 +77,46 @@ func NewMigrationFromReader(sql io.Reader) (*MigrationFile, error) {
return &MigrationFile{Statements: lines}, nil
}

func (m *MigrationFile) ExecBatch(ctx context.Context, conn *pgx.Conn) error {
// Batch migration commands, without using statement cache
batch := &pgconn.Batch{}
for _, line := range m.Statements {
batch.ExecParams(line, nil, nil, nil, nil)
}
// Insert into migration history
if len(m.Version) > 0 {
if err := m.insertVersionSQL(conn, batch); err != nil {
return err
func isPipelineIncompatible(sql string) bool {
upper := strings.ToUpper(trimLeadingSQLComments(sql))
return createIndexPattern.MatchString(upper) ||
reindexPattern.MatchString(upper) ||
vacuumPattern.MatchString(upper) ||
alterSystemPattern.MatchString(upper) ||
clusterPattern.MatchString(upper)
}

func trimLeadingSQLComments(sql string) string {
trimmed := strings.TrimLeftFunc(sql, func(r rune) bool {
return r == '\ufeff' || r == ' ' || r == '\t' || r == '\n' || r == '\r'
})
for {
switch {
case strings.HasPrefix(trimmed, "--"):
if idx := strings.IndexByte(trimmed, '\n'); idx >= 0 {
trimmed = strings.TrimLeft(trimmed[idx+1:], " \t\n\r")
continue
}
return ""
case strings.HasPrefix(trimmed, "/*"):
if idx := strings.Index(trimmed, "*/"); idx >= 0 {
trimmed = strings.TrimLeft(trimmed[idx+2:], " \t\n\r")
continue
}
return trimmed
default:
return strings.TrimSpace(trimmed)
}
}
// ExecBatch is implicitly transactional
if result, err := conn.PgConn().ExecBatch(ctx, batch).ReadAll(); err != nil {
// Defaults to printing the last statement on error
}

func (m *MigrationFile) ExecBatch(ctx context.Context, conn *pgx.Conn) error {
batch := &pgconn.Batch{}
batchSize := 0
executed := 0

formatError := func(err error, i int) error {
stat := INSERT_MIGRATION_VERSION
i := len(result)
if i < len(m.Statements) {
stat = m.Statements[i]
}
Expand All @@ -99,7 +127,6 @@ func (m *MigrationFile) ExecBatch(ctx context.Context, conn *pgx.Conn) error {
if len(pgErr.Detail) > 0 {
msg = append(msg, pgErr.Detail)
}
// Provide helpful hint for extension type errors (SQLSTATE 42704: undefined_object)
if typeName := extractTypeName(pgErr.Message); len(typeName) > 0 && pgErr.Code == "42704" && !IsSchemaQualified(typeName) {
msg = append(msg, "")
msg = append(msg, "Hint: This type may be defined in a schema that's not in your search_path.")
Expand All @@ -111,7 +138,44 @@ func (m *MigrationFile) ExecBatch(ctx context.Context, conn *pgx.Conn) error {
msg = append(msg, fmt.Sprintf("At statement: %d", i), stat)
return errors.Errorf("%w\n%s", err, strings.Join(msg, "\n"))
}
return nil

flushBatch := func() error {
if batchSize == 0 {
return nil
}
if result, err := conn.PgConn().ExecBatch(ctx, batch).ReadAll(); err != nil {
return formatError(err, executed+len(result))
}
executed += batchSize
batch = &pgconn.Batch{}
batchSize = 0
return nil
}

for _, line := range m.Statements {
if isPipelineIncompatible(line) {
if err := flushBatch(); err != nil {
return err
}
if _, err := conn.PgConn().Exec(ctx, line).ReadAll(); err != nil {
return formatError(err, executed)
}
executed++
} else {
batch.ExecParams(line, nil, nil, nil, nil)
batchSize++
}
}

// Insert into migration history
if len(m.Version) > 0 {
if err := m.insertVersionSQL(conn, batch); err != nil {
return err
}
batchSize++
}

return flushBatch()
}

func markError(stat string, pos int) string {
Expand Down
147 changes: 147 additions & 0 deletions apps/cli-go/pkg/migration/file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,73 @@ func TestMigrationFile(t *testing.T) {
assert.NoError(t, err)
})

t.Run("executes pipeline incompatible statements outside batch", func(t *testing.T) {
migration := MigrationFile{
Statements: []string{
"create table public.widgets(id bigint primary key)",
"CREATE UNIQUE INDEX CONCURRENTLY widgets_id_idx ON public.widgets(id)",
"alter table public.widgets enable row level security",
},
Version: "20260101000000",
Name: "create_widgets",
}
// Setup mock postgres
conn := pgtest.NewConn()
defer conn.Close(t)
conn.Query(migration.Statements[0]).
Reply("CREATE TABLE").
SimpleQuery(migration.Statements[1]).
Reply("CREATE INDEX").
Query(migration.Statements[2]).
Reply("ALTER TABLE").
Query(INSERT_MIGRATION_VERSION, migration.Version, migration.Name, migration.Statements).
Reply("INSERT 0 1")
// Run test
err := migration.ExecBatch(context.Background(), conn.MockClient(t))
// Check error
assert.NoError(t, err)
})

t.Run("records migration version when file has no statements", func(t *testing.T) {
migration := MigrationFile{
Version: "20260101000000",
Name: "empty_migration",
}
// Setup mock postgres
conn := pgtest.NewConn()
defer conn.Close(t)
conn.Query(INSERT_MIGRATION_VERSION, migration.Version, migration.Name, migration.Statements).
Reply("INSERT 0 1")
// Run test
err := migration.ExecBatch(context.Background(), conn.MockClient(t))
// Check error
assert.NoError(t, err)
})

t.Run("reports pipeline incompatible statement errors with statement index", func(t *testing.T) {
migration := MigrationFile{
Statements: []string{
"create table public.widgets(id bigint primary key)",
"CREATE INDEX CONCURRENTLY widgets_id_idx ON public.widgets(id)",
"alter table public.widgets enable row level security",
},
Version: "20260101000000",
Name: "create_widgets",
}
// Setup mock postgres
conn := pgtest.NewConn()
defer conn.Close(t)
conn.Query(migration.Statements[0]).
Reply("CREATE TABLE").
SimpleQuery(migration.Statements[1]).
ReplyError("25001", "CREATE INDEX CONCURRENTLY cannot be executed within a pipeline")
// Run test
err := migration.ExecBatch(context.Background(), conn.MockClient(t))
// Check error
assert.ErrorContains(t, err, "ERROR: CREATE INDEX CONCURRENTLY cannot be executed within a pipeline (SQLSTATE 25001)")
assert.ErrorContains(t, err, "At statement: 1\nCREATE INDEX CONCURRENTLY widgets_id_idx ON public.widgets(id)")
})

t.Run("throws error on insert failure", func(t *testing.T) {
migration := MigrationFile{
Statements: []string{"create schema public"},
Expand Down Expand Up @@ -152,6 +219,86 @@ func TestExtractTypeName(t *testing.T) {
})
}

func TestIsPipelineIncompatible(t *testing.T) {
cases := []struct {
name string
sql string
want bool
}{
{
name: "create index concurrently",
sql: "CREATE INDEX CONCURRENTLY widgets_id_idx ON public.widgets(id)",
want: true,
},
{
name: "create unique index concurrently",
sql: "CREATE UNIQUE INDEX CONCURRENTLY widgets_id_idx ON public.widgets(id)",
want: true,
},
{
name: "create index concurrently after comments",
sql: "-- cannot run in a transaction\n/* generated */\nCREATE INDEX CONCURRENTLY widgets_id_idx ON public.widgets(id)",
want: true,
},
{
name: "reindex table concurrently",
sql: "REINDEX TABLE CONCURRENTLY public.widgets",
want: true,
},
{
name: "reindex with options concurrently",
sql: "REINDEX (VERBOSE) INDEX CONCURRENTLY widgets_id_idx",
want: true,
},
{
name: "vacuum bare",
sql: "VACUUM",
want: true,
},
{
name: "vacuum with options",
sql: "VACUUM (ANALYZE) public.widgets",
want: true,
},
{
name: "alter system",
sql: "ALTER SYSTEM SET log_statement = 'all'",
want: true,
},
{
name: "cluster",
sql: "CLUSTER public.widgets USING widgets_id_idx",
want: true,
},
{
name: "ordinary create index",
sql: "CREATE INDEX widgets_id_idx ON public.widgets(id)",
want: false,
},
{
name: "concurrently in string literal",
sql: "SELECT 'CREATE INDEX CONCURRENTLY widgets_id_idx ON public.widgets(id)'",
want: false,
},
{
name: "concurrently in leading comment only",
sql: "-- CREATE INDEX CONCURRENTLY widgets_id_idx ON public.widgets(id)\nSELECT 1",
want: false,
},
{
name: "word prefix",
sql: "VACUUMING public.widgets",
want: false,
},
}

for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want, isPipelineIncompatible(tt.sql))
})
}
}

func TestIsSchemaQualified(t *testing.T) {
assert.True(t, IsSchemaQualified("extensions.ltree"))
assert.True(t, IsSchemaQualified("public.my_type"))
Expand Down
6 changes: 6 additions & 0 deletions apps/cli-go/pkg/pgtest/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ func (r *MockConn) Query(sql string, args ...any) *MockConn {
return r
}

// SimpleQuery adds a simple-protocol query to the mock connection.
func (r *MockConn) SimpleQuery(sql string) *MockConn {
r.script.Steps = append(r.script.Steps, ExpectSimpleQuery(sql))
return r
}

func (r *MockConn) encodeValueArg(v any) (value []byte, oid uint32) {
if v == nil {
return nil, pgtype.TextArrayOID
Expand Down
31 changes: 20 additions & 11 deletions apps/cli-go/pkg/pgtest/step.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ import (
var ci = pgtype.NewConnInfo()

type extendedQueryStep struct {
sql string
params [][]byte
oids []uint32
reply pgmock.Script
sql string
params [][]byte
oids []uint32
simpleOnly bool
reply pgmock.Script
}

func (e *extendedQueryStep) Step(backend *pgproto3.Backend) error {
Expand All @@ -24,6 +25,16 @@ func (e *extendedQueryStep) Step(backend *pgproto3.Backend) error {
return err
}

// Handle simple query
want := &pgproto3.Query{String: e.sql}
if m, ok := msg.(*pgproto3.Query); ok && reflect.DeepEqual(m, want) {
e.reply.Steps = append(e.reply.Steps, pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}))
return e.reply.Run(backend)
}
if e.simpleOnly {
return errors.Errorf("expected => %#v\nactual => %#v", want, msg)
}

// Handle prepared statements, name can be dynamic: lrupsc_5_0
if m, ok := msg.(*pgproto3.Parse); ok {
want := &pgproto3.Parse{Name: m.Name, Query: e.sql, ParameterOIDs: m.ParameterOIDs}
Expand Down Expand Up @@ -75,13 +86,6 @@ func (e *extendedQueryStep) Step(backend *pgproto3.Backend) error {
return e.reply.Run(backend)
}

// Handle simple query
want := &pgproto3.Query{String: e.sql}
if m, ok := msg.(*pgproto3.Query); ok && reflect.DeepEqual(m, want) {
e.reply.Steps = append(e.reply.Steps, pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}))
return e.reply.Run(backend)
}

return errors.Errorf("expected => %#v\nactual => %#v", want, msg)
}

Expand All @@ -90,6 +94,11 @@ func ExpectQuery(sql string, params [][]byte, oids []uint32) pgmock.Step {
return &extendedQueryStep{sql: sql, params: params, oids: oids}
}

// ExpectSimpleQuery expects SQL through the simple query protocol.
func ExpectSimpleQuery(sql string) pgmock.Step {
return &extendedQueryStep{sql: sql, simpleOnly: true}
}

type terminateStep struct{}

func (e *terminateStep) Step(backend *pgproto3.Backend) error {
Expand Down
Loading