Skip to content

Add MedFlamingo Multimodal VQA Pipeline#954

Open
zarmeen2 wants to merge 18 commits intosunlabuiuc:masterfrom
zDoda:master
Open

Add MedFlamingo Multimodal VQA Pipeline#954
zarmeen2 wants to merge 18 commits intosunlabuiuc:masterfrom
zDoda:master

Conversation

@zarmeen2
Copy link
Copy Markdown

@zarmeen2 zarmeen2 commented Apr 7, 2026

Contributors:

Contribution Type: Full pipeline

Overview

This PR implements a full end-to-end Medical Visual Question Answering (VQA) pipeline based on the MedFlamingo architecture. It adds a new dataset loader, task definition, and model that integrate with PyHealth's existing Trainer, SampleDataset, and BaseModel conventions. The model also exposes a standalone generate() interface for free-text, few-shot VQA.


Files Changed

PyHealth/
├── docs/api/
│   ├── datasets/
│   │   └── pyhealth.datasets.VQARADDataset.rst  # new
│   ├── models/
│   │   └── pyhealth.models.MedFlamingo.rst       # new
│   ├── tasks/
│   │   └── pyhealth.tasks.MedicalVQATask.rst     # new
│   ├── datasets.rst                              # updated
│   ├── models.rst                                # updated
│   └── tasks.rst                                 # updated
├── examples/
│   └── vqarad_medvqa_medflamingo.py              # new: end-to-end example + ablation study
├── pyhealth/
│   ├── datasets/
│   │   ├── configs/
│   │   │   └── vqarad.yaml                       # new: schema config
│   │   ├── __init__.py                           # updated: exports VQARADDataset
│   │   └── vqarad.py                             # new: VQA-RAD dataset loader
│   ├── models/
│   │   ├── __init__.py                           # updated: exports MedFlamingo, MedFlamingoLayer
│   │   └── medflamingo.py                        # new: PerceiverResampler, MedFlamingoLayer, MedFlamingo
│   └── tasks/
│       ├── __init__.py                           # updated: exports MedicalVQATask
│       └── medical_vqa_task.py                   # new: MedicalVQATask
├── tests/core/
│   ├── test_medflamingo.py                       # updated: model unit tests with CPU stubs
│   ├── test_medical_vqa_task.py                  # new: isolated task unit tests
│   └── test_vqarad.py                            # new: dataset integration tests
├── pixi.lock                                     # updated: dependency lock file
└── test_medflamingo.py                           # new: root-level smoke test

New Components

VQARADDataset — Loads the VQA-RAD dataset from its raw JSON, normalizes field name variants, and writes a flat CSV for PyHealth's base dataset pipeline. Overrides set_task() to auto-inject ImageProcessor(mode="RGB", image_size=224).

MedicalVQATask — Converts patient VQA-RAD events into {image, question, answer} sample dicts. Declares input_schema: {image: "image", question: "text"} and output_schema: {answer: "multiclass"}, framing VQA as closed-set classification for the standard training loop.

MedFlamingo — Integrates a frozen CLIP ViT-L/14 vision encoder and a frozen LLM (default: OPT-6.7B) with trainable gated cross-attention layers (MedFlamingoLayer) inserted every N LLM layers. A PerceiverResampler compresses variable-length CLIP patch tokens to a fixed 64-token sequence before each cross-attention block. Only the cross-attention layers and classification head are trained.


Design Decisions

  • Zero-initialized gates: Both the attention gate and FFN gate in each MedFlamingoLayer are initialized to zero so the model starts training as the frozen LLM, preventing unstable early updates (from the original Flamingo paper).

  • Dual interface: forward() conforms to PyHealth's BaseModel contract for Trainer compatibility; generate() supports open-ended few-shot generation by passing visually-conditioned embeddings (inputs_embeds) directly to the LLM.

  • Testable by design: _init_vision_encoder() and _init_lang_model() are isolated methods overridden in TestableMedFlamingo to swap in CPU-only stubs, so tests run without any model downloads.


Tests

The test suite is split into three focused files. test_medflamingo.py tests the model in isolation using lightweight CPU stubs — covering forward(), generate() (single and few-shot), inputs_embeds usage verification, and gradient flow. test_medical_vqa_task.py tests MedicalVQATask schema, sample emission, and edge cases with dummy objects. test_vqarad.py runs VQARADDataset end-to-end against a synthetic fixture, validating CSV output, patient/event parsing, and processed sample shapes.

zarmeen2 and others added 18 commits March 26, 2026 20:25
Adding MedFlamingo Model Scaffold
- Fix MedFlamingo.generate() to pass inputs_embeds so xattn visual
  conditioning is actually applied (was passing raw input_ids)
- Fix MedFlamingo.__init__() to initialise self._fc = None when no
  dataset is supplied (prevents AttributeError in forward())
- VQARADDataset.prepare_metadata(): filter rows whose image file is
  missing from disk (14 OSF images never existed); logs a warning
- Remove duplicate VQARADDataset import in datasets/__init__.py
- Remove duplicate MedicalVQATask import in tasks/__init__.py
- medical_vqa_task.py: add module docstring, full Google-style class
  docstring, and __call__ docstring with Args / Returns / Example
- examples/vqarad_medvqa_medflamingo.py: full rewrite with three
  ablation axes (cross_attn_every_n_layers, num_resampler_tokens,
  freeze_vision), --ablation CLI flag, helper functions, usage docs
- tests/core/test_medflamingo.py: remove all TODO stubs; add isolated
  MedicalVQATask unit tests and test_generate_uses_inputs_embeds;
  fix Patient construction to use Polars DataFrame API

Contributors: Zarmeen Hasan (zarmeen2), Camdyn Zook (camdynz2)
…testing-suggestions

Feat/docs example repo clean up and testing suggestions
@zarmeen2 zarmeen2 changed the title Med-Flamingo Full Pipeline Implementation Add MedFlamingo Multimodal VQA Pipeline Apr 9, 2026
@zarmeen2 zarmeen2 marked this pull request as ready for review April 9, 2026 00:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants