~/blog/part2-gpt-oss-120b-serve-script

gpt-oss-120B on DGX Spark · part 1

[vLLM] Running a 120B Model on DGX Spark at 60 tok/s — Zero API Cost, Six Bugs

2026-03-1910 min read#dgx-spark#sm121#vllm#gpt-oss中文版

Preface

The goal was simple: give the openclaw agent running on this machine a 120B model as its brain. Local inference, no API quota, no rate limits, no cost per token. A model that's always available, runs at the edge of the hardware's bandwidth ceiling, and doesn't phone home.

The DGX Spark (GB10, SM121, 128 GB unified memory) is the hardware. gpt-oss-120B is Microsoft's open-source 120B model, MXFP4-quantized. At 273 GB/s memory bandwidth and 60 GB of weights, the theoretical ceiling is ~60 tok/s. That's the number to hit.

Between zero and that number: six bugs. This is the path through them.


Part 1 covered the SM121-specific NVFP4 fixes — if any model outputs !!!!! on your DGX Spark, fix those first. This article picks up after that, with gpt-oss-120B specifically: a different quantization format, a different tokenizer, and its own set of obstacles.


What Makes gpt-oss-120B Different from Other MXFP4 Models?

gpt-oss-120B is Microsoft's open-source 120B model, MXFP4-quantized, in a format called harmony. It uses the openai_harmony tokenizer (which depends on tiktoken, not HuggingFace tokenizers). It has a --reasoning-parser designed specifically for its output format. These three facts each contribute exactly one bug to this story.

The SM121-compatible vLLM fork is eugr's branch, which patches MXFP4 support to work on GB10.

Note: Since this article was written, eugr has published spark-vllm-docker — a Docker-based setup with prebuilt nightly wheels, model recipes, and multi-node support. It includes a run-recipe.sh openai-gpt-oss-120b that handles most of the configuration below automatically. The manual approach documented here still applies if you're patching stock vLLM directly, but the Docker repo is the easier path for most setups.


Bug 1: Import Path Mismatch

The eugr fork was written against his own vLLM tree. When you apply the patches to stock vLLM 0.17.1, one import breaks immediately:

# eugr fork path (wrong on stock vLLM)
from vllm.model_executor.layers.quantization.quant_utils import cutlass_fp4_supported

# stock vLLM 0.17.1 (correct)
from vllm.model_executor.layers.quantization.nvfp4_utils import cutlass_fp4_supported

This is in mxfp4.py. Server crashes on startup, error message is not helpful. One-line fix.


Bug 2: --enforce-eager Cuts Speed in Half

First successful run. Speed: 26 tok/s. Expected: ~59 tok/s.

The serve script had --enforce-eager. This flag disables CUDAGraph and torch.compile — it's a debugging flag that forces eager execution mode. Someone added it for debugging at some point and forgot to remove it.

Remove it. 26 tok/s → 59 tok/s.

--enforce-eager should never appear in a production serve script. If it's there, remove it before debugging anything else.


Bug 3: tiktoken Vocab Download Fails on Air-Gapped Machines

gpt-oss uses the openai_harmony tokenizer, which uses tiktoken under the hood. On startup, tiktoken downloads its vocab file from OpenAI's CDN:

https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken

The GX10 is on an internal network. Download fails silently or with a cryptic error:

HarmonyError: error downloading or loading vocab file

The workaround requires reading the tiktoken source code. tiktoken caches vocab files using their SHA1 hash as the filename. This is not documented anywhere. The hash for o200k_base.tiktoken is:

fb374d419588a4632f3f557e76b4b70aebbca790

Fix:

# on a machine with internet access
mkdir -p ~/models/tiktoken_cache
wget "https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken" \
  -O ~/models/tiktoken_cache/fb374d419588a4632f3f557e76b4b70aebbca790

# in your serve script
export TIKTOKEN_RS_CACHE_DIR=/home/username/models/tiktoken_cache

The directory and filename must match exactly. No extension on the cached file — just the raw SHA1 hash.


Bug 4: content: null — The Reasoning Parser Trap

Server running. First test request. The GX10 logs show tokens being generated. But no response arrives. Check the logs:

content: None
reasoning_len: 431
tokens: {'prompt_tokens': 68, 'completion_tokens': 100}

The --reasoning-parser openai_gptoss flag routes all output into the reasoning field and sets content to null. gpt-oss is not a thinking model. The parser is designed for a specific output format gpt-oss uses for its reasoning channel — but if your client only reads content (as any standard OpenAI-compatible client does), it gets nothing.

Fix: remove --reasoning-parser openai_gptoss. Content appears.

(This diagnosis was correct, but bug 6 was also active. Removing the parser helped, but without fixing bug 6, outputs still degraded into repetition loops for anything longer than a few sentences.)


Bug 5: System Messages Bypass Harmony Encoding

gpt-oss uses the harmony message format. vLLM processes chat completion requests roughly like this:

  1. Build a system message using get_system_message()
  2. Iterate over request.messages and append each one

The bug (tracked as vLLM PR #31607, unmerged at time of writing): if the client sends a message with role: "system" in the messages array, it gets serialized as a raw harmony message rather than going through get_system_message(). The model sees a malformed token sequence and starts producing garbage.

Any client that sends a system prompt hits this — which is most of them. For openclaw, this matters: every request from the agent loop includes a system prompt.

Manual fix: patch vllm/entrypoints/openai/serving_chat.py to extract system-role messages from request.messages and route them through get_system_message(instructions=...).

Also required:

export VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS=1

This env var defaults to 0. Without it, the instructions parameter to get_system_message() is silently ignored, and the system prompt never reaches the model regardless.


Bug 6: The Wrong Environment Variable (Silent Failure)

This is the one that takes the longest to find.

After bugs 1–5 are fixed, simple requests work. But anything longer falls into a repetition loop:

The user is asking about... The user is asking about... The user is asking about...

Temperature, top_p, repetition_penalty — none of it matters. The loop always wins.

The known cause: on SM121, CUTLASS_FP4 produces first-token logit corruption (see Part 1 and vLLM issue #37030). Fix is to force Marlin for all MXFP4 GEMMs.

Caveat: This applies to stock vLLM. The spark-vllm-docker build uses a different patched vLLM with --mxfp4-backend CUTLASS and VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1 — their patches make CUTLASS work on SM121. If you're using that Docker image, the Marlin workaround is not required. If you're on stock vLLM or applying patches manually, it is.

The serve script already had this:

export VLLM_NVFP4_GEMM_BACKEND=marlin   # ← looks right

This environment variable does not exist in vLLM 0.17.1.

vLLM reads it, finds no matching env var, and silently ignores it. No warning, no error. It falls back to auto-selection. Auto-selection on SM12x picks CUTLASS_FP4. The startup log shows:

[MXFP4] Auto-selected: CUTLASS_FP4 (vLLM native SM120 FP4 grouped GEMM for SM12x)

The correct env var is:

export VLLM_MXFP4_BACKEND=marlin   # ← correct

When this is right, the startup log shows:

[MXFP4] Using backend: marlin (VLLM_MXFP4_BACKEND=marlin)

After this fix: hi → normal response. Longer story request → 3,970 chars, no loop, finish_reason: stop.


What Does the Working Serve Script Look Like?

#!/bin/bash
source /home/username/.python-vllm-eugr/bin/activate

# SM121 backend: force Marlin everywhere (CUTLASS_FP4 is broken on SM121)
# NOTE: if using eugr/spark-vllm-docker, their patches make CUTLASS work —
#       use --mxfp4-backend CUTLASS + VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1 instead
export VLLM_MXFP4_BACKEND=marlin               # ← NOT VLLM_NVFP4_GEMM_BACKEND
export VLLM_MARLIN_USE_ATOMIC_ADD=1            # SM121 Marlin atomic race fix
export FLASHINFER_DISABLE_VERSION_CHECK=1

# Offline tokenizer cache
export TIKTOKEN_RS_CACHE_DIR=/home/username/models/tiktoken_cache

# gpt-oss harmony system message fix
export VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS=1

# Clear compile caches between runs
rm -rf ~/.cache/flashinfer/ ~/.cache/vllm/torch_compile_cache/ 2>/dev/null || true

exec vllm serve /home/username/models/gpt-oss-120b \
  --served-model-name gpt-oss-120b \
  --host 0.0.0.0 --port 8001 \
  --quantization mxfp4 \
  --mxfp4-layers moe,qkv,o,lm_head \
  --kv-cache-dtype fp8 \
  --max-model-len 131072 \
  --max-num-batched-tokens 8192 \
  --gpu-memory-utilization 0.90 \
  --attention-backend FLASHINFER \
  --moe-backend marlin

Two flags updated from the original script, based on eugr/spark-vllm-docker recipe and a NVIDIA Developer Forum thread benchmarking gpt-oss-120B on GB10:

  • --attention-backend FLASHINFER replaces TRITON_ATTN. TRITON is a fallback path; FLASHINFER is the complete path and benchmarks faster.
  • --mxfp4-layers moe,qkv,o,lm_head explicitly quantizes projection layers. Without it, qkv/o/lm_head run in BF16 — leaving performance on the table.

Note: --reasoning-parser openai_gptoss is not included. Remove it unless you specifically need the reasoning channel separated.

eugr venv users: --attention-backend FLASHINFER and --mxfp4-layers are specific to spark-vllm-docker's patched build. If using the eugr venv directly, omit both — TRITON_ATTN is auto-selected and performance is equivalent (~63 tok/s measured).


Performance

MetricValue
Decode speed~59 tok/s
BackendMarlin W4A16 (weight dequantized at inference)
KV cache (fp8, 0.90 utilization)~580K tokens capacity
Max context131K tokens
CUDAGraph✅ captured

The arithmetic: GB10 has 273 GB/s memory bandwidth. 120B × 4-bit ≈ 60 GB. At 273 GB/s, theoretical decode is ~4.5 matrix loads/s, meaning ~60 tok/s. The measured 59 tok/s matches the bandwidth ceiling almost exactly. This is what working looks like on GB10.


What Does 59 tok/s at 131K Context Mean for a Local Agent?

59 tok/s at 131K context, on a machine that fits under a desk.

For openclaw — an always-on AI agent running locally — the arithmetic is different from cloud API use. There's no cost per token. No rate limit. No latency spike when a paid tier gets throttled. The model is available at any hour, responding to any prompt, without a usage meter running in the background.

120B parameters means the model handles complex reasoning, long system prompts, and multi-turn context without the quality ceiling that smaller local models hit. openclaw runs on this model full-time as its primary inference backend. The setup cost was six bugs and a few hours. The ongoing cost is electricity.

The DGX Spark makes this possible because of the unified memory architecture: 128 GB accessible by both CPU and GPU, enough to hold a 60 GB model with room left for a large KV cache. No quantization tradeoff that would have been necessary on a GPU with smaller VRAM.


What Was Gained

The bug that cost the most time: Bug 6 (wrong env var) was invisible. No error, no warning. The startup log is the only signal — and only if you know what to look for. A script can look completely correct and silently do nothing. Always read the startup log before debugging model behavior.

Transferable diagnostics:

  • Repetition loops on gpt-oss → check startup log for Auto-selected: CUTLASS_FP4. If present, the Marlin env var isn't being read.
  • content: null in responses → reasoning parser is routing output to the wrong field. Check the parser flag first.
  • Tokenizer download failures on internal networks → tiktoken uses SHA1-hashed filenames for its cache. The workaround is pre-downloading and pointing TIKTOKEN_RS_CACHE_DIR at the directory.

The pattern that applies everywhere: --enforce-eager in a production serve script is a silent 2× regression. If you inherited a script and the speed is wrong, check for this flag before debugging anything else.


Setup Checklist

If you're serving gpt-oss-120B on SM121:

  1. Apply the SM121 fixes from Part 1 first.
  2. Fix the mxfp4.py import path (eugr fork path → stock vLLM path).
  3. Remove --enforce-eager from your serve script. Check. Then check again.
  4. Pre-download the tiktoken vocab; set TIKTOKEN_RS_CACHE_DIR.
  5. Use VLLM_MXFP4_BACKEND=marlin — not VLLM_NVFP4_GEMM_BACKEND. The wrong variable silently does nothing.
  6. Confirm the startup log says Using backend: marlin, not Auto-selected: CUTLASS_FP4.
  7. If sending system prompts: patch serving_chat.py for PR #31607, and set VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS=1.

Also in this series: Qwen3.5-122B Runs. But at 14 tok/s — the GDN Kernel Gap

FAQ

How do I serve gpt-oss-120B on DGX Spark (SM121) with vLLM at near-theoretical speed?
Use VLLM_MXFP4_BACKEND=marlin (not VLLM_NVFP4_GEMM_BACKEND — the latter silently does nothing). Confirm startup log shows 'Using backend: marlin', not 'Auto-selected: CUTLASS_FP4'. Remove --enforce-eager. With correct config, gpt-oss-120B hits ~59 tok/s on GB10 (273 GB/s bandwidth, 60 GB weights).
What is the correct env var to force Marlin backend in vLLM for MXFP4 models?
VLLM_MXFP4_BACKEND=marlin — not VLLM_NVFP4_GEMM_BACKEND. In vLLM 0.17.1, VLLM_NVFP4_GEMM_BACKEND does not exist and is silently ignored. The startup log will show 'Auto-selected: CUTLASS_FP4' if the wrong variable is used, and repetition loops follow.
Why does tiktoken fail to load the gpt-oss tokenizer on an air-gapped machine?
tiktoken downloads o200k_base.tiktoken from OpenAI's CDN at startup. It caches using the SHA1 hash as the filename (fb374d419588a4632f3f557e76b4b70aebbca790 — no extension). Pre-download on an internet-connected machine and set TIKTOKEN_RS_CACHE_DIR to point to the directory containing the hash-named file.
Why does gpt-oss-120B return content: null even though the model is generating tokens?
--reasoning-parser openai_gptoss routes all output to the reasoning field and sets content to null. Standard OpenAI-compatible clients read content, so they get nothing. Remove --reasoning-parser openai_gptoss unless you specifically need the reasoning channel separated.