Skip to content

Commit 420e73f

Browse files
committed
Add websocket support to middleware
1 parent be070c4 commit 420e73f

3 files changed

Lines changed: 69 additions & 14 deletions

File tree

src/blueapi/service/main.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,14 @@
3636
from starlette.responses import JSONResponse
3737
from super_state_machine.errors import TransitionError
3838

39-
from blueapi import __version__
4039
from blueapi.config import ApplicationConfig, OIDCConfig, Tag
4140
from blueapi.core.bluesky_types import DataEvent
4241
from blueapi.service import interface
4342
from blueapi.service.authentication import CommonHttpOAuth
43+
from blueapi.service.middleware import (
44+
ObservabilityContextPropagator,
45+
VersionHeaders,
46+
)
4447
from blueapi.worker import TrackableTask, WorkerState
4548
from blueapi.worker.event import ProgressEvent, TaskStatusEnum, WorkerEvent
4649
from blueapi.worker.worker_errors import WorkerBusyError
@@ -132,8 +135,9 @@ def get_app(config: ApplicationConfig):
132135
app.include_router(secure_router, dependencies=dependencies)
133136
app.add_exception_handler(KeyError, on_key_error_404)
134137
app.add_exception_handler(jwt.PyJWTError, on_token_error_401)
135-
app.middleware("http")(add_version_headers)
136-
app.middleware("http")(inject_propagated_observability_context)
138+
139+
app.add_middleware(ObservabilityContextPropagator)
140+
app.add_middleware(VersionHeaders)
137141
app.middleware("http")(log_request_details)
138142
if config.api.cors:
139143
app.add_middleware(
@@ -625,15 +629,6 @@ def start(config: ApplicationConfig):
625629
)
626630

627631

628-
async def add_version_headers(
629-
request: Request, call_next: Callable[[Request], Awaitable[Response]]
630-
):
631-
response = await call_next(request)
632-
response.headers["X-API-Version"] = ApplicationConfig.REST_API_VERSION
633-
response.headers["X-BlueAPI-Version"] = __version__
634-
return response
635-
636-
637632
async def log_request_details(
638633
request: Request, call_next: Callable[[Request], Awaitable[StreamingResponse]]
639634
) -> Response:

src/blueapi/service/middleware.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import logging
2+
3+
from opentelemetry.context import attach
4+
from opentelemetry.propagate import get_global_textmap
5+
from starlette.types import ASGIApp, Message, Receive, Scope, Send
6+
7+
from blueapi import __version__
8+
from blueapi.config import ApplicationConfig
9+
10+
OBS_LOGGER = logging.getLogger("blueapi.service.middleware.obs")
11+
VER_LOGGER = logging.getLogger("blueapi.service.middleware.version")
12+
13+
CONTEXT_HEADER = ApplicationConfig.CONTEXT_HEADER.encode()
14+
VENDOR_CONTEXT_HEADER = ApplicationConfig.VENDOR_CONTEXT_HEADER.encode()
15+
16+
API_VERSION = (b"x-api-version", ApplicationConfig.REST_API_VERSION.encode("utf-8"))
17+
VERSION = (b"x-blueapi-version", __version__.encode("utf-8"))
18+
19+
20+
class VersionHeaders:
21+
def __init__(self, app: ASGIApp):
22+
self.app = app
23+
24+
async def __call__(self, scope: Scope, receive: Receive, send: Send):
25+
if scope.get("type") not in ("websocket", "http"):
26+
return await self.app(scope, receive, send)
27+
28+
async def local_send(message: Message):
29+
VER_LOGGER.info("message: %s", message)
30+
if message["type"] in ("websocket.accept", "http.response.start"):
31+
message["headers"].append(VERSION)
32+
message["headers"].append(API_VERSION)
33+
await send(message)
34+
35+
await self.app(scope, receive, local_send)
36+
37+
38+
class ObservabilityContextPropagator:
39+
def __init__(self, app: ASGIApp):
40+
self.app = app
41+
42+
async def __call__(self, scope: Scope, receive: Receive, send: Send):
43+
if scope["type"] not in ("http", "websocket"):
44+
return await self.app(scope, receive, send)
45+
46+
ctx = None
47+
v_ctx = None
48+
for key, val in scope.get("headers", ()):
49+
if key == CONTEXT_HEADER:
50+
ctx = val.decode()
51+
elif key == VENDOR_CONTEXT_HEADER:
52+
v_ctx = val.decode()
53+
if ctx:
54+
OBS_LOGGER.debug("Propagating observability context: %s, %s", ctx, v_ctx)
55+
carrier = {ApplicationConfig.CONTEXT_HEADER: ctx}
56+
if v_ctx:
57+
carrier[ApplicationConfig.VENDOR_CONTEXT_HEADER] = v_ctx
58+
attach(get_global_textmap().extract(carrier))
59+
60+
await self.app(scope, receive, send)

tests/unit_tests/service/test_main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@
88
from blueapi import __version__
99
from blueapi.config import ApplicationConfig
1010
from blueapi.service.main import (
11-
add_version_headers,
1211
get_passthrough_headers,
1312
log_request_details,
1413
)
14+
from blueapi.service.middleware import VersionHeaders
1515

1616

1717
async def test_add_version_header():
1818
app = FastAPI()
19-
app.middleware("http")(add_version_headers)
19+
app.add_middleware(VersionHeaders)
2020

2121
@app.get("/")
2222
async def root():

0 commit comments

Comments
 (0)