Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .fernignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ src/schematic/client.py
scripts/
src/schematic/cache/
src/schematic/event_buffer.py
src/schematic/event_capture.py
src/schematic/http_client.py
src/schematic/logging.py
src/schematic/datastream/
Expand Down
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,25 @@ client = Schematic("", config)
client.check_flag("some-flag-key") # Returns True
```

You can also set flag defaults dynamically after the client has been constructed using `set_flag_default` and `set_flag_defaults`. This is useful in automated testing contexts, where you may want to specify per-test flag values:

```python
from schematic.client import Schematic, SchematicConfig

client = Schematic("", SchematicConfig(offline=True))

# Set a single flag default
client.set_flag_default("some-flag-key", True)

# Or set multiple flag defaults at once
client.set_flag_defaults({
"some-flag-key": True,
"another-flag-key": False,
})

client.check_flag("some-flag-key") # Returns True
```

### Timeouts
By default, requests time out after 60 seconds. You can configure this with a
timeout option at the client or request level.
Expand Down
34 changes: 31 additions & 3 deletions src/schematic/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from typing import Any, Callable, Dict, List, Optional, Union

import httpx

from .base_client import AsyncBaseSchematic, BaseSchematic
from .cache import DEFAULT_CACHE_SIZE, DEFAULT_CACHE_TTL, AsyncCacheProvider, CacheProvider, LocalCache
from .datastream import DataStreamClient, DataStreamClientOptions
from .event_buffer import AsyncEventBuffer, EventBuffer
from .event_capture import AsyncEventCaptureClient, EventCaptureClient
from .http_client import AsyncOfflineHTTPClient, OfflineHTTPClient
from .logging import get_default_logger
from .types import (
Expand Down Expand Up @@ -59,6 +59,7 @@ class DataStreamConfig:
class SchematicConfig:
base_url: Optional[str] = None
event_buffer_period: Optional[int] = None
event_capture_url: Optional[str] = None
flag_defaults: Optional[Dict[str, bool]] = None
follow_redirects: Optional[bool] = True
httpx_client: Optional[httpx.Client] = None
Expand All @@ -83,8 +84,14 @@ def __init__(self, api_key: str, config: Optional[SchematicConfig] = None):
self.event_buffer_period = config.event_buffer_period
self.logger = config.logger or get_default_logger()
self.flag_defaults = config.flag_defaults or {}
self.event_capture_client = EventCaptureClient(
api_key=api_key,
base_url=config.event_capture_url,
httpx_client=httpx_client,
get_headers=self._client_wrapper.get_headers,
)
self.event_buffer = EventBuffer(
events_api=self.events,
event_sender=self.event_capture_client,
logger=self.logger,
period=self.event_buffer_period,
)
Expand All @@ -101,6 +108,7 @@ def initialize(self) -> None:

def shutdown(self) -> None:
self.event_buffer.stop()
self.event_capture_client.close()

def check_flag(
self,
Expand Down Expand Up @@ -305,6 +313,12 @@ def _enqueue_event(self, event_type: str, body: EventBody) -> None:
def _get_flag_default(self, flag_key: str) -> bool:
return self.flag_defaults.get(flag_key, False)

def set_flag_default(self, flag_key: str, value: bool) -> None:
self.flag_defaults[flag_key] = value

def set_flag_defaults(self, values: Dict[str, bool]) -> None:
self.flag_defaults.update(values)

def _resolve_default(self, flag_key: str, options: Optional[CheckFlagOptions] = None) -> bool:
if options and options.default_value is not None:
if callable(options.default_value):
Expand All @@ -317,6 +331,7 @@ def _resolve_default(self, flag_key: str, options: Optional[CheckFlagOptions] =
class AsyncSchematicConfig:
base_url: Optional[str] = None
event_buffer_period: Optional[int] = None
event_capture_url: Optional[str] = None
flag_defaults: Optional[Dict[str, bool]] = None
follow_redirects: Optional[bool] = True
httpx_client: Optional[httpx.AsyncClient] = None
Expand Down Expand Up @@ -372,8 +387,14 @@ def __init__(self, api_key: str, config: Optional[AsyncSchematicConfig] = None):
self.event_buffer_period = config.event_buffer_period
self.logger = config.logger or get_default_logger()
self.flag_defaults = config.flag_defaults or {}
self.event_capture_client = AsyncEventCaptureClient(
api_key=api_key,
base_url=config.event_capture_url,
httpx_client=httpx_client,
get_headers=self._client_wrapper.get_headers,
)
self.event_buffer = AsyncEventBuffer(
events_api=self.events,
event_sender=self.event_capture_client,
logger=self.logger,
period=self.event_buffer_period,
)
Expand Down Expand Up @@ -727,6 +748,12 @@ async def _enqueue_event(self, event_type: str, body: EventBody) -> None:
def _get_flag_default(self, flag_key: str) -> bool:
return self.flag_defaults.get(flag_key, False)

def set_flag_default(self, flag_key: str, value: bool) -> None:
self.flag_defaults[flag_key] = value

def set_flag_defaults(self, values: Dict[str, bool]) -> None:
self.flag_defaults.update(values)

def _resolve_default(self, flag_key: str, options: Optional[CheckFlagOptions] = None) -> bool:
if options and options.default_value is not None:
if callable(options.default_value):
Expand Down Expand Up @@ -767,6 +794,7 @@ async def shutdown(self) -> None:

# Flush and stop the event buffer
await self.event_buffer.stop()
await self.event_capture_client.close()
self.logger.info("Shutdown complete.")
except Exception as e:
self.logger.error(f"Error during shutdown: {e}")
Expand Down
14 changes: 7 additions & 7 deletions src/schematic/event_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
from typing import List, Optional

from .events.client import AsyncEventsClient, EventsClient
from .event_capture import AsyncEventCaptureClient, EventCaptureClient
from .types import CreateEventRequestBody

DEFAULT_MAX_EVENTS = 100 # Default maximum number of events
Expand All @@ -17,15 +17,15 @@
class EventBuffer:
def __init__(
self,
events_api: EventsClient,
event_sender: EventCaptureClient,
logger: logging.Logger,
period: Optional[int] = None,
max_events: int = DEFAULT_MAX_EVENTS,
max_retries: int = DEFAULT_MAX_RETRIES,
initial_retry_delay: float = DEFAULT_INITIAL_RETRY_DELAY,
):
self.events: List[CreateEventRequestBody] = []
self.events_api = events_api
self.event_sender = event_sender
self.interval = period or DEFAULT_EVENT_BUFFER_PERIOD
self.logger = logger
self.max_events = max_events
Expand Down Expand Up @@ -61,7 +61,7 @@ def _process_events(self, events_to_process):
if retry_count > 0:
self.logger.info(f"Retrying event batch submission (attempt {retry_count} of {self.max_retries})")

self.events_api.create_event_batch(events=events_to_process)
self.event_sender.send_batch(events_to_process)
success = True

except Exception as e:
Expand Down Expand Up @@ -125,15 +125,15 @@ def stop(self):
class AsyncEventBuffer:
def __init__(
self,
events_api: AsyncEventsClient,
event_sender: AsyncEventCaptureClient,
logger: logging.Logger,
period: Optional[int] = None,
max_events: int = DEFAULT_MAX_EVENTS,
max_retries: int = DEFAULT_MAX_RETRIES,
initial_retry_delay: float = DEFAULT_INITIAL_RETRY_DELAY,
):
self.events: List[CreateEventRequestBody] = []
self.events_api = events_api
self.event_sender = event_sender
self.interval = period or DEFAULT_EVENT_BUFFER_PERIOD
self.logger = logger
self.max_events = max_events
Expand Down Expand Up @@ -168,7 +168,7 @@ async def _process_events_async(self, events_to_process):
if retry_count > 0:
self.logger.info(f"Retrying event batch submission (attempt {retry_count} of {self.max_retries})")

await self.events_api.create_event_batch(events=events_to_process)
await self.event_sender.send_batch(events_to_process)
success = True

except Exception as e:
Expand Down
142 changes: 142 additions & 0 deletions src/schematic/event_capture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import datetime as dt
import typing

import httpx
import pydantic
from .core.pydantic_utilities import UniversalBaseModel
from .types import CreateEventRequestBody, EventBody, EventType

DEFAULT_EVENT_CAPTURE_BASE_URL = "https://c.schematichq.com"
DEFAULT_TIMEOUT = 10.0


class _CaptureEventPayload(UniversalBaseModel):
"""Wire format for a single event sent to the capture service.

Mirrors the shape used by the Go/Ruby/C# SDKs: `type` (not `event_type`)
and an `api_key` field embedded on each event.
"""

api_key: str = pydantic.Field()
body: typing.Optional[EventBody] = None
type: EventType = pydantic.Field()
sent_at: typing.Optional[dt.datetime] = None


class _CaptureBatchPayload(UniversalBaseModel):
events: typing.List[_CaptureEventPayload]


def _to_payload(event: CreateEventRequestBody, api_key: str) -> _CaptureEventPayload:
return _CaptureEventPayload(
api_key=api_key,
body=event.body,
type=event.event_type,
sent_at=event.sent_at,
)


def _build_endpoint(base_url: str) -> str:
return base_url.rstrip("/") + "/batch"


def _build_headers(
api_key: str,
get_headers: typing.Optional[typing.Callable[[], typing.Dict[str, str]]] = None,
) -> typing.Dict[str, str]:
"""Build the headers for capture-service requests.

Reuses the API client's header builder when provided so we send the same
X-Fern-* / User-Agent / X-Schematic-Api-Key headers as the REST client —
keeps observability consistent across both endpoints.
"""
headers: typing.Dict[str, str] = {}
if get_headers is not None:
headers.update(get_headers())
else:
headers["X-Schematic-Api-Key"] = api_key
headers["Content-Type"] = "application/json"
return headers


def _serialize_batch(
events: typing.List[CreateEventRequestBody], api_key: str
) -> str:
batch = _CaptureBatchPayload(
events=[_to_payload(e, api_key) for e in events],
)
return batch.model_dump_json(by_alias=True, exclude_none=True)


class EventCaptureClient:
"""HTTP client for sending events to the Schematic event capture service."""

def __init__(
self,
api_key: str,
base_url: typing.Optional[str] = None,
httpx_client: typing.Optional[httpx.Client] = None,
get_headers: typing.Optional[typing.Callable[[], typing.Dict[str, str]]] = None,
):
self._api_key = api_key
self._base_url = base_url or DEFAULT_EVENT_CAPTURE_BASE_URL
self._owns_client = httpx_client is None
self._httpx_client = httpx_client or httpx.Client(timeout=DEFAULT_TIMEOUT)
self._get_headers = get_headers

def send_batch(self, events: typing.List[CreateEventRequestBody]) -> None:
if not events:
return

body = _serialize_batch(events, self._api_key)
response = self._httpx_client.post(
_build_endpoint(self._base_url),
content=body,
headers=_build_headers(self._api_key, self._get_headers),
)

if response.status_code < 200 or response.status_code >= 300:
raise RuntimeError(
f"capture service returned HTTP {response.status_code}: {response.text}"
)

def close(self) -> None:
if self._owns_client:
self._httpx_client.close()


class AsyncEventCaptureClient:
"""Async HTTP client for sending events to the Schematic event capture service."""

def __init__(
self,
api_key: str,
base_url: typing.Optional[str] = None,
httpx_client: typing.Optional[httpx.AsyncClient] = None,
get_headers: typing.Optional[typing.Callable[[], typing.Dict[str, str]]] = None,
):
self._api_key = api_key
self._base_url = base_url or DEFAULT_EVENT_CAPTURE_BASE_URL
self._owns_client = httpx_client is None
self._httpx_client = httpx_client or httpx.AsyncClient(timeout=DEFAULT_TIMEOUT)
self._get_headers = get_headers

async def send_batch(self, events: typing.List[CreateEventRequestBody]) -> None:
if not events:
return

body = _serialize_batch(events, self._api_key)
response = await self._httpx_client.post(
_build_endpoint(self._base_url),
content=body,
headers=_build_headers(self._api_key, self._get_headers),
)

if response.status_code < 200 or response.status_code >= 300:
raise RuntimeError(
f"capture service returned HTTP {response.status_code}: {response.text}"
)

async def close(self) -> None:
if self._owns_client:
await self._httpx_client.aclose()
Loading
Loading