diff --git a/.fernignore b/.fernignore index 01c9184..38b5616 100644 --- a/.fernignore +++ b/.fernignore @@ -13,6 +13,7 @@ src/schematic/client.py scripts/ src/schematic/cache/ src/schematic/event_buffer.py +src/schematic/event_capture.py src/schematic/http_client.py src/schematic/logging.py src/schematic/datastream/ diff --git a/README.md b/README.md index cce6876..4d68255 100644 --- a/README.md +++ b/README.md @@ -642,6 +642,25 @@ client = Schematic("", config) client.check_flag("some-flag-key") # Returns True ``` +You can also set flag defaults dynamically after the client has been constructed using `set_flag_default` and `set_flag_defaults`. This is useful in automated testing contexts, where you may want to specify per-test flag values: + +```python +from schematic.client import Schematic, SchematicConfig + +client = Schematic("", SchematicConfig(offline=True)) + +# Set a single flag default +client.set_flag_default("some-flag-key", True) + +# Or set multiple flag defaults at once +client.set_flag_defaults({ + "some-flag-key": True, + "another-flag-key": False, +}) + +client.check_flag("some-flag-key") # Returns True +``` + ### Timeouts By default, requests time out after 60 seconds. You can configure this with a timeout option at the client or request level. diff --git a/src/schematic/client.py b/src/schematic/client.py index 1f1827f..e621a14 100644 --- a/src/schematic/client.py +++ b/src/schematic/client.py @@ -4,11 +4,11 @@ from typing import Any, Callable, Dict, List, Optional, Union import httpx - from .base_client import AsyncBaseSchematic, BaseSchematic from .cache import DEFAULT_CACHE_SIZE, DEFAULT_CACHE_TTL, AsyncCacheProvider, CacheProvider, LocalCache from .datastream import DataStreamClient, DataStreamClientOptions from .event_buffer import AsyncEventBuffer, EventBuffer +from .event_capture import AsyncEventCaptureClient, EventCaptureClient from .http_client import AsyncOfflineHTTPClient, OfflineHTTPClient from .logging import get_default_logger from .types import ( @@ -59,6 +59,7 @@ class DataStreamConfig: class SchematicConfig: base_url: Optional[str] = None event_buffer_period: Optional[int] = None + event_capture_url: Optional[str] = None flag_defaults: Optional[Dict[str, bool]] = None follow_redirects: Optional[bool] = True httpx_client: Optional[httpx.Client] = None @@ -83,8 +84,14 @@ def __init__(self, api_key: str, config: Optional[SchematicConfig] = None): self.event_buffer_period = config.event_buffer_period self.logger = config.logger or get_default_logger() self.flag_defaults = config.flag_defaults or {} + self.event_capture_client = EventCaptureClient( + api_key=api_key, + base_url=config.event_capture_url, + httpx_client=httpx_client, + get_headers=self._client_wrapper.get_headers, + ) self.event_buffer = EventBuffer( - events_api=self.events, + event_sender=self.event_capture_client, logger=self.logger, period=self.event_buffer_period, ) @@ -101,6 +108,7 @@ def initialize(self) -> None: def shutdown(self) -> None: self.event_buffer.stop() + self.event_capture_client.close() def check_flag( self, @@ -305,6 +313,12 @@ def _enqueue_event(self, event_type: str, body: EventBody) -> None: def _get_flag_default(self, flag_key: str) -> bool: return self.flag_defaults.get(flag_key, False) + def set_flag_default(self, flag_key: str, value: bool) -> None: + self.flag_defaults[flag_key] = value + + def set_flag_defaults(self, values: Dict[str, bool]) -> None: + self.flag_defaults.update(values) + def _resolve_default(self, flag_key: str, options: Optional[CheckFlagOptions] = None) -> bool: if options and options.default_value is not None: if callable(options.default_value): @@ -317,6 +331,7 @@ def _resolve_default(self, flag_key: str, options: Optional[CheckFlagOptions] = class AsyncSchematicConfig: base_url: Optional[str] = None event_buffer_period: Optional[int] = None + event_capture_url: Optional[str] = None flag_defaults: Optional[Dict[str, bool]] = None follow_redirects: Optional[bool] = True httpx_client: Optional[httpx.AsyncClient] = None @@ -372,8 +387,14 @@ def __init__(self, api_key: str, config: Optional[AsyncSchematicConfig] = None): self.event_buffer_period = config.event_buffer_period self.logger = config.logger or get_default_logger() self.flag_defaults = config.flag_defaults or {} + self.event_capture_client = AsyncEventCaptureClient( + api_key=api_key, + base_url=config.event_capture_url, + httpx_client=httpx_client, + get_headers=self._client_wrapper.get_headers, + ) self.event_buffer = AsyncEventBuffer( - events_api=self.events, + event_sender=self.event_capture_client, logger=self.logger, period=self.event_buffer_period, ) @@ -727,6 +748,12 @@ async def _enqueue_event(self, event_type: str, body: EventBody) -> None: def _get_flag_default(self, flag_key: str) -> bool: return self.flag_defaults.get(flag_key, False) + def set_flag_default(self, flag_key: str, value: bool) -> None: + self.flag_defaults[flag_key] = value + + def set_flag_defaults(self, values: Dict[str, bool]) -> None: + self.flag_defaults.update(values) + def _resolve_default(self, flag_key: str, options: Optional[CheckFlagOptions] = None) -> bool: if options and options.default_value is not None: if callable(options.default_value): @@ -767,6 +794,7 @@ async def shutdown(self) -> None: # Flush and stop the event buffer await self.event_buffer.stop() + await self.event_capture_client.close() self.logger.info("Shutdown complete.") except Exception as e: self.logger.error(f"Error during shutdown: {e}") diff --git a/src/schematic/event_buffer.py b/src/schematic/event_buffer.py index 617bea5..cccf564 100644 --- a/src/schematic/event_buffer.py +++ b/src/schematic/event_buffer.py @@ -5,7 +5,7 @@ import time from typing import List, Optional -from .events.client import AsyncEventsClient, EventsClient +from .event_capture import AsyncEventCaptureClient, EventCaptureClient from .types import CreateEventRequestBody DEFAULT_MAX_EVENTS = 100 # Default maximum number of events @@ -17,7 +17,7 @@ class EventBuffer: def __init__( self, - events_api: EventsClient, + event_sender: EventCaptureClient, logger: logging.Logger, period: Optional[int] = None, max_events: int = DEFAULT_MAX_EVENTS, @@ -25,7 +25,7 @@ def __init__( initial_retry_delay: float = DEFAULT_INITIAL_RETRY_DELAY, ): self.events: List[CreateEventRequestBody] = [] - self.events_api = events_api + self.event_sender = event_sender self.interval = period or DEFAULT_EVENT_BUFFER_PERIOD self.logger = logger self.max_events = max_events @@ -61,7 +61,7 @@ def _process_events(self, events_to_process): if retry_count > 0: self.logger.info(f"Retrying event batch submission (attempt {retry_count} of {self.max_retries})") - self.events_api.create_event_batch(events=events_to_process) + self.event_sender.send_batch(events_to_process) success = True except Exception as e: @@ -125,7 +125,7 @@ def stop(self): class AsyncEventBuffer: def __init__( self, - events_api: AsyncEventsClient, + event_sender: AsyncEventCaptureClient, logger: logging.Logger, period: Optional[int] = None, max_events: int = DEFAULT_MAX_EVENTS, @@ -133,7 +133,7 @@ def __init__( initial_retry_delay: float = DEFAULT_INITIAL_RETRY_DELAY, ): self.events: List[CreateEventRequestBody] = [] - self.events_api = events_api + self.event_sender = event_sender self.interval = period or DEFAULT_EVENT_BUFFER_PERIOD self.logger = logger self.max_events = max_events @@ -168,7 +168,7 @@ async def _process_events_async(self, events_to_process): if retry_count > 0: self.logger.info(f"Retrying event batch submission (attempt {retry_count} of {self.max_retries})") - await self.events_api.create_event_batch(events=events_to_process) + await self.event_sender.send_batch(events_to_process) success = True except Exception as e: diff --git a/src/schematic/event_capture.py b/src/schematic/event_capture.py new file mode 100644 index 0000000..8f7196f --- /dev/null +++ b/src/schematic/event_capture.py @@ -0,0 +1,142 @@ +import datetime as dt +import typing + +import httpx +import pydantic +from .core.pydantic_utilities import UniversalBaseModel +from .types import CreateEventRequestBody, EventBody, EventType + +DEFAULT_EVENT_CAPTURE_BASE_URL = "https://c.schematichq.com" +DEFAULT_TIMEOUT = 10.0 + + +class _CaptureEventPayload(UniversalBaseModel): + """Wire format for a single event sent to the capture service. + + Mirrors the shape used by the Go/Ruby/C# SDKs: `type` (not `event_type`) + and an `api_key` field embedded on each event. + """ + + api_key: str = pydantic.Field() + body: typing.Optional[EventBody] = None + type: EventType = pydantic.Field() + sent_at: typing.Optional[dt.datetime] = None + + +class _CaptureBatchPayload(UniversalBaseModel): + events: typing.List[_CaptureEventPayload] + + +def _to_payload(event: CreateEventRequestBody, api_key: str) -> _CaptureEventPayload: + return _CaptureEventPayload( + api_key=api_key, + body=event.body, + type=event.event_type, + sent_at=event.sent_at, + ) + + +def _build_endpoint(base_url: str) -> str: + return base_url.rstrip("/") + "/batch" + + +def _build_headers( + api_key: str, + get_headers: typing.Optional[typing.Callable[[], typing.Dict[str, str]]] = None, +) -> typing.Dict[str, str]: + """Build the headers for capture-service requests. + + Reuses the API client's header builder when provided so we send the same + X-Fern-* / User-Agent / X-Schematic-Api-Key headers as the REST client — + keeps observability consistent across both endpoints. + """ + headers: typing.Dict[str, str] = {} + if get_headers is not None: + headers.update(get_headers()) + else: + headers["X-Schematic-Api-Key"] = api_key + headers["Content-Type"] = "application/json" + return headers + + +def _serialize_batch( + events: typing.List[CreateEventRequestBody], api_key: str +) -> str: + batch = _CaptureBatchPayload( + events=[_to_payload(e, api_key) for e in events], + ) + return batch.model_dump_json(by_alias=True, exclude_none=True) + + +class EventCaptureClient: + """HTTP client for sending events to the Schematic event capture service.""" + + def __init__( + self, + api_key: str, + base_url: typing.Optional[str] = None, + httpx_client: typing.Optional[httpx.Client] = None, + get_headers: typing.Optional[typing.Callable[[], typing.Dict[str, str]]] = None, + ): + self._api_key = api_key + self._base_url = base_url or DEFAULT_EVENT_CAPTURE_BASE_URL + self._owns_client = httpx_client is None + self._httpx_client = httpx_client or httpx.Client(timeout=DEFAULT_TIMEOUT) + self._get_headers = get_headers + + def send_batch(self, events: typing.List[CreateEventRequestBody]) -> None: + if not events: + return + + body = _serialize_batch(events, self._api_key) + response = self._httpx_client.post( + _build_endpoint(self._base_url), + content=body, + headers=_build_headers(self._api_key, self._get_headers), + ) + + if response.status_code < 200 or response.status_code >= 300: + raise RuntimeError( + f"capture service returned HTTP {response.status_code}: {response.text}" + ) + + def close(self) -> None: + if self._owns_client: + self._httpx_client.close() + + +class AsyncEventCaptureClient: + """Async HTTP client for sending events to the Schematic event capture service.""" + + def __init__( + self, + api_key: str, + base_url: typing.Optional[str] = None, + httpx_client: typing.Optional[httpx.AsyncClient] = None, + get_headers: typing.Optional[typing.Callable[[], typing.Dict[str, str]]] = None, + ): + self._api_key = api_key + self._base_url = base_url or DEFAULT_EVENT_CAPTURE_BASE_URL + self._owns_client = httpx_client is None + self._httpx_client = httpx_client or httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) + self._get_headers = get_headers + + async def send_batch(self, events: typing.List[CreateEventRequestBody]) -> None: + if not events: + return + + body = _serialize_batch(events, self._api_key) + response = await self._httpx_client.post( + _build_endpoint(self._base_url), + content=body, + headers=_build_headers(self._api_key, self._get_headers), + ) + + if response.status_code < 200 or response.status_code >= 300: + raise RuntimeError( + f"capture service returned HTTP {response.status_code}: {response.text}" + ) + + async def close(self) -> None: + if self._owns_client: + await self._httpx_client.aclose() diff --git a/tests/custom/test_event_buffer.py b/tests/custom/test_event_buffer.py index 7068f43..64407bd 100644 --- a/tests/custom/test_event_buffer.py +++ b/tests/custom/test_event_buffer.py @@ -1,7 +1,7 @@ import threading import time import unittest -from unittest.mock import MagicMock, patch, call +from unittest.mock import AsyncMock, MagicMock, patch, call import pytest import asyncio @@ -13,10 +13,10 @@ class TestEventBuffer(unittest.TestCase): def setUp(self): - self.mock_api = MagicMock() + self.mock_sender = MagicMock() self.mock_logger = MagicMock() self.event_buffer = EventBuffer( - events_api=self.mock_api, logger=self.mock_logger, period=1, max_events=5 + event_sender=self.mock_sender, logger=self.mock_logger, period=1, max_events=5 ) def tearDown(self): @@ -45,7 +45,7 @@ def test_flush(self): self.event_buffer.events = [event] self.event_buffer._flush() - self.mock_api.create_event_batch.assert_called_once_with(events=[event]) + self.mock_sender.send_batch.assert_called_once_with([event]) self.assertEqual(len(self.event_buffer.events), 0) @@ -59,10 +59,10 @@ def test_shutdown_flushes_remaining(self): Verify that stop() flushes buffered events even if batch isn't full. """ - mock_api = MagicMock() + mock_sender = MagicMock() mock_logger = MagicMock() buffer = EventBuffer( - events_api=mock_api, + event_sender=mock_sender, logger=mock_logger, period=10, # Long period so periodic flush won't trigger max_events=100, # Large batch so auto-flush won't trigger @@ -74,14 +74,14 @@ def test_shutdown_flushes_remaining(self): buffer.push(event) # No flush should have happened yet - mock_api.create_event_batch.assert_not_called() + mock_sender.send_batch.assert_not_called() # Stop the buffer, which should flush remaining events buffer.stop() # Verify all events were flushed - mock_api.create_event_batch.assert_called_once() - flushed_events = mock_api.create_event_batch.call_args.kwargs["events"] + mock_sender.send_batch.assert_called_once() + flushed_events = mock_sender.send_batch.call_args.args[0] self.assertEqual(len(flushed_events), 5) def test_push_after_shutdown_rejected(self): @@ -90,10 +90,10 @@ def test_push_after_shutdown_rejected(self): Mirrors Go/Node behavior — after shutdown, the buffer logs an error and refuses to accept new events rather than queuing them. """ - mock_api = MagicMock() + mock_sender = MagicMock() mock_logger = MagicMock() buffer = EventBuffer( - events_api=mock_api, + event_sender=mock_sender, logger=mock_logger, period=10, max_events=100, @@ -110,17 +110,17 @@ def test_push_after_shutdown_rejected(self): "Event buffer is stopped, not accepting new events" ) # And no API call should have been issued for the rejected event - mock_api.create_event_batch.assert_not_called() + mock_sender.send_batch.assert_not_called() def test_concurrent_push(self): """Corresponds to Go TestEventBuffer_ConcurrentPush. Verify no events are lost when pushing from multiple threads. """ - mock_api = MagicMock() + mock_sender = MagicMock() mock_logger = MagicMock() buffer = EventBuffer( - events_api=mock_api, + event_sender=mock_sender, logger=mock_logger, period=10, # Long period to avoid periodic flush during test max_events=1000, # Large batch to avoid auto-flush @@ -152,8 +152,8 @@ def worker(): # Count total events sent total_sent = sum( - len(c.kwargs["events"]) - for c in mock_api.create_event_batch.call_args_list + len(c.args[0]) + for c in mock_sender.send_batch.call_args_list ) self.assertEqual(total_sent, total_expected) @@ -163,7 +163,7 @@ class TestAsyncEventBuffer: async def test_push_event(self): # Create a separate mock and buffer instance just for this test - mock_api = MagicMock() + mock_sender = AsyncMock() mock_logger = MagicMock() task_mock = MagicMock() @@ -171,7 +171,7 @@ async def test_push_event(self): with patch('asyncio.create_task', return_value=task_mock): # Then create the buffer, which uses create_task buffer = AsyncEventBuffer( - events_api=mock_api, logger=mock_logger, period=1, max_events=5 + event_sender=mock_sender, logger=mock_logger, period=1, max_events=5 ) # Test push event @@ -186,7 +186,7 @@ async def test_push_event(self): async def test_push_event_exceeding_max_events(self): # Create a separate mock and buffer instance just for this test - mock_api = MagicMock() + mock_sender = AsyncMock() mock_logger = MagicMock() task_mock = MagicMock() @@ -194,7 +194,7 @@ async def test_push_event_exceeding_max_events(self): with patch('asyncio.create_task', return_value=task_mock): # Then create the buffer, which uses create_task buffer = AsyncEventBuffer( - events_api=mock_api, logger=mock_logger, period=1, max_events=5 + event_sender=mock_sender, logger=mock_logger, period=1, max_events=5 ) # Setup test @@ -215,7 +215,7 @@ async def test_push_event_exceeding_max_events(self): async def test_flush(self): # Create a separate mock and buffer instance just for this test - mock_api = MagicMock() + mock_sender = AsyncMock() mock_logger = MagicMock() task_mock = MagicMock() @@ -223,7 +223,7 @@ async def test_flush(self): with patch('asyncio.create_task', return_value=task_mock): # Also patch the max_retries to 0 to disable retry behavior for this test buffer = AsyncEventBuffer( - events_api=mock_api, logger=mock_logger, period=1, max_events=5, + event_sender=mock_sender, logger=mock_logger, period=1, max_events=5, max_retries=0 # Disable retries for this specific test ) @@ -235,7 +235,7 @@ async def test_flush(self): await buffer._flush() # Verify expectations - mock_api.create_event_batch.assert_called_once_with(events=[event]) + mock_sender.send_batch.assert_called_once_with([event]) assert len(buffer.events) == 0 # Clean up @@ -244,7 +244,7 @@ async def test_flush(self): async def test_stop(self): # Create a separate mock and buffer instance just for this test - mock_api = MagicMock() + mock_sender = AsyncMock() mock_logger = MagicMock() task_mock = MagicMock() @@ -252,7 +252,7 @@ async def test_stop(self): with patch('asyncio.create_task', return_value=task_mock): # Then create the buffer, which uses create_task buffer = AsyncEventBuffer( - events_api=mock_api, logger=mock_logger, period=1, max_events=5 + event_sender=mock_sender, logger=mock_logger, period=1, max_events=5 ) # Test stopping the buffer @@ -264,13 +264,13 @@ async def test_stop(self): async def test_shutdown_flushes_remaining(self): """Corresponds to Go TestEventBuffer_ShutdownFlushesRemaining (async).""" - mock_api = MagicMock() + mock_sender = AsyncMock() mock_logger = MagicMock() task_mock = MagicMock() with patch('asyncio.create_task', return_value=task_mock): buffer = AsyncEventBuffer( - events_api=mock_api, + event_sender=mock_sender, logger=mock_logger, period=10, max_events=100, @@ -282,25 +282,25 @@ async def test_shutdown_flushes_remaining(self): event = MagicMock(spec=CreateEventRequestBody) await buffer.push(event) - mock_api.create_event_batch.assert_not_called() + mock_sender.send_batch.assert_not_called() # Stop should flush remaining events await buffer.stop() - mock_api.create_event_batch.assert_called_once() - flushed = mock_api.create_event_batch.call_args.kwargs["events"] + mock_sender.send_batch.assert_called_once() + flushed = mock_sender.send_batch.call_args.args[0] assert len(flushed) == 5 async def test_push_after_shutdown_rejected(self): """Async equivalent of TestEventBuffer.test_push_after_shutdown_rejected.""" - mock_api = MagicMock() + mock_sender = AsyncMock() mock_logger = MagicMock() task_mock = MagicMock() with patch('asyncio.create_task', return_value=task_mock): buffer = AsyncEventBuffer( - events_api=mock_api, + event_sender=mock_sender, logger=mock_logger, period=10, max_events=100, @@ -316,7 +316,7 @@ async def test_push_after_shutdown_rejected(self): mock_logger.error.assert_called_with( "Event buffer is stopped, not accepting new events" ) - mock_api.create_event_batch.assert_not_called() + mock_sender.send_batch.assert_not_called() if __name__ == "__main__": diff --git a/tests/custom/test_event_buffer_retry.py b/tests/custom/test_event_buffer_retry.py index 3e037cb..6434a0d 100644 --- a/tests/custom/test_event_buffer_retry.py +++ b/tests/custom/test_event_buffer_retry.py @@ -12,10 +12,10 @@ class TestEventBufferRetry(unittest.TestCase): """Test the retry mechanism in the synchronous EventBuffer.""" def setUp(self): - self.mock_api = MagicMock() + self.mock_sender = MagicMock() self.mock_logger = MagicMock() self.event_buffer = EventBuffer( - events_api=self.mock_api, logger=self.mock_logger, period=1, max_events=5 + event_sender=self.mock_sender, logger=self.mock_logger, period=1, max_events=5 ) def tearDown(self): @@ -27,7 +27,7 @@ def test_flush_with_retry(self): self.event_buffer.events = [event] # Configure mock to fail twice then succeed on third attempt - self.mock_api.create_event_batch.side_effect = [ + self.mock_sender.send_batch.side_effect = [ Exception("API failure 1"), Exception("API failure 2"), None # Success @@ -37,7 +37,7 @@ def test_flush_with_retry(self): self.event_buffer._flush() # Verify retry attempts - self.assertEqual(self.mock_api.create_event_batch.call_count, 3) + self.assertEqual(self.mock_sender.send_batch.call_count, 3) self.assertEqual(mock_sleep.call_count, 2) # Sleep called twice (between retries) # Verify events are cleared after success @@ -58,10 +58,10 @@ def test_exponential_backoff_timing(self): Mirrors the spec's "Event buffer exponential backoff timing is correct" check item — currently missing in every SDK. """ - mock_api = MagicMock() + mock_sender = MagicMock() mock_logger = MagicMock() buffer = EventBuffer( - events_api=mock_api, + event_sender=mock_sender, logger=mock_logger, period=10, max_events=100, @@ -72,7 +72,7 @@ def test_exponential_backoff_timing(self): try: event = MagicMock(spec=CreateEventRequestBody) buffer.events = [event] - mock_api.create_event_batch.side_effect = Exception("always fails") + mock_sender.send_batch.side_effect = Exception("always fails") sleeps = [] with patch("time.sleep", side_effect=sleeps.append): @@ -101,13 +101,13 @@ def test_flush_with_max_retries_exhausted(self): self.event_buffer.events = [event] # Configure mock to always fail - self.mock_api.create_event_batch.side_effect = Exception("API failure") + self.mock_sender.send_batch.side_effect = Exception("API failure") with patch("time.sleep") as mock_sleep: # Mock sleep to speed up test self.event_buffer._flush() # Verify all retry attempts were made - self.assertEqual(self.mock_api.create_event_batch.call_count, DEFAULT_MAX_RETRIES + 1) + self.assertEqual(self.mock_sender.send_batch.call_count, DEFAULT_MAX_RETRIES + 1) self.assertEqual(mock_sleep.call_count, DEFAULT_MAX_RETRIES) # Verify events are cleared even after failure @@ -124,18 +124,18 @@ class TestAsyncEventBufferRetry: @pytest.fixture async def buffer_with_mock_periodic_flush(self): """Setup an AsyncEventBuffer with a mocked _periodic_flush function.""" - mock_api = AsyncMock() + mock_sender = AsyncMock() mock_logger = MagicMock() - + # Create the buffer buffer = AsyncEventBuffer( - events_api=mock_api, logger=mock_logger, period=1, max_events=5 + event_sender=mock_sender, logger=mock_logger, period=1, max_events=5 ) - + # Replace the _periodic_flush with a dummy coro that just returns immediately async def dummy_periodic_flush(): pass - + # Patch the method and cancel the task with patch.object(buffer, '_periodic_flush', dummy_periodic_flush): buffer.flush_task.cancel() # Cancel the original task to avoid warnings @@ -143,56 +143,56 @@ async def dummy_periodic_flush(): await buffer.flush_task except asyncio.CancelledError: pass - + # Create a new task with our dummy function buffer.flush_task = asyncio.create_task(dummy_periodic_flush()) - - yield buffer, mock_api, mock_logger - + + yield buffer, mock_sender, mock_logger + # Clean up after test buffer.flush_task.cancel() await buffer.stop() async def test_flush_with_retry(self, buffer_with_mock_periodic_flush): """Test that the AsyncEventBuffer retries failed API calls.""" - buffer, mock_api, mock_logger = buffer_with_mock_periodic_flush - + buffer, mock_sender, mock_logger = buffer_with_mock_periodic_flush + # Setup test data event = MagicMock(spec=CreateEventRequestBody) buffer.events = [event] - + # Configure mock to fail twice then succeed on third attempt - mock_api.create_event_batch.side_effect = [ + mock_sender.send_batch.side_effect = [ Exception("API failure 1"), Exception("API failure 2"), None # Success ] - + # Execute with mocked sleep with patch("asyncio.sleep", return_value=None) as mock_sleep: await buffer._flush() - + # Verify retry attempts - assert mock_api.create_event_batch.call_count == 3 + assert mock_sender.send_batch.call_count == 3 assert mock_sleep.call_count == 2 # Sleep called twice (between retries) - + # Verify events are cleared after success assert len(buffer.events) == 0 - + # Verify logging mock_logger.warning.assert_called() mock_logger.info.assert_called_with("Event batch submission succeeded after 2 retries") async def test_exponential_backoff_timing(self, buffer_with_mock_periodic_flush): """Async equivalent of TestEventBufferRetry.test_exponential_backoff_timing.""" - buffer, mock_api, _ = buffer_with_mock_periodic_flush + buffer, mock_sender, _ = buffer_with_mock_periodic_flush # Use canonical defaults: initial_retry_delay=1, max_retries=3 buffer.initial_retry_delay = 1 buffer.max_retries = 3 event = MagicMock(spec=CreateEventRequestBody) buffer.events = [event] - mock_api.create_event_batch.side_effect = Exception("always fails") + mock_sender.send_batch.side_effect = Exception("always fails") sleeps = [] @@ -214,28 +214,28 @@ async def fake_sleep(delay): async def test_flush_with_max_retries_exhausted(self, buffer_with_mock_periodic_flush): """Test that the AsyncEventBuffer gives up after max_retries attempts.""" - buffer, mock_api, mock_logger = buffer_with_mock_periodic_flush - + buffer, mock_sender, mock_logger = buffer_with_mock_periodic_flush + # Setup test data event = MagicMock(spec=CreateEventRequestBody) buffer.events = [event] - + # Configure mock to always fail - mock_api.create_event_batch.side_effect = Exception("API failure") - + mock_sender.send_batch.side_effect = Exception("API failure") + with patch("asyncio.sleep", return_value=None) as mock_sleep: await buffer._flush() - + # Verify all retry attempts were made - assert mock_api.create_event_batch.call_count == DEFAULT_MAX_RETRIES + 1 + assert mock_sender.send_batch.call_count == DEFAULT_MAX_RETRIES + 1 assert mock_sleep.call_count == DEFAULT_MAX_RETRIES - + # Verify events are cleared even after failure assert len(buffer.events) == 0 - + # Verify error is logged mock_logger.error.assert_called() if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main()