Issue Title
train_llama_torchtitan example: Multiple issues running on TPU v4-32
Issue Description
Summary
Encountered multiple issues setting up and running the train_llama_torchtitan example on TPU v4-32. Documenting all problems and fixes.
Environment
- Hardware: TPU v4-32 (4 workers)
- Python: 3.11
Issue 1: Missing splash_attn.py file
The example references splash_attn.py for TPU flash attention but it's not automatically included when copying the example files.
Fix: Manually download splash_attn.py from the repository.
Issue 2: Python version compatibility
The TPU VM defaulted to an older Python version incompatible with the dependencies.
Fix: Install and use Python 3.11 explicitly in install_deps.sh.
Issue 3: pip install -e . fails for splash_attn
The install script tried to run pip install -e . for splash_attn, but it's a standalone file, not a package.
Fix: Remove the pip install -e . line for splash_attn from install_deps.sh.
Issue 4: freqs_cis KeyError / state_dict mismatch
Error:
or
RuntimeError: Unexpected key(s) in state_dict: "freqs_cis"
Cause: Code assumes freqs_cis handling for scan mode. When use_scan=False, the model structure differs.
Fix: Conditionally handle freqs_cis based on use_scan:
state_dict = create_sharded_weights(gpt, mesh, sharding_map)
if "freqs_cis" in state_dict:
state_dict.pop("freqs_cis")
if use_scan:
state_dict["freqs_cis"] = freqs_cis.to("jax").apply_jax(jax.device_put, replicated)
gpt.load_state_dict(state_dict, assign=True)
if not use_scan:
gpt.freqs_cis = freqs_cis.to("jax").apply_jax(jax.device_put, replicated)
Issue 5: Embedding sharding error with tp_parallelism=1
Error:
jax._src.core.ShardingTypeError: Use `.at[...].get(out_sharding=)` to specify the output sharding of a gather from a sharded source.
Got operand=ShapedArray(bfloat16[128256@fsdp,4096@tp]), indices=ShapedArray(int32[16@fsdp,2048,1])
Cause: sharding_map_original shards tok_embeddings.weight as ("fsdp", "tp"). With tp_parallelism=1, the embedding lookup creates an ambiguous gather that JAX cannot resolve.
Fix: Add sharding_map_original_fsdp for non-scan mode with tp_parallelism=1:
sharding_map_original_fsdp = {
"tok_embeddings.weight": (None, "fsdp"), # vocab replicated, hidden sharded
"output.weight": (None, "fsdp"),
# ... other weights with FSDP-only sharding
}
Update selection logic to use this map when use_scan=False and tp_parallelism=1.
Issue 6: Missing sharding map selection for scan mode with tp_parallelism=1
The code only selected sharding_map_scan regardless of tp_parallelism value.
Fix: Update selection logic:
if use_scan:
if tp_parallelism == 1:
sharding_map = sharding_map_scan_fsdp
else:
sharding_map = sharding_map_scan
else:
if tp_parallelism == 1:
sharding_map = sharding_map_original_fsdp
else:
sharding_map = sharding_map_original
Issue Title
train_llama_torchtitanexample: Multiple issues running on TPU v4-32Issue Description
Summary
Encountered multiple issues setting up and running the
train_llama_torchtitanexample on TPU v4-32. Documenting all problems and fixes.Environment
Issue 1: Missing
splash_attn.pyfileThe example references
splash_attn.pyfor TPU flash attention but it's not automatically included when copying the example files.Fix: Manually download
splash_attn.pyfrom the repository.Issue 2: Python version compatibility
The TPU VM defaulted to an older Python version incompatible with the dependencies.
Fix: Install and use Python 3.11 explicitly in install_deps.sh.
Issue 3:
pip install -e .fails forsplash_attnThe install script tried to run
pip install -e .forsplash_attn, but it's a standalone file, not a package.Fix: Remove the
pip install -e .line for splash_attn from install_deps.sh.Issue 4:
freqs_cisKeyError / state_dict mismatchError:
or
Cause: Code assumes
freqs_cishandling for scan mode. Whenuse_scan=False, the model structure differs.Fix: Conditionally handle
freqs_cisbased onuse_scan:Issue 5: Embedding sharding error with
tp_parallelism=1Error:
Cause:
sharding_map_originalshardstok_embeddings.weightas("fsdp", "tp"). Withtp_parallelism=1, the embedding lookup creates an ambiguous gather that JAX cannot resolve.Fix: Add
sharding_map_original_fsdpfor non-scan mode withtp_parallelism=1:Update selection logic to use this map when
use_scan=Falseandtp_parallelism=1.Issue 6: Missing sharding map selection for scan mode with
tp_parallelism=1The code only selected
sharding_map_scanregardless oftp_parallelismvalue.Fix: Update selection logic: