Rewriting model inference with CUDA kernels: the bottleneck was not just GEMM [P]

I’ve been working on a CUDA-first inference runtime for small-batch / realtime ML workloads. The core idea is simple: instead of treating…

By AI Maestro May 18, 2026 2 min read
Rewriting model inference with CUDA kernels: the bottleneck was not just GEMM [P]

I’ve been working on a CUDA-first inference runtime for small-batch / realtime ML workloads.

The core idea is simple: instead of treating PyTorch / TensorRT / generic graph runtimes as the main execution path, I rewrite the model inference path directly with C++/CUDA kernels.

This started from robotics / VLA workloads, but the problem is more general.

In small-batch inference, the bottleneck is often not just a single slow GEMM. A lot of latency comes from the runtime glue around the math:

  • fragmented small kernels
  • norm / residual / activation boundaries
  • quantize / dequantize overhead
  • layout transitions
  • Python / runtime scheduling
  • graph compiler fusion failures
  • precision conversion around FP8 / FP4 regions

For cloud LLM serving, batching can hide a lot of this.

For robotics, VLA, world models, and other realtime workloads, batch size is usually 1. There is nowhere to hide. Every launch, sync, and format boundary shows up directly in latency.

Some current results from my implementation:

Model / workloadHardwareFlashRT latency
Pi0.5Jetson Thor~44 ms
Pi0Jetson Thor~46 ms
GROOT N1.6Jetson Thor~41–45 ms
Pi0.5RTX 5090~17.6 ms
GROOT N1.6RTX 5090~12.5–13.1 ms
Pi0-FASTRTX 5090~2.39 ms/token
Qwen3.6 27BRTX 5090~129 tok/s with NVFP4
Motus / Wan-style world modelRTX 5090~1.3s baseline → targeting ~100ms E2E

The Motus / world-model case is especially interesting.

The baseline path is around 1.3s end-to-end. The target is ~100ms E2E, but the hard part is not simply “use a faster GEMM”. The bottlenecks are VAE, joint attention, launch fragmentation, and a large amount of glue around the actual math.

One lesson from this work: lower precision is not automatically a win.

FP8 has been consistently useful. FP4 / NVFP4 is more mixed. It can help memory footprint and some large GEMM regions, but if the FP4 region is small, discontinuous, or surrounded by conversion / scaling overhead, the end-to-end speedup can be tiny.

For example, in some VLA / world-model paths, FP4 over FP8 only gives a few percent latency improvement unless the region is large and deeply fused.

This changed how I think about inference optimization.

For large-batch cloud serving, generic runtimes and batching are often enough.

For realtime small-batch inference, the runtime overhead becomes the workload.

Curious if others have seen similar behavior with torch.compile, TensorRT, XLA, Triton, or custom CUDA kernels.

At what point do you stop trying to make a generic compiler optimize the model, and just rewrite the inference path directly?

Implementation: https://github.com/LiangSu8899/FlashRT

submitted by /u/Diligent-End-2711

Stay ahead of AI. Get the most important stories delivered to your inbox — no spam, no noise.

Name
Scroll to Top