Skip to content

feat: batch inference for tile mode + tessera double-fetch fix#86

Merged
Dinghye merged 7 commits into
mainfrom
speedtile
Jun 11, 2026
Merged

feat: batch inference for tile mode + tessera double-fetch fix#86
Dinghye merged 7 commits into
mainfrom
speedtile

Conversation

@Dinghye

@Dinghye Dinghye commented Jun 10, 2026

Copy link
Copy Markdown
Collaborator

Summary

  • Tile-mode batch inference. Adds a new "Tier 1.5" path in InferenceEngine:
    when input_prep="tile" and the embedder supports prefetched batch, every
    input image is sliced into tiles, all tiles across all spatial points are
    flattened into one batch
    , run through a single forward pass, then stitched
    back per point. Previously tile mode fell through to the single-item path
    and lost all batching speedup.
  • Batch APIs on three on-the-fly embedders. FoMo, SatVision-TOA, and
    WildSAT now implement get_embeddings_batch_from_inputs, including new
    batched forward helpers (e.g. _fomo_forward_tokens_batch).
  • Tessera double-fetch fix. _mosaic_and_crop_strict_roi was invoking
    tiles_rows_factory() twice (bounds scan + paste), causing
    geotessera.fetch_embeddings to re-iterate — and on a cold cache,
    re-download — every tile block. Now materialized once into a list both
    passes share. This was the dominant contributor to the "stuck at 0/N"
    interval before the first export_batch chunk.

Changes

Area Files
Inference pipeline src/rs_embed/pipelines/inference.py (+259)
New batch APIs src/rs_embed/embedders/onthefly_{fomo,satvision_toa,wildsat}.py (+581)
Tessera fix src/rs_embed/embedders/precomputed_tessera.py (+7/-2)
Tests tests/test_embedder_base_contracts.py (+55), tests/test_input_prep_tiling.py (+165), minor touch-ups
Misc CHANGELOG, demo notebook

Net: ~1063 insertions / 19 deletions across 11 files.

Notes

  • The new tier only fires for explicit input_prep="tile". Auto mode is
    unchanged — it still requires per-image size inspection that can't be
    batched before the fetch.
  • Tessera memory tradeoff: all tile blocks for the current point are now
    resident simultaneously instead of one at a time. Typical 2–4 km buffer is
    1–4 blocks. Reduce RS_EMBED_TESSERA_BATCH_WORKERS on memory-tight hosts
    with very large buffers.

Test

  • pytest tests/test_input_prep_tiling.py tests/test_embedder_base_contracts.py
  • pytest tests/test_inference_helpers.py tests/test_inspect.py
  • Manual: run export_batch with input_prep="tile" on FoMo/SatVision-TOA/WildSAT and confirm batched path is taken (single forward call per batch_size tiles)
  • Manual: run export_batch including tessera on a cold cache and confirm no duplicate fetch logs for the same tile

Dinghye and others added 4 commits May 23, 2026 18:19
_mosaic_and_crop_strict_roi invoked tiles_rows_factory() twice (bounds
scan + paste), re-triggering geotessera.fetch_embeddings on every tile.
Materialize the rows once so both passes share the same list.

Co-Authored-By: Claude Opus 4.7 <[email protected]>

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR improves export-time throughput by adding a true batched inference path for explicit tile-mode preprocessing, extending on-the-fly embedders with a prefetched-input batch API, and fixing an expensive double-fetch in the Tessera precomputed embedder mosaic path.

Changes:

  • Add “Tier 1.5” inference: tile each prefetched input per point, flatten all tiles across points into batches, run a single forward per batch, then stitch per-point outputs.
  • Implement get_embeddings_batch_from_inputs for FoMo, SatVision-TOA, and WildSAT to enable real GPU batching on prefetched inputs.
  • Materialize Tessera tile rows once in _mosaic_and_crop_strict_roi to avoid re-iterating / re-downloading tiles.

Reviewed changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
src/rs_embed/pipelines/inference.py Adds tiled multi-point batching (“Tier 1.5”) and batch capability gating updates.
src/rs_embed/embedders/onthefly_fomo.py Adds batched forward/token handling and get_embeddings_batch_from_inputs.
src/rs_embed/embedders/onthefly_satvision_toa.py Adds get_embeddings_batch_from_inputs implementation for prefetched inputs.
src/rs_embed/embedders/onthefly_wildsat.py Adds batch forward helper and get_embeddings_batch_from_inputs.
src/rs_embed/embedders/precomputed_tessera.py Fixes double iteration of tiles_rows_factory() by materializing once.
tests/test_input_prep_tiling.py Adds tests covering tiled batching across multiple points and dispatch from infer_chunk.
tests/test_inference_helpers.py Updates batch-capability helper test expectations for new can_tiled return.
tests/test_embedder_base_contracts.py Adds a contract test enforcing batch-from-inputs overrides for input_chw on-the-fly embedders.
tests/test_inspect.py Minor import re-ordering.
examples/demo.ipynb Updates demo model list and switches export example to input_prep="tile".
CHANGELOG.md Documents Tessera double-fetch fix and related context.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread src/rs_embed/pipelines/inference.py Outdated
Comment on lines +452 to +456
if tiled_model_config is not None and embedder_accepts_model_config(
type(embedder), "get_embeddings_batch_from_inputs"
):
batch_kwargs["model_config"] = tiled_model_config

Comment thread tests/test_input_prep_tiling.py Outdated
return Embedding(data=np.asarray([float(x.mean())], dtype=np.float32), meta={})


def _make_engine(tile_size: int = 4) -> tuple:
embedder, model_config, tile_size=tile_size
)
tiled_model_config = tiled_mc
ys, xs = _tile_yx_starts(h=h, w=w, tile_size=tile_size, stride=stride)
Comment thread examples/demo.ipynb
" continue_on_error=True,\n",
" show_progress=True,\n",
" input_prep=\"resize\", # you can chage it to 'tile' if you want to keep original resolution\n",
" input_prep=\"tile\", # you can chage it to 'tile' if you want to keep original resolution\n",

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 12 out of 12 changed files in this pull request and generated 3 comments.

Comment thread src/rs_embed/pipelines/inference.py Outdated
Comment on lines +488 to +491
# Step 3: stitch tiles back into per-point embeddings.
for i, _spatial, _inp in ready:
flat_start, tile_count, tile_metas, h, w, tiled_mc = tile_map[i]
tile_embs = all_tile_embs[flat_start : flat_start + tile_count]
return Embedding(data=np.asarray([float(x.mean())], dtype=np.float32), meta={})


def _make_engine(tile_size: int = 4, max_tiles: int = 16) -> tuple:
Comment thread examples/demo.ipynb
" continue_on_error=True,\n",
" show_progress=True,\n",
" input_prep=\"resize\", # you can chage it to 'tile' if you want to keep original resolution\n",
" input_prep=\"tile\", # you can chage it to 'tile' if you want to keep original resolution\n",
@Dinghye Dinghye merged commit 53a9c07 into main Jun 11, 2026
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants