From 5aab6b454589be84ac2b04fb9a3950628b2dab6d Mon Sep 17 00:00:00 2001 From: Zeno Belli Date: Mon, 25 May 2026 11:59:37 +0200 Subject: [PATCH 1/6] feat: add mcp command --- README.md | 32 ++ .../https-wrench-k3s-anchor-and-aliases.yaml | 27 ++ go.mod | 6 + go.sum | 12 + https-wrench.schema.json | 103 ++-- internal/certinfo/certinfo_handlers.go | 11 + internal/cmd/mcp.go | 41 ++ internal/cmd/mcp_test.go | 36 ++ internal/cmd/root.go | 16 + internal/cmd/root_test.go | 39 ++ internal/jwtinfo/jwtinfo_test.go | 15 + .../mcp/assets/examples/https-wrench-k3s.yaml | 56 +++ .../https-wrench-proxyProtocolV2.yaml | 31 ++ ...s-wrench-response-certificates-filter.yaml | 33 ++ internal/mcp/assets/sample-config.yaml | 99 ++++ internal/mcp/assets/schema.json | 235 +++++++++ internal/mcp/coverage_test.go | 361 ++++++++++++++ internal/mcp/embed.go | 63 +++ internal/mcp/prompts.go | 87 ++++ internal/mcp/resources.go | 81 ++++ internal/mcp/server.go | 64 +++ internal/mcp/server_test.go | 334 +++++++++++++ internal/mcp/tools.go | 401 ++++++++++++++++ internal/mcp/tools_exec.go | 444 ++++++++++++++++++ internal/mcp/tools_exec_test.go | 385 +++++++++++++++ .../requests/requests_handlers_print_test.go | 76 +++ internal/requests/requests_test.go | 11 + 27 files changed, 3072 insertions(+), 27 deletions(-) create mode 100644 assets/examples/https-wrench-k3s-anchor-and-aliases.yaml create mode 100644 internal/cmd/mcp.go create mode 100644 internal/cmd/mcp_test.go create mode 100644 internal/mcp/assets/examples/https-wrench-k3s.yaml create mode 100644 internal/mcp/assets/examples/https-wrench-proxyProtocolV2.yaml create mode 100644 internal/mcp/assets/examples/https-wrench-response-certificates-filter.yaml create mode 100644 internal/mcp/assets/sample-config.yaml create mode 100644 internal/mcp/assets/schema.json create mode 100644 internal/mcp/coverage_test.go create mode 100644 internal/mcp/embed.go create mode 100644 internal/mcp/prompts.go create mode 100644 internal/mcp/resources.go create mode 100644 internal/mcp/server.go create mode 100644 internal/mcp/server_test.go create mode 100644 internal/mcp/tools.go create mode 100644 internal/mcp/tools_exec.go create mode 100644 internal/mcp/tools_exec_test.go create mode 100644 internal/requests/requests_handlers_print_test.go diff --git a/README.md b/README.md index a5661e2..bf917af 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,8 @@ or remote JWKS endpoints. jwks: Generate pretty-printed JSON Web Key Sets (JWKS) from public keys for exposure on well-known endpoints. +mcp: Run a Model Context Protocol server on stdin/stdout for AI agent integration. + Distributed under an open-source license: https://github.com/xenOs76/https-wrench Usage: @@ -57,6 +59,7 @@ Available Commands: help Help about any command jwks Generate a JSON Web Key Set (JWKS) from a public key jwtinfo Inspect and validate JSON Web Tokens (JWT) + mcp Run an MCP server for AI agent integration requests Execute YAML-defined HTTPS requests Flags: @@ -375,6 +378,35 @@ Generate a JWKS with a SHA-256-derived KID: ❯ https-wrench jwks --public-key-file public.pem ``` +### HTTPS Wrench mcp + +`mcp` runs a Model Context Protocol server on stdin/stdout. Connect it from Cursor, Claude Desktop, or other MCP clients to author `requests` YAML, validate configs, build CLI commands, and run https-wrench operations directly. + +```shell +https-wrench mcp +``` + +Cursor configuration example: + +```json +{ + "mcpServers": { + "https-wrench": { + "command": "https-wrench", + "args": ["mcp"] + } + } +} +``` + +**Resources:** JSON schema, sample config, example YAML files, and a requests cheat sheet (`https-wrench://schema`, `https-wrench://sample-config`, `https-wrench://examples/{name}`, `https-wrench://docs/requests`). + +**Prompts:** `author_requests_config` — parameterized guidance for writing requests YAML. + +**Tools (assist):** `validate_requests_config`, `requests_config_template`, `build_cli_command`. + +**Tools (execution):** `run_requests`, `certinfo`, `jwtinfo`, `generate_jwks`. Encrypted private keys for `certinfo` require the `CERTINFO_PKEY_PW` environment variable (no interactive prompt under MCP). + ## Sample output
diff --git a/assets/examples/https-wrench-k3s-anchor-and-aliases.yaml b/assets/examples/https-wrench-k3s-anchor-and-aliases.yaml new file mode 100644 index 0000000..6c58406 --- /dev/null +++ b/assets/examples/https-wrench-k3s-anchor-and-aliases.yaml @@ -0,0 +1,27 @@ +# yaml-language-server: $schema=../../https-wrench.schema.json +# vim: set ts=2 sw=2 tw=0 fo=cnqoj +--- +verbose: true + +baseRequest: &base + printResponseHeaders: true + responseHeadersFilter: + - Server + hosts: + - name: httpbingo.k3s.os76.xyz + +requests: + - name: k3sOs76ViaCaddyDnsDirect + <<: *base + + - name: k3sOs76NoLbShouldFail + transportOverrideUrl: https://rpi501.home.arpa + <<: *base + + - name: k3sOs76ViaIstioDnsOverride + transportOverrideUrl: https://192.168.1.114:30443 + <<: *base + + - name: k3sOs76ViaNginxDnsOverride + transportOverrideUrl: https://argo.home.arpa + <<: *base diff --git a/go.mod b/go.mod index fbaaa18..378729b 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/golang-jwt/jwt/v5 v5.3.1 github.com/google/go-cmp v0.7.0 github.com/gookit/goutil v0.7.4 + github.com/modelcontextprotocol/go-sdk v1.6.1 github.com/pires/go-proxyproto v0.12.0 github.com/spf13/cobra v1.10.2 github.com/spf13/pflag v1.0.10 @@ -37,6 +38,7 @@ require ( github.com/dlclark/regexp2 v1.12.0 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/go-viper/mapstructure/v2 v2.5.0 // indirect + github.com/google/jsonschema-go v0.4.3 // indirect github.com/hexops/gotextdiff v1.0.3 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/lucasb-eyer/go-colorful v1.4.0 // indirect @@ -48,12 +50,16 @@ require ( github.com/rivo/uniseg v0.4.7 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/sagikazarmark/locafero v0.12.0 // indirect + github.com/segmentio/asm v1.1.3 // indirect + github.com/segmentio/encoding v0.5.4 // indirect github.com/spf13/afero v1.15.0 // indirect github.com/spf13/cast v1.10.0 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/crypto v0.50.0 // indirect + golang.org/x/oauth2 v0.35.0 // indirect golang.org/x/sys v0.43.0 // indirect golang.org/x/text v0.36.0 // indirect golang.org/x/time v0.15.0 // indirect diff --git a/go.sum b/go.sum index 071bf91..6917991 100644 --- a/go.sum +++ b/go.sum @@ -51,6 +51,8 @@ github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63Y github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/jsonschema-go v0.4.3 h1:/DBOLZTfDow7pe2GmaJNhltueGTtDKICi8V8p+DQPd0= +github.com/google/jsonschema-go v0.4.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/gookit/goutil v0.7.4 h1:OWgUngToNz+bPlX5aP+EMG31DraEU63uvKMwwT3vseM= github.com/gookit/goutil v0.7.4/go.mod h1:vJS9HXctYTCLtCsZot5L5xF+O1oR17cDYO9R0HxBmnU= github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= @@ -67,6 +69,8 @@ github.com/mattn/go-isatty v0.0.22 h1:j8l17JJ9i6VGPUFUYoTUKPSgKe/83EYU2zBC7YNKMw github.com/mattn/go-isatty v0.0.22/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4= github.com/mattn/go-runewidth v0.0.23 h1:7ykA0T0jkPpzSvMS5i9uoNn2Xy3R383f9HDx3RybWcw= github.com/mattn/go-runewidth v0.0.23/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= +github.com/modelcontextprotocol/go-sdk v1.6.1 h1:0zOSupjKUxPKSocPT1Wtago+mUHU2/uZ4xSOY0FGReU= +github.com/modelcontextprotocol/go-sdk v1.6.1/go.mod h1:kzm3kzFL1/+AziGOE0nUs3gvPoNxMCvkxokMkuFapXQ= github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= github.com/pelletier/go-toml/v2 v2.3.0 h1:k59bC/lIZREW0/iVaQR8nDHxVq8OVlIzYCOJf421CaM= @@ -83,6 +87,10 @@ github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sagikazarmark/locafero v0.12.0 h1:/NQhBAkUb4+fH1jivKHWusDYFjMOOKU88eegjfxfHb4= github.com/sagikazarmark/locafero v0.12.0/go.mod h1:sZh36u/YSZ918v0Io+U9ogLYQJ9tLLBmM4eneO6WwsI= +github.com/segmentio/asm v1.1.3 h1:WM03sfUOENvvKexOLp+pCqgb/WDjsi7EK8gIsICtzhc= +github.com/segmentio/asm v1.1.3/go.mod h1:Ld3L4ZXGNcSLRg4JBsZ3//1+f/TjYl0Mzen/DQy1EJg= +github.com/segmentio/encoding v0.5.4 h1:OW1VRern8Nw6ITAtwSZ7Idrl3MXCFwXHPgqESYfvNt0= +github.com/segmentio/encoding v0.5.4/go.mod h1:HS1ZKa3kSN32ZHVZ7ZLPLXWvOVIiZtyJnO1gPH1sKt0= github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg= github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= @@ -100,6 +108,8 @@ github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8 github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM= github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= @@ -108,6 +118,8 @@ golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= +golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ= +golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY= diff --git a/https-wrench.schema.json b/https-wrench.schema.json index c951a9b..597447c 100644 --- a/https-wrench.schema.json +++ b/https-wrench.schema.json @@ -5,18 +5,27 @@ "HttpsWrenchConfiguration": { "type": "object", "additionalProperties": false, + "description": "Top-level configuration for https-wrench requests.", "properties": { "debug": { - "type": "boolean" + "type": "boolean", + "description": "Enables global debug mode to print additional diagnostic information." }, "verbose": { - "type": "boolean" + "type": "boolean", + "description": "Enables verbose output, showing more details during execution." }, "caBundle": { - "type": "string" + "type": "string", + "description": "PEM-encoded CA certificate bundle used to verify server certificates. When omitted, the system trust store is used." + }, + "baseRequest": { + "$ref": "#/definitions/RequestDefaults", + "description": "YAML-only shared request template. Define an anchor (e.g. baseRequest: &base) and merge into requests with <<: *base. Ignored by https-wrench at runtime." }, "requests": { "type": "array", + "description": "List of HTTPS requests to execute.", "items": { "$ref": "#/definitions/Request" } @@ -28,12 +37,14 @@ ], "title": "HttpsWrenchConfiguration" }, - "Request": { + "RequestDefaults": { "type": "object", "additionalProperties": false, + "description": "Shared request fields with no required keys. Used by baseRequest anchors and merged into requests[] entries.", "properties": { "name": { - "type": "string" + "type": "string", + "description": "Descriptive name for this request, used in output. Required on each requests[] entry." }, "transportOverrideUrl": { "type": "string", @@ -41,39 +52,50 @@ "qt-uri-protocols": [ "https" ], - "pattern": "^https://" + "pattern": "^https://", + "description": "TLS/TCP dial address (https://host or https://ip:port). The logical hostname remains hosts[].name for Host header and SNI." }, "clientTimeout": { - "type": "number" + "type": "number", + "description": "HTTP client timeout in seconds." }, "requestDebug": { - "type": "boolean" + "type": "boolean", + "description": "If true, dumps the raw outgoing HTTP request for debugging." }, "responseDebug": { - "type": "boolean" + "type": "boolean", + "description": "If true, dumps the raw HTTP response, including TLS connection details, for debugging." }, "responseBodyMatchRegexp": { - "type": "string" + "type": "string", + "description": "Regular expression that the response body must match." }, "printResponseBody": { - "type": "boolean" + "type": "boolean", + "description": "If true, prints the HTTP response body." }, "printResponseHeaders": { - "type": "boolean" + "type": "boolean", + "description": "If true, prints HTTP response headers." }, "printResponseCertificates": { - "type": "boolean" + "type": "boolean", + "description": "If true, prints TLS certificates from the peer chain." }, "responseCertificatesFilter": { "type": "array", "description": "Filter to display only specific certificates from the peer chain and/or only subset of fields for each certificate. Each item in the array is a map of certificate index (0-indexed, where 0 is the leaf certificate) to a list of certificate fields to render (e.g. Subject, DNSNames, Issuer, NotBefore, NotAfter, Expiration). If the list of fields is empty, all fields for that certificate are printed.", "items": { "type": "object", + "description": "Map of certificate chain index to the list of certificate fields to print.", "patternProperties": { "^[0-9]+$": { "type": "array", + "description": "Certificate fields to print for the given chain index.", "items": { "type": "string", + "description": "Certificate field name to include in output.", "enum": [ "Subject", "DNSNames", @@ -96,45 +118,50 @@ } }, "enableProxyProtocolV2": { - "type": "boolean" + "type": "boolean", + "description": "If true, sends an HAProxy PROXY protocol v2 header on connect. Requires transportOverrideUrl." }, "insecure": { - "type": "boolean" + "type": "boolean", + "description": "If true, skips TLS server certificate verification (InsecureSkipVerify)." }, "responseHeadersFilter": { "type": "array", + "description": "Response header names to include when printResponseHeaders is enabled.", "items": { "type": "string", - "pattern": "^[A-Z]" + "pattern": "^[A-Z]", + "description": "HTTP response header name. Must start with an uppercase letter." } }, "requestBody": { - "type": "string" + "type": "string", + "description": "HTTP request body payload." }, "requestMethod": { "type": "string", - "pattern": "^(POST|GET|HEAD|PUT|DELETE|PATCH|HEAD|OPTIONS|TRACE)$" + "pattern": "^(POST|GET|HEAD|PUT|DELETE|PATCH|HEAD|OPTIONS|TRACE)$", + "description": "HTTP method for the request. Defaults to GET when omitted." }, "requestHeaders": { "type": "array", + "description": "Custom HTTP request headers to send.", "items": { "$ref": "#/definitions/RequestHeader" } }, "userAgent": { - "type": "string" + "type": "string", + "description": "Custom User-Agent string for the request." }, "hosts": { "type": "array", + "description": "Target hostnames and paths to request. Required on each requests[] entry.", "items": { "$ref": "#/definitions/Host" } } }, - "required": [ - "hosts", - "name" - ], "dependencies": { "enableProxyProtocolV2": [ "transportOverrideUrl" @@ -143,20 +170,39 @@ "printResponseCertificates" ] }, + "title": "RequestDefaults" + }, + "Request": { + "description": "A single HTTPS probe definition. Each entry must include name and hosts.", + "allOf": [ + { + "$ref": "#/definitions/RequestDefaults" + }, + { + "required": [ + "name", + "hosts" + ] + } + ], "title": "Request" }, "Host": { "type": "object", "additionalProperties": false, + "description": "A logical hostname and the URI paths to request on it.", "properties": { "name": { - "type": "string" + "type": "string", + "description": "Hostname used for the request URL, Host header, and TLS ServerName indication." }, "uriList": { "type": "array", + "description": "URI paths to request on this host. Each path must start with /. When omitted, tool defaults apply.", "items": { "type": "string", - "pattern": "^/" + "pattern": "^/", + "description": "Request path starting with /." } } }, @@ -168,12 +214,15 @@ "RequestHeader": { "type": "object", "additionalProperties": false, + "description": "A single HTTP request header key-value pair.", "properties": { "key": { - "type": "string" + "type": "string", + "description": "HTTP header name." }, "value": { - "type": "string" + "type": "string", + "description": "HTTP header value." } }, "required": [ diff --git a/internal/certinfo/certinfo_handlers.go b/internal/certinfo/certinfo_handlers.go index e6b1cd2..3cb2232 100644 --- a/internal/certinfo/certinfo_handlers.go +++ b/internal/certinfo/certinfo_handlers.go @@ -308,6 +308,17 @@ func CertsToTables(w io.Writer, certs []*x509.Certificate, filter ...[]map[int][ addRow(sl("DNSNames"), sv(dnsNames)) } + if hasField("IPAddresses") { + var ipStrs []string + + for _, ip := range cert.IPAddresses { + ipStrs = append(ipStrs, ip.String()) + } + + ips := strings.Join(ipStrs, "\n") + addRow(sl("IPAddresses"), sv(ips)) + } + if hasField("Issuer") { issuer := cert.Issuer.String() addRow(sl("Issuer"), sv(issuer)) diff --git a/internal/cmd/mcp.go b/internal/cmd/mcp.go new file mode 100644 index 0000000..b3438da --- /dev/null +++ b/internal/cmd/mcp.go @@ -0,0 +1,41 @@ +/* +Copyright © 2026 Zeno Belli +*/ + +package cmd + +import ( + "github.com/spf13/cobra" + mcpserver "github.com/xenos76/https-wrench/internal/mcp" +) + +var mcpCmd = &cobra.Command{ + Use: "mcp", + Short: "Run an MCP server for AI agent integration", + Long: `Start a Model Context Protocol server on stdin/stdout. + +The server exposes reference resources (JSON schema, sample config, examples), +prompts for authoring requests YAML, and assist tools to validate configs and +build CLI invocations for other https-wrench subcommands. + +Configure in Cursor or Claude Desktop: + + { + "mcpServers": { + "https-wrench": { + "command": "https-wrench", + "args": ["mcp"] + } + } + } +`, + SilenceUsage: true, + SilenceErrors: true, + RunE: func(cmd *cobra.Command, _ []string) error { + return mcpserver.Run(cmd.Context(), version) + }, +} + +func init() { + rootCmd.AddCommand(mcpCmd) +} diff --git a/internal/cmd/mcp_test.go b/internal/cmd/mcp_test.go new file mode 100644 index 0000000..6fcdebc --- /dev/null +++ b/internal/cmd/mcp_test.go @@ -0,0 +1,36 @@ +package cmd + +import ( + "os" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsMCPCommand(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + args []string + want bool + }{ + {name: "mcp subcommand", args: []string{"https-wrench", "mcp"}, want: true}, + {name: "mcp with config flag", args: []string{"https-wrench", "--config", "x.yaml", "mcp"}, want: true}, + {name: "requests subcommand", args: []string{"https-wrench", "requests", "--config", "x.yaml"}, want: false}, + {name: "root help", args: []string{"https-wrench", "-h"}, want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + oldArgs := os.Args + + t.Cleanup(func() { os.Args = oldArgs }) + + os.Args = tt.args + require.Equal(t, tt.want, isMCPCommand()) + }) + } +} diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 5156f63..a449360 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -69,6 +69,8 @@ or remote JWKS endpoints. jwks: Generate pretty-printed JSON Web Key Sets (JWKS) from public keys for exposure on well-known endpoints. +mcp: Run a Model Context Protocol server on stdin/stdout for AI agent integration. + Distributed under an open-source license: https://github.com/xenOs76/https-wrench`, Run: func(cmd *cobra.Command, _ []string) { @@ -118,6 +120,10 @@ func init() { } func initConfig() { + if isMCPCommand() { + return + } + if cfgFile != "" { viper.SetConfigFile(cfgFile) } else { @@ -149,6 +155,16 @@ func LoadConfig() (*HTTPSWrenchConfig, error) { return config, nil } +func isMCPCommand() bool { + for _, arg := range os.Args[1:] { + if arg == "mcp" { + return true + } + } + + return false +} + func addCaBundleFlag(cmd *cobra.Command) { cmd.Flags().StringVar(&caBundlePath, "ca-bundle", "", `Path to bundle file with CA certificates to use for validation`) diff --git a/internal/cmd/root_test.go b/internal/cmd/root_test.go index f8008d5..14664e5 100644 --- a/internal/cmd/root_test.go +++ b/internal/cmd/root_test.go @@ -74,6 +74,45 @@ func TestRootCmd_LoadConfig(t *testing.T) { require.Equal(t, "SampleRequestAgainstLocalWebserver", config.Requests[0].Name) require.Equal(t, "https://127.0.0.1:9443", config.Requests[0].TransportOverrideURL) }) + + t.Run("LoadConfig YAML anchor and merge keys", func(t *testing.T) { + oldCfg := cfgFile + + t.Cleanup(func() { + cfgFile = oldCfg + + viper.Reset() + }) + + cfgFile = "../../assets/examples/https-wrench-k3s-anchor-and-aliases.yaml" + + initConfig() + + config, err := LoadConfig() + require.NoError(t, err) + require.True(t, config.Verbose) + require.Len(t, config.Requests, 4) + + for _, req := range config.Requests { + require.True(t, req.PrintResponseHeaders) + require.Equal(t, []string{"Server"}, req.ResponseHeadersFilter) + require.Len(t, req.Hosts, 1) + require.Equal(t, "httpbingo.k3s.os76.xyz", req.Hosts[0].Name) + } + + require.Equal(t, "k3sOs76ViaCaddyDnsDirect", config.Requests[0].Name) + require.Empty(t, config.Requests[0].TransportOverrideURL) + + require.Equal(t, "k3sOs76NoLbShouldFail", config.Requests[1].Name) + require.Equal(t, "https://rpi501.home.arpa", config.Requests[1].TransportOverrideURL) + + require.Equal(t, "k3sOs76ViaIstioDnsOverride", config.Requests[2].Name) + require.Equal(t, "https://192.168.1.114:30443", config.Requests[2].TransportOverrideURL) + + require.Equal(t, "k3sOs76ViaNginxDnsOverride", config.Requests[3].Name) + require.Equal(t, "https://argo.home.arpa", config.Requests[3].TransportOverrideURL) + }) + t.Run("LoadConfig unmarshal error", func(t *testing.T) { oldCfg := cfgFile diff --git a/internal/jwtinfo/jwtinfo_test.go b/internal/jwtinfo/jwtinfo_test.go index edbe77c..3549df6 100644 --- a/internal/jwtinfo/jwtinfo_test.go +++ b/internal/jwtinfo/jwtinfo_test.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "maps" + "net/http" "os" "strings" "testing" @@ -256,6 +257,20 @@ func TestRequestToken(t *testing.T) { //nolint:revive } +func TestRequestToken_nilReadAll(t *testing.T) { + t.Parallel() + + _, err := RequestToken( + context.Background(), + "http://example.com/token", + map[string]string{"grant_type": "client_credentials"}, + &http.Client{}, + nil, + ) + require.Error(t, err) + require.ErrorContains(t, err, "nil body reader") +} + type requestTokenTestCase struct { name string user string diff --git a/internal/mcp/assets/examples/https-wrench-k3s.yaml b/internal/mcp/assets/examples/https-wrench-k3s.yaml new file mode 100644 index 0000000..9a6a01d --- /dev/null +++ b/internal/mcp/assets/examples/https-wrench-k3s.yaml @@ -0,0 +1,56 @@ +# yaml-language-server: $schema=https://raw.githubusercontent.com/xenOs76/https-wrench/refs/heads/main/https-wrench.schema.json +# vim: set ts=2 sw=2 tw=0 fo=cnqoj +--- +debug: false +verbose: true +requests: + - name: k3sOs76ViaCaddyDnsDirect + printResponseHeaders: true + responseHeadersFilter: + - Server + + hosts: + - name: httpbingo.k3s.os76.xyz + + - name: k3sOs76NoLbShouldFail + transportOverrideUrl: https://rpi501.home.arpa + + # requestDebug: true + # responseDebug: true + hosts: + - name: httpbingo.k3s.os76.xyz + uriList: + - /headers + + - name: k3sOs76ViaIstioDnsOverride + transportOverrideUrl: https://192.168.1.114:30443 + # requestDebug: true + # responseDebug: true + # printResponseBody: true + printResponseHeaders: true + responseHeadersFilter: + - Server + + userAgent: wrench-transport-override + requestHeaders: + - key: x-api-key + value: aaa-bbb-ccc + + hosts: + - name: httpbingo.k3s.os76.xyz + uriList: + - /headers + # - /status/404 + # - /status/503 + + - name: k3sOs76ViaNginxDnsOverride + transportOverrideUrl: https://argo.home.arpa + # printResponseBody: true + printResponseHeaders: true + responseHeadersFilter: + - Server + + hosts: + - name: httpbingo.k3s.os76.xyz + uriList: + - /headers diff --git a/internal/mcp/assets/examples/https-wrench-proxyProtocolV2.yaml b/internal/mcp/assets/examples/https-wrench-proxyProtocolV2.yaml new file mode 100644 index 0000000..c06c49b --- /dev/null +++ b/internal/mcp/assets/examples/https-wrench-proxyProtocolV2.yaml @@ -0,0 +1,31 @@ +# yaml-language-server: $schema=https://raw.githubusercontent.com/xenOs76/https-wrench/refs/heads/main/https-wrench.schema.json +--- +## +## HTTPS Wrench - ProxyProtocolV2 test +## +## Test to validate the connection against a webserver implementing ProxyProtocolV2. +## This configuration file can be run against the Devenv testing environment defined +## in devenv.nix. + +debug: false +verbose: true + +requests: + - name: RequestOverProxyPtorocolV2 + transportOverrideUrl: https://127.0.0.1:9444 + enableProxyProtocolV2: true + clientTimeout: 2 + insecure: true + printResponseBody: true + responseBodyMatchRegexp: ".*https-wrench-request.*" + printResponseHeaders: true + responseHeadersFilter: + - Content-Type + - Server + + requestMethod: POST + + hosts: + - name: example.com + uriList: + - /post diff --git a/internal/mcp/assets/examples/https-wrench-response-certificates-filter.yaml b/internal/mcp/assets/examples/https-wrench-response-certificates-filter.yaml new file mode 100644 index 0000000..efe48fe --- /dev/null +++ b/internal/mcp/assets/examples/https-wrench-response-certificates-filter.yaml @@ -0,0 +1,33 @@ +# yaml-language-server: $schema=https://raw.githubusercontent.com/xenOs76/https-wrench/refs/heads/main/https-wrench.schema.json +# vim: set ts=2 sw=2 tw=0 fo=cnqoj +--- +## HTTPS Wrench — Response Certificates Filter Example +## +## This example demonstrates how to use the 'responseCertificatesFilter' option +## to selectively print certificate chains in your HTTP responses. +## +## Note: This option depends on 'printResponseCertificates: true' being enabled. + +debug: false +verbose: true + +requests: + - name: selective-certificate-display + printResponseCertificates: true + responseCertificatesFilter: + # Filter for the leaf certificate (index 0) in the peer chain + - 0: + - Subject + - DNSNames + - Issuer + - NotAfter + - Expiration + # Filter for the intermediate/CA certificate (index 1) in the peer chain + - 1: + - Subject + - Issuer + - IsCA + + hosts: + - name: google.com + - name: github.com diff --git a/internal/mcp/assets/sample-config.yaml b/internal/mcp/assets/sample-config.yaml new file mode 100644 index 0000000..f3a82ee --- /dev/null +++ b/internal/mcp/assets/sample-config.yaml @@ -0,0 +1,99 @@ +# yaml-language-server: $schema=https://raw.githubusercontent.com/xenOs76/https-wrench/refs/heads/main/https-wrench.schema.json +--- +## +## HTTPS Wrench - sample configuration file +## + +## debug - Enables global debug mode to print additional diagnostic information. +debug: false + +## verbose - Enables verbose output, showing more details during execution. Required option. +verbose: true + +## caBundle - A PEM-encoded CA certificate bundle as a multiline string to be used for verifying server certificates. +## When testing inside the devenv environment, the 'devenv up' command will create new self-signed certificates and +## start a local, HTTPS-enabled Nginx server. +## The server will take the certificate from $CAROOT/full-cert.pem. +## If caBundle is not set, the requests made using this configuration file will fail with a TLS certificate verification error. +## Add the content of $CAROOT/rootCA.pem to the variable caBundle to test the sample configuration against the local +## webserver. +caBundle: | + -----BEGIN CERTIFICATE----- + MIIEbTCCAtWgAwIBAgIQJdy/eKgQx9G54MUxW+ow5zANBgkqhkiG9w0BAQsFADBP + ... + +## requests - List of HTTP requests to execute. Required option. +requests: + ## name - The name of the request, used for display purposes. Required option. + - name: SampleRequestAgainstLocalWebserver + + ## transportOverrideUrl - Override URL for the transport layer. Can be used to force a connection to a specific IP or proxy. Must start with https:// + transportOverrideUrl: https://127.0.0.1:9443 + + ## enableProxyProtocolV2 - Enables sending an HAProxy PROXY protocol v2 header. Requires 'transportOverrideUrl' to be set. + enableProxyProtocolV2: false + + ## clientTimeout - The timeout for the HTTP client in seconds. + clientTimeout: 5 + + ## insecure - If true, skips TLS certificate verification (InsecureSkipVerify). + insecure: false + + ## requestDebug - If true, dumps the raw HTTP request to the output for debugging. + requestDebug: false + + ## responseDebug - If true, dumps the raw HTTP response, including TLS connection details, for debugging. + responseDebug: false + + ## printResponseBody - If true, prints the body of the HTTP response. + printResponseBody: true + + ## responseBodyMatchRegexp - A regular expression to match against the response body. + responseBodyMatchRegexp: ".*https-wrench-agent.*" + + ## printResponseHeaders - If true, prints the headers of the HTTP response. + printResponseHeaders: true + + ## responseHeadersFilter - A list of specific response headers to filter and display. + responseHeadersFilter: + - Content-Type + - Server + + ## printResponseCertificates - If true, prints the TLS certificates returned in the response. + printResponseCertificates: true + + ## responseCertificatesFilter - Filter the printed TLS certificates to show only specific certificates in the chain (0-indexed) and/or a subset of fields. + # responseCertificatesFilter: + # - 0: + # - Subject + # - DNSNames + # - Issuer + # - NotAfter + # - Expiration + + ## requestMethod - The HTTP method to use for the request (e.g., GET, POST, PUT, DELETE). Defaults to GET. + requestMethod: POST + + ## requestHeaders - A list of custom headers to send with the HTTP request. + requestHeaders: + ## key - The name of the header. + - key: Content-Type + ## value - The value of the header. + value: application/json + - key: X-Custom-Header + value: custom-value + + ## requestBody - The body payload to send with the HTTP request. + requestBody: "{\"key\": \"value\"}" + + ## userAgent - A custom User-Agent string to send with the request. + userAgent: custom-https-wrench-agent/1.0 + + ## hosts - The target hosts to send the request to. Required option. + hosts: + ## name - The hostname (used for the Host header and TLS ServerName indication). Required option. + - name: example.com + ## uriList - A list of URIs (paths) to request on this host. Must start with a forward slash (/). + uriList: + - /post + - /status/503 diff --git a/internal/mcp/assets/schema.json b/internal/mcp/assets/schema.json new file mode 100644 index 0000000..597447c --- /dev/null +++ b/internal/mcp/assets/schema.json @@ -0,0 +1,235 @@ +{ + "$schema": "http://json-schema.org/draft-06/schema#", + "$ref": "#/definitions/HttpsWrenchConfiguration", + "definitions": { + "HttpsWrenchConfiguration": { + "type": "object", + "additionalProperties": false, + "description": "Top-level configuration for https-wrench requests.", + "properties": { + "debug": { + "type": "boolean", + "description": "Enables global debug mode to print additional diagnostic information." + }, + "verbose": { + "type": "boolean", + "description": "Enables verbose output, showing more details during execution." + }, + "caBundle": { + "type": "string", + "description": "PEM-encoded CA certificate bundle used to verify server certificates. When omitted, the system trust store is used." + }, + "baseRequest": { + "$ref": "#/definitions/RequestDefaults", + "description": "YAML-only shared request template. Define an anchor (e.g. baseRequest: &base) and merge into requests with <<: *base. Ignored by https-wrench at runtime." + }, + "requests": { + "type": "array", + "description": "List of HTTPS requests to execute.", + "items": { + "$ref": "#/definitions/Request" + } + } + }, + "required": [ + "requests", + "verbose" + ], + "title": "HttpsWrenchConfiguration" + }, + "RequestDefaults": { + "type": "object", + "additionalProperties": false, + "description": "Shared request fields with no required keys. Used by baseRequest anchors and merged into requests[] entries.", + "properties": { + "name": { + "type": "string", + "description": "Descriptive name for this request, used in output. Required on each requests[] entry." + }, + "transportOverrideUrl": { + "type": "string", + "format": "uri", + "qt-uri-protocols": [ + "https" + ], + "pattern": "^https://", + "description": "TLS/TCP dial address (https://host or https://ip:port). The logical hostname remains hosts[].name for Host header and SNI." + }, + "clientTimeout": { + "type": "number", + "description": "HTTP client timeout in seconds." + }, + "requestDebug": { + "type": "boolean", + "description": "If true, dumps the raw outgoing HTTP request for debugging." + }, + "responseDebug": { + "type": "boolean", + "description": "If true, dumps the raw HTTP response, including TLS connection details, for debugging." + }, + "responseBodyMatchRegexp": { + "type": "string", + "description": "Regular expression that the response body must match." + }, + "printResponseBody": { + "type": "boolean", + "description": "If true, prints the HTTP response body." + }, + "printResponseHeaders": { + "type": "boolean", + "description": "If true, prints HTTP response headers." + }, + "printResponseCertificates": { + "type": "boolean", + "description": "If true, prints TLS certificates from the peer chain." + }, + "responseCertificatesFilter": { + "type": "array", + "description": "Filter to display only specific certificates from the peer chain and/or only subset of fields for each certificate. Each item in the array is a map of certificate index (0-indexed, where 0 is the leaf certificate) to a list of certificate fields to render (e.g. Subject, DNSNames, Issuer, NotBefore, NotAfter, Expiration). If the list of fields is empty, all fields for that certificate are printed.", + "items": { + "type": "object", + "description": "Map of certificate chain index to the list of certificate fields to print.", + "patternProperties": { + "^[0-9]+$": { + "type": "array", + "description": "Certificate fields to print for the given chain index.", + "items": { + "type": "string", + "description": "Certificate field name to include in output.", + "enum": [ + "Subject", + "DNSNames", + "Issuer", + "NotBefore", + "NotAfter", + "Expiration", + "IsCA", + "AuthorityKeyId", + "SubjectKeyId", + "PublicKeyAlgorithm", + "SignatureAlgorithm", + "SerialNumber", + "Fingerprint SHA-256" + ] + } + } + }, + "additionalProperties": false + } + }, + "enableProxyProtocolV2": { + "type": "boolean", + "description": "If true, sends an HAProxy PROXY protocol v2 header on connect. Requires transportOverrideUrl." + }, + "insecure": { + "type": "boolean", + "description": "If true, skips TLS server certificate verification (InsecureSkipVerify)." + }, + "responseHeadersFilter": { + "type": "array", + "description": "Response header names to include when printResponseHeaders is enabled.", + "items": { + "type": "string", + "pattern": "^[A-Z]", + "description": "HTTP response header name. Must start with an uppercase letter." + } + }, + "requestBody": { + "type": "string", + "description": "HTTP request body payload." + }, + "requestMethod": { + "type": "string", + "pattern": "^(POST|GET|HEAD|PUT|DELETE|PATCH|HEAD|OPTIONS|TRACE)$", + "description": "HTTP method for the request. Defaults to GET when omitted." + }, + "requestHeaders": { + "type": "array", + "description": "Custom HTTP request headers to send.", + "items": { + "$ref": "#/definitions/RequestHeader" + } + }, + "userAgent": { + "type": "string", + "description": "Custom User-Agent string for the request." + }, + "hosts": { + "type": "array", + "description": "Target hostnames and paths to request. Required on each requests[] entry.", + "items": { + "$ref": "#/definitions/Host" + } + } + }, + "dependencies": { + "enableProxyProtocolV2": [ + "transportOverrideUrl" + ], + "responseCertificatesFilter": [ + "printResponseCertificates" + ] + }, + "title": "RequestDefaults" + }, + "Request": { + "description": "A single HTTPS probe definition. Each entry must include name and hosts.", + "allOf": [ + { + "$ref": "#/definitions/RequestDefaults" + }, + { + "required": [ + "name", + "hosts" + ] + } + ], + "title": "Request" + }, + "Host": { + "type": "object", + "additionalProperties": false, + "description": "A logical hostname and the URI paths to request on it.", + "properties": { + "name": { + "type": "string", + "description": "Hostname used for the request URL, Host header, and TLS ServerName indication." + }, + "uriList": { + "type": "array", + "description": "URI paths to request on this host. Each path must start with /. When omitted, tool defaults apply.", + "items": { + "type": "string", + "pattern": "^/", + "description": "Request path starting with /." + } + } + }, + "required": [ + "name" + ], + "title": "Host" + }, + "RequestHeader": { + "type": "object", + "additionalProperties": false, + "description": "A single HTTP request header key-value pair.", + "properties": { + "key": { + "type": "string", + "description": "HTTP header name." + }, + "value": { + "type": "string", + "description": "HTTP header value." + } + }, + "required": [ + "key", + "value" + ], + "title": "RequestHeader" + } + } +} diff --git a/internal/mcp/coverage_test.go b/internal/mcp/coverage_test.go new file mode 100644 index 0000000..425c45a --- /dev/null +++ b/internal/mcp/coverage_test.go @@ -0,0 +1,361 @@ +package mcp + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "errors" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/require" +) + +func TestShellQuote(t *testing.T) { + t.Parallel() + + require.Equal(t, "''", shellQuote("")) + require.Equal(t, "plain", shellQuote("plain")) + require.Equal(t, `"has space"`, shellQuote("has space")) + require.Equal(t, `"say \"hi\""`, shellQuote(`say "hi"`)) +} + +func TestParsePaths(t *testing.T) { + t.Parallel() + + require.Equal(t, []string{"/"}, parsePaths("")) + require.Equal(t, []string{"/a", "/b"}, parsePaths(" /a , /b ")) + require.Equal(t, []string{"/"}, parsePaths(" , , ")) +} + +func TestLoadConfigYAML(t *testing.T) { + t.Parallel() + + _, err := loadConfigYAML("", "") + require.Error(t, err) + + _, err = loadConfigYAML("a: 1", "/tmp/x.yaml") + require.Error(t, err) + require.Contains(t, err.Error(), "exactly one") + + yaml, err := loadConfigYAML("verbose: true\n", "") + require.NoError(t, err) + require.Contains(t, yaml, "verbose") + + dir := t.TempDir() + path := filepath.Join(dir, "cfg.yaml") + require.NoError(t, os.WriteFile(path, []byte("verbose: true\n"), 0o600)) + + yaml, err = loadConfigYAML("", path) + require.NoError(t, err) + require.Contains(t, yaml, "verbose") + + _, err = loadConfigYAML("", filepath.Join(dir, "missing.yaml")) + require.Error(t, err) + require.Contains(t, err.Error(), "read config file") +} + +func TestValidateRequestsConfig_errors(t *testing.T) { + t.Parallel() + + valid, errs := validateRequestsConfig("not: [valid: yaml") + require.False(t, valid) + require.NotEmpty(t, errs) + + valid, errs = validateRequestsConfig("verbose: true\nrequests: not-a-list\n") + require.False(t, valid) + require.NotEmpty(t, errs) + + valid, errs = validateRequestsConfig(`verbose: true +requests: + - name: x + transportOverrideUrl: http://bad + hosts: + - name: h + uriList: + - / +`) + require.False(t, valid) + require.Contains(t, strings.Join(errs, " "), "https://") + + valid, errs = validateRequestsConfig(`verbose: true +requests: + - name: x + enableProxyProtocolV2: true + hosts: + - name: h + uriList: + - / +`) + require.False(t, valid) + require.Contains(t, strings.Join(errs, " "), "enableProxyProtocolV2") + + valid, errs = validateRequestsConfig(`verbose: true +requests: + - name: "" + hosts: [] +`) + require.False(t, valid) + require.NotEmpty(t, errs) +} + +func TestBuildRequestsConfigYAML_errors(t *testing.T) { + t.Parallel() + + _, errs := buildRequestsConfigYAML(requestsConfigTemplateInput{}) + require.NotEmpty(t, errs) + + _, errs = buildRequestsConfigYAML(requestsConfigTemplateInput{ + Hostname: "app.example.com", + Paths: "no-slash", + }) + require.NotEmpty(t, errs) +} + +func TestBuildCLICommand(t *testing.T) { + t.Parallel() + + _, errs := buildCLICommand("unknown", nil) + require.NotEmpty(t, errs) + + cmd, errs := buildCLICommand("jwks", map[string]string{ + "public-key-file": "/keys/pub.pem", + "kid": "my-kid", + }) + require.Empty(t, errs) + require.Contains(t, cmd, "https-wrench jwks") + + _, errs = buildCLICommand("jwtinfo", map[string]string{"token-file": "t.jwt"}) + require.Empty(t, errs) + + _, errs = buildCLICommand("requests", map[string]string{"config": "cfg.yaml"}) + require.Empty(t, errs) + + _, errs = buildCLICommand("certinfo", map[string]string{"unknown-flag": "x"}) + require.NotEmpty(t, errs) + + _, errs = buildCLICommand("jwtinfo", map[string]string{}) + require.NotEmpty(t, errs) +} + +func TestExecToolTimeout(t *testing.T) { + t.Parallel() + + require.Equal(t, defaultExecToolTimeout, execToolTimeout(0)) + require.Equal(t, 5*time.Second, execToolTimeout(5)) +} + +func TestCertinfoInputProvided(t *testing.T) { + t.Parallel() + + require.False(t, certinfoInputProvided(certinfoInput{})) + require.True(t, certinfoInputProvided(certinfoInput{CertBundle: "x.pem"})) +} + +func TestLoadJwtTokenData_errors(t *testing.T) { + t.Parallel() + + _, err := loadJwtTokenData(context.Background(), jwtinfoInput{}) + require.Error(t, err) + + _, err = loadJwtTokenData(context.Background(), jwtinfoInput{ + TokenFile: "a.jwt", + RequestURL: "https://example.com/token", + }) + require.Error(t, err) + + _, err = loadJwtTokenData(context.Background(), jwtinfoInput{ + RequestURL: "https://example.com/token", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "requestValues") +} + +func TestCaptureOutput_error(t *testing.T) { + t.Parallel() + + _, err := captureOutput(func(_ io.Writer) error { + return errors.New("boom") + }) + require.Error(t, err) +} + +func TestExampleResourceHints(t *testing.T) { + t.Parallel() + + hints := exampleResourceHints(requestsConfigTemplateInput{Hostname: "app.example.com"}) + require.Contains(t, hints, "k3s") + + hints = exampleResourceHints(requestsConfigTemplateInput{ + Hostname: "app.example.com", + TransportOverrideURL: "https://edge.example.net", + }) + require.Contains(t, hints, "proxy-protocol-v2") + + hints = exampleResourceHints(requestsConfigTemplateInput{ + Hostname: "app.example.com", + Insecure: true, + }) + require.Contains(t, hints, "k3s") +} + +func TestAuthorRequestsConfigPrompt_insecureAndErrors(t *testing.T) { + t.Parallel() + + _, err := authorRequestsConfigPrompt(context.Background(), &sdkmcp.GetPromptRequest{ + Params: &sdkmcp.GetPromptParams{Arguments: map[string]string{}}, + }) + require.Error(t, err) + + res, err := authorRequestsConfigPrompt(context.Background(), &sdkmcp.GetPromptRequest{ + Params: &sdkmcp.GetPromptParams{Arguments: map[string]string{ + "hostname": "app.example.com", + "insecure": "true", + }}, + }) + require.NoError(t, err) + require.NotEmpty(t, res.Messages) +} + +func TestReadExampleResource_notFound(t *testing.T) { + t.Parallel() + + _, err := readExampleResource(context.Background(), &sdkmcp.ReadResourceRequest{ + Params: &sdkmcp.ReadResourceParams{URI: "https-wrench://examples/"}, + }) + require.Error(t, err) + + _, err = readExampleResource(context.Background(), &sdkmcp.ReadResourceRequest{ + Params: &sdkmcp.ReadResourceParams{URI: "https-wrench://examples/unknown-example"}, + }) + require.Error(t, err) +} + +func TestTextResource(t *testing.T) { + t.Parallel() + + res := textResource("https-wrench://test", "hello") + require.Equal(t, "hello", res.Contents[0].Text) +} + +func TestNewServer_emptyVersion(t *testing.T) { + t.Parallel() + + server := NewServer("") + require.NotNil(t, server) +} + +func TestExecuteHelpers(t *testing.T) { + t.Parallel() + + _, err := executeCertinfo(context.Background(), certinfoInput{}) + require.Error(t, err) + + _, err = executeCertinfo(context.Background(), certinfoInput{ + TLSEndpoint: "example.com:443", + TLSInfo: true, + }) + require.NoError(t, err) + + _, err = executeGenerateJWKS(context.Background(), generateJWKSInput{}) + require.Error(t, err) + + _, err = executeRunRequests(context.Background(), runRequestsInput{}) + require.Error(t, err) + + _, err = executeJwtinfo(context.Background(), jwtinfoInput{TokenFile: "/no/such/token.jwt"}) + require.Error(t, err) +} + +func TestLoadRequestsConfigYAML_invalid(t *testing.T) { + t.Parallel() + + _, _, err := loadRequestsConfigYAML("requests: [") + require.Error(t, err) + + _, _, err = loadRequestsConfigYAML("verbose: true\nrequests: not-a-list\n") + require.Error(t, err) +} + +func TestBuildRequestsMetaConfig_caBundle(t *testing.T) { + t.Parallel() + + pubFile := filepath.Join("..", "certinfo", "testdata", "rsa-pkcs8-crt.pem") + loaded := loadedRequestsConfig{Verbose: true} + + meta, err := buildRequestsMetaConfig(loaded, pubFile) + require.NoError(t, err) + require.NotNil(t, meta) +} + +func TestJwtinfoHandler_requestURL(t *testing.T) { + t.Parallel() + + tokenString := writeTestJWT(t) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"` + tokenString + `"}`)) + })) + t.Cleanup(ts.Close) + + out, err := executeJwtinfo(context.Background(), jwtinfoInput{ + RequestURL: ts.URL, + RequestValues: map[string]string{"grant_type": "client_credentials"}, + }) + require.NoError(t, err) + require.Contains(t, out.Output, "JwtInfo") +} + +func TestRequestsConfigTemplateHandler_error(t *testing.T) { + t.Parallel() + + _, _, err := requestsConfigTemplateHandler( + context.Background(), + nil, + requestsConfigTemplateInput{}, + ) + require.Error(t, err) +} + +func TestMCPFileReader(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + path := filepath.Join(dir, "data.txt") + require.NoError(t, os.WriteFile(path, []byte("ok"), 0o600)) + + reader := mcpFileReader{} + data, err := reader.ReadFile(path) + require.NoError(t, err) + require.Equal(t, "ok", string(data)) + + _, err = reader.ReadPassword(0) + require.Error(t, err) +} + +func writeTestJWT(t *testing.T) string { + t.Helper() + + priv, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(time.Now()), + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), + Subject: "mcp-test", + }) + + tokenString, err := token.SignedString(priv) + require.NoError(t, err) + + return tokenString +} diff --git a/internal/mcp/embed.go b/internal/mcp/embed.go new file mode 100644 index 0000000..0fb15a8 --- /dev/null +++ b/internal/mcp/embed.go @@ -0,0 +1,63 @@ +package mcp + +import "embed" + +//go:generate cp ../../https-wrench.schema.json assets/schema.json +//go:generate cp ../cmd/embedded/config-example.yaml assets/sample-config.yaml +//go:generate cp ../../assets/examples/https-wrench-k3s.yaml assets/examples/ +//go:generate cp ../../assets/examples/https-wrench-response-certificates-filter.yaml assets/examples/ +//go:generate cp ../../assets/examples/https-wrench-proxyProtocolV2.yaml assets/examples/ + +//go:embed assets/schema.json +//go:embed assets/sample-config.yaml +//go:embed assets/examples/* +var assets embed.FS + +const ( + uriSchema = "https-wrench://schema" + uriSampleConfig = "https-wrench://sample-config" + uriDocsRequests = "https-wrench://docs/requests" + uriExampleTmpl = "https-wrench://examples/{name}" +) + +var exampleFiles = map[string]string{ + "k3s": "assets/examples/https-wrench-k3s.yaml", + "response-certificates-filter": "assets/examples/https-wrench-response-certificates-filter.yaml", + "proxy-protocol-v2": "assets/examples/https-wrench-proxyProtocolV2.yaml", +} + +const schemaCommentHeader = "# yaml-language-server: $schema=" + + "https://raw.githubusercontent.com/xenOs76/https-wrench/refs/heads/main/https-wrench.schema.json" + +const requestsDocsMarkdown = `# https-wrench requests + +Run probes with: ` + "`https-wrench requests --config `" + ` + +## Top level + +| Field | Notes | +|-------|-------| +| ` + "`verbose`" + ` | Required boolean. | +| ` + "`debug`" + ` | Optional. | +| ` + "`caBundle`" + ` | Optional PEM CA bundle path or inline PEM. | +| ` + "`requests`" + ` | Required array of request objects. | + +## Each request + +**Required:** ` + "`name`" + `, ` + "`hosts`" + ` (non-empty). + +- ` + "`hosts[].name`" + `: hostname for URL and SNI. +- ` + "`hosts[].uriList`" + `: paths must start with ` + "`/`" + `. +- ` + "`transportOverrideUrl`" + `: optional dial URL (` + "`https://...`" + `); ` + + `logical host stays in ` + "`hosts[].name`" + `. +- ` + "`insecure`" + `: skip TLS verification when dial address does not match cert. +- ` + "`requestMethod`" + `: GET, HEAD, POST, etc. +- ` + "`printResponseBody`" + `, ` + "`printResponseHeaders`" + `, ` + + "`printResponseCertificates`" + `: response inspection toggles. + +## MCP resources + +- ` + uriSchema + ` +- ` + uriSampleConfig + ` +- ` + uriExampleTmpl + ` (names: k3s, response-certificates-filter, proxy-protocol-v2) +` diff --git a/internal/mcp/prompts.go b/internal/mcp/prompts.go new file mode 100644 index 0000000..f88d6cc --- /dev/null +++ b/internal/mcp/prompts.go @@ -0,0 +1,87 @@ +package mcp + +import ( + "context" + "fmt" + "strings" + + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" +) + +func registerPrompts(server *sdkmcp.Server) { + server.AddPrompt(&sdkmcp.Prompt{ + Name: "author_requests_config", + Description: "Guide for authoring a https-wrench requests YAML configuration", + Arguments: []*sdkmcp.PromptArgument{ + {Name: "hostname", Description: "Application hostname (hosts[].name)", Required: true}, + {Name: "paths", Description: "Comma-separated URI paths starting with / (default /)"}, + {Name: "transport_override_url", Description: "Optional https:// dial URL for transportOverrideUrl"}, + {Name: "insecure", Description: "Set to true to skip TLS verification for this request"}, + {Name: "method", Description: "HTTP method (default HEAD)"}, + }, + }, authorRequestsConfigPrompt) +} + +func authorRequestsConfigPrompt(_ context.Context, req *sdkmcp.GetPromptRequest) (*sdkmcp.GetPromptResult, error) { + args := req.Params.Arguments + input := requestsConfigTemplateInput{ + Hostname: args["hostname"], + Paths: args["paths"], + TransportOverrideURL: args["transport_override_url"], + Method: args["method"], + } + + if strings.EqualFold(args["insecure"], "true") { + input.Insecure = true + } + + yaml, errs := buildRequestsConfigYAML(input) + if len(errs) > 0 { + return nil, fmt.Errorf("%s", strings.Join(errs, "; ")) + } + + exampleHints := exampleResourceHints(input) + text := strings.Join([]string{ + "Author a https-wrench requests configuration using the resources below.", + "", + "Reference resources:", + "- " + uriSchema, + "- " + uriSampleConfig, + "- " + uriDocsRequests, + exampleHints, + "", + "Starting skeleton YAML:", + "```yaml", + strings.TrimRight(yaml, "\n"), + "```", + "", + "After editing, validate with the validate_requests_config tool.", + }, "\n") + + return &sdkmcp.GetPromptResult{ + Description: "Author a requests YAML config for https-wrench", + Messages: []*sdkmcp.PromptMessage{{ + Role: "user", + Content: &sdkmcp.TextContent{Text: text}, + }}, + }, nil +} + +func exampleResourceHints(input requestsConfigTemplateInput) string { + var hints []string + + if strings.TrimSpace(input.TransportOverrideURL) != "" { + hints = append(hints, "- https-wrench://examples/k3s") + hints = append(hints, "- https-wrench://examples/proxy-protocol-v2") + } + + if input.Insecure { + hints = append(hints, "- https-wrench://examples/k3s") + } + + if len(hints) == 0 { + hints = append(hints, "- https-wrench://examples/k3s") + } + + return strings.Join(hints, "\n") +} diff --git a/internal/mcp/resources.go b/internal/mcp/resources.go new file mode 100644 index 0000000..eba5cbf --- /dev/null +++ b/internal/mcp/resources.go @@ -0,0 +1,81 @@ +package mcp + +import ( + "context" + "fmt" + "strings" + + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" +) + +func registerResources(server *sdkmcp.Server) { + server.AddResource(&sdkmcp.Resource{ + URI: uriSchema, + Name: "https-wrench JSON Schema", + Description: "JSON Schema for requests configuration files", + MIMEType: "application/json", + }, readStaticResource("assets/schema.json", uriSchema)) + + server.AddResource(&sdkmcp.Resource{ + URI: uriSampleConfig, + Name: "Sample requests config", + Description: "Starter YAML configuration for https-wrench requests", + MIMEType: "text/yaml", + }, readStaticResource("assets/sample-config.yaml", uriSampleConfig)) + + server.AddResource(&sdkmcp.Resource{ + URI: uriDocsRequests, + Name: "requests command reference", + Description: "Markdown cheat sheet for authoring requests YAML", + MIMEType: "text/markdown", + }, func(_ context.Context, _ *sdkmcp.ReadResourceRequest) (*sdkmcp.ReadResourceResult, error) { + return textResource(uriDocsRequests, requestsDocsMarkdown), nil + }) + + server.AddResourceTemplate(&sdkmcp.ResourceTemplate{ + URITemplate: uriExampleTmpl, + Name: "Example requests config", + Description: "Example YAML configs from assets/examples " + + "(names: k3s, response-certificates-filter, proxy-protocol-v2)", + MIMEType: "text/yaml", + }, readExampleResource) +} + +func readStaticResource(assetPath, uri string) sdkmcp.ResourceHandler { + return func(_ context.Context, _ *sdkmcp.ReadResourceRequest) (*sdkmcp.ReadResourceResult, error) { + data, err := assets.ReadFile(assetPath) + if err != nil { + return nil, fmt.Errorf("read embedded asset %q: %w", assetPath, err) + } + + return textResource(uri, string(data)), nil + } +} + +func readExampleResource(_ context.Context, req *sdkmcp.ReadResourceRequest) (*sdkmcp.ReadResourceResult, error) { + name := strings.TrimPrefix(req.Params.URI, "https-wrench://examples/") + if name == req.Params.URI || name == "" { + return nil, sdkmcp.ResourceNotFoundError(req.Params.URI) + } + + assetPath, ok := exampleFiles[name] + if !ok { + return nil, sdkmcp.ResourceNotFoundError(req.Params.URI) + } + + data, err := assets.ReadFile(assetPath) + if err != nil { + return nil, fmt.Errorf("read example %q: %w", name, err) + } + + return textResource(req.Params.URI, string(data)), nil +} + +func textResource(uri, text string) *sdkmcp.ReadResourceResult { + return &sdkmcp.ReadResourceResult{ + Contents: []*sdkmcp.ResourceContents{{ + URI: uri, + Text: text, + }}, + } +} diff --git a/internal/mcp/server.go b/internal/mcp/server.go new file mode 100644 index 0000000..5e19488 --- /dev/null +++ b/internal/mcp/server.go @@ -0,0 +1,64 @@ +package mcp + +import ( + "context" + "fmt" + + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" +) + +const serverInstructions = "MCP server for https-wrench. " + + "Use resources and prompts to author requests YAML, validate configs, and build CLI commands. " + + "Execution tools run requests probes, certinfo, jwtinfo, and JWKS generation directly." + +// Run starts the https-wrench MCP server on stdin/stdout until the client disconnects. +func Run(ctx context.Context, version string) error { + server := NewServer(version) + + return server.Run(ctx, &sdkmcp.StdioTransport{}) +} + +// NewServer builds an MCP server with assist and execution resources, prompts, and tools. +func NewServer(version string) *sdkmcp.Server { + if version == "" { + version = "development" + } + + server := sdkmcp.NewServer(&sdkmcp.Implementation{ + Name: "https-wrench", + Version: version, + }, &sdkmcp.ServerOptions{ + Instructions: serverInstructions, + }) + + registerResources(server) + registerPrompts(server) + registerTools(server) + + return server +} + +// RunInMemory connects the server to an in-memory transport pair for tests. +func RunInMemory(ctx context.Context, version string) (*sdkmcp.ClientSession, func(), error) { + server := NewServer(version) + client := sdkmcp.NewClient(&sdkmcp.Implementation{ + Name: "https-wrench-test-client", + Version: "test", + }, nil) + + t1, t2 := sdkmcp.NewInMemoryTransports() + if _, err := server.Connect(ctx, t1, nil); err != nil { + return nil, nil, fmt.Errorf("connect server: %w", err) + } + + session, err := client.Connect(ctx, t2, nil) + if err != nil { + return nil, nil, fmt.Errorf("connect client: %w", err) + } + + cleanup := func() { + _ = session.Close() + } + + return session, cleanup, nil +} diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go new file mode 100644 index 0000000..e164bcc --- /dev/null +++ b/internal/mcp/server_test.go @@ -0,0 +1,334 @@ +package mcp_test + +import ( + "context" + "encoding/json" + "testing" + + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/require" + mcpserver "github.com/xenos76/https-wrench/internal/mcp" +) + +func TestMCPServer_listsFeatures(t *testing.T) { + t.Parallel() + + ctx := context.Background() + session, cleanup, err := mcpserver.RunInMemory(ctx, "test") + require.NoError(t, err) + + defer cleanup() + + var toolNames []string + + for tool, err := range session.Tools(ctx, nil) { + require.NoError(t, err) + + toolNames = append(toolNames, tool.Name) + } + + require.ElementsMatch(t, []string{ + "validate_requests_config", + "requests_config_template", + "build_cli_command", + "run_requests", + "certinfo", + "jwtinfo", + "generate_jwks", + }, toolNames) + + var resourceURIs []string + + for res, err := range session.Resources(ctx, nil) { + require.NoError(t, err) + + resourceURIs = append(resourceURIs, res.URI) + } + + require.Contains(t, resourceURIs, "https-wrench://schema") + require.Contains(t, resourceURIs, "https-wrench://sample-config") + require.Contains(t, resourceURIs, "https-wrench://docs/requests") + + var promptNames []string + + for prompt, err := range session.Prompts(ctx, nil) { + require.NoError(t, err) + + promptNames = append(promptNames, prompt.Name) + } + + require.Contains(t, promptNames, "author_requests_config") +} + +func TestResources_readSchema(t *testing.T) { + t.Parallel() + + ctx := context.Background() + session, cleanup, err := mcpserver.RunInMemory(ctx, "test") + require.NoError(t, err) + + defer cleanup() + + res, err := session.ReadResource(ctx, &sdkmcp.ReadResourceParams{ + URI: "https-wrench://schema", + }) + require.NoError(t, err) + require.NotEmpty(t, res.Contents) + require.Contains(t, res.Contents[0].Text, "HttpsWrenchConfiguration") +} + +func TestResources_readExample(t *testing.T) { + t.Parallel() + + ctx := context.Background() + session, cleanup, err := mcpserver.RunInMemory(ctx, "test") + require.NoError(t, err) + + defer cleanup() + + res, err := session.ReadResource(ctx, &sdkmcp.ReadResourceParams{ + URI: "https-wrench://examples/k3s", + }) + require.NoError(t, err) + require.NotEmpty(t, res.Contents) + require.Contains(t, res.Contents[0].Text, "requests:") +} + +func TestValidateRequestsConfig_emptyRequests(t *testing.T) { + t.Parallel() + + out := callValidateTool(t, "verbose: true\nrequests: []\n") + require.False(t, out["valid"].(bool)) + require.NotEmpty(t, out["errors"]) +} + +func TestValidateRequestsConfig_emptyHostName(t *testing.T) { + t.Parallel() + + yaml := `verbose: true +requests: + - name: example + hosts: + - name: "" + uriList: + - / +` + + out := callValidateTool(t, yaml) + require.False(t, out["valid"].(bool)) + require.NotEmpty(t, out["errors"]) +} + +func TestRequestsConfigTemplate_missingHostname(t *testing.T) { + t.Parallel() + + ctx := context.Background() + session, cleanup, err := mcpserver.RunInMemory(ctx, "test") + require.NoError(t, err) + + defer cleanup() + + res, err := session.CallTool(ctx, &sdkmcp.CallToolParams{ + Name: "requests_config_template", + Arguments: map[string]any{}, + }) + require.NoError(t, err) + require.True(t, res.IsError) +} + +func TestBuildCLICommand_quotedFlag(t *testing.T) { + t.Parallel() + + out := callBuildCLITool(t, map[string]any{ + "command": "certinfo", + "flags": map[string]any{ + "tls-endpoint": "host with spaces:443", + }, + }) + require.Empty(t, out["errors"]) + require.Contains(t, out["command"], `"host with spaces:443"`) +} + +func TestValidateRequestsConfig_valid(t *testing.T) { + t.Parallel() + + yaml := `verbose: true +requests: + - name: example + requestMethod: HEAD + hosts: + - name: www.example.com + uriList: + - / +` + + out := callValidateTool(t, yaml) + require.True(t, out["valid"].(bool)) + require.Empty(t, out["errors"]) +} + +func TestValidateRequestsConfig_missingVerbose(t *testing.T) { + t.Parallel() + + yaml := `requests: + - name: example + hosts: + - name: www.example.com + uriList: + - / +` + + out := callValidateTool(t, yaml) + require.False(t, out["valid"].(bool)) + require.NotEmpty(t, out["errors"]) +} + +func TestValidateRequestsConfig_badURIList(t *testing.T) { + t.Parallel() + + yaml := `verbose: true +requests: + - name: example + hosts: + - name: www.example.com + uriList: + - no-leading-slash +` + + out := callValidateTool(t, yaml) + require.False(t, out["valid"].(bool)) + require.NotEmpty(t, out["errors"]) +} + +func TestBuildCLICommand_certinfo(t *testing.T) { + t.Parallel() + + out := callBuildCLITool(t, map[string]any{ + "command": "certinfo", + "flags": map[string]any{ + "tls-endpoint": "example.com:443", + "tls-info": "true", + }, + }) + + require.Empty(t, out["errors"]) + require.Equal(t, "https-wrench certinfo --tls-endpoint example.com:443 --tls-info true", out["command"]) +} + +func TestBuildCLICommand_jwksMissingRequired(t *testing.T) { + t.Parallel() + + out := callBuildCLITool(t, map[string]any{ + "command": "jwks", + "flags": map[string]any{}, + }) + + require.NotEmpty(t, out["errors"]) + require.Empty(t, out["command"]) +} + +func TestRequestsConfigTemplate(t *testing.T) { + t.Parallel() + + ctx := context.Background() + session, cleanup, err := mcpserver.RunInMemory(ctx, "test") + require.NoError(t, err) + + defer cleanup() + + res, err := session.CallTool(ctx, &sdkmcp.CallToolParams{ + Name: "requests_config_template", + Arguments: map[string]any{ + "hostname": "www.example.com", + "paths": "/,/health", + "transportOverrideUrl": "https://edge.example.net", + "method": "GET", + }, + }) + require.NoError(t, err) + require.False(t, res.IsError) + + out := decodeStructuredOutput(t, res) + yaml, ok := out["configYaml"].(string) + require.True(t, ok) + require.Contains(t, yaml, "transportOverrideUrl: https://edge.example.net") + require.Contains(t, yaml, "www.example.com") + require.Contains(t, yaml, "/health") +} + +func TestAuthorRequestsConfigPrompt(t *testing.T) { + t.Parallel() + + ctx := context.Background() + session, cleanup, err := mcpserver.RunInMemory(ctx, "test") + require.NoError(t, err) + + defer cleanup() + + res, err := session.GetPrompt(ctx, &sdkmcp.GetPromptParams{ + Name: "author_requests_config", + Arguments: map[string]string{ + "hostname": "app.example.com", + "paths": "/", + }, + }) + require.NoError(t, err) + require.NotEmpty(t, res.Messages) + text := res.Messages[0].Content.(*sdkmcp.TextContent).Text + require.Contains(t, text, "app.example.com") + require.Contains(t, text, "validate_requests_config") +} + +func callValidateTool(t *testing.T, yaml string) map[string]any { + t.Helper() + + ctx := context.Background() + session, cleanup, err := mcpserver.RunInMemory(ctx, "test") + require.NoError(t, err) + + defer cleanup() + + res, err := session.CallTool(ctx, &sdkmcp.CallToolParams{ + Name: "validate_requests_config", + Arguments: map[string]any{ + "configYaml": yaml, + }, + }) + require.NoError(t, err) + require.False(t, res.IsError) + + return decodeStructuredOutput(t, res) +} + +func callBuildCLITool(t *testing.T, args map[string]any) map[string]any { + t.Helper() + + ctx := context.Background() + session, cleanup, err := mcpserver.RunInMemory(ctx, "test") + require.NoError(t, err) + + defer cleanup() + + res, err := session.CallTool(ctx, &sdkmcp.CallToolParams{ + Name: "build_cli_command", + Arguments: args, + }) + require.NoError(t, err) + require.False(t, res.IsError) + + return decodeStructuredOutput(t, res) +} + +func decodeStructuredOutput(t *testing.T, res *sdkmcp.CallToolResult) map[string]any { + t.Helper() + + require.NotNil(t, res.StructuredContent) + + data, err := json.Marshal(res.StructuredContent) + require.NoError(t, err) + + var out map[string]any + require.NoError(t, json.Unmarshal(data, &out)) + + return out +} diff --git a/internal/mcp/tools.go b/internal/mcp/tools.go new file mode 100644 index 0000000..fccc073 --- /dev/null +++ b/internal/mcp/tools.go @@ -0,0 +1,401 @@ +package mcp + +import ( + "context" + "fmt" + "slices" + "strconv" + "strings" + + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/spf13/viper" + "github.com/xenos76/https-wrench/internal/requests" +) + +type validateRequestsConfigInput struct { + ConfigYAML string `json:"configYaml" jsonschema:"YAML content of a https-wrench requests configuration file"` +} + +type validateRequestsConfigOutput struct { + Valid bool `json:"valid"` + Errors []string `json:"errors,omitempty"` +} + +type requestsConfigTemplateInput struct { + Hostname string `json:"hostname" jsonschema:"Application hostname (hosts[].name)"` + Paths string `json:"paths,omitempty" jsonschema:"Comma-separated URI paths starting with /"` + TransportOverrideURL string `json:"transportOverrideUrl,omitempty" jsonschema:"Optional https:// dial URL"` + Insecure bool `json:"insecure,omitempty" jsonschema:"Set insecure true on the request"` + Method string `json:"method,omitempty" jsonschema:"HTTP method (default HEAD)"` + RequestName string `json:"requestName,omitempty" jsonschema:"Display name for the request entry"` +} + +type requestsConfigTemplateOutput struct { + ConfigYAML string `json:"configYaml"` +} + +type buildCLICommandInput struct { + Command string `json:"command" jsonschema:"Subcommand: certinfo, jwtinfo, jwks, or requests"` + Flags map[string]string `json:"flags" jsonschema:"Flag names (without leading dashes) to values"` +} + +type buildCLICommandOutput struct { + Command string `json:"command"` + Errors []string `json:"errors,omitempty"` +} + +type cliCommandDef struct { + requiredFlags []string + oneOfGroups [][]string + allowedFlags map[string]struct{} +} + +type parsedRequestsConfig struct { + Verbose bool + Requests []requests.RequestConfig +} + +func registerTools(server *sdkmcp.Server) { + sdkmcp.AddTool(server, &sdkmcp.Tool{ + Name: "validate_requests_config", + Description: "Parse and structurally validate a https-wrench requests YAML configuration", + }, validateRequestsConfigHandler) + + sdkmcp.AddTool(server, &sdkmcp.Tool{ + Name: "requests_config_template", + Description: "Generate a starter requests YAML configuration from high-level parameters", + }, requestsConfigTemplateHandler) + + sdkmcp.AddTool(server, &sdkmcp.Tool{ + Name: "build_cli_command", + Description: "Build a shell-ready https-wrench CLI command for certinfo, jwtinfo, jwks, or requests", + }, buildCLICommandHandler) + + registerExecTools(server) +} + +func validateRequestsConfigHandler( + _ context.Context, + _ *sdkmcp.CallToolRequest, + input validateRequestsConfigInput, +) (*sdkmcp.CallToolResult, validateRequestsConfigOutput, error) { + valid, errs := validateRequestsConfig(input.ConfigYAML) + + return nil, validateRequestsConfigOutput{Valid: valid, Errors: errs}, nil +} + +func requestsConfigTemplateHandler( + _ context.Context, + _ *sdkmcp.CallToolRequest, + input requestsConfigTemplateInput, +) (*sdkmcp.CallToolResult, requestsConfigTemplateOutput, error) { + yaml, errs := buildRequestsConfigYAML(input) + if len(errs) > 0 { + return nil, requestsConfigTemplateOutput{}, fmt.Errorf("%s", strings.Join(errs, "; ")) + } + + return nil, requestsConfigTemplateOutput{ConfigYAML: yaml}, nil +} + +func buildCLICommandHandler( + _ context.Context, + _ *sdkmcp.CallToolRequest, + input buildCLICommandInput, +) (*sdkmcp.CallToolResult, buildCLICommandOutput, error) { + cmd, errs := buildCLICommand(input.Command, input.Flags) + + out := buildCLICommandOutput{Command: cmd, Errors: errs} + if len(errs) > 0 { + return nil, out, nil + } + + return nil, out, nil +} + +func validateRequestsConfig(yamlContent string) (bool, []string) { + cfg, verboseSet, err := parseRequestsConfigYAML(yamlContent) + if err != nil { + return false, []string{err.Error()} + } + + var errs []string + + if !verboseSet { + errs = append(errs, "verbose is required") + } + + if len(cfg.Requests) == 0 { + errs = append(errs, "requests must contain at least one entry") + } + + for i, req := range cfg.Requests { + errs = append(errs, validateRequestConfig(i, req)...) + } + + return len(errs) == 0, errs +} + +func parseRequestsConfigYAML(yamlContent string) (parsedRequestsConfig, bool, error) { + v := viper.New() + v.SetConfigType("yaml") + + if err := v.ReadConfig(strings.NewReader(yamlContent)); err != nil { + return parsedRequestsConfig{}, false, fmt.Errorf("yaml parse: %w", err) + } + + cfg := struct { + Verbose bool `mapstructure:"verbose"` + requests.RequestsMetaConfig `mapstructure:",squash"` + }{} + + if err := v.Unmarshal(&cfg); err != nil { + return parsedRequestsConfig{}, false, fmt.Errorf("config unmarshal: %w", err) + } + + return parsedRequestsConfig{ + Verbose: cfg.Verbose, + Requests: cfg.Requests, + }, v.IsSet("verbose"), nil +} + +func validateRequestConfig(index int, req requests.RequestConfig) []string { + prefix := fmt.Sprintf("requests[%d]", index) + + var errs []string + + if req.Name == "" { + errs = append(errs, prefix+": name is required") + } + + if len(req.Hosts) == 0 { + errs = append(errs, prefix+": hosts must be non-empty") + } + + if req.TransportOverrideURL != "" && !strings.HasPrefix(req.TransportOverrideURL, "https://") { + errs = append(errs, prefix+": transportOverrideUrl must start with https://") + } + + if req.EnableProxyProtocolV2 && req.TransportOverrideURL == "" { + errs = append(errs, prefix+": enableProxyProtocolV2 requires transportOverrideUrl") + } + + for hi, host := range req.Hosts { + errs = append(errs, validateRequestHost(prefix, hi, host)...) + } + + return errs +} + +func validateRequestHost(prefix string, index int, host requests.Host) []string { + hostPrefix := fmt.Sprintf("%s.hosts[%d]", prefix, index) + + var errs []string + + if host.Name == "" { + errs = append(errs, hostPrefix+": name is required") + } + + for ui, uri := range host.URIList { + if !uri.Parse() { + errs = append(errs, fmt.Sprintf("%s.uriList[%d]: path %q must start with /", hostPrefix, ui, uri)) + } + } + + return errs +} + +func buildRequestsConfigYAML(input requestsConfigTemplateInput) (string, []string) { + var errs []string + + hostname := strings.TrimSpace(input.Hostname) + if hostname == "" { + errs = append(errs, "hostname is required") + } + + method := strings.ToUpper(strings.TrimSpace(input.Method)) + if method == "" { + method = "HEAD" + } + + name := strings.TrimSpace(input.RequestName) + if name == "" { + name = "example-" + strings.ReplaceAll(hostname, ".", "-") + } + + paths := parsePaths(input.Paths) + + for _, p := range paths { + if !strings.HasPrefix(p, "/") { + errs = append(errs, fmt.Sprintf("path %q must start with /", p)) + } + } + + if len(errs) > 0 { + return "", errs + } + + var b strings.Builder + fmt.Fprintln(&b, schemaCommentHeader) + fmt.Fprintln(&b, "---") + fmt.Fprintln(&b, "verbose: true") + fmt.Fprintln(&b, "requests:") + fmt.Fprintf(&b, " - name: %s\n", name) + fmt.Fprintf(&b, " requestMethod: %s\n", method) + + if transport := strings.TrimSpace(input.TransportOverrideURL); transport != "" { + fmt.Fprintf(&b, " transportOverrideUrl: %s\n", transport) + } + + if input.Insecure { + fmt.Fprintln(&b, " insecure: true") + } + + fmt.Fprintln(&b, " hosts:") + fmt.Fprintf(&b, " - name: %s\n", hostname) + fmt.Fprintln(&b, " uriList:") + + for _, p := range paths { + fmt.Fprintf(&b, " - %s\n", p) + } + + return strings.TrimRight(b.String(), "\n") + "\n", nil +} + +func parsePaths(paths string) []string { + paths = strings.TrimSpace(paths) + if paths == "" { + return []string{"/"} + } + + parts := strings.Split(paths, ",") + + out := make([]string, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p != "" { + out = append(out, p) + } + } + + if len(out) == 0 { + return []string{"/"} + } + + return out +} + +var allowedCLICommands = map[string]cliCommandDef{ + "certinfo": { + oneOfGroups: [][]string{{"tls-endpoint", "cert-bundle", "key-file", "ca-bundle"}}, + allowedFlags: map[string]struct{}{ + "ca-bundle": {}, "cert-bundle": {}, "key-file": {}, + "tls-endpoint": {}, "tls-servername": {}, "tls-insecure": {}, "tls-info": {}, + }, + }, + "jwtinfo": { + oneOfGroups: [][]string{{"token-file", "request-url"}}, + allowedFlags: map[string]struct{}{ + "token-file": {}, "request-url": {}, "request-values": {}, + "request-values-json": {}, "request-values-file": {}, + "validation-url": {}, "refresh": {}, "token-output-file": {}, "renew-threshold": {}, + }, + }, + "jwks": { + requiredFlags: []string{"public-key-file"}, + allowedFlags: map[string]struct{}{ + "public-key-file": {}, "kid": {}, + }, + }, + "requests": { + oneOfGroups: [][]string{{"config", "show-sample-config"}}, + allowedFlags: map[string]struct{}{ + "config": {}, "show-sample-config": {}, "ca-bundle": {}, + }, + }, +} + +func buildCLICommand(command string, flags map[string]string) (string, []string) { + command = strings.ToLower(strings.TrimSpace(command)) + + def, ok := allowedCLICommands[command] + if !ok { + return "", []string{ + fmt.Sprintf("unsupported command %q (use certinfo, jwtinfo, jwks, or requests)", command), + } + } + + errs := validateCLIFlags(command, def, flags) + if len(errs) > 0 { + return "", errs + } + + names := sortedAllowedFlagNames(def, flags) + + var parts []string + + parts = append(parts, "https-wrench", command) + for _, name := range names { + parts = append(parts, "--"+name, shellQuote(flags[name])) + } + + return strings.Join(parts, " "), nil +} + +func validateCLIFlags(command string, def cliCommandDef, flags map[string]string) []string { + var errs []string + + for _, req := range def.requiredFlags { + if _, set := flags[req]; !set { + errs = append(errs, fmt.Sprintf("missing required flag %q", req)) + } + } + + for _, group := range def.oneOfGroups { + if !oneOfFlagsSet(group, flags) { + errs = append(errs, fmt.Sprintf("one of flags %v is required", group)) + } + } + + for name := range flags { + if _, allowed := def.allowedFlags[name]; !allowed { + errs = append(errs, fmt.Sprintf("unknown flag %q for command %q", name, command)) + } + } + + return errs +} + +func oneOfFlagsSet(group []string, flags map[string]string) bool { + for _, name := range group { + if _, set := flags[name]; set { + return true + } + } + + return false +} + +func sortedAllowedFlagNames(def cliCommandDef, flags map[string]string) []string { + names := make([]string, 0, len(flags)) + for name := range flags { + if _, allowed := def.allowedFlags[name]; allowed { + names = append(names, name) + } + } + + slices.Sort(names) + + return names +} + +func shellQuote(value string) string { + if value == "" { + return "''" + } + + if !strings.ContainsAny(value, " \t\n\"'\\$`!#&|;<>()*?[]{}~") { + return value + } + + return strconv.Quote(value) +} diff --git a/internal/mcp/tools_exec.go b/internal/mcp/tools_exec.go new file mode 100644 index 0000000..ef562b0 --- /dev/null +++ b/internal/mcp/tools_exec.go @@ -0,0 +1,444 @@ +package mcp + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net/http" + "os" + "strings" + "time" + + "github.com/MicahParks/keyfunc/v3" + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/spf13/viper" + "github.com/xenos76/https-wrench/internal/certinfo" + "github.com/xenos76/https-wrench/internal/jwks" + "github.com/xenos76/https-wrench/internal/jwtinfo" + "github.com/xenos76/https-wrench/internal/requests" +) + +const ( + defaultExecToolTimeout = 60 * time.Second + certinfoKeyPasswordEnv = "CERTINFO_PKEY_PW" +) + +type execToolOutput struct { + Output string `json:"output"` + Error string `json:"error,omitempty"` +} + +type runRequestsInput struct { + ConfigYAML string `json:"configYaml,omitempty" jsonschema:"Inline requests YAML configuration"` + ConfigPath string `json:"configPath,omitempty" jsonschema:"Path to requests YAML on the MCP server host"` + CaBundlePath string `json:"caBundlePath,omitempty" jsonschema:"Optional CA bundle PEM file path"` + TimeoutSec int `json:"timeoutSec,omitempty" jsonschema:"Overall operation timeout in seconds (default 60)"` +} + +type certinfoInput struct { + CaBundle string `json:"caBundle,omitempty"` + CertBundle string `json:"certBundle,omitempty"` + KeyFile string `json:"keyFile,omitempty"` + TLSEndpoint string `json:"tlsEndpoint,omitempty"` + TLSServername string `json:"tlsServername,omitempty"` + TLSInsecure bool `json:"tlsInsecure,omitempty"` + TLSInfo bool `json:"tlsInfo,omitempty"` + TimeoutSec int `json:"timeoutSec,omitempty"` +} + +type jwtinfoInput struct { + TokenFile string `json:"tokenFile,omitempty"` + RequestURL string `json:"requestUrl,omitempty"` + RequestValues map[string]string `json:"requestValues,omitempty"` + ValidationURL string `json:"validationUrl,omitempty"` + TimeoutSec int `json:"timeoutSec,omitempty"` +} + +type generateJWKSInput struct { + PublicKeyFile string `json:"publicKeyFile" jsonschema:"Path to PEM-encoded public key file"` + Kid string `json:"kid,omitempty" jsonschema:"Optional key ID"` +} + +type loadedRequestsConfig struct { + Debug bool + Verbose bool + CaBundle string + Requests []requests.RequestConfig +} + +type mcpFileReader struct{} + +func (mcpFileReader) ReadFile(name string) ([]byte, error) { + return os.ReadFile(name) +} + +func (mcpFileReader) ReadPassword(_ int) ([]byte, error) { + return nil, errors.New("encrypted private keys require CERTINFO_PKEY_PW under MCP") +} + +func registerExecTools(server *sdkmcp.Server) { + sdkmcp.AddTool(server, &sdkmcp.Tool{ + Name: "run_requests", + Description: "Execute https-wrench requests from inline YAML or a config file path", + }, runRequestsHandler) + + sdkmcp.AddTool(server, &sdkmcp.Tool{ + Name: "certinfo", + Description: "Inspect x.509 certificates and keys from local files or a TLS endpoint", + }, certinfoHandler) + + sdkmcp.AddTool(server, &sdkmcp.Tool{ + Name: "jwtinfo", + Description: "Inspect JWT tokens from a file or token endpoint (no refresh loop)", + }, jwtinfoHandler) + + sdkmcp.AddTool(server, &sdkmcp.Tool{ + Name: "generate_jwks", + Description: "Generate a JSON Web Key Set from a PEM public key file", + }, generateJWKSHandler) +} + +func runRequestsHandler( + ctx context.Context, + _ *sdkmcp.CallToolRequest, + input runRequestsInput, +) (*sdkmcp.CallToolResult, execToolOutput, error) { + ctx, cancel := toolContext(ctx, input.TimeoutSec) + defer cancel() + + out, err := executeRunRequests(ctx, input) + if err != nil { + return nil, execToolOutput{Error: err.Error()}, nil + } + + return nil, out, nil +} + +func certinfoHandler( + ctx context.Context, + _ *sdkmcp.CallToolRequest, + input certinfoInput, +) (*sdkmcp.CallToolResult, execToolOutput, error) { + ctx, cancel := toolContext(ctx, input.TimeoutSec) + defer cancel() + + out, err := executeCertinfo(ctx, input) + if err != nil { + return nil, execToolOutput{Error: err.Error()}, nil + } + + return nil, out, nil +} + +func jwtinfoHandler( + ctx context.Context, + _ *sdkmcp.CallToolRequest, + input jwtinfoInput, +) (*sdkmcp.CallToolResult, execToolOutput, error) { + ctx, cancel := toolContext(ctx, input.TimeoutSec) + defer cancel() + + out, err := executeJwtinfo(ctx, input) + if err != nil { + return nil, execToolOutput{Error: err.Error()}, nil + } + + return nil, out, nil +} + +func generateJWKSHandler( + ctx context.Context, + _ *sdkmcp.CallToolRequest, + input generateJWKSInput, +) (*sdkmcp.CallToolResult, execToolOutput, error) { + ctx, cancel := toolContext(ctx, 0) + defer cancel() + + out, err := executeGenerateJWKS(ctx, input) + if err != nil { + return nil, execToolOutput{Error: err.Error()}, nil + } + + return nil, out, nil +} + +func executeRunRequests(_ context.Context, input runRequestsInput) (execToolOutput, error) { + yamlContent, err := loadConfigYAML(input.ConfigYAML, input.ConfigPath) + if err != nil { + return execToolOutput{}, err + } + + valid, errs := validateRequestsConfig(yamlContent) + if !valid { + return execToolOutput{}, fmt.Errorf("invalid config: %s", strings.Join(errs, "; ")) + } + + loaded, _, err := loadRequestsConfigYAML(yamlContent) + if err != nil { + return execToolOutput{}, err + } + + meta, err := buildRequestsMetaConfig(loaded, input.CaBundlePath) + if err != nil { + return execToolOutput{}, err + } + + output, err := captureWithStdout(func(w io.Writer) error { + _, handleErr := requests.HandleRequests(w, meta) + + return handleErr + }) + if err != nil { + return execToolOutput{}, err + } + + return execToolOutput{Output: output}, nil +} + +func executeCertinfo(_ context.Context, input certinfoInput) (execToolOutput, error) { + if !certinfoInputProvided(input) { + return execToolOutput{}, errors.New( + "one of tlsEndpoint, certBundle, keyFile, or caBundle is required", + ) + } + + if input.TLSInfo && input.TLSEndpoint == "" { + return execToolOutput{}, errors.New("tlsInfo requires tlsEndpoint") + } + + cfg, err := certinfo.New() + if err != nil { + return execToolOutput{}, err + } + + reader := mcpFileReader{} + + if err = cfg.SetCaPoolFromFile(input.CaBundle, reader); err != nil { + return execToolOutput{}, err + } + + if err = cfg.SetCertsFromFile(input.CertBundle, reader); err != nil { + return execToolOutput{}, err + } + + cfg.SetTLSInsecure(input.TLSInsecure). + SetTLSServerName(input.TLSServername). + SetTLSInfoRequested(input.TLSInfo) + + if err = cfg.SetTLSEndpoint(input.TLSEndpoint); err != nil { + return execToolOutput{}, err + } + + if err = cfg.SetPrivateKeyFromFile(input.KeyFile, certinfoKeyPasswordEnv, reader); err != nil { + return execToolOutput{}, err + } + + output, err := captureOutput(cfg.PrintData) + if err != nil { + return execToolOutput{}, err + } + + return execToolOutput{Output: output}, nil +} + +func executeJwtinfo(ctx context.Context, input jwtinfoInput) (execToolOutput, error) { + tokenData, err := loadJwtTokenData(ctx, input) + if err != nil { + return execToolOutput{}, err + } + + if tokenData == nil || tokenData.AccessTokenRaw == "" { + return execToolOutput{}, errors.New("no JWT token data available") + } + + if err = tokenData.DecodeBase64(); err != nil { + return execToolOutput{}, err + } + + if input.ValidationURL != "" { + if err = tokenData.ParseWithJWKS(ctx, input.ValidationURL, keyfunc.Override{}); err != nil { + return execToolOutput{}, err + } + } + + output, err := captureOutput(func(w io.Writer) error { + return jwtinfo.PrintTokenInfo(tokenData, w) + }) + if err != nil { + return execToolOutput{}, err + } + + return execToolOutput{Output: output}, nil +} + +func executeGenerateJWKS(ctx context.Context, input generateJWKSInput) (execToolOutput, error) { + if strings.TrimSpace(input.PublicKeyFile) == "" { + return execToolOutput{}, errors.New("publicKeyFile is required") + } + + jwksJSON, err := jwks.GenerateJWKS(ctx, input.PublicKeyFile, input.Kid) + if err != nil { + return execToolOutput{}, err + } + + return execToolOutput{Output: jwksJSON}, nil +} + +func loadConfigYAML(configYAML, configPath string) (string, error) { + hasYAML := strings.TrimSpace(configYAML) != "" + hasPath := strings.TrimSpace(configPath) != "" + + switch { + case hasYAML && hasPath: + return "", errors.New("provide exactly one of configYaml or configPath") + case !hasYAML && !hasPath: + return "", errors.New("configYaml or configPath is required") + case hasPath: + data, err := os.ReadFile(configPath) + if err != nil { + return "", fmt.Errorf("read config file: %w", err) + } + + return string(data), nil + default: + return configYAML, nil + } +} + +func loadRequestsConfigYAML(yamlContent string) (loadedRequestsConfig, bool, error) { + v := viper.New() + v.SetConfigType("yaml") + + if err := v.ReadConfig(strings.NewReader(yamlContent)); err != nil { + return loadedRequestsConfig{}, false, fmt.Errorf("yaml parse: %w", err) + } + + cfg := struct { + Debug bool `mapstructure:"debug"` + Verbose bool `mapstructure:"verbose"` + CaBundle string `mapstructure:"caBundle"` + requests.RequestsMetaConfig `mapstructure:",squash"` + }{} + + if err := v.Unmarshal(&cfg); err != nil { + return loadedRequestsConfig{}, false, fmt.Errorf("config unmarshal: %w", err) + } + + return loadedRequestsConfig{ + Debug: cfg.Debug, + Verbose: cfg.Verbose, + CaBundle: cfg.CaBundle, + Requests: cfg.Requests, + }, v.IsSet("verbose"), nil +} + +func buildRequestsMetaConfig(loaded loadedRequestsConfig, caBundlePath string) (*requests.RequestsMetaConfig, error) { + meta, err := requests.NewRequestsMetaConfig() + if err != nil { + return nil, err + } + + meta.SetVerbose(loaded.Verbose). + SetDebug(loaded.Debug). + SetRequests(loaded.Requests) + + if err = meta.SetCaPoolFromYAML(loaded.CaBundle); err != nil { + return nil, err + } + + if err = meta.SetCaPoolFromFile(caBundlePath, mcpFileReader{}); err != nil { + return nil, err + } + + return meta, nil +} + +func loadJwtTokenData(ctx context.Context, input jwtinfoInput) (*jwtinfo.JwtTokenData, error) { + hasFile := strings.TrimSpace(input.TokenFile) != "" + hasURL := strings.TrimSpace(input.RequestURL) != "" + + switch { + case hasFile && hasURL: + return nil, errors.New("provide exactly one of tokenFile or requestUrl") + case !hasFile && !hasURL: + return nil, errors.New("tokenFile or requestUrl is required") + case hasFile: + return jwtinfo.ReadTokenFromFile(input.TokenFile) + default: + if len(input.RequestValues) == 0 { + return nil, errors.New("requestValues is required with requestUrl") + } + + client := &http.Client{Timeout: execToolTimeout(input.TimeoutSec)} + + return jwtinfo.RequestToken(ctx, input.RequestURL, input.RequestValues, client, io.ReadAll) + } +} + +func certinfoInputProvided(input certinfoInput) bool { + return input.CaBundle != "" || + input.CertBundle != "" || + input.KeyFile != "" || + input.TLSEndpoint != "" +} + +func toolContext(parent context.Context, timeoutSec int) (context.Context, context.CancelFunc) { + timeout := execToolTimeout(timeoutSec) + + return context.WithTimeout(parent, timeout) +} + +func execToolTimeout(timeoutSec int) time.Duration { + if timeoutSec <= 0 { + return defaultExecToolTimeout + } + + return time.Duration(timeoutSec) * time.Second +} + +func captureOutput(fn func(io.Writer) error) (string, error) { + var buf bytes.Buffer + + if err := fn(&buf); err != nil { + return buf.String(), err + } + + return buf.String(), nil +} + +func captureWithStdout(fn func(io.Writer) error) (string, error) { + pipeR, pipeW, err := os.Pipe() + if err != nil { + return "", err + } + + oldStdout := os.Stdout + os.Stdout = pipeW + + var writerBuf bytes.Buffer + + fnErr := fn(&writerBuf) + + if closeErr := pipeW.Close(); closeErr != nil && fnErr == nil { + fnErr = closeErr + } + + os.Stdout = oldStdout + + var stdoutBuf bytes.Buffer + + if _, copyErr := io.Copy(&stdoutBuf, pipeR); copyErr != nil && fnErr == nil { + fnErr = copyErr + } + + _ = pipeR.Close() + + var combined bytes.Buffer + + _, _ = combined.Write(stdoutBuf.Bytes()) + _, _ = combined.Write(writerBuf.Bytes()) + + return combined.String(), fnErr +} diff --git a/internal/mcp/tools_exec_test.go b/internal/mcp/tools_exec_test.go new file mode 100644 index 0000000..c7db381 --- /dev/null +++ b/internal/mcp/tools_exec_test.go @@ -0,0 +1,385 @@ +package mcp_test + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/require" + mcpserver "github.com/xenos76/https-wrench/internal/mcp" +) + +func TestGenerateJwksTool(t *testing.T) { + t.Parallel() + + pubFile := writeRSAPublicKeyPEM(t) + + out := callExecTool(t, "generate_jwks", map[string]any{ + "publicKeyFile": pubFile, + }) + require.Empty(t, out["error"]) + require.Contains(t, out["output"], `"keys"`) +} + +func TestCertinfoTool_localBundle(t *testing.T) { + t.Parallel() + + out := callExecTool(t, "certinfo", map[string]any{ + "certBundle": filepath.Join("..", "certinfo", "testdata", "rsa-pkcs8-crt.pem"), + }) + require.Empty(t, out["error"]) + require.Contains(t, out["output"], "Certinfo") +} + +func TestCertinfoTool_encryptedKeyNoPassword(t *testing.T) { + t.Setenv("CERTINFO_PKEY_PW", "") + + out := callExecTool(t, "certinfo", map[string]any{ + "keyFile": filepath.Join("..", "certinfo", "testdata", "rsa-pkcs8-encrypted-private-key.pem"), + }) + require.NotEmpty(t, out["error"]) + require.Contains(t, out["error"], "CERTINFO_PKEY_PW") +} + +func TestJwtinfoTool_tokenFile(t *testing.T) { + t.Parallel() + + tokenFile := writeJWTTokenFile(t) + + out := callExecTool(t, "jwtinfo", map[string]any{ + "tokenFile": tokenFile, + }) + require.Empty(t, out["error"]) + require.Contains(t, out["output"], "JwtInfo") +} + +func TestRunRequestsTool_invalidConfig(t *testing.T) { + t.Parallel() + + out := callExecTool(t, "run_requests", map[string]any{ + "configYaml": "requests:\n - name: x\n hosts:\n - name: example.com\n", + }) + require.NotEmpty(t, out["error"]) + require.Contains(t, out["error"], "verbose") +} + +func TestRunRequestsTool_configPath(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + cfgPath := filepath.Join(dir, "requests.yaml") + yaml := `verbose: true +requests: + - name: from-file + requestMethod: HEAD + hosts: + - name: example.com + uriList: + - / +` + require.NoError(t, os.WriteFile(cfgPath, []byte(yaml), 0o600)) + + out := callExecTool(t, "run_requests", map[string]any{ + "configPath": cfgPath, + }) + require.Empty(t, out["error"]) + require.Contains(t, out["output"], "from-file") +} + +func TestRunRequestsTool_bothConfigSources(t *testing.T) { + t.Parallel() + + out := callExecTool(t, "run_requests", map[string]any{ + "configYaml": "verbose: true", + "configPath": "/tmp/x.yaml", + }) + require.NotEmpty(t, out["error"]) + require.Contains(t, out["error"], "exactly one") +} + +func TestCertinfoTool_noInput(t *testing.T) { + t.Parallel() + + out := callExecTool(t, "certinfo", map[string]any{}) + require.NotEmpty(t, out["error"]) +} + +func TestCertinfoTool_tlsInfoRequiresEndpoint(t *testing.T) { + t.Parallel() + + out := callExecTool(t, "certinfo", map[string]any{ + "certBundle": filepath.Join("..", "certinfo", "testdata", "rsa-pkcs8-crt.pem"), + "tlsInfo": true, + }) + require.NotEmpty(t, out["error"]) + require.Contains(t, out["error"], "tlsInfo requires tlsEndpoint") +} + +func TestGenerateJwksTool_missingFile(t *testing.T) { + t.Parallel() + + out := callExecTool(t, "generate_jwks", map[string]any{ + "publicKeyFile": "", + }) + require.NotEmpty(t, out["error"]) +} + +func TestJwtinfoTool_invalidToken(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + path := filepath.Join(dir, "bad.jwt") + require.NoError(t, os.WriteFile(path, []byte("not-a-jwt"), 0o600)) + + out := callExecTool(t, "jwtinfo", map[string]any{ + "tokenFile": path, + }) + require.NotEmpty(t, out["error"]) +} + +func TestRunRequestsTool_invalidCaBundle(t *testing.T) { + t.Parallel() + + out := callExecTool(t, "run_requests", map[string]any{ + "configYaml": `verbose: true +caBundle: "not-a-pem-bundle" +requests: + - name: bad-ca + requestMethod: HEAD + hosts: + - name: example.com + uriList: + - / +`, + }) + require.NotEmpty(t, out["error"]) +} + +func TestRunRequestsTool_invalidCaBundlePath(t *testing.T) { + t.Parallel() + + out := callExecTool(t, "run_requests", map[string]any{ + "configYaml": `verbose: true +requests: + - name: bad-ca-path + requestMethod: HEAD + hosts: + - name: example.com + uriList: + - / +`, + "caBundlePath": filepath.Join(t.TempDir(), "missing-ca.pem"), + }) + require.NotEmpty(t, out["error"]) +} + +func TestJwtinfoTool_emptyAccessToken(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + path := filepath.Join(dir, "empty.jwt") + require.NoError(t, os.WriteFile(path, []byte(`{"access_token":""}`), 0o600)) + + out := callExecTool(t, "jwtinfo", map[string]any{ + "tokenFile": path, + }) + require.NotEmpty(t, out["error"]) +} + +func TestCertinfoTool_invalidCertBundle(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + path := filepath.Join(dir, "bad.pem") + require.NoError(t, os.WriteFile(path, []byte("not a cert"), 0o600)) + + out := callExecTool(t, "certinfo", map[string]any{ + "certBundle": path, + }) + require.NotEmpty(t, out["error"]) +} + +func TestCertinfoTool_invalidCaBundle(t *testing.T) { + t.Parallel() + + out := callExecTool(t, "certinfo", map[string]any{ + "caBundle": filepath.Join(t.TempDir(), "missing-ca.pem"), + }) + require.NotEmpty(t, out["error"]) +} + +func TestBuildCLICommand_jwtinfo(t *testing.T) { + t.Parallel() + + out := callBuildCLITool(t, map[string]any{ + "command": "jwtinfo", + "flags": map[string]any{ + "token-file": "token.jwt", + }, + }) + require.Empty(t, out["errors"]) + require.Contains(t, out["command"], "jwtinfo") +} + +func TestBuildCLICommand_requests(t *testing.T) { + t.Parallel() + + out := callBuildCLITool(t, map[string]any{ + "command": "requests", + "flags": map[string]any{ + "show-sample-config": "true", + }, + }) + require.Empty(t, out["errors"]) + require.Contains(t, out["command"], "requests") +} + +func TestResources_readDocs(t *testing.T) { + t.Parallel() + + ctx := context.Background() + session, cleanup, err := mcpserver.RunInMemory(ctx, "test") + require.NoError(t, err) + + defer cleanup() + + res, err := session.ReadResource(ctx, &sdkmcp.ReadResourceParams{ + URI: "https-wrench://docs/requests", + }) + require.NoError(t, err) + require.NotEmpty(t, res.Contents) + require.Contains(t, res.Contents[0].Text, "requests") +} + +func TestResources_readSampleConfig(t *testing.T) { + t.Parallel() + + ctx := context.Background() + session, cleanup, err := mcpserver.RunInMemory(ctx, "test") + require.NoError(t, err) + + defer cleanup() + + res, err := session.ReadResource(ctx, &sdkmcp.ReadResourceParams{ + URI: "https-wrench://sample-config", + }) + require.NoError(t, err) + require.Contains(t, res.Contents[0].Text, "requests:") +} + +func TestValidateRequestsConfig_invalidYAML(t *testing.T) { + t.Parallel() + + out := callValidateTool(t, "verbose: true\nrequests: [") + require.False(t, out["valid"].(bool)) +} + +func TestRunRequestsTool_inlineYaml(t *testing.T) { + t.Parallel() + + ts := httptest.NewTLSServer(httpHandlerOK()) + t.Cleanup(ts.Close) + + hostPort := ts.Listener.Addr().String() + yaml := `verbose: true +requests: + - name: mcp-test + requestMethod: GET + insecure: true + transportOverrideUrl: https://` + hostPort + ` + printResponseBody: true + hosts: + - name: example.com + uriList: + - / +` + + out := callExecTool(t, "run_requests", map[string]any{ + "configYaml": yaml, + }) + require.Empty(t, out["error"]) + require.Contains(t, out["output"], "mcp-test") +} + +func callExecTool(t *testing.T, name string, args map[string]any) map[string]any { + t.Helper() + + ctx := context.Background() + session, cleanup, err := mcpserver.RunInMemory(ctx, "test") + require.NoError(t, err) + + defer cleanup() + + res, err := session.CallTool(ctx, &sdkmcp.CallToolParams{ + Name: name, + Arguments: args, + }) + require.NoError(t, err) + require.False(t, res.IsError) + + return decodeStructuredOutput(t, res) +} + +func writeRSAPublicKeyPEM(t *testing.T) string { + t.Helper() + + priv, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + pubBytes, err := x509.MarshalPKIXPublicKey(&priv.PublicKey) + require.NoError(t, err) + + path := filepath.Join(t.TempDir(), "rsa-public.pem") + file, err := os.Create(path) + require.NoError(t, err) + + err = pem.Encode(file, &pem.Block{Type: "PUBLIC KEY", Bytes: pubBytes}) + require.NoError(t, err) + require.NoError(t, file.Close()) + + return path +} + +func writeJWTTokenFile(t *testing.T) string { + t.Helper() + + priv, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(time.Now()), + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), + Subject: "mcp-test", + }) + + tokenString, err := token.SignedString(priv) + require.NoError(t, err) + + path := filepath.Join(t.TempDir(), "token.jwt") + require.NoError(t, os.WriteFile(path, []byte(tokenString), 0o600)) + + return path +} + +func httpHandlerOK() http.Handler { + return httpHandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + }) +} + +type httpHandlerFunc func(http.ResponseWriter, *http.Request) + +func (f httpHandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) { + f(w, r) +} diff --git a/internal/requests/requests_handlers_print_test.go b/internal/requests/requests_handlers_print_test.go new file mode 100644 index 0000000..74df96f --- /dev/null +++ b/internal/requests/requests_handlers_print_test.go @@ -0,0 +1,76 @@ +package requests + +import ( + "bytes" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "errors" + "io" + "net/http" + "os" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPrintResponseData(t *testing.T) { + t.Parallel() + + rd := ResponseData{ + URL: "https://example.com/", + Request: RequestConfig{ + PrintResponseHeaders: true, + PrintResponseBody: true, + ResponseBodyMatchRegexp: "ok", + }, + Response: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Server": []string{"test"}}, + }, + ResponseBody: "ok", + ResponseBodyRegexpMatched: true, + } + + require.Empty(t, capturePrintResponseData(rd, false)) + + rd.Error = errors.New("dial failed") + out := capturePrintResponseData(rd, true) + require.Contains(t, out, "dial failed") + require.Contains(t, out, "Error:") + + rd.Error = nil + rd.Request.PrintResponseCertificates = true + rd.Response.TLS = &tls.ConnectionState{ + Version: tls.VersionTLS13, + CipherSuite: tls.TLS_AES_128_GCM_SHA256, + PeerCertificates: []*x509.Certificate{ + {Subject: pkix.Name{CommonName: "example.com"}}, + }, + } + out = capturePrintResponseData(rd, true) + require.Contains(t, out, "TLS:") + require.Contains(t, out, "https://example.com/") + require.Contains(t, out, "200") + require.Contains(t, out, "Headers:") + require.Contains(t, out, "Body:") + require.Contains(t, out, "BodyRegexpMatch:") +} + +func capturePrintResponseData(rd ResponseData, verbose bool) string { + oldStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + rd.PrintResponseData(verbose) + + _ = w.Close() + os.Stdout = oldStdout + + var buf bytes.Buffer + + _, _ = io.Copy(&buf, r) + _ = r.Close() + + return buf.String() +} diff --git a/internal/requests/requests_test.go b/internal/requests/requests_test.go index 4777113..67c2ede 100644 --- a/internal/requests/requests_test.go +++ b/internal/requests/requests_test.go @@ -1309,6 +1309,17 @@ func TestImportResponseBody_Errors(t *testing.T) { require.False(t, rd.ResponseBodyRegexpMatched) require.Equal(t, "test body", rd.ResponseBody) }) + + t.Run("html content type highlighting", func(t *testing.T) { + rd := ResponseData{ + Response: &http.Response{ + Header: http.Header{"Content-Type": []string{"text/html; charset=utf-8"}}, + Body: io.NopCloser(bytes.NewBufferString("hello")), + }, + } + rd.ImportResponseBody() + require.Contains(t, rd.ResponseBody, "hello") + }) } type newHTTPClientFromRequestConfigTestCase struct { From 51316cd68cc1424ddd0b010488341398921dfff3 Mon Sep 17 00:00:00 2001 From: Zeno Belli Date: Mon, 25 May 2026 12:03:06 +0200 Subject: [PATCH 2/6] chore: tidy modes --- go.sum | 2 ++ 1 file changed, 2 insertions(+) diff --git a/go.sum b/go.sum index 6917991..78a92f3 100644 --- a/go.sum +++ b/go.sum @@ -128,6 +128,8 @@ golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= +golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s= +golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= From ed8ef3f2df9869b891fb780572feeacd646902e5 Mon Sep 17 00:00:00 2001 From: Zeno Belli Date: Mon, 25 May 2026 12:21:03 +0200 Subject: [PATCH 3/6] test: comment parallel execution in mcp test it is the probable cause of random test failure in Github CI --- internal/cmd/mcp_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/cmd/mcp_test.go b/internal/cmd/mcp_test.go index 6fcdebc..d24b914 100644 --- a/internal/cmd/mcp_test.go +++ b/internal/cmd/mcp_test.go @@ -8,8 +8,8 @@ import ( ) func TestIsMCPCommand(t *testing.T) { - t.Parallel() - + // t.Parallel() + // tests := []struct { name string args []string From f7d6bf08e6a7168d9aea109c560f943a2732ca7b Mon Sep 17 00:00:00 2001 From: Zeno Belli Date: Mon, 25 May 2026 12:49:02 +0200 Subject: [PATCH 4/6] fix: eliminate MCP test stdout races and encrypted-key prompt flake Route requests verbose output through the HandleRequests writer instead of replacing os.Stdout, and skip terminal passphrase prompts for readers that declare NoPasswordPrompt (MCP and test mocks). Co-authored-by: Cursor --- .testcoverage.yml | 2 +- internal/certinfo/certinfo.go | 6 ++ internal/certinfo/common_handlers.go | 9 +++ internal/certinfo/main_test.go | 2 + internal/mcp/coverage_test.go | 1 + internal/mcp/tools_exec.go | 39 +---------- internal/requests/requests.go | 25 +++---- internal/requests/requests_handlers.go | 66 +++++++++---------- .../requests/requests_handlers_print_test.go | 14 +--- internal/requests/requests_handlers_test.go | 2 + internal/requests/requests_test.go | 3 +- 11 files changed, 74 insertions(+), 95 deletions(-) diff --git a/.testcoverage.yml b/.testcoverage.yml index 3e1bfb4..c54a53b 100644 --- a/.testcoverage.yml +++ b/.testcoverage.yml @@ -3,7 +3,7 @@ profile: cover.out threshold: file: 70 package: 80 - total: 85 + total: 95 exclude: paths: diff --git a/internal/certinfo/certinfo.go b/internal/certinfo/certinfo.go index a62026f..73ddb63 100644 --- a/internal/certinfo/certinfo.go +++ b/internal/certinfo/certinfo.go @@ -69,6 +69,12 @@ type ( ReadPassword(fd int) ([]byte, error) } + // NoPasswordPromptReader is implemented by readers that must not use an + // interactive terminal prompt (for example MCP or automated tests). + NoPasswordPromptReader interface { + NoPasswordPrompt() bool + } + // InputReader implements the Reader interface using standard OS calls. InputReader struct{} ) diff --git a/internal/certinfo/common_handlers.go b/internal/certinfo/common_handlers.go index 477bb44..dd74204 100644 --- a/internal/certinfo/common_handlers.go +++ b/internal/certinfo/common_handlers.go @@ -189,6 +189,15 @@ func getPassphraseIfNeeded(isEncrypted bool, pwEnvKey string, pwReader Reader) ( return []byte(pkeyEnvPw), nil } + if noPrompt, ok := pwReader.(NoPasswordPromptReader); ok && noPrompt.NoPasswordPrompt() { + pw, trErr := pwReader.ReadPassword(int(os.Stdin.Fd())) + if trErr != nil { + return nil, fmt.Errorf("error reading passphrase: %w", trErr) + } + + return pw, nil + } + fmt.Print("Private key is encrypted, please enter passphrase: ") pw, trErr := pwReader.ReadPassword(int(os.Stdin.Fd())) diff --git a/internal/certinfo/main_test.go b/internal/certinfo/main_test.go index 6c799ca..818a34f 100644 --- a/internal/certinfo/main_test.go +++ b/internal/certinfo/main_test.go @@ -151,6 +151,8 @@ func (MockErrReader) ReadFile(name string) ([]byte, error) { return nil, fmt.Errorf("unable to read file %s", name) } +func (MockErrReader) NoPasswordPrompt() bool { return true } + func (MockErrReader) ReadPassword(fd int) ([]byte, error) { return func(_ int) ([]byte, error) { return []byte{}, errors.New("mockErrReader: unable to read password") diff --git a/internal/mcp/coverage_test.go b/internal/mcp/coverage_test.go index 425c45a..7db978d 100644 --- a/internal/mcp/coverage_test.go +++ b/internal/mcp/coverage_test.go @@ -337,6 +337,7 @@ func TestMCPFileReader(t *testing.T) { data, err := reader.ReadFile(path) require.NoError(t, err) require.Equal(t, "ok", string(data)) + require.True(t, reader.NoPasswordPrompt()) _, err = reader.ReadPassword(0) require.Error(t, err) diff --git a/internal/mcp/tools_exec.go b/internal/mcp/tools_exec.go index ef562b0..c12a4e2 100644 --- a/internal/mcp/tools_exec.go +++ b/internal/mcp/tools_exec.go @@ -74,6 +74,8 @@ func (mcpFileReader) ReadFile(name string) ([]byte, error) { return os.ReadFile(name) } +func (mcpFileReader) NoPasswordPrompt() bool { return true } + func (mcpFileReader) ReadPassword(_ int) ([]byte, error) { return nil, errors.New("encrypted private keys require CERTINFO_PKEY_PW under MCP") } @@ -185,7 +187,7 @@ func executeRunRequests(_ context.Context, input runRequestsInput) (execToolOutp return execToolOutput{}, err } - output, err := captureWithStdout(func(w io.Writer) error { + output, err := captureOutput(func(w io.Writer) error { _, handleErr := requests.HandleRequests(w, meta) return handleErr @@ -407,38 +409,3 @@ func captureOutput(fn func(io.Writer) error) (string, error) { return buf.String(), nil } - -func captureWithStdout(fn func(io.Writer) error) (string, error) { - pipeR, pipeW, err := os.Pipe() - if err != nil { - return "", err - } - - oldStdout := os.Stdout - os.Stdout = pipeW - - var writerBuf bytes.Buffer - - fnErr := fn(&writerBuf) - - if closeErr := pipeW.Close(); closeErr != nil && fnErr == nil { - fnErr = closeErr - } - - os.Stdout = oldStdout - - var stdoutBuf bytes.Buffer - - if _, copyErr := io.Copy(&stdoutBuf, pipeR); copyErr != nil && fnErr == nil { - fnErr = copyErr - } - - _ = pipeR.Close() - - var combined bytes.Buffer - - _, _ = combined.Write(stdoutBuf.Bytes()) - _, _ = combined.Write(writerBuf.Bytes()) - - return combined.String(), fnErr -} diff --git a/internal/requests/requests.go b/internal/requests/requests.go index e03a9c8..6349241 100644 --- a/internal/requests/requests.go +++ b/internal/requests/requests.go @@ -247,14 +247,14 @@ func (r *RequestsMetaConfig) PrintCmd(w io.Writer) { // PrintTitle prints the request name and transport override information if verbose mode is enabled. // //nolint:revive -func (r *RequestConfig) PrintTitle(isVerbose bool) { +func (r *RequestConfig) PrintTitle(w io.Writer, isVerbose bool) { if isVerbose { - fmt.Print(style.LgSprintf(style.TitleKey, "Request:")) - fmt.Println(style.LgSprintf(style.Title, "%s", r.Name)) + fmt.Fprint(w, style.LgSprintf(style.TitleKey, "Request:")) + fmt.Fprintln(w, style.LgSprintf(style.Title, "%s", r.Name)) if r.TransportOverrideURL != "" { - fmt.Print(style.LgSprintf(style.ItemKey, "Via:")) - fmt.Println(style.LgSprintf(style.Via, "%s", r.TransportOverrideURL)) + fmt.Fprint(w, style.LgSprintf(style.ItemKey, "Via:")) + fmt.Fprintln(w, style.LgSprintf(style.Via, "%s", r.TransportOverrideURL)) } } } @@ -633,16 +633,17 @@ func NewHTTPClientFromRequestConfig( // //nolint:revive func processHTTPRequestsByHost( + w io.Writer, r RequestConfig, caPool *x509.CertPool, isVerbose bool, ) ([]ResponseData, error) { var responseDataList []ResponseData - r.PrintTitle(isVerbose) + r.PrintTitle(w, isVerbose) for _, host := range r.Hosts { - hostResults, err := processRequestsForHost(r, host, caPool, isVerbose) + hostResults, err := processRequestsForHost(w, r, host, caPool, isVerbose) if err != nil { return nil, err } @@ -655,6 +656,7 @@ func processHTTPRequestsByHost( // processRequestsForHost initializes the HTTP client and executes all configured URIs for a single host. func processRequestsForHost( + w io.Writer, r RequestConfig, host Host, caPool *x509.CertPool, @@ -675,9 +677,9 @@ func processRequestsForHost( requestBodyBytes := []byte(r.RequestBody) for _, reqURL := range urlList { - responseData := executeSingleRequest(r, reqClient, reqURL, requestBodyBytes, isVerbose) + responseData := executeSingleRequest(w, r, reqClient, reqURL, requestBodyBytes, isVerbose) responseDataList = append(responseDataList, responseData) - responseData.PrintResponseData(isVerbose) + responseData.PrintResponseData(w, isVerbose) } return responseDataList, nil @@ -685,6 +687,7 @@ func processRequestsForHost( // executeSingleRequest performs a single HTTP request and returns the collected response data. func executeSingleRequest( + w io.Writer, r RequestConfig, reqClient *RequestHTTPClient, reqURL string, @@ -716,7 +719,7 @@ func executeSingleRequest( req.Header.Set("User-Agent", ua) - if err := r.PrintRequestDebug(os.Stdout, req); err != nil { + if err := r.PrintRequestDebug(w, req); err != nil { fmt.Fprintf(os.Stderr, "Warning: PrintRequestDebug failed: %v\n", err) } @@ -726,7 +729,7 @@ func executeSingleRequest( return responseData } - r.PrintResponseDebug(os.Stdout, resp) + r.PrintResponseDebug(w, resp) responseData.Response = resp diff --git a/internal/requests/requests_handlers.go b/internal/requests/requests_handlers.go index 6e76a77..aaa6971 100644 --- a/internal/requests/requests_handlers.go +++ b/internal/requests/requests_handlers.go @@ -10,7 +10,6 @@ import ( "net" "net/http" "net/url" - "os" "regexp" "slices" "strconv" @@ -219,6 +218,7 @@ func HandleRequests(w io.Writer, cfg *RequestsMetaConfig) (map[string][]Response for _, r := range cfg.Requests { responseDataList, err := processHTTPRequestsByHost( + w, r, cfg.CACertsPool, cfg.RequestVerbose, @@ -287,58 +287,58 @@ func (rd *ResponseData) ImportResponseBody() { // PrintResponseData prints the collected response data (status, headers, body) if verbose mode is enabled. // //nolint:revive -func (rd ResponseData) PrintResponseData(isVerbose bool) { +func (rd ResponseData) PrintResponseData(w io.Writer, isVerbose bool) { if !isVerbose { return } - fmt.Println(style.LgSprintf(style.ItemKey, + fmt.Fprintln(w, style.LgSprintf(style.ItemKey, "- Url: %s", style.URL.Render(rd.URL)), ) - fmt.Print(style.LgSprintf(style.ItemKeyP3, "StatusCode: ")) + fmt.Fprint(w, style.LgSprintf(style.ItemKeyP3, "StatusCode: ")) if rd.Error != nil { - fmt.Println(style.LgSprintf(style.StatusError, "0")) - fmt.Println(style.LgSprintf( + fmt.Fprintln(w, style.LgSprintf(style.StatusError, "0")) + fmt.Fprintln(w, style.LgSprintf( style.ItemKeyP3, "Error: %s", - style.Error.Render(rd.Error.Error())), - ) - fmt.Println() - } + style.Error.Render(rd.Error.Error()), + )) + fmt.Fprintln(w) - if rd.Error == nil { - fmt.Println(style.LgSprintf(style.Status, - "%v", - style.StatusCodeParse(rd.Response.StatusCode))) + return + } - if rd.Request.PrintResponseCertificates { - RenderTLSData(os.Stdout, rd.Response, rd.Request.ResponseCertificatesFilter) - } + fmt.Fprintln(w, style.LgSprintf(style.Status, + "%v", + style.StatusCodeParse(rd.Response.StatusCode))) - if rd.Request.PrintResponseHeaders { - headersStr := filterResponseHeaders( - rd.Response.Header, - rd.Request.ResponseHeadersFilter) + if rd.Request.PrintResponseCertificates { + RenderTLSData(w, rd.Response, rd.Request.ResponseCertificatesFilter) + } - fmt.Println(style.LgSprintf(style.ItemKeyP3, "Headers: ")) - fmt.Println(headersStr) - } + if rd.Request.PrintResponseHeaders { + headersStr := filterResponseHeaders( + rd.Response.Header, + rd.Request.ResponseHeadersFilter) - if rd.Request.ResponseBodyMatchRegexp != "" { - fmt.Print(style.LgSprintf(style.ItemKeyP3, "BodyRegexpMatch: ")) - fmt.Println(rd.ResponseBodyRegexpMatched) - } + fmt.Fprintln(w, style.LgSprintf(style.ItemKeyP3, "Headers: ")) + fmt.Fprintln(w, headersStr) + } - if rd.Request.PrintResponseBody { - fmt.Println(style.LgSprintf(style.ItemKeyP3, "Body:")) - fmt.Println(rd.ResponseBody) - } + if rd.Request.ResponseBodyMatchRegexp != "" { + fmt.Fprint(w, style.LgSprintf(style.ItemKeyP3, "BodyRegexpMatch: ")) + fmt.Fprintln(w, rd.ResponseBodyRegexpMatched) + } - fmt.Println() + if rd.Request.PrintResponseBody { + fmt.Fprintln(w, style.LgSprintf(style.ItemKeyP3, "Body:")) + fmt.Fprintln(w, rd.ResponseBody) } + + fmt.Fprintln(w) } // RenderTLSData prints TLS version, cipher suite, and peer certificates for an HTTP response. diff --git a/internal/requests/requests_handlers_print_test.go b/internal/requests/requests_handlers_print_test.go index 74df96f..6f280fb 100644 --- a/internal/requests/requests_handlers_print_test.go +++ b/internal/requests/requests_handlers_print_test.go @@ -6,9 +6,7 @@ import ( "crypto/x509" "crypto/x509/pkix" "errors" - "io" "net/http" - "os" "testing" "github.com/stretchr/testify/require" @@ -58,19 +56,9 @@ func TestPrintResponseData(t *testing.T) { } func capturePrintResponseData(rd ResponseData, verbose bool) string { - oldStdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w - - rd.PrintResponseData(verbose) - - _ = w.Close() - os.Stdout = oldStdout - var buf bytes.Buffer - _, _ = io.Copy(&buf, r) - _ = r.Close() + rd.PrintResponseData(&buf, verbose) return buf.String() } diff --git a/internal/requests/requests_handlers_test.go b/internal/requests/requests_handlers_test.go index 44a3f04..efc6cf4 100644 --- a/internal/requests/requests_handlers_test.go +++ b/internal/requests/requests_handlers_test.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "crypto/x509" "fmt" + "io" "net/http" "testing" @@ -395,6 +396,7 @@ func TestRenderTLSData(t *testing.T) { defer ts.Close() respList, err := processHTTPRequestsByHost( + io.Discard, tt.reqConf, tt.pool, false, diff --git a/internal/requests/requests_test.go b/internal/requests/requests_test.go index 67c2ede..63b114e 100644 --- a/internal/requests/requests_test.go +++ b/internal/requests/requests_test.go @@ -1247,7 +1247,7 @@ func TestProcessHTTPRequestsByHost_Errors(t *testing.T) { {Name: "localhost", URIList: []URI{"invalid"}}, }, } - _, err := processHTTPRequestsByHost(reqConf, nil, false) + _, err := processHTTPRequestsByHost(io.Discard, reqConf, nil, false) require.Error(t, err) require.ErrorContains(t, err, "invalid uri") }) @@ -1582,6 +1582,7 @@ func runProcessHTTPRequestsByHostSubtest(t *testing.T, tt processHTTPRequestsByH defer ts.Close() respList, err := processHTTPRequestsByHost( + io.Discard, tt.reqConf, tt.pool, tt.verbose, From aa22215e366c865370ffd4041b69ca9d98c16df8 Mon Sep 17 00:00:00 2001 From: Zeno Belli Date: Mon, 25 May 2026 13:19:04 +0200 Subject: [PATCH 5/6] fix: address CodeRabbit review feedback on MCP PR Enforce handler timeouts via runWithContext, quote generated YAML scalars, tighten clientTimeout schema to integer, fix example typo, stabilize the embedded $schema URI, and harden tests (safe type assertion, local jwt URL). Co-authored-by: Cursor --- .../https-wrench-proxyProtocolV2.yaml | 2 +- https-wrench.schema.json | 3 +- internal/jwtinfo/jwtinfo_test.go | 2 +- .../https-wrench-proxyProtocolV2.yaml | 2 +- internal/mcp/assets/schema.json | 3 +- internal/mcp/embed.go | 3 +- internal/mcp/server_test.go | 9 ++-- internal/mcp/tools.go | 15 ++++-- internal/mcp/tools_exec.go | 52 ++++++++++++++++++- 9 files changed, 73 insertions(+), 18 deletions(-) diff --git a/assets/examples/https-wrench-proxyProtocolV2.yaml b/assets/examples/https-wrench-proxyProtocolV2.yaml index c06c49b..acf7978 100644 --- a/assets/examples/https-wrench-proxyProtocolV2.yaml +++ b/assets/examples/https-wrench-proxyProtocolV2.yaml @@ -11,7 +11,7 @@ debug: false verbose: true requests: - - name: RequestOverProxyPtorocolV2 + - name: RequestOverProxyProtocolV2 transportOverrideUrl: https://127.0.0.1:9444 enableProxyProtocolV2: true clientTimeout: 2 diff --git a/https-wrench.schema.json b/https-wrench.schema.json index 597447c..4b4f18d 100644 --- a/https-wrench.schema.json +++ b/https-wrench.schema.json @@ -56,7 +56,8 @@ "description": "TLS/TCP dial address (https://host or https://ip:port). The logical hostname remains hosts[].name for Host header and SNI." }, "clientTimeout": { - "type": "number", + "type": "integer", + "minimum": 0, "description": "HTTP client timeout in seconds." }, "requestDebug": { diff --git a/internal/jwtinfo/jwtinfo_test.go b/internal/jwtinfo/jwtinfo_test.go index 3549df6..cbe034d 100644 --- a/internal/jwtinfo/jwtinfo_test.go +++ b/internal/jwtinfo/jwtinfo_test.go @@ -262,7 +262,7 @@ func TestRequestToken_nilReadAll(t *testing.T) { _, err := RequestToken( context.Background(), - "http://example.com/token", + "http://127.0.0.1/token", map[string]string{"grant_type": "client_credentials"}, &http.Client{}, nil, diff --git a/internal/mcp/assets/examples/https-wrench-proxyProtocolV2.yaml b/internal/mcp/assets/examples/https-wrench-proxyProtocolV2.yaml index c06c49b..acf7978 100644 --- a/internal/mcp/assets/examples/https-wrench-proxyProtocolV2.yaml +++ b/internal/mcp/assets/examples/https-wrench-proxyProtocolV2.yaml @@ -11,7 +11,7 @@ debug: false verbose: true requests: - - name: RequestOverProxyPtorocolV2 + - name: RequestOverProxyProtocolV2 transportOverrideUrl: https://127.0.0.1:9444 enableProxyProtocolV2: true clientTimeout: 2 diff --git a/internal/mcp/assets/schema.json b/internal/mcp/assets/schema.json index 597447c..4b4f18d 100644 --- a/internal/mcp/assets/schema.json +++ b/internal/mcp/assets/schema.json @@ -56,7 +56,8 @@ "description": "TLS/TCP dial address (https://host or https://ip:port). The logical hostname remains hosts[].name for Host header and SNI." }, "clientTimeout": { - "type": "number", + "type": "integer", + "minimum": 0, "description": "HTTP client timeout in seconds." }, "requestDebug": { diff --git a/internal/mcp/embed.go b/internal/mcp/embed.go index 0fb15a8..c8c9871 100644 --- a/internal/mcp/embed.go +++ b/internal/mcp/embed.go @@ -26,8 +26,7 @@ var exampleFiles = map[string]string{ "proxy-protocol-v2": "assets/examples/https-wrench-proxyProtocolV2.yaml", } -const schemaCommentHeader = "# yaml-language-server: $schema=" + - "https://raw.githubusercontent.com/xenOs76/https-wrench/refs/heads/main/https-wrench.schema.json" +const schemaCommentHeader = "# yaml-language-server: $schema=" + uriSchema const requestsDocsMarkdown = `# https-wrench requests diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go index e164bcc..6ed88eb 100644 --- a/internal/mcp/server_test.go +++ b/internal/mcp/server_test.go @@ -251,7 +251,7 @@ func TestRequestsConfigTemplate(t *testing.T) { out := decodeStructuredOutput(t, res) yaml, ok := out["configYaml"].(string) require.True(t, ok) - require.Contains(t, yaml, "transportOverrideUrl: https://edge.example.net") + require.Contains(t, yaml, `transportOverrideUrl: "https://edge.example.net"`) require.Contains(t, yaml, "www.example.com") require.Contains(t, yaml, "/health") } @@ -274,9 +274,10 @@ func TestAuthorRequestsConfigPrompt(t *testing.T) { }) require.NoError(t, err) require.NotEmpty(t, res.Messages) - text := res.Messages[0].Content.(*sdkmcp.TextContent).Text - require.Contains(t, text, "app.example.com") - require.Contains(t, text, "validate_requests_config") + content, ok := res.Messages[0].Content.(*sdkmcp.TextContent) + require.Truef(t, ok, "expected *sdkmcp.TextContent, got %T", res.Messages[0].Content) + require.Contains(t, content.Text, "app.example.com") + require.Contains(t, content.Text, "validate_requests_config") } func callValidateTool(t *testing.T, yaml string) map[string]any { diff --git a/internal/mcp/tools.go b/internal/mcp/tools.go index fccc073..92957b0 100644 --- a/internal/mcp/tools.go +++ b/internal/mcp/tools.go @@ -239,11 +239,11 @@ func buildRequestsConfigYAML(input requestsConfigTemplateInput) (string, []strin fmt.Fprintln(&b, "---") fmt.Fprintln(&b, "verbose: true") fmt.Fprintln(&b, "requests:") - fmt.Fprintf(&b, " - name: %s\n", name) - fmt.Fprintf(&b, " requestMethod: %s\n", method) + fmt.Fprintf(&b, " - name: %s\n", yamlQuotedScalar(name)) + fmt.Fprintf(&b, " requestMethod: %s\n", yamlQuotedScalar(method)) if transport := strings.TrimSpace(input.TransportOverrideURL); transport != "" { - fmt.Fprintf(&b, " transportOverrideUrl: %s\n", transport) + fmt.Fprintf(&b, " transportOverrideUrl: %s\n", yamlQuotedScalar(transport)) } if input.Insecure { @@ -251,11 +251,11 @@ func buildRequestsConfigYAML(input requestsConfigTemplateInput) (string, []strin } fmt.Fprintln(&b, " hosts:") - fmt.Fprintf(&b, " - name: %s\n", hostname) + fmt.Fprintf(&b, " - name: %s\n", yamlQuotedScalar(hostname)) fmt.Fprintln(&b, " uriList:") for _, p := range paths { - fmt.Fprintf(&b, " - %s\n", p) + fmt.Fprintf(&b, " - %s\n", yamlQuotedScalar(p)) } return strings.TrimRight(b.String(), "\n") + "\n", nil @@ -284,6 +284,11 @@ func parsePaths(paths string) []string { return out } +// yamlQuotedScalar returns a YAML-safe double-quoted scalar. +func yamlQuotedScalar(s string) string { + return strconv.Quote(s) +} + var allowedCLICommands = map[string]cliCommandDef{ "certinfo": { oneOfGroups: [][]string{{"tls-endpoint", "cert-bundle", "key-file", "ca-bundle"}}, diff --git a/internal/mcp/tools_exec.go b/internal/mcp/tools_exec.go index c12a4e2..c65ee0b 100644 --- a/internal/mcp/tools_exec.go +++ b/internal/mcp/tools_exec.go @@ -166,7 +166,17 @@ func generateJWKSHandler( return nil, out, nil } -func executeRunRequests(_ context.Context, input runRequestsInput) (execToolOutput, error) { +func executeRunRequests(ctx context.Context, input runRequestsInput) (execToolOutput, error) { + return runWithContext(ctx, func(ctx context.Context) (execToolOutput, error) { + if err := ctx.Err(); err != nil { + return execToolOutput{}, err + } + + return runRequestsExec(input) + }) +} + +func runRequestsExec(input runRequestsInput) (execToolOutput, error) { yamlContent, err := loadConfigYAML(input.ConfigYAML, input.ConfigPath) if err != nil { return execToolOutput{}, err @@ -199,7 +209,17 @@ func executeRunRequests(_ context.Context, input runRequestsInput) (execToolOutp return execToolOutput{Output: output}, nil } -func executeCertinfo(_ context.Context, input certinfoInput) (execToolOutput, error) { +func executeCertinfo(ctx context.Context, input certinfoInput) (execToolOutput, error) { + return runWithContext(ctx, func(ctx context.Context) (execToolOutput, error) { + if err := ctx.Err(); err != nil { + return execToolOutput{}, err + } + + return certinfoExec(input) + }) +} + +func certinfoExec(input certinfoInput) (execToolOutput, error) { if !certinfoInputProvided(input) { return execToolOutput{}, errors.New( "one of tlsEndpoint, certBundle, keyFile, or caBundle is required", @@ -409,3 +429,31 @@ func captureOutput(fn func(io.Writer) error) (string, error) { return buf.String(), nil } + +func runWithContext[T any](ctx context.Context, fn func(context.Context) (T, error)) (T, error) { + if ctx == nil { + ctx = context.Background() + } + + done := make(chan struct { + v T + err error + }, 1) + + go func() { + v, err := fn(ctx) + done <- struct { + v T + err error + }{v: v, err: err} + }() + + select { + case <-ctx.Done(): + var zero T + + return zero, ctx.Err() + case r := <-done: + return r.v, r.err + } +} From 3ee3b4d25749ec12e97c576db1f035a8a4a058c0 Mon Sep 17 00:00:00 2001 From: Zeno Belli Date: Mon, 25 May 2026 13:56:00 +0200 Subject: [PATCH 6/6] fix: propagate context through MCP requests and certinfo TLS paths Thread cancellation from MCP tool timeouts into HandleRequests, http.NewRequestWithContext, and certinfo DialContext handshakes. Remove parallel os.Args mutation in TestIsMCPCommand subtests. Co-authored-by: Cursor --- internal/certinfo/certinfo.go | 5 +- internal/certinfo/certinfo_handlers.go | 90 +++++++++++++-------- internal/certinfo/certinfo_handlers_test.go | 15 ++-- internal/certinfo/certinfo_test.go | 17 ++-- internal/cmd/certinfo.go | 6 +- internal/cmd/mcp_test.go | 2 - internal/cmd/requests.go | 3 +- internal/mcp/coverage_test.go | 6 +- internal/mcp/tools_exec.go | 36 +++++++-- internal/requests/requests.go | 17 +++- internal/requests/requests_handlers.go | 12 ++- internal/requests/requests_handlers_test.go | 4 +- internal/requests/requests_test.go | 4 +- 13 files changed, 146 insertions(+), 71 deletions(-) diff --git a/internal/certinfo/certinfo.go b/internal/certinfo/certinfo.go index 73ddb63..e8196d8 100644 --- a/internal/certinfo/certinfo.go +++ b/internal/certinfo/certinfo.go @@ -1,6 +1,7 @@ package certinfo import ( + "context" "crypto" "crypto/x509" "fmt" @@ -180,7 +181,7 @@ func (c *Config) SetPrivateKeyFromFile( } // SetTLSEndpoint parses a host:port string and fetches the remote certificates from that endpoint. -func (c *Config) SetTLSEndpoint(hostport string) error { +func (c *Config) SetTLSEndpoint(ctx context.Context, hostport string) error { if hostport != emptyString { eHost, ePort, err := net.SplitHostPort(hostport) if err != nil { @@ -191,7 +192,7 @@ func (c *Config) SetTLSEndpoint(hostport string) error { c.TLSEndpointHost = eHost c.TLSEndpointPort = ePort - err = c.GetRemoteCerts() + err = c.GetRemoteCerts(ctx) if err != nil { return fmt.Errorf("unable to get endpoint certificates: %w", err) } diff --git a/internal/certinfo/certinfo_handlers.go b/internal/certinfo/certinfo_handlers.go index 3cb2232..2c8bed8 100644 --- a/internal/certinfo/certinfo_handlers.go +++ b/internal/certinfo/certinfo_handlers.go @@ -6,6 +6,7 @@ package certinfo import ( "cmp" + "context" "crypto/sha256" "crypto/tls" "crypto/x509" @@ -29,7 +30,7 @@ import ( // to the provided writer in a human-readable format. // //nolint:revive -func (c *Config) PrintData(w io.Writer) error { +func (c *Config) PrintData(ctx context.Context, w io.Writer) error { ks := style.ItemKey.PaddingBottom(0).PaddingTop(1).PaddingLeft(1) sl := style.CertKeyP4.Bold(true) sv := style.CertValue.Bold(false) @@ -49,7 +50,7 @@ func (c *Config) PrintData(w io.Writer) error { } if c.TLSInfoRequested { - _ = c.ProbeTLSInfo() + _ = c.ProbeTLSInfo(ctx) c.printTLSInfo(w, ks, sl, sv) } @@ -174,30 +175,47 @@ func (c *Config) printCACerts(w io.Writer, ks, sl, sv lipgloss.Style) error { return nil } +// dialTLS connects to serverAddr and completes a TLS handshake using ctx for cancellation. +func dialTLS(ctx context.Context, serverAddr string, tlsConfig *tls.Config) (*tls.Conn, error) { + dialer := &net.Dialer{Timeout: TLSTimeout} + + rawConn, err := dialer.DialContext(ctx, "tcp", serverAddr) + if err != nil { + return nil, err + } + + conn := tls.Client(rawConn, tlsConfig) + + if err = conn.HandshakeContext(ctx); err != nil { + _ = rawConn.Close() + + return nil, err + } + + return conn, nil +} + // GetRemoteCerts establishes a TLS connection to the configured endpoint and retrieves // the peer certificate chain. It also performs certificate verification unless TLSInsecure is true. -func (c *Config) GetRemoteCerts() error { +func (c *Config) GetRemoteCerts(ctx context.Context) error { tlsConfig := &tls.Config{ RootCAs: c.CACertsPool, InsecureSkipVerify: c.TLSInsecure, } - if c.TLSServerName != emptyString { + verifyName := c.TLSServerName + switch { + case c.TLSServerName != emptyString: tlsConfig.ServerName = c.TLSServerName + case c.TLSEndpointHost != emptyString: + tlsConfig.ServerName = c.TLSEndpointHost + verifyName = c.TLSEndpointHost + default: } serverAddr := net.JoinHostPort(c.TLSEndpointHost, c.TLSEndpointPort) - dialer := &net.Dialer{ - Timeout: TLSTimeout, - } - - conn, err := tls.DialWithDialer( - dialer, - "tcp", - serverAddr, - tlsConfig, - ) + conn, err := dialTLS(ctx, serverAddr, tlsConfig) if err != nil { return fmt.Errorf("TLS handshake failed: %w", err) } @@ -214,7 +232,7 @@ func (c *Config) GetRemoteCerts() error { } opts := x509.VerifyOptions{ - DNSName: c.TLSServerName, + DNSName: verifyName, Roots: c.CACertsPool, Intermediates: x509.NewCertPool(), } @@ -425,7 +443,7 @@ func tlsVersionToString(version uint16) string { } // probeProtocol tests whether the TLS endpoint supports a specific TLS protocol version. -func (c *Config) probeProtocol(version uint16) bool { +func (c *Config) probeProtocol(ctx context.Context, version uint16) bool { tlsConfig := &tls.Config{ MinVersion: version, MaxVersion: version, @@ -434,17 +452,15 @@ func (c *Config) probeProtocol(version uint16) bool { if c.TLSServerName != emptyString { tlsConfig.ServerName = c.TLSServerName + } else if c.TLSEndpointHost != emptyString { + tlsConfig.ServerName = c.TLSEndpointHost } serverAddr := net.JoinHostPort(c.TLSEndpointHost, c.TLSEndpointPort) - dialer := &net.Dialer{ - Timeout: TLSTimeout, - } - - conn, err := tls.DialWithDialer(dialer, "tcp", serverAddr, tlsConfig) + conn, err := dialTLS(ctx, serverAddr, tlsConfig) if err == nil { - conn.Close() + _ = conn.Close() return true } @@ -453,7 +469,7 @@ func (c *Config) probeProtocol(version uint16) bool { } // probeCipher tests whether a specific TLS 1.0-1.2 cipher suite is supported. -func (c *Config) probeCipher(suite *tls.CipherSuite) (bool, string) { +func (c *Config) probeCipher(ctx context.Context, suite *tls.CipherSuite) (bool, string) { tlsConfig := &tls.Config{ MinVersion: tls.VersionTLS10, MaxVersion: tls.VersionTLS12, @@ -463,19 +479,17 @@ func (c *Config) probeCipher(suite *tls.CipherSuite) (bool, string) { if c.TLSServerName != emptyString { tlsConfig.ServerName = c.TLSServerName + } else if c.TLSEndpointHost != emptyString { + tlsConfig.ServerName = c.TLSEndpointHost } serverAddr := net.JoinHostPort(c.TLSEndpointHost, c.TLSEndpointPort) - dialer := &net.Dialer{ - Timeout: TLSTimeout, - } - - conn, err := tls.DialWithDialer(dialer, "tcp", serverAddr, tlsConfig) + conn, err := dialTLS(ctx, serverAddr, tlsConfig) if err == nil { state := conn.ConnectionState() - conn.Close() + _ = conn.Close() return true, tlsVersionToString(state.Version) } @@ -484,7 +498,7 @@ func (c *Config) probeCipher(suite *tls.CipherSuite) (bool, string) { } // ProbeTLSInfo concurrently scans the endpoint for supported TLS versions and cipher suites. -func (c *Config) ProbeTLSInfo() error { +func (c *Config) ProbeTLSInfo(ctx context.Context) error { if c.TLSEndpoint == emptyString { return nil } @@ -495,7 +509,11 @@ func (c *Config) ProbeTLSInfo() error { versions := []uint16{tls.VersionTLS10, tls.VersionTLS11, tls.VersionTLS12, tls.VersionTLS13} for _, v := range versions { - supported := c.probeProtocol(v) + if err := ctx.Err(); err != nil { + return err + } + + supported := c.probeProtocol(ctx, v) c.ProbedProtocols[tlsVersionToString(v)] = supported } @@ -503,7 +521,7 @@ func (c *Config) ProbeTLSInfo() error { // 2. Probe ciphers concurrently suites := append(tls.CipherSuites(), tls.InsecureCipherSuites()...) - c.ProbedCiphers = c.probeCiphersConcurrently(suites) + c.ProbedCiphers = c.probeCiphersConcurrently(ctx, suites) return nil } @@ -511,7 +529,7 @@ func (c *Config) ProbeTLSInfo() error { // probeCiphersConcurrently manages the worker pool to concurrently scan cipher suites. // //nolint:gocognit,revive,wsl -func (c *Config) probeCiphersConcurrently(suites []*tls.CipherSuite) []ProbedCipher { +func (c *Config) probeCiphersConcurrently(ctx context.Context, suites []*tls.CipherSuite) []ProbedCipher { type job struct { suite *tls.CipherSuite } @@ -539,6 +557,10 @@ func (c *Config) probeCiphersConcurrently(suites []*tls.CipherSuite) []ProbedCip defer wg.Done() for j := range jobs { + if err := ctx.Err(); err != nil { + return + } + suite := j.suite isTLS13 := false @@ -559,7 +581,7 @@ func (c *Config) probeCiphersConcurrently(suites []*tls.CipherSuite) []ProbedCip supported = c.ProbedProtocols["TLS 1.3"] protoName = "TLS 1.3" } else { - ok, name := c.probeCipher(suite) + ok, name := c.probeCipher(ctx, suite) supported = ok protoName = name diff --git a/internal/certinfo/certinfo_handlers_test.go b/internal/certinfo/certinfo_handlers_test.go index 186ec83..522241d 100644 --- a/internal/certinfo/certinfo_handlers_test.go +++ b/internal/certinfo/certinfo_handlers_test.go @@ -2,6 +2,7 @@ package certinfo import ( "bytes" + "context" "crypto/x509" "crypto/x509/pkix" "math/big" @@ -145,10 +146,10 @@ func TestCertinfo_GetRemoteCerts(t *testing.T) { cc.SetTLSServerName(tt.srvCfg.serverName) cc.SetCaPoolFromFile(tt.caCertFile, inputReader) - cc.SetTLSEndpoint(tt.srvCfg.serverAddr) + cc.SetTLSEndpoint(context.Background(), tt.srvCfg.serverAddr) cc.SetTLSInsecure(tt.insecure) - err = cc.GetRemoteCerts() + err = cc.GetRemoteCerts(context.Background()) if !tt.expectError { require.NoError(t, err, "check error not expected") require.Equal(t, tt.srvCfg.serverName, cc.TLSServerName, "check TLSServerName") @@ -491,7 +492,7 @@ func TestCertinfo_PrintData(t *testing.T) { }) cc.CertsBundleFilePath = "dummy" - errPrint := cc.PrintData(&buffer) + errPrint := cc.PrintData(context.Background(), &buffer) require.Error(t, errPrint) require.ErrorContains(t, errPrint, "unable to check if private key matches local certificate") }) @@ -508,7 +509,7 @@ func TestCertinfo_PrintData(t *testing.T) { cc.TLSEndpointHost = "localhost" cc.TLSEndpointPort = "443" - errPrint := cc.PrintData(&buffer) + errPrint := cc.PrintData(context.Background(), &buffer) require.Error(t, errPrint) require.ErrorContains(t, errPrint, "unable to check if private key matches remote TLS Endpoint certificate") }) @@ -520,7 +521,7 @@ func TestCertinfo_PrintData(t *testing.T) { cc.CACertsFilePath = "non_existent_file.pem" - errPrint := cc.PrintData(&buffer) + errPrint := cc.PrintData(context.Background(), &buffer) require.Error(t, errPrint) require.ErrorContains(t, errPrint, "unable for read Root certificates") }) @@ -559,7 +560,7 @@ func runPrintDataSubtest(t *testing.T, tt printDataTestCase) { cc.SetTLSServerName(tt.tlsServerName) cc.SetTLSInsecure(tt.tlsInsecure) - err = cc.SetTLSEndpoint(tt.tlsEndpoint) + err = cc.SetTLSEndpoint(context.Background(), tt.tlsEndpoint) if tt.expectCertsFetchErr { require.EqualError(t, err, tt.expectCertsFetcMsg) } else { @@ -567,7 +568,7 @@ func runPrintDataSubtest(t *testing.T, tt printDataTestCase) { } } - errPrint := cc.PrintData(&buffer) + errPrint := cc.PrintData(context.Background(), &buffer) require.NoError(t, errPrint) got := buffer.String() diff --git a/internal/certinfo/certinfo_test.go b/internal/certinfo/certinfo_test.go index fd33470..362729a 100644 --- a/internal/certinfo/certinfo_test.go +++ b/internal/certinfo/certinfo_test.go @@ -2,6 +2,7 @@ package certinfo import ( "bytes" + "context" "crypto/tls" "fmt" "net/http" @@ -400,7 +401,7 @@ func TestCertinfo_SetTLSEndpoint(t *testing.T) { cc, errNew := New() require.NoError(t, errNew) - err := cc.SetTLSEndpoint(tt.endpoint) + err := cc.SetTLSEndpoint(context.Background(), tt.endpoint) if !tt.processErr { // skip requiring NoError since SetTLSEndpoint will always return network errors @@ -440,10 +441,10 @@ func TestCertinfo_ProbeTLSInfo(t *testing.T) { cc.SetTLSInsecure(true) cc.SetTLSServerName("example.com") - err = cc.SetTLSEndpoint(u.Host) + err = cc.SetTLSEndpoint(context.Background(), u.Host) require.NoError(t, err) - err = cc.ProbeTLSInfo() + err = cc.ProbeTLSInfo(context.Background()) require.NoError(t, err) // Since it's a local TLS server run by Go's httptest, it supports TLS 1.3 or TLS 1.2 @@ -470,7 +471,7 @@ func TestCertinfo_ProbeTLSInfo_NotRequested(t *testing.T) { cc.SetTLSInfoRequested(false) require.False(t, cc.TLSInfoRequested) - err = cc.ProbeTLSInfo() + err = cc.ProbeTLSInfo(context.Background()) require.NoError(t, err) require.Empty(t, cc.NegotiatedProtocol) } @@ -483,7 +484,7 @@ func TestCertinfo_ProbeTLSInfo_NoEndpoint(t *testing.T) { cc.SetTLSInfoRequested(true) - err = cc.ProbeTLSInfo() + err = cc.ProbeTLSInfo(context.Background()) require.NoError(t, err) require.Empty(t, cc.ProbedProtocols) } @@ -502,7 +503,7 @@ func TestCertinfo_ProbeTLSInfo_Unreachable(t *testing.T) { cc.TLSEndpointHost = "127.0.0.1" cc.TLSEndpointPort = "54321" - err = cc.ProbeTLSInfo() + err = cc.ProbeTLSInfo(context.Background()) require.NoError(t, err) // When unreachable, all scanned protocols should be unsupported @@ -655,7 +656,7 @@ func TestCertinfo_ProbeTLSInfo_SingleCipher(t *testing.T) { }, } - res := cc.probeCiphersConcurrently(ciphers) + res := cc.probeCiphersConcurrently(context.Background(), ciphers) require.Len(t, res, 1) require.Equal(t, "TLS_AES_128_GCM_SHA256", res[0].Name) require.False(t, res[0].Supported) @@ -673,7 +674,7 @@ func TestCertinfo_PrintData_WithTLSInfo(t *testing.T) { var buf bytes.Buffer - err = cc.PrintData(&buf) + err = cc.PrintData(context.Background(), &buf) require.NoError(t, err) require.Contains(t, buf.String(), "Negotiated TLS Connection") } diff --git a/internal/cmd/certinfo.go b/internal/cmd/certinfo.go index 35cdc45..5820c74 100644 --- a/internal/cmd/certinfo.go +++ b/internal/cmd/certinfo.go @@ -5,6 +5,8 @@ Copyright © 2025 Zeno Belli xeno@os76.xyz package cmd import ( + "context" + "github.com/spf13/cobra" "github.com/spf13/viper" "github.com/xenos76/https-wrench/internal/certinfo" @@ -99,7 +101,7 @@ Examples: // SetTLSEndpoint may need the SNI/ServerName and insecure options to be set // before being able to ask details about the certificate we want to a // webserver using self-signed and valid certificates - if err = certinfoCfg.SetTLSEndpoint(tlsEndpoint); err != nil { + if err = certinfoCfg.SetTLSEndpoint(context.Background(), tlsEndpoint); err != nil { cmd.Printf("Error setting TLS endpoint: %s", err) return } @@ -113,7 +115,7 @@ Examples: } // dump.Print(certinfoCfg) - if err = certinfoCfg.PrintData(cmd.OutOrStdout()); err != nil { + if err = certinfoCfg.PrintData(context.Background(), cmd.OutOrStdout()); err != nil { cmd.Printf("error printing Certinfo data: %s", err) } }, diff --git a/internal/cmd/mcp_test.go b/internal/cmd/mcp_test.go index d24b914..94b368f 100644 --- a/internal/cmd/mcp_test.go +++ b/internal/cmd/mcp_test.go @@ -23,8 +23,6 @@ func TestIsMCPCommand(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - t.Parallel() - oldArgs := os.Args t.Cleanup(func() { os.Args = oldArgs }) diff --git a/internal/cmd/requests.go b/internal/cmd/requests.go index e3e65f3..414df0c 100644 --- a/internal/cmd/requests.go +++ b/internal/cmd/requests.go @@ -5,6 +5,7 @@ Copyright © 2025 Zeno Belli xeno@os76.xyz package cmd import ( + "context" _ "embed" "fmt" "os" @@ -95,7 +96,7 @@ Examples: cmd.Print(err) } - responseMap, err := requests.HandleRequests(cmd.OutOrStdout(), requestsCfg) + responseMap, err := requests.HandleRequests(context.Background(), cmd.OutOrStdout(), requestsCfg) if err != nil { cmd.Print(err) } diff --git a/internal/mcp/coverage_test.go b/internal/mcp/coverage_test.go index 7db978d..d44d0ad 100644 --- a/internal/mcp/coverage_test.go +++ b/internal/mcp/coverage_test.go @@ -260,8 +260,10 @@ func TestExecuteHelpers(t *testing.T) { require.Error(t, err) _, err = executeCertinfo(context.Background(), certinfoInput{ - TLSEndpoint: "example.com:443", - TLSInfo: true, + TLSEndpoint: "example.com:443", + TLSInsecure: true, + TLSInfo: true, + TLSServername: "example.com", }) require.NoError(t, err) diff --git a/internal/mcp/tools_exec.go b/internal/mcp/tools_exec.go index c65ee0b..907648b 100644 --- a/internal/mcp/tools_exec.go +++ b/internal/mcp/tools_exec.go @@ -172,11 +172,15 @@ func executeRunRequests(ctx context.Context, input runRequestsInput) (execToolOu return execToolOutput{}, err } - return runRequestsExec(input) + return runRequestsExec(ctx, input) }) } -func runRequestsExec(input runRequestsInput) (execToolOutput, error) { +func runRequestsExec(ctx context.Context, input runRequestsInput) (execToolOutput, error) { + if err := ctx.Err(); err != nil { + return execToolOutput{}, err + } + yamlContent, err := loadConfigYAML(input.ConfigYAML, input.ConfigPath) if err != nil { return execToolOutput{}, err @@ -198,7 +202,7 @@ func runRequestsExec(input runRequestsInput) (execToolOutput, error) { } output, err := captureOutput(func(w io.Writer) error { - _, handleErr := requests.HandleRequests(w, meta) + _, handleErr := requests.HandleRequests(ctx, w, meta) return handleErr }) @@ -215,11 +219,15 @@ func executeCertinfo(ctx context.Context, input certinfoInput) (execToolOutput, return execToolOutput{}, err } - return certinfoExec(input) + return certinfoExec(ctx, input) }) } -func certinfoExec(input certinfoInput) (execToolOutput, error) { +func certinfoExec(ctx context.Context, input certinfoInput) (execToolOutput, error) { + if err := ctx.Err(); err != nil { + return execToolOutput{}, err + } + if !certinfoInputProvided(input) { return execToolOutput{}, errors.New( "one of tlsEndpoint, certBundle, keyFile, or caBundle is required", @@ -237,6 +245,10 @@ func certinfoExec(input certinfoInput) (execToolOutput, error) { reader := mcpFileReader{} + if err = ctx.Err(); err != nil { + return execToolOutput{}, err + } + if err = cfg.SetCaPoolFromFile(input.CaBundle, reader); err != nil { return execToolOutput{}, err } @@ -245,11 +257,19 @@ func certinfoExec(input certinfoInput) (execToolOutput, error) { return execToolOutput{}, err } + if err = ctx.Err(); err != nil { + return execToolOutput{}, err + } + cfg.SetTLSInsecure(input.TLSInsecure). SetTLSServerName(input.TLSServername). SetTLSInfoRequested(input.TLSInfo) - if err = cfg.SetTLSEndpoint(input.TLSEndpoint); err != nil { + if err = cfg.SetTLSEndpoint(ctx, input.TLSEndpoint); err != nil { + return execToolOutput{}, err + } + + if err = ctx.Err(); err != nil { return execToolOutput{}, err } @@ -257,7 +277,9 @@ func certinfoExec(input certinfoInput) (execToolOutput, error) { return execToolOutput{}, err } - output, err := captureOutput(cfg.PrintData) + output, err := captureOutput(func(w io.Writer) error { + return cfg.PrintData(ctx, w) + }) if err != nil { return execToolOutput{}, err } diff --git a/internal/requests/requests.go b/internal/requests/requests.go index 6349241..114042c 100644 --- a/internal/requests/requests.go +++ b/internal/requests/requests.go @@ -633,6 +633,7 @@ func NewHTTPClientFromRequestConfig( // //nolint:revive func processHTTPRequestsByHost( + ctx context.Context, w io.Writer, r RequestConfig, caPool *x509.CertPool, @@ -643,7 +644,11 @@ func processHTTPRequestsByHost( r.PrintTitle(w, isVerbose) for _, host := range r.Hosts { - hostResults, err := processRequestsForHost(w, r, host, caPool, isVerbose) + if err := ctx.Err(); err != nil { + return nil, err + } + + hostResults, err := processRequestsForHost(ctx, w, r, host, caPool, isVerbose) if err != nil { return nil, err } @@ -656,6 +661,7 @@ func processHTTPRequestsByHost( // processRequestsForHost initializes the HTTP client and executes all configured URIs for a single host. func processRequestsForHost( + ctx context.Context, w io.Writer, r RequestConfig, host Host, @@ -677,7 +683,11 @@ func processRequestsForHost( requestBodyBytes := []byte(r.RequestBody) for _, reqURL := range urlList { - responseData := executeSingleRequest(w, r, reqClient, reqURL, requestBodyBytes, isVerbose) + if err := ctx.Err(); err != nil { + return nil, err + } + + responseData := executeSingleRequest(ctx, w, r, reqClient, reqURL, requestBodyBytes, isVerbose) responseDataList = append(responseDataList, responseData) responseData.PrintResponseData(w, isVerbose) } @@ -687,6 +697,7 @@ func processRequestsForHost( // executeSingleRequest performs a single HTTP request and returns the collected response data. func executeSingleRequest( + ctx context.Context, w io.Writer, r RequestConfig, reqClient *RequestHTTPClient, @@ -702,7 +713,7 @@ func executeSingleRequest( requestBodyReader := bytes.NewReader(requestBodyBytes) - req, err := http.NewRequest(reqClient.method, reqURL, requestBodyReader) + req, err := http.NewRequestWithContext(ctx, reqClient.method, reqURL, requestBodyReader) if err != nil { responseData.Error = fmt.Errorf("failed to create request: %w", err) return responseData diff --git a/internal/requests/requests_handlers.go b/internal/requests/requests_handlers.go index aaa6971..71ed86d 100644 --- a/internal/requests/requests_handlers.go +++ b/internal/requests/requests_handlers.go @@ -2,6 +2,7 @@ package requests import ( "bytes" + "context" "crypto/tls" "encoding/json" "errors" @@ -211,13 +212,22 @@ func proxyProtoHeaderFromRequest(r RequestConfig, serverName string) (proxyproto } // HandleRequests iterates through all configured requests and processes them, returning a map of response data. -func HandleRequests(w io.Writer, cfg *RequestsMetaConfig) (map[string][]ResponseData, error) { +func HandleRequests( + ctx context.Context, + w io.Writer, + cfg *RequestsMetaConfig, +) (map[string][]ResponseData, error) { responseDataMap := make(map[string][]ResponseData) cfg.PrintCmd(w) for _, r := range cfg.Requests { + if err := ctx.Err(); err != nil { + return nil, err + } + responseDataList, err := processHTTPRequestsByHost( + ctx, w, r, cfg.CACertsPool, diff --git a/internal/requests/requests_handlers_test.go b/internal/requests/requests_handlers_test.go index efc6cf4..735ab21 100644 --- a/internal/requests/requests_handlers_test.go +++ b/internal/requests/requests_handlers_test.go @@ -2,6 +2,7 @@ package requests import ( "bytes" + "context" "crypto/tls" "crypto/x509" "fmt" @@ -396,6 +397,7 @@ func TestRenderTLSData(t *testing.T) { defer ts.Close() respList, err := processHTTPRequestsByHost( + context.Background(), io.Discard, tt.reqConf, tt.pool, @@ -519,7 +521,7 @@ func runHandleRequestsSubtest(t *testing.T, tt handleRequestsTestCase) { defer ts.Close() buffer := bytes.Buffer{} - respMap, err := HandleRequests(&buffer, &tt.reqMeta) + respMap, err := HandleRequests(context.Background(), &buffer, &tt.reqMeta) if tt.expectErr { require.Error(t, err) diff --git a/internal/requests/requests_test.go b/internal/requests/requests_test.go index 63b114e..39facff 100644 --- a/internal/requests/requests_test.go +++ b/internal/requests/requests_test.go @@ -2,6 +2,7 @@ package requests import ( "bytes" + "context" "crypto/tls" "crypto/x509" "errors" @@ -1247,7 +1248,7 @@ func TestProcessHTTPRequestsByHost_Errors(t *testing.T) { {Name: "localhost", URIList: []URI{"invalid"}}, }, } - _, err := processHTTPRequestsByHost(io.Discard, reqConf, nil, false) + _, err := processHTTPRequestsByHost(context.Background(), io.Discard, reqConf, nil, false) require.Error(t, err) require.ErrorContains(t, err, "invalid uri") }) @@ -1582,6 +1583,7 @@ func runProcessHTTPRequestsByHostSubtest(t *testing.T, tt processHTTPRequestsByH defer ts.Close() respList, err := processHTTPRequestsByHost( + context.Background(), io.Discard, tt.reqConf, tt.pool,