Skip to content

Fix GraphUpdate from_config for nested serialized layers (improves Keras / tfmot compatibility)#923

Open
ur-miya wants to merge 1 commit into
tensorflow:mainfrom
ur-miya:fix-gnn-from-config-robustness
Open

Fix GraphUpdate from_config for nested serialized layers (improves Keras / tfmot compatibility)#923
ur-miya wants to merge 1 commit into
tensorflow:mainfrom
ur-miya:fix-gnn-from-config-robustness

Conversation

@ur-miya
Copy link
Copy Markdown

@ur-miya ur-miya commented Jun 3, 2026

Summary

This PR makes several TF‑GNN Keras layers (GraphUpdate, EdgeSetUpdate, NodeSetUpdate, ContextUpdate) more robust when reconstructed from serialized configs, especially in setups where tooling such as tfmot.quantization.keras serializes nested layers as config dicts instead of already-instantiated Layer objects.

The constructors and _check_is_layer(...) contract are unchanged; only from_config(...) paths are made tolerant to seeing serialized sub‑layer configs and normalize them back to tf.keras.layers.Layer instances before calling __init__.


Problem

In some environments (e.g. TF 2.16 + tf-keras + tfmot QAT), the following pattern can occur:

  • A model containing TF‑GNN layers (e.g. GraphUpdate) is cloned/serialized via Keras (clone_model, model.to_json, etc.).
  • Nested sub‑layers (e.g. EdgeSetUpdate, NodeSetUpdate.next_state, context/node/edge inputs) are stored in the config as dicts of the form {class_name, config, ...} rather than as Layer instances.
  • Later, Keras calls GraphUpdate.from_config(config) / EdgeSetUpdate.from_config(config). The corresponding __init__ implementations still expect actual Layer objects and call _check_is_layer(...), leading to errors like:

GraphUpdate(edge_sets={edge: ...}) must be a tf.keras.layer.Layer, got type: dict
EdgeSetUpdate(next_state=...) must be a tf.keras.layer.Layer, got type: dict.

This makes it difficult to combine TF‑GNN models with tooling that relies on cloning and graph transformations via Keras configs (e.g. QAT with tfmot).


Approach

The idea is to leave _check_is_layer(...) and the type guarantees in the constructors unchanged, and instead make from_config(...) smart enough to recognize serialized Keras layers and deserialize them back into Layer instances.

  1. Introduce a small helper:
def _maybe_deserialize_layer(obj):
  if isinstance(obj, tf.keras.layers.Layer):
    return obj
  if isinstance(obj, dict) and "class_name" in obj and "config" in obj:
    return tf.keras.layers.deserialize(obj)
  return obj
  1. Update the from_config(...) methods:
  • GraphUpdate

    • Keep the existing du.pop_by_prefix logic:

      config["edge_sets"] = du.pop_by_prefix(config, "edge_sets/")
      config["node_sets"] = du.pop_by_prefix(config, "node_sets/")
    • After that, run _maybe_deserialize_layer on:

      • each value in config["edge_sets"] and config["node_sets"];
      • config["context"] (if present).
  • EdgeSetUpdate

    • Add a custom from_config that:
      • copies the config;
      • applies _maybe_deserialize_layer to config["next_state"] (if present);
      • calls cls(**config).
  • NodeSetUpdate

    • Extend from_config to:
      • call du.pop_by_prefix(config, "edge_set_inputs/") as before;
      • run _maybe_deserialize_layer on each edge_set_inputs[...];
      • run _maybe_deserialize_layer on next_state (if present).
  • ContextUpdate

    • Extend from_config to:
      • call du.pop_by_prefix on node_set_inputs/* and edge_set_inputs/*;
      • run _maybe_deserialize_layer on each node_set_inputs[...] and edge_set_inputs[...];
      • run _maybe_deserialize_layer on next_state (if present).

In all cases, _check_is_layer(...) in the constructors is unchanged and will still reject non‑Layer values that could not be normalized.


Rationale / why this is safe

  • We do not change runtime call semantics or relax any type checks. All constructors still require actual tf.keras.layers.Layer instances and _check_is_layer(...) continues to enforce that.
  • Only from_config(...) code paths are touched, i.e. reconstruction from serialized configs (load_model, clone_model, etc.), not ordinary model building.
  • _maybe_deserialize_layer(...) is conservative:
    • if obj is already a Layer, it is returned as‑is;
    • if obj is a dict with class_name and config, it matches the standard Keras serialization format and is deserialized via tf.keras.layers.deserialize;
    • in all other cases obj is left unchanged and _check_is_layer(...) will still raise if it is not a Layer.

The pattern is applied consistently across GraphUpdate, EdgeSetUpdate, NodeSetUpdate and ContextUpdate, which makes future maintenance and reasoning about serialization behavior easier.


Tests

  • New tests in tensorflow_gnn/keras/layers/graph_update_test.py:

    • GraphUpdateSerializationTest.test_graph_update_from_config_deserializes_nested_layers

      • Constructs a GraphUpdate with nested EdgeSetUpdate / NodeSetUpdate using TF‑GNN layers (including Pool and NextStateFromConcat), calls get_config() and GraphUpdate.from_config(config), and asserts the rebuilt instance is a GraphUpdate whose get_config() still contains keys like "edge_sets/edge" and "node_sets/node".
    • GraphUpdateSerializationTest.test_edge_set_update_from_config_deserializes_next_state

      • Verifies that EdgeSetUpdate.from_config(...) correctly restores next_state and that get_config() on the rebuilt layer includes "next_state".
    • GraphUpdateSerializationTest.test_node_set_update_from_config_deserializes_nested_layers

      • Verifies that NodeSetUpdate.from_config(...) correctly restores edge_set_inputs and next_state, and that get_config() on the rebuilt layer includes "edge_set_inputs/edge" and "next_state".
  • Existing tests in graph_update_test.py continue to pass.

  • All tests were run under a fresh virtualenv with TensorFlow 2.16+, tensorflow-gnn 1.0.3 and tf-keras, with:

    TF_USE_LEGACY_KERAS=1 python -m tensorflow_gnn.keras.layers.graph_update_test

@google-cla
Copy link
Copy Markdown

google-cla Bot commented Jun 3, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant