Title: Add tensor key filtering to avoid streaming unused safetensors tensors
InstantTensor currently appears to schedule/read every tensor in a safetensors file set when using safe_open(...).tensors(). This is inefficient for loaders that only need a subset of checkpoint tensors.
Use case: vLLM MTP speculative decoding.
In vLLM, enabling MTP speculative decoding loads the target model, then loads an MTP drafter from the same checkpoint. The MTP drafter is much smaller and only needs MTP-related tensors, but with load_format=instanttensor the second load streams the whole checkpoint again before vLLM can discard non-MTP tensors.
Observed behavior:
Loading safetensors using InstantTensor loader: 100% | 232670/232670 | ~97s
...
Loading drafter model...
Loading safetensors using InstantTensor loader: 100% | 232670/232670 | ~99s
The second pass is for the MTP drafter. It only needs a subset of tensors, but InstantTensor reads all tensors because vLLM calls f.tensors().
Requested feature:
Add support for filtering tensor keys before opening/scheduling native I/O, for example:
with instanttensor.safe_open(
files,
framework="pt",
device=device,
process_group=process_group,
include_keys=needed_keys,
) as f:
for name, tensor in f.tensors():
...
or:
with instanttensor.safe_open(
files,
framework="pt",
device=device,
process_group=process_group,
tensor_filter=lambda name: name in needed_keys,
) as f:
...
Important requirement:
The filter should be applied before building native read offsets/chunks, so skipped tensors are not read from disk and do not contribute to buffer sizing/progress totals.
Why this likely needs InstantTensor support:
vLLM can determine which weights are needed, but after f.tensors() yields a tensor, the payload has already been read. A vLLM-only filter can only discard tensors after I/O. Filtering whole shard files via model.safetensors.index.json is only a partial workaround because shards often contain mixed needed/unneeded tensors.
Expected result:
keys() / offset_keys() return only included keys, or a separate filtered iterator is provided.
tensors() only streams included tensors.
total_tensor_size, buffer sizing, and progress total reflect included tensors only.
- Distributed loading remains deterministic when all ranks pass the same filter/list.
This would significantly reduce startup time and I/O for vLLM MTP/speculative decoding and other partial-loading workflows.
Title: Add tensor key filtering to avoid streaming unused safetensors tensors
InstantTensor currently appears to schedule/read every tensor in a safetensors file set when using
safe_open(...).tensors(). This is inefficient for loaders that only need a subset of checkpoint tensors.Use case: vLLM MTP speculative decoding.
In vLLM, enabling MTP speculative decoding loads the target model, then loads an MTP drafter from the same checkpoint. The MTP drafter is much smaller and only needs MTP-related tensors, but with
load_format=instanttensorthe second load streams the whole checkpoint again before vLLM can discard non-MTP tensors.Observed behavior:
The second pass is for the MTP drafter. It only needs a subset of tensors, but InstantTensor reads all tensors because vLLM calls
f.tensors().Requested feature:
Add support for filtering tensor keys before opening/scheduling native I/O, for example:
or:
Important requirement:
The filter should be applied before building native read offsets/chunks, so skipped tensors are not read from disk and do not contribute to buffer sizing/progress totals.
Why this likely needs InstantTensor support:
vLLM can determine which weights are needed, but after
f.tensors()yields a tensor, the payload has already been read. A vLLM-only filter can only discard tensors after I/O. Filtering whole shard files viamodel.safetensors.index.jsonis only a partial workaround because shards often contain mixed needed/unneeded tensors.Expected result:
keys()/offset_keys()return only included keys, or a separate filtered iterator is provided.tensors()only streams included tensors.total_tensor_size, buffer sizing, and progress total reflect included tensors only.This would significantly reduce startup time and I/O for vLLM MTP/speculative decoding and other partial-loading workflows.