diff --git a/internal/api/api.go b/internal/api/api.go index 4bda6893..170e73c4 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -24,6 +24,7 @@ type RestAPIServer struct { } func Init(opts config.RestAPIOpts, logger log.LoggerIface) *RestAPIServer { + mux := http.NewServeMux() s := &RestAPIServer{ nil, logger, @@ -32,12 +33,13 @@ func Init(opts config.RestAPIOpts, logger log.LoggerIface) *RestAPIServer { ReadTimeout: 10 * time.Second, WriteTimeout: 10 * time.Second, MaxHeaderBytes: 1 << 20, + Handler: mux, }, } - http.HandleFunc("/liveness", s.livenessHandler) - http.HandleFunc("/readiness", s.readinessHandler) - http.HandleFunc("/startchain", s.chainHandler) - http.HandleFunc("/stopchain", s.chainHandler) + mux.HandleFunc("/liveness", s.livenessHandler) + mux.HandleFunc("/readiness", s.readinessHandler) + mux.HandleFunc("/startchain", s.chainHandler) + mux.HandleFunc("/stopchain", s.chainHandler) if opts.Port != 0 { logger.WithField("port", opts.Port).Info("Starting REST API server...") go func() { logger.Error(s.ListenAndServe()) }() diff --git a/main.go b/main.go index 375b72db..bb71122e 100644 --- a/main.go +++ b/main.go @@ -66,47 +66,21 @@ func printVersion() { `, version, dbapi, commit, date) } -func main() { - cmdOpts, err := config.NewConfig(os.Stdout) - if err != nil { - if cmdOpts != nil && cmdOpts.VersionOnly() { - printVersion() - return - } - fmt.Println("Configuration error: ", err) - exitCode = ExitCodeConfigError - return - } - if cmdOpts.Version { - printVersion() - } - - logger := log.Init(cmdOpts.Logging) - ctx, cancel := context.WithCancel(context.Background()) - SetupCloseHandler(cancel) - defer func() { - cancel() - if err := recover(); err != nil { - exitCode = ExitCodeFatalError - logger.WithField("callstack", string(debug.Stack())).Error(err) - } - os.Exit(exitCode) - }() - +// run contains the core application logic and returns an exit code. +func run(ctx context.Context, cmdOpts *config.CmdOptions, logger log.LoggerHookerIface) int { apiserver := api.Init(cmdOpts.RESTApi, logger) + var err error if pge, err = pgengine.New(ctx, *cmdOpts, logger); err != nil { logger.WithError(err).Error("Connection failed") - exitCode = ExitCodeDBEngineError - return + return ExitCodeDBEngineError } defer pge.Finalize() if cmdOpts.Start.Upgrade { if err := pge.MigrateDb(ctx); err != nil { logger.WithError(err).Error("Upgrade failed") - exitCode = ExitCodeUpgradeError - return + return ExitCodeUpgradeError } } else { if upgrade, err := pge.CheckNeedMigrateDb(ctx); upgrade || err != nil { @@ -116,17 +90,47 @@ func main() { if err != nil { logger.WithError(err).Error("Migration check failed") } - exitCode = ExitCodeUpgradeError - return + return ExitCodeUpgradeError } } if cmdOpts.Start.Init { - return + return ExitCodeOK } sch := scheduler.New(pge, logger) apiserver.APIHandler = sch if sch.Run(ctx) == scheduler.ShutdownStatus { - exitCode = ExitCodeShutdownCommand + return ExitCodeShutdownCommand + } + return ExitCodeOK +} + +func main() { + cmdOpts, err := config.NewConfig(os.Stdout) + if err != nil { + if cmdOpts != nil && cmdOpts.VersionOnly() { + printVersion() + return + } + fmt.Println("Configuration error: ", err) + exitCode = ExitCodeConfigError + return + } + if cmdOpts.Version { + printVersion() } + + logger := log.Init(cmdOpts.Logging) + ctx, cancel := context.WithCancel(context.Background()) + SetupCloseHandler(cancel) + defer func() { + cancel() + if err := recover(); err != nil { + exitCode = ExitCodeFatalError + logger.WithField("callstack", string(debug.Stack())).Error(err) + } + os.Exit(exitCode) + }() + + exitCode = run(ctx, cmdOpts, logger) } diff --git a/main_test.go b/main_test.go new file mode 100644 index 00000000..5e179eb7 --- /dev/null +++ b/main_test.go @@ -0,0 +1,167 @@ +package main + +import ( + "bytes" + "context" + "io" + "os" + "runtime" + "syscall" + "testing" + "time" + + "github.com/cybertec-postgresql/pg_timetable/internal/config" + "github.com/cybertec-postgresql/pg_timetable/internal/log" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/modules/postgres" + "github.com/testcontainers/testcontainers-go/wait" +) + +// newTestLogger returns a silent logger suitable for use in tests. +func newTestLogger() log.LoggerHookerIface { + return log.Init(config.LoggingOpts{LogLevel: "panic", LogDBLevel: "none"}) +} + +// setupTestContainer starts a bare PostgreSQL container and returns the +// connection string along with a cleanup function. Unlike the shared +// testutils helper, it does NOT initialise the pg_timetable schema so that +// run() can perform that step itself. +func setupTestContainer(t *testing.T) (connStr string, cleanup func()) { + t.Helper() + ctx := context.Background() + c, err := postgres.Run( + ctx, + "postgres:18-alpine", + postgres.WithDatabase("timetable"), + postgres.WithUsername("scheduler"), + postgres.WithPassword("somestrong"), + testcontainers.WithWaitStrategyAndDeadline( + 60*time.Second, + wait.ForLog("database system is ready to accept connections").WithOccurrence(2), + ), + ) + require.NoError(t, err, "Failed to start PostgreSQL container") + cs, err := c.ConnectionString(ctx, "sslmode=disable") + if err != nil { + _ = c.Terminate(ctx) + t.Fatalf("Failed to get connection string: %v", err) + } + return cs, func() { _ = c.Terminate(ctx) } +} + +// TestPrintVersion verifies that printVersion writes the expected fields to +// stdout. +func TestPrintVersion(t *testing.T) { + r, w, err := os.Pipe() + require.NoError(t, err) + oldStdout := os.Stdout + os.Stdout = w + + printVersion() + + w.Close() + os.Stdout = oldStdout + + var buf bytes.Buffer + _, _ = io.Copy(&buf, r) + out := buf.String() + + assert.Contains(t, out, "pg_timetable:") + assert.Contains(t, out, "Version:") + assert.Contains(t, out, "DB Schema:") + assert.Contains(t, out, "Git Commit:") + assert.Contains(t, out, "Built:") +} + +// TestSetupCloseHandler verifies that sending SIGTERM causes the provided +// cancel function to be called. Skipped on Windows where signal delivery to +// the current process works differently. +func TestSetupCloseHandler(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("SIGTERM delivery to self is not supported on Windows") + } + + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + SetupCloseHandler(func() { + cancel() + close(done) + }) + + p, err := os.FindProcess(os.Getpid()) + require.NoError(t, err) + require.NoError(t, p.Signal(syscall.SIGTERM)) + + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatal("cancel was not called within 3 s of receiving SIGTERM") + } + assert.ErrorIs(t, ctx.Err(), context.Canceled) +} + +// TestRunDBConnectionFailure verifies that run returns ExitCodeDBEngineError +// when the database is unreachable. +func TestRunDBConnectionFailure(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 12*time.Second) + defer cancel() + + cmdOpts := config.NewCmdOptions( + "--clientname=test_conn_fail", + // port 1 is almost universally refused immediately + "--connstr=postgres://invalid:invalid@localhost:1/invalid?sslmode=disable", + ) + code := run(ctx, cmdOpts, newTestLogger()) + assert.Equal(t, ExitCodeDBEngineError, code) +} + +// TestRunInitOnly verifies that run initialises the database schema and exits +// cleanly when the --init flag is supplied. +func TestRunInitOnly(t *testing.T) { + connStr, cleanup := setupTestContainer(t) + defer cleanup() + + cmdOpts := config.NewCmdOptions( + "--clientname=test_main_init", + "--connstr="+connStr, + "--init", + ) + code := run(context.Background(), cmdOpts, newTestLogger()) + assert.Equal(t, ExitCodeOK, code) +} + +// TestRunUpgrade verifies that run performs a schema upgrade and exits cleanly +// when the --upgrade flag is combined with --init. +func TestRunUpgrade(t *testing.T) { + connStr, cleanup := setupTestContainer(t) + defer cleanup() + + cmdOpts := config.NewCmdOptions( + "--clientname=test_main_upgrade", + "--connstr="+connStr, + "--upgrade", + "--init", + ) + code := run(context.Background(), cmdOpts, newTestLogger()) + assert.Equal(t, ExitCodeOK, code) +} + +// TestRunContextCancellation verifies that run returns ExitCodeOK (not +// ExitCodeShutdownCommand) when the context is cancelled while the scheduler +// is running. +func TestRunContextCancellation(t *testing.T) { + connStr, cleanup := setupTestContainer(t) + defer cleanup() + + ctx, cancel := context.WithTimeout(context.Background(), 8*time.Second) + defer cancel() + + cmdOpts := config.NewCmdOptions( + "--clientname=test_main_cancel", + "--connstr="+connStr, + ) + code := run(ctx, cmdOpts, newTestLogger()) + assert.Equal(t, ExitCodeOK, code) +}