diff --git a/CHANGES/1328.bugfix.rst b/CHANGES/1328.bugfix.rst new file mode 100644 index 000000000..e90957a77 --- /dev/null +++ b/CHANGES/1328.bugfix.rst @@ -0,0 +1,2 @@ +Fixed global counter system using an atomic variable. +-- by :user:`Vizonex`. diff --git a/multidict/_multilib/state.h b/multidict/_multilib/state.h index 4e2610b6c..ec836e29b 100644 --- a/multidict/_multilib/state.h +++ b/multidict/_multilib/state.h @@ -5,6 +5,8 @@ extern "C" { #endif +#include + /* State of the _multidict module */ typedef struct { PyTypeObject *IStrType; @@ -26,7 +28,7 @@ typedef struct { PyObject *str_lower; PyObject *str_name; - uint64_t global_version; + _Atomic uint64_t global_version; } mod_state; static inline mod_state * @@ -128,7 +130,11 @@ get_mod_state_by_def(PyObject *self) static inline uint64_t NEXT_VERSION(mod_state *state) { - return ++state->global_version; + /* relaxed is fine here as we only care about the atomicity of the RMW + * itself */ + return atomic_fetch_add_explicit( + &state->global_version, 1, memory_order_relaxed) + + 1; } #ifdef __cplusplus diff --git a/setup.py b/setup.py index 318229fcc..a5fbb899c 100644 --- a/setup.py +++ b/setup.py @@ -1,8 +1,8 @@ import os -import platform import sys from setuptools import Extension, setup +from setuptools.command.build_ext import build_ext NO_EXTENSIONS = bool(os.environ.get("MULTIDICT_NO_EXTENSIONS")) DEBUG_BUILD = bool(os.environ.get("MULTIDICT_DEBUG_BUILD")) @@ -10,26 +10,42 @@ if sys.implementation.name != "cpython": NO_EXTENSIONS = True -CFLAGS = ["-O0", "-g3", "-UNDEBUG"] if DEBUG_BUILD else ["-O3", "-DNDEBUG"] - -if platform.system() != "Windows": - CFLAGS.extend( - [ - "-std=c11", - "-Wall", - "-Wsign-compare", - "-Wconversion", - "-fno-strict-aliasing", - "-Wno-conversion", - "-Werror", - ] - ) +BASE_CFLAGS = ["O0", "g3", "UNDEBUG"] if DEBUG_BUILD else ["O3", "DNDEBUG"] + +UNIX_CFLAGS = [ + "-std=c11", + "-Wall", + "-Wsign-compare", + "-Wconversion", + "-fno-strict-aliasing", + "-Wno-conversion", + "-Werror", +] + +MSVC_CFLAGS = ["/std:c11", "/experimental:c11atomics"] + + +class BuildExt(build_ext): + def build_extensions(self): + if self.compiler.compiler_type == "msvc": + for ext in self.extensions: + ext.extra_compile_args.extend(MSVC_CFLAGS) + for flag in BASE_CFLAGS: + # XXX: MSVC Doesn't have a /O3 flag only O2 is possible... + ext.extra_compile_args.append("/O2" if flag == "O3" else f"/{flag}") + else: + for ext in self.extensions: + ext.extra_compile_args.extend(UNIX_CFLAGS) + for flag in BASE_CFLAGS: + ext.extra_compile_args.append(f"-{flag}") + super().build_extensions() + extensions = [ Extension( "multidict._multidict", ["multidict/_multidict.c"], - extra_compile_args=CFLAGS, + extra_compile_args=[], ), ] @@ -38,7 +54,7 @@ print("*********************") print("* Accelerated build *") print("*********************") - setup(ext_modules=extensions) + setup(ext_modules=extensions, cmdclass={"build_ext": BuildExt}) else: print("*********************") print("* Pure Python build *") diff --git a/tests/isolated/multidict_global_counter.py b/tests/isolated/multidict_global_counter.py new file mode 100644 index 000000000..fd8c25ff6 --- /dev/null +++ b/tests/isolated/multidict_global_counter.py @@ -0,0 +1,33 @@ +import sysconfig +import threading + +import multidict +from multidict import MultiDict + +FREETHREADED = bool(sysconfig.get_config_var("Py_GIL_DISABLED")) + + +md: MultiDict[int] = MultiDict() +N, M = 3, 100 +baseline = multidict.getversion(md) # type: ignore[arg-type] + + +def worker(tid: int) -> None: + for i in range(M): + md[f"k{tid}_{i}"] = i + + +if (__name__ == "__main__") and FREETHREADED: + threads = [threading.Thread(target=worker, args=(tid,)) for tid in range(N)] + for t in threads: + t.start() + for t in threads: + t.join() + + observed = multidict.getversion(md) - baseline # type: ignore[arg-type] + expected = N * M + assert expected == observed, ( + f"expected delta: {expected}" + f" observed: {observed} " + f"lost: {expected - observed}" + ) diff --git a/tests/test_leaks.py b/tests/test_leaks.py index 105853ddd..18dc3de6c 100644 --- a/tests/test_leaks.py +++ b/tests/test_leaks.py @@ -18,6 +18,7 @@ "multidict_type_leak.py", "multidict_type_leak_items_values.py", "multidict_pop.py", + "multidict_global_counter.py", ), ) @pytest.mark.leaks