Skip to content

Add tensor key filtering to avoid streaming unused safetensors tensors #14

@bbartels

Description

@bbartels

Title: Add tensor key filtering to avoid streaming unused safetensors tensors

InstantTensor currently appears to schedule/read every tensor in a safetensors file set when using safe_open(...).tensors(). This is inefficient for loaders that only need a subset of checkpoint tensors.

Use case: vLLM MTP speculative decoding.

In vLLM, enabling MTP speculative decoding loads the target model, then loads an MTP drafter from the same checkpoint. The MTP drafter is much smaller and only needs MTP-related tensors, but with load_format=instanttensor the second load streams the whole checkpoint again before vLLM can discard non-MTP tensors.

Observed behavior:

Loading safetensors using InstantTensor loader: 100% | 232670/232670 | ~97s
...
Loading drafter model...
Loading safetensors using InstantTensor loader: 100% | 232670/232670 | ~99s

The second pass is for the MTP drafter. It only needs a subset of tensors, but InstantTensor reads all tensors because vLLM calls f.tensors().

Requested feature:

Add support for filtering tensor keys before opening/scheduling native I/O, for example:

with instanttensor.safe_open(
    files,
    framework="pt",
    device=device,
    process_group=process_group,
    include_keys=needed_keys,
) as f:
    for name, tensor in f.tensors():
        ...

or:

with instanttensor.safe_open(
    files,
    framework="pt",
    device=device,
    process_group=process_group,
    tensor_filter=lambda name: name in needed_keys,
) as f:
    ...

Important requirement:

The filter should be applied before building native read offsets/chunks, so skipped tensors are not read from disk and do not contribute to buffer sizing/progress totals.

Why this likely needs InstantTensor support:

vLLM can determine which weights are needed, but after f.tensors() yields a tensor, the payload has already been read. A vLLM-only filter can only discard tensors after I/O. Filtering whole shard files via model.safetensors.index.json is only a partial workaround because shards often contain mixed needed/unneeded tensors.

Expected result:

  • keys() / offset_keys() return only included keys, or a separate filtered iterator is provided.
  • tensors() only streams included tensors.
  • total_tensor_size, buffer sizing, and progress total reflect included tensors only.
  • Distributed loading remains deterministic when all ranks pass the same filter/list.

This would significantly reduce startup time and I/O for vLLM MTP/speculative decoding and other partial-loading workflows.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions