From fd730d70b9854afc6b9f8885ec92969c82921411 Mon Sep 17 00:00:00 2001 From: Serghei Iakovlev Date: Sun, 26 Apr 2026 19:12:49 +0200 Subject: [PATCH 1/2] feat(agent): add OpenCode adapter Implements the fork-per-turn OpenCode adapter using `opencode run --format json`. - command.go: config parsing, argument/env construction, permission policy synthesis, and SSH remote command construction - parse.go: JSON envelope parsing, part parsers, and export-based token usage recovery via `opencode export --sanitize` - opencode.go: session lifecycle, subprocess management, event normalization, and stall watchdog integration - Full test suite: unit tests with fixtures, subprocess-backed Unix tests, and env-gated integration tests (SORTIE_OPENCODE_TEST=1) - Register adapter in binary and add registry metadata assertion - Add *_unix_test.go file nesting pattern to VS Code explorer settings Closes #476 --- .vscode/settings.json | 2 +- cmd/sortie/main.go | 1 + internal/agent/opencode/command.go | 241 ++++++ internal/agent/opencode/command_test.go | 487 +++++++++++ internal/agent/opencode/integration_test.go | 281 +++++++ internal/agent/opencode/opencode.go | 783 ++++++++++++++++++ internal/agent/opencode/opencode_test.go | 725 ++++++++++++++++ internal/agent/opencode/parse.go | 351 ++++++++ internal/agent/opencode/parse_test.go | 306 +++++++ internal/agent/opencode/parse_unix_test.go | 135 +++ .../agent/opencode/testdata/export_usage.json | 28 + .../testdata/export_usage_missing_tokens.json | 10 + .../testdata/logical_failure_exit0.jsonl | 2 + .../opencode/testdata/malformed_event.jsonl | 2 + .../permission_warning_then_error.txt | 2 + .../agent/opencode/testdata/resume_turn.jsonl | 3 + .../agent/opencode/testdata/simple_turn.jsonl | 3 + .../opencode/testdata/tool_success.jsonl | 3 + internal/registry/adapter_meta_test.go | 6 + 19 files changed, 3370 insertions(+), 1 deletion(-) create mode 100644 internal/agent/opencode/command.go create mode 100644 internal/agent/opencode/command_test.go create mode 100644 internal/agent/opencode/integration_test.go create mode 100644 internal/agent/opencode/opencode.go create mode 100644 internal/agent/opencode/opencode_test.go create mode 100644 internal/agent/opencode/parse.go create mode 100644 internal/agent/opencode/parse_test.go create mode 100644 internal/agent/opencode/parse_unix_test.go create mode 100644 internal/agent/opencode/testdata/export_usage.json create mode 100644 internal/agent/opencode/testdata/export_usage_missing_tokens.json create mode 100644 internal/agent/opencode/testdata/logical_failure_exit0.jsonl create mode 100644 internal/agent/opencode/testdata/malformed_event.jsonl create mode 100644 internal/agent/opencode/testdata/permission_warning_then_error.txt create mode 100644 internal/agent/opencode/testdata/resume_turn.jsonl create mode 100644 internal/agent/opencode/testdata/simple_turn.jsonl create mode 100644 internal/agent/opencode/testdata/tool_success.jsonl diff --git a/.vscode/settings.json b/.vscode/settings.json index 0cc8ca8a..df6b8fbc 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -10,7 +10,7 @@ "explorer.fileNesting.enabled": true, "explorer.fileNesting.expand": false, "explorer.fileNesting.patterns": { - "*.go": "${capture}_test.go, ${capture}_integration_test.go", + "*.go": "${capture}_test.go, ${capture}_unix_test.go, ${capture}_integration_test.go", "go.mod": "go.sum", "codecov.yml": "coverage.*", ".env": ".env.*", diff --git a/cmd/sortie/main.go b/cmd/sortie/main.go index 48aa67a8..af0f397e 100644 --- a/cmd/sortie/main.go +++ b/cmd/sortie/main.go @@ -39,6 +39,7 @@ import ( _ "github.com/sortie-ai/sortie/internal/agent/codex" _ "github.com/sortie-ai/sortie/internal/agent/copilot" _ "github.com/sortie-ai/sortie/internal/agent/mock" + _ "github.com/sortie-ai/sortie/internal/agent/opencode" _ "github.com/sortie-ai/sortie/internal/scm/github" _ "github.com/sortie-ai/sortie/internal/tracker/file" _ "github.com/sortie-ai/sortie/internal/tracker/jira" diff --git a/internal/agent/opencode/command.go b/internal/agent/opencode/command.go new file mode 100644 index 00000000..9c984c64 --- /dev/null +++ b/internal/agent/opencode/command.go @@ -0,0 +1,241 @@ +package opencode + +import ( + "encoding/json" + "fmt" + "log/slog" + "slices" + "strconv" + "strings" + + "github.com/sortie-ai/sortie/internal/agent/sshutil" + "github.com/sortie-ai/sortie/internal/typeutil" +) + +type passthroughConfig struct { + Model string + Agent string + Variant string + Thinking bool + Pure bool + DangerousSkipPermissions bool + DisableAutocompact bool + AllowedTools []string + DeniedTools []string +} + +type permissionAction string + +const ( + permissionAllow permissionAction = "allow" + permissionDeny permissionAction = "deny" +) + +type permissionPolicy map[string]permissionAction + +var knownPermissionKeys = map[string]struct{}{ + "bash": {}, + "codesearch": {}, + "doom_loop": {}, + "edit": {}, + "external_directory": {}, + "glob": {}, + "grep": {}, + "list": {}, + "lsp": {}, + "question": {}, + "read": {}, + "skill": {}, + "task": {}, + "todowrite": {}, + "webfetch": {}, + "websearch": {}, +} + +func parsePassthroughConfig(config map[string]any) (passthroughConfig, error) { + pt := passthroughConfig{ + Model: typeutil.StringFrom(config, "model"), + Agent: typeutil.StringFrom(config, "agent"), + Variant: typeutil.StringFrom(config, "variant"), + Thinking: typeutil.BoolFrom(config, "thinking", false), + Pure: typeutil.BoolFrom(config, "pure", false), + DangerousSkipPermissions: typeutil.BoolFrom(config, "dangerously_skip_permissions", true), + DisableAutocompact: typeutil.BoolFrom(config, "disable_autocompact", true), + AllowedTools: slices.Clone(typeutil.ExtractStringSlice(config["allowed_tools"])), + DeniedTools: slices.Clone(typeutil.ExtractStringSlice(config["denied_tools"])), + } + + allowed := make(map[string]struct{}, len(pt.AllowedTools)) + for _, key := range pt.AllowedTools { + allowed[key] = struct{}{} + } + + var conflicts []string + for _, key := range pt.DeniedTools { + if _, ok := allowed[key]; ok { + conflicts = append(conflicts, key) + } + } + if len(conflicts) > 0 { + slices.Sort(conflicts) + return passthroughConfig{}, fmt.Errorf("allowed_tools and denied_tools overlap: %s", strings.Join(conflicts, ", ")) + } + + return pt, nil +} + +func buildRunArgs(state *sessionState, prompt string, pt passthroughConfig) []string { + args := []string{"run", "--format", "json", "--dir", state.target.WorkspacePath} + + if state.sessionID != "" { + args = append(args, "--session", state.sessionID) + } + if pt.Model != "" { + args = append(args, "--model", pt.Model) + } + if pt.Agent != "" { + args = append(args, "--agent", pt.Agent) + } + if pt.Variant != "" { + args = append(args, "--variant", pt.Variant) + } + if pt.Thinking { + args = append(args, "--thinking") + } + if pt.Pure { + args = append(args, "--pure") + } + if pt.DangerousSkipPermissions { + args = append(args, "--dangerously-skip-permissions") + } + + args = append(args, "--", prompt) + return args +} + +func buildRunEnv(base []string, pt passthroughConfig) ([]string, error) { + managedEnv, err := buildManagedEnv(pt) + if err != nil { + return nil, err + } + + env := make([]string, 0, len(base)+len(managedEnv)) + for _, entry := range base { + if shouldDropManagedEnv(entry) { + continue + } + env = append(env, entry) + } + + keys := make([]string, 0, len(managedEnv)) + for key := range managedEnv { + keys = append(keys, key) + } + slices.Sort(keys) + for _, key := range keys { + env = append(env, key+"="+managedEnv[key]) + } + + return env, nil +} + +func buildSSHRemoteCommand(remoteCommand string, extraEnv map[string]string) string { + if len(extraEnv) == 0 { + return remoteCommand + } + + keys := make([]string, 0, len(extraEnv)) + for key := range extraEnv { + keys = append(keys, key) + } + slices.Sort(keys) + + parts := make([]string, 0, len(keys)+1) + for _, key := range keys { + parts = append(parts, key+"="+sshutil.ShellQuote(extraEnv[key])) + } + parts = append(parts, remoteCommand) + + return strings.Join(parts, " ") +} + +func buildManagedEnv(pt passthroughConfig) (map[string]string, error) { + managed := map[string]string{ + "OPENCODE_AUTO_SHARE": "false", + "OPENCODE_DISABLE_AUTOCOMPACT": strconv.FormatBool(pt.DisableAutocompact), + "OPENCODE_DISABLE_AUTOUPDATE": "true", + "OPENCODE_DISABLE_LSP_DOWNLOAD": "true", + } + + policy, ok := buildPermissionPolicy(pt) + if !ok { + return managed, nil + } + + encoded, err := json.Marshal(policy) + if err != nil { + return nil, fmt.Errorf("marshal opencode permission policy: %w", err) + } + managed["OPENCODE_PERMISSION"] = string(encoded) + + return managed, nil +} + +func buildPermissionPolicy(pt passthroughConfig) (permissionPolicy, bool) { + if len(pt.AllowedTools) == 0 && len(pt.DeniedTools) == 0 { + return nil, false + } + + policy := make(permissionPolicy, len(pt.AllowedTools)+len(pt.DeniedTools)+len(knownPermissionKeys)) + allowed := make(map[string]struct{}, len(pt.AllowedTools)) + for _, key := range pt.AllowedTools { + allowed[key] = struct{}{} + policy[key] = permissionAllow + logUnknownPermissionKey(key) + } + + if len(pt.AllowedTools) > 0 { + for key := range knownPermissionKeys { + if _, ok := allowed[key]; ok { + continue + } + policy[key] = permissionDeny + } + } + + for _, key := range pt.DeniedTools { + policy[key] = permissionDeny + logUnknownPermissionKey(key) + } + + return policy, true +} + +func shouldDropManagedEnv(entry string) bool { + key, _, found := strings.Cut(entry, "=") + if !found { + return false + } + + switch key { + case "OPENCODE_AUTO_SHARE", + "OPENCODE_DISABLE_AUTOCOMPACT", + "OPENCODE_DISABLE_AUTOUPDATE", + "OPENCODE_DISABLE_LSP_DOWNLOAD", + "OPENCODE_PERMISSION": + return true + default: + return false + } +} + +func logUnknownPermissionKey(key string) { + if _, ok := knownPermissionKeys[key]; ok { + return + } + + slog.Default().With(slog.String("component", "opencode-adapter")).Debug( + "forwarding unknown opencode permission key", + slog.String("permission_key", key), + ) +} diff --git a/internal/agent/opencode/command_test.go b/internal/agent/opencode/command_test.go new file mode 100644 index 00000000..9ff8e32d --- /dev/null +++ b/internal/agent/opencode/command_test.go @@ -0,0 +1,487 @@ +package opencode + +import ( + "encoding/json" + "errors" + "strings" + "testing" + + "github.com/sortie-ai/sortie/internal/agent/agentcore" +) + +// envLookup returns the value for key in an env []string slice. +func envLookup(env []string, key string) (string, bool) { + prefix := key + "=" + for _, e := range env { + if strings.HasPrefix(e, prefix) { + return strings.TrimPrefix(e, prefix), true + } + } + return "", false +} + +// assertEnvPresent fails unless key is present in env with the given value. +func assertEnvPresent(t *testing.T, env []string, key, wantVal string) { + t.Helper() + got, ok := envLookup(env, key) + if !ok { + t.Errorf("env %q absent, want %q=%q", key, key, wantVal) + return + } + if got != wantVal { + t.Errorf("env %q = %q, want %q", key, got, wantVal) + } +} + +// assertEnvAbsent fails if key is present in env. +func assertEnvAbsent(t *testing.T, env []string, key string) { + t.Helper() + if _, ok := envLookup(env, key); ok { + t.Errorf("env %q is present, want absent", key) + } +} + +// assertHasArgPair fails if flag and value do not appear as consecutive +// elements in args. +func assertHasArgPair(t *testing.T, args []string, flag, value string) { + t.Helper() + for i := 0; i < len(args)-1; i++ { + if args[i] == flag && args[i+1] == value { + return + } + } + t.Errorf("buildRunArgs() missing %q %q in [%s]", flag, value, strings.Join(args, " ")) +} + +// assertHasFlag fails if flag does not appear in args. +func assertHasFlag(t *testing.T, args []string, flag string) { + t.Helper() + for _, a := range args { + if a == flag { + return + } + } + t.Errorf("buildRunArgs() missing flag %q in [%s]", flag, strings.Join(args, " ")) +} + +// assertNoFlag fails if flag appears anywhere in args. +func assertNoFlag(t *testing.T, args []string, flag string) { + t.Helper() + for _, a := range args { + if a == flag { + t.Errorf("buildRunArgs() unexpected flag %q in [%s]", flag, strings.Join(args, " ")) + return + } + } +} + +// newTestSessionState returns a sessionState suitable for buildRunArgs tests. +func newTestSessionState(workspacePath, sessionID string) *sessionState { + return &sessionState{ + target: agentcore.LaunchTarget{ + WorkspacePath: workspacePath, + }, + sessionID: sessionID, + } +} + +func TestNewOpenCodeAdapter_ParsePassthroughConfig(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + config map[string]any + wantErr bool + checkFunc func(t *testing.T, pt passthroughConfig) + }{ + { + name: "defaults", + config: map[string]any{}, + checkFunc: func(t *testing.T, pt passthroughConfig) { + t.Helper() + if !pt.DangerousSkipPermissions { + t.Error("DangerousSkipPermissions = false, want true (default)") + } + if !pt.DisableAutocompact { + t.Error("DisableAutocompact = false, want true (default)") + } + }, + }, + { + name: "allowed_tools_parse", + config: map[string]any{ + "allowed_tools": []any{"read", "edit"}, + }, + checkFunc: func(t *testing.T, pt passthroughConfig) { + t.Helper() + if len(pt.AllowedTools) != 2 { + t.Fatalf("AllowedTools len = %d, want 2", len(pt.AllowedTools)) + } + if pt.AllowedTools[0] != "read" { + t.Errorf("AllowedTools[0] = %q, want %q", pt.AllowedTools[0], "read") + } + if pt.AllowedTools[1] != "edit" { + t.Errorf("AllowedTools[1] = %q, want %q", pt.AllowedTools[1], "edit") + } + }, + }, + { + name: "denied_tools_parse", + config: map[string]any{ + "denied_tools": []any{"bash"}, + }, + checkFunc: func(t *testing.T, pt passthroughConfig) { + t.Helper() + if len(pt.DeniedTools) != 1 { + t.Fatalf("DeniedTools len = %d, want 1", len(pt.DeniedTools)) + } + if pt.DeniedTools[0] != "bash" { + t.Errorf("DeniedTools[0] = %q, want %q", pt.DeniedTools[0], "bash") + } + }, + }, + { + name: "unknown_key_preserved", + config: map[string]any{ + "allowed_tools": []any{"customtool"}, + }, + checkFunc: func(t *testing.T, pt passthroughConfig) { + t.Helper() + if len(pt.AllowedTools) != 1 || pt.AllowedTools[0] != "customtool" { + t.Errorf("AllowedTools = %v, want [customtool]", pt.AllowedTools) + } + }, + }, + { + name: "overlap_error", + config: map[string]any{ + "allowed_tools": []any{"bash"}, + "denied_tools": []any{"bash"}, + }, + wantErr: true, + }, + { + name: "model_and_flags", + config: map[string]any{ + "model": "anthropic/claude-3-5-sonnet", + "dangerously_skip_permissions": false, + "disable_autocompact": false, + }, + checkFunc: func(t *testing.T, pt passthroughConfig) { + t.Helper() + if pt.Model != "anthropic/claude-3-5-sonnet" { + t.Errorf("Model = %q, want %q", pt.Model, "anthropic/claude-3-5-sonnet") + } + if pt.DangerousSkipPermissions { + t.Error("DangerousSkipPermissions = true, want false") + } + if pt.DisableAutocompact { + t.Error("DisableAutocompact = true, want false") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + a, err := NewOpenCodeAdapter(tt.config) + + if tt.wantErr { + if err == nil { + t.Fatal("NewOpenCodeAdapter() error = nil, want error") + } + if !strings.Contains(err.Error(), "bash") { + t.Errorf("error = %q, want it to mention %q", err.Error(), "bash") + } + return + } + + if err != nil { + t.Fatalf("NewOpenCodeAdapter() error = %v", err) + } + + oc, ok := a.(*OpenCodeAdapter) + if !ok { + t.Fatalf("adapter type = %T, want *OpenCodeAdapter", a) + } + if tt.checkFunc != nil { + tt.checkFunc(t, oc.passthrough) + } + }) + } +} + +func TestBuildRunArgs(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + sessionID string + pt passthroughConfig + prompt string + wantPresent []string + wantPairs [][2]string + wantAbsent []string + }{ + { + name: "fresh_session", + sessionID: "", + pt: passthroughConfig{}, + prompt: "do work", + wantAbsent: []string{"--session"}, + }, + { + name: "resume_session", + sessionID: "ses_abc", + pt: passthroughConfig{}, + prompt: "continue", + wantPairs: [][2]string{{"--session", "ses_abc"}}, + }, + { + name: "skip_permissions_default", + sessionID: "", + pt: passthroughConfig{DangerousSkipPermissions: true}, + prompt: "work", + wantPresent: []string{"--dangerously-skip-permissions"}, + }, + { + name: "skip_permissions_disabled", + sessionID: "", + pt: passthroughConfig{DangerousSkipPermissions: false}, + prompt: "work", + wantAbsent: []string{"--dangerously-skip-permissions"}, + }, + { + name: "model_flag", + sessionID: "", + pt: passthroughConfig{Model: "anthropic/claude-3-5-sonnet"}, + prompt: "work", + wantPairs: [][2]string{{"--model", "anthropic/claude-3-5-sonnet"}}, + }, + { + name: "prompt_after_dashdash", + sessionID: "", + pt: passthroughConfig{}, + prompt: "my --prompt with flags", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + state := newTestSessionState("/tmp/workspace", tt.sessionID) + args := buildRunArgs(state, tt.prompt, tt.pt) + + for _, flag := range tt.wantPresent { + assertHasFlag(t, args, flag) + } + for _, pair := range tt.wantPairs { + assertHasArgPair(t, args, pair[0], pair[1]) + } + for _, flag := range tt.wantAbsent { + assertNoFlag(t, args, flag) + } + + // Prompt must be the last argument, after "--". + if len(args) < 2 { + t.Fatalf("args too short: %v", args) + } + lastTwo := args[len(args)-2:] + if lastTwo[0] != "--" { + t.Errorf("second-to-last arg = %q, want %q", lastTwo[0], "--") + } + if lastTwo[1] != tt.prompt { + t.Errorf("last arg = %q, want prompt %q", lastTwo[1], tt.prompt) + } + }) + } +} + +func TestBuildRunEnv(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + base []string + pt passthroughConfig + checkFunc func(t *testing.T, env []string) + }{ + { + name: "baseline_always_set", + base: []string{}, + pt: passthroughConfig{}, + checkFunc: func(t *testing.T, env []string) { + t.Helper() + assertEnvPresent(t, env, "OPENCODE_AUTO_SHARE", "false") + assertEnvPresent(t, env, "OPENCODE_DISABLE_AUTOUPDATE", "true") + assertEnvPresent(t, env, "OPENCODE_DISABLE_LSP_DOWNLOAD", "true") + }, + }, + { + name: "autocompact_default_true", + base: []string{}, + pt: passthroughConfig{DisableAutocompact: true}, + checkFunc: func(t *testing.T, env []string) { + t.Helper() + assertEnvPresent(t, env, "OPENCODE_DISABLE_AUTOCOMPACT", "true") + }, + }, + { + name: "autocompact_disabled", + base: []string{}, + pt: passthroughConfig{DisableAutocompact: false}, + checkFunc: func(t *testing.T, env []string) { + t.Helper() + assertEnvPresent(t, env, "OPENCODE_DISABLE_AUTOCOMPACT", "false") + }, + }, + { + name: "inherited_permission_removed", + base: []string{"OPENCODE_PERMISSION=old_value", "OTHER_VAR=keep"}, + pt: passthroughConfig{}, + checkFunc: func(t *testing.T, env []string) { + t.Helper() + assertEnvAbsent(t, env, "OPENCODE_PERMISSION") + assertEnvPresent(t, env, "OTHER_VAR", "keep") + }, + }, + { + name: "allowed_tools_policy", + base: []string{}, + pt: passthroughConfig{AllowedTools: []string{"read"}}, + checkFunc: func(t *testing.T, env []string) { + t.Helper() + raw, ok := envLookup(env, "OPENCODE_PERMISSION") + if !ok { + t.Fatal("OPENCODE_PERMISSION absent") + } + var policy map[string]string + if err := json.Unmarshal([]byte(raw), &policy); err != nil { + t.Fatalf("OPENCODE_PERMISSION unmarshal: %v", err) + } + if policy["read"] != "allow" { + t.Errorf("OPENCODE_PERMISSION[read] = %q, want %q", policy["read"], "allow") + } + if policy["bash"] != "deny" { + t.Errorf("OPENCODE_PERMISSION[bash] = %q, want %q", policy["bash"], "deny") + } + }, + }, + { + name: "denied_tools_policy", + base: []string{}, + pt: passthroughConfig{DeniedTools: []string{"bash"}}, + checkFunc: func(t *testing.T, env []string) { + t.Helper() + raw, ok := envLookup(env, "OPENCODE_PERMISSION") + if !ok { + t.Fatal("OPENCODE_PERMISSION absent") + } + var policy map[string]string + if err := json.Unmarshal([]byte(raw), &policy); err != nil { + t.Fatalf("OPENCODE_PERMISSION unmarshal: %v", err) + } + if policy["bash"] != "deny" { + t.Errorf("OPENCODE_PERMISSION[bash] = %q, want %q", policy["bash"], "deny") + } + }, + }, + { + name: "no_policy_no_permission_key", + base: []string{}, + pt: passthroughConfig{}, + checkFunc: func(t *testing.T, env []string) { + t.Helper() + assertEnvAbsent(t, env, "OPENCODE_PERMISSION") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + env, err := buildRunEnv(tt.base, tt.pt) + if err != nil { + t.Fatalf("buildRunEnv() error = %v", err) + } + if tt.checkFunc != nil { + tt.checkFunc(t, env) + } + }) + } +} + +func TestSSHRemoteCommand(t *testing.T) { + t.Parallel() + + t.Run("env_prefixed", func(t *testing.T) { + t.Parallel() + + extra := map[string]string{ + "KEY_A": "value_a", + "KEY_B": "value_b", + } + got := buildSSHRemoteCommand("opencode", extra) + + if !strings.Contains(got, "KEY_A=") { + t.Errorf("result %q missing KEY_A", got) + } + if !strings.Contains(got, "KEY_B=") { + t.Errorf("result %q missing KEY_B", got) + } + if !strings.HasSuffix(got, " opencode") { + t.Errorf("result %q does not end with remote command", got) + } + }) + + t.Run("values_shell_quoted", func(t *testing.T) { + t.Parallel() + + extra := map[string]string{ + "KEY": "value with spaces", + } + got := buildSSHRemoteCommand("opencode", extra) + + // ShellQuote wraps in single quotes. + if !strings.Contains(got, "'value with spaces'") { + t.Errorf("result %q: value with spaces not single-quoted", got) + } + }) + + t.Run("no_extra_env_returns_command", func(t *testing.T) { + t.Parallel() + + got := buildSSHRemoteCommand("opencode run --format json", nil) + if got != "opencode run --format json" { + t.Errorf("result = %q, want %q", got, "opencode run --format json") + } + }) + + t.Run("no_arbitrary_env", func(t *testing.T) { + t.Parallel() + + extra := map[string]string{ + "MY_KEY": "my_val", + } + got := buildSSHRemoteCommand("opencode", extra) + + // Only MY_KEY should appear as an env prefix; no other KEY= patterns. + parts := strings.Fields(got) + envCount := 0 + for _, p := range parts { + if strings.Contains(p, "=") && p != "opencode" { + envCount++ + } + } + if envCount != 1 { + t.Errorf("env prefix count = %d, want 1; result = %q", envCount, got) + } + }) +} + +// Compile-time interface check. +var _ = errors.New // keep errors imported if needed in future diff --git a/internal/agent/opencode/integration_test.go b/internal/agent/opencode/integration_test.go new file mode 100644 index 00000000..3d21b9db --- /dev/null +++ b/internal/agent/opencode/integration_test.go @@ -0,0 +1,281 @@ +package opencode_test + +import ( + "context" + "os" + "strings" + "testing" + "time" + + "github.com/sortie-ai/sortie/internal/domain" + + _ "github.com/sortie-ai/sortie/internal/agent/opencode" + "github.com/sortie-ai/sortie/internal/registry" +) + +func skipIfNotEnabled(t *testing.T) { + t.Helper() + if os.Getenv("SORTIE_OPENCODE_TEST") != "1" { + t.Skip("set SORTIE_OPENCODE_TEST=1 to run opencode integration tests") + } +} + +// integrationCommand returns the opencode binary path, defaulting to "opencode". +func integrationCommand() string { + if cmd := os.Getenv("SORTIE_OPENCODE_COMMAND"); cmd != "" { + return cmd + } + return "opencode" +} + +// integrationConfig returns base config for integration tests. +func integrationConfig() map[string]any { + model := os.Getenv("SORTIE_OPENCODE_MODEL") + if model == "" { + model = "anthropic/claude-haiku-4-5" + } + + cfg := map[string]any{ + "dangerously_skip_permissions": true, + "disable_autocompact": true, + "model": model, + } + return cfg +} + +// mustNewAdapter creates an adapter or fatals. +func mustNewAdapter(t *testing.T) domain.AgentAdapter { + t.Helper() + factory, err := registry.Agents.Get("opencode") + if err != nil { + t.Fatalf("registry.Agents.Get(opencode): %v", err) + } + a, err := factory(integrationConfig()) + if err != nil { + t.Fatalf("factory(): %v", err) + } + return a +} + +// mustStartIntegrationSession starts a session against the real opencode binary. +func mustStartIntegrationSession(t *testing.T, a domain.AgentAdapter, resumeID string) domain.Session { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + session, err := a.StartSession(ctx, domain.StartSessionParams{ + WorkspacePath: t.TempDir(), + AgentConfig: domain.AgentConfig{Command: integrationCommand()}, + ResumeSessionID: resumeID, + }) + if err != nil { + t.Fatalf("StartSession(): %v", err) + } + return session +} + +// collectAllEvents runs a turn and returns all events and the result. +func collectAllEvents(t *testing.T, a domain.AgentAdapter, session domain.Session, prompt string) ([]domain.AgentEvent, domain.TurnResult) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + var events []domain.AgentEvent + result, err := a.RunTurn(ctx, session, domain.RunTurnParams{ + Prompt: prompt, + OnEvent: func(e domain.AgentEvent) { + events = append(events, e) + }, + }) + if err != nil { + t.Logf("RunTurn error: %v", err) + } + return events, result +} + +func TestIntegration_HappyPathFreshTurn(t *testing.T) { + skipIfNotEnabled(t) + + a := mustNewAdapter(t) + session := mustStartIntegrationSession(t, a, "") + t.Cleanup(func() { _ = a.StopSession(context.Background(), session) }) + + events, result := collectAllEvents(t, a, session, "Reply with exactly: hello") + + if result.ExitReason != domain.EventTurnCompleted { + t.Errorf("ExitReason = %q, want %q", result.ExitReason, domain.EventTurnCompleted) + } + + var sessionStarted bool + for _, e := range events { + if e.Type == domain.EventSessionStarted { + sessionStarted = true + if e.SessionID == "" { + t.Error("EventSessionStarted has empty SessionID") + } + } + } + if !sessionStarted { + t.Error("no session_started event emitted") + } +} + +func TestIntegration_SessionResume(t *testing.T) { + skipIfNotEnabled(t) + + a := mustNewAdapter(t) + session := mustStartIntegrationSession(t, a, "") + t.Cleanup(func() { _ = a.StopSession(context.Background(), session) }) + + // First turn. + _, result1 := collectAllEvents(t, a, session, "Say: turn one") + if result1.ExitReason != domain.EventTurnCompleted { + t.Fatalf("turn 1 ExitReason = %q, want completed", result1.ExitReason) + } + sessionID := result1.SessionID + if sessionID == "" { + t.Fatal("turn 1 SessionID is empty") + } + + // Resume session. + a2 := mustNewAdapter(t) + session2 := mustStartIntegrationSession(t, a2, sessionID) + t.Cleanup(func() { _ = a2.StopSession(context.Background(), session2) }) + + _, result2 := collectAllEvents(t, a2, session2, "What did I say in the previous message?") + if result2.ExitReason != domain.EventTurnCompleted { + t.Errorf("resumed turn ExitReason = %q, want completed", result2.ExitReason) + } +} + +func TestIntegration_InvalidModelFailure(t *testing.T) { + skipIfNotEnabled(t) + + cfg := integrationConfig() + cfg["model"] = "nonexistent/nonexistent" + + factory, err := registry.Agents.Get("opencode") + if err != nil { + t.Fatalf("registry.Agents.Get: %v", err) + } + a, err := factory(cfg) + if err != nil { + t.Fatalf("factory(): %v", err) + } + + session := mustStartIntegrationSession(t, a, "") + t.Cleanup(func() { _ = a.StopSession(context.Background(), session) }) + + events, result := collectAllEvents(t, a, session, "Reply with exactly: hello") + if result.ExitReason != domain.EventTurnFailed { + t.Fatalf("ExitReason = %q, want %q", result.ExitReason, domain.EventTurnFailed) + } + + var sawTurnFailed bool + for _, event := range events { + if event.Type != domain.EventTurnFailed { + continue + } + sawTurnFailed = true + if !strings.Contains(event.Message, "Model not found") { + t.Errorf("turn_failed message = %q, want invalid-model detail", event.Message) + } + } + if !sawTurnFailed { + t.Fatalf("expected turn_failed event for invalid model, events=%+v", events) + } +} + +func TestIntegration_PermissionDeny(t *testing.T) { + skipIfNotEnabled(t) + + cfg := integrationConfig() + cfg["dangerously_skip_permissions"] = false + + factory, err := registry.Agents.Get("opencode") + if err != nil { + t.Fatalf("registry.Agents.Get: %v", err) + } + a, err := factory(cfg) + if err != nil { + t.Fatalf("factory(): %v", err) + } + + session := mustStartIntegrationSession(t, a, "") + t.Cleanup(func() { _ = a.StopSession(context.Background(), session) }) + + events, _ := collectAllEvents(t, a, session, + "Read the exact contents of /etc/hostname and return it verbatim. Do not guess. If access is denied, say that it was denied.") + + // OpenCode auto-rejects external_directory access in headless mode without + // --dangerously-skip-permissions, which yields a tool_use error envelope. + var sawToolError bool + for _, e := range events { + if e.Type == domain.EventToolResult && e.ToolError { + sawToolError = true + } + } + if !sawToolError { + t.Fatalf("expected at least one tool_result with ToolError=true for denied permission, events=%+v", events) + } +} + +func TestIntegration_TurnCancellation(t *testing.T) { + skipIfNotEnabled(t) + + a := mustNewAdapter(t) + session := mustStartIntegrationSession(t, a, "") + t.Cleanup(func() { _ = a.StopSession(context.Background(), session) }) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + turnCtx, turnCancel := context.WithCancel(ctx) + resultCh := make(chan domain.TurnResult, 1) + go func() { + result, _ := a.RunTurn(turnCtx, session, domain.RunTurnParams{ + Prompt: "Count to 1000 slowly, outputting each number on its own line", + OnEvent: func(_ domain.AgentEvent) {}, + }) + resultCh <- result + }() + + // Cancel after a brief moment. + time.Sleep(500 * time.Millisecond) + turnCancel() + + select { + case result := <-resultCh: + if result.ExitReason != domain.EventTurnCancelled { + t.Errorf("ExitReason = %q, want %q", result.ExitReason, domain.EventTurnCancelled) + } + case <-ctx.Done(): + t.Fatal("RunTurn did not return after context cancel") + } +} + +func TestIntegration_PermissionDeepMerge(t *testing.T) { + skipIfNotEnabled(t) + + // Verify that setting OPENCODE_PERMISSION does not replace but merges + // with any existing permission config (deep-merge semantics). + cfg := integrationConfig() + cfg["allowed_tools"] = []any{"read", "glob"} + + factory, err := registry.Agents.Get("opencode") + if err != nil { + t.Fatalf("registry.Agents.Get: %v", err) + } + a, err := factory(cfg) + if err != nil { + t.Fatalf("factory(): %v", err) + } + + session := mustStartIntegrationSession(t, a, "") + t.Cleanup(func() { _ = a.StopSession(context.Background(), session) }) + + _, result := collectAllEvents(t, a, session, "List files in the current directory") + if result.ExitReason != domain.EventTurnCompleted { + t.Errorf("ExitReason = %q, want completed", result.ExitReason) + } +} diff --git a/internal/agent/opencode/opencode.go b/internal/agent/opencode/opencode.go new file mode 100644 index 00000000..83ab6bdc --- /dev/null +++ b/internal/agent/opencode/opencode.go @@ -0,0 +1,783 @@ +// Package opencode implements [domain.AgentAdapter] for the OpenCode CLI. +// It launches one `opencode run --format json` subprocess per turn, +// normalizes stdout envelopes into domain events, and recovers final token +// usage with `opencode export --sanitize`. +package opencode + +import ( + "bufio" + "context" + "fmt" + "io" + "log/slog" + "os" + "os/exec" + "slices" + "strconv" + "strings" + "sync" + "time" + + "github.com/sortie-ai/sortie/internal/agent/agentcore" + "github.com/sortie-ai/sortie/internal/agent/procutil" + "github.com/sortie-ai/sortie/internal/agent/sshutil" + "github.com/sortie-ai/sortie/internal/domain" + "github.com/sortie-ai/sortie/internal/logging" + "github.com/sortie-ai/sortie/internal/registry" + "github.com/sortie-ai/sortie/internal/typeutil" +) + +func init() { + registry.Agents.RegisterWithMeta("opencode", NewOpenCodeAdapter, registry.AgentMeta{ + RequiresCommand: true, + }) +} + +var _ domain.AgentAdapter = (*OpenCodeAdapter)(nil) + +type OpenCodeAdapter struct { + passthrough passthroughConfig +} + +type sessionState struct { + target agentcore.LaunchTarget + agentConfig domain.AgentConfig + passthrough passthroughConfig + sessionID string + turnCount int + sessionOpened bool + closed bool + baseLogger *slog.Logger + mu sync.Mutex + active *turnRuntime +} + +type turnRuntime struct { + pid string + proc *os.Process + waitCh chan waitResult + lineCh chan parsedLine + readerDone chan struct{} + stopCh chan struct{} + stopOnce sync.Once + stderrCollector *procutil.StderrCollector + firstJSONSeen bool + terminalError *rawRunError + terminalOutcome domain.AgentEventType + waitMu sync.Mutex + waitRes waitResult +} + +type waitResult struct { + exitCode int + err error +} + +// NewOpenCodeAdapter creates an [OpenCodeAdapter] from the raw "opencode" +// adapter configuration in WORKFLOW.md. +func NewOpenCodeAdapter(config map[string]any) (domain.AgentAdapter, error) { + pt, err := parsePassthroughConfig(config) + if err != nil { + return nil, err + } + return &OpenCodeAdapter{passthrough: pt}, nil +} + +// StartSession resolves the launch target and initializes adapter-owned +// session state without starting an OpenCode subprocess. +func (a *OpenCodeAdapter) StartSession(_ context.Context, params domain.StartSessionParams) (domain.Session, error) { + target, agentErr := agentcore.ResolveLaunchTarget(params, "opencode") + if agentErr != nil { + return domain.Session{}, agentErr + } + + state := &sessionState{ + target: target, + agentConfig: params.AgentConfig, + passthrough: a.passthrough, + sessionID: params.ResumeSessionID, + baseLogger: slog.Default().With(slog.String("component", "opencode-adapter")), + } + + return domain.Session{ + ID: params.WorkspacePath, + AgentPID: "", + Internal: state, + }, nil +} + +// RunTurn executes one OpenCode turn by starting a subprocess, reading its +// stdout through a single reader goroutine, and relaying normalized events via +// params.OnEvent. +func (a *OpenCodeAdapter) RunTurn(ctx context.Context, session domain.Session, params domain.RunTurnParams) (domain.TurnResult, error) { + if params.OnEvent == nil { + panic("opencode: OnEvent must be non-nil") + } + + state, ok := session.Internal.(*sessionState) + if !ok { + return domain.TurnResult{}, &domain.AgentError{ + Kind: domain.ErrResponseError, + Message: fmt.Sprintf("unexpected session internal type %T", session.Internal), + } + } + + env, err := buildRunEnv(os.Environ(), a.passthrough) + if err != nil { + return domain.TurnResult{}, &domain.AgentError{ + Kind: domain.ErrResponseError, + Message: "build opencode environment", + Err: err, + } + } + + managedEnv, err := buildManagedEnv(a.passthrough) + if err != nil { + return domain.TurnResult{}, &domain.AgentError{ + Kind: domain.ErrResponseError, + Message: "build opencode managed environment", + Err: err, + } + } + + state.mu.Lock() + if state.closed { + state.mu.Unlock() + return domain.TurnResult{}, &domain.AgentError{ + Kind: domain.ErrResponseError, + Message: "session already stopped", + } + } + if state.active != nil { + state.mu.Unlock() + return domain.TurnResult{}, &domain.AgentError{ + Kind: domain.ErrResponseError, + Message: "session already has an active turn", + } + } + state.turnCount++ + cmdArgs := buildRunArgs(state, params.Prompt, a.passthrough) + logger := state.loggerLocked() + + var cmd *exec.Cmd + if state.target.RemoteCommand != "" { + remoteCommand := buildSSHRemoteCommand(state.target.RemoteCommand, managedEnv) + sshArgs := sshutil.BuildSSHArgs( + state.target.SSHHost, + state.target.WorkspacePath, + remoteCommand, + cmdArgs, + sshutil.SSHOptions{StrictHostKeyChecking: state.target.SSHStrictHostKeyChecking}, + ) + cmd = exec.CommandContext(ctx, state.target.Command, sshArgs...) //nolint:gosec // args are constructed programmatically with shell quoting + } else { + allArgs := append(slices.Clone(state.target.Args), cmdArgs...) + cmd = exec.CommandContext(ctx, state.target.Command, allArgs...) //nolint:gosec // args are constructed programmatically + } + procutil.SetProcessGroup(cmd) + cmd.Cancel = func() error { + return procutil.SignalGraceful(cmd.Process.Pid) + } + cmd.WaitDelay = 5 * time.Second + cmd.Dir = state.target.WorkspacePath + cmd.Env = env + + stdoutPipe, err := cmd.StdoutPipe() + if err != nil { + state.mu.Unlock() + return domain.TurnResult{}, &domain.AgentError{ + Kind: domain.ErrResponseError, + Message: "create stdout pipe", + Err: err, + } + } + stderrPipe, err := cmd.StderrPipe() + if err != nil { + state.mu.Unlock() + return domain.TurnResult{}, &domain.AgentError{ + Kind: domain.ErrResponseError, + Message: "create stderr pipe", + Err: err, + } + } + + if err := cmd.Start(); err != nil { + state.mu.Unlock() + return domain.TurnResult{}, &domain.AgentError{ + Kind: domain.ErrResponseError, + Message: "start opencode subprocess", + Err: err, + } + } + + runtime := &turnRuntime{ + pid: strconv.Itoa(cmd.Process.Pid), + proc: cmd.Process, + waitCh: make(chan waitResult, 1), + lineCh: make(chan parsedLine, 16), + readerDone: make(chan struct{}), + stopCh: make(chan struct{}), + terminalOutcome: domain.EventTurnCompleted, + } + state.active = runtime + state.mu.Unlock() + + if assignErr := procutil.AssignProcess(cmd.Process.Pid, cmd.Process); assignErr != nil { + logger.Warn("process group assignment failed", slog.Any("error", assignErr)) + } + + runtime.stderrCollector = procutil.NewStderrCollector(stderrPipe, logger) + startOpenCodeReader(stdoutPipe, runtime) + startWait(runtime, cmd) + + emit := func(event domain.AgentEvent) { + if state.target.RemoteCommand == "" { + event.AgentPID = runtime.pid + } + params.OnEvent(event) + } + + readTimeout := readTimeout(state) + readTimer := time.NewTimer(readTimeout) + defer stopTimer(readTimer) + + readTimeoutC := readTimer.C + lineCh := runtime.lineCh + waitCh := runtime.waitCh + var exit waitResult + processExited := false + + for { + select { + case parsed, ok := <-lineCh: + if !ok { + lineCh = nil + if processExited { + return a.finalizeExitedTurn(ctx, state, runtime, emit, exit) + } + continue + } + + if parsed.Err != nil { + if ctx.Err() != nil || state.isClosed() { + closeStop(runtime) + killTurnProcess(runtime) + <-runtime.readerDone + _ = waitForProcess(runtime) + clearActive(state, runtime) + agentcore.EmitTurnCancelled(emit, "turn cancelled") + return domain.TurnResult{ + SessionID: state.currentSessionID(), + ExitReason: domain.EventTurnCancelled, + }, nil + } + + emitTurnEndedWithError(emit, "stdout read error") + closeStop(runtime) + killTurnProcess(runtime) + <-runtime.readerDone + _ = waitForProcess(runtime) + procutil.EmitWarnLines(runtime.stderrCollector.Lines(), state.logger()) + clearActive(state, runtime) + return domain.TurnResult{ + SessionID: state.currentSessionID(), + ExitReason: domain.EventTurnEndedWithError, + }, &domain.AgentError{ + Kind: domain.ErrResponseError, + Message: "stdout read error", + Err: parsed.Err, + } + } + + if parsed.PlainText != "" { + if readTimeoutC != nil { + resetTimer(readTimer, readTimeout) + } + + plainText := typeutil.TruncateRunes(parsed.PlainText, 500) + if isPermissionWarning(parsed.PlainText) { + agentcore.EmitNotification(emit, plainText) + } else { + emit(domain.AgentEvent{ + Type: domain.EventMalformed, + Timestamp: time.Now().UTC(), + Message: plainText, + }) + } + continue + } + + event := parsed.Event + if event == nil { + continue + } + + if readTimeoutC != nil { + runtime.firstJSONSeen = true + stopTimer(readTimer) + readTimeoutC = nil + } + + started, mismatch := state.applySessionEvent(event.SessionID) + if mismatch { + message := fmt.Sprintf("session id mismatch: expected %q, got %q", state.currentSessionID(), event.SessionID) + emitTurnEndedWithError(emit, message) + closeStop(runtime) + killTurnProcess(runtime) + <-runtime.readerDone + _ = waitForProcess(runtime) + procutil.EmitWarnLines(runtime.stderrCollector.Lines(), state.logger()) + clearActive(state, runtime) + return domain.TurnResult{ + SessionID: state.currentSessionID(), + ExitReason: domain.EventTurnEndedWithError, + }, &domain.AgentError{ + Kind: domain.ErrResponseError, + Message: message, + } + } + if started { + emit(domain.AgentEvent{ + Type: domain.EventSessionStarted, + Timestamp: time.Now().UTC(), + SessionID: state.currentSessionID(), + Message: "session started", + }) + } + + now := time.Now().UTC() + switch event.Type { + case "step_start": + if _, err := parseStepStartPart(event.Part); err != nil { + emit(domain.AgentEvent{Type: domain.EventMalformed, Timestamp: now, Message: "invalid step_start payload"}) + continue + } + agentcore.EmitNotification(emit, "step started") + + case "text": + part, err := parseTextPart(event.Part) + if err != nil { + emit(domain.AgentEvent{Type: domain.EventMalformed, Timestamp: now, Message: "invalid text payload"}) + continue + } + agentcore.EmitNotification(emit, typeutil.TruncateRunes(part.Text, 500)) + + case "reasoning": + if _, err := parseReasoningPart(event.Part); err != nil { + emit(domain.AgentEvent{Type: domain.EventMalformed, Timestamp: now, Message: "invalid reasoning payload"}) + continue + } + emit(domain.AgentEvent{ + Type: domain.EventOtherMessage, + Timestamp: now, + Message: "reasoning block", + }) + + case "tool_use": + part, err := parseToolPart(event.Part) + if err != nil { + emit(domain.AgentEvent{Type: domain.EventMalformed, Timestamp: now, Message: "invalid tool_use payload"}) + continue + } + emit(domain.AgentEvent{ + Type: domain.EventToolResult, + Timestamp: now, + ToolName: part.Tool, + ToolDurationMS: toolDuration(part.State.Time), + ToolError: strings.EqualFold(part.State.Status, "error"), + Message: typeutil.TruncateRunes(part.State.Error, 500), + }) + + case "step_finish": + part, err := parseStepFinishPart(event.Part) + if err != nil { + emit(domain.AgentEvent{Type: domain.EventMalformed, Timestamp: now, Message: "invalid step_finish payload"}) + continue + } + agentcore.EmitNotification(emit, fmt.Sprintf("step finished: %s", part.Reason)) + + case "error": + runtime.terminalOutcome = domain.EventTurnFailed + runtime.terminalError = event.Error + agentcore.EmitTurnFailed(emit, rawRunErrorMessage(event.Error), 0) + + default: + emit(domain.AgentEvent{ + Type: domain.EventMalformed, + Timestamp: now, + Message: fmt.Sprintf("unknown event type: %s", event.Type), + }) + } + + case <-waitCh: + exit = waitForProcess(runtime) + processExited = true + waitCh = nil + if lineCh == nil { + return a.finalizeExitedTurn(ctx, state, runtime, emit, exit) + } + + case <-ctx.Done(): + closeStop(runtime) + killTurnProcess(runtime) + <-runtime.readerDone + _ = waitForProcess(runtime) + clearActive(state, runtime) + agentcore.EmitTurnCancelled(emit, "turn cancelled") + return domain.TurnResult{ + SessionID: state.currentSessionID(), + ExitReason: domain.EventTurnCancelled, + }, nil + + case <-readTimeoutC: + emitTurnEndedWithError(emit, "timed out waiting for first opencode json event") + closeStop(runtime) + killTurnProcess(runtime) + <-runtime.readerDone + _ = waitForProcess(runtime) + procutil.EmitWarnLines(runtime.stderrCollector.Lines(), state.logger()) + clearActive(state, runtime) + return domain.TurnResult{ + SessionID: state.currentSessionID(), + ExitReason: domain.EventTurnEndedWithError, + }, &domain.AgentError{ + Kind: domain.ErrResponseTimeout, + Message: "timed out waiting for first opencode json event", + } + } + } +} + +// StopSession marks the session closed and terminates any active subprocess. +func (a *OpenCodeAdapter) StopSession(_ context.Context, session domain.Session) error { + state, ok := session.Internal.(*sessionState) + if !ok { + return &domain.AgentError{ + Kind: domain.ErrResponseError, + Message: fmt.Sprintf("unexpected session internal type %T", session.Internal), + } + } + + state.mu.Lock() + state.closed = true + active := state.active + state.mu.Unlock() + + if active == nil { + return nil + } + + closeStop(active) + killTurnProcess(active) + <-active.readerDone + _ = waitForProcess(active) + clearActive(state, active) + + return nil +} + +// EventStream returns nil because OpenCode events are delivered via the +// RunTurn callback. +func (a *OpenCodeAdapter) EventStream() <-chan domain.AgentEvent { + return nil +} + +func (a *OpenCodeAdapter) finalizeExitedTurn(ctx context.Context, state *sessionState, runtime *turnRuntime, emit func(domain.AgentEvent), exit waitResult) (domain.TurnResult, error) { + usage := queryExportUsage(ctx, state) + usageSnapshot := domain.TokenUsage{ + InputTokens: usage.InputTokens, + OutputTokens: usage.OutputTokens, + TotalTokens: usage.TotalTokens, + CacheReadTokens: usage.CacheReadTokens, + } + if hasUsage(usage) { + emit(domain.AgentEvent{ + Type: domain.EventTokenUsage, + Timestamp: time.Now().UTC(), + Usage: usageSnapshot, + Model: usage.Model, + }) + } + + clearActive(state, runtime) + stderrLines := runtime.stderrCollector.Lines() + sessionID := state.currentSessionID() + + if runtime.terminalError != nil { + procutil.EmitWarnLines(stderrLines, state.logger()) + return domain.TurnResult{ + SessionID: sessionID, + ExitReason: domain.EventTurnFailed, + Usage: usageSnapshot, + }, nil + } + + if ctx.Err() != nil { + agentcore.EmitTurnCancelled(emit, "turn cancelled") + return domain.TurnResult{ + SessionID: sessionID, + ExitReason: domain.EventTurnCancelled, + Usage: usageSnapshot, + }, nil + } + + if state.isClosed() { + agentcore.EmitTurnCancelled(emit, "turn cancelled") + return domain.TurnResult{ + SessionID: sessionID, + ExitReason: domain.EventTurnCancelled, + Usage: usageSnapshot, + }, nil + } + + if !runtime.firstJSONSeen { + procutil.EmitWarnLines(stderrLines, state.logger()) + emitTurnEndedWithError(emit, "process exited before first opencode json event") + return domain.TurnResult{ + SessionID: sessionID, + ExitReason: domain.EventTurnEndedWithError, + Usage: usageSnapshot, + }, &domain.AgentError{ + Kind: domain.ErrPortExit, + Message: "process exited before first opencode json event", + Err: exit.err, + } + } + + if exit.err != nil || exit.exitCode != 0 { + procutil.EmitWarnLines(stderrLines, state.logger()) + message := portExitMessage(exit) + emitTurnEndedWithError(emit, message) + return domain.TurnResult{ + SessionID: sessionID, + ExitReason: domain.EventTurnEndedWithError, + Usage: usageSnapshot, + }, &domain.AgentError{ + Kind: domain.ErrPortExit, + Message: message, + Err: exit.err, + } + } + + agentcore.EmitTurnCompleted(emit, "", 0) + return domain.TurnResult{ + SessionID: sessionID, + ExitReason: domain.EventTurnCompleted, + Usage: usageSnapshot, + }, nil +} + +func (s *sessionState) logger() *slog.Logger { + sessionID := s.currentSessionID() + if sessionID == "" { + return s.baseLogger + } + return logging.WithSession(s.baseLogger, sessionID) +} + +// loggerLocked returns a logger for s, reading sessionID without acquiring +// s.mu. Callers must already hold s.mu. +func (s *sessionState) loggerLocked() *slog.Logger { + if s.sessionID == "" { + return s.baseLogger + } + return logging.WithSession(s.baseLogger, s.sessionID) +} + +func (s *sessionState) currentSessionID() string { + s.mu.Lock() + defer s.mu.Unlock() + return s.sessionID +} + +func (s *sessionState) isClosed() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.closed +} + +func (s *sessionState) applySessionEvent(eventSessionID string) (bool, bool) { + if eventSessionID == "" { + return false, false + } + + s.mu.Lock() + defer s.mu.Unlock() + + if s.sessionID == "" { + s.sessionID = eventSessionID + } else if s.sessionID != eventSessionID { + return false, true + } + + if s.sessionOpened { + return false, false + } + s.sessionOpened = true + return true, false +} + +func startOpenCodeReader(stdout io.Reader, runtime *turnRuntime) { + go func() { + defer close(runtime.lineCh) + defer close(runtime.readerDone) + + scanner := bufio.NewScanner(stdout) + scanner.Buffer(make([]byte, 0, 64*1024), maxLineBytes) + + for scanner.Scan() { + line := scanner.Bytes() + event, err := parseRunEvent(line) + parsed := parsedLine{} + if err != nil { + parsed.PlainText = string(line) + } else { + parsed.Event = &event + } + + select { + case runtime.lineCh <- parsed: + case <-runtime.stopCh: + return + } + } + + if err := scanner.Err(); err != nil { + select { + case runtime.lineCh <- parsedLine{Err: err}: + case <-runtime.stopCh: + } + } + }() +} + +func startWait(runtime *turnRuntime, cmd *exec.Cmd) { + go func() { + // Wait for the reader goroutine to finish before calling + // cmd.Wait(). cmd.Wait() closes the stdout pipe read end, which + // races with the scanner in startOpenCodeReader if called before + // the reader has drained all buffered output. + <-runtime.readerDone + + waitErr := cmd.Wait() + procutil.KillProcessGroup(cmd.Process.Pid) //nolint:errcheck,gosec // best-effort cleanup of surviving group members + procutil.CleanupProcess(cmd.Process.Pid) + + runtime.waitMu.Lock() + runtime.waitRes = waitResult{ + exitCode: procutil.ExtractExitCode(waitErr), + err: waitErr, + } + runtime.waitMu.Unlock() + + close(runtime.waitCh) + }() +} + +func waitForProcess(runtime *turnRuntime) waitResult { + <-runtime.waitCh + runtime.waitMu.Lock() + defer runtime.waitMu.Unlock() + return runtime.waitRes +} + +func clearActive(state *sessionState, runtime *turnRuntime) { + state.mu.Lock() + defer state.mu.Unlock() + if state.active == runtime { + state.active = nil + } +} + +func closeStop(runtime *turnRuntime) { + runtime.stopOnce.Do(func() { + close(runtime.stopCh) + }) +} + +func killTurnProcess(runtime *turnRuntime) { + if runtime == nil || runtime.proc == nil { + return + } + procutil.KillProcessGroup(runtime.proc.Pid) //nolint:errcheck,gosec // best-effort cleanup +} + +func stopTimer(timer *time.Timer) { + if timer == nil { + return + } + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } +} + +func resetTimer(timer *time.Timer, timeout time.Duration) { + stopTimer(timer) + timer.Reset(timeout) +} + +func readTimeout(state *sessionState) time.Duration { + if state.agentConfig.ReadTimeoutMS > 0 { + return time.Duration(state.agentConfig.ReadTimeoutMS) * time.Millisecond + } + return 30 * time.Second +} + +func exportTimeout(state *sessionState) time.Duration { + timeout := 2 * readTimeout(state) + if timeout <= 0 || timeout > 30*time.Second { + return 30 * time.Second + } + return timeout +} + +func emitTurnEndedWithError(emit func(domain.AgentEvent), message string) { + emit(domain.AgentEvent{ + Type: domain.EventTurnEndedWithError, + Timestamp: time.Now().UTC(), + Message: message, + }) +} + +func isPermissionWarning(line string) bool { + return strings.HasPrefix(strings.TrimSpace(line), "! permission requested:") +} + +func toolDuration(partTime rawPartTime) int64 { + if partTime.End <= partTime.Start { + return 0 + } + return partTime.End - partTime.Start +} + +func rawRunErrorMessage(runErr *rawRunError) string { + if runErr == nil { + return "opencode reported an unknown error" + } + if runErr.Data != nil { + if message, ok := runErr.Data["message"].(string); ok && message != "" { + return message + } + } + if runErr.Name != "" { + return runErr.Name + } + return "opencode reported an unknown error" +} + +func portExitMessage(exit waitResult) string { + if exit.exitCode > 0 { + return fmt.Sprintf("opencode exited with code %d", exit.exitCode) + } + if exit.err != nil { + return fmt.Sprintf("opencode exited unexpectedly: %v", exit.err) + } + return "opencode exited unexpectedly" +} + +func hasUsage(usage exportUsage) bool { + return usage.InputTokens > 0 || usage.OutputTokens > 0 || usage.TotalTokens > 0 || usage.CacheReadTokens > 0 +} diff --git a/internal/agent/opencode/opencode_test.go b/internal/agent/opencode/opencode_test.go new file mode 100644 index 00000000..412d6e99 --- /dev/null +++ b/internal/agent/opencode/opencode_test.go @@ -0,0 +1,725 @@ +//go:build unix + +package opencode + +import ( + "context" + "errors" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/sortie-ai/sortie/internal/agent/agenttest" + "github.com/sortie-ai/sortie/internal/domain" +) + +// writeOpenCodeScript writes an executable shell script named fake-opencode +// in dir with the given body and returns its path. +func writeOpenCodeScript(t *testing.T, dir, body string) string { + t.Helper() + return agenttest.WriteScript(t, dir, "fake-opencode", body) +} + +// mustStartSession starts a session with the given command or fatals. +func mustStartSession(t *testing.T, a domain.AgentAdapter, workDir, cmd string) domain.Session { + t.Helper() + session, err := a.StartSession(context.Background(), domain.StartSessionParams{ + WorkspacePath: workDir, + AgentConfig: domain.AgentConfig{Command: cmd}, + }) + if err != nil { + t.Fatalf("StartSession() error = %v", err) + } + return session +} + +func writeRunFixtureScript(t *testing.T, dir, fixtureName string) string { + t.Helper() + + runPath := filepath.Join(dir, fixtureName) + if err := os.WriteFile(runPath, loadFixture(t, fixtureName), 0o644); err != nil { + t.Fatalf("WriteFile(%q): %v", fixtureName, err) + } + + exportPath := filepath.Join(dir, "export.json") + if err := os.WriteFile(exportPath, []byte(`{"messages":[]}`), 0o644); err != nil { + t.Fatalf("WriteFile(export.json): %v", err) + } + + body := `case "$1" in + export) cat '` + exportPath + `'; exit 0;; +esac +cat '` + runPath + `'` + + return writeOpenCodeScript(t, dir, body) +} + +// collectEvents runs a turn and collects all emitted events. +func collectEvents(t *testing.T, a domain.AgentAdapter, session domain.Session, prompt string) ([]domain.AgentEvent, domain.TurnResult, error) { + t.Helper() + var events []domain.AgentEvent + result, err := a.RunTurn(context.Background(), session, domain.RunTurnParams{ + Prompt: prompt, + OnEvent: func(e domain.AgentEvent) { + events = append(events, e) + }, + }) + return events, result, err +} + +func TestNewOpenCodeAdapter(t *testing.T) { + t.Parallel() + + a, err := NewOpenCodeAdapter(map[string]any{}) + if err != nil { + t.Fatalf("NewOpenCodeAdapter() error = %v", err) + } + if a == nil { + t.Fatal("adapter is nil") + } + if _, ok := a.(*OpenCodeAdapter); !ok { + t.Errorf("adapter type = %T, want *OpenCodeAdapter", a) + } +} + +func TestEventStream_ReturnsNil(t *testing.T) { + t.Parallel() + + a, _ := NewOpenCodeAdapter(map[string]any{}) + if ch := a.EventStream(); ch != nil { + t.Errorf("EventStream() = %v, want nil", ch) + } +} + +func TestStartSession_InvalidWorkspace(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + params domain.StartSessionParams + wantErr domain.AgentErrorKind + }{ + { + name: "empty_workspace_path", + params: domain.StartSessionParams{ + AgentConfig: domain.AgentConfig{Command: "/bin/sh"}, + }, + wantErr: domain.ErrInvalidWorkspaceCwd, + }, + { + name: "non_existent_workspace", + params: domain.StartSessionParams{ + WorkspacePath: "/nonexistent/path/sortie-test-xyz", + AgentConfig: domain.AgentConfig{Command: "/bin/sh"}, + }, + wantErr: domain.ErrInvalidWorkspaceCwd, + }, + { + name: "command_not_found", + params: domain.StartSessionParams{ + WorkspacePath: mustMakeTempDir(t), + AgentConfig: domain.AgentConfig{Command: "sortie-nonexistent-binary-opencode-xyz"}, + }, + wantErr: domain.ErrAgentNotFound, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + a, _ := NewOpenCodeAdapter(map[string]any{}) + _, err := a.StartSession(context.Background(), tt.params) + if err == nil { + t.Fatal("StartSession() error = nil, want error") + } + var agentErr *domain.AgentError + if !errors.As(err, &agentErr) { + t.Fatalf("error type = %T, want *domain.AgentError", err) + } + if agentErr.Kind != tt.wantErr { + t.Errorf("Kind = %q, want %q", agentErr.Kind, tt.wantErr) + } + }) + } +} + +func TestStartSession_ResumeSession(t *testing.T) { + t.Parallel() + + a, _ := NewOpenCodeAdapter(map[string]any{}) + resumeID := "ses_resume123" + session, err := a.StartSession(context.Background(), domain.StartSessionParams{ + WorkspacePath: t.TempDir(), + AgentConfig: domain.AgentConfig{Command: "/bin/sh"}, + ResumeSessionID: resumeID, + }) + if err != nil { + t.Fatalf("StartSession() error = %v", err) + } + + state := session.Internal.(*sessionState) + if state.sessionID != resumeID { + t.Errorf("sessionID = %q, want %q", state.sessionID, resumeID) + } +} + +func TestRunTurn_WrongInternalType(t *testing.T) { + t.Parallel() + + a, _ := NewOpenCodeAdapter(map[string]any{}) + session := domain.Session{ + ID: "test", + Internal: "not-a-session-state", + } + + _, err := a.RunTurn(context.Background(), session, domain.RunTurnParams{ + Prompt: "work", + OnEvent: func(_ domain.AgentEvent) {}, + }) + if err == nil { + t.Fatal("RunTurn() error = nil, want error for wrong internal type") + } + var agentErr *domain.AgentError + if !errors.As(err, &agentErr) { + t.Fatalf("error type = %T, want *domain.AgentError", err) + } + if agentErr.Kind != domain.ErrResponseError { + t.Errorf("Kind = %q, want %q", agentErr.Kind, domain.ErrResponseError) + } +} + +func TestRunTurn_ClosedSession(t *testing.T) { + t.Parallel() + + a, _ := NewOpenCodeAdapter(map[string]any{}) + tmpDir := t.TempDir() + session := mustStartSession(t, a, tmpDir, "/bin/sh") + + state := session.Internal.(*sessionState) + state.mu.Lock() + state.closed = true + state.mu.Unlock() + + _, err := a.RunTurn(context.Background(), session, domain.RunTurnParams{ + Prompt: "work", + OnEvent: func(_ domain.AgentEvent) {}, + }) + if err == nil { + t.Fatal("RunTurn() error = nil, want error for closed session") + } + var agentErr *domain.AgentError + if !errors.As(err, &agentErr) { + t.Fatalf("error type = %T, want *domain.AgentError", err) + } + if agentErr.Kind != domain.ErrResponseError { + t.Errorf("Kind = %q, want %q", agentErr.Kind, domain.ErrResponseError) + } +} + +func TestRunTurn_ConcurrentRunRejected(t *testing.T) { + t.Parallel() + + a, _ := NewOpenCodeAdapter(map[string]any{}) + tmpDir := t.TempDir() + session := mustStartSession(t, a, tmpDir, "/bin/sh") + + state := session.Internal.(*sessionState) + state.mu.Lock() + state.active = &turnRuntime{ + stopCh: make(chan struct{}), + waitCh: make(chan waitResult), + } + state.mu.Unlock() + + _, err := a.RunTurn(context.Background(), session, domain.RunTurnParams{ + Prompt: "work", + OnEvent: func(_ domain.AgentEvent) {}, + }) + if err == nil { + t.Fatal("RunTurn() error = nil, want error for concurrent turn") + } + var agentErr *domain.AgentError + if !errors.As(err, &agentErr) { + t.Fatalf("error type = %T, want *domain.AgentError", err) + } + if agentErr.Kind != domain.ErrResponseError { + t.Errorf("Kind = %q, want %q", agentErr.Kind, domain.ErrResponseError) + } +} + +func TestRunTurn_SessionIDMismatch(t *testing.T) { + t.Parallel() + + a, _ := NewOpenCodeAdapter(map[string]any{}) + tmpDir := t.TempDir() + script := writeRunFixtureScript(t, tmpDir, "simple_turn.jsonl") + + session, err := a.StartSession(context.Background(), domain.StartSessionParams{ + WorkspacePath: tmpDir, + AgentConfig: domain.AgentConfig{Command: script}, + ResumeSessionID: "ses_expected", + }) + if err != nil { + t.Fatalf("StartSession() error = %v", err) + } + + events, result, runErr := collectEvents(t, a, session, "work") + if runErr == nil { + t.Fatal("RunTurn() error = nil, want session mismatch error") + } + var agentErr *domain.AgentError + if !errors.As(runErr, &agentErr) { + t.Fatalf("error type = %T, want *domain.AgentError", runErr) + } + if agentErr.Kind != domain.ErrResponseError { + t.Errorf("Kind = %q, want %q", agentErr.Kind, domain.ErrResponseError) + } + if result.ExitReason != domain.EventTurnEndedWithError { + t.Errorf("ExitReason = %q, want %q", result.ExitReason, domain.EventTurnEndedWithError) + } + + var mismatchCount int + for _, event := range events { + if event.Type == domain.EventSessionStarted { + t.Fatalf("unexpected session_started event for mismatched session: %+v", event) + } + if event.Type == domain.EventTurnEndedWithError { + mismatchCount++ + if !strings.Contains(event.Message, `expected "ses_expected"`) || !strings.Contains(event.Message, `got "ses_abc123"`) { + t.Errorf("turn_ended_with_error message = %q, want mismatch details", event.Message) + } + } + } + if mismatchCount != 1 { + t.Errorf("turn_ended_with_error count = %d, want 1", mismatchCount) + } +} + +func TestStopSession_NoActiveTurn(t *testing.T) { + t.Parallel() + + a, _ := NewOpenCodeAdapter(map[string]any{}) + tmpDir := t.TempDir() + session := mustStartSession(t, a, tmpDir, "/bin/sh") + + if err := a.StopSession(context.Background(), session); err != nil { + t.Fatalf("StopSession() error = %v, want nil", err) + } + // Double stop should also return nil. + if err := a.StopSession(context.Background(), session); err != nil { + t.Fatalf("StopSession() second call error = %v, want nil", err) + } +} + +func TestStopSession_WrongInternalType(t *testing.T) { + t.Parallel() + + a, _ := NewOpenCodeAdapter(map[string]any{}) + session := domain.Session{ + ID: "test", + Internal: "not-a-session-state", + } + + err := a.StopSession(context.Background(), session) + if err == nil { + t.Fatal("StopSession() error = nil, want error for wrong internal type") + } + var agentErr *domain.AgentError + if !errors.As(err, &agentErr) { + t.Fatalf("error type = %T, want *domain.AgentError", err) + } +} + +func TestRunTurn_SessionStartedOnce(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + + // Write fixture to a stable path the script can cat. + fixture := loadFixture(t, "simple_turn.jsonl") + fixturePath := filepath.Join(tmpDir, "output.jsonl") + if err := os.WriteFile(fixturePath, fixture, 0o644); err != nil { + t.Fatal(err) + } + + script := writeOpenCodeScript(t, tmpDir, "cat '"+fixturePath+"'") + + a, _ := NewOpenCodeAdapter(map[string]any{}) + session := mustStartSession(t, a, tmpDir, script) + + countType := func(events []domain.AgentEvent, typ domain.AgentEventType) int { + n := 0 + for _, e := range events { + if e.Type == typ { + n++ + } + } + return n + } + + // First turn: session_started fires exactly once. + turn1Events, result1, err := collectEvents(t, a, session, "first prompt") + if err != nil { + t.Fatalf("RunTurn (turn 1) error = %v", err) + } + if result1.ExitReason != domain.EventTurnCompleted { + t.Errorf("turn 1 ExitReason = %q, want %q", result1.ExitReason, domain.EventTurnCompleted) + } + if n := countType(turn1Events, domain.EventSessionStarted); n != 1 { + t.Errorf("turn 1: session_started count = %d, want 1", n) + } + + // Second turn on the same session: session_started must not fire again. + turn2Events, result2, err := collectEvents(t, a, session, "second prompt") + if err != nil { + t.Fatalf("RunTurn (turn 2) error = %v", err) + } + if result2.ExitReason != domain.EventTurnCompleted { + t.Errorf("turn 2 ExitReason = %q, want %q", result2.ExitReason, domain.EventTurnCompleted) + } + if n := countType(turn2Events, domain.EventSessionStarted); n != 0 { + t.Errorf("turn 2: session_started count = %d, want 0 (already opened)", n) + } +} + +func TestRunTurn_LogicalFailureExitZero(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + script := writeRunFixtureScript(t, tmpDir, "logical_failure_exit0.jsonl") + + a, _ := NewOpenCodeAdapter(map[string]any{}) + session := mustStartSession(t, a, tmpDir, script) + + events, result, err := collectEvents(t, a, session, "work") + if err != nil { + t.Fatalf("RunTurn() error = %v, want nil", err) + } + if result.ExitReason != domain.EventTurnFailed { + t.Errorf("ExitReason = %q, want %q", result.ExitReason, domain.EventTurnFailed) + } + + var turnFailedCount int + for _, event := range events { + if event.Type == domain.EventTurnEndedWithError { + t.Fatalf("unexpected turn_ended_with_error event: %+v", event) + } + if event.Type == domain.EventTurnFailed { + turnFailedCount++ + } + } + if turnFailedCount != 1 { + t.Errorf("turn_failed count = %d, want 1", turnFailedCount) + } +} + +func TestRunTurn_OversizedStdoutLine(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + script := writeOpenCodeScript(t, tmpDir, `head -c $((10*1024*1024+1)) /dev/zero | tr '\000' 'a' +printf '\n'`) + + a, _ := NewOpenCodeAdapter(map[string]any{}) + session := mustStartSession(t, a, tmpDir, script) + + _, result, err := collectEvents(t, a, session, "work") + if err == nil { + t.Fatal("RunTurn() error = nil, want oversized-line failure") + } + var agentErr *domain.AgentError + if !errors.As(err, &agentErr) { + t.Fatalf("error type = %T, want *domain.AgentError", err) + } + if agentErr.Kind != domain.ErrResponseError { + t.Errorf("Kind = %q, want %q", agentErr.Kind, domain.ErrResponseError) + } + if result.ExitReason != domain.EventTurnEndedWithError { + t.Errorf("ExitReason = %q, want %q", result.ExitReason, domain.EventTurnEndedWithError) + } +} + +func TestRunTurn_EventAgentPID(t *testing.T) { + t.Parallel() + + t.Run("local_events_include_pid", func(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + script := writeRunFixtureScript(t, tmpDir, "simple_turn.jsonl") + + a, _ := NewOpenCodeAdapter(map[string]any{}) + session := mustStartSession(t, a, tmpDir, script) + + events, result, err := collectEvents(t, a, session, "work") + if err != nil { + t.Fatalf("RunTurn() error = %v", err) + } + if result.ExitReason != domain.EventTurnCompleted { + t.Fatalf("ExitReason = %q, want %q", result.ExitReason, domain.EventTurnCompleted) + } + if len(events) == 0 { + t.Fatal("events = 0, want > 0") + } + + wantPID := events[0].AgentPID + if wantPID == "" { + t.Fatal("first event AgentPID is empty, want subprocess pid") + } + for _, event := range events { + if event.AgentPID != wantPID { + t.Errorf("event %q AgentPID = %q, want %q", event.Type, event.AgentPID, wantPID) + } + } + }) + + t.Run("ssh_events_leave_pid_empty", func(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + script := writeRunFixtureScript(t, tmpDir, "simple_turn.jsonl") + + a, _ := NewOpenCodeAdapter(map[string]any{}) + session, err := a.StartSession(context.Background(), domain.StartSessionParams{ + WorkspacePath: tmpDir, + AgentConfig: domain.AgentConfig{Command: "opencode"}, + SSHHost: "example.test", + }) + if err != nil { + t.Fatalf("StartSession() error = %v", err) + } + state := session.Internal.(*sessionState) + state.target.Command = script + + events, result, err := collectEvents(t, a, session, "work") + if err != nil { + t.Fatalf("RunTurn() error = %v", err) + } + if result.ExitReason != domain.EventTurnCompleted { + t.Fatalf("ExitReason = %q, want %q", result.ExitReason, domain.EventTurnCompleted) + } + if len(events) == 0 { + t.Fatal("events = 0, want > 0") + } + for _, event := range events { + if event.AgentPID != "" { + t.Errorf("event %q AgentPID = %q, want empty in ssh mode", event.Type, event.AgentPID) + } + } + }) +} + +func TestRunTurn_ActivityVisibilityForStallWatchdog(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + script := writeOpenCodeScript(t, tmpDir, `case "$1" in + export) echo '{"messages":[]}'; exit 0;; +esac +printf '! permission requested: external_directory (/etc/*); auto-rejecting\n' +printf '{"type":"step_start","timestamp":1000,"sessionID":"ses_visibility123","part":{"id":"p1","messageID":"m1","sessionID":"ses_visibility123","snapshot":"","type":"step-start"}}\n' +printf '{"type":"unknown_future_type","timestamp":1001,"sessionID":"ses_visibility123","data":"something"}\n' +printf '{"type":"step_finish","timestamp":1002,"sessionID":"ses_visibility123","part":{"id":"p2","messageID":"m1","sessionID":"ses_visibility123","type":"step-finish","reason":"stop"}}\n'`) + + a, _ := NewOpenCodeAdapter(map[string]any{}) + session := mustStartSession(t, a, tmpDir, script) + + events, result, err := collectEvents(t, a, session, "work") + if err != nil { + t.Fatalf("RunTurn() error = %v", err) + } + if result.ExitReason != domain.EventTurnCompleted { + t.Fatalf("ExitReason = %q, want %q", result.ExitReason, domain.EventTurnCompleted) + } + + var sawPermissionWarning bool + var sawStepStarted bool + var sawUnknownMalformed bool + var sawStepFinished bool + var sawSessionStarted bool + var sawTurnCompleted bool + + for _, event := range events { + switch event.Type { + case domain.EventNotification: + switch { + case strings.HasPrefix(event.Message, "! permission requested:"): + sawPermissionWarning = true + case event.Message == "step started": + sawStepStarted = true + case event.Message == "step finished: stop": + sawStepFinished = true + } + case domain.EventMalformed: + if strings.Contains(event.Message, "unknown event type") { + sawUnknownMalformed = true + } + case domain.EventSessionStarted: + sawSessionStarted = true + case domain.EventTurnCompleted: + sawTurnCompleted = true + } + } + + if !sawPermissionWarning { + t.Error("permission warning notification was not emitted") + } + if !sawStepStarted { + t.Error("step_start notification was not emitted") + } + if !sawUnknownMalformed { + t.Error("unknown JSON envelope did not emit malformed event") + } + if !sawStepFinished { + t.Error("step_finish notification was not emitted") + } + if !sawSessionStarted { + t.Error("session_started event was not emitted") + } + if !sawTurnCompleted { + t.Error("turn_completed event was not emitted") + } +} + +func TestRunTurn_TurnCancelledOnContextCancel(t *testing.T) { + t.Parallel() + + outerCtx, outerCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer outerCancel() + + tmpDir := t.TempDir() + // Script: emit one JSON event on a run call, then block until killed. + // Handle export subcommand immediately so queryExportUsage doesn't block. + script := writeOpenCodeScript(t, tmpDir, `case "$1" in + export) echo '{"messages":[]}'; exit 0;; +esac +printf '{"type":"step_start","timestamp":1000,"sessionID":"ses_abc123","part":{"id":"p1","messageID":"m1","sessionID":"ses_abc123","snapshot":"","type":"step-start"}}\n' +sleep 1000`) + + a, _ := NewOpenCodeAdapter(map[string]any{}) + session, err := a.StartSession(outerCtx, domain.StartSessionParams{ + WorkspacePath: tmpDir, + AgentConfig: domain.AgentConfig{Command: script}, + }) + if err != nil { + t.Fatalf("StartSession() error = %v", err) + } + + // turnCtx is the context we'll cancel to trigger TurnCancelled. + turnCtx, turnCancel := context.WithCancel(outerCtx) + + gotEvent := make(chan struct{}, 1) + resultCh := make(chan domain.TurnResult, 1) + errCh := make(chan error, 1) + go func() { + result, runErr := a.RunTurn(turnCtx, session, domain.RunTurnParams{ + Prompt: "work", + OnEvent: func(_ domain.AgentEvent) { + select { + case gotEvent <- struct{}{}: + default: + } + }, + }) + resultCh <- result + errCh <- runErr + }() + + // Wait for the subprocess to emit the first event. + select { + case <-gotEvent: + case <-outerCtx.Done(): + t.Fatal("timed out waiting for first event") + } + + turnCancel() + + select { + case result := <-resultCh: + if result.ExitReason != domain.EventTurnCancelled { + t.Errorf("ExitReason = %q, want %q", result.ExitReason, domain.EventTurnCancelled) + } + if err := <-errCh; err != nil { + t.Errorf("RunTurn() error = %v, want nil on cancel", err) + } + case <-outerCtx.Done(): + t.Fatal("RunTurn did not return after context cancel") + } +} + +func TestRunTurn_StopSessionUnblocksReader(t *testing.T) { + t.Parallel() + + // testCtx bounds the assertion deadline; runCtx is separate so + // ctx.Done() in RunTurn's main loop doesn't race with the test's + // resultCh select. + testCtx, testCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer testCancel() + + runCtx, runCancel := context.WithCancel(context.Background()) + defer runCancel() + + tmpDir := t.TempDir() + // Script: emit one JSON event on a run call, then block until killed. + // Handle export subcommand immediately so queryExportUsage doesn't block. + script := writeOpenCodeScript(t, tmpDir, `case "$1" in + export) echo '{"messages":[]}'; exit 0;; +esac +printf '{"type":"step_start","timestamp":1000,"sessionID":"ses_abc123","part":{"id":"p1","messageID":"m1","sessionID":"ses_abc123","snapshot":"","type":"step-start"}}\n' +sleep 1000`) + + a, _ := NewOpenCodeAdapter(map[string]any{}) + session, err := a.StartSession(testCtx, domain.StartSessionParams{ + WorkspacePath: tmpDir, + AgentConfig: domain.AgentConfig{Command: script}, + }) + if err != nil { + t.Fatalf("StartSession() error = %v", err) + } + + gotEvent := make(chan struct{}, 1) + resultCh := make(chan domain.TurnResult, 1) + go func() { + result, _ := a.RunTurn(runCtx, session, domain.RunTurnParams{ + Prompt: "work", + OnEvent: func(_ domain.AgentEvent) { + select { + case gotEvent <- struct{}{}: + default: + } + }, + }) + resultCh <- result + }() + + // Wait for the subprocess to be active. + select { + case <-gotEvent: + case <-testCtx.Done(): + t.Fatal("timed out waiting for first event") + } + + if err := a.StopSession(testCtx, session); err != nil { + t.Fatalf("StopSession() error = %v", err) + } + + select { + case result := <-resultCh: + if result.ExitReason != domain.EventTurnCancelled { + t.Errorf("ExitReason = %q, want %q", result.ExitReason, domain.EventTurnCancelled) + } + case <-testCtx.Done(): + t.Fatal("RunTurn did not return after StopSession") + } +} + +// mustMakeTempDir is a helper that returns a temporary directory path. +// Used in test table initialization where t.TempDir() cannot be called +// inside a struct literal. +func mustMakeTempDir(t *testing.T) string { + t.Helper() + return t.TempDir() +} diff --git a/internal/agent/opencode/parse.go b/internal/agent/opencode/parse.go new file mode 100644 index 00000000..ee51e1f8 --- /dev/null +++ b/internal/agent/opencode/parse.go @@ -0,0 +1,351 @@ +package opencode + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "os" + "os/exec" + "slices" + + "github.com/sortie-ai/sortie/internal/agent/sshutil" +) + +const maxLineBytes = 10 * 1024 * 1024 + +type parsedLine struct { + Event *rawRunEvent + PlainText string + Err error +} + +type rawRunEvent struct { + Type string `json:"type"` + Timestamp int64 `json:"timestamp"` + SessionID string `json:"sessionID"` + Part json.RawMessage `json:"part,omitempty"` + Error *rawRunError `json:"error,omitempty"` +} + +type rawRunError struct { + Name string `json:"name,omitempty"` + Data map[string]any `json:"data,omitempty"` +} + +type rawPartTime struct { + Start int64 `json:"start,omitempty"` + End int64 `json:"end,omitempty"` +} + +type rawStepStartPart struct { + ID string `json:"id"` + MessageID string `json:"messageID"` + SessionID string `json:"sessionID"` + Snapshot string `json:"snapshot,omitempty"` + Type string `json:"type"` +} + +type rawTextPart struct { + ID string `json:"id"` + MessageID string `json:"messageID"` + SessionID string `json:"sessionID"` + Type string `json:"type"` + Text string `json:"text"` + Time rawPartTime `json:"time,omitempty"` +} + +type rawReasoningPart struct { + ID string `json:"id"` + MessageID string `json:"messageID"` + SessionID string `json:"sessionID"` + Type string `json:"type"` + Text string `json:"text"` + Time rawPartTime `json:"time,omitempty"` +} + +type rawToolPart struct { + ID string `json:"id"` + MessageID string `json:"messageID"` + SessionID string `json:"sessionID"` + Type string `json:"type"` + Tool string `json:"tool"` + CallID string `json:"callID"` + State rawToolState `json:"state"` +} + +type rawToolState struct { + Status string `json:"status"` + Input any `json:"input,omitempty"` + Output any `json:"output,omitempty"` + Metadata any `json:"metadata,omitempty"` + Error string `json:"error,omitempty"` + Title string `json:"title,omitempty"` + Time rawPartTime `json:"time,omitempty"` +} + +type rawStepFinishPart struct { + ID string `json:"id"` + MessageID string `json:"messageID"` + SessionID string `json:"sessionID"` + Type string `json:"type"` + Reason string `json:"reason"` + Tokens *rawStepTokens `json:"tokens,omitempty"` + Cost float64 `json:"cost,omitempty"` +} + +type rawStepTokens struct { + Total int64 `json:"total,omitempty"` + Input int64 `json:"input,omitempty"` + Output int64 `json:"output,omitempty"` + Reasoning int64 `json:"reasoning,omitempty"` + Cache rawCacheUsage `json:"cache,omitempty"` +} + +type rawCacheUsage struct { + Read int64 `json:"read,omitempty"` + Write int64 `json:"write,omitempty"` +} + +type exportUsage struct { + InputTokens int64 + OutputTokens int64 + TotalTokens int64 + CacheReadTokens int64 + Model string + Cost float64 +} + +func parseRunEvent(line []byte) (rawRunEvent, error) { + var event rawRunEvent + if err := json.Unmarshal(line, &event); err != nil { + return rawRunEvent{}, fmt.Errorf("parse run event: %w", err) + } + return event, nil +} + +func parseStepStartPart(raw json.RawMessage) (rawStepStartPart, error) { + var part rawStepStartPart + if err := json.Unmarshal(raw, &part); err != nil { + return rawStepStartPart{}, err + } + return part, nil +} + +func parseTextPart(raw json.RawMessage) (rawTextPart, error) { + var part rawTextPart + if err := json.Unmarshal(raw, &part); err != nil { + return rawTextPart{}, err + } + return part, nil +} + +func parseReasoningPart(raw json.RawMessage) (rawReasoningPart, error) { + var part rawReasoningPart + if err := json.Unmarshal(raw, &part); err != nil { + return rawReasoningPart{}, err + } + return part, nil +} + +func parseToolPart(raw json.RawMessage) (rawToolPart, error) { + var part rawToolPart + if err := json.Unmarshal(raw, &part); err != nil { + return rawToolPart{}, err + } + return part, nil +} + +func parseStepFinishPart(raw json.RawMessage) (rawStepFinishPart, error) { + var part rawStepFinishPart + if err := json.Unmarshal(raw, &part); err != nil { + return rawStepFinishPart{}, err + } + return part, nil +} + +func queryExportUsage(ctx context.Context, state *sessionState) exportUsage { + sessionID := state.currentSessionID() + if sessionID == "" { + return exportUsage{} + } + + env, err := buildRunEnv(os.Environ(), state.passthrough) + if err != nil { + state.logger().Warn("failed to build opencode export environment", slog.Any("error", err)) + return exportUsage{} + } + + managedEnv, err := buildManagedEnv(state.passthrough) + if err != nil { + state.logger().Warn("failed to build opencode export command", slog.Any("error", err)) + return exportUsage{} + } + + queryCtx, cancel := context.WithTimeout(ctx, exportTimeout(state)) + defer cancel() + + exportArgs := []string{"export", "--sanitize", sessionID} + var cmd *exec.Cmd + if state.target.RemoteCommand != "" { + remoteCommand := buildSSHRemoteCommand(state.target.RemoteCommand, managedEnv) + sshArgs := sshutil.BuildSSHArgs( + state.target.SSHHost, + state.target.WorkspacePath, + remoteCommand, + exportArgs, + sshutil.SSHOptions{StrictHostKeyChecking: state.target.SSHStrictHostKeyChecking}, + ) + cmd = exec.CommandContext(queryCtx, state.target.Command, sshArgs...) //nolint:gosec // args are constructed programmatically with shell quoting + } else { + allArgs := append(slices.Clone(state.target.Args), exportArgs...) + cmd = exec.CommandContext(queryCtx, state.target.Command, allArgs...) //nolint:gosec // args are constructed programmatically + } + cmd.Dir = state.target.WorkspacePath + cmd.Env = env + + stdout, err := cmd.Output() + if err != nil { + state.logger().Warn("failed to export opencode usage", slog.Any("error", err)) + return exportUsage{} + } + + usage := parseExportOutput(stdout, sessionID) + if usage.InputTokens == 0 && usage.OutputTokens == 0 { + state.logger().Warn("no assistant token usage found in opencode export") + } + return usage +} + +// parseExportOutput extracts token usage from the JSON returned by +// opencode export, scanning messages in reverse to find the most recent +// assistant message for sessionID. Returns zero exportUsage on any parse +// failure or when no matching message is found. +func parseExportOutput(data []byte, sessionID string) exportUsage { + var payload map[string]any + if err := json.Unmarshal(data, &payload); err != nil { + return exportUsage{} + } + messages, ok := payload["messages"].([]any) + if !ok { + return exportUsage{} + } + for i := len(messages) - 1; i >= 0; i-- { + message, ok := messages[i].(map[string]any) + if !ok { + continue + } + info := mapFromAny(message["info"]) + if info == nil { + continue + } + if stringFromAny(info["role"]) != "assistant" { + continue + } + if stringFromAny(info["sessionID"]) != sessionID { + continue + } + tokens := mapFromAny(info["tokens"]) + if tokens == nil { + continue + } + inputTokens, ok := int64FromAny(tokens["input"]) + if !ok { + continue + } + outputTokens, ok := int64FromAny(tokens["output"]) + if !ok { + continue + } + totalTokens := inputTokens + outputTokens + if total, ok := int64FromAny(tokens["total"]); ok { + totalTokens = total + } + var cacheReadTokens int64 + if cache := mapFromAny(tokens["cache"]); cache != nil { + if read, ok := int64FromAny(cache["read"]); ok { + cacheReadTokens = read + } + } + var model string + providerID := stringFromAny(info["providerID"]) + modelID := stringFromAny(info["modelID"]) + if providerID != "" && modelID != "" { + model = providerID + "/" + modelID + } + cost, _ := float64FromAny(info["cost"]) + return exportUsage{ + InputTokens: inputTokens, + OutputTokens: outputTokens, + TotalTokens: totalTokens, + CacheReadTokens: cacheReadTokens, + Model: model, + Cost: cost, + } + } + return exportUsage{} +} + +func mapFromAny(value any) map[string]any { + if value == nil { + return nil + } + typed, ok := value.(map[string]any) + if !ok { + return nil + } + return typed +} + +func stringFromAny(value any) string { + str, ok := value.(string) + if !ok { + return "" + } + return str +} + +func int64FromAny(value any) (int64, bool) { + switch typed := value.(type) { + case int: + return int64(typed), true + case int32: + return int64(typed), true + case int64: + return typed, true + case float64: + if typed != float64(int64(typed)) { + return 0, false + } + return int64(typed), true + case json.Number: + value, err := typed.Int64() + if err != nil { + return 0, false + } + return value, true + default: + return 0, false + } +} + +func float64FromAny(value any) (float64, bool) { + switch typed := value.(type) { + case int: + return float64(typed), true + case int32: + return float64(typed), true + case int64: + return float64(typed), true + case float64: + return typed, true + case json.Number: + value, err := typed.Float64() + if err != nil { + return 0, false + } + return value, true + default: + return 0, false + } +} diff --git a/internal/agent/opencode/parse_test.go b/internal/agent/opencode/parse_test.go new file mode 100644 index 00000000..9b50f4f8 --- /dev/null +++ b/internal/agent/opencode/parse_test.go @@ -0,0 +1,306 @@ +package opencode + +import ( + "bytes" + "os" + "path/filepath" + "strings" + "testing" +) + +// loadFixture reads testdata/ and returns its bytes. +func loadFixture(t *testing.T, name string) []byte { + t.Helper() + data, err := os.ReadFile(filepath.Join("testdata", name)) + if err != nil { + t.Fatalf("loadFixture(%q): %v", name, err) + } + return data +} + +// loadFixtureLine returns the zero-based line at index from a fixture file. +func loadFixtureLine(t *testing.T, name string, index int) []byte { + t.Helper() + data := loadFixture(t, name) + lines := bytes.Split(bytes.TrimRight(data, "\n"), []byte("\n")) + if index < 0 || index >= len(lines) { + t.Fatalf("loadFixtureLine(%q, %d): file has %d lines", name, index, len(lines)) + } + return lines[index] +} + +func TestParseRunEvent(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + fixture string + lineIdx int + wantType string + checkFunc func(t *testing.T, ev rawRunEvent) + }{ + { + name: "step_start_line", + fixture: "simple_turn.jsonl", + lineIdx: 0, + wantType: "step_start", + checkFunc: func(t *testing.T, ev rawRunEvent) { + t.Helper() + if len(ev.Part) == 0 { + t.Error("Part is empty, want non-empty") + } + part, err := parseStepStartPart(ev.Part) + if err != nil { + t.Fatalf("parseStepStartPart() error = %v", err) + } + if part.ID == "" { + t.Error("StepStartPart.ID is empty") + } + }, + }, + { + name: "text_line", + fixture: "simple_turn.jsonl", + lineIdx: 1, + wantType: "text", + checkFunc: func(t *testing.T, ev rawRunEvent) { + t.Helper() + if len(ev.Part) == 0 { + t.Error("Part is empty, want non-empty") + } + part, err := parseTextPart(ev.Part) + if err != nil { + t.Fatalf("parseTextPart() error = %v", err) + } + if part.Text == "" { + t.Error("TextPart.Text is empty") + } + }, + }, + { + name: "step_finish_line", + fixture: "simple_turn.jsonl", + lineIdx: 2, + wantType: "step_finish", + checkFunc: func(t *testing.T, ev rawRunEvent) { + t.Helper() + if len(ev.Part) == 0 { + t.Error("Part is empty, want non-empty") + } + part, err := parseStepFinishPart(ev.Part) + if err != nil { + t.Fatalf("parseStepFinishPart() error = %v", err) + } + if part.Reason != "stop" { + t.Errorf("StepFinishPart.Reason = %q, want %q", part.Reason, "stop") + } + }, + }, + { + name: "tool_use_line", + fixture: "tool_success.jsonl", + lineIdx: 1, + wantType: "tool_use", + checkFunc: func(t *testing.T, ev rawRunEvent) { + t.Helper() + if len(ev.Part) == 0 { + t.Error("Part is empty, want non-empty") + } + part, err := parseToolPart(ev.Part) + if err != nil { + t.Fatalf("parseToolPart() error = %v", err) + } + if part.Tool != "read" { + t.Errorf("ToolPart.Tool = %q, want %q", part.Tool, "read") + } + if part.State.Status != "completed" { + t.Errorf("ToolPart.State.Status = %q, want %q", part.State.Status, "completed") + } + }, + }, + { + name: "error_line", + fixture: "logical_failure_exit0.jsonl", + lineIdx: 1, + wantType: "error", + checkFunc: func(t *testing.T, ev rawRunEvent) { + t.Helper() + if ev.Error == nil { + t.Fatal("Error is nil, want non-nil") + } + if ev.Error.Name != "ProviderAuthError" { + t.Errorf("Error.Name = %q, want %q", ev.Error.Name, "ProviderAuthError") + } + if ev.Error.Data == nil { + t.Fatal("Error.Data is nil, want non-nil") + } + if msg, _ := ev.Error.Data["message"].(string); msg != "invalid api key" { + t.Errorf("Error.Data[message] = %q, want %q", msg, "invalid api key") + } + }, + }, + { + name: "unknown_type_no_error", + fixture: "malformed_event.jsonl", + lineIdx: 1, + wantType: "unknown_future_type", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + line := loadFixtureLine(t, tt.fixture, tt.lineIdx) + ev, err := parseRunEvent(line) + if err != nil { + t.Fatalf("parseRunEvent() error = %v", err) + } + if ev.Type != tt.wantType { + t.Errorf("Type = %q, want %q", ev.Type, tt.wantType) + } + if ev.SessionID == "" && tt.fixture != "malformed_event.jsonl" { + t.Errorf("SessionID is empty") + } + if tt.checkFunc != nil { + tt.checkFunc(t, ev) + } + }) + } +} + +func TestParseRunEvent_InvalidJSON(t *testing.T) { + t.Parallel() + + _, err := parseRunEvent([]byte("not valid json")) + if err == nil { + t.Fatal("parseRunEvent(invalid) error = nil, want error") + } +} + +func TestScanLines(t *testing.T) { + t.Parallel() + + data := loadFixture(t, "permission_warning_then_error.txt") + lines := bytes.Split(bytes.TrimRight(data, "\n"), []byte("\n")) + if len(lines) < 2 { + t.Fatalf("fixture has %d lines, want >= 2", len(lines)) + } + + t.Run("plain_text_line_fails_json_parse", func(t *testing.T) { + t.Parallel() + + _, err := parseRunEvent(lines[0]) + if err == nil { + t.Fatal("parseRunEvent(plain text) error = nil, want error") + } + text := string(lines[0]) + if !strings.HasPrefix(text, "! permission requested:") { + t.Errorf("plain text = %q, want prefix %q", text, "! permission requested:") + } + }) + + t.Run("json_line_parsed_as_tool_use", func(t *testing.T) { + t.Parallel() + + ev, err := parseRunEvent(lines[1]) + if err != nil { + t.Fatalf("parseRunEvent(json line) error = %v", err) + } + if ev.Type != "tool_use" { + t.Errorf("Type = %q, want %q", ev.Type, "tool_use") + } + part, err := parseToolPart(ev.Part) + if err != nil { + t.Fatalf("parseToolPart() error = %v", err) + } + if part.State.Status != "error" { + t.Errorf("State.Status = %q, want %q", part.State.Status, "error") + } + }) +} + +func TestQueryExportUsage(t *testing.T) { + t.Parallel() + + t.Run("parse_usage_extracted", func(t *testing.T) { + t.Parallel() + + data := loadFixture(t, "export_usage.json") + usage := parseExportOutput(data, "ses_abc123") + + if usage.InputTokens != 1500 { + t.Errorf("InputTokens = %d, want 1500", usage.InputTokens) + } + if usage.OutputTokens != 300 { + t.Errorf("OutputTokens = %d, want 300", usage.OutputTokens) + } + if usage.TotalTokens != 1800 { + t.Errorf("TotalTokens = %d, want 1800", usage.TotalTokens) + } + if usage.CacheReadTokens != 200 { + t.Errorf("CacheReadTokens = %d, want 200", usage.CacheReadTokens) + } + if usage.Model != "anthropic/claude-sonnet-4-5" { + t.Errorf("Model = %q, want %q", usage.Model, "anthropic/claude-sonnet-4-5") + } + }) + + t.Run("parse_missing_tokens_returns_zero", func(t *testing.T) { + t.Parallel() + + data := loadFixture(t, "export_usage_missing_tokens.json") + usage := parseExportOutput(data, "ses_abc123") + + if usage.InputTokens != 0 { + t.Errorf("InputTokens = %d, want 0", usage.InputTokens) + } + if usage.OutputTokens != 0 { + t.Errorf("OutputTokens = %d, want 0", usage.OutputTokens) + } + if usage.CacheReadTokens != 0 { + t.Errorf("CacheReadTokens = %d, want 0", usage.CacheReadTokens) + } + }) + + t.Run("parse_session_id_mismatch_returns_zero", func(t *testing.T) { + t.Parallel() + + data := loadFixture(t, "export_usage.json") + usage := parseExportOutput(data, "ses_different_session") + + if usage.InputTokens != 0 { + t.Errorf("InputTokens = %d, want 0 for mismatched session", usage.InputTokens) + } + }) + + t.Run("parse_invalid_json_returns_zero", func(t *testing.T) { + t.Parallel() + + usage := parseExportOutput([]byte("not valid json"), "ses_abc123") + if usage.InputTokens != 0 || usage.OutputTokens != 0 { + t.Errorf("invalid JSON should return zero usage, got InputTokens=%d OutputTokens=%d", + usage.InputTokens, usage.OutputTokens) + } + }) + + t.Run("parse_empty_messages_returns_zero", func(t *testing.T) { + t.Parallel() + + usage := parseExportOutput([]byte(`{"messages":[]}`), "ses_abc123") + if usage.InputTokens != 0 { + t.Errorf("empty messages should return zero usage, got InputTokens=%d", usage.InputTokens) + } + }) + + t.Run("parse_user_message_skipped", func(t *testing.T) { + t.Parallel() + + // Only user message in the array; should return zero usage. + data := []byte(`{"messages":[{"info":{"role":"user","sessionID":"ses_abc123","tokens":{"input":100,"output":50}}}]}`) + usage := parseExportOutput(data, "ses_abc123") + if usage.InputTokens != 0 { + t.Errorf("user message should be skipped, got InputTokens=%d", usage.InputTokens) + } + }) +} diff --git a/internal/agent/opencode/parse_unix_test.go b/internal/agent/opencode/parse_unix_test.go new file mode 100644 index 00000000..a0fa9209 --- /dev/null +++ b/internal/agent/opencode/parse_unix_test.go @@ -0,0 +1,135 @@ +//go:build unix + +package opencode + +import ( + "context" + "log/slog" + "os" + "path/filepath" + "strconv" + "strings" + "testing" + + "github.com/sortie-ai/sortie/internal/agent/agentcore" + "github.com/sortie-ai/sortie/internal/agent/agenttest" +) + +func writeExportScript(t *testing.T, dir, fixtureName string, exitCode int) (string, string) { + t.Helper() + + argsPath := filepath.Join(dir, "args.log") + body := `printf '%s\n' "$@" > '` + argsPath + `' +exit ` + strconv.Itoa(exitCode) + if fixtureName != "" && exitCode == 0 { + fixturePath := filepath.Join(dir, fixtureName) + if err := os.WriteFile(fixturePath, loadFixture(t, fixtureName), 0o644); err != nil { + t.Fatalf("WriteFile(%q): %v", fixtureName, err) + } + body = `printf '%s\n' "$@" > '` + argsPath + `' +cat '` + fixturePath + `'` + } + + return agenttest.WriteScript(t, dir, "fake-export", body), argsPath +} + +func testExportState(command, workspace string) *sessionState { + return &sessionState{ + target: agentcore.LaunchTarget{ + Command: command, + WorkspacePath: workspace, + }, + sessionID: "ses_abc123", + baseLogger: slog.Default(), + } +} + +func TestQueryExportSubprocess(t *testing.T) { + t.Parallel() + + t.Run("local_subprocess_usage_extracted", func(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + script, argsPath := writeExportScript(t, tmpDir, "export_usage.json", 0) + state := testExportState(script, tmpDir) + + usage := queryExportUsage(context.Background(), state) + if usage.InputTokens != 1500 { + t.Errorf("InputTokens = %d, want 1500", usage.InputTokens) + } + if usage.OutputTokens != 300 { + t.Errorf("OutputTokens = %d, want 300", usage.OutputTokens) + } + if usage.CacheReadTokens != 200 { + t.Errorf("CacheReadTokens = %d, want 200", usage.CacheReadTokens) + } + + args, err := os.ReadFile(argsPath) + if err != nil { + t.Fatalf("ReadFile(args.log): %v", err) + } + if string(args) != "export\n--sanitize\nses_abc123\n" { + t.Errorf("export args = %q, want %q", string(args), "export\n--sanitize\nses_abc123\n") + } + }) + + t.Run("local_subprocess_missing_tokens_returns_zero", func(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + script, _ := writeExportScript(t, tmpDir, "export_usage_missing_tokens.json", 0) + state := testExportState(script, tmpDir) + + usage := queryExportUsage(context.Background(), state) + if usage != (exportUsage{}) { + t.Errorf("usage = %+v, want zero value", usage) + } + }) + + t.Run("local_subprocess_nonzero_exit_returns_zero", func(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + script, _ := writeExportScript(t, tmpDir, "", 1) + state := testExportState(script, tmpDir) + + usage := queryExportUsage(context.Background(), state) + if usage != (exportUsage{}) { + t.Errorf("usage = %+v, want zero value", usage) + } + }) + + t.Run("ssh_subprocess_usage_extracted", func(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + script, argsPath := writeExportScript(t, tmpDir, "export_usage.json", 0) + state := testExportState(script, tmpDir) + state.target.RemoteCommand = "opencode" + state.target.SSHHost = "example.test" + + usage := queryExportUsage(context.Background(), state) + if usage.InputTokens != 1500 { + t.Errorf("InputTokens = %d, want 1500", usage.InputTokens) + } + if usage.OutputTokens != 300 { + t.Errorf("OutputTokens = %d, want 300", usage.OutputTokens) + } + if usage.CacheReadTokens != 200 { + t.Errorf("CacheReadTokens = %d, want 200", usage.CacheReadTokens) + } + + args, err := os.ReadFile(argsPath) + if err != nil { + t.Fatalf("ReadFile(args.log): %v", err) + } + logged := string(args) + if !strings.Contains(logged, "example.test") { + t.Errorf("ssh args = %q, want host %q", logged, "example.test") + } + if !strings.Contains(logged, "export") || !strings.Contains(logged, "--sanitize") || !strings.Contains(logged, "ses_abc123") { + t.Errorf("ssh args = %q, want export invocation details", logged) + } + }) +} diff --git a/internal/agent/opencode/testdata/export_usage.json b/internal/agent/opencode/testdata/export_usage.json new file mode 100644 index 00000000..1a0e519d --- /dev/null +++ b/internal/agent/opencode/testdata/export_usage.json @@ -0,0 +1,28 @@ +{ + "messages": [ + { + "info": { + "role": "user", + "sessionID": "ses_abc123" + } + }, + { + "info": { + "role": "assistant", + "sessionID": "ses_abc123", + "providerID": "anthropic", + "modelID": "claude-sonnet-4-5", + "tokens": { + "input": 1500, + "output": 300, + "total": 1800, + "cache": { + "read": 200, + "write": 50 + } + }, + "cost": 0.05 + } + } + ] +} diff --git a/internal/agent/opencode/testdata/export_usage_missing_tokens.json b/internal/agent/opencode/testdata/export_usage_missing_tokens.json new file mode 100644 index 00000000..1c22f82a --- /dev/null +++ b/internal/agent/opencode/testdata/export_usage_missing_tokens.json @@ -0,0 +1,10 @@ +{ + "messages": [ + { + "info": { + "role": "assistant", + "sessionID": "ses_abc123" + } + } + ] +} diff --git a/internal/agent/opencode/testdata/logical_failure_exit0.jsonl b/internal/agent/opencode/testdata/logical_failure_exit0.jsonl new file mode 100644 index 00000000..690a5d72 --- /dev/null +++ b/internal/agent/opencode/testdata/logical_failure_exit0.jsonl @@ -0,0 +1,2 @@ +{"type":"step_start","timestamp":1777197598000,"sessionID":"ses_fail123","part":{"id":"prt_fail001","messageID":"msg_fail001","sessionID":"ses_fail123","snapshot":"","type":"step-start"}} +{"type":"error","timestamp":1777197598202,"sessionID":"ses_fail123","error":{"name":"ProviderAuthError","data":{"message":"invalid api key"}}} diff --git a/internal/agent/opencode/testdata/malformed_event.jsonl b/internal/agent/opencode/testdata/malformed_event.jsonl new file mode 100644 index 00000000..3026ac77 --- /dev/null +++ b/internal/agent/opencode/testdata/malformed_event.jsonl @@ -0,0 +1,2 @@ +{"type":"step_start","timestamp":1777197446593,"sessionID":"ses_abc123","part":{"id":"prt_m01","messageID":"msg_m01","sessionID":"ses_abc123","snapshot":"","type":"step-start"}} +{"type":"unknown_future_type","data":"something"} diff --git a/internal/agent/opencode/testdata/permission_warning_then_error.txt b/internal/agent/opencode/testdata/permission_warning_then_error.txt new file mode 100644 index 00000000..4bb87360 --- /dev/null +++ b/internal/agent/opencode/testdata/permission_warning_then_error.txt @@ -0,0 +1,2 @@ +! permission requested: bash; auto-rejecting +{"type":"tool_use","timestamp":1777197470000,"sessionID":"ses_abc123","part":{"id":"prt_bash_001","messageID":"msg_bash_001","sessionID":"ses_abc123","type":"tool","tool":"bash","callID":"call_bash_001","state":{"status":"error","error":"The user rejected permission to use this specific tool call.","time":{"start":1777197470000,"end":1777197470001}}}} diff --git a/internal/agent/opencode/testdata/resume_turn.jsonl b/internal/agent/opencode/testdata/resume_turn.jsonl new file mode 100644 index 00000000..b610e559 --- /dev/null +++ b/internal/agent/opencode/testdata/resume_turn.jsonl @@ -0,0 +1,3 @@ +{"type":"step_start","timestamp":1777197450000,"sessionID":"ses_abc123","part":{"id":"prt_r01","messageID":"msg_r01","sessionID":"ses_abc123","snapshot":"","type":"step-start"}} +{"type":"text","timestamp":1777197450100,"sessionID":"ses_abc123","part":{"id":"prt_r02","messageID":"msg_r01","sessionID":"ses_abc123","type":"text","text":"Resuming work.","time":{"start":1777197450050,"end":1777197450090}}} +{"type":"step_finish","timestamp":1777197450200,"sessionID":"ses_abc123","part":{"id":"prt_r03","reason":"stop","messageID":"msg_r01","sessionID":"ses_abc123","type":"step-finish","tokens":{"total":800,"input":700,"output":100},"cost":0}} diff --git a/internal/agent/opencode/testdata/simple_turn.jsonl b/internal/agent/opencode/testdata/simple_turn.jsonl new file mode 100644 index 00000000..b167a134 --- /dev/null +++ b/internal/agent/opencode/testdata/simple_turn.jsonl @@ -0,0 +1,3 @@ +{"type":"step_start","timestamp":1777197446593,"sessionID":"ses_abc123","part":{"id":"prt_001","messageID":"msg_001","sessionID":"ses_abc123","snapshot":"45865d3017876fc42b80fa16e317d109a7008c30","type":"step-start"}} +{"type":"text","timestamp":1777197446597,"sessionID":"ses_abc123","part":{"id":"prt_002","messageID":"msg_001","sessionID":"ses_abc123","type":"text","text":"Hello, world!","time":{"start":1777197446595,"end":1777197446596}}} +{"type":"step_finish","timestamp":1777197446660,"sessionID":"ses_abc123","part":{"id":"prt_003","reason":"stop","messageID":"msg_001","sessionID":"ses_abc123","type":"step-finish","tokens":{"total":16267,"input":14406,"output":21,"reasoning":0,"cache":{"write":0,"read":1840}},"cost":0}} diff --git a/internal/agent/opencode/testdata/tool_success.jsonl b/internal/agent/opencode/testdata/tool_success.jsonl new file mode 100644 index 00000000..acc29c77 --- /dev/null +++ b/internal/agent/opencode/testdata/tool_success.jsonl @@ -0,0 +1,3 @@ +{"type":"step_start","timestamp":1777197461000,"sessionID":"ses_abc123","part":{"id":"prt_101","messageID":"msg_101","sessionID":"ses_abc123","snapshot":"","type":"step-start"}} +{"type":"tool_use","timestamp":1777197461503,"sessionID":"ses_abc123","part":{"id":"prt_102","messageID":"msg_101","sessionID":"ses_abc123","type":"tool","tool":"read","callID":"call_001","state":{"status":"completed","input":{"filePath":"/home/ubuntu/work/sortie/README.md"},"output":"/home/ubuntu/work/sortie/README.md\nfile\n\nThis is a representative content body returned by the read tool. It is intentionally long enough to exceed 200 characters so tests can verify large output handling in the tool_use event parsing path without truncation.","metadata":{"preview":"test preview","truncated":false,"loaded":[]},"title":"README.md","time":{"start":1777197461489,"end":1777197461502}}}} +{"type":"step_finish","timestamp":1777197461600,"sessionID":"ses_abc123","part":{"id":"prt_103","reason":"stop","messageID":"msg_101","sessionID":"ses_abc123","type":"step-finish","tokens":{"total":1000,"input":800,"output":200},"cost":0}} diff --git a/internal/registry/adapter_meta_test.go b/internal/registry/adapter_meta_test.go index 3ca06afa..07b13d85 100644 --- a/internal/registry/adapter_meta_test.go +++ b/internal/registry/adapter_meta_test.go @@ -8,6 +8,7 @@ import ( // Trigger adapter init() registrations. _ "github.com/sortie-ai/sortie/internal/agent/claude" _ "github.com/sortie-ai/sortie/internal/agent/mock" + _ "github.com/sortie-ai/sortie/internal/agent/opencode" _ "github.com/sortie-ai/sortie/internal/scm/github" _ "github.com/sortie-ai/sortie/internal/tracker/file" _ "github.com/sortie-ai/sortie/internal/tracker/jira" @@ -75,6 +76,11 @@ func TestAdapterMeta_RealRegistrations(t *testing.T) { kind: "claude-code", wantCommand: true, }, + { + name: "opencode requires command", + kind: "opencode", + wantCommand: true, + }, { name: "mock requires nothing", kind: "mock", From 715e0337249f22a057247c13f156d71e908a74f8 Mon Sep 17 00:00:00 2001 From: Serghei Iakovlev Date: Sun, 26 Apr 2026 19:33:03 +0200 Subject: [PATCH 2/2] fix(opencode): address review feedback Fix the OpenCode session handle and teardown behavior after PR review. - return the resumed session ID instead of leaking workspace paths via Session.ID - honor StopSession(ctx) with a bounded graceful shutdown path - correct the managed-environment warning text in export usage recovery - remove the dead errors import placeholder and add regression coverage for stop deadlines --- internal/agent/opencode/command_test.go | 4 -- internal/agent/opencode/opencode.go | 44 ++++++++++----- internal/agent/opencode/opencode_test.go | 70 ++++++++++++++++++++++++ internal/agent/opencode/parse.go | 2 +- 4 files changed, 102 insertions(+), 18 deletions(-) diff --git a/internal/agent/opencode/command_test.go b/internal/agent/opencode/command_test.go index 9ff8e32d..8456f9a5 100644 --- a/internal/agent/opencode/command_test.go +++ b/internal/agent/opencode/command_test.go @@ -2,7 +2,6 @@ package opencode import ( "encoding/json" - "errors" "strings" "testing" @@ -482,6 +481,3 @@ func TestSSHRemoteCommand(t *testing.T) { } }) } - -// Compile-time interface check. -var _ = errors.New // keep errors imported if needed in future diff --git a/internal/agent/opencode/opencode.go b/internal/agent/opencode/opencode.go index 83ab6bdc..4b0cf6e1 100644 --- a/internal/agent/opencode/opencode.go +++ b/internal/agent/opencode/opencode.go @@ -100,7 +100,7 @@ func (a *OpenCodeAdapter) StartSession(_ context.Context, params domain.StartSes } return domain.Session{ - ID: params.WorkspacePath, + ID: state.sessionID, AgentPID: "", Internal: state, }, nil @@ -449,7 +449,7 @@ func (a *OpenCodeAdapter) RunTurn(ctx context.Context, session domain.Session, p } // StopSession marks the session closed and terminates any active subprocess. -func (a *OpenCodeAdapter) StopSession(_ context.Context, session domain.Session) error { +func (a *OpenCodeAdapter) StopSession(ctx context.Context, session domain.Session) error { state, ok := session.Internal.(*sessionState) if !ok { return &domain.AgentError{ @@ -461,19 +461,10 @@ func (a *OpenCodeAdapter) StopSession(_ context.Context, session domain.Session) state.mu.Lock() state.closed = true active := state.active + state.active = nil state.mu.Unlock() - if active == nil { - return nil - } - - closeStop(active) - killTurnProcess(active) - <-active.readerDone - _ = waitForProcess(active) - clearActive(state, active) - - return nil + return stopActiveTurn(ctx, active) } // EventStream returns nil because OpenCode events are delivered via the @@ -695,6 +686,33 @@ func closeStop(runtime *turnRuntime) { }) } +func stopActiveTurn(ctx context.Context, runtime *turnRuntime) error { + if runtime == nil { + return nil + } + + closeStop(runtime) + if runtime.proc == nil { + return nil + } + + _ = procutil.SignalGraceful(runtime.proc.Pid) //nolint:errcheck // best-effort signal; process may already be dead + + graceTimer := time.NewTimer(5 * time.Second) + defer stopTimer(graceTimer) + + select { + case <-runtime.waitCh: + return nil + case <-graceTimer.C: + killTurnProcess(runtime) + return nil + case <-ctx.Done(): + killTurnProcess(runtime) + return ctx.Err() + } +} + func killTurnProcess(runtime *turnRuntime) { if runtime == nil || runtime.proc == nil { return diff --git a/internal/agent/opencode/opencode_test.go b/internal/agent/opencode/opencode_test.go index 412d6e99..87c445ed 100644 --- a/internal/agent/opencode/opencode_test.go +++ b/internal/agent/opencode/opencode_test.go @@ -164,6 +164,9 @@ func TestStartSession_ResumeSession(t *testing.T) { if state.sessionID != resumeID { t.Errorf("sessionID = %q, want %q", state.sessionID, resumeID) } + if session.ID != resumeID { + t.Errorf("session.ID = %q, want %q", session.ID, resumeID) + } } func TestRunTurn_WrongInternalType(t *testing.T) { @@ -314,6 +317,73 @@ func TestStopSession_NoActiveTurn(t *testing.T) { } } +func TestStopSession_ContextDeadline(t *testing.T) { + t.Parallel() + + testCtx, testCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer testCancel() + + tmpDir := t.TempDir() + script := writeOpenCodeScript(t, tmpDir, `case "$1" in + export) echo '{"messages":[]}'; exit 0;; +esac +trap '' TERM +printf '{"type":"step_start","timestamp":1000,"sessionID":"ses_abc123","part":{"id":"p1","messageID":"m1","sessionID":"ses_abc123","snapshot":"","type":"step-start"}}\n' +while :; do sleep 1; done`) + + a, _ := NewOpenCodeAdapter(map[string]any{}) + session, err := a.StartSession(testCtx, domain.StartSessionParams{ + WorkspacePath: tmpDir, + AgentConfig: domain.AgentConfig{Command: script}, + }) + if err != nil { + t.Fatalf("StartSession() error = %v", err) + } + + gotEvent := make(chan struct{}, 1) + resultCh := make(chan domain.TurnResult, 1) + errCh := make(chan error, 1) + go func() { + result, runErr := a.RunTurn(context.Background(), session, domain.RunTurnParams{ + Prompt: "work", + OnEvent: func(_ domain.AgentEvent) { + select { + case gotEvent <- struct{}{}: + default: + } + }, + }) + resultCh <- result + errCh <- runErr + }() + + select { + case <-gotEvent: + case <-testCtx.Done(): + t.Fatal("timed out waiting for first event") + } + + stopCtx, stopCancel := context.WithTimeout(testCtx, 50*time.Millisecond) + defer stopCancel() + + err = a.StopSession(stopCtx, session) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("StopSession() error = %v, want %v", err, context.DeadlineExceeded) + } + + select { + case result := <-resultCh: + if result.ExitReason != domain.EventTurnCancelled { + t.Errorf("ExitReason = %q, want %q", result.ExitReason, domain.EventTurnCancelled) + } + if runErr := <-errCh; runErr != nil { + t.Errorf("RunTurn() error = %v, want nil", runErr) + } + case <-testCtx.Done(): + t.Fatal("RunTurn did not return after StopSession timeout") + } +} + func TestStopSession_WrongInternalType(t *testing.T) { t.Parallel() diff --git a/internal/agent/opencode/parse.go b/internal/agent/opencode/parse.go index ee51e1f8..5a881137 100644 --- a/internal/agent/opencode/parse.go +++ b/internal/agent/opencode/parse.go @@ -178,7 +178,7 @@ func queryExportUsage(ctx context.Context, state *sessionState) exportUsage { managedEnv, err := buildManagedEnv(state.passthrough) if err != nil { - state.logger().Warn("failed to build opencode export command", slog.Any("error", err)) + state.logger().Warn("failed to build opencode managed environment", slog.Any("error", err)) return exportUsage{} }