Skip to content

[wip] Add batch inference pipeline tutorials (image + LLM)#228

Open
kumare3 wants to merge 1 commit intomainfrom
inference-pipeline
Open

[wip] Add batch inference pipeline tutorials (image + LLM)#228
kumare3 wants to merge 1 commit intomainfrom
inference-pipeline

Conversation

@kumare3
Copy link
Contributor

@kumare3 kumare3 commented Mar 23, 2026

Two new tutorials demonstrating the InferencePipeline abstraction from flyte.extras for maximizing GPU utilization during batch inference.

batch_image_pipeline.py — Image Classification with ResNet-50

3-stage pipeline: download images → resize/normalize → GPU inference → decode labels. Optimizations informed by GPU expert review:

  • torch.compile(dynamic=False, mode="reduce-overhead") with multi-size warmup to pre-compile CUDA graphs for all plausible batch sizes
  • pin_memory() + non_blocking H2D transfer to overlap PCIe with compute
  • Top-5 computed on-device — 200x less D2H data vs full 1000-class logits
  • min_batch_size=8 prevents pathological batch-of-1 under low concurrency (T4 throughput drops ~15x at batch=1 vs batch=32 for ResNet-50)

batch_llm_pipeline.py — LLM Inference with Qwen2.5-0.5B on JSONL

3-stage pipeline: read JSONL → tokenize with token-count cost estimation → model.generate() via token-budgeted batching → decode + substring eval. Optimizations:

  • Token-budgeted DynamicBatcher (TARGET_BATCH_TOKENS=4096) assembles variable-length prompts into GPU-optimal batches
  • tokenizer.pad() for correct left-padding (avoids pad_token_id=0 / BOS corruption bug from manual padding)
  • eos_token_id for early termination — short answers stop generating instead of padding to max_new_tokens
  • Warmup at full batch=16, max_new_tokens=128 to pre-allocate KV cache
  • Includes 100-prompt demo JSONL with substring-match evaluation

Both examples use ReusePolicy so multiple concurrent tasks on the same replica share a single InferencePipeline singleton — the DynamicBatcher sees items from all streams, producing larger GPU batches automatically.

Depends on: flyteorg/flyte-sdk#826

Two new tutorials demonstrating the InferencePipeline abstraction from
flyte.extras for maximizing GPU utilization during batch inference.

## batch_image_pipeline.py — Image Classification with ResNet-50

3-stage pipeline: download images → resize/normalize → GPU inference →
decode labels. Optimizations informed by GPU expert review:

- torch.compile(dynamic=False, mode="reduce-overhead") with multi-size
  warmup to pre-compile CUDA graphs for all plausible batch sizes
- pin_memory() + non_blocking H2D transfer to overlap PCIe with compute
- Top-5 computed on-device — 200x less D2H data vs full 1000-class logits
- min_batch_size=8 prevents pathological batch-of-1 under low concurrency
  (T4 throughput drops ~15x at batch=1 vs batch=32 for ResNet-50)

## batch_llm_pipeline.py — LLM Inference with Qwen2.5-0.5B on JSONL

3-stage pipeline: read JSONL → tokenize with token-count cost estimation →
model.generate() via token-budgeted batching → decode + substring eval.
Optimizations:

- Token-budgeted DynamicBatcher (TARGET_BATCH_TOKENS=4096) assembles
  variable-length prompts into GPU-optimal batches
- tokenizer.pad() for correct left-padding (avoids pad_token_id=0 / BOS
  corruption bug from manual padding)
- eos_token_id for early termination — short answers stop generating
  instead of padding to max_new_tokens
- Warmup at full batch=16, max_new_tokens=128 to pre-allocate KV cache
- Includes 100-prompt demo JSONL with substring-match evaluation

Both examples use ReusePolicy so multiple concurrent tasks on the same
replica share a single InferencePipeline singleton — the DynamicBatcher
sees items from all streams, producing larger GPU batches automatically.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

@worker.task(cache="auto", retries=2)
async def generate_responses(
jsonl_file: flyte.io.File,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can use flyteplugins-jsonl here.

Comment on lines +379 to +389
# Split into chunk files and upload
tasks = []
for i in range(0, len(lines), chunk_size):
chunk_lines = lines[i : i + chunk_size]
chunk_id = f"chunk_{i // chunk_size:03d}"

# Write chunk to temp file and upload
chunk_path = os.path.join(tempfile.gettempdir(), f"{chunk_id}.jsonl")
with open(chunk_path, "w") as f:
f.write("\n".join(chunk_lines) + "\n")
chunk_file = await flyte.io.File.from_local(chunk_path)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +187 to +191
loop = asyncio.get_running_loop()
tensor = await loop.run_in_executor(
_io_cpu_pool,
lambda: _preprocess_transform(Image.open(BytesIO(resp.content)).convert("RGB")),
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can set preprocess_executor no?

for i in range(len(batch))
]

return await loop.run_in_executor(_gpu_pool, _forward)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't think we need this, right?

texts = tokenizer.batch_decode(generated, skip_special_tokens=True)
return texts

return await loop.run_in_executor(_gpu_pool, _generate)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't think we need this, right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants