Skip to content

Commit afd9841

Browse files
Improve test to include optional parameter
1 parent bd6ceae commit afd9841

1 file changed

Lines changed: 21 additions & 3 deletions

File tree

tests/unit_tests/core/test_context.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from __future__ import annotations
22

33
from dataclasses import dataclass, field
4+
from inspect import Parameter
45
from pathlib import Path
56
from types import ModuleType, NoneType
6-
from typing import Any, Generic, TypeVar, Union
7+
from typing import Any, Generic, TypeVar, Union, get_args, get_type_hints
78
from unittest.mock import ANY, MagicMock, Mock, patch
89

910
import pytest
@@ -88,6 +89,10 @@ def has_some_params(foo: int = 42, bar: str = "bar") -> MsgGenerator:
8889
yield from ()
8990

9091

92+
def has_optional_parameter(foo: dict[str, Any] | None = None) -> MsgGenerator:
93+
yield from ()
94+
95+
9196
def has_typeless_param(foo) -> MsgGenerator:
9297
yield from ()
9398

@@ -169,7 +174,9 @@ def some_configurable() -> SomeConfigurable:
169174
return SomeConfigurable()
170175

171176

172-
@pytest.mark.parametrize("plan", [has_no_params, has_one_param, has_some_params])
177+
@pytest.mark.parametrize(
178+
"plan", [has_no_params, has_one_param, has_some_params, has_optional_parameter]
179+
)
173180
def test_add_plan(empty_context: BlueskyContext, plan: PlanGenerator):
174181
empty_context.register_plan(plan)
175182
assert plan.__name__ in empty_context.plans
@@ -428,14 +435,25 @@ def test_with_config_passes_mock_to_with_dodal_module(
428435
mock_with_dodal_module.assert_called_once_with(ANY, mock=mock)
429436

430437

431-
def test_function_spec(empty_context: BlueskyContext):
438+
def test_function_spec_with_some_params(empty_context: BlueskyContext):
432439
spec = empty_context._type_spec_for_function(has_some_params)
433440
assert spec["foo"][0] is int
434441
assert spec["foo"][1].default == 42
435442
assert spec["bar"][0] is str
436443
assert spec["bar"][1].default == "bar"
437444

438445

446+
def test_function_spec_with_optional_params(empty_context: BlueskyContext):
447+
spec = empty_context._type_spec_for_function(has_optional_parameter)
448+
types = get_type_hints(has_optional_parameter)
449+
arg_type = types.get("foo", Parameter.empty)
450+
451+
_type = SkipJsonSchema[empty_context._convert_type(arg_type, False)]
452+
inner_type, *annotations = get_args(_type)
453+
assert spec["foo"][0] == inner_type
454+
assert spec["foo"][1].default is None
455+
456+
439457
def test_basic_type_conversion(empty_context: BlueskyContext):
440458
assert empty_context._convert_type(int) is int
441459
assert empty_context._convert_type(dict[str, int]) == dict[str, int]

0 commit comments

Comments
 (0)