Tutorial: Profiling Prefill vs Decode with roctx Markers
rocm-trace-lite (RTL) includes a built-in roctx shim — no libroctx64 install needed. Insert markers in your code to analyze GPU kernel hotspots per inference phase.
Quick Start
1. Install RTL
pip install rocm_trace_lite-0.3.2-py3-none-linux_x86_64.whl
2. Add roctx markers to your code
import ctypes
import torch
# Load roctx from global symbol table (RTL injects via LD_PRELOAD)
lib = ctypes.CDLL(None)
roctx_push = lib.roctxRangePushA
roctx_push.argtypes = [ctypes.c_char_p]
roctx_push.restype = ctypes.c_int
roctx_pop = lib.roctxRangePop
roctx_pop.restype = ctypes.c_int
# Mark prefill phase
roctx_push(b"prefill")
# ... prefill code (large batch GEMM, attention, etc.) ...
torch.cuda.synchronize()
roctx_pop()
# Mark decode phase
roctx_push(b"decode")
# ... decode loop (skinny GEMM, token-by-token generation) ...
torch.cuda.synchronize()
roctx_pop()
3. Run with RTL
rtl trace -o trace.db python3 my_model.py
rtl trace automatically sets LD_PRELOAD so roctx symbols are globally visible to your code.
4. Analyze by region
rtl summary trace.db
Or query the SQLite trace directly:
-- List all roctx markers with duration
SELECT s.string AS marker, (o.end - o.start) / 1e6 AS duration_ms
FROM rocpd_op o JOIN rocpd_string s ON o.description_id = s.id
WHERE o.gpuId = -1 ORDER BY o.start;
-- Top kernels within a specific region
SELECT s.string AS kernel, COUNT(*) AS calls,
ROUND(AVG(o.end - o.start) / 1e3, 1) AS avg_us
FROM rocpd_op o JOIN rocpd_string s ON o.description_id = s.id
WHERE o.gpuId >= 0 AND o.start BETWEEN <region_start> AND <region_end>
GROUP BY s.string ORDER BY SUM(o.end - o.start) DESC LIMIT 10;
Example Output
roctx markers:
prefill: 24.3ms
decode: 54.9ms
=== prefill (74 ops, 0.7ms GPU time) ===
Cijk_Ailk_Bljk (MT64x64x256) calls=12 avg=16.1us 27.7% <- Attention GEMM (large batch)
Cijk_Ailk_Bljk (MT64x128x128) calls=12 avg=13.2us 22.7% <- FFN up-project
Cijk_Ailk_Bljk (MT64x16x512) calls=13 avg=12.0us 22.5% <- FFN down-project
vectorized_elementwise (gelu) calls=12 avg=5.7us 9.9% <- Activation
ScaleAlphaVec_PostGSU8 calls=12 avg=4.8us 8.2% <- Post-GEMM scale
=== decode (6336 ops, 36.9ms GPU time) ===
Cijk_Ailk_Bljk (MT256x16x384) calls=1600 avg=8.0us 34.6% <- Skinny GEMM (batch=1)
ScaleAlphaVec_PostGSU8_VW1 calls=1600 avg=4.5us 19.3% <- Post-GEMM scale
Cijk_Ailk_Bljk (MT64x16x128) calls=768 avg=9.1us 18.9% <- FFN GEMM
ScaleAlphaVec_PostGSU8_VW4 calls=768 avg=4.8us 10.1%
vectorized_elementwise (gelu) calls=768 avg=4.0us 8.3%
Key Insights
Phase |
Characteristics |
Optimization direction |
|---|---|---|
Prefill |
Large batch GEMMs, few calls, long duration each |
Compute-bound: optimize GEMM tile sizes |
Decode |
Skinny GEMMs (batch=1), many calls, short duration |
Memory-bound: reduce kernel launch overhead |
Visualize in Perfetto
rtl convert trace.db -o trace.json
# Open trace.json.gz at https://ui.perfetto.dev
roctx markers appear as spans on the timeline, showing prefill and decode phases with all GPU kernels nested inside.
Notes
Call
torch.cuda.synchronize()beforeroctx_pop()to ensure all GPU kernels within the region have completedrtl traceautomatically setsLD_PRELOADfor roctx support (no manual setup needed)No libroctx64 install needed — RTL’s built-in shim provides all roctx symbols
Markers are stored in the same SQLite trace DB alongside GPU kernel records