Skip to content

vLLM based Eval framework#3531

Open
dipannita08 wants to merge 2 commits into
mainfrom
eval-framework-01
Open

vLLM based Eval framework#3531
dipannita08 wants to merge 2 commits into
mainfrom
eval-framework-01

Conversation

@dipannita08
Copy link
Copy Markdown
Collaborator

@dipannita08 dipannita08 commented Mar 31, 2026

Description

Implement a evaluation framework with vllm backend. Requirements, design, further details: go/eval-framework-vllm

The rest of the description includes relevant details and context, examples:

  • why is this change being made,
  • the problem being solved and any relevant context,
  • why this is a good solution,
  • some information about the specific implementation,
  • shortcomings of the solution and possible future improvements.

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/508639301

Tests

  • Unit tests
  • E2E tests b/508639301

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 31, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

@github-actions
Copy link
Copy Markdown

🤖 Hi @Rohan-Bierneni, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

The implementation of the vLLM-based evaluation framework is a strong addition to MaxText, providing native support for custom benchmarks, lm-evaluation-harness, and evalchemy. The code is well-structured, but it needs critical updates to correctly support multi-host TPU environments whereリード (lead) rank coordination is essential.

🔍 General Feedback

  • Rank Coordination: In multi-host TPU setups, client-side operations (warmup, generation, reporting) must be restricted to jax.process_index() == 0 to avoid redundant work and failures on non-lead ranks.
  • Configurability: Key parameters like request timeouts should be made configurable via the CLI/config files rather than being hardcoded.
  • Efficiency: Minor optimizations in NLTK data handling and FastAPI request processing would improve the overall robustness and performance of the evaluation tool.

Comment thread src/maxtext/eval/runner/eval_runner.py Outdated
Comment thread src/maxtext/eval/runner/lm_eval_runner.py Outdated
Comment thread src/maxtext/eval/runner/evalchemy_runner.py Outdated
Comment thread src/maxtext/eval/scoring/rouge_scorer.py Outdated
Comment thread src/maxtext/eval/runner/async_client.py Outdated
Comment thread src/maxtext/eval/runner/server_manager.py
Copy link
Copy Markdown
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! I’ve done a high-level pass, though I haven’t done a deep dive into the code just yet. Have you had a chance to run any benchmarks? I'm curious if you're seeing decent scores. Also, is multi-host benchmarking for large models within the scope of this PR?

Comment thread src/maxtext/eval/README.md Outdated
Comment thread src/maxtext/eval/README.md Outdated
Comment thread src/maxtext/eval/README.md Outdated
Comment thread src/maxtext/eval/runner/lm_eval_runner.py Outdated
@dipannita08 dipannita08 force-pushed the eval-framework-01 branch 2 times, most recently from cb0295b to 348c539 Compare April 23, 2026 01:12
@dipannita08 dipannita08 force-pushed the eval-framework-01 branch from c2d7d04 to 03ce83a Compare May 1, 2026 17:39
@dipannita08 dipannita08 requested a review from RissyRan May 1, 2026 18:13
Comment thread src/maxtext/eval/datasets/base.py Outdated
@entrpn
Copy link
Copy Markdown
Collaborator

entrpn commented May 1, 2026

@dipannita08 has this code been tested with an nnx checkpoint? @hengtaoguo recently added support for this and it is needed in order to support distilled checkpoints. Hengtao's PR: #3188

Copy link
Copy Markdown
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change!

Potentially we could have a run and see how it performance? This is Kurt's final presentation shows some period of time for common benchmarks using JetStream. It will be great to see if we could speed up using this tool.

Comment thread src/maxtext/eval/configs/mlperf.yml
@@ -0,0 +1,8 @@
# Base evaluation configuration.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we also include benchmark name or eval dataset in this base yml?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Base config covers server and generation parameters that are shared across all eval runs (temperature, concurrency, tensor_parallel_size, etc.). Benchmark name and dataset-specific settings (e.g. num_samples, max_tokens) live in task-specific configs like mlperf.yml so we can re-use the base config for different benchmarks without modification.

For the harness-based runners (lm_eval, evalchemy), benchmark/task selection is handled by the --tasks CLI arg (--tasks gsm8k mmlu gpqa) rather than config files (examples in the bug I shared above).

Comment thread src/maxtext/eval/runner/common.py
Comment thread src/maxtext/eval/runner/eval_runner.py
@dipannita08
Copy link
Copy Markdown
Collaborator Author

@dipannita08 has this code been tested with an nnx checkpoint? @hengtaoguo recently added support for this and it is needed in order to support distilled checkpoints. Hengtao's PR: #3188

Yes, the checkpoint loading in model_creation_utils.py defaults to nnx and falls back to Linen if it detects the params.params double-nesting. MaxTextForCausalLM in the adapter is an nnx.Module. Some example runs with Qwen3-30B-A3B NNX checkpoints are in b/508639301.

@dipannita08
Copy link
Copy Markdown
Collaborator Author

Potentially we could have a run and see how it performance? This is Kurt's final presentation shows some period of time for common benchmarks using JetStream. It will be great to see if we could speed up using this tool.

Please see E2E results in the bug: b/508639301

Copy link
Copy Markdown
Collaborator

@hengtaoguo hengtaoguo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work Dipannita! Do you plan to add a documentation for this route as a follow up?

{"role": "system", "content": system_prompt},
{"role": "user", "content": question},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, is it usually preferred to use tokenizer's chat_template if available?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is preferred for any custom benchmarks. We defer all of the tokenizer chat templating to the harness if it already supports the benchmark in question.

_DEFAULT_CONCURRENCY = 64
_DEFAULT_MAX_TOKENS = 1024
_DEFAULT_TEMPERATURE = 0.0
_COMPLETIONS_PATH = "/v1/completions"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For new customized dataset in the future, I assume it needs new endpoints here?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For multi-modal, we will likely need to add an endpoint here, But no new endpoints are needed for custom text-based datasets. The /v1/completions endpoint accepts any pre-formatted prompt string, so a new dataset just needs to implement BenchmarkDataset.sample_requests() and return SampleRequest objects with the prompt already formatted. For chat-format datasets, the /v1/chat/completions endpoint is already implemented in server_manager.py and used by the lm-eval/evalchemy harness path. The custom dataset path uses /v1/completions because prompts are pre-formatted via the chat template in eval_runner.py before they're sent.

# See the License for the specific language governing permissions and
# limitations under the License.

"""Async HTTP client for the /v1/completions endpoint.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excuse my naive question: what does "completions" mean in this scenario? Is this async_client designed specific for completions task, or generalizable to other tasks too?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the OpenAI text completion API. It allows us to send a raw prompt string and the model returns the text continuation. The /v1/chat/completions, takes a structured messages array and applies chat templating server-side. We are using /v1/completions in the custom dataset path because prompt formatting (and chat templating) is done client-side in eval_runner.py before the request is sent, so each dataset has control over prompt structure. I assume we will mostly be using the industry standard datasets (eval_harness) for most cases though.

from maxtext.eval.runner.async_client import generate_batch
from maxtext.eval.runner.common import build_server_manager, maybe_upload_to_gcs, resolve_token
from maxtext.eval.runner.warmup import warmup_server
from maxtext.eval.scoring.registry import get_scorer
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any specific reasons not to import on the toplevel?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These imports are deferred because they pull in optional dependencies so deferring helps with importing and testing helper modules and methods without pulling in vllm deps etc.

base_url = server.base_url

# Warmup server.
warmup_server(base_url=base_url, model=model_name, sample_requests=requests)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vLLM runs pre-compilation when it receives the first request by default. Is pre-compilation disabled in this framework?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pre-compilation is not disabled, it is additionally implemented by warmup_server() before any evaluation requests are sent to help speed up the first vLLM request as a pre-vllm-warmup.

The SKIP_JAX_PRECOMPILE=1 environment variable set in server_manager.py only suppresses JAX's startup precompilation pass (which would try to compile shapes before the model config is known), this warmup step then covers exactly the shapes needed for the eval workload. This gives more targeted compilation than JAX's default precompile and avoids wasted compile time on shapes that won't be used.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me! Have you verified that you are not seeing any instances of compilation during generation?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did check for single host, should be similar for multihost - did not check.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants