From ccb8c9f363a036e331dbb3f3d04f8d1ce6364793 Mon Sep 17 00:00:00 2001 From: Zack Hubert Date: Mon, 16 Mar 2026 06:35:48 -0700 Subject: [PATCH] fix(workflow): complete model selection for ai.summarize and builtin templates ai.summarize (added in #432) was missing the SetModel() call that all other AI action handlers received in #431. Also extends the builtin template parameter system so that plan, code, document, ci, and review templates accept a `model` param and pass it through to their AI states. Documents settings.model, per-state model, and template model param in docs/workflow.html. Co-Authored-By: Claude Opus 4.6 (1M context) --- docs/workflow.html | 68 ++++++++++++++++++++++++++++- internal/daemon/actions_test.go | 64 +++++++++++++++++++++++++++ internal/daemon/coding.go | 1 + internal/workflow/defaults.go | 11 +++++ internal/workflow/template.go | 9 ++-- internal/workflow/template_test.go | 69 ++++++++++++++++++++++++++++++ 6 files changed, 218 insertions(+), 4 deletions(-) diff --git a/docs/workflow.html b/docs/workflow.html index 8b66b92..26dbe05 100644 --- a/docs/workflow.html +++ b/docs/workflow.html @@ -1283,6 +1283,7 @@

Example config

coding: type: task action: ai.code + model: opus # override the settings-level model for this state params: max_turns: 50 max_duration: 30m @@ -1379,6 +1380,18 @@

settings block reference

built-in Custom Docker image to use for containerized Claude sessions. + + model + string + CLI default + + Default Claude model for all AI states. Accepts aliases + (opus, sonnet, haiku) or + full canonical IDs (e.g. claude-haiku-4-5-20251001). + Can be overridden per-state with the state-level + model field. + + @@ -1393,7 +1406,8 @@

settings block reference

max_turns: 50 # stop session after 50 turns max_duration: 30 # stop session after 30 minutes auto_merge: true # merge automatically when CI passes - merge_method: squash # rebase | squash | merge + merge_method: squash # rebase | squash | merge + model: sonnet # default model for all AI states

triggers block

@@ -1643,6 +1657,53 @@

wait

wait.

+

model

+

+ Any AI state (ai.code, ai.plan, + ai.document, ai.fix_ci, + ai.resolve_conflicts, ai.address_review, + ai.review, ai.summarize) can set a + model field to override the Claude model used for that + state. Accepts aliases (opus, sonnet, + haiku) or full canonical IDs. Overrides the + settings.model default. If neither is set, the CLI + default model is used. +

+
+
+ per-state model override +
+
coding:
+  type: task
+  action: ai.code
+  model: opus            # use opus for this state only
+  next: open_pr
+
+fix_ci:
+  type: task
+  action: ai.fix_ci
+  model: haiku           # use a faster model for CI fixes
+  next: push_ci_fix
+
+

+ Builtin templates that contain AI actions also accept a + model parameter, which is passed through to the + template's AI states: +

+
+
+ model via template param +
+
coding:
+  type: template
+  use: builtin:code
+  params:
+    model: haiku
+  exits:
+    success: open_pr
+    failure: failed
+
+

choice

A choice state reads values from the accumulated step @@ -1809,6 +1870,7 @@

Built-in templates

containerizedbooltrueRun the planning session inside a container. + modelstringnoneClaude model for the planning session (e.g. haiku, sonnet, opus). @@ -1843,6 +1905,7 @@

Built-in templates

containerizedbooltrueRun the coding session inside a container. simplifyboolfalseRun the simplify pass after coding to clean up the implementation. + modelstringnoneClaude model for the coding session (e.g. haiku, sonnet, opus). @@ -1877,6 +1940,7 @@

Built-in templates

containerizedbooltrueRun the documentation session inside a container. + modelstringnoneClaude model for the documentation session (e.g. haiku, sonnet, opus). @@ -1944,6 +2008,7 @@

Built-in templates

simplifyboolfalseRun the simplify pass when resolving conflicts or fixing CI. + modelstringnoneClaude model for the AI states (fix CI, resolve conflicts) in this template. @@ -1978,6 +2043,7 @@

Built-in templates

simplifyboolfalseRun the simplify pass when addressing review feedback. + modelstringnoneClaude model for the address-review AI state in this template. diff --git a/internal/daemon/actions_test.go b/internal/daemon/actions_test.go index 5665431..ba265d1 100644 --- a/internal/daemon/actions_test.go +++ b/internal/daemon/actions_test.go @@ -10398,6 +10398,70 @@ func TestStartSummarize_CreatesWorker(t *testing.T) { } } +func TestStartSummarize_SetsModel(t *testing.T) { + cfg := testConfig() + cfg.Repos = []string{"/test/repo"} + + mockExec := exec.NewMockExecutor(nil) + mockExec.AddPrefixMatch("git", []string{"symbolic-ref"}, exec.MockResponse{ + Stdout: []byte("refs/remotes/origin/main\n"), + }) + mockExec.AddPrefixMatch("git", []string{"diff"}, exec.MockResponse{ + Stdout: []byte("diff --git a/foo.go b/foo.go\n+added line\n"), + }) + + d := testDaemonWithExec(cfg, mockExec) + d.repoFilter = "/test/repo" + + // Configure workflow with per-state model. + wfCfg := &workflow.Config{ + States: map[string]*workflow.State{ + "summarize": { + Type: workflow.StateTypeTask, + Action: "ai.summarize", + Model: "haiku", + }, + }, + } + d.workflowConfigs = map[string]*workflow.Config{"/test/repo": wfCfg} + + var capturedRunner *claude.MockRunner + d.sessionMgr.SetRunnerFactory(func(sessionID, workingDir, repoPath string, sessionStarted bool, initialMessages []claude.Message) claude.RunnerInterface { + r := claude.NewMockRunner(sessionID, false, nil) + capturedRunner = r + return r + }) + + sess := testSession("sess-1") + sess.RepoPath = "/test/repo" + sess.WorkTree = "/test/worktree-sess-1" + sess.BaseBranch = "main" + cfg.AddSession(*sess) + + d.state.AddWorkItem(&daemonstate.WorkItem{ + ID: "work-1", + IssueRef: config.IssueRef{Source: "github", ID: "42", Title: "Fix bug"}, + SessionID: "sess-1", + Branch: "feature-42", + CurrentStep: "summarize", + StepData: map[string]any{"_repo_path": "/test/repo"}, + }) + + item, _ := d.state.GetWorkItem("work-1") + + err := d.startSummarize(t.Context(), item) + if err != nil { + t.Fatalf("startSummarize failed: %v", err) + } + + if capturedRunner == nil { + t.Fatal("expected runner factory to be called") + } + if got := capturedRunner.GetModel(); got != "claude-haiku-4-5-20251001" { + t.Errorf("expected model %q, got %q", "claude-haiku-4-5-20251001", got) + } +} + // --- injectScheduledIssue tests --- func TestInjectScheduledIssue_EnqueuesWorkItem(t *testing.T) { diff --git a/internal/daemon/coding.go b/internal/daemon/coding.go index 498fce8..f609255 100644 --- a/internal/daemon/coding.go +++ b/internal/daemon/coding.go @@ -1529,6 +1529,7 @@ func (d *Daemon) startSummarize(ctx context.Context, item daemonstate.WorkItem) w := d.createWorkerWithPrompt(ctx, item, sess, initialMsg, resolvedPrompt, summarizeTools) runner := d.sessionMgr.GetOrCreateRunner(sess) runner.SetDisallowedTools(claude.ToolSetPlanningDeny) + runner.SetModel(d.resolveStateModel(wfCfg, item.CurrentStep)) w.SetPlanningMode(true) maxTurns := params.Int("max_turns", 0) maxDuration := params.Duration("max_duration", 0) diff --git a/internal/workflow/defaults.go b/internal/workflow/defaults.go index 8a6c173..eb29ad0 100644 --- a/internal/workflow/defaults.go +++ b/internal/workflow/defaults.go @@ -259,12 +259,14 @@ func PlanTemplateConfig() *TemplateConfig { }, Params: []TemplateParam{ {Name: "containerized", Default: true}, + {Name: "model", Default: ""}, }, States: map[string]*State{ "planning": { Type: StateTypeTask, Action: "ai.plan", DisplayName: "Planning", + Model: "{{model}}", Params: map[string]any{ "max_turns": 30, "max_duration": "15m", @@ -328,12 +330,14 @@ func CodeTemplateConfig() *TemplateConfig { Params: []TemplateParam{ {Name: "simplify", Default: false}, {Name: "containerized", Default: true}, + {Name: "model", Default: ""}, }, States: map[string]*State{ "coding": { Type: StateTypeTask, Action: "ai.code", DisplayName: "Coding", + Model: "{{model}}", Params: map[string]any{ "max_turns": 50, "max_duration": "30m", @@ -366,12 +370,14 @@ func DocumentTemplateConfig() *TemplateConfig { }, Params: []TemplateParam{ {Name: "containerized", Default: true}, + {Name: "model", Default: ""}, }, States: map[string]*State{ "documenting": { Type: StateTypeTask, Action: "ai.document", DisplayName: "Documenting", + Model: "{{model}}", Params: map[string]any{ "max_turns": 50, "max_duration": "30m", @@ -441,6 +447,7 @@ func CITemplateConfig() *TemplateConfig { }, Params: []TemplateParam{ {Name: "simplify", Default: false}, + {Name: "model", Default: ""}, }, States: map[string]*State{ "await_ci": { @@ -480,6 +487,7 @@ func CITemplateConfig() *TemplateConfig { Type: StateTypeTask, Action: "ai.resolve_conflicts", DisplayName: "Resolving Conflicts", + Model: "{{model}}", Params: map[string]any{ "max_conflict_rounds": 3, "simplify": "{{simplify}}", @@ -499,6 +507,7 @@ func CITemplateConfig() *TemplateConfig { Type: StateTypeTask, Action: "ai.fix_ci", DisplayName: "Fixing CI", + Model: "{{model}}", Params: map[string]any{ "max_ci_fix_rounds": 3, "simplify": "{{simplify}}", @@ -559,6 +568,7 @@ func ReviewTemplateConfig() *TemplateConfig { }, Params: []TemplateParam{ {Name: "simplify", Default: false}, + {Name: "model", Default: ""}, }, States: map[string]*State{ "await_review": { @@ -589,6 +599,7 @@ func ReviewTemplateConfig() *TemplateConfig { Type: StateTypeTask, Action: "ai.address_review", DisplayName: "Addressing Review", + Model: "{{model}}", Params: map[string]any{ "max_review_rounds": 3, "simplify": "{{simplify}}", diff --git a/internal/workflow/template.go b/internal/workflow/template.go index a98a877..a739935 100644 --- a/internal/workflow/template.go +++ b/internal/workflow/template.go @@ -261,12 +261,15 @@ func resolveParams(defs []TemplateParam, overrides map[string]any) map[string]an var paramPlaceholder = regexp.MustCompile(`\{\{(\w+)\}\}`) // applyParamSubstitution replaces {{param_name}} placeholders in the string -// values of state.Params with the corresponding value from params. -// Only string values within state.Params are substituted. +// values of state.Params and state.Model with the corresponding value from params. func applyParamSubstitution(state *State, params map[string]any) { - if len(state.Params) == 0 || len(params) == 0 { + if len(params) == 0 { return } + // Substitute in state.Model (e.g. "{{model}}" → "haiku"). + if state.Model != "" { + state.Model = substituteParams(state.Model, params) + } for k, v := range state.Params { if s, ok := v.(string); ok { // When the entire value is a single {{name}} placeholder, replace it diff --git a/internal/workflow/template_test.go b/internal/workflow/template_test.go index b14252e..d2f3970 100644 --- a/internal/workflow/template_test.go +++ b/internal/workflow/template_test.go @@ -1777,6 +1777,75 @@ func TestExpandTemplates_BuiltinLinearAwaitState(t *testing.T) { } } +func TestExpandTemplates_BuiltinModelParam(t *testing.T) { + tests := []struct { + name string + builtin string + aiStates []string + }{ + {"plan", "builtin:plan", []string{"_t_start_planning"}}, + {"code", "builtin:code", []string{"_t_start_coding"}}, + {"document", "builtin:document", []string{"_t_start_documenting"}}, + {"ci", "builtin:ci", []string{"_t_start_fix_ci", "_t_start_resolve_conflicts"}}, + {"review", "builtin:review", []string{"_t_start_address_review"}}, + } + for _, tt := range tests { + t.Run(tt.name+"_default_empty", func(t *testing.T) { + cfg := minimalCfg(map[string]*State{ + "start": { + Type: StateTypeTemplate, + Use: tt.builtin, + Exits: map[string]string{"success": "done", "failure": "failed"}, + }, + }) + if tt.builtin == "builtin:review" { + cfg.States["start"].Exits["ci_regression"] = "failed" + } + cfg.Start = "start" + result, err := ExpandTemplates(cfg, t.TempDir()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + for _, stateName := range tt.aiStates { + s := result.States[stateName] + if s == nil { + t.Fatalf("state %q missing", stateName) + } + if s.Model != "" { + t.Errorf("%s: model should be empty by default, got %q", stateName, s.Model) + } + } + }) + t.Run(tt.name+"_model_override", func(t *testing.T) { + cfg := minimalCfg(map[string]*State{ + "start": { + Type: StateTypeTemplate, + Use: tt.builtin, + Params: map[string]any{"model": "haiku"}, + Exits: map[string]string{"success": "done", "failure": "failed"}, + }, + }) + if tt.builtin == "builtin:review" { + cfg.States["start"].Exits["ci_regression"] = "failed" + } + cfg.Start = "start" + result, err := ExpandTemplates(cfg, t.TempDir()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + for _, stateName := range tt.aiStates { + s := result.States[stateName] + if s == nil { + t.Fatalf("state %q missing", stateName) + } + if s.Model != "haiku" { + t.Errorf("%s: model = %q, want %q", stateName, s.Model, "haiku") + } + } + }) + } +} + // TestExpandTemplates_ModularComposition verifies that multiple modular templates // can be composed into a complete workflow (the primary use case). func TestExpandTemplates_ModularComposition(t *testing.T) {