Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,37 @@ All notable changes to fairseq2 are documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [0.8.0] - TBD
- fsspec integration for remote filesystem support. Checkpoints can be saved to
and loaded from S3 via `--checkpoint-dir s3://bucket/path/`. Requires `s3fs`.
- New `GlobalFileSystem` replaces `LocalFileSystem` as default, dispatching to
the appropriate backend based on URI scheme.
## [0.8.0] - March 25th, 2026
- fsspec integration for remote filesystem support. Checkpoints can be saved to and loaded from S3 via `--checkpoint-dir s3://bucket/path/`. Requires `s3fs`. (#1126)
- New `GlobalFileSystem` replaces `LocalFileSystem` as default, dispatching to the appropriate backend based on URI scheme. (#1126)
- PyTorch 2.9.1 and 2.10 (forward compatibility) are now supported. PyTorch 2.9 introduced breaking changes to LR scheduler return types, which have been addressed. (#1477, #1491, #1456)
- **Breaking**: Trainer, evaluator, generator, validator, and task moved from `fairseq2.recipe` to `fairseq2` package root. (#1417)
- **Breaking**: LM recipes restructured: `text_generate` renamed to `generate`, SFT configs removed/renamed, recipe config classes changed. (#1431, #1432, #1433)
- **Breaking**: `RecipeModel` is deprecated. Access the model directly via `.module` instead. (#1403)
- **Breaking**: `pq.ParquetDataset` replaced with `pyarrow.dataset` interface. (#1490)
- **Breaking**: `resolve_optional` renamed to `maybe_resolve`. (#1462)
- **Breaking**: Revised `ModelCheckpointLoader` API. (#1475)
- **Breaking**: Refactored tensor sharded modules (embedding, projection, FFN, attention). (#1476)
- New context managers for procedural programming: `GangContext`, `DeviceContext`, `DataTypeContext`, `current_dtype`. Eliminates need to pass state through nested function calls. (#1474, #1473, #1464)
- `CheckpointManager`, `Optimizer`, and `LRScheduler` now exposed in `RecipeContext`. (#1461)
- Synchronous asset loading across ranks for models and tokenizers. Use when all ranks need identical assets loaded simultaneously. (#1429, #1426)
- `CheckpointManager.register_save_hook` allows custom logic during checkpoint saves. (#1439)
- Config files now support `${env:<NAME>}` to interpolate environment variables. (#1435)
- `--no-rich` CLI flag disables rich text output for log parsing. (#1421)
- Hugging Face export now runs in isolated process with saved command line and logs for debugging. (#1459, #1458, #1437, #1434)
- Improved support for gated Hugging Face models. (#1422)
- `get_family` utility functions for detecting model families. (#1454)
- Gemma3n model family (E2B/E4B) with text + audio inference and SFT training. (#1496)
- Generic HuggingFace model integration: load, shard, and train any HuggingFace CausalLM model directly through `HgCausalLMAdapter` without requiring a native fairseq2 reimplementation. Includes FSDP sharding, HF tokenizer integration, and SFT recipe support. (#1479)
- `AssetDownloadManager` gains `local_only` parameter and custom download subpath support. (#1423, #1425)
- Recipes now set Python `random` and `numpy` seeds for reproducibility. (#1419)
- Wandb metric recorder now respects wandb environment variables. (#1440)
- Improved `share_parameters` implementation. (#1484)
- Fixed `cross_entropy` with `reduction="mean"` to properly exclude padding tokens from the denominator. (#1455)
- Fixed `Flash3SDPA` to support the `flash-attn-3` v3.0.0 package API (`flash_attn_3._C` / `torch.ops.flash_attn_3`) in addition to the legacy `flash_attn_3_cuda` module. (#1495)
- Fixed data pipeline sampling bug when `allow_repeats=False` with many pipelines. (#1471)
- Fixed `DataParallelFacade` weakref errors. (#1447, #1436)
- Fixed WER calculation to use lists instead of tensors. (#1413)

## [0.7.0] - Nov 4th, 2025
- `RecipeModel` is now callable and forwards the call to `RecipeModel.module`
Expand All @@ -24,7 +50,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [0.6.0] - Oct 7th, 2025
- `fairseq2.sharder` is deprecated. fairseq2 now expects parallelism strategies
to be applied within model factories. This gives model authors full control
over how parallelism is applied to their models. [More info](https://github.com/facebookresearch/fairseq2/pull/1349)
over how parallelism is applied to their models. [More info](https://github.com/facebookresearch/fairseq2/pull/1349)
- `Gangs` can now be used as a context manager, along with a new `maybe_get_current_gangs()`
helper function. This feature is particularly useful in procedural programming,
as it eliminates the need to pass a `Gangs` instance through every function call.
Expand Down
52 changes: 51 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,44 @@ matrix shows the supported combinations.
<td><code>cpu</code>, <code>cu126</code>, <code>cu128</code></td>
<td><code>x86_64</code></td>
</tr>
<tr>
<td rowspan=3><code>0.8</code></td>
<td><code>2.9.1</code></td>
<td><code>&gt;=3.10</code>, <code>&lt;=3.12</code></td>
<td><code>cpu</code>, <code>cu126</code>, <code>cu128</code></td>
<td><code>x86_64</code></td>
</tr>
<tr>
<td><code>2.8.0</code></td>
<td><code>&gt;=3.10</code>, <code>&lt;=3.12</code></td>
<td><code>cpu</code>, <code>cu126</code>, <code>cu128</code></td>
<td><code>x86_64</code></td>
</tr>
<tr>
<td><code>2.7.1</code></td>
<td><code>&gt;=3.10</code>, <code>&lt;=3.12</code></td>
<td><code>cpu</code>, <code>cu126</code>, <code>cu128</code></td>
<td><code>x86_64</code></td>
</tr>
<tr>
<td rowspan=3><code>0.7</code></td>
<td><code>2.9.0</code></td>
<td><code>&gt;=3.10</code>, <code>&lt;=3.12</code></td>
<td><code>cpu</code>, <code>cu126</code>, <code>cu128</code></td>
<td><code>x86_64</code></td>
</tr>
<tr>
<td><code>2.8.0</code></td>
<td><code>&gt;=3.10</code>, <code>&lt;=3.12</code></td>
<td><code>cpu</code>, <code>cu126</code>, <code>cu128</code></td>
<td><code>x86_64</code></td>
</tr>
<tr>
<td><code>2.7.1</code></td>
<td><code>&gt;=3.10</code>, <code>&lt;=3.12</code></td>
<td><code>cpu</code>, <code>cu126</code>, <code>cu128</code></td>
<td><code>x86_64</code></td>
</tr>
<tr>
<td rowspan=3><code>0.6</code></td>
<td><code>2.8.0</code></td>
Expand Down Expand Up @@ -200,11 +238,23 @@ the supported combinations.
</thead>
<tbody>
<tr>
<td rowspan=3><code>HEAD</code></td>
<td><code>HEAD</code></td>
<td><code>2.9.1</code></td>
<td><code>&gt;=3.10</code>, <code>&lt;=3.12</code></td>
<td><code>arm64</code></td>
</tr>
<tr>
<td><code>0.8</code></td>
<td><code>2.9.1</code></td>
<td><code>&gt;=3.10</code>, <code>&lt;=3.12</code></td>
<td><code>arm64</code></td>
</tr>
<tr>
<td><code>0.7</code></td>
<td><code>2.9.0</code></td>
<td><code>&gt;=3.10</code>, <code>&lt;=3.12</code></td>
<td><code>arm64</code></td>
</tr>
<tr>
<td rowspan=2><code>0.6</code></td>
<td><code>2.8.0</code></td>
Expand Down
2 changes: 1 addition & 1 deletion requirements-devel.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
black~=25.1
black~=26.3
flake8~=7.1
flake8-pyi~=24.6
flake8-pyproject~=1.2
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
"arrow": [
"pyarrow>=17.0",
"retrying~=1.3",
"pandas~=2.0",
"pandas~=2.2",
"polars~=1.19",
"xxhash~=3.5",
],
Expand Down
4 changes: 2 additions & 2 deletions src/fairseq2/metrics/recorders/jsonl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import json
import re
from collections.abc import Mapping, Sequence
from datetime import datetime
from datetime import datetime, timezone
from pathlib import Path
from typing import Final, TextIO, final

Expand Down Expand Up @@ -85,7 +85,7 @@ def sanitize(value: object) -> object:
f"`values` must consist of objects of types `{int}`, `{float}`, `{Tensor}`, and `{str}` only."
)

output: dict[str, object] = {"Time": datetime.utcnow().isoformat()}
output: dict[str, object] = {"Time": datetime.now(timezone.utc).isoformat()}

if step_nr is not None:
output["Step"] = step_nr
Expand Down
3 changes: 3 additions & 0 deletions src/fairseq2/model_checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@
from fairseq2.model_checkpoint.loader import (
CorruptModelCheckpointError as CorruptModelCheckpointError,
)

# TODO: Deprecated, will be removed in v0.14.
ModelCheckpointError = CorruptModelCheckpointError
from fairseq2.model_checkpoint.loader import (
ModelCheckpointLoader as ModelCheckpointLoader,
)
Expand Down
21 changes: 14 additions & 7 deletions src/fairseq2/models/transformer/sdpa/flash3.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,18 @@
from typing_extensions import override

try:
import flash_attn_3_cuda # type: ignore[import-not-found]
import flash_attn_3._C as _flash_attn_3_C # type: ignore[import-not-found,import-untyped] # noqa: F401,N812

_flash_attn_3_ops: Any = torch.ops.flash_attn_3
except ImportError:
_has_flash_attn_3 = False
else:
_has_flash_attn_3 = True
try:
import flash_attn_3_cuda as _flash_attn_3_C # type: ignore[import-not-found,no-redef] # noqa: F401,N812

_flash_attn_3_ops = _flash_attn_3_C
except ImportError:
_flash_attn_3_ops = None

_has_flash_attn_3 = _flash_attn_3_ops is not None

from fairseq2.error import NotSupportedError, OperationalError
from fairseq2.models.transformer.attention_bias import (
Expand Down Expand Up @@ -156,7 +163,7 @@ def _flash_attn_3_op(
v = _contiguous(v)

# fmt: off
out, softmax_lse, *_ = flash_attn_3_cuda.fwd(
out, softmax_lse, *_ = _flash_attn_3_ops.fwd(
q,
k,
v,
Expand Down Expand Up @@ -332,7 +339,7 @@ def _flash_attn_3_varlen_op(
cu_seqlens_k = _contiguous(cu_seqlens_k)

# fmt: off
out, softmax_lse, *_ = flash_attn_3_cuda.fwd(
out, softmax_lse, *_ = _flash_attn_3_ops.fwd(
q,
k,
v,
Expand Down Expand Up @@ -491,7 +498,7 @@ def _flash_attn_3_bwd_op(
rhs_window_size: int,
) -> None:
# fmt: off
flash_attn_3_cuda.bwd(
_flash_attn_3_ops.bwd(
dout,
q,
k,
Expand Down
3 changes: 3 additions & 0 deletions src/fairseq2/recipe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,8 @@
from fairseq2.evaluator import EvalUnit as EvalUnit
from fairseq2.generator import Generator as Generator
from fairseq2.generator import GeneratorUnit as GeneratorUnit
from fairseq2.task import Task as Task
from fairseq2.task import TaskStopException as TaskStopException
from fairseq2.trainer import Trainer as Trainer
from fairseq2.trainer import TrainUnit as TrainUnit
from fairseq2.validator import Validator as Validator
32 changes: 27 additions & 5 deletions src/fairseq2/recipe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,11 @@ def validate(self) -> ValidationResult:
)

if self.publish_metrics_every_n_steps is not None:
if self.validate_every_n_steps % self.publish_metrics_every_n_steps != 0: # fmt: skip
not_multiple = (
self.validate_every_n_steps % self.publish_metrics_every_n_steps
!= 0
)
if not_multiple:
result.add_error(
f"`validate_every_n_steps` must be a multiple of `publish_metrics_every_n_steps` ({self.publish_metrics_every_n_steps}), but is {self.validate_every_n_steps} instead."
)
Expand All @@ -451,7 +455,12 @@ def validate(self) -> ValidationResult:
)

if self.publish_metrics_every_n_data_epochs is not None:
if self.validate_every_n_data_epochs % self.publish_metrics_every_n_data_epochs != 0: # fmt: skip
not_multiple = (
self.validate_every_n_data_epochs
% self.publish_metrics_every_n_data_epochs
!= 0
)
if not_multiple:
result.add_error(
f"`validate_every_n_data_epochs` must be a multiple of `publish_metrics_every_n_data_epochs` ({self.publish_metrics_every_n_data_epochs}), but is {self.validate_every_n_data_epochs} instead."
)
Expand All @@ -463,7 +472,11 @@ def validate(self) -> ValidationResult:
)

if self.publish_metrics_every_n_steps is not None:
if self.checkpoint_every_n_steps % self.publish_metrics_every_n_steps != 0: # fmt: skip
not_multiple = (
self.checkpoint_every_n_steps % self.publish_metrics_every_n_steps
!= 0
)
if not_multiple:
result.add_error(
f"`checkpoint_every_n_steps` must be a multiple of `publish_metrics_every_n_steps` ({self.publish_metrics_every_n_steps}), but is {self.checkpoint_every_n_steps} instead."
)
Expand All @@ -475,7 +488,12 @@ def validate(self) -> ValidationResult:
)

if self.publish_metrics_every_n_data_epochs is not None:
if self.checkpoint_every_n_data_epochs % self.publish_metrics_every_n_data_epochs != 0: # fmt: skip
not_multiple = (
self.checkpoint_every_n_data_epochs
% self.publish_metrics_every_n_data_epochs
!= 0
)
if not_multiple:
result.add_error(
f"`checkpoint_every_n_data_epochs` must be a multiple of `publish_metrics_every_n_data_epochs` ({self.publish_metrics_every_n_data_epochs}), but is {self.checkpoint_every_n_data_epochs} instead."
)
Expand Down Expand Up @@ -509,7 +527,11 @@ def validate(self) -> ValidationResult:
)

if self.checkpoint_every_n_steps is not None:
if self.keep_checkpoint_every_n_steps % self.checkpoint_every_n_steps != 0: # fmt: skip
not_multiple = (
self.keep_checkpoint_every_n_steps % self.checkpoint_every_n_steps
!= 0
)
if not_multiple:
result.add_error(
f"`keep_checkpoint_every_n_steps` must be a multiple of `checkpoint_every_n_steps` ({self.checkpoint_every_n_steps}), but is {self.keep_checkpoint_every_n_steps} instead."
)
Expand Down
10 changes: 10 additions & 0 deletions src/fairseq2/recipe/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from fairseq2.task import Task as Task # noqa: F401
from fairseq2.task import TaskStopException as TaskStopException # noqa: F401
11 changes: 11 additions & 0 deletions src/fairseq2/recipe/validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from fairseq2.validator import NOOP_VALIDATOR as NOOP_VALIDATOR # noqa: F401
from fairseq2.validator import StandardValidator as StandardValidator # noqa: F401
from fairseq2.validator import Validator as Validator # noqa: F401
9 changes: 9 additions & 0 deletions src/fairseq2/runtime/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from fairseq2.error import InternalError, InvalidOperationError
from fairseq2.runtime.lazy import Lazy
from fairseq2.runtime.lookup import Lookup
from fairseq2.utils.warn import _warn_deprecated

T = TypeVar("T")

Expand All @@ -43,6 +44,14 @@ def maybe_resolve(
self, kls: type[T], *, key: Hashable | None = None
) -> T | None: ...

def resolve_optional(
self, kls: type[T], *, key: Hashable | None = None
) -> T | None:
_warn_deprecated(
"`resolve_optional()` is deprecated and will be removed in v0.14. Use `maybe_resolve()` instead."
)
return self.maybe_resolve(kls, key=key)

@abstractmethod
def iter_keys(self, kls: type[T]) -> Iterator[Hashable]: ...

Expand Down
12 changes: 10 additions & 2 deletions src/fairseq2/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,11 @@ def __init__(
)

if publish_metrics_every_n_data_epochs is not None:
if validate_every_n_data_epochs % publish_metrics_every_n_data_epochs != 0: # fmt: skip
not_multiple = (
validate_every_n_data_epochs % publish_metrics_every_n_data_epochs
!= 0
)
if not_multiple:
raise ValueError(
f"`validate_every_n_data_epochs` must be a multiple of `publish_metrics_every_n_data_epochs` ({publish_metrics_every_n_data_epochs}), but is {validate_every_n_data_epochs} instead."
)
Expand All @@ -215,7 +219,11 @@ def __init__(
)

if publish_metrics_every_n_data_epochs is not None:
if checkpoint_every_n_data_epochs % publish_metrics_every_n_data_epochs != 0: # fmt: skip
not_multiple = (
checkpoint_every_n_data_epochs % publish_metrics_every_n_data_epochs
!= 0
)
if not_multiple:
raise ValueError(
f"`checkpoint_every_n_data_epochs` must be a multiple of `publish_metrics_every_n_data_epochs` ({publish_metrics_every_n_data_epochs}), but is {checkpoint_every_n_data_epochs} instead."
)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/nn/test_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ def test_get_shard_dims_work() -> None:

module = Sequential(
Linear(32, 32, bias=True),
ColumnShardedLinear(32, 32, bias=True, gangs=gangs),
ColumnShardedLinear(32, 32, bias=True, gangs=gangs, device=device),
Linear(32, 32, bias=False),
Sequential(
RowShardedLinear(32, 32, bias=True, gangs=gangs),
RowShardedLinear(32, 32, bias=True, gangs=gangs, device=device),
),
)

Expand Down
Loading