Skip to content

Commit 2c2675b

Browse files
tpoliawabbiemery
authored andcommitted
refactor: Rewrite server middleware with websocket support
The builtin fastapi support for middleware only supports http/rest requests. To enable the same middleware for websockets, a new implementation of the starlette middleware is required and it makes sense to use the same implementation for both protocols. For rest requests, there should not be any change in behaviour from this change.
1 parent 53af457 commit 2c2675b

3 files changed

Lines changed: 67 additions & 38 deletions

File tree

src/blueapi/service/main.py

Lines changed: 7 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,19 @@
2525
get_tracer,
2626
start_as_current_span,
2727
)
28-
from opentelemetry.context import attach
2928
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
30-
from opentelemetry.propagate import get_global_textmap
3129
from opentelemetry.trace import get_tracer_provider
3230
from pydantic import ValidationError
3331
from pydantic.json_schema import SkipJsonSchema
3432
from starlette.responses import JSONResponse
3533
from super_state_machine.errors import TransitionError
3634

37-
from blueapi import __version__
3835
from blueapi.config import ApplicationConfig, OIDCConfig, Tag
3936
from blueapi.service import interface
37+
from blueapi.service.middleware import (
38+
ObservabilityContextPropagator,
39+
VersionHeaders,
40+
)
4041
from blueapi.worker import TrackableTask, WorkerState
4142
from blueapi.worker.event import TaskStatusEnum
4243

@@ -126,8 +127,9 @@ def get_app(config: ApplicationConfig):
126127
app.include_router(secure_router_v1, dependencies=dependencies)
127128
app.add_exception_handler(KeyError, on_key_error_404)
128129
app.add_exception_handler(jwt.PyJWTError, on_token_error_401)
129-
app.middleware("http")(add_version_headers)
130-
app.middleware("http")(inject_propagated_observability_context)
130+
131+
app.add_middleware(ObservabilityContextPropagator)
132+
app.add_middleware(VersionHeaders)
131133
app.middleware("http")(log_request_details)
132134
if config.api.cors:
133135
app.add_middleware(
@@ -607,15 +609,6 @@ def start(config: ApplicationConfig):
607609
)
608610

609611

610-
async def add_version_headers(
611-
request: Request, call_next: Callable[[Request], Awaitable[Response]]
612-
):
613-
response = await call_next(request)
614-
response.headers["X-API-Version"] = ApplicationConfig.REST_API_VERSION
615-
response.headers["X-BlueAPI-Version"] = __version__
616-
return response
617-
618-
619612
async def log_request_details(
620613
request: Request, call_next: Callable[[Request], Awaitable[StreamingResponse]]
621614
) -> Response:
@@ -637,25 +630,3 @@ async def log_request_details(
637630
LOGGER.info(log_message, extra=extra)
638631

639632
return response
640-
641-
642-
async def inject_propagated_observability_context(
643-
request: Request, call_next: Callable[[Request], Awaitable[Response]]
644-
) -> Response:
645-
"""Middleware to extract any propagated observability context from the
646-
HTTP headers and attach it to the local one.
647-
"""
648-
headers = request.headers
649-
if ApplicationConfig.CONTEXT_HEADER in headers:
650-
carrier = {
651-
ApplicationConfig.CONTEXT_HEADER: headers[ApplicationConfig.CONTEXT_HEADER]
652-
}
653-
if ApplicationConfig.VENDOR_CONTEXT_HEADER in headers:
654-
carrier[ApplicationConfig.VENDOR_CONTEXT_HEADER] = headers[
655-
ApplicationConfig.VENDOR_CONTEXT_HEADER
656-
]
657-
ctx = get_global_textmap().extract(carrier)
658-
659-
attach(ctx)
660-
response = await call_next(request)
661-
return response

src/blueapi/service/middleware.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import logging
2+
3+
from opentelemetry.context import attach
4+
from opentelemetry.propagate import get_global_textmap
5+
from starlette.types import ASGIApp, Message, Receive, Scope, Send
6+
7+
from blueapi import __version__
8+
from blueapi.config import ApplicationConfig
9+
10+
OBS_LOGGER = logging.getLogger("blueapi.service.middleware.observability")
11+
12+
CONTEXT_HEADER = ApplicationConfig.CONTEXT_HEADER.encode()
13+
VENDOR_CONTEXT_HEADER = ApplicationConfig.VENDOR_CONTEXT_HEADER.encode()
14+
15+
API_VERSION = (b"x-api-version", ApplicationConfig.REST_API_VERSION.encode("utf-8"))
16+
VERSION = (b"x-blueapi-version", __version__.encode("utf-8"))
17+
18+
19+
class VersionHeaders:
20+
def __init__(self, app: ASGIApp):
21+
self.app = app
22+
23+
async def __call__(self, scope: Scope, receive: Receive, send: Send):
24+
if scope.get("type") not in ("websocket", "http"):
25+
return await self.app(scope, receive, send)
26+
27+
async def local_send(message: Message):
28+
if message["type"] in ("websocket.accept", "http.response.start"):
29+
message["headers"].append(VERSION)
30+
message["headers"].append(API_VERSION)
31+
await send(message)
32+
33+
await self.app(scope, receive, local_send)
34+
35+
36+
class ObservabilityContextPropagator:
37+
def __init__(self, app: ASGIApp):
38+
self.app = app
39+
40+
async def __call__(self, scope: Scope, receive: Receive, send: Send):
41+
if scope["type"] not in ("http", "websocket"):
42+
return await self.app(scope, receive, send)
43+
44+
ctx = None
45+
v_ctx = None
46+
for key, val in scope.get("headers", ()):
47+
if key == CONTEXT_HEADER:
48+
ctx = val.decode()
49+
elif key == VENDOR_CONTEXT_HEADER:
50+
v_ctx = val.decode()
51+
if ctx:
52+
OBS_LOGGER.debug("Propagating observability context: %s, %s", ctx, v_ctx)
53+
carrier = {ApplicationConfig.CONTEXT_HEADER: ctx}
54+
if v_ctx:
55+
carrier[ApplicationConfig.VENDOR_CONTEXT_HEADER] = v_ctx
56+
attach(get_global_textmap().extract(carrier))
57+
58+
await self.app(scope, receive, send)

tests/unit_tests/service/test_main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@
88
from blueapi import __version__
99
from blueapi.config import ApplicationConfig
1010
from blueapi.service.main import (
11-
add_version_headers,
1211
get_passthrough_headers,
1312
log_request_details,
1413
)
14+
from blueapi.service.middleware import VersionHeaders
1515

1616

1717
async def test_add_version_header():
1818
app = FastAPI()
19-
app.middleware("http")(add_version_headers)
19+
app.add_middleware(VersionHeaders)
2020

2121
@app.get("/")
2222
async def root():

0 commit comments

Comments
 (0)