• Optimizing retrieval-augmented LLM inference isn’t just about bigger GPUs, it’s about smarter orchestration.

    🚨 The Problem

    A few months ago, our retrieval-augmented generation (RAG) system was hitting its limits.

    We were serving clinical queries from healthcare providers, but:

    • p95 latency: 850 ms
    • GPU utilization: ~35 %
    • Query throughput: barely 300 QPS under load

    The pipeline looked like this:

    User Query
      ↓
    1️⃣ Embed query → FAISS search
    2️⃣ Lexical retrieval → Elasticsearch (BM25)
    3️⃣ Merge + re-rank
    4️⃣ Construct context → LLM inference
    5️⃣ De-identify + postprocess
      ↓
    Response
    

    Despite using a strong model, our serving stack wasn’t scaling.

    We needed to bring latency below 500 ms while keeping costs in check.


    🧠 Profiling the Bottleneck

    We profiled the full inference path using:

    • Nsight Systems and PyTorch Profiler for GPU traces
    • Prometheus + Grafana for latency and token-level metrics

    Key findings:

    • 40 % GPU idle time between kernels (batching inefficiency)
    • Retrieval (FAISS + Elastic) serialized — ~120 ms wasted
    • KV cache fragmentation consuming unnecessary VRAM
    • Static batching delaying low-traffic requests

    ⚙️ The Redesign

    We re-architected inference around three major improvements:

    (1) Tensor-parallel serving via vLLM, (2) Continuous batching, and (3) Semantic caching.


    🚀 1. Multi-GPU Tensor Parallelism

    We deployed our model using vLLM on AWS p5 instances (8× H100 GPUs):

    python -m vllm.entrypoints.api_server \\
      --model meta-llama/Llama-2-70b-chat-hf \\
      --tensor-parallel-size 4 \\
      --gpu-memory-utilization 0.9 \\
      --port 8000
    
    • Model weights sharded across 4 GPUs
    • Each GPU stores KV-cache only for its own attention heads
    • Communication handled by NCCL over NVLink

    Result:

    • Linear scaling of throughput
    • GPU utilization ↑ to 85 %
    • Per-token latency ↓ 35 %

    2. Continuous Batching

    Static batching was replaced with continuous batching, where new requests join an ongoing batch mid-generation.

    This kept GPUs fully occupied even with variable traffic.

    Before: waiting to fill batches → GPU idle gaps

    After: dynamic merge/split at token granularity

    Result:

    • Throughput ↑ 2.8×
    • Latency variance (p95 – p99) ↓ 50 %

    🧠 3. Semantic Caching

    We introduced a Redis-backed semantic cache keyed by query embeddings.

    cache_key = f"embed:{hash(embedding)}"
    if redis.exists(cache_key):
        results = redis.get(cache_key)
    else:
        results = run_retrieval(embedding)
        redis.setex(cache_key, 3600, results)
    

    When a new query arrives, we check for high cosine similarity (> 0.9) with cached embeddings.

    If found, we reuse the retrieval + context immediately — skipping FAISS + Elastic calls.

    Result:

    • Cache hit rate 71 %
    • Vector DB load ↓ 60 %
    • End-to-end latency ↓ ~90 ms for repeated/semantically similar queries

    🔍 Results

    MetricBeforeAfterΔ
    GPU Utilization35 %85 %+50 pts
    p95 Latency850 ms490 ms↓ 42 %
    Cost per Query↓ 37 %
    NDCG@10 (Retrieval Quality)0.730.89+22 %
    Cache Hit Rate0 %71 %

    🧩 Lessons Learned

    • Batching is a hidden bottleneck — static batching kills latency under variable load.
    • Tensor parallelism isn’t free — one slow GPU throttles the group; NVLink topology matters.
    • Semantic caching > lexical caching — paraphrased queries benefit enormously.
    • Monitoring is critical — token throughput and cache hit dashboards caught regressions early.

    🧠 Takeaway

    Optimizing LLM inference isn’t about throwing bigger GPUs at the problem — it’s about making each token, cache lookup, and GPU cycle count.

    By re-architecting our serving stack with vLLM, tensor parallelism, and semantic caching,

    we now sustain 1 K QPS at < 500 ms p95 — with better retrieval grounding and 35 % lower cost.


    ✍️ Final Thought

    We plan to open-source part of our semantic caching implementation soon.

    If you’re scaling RAG inference or debugging GPU underutilization, I’d love to compare notes — reach out anytime.