Post

WBC Analyzer: Robust OOD Generalization in Peripheral Blood Smears

End-to-end white blood cell classification using DenseNet121 with custom WBCAttention and MedSwish layers, achieving 98.53% in-distribution and 89.05% OOD accuracy via retraining-free inference adaptation, featuring a 5-step Medical Enhanced Filter (MEF) pipeline and an autonomous LLM shortcut-detection agent.

WBC Analyzer: Robust OOD Generalization in Peripheral Blood Smears

Overview

White blood cell (WBC) classification in peripheral blood smears is a critical task for diagnosing leukemia, infections, and immune disorders. However, deploying deep learning models in real-world clinical environments is challenging due to domain shift: differences in staining kits, exposure levels, and microscope camera sensors across laboratories cause model accuracy to plummet.

This project introduces WBC Analyzer, a deep learning system that achieves high generalization on unseen hardware data without retraining. By combining a custom attention-guided backbone with a test-time normalization pipeline and an autonomous multi-modal LLM auditing agent, WBC Analyzer ensures reliable and interpretable clinical predictions.

This research has been published as an academic preprint (DOI: 10.13140/RG.2.2.34201.79208).


Results at a Glance

Evaluation SettingMethod / AdaptationAccuracyDelta
In-Distribution (ID)DenseNet121 + WBCAttention98.53%Baseline
Out-of-Distribution (OOD)Naive Baseline (Unadapted)56.96%-41.57 pp
Out-of-Distribution (OOD)WBC Analyzer (Adapted)89.05%+32.09 pp

Note: OOD evaluation was performed on unseen hospital datasets using different slide preparation hardware.


Phase 1: Model Architecture — WBCAttention & MedSwish

To prevent the neural network from memorizing background blood plasma noise, I designed a specialized architecture that forces the network to focus on cellular morphology.

1. WBCAttention Block

The model utilizes a custom spatial and channel attention mechanism inserted between the dense blocks of a DenseNet121 backbone.

  • Channel Attention: Compresses spatial dimensions using both global average pooling and global max pooling, then feeds the statistics through a shared MLP to weight feature maps based on semantic importance (e.g., emphasizing nuclear structure features over background).
  • Spatial Attention: Evaluates where features are active across the spatial grid, applying a convolution over concatenated max and average pooling maps to localize cell nuclei and granules.
1
2
3
4
5
6
7
8
9
10
class WBCAttention(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super().__init__()
        self.channel_gate = ChannelAttention(in_channels, reduction)
        self.spatial_gate = SpatialAttention()

    def forward(self, x):
        out = self.channel_gate(x) * x
        out = self.spatial_gate(out) * out
        return out

2. MedSwish Activation Function

Standard ReLU activations suffer from “dying neurons” in low-contrast medical images, while standard Swish can sometimes lead to gradient instability during early epochs. MedSwish modifies the sigmoid gating parameterized to preserve gradient flow in low-contrast, stain-shifted regions:

\[\text{MedSwish}(x) = x \cdot \text{sigmoid}(\beta \cdot x + \alpha)\]

Where $\beta$ controls the gating sharpness and $\alpha$ acts as a soft-bias threshold tailored for medical image distributions, stabilizing backpropagation for deep feature maps.


Phase 2: Domain Adaptation — Medical Enhanced Filter (MEF)

To normalize images before inference, I built a 5-step preprocessing pipeline that translates staining styles and exposure levels into a standardized distribution.

The MEF Normalization Pipeline

  1. Grayscale Luminance Extraction: Converts input to single-channel luminance to detect overall illumination levels.
  2. Adaptive Tissue Thresholding: Separates background plasma (which contains slide dirt and light glare) from white and red blood cells.
  3. LAB Color Space Alignment: Maps the RGB input to LAB color space. The mean and standard deviation of $a$ (magenta-green) and $b$ (yellow-blue) channels are matched to a clinical reference template.
  4. Contrast-Limited Adaptive Histogram Equalization (CLAHE): Enhances internal nuclear structure (chromatin patterns) without amplifying background noise.
  5. Exposure Normalization: Scales the L channel to match reference brightness levels.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def medical_enhanced_filter(image, reference_lab_stats):
    # Convert input image to LAB color space
    lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
    l, a, b = cv2.split(lab)
    
    # 1. Align color distributions to clinical reference
    for channel, ref_stats in zip([a, b], reference_lab_stats):
        mean, std = channel.mean(), channel.std()
        normalized = (channel - mean) * (ref_stats['std'] / (std + 1e-5)) + ref_stats['mean']
        channel = np.clip(normalized, 0, 255).astype(np.uint8)
        
    # 2. Enhance internal cell structure with CLAHE
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    l_enhanced = clahe.apply(l)
    
    # Reconstruct BGR image
    normalized_img = cv2.merge((l_enhanced, a, b))
    return cv2.cvtColor(normalized_img, cv2.COLOR_LAB2BGR)

Phase 3: Explainability & LLM-Driven Shortcut Auditing

A major risk in medical AI is shortcut learning — e.g., the model classifying a cell as a Lymphocyte simply because the slide template has a specific pen mark or background stain artifact.

To tackle this, WBC Analyzer integrates an automated auditing agent:

graph TD
    Image[Input Microscope Slide] --> Model[DenseNet121 + WBCAttention]
    Model --> GradCAM[Generate Grad-CAM Heatmap]
    Model --> Pred[Class Prediction]
    GradCAM --> LLMAgent[Autonomous Multimodal LLM Agent]
    Pred --> LLMAgent
    LLMAgent --> Audit[Evaluate Saliency Fit]
    Audit -->|Nucleus & Cytoplasm Focused| Normal[Pass Audit]
    Audit -->|Background/Shortcut Focused| Warn[Trigger Low-Confidence Flag]

The LLM Audit Prompt Strategy

The system generates a Grad-CAM heatmap overlay. The multimodal agent (GPT-4o, falling back to Gemini 2.5 Flash if rate-limited) receives both the original image and the Grad-CAM heatmap.

It evaluates the alignment using semantic rules:

  • Condition 1: Is the peak activation (red zone) focused on the cell nucleus and cytoplasm?
  • Condition 2: Is there significant activation in empty background regions?
  • Condition 3: Are there dark slide artifacts or staining pools attracting model attention?

If the LLM agent flags a background shortcut, the system appends a low-confidence indicator to the API response, warning the clinical operator to review the classification manually.


Phase 4: Retraining-Free Test-Time Adaptation

Instead of expensive model retraining or fine-tuning, WBC Analyzer uses test-time entropy minimization during inference.

When a batch of slides from a new clinic arrives:

  1. The model computes predictions.
  2. It measures Shannon entropy of the softmax outputs. High entropy indicates the model is confused by the new domain.
  3. The system updates only the scale and shift parameters ($\gamma$, $\beta$) in the BatchNorm layers via backpropagation to minimize entropy:
\[\mathcal{L}_{\text{entropy}} = - \sum_{c} p(c) \log p(c)\]

This dynamic shift adjusts the model’s intermediate representations to fit the new domain statistics, yielding a +32.09 pp accuracy boost completely unsupervised, with zero human labeling or parameter modifications in convolutional weights.


Live Demo & Integration

WBC Analyzer is packaged as a high-performance Flask REST API. It yields:

  • Inference Latency: ~45ms per image (on NVIDIA T4 GPU)
  • API Throughput: ~22 requests/second
  • Interpretable Output: Returns class probabilities, a URL to the Grad-CAM visualization, and the LLM shortcut audit report.

You can inspect the source code on GitHub or try the web interface in the Live Demo.

This post is licensed under CC BY 4.0 by the author.