Skip to content

h-jia/BimedLLama

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

FlowerTune LLM on Medical Dataset

This directory conducts federated instruction tuning using FedProx with a pretrained ContactDoctor/Bio-Medical-Llama-3-8B model on a Medical dataset. We use Flower Datasets to download, partition, and preprocess the dataset. Flower's Simulation Engine is used to simulate the LLM fine-tuning process in a federated way, which allows users to perform the training on a single GPU.

PEFT Adapter

The fine-tuning results have been submitted as a PEFT adapter and can be accessed here:

Methodology

This experiment performs federated LLM fine-tuning with LoRA using the 🤗PEFT library. The clients' models are aggregated with FedProx strategy.

Bio-Medical-Llama-3-8B

For the Bio-Medical-Llama-3-8B model, we adopted the following fine-tuning methodology:

  • Precision: BF16 for model weights with TF32 for computation
  • Quantization: 4-bit quantization to reduce memory footprint
  • Gradient Checkpointing: Enabled to save memory during backpropagation
  • Optimizer: Paged AdamW 8-bit
  • LoRA Configuration:
    • Rank (r): 8
    • Alpha: 32
  • Batch Size: 16 per device
  • Gradient Accumulation: 1 step
  • Learning Rate: Constant with warmup
    • Max: 1e-5
    • Min: 1e-6
  • Training Schedule:
    • Warmup Steps: 2
    • Max Steps: 6
    • Epochs: 3
    • Max Gradient Norm: 1.0
  • Rounds: 10 server rounds
  • Client Participation: 15% of clients selected per round
  • Federation Type: Local simulation with 20 supernodes
  • Resources per Client: 6 CPUs, 1 GPU
  • Dataset: Medical Meadow Medical Flashcards

When bf16 and tf32 are enabled, model weights are stored in bf16 format, while gradients are computed in half-precision and converted to full 32-bit precision for updates.

Evaluation Results

Hardware: NVIDIA A100 (1 X GPU)

  • pubmedqa: 0.6580
  • medqa: 0.6031
  • medmcqa: 0.6834
  • careqa: 0.5367
  • average: 0.6203

Communication Budget

741.89 MB

Environments setup

Project dependencies are defined in pyproject.toml. Install them in an activated Python environment with:

pip install -e .

Experimental setup

The dataset is divided into 20 partitions in an IID fashion, a partition is assigned to each ClientApp. We randomly sample a fraction (0.15) of the total nodes to participate in each round, for a total of 10 rounds. All settings are defined in pyproject.toml.

Important

Please note that [tool.flwr.app.config.static] and options.num-supernodes under [tool.flwr.federations.local-simulation] are not allowed to be modified for fair competition if you plan to participated in the LLM leaderboard.

Running the challenge

Run the challenge with default config values. The configs are defined in [tool.flwr.app.config] entry of pyproject.toml, and are loaded automatically.

flwr run

Model saving

The global PEFT model checkpoints are saved every 5 rounds after aggregation on the sever side as default, which can be specified with train.save-every-round under [tool.flwr.app.config] entry in pyproject.toml.

Acknowledgements

This code is based on the following repositories:

We thank the authors for their valuable contributions to the medical LLM fine-tuning community.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages