Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions docs/source/tendencies.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,28 @@ If the ``value`` is not specified, it will be set to the last value of the previ
- {type: linear, to: 3, duration: 10}
- {type: constant, duration: 10}

Value types
-----------

The ``value`` is type-aware and is not restricted to floating-point numbers:

* **Numeric values** (floats and integers) produce a floating-point signal that is
interpolated across gaps between tendencies, as usual. Integers are treated as
floats during evaluation.
* **Categorical values** (strings or booleans) describe non-numeric signals, such
as a heating scheme (``"nbi"``, ``"ec"``) or an on/off flag. These are held as a
step (zero-order hold) rather than interpolated: across a gap, the previous value
is carried forward, and the same applies when extrapolating beyond the waveform.

.. code-block:: yaml

- {type: constant, value: ohmic, duration: 2}
- {type: constant, value: nbi, duration: 2}

.. note::
Categorical and numeric values cannot be mixed within a single waveform, because
the gaps between them cannot be interpolated. Doing so raises a validation error.

Linear Tendency
===============

Expand Down
46 changes: 46 additions & 0 deletions tests/tendencies/test_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,49 @@ def test_declarative_assignments():
assert t2.value == 6
assert not t1.annotations
assert not t2.annotations


def test_integer_value_not_coerced():
"""Integer inputs are kept as integers, not coerced to float."""
tendency = ConstantTendency(user_duration=1, user_value=5)
assert tendency.value == 5
assert isinstance(tendency.value, int)
assert not tendency.is_categorical
assert not tendency.annotations


def test_string_value():
"""A constant tendency can hold a string (non-numeric) value."""
tendency = ConstantTendency(user_duration=2, user_value="nbi")
assert tendency.value == "nbi"
assert tendency.is_categorical
assert tendency.start_value == "nbi"
assert tendency.end_value == "nbi"

_, values = tendency.get_value(np.array([0.0, 1.0, 2.0]))
assert list(values) == ["nbi", "nbi", "nbi"]
assert not tendency.annotations


def test_boolean_value():
"""A constant tendency can hold a boolean value, kept as a (categorical) bool."""
tendency = ConstantTendency(user_duration=2, user_value=True)
assert tendency.value is True
assert tendency.is_categorical
# start/end values round-trip through numpy, so compare by value (np.bool_)
assert bool(tendency.start_value) is True
assert bool(tendency.end_value) is True

_, values = tendency.get_value(np.array([0.0, 1.0, 2.0]))
assert [bool(v) for v in values] == [True, True, True]
assert values.dtype == bool
assert not tendency.annotations


def test_string_value_chained():
"""A value-less string constant inherits the previous string value."""
prev = ConstantTendency(user_value="ec", user_start=0, user_duration=1)
tendency = ConstantTendency(user_duration=1)
tendency.set_previous_tendency(prev)
assert tendency.value == "ec"
assert tendency.is_categorical
66 changes: 66 additions & 0 deletions tests/test_waveform.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,69 @@ def test_overlap_derivatives():
expected = [2, 2, -1.5, -1.5, -1.5, -1.5, -1.5]
values = waveform.get_derivative(np.linspace(0, 3, 7))
assert np.allclose(values, expected)


def test_string_waveform():
"""A waveform of string constants evaluates as a zero-order-hold step function."""
waveform = Waveform(
waveform=[
{"user_type": "constant", "user_value": "ohmic", "user_duration": 2},
{"user_type": "constant", "user_value": "nbi", "user_duration": 2},
]
)
assert waveform.is_categorical

_, values = waveform.get_value(np.array([0.0, 1.0, 2.0, 3.0]))
# The switch happens at t=2; later tendencies take precedence at the boundary.
assert list(values) == ["ohmic", "ohmic", "nbi", "nbi"]

# Values are held (not interpolated) when extrapolating beyond the domain.
_, extrap = waveform.get_value(np.array([-1.0, 5.0]))
assert list(extrap) == ["ohmic", "nbi"]


def test_boolean_waveform():
"""A waveform of boolean constants evaluates as a zero-order-hold step function."""
waveform = Waveform(
waveform=[
{"user_type": "constant", "user_value": False, "user_duration": 2},
{"user_type": "constant", "user_value": True, "user_duration": 2},
]
)
assert waveform.is_categorical

_, values = waveform.get_value(np.array([0.0, 1.0, 2.0, 3.0]))
assert [bool(v) for v in values] == [False, False, True, True]


def test_numeric_waveform_evaluates_to_float():
"""A numeric waveform evaluates to a float array (the int value is not stepwise)."""
waveform = Waveform(
waveform=[{"user_type": "constant", "user_value": 8, "user_duration": 3}]
)
assert not waveform.is_categorical
_, values = waveform.get_value(np.array([0.0, 1.0, 2.0, 3.0]))
assert values.dtype == float


def test_mixing_categorical_and_numeric_is_flagged():
"""Mixing categorical (string/bool) and numeric tendencies adds an annotation."""
waveform = Waveform(
waveform=[
{"user_type": "constant", "user_value": "nbi", "user_duration": 2},
{"user_type": "constant", "user_value": 3, "user_duration": 2},
]
)
assert waveform.annotations
assert any("mix" in a["text"].lower() for a in waveform.annotations)


def test_homogeneous_categorical_not_flagged():
"""A waveform of only categorical values is not flagged as mixed."""
waveform = Waveform(
waveform=[
{"user_type": "constant", "user_value": "nbi", "user_duration": 2},
{"user_type": "constant", "user_value": "ec", "user_duration": 2},
]
)
assert not waveform.annotations
5 changes: 3 additions & 2 deletions waveform_editor/tendencies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ class BaseTendency(param.Parameterized):
values from the start value of this tendency.
""",
)
start_value = param.Number(default=0.0, doc="Value at self.start")
end_value = param.Number(default=0.0, doc="Value at self.end")
start_value = param.Parameter(default=0.0, doc="Value at self.start")
end_value = param.Parameter(default=0.0, doc="Value at self.end")

start_derivative = param.Number(default=0.0, doc="Derivative at self.start")
end_derivative = param.Number(default=0.0, doc="Derivative at self.end")
Expand All @@ -77,6 +77,7 @@ class BaseTendency(param.Parameterized):
)
annotations = param.ClassSelector(class_=Annotations, default=Annotations())
allow_zero_duration = False
is_categorical = False

def __init__(self, **kwargs):
super().__init__()
Expand Down
11 changes: 9 additions & 2 deletions waveform_editor/tendencies/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class ConstantTendency(BaseTendency):
Constant tendency class for a constant signal.
"""

user_value = param.Number(
user_value = param.Parameter(
default=None,
doc="The constant value of the tendency provided by the user.",
)
Expand All @@ -19,6 +19,13 @@ def __init__(self, **kwargs):
self.value = 0.0
super().__init__(**kwargs)

@property
def is_categorical(self):
"""Whether this constant holds a non-numeric (categorical) value, e.g. a
string or boolean, that is held as a step rather than interpolated."""
value = self.value
return isinstance(value, bool) or not isinstance(value, (int, float))

def get_value(
Comment thread
DaanVanVugt marked this conversation as resolved.
self, time: np.ndarray | None = None
) -> tuple[np.ndarray, np.ndarray]:
Expand All @@ -33,7 +40,7 @@ def get_value(
"""
if time is None:
time = np.array([self.start, self.end])
values = self.value * np.ones(len(time))
values = np.full(len(time), self.value)
return time, values

def get_derivative(self, time: np.ndarray) -> np.ndarray:
Expand Down
42 changes: 35 additions & 7 deletions waveform_editor/waveform.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ def __init__(
if waveform is not None:
self._process_waveform(waveform)

@property
def is_categorical(self):
"""Whether this waveform produces non-numeric (categorical) values, e.g.
strings or booleans, which are held as steps rather than interpolated."""
return any(t.is_categorical for t in self.tendencies)

def get_value(
self, time: np.ndarray | None = None
) -> tuple[np.ndarray, np.ndarray]:
Expand Down Expand Up @@ -97,7 +103,11 @@ def _evaluate_tendencies(self, time, eval_derivatives=False):
Returns:
numpy array containing the computed values.
"""
values = np.zeros_like(time, dtype=float)
is_categorical = self.is_categorical and not eval_derivatives
if is_categorical:
values = np.empty(len(time), dtype=object)
else:
values = np.zeros_like(time, dtype=float)

for i, tendency in enumerate(self.tendencies):
mask = (time >= tendency.start) & (time <= tendency.end)
Expand All @@ -107,17 +117,17 @@ def _evaluate_tendencies(self, time, eval_derivatives=False):
else:
_, values[mask] = tendency.get_value(time[mask])

# Handle gaps between tendencies, we linearly interpolate between the
# gap values.
# Handle gaps between tendencies (hold for strings, interpolate otherwise).
if i and tendency.prev_tendency.end < tendency.start:
prev_tendency = tendency.prev_tendency
mask = (time < tendency.start) & (time > prev_tendency.end)
slope = (tendency.start_value - prev_tendency.end_value) / (
tendency.start - prev_tendency.end
)
if np.any(mask):
if eval_derivatives:
values[mask] = slope
values[mask] = (
tendency.start_value - prev_tendency.end_value
) / (tendency.start - prev_tendency.end)
elif is_categorical:
values[mask] = prev_tendency.end_value
else:
values[mask] = np.interp(
time[mask],
Expand Down Expand Up @@ -173,11 +183,29 @@ def _process_waveform(self, waveform):
self.tendencies[i - 1].set_next_tendency(self.tendencies[i])
self.tendencies[i].set_previous_tendency(self.tendencies[i - 1])

self._validate_value_types()

self.update_annotations()

for tendency in self.tendencies:
tendency.param.watch(self.update_annotations, "annotations")

def _validate_value_types(self):
"""Categorical (e.g. string or boolean) and numeric tendencies cannot be
mixed within a single waveform, as the gaps between them cannot be
interpolated. Flag the first tendency whose type breaks the pattern."""
if not self.tendencies:
return
first_is_categorical = self.tendencies[0].is_categorical
for tendency in self.tendencies[1:]:
if tendency.is_categorical != first_is_categorical:
error_msg = (
"Cannot mix categorical (e.g. string or boolean) and numeric "
"values within a single waveform.\n"
)
self.annotations.add(tendency.line_number, error_msg)
break

def update_annotations(self, event=None):
"""Merges the annotations of the individual tendencies into the annotations
of this waveform."""
Expand Down
Loading