|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | from dataclasses import dataclass, field |
| 4 | +from inspect import Parameter |
4 | 5 | from pathlib import Path |
5 | 6 | from types import ModuleType, NoneType |
6 | | -from typing import Any, Generic, TypeVar, Union |
| 7 | +from typing import Any, Generic, TypeVar, Union, get_args, get_type_hints |
7 | 8 | from unittest.mock import ANY, MagicMock, Mock, patch |
8 | 9 |
|
9 | 10 | import pytest |
@@ -88,6 +89,10 @@ def has_some_params(foo: int = 42, bar: str = "bar") -> MsgGenerator: |
88 | 89 | yield from () |
89 | 90 |
|
90 | 91 |
|
| 92 | +def has_optional_parameter(foo: dict[str, Any] | None = None) -> MsgGenerator: |
| 93 | + yield from () |
| 94 | + |
| 95 | + |
91 | 96 | def has_typeless_param(foo) -> MsgGenerator: |
92 | 97 | yield from () |
93 | 98 |
|
@@ -169,7 +174,9 @@ def some_configurable() -> SomeConfigurable: |
169 | 174 | return SomeConfigurable() |
170 | 175 |
|
171 | 176 |
|
172 | | -@pytest.mark.parametrize("plan", [has_no_params, has_one_param, has_some_params]) |
| 177 | +@pytest.mark.parametrize( |
| 178 | + "plan", [has_no_params, has_one_param, has_some_params, has_optional_parameter] |
| 179 | +) |
173 | 180 | def test_add_plan(empty_context: BlueskyContext, plan: PlanGenerator): |
174 | 181 | empty_context.register_plan(plan) |
175 | 182 | assert plan.__name__ in empty_context.plans |
@@ -428,14 +435,25 @@ def test_with_config_passes_mock_to_with_dodal_module( |
428 | 435 | mock_with_dodal_module.assert_called_once_with(ANY, mock=mock) |
429 | 436 |
|
430 | 437 |
|
431 | | -def test_function_spec(empty_context: BlueskyContext): |
| 438 | +def test_function_spec_with_some_params(empty_context: BlueskyContext): |
432 | 439 | spec = empty_context._type_spec_for_function(has_some_params) |
433 | 440 | assert spec["foo"][0] is int |
434 | 441 | assert spec["foo"][1].default == 42 |
435 | 442 | assert spec["bar"][0] is str |
436 | 443 | assert spec["bar"][1].default == "bar" |
437 | 444 |
|
438 | 445 |
|
| 446 | +def test_function_spec_with_optional_params(empty_context: BlueskyContext): |
| 447 | + spec = empty_context._type_spec_for_function(has_optional_parameter) |
| 448 | + types = get_type_hints(has_optional_parameter) |
| 449 | + arg_type = types.get("foo", Parameter.empty) |
| 450 | + |
| 451 | + _type = SkipJsonSchema[empty_context._convert_type(arg_type, False)] |
| 452 | + inner_type, *annotations = get_args(_type) |
| 453 | + assert spec["foo"][0] == inner_type |
| 454 | + assert spec["foo"][1].default is None |
| 455 | + |
| 456 | + |
439 | 457 | def test_basic_type_conversion(empty_context: BlueskyContext): |
440 | 458 | assert empty_context._convert_type(int) is int |
441 | 459 | assert empty_context._convert_type(dict[str, int]) == dict[str, int] |
|
0 commit comments