-
Notifications
You must be signed in to change notification settings - Fork 13
Expand file tree
/
Copy pathtest_main.py
More file actions
81 lines (67 loc) · 2.31 KB
/
test_main.py
File metadata and controls
81 lines (67 loc) · 2.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from unittest import mock
from unittest.mock import Mock, call
import pytest
from fastapi import FastAPI, Request
from fastapi.testclient import TestClient
from blueapi import __version__
from blueapi.config import ApplicationConfig
from blueapi.service.main import (
add_version_headers,
get_passthrough_headers,
log_request_details,
)
async def test_add_version_header():
app = FastAPI()
app.middleware("http")(add_version_headers)
@app.get("/")
async def root():
return {"message": "Hello World"}
client = TestClient(app)
response = client.get("/")
assert response.headers["X-API-VERSION"] == ApplicationConfig.REST_API_VERSION
assert response.headers["X-BlueAPI-VERSION"] == __version__
@pytest.mark.parametrize("path,level", [("/", "info"), ("/healthz", "debug")])
async def test_log_request_details(path: str, level: str):
with mock.patch("blueapi.service.main.LOGGER") as logger:
app = FastAPI()
app.middleware("http")(log_request_details)
@app.post(path)
async def root():
return {"message": "Hello World"}
client = TestClient(app)
response = client.post(path, content="foo")
assert response.status_code == 200
log_level = getattr(logger, level)
log_level.assert_has_calls(
[
call(
f"testclient:50000 POST {path}",
extra={
"request_body": b"foo",
},
),
call(
f"testclient:50000 POST {path} 200",
extra={
"request_body": b"foo",
},
),
]
)
@pytest.mark.parametrize(
"headers, expected_headers",
[
({}, {}),
({"foo": "bar"}, {}),
({"authorization": "yes"}, {"authorization": "yes"}),
({"autHORIzation": "yes"}, {"autHORIzation": "yes"}),
({"autHORIzation": "yes", "foo": "bar"}, {"autHORIzation": "yes"}),
({"autHORIzation": ""}, {"autHORIzation": ""}),
],
)
def test_get_passthrough_headers(
headers: dict[str, str], expected_headers: dict[str, str]
):
request = Mock(spec=Request)
request.headers = headers
assert get_passthrough_headers(request) == expected_headers