diff --git a/docs/async.md b/docs/async.md index 3680d7ac..313d03f8 100644 --- a/docs/async.md +++ b/docs/async.md @@ -75,7 +75,7 @@ the initial state: ... print(list(sm.configuration_values)) >>> asyncio.run(show_problem()) -[None] +[] ``` diff --git a/statemachine/configuration.py b/statemachine/configuration.py index 8c8d5d5f..c17123aa 100644 --- a/statemachine/configuration.py +++ b/statemachine/configuration.py @@ -55,75 +55,53 @@ def value(self) -> Any: @value.setter def value(self, val: Any): - self._invalidate() - if val is not None and not isinstance(val, MutableSet) and val not in self._states_map: - raise InvalidStateValue(val) - setattr(self._model, self._state_field, val) + if val is None: + self._write_to_model(OrderedSet()) + elif isinstance(val, MutableSet): + self._write_to_model(OrderedSet(val) if not isinstance(val, OrderedSet) else val) + else: + self._write_to_model(OrderedSet([val])) @property def values(self) -> OrderedSet[Any]: """The set of raw state values currently active.""" - v = self.value - if isinstance(v, OrderedSet): - return v - return OrderedSet([v]) + return self._read_from_model() # -- Resolved states ------------------------------------------------------- @property def states(self) -> "OrderedSet[State]": """The set of currently active :class:`State` instances (cached).""" - csv = self.value - if self._cached is not None and self._cached_value is csv: + raw = self.value + if self._cached is not None and self._cached_value is raw: return self._cached - if csv is None: + if raw is None: return OrderedSet() - instance_states = self._instance_states - if not isinstance(csv, MutableSet): - result = OrderedSet([instance_states[self._states_map[csv].id]]) - else: - result = OrderedSet([instance_states[self._states_map[v].id] for v in csv]) - + # Normalize inline (avoid second getattr via _read_from_model) + values = raw if isinstance(raw, MutableSet) else (raw,) + result = OrderedSet(self._instance_states[self._states_map[v].id] for v in values) self._cached = result - self._cached_value = csv + self._cached_value = raw return result @states.setter def states(self, new_configuration: "OrderedSet[State]"): - if len(new_configuration) == 0: - self.value = None - elif len(new_configuration) == 1: - self.value = next(iter(new_configuration)).value - else: - self.value = OrderedSet(s.value for s in new_configuration) + self._write_to_model(OrderedSet(s.value for s in new_configuration)) # -- Incremental mutation (used by the engine) ----------------------------- def add(self, state: "State"): - """Add *state* to the configuration, maintaining the dual representation.""" - csv = self.value - if csv is None: - self.value = state.value - elif isinstance(csv, MutableSet): - csv.add(state.value) - self.value = csv - else: - self.value = OrderedSet([csv, state.value]) + """Add *state* to the configuration.""" + values = self._read_from_model() + values.add(state.value) + self._write_to_model(values) def discard(self, state: "State"): - """Remove *state* from the configuration, normalizing back to scalar.""" - csv = self.value - if isinstance(csv, MutableSet): - csv.discard(state.value) - if len(csv) == 0: - self.value = None - elif len(csv) == 1: - self.value = next(iter(csv)) - else: - self.value = csv - elif csv == state.value: - self.value = None + """Remove *state* from the configuration.""" + values = self._read_from_model() + values.discard(state.value) + self._write_to_model(values) # -- Deprecated v2 compat -------------------------------------------------- @@ -153,7 +131,31 @@ def current_state(self) -> "State | OrderedSet[State]": except KeyError as err: raise InvalidStateValue(csv) from err - # -- Internal -------------------------------------------------------------- + # -- Internal: model boundary ---------------------------------------------- + + def _read_from_model(self) -> OrderedSet: + """Normalize: model value → always ``OrderedSet``.""" + raw = self.value + if raw is None: + return OrderedSet() + if isinstance(raw, OrderedSet): + return raw + if isinstance(raw, MutableSet): + return OrderedSet(raw) + return OrderedSet([raw]) + + def _write_to_model(self, values: OrderedSet): + """Denormalize: ``OrderedSet`` → ``None | scalar | OrderedSet`` for model.""" + self._invalidate() + if len(values) == 0: + raw = None + elif len(values) == 1: + raw = next(iter(values)) + else: + raw = values + if raw is not None and not isinstance(raw, MutableSet) and raw not in self._states_map: + raise InvalidStateValue(raw) + setattr(self._model, self._state_field, raw) def _invalidate(self): self._cached = None diff --git a/tests/test_api_contract.py b/tests/test_api_contract.py new file mode 100644 index 00000000..0ac7d2d6 --- /dev/null +++ b/tests/test_api_contract.py @@ -0,0 +1,280 @@ +"""Contract tests: observable behavior of public Configuration APIs. + +Documents the exact values returned by each public API across all supported +topologies (flat, compound, parallel, complex parallel) and lifecycle phases +(initial state, after transitions, final state). + +APIs under test (StateChart): + sm.current_state_value -- raw value stored on the model + sm.configuration_values -- OrderedSet of raw values + sm.configuration -- OrderedSet[State] + sm.current_state -- State or OrderedSet[State] (deprecated) + +API under test (Model): + model.state -- raw attribute on the model object +""" + +import warnings +from typing import Any + +import pytest +from statemachine.orderedset import OrderedSet + +from statemachine import State +from statemachine import StateChart + +# --------------------------------------------------------------------------- +# Model +# --------------------------------------------------------------------------- + + +class Model: + """Explicit model to verify raw state persistence independently.""" + + def __init__(self): + self.state: Any = None + + +# --------------------------------------------------------------------------- +# Topologies +# --------------------------------------------------------------------------- + + +class FlatSC(StateChart): + s1 = State(initial=True) + s2 = State() + s3 = State(final=True) + + go = s1.to(s2) + finish = s2.to(s3) + + +class CompoundSC(StateChart): + class parent(State.Compound): + child1 = State(initial=True) + child2 = State() + move = child1.to(child2) + + done = State(final=True) + leave = parent.to(done) + + +class ParallelSC(StateChart): + class regions(State.Parallel): + class region_a(State.Compound): + a1 = State(initial=True) + a2 = State() + go_a = a1.to(a2) + + class region_b(State.Compound): + b1 = State(initial=True) + b2 = State() + go_b = b1.to(b2) + + +class ComplexParallelSC(StateChart): + class top(State.Parallel): + class left(State.Compound): + class nested(State.Compound): + l1 = State(initial=True) + l2 = State() + move_l = l1.to(l2) + + left_done = State(final=True) + finish_left = nested.to(left_done) + + class right(State.Compound): + r1 = State(initial=True) + r2 = State() + move_r = r1.to(r2) + + +# --------------------------------------------------------------------------- +# Assertion helper +# --------------------------------------------------------------------------- + + +def assert_contract(sm, model, expected_ids: set): + """Assert the full observable API contract. + + When exactly one state is active, the model stores a scalar and + ``current_state`` returns a single ``State``. When multiple states + are active (compound/parallel), the model stores an ``OrderedSet`` + and ``current_state`` returns ``OrderedSet[State]``. + """ + scalar = len(expected_ids) == 1 + + # model.state and current_state_value point to the same object + assert model.state is sm.current_state_value + + if scalar: + val = next(iter(expected_ids)) + assert model.state == val + assert not isinstance(model.state, OrderedSet) + else: + assert isinstance(model.state, OrderedSet) + assert set(model.state) == expected_ids + + # configuration_values -- always OrderedSet of raw values + assert isinstance(sm.configuration_values, OrderedSet) + assert set(sm.configuration_values) == expected_ids + + # configuration -- always OrderedSet[State] + assert len(sm.configuration) == len(expected_ids) + assert {s.id for s in sm.configuration} == expected_ids + + # current_state (deprecated) -- unwrapped when single + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + cs = sm.current_state + if scalar: + assert not isinstance(cs, OrderedSet) + assert cs.id == next(iter(expected_ids)) + else: + assert isinstance(cs, OrderedSet) + assert {s.id for s in cs} == expected_ids + + +# --------------------------------------------------------------------------- +# Main contract matrix: topology x lifecycle x engine +# --------------------------------------------------------------------------- + + +SCENARIOS = [ + # -- Flat -- + pytest.param(FlatSC, [], {"s1"}, id="flat-initial"), + pytest.param(FlatSC, ["go"], {"s2"}, id="flat-after-go"), + pytest.param(FlatSC, ["go", "finish"], {"s3"}, id="flat-final"), + # -- Compound -- + pytest.param(CompoundSC, [], {"parent", "child1"}, id="compound-initial"), + pytest.param(CompoundSC, ["move"], {"parent", "child2"}, id="compound-inner-move"), + pytest.param(CompoundSC, ["leave"], {"done"}, id="compound-exit"), + # -- Parallel -- + pytest.param( + ParallelSC, + [], + {"regions", "region_a", "a1", "region_b", "b1"}, + id="parallel-initial", + ), + pytest.param( + ParallelSC, + ["go_a"], + {"regions", "region_a", "a2", "region_b", "b1"}, + id="parallel-one-region", + ), + pytest.param( + ParallelSC, + ["go_a", "go_b"], + {"regions", "region_a", "a2", "region_b", "b2"}, + id="parallel-both-regions", + ), + # -- Complex parallel -- + pytest.param( + ComplexParallelSC, + [], + {"top", "left", "nested", "l1", "right", "r1"}, + id="complex-initial", + ), + pytest.param( + ComplexParallelSC, + ["move_l"], + {"top", "left", "nested", "l2", "right", "r1"}, + id="complex-nested-move", + ), + pytest.param( + ComplexParallelSC, + ["move_r"], + {"top", "left", "nested", "l1", "right", "r2"}, + id="complex-other-region", + ), + pytest.param( + ComplexParallelSC, + ["move_l", "move_r"], + {"top", "left", "nested", "l2", "right", "r2"}, + id="complex-both-regions", + ), + pytest.param( + ComplexParallelSC, + ["finish_left"], + {"top", "left", "left_done", "right", "r1"}, + id="complex-exit-nested", + ), +] + + +@pytest.mark.parametrize(("sc_class", "events", "expected_ids"), SCENARIOS) +async def test_configuration_contract(sm_runner, sc_class, events, expected_ids): + model = Model() + sm = await sm_runner.start(sc_class, model=model) + for event in events: + await sm_runner.send(sm, event) + assert_contract(sm, model, expected_ids) + + +# --------------------------------------------------------------------------- +# Model setter contract +# --------------------------------------------------------------------------- + +SETTER_SCENARIOS = [ + pytest.param(FlatSC, "s2", {"s2"}, id="scalar-on-flat"), + pytest.param( + CompoundSC, + OrderedSet(["parent", "child2"]), + {"parent", "child2"}, + id="orderedset-on-compound", + ), + pytest.param(CompoundSC, "done", {"done"}, id="scalar-collapses-orderedset"), +] + + +@pytest.mark.parametrize(("sc_class", "new_value", "expected_ids"), SETTER_SCENARIOS) +async def test_setter_contract(sm_runner, sc_class, new_value, expected_ids): + model = Model() + sm = await sm_runner.start(sc_class, model=model) + sm.current_state_value = new_value + assert_contract(sm, model, expected_ids) + + +async def test_set_none_clears_configuration(sm_runner): + model = Model() + sm = await sm_runner.start(FlatSC, model=model) + + sm.current_state_value = None + + assert model.state is None + assert sm.current_state_value is None + assert sm.configuration_values == OrderedSet() + assert sm.configuration == OrderedSet() + + +# --------------------------------------------------------------------------- +# Uninitialized state (async-only: sync enters initial state in __init__) +# --------------------------------------------------------------------------- + +UNINITIALIZED_SCENARIOS = [ + pytest.param(FlatSC, {"s1"}, id="flat"), + pytest.param(CompoundSC, {"parent", "child1"}, id="compound"), + pytest.param( + ParallelSC, + {"regions", "region_a", "a1", "region_b", "b1"}, + id="parallel", + ), +] + + +@pytest.mark.parametrize(("sc_class", "expected_ids"), UNINITIALIZED_SCENARIOS) +async def test_uninitialized_then_activated(sc_class, expected_ids): + from tests.conftest import _AsyncListener + + model = Model() + sm = sc_class(model=model, listeners=[_AsyncListener()]) + + # Before activation: all APIs reflect empty configuration + assert model.state is None + assert sm.current_state_value is None + assert sm.configuration_values == OrderedSet() + assert sm.configuration == OrderedSet() + + # After activation: full contract holds + await sm.activate_initial_state() + assert_contract(sm, model, expected_ids) diff --git a/tests/test_configuration.py b/tests/test_configuration.py index 18961960..07f13ea6 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -41,6 +41,42 @@ def test_set_multi_element_configuration(self): assert sm.current_state_value == OrderedSet([ParallelSM.s1.value, ParallelSM.s2.value]) +class TestConfigurationValueSetter: + def test_set_value_none_writes_none_to_model(self): + sm = ParallelSM() + assert sm.current_state_value is not None + + sm.current_state_value = None + assert sm.current_state_value is None + assert sm.configuration_values == OrderedSet() + + def test_set_value_plain_set_coerces_to_ordered_set(self): + sm = ParallelSM() + s1_val = ParallelSM.s1.value + s2_val = ParallelSM.s2.value + + # Assign a plain set (MutableSet but not OrderedSet) + sm.current_state_value = {s1_val, s2_val} + # Model should store an OrderedSet (denormalized back to it) + assert isinstance(sm.current_state_value, OrderedSet) + assert sm.current_state_value == OrderedSet([s1_val, s2_val]) + + +class TestReadFromModelNonOrderedSet: + def test_read_from_model_coerces_plain_set(self): + """When the model stores a plain set, _read_from_model coerces it.""" + sm = ParallelSM() + s1_val = ParallelSM.s1.value + s2_val = ParallelSM.s2.value + + # Bypass the value setter to place a plain set on the model + setattr(sm._config._model, sm._config._state_field, {s1_val, s2_val}) + + values = sm._config._read_from_model() + assert isinstance(values, OrderedSet) + assert values == OrderedSet([s1_val, s2_val]) + + class TestConfigurationDiscard: def test_discard_nonmatching_scalar(self): sm = ParallelSM()