You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Bug 1: OOM check in import_external_tensor uses aligned_export_size (export called first)
Bug 2: owns_tensor checks data_ptr() range for all tensors; zero-element tensors no longer unconditionally claimed
Bug 3: GC auto-free via weakref.finalize on tensor.untyped_storage() (survives reshape()/other view ops) + _pending_free deque (avoids lock-reentry deadlock)
Bug 4: Physical memory renewed at reuse time — _remap_free_block releases old mem_unmap+mem_release, creates new mem_create+mem_map, then restores mem_set_access(base_va, cumulative, ...) — releasing at free time is incompatible with ROCm's requirement that hipMemSetAccess must start from base_va (calling on sub-ranges gives HIP error 1)
Design 5: _CUDAArrayInterface defined once at module level
Design 6: _element_size() with module-level _DTYPE_ELEMENT_SIZE cache
Design 7: FD leak in import_external_tensor fixed with try/finally
Design 8: Unused va_multiplier parameter removed
Test 9: test_vmem_pow2_oom — RuntimeError on VA exhaustion
Test 10: test_vmem_pow2_thread_safety — 4 threads × 20 alloc/free cycles
Test 11: test_vmem_pow2_close + test_vmem_pow2_close_disables_finalizers
Test 12: All print() calls removed from tests
Generation counter added to segment API: get_allocation_segments() returns (offset, size, va, generation) 4-tuples; symmetric_heap uses (offset, size, generation) as dedup key, handles stale peer mappings
vmem_allocator.py returns generation=0 in segments (backward compatible)
This section details on the original issue you should resolve
<issue_title>Implement power-of-two vmem allocator</issue_title>
<issue_description>Implement a new virtual memory power of two allocator. Allocator must support as symmetric api</issue_description>
<agent_instructions>You have access to AMD gpu. Make sure everything is tested and working </agent_instructions>
Comments on the Issue (you are @copilot in this section)
@copilot Please address the following issues in this PR:
Critical Bugs
import_external_tensor OOM check mismatch: The OOM check on line 390 uses aligned_size (derived from alloc_size) but the actual mapping on line 403 uses aligned_export_size (derived from export_size). These can differ, so the check may pass but the mapping could still exceed the heap, or vice versa. The check should use aligned_export_size.
owns_tensor incorrectly claims zero-element tensors: owns_tensor returns True unconditionally for any zero-element tensor, even ones not allocated by this allocator. A torch.zeros(0, device="cuda") created externally would be incorrectly claimed as owned. This can break routing logic that depends on owns_tensor.
No GC-based auto-free: Tensors that go out of scope leak their VA block permanently — the block is never returned to the free list until the entire allocator is close()'d. The allocator should use a weak-reference finalizer (e.g., weakref.ref with a callback) on returned tensors to automatically return blocks to the free list when the tensor is garbage collected. Users should not be required to manually call allocator.free(t) for normal usage.
Physical memory not released on free(): Currently free() only does bookkeeping — the physical memory (mem_create handle) stays mapped, permanently holding VRAM for blocks that aren't in use. On free(), the allocator should call mem_unmap and mem_release to return physical memory to the system while keeping the VA offset on the free list for reuse. On re-allocation from the free list, call mem_create + mem_map again to back the reserved VA with fresh physical memory. The cumulative_mapped_size / mem_set_access logic will need adjustment to handle gaps in the physical mapping. The DMA-BUF concern can be handled by re-exporting segments after reallocation rather than assuming all segments are always physically present.
Design Issues
_CUDAArrayInterface class redefined on every call: The inner class _CUDAArrayInterface is redefined inside both allocate() and import_external_tensor() on every invocation. Define it once at module level or as a nested class of the allocator.
Inefficient element_size computation: torch.tensor([], dtype=dtype).element_size() creates a temporary tensor on every allocation just to get the element size. Cache this or use a more efficient approach.
File descriptor leak in import_external_tensor: If an exception occurs between export_dmabuf_handle and os.close(dmabuf_fd), the fd leaks. Wrap in try/finally.
Unused va_multiplier parameter: The constructor accepts va_multiplier but never uses it. Either implement it or remove it.
Testing Gaps
No test for OOM / heap exhaustion — the RuntimeError on VA space exhaustion is untested.
No test for thread safety — given the Lock, concurrent allocation/free should be tested.
No test for close() / resource cleanup — the close() and __del__ paths are untested.
Tests use print() for pass reporting — remove prints; rely on pytest assertions.
@copilot Please address the following issues in this PR:
Critical Bugs...
All 12 issues addressed in commit 17edd0b. Key points:
Critical bugs:
OOM check in import_external_tensor now uses aligned_export_size (export called first, then check)
owns_tensor checks data_ptr() range for all tensors — zero-element tensors no longer unconditionally claimed
GC auto-free via weakref.finalize on tensor.untyped_storage() (not the Python wrapper, which doesn't survive .reshape()) + _pending_free deque to avoid lock-reentry deadlock
Physical memory is renewed at reuse time: _remap_free_block releases the old handle (mem_unmap + mem_release) and creates fresh physical memory (mem_create + mem_map) at the same VA, then restores mem_set_access(base_va, cumulative_size, …). Releasing at free() time is incompatible with ROCm's requirement that hipMemSetAccess must start from base_va — calling it on arbitrary sub-ranges gives HIP error 1 (invalid argument), so VA ranges must always stay gap-free
Segment API:get_allocation_segments() now returns (offset, size, va, generation) 4-tuples; _remap_free_block increments the generation; symmetric_heap replaces the _peer_imported_segments set with _peer_segment_generations dict and handles stale peer mappings (unmap old, remap new, re-set access)
New tests: OOM, thread safety, close(), GC auto-free, generation counter, zero-element owns_tensor check; all print() calls removed
All 24 new tests pass on 2 ranks; all 2255 existing tests pass.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
VMemPow2Allocatoriniris/allocators/vmem_pow2_allocator.pyiris/allocators/__init__.pyto exportVMemPow2Allocatoriris/symmetric_heap.pyto supportallocator_type="vmem_pow2"tests/unittests/test_vmem_pow2_allocator.pywith comprehensive testsimport_external_tensorusesaligned_export_size(export called first)owns_tensorchecksdata_ptr()range for all tensors; zero-element tensors no longer unconditionally claimedweakref.finalizeontensor.untyped_storage()(survivesreshape()/other view ops) +_pending_freedeque (avoids lock-reentry deadlock)_remap_free_blockreleases oldmem_unmap+mem_release, creates newmem_create+mem_map, then restoresmem_set_access(base_va, cumulative, ...)— releasing at free time is incompatible with ROCm's requirement thathipMemSetAccessmust start frombase_va(calling on sub-ranges gives HIP error 1)_CUDAArrayInterfacedefined once at module level_element_size()with module-level_DTYPE_ELEMENT_SIZEcacheimport_external_tensorfixed withtry/finallyva_multiplierparameter removedtest_vmem_pow2_oom— RuntimeError on VA exhaustiontest_vmem_pow2_thread_safety— 4 threads × 20 alloc/free cyclestest_vmem_pow2_close+test_vmem_pow2_close_disables_finalizersprint()calls removed from testsget_allocation_segments()returns(offset, size, va, generation)4-tuples; symmetric_heap uses(offset, size, generation)as dedup key, handles stale peer mappingsvmem_allocator.pyreturnsgeneration=0in segments (backward compatible)ptr==0note, deque thread-safety note)Original prompt
✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.