Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
## [0.22.3-rc.1](https://github.com/sequential-parameter-optimization/spotforecast2/compare/v0.22.2...v0.22.3-rc.1) (2026-03-25)

### Bug Fixes

* cache via cache_home ([6996b2d](https://github.com/sequential-parameter-optimization/spotforecast2/commit/6996b2da1b4ea3dfb3ec7c070345c788a7dc97e0))

### Documentation

* demo10 -> demo100 ([8b52033](https://github.com/sequential-parameter-optimization/spotforecast2/commit/8b520331568910045c0ea0f01423b8c874b5229e))
* return df ([709c588](https://github.com/sequential-parameter-optimization/spotforecast2/commit/709c588dc984402216311b3a71bbd12c6dca448b))

## [0.22.2](https://github.com/sequential-parameter-optimization/spotforecast2/compare/v0.22.1...v0.22.2) (2026-03-25)

### Bug Fixes
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "spotforecast2"
version = "0.22.2"
version = "0.22.3-rc.1"
description = "Forecasting with spot"
readme = "README.md"
license = { text = "AGPL-3.0-or-later" }
Expand Down
7 changes: 1 addition & 6 deletions src/spotforecast2/manager/multitask/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,6 @@ class BaseTask:
data_frame_name:
Identifier for the active dataset, used for
cache-directory naming and model file naming.
cache_data:
Whether to cache intermediate data to disk. Boolean flag.
cache_home:
Cache directory path. String or Path.
agg_weights:
Expand Down Expand Up @@ -217,7 +215,6 @@ def __init__(
dataframe: Optional[pd.DataFrame] = None,
data_test: Optional[pd.DataFrame] = None,
data_frame_name: str = "default",
cache_data: bool = True,
cache_home: Optional[Path] = None,
agg_weights: Optional[List[float]] = None,
index_name: str = "DateTime",
Expand All @@ -244,7 +241,6 @@ def __init__(
self._dataframe = dataframe
self.data_frame_name = data_frame_name
self.data_test = data_test
self.cache_data = cache_data
self.cache_home = cache_home
self.agg_weights = agg_weights
self.index_name = index_name
Expand Down Expand Up @@ -317,7 +313,6 @@ def _build_config(self, **overrides: Any) -> ConfigMulti:
"use_exogenous_features": self.use_exogenous_features,
"index_name": self.index_name,
"cache_home": get_cache_home(self.cache_home),
"cache_data": self.cache_data,
"n_trials_optuna": self.n_trials_optuna,
"n_trials_spotoptim": self.n_trials_spotoptim,
"n_initial_spotoptim": self.n_initial_spotoptim,
Expand Down Expand Up @@ -552,7 +547,7 @@ def build_exogenous_features(self) -> "BaseTask":
longitude=self.config.longitude,
timezone=self.config.timezone,
freq="h",
cache_home=self.config.cache_home if self.cache_data else None,
cache_home=self.config.cache_home,
verbose=self.verbose,
)
self.logger.info(" Weather features: %s", weather_features.shape)
Expand Down
3 changes: 0 additions & 3 deletions src/spotforecast2/manager/multitask/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ class MultiTask(BaseTask):
data_test: Pre-loaded input DataFrame with Test data. The DataFrame must contain a
datetime column matching ``index_name`` plus at least one
numeric target column. Optional.
cache_data: Whether to cache intermediate data to disk.
cache_home: Cache directory path.
agg_weights: Per-target aggregation weights.
index_name: Datetime column name in the raw CSV / DataFrame.
Expand Down Expand Up @@ -107,7 +106,6 @@ def __init__(
dataframe: Optional[pd.DataFrame] = None,
data_test: Optional[pd.DataFrame] = None,
data_frame_name: str = "default",
cache_data: bool = True,
cache_home: Optional[Path] = None,
agg_weights: Optional[List[float]] = None,
index_name: str = "DateTime",
Expand Down Expand Up @@ -137,7 +135,6 @@ def __init__(
dataframe=dataframe,
data_test=data_test,
data_frame_name=data_frame_name,
cache_data=cache_data,
cache_home=cache_home,
agg_weights=agg_weights,
index_name=index_name,
Expand Down
72 changes: 40 additions & 32 deletions src/spotforecast2/manager/multitask/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@
def run(
dataframe: pd.DataFrame = None,
task: str = "lazy",
cache_data: bool = True,
cache_home: Optional[str] = None,
bounds: Optional[List[Tuple[float, float]]] = None,
agg_weights: Optional[List[float]] = None,
project_name: str = "test_project",
n_trials_optuna: Optional[int] = 10,
train_days: Optional[int] = 3 * 365,
val_days: Optional[int] = 31,
imputation_method: str = "weighted",
show_progress: bool = False,
plot_with_outliers: bool = False,
show: bool = False,
Expand All @@ -82,8 +82,6 @@ def run(
task: Pipeline mode — one of ``"lazy"``, ``"optuna"``,
``"spotoptim"``, ``"predict"``, or ``"clean"``.
Defaults to ``"lazy"``.
cache_data: Whether to cache the preprocessed data. Defaults to
``False``.
cache_home: Optional path to the cache directory. Defaults to
``None``, which uses the package default cache location that
is defined via spotforecast2_safe's `get_cache_home()`.
Expand All @@ -98,57 +96,71 @@ def run(
train_days: Optional number of days in the training window. Defaults to 3 years (1095 days).
val_days: Optional number of days in the validation window. If
``None``, the default of 31 days is used.
show_progress: Whether to show an Optuna progress bar during optimization. Default is False.
plot_with_outliers: Whether to generate a visualization of the data with outliers highlighted. Defaults to False.
show: Whether to display prediction figures after running each task. Defaults to False.
verbose: Default is False.
imputation_method: Method used for imputation of detected
outliers. Passed to the ``imputation_method`` argument of
MultiTask. Options are ``"weighted"`` or ``"linear"``. Defaults to ``"weighted"``.
show_progress:
Whether to print progress messages during pipeline execution.
Defaults to False.
plot_with_outliers:
Whether to generate a visualization of the data with outliers highlighted. Defaults to False.
show:
Whether to display prediction figures after running each task. Defaults to False.
verbose:
Default is False.
log_level:
Logging level. Default is 40 (ERROR). Other common values include 0 (NOTSET), 10 (DEBUG), 20 (INFO), 30 (WARNING), 50 (CRITICAL).
**kwargs: Additional keyword arguments forwarded verbatim to
MultiTask (e.g. ``predict_size``, ``train_days``,
``val_days``, ``cache_home``).
MultiTask.

Returns:
DataFrame whose index is the forecast horizon timestamps and
whose single column ``"forecast"`` contains the aggregated
predicted values. For the ``"clean"`` task an empty DataFrame
is returned.
DataFrame:
DataFrame whose index is the forecast horizon timestamps and
whose single column ``"forecast"`` contains the aggregated
predicted values. For the ``"clean"`` task an empty DataFrame
is returned.

Raises:
ValueError: If ``task`` is not one of the supported task names.
ValueError:
If ``task`` is not one of the supported task names.

Examples:
Run the pipeline using cached or default model parameters
(``"lazy"`` task):

```{python}
import pandas as pd
from spotforecast2.manager.multitask.runner import run
from spotforecast2_safe.data.fetch_data import fetch_data, get_package_data_home
import warnings
warnings.filterwarnings("ignore")

data_home = get_package_data_home()
df = fetch_data(filename=str(data_home / "demo10.csv"))
df = fetch_data(filename=str(data_home / "demo02.csv"))

forecast = run(df, task="lazy", project_name="demo10", predict_size=24)
forecast = run(df, task="lazy", project_name="demo02", train_days = 365, predict_size=24, imputation_method="linear")
print(forecast)
```

Tune hyperparameters via Optuna Bayesian search (``"optuna"`` task):

```{python}
import pandas as pd
from spotforecast2.manager.multitask.runner import run
from spotforecast2_safe.data.fetch_data import fetch_data, get_package_data_home
import warnings
warnings.filterwarnings("ignore")

data_home = get_package_data_home()
df = fetch_data(filename=str(data_home / "demo10.csv"))
df = fetch_data(filename=str(data_home / "demo02.csv"))

forecast = run(
df,
task="optuna",
project_name="demo10",
n_trials_optuna=20,
project_name="demo02",
n_trials_optuna=5,
predict_size=24,
train_days=365,
val_days=7,
imputation_method="linear"
)
print(forecast)
```
Expand All @@ -158,14 +170,15 @@ def run(
``"optuna"``) must have saved models to the cache first:

```{python}
import pandas as pd
from spotforecast2.manager.multitask.runner import run
from spotforecast2_safe.data.fetch_data import fetch_data, get_package_data_home
import warnings
warnings.filterwarnings("ignore")

data_home = get_package_data_home()
df = fetch_data(filename=str(data_home / "demo10.csv"))
df = fetch_data(filename=str(data_home / "demo02.csv"))

forecast = run(df, task="predict", project_name="demo10", predict_size=24)
forecast = run(df, task="predict", project_name="demo02", predict_size=24, imputation_method="linear")
print(forecast)
```

Expand All @@ -175,25 +188,20 @@ def run(
```{python}
from spotforecast2.manager.multitask.runner import run

result = run(task="clean", project_name="demo10")
result = run(task="clean", project_name="demo02")
print(result.empty)
```
"""
if task not in _ALL_TASKS:
raise ValueError(f"Unknown task '{task}'. Choose from: {sorted(_ALL_TASKS)}")

if cache_data and cache_home is None:
# issue a warning if caching is enabled but no cache_home is provided, as this will use the package default cache location
print(
f"[run] Warning: cache_data is True but no cache_home provided. Using package default cache location {get_cache_home()}."
)
if cache_home is None:
cache_home = get_cache_home()

if task == "clean":
mt = MultiTask(
task="clean",
data_frame_name=project_name,
cache_data=True,
cache_home=cache_home,
**kwargs,
)
Expand All @@ -211,11 +219,11 @@ def run(
data_frame_name=project_name,
agg_weights=effective_agg_weights,
bounds=effective_bounds,
cache_data=cache_data,
cache_home=cache_home,
n_trials_optuna=n_trials_optuna,
train_days=train_days,
val_days=val_days,
imputation_method=imputation_method,
show_progress=show_progress,
verbose=verbose,
log_level=log_level,
Expand Down
4 changes: 3 additions & 1 deletion tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ def test_bounds_default_is_none(self):
def test_project_name_default(self):
assert self._sig().parameters["project_name"].default == "test_project"

def test_no_cache_data_param(self):
assert "cache_data" not in self._sig().parameters


# ---------------------------------------------------------------------------
# ValueError on unknown task
Expand Down Expand Up @@ -154,7 +157,6 @@ def test_multitask_constructed_with_clean(self, MockMT):
MockMT.assert_called_once_with(
task="clean",
data_frame_name="mydata",
cache_data=True,
cache_home=get_cache_home(),
)

Expand Down
48 changes: 22 additions & 26 deletions tests/test_runner_extended.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

- show=True/False forwarded to mt.run()
- plot_with_outliers=True/False controls mt.plot_with_outliers() call
- cache_data=True with no explicit cache_home triggers warning and
auto-resolves to get_cache_home()
- cache_data=False leaves cache_home as None without printing a warning
- Explicit cache_home is forwarded as-is and suppresses the warning
- cache_home=None auto-resolves to get_cache_home() for both pipeline
and clean tasks
- Explicit cache_home is forwarded as-is for both pipeline and clean tasks
- No warning is printed when cache_home is None
- Custom agg_weights forwarded to MultiTask constructor
- Scalar parameters n_trials_optuna, train_days, val_days, show_progress,
verbose, and log_level forwarded to MultiTask constructor
Expand Down Expand Up @@ -124,57 +124,53 @@ def test_plot_with_outliers_clean_task_never_called(self, MockMT):


# ---------------------------------------------------------------------------
# cache_data / cache_home interaction
# cache_home auto-resolution
# ---------------------------------------------------------------------------


class TestCacheDataBehavior:
class TestCacheHomeBehavior:
"""Tests auto-resolution and forwarding of cache_home."""

@patch("spotforecast2.manager.multitask.runner.MultiTask")
def test_cache_data_true_no_home_uses_get_cache_home(self, MockMT):
def test_none_cache_home_auto_resolves_for_pipeline_task(self, MockMT):
from spotforecast2_safe.data.fetch_data import get_cache_home

mt = _mock_mt()
MockMT.return_value = mt
run(_DUMMY_DF, task="lazy", cache_data=True, cache_home=None)
run(_DUMMY_DF, task="lazy", cache_home=None)
_, kwargs = MockMT.call_args
assert kwargs["cache_home"] == get_cache_home()

@patch("spotforecast2.manager.multitask.runner.MultiTask")
def test_cache_data_true_no_home_prints_warning(self, MockMT, capsys):
MockMT.return_value = _mock_mt()
run(_DUMMY_DF, task="lazy", cache_data=True, cache_home=None)
captured = capsys.readouterr()
assert "Warning" in captured.out
def test_none_cache_home_auto_resolves_for_clean_task(self, MockMT):
from spotforecast2_safe.data.fetch_data import get_cache_home

@patch("spotforecast2.manager.multitask.runner.MultiTask")
def test_cache_data_false_no_warning_printed(self, MockMT, capsys):
MockMT.return_value = _mock_mt()
run(_DUMMY_DF, task="lazy", cache_data=False)
captured = capsys.readouterr()
assert captured.out == ""
mt = _mock_mt()
MockMT.return_value = mt
run(_DUMMY_DF, task="clean", cache_home=None)
_, kwargs = MockMT.call_args
assert kwargs["cache_home"] == get_cache_home()

@patch("spotforecast2.manager.multitask.runner.MultiTask")
def test_cache_data_false_cache_home_none_forwarded(self, MockMT):
def test_explicit_cache_home_forwarded_for_pipeline_task(self, MockMT):
mt = _mock_mt()
MockMT.return_value = mt
run(_DUMMY_DF, task="lazy", cache_data=False, cache_home=None)
run(_DUMMY_DF, task="lazy", cache_home="/my/cache")
_, kwargs = MockMT.call_args
assert kwargs["cache_home"] is None
assert kwargs["cache_home"] == "/my/cache"

@patch("spotforecast2.manager.multitask.runner.MultiTask")
def test_explicit_cache_home_forwarded_as_is(self, MockMT):
def test_explicit_cache_home_forwarded_for_clean_task(self, MockMT):
mt = _mock_mt()
MockMT.return_value = mt
run(_DUMMY_DF, task="lazy", cache_data=True, cache_home="/my/cache")
run(_DUMMY_DF, task="clean", cache_home="/my/cache")
_, kwargs = MockMT.call_args
assert kwargs["cache_home"] == "/my/cache"

@patch("spotforecast2.manager.multitask.runner.MultiTask")
def test_explicit_cache_home_suppresses_warning(self, MockMT, capsys):
def test_no_warning_printed_when_cache_home_none(self, MockMT, capsys):
MockMT.return_value = _mock_mt()
run(_DUMMY_DF, task="lazy", cache_data=True, cache_home="/my/cache")
run(_DUMMY_DF, task="lazy", cache_home=None)
captured = capsys.readouterr()
assert captured.out == ""

Expand Down
Loading