Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions api/experimentation/dataclasses.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
from dataclasses import dataclass
from datetime import datetime

from core.dataclasses import AuthorData
from experimentation.stats import Inference, VariantStats
from experimentation.types import ExposureGranularity
from features.feature_states.models import FeatureValueType
from features.versioning.dataclasses import MultivariateValueChangeSet


@dataclass(frozen=True)
class RolloutSpec:
enabled: bool
rollout_percentage: float
feature_state_value: str
value_type: FeatureValueType
multivariate_values: list[MultivariateValueChangeSet]
author: AuthorData


@dataclass(frozen=True)
Expand Down
26 changes: 26 additions & 0 deletions api/experimentation/migrations/0009_add_rollout_segment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Generated by Django 5.2.14 on 2026-06-19 09:59

import django.db.models.deletion
from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
("experimentation", "0008_experiment_results"),
("segments", "0030_add_default_to_segment_version"),
]

operations = [
migrations.AddField(
model_name="experiment",
name="rollout_segment",
field=models.OneToOneField(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="experiment_rollout",
to="segments.segment",
),
),
]
7 changes: 7 additions & 0 deletions api/experimentation/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,13 @@ class Experiment(LifecycleModelMixin, SoftDeleteExportableModel): # type: ignor
updated_at = models.DateTimeField(auto_now=True)
started_at = models.DateTimeField(null=True, blank=True)
ended_at = models.DateTimeField(null=True, blank=True)
rollout_segment = models.OneToOneField(
"segments.Segment",
on_delete=models.SET_NULL,
related_name="experiment_rollout",
null=True,
blank=True,
)

class Meta:
constraints = [
Expand Down
57 changes: 56 additions & 1 deletion api/experimentation/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from django.db.models import QuerySet
from rest_framework import serializers

from core.dataclasses import AuthorData
from environments.models import Environment
from experimentation.dataclasses import WarehouseEventStats
from experimentation.dataclasses import RolloutSpec, WarehouseEventStats
from experimentation.metric_definitions import validate_metric_definition
from experimentation.models import (
ExpectedDirection,
Expand All @@ -18,14 +19,20 @@
WarehouseConnection,
WarehouseType,
)
from experimentation.services import apply_experiment_rollout
from experimentation.types import (
SNOWFLAKE_DEFAULTS,
MetricExperimentResult,
SnowflakeConfig,
)
from features.feature_states.serializers import (
FeatureValueSerializer,
MultivariateValueSerializer,
)
from features.feature_types import MULTIVARIATE
from features.models import Feature
from features.multivariate.serializers import NestedMultivariateFeatureOptionSerializer
from features.versioning.dataclasses import MultivariateValueChangeSet


class WarehouseConnectionSerializer(serializers.ModelSerializer): # type: ignore[type-arg]
Expand Down Expand Up @@ -207,6 +214,35 @@ class ExperimentMetricInlineSerializer(serializers.Serializer): # type: ignore[
expected_direction = serializers.ChoiceField(choices=ExpectedDirection.choices)


class ExperimentRolloutSerializer(serializers.Serializer): # type: ignore[type-arg]
enabled = serializers.BooleanField(required=True)
rollout_percentage = serializers.FloatField(
required=True, min_value=0, max_value=100
)
feature_state_value = FeatureValueSerializer(required=True)
multivariate_feature_state_values = MultivariateValueSerializer(
many=True, required=False
)

@staticmethod
def to_spec(data: dict[str, Any], request: Any) -> RolloutSpec:
value = data["feature_state_value"]
return RolloutSpec(
enabled=data["enabled"],
rollout_percentage=data["rollout_percentage"],
feature_state_value=value["value"],
value_type=value["type"],
multivariate_values=[
MultivariateValueChangeSet(
multivariate_feature_option_id=mv["multivariate_feature_option"],
percentage_allocation=mv["percentage_allocation"],
)
for mv in data.get("multivariate_feature_state_values", [])
],
author=AuthorData.from_request(request),
)


class ExperimentSerializer(serializers.ModelSerializer): # type: ignore[type-arg]
# Annotated with the common base type so ExperimentListSerializer can
# override the field with a read-only representation.
Expand All @@ -215,6 +251,7 @@ class ExperimentSerializer(serializers.ModelSerializer): # type: ignore[type-ar
required=False,
write_only=True,
)
experiment_rollout = ExperimentRolloutSerializer(required=False, write_only=True)

class Meta:
model = Experiment
Expand All @@ -225,6 +262,7 @@ class Meta:
"hypothesis",
"status",
"metrics",
"experiment_rollout",
"created_at",
"updated_at",
"started_at",
Expand Down Expand Up @@ -260,6 +298,15 @@ def validate(self, attrs: dict[str, Any]) -> dict[str, Any]:
raise serializers.ValidationError(
{"metrics": "Cannot change the metrics of an existing experiment."}
)
if self.instance is not None and "experiment_rollout" in attrs:
raise serializers.ValidationError(
{
"experiment_rollout": (
"Cannot change the rollout via this endpoint; "
"use the rollout endpoint instead."
)
}
)
self._validate_metrics(attrs.get("metrics") or [])
return attrs

Expand All @@ -272,6 +319,7 @@ def _validate_metrics(self, metrics: list[dict[str, Any]]) -> None:

def create(self, validated_data: dict[str, Any]) -> Experiment:
metrics: list[dict[str, Any]] = validated_data.pop("metrics", [])
rollout: dict[str, Any] | None = validated_data.pop("experiment_rollout", None)
with transaction.atomic():
experiment: Experiment = super().create(validated_data)
ExperimentMetric.objects.bulk_create(
Expand All @@ -282,6 +330,13 @@ def create(self, validated_data: dict[str, Any]) -> Experiment:
)
for entry in metrics
)
if rollout is not None:
apply_experiment_rollout(
experiment,
ExperimentRolloutSerializer.to_spec(
rollout, self.context["request"]
),
)
return experiment


Expand Down
80 changes: 80 additions & 0 deletions api/experimentation/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
from clickhouse_driver import Client
from clickhouse_driver.util.helpers import parse_url
from django.conf import settings
from django.db import transaction
from django.db.models import Q
from django.utils import timezone
from flag_engine.segments.constants import PERCENTAGE_SPLIT
Comment thread
gagantrivedi marked this conversation as resolved.
from rest_framework.exceptions import ValidationError

from audit.models import AuditLog
from audit.related_object_type import RelatedObjectType
Expand All @@ -32,6 +35,7 @@
MetricSpec,
ResultsAggregates,
ResultsSummary,
RolloutSpec,
WarehouseEventStats,
)
from experimentation.models import (
Expand All @@ -50,7 +54,10 @@
srm_p_value,
)
from features.models import FeatureState
from features.versioning.dataclasses import FlagChangeSet
from features.versioning.versioning_service import update_flag
from integrations.flagsmith.client import get_openfeature_client
from segments.models import Condition, Segment, SegmentRule

if typing.TYPE_CHECKING:
from collections.abc import Sequence
Expand Down Expand Up @@ -512,6 +519,79 @@ def transition_experiment_status(
return experiment


def _create_rollout_segment(
experiment: Experiment, rollout_percentage: float
) -> Segment:
segment: Segment = Segment.objects.create(
name=f"experiment-{experiment.id}-rollout",
project=experiment.feature.project,
is_system_segment=True,
)
rule = SegmentRule.objects.create(segment=segment, type=SegmentRule.ALL_RULE)
Condition.objects.create(
rule=rule,
operator=PERCENTAGE_SPLIT,
property="$.identity.key",
value=str(rollout_percentage),
)
return segment


def validate_rollout_spec(experiment: Experiment, spec: RolloutSpec) -> None:
option_ids = [v.multivariate_feature_option_id for v in spec.multivariate_values]
if len(option_ids) != len(set(option_ids)):
raise ValidationError("Multivariate options must be unique")
valid_option_ids = set(
experiment.feature.multivariate_options.values_list("id", flat=True)
)
if invalid := set(option_ids) - valid_option_ids:
raise ValidationError(
f"Multivariate options {sorted(invalid)} do not belong to the feature"
)
total = sum(v.percentage_allocation for v in spec.multivariate_values)
if total > 100:
raise ValidationError(
f"Multivariate allocations must not exceed 100%, got {total}%."
)


def _sync_rollout_segment(experiment: Experiment, rollout_percentage: float) -> Segment:
segment = experiment.rollout_segment
if segment is not None:
condition = Condition.objects.get(
rule__segment=segment, operator=PERCENTAGE_SPLIT
)
condition.value = str(rollout_percentage)
condition.save()
return segment
segment = _create_rollout_segment(experiment, rollout_percentage)
experiment.rollout_segment = segment
experiment.save()
return segment


def apply_experiment_rollout(experiment: Experiment, spec: RolloutSpec) -> None:
if experiment.status in (ExperimentStatus.RUNNING, ExperimentStatus.COMPLETED):
raise ValidationError(
f"Cannot change the rollout of a {experiment.status} experiment."
)
validate_rollout_spec(experiment, spec)
with transaction.atomic():
segment = _sync_rollout_segment(experiment, spec.rollout_percentage)
update_flag(
experiment.environment,
experiment.feature,
FlagChangeSet(
author=spec.author,
enabled=spec.enabled,
feature_state_value=spec.feature_state_value,
type_=spec.value_type,
segment_id=segment.id,
multivariate_values=spec.multivariate_values,
),
)


def mark_warehouse_pending_connection(
connection: WarehouseConnection,
) -> WarehouseConnection:
Expand Down
15 changes: 14 additions & 1 deletion api/experimentation/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,14 @@
ExperimentListSerializer,
ExperimentMetricSerializer,
ExperimentResultsSerializer,
ExperimentRolloutSerializer,
ExperimentSerializer,
MetricSerializer,
WarehouseConnectionSerializer,
)
from experimentation.services import (
annotate_warehouse_event_stats,
apply_experiment_rollout,
create_experiment_audit_log,
create_metric_audit_log,
create_warehouse_audit_log,
Expand Down Expand Up @@ -176,7 +178,7 @@ def get_serializer_context(self) -> dict[str, Any]:
return context

def get_serializer_class(self) -> type[BaseSerializer[Experiment]]:
if self.action in ("list", "retrieve", "start", "pause", "complete"):
if self.action in ("list", "retrieve", "start", "pause", "complete", "rollout"):
return ExperimentListSerializer
return ExperimentSerializer

Expand Down Expand Up @@ -290,6 +292,17 @@ def pause(self, request: Request, **kwargs: object) -> Response:
def complete(self, request: Request, **kwargs: object) -> Response:
return self._transition_status(ExperimentStatus.COMPLETED)

@action(detail=True, methods=["patch"])
def rollout(self, request: Request, **kwargs: object) -> Response:
experiment: Experiment = self.get_object()
serializer = ExperimentRolloutSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
apply_experiment_rollout(
experiment,
ExperimentRolloutSerializer.to_spec(serializer.validated_data, request),
)
return Response(self.get_serializer(experiment).data)

@action(detail=True, methods=["get"])
def exposures(self, request: Request, **kwargs: object) -> Response:
experiment: Experiment = self.get_object()
Expand Down
Loading
Loading