Skip to content

Commit d79de0f

Browse files
committed
refactor: Rewrite server middleware with websocket support
The builtin fastapi support for middleware only supports http/rest requests. To enable the same middleware for websockets, a new implementation of the starlette middleware is required and it makes sense to use the same implementation for both protocols. For rest requests, there should not be any change in behaviour from this change.
1 parent 107d54e commit d79de0f

3 files changed

Lines changed: 67 additions & 38 deletions

File tree

src/blueapi/service/main.py

Lines changed: 7 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,19 @@
2525
get_tracer,
2626
start_as_current_span,
2727
)
28-
from opentelemetry.context import attach
2928
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
30-
from opentelemetry.propagate import get_global_textmap
3129
from opentelemetry.trace import get_tracer_provider
3230
from pydantic import ValidationError
3331
from pydantic.json_schema import SkipJsonSchema
3432
from starlette.responses import JSONResponse
3533
from super_state_machine.errors import TransitionError
3634

37-
from blueapi import __version__
3835
from blueapi.config import ApplicationConfig, OIDCConfig, Tag
3936
from blueapi.service import interface
37+
from blueapi.service.middleware import (
38+
ObservabilityContextPropagator,
39+
VersionHeaders,
40+
)
4041
from blueapi.worker import TrackableTask, WorkerState
4142
from blueapi.worker.event import TaskStatusEnum
4243

@@ -124,8 +125,9 @@ def get_app(config: ApplicationConfig):
124125
app.include_router(secure_router, dependencies=dependencies)
125126
app.add_exception_handler(KeyError, on_key_error_404)
126127
app.add_exception_handler(jwt.PyJWTError, on_token_error_401)
127-
app.middleware("http")(add_version_headers)
128-
app.middleware("http")(inject_propagated_observability_context)
128+
129+
app.add_middleware(ObservabilityContextPropagator)
130+
app.add_middleware(VersionHeaders)
129131
app.middleware("http")(log_request_details)
130132
if config.api.cors:
131133
app.add_middleware(
@@ -569,15 +571,6 @@ def start(config: ApplicationConfig):
569571
)
570572

571573

572-
async def add_version_headers(
573-
request: Request, call_next: Callable[[Request], Awaitable[Response]]
574-
):
575-
response = await call_next(request)
576-
response.headers["X-API-Version"] = ApplicationConfig.REST_API_VERSION
577-
response.headers["X-BlueAPI-Version"] = __version__
578-
return response
579-
580-
581574
async def log_request_details(
582575
request: Request, call_next: Callable[[Request], Awaitable[StreamingResponse]]
583576
) -> Response:
@@ -599,25 +592,3 @@ async def log_request_details(
599592
LOGGER.info(log_message, extra=extra)
600593

601594
return response
602-
603-
604-
async def inject_propagated_observability_context(
605-
request: Request, call_next: Callable[[Request], Awaitable[Response]]
606-
) -> Response:
607-
"""Middleware to extract any propagated observability context from the
608-
HTTP headers and attach it to the local one.
609-
"""
610-
headers = request.headers
611-
if ApplicationConfig.CONTEXT_HEADER in headers:
612-
carrier = {
613-
ApplicationConfig.CONTEXT_HEADER: headers[ApplicationConfig.CONTEXT_HEADER]
614-
}
615-
if ApplicationConfig.VENDOR_CONTEXT_HEADER in headers:
616-
carrier[ApplicationConfig.VENDOR_CONTEXT_HEADER] = headers[
617-
ApplicationConfig.VENDOR_CONTEXT_HEADER
618-
]
619-
ctx = get_global_textmap().extract(carrier)
620-
621-
attach(ctx)
622-
response = await call_next(request)
623-
return response

src/blueapi/service/middleware.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import logging
2+
3+
from opentelemetry.context import attach
4+
from opentelemetry.propagate import get_global_textmap
5+
from starlette.types import ASGIApp, Message, Receive, Scope, Send
6+
7+
from blueapi import __version__
8+
from blueapi.config import ApplicationConfig
9+
10+
OBS_LOGGER = logging.getLogger("blueapi.service.middleware.observability")
11+
12+
CONTEXT_HEADER = ApplicationConfig.CONTEXT_HEADER.encode()
13+
VENDOR_CONTEXT_HEADER = ApplicationConfig.VENDOR_CONTEXT_HEADER.encode()
14+
15+
API_VERSION = (b"x-api-version", ApplicationConfig.REST_API_VERSION.encode("utf-8"))
16+
VERSION = (b"x-blueapi-version", __version__.encode("utf-8"))
17+
18+
19+
class VersionHeaders:
20+
def __init__(self, app: ASGIApp):
21+
self.app = app
22+
23+
async def __call__(self, scope: Scope, receive: Receive, send: Send):
24+
if scope.get("type") not in ("websocket", "http"):
25+
return await self.app(scope, receive, send)
26+
27+
async def local_send(message: Message):
28+
if message["type"] in ("websocket.accept", "http.response.start"):
29+
message["headers"].append(VERSION)
30+
message["headers"].append(API_VERSION)
31+
await send(message)
32+
33+
await self.app(scope, receive, local_send)
34+
35+
36+
class ObservabilityContextPropagator:
37+
def __init__(self, app: ASGIApp):
38+
self.app = app
39+
40+
async def __call__(self, scope: Scope, receive: Receive, send: Send):
41+
if scope["type"] not in ("http", "websocket"):
42+
return await self.app(scope, receive, send)
43+
44+
ctx = None
45+
v_ctx = None
46+
for key, val in scope.get("headers", ()):
47+
if key == CONTEXT_HEADER:
48+
ctx = val.decode()
49+
elif key == VENDOR_CONTEXT_HEADER:
50+
v_ctx = val.decode()
51+
if ctx:
52+
OBS_LOGGER.debug("Propagating observability context: %s, %s", ctx, v_ctx)
53+
carrier = {ApplicationConfig.CONTEXT_HEADER: ctx}
54+
if v_ctx:
55+
carrier[ApplicationConfig.VENDOR_CONTEXT_HEADER] = v_ctx
56+
attach(get_global_textmap().extract(carrier))
57+
58+
await self.app(scope, receive, send)

tests/unit_tests/service/test_main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@
88
from blueapi import __version__
99
from blueapi.config import ApplicationConfig
1010
from blueapi.service.main import (
11-
add_version_headers,
1211
get_passthrough_headers,
1312
log_request_details,
1413
)
14+
from blueapi.service.middleware import VersionHeaders
1515

1616

1717
async def test_add_version_header():
1818
app = FastAPI()
19-
app.middleware("http")(add_version_headers)
19+
app.add_middleware(VersionHeaders)
2020

2121
@app.get("/")
2222
async def root():

0 commit comments

Comments
 (0)