From f65a0d236f1904aa63532ca2a606dee23c3190db Mon Sep 17 00:00:00 2001 From: Peter Evans Date: Wed, 20 May 2026 21:28:19 -0500 Subject: [PATCH 1/5] feat: Attribute-based access control (ABAC) --- README.md | 25 +- cmd/zepctl/main.go | 5 + docs/cli-specification.md | 167 ++++++- docs/cli.mdx | 188 +++++++- go.mod | 9 +- go.sum | 38 ++ internal/abac/client.go | 248 +++++++++++ internal/abac/client_test.go | 400 +++++++++++++++++ internal/abac/types.go | 165 +++++++ internal/auth/config.go | 50 +++ internal/auth/config_test.go | 59 +++ internal/auth/idtoken.go | 38 ++ internal/auth/idtoken_test.go | 66 +++ internal/auth/oauth.go | 414 ++++++++++++++++++ internal/auth/oauth_test.go | 454 +++++++++++++++++++ internal/auth/session.go | 117 +++++ internal/auth/session_test.go | 204 +++++++++ internal/auth/tokens.go | 39 ++ internal/auth/tokens_test.go | 118 +++++ internal/cli/api_key_abac.go | 364 ++++++++++++++++ internal/cli/api_key_abac_test.go | 625 +++++++++++++++++++++++++++ internal/cli/auth.go | 305 +++++++++++++ internal/cli/auth_test.go | 520 ++++++++++++++++++++++ internal/cli/config.go | 222 ++++++++-- internal/cli/config_test.go | 469 ++++++++++++++++++++ internal/cli/edge.go | 10 +- internal/cli/environment.go | 184 ++++++++ internal/cli/environment_test.go | 316 ++++++++++++++ internal/cli/episode.go | 10 +- internal/cli/graph.go | 34 +- internal/cli/node.go | 14 +- internal/cli/observation.go | 6 +- internal/cli/ontology.go | 4 +- internal/cli/policy_set.go | 304 +++++++++++++ internal/cli/policy_set_test.go | 481 +++++++++++++++++++++ internal/cli/project.go | 2 +- internal/cli/root.go | 4 + internal/cli/set_project.go | 141 ++++++ internal/cli/set_project_test.go | 123 ++++++ internal/cli/summary_instructions.go | 8 +- internal/cli/task.go | 4 +- internal/cli/thread.go | 16 +- internal/cli/thread_summary.go | 4 +- internal/cli/user.go | 16 +- internal/client/client.go | 195 ++++++++- internal/client/client_test.go | 364 ++++++++++++++++ internal/config/config.go | 189 +++++++- internal/config/config_test.go | 543 +++++++++++++++++++++++ internal/keyring/keyring.go | 120 ++++- internal/keyring/keyring_test.go | 290 +++++++++++++ internal/output/output.go | 27 ++ 51 files changed, 8566 insertions(+), 152 deletions(-) create mode 100644 internal/abac/client.go create mode 100644 internal/abac/client_test.go create mode 100644 internal/abac/types.go create mode 100644 internal/auth/config.go create mode 100644 internal/auth/config_test.go create mode 100644 internal/auth/idtoken.go create mode 100644 internal/auth/idtoken_test.go create mode 100644 internal/auth/oauth.go create mode 100644 internal/auth/oauth_test.go create mode 100644 internal/auth/session.go create mode 100644 internal/auth/session_test.go create mode 100644 internal/auth/tokens.go create mode 100644 internal/auth/tokens_test.go create mode 100644 internal/cli/api_key_abac.go create mode 100644 internal/cli/api_key_abac_test.go create mode 100644 internal/cli/auth.go create mode 100644 internal/cli/auth_test.go create mode 100644 internal/cli/config_test.go create mode 100644 internal/cli/environment.go create mode 100644 internal/cli/environment_test.go create mode 100644 internal/cli/policy_set.go create mode 100644 internal/cli/policy_set_test.go create mode 100644 internal/cli/set_project.go create mode 100644 internal/cli/set_project_test.go create mode 100644 internal/client/client_test.go create mode 100644 internal/config/config_test.go create mode 100644 internal/keyring/keyring_test.go diff --git a/README.md b/README.md index f897913..d777291 100644 --- a/README.md +++ b/README.md @@ -35,21 +35,37 @@ zepctl user list ## Authentication -Set environment variables or use profiles: +zepctl supports two authentication modes (which coexist on the same profile): + +- **API key** -- long-lived, set via `ZEP_API_KEY` or stored per-profile in the system keychain. Used for headless / CI scenarios. +- **Bearer token** -- obtained interactively via `zepctl auth login` (OAuth / Kinde) and stored as a refresh token in the system keychain. Required for ABAC management (`policy-set`, `api-key`) and the interactive `config set-project` flow. + +```bash +# Browser-based login; auto-selects a project after authentication +zepctl auth login + +# Headless mode prints the URL instead of opening a browser +zepctl auth login --no-browser + +# Inspect credentials and bearer expiration +zepctl auth status +``` | Variable | Description | |----------|-------------| | `ZEP_API_KEY` | API key for authentication | | `ZEP_API_URL` | API endpoint (default: `https://api.getzep.com`) | | `ZEP_PROFILE` | Override current profile | +| `ZEP_PROJECT` | Override active project UUID | -Configuration file location: `~/.zepctl/config.yaml` +Configuration file location: `~/.zepctl/config.yaml`. API keys and OAuth refresh tokens are stored in the system keychain. ## Commands | Command | Description | |---------|-------------| -| `config` | Manage profiles and settings | +| `config` | Manage profiles, environment presets, and the active project | +| `auth` | Bearer-token login / logout / status | | `project` | Get project information | | `user` | Manage users | | `thread` | Manage conversation threads | @@ -62,6 +78,8 @@ Configuration file location: `~/.zepctl/config.yaml` | `task` | Monitor async operations | | `ontology` | Manage graph schema | | `summary-instructions` | Manage user summary instructions | +| `policy-set` | Manage ABAC policy sets (bearer auth) | +| `api-key` | List API keys, configure ABAC, and dry-run policy decisions (bearer auth) | ## Global Flags @@ -69,6 +87,7 @@ Configuration file location: `~/.zepctl/config.yaml` |------|-------------| | `--api-key`, `-k` | Override API key | | `--profile`, `-p` | Use specific profile | +| `--project` | Override active project UUID for this command | | `--output`, `-o` | Output format: `table`, `json`, `yaml`, `wide` | | `--help`, `-h` | Display help | diff --git a/cmd/zepctl/main.go b/cmd/zepctl/main.go index 5ec67d4..c9b5310 100644 --- a/cmd/zepctl/main.go +++ b/cmd/zepctl/main.go @@ -1,6 +1,7 @@ package main import ( + "errors" "os" "github.com/getzep/zepctl/internal/cli" @@ -8,6 +9,10 @@ import ( func main() { if err := cli.Execute(); err != nil { + var exitErr *cli.ExitCodeError + if errors.As(err, &exitErr) { + os.Exit(exitErr.Code) + } os.Exit(1) } } diff --git a/docs/cli-specification.md b/docs/cli-specification.md index b8d4b2d..95172e1 100644 --- a/docs/cli-specification.md +++ b/docs/cli-specification.md @@ -14,6 +14,13 @@ ## Authentication & Configuration +zepctl supports two authentication modes: + +- **API key** (long-lived): the original mode, used for headless / CI scenarios. Pass via `ZEP_API_KEY` or store per-profile in the system keychain. +- **Bearer token** (OAuth / Kinde): obtained interactively via `zepctl auth login`. Required for the ABAC management commands (`policy-set`, `api-key`) and the `config set-project` interactive flow. Refresh tokens are stored in the system keychain. + +Both modes coexist on the same profile. Commands that require bearer auth use the bearer token even if an API key is also present. + ### Configuration File Location: `~/.zepctl/config.yaml` @@ -24,13 +31,26 @@ profiles: - name: production # API keys are stored securely in the system keychain - name: development - api-url: https://api.dev.getzep.com # Optional: only if using non-default URL + api-url: https://api.example.com + # Optional per-profile OAuth overrides; otherwise build-time defaults apply. + oauth-issuer: https://your-tenant.kinde.com + oauth-client-id: + oauth-audience: + project-uuid: + account-uuid: +environments: + # Named presets that can be applied to profiles via `--env`. + - name: dev + api-url: https://api.example.com + oauth-issuer: https://your-tenant.kinde.com + oauth-client-id: + oauth-audience: defaults: output: table page-size: 50 ``` -**Credential Storage**: API keys are stored in the system keychain (macOS Keychain, Windows Credential Manager, or Linux Secret Service) rather than in the config file. For CI/CD environments without keychain access, use the `ZEP_API_KEY` environment variable. +**Credential Storage**: API keys and OAuth refresh tokens are stored in the system keychain (macOS Keychain, Windows Credential Manager, or Linux Secret Service) rather than in the config file. For CI/CD environments without keychain access, use the `ZEP_API_KEY` environment variable. ### Environment Variables @@ -40,17 +60,55 @@ defaults: | `ZEP_API_URL` | API endpoint URL (default: `https://api.getzep.com`) | | `ZEP_PROFILE` | Override current profile | | `ZEP_OUTPUT` | Default output format | +| `ZEP_PROJECT` | Override active project UUID | ### Configuration Commands ```bash -zepctl config use-profile # Switch active profile -zepctl config get-profiles # List all profiles -zepctl config add-profile # Add a new profile -zepctl config delete-profile # Remove a profile -zepctl config view # Display current configuration +# Profiles +zepctl config view # Display current configuration +zepctl config get-profiles # List all profiles +zepctl config use-profile # Switch active profile +zepctl config add-profile # Add a new profile (prompts for API key) +zepctl config update-profile [name] # Update fields on an existing profile +zepctl config delete-profile # Remove a profile +zepctl config set-project [uuid] # Set the active project (interactive if UUID omitted) + +# Environment presets (reusable api-url + OAuth settings) +zepctl config get-environments # List all environment presets +zepctl config add-environment # Add a named preset +zepctl config update-environment # Update fields on a preset +zepctl config delete-environment # Remove a preset ``` +#### Profile Flags + +`add-profile` and `update-profile` accept the same field flags. Only flags explicitly passed are applied; omitted flags leave existing values untouched (on update). Pass an empty string to clear a field on update. + +| Flag | Description | +|------|-------------| +| `--api-key` | API key (stored in system keychain) | +| `--api-url` | API endpoint URL | +| `--env` | Apply a named environment preset (replaces `api-url`/OAuth fields). Per-field flags override the preset. | +| `--oauth-issuer` | OIDC issuer override for `auth login` | +| `--oauth-client-id` | OAuth client ID override for `auth login` | +| `--oauth-audience` | OAuth audience for the bearer token `aud` claim | +| `--project` | Project UUID (update only) | +| `--account` | Account UUID (update only) | +| `--no-api-key` | (add only) Create a bearer-only profile with no API key, skipping the prompt | + +#### Environment Preset Flags + +Environments are reusable bundles of endpoint + OAuth settings. Profiles can adopt them via `--env`. + +| Flag | Description | +|------|-------------| +| `--api-url` | API URL for the environment | +| `--oauth-issuer` | OIDC issuer | +| `--oauth-client-id` | OAuth client ID | +| `--oauth-audience` | OAuth audience for bearer token `aud` claim | +| `--force` | (delete only) Skip confirmation prompt | + ## Global Flags | Flag | Short | Description | @@ -58,9 +116,11 @@ zepctl config view # Display current configuration | `--api-key` | `-k` | Override API key | | `--api-url` | | Override API URL | | `--profile` | `-p` | Use specific profile | +| `--project` | | Override active project UUID for this command | | `--output` | `-o` | Output format: `table`, `json`, `yaml`, `wide` | | `--quiet` | `-q` | Suppress non-essential output | | `--verbose` | `-v` | Enable verbose output | +| `--config` | | Path to config file (default: `$HOME/.zepctl/config.yaml`) | | `--help` | `-h` | Display help | | `--version` | | Display version | @@ -68,6 +128,31 @@ zepctl config view # Display current configuration ## Command Reference +### Auth Commands + +Manage bearer-token authentication for the current profile. + +```bash +zepctl auth login [--no-browser] [--env ] +zepctl auth logout +zepctl auth status +``` + +- `auth login` opens a browser window for interactive OAuth authentication and stores the resulting refresh + access tokens in the system keychain. If no profile exists yet, one is created using `--profile`/`default` and (optionally) the named environment preset. +- `--no-browser` prints the authorization URL instead of opening a browser (useful for SSH / headless sessions). +- `auth login` auto-selects a project after authentication: if the account has exactly one project, it is selected automatically; otherwise the CLI prompts for a choice. +- `auth logout` revokes the refresh token at the OAuth provider (best-effort) and clears the bearer token from the keychain. +- `auth status` displays the active profile's API URL, OIDC issuer, masked API key (if any), and bearer-token expiration. + +**Example**: +```bash +# Bootstrap an isolated dev profile with browser-based auth +zepctl --profile dev auth login --env dev +zepctl auth status +``` + +--- + ### Project Commands ```bash @@ -644,6 +729,74 @@ zepctl summary-instructions delete [flags] --- +### Policy Set Commands + +ABAC policy sets are reusable bundles of access rules attached to API keys. Policy-set commands require a bearer token (`auth login`) and operate on the currently-selected project (including `validate`, which calls the server). + +```bash +zepctl policy-set list +zepctl policy-set get +zepctl policy-set create --file +zepctl policy-set update --file +zepctl policy-set delete [--force] +zepctl policy-set validate --file +``` + +| Flag | Description | +|------|-------------| +| `--file` | (create/update/validate) Path to policy set YAML file | +| `--force` | (delete) Skip confirmation prompt | + +- `validate` exits 0 if the spec is valid, 1 if validation fails (errors printed to stderr), and 2 on client or transport errors. +- `delete` requires `--force` in non-interactive contexts; in a TTY it prompts for confirmation. +- Table output for `get` shows the policy set metadata plus the spec rendered as indented YAML. + +--- + +### API Key ABAC Commands + +Configure ABAC enforcement on individual API keys and dry-run policy decisions. All `api-key` commands require a bearer token. + +```bash +# List API keys (UUID, name, masked key, role) +zepctl api-key list + +# Per-key ABAC settings +zepctl api-key settings get +zepctl api-key settings set --mode + +# Policy set attachments +zepctl api-key policy-sets list +zepctl api-key policy-sets attach +zepctl api-key policy-sets detach + +# Dry-run policy decisions (do not perform the action) +zepctl api-key evaluate --action # decision only +zepctl api-key explain --action # decision + full evaluator trace +``` + +| Flag | Description | +|------|-------------| +| `--mode` | (`settings set`) ABAC enforcement mode: `off`, `report_only`, `enforce` | +| `--action` | (`evaluate`, `explain`) Required. Action name to evaluate (e.g. `thread.get`) | + +- `settings set` requires at least one setting flag. +- `evaluate` returns the final outcome (`allow` / `deny`), the ABAC and ABAC-shadow decisions, whether the role would have allowed the action, and whether the disagreement would be logged. +- `explain` returns everything `evaluate` does, plus the registry entry's read-only flag and a list of evaluated and skipped policy sets with per-policy match reasons. +- Both commands always exit 0 on a successful API call regardless of allow/deny outcome -- inspect the JSON or table output to check the decision. + +**Examples**: +```bash +# Switch a key to report-only mode and attach a policy set +zepctl api-key settings set abcd1234-... --mode report_only +zepctl api-key policy-sets attach abcd1234-... efgh5678-... + +# Check what the evaluator would decide for thread.get +zepctl api-key evaluate abcd1234-... --action thread.get -o json +``` + +--- + ## Scripting Examples ### Export All Users diff --git a/docs/cli.mdx b/docs/cli.mdx index 3114265..39874fb 100644 --- a/docs/cli.mdx +++ b/docs/cli.mdx @@ -40,6 +40,13 @@ zepctl user list ## Authentication +zepctl supports two authentication modes that coexist on the same profile: + +- **API key** (long-lived): pass via `ZEP_API_KEY` or store per-profile in the system keychain. Used for headless / CI scenarios. +- **Bearer token** (OAuth / Kinde): obtained interactively via `zepctl auth login` and stored as a refresh token in the system keychain. Required for the ABAC management commands (`policy-set`, `api-key`) and the `config set-project` interactive flow. + +Commands that require bearer auth use the bearer token even if an API key is also present. + ### Environment Variables | Variable | Description | @@ -48,6 +55,7 @@ zepctl user list | `ZEP_API_URL` | API endpoint URL (default: `https://api.getzep.com`) | | `ZEP_PROFILE` | Override current profile | | `ZEP_OUTPUT` | Default output format | +| `ZEP_PROJECT` | Override active project UUID | ### Configuration File @@ -59,16 +67,48 @@ profiles: - name: production # API keys are stored securely in the system keychain - name: development - api-url: https://api.dev.getzep.com # Optional: only if using non-default URL + api-url: https://api.example.com + # Optional per-profile OAuth overrides; otherwise build-time defaults apply. + oauth-issuer: https://your-tenant.kinde.com + oauth-client-id: + oauth-audience: + project-uuid: + account-uuid: +environments: + # Named presets that can be applied to profiles via `--env`. + - name: dev + api-url: https://api.example.com + oauth-issuer: https://your-tenant.kinde.com + oauth-client-id: + oauth-audience: defaults: output: table page-size: 50 ``` -API keys are stored in the system keychain (macOS Keychain, Windows Credential Manager, or Linux Secret Service) rather than in the config file. For CI/CD environments without keychain access, use the `ZEP_API_KEY` environment variable. +API keys and OAuth refresh tokens are stored in the system keychain (macOS Keychain, Windows Credential Manager, or Linux Secret Service) rather than in the config file. For CI/CD environments without keychain access, use the `ZEP_API_KEY` environment variable. +### Interactive Login + +```bash +# Open a browser to authenticate, then auto-select a project +zepctl auth login + +# Bootstrap an isolated dev profile against a configured environment preset +zepctl --profile dev auth login --env dev + +# Print the authorization URL instead of opening a browser (headless / SSH) +zepctl auth login --no-browser + +# Check current credentials and bearer token expiration +zepctl auth status + +# Revoke refresh token and clear bearer token for the current profile +zepctl auth logout +``` + ## Global Flags | Flag | Short | Description | @@ -76,34 +116,89 @@ API keys are stored in the system keychain (macOS Keychain, Windows Credential M | `--api-key` | `-k` | Override API key | | `--api-url` | | Override API URL | | `--profile` | `-p` | Use specific profile | +| `--project` | | Override active project UUID for this command | | `--output` | `-o` | Output format: `table`, `json`, `yaml`, `wide` | | `--quiet` | `-q` | Suppress non-essential output | | `--verbose` | `-v` | Enable verbose output | +| `--config` | | Path to config file (default: `$HOME/.zepctl/config.yaml`) | | `--help` | `-h` | Display help | ## Commands ### config -Manage zepctl configuration including profiles and defaults. +Manage zepctl configuration including profiles, environment presets, and the active project. ```bash -# View current configuration +# Profiles zepctl config view - -# List all profiles zepctl config get-profiles - -# Switch active profile zepctl config use-profile +zepctl config add-profile [--api-url URL] [--env ] [--no-api-key] +zepctl config update-profile [name] [flags] +zepctl config delete-profile [--force] -# Add a new profile (prompts for API key) -zepctl config add-profile [--api-url URL] +# Active project for the current profile +zepctl config set-project [uuid] # interactive if uuid omitted -# Remove a profile -zepctl config delete-profile [--force] +# Environment presets (reusable api-url + OAuth settings) +zepctl config get-environments +zepctl config add-environment [flags] +zepctl config update-environment [flags] +zepctl config delete-environment [--force] ``` +#### Profile Flags + +`add-profile` and `update-profile` accept the same field flags. On `update`, only flags explicitly passed are applied; pass an empty string to clear a field. + +| Flag | Description | +|------|-------------| +| `--api-key` | API key (stored in system keychain) | +| `--api-url` | API endpoint URL | +| `--env` | Apply a named environment preset (replaces `api-url`/OAuth fields). Per-field flags override the preset. | +| `--oauth-issuer` | OIDC issuer override for `auth login` | +| `--oauth-client-id` | OAuth client ID override for `auth login` | +| `--oauth-audience` | OAuth audience for the bearer token `aud` claim | +| `--project` | Project UUID (update only) | +| `--account` | Account UUID (update only) | +| `--no-api-key` | (`add-profile` only) Create a bearer-only profile with no API key, skipping the prompt | + +#### Environment Preset Flags + +| Flag | Description | +|------|-------------| +| `--api-url` | API URL for the environment | +| `--oauth-issuer` | OIDC issuer | +| `--oauth-client-id` | OAuth client ID | +| `--oauth-audience` | OAuth audience for bearer token `aud` claim | +| `--force` | (`delete-environment` only) Skip confirmation prompt | + +### auth + +Manage bearer-token authentication for the current profile. + +```bash +# Open a browser to authenticate, store refresh + access tokens, auto-select a project +zepctl auth login + +# Headless: print the authorization URL instead of opening a browser +zepctl auth login --no-browser + +# Create a new profile against a named environment preset on first login +zepctl --profile dev auth login --env dev + +# Revoke refresh token and clear bearer token from the keychain +zepctl auth logout + +# Show profile, API URL, OIDC issuer, masked API key, and bearer expiration +zepctl auth status +``` + + +`auth login` auto-selects a project after authentication: if the account has exactly one project, it is selected automatically; otherwise the CLI prompts for a choice. ABAC management commands (`policy-set`, `api-key`) require a valid bearer token. + + ### project Get project information. @@ -504,6 +599,75 @@ zepctl summary-instructions add --name NAME --file instructions.txt [--user USER zepctl summary-instructions delete [--force] [--user USER_IDS] ``` +### policy-set + +Manage ABAC policy sets -- reusable bundles of access rules attached to API keys. Requires a bearer token (`auth login`) and operates on the currently-selected project (including `validate`, which calls the server). + +```bash +# List policy sets in the active project +zepctl policy-set list + +# Get a single policy set (table output renders the spec as indented YAML) +zepctl policy-set get + +# Create / update from a YAML file +zepctl policy-set create --file path/to/spec.yaml +zepctl policy-set update --file path/to/spec.yaml + +# Validate a YAML spec without persisting it +zepctl policy-set validate --file path/to/spec.yaml + +# Delete (prompts unless --force; --force is required in non-interactive contexts) +zepctl policy-set delete [--force] +``` + + +`validate` exits 0 on success, 1 if validation fails (errors printed to stderr), and 2 on client or transport errors. + + +### api-key + +Configure ABAC enforcement on individual API keys and dry-run policy decisions. All `api-key` commands require a bearer token. + +```bash +# List API keys (UUID, name, masked key, role) +zepctl api-key list + +# Per-key ABAC settings +zepctl api-key settings get +zepctl api-key settings set --mode + +# Policy set attachments +zepctl api-key policy-sets list +zepctl api-key policy-sets attach +zepctl api-key policy-sets detach + +# Dry-run a policy decision against the key's live configuration +zepctl api-key evaluate --action +zepctl api-key explain --action +``` + +#### Flags + +| Flag | Description | +|------|-------------| +| `--mode` | (`settings set`) ABAC enforcement mode: `off`, `report_only`, `enforce` | +| `--action` | (`evaluate`, `explain`) Required. Action name to evaluate (e.g. `thread.get`) | + +`evaluate` returns the final outcome (`allow` / `deny`), the ABAC and ABAC-shadow decisions, whether the role would have allowed the action, and whether the disagreement would be logged. `explain` returns everything `evaluate` does plus the registry entry's read-only flag and the full evaluator trace (matched policies per set, set modes, and skipped sets). + + +Both `evaluate` and `explain` always exit 0 on a successful API call regardless of allow/deny outcome -- inspect the JSON or table output to check the decision. + + +**Example**: +```bash +# Switch a key to report-only mode, attach a policy set, then check the decision +zepctl api-key settings set abcd1234-... --mode report_only +zepctl api-key policy-sets attach abcd1234-... efgh5678-... +zepctl api-key evaluate abcd1234-... --action thread.get -o json +``` + ## Examples ### Export All Users diff --git a/go.mod b/go.mod index a6d1c60..5c8807c 100644 --- a/go.mod +++ b/go.mod @@ -4,9 +4,12 @@ go 1.25.5 require ( github.com/getzep/zep-go/v3 v3.21.0 + github.com/google/uuid v1.4.0 github.com/spf13/cobra v1.8.1 github.com/spf13/viper v1.19.0 + github.com/stretchr/testify v1.9.0 github.com/zalando/go-keyring v0.2.6 + golang.org/x/oauth2 v0.18.0 golang.org/x/term v0.38.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -14,14 +17,16 @@ require ( require ( al.essio.dev/pkg/shellescape v1.5.1 // indirect github.com/danieljoos/wincred v1.2.2 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/godbus/dbus/v5 v5.1.0 // indirect - github.com/google/uuid v1.4.0 // indirect + github.com/golang/protobuf v1.5.3 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect @@ -34,5 +39,7 @@ require ( golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect golang.org/x/sys v0.39.0 // indirect golang.org/x/text v0.14.0 // indirect + google.golang.org/appengine v1.6.8 // indirect + google.golang.org/protobuf v1.33.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect ) diff --git a/go.sum b/go.sum index f38ec3f..16b814c 100644 --- a/go.sum +++ b/go.sum @@ -15,6 +15,11 @@ github.com/getzep/zep-go/v3 v3.21.0 h1:+ymrLdC8zWjUkKW3LM6U26VMtY0GdAcPGMYEsIwbS github.com/getzep/zep-go/v3 v3.21.0/go.mod h1:gTP6uw5RPlcFSs5z0pGUzhOpx8+w/S2swSc08efsSyQ= github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= +github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= @@ -70,20 +75,53 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/zalando/go-keyring v0.2.6 h1:r7Yc3+H+Ux0+M72zacZoItR3UDxeWfKTcabvkI8ua9s= github.com/zalando/go-keyring v0.2.6/go.mod h1:2TCrxYrbUNYfNS/Kgy/LSrkSQzZ5UPVH85RwfczwvcI= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/oauth2 v0.18.0 h1:09qnuIAgzdx1XplqJvW6CQqMCtGZykZWcXzPMPUusvI= +golang.org/x/oauth2 v0.18.0/go.mod h1:Wf7knwG0MPoWIMMBgFlEaSUDaKskp0dCfrlJRJXbBi8= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q= golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM= +google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/abac/client.go b/internal/abac/client.go new file mode 100644 index 0000000..e37c667 --- /dev/null +++ b/internal/abac/client.go @@ -0,0 +1,248 @@ +package abac + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" +) + +// Client provides methods for ABAC API endpoints. The caller obtains +// httpClient from client.NewBearerHTTPClient (which handles token +// refresh) and projectUUID from config.GetProjectUUID. +type Client struct { + HTTP *http.Client + BaseURL string + ProjectUUID string +} + +// NewClient creates an ABAC client. +func NewClient(httpClient *http.Client, baseURL, projectUUID string) *Client { + return &Client{ + HTTP: httpClient, + BaseURL: strings.TrimRight(baseURL, "/"), + ProjectUUID: projectUUID, + } +} + +// --- Policy Set CRUD --- + +func (c *Client) ListPolicySets(ctx context.Context) (*PolicySetList, error) { + var result PolicySetList + if err := c.doRequest(ctx, http.MethodGet, "/api/v2/abac/policy-sets", nil, &result); err != nil { + return nil, err + } + return &result, nil +} + +func (c *Client) GetPolicySet(ctx context.Context, uuid string) (*PolicySet, error) { + var result PolicySet + if err := c.doRequest(ctx, http.MethodGet, "/api/v2/abac/policy-sets/"+uuid, nil, &result); err != nil { + return nil, err + } + return &result, nil +} + +func (c *Client) CreatePolicySet(ctx context.Context, yamlContent string) (*PolicySet, error) { + body := map[string]string{"yaml": yamlContent} + var result PolicySet + if err := c.doRequest(ctx, http.MethodPost, "/api/v2/abac/policy-sets", body, &result); err != nil { + return nil, err + } + return &result, nil +} + +func (c *Client) UpdatePolicySet(ctx context.Context, uuid, yamlContent string) (*PolicySet, error) { + body := map[string]string{"yaml": yamlContent} + var result PolicySet + if err := c.doRequest(ctx, http.MethodPatch, "/api/v2/abac/policy-sets/"+uuid, body, &result); err != nil { + return nil, err + } + return &result, nil +} + +func (c *Client) DeletePolicySet(ctx context.Context, uuid string) error { + return c.doRequest(ctx, http.MethodDelete, "/api/v2/abac/policy-sets/"+uuid, nil, nil) +} + +func (c *Client) ValidatePolicySet(ctx context.Context, yamlContent string) (*ValidationResult, error) { + body := map[string]string{"yaml": yamlContent} + var result ValidationResult + if err := c.doRequest(ctx, http.MethodPost, "/api/v2/abac/policy-sets/validate", body, &result); err != nil { + return nil, err + } + return &result, nil +} + +// --- API Key Settings --- + +func (c *Client) GetAPIKeySettings(ctx context.Context, keyUUID string) (*APIKeySettings, error) { + var result APIKeySettings + path := fmt.Sprintf("/api/v2/abac/api-keys/%s/settings", keyUUID) + if err := c.doRequest(ctx, http.MethodGet, path, nil, &result); err != nil { + return nil, err + } + return &result, nil +} + +func (c *Client) SetAPIKeySettings(ctx context.Context, keyUUID, mode string) (*APIKeySettings, error) { + body := map[string]string{"abac_mode": mode} + var result APIKeySettings + path := fmt.Sprintf("/api/v2/abac/api-keys/%s/settings", keyUUID) + if err := c.doRequest(ctx, http.MethodPatch, path, body, &result); err != nil { + return nil, err + } + return &result, nil +} + +// --- API Key List --- + +// ListAPIKeys returns the API keys for the current project. +func (c *Client) ListAPIKeys(ctx context.Context) (*ProjectKeysResponse, error) { + var result ProjectKeysResponse + if err := c.doRequest(ctx, http.MethodGet, "/api/v2/abac/api-keys", nil, &result); err != nil { + return nil, err + } + return &result, nil +} + +// --- API Key Policy Attachments --- + +func (c *Client) ListAPIKeyPolicySets(ctx context.Context, keyUUID string) (*AttachedPolicySets, error) { + var result AttachedPolicySets + path := fmt.Sprintf("/api/v2/abac/api-keys/%s/policy-sets", keyUUID) + if err := c.doRequest(ctx, http.MethodGet, path, nil, &result); err != nil { + return nil, err + } + return &result, nil +} + +func (c *Client) AttachPolicySet(ctx context.Context, keyUUID, policySetUUID string) (*AttachedPolicySets, error) { + body := map[string]string{"policy_set_uuid": policySetUUID} + var result AttachedPolicySets + path := fmt.Sprintf("/api/v2/abac/api-keys/%s/policy-sets", keyUUID) + if err := c.doRequest(ctx, http.MethodPost, path, body, &result); err != nil { + return nil, err + } + return &result, nil +} + +func (c *Client) DetachPolicySet(ctx context.Context, keyUUID, policySetUUID string) (*AttachedPolicySets, error) { + var result AttachedPolicySets + path := fmt.Sprintf("/api/v2/abac/api-keys/%s/policy-sets/%s", keyUUID, policySetUUID) + if err := c.doRequest(ctx, http.MethodDelete, path, nil, &result); err != nil { + return nil, err + } + return &result, nil +} + +// --- Policy Evaluation (Dry-Run) --- + +// EvaluatePolicy posts to /api/v2/abac/policies/evaluate and returns the +// concise decision plus the unmodified server body in RawJSON. +func (c *Client) EvaluatePolicy(ctx context.Context, keyUUID, action string) (*EvaluateResponse, error) { + body := map[string]string{"api_key_uuid": keyUUID, "action": action} + raw, err := c.doRequestRaw(ctx, http.MethodPost, "/api/v2/abac/policies/evaluate", body) + if err != nil { + return nil, err + } + var result EvaluateResponse + if err := json.Unmarshal(raw, &result); err != nil { + return nil, fmt.Errorf("decoding response: %w", err) + } + result.RawJSON = raw + return &result, nil +} + +// ExplainPolicy posts to /api/v2/abac/policies/explain and returns the +// decision plus a structured trace. RawJSON preserves any +// trace fields the typed shape does not yet render. +func (c *Client) ExplainPolicy(ctx context.Context, keyUUID, action string) (*ExplainResponse, error) { + body := map[string]string{"api_key_uuid": keyUUID, "action": action} + raw, err := c.doRequestRaw(ctx, http.MethodPost, "/api/v2/abac/policies/explain", body) + if err != nil { + return nil, err + } + var result ExplainResponse + if err := json.Unmarshal(raw, &result); err != nil { + return nil, fmt.Errorf("decoding response: %w", err) + } + result.RawJSON = raw + return &result, nil +} + +// --- Internal --- + +func (c *Client) doRequest(ctx context.Context, method, path string, body, result any) error { + raw, err := c.doRequestRaw(ctx, method, path, body) + if err != nil { + return err + } + if len(raw) == 0 || result == nil { + return nil + } + if err := json.Unmarshal(raw, result); err != nil { + return fmt.Errorf("decoding response: %w", err) + } + return nil +} + +func (c *Client) doRequestRaw(ctx context.Context, method, path string, body any) ([]byte, error) { + var bodyReader io.Reader + if body != nil { + b, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("encoding request: %w", err) + } + bodyReader = bytes.NewReader(b) + } + + req, err := http.NewRequestWithContext(ctx, method, c.BaseURL+path, bodyReader) + if err != nil { + return nil, fmt.Errorf("building request: %w", err) + } + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + if c.ProjectUUID != "" { + req.Header.Set("X-Zep-Project", c.ProjectUUID) + } + + resp, err := c.HTTP.Do(req) + if err != nil { + return nil, fmt.Errorf("sending request: %w", err) + } + defer resp.Body.Close() + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading response: %w", err) + } + + if resp.StatusCode >= 400 { + return nil, decodeError(resp.StatusCode, raw) + } + if resp.StatusCode == http.StatusNoContent { + return nil, nil + } + return raw, nil +} + +func decodeError(statusCode int, raw []byte) error { + var errResp struct { + Message string `json:"message"` + } + if err := json.Unmarshal(raw, &errResp); err != nil || errResp.Message == "" { + return &APIError{ + StatusCode: statusCode, + Message: fmt.Sprintf("API error (HTTP %d)", statusCode), + } + } + return &APIError{ + StatusCode: statusCode, + Message: errResp.Message, + } +} diff --git a/internal/abac/client_test.go b/internal/abac/client_test.go new file mode 100644 index 0000000..25b47f3 --- /dev/null +++ b/internal/abac/client_test.go @@ -0,0 +1,400 @@ +package abac + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// newTestClient creates a Client pointing at the given test server. +func newTestClient(srv *httptest.Server) *Client { + return NewClient(srv.Client(), srv.URL, "test-project-uuid") +} + +// --- Policy Set CRUD --- + +func TestClient_ListPolicySets(t *testing.T) { + var gotReq *http.Request + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotReq = r + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(PolicySetList{ + PolicySets: []PolicySetSummary{ + {UUID: "ps-1", Name: "policy_one", Mode: "enforce", Version: 2}, + {UUID: "ps-2", Name: "policy_two", Mode: "off", Version: 1}, + }, + }) + })) + defer srv.Close() + + result, err := newTestClient(srv).ListPolicySets(context.Background()) + require.NoError(t, err) + assert.Len(t, result.PolicySets, 2) + assert.Equal(t, "policy_one", result.PolicySets[0].Name) + assert.Equal(t, http.MethodGet, gotReq.Method) + assert.Equal(t, "/api/v2/abac/policy-sets", gotReq.URL.Path) + assert.Equal(t, "test-project-uuid", gotReq.Header.Get("X-Zep-Project")) +} + +func TestClient_GetPolicySet(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/api/v2/abac/policy-sets/ps-uuid-1", r.URL.Path) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(PolicySet{ + UUID: "ps-uuid-1", Name: "test_policy", Mode: "report_only", Version: 3, + Spec: map[string]any{"policies": []any{}}, + }) + })) + defer srv.Close() + + result, err := newTestClient(srv).GetPolicySet(context.Background(), "ps-uuid-1") + require.NoError(t, err) + assert.Equal(t, "test_policy", result.Name) + assert.Equal(t, 3, result.Version) + assert.NotNil(t, result.Spec) +} + +func TestClient_CreatePolicySet(t *testing.T) { + var gotBody map[string]string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method) + assert.Equal(t, "application/json", r.Header.Get("Content-Type")) + body, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(body, &gotBody) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(PolicySet{ + UUID: "new-uuid", Name: "created", Mode: "off", Version: 1, + }) + })) + defer srv.Close() + + result, err := newTestClient(srv).CreatePolicySet(context.Background(), "policy_set:\n name: created\n mode: off\n spec: {}") + require.NoError(t, err) + assert.Equal(t, "created", result.Name) + assert.Equal(t, 1, result.Version) + assert.Contains(t, gotBody["yaml"], "policy_set:") +} + +func TestClient_UpdatePolicySet(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPatch, r.Method) + assert.Equal(t, "/api/v2/abac/policy-sets/ps-1", r.URL.Path) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(PolicySet{ + UUID: "ps-1", Name: "updated", Mode: "enforce", Version: 4, + }) + })) + defer srv.Close() + + result, err := newTestClient(srv).UpdatePolicySet(context.Background(), "ps-1", "yaml content") + require.NoError(t, err) + assert.Equal(t, 4, result.Version) +} + +func TestClient_DeletePolicySet(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodDelete, r.Method) + assert.Equal(t, "/api/v2/abac/policy-sets/ps-1", r.URL.Path) + w.WriteHeader(http.StatusNoContent) + })) + defer srv.Close() + + err := newTestClient(srv).DeletePolicySet(context.Background(), "ps-1") + require.NoError(t, err) +} + +func TestClient_ValidatePolicySet_Valid(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method) + assert.Equal(t, "/api/v2/abac/policy-sets/validate", r.URL.Path) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(ValidationResult{Valid: true, Errors: []ValidationError{}}) + })) + defer srv.Close() + + result, err := newTestClient(srv).ValidatePolicySet(context.Background(), "yaml content") + require.NoError(t, err) + assert.True(t, result.Valid) + assert.Empty(t, result.Errors) +} + +func TestClient_ValidatePolicySet_Invalid(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(ValidationResult{ + Valid: false, + Errors: []ValidationError{ + {PolicyID: "bad_policy", Message: "unrecognized action"}, + }, + }) + })) + defer srv.Close() + + result, err := newTestClient(srv).ValidatePolicySet(context.Background(), "yaml content") + require.NoError(t, err) + assert.False(t, result.Valid) + assert.Len(t, result.Errors, 1) + assert.Equal(t, "bad_policy", result.Errors[0].PolicyID) +} + +// --- API Key List --- + +func TestClient_ListAPIKeys(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/api/v2/abac/api-keys", r.URL.Path) + assert.Equal(t, http.MethodGet, r.Method) + assert.Equal(t, "test-project-uuid", r.Header.Get("X-Zep-Project")) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(ProjectKeysResponse{ + Keys: []ProjectKey{ + {UUID: "key-1", Name: "prod", FirstFour: "z_1d", LastFour: "a4be", Role: "default_allow"}, + }, + }) + })) + defer srv.Close() + + result, err := newTestClient(srv).ListAPIKeys(context.Background()) + require.NoError(t, err) + assert.Len(t, result.Keys, 1) + assert.Equal(t, "key-1", result.Keys[0].UUID) + assert.Equal(t, "z_1d", result.Keys[0].FirstFour) + assert.Equal(t, "default_allow", result.Keys[0].Role) +} + +// --- API Key Settings --- + +func TestClient_GetAPIKeySettings(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/api/v2/abac/api-keys/key-1/settings", r.URL.Path) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(APIKeySettings{ABACMode: "off", Capabilities: "read_write"}) + })) + defer srv.Close() + + result, err := newTestClient(srv).GetAPIKeySettings(context.Background(), "key-1") + require.NoError(t, err) + assert.Equal(t, "off", result.ABACMode) + assert.Equal(t, "read_write", result.Capabilities) +} + +func TestClient_SetAPIKeySettings(t *testing.T) { + var gotBody map[string]string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPatch, r.Method) + body, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(body, &gotBody) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(APIKeySettings{ABACMode: "enforce", Capabilities: "read_write"}) + })) + defer srv.Close() + + result, err := newTestClient(srv).SetAPIKeySettings(context.Background(), "key-1", "enforce") + require.NoError(t, err) + assert.Equal(t, "enforce", result.ABACMode) + assert.Equal(t, "enforce", gotBody["abac_mode"]) +} + +// --- API Key Policy Attachments --- + +func TestClient_ListAPIKeyPolicySets(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/api/v2/abac/api-keys/key-1/policy-sets", r.URL.Path) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(AttachedPolicySets{ + PolicySets: []PolicySetSummary{{UUID: "ps-1", Name: "test", Mode: "enforce", Version: 1}}, + }) + })) + defer srv.Close() + + result, err := newTestClient(srv).ListAPIKeyPolicySets(context.Background(), "key-1") + require.NoError(t, err) + assert.Len(t, result.PolicySets, 1) +} + +func TestClient_AttachPolicySet(t *testing.T) { + var gotBody map[string]string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method) + assert.Equal(t, "/api/v2/abac/api-keys/key-1/policy-sets", r.URL.Path) + body, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(body, &gotBody) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(AttachedPolicySets{ + PolicySets: []PolicySetSummary{{UUID: "ps-1", Name: "attached", Mode: "enforce", Version: 1}}, + }) + })) + defer srv.Close() + + result, err := newTestClient(srv).AttachPolicySet(context.Background(), "key-1", "ps-1") + require.NoError(t, err) + assert.Len(t, result.PolicySets, 1) + assert.Equal(t, "ps-1", gotBody["policy_set_uuid"]) +} + +func TestClient_DetachPolicySet(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodDelete, r.Method) + assert.Equal(t, "/api/v2/abac/api-keys/key-1/policy-sets/ps-1", r.URL.Path) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(AttachedPolicySets{PolicySets: []PolicySetSummary{}}) + })) + defer srv.Close() + + result, err := newTestClient(srv).DetachPolicySet(context.Background(), "key-1", "ps-1") + require.NoError(t, err) + assert.Empty(t, result.PolicySets) +} + +// --- Policy Evaluation (Dry-Run) --- + +func TestClient_EvaluatePolicy(t *testing.T) { + var gotBody map[string]string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method) + assert.Equal(t, "/api/v2/abac/policies/evaluate", r.URL.Path) + assert.Equal(t, "application/json", r.Header.Get("Content-Type")) + assert.Equal(t, "test-project-uuid", r.Header.Get("X-Zep-Project")) + body, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(body, &gotBody) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"outcome":"allow","abac":"ALLOW","abac_shadow":"NEUTRAL","role_allows":false,"would_log_disagreement":true}`)) + })) + defer srv.Close() + + result, err := newTestClient(srv).EvaluatePolicy(context.Background(), "key-1", "thread.get") + require.NoError(t, err) + assert.Equal(t, "allow", result.Outcome) + assert.Equal(t, "ALLOW", result.ABAC) + assert.Equal(t, "NEUTRAL", result.ABACShadow) + assert.False(t, result.RoleAllows) + assert.True(t, result.WouldLogDisagreement) + assert.Equal(t, "key-1", gotBody["api_key_uuid"]) + assert.Equal(t, "thread.get", gotBody["action"]) + // RawJSON preserves the server response body verbatim. + assert.Contains(t, string(result.RawJSON), `"outcome":"allow"`) +} + +// EvaluatePolicy preserves server-added fields the typed shape doesn't +// model, for forward compatibility. +func TestClient_EvaluatePolicy_PreservesUnknownFields(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"outcome":"deny","abac":"NEUTRAL","abac_shadow":"NEUTRAL","role_allows":false,"would_log_disagreement":false,"future_field":"surprise"}`)) + })) + defer srv.Close() + + result, err := newTestClient(srv).EvaluatePolicy(context.Background(), "key-1", "thread.delete") + require.NoError(t, err) + assert.Equal(t, "deny", result.Outcome) + assert.Contains(t, string(result.RawJSON), `"future_field":"surprise"`) +} + +func TestClient_ExplainPolicy(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method) + assert.Equal(t, "/api/v2/abac/policies/explain", r.URL.Path) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "outcome": "allow", + "abac": "ALLOW", + "abac_shadow": "NEUTRAL", + "role_allows": false, + "would_log_disagreement": true, + "trace": { + "action": "thread.get", + "registry_entry": {"read_only": true}, + "role_base_covers_action": false, + "evaluated_sets": [ + {"uuid": "set-1", "name": "read_everything", "set_mode": "enforce", + "matched": [{"policy_id": "allow_reads", "effect": "allow", "matched_via": "readonly"}]} + ], + "skipped_sets": [] + } + }`)) + })) + defer srv.Close() + + result, err := newTestClient(srv).ExplainPolicy(context.Background(), "key-1", "thread.get") + require.NoError(t, err) + assert.Equal(t, "ALLOW", result.ABAC) + assert.Equal(t, "thread.get", result.Trace.Action) + assert.True(t, result.Trace.RegistryEntry.ReadOnly) + require.Len(t, result.Trace.EvaluatedSets, 1) + assert.Equal(t, "read_everything", result.Trace.EvaluatedSets[0].Name) + require.Len(t, result.Trace.EvaluatedSets[0].Matched, 1) + assert.Equal(t, "readonly", result.Trace.EvaluatedSets[0].Matched[0].MatchedVia) + assert.Empty(t, result.Trace.SkippedSets) + // role_base_covers_action is intentionally not in the typed shape but is + // retained in RawJSON. + assert.Contains(t, string(result.RawJSON), `"role_base_covers_action": false`) +} + +// --- Error Handling --- + +func TestClient_NotFound(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _ = json.NewEncoder(w).Encode(map[string]string{"message": "policy set not found"}) + })) + defer srv.Close() + + _, err := newTestClient(srv).GetPolicySet(context.Background(), "nonexistent") + require.Error(t, err) + assert.True(t, IsNotFound(err)) + assert.Contains(t, err.Error(), "policy set not found") +} + +func TestClient_Conflict(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusConflict) + _ = json.NewEncoder(w).Encode(map[string]string{"message": "duplicate name"}) + })) + defer srv.Close() + + _, err := newTestClient(srv).CreatePolicySet(context.Background(), "yaml") + require.Error(t, err) + assert.True(t, IsConflict(err)) +} + +func TestClient_BadRequest(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"message": "invalid yaml"}) + })) + defer srv.Close() + + _, err := newTestClient(srv).CreatePolicySet(context.Background(), "bad") + require.Error(t, err) + var apiErr *APIError + require.ErrorAs(t, err, &apiErr) + assert.Equal(t, http.StatusBadRequest, apiErr.StatusCode) +} + +func TestClient_ErrorWithoutBody(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + + _, err := newTestClient(srv).ListPolicySets(context.Background()) + require.Error(t, err) + var apiErr *APIError + require.ErrorAs(t, err, &apiErr) + assert.Equal(t, http.StatusInternalServerError, apiErr.StatusCode) + assert.Contains(t, apiErr.Message, "HTTP 500") +} + +func TestIsNotFound_NonAPIError(t *testing.T) { + assert.False(t, IsNotFound(assert.AnError)) +} + +func TestIsConflict_NonAPIError(t *testing.T) { + assert.False(t, IsConflict(assert.AnError)) +} diff --git a/internal/abac/types.go b/internal/abac/types.go new file mode 100644 index 0000000..63ab069 --- /dev/null +++ b/internal/abac/types.go @@ -0,0 +1,165 @@ +package abac + +import ( + "encoding/json" + "errors" + "net/http" +) + +// PolicySet is the full policy set response from the API. +type PolicySet struct { + UUID string `json:"uuid"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + Mode string `json:"mode"` + Version int `json:"version"` + Spec map[string]any `json:"spec,omitempty"` + ProjectUUID string `json:"project_uuid"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` +} + +// PolicySetSummary is the abbreviated form returned by list endpoints. +type PolicySetSummary struct { + UUID string `json:"uuid"` + Name string `json:"name"` + Mode string `json:"mode"` + Version int `json:"version"` +} + +// PolicySetList is the response for GET /api/v2/abac/policy-sets. +type PolicySetList struct { + PolicySets []PolicySetSummary `json:"policy_sets"` +} + +// AttachedPolicySets is the response for api-key policy-set list/attach/detach. +type AttachedPolicySets struct { + PolicySets []PolicySetSummary `json:"policy_sets"` +} + +// APIKeySettings is the response for api-key settings endpoints. +type APIKeySettings struct { + ABACMode string `json:"abac_mode"` + Capabilities string `json:"capabilities"` +} + +// ValidationResult is the response from the validate endpoint. +type ValidationResult struct { + Valid bool `json:"valid"` + Errors []ValidationError `json:"errors"` +} + +// ValidationError is a single validation error from the server. +type ValidationError struct { + PolicyID string `json:"policy_id,omitempty"` + Message string `json:"message"` +} + +// ProjectKey represents an API key in the project keys list response. +type ProjectKey struct { + UUID string `json:"uuid"` + ProjectUUID string `json:"project_uuid"` + AccountUUID string `json:"account_uuid"` + Name string `json:"name"` + FirstFour string `json:"first_four"` + LastFour string `json:"last_four"` + Role string `json:"role"` +} + +// ProjectKeysResponse is the response for the list API keys endpoint. +type ProjectKeysResponse struct { + Keys []ProjectKey `json:"keys"` +} + +// EvaluateResponse is the response from POST /api/v2/abac/policies/evaluate. +// +// RawJSON is the unmodified server response body. Table output uses the +// typed fields; JSON/YAML output prints RawJSON verbatim so any +// server-added fields are surfaced without a CLI release. Set by the +// client method, not by the JSON decoder. +type EvaluateResponse struct { + Outcome string `json:"outcome"` + ABAC string `json:"abac"` + ABACShadow string `json:"abac_shadow"` + RoleAllows bool `json:"role_allows"` + WouldLogDisagreement bool `json:"would_log_disagreement"` + RawJSON json.RawMessage `json:"-"` +} + +// ExplainResponse is the response from POST /api/v2/abac/policies/explain. +// +// The trace is documentation-shaped and forward-compatible: unknown fields +// inside the trace are preserved in RawJSON and surfaced in JSON/YAML +// output even if the typed shape does not yet render them. +type ExplainResponse struct { + Outcome string `json:"outcome"` + ABAC string `json:"abac"` + ABACShadow string `json:"abac_shadow"` + RoleAllows bool `json:"role_allows"` + WouldLogDisagreement bool `json:"would_log_disagreement"` + Trace ExplainTrace `json:"trace"` + RawJSON json.RawMessage `json:"-"` +} + +// ExplainTrace carries the structured decision trace. +// +// role_base_covers_action is intentionally absent: the same boolean is +// already exposed at the top level as RoleAllows. The raw field is +// preserved in ExplainResponse.RawJSON for JSON callers. +type ExplainTrace struct { + Action string `json:"action"` + RegistryEntry ExplainRegistryEntry `json:"registry_entry"` + EvaluatedSets []ExplainEvaluatedSet `json:"evaluated_sets"` + SkippedSets []ExplainSkippedSet `json:"skipped_sets"` +} + +type ExplainRegistryEntry struct { + ReadOnly bool `json:"read_only"` +} + +type ExplainEvaluatedSet struct { + UUID string `json:"uuid"` + Name string `json:"name"` + SetMode string `json:"set_mode"` + Matched []ExplainMatched `json:"matched"` +} + +type ExplainMatched struct { + PolicyID string `json:"policy_id"` + Effect string `json:"effect"` + MatchedVia string `json:"matched_via"` +} + +type ExplainSkippedSet struct { + UUID string `json:"uuid"` + Name string `json:"name"` + Reason string `json:"reason"` +} + +// APIError is returned when the ABAC API returns a non-2xx response. +type APIError struct { + StatusCode int + Message string +} + +func (e *APIError) Error() string { + return e.Message +} + +// IsNotFound reports whether err is a 404 API error. +func IsNotFound(err error) bool { + var apiErr *APIError + if errors.As(err, &apiErr) { + return apiErr.StatusCode == http.StatusNotFound + } + return false +} + +// IsConflict reports whether err is a 409 API error. +func IsConflict(err error) bool { + var apiErr *APIError + if errors.As(err, &apiErr) { + return apiErr.StatusCode == http.StatusConflict + } + return false +} diff --git a/internal/auth/config.go b/internal/auth/config.go new file mode 100644 index 0000000..b2a08e2 --- /dev/null +++ b/internal/auth/config.go @@ -0,0 +1,50 @@ +package auth + +// Default OAuth configuration values. These can be overridden at build +// time via -ldflags for development/testing: +// +// go build -ldflags "-X github.com/getzep/zepctl/internal/auth.defaultIssuer=https://dev.kinde.com \ +// -X github.com/getzep/zepctl/internal/auth.defaultClientID=dev-client-id" +// +// They can also be overridden per-profile via the `oauth-issuer` and +// `oauth-client-id` config fields, so a single binary can authenticate +// against multiple OAuth tenants. See OAuthConfigFor. +var ( + defaultIssuer = "https://getzep.kinde.com" + defaultClientID = "8b0f41c63c6141c282c3b4dfd740708d" +) + +// DefaultOAuthConfig returns the OAuth configuration for zepctl using +// build-time defaults. Prefer OAuthConfigFor when a profile is available. +// +// The client ID is NOT a secret. It is a public OAuth 2.0 client identifier +// for a "Front-end" application, which has no client secret. It is safe to +// commit to source control and compile into the binary. See spec Sections +// 2.2 and 7.5. +// +// Audience is intentionally empty in the build-time defaults. Profiles that +// target a backend which enforces the aud claim must provide one explicitly +// via the OAuth audience profile field (or via an environment preset). +func DefaultOAuthConfig() *OAuthConfig { + return &OAuthConfig{ + Issuer: defaultIssuer, + ClientID: defaultClientID, + } +} + +// OAuthConfigFor returns the OAuth configuration for the given profile, +// using profile-level overrides where present and falling back to the +// build-time defaults. Any subset of overrides may be set. +func OAuthConfigFor(issuer, clientID, audience string) *OAuthConfig { + cfg := DefaultOAuthConfig() + if issuer != "" { + cfg.Issuer = issuer + } + if clientID != "" { + cfg.ClientID = clientID + } + if audience != "" { + cfg.Audience = audience + } + return cfg +} diff --git a/internal/auth/config_test.go b/internal/auth/config_test.go new file mode 100644 index 0000000..32caf29 --- /dev/null +++ b/internal/auth/config_test.go @@ -0,0 +1,59 @@ +package auth + +import "testing" + +func TestOAuthConfigFor_FallsBackToDefaults(t *testing.T) { + cfg := OAuthConfigFor("", "", "") + if cfg.Issuer != defaultIssuer { + t.Errorf("Issuer = %q, want default %q", cfg.Issuer, defaultIssuer) + } + if cfg.ClientID != defaultClientID { + t.Errorf("ClientID = %q, want default %q", cfg.ClientID, defaultClientID) + } + if cfg.Audience != "" { + t.Errorf("Audience = %q, want empty default", cfg.Audience) + } +} + +func TestOAuthConfigFor_OverridesIssuerOnly(t *testing.T) { + cfg := OAuthConfigFor("https://custom.kinde.com", "", "") + if cfg.Issuer != "https://custom.kinde.com" { + t.Errorf("Issuer = %q, want override", cfg.Issuer) + } + if cfg.ClientID != defaultClientID { + t.Errorf("ClientID = %q, want default %q", cfg.ClientID, defaultClientID) + } +} + +func TestOAuthConfigFor_OverridesClientIDOnly(t *testing.T) { + cfg := OAuthConfigFor("", "custom-client", "") + if cfg.Issuer != defaultIssuer { + t.Errorf("Issuer = %q, want default %q", cfg.Issuer, defaultIssuer) + } + if cfg.ClientID != "custom-client" { + t.Errorf("ClientID = %q, want override", cfg.ClientID) + } +} + +func TestOAuthConfigFor_OverridesAudienceOnly(t *testing.T) { + cfg := OAuthConfigFor("", "", "https://api.example.com") + if cfg.Issuer != defaultIssuer { + t.Errorf("Issuer = %q, want default", cfg.Issuer) + } + if cfg.Audience != "https://api.example.com" { + t.Errorf("Audience = %q, want override", cfg.Audience) + } +} + +func TestOAuthConfigFor_OverridesAll(t *testing.T) { + cfg := OAuthConfigFor("https://custom.kinde.com", "custom-client", "https://api.example.com") + if cfg.Issuer != "https://custom.kinde.com" { + t.Errorf("Issuer = %q, want override", cfg.Issuer) + } + if cfg.ClientID != "custom-client" { + t.Errorf("ClientID = %q, want override", cfg.ClientID) + } + if cfg.Audience != "https://api.example.com" { + t.Errorf("Audience = %q, want override", cfg.Audience) + } +} diff --git a/internal/auth/idtoken.go b/internal/auth/idtoken.go new file mode 100644 index 0000000..c285c5d --- /dev/null +++ b/internal/auth/idtoken.go @@ -0,0 +1,38 @@ +package auth + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "strings" +) + +// IDTokenClaims holds the claims extracted from a Kinde ID token. +// Only the fields needed for display are included. +type IDTokenClaims struct { + Email string `json:"email"` + Name string `json:"name"` + Sub string `json:"sub"` +} + +// ParseUnverifiedIDToken extracts claims from an ID token without validating the +// signature. The ID token is used only for display (user email/name) and +// is not stored long-term. +func ParseUnverifiedIDToken(idToken string) (*IDTokenClaims, error) { + parts := strings.Split(idToken, ".") + if len(parts) != 3 { + return nil, fmt.Errorf("invalid ID token format") + } + + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return nil, fmt.Errorf("decoding ID token payload: %w", err) + } + + var claims IDTokenClaims + if err := json.Unmarshal(payload, &claims); err != nil { + return nil, fmt.Errorf("parsing ID token claims: %w", err) + } + + return &claims, nil +} diff --git a/internal/auth/idtoken_test.go b/internal/auth/idtoken_test.go new file mode 100644 index 0000000..53e9b18 --- /dev/null +++ b/internal/auth/idtoken_test.go @@ -0,0 +1,66 @@ +package auth + +import ( + "encoding/base64" + "encoding/json" + "testing" +) + +func TestParseUnverifiedIDToken(t *testing.T) { + // Build a valid JWT-shaped ID token (header.payload.signature). + claims := map[string]string{ + "email": "fred@frobozz.infocom", + "name": "Fred", + "sub": "kp_user123", + } + payload, _ := json.Marshal(claims) + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256"}`)) + body := base64.RawURLEncoding.EncodeToString(payload) + sig := base64.RawURLEncoding.EncodeToString([]byte("signature")) + + token := header + "." + body + "." + sig + + got, err := ParseUnverifiedIDToken(token) + if err != nil { + t.Fatalf("ParseUnverifiedIDToken: %v", err) + } + if got.Email != "fred@frobozz.infocom" { + t.Errorf("Email = %q, want %q", got.Email, "fred@frobozz.infocom") + } + if got.Name != "Fred" { + t.Errorf("Name = %q, want %q", got.Name, "Fred") + } + if got.Sub != "kp_user123" { + t.Errorf("Sub = %q, want %q", got.Sub, "kp_user123") + } +} + +func TestParseUnverifiedIDToken_InvalidFormat(t *testing.T) { + _, err := ParseUnverifiedIDToken("not-a-jwt") + if err == nil { + t.Error("expected error for invalid token format") + } +} + +func TestParseUnverifiedIDToken_InvalidPayload(t *testing.T) { + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256"}`)) + sig := base64.RawURLEncoding.EncodeToString([]byte("sig")) + token := header + ".!!!invalid!!!" + "." + sig + + _, err := ParseUnverifiedIDToken(token) + if err == nil { + t.Error("expected error for invalid base64 payload") + } +} + +func TestParseUnverifiedIDToken_InvalidJSON(t *testing.T) { + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256"}`)) + body := base64.RawURLEncoding.EncodeToString([]byte(`not json`)) + sig := base64.RawURLEncoding.EncodeToString([]byte("sig")) + token := header + "." + body + "." + sig + + _, err := ParseUnverifiedIDToken(token) + if err == nil { + t.Error("expected error for invalid JSON payload") + } +} diff --git a/internal/auth/oauth.go b/internal/auth/oauth.go new file mode 100644 index 0000000..94758f3 --- /dev/null +++ b/internal/auth/oauth.go @@ -0,0 +1,414 @@ +package auth + +import ( + "context" + "crypto/rand" + "encoding/base64" + "fmt" + "html" + "net" + "net/http" + "net/url" + "os/exec" + "runtime" + "strings" + "time" + + "golang.org/x/oauth2" +) + +const ( + callbackTimeout = 5 * time.Minute + callbackPath = "/callback" + + // DefaultAPIURL is the default Zep API base URL. + DefaultAPIURL = "https://api.getzep.com" + + // callbackPort is the fixed port for the OAuth callback server. + // This port is registered as an allowed callback URL in Kinde + // (http://127.0.0.1:18923/callback). Using a fixed port avoids + // the need for wildcard port registration in Kinde. + callbackPort = 18923 +) + +// OAuthConfig holds the OAuth configuration for zepctl. +type OAuthConfig struct { + Issuer string // OIDC issuer, e.g. "https://myapp.kinde.com". + ClientID string // OAuth client ID (public, not a secret). + Audience string // Optional API audience. +} + +// TokenResult holds the tokens returned from a successful OAuth exchange. +type TokenResult struct { + AccessToken string + RefreshToken string + ExpiresAt time.Time + IDToken string +} + +// Login performs the Authorization Code + PKCE flow. +// If noBrowser is true, the manual code-paste flow is used. +func Login(ctx context.Context, cfg *OAuthConfig, session *KeychainSession, noBrowser bool) (*TokenResult, error) { + if noBrowser { + return loginManual(ctx, cfg, session) + } + return loginBrowser(ctx, cfg, session) +} + +// newOAuth2Config returns a standard oauth2.Config for the given OAuthConfig. +func newOAuth2Config(cfg *OAuthConfig) *oauth2.Config { + return &oauth2.Config{ + ClientID: cfg.ClientID, + RedirectURL: callbackURL(), + // Kinde uses "offline" not "offline_access". + Scopes: []string{"openid", "profile", "email", "offline"}, + Endpoint: oauth2.Endpoint{ + AuthURL: cfg.Issuer + "/oauth2/auth", + TokenURL: cfg.Issuer + "/oauth2/token", + AuthStyle: oauth2.AuthStyleInParams, + }, + } +} + +// generateAuthURL builds the authorization URL with PKCE parameters and +// stores the state and code verifier in the session. +func generateAuthURL(cfg *OAuthConfig, session *KeychainSession) string { + oauthCfg := newOAuth2Config(cfg) + + state := generateState() + _ = session.SetState(state) + + verifier := oauth2.GenerateVerifier() + _ = session.SetCodeVerifier(verifier) + + opts := []oauth2.AuthCodeOption{ + oauth2.S256ChallengeOption(verifier), + oauth2.SetAuthURLParam("is_use_auth_success_page", "true"), + } + // Kinde populates the access token's aud claim only when an audience is requested. + if cfg.Audience != "" { + opts = append(opts, oauth2.SetAuthURLParam("audience", cfg.Audience)) + } + + return oauthCfg.AuthCodeURL(state, opts...) +} + +// generateState returns a cryptographically random URL-safe string for CSRF protection. +func generateState() string { + b := make([]byte, 32) + _, _ = rand.Read(b) + return base64.RawURLEncoding.EncodeToString(b) +} + +// callbackURL returns the fixed localhost callback URL registered in Kinde. +func callbackURL() string { + return fmt.Sprintf("http://127.0.0.1:%d%s", callbackPort, callbackPath) +} + +// callbackHandler returns an HTTP handler that receives the OAuth redirect, +// exchanges the authorization code for tokens, and signals completion on doneCh. +func callbackHandler(cfg *OAuthConfig, session *KeychainSession, doneCh chan<- error) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if errParam := r.URL.Query().Get("error"); errParam != "" { + desc := r.URL.Query().Get("error_description") + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintf(w, "Authentication failed: %s -- %s", html.EscapeString(errParam), html.EscapeString(desc)) + doneCh <- fmt.Errorf("OAuth error: %s: %s", errParam, desc) + return + } + code := r.URL.Query().Get("code") + state := r.URL.Query().Get("state") + if err := exchangeCode(r.Context(), cfg, session, code, state); err != nil { + http.Error(w, "Authentication failed", http.StatusInternalServerError) + doneCh <- fmt.Errorf("token exchange failed: %w", err) + return + } + w.WriteHeader(http.StatusOK) + doneCh <- nil + } +} + +func loginBrowser(ctx context.Context, cfg *OAuthConfig, session *KeychainSession) (*TokenResult, error) { + doneCh := make(chan error, 1) + + lc := net.ListenConfig{} + listener, err := lc.Listen(ctx, "tcp", fmt.Sprintf("127.0.0.1:%d", callbackPort)) + if err != nil { + // Port in use -- fall back to manual flow. + return loginManual(ctx, cfg, session) + } + + mux := http.NewServeMux() + mux.HandleFunc(callbackPath, callbackHandler(cfg, session, doneCh)) + srv := &http.Server{ + Handler: mux, + ReadHeaderTimeout: 10 * time.Second, + } + go func() { _ = srv.Serve(listener) }() + defer func() { + shutCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _ = srv.Shutdown(shutCtx) + }() + + authURL := generateAuthURL(cfg, session) + + if err := openBrowser(ctx, authURL); err != nil { + // Browser failed -- fall back to manual flow. + return loginManual(ctx, cfg, session) + } + + fmt.Println("Opening browser to authenticate...") + + timeout := time.NewTimer(callbackTimeout) + defer timeout.Stop() + + select { + case err := <-doneCh: + if err != nil { + return nil, err + } + return tokenResultFromSession(session) + case <-timeout.C: + return nil, fmt.Errorf("timed out waiting for authentication (5 minutes)") + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// loginManual performs the manual code-paste flow. +func loginManual(ctx context.Context, cfg *OAuthConfig, session *KeychainSession) (*TokenResult, error) { + doneCh := make(chan error, 1) + + // Try to bind the callback port. If it fails, the flow still works + // via manual code paste. + var srv *http.Server + lc := net.ListenConfig{} + listener, lerr := lc.Listen(ctx, "tcp", fmt.Sprintf("127.0.0.1:%d", callbackPort)) + if lerr == nil { + mux := http.NewServeMux() + mux.HandleFunc(callbackPath, callbackHandler(cfg, session, doneCh)) + srv = &http.Server{ + Handler: mux, + ReadHeaderTimeout: 10 * time.Second, + } + go func() { _ = srv.Serve(listener) }() + } + + authURL := generateAuthURL(cfg, session) + printManualInstructions(authURL) + + if srv != nil { + defer func() { + shutCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _ = srv.Shutdown(shutCtx) + }() + } + + // Read code from stdin in a goroutine -- races with the callback server. + stdinCh := make(chan codeOrError, 1) + go readCodeFromStdin(stdinCh) + + timeout := time.NewTimer(callbackTimeout) + defer timeout.Stop() + + select { + case err := <-doneCh: + if err != nil { + return nil, err + } + fmt.Println("\nCallback received -- completing authentication.") + return tokenResultFromSession(session) + case result := <-stdinCh: + if result.err != nil { + return nil, result.err + } + return exchangeManualCode(ctx, cfg, session, result.code) + case <-timeout.C: + return nil, fmt.Errorf("timed out waiting for authentication (5 minutes)") + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func printManualInstructions(authURL string) { + fmt.Println("Open this URL in your browser to authenticate:") + fmt.Println() + fmt.Println(" " + authURL) + fmt.Println() + fmt.Print("Paste the authorization code here (or wait for redirect): ") +} + +// exchangeManualCode exchanges a manually pasted authorization code for +// tokens using the session's stored state (preserving the PKCE verifier). +func exchangeManualCode(ctx context.Context, cfg *OAuthConfig, session *KeychainSession, code string) (*TokenResult, error) { + state, err := session.GetState() + if err != nil { + return nil, fmt.Errorf("reading state: %w", err) + } + + if err := exchangeCode(ctx, cfg, session, code, state); err != nil { + return nil, fmt.Errorf("token exchange failed: %w", err) + } + + return tokenResultFromSession(session) +} + +// exchangeCode performs the Authorization Code + PKCE token exchange. +func exchangeCode(ctx context.Context, cfg *OAuthConfig, session *KeychainSession, code, receivedState string) error { + storedState, err := session.GetState() + if err != nil { + return fmt.Errorf("reading state: %w", err) + } + if storedState == "" { + return fmt.Errorf("state not found in session") + } + if storedState != receivedState { + return fmt.Errorf("state mismatch: expected %s, got %s", storedState, receivedState) + } + + codeVerifier, err := session.GetCodeVerifier() + if err != nil { + return fmt.Errorf("reading code verifier: %w", err) + } + if codeVerifier == "" { + return fmt.Errorf("code verifier not found in session") + } + + oauthCfg := newOAuth2Config(cfg) + + token, err := oauthCfg.Exchange(ctx, code, oauth2.SetAuthURLParam("code_verifier", codeVerifier)) + if err != nil { + return fmt.Errorf("exchanging authorization code: %w", err) + } + + return session.SetRawToken(token) +} + +// tokenResultFromSession reads the persisted token and returns it as a TokenResult. +func tokenResultFromSession(session *KeychainSession) (*TokenResult, error) { + tok, err := session.GetRawToken() + if err != nil { + return nil, fmt.Errorf("reading token after exchange: %w", err) + } + if tok == nil { + return nil, fmt.Errorf("no token stored after exchange") + } + return &TokenResult{ + AccessToken: tok.AccessToken, + RefreshToken: tok.RefreshToken, + ExpiresAt: tok.Expiry, + IDToken: session.LastIDToken(), + }, nil +} + +// NewAutoRefreshClient creates an HTTP client that automatically refreshes +// the bearer token when it expires. Uses golang.org/x/oauth2's TokenSource +// which handles refresh transparently. +func NewAutoRefreshClient(ctx context.Context, cfg *OAuthConfig, session *KeychainSession) (*http.Client, error) { + tok, err := session.GetRawToken() + if err != nil { + return nil, fmt.Errorf("reading stored token: %w", err) + } + if tok == nil { + return nil, fmt.Errorf("no bearer token stored") + } + + oauthCfg := newOAuth2Config(cfg) + + // oauth2.Config.TokenSource returns a ReuseTokenSource that auto-refreshes. + ts := oauthCfg.TokenSource(ctx, tok) + + // Wrap to persist refreshed tokens to keychain. + pts := &persistingTokenSource{base: ts, session: session} + return oauth2.NewClient(ctx, pts), nil +} + +// persistingTokenSource wraps an oauth2.TokenSource and persists refreshed +// tokens to the keychain via the session. It tracks the last-seen token +// pointer so that it only writes to the keychain when the underlying +// ReuseTokenSource actually refreshes (returns a new pointer). +type persistingTokenSource struct { + base oauth2.TokenSource + session *KeychainSession + lastTok *oauth2.Token +} + +func (p *persistingTokenSource) Token() (*oauth2.Token, error) { + tok, err := p.base.Token() + if err != nil { + return nil, err + } + // ReuseTokenSource returns the same pointer when the cached token is + // still valid. Only persist when we get a new pointer (i.e., a refresh + // actually happened), avoiding redundant keychain reads+writes. + if tok != p.lastTok { + if err := p.session.SetRawToken(tok); err != nil { + return nil, fmt.Errorf("persisting refreshed token: %w", err) + } + p.lastTok = tok + } + return tok, nil +} + +// RevokeToken revokes a refresh token at the OAuth revocation endpoint. +// Best-effort: callers should log and continue on failure. +func RevokeToken(ctx context.Context, cfg *OAuthConfig, refreshToken string) error { + revokeURL := cfg.Issuer + "/oauth2/revoke" + data := url.Values{ + "client_id": {cfg.ClientID}, + "token": {refreshToken}, + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, revokeURL, strings.NewReader(data.Encode())) + if err != nil { + return fmt.Errorf("building revocation request: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("sending revocation request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("revocation returned status %d", resp.StatusCode) + } + return nil +} + +type codeOrError struct { + code string + err error +} + +func readCodeFromStdin(ch chan<- codeOrError) { + var code string + if _, err := fmt.Scanln(&code); err != nil { + ch <- codeOrError{err: fmt.Errorf("reading authorization code: %w", err)} + return + } + code = strings.TrimSpace(code) + if code == "" { + ch <- codeOrError{err: fmt.Errorf("authorization code cannot be empty")} + return + } + ch <- codeOrError{code: code} +} + +func openBrowser(ctx context.Context, targetURL string) error { + var cmd *exec.Cmd + switch runtime.GOOS { + case "darwin": + cmd = exec.CommandContext(ctx, "open", targetURL) + case "linux": + cmd = exec.CommandContext(ctx, "xdg-open", targetURL) + case "windows": + cmd = exec.CommandContext(ctx, "rundll32", "url.dll,FileProtocolHandler", targetURL) + default: + return fmt.Errorf("unsupported platform: %s", runtime.GOOS) + } + return cmd.Start() +} diff --git a/internal/auth/oauth_test.go b/internal/auth/oauth_test.go new file mode 100644 index 0000000..f6db62d --- /dev/null +++ b/internal/auth/oauth_test.go @@ -0,0 +1,454 @@ +package auth + +import ( + "context" + "encoding/base64" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + gokeyring "github.com/zalando/go-keyring" + "golang.org/x/oauth2" +) + +func init() { + gokeyring.MockInit() +} + +func TestRevokeToken(t *testing.T) { + var receivedClientID, receivedToken string + revokeServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + t.Fatalf("parsing form: %v", err) + } + receivedClientID = r.Form.Get("client_id") + receivedToken = r.Form.Get("token") + w.WriteHeader(http.StatusOK) + })) + defer revokeServer.Close() + + cfg := &OAuthConfig{ + Issuer: revokeServer.URL, + ClientID: "test-client-id", + } + + err := RevokeToken(context.Background(), cfg, "refresh_to_revoke") + if err != nil { + t.Fatalf("RevokeToken: %v", err) + } + if receivedClientID != "test-client-id" { + t.Errorf("client_id = %q, want %q", receivedClientID, "test-client-id") + } + if receivedToken != "refresh_to_revoke" { + t.Errorf("token = %q, want %q", receivedToken, "refresh_to_revoke") + } +} + +func TestRevokeToken_ServerError(t *testing.T) { + revokeServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer revokeServer.Close() + + cfg := &OAuthConfig{ + Issuer: revokeServer.URL, + ClientID: "test-client-id", + } + + err := RevokeToken(context.Background(), cfg, "refresh_tok") + if err == nil { + t.Error("expected error for 500 response") + } +} + +func TestRevokeToken_Accepts2xxCodes(t *testing.T) { + revokeServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + defer revokeServer.Close() + + cfg := &OAuthConfig{ + Issuer: revokeServer.URL, + ClientID: "test-client-id", + } + + err := RevokeToken(context.Background(), cfg, "refresh_tok") + if err != nil { + t.Errorf("RevokeToken should accept 204, got error: %v", err) + } +} + +func TestGenerateAuthURL(t *testing.T) { + session := NewKeychainSession("test-auth-url") + cfg := &OAuthConfig{ + Issuer: "https://test.kinde.com", + ClientID: "test-client-id", + } + + authURL := generateAuthURL(cfg, session) + parsed, err := url.Parse(authURL) + if err != nil { + t.Fatalf("parsing auth URL: %v", err) + } + + if parsed.Host != "test.kinde.com" { + t.Errorf("host = %q, want %q", parsed.Host, "test.kinde.com") + } + + q := parsed.Query() + checks := map[string]string{ + "client_id": "test-client-id", + "response_type": "code", + "code_challenge_method": "S256", + "is_use_auth_success_page": "true", + } + for key, want := range checks { + if got := q.Get(key); got != want { + t.Errorf("query param %q = %q, want %q", key, got, want) + } + } + + if q.Get("code_challenge") == "" { + t.Error("code_challenge should be present (PKCE)") + } + if q.Get("state") == "" { + t.Error("state should be present") + } + + scope := q.Get("scope") + for _, s := range []string{"openid", "profile", "email", "offline"} { + if !strings.Contains(scope, s) { + t.Errorf("scope missing %q in %q", s, scope) + } + } + + // State and code verifier should be stored in session. + state, _ := session.GetState() + if state == "" { + t.Error("session state should be set after generateAuthURL") + } + verifier, _ := session.GetCodeVerifier() + if verifier == "" { + t.Error("session code verifier should be set after generateAuthURL") + } + + if got, ok := q["audience"]; ok { + t.Errorf("audience param should be absent when cfg.Audience empty, got %v", got) + } +} + +func TestGenerateAuthURL_WithAudience(t *testing.T) { + session := NewKeychainSession("test-auth-url-audience") + cfg := &OAuthConfig{ + Issuer: "https://test.kinde.com", + ClientID: "test-client-id", + Audience: "https://api.example.com/api", + } + + authURL := generateAuthURL(cfg, session) + parsed, err := url.Parse(authURL) + if err != nil { + t.Fatalf("parsing auth URL: %v", err) + } + if got := parsed.Query().Get("audience"); got != "https://api.example.com/api" { + t.Errorf("audience param = %q, want %q", got, "https://api.example.com/api") + } +} + +func TestTokenResultFromSession(t *testing.T) { + session := NewKeychainSession("test-token-result") + expiry := time.Now().Add(1 * time.Hour) + + if err := session.SetRawToken(&oauth2.Token{ + AccessToken: "access_xyz", + RefreshToken: "refresh_xyz", + Expiry: expiry, + }); err != nil { + t.Fatalf("SetRawToken: %v", err) + } + + result, err := tokenResultFromSession(session) + if err != nil { + t.Fatalf("tokenResultFromSession: %v", err) + } + if result.AccessToken != "access_xyz" { + t.Errorf("AccessToken = %q, want %q", result.AccessToken, "access_xyz") + } + if result.RefreshToken != "refresh_xyz" { + t.Errorf("RefreshToken = %q, want %q", result.RefreshToken, "refresh_xyz") + } +} + +func TestExchangeManualCode(t *testing.T) { + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "manual_access_token", + "refresh_token": "manual_refresh_token", + "token_type": "Bearer", + "expires_in": 3600, + }) + })) + defer tokenServer.Close() + + cfg := &OAuthConfig{ + Issuer: tokenServer.URL, + ClientID: "test-client-id", + } + session := NewKeychainSession("test-manual-exchange") + + // Populate state and code verifier as generateAuthURL would. + _ = session.SetState("test-state") + _ = session.SetCodeVerifier("test-verifier") + + result, err := exchangeManualCode(context.Background(), cfg, session, "test-auth-code") + if err != nil { + t.Fatalf("exchangeManualCode: %v", err) + } + if result.AccessToken != "manual_access_token" { + t.Errorf("AccessToken = %q, want %q", result.AccessToken, "manual_access_token") + } +} + +func TestExchangeManualCode_ServerError(t *testing.T) { + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]any{ + "error": "invalid_grant", + "error_description": "authorization code expired", + }) + })) + defer tokenServer.Close() + + cfg := &OAuthConfig{ + Issuer: tokenServer.URL, + ClientID: "test-client-id", + } + session := NewKeychainSession("test-exchange-error") + + _ = session.SetState("test-state") + _ = session.SetCodeVerifier("test-verifier") + + _, err := exchangeManualCode(context.Background(), cfg, session, "expired-code") + if err == nil { + t.Fatal("expected error for expired authorization code") + } + errStr := err.Error() + if !strings.Contains(errStr, "token exchange failed") { + t.Errorf("error = %q, want to contain 'token exchange failed'", errStr) + } +} + +func TestExchangeManualCode_CapturesIDToken(t *testing.T) { + // Create a minimal JWT-shaped ID token for testing. + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256"}`)) + payload := base64.RawURLEncoding.EncodeToString([]byte( + `{"email":"test@example.com","name":"Test User","sub":"user-123"}`, + )) + testIDToken := header + "." + payload + ".test-signature" + + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "access_tok", + "refresh_token": "refresh_tok", + "token_type": "Bearer", + "expires_in": 3600, + "id_token": testIDToken, + }) + })) + defer tokenServer.Close() + + cfg := &OAuthConfig{ + Issuer: tokenServer.URL, + ClientID: "test-client-id", + } + session := NewKeychainSession("test-idtoken-capture") + + _ = session.SetState("test-state") + _ = session.SetCodeVerifier("test-verifier") + + result, err := exchangeManualCode(context.Background(), cfg, session, "test-code") + if err != nil { + t.Fatalf("exchangeManualCode: %v", err) + } + + if result.IDToken != testIDToken { + t.Errorf("IDToken not captured; got %q", result.IDToken) + } + if session.LastIDToken() != testIDToken { + t.Errorf("session.LastIDToken() = %q, want test ID token", session.LastIDToken()) + } +} + +func TestExchangeCode_StateMismatch(t *testing.T) { + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + // The server should never be reached -- state validation happens + // client-side before the token exchange request is sent. + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "should_not_get_this", + "refresh_token": "should_not_get_this", + "token_type": "Bearer", + "expires_in": 3600, + }) + })) + defer tokenServer.Close() + + cfg := &OAuthConfig{ + Issuer: tokenServer.URL, + ClientID: "test-client-id", + } + session := NewKeychainSession("test-state-mismatch") + + _ = session.SetState("correct-state") + _ = session.SetCodeVerifier("test-verifier") + + // Exchange with a wrong state value -- should be rejected. + err := exchangeCode(context.Background(), cfg, session, "test-auth-code", "wrong-state-value") + if err == nil { + t.Fatal("expected error when state does not match") + } + if !strings.Contains(err.Error(), "state mismatch") { + t.Errorf("error = %q, want to contain 'state mismatch'", err.Error()) + } + + // The token should NOT have been persisted. + tok, _ := session.GetRawToken() + if tok != nil && tok.AccessToken == "should_not_get_this" { + t.Error("token should not be stored after state mismatch") + } +} + +func TestRevokeToken_NetworkError(t *testing.T) { + cfg := &OAuthConfig{ + Issuer: "http://192.0.2.1", // TEST-NET-1, non-routable + ClientID: "test-client-id", + } + // Use a canceled context to guarantee fast failure. + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := RevokeToken(ctx, cfg, "refresh_tok") + if err == nil { + t.Error("expected error for network failure") + } +} + +func TestGenerateState(t *testing.T) { + s1 := generateState() + s2 := generateState() + + if s1 == "" { + t.Error("state should not be empty") + } + if s1 == s2 { + t.Error("consecutive states should differ") + } + if len(s1) < 20 { + t.Errorf("state too short: %d chars", len(s1)) + } +} + +func TestPersistingTokenSource_SkipsRedundantWrites(t *testing.T) { + session := NewKeychainSession("test-persisting-ts") + + validTok := &oauth2.Token{ + AccessToken: "access_valid", + RefreshToken: "refresh_valid", + Expiry: time.Now().Add(1 * time.Hour), + TokenType: "Bearer", + } + + // fakeSource always returns the same pointer (simulating ReuseTokenSource + // returning its cached token). + calls := 0 + fake := tokenSourceFunc(func() (*oauth2.Token, error) { + calls++ + return validTok, nil + }) + + pts := &persistingTokenSource{base: fake, session: session} + + // First call: should persist (new pointer). + tok1, err := pts.Token() + if err != nil { + t.Fatalf("Token() #1: %v", err) + } + if tok1 != validTok { + t.Error("expected same pointer from first call") + } + + // Second call: same pointer, should skip keychain write. + // We verify by checking that the session still works, and that the + // optimization doesn't break anything. + tok2, err := pts.Token() + if err != nil { + t.Fatalf("Token() #2: %v", err) + } + if tok2 != tok1 { + t.Error("expected same pointer from second call") + } + if calls != 2 { + t.Errorf("underlying source called %d times, want 2", calls) + } + + // Simulate a refresh: underlying source returns a new pointer. + refreshedTok := &oauth2.Token{ + AccessToken: "access_refreshed", + RefreshToken: "refresh_refreshed", + Expiry: time.Now().Add(1 * time.Hour), + TokenType: "Bearer", + } + fake = tokenSourceFunc(func() (*oauth2.Token, error) { + calls++ + return refreshedTok, nil + }) + pts.base = fake + + tok3, err := pts.Token() + if err != nil { + t.Fatalf("Token() #3: %v", err) + } + if tok3 != refreshedTok { + t.Error("expected refreshed pointer") + } + + // Verify the refreshed token was persisted. + stored, err := session.GetRawToken() + if err != nil { + t.Fatalf("GetRawToken: %v", err) + } + if stored.AccessToken != "access_refreshed" { + t.Errorf("persisted AccessToken = %q, want %q", stored.AccessToken, "access_refreshed") + } +} + +// tokenSourceFunc is a helper that adapts a function to oauth2.TokenSource. +type tokenSourceFunc func() (*oauth2.Token, error) + +func (f tokenSourceFunc) Token() (*oauth2.Token, error) { return f() } + +func TestNewOAuth2Config(t *testing.T) { + cfg := &OAuthConfig{ + Issuer: "https://test.kinde.com", + ClientID: "test-client-id", + } + + oauthCfg := newOAuth2Config(cfg) + + if oauthCfg.ClientID != "test-client-id" { + t.Errorf("ClientID = %q, want %q", oauthCfg.ClientID, "test-client-id") + } + if oauthCfg.Endpoint.AuthURL != "https://test.kinde.com/oauth2/auth" { + t.Errorf("AuthURL = %q", oauthCfg.Endpoint.AuthURL) + } + if oauthCfg.Endpoint.TokenURL != "https://test.kinde.com/oauth2/token" { + t.Errorf("TokenURL = %q", oauthCfg.Endpoint.TokenURL) + } +} diff --git a/internal/auth/session.go b/internal/auth/session.go new file mode 100644 index 0000000..8cfd822 --- /dev/null +++ b/internal/auth/session.go @@ -0,0 +1,117 @@ +package auth + +import ( + "fmt" + "time" + + "golang.org/x/oauth2" + + "github.com/getzep/zepctl/internal/keyring" +) + +const bearerTokenType = "Bearer" + +// KeychainSession stores OAuth tokens in the OS keychain and keeps +// ephemeral OAuth state (CSRF state, PKCE code verifier) in memory. +type KeychainSession struct { + profile string + state string + codeVerifier string + // lastIDToken captures the ID token from the most recent SetRawToken + // call. The standard oauth2.Token struct has no dedicated field for + // it, but the token response includes it in Extra("id_token"). + lastIDToken string +} + +// NewKeychainSession creates a session for the given profile. Tokens are +// persisted to the OS keychain; OAuth flow state is held in memory. +func NewKeychainSession(profile string) *KeychainSession { + return &KeychainSession{profile: profile} +} + +// SetRawToken persists an oauth2.Token to the keychain, preserving any +// existing API key in the profile's credential entry. Also captures the +// ID token (if present) for later extraction by the caller. +func (s *KeychainSession) SetRawToken(token *oauth2.Token) error { + creds, err := keyring.GetCredentials(s.profile) + if err != nil { + creds = &keyring.Credentials{} + } + + if token == nil { + // Clear bearer fields but preserve any existing API key. + creds.AccessToken = "" + creds.RefreshToken = "" + creds.ExpiresAt = "" + return keyring.SetCredentials(s.profile, creds) + } + + creds.AccessToken = token.AccessToken + creds.RefreshToken = token.RefreshToken + if !token.Expiry.IsZero() { + creds.ExpiresAt = token.Expiry.Format(time.RFC3339) + } + + // Capture the ID token from the Extra fields. + if idToken, ok := token.Extra("id_token").(string); ok { + s.lastIDToken = idToken + } + + return keyring.SetCredentials(s.profile, creds) +} + +// LastIDToken returns the ID token captured from the most recent +// SetRawToken call. Returns empty string if no ID token was present. +func (s *KeychainSession) LastIDToken() string { + return s.lastIDToken +} + +// GetRawToken reads the stored oauth2.Token from the keychain. +// Returns nil (not an error) if no bearer token is stored. +func (s *KeychainSession) GetRawToken() (*oauth2.Token, error) { + creds, err := keyring.GetCredentials(s.profile) + if err != nil { + return nil, fmt.Errorf("reading credentials: %w", err) + } + + if !creds.HasBearerToken() { + return nil, nil + } + + tok := &oauth2.Token{ + AccessToken: creds.AccessToken, + RefreshToken: creds.RefreshToken, + TokenType: bearerTokenType, + } + + if creds.ExpiresAt != "" { + t, err := time.Parse(time.RFC3339, creds.ExpiresAt) + if err == nil { + tok.Expiry = t + } + } + + return tok, nil +} + +// SetState stores the CSRF state parameter for the current OAuth flow. +func (s *KeychainSession) SetState(state string) error { + s.state = state + return nil +} + +// GetState returns the stored CSRF state parameter. +func (s *KeychainSession) GetState() (string, error) { + return s.state, nil +} + +// SetCodeVerifier stores the PKCE code verifier for the current OAuth flow. +func (s *KeychainSession) SetCodeVerifier(codeVerifier string) error { + s.codeVerifier = codeVerifier + return nil +} + +// GetCodeVerifier returns the stored PKCE code verifier. +func (s *KeychainSession) GetCodeVerifier() (string, error) { + return s.codeVerifier, nil +} diff --git a/internal/auth/session_test.go b/internal/auth/session_test.go new file mode 100644 index 0000000..7953984 --- /dev/null +++ b/internal/auth/session_test.go @@ -0,0 +1,204 @@ +package auth + +import ( + "testing" + "time" + + "github.com/getzep/zepctl/internal/keyring" + gokeyring "github.com/zalando/go-keyring" + "golang.org/x/oauth2" +) + +func init() { + gokeyring.MockInit() +} + +func TestKeychainSession_SetGetRawToken(t *testing.T) { + session := NewKeychainSession("test-session-set-get") + expiry := time.Now().Add(1 * time.Hour).Truncate(time.Second) + + err := session.SetRawToken(&oauth2.Token{ + AccessToken: "access_123", + RefreshToken: "refresh_456", + TokenType: "Bearer", + Expiry: expiry, + }) + if err != nil { + t.Fatalf("SetRawToken: %v", err) + } + + tok, err := session.GetRawToken() + if err != nil { + t.Fatalf("GetRawToken: %v", err) + } + if tok == nil { + t.Fatal("GetRawToken returned nil") + } + if tok.AccessToken != "access_123" { + t.Errorf("AccessToken = %q, want %q", tok.AccessToken, "access_123") + } + if tok.RefreshToken != "refresh_456" { + t.Errorf("RefreshToken = %q, want %q", tok.RefreshToken, "refresh_456") + } + if tok.TokenType != "Bearer" { + t.Errorf("TokenType = %q, want %q", tok.TokenType, "Bearer") + } + if !tok.Expiry.Equal(expiry) { + t.Errorf("Expiry = %v, want %v", tok.Expiry, expiry) + } +} + +func TestKeychainSession_PreservesAPIKey(t *testing.T) { + profile := "test-session-preserve-key" + creds := &keyring.Credentials{APIKey: "z_my_api_key"} + if err := keyring.SetCredentials(profile, creds); err != nil { + t.Fatalf("SetCredentials: %v", err) + } + + session := NewKeychainSession(profile) + err := session.SetRawToken(&oauth2.Token{ + AccessToken: "bearer_tok", + RefreshToken: "refresh_tok", + }) + if err != nil { + t.Fatalf("SetRawToken: %v", err) + } + + got, err := keyring.GetCredentials(profile) + if err != nil { + t.Fatalf("GetCredentials: %v", err) + } + if got.APIKey != "z_my_api_key" { + t.Errorf("API key should be preserved, got %q", got.APIKey) + } + if got.AccessToken != "bearer_tok" { + t.Errorf("AccessToken = %q, want %q", got.AccessToken, "bearer_tok") + } +} + +func TestKeychainSession_SetRawToken_NilPreservesAPIKey(t *testing.T) { + profile := "test-session-nil-token" + creds := &keyring.Credentials{ + APIKey: "z_my_api_key", + AccessToken: "old_access", + RefreshToken: "old_refresh", + ExpiresAt: "2026-04-20T15:30:00Z", + } + if err := keyring.SetCredentials(profile, creds); err != nil { + t.Fatalf("SetCredentials: %v", err) + } + + session := NewKeychainSession(profile) + if err := session.SetRawToken(nil); err != nil { + t.Fatalf("SetRawToken(nil): %v", err) + } + + got, err := keyring.GetCredentials(profile) + if err != nil { + t.Fatalf("GetCredentials: %v", err) + } + if got.APIKey != "z_my_api_key" { + t.Errorf("API key should be preserved, got %q", got.APIKey) + } + if got.HasBearerToken() { + t.Errorf("bearer token should be cleared, got AccessToken=%q", got.AccessToken) + } + if got.ExpiresAt != "" { + t.Errorf("ExpiresAt should be cleared, got %q", got.ExpiresAt) + } +} + +func TestKeychainSession_GetRawToken_NoBearerToken(t *testing.T) { + profile := "test-session-no-bearer" + creds := &keyring.Credentials{APIKey: "z_key_only"} + if err := keyring.SetCredentials(profile, creds); err != nil { + t.Fatalf("SetCredentials: %v", err) + } + + session := NewKeychainSession(profile) + tok, err := session.GetRawToken() + if err != nil { + t.Fatalf("GetRawToken: %v", err) + } + if tok != nil { + t.Errorf("expected nil token for profile without bearer token, got %+v", tok) + } +} + +func TestKeychainSession_EphemeralState(t *testing.T) { + session := NewKeychainSession("test-session-state") + + // State starts empty. + state, err := session.GetState() + if err != nil { + t.Fatalf("GetState: %v", err) + } + if state != "" { + t.Errorf("initial state = %q, want empty", state) + } + + if err := session.SetState("csrf_state_123"); err != nil { + t.Fatalf("SetState: %v", err) + } + state, _ = session.GetState() + if state != "csrf_state_123" { + t.Errorf("state = %q, want %q", state, "csrf_state_123") + } +} + +func TestKeychainSession_EphemeralCodeVerifier(t *testing.T) { + session := NewKeychainSession("test-session-verifier") + + if err := session.SetCodeVerifier("verifier_abc"); err != nil { + t.Fatalf("SetCodeVerifier: %v", err) + } + v, _ := session.GetCodeVerifier() + if v != "verifier_abc" { + t.Errorf("code verifier = %q, want %q", v, "verifier_abc") + } +} + +// TestKeychainSession_SetRawToken_ReplacesExisting verifies that storing a +// new token fully replaces the previous one (refresh token rotation). +func TestKeychainSession_SetRawToken_ReplacesExisting(t *testing.T) { + session := NewKeychainSession("test-session-replace") + + // Store first token. + if err := session.SetRawToken(&oauth2.Token{ + AccessToken: "first_access", + RefreshToken: "first_refresh", + }); err != nil { + t.Fatalf("SetRawToken (first): %v", err) + } + + // Store second token (simulates refresh rotation). + if err := session.SetRawToken(&oauth2.Token{ + AccessToken: "second_access", + RefreshToken: "second_refresh", + }); err != nil { + t.Fatalf("SetRawToken (second): %v", err) + } + + tok, err := session.GetRawToken() + if err != nil { + t.Fatalf("GetRawToken: %v", err) + } + if tok.AccessToken != "second_access" { + t.Errorf("AccessToken = %q, want %q", tok.AccessToken, "second_access") + } + if tok.RefreshToken != "second_refresh" { + t.Errorf("RefreshToken = %q, want %q", tok.RefreshToken, "second_refresh") + } + + // Verify the first token is fully gone from the keychain. + creds, err := keyring.GetCredentials("test-session-replace") + if err != nil { + t.Fatalf("GetCredentials: %v", err) + } + if creds.AccessToken == "first_access" { + t.Error("first access token should have been replaced") + } + if creds.RefreshToken == "first_refresh" { + t.Error("first refresh token should have been replaced") + } +} diff --git a/internal/auth/tokens.go b/internal/auth/tokens.go new file mode 100644 index 0000000..9d4f8d4 --- /dev/null +++ b/internal/auth/tokens.go @@ -0,0 +1,39 @@ +package auth + +import ( + "time" + + "github.com/getzep/zepctl/internal/keyring" +) + +// StoreBearerToken saves a new bearer token to the profile's keychain entry. +// Preserves any existing API key. +func StoreBearerToken(profile string, result *TokenResult, email string) error { + creds, err := keyring.GetCredentials(profile) + if err != nil { + creds = &keyring.Credentials{} + } + + creds.AccessToken = result.AccessToken + creds.RefreshToken = result.RefreshToken + creds.ExpiresAt = result.ExpiresAt.Format(time.RFC3339) + creds.UserEmail = email + + return keyring.SetCredentials(profile, creds) +} + +// ClearBearerToken removes bearer token fields from the profile's keychain +// entry, preserving any existing API key. +func ClearBearerToken(profile string) error { + creds, err := keyring.GetCredentials(profile) + if err != nil { + return nil + } + + creds.AccessToken = "" + creds.RefreshToken = "" + creds.ExpiresAt = "" + creds.UserEmail = "" + + return keyring.SetCredentials(profile, creds) +} diff --git a/internal/auth/tokens_test.go b/internal/auth/tokens_test.go new file mode 100644 index 0000000..b2d92f1 --- /dev/null +++ b/internal/auth/tokens_test.go @@ -0,0 +1,118 @@ +package auth + +import ( + "testing" + "time" + + "github.com/getzep/zepctl/internal/keyring" + gokeyring "github.com/zalando/go-keyring" +) + +func init() { + gokeyring.MockInit() +} + +func TestStoreBearerToken_PreservesAPIKey(t *testing.T) { + profile := "test-store-preserve" + creds := &keyring.Credentials{APIKey: "z_existing_key"} + if err := keyring.SetCredentials(profile, creds); err != nil { + t.Fatalf("SetCredentials: %v", err) + } + + result := &TokenResult{ + AccessToken: "new_bearer", + RefreshToken: "new_refresh", + ExpiresAt: time.Now().Add(1 * time.Hour), + } + + if err := StoreBearerToken(profile, result, "user@example.com"); err != nil { + t.Fatalf("StoreBearerToken: %v", err) + } + + got, err := keyring.GetCredentials(profile) + if err != nil { + t.Fatalf("GetCredentials: %v", err) + } + + if got.APIKey != "z_existing_key" { + t.Errorf("APIKey = %q, want %q (should be preserved)", got.APIKey, "z_existing_key") + } + if got.AccessToken != "new_bearer" { + t.Errorf("AccessToken = %q, want %q", got.AccessToken, "new_bearer") + } + if got.UserEmail != "user@example.com" { + t.Errorf("UserEmail = %q, want %q", got.UserEmail, "user@example.com") + } +} + +func TestClearBearerToken_PreservesAPIKey(t *testing.T) { + profile := "test-clear-preserve" + creds := &keyring.Credentials{ + APIKey: "z_keep_this", + AccessToken: "clear_this", + RefreshToken: "clear_this_too", + ExpiresAt: "2026-04-20T15:30:00Z", + UserEmail: "clear@example.com", + } + if err := keyring.SetCredentials(profile, creds); err != nil { + t.Fatalf("SetCredentials: %v", err) + } + + if err := ClearBearerToken(profile); err != nil { + t.Fatalf("ClearBearerToken: %v", err) + } + + got, err := keyring.GetCredentials(profile) + if err != nil { + t.Fatalf("GetCredentials: %v", err) + } + + if got.APIKey != "z_keep_this" { + t.Errorf("APIKey = %q, want %q (should be preserved)", got.APIKey, "z_keep_this") + } + if got.HasBearerToken() { + t.Error("bearer token should be cleared") + } + if got.UserEmail != "" { + t.Errorf("UserEmail = %q, want empty", got.UserEmail) + } +} + +func TestStoreBearerToken_NewProfile(t *testing.T) { + profile := "test-store-new-profile" + + result := &TokenResult{ + AccessToken: "new_bearer", + RefreshToken: "new_refresh", + ExpiresAt: time.Now().Add(1 * time.Hour), + } + + if err := StoreBearerToken(profile, result, "user@example.com"); err != nil { + t.Fatalf("StoreBearerToken: %v", err) + } + + got, err := keyring.GetCredentials(profile) + if err != nil { + t.Fatalf("GetCredentials: %v", err) + } + if got.AccessToken != "new_bearer" { + t.Errorf("AccessToken = %q, want %q", got.AccessToken, "new_bearer") + } + if got.RefreshToken != "new_refresh" { + t.Errorf("RefreshToken = %q, want %q", got.RefreshToken, "new_refresh") + } + if got.UserEmail != "user@example.com" { + t.Errorf("UserEmail = %q, want %q", got.UserEmail, "user@example.com") + } + // No API key should be set on a fresh profile. + if got.HasAPIKey() { + t.Errorf("new profile should not have API key, got %q", got.APIKey) + } +} + +func TestClearBearerToken_NonexistentProfile(t *testing.T) { + // Clearing bearer token on a profile that doesn't exist should not error. + if err := ClearBearerToken("test-clear-nonexistent"); err != nil { + t.Errorf("ClearBearerToken on nonexistent profile: %v", err) + } +} diff --git a/internal/cli/api_key_abac.go b/internal/cli/api_key_abac.go new file mode 100644 index 0000000..cb428db --- /dev/null +++ b/internal/cli/api_key_abac.go @@ -0,0 +1,364 @@ +package cli + +import ( + "fmt" + "io" + "strings" + + "github.com/spf13/cobra" + + "github.com/getzep/zepctl/internal/abac" + "github.com/getzep/zepctl/internal/client" + "github.com/getzep/zepctl/internal/output" +) + +const abacModeEnforce = "enforce" + +var ( + validModesList = []string{"off", "report_only", abacModeEnforce} + validModes = func() map[string]bool { + m := make(map[string]bool, len(validModesList)) + for _, v := range validModesList { + m[v] = true + } + return m + }() +) + +var apiKeyCmd = &cobra.Command{ + Use: "api-key", + Short: "Manage API key ABAC configuration", + Long: "Configure ABAC settings and policy set attachments for API keys.", +} + +var apiKeyListCmd = &cobra.Command{ + Use: listCmdUse, + Short: "List API keys in the current project", + RunE: func(cmd *cobra.Command, _ []string) error { + ac, err := newABACClient(cmd) + if err != nil { + return err + } + result, err := ac.ListAPIKeys(cmd.Context()) + if err != nil { + return fmt.Errorf("listing API keys: %w", err) + } + if output.GetFormat() == output.FormatTable { + tbl := output.NewTable("UUID", "NAME", "KEY", "ROLE") + tbl.WriteHeader() + for _, k := range result.Keys { + key := k.FirstFour + "..." + k.LastFour + tbl.WriteRow(k.UUID, k.Name, key, k.Role) + } + return tbl.Flush() + } + return output.Print(result) + }, +} + +// --- Settings subgroup --- + +var apiKeySettingsCmd = &cobra.Command{ + Use: "settings", + Short: "Manage ABAC settings for an API key", +} + +var apiKeySettingsGetCmd = &cobra.Command{ + Use: "get ", + Short: "Get ABAC settings for an API key", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + if err := validateUUID(args[0], "API key"); err != nil { + return err + } + ac, err := newABACClient(cmd) + if err != nil { + return err + } + settings, err := ac.GetAPIKeySettings(cmd.Context(), args[0]) + if err != nil { + if abac.IsNotFound(err) { + return fmt.Errorf("API key not found: %s", args[0]) + } + return fmt.Errorf("getting API key settings: %w", err) + } + if output.GetFormat() == output.FormatTable { + fmt.Fprintf(cmd.OutOrStdout(), "ABAC Mode: %s\n", settings.ABACMode) + fmt.Fprintf(cmd.OutOrStdout(), "Capabilities: %s\n", settings.Capabilities) + return nil + } + return output.Print(settings) + }, +} + +var apiKeySettingsSetCmd = &cobra.Command{ + Use: "set ", + Short: "Update ABAC settings for an API key", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + if err := validateUUID(args[0], "API key"); err != nil { + return err + } + mode, _ := cmd.Flags().GetString("mode") + if mode == "" { + return fmt.Errorf("at least one setting flag is required (e.g., --mode)") + } + if !validModes[mode] { + return fmt.Errorf("invalid mode %q: must be one of: %s", mode, strings.Join(validModesList, ", ")) + } + ac, err := newABACClient(cmd) + if err != nil { + return err + } + settings, err := ac.SetAPIKeySettings(cmd.Context(), args[0], mode) + if err != nil { + if abac.IsNotFound(err) { + return fmt.Errorf("API key not found: %s", args[0]) + } + return fmt.Errorf("updating API key settings: %w", err) + } + if output.GetFormat() == output.FormatTable { + output.Info("Updated API key settings:\n ABAC Mode: %s\n Capabilities: %s", + settings.ABACMode, settings.Capabilities) + return nil + } + return output.Print(settings) + }, +} + +// --- Policy Sets subgroup --- + +var apiKeyPolicySetsCmd = &cobra.Command{ + Use: "policy-sets", + Short: "Manage policy set attachments for an API key", +} + +var apiKeyPolicySetsListCmd = &cobra.Command{ + Use: "list ", + Short: "List policy sets attached to an API key", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + if err := validateUUID(args[0], "API key"); err != nil { + return err + } + ac, err := newABACClient(cmd) + if err != nil { + return err + } + result, err := ac.ListAPIKeyPolicySets(cmd.Context(), args[0]) + if err != nil { + if abac.IsNotFound(err) { + return fmt.Errorf("API key not found: %s", args[0]) + } + return fmt.Errorf("listing attached policy sets: %w", err) + } + if output.GetFormat() == output.FormatTable { + tbl := output.NewTable("UUID", "NAME", "MODE", "VERSION") + tbl.WriteHeader() + for _, ps := range result.PolicySets { + tbl.WriteRow(ps.UUID, ps.Name, ps.Mode, fmt.Sprintf("%d", ps.Version)) + } + return tbl.Flush() + } + return output.Print(result) + }, +} + +var apiKeyPolicySetsAttachCmd = &cobra.Command{ + Use: "attach ", + Short: "Attach a policy set to an API key", + Args: cobra.ExactArgs(2), + RunE: func(cmd *cobra.Command, args []string) error { + if err := validateUUID(args[0], "API key"); err != nil { + return err + } + if err := validateUUID(args[1], "policy set"); err != nil { + return err + } + ac, err := newABACClient(cmd) + if err != nil { + return err + } + result, err := ac.AttachPolicySet(cmd.Context(), args[0], args[1]) + if err != nil { + return fmt.Errorf("attaching policy set: %w", err) + } + if output.GetFormat() == output.FormatTable { + for _, ps := range result.PolicySets { + if ps.UUID == args[1] { + output.Info("Attached policy set %q to API key %s", ps.Name, args[0]) + return nil + } + } + output.Info("Attached policy set to API key %s", args[0]) + return nil + } + return output.Print(result) + }, +} + +var apiKeyPolicySetsDetachCmd = &cobra.Command{ + Use: "detach ", + Short: "Detach a policy set from an API key", + Args: cobra.ExactArgs(2), + RunE: func(cmd *cobra.Command, args []string) error { + if err := validateUUID(args[0], "API key"); err != nil { + return err + } + if err := validateUUID(args[1], "policy set"); err != nil { + return err + } + ac, err := newABACClient(cmd) + if err != nil { + return err + } + result, err := ac.DetachPolicySet(cmd.Context(), args[0], args[1]) + if err != nil { + return fmt.Errorf("detaching policy set: %w", err) + } + if output.GetFormat() == output.FormatTable { + output.Info("Detached policy set from API key %s", args[0]) + return nil + } + return output.Print(result) + }, +} + +// --- Dry-Run subcommands (Sections 4.7, 4.8) --- + +var apiKeyEvaluateCmd = &cobra.Command{ + Use: "evaluate ", + Short: "Dry-run a policy decision against an API key's live configuration", + Long: "Asks the server what the evaluator would decide for the given action " + + "and API key, without performing the request. Always exits 0 on a successful " + + "API call regardless of allow/deny outcome.", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + if err := validateUUID(args[0], "API key"); err != nil { + return err + } + action, _ := cmd.Flags().GetString("action") + ac, err := newABACClient(cmd) + if err != nil { + return err + } + result, err := ac.EvaluatePolicy(cmd.Context(), args[0], action) + if err != nil { + if abac.IsNotFound(err) { + return fmt.Errorf("API key not found: %s", args[0]) + } + return fmt.Errorf("evaluating policy: %w", err) + } + if output.GetFormat() == output.FormatTable { + renderEvaluateTable(cmd.OutOrStdout(), result) + return nil + } + return output.FprintRaw(cmd.OutOrStdout(), result.RawJSON) + }, +} + +var apiKeyExplainCmd = &cobra.Command{ + Use: "explain ", + Short: "Dry-run a policy decision and return the evaluator trace", + Long: "Like evaluate, but returns the full evaluator trace (matched policies, " + + "set modes, skipped sets) for human inspection. Always exits 0 on a successful " + + "API call regardless of allow/deny outcome.", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + if err := validateUUID(args[0], "API key"); err != nil { + return err + } + action, _ := cmd.Flags().GetString("action") + ac, err := newABACClient(cmd) + if err != nil { + return err + } + result, err := ac.ExplainPolicy(cmd.Context(), args[0], action) + if err != nil { + if abac.IsNotFound(err) { + return fmt.Errorf("API key not found: %s", args[0]) + } + return fmt.Errorf("explaining policy: %w", err) + } + if output.GetFormat() == output.FormatTable { + renderExplainTable(cmd.OutOrStdout(), result) + return nil + } + return output.FprintRaw(cmd.OutOrStdout(), result.RawJSON) + }, +} + +func renderEvaluateTable(w io.Writer, r *abac.EvaluateResponse) { + fmt.Fprintf(w, "Outcome: %s\n", r.Outcome) + fmt.Fprintf(w, "ABAC: %s\n", r.ABAC) + fmt.Fprintf(w, "ABAC shadow: %s\n", r.ABACShadow) + fmt.Fprintf(w, "Role allows: %t\n", r.RoleAllows) + fmt.Fprintf(w, "Would log disagreement: %t\n", r.WouldLogDisagreement) +} + +func renderExplainTable(w io.Writer, r *abac.ExplainResponse) { + fmt.Fprintf(w, "Action: %s\n", r.Trace.Action) + fmt.Fprintf(w, "Outcome: %s\n", r.Outcome) + fmt.Fprintf(w, "ABAC: %s\n", r.ABAC) + fmt.Fprintf(w, "ABAC shadow: %s\n", r.ABACShadow) + fmt.Fprintf(w, "Role allows: %t\n", r.RoleAllows) + fmt.Fprintf(w, "Registry read-only: %t\n", r.Trace.RegistryEntry.ReadOnly) + fmt.Fprintf(w, "Would log disagreement: %t\n", r.WouldLogDisagreement) + fmt.Fprintln(w) + + if len(r.Trace.EvaluatedSets) == 0 { + fmt.Fprintln(w, "Evaluated policy sets: (none)") + } else { + fmt.Fprintln(w, "Evaluated policy sets:") + for _, s := range r.Trace.EvaluatedSets { + header := fmt.Sprintf(" %s (%s, set mode: %s)", s.Name, truncateUUID(s.UUID), s.SetMode) + if len(s.Matched) == 0 { + fmt.Fprintln(w, header+" -- no policies matched") + continue + } + fmt.Fprintln(w, header) + for _, m := range s.Matched { + fmt.Fprintf(w, " [%s] %s -- matched via %s\n", m.Effect, m.PolicyID, m.MatchedVia) + } + } + } + fmt.Fprintln(w) + + if len(r.Trace.SkippedSets) == 0 { + fmt.Fprintln(w, "Skipped policy sets: (none)") + return + } + fmt.Fprintln(w, "Skipped policy sets:") + for _, s := range r.Trace.SkippedSets { + fmt.Fprintf(w, " %s (%s) -- %s\n", s.Name, truncateUUID(s.UUID), s.Reason) + } +} + +func init() { + rootCmd.AddCommand(apiKeyCmd) + apiKeyCmd.AddCommand(apiKeyListCmd) + apiKeyCmd.AddCommand(apiKeySettingsCmd) + apiKeyCmd.AddCommand(apiKeyPolicySetsCmd) + apiKeyCmd.AddCommand(apiKeyEvaluateCmd) + apiKeyCmd.AddCommand(apiKeyExplainCmd) + + apiKeySettingsCmd.AddCommand(apiKeySettingsGetCmd) + apiKeySettingsCmd.AddCommand(apiKeySettingsSetCmd) + apiKeyPolicySetsCmd.AddCommand(apiKeyPolicySetsListCmd) + apiKeyPolicySetsCmd.AddCommand(apiKeyPolicySetsAttachCmd) + apiKeyPolicySetsCmd.AddCommand(apiKeyPolicySetsDetachCmd) + + apiKeySettingsSetCmd.Flags().String("mode", "", "ABAC enforcement mode: off, report_only, enforce") + apiKeyEvaluateCmd.Flags().String("action", "", "Action name to evaluate (e.g. thread.get)") + apiKeyExplainCmd.Flags().String("action", "", "Action name to explain (e.g. thread.get)") + _ = apiKeyEvaluateCmd.MarkFlagRequired("action") + _ = apiKeyExplainCmd.MarkFlagRequired("action") + + for _, cmd := range []*cobra.Command{ + apiKeyListCmd, apiKeySettingsGetCmd, apiKeySettingsSetCmd, + apiKeyPolicySetsListCmd, apiKeyPolicySetsAttachCmd, apiKeyPolicySetsDetachCmd, + apiKeyEvaluateCmd, apiKeyExplainCmd, + } { + client.SetCredentialType(cmd, client.CredentialBearer) + } +} diff --git a/internal/cli/api_key_abac_test.go b/internal/cli/api_key_abac_test.go new file mode 100644 index 0000000..ea39e26 --- /dev/null +++ b/internal/cli/api_key_abac_test.go @@ -0,0 +1,625 @@ +package cli + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/spf13/cobra" + "github.com/spf13/viper" + + "github.com/getzep/zepctl/internal/abac" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// testUUID2 is a second arbitrary valid UUID (testUUID is in policy_set_test.go). +const testUUID2 = "a1b2c3d4-5e6f-7a8b-9c0d-1e2f3a4b5c6d" + +// --- Credential Type --- + +func TestAPIKeyCommands_CredentialType(t *testing.T) { + cmds := map[string]*cobra.Command{ + "list": apiKeyListCmd, + "settings-get": apiKeySettingsGetCmd, + "settings-set": apiKeySettingsSetCmd, + "policy-sets-list": apiKeyPolicySetsListCmd, + "policy-sets-attach": apiKeyPolicySetsAttachCmd, + "policy-sets-detach": apiKeyPolicySetsDetachCmd, + "evaluate": apiKeyEvaluateCmd, + "explain": apiKeyExplainCmd, + } + for name, cmd := range cmds { + t.Run(name, func(t *testing.T) { + require.NotNil(t, cmd.Annotations, "command %q has no annotations", name) + assert.Equal(t, "bearer", cmd.Annotations["zepctl_credential_type"], + "command %q should declare CredentialBearer", name) + }) + } +} + +// --- UUID Validation (API key commands) --- + +func TestAPIKeySettingsGet_InvalidUUID(t *testing.T) { + err := apiKeySettingsGetCmd.RunE(apiKeySettingsGetCmd, []string{"bad"}) + require.Error(t, err) + assert.Contains(t, err.Error(), `invalid API key UUID: "bad"`) +} + +func TestAPIKeySettingsSet_InvalidUUID(t *testing.T) { + err := apiKeySettingsSetCmd.RunE(apiKeySettingsSetCmd, []string{"bad"}) + require.Error(t, err) + assert.Contains(t, err.Error(), `invalid API key UUID: "bad"`) +} + +func TestAPIKeyPolicySetsAttach_InvalidAPIKeyUUID(t *testing.T) { + err := apiKeyPolicySetsAttachCmd.RunE(apiKeyPolicySetsAttachCmd, + []string{"bad", testUUID}) + require.Error(t, err) + assert.Contains(t, err.Error(), `invalid API key UUID: "bad"`) +} + +func TestAPIKeyPolicySetsAttach_InvalidPolicySetUUID(t *testing.T) { + err := apiKeyPolicySetsAttachCmd.RunE(apiKeyPolicySetsAttachCmd, + []string{testUUID, "bad"}) + require.Error(t, err) + assert.Contains(t, err.Error(), `invalid policy set UUID: "bad"`) +} + +func TestAPIKeyPolicySetsDetach_InvalidAPIKeyUUID(t *testing.T) { + err := apiKeyPolicySetsDetachCmd.RunE(apiKeyPolicySetsDetachCmd, + []string{"bad", testUUID}) + require.Error(t, err) + assert.Contains(t, err.Error(), `invalid API key UUID: "bad"`) +} + +func TestAPIKeyPolicySetsDetach_InvalidPolicySetUUID(t *testing.T) { + err := apiKeyPolicySetsDetachCmd.RunE(apiKeyPolicySetsDetachCmd, + []string{testUUID, "bad"}) + require.Error(t, err) + assert.Contains(t, err.Error(), `invalid policy set UUID: "bad"`) +} + +// --- Mode Validation --- + +func TestAPIKeySettingsSet_NoFlags(t *testing.T) { + _ = apiKeySettingsSetCmd.Flags().Set("mode", "") + err := apiKeySettingsSetCmd.RunE(apiKeySettingsSetCmd, + []string{testUUID}) + require.Error(t, err) + assert.Contains(t, err.Error(), "at least one setting flag is required") +} + +func TestAPIKeySettingsSet_InvalidMode(t *testing.T) { + _ = apiKeySettingsSetCmd.Flags().Set("mode", "invalid") + defer func() { _ = apiKeySettingsSetCmd.Flags().Set("mode", "") }() + + err := apiKeySettingsSetCmd.RunE(apiKeySettingsSetCmd, + []string{testUUID}) + require.Error(t, err) + assert.Contains(t, err.Error(), `invalid mode "invalid": must be one of: off, report_only, enforce`) +} + +func TestAPIKeySettingsSet_ValidModes(t *testing.T) { + for _, mode := range []string{"off", "report_only", "enforce"} { + t.Run(mode, func(t *testing.T) { + assert.True(t, validModes[mode], "mode %q should be valid", mode) + }) + } +} + +// --- API Key Settings Get (Sections 4.1.3, 4.1.4) --- + +func TestAPIKeySettingsGet_TableOutput(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(abac.APIKeySettings{ABACMode: "enforce", Capabilities: "read"}) + })) + defer srv.Close() + withTestABACClient(t, srv) + + cmd := newTestCommand(apiKeySettingsGetCmd) + var buf bytes.Buffer + cmd.SetOut(&buf) + + err := cmd.RunE(cmd, []string{testUUID}) + assert.NoError(t, err) + + out := buf.String() + assert.Contains(t, out, "ABAC Mode:") + assert.Contains(t, out, "enforce") + assert.Contains(t, out, "Capabilities:") + assert.Contains(t, out, "read") +} + +func TestAPIKeySettingsGet_NotFound(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _ = json.NewEncoder(w).Encode(map[string]string{"message": "not found"}) + })) + defer srv.Close() + withTestABACClient(t, srv) + + cmd := newTestCommand(apiKeySettingsGetCmd) + err := cmd.RunE(cmd, []string{testUUID}) + require.Error(t, err) + assert.Contains(t, err.Error(), "API key not found:") +} + +// --- API Key Settings Set (Sections 4.2.3, 4.2.5) --- + +func TestAPIKeySettingsSet_TableOutput(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(abac.APIKeySettings{ABACMode: "enforce", Capabilities: "read"}) + })) + defer srv.Close() + withTestABACClient(t, srv) + + cmd := newTestCommand(apiKeySettingsSetCmd) + _ = cmd.Flags().Set("mode", "enforce") + defer func() { _ = apiKeySettingsSetCmd.Flags().Set("mode", "") }() + + var runErr error + out := captureStderr(t, func() { + runErr = cmd.RunE(cmd, []string{testUUID}) + }) + assert.NoError(t, runErr) + + assert.Contains(t, out, "Updated API key settings:") + assert.Contains(t, out, "ABAC Mode:") + assert.Contains(t, out, "enforce") + assert.Contains(t, out, "Capabilities:") + assert.Contains(t, out, "read") +} + +func TestAPIKeySettingsSet_NotFound(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _ = json.NewEncoder(w).Encode(map[string]string{"message": "not found"}) + })) + defer srv.Close() + withTestABACClient(t, srv) + + cmd := newTestCommand(apiKeySettingsSetCmd) + _ = cmd.Flags().Set("mode", "enforce") + defer func() { _ = apiKeySettingsSetCmd.Flags().Set("mode", "") }() + + err := cmd.RunE(cmd, []string{testUUID}) + require.Error(t, err) + assert.Contains(t, err.Error(), "API key not found:") +} + +// --- API Key Policy Sets List (Sections 4.3.2, 4.3.3, 4.3.4) --- + +func TestAPIKeyPolicySetsListCmd_InvalidUUID(t *testing.T) { + err := apiKeyPolicySetsListCmd.RunE(apiKeyPolicySetsListCmd, []string{"bad"}) + require.Error(t, err) + assert.Contains(t, err.Error(), `invalid API key UUID: "bad"`) +} + +func TestAPIKeyPolicySetsListCmd_TableOutput(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(abac.AttachedPolicySets{ + PolicySets: []abac.PolicySetSummary{ + {UUID: "ps-1", Name: "policy_one", Mode: "enforce", Version: 2}, + {UUID: "ps-2", Name: "policy_two", Mode: "off", Version: 1}, + }, + }) + })) + defer srv.Close() + withTestABACClient(t, srv) + + cmd := newTestCommand(apiKeyPolicySetsListCmd) + var runErr error + out := captureStdout(t, func() { + runErr = cmd.RunE(cmd, []string{testUUID}) + }) + assert.NoError(t, runErr) + + assert.Contains(t, out, "UUID") + assert.Contains(t, out, "NAME") + assert.Contains(t, out, "MODE") + assert.Contains(t, out, "VERSION") + assert.Contains(t, out, "ps-1") + assert.Contains(t, out, "policy_one") + assert.Contains(t, out, "enforce") + assert.Contains(t, out, "ps-2") + assert.Contains(t, out, "policy_two") + + // Verify header + 2 data rows. + lines := strings.Split(strings.TrimSpace(out), "\n") + assert.Equal(t, 3, len(lines)) +} + +func TestAPIKeyPolicySetsListCmd_NotFound(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _ = json.NewEncoder(w).Encode(map[string]string{"message": "not found"}) + })) + defer srv.Close() + withTestABACClient(t, srv) + + cmd := newTestCommand(apiKeyPolicySetsListCmd) + err := cmd.RunE(cmd, []string{testUUID}) + require.Error(t, err) + assert.Contains(t, err.Error(), "API key not found:") +} + +// --- API Key Policy Sets Attach (Sections 4.4.3, 4.4.4) --- + +func TestAPIKeyPolicySetsAttach_TableOutput(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(abac.AttachedPolicySets{ + PolicySets: []abac.PolicySetSummary{ + {UUID: testUUID2, Name: "attached_policy", Mode: "enforce", Version: 1}, + }, + }) + })) + defer srv.Close() + withTestABACClient(t, srv) + + cmd := newTestCommand(apiKeyPolicySetsAttachCmd) + var runErr error + out := captureStderr(t, func() { + runErr = cmd.RunE(cmd, []string{testUUID, testUUID2}) + }) + assert.NoError(t, runErr) + + assert.Contains(t, out, "Attached policy set") + assert.Contains(t, out, `"attached_policy"`) + assert.Contains(t, out, testUUID) +} + +func TestAPIKeyPolicySetsAttach_Idempotent(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(abac.AttachedPolicySets{ + PolicySets: []abac.PolicySetSummary{ + {UUID: testUUID2, Name: "policy", Mode: "enforce", Version: 1}, + }, + }) + })) + defer srv.Close() + withTestABACClient(t, srv) + + // Call twice -- both should succeed (idempotent attach). + for i := 0; i < 2; i++ { + cmd := newTestCommand(apiKeyPolicySetsAttachCmd) + captureStderr(t, func() { + err := cmd.RunE(cmd, []string{testUUID, testUUID2}) + assert.NoError(t, err) + }) + } +} + +// --- API Key Policy Sets Detach (Sections 4.5.3, 4.5.4) --- + +func TestAPIKeyPolicySetsDetach_TableOutput(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(abac.AttachedPolicySets{PolicySets: []abac.PolicySetSummary{}}) + })) + defer srv.Close() + withTestABACClient(t, srv) + + cmd := newTestCommand(apiKeyPolicySetsDetachCmd) + var runErr error + out := captureStderr(t, func() { + runErr = cmd.RunE(cmd, []string{testUUID, testUUID2}) + }) + assert.NoError(t, runErr) + + assert.Contains(t, out, "Detached policy set from API key") + assert.Contains(t, out, testUUID) +} + +func TestAPIKeyPolicySetsDetach_Idempotent(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(abac.AttachedPolicySets{PolicySets: []abac.PolicySetSummary{}}) + })) + defer srv.Close() + withTestABACClient(t, srv) + + // Call twice -- both should succeed (idempotent detach). + for i := 0; i < 2; i++ { + cmd := newTestCommand(apiKeyPolicySetsDetachCmd) + captureStderr(t, func() { + err := cmd.RunE(cmd, []string{testUUID, testUUID2}) + assert.NoError(t, err) + }) + } +} + +// --- API Key Evaluate --- + +func TestAPIKeyEvaluate_InvalidUUID(t *testing.T) { + err := apiKeyEvaluateCmd.RunE(apiKeyEvaluateCmd, []string{"bad"}) + require.Error(t, err) + assert.Contains(t, err.Error(), `invalid API key UUID: "bad"`) +} + +// Locks in MarkFlagRequired so a future refactor that drops it is caught. +// Cobra enforces required flags during Execute, which the RunE-direct tests +// in this file bypass; this assertion is the cheap belt-and-braces. +func TestAPIKeyEvaluate_ActionFlagRequired(t *testing.T) { + ann := apiKeyEvaluateCmd.Flags().Lookup("action").Annotations[cobra.BashCompOneRequiredFlag] + require.Len(t, ann, 1) + assert.Equal(t, "true", ann[0]) +} + +func TestAPIKeyEvaluate_TableOutput(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"outcome":"allow","abac":"ALLOW","abac_shadow":"NEUTRAL","role_allows":false,"would_log_disagreement":true}`)) + })) + defer srv.Close() + withTestABACClient(t, srv) + + cmd := newTestCommand(apiKeyEvaluateCmd) + _ = cmd.Flags().Set("action", "thread.get") + defer func() { _ = apiKeyEvaluateCmd.Flags().Set("action", "") }() + + var buf bytes.Buffer + cmd.SetOut(&buf) + + err := cmd.RunE(cmd, []string{testUUID}) + require.NoError(t, err) + out := buf.String() + assert.Contains(t, out, "Outcome: allow") + assert.Contains(t, out, "ABAC: ALLOW") + assert.Contains(t, out, "ABAC shadow: NEUTRAL") + assert.Contains(t, out, "Role allows: false") + assert.Contains(t, out, "Would log disagreement: true") +} + +// Allow/deny outcome must not change the exit code -- the operator +// asked a question and the server answered it. +func TestAPIKeyEvaluate_DenyOutcomeExitsZero(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"outcome":"deny","abac":"NEUTRAL","abac_shadow":"NEUTRAL","role_allows":false,"would_log_disagreement":false}`)) + })) + defer srv.Close() + withTestABACClient(t, srv) + + cmd := newTestCommand(apiKeyEvaluateCmd) + _ = cmd.Flags().Set("action", "thread.delete") + defer func() { _ = apiKeyEvaluateCmd.Flags().Set("action", "") }() + var buf bytes.Buffer + cmd.SetOut(&buf) + + err := cmd.RunE(cmd, []string{testUUID}) + require.NoError(t, err) + assert.Contains(t, buf.String(), "Outcome: deny") +} + +func TestAPIKeyEvaluate_NotFound(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _ = json.NewEncoder(w).Encode(map[string]string{"message": "not found"}) + })) + defer srv.Close() + withTestABACClient(t, srv) + + cmd := newTestCommand(apiKeyEvaluateCmd) + _ = cmd.Flags().Set("action", "thread.get") + defer func() { _ = apiKeyEvaluateCmd.Flags().Set("action", "") }() + + err := cmd.RunE(cmd, []string{testUUID}) + require.Error(t, err) + assert.Contains(t, err.Error(), "API key not found:") +} + +func TestAPIKeyEvaluate_BadAction(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"message": "action is not in the registry"}) + })) + defer srv.Close() + withTestABACClient(t, srv) + + cmd := newTestCommand(apiKeyEvaluateCmd) + _ = cmd.Flags().Set("action", "no.such.action") + defer func() { _ = apiKeyEvaluateCmd.Flags().Set("action", "") }() + + err := cmd.RunE(cmd, []string{testUUID}) + require.Error(t, err) + assert.Contains(t, err.Error(), "evaluating policy:") + assert.Contains(t, err.Error(), "action is not in the registry") +} + +// --- API Key Explain --- + +func TestAPIKeyExplain_InvalidUUID(t *testing.T) { + err := apiKeyExplainCmd.RunE(apiKeyExplainCmd, []string{"bad"}) + require.Error(t, err) + assert.Contains(t, err.Error(), `invalid API key UUID: "bad"`) +} + +func TestAPIKeyExplain_ActionFlagRequired(t *testing.T) { + ann := apiKeyExplainCmd.Flags().Lookup("action").Annotations[cobra.BashCompOneRequiredFlag] + require.Len(t, ann, 1) + assert.Equal(t, "true", ann[0]) +} + +func TestAPIKeyExplain_TableOutput_AllowWithMatch(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "outcome": "allow", "abac": "ALLOW", "abac_shadow": "NEUTRAL", + "role_allows": false, "would_log_disagreement": true, + "trace": { + "action": "thread.get", + "registry_entry": {"read_only": true}, + "evaluated_sets": [ + {"uuid": "9f1a2b3c-set-uuid", "name": "read_everything", "set_mode": "enforce", + "matched": [{"policy_id": "allow_reads", "effect": "allow", "matched_via": "readonly"}]} + ], + "skipped_sets": [] + } + }`)) + })) + defer srv.Close() + withTestABACClient(t, srv) + + cmd := newTestCommand(apiKeyExplainCmd) + _ = cmd.Flags().Set("action", "thread.get") + defer func() { _ = apiKeyExplainCmd.Flags().Set("action", "") }() + var buf bytes.Buffer + cmd.SetOut(&buf) + + err := cmd.RunE(cmd, []string{testUUID}) + require.NoError(t, err) + out := buf.String() + + assert.Contains(t, out, "Action: thread.get") + assert.Contains(t, out, "Registry read-only: true") + assert.Contains(t, out, "Evaluated policy sets:") + assert.Contains(t, out, "read_everything (9f1a2b3c..., set mode: enforce)") + assert.Contains(t, out, " [allow] allow_reads -- matched via readonly") + assert.Contains(t, out, "Skipped policy sets: (none)") +} + +// Empty matched on an evaluated set must render the "no policies matched" +// suffix on the header. +func TestAPIKeyExplain_TableOutput_NoMatchedPolicies(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "outcome": "deny", "abac": "NEUTRAL", "abac_shadow": "NEUTRAL", + "role_allows": false, "would_log_disagreement": false, + "trace": { + "action": "thread.delete", + "registry_entry": {"read_only": false}, + "evaluated_sets": [ + {"uuid": "9f1a2b3c-set-uuid", "name": "read_everything", "set_mode": "enforce", "matched": []} + ], + "skipped_sets": [] + } + }`)) + })) + defer srv.Close() + withTestABACClient(t, srv) + + cmd := newTestCommand(apiKeyExplainCmd) + _ = cmd.Flags().Set("action", "thread.delete") + defer func() { _ = apiKeyExplainCmd.Flags().Set("action", "") }() + var buf bytes.Buffer + cmd.SetOut(&buf) + + err := cmd.RunE(cmd, []string{testUUID}) + require.NoError(t, err) + out := buf.String() + assert.Contains(t, out, "read_everything (9f1a2b3c..., set mode: enforce) -- no policies matched") +} + +// Empty evaluated_sets must render the "(none)" parallel form. +func TestAPIKeyExplain_TableOutput_NoEvaluatedSets(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "outcome": "deny", "abac": "NEUTRAL", "abac_shadow": "NEUTRAL", + "role_allows": false, "would_log_disagreement": false, + "trace": { + "action": "thread.get", + "registry_entry": {"read_only": true}, + "evaluated_sets": [], + "skipped_sets": [ + {"uuid": "1a2b3c4d-set-uuid", "name": "off_set", "reason": "set mode is off"} + ] + } + }`)) + })) + defer srv.Close() + withTestABACClient(t, srv) + + cmd := newTestCommand(apiKeyExplainCmd) + _ = cmd.Flags().Set("action", "thread.get") + defer func() { _ = apiKeyExplainCmd.Flags().Set("action", "") }() + var buf bytes.Buffer + cmd.SetOut(&buf) + + err := cmd.RunE(cmd, []string{testUUID}) + require.NoError(t, err) + out := buf.String() + assert.Contains(t, out, "Evaluated policy sets: (none)") + assert.Contains(t, out, "off_set (1a2b3c4d...) -- set mode is off") +} + +// An unknown trace field must not break table rendering and must be +// preserved by the JSON path (forward-compat). +func TestAPIKeyExplain_TableOutput_UnknownTraceField(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "outcome": "allow", "abac": "ALLOW", "abac_shadow": "NEUTRAL", + "role_allows": false, "would_log_disagreement": true, + "trace": { + "action": "thread.get", + "registry_entry": {"read_only": true, "future_capability": "redact"}, + "future_top_level": "surprise", + "evaluated_sets": [], + "skipped_sets": [] + } + }`)) + })) + defer srv.Close() + withTestABACClient(t, srv) + + cmd := newTestCommand(apiKeyExplainCmd) + _ = cmd.Flags().Set("action", "thread.get") + defer func() { _ = apiKeyExplainCmd.Flags().Set("action", "") }() + var buf bytes.Buffer + cmd.SetOut(&buf) + + err := cmd.RunE(cmd, []string{testUUID}) + require.NoError(t, err) + out := buf.String() + // Table render still succeeds with the known fields. + assert.Contains(t, out, "Action: thread.get") + assert.Contains(t, out, "Registry read-only: true") +} + +// End-to-end forward-compat: --output json must surface server-added +// fields the typed shape doesn't model. This locks in the CLI -> +// FprintRaw -> stdout path that table-only tests don't exercise. +func TestAPIKeyExplain_JSONOutput_PreservesUnknownFields(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "outcome": "allow", "abac": "ALLOW", "abac_shadow": "NEUTRAL", + "role_allows": false, "would_log_disagreement": true, + "trace": { + "action": "thread.get", + "registry_entry": {"read_only": true, "future_capability": "redact"}, + "future_top_level": "surprise", + "evaluated_sets": [], + "skipped_sets": [] + } + }`)) + })) + defer srv.Close() + withTestABACClient(t, srv) + + viper.Set("output", "json") + defer viper.Set("output", "") + + cmd := newTestCommand(apiKeyExplainCmd) + _ = cmd.Flags().Set("action", "thread.get") + defer func() { _ = apiKeyExplainCmd.Flags().Set("action", "") }() + var buf bytes.Buffer + cmd.SetOut(&buf) + + err := cmd.RunE(cmd, []string{testUUID}) + require.NoError(t, err) + out := buf.String() + assert.Contains(t, out, `"future_capability": "redact"`) + assert.Contains(t, out, `"future_top_level": "surprise"`) +} diff --git a/internal/cli/auth.go b/internal/cli/auth.go new file mode 100644 index 0000000..de1ebe3 --- /dev/null +++ b/internal/cli/auth.go @@ -0,0 +1,305 @@ +package cli + +import ( + "context" + "fmt" + "net/http" + "strconv" + "strings" + "time" + + "github.com/getzep/zepctl/internal/auth" + "github.com/getzep/zepctl/internal/client" + "github.com/getzep/zepctl/internal/config" + "github.com/getzep/zepctl/internal/keyring" + "github.com/getzep/zepctl/internal/output" + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +var authCmd = &cobra.Command{ + Use: "auth", + Short: "Manage bearer token authentication", + Long: `Authenticate with Zep using your browser to obtain a bearer token for CLI access.`, +} + +var authLoginCmd = &cobra.Command{ + Use: "login", + Short: "Authenticate via browser to obtain a bearer token", + RunE: func(cmd *cobra.Command, args []string) error { + noBrowser, _ := cmd.Flags().GetBool("no-browser") + + cfg, err := config.Load() + if err != nil { + return fmt.Errorf("loading config: %w", err) + } + + profile := cfg.GetCurrentProfile() + if profile == nil { + // No profile exists -- create one. Use the name passed via + // --profile if provided, otherwise "default", so users can + // bootstrap an isolated profile via + // `zepctl --profile foo auth login`. + name := viper.GetString("profile") + if name == "" { + name = "default" + } + + // Validate --env before the interactive browser flow. + envName, _ := cmd.Flags().GetString("env") + newProfile := config.Profile{Name: name, APIURL: auth.DefaultAPIURL} + if envName != "" { + env := cfg.GetEnvironment(envName) + if env == nil { + return fmt.Errorf("environment %q not found; configure it with \"zepctl config add-environment\"", envName) + } + newProfile.APIURL = env.APIURL + newProfile.OAuthIssuer = env.OAuthIssuer + newProfile.OAuthClientID = env.OAuthClientID + newProfile.OAuthAudience = env.OAuthAudience + } + + cfg.Profiles = append(cfg.Profiles, newProfile) + cfg.CurrentProfile = name + if err := cfg.Save(); err != nil { + return fmt.Errorf("creating profile %q: %w", name, err) + } + profile = cfg.GetProfile(name) + } else if envName, _ := cmd.Flags().GetString("env"); envName != "" { + // --env on an existing profile is ambiguous (mutate? ignore?). Force + // the user through update-profile so the change is explicit. + return fmt.Errorf("profile %q already exists; use \"config update-profile --env %s\" to change its environment", profile.Name, envName) + } + + // If there's an existing bearer token, revoke its refresh token first. + oauthCfg := auth.OAuthConfigFor(profile.OAuthIssuer, profile.OAuthClientID, profile.OAuthAudience) + creds, err := keyring.GetCredentials(profile.Name) + if err == nil && creds.RefreshToken != "" { + if err := auth.RevokeToken(cmd.Context(), oauthCfg, creds.RefreshToken); err != nil { + output.Warn("Could not revoke existing token: %v", err) + } + } + + session := auth.NewKeychainSession(profile.Name) + result, err := auth.Login(cmd.Context(), oauthCfg, session, noBrowser) + if err != nil { + return fmt.Errorf("authentication failed: %w", err) + } + + // The SDK persists tokens via SetRawToken, but we also need to + // store the user email (extracted from the ID token) in the + // keychain entry. + email := "" + if result.IDToken != "" { + if claims, err := auth.ParseUnverifiedIDToken(result.IDToken); err == nil { + email = claims.Email + } + } + if email != "" { + // Update the keychain entry with the email. + creds, err = keyring.GetCredentials(profile.Name) + if err == nil { + creds.UserEmail = email + _ = keyring.SetCredentials(profile.Name, creds) + } + } + + if email != "" { + output.Info("Authenticated as %s", email) + } else { + output.Info("Authenticated successfully") + } + + // Auto-select project if needed. + if profile.ProjectUUID == "" { + if err := autoSelectProject(cmd.Context(), cfg, profile, result.AccessToken); err != nil { + output.Warn("Could not auto-select project: %v", err) + output.Info("Run \"zepctl config set-project\" to select a project.") + } + } + + return nil + }, +} + +var authLogoutCmd = &cobra.Command{ + Use: "logout", + Short: "Clear bearer token for the current profile", + RunE: func(cmd *cobra.Command, args []string) error { + cfg, err := config.Load() + if err != nil { + return fmt.Errorf("loading config: %w", err) + } + + profile := cfg.GetCurrentProfile() + if profile == nil { + return fmt.Errorf("no active profile") + } + + // Revoke refresh token at Kinde (best-effort). + creds, err := keyring.GetCredentials(profile.Name) + if err == nil && creds.RefreshToken != "" { + oauthCfg := auth.OAuthConfigFor(profile.OAuthIssuer, profile.OAuthClientID, profile.OAuthAudience) + if err := auth.RevokeToken(cmd.Context(), oauthCfg, creds.RefreshToken); err != nil { + output.Warn("Could not revoke token at server: %v", err) + } + } + + if err := auth.ClearBearerToken(profile.Name); err != nil { + return fmt.Errorf("clearing token: %w", err) + } + + output.Info("Bearer token cleared for profile %q.", profile.Name) + return nil + }, +} + +var authStatusCmd = &cobra.Command{ + Use: "status", + Short: "Display authentication status for the current profile", + RunE: func(cmd *cobra.Command, args []string) error { + cfg, err := config.Load() + if err != nil { + return fmt.Errorf("loading config: %w", err) + } + + profile := cfg.GetCurrentProfile() + if profile == nil { + return fmt.Errorf("no active profile") + } + + creds, err := keyring.GetCredentials(profile.Name) + if err != nil { + return fmt.Errorf("reading credentials: %w", err) + } + + apiURL := profile.APIURL + if apiURL == "" { + apiURL = "(SDK default)" + } + + oauthCfg := auth.OAuthConfigFor(profile.OAuthIssuer, profile.OAuthClientID, profile.OAuthAudience) + + fmt.Printf("Profile: %s\n", profile.Name) + fmt.Printf("API URL: %s\n", apiURL) + fmt.Printf("OIDC issuer: %s\n", oauthCfg.Issuer) + + if creds.HasAPIKey() { + masked := maskKey(creds.APIKey) + fmt.Printf("API key: %s\n", masked) + } else { + fmt.Println("API key: not configured") + } + + if creds.HasBearerToken() { + if creds.IsExpired() { + fmt.Println("Bearer token: expired") + } else { + d := creds.ExpiresIn() + fmt.Printf("Bearer token: valid (expires in %s)\n", formatDuration(d)) + } + if creds.UserEmail != "" { + fmt.Printf("User: %s\n", creds.UserEmail) + } + } else { + fmt.Println("Bearer token: not configured") + } + + return nil + }, +} + +// autoSelectProject resolves the user's account and selects a project +// after auth login. +func autoSelectProject(ctx context.Context, cfg *config.Config, profile *config.Profile, accessToken string) error { + apiURL := config.GetAPIURL() + if apiURL == "" { + apiURL = auth.DefaultAPIURL + } + + httpClient := &http.Client{ + Transport: &client.BearerTransport{Token: accessToken, Base: http.DefaultTransport}, + } + + accountUUID, projects, err := authenticateAndGetProjects(ctx, httpClient, apiURL) + if err != nil { + return err + } + profile.AccountUUID = accountUUID + // Persist account UUID immediately so it survives even if the user has + // zero projects (early return below) or interactive selection is aborted. + if err := cfg.Save(); err != nil { + return fmt.Errorf("saving account UUID: %w", err) + } + + if len(projects) == 0 { + return fmt.Errorf("no projects found") + } + + if len(projects) == 1 { + profile.ProjectUUID = projects[0].UUID + if err := cfg.Save(); err != nil { + return fmt.Errorf("saving config: %w", err) + } + output.Info("Auto-selected project %q (%s)", projects[0].Name, projects[0].UUID) + return nil + } + + fmt.Println("Multiple projects found:") + for i, p := range projects { + fmt.Printf(" %d. %s (%s)\n", i+1, p.Name, p.UUID) + } + fmt.Printf("Select a project [1-%d]: ", len(projects)) + + var choice string + if _, err := fmt.Scanln(&choice); err != nil { + return fmt.Errorf("reading selection: %w", err) + } + + idx, err := strconv.Atoi(strings.TrimSpace(choice)) + if err != nil || idx < 1 || idx > len(projects) { + return fmt.Errorf("invalid selection: %q", choice) + } + + selected := projects[idx-1] + profile.ProjectUUID = selected.UUID + if err := cfg.Save(); err != nil { + return fmt.Errorf("saving config: %w", err) + } + output.Info("Active project set to %s (%s)", selected.UUID, selected.Name) + return nil +} + +// maskKey returns the first 2 and last 4 characters of a key with "..." in between. +func maskKey(key string) string { + if len(key) <= 6 { + return key + } + return key[:2] + "..." + key[len(key)-4:] +} + +// formatDuration returns a human-readable duration string. +func formatDuration(d time.Duration) string { + if d < time.Minute { + return fmt.Sprintf("%d seconds", int(d.Seconds())) + } + if d < time.Hour { + return fmt.Sprintf("%d minutes", int(d.Minutes())) + } + h := int(d.Hours()) + m := int(d.Minutes()) % 60 + if m == 0 { + return fmt.Sprintf("%d hours", h) + } + return fmt.Sprintf("%dh %dm", h, m) +} + +func init() { + rootCmd.AddCommand(authCmd) + authCmd.AddCommand(authLoginCmd) + authCmd.AddCommand(authLogoutCmd) + authCmd.AddCommand(authStatusCmd) + + authLoginCmd.Flags().Bool("no-browser", false, "Print authorization URL instead of opening browser") + authLoginCmd.Flags().String("env", "", "When creating a new profile, apply a named environment preset (see \"config add-environment\")") +} diff --git a/internal/cli/auth_test.go b/internal/cli/auth_test.go new file mode 100644 index 0000000..8d6a7b7 --- /dev/null +++ b/internal/cli/auth_test.go @@ -0,0 +1,520 @@ +package cli + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/getzep/zepctl/internal/config" + "github.com/getzep/zepctl/internal/keyring" + "github.com/spf13/cobra" + "github.com/spf13/viper" + gokeyring "github.com/zalando/go-keyring" +) + +func init() { + gokeyring.MockInit() +} + +func TestMaskKey(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"z_abc123def456", "z_...f456"}, + {"short", "short"}, + {"z_7x2f", "z_7x2f"}, + {"z_abcdefghijklmnop", "z_...mnop"}, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := maskKey(tt.input) + if got != tt.want { + t.Errorf("maskKey(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestFormatDuration(t *testing.T) { + tests := []struct { + input time.Duration + want string + }{ + {30 * time.Second, "30 seconds"}, + {47 * time.Minute, "47 minutes"}, + {90 * time.Minute, "1h 30m"}, + {1 * time.Second, "1 seconds"}, + {2 * time.Hour, "2 hours"}, + {time.Hour + 15*time.Minute, "1h 15m"}, + } + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + got := formatDuration(tt.input) + if got != tt.want { + t.Errorf("formatDuration(%v) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +// writeTestConfig writes a minimal config file to tmpDir so config.Load works. +func writeTestConfig(t *testing.T, tmpDir string) { + t.Helper() + configDir := filepath.Join(tmpDir, ".zepctl") + if err := os.MkdirAll(configDir, 0o700); err != nil { + t.Fatal(err) + } + cfg := "current-profile: test\nprofiles:\n - name: test\n api-url: https://api.getzep.com\n" + if err := os.WriteFile(filepath.Join(configDir, "config.yaml"), []byte(cfg), 0o600); err != nil { + t.Fatal(err) + } +} + +// resetCmdFlags clears named flag values to their registered defaults and +// resets the Changed bit. Cobra commands are package-level globals; +// without this, flag.Set state survives across tests. +func resetCmdFlags(t *testing.T, cmd *cobra.Command, names ...string) { + t.Helper() + for _, n := range names { + if f := cmd.Flags().Lookup(n); f != nil { + _ = f.Value.Set(f.DefValue) + f.Changed = false + } + } +} + +// captureStdout runs fn and returns everything written to os.Stdout. +// Uses defer to restore os.Stdout even if fn panics. +func captureStdout(t *testing.T, fn func()) string { + t.Helper() + old := os.Stdout + r, w, err := os.Pipe() + if err != nil { + t.Fatal(err) + } + os.Stdout = w + defer func() { + w.Close() + os.Stdout = old + }() + + fn() + + w.Close() + os.Stdout = old + + var buf bytes.Buffer + _, _ = io.Copy(&buf, r) + return buf.String() +} + +// setupAutoSelectTest sets up viper, HOME, and a mock server for +// autoSelectProject tests. The mock /api/web/v1/authenticate endpoint +// returns accountUUID alongside the projects produced by projectsFn. +func setupAutoSelectTest(t *testing.T, accountUUID string, projectsFn func() ([]projectInfo, int)) (*config.Config, *config.Profile) { + t.Helper() + t.Setenv("HOME", t.TempDir()) + viper.Reset() + t.Cleanup(func() { viper.Reset() }) + + mux := http.NewServeMux() + mux.HandleFunc("/api/web/v1/authenticate", func(w http.ResponseWriter, _ *http.Request) { + projects, status := projectsFn() + if status != 0 && status != http.StatusOK { + w.WriteHeader(status) + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "account_uuid": accountUUID, + "projects": projects, + }) + }) + srv := httptest.NewServer(mux) + t.Cleanup(srv.Close) + + viper.Set("api-url", srv.URL) + + cfg := &config.Config{ + CurrentProfile: "test", + Profiles: []config.Profile{{Name: "test"}}, + } + profile := cfg.GetProfile("test") + return cfg, profile +} + +// runAuthStatus sets up a config with a "test" profile, seeds the keychain +// with creds, runs authStatusCmd, and returns the captured stdout. +func runAuthStatus(t *testing.T, creds *keyring.Credentials) string { + t.Helper() + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + viper.Reset() + t.Cleanup(func() { viper.Reset() }) + + writeTestConfig(t, tmpDir) + _, _ = config.Reload() + + if err := keyring.SetCredentials("test", creds); err != nil { + t.Fatalf("SetCredentials: %v", err) + } + + cmd := &cobra.Command{Use: "test"} + var runErr error + got := captureStdout(t, func() { + runErr = authStatusCmd.RunE(cmd, nil) + }) + if runErr != nil { + t.Fatalf("auth status: %v", runErr) + } + return got +} + +func TestAutoSelectProject_SingleProject(t *testing.T) { + cfg, profile := setupAutoSelectTest(t, "acc-123", func() ([]projectInfo, int) { + return []projectInfo{{UUID: "proj-1", Name: "My Project"}}, http.StatusOK + }) + + err := autoSelectProject(context.Background(), cfg, profile, "test-token") + if err != nil { + t.Fatalf("autoSelectProject: %v", err) + } + if profile.ProjectUUID != "proj-1" { + t.Errorf("ProjectUUID = %q, want %q", profile.ProjectUUID, "proj-1") + } + if profile.AccountUUID != "acc-123" { + t.Errorf("AccountUUID = %q, want %q", profile.AccountUUID, "acc-123") + } +} + +// TestAutoSelectProject_NoProjects verifies that when authenticate returns +// zero projects, autoSelectProject errors out -- but the account UUID is +// still persisted to the profile for follow-up commands. +func TestAutoSelectProject_NoProjects(t *testing.T) { + cfg, profile := setupAutoSelectTest(t, "acc-456", func() ([]projectInfo, int) { + return []projectInfo{}, http.StatusOK + }) + + err := autoSelectProject(context.Background(), cfg, profile, "test-token") + if err == nil { + t.Fatal("expected error for no projects") + } + if !strings.Contains(err.Error(), "no projects found") { + t.Errorf("error = %q, want to contain %q", err.Error(), "no projects found") + } + if profile.AccountUUID != "acc-456" { + t.Errorf("AccountUUID = %q, want %q (should be persisted even when no projects)", + profile.AccountUUID, "acc-456") + } +} + +func TestAutoSelectProject_AuthenticateFails(t *testing.T) { + cfg, profile := setupAutoSelectTest(t, "acc-789", func() ([]projectInfo, int) { + return nil, http.StatusInternalServerError + }) + + err := autoSelectProject(context.Background(), cfg, profile, "test-token") + if err == nil { + t.Fatal("expected error from failed authenticate") + } + if profile.AccountUUID != "" { + t.Errorf("AccountUUID = %q, want empty (no save when authenticate fails)", profile.AccountUUID) + } +} + +func TestAuthStatusOutput_BearerOnly(t *testing.T) { + got := runAuthStatus(t, &keyring.Credentials{ + AccessToken: "eyJhbGci.valid_tok", + RefreshToken: "refresh_tok", + ExpiresAt: time.Now().Add(47 * time.Minute).Format(time.RFC3339), + UserEmail: "test@example.com", + }) + + checks := []string{ + "API key: not configured", + "valid (expires in", + "test@example.com", + } + for _, want := range checks { + if !strings.Contains(got, want) { + t.Errorf("output missing %q:\n%s", want, got) + } + } +} + +func TestAuthStatusOutput_BothCredentials(t *testing.T) { + got := runAuthStatus(t, &keyring.Credentials{ + APIKey: "z_test_key_12345", + AccessToken: "eyJhbGci.test", + RefreshToken: "refresh_tok", + ExpiresAt: time.Now().Add(47 * time.Minute).Format(time.RFC3339), + UserEmail: "fred@frobozz.infocom", + }) + + checks := []string{ + "z_...2345", // masked API key + "valid (expires in", // bearer status + "fred@frobozz.infocom", // user email + } + for _, want := range checks { + if !strings.Contains(got, want) { + t.Errorf("output missing %q:\n%s", want, got) + } + } +} + +func TestAuthStatusOutput_APIKeyOnly(t *testing.T) { + got := runAuthStatus(t, &keyring.Credentials{ + APIKey: "z_only_key_9999", + }) + + if !strings.Contains(got, "z_...9999") { + t.Errorf("output missing masked API key:\n%s", got) + } + if !strings.Contains(got, "Bearer token: not configured") { + t.Errorf("output should show bearer as not configured:\n%s", got) + } +} + +// TestLogoutRevocationFailure_StillClearsToken verifies that auth logout +// clears bearer token fields even when the revocation request fails. +func TestLogoutRevocationFailure_StillClearsToken(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + viper.Reset() + t.Cleanup(func() { viper.Reset() }) + + writeTestConfig(t, tmpDir) + _, _ = config.Reload() + + // Seed keychain with both credentials and a refresh token that + // will trigger the revocation path. + if err := keyring.SetCredentials("test", &keyring.Credentials{ + APIKey: "z_preserved_key", + AccessToken: "old_access", + RefreshToken: "old_refresh", + ExpiresAt: time.Now().Add(time.Hour).Format(time.RFC3339), + UserEmail: "user@example.com", + }); err != nil { + t.Fatalf("SetCredentials: %v", err) + } + + // Use a canceled context so RevokeToken (which hits the real Kinde + // issuer URL) fails immediately. ClearBearerToken does not use the + // context, so the rest of logout proceeds. + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + cmd := &cobra.Command{Use: "test"} + cmd.SetContext(ctx) + + var runErr error + captureStdout(t, func() { + runErr = authLogoutCmd.RunE(cmd, nil) + }) + if runErr != nil { + t.Fatalf("auth logout should succeed even when revocation fails: %v", runErr) + } + + // Verify bearer token was cleared but API key preserved. + creds, err := keyring.GetCredentials("test") + if err != nil { + t.Fatalf("GetCredentials: %v", err) + } + if creds.HasBearerToken() { + t.Error("bearer token should be cleared after logout") + } + if creds.APIKey != "z_preserved_key" { + t.Errorf("API key should be preserved, got %q", creds.APIKey) + } + if creds.UserEmail != "" { + t.Errorf("UserEmail should be cleared, got %q", creds.UserEmail) + } +} + +func resetAuthLoginFlags(t *testing.T) { + t.Helper() + resetCmdFlags(t, authLoginCmd, "no-browser", "env") +} + +// TestAuthLogin_BootstrapWithEnv_AppliesEnvToNewProfile verifies that +// `zepctl --profile foo auth login --env development` creates the new +// profile with api-url/oauth-issuer/oauth-client-id taken from the +// environment preset. The OAuth flow itself is short-circuited via a +// canceled context -- by that point the profile has already been written +// to disk, which is what we're asserting. +func TestAuthLogin_BootstrapWithEnv_AppliesEnvToNewProfile(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + viper.Reset() + t.Cleanup(func() { viper.Reset() }) + + configDir := filepath.Join(tmpDir, ".zepctl") + if err := os.MkdirAll(configDir, 0o700); err != nil { + t.Fatal(err) + } + yaml := `current-profile: "" +profiles: [] +environments: + - name: development + api-url: https://api.dev.example.com + oauth-issuer: https://issuer.example.com + oauth-client-id: dev-client-id +` + if err := os.WriteFile(filepath.Join(configDir, "config.yaml"), []byte(yaml), 0o600); err != nil { + t.Fatal(err) + } + _, _ = config.Reload() + resetAuthLoginFlags(t) + t.Cleanup(func() { resetAuthLoginFlags(t) }) + + if err := authLoginCmd.Flags().Set("env", "development"); err != nil { + t.Fatalf("setting flag: %v", err) + } + viper.Set("profile", "dev") + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + cmd := &cobra.Command{Use: "test"} + cmd.SetContext(ctx) + // Mirror authLoginCmd's flagset onto our test cmd so RunE can read + // --env and --no-browser via cmd.Flags(). + cmd.Flags().AddFlagSet(authLoginCmd.Flags()) + + // Login will fail (canceled ctx + no callback server), but the profile + // is created and saved before the OAuth flow begins. + captureStdout(t, func() { + _ = authLoginCmd.RunE(cmd, nil) + }) + + cfg, _ := config.Reload() + p := cfg.GetProfile("dev") + if p == nil { + t.Fatal("profile 'dev' not created from env preset") + } + if p.APIURL != "https://api.dev.example.com" { + t.Errorf("APIURL = %q, want from env", p.APIURL) + } + if p.OAuthIssuer != "https://issuer.example.com" { + t.Errorf("OAuthIssuer = %q, want from env", p.OAuthIssuer) + } + if p.OAuthClientID != "dev-client-id" { + t.Errorf("OAuthClientID = %q, want from env", p.OAuthClientID) + } +} + +func TestAuthLogin_BootstrapWithUnknownEnv_ErrorsBeforeAuth(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + viper.Reset() + t.Cleanup(func() { viper.Reset() }) + + // Empty config, no environments defined. + configDir := filepath.Join(tmpDir, ".zepctl") + if err := os.MkdirAll(configDir, 0o700); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(configDir, "config.yaml"), []byte("current-profile: \"\"\nprofiles: []\n"), 0o600); err != nil { + t.Fatal(err) + } + _, _ = config.Reload() + resetAuthLoginFlags(t) + t.Cleanup(func() { resetAuthLoginFlags(t) }) + + if err := authLoginCmd.Flags().Set("env", "missing-env"); err != nil { + t.Fatalf("setting flag: %v", err) + } + viper.Set("profile", "x") + + cmd := &cobra.Command{Use: "test"} + cmd.SetContext(context.Background()) + cmd.Flags().AddFlagSet(authLoginCmd.Flags()) + + err := authLoginCmd.RunE(cmd, nil) + if err == nil { + t.Fatal("expected error for unknown environment") + } + if !strings.Contains(err.Error(), "not found") { + t.Errorf("error = %q, want to mention 'not found'", err.Error()) + } + + // And the profile must NOT have been created -- we want the failure to + // happen before any persistent change to config. + cfg, _ := config.Reload() + if cfg.GetProfile("x") != nil { + t.Error("profile 'x' should not be created on unknown env") + } +} + +func TestAuthLogin_ExistingProfileWithEnv_Errors(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + viper.Reset() + t.Cleanup(func() { viper.Reset() }) + + configDir := filepath.Join(tmpDir, ".zepctl") + if err := os.MkdirAll(configDir, 0o700); err != nil { + t.Fatal(err) + } + yaml := `current-profile: existing +profiles: + - name: existing + api-url: https://existing.example.com +environments: + - name: development + api-url: https://api.dev.example.com +` + if err := os.WriteFile(filepath.Join(configDir, "config.yaml"), []byte(yaml), 0o600); err != nil { + t.Fatal(err) + } + _, _ = config.Reload() + resetAuthLoginFlags(t) + t.Cleanup(func() { resetAuthLoginFlags(t) }) + + if err := authLoginCmd.Flags().Set("env", "development"); err != nil { + t.Fatalf("setting flag: %v", err) + } + + cmd := &cobra.Command{Use: "test"} + cmd.SetContext(context.Background()) + cmd.Flags().AddFlagSet(authLoginCmd.Flags()) + + err := authLoginCmd.RunE(cmd, nil) + if err == nil { + t.Fatal("expected error: --env on existing profile is ambiguous") + } + if !strings.Contains(err.Error(), "update-profile") { + t.Errorf("error = %q, want to redirect to 'update-profile'", err.Error()) + } + + // Profile must be unchanged. + cfg, _ := config.Reload() + p := cfg.GetProfile("existing") + if p == nil || p.APIURL != "https://existing.example.com" { + t.Errorf("profile mutated: %+v", p) + } +} + +func TestAuthStatusOutput_BearerExpired(t *testing.T) { + got := runAuthStatus(t, &keyring.Credentials{ + AccessToken: "expired_tok", + RefreshToken: "refresh_tok", + ExpiresAt: time.Now().Add(-1 * time.Hour).Format(time.RFC3339), + }) + + if !strings.Contains(got, "Bearer token: expired") { + t.Errorf("output should show bearer as expired:\n%s", got) + } + if !strings.Contains(got, "API key: not configured") { + t.Errorf("output should show API key as not configured:\n%s", got) + } +} diff --git a/internal/cli/config.go b/internal/cli/config.go index ed1b6b2..8c7335a 100644 --- a/internal/cli/config.go +++ b/internal/cli/config.go @@ -96,44 +96,83 @@ var configAddProfileCmd = &cobra.Command{ return fmt.Errorf("loading config: %w", err) } - if cfg.GetProfile(name) != nil { - return fmt.Errorf("profile %q already exists", name) + noAPIKeyFlag, _ := cmd.Flags().GetBool("no-api-key") + + // Check if profile already exists. + existing := cfg.GetProfile(name) + if existing != nil { + // --no-api-key on an existing profile is ambiguous: either it + // already has the right shape, or the user wants something else. + // Tell them to use update-profile instead. + if noAPIKeyFlag { + return fmt.Errorf("profile %q already exists; use \"config update-profile\" to modify it", name) + } + // If the existing profile has no API key (bearer-only from auth login), + // offer to add the key to the existing profile. + creds, err := keyring.GetCredentials(name) + if err != nil || creds.HasAPIKey() { + return fmt.Errorf("profile %q already exists", name) + } + + fmt.Printf("Profile %q has no API key. Add one to this profile? [Y/n]: ", name) + reader := bufio.NewReader(os.Stdin) + response, _ := reader.ReadString('\n') + response = strings.TrimSpace(strings.ToLower(response)) + if response != "" && response != "y" && response != "yes" { + output.Info("Aborted") + return nil + } } apiKey, _ := cmd.Flags().GetString("api-key") - apiURL, _ := cmd.Flags().GetString("api-url") - - if apiKey == "" { - fmt.Print("API Key: ") - if term.IsTerminal(int(os.Stdin.Fd())) { - keyBytes, err := term.ReadPassword(int(os.Stdin.Fd())) - fmt.Println() // newline after hidden input - if err != nil { - return fmt.Errorf("reading API key: %w", err) - } - apiKey = string(keyBytes) - } else { - // Fallback for non-terminal input (piped) - reader := bufio.NewReader(os.Stdin) - apiKey, _ = reader.ReadString('\n') + envName, _ := cmd.Flags().GetString("env") + noAPIKey, _ := cmd.Flags().GetBool("no-api-key") + + // Resolve --env up front so we fail fast on unknown names before we + // touch the keychain. + var env *config.Environment + if envName != "" { + env = cfg.GetEnvironment(envName) + if env == nil { + return fmt.Errorf("environment %q not found; configure it with \"zepctl config add-environment\"", envName) } - apiKey = strings.TrimSpace(apiKey) } - if apiKey == "" { - return fmt.Errorf("API key cannot be empty") - } + if !noAPIKey { + if apiKey == "" { + fmt.Print("API Key: ") + if term.IsTerminal(int(os.Stdin.Fd())) { + keyBytes, err := term.ReadPassword(int(os.Stdin.Fd())) + fmt.Println() // newline after hidden input + if err != nil { + return fmt.Errorf("reading API key: %w", err) + } + apiKey = string(keyBytes) + } else { + // Fallback for non-terminal input (piped) + reader := bufio.NewReader(os.Stdin) + apiKey, _ = reader.ReadString('\n') + } + apiKey = strings.TrimSpace(apiKey) + } - // Store API key in system keychain - if err := keyring.Set(name, apiKey); err != nil { - return fmt.Errorf("storing API key: %w", err) + if apiKey == "" { + return fmt.Errorf("API key cannot be empty (pass --no-api-key for a bearer-only profile)") + } + + // Store API key in system keychain (preserves bearer token if present). + if err := keyring.Set(name, apiKey); err != nil { + return fmt.Errorf("storing API key: %w", err) + } } - // apiURL can be empty - the SDK will use its default - cfg.Profiles = append(cfg.Profiles, config.Profile{ - Name: name, - APIURL: apiURL, - }) + if existing == nil { + newProfile := config.Profile{Name: name} + applyEnvAndOverrides(&newProfile, env, cmd) + cfg.Profiles = append(cfg.Profiles, newProfile) + } else { + applyEnvAndOverrides(existing, env, cmd) + } if cfg.CurrentProfile == "" { cfg.CurrentProfile = name @@ -143,7 +182,14 @@ var configAddProfileCmd = &cobra.Command{ return fmt.Errorf("saving config: %w", err) } - output.Info("Added profile %q (API key stored in system keychain)", name) + switch { + case existing != nil: + output.Info("Added API key to existing profile %q", name) + case noAPIKey: + output.Info("Added profile %q (no API key -- bearer auth only)", name) + default: + output.Info("Added profile %q (API key stored in system keychain)", name) + } return nil }, } @@ -205,15 +251,131 @@ var configDeleteProfileCmd = &cobra.Command{ }, } +var configUpdateProfileCmd = &cobra.Command{ + Use: "update-profile [name]", + Short: "Update a profile's settings", + Long: `Update fields on an existing profile. If no name is given, updates the +current active profile. Only the flags you provide are changed; other +fields are left as-is.`, + Args: cobra.MaximumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + cfg, err := config.Load() + if err != nil { + return fmt.Errorf("loading config: %w", err) + } + + name := cfg.CurrentProfile + if len(args) == 1 { + name = args[0] + } + + profile := cfg.GetProfile(name) + if profile == nil { + return fmt.Errorf("profile %q not found", name) + } + + var env *config.Environment + if cmd.Flags().Changed("env") { + envName, _ := cmd.Flags().GetString("env") + env = cfg.GetEnvironment(envName) + if env == nil { + return fmt.Errorf("environment %q not found; configure it with \"zepctl config add-environment\"", envName) + } + } + + changed := applyEnvAndOverrides(profile, env, cmd) + + if cmd.Flags().Changed("api-key") { + v, _ := cmd.Flags().GetString("api-key") + if v == "" { + return fmt.Errorf("API key cannot be empty") + } + if err := keyring.Set(name, v); err != nil { + return fmt.Errorf("storing API key: %w", err) + } + changed = true + } + if setIfChanged(cmd, "project", &profile.ProjectUUID) { + changed = true + } + if setIfChanged(cmd, "account", &profile.AccountUUID) { + changed = true + } + + if !changed { + return fmt.Errorf("no flags provided; use --env, --api-url, --api-key, --project, --account, --oauth-issuer, --oauth-client-id, or --oauth-audience to update a field") + } + + if err := cfg.Save(); err != nil { + return fmt.Errorf("saving config: %w", err) + } + + output.Info("Updated profile %q", name) + return nil + }, +} + +// applyEnvAndOverrides layers an environment's auth fields onto a profile, +// then lets any explicit per-field flag override the env's value. Returns +// true if any field was set. +func applyEnvAndOverrides(p *config.Profile, env *config.Environment, cmd *cobra.Command) bool { + changed := false + if env != nil { + p.APIURL = env.APIURL + p.OAuthIssuer = env.OAuthIssuer + p.OAuthClientID = env.OAuthClientID + p.OAuthAudience = env.OAuthAudience + changed = true + } + if setIfChanged(cmd, "api-url", &p.APIURL) { + changed = true + } + if setIfChanged(cmd, "oauth-issuer", &p.OAuthIssuer) { + changed = true + } + if setIfChanged(cmd, "oauth-client-id", &p.OAuthClientID) { + changed = true + } + if setIfChanged(cmd, "oauth-audience", &p.OAuthAudience) { + changed = true + } + return changed +} + +// setIfChanged copies the named string flag's value into dst if the user +// passed the flag. Returns true if the flag was set. +func setIfChanged(cmd *cobra.Command, name string, dst *string) bool { + if !cmd.Flags().Changed(name) { + return false + } + v, _ := cmd.Flags().GetString(name) + *dst = v + return true +} + func init() { rootCmd.AddCommand(configCmd) configCmd.AddCommand(configViewCmd) configCmd.AddCommand(configGetProfilesCmd) configCmd.AddCommand(configUseProfileCmd) configCmd.AddCommand(configAddProfileCmd) + configCmd.AddCommand(configUpdateProfileCmd) configCmd.AddCommand(configDeleteProfileCmd) configAddProfileCmd.Flags().String("api-key", "", "API key for the profile") configAddProfileCmd.Flags().String("api-url", "", "API URL for the profile (uses SDK default if not set)") + configAddProfileCmd.Flags().String("oauth-issuer", "", "Override OIDC issuer for `auth login` (uses build-time default if unset)") + configAddProfileCmd.Flags().String("oauth-client-id", "", "Override OAuth client ID for `auth login` (uses build-time default if unset)") + configAddProfileCmd.Flags().String("oauth-audience", "", "Override OAuth audience for `auth login` (uses build-time default if unset)") + configAddProfileCmd.Flags().String("env", "", "Apply a named environment preset (see \"config add-environment\"); explicit per-field flags override the preset") + configAddProfileCmd.Flags().Bool("no-api-key", false, "Create a bearer-only profile with no API key (skip prompt)") + configUpdateProfileCmd.Flags().String("api-key", "", "Update API key (stored in system keychain)") + configUpdateProfileCmd.Flags().String("api-url", "", "Update API URL") + configUpdateProfileCmd.Flags().String("project", "", "Update project UUID") + configUpdateProfileCmd.Flags().String("account", "", "Update account UUID") + configUpdateProfileCmd.Flags().String("oauth-issuer", "", "Update OIDC issuer override (empty string clears the override)") + configUpdateProfileCmd.Flags().String("oauth-client-id", "", "Update OAuth client ID override (empty string clears the override)") + configUpdateProfileCmd.Flags().String("oauth-audience", "", "Update OAuth audience override (empty string clears the override)") + configUpdateProfileCmd.Flags().String("env", "", "Apply a named environment preset, replacing api-url/oauth-issuer/oauth-client-id/oauth-audience; explicit per-field flags override the preset") configDeleteProfileCmd.Flags().Bool("force", false, "Skip confirmation prompt") } diff --git a/internal/cli/config_test.go b/internal/cli/config_test.go new file mode 100644 index 0000000..39a78db --- /dev/null +++ b/internal/cli/config_test.go @@ -0,0 +1,469 @@ +package cli + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/getzep/zepctl/internal/config" + "github.com/getzep/zepctl/internal/keyring" + "github.com/spf13/viper" +) + +// TestAddProfile_ExistingWithAPIKey_ReturnsError verifies that +// "config add-profile " returns an error when the profile already +// has an API key. +func TestAddProfile_ExistingWithAPIKey_ReturnsError(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + viper.Reset() + t.Cleanup(func() { viper.Reset() }) + + writeTestConfig(t, tmpDir) // creates "test" profile + _, _ = config.Reload() + + // Seed keychain with an API key for the existing profile. + if err := keyring.Set("test", "z_existing_key"); err != nil { + t.Fatalf("keyring.Set: %v", err) + } + + err := configAddProfileCmd.RunE(configAddProfileCmd, []string{"test"}) + if err == nil { + t.Fatal("expected error when profile already has API key") + } + if !strings.Contains(err.Error(), "already exists") { + t.Errorf("error = %q, want to contain %q", err.Error(), "already exists") + } +} + +// TestAddProfile_NoAPIKey verifies that --no-api-key creates a +// bearer-only profile without prompting for or storing an API key. +func TestAddProfile_NoAPIKey(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + viper.Reset() + t.Cleanup(func() { viper.Reset() }) + + configDir := filepath.Join(tmpDir, ".zepctl") + if err := os.MkdirAll(configDir, 0o700); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(configDir, "config.yaml"), []byte("current-profile: \"\"\nprofiles: []\n"), 0o600); err != nil { + t.Fatal(err) + } + _, _ = config.Reload() + + if err := configAddProfileCmd.Flags().Set("no-api-key", "true"); err != nil { + t.Fatalf("setting flag: %v", err) + } + if err := configAddProfileCmd.Flags().Set("api-url", "https://api.dev.example.com"); err != nil { + t.Fatalf("setting flag: %v", err) + } + if err := configAddProfileCmd.Flags().Set("oauth-issuer", "https://dev.kinde.com"); err != nil { + t.Fatalf("setting flag: %v", err) + } + t.Cleanup(func() { + _ = configAddProfileCmd.Flags().Set("no-api-key", "false") + _ = configAddProfileCmd.Flags().Set("api-url", "") + _ = configAddProfileCmd.Flags().Set("oauth-issuer", "") + }) + + if err := configAddProfileCmd.RunE(configAddProfileCmd, []string{"bearer-only"}); err != nil { + t.Fatalf("add-profile: %v", err) + } + + cfg, err := config.Reload() + if err != nil { + t.Fatalf("Reload: %v", err) + } + p := cfg.GetProfile("bearer-only") + if p == nil { + t.Fatal("profile 'bearer-only' missing") + } + if p.OAuthIssuer != "https://dev.kinde.com" { + t.Errorf("OAuthIssuer = %q, want %q", p.OAuthIssuer, "https://dev.kinde.com") + } + + creds, _ := keyring.GetCredentials("bearer-only") + if creds.HasAPIKey() { + t.Errorf("expected no API key in keychain, got %q", creds.APIKey) + } +} + +// TestUpdateProfile_OAuthFields verifies that --oauth-issuer and +// --oauth-client-id flags are persisted to the profile and that +// passing an empty string clears the override. +func TestUpdateProfile_OAuthFields(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + viper.Reset() + t.Cleanup(func() { viper.Reset() }) + + writeTestConfig(t, tmpDir) // creates "test" profile + _, _ = config.Reload() + + if err := configUpdateProfileCmd.Flags().Set("oauth-issuer", "https://dev.kinde.com"); err != nil { + t.Fatalf("setting flag: %v", err) + } + if err := configUpdateProfileCmd.Flags().Set("oauth-client-id", "dev-client"); err != nil { + t.Fatalf("setting flag: %v", err) + } + t.Cleanup(func() { + _ = configUpdateProfileCmd.Flags().Set("oauth-issuer", "") + _ = configUpdateProfileCmd.Flags().Set("oauth-client-id", "") + }) + + if err := configUpdateProfileCmd.RunE(configUpdateProfileCmd, []string{"test"}); err != nil { + t.Fatalf("update-profile: %v", err) + } + + cfg, err := config.Reload() + if err != nil { + t.Fatalf("Reload: %v", err) + } + p := cfg.GetProfile("test") + if p == nil { + t.Fatal("profile 'test' missing after update") + } + if p.OAuthIssuer != "https://dev.kinde.com" { + t.Errorf("OAuthIssuer = %q, want %q", p.OAuthIssuer, "https://dev.kinde.com") + } + if p.OAuthClientID != "dev-client" { + t.Errorf("OAuthClientID = %q, want %q", p.OAuthClientID, "dev-client") + } + + // Now clear both with explicit empty strings. + if err := configUpdateProfileCmd.Flags().Set("oauth-issuer", ""); err != nil { + t.Fatalf("setting flag: %v", err) + } + if err := configUpdateProfileCmd.Flags().Set("oauth-client-id", ""); err != nil { + t.Fatalf("setting flag: %v", err) + } + if err := configUpdateProfileCmd.RunE(configUpdateProfileCmd, []string{"test"}); err != nil { + t.Fatalf("update-profile (clear): %v", err) + } + cfg, _ = config.Reload() + p = cfg.GetProfile("test") + if p.OAuthIssuer != "" { + t.Errorf("OAuthIssuer = %q, want cleared", p.OAuthIssuer) + } + if p.OAuthClientID != "" { + t.Errorf("OAuthClientID = %q, want cleared", p.OAuthClientID) + } +} + +func resetAddProfileFlags(t *testing.T) { + t.Helper() + resetCmdFlags(t, configAddProfileCmd, + "api-key", "api-url", "oauth-issuer", "oauth-client-id", "oauth-audience", "env", "no-api-key") +} + +func resetUpdateProfileFlags(t *testing.T) { + t.Helper() + resetCmdFlags(t, configUpdateProfileCmd, + "api-key", "api-url", "project", "account", "oauth-issuer", "oauth-client-id", "oauth-audience", "env") +} + +// writeConfigYAML writes a verbatim YAML config to tmpDir/.zepctl/config.yaml. +func writeConfigYAML(t *testing.T, tmpDir, yaml string) { + t.Helper() + configDir := filepath.Join(tmpDir, ".zepctl") + if err := os.MkdirAll(configDir, 0o700); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(configDir, "config.yaml"), []byte(yaml), 0o600); err != nil { + t.Fatal(err) + } +} + +func TestAddProfile_WithEnv_AppliesPreset(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + viper.Reset() + t.Cleanup(func() { viper.Reset() }) + + writeConfigYAML(t, tmpDir, `current-profile: "" +profiles: [] +environments: + - name: development + api-url: https://api.dev.example.com + oauth-issuer: https://issuer.example.com + oauth-client-id: dev-client-id + oauth-audience: https://api.dev.example.com/api +`) + _, _ = config.Reload() + resetAddProfileFlags(t) + t.Cleanup(func() { resetAddProfileFlags(t) }) + + if err := configAddProfileCmd.Flags().Set("no-api-key", "true"); err != nil { + t.Fatalf("setting flag: %v", err) + } + if err := configAddProfileCmd.Flags().Set("env", "development"); err != nil { + t.Fatalf("setting flag: %v", err) + } + + if err := configAddProfileCmd.RunE(configAddProfileCmd, []string{"dev"}); err != nil { + t.Fatalf("add-profile: %v", err) + } + + cfg, _ := config.Reload() + p := cfg.GetProfile("dev") + if p == nil { + t.Fatal("profile 'dev' missing") + } + if p.APIURL != "https://api.dev.example.com" { + t.Errorf("APIURL = %q, want from env", p.APIURL) + } + if p.OAuthIssuer != "https://issuer.example.com" { + t.Errorf("OAuthIssuer = %q, want from env", p.OAuthIssuer) + } + if p.OAuthClientID != "dev-client-id" { + t.Errorf("OAuthClientID = %q, want from env", p.OAuthClientID) + } + if p.OAuthAudience != "https://api.dev.example.com/api" { + t.Errorf("OAuthAudience = %q, want from env", p.OAuthAudience) + } +} + +func TestAddProfile_WithEnv_ExplicitFlagOverrides(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + viper.Reset() + t.Cleanup(func() { viper.Reset() }) + + writeConfigYAML(t, tmpDir, `current-profile: "" +profiles: [] +environments: + - name: development + api-url: https://api.dev.example.com + oauth-issuer: https://issuer.example.com + oauth-client-id: dev-client-id +`) + _, _ = config.Reload() + resetAddProfileFlags(t) + t.Cleanup(func() { resetAddProfileFlags(t) }) + + if err := configAddProfileCmd.Flags().Set("no-api-key", "true"); err != nil { + t.Fatalf("setting flag: %v", err) + } + if err := configAddProfileCmd.Flags().Set("env", "development"); err != nil { + t.Fatalf("setting flag: %v", err) + } + // Explicit --api-url should win over the env preset. + if err := configAddProfileCmd.Flags().Set("api-url", "http://localhost:8001"); err != nil { + t.Fatalf("setting flag: %v", err) + } + + if err := configAddProfileCmd.RunE(configAddProfileCmd, []string{"local"}); err != nil { + t.Fatalf("add-profile: %v", err) + } + + cfg, _ := config.Reload() + p := cfg.GetProfile("local") + if p.APIURL != "http://localhost:8001" { + t.Errorf("APIURL = %q, want explicit override", p.APIURL) + } + // OAuth fields untouched -- still come from the env. + if p.OAuthIssuer != "https://issuer.example.com" { + t.Errorf("OAuthIssuer = %q, want from env", p.OAuthIssuer) + } + if p.OAuthClientID != "dev-client-id" { + t.Errorf("OAuthClientID = %q, want from env", p.OAuthClientID) + } +} + +func TestAddProfile_WithEnv_UnknownErrors(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + viper.Reset() + t.Cleanup(func() { viper.Reset() }) + emptyConfigDir(t, tmpDir) + _, _ = config.Reload() + resetAddProfileFlags(t) + t.Cleanup(func() { resetAddProfileFlags(t) }) + + if err := configAddProfileCmd.Flags().Set("no-api-key", "true"); err != nil { + t.Fatalf("setting flag: %v", err) + } + if err := configAddProfileCmd.Flags().Set("env", "missing-env"); err != nil { + t.Fatalf("setting flag: %v", err) + } + + err := configAddProfileCmd.RunE(configAddProfileCmd, []string{"x"}) + if err == nil { + t.Fatal("expected error for unknown environment") + } + if !strings.Contains(err.Error(), "not found") { + t.Errorf("error = %q, want to mention 'not found'", err.Error()) + } +} + +func TestUpdateProfile_WithEnv_ReplacesAllThreeFields(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + viper.Reset() + t.Cleanup(func() { viper.Reset() }) + + writeConfigYAML(t, tmpDir, `current-profile: dev +profiles: + - name: dev + api-url: https://stale.example.com + oauth-issuer: https://stale-issuer.example.com + oauth-client-id: stale-client +environments: + - name: development + api-url: https://api.dev.example.com + oauth-issuer: https://issuer.example.com + oauth-client-id: fresh-client-id +`) + _, _ = config.Reload() + resetUpdateProfileFlags(t) + t.Cleanup(func() { resetUpdateProfileFlags(t) }) + + if err := configUpdateProfileCmd.Flags().Set("env", "development"); err != nil { + t.Fatalf("setting flag: %v", err) + } + + if err := configUpdateProfileCmd.RunE(configUpdateProfileCmd, []string{"dev"}); err != nil { + t.Fatalf("update-profile: %v", err) + } + + cfg, _ := config.Reload() + p := cfg.GetProfile("dev") + if p.APIURL != "https://api.dev.example.com" { + t.Errorf("APIURL = %q, want replaced", p.APIURL) + } + if p.OAuthIssuer != "https://issuer.example.com" { + t.Errorf("OAuthIssuer = %q, want replaced", p.OAuthIssuer) + } + if p.OAuthClientID != "fresh-client-id" { + t.Errorf("OAuthClientID = %q, want replaced", p.OAuthClientID) + } +} + +func TestUpdateProfile_WithEnv_ExplicitFlagOverrides(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + viper.Reset() + t.Cleanup(func() { viper.Reset() }) + + writeConfigYAML(t, tmpDir, `current-profile: dev +profiles: + - name: dev + api-url: https://stale.example.com +environments: + - name: development + api-url: https://api.dev.example.com + oauth-issuer: https://issuer.example.com + oauth-client-id: dev-client-id +`) + _, _ = config.Reload() + resetUpdateProfileFlags(t) + t.Cleanup(func() { resetUpdateProfileFlags(t) }) + + if err := configUpdateProfileCmd.Flags().Set("env", "development"); err != nil { + t.Fatalf("setting flag: %v", err) + } + if err := configUpdateProfileCmd.Flags().Set("oauth-client-id", "override-client"); err != nil { + t.Fatalf("setting flag: %v", err) + } + + if err := configUpdateProfileCmd.RunE(configUpdateProfileCmd, []string{"dev"}); err != nil { + t.Fatalf("update-profile: %v", err) + } + + cfg, _ := config.Reload() + p := cfg.GetProfile("dev") + if p.APIURL != "https://api.dev.example.com" { + t.Errorf("APIURL = %q, want from env", p.APIURL) + } + if p.OAuthClientID != "override-client" { + t.Errorf("OAuthClientID = %q, want explicit override", p.OAuthClientID) + } +} + +func TestUpdateProfile_WithEnv_UnknownErrors(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + viper.Reset() + t.Cleanup(func() { viper.Reset() }) + + writeConfigYAML(t, tmpDir, `current-profile: dev +profiles: + - name: dev +`) + _, _ = config.Reload() + resetUpdateProfileFlags(t) + t.Cleanup(func() { resetUpdateProfileFlags(t) }) + + if err := configUpdateProfileCmd.Flags().Set("env", "missing-env"); err != nil { + t.Fatalf("setting flag: %v", err) + } + + err := configUpdateProfileCmd.RunE(configUpdateProfileCmd, []string{"dev"}) + if err == nil { + t.Fatal("expected error for unknown environment") + } + if !strings.Contains(err.Error(), "not found") { + t.Errorf("error = %q, want to mention 'not found'", err.Error()) + } +} + +// TestAddProfile_NewProfile verifies that "config add-profile " +// creates a new profile with an API key stored in the keychain. +func TestAddProfile_NewProfile(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + viper.Reset() + t.Cleanup(func() { viper.Reset() }) + + // Start with a config that has no profiles. + configDir := filepath.Join(tmpDir, ".zepctl") + if err := os.MkdirAll(configDir, 0o700); err != nil { + t.Fatal(err) + } + cfgContent := "current-profile: \"\"\nprofiles: []\n" + if err := os.WriteFile(filepath.Join(configDir, "config.yaml"), []byte(cfgContent), 0o600); err != nil { + t.Fatal(err) + } + _, _ = config.Reload() + + // Set the api-key flag on the command so it doesn't prompt on stdin. + if err := configAddProfileCmd.Flags().Set("api-key", "z_brand_new_key"); err != nil { + t.Fatalf("setting flag: %v", err) + } + t.Cleanup(func() { + _ = configAddProfileCmd.Flags().Set("api-key", "") + _ = configAddProfileCmd.Flags().Set("api-url", "") + }) + + err := configAddProfileCmd.RunE(configAddProfileCmd, []string{"brandnew"}) + if err != nil { + t.Fatalf("add-profile: %v", err) + } + + // Verify profile was created in config. + cfg, err := config.Reload() + if err != nil { + t.Fatalf("Reload: %v", err) + } + p := cfg.GetProfile("brandnew") + if p == nil { + t.Fatal("profile 'brandnew' not found after add-profile") + } + + // Verify the new profile was set as current (since no profile was active). + if cfg.CurrentProfile != "brandnew" { + t.Errorf("CurrentProfile = %q, want %q", cfg.CurrentProfile, "brandnew") + } + + // Verify API key in keychain. + creds, err := keyring.GetCredentials("brandnew") + if err != nil { + t.Fatalf("GetCredentials: %v", err) + } + if creds.APIKey != "z_brand_new_key" { + t.Errorf("APIKey = %q, want %q", creds.APIKey, "z_brand_new_key") + } +} diff --git a/internal/cli/edge.go b/internal/cli/edge.go index a40feca..0a7bef3 100644 --- a/internal/cli/edge.go +++ b/internal/cli/edge.go @@ -22,7 +22,7 @@ var edgeCmd = &cobra.Command{ } var edgeListCmd = &cobra.Command{ - Use: "list", + Use: listCmdUse, Short: "List edges", RunE: func(cmd *cobra.Command, args []string) error { userID, _ := cmd.Flags().GetString("user") @@ -32,7 +32,7 @@ var edgeListCmd = &cobra.Command{ return fmt.Errorf("either --user or --graph is required") } - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -96,7 +96,7 @@ var edgeGetCmd = &cobra.Command{ RunE: func(cmd *cobra.Command, args []string) error { uuid := args[0] - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -141,7 +141,7 @@ var edgeUpdateCmd = &cobra.Command{ expiredAt, _ := cmd.Flags().GetString("expired-at") attrsStr, _ := cmd.Flags().GetString("attrs") - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -200,7 +200,7 @@ var edgeDeleteCmd = &cobra.Command{ } } - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } diff --git a/internal/cli/environment.go b/internal/cli/environment.go new file mode 100644 index 0000000..070585b --- /dev/null +++ b/internal/cli/environment.go @@ -0,0 +1,184 @@ +package cli + +import ( + "bufio" + "fmt" + "os" + "strings" + + "github.com/getzep/zepctl/internal/config" + "github.com/getzep/zepctl/internal/output" + "github.com/spf13/cobra" +) + +var configGetEnvironmentsCmd = &cobra.Command{ + Use: "get-environments", + Short: "List all environment presets", + RunE: func(_ *cobra.Command, _ []string) error { + cfg, err := config.Load() + if err != nil { + return fmt.Errorf("loading config: %w", err) + } + + if output.GetFormat() == output.FormatTable { + tbl := output.NewTable("NAME", "API URL", "OAUTH ISSUER", "OAUTH CLIENT ID", "OAUTH AUDIENCE") + tbl.WriteHeader() + for _, e := range cfg.Environments { + tbl.WriteRow(e.Name, e.APIURL, e.OAuthIssuer, e.OAuthClientID, e.OAuthAudience) + } + return tbl.Flush() + } + + return output.Print(cfg.Environments) + }, +} + +var configAddEnvironmentCmd = &cobra.Command{ + Use: "add-environment ", + Short: "Add a named environment preset (api-url, oauth-issuer, oauth-client-id)", + Long: `Add a reusable environment preset that can be applied to profiles via +"config add-profile --env " or "config update-profile --env ". + +Environments are stored in the user's config file, not the binary, so +non-default endpoints stay out of distributed builds.`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + name := args[0] + + cfg, err := config.Load() + if err != nil { + return fmt.Errorf("loading config: %w", err) + } + + if cfg.GetEnvironment(name) != nil { + return fmt.Errorf("environment %q already exists; use \"config update-environment\" to modify it", name) + } + + apiURL, _ := cmd.Flags().GetString("api-url") + oauthIssuer, _ := cmd.Flags().GetString("oauth-issuer") + oauthClientID, _ := cmd.Flags().GetString("oauth-client-id") + oauthAudience, _ := cmd.Flags().GetString("oauth-audience") + + cfg.Environments = append(cfg.Environments, config.Environment{ + Name: name, + APIURL: apiURL, + OAuthIssuer: oauthIssuer, + OAuthClientID: oauthClientID, + OAuthAudience: oauthAudience, + }) + + if err := cfg.Save(); err != nil { + return fmt.Errorf("saving config: %w", err) + } + + output.Info("Added environment %q", name) + return nil + }, +} + +var configUpdateEnvironmentCmd = &cobra.Command{ + Use: "update-environment ", + Short: "Update an environment's settings", + Long: `Update fields on an existing environment. Only the flags you provide are +changed; other fields are left as-is. Pass an empty string to clear a field.`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + name := args[0] + + cfg, err := config.Load() + if err != nil { + return fmt.Errorf("loading config: %w", err) + } + + env := cfg.GetEnvironment(name) + if env == nil { + return fmt.Errorf("environment %q not found", name) + } + + changed := setIfChanged(cmd, "api-url", &env.APIURL) + if setIfChanged(cmd, "oauth-issuer", &env.OAuthIssuer) { + changed = true + } + if setIfChanged(cmd, "oauth-client-id", &env.OAuthClientID) { + changed = true + } + if setIfChanged(cmd, "oauth-audience", &env.OAuthAudience) { + changed = true + } + + if !changed { + return fmt.Errorf("no flags provided; use --api-url, --oauth-issuer, --oauth-client-id, or --oauth-audience to update a field") + } + + if err := cfg.Save(); err != nil { + return fmt.Errorf("saving config: %w", err) + } + + output.Info("Updated environment %q", name) + return nil + }, +} + +var configDeleteEnvironmentCmd = &cobra.Command{ + Use: "delete-environment ", + Short: "Remove an environment preset", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + name := args[0] + force, _ := cmd.Flags().GetBool("force") + + cfg, err := config.Load() + if err != nil { + return fmt.Errorf("loading config: %w", err) + } + + if cfg.GetEnvironment(name) == nil { + return fmt.Errorf("environment %q not found", name) + } + + if !force { + fmt.Printf("Delete environment %q? [y/N]: ", name) + reader := bufio.NewReader(os.Stdin) + response, _ := reader.ReadString('\n') + response = strings.TrimSpace(strings.ToLower(response)) + if response != "y" && response != "yes" { + output.Info("Aborted") + return nil + } + } + + var kept []config.Environment + for _, e := range cfg.Environments { + if e.Name != name { + kept = append(kept, e) + } + } + cfg.Environments = kept + + if err := cfg.Save(); err != nil { + return fmt.Errorf("saving config: %w", err) + } + + output.Info("Deleted environment %q", name) + return nil + }, +} + +func init() { + configCmd.AddCommand(configGetEnvironmentsCmd) + configCmd.AddCommand(configAddEnvironmentCmd) + configCmd.AddCommand(configUpdateEnvironmentCmd) + configCmd.AddCommand(configDeleteEnvironmentCmd) + + configAddEnvironmentCmd.Flags().String("api-url", "", "API URL for the environment") + configAddEnvironmentCmd.Flags().String("oauth-issuer", "", "OIDC issuer for the environment") + configAddEnvironmentCmd.Flags().String("oauth-client-id", "", "OAuth client ID for the environment") + configAddEnvironmentCmd.Flags().String("oauth-audience", "", "OAuth audience for bearer token aud claim") + + configUpdateEnvironmentCmd.Flags().String("api-url", "", "Update API URL (empty string clears it)") + configUpdateEnvironmentCmd.Flags().String("oauth-issuer", "", "Update OIDC issuer (empty string clears it)") + configUpdateEnvironmentCmd.Flags().String("oauth-client-id", "", "Update OAuth client ID (empty string clears it)") + configUpdateEnvironmentCmd.Flags().String("oauth-audience", "", "Update OAuth audience (empty string clears it)") + + configDeleteEnvironmentCmd.Flags().Bool("force", false, "Skip confirmation prompt") +} diff --git a/internal/cli/environment_test.go b/internal/cli/environment_test.go new file mode 100644 index 0000000..598d191 --- /dev/null +++ b/internal/cli/environment_test.go @@ -0,0 +1,316 @@ +package cli + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/getzep/zepctl/internal/config" + "github.com/spf13/viper" +) + +// emptyConfigDir writes a config.yaml with no profiles or environments to +// tmpDir/.zepctl/, so config.Load() returns a clean Config. +func emptyConfigDir(t *testing.T, tmpDir string) { + t.Helper() + configDir := filepath.Join(tmpDir, ".zepctl") + if err := os.MkdirAll(configDir, 0o700); err != nil { + t.Fatal(err) + } + cfgContent := "current-profile: \"\"\nprofiles: []\n" + if err := os.WriteFile(filepath.Join(configDir, "config.yaml"), []byte(cfgContent), 0o600); err != nil { + t.Fatal(err) + } +} + +func resetEnvFlags(t *testing.T) { + t.Helper() + envFlags := []string{"api-url", "oauth-issuer", "oauth-client-id", "oauth-audience"} + resetCmdFlags(t, configAddEnvironmentCmd, envFlags...) + resetCmdFlags(t, configUpdateEnvironmentCmd, envFlags...) + resetCmdFlags(t, configDeleteEnvironmentCmd, "force") +} + +func TestAddEnvironment_New(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + viper.Reset() + t.Cleanup(func() { viper.Reset() }) + emptyConfigDir(t, tmpDir) + _, _ = config.Reload() + resetEnvFlags(t) + t.Cleanup(func() { resetEnvFlags(t) }) + + if err := configAddEnvironmentCmd.Flags().Set("api-url", "https://api.dev.example.com"); err != nil { + t.Fatalf("setting flag: %v", err) + } + if err := configAddEnvironmentCmd.Flags().Set("oauth-issuer", "https://issuer.example.com"); err != nil { + t.Fatalf("setting flag: %v", err) + } + if err := configAddEnvironmentCmd.Flags().Set("oauth-client-id", "client-xyz"); err != nil { + t.Fatalf("setting flag: %v", err) + } + + if err := configAddEnvironmentCmd.RunE(configAddEnvironmentCmd, []string{"development"}); err != nil { + t.Fatalf("add-environment: %v", err) + } + + cfg, err := config.Reload() + if err != nil { + t.Fatalf("Reload: %v", err) + } + env := cfg.GetEnvironment("development") + if env == nil { + t.Fatal("environment 'development' missing after add-environment") + } + if env.APIURL != "https://api.dev.example.com" { + t.Errorf("APIURL = %q, want %q", env.APIURL, "https://api.dev.example.com") + } + if env.OAuthIssuer != "https://issuer.example.com" { + t.Errorf("OAuthIssuer = %q", env.OAuthIssuer) + } + if env.OAuthClientID != "client-xyz" { + t.Errorf("OAuthClientID = %q", env.OAuthClientID) + } +} + +func TestAddEnvironment_Duplicate_Errors(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + viper.Reset() + t.Cleanup(func() { viper.Reset() }) + + configDir := filepath.Join(tmpDir, ".zepctl") + if err := os.MkdirAll(configDir, 0o700); err != nil { + t.Fatal(err) + } + yaml := `current-profile: "" +profiles: [] +environments: + - name: development + api-url: https://api.dev.example.com +` + if err := os.WriteFile(filepath.Join(configDir, "config.yaml"), []byte(yaml), 0o600); err != nil { + t.Fatal(err) + } + _, _ = config.Reload() + resetEnvFlags(t) + t.Cleanup(func() { resetEnvFlags(t) }) + + err := configAddEnvironmentCmd.RunE(configAddEnvironmentCmd, []string{"development"}) + if err == nil { + t.Fatal("expected error for duplicate environment") + } + if !strings.Contains(err.Error(), "already exists") { + t.Errorf("error = %q, want to mention 'already exists'", err.Error()) + } +} + +func TestUpdateEnvironment_PartialUpdate(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + viper.Reset() + t.Cleanup(func() { viper.Reset() }) + + configDir := filepath.Join(tmpDir, ".zepctl") + if err := os.MkdirAll(configDir, 0o700); err != nil { + t.Fatal(err) + } + yaml := `current-profile: "" +profiles: [] +environments: + - name: development + api-url: https://old.example.com + oauth-issuer: https://issuer.example.com + oauth-client-id: old-client +` + if err := os.WriteFile(filepath.Join(configDir, "config.yaml"), []byte(yaml), 0o600); err != nil { + t.Fatal(err) + } + _, _ = config.Reload() + resetEnvFlags(t) + t.Cleanup(func() { resetEnvFlags(t) }) + + if err := configUpdateEnvironmentCmd.Flags().Set("api-url", "https://new.example.com"); err != nil { + t.Fatalf("setting flag: %v", err) + } + + if err := configUpdateEnvironmentCmd.RunE(configUpdateEnvironmentCmd, []string{"development"}); err != nil { + t.Fatalf("update-environment: %v", err) + } + + cfg, _ := config.Reload() + env := cfg.GetEnvironment("development") + if env == nil { + t.Fatal("environment missing") + } + if env.APIURL != "https://new.example.com" { + t.Errorf("APIURL = %q, want updated", env.APIURL) + } + // Untouched fields preserved. + if env.OAuthIssuer != "https://issuer.example.com" { + t.Errorf("OAuthIssuer = %q, want preserved", env.OAuthIssuer) + } + if env.OAuthClientID != "old-client" { + t.Errorf("OAuthClientID = %q, want preserved", env.OAuthClientID) + } +} + +func TestUpdateEnvironment_ClearField(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + viper.Reset() + t.Cleanup(func() { viper.Reset() }) + + configDir := filepath.Join(tmpDir, ".zepctl") + if err := os.MkdirAll(configDir, 0o700); err != nil { + t.Fatal(err) + } + yaml := `current-profile: "" +profiles: [] +environments: + - name: development + oauth-client-id: to-be-cleared +` + if err := os.WriteFile(filepath.Join(configDir, "config.yaml"), []byte(yaml), 0o600); err != nil { + t.Fatal(err) + } + _, _ = config.Reload() + resetEnvFlags(t) + t.Cleanup(func() { resetEnvFlags(t) }) + + // Explicit empty string clears the field (mirrors update-profile semantics). + if err := configUpdateEnvironmentCmd.Flags().Set("oauth-client-id", ""); err != nil { + t.Fatalf("setting flag: %v", err) + } + + if err := configUpdateEnvironmentCmd.RunE(configUpdateEnvironmentCmd, []string{"development"}); err != nil { + t.Fatalf("update-environment: %v", err) + } + + cfg, _ := config.Reload() + env := cfg.GetEnvironment("development") + if env.OAuthClientID != "" { + t.Errorf("OAuthClientID = %q, want cleared", env.OAuthClientID) + } +} + +func TestUpdateEnvironment_NotFound(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + viper.Reset() + t.Cleanup(func() { viper.Reset() }) + emptyConfigDir(t, tmpDir) + _, _ = config.Reload() + resetEnvFlags(t) + t.Cleanup(func() { resetEnvFlags(t) }) + + if err := configUpdateEnvironmentCmd.Flags().Set("api-url", "x"); err != nil { + t.Fatalf("setting flag: %v", err) + } + + err := configUpdateEnvironmentCmd.RunE(configUpdateEnvironmentCmd, []string{"missing"}) + if err == nil { + t.Fatal("expected error for unknown environment") + } + if !strings.Contains(err.Error(), "not found") { + t.Errorf("error = %q, want to mention 'not found'", err.Error()) + } +} + +func TestUpdateEnvironment_NoFlagsProvided(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + viper.Reset() + t.Cleanup(func() { viper.Reset() }) + + configDir := filepath.Join(tmpDir, ".zepctl") + if err := os.MkdirAll(configDir, 0o700); err != nil { + t.Fatal(err) + } + yaml := `current-profile: "" +profiles: [] +environments: + - name: development + api-url: https://x.example.com +` + if err := os.WriteFile(filepath.Join(configDir, "config.yaml"), []byte(yaml), 0o600); err != nil { + t.Fatal(err) + } + _, _ = config.Reload() + resetEnvFlags(t) + t.Cleanup(func() { resetEnvFlags(t) }) + + err := configUpdateEnvironmentCmd.RunE(configUpdateEnvironmentCmd, []string{"development"}) + if err == nil { + t.Fatal("expected error when no flags provided") + } + if !strings.Contains(err.Error(), "no flags provided") { + t.Errorf("error = %q", err.Error()) + } +} + +func TestDeleteEnvironment_Force(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + viper.Reset() + t.Cleanup(func() { viper.Reset() }) + + configDir := filepath.Join(tmpDir, ".zepctl") + if err := os.MkdirAll(configDir, 0o700); err != nil { + t.Fatal(err) + } + yaml := `current-profile: "" +profiles: [] +environments: + - name: development + api-url: https://x.example.com + - name: local +` + if err := os.WriteFile(filepath.Join(configDir, "config.yaml"), []byte(yaml), 0o600); err != nil { + t.Fatal(err) + } + _, _ = config.Reload() + resetEnvFlags(t) + t.Cleanup(func() { resetEnvFlags(t) }) + + if err := configDeleteEnvironmentCmd.Flags().Set("force", "true"); err != nil { + t.Fatalf("setting flag: %v", err) + } + + if err := configDeleteEnvironmentCmd.RunE(configDeleteEnvironmentCmd, []string{"development"}); err != nil { + t.Fatalf("delete-environment: %v", err) + } + + cfg, _ := config.Reload() + if cfg.GetEnvironment("development") != nil { + t.Error("development environment should be deleted") + } + if cfg.GetEnvironment("local") == nil { + t.Error("local environment should be preserved") + } +} + +func TestDeleteEnvironment_NotFound(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + viper.Reset() + t.Cleanup(func() { viper.Reset() }) + emptyConfigDir(t, tmpDir) + _, _ = config.Reload() + resetEnvFlags(t) + t.Cleanup(func() { resetEnvFlags(t) }) + + if err := configDeleteEnvironmentCmd.Flags().Set("force", "true"); err != nil { + t.Fatalf("setting flag: %v", err) + } + + err := configDeleteEnvironmentCmd.RunE(configDeleteEnvironmentCmd, []string{"missing"}) + if err == nil { + t.Fatal("expected error for unknown environment") + } + if !strings.Contains(err.Error(), "not found") { + t.Errorf("error = %q", err.Error()) + } +} diff --git a/internal/cli/episode.go b/internal/cli/episode.go index baca772..46b7a96 100644 --- a/internal/cli/episode.go +++ b/internal/cli/episode.go @@ -21,7 +21,7 @@ var episodeCmd = &cobra.Command{ } var episodeListCmd = &cobra.Command{ - Use: "list", + Use: listCmdUse, Short: "List episodes", RunE: func(cmd *cobra.Command, args []string) error { userID, _ := cmd.Flags().GetString("user") @@ -32,7 +32,7 @@ var episodeListCmd = &cobra.Command{ return fmt.Errorf("either --user or --graph is required") } - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -95,7 +95,7 @@ var episodeGetCmd = &cobra.Command{ RunE: func(cmd *cobra.Command, args []string) error { uuid := args[0] - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -138,7 +138,7 @@ var episodeMentionsCmd = &cobra.Command{ RunE: func(cmd *cobra.Command, args []string) error { uuid := args[0] - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -171,7 +171,7 @@ var episodeDeleteCmd = &cobra.Command{ } } - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } diff --git a/internal/cli/graph.go b/internal/cli/graph.go index c7eb6f1..e18b9e1 100644 --- a/internal/cli/graph.go +++ b/internal/cli/graph.go @@ -22,10 +22,10 @@ var graphCmd = &cobra.Command{ } var graphListCmd = &cobra.Command{ - Use: "list", + Use: listCmdUse, Short: "List graphs", RunE: func(cmd *cobra.Command, args []string) error { - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -87,7 +87,7 @@ var graphCreateCmd = &cobra.Command{ RunE: func(cmd *cobra.Command, args []string) error { graphID := args[0] - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -124,7 +124,7 @@ var graphDeleteCmd = &cobra.Command{ } } - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -164,7 +164,7 @@ var graphCloneCmd = &cobra.Command{ return fmt.Errorf("--target-user cannot be used with --source-graph; use --target-graph instead") } - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -232,7 +232,7 @@ var graphAddCmd = &cobra.Command{ return fmt.Errorf("either graph-id argument or --user flag is required") } - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -371,7 +371,7 @@ Example: --source-attrs '{"type": "Person", "age": 30}'`, return fmt.Errorf("--target-node is required") } - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -398,10 +398,18 @@ Example: --source-attrs '{"type": "Person", "age": 30}'`, } if sourceLabelsStr != "" { - req.SourceNodeLabels = strings.Split(sourceLabelsStr, ",") + labels := strings.Split(sourceLabelsStr, ",") + if len(labels) > 1 { + return fmt.Errorf("--source-label accepts a single value; got %d", len(labels)) + } + req.SourceNodeLabels = labels } if targetLabelsStr != "" { - req.TargetNodeLabels = strings.Split(targetLabelsStr, ",") + labels := strings.Split(targetLabelsStr, ",") + if len(labels) > 1 { + return fmt.Errorf("--target-label accepts a single value; got %d", len(labels)) + } + req.TargetNodeLabels = labels } // Parse source node attributes @@ -488,7 +496,7 @@ Date filters allow filtering by date fields (created_at, valid_at, invalid_at, e return fmt.Errorf("either --user or --graph is required") } - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -673,7 +681,7 @@ mutually exclusive with seed selection flags (--node-labels/--edge-types/--node- return fmt.Errorf("--query is mutually exclusive with --node-labels/--edge-types/--node-uuids") } - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -981,8 +989,8 @@ func init() { graphAddFactCmd.Flags().String("source-attrs", "", "Source node attributes as JSON") graphAddFactCmd.Flags().String("edge-attrs", "", "Edge attributes as JSON") graphAddFactCmd.Flags().String("target-attrs", "", "Target node attributes as JSON") - graphAddFactCmd.Flags().String("source-labels", "", "Comma-separated labels for the source node") - graphAddFactCmd.Flags().String("target-labels", "", "Comma-separated labels for the target node") + graphAddFactCmd.Flags().String("source-labels", "", "Label for the source node (single value; API accepts at most one)") + graphAddFactCmd.Flags().String("target-labels", "", "Label for the target node (single value; API accepts at most one)") // Search flags graphSearchCmd.Flags().String("user", "", "Search user graph") diff --git a/internal/cli/node.go b/internal/cli/node.go index 17c2e41..a673479 100644 --- a/internal/cli/node.go +++ b/internal/cli/node.go @@ -22,7 +22,7 @@ var nodeCmd = &cobra.Command{ } var nodeListCmd = &cobra.Command{ - Use: "list", + Use: listCmdUse, Short: "List nodes", RunE: func(cmd *cobra.Command, args []string) error { userID, _ := cmd.Flags().GetString("user") @@ -32,7 +32,7 @@ var nodeListCmd = &cobra.Command{ return fmt.Errorf("either --user or --graph is required") } - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -92,7 +92,7 @@ var nodeGetCmd = &cobra.Command{ RunE: func(cmd *cobra.Command, args []string) error { uuid := args[0] - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -127,7 +127,7 @@ var nodeEdgesCmd = &cobra.Command{ RunE: func(cmd *cobra.Command, args []string) error { uuid := args[0] - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -162,7 +162,7 @@ var nodeEpisodesCmd = &cobra.Command{ RunE: func(cmd *cobra.Command, args []string) error { uuid := args[0] - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -204,7 +204,7 @@ var nodeUpdateCmd = &cobra.Command{ labelsStr, _ := cmd.Flags().GetString("labels") attrsStr, _ := cmd.Flags().GetString("attrs") - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -257,7 +257,7 @@ var nodeDeleteCmd = &cobra.Command{ } } - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } diff --git a/internal/cli/observation.go b/internal/cli/observation.go index e6dda7c..94534b4 100644 --- a/internal/cli/observation.go +++ b/internal/cli/observation.go @@ -18,7 +18,7 @@ var observationCmd = &cobra.Command{ } var observationListCmd = &cobra.Command{ - Use: "list", + Use: listCmdUse, Short: "List observations", RunE: func(cmd *cobra.Command, args []string) error { userID, _ := cmd.Flags().GetString("user") @@ -30,7 +30,7 @@ var observationListCmd = &cobra.Command{ return fmt.Errorf("either --user or --graph is required") } - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -86,7 +86,7 @@ var observationGetCmd = &cobra.Command{ RunE: func(cmd *cobra.Command, args []string) error { uuid := args[0] - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } diff --git a/internal/cli/ontology.go b/internal/cli/ontology.go index c104bed..4e5b0a5 100644 --- a/internal/cli/ontology.go +++ b/internal/cli/ontology.go @@ -24,7 +24,7 @@ var ontologyGetCmd = &cobra.Command{ Use: "get", Short: "Get ontology definitions", RunE: func(cmd *cobra.Command, args []string) error { - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -66,7 +66,7 @@ var ontologySetCmd = &cobra.Command{ } } - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } diff --git a/internal/cli/policy_set.go b/internal/cli/policy_set.go new file mode 100644 index 0000000..ec8024f --- /dev/null +++ b/internal/cli/policy_set.go @@ -0,0 +1,304 @@ +package cli + +import ( + "bufio" + "fmt" + "os" + "strings" + + "github.com/google/uuid" + "github.com/spf13/cobra" + "golang.org/x/term" + "gopkg.in/yaml.v3" + + "github.com/getzep/zepctl/internal/abac" + "github.com/getzep/zepctl/internal/client" + "github.com/getzep/zepctl/internal/config" + "github.com/getzep/zepctl/internal/output" +) + +// ExitCodeError carries a specific exit code. The main function inspects +// this to exit with a code other than 1 (used by policy-set validate). +type ExitCodeError struct { + Code int + Err error +} + +func (e *ExitCodeError) Error() string { return e.Err.Error() } +func (e *ExitCodeError) Unwrap() error { return e.Err } + +// newABACClient creates an ABAC API client from the current command context. +// It is a variable so tests can replace it with a version that returns a +// pre-configured client pointing at an httptest server. +var newABACClient = newABACClientDefault + +func newABACClientDefault(cmd *cobra.Command) (*abac.Client, error) { + httpClient, baseURL, err := client.NewBearerHTTPClient(cmd.Context()) + if err != nil { + return nil, err + } + projectUUID := config.GetProjectUUID() + if projectUUID == "" { + return nil, fmt.Errorf("no project configured; run \"zepctl config set-project\" to select a project") + } + return abac.NewClient(httpClient, baseURL, projectUUID), nil +} + +// validateUUID parses value as a UUID and returns a formatted error on failure. +func validateUUID(value, label string) error { + if _, err := uuid.Parse(value); err != nil { + return fmt.Errorf("invalid %s UUID: %q", label, value) + } + return nil +} + +// --- Commands --- + +var policySetCmd = &cobra.Command{ + Use: "policy-set", + Short: "Manage ABAC policy sets", + Long: "Create, list, update, delete, and validate ABAC policy sets.", +} + +var policySetListCmd = &cobra.Command{ + Use: listCmdUse, + Short: "List policy sets in the current project", + RunE: func(cmd *cobra.Command, _ []string) error { + ac, err := newABACClient(cmd) + if err != nil { + return err + } + result, err := ac.ListPolicySets(cmd.Context()) + if err != nil { + return fmt.Errorf("listing policy sets: %w", err) + } + if output.GetFormat() == output.FormatTable { + tbl := output.NewTable("UUID", "NAME", "MODE", "VERSION") + tbl.WriteHeader() + for _, ps := range result.PolicySets { + tbl.WriteRow(ps.UUID, ps.Name, ps.Mode, fmt.Sprintf("%d", ps.Version)) + } + return tbl.Flush() + } + return output.Print(result) + }, +} + +var policySetGetCmd = &cobra.Command{ + Use: "get ", + Short: "Get a policy set by UUID", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + if err := validateUUID(args[0], "policy set"); err != nil { + return err + } + ac, err := newABACClient(cmd) + if err != nil { + return err + } + ps, err := ac.GetPolicySet(cmd.Context(), args[0]) + if err != nil { + if abac.IsNotFound(err) { + return fmt.Errorf("policy set not found: %s", args[0]) + } + return fmt.Errorf("getting policy set: %w", err) + } + if output.GetFormat() == output.FormatTable { + fmt.Fprintf(cmd.OutOrStdout(), "UUID: %s\n", ps.UUID) + fmt.Fprintf(cmd.OutOrStdout(), "Name: %s\n", ps.Name) + if ps.Description != "" { + fmt.Fprintf(cmd.OutOrStdout(), "Description: %s\n", ps.Description) + } + fmt.Fprintf(cmd.OutOrStdout(), "Mode: %s\n", ps.Mode) + fmt.Fprintf(cmd.OutOrStdout(), "Version: %d\n", ps.Version) + fmt.Fprintf(cmd.OutOrStdout(), "Project: %s\n", ps.ProjectUUID) + fmt.Fprintf(cmd.OutOrStdout(), "Created: %s\n", ps.CreatedAt) + fmt.Fprintf(cmd.OutOrStdout(), "Updated: %s\n", ps.UpdatedAt) + if ps.Spec != nil { + specYAML, err := yaml.Marshal(ps.Spec) + if err == nil { + fmt.Fprintf(cmd.OutOrStdout(), "\nSpec:\n") + for _, line := range strings.Split(strings.TrimRight(string(specYAML), "\n"), "\n") { + fmt.Fprintf(cmd.OutOrStdout(), " %s\n", line) + } + } + } + return nil + } + return output.Print(ps) + }, +} + +var policySetCreateCmd = &cobra.Command{ + Use: "create", + Short: "Create a policy set from a YAML file", + RunE: func(cmd *cobra.Command, _ []string) error { + filePath, _ := cmd.Flags().GetString("file") + data, err := os.ReadFile(filePath) + if err != nil { + return fmt.Errorf("reading file: %w", err) + } + ac, err := newABACClient(cmd) + if err != nil { + return err + } + ps, err := ac.CreatePolicySet(cmd.Context(), string(data)) + if err != nil { + return fmt.Errorf("creating policy set: %w", err) + } + if output.GetFormat() == output.FormatTable { + output.Info("Created policy set %q (%s..., version %d)", ps.Name, truncateUUID(ps.UUID), ps.Version) + return nil + } + return output.Print(ps) + }, +} + +var policySetUpdateCmd = &cobra.Command{ + Use: "update ", + Short: "Update a policy set from a YAML file", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + if err := validateUUID(args[0], "policy set"); err != nil { + return err + } + filePath, _ := cmd.Flags().GetString("file") + data, err := os.ReadFile(filePath) + if err != nil { + return fmt.Errorf("reading file: %w", err) + } + ac, err := newABACClient(cmd) + if err != nil { + return err + } + ps, err := ac.UpdatePolicySet(cmd.Context(), args[0], string(data)) + if err != nil { + if abac.IsNotFound(err) { + return fmt.Errorf("policy set not found: %s", args[0]) + } + return fmt.Errorf("updating policy set: %w", err) + } + if output.GetFormat() == output.FormatTable { + output.Info("Updated policy set %q (version %d)", ps.Name, ps.Version) + return nil + } + return output.Print(ps) + }, +} + +var policySetDeleteCmd = &cobra.Command{ + Use: "delete ", + Short: "Delete a policy set", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + if err := validateUUID(args[0], "policy set"); err != nil { + return err + } + force, _ := cmd.Flags().GetBool("force") + if !force { + if !term.IsTerminal(int(os.Stdin.Fd())) { + return fmt.Errorf("use --force to delete without confirmation") + } + fmt.Fprintf(cmd.ErrOrStderr(), + "Delete policy set %s? This will also remove all attachments. [y/N]: ", args[0]) + reader := bufio.NewReader(os.Stdin) + response, _ := reader.ReadString('\n') + response = strings.TrimSpace(strings.ToLower(response)) + if response != "y" && response != "yes" { + return nil + } + } + ac, err := newABACClient(cmd) + if err != nil { + return err + } + if err := ac.DeletePolicySet(cmd.Context(), args[0]); err != nil { + if abac.IsNotFound(err) { + return fmt.Errorf("policy set not found: %s", args[0]) + } + return fmt.Errorf("deleting policy set: %w", err) + } + output.Info("Deleted policy set %s", args[0]) + return nil + }, +} + +var policySetValidateCmd = &cobra.Command{ + Use: "validate", + Short: "Validate a policy set YAML file", + RunE: func(cmd *cobra.Command, _ []string) error { + filePath, _ := cmd.Flags().GetString("file") + data, err := os.ReadFile(filePath) + if err != nil { + return fmt.Errorf("reading file: %w", err) + } + ac, err := newABACClient(cmd) + if err != nil { + return &ExitCodeError{Code: 2, Err: err} + } + result, err := ac.ValidatePolicySet(cmd.Context(), string(data)) + if err != nil { + return &ExitCodeError{Code: 2, Err: fmt.Errorf("validating policy set: %w", err)} + } + + if output.GetFormat() != output.FormatTable { + if err := output.Print(result); err != nil { + return err + } + } + + if result.Valid { + if output.GetFormat() == output.FormatTable { + output.Info("Validation passed.") + } + return nil + } + + // Invalid -- print errors and exit 1. + if output.GetFormat() == output.FormatTable { + fmt.Fprintln(cmd.ErrOrStderr(), "Validation failed:") + for _, ve := range result.Errors { + if ve.PolicyID != "" { + fmt.Fprintf(cmd.ErrOrStderr(), " - %s (policy: %s)\n", ve.Message, ve.PolicyID) + } else { + fmt.Fprintf(cmd.ErrOrStderr(), " - %s\n", ve.Message) + } + } + } + cmd.SilenceErrors = true + return &ExitCodeError{Code: 1, Err: fmt.Errorf("validation failed")} + }, +} + +// truncateUUID returns the first 8 characters of a UUID followed by "...". +func truncateUUID(id string) string { + if len(id) > 8 { + return id[:8] + "..." + } + return id +} + +func init() { + rootCmd.AddCommand(policySetCmd) + policySetCmd.AddCommand(policySetListCmd) + policySetCmd.AddCommand(policySetGetCmd) + policySetCmd.AddCommand(policySetCreateCmd) + policySetCmd.AddCommand(policySetUpdateCmd) + policySetCmd.AddCommand(policySetDeleteCmd) + policySetCmd.AddCommand(policySetValidateCmd) + + policySetCreateCmd.Flags().String("file", "", "Path to policy set YAML file") + _ = policySetCreateCmd.MarkFlagRequired("file") + policySetUpdateCmd.Flags().String("file", "", "Path to updated policy set YAML file") + _ = policySetUpdateCmd.MarkFlagRequired("file") + policySetValidateCmd.Flags().String("file", "", "Path to policy set YAML file to validate") + _ = policySetValidateCmd.MarkFlagRequired("file") + policySetDeleteCmd.Flags().Bool("force", false, "Skip confirmation prompt") + + for _, cmd := range []*cobra.Command{ + policySetListCmd, policySetGetCmd, policySetCreateCmd, + policySetUpdateCmd, policySetDeleteCmd, policySetValidateCmd, + } { + client.SetCredentialType(cmd, client.CredentialBearer) + } +} diff --git a/internal/cli/policy_set_test.go b/internal/cli/policy_set_test.go new file mode 100644 index 0000000..92dbf37 --- /dev/null +++ b/internal/cli/policy_set_test.go @@ -0,0 +1,481 @@ +package cli + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/spf13/cobra" + + "github.com/getzep/zepctl/internal/abac" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --- Output capture helpers --- + +// captureStderr redirects os.Stderr during fn and returns what was written. +// (captureStdout is defined in auth_test.go and shared across this package.) +func captureStderr(t *testing.T, fn func()) string { + t.Helper() + old := os.Stderr + r, w, err := os.Pipe() + require.NoError(t, err) + os.Stderr = w + defer func() { + w.Close() + os.Stderr = old + }() + + fn() + + w.Close() + os.Stderr = old + + var buf bytes.Buffer + _, _ = io.Copy(&buf, r) + return buf.String() +} + +// testUUID is an arbitrary valid UUID used throughout tests. +const testUUID = "9f1a2b3c-4d5e-6f7a-8b9c-0d1e2f3a4b5c" + +// --- UUID Validation --- + +func TestValidateUUID_Valid(t *testing.T) { + err := validateUUID(testUUID, "policy set") + assert.NoError(t, err) +} + +func TestValidateUUID_Invalid(t *testing.T) { + err := validateUUID("not-a-uuid", "policy set") + require.Error(t, err) + assert.Contains(t, err.Error(), `invalid policy set UUID: "not-a-uuid"`) +} + +func TestValidateUUID_Empty(t *testing.T) { + err := validateUUID("", "API key") + require.Error(t, err) + assert.Contains(t, err.Error(), `invalid API key UUID`) +} + +// --- ExitCodeError --- + +func TestExitCodeError(t *testing.T) { + inner := errors.New("something failed") + err := &ExitCodeError{Code: 2, Err: inner} + + assert.Equal(t, "something failed", err.Error()) + assert.Equal(t, 2, err.Code) + assert.True(t, errors.Is(err, inner)) +} + +// --- Truncate UUID --- + +func TestTruncateUUID(t *testing.T) { + assert.Equal(t, "9f1a2b3c...", truncateUUID(testUUID)) + assert.Equal(t, "short", truncateUUID("short")) +} + +// --- Credential Type --- + +func TestPolicySetCommands_CredentialType(t *testing.T) { + cmds := map[string]*cobra.Command{ + "list": policySetListCmd, + "get": policySetGetCmd, + "create": policySetCreateCmd, + "update": policySetUpdateCmd, + "delete": policySetDeleteCmd, + "validate": policySetValidateCmd, + } + for name, cmd := range cmds { + t.Run(name, func(t *testing.T) { + require.NotNil(t, cmd.Annotations, "command %q has no annotations", name) + assert.Equal(t, "bearer", cmd.Annotations["zepctl_credential_type"], + "command %q should declare CredentialBearer", name) + }) + } +} + +// --- Policy Set List (Sections 3.1.2, 3.1.3) --- + +func TestPolicySetList_TableOutput(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(abac.PolicySetList{ + PolicySets: []abac.PolicySetSummary{ + {UUID: "uuid-1", Name: "policy_one", Mode: "enforce", Version: 2}, + {UUID: "uuid-2", Name: "policy_two", Mode: "off", Version: 1}, + }, + }) + })) + defer srv.Close() + withTestABACClient(t, srv) + + cmd := newTestCommand(policySetListCmd) + var runErr error + out := captureStdout(t, func() { + runErr = cmd.RunE(cmd, nil) + }) + assert.NoError(t, runErr) + + assert.Contains(t, out, "UUID") + assert.Contains(t, out, "NAME") + assert.Contains(t, out, "MODE") + assert.Contains(t, out, "VERSION") + assert.Contains(t, out, "uuid-1") + assert.Contains(t, out, "policy_one") + assert.Contains(t, out, "enforce") + assert.Contains(t, out, "uuid-2") + assert.Contains(t, out, "policy_two") + assert.Contains(t, out, "off") +} + +func TestPolicySetList_EmptyResult(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(abac.PolicySetList{PolicySets: []abac.PolicySetSummary{}}) + })) + defer srv.Close() + withTestABACClient(t, srv) + + cmd := newTestCommand(policySetListCmd) + var runErr error + out := captureStdout(t, func() { + runErr = cmd.RunE(cmd, nil) + }) + assert.NoError(t, runErr) + + assert.Contains(t, out, "UUID") + assert.Contains(t, out, "NAME") + lines := strings.Split(strings.TrimSpace(out), "\n") + assert.Equal(t, 1, len(lines), "expected only header row for empty result") +} + +// --- Policy Set Get --- + +func TestPolicySetGet_TableOutput(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(abac.PolicySet{ + UUID: testUUID, + Name: "test_policy", Description: "A test policy", + Mode: "report_only", Version: 3, + ProjectUUID: "proj-uuid", + CreatedAt: "2026-04-15T10:30:00Z", + UpdatedAt: "2026-04-20T14:15:00Z", + Spec: map[string]any{"policies": []any{}}, + }) + })) + defer srv.Close() + withTestABACClient(t, srv) + + cmd := newTestCommand(policySetGetCmd) + var buf bytes.Buffer + cmd.SetOut(&buf) + + err := cmd.RunE(cmd, []string{testUUID}) + assert.NoError(t, err) + + out := buf.String() + assert.Contains(t, out, "UUID:") + assert.Contains(t, out, "9f1a2b3c-4d5e-6f7a-8b9c-0d1e2f3a4b5c") + assert.Contains(t, out, "Name:") + assert.Contains(t, out, "test_policy") + assert.Contains(t, out, "Description:") + assert.Contains(t, out, "A test policy") + assert.Contains(t, out, "Mode:") + assert.Contains(t, out, "report_only") + assert.Contains(t, out, "Version:") + assert.Contains(t, out, "3") + assert.Contains(t, out, "Project:") + assert.Contains(t, out, "proj-uuid") + assert.Contains(t, out, "Created:") + assert.Contains(t, out, "Updated:") + assert.Contains(t, out, "Spec:") + assert.Contains(t, out, "policies") +} + +func TestPolicySetGet_InvalidUUID(t *testing.T) { + err := policySetGetCmd.RunE(policySetGetCmd, []string{"not-valid"}) + require.Error(t, err) + assert.Contains(t, err.Error(), `invalid policy set UUID: "not-valid"`) +} + +// --- Policy Set Create --- + +func TestPolicySetCreate_FileNotFound(t *testing.T) { + cmd := newTestCommand(policySetCreateCmd) + _ = cmd.Flags().Set("file", "/nonexistent/path.yaml") + + err := cmd.RunE(cmd, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "reading file:") +} + +func TestPolicySetCreate_TableOutput(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(abac.PolicySet{ + UUID: testUUID, + Name: "new_policy", Mode: "off", Version: 1, + }) + })) + defer srv.Close() + withTestABACClient(t, srv) + + cmd := newTestCommand(policySetCreateCmd) + _ = cmd.Flags().Set("file", writeTestFile(t, "yaml content")) + + var runErr error + out := captureStderr(t, func() { + runErr = cmd.RunE(cmd, nil) + }) + assert.NoError(t, runErr) + + assert.Contains(t, out, "Created policy set") + assert.Contains(t, out, `"new_policy"`) + assert.Contains(t, out, "9f1a2b3c...") + assert.Contains(t, out, "version 1") +} + +// --- Policy Set Update (Sections 3.4.2, 3.4.3, 3.4.4) --- + +func TestPolicySetUpdate_InvalidUUID(t *testing.T) { + err := policySetUpdateCmd.RunE(policySetUpdateCmd, []string{"bad-uuid"}) + require.Error(t, err) + assert.Contains(t, err.Error(), `invalid policy set UUID: "bad-uuid"`) +} + +func TestPolicySetUpdate_FileNotFound(t *testing.T) { + cmd := newTestCommand(policySetUpdateCmd) + _ = cmd.Flags().Set("file", "/nonexistent/path.yaml") + + err := cmd.RunE(cmd, []string{testUUID}) + require.Error(t, err) + assert.Contains(t, err.Error(), "reading file:") +} + +func TestPolicySetUpdate_TableOutput(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(abac.PolicySet{ + UUID: testUUID, + Name: "updated_policy", Mode: "enforce", Version: 4, + }) + })) + defer srv.Close() + withTestABACClient(t, srv) + + cmd := newTestCommand(policySetUpdateCmd) + _ = cmd.Flags().Set("file", writeTestFile(t, "yaml content")) + + var runErr error + out := captureStderr(t, func() { + runErr = cmd.RunE(cmd, []string{testUUID}) + }) + assert.NoError(t, runErr) + + assert.Contains(t, out, "Updated policy set") + assert.Contains(t, out, `"updated_policy"`) + assert.Contains(t, out, "version 4") +} + +func TestPolicySetUpdate_NotFound(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _ = json.NewEncoder(w).Encode(map[string]string{"message": "not found"}) + })) + defer srv.Close() + withTestABACClient(t, srv) + + cmd := newTestCommand(policySetUpdateCmd) + _ = cmd.Flags().Set("file", writeTestFile(t, "yaml content")) + + err := cmd.RunE(cmd, []string{testUUID}) + require.Error(t, err) + assert.Contains(t, err.Error(), "policy set not found: 9f1a2b3c") +} + +// --- Policy Set Delete (Sections 3.5.3, 3.5.4, 3.5.5) --- + +func TestPolicySetDelete_NonInteractiveWithoutForce(t *testing.T) { + // Replace stdin with a closed pipe (non-terminal, no data). + oldStdin := os.Stdin + r, w, _ := os.Pipe() + os.Stdin = r + w.Close() + defer func() { os.Stdin = oldStdin }() + + err := policySetDeleteCmd.RunE(policySetDeleteCmd, []string{testUUID}) + require.Error(t, err) + assert.Contains(t, err.Error(), "use --force to delete without confirmation") +} + +func TestPolicySetDelete_Force(t *testing.T) { + var gotMethod, gotPath string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotMethod = r.Method + gotPath = r.URL.Path + w.WriteHeader(http.StatusNoContent) + })) + defer srv.Close() + withTestABACClient(t, srv) + + cmd := newTestCommand(policySetDeleteCmd) + _ = cmd.Flags().Set("force", "true") + defer func() { _ = policySetDeleteCmd.Flags().Set("force", "false") }() + + var runErr error + out := captureStderr(t, func() { + runErr = cmd.RunE(cmd, []string{testUUID}) + }) + assert.NoError(t, runErr) + assert.Equal(t, http.MethodDelete, gotMethod) + assert.Equal(t, "/api/v2/abac/policy-sets/"+testUUID, gotPath) + assert.Contains(t, out, "Deleted policy set") + assert.Contains(t, out, testUUID) +} + +func TestPolicySetDelete_InvalidUUID(t *testing.T) { + err := policySetDeleteCmd.RunE(policySetDeleteCmd, []string{"bad-uuid"}) + require.Error(t, err) + assert.Contains(t, err.Error(), `invalid policy set UUID: "bad-uuid"`) +} + +func TestPolicySetDelete_NotFound(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _ = json.NewEncoder(w).Encode(map[string]string{"message": "not found"}) + })) + defer srv.Close() + withTestABACClient(t, srv) + + cmd := newTestCommand(policySetDeleteCmd) + _ = cmd.Flags().Set("force", "true") + defer func() { _ = policySetDeleteCmd.Flags().Set("force", "false") }() + + err := cmd.RunE(cmd, []string{testUUID}) + require.Error(t, err) + assert.Contains(t, err.Error(), "policy set not found: 9f1a2b3c") +} + +// --- Policy Set Validate (Sections 3.6.3, 3.6.4, 3.6.5) --- + +// withTestABACClient replaces newABACClient with one that returns a client +// pointing at srv. Restores the original on cleanup. +func withTestABACClient(t *testing.T, srv *httptest.Server) { + t.Helper() + orig := newABACClient + newABACClient = func(_ *cobra.Command) (*abac.Client, error) { + return abac.NewClient(srv.Client(), srv.URL, "test-project"), nil + } + t.Cleanup(func() { newABACClient = orig }) +} + +func TestPolicySetValidate_FileNotFound(t *testing.T) { + cmd := newTestCommand(policySetValidateCmd) + _ = cmd.Flags().Set("file", "/nonexistent/path.yaml") + + err := cmd.RunE(cmd, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "reading file:") + // File errors should NOT be ExitCodeError (default exit 1). + var exitErr *ExitCodeError + assert.False(t, errors.As(err, &exitErr)) +} + +func TestPolicySetValidate_TableOutput_Passed(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(abac.ValidationResult{Valid: true, Errors: []abac.ValidationError{}}) + })) + defer srv.Close() + withTestABACClient(t, srv) + + cmd := newTestCommand(policySetValidateCmd) + _ = cmd.Flags().Set("file", writeTestFile(t, "valid yaml")) + + var runErr error + out := captureStderr(t, func() { + runErr = cmd.RunE(cmd, nil) + }) + assert.NoError(t, runErr) + assert.Contains(t, out, "Validation passed.") +} + +func TestPolicySetValidate_TableOutput_Failed(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(abac.ValidationResult{ + Valid: false, + Errors: []abac.ValidationError{ + {PolicyID: "bad_policy", Message: "unrecognized action"}, + {Message: "missing required field"}, + }, + }) + })) + defer srv.Close() + withTestABACClient(t, srv) + + cmd := newTestCommand(policySetValidateCmd) + _ = cmd.Flags().Set("file", writeTestFile(t, "bad yaml")) + + var runErr error + out := captureStderr(t, func() { + runErr = cmd.RunE(cmd, nil) + }) + require.Error(t, runErr) + + var exitErr *ExitCodeError + require.True(t, errors.As(runErr, &exitErr)) + assert.Equal(t, 1, exitErr.Code) + + assert.Contains(t, out, "Validation failed:") + assert.Contains(t, out, "unrecognized action (policy: bad_policy)") + assert.Contains(t, out, "missing required field") +} + +func TestPolicySetValidate_APIError_ExitCode2(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _ = json.NewEncoder(w).Encode(map[string]string{"message": "server error"}) + })) + defer srv.Close() + withTestABACClient(t, srv) + + cmd := newTestCommand(policySetValidateCmd) + _ = cmd.Flags().Set("file", writeTestFile(t, "yaml content")) + + err := cmd.RunE(cmd, nil) + require.Error(t, err) + var exitErr *ExitCodeError + require.True(t, errors.As(err, &exitErr)) + assert.Equal(t, 2, exitErr.Code) +} + +// --- Helpers --- + +// newTestCommand creates a copy of cmd with a background context set. +func newTestCommand(cmd *cobra.Command) *cobra.Command { + clone := *cmd + clone.SetContext(context.Background()) + return &clone +} + +// writeTestFile creates a temp file with the given content and returns its path. +func writeTestFile(t *testing.T, content string) string { + t.Helper() + f := filepath.Join(t.TempDir(), "policy.yaml") + require.NoError(t, os.WriteFile(f, []byte(content), 0o600)) + return f +} diff --git a/internal/cli/project.go b/internal/cli/project.go index 30ddc06..6cb0756 100644 --- a/internal/cli/project.go +++ b/internal/cli/project.go @@ -19,7 +19,7 @@ var projectGetCmd = &cobra.Command{ Use: "get", Short: "Get project information", RunE: func(cmd *cobra.Command, args []string) error { - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } diff --git a/internal/cli/root.go b/internal/cli/root.go index 85e7425..53cb6a7 100644 --- a/internal/cli/root.go +++ b/internal/cli/root.go @@ -9,6 +9,8 @@ import ( "github.com/spf13/viper" ) +const listCmdUse = "list" + var ( // Version information set by goreleaser. version = "dev" @@ -42,11 +44,13 @@ func init() { rootCmd.PersistentFlags().StringP("output", "o", "table", "Output format: table, json, yaml, wide") rootCmd.PersistentFlags().BoolP("quiet", "q", false, "Suppress non-essential output") rootCmd.PersistentFlags().BoolP("verbose", "v", false, "Enable verbose output") + rootCmd.PersistentFlags().String("project", "", "Project UUID for this command") _ = viper.BindPFlag("api-key", rootCmd.PersistentFlags().Lookup("api-key")) _ = viper.BindPFlag("api-url", rootCmd.PersistentFlags().Lookup("api-url")) _ = viper.BindPFlag("profile", rootCmd.PersistentFlags().Lookup("profile")) _ = viper.BindPFlag("output", rootCmd.PersistentFlags().Lookup("output")) + _ = viper.BindPFlag("project", rootCmd.PersistentFlags().Lookup("project")) } func initConfig() { diff --git a/internal/cli/set_project.go b/internal/cli/set_project.go new file mode 100644 index 0000000..a911689 --- /dev/null +++ b/internal/cli/set_project.go @@ -0,0 +1,141 @@ +package cli + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "net/http" + "os" + "strconv" + "strings" + + "github.com/getzep/zepctl/internal/client" + "github.com/getzep/zepctl/internal/config" + "github.com/getzep/zepctl/internal/output" + "github.com/spf13/cobra" +) + +var configSetProjectCmd = &cobra.Command{ + Use: "set-project [uuid]", + Short: "Set the active project for the current profile", + Long: `Set the active project for the current profile. If a UUID is provided, +it is set directly. Otherwise, the accessible projects are fetched and +you are prompted to choose.`, + Args: cobra.MaximumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + cfg, err := config.Load() + if err != nil { + return fmt.Errorf("loading config: %w", err) + } + + profile := cfg.GetCurrentProfile() + if profile == nil { + return fmt.Errorf("no active profile") + } + + if len(args) == 1 { + // Direct UUID assignment. + profile.ProjectUUID = args[0] + if err := cfg.Save(); err != nil { + return fmt.Errorf("saving config: %w", err) + } + output.Info("Active project set to %s", args[0]) + return nil + } + + // Interactive project selection. + httpClient, baseURL, err := client.NewBearerHTTPClient(cmd.Context()) + if err != nil { + return fmt.Errorf("creating client: %w", err) + } + + accountUUID, projects, err := authenticateAndGetProjects(cmd.Context(), httpClient, baseURL) + if err != nil { + return fmt.Errorf("fetching account and projects: %w", err) + } + profile.AccountUUID = accountUUID + // Persist the resolved account UUID immediately so it survives even + // if there are zero projects or the user aborts interactive + // selection. Mirrors the auth login path in auth.go. + if err := cfg.Save(); err != nil { + return fmt.Errorf("saving account UUID: %w", err) + } + + if len(projects) == 0 { + return fmt.Errorf("no projects found for this account") + } + + var selected projectInfo + if len(projects) == 1 { + selected = projects[0] + output.Info("Auto-selected project %q (%s)", selected.Name, selected.UUID) + } else { + fmt.Println("Select a project:") + for i, p := range projects { + fmt.Printf(" %d. %s (%s)\n", i+1, p.Name, p.UUID) + } + fmt.Printf("Select a project [1-%d]: ", len(projects)) + + reader := bufio.NewReader(os.Stdin) + response, _ := reader.ReadString('\n') + response = strings.TrimSpace(response) + + idx, err := strconv.Atoi(response) + if err != nil || idx < 1 || idx > len(projects) { + return fmt.Errorf("invalid selection: %q", response) + } + selected = projects[idx-1] + } + + profile.ProjectUUID = selected.UUID + if err := cfg.Save(); err != nil { + return fmt.Errorf("saving config: %w", err) + } + + output.Info("Active project set to %s (%s)", selected.UUID, selected.Name) + return nil + }, +} + +type projectInfo struct { + UUID string `json:"uuid"` + Name string `json:"name"` +} + +// authenticateAndGetProjects calls POST /api/web/v1/authenticate to resolve +// the bearer-authenticated user's account UUID and accessible projects in a +// single round trip. The endpoint returns AccountMemberResponse, which carries +// both account_uuid and the projects array (src/api/apidata/account.go). +func authenticateAndGetProjects(ctx context.Context, httpClient *http.Client, baseURL string) (string, []projectInfo, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/api/web/v1/authenticate", http.NoBody) + if err != nil { + return "", nil, fmt.Errorf("building authenticate request: %w", err) + } + + resp, err := httpClient.Do(req) + if err != nil { + return "", nil, fmt.Errorf("calling authenticate: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", nil, fmt.Errorf("authenticate returned %d", resp.StatusCode) + } + + var result struct { + AccountUUID string `json:"account_uuid"` + Projects []projectInfo `json:"projects"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", nil, fmt.Errorf("parsing authenticate response: %w", err) + } + if result.AccountUUID == "" { + return "", nil, fmt.Errorf("no account_uuid in authenticate response") + } + return result.AccountUUID, result.Projects, nil +} + +func init() { + configCmd.AddCommand(configSetProjectCmd) +} diff --git a/internal/cli/set_project_test.go b/internal/cli/set_project_test.go new file mode 100644 index 0000000..70249ad --- /dev/null +++ b/internal/cli/set_project_test.go @@ -0,0 +1,123 @@ +package cli + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/getzep/zepctl/internal/config" + "github.com/spf13/viper" +) + +// TestConfigSetProject_DirectUUID verifies that "config set-project " +// sets the project without interactive prompt. +func TestConfigSetProject_DirectUUID(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + viper.Reset() + t.Cleanup(func() { viper.Reset() }) + + // Write config file so Load() and Save() work. + writeTestConfig(t, tmpDir) + _, _ = config.Reload() + + cmd := configSetProjectCmd + cmd.SetContext(context.Background()) + + err := cmd.RunE(cmd, []string{"proj-uuid-456"}) + if err != nil { + t.Fatalf("set-project direct UUID: %v", err) + } + + // Reload and verify project was set. + cfg, err := config.Reload() + if err != nil { + t.Fatalf("Reload: %v", err) + } + profile := cfg.GetProfile("test") + if profile == nil { + t.Fatal("profile 'test' not found after set-project") + } + if profile.ProjectUUID != "proj-uuid-456" { + t.Errorf("ProjectUUID = %q, want %q", profile.ProjectUUID, "proj-uuid-456") + } +} + +func TestAuthenticateAndGetProjects(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/api/web/v1/authenticate" { + t.Errorf("unexpected request: %s %s", r.Method, r.URL.Path) + w.WriteHeader(http.StatusNotFound) + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "account_uuid": "acc-uuid-123", + "projects": []projectInfo{ + {UUID: "proj-1", Name: "Project One"}, + {UUID: "proj-2", Name: "Project Two"}, + }, + }) + })) + defer srv.Close() + + accountUUID, projects, err := authenticateAndGetProjects(context.Background(), srv.Client(), srv.URL) + if err != nil { + t.Fatalf("authenticateAndGetProjects: %v", err) + } + if accountUUID != "acc-uuid-123" { + t.Errorf("account UUID = %q, want %q", accountUUID, "acc-uuid-123") + } + if len(projects) != 2 { + t.Fatalf("expected 2 projects, got %d", len(projects)) + } + if projects[0].UUID != "proj-1" || projects[0].Name != "Project One" { + t.Errorf("first project = %+v, want {UUID:proj-1, Name:Project One}", projects[0]) + } +} + +func TestAuthenticateAndGetProjects_Empty(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "account_uuid": "acc-uuid-123", + "projects": []projectInfo{}, + }) + })) + defer srv.Close() + + _, projects, err := authenticateAndGetProjects(context.Background(), srv.Client(), srv.URL) + if err != nil { + t.Fatalf("authenticateAndGetProjects: %v", err) + } + if len(projects) != 0 { + t.Errorf("expected 0 projects, got %d", len(projects)) + } +} + +func TestAuthenticateAndGetProjects_Error(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + defer srv.Close() + + _, _, err := authenticateAndGetProjects(context.Background(), srv.Client(), srv.URL) + if err == nil { + t.Error("expected error for 401 response") + } +} + +func TestAuthenticateAndGetProjects_MissingAccountUUID(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{"projects": []projectInfo{}}) + })) + defer srv.Close() + + _, _, err := authenticateAndGetProjects(context.Background(), srv.Client(), srv.URL) + if err == nil { + t.Error("expected error when account_uuid is missing") + } +} diff --git a/internal/cli/summary_instructions.go b/internal/cli/summary_instructions.go index a1addc8..bf6f359 100644 --- a/internal/cli/summary_instructions.go +++ b/internal/cli/summary_instructions.go @@ -24,12 +24,12 @@ var summaryInstructionsCmd = &cobra.Command{ } var summaryInstructionsListCmd = &cobra.Command{ - Use: "list", + Use: listCmdUse, Short: "List summary instructions", RunE: func(cmd *cobra.Command, args []string) error { userID, _ := cmd.Flags().GetString("user") - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -92,7 +92,7 @@ var summaryInstructionsAddCmd = &cobra.Command{ return fmt.Errorf("instruction text exceeds maximum length of %d characters (got %d)", maxInstructionLength, len(instructionText)) } - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -148,7 +148,7 @@ var summaryInstructionsDeleteCmd = &cobra.Command{ } } - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } diff --git a/internal/cli/task.go b/internal/cli/task.go index 7d25122..bf28dc4 100644 --- a/internal/cli/task.go +++ b/internal/cli/task.go @@ -23,7 +23,7 @@ var taskGetCmd = &cobra.Command{ RunE: func(cmd *cobra.Command, args []string) error { taskID := args[0] - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -74,7 +74,7 @@ var taskWaitCmd = &cobra.Command{ timeout, _ := cmd.Flags().GetDuration("timeout") pollInterval, _ := cmd.Flags().GetDuration("poll-interval") - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } diff --git a/internal/cli/thread.go b/internal/cli/thread.go index 6ebe499..6c5dab9 100644 --- a/internal/cli/thread.go +++ b/internal/cli/thread.go @@ -22,10 +22,10 @@ var threadCmd = &cobra.Command{ } var threadListCmd = &cobra.Command{ - Use: "list", + Use: listCmdUse, Short: "List all threads", RunE: func(cmd *cobra.Command, args []string) error { - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -88,7 +88,7 @@ var threadCreateCmd = &cobra.Command{ return fmt.Errorf("--user flag is required") } - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -116,7 +116,7 @@ var threadGetCmd = &cobra.Command{ threadID := args[0] lastN, _ := cmd.Flags().GetInt("last") - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -175,7 +175,7 @@ var threadDeleteCmd = &cobra.Command{ } } - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -198,7 +198,7 @@ var threadMessagesCmd = &cobra.Command{ lastN, _ := cmd.Flags().GetInt("last") limit, _ := cmd.Flags().GetInt("limit") - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -287,7 +287,7 @@ var threadAddMessagesCmd = &cobra.Command{ return fmt.Errorf("parsing messages: %w", err) } - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -348,7 +348,7 @@ var threadContextCmd = &cobra.Command{ RunE: func(cmd *cobra.Command, args []string) error { threadID := args[0] - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } diff --git a/internal/cli/thread_summary.go b/internal/cli/thread_summary.go index 3f2e518..754868f 100644 --- a/internal/cli/thread_summary.go +++ b/internal/cli/thread_summary.go @@ -17,7 +17,7 @@ var threadSummaryCmd = &cobra.Command{ } var threadSummaryListCmd = &cobra.Command{ - Use: "list", + Use: listCmdUse, Short: "List thread summaries", RunE: func(cmd *cobra.Command, args []string) error { userID, _ := cmd.Flags().GetString("user") @@ -29,7 +29,7 @@ var threadSummaryListCmd = &cobra.Command{ return fmt.Errorf("either --user or --graph is required") } - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } diff --git a/internal/cli/user.go b/internal/cli/user.go index 4cb44dc..4599a4c 100644 --- a/internal/cli/user.go +++ b/internal/cli/user.go @@ -21,10 +21,10 @@ var userCmd = &cobra.Command{ } var userListCmd = &cobra.Command{ - Use: "list", + Use: listCmdUse, Short: "List users", RunE: func(cmd *cobra.Command, args []string) error { - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -94,7 +94,7 @@ var userGetCmd = &cobra.Command{ RunE: func(cmd *cobra.Command, args []string) error { userID := args[0] - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -141,7 +141,7 @@ var userCreateCmd = &cobra.Command{ RunE: func(cmd *cobra.Command, args []string) error { userID := args[0] - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -199,7 +199,7 @@ var userUpdateCmd = &cobra.Command{ RunE: func(cmd *cobra.Command, args []string) error { userID := args[0] - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -268,7 +268,7 @@ var userDeleteCmd = &cobra.Command{ } } - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -289,7 +289,7 @@ var userThreadsCmd = &cobra.Command{ RunE: func(cmd *cobra.Command, args []string) error { userID := args[0] - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } @@ -327,7 +327,7 @@ var userNodeCmd = &cobra.Command{ RunE: func(cmd *cobra.Command, args []string) error { userID := args[0] - c, err := client.New() + c, err := client.NewForCommand(cmd) if err != nil { return err } diff --git a/internal/client/client.go b/internal/client/client.go index 129d72f..58443cb 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -1,31 +1,210 @@ package client import ( + "context" + "errors" "fmt" + "net/http" zepclient "github.com/getzep/zep-go/v3/client" "github.com/getzep/zep-go/v3/option" + "github.com/getzep/zepctl/internal/auth" "github.com/getzep/zepctl/internal/config" + "github.com/getzep/zepctl/internal/keyring" + "github.com/spf13/cobra" + "golang.org/x/oauth2" ) +// CredentialType indicates which credential a command requires. +type CredentialType int + +const ( + // CredentialAPIKey means the command uses API key authentication. + CredentialAPIKey CredentialType = iota + // CredentialBearer means the command uses bearer token authentication. + CredentialBearer +) + +// credentialTypeAnnotation is the Cobra annotation key for declaring a +// command's credential type at registration time. +const credentialTypeAnnotation = "zepctl_credential_type" //nolint:gosec // Annotation key, not a credential + +// projectHeader is the HTTP header used to specify the target project +// for bearer-authenticated requests. +const projectHeader = "X-Zep-Project" + // Client is an alias for the Zep client. type Client = zepclient.Client -// New creates a new Zep client using the current configuration. -func New() (*Client, error) { - apiKey := config.GetAPIKey() - if apiKey == "" { - return nil, fmt.Errorf("no API key configured; set ZEP_API_KEY or configure a profile") +// SetCredentialType declares the credential type a command requires. Call +// this at command registration time (typically in an init function). +func SetCredentialType(cmd *cobra.Command, ct CredentialType) { + if cmd.Annotations == nil { + cmd.Annotations = map[string]string{} + } + switch ct { + case CredentialBearer: + cmd.Annotations[credentialTypeAnnotation] = "bearer" + default: + cmd.Annotations[credentialTypeAnnotation] = "api-key" + } +} + +// credentialTypeFromCommand reads the credential type declared on cmd via +// SetCredentialType. Returns CredentialAPIKey if no annotation is set. +func credentialTypeFromCommand(cmd *cobra.Command) CredentialType { + if cmd.Annotations != nil { + if v, ok := cmd.Annotations[credentialTypeAnnotation]; ok && v == "bearer" { + return CredentialBearer + } } + return CredentialAPIKey +} - opts := []option.RequestOption{ - option.WithAPIKey(apiKey), +// NewForCommand creates a Zep client using the credential type declared on +// cmd. If the --api-key flag or ZEP_API_KEY env var is set, API key auth +// is used regardless of the command's declaration. +func NewForCommand(cmd *cobra.Command) (*Client, error) { + // Explicit API key override bypasses the command's declared type. + if config.GetAPIKeyOverride() != "" { + return NewWithCredential(cmd.Context(), CredentialAPIKey) + } + return NewWithCredential(cmd.Context(), credentialTypeFromCommand(cmd)) +} + +// NewWithCredential creates a new Zep client using the specified credential type. +func NewWithCredential(ctx context.Context, credType CredentialType) (*Client, error) { + var opts []option.RequestOption + + switch credType { + case CredentialBearer: + httpClient, err := newBearerClient(ctx) + if err != nil { + return nil, err + } + headers := http.Header{} + if projectUUID := config.GetProjectUUID(); projectUUID != "" { + headers.Set(projectHeader, projectUUID) + } + opts = append(opts, + option.WithHTTPClient(httpClient), + option.WithHTTPHeader(headers), + ) + default: + apiKey := config.GetAPIKey() + if apiKey == "" { + return nil, fmt.Errorf("no API key configured; set ZEP_API_KEY or run \"zepctl config add-profile\"") + } + opts = append(opts, option.WithAPIKey(apiKey)) } - // Only set base URL if explicitly configured; otherwise use SDK default if apiURL := config.GetAPIURL(); apiURL != "" { opts = append(opts, option.WithBaseURL(apiURL)) } return zepclient.NewClient(opts...), nil } + +// newBearerClient returns an *http.Client that automatically attaches +// bearer tokens and refreshes them when expired via golang.org/x/oauth2. +// The returned client's transport is wrapped to handle refresh failures: +// on a token retrieval error the bearer token fields are cleared from the +// keychain and a user-facing message is returned. +func newBearerClient(ctx context.Context) (*http.Client, error) { + cfg, err := config.Load() + if err != nil { + return nil, fmt.Errorf("loading config: %w", err) + } + + profile := cfg.GetCurrentProfile() + if profile == nil { + return nil, fmt.Errorf("no active profile; run \"zepctl auth login\" to authenticate") + } + + oauthCfg := auth.OAuthConfigFor(profile.OAuthIssuer, profile.OAuthClientID, profile.OAuthAudience) + session := auth.NewKeychainSession(profile.Name) + + httpClient, err := auth.NewAutoRefreshClient(ctx, oauthCfg, session) + if err != nil { + return nil, err + } + + // Wrap the transport to detect refresh failures and clear stale tokens. + httpClient.Transport = &refreshFailureTransport{ + base: httpClient.Transport, + profile: profile.Name, + } + + return httpClient, nil +} + +// NewBearerHTTPClient returns a raw *http.Client that attaches bearer token +// auth headers. Refreshes the token if expired. Used for web API calls +// (account resolution, project listing) that go through the web middleware +// rather than the SDK client. +func NewBearerHTTPClient(ctx context.Context) (*http.Client, string, error) { + httpClient, err := newBearerClient(ctx) + if err != nil { + return nil, "", err + } + + apiURL := config.GetAPIURL() + if apiURL == "" { + apiURL = auth.DefaultAPIURL + } + + return httpClient, apiURL, nil +} + +// BearerTransport is an http.RoundTripper that adds a Bearer Authorization header. +// Used by autoSelectProject during login when the SDK client is not yet available. +type BearerTransport struct { + Token string + Base http.RoundTripper +} + +func (t *BearerTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req2 := req.Clone(req.Context()) + req2.Header.Set("Authorization", "Bearer "+t.Token) + return t.Base.RoundTrip(req2) +} + +// refreshFailureTransport wraps the oauth2 auto-refresh transport to +// detect token refresh failures. When a refresh fails (expired/revoked +// refresh token), it clears the bearer token fields from the keychain and +// returns a user-facing error message. +type refreshFailureTransport struct { + base http.RoundTripper + profile string +} + +func (t *refreshFailureTransport) RoundTrip(req *http.Request) (*http.Response, error) { + resp, err := t.base.RoundTrip(req) + if err == nil { + return resp, nil + } + + var retrieveErr *oauth2.RetrieveError + if !errors.As(err, &retrieveErr) { + return nil, err + } + + // If this is an invalid_grant error, another process may have already + // rotated the refresh token. Re-read from keychain and retry once with + // the (possibly updated) access token before giving up. The retry + // goes through http.DefaultTransport directly (not the oauth2 chain) + // to avoid recursive refresh attempts. If the retry also fails, the + // error is returned without the friendly "session expired" message -- + // acceptable for a CLI where this is an edge case. + if retrieveErr.ErrorCode == "invalid_grant" { + creds, kerr := keyring.GetCredentials(t.profile) + if kerr == nil && creds.HasBearerToken() && !creds.IsExpired() { + retry := req.Clone(req.Context()) + retry.Header.Set("Authorization", "Bearer "+creds.AccessToken) + return http.DefaultTransport.RoundTrip(retry) + } + } + + _ = auth.ClearBearerToken(t.profile) + return nil, fmt.Errorf("session expired; run \"zepctl auth login\" to re-authenticate") +} diff --git a/internal/client/client_test.go b/internal/client/client_test.go new file mode 100644 index 0000000..cc7523e --- /dev/null +++ b/internal/client/client_test.go @@ -0,0 +1,364 @@ +package client + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/getzep/zepctl/internal/config" + "github.com/getzep/zepctl/internal/keyring" + "github.com/spf13/cobra" + "github.com/spf13/viper" + gokeyring "github.com/zalando/go-keyring" + "golang.org/x/oauth2" +) + +func init() { + // Use in-memory mock keyring for tests. + gokeyring.MockInit() +} + +func TestBearerTransport_SetsHeader(t *testing.T) { + var gotAuth string + srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + })) + defer srv.Close() + + transport := &BearerTransport{ + Token: "test_bearer_token", + Base: http.DefaultTransport, + } + + httpClient := &http.Client{Transport: transport} + resp, err := httpClient.Get(srv.URL) //nolint:noctx // test-only, no context needed + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if gotAuth != "Bearer test_bearer_token" { + t.Errorf("Authorization header = %q, want %q", gotAuth, "Bearer test_bearer_token") + } +} + +func TestCredentialType_Constants(t *testing.T) { + if CredentialAPIKey == CredentialBearer { + t.Error("CredentialAPIKey and CredentialBearer should be different") + } +} + +func TestSetCredentialType_RoundTrip(t *testing.T) { + tests := []struct { + name string + set CredentialType + want CredentialType + }{ + {"api key", CredentialAPIKey, CredentialAPIKey}, + {"bearer", CredentialBearer, CredentialBearer}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + SetCredentialType(cmd, tt.set) + got := credentialTypeFromCommand(cmd) + if got != tt.want { + t.Errorf("credentialTypeFromCommand() = %d, want %d", got, tt.want) + } + }) + } +} + +func TestCredentialTypeFromCommand_DefaultsToAPIKey(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + if got := credentialTypeFromCommand(cmd); got != CredentialAPIKey { + t.Errorf("credentialTypeFromCommand() = %d, want CredentialAPIKey", got) + } +} + +func TestRefreshFailureTransport_PassesThrough(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + transport := &refreshFailureTransport{ + base: http.DefaultTransport, + profile: "test-refresh-passthrough", + } + + httpClient := &http.Client{Transport: transport} + resp, err := httpClient.Get(srv.URL) //nolint:noctx // test-only + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected 200, got %d", resp.StatusCode) + } +} + +func TestRefreshFailureTransport_DetectsRetrieveError(t *testing.T) { + // Simulate the base transport returning an oauth2.RetrieveError, + // which happens when the Kinde SDK tries to refresh an expired token. + retrieveErr := &oauth2.RetrieveError{ + Response: &http.Response{StatusCode: http.StatusUnauthorized}, + } + base := roundTripFunc(func(_ *http.Request) (*http.Response, error) { + return nil, retrieveErr + }) + + transport := &refreshFailureTransport{ + base: base, + profile: "test-refresh-detect", + } + + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.com/test", http.NoBody) + resp, err := transport.RoundTrip(req) + if resp != nil { + resp.Body.Close() + } + if err == nil { + t.Fatal("expected error, got nil") + } + want := `session expired; run "zepctl auth login" to re-authenticate` + if err.Error() != want { + t.Errorf("error = %q, want %q", err.Error(), want) + } +} + +func TestRefreshFailureTransport_NonRetrieveErrorPassesThrough(t *testing.T) { + // Non-oauth2 errors should pass through unmodified. + origErr := fmt.Errorf("network timeout") + base := roundTripFunc(func(_ *http.Request) (*http.Response, error) { + return nil, origErr + }) + + transport := &refreshFailureTransport{ + base: base, + profile: "test-refresh-other", + } + + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.com/test", http.NoBody) + resp, err := transport.RoundTrip(req) + if resp != nil { + resp.Body.Close() + } + if !errors.Is(err, origErr) { + t.Errorf("expected original error %v, got %v", origErr, err) + } +} + +func TestRefreshFailureTransport_InvalidGrantRetry(t *testing.T) { + // Simulate the concurrent refresh race: the base transport returns + // invalid_grant (another process already rotated the refresh token), + // but fresh credentials are available in the keychain. + profile := "test-invalid-grant-retry" + + // Pre-load fresh credentials into the mock keychain. + freshCreds := &keyring.Credentials{ + AccessToken: "fresh_access_token", + RefreshToken: "fresh_refresh_token", + ExpiresAt: time.Now().Add(time.Hour).Format(time.RFC3339), + } + if err := keyring.SetCredentials(profile, freshCreds); err != nil { + t.Fatalf("seeding keychain: %v", err) + } + + // Backend server that expects the fresh token. + var gotAuth string + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + + invalidGrant := &oauth2.RetrieveError{ + Response: &http.Response{StatusCode: 400}, + ErrorCode: "invalid_grant", + } + var calls atomic.Int32 + base := roundTripFunc(func(req *http.Request) (*http.Response, error) { + if calls.Add(1) == 1 { + // First call: simulate the SDK's failed refresh. + return nil, invalidGrant + } + // Should not reach here -- the retry bypasses this transport. + return nil, fmt.Errorf("unexpected second call to base transport") + }) + + transport := &refreshFailureTransport{ + base: base, + profile: profile, + } + + // The request must target the real backend so the retry succeeds. + req, _ := http.NewRequest(http.MethodGet, backend.URL+"/test", http.NoBody) //nolint:noctx // test-only + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip returned error: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("status = %d, want 200", resp.StatusCode) + } + if gotAuth != "Bearer fresh_access_token" { + t.Errorf("retry Authorization = %q, want %q", gotAuth, "Bearer fresh_access_token") + } +} + +func TestRefreshFailureTransport_InvalidGrantExpiredFallback(t *testing.T) { + // When invalid_grant fires but keychain creds are also expired, + // it should fall through to the clear-and-error path. + profile := "test-invalid-grant-expired" + + expiredCreds := &keyring.Credentials{ + AccessToken: "stale_token", + RefreshToken: "stale_refresh", + ExpiresAt: time.Now().Add(-time.Hour).Format(time.RFC3339), + } + if err := keyring.SetCredentials(profile, expiredCreds); err != nil { + t.Fatalf("seeding keychain: %v", err) + } + + invalidGrant := &oauth2.RetrieveError{ + Response: &http.Response{StatusCode: 400}, + ErrorCode: "invalid_grant", + } + base := roundTripFunc(func(_ *http.Request) (*http.Response, error) { + return nil, invalidGrant + }) + + transport := &refreshFailureTransport{ + base: base, + profile: profile, + } + + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.com/test", http.NoBody) + resp, err := transport.RoundTrip(req) + if resp != nil { + resp.Body.Close() + } + if err == nil { + t.Fatal("expected error, got nil") + } + want := `session expired; run "zepctl auth login" to re-authenticate` + if err.Error() != want { + t.Errorf("error = %q, want %q", err.Error(), want) + } +} + +func TestNewWithCredential_APIKey_MissingKey(t *testing.T) { + viper.Reset() + defer viper.Reset() + + t.Setenv("HOME", t.TempDir()) + _, _ = config.Reload() + + _, err := NewWithCredential(context.Background(), CredentialAPIKey) + if err == nil { + t.Fatal("expected error for missing API key") + } + if !strings.Contains(err.Error(), "no API key configured") { + t.Errorf("error = %q, want to contain %q", err.Error(), "no API key configured") + } +} + +func TestNewForCommand_APIKeyOverrideForcesAPIKey(t *testing.T) { + viper.Reset() + defer viper.Reset() + + t.Setenv("HOME", t.TempDir()) + _, _ = config.Reload() + + viper.Set("api-key", "z_override_key") + + cmd := &cobra.Command{Use: "test"} + cmd.SetContext(context.Background()) + SetCredentialType(cmd, CredentialBearer) // command declares bearer + + // Should succeed using API key despite bearer declaration. + c, err := NewForCommand(cmd) + if err != nil { + t.Fatalf("NewForCommand with API key override: %v", err) + } + if c == nil { + t.Fatal("expected non-nil client") + } +} + +func TestRefreshFailureTransport_ClearsKeychainOnFailure(t *testing.T) { + profile := "test-refresh-clear-keychain" + + creds := &keyring.Credentials{ + APIKey: "z_preserved", + AccessToken: "old_access", + RefreshToken: "old_refresh", + ExpiresAt: time.Now().Add(time.Hour).Format(time.RFC3339), + } + if err := keyring.SetCredentials(profile, creds); err != nil { + t.Fatalf("seeding keychain: %v", err) + } + + // Non-invalid_grant RetrieveError triggers clear-and-error path. + retrieveErr := &oauth2.RetrieveError{ + Response: &http.Response{StatusCode: http.StatusUnauthorized}, + } + base := roundTripFunc(func(_ *http.Request) (*http.Response, error) { + return nil, retrieveErr + }) + + transport := &refreshFailureTransport{ + base: base, + profile: profile, + } + + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.com/test", http.NoBody) + resp, _ := transport.RoundTrip(req) + if resp != nil { + resp.Body.Close() + } + + // Verify bearer token was cleared but API key preserved. + got, err := keyring.GetCredentials(profile) + if err != nil { + t.Fatalf("GetCredentials: %v", err) + } + if got.HasBearerToken() { + t.Error("bearer token should be cleared after refresh failure") + } + if got.APIKey != "z_preserved" { + t.Errorf("API key should be preserved, got %q", got.APIKey) + } +} + +func TestNewWithCredential_Bearer_NoProfile(t *testing.T) { + viper.Reset() + defer viper.Reset() + + t.Setenv("HOME", t.TempDir()) + _, _ = config.Reload() + + _, err := NewWithCredential(context.Background(), CredentialBearer) + if err == nil { + t.Fatal("expected error for missing profile with bearer credential") + } + if !strings.Contains(err.Error(), "auth login") { + t.Errorf("error = %q, want actionable message mentioning %q", err.Error(), "auth login") + } +} + +// roundTripFunc adapts a function to http.RoundTripper. +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} diff --git a/internal/config/config.go b/internal/config/config.go index aa119cf..10bb276 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -4,31 +4,58 @@ import ( "fmt" "os" "path/filepath" + "strings" "sync" "github.com/getzep/zepctl/internal/keyring" + "github.com/getzep/zepctl/internal/output" "github.com/spf13/viper" "gopkg.in/yaml.v3" ) var ( + configMu sync.Mutex cachedConfig *Config - configOnce sync.Once configErr error ) // Profile represents a named configuration profile. -// API keys are stored in the system keychain, not in this config file. +// Credentials (API key and/or bearer token) are stored in the system keychain, +// not in this config file. +// +// OAuthIssuer and OAuthClientID are optional per-profile overrides for the +// OIDC issuer and OAuth client ID used by `auth login`. When empty, the +// binary's build-time defaults are used (see internal/auth/config.go). +// Setting these per profile lets a single binary authenticate against +// multiple OAuth tenants -- e.g. one profile for production and another +// for development. type Profile struct { - Name string `yaml:"name"` - APIURL string `yaml:"api-url,omitempty"` + Name string `yaml:"name"` + APIURL string `yaml:"api-url,omitempty"` + AccountUUID string `yaml:"account-uuid,omitempty"` + ProjectUUID string `yaml:"project-uuid,omitempty"` + OAuthIssuer string `yaml:"oauth-issuer,omitempty"` + OAuthClientID string `yaml:"oauth-client-id,omitempty"` + OAuthAudience string `yaml:"oauth-audience,omitempty"` +} + +// Environment is a named preset of auth fields (api-url plus the OAuth +// issuer/client-id/audience triple) stored in the user's config file. +// See `zepctl config add-environment`. +type Environment struct { + Name string `yaml:"name"` + APIURL string `yaml:"api-url,omitempty"` + OAuthIssuer string `yaml:"oauth-issuer,omitempty"` + OAuthClientID string `yaml:"oauth-client-id,omitempty"` + OAuthAudience string `yaml:"oauth-audience,omitempty"` } // Config represents the zepctl configuration. type Config struct { - CurrentProfile string `yaml:"current-profile"` - Profiles []Profile `yaml:"profiles"` - Defaults Defaults `yaml:"defaults"` + CurrentProfile string `yaml:"current-profile"` + Profiles []Profile `yaml:"profiles"` + Environments []Environment `yaml:"environments,omitempty"` + Defaults Defaults `yaml:"defaults"` } // Defaults represents default settings. @@ -49,9 +76,12 @@ func GetConfigPath() (string, error) { // Load loads the configuration from disk. // The config is cached after the first load for efficiency. func Load() (*Config, error) { - configOnce.Do(func() { - cachedConfig, configErr = loadFromDisk() - }) + configMu.Lock() + defer configMu.Unlock() + if cachedConfig != nil || configErr != nil { + return cachedConfig, configErr + } + cachedConfig, configErr = loadFromDisk() return cachedConfig, configErr } @@ -80,20 +110,106 @@ func loadFromDisk() (*Config, error) { return nil, fmt.Errorf("parsing config file: %w", err) } + if dropped := dedupeProfiles(&cfg); len(dropped) > 0 { + output.Warn("config contained duplicate profile names: %s. Keeping the last entry (by file position) for each duplicate name; earlier duplicates will be removed on next save.", strings.Join(dropped, ", ")) + } + + if dropped := dedupeEnvironments(&cfg); len(dropped) > 0 { + output.Warn("config contained duplicate environment names: %s. Keeping the last entry (by file position) for each duplicate name; earlier duplicates will be removed on next save.", strings.Join(dropped, ", ")) + } + return &cfg, nil } +// dedupeByName collapses entries that share a name. When duplicates are +// found, the LAST occurrence by position wins (later writes typically +// reflect the most recent intent). Returns the unique set of duplicate +// names that were collapsed; order of remaining entries is preserved. +func dedupeByName[T any](items *[]T, name func(T) string) []string { + if len(*items) < 2 { + return nil + } + + lastIndex := make(map[string]int, len(*items)) + for i, v := range *items { + lastIndex[name(v)] = i + } + + kept := make([]T, 0, len(*items)) + var dupes []string + reported := make(map[string]bool, len(*items)) + for i, v := range *items { + n := name(v) + if lastIndex[n] != i { + if !reported[n] { + dupes = append(dupes, n) + reported[n] = true + } + continue + } + kept = append(kept, v) + } + + if len(dupes) == 0 { + return nil + } + *items = kept + return dupes +} + +func dedupeProfiles(cfg *Config) []string { + return dedupeByName(&cfg.Profiles, func(p Profile) string { return p.Name }) +} + +func dedupeEnvironments(cfg *Config) []string { + return dedupeByName(&cfg.Environments, func(e Environment) string { return e.Name }) +} + +// validateUniqueNames errors if any two items share a name. Defense-in-depth +// for Save: Load already dedupes on parse, but in-memory mutations could in +// theory introduce a duplicate before Save is called. +func validateUniqueNames[T any](items []T, name func(T) string, kind string) error { + seen := make(map[string]bool, len(items)) + for _, v := range items { + n := name(v) + if seen[n] { + return fmt.Errorf("duplicate %s name %q in config", kind, n) + } + seen[n] = true + } + return nil +} + +// findByName returns a pointer into items whose name matches target, or nil. +// The pointer is into the underlying slice so callers can mutate in place. +func findByName[T any](items []T, target string, name func(T) string) *T { + for i := range items { + if name(items[i]) == target { + return &items[i] + } + } + return nil +} + // Reload forces a reload of the configuration from disk. // This is useful after modifying the config file (e.g., adding a profile). func Reload() (*Config, error) { - configOnce = sync.Once{} + configMu.Lock() cachedConfig = nil configErr = nil + configMu.Unlock() return Load() } // Save writes the configuration to disk and updates the cache. func (c *Config) Save() error { + if err := validateUniqueNames(c.Profiles, func(p Profile) string { return p.Name }, "profile"); err != nil { + return err + } + if err := validateUniqueNames(c.Environments, func(e Environment) string { return e.Name }, "environment"); err != nil { + return err + } + path, err := GetConfigPath() if err != nil { return err @@ -113,19 +229,23 @@ func (c *Config) Save() error { return fmt.Errorf("writing config file: %w", err) } - // Update cache to reflect saved changes + // Update cache to reflect saved changes. + configMu.Lock() cachedConfig = c + configErr = nil + configMu.Unlock() return nil } // GetProfile returns the profile with the given name. func (c *Config) GetProfile(name string) *Profile { - for i := range c.Profiles { - if c.Profiles[i].Name == name { - return &c.Profiles[i] - } - } - return nil + return findByName(c.Profiles, name, func(p Profile) string { return p.Name }) +} + +// GetEnvironment returns the environment preset with the given name, or nil +// if not configured. +func (c *Config) GetEnvironment(name string) *Environment { + return findByName(c.Environments, name, func(e Environment) string { return e.Name }) } // GetCurrentProfile returns the current active profile. @@ -137,6 +257,14 @@ func (c *Config) GetCurrentProfile() *Profile { return c.GetProfile(c.CurrentProfile) } +// GetAPIKeyOverride returns the API key if explicitly set via the --api-key +// flag or ZEP_API_KEY environment variable. Returns empty string if neither +// is set (does not check the profile keychain). Used to detect explicit +// overrides. +func GetAPIKeyOverride() string { + return viper.GetString("api-key") +} + // GetAPIKey returns the API key to use, checking flags, env, and profile keychain. func GetAPIKey() string { // Flag/env takes precedence @@ -151,9 +279,30 @@ func GetAPIKey() string { } if profile := cfg.GetCurrentProfile(); profile != nil { - if key, err := keyring.Get(profile.Name); err == nil && key != "" { - return key + creds, err := keyring.GetCredentials(profile.Name) + if err != nil { + return "" } + return creds.APIKey + } + + return "" +} + +// GetProjectUUID returns the project UUID to use, checking flags, env, and profile. +func GetProjectUUID() string { + // Flag/env takes precedence + if p := viper.GetString("project"); p != "" { + return p + } + + cfg, err := Load() + if err != nil { + return "" + } + + if profile := cfg.GetCurrentProfile(); profile != nil { + return profile.ProjectUUID } return "" diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..e0e840e --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,543 @@ +package config + +import ( + "bytes" + "io" + "os" + "path/filepath" + "testing" + + "github.com/getzep/zepctl/internal/keyring" + "github.com/spf13/viper" + gokeyring "github.com/zalando/go-keyring" +) + +func init() { + gokeyring.MockInit() +} + +// resetConfig clears the cached config and viper state for test isolation. +func resetConfig(t *testing.T) { + t.Helper() + configMu.Lock() + cachedConfig = nil + configErr = nil + configMu.Unlock() + viper.Reset() +} + +// setTestConfig injects a config into the cache so Load() returns it +// without reading from disk. +func setTestConfig(t *testing.T, cfg *Config) { + t.Helper() + configMu.Lock() + cachedConfig = cfg + configErr = nil + configMu.Unlock() +} + +func TestGetProjectUUID_FlagTakesPrecedence(t *testing.T) { + resetConfig(t) + defer resetConfig(t) + + setTestConfig(t, &Config{ + CurrentProfile: "default", + Profiles: []Profile{ + {Name: "default", ProjectUUID: "profile-uuid"}, + }, + }) + + viper.Set("project", "flag-uuid") + t.Setenv("ZEP_PROJECT", "env-uuid") + + got := GetProjectUUID() + if got != "flag-uuid" { + t.Errorf("GetProjectUUID() = %q, want %q (flag should take precedence)", got, "flag-uuid") + } +} + +func TestGetProjectUUID_EnvOverridesProfile(t *testing.T) { + resetConfig(t) + defer resetConfig(t) + + setTestConfig(t, &Config{ + CurrentProfile: "default", + Profiles: []Profile{ + {Name: "default", ProjectUUID: "profile-uuid"}, + }, + }) + + // Viper reads ZEP_PROJECT env var when bound. Since we're testing + // without flag binding, simulate env by setting the viper key to empty + // (no flag) and using the env var via AutomaticEnv. + viper.SetEnvPrefix("ZEP") + viper.AutomaticEnv() + t.Setenv("ZEP_PROJECT", "env-uuid") + + got := GetProjectUUID() + if got != "env-uuid" { + t.Errorf("GetProjectUUID() = %q, want %q (env should override profile)", got, "env-uuid") + } +} + +func TestGetProjectUUID_FallsBackToProfile(t *testing.T) { + resetConfig(t) + defer resetConfig(t) + + setTestConfig(t, &Config{ + CurrentProfile: "default", + Profiles: []Profile{ + {Name: "default", ProjectUUID: "profile-uuid"}, + }, + }) + + got := GetProjectUUID() + if got != "profile-uuid" { + t.Errorf("GetProjectUUID() = %q, want %q", got, "profile-uuid") + } +} + +func TestGetProjectUUID_EmptyWhenNothingSet(t *testing.T) { + resetConfig(t) + defer resetConfig(t) + + setTestConfig(t, &Config{ + CurrentProfile: "default", + Profiles: []Profile{ + {Name: "default"}, + }, + }) + + got := GetProjectUUID() + if got != "" { + t.Errorf("GetProjectUUID() = %q, want empty", got) + } +} + +func TestLoad_ReturnsDefaultsWhenNoFile(t *testing.T) { + resetConfig(t) + defer resetConfig(t) + + // Point HOME to a temp dir with no config file. + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load: %v", err) + } + if cfg.Defaults.Output != "table" { + t.Errorf("default output = %q, want %q", cfg.Defaults.Output, "table") + } + if cfg.Defaults.PageSize != 50 { + t.Errorf("default page size = %d, want 50", cfg.Defaults.PageSize) + } +} + +func TestLoad_ParsesConfigFile(t *testing.T) { + resetConfig(t) + defer resetConfig(t) + + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + + configDir := filepath.Join(tmpDir, ".zepctl") + if err := os.MkdirAll(configDir, 0o700); err != nil { + t.Fatal(err) + } + + configContent := `current-profile: myprofile +profiles: + - name: myprofile + api-url: https://api.example.com + project-uuid: abc-123 +` + if err := os.WriteFile(filepath.Join(configDir, "config.yaml"), []byte(configContent), 0o600); err != nil { + t.Fatal(err) + } + + cfg, err := Load() + if err != nil { + t.Fatalf("Load: %v", err) + } + if cfg.CurrentProfile != "myprofile" { + t.Errorf("CurrentProfile = %q, want %q", cfg.CurrentProfile, "myprofile") + } + p := cfg.GetProfile("myprofile") + if p == nil { + t.Fatal("profile 'myprofile' not found") + } + if p.ProjectUUID != "abc-123" { + t.Errorf("ProjectUUID = %q, want %q", p.ProjectUUID, "abc-123") + } +} + +func TestGetAPIKey_FlagTakesPrecedence(t *testing.T) { + resetConfig(t) + defer resetConfig(t) + + setTestConfig(t, &Config{ + CurrentProfile: "default", + Profiles: []Profile{{Name: "default"}}, + }) + // Seed the profile's keychain with an API key. + if err := keyring.Set("default", "z_profile_key"); err != nil { + t.Fatalf("keyring.Set: %v", err) + } + + viper.Set("api-key", "z_flag_key") + + got := GetAPIKey() + if got != "z_flag_key" { + t.Errorf("GetAPIKey() = %q, want %q (flag should take precedence)", got, "z_flag_key") + } +} + +func TestGetAPIKey_FallsBackToProfile(t *testing.T) { + resetConfig(t) + defer resetConfig(t) + + setTestConfig(t, &Config{ + CurrentProfile: "default", + Profiles: []Profile{{Name: "default"}}, + }) + if err := keyring.Set("default", "z_profile_key"); err != nil { + t.Fatalf("keyring.Set: %v", err) + } + + got := GetAPIKey() + if got != "z_profile_key" { + t.Errorf("GetAPIKey() = %q, want %q", got, "z_profile_key") + } +} + +func TestGetAPIKey_EmptyWhenNothingSet(t *testing.T) { + resetConfig(t) + defer resetConfig(t) + + // Use a profile name that has no keychain entry. + setTestConfig(t, &Config{ + CurrentProfile: "empty-profile", + Profiles: []Profile{{Name: "empty-profile"}}, + }) + + got := GetAPIKey() + if got != "" { + t.Errorf("GetAPIKey() = %q, want empty", got) + } +} + +func TestGetAPIKeyOverride_ReturnsOverride(t *testing.T) { + resetConfig(t) + defer resetConfig(t) + + viper.Set("api-key", "z_explicit_override") + + got := GetAPIKeyOverride() + if got != "z_explicit_override" { + t.Errorf("GetAPIKeyOverride() = %q, want %q", got, "z_explicit_override") + } + + // Empty when not set. + viper.Reset() + got = GetAPIKeyOverride() + if got != "" { + t.Errorf("GetAPIKeyOverride() = %q, want empty", got) + } +} + +func TestDedupeProfiles_KeepsLastOccurrence(t *testing.T) { + cfg := &Config{ + Profiles: []Profile{ + {Name: "default", APIURL: "https://api.example.com"}, + {Name: "other"}, + {Name: "default", APIURL: "https://api.getzep.com"}, + }, + } + dupes := dedupeProfiles(cfg) + if len(dupes) != 1 || dupes[0] != "default" { + t.Errorf("dupes = %v, want [default]", dupes) + } + if len(cfg.Profiles) != 2 { + t.Fatalf("profiles = %d, want 2", len(cfg.Profiles)) + } + // Order is preserved: "other" stays at its original index, and the kept + // "default" copy is the LAST one (the prod URL). + if cfg.Profiles[0].Name != "other" { + t.Errorf("Profiles[0].Name = %q, want %q", cfg.Profiles[0].Name, "other") + } + if cfg.Profiles[1].Name != "default" || cfg.Profiles[1].APIURL != "https://api.getzep.com" { + t.Errorf("Profiles[1] = %+v, want last default with prod URL", cfg.Profiles[1]) + } +} + +// TestDedupeProfiles_ReportsEachNameOnce verifies the unique-names guarantee +// of the returned slice: a name appearing N>2 times is still reported once +// so the user-facing warning doesn't repeat the same name multiple times. +func TestDedupeProfiles_ReportsEachNameOnce(t *testing.T) { + cfg := &Config{ + Profiles: []Profile{ + {Name: "default", APIURL: "a"}, + {Name: "default", APIURL: "b"}, + {Name: "default", APIURL: "c"}, + }, + } + dupes := dedupeProfiles(cfg) + if len(dupes) != 1 || dupes[0] != "default" { + t.Errorf("dupes = %v, want exactly one [default] entry", dupes) + } + if len(cfg.Profiles) != 1 || cfg.Profiles[0].APIURL != "c" { + t.Errorf("expected single profile with last APIURL, got %+v", cfg.Profiles) + } +} + +func TestDedupeProfiles_NoDuplicates(t *testing.T) { + cfg := &Config{ + Profiles: []Profile{ + {Name: "a"}, + {Name: "b"}, + }, + } + if dupes := dedupeProfiles(cfg); dupes != nil { + t.Errorf("dupes = %v, want nil", dupes) + } + if len(cfg.Profiles) != 2 { + t.Errorf("profiles mutated when no dupes present") + } +} + +// TestLoad_DedupesAndWarnsOnDuplicates writes a YAML config with two +// "default" profiles to disk, calls Load, and verifies (a) the cached +// config has the duplicates collapsed, keeping the last by file position, +// and (b) a warning naming the duplicate was emitted to stderr. +func TestLoad_DedupesAndWarnsOnDuplicates(t *testing.T) { + resetConfig(t) + defer resetConfig(t) + + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + cfgDir := filepath.Join(tmpDir, ".zepctl") + if err := os.MkdirAll(cfgDir, 0o700); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + yaml := `current-profile: default +profiles: + - name: default + api-url: https://api.example.com + - name: keep + api-url: https://other.example.com + - name: default + api-url: https://api.getzep.com +defaults: + output: table + page-size: 50 +` + if err := os.WriteFile(filepath.Join(cfgDir, "config.yaml"), []byte(yaml), 0o600); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + // Capture stderr for the duration of Load() so we can assert on the warning. + origStderr := os.Stderr + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("pipe: %v", err) + } + os.Stderr = w + + cfg, loadErr := Load() + + _ = w.Close() + os.Stderr = origStderr + var stderr bytes.Buffer + _, _ = io.Copy(&stderr, r) + + if loadErr != nil { + t.Fatalf("Load: %v", loadErr) + } + if len(cfg.Profiles) != 2 { + t.Fatalf("Profiles after Load: got %d, want 2 (deduped)", len(cfg.Profiles)) + } + // "keep" stays in its original position (index 1 -> 0 after dedupe); + // "default" survives at the position of its last occurrence. + if cfg.Profiles[0].Name != "keep" { + t.Errorf("Profiles[0].Name = %q, want %q", cfg.Profiles[0].Name, "keep") + } + if cfg.Profiles[1].Name != "default" || cfg.Profiles[1].APIURL != "https://api.getzep.com" { + t.Errorf("Profiles[1] = %+v, want last default with prod URL", cfg.Profiles[1]) + } + if !bytes.Contains(stderr.Bytes(), []byte("default")) { + t.Errorf("stderr did not mention duplicate name 'default'; got: %q", stderr.String()) + } + if !bytes.Contains(stderr.Bytes(), []byte("Warning")) { + t.Errorf("stderr did not include a Warning prefix; got: %q", stderr.String()) + } +} + +func TestGetEnvironment_FoundAndNotFound(t *testing.T) { + cfg := &Config{ + Environments: []Environment{ + {Name: "development", APIURL: "https://dev.example.com"}, + {Name: "local", APIURL: "http://localhost:8000"}, + }, + } + + if env := cfg.GetEnvironment("development"); env == nil || env.APIURL != "https://dev.example.com" { + t.Errorf("GetEnvironment(development) = %+v, want development env", env) + } + if env := cfg.GetEnvironment("missing"); env != nil { + t.Errorf("GetEnvironment(missing) = %+v, want nil", env) + } +} + +func TestGetEnvironment_ReturnsPointerToConfigSlice(t *testing.T) { + // Mutating the returned pointer must update the underlying config so + // callers can do `env := cfg.GetEnvironment(...); env.APIURL = "..."` and + // have Save persist the change. This mirrors GetProfile's contract. + cfg := &Config{ + Environments: []Environment{{Name: "development", APIURL: "old"}}, + } + env := cfg.GetEnvironment("development") + env.APIURL = "new" + if cfg.Environments[0].APIURL != "new" { + t.Errorf("Environments[0].APIURL = %q, want %q (mutation must be visible)", + cfg.Environments[0].APIURL, "new") + } +} + +func TestDedupeEnvironments_KeepsLastOccurrence(t *testing.T) { + cfg := &Config{ + Environments: []Environment{ + {Name: "development", APIURL: "first"}, + {Name: "local"}, + {Name: "development", APIURL: "last"}, + }, + } + dupes := dedupeEnvironments(cfg) + if len(dupes) != 1 || dupes[0] != "development" { + t.Errorf("dupes = %v, want [development]", dupes) + } + if len(cfg.Environments) != 2 { + t.Fatalf("Environments = %d, want 2", len(cfg.Environments)) + } + if cfg.Environments[0].Name != "local" { + t.Errorf("Environments[0].Name = %q, want %q", cfg.Environments[0].Name, "local") + } + if cfg.Environments[1].Name != "development" || cfg.Environments[1].APIURL != "last" { + t.Errorf("Environments[1] = %+v, want last development entry", cfg.Environments[1]) + } +} + +func TestDedupeEnvironments_NoDuplicates(t *testing.T) { + cfg := &Config{ + Environments: []Environment{ + {Name: "development"}, + {Name: "local"}, + }, + } + if dupes := dedupeEnvironments(cfg); dupes != nil { + t.Errorf("dupes = %v, want nil", dupes) + } + if len(cfg.Environments) != 2 { + t.Errorf("environments mutated when no dupes present") + } +} + +func TestSave_RejectsDuplicateEnvironments(t *testing.T) { + resetConfig(t) + defer resetConfig(t) + + cfg := &Config{ + Environments: []Environment{ + {Name: "development"}, + {Name: "development"}, + }, + } + if err := cfg.Save(); err == nil { + t.Fatal("expected error from Save with duplicate environment names") + } +} + +func TestLoad_ParsesEnvironmentsAndDedupes(t *testing.T) { + resetConfig(t) + defer resetConfig(t) + + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + cfgDir := filepath.Join(tmpDir, ".zepctl") + if err := os.MkdirAll(cfgDir, 0o700); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + + yaml := `current-profile: dev +profiles: + - name: dev + api-url: https://api.development.example.com +environments: + - name: development + api-url: https://first.example.com + oauth-issuer: https://issuer.example.com + oauth-client-id: first-client-id + - name: local + api-url: http://localhost:8000 + - name: development + api-url: https://second.example.com + oauth-issuer: https://issuer.example.com + oauth-client-id: second-client-id +` + if err := os.WriteFile(filepath.Join(cfgDir, "config.yaml"), []byte(yaml), 0o600); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + // Discard stderr so the dedupe warning doesn't pollute test output. + origStderr := os.Stderr + devnull, err := os.Open(os.DevNull) + if err != nil { + t.Fatalf("open /dev/null: %v", err) + } + os.Stderr = devnull + defer func() { + os.Stderr = origStderr + _ = devnull.Close() + }() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load: %v", err) + } + + if len(cfg.Environments) != 2 { + t.Fatalf("Environments after dedupe = %d, want 2", len(cfg.Environments)) + } + dev := cfg.GetEnvironment("development") + if dev == nil { + t.Fatal("development environment missing after dedupe") + } + if dev.APIURL != "https://second.example.com" { + t.Errorf("development.APIURL = %q, want last entry %q", + dev.APIURL, "https://second.example.com") + } + if dev.OAuthClientID != "second-client-id" { + t.Errorf("development.OAuthClientID = %q, want %q", + dev.OAuthClientID, "second-client-id") + } + if local := cfg.GetEnvironment("local"); local == nil || local.APIURL != "http://localhost:8000" { + t.Errorf("local environment = %+v, want APIURL=http://localhost:8000", local) + } +} + +func TestSave_RejectsDuplicateProfiles(t *testing.T) { + resetConfig(t) + defer resetConfig(t) + + cfg := &Config{ + Profiles: []Profile{ + {Name: "default"}, + {Name: "default"}, + }, + } + // validateNoDuplicateProfiles errors before any disk I/O, so HOME setup + // is unnecessary here. + if err := cfg.Save(); err == nil { + t.Fatal("expected error from Save with duplicate profile names") + } +} diff --git a/internal/keyring/keyring.go b/internal/keyring/keyring.go index 4f5fdec..3ea34a1 100644 --- a/internal/keyring/keyring.go +++ b/internal/keyring/keyring.go @@ -1,8 +1,10 @@ package keyring import ( + "encoding/json" "errors" "fmt" + "time" "github.com/zalando/go-keyring" ) @@ -11,33 +13,129 @@ const ( serviceName = "zepctl" ) +// Credentials holds all authentication credentials for a profile. +// All fields are optional -- a profile may have just an API key, +// just a bearer token, or both. +type Credentials struct { + APIKey string `json:"api_key,omitempty"` + AccessToken string `json:"access_token,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` + ExpiresAt string `json:"expires_at,omitempty"` + UserEmail string `json:"user_email,omitempty"` +} + +// IsExpired returns true if the access token has expired. +func (c *Credentials) IsExpired() bool { + if c.ExpiresAt == "" { + return true + } + t, err := time.Parse(time.RFC3339, c.ExpiresAt) + if err != nil { + return true + } + return time.Now().After(t) +} + +// ExpiresIn returns the duration until the access token expires. +// Returns 0 if already expired or unparseable. +func (c *Credentials) ExpiresIn() time.Duration { + if c.ExpiresAt == "" { + return 0 + } + t, err := time.Parse(time.RFC3339, c.ExpiresAt) + if err != nil { + return 0 + } + d := time.Until(t) + if d < 0 { + return 0 + } + return d +} + +// HasBearerToken returns true if bearer token fields are populated. +func (c *Credentials) HasBearerToken() bool { + return c.AccessToken != "" +} + +// HasAPIKey returns true if an API key is configured. +func (c *Credentials) HasAPIKey() bool { + return c.APIKey != "" +} + +// GetCredentials retrieves structured credentials for a profile. +// Transparently migrates legacy raw API key strings to JSON format. +func GetCredentials(profile string) (*Credentials, error) { + raw, err := keyring.Get(serviceName, profile) + if err != nil { + if errors.Is(err, keyring.ErrNotFound) { + return &Credentials{}, nil + } + return nil, fmt.Errorf("retrieving credentials from keychain: %w", err) + } + + if raw == "" { + return &Credentials{}, nil + } + + // Try parsing as JSON (new format) + var creds Credentials + if err := json.Unmarshal([]byte(raw), &creds); err == nil { + return &creds, nil + } + + // JSON parse failed -- treat as legacy raw API key string. + // Migrate in-place to JSON format. + creds = Credentials{APIKey: raw} + if err := SetCredentials(profile, &creds); err != nil { + // Migration failed, but we still have the credentials in memory. + // Return them without persisting the upgrade. + return &creds, nil + } + return &creds, nil +} + +// SetCredentials stores structured credentials for a profile. +func SetCredentials(profile string, creds *Credentials) error { + data, err := json.Marshal(creds) //nolint:gosec // G117: intentionally marshaling credentials for keychain storage + if err != nil { + return fmt.Errorf("marshaling credentials: %w", err) + } + if err := keyring.Set(serviceName, profile, string(data)); err != nil { + return fmt.Errorf("storing credentials in keychain: %w", err) + } + return nil +} + // Set stores an API key for a profile in the system keychain. +// Uses the new JSON format. func Set(profile, apiKey string) error { - if err := keyring.Set(serviceName, profile, apiKey); err != nil { - return fmt.Errorf("storing API key in keychain: %w", err) + creds, err := GetCredentials(profile) + if err != nil { + // If we can't read existing creds, start fresh + creds = &Credentials{} } - return nil + creds.APIKey = apiKey + return SetCredentials(profile, creds) } // Get retrieves an API key for a profile from the system keychain. +// Handles both legacy raw strings and JSON format. func Get(profile string) (string, error) { - apiKey, err := keyring.Get(serviceName, profile) + creds, err := GetCredentials(profile) if err != nil { - if errors.Is(err, keyring.ErrNotFound) { - return "", nil - } - return "", fmt.Errorf("retrieving API key from keychain: %w", err) + return "", err } - return apiKey, nil + return creds.APIKey, nil } -// Delete removes an API key for a profile from the system keychain. +// Delete removes all credentials for a profile from the system keychain. func Delete(profile string) error { if err := keyring.Delete(serviceName, profile); err != nil { if errors.Is(err, keyring.ErrNotFound) { return nil } - return fmt.Errorf("deleting API key from keychain: %w", err) + return fmt.Errorf("deleting credentials from keychain: %w", err) } return nil } diff --git a/internal/keyring/keyring_test.go b/internal/keyring/keyring_test.go new file mode 100644 index 0000000..1698055 --- /dev/null +++ b/internal/keyring/keyring_test.go @@ -0,0 +1,290 @@ +package keyring + +import ( + "encoding/json" + "testing" + "time" + + gokeyring "github.com/zalando/go-keyring" +) + +func init() { + // Use in-memory mock keyring for tests. + gokeyring.MockInit() +} + +func TestCredentials_IsExpired(t *testing.T) { + tests := []struct { + name string + expiresAt string + want bool + }{ + {"empty expiry", "", true}, + {"invalid format", "not-a-date", true}, + {"past time", time.Now().Add(-1 * time.Hour).Format(time.RFC3339), true}, + {"future time", time.Now().Add(1 * time.Hour).Format(time.RFC3339), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Credentials{ExpiresAt: tt.expiresAt} + if got := c.IsExpired(); got != tt.want { + t.Errorf("IsExpired() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCredentials_ExpiresIn(t *testing.T) { + t.Run("empty expiry returns zero", func(t *testing.T) { + c := &Credentials{} + if d := c.ExpiresIn(); d != 0 { + t.Errorf("ExpiresIn() = %v, want 0", d) + } + }) + + t.Run("past expiry returns zero", func(t *testing.T) { + c := &Credentials{ExpiresAt: time.Now().Add(-1 * time.Hour).Format(time.RFC3339)} + if d := c.ExpiresIn(); d != 0 { + t.Errorf("ExpiresIn() = %v, want 0", d) + } + }) + + t.Run("future expiry returns positive duration", func(t *testing.T) { + c := &Credentials{ExpiresAt: time.Now().Add(30 * time.Minute).Format(time.RFC3339)} + d := c.ExpiresIn() + if d <= 29*time.Minute || d > 30*time.Minute { + t.Errorf("ExpiresIn() = %v, want ~30m", d) + } + }) +} + +func TestCredentials_HasBearerToken(t *testing.T) { + c := &Credentials{} + if c.HasBearerToken() { + t.Error("empty credentials should not have bearer token") + } + c.AccessToken = "tok" + if !c.HasBearerToken() { + t.Error("should have bearer token when access_token is set") + } +} + +func TestCredentials_HasAPIKey(t *testing.T) { + c := &Credentials{} + if c.HasAPIKey() { + t.Error("empty credentials should not have API key") + } + c.APIKey = "z_key" + if !c.HasAPIKey() { + t.Error("should have API key when api_key is set") + } +} + +func TestSetAndGetCredentials(t *testing.T) { + profile := "test-json-profile" + + creds := &Credentials{ + APIKey: "z_test123", + AccessToken: "eyJhbGci.test", + RefreshToken: "refresh_abc", + ExpiresAt: "2026-04-20T15:30:00Z", + UserEmail: "fred@frobozz.infocom", + } + + if err := SetCredentials(profile, creds); err != nil { + t.Fatalf("SetCredentials: %v", err) + } + + got, err := GetCredentials(profile) + if err != nil { + t.Fatalf("GetCredentials: %v", err) + } + + if got.APIKey != creds.APIKey { + t.Errorf("APIKey = %q, want %q", got.APIKey, creds.APIKey) + } + if got.AccessToken != creds.AccessToken { + t.Errorf("AccessToken = %q, want %q", got.AccessToken, creds.AccessToken) + } + if got.RefreshToken != creds.RefreshToken { + t.Errorf("RefreshToken = %q, want %q", got.RefreshToken, creds.RefreshToken) + } + if got.ExpiresAt != creds.ExpiresAt { + t.Errorf("ExpiresAt = %q, want %q", got.ExpiresAt, creds.ExpiresAt) + } + if got.UserEmail != creds.UserEmail { + t.Errorf("UserEmail = %q, want %q", got.UserEmail, creds.UserEmail) + } +} + +func TestGetCredentials_OnlyAPIKey(t *testing.T) { + profile := "test-apikey-only" + + creds := &Credentials{APIKey: "z_onlykey"} + if err := SetCredentials(profile, creds); err != nil { + t.Fatalf("SetCredentials: %v", err) + } + + got, err := GetCredentials(profile) + if err != nil { + t.Fatalf("GetCredentials: %v", err) + } + + if got.APIKey != "z_onlykey" { + t.Errorf("APIKey = %q, want %q", got.APIKey, "z_onlykey") + } + if got.HasBearerToken() { + t.Error("should not have bearer token") + } +} + +func TestGetCredentials_OnlyBearerToken(t *testing.T) { + profile := "test-bearer-only" + + creds := &Credentials{ + AccessToken: "bearer_tok", + RefreshToken: "refresh_tok", + ExpiresAt: "2026-04-20T15:30:00Z", + UserEmail: "user@example.com", + } + if err := SetCredentials(profile, creds); err != nil { + t.Fatalf("SetCredentials: %v", err) + } + + got, err := GetCredentials(profile) + if err != nil { + t.Fatalf("GetCredentials: %v", err) + } + + if got.HasAPIKey() { + t.Error("should not have API key") + } + if !got.HasBearerToken() { + t.Error("should have bearer token") + } + if got.UserEmail != "user@example.com" { + t.Errorf("UserEmail = %q, want %q", got.UserEmail, "user@example.com") + } +} + +func TestGetCredentials_LegacyMigration(t *testing.T) { + profile := "test-legacy" + + // Store a raw API key string (legacy format) + if err := gokeyring.Set(serviceName, profile, "z_legacy_key_123"); err != nil { + t.Fatalf("setting raw keyring value: %v", err) + } + + // GetCredentials should parse it as legacy and migrate + got, err := GetCredentials(profile) + if err != nil { + t.Fatalf("GetCredentials: %v", err) + } + + if got.APIKey != "z_legacy_key_123" { + t.Errorf("APIKey = %q, want %q", got.APIKey, "z_legacy_key_123") + } + + // Verify the migration happened: re-read raw value, should be JSON now + raw, err := gokeyring.Get(serviceName, profile) + if err != nil { + t.Fatalf("reading raw keyring after migration: %v", err) + } + + var migrated Credentials + if err := json.Unmarshal([]byte(raw), &migrated); err != nil { + t.Fatalf("migrated value is not valid JSON: %v (raw: %q)", err, raw) + } + if migrated.APIKey != "z_legacy_key_123" { + t.Errorf("migrated APIKey = %q, want %q", migrated.APIKey, "z_legacy_key_123") + } +} + +func TestGetCredentials_NotFound(t *testing.T) { + got, err := GetCredentials("nonexistent-profile") + if err != nil { + t.Fatalf("GetCredentials: %v", err) + } + if got.HasAPIKey() || got.HasBearerToken() { + t.Error("expected empty credentials for nonexistent profile") + } +} + +func TestSet_PreservesBearerToken(t *testing.T) { + profile := "test-set-preserve" + + // Start with bearer token only + creds := &Credentials{ + AccessToken: "bearer_tok", + RefreshToken: "refresh_tok", + ExpiresAt: "2026-04-20T15:30:00Z", + UserEmail: "user@example.com", + } + if err := SetCredentials(profile, creds); err != nil { + t.Fatalf("SetCredentials: %v", err) + } + + // Use Set() to add an API key -- should preserve bearer token + if err := Set(profile, "z_new_key"); err != nil { + t.Fatalf("Set: %v", err) + } + + got, err := GetCredentials(profile) + if err != nil { + t.Fatalf("GetCredentials: %v", err) + } + + if got.APIKey != "z_new_key" { + t.Errorf("APIKey = %q, want %q", got.APIKey, "z_new_key") + } + if got.AccessToken != "bearer_tok" { + t.Errorf("AccessToken = %q, want %q (should be preserved)", got.AccessToken, "bearer_tok") + } +} + +func TestGet_BackwardsCompatible(t *testing.T) { + profile := "test-get-compat" + + creds := &Credentials{APIKey: "z_compat_key"} + if err := SetCredentials(profile, creds); err != nil { + t.Fatalf("SetCredentials: %v", err) + } + + key, err := Get(profile) + if err != nil { + t.Fatalf("Get: %v", err) + } + if key != "z_compat_key" { + t.Errorf("Get() = %q, want %q", key, "z_compat_key") + } +} + +func TestDelete(t *testing.T) { + profile := "test-delete" + + creds := &Credentials{ + APIKey: "z_del_key", + AccessToken: "del_tok", + } + if err := SetCredentials(profile, creds); err != nil { + t.Fatalf("SetCredentials: %v", err) + } + + if err := Delete(profile); err != nil { + t.Fatalf("Delete: %v", err) + } + + got, err := GetCredentials(profile) + if err != nil { + t.Fatalf("GetCredentials after delete: %v", err) + } + if got.HasAPIKey() || got.HasBearerToken() { + t.Error("expected empty credentials after delete") + } +} + +func TestDelete_NotFound(t *testing.T) { + if err := Delete("nonexistent-for-delete"); err != nil { + t.Errorf("Delete of nonexistent profile should not error: %v", err) + } +} diff --git a/internal/output/output.go b/internal/output/output.go index f7729d2..9d0ba18 100644 --- a/internal/output/output.go +++ b/internal/output/output.go @@ -1,6 +1,7 @@ package output import ( + "bytes" "encoding/json" "fmt" "io" @@ -56,6 +57,32 @@ func Fprint(w io.Writer, data any) error { return printJSON(w, data) } +// FprintRaw outputs a raw JSON document in the configured format, +// preserving server-added fields the CLI doesn't model. JSON output +// pretty-prints the bytes through json.Indent (preserving server key +// order). YAML output parses into a generic value and emits as YAML +// (key order is not preserved through this path). +func FprintRaw(w io.Writer, rawJSON []byte) error { + switch GetFormat() { + case FormatYAML: + var v any + if err := json.Unmarshal(rawJSON, &v); err != nil { + return fmt.Errorf("decoding response: %w", err) + } + encoder := yaml.NewEncoder(w) + encoder.SetIndent(2) + return encoder.Encode(v) + default: + var buf bytes.Buffer + if err := json.Indent(&buf, rawJSON, "", " "); err != nil { + return fmt.Errorf("indenting response: %w", err) + } + buf.WriteByte('\n') + _, err := w.Write(buf.Bytes()) + return err + } +} + func printJSON(w io.Writer, data any) error { encoder := json.NewEncoder(w) encoder.SetIndent("", " ") From b9b1bd436a4dc751205cbfeb679160b588d905a9 Mon Sep 17 00:00:00 2001 From: Peter Evans Date: Wed, 20 May 2026 22:00:38 -0500 Subject: [PATCH 2/5] chore: Update dependencies --- go.mod | 9 +++------ go.sum | 47 ++++++----------------------------------------- 2 files changed, 9 insertions(+), 47 deletions(-) diff --git a/go.mod b/go.mod index 5c8807c..f6f911f 100644 --- a/go.mod +++ b/go.mod @@ -4,12 +4,12 @@ go 1.25.5 require ( github.com/getzep/zep-go/v3 v3.21.0 - github.com/google/uuid v1.4.0 + github.com/google/uuid v1.6.0 github.com/spf13/cobra v1.8.1 github.com/spf13/viper v1.19.0 - github.com/stretchr/testify v1.9.0 + github.com/stretchr/testify v1.10.0 github.com/zalando/go-keyring v0.2.6 - golang.org/x/oauth2 v0.18.0 + golang.org/x/oauth2 v0.36.0 golang.org/x/term v0.38.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -20,7 +20,6 @@ require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/godbus/dbus/v5 v5.1.0 // indirect - github.com/golang/protobuf v1.5.3 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/magiconair/properties v1.8.7 // indirect @@ -39,7 +38,5 @@ require ( golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect golang.org/x/sys v0.39.0 // indirect golang.org/x/text v0.14.0 // indirect - google.golang.org/appengine v1.6.8 // indirect - google.golang.org/protobuf v1.33.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect ) diff --git a/go.sum b/go.sum index 16b814c..59d3aad 100644 --- a/go.sum +++ b/go.sum @@ -15,17 +15,12 @@ github.com/getzep/zep-go/v3 v3.21.0 h1:+ymrLdC8zWjUkKW3LM6U26VMtY0GdAcPGMYEsIwbS github.com/getzep/zep-go/v3 v3.21.0/go.mod h1:gTP6uw5RPlcFSs5z0pGUzhOpx8+w/S2swSc08efsSyQ= github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= -github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= -github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= -github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4= -github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= @@ -71,57 +66,27 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/zalando/go-keyring v0.2.6 h1:r7Yc3+H+Ux0+M72zacZoItR3UDxeWfKTcabvkI8ua9s= github.com/zalando/go-keyring v0.2.6/go.mod h1:2TCrxYrbUNYfNS/Kgy/LSrkSQzZ5UPVH85RwfczwvcI= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/oauth2 v0.18.0 h1:09qnuIAgzdx1XplqJvW6CQqMCtGZykZWcXzPMPUusvI= -golang.org/x/oauth2 v0.18.0/go.mod h1:Wf7knwG0MPoWIMMBgFlEaSUDaKskp0dCfrlJRJXbBi8= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= +golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q= golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM= -google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= -google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= From fbecbe39c095a414ab86cf970bed2c010b4d297b Mon Sep 17 00:00:00 2001 From: Peter Evans Date: Wed, 20 May 2026 22:12:33 -0500 Subject: [PATCH 3/5] chore: Remove dangling spec reference --- internal/auth/config.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/internal/auth/config.go b/internal/auth/config.go index b2a08e2..c9f87ec 100644 --- a/internal/auth/config.go +++ b/internal/auth/config.go @@ -19,8 +19,7 @@ var ( // // The client ID is NOT a secret. It is a public OAuth 2.0 client identifier // for a "Front-end" application, which has no client secret. It is safe to -// commit to source control and compile into the binary. See spec Sections -// 2.2 and 7.5. +// commit to source control and compile into the binary. // // Audience is intentionally empty in the build-time defaults. Profiles that // target a backend which enforces the aud claim must provide one explicitly From 2ce7e535ecf62a3c2a03900d78ba652f2ac7eca7 Mon Sep 17 00:00:00 2001 From: Peter Evans Date: Wed, 20 May 2026 22:46:33 -0500 Subject: [PATCH 4/5] chore: Update a few more dependencies --- go.mod | 9 ++++----- go.sum | 33 ++++++++++----------------------- 2 files changed, 14 insertions(+), 28 deletions(-) diff --git a/go.mod b/go.mod index f6f911f..c01b102 100644 --- a/go.mod +++ b/go.mod @@ -24,7 +24,7 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect - github.com/pelletier/go-toml/v2 v2.2.2 // indirect + github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect @@ -33,10 +33,9 @@ require ( github.com/spf13/cast v1.6.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.6.0 // indirect - go.uber.org/atomic v1.9.0 // indirect - go.uber.org/multierr v1.9.0 // indirect - golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect + go.uber.org/multierr v1.11.0 // indirect + golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 // indirect golang.org/x/sys v0.39.0 // indirect - golang.org/x/text v0.14.0 // indirect + golang.org/x/text v0.27.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect ) diff --git a/go.sum b/go.sum index 59d3aad..f62f40a 100644 --- a/go.sum +++ b/go.sum @@ -3,7 +3,6 @@ al.essio.dev/pkg/shellescape v1.5.1/go.mod h1:6sIqp7X2P6mThCQ7twERpZTuigpr6KbZWt github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/danieljoos/wincred v1.2.2 h1:774zMFJrqaeYCK2W57BgAem/MLi6mtSE47MB6BOJ0i0= github.com/danieljoos/wincred v1.2.2/go.mod h1:w7w4Utbrz8lqeMbDAK0lkNJUv5sAOkFi7nd/ogr0Uh8= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -15,8 +14,8 @@ github.com/getzep/zep-go/v3 v3.21.0 h1:+ymrLdC8zWjUkKW3LM6U26VMtY0GdAcPGMYEsIwbS github.com/getzep/zep-go/v3 v3.21.0/go.mod h1:gTP6uw5RPlcFSs5z0pGUzhOpx8+w/S2swSc08efsSyQ= github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -33,9 +32,8 @@ github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0V github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= -github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= -github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= +github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= @@ -57,41 +55,30 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.19.0 h1:RWq5SEjt8o25SROyN3z2OrDB9l7RPd3lwTWU8EcEdcI= github.com/spf13/viper v1.19.0/go.mod h1:GQUN9bilAbhU/jgc1bKs99f/suXKeUMct8Adx5+Ntkg= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/zalando/go-keyring v0.2.6 h1:r7Yc3+H+Ux0+M72zacZoItR3UDxeWfKTcabvkI8ua9s= github.com/zalando/go-keyring v0.2.6/go.mod h1:2TCrxYrbUNYfNS/Kgy/LSrkSQzZ5UPVH85RwfczwvcI= -go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= -go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= -go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= -go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= -golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= -golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 h1:R9PFI6EUdfVKgwKjZef7QIwGcBKu86OEFpJ9nUEP2l4= +golang.org/x/exp v0.0.0-20250718183923-645b1fa84792/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc= golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q= golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= +golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= From c8393844afe67667632ae5a9ddefbb8ac46f2392 Mon Sep 17 00:00:00 2001 From: Peter Evans Date: Thu, 21 May 2026 10:33:38 -0500 Subject: [PATCH 5/5] chore: Apply gofumpt formatting Co-Authored-By: Claude Opus 4.7 (1M context) --- internal/client/client.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/internal/client/client.go b/internal/client/client.go index 58443cb..86ba321 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -86,7 +86,8 @@ func NewWithCredential(ctx context.Context, credType CredentialType) (*Client, e if projectUUID := config.GetProjectUUID(); projectUUID != "" { headers.Set(projectHeader, projectUUID) } - opts = append(opts, + opts = append( + opts, option.WithHTTPClient(httpClient), option.WithHTTPHeader(headers), )