WBC Analyzer: Robust OOD Generalization in Peripheral Blood Smears
End-to-end white blood cell classification using DenseNet121 with custom WBCAttention and MedSwish layers, achieving 98.53% in-distribution and 89.05% OOD accuracy via retraining-free inference adaptation, featuring a 5-step Medical Enhanced Filter (MEF) pipeline and an autonomous LLM shortcut-detection agent.
Overview
White blood cell (WBC) classification in peripheral blood smears is a critical task for diagnosing leukemia, infections, and immune disorders. However, deploying deep learning models in real-world clinical environments is challenging due to domain shift: differences in staining kits, exposure levels, and microscope camera sensors across laboratories cause model accuracy to plummet.
This project introduces WBC Analyzer, a deep learning system that achieves high generalization on unseen hardware data without retraining. By combining a custom attention-guided backbone with a test-time normalization pipeline and an autonomous multi-modal LLM auditing agent, WBC Analyzer ensures reliable and interpretable clinical predictions.
This research has been published as an academic preprint (DOI: 10.13140/RG.2.2.34201.79208).
Results at a Glance
| Evaluation Setting | Method / Adaptation | Accuracy | Delta |
|---|---|---|---|
| In-Distribution (ID) | DenseNet121 + WBCAttention | 98.53% | Baseline |
| Out-of-Distribution (OOD) | Naive Baseline (Unadapted) | 56.96% | -41.57 pp |
| Out-of-Distribution (OOD) | WBC Analyzer (Adapted) | 89.05% | +32.09 pp |
Note: OOD evaluation was performed on unseen hospital datasets using different slide preparation hardware.
Phase 1: Model Architecture — WBCAttention & MedSwish
To prevent the neural network from memorizing background blood plasma noise, I designed a specialized architecture that forces the network to focus on cellular morphology.
1. WBCAttention Block
The model utilizes a custom spatial and channel attention mechanism inserted between the dense blocks of a DenseNet121 backbone.
- Channel Attention: Compresses spatial dimensions using both global average pooling and global max pooling, then feeds the statistics through a shared MLP to weight feature maps based on semantic importance (e.g., emphasizing nuclear structure features over background).
- Spatial Attention: Evaluates where features are active across the spatial grid, applying a convolution over concatenated max and average pooling maps to localize cell nuclei and granules.
1
2
3
4
5
6
7
8
9
10
class WBCAttention(nn.Module):
def __init__(self, in_channels, reduction=16):
super().__init__()
self.channel_gate = ChannelAttention(in_channels, reduction)
self.spatial_gate = SpatialAttention()
def forward(self, x):
out = self.channel_gate(x) * x
out = self.spatial_gate(out) * out
return out
2. MedSwish Activation Function
Standard ReLU activations suffer from “dying neurons” in low-contrast medical images, while standard Swish can sometimes lead to gradient instability during early epochs. MedSwish modifies the sigmoid gating parameterized to preserve gradient flow in low-contrast, stain-shifted regions:
\[\text{MedSwish}(x) = x \cdot \text{sigmoid}(\beta \cdot x + \alpha)\]Where $\beta$ controls the gating sharpness and $\alpha$ acts as a soft-bias threshold tailored for medical image distributions, stabilizing backpropagation for deep feature maps.
Phase 2: Domain Adaptation — Medical Enhanced Filter (MEF)
To normalize images before inference, I built a 5-step preprocessing pipeline that translates staining styles and exposure levels into a standardized distribution.
The MEF Normalization Pipeline
- Grayscale Luminance Extraction: Converts input to single-channel luminance to detect overall illumination levels.
- Adaptive Tissue Thresholding: Separates background plasma (which contains slide dirt and light glare) from white and red blood cells.
- LAB Color Space Alignment: Maps the RGB input to LAB color space. The mean and standard deviation of $a$ (magenta-green) and $b$ (yellow-blue) channels are matched to a clinical reference template.
- Contrast-Limited Adaptive Histogram Equalization (CLAHE): Enhances internal nuclear structure (chromatin patterns) without amplifying background noise.
- Exposure Normalization: Scales the L channel to match reference brightness levels.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def medical_enhanced_filter(image, reference_lab_stats):
# Convert input image to LAB color space
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
l, a, b = cv2.split(lab)
# 1. Align color distributions to clinical reference
for channel, ref_stats in zip([a, b], reference_lab_stats):
mean, std = channel.mean(), channel.std()
normalized = (channel - mean) * (ref_stats['std'] / (std + 1e-5)) + ref_stats['mean']
channel = np.clip(normalized, 0, 255).astype(np.uint8)
# 2. Enhance internal cell structure with CLAHE
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
l_enhanced = clahe.apply(l)
# Reconstruct BGR image
normalized_img = cv2.merge((l_enhanced, a, b))
return cv2.cvtColor(normalized_img, cv2.COLOR_LAB2BGR)
Phase 3: Explainability & LLM-Driven Shortcut Auditing
A major risk in medical AI is shortcut learning — e.g., the model classifying a cell as a Lymphocyte simply because the slide template has a specific pen mark or background stain artifact.
To tackle this, WBC Analyzer integrates an automated auditing agent:
graph TD
Image[Input Microscope Slide] --> Model[DenseNet121 + WBCAttention]
Model --> GradCAM[Generate Grad-CAM Heatmap]
Model --> Pred[Class Prediction]
GradCAM --> LLMAgent[Autonomous Multimodal LLM Agent]
Pred --> LLMAgent
LLMAgent --> Audit[Evaluate Saliency Fit]
Audit -->|Nucleus & Cytoplasm Focused| Normal[Pass Audit]
Audit -->|Background/Shortcut Focused| Warn[Trigger Low-Confidence Flag]
The LLM Audit Prompt Strategy
The system generates a Grad-CAM heatmap overlay. The multimodal agent (GPT-4o, falling back to Gemini 2.5 Flash if rate-limited) receives both the original image and the Grad-CAM heatmap.
It evaluates the alignment using semantic rules:
- Condition 1: Is the peak activation (red zone) focused on the cell nucleus and cytoplasm?
- Condition 2: Is there significant activation in empty background regions?
- Condition 3: Are there dark slide artifacts or staining pools attracting model attention?
If the LLM agent flags a background shortcut, the system appends a low-confidence indicator to the API response, warning the clinical operator to review the classification manually.
Phase 4: Retraining-Free Test-Time Adaptation
Instead of expensive model retraining or fine-tuning, WBC Analyzer uses test-time entropy minimization during inference.
When a batch of slides from a new clinic arrives:
- The model computes predictions.
- It measures Shannon entropy of the softmax outputs. High entropy indicates the model is confused by the new domain.
- The system updates only the scale and shift parameters ($\gamma$, $\beta$) in the BatchNorm layers via backpropagation to minimize entropy:
This dynamic shift adjusts the model’s intermediate representations to fit the new domain statistics, yielding a +32.09 pp accuracy boost completely unsupervised, with zero human labeling or parameter modifications in convolutional weights.
Live Demo & Integration
WBC Analyzer is packaged as a high-performance Flask REST API. It yields:
- Inference Latency: ~45ms per image (on NVIDIA T4 GPU)
- API Throughput: ~22 requests/second
- Interpretable Output: Returns class probabilities, a URL to the Grad-CAM visualization, and the LLM shortcut audit report.
You can inspect the source code on GitHub or try the web interface in the Live Demo.