Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 67 additions & 1 deletion docs/workflow.html
Original file line number Diff line number Diff line change
Expand Up @@ -1283,6 +1283,7 @@ <h3>Example config</h3>
<span class="ck">coding:</span>
<span class="ck">type:</span> <span class="cs">task</span>
<span class="ck">action:</span> <span class="ca">ai.code</span>
<span class="ck">model:</span> <span class="cv">opus</span> <span class="cc"># override the settings-level model for this state</span>
<span class="ck">params:</span>
<span class="ck">max_turns:</span> <span class="cv">50</span>
<span class="ck">max_duration:</span> <span class="cv">30m</span>
Expand Down Expand Up @@ -1379,6 +1380,18 @@ <h3 id="settings">settings block reference</h3>
<td><em>built-in</em></td>
<td>Custom Docker image to use for containerized Claude sessions.</td>
</tr>
<tr>
<td><code>model</code></td>
<td>string</td>
<td><em>CLI default</em></td>
<td>
Default Claude model for all AI states. Accepts aliases
(<code>opus</code>, <code>sonnet</code>, <code>haiku</code>) or
full canonical IDs (e.g. <code>claude-haiku-4-5-20251001</code>).
Can be overridden per-state with the state-level
<code>model</code> field.
</td>
</tr>
</tbody>
</table>

Expand All @@ -1393,7 +1406,8 @@ <h3 id="settings">settings block reference</h3>
<span class="ck">max_turns:</span> <span class="cv">50</span> <span class="cc"># stop session after 50 turns</span>
<span class="ck">max_duration:</span> <span class="cv">30</span> <span class="cc"># stop session after 30 minutes</span>
<span class="ck">auto_merge:</span> <span class="cv">true</span> <span class="cc"># merge automatically when CI passes</span>
<span class="ck">merge_method:</span> <span class="cv">squash</span> <span class="cc"># rebase | squash | merge</span></pre>
<span class="ck">merge_method:</span> <span class="cv">squash</span> <span class="cc"># rebase | squash | merge</span>
<span class="ck">model:</span> <span class="cv">sonnet</span> <span class="cc"># default model for all AI states</span></pre>
</div>

<h3 id="triggers">triggers block</h3>
Expand Down Expand Up @@ -1643,6 +1657,53 @@ <h3 id="state-wait">wait</h3>
<code>wait</code>.
</p>

<h4 id="state-model">model</h4>
<p>
Any AI state (<code>ai.code</code>, <code>ai.plan</code>,
<code>ai.document</code>, <code>ai.fix_ci</code>,
<code>ai.resolve_conflicts</code>, <code>ai.address_review</code>,
<code>ai.review</code>, <code>ai.summarize</code>) can set a
<code>model</code> field to override the Claude model used for that
state. Accepts aliases (<code>opus</code>, <code>sonnet</code>,
<code>haiku</code>) or full canonical IDs. Overrides the
<code>settings.model</code> default. If neither is set, the CLI
default model is used.
</p>
<div class="code-block">
<div class="code-header">
<span class="code-filename">per-state model override</span>
</div>
<pre><span class="ck">coding:</span>
<span class="ck">type:</span> <span class="cs">task</span>
<span class="ck">action:</span> <span class="ca">ai.code</span>
<span class="ck">model:</span> <span class="cv">opus</span> <span class="cc"># use opus for this state only</span>
<span class="ck">next:</span> <span class="cv">open_pr</span>

<span class="ck">fix_ci:</span>
<span class="ck">type:</span> <span class="cs">task</span>
<span class="ck">action:</span> <span class="ca">ai.fix_ci</span>
<span class="ck">model:</span> <span class="cv">haiku</span> <span class="cc"># use a faster model for CI fixes</span>
<span class="ck">next:</span> <span class="cv">push_ci_fix</span></pre>
</div>
<p>
Builtin templates that contain AI actions also accept a
<code>model</code> parameter, which is passed through to the
template's AI states:
</p>
<div class="code-block">
<div class="code-header">
<span class="code-filename">model via template param</span>
</div>
<pre><span class="ck">coding:</span>
<span class="ck">type:</span> <span class="cs">template</span>
<span class="ck">use:</span> <span class="cv">builtin:code</span>
<span class="ck">params:</span>
<span class="ck">model:</span> <span class="cv">haiku</span>
<span class="ck">exits:</span>
<span class="ck">success:</span> <span class="cv">open_pr</span>
<span class="ck">failure:</span> <span class="cv">failed</span></pre>
</div>

<h3 id="state-choice">choice</h3>
<p>
A <code>choice</code> state reads values from the accumulated step
Expand Down Expand Up @@ -1809,6 +1870,7 @@ <h4 id="templates">Built-in templates</h4>
</thead>
<tbody>
<tr><td>containerized</td><td>bool</td><td>true</td><td>Run the planning session inside a container.</td></tr>
<tr><td>model</td><td>string</td><td><em>none</em></td><td>Claude model for the planning session (e.g. <code>haiku</code>, <code>sonnet</code>, <code>opus</code>).</td></tr>
</tbody>
</table>
</div>
Expand Down Expand Up @@ -1843,6 +1905,7 @@ <h4 id="templates">Built-in templates</h4>
<tbody>
<tr><td>containerized</td><td>bool</td><td>true</td><td>Run the coding session inside a container.</td></tr>
<tr><td>simplify</td><td>bool</td><td>false</td><td>Run the simplify pass after coding to clean up the implementation.</td></tr>
<tr><td>model</td><td>string</td><td><em>none</em></td><td>Claude model for the coding session (e.g. <code>haiku</code>, <code>sonnet</code>, <code>opus</code>).</td></tr>
</tbody>
</table>
</div>
Expand Down Expand Up @@ -1877,6 +1940,7 @@ <h4 id="templates">Built-in templates</h4>
</thead>
<tbody>
<tr><td>containerized</td><td>bool</td><td>true</td><td>Run the documentation session inside a container.</td></tr>
<tr><td>model</td><td>string</td><td><em>none</em></td><td>Claude model for the documentation session (e.g. <code>haiku</code>, <code>sonnet</code>, <code>opus</code>).</td></tr>
</tbody>
</table>
</div>
Expand Down Expand Up @@ -1944,6 +2008,7 @@ <h4 id="templates">Built-in templates</h4>
</thead>
<tbody>
<tr><td>simplify</td><td>bool</td><td>false</td><td>Run the simplify pass when resolving conflicts or fixing CI.</td></tr>
<tr><td>model</td><td>string</td><td><em>none</em></td><td>Claude model for the AI states (fix CI, resolve conflicts) in this template.</td></tr>
</tbody>
</table>
</div>
Expand Down Expand Up @@ -1978,6 +2043,7 @@ <h4 id="templates">Built-in templates</h4>
</thead>
<tbody>
<tr><td>simplify</td><td>bool</td><td>false</td><td>Run the simplify pass when addressing review feedback.</td></tr>
<tr><td>model</td><td>string</td><td><em>none</em></td><td>Claude model for the address-review AI state in this template.</td></tr>
</tbody>
</table>
</div>
Expand Down
64 changes: 64 additions & 0 deletions internal/daemon/actions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10398,6 +10398,70 @@ func TestStartSummarize_CreatesWorker(t *testing.T) {
}
}

func TestStartSummarize_SetsModel(t *testing.T) {
cfg := testConfig()
cfg.Repos = []string{"/test/repo"}

mockExec := exec.NewMockExecutor(nil)
mockExec.AddPrefixMatch("git", []string{"symbolic-ref"}, exec.MockResponse{
Stdout: []byte("refs/remotes/origin/main\n"),
})
mockExec.AddPrefixMatch("git", []string{"diff"}, exec.MockResponse{
Stdout: []byte("diff --git a/foo.go b/foo.go\n+added line\n"),
})

d := testDaemonWithExec(cfg, mockExec)
d.repoFilter = "/test/repo"

// Configure workflow with per-state model.
wfCfg := &workflow.Config{
States: map[string]*workflow.State{
"summarize": {
Type: workflow.StateTypeTask,
Action: "ai.summarize",
Model: "haiku",
},
},
}
d.workflowConfigs = map[string]*workflow.Config{"/test/repo": wfCfg}

var capturedRunner *claude.MockRunner
d.sessionMgr.SetRunnerFactory(func(sessionID, workingDir, repoPath string, sessionStarted bool, initialMessages []claude.Message) claude.RunnerInterface {
r := claude.NewMockRunner(sessionID, false, nil)
capturedRunner = r
return r
})

sess := testSession("sess-1")
sess.RepoPath = "/test/repo"
sess.WorkTree = "/test/worktree-sess-1"
sess.BaseBranch = "main"
cfg.AddSession(*sess)

d.state.AddWorkItem(&daemonstate.WorkItem{
ID: "work-1",
IssueRef: config.IssueRef{Source: "github", ID: "42", Title: "Fix bug"},
SessionID: "sess-1",
Branch: "feature-42",
CurrentStep: "summarize",
StepData: map[string]any{"_repo_path": "/test/repo"},
})

item, _ := d.state.GetWorkItem("work-1")

err := d.startSummarize(t.Context(), item)
if err != nil {
t.Fatalf("startSummarize failed: %v", err)
}

if capturedRunner == nil {
t.Fatal("expected runner factory to be called")
}
if got := capturedRunner.GetModel(); got != "claude-haiku-4-5-20251001" {
t.Errorf("expected model %q, got %q", "claude-haiku-4-5-20251001", got)
}
}

// --- injectScheduledIssue tests ---

func TestInjectScheduledIssue_EnqueuesWorkItem(t *testing.T) {
Expand Down
1 change: 1 addition & 0 deletions internal/daemon/coding.go
Original file line number Diff line number Diff line change
Expand Up @@ -1529,6 +1529,7 @@ func (d *Daemon) startSummarize(ctx context.Context, item daemonstate.WorkItem)
w := d.createWorkerWithPrompt(ctx, item, sess, initialMsg, resolvedPrompt, summarizeTools)
runner := d.sessionMgr.GetOrCreateRunner(sess)
runner.SetDisallowedTools(claude.ToolSetPlanningDeny)
runner.SetModel(d.resolveStateModel(wfCfg, item.CurrentStep))
w.SetPlanningMode(true)
maxTurns := params.Int("max_turns", 0)
maxDuration := params.Duration("max_duration", 0)
Expand Down
11 changes: 11 additions & 0 deletions internal/workflow/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,12 +259,14 @@ func PlanTemplateConfig() *TemplateConfig {
},
Params: []TemplateParam{
{Name: "containerized", Default: true},
{Name: "model", Default: ""},
},
States: map[string]*State{
"planning": {
Type: StateTypeTask,
Action: "ai.plan",
DisplayName: "Planning",
Model: "{{model}}",
Params: map[string]any{
"max_turns": 30,
"max_duration": "15m",
Expand Down Expand Up @@ -328,12 +330,14 @@ func CodeTemplateConfig() *TemplateConfig {
Params: []TemplateParam{
{Name: "simplify", Default: false},
{Name: "containerized", Default: true},
{Name: "model", Default: ""},
},
States: map[string]*State{
"coding": {
Type: StateTypeTask,
Action: "ai.code",
DisplayName: "Coding",
Model: "{{model}}",
Params: map[string]any{
"max_turns": 50,
"max_duration": "30m",
Expand Down Expand Up @@ -366,12 +370,14 @@ func DocumentTemplateConfig() *TemplateConfig {
},
Params: []TemplateParam{
{Name: "containerized", Default: true},
{Name: "model", Default: ""},
},
States: map[string]*State{
"documenting": {
Type: StateTypeTask,
Action: "ai.document",
DisplayName: "Documenting",
Model: "{{model}}",
Params: map[string]any{
"max_turns": 50,
"max_duration": "30m",
Expand Down Expand Up @@ -441,6 +447,7 @@ func CITemplateConfig() *TemplateConfig {
},
Params: []TemplateParam{
{Name: "simplify", Default: false},
{Name: "model", Default: ""},
},
States: map[string]*State{
"await_ci": {
Expand Down Expand Up @@ -480,6 +487,7 @@ func CITemplateConfig() *TemplateConfig {
Type: StateTypeTask,
Action: "ai.resolve_conflicts",
DisplayName: "Resolving Conflicts",
Model: "{{model}}",
Params: map[string]any{
"max_conflict_rounds": 3,
"simplify": "{{simplify}}",
Expand All @@ -499,6 +507,7 @@ func CITemplateConfig() *TemplateConfig {
Type: StateTypeTask,
Action: "ai.fix_ci",
DisplayName: "Fixing CI",
Model: "{{model}}",
Params: map[string]any{
"max_ci_fix_rounds": 3,
"simplify": "{{simplify}}",
Expand Down Expand Up @@ -559,6 +568,7 @@ func ReviewTemplateConfig() *TemplateConfig {
},
Params: []TemplateParam{
{Name: "simplify", Default: false},
{Name: "model", Default: ""},
},
States: map[string]*State{
"await_review": {
Expand Down Expand Up @@ -589,6 +599,7 @@ func ReviewTemplateConfig() *TemplateConfig {
Type: StateTypeTask,
Action: "ai.address_review",
DisplayName: "Addressing Review",
Model: "{{model}}",
Params: map[string]any{
"max_review_rounds": 3,
"simplify": "{{simplify}}",
Expand Down
9 changes: 6 additions & 3 deletions internal/workflow/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,15 @@ func resolveParams(defs []TemplateParam, overrides map[string]any) map[string]an
var paramPlaceholder = regexp.MustCompile(`\{\{(\w+)\}\}`)

// applyParamSubstitution replaces {{param_name}} placeholders in the string
// values of state.Params with the corresponding value from params.
// Only string values within state.Params are substituted.
// values of state.Params and state.Model with the corresponding value from params.
func applyParamSubstitution(state *State, params map[string]any) {
if len(state.Params) == 0 || len(params) == 0 {
if len(params) == 0 {
return
}
// Substitute in state.Model (e.g. "{{model}}" → "haiku").
if state.Model != "" {
state.Model = substituteParams(state.Model, params)
}
for k, v := range state.Params {
if s, ok := v.(string); ok {
// When the entire value is a single {{name}} placeholder, replace it
Expand Down
69 changes: 69 additions & 0 deletions internal/workflow/template_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1777,6 +1777,75 @@ func TestExpandTemplates_BuiltinLinearAwaitState(t *testing.T) {
}
}

func TestExpandTemplates_BuiltinModelParam(t *testing.T) {
tests := []struct {
name string
builtin string
aiStates []string
}{
{"plan", "builtin:plan", []string{"_t_start_planning"}},
{"code", "builtin:code", []string{"_t_start_coding"}},
{"document", "builtin:document", []string{"_t_start_documenting"}},
{"ci", "builtin:ci", []string{"_t_start_fix_ci", "_t_start_resolve_conflicts"}},
{"review", "builtin:review", []string{"_t_start_address_review"}},
}
for _, tt := range tests {
t.Run(tt.name+"_default_empty", func(t *testing.T) {
cfg := minimalCfg(map[string]*State{
"start": {
Type: StateTypeTemplate,
Use: tt.builtin,
Exits: map[string]string{"success": "done", "failure": "failed"},
},
})
if tt.builtin == "builtin:review" {
cfg.States["start"].Exits["ci_regression"] = "failed"
}
cfg.Start = "start"
result, err := ExpandTemplates(cfg, t.TempDir())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
for _, stateName := range tt.aiStates {
s := result.States[stateName]
if s == nil {
t.Fatalf("state %q missing", stateName)
}
if s.Model != "" {
t.Errorf("%s: model should be empty by default, got %q", stateName, s.Model)
}
}
})
t.Run(tt.name+"_model_override", func(t *testing.T) {
cfg := minimalCfg(map[string]*State{
"start": {
Type: StateTypeTemplate,
Use: tt.builtin,
Params: map[string]any{"model": "haiku"},
Exits: map[string]string{"success": "done", "failure": "failed"},
},
})
if tt.builtin == "builtin:review" {
cfg.States["start"].Exits["ci_regression"] = "failed"
}
cfg.Start = "start"
result, err := ExpandTemplates(cfg, t.TempDir())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
for _, stateName := range tt.aiStates {
s := result.States[stateName]
if s == nil {
t.Fatalf("state %q missing", stateName)
}
if s.Model != "haiku" {
t.Errorf("%s: model = %q, want %q", stateName, s.Model, "haiku")
}
}
})
}
}

// TestExpandTemplates_ModularComposition verifies that multiple modular templates
// can be composed into a complete workflow (the primary use case).
func TestExpandTemplates_ModularComposition(t *testing.T) {
Expand Down
Loading