|
6 | 6 | "crypto/rsa" |
7 | 7 | "crypto/sha256" |
8 | 8 | "encoding/json" |
| 9 | + "fmt" |
9 | 10 | "net/http" |
10 | 11 | "net/http/cookiejar" |
11 | 12 | "net/http/httptest" |
@@ -271,3 +272,144 @@ func TestPrivateClient(t *testing.T) { |
271 | 272 | require.NotEmpty(t, newAccessToken) |
272 | 273 | require.NotEqual(t, originalAccessToken, newAccessToken, "Access token should be different after refresh") |
273 | 274 | } |
| 275 | + |
| 276 | +// registerTestClient is a helper that registers a private OAuth client and returns the registration response. |
| 277 | +func registerTestClient(t *testing.T, serverURL string) registrationResponse { |
| 278 | + t.Helper() |
| 279 | + |
| 280 | + regReq := registrationRequest{ |
| 281 | + ClientName: "Test OAuth Client", |
| 282 | + GrantTypes: []string{"authorization_code", "refresh_token"}, |
| 283 | + ResponseTypes: []string{"code"}, |
| 284 | + TokenEndpointAuthMethod: "client_secret_basic", |
| 285 | + Scope: "test", |
| 286 | + RedirectURIs: []string{"http://localhost:8080/callback"}, |
| 287 | + } |
| 288 | + |
| 289 | + reqBody, err := json.Marshal(regReq) |
| 290 | + require.NoError(t, err) |
| 291 | + |
| 292 | + resp, err := http.Post(serverURL+RegistrationEndpoint, "application/json", bytes.NewReader(reqBody)) |
| 293 | + require.NoError(t, err) |
| 294 | + defer resp.Body.Close() |
| 295 | + |
| 296 | + require.Equal(t, http.StatusCreated, resp.StatusCode) |
| 297 | + |
| 298 | + var regResp registrationResponse |
| 299 | + err = json.NewDecoder(resp.Body).Decode(®Resp) |
| 300 | + require.NoError(t, err) |
| 301 | + |
| 302 | + return regResp |
| 303 | +} |
| 304 | + |
| 305 | +// testAuthFlowWithURL performs the OAuth authorization flow given a raw authorization URL |
| 306 | +// and returns the callback URL after authorization completes. |
| 307 | +func testAuthFlowWithURL(t *testing.T, serverURL, authURL string) *url.URL { |
| 308 | + t.Helper() |
| 309 | + |
| 310 | + jar, err := cookiejar.New(nil) |
| 311 | + require.NoError(t, err) |
| 312 | + client := &http.Client{ |
| 313 | + Jar: jar, |
| 314 | + CheckRedirect: func(req *http.Request, via []*http.Request) error { |
| 315 | + return http.ErrUseLastResponse |
| 316 | + }, |
| 317 | + } |
| 318 | + |
| 319 | + // Step 1: Make initial authorization request |
| 320 | + authResp, err := client.Get(authURL) |
| 321 | + require.NoError(t, err) |
| 322 | + defer authResp.Body.Close() |
| 323 | + |
| 324 | + require.Contains(t, []int{http.StatusFound, http.StatusSeeOther}, authResp.StatusCode, |
| 325 | + "expected redirect, got %d", authResp.StatusCode) |
| 326 | + location := authResp.Header.Get("Location") |
| 327 | + require.NotEmpty(t, location) |
| 328 | + require.Contains(t, location, strings.ReplaceAll(AuthorizationReturnEndpoint, ":ar_id", "")) |
| 329 | + |
| 330 | + // Step 2: Follow the redirect to complete authorization |
| 331 | + authReturnResp, err := client.Get(serverURL + location) |
| 332 | + require.NoError(t, err) |
| 333 | + defer authReturnResp.Body.Close() |
| 334 | + |
| 335 | + require.Contains(t, []int{http.StatusFound, http.StatusSeeOther}, authReturnResp.StatusCode, |
| 336 | + "expected redirect with authorization code, got %d", authReturnResp.StatusCode) |
| 337 | + callbackLocation := authReturnResp.Header.Get("Location") |
| 338 | + require.NotEmpty(t, callbackLocation) |
| 339 | + |
| 340 | + callbackURL, err := url.Parse(callbackLocation) |
| 341 | + require.NoError(t, err) |
| 342 | + require.NotEmpty(t, callbackURL.Query().Get("code"), "callback URL should contain an authorization code") |
| 343 | + |
| 344 | + return callbackURL |
| 345 | +} |
| 346 | + |
| 347 | +func TestAuthWithoutState(t *testing.T) { |
| 348 | + server, _, _ := setupTestServer(t) |
| 349 | + regResp := registerTestClient(t, server.URL) |
| 350 | + |
| 351 | + // Build authorization URL manually WITHOUT a state parameter |
| 352 | + authURL := fmt.Sprintf("%s%s?response_type=code&client_id=%s&redirect_uri=%s", |
| 353 | + server.URL, AuthorizationEndpoint, regResp.ClientID, |
| 354 | + url.QueryEscape("http://localhost:8080/callback")) |
| 355 | + |
| 356 | + callbackURL := testAuthFlowWithURL(t, server.URL, authURL) |
| 357 | + |
| 358 | + // Server should have generated a state and echoed it back |
| 359 | + require.NotEmpty(t, callbackURL.Query().Get("state"), "server should generate a state when client omits it") |
| 360 | + |
| 361 | + // Exchange authorization code for tokens |
| 362 | + code := callbackURL.Query().Get("code") |
| 363 | + tokenReq := url.Values{} |
| 364 | + tokenReq.Set("grant_type", "authorization_code") |
| 365 | + tokenReq.Set("code", code) |
| 366 | + tokenReq.Set("redirect_uri", "http://localhost:8080/callback") |
| 367 | + tokenReq.Set("client_id", regResp.ClientID) |
| 368 | + tokenReq.Set("client_secret", regResp.ClientSecret) |
| 369 | + |
| 370 | + tokenResp, err := http.PostForm(server.URL+TokenEndpoint, tokenReq) |
| 371 | + require.NoError(t, err) |
| 372 | + defer tokenResp.Body.Close() |
| 373 | + |
| 374 | + require.Equal(t, http.StatusOK, tokenResp.StatusCode) |
| 375 | + |
| 376 | + var tokenResult map[string]any |
| 377 | + err = json.NewDecoder(tokenResp.Body).Decode(&tokenResult) |
| 378 | + require.NoError(t, err) |
| 379 | + require.NotEmpty(t, tokenResult["access_token"]) |
| 380 | +} |
| 381 | + |
| 382 | +func TestAuthWithEmptyState(t *testing.T) { |
| 383 | + server, _, _ := setupTestServer(t) |
| 384 | + regResp := registerTestClient(t, server.URL) |
| 385 | + |
| 386 | + // Build authorization URL with an empty state parameter |
| 387 | + authURL := fmt.Sprintf("%s%s?response_type=code&client_id=%s&redirect_uri=%s&state=", |
| 388 | + server.URL, AuthorizationEndpoint, regResp.ClientID, |
| 389 | + url.QueryEscape("http://localhost:8080/callback")) |
| 390 | + |
| 391 | + callbackURL := testAuthFlowWithURL(t, server.URL, authURL) |
| 392 | + |
| 393 | + // Server should have generated a state and echoed it back |
| 394 | + require.NotEmpty(t, callbackURL.Query().Get("state"), "server should generate a state when client sends empty state") |
| 395 | + |
| 396 | + // Exchange authorization code for tokens |
| 397 | + code := callbackURL.Query().Get("code") |
| 398 | + tokenReq := url.Values{} |
| 399 | + tokenReq.Set("grant_type", "authorization_code") |
| 400 | + tokenReq.Set("code", code) |
| 401 | + tokenReq.Set("redirect_uri", "http://localhost:8080/callback") |
| 402 | + tokenReq.Set("client_id", regResp.ClientID) |
| 403 | + tokenReq.Set("client_secret", regResp.ClientSecret) |
| 404 | + |
| 405 | + tokenResp, err := http.PostForm(server.URL+TokenEndpoint, tokenReq) |
| 406 | + require.NoError(t, err) |
| 407 | + defer tokenResp.Body.Close() |
| 408 | + |
| 409 | + require.Equal(t, http.StatusOK, tokenResp.StatusCode) |
| 410 | + |
| 411 | + var tokenResult map[string]any |
| 412 | + err = json.NewDecoder(tokenResp.Body).Decode(&tokenResult) |
| 413 | + require.NoError(t, err) |
| 414 | + require.NotEmpty(t, tokenResult["access_token"]) |
| 415 | +} |
0 commit comments