Skip to content
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ poetry.lock
dist
clients/python/moondream/torch
wandb/
bitblas_cache/
moondream_finetune.safetensors
2 changes: 2 additions & 0 deletions moondream/torch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ class TextConfig:
n_heads: int = 32
n_kv_heads: int = 32
prefix_attn: int = 730
group_size: int = 128
cache_dir: str = "./bitblas_cache"


@dataclass(frozen=True)
Expand Down
65 changes: 65 additions & 0 deletions moondream/torch/layers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from dataclasses import dataclass
from typing import Literal

import bitblas
from bitblas.cache import OperatorCache

import torch
from torch.nn import functional as F
import torch.nn as nn


def gelu_approx(x):
Expand All @@ -15,6 +19,66 @@ class LinearWeights:
bias: torch.Tensor


class Linear(nn.Module):
"""
Linear layer with support for bitblas quantization.
If dtype is torch.int8, it uses bitblas for quantization.
Otherwise, it uses a standard nn.Linear layer.
"""

def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
operator_cache: OperatorCache = None,
cache_dir: str = None,
group_size: int = 128,
):
super().__init__()

if dtype == torch.int8:
self.linear = bitblas.Linear(
in_features=in_features,
out_features=out_features,
bias=bias,
with_zeros=True,
zeros_mode="original",
with_scaling=True,
A_dtype="float16",
W_dtype="uint4",
accum_dtype="float16",
out_dtype="float16",
fast_decoding=True,
enable_tuning=True,
operator_cache=operator_cache,
database_path=cache_dir,
group_size=group_size,
)
else:
self.linear = nn.Linear(
in_features=in_features,
out_features=out_features,
bias=bias,
dtype=torch.float16,
)

def forward(self, x):
return self.linear(x)

@property
def weight(self) -> torch.Tensor:
try:
return self.linear.weight
except AttributeError:
return self.linear.qweight

@property
def bias(self) -> torch.Tensor:
return self.linear.bias


def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor:
return F.linear(x, w.weight, w.bias)

Expand All @@ -37,6 +101,7 @@ class MLPWeights:


def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor:

x = w.fc1(x)
x = gelu_approx(x)
x = w.fc2(x)
Expand Down
26 changes: 15 additions & 11 deletions moondream/torch/moondream.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,16 @@ class MoondreamModel(nn.Module):
def __init__(self, config: MoondreamConfig, dtype=torch.float16, setup_caches=True):
super().__init__()
self.config = config
self.dtype = dtype
self.setup_caches_flag = setup_caches

self.tokenizer = Tokenizer.from_pretrained(
"vikhyatk/moondream2", revision="2025-01-09"
)

self.vision = build_vision_model(config.vision, dtype)
self.text = build_text_model(config.text, dtype)

self.text = None

# Region Model
self.region = nn.ModuleDict(
Expand Down Expand Up @@ -125,11 +129,11 @@ def __init__(self, config: MoondreamConfig, dtype=torch.float16, setup_caches=Tr
attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1
self.register_buffer("attn_mask", attn_mask, persistent=False)

# Initialize KV caches.
if setup_caches:
self._setup_caches()

def _setup_caches(self):
"""Setup KV caches for the text model"""
if self.text is None:
return # Can't set up caches without text model

c = self.config.text
for b in self.text.blocks:
b.kv_cache = KVCache(
Expand Down Expand Up @@ -163,15 +167,14 @@ def _decode_one_tok(

def compile(self):
# TODO: vision_projection is not being compiled
self._vis_enc = torch.compile(self._vis_enc, fullgraph=True)
self._prefill = torch.compile(self._prefill, fullgraph=True)
self._decode_one_tok = torch.compile(
self._decode_one_tok, fullgraph=True, mode="reduce-overhead"
self._vis_enc = torch.compile(
self._vis_enc, fullgraph=False, mode="reduce-overhead"
)
# self._prefill = torch.compile(self._prefill)
# self._decode_one_tok = torch.compile(self._decode_one_tok)

def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor:
all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device)

torch._dynamo.mark_dynamic(all_crops, 0)

outputs = self._vis_enc(all_crops)
Expand Down Expand Up @@ -201,6 +204,7 @@ def encode_image(self, image: Union[Image.Image, EncodedImage]) -> EncodedImage:

# Run through text model in addition to the vision encoder, to minimize
# re-computation if multiple queries are performed on this image.

with torch.inference_mode():
img_emb = self._run_vision_encoder(image)
bos_emb = text_encoder(
Expand Down Expand Up @@ -236,10 +240,10 @@ def _apply_top_p(self, probs: torch.Tensor, top_p: float):
def _prefill_prompt(
self, prompt_tokens: torch.Tensor, pos: int, temperature: float, top_p: float
):

with torch.inference_mode():
prompt_emb = text_encoder(prompt_tokens, self.text)
torch._dynamo.mark_dynamic(prompt_emb, 1)

mask = self.attn_mask[:, :, pos : pos + prompt_emb.size(1), :]
pos_ids = torch.arange(pos, pos + prompt_emb.size(1), dtype=torch.long)
hidden = self._prefill(prompt_emb, mask, pos_ids)
Expand Down
7 changes: 6 additions & 1 deletion moondream/torch/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@

from .weights import load_weights_into_model
from .moondream import MoondreamModel, MoondreamConfig
import time

if __name__ == "__main__":
start = time.time()
parser = argparse.ArgumentParser()
parser.add_argument("--image", "-i", type=str, required=True)
parser.add_argument("--prompt", "-p", type=str, required=True)
Expand All @@ -32,17 +34,20 @@
config = MoondreamConfig.from_dict(config)
else:
config = MoondreamConfig()

model = MoondreamModel(config)
load_weights_into_model(args.model, model)
model = model.to(device)

# Encode image.
image_path = args.image
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image not found at {image_path}")
image = Image.open(image_path)
model = model.to(device)

if not args.benchmark:

# model.compile()
encoded_image = model.encode_image(image)

# Short caption
Expand Down
57 changes: 41 additions & 16 deletions moondream/torch/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import torch.nn as nn

from torch.nn import functional as F
from bitblas.cache import OperatorCache

from .layers import layer_norm, mlp
from .layers import layer_norm, mlp, Linear
from .rope import apply_rotary_emb, precompute_freqs_cis
from .config import TextConfig

Expand All @@ -26,6 +27,7 @@ def attn(
head_dim = d_model // n_heads

qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)

q_dim = n_heads * head_dim
kv_dim = n_kv_heads * head_dim

Expand Down Expand Up @@ -139,6 +141,7 @@ def text_decoder(
n_kv_heads=config.n_kv_heads,
position_ids=position_ids,
)

l_mlp = mlp(l_in, block.mlp)
x = x + l_attn + l_mlp

Expand All @@ -158,44 +161,66 @@ def _lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
return logits


def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
def build_text_model(
config: TextConfig,
linear_dtype: torch.dtype = torch.float16,
layernorm_dtype: torch.dtype = torch.float16,
) -> (
nn.Module
): # note : layernorm dtype is used for layernorm, lm_head and wte not just layernorm
qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads))

operator_cache = None
cache_dir = None
group_size = None
if linear_dtype == torch.int8:

operator_cache = OperatorCache()
cache_dir = config.cache_dir
group_size = config.group_size

def create_linear(in_features, out_features, dtype=linear_dtype):
# factory function for creating Linear layers so we dont have to pass everything again and again
return Linear(
in_features=in_features,
out_features=out_features,
dtype=dtype,
operator_cache=operator_cache,
cache_dir=cache_dir,
group_size=group_size,
)

text = nn.ModuleDict(
{
"blocks": nn.ModuleList(
[
nn.ModuleDict(
{
"ln": nn.LayerNorm(config.dim, dtype=dtype),
"ln": nn.LayerNorm(config.dim, dtype=layernorm_dtype),
"attn": nn.ModuleDict(
{
"qkv": nn.Linear(config.dim, qkv_dim, dtype=dtype),
"proj": nn.Linear(
config.dim, config.dim, dtype=dtype
),
"qkv": create_linear(config.dim, qkv_dim),
"proj": create_linear(config.dim, config.dim),
}
),
"mlp": nn.ModuleDict(
{
"fc1": nn.Linear(
config.dim, config.ff_dim, dtype=dtype
),
"fc2": nn.Linear(
config.ff_dim, config.dim, dtype=dtype
),
"fc1": create_linear(config.dim, config.ff_dim),
"fc2": create_linear(config.ff_dim, config.dim),
}
),
}
)
for _ in range(config.n_layers)
]
),
"post_ln": nn.LayerNorm(config.dim, dtype=dtype),
"lm_head": nn.Linear(config.dim, config.vocab_size, dtype=dtype),
"post_ln": nn.LayerNorm(config.dim, dtype=layernorm_dtype),
"lm_head": nn.Linear(config.dim, config.vocab_size, dtype=layernorm_dtype),
}
)
text.wte = nn.Parameter(torch.empty(config.vocab_size, config.dim, dtype=dtype))
text.wte = nn.Parameter(
torch.empty(config.vocab_size, config.dim, dtype=layernorm_dtype)
)
text.register_buffer(
"freqs_cis",
precompute_freqs_cis(config.dim // (2 * config.n_heads), config.max_context),
Expand Down
Loading