Skip to content

feat: add use_flex_decoding parameter for non-CUDA devices#326

Open
Mr-Neutr0n wants to merge 1 commit intovikhyat:mainfrom
Mr-Neutr0n:fix/flex-decoding-param
Open

feat: add use_flex_decoding parameter for non-CUDA devices#326
Mr-Neutr0n wants to merge 1 commit intovikhyat:mainfrom
Mr-Neutr0n:fix/flex-decoding-param

Conversation

@Mr-Neutr0n
Copy link
Copy Markdown

Summary

Add a use_flex_decoding constructor parameter to MoondreamModel and HfConfig to allow users to disable flex decoding on non-CUDA devices (e.g., MPS on Mac).

Problem

Moondream 3 uses flex decoding which requires CUDA. Users on Mac with MPS cannot run the model because create_block_mask is CUDA-only. Currently, users have to manually patch the source code to set use_flex_decoding = False.

Solution

Add a use_flex_decoding parameter that defaults to True (backward compatible) but can be set to False for non-CUDA devices.

Changes

  • moondream/torch/moondream.py: Add use_flex_decoding parameter to MoondreamModel.__init__
  • moondream/torch/hf_moondream.py: Add use_flex_decoding to HfConfig and pass through to MoondreamModel

Usage

# For HuggingFace loading on Mac with MPS
from transformers import AutoModelForCausalLM, AutoConfig

config = AutoConfig.from_pretrained('moondream/moondream3', use_flex_decoding=False)
model = AutoModelForCausalLM.from_pretrained('moondream/moondream3', config=config)
model.to('mps')

Test plan

  • Load model with use_flex_decoding=True on CUDA (default behavior unchanged)
  • Load model with use_flex_decoding=False on MPS (Mac)
  • Verify inference works correctly with flex decoding disabled

Fixes #316

Add a `use_flex_decoding` constructor parameter to allow users to disable
flex decoding on non-CUDA devices (e.g., MPS on Mac).

Changes:
- MoondreamModel: Add `use_flex_decoding` parameter (default: True)
- HfConfig: Add `use_flex_decoding` config option
- HfMoondream: Pass through `use_flex_decoding` to MoondreamModel

Usage:
```python
# For HuggingFace loading on Mac with MPS
from transformers import AutoModelForCausalLM, AutoConfig
config = AutoConfig.from_pretrained('moondream/moondream3', use_flex_decoding=False)
model = AutoModelForCausalLM.from_pretrained('moondream/moondream3', config=config)
```

Fixes vikhyat#316
@Mr-Neutr0n
Copy link
Copy Markdown
Author

Friendly follow-up - is there anything I can improve in this PR? Happy to address any feedback!

@Mr-Neutr0n
Copy link
Copy Markdown
Author

Friendly bump! Let me know if there's anything I should update or improve to help move this forward.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Make it easier to unset use_flex_decoding for Moondream 3 on non-CUDA devices

1 participant