Add MedFlamingo Multimodal VQA Pipeline#954
Open
zarmeen2 wants to merge 18 commits intosunlabuiuc:masterfrom
Open
Add MedFlamingo Multimodal VQA Pipeline#954zarmeen2 wants to merge 18 commits intosunlabuiuc:masterfrom
zarmeen2 wants to merge 18 commits intosunlabuiuc:masterfrom
Conversation
Adding MedFlamingo Model Scaffold
add MedFlamingo to models.rst
- Fix MedFlamingo.generate() to pass inputs_embeds so xattn visual conditioning is actually applied (was passing raw input_ids) - Fix MedFlamingo.__init__() to initialise self._fc = None when no dataset is supplied (prevents AttributeError in forward()) - VQARADDataset.prepare_metadata(): filter rows whose image file is missing from disk (14 OSF images never existed); logs a warning - Remove duplicate VQARADDataset import in datasets/__init__.py - Remove duplicate MedicalVQATask import in tasks/__init__.py - medical_vqa_task.py: add module docstring, full Google-style class docstring, and __call__ docstring with Args / Returns / Example - examples/vqarad_medvqa_medflamingo.py: full rewrite with three ablation axes (cross_attn_every_n_layers, num_resampler_tokens, freeze_vision), --ablation CLI flag, helper functions, usage docs - tests/core/test_medflamingo.py: remove all TODO stubs; add isolated MedicalVQATask unit tests and test_generate_uses_inputs_embeds; fix Patient construction to use Polars DataFrame API Contributors: Zarmeen Hasan (zarmeen2), Camdyn Zook (camdynz2)
Feat/medflamingo full pipeline
…testing-suggestions Feat/docs example repo clean up and testing suggestions
…testing-suggestions test fixes
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Contributors:
Contribution Type: Full pipeline
Overview
This PR implements a full end-to-end Medical Visual Question Answering (VQA) pipeline based on the MedFlamingo architecture. It adds a new dataset loader, task definition, and model that integrate with PyHealth's existing
Trainer,SampleDataset, andBaseModelconventions. The model also exposes a standalonegenerate()interface for free-text, few-shot VQA.Files Changed
New Components
VQARADDataset— Loads the VQA-RAD dataset from its raw JSON, normalizes field name variants, and writes a flat CSV for PyHealth's base dataset pipeline. Overridesset_task()to auto-injectImageProcessor(mode="RGB", image_size=224).MedicalVQATask— Converts patient VQA-RAD events into{image, question, answer}sample dicts. Declaresinput_schema: {image: "image", question: "text"}andoutput_schema: {answer: "multiclass"}, framing VQA as closed-set classification for the standard training loop.MedFlamingo— Integrates a frozen CLIP ViT-L/14 vision encoder and a frozen LLM (default: OPT-6.7B) with trainable gated cross-attention layers (MedFlamingoLayer) inserted every N LLM layers. APerceiverResamplercompresses variable-length CLIP patch tokens to a fixed 64-token sequence before each cross-attention block. Only the cross-attention layers and classification head are trained.Design Decisions
Zero-initialized gates: Both the attention gate and FFN gate in each
MedFlamingoLayerare initialized to zero so the model starts training as the frozen LLM, preventing unstable early updates (from the original Flamingo paper).Dual interface:
forward()conforms to PyHealth'sBaseModelcontract forTrainercompatibility;generate()supports open-ended few-shot generation by passing visually-conditioned embeddings (inputs_embeds) directly to the LLM.Testable by design:
_init_vision_encoder()and_init_lang_model()are isolated methods overridden inTestableMedFlamingoto swap in CPU-only stubs, so tests run without any model downloads.Tests
The test suite is split into three focused files.
test_medflamingo.pytests the model in isolation using lightweight CPU stubs — coveringforward(),generate()(single and few-shot),inputs_embedsusage verification, and gradient flow.test_medical_vqa_task.pytestsMedicalVQATaskschema, sample emission, and edge cases with dummy objects.test_vqarad.pyrunsVQARADDatasetend-to-end against a synthetic fixture, validating CSV output, patient/event parsing, and processed sample shapes.