This directory conducts federated instruction tuning using FedProx with a pretrained ContactDoctor/Bio-Medical-Llama-3-8B model on a Medical dataset. We use Flower Datasets to download, partition, and preprocess the dataset. Flower's Simulation Engine is used to simulate the LLM fine-tuning process in a federated way, which allows users to perform the training on a single GPU.
The fine-tuning results have been submitted as a PEFT adapter and can be accessed here:
This experiment performs federated LLM fine-tuning with LoRA using the 🤗PEFT library. The clients' models are aggregated with FedProx strategy.
For the Bio-Medical-Llama-3-8B model, we adopted the following fine-tuning methodology:
- Precision: BF16 for model weights with TF32 for computation
- Quantization: 4-bit quantization to reduce memory footprint
- Gradient Checkpointing: Enabled to save memory during backpropagation
- Optimizer: Paged AdamW 8-bit
- LoRA Configuration:
- Rank (r): 8
- Alpha: 32
- Batch Size: 16 per device
- Gradient Accumulation: 1 step
- Learning Rate: Constant with warmup
- Max: 1e-5
- Min: 1e-6
- Training Schedule:
- Warmup Steps: 2
- Max Steps: 6
- Epochs: 3
- Max Gradient Norm: 1.0
- Rounds: 10 server rounds
- Client Participation: 15% of clients selected per round
- Federation Type: Local simulation with 20 supernodes
- Resources per Client: 6 CPUs, 1 GPU
- Dataset: Medical Meadow Medical Flashcards
When bf16 and tf32 are enabled, model weights are stored in bf16 format, while gradients are computed in half-precision and converted to full 32-bit precision for updates.
Hardware: NVIDIA A100 (1 X GPU)
- pubmedqa: 0.6580
- medqa: 0.6031
- medmcqa: 0.6834
- careqa: 0.5367
- average: 0.6203
741.89 MB
Project dependencies are defined in pyproject.toml. Install them in an activated Python environment with:
pip install -e .The dataset is divided into 20 partitions in an IID fashion, a partition is assigned to each ClientApp.
We randomly sample a fraction (0.15) of the total nodes to participate in each round, for a total of 10 rounds.
All settings are defined in pyproject.toml.
Important
Please note that [tool.flwr.app.config.static] and options.num-supernodes under [tool.flwr.federations.local-simulation] are not allowed to be modified for fair competition if you plan to participated in the LLM leaderboard.
Run the challenge with default config values.
The configs are defined in [tool.flwr.app.config] entry of pyproject.toml, and are loaded automatically.
flwr runThe global PEFT model checkpoints are saved every 5 rounds after aggregation on the sever side as default, which can be specified with train.save-every-round under [tool.flwr.app.config] entry in pyproject.toml.
This code is based on the following repositories:
We thank the authors for their valuable contributions to the medical LLM fine-tuning community.