From 486cf2e7cfa44cbc37b0ae27b4b9907285744af3 Mon Sep 17 00:00:00 2001 From: Julio Machado Date: Fri, 24 Apr 2026 01:11:58 -0300 Subject: [PATCH 1/4] ENH: __verify_trigger method and tests implementation. adaptation of out_of_rail_trigger to pass in __verify_trigger Co-authored-by: Copilot --- rocketpy/simulation/events.py | 89 +++++++++++++++++++++++++++- rocketpy/simulation/flight.py | 2 +- tests/unit/simulation/test_events.py | 25 ++++++++ 3 files changed, 112 insertions(+), 4 deletions(-) create mode 100644 tests/unit/simulation/test_events.py diff --git a/rocketpy/simulation/events.py b/rocketpy/simulation/events.py index 13b775725..95ab746a9 100644 --- a/rocketpy/simulation/events.py +++ b/rocketpy/simulation/events.py @@ -1,3 +1,6 @@ +import inspect +from typing import get_type_hints +import warnings class Event: # TODO: should "sensors" arg of the trigger function be a dictionary instead # of a list? It would be more intuitive to access the sensors by name @@ -89,11 +92,10 @@ def __init__(self, trigger, action, name, event_context=None): passed to subsequent calls. Defaults to an empty dictionary if not provided. """ + self.name = name self.trigger = self.__verify_trigger(trigger) self.action = self.__verify_action(action) - self.name = name self.event_context = event_context if event_context is not None else {} - # TODO: implement tracking for whether this event is currently enabled # or disabled. The disable_event flag from the action return value should # control whether this event continues to be checked for triggering. @@ -122,8 +124,89 @@ def __verify_trigger(self, trigger): # 2. Return type annotation is bool or can be tested to return bool # 3. Consider allowing signature to be flexible (accepts **kwargs) # to accommodate user-defined custom event_context keys + # verify if the return type is bool when annotated + return_annotation = get_type_hints(trigger).get('return', None) + if return_annotation is not None and return_annotation is not bool: + raise ValueError(f"Trigger function {self.name} must return a boolean value.") + # verify if the trigger function accepts **kwargs and therefore can + # receive standard event arguments plus custom event_context keys + s = inspect.signature(trigger) + if not any(p.kind == inspect.Parameter.VAR_KEYWORD for p in s.parameters.values()): + raise ValueError( + f"Trigger function {self.name} must accept **kwargs to receive event context " + f"and simulation state." + ) + if any(p.kind == inspect.Parameter.POSITIONAL_ONLY for p in s.parameters.values()): + raise ValueError( + f"Trigger function {self.name} must accept keyword arguments; " + "positional-only parameters are not supported." + ) + # Helper function to generate dummy values based on type annotations + # of parameters, allowing to test the function without real values + def _placeholder_for_parameter(parameter): + annotation = parameter.annotation + if annotation is inspect.Parameter.empty: + warnings.warn(f"Trigger function {self.name}: Test with parameters skipped due " + f"to missing type annotation for parameter '{parameter.name}'. \n" + f"Is highly recommended that parameters have type annotations " + f"(var: type). Parameter '{parameter.name}' has no annotation.") + skip_test = True + return None, skip_test + if annotation in (int, float): + return 0, False + if annotation is bool: + return False, False + if annotation is str: + return "", False + if annotation in (list, tuple, set, dict): + return annotation(), False + origin = getattr(annotation, "__origin__", None) + if origin in (list, tuple, set, dict): + return origin(), False + return None, False + # Build a dictionary with dummy values to test if function accepts **kwargs + # Include an unexpected argument to validate the function doesn't complain + test_kwargs = {"unexpected_kwarg": 123} + skip_test = False + # Iterate through function parameters to generate appropriate test values + for name, parameter in s.parameters.items(): + if parameter.kind in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ): + if parameter.default is inspect.Parameter.empty: + annotation = parameter.annotation + if annotation in (list, tuple, set, dict): + skip_test = True + elif hasattr(annotation, "__origin__") and getattr(annotation, "__origin__", None) in (list, tuple, set, dict): + skip_test = True + else: + test_kwargs[name], skip_test = _placeholder_for_parameter(parameter) + # Execute the trigger function with test values to validate compatibility + # If TypeError occurs, the function doesn't properly accept **kwargs + if not skip_test: + try: + trigger(**test_kwargs) + except TypeError as e: + raise ValueError( + f"Trigger function {self.name} must accept arbitrary kwargs without raising " + "a TypeError." + ) from e + except Exception as e: + raise ValueError( + f"Trigger function {self.name} must accept arbitrary kwargs without raising " + f"an error: {e}" + ) from e + else: + # Test was skipped due to complex types; warn user to validate manually + warnings.warn( + f"Trigger function {self.name}: Test with parameters " + f"skipped for parameters with complex types " + f"(list, tuple, set, dict). Ensure the function handles " + f"arbitrary inputs gracefully." + ) return trigger - + def __verify_action(self, action): """Verifies that the action function is valid. diff --git a/rocketpy/simulation/flight.py b/rocketpy/simulation/flight.py index 9e78d169b..50dbf5a8d 100644 --- a/rocketpy/simulation/flight.py +++ b/rocketpy/simulation/flight.py @@ -619,7 +619,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements self.ode_solver = ode_solver # Events - def out_of_rail_trigger(state): + def out_of_rail_trigger(state, **kwargs) -> bool: return ( state[0] ** 2 + state[1] ** 2 + (state[2] - self.env.elevation) ** 2 >= self.effective_1rl**2 diff --git a/tests/unit/simulation/test_events.py b/tests/unit/simulation/test_events.py new file mode 100644 index 000000000..f8e272b31 --- /dev/null +++ b/tests/unit/simulation/test_events.py @@ -0,0 +1,25 @@ +import pytest + +from rocketpy.simulation.events import Event + + +def test_verify_trigger_accepts_required_args_with_kwargs(): + def trigger(a: int, b: float, **kwargs) -> bool: + return True + + def action(**kwargs): + return None + + event = Event(trigger=trigger, action=action, name="test") + assert event.trigger is trigger + + +def test_verify_trigger_rejects_missing_kwargs(): + def trigger(a, b) -> bool: + return True + + def action(**kwargs): + return None + + with pytest.raises(ValueError, match=r"must accept \*\*kwargs"): + Event(trigger=trigger, action=action, name="test") From a2599400fc3f5c012d6bd06fc051436ca584ff20 Mon Sep 17 00:00:00 2001 From: Julio Machado Date: Tue, 28 Apr 2026 19:40:49 -0300 Subject: [PATCH 2/4] mnt: formatting --- rocketpy/simulation/events.py | 119 +++++++++++++++++++--------------- 1 file changed, 67 insertions(+), 52 deletions(-) diff --git a/rocketpy/simulation/events.py b/rocketpy/simulation/events.py index 95ab746a9..deb5fd779 100644 --- a/rocketpy/simulation/events.py +++ b/rocketpy/simulation/events.py @@ -1,31 +1,33 @@ import inspect -from typing import get_type_hints import warnings +from typing import get_type_hints + + class Event: # TODO: should "sensors" arg of the trigger function be a dictionary instead - # of a list? It would be more intuitive to access the sensors by name + # of a list? It would be more intuitive to access the sensors by name def __init__(self, trigger, action, name, event_context=None): """Initializes an Event object. Parameters ---------- trigger : function - A function that must return a boolean value. The event will be - triggered when this function returns True. The function should be + A function that must return a boolean value. The event will be + triggered when this function returns True. The function should be defined with the following signature: trigger(**kwargs), where kwargs is a dictionary containing the keys: - `"time"` (float): The current simulation time in seconds. - `"state"` (list): The state vector of the simulation, structured as `[x, y, z, vx, vy, vz, e0, e1, e2, e3, wx, wy, wz]`. - - `"state_dot"` (list): The time derivative of the state vector, + - `"state_dot"` (list): The time derivative of the state vector, structured as `[vx, vy, vz, ax, ay, az, e0_dot, e1_dot, e2_dot, e3_dot, wx_dot, wy_dot, wz_dot]`. - `"sampling_rate"` (float or None): The sampling rate of the event, in seconds. If None, the event will be checked for triggering at every time step of the simulation. If a float - value is provided, the event will only be checked for + value is provided, the event will only be checked for triggering at that specific time interval. - - `"sensors"` (list): A list of sensors that are attached to the + - `"sensors"` (list): A list of sensors that are attached to the rocket. The most recent measurements of the sensors are provided with the ``sensor.measurement`` attribute. The sensors are listed in the same order as they are added to the rocket. @@ -37,20 +39,20 @@ def __init__(self, trigger, action, name, event_context=None): - `"phase_index"` (int): The index of the current flight phase. - `"node_index"` (int): The index of the current node in the current flight phase. - - Any additional custom key-value pairs provided via the + - Any additional custom key-value pairs provided via the `event_context` parameter (see below). - + action : function A function that will be executed when the event is triggered. The - function should be defined with the following signature: + function should be defined with the following signature: action(**kwargs), where kwargs is a dictionary containing the same keys as the trigger function. The action function can also modify the state of the simulation by returning a dictionary with the keys: - `"state"` (list): A new state vector to replace the current state vector. The structure of the state vector is the same as the one provided in the trigger function. - - `"disable_event"` (bool): If True, the event will not be - checked for triggering again after being triggered, making + - `"disable_event"` (bool): If True, the event will not be + checked for triggering again after being triggered, making it a one-time event. Defaults to True. - `"new_events"` (list): A list of new Event objects to be added to the simulation when the event is triggered. This can be @@ -58,38 +60,38 @@ def __init__(self, trigger, action, name, event_context=None): triggered, such as a parachute deployment event that spawns a new event to check for the parachute deployment after a certain time delay. - - `"remove_events"` (list): A list of Event objects to be - removed from the simulation when the event is triggered. This - can be used to create events that remove other events when - they are triggered, such as a parachute deployment event that + - `"remove_events"` (list): A list of Event objects to be + removed from the simulation when the event is triggered. This + can be used to create events that remove other events when + they are triggered, such as a parachute deployment event that removes the apogee event when it is triggered. - - Any other key-value pairs defined in `event_context` will - also be included. These allow you to maintain custom state or - counters across multiple trigger and action calls. Use cases + - Any other key-value pairs defined in `event_context` will + also be included. These allow you to maintain custom state or + counters across multiple trigger and action calls. Use cases include: tracking the number of times an event has been triggered - (e.g., `{"trigger_count": 0}`), recording the time of the last - trigger (e.g., `{"last_trigger_time": None}`), or any other + (e.g., `{"trigger_count": 0}`), recording the time of the last + trigger (e.g., `{"last_trigger_time": None}`), or any other custom data your trigger/action functions need to share state. - - Example: If you initialize the event with - `event_context={"trigger_count": 0}`, your trigger and action - functions will receive `trigger_count=0` in their kwargs dict. - You can then update this value in the action function by - including it in the returned dictionary (e.g., - `{"trigger_count": 1}`), and it will be passed to subsequent + + Example: If you initialize the event with + `event_context={"trigger_count": 0}`, your trigger and action + functions will receive `trigger_count=0` in their kwargs dict. + You can then update this value in the action function by + including it in the returned dictionary (e.g., + `{"trigger_count": 1}`), and it will be passed to subsequent trigger/action calls. name : str A name for the event, used for identification purposes. event_context : dict, optional - A dictionary of custom key-value pairs that will be passed to the - trigger and action functions. This allows you to initialize and - maintain custom state that persists across multiple trigger/action - calls. For example, `event_context={"trigger_count": 0, - "last_trigger_time": None}` can be used to track event state. - When the action function returns a dictionary with updated values - (e.g., `{"trigger_count": 1}`), those values persist and are - passed to subsequent calls. Defaults to an empty dictionary if not + A dictionary of custom key-value pairs that will be passed to the + trigger and action functions. This allows you to initialize and + maintain custom state that persists across multiple trigger/action + calls. For example, `event_context={"trigger_count": 0, + "last_trigger_time": None}` can be used to track event state. + When the action function returns a dictionary with updated values + (e.g., `{"trigger_count": 1}`), those values persist and are + passed to subsequent calls. Defaults to an empty dictionary if not provided. """ self.name = name @@ -125,31 +127,40 @@ def __verify_trigger(self, trigger): # 3. Consider allowing signature to be flexible (accepts **kwargs) # to accommodate user-defined custom event_context keys # verify if the return type is bool when annotated - return_annotation = get_type_hints(trigger).get('return', None) + return_annotation = get_type_hints(trigger).get("return", None) if return_annotation is not None and return_annotation is not bool: - raise ValueError(f"Trigger function {self.name} must return a boolean value.") + raise ValueError( + f"Trigger function {self.name} must return a boolean value." + ) # verify if the trigger function accepts **kwargs and therefore can # receive standard event arguments plus custom event_context keys s = inspect.signature(trigger) - if not any(p.kind == inspect.Parameter.VAR_KEYWORD for p in s.parameters.values()): + if not any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in s.parameters.values() + ): raise ValueError( f"Trigger function {self.name} must accept **kwargs to receive event context " f"and simulation state." ) - if any(p.kind == inspect.Parameter.POSITIONAL_ONLY for p in s.parameters.values()): + if any( + p.kind == inspect.Parameter.POSITIONAL_ONLY for p in s.parameters.values() + ): raise ValueError( f"Trigger function {self.name} must accept keyword arguments; " "positional-only parameters are not supported." ) + # Helper function to generate dummy values based on type annotations # of parameters, allowing to test the function without real values def _placeholder_for_parameter(parameter): annotation = parameter.annotation if annotation is inspect.Parameter.empty: - warnings.warn(f"Trigger function {self.name}: Test with parameters skipped due " - f"to missing type annotation for parameter '{parameter.name}'. \n" + warnings.warn( + f"Trigger function {self.name}: Test with parameters skipped due " + f"to missing type annotation for parameter '{parameter.name}'. \n" f"Is highly recommended that parameters have type annotations " - f"(var: type). Parameter '{parameter.name}' has no annotation.") + f"(var: type). Parameter '{parameter.name}' has no annotation." + ) skip_test = True return None, skip_test if annotation in (int, float): @@ -164,6 +175,7 @@ def _placeholder_for_parameter(parameter): if origin in (list, tuple, set, dict): return origin(), False return None, False + # Build a dictionary with dummy values to test if function accepts **kwargs # Include an unexpected argument to validate the function doesn't complain test_kwargs = {"unexpected_kwarg": 123} @@ -178,10 +190,14 @@ def _placeholder_for_parameter(parameter): annotation = parameter.annotation if annotation in (list, tuple, set, dict): skip_test = True - elif hasattr(annotation, "__origin__") and getattr(annotation, "__origin__", None) in (list, tuple, set, dict): + elif hasattr(annotation, "__origin__") and getattr( + annotation, "__origin__", None + ) in (list, tuple, set, dict): skip_test = True else: - test_kwargs[name], skip_test = _placeholder_for_parameter(parameter) + test_kwargs[name], skip_test = _placeholder_for_parameter( + parameter + ) # Execute the trigger function with test values to validate compatibility # If TypeError occurs, the function doesn't properly accept **kwargs if not skip_test: @@ -274,21 +290,20 @@ def __call__(self, *args, **kwds): pass - # TODO: add a parameter to the Event class that specify whether the event should -# be triggered only once, or if it can be triggered multiple times. Also, add a +# be triggered only once, or if it can be triggered multiple times. Also, add a # way to stop the event from continuously triggering on command inside the action -# function, such as a "disable" method that can be called inside the action +# function, such as a "disable" method that can be called inside the action # function to prevent the event from being triggered again. # TODO: add a parameter to the Event class that specify whether the event should # be a discrete event, meaning that it should only be checked for triggering at # specific time intervals (e.g. every 0.1 seconds) instead of at every time step -# of the simulation. This would be useful for parachute events. This should be +# of the simulation. This would be useful for parachute events. This should be # done by adding a "sampling_rate" parameter to the Event class, that is none by # default (meaning that the event is checked at every time step), but if it is -# set to a float value, the event will only be checked for triggering at that -# specific time interval. The flight class should be able to differentiate +# set to a float value, the event will only be checked for triggering at that +# specific time interval. The flight class should be able to differentiate # between the discrete and continuous events (we will handle this later) @@ -302,4 +317,4 @@ def __call__(self, *args, **kwds): # - Respect the disable_event flag and sampling_rate to control when events # are checked for triggering # - Handle the sampling_rate logic: only check events at the specified intervals, -# not at every simulation time step \ No newline at end of file +# not at every simulation time step From bc80a1f2a9ee2497cb590e1d3b867f2fea868ac5 Mon Sep 17 00:00:00 2001 From: Julio Machado Date: Wed, 29 Apr 2026 16:30:29 -0300 Subject: [PATCH 3/4] ENH: Events class - Update of the verify trigger method, implementation of the verify action method and adaptation of the out of rail trigger and action functions in flight class in order to pass in the new verifications Co-authored-by: Copilot --- rocketpy/simulation/events.py | 140 ++++++--------------------- rocketpy/simulation/flight.py | 14 ++- tests/unit/simulation/test_events.py | 38 +++++++- 3 files changed, 77 insertions(+), 115 deletions(-) diff --git a/rocketpy/simulation/events.py b/rocketpy/simulation/events.py index deb5fd779..79d6e6ccc 100644 --- a/rocketpy/simulation/events.py +++ b/rocketpy/simulation/events.py @@ -1,9 +1,10 @@ import inspect -import warnings from typing import get_type_hints class Event: + """A class representing an event in the simulation.""" + # TODO: should "sensors" arg of the trigger function be a dictionary instead # of a list? It would be more intuitive to access the sensors by name def __init__(self, trigger, action, name, event_context=None): @@ -102,7 +103,6 @@ def __init__(self, trigger, action, name, event_context=None): # or disabled. The disable_event flag from the action return value should # control whether this event continues to be checked for triggering. - # TODO: check_trigger does note receive enough arguments to substitute parachute events def __verify_trigger(self, trigger): """Verifies that the trigger function is valid. @@ -119,107 +119,23 @@ def __verify_trigger(self, trigger): Raises ------ ValueError - If the trigger function does not have the correct signature or does not return a boolean value. + If the trigger function does not have the correct signature or does not return a boolean value + (at least if not declared or annotated). """ - # TODO: implement inspection of trigger function to verify: - # 1. It accepts **kwargs (accepts arbitrary keyword arguments) - # 2. Return type annotation is bool or can be tested to return bool - # 3. Consider allowing signature to be flexible (accepts **kwargs) - # to accommodate user-defined custom event_context keys - # verify if the return type is bool when annotated - return_annotation = get_type_hints(trigger).get("return", None) - if return_annotation is not None and return_annotation is not bool: - raise ValueError( - f"Trigger function {self.name} must return a boolean value." - ) - # verify if the trigger function accepts **kwargs and therefore can - # receive standard event arguments plus custom event_context keys + # verify if the trigger function accepts only **kwargs arguments s = inspect.signature(trigger) - if not any( - p.kind == inspect.Parameter.VAR_KEYWORD for p in s.parameters.values() - ): + if any(p.kind != inspect.Parameter.VAR_KEYWORD for p in s.parameters.values()): raise ValueError( - f"Trigger function {self.name} must accept **kwargs to receive event context " - f"and simulation state." + f"The Trigger function of the {self.name} event must accept only keyword arguments. def {trigger.__name__}(**kwargs) -> bool:" ) - if any( - p.kind == inspect.Parameter.POSITIONAL_ONLY for p in s.parameters.values() - ): + # Verify if the return type annotation is bool. + # Since is not possible to know for sure if the user is actually returning a bool value, + # we enforce bool annotation to motivate users to actually return bool values. + return_annotation = get_type_hints(trigger).get("return", None) + if return_annotation is not bool: raise ValueError( - f"Trigger function {self.name} must accept keyword arguments; " - "positional-only parameters are not supported." - ) - - # Helper function to generate dummy values based on type annotations - # of parameters, allowing to test the function without real values - def _placeholder_for_parameter(parameter): - annotation = parameter.annotation - if annotation is inspect.Parameter.empty: - warnings.warn( - f"Trigger function {self.name}: Test with parameters skipped due " - f"to missing type annotation for parameter '{parameter.name}'. \n" - f"Is highly recommended that parameters have type annotations " - f"(var: type). Parameter '{parameter.name}' has no annotation." - ) - skip_test = True - return None, skip_test - if annotation in (int, float): - return 0, False - if annotation is bool: - return False, False - if annotation is str: - return "", False - if annotation in (list, tuple, set, dict): - return annotation(), False - origin = getattr(annotation, "__origin__", None) - if origin in (list, tuple, set, dict): - return origin(), False - return None, False - - # Build a dictionary with dummy values to test if function accepts **kwargs - # Include an unexpected argument to validate the function doesn't complain - test_kwargs = {"unexpected_kwarg": 123} - skip_test = False - # Iterate through function parameters to generate appropriate test values - for name, parameter in s.parameters.items(): - if parameter.kind in ( - inspect.Parameter.POSITIONAL_OR_KEYWORD, - inspect.Parameter.KEYWORD_ONLY, - ): - if parameter.default is inspect.Parameter.empty: - annotation = parameter.annotation - if annotation in (list, tuple, set, dict): - skip_test = True - elif hasattr(annotation, "__origin__") and getattr( - annotation, "__origin__", None - ) in (list, tuple, set, dict): - skip_test = True - else: - test_kwargs[name], skip_test = _placeholder_for_parameter( - parameter - ) - # Execute the trigger function with test values to validate compatibility - # If TypeError occurs, the function doesn't properly accept **kwargs - if not skip_test: - try: - trigger(**test_kwargs) - except TypeError as e: - raise ValueError( - f"Trigger function {self.name} must accept arbitrary kwargs without raising " - "a TypeError." - ) from e - except Exception as e: - raise ValueError( - f"Trigger function {self.name} must accept arbitrary kwargs without raising " - f"an error: {e}" - ) from e - else: - # Test was skipped due to complex types; warn user to validate manually - warnings.warn( - f"Trigger function {self.name}: Test with parameters " - f"skipped for parameters with complex types " - f"(list, tuple, set, dict). Ensure the function handles " - f"arbitrary inputs gracefully." + f"The Trigger function of the {self.name} event must return a boolean value and must be annotated with '-> bool' for type checking.\n" + f"def {trigger.__name__}(**kwargs) -> bool:" ) return trigger @@ -239,17 +155,25 @@ def __verify_action(self, action): Raises ------ ValueError - If the action function does not have the correct signature. + If the action function does not have the correct signature or does not return a valid type. """ - # TODO: implement inspection of action function to verify: - # 1. It accepts **kwargs (accepts arbitrary keyword arguments) - # 2. It can optionally return None or a dict with any of these keys: - # - \"state\": list of floats - # - \"disable_event\": bool - # - \"new_events\": list of Event objects - # - \"remove_events\": list of Event objects - # - Any custom keys to update event_context - # 3. Raise ValueError if signature doesn't match expectations + # verify if the action function accepts only **kwargs arguments + s = inspect.signature(action) + if any(p.kind != inspect.Parameter.VAR_KEYWORD for p in s.parameters.values()): + raise ValueError( + f"The Action function of the {self.name} event must accept only keyword arguments. def {action.__name__}(**kwargs) -> None or dict:" + ) + # verify if the return type annotation is None or dict + # Since is not possible to know for sure if the user is actually returning None or a dict, + # we enforce None or dict annotation to motivate users to actually return None or dict. + return_annotation = get_type_hints(action).get("return", None) + if return_annotation is not None and return_annotation is not ( + type(None) or dict + ): + raise ValueError( + f"The Action function of the {self.name} event must return None or a dictionary and must be annotated with '-> None' or '-> dict' for type checking.\n" + f"def {action.__name__}(**kwargs) -> None or dict:" + ) return action def __repr__(self): diff --git a/rocketpy/simulation/flight.py b/rocketpy/simulation/flight.py index 50dbf5a8d..614fe0852 100644 --- a/rocketpy/simulation/flight.py +++ b/rocketpy/simulation/flight.py @@ -619,7 +619,8 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements self.ode_solver = ode_solver # Events - def out_of_rail_trigger(state, **kwargs) -> bool: + def out_of_rail_trigger(**kwargs) -> bool: + state = kwargs["state"] return ( state[0] ** 2 + state[1] ** 2 + (state[2] - self.env.elevation) ** 2 >= self.effective_1rl**2 @@ -1031,8 +1032,10 @@ def __check_simulation_events(self, phase, phase_index, node_index): # TODO: make all these 3 events be handled with the Events class # Check for first out of rail event if len(self.out_of_rail_state) == 1: - if self.out_of_rail_event.trigger(self.y_sol): - return self.out_of_rail_event.action(phase, phase_index, node_index) + if self.out_of_rail_event.trigger(state=self.y_sol): + return self.out_of_rail_event.action( + phase=phase, phase_index=phase_index, node_index=node_index + ) # Check for apogee event # TODO: negative vz doesn't really mean apogee. Improve this. @@ -1045,7 +1048,7 @@ def __check_simulation_events(self, phase, phase_index, node_index): return False - def __handle_out_of_rail_event(self, phase, phase_index, node_index): + def __handle_out_of_rail_event(self, **kwargs): """Handle the out of rail event. Parameters @@ -1062,6 +1065,9 @@ def __handle_out_of_rail_event(self, phase, phase_index, node_index): bool True to indicate the simulation should break. """ + phase = kwargs.get("phase") + phase_index = kwargs.get("phase_index") + node_index = kwargs.get("node_index") # Check exactly when it went out using root finding # Disconsider elevation self.solution[-2][3] -= self.env.elevation diff --git a/tests/unit/simulation/test_events.py b/tests/unit/simulation/test_events.py index f8e272b31..057e84cf0 100644 --- a/tests/unit/simulation/test_events.py +++ b/tests/unit/simulation/test_events.py @@ -3,8 +3,8 @@ from rocketpy.simulation.events import Event -def test_verify_trigger_accepts_required_args_with_kwargs(): - def trigger(a: int, b: float, **kwargs) -> bool: +def test_verify_trigger_accepts_only_kwargs(): + def trigger(**kwargs) -> bool: return True def action(**kwargs): @@ -21,5 +21,37 @@ def trigger(a, b) -> bool: def action(**kwargs): return None - with pytest.raises(ValueError, match=r"must accept \*\*kwargs"): + with pytest.raises( + ValueError, + match=r"The Trigger function of the test event must accept only keyword arguments. def trigger\(\*\*kwargs\) -> bool:", + ): + Event(trigger=trigger, action=action, name="test") + + +def test_verify_trigger_rejects_args_with_kwargs(): + def trigger(a, b, **kwargs) -> bool: + return True + + def action(**kwargs): + return None + + with pytest.raises( + ValueError, + match=r"The Trigger function of the test event must accept only keyword arguments. def trigger\(\*\*kwargs\) -> bool:", + ): + Event(trigger=trigger, action=action, name="test") + + +def test_verify_trigger_rejects_triggers_without_bool_return_annotation(): + def trigger(**kwargs): + return True + + def action(**kwargs): + return None + + with pytest.raises( + ValueError, + match="The Trigger function of the test event must return a boolean value and must be annotated with '-> bool' for type checking.\n" + + r"def trigger\(\*\*kwargs\) -> bool\:", + ): Event(trigger=trigger, action=action, name="test") From fa55a623734863050c03e2566bb4ae76eaa297f6 Mon Sep 17 00:00:00 2001 From: Julio Machado Date: Mon, 4 May 2026 13:17:01 -0300 Subject: [PATCH 4/4] ENH: correction and update of verify trigger and action methods and implementation of new tests, mostly for verify actions method. --- rocketpy/simulation/events.py | 23 ++++-- rocketpy/simulation/flight.py | 10 +-- tests/unit/simulation/test_events.py | 116 ++++++++++++++++++++++++++- 3 files changed, 133 insertions(+), 16 deletions(-) diff --git a/rocketpy/simulation/events.py b/rocketpy/simulation/events.py index 79d6e6ccc..24e650dd0 100644 --- a/rocketpy/simulation/events.py +++ b/rocketpy/simulation/events.py @@ -123,8 +123,12 @@ def __verify_trigger(self, trigger): (at least if not declared or annotated). """ # verify if the trigger function accepts only **kwargs arguments + # also avoids functions with no arguments, since they can't be used as triggers s = inspect.signature(trigger) - if any(p.kind != inspect.Parameter.VAR_KEYWORD for p in s.parameters.values()): + if ( + any(p.kind != inspect.Parameter.VAR_KEYWORD for p in s.parameters.values()) + or len(s.parameters) == 0 + ): raise ValueError( f"The Trigger function of the {self.name} event must accept only keyword arguments. def {trigger.__name__}(**kwargs) -> bool:" ) @@ -159,20 +163,25 @@ def __verify_action(self, action): """ # verify if the action function accepts only **kwargs arguments s = inspect.signature(action) - if any(p.kind != inspect.Parameter.VAR_KEYWORD for p in s.parameters.values()): + if ( + any(p.kind != inspect.Parameter.VAR_KEYWORD for p in s.parameters.values()) + or len(s.parameters) == 0 + ): raise ValueError( - f"The Action function of the {self.name} event must accept only keyword arguments. def {action.__name__}(**kwargs) -> None or dict:" + f"The Action function of the {self.name} event must accept only keyword arguments. def {action.__name__}(**kwargs) -> None | dict:" ) # verify if the return type annotation is None or dict # Since is not possible to know for sure if the user is actually returning None or a dict, # we enforce None or dict annotation to motivate users to actually return None or dict. - return_annotation = get_type_hints(action).get("return", None) - if return_annotation is not None and return_annotation is not ( - type(None) or dict + return_annotation = get_type_hints(action).get("return", int) + if ( + (return_annotation is not type(None)) + and (return_annotation is not dict) + and (return_annotation is not bool) ): raise ValueError( f"The Action function of the {self.name} event must return None or a dictionary and must be annotated with '-> None' or '-> dict' for type checking.\n" - f"def {action.__name__}(**kwargs) -> None or dict:" + f"def {action.__name__}(**kwargs) -> None | dict:" ) return action diff --git a/rocketpy/simulation/flight.py b/rocketpy/simulation/flight.py index 614fe0852..2e38a6378 100644 --- a/rocketpy/simulation/flight.py +++ b/rocketpy/simulation/flight.py @@ -1048,10 +1048,10 @@ def __check_simulation_events(self, phase, phase_index, node_index): return False - def __handle_out_of_rail_event(self, **kwargs): + def __handle_out_of_rail_event(self, **kwargs) -> bool: """Handle the out of rail event. - Parameters + keyword arguments are passed by the Event class when the trigger function is called. ---------- phase : FlightPhase The current flight phase. @@ -1065,9 +1065,9 @@ def __handle_out_of_rail_event(self, **kwargs): bool True to indicate the simulation should break. """ - phase = kwargs.get("phase") - phase_index = kwargs.get("phase_index") - node_index = kwargs.get("node_index") + phase = kwargs["phase"] + phase_index = kwargs["phase_index"] + node_index = kwargs["node_index"] # Check exactly when it went out using root finding # Disconsider elevation self.solution[-2][3] -= self.env.elevation diff --git a/tests/unit/simulation/test_events.py b/tests/unit/simulation/test_events.py index 057e84cf0..9e569eeb8 100644 --- a/tests/unit/simulation/test_events.py +++ b/tests/unit/simulation/test_events.py @@ -7,18 +7,35 @@ def test_verify_trigger_accepts_only_kwargs(): def trigger(**kwargs) -> bool: return True - def action(**kwargs): + def action(**kwargs) -> None: return None event = Event(trigger=trigger, action=action, name="test") assert event.trigger is trigger +def test_verify_trigger_evaluation_of_number_of_parameters(): + def trigger(**kwargs) -> bool: + a = kwargs["a"] + b = kwargs["b"] + c = kwargs["c"] + return a + b + c == 6 + + def action(**kwargs) -> None: + return None + + kwargs_test = {"a": 1, "b": 2, "c": 3} + assert trigger(**kwargs_test) + + event = Event(trigger=trigger, action=action, name="test") + assert event.trigger is trigger + + def test_verify_trigger_rejects_missing_kwargs(): def trigger(a, b) -> bool: return True - def action(**kwargs): + def action(**kwargs) -> None: return None with pytest.raises( @@ -32,7 +49,21 @@ def test_verify_trigger_rejects_args_with_kwargs(): def trigger(a, b, **kwargs) -> bool: return True - def action(**kwargs): + def action(**kwargs) -> None: + return None + + with pytest.raises( + ValueError, + match=r"The Trigger function of the test event must accept only keyword arguments. def trigger\(\*\*kwargs\) -> bool:", + ): + Event(trigger=trigger, action=action, name="test") + + +def test_verify_trigger_rejects_triggers_with_no_parameters(): + def trigger() -> bool: + return True + + def action(**kwargs) -> None: return None with pytest.raises( @@ -46,7 +77,7 @@ def test_verify_trigger_rejects_triggers_without_bool_return_annotation(): def trigger(**kwargs): return True - def action(**kwargs): + def action(**kwargs) -> None: return None with pytest.raises( @@ -55,3 +86,80 @@ def action(**kwargs): + r"def trigger\(\*\*kwargs\) -> bool\:", ): Event(trigger=trigger, action=action, name="test") + + +# The following tests verify if action functions were correctly implemented + + +def test_verify_action_accepts_only_kwargs(): + def trigger(**kwargs) -> bool: + return True + + def action(**kwargs) -> None: + return None + + event = Event(trigger=trigger, action=action, name="test") + assert event.action is action + + +def test_verify_action_rejects_missing_kwargs(): + def trigger(**kwargs) -> bool: + return True + + def action(a, b) -> None: + return None + + with pytest.raises( + ValueError, + match=r"The Action function of the test event must accept only keyword arguments. def action\(\*\*kwargs\) -> None \| dict:", + ): + Event(trigger=trigger, action=action, name="test") + + +def test_verify_action_rejects_args_with_kwargs(): + def trigger(**kwargs) -> bool: + return True + + def action(a, b, **kwargs) -> None: + return None + + with pytest.raises( + ValueError, + match=r"The Action function of the test event must accept only keyword arguments. def action\(\*\*kwargs\) -> None \| dict:", + ): + Event(trigger=trigger, action=action, name="test") + + +def test_verify_action_accepts_dict_return_type(): + def trigger(**kwargs) -> bool: + return True + + def action(**kwargs) -> dict: + return {"key": "value"} + + event = Event(trigger=trigger, action=action, name="test") + assert event.action is action + + +def test_verify_action_accepts_none_return_type(): + def trigger(**kwargs) -> bool: + return True + + def action(**kwargs) -> None: + return None + + event = Event(trigger=trigger, action=action, name="test") + assert event.action is action + + +# this was also allowed because some actions functions already return bool, they need to be updated +# then this test can be removed and the check for bool return type can be removed from the events.py file +def test_verify_action_accepts_bool_return_type(): + def trigger(**kwargs) -> bool: + return True + + def action(**kwargs) -> bool: + return True + + event = Event(trigger=trigger, action=action, name="test") + assert event.action is action