Skip to content

Commit cc3355f

Browse files
authored
feat: Support for dataclasses as composite types (#1242)
This modifies the previous PR #1231 and adds the following: * Simplifies the creation of composites, the composite no longer needs to be an instance of `BaseModel` and instead can be a dataclass * The composite does not need to specify a default and instead determines the device to inject from the attribute name
1 parent 62b3b59 commit cc3355f

5 files changed

Lines changed: 151 additions & 25 deletions

File tree

docs/how-to/write-plans.md

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,24 @@ The type annotations in the example above (e.g. `: str`, `: int`, `-> MsgGenerat
2424

2525
## Injecting Devices
2626

27-
Some plans are created for specific sets of devices, or will almost always be used with the same devices, it is useful to be able to specify defaults. [Dodal makes this easy with its factory functions](https://diamondlightsource.github.io/dodal/main/how-to/include-devices-in-plans.html).
27+
Some plans are created for specific sets of devices, or will almost always be used with the same devices, it is useful to be able to specify defaults. [Dodal makes this easy with its inject function](https://diamondlightsource.github.io/dodal/main/reference/generated/dodal.common.html#dodal.common.inject).
28+
29+
## Injecting multiple devices
30+
31+
If a plan requires multiple devices to be injected at once, rather than have a plan with several device parameters each of them with their own injection default, it is possible to define a device composite which can be accepted as a parameter.
32+
33+
For example you could define a composite as below:
34+
35+
```{literalinclude} ../../tests/unit_tests/code_examples/device_composite.py
36+
:language: python
37+
```
38+
39+
Then in your plan module:
40+
41+
```{literalinclude} ../../tests/unit_tests/code_examples/plan_with_composite.py
42+
:language: python
43+
```
44+
2845

2946
## Injecting Metadata
3047

src/blueapi/core/context.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import sys
33
from collections.abc import Callable
4-
from dataclasses import InitVar, dataclass, field
4+
from dataclasses import InitVar, dataclass, field, fields, is_dataclass
55
from importlib import import_module
66
from inspect import Parameter, isclass, signature
77
from types import ModuleType, NoneType, UnionType
@@ -12,7 +12,12 @@
1212
from dodal.common.beamlines.beamline_utils import get_path_provider, set_path_provider
1313
from dodal.utils import AnyDevice, make_all_devices
1414
from ophyd_async.core import NotConnectedError, PathProvider
15-
from pydantic import BaseModel, GetCoreSchemaHandler, GetJsonSchemaHandler, create_model
15+
from pydantic import (
16+
BaseModel,
17+
GetCoreSchemaHandler,
18+
GetJsonSchemaHandler,
19+
create_model,
20+
)
1621
from pydantic.fields import FieldInfo
1722
from pydantic.json_schema import JsonSchemaValue, SkipJsonSchema
1823
from pydantic_core import CoreSchema, core_schema
@@ -100,7 +105,7 @@ def is_bluesky_type(typ: type) -> bool:
100105
return typ in BLUESKY_PROTOCOLS or isinstance(typ, BLUESKY_PROTOCOLS)
101106

102107

103-
C = TypeVar("C", bound=BaseModel, covariant=True)
108+
C = TypeVar("C", covariant=True)
104109

105110

106111
@dataclass
@@ -442,16 +447,19 @@ def _type_spec_for_function(
442447
)
443448

444449
no_default = para.default is Parameter.empty
445-
default_factory = (
446-
self._composite_factory(arg_type)
447-
if isclass(arg_type)
448-
and issubclass(arg_type, BaseModel)
450+
if (
451+
isclass(arg_type)
452+
and (issubclass(arg_type, BaseModel) or is_dataclass(arg_type))
449453
and isinstance(para.default, str)
450-
else DefaultFactory(para.default)
451-
)
454+
):
455+
default_factory = self._composite_factory(arg_type)
456+
_type = SkipJsonSchema[self._convert_type(arg_type, no_default)]
457+
else:
458+
default_factory = DefaultFactory(para.default)
459+
_type = self._convert_type(arg_type, no_default)
452460
factory = None if no_default else default_factory
453461
new_args[name] = (
454-
self._convert_type(arg_type, no_default),
462+
_type,
455463
FieldInfo(default_factory=factory),
456464
)
457465
return new_args
@@ -487,14 +495,20 @@ def _convert_type(self, typ: type | Any, no_default: bool = True) -> type:
487495

488496
def _composite_factory(self, composite_class: type[C]) -> Callable[[], C]:
489497
def _inject_composite():
490-
devices = {
491-
field: self.find_device(info.default)
492-
if info.annotation is not None
493-
and is_bluesky_type(info.annotation)
494-
and isinstance(info.default, str)
495-
else info.default
496-
for field, info in composite_class.model_fields.items()
497-
}
498+
if issubclass(composite_class, BaseModel):
499+
devices = {
500+
field_name: self.find_device(field_name)
501+
for field_name in composite_class.model_fields.keys()
502+
}
503+
else:
504+
assert is_dataclass(composite_class), (
505+
f"Unsupported composite type: {composite_class}, composite must be"
506+
" a pydantic BaseModel or a dataclass"
507+
)
508+
devices = {
509+
field.name: self.find_device(field.name)
510+
for field in fields(composite_class)
511+
}
498512
return composite_class(**devices)
499513

500514
return _inject_composite
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import pydantic
2+
from tests.unit_tests.code_examples.device_module import BimorphMirror
3+
4+
5+
@pydantic.dataclasses.dataclass(config={"arbitrary_types_allowed": True})
6+
class MyDeviceComposite:
7+
oav: BimorphMirror
8+
# More devices here....
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from bluesky.utils import MsgGenerator
2+
from dodal.common import inject
3+
from tests.unit_tests.code_examples.device_composite import MyDeviceComposite
4+
5+
6+
def my_plan(
7+
parameter_one: int,
8+
parameter_two: str,
9+
my_necessary_devices: MyDeviceComposite = inject(""),
10+
) -> MsgGenerator[None]:
11+
# logic goes here
12+
...

tests/unit_tests/worker/test_task_worker.py

Lines changed: 81 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import dataclasses
12
import itertools
23
import threading
34
from collections.abc import Callable, Iterable
@@ -7,6 +8,7 @@
78
from typing import Any, TypeVar
89
from unittest.mock import ANY, MagicMock, Mock, patch
910

11+
import pydantic
1012
import pytest
1113
from bluesky.protocols import Movable, Status
1214
from bluesky.utils import MsgGenerator
@@ -21,6 +23,7 @@
2123
from blueapi.config import DeviceSource, EnvironmentConfig
2224
from blueapi.core import BlueskyContext, EventStream
2325
from blueapi.core.bluesky_types import DataEvent
26+
from blueapi.service.model import PlanModel
2427
from blueapi.utils.base_model import BlueapiBaseModel
2528
from blueapi.worker import (
2629
Task,
@@ -657,6 +660,20 @@ def injected_device_plan(
657660
assert params["dev"] == fake_device
658661

659662

663+
def test_injected_devices_plan_model(
664+
fake_device: FakeDevice,
665+
context: BlueskyContext,
666+
):
667+
def injected_device_plan(
668+
dev: FakeDevice = inject(fake_device.name),
669+
) -> MsgGenerator:
670+
yield from ()
671+
672+
context.register_plan(injected_device_plan)
673+
plan = context.plans["injected_device_plan"]
674+
PlanModel.from_plan(plan)
675+
676+
660677
def test_missing_injected_devices_fail_early(
661678
context: BlueskyContext,
662679
):
@@ -695,25 +712,83 @@ def test_cycle_without_otel_context(mock_logger: Mock, inert_worker: TaskWorker)
695712

696713

697714
class MyComposite(BlueapiBaseModel):
698-
dev_a: FakeDevice = inject(fake_device.name)
699-
dev_b: FakeDevice = inject(second_fake_device.name)
715+
fake_device: FakeDevice
716+
second_fake_device: FakeDevice
700717

701718
model_config = {"arbitrary_types_allowed": True}
702719

703720

721+
@pydantic.dataclasses.dataclass(config={"arbitrary_types_allowed": True})
722+
class MyPydanticDataClassComposite:
723+
fake_device: FakeDevice
724+
second_fake_device: FakeDevice
725+
726+
727+
@dataclasses.dataclass()
728+
class MyStandardDataClassComposite:
729+
fake_device: FakeDevice
730+
second_fake_device: FakeDevice
731+
732+
704733
def injected_device_plan(composite: MyComposite = inject("")) -> MsgGenerator:
705734
yield from ()
706735

707736

737+
def injected_dataclass_device_plan(
738+
composite: MyPydanticDataClassComposite = inject(""),
739+
) -> MsgGenerator:
740+
yield from ()
741+
742+
743+
def injected_standard_dataclass_device_plan(
744+
composite: MyStandardDataClassComposite = inject(""),
745+
) -> MsgGenerator:
746+
yield from ()
747+
748+
708749
def test_injected_composite_devices_are_found(
709750
fake_device: FakeDevice,
710751
second_fake_device: FakeDevice,
711752
context: BlueskyContext,
712753
):
713754
context.register_plan(injected_device_plan)
714755
params = Task(name="injected_device_plan").prepare_params(context)
715-
assert params["composite"].dev_a == fake_device
716-
assert params["composite"].dev_b == second_fake_device
756+
assert params["composite"].fake_device == fake_device
757+
assert params["composite"].second_fake_device == second_fake_device
758+
759+
760+
def test_injected_composite_devices_plan_model(
761+
fake_device: FakeDevice,
762+
second_fake_device: FakeDevice,
763+
context: BlueskyContext,
764+
):
765+
context.register_plan(injected_device_plan)
766+
plan = context.plans["injected_device_plan"]
767+
PlanModel.from_plan(plan)
768+
769+
770+
def test_injected_composite_with_pydantic_dataclass(
771+
context: BlueskyContext,
772+
fake_device: FakeDevice,
773+
second_fake_device: FakeDevice,
774+
):
775+
context.register_plan(injected_dataclass_device_plan)
776+
params = Task(name="injected_dataclass_device_plan").prepare_params(context)
777+
assert params["composite"].fake_device == fake_device
778+
assert params["composite"].second_fake_device == second_fake_device
779+
780+
781+
def test_injected_composite_with_standard_dataclass(
782+
context: BlueskyContext,
783+
fake_device: FakeDevice,
784+
second_fake_device: FakeDevice,
785+
):
786+
context.register_plan(injected_standard_dataclass_device_plan)
787+
params = Task(name="injected_standard_dataclass_device_plan").prepare_params(
788+
context
789+
)
790+
assert params["composite"].fake_device == fake_device
791+
assert params["composite"].second_fake_device == second_fake_device
717792

718793

719794
def test_plan_module_with_composite_devices_can_be_loaded_before_device_module(
@@ -725,5 +800,5 @@ def test_plan_module_with_composite_devices_can_be_loaded_before_device_module(
725800
context_without_devices.register_device(fake_device)
726801
context_without_devices.register_device(second_fake_device)
727802
params = Task(name="injected_device_plan").prepare_params(context_without_devices)
728-
assert params["composite"].dev_a == fake_device
729-
assert params["composite"].dev_b == second_fake_device
803+
assert params["composite"].fake_device == fake_device
804+
assert params["composite"].second_fake_device == second_fake_device

0 commit comments

Comments
 (0)