diff --git a/docs/source/tendencies.rst b/docs/source/tendencies.rst index 88030e0e..88458147 100644 --- a/docs/source/tendencies.rst +++ b/docs/source/tendencies.rst @@ -51,6 +51,28 @@ If the ``value`` is not specified, it will be set to the last value of the previ - {type: linear, to: 3, duration: 10} - {type: constant, duration: 10} +Value types +----------- + +The ``value`` is type-aware and is not restricted to floating-point numbers: + +* **Numeric values** (floats and integers) produce a floating-point signal that is + interpolated across gaps between tendencies, as usual. Integers are treated as + floats during evaluation. +* **Categorical values** (strings or booleans) describe non-numeric signals, such + as a heating scheme (``"nbi"``, ``"ec"``) or an on/off flag. These are held as a + step (zero-order hold) rather than interpolated: across a gap, the previous value + is carried forward, and the same applies when extrapolating beyond the waveform. + +.. code-block:: yaml + + - {type: constant, value: ohmic, duration: 2} + - {type: constant, value: nbi, duration: 2} + +.. note:: + Categorical and numeric values cannot be mixed within a single waveform, because + the gaps between them cannot be interpolated. Doing so raises a validation error. + Linear Tendency =============== diff --git a/tests/tendencies/test_constant.py b/tests/tendencies/test_constant.py index c21f87cf..b557dfec 100644 --- a/tests/tendencies/test_constant.py +++ b/tests/tendencies/test_constant.py @@ -86,3 +86,49 @@ def test_declarative_assignments(): assert t2.value == 6 assert not t1.annotations assert not t2.annotations + + +def test_integer_value_not_coerced(): + """Integer inputs are kept as integers, not coerced to float.""" + tendency = ConstantTendency(user_duration=1, user_value=5) + assert tendency.value == 5 + assert isinstance(tendency.value, int) + assert not tendency.is_categorical + assert not tendency.annotations + + +def test_string_value(): + """A constant tendency can hold a string (non-numeric) value.""" + tendency = ConstantTendency(user_duration=2, user_value="nbi") + assert tendency.value == "nbi" + assert tendency.is_categorical + assert tendency.start_value == "nbi" + assert tendency.end_value == "nbi" + + _, values = tendency.get_value(np.array([0.0, 1.0, 2.0])) + assert list(values) == ["nbi", "nbi", "nbi"] + assert not tendency.annotations + + +def test_boolean_value(): + """A constant tendency can hold a boolean value, kept as a (categorical) bool.""" + tendency = ConstantTendency(user_duration=2, user_value=True) + assert tendency.value is True + assert tendency.is_categorical + # start/end values round-trip through numpy, so compare by value (np.bool_) + assert bool(tendency.start_value) is True + assert bool(tendency.end_value) is True + + _, values = tendency.get_value(np.array([0.0, 1.0, 2.0])) + assert [bool(v) for v in values] == [True, True, True] + assert values.dtype == bool + assert not tendency.annotations + + +def test_string_value_chained(): + """A value-less string constant inherits the previous string value.""" + prev = ConstantTendency(user_value="ec", user_start=0, user_duration=1) + tendency = ConstantTendency(user_duration=1) + tendency.set_previous_tendency(prev) + assert tendency.value == "ec" + assert tendency.is_categorical diff --git a/tests/test_waveform.py b/tests/test_waveform.py index 6735d8a3..2786e760 100644 --- a/tests/test_waveform.py +++ b/tests/test_waveform.py @@ -273,3 +273,69 @@ def test_overlap_derivatives(): expected = [2, 2, -1.5, -1.5, -1.5, -1.5, -1.5] values = waveform.get_derivative(np.linspace(0, 3, 7)) assert np.allclose(values, expected) + + +def test_string_waveform(): + """A waveform of string constants evaluates as a zero-order-hold step function.""" + waveform = Waveform( + waveform=[ + {"user_type": "constant", "user_value": "ohmic", "user_duration": 2}, + {"user_type": "constant", "user_value": "nbi", "user_duration": 2}, + ] + ) + assert waveform.is_categorical + + _, values = waveform.get_value(np.array([0.0, 1.0, 2.0, 3.0])) + # The switch happens at t=2; later tendencies take precedence at the boundary. + assert list(values) == ["ohmic", "ohmic", "nbi", "nbi"] + + # Values are held (not interpolated) when extrapolating beyond the domain. + _, extrap = waveform.get_value(np.array([-1.0, 5.0])) + assert list(extrap) == ["ohmic", "nbi"] + + +def test_boolean_waveform(): + """A waveform of boolean constants evaluates as a zero-order-hold step function.""" + waveform = Waveform( + waveform=[ + {"user_type": "constant", "user_value": False, "user_duration": 2}, + {"user_type": "constant", "user_value": True, "user_duration": 2}, + ] + ) + assert waveform.is_categorical + + _, values = waveform.get_value(np.array([0.0, 1.0, 2.0, 3.0])) + assert [bool(v) for v in values] == [False, False, True, True] + + +def test_numeric_waveform_evaluates_to_float(): + """A numeric waveform evaluates to a float array (the int value is not stepwise).""" + waveform = Waveform( + waveform=[{"user_type": "constant", "user_value": 8, "user_duration": 3}] + ) + assert not waveform.is_categorical + _, values = waveform.get_value(np.array([0.0, 1.0, 2.0, 3.0])) + assert values.dtype == float + + +def test_mixing_categorical_and_numeric_is_flagged(): + """Mixing categorical (string/bool) and numeric tendencies adds an annotation.""" + waveform = Waveform( + waveform=[ + {"user_type": "constant", "user_value": "nbi", "user_duration": 2}, + {"user_type": "constant", "user_value": 3, "user_duration": 2}, + ] + ) + assert waveform.annotations + assert any("mix" in a["text"].lower() for a in waveform.annotations) + + +def test_homogeneous_categorical_not_flagged(): + """A waveform of only categorical values is not flagged as mixed.""" + waveform = Waveform( + waveform=[ + {"user_type": "constant", "user_value": "nbi", "user_duration": 2}, + {"user_type": "constant", "user_value": "ec", "user_duration": 2}, + ] + ) + assert not waveform.annotations diff --git a/waveform_editor/tendencies/base.py b/waveform_editor/tendencies/base.py index 50b43198..f92246d3 100644 --- a/waveform_editor/tendencies/base.py +++ b/waveform_editor/tendencies/base.py @@ -61,8 +61,8 @@ class BaseTendency(param.Parameterized): values from the start value of this tendency. """, ) - start_value = param.Number(default=0.0, doc="Value at self.start") - end_value = param.Number(default=0.0, doc="Value at self.end") + start_value = param.Parameter(default=0.0, doc="Value at self.start") + end_value = param.Parameter(default=0.0, doc="Value at self.end") start_derivative = param.Number(default=0.0, doc="Derivative at self.start") end_derivative = param.Number(default=0.0, doc="Derivative at self.end") @@ -77,6 +77,7 @@ class BaseTendency(param.Parameterized): ) annotations = param.ClassSelector(class_=Annotations, default=Annotations()) allow_zero_duration = False + is_categorical = False def __init__(self, **kwargs): super().__init__() diff --git a/waveform_editor/tendencies/constant.py b/waveform_editor/tendencies/constant.py index 6b1cd4af..01c8d7fd 100644 --- a/waveform_editor/tendencies/constant.py +++ b/waveform_editor/tendencies/constant.py @@ -10,7 +10,7 @@ class ConstantTendency(BaseTendency): Constant tendency class for a constant signal. """ - user_value = param.Number( + user_value = param.Parameter( default=None, doc="The constant value of the tendency provided by the user.", ) @@ -19,6 +19,13 @@ def __init__(self, **kwargs): self.value = 0.0 super().__init__(**kwargs) + @property + def is_categorical(self): + """Whether this constant holds a non-numeric (categorical) value, e.g. a + string or boolean, that is held as a step rather than interpolated.""" + value = self.value + return isinstance(value, bool) or not isinstance(value, (int, float)) + def get_value( self, time: np.ndarray | None = None ) -> tuple[np.ndarray, np.ndarray]: @@ -33,7 +40,7 @@ def get_value( """ if time is None: time = np.array([self.start, self.end]) - values = self.value * np.ones(len(time)) + values = np.full(len(time), self.value) return time, values def get_derivative(self, time: np.ndarray) -> np.ndarray: diff --git a/waveform_editor/waveform.py b/waveform_editor/waveform.py index 17011175..a7d77373 100644 --- a/waveform_editor/waveform.py +++ b/waveform_editor/waveform.py @@ -49,6 +49,12 @@ def __init__( if waveform is not None: self._process_waveform(waveform) + @property + def is_categorical(self): + """Whether this waveform produces non-numeric (categorical) values, e.g. + strings or booleans, which are held as steps rather than interpolated.""" + return any(t.is_categorical for t in self.tendencies) + def get_value( self, time: np.ndarray | None = None ) -> tuple[np.ndarray, np.ndarray]: @@ -97,7 +103,11 @@ def _evaluate_tendencies(self, time, eval_derivatives=False): Returns: numpy array containing the computed values. """ - values = np.zeros_like(time, dtype=float) + is_categorical = self.is_categorical and not eval_derivatives + if is_categorical: + values = np.empty(len(time), dtype=object) + else: + values = np.zeros_like(time, dtype=float) for i, tendency in enumerate(self.tendencies): mask = (time >= tendency.start) & (time <= tendency.end) @@ -107,17 +117,17 @@ def _evaluate_tendencies(self, time, eval_derivatives=False): else: _, values[mask] = tendency.get_value(time[mask]) - # Handle gaps between tendencies, we linearly interpolate between the - # gap values. + # Handle gaps between tendencies (hold for strings, interpolate otherwise). if i and tendency.prev_tendency.end < tendency.start: prev_tendency = tendency.prev_tendency mask = (time < tendency.start) & (time > prev_tendency.end) - slope = (tendency.start_value - prev_tendency.end_value) / ( - tendency.start - prev_tendency.end - ) if np.any(mask): if eval_derivatives: - values[mask] = slope + values[mask] = ( + tendency.start_value - prev_tendency.end_value + ) / (tendency.start - prev_tendency.end) + elif is_categorical: + values[mask] = prev_tendency.end_value else: values[mask] = np.interp( time[mask], @@ -173,11 +183,29 @@ def _process_waveform(self, waveform): self.tendencies[i - 1].set_next_tendency(self.tendencies[i]) self.tendencies[i].set_previous_tendency(self.tendencies[i - 1]) + self._validate_value_types() + self.update_annotations() for tendency in self.tendencies: tendency.param.watch(self.update_annotations, "annotations") + def _validate_value_types(self): + """Categorical (e.g. string or boolean) and numeric tendencies cannot be + mixed within a single waveform, as the gaps between them cannot be + interpolated. Flag the first tendency whose type breaks the pattern.""" + if not self.tendencies: + return + first_is_categorical = self.tendencies[0].is_categorical + for tendency in self.tendencies[1:]: + if tendency.is_categorical != first_is_categorical: + error_msg = ( + "Cannot mix categorical (e.g. string or boolean) and numeric " + "values within a single waveform.\n" + ) + self.annotations.add(tendency.line_number, error_msg) + break + def update_annotations(self, event=None): """Merges the annotations of the individual tendencies into the annotations of this waveform."""