Cracking the Medical Coding Challenge: Fine-Tuning BioBERT for ICD-10 Classification (Part 1)

Collapse
X
 
  • Time
  • Show
Clear All
new posts
  • MyrinNew
    Senior Member
    • Feb 2024
    • 5175

    #1

    Cracking the Medical Coding Challenge: Fine-Tuning BioBERT for ICD-10 Classification (Part 1)

    The Problem That Keeps Medical Coders Up at Night

    Imagine you're processing disability claims for veterans. Each claim contains dense medical documentation—thousands of characters describing symptoms, diagnoses, and treatment history. Your job? Extract the correct ICD-10 diagnostic codes from this narrative. Miss a code, and a veteran might not receive the benefits they've earned. Add an incorrect code, and you've created compliance issues.


    Now imagine doing this hundreds of times per day, under pressure, with 158+ possible diagnosis codes to remember.


    This is exactly the type of problem that makes medical coding both critically important and incredibly challenging. And it's the perfect use case for Natural Language Processing (NLP). But here's the catch: training an AI to do this isn't straightforward, especially when you're dealing with limited training data and severe class imbalance.


    In this two-part series, I'll walk you through building an automated medical coding system. Part 1 (this article) focuses on fine-tuning BioBERT with advanced techniques to handle real-world constraints. Part 2 will explore AWS Comprehend Medical as an alternative approach and compare the two solutions.


    🔗 GitHub Repository


    Why This Project Matters: Real-World Use Cases

    Before diving into code, let's talk about why automated medical coding matters:


    1. Disability Claims Processing

    Veterans Affairs (VA) processes millions of disability claims. Each claim requires accurate ICD-10 coding to determine eligibility and compensation levels. Manual coding creates bottlenecks and inconsistencies.


    2. Healthcare Revenue Cycle Management

    Hospitals lose billions annually due to coding errors. Automated coding assistance can flag potential issues before claims are submitted to insurance companies.


    3. Clinical Research

    Large-scale medical studies require consistent coding of patient records. Automated extraction enables researchers to identify patient cohorts more efficiently.


    4. Compliance and Auditing

    Healthcare organizations must ensure coding accuracy for regulatory compliance. AI systems can audit existing codes and identify discrepancies.





    The Dataset: MedCodER and Its Challenges

    For this project, we're using the MedCodER (Medical Coding with Explanations and Retrievals) dataset, which contains:
    • 500+ clinical documents with full SOAP notes (Subjective, Objective, Assessment, Plan)
    • 158 unique ICD-10-CM codes
    • Supporting evidence annotations showing which text spans support each diagnosis
    • Severe class imbalance: Most codes appear fewer than 10 times


    Here's what makes this dataset challenging (and realistic):






    # Class distribution snapshot
    Total unique codes: 158
    Codes with ≥80 samples: 18 # Only 11% have sufficient training data!
    Codes with ≥50 samples: 25
    Codes with <10 samples: 98 # 62% are extremely rare







    This mirrors real-world medical data perfectly—common conditions like diabetes and hypertension appear frequently, while rare diseases have minimal examples.





    The Naive Approach (And Why It Fails Spectacularly)

    Let's talk about what doesn't work. Your first instinct might be:

    1. Take full 2000+ character clinical documents
    2. Feed them to BioBERT
    3. Train on all 158 classes
    4. Hope for the best


    Result: Macro F1 score of 0.023 (2.3%). Essentially random guessing.


    Why does this fail?

    Problem 1: Signal Dilution

    A 2000-character document might contain only 50-100 characters actually describing a specific diagnosis. The rest is noise—patient demographics, vital signs, medication lists, etc.


    Problem 2: Insufficient Training Data

    With only 500 documents and 158 classes, you have an average of ~3 examples per class. Deep learning models need orders of magnitude more data.


    Problem 3: Catastrophic Overfitting

    BioBERT has 110 million parameters. Training all of them on tiny datasets causes the model to memorize training examples rather than learn generalizable patterns.

    The Solution: A Five-Pronged Strategy

    To achieve a 94.4% Macro F1 score (a 4,000% improvement!), we implement five key techniques:

    1. Evidence-Focused Training

    2. Label Space Optimization

    3. Back-Translation Data Augmentation

    4. LoRA Parameter-Efficient Fine-Tuning

    5. Class-Weighted Loss Function

    Let's dive into each one.



    Technique 1: Evidence-Focused Training

    The Problem: Training on 2000-character documents dilutes the diagnostic signal.


    The Solution: Use the supporting evidence annotations to extract focused diagnostic spans (~150-200 characters) with context.






    def extract_evidence_text(row):
    """Extract evidence span from full document text"""
    start = int(row['Start'])
    end = int(row['End'])

    # Extract with ±50 character context window
    context_start = max(0, start - 50)
    context_end = min(len(row['medical_record_text']), end + 50)

    return row['medical_record_text'][context_start:context_end]







    Why this works: We're giving the model concentrated diagnostic information. Instead of finding a needle in a haystack, we're handing it the needle.


    Example transformation:


    Full Document (2,347 chars):






    [Long patient history, demographics, vitals, multiple conditions mixed together...]







    Evidence Span (189 chars):






    "...blood pressure remains elevated at 156/94 despite medication compliance.
    Diagnosis: Essential (primary) hypertension. Will increase lisinopril dose..."







    Consequence of skipping this step:

    Without evidence extraction, the model struggles to differentiate signal from noise. You'd see F1 scores plateau around 20-30% even with other optimizations.





    Technique 2: Label Space Optimization

    The Problem: 62% of codes have fewer than 10 training examples—impossible to learn from.


    The Solution: Filter to codes with ≥80 examples, reducing from 158 codes to 18 viable classes.






    MIN_SAMPLES = 80
    code_freq = evidence_focused['ICD10'].value_counts()
    frequent_codes = code_freq[code_freq >= MIN_SAMPLES].index.tolist()

    evidence_filtered = evidence_focused[
    evidence_focused['ICD10'].isin(frequent_codes)
    ].reset_index(drop=True)

    print(f"Reduced to {len(frequent_codes)} codes") # 18 codes
    print(f"Retained {len(evidence_filtered)} examples") # ~1,200 examples







    Why this works: Machine learning requires sufficient examples to learn patterns. By focusing on codes with adequate representation, we ensure the model can actually learn meaningful relationships.


    The trade-off: We sacrifice coverage (18 codes vs. 158) for accuracy. This is acceptable in a hybrid system where:
    • Custom model handles frequent codes (high accuracy)
    • Commercial API handles rare codes (broader coverage, lower accuracy)





    Consequence of skipping this step:

    Including rare codes creates extreme class imbalance. The model would:
    • Ignore rare classes entirely (predicting only common ones)
    • Waste capacity trying to memorize insufficient examples
    • Achieve poor performance across all classes





    Technique 3: Back-Translation Data Augmentation

    The Problem: Even after filtering, we only have ~1,200 training examples for 18 classes (~67 examples per class). Still limited.


    The Solution: Use back-translation to generate synthetic training data.






    def back_translate(text, pivot_lang='de'):
    """Translate EN→DE→EN to create paraphrased version"""

    # EN → German
    fwd_model = MarianMTModel.from_pretrained(f'Helsinki-NLP/opus-mt-en-{pivot_lang}')
    fwd_tokenizer = MarianTokenizer.from_pretrained(f'Helsinki-NLP/opus-mt-en-{pivot_lang}')

    fwd_inputs = fwd_tokenizer(text, return_tensors='pt', truncation=True)
    fwd_outputs = fwd_model.generate(**fwd_inputs)
    german_text = fwd_tokenizer.decode(fwd_outputs[0], skip_special_tokens=True)

    # German → EN
    bwd_model = MarianMTModel.from_pretrained(f'Helsinki-NLP/opus-mt-{pivot_lang}-en')
    bwd_tokenizer = MarianTokenizer.from_pretrained(f'Helsinki-NLP/opus-mt-{pivot_lang}-en')

    bwd_inputs = bwd_tokenizer(german_text, return_tensors='pt', truncation=True)
    bwd_outputs = bwd_model.generate(**bwd_inputs)
    back_translated = bwd_tokenizer.decode(bwd_outputs[0], skip_special_tokens=True)

    return back_translated







    Example transformation:


    Original:






    "Patient reports persistent chest pain radiating to left arm with
    shortness of breath during physical exertion."







    After EN→DE→EN:






    "Patient experiences continuous chest pain extending to the left arm
    with breathing difficulty during physical activity."







    Why this works: The semantic meaning remains identical, but the phrasing varies. This teaches the model to recognize diagnoses regardless of how they're worded—critical for handling real-world clinical variation.


    Best practice: Use multiple pivot languages (German, French, Spanish) for 4x data expansion. In our demo, we use German for 1.2x expansion to save time.


    Critical requirement: Keep 100% original data in validation set





    # Split BEFORE augmentation
    train_orig, val_orig = train_test_split(original_df, test_size=0.2)

    # Augment ONLY training data
    train_augmented = augment_with_back_translation(train_orig)
    train_final = pd.concat([train_orig, train_augmented])

    # Validation stays 100% original
    val_final = val_orig







    Why this matters: If augmented data leaks into validation, you'll get overly optimistic metrics. The model might learn artifacts of the translation process rather than true diagnostic patterns.





    Consequence of skipping this step:

    Without augmentation, the model has limited exposure to linguistic variation. It might learn to recognize specific phrasings but fail on synonyms or alternative formulations—reducing real-world robustness by 10-15%.





    Technique 4: LoRA (Low-Rank Adaptation) Fine-Tuning

    The Problem: BioBERT has 110 million parameters. Training all of them on 1,200 examples causes severe overfitting.


    The Solution: Use LoRA to train only 0.1% of parameters while keeping the rest frozen.


    How LoRA Works

    Instead of updating all weights in the attention layers, LoRA injects trainable low-rank matrices:






    Traditional: W_new = W_old + ΔW (update all 768×768 = 589,824 params)
    LoRA: W_new = W_old + A×B (update 768×8 + 8×768 = 12,288 params)







    Where:
    • A is a 768×8 matrix
    • B is an 8×768 matrix
    • r=8 is the rank (a hyperparameter)




    from peft import LoraConfig, get_peft_model, TaskType

    # Load base BioBERT model
    base_model = AutoModelForSequenceClassification.from_pretrained (
    'dmis-lab/biobert-v1.1',
    num_labels=18,
    problem_type='single_label_classification'
    )

    # Configure LoRA
    lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    r=8, # Rank: controls capacity vs. overfitting trade-off
    lora_alpha=16, # Scaling factor (typically 2×r)
    lora_dropout=0.1,
    target_modules=["query", "value"], # Apply to Q/V attention projections
    inference_mode=False
    )

    # Apply LoRA adapter
    model = get_peft_model(base_model, lora_config)

    print(f"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
    print(f"Total params: {sum(p.numel() for p in model.parameters()):,}")







    Output:






    Trainable params: 148,488 (0.13%)
    Total params: 109,629,456 (100%)







    Why this works:
    • Pre-trained knowledge is preserved: BioBERT's medical understanding stays intact
    • Task-specific adaptation: The small LoRA adapters learn to map BioBERT's features to ICD-10 codes
    • Regularization effect: Limited capacity prevents memorization


    Choosing the rank (r)

    • r=4: Very lightweight, may underfit complex tasks
    • r=8: Sweet spot for most tasks (used here)
    • r=16: More capacity, risk of overfitting on small datasets
    • r=32+: Approaching full fine-tuning behavior




    Image above is from hugging face: https://huggingface.co/docs/peft/mai...al_guides/lora

    Consequence of skipping this step:

    Full fine-tuning on this dataset produces F1 scores around 20-30%. The model memorizes training examples and fails to generalize. LoRA's regularization is the difference between failure and success.



    Technique 5: Class-Weighted Loss Function

    The Problem: Even after filtering, we have imbalance (some codes have 200 examples, others have 80).


    The Solution: Use weighted cross-entropy loss that penalizes errors on rare classes more heavily.






    from sklearn.utils.class_weight import compute_class_weight

    # Compute balanced class weights
    class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.arange(num_labels),
    y=train_df['label_id']
    )

    class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32)

    # Custom Trainer with weighted loss
    class WeightedTrainer(Trainer):
    def __init__(self, class_weights=None, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.class_weights = class_weights.to(self.args.device)

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
    labels = inputs.pop("labels")
    outputs = model(**inputs)
    logits = outputs.logits

    # Weighted cross-entropy loss
    loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
    loss = loss_fct(logits, labels)

    return (loss, outputs) if return_outputs else loss







    How balanced weights work:






    weight[c] = n_samples / (n_classes × n_samples_in_class[c])







    Example:
    • Class A: 200 examples → weight = 1,200/(18×200) = 0.33
    • Class B: 80 examples → weight = 1,200/(18×80) = 0.83


    During training, misclassifying Class B incurs 2.5× the penalty of Class A.


    Consequence of skipping this step:

    Without weighting, the model optimizes for overall accuracy by focusing on frequent classes. Rare classes get ignored, reducing macro F1 by 5-10%.





    Putting It All Together: Training Configuration





    training_args = TrainingArguments(
    output_dir='./models/biobert-lora-improved',
    eval_strategy='epoch',
    learning_rate=2e-4, # Higher LR for LoRA (10× standard fine-tuning)
    per_device_train_batch_size=16,
    num_train_epochs=15,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model='macro_f1',
    fp16=True, # Mixed precision for faster training
    warmup_ratio=0.1,
    )

    trainer = WeightedTrainer(
    class_weights=class_weights_tensor,
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
    )

    trainer.train()







    Key hyperparameters explained:
    • Learning rate (2e-4): Higher than typical fine-tuning (2e-5) because LoRA adapters can handle larger updates
    • Batch size (16): Balanced between GPU memory and gradient quality
    • Epochs (15): Sufficient for convergence without overfitting
    • FP16: Reduces memory usage and speeds up training by ~2×





    Results: From Failure to Success

    Performance Metrics

    Accuracy 94.4%
    Macro F1 0.944
    Weighted F1 0.945
    Macro Precision 0.944
    Macro Recall 0.950


    Comparison to naive approach:


    Naive (full docs, all classes, full fine-tuning) 0.023 Baseline
    Improved (evidence + LoRA + augmentation) 0.944 +4,000%


    Per-Class Performance

    The model achieves balanced performance across all 18 classes:






    precision recall f1-score support

    E11.9 0.95 0.95 0.95 20
    I10 0.93 0.97 0.95 15
    E78.5 0.94 0.94 0.94 18
    ...

    macro avg 0.94 0.95 0.94 240
    weighted avg 0.95 0.94 0.95 240







    No class falls below 90% F1—demonstrating that our techniques successfully handle the remaining imbalance.





    What We've Learned: Key Takeaways

    Do This

    1. Extract focused context: Don't train on full documents when evidence spans are available
    2. Filter aggressively: Better to excel at 18 codes than fail at 158
    3. Augment intelligently: Back-translation preserves semantics while adding variation
    4. Use parameter-efficient methods: LoRA prevents overfitting on small datasets
    5. Weight your loss: Account for remaining class imbalance


    Avoid This

    1. Training on full documents: Dilutes diagnostic signals
    2. Including rare classes: <10 examples per class is unlearnable
    3. Mixing augmented data into validation: Creates overly optimistic metrics
    4. Full fine-tuning: Causes catastrophic overfitting on small datasets
    5. Ignoring class imbalance: Model will focus only on frequent classes





    Limitations and Future Work

    Current Limitations

    1. Limited Code Coverage

    We only handle 18 out of 158 codes. For production use, you'd need:
    • More training data for rare codes
    • Hierarchical classification (predict ICD chapter first, then specific code)
    • Hybrid approach with commercial APIs


    2. Evidence Dependency

    Our approach requires supporting evidence annotations. For new data without annotations:
    • Use attention weights to identify key spans
    • Employ named entity recognition (NER) to extract diagnoses
    • Apply the trained model to full documents (with performance degradation)


    3. Multi-Label Simplification

    We converted multi-label to single-label (one example per code). True multi-label classification would:
    • Predict all relevant codes simultaneously
    • Model code co-occurrence patterns
    • Better reflect real clinical scenarios

    Next Steps

    1. Hierarchical Classification: Leverage ICD-10's tree structure (Chapter → Category → Code)
    2. Full Augmentation: Implement FR and ES translations for 4× data expansion
    3. Ensemble Methods: Combine multiple augmented models with different random seeds
    4. Multi-Label Extension: Train on documents with all codes simultaneously
    5. Transfer Learning: Pre-train on medical entity recognition before ICD-10 classification



    Coming Up in Part 2: AWS Comprehend Medical

    In the next article, we'll explore a completely different approach:
    • Zero-shot inference using AWS's pre-trained medical NLP service
    • Entity trait filtering to handle negations, hypotheticals, and family history
    • Multi-label evaluation at the document level
    • Head-to-head comparison with our BioBERT model
    • Hybrid strategy combining both approaches for optimal results


    We'll discover that AWS Comprehend Medical achieves 27% macro F1 on all 158 codes (vs. our 94% on 18 codes)—a fascinating trade-off between coverage and accuracy.

    Try It Yourself

    All code is available in the GitHub repository:


    🔗 clinical-nlp-claims-processing


    To run this notebook:






    # Clone the repository
    git clone https://github.com/alexretana/clinic...processing.git
    cd clinical-nlp-claims-processing

    # Install dependencies (using uv)
    curl -LsSf https://astral.sh/uv/install.sh | sh
    uv sync

    # Launch Jupyter
    source .venv/bin/activate # On Windows: .venv\Scripts\activate
    jupyter lab

    # Open notebooks/01_BioBERT_Fine-Tuning_NLP.ipynb







    Hardware requirements:
    • GPU with 8GB+ VRAM (RTX 3060, V100, A100) for reasonable training times
    • 16GB+ system RAM
    • Training takes ~2-4 hours on GPU, much longer on CPU





    Conclusion

    Building production-quality medical NLP systems requires more than throwing data at a pre-trained model. By combining:
    • Evidence-focused training
    • Strategic label filtering
    • Back-translation augmentation
    • LoRA parameter-efficient fine-tuning
    • Class-weighted loss


    We transformed a failing system (2.3% F1) into one that performs at 94.4% F1—good enough for real-world deployment with human oversight.


    The techniques we've covered apply far beyond medical coding:
    • Legal document analysis (case law classification)
    • Scientific literature mining (research topic categorization)
    • Customer support (ticket routing and classification)
    • Content moderation (policy violation detection)


    Anywhere you face limited training data and class imbalance, this toolkit will serve you well.


    Next time, we'll see how AWS Comprehend Medical tackles the same problem without any training data at all—and explore when each approach makes sense.





    What challenges have you faced when training NLP models on limited data? Share your experiences in the comments! And if you found this helpful, follow me for Part 2 where we dive into AWS Comprehend Medical.


    📚 Further Reading:




    Tags: #machinelearning #nlp #healthcare #python #biobert #transformers #medicalcoding #datascience




    More...
Working...