diff --git a/CHANGELOG.md b/CHANGELOG.md index fb3ee355f06b..5ab687618bf9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,14 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). +## [7.0.6] + +[7.0.6]: https://github.com/microsoft/CCF/releases/tag/ccf-7.0.6 + +### Fixed + +- Forwarded commands are no longer processed until the node is part of the network, matching the existing behaviour for other node-to-node messages. Previously a forwarded command could be executed while the node was in an earlier startup state, which could lead to undefined behaviour for some commands (#7936). + ## [7.0.5] [7.0.5]: https://github.com/microsoft/CCF/releases/tag/ccf-7.0.5 diff --git a/CMakeLists.txt b/CMakeLists.txt index d437381db79f..9f6482570118 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -591,6 +591,11 @@ if(BUILD_TESTS) ${CMAKE_CURRENT_SOURCE_DIR}/src/ds/test/parse_json_safe.cpp ) + add_unit_test( + state_machine_test + ${CMAKE_CURRENT_SOURCE_DIR}/src/ds/test/state_machine.cpp + ) + add_unit_test( logger_test ${CMAKE_CURRENT_SOURCE_DIR}/src/ds/test/logger.cpp @@ -780,6 +785,11 @@ if(BUILD_TESTS) PRIVATE ccf_kv ccf_endpoints ccf_tasks ) + add_unit_test( + node_inbound_message_test + ${CMAKE_CURRENT_SOURCE_DIR}/src/node/test/node_inbound_message.cpp + ) + add_unit_test( indexing_test ${CMAKE_CURRENT_SOURCE_DIR}/src/indexing/test/indexing.cpp diff --git a/python/pyproject.toml b/python/pyproject.toml index e8809e787664..ebcac209cc44 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "ccf" -version = "7.0.5" +version = "7.0.6" authors = [ { name="CCF Team", email="CCF-Sec@microsoft.com" }, ] diff --git a/src/ds/state_machine.h b/src/ds/state_machine.h index 68c2328fc1cf..0eb4c194f029 100644 --- a/src/ds/state_machine.h +++ b/src/ds/state_machine.h @@ -5,6 +5,8 @@ #include "ds/internal_logger.h" #include +#include +#include #include #include @@ -37,6 +39,11 @@ namespace ds return state_ == state.load(); } + [[nodiscard]] bool check_one_of(const std::set& states) const + { + return states.contains(state.load()); + } + [[nodiscard]] T value() const { return state.load(); diff --git a/src/ds/test/state_machine.cpp b/src/ds/test/state_machine.cpp new file mode 100644 index 000000000000..2426e219233e --- /dev/null +++ b/src/ds/test/state_machine.cpp @@ -0,0 +1,102 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. + +#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN +#include "../state_machine.h" + +#include + +namespace +{ + enum class Example + { + A, + B, + C, + D + }; +} + +template <> +struct fmt::formatter : fmt::formatter +{ + template + auto format(Example e, FormatContext& ctx) const + { + std::string_view name = "unknown"; + switch (e) + { + case Example::A: + name = "A"; + break; + case Example::B: + name = "B"; + break; + case Example::C: + name = "C"; + break; + case Example::D: + name = "D"; + break; + } + return fmt::formatter::format(name, ctx); + } +}; + +TEST_CASE("Basic state machine" * doctest::test_suite("state_machine")) +{ + ds::StateMachine sm("example", Example::A); + + REQUIRE(sm.value() == Example::A); + REQUIRE(sm.check(Example::A)); + REQUIRE_FALSE(sm.check(Example::B)); + REQUIRE_NOTHROW(sm.expect(Example::A)); + REQUIRE_THROWS_AS(sm.expect(Example::B), std::logic_error); + + sm.advance(Example::B); + + REQUIRE(sm.value() == Example::B); + REQUIRE(sm.check(Example::B)); + REQUIRE_FALSE(sm.check(Example::A)); + REQUIRE_NOTHROW(sm.expect(Example::B)); + REQUIRE_THROWS_AS(sm.expect(Example::A), std::logic_error); +} + +TEST_CASE("check_one_of" * doctest::test_suite("state_machine")) +{ + ds::StateMachine sm("example", Example::A); + + { + INFO("Current state is part of the set"); + REQUIRE(sm.check_one_of({Example::A})); + REQUIRE(sm.check_one_of({Example::A, Example::B})); + REQUIRE(sm.check_one_of({Example::A, Example::B, Example::C})); + } + + { + INFO("Current state is not part of the set"); + REQUIRE_FALSE(sm.check_one_of({Example::B})); + REQUIRE_FALSE(sm.check_one_of({Example::B, Example::C})); + REQUIRE_FALSE(sm.check_one_of({Example::B, Example::C, Example::D})); + } + + { + INFO("Empty set never matches"); + REQUIRE_FALSE(sm.check_one_of({})); + } + + { + INFO("Set membership follows state transitions"); + const std::set states{Example::B, Example::C}; + REQUIRE_FALSE(sm.check_one_of(states)); + + sm.advance(Example::B); + REQUIRE(sm.check_one_of(states)); + + sm.advance(Example::C); + REQUIRE(sm.check_one_of(states)); + + sm.advance(Example::D); + REQUIRE_FALSE(sm.check_one_of(states)); + } +} diff --git a/src/node/node_inbound_message.h b/src/node/node_inbound_message.h new file mode 100644 index 000000000000..bc4278c39553 --- /dev/null +++ b/src/node/node_inbound_message.h @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. +#pragma once + +#include "ccf/entity_id.h" +#include "ccf/node_startup_state.h" +#include "ds/internal_logger.h" +#include "ds/state_machine.h" +#include "node/node_types.h" + +namespace ccf +{ + // Reads a serialised node_inbound message and dispatches it to the + // appropriate handler, but only once the node is part of the network. This + // includes forwarded commands, which must not be executed until the node is + // part of the network, as some commands may otherwise exhibit undefined + // behaviour. + template + void recv_node_inbound_message( + const uint8_t* data, + size_t size, + ::ds::StateMachine& sm, + TForwarder* cmd_forwarder, + TChannels* n2n_channels, + TConsensus* consensus) + { + auto [msg_type, from, payload] = + ringbuffer::read_message(data, size); + + const auto* payload_data = payload.data; + auto payload_size = payload.size; + + static const std::set active_states{ + NodeStartupState::partOfNetwork, + NodeStartupState::partOfPublicNetwork, + NodeStartupState::readingPrivateLedger}; + + if (!sm.check_one_of(active_states)) + { + LOG_DEBUG_FMT( + "Ignoring node msg received too early - current state is {}", + sm.value()); + return; + } + + switch (msg_type) + { + case forwarded_msg: + { + if (cmd_forwarder == nullptr) + { + LOG_FAIL_FMT( + "Ignoring forwarded node message: command forwarder not " + "initialised"); + return; + } + cmd_forwarder->recv_message(from, payload_data, payload_size); + return; + } + case channel_msg: + { + if (n2n_channels == nullptr) + { + LOG_FAIL_FMT( + "Ignoring channel node message: node-to-node channels not " + "initialised"); + return; + } + n2n_channels->recv_channel_message(from, payload_data, payload_size); + return; + } + case consensus_msg: + { + if (consensus == nullptr) + { + LOG_FAIL_FMT( + "Ignoring consensus node message: consensus not initialised"); + return; + } + consensus->recv_message(from, payload_data, payload_size); + return; + } + default: + { + throw std::logic_error(fmt::format( + "Unknown node message type: {}", static_cast(msg_type))); + } + } + } +} diff --git a/src/node/node_state.h b/src/node/node_state.h index 23a316e7e490..53d158831b8d 100644 --- a/src/node/node_state.h +++ b/src/node/node_state.h @@ -41,6 +41,7 @@ #include "node/ledger_secret.h" #include "node/ledger_secrets.h" #include "node/local_sealing.h" +#include "node/node_inbound_message.h" #include "node/node_to_node_channel_manager.h" #include "node/recovery_decision_protocol.h" #include "node/signature_cache_subsystem.h" @@ -2254,57 +2255,13 @@ namespace ccf void recv_node_inbound(const uint8_t* data, size_t size) { - auto [msg_type, from, payload] = - ringbuffer::read_message(data, size); - - const auto* payload_data = payload.data; - auto payload_size = payload.size; - - if (msg_type == NodeMsgType::forwarded_msg) - { - cmd_forwarder->recv_message(from, payload_data, payload_size); - } - else - { - // Only process messages once part of network - if ( - !sm.check(NodeStartupState::partOfNetwork) && - !sm.check(NodeStartupState::partOfPublicNetwork) && - !sm.check(NodeStartupState::readingPrivateLedger)) - { - LOG_DEBUG_FMT( - "Ignoring node msg received too early - current state is {}", - sm.value()); - return; - } - - switch (msg_type) - { - case forwarded_msg: - { - LOG_FAIL_FMT("Unexpected forwarded_msg in recv_node_inbound"); - return; - } - case channel_msg: - { - n2n_channels->recv_channel_message( - from, payload_data, payload_size); - return; - } - - case consensus_msg: - { - consensus->recv_message(from, payload_data, payload_size); - return; - } - default: - { - throw std::logic_error(fmt::format( - "Unknown node message type: {}", - static_cast(msg_type))); - } - } - } + recv_node_inbound_message( + data, + size, + sm, + cmd_forwarder.get(), + n2n_channels.get(), + consensus.get()); } // diff --git a/src/node/test/node_inbound_message.cpp b/src/node/test/node_inbound_message.cpp new file mode 100644 index 000000000000..8198d845fbec --- /dev/null +++ b/src/node/test/node_inbound_message.cpp @@ -0,0 +1,225 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. + +#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN +#include "../node_inbound_message.h" + +#include "ccf/ds/nonstd.h" + +#include +#include +#include + +namespace +{ + // Records the most recent message passed to it, so a test can assert whether + // (and with what arguments) a particular handler was invoked. + struct StubHandler + { + std::optional last_from; + std::vector last_payload; + size_t call_count = 0; + + void recv_message(const ccf::NodeId& from, const uint8_t* data, size_t size) + { + last_from = from; + last_payload.assign(data, data + size); + ++call_count; + } + + void recv_channel_message( + const ccf::NodeId& from, const uint8_t* data, size_t size) + { + recv_message(from, data, size); + } + }; + + std::vector serialise_node_inbound( + ccf::NodeMsgType msg_type, + const ccf::NodeId& from, + const std::vector& payload) + { + auto sections = + ringbuffer::MessageSerializers::serialize( + msg_type, + from.value(), + serializer::ByteRange{payload.data(), payload.size()}); + + std::vector result; + ccf::nonstd::tuple_for_each(sections, [&](const auto& s) { + result.insert(result.end(), s->data(), s->data() + s->size()); + }); + return result; + } +} + +TEST_CASE( + "recv_node_inbound_message gating" * + doctest::test_suite("node_inbound_message")) +{ + const ccf::NodeId from("0123456789abcdef"); + const std::vector payload{1, 2, 3, 4, 5}; + + const auto early_states = { + ccf::NodeStartupState::uninitialized, + ccf::NodeStartupState::initialized, + ccf::NodeStartupState::pending, + ccf::NodeStartupState::readingPublicLedger}; + + const auto active_states = { + ccf::NodeStartupState::partOfNetwork, + ccf::NodeStartupState::partOfPublicNetwork, + ccf::NodeStartupState::readingPrivateLedger}; + + SUBCASE("Forwarded commands are not processed before part of network") + { + const auto serialised = + serialise_node_inbound(ccf::forwarded_msg, from, payload); + + for (const auto state : early_states) + { + INFO("Early state: ", state); + ds::StateMachine sm("test", state); + StubHandler forwarder; + StubHandler channels; + StubHandler consensus; + + ccf::recv_node_inbound_message( + serialised.data(), + serialised.size(), + sm, + &forwarder, + &channels, + &consensus); + + REQUIRE(forwarder.call_count == 0); + REQUIRE(channels.call_count == 0); + REQUIRE(consensus.call_count == 0); + } + } + + SUBCASE("Forwarded commands are processed once part of network") + { + const auto serialised = + serialise_node_inbound(ccf::forwarded_msg, from, payload); + + for (const auto state : active_states) + { + INFO("Active state: ", state); + ds::StateMachine sm("test", state); + StubHandler forwarder; + StubHandler channels; + StubHandler consensus; + + ccf::recv_node_inbound_message( + serialised.data(), + serialised.size(), + sm, + &forwarder, + &channels, + &consensus); + + REQUIRE(forwarder.call_count == 1); + REQUIRE(forwarder.last_from == from); + REQUIRE(forwarder.last_payload == payload); + REQUIRE(channels.call_count == 0); + REQUIRE(consensus.call_count == 0); + } + } + + SUBCASE("Channel messages are gated and dispatched identically") + { + const auto serialised = + serialise_node_inbound(ccf::channel_msg, from, payload); + + { + INFO("Dropped before part of network"); + ds::StateMachine sm( + "test", ccf::NodeStartupState::pending); + StubHandler forwarder; + StubHandler channels; + StubHandler consensus; + + ccf::recv_node_inbound_message( + serialised.data(), + serialised.size(), + sm, + &forwarder, + &channels, + &consensus); + + REQUIRE(channels.call_count == 0); + } + + { + INFO("Dispatched once part of network"); + ds::StateMachine sm( + "test", ccf::NodeStartupState::partOfNetwork); + StubHandler forwarder; + StubHandler channels; + StubHandler consensus; + + ccf::recv_node_inbound_message( + serialised.data(), + serialised.size(), + sm, + &forwarder, + &channels, + &consensus); + + REQUIRE(channels.call_count == 1); + REQUIRE(channels.last_from == from); + REQUIRE(channels.last_payload == payload); + REQUIRE(forwarder.call_count == 0); + REQUIRE(consensus.call_count == 0); + } + } + + SUBCASE("Consensus messages are gated and dispatched identically") + { + const auto serialised = + serialise_node_inbound(ccf::consensus_msg, from, payload); + + { + INFO("Dropped before part of network"); + ds::StateMachine sm( + "test", ccf::NodeStartupState::initialized); + StubHandler forwarder; + StubHandler channels; + StubHandler consensus; + + ccf::recv_node_inbound_message( + serialised.data(), + serialised.size(), + sm, + &forwarder, + &channels, + &consensus); + + REQUIRE(consensus.call_count == 0); + } + + { + INFO("Dispatched once part of network"); + ds::StateMachine sm( + "test", ccf::NodeStartupState::readingPrivateLedger); + StubHandler forwarder; + StubHandler channels; + StubHandler consensus; + + ccf::recv_node_inbound_message( + serialised.data(), + serialised.size(), + sm, + &forwarder, + &channels, + &consensus); + + REQUIRE(consensus.call_count == 1); + REQUIRE(consensus.last_from == from); + REQUIRE(consensus.last_payload == payload); + REQUIRE(forwarder.call_count == 0); + REQUIRE(channels.call_count == 0); + } + } +}