Skip to content

feat(gemma4): complete Gemma 4 rewrite — Chapters 1-7#14

Open
cofin wants to merge 81 commits intomainfrom
feat/gemma4
Open

feat(gemma4): complete Gemma 4 rewrite — Chapters 1-7#14
cofin wants to merge 81 commits intomainfrom
feat/gemma4

Conversation

@cofin
Copy link
Copy Markdown
Owner

@cofin cofin commented Apr 3, 2026

Complete rewrite from Gemma 2/3/3n to Gemma 4. All legacy code deleted. Supports all four Gemma 4 families: 31B dense, E2B, E4B, and 26B MoE (A4B).

What changed:

  • Deleted all Gemma 2/3/3n code — models, configs, tests, ops
  • New config.json-based variant detection and HuggingFace download backend
  • Hybrid attention engine — sliding-window (theta=10K) + full global (theta=1M) with per-layer dispatch
  • SigLIP vision encoder — bidirectional attention, variable-resolution preprocessing, token merging during prefill, video frame extraction
  • Per-Layer Embedding (PLE) for E2B/E4B — embed→project→norm→inject per layer, shared-KV attention
  • Audio tower — mel spectrogram via numpy.fft, AudioHydrator, process_audio FFI (weight hydration still pending HF tensor name standardization)
  • MoE routing — 128 experts top-8, sparse GEGLU dispatch with weighted sum
  • Full native GPU backend — ComputeBackend trait, layer-by-layer weight streaming, GPU-resident KV cache, element-wise/reduction/matmul kernels, GPU dispatch in core.mojo
  • Automated perf-benchmark CI workflow on every push to main and PRs
  • Pinned Python 3.12 in Makefile install to avoid freethreaded 3.14 pickup

78 commits, 148 Python tests passing, 4 Mojo test files passing.

cofin added 30 commits April 2, 2026 21:09
Remove convert.py (Orbax→safetensors converter), orbax_loader.py,
tensorstore dependency, all Orbax code paths from loader.py and hub.py,
parity_config.py, and all 21 legacy test files that tested Gemma 2/3/3n
behavior. Loader now supports safetensors only.
Replace ModelVariant (STANDARD/NANO) with Gemma4Variant enum supporting
four Gemma 4 architectures: DENSE_31B, DENSE_E2B, DENSE_E4B, MOE_26B_A4B.
Detection parses HuggingFace config.json instead of tensor name inspection.
Update GenerationConfig/EmbeddingConfig defaults to google/gemma-4-31B-it,
top_k=64, top_p=0.95. Update hub tokenizer path for gemma4 family, update
class docstrings and error messages from Gemma 3 to Gemma 4.
Replace _format_instruction_prompt with _format_gemma4_prompt supporting
plain strings, system prompts, and multi-turn message lists. Thread
system_prompt parameter through SyncGemmaModel and AsyncGemmaModel
generate/generate_stream methods.
Update model.py docstrings and pyproject.toml description from Gemma 3
to Gemma 4. Verification: grep finds zero legacy references in src/py/.
Add test_config.py (11 tests), test_variant_detection.py (12 tests),
test_chat_template.py (7 tests), test_loader.py (5 tests), and
test_hub.py (10 tests). Total 45 tests covering all Ch1 deliverables.
Replace obstore GCSStore with HTTPStore for downloading Gemma 4 models
from HuggingFace. Shard discovery via model.safetensors.index.json
parsing instead of GCS bucket listing. Support HF_TOKEN env var for
gated model access. Delete GCS-specific code paths.
Add _bf16_to_f32() utility that converts bfloat16 raw bytes to float32
via uint16→uint32 left-shift. SafetensorsLoader.get_tensor_metadata()
now detects BF16 dtype from safetensors headers and converts in-memory,
reporting F32 to Mojo. f32 and i8 tensors pass through unchanged.
Add resolve_tokenizer() with priority: explicit path > HF-downloaded
tokenizer.model in model dir. Add validate_config_json() to verify
Gemma 4 model_type and required fields from HuggingFace config.
Add test_hf_token.py (HF_TOKEN env var detection, Bearer header
injection, token passthrough to download_sync) and
test_config_json_download.py (config.json validation, variant detection
integration). Full suite: 96 tests passing.
KVCache with per-layer ring buffer (sliding) and linear (full) allocation
in a single contiguous arena. Dual RoPE frequency tables (theta=10K for
sliding, theta=1M with partial rotation for full). Layer type dispatch in
forward_gemma4_layer/forward_gemma4_step. compute_kv_cache_memory() Python
utility with 9 tests. Deleted ~1,800 lines of Nano/GPU/Vision dead code.
Rewrite core.mojo to use forward_gemma4_step with hybrid KVCache and
dual-theta RoPETables. Delete all Nano/Vision/GPU runtime code (~3,200
lines removed). Wire Python config.json parsing (layer_types,
window_size, partial_rotary_factor, k_eq_v) through to Mojo init.

- core.mojo: new init creates KVCache + RoPETables on heap, step_mojo
  calls forward_gemma4_step, reset_cache zeros hybrid KVCache,
  free_arena cleans up all heap objects
- layers.mojo: remove legacy forward_attention, forward_layer,
  forward_step, forward_sequence (replaced by Gemma 4 functions in Ch3)
- model.py: add _parse_gemma4_architecture(), pass layer_types as
  native Python list to Mojo, simplify _reset_llm_session_state
- Delete 5 dead Mojo test files, rewrite 3 for Gemma 4
- 119 Python tests passing, zero dead code references
test_altup_contract.mojo and test_nano_layers.mojo were deleted in Ch3
(Nano cleanup) but remained in the parametrize list. The former caused
a hard failure; the latter was masked by the unstable-skip list.
123 tests passing, 1 skipped (known unstable test_layers.mojo).
All Nano/Vision/GPU dead code removed from core.mojo (~3,200 lines net).
FFI boundary rewritten for Gemma 4 forward path.
…5 Task 5.1)

Vision encoder weight structs for SigLIP ViT: VisionLayerWeights holds
q/k/v/o_proj + fc1/fc2 MLP + layer_norm1/2 per layer. VisionModelWeights
holds patch/position embeddings, post-norm, projection, and layer list.
…Task 5.5)

Standalone GELU activation (SigLIP uses standard GELU, not GEGLU).
Non-overlapping average_pool_2d reduces vision tokens by kernel^2 (9x for 3x3).
… Task 5.3)

Rewrite ImageHydrator with token budget selection, SigLIP normalization
((pixel/255 - 0.5) / 0.5), and 16x16 patch extraction. Returns ImageInput
dataclass with patches, grid dims, and token count.
…5.4)

Extract video frames at 1 FPS, max 32 frames/60s, using lowest token
budget (70). No new Python deps — ffmpeg subprocess only.
… 5.6)

Batched multi-head attention for SigLIP vision encoder. All patches attend
to all patches — no causal mask, no KV cache, no RoPE. Uses mat_mat_mul
for batched Q/K/V projections.
…Ch5 Task 5.7)

Full SigLIP encoder pipeline: patch embed → position embed → N vision
transformer layers (LayerNorm + bidirectional attn + GELU MLP) →
post-LayerNorm → 3×3 average pooling → decoder projection.
Add _build_vision_runtime/_build_vision_from_runtime to extract
vision_tower.* tensors. Extend _init_model_impl_mojo to load vision
weights when num_vision_layers > 0. Flatten/hydrate via Appender/Hydrator.
…5 Task 5.8)

process_image_mojo runs vision encoder on patches, stores embeddings.
step_with_embedding_mojo feeds pre-computed vision embeddings through
decoder layers, bypassing token embedding lookup. Both registered in
PyInit__core.
Parse vision_config section: num_hidden_layers, hidden_size,
num_attention_heads, intermediate_size. Parse image_token_index.
Add max_image_tokens to GenerationConfig (default 560).
… (Ch5 Task 5.9)

During prefill, replace <image> placeholder tokens with vision embeddings
via step_with_embedding. CoreBackend.process_images now passes ImageInput
attributes (patches, grid_h, grid_w) to the FFI layer.
Remove hasattr guard on process_image — it's now registered in
PyInit__core. All 117 Python tests + 4 Mojo test files pass.
No legacy 384/patch_size=14 references remain.
Full multimodal vision pipeline:
- VisionLayerWeights/VisionModelWeights structs
- gelu + average_pool_2d ops
- Bidirectional vision attention (no causal mask)
- Full SigLIP encoder: patch embed → position embed → N layers → pooling → projection
- Variable-resolution preprocessing: token budgets, SigLIP normalization, 16×16 patches
- Video frame extraction via ffmpeg
- process_image + step_with_embedding FFI
- Token merging: replace <image> placeholders with vision embeddings
- Vision config wiring from config.json

117 Python tests + 4 Mojo test files passing.
…/config

Chapter 6 (E2B/E4B):
- PLELayerWeights struct with per-layer embedding/projection/norm
- forward_ple_input: embed → project → norm → add to hidden state
- forward_gemma4_ple_step: E2B/E4B forward with PLE + shared-KV dispatch
- Shared-KV config parsing (kv_sharing_layer_map)
- Double-wide MLP flag parsing (use_double_wide_mlp)
- ModelWeights extended with ple_layers + has_ple

Chapter 7 (MoE 26B):
- MoEExpertWeights, MoELayerWeights, MoEModelWeights structs
- top_k op in ops.mojo (Int32 indices)
- forward_moe_router: logits → softmax → top_k → renormalize
- forward_moe_experts: sparse GEGLU expert dispatch + weighted sum
- forward_moe_layer: attention (K=V) + MoE block
- forward_gemma4_moe_step: full 26B forward path
- MoE config parsing: num_experts, moe_top_k, moe_intermediate_size
- Variant detection fix for num_local_experts

129 Python tests + 4 Mojo test files passing.
Chapter 5 — Vision Pipeline:
  VisionLayerWeights/VisionModelWeights, gelu + average_pool_2d ops,
  bidirectional vision attention, SigLIP encoder, variable-resolution
  preprocessing, video frame extraction, process_image + step_with_embedding
  FFI, token merging, vision config wiring.

Chapter 6 — E2B/E4B:
  PLELayerWeights, forward_ple_input, forward_gemma4_ple_step,
  shared-KV attention dispatch, double-wide MLP config, PLE/KV config parsing.

Chapter 7 — MoE 26B:
  MoEExpertWeights/MoELayerWeights/MoEModelWeights, top_k op,
  forward_moe_router, forward_moe_experts, forward_moe_layer,
  forward_gemma4_moe_step, MoE config parsing.

129 Python tests + 4 Mojo test files passing.
Zero legacy vision references (384/patch_size=14).
Zero unauthorized dependencies (scipy/librosa/torchaudio).
cofin added 29 commits April 3, 2026 22:06
upload_layer_weights packs all 13 dense layer tensors into staging buffer.
upload_expert_weights uploads a single MoE expert (gate/up/down_proj).
upload_vision_layer_weights uploads 8 vision encoder tensors.
Each function does one host→device copy and returns new weight struct
with device pointers, preserving original shapes.
All 10 tasks: GPUContext, WeightStage, PersistentBuffers, GPUKVCache,
GPUScratch, GPU init/cleanup paths, dense/MoE/vision upload functions.
Widen ops.mojo pointer types from MutExternalOrigin to MutAnyOrigin
so both CPU heap pointers and GPU DeviceBuffer pointers work without
casts. Define ComputeBackend trait with @staticmethod methods for all
9 ops. CPUBackend delegates to the existing @always_inline free
functions with zero overhead.

Chapter 2, Tasks 2.1-2.2.
New file ops_gpu.mojo with thread-per-element GPU kernels:
- gelu_kernel: standard GELU activation using erf()
- geglu_kernel: fused gate-GELU-up for MLP layers
- rope_rotate_kernel: RoPE rotation, one thread per dim pair

Also adds launch config utilities (ceildiv, optimal_block_size) and
comptime constants (BLOCK_1D=256, TILE_BK=64, TILE_BM/BN=16).

Chapter 2, Tasks 2.3-2.4, 2.13 (partial).
- softmax_kernel[BLOCK_SIZE]: single-block softmax using block_max
  and block_sum reductions for vectors up to BLOCK_SIZE
- softmax_strided_kernel[BLOCK_SIZE]: strided variant for vectors
  larger than BLOCK_SIZE (multi-element per thread)
- rms_norm_kernel[BLOCK_SIZE]: RMS normalization with Gemma's
  (1+w) scaling using block_sum for sum-of-squares

All kernels use std.gpu.primitives.block for reductions, with
BLOCK_SIZE as a comptime parameter for compile-time specialization.

Chapter 2, Tasks 2.5-2.6.
Three tiled matmul kernels using shared memory:
- vec_mat_mul_kernel: 1D tiled with shared x vector, BK=64 tile,
  each thread computes one output element
- mat_mat_mul_kernel: 2D tiled (BM=16, BN=16, BK=16), two shared
  tiles for x and w, classic 2D matmul for batched vision ops
- vec_mat_mul_i8_kernel: same tiling as vec_mat_mul but loads int8
  weights with float32 cast and per-tensor scale

Chapter 2, Tasks 2.7-2.9.
- average_pool_2d_kernel: one block per output spatial position,
  threads parallelize across hidden dimension, averages kernel×kernel
  input blocks (SigLIP vision encoder pooling)
- top_k_kernel: single-warp kernel for MoE expert selection (k=8,
  n=128), uses warp_max for distributed max finding across lanes

Chapter 2, Tasks 2.10-2.11.
GPUBackend provides convenience methods that handle grid/block sizing
and shared memory allocation for each kernel. Cannot implement the
ComputeBackend trait directly (GPU launches need DeviceContext), so
layers use comptime if has_accelerator() to dispatch.

Methods: launch_gelu, launch_geglu, launch_rope_rotate,
launch_softmax, launch_rms_norm, launch_vec_mat_mul,
launch_mat_mat_mul, launch_vec_mat_mul_i8, launch_average_pool_2d,
launch_top_k.

Chapter 2, Tasks 2.12-2.13.
- Replace deprecated 'alias' with 'comptime' for constants in model.mojo.
- Transition from verbose 'List[T]()' to bracket literals '[]' for empty list initialization.
- Audit 'raises' consistency across all 'def' functions for future Mojo compatibility.
- Ensure all tests pass with updated syntax.
… dispatch

Wire GPU weight streaming through all Gemma 4 step functions (dense, PLE,
MoE, vision/audio encoders). Key changes:

- Add 2-phase MoE weight streaming: attention+router uploaded via new
  upload_moe_attention_weights, then each selected expert streamed on-demand
- Add persistent GPU buffer support to step_with_embedding and MoE step
- Replace scalar residual loops in MoE layer with backend.vector_add
- Fix GPUBackend: move 7 misplaced trait methods from kernel functions back
  into the struct (vector_add, average_pool_2d, top_k, kv_write, attention)
- Fix module-level comptime-if scoping: use unconditional GPU type imports
- Fix Mojo keyword conflicts (out -> dst) and DevicePassable (Bool -> Int)
- Remove raises from upload functions, use abort-on-error pattern
- Remove duplicate get_attention_range and except block
…te utility

Complete remaining gpu-forward-paths tasks:
- Task 4.9: Parameterize forward_vision_encoder with S/C type params for
  GPU weight streaming via upload_vision_layer_weights per layer
- Task 4.10: Same pattern for forward_audio_encoder
- Task 4.12: Add copy_state kernel+function to gpu_context.mojo for
  device-to-device state copy without host involvement
- Replace scalar state-swap loops with backend.copy() in both encoders
- Fix test_vision_encoder.mojo to pass CPUBackend and new params
- Add test_gpu_forward.mojo with CPU regression and GPU integration tests
All 12 tasks verified, user confirmed. Epic mogemma-92py.4 closed.
- Task 5.6: GPU branch in process_image_mojo using GPUBackend + weight streaming
- Task 5.7: GPU branch stub in process_audio_mojo (blocked on weight hydration)
- Task 5.8: GPU branch in generate_embeddings_mojo with per-sequence GPU KV reset
- Task 5.10: GPU-aware reset_cache_mojo with GPUKVCache.reset() kernel
- GPUKVCache.reset(): new method using zero-fill kernel on device buffers
…d docs

- Task 5.11: Improve GPU unavailable error message with actionable guidance
- Task 5.12: Add device/backend span attributes to OTel telemetry spans
- Task 5.13: Add GPU metrics collection to benchmark payload
- Task 5.14: Register GPU test files in test_mojo.py parametrize list
- Task 5.15: Add GPU requirements note to README
All 15 tasks verified. Epic mogemma-92py.5 closed.
GPU dispatch end-to-end: init, step, encoders, embeddings, cleanup, reset,
Python validation, telemetry, benchmarks, CI guards, docs.
Adds a GitHub Actions workflow that runs `make benchmark` on every push
to main and on pull requests. Results are uploaded as artifacts for
90-day retention. Purely informational — does not gate PRs.
Standardizes dictionary keys for GPU resources ('_gpu_context_ptr' and
'_gpu_weight_stage_ptr') to ensure consistent lookup across FFI entry
points. This fixes runtime lookups and ensures template instantiation
gating works as intended during build on CPU-only CI environments.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant