diff --git a/flax/traceback_util.py b/flax/traceback_util.py index 7b33e66cc..164d37c04 100644 --- a/flax/traceback_util.py +++ b/flax/traceback_util.py @@ -25,7 +25,7 @@ # Whether to filter flax frames from traceback. _flax_filter_tracebacks = config.flax_filter_frames # Flax specific set of paths to exclude from tracebacks. -_flax_exclusions = set() +_flax_exclusions = set() # type: ignore[var-annotated] # re-import JAX symbol for convenience. diff --git a/flax/training/checkpoints.py b/flax/training/checkpoints.py index 17ac93355..b764e9ad1 100644 --- a/flax/training/checkpoints.py +++ b/flax/training/checkpoints.py @@ -1163,7 +1163,7 @@ def read_chunk(i): else: checkpoint_contents = fp.read() - state_dict = serialization.msgpack_restore(checkpoint_contents) + state_dict = serialization.msgpack_restore(bytes(checkpoint_contents)) state_dict = _restore_mpas( state_dict, target,