Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 30 additions & 35 deletions bq/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import logging
import platform
import socketserver
import sys
import threading
import typing
Expand All @@ -13,6 +14,7 @@
from importlib.metadata import version
from wsgiref.simple_server import make_server
from wsgiref.simple_server import WSGIRequestHandler
from wsgiref.simple_server import WSGIServer

import venusian
from sqlalchemy import func
Expand Down Expand Up @@ -52,6 +54,8 @@ def log_message(self, format, *args):
)
)

class ThreadingWSGIServer(socketserver.ThreadingMixIn, WSGIServer):
daemon_threads = True

class BeanQueue:
def __init__(
Expand All @@ -70,6 +74,9 @@ def __init__(
self._worker_update_shutdown_event: threading.Event = threading.Event()
# noop if metrics thread is not started yet, shutdown if it is started
self._metrics_server_shutdown: typing.Callable[[], None] = lambda: None
# Health state as atomic tuple: (is_ok, info_dict)
# Written by heartbeat/main threads, read by HTTP handler threads
self._health_state: tuple[bool, dict] = (False, {})

def create_default_engine(self):
# Use thread-safe connection pool when thread pool executor is enabled
Expand Down Expand Up @@ -195,6 +202,7 @@ def update_workers(
db.commit()

if current_worker.state != models.WorkerState.RUNNING:
self._health_state = (False, {"state": str(current_worker.state)})
# This probably means we are somehow very slow to update the heartbeat in time, or the timeout window
# is set too short. It could also be the administrator update the worker state to something else than
# RUNNING. Regardless the reason, let's stop processing.
Expand All @@ -214,51 +222,35 @@ def update_workers(
current_worker.last_heartbeat = func.now()
db.add(current_worker)
db.commit()
self._health_state = (
current_worker.state == models.WorkerState.RUNNING,
{"state": str(current_worker.state)},
)

def _serve_http_request(
self, worker_id: typing.Any, environ: dict, start_response: typing.Callable
) -> list[bytes]:
path = environ["PATH_INFO"]
if path == "/healthz":
db = self.make_session()
worker_service = self._make_worker_service(db)
worker = worker_service.get_worker(worker_id)
if worker is not None and worker.state == models.WorkerState.RUNNING:
start_response(
"200 OK",
[
("Content-Type", "application/json"),
],
)
return [
json.dumps(dict(status="ok", worker_id=str(worker_id))).encode(
"utf8"
)
]
health_ok, health_info = self._health_state
if health_ok:
start_response("200 OK", [("Content-Type", "application/json")])
return [json.dumps(dict(
status="ok",
worker_id=str(worker_id),
**health_info,
)).encode("utf8")]
else:
logger.warning("Bad worker %s state %s", worker_id, worker.state)
start_response(
"500 Internal Server Error",
[
("Content-Type", "application/json"),
],
[("Content-Type", "application/json")],
)
return [
json.dumps(
dict(
status="internal error",
worker_id=str(worker_id),
state=str(worker.state),
)
).encode("utf8")
]
# TODO: add other metrics endpoints
start_response(
"404 NOT FOUND",
[
("Content-Type", "application/json"),
],
)
return [json.dumps(dict(
status="error",
worker_id=str(worker_id),
**health_info,
)).encode("utf8")]
start_response("404 NOT FOUND", [("Content-Type", "application/json")])
return [json.dumps(dict(status="not found")).encode("utf8")]

def run_metrics_http_server(self, worker_id: typing.Any):
Expand All @@ -269,6 +261,7 @@ def run_metrics_http_server(self, worker_id: typing.Any):
port,
functools.partial(self._serve_http_request, worker_id),
handler_class=WSGIRequestHandlerWithLogger,
server_class=ThreadingWSGIServer,
) as httpd:
# expose graceful shutdown to the main thread
self._metrics_server_shutdown = httpd.shutdown
Expand Down Expand Up @@ -475,6 +468,7 @@ def process_tasks(
db.add(worker)
dispatch_service.listen(channels)
db.commit()
self._health_state = (True, {"state": "RUNNING"})

metrics_server_thread = None
if self.config.METRICS_HTTP_SERVER_ENABLED:
Expand Down Expand Up @@ -538,6 +532,7 @@ def process_tasks(
)
except (SystemExit, KeyboardInterrupt):
db.rollback()
self._health_state = (False, {})
logger.info("Shutting down ...")

# Shutdown the executor if it was created
Expand Down
100 changes: 100 additions & 0 deletions tests/unit/test_healthcheck.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import json
from unittest.mock import MagicMock

import pytest

from bq.app import BeanQueue
from bq.config import Config


def _make_environ(path: str) -> dict:
"""Build a minimal WSGI environ dict."""
return {"PATH_INFO": path, "REQUEST_METHOD": "GET"}


@pytest.fixture
def bq():
"""Create a BeanQueue with real Config (no real DB needed)."""
instance = BeanQueue(config=Config(
DATABASE_URL="postgresql://test@localhost/test",
))
return instance

class TestHealthzEndpoint:
"""Tests for the /healthz HTTP handler."""

def test_healthz_returns_200_when_healthy(self, bq):
bq._health_state = (True, {"state": "RUNNING"})

start_response = MagicMock()
result = bq._serve_http_request("42", _make_environ("/healthz"), start_response)

start_response.assert_called_once_with(
"200 OK", [("Content-Type", "application/json")]
)
body = json.loads(result[0])
assert body["status"] == "ok"
assert body["worker_id"] == "42"
assert body["state"] == "RUNNING"

def test_healthz_returns_500_when_unhealthy(self, bq):
bq._health_state = (False, {"state": "SHUTDOWN"})

start_response = MagicMock()
result = bq._serve_http_request("42", _make_environ("/healthz"), start_response)

start_response.assert_called_once_with(
"500 Internal Server Error",
[("Content-Type", "application/json")],
)
body = json.loads(result[0])
assert body["status"] == "error"
assert body["worker_id"] == "42"
assert body["state"] == "SHUTDOWN"

def test_healthz_returns_500_before_worker_initialized(self, bq):
"""Before process_tasks runs, _health_ok is False and _health_info is empty."""
start_response = MagicMock()
result = bq._serve_http_request("1", _make_environ("/healthz"), start_response)

start_response.assert_called_once_with(
"500 Internal Server Error",
[("Content-Type", "application/json")],
)
body = json.loads(result[0])
assert body["status"] == "error"
assert body["worker_id"] == "1"

def test_healthz_does_not_create_db_session(self, bq):
"""The critical fix: /healthz must never touch the DB."""
bq._health_state = (True, {"state": "RUNNING"})

bq.make_session = MagicMock()
start_response = MagicMock()
bq._serve_http_request("42", _make_environ("/healthz"), start_response)

bq.make_session.assert_not_called()

def test_unknown_path_returns_404(self, bq):
start_response = MagicMock()
result = bq._serve_http_request("42", _make_environ("/unknown"), start_response)

start_response.assert_called_once_with(
"404 NOT FOUND", [("Content-Type", "application/json")]
)
body = json.loads(result[0])
assert body["status"] == "not found"

def test_404_does_not_create_db_session(self, bq):
bq.make_session = MagicMock()
start_response = MagicMock()
bq._serve_http_request("42", _make_environ("/anything"), start_response)

bq.make_session.assert_not_called()


class TestHealthStateInitialization:
"""Tests that _health_ok defaults correctly."""

def test_defaults_to_unhealthy(self, bq):
assert bq._health_state == (False, {})