diff --git a/.github/workflows/flax_test.yml b/.github/workflows/flax_test.yml index 00c09c4c0..cd0de775d 100644 --- a/.github/workflows/flax_test.yml +++ b/.github/workflows/flax_test.yml @@ -98,6 +98,9 @@ jobs: - python-version: '3.12' test-type: mypy jax-version: 'newest' + - python-version: '3.12' + test-type: pyrefly + jax-version: 'newest' steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Setup uv @@ -128,6 +131,8 @@ jobs: uv run --no-sync tests/run_all_tests.sh --only-pytype elif [[ "${{ matrix.test-type }}" == "mypy" ]]; then uv run --no-sync tests/run_all_tests.sh --only-mypy + elif [[ "${{ matrix.test-type }}" == "pyrefly" ]]; then + uv run --no-sync tests/run_all_tests.sh --only-pyrefly else echo "Unknown test type: ${{ matrix.test-type }}" exit 1 diff --git a/flax/linen/linear.py b/flax/linen/linear.py index 27ace4869..547a2749b 100644 --- a/flax/linen/linear.py +++ b/flax/linen/linear.py @@ -189,6 +189,7 @@ def bias_init_wrap(rng, shape, dtype=jnp.float32): inputs, kernel, bias = self.promote_dtype( inputs, kernel, bias, dtype=self.dtype ) + assert inputs is not None and kernel is not None if self.dot_general_cls is not None: dot_general = self.dot_general_cls() diff --git a/flax/traceback_util.py b/flax/traceback_util.py index 7b33e66cc..164d37c04 100644 --- a/flax/traceback_util.py +++ b/flax/traceback_util.py @@ -25,7 +25,7 @@ # Whether to filter flax frames from traceback. _flax_filter_tracebacks = config.flax_filter_frames # Flax specific set of paths to exclude from tracebacks. -_flax_exclusions = set() +_flax_exclusions = set() # type: ignore[var-annotated] # re-import JAX symbol for convenience. diff --git a/flax/training/checkpoints.py b/flax/training/checkpoints.py index 17ac93355..b764e9ad1 100644 --- a/flax/training/checkpoints.py +++ b/flax/training/checkpoints.py @@ -1163,7 +1163,7 @@ def read_chunk(i): else: checkpoint_contents = fp.read() - state_dict = serialization.msgpack_restore(checkpoint_contents) + state_dict = serialization.msgpack_restore(bytes(checkpoint_contents)) state_dict = _restore_mpas( state_dict, target, diff --git a/pyproject.toml b/pyproject.toml index aefae59d4..67a4477a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ testing = [ "jraph>=0.0.6dev0", "ml-collections", "mypy", + "pyrefly", "opencv-python", # Set protobuf version to prevent error in # examples/mnist/train_test.py::TrainTest::test_train_and_evaluate diff --git a/pyrefly.toml b/pyrefly.toml new file mode 100644 index 000000000..55ec5f915 --- /dev/null +++ b/pyrefly.toml @@ -0,0 +1,26 @@ +# Pyrefly configuration - migrated from mypy +# Only type-check flax/linen/linear.py for now; expand as issues are resolved. +project-includes = ["flax/linen/linear.py"] + +preset = "legacy" +ignore-missing-imports = [ + "tensorflow.*", + "tensorboard.*", + "absl.*", + "jax.*", + "rich.*", + "jaxlib.cuda.*", + "jaxlib.cpu.*", + "msgpack", + "numpy.*", + "optax.*", + "orbax.*", + "opt_einsum.*", + "scipy.*", + "libtpu.*", + "jaxlib.mlir.*", + "yaml", +] + +[errors] +missing-attribute = "ignore" diff --git a/tests/run_all_tests.sh b/tests/run_all_tests.sh index 9a5e45496..3141b4ea4 100755 --- a/tests/run_all_tests.sh +++ b/tests/run_all_tests.sh @@ -5,6 +5,7 @@ PYTEST_OPTS= RUN_DOCTEST=false RUN_MYPY=false RUN_PYTEST=false +RUN_PYREFLY=false RUN_PYTYPE=false GH_VENV=false @@ -30,6 +31,9 @@ case $flag in --only-mypy) RUN_MYPY=true ;; + --only-pyrefly) + RUN_PYREFLY=true + ;; --use-venv) GH_VENV=true ;; @@ -40,12 +44,13 @@ case $flag in esac done -# if neither --only-doctest, --only-pytest, --only-pytype, --only-mypy is set, run all tests -if ! $RUN_DOCTEST && ! $RUN_PYTEST && ! $RUN_PYTYPE && ! $RUN_MYPY; then +# if neither --only-doctest, --only-pytest, --only-pytype, --only-mypy, --only-pyrefly is set, run all tests +if ! $RUN_DOCTEST && ! $RUN_PYTEST && ! $RUN_PYTYPE && ! $RUN_MYPY && ! $RUN_PYREFLY; then RUN_DOCTEST=true RUN_PYTEST=true RUN_PYTYPE=true RUN_MYPY=true + RUN_PYREFLY=true fi # Activate cached virtual env for github CI @@ -58,6 +63,7 @@ echo "PYTEST_OPTS: $PYTEST_OPTS" echo "RUN_DOCTEST: $RUN_DOCTEST" echo "RUN_PYTEST: $RUN_PYTEST" echo "RUN_MYPY: $RUN_MYPY" +echo "RUN_PYREFLY: $RUN_PYREFLY" echo "RUN_PYTYPE: $RUN_PYTYPE" echo "GH_VENV: $GH_VENV" echo "WHICH PYTHON: $(which python)" @@ -155,5 +161,11 @@ if $RUN_MYPY; then mypy --config pyproject.toml flax/ --show-error-codes fi +if $RUN_PYREFLY; then + echo "=== RUNNING PYREFLY ===" + # Type-check using pyrefly.toml (currently scoped to flax/linen/linear.py). + pyrefly check +fi + # Return error code 0 if no real failures happened. echo "finished all tests."