Skip to content

Commit d0ba007

Browse files
committed
Add middleware tests
1 parent d79de0f commit d0ba007

1 file changed

Lines changed: 108 additions & 0 deletions

File tree

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from unittest.mock import AsyncMock, MagicMock, Mock, patch
2+
3+
import pytest
4+
from starlette.types import ASGIApp
5+
6+
from blueapi.config import ApplicationConfig
7+
from blueapi.service.middleware import (
8+
API_VERSION,
9+
CONTEXT_HEADER,
10+
VENDOR_CONTEXT_HEADER,
11+
VERSION,
12+
ObservabilityContextPropagator,
13+
VersionHeaders,
14+
)
15+
16+
17+
@pytest.fixture
18+
def app():
19+
return AsyncMock(spec=ASGIApp)
20+
21+
22+
@pytest.mark.parametrize(
23+
"protocol,message_type",
24+
[("http", "http.response.start"), ("websocket", "websocket.accept")],
25+
)
26+
async def test_version_headers_added(app: Mock, protocol: str, message_type: str):
27+
vh = VersionHeaders(app)
28+
29+
send = AsyncMock()
30+
scope = {"type": protocol}
31+
await vh(scope, Mock(), send)
32+
33+
# the middleware wraps the send function so we need to extract the function
34+
# the app was actually called with
35+
local_send = app.call_args[0][2]
36+
37+
# Calling the wrapped send method here is equivalent to what the downstream
38+
# framework would do after the middleware has done its thing
39+
message = {"type": message_type, "headers": []}
40+
await local_send(message)
41+
42+
# Check the headers were sent to the original send method
43+
send.assert_called_once_with(
44+
{"type": message_type, "headers": [VERSION, API_VERSION]}
45+
)
46+
47+
48+
async def test_version_headers_ignore_non_http_or_websockets(app: Mock):
49+
vh = VersionHeaders(app)
50+
51+
scope = {"type": "other"}
52+
send = Mock()
53+
recv = Mock()
54+
55+
await vh(scope, recv, send)
56+
57+
# for non-http/ws requests, the original args are passed directly
58+
app.assert_called_once_with(scope, recv, send)
59+
60+
61+
async def test_obs_context_ignores_non_http_or_websockets(app: Mock):
62+
ocp = ObservabilityContextPropagator(app)
63+
64+
scope = MagicMock()
65+
scope.__getitem__.side_effect = {"type": "other"}.__getitem__
66+
67+
with patch("blueapi.service.middleware.attach") as att:
68+
await ocp(scope, Mock(), Mock())
69+
70+
att.assert_not_called()
71+
scope.get.assert_not_called()
72+
73+
74+
@pytest.mark.parametrize("protocol", ["http", "websocket"])
75+
async def test_obs_context_passes_context(app: Mock, protocol: str):
76+
ocp = ObservabilityContextPropagator(app)
77+
scope = {"type": protocol, "headers": ((CONTEXT_HEADER, b"req_context"),)}
78+
79+
with patch("blueapi.service.middleware.attach") as att:
80+
with patch("blueapi.service.middleware.get_global_textmap") as get_global:
81+
get_global.return_value.extract.side_effect = lambda x: x
82+
await ocp(scope, Mock(), Mock())
83+
84+
att.assert_called_once_with({ApplicationConfig.CONTEXT_HEADER: "req_context"})
85+
86+
87+
@pytest.mark.parametrize("protocol", ["http", "websocket"])
88+
async def test_obs_context_passes_vendor_context(app: Mock, protocol: str):
89+
ocp = ObservabilityContextPropagator(app)
90+
scope = {
91+
"type": protocol,
92+
"headers": (
93+
(CONTEXT_HEADER, b"req_context"),
94+
(VENDOR_CONTEXT_HEADER, b"vendor_context"),
95+
),
96+
}
97+
98+
with patch("blueapi.service.middleware.attach") as att:
99+
with patch("blueapi.service.middleware.get_global_textmap") as get_global:
100+
get_global.return_value.extract.side_effect = lambda x: x
101+
await ocp(scope, Mock(), Mock())
102+
103+
att.assert_called_once_with(
104+
{
105+
ApplicationConfig.CONTEXT_HEADER: "req_context",
106+
ApplicationConfig.VENDOR_CONTEXT_HEADER: "vendor_context",
107+
}
108+
)

0 commit comments

Comments
 (0)