Skip to content

Commit 17e86b0

Browse files
Improve type hints for plan signature to include defaults
1 parent a8aa805 commit 17e86b0

4 files changed

Lines changed: 42 additions & 40 deletions

File tree

src/blueapi/client/client.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,24 +218,33 @@ def _build_args(self, *args, **kwargs):
218218

219219
def __repr__(self):
220220
props = self.model.parameter_schema.get("properties", {})
221+
required = set(self.required)
222+
221223
tab = " "
222224
args = []
225+
223226
for name, info in props.items():
224227
typ = _pretty_type(info)
225228
arg = f"{name}: {typ}"
226-
if name not in self.required:
227-
arg = f"{arg} | None = None"
229+
230+
if name not in required:
231+
if "default" in info:
232+
default = repr(info["default"])
233+
arg = f"{arg} = {default}"
234+
else:
235+
arg = f"{arg} | None = None"
236+
228237
args.append(arg)
229238

230239
single_line = f"{self.name}({', '.join(args)})"
231240
max_length = 100
232241
max_args_inline = 3
242+
233243
if len(single_line) <= max_length and len(args) <= max_args_inline:
234244
return single_line
235245

236246
# Fall back to multiline if too many arguments or too long.
237247
multiline_args = ",\n".join(f"{tab}{arg}" for arg in args)
238-
239248
return f"{self.name}(\n{multiline_args}\n)"
240249

241250

src/blueapi/core/context.py

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from importlib import import_module
66
from inspect import Parameter, isclass, signature
77
from types import ModuleType, NoneType, UnionType
8-
from typing import Any, Generic, TypeVar, Union, get_args, get_origin, get_type_hints
8+
from typing import Any, TypeVar, Union, get_args, get_origin, get_type_hints
99

1010
from bluesky.protocols import HasName
1111
from bluesky.run_engine import RunEngine
@@ -516,14 +516,16 @@ def _type_spec_for_function(
516516
):
517517
default_factory = self._composite_factory(arg_type)
518518
_type = SkipJsonSchema[self._convert_type(arg_type, no_default)]
519+
field_info = FieldInfo(default_factory=default_factory)
519520
else:
520-
default_factory = DefaultFactory(para.default)
521521
_type = self._convert_type(arg_type, no_default)
522-
factory = None if no_default else default_factory
523-
new_args[name] = (
524-
_type,
525-
FieldInfo(default_factory=factory),
526-
)
522+
if no_default:
523+
field_info = FieldInfo()
524+
else:
525+
field_info = FieldInfo(default=para.default)
526+
527+
new_args[name] = (_type, field_info)
528+
527529
return new_args
528530

529531
def _convert_type(self, typ: Any, no_default: bool = True) -> type:
@@ -574,19 +576,3 @@ def _inject_composite():
574576
return composite_class(**devices)
575577

576578
return _inject_composite
577-
578-
579-
D = TypeVar("D")
580-
581-
582-
class DefaultFactory(Generic[D]):
583-
_value: D
584-
585-
def __init__(self, value: D):
586-
self._value = value
587-
588-
def __call__(self) -> D:
589-
return self._value
590-
591-
def __eq__(self, other) -> bool:
592-
return other.__class__ == self.__class__ and self._value == other._value

tests/unit_tests/client/test_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -730,8 +730,8 @@ def test_plan_multi_parameter_fallback_help_text(client):
730730
"two": {
731731
"anyOf": [{"items": {}, "type": "array"}, {"type": "boolean"}],
732732
},
733-
"three": {},
734-
"four": {},
733+
"three": {"default": 3},
734+
"four": {"default": None},
735735
},
736736
"required": ["one", "two"],
737737
},
@@ -742,8 +742,8 @@ def test_plan_multi_parameter_fallback_help_text(client):
742742
plan.help_text == "Plan foo(\n"
743743
" one: Any,\n"
744744
" two: list[Any] | bool,\n"
745-
" three: Any | None = None,\n"
746-
" four: Any | None = None\n"
745+
" three: Any = 3,\n"
746+
" four: Any = None\n"
747747
")"
748748
)
749749

tests/unit_tests/core/test_context.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
TiledConfig,
4747
)
4848
from blueapi.core import BlueskyContext, is_bluesky_compatible_device
49-
from blueapi.core.context import DefaultFactory, generic_bounds, qualified_name
49+
from blueapi.core.context import generic_bounds, qualified_name
5050
from blueapi.core.protocols import DeviceConnectResult, DeviceManager
5151
from blueapi.utils.connect_devices import _establish_device_connections
5252
from blueapi.utils.invalid_config_error import InvalidConfigError
@@ -431,9 +431,9 @@ def test_with_config_passes_mock_to_with_dodal_module(
431431
def test_function_spec(empty_context: BlueskyContext):
432432
spec = empty_context._type_spec_for_function(has_some_params)
433433
assert spec["foo"][0] is int
434-
assert spec["foo"][1].default_factory == DefaultFactory(42)
434+
assert spec["foo"][1].default == 42
435435
assert spec["bar"][0] is str
436-
assert spec["bar"][1].default_factory == DefaultFactory("bar")
436+
assert spec["bar"][1].default == "bar"
437437

438438

439439
def test_basic_type_conversion(empty_context: BlueskyContext):
@@ -514,7 +514,7 @@ def default_movable(mov: Movable = inject("demo")) -> MsgGenerator:
514514
spec = empty_context._type_spec_for_function(default_movable)
515515
movable_ref = empty_context._reference(Movable)
516516
assert spec["mov"][0] == movable_ref
517-
assert spec["mov"][1].default_factory == DefaultFactory("demo")
517+
assert spec["mov"][1].default == "demo"
518518

519519

520520
def test_generic_default_device_reference(empty_context: BlueskyContext):
@@ -524,7 +524,7 @@ def default_movable(mov: Movable[float] = inject("demo")) -> MsgGenerator:
524524
spec = empty_context._type_spec_for_function(default_movable)
525525
motor_ref = empty_context._reference(Movable[float])
526526
assert spec["mov"][0] == motor_ref
527-
assert spec["mov"][1].default_factory == DefaultFactory("demo")
527+
assert spec["mov"][1].default == "demo"
528528

529529

530530
class ConcreteStoppable(Stoppable):
@@ -574,7 +574,7 @@ def test_str_default(empty_context: BlueskyContext, sim_motor: Motor, alt_motor:
574574

575575
spec = empty_context._type_spec_for_function(has_default_reference)
576576
assert spec["m"][0] is movable_ref
577-
assert (df := spec["m"][1].default_factory) and df() == SIM_MOTOR_NAME # type: ignore
577+
assert spec["m"][1].default == SIM_MOTOR_NAME
578578

579579
assert has_default_reference.__name__ in empty_context.plans
580580
model = empty_context.plans[has_default_reference.__name__].model
@@ -593,7 +593,7 @@ def test_nested_str_default(
593593

594594
spec = empty_context._type_spec_for_function(has_default_nested_reference)
595595
assert spec["m"][0] == list[movable_ref]
596-
assert (df := spec["m"][1].default_factory) and df() == [SIM_MOTOR_NAME] # type: ignore
596+
assert spec["m"][1].default == [SIM_MOTOR_NAME]
597597

598598
assert has_default_nested_reference.__name__ in empty_context.plans
599599
model = empty_context.plans[has_default_nested_reference.__name__].model
@@ -697,7 +697,7 @@ def demo_plan(foo: int | None = None) -> MsgGenerator:
697697
empty_context.register_plan(demo_plan)
698698
schema = empty_context.plans["demo_plan"].model.model_json_schema()
699699
assert schema["properties"] == {
700-
"foo": {"title": "Foo", "type": "integer"},
700+
"foo": {"title": "Foo", "type": "integer", "default": None},
701701
}
702702
assert "foo" not in schema.get("required", [])
703703

@@ -725,7 +725,11 @@ def demo_plan(foo: int | str | None = None) -> MsgGenerator:
725725
empty_context.register_plan(demo_plan)
726726
schema = empty_context.plans["demo_plan"].model.model_json_schema()
727727
assert schema["properties"] == {
728-
"foo": {"title": "Foo", "anyOf": [{"type": "integer"}, {"type": "string"}]}
728+
"foo": {
729+
"title": "Foo",
730+
"anyOf": [{"type": "integer"}, {"type": "string"}],
731+
"default": None,
732+
}
729733
}
730734
assert "foo" not in schema.get("required", [])
731735

@@ -739,7 +743,10 @@ def demo_plan(foo: int | None) -> MsgGenerator:
739743
empty_context.register_plan(demo_plan)
740744
schema = empty_context.plans["demo_plan"].model.model_json_schema()
741745
assert schema["properties"] == {
742-
"foo": {"title": "Foo", "anyOf": [{"type": "integer"}, {"type": "null"}]}
746+
"foo": {
747+
"title": "Foo",
748+
"anyOf": [{"type": "integer"}, {"type": "null"}],
749+
}
743750
}
744751
assert "foo" in schema.get("required", [])
745752

0 commit comments

Comments
 (0)