Skip to content
Merged

Dev #23

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 0 additions & 193 deletions Book1(Sheet1).csv

This file was deleted.

73 changes: 73 additions & 0 deletions internal/httpx/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package httpx_test

import (
"net/http"
"testing"
"time"

"sentinelgo/internal/httpx"
)

func TestNewClient_ReturnsNonNil(t *testing.T) {
c := httpx.NewClient(30 * time.Second)
if c == nil {
t.Fatal("NewClient returned nil")
}
}

func TestNewClient_TimeoutIsSet(t *testing.T) {
timeout := 45 * time.Second
c := httpx.NewClient(timeout)
if c.Timeout != timeout {
t.Errorf("Timeout: got %v, want %v", c.Timeout, timeout)
}
}

func TestNewClient_ZeroTimeout(t *testing.T) {
c := httpx.NewClient(0)
if c == nil {
t.Fatal("NewClient(0) returned nil")
}
if c.Timeout != 0 {
t.Errorf("Timeout with zero: got %v, want 0", c.Timeout)
}
}

func TestNewClient_HasCustomTransport(t *testing.T) {
c := httpx.NewClient(10 * time.Second)
if c.Transport == nil {
t.Fatal("expected a custom Transport, got nil (would use default which lacks pool tuning)")
}
}

func TestNewClient_TransportIsHTTPTransport(t *testing.T) {
c := httpx.NewClient(10 * time.Second)
tr, ok := c.Transport.(*http.Transport)
if !ok {
t.Fatalf("expected *http.Transport, got %T", c.Transport)
}
if tr.MaxIdleConns == 0 {
t.Error("MaxIdleConns should be non-zero for connection reuse")
}
if tr.MaxIdleConnsPerHost == 0 {
t.Error("MaxIdleConnsPerHost should be non-zero for connection reuse")
}
if tr.IdleConnTimeout == 0 {
t.Error("IdleConnTimeout should be non-zero")
}
}

func TestNewClient_DifferentTimeouts_IndependentClients(t *testing.T) {
c1 := httpx.NewClient(10 * time.Second)
c2 := httpx.NewClient(60 * time.Second)

if c1.Timeout != 10*time.Second {
t.Errorf("c1 timeout: got %v, want 10s", c1.Timeout)
}
if c2.Timeout != 60*time.Second {
t.Errorf("c2 timeout: got %v, want 60s", c2.Timeout)
}
if c1 == c2 {
t.Error("NewClient should return independent client instances")
}
}
91 changes: 91 additions & 0 deletions internal/sanitize/sanitize_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package sanitize_test

import (
"strings"
"testing"

"sentinelgo/internal/sanitize"
)

func TestForLog_PlainString_Unchanged(t *testing.T) {
input := "v1.2.3"
got := sanitize.ForLog(input)
if got != input {
t.Errorf("ForLog(%q) = %q, want %q", input, got, input)
}
}

func TestForLog_Empty_ReturnsEmpty(t *testing.T) {
got := sanitize.ForLog("")
if got != "" {
t.Errorf("ForLog(\"\") = %q, want empty string", got)
}
}

func TestForLog_StripNewline(t *testing.T) {
input := "line1\nline2"
got := sanitize.ForLog(input)
if strings.Contains(got, "\n") {
t.Errorf("ForLog should strip newlines; got: %q", got)
}
}

func TestForLog_StripCarriageReturn(t *testing.T) {
input := "value\r\ninjected"
got := sanitize.ForLog(input)
if strings.Contains(got, "\r") || strings.Contains(got, "\n") {
t.Errorf("ForLog should strip \\r and \\n; got: %q", got)
}
}

func TestForLog_StripCRLF(t *testing.T) {
input := "INFO level\r\nERROR injected"
got := sanitize.ForLog(input)
if strings.Contains(got, "\r") || strings.Contains(got, "\n") {
t.Errorf("ForLog should strip CRLF; got: %q", got)
}
if !strings.Contains(got, "INFO level") || !strings.Contains(got, "ERROR injected") {
t.Errorf("ForLog should preserve non-whitespace content; got: %q", got)
}
}

func TestForLog_MultipleNewlines(t *testing.T) {
input := "a\nb\nc\n\n"
got := sanitize.ForLog(input)
if strings.Contains(got, "\n") {
t.Errorf("ForLog should strip all newlines; got: %q", got)
}
}

func TestForLog_OnlyNewlines(t *testing.T) {
got := sanitize.ForLog("\n\r\n\r")
if got != "" {
t.Errorf("ForLog with only whitespace characters: got %q, want empty string", got)
}
}

func TestForLog_PreservesOtherContent(t *testing.T) {
input := "version=v2.0.0 arch=amd64 os=linux"
got := sanitize.ForLog(input)
if got != input {
t.Errorf("ForLog should preserve content without newlines; got %q, want %q", got, input)
}
}

func TestForLog_LogInjectionPrevention(t *testing.T) {
// Simulate an attacker-controlled value that tries to inject a log line.
malicious := "v1.0.0\n2025-01-01 00:00:00 ERROR: injected log entry"
got := sanitize.ForLog(malicious)
lines := strings.Split(got, "\n")
if len(lines) > 1 {
t.Errorf("ForLog should prevent log injection; got %d lines: %q", len(lines), got)
}
}

func TestForLog_TabAndSpacePreserved(t *testing.T) {
input := "key\tvalue spaced"
got := sanitize.ForLog(input)
if got != input {
t.Errorf("ForLog should preserve tabs and spaces; got %q, want %q", got, input)
}
}
115 changes: 103 additions & 12 deletions internal/service/task/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,47 @@ import (
"os"
"path/filepath"
"runtime"
"sync"
"time"

"sentinelgo/internal/config"
"sentinelgo/internal/httpx"
"sentinelgo/internal/taskstore"
)

const (
defaultTaskTimeout = 30 * time.Minute
maxTaskTimeout = 24 * time.Hour
watchdogInterval = 2 * time.Minute
watchdogGrace = 5 * time.Minute
)

// activeTask tracks an in-flight task for the watchdog.
type activeTask struct {
cancel context.CancelFunc
deadline time.Time
}

// TaskExecutorService handles the execution of tasks assigned to the agent.
type TaskExecutorService struct {
cfg *config.Config
pollingSvc *TaskPollingService
client *http.Client
runningTasks map[string]bool
nativeHandlers map[string]NativeTaskHandler

activeMu sync.Mutex
activeRunning map[string]activeTask // task ID → {cancel, deadline}
}

// NewTaskExecutorService creates a new task execution service.
func NewTaskExecutorService(cfg *config.Config, pollingSvc *TaskPollingService) *TaskExecutorService {
s := &TaskExecutorService{
cfg: cfg,
pollingSvc: pollingSvc,
client: httpx.NewClient(2 * time.Minute),
runningTasks: make(map[string]bool),
cfg: cfg,
pollingSvc: pollingSvc,
client: httpx.NewClient(2 * time.Minute),
runningTasks: make(map[string]bool),
activeRunning: make(map[string]activeTask),
}
s.registerNativeHandlers()
return s
Expand All @@ -53,6 +71,53 @@ func (s *TaskExecutorService) RunExecutionLoop(ctx context.Context) {
}
}

// RunWatchdog polls activeRunning every watchdogInterval and cancels any task
// that has been running past its deadline plus watchdogGrace. The cancelled
// context fires taskCtx.Done() in runTask, which uses the existing timeout
// machinery to mark the task failed and report it to the server.
func (s *TaskExecutorService) RunWatchdog(ctx context.Context) {
log.Printf("Watchdog: started (interval=%v, grace=%v)", watchdogInterval, watchdogGrace)
ticker := time.NewTicker(watchdogInterval)
defer ticker.Stop()

for {
select {
case <-ctx.Done():
log.Printf("Watchdog: stopped")
return
case <-ticker.C:
s.cancelOverdueTasks()
}
}
}

func (s *TaskExecutorService) cancelOverdueTasks() {
now := time.Now()
s.activeMu.Lock()
defer s.activeMu.Unlock()

for id, t := range s.activeRunning {
if now.After(t.deadline.Add(watchdogGrace)) {
log.Printf("Watchdog: cancelling stuck task %s (deadline was %v ago)",
id, now.Sub(t.deadline).Round(time.Second))
t.cancel()
// deregisterRunning will be called by the defer in runTask
}
}
}

func (s *TaskExecutorService) registerRunning(id string, cancel context.CancelFunc, deadline time.Time) {
s.activeMu.Lock()
s.activeRunning[id] = activeTask{cancel: cancel, deadline: deadline}
s.activeMu.Unlock()
}

func (s *TaskExecutorService) deregisterRunning(id string) {
s.activeMu.Lock()
delete(s.activeRunning, id)
s.activeMu.Unlock()
}

// ExecutePendingTasks finds locally stored 'assigned' tasks and runs them.
func (s *TaskExecutorService) ExecutePendingTasks(ctx context.Context) {
tasks, err := s.pollingSvc.GetLocalTasks()
Expand Down Expand Up @@ -106,14 +171,37 @@ func (s *TaskExecutorService) executeTask(ctx context.Context, task taskstore.Ta
}
}

func (s *TaskExecutorService) runTask(ctx context.Context, task taskstore.Task) (string, error) {
if handler, ok := s.nativeHandlers[task.Slug]; ok {
return handler(ctx, task)
// resolveTaskTimeout returns the timeout for a task. If the task payload
// contains a "timeout_minutes" key (float64 > 0), that value is used, capped
// at maxTaskTimeout. Otherwise defaultTaskTimeout applies.
func resolveTaskTimeout(task taskstore.Task) time.Duration {
if v, ok := task.Payload["timeout_minutes"]; ok {
if mins, ok := v.(float64); ok && mins > 0 {
d := time.Duration(mins * float64(time.Minute))
if d > maxTaskTimeout {
d = maxTaskTimeout
}
return d
}
}
return defaultTaskTimeout
}

timeoutCtx, cancel := context.WithTimeout(ctx, 10*time.Minute)
func (s *TaskExecutorService) runTask(ctx context.Context, task taskstore.Task) (string, error) {
timeout := resolveTaskTimeout(task)
taskCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()

// Register with the watchdog. If the task exceeds its deadline plus the
// grace period the watchdog will call cancel(), firing taskCtx.Done() and
// causing the select below (or the native handler) to return an error.
s.registerRunning(task.ID, cancel, time.Now().Add(timeout))
defer s.deregisterRunning(task.ID)

if handler, ok := s.nativeHandlers[task.Slug]; ok {
return handler(taskCtx, task)
}

scriptPath, scriptName, err := s.resolveScript(task)
if err != nil {
return "", err
Expand All @@ -130,7 +218,7 @@ func (s *TaskExecutorService) runTask(ctx context.Context, task taskstore.Task)
}()

localScriptPath := filepath.Join(tempDir, scriptName)
if err := s.downloadScript(timeoutCtx, scriptPath, localScriptPath); err != nil {
if err := s.downloadScript(taskCtx, scriptPath, localScriptPath); err != nil {
return "", fmt.Errorf("download script: %w", err)
}

Expand All @@ -147,16 +235,19 @@ func (s *TaskExecutorService) runTask(ctx context.Context, task taskstore.Task)
}, 1)

go func() {
output, err := s.executeLocalScript(timeoutCtx, localScriptPath, payloadPath)
output, err := s.executeLocalScript(taskCtx, localScriptPath, payloadPath)
resultChan <- struct {
output string
err error
}{output, err}
}()

select {
case <-timeoutCtx.Done():
timeoutNote := fmt.Sprintf("Task execution timed out after 10 minutes. Task ID: %s, Slug: %s. The task was forcefully stopped.", task.ID, task.Slug)
case <-taskCtx.Done():
timeoutNote := fmt.Sprintf(
"Task exceeded timeout of %v. Task ID: %s, Slug: %s. Forcefully stopped.",
timeout.Round(time.Second), task.ID, task.Slug,
)
log.Printf("Executor: %s", timeoutNote)
return timeoutNote, fmt.Errorf("task execution timeout")
case result := <-resultChan:
Expand Down
3 changes: 3 additions & 0 deletions internal/service/task/executor_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"os"
"os/exec"
"path/filepath"
"time"
)

// executeLocalScript runs the task script on macOS.
Expand All @@ -29,6 +30,8 @@ func (s *TaskExecutorService) executeLocalScript(ctx context.Context, scriptPath
cmd = exec.CommandContext(ctx, scriptPath, payloadPath)
}

cmd.WaitDelay = 30 * time.Second

output, err := cmd.CombinedOutput()
return string(output), err
}
Loading
Loading