diff --git a/src/node/node_inbound_message.h b/src/node/node_inbound_message.h index bc4278c3955..273501d9c10 100644 --- a/src/node/node_inbound_message.h +++ b/src/node/node_inbound_message.h @@ -8,18 +8,28 @@ #include "ds/state_machine.h" #include "node/node_types.h" +#include + namespace ccf { + inline bool can_process_node_inbound_message( + ::ds::StateMachine& sm) + { + static const std::set active_states{ + NodeStartupState::partOfNetwork, + NodeStartupState::partOfPublicNetwork, + NodeStartupState::readingPrivateLedger}; + + return sm.check_one_of(active_states); + } + // Reads a serialised node_inbound message and dispatches it to the - // appropriate handler, but only once the node is part of the network. This - // includes forwarded commands, which must not be executed until the node is - // part of the network, as some commands may otherwise exhibit undefined - // behaviour. + // appropriate handler. Callers must check can_process_node_inbound_message() + // before calling this. template void recv_node_inbound_message( const uint8_t* data, size_t size, - ::ds::StateMachine& sm, TForwarder* cmd_forwarder, TChannels* n2n_channels, TConsensus* consensus) @@ -30,19 +40,6 @@ namespace ccf const auto* payload_data = payload.data; auto payload_size = payload.size; - static const std::set active_states{ - NodeStartupState::partOfNetwork, - NodeStartupState::partOfPublicNetwork, - NodeStartupState::readingPrivateLedger}; - - if (!sm.check_one_of(active_states)) - { - LOG_DEBUG_FMT( - "Ignoring node msg received too early - current state is {}", - sm.value()); - return; - } - switch (msg_type) { case forwarded_msg: diff --git a/src/node/node_state.h b/src/node/node_state.h index 53d158831b8..8b4d01105b4 100644 --- a/src/node/node_state.h +++ b/src/node/node_state.h @@ -2255,13 +2255,16 @@ namespace ccf void recv_node_inbound(const uint8_t* data, size_t size) { + if (!can_process_node_inbound_message(sm)) + { + LOG_DEBUG_FMT( + "Ignoring node msg received too early - current state is {}", + sm.value()); + return; + } + recv_node_inbound_message( - data, - size, - sm, - cmd_forwarder.get(), - n2n_channels.get(), - consensus.get()); + data, size, cmd_forwarder.get(), n2n_channels.get(), consensus.get()); } // diff --git a/src/node/test/node_inbound_message.cpp b/src/node/test/node_inbound_message.cpp index 8198d845fbe..fd8c9fe3b62 100644 --- a/src/node/test/node_inbound_message.cpp +++ b/src/node/test/node_inbound_message.cpp @@ -54,7 +54,7 @@ namespace } TEST_CASE( - "recv_node_inbound_message gating" * + "node inbound message gate and dispatch" * doctest::test_suite("node_inbound_message")) { const ccf::NodeId from("0123456789abcdef"); @@ -71,155 +71,80 @@ TEST_CASE( ccf::NodeStartupState::partOfPublicNetwork, ccf::NodeStartupState::readingPrivateLedger}; - SUBCASE("Forwarded commands are not processed before part of network") + SUBCASE("Caller gate rejects early states") { - const auto serialised = - serialise_node_inbound(ccf::forwarded_msg, from, payload); - for (const auto state : early_states) { INFO("Early state: ", state); ds::StateMachine sm("test", state); - StubHandler forwarder; - StubHandler channels; - StubHandler consensus; - - ccf::recv_node_inbound_message( - serialised.data(), - serialised.size(), - sm, - &forwarder, - &channels, - &consensus); - - REQUIRE(forwarder.call_count == 0); - REQUIRE(channels.call_count == 0); - REQUIRE(consensus.call_count == 0); + REQUIRE(!ccf::can_process_node_inbound_message(sm)); } } - SUBCASE("Forwarded commands are processed once part of network") + SUBCASE("Caller gate accepts active states") { - const auto serialised = - serialise_node_inbound(ccf::forwarded_msg, from, payload); - for (const auto state : active_states) { INFO("Active state: ", state); ds::StateMachine sm("test", state); - StubHandler forwarder; - StubHandler channels; - StubHandler consensus; - - ccf::recv_node_inbound_message( - serialised.data(), - serialised.size(), - sm, - &forwarder, - &channels, - &consensus); - - REQUIRE(forwarder.call_count == 1); - REQUIRE(forwarder.last_from == from); - REQUIRE(forwarder.last_payload == payload); - REQUIRE(channels.call_count == 0); - REQUIRE(consensus.call_count == 0); + REQUIRE(ccf::can_process_node_inbound_message(sm)); } } - SUBCASE("Channel messages are gated and dispatched identically") + SUBCASE("Forwarded commands dispatch to command forwarder") + { + const auto serialised = + serialise_node_inbound(ccf::forwarded_msg, from, payload); + + StubHandler forwarder; + StubHandler channels; + StubHandler consensus; + + ccf::recv_node_inbound_message( + serialised.data(), serialised.size(), &forwarder, &channels, &consensus); + + REQUIRE(forwarder.call_count == 1); + REQUIRE(forwarder.last_from == from); + REQUIRE(forwarder.last_payload == payload); + REQUIRE(channels.call_count == 0); + REQUIRE(consensus.call_count == 0); + } + + SUBCASE("Channel messages dispatch to node-to-node channels") { const auto serialised = serialise_node_inbound(ccf::channel_msg, from, payload); - { - INFO("Dropped before part of network"); - ds::StateMachine sm( - "test", ccf::NodeStartupState::pending); - StubHandler forwarder; - StubHandler channels; - StubHandler consensus; - - ccf::recv_node_inbound_message( - serialised.data(), - serialised.size(), - sm, - &forwarder, - &channels, - &consensus); - - REQUIRE(channels.call_count == 0); - } + StubHandler forwarder; + StubHandler channels; + StubHandler consensus; - { - INFO("Dispatched once part of network"); - ds::StateMachine sm( - "test", ccf::NodeStartupState::partOfNetwork); - StubHandler forwarder; - StubHandler channels; - StubHandler consensus; - - ccf::recv_node_inbound_message( - serialised.data(), - serialised.size(), - sm, - &forwarder, - &channels, - &consensus); - - REQUIRE(channels.call_count == 1); - REQUIRE(channels.last_from == from); - REQUIRE(channels.last_payload == payload); - REQUIRE(forwarder.call_count == 0); - REQUIRE(consensus.call_count == 0); - } + ccf::recv_node_inbound_message( + serialised.data(), serialised.size(), &forwarder, &channels, &consensus); + + REQUIRE(channels.call_count == 1); + REQUIRE(channels.last_from == from); + REQUIRE(channels.last_payload == payload); + REQUIRE(forwarder.call_count == 0); + REQUIRE(consensus.call_count == 0); } - SUBCASE("Consensus messages are gated and dispatched identically") + SUBCASE("Consensus messages dispatch to consensus") { const auto serialised = serialise_node_inbound(ccf::consensus_msg, from, payload); - { - INFO("Dropped before part of network"); - ds::StateMachine sm( - "test", ccf::NodeStartupState::initialized); - StubHandler forwarder; - StubHandler channels; - StubHandler consensus; - - ccf::recv_node_inbound_message( - serialised.data(), - serialised.size(), - sm, - &forwarder, - &channels, - &consensus); - - REQUIRE(consensus.call_count == 0); - } + StubHandler forwarder; + StubHandler channels; + StubHandler consensus; - { - INFO("Dispatched once part of network"); - ds::StateMachine sm( - "test", ccf::NodeStartupState::readingPrivateLedger); - StubHandler forwarder; - StubHandler channels; - StubHandler consensus; - - ccf::recv_node_inbound_message( - serialised.data(), - serialised.size(), - sm, - &forwarder, - &channels, - &consensus); - - REQUIRE(consensus.call_count == 1); - REQUIRE(consensus.last_from == from); - REQUIRE(consensus.last_payload == payload); - REQUIRE(forwarder.call_count == 0); - REQUIRE(channels.call_count == 0); - } + ccf::recv_node_inbound_message( + serialised.data(), serialised.size(), &forwarder, &channels, &consensus); + + REQUIRE(consensus.call_count == 1); + REQUIRE(consensus.last_from == from); + REQUIRE(consensus.last_payload == payload); + REQUIRE(forwarder.call_count == 0); + REQUIRE(channels.call_count == 0); } }