diff --git a/addr.go b/addr.go index db90b8d..a65b360 100644 --- a/addr.go +++ b/addr.go @@ -4,12 +4,17 @@ import ( "context" "encoding/json" "errors" + "fmt" + "io" "log/slog" "net/http" "strconv" + "strings" "time" ) +const maxAddressErrorBody = 4 << 10 + var ( // defined as a variable so it can be overridden in tests. addrURI = `https://experience.aucklandcouncil.govt.nz/nextapi/property` @@ -51,6 +56,11 @@ func MatchingPropertyAddresses(ctx context.Context, addrReq *AddrRequest) (*Addr return cachedAr, nil } + token, err := addrTokenProvider(ctx) + if err != nil { + return nil, fmt.Errorf("get address API token: %w", err) + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, addrURI, nil) if err != nil { return nil, err @@ -61,6 +71,7 @@ func MatchingPropertyAddresses(ctx context.Context, addrReq *AddrRequest) (*Addr q.Add("pageSize", strconv.Itoa(addrReq.PageSize)) } req.URL.RawQuery = q.Encode() + req.Header.Set("Authorization", "Bearer "+token) start := time.Now() resp, err := addrHTTPClient.Do(req) @@ -71,7 +82,7 @@ func MatchingPropertyAddresses(ctx context.Context, addrReq *AddrRequest) (*Addr slog.DebugContext(ctx, "address call complete", "duration", time.Since(start)) if resp.StatusCode != http.StatusOK { - return nil, errors.New("address API returned status code: " + strconv.Itoa(resp.StatusCode)) + return nil, addressStatusError(resp) } dec := json.NewDecoder(resp.Body) @@ -95,3 +106,67 @@ func oneAddress(ctx context.Context, addr string) (*Address, error) { } return &resp.Items[0], nil } + +func addressStatusError(resp *http.Response) error { + body, err := io.ReadAll(io.LimitReader(resp.Body, maxAddressErrorBody)) + if err != nil { + return fmt.Errorf("address API returned status code: %d", resp.StatusCode) + } + + msg := sanitizeAddressErrorBody(string(body)) + if msg == "" { + return fmt.Errorf("address API returned status code: %d", resp.StatusCode) + } + return fmt.Errorf("address API returned status code: %d: %s", resp.StatusCode, msg) +} + +func sanitizeAddressErrorBody(body string) string { + body = strings.TrimSpace(body) + if body == "" { + return "" + } + body = strings.Map(func(r rune) rune { + if r == '\n' || r == '\r' || r == '\t' { + return ' ' + } + if r < ' ' { + return -1 + } + return r + }, body) + body = strings.Join(strings.Fields(body), " ") + if body == "" { + return "" + } + + const marker = `"error"` + if strings.Contains(body, marker) { + var payload struct { + Error string `json:"error"` + } + if err := json.Unmarshal([]byte(body), &payload); err == nil && payload.Error != "" { + body = payload.Error + } + } + + body = redactSecretLikeText(body) + if len(body) > maxAddressErrorBody { + body = body[:maxAddressErrorBody] + } + return body +} + +func redactSecretLikeText(s string) string { + words := strings.Fields(s) + for i, word := range words { + trimmed := strings.Trim(word, `"'.,;:()[]{}<>`) + if strings.EqualFold(trimmed, "Bearer") && i+1 < len(words) { + words[i+1] = "" + continue + } + if strings.Count(trimmed, ".") >= 2 && len(trimmed) > 40 { + words[i] = strings.Replace(word, trimmed, "", 1) + } + } + return strings.Join(words, " ") +} diff --git a/addr_test.go b/addr_test.go index 3999713..b505eb7 100644 --- a/addr_test.go +++ b/addr_test.go @@ -1,12 +1,18 @@ package aklapi import ( + "context" "encoding/json" + "errors" "io" "net/http" "net/http/httptest" "reflect" + "strings" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var testAddr = &Address{ @@ -64,10 +70,12 @@ func TestMatchingPropertyAddresses(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - defer tt.testSrv.Close() - oldURI := addrURI - defer func() { addrURI = oldURI }() + resetAddressTestGlobals(t) + t.Cleanup(tt.testSrv.Close) addrURI = tt.testSrv.URL + addrTokenProvider = func(context.Context) (string, error) { + return "test-token", nil + } got, err := MatchingPropertyAddresses(t.Context(), tt.args.addrReq) if (err != nil) != tt.wantErr { t.Errorf("MatchingPropertyAddresses() error = %v, wantErr %v", err, tt.wantErr) @@ -102,10 +110,12 @@ func TestAddress(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - defer tt.testSrv.Close() - oldURI := addrURI - defer func() { addrURI = oldURI }() + resetAddressTestGlobals(t) + t.Cleanup(tt.testSrv.Close) addrURI = tt.testSrv.URL + addrTokenProvider = func(context.Context) (string, error) { + return "test-token", nil + } got, err := AddressLookup(t.Context(), tt.args.addr) if (err != nil) != tt.wantErr { t.Errorf("Address() error = %v, wantErr %v", err, tt.wantErr) @@ -156,10 +166,12 @@ func Test_oneAddress(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - defer tt.testSrv.Close() - oldURI := addrURI - defer func() { addrURI = oldURI }() + resetAddressTestGlobals(t) + t.Cleanup(tt.testSrv.Close) addrURI = tt.testSrv.URL + addrTokenProvider = func(context.Context) (string, error) { + return "test-token", nil + } got, err := oneAddress(t.Context(), tt.args.addr) if (err != nil) != tt.wantErr { t.Errorf("oneAddress() error = %v, wantErr %v", err, tt.wantErr) @@ -172,6 +184,119 @@ func Test_oneAddress(t *testing.T) { } } +func TestMatchingPropertyAddressesAuthorization(t *testing.T) { + resetAddressTestGlobals(t) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "Bearer test-token", r.Header.Get("Authorization")) + assert.Equal(t, "red sq", r.URL.Query().Get("query")) + assert.Equal(t, "1", r.URL.Query().Get("pageSize")) + writeAddrJSON(w, AddrResponse{Items: []Address{*testAddr}}) + })) + t.Cleanup(ts.Close) + + addrURI = ts.URL + addrTokenProvider = func(context.Context) (string, error) { + return "test-token", nil + } + + got, err := MatchingPropertyAddresses(t.Context(), &AddrRequest{PageSize: 1, SearchText: "red sq"}) + require.NoError(t, err) + assert.Equal(t, &AddrResponse{Items: []Address{*testAddr}}, got) +} + +func TestMatchingPropertyAddressesTokenFailure(t *testing.T) { + resetAddressTestGlobals(t) + + var called bool + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(ts.Close) + + addrURI = ts.URL + addrTokenProvider = func(context.Context) (string, error) { + return "", errors.New("token unavailable") + } + + got, err := MatchingPropertyAddresses(t.Context(), &AddrRequest{SearchText: "red sq"}) + require.Error(t, err) + assert.Nil(t, got) + assert.False(t, called) + assert.Contains(t, err.Error(), "get address API token: token unavailable") +} + +func TestMatchingPropertyAddressesStatusError(t *testing.T) { + resetAddressTestGlobals(t) + + tests := []struct { + name string + body string + want string + forbidText string + }{ + { + name: "json error", + body: `{"error":"Authorisation error. Current IP address logged."}`, + want: "address API returned status code: 401: Authorisation error. Current IP address logged.", + }, + { + name: "secret-looking body is redacted", + body: `{"error":"bad token Bearer header.eyJleHAiOjE3Nzk0NTQ4MDB9.signature"}`, + want: "address API returned status code: 401: bad token Bearer ", + forbidText: "header.eyJ", + }, + { + name: "large body is bounded", + body: strings.Repeat("x", maxAddressErrorBody+100), + want: "address API returned status code: 401: " + strings.Repeat("x", maxAddressErrorBody), + forbidText: strings.Repeat("x", maxAddressErrorBody+1), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetAddressTestGlobals(t) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(tt.body)) + })) + t.Cleanup(ts.Close) + + addrURI = ts.URL + addrTokenProvider = func(context.Context) (string, error) { + return "test-token", nil + } + + _, err := MatchingPropertyAddresses(t.Context(), &AddrRequest{SearchText: "red sq"}) + require.Error(t, err) + assert.Equal(t, tt.want, err.Error()) + if tt.forbidText != "" { + assert.NotContains(t, err.Error(), tt.forbidText) + } + }) + } +} + +func resetAddressTestGlobals(t *testing.T) { + t.Helper() + oldURI := addrURI + oldClient := addrHTTPClient + oldProvider := addrTokenProvider + oldCache := addrCache + oldNoCache := NoCache + t.Cleanup(func() { + addrURI = oldURI + addrHTTPClient = oldClient + addrTokenProvider = oldProvider + addrCache = oldCache + NoCache = oldNoCache + }) + addrCache = newLRUCache[string, *AddrResponse](defCacheSz) + NoCache = false +} + func writeAddrJSON(w io.Writer, r AddrResponse) { data, err := json.Marshal(r) if err != nil { diff --git a/token.go b/token.go new file mode 100644 index 0000000..f23f2db --- /dev/null +++ b/token.go @@ -0,0 +1,146 @@ +package aklapi + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "regexp" + "strings" + "sync" + "time" +) + +const ( + addrTokenFallbackTTL = 5 * time.Minute + addrTokenRefreshMargin = 30 * time.Second + maxTokenPageSize = 16 << 20 +) + +var ( + addrTokenURI = `https://www.aucklandcouncil.govt.nz/en/rubbish-recycling/rubbish-recycling-collections/rubbish-recycling-collection-days.html` + // Kept replaceable in case Auckland Council moves token acquisition back to + // a server action. The embedded initialToken path does not currently use it. + addrTokenActionID = "" + addrTokenClient = &http.Client{Timeout: 15 * time.Second, Transport: &browserTransport{wrapped: http.DefaultTransport}} + addrTokenProvider = cachedAddrToken + addrTokenCache = &bearerTokenCache{} +) + +var initialTokenPatterns = []*regexp.Regexp{ + regexp.MustCompile(`initialToken\\":\\"([^"\\]+)`), + regexp.MustCompile(`"initialToken"\s*:\s*"([^"]+)`), + regexp.MustCompile(`initialToken":"([^"]+)`), +} + +type bearerToken struct { + value string + expires time.Time +} + +func (t bearerToken) valid(at time.Time) bool { + return t.value != "" && t.expires.After(at.Add(addrTokenRefreshMargin)) +} + +type bearerTokenCache struct { + mu sync.Mutex + token bearerToken +} + +func cachedAddrToken(ctx context.Context) (string, error) { + return addrTokenCache.tokenValue(ctx) +} + +func (c *bearerTokenCache) tokenValue(ctx context.Context) (string, error) { + at := now() + c.mu.Lock() + defer c.mu.Unlock() + + if c.token.valid(at) { + return c.token.value, nil + } + + token, err := fetchAddrToken(ctx) + if err != nil { + slog.WarnContext(ctx, "address token refresh failed", "error", err) + return "", err + } + c.token = token + return token.value, nil +} + +func fetchAddrToken(ctx context.Context) (bearerToken, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, addrTokenURI, nil) + if err != nil { + return bearerToken{}, err + } + setBrowserDocumentHeaders(req) + + resp, err := addrTokenClient.Do(req) + if err != nil { + return bearerToken{}, fmt.Errorf("fetch address token page: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return bearerToken{}, fmt.Errorf("fetch address token page: status %d", resp.StatusCode) + } + + token, err := extractInitialToken(io.LimitReader(resp.Body, maxTokenPageSize)) + if err != nil { + return bearerToken{}, fmt.Errorf("parse address token page: %w", err) + } + + expires, err := jwtExpiry(token) + if err != nil { + slog.DebugContext(ctx, "using fallback address token expiry", "error", err) + expires = now().Add(addrTokenFallbackTTL) + } + return bearerToken{value: token, expires: expires}, nil +} + +func extractInitialToken(r io.Reader) (string, error) { + body, err := io.ReadAll(r) + if err != nil { + return "", err + } + + for _, pattern := range initialTokenPatterns { + matches := pattern.FindSubmatch(body) + if len(matches) == 2 { + token := strings.TrimSpace(string(matches[1])) + if token == "" { + return "", errors.New("empty initialToken") + } + return token, nil + } + } + return "", errors.New("initialToken not found") +} + +func jwtExpiry(token string) (time.Time, error) { + parts := strings.Split(token, ".") + if len(parts) < 2 { + return time.Time{}, errors.New("token is not a JWT") + } + + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return time.Time{}, fmt.Errorf("decode JWT payload: %w", err) + } + + var claims struct { + Exp int64 `json:"exp"` + } + if err := json.Unmarshal(payload, &claims); err != nil { + return time.Time{}, fmt.Errorf("decode JWT claims: %w", err) + } + if claims.Exp <= 0 { + return time.Time{}, errors.New("JWT exp claim missing") + } + return time.Unix(claims.Exp, 0).UTC(), nil +} diff --git a/token_test.go b/token_test.go new file mode 100644 index 0000000..457207c --- /dev/null +++ b/token_test.go @@ -0,0 +1,316 @@ +package aklapi + +import ( + "context" + "encoding/base64" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBearerTokenValid(t *testing.T) { + base := time.Date(2026, 5, 22, 12, 0, 0, 0, time.UTC) + + tests := []struct { + name string + token bearerToken + at time.Time + want bool + }{ + {"empty token", bearerToken{expires: base.Add(time.Hour)}, base, false}, + {"valid before margin", bearerToken{value: "token", expires: base.Add(time.Minute)}, base, true}, + {"near expiry", bearerToken{value: "token", expires: base.Add(30 * time.Second)}, base, false}, + {"expired", bearerToken{value: "token", expires: base.Add(-time.Second)}, base, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, tt.token.valid(tt.at)) + }) + } +} + +func TestBearerTokenCache(t *testing.T) { + base := time.Date(2026, 5, 22, 12, 0, 0, 0, time.UTC) + + tests := []struct { + name string + startToken bearerToken + wantRequests int32 + wantToken string + wantExpiresAt time.Time + }{ + { + name: "reuses cached token", + startToken: bearerToken{value: "cached", expires: base.Add(time.Hour)}, + wantRequests: 0, + wantToken: "cached", + wantExpiresAt: base.Add(time.Hour), + }, + { + name: "refreshes expired token", + startToken: bearerToken{value: "expired", expires: base.Add(-time.Second)}, + wantRequests: 1, + wantToken: "fresh", + wantExpiresAt: base.Add(time.Hour), + }, + { + name: "refreshes token near expiry", + startToken: bearerToken{value: "stale", expires: base.Add(10 * time.Second)}, + wantRequests: 1, + wantToken: "fresh", + wantExpiresAt: base.Add(time.Hour), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + restoreTokenGlobals(t) + now = func() time.Time { return base } + + token := jwtWithExp(t, base.Add(time.Hour)) + var requests int32 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&requests, 1) + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.Write([]byte(`self.__next_f.push([1,"initialToken\":\"` + token + `\""])`)) + })) + t.Cleanup(ts.Close) + + addrTokenURI = ts.URL + addrTokenClient = ts.Client() + addrTokenCache = &bearerTokenCache{token: tt.startToken} + + got, err := cachedAddrToken(t.Context()) + require.NoError(t, err) + wantToken := tt.wantToken + if tt.wantRequests > 0 { + wantToken = token + } + assert.Equal(t, wantToken, got) + assert.Equal(t, tt.wantRequests, atomic.LoadInt32(&requests)) + assert.Equal(t, tt.wantExpiresAt, addrTokenCache.token.expires) + }) + } +} + +func TestBearerTokenCacheConcurrentRefresh(t *testing.T) { + restoreTokenGlobals(t) + base := time.Date(2026, 5, 22, 12, 0, 0, 0, time.UTC) + now = func() time.Time { return base } + + token := jwtWithExp(t, base.Add(time.Hour)) + var requests int32 + start := make(chan struct{}) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&requests, 1) + <-start + w.Write([]byte(`self.__next_f.push([1,"initialToken\":\"` + token + `\""])`)) + })) + t.Cleanup(ts.Close) + + addrTokenURI = ts.URL + addrTokenClient = ts.Client() + addrTokenCache = &bearerTokenCache{} + + const goroutines = 20 + errs := make(chan error, goroutines) + results := make(chan string, goroutines) + for range goroutines { + go func() { + got, err := cachedAddrToken(context.Background()) + errs <- err + results <- got + }() + } + + close(start) + for range goroutines { + require.NoError(t, <-errs) + assert.Equal(t, token, <-results) + } + assert.Equal(t, int32(1), atomic.LoadInt32(&requests)) +} + +func TestFetchAddrToken(t *testing.T) { + restoreTokenGlobals(t) + base := time.Date(2026, 5, 22, 12, 0, 0, 0, time.UTC) + now = func() time.Time { return base } + + tests := []struct { + name string + body string + status int + wantToken string + wantExp time.Time + wantErr string + }{ + { + name: "JWT expiry", + body: `self.__next_f.push([1,"initialToken\":\"` + jwtWithExp(t, base.Add(time.Hour)) + `\""])`, + status: http.StatusOK, + wantToken: jwtWithExp(t, base.Add(time.Hour)), + wantExp: base.Add(time.Hour), + }, + { + name: "opaque token uses fallback TTL", + body: `self.__next_f.push([1,"initialToken\":\"opaque-token\""])`, + status: http.StatusOK, + wantToken: "opaque-token", + wantExp: base.Add(addrTokenFallbackTTL), + }, + { + name: "non-200", + body: `error`, + status: http.StatusInternalServerError, + wantErr: "status 500", + }, + { + name: "missing token", + body: `self.__next_f.push([1,"{}"])`, + status: http.StatusOK, + wantErr: "initialToken not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodGet, r.Method) + assert.Equal(t, browserAccept, r.Header.Get("Accept")) + assert.Equal(t, "none", r.Header.Get("Sec-Fetch-Site")) + w.WriteHeader(tt.status) + w.Write([]byte(tt.body)) + })) + t.Cleanup(ts.Close) + + addrTokenURI = ts.URL + addrTokenClient = ts.Client() + + got, err := fetchAddrToken(t.Context()) + if tt.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantToken, got.value) + assert.Equal(t, tt.wantExp, got.expires) + }) + } +} + +func TestExtractInitialToken(t *testing.T) { + tests := []struct { + name string + html string + want string + wantErr string + }{ + { + name: "escaped Next Flight payload", + html: `self.__next_f.push([1,"initialToken\":\"jwt-token\",\"children\":[]"])`, + want: "jwt-token", + }, + { + name: "JSON payload", + html: `{"initialToken":"json-token"}`, + want: "json-token", + }, + { + name: "multiple chunks", + html: `self.__next_f.push([1,"{}"]);self.__next_f.push([2,"initialToken\":\"second-token\""])`, + want: "second-token", + }, + { + name: "missing token", + html: `self.__next_f.push([1,"{}"])`, + wantErr: "initialToken not found", + }, + { + name: "malformed empty token", + html: `{"initialToken":""}`, + wantErr: "initialToken not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := extractInitialToken(strings.NewReader(tt.html)) + if tt.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestJWTExpiry(t *testing.T) { + expires := time.Date(2026, 5, 22, 12, 30, 0, 0, time.UTC) + + tests := []struct { + name string + token string + want time.Time + wantErr string + }{ + {"valid", jwtWithExp(t, expires), expires, ""}, + {"not JWT", "opaque", time.Time{}, "token is not a JWT"}, + {"bad base64", "header.not*base64.sig", time.Time{}, "decode JWT payload"}, + {"bad JSON", "header." + base64.RawURLEncoding.EncodeToString([]byte(`not-json`)) + ".sig", time.Time{}, "decode JWT claims"}, + {"missing exp", jwtWithPayload(t, map[string]any{"sub": "test"}), time.Time{}, "JWT exp claim missing"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := jwtExpiry(tt.token) + if tt.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func restoreTokenGlobals(t *testing.T) { + t.Helper() + oldURI := addrTokenURI + oldActionID := addrTokenActionID + oldClient := addrTokenClient + oldProvider := addrTokenProvider + oldCache := addrTokenCache + oldNow := now + t.Cleanup(func() { + addrTokenURI = oldURI + addrTokenActionID = oldActionID + addrTokenClient = oldClient + addrTokenProvider = oldProvider + addrTokenCache = oldCache + now = oldNow + }) +} + +func jwtWithExp(t *testing.T, expires time.Time) string { + t.Helper() + return jwtWithPayload(t, map[string]any{"exp": expires.Unix()}) +} + +func jwtWithPayload(t *testing.T, payload map[string]any) string { + t.Helper() + claims, err := json.Marshal(payload) + require.NoError(t, err) + return "header." + base64.RawURLEncoding.EncodeToString(claims) + ".signature" +}