From e6a99034b9a0365b18c319fb7380124b7f30f7a6 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Fri, 13 Mar 2026 21:20:46 -0500 Subject: [PATCH 01/79] feat: incremental loading --- cmd/wire/check_cmd.go | 7 +- cmd/wire/diff_cmd.go | 9 +- cmd/wire/gen_cmd.go | 11 +- cmd/wire/incremental_flag.go | 60 ++ cmd/wire/main.go | 42 +- cmd/wire/show_cmd.go | 7 +- cmd/wire/watch_cmd.go | 11 +- go.mod | 2 +- internal/wire/cache_bypass.go | 17 + internal/wire/cache_test.go | 6 +- internal/wire/generate_package.go | 2 +- internal/wire/incremental.go | 65 ++ internal/wire/incremental_bench_test.go | 654 +++++++++++++ internal/wire/incremental_fingerprint.go | 421 ++++++++ internal/wire/incremental_fingerprint_test.go | 104 ++ internal/wire/incremental_graph.go | 306 ++++++ internal/wire/incremental_graph_test.go | 97 ++ internal/wire/incremental_manifest.go | 876 +++++++++++++++++ internal/wire/incremental_session.go | 95 ++ internal/wire/incremental_summary.go | 647 ++++++++++++ internal/wire/incremental_summary_test.go | 287 ++++++ internal/wire/incremental_test.go | 65 ++ internal/wire/load_debug.go | 304 ++++++ internal/wire/loader_test.go | 920 ++++++++++++++++++ internal/wire/local_fastpath.go | 556 +++++++++++ internal/wire/parse.go | 26 +- internal/wire/parser_lazy_loader.go | 64 +- internal/wire/parser_lazy_loader_test.go | 55 +- internal/wire/time_compat.go | 22 + internal/wire/timing.go | 8 + internal/wire/wire.go | 89 +- internal/wire/wire_test.go | 50 + 32 files changed, 5843 insertions(+), 42 deletions(-) create mode 100644 cmd/wire/incremental_flag.go create mode 100644 internal/wire/cache_bypass.go create mode 100644 internal/wire/incremental.go create mode 100644 internal/wire/incremental_bench_test.go create mode 100644 internal/wire/incremental_fingerprint.go create mode 100644 internal/wire/incremental_fingerprint_test.go create mode 100644 internal/wire/incremental_graph.go create mode 100644 internal/wire/incremental_graph_test.go create mode 100644 internal/wire/incremental_manifest.go create mode 100644 internal/wire/incremental_session.go create mode 100644 internal/wire/incremental_summary.go create mode 100644 internal/wire/incremental_summary_test.go create mode 100644 internal/wire/incremental_test.go create mode 100644 internal/wire/load_debug.go create mode 100644 internal/wire/local_fastpath.go create mode 100644 internal/wire/time_compat.go diff --git a/cmd/wire/check_cmd.go b/cmd/wire/check_cmd.go index 7857437..71872d9 100644 --- a/cmd/wire/check_cmd.go +++ b/cmd/wire/check_cmd.go @@ -26,8 +26,9 @@ import ( ) type checkCmd struct { - tags string - profile profileFlags + tags string + incremental optionalBoolFlag + profile profileFlags } // Name returns the subcommand name. @@ -52,6 +53,7 @@ func (*checkCmd) Usage() string { // SetFlags registers flags for the subcommand. func (cmd *checkCmd) SetFlags(f *flag.FlagSet) { f.StringVar(&cmd.tags, "tags", "", "append build tags to the default wirebuild") + addIncrementalFlag(&cmd.incremental, f) cmd.profile.addFlags(f) } @@ -65,6 +67,7 @@ func (cmd *checkCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...inter defer stop() totalStart := time.Now() ctx = withTiming(ctx, cmd.profile.timings) + ctx = cmd.incremental.apply(ctx) wd, err := os.Getwd() if err != nil { diff --git a/cmd/wire/diff_cmd.go b/cmd/wire/diff_cmd.go index 592cced..c7facca 100644 --- a/cmd/wire/diff_cmd.go +++ b/cmd/wire/diff_cmd.go @@ -29,9 +29,10 @@ import ( ) type diffCmd struct { - headerFile string - tags string - profile profileFlags + headerFile string + tags string + incremental optionalBoolFlag + profile profileFlags } // Name returns the subcommand name. @@ -60,6 +61,7 @@ func (*diffCmd) Usage() string { func (cmd *diffCmd) SetFlags(f *flag.FlagSet) { f.StringVar(&cmd.headerFile, "header_file", "", "path to file to insert as a header in wire_gen.go") f.StringVar(&cmd.tags, "tags", "", "append build tags to the default wirebuild") + addIncrementalFlag(&cmd.incremental, f) cmd.profile.addFlags(f) } @@ -77,6 +79,7 @@ func (cmd *diffCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interf defer stop() totalStart := time.Now() ctx = withTiming(ctx, cmd.profile.timings) + ctx = cmd.incremental.apply(ctx) wd, err := os.Getwd() if err != nil { diff --git a/cmd/wire/gen_cmd.go b/cmd/wire/gen_cmd.go index 1532dd4..13b88ed 100644 --- a/cmd/wire/gen_cmd.go +++ b/cmd/wire/gen_cmd.go @@ -29,6 +29,7 @@ type genCmd struct { headerFile string prefixFileName string tags string + incremental optionalBoolFlag profile profileFlags } @@ -55,6 +56,7 @@ func (cmd *genCmd) SetFlags(f *flag.FlagSet) { f.StringVar(&cmd.headerFile, "header_file", "", "path to file to insert as a header in wire_gen.go") f.StringVar(&cmd.prefixFileName, "output_file_prefix", "", "string to prepend to output file names.") f.StringVar(&cmd.tags, "tags", "", "append build tags to the default wirebuild") + addIncrementalFlag(&cmd.incremental, f) cmd.profile.addFlags(f) } @@ -68,6 +70,7 @@ func (cmd *genCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interfa defer stop() totalStart := time.Now() ctx = withTiming(ctx, cmd.profile.timings) + ctx = cmd.incremental.apply(ctx) wd, err := os.Getwd() if err != nil { @@ -107,8 +110,12 @@ func (cmd *genCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interfa // No Wire output. Maybe errors, maybe no Wire directives. continue } - if err := out.Commit(); err == nil { - log.Printf("%s: wrote %s (%s)\n", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) + if wrote, err := out.CommitWithStatus(); err == nil { + if wrote { + log.Printf("%s: wrote %s (%s)\n", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) + } else { + log.Printf("%s: unchanged %s (%s)\n", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) + } } else { log.Printf("%s: failed to write %s: %v\n", out.PkgPath, out.OutputPath, err) success = false diff --git a/cmd/wire/incremental_flag.go b/cmd/wire/incremental_flag.go new file mode 100644 index 0000000..2962128 --- /dev/null +++ b/cmd/wire/incremental_flag.go @@ -0,0 +1,60 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "flag" + "strconv" + + "github.com/goforj/wire/internal/wire" +) + +type optionalBoolFlag struct { + value bool + set bool +} + +func (f *optionalBoolFlag) String() string { + if f == nil { + return "" + } + return strconv.FormatBool(f.value) +} + +func (f *optionalBoolFlag) Set(s string) error { + v, err := strconv.ParseBool(s) + if err != nil { + return err + } + f.value = v + f.set = true + return nil +} + +func (f *optionalBoolFlag) IsBoolFlag() bool { + return true +} + +func (f *optionalBoolFlag) apply(ctx context.Context) context.Context { + if f == nil || !f.set { + return ctx + } + return wire.WithIncremental(ctx, f.value) +} + +func addIncrementalFlag(f *optionalBoolFlag, fs *flag.FlagSet) { + fs.Var(f, "incremental", "enable the incremental engine (overrides "+wire.IncrementalEnvVar+")") +} diff --git a/cmd/wire/main.go b/cmd/wire/main.go index 2f90783..3166531 100644 --- a/cmd/wire/main.go +++ b/cmd/wire/main.go @@ -34,6 +34,13 @@ import ( "github.com/google/subcommands" ) +var topLevelIncremental optionalBoolFlag + +const ( + ansiRed = "\033[31m" + ansiReset = "\033[0m" +) + // main wires up subcommands and executes the selected command. func main() { subcommands.Register(subcommands.CommandsCommand(), "") @@ -45,6 +52,7 @@ func main() { subcommands.Register(&genCmd{}, "") subcommands.Register(&watchCmd{}, "") subcommands.Register(&showCmd{}, "") + addIncrementalFlag(&topLevelIncremental, flag.CommandLine) flag.Parse() // Initialize the default logger to log to stderr. @@ -71,9 +79,9 @@ func main() { // Default to running the "gen" command. if args := flag.Args(); len(args) == 0 || !allCmds[args[0]] { genCmd := &genCmd{} - os.Exit(int(genCmd.Execute(context.Background(), flag.CommandLine))) + os.Exit(int(genCmd.Execute(topLevelIncremental.apply(context.Background()), flag.CommandLine))) } - os.Exit(int(subcommands.Execute(context.Background()))) + os.Exit(int(subcommands.Execute(topLevelIncremental.apply(context.Background())))) } // installStackDumper registers signal handlers to dump goroutine stacks. @@ -200,6 +208,34 @@ func newGenerateOptions(headerFile string) (*wire.GenerateOptions, error) { // logErrors logs each error with consistent formatting. func logErrors(errs []error) { for _, err := range errs { - log.Println(strings.Replace(err.Error(), "\n", "\n\t", -1)) + msg := err.Error() + if strings.Contains(msg, "\n") { + logMultilineError("\n " + strings.ReplaceAll(msg, "\n", "\n ")) + continue + } + logMultilineError(msg) + } +} + +func logMultilineError(msg string) { + if shouldColorStderr() { + log.Print(ansiRed + msg + ansiReset) + return + } + log.Print(msg) +} + +func shouldColorStderr() bool { + if os.Getenv("NO_COLOR") != "" { + return false + } + term := os.Getenv("TERM") + if term == "" || term == "dumb" { + return false + } + info, err := os.Stderr.Stat() + if err != nil { + return false } + return (info.Mode() & os.ModeCharDevice) != 0 } diff --git a/cmd/wire/show_cmd.go b/cmd/wire/show_cmd.go index 5a81b29..1313ade 100644 --- a/cmd/wire/show_cmd.go +++ b/cmd/wire/show_cmd.go @@ -34,8 +34,9 @@ import ( ) type showCmd struct { - tags string - profile profileFlags + tags string + incremental optionalBoolFlag + profile profileFlags } // Name returns the subcommand name. @@ -62,6 +63,7 @@ func (*showCmd) Usage() string { // SetFlags registers flags for the subcommand. func (cmd *showCmd) SetFlags(f *flag.FlagSet) { f.StringVar(&cmd.tags, "tags", "", "append build tags to the default wirebuild") + addIncrementalFlag(&cmd.incremental, f) cmd.profile.addFlags(f) } @@ -75,6 +77,7 @@ func (cmd *showCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interf defer stop() totalStart := time.Now() ctx = withTiming(ctx, cmd.profile.timings) + ctx = cmd.incremental.apply(ctx) wd, err := os.Getwd() if err != nil { diff --git a/cmd/wire/watch_cmd.go b/cmd/wire/watch_cmd.go index 779625f..13743cd 100644 --- a/cmd/wire/watch_cmd.go +++ b/cmd/wire/watch_cmd.go @@ -36,6 +36,7 @@ type watchCmd struct { headerFile string prefixFileName string tags string + incremental optionalBoolFlag profile profileFlags pollInterval time.Duration rescanInterval time.Duration @@ -63,6 +64,7 @@ func (cmd *watchCmd) SetFlags(f *flag.FlagSet) { f.StringVar(&cmd.headerFile, "header_file", "", "path to file to insert as a header in wire_gen.go") f.StringVar(&cmd.prefixFileName, "output_file_prefix", "", "string to prepend to output file names.") f.StringVar(&cmd.tags, "tags", "", "append build tags to the default wirebuild") + addIncrementalFlag(&cmd.incremental, f) f.DurationVar(&cmd.pollInterval, "poll_interval", 250*time.Millisecond, "interval between file stat checks") f.DurationVar(&cmd.rescanInterval, "rescan_interval", 2*time.Second, "interval to rescan for new or removed Go files") cmd.profile.addFlags(f) @@ -77,6 +79,7 @@ func (cmd *watchCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...inter } defer stop() ctx = withTiming(ctx, cmd.profile.timings) + ctx = cmd.incremental.apply(ctx) if cmd.pollInterval <= 0 { log.Println("poll_interval must be greater than zero") @@ -126,8 +129,12 @@ func (cmd *watchCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...inter if len(out.Content) == 0 { continue } - if err := out.Commit(); err == nil { - log.Printf("%s: wrote %s (%s)\n", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) + if wrote, err := out.CommitWithStatus(); err == nil { + if wrote { + log.Printf("%s: wrote %s (%s)\n", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) + } else { + log.Printf("%s: unchanged %s (%s)\n", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) + } } else { log.Printf("%s: failed to write %s: %v\n", out.PkgPath, out.OutputPath, err) success = false diff --git a/go.mod b/go.mod index e800555..5db8855 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/goforj/wire go 1.19 require ( + github.com/fsnotify/fsnotify v1.7.0 github.com/google/go-cmp v0.6.0 github.com/google/subcommands v1.2.0 github.com/pmezard/go-difflib v1.0.0 @@ -10,7 +11,6 @@ require ( ) require ( - github.com/fsnotify/fsnotify v1.7.0 // indirect golang.org/x/mod v0.20.0 // indirect golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.23.0 // indirect diff --git a/internal/wire/cache_bypass.go b/internal/wire/cache_bypass.go new file mode 100644 index 0000000..b195eef --- /dev/null +++ b/internal/wire/cache_bypass.go @@ -0,0 +1,17 @@ +package wire + +import "context" + +type bypassPackageCacheKey struct{} + +func withBypassPackageCache(ctx context.Context) context.Context { + return context.WithValue(ctx, bypassPackageCacheKey{}, true) +} + +func bypassPackageCache(ctx context.Context) bool { + if ctx == nil { + return false + } + v, _ := ctx.Value(bypassPackageCacheKey{}).(bool) + return v +} diff --git a/internal/wire/cache_test.go b/internal/wire/cache_test.go index bc55bae..6ffb20a 100644 --- a/internal/wire/cache_test.go +++ b/internal/wire/cache_test.go @@ -123,8 +123,10 @@ func TestCacheInvalidation(t *testing.T) { if key2 == key { t.Fatal("expected cache key to change after source update") } - if cached, ok := readCache(key2); !ok || len(cached) == 0 { - t.Fatal("expected cache entry after second Generate") + if !IncrementalEnabled(ctx, env) { + if cached, ok := readCache(key2); !ok || len(cached) == 0 { + t.Fatal("expected cache entry after second Generate") + } } } diff --git a/internal/wire/generate_package.go b/internal/wire/generate_package.go index de34aa6..01d3d20 100644 --- a/internal/wire/generate_package.go +++ b/internal/wire/generate_package.go @@ -47,7 +47,7 @@ func generateForPackage(ctx context.Context, pkg *packages.Package, loader *lazy res.Errs = append(res.Errs, err) return res } - if cacheKey != "" { + if cacheKey != "" && !bypassPackageCache(ctx) { cacheHitStart := time.Now() if cached, ok := readCache(cacheKey); ok { res.Content = cached diff --git a/internal/wire/incremental.go b/internal/wire/incremental.go new file mode 100644 index 0000000..0bc334c --- /dev/null +++ b/internal/wire/incremental.go @@ -0,0 +1,65 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "context" + "strconv" + "strings" +) + +const IncrementalEnvVar = "WIRE_INCREMENTAL" + +type incrementalKey struct{} + +// WithIncremental overrides incremental-mode resolution for the provided +// context. This takes precedence over the environment variable. +func WithIncremental(ctx context.Context, enabled bool) context.Context { + if ctx == nil { + ctx = context.Background() + } + return context.WithValue(ctx, incrementalKey{}, enabled) +} + +// IncrementalEnabled reports whether incremental mode is enabled for the +// current operation. A context override takes precedence over env. +func IncrementalEnabled(ctx context.Context, env []string) bool { + if ctx != nil { + if v := ctx.Value(incrementalKey{}); v != nil { + if enabled, ok := v.(bool); ok { + return enabled + } + } + } + raw, ok := lookupEnv(env, IncrementalEnvVar) + if !ok { + return false + } + enabled, err := strconv.ParseBool(strings.TrimSpace(raw)) + if err != nil { + return false + } + return enabled +} + +func lookupEnv(env []string, key string) (string, bool) { + prefix := key + "=" + for i := len(env) - 1; i >= 0; i-- { + if strings.HasPrefix(env[i], prefix) { + return strings.TrimPrefix(env[i], prefix), true + } + } + return "", false +} diff --git a/internal/wire/incremental_bench_test.go b/internal/wire/incremental_bench_test.go new file mode 100644 index 0000000..b981d23 --- /dev/null +++ b/internal/wire/incremental_bench_test.go @@ -0,0 +1,654 @@ +package wire + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + "text/tabwriter" + "testing" + "time" +) + +const ( + largeBenchmarkTestPackageCount = 24 + largeBenchmarkHelperCount = 12 +) + +var largeBenchmarkSizes = []int{10, 100, 1000} + +func BenchmarkGenerateIncrementalFirstSeenShapeChange(b *testing.B) { + cacheHooksMu.Lock() + state := saveCacheHooks() + b.Cleanup(func() { + restoreCacheHooks(state) + cacheHooksMu.Unlock() + }) + + repoRoot := benchmarkRepoRoot(b) + + for i := 0; i < b.N; i++ { + cacheRoot := b.TempDir() + osTempDir = func() string { return cacheRoot } + + root := b.TempDir() + writeIncrementalBenchmarkModule(b, repoRoot, root) + + env := append(os.Environ(), "GOWORK=off") + ctx := WithIncremental(context.Background(), true) + + if _, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { + b.Fatalf("baseline Generate returned errors: %v", errs) + } + + writeBenchmarkFile(b, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string; Count int }", + "", + "func NewMessage() string { return \"ok\" }", + "", + "func NewCount() int { return 7 }", + "", + "func New(msg string, count int) *Foo {", + "\treturn &Foo{Message: msg, Count: count}", + "}", + "", + }, "\n")) + writeBenchmarkFile(b, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, NewCount, New)", + "", + }, "\n")) + + b.StartTimer() + gens, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + b.StopTimer() + if len(errs) > 0 { + b.Fatalf("incremental shape-change Generate returned errors: %v", errs) + } + if len(gens) != 1 || len(gens[0].Errs) > 0 { + b.Fatalf("unexpected Generate results: %+v", gens) + } + } +} + +func BenchmarkGenerateLargeRepoNormalShapeChange(b *testing.B) { + runLargeRepoShapeChangeBenchmarks(b, false) +} + +func BenchmarkGenerateLargeRepoIncrementalShapeChange(b *testing.B) { + runLargeRepoShapeChangeBenchmarks(b, true) +} + +func TestPrintLargeRepoBenchmarkComparisonTable(t *testing.T) { + if os.Getenv("WIRE_BENCH_TABLE") == "" { + t.Skip("set WIRE_BENCH_TABLE=1 to print the large-repo benchmark comparison table") + } + + cacheHooksMu.Lock() + state := saveCacheHooks() + t.Cleanup(func() { + restoreCacheHooks(state) + cacheHooksMu.Unlock() + }) + + repoRoot := benchmarkRepoRoot(t) + rows := make([]largeRepoBenchmarkRow, 0, len(largeBenchmarkSizes)) + for _, packageCount := range largeBenchmarkSizes { + coldNormal := measureLargeRepoColdOnce(t, repoRoot, packageCount, false) + coldIncremental := measureLargeRepoColdOnce(t, repoRoot, packageCount, true) + normal := measureLargeRepoShapeChangeOnce(t, repoRoot, packageCount, false) + incremental := measureLargeRepoShapeChangeOnce(t, repoRoot, packageCount, true) + knownToggle := measureLargeRepoKnownToggleOnce(t, repoRoot, packageCount) + rows = append(rows, largeRepoBenchmarkRow{ + packageCount: packageCount, + coldNormal: coldNormal, + coldIncremental: coldIncremental, + normal: normal, + incremental: incremental, + knownToggle: knownToggle, + }) + } + + var out strings.Builder + tw := tabwriter.NewWriter(&out, 0, 0, 2, ' ', 0) + fmt.Fprintln(tw, "size\tcold normal\tcold incr\tcold delta\tcold x\tshape normal\tshape incr\tshape delta\tshape x\tknown toggle") + for _, row := range rows { + fmt.Fprintf(tw, "%d\t%s\t%s\t%s\t%.2fx\t%s\t%s\t%s\t%.2fx\t%s\n", + row.packageCount, + formatBenchmarkDuration(row.coldNormal), + formatBenchmarkDuration(row.coldIncremental), + formatPercentImprovement(row.coldNormal, row.coldIncremental), + speedupRatio(row.coldNormal, row.coldIncremental), + formatBenchmarkDuration(row.normal), + formatBenchmarkDuration(row.incremental), + formatPercentImprovement(row.normal, row.incremental), + speedupRatio(row.normal, row.incremental), + formatBenchmarkDuration(row.knownToggle), + ) + } + if err := tw.Flush(); err != nil { + t.Fatalf("flush benchmark table: %v", err) + } + fmt.Print(out.String()) +} + +func TestPrintLargeRepoShapeChangeBreakdownTable(t *testing.T) { + if os.Getenv("WIRE_BENCH_BREAKDOWN") == "" { + t.Skip("set WIRE_BENCH_BREAKDOWN=1 to print the large-repo shape-change breakdown table") + } + + cacheHooksMu.Lock() + state := saveCacheHooks() + t.Cleanup(func() { + restoreCacheHooks(state) + cacheHooksMu.Unlock() + }) + + repoRoot := benchmarkRepoRoot(t) + var out strings.Builder + tw := tabwriter.NewWriter(&out, 0, 0, 2, ' ', 0) + fmt.Fprintln(tw, "size\tnormal total\tbase load\tlazy load\tincr total\tfast load\tfast generate\tspeedup") + for _, packageCount := range largeBenchmarkSizes { + normal := measureLargeRepoShapeChangeTraceOnce(t, repoRoot, packageCount, false) + incremental := measureLargeRepoShapeChangeTraceOnce(t, repoRoot, packageCount, true) + fmt.Fprintf(tw, "%d\t%s\t%s\t%s\t%s\t%s\t%s\t%.2fx\n", + packageCount, + formatBenchmarkDuration(normal.total), + formatBenchmarkDuration(normal.label("load.packages.base.load")), + formatBenchmarkDuration(normal.label("load.packages.lazy.load")), + formatBenchmarkDuration(incremental.total), + formatBenchmarkDuration(incremental.label("incremental.local_fastpath.load")), + formatBenchmarkDuration(incremental.label("incremental.local_fastpath.generate")), + speedupRatio(normal.total, incremental.total), + ) + } + if err := tw.Flush(); err != nil { + t.Fatalf("flush breakdown table: %v", err) + } + fmt.Print(out.String()) +} + +func writeIncrementalBenchmarkModule(tb testing.TB, repoRoot string, root string) { + tb.Helper() + + writeBenchmarkFile(tb, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeBenchmarkFile(tb, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *dep.Foo {", + "\twire.Build(dep.NewSet)", + "\treturn nil", + "}", + "", + }, "\n")) + + writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return \"ok\" }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n")) + + writeBenchmarkFile(tb, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, New)", + "", + }, "\n")) +} + +func TestGenerateIncrementalLargeRepoShapeChangeMatchesNormalGenerate(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := benchmarkRepoRoot(t) + root := t.TempDir() + writeLargeBenchmarkModule(t, repoRoot, root, largeBenchmarkTestPackageCount) + + env := append(os.Environ(), "GOWORK=off") + incrementalCtx := WithIncremental(context.Background(), true) + + if _, errs := Generate(incrementalCtx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { + t.Fatalf("baseline incremental Generate returned errors: %v", errs) + } + + mutateLargeBenchmarkModule(t, root, largeBenchmarkTestPackageCount/2) + + var incrementalLabels []string + incrementalTimedCtx := WithTiming(incrementalCtx, func(label string, _ time.Duration) { + incrementalLabels = append(incrementalLabels, label) + }) + incrementalGens, errs := Generate(incrementalTimedCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("incremental large-repo Generate returned errors: %v", errs) + } + if len(incrementalGens) != 1 || len(incrementalGens[0].Errs) > 0 { + t.Fatalf("unexpected incremental results: %+v", incrementalGens) + } + if !containsLabel(incrementalLabels, "incremental.local_fastpath.load") { + t.Fatalf("expected large-repo shape change to use local fast path, labels=%v", incrementalLabels) + } + + normalGens, errs := Generate(context.Background(), root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("normal large-repo Generate returned errors: %v", errs) + } + if len(normalGens) != 1 || len(normalGens[0].Errs) > 0 { + t.Fatalf("unexpected normal results: %+v", normalGens) + } + if incrementalGens[0].OutputPath != normalGens[0].OutputPath { + t.Fatalf("output paths differ: incremental=%q normal=%q", incrementalGens[0].OutputPath, normalGens[0].OutputPath) + } + if string(incrementalGens[0].Content) != string(normalGens[0].Content) { + t.Fatal("large-repo shape-changing incremental output differs from normal Generate output") + } +} + +func runLargeRepoShapeChangeBenchmarks(b *testing.B, incremental bool) { + cacheHooksMu.Lock() + state := saveCacheHooks() + b.Cleanup(func() { + restoreCacheHooks(state) + cacheHooksMu.Unlock() + }) + + repoRoot := benchmarkRepoRoot(b) + for _, packageCount := range largeBenchmarkSizes { + packageCount := packageCount + b.Run(fmt.Sprintf("size=%d", packageCount), func(b *testing.B) { + for i := 0; i < b.N; i++ { + b.StartTimer() + _ = measureLargeRepoShapeChangeOnce(b, repoRoot, packageCount, incremental) + b.StopTimer() + } + }) + } +} + +type largeRepoBenchmarkRow struct { + packageCount int + coldNormal time.Duration + coldIncremental time.Duration + normal time.Duration + incremental time.Duration + knownToggle time.Duration +} + +type shapeChangeTrace struct { + total time.Duration + labels map[string]time.Duration +} + +func measureLargeRepoShapeChangeOnce(tb testing.TB, repoRoot string, packageCount int, incremental bool) time.Duration { + tb.Helper() + + cacheRoot := tb.TempDir() + osTempDir = func() string { return cacheRoot } + + root := tb.TempDir() + writeLargeBenchmarkModule(tb, repoRoot, root, packageCount) + env := append(os.Environ(), "GOWORK=off") + ctx := context.Background() + if incremental { + ctx = WithIncremental(ctx, true) + } + if _, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { + tb.Fatalf("baseline Generate returned errors: %v", errs) + } + + mutateLargeBenchmarkModule(tb, root, packageCount/2) + + start := time.Now() + gens, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + dur := time.Since(start) + if len(errs) > 0 { + tb.Fatalf("shape-change Generate returned errors: %v", errs) + } + if len(gens) != 1 || len(gens[0].Errs) > 0 { + tb.Fatalf("unexpected Generate results: %+v", gens) + } + return dur +} + +func measureLargeRepoShapeChangeTraceOnce(tb testing.TB, repoRoot string, packageCount int, incremental bool) shapeChangeTrace { + tb.Helper() + + cacheRoot := tb.TempDir() + osTempDir = func() string { return cacheRoot } + + root := tb.TempDir() + writeLargeBenchmarkModule(tb, repoRoot, root, packageCount) + env := append(os.Environ(), "GOWORK=off") + ctx := context.Background() + if incremental { + ctx = WithIncremental(ctx, true) + } + if _, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { + tb.Fatalf("baseline Generate returned errors: %v", errs) + } + + mutateLargeBenchmarkModule(tb, root, packageCount/2) + + trace := shapeChangeTrace{labels: make(map[string]time.Duration)} + ctx = WithTiming(ctx, func(label string, dur time.Duration) { + trace.labels[label] += dur + }) + start := time.Now() + gens, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + trace.total = time.Since(start) + if len(errs) > 0 { + tb.Fatalf("shape-change Generate returned errors: %v", errs) + } + if len(gens) != 1 || len(gens[0].Errs) > 0 { + tb.Fatalf("unexpected Generate results: %+v", gens) + } + return trace +} + +func (s shapeChangeTrace) label(name string) time.Duration { + if s.labels == nil { + return 0 + } + return s.labels[name] +} + +func measureLargeRepoColdOnce(tb testing.TB, repoRoot string, packageCount int, incremental bool) time.Duration { + tb.Helper() + + cacheRoot := tb.TempDir() + osTempDir = func() string { return cacheRoot } + + root := tb.TempDir() + writeLargeBenchmarkModule(tb, repoRoot, root, packageCount) + env := append(os.Environ(), "GOWORK=off") + ctx := context.Background() + if incremental { + ctx = WithIncremental(ctx, true) + } + + start := time.Now() + gens, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + dur := time.Since(start) + if len(errs) > 0 { + tb.Fatalf("cold Generate returned errors: %v", errs) + } + if len(gens) != 1 || len(gens[0].Errs) > 0 { + tb.Fatalf("unexpected cold Generate results: %+v", gens) + } + return dur +} + +func measureLargeRepoKnownToggleOnce(tb testing.TB, repoRoot string, packageCount int) time.Duration { + tb.Helper() + + cacheRoot := tb.TempDir() + osTempDir = func() string { return cacheRoot } + + root := tb.TempDir() + writeLargeBenchmarkModule(tb, repoRoot, root, packageCount) + env := append(os.Environ(), "GOWORK=off") + ctx := WithIncremental(context.Background(), true) + mutatedIndex := packageCount / 2 + + if _, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { + tb.Fatalf("baseline Generate returned errors: %v", errs) + } + + mutateLargeBenchmarkModule(tb, root, mutatedIndex) + gens, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + tb.Fatalf("mutated Generate returned errors: %v", errs) + } + if len(gens) != 1 || len(gens[0].Errs) > 0 { + tb.Fatalf("unexpected mutated Generate results: %+v", gens) + } + + writeLargeBenchmarkPackage(tb, root, mutatedIndex, false) + + start := time.Now() + gens, errs = Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + dur := time.Since(start) + if len(errs) > 0 { + tb.Fatalf("toggle Generate returned errors: %v", errs) + } + if len(gens) != 1 || len(gens[0].Errs) > 0 { + tb.Fatalf("unexpected toggle Generate results: %+v", gens) + } + return dur +} + +func formatPercentImprovement(normal time.Duration, incremental time.Duration) string { + if normal <= 0 { + return "0.0%" + } + improvement := 100 * (float64(normal-incremental) / float64(normal)) + return fmt.Sprintf("%.1f%%", improvement) +} + +func speedupRatio(normal time.Duration, incremental time.Duration) float64 { + if incremental <= 0 { + return 0 + } + return float64(normal) / float64(incremental) +} + +func formatBenchmarkDuration(d time.Duration) string { + switch { + case d >= time.Second: + return fmt.Sprintf("%.2fs", d.Seconds()) + case d >= time.Millisecond: + return fmt.Sprintf("%.2fms", float64(d)/float64(time.Millisecond)) + case d >= time.Microsecond: + return fmt.Sprintf("%.2fµs", float64(d)/float64(time.Microsecond)) + default: + return d.String() + } +} + +func writeLargeBenchmarkModule(tb testing.TB, repoRoot string, root string, packageCount int) { + tb.Helper() + + writeBenchmarkFile(tb, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + wireImports := []string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"github.com/goforj/wire\"", + } + appImports := []string{ + "package app", + "", + "import (", + } + buildArgs := []string{"\twire.Build("} + argNames := make([]string, 0, packageCount) + for i := 0; i < packageCount; i++ { + pkgName := fmt.Sprintf("layer%02d", i) + wireImports = append(wireImports, fmt.Sprintf("\t%s %q", pkgName, "example.com/app/"+pkgName)) + appImports = append(appImports, fmt.Sprintf("\t%s %q", pkgName, "example.com/app/"+pkgName)) + buildArgs = append(buildArgs, fmt.Sprintf("\t\t%s.NewSet,", pkgName)) + argNames = append(argNames, fmt.Sprintf("dep%02d *%s.Token", i, pkgName)) + } + wireImports = append(wireImports, ")", "") + appImports = append(appImports, ")", "") + wireFile := append([]string{}, wireImports...) + wireFile = append(wireFile, "func Init() *App {") + wireFile = append(wireFile, buildArgs...) + wireFile = append(wireFile, "\t\tNewApp,", "\t)", "\treturn nil", "}", "") + writeBenchmarkFile(tb, filepath.Join(root, "app", "wire.go"), strings.Join(wireFile, "\n")) + + appGo := append(appImports[:len(appImports)-2], // reuse imports without trailing blank line + ")", + "", + "type App struct {", + "\tCount int", + "}", + "", + fmt.Sprintf("func NewApp(%s) *App {", strings.Join(argNames, ", ")), + fmt.Sprintf("\treturn &App{Count: %d}", packageCount), + "}", + "", + ) + writeBenchmarkFile(tb, filepath.Join(root, "app", "app.go"), strings.Join(appGo, "\n")) + + for i := 0; i < packageCount; i++ { + writeLargeBenchmarkPackage(tb, root, i, false) + } +} + +func mutateLargeBenchmarkModule(tb testing.TB, root string, mutatedIndex int) { + tb.Helper() + writeLargeBenchmarkPackage(tb, root, mutatedIndex, true) +} + +func writeLargeBenchmarkPackage(tb testing.TB, root string, index int, mutated bool) { + tb.Helper() + + pkgName := fmt.Sprintf("layer%02d", index) + pkgDir := filepath.Join(root, pkgName) + + writeBenchmarkFile(tb, filepath.Join(pkgDir, "helpers.go"), renderLargeBenchmarkHelpers(pkgName, index, mutated)) + writeBenchmarkFile(tb, filepath.Join(pkgDir, "wire.go"), renderLargeBenchmarkWire(pkgName, mutated)) +} + +func renderLargeBenchmarkHelpers(pkgName string, index int, mutated bool) string { + lines := []string{ + "package " + pkgName, + "", + "import (", + "\t\"fmt\"", + "\t\"strconv\"", + "\t\"strings\"", + ")", + "", + "type Config struct {", + "\tLabel string", + "}", + "", + "type Weight int", + "", + "type Token struct {", + "\tConfig Config", + "\tWeight Weight", + "}", + "", + fmt.Sprintf("func NewConfig() Config { return Config{Label: %q} }", pkgName), + "", + } + if mutated { + lines = append(lines, + fmt.Sprintf("func NewWeight() Weight { return Weight(%d) }", index+100), + "", + "func New(cfg Config, weight Weight) *Token {", + fmt.Sprintf("\t_ = helper%02d()", largeBenchmarkHelperCount-1), + "\treturn &Token{Config: cfg, Weight: weight}", + "}", + "", + ) + } else { + lines = append(lines, + "func New(cfg Config) *Token {", + fmt.Sprintf("\t_ = helper%02d()", largeBenchmarkHelperCount-1), + "\treturn &Token{Config: cfg}", + "}", + "", + ) + } + for i := 0; i < largeBenchmarkHelperCount; i++ { + lines = append(lines, fmt.Sprintf("func helper%02d() string {", i)) + lines = append(lines, fmt.Sprintf("\treturn strings.ToUpper(fmt.Sprintf(\"%%s-%%d\", %q, %d)) + strconv.Itoa(%d)", pkgName, i, index+i)) + lines = append(lines, "}", "") + } + return strings.Join(lines, "\n") +} + +func renderLargeBenchmarkWire(pkgName string, mutated bool) string { + lines := []string{ + "package " + pkgName, + "", + "import (", + "\t\"github.com/goforj/wire\"", + ")", + "", + } + if mutated { + lines = append(lines, "var NewSet = wire.NewSet(NewConfig, NewWeight, New)", "") + } else { + lines = append(lines, "var NewSet = wire.NewSet(NewConfig, New)", "") + } + return strings.Join(lines, "\n") +} + +func strconvQuote(s string) string { + return fmt.Sprintf("%q", s) +} + +func benchmarkRepoRoot(tb testing.TB) string { + tb.Helper() + wd, err := os.Getwd() + if err != nil { + tb.Fatalf("Getwd failed: %v", err) + } + repoRoot := filepath.Clean(filepath.Join(wd, "..", "..")) + if _, err := os.Stat(filepath.Join(repoRoot, "go.mod")); err != nil { + tb.Fatalf("repo root not found at %s: %v", repoRoot, err) + } + return repoRoot +} + +func writeBenchmarkFile(tb testing.TB, path string, content string) { + tb.Helper() + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + tb.Fatalf("MkdirAll failed: %v", err) + } + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + tb.Fatalf("WriteFile failed: %v", err) + } +} diff --git a/internal/wire/incremental_fingerprint.go b/internal/wire/incremental_fingerprint.go new file mode 100644 index 0000000..886d07f --- /dev/null +++ b/internal/wire/incremental_fingerprint.go @@ -0,0 +1,421 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/binary" + "fmt" + "go/ast" + "go/parser" + "go/printer" + "go/token" + "path/filepath" + "sort" + "strings" + + "golang.org/x/tools/go/packages" +) + +const incrementalFingerprintVersion = "wire-incremental-v1" + +type packageFingerprint struct { + Version string + WD string + Tags string + PkgPath string + Files []cacheFile + ShapeHash string + LocalImports []string +} + +type fingerprintStats struct { + localPackages int + metaHits int + metaMisses int + unchanged int + changed int +} + +type incrementalFingerprintSnapshot struct { + stats fingerprintStats + changed []string + fingerprints map[string]*packageFingerprint +} + +func analyzeIncrementalFingerprints(ctx context.Context, wd string, env []string, tags string, pkgs []*packages.Package) *incrementalFingerprintSnapshot { + if !IncrementalEnabled(ctx, env) { + return nil + } + start := timeNow() + snapshot := collectIncrementalFingerprints(wd, tags, pkgs) + debugf(ctx, "incremental.fingerprint local_pkgs=%d meta_hits=%d meta_misses=%d unchanged=%d changed=%d total=%s", + snapshot.stats.localPackages, + snapshot.stats.metaHits, + snapshot.stats.metaMisses, + snapshot.stats.unchanged, + snapshot.stats.changed, + timeSince(start), + ) + if len(snapshot.changed) > 0 { + debugf(ctx, "incremental.fingerprint changed_pkgs=%s", strings.Join(snapshot.changed, ", ")) + } + return snapshot +} + +func collectIncrementalFingerprints(wd string, tags string, pkgs []*packages.Package) *incrementalFingerprintSnapshot { + all := collectAllPackages(pkgs) + moduleRoot := findModuleRoot(wd) + snapshot := &incrementalFingerprintSnapshot{ + fingerprints: make(map[string]*packageFingerprint), + } + for _, pkg := range all { + if classifyPackageLocation(moduleRoot, pkg) != "local" { + continue + } + snapshot.stats.localPackages++ + files := packageFingerprintFiles(pkg) + if len(files) == 0 { + continue + } + sort.Strings(files) + metaFiles, err := buildCacheFiles(files) + if err != nil { + snapshot.stats.metaMisses++ + continue + } + key := incrementalFingerprintKey(wd, tags, pkg.PkgPath) + if prev, ok := readIncrementalFingerprint(key); ok && incrementalFingerprintMetaMatches(prev, wd, tags, pkg.PkgPath, metaFiles) { + snapshot.stats.metaHits++ + snapshot.stats.unchanged++ + snapshot.fingerprints[pkg.PkgPath] = prev + continue + } + snapshot.stats.metaMisses++ + fp, err := buildPackageFingerprint(wd, tags, pkg, metaFiles) + if err != nil { + continue + } + prev, hadPrev := readIncrementalFingerprint(key) + writeIncrementalFingerprint(key, fp) + snapshot.fingerprints[pkg.PkgPath] = fp + if hadPrev && incrementalFingerprintEquivalent(prev, fp) { + snapshot.stats.unchanged++ + continue + } + snapshot.stats.changed++ + snapshot.changed = append(snapshot.changed, pkg.PkgPath) + } + sort.Strings(snapshot.changed) + return snapshot +} + +func packageFingerprintFiles(pkg *packages.Package) []string { + if pkg == nil { + return nil + } + if len(pkg.CompiledGoFiles) > 0 { + return append([]string(nil), pkg.CompiledGoFiles...) + } + return append([]string(nil), pkg.GoFiles...) +} + +func incrementalFingerprintEquivalent(a, b *packageFingerprint) bool { + if a == nil || b == nil { + return false + } + if a.ShapeHash != b.ShapeHash || a.PkgPath != b.PkgPath || a.Tags != b.Tags || filepath.Clean(a.WD) != filepath.Clean(b.WD) { + return false + } + if len(a.LocalImports) != len(b.LocalImports) { + return false + } + for i := range a.LocalImports { + if a.LocalImports[i] != b.LocalImports[i] { + return false + } + } + return true +} + +func incrementalFingerprintMetaMatches(prev *packageFingerprint, wd string, tags string, pkgPath string, files []cacheFile) bool { + if prev == nil || prev.Version != incrementalFingerprintVersion { + return false + } + if filepath.Clean(prev.WD) != filepath.Clean(wd) || prev.Tags != tags || prev.PkgPath != pkgPath { + return false + } + if len(prev.Files) != len(files) { + return false + } + for i := range prev.Files { + if prev.Files[i] != files[i] { + return false + } + } + return true +} + +func buildPackageFingerprint(wd string, tags string, pkg *packages.Package, files []cacheFile) (*packageFingerprint, error) { + shapeHash, err := packageShapeHash(packageFingerprintFiles(pkg)) + if err != nil { + return nil, err + } + localImports := make([]string, 0, len(pkg.Imports)) + moduleRoot := findModuleRoot(wd) + for _, imp := range pkg.Imports { + if classifyPackageLocation(moduleRoot, imp) == "local" { + localImports = append(localImports, imp.PkgPath) + } + } + sort.Strings(localImports) + return &packageFingerprint{ + Version: incrementalFingerprintVersion, + WD: filepath.Clean(wd), + Tags: tags, + PkgPath: pkg.PkgPath, + Files: append([]cacheFile(nil), files...), + ShapeHash: shapeHash, + LocalImports: localImports, + }, nil +} + +func packageShapeHash(files []string) (string, error) { + fset := token.NewFileSet() + var buf bytes.Buffer + for _, name := range files { + file, err := parser.ParseFile(fset, name, nil, parser.SkipObjectResolution) + if err != nil { + return "", err + } + stripFunctionBodies(file) + if err := printer.Fprint(&buf, fset, file); err != nil { + return "", err + } + buf.WriteByte(0) + } + sum := sha256.Sum256(buf.Bytes()) + return fmt.Sprintf("%x", sum[:]), nil +} + +func stripFunctionBodies(file *ast.File) { + if file == nil { + return + } + for _, decl := range file.Decls { + if fn, ok := decl.(*ast.FuncDecl); ok { + fn.Body = nil + fn.Doc = nil + } + } +} + +func incrementalFingerprintKey(wd string, tags string, pkgPath string) string { + h := sha256.New() + h.Write([]byte(incrementalFingerprintVersion)) + h.Write([]byte{0}) + h.Write([]byte(filepath.Clean(wd))) + h.Write([]byte{0}) + h.Write([]byte(tags)) + h.Write([]byte{0}) + h.Write([]byte(pkgPath)) + return fmt.Sprintf("%x", h.Sum(nil)) +} + +func incrementalFingerprintPath(key string) string { + return filepath.Join(cacheDir(), key+".ifp") +} + +func readIncrementalFingerprint(key string) (*packageFingerprint, bool) { + data, err := osReadFile(incrementalFingerprintPath(key)) + if err != nil { + return nil, false + } + fp, err := decodeIncrementalFingerprint(data) + if err != nil { + return nil, false + } + return fp, true +} + +func writeIncrementalFingerprint(key string, fp *packageFingerprint) { + data, err := encodeIncrementalFingerprint(fp) + if err != nil { + return + } + dir := cacheDir() + if err := osMkdirAll(dir, 0755); err != nil { + return + } + tmp, err := osCreateTemp(dir, key+".ifp-") + if err != nil { + return + } + _, writeErr := tmp.Write(data) + closeErr := tmp.Close() + if writeErr != nil || closeErr != nil { + osRemove(tmp.Name()) + return + } + if err := osRename(tmp.Name(), incrementalFingerprintPath(key)); err != nil { + osRemove(tmp.Name()) + } +} + +func encodeIncrementalFingerprint(fp *packageFingerprint) ([]byte, error) { + var buf bytes.Buffer + writeString := func(s string) error { + if err := binary.Write(&buf, binary.LittleEndian, uint32(len(s))); err != nil { + return err + } + _, err := buf.WriteString(s) + return err + } + writeCacheFiles := func(files []cacheFile) error { + if err := binary.Write(&buf, binary.LittleEndian, uint32(len(files))); err != nil { + return err + } + for _, f := range files { + if err := writeString(f.Path); err != nil { + return err + } + if err := binary.Write(&buf, binary.LittleEndian, f.Size); err != nil { + return err + } + if err := binary.Write(&buf, binary.LittleEndian, f.ModTime); err != nil { + return err + } + } + return nil + } + writeStrings := func(items []string) error { + if err := binary.Write(&buf, binary.LittleEndian, uint32(len(items))); err != nil { + return err + } + for _, item := range items { + if err := writeString(item); err != nil { + return err + } + } + return nil + } + if fp == nil { + return nil, fmt.Errorf("nil fingerprint") + } + for _, s := range []string{fp.Version, fp.WD, fp.Tags, fp.PkgPath, fp.ShapeHash} { + if err := writeString(s); err != nil { + return nil, err + } + } + if err := writeCacheFiles(fp.Files); err != nil { + return nil, err + } + if err := writeStrings(fp.LocalImports); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func decodeIncrementalFingerprint(data []byte) (*packageFingerprint, error) { + r := bytes.NewReader(data) + readString := func() (string, error) { + var n uint32 + if err := binary.Read(r, binary.LittleEndian, &n); err != nil { + return "", err + } + buf := make([]byte, n) + if _, err := r.Read(buf); err != nil { + return "", err + } + return string(buf), nil + } + readCacheFiles := func() ([]cacheFile, error) { + var n uint32 + if err := binary.Read(r, binary.LittleEndian, &n); err != nil { + return nil, err + } + out := make([]cacheFile, 0, n) + for i := uint32(0); i < n; i++ { + path, err := readString() + if err != nil { + return nil, err + } + var size int64 + if err := binary.Read(r, binary.LittleEndian, &size); err != nil { + return nil, err + } + var modTime int64 + if err := binary.Read(r, binary.LittleEndian, &modTime); err != nil { + return nil, err + } + out = append(out, cacheFile{Path: path, Size: size, ModTime: modTime}) + } + return out, nil + } + readStrings := func() ([]string, error) { + var n uint32 + if err := binary.Read(r, binary.LittleEndian, &n); err != nil { + return nil, err + } + out := make([]string, 0, n) + for i := uint32(0); i < n; i++ { + item, err := readString() + if err != nil { + return nil, err + } + out = append(out, item) + } + return out, nil + } + version, err := readString() + if err != nil { + return nil, err + } + wd, err := readString() + if err != nil { + return nil, err + } + tags, err := readString() + if err != nil { + return nil, err + } + pkgPath, err := readString() + if err != nil { + return nil, err + } + shapeHash, err := readString() + if err != nil { + return nil, err + } + files, err := readCacheFiles() + if err != nil { + return nil, err + } + localImports, err := readStrings() + if err != nil { + return nil, err + } + return &packageFingerprint{ + Version: version, + WD: wd, + Tags: tags, + PkgPath: pkgPath, + ShapeHash: shapeHash, + Files: files, + LocalImports: localImports, + }, nil +} diff --git a/internal/wire/incremental_fingerprint_test.go b/internal/wire/incremental_fingerprint_test.go new file mode 100644 index 0000000..afe81de --- /dev/null +++ b/internal/wire/incremental_fingerprint_test.go @@ -0,0 +1,104 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "os" + "path/filepath" + "testing" + + "golang.org/x/tools/go/packages" +) + +func TestPackageShapeHashIgnoresFunctionBodies(t *testing.T) { + dir := t.TempDir() + file := writeTempFile(t, dir, "pkg.go", "package p\n\nfunc Hello() string { return \"a\" }\n") + hash1, err := packageShapeHash([]string{file}) + if err != nil { + t.Fatalf("packageShapeHash first failed: %v", err) + } + if err := os.WriteFile(file, []byte("package p\n\nfunc Hello() string { return \"b\" }\n"), 0644); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + hash2, err := packageShapeHash([]string{file}) + if err != nil { + t.Fatalf("packageShapeHash second failed: %v", err) + } + if hash1 != hash2 { + t.Fatalf("body-only change should not affect shape hash: %q vs %q", hash1, hash2) + } +} + +func TestIncrementalFingerprintRoundTrip(t *testing.T) { + fp := &packageFingerprint{ + Version: incrementalFingerprintVersion, + WD: "/tmp/app", + Tags: "dev", + PkgPath: "example.com/app", + ShapeHash: "shape", + Files: []cacheFile{{Path: "/tmp/app/pkg.go", Size: 12, ModTime: 34}}, + LocalImports: []string{"example.com/dep"}, + } + data, err := encodeIncrementalFingerprint(fp) + if err != nil { + t.Fatalf("encodeIncrementalFingerprint failed: %v", err) + } + got, err := decodeIncrementalFingerprint(data) + if err != nil { + t.Fatalf("decodeIncrementalFingerprint failed: %v", err) + } + if !incrementalFingerprintEquivalent(fp, got) { + t.Fatalf("fingerprint mismatch after round-trip: got %+v want %+v", got, fp) + } + if len(got.Files) != 1 || got.Files[0] != fp.Files[0] { + t.Fatalf("file metadata mismatch after round-trip: got %+v want %+v", got.Files, fp.Files) + } +} + +func TestCollectIncrementalFingerprintsTreatsBodyOnlyChangeAsUnchanged(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + root := t.TempDir() + writeFile(t, filepath.Join(root, "go.mod"), "module example.com/test\n\ngo 1.21\n") + file := filepath.Join(root, "app", "app.go") + writeFile(t, file, "package app\n\nfunc Hello() string { return \"a\" }\n") + pkg := &packages.Package{ + PkgPath: "example.com/app", + CompiledGoFiles: []string{file}, + GoFiles: []string{file}, + Imports: map[string]*packages.Package{}, + } + + snapshot := collectIncrementalFingerprints(root, "", []*packages.Package{pkg}) + if snapshot.stats.changed != 1 || len(snapshot.changed) != 1 || snapshot.changed[0] != pkg.PkgPath { + t.Fatalf("first run stats=%+v changed=%v", snapshot.stats, snapshot.changed) + } + + if err := os.WriteFile(file, []byte("package app\n\nfunc Hello() string { return \"b\" }\n"), 0644); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + snapshot = collectIncrementalFingerprints(root, "", []*packages.Package{pkg}) + if snapshot.stats.unchanged != 1 { + t.Fatalf("body-only change should be unchanged by shape, stats=%+v changed=%v", snapshot.stats, snapshot.changed) + } + if len(snapshot.changed) != 0 { + t.Fatalf("body-only change should not report changed packages, got %v", snapshot.changed) + } +} diff --git a/internal/wire/incremental_graph.go b/internal/wire/incremental_graph.go new file mode 100644 index 0000000..66cf28d --- /dev/null +++ b/internal/wire/incremental_graph.go @@ -0,0 +1,306 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/binary" + "fmt" + "path/filepath" + "sort" + "strings" + + "golang.org/x/tools/go/packages" +) + +const incrementalGraphVersion = "wire-incremental-graph-v1" + +type incrementalGraph struct { + Version string + WD string + Tags string + Roots []string + LocalReverse map[string][]string +} + +func analyzeIncrementalGraph(ctx context.Context, wd string, env []string, tags string, pkgs []*packages.Package, snapshot *incrementalFingerprintSnapshot) { + if !IncrementalEnabled(ctx, env) || snapshot == nil { + return + } + graph := buildIncrementalGraph(wd, tags, pkgs) + writeIncrementalGraph(incrementalGraphKey(wd, tags, graph.Roots), graph) + if len(snapshot.changed) == 0 { + return + } + affected := affectedRoots(graph, snapshot.changed) + if len(affected) > 0 { + debugf(ctx, "incremental.graph changed=%s affected_roots=%s", stringsJoin(snapshot.changed), stringsJoin(affected)) + } else { + debugf(ctx, "incremental.graph changed=%s affected_roots=", stringsJoin(snapshot.changed)) + } +} + +func buildIncrementalGraph(wd string, tags string, pkgs []*packages.Package) *incrementalGraph { + moduleRoot := findModuleRoot(wd) + graph := &incrementalGraph{ + Version: incrementalGraphVersion, + WD: filepath.Clean(wd), + Tags: tags, + Roots: make([]string, 0, len(pkgs)), + LocalReverse: make(map[string][]string), + } + for _, pkg := range pkgs { + if pkg == nil { + continue + } + graph.Roots = append(graph.Roots, pkg.PkgPath) + } + sort.Strings(graph.Roots) + for _, pkg := range collectAllPackages(pkgs) { + if classifyPackageLocation(moduleRoot, pkg) != "local" { + continue + } + for _, imp := range pkg.Imports { + if classifyPackageLocation(moduleRoot, imp) != "local" { + continue + } + graph.LocalReverse[imp.PkgPath] = append(graph.LocalReverse[imp.PkgPath], pkg.PkgPath) + } + } + for path := range graph.LocalReverse { + sort.Strings(graph.LocalReverse[path]) + } + return graph +} + +func affectedRoots(graph *incrementalGraph, changed []string) []string { + if graph == nil || len(changed) == 0 { + return nil + } + rootSet := make(map[string]struct{}, len(graph.Roots)) + for _, root := range graph.Roots { + rootSet[root] = struct{}{} + } + seen := make(map[string]struct{}) + queue := append([]string(nil), changed...) + affected := make(map[string]struct{}) + for len(queue) > 0 { + cur := queue[0] + queue = queue[1:] + if _, ok := seen[cur]; ok { + continue + } + seen[cur] = struct{}{} + if _, ok := rootSet[cur]; ok { + affected[cur] = struct{}{} + } + for _, next := range graph.LocalReverse[cur] { + if _, ok := seen[next]; !ok { + queue = append(queue, next) + } + } + } + out := make([]string, 0, len(affected)) + for root := range affected { + out = append(out, root) + } + sort.Strings(out) + return out +} + +func incrementalGraphKey(wd string, tags string, roots []string) string { + h := sha256.New() + h.Write([]byte(incrementalGraphVersion)) + h.Write([]byte{0}) + h.Write([]byte(filepath.Clean(wd))) + h.Write([]byte{0}) + h.Write([]byte(tags)) + h.Write([]byte{0}) + for _, root := range roots { + h.Write([]byte(root)) + h.Write([]byte{0}) + } + return fmt.Sprintf("%x", h.Sum(nil)) +} + +func incrementalGraphPath(key string) string { + return filepath.Join(cacheDir(), key+".igr") +} + +func writeIncrementalGraph(key string, graph *incrementalGraph) { + data, err := encodeIncrementalGraph(graph) + if err != nil { + return + } + dir := cacheDir() + if err := osMkdirAll(dir, 0755); err != nil { + return + } + tmp, err := osCreateTemp(dir, key+".igr-") + if err != nil { + return + } + _, writeErr := tmp.Write(data) + closeErr := tmp.Close() + if writeErr != nil || closeErr != nil { + osRemove(tmp.Name()) + return + } + if err := osRename(tmp.Name(), incrementalGraphPath(key)); err != nil { + osRemove(tmp.Name()) + } +} + +func readIncrementalGraph(key string) (*incrementalGraph, bool) { + data, err := osReadFile(incrementalGraphPath(key)) + if err != nil { + return nil, false + } + graph, err := decodeIncrementalGraph(data) + if err != nil { + return nil, false + } + return graph, true +} + +func encodeIncrementalGraph(graph *incrementalGraph) ([]byte, error) { + if graph == nil { + return nil, fmt.Errorf("nil incremental graph") + } + var buf bytes.Buffer + writeString := func(s string) error { + if err := binary.Write(&buf, binary.LittleEndian, uint32(len(s))); err != nil { + return err + } + _, err := buf.WriteString(s) + return err + } + for _, s := range []string{graph.Version, graph.WD, graph.Tags} { + if err := writeString(s); err != nil { + return nil, err + } + } + if err := binary.Write(&buf, binary.LittleEndian, uint32(len(graph.Roots))); err != nil { + return nil, err + } + for _, root := range graph.Roots { + if err := writeString(root); err != nil { + return nil, err + } + } + keys := make([]string, 0, len(graph.LocalReverse)) + for k := range graph.LocalReverse { + keys = append(keys, k) + } + sort.Strings(keys) + if err := binary.Write(&buf, binary.LittleEndian, uint32(len(keys))); err != nil { + return nil, err + } + for _, k := range keys { + if err := writeString(k); err != nil { + return nil, err + } + children := append([]string(nil), graph.LocalReverse[k]...) + sort.Strings(children) + if err := binary.Write(&buf, binary.LittleEndian, uint32(len(children))); err != nil { + return nil, err + } + for _, child := range children { + if err := writeString(child); err != nil { + return nil, err + } + } + } + return buf.Bytes(), nil +} + +func decodeIncrementalGraph(data []byte) (*incrementalGraph, error) { + r := bytes.NewReader(data) + readString := func() (string, error) { + var n uint32 + if err := binary.Read(r, binary.LittleEndian, &n); err != nil { + return "", err + } + buf := make([]byte, n) + if _, err := r.Read(buf); err != nil { + return "", err + } + return string(buf), nil + } + version, err := readString() + if err != nil { + return nil, err + } + wd, err := readString() + if err != nil { + return nil, err + } + tags, err := readString() + if err != nil { + return nil, err + } + var rootCount uint32 + if err := binary.Read(r, binary.LittleEndian, &rootCount); err != nil { + return nil, err + } + roots := make([]string, 0, rootCount) + for i := uint32(0); i < rootCount; i++ { + root, err := readString() + if err != nil { + return nil, err + } + roots = append(roots, root) + } + var edgeCount uint32 + if err := binary.Read(r, binary.LittleEndian, &edgeCount); err != nil { + return nil, err + } + reverse := make(map[string][]string, edgeCount) + for i := uint32(0); i < edgeCount; i++ { + k, err := readString() + if err != nil { + return nil, err + } + var n uint32 + if err := binary.Read(r, binary.LittleEndian, &n); err != nil { + return nil, err + } + children := make([]string, 0, n) + for j := uint32(0); j < n; j++ { + child, err := readString() + if err != nil { + return nil, err + } + children = append(children, child) + } + reverse[k] = children + } + return &incrementalGraph{ + Version: version, + WD: wd, + Tags: tags, + Roots: roots, + LocalReverse: reverse, + }, nil +} + +func stringsJoin(items []string) string { + if len(items) == 0 { + return "" + } + return strings.Join(items, ",") +} diff --git a/internal/wire/incremental_graph_test.go b/internal/wire/incremental_graph_test.go new file mode 100644 index 0000000..8a91b54 --- /dev/null +++ b/internal/wire/incremental_graph_test.go @@ -0,0 +1,97 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "path/filepath" + "reflect" + "testing" + + "golang.org/x/tools/go/packages" +) + +func TestIncrementalGraphRoundTrip(t *testing.T) { + graph := &incrementalGraph{ + Version: incrementalGraphVersion, + WD: "/tmp/app", + Tags: "dev", + Roots: []string{"example.com/app", "example.com/other"}, + LocalReverse: map[string][]string{ + "example.com/dep": {"example.com/app"}, + "example.com/sub": {"example.com/dep", "example.com/other"}, + }, + } + data, err := encodeIncrementalGraph(graph) + if err != nil { + t.Fatalf("encodeIncrementalGraph failed: %v", err) + } + got, err := decodeIncrementalGraph(data) + if err != nil { + t.Fatalf("decodeIncrementalGraph failed: %v", err) + } + if !reflect.DeepEqual(got, graph) { + t.Fatalf("graph round-trip mismatch:\n got=%+v\nwant=%+v", got, graph) + } +} + +func TestAffectedRoots(t *testing.T) { + graph := &incrementalGraph{ + Roots: []string{"example.com/app", "example.com/other"}, + LocalReverse: map[string][]string{ + "example.com/dep": {"example.com/app"}, + "example.com/sub": {"example.com/dep", "example.com/other"}, + }, + } + got := affectedRoots(graph, []string{"example.com/sub"}) + want := []string{"example.com/app", "example.com/other"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("affectedRoots=%v want %v", got, want) + } +} + +func TestBuildIncrementalGraph(t *testing.T) { + root := t.TempDir() + writeFile(t, filepath.Join(root, "go.mod"), "module example.com/test\n\ngo 1.21\n") + + appFile := filepath.Join(root, "app", "app.go") + depFile := filepath.Join(root, "dep", "dep.go") + writeFile(t, appFile, "package app\n") + writeFile(t, depFile, "package dep\n") + + dep := &packages.Package{ + PkgPath: "example.com/test/dep", + CompiledGoFiles: []string{depFile}, + GoFiles: []string{depFile}, + Imports: map[string]*packages.Package{}, + } + app := &packages.Package{ + PkgPath: "example.com/test/app", + CompiledGoFiles: []string{appFile}, + GoFiles: []string{appFile}, + Imports: map[string]*packages.Package{ + "example.com/test/dep": dep, + }, + } + + graph := buildIncrementalGraph(root, "", []*packages.Package{app}) + if len(graph.Roots) != 1 || graph.Roots[0] != app.PkgPath { + t.Fatalf("unexpected roots: %v", graph.Roots) + } + got := graph.LocalReverse[dep.PkgPath] + want := []string{app.PkgPath} + if !reflect.DeepEqual(got, want) { + t.Fatalf("unexpected reverse edges: got=%v want=%v", got, want) + } +} diff --git a/internal/wire/incremental_manifest.go b/internal/wire/incremental_manifest.go new file mode 100644 index 0000000..ae36c77 --- /dev/null +++ b/internal/wire/incremental_manifest.go @@ -0,0 +1,876 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/binary" + "fmt" + "os" + "path/filepath" + "sort" + "strings" + + "golang.org/x/tools/go/packages" +) + +const incrementalManifestVersion = "wire-incremental-manifest-v1" + +type incrementalManifest struct { + Version string + WD string + Tags string + Prefix string + HeaderHash string + EnvHash string + Patterns []string + LocalPackages []packageFingerprint + ExternalPkgs []externalPackageExport + ExternalFiles []cacheFile + ExtraFiles []cacheFile + Outputs []incrementalOutput +} + +type externalPackageExport struct { + PkgPath string + ExportFile string +} + +type incrementalOutput struct { + PkgPath string + OutputPath string + ContentKey string +} + +type incrementalPreloadState struct { + selectorKey string + manifest *incrementalManifest + valid bool + currentLocal []packageFingerprint + reason string +} + +func readPreloadIncrementalManifestResults(ctx context.Context, wd string, env []string, patterns []string, opts *GenerateOptions) ([]GenerateResult, bool) { + state, ok := prepareIncrementalPreloadState(ctx, wd, env, patterns, opts) + return readPreloadIncrementalManifestResultsFromState(ctx, wd, env, patterns, opts, state, ok) +} + +func readPreloadIncrementalManifestResultsFromState(ctx context.Context, wd string, env []string, patterns []string, opts *GenerateOptions, state *incrementalPreloadState, ok bool) ([]GenerateResult, bool) { + if !ok { + debugf(ctx, "incremental.preload_manifest miss reason=no_manifest") + return nil, false + } + if state.valid { + results, ok := incrementalManifestOutputs(state.manifest) + if !ok { + debugf(ctx, "incremental.preload_manifest miss reason=outputs") + return nil, false + } + debugf(ctx, "incremental.preload_manifest hit outputs=%d", len(results)) + return results, true + } else if archived := readStateIncrementalManifest(state.selectorKey, state.currentLocal); archived != nil { + if ok, _, _ := incrementalManifestPreloadValid(ctx, archived, wd, env, patterns, opts); ok { + results, ok := incrementalManifestOutputs(archived) + if !ok { + debugf(ctx, "incremental.preload_manifest miss reason=state_outputs") + return nil, false + } + writeIncrementalManifestFile(state.selectorKey, archived) + debugf(ctx, "incremental.preload_manifest state_hit outputs=%d", len(results)) + return results, true + } + debugf(ctx, "incremental.preload_manifest miss reason=%s", state.reason) + return nil, false + } else { + debugf(ctx, "incremental.preload_manifest miss reason=%s", state.reason) + return nil, false + } +} + +func prepareIncrementalPreloadState(ctx context.Context, wd string, env []string, patterns []string, opts *GenerateOptions) (*incrementalPreloadState, bool) { + selectorKey := incrementalManifestSelectorKey(wd, env, patterns, opts) + manifest, ok := readIncrementalManifest(selectorKey) + if !ok { + return nil, false + } + valid, currentLocal, reason := incrementalManifestPreloadValid(ctx, manifest, wd, env, patterns, opts) + return &incrementalPreloadState{ + selectorKey: selectorKey, + manifest: manifest, + valid: valid, + currentLocal: currentLocal, + reason: reason, + }, true +} + +func readIncrementalManifestResults(ctx context.Context, wd string, env []string, patterns []string, opts *GenerateOptions, pkgs []*packages.Package, snapshot *incrementalFingerprintSnapshot) ([]GenerateResult, bool) { + if snapshot == nil || snapshot.stats.changed != 0 { + return nil, false + } + key := incrementalManifestSelectorKey(wd, env, patterns, opts) + manifest, ok := readIncrementalManifest(key) + if !ok || !incrementalManifestValid(manifest, wd, env, patterns, opts, pkgs) { + return nil, false + } + results := make([]GenerateResult, 0, len(manifest.Outputs)) + for _, out := range manifest.Outputs { + content, ok := readCache(out.ContentKey) + if !ok { + return nil, false + } + results = append(results, GenerateResult{ + PkgPath: out.PkgPath, + OutputPath: out.OutputPath, + Content: content, + }) + } + debugf(ctx, "incremental.manifest hit outputs=%d", len(results)) + return results, true +} + +func writeIncrementalManifest(wd string, env []string, patterns []string, opts *GenerateOptions, pkgs []*packages.Package, snapshot *incrementalFingerprintSnapshot, generated []GenerateResult) { + if snapshot == nil || len(generated) == 0 { + return + } + externalPkgs := buildExternalPackageExports(wd, pkgs) + externalFiles, err := buildExternalPackageFiles(wd, pkgs) + if err != nil { + return + } + manifest := &incrementalManifest{ + Version: incrementalManifestVersion, + WD: filepath.Clean(wd), + Tags: opts.Tags, + Prefix: opts.PrefixOutputFile, + HeaderHash: headerHash(opts.Header), + EnvHash: envHash(env), + Patterns: sortedStrings(patterns), + LocalPackages: snapshotPackageFingerprints(snapshot), + ExternalPkgs: externalPkgs, + ExternalFiles: externalFiles, + ExtraFiles: extraCacheFiles(wd), + } + for _, out := range generated { + if len(out.Content) == 0 || out.OutputPath == "" { + continue + } + contentKey := incrementalContentKey(out.Content) + writeCache(contentKey, out.Content) + manifest.Outputs = append(manifest.Outputs, incrementalOutput{ + PkgPath: out.PkgPath, + OutputPath: out.OutputPath, + ContentKey: contentKey, + }) + } + if len(manifest.Outputs) == 0 { + return + } + selectorKey := incrementalManifestSelectorKey(wd, env, patterns, opts) + stateKey := incrementalManifestStateKey(selectorKey, manifest.LocalPackages) + writeIncrementalManifestFile(selectorKey, manifest) + writeIncrementalManifestFile(stateKey, manifest) +} + +func incrementalManifestSelectorKey(wd string, env []string, patterns []string, opts *GenerateOptions) string { + h := sha256.New() + h.Write([]byte(incrementalManifestVersion)) + h.Write([]byte{0}) + h.Write([]byte(filepath.Clean(wd))) + h.Write([]byte{0}) + h.Write([]byte(envHash(env))) + h.Write([]byte{0}) + h.Write([]byte(opts.Tags)) + h.Write([]byte{0}) + h.Write([]byte(opts.PrefixOutputFile)) + h.Write([]byte{0}) + h.Write([]byte(headerHash(opts.Header))) + h.Write([]byte{0}) + for _, p := range sortedStrings(patterns) { + h.Write([]byte(p)) + h.Write([]byte{0}) + } + return fmt.Sprintf("%x", h.Sum(nil)) +} + +func snapshotPackageFingerprints(snapshot *incrementalFingerprintSnapshot) []packageFingerprint { + if snapshot == nil || len(snapshot.fingerprints) == 0 { + return nil + } + paths := make([]string, 0, len(snapshot.fingerprints)) + for path := range snapshot.fingerprints { + paths = append(paths, path) + } + sort.Strings(paths) + out := make([]packageFingerprint, 0, len(paths)) + for _, path := range paths { + if fp := snapshot.fingerprints[path]; fp != nil { + out = append(out, *fp) + } + } + return out +} + +func buildExternalPackageFiles(wd string, pkgs []*packages.Package) ([]cacheFile, error) { + moduleRoot := findModuleRoot(wd) + seen := make(map[string]struct{}) + var files []string + for _, pkg := range collectAllPackages(pkgs) { + if classifyPackageLocation(moduleRoot, pkg) == "local" { + continue + } + names := pkg.CompiledGoFiles + if len(names) == 0 { + names = pkg.GoFiles + } + for _, name := range names { + clean := filepath.Clean(name) + if _, ok := seen[clean]; ok { + continue + } + seen[clean] = struct{}{} + files = append(files, clean) + } + } + sort.Strings(files) + return buildCacheFiles(files) +} + +func buildExternalPackageExports(wd string, pkgs []*packages.Package) []externalPackageExport { + moduleRoot := findModuleRoot(wd) + out := make([]externalPackageExport, 0) + for _, pkg := range collectAllPackages(pkgs) { + if classifyPackageLocation(moduleRoot, pkg) == "local" { + continue + } + if pkg == nil || pkg.PkgPath == "" || pkg.ExportFile == "" { + continue + } + out = append(out, externalPackageExport{ + PkgPath: pkg.PkgPath, + ExportFile: pkg.ExportFile, + }) + } + sort.Slice(out, func(i, j int) bool { return out[i].PkgPath < out[j].PkgPath }) + return out +} + +func incrementalManifestValid(manifest *incrementalManifest, wd string, env []string, patterns []string, opts *GenerateOptions, pkgs []*packages.Package) bool { + if manifest == nil || manifest.Version != incrementalManifestVersion { + return false + } + if filepath.Clean(manifest.WD) != filepath.Clean(wd) || manifest.Tags != opts.Tags || manifest.Prefix != opts.PrefixOutputFile { + return false + } + if manifest.HeaderHash != headerHash(opts.Header) || manifest.EnvHash != envHash(env) { + return false + } + if len(manifest.Patterns) != len(patterns) { + return false + } + for i, p := range sortedStrings(patterns) { + if manifest.Patterns[i] != p { + return false + } + } + currentExternal, err := buildExternalPackageFiles(wd, pkgs) + if err != nil || len(currentExternal) != len(manifest.ExternalFiles) { + return false + } + for i := range currentExternal { + if currentExternal[i] != manifest.ExternalFiles[i] { + return false + } + } + if len(manifest.ExtraFiles) > 0 { + current, err := buildCacheFilesFromMeta(manifest.ExtraFiles) + if err != nil || len(current) != len(manifest.ExtraFiles) { + return false + } + for i := range current { + if current[i] != manifest.ExtraFiles[i] { + return false + } + } + } + return len(manifest.Outputs) > 0 +} + +func incrementalManifestPreloadValid(ctx context.Context, manifest *incrementalManifest, wd string, env []string, patterns []string, opts *GenerateOptions) (bool, []packageFingerprint, string) { + if manifest == nil || manifest.Version != incrementalManifestVersion { + return false, nil, "version" + } + if filepath.Clean(manifest.WD) != filepath.Clean(wd) || manifest.Tags != opts.Tags || manifest.Prefix != opts.PrefixOutputFile { + return false, nil, "config" + } + if manifest.HeaderHash != headerHash(opts.Header) || manifest.EnvHash != envHash(env) { + return false, nil, "env" + } + if len(manifest.Patterns) != len(patterns) { + return false, nil, "patterns.length" + } + for i, p := range sortedStrings(patterns) { + if manifest.Patterns[i] != p { + return false, nil, "patterns.value" + } + } + if len(manifest.ExtraFiles) > 0 { + current, err := buildCacheFilesFromMeta(manifest.ExtraFiles) + if err != nil || len(current) != len(manifest.ExtraFiles) { + return false, nil, "extra_files" + } + for i := range current { + if current[i] != manifest.ExtraFiles[i] { + return false, nil, "extra_files.diff" + } + } + } + currentLocal, ok, reason := incrementalManifestCurrentLocalPackages(ctx, manifest.LocalPackages) + if !ok { + return false, currentLocal, "local_packages." + reason + } + if len(manifest.ExternalFiles) > 0 { + current, err := buildCacheFilesFromMeta(manifest.ExternalFiles) + if err != nil || len(current) != len(manifest.ExternalFiles) { + return false, currentLocal, "external_files" + } + for i := range current { + if current[i] != manifest.ExternalFiles[i] { + return false, currentLocal, "external_files.diff" + } + } + } + if len(manifest.Outputs) == 0 { + return false, currentLocal, "outputs" + } + return true, currentLocal, "" +} + +func incrementalManifestCurrentLocalPackages(ctx context.Context, local []packageFingerprint) ([]packageFingerprint, bool, string) { + currentState := make([]packageFingerprint, 0, len(local)) + var firstReason string + for _, fp := range local { + if len(fp.Files) == 0 { + if firstReason == "" { + firstReason = fp.PkgPath + ".files" + } + continue + } + storedFiles := filesFromMeta(fp.Files) + if len(storedFiles) == 0 { + if firstReason == "" { + firstReason = fp.PkgPath + ".stored_files" + } + continue + } + currentMeta, err := buildCacheFiles(storedFiles) + if err != nil { + debugf(ctx, "incremental.preload_manifest local_pkg=%s meta_error=%v", fp.PkgPath, err) + if firstReason == "" { + firstReason = fp.PkgPath + ".meta_error" + } + continue + } + currentFP := fp + currentFP.Files = append([]cacheFile(nil), currentMeta...) + sameMeta := len(currentMeta) == len(fp.Files) + if sameMeta { + for i := range currentMeta { + if currentMeta[i] != fp.Files[i] { + sameMeta = false + break + } + } + } + if !sameMeta { + shapeHash, err := packageShapeHash(storedFiles) + if err != nil { + debugf(ctx, "incremental.preload_manifest local_pkg=%s shape_error=%v", fp.PkgPath, err) + if firstReason == "" { + firstReason = fp.PkgPath + ".shape_error" + } + continue + } + currentFP.ShapeHash = shapeHash + if shapeHash != fp.ShapeHash { + debugf(ctx, "incremental.preload_manifest local_pkg=%s stored_shape=%s current_shape=%s files=%s", fp.PkgPath, fp.ShapeHash, shapeHash, strings.Join(storedFiles, ",")) + if firstReason == "" { + firstReason = fp.PkgPath + ".shape_mismatch" + } + } + } + if changed, err := packageDirectoryIntroducedRelevantFiles(fp.Files); err != nil { + debugf(ctx, "incremental.preload_manifest local_pkg=%s dir_scan_error=%v", fp.PkgPath, err) + if firstReason == "" { + firstReason = fp.PkgPath + ".dir_scan_error" + } + continue + } else if changed { + debugf(ctx, "incremental.preload_manifest local_pkg=%s introduced_relevant_files=true", fp.PkgPath) + if firstReason == "" { + firstReason = fp.PkgPath + ".introduced_relevant_files" + } + } + currentState = append(currentState, currentFP) + } + if firstReason != "" { + return currentState, false, firstReason + } + return currentState, true, "" +} + +func incrementalManifestOutputs(manifest *incrementalManifest) ([]GenerateResult, bool) { + results := make([]GenerateResult, 0, len(manifest.Outputs)) + for _, out := range manifest.Outputs { + content, ok := readCache(out.ContentKey) + if !ok { + return nil, false + } + results = append(results, GenerateResult{ + PkgPath: out.PkgPath, + OutputPath: out.OutputPath, + Content: content, + }) + } + return results, true +} + +func readStateIncrementalManifest(selectorKey string, local []packageFingerprint) *incrementalManifest { + if len(local) == 0 { + return nil + } + stateKey := incrementalManifestStateKey(selectorKey, local) + manifest, ok := readIncrementalManifest(stateKey) + if !ok { + return nil + } + return manifest +} + +func incrementalManifestStateKey(selectorKey string, local []packageFingerprint) string { + h := sha256.New() + h.Write([]byte(selectorKey)) + h.Write([]byte{0}) + for _, fp := range snapshotPackageFingerprints(&incrementalFingerprintSnapshot{fingerprints: fingerprintsFromSlice(local)}) { + h.Write([]byte(fp.PkgPath)) + h.Write([]byte{0}) + h.Write([]byte(fp.ShapeHash)) + h.Write([]byte{0}) + } + return fmt.Sprintf("%x", h.Sum(nil)) +} + +func fingerprintsFromSlice(local []packageFingerprint) map[string]*packageFingerprint { + if len(local) == 0 { + return nil + } + out := make(map[string]*packageFingerprint, len(local)) + for i := range local { + fp := local[i] + out[fp.PkgPath] = &fp + } + return out +} + +func filesFromMeta(files []cacheFile) []string { + out := make([]string, 0, len(files)) + for _, f := range files { + out = append(out, filepath.Clean(f.Path)) + } + sort.Strings(out) + return out +} + +func packageDirectoryIntroducedRelevantFiles(files []cacheFile) (bool, error) { + dirs := make(map[string]struct{}) + old := make(map[string]struct{}, len(files)) + for _, f := range files { + path := filepath.Clean(f.Path) + dirs[filepath.Dir(path)] = struct{}{} + old[path] = struct{}{} + } + for dir := range dirs { + entries, err := os.ReadDir(dir) + if err != nil { + return false, err + } + for _, entry := range entries { + if entry.IsDir() { + continue + } + name := entry.Name() + if !strings.HasSuffix(name, ".go") { + continue + } + if strings.HasSuffix(name, "_test.go") { + continue + } + if strings.HasSuffix(name, "wire_gen.go") { + continue + } + path := filepath.Clean(filepath.Join(dir, name)) + if _, ok := old[path]; !ok { + return true, nil + } + } + } + return false, nil +} + +func incrementalManifestPath(key string) string { + return filepath.Join(cacheDir(), key+".iman") +} + +func readIncrementalManifest(key string) (*incrementalManifest, bool) { + data, err := osReadFile(incrementalManifestPath(key)) + if err != nil { + return nil, false + } + manifest, err := decodeIncrementalManifest(data) + if err != nil { + return nil, false + } + return manifest, true +} + +func writeIncrementalManifestFile(key string, manifest *incrementalManifest) { + data, err := encodeIncrementalManifest(manifest) + if err != nil { + return + } + dir := cacheDir() + if err := osMkdirAll(dir, 0755); err != nil { + return + } + tmp, err := osCreateTemp(dir, key+".iman-") + if err != nil { + return + } + _, writeErr := tmp.Write(data) + closeErr := tmp.Close() + if writeErr != nil || closeErr != nil { + osRemove(tmp.Name()) + return + } + if err := osRename(tmp.Name(), incrementalManifestPath(key)); err != nil { + osRemove(tmp.Name()) + } +} + +func encodeIncrementalManifest(manifest *incrementalManifest) ([]byte, error) { + var buf bytes.Buffer + if manifest == nil { + return nil, fmt.Errorf("nil incremental manifest") + } + writeString := func(s string) error { + if err := binary.Write(&buf, binary.LittleEndian, uint32(len(s))); err != nil { + return err + } + _, err := buf.WriteString(s) + return err + } + writeCacheFiles := func(files []cacheFile) error { + if err := binary.Write(&buf, binary.LittleEndian, uint32(len(files))); err != nil { + return err + } + for _, f := range files { + if err := writeString(f.Path); err != nil { + return err + } + if err := binary.Write(&buf, binary.LittleEndian, f.Size); err != nil { + return err + } + if err := binary.Write(&buf, binary.LittleEndian, f.ModTime); err != nil { + return err + } + } + return nil + } + writeExternalPkgs := func(pkgs []externalPackageExport) error { + if err := binary.Write(&buf, binary.LittleEndian, uint32(len(pkgs))); err != nil { + return err + } + for _, pkg := range pkgs { + if err := writeString(pkg.PkgPath); err != nil { + return err + } + if err := writeString(pkg.ExportFile); err != nil { + return err + } + } + return nil + } + writeFingerprints := func(fps []packageFingerprint) error { + if err := binary.Write(&buf, binary.LittleEndian, uint32(len(fps))); err != nil { + return err + } + for _, fp := range fps { + for _, s := range []string{fp.Version, fp.WD, fp.Tags, fp.PkgPath, fp.ShapeHash} { + if err := writeString(s); err != nil { + return err + } + } + if err := writeCacheFiles(fp.Files); err != nil { + return err + } + if err := binary.Write(&buf, binary.LittleEndian, uint32(len(fp.LocalImports))); err != nil { + return err + } + for _, imp := range fp.LocalImports { + if err := writeString(imp); err != nil { + return err + } + } + } + return nil + } + writeOutputs := func(outputs []incrementalOutput) error { + if err := binary.Write(&buf, binary.LittleEndian, uint32(len(outputs))); err != nil { + return err + } + for _, out := range outputs { + for _, s := range []string{out.PkgPath, out.OutputPath, out.ContentKey} { + if err := writeString(s); err != nil { + return err + } + } + } + return nil + } + for _, s := range []string{manifest.Version, manifest.WD, manifest.Tags, manifest.Prefix, manifest.HeaderHash, manifest.EnvHash} { + if err := writeString(s); err != nil { + return nil, err + } + } + if err := binary.Write(&buf, binary.LittleEndian, uint32(len(manifest.Patterns))); err != nil { + return nil, err + } + for _, p := range manifest.Patterns { + if err := writeString(p); err != nil { + return nil, err + } + } + if err := writeFingerprints(manifest.LocalPackages); err != nil { + return nil, err + } + if err := writeExternalPkgs(manifest.ExternalPkgs); err != nil { + return nil, err + } + if err := writeCacheFiles(manifest.ExternalFiles); err != nil { + return nil, err + } + if err := writeCacheFiles(manifest.ExtraFiles); err != nil { + return nil, err + } + if err := writeOutputs(manifest.Outputs); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func decodeIncrementalManifest(data []byte) (*incrementalManifest, error) { + r := bytes.NewReader(data) + readString := func() (string, error) { + var n uint32 + if err := binary.Read(r, binary.LittleEndian, &n); err != nil { + return "", err + } + buf := make([]byte, n) + if _, err := r.Read(buf); err != nil { + return "", err + } + return string(buf), nil + } + readCacheFiles := func() ([]cacheFile, error) { + var n uint32 + if err := binary.Read(r, binary.LittleEndian, &n); err != nil { + return nil, err + } + out := make([]cacheFile, 0, n) + for i := uint32(0); i < n; i++ { + path, err := readString() + if err != nil { + return nil, err + } + var size int64 + if err := binary.Read(r, binary.LittleEndian, &size); err != nil { + return nil, err + } + var modTime int64 + if err := binary.Read(r, binary.LittleEndian, &modTime); err != nil { + return nil, err + } + out = append(out, cacheFile{Path: path, Size: size, ModTime: modTime}) + } + return out, nil + } + readExternalPkgs := func() ([]externalPackageExport, error) { + var n uint32 + if err := binary.Read(r, binary.LittleEndian, &n); err != nil { + return nil, err + } + out := make([]externalPackageExport, 0, n) + for i := uint32(0); i < n; i++ { + pkgPath, err := readString() + if err != nil { + return nil, err + } + exportFile, err := readString() + if err != nil { + return nil, err + } + out = append(out, externalPackageExport{PkgPath: pkgPath, ExportFile: exportFile}) + } + return out, nil + } + readFingerprints := func() ([]packageFingerprint, error) { + var n uint32 + if err := binary.Read(r, binary.LittleEndian, &n); err != nil { + return nil, err + } + out := make([]packageFingerprint, 0, n) + for i := uint32(0); i < n; i++ { + version, err := readString() + if err != nil { + return nil, err + } + wd, err := readString() + if err != nil { + return nil, err + } + tags, err := readString() + if err != nil { + return nil, err + } + pkgPath, err := readString() + if err != nil { + return nil, err + } + shapeHash, err := readString() + if err != nil { + return nil, err + } + files, err := readCacheFiles() + if err != nil { + return nil, err + } + var importCount uint32 + if err := binary.Read(r, binary.LittleEndian, &importCount); err != nil { + return nil, err + } + localImports := make([]string, 0, importCount) + for j := uint32(0); j < importCount; j++ { + imp, err := readString() + if err != nil { + return nil, err + } + localImports = append(localImports, imp) + } + out = append(out, packageFingerprint{ + Version: version, + WD: wd, + Tags: tags, + PkgPath: pkgPath, + ShapeHash: shapeHash, + Files: files, + LocalImports: localImports, + }) + } + return out, nil + } + readOutputs := func() ([]incrementalOutput, error) { + var n uint32 + if err := binary.Read(r, binary.LittleEndian, &n); err != nil { + return nil, err + } + out := make([]incrementalOutput, 0, n) + for i := uint32(0); i < n; i++ { + pkgPath, err := readString() + if err != nil { + return nil, err + } + outputPath, err := readString() + if err != nil { + return nil, err + } + contentKey, err := readString() + if err != nil { + return nil, err + } + out = append(out, incrementalOutput{PkgPath: pkgPath, OutputPath: outputPath, ContentKey: contentKey}) + } + return out, nil + } + fields := make([]string, 6) + for i := range fields { + s, err := readString() + if err != nil { + return nil, err + } + fields[i] = s + } + var patternCount uint32 + if err := binary.Read(r, binary.LittleEndian, &patternCount); err != nil { + return nil, err + } + patterns := make([]string, 0, patternCount) + for i := uint32(0); i < patternCount; i++ { + p, err := readString() + if err != nil { + return nil, err + } + patterns = append(patterns, p) + } + localPackages, err := readFingerprints() + if err != nil { + return nil, err + } + externalPkgs, err := readExternalPkgs() + if err != nil { + return nil, err + } + externalFiles, err := readCacheFiles() + if err != nil { + return nil, err + } + extraFiles, err := readCacheFiles() + if err != nil { + return nil, err + } + outputs, err := readOutputs() + if err != nil { + return nil, err + } + return &incrementalManifest{ + Version: fields[0], + WD: fields[1], + Tags: fields[2], + Prefix: fields[3], + HeaderHash: fields[4], + EnvHash: fields[5], + Patterns: patterns, + LocalPackages: localPackages, + ExternalPkgs: externalPkgs, + ExternalFiles: externalFiles, + ExtraFiles: extraFiles, + Outputs: outputs, + }, nil +} + +func incrementalContentKey(content []byte) string { + sum := sha256.Sum256(content) + return fmt.Sprintf("%x", sum[:]) +} diff --git a/internal/wire/incremental_session.go b/internal/wire/incremental_session.go new file mode 100644 index 0000000..fda6605 --- /dev/null +++ b/internal/wire/incremental_session.go @@ -0,0 +1,95 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "crypto/sha256" + "encoding/hex" + "go/ast" + "go/token" + "path/filepath" + "strings" + "sync" +) + +type incrementalSession struct { + fset *token.FileSet + mu sync.Mutex + parsedDeps map[string]cachedParsedFile +} + +type cachedParsedFile struct { + hash string + file *ast.File +} + +var incrementalSessions sync.Map + +func sessionKey(wd string, env []string, tags string) string { + var b strings.Builder + b.WriteString(filepath.Clean(wd)) + b.WriteByte('\n') + b.WriteString(tags) + b.WriteByte('\n') + for _, entry := range env { + b.WriteString(entry) + b.WriteByte('\x00') + } + return b.String() +} + +func getIncrementalSession(wd string, env []string, tags string) *incrementalSession { + key := sessionKey(wd, env, tags) + if session, ok := incrementalSessions.Load(key); ok { + return session.(*incrementalSession) + } + session := &incrementalSession{ + fset: token.NewFileSet(), + parsedDeps: make(map[string]cachedParsedFile), + } + actual, _ := incrementalSessions.LoadOrStore(key, session) + return actual.(*incrementalSession) +} + +func (s *incrementalSession) getParsedDep(filename string, src []byte) (*ast.File, bool) { + if s == nil { + return nil, false + } + hash := hashSource(src) + s.mu.Lock() + defer s.mu.Unlock() + entry, ok := s.parsedDeps[filepath.Clean(filename)] + if !ok || entry.hash != hash { + return nil, false + } + return entry.file, true +} + +func (s *incrementalSession) storeParsedDep(filename string, src []byte, file *ast.File) { + if s == nil || file == nil { + return + } + s.mu.Lock() + defer s.mu.Unlock() + s.parsedDeps[filepath.Clean(filename)] = cachedParsedFile{ + hash: hashSource(src), + file: file, + } +} + +func hashSource(src []byte) string { + sum := sha256.Sum256(src) + return hex.EncodeToString(sum[:]) +} diff --git a/internal/wire/incremental_summary.go b/internal/wire/incremental_summary.go new file mode 100644 index 0000000..faaa9b8 --- /dev/null +++ b/internal/wire/incremental_summary.go @@ -0,0 +1,647 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "bytes" + "crypto/sha256" + "encoding/binary" + "fmt" + "go/ast" + "go/types" + "path/filepath" + "sort" + + "golang.org/x/tools/go/packages" +) + +const incrementalSummaryVersion = "wire-incremental-summary-v1" + +type packageSummary struct { + Version string + WD string + Tags string + PkgPath string + ShapeHash string + LocalImports []string + ProviderSets []providerSetSummary + Injectors []injectorSummary +} + +type providerSetSummary struct { + VarName string + Providers []providerSummary + Imports []providerSetRefSummary + Bindings []ifaceBindingSummary + Values []string + Fields []fieldSummary + InputTypes []string +} + +type providerSummary struct { + PkgPath string + Name string + Args []providerInputSummary + Out []string + Varargs bool + IsStruct bool + HasCleanup bool + HasErr bool +} + +type providerInputSummary struct { + Type string + FieldName string +} + +type providerSetRefSummary struct { + PkgPath string + VarName string +} + +type ifaceBindingSummary struct { + Iface string + Provided string +} + +type fieldSummary struct { + PkgPath string + Parent string + Name string + Out []string +} + +type injectorSummary struct { + Name string + Inputs []string + Output string + Build providerSetSummary +} + +type packageSummarySnapshot struct { + Changed map[string]*packageSummary + Unchanged map[string]*packageSummary +} + +func incrementalSummaryKey(wd string, tags string, pkgPath string) string { + h := sha256.New() + h.Write([]byte(incrementalSummaryVersion)) + h.Write([]byte{0}) + h.Write([]byte(filepath.Clean(wd))) + h.Write([]byte{0}) + h.Write([]byte(tags)) + h.Write([]byte{0}) + h.Write([]byte(pkgPath)) + return fmt.Sprintf("%x", h.Sum(nil)) +} + +func incrementalSummaryPath(key string) string { + return filepath.Join(cacheDir(), key+".isum") +} + +func readIncrementalPackageSummary(key string) (*packageSummary, bool) { + data, err := osReadFile(incrementalSummaryPath(key)) + if err != nil { + return nil, false + } + summary, err := decodeIncrementalSummary(data) + if err != nil { + return nil, false + } + return summary, true +} + +func writeIncrementalPackageSummary(key string, summary *packageSummary) { + data, err := encodeIncrementalSummary(summary) + if err != nil { + return + } + dir := cacheDir() + if err := osMkdirAll(dir, 0755); err != nil { + return + } + tmp, err := osCreateTemp(dir, key+".isum-") + if err != nil { + return + } + _, writeErr := tmp.Write(data) + closeErr := tmp.Close() + if writeErr != nil || closeErr != nil { + osRemove(tmp.Name()) + return + } + if err := osRename(tmp.Name(), incrementalSummaryPath(key)); err != nil { + osRemove(tmp.Name()) + } +} + +func writeIncrementalPackageSummaries(loader *lazyLoader, pkgs []*packages.Package) { + if loader == nil || len(pkgs) == 0 { + return + } + moduleRoot := findModuleRoot(loader.wd) + all := collectAllPackages(pkgs) + for path, pkg := range loader.loaded { + if pkg != nil { + all[path] = pkg + } + } + allPkgs := make([]*packages.Package, 0, len(all)) + for _, pkg := range all { + allPkgs = append(allPkgs, pkg) + } + oc := newObjectCache(allPkgs, loader) + for _, pkg := range all { + if classifyPackageLocation(moduleRoot, pkg) != "local" { + continue + } + if pkg == nil || pkg.TypesInfo == nil || len(pkg.Syntax) == 0 { + continue + } + summary, err := buildPackageSummary(loader, oc, pkg) + if err != nil { + continue + } + writeIncrementalPackageSummary(incrementalSummaryKey(loader.wd, loader.tags, pkg.PkgPath), summary) + } +} + +func collectIncrementalPackageSummaries(loader *lazyLoader, pkgs []*packages.Package) *packageSummarySnapshot { + if loader == nil || loader.fingerprints == nil { + return nil + } + snapshot := &packageSummarySnapshot{ + Changed: make(map[string]*packageSummary), + Unchanged: make(map[string]*packageSummary), + } + changed := make(map[string]struct{}, len(loader.fingerprints.changed)) + for _, path := range loader.fingerprints.changed { + changed[path] = struct{}{} + } + moduleRoot := findModuleRoot(loader.wd) + oc := newObjectCache(pkgs, loader) + for _, pkg := range collectAllPackages(pkgs) { + if pkg == nil { + continue + } + if classifyPackageLocation(moduleRoot, pkg) != "local" { + continue + } + if _, ok := changed[pkg.PkgPath]; ok { + if pkg.TypesInfo == nil || len(pkg.Syntax) == 0 { + loaded, errs := oc.ensurePackage(pkg.PkgPath) + if len(errs) > 0 { + continue + } + pkg = loaded + } + if pkg == nil || pkg.TypesInfo == nil || len(pkg.Syntax) == 0 { + continue + } + summary, err := buildPackageSummary(loader, oc, pkg) + if err != nil { + continue + } + snapshot.Changed[pkg.PkgPath] = summary + continue + } + if summary, ok := readIncrementalPackageSummary(incrementalSummaryKey(loader.wd, loader.tags, pkg.PkgPath)); ok { + snapshot.Unchanged[pkg.PkgPath] = summary + } + } + return snapshot +} + +func buildPackageSummary(loader *lazyLoader, oc *objectCache, pkg *packages.Package) (*packageSummary, error) { + if loader == nil || oc == nil || pkg == nil { + return nil, fmt.Errorf("missing loader, object cache, or package") + } + summary := &packageSummary{ + Version: incrementalSummaryVersion, + WD: filepath.Clean(loader.wd), + Tags: loader.tags, + PkgPath: pkg.PkgPath, + } + if snapshot := loader.fingerprints; snapshot != nil { + if fp := snapshot.fingerprints[pkg.PkgPath]; fp != nil { + summary.ShapeHash = fp.ShapeHash + summary.LocalImports = append(summary.LocalImports, fp.LocalImports...) + } + } + scope := pkg.Types.Scope() + for _, name := range scope.Names() { + obj := scope.Lookup(name) + if !isProviderSetType(obj.Type()) { + continue + } + item, errs := oc.get(obj) + if len(errs) > 0 { + continue + } + pset, ok := item.(*ProviderSet) + if !ok { + continue + } + summary.ProviderSets = append(summary.ProviderSets, summarizeProviderSet(pset)) + } + sort.Slice(summary.ProviderSets, func(i, j int) bool { + return summary.ProviderSets[i].VarName < summary.ProviderSets[j].VarName + }) + for _, file := range pkg.Syntax { + for _, decl := range file.Decls { + fn, ok := decl.(*ast.FuncDecl) + if !ok { + continue + } + buildCall, err := findInjectorBuild(pkg.TypesInfo, fn) + if err != nil || buildCall == nil { + continue + } + sig := pkg.TypesInfo.ObjectOf(fn.Name).Type().(*types.Signature) + ins, out, err := injectorFuncSignature(sig) + if err != nil { + continue + } + injectorArgs := &InjectorArgs{ + Name: fn.Name.Name, + Tuple: ins, + Pos: fn.Pos(), + } + set, errs := oc.processNewSet(pkg.TypesInfo, pkg.PkgPath, buildCall, injectorArgs, "") + if len(errs) > 0 { + continue + } + summary.Injectors = append(summary.Injectors, injectorSummary{ + Name: fn.Name.Name, + Inputs: summarizeTuple(ins), + Output: summaryTypeString(out.out), + Build: summarizeProviderSet(set), + }) + } + } + sort.Slice(summary.Injectors, func(i, j int) bool { + return summary.Injectors[i].Name < summary.Injectors[j].Name + }) + return summary, nil +} + +func summarizeProviderSet(pset *ProviderSet) providerSetSummary { + if pset == nil { + return providerSetSummary{} + } + summary := providerSetSummary{ + VarName: pset.VarName, + } + for _, provider := range pset.Providers { + summary.Providers = append(summary.Providers, summarizeProvider(provider)) + } + for _, imported := range pset.Imports { + summary.Imports = append(summary.Imports, providerSetRefSummary{ + PkgPath: imported.PkgPath, + VarName: imported.VarName, + }) + } + for _, binding := range pset.Bindings { + summary.Bindings = append(summary.Bindings, ifaceBindingSummary{ + Iface: summaryTypeString(binding.Iface), + Provided: summaryTypeString(binding.Provided), + }) + } + for _, value := range pset.Values { + summary.Values = append(summary.Values, summaryTypeString(value.Out)) + } + for _, field := range pset.Fields { + item := fieldSummary{ + Parent: summaryTypeString(field.Parent), + Name: field.Name, + Out: summarizeTypes(field.Out), + } + if field.Pkg != nil { + item.PkgPath = field.Pkg.Path() + } + summary.Fields = append(summary.Fields, item) + } + if pset.InjectorArgs != nil { + summary.InputTypes = summarizeTuple(pset.InjectorArgs.Tuple) + } + sort.Slice(summary.Providers, func(i, j int) bool { + return summary.Providers[i].PkgPath+"."+summary.Providers[i].Name < summary.Providers[j].PkgPath+"."+summary.Providers[j].Name + }) + sort.Slice(summary.Imports, func(i, j int) bool { + return summary.Imports[i].PkgPath+"."+summary.Imports[i].VarName < summary.Imports[j].PkgPath+"."+summary.Imports[j].VarName + }) + sort.Slice(summary.Bindings, func(i, j int) bool { + return summary.Bindings[i].Iface+":"+summary.Bindings[i].Provided < summary.Bindings[j].Iface+":"+summary.Bindings[j].Provided + }) + sort.Strings(summary.Values) + sort.Slice(summary.Fields, func(i, j int) bool { + return summary.Fields[i].Parent+"."+summary.Fields[i].Name < summary.Fields[j].Parent+"."+summary.Fields[j].Name + }) + sort.Strings(summary.InputTypes) + return summary +} + +func summarizeProvider(provider *Provider) providerSummary { + summary := providerSummary{ + Name: provider.Name, + Varargs: provider.Varargs, + IsStruct: provider.IsStruct, + HasCleanup: provider.HasCleanup, + HasErr: provider.HasErr, + Out: summarizeTypes(provider.Out), + } + if provider.Pkg != nil { + summary.PkgPath = provider.Pkg.Path() + } + for _, arg := range provider.Args { + summary.Args = append(summary.Args, providerInputSummary{ + Type: summaryTypeString(arg.Type), + FieldName: arg.FieldName, + }) + } + return summary +} + +func summarizeTuple(tuple *types.Tuple) []string { + if tuple == nil { + return nil + } + out := make([]string, 0, tuple.Len()) + for i := 0; i < tuple.Len(); i++ { + out = append(out, summaryTypeString(tuple.At(i).Type())) + } + return out +} + +func summarizeTypes(typesList []types.Type) []string { + out := make([]string, 0, len(typesList)) + for _, t := range typesList { + out = append(out, summaryTypeString(t)) + } + return out +} + +func summaryTypeString(t types.Type) string { + if t == nil { + return "" + } + return types.TypeString(t, func(pkg *types.Package) string { + if pkg == nil { + return "" + } + return pkg.Path() + }) +} + +func encodeIncrementalSummary(summary *packageSummary) ([]byte, error) { + if summary == nil { + return nil, fmt.Errorf("nil package summary") + } + var buf bytes.Buffer + enc := binarySummaryEncoder{buf: &buf} + enc.string(summary.Version) + enc.string(summary.WD) + enc.string(summary.Tags) + enc.string(summary.PkgPath) + enc.string(summary.ShapeHash) + enc.strings(summary.LocalImports) + enc.providerSets(summary.ProviderSets) + enc.u32(uint32(len(summary.Injectors))) + for _, injector := range summary.Injectors { + enc.string(injector.Name) + enc.strings(injector.Inputs) + enc.string(injector.Output) + enc.providerSet(injector.Build) + } + if enc.err != nil { + return nil, enc.err + } + return buf.Bytes(), nil +} + +func decodeIncrementalSummary(data []byte) (*packageSummary, error) { + dec := binarySummaryDecoder{r: bytes.NewReader(data)} + summary := &packageSummary{ + Version: dec.string(), + WD: dec.string(), + Tags: dec.string(), + PkgPath: dec.string(), + ShapeHash: dec.string(), + } + summary.LocalImports = dec.strings() + summary.ProviderSets = dec.providerSets() + for n := dec.u32(); n > 0; n-- { + summary.Injectors = append(summary.Injectors, injectorSummary{ + Name: dec.string(), + Inputs: dec.strings(), + Output: dec.string(), + Build: dec.providerSet(), + }) + } + if dec.err != nil { + return nil, dec.err + } + return summary, nil +} + +type binarySummaryEncoder struct { + buf *bytes.Buffer + err error +} + +func (e *binarySummaryEncoder) u32(v uint32) { + if e.err != nil { + return + } + e.err = binary.Write(e.buf, binary.LittleEndian, v) +} + +func (e *binarySummaryEncoder) string(s string) { + e.u32(uint32(len(s))) + if e.err != nil { + return + } + _, e.err = e.buf.WriteString(s) +} + +func (e *binarySummaryEncoder) bool(v bool) { + if e.err != nil { + return + } + var b byte + if v { + b = 1 + } + e.err = e.buf.WriteByte(b) +} + +func (e *binarySummaryEncoder) strings(values []string) { + e.u32(uint32(len(values))) + for _, v := range values { + e.string(v) + } +} + +func (e *binarySummaryEncoder) providerSets(values []providerSetSummary) { + e.u32(uint32(len(values))) + for _, value := range values { + e.providerSet(value) + } +} + +func (e *binarySummaryEncoder) providerSet(value providerSetSummary) { + e.string(value.VarName) + e.u32(uint32(len(value.Providers))) + for _, provider := range value.Providers { + e.string(provider.PkgPath) + e.string(provider.Name) + e.u32(uint32(len(provider.Args))) + for _, arg := range provider.Args { + e.string(arg.Type) + e.string(arg.FieldName) + } + e.strings(provider.Out) + e.bool(provider.Varargs) + e.bool(provider.IsStruct) + e.bool(provider.HasCleanup) + e.bool(provider.HasErr) + } + e.u32(uint32(len(value.Imports))) + for _, imported := range value.Imports { + e.string(imported.PkgPath) + e.string(imported.VarName) + } + e.u32(uint32(len(value.Bindings))) + for _, binding := range value.Bindings { + e.string(binding.Iface) + e.string(binding.Provided) + } + e.strings(value.Values) + e.u32(uint32(len(value.Fields))) + for _, field := range value.Fields { + e.string(field.PkgPath) + e.string(field.Parent) + e.string(field.Name) + e.strings(field.Out) + } + e.strings(value.InputTypes) +} + +type binarySummaryDecoder struct { + r *bytes.Reader + err error +} + +func (d *binarySummaryDecoder) u32() uint32 { + if d.err != nil { + return 0 + } + var v uint32 + d.err = binary.Read(d.r, binary.LittleEndian, &v) + return v +} + +func (d *binarySummaryDecoder) string() string { + n := d.u32() + if d.err != nil { + return "" + } + buf := make([]byte, n) + _, d.err = d.r.Read(buf) + return string(buf) +} + +func (d *binarySummaryDecoder) bool() bool { + if d.err != nil { + return false + } + b, err := d.r.ReadByte() + if err != nil { + d.err = err + return false + } + return b != 0 +} + +func (d *binarySummaryDecoder) strings() []string { + n := d.u32() + if d.err != nil { + return nil + } + out := make([]string, 0, n) + for i := uint32(0); i < n; i++ { + out = append(out, d.string()) + } + return out +} + +func (d *binarySummaryDecoder) providerSets() []providerSetSummary { + n := d.u32() + if d.err != nil { + return nil + } + out := make([]providerSetSummary, 0, n) + for i := uint32(0); i < n; i++ { + out = append(out, d.providerSet()) + } + return out +} + +func (d *binarySummaryDecoder) providerSet() providerSetSummary { + value := providerSetSummary{ + VarName: d.string(), + } + for n := d.u32(); n > 0; n-- { + provider := providerSummary{ + PkgPath: d.string(), + Name: d.string(), + } + for m := d.u32(); m > 0; m-- { + provider.Args = append(provider.Args, providerInputSummary{ + Type: d.string(), + FieldName: d.string(), + }) + } + provider.Out = d.strings() + provider.Varargs = d.bool() + provider.IsStruct = d.bool() + provider.HasCleanup = d.bool() + provider.HasErr = d.bool() + value.Providers = append(value.Providers, provider) + } + for n := d.u32(); n > 0; n-- { + value.Imports = append(value.Imports, providerSetRefSummary{ + PkgPath: d.string(), + VarName: d.string(), + }) + } + for n := d.u32(); n > 0; n-- { + value.Bindings = append(value.Bindings, ifaceBindingSummary{ + Iface: d.string(), + Provided: d.string(), + }) + } + value.Values = d.strings() + for n := d.u32(); n > 0; n-- { + value.Fields = append(value.Fields, fieldSummary{ + PkgPath: d.string(), + Parent: d.string(), + Name: d.string(), + Out: d.strings(), + }) + } + value.InputTypes = d.strings() + return value +} diff --git a/internal/wire/incremental_summary_test.go b/internal/wire/incremental_summary_test.go new file mode 100644 index 0000000..efb4028 --- /dev/null +++ b/internal/wire/incremental_summary_test.go @@ -0,0 +1,287 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestIncrementalSummaryEncodeDecodeRoundTrip(t *testing.T) { + summary := &packageSummary{ + Version: incrementalSummaryVersion, + WD: "/tmp/app", + Tags: "dev", + PkgPath: "example.com/app/dep", + ShapeHash: "abc123", + LocalImports: []string{"example.com/app/shared"}, + ProviderSets: []providerSetSummary{{ + VarName: "Set", + Providers: []providerSummary{{ + PkgPath: "example.com/app/dep", + Name: "NewThing", + Args: []providerInputSummary{{Type: "string"}}, + Out: []string{"*example.com/app/dep.Thing"}, + HasCleanup: true, + }}, + Imports: []providerSetRefSummary{{PkgPath: "example.com/app/shared", VarName: "SharedSet"}}, + Bindings: []ifaceBindingSummary{{Iface: "error", Provided: "*example.com/app/dep.Thing"}}, + Values: []string{"string"}, + Fields: []fieldSummary{{PkgPath: "example.com/app/dep", Parent: "example.com/app/dep.Config", Name: "Name", Out: []string{"string"}}}, + InputTypes: []string{"context.Context"}, + }}, + Injectors: []injectorSummary{{ + Name: "Init", + Inputs: []string{"context.Context"}, + Output: "*example.com/app/dep.Thing", + Build: providerSetSummary{ + Imports: []providerSetRefSummary{{PkgPath: "example.com/app/shared", VarName: "SharedSet"}}, + }, + }}, + } + data, err := encodeIncrementalSummary(summary) + if err != nil { + t.Fatalf("encodeIncrementalSummary: %v", err) + } + got, err := decodeIncrementalSummary(data) + if err != nil { + t.Fatalf("decodeIncrementalSummary: %v", err) + } + if got.Version != summary.Version || got.PkgPath != summary.PkgPath || got.ShapeHash != summary.ShapeHash { + t.Fatalf("decoded summary mismatch: %+v", got) + } + if len(got.ProviderSets) != 1 || got.ProviderSets[0].VarName != "Set" { + t.Fatalf("decoded provider sets mismatch: %+v", got.ProviderSets) + } + if len(got.Injectors) != 1 || got.Injectors[0].Name != "Init" { + t.Fatalf("decoded injectors mismatch: %+v", got.Injectors) + } +} + +func TestBuildPackageSummary(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + + writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *dep.Foo {", + "\twire.Build(dep.Set)", + "\treturn nil", + "}", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Foo struct{ Message string }", + "", + "func NewMessage() string { return \"ok\" }", + "", + "func New(msg string) *Foo { return &Foo{Message: msg} }", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var Set = wire.NewSet(NewMessage, New)", + "", + }, "\n")) + + env := append(os.Environ(), "GOWORK=off") + ctx := WithIncremental(context.Background(), true) + pkgs, loader, errs := load(ctx, root, env, "", []string{"./app"}) + if len(errs) > 0 { + t.Fatalf("load returned errors: %v", errs) + } + oc := newObjectCache(pkgs, loader) + loadedDep, errs := oc.ensurePackage("example.com/app/dep") + if len(errs) > 0 { + t.Fatalf("ensurePackage returned errors: %v", errs) + } + summary, err := buildPackageSummary(loader, oc, loadedDep) + if err != nil { + t.Fatalf("buildPackageSummary: %v", err) + } + if summary.PkgPath != "example.com/app/dep" { + t.Fatalf("summary pkg path = %q", summary.PkgPath) + } + if len(summary.ProviderSets) != 1 || summary.ProviderSets[0].VarName != "Set" { + t.Fatalf("unexpected provider sets: %+v", summary.ProviderSets) + } + if len(summary.ProviderSets[0].Providers) != 2 { + t.Fatalf("unexpected providers: %+v", summary.ProviderSets[0].Providers) + } + loadedApp, errs := oc.ensurePackage("example.com/app/app") + if len(errs) > 0 { + t.Fatalf("ensurePackage app returned errors: %v", errs) + } + appSummary, err := buildPackageSummary(loader, oc, loadedApp) + if err != nil { + t.Fatalf("buildPackageSummary app: %v", err) + } + if len(appSummary.Injectors) != 1 || appSummary.Injectors[0].Name != "Init" { + t.Fatalf("unexpected injectors: %+v", appSummary.Injectors) + } + if len(appSummary.Injectors[0].Build.Imports) != 1 || appSummary.Injectors[0].Build.Imports[0].PkgPath != "example.com/app/dep" { + t.Fatalf("unexpected injector imports: %+v", appSummary.Injectors[0].Build.Imports) + } +} + +func TestCollectIncrementalPackageSummariesUsesCacheForUnchanged(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + + writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *dep.Foo {", + "\twire.Build(dep.Set)", + "\treturn nil", + "}", + "", + }, "\n")) + + depFile := filepath.Join(root, "dep", "dep.go") + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "type Foo struct{ Message string }", + "", + "func NewMessage() string { return \"ok\" }", + "", + "func New(msg string) *Foo { return &Foo{Message: msg} }", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var Set = wire.NewSet(NewMessage, New)", + "", + }, "\n")) + + env := append(os.Environ(), "GOWORK=off") + ctx := WithIncremental(context.Background(), true) + + gens, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("Generate returned errors: %v", errs) + } + if len(gens) != 1 || len(gens[0].Errs) > 0 { + t.Fatalf("unexpected Generate result: %+v", gens) + } + + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "type Foo struct{ Message string; Count int }", + "", + "func NewMessage() string { return \"ok\" }", + "", + "func NewCount() int { return 7 }", + "", + "func New(msg string, count int) *Foo { return &Foo{Message: msg, Count: count} }", + "", + }, "\n")) + writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var Set = wire.NewSet(NewMessage, NewCount, New)", + "", + }, "\n")) + + pkgs, loader, errs := load(ctx, root, env, "", []string{"./app"}) + if len(errs) > 0 { + t.Fatalf("load returned errors: %v", errs) + } + snapshot := collectIncrementalPackageSummaries(loader, pkgs) + if snapshot == nil { + t.Fatal("collectIncrementalPackageSummaries returned nil") + } + if _, ok := snapshot.Changed["example.com/app/dep"]; !ok { + t.Fatalf("expected changed dep summary, got %+v", snapshot.Changed) + } + if _, ok := snapshot.Unchanged["example.com/app/app"]; !ok { + t.Fatalf("expected unchanged app summary from cache, got %+v", snapshot.Unchanged) + } + if len(snapshot.Unchanged["example.com/app/app"].Injectors) != 1 { + t.Fatalf("unexpected cached app summary: %+v", snapshot.Unchanged["example.com/app/app"]) + } + if len(snapshot.Changed["example.com/app/dep"].ProviderSets) != 1 { + t.Fatalf("unexpected changed dep summary: %+v", snapshot.Changed["example.com/app/dep"]) + } +} diff --git a/internal/wire/incremental_test.go b/internal/wire/incremental_test.go new file mode 100644 index 0000000..a531123 --- /dev/null +++ b/internal/wire/incremental_test.go @@ -0,0 +1,65 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "context" + "testing" +) + +func TestIncrementalEnabledDefaultOff(t *testing.T) { + if IncrementalEnabled(context.Background(), nil) { + t.Fatal("IncrementalEnabled should default to false") + } +} + +func TestIncrementalEnabledFromEnv(t *testing.T) { + env := []string{ + "FOO=bar", + IncrementalEnvVar + "=true", + } + if !IncrementalEnabled(context.Background(), env) { + t.Fatal("IncrementalEnabled should read the environment variable") + } +} + +func TestIncrementalEnabledUsesLastEnvValue(t *testing.T) { + env := []string{ + IncrementalEnvVar + "=false", + IncrementalEnvVar + "=true", + } + if !IncrementalEnabled(context.Background(), env) { + t.Fatal("IncrementalEnabled should use the last matching env value") + } +} + +func TestIncrementalEnabledContextOverridesEnv(t *testing.T) { + env := []string{ + IncrementalEnvVar + "=false", + } + ctx := WithIncremental(context.Background(), true) + if !IncrementalEnabled(ctx, env) { + t.Fatal("context override should take precedence over env") + } +} + +func TestIncrementalEnabledInvalidEnvFallsBackFalse(t *testing.T) { + env := []string{ + IncrementalEnvVar + "=maybe", + } + if IncrementalEnabled(context.Background(), env) { + t.Fatal("invalid env value should not enable incremental mode") + } +} diff --git a/internal/wire/load_debug.go b/internal/wire/load_debug.go new file mode 100644 index 0000000..fd8c4d7 --- /dev/null +++ b/internal/wire/load_debug.go @@ -0,0 +1,304 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "context" + "fmt" + "os" + "path/filepath" + "sort" + "strings" + "sync" + "time" + + "golang.org/x/tools/go/packages" +) + +type parseFileStats struct { + mu sync.Mutex + calls int + primaryCalls int + depCalls int + cacheHits int + cacheMisses int + errors int + total time.Duration +} + +func (ps *parseFileStats) record(primary bool, dur time.Duration, err error, cacheHit bool) { + ps.mu.Lock() + defer ps.mu.Unlock() + ps.calls++ + if primary { + ps.primaryCalls++ + } else { + ps.depCalls++ + } + if cacheHit { + ps.cacheHits++ + } else { + ps.cacheMisses++ + } + ps.total += dur + if err != nil { + ps.errors++ + } +} + +func (ps *parseFileStats) snapshot() parseFileStats { + ps.mu.Lock() + defer ps.mu.Unlock() + return parseFileStats{ + calls: ps.calls, + primaryCalls: ps.primaryCalls, + depCalls: ps.depCalls, + cacheHits: ps.cacheHits, + cacheMisses: ps.cacheMisses, + errors: ps.errors, + total: ps.total, + } +} + +type loadScopeStats struct { + roots int + totalPackages int + compiledFiles int + syntaxFiles int + packagesWithSyntax int + packagesWithTypes int + packagesWithTypesInfo int + localPackages int + localSyntaxPackages int + externalPackages int + externalSyntaxPkgs int + unknownPackages int + topCompiled []string + topSyntax []string +} + +type packageMetric struct { + path string + count int +} + +func logLoadDebug(ctx context.Context, scope string, mode packages.LoadMode, subject string, wd string, pkgs []*packages.Package, parseStats *parseFileStats) { + if timing(ctx) == nil { + return + } + stats := summarizeLoadScope(wd, pkgs) + debugf(ctx, "load.debug scope=%s subject=%s mode=%s roots=%d total_pkgs=%d compiled_files=%d syntax_files=%d syntax_pkgs=%d typed_pkgs=%d types_info_pkgs=%d local_pkgs=%d local_syntax_pkgs=%d external_pkgs=%d external_syntax_pkgs=%d unknown_pkgs=%d", + scope, + subject, + formatLoadMode(mode), + stats.roots, + stats.totalPackages, + stats.compiledFiles, + stats.syntaxFiles, + stats.packagesWithSyntax, + stats.packagesWithTypes, + stats.packagesWithTypesInfo, + stats.localPackages, + stats.localSyntaxPackages, + stats.externalPackages, + stats.externalSyntaxPkgs, + stats.unknownPackages, + ) + if len(stats.topCompiled) > 0 { + debugf(ctx, "load.debug scope=%s top_compiled_files=%s", scope, strings.Join(stats.topCompiled, ", ")) + } + if len(stats.topSyntax) > 0 { + debugf(ctx, "load.debug scope=%s top_syntax_files=%s", scope, strings.Join(stats.topSyntax, ", ")) + } + if parseStats != nil { + snap := parseStats.snapshot() + debugf(ctx, "load.debug scope=%s parse.calls=%d parse.primary=%d parse.deps=%d parse.cache_hits=%d parse.cache_misses=%d parse.errors=%d parse.total=%s", + scope, + snap.calls, + snap.primaryCalls, + snap.depCalls, + snap.cacheHits, + snap.cacheMisses, + snap.errors, + snap.total, + ) + } +} + +func summarizeLoadScope(wd string, pkgs []*packages.Package) loadScopeStats { + all := collectAllPackages(pkgs) + stats := loadScopeStats{ + roots: len(pkgs), + totalPackages: len(all), + } + moduleRoot := findModuleRoot(wd) + var compiled []packageMetric + var syntax []packageMetric + for _, pkg := range all { + if pkg == nil { + continue + } + compiledCount := len(pkg.CompiledGoFiles) + syntaxCount := len(pkg.Syntax) + stats.compiledFiles += compiledCount + stats.syntaxFiles += syntaxCount + if syntaxCount > 0 { + stats.packagesWithSyntax++ + } + if pkg.Types != nil { + stats.packagesWithTypes++ + } + if pkg.TypesInfo != nil { + stats.packagesWithTypesInfo++ + } + class := classifyPackageLocation(moduleRoot, pkg) + switch class { + case "local": + stats.localPackages++ + if syntaxCount > 0 { + stats.localSyntaxPackages++ + } + case "external": + stats.externalPackages++ + if syntaxCount > 0 { + stats.externalSyntaxPkgs++ + } + default: + stats.unknownPackages++ + } + if compiledCount > 0 { + compiled = append(compiled, packageMetric{path: pkg.PkgPath, count: compiledCount}) + } + if syntaxCount > 0 { + syntax = append(syntax, packageMetric{path: pkg.PkgPath, count: syntaxCount}) + } + } + stats.topCompiled = topPackageMetrics(compiled) + stats.topSyntax = topPackageMetrics(syntax) + return stats +} + +func classifyPackageLocation(moduleRoot string, pkg *packages.Package) string { + if moduleRoot == "" || pkg == nil { + return "unknown" + } + for _, name := range pkg.CompiledGoFiles { + if isWithinRoot(moduleRoot, name) { + return "local" + } + return "external" + } + for _, name := range pkg.GoFiles { + if isWithinRoot(moduleRoot, name) { + return "local" + } + return "external" + } + return "unknown" +} + +func isWithinRoot(root, name string) bool { + cleanRoot := filepath.Clean(root) + cleanName := filepath.Clean(name) + if cleanName == cleanRoot { + return true + } + rel, err := filepath.Rel(cleanRoot, cleanName) + if err != nil { + return false + } + return rel != ".." && !strings.HasPrefix(rel, ".."+string(os.PathSeparator)) +} + +func topPackageMetrics(metrics []packageMetric) []string { + sort.Slice(metrics, func(i, j int) bool { + if metrics[i].count == metrics[j].count { + return metrics[i].path < metrics[j].path + } + return metrics[i].count > metrics[j].count + }) + if len(metrics) > 5 { + metrics = metrics[:5] + } + out := make([]string, 0, len(metrics)) + for _, m := range metrics { + out = append(out, fmt.Sprintf("%s(%d)", m.path, m.count)) + } + return out +} + +func findModuleRoot(wd string) string { + dir := filepath.Clean(wd) + for { + if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil { + return dir + } + parent := filepath.Dir(dir) + if parent == dir { + return "" + } + dir = parent + } +} + +func formatLoadMode(mode packages.LoadMode) string { + flags := []struct { + bit packages.LoadMode + name string + }{ + {packages.NeedName, "NeedName"}, + {packages.NeedFiles, "NeedFiles"}, + {packages.NeedCompiledGoFiles, "NeedCompiledGoFiles"}, + {packages.NeedImports, "NeedImports"}, + {packages.NeedDeps, "NeedDeps"}, + {packages.NeedExportsFile, "NeedExportsFile"}, + {packages.NeedTypes, "NeedTypes"}, + {packages.NeedSyntax, "NeedSyntax"}, + {packages.NeedTypesInfo, "NeedTypesInfo"}, + {packages.NeedTypesSizes, "NeedTypesSizes"}, + {packages.NeedModule, "NeedModule"}, + {packages.NeedEmbedFiles, "NeedEmbedFiles"}, + {packages.NeedEmbedPatterns, "NeedEmbedPatterns"}, + } + var parts []string + for _, flag := range flags { + if mode&flag.bit != 0 { + parts = append(parts, flag.name) + } + } + if len(parts) == 0 { + return "0" + } + return strings.Join(parts, "|") +} + +func primaryFileSet(files map[string]struct{}) map[string]struct{} { + if len(files) == 0 { + return nil + } + out := make(map[string]struct{}, len(files)) + for name := range files { + out[filepath.Clean(name)] = struct{}{} + } + return out +} + +func isPrimaryFile(primary map[string]struct{}, filename string) bool { + if len(primary) == 0 { + return false + } + _, ok := primary[filepath.Clean(filename)] + return ok +} diff --git a/internal/wire/loader_test.go b/internal/wire/loader_test.go index 1fbd96c..2f41c8d 100644 --- a/internal/wire/loader_test.go +++ b/internal/wire/loader_test.go @@ -20,6 +20,7 @@ import ( "path/filepath" "strings" "testing" + "time" ) func TestLoadAndGenerateModule(t *testing.T) { @@ -124,6 +125,925 @@ func TestLoadAndGenerateModule(t *testing.T) { } } +func TestLoadAndGenerateModuleIncrementalMatches(t *testing.T) { + repoRoot := mustRepoRoot(t) + root := t.TempDir() + + writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *dep.Foo {", + "\twire.Build(dep.NewSet)", + "\treturn nil", + "}", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return \"ok\" }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, New)", + "", + }, "\n")) + + env := append(os.Environ(), "GOWORK=off") + + info, errs := Load(context.Background(), root, env, "", []string{"./app"}) + if len(errs) > 0 { + t.Fatalf("Load returned errors: %v", errs) + } + if info == nil || len(info.Injectors) != 1 { + t.Fatalf("Load returned unexpected info: %+v errs=%v", info, errs) + } + + incrementalCtx := WithIncremental(context.Background(), true) + incrementalInfo, errs := Load(incrementalCtx, root, env, "", []string{"./app"}) + if len(errs) > 0 { + t.Fatalf("incremental Load returned errors: %v", errs) + } + if incrementalInfo == nil || len(incrementalInfo.Injectors) != 1 { + t.Fatalf("incremental Load returned unexpected info: %+v errs=%v", incrementalInfo, errs) + } + + normalGens, errs := Generate(context.Background(), root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("Generate returned errors: %v", errs) + } + incrementalGens, errs := Generate(incrementalCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("incremental Generate returned errors: %v", errs) + } + if len(normalGens) != 1 || len(incrementalGens) != 1 { + t.Fatalf("unexpected result counts: normal=%d incremental=%d", len(normalGens), len(incrementalGens)) + } + if len(normalGens[0].Errs) > 0 || len(incrementalGens[0].Errs) > 0 { + t.Fatalf("unexpected generate errors: normal=%v incremental=%v", normalGens[0].Errs, incrementalGens[0].Errs) + } + if normalGens[0].OutputPath != incrementalGens[0].OutputPath { + t.Fatalf("output paths differ: normal=%q incremental=%q", normalGens[0].OutputPath, incrementalGens[0].OutputPath) + } + if string(normalGens[0].Content) != string(incrementalGens[0].Content) { + t.Fatalf("generated content differs between normal and incremental modes") + } +} + +func TestGenerateIncrementalManifestSkipsLazyLoadOnBodyOnlyChange(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + + writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *dep.Foo {", + "\twire.Build(dep.NewSet)", + "\treturn nil", + "}", + "", + }, "\n")) + writeFile(t, filepath.Join(root, "app", "wire_gen.go"), strings.Join([]string{ + "//go:build !wireinject", + "", + "package app", + "", + "func generated() {}", + "", + }, "\n")) + writeFile(t, filepath.Join(root, "app", "app_test.go"), strings.Join([]string{ + "package app", + "", + "func testOnly() {}", + "", + }, "\n")) + + depFile := filepath.Join(root, "dep", "dep.go") + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return \"a\" }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, New)", + "", + }, "\n")) + + env := append(os.Environ(), "GOWORK=off") + ctx := WithIncremental(context.Background(), true) + + var firstLabels []string + firstCtx := WithTiming(ctx, func(label string, _ time.Duration) { + firstLabels = append(firstLabels, label) + }) + first, errs := Generate(firstCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("first Generate returned errors: %v", errs) + } + if len(first) != 1 || len(first[0].Errs) > 0 { + t.Fatalf("unexpected first Generate result: %+v", first) + } + if !containsLabel(firstLabels, "load.packages.lazy.load") { + t.Fatalf("expected first incremental generate to perform lazy load, labels=%v", firstLabels) + } + + if err := os.WriteFile(depFile, []byte(strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return \"b\" }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n")), 0644); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + + var secondLabels []string + secondCtx := WithTiming(ctx, func(label string, _ time.Duration) { + secondLabels = append(secondLabels, label) + }) + second, errs := Generate(secondCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("second Generate returned errors: %v", errs) + } + if len(second) != 1 || len(second[0].Errs) > 0 { + t.Fatalf("unexpected second Generate result: %+v", second) + } + if containsLabel(secondLabels, "generate.load") { + t.Fatalf("expected second Generate to hit preload incremental manifest before package load, labels=%v", secondLabels) + } + if containsLabel(secondLabels, "load.packages.lazy.load") { + t.Fatalf("expected second Generate to skip lazy load, labels=%v", secondLabels) + } + if string(first[0].Content) != string(second[0].Content) { + t.Fatal("expected body-only change to reuse identical generated output") + } +} + +func TestGenerateIncrementalShapeChangeFallsBackToLazyLoad(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + + writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *dep.Foo {", + "\twire.Build(dep.NewSet)", + "\treturn nil", + "}", + "", + }, "\n")) + + depFile := filepath.Join(root, "dep", "dep.go") + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return \"a\" }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n")) + + wireFile := filepath.Join(root, "dep", "wire.go") + writeFile(t, wireFile, strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, New)", + "", + }, "\n")) + + env := append(os.Environ(), "GOWORK=off") + ctx := WithIncremental(context.Background(), true) + + var firstLabels []string + firstCtx := WithTiming(ctx, func(label string, _ time.Duration) { + firstLabels = append(firstLabels, label) + }) + first, errs := Generate(firstCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("first Generate returned errors: %v", errs) + } + if len(first) != 1 || len(first[0].Errs) > 0 { + t.Fatalf("unexpected first Generate result: %+v", first) + } + if !containsLabel(firstLabels, "load.packages.lazy.load") { + t.Fatalf("expected first incremental generate to perform lazy load, labels=%v", firstLabels) + } + + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string; Count int }", + "", + "func NewMessage() string { return \"a\" }", + "", + "func NewCount() int { return 7 }", + "", + "func New(msg string, count int) *Foo {", + "\treturn &Foo{Message: msg, Count: count}", + "}", + "", + }, "\n")) + writeFile(t, wireFile, strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, NewCount, New)", + "", + }, "\n")) + + var secondLabels []string + secondCtx := WithTiming(ctx, func(label string, _ time.Duration) { + secondLabels = append(secondLabels, label) + }) + second, errs := Generate(secondCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("second Generate returned errors: %v", errs) + } + if len(second) != 1 || len(second[0].Errs) > 0 { + t.Fatalf("unexpected second Generate result: %+v", second) + } + if containsLabel(secondLabels, "generate.load") { + t.Fatalf("expected shape-changing incremental run to skip package load via local fast path, labels=%v", secondLabels) + } + if containsLabel(secondLabels, "load.packages.lazy.load") { + t.Fatalf("expected shape-changing incremental run to skip lazy load via local fast path, labels=%v", secondLabels) + } + if !containsLabel(secondLabels, "incremental.local_fastpath.load") { + t.Fatalf("expected shape-changing incremental run to use local fast path, labels=%v", secondLabels) + } + if string(first[0].Content) == string(second[0].Content) { + t.Fatal("expected shape-changing edit to regenerate different output") + } +} + +func TestGenerateIncrementalRepeatedShapeStateHitsPreloadManifest(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + + writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *dep.Foo {", + "\twire.Build(dep.NewSet)", + "\treturn nil", + "}", + "", + }, "\n")) + + depFile := filepath.Join(root, "dep", "dep.go") + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string; Count int }", + "", + "func NewMessage() string { return \"a\" }", + "", + "func NewCount() int { return 7 }", + "", + "func New(msg string, count int) *Foo {", + "\treturn &Foo{Message: msg, Count: count}", + "}", + "", + }, "\n")) + writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, NewCount, New)", + "", + }, "\n")) + + env := append(os.Environ(), "GOWORK=off") + ctx := WithIncremental(context.Background(), true) + + first, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("first Generate returned errors: %v", errs) + } + if len(first) != 1 || len(first[0].Errs) > 0 { + t.Fatalf("unexpected first Generate result: %+v", first) + } + + var secondLabels []string + secondCtx := WithTiming(ctx, func(label string, _ time.Duration) { + secondLabels = append(secondLabels, label) + }) + second, errs := Generate(secondCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("second Generate returned errors: %v", errs) + } + if len(second) != 1 || len(second[0].Errs) > 0 { + t.Fatalf("unexpected second Generate result: %+v", second) + } + if containsLabel(secondLabels, "generate.load") { + t.Fatalf("expected repeated shape state to hit preload manifest before package load, labels=%v", secondLabels) + } + if containsLabel(secondLabels, "load.packages.lazy.load") { + t.Fatalf("expected repeated shape state to skip lazy load, labels=%v", secondLabels) + } + if string(first[0].Content) != string(second[0].Content) { + t.Fatal("expected repeated shape state to reuse identical generated output") + } +} + +func TestGenerateIncrementalShapeChangeThenRepeatHitsPreloadManifest(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + + writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *dep.Foo {", + "\twire.Build(dep.NewSet)", + "\treturn nil", + "}", + "", + }, "\n")) + writeFile(t, filepath.Join(root, "extra", "extra.go"), strings.Join([]string{ + "package extra", + "", + "type Marker struct{}", + "", + }, "\n")) + + depFile := filepath.Join(root, "dep", "dep.go") + wireFile := filepath.Join(root, "dep", "wire.go") + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return \"a\" }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n")) + writeFile(t, wireFile, strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, New)", + "", + }, "\n")) + + env := append(os.Environ(), "GOWORK=off") + ctx := WithIncremental(context.Background(), true) + + first, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("first Generate returned errors: %v", errs) + } + if len(first) != 1 || len(first[0].Errs) > 0 { + t.Fatalf("unexpected first Generate result: %+v", first) + } + + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string; Count int }", + "", + "func NewMessage() string { return \"a\" }", + "", + "func NewCount() int { return 7 }", + "", + "func New(msg string, count int) *Foo {", + "\treturn &Foo{Message: msg, Count: count}", + "}", + "", + }, "\n")) + writeFile(t, wireFile, strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, NewCount, New)", + "", + }, "\n")) + + var secondLabels []string + secondCtx := WithTiming(ctx, func(label string, _ time.Duration) { + secondLabels = append(secondLabels, label) + }) + second, errs := Generate(secondCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("second Generate returned errors: %v", errs) + } + if len(second) != 1 || len(second[0].Errs) > 0 { + t.Fatalf("unexpected second Generate result: %+v", second) + } + if containsLabel(secondLabels, "generate.load") { + t.Fatalf("expected shape-changing Generate to skip package load via local fast path, labels=%v", secondLabels) + } + if containsLabel(secondLabels, "load.packages.lazy.load") { + t.Fatalf("expected shape-changing Generate to skip lazy load via local fast path, labels=%v", secondLabels) + } + if !containsLabel(secondLabels, "incremental.local_fastpath.load") { + t.Fatalf("expected shape-changing Generate to use local fast path, labels=%v", secondLabels) + } + + var thirdLabels []string + thirdCtx := WithTiming(ctx, func(label string, _ time.Duration) { + thirdLabels = append(thirdLabels, label) + }) + third, errs := Generate(thirdCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("third Generate returned errors: %v", errs) + } + if len(third) != 1 || len(third[0].Errs) > 0 { + t.Fatalf("unexpected third Generate result: %+v", third) + } + if containsLabel(thirdLabels, "generate.load") { + t.Fatalf("expected repeated shape-changing state to hit preload manifest before package load, labels=%v", thirdLabels) + } + if containsLabel(thirdLabels, "load.packages.lazy.load") { + t.Fatalf("expected repeated shape-changing state to skip lazy load, labels=%v", thirdLabels) + } + if string(second[0].Content) != string(third[0].Content) { + t.Fatal("expected repeated shape-changing state to reuse identical generated output") + } +} + +func TestGenerateIncrementalShapeChangeMatchesNormalGenerate(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + + writeIncrementalBenchmarkModule(t, repoRoot, root) + + env := append(os.Environ(), "GOWORK=off") + incrementalCtx := WithIncremental(context.Background(), true) + + if _, errs := Generate(incrementalCtx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { + t.Fatalf("baseline incremental Generate returned errors: %v", errs) + } + + writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string; Count int }", + "", + "func NewMessage() string { return \"ok\" }", + "", + "func NewCount() int { return 7 }", + "", + "func New(msg string, count int) *Foo {", + "\treturn &Foo{Message: msg, Count: count}", + "}", + "", + }, "\n")) + writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, NewCount, New)", + "", + }, "\n")) + + var incrementalLabels []string + incrementalTimedCtx := WithTiming(incrementalCtx, func(label string, _ time.Duration) { + incrementalLabels = append(incrementalLabels, label) + }) + incrementalGens, errs := Generate(incrementalTimedCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("incremental shape-change Generate returned errors: %v", errs) + } + if len(incrementalGens) != 1 || len(incrementalGens[0].Errs) > 0 { + t.Fatalf("unexpected incremental Generate results: %+v", incrementalGens) + } + if !containsLabel(incrementalLabels, "incremental.local_fastpath.load") { + t.Fatalf("expected incremental shape-change Generate to use local fast path, labels=%v", incrementalLabels) + } + + normalGens, errs := Generate(context.Background(), root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("normal Generate returned errors: %v", errs) + } + if len(normalGens) != 1 || len(normalGens[0].Errs) > 0 { + t.Fatalf("unexpected normal Generate results: %+v", normalGens) + } + if incrementalGens[0].OutputPath != normalGens[0].OutputPath { + t.Fatalf("output paths differ: incremental=%q normal=%q", incrementalGens[0].OutputPath, normalGens[0].OutputPath) + } + if string(incrementalGens[0].Content) != string(normalGens[0].Content) { + t.Fatal("shape-changing incremental output differs from normal Generate output") + } +} + +func TestGenerateIncrementalInvalidShapeChangeDoesNotReuseManifest(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + + writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *dep.Foo {", + "\twire.Build(dep.NewSet)", + "\treturn nil", + "}", + "", + }, "\n")) + + depFile := filepath.Join(root, "dep", "dep.go") + wireFile := filepath.Join(root, "dep", "wire.go") + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return \"a\" }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n")) + writeFile(t, wireFile, strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, New)", + "", + }, "\n")) + + env := append(os.Environ(), "GOWORK=off") + ctx := WithIncremental(context.Background(), true) + + first, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("first Generate returned errors: %v", errs) + } + if len(first) != 1 || len(first[0].Errs) > 0 { + t.Fatalf("unexpected first Generate result: %+v", first) + } + + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "import \"example.com/app/extra\"", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return \"a\" }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n")) + + var secondLabels []string + secondCtx := WithTiming(ctx, func(label string, _ time.Duration) { + secondLabels = append(secondLabels, label) + }) + second, errs := Generate(secondCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(second) != 0 { + t.Fatalf("expected invalid incremental generate to stop before generation, got %+v", second) + } + if len(errs) == 0 { + t.Fatal("expected invalid incremental generate to return errors") + } + if containsLabel(secondLabels, "generate.load") { + t.Fatalf("expected invalid incremental generate to stop before slow-path load, labels=%v", secondLabels) + } + if got := errs[0].Error(); !strings.Contains(got, "type-check failed for example.com/app/app") { + t.Fatalf("expected fast-path type-check error, got %q", got) + } +} + +func TestGenerateIncrementalToggleBackToKnownShapeHitsArchivedPreloadManifest(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + + writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *dep.Foo {", + "\twire.Build(dep.NewSet)", + "\treturn nil", + "}", + "", + }, "\n")) + + depFile := filepath.Join(root, "dep", "dep.go") + wireFile := filepath.Join(root, "dep", "wire.go") + + oldDep := strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return \"a\" }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n") + newDep := strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string; Count int }", + "", + "func NewMessage() string { return \"a\" }", + "", + "func NewCount() int { return 7 }", + "", + "func New(msg string, count int) *Foo {", + "\treturn &Foo{Message: msg, Count: count}", + "}", + "", + }, "\n") + oldWire := strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, New)", + "", + }, "\n") + newWire := strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, NewCount, New)", + "", + }, "\n") + + writeFile(t, depFile, oldDep) + writeFile(t, wireFile, oldWire) + + env := append(os.Environ(), "GOWORK=off") + ctx := WithIncremental(context.Background(), true) + + first, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("first Generate returned errors: %v", errs) + } + if len(first) != 1 || len(first[0].Errs) > 0 { + t.Fatalf("unexpected first Generate result: %+v", first) + } + + writeFile(t, depFile, newDep) + writeFile(t, wireFile, newWire) + second, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("second Generate returned errors: %v", errs) + } + if len(second) != 1 || len(second[0].Errs) > 0 { + t.Fatalf("unexpected second Generate result: %+v", second) + } + + writeFile(t, depFile, oldDep) + writeFile(t, wireFile, oldWire) + + var thirdLabels []string + thirdCtx := WithTiming(ctx, func(label string, _ time.Duration) { + thirdLabels = append(thirdLabels, label) + }) + third, errs := Generate(thirdCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("third Generate returned errors: %v", errs) + } + if len(third) != 1 || len(third[0].Errs) > 0 { + t.Fatalf("unexpected third Generate result: %+v", third) + } + if containsLabel(thirdLabels, "generate.load") { + t.Fatalf("expected toggled-back shape state to hit archived preload manifest before package load, labels=%v", thirdLabels) + } + if containsLabel(thirdLabels, "load.packages.lazy.load") { + t.Fatalf("expected toggled-back shape state to skip lazy load, labels=%v", thirdLabels) + } + if string(first[0].Content) != string(third[0].Content) { + t.Fatal("expected toggled-back shape state to reuse archived generated output") + } +} + +func containsLabel(labels []string, want string) bool { + for _, label := range labels { + if label == want { + return true + } + } + return false +} + func mustRepoRoot(t *testing.T) string { t.Helper() wd, err := os.Getwd() diff --git a/internal/wire/local_fastpath.go b/internal/wire/local_fastpath.go new file mode 100644 index 0000000..466dcc2 --- /dev/null +++ b/internal/wire/local_fastpath.go @@ -0,0 +1,556 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "context" + "fmt" + "go/ast" + "go/format" + importerpkg "go/importer" + "go/parser" + "go/token" + "go/types" + "io" + "os" + "path/filepath" + "runtime" + "sort" + "strings" + "time" + + "golang.org/x/tools/go/packages" +) + +func tryIncrementalLocalFastPath(ctx context.Context, wd string, env []string, patterns []string, opts *GenerateOptions, state *incrementalPreloadState) ([]GenerateResult, bool, bool, []error) { + if state == nil || state.manifest == nil { + return nil, false, false, nil + } + if !strings.HasSuffix(state.reason, ".shape_mismatch") { + return nil, false, false, nil + } + roots := manifestOutputPkgPaths(state.manifest) + if len(roots) != 1 { + return nil, false, false, nil + } + changed := changedPackagePaths(state.manifest.LocalPackages, state.currentLocal) + if len(changed) != 1 { + return nil, false, false, nil + } + graph, ok := readIncrementalGraph(incrementalGraphKey(wd, opts.Tags, roots)) + if !ok { + return nil, false, false, nil + } + affected := affectedRoots(graph, changed) + if len(affected) != 1 || affected[0] != roots[0] { + return nil, false, false, nil + } + + fastPathStart := time.Now() + loaded, err := loadLocalPackagesForFastPath(ctx, wd, opts.Tags, roots[0], state.currentLocal, state.manifest.ExternalPkgs) + if err != nil { + debugf(ctx, "incremental.local_fastpath miss reason=%v", err) + if shouldBypassIncrementalManifestAfterFastPathError(err) { + return nil, true, true, []error{err} + } + return nil, false, false, nil + } + logTiming(ctx, "incremental.local_fastpath.load", fastPathStart) + + generated, errs := generateFromTypedPackages(ctx, loaded.root, loaded.allPackages, opts) + logTiming(ctx, "incremental.local_fastpath.generate", fastPathStart) + if len(errs) > 0 { + return nil, true, true, errs + } + + snapshot := &incrementalFingerprintSnapshot{ + fingerprints: loaded.fingerprints, + changed: append([]string(nil), changed...), + } + loader := &lazyLoader{ + ctx: ctx, + wd: wd, + env: env, + tags: opts.Tags, + fset: loaded.fset, + fingerprints: snapshot, + loaded: make(map[string]*packages.Package, len(loaded.byPath)), + } + for path, pkg := range loaded.byPath { + loader.loaded[path] = pkg + } + writeIncrementalFingerprints(snapshot, wd, opts.Tags) + writeIncrementalPackageSummaries(loader, loaded.allPackages) + writeIncrementalManifestFromState(wd, env, patterns, opts, state, snapshot, generated) + writeIncrementalGraphFromSnapshot(wd, opts.Tags, roots, loaded.fingerprints) + + debugf(ctx, "incremental.local_fastpath hit root=%s changed=%s", roots[0], strings.Join(changed, ",")) + return generated, true, false, nil +} + +func shouldBypassIncrementalManifestAfterFastPathError(err error) bool { + if err == nil { + return false + } + msg := err.Error() + return strings.Contains(msg, "type-check failed for ") +} + +func formatLocalTypeCheckError(wd string, pkgPath string, errs []packages.Error) error { + if len(errs) == 0 { + return fmt.Errorf("type-check failed for %s", pkgPath) + } + root := findModuleRoot(wd) + lines := []string{} + for _, pkgErr := range errs { + details := normalizeErrorLines(pkgErr.Msg, root) + if len(details) == 0 { + continue + } + lines = append(lines, fmt.Sprintf("type-check failed for %s: %s", pkgPath, details[0])) + for _, line := range details[1:] { + lines = append(lines, line) + } + } + if len(lines) == 0 { + lines = append(lines, fmt.Sprintf("type-check failed for %s", pkgPath)) + } + return fmt.Errorf("%s", strings.Join(lines, "\n")) +} + +func normalizeErrorLines(msg string, root string) []string { + msg = strings.TrimSpace(msg) + if msg == "" { + return []string{"unknown error"} + } + lines := unfoldTypeCheckChain(msg) + for i := range lines { + lines[i] = relativizeErrorLine(lines[i], root) + } + if len(lines) == 0 { + return []string{"unknown error"} + } + return lines +} + +func relativizeErrorLine(line string, root string) string { + if root == "" { + return line + } + cleanRoot := filepath.Clean(root) + prefix := cleanRoot + string(os.PathSeparator) + return strings.ReplaceAll(line, prefix, "") +} + +func unfoldTypeCheckChain(msg string) []string { + msg = strings.TrimSpace(msg) + if msg == "" { + return nil + } + if inner, outer, ok := splitNestedTypeCheck(msg); ok { + lines := []string{strings.TrimSpace(outer)} + return append(lines, unfoldTypeCheckChain(inner)...) + } + parts := strings.Split(msg, "\n") + lines := make([]string, 0, len(parts)) + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + lines = append(lines, part) + } + return lines +} + +func splitNestedTypeCheck(msg string) (inner string, outer string, ok bool) { + msg = strings.TrimSpace(msg) + if len(msg) < 2 || msg[len(msg)-1] != ')' { + return "", "", false + } + depth := 0 + for i := len(msg) - 1; i >= 0; i-- { + switch msg[i] { + case ')': + depth++ + case '(': + depth-- + if depth == 0 { + inner = strings.TrimSpace(msg[i+1 : len(msg)-1]) + if strings.HasPrefix(inner, "type-check failed for ") { + return inner, strings.TrimSpace(msg[:i]), true + } + return "", "", false + } + } + } + return "", "", false +} + +type localFastPathLoaded struct { + fset *token.FileSet + root *packages.Package + allPackages []*packages.Package + byPath map[string]*packages.Package + fingerprints map[string]*packageFingerprint +} + +type localFastPathLoader struct { + ctx context.Context + wd string + tags string + fset *token.FileSet + rootPkgPath string + meta map[string]*packageFingerprint + pkgs map[string]*packages.Package + externalMeta map[string]externalPackageExport + externalImp types.Importer +} + +func loadLocalPackagesForFastPath(ctx context.Context, wd string, tags string, rootPkgPath string, current []packageFingerprint, external []externalPackageExport) (*localFastPathLoaded, error) { + meta := fingerprintsFromSlice(current) + if len(meta) == 0 { + return nil, fmt.Errorf("no local fingerprints") + } + if meta[rootPkgPath] == nil { + return nil, fmt.Errorf("missing root package fingerprint") + } + externalMeta := make(map[string]externalPackageExport, len(external)) + for _, item := range external { + if item.PkgPath == "" || item.ExportFile == "" { + continue + } + externalMeta[item.PkgPath] = item + } + loader := &localFastPathLoader{ + ctx: ctx, + wd: wd, + tags: tags, + fset: token.NewFileSet(), + rootPkgPath: rootPkgPath, + meta: meta, + pkgs: make(map[string]*packages.Package, len(meta)), + externalMeta: externalMeta, + } + loader.externalImp = importerpkg.ForCompiler(loader.fset, "gc", loader.openExternalExport) + root, err := loader.load(rootPkgPath) + if err != nil { + return nil, err + } + all := make([]*packages.Package, 0, len(loader.pkgs)) + for _, pkg := range loader.pkgs { + all = append(all, pkg) + } + sort.Slice(all, func(i, j int) bool { return all[i].PkgPath < all[j].PkgPath }) + return &localFastPathLoaded{ + fset: loader.fset, + root: root, + allPackages: all, + byPath: loader.pkgs, + fingerprints: loader.meta, + }, nil +} + +func (l *localFastPathLoader) load(pkgPath string) (*packages.Package, error) { + if pkg := l.pkgs[pkgPath]; pkg != nil { + return pkg, nil + } + fp := l.meta[pkgPath] + if fp == nil { + return nil, fmt.Errorf("package %s not tracked as local", pkgPath) + } + files := filesFromMeta(fp.Files) + if len(files) == 0 { + return nil, fmt.Errorf("package %s has no files", pkgPath) + } + mode := parser.ParseComments | parser.SkipObjectResolution + syntax := make([]*ast.File, 0, len(files)) + for _, name := range files { + file, err := parser.ParseFile(l.fset, name, nil, mode) + if err != nil { + return nil, err + } + syntax = append(syntax, file) + } + if len(syntax) == 0 { + return nil, fmt.Errorf("package %s parsed no files", pkgPath) + } + + pkgName := syntax[0].Name.Name + info := &types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + Defs: make(map[*ast.Ident]types.Object), + Uses: make(map[*ast.Ident]types.Object), + Implicits: make(map[ast.Node]types.Object), + Selections: make(map[*ast.SelectorExpr]*types.Selection), + Scopes: make(map[ast.Node]*types.Scope), + Instances: make(map[*ast.Ident]types.Instance), + } + pkg := &packages.Package{ + Fset: l.fset, + Name: pkgName, + PkgPath: pkgPath, + GoFiles: append([]string(nil), files...), + CompiledGoFiles: append([]string(nil), files...), + Syntax: syntax, + TypesInfo: info, + Imports: make(map[string]*packages.Package), + } + l.pkgs[pkgPath] = pkg + + conf := &types.Config{ + Importer: importerFunc(func(path string) (*types.Package, error) { + return l.importPackage(path) + }), + Sizes: types.SizesFor("gc", runtime.GOARCH), + Error: func(err error) { + pkg.Errors = append(pkg.Errors, packages.Error{Msg: err.Error()}) + }, + } + checkedPkg, err := conf.Check(pkgPath, l.fset, syntax, info) + if checkedPkg != nil { + pkg.Types = checkedPkg + } + if err != nil && len(pkg.Errors) == 0 { + return nil, err + } + if len(pkg.Errors) > 0 { + return nil, formatLocalTypeCheckError(l.wd, pkgPath, pkg.Errors) + } + + imports := packageImportPaths(syntax) + localImports := make([]string, 0, len(imports)) + for _, path := range imports { + if dep := l.pkgs[path]; dep != nil { + pkg.Imports[path] = dep + localImports = append(localImports, path) + } + } + sort.Strings(localImports) + updated := *fp + updated.LocalImports = localImports + updated.Tags = l.tags + updated.WD = filepath.Clean(l.wd) + l.meta[pkgPath] = &updated + return pkg, nil +} + +func (l *localFastPathLoader) importPackage(path string) (*types.Package, error) { + if l.meta[path] != nil { + pkg, err := l.load(path) + if err != nil { + return nil, err + } + return pkg.Types, nil + } + if l.externalImp == nil { + return nil, fmt.Errorf("missing external importer") + } + return l.externalImp.Import(path) +} + +func (l *localFastPathLoader) openExternalExport(path string) (io.ReadCloser, error) { + meta, ok := l.externalMeta[path] + if !ok || meta.ExportFile == "" { + return nil, fmt.Errorf("missing export data for %s", path) + } + return os.Open(meta.ExportFile) +} + +type importerFunc func(string) (*types.Package, error) + +func (fn importerFunc) Import(path string) (*types.Package, error) { + return fn(path) +} + +func packageImportPaths(files []*ast.File) []string { + seen := make(map[string]struct{}) + var out []string + for _, file := range files { + for _, spec := range file.Imports { + path := strings.Trim(spec.Path.Value, "\"") + if path == "" { + continue + } + if _, ok := seen[path]; ok { + continue + } + seen[path] = struct{}{} + out = append(out, path) + } + } + sort.Strings(out) + return out +} + +func generateFromTypedPackages(ctx context.Context, root *packages.Package, allPkgs []*packages.Package, opts *GenerateOptions) ([]GenerateResult, []error) { + if root == nil { + return nil, []error{fmt.Errorf("missing root package")} + } + if opts == nil { + opts = &GenerateOptions{} + } + pkgStart := time.Now() + res := GenerateResult{PkgPath: root.PkgPath} + outDir, err := detectOutputDir(root.GoFiles) + logTiming(ctx, "generate.package."+root.PkgPath+".output_dir", pkgStart) + if err != nil { + res.Errs = append(res.Errs, err) + return []GenerateResult{res}, nil + } + res.OutputPath = filepath.Join(outDir, opts.PrefixOutputFile+"wire_gen.go") + + oc := newObjectCache(allPkgs, nil) + g := newGen(root) + injectorStart := time.Now() + injectorFiles, errs := generateInjectors(oc, g, root) + logTiming(ctx, "generate.package."+root.PkgPath+".injectors", injectorStart) + if len(errs) > 0 { + res.Errs = errs + return []GenerateResult{res}, nil + } + copyStart := time.Now() + copyNonInjectorDecls(g, injectorFiles, root.TypesInfo) + logTiming(ctx, "generate.package."+root.PkgPath+".copy_non_injectors", copyStart) + frameStart := time.Now() + goSrc := g.frame(opts.Tags) + logTiming(ctx, "generate.package."+root.PkgPath+".frame", frameStart) + if len(opts.Header) > 0 { + goSrc = append(opts.Header, goSrc...) + } + formatStart := time.Now() + fmtSrc, err := format.Source(goSrc) + logTiming(ctx, "generate.package."+root.PkgPath+".format", formatStart) + if err != nil { + res.Errs = append(res.Errs, err) + } else { + goSrc = fmtSrc + } + res.Content = goSrc + logTiming(ctx, "generate.package."+root.PkgPath+".total", pkgStart) + return []GenerateResult{res}, nil +} + +func writeIncrementalFingerprints(snapshot *incrementalFingerprintSnapshot, wd string, tags string) { + if snapshot == nil { + return + } + for _, fp := range snapshotPackageFingerprints(snapshot) { + fp := fp + writeIncrementalFingerprint(incrementalFingerprintKey(wd, tags, fp.PkgPath), &fp) + } +} + +func writeIncrementalManifestFromState(wd string, env []string, patterns []string, opts *GenerateOptions, state *incrementalPreloadState, snapshot *incrementalFingerprintSnapshot, generated []GenerateResult) { + if snapshot == nil || len(generated) == 0 || state == nil || state.manifest == nil { + return + } + manifest := &incrementalManifest{ + Version: incrementalManifestVersion, + WD: filepath.Clean(wd), + Tags: opts.Tags, + Prefix: opts.PrefixOutputFile, + HeaderHash: headerHash(opts.Header), + EnvHash: envHash(env), + Patterns: sortedStrings(patterns), + LocalPackages: snapshotPackageFingerprints(snapshot), + ExternalPkgs: append([]externalPackageExport(nil), state.manifest.ExternalPkgs...), + ExternalFiles: append([]cacheFile(nil), state.manifest.ExternalFiles...), + ExtraFiles: extraCacheFiles(wd), + } + for _, out := range generated { + if len(out.Content) == 0 || out.OutputPath == "" { + continue + } + contentKey := incrementalContentKey(out.Content) + writeCache(contentKey, out.Content) + manifest.Outputs = append(manifest.Outputs, incrementalOutput{ + PkgPath: out.PkgPath, + OutputPath: out.OutputPath, + ContentKey: contentKey, + }) + } + if len(manifest.Outputs) == 0 { + return + } + selectorKey := incrementalManifestSelectorKey(wd, env, patterns, opts) + writeIncrementalManifestFile(selectorKey, manifest) + writeIncrementalManifestFile(incrementalManifestStateKey(selectorKey, manifest.LocalPackages), manifest) +} + +func writeIncrementalGraphFromSnapshot(wd string, tags string, roots []string, fps map[string]*packageFingerprint) { + if len(roots) == 0 || len(fps) == 0 { + return + } + graph := &incrementalGraph{ + Version: incrementalGraphVersion, + WD: filepath.Clean(wd), + Tags: tags, + Roots: append([]string(nil), roots...), + LocalReverse: make(map[string][]string), + } + sort.Strings(graph.Roots) + for _, fp := range fps { + if fp == nil { + continue + } + for _, imp := range fp.LocalImports { + graph.LocalReverse[imp] = append(graph.LocalReverse[imp], fp.PkgPath) + } + } + for path := range graph.LocalReverse { + sort.Strings(graph.LocalReverse[path]) + } + writeIncrementalGraph(incrementalGraphKey(wd, tags, graph.Roots), graph) +} + +func manifestOutputPkgPaths(manifest *incrementalManifest) []string { + if manifest == nil || len(manifest.Outputs) == 0 { + return nil + } + seen := make(map[string]struct{}, len(manifest.Outputs)) + paths := make([]string, 0, len(manifest.Outputs)) + for _, out := range manifest.Outputs { + if out.PkgPath == "" { + continue + } + if _, ok := seen[out.PkgPath]; ok { + continue + } + seen[out.PkgPath] = struct{}{} + paths = append(paths, out.PkgPath) + } + sort.Strings(paths) + return paths +} + +func changedPackagePaths(previous []packageFingerprint, current []packageFingerprint) []string { + if len(current) == 0 { + return nil + } + prevByPath := make(map[string]packageFingerprint, len(previous)) + for _, fp := range previous { + prevByPath[fp.PkgPath] = fp + } + changed := make([]string, 0, len(current)) + for _, fp := range current { + prev, ok := prevByPath[fp.PkgPath] + if !ok || !incrementalFingerprintEquivalent(&prev, &fp) { + changed = append(changed, fp.PkgPath) + } + } + sort.Strings(changed) + return changed +} diff --git a/internal/wire/parse.go b/internal/wire/parse.go index fc1b353..2f038a9 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -250,6 +250,9 @@ type Field struct { // In case of duplicate environment variables, the last one in the list // takes precedence. func Load(ctx context.Context, wd string, env []string, tags string, patterns []string) (*Info, []error) { + if IncrementalEnabled(ctx, env) { + debugf(ctx, "incremental=enabled") + } loadStart := time.Now() pkgs, loader, errs := load(ctx, wd, env, tags, patterns) logTiming(ctx, "load.packages", loadStart) @@ -365,7 +368,13 @@ func Load(ctx context.Context, wd string, env []string, tags string, patterns [] // In case of duplicate environment variables, the last one in the list // takes precedence. func load(ctx context.Context, wd string, env []string, tags string, patterns []string) ([]*packages.Package, *lazyLoader, []error) { + var session *incrementalSession fset := token.NewFileSet() + if IncrementalEnabled(ctx, env) { + session = getIncrementalSession(wd, env, tags) + fset = session.fset + debugf(ctx, "incremental session=enabled") + } baseCfg := &packages.Config{ Context: ctx, Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps, @@ -384,6 +393,7 @@ func load(ctx context.Context, wd string, env []string, tags string, patterns [] baseLoadStart := time.Now() pkgs, err := packages.Load(baseCfg, escaped...) logTiming(ctx, "load.packages.base.load", baseLoadStart) + logLoadDebug(ctx, "base", baseCfg.Mode, strings.Join(patterns, ","), wd, pkgs, nil) if err != nil { return nil, nil, []error{err} } @@ -393,15 +403,19 @@ func load(ctx context.Context, wd string, env []string, tags string, patterns [] if len(errs) > 0 { return nil, nil, errs } + fingerprints := analyzeIncrementalFingerprints(ctx, wd, env, tags, pkgs) + analyzeIncrementalGraph(ctx, wd, env, tags, pkgs, fingerprints) baseFiles := collectPackageFiles(pkgs) loader := &lazyLoader{ - ctx: ctx, - wd: wd, - env: env, - tags: tags, - fset: fset, - baseFiles: baseFiles, + ctx: ctx, + wd: wd, + env: env, + tags: tags, + fset: fset, + baseFiles: baseFiles, + session: session, + fingerprints: fingerprints, } return pkgs, loader, nil } diff --git a/internal/wire/parser_lazy_loader.go b/internal/wire/parser_lazy_loader.go index b3d7011..223c9ad 100644 --- a/internal/wire/parser_lazy_loader.go +++ b/internal/wire/parser_lazy_loader.go @@ -26,12 +26,15 @@ import ( ) type lazyLoader struct { - ctx context.Context - wd string - env []string - tags string - fset *token.FileSet - baseFiles map[string]map[string]struct{} + ctx context.Context + wd string + env []string + tags string + fset *token.FileSet + baseFiles map[string]map[string]struct{} + session *incrementalSession + fingerprints *incrementalFingerprintSnapshot + loaded map[string]*packages.Package } func collectPackageFiles(pkgs []*packages.Package) map[string]map[string]struct{} { @@ -74,10 +77,11 @@ func (ll *lazyLoader) load(pkgPath string) ([]*packages.Package, []error) { } func (ll *lazyLoader) fullMode() packages.LoadMode { - return packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax + return packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile } func (ll *lazyLoader) loadWithMode(pkgPath string, mode packages.LoadMode, timingLabel string) ([]*packages.Package, []error) { + parseStats := &parseFileStats{} cfg := &packages.Config{ Context: ll.ctx, Mode: mode, @@ -85,7 +89,7 @@ func (ll *lazyLoader) loadWithMode(pkgPath string, mode packages.LoadMode, timin Env: ll.env, BuildFlags: []string{"-tags=wireinject"}, Fset: ll.fset, - ParseFile: ll.parseFileFor(pkgPath), + ParseFile: ll.parseFileFor(pkgPath, parseStats), } if len(ll.tags) > 0 { cfg.BuildFlags[0] += " " + ll.tags @@ -93,6 +97,7 @@ func (ll *lazyLoader) loadWithMode(pkgPath string, mode packages.LoadMode, timin loadStart := time.Now() pkgs, err := packages.Load(cfg, "pattern="+pkgPath) logTiming(ll.ctx, timingLabel, loadStart) + logLoadDebug(ll.ctx, "lazy", mode, pkgPath, ll.wd, pkgs, parseStats) if err != nil { return nil, []error{err} } @@ -100,26 +105,52 @@ func (ll *lazyLoader) loadWithMode(pkgPath string, mode packages.LoadMode, timin if len(errs) > 0 { return nil, errs } + ll.rememberPackages(pkgs) return pkgs, nil } -func (ll *lazyLoader) parseFileFor(pkgPath string) func(*token.FileSet, string, []byte) (*ast.File, error) { - primary := ll.baseFiles[pkgPath] +func (ll *lazyLoader) rememberPackages(pkgs []*packages.Package) { + if ll == nil || len(pkgs) == 0 { + return + } + if ll.loaded == nil { + ll.loaded = make(map[string]*packages.Package) + } + for path, pkg := range collectAllPackages(pkgs) { + if pkg != nil { + ll.loaded[path] = pkg + } + } +} + +func (ll *lazyLoader) parseFileFor(pkgPath string, stats *parseFileStats) func(*token.FileSet, string, []byte) (*ast.File, error) { + primary := primaryFileSet(ll.baseFiles[pkgPath]) return func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { - mode := parser.SkipObjectResolution - if primary != nil { - if _, ok := primary[filepath.Clean(filename)]; ok { - mode = parser.ParseComments | parser.SkipObjectResolution + start := time.Now() + isPrimary := isPrimaryFile(primary, filename) + if !isPrimary && ll.session != nil { + if file, ok := ll.session.getParsedDep(filename, src); ok { + if stats != nil { + stats.record(false, time.Since(start), nil, true) + } + return file, nil } } + mode := parser.SkipObjectResolution + if isPrimary { + mode = parser.ParseComments | parser.SkipObjectResolution + } file, err := parser.ParseFile(fset, filename, src, mode) + if stats != nil { + stats.record(isPrimary, time.Since(start), err, false) + } if err != nil { return nil, err } if primary == nil { return file, nil } - if _, ok := primary[filepath.Clean(filename)]; ok { + if isPrimary { return file, nil } for _, decl := range file.Decls { @@ -128,6 +159,9 @@ func (ll *lazyLoader) parseFileFor(pkgPath string) func(*token.FileSet, string, fn.Doc = nil } } + if ll.session != nil { + ll.session.storeParsedDep(filename, src, file) + } return file, nil } } diff --git a/internal/wire/parser_lazy_loader_test.go b/internal/wire/parser_lazy_loader_test.go index 31838ea..86b49da 100644 --- a/internal/wire/parser_lazy_loader_test.go +++ b/internal/wire/parser_lazy_loader_test.go @@ -47,7 +47,7 @@ func TestLazyLoaderParseFileFor(t *testing.T) { "", }, "\n") - parse := ll.parseFileFor(pkgPath) + parse := ll.parseFileFor(pkgPath, &parseFileStats{}) file, err := parse(fset, primary, []byte(src)) if err != nil { t.Fatalf("parse primary: %v", err) @@ -73,6 +73,59 @@ func TestLazyLoaderParseFileFor(t *testing.T) { } } +func TestLazyLoaderParseFileForCachesDependencyFiles(t *testing.T) { + t.Helper() + fset := token.NewFileSet() + pkgPath := "example.com/pkg" + root := t.TempDir() + primary := filepath.Join(root, "primary.go") + secondary := filepath.Join(root, "secondary.go") + session := &incrementalSession{ + fset: fset, + parsedDeps: make(map[string]cachedParsedFile), + } + ll := &lazyLoader{ + fset: fset, + baseFiles: map[string]map[string]struct{}{ + pkgPath: {filepath.Clean(primary): {}}, + }, + session: session, + } + src := []byte(strings.Join([]string{ + "package pkg", + "", + "func Foo() {", + "\tprintln(\"hi\")", + "}", + "", + }, "\n")) + + stats1 := &parseFileStats{} + parse1 := ll.parseFileFor(pkgPath, stats1) + file1, err := parse1(fset, secondary, src) + if err != nil { + t.Fatalf("first parse: %v", err) + } + snap1 := stats1.snapshot() + if snap1.cacheHits != 0 || snap1.cacheMisses != 1 { + t.Fatalf("first parse stats = %+v, want 0 hits and 1 miss", snap1) + } + + stats2 := &parseFileStats{} + parse2 := ll.parseFileFor(pkgPath, stats2) + file2, err := parse2(fset, secondary, src) + if err != nil { + t.Fatalf("second parse: %v", err) + } + if file1 != file2 { + t.Fatal("expected cached dependency parse to reuse AST") + } + snap2 := stats2.snapshot() + if snap2.cacheHits != 1 || snap2.cacheMisses != 0 { + t.Fatalf("second parse stats = %+v, want 1 hit and 0 misses", snap2) + } +} + func TestLoadModuleUsesWireinjectTagsForDeps(t *testing.T) { repoRoot := mustRepoRoot(t) root := t.TempDir() diff --git a/internal/wire/time_compat.go b/internal/wire/time_compat.go new file mode 100644 index 0000000..6f0c9c4 --- /dev/null +++ b/internal/wire/time_compat.go @@ -0,0 +1,22 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import "time" + +var ( + timeNow = time.Now + timeSince = time.Since +) diff --git a/internal/wire/timing.go b/internal/wire/timing.go index 376d573..d83754b 100644 --- a/internal/wire/timing.go +++ b/internal/wire/timing.go @@ -16,6 +16,7 @@ package wire import ( "context" + "log" "time" ) @@ -49,3 +50,10 @@ func logTiming(ctx context.Context, label string, start time.Time) { t(label, time.Since(start)) } } + +func debugf(ctx context.Context, format string, args ...interface{}) { + if timing(ctx) == nil { + return + } + log.Printf("timing: "+format, args...) +} diff --git a/internal/wire/wire.go b/internal/wire/wire.go index aa3efe3..64202dc 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -25,6 +25,7 @@ import ( "go/token" "go/types" "io/ioutil" + "os" "path/filepath" "sort" "strconv" @@ -53,10 +54,27 @@ type GenerateResult struct { // Commit writes the generated file to disk. func (gen GenerateResult) Commit() error { + _, err := gen.CommitWithStatus() + return err +} + +// CommitWithStatus writes the generated file to disk when the content changed. +// It returns whether the file was written. +func (gen GenerateResult) CommitWithStatus() (bool, error) { if len(gen.Content) == 0 { - return nil + return false, nil + } + current, err := os.ReadFile(gen.OutputPath) + if err == nil && bytes.Equal(current, gen.Content) { + return false, nil } - return ioutil.WriteFile(gen.OutputPath, gen.Content, 0666) + if err != nil && !os.IsNotExist(err) { + return false, err + } + if err := ioutil.WriteFile(gen.OutputPath, gen.Content, 0666); err != nil { + return false, err + } + return true, nil } // GenerateOptions holds options for Generate. @@ -83,6 +101,20 @@ func Generate(ctx context.Context, wd string, env []string, patterns []string, o if opts == nil { opts = &GenerateOptions{} } + var preloadState *incrementalPreloadState + bypassIncrementalManifest := false + if IncrementalEnabled(ctx, env) { + debugf(ctx, "incremental=enabled") + preloadState, _ = prepareIncrementalPreloadState(ctx, wd, env, patterns, opts) + if cached, ok := readPreloadIncrementalManifestResultsFromState(ctx, wd, env, patterns, opts, preloadState, preloadState != nil); ok { + return cached, nil + } + if generated, ok, bypass, errs := tryIncrementalLocalFastPath(ctx, wd, env, patterns, opts, preloadState); ok || len(errs) > 0 { + return generated, errs + } else if bypass { + bypassIncrementalManifest = true + } + } if cached, ok := readManifestResults(wd, env, patterns, opts); ok { return cached, nil } @@ -92,16 +124,69 @@ func Generate(ctx context.Context, wd string, env []string, patterns []string, o if len(errs) > 0 { return nil, errs } + if !bypassIncrementalManifest { + if cached, ok := readIncrementalManifestResults(ctx, wd, env, patterns, opts, pkgs, loader.fingerprints); ok { + warmPackageOutputCache(pkgs, opts, cached) + return cached, nil + } + } else { + debugf(ctx, "incremental.manifest bypass reason=fastpath_error") + ctx = withBypassPackageCache(ctx) + } generated := make([]GenerateResult, len(pkgs)) for i, pkg := range pkgs { generated[i] = generateForPackage(ctx, pkg, loader, opts) } if allGeneratedOK(generated) { + if IncrementalEnabled(ctx, env) { + writeIncrementalPackageSummaries(loader, pkgs) + } writeManifest(wd, env, patterns, opts, pkgs) + writeIncrementalManifest(wd, env, patterns, opts, incrementalManifestPackages(pkgs, loader), loader.fingerprints, generated) } return generated, nil } +func warmPackageOutputCache(pkgs []*packages.Package, opts *GenerateOptions, generated []GenerateResult) { + if len(pkgs) == 0 || len(generated) == 0 { + return + } + byPkg := make(map[string][]byte, len(generated)) + for _, gen := range generated { + if len(gen.Content) == 0 { + continue + } + byPkg[gen.PkgPath] = gen.Content + } + for _, pkg := range pkgs { + content := byPkg[pkg.PkgPath] + if len(content) == 0 { + continue + } + key, err := cacheKeyForPackage(pkg, opts) + if err != nil || key == "" { + continue + } + writeCache(key, content) + } +} + +func incrementalManifestPackages(pkgs []*packages.Package, loader *lazyLoader) []*packages.Package { + if loader == nil || len(loader.loaded) == 0 { + return pkgs + } + out := make([]*packages.Package, 0, len(loader.loaded)) + for _, pkg := range loader.loaded { + if pkg != nil { + out = append(out, pkg) + } + } + if len(out) == 0 { + return pkgs + } + return out +} + // generateInjectors generates the injectors for a given package. func generateInjectors(oc *objectCache, g *gen, pkg *packages.Package) (injectorFiles []*ast.File, _ []error) { injectorFiles = make([]*ast.File, 0, len(pkg.Syntax)) diff --git a/internal/wire/wire_test.go b/internal/wire/wire_test.go index 14080df..cb167aa 100644 --- a/internal/wire/wire_test.go +++ b/internal/wire/wire_test.go @@ -111,6 +111,7 @@ func TestWire(t *testing.T) { t.Log(e.Error()) gotErrStrings[i] = scrubError(gopath, e.Error()) } + gotErrStrings = filterLegacyCompilerErrors(gotErrStrings) if !test.wantWireError { t.Fatal("Did not expect errors. To -record an error, create want/wire_errs.txt.") } @@ -191,6 +192,33 @@ func TestGenerateResultCommit(t *testing.T) { } } +func TestGenerateResultCommitWithStatus(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + path := filepath.Join(dir, "wire_gen.go") + gen := GenerateResult{ + OutputPath: path, + Content: []byte("package p\n"), + } + + wrote, err := gen.CommitWithStatus() + if err != nil { + t.Fatalf("first CommitWithStatus failed: %v", err) + } + if !wrote { + t.Fatal("expected first CommitWithStatus call to write") + } + + wrote, err = gen.CommitWithStatus() + if err != nil { + t.Fatalf("second CommitWithStatus failed: %v", err) + } + if wrote { + t.Fatal("expected second CommitWithStatus call to report unchanged") + } +} + func TestZeroValue(t *testing.T) { t.Parallel() @@ -521,6 +549,28 @@ func scrubLineColumn(s string) (replacement string, n int) { return ":x:y", n } +func filterLegacyCompilerErrors(errs []string) []string { + hasCanonicalPath := false + for _, err := range errs { + if strings.HasPrefix(err, "example.com/") { + hasCanonicalPath = true + break + } + } + if !hasCanonicalPath { + return errs + } + + filtered := errs[:0] + for _, err := range errs { + if strings.HasPrefix(err, "-: # ") { + continue + } + filtered = append(filtered, err) + } + return filtered +} + type testCase struct { name string pkg string From e3f07cb43397fbd5e25cb45c334300fc764126ef Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Fri, 13 Mar 2026 23:35:20 -0500 Subject: [PATCH 02/79] feat(incremental): reuse unchanged local packages in fast path and harden fallback behavior --- internal/wire/incremental_bench_test.go | 117 +++++-- internal/wire/incremental_manifest.go | 4 - internal/wire/incremental_summary.go | 11 +- internal/wire/loader_test.go | 142 ++++++++ internal/wire/local_fastpath.go | 371 +++++++++++++++++++-- internal/wire/parse.go | 20 +- internal/wire/summary_provider_resolver.go | 223 +++++++++++++ 7 files changed, 826 insertions(+), 62 deletions(-) create mode 100644 internal/wire/summary_provider_resolver.go diff --git a/internal/wire/incremental_bench_test.go b/internal/wire/incremental_bench_test.go index b981d23..911a8c7 100644 --- a/internal/wire/incremental_bench_test.go +++ b/internal/wire/incremental_bench_test.go @@ -5,15 +5,15 @@ import ( "fmt" "os" "path/filepath" + "strconv" "strings" - "text/tabwriter" "testing" "time" ) const ( largeBenchmarkTestPackageCount = 24 - largeBenchmarkHelperCount = 12 + largeBenchmarkHelperCount = 12 ) var largeBenchmarkSizes = []int{10, 100, 1000} @@ -106,36 +106,42 @@ func TestPrintLargeRepoBenchmarkComparisonTable(t *testing.T) { incremental := measureLargeRepoShapeChangeOnce(t, repoRoot, packageCount, true) knownToggle := measureLargeRepoKnownToggleOnce(t, repoRoot, packageCount) rows = append(rows, largeRepoBenchmarkRow{ - packageCount: packageCount, - coldNormal: coldNormal, - coldIncremental: coldIncremental, - normal: normal, - incremental: incremental, - knownToggle: knownToggle, + packageCount: packageCount, + coldNormal: coldNormal, + coldIncremental: coldIncremental, + normal: normal, + incremental: incremental, + knownToggle: knownToggle, }) } - var out strings.Builder - tw := tabwriter.NewWriter(&out, 0, 0, 2, ' ', 0) - fmt.Fprintln(tw, "size\tcold normal\tcold incr\tcold delta\tcold x\tshape normal\tshape incr\tshape delta\tshape x\tknown toggle") + table := [][]string{{ + "repo size", + "cold old", + "cold new", + "cold delta", + "shape old", + "shape new", + "shape delta", + "known toggle", + "cold speedup", + "shape speedup", + }} for _, row := range rows { - fmt.Fprintf(tw, "%d\t%s\t%s\t%s\t%.2fx\t%s\t%s\t%s\t%.2fx\t%s\n", - row.packageCount, + table = append(table, []string{ + strconv.Itoa(row.packageCount), formatBenchmarkDuration(row.coldNormal), formatBenchmarkDuration(row.coldIncremental), formatPercentImprovement(row.coldNormal, row.coldIncremental), - speedupRatio(row.coldNormal, row.coldIncremental), formatBenchmarkDuration(row.normal), formatBenchmarkDuration(row.incremental), formatPercentImprovement(row.normal, row.incremental), - speedupRatio(row.normal, row.incremental), formatBenchmarkDuration(row.knownToggle), - ) - } - if err := tw.Flush(); err != nil { - t.Fatalf("flush benchmark table: %v", err) + fmt.Sprintf("%.2fx", speedupRatio(row.coldNormal, row.coldIncremental)), + fmt.Sprintf("%.2fx", speedupRatio(row.normal, row.incremental)), + }) } - fmt.Print(out.String()) + fmt.Print(renderASCIITable(table)) } func TestPrintLargeRepoShapeChangeBreakdownTable(t *testing.T) { @@ -151,27 +157,35 @@ func TestPrintLargeRepoShapeChangeBreakdownTable(t *testing.T) { }) repoRoot := benchmarkRepoRoot(t) - var out strings.Builder - tw := tabwriter.NewWriter(&out, 0, 0, 2, ' ', 0) - fmt.Fprintln(tw, "size\tnormal total\tbase load\tlazy load\tincr total\tfast load\tfast generate\tspeedup") + rows := [][]string{{ + "repo size", + "old total", + "old base load", + "old typed load", + "new total", + "new local load", + "new cached sets", + "new injector solve", + "new generate", + "speedup", + }} for _, packageCount := range largeBenchmarkSizes { normal := measureLargeRepoShapeChangeTraceOnce(t, repoRoot, packageCount, false) incremental := measureLargeRepoShapeChangeTraceOnce(t, repoRoot, packageCount, true) - fmt.Fprintf(tw, "%d\t%s\t%s\t%s\t%s\t%s\t%s\t%.2fx\n", - packageCount, + rows = append(rows, []string{ + strconv.Itoa(packageCount), formatBenchmarkDuration(normal.total), formatBenchmarkDuration(normal.label("load.packages.base.load")), formatBenchmarkDuration(normal.label("load.packages.lazy.load")), formatBenchmarkDuration(incremental.total), formatBenchmarkDuration(incremental.label("incremental.local_fastpath.load")), + formatBenchmarkDuration(incremental.label("incremental.local_fastpath.summary_resolve")), + formatBenchmarkDuration(incremental.label("generate.package.example.com/app/app.injectors")), formatBenchmarkDuration(incremental.label("incremental.local_fastpath.generate")), - speedupRatio(normal.total, incremental.total), - ) - } - if err := tw.Flush(); err != nil { - t.Fatalf("flush breakdown table: %v", err) + fmt.Sprintf("%.2fx", speedupRatio(normal.total, incremental.total)), + }) } - fmt.Print(out.String()) + fmt.Print(renderASCIITable(rows)) } func writeIncrementalBenchmarkModule(tb testing.TB, repoRoot string, root string) { @@ -652,3 +666,44 @@ func writeBenchmarkFile(tb testing.TB, path string, content string) { tb.Fatalf("WriteFile failed: %v", err) } } + +func renderASCIITable(rows [][]string) string { + if len(rows) == 0 { + return "" + } + widths := make([]int, len(rows[0])) + for _, row := range rows { + for i, cell := range row { + if len(cell) > widths[i] { + widths[i] = len(cell) + } + } + } + var b strings.Builder + border := func() { + b.WriteByte('+') + for _, width := range widths { + b.WriteString(strings.Repeat("-", width+2)) + b.WriteByte('+') + } + b.WriteByte('\n') + } + writeRow := func(row []string) { + b.WriteByte('|') + for i, cell := range row { + b.WriteByte(' ') + b.WriteString(cell) + b.WriteString(strings.Repeat(" ", widths[i]-len(cell)+1)) + b.WriteByte('|') + } + b.WriteByte('\n') + } + border() + writeRow(rows[0]) + border() + for _, row := range rows[1:] { + writeRow(row) + } + border() + return b.String() +} diff --git a/internal/wire/incremental_manifest.go b/internal/wire/incremental_manifest.go index ae36c77..11d250f 100644 --- a/internal/wire/incremental_manifest.go +++ b/internal/wire/incremental_manifest.go @@ -250,12 +250,8 @@ func buildExternalPackageFiles(wd string, pkgs []*packages.Package) ([]cacheFile } func buildExternalPackageExports(wd string, pkgs []*packages.Package) []externalPackageExport { - moduleRoot := findModuleRoot(wd) out := make([]externalPackageExport, 0) for _, pkg := range collectAllPackages(pkgs) { - if classifyPackageLocation(moduleRoot, pkg) == "local" { - continue - } if pkg == nil || pkg.PkgPath == "" || pkg.ExportFile == "" { continue } diff --git a/internal/wire/incremental_summary.go b/internal/wire/incremental_summary.go index faaa9b8..2930b37 100644 --- a/internal/wire/incremental_summary.go +++ b/internal/wire/incremental_summary.go @@ -148,6 +148,10 @@ func writeIncrementalPackageSummary(key string, summary *packageSummary) { } func writeIncrementalPackageSummaries(loader *lazyLoader, pkgs []*packages.Package) { + writeIncrementalPackageSummariesWithSummary(loader, pkgs, nil, nil) +} + +func writeIncrementalPackageSummariesWithSummary(loader *lazyLoader, pkgs []*packages.Package, summary *summaryProviderResolver, only map[string]struct{}) { if loader == nil || len(pkgs) == 0 { return } @@ -162,11 +166,16 @@ func writeIncrementalPackageSummaries(loader *lazyLoader, pkgs []*packages.Packa for _, pkg := range all { allPkgs = append(allPkgs, pkg) } - oc := newObjectCache(allPkgs, loader) + oc := newObjectCacheWithLoader(allPkgs, loader, nil, summary) for _, pkg := range all { if classifyPackageLocation(moduleRoot, pkg) != "local" { continue } + if len(only) > 0 { + if _, ok := only[pkg.PkgPath]; !ok { + continue + } + } if pkg == nil || pkg.TypesInfo == nil || len(pkg.Syntax) == 0 { continue } diff --git a/internal/wire/loader_test.go b/internal/wire/loader_test.go index 2f41c8d..6a26d8e 100644 --- a/internal/wire/loader_test.go +++ b/internal/wire/loader_test.go @@ -794,6 +794,148 @@ func TestGenerateIncrementalShapeChangeMatchesNormalGenerate(t *testing.T) { } } +func TestGenerateIncrementalShapeChangeWithUnchangedDependentPackageMatchesNormalGenerate(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + + writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"example.com/app/router\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *router.Routes {", + "\twire.Build(dep.Set, router.Set)", + "\treturn nil", + "}", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Controller struct { Message string }", + "", + "func NewMessage() string { return \"ok\" }", + "", + "func NewController(msg string) *Controller {", + "\treturn &Controller{Message: msg}", + "}", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var Set = wire.NewSet(NewMessage, NewController)", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "router", "router.go"), strings.Join([]string{ + "package router", + "", + "import \"example.com/app/dep\"", + "", + "type Routes struct { Controller *dep.Controller }", + "", + "func ProvideRoutes(controller *dep.Controller) *Routes {", + "\treturn &Routes{Controller: controller}", + "}", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "router", "wire.go"), strings.Join([]string{ + "package router", + "", + "import \"github.com/goforj/wire\"", + "", + "var Set = wire.NewSet(ProvideRoutes)", + "", + }, "\n")) + + env := append(os.Environ(), "GOWORK=off") + incrementalCtx := WithIncremental(context.Background(), true) + + if _, errs := Generate(incrementalCtx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { + t.Fatalf("baseline incremental Generate returned errors: %v", errs) + } + + writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Controller struct { Message string; Count int }", + "", + "func NewMessage() string { return \"ok\" }", + "", + "func NewCount() int { return 7 }", + "", + "func NewController(msg string, count int) *Controller {", + "\treturn &Controller{Message: msg, Count: count}", + "}", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var Set = wire.NewSet(NewMessage, NewCount, NewController)", + "", + }, "\n")) + + var incrementalLabels []string + incrementalTimedCtx := WithTiming(incrementalCtx, func(label string, _ time.Duration) { + incrementalLabels = append(incrementalLabels, label) + }) + incrementalGens, errs := Generate(incrementalTimedCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("incremental Generate returned errors: %v", errs) + } + if len(incrementalGens) != 1 || len(incrementalGens[0].Errs) > 0 { + t.Fatalf("unexpected incremental Generate results: %+v", incrementalGens) + } + if !containsLabel(incrementalLabels, "incremental.local_fastpath.load") { + t.Fatalf("expected incremental Generate to use local fast path, labels=%v", incrementalLabels) + } + + normalGens, errs := Generate(context.Background(), root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("normal Generate returned errors: %v", errs) + } + if len(normalGens) != 1 || len(normalGens[0].Errs) > 0 { + t.Fatalf("unexpected normal Generate results: %+v", normalGens) + } + if string(incrementalGens[0].Content) != string(normalGens[0].Content) { + t.Fatal("incremental output differs from normal Generate output when unchanged package depends on changed package") + } +} + func TestGenerateIncrementalInvalidShapeChangeDoesNotReuseManifest(t *testing.T) { lockCacheHooks(t) state := saveCacheHooks() diff --git a/internal/wire/local_fastpath.go b/internal/wire/local_fastpath.go index 466dcc2..4ef1f8f 100644 --- a/internal/wire/local_fastpath.go +++ b/internal/wire/local_fastpath.go @@ -59,7 +59,7 @@ func tryIncrementalLocalFastPath(ctx context.Context, wd string, env []string, p } fastPathStart := time.Now() - loaded, err := loadLocalPackagesForFastPath(ctx, wd, opts.Tags, roots[0], state.currentLocal, state.manifest.ExternalPkgs) + loaded, err := loadLocalPackagesForFastPath(ctx, wd, opts.Tags, roots[0], changed, state.currentLocal, state.manifest.ExternalPkgs) if err != nil { debugf(ctx, "incremental.local_fastpath miss reason=%v", err) if shouldBypassIncrementalManifestAfterFastPathError(err) { @@ -69,7 +69,7 @@ func tryIncrementalLocalFastPath(ctx context.Context, wd string, env []string, p } logTiming(ctx, "incremental.local_fastpath.load", fastPathStart) - generated, errs := generateFromTypedPackages(ctx, loaded.root, loaded.allPackages, opts) + generated, errs := generateFromTypedPackages(ctx, loaded, opts) logTiming(ctx, "incremental.local_fastpath.generate", fastPathStart) if len(errs) > 0 { return nil, true, true, errs @@ -91,8 +91,12 @@ func tryIncrementalLocalFastPath(ctx context.Context, wd string, env []string, p for path, pkg := range loaded.byPath { loader.loaded[path] = pkg } + changedSet := make(map[string]struct{}, len(snapshot.changed)) + for _, path := range snapshot.changed { + changedSet[path] = struct{}{} + } writeIncrementalFingerprints(snapshot, wd, opts.Tags) - writeIncrementalPackageSummaries(loader, loaded.allPackages) + writeIncrementalPackageSummariesWithSummary(loader, loaded.allPackages, newSummaryProviderResolver(ctx, loaded.loader.summaries, loaded.loader.importExportPackage), changedSet) writeIncrementalManifestFromState(wd, env, patterns, opts, state, snapshot, generated) writeIncrementalGraphFromSnapshot(wd, opts.Tags, roots, loaded.fingerprints) @@ -105,6 +109,9 @@ func shouldBypassIncrementalManifestAfterFastPathError(err error) bool { return false } msg := err.Error() + if strings.Contains(msg, "missing external export data for ") { + return false + } return strings.Contains(msg, "type-check failed for ") } @@ -205,6 +212,7 @@ type localFastPathLoaded struct { allPackages []*packages.Package byPath map[string]*packages.Package fingerprints map[string]*packageFingerprint + loader *localFastPathLoader } type localFastPathLoader struct { @@ -212,14 +220,19 @@ type localFastPathLoader struct { wd string tags string fset *token.FileSet + modulePrefix string rootPkgPath string + changedPkgs map[string]struct{} + sourcePkgs map[string]struct{} + summaries map[string]*packageSummary meta map[string]*packageFingerprint pkgs map[string]*packages.Package + imported map[string]*types.Package externalMeta map[string]externalPackageExport externalImp types.Importer } -func loadLocalPackagesForFastPath(ctx context.Context, wd string, tags string, rootPkgPath string, current []packageFingerprint, external []externalPackageExport) (*localFastPathLoaded, error) { +func loadLocalPackagesForFastPath(ctx context.Context, wd string, tags string, rootPkgPath string, changed []string, current []packageFingerprint, external []externalPackageExport) (*localFastPathLoaded, error) { meta := fingerprintsFromSlice(current) if len(meta) == 0 { return nil, fmt.Errorf("no local fingerprints") @@ -239,11 +252,38 @@ func loadLocalPackagesForFastPath(ctx context.Context, wd string, tags string, r wd: wd, tags: tags, fset: token.NewFileSet(), + modulePrefix: moduleImportPrefix(meta), rootPkgPath: rootPkgPath, + changedPkgs: make(map[string]struct{}, len(changed)), + sourcePkgs: make(map[string]struct{}), + summaries: make(map[string]*packageSummary), meta: meta, pkgs: make(map[string]*packages.Package, len(meta)), + imported: make(map[string]*types.Package, len(meta)+len(externalMeta)), externalMeta: externalMeta, } + for _, path := range changed { + loader.changedPkgs[path] = struct{}{} + } + loader.markSourceClosure() + candidates := make(map[string]*packageSummary) + for path, fp := range meta { + if path == rootPkgPath { + continue + } + if _, changed := loader.changedPkgs[path]; changed { + continue + } + if _, ok := externalMeta[path]; !ok { + continue + } + summary, ok := readIncrementalPackageSummary(incrementalSummaryKey(wd, tags, path)) + if !ok || summary == nil || summary.ShapeHash != fp.ShapeHash { + continue + } + candidates[path] = summary + } + loader.summaries = filterSupportedPackageSummaries(candidates) loader.externalImp = importerpkg.ForCompiler(loader.fset, "gc", loader.openExternalExport) root, err := loader.load(rootPkgPath) if err != nil { @@ -260,6 +300,7 @@ func loadLocalPackagesForFastPath(ctx context.Context, wd string, tags string, r allPackages: all, byPath: loader.pkgs, fingerprints: loader.meta, + loader: loader, }, nil } @@ -275,10 +316,13 @@ func (l *localFastPathLoader) load(pkgPath string) (*packages.Package, error) { if len(files) == 0 { return nil, fmt.Errorf("package %s has no files", pkgPath) } - mode := parser.ParseComments | parser.SkipObjectResolution + mode := parser.SkipObjectResolution + if pkgPath == l.rootPkgPath { + mode |= parser.ParseComments + } syntax := make([]*ast.File, 0, len(files)) for _, name := range files { - file, err := parser.ParseFile(l.fset, name, nil, mode) + file, err := l.parseFileForFastPath(name, mode, pkgPath) if err != nil { return nil, err } @@ -289,15 +333,7 @@ func (l *localFastPathLoader) load(pkgPath string) (*packages.Package, error) { } pkgName := syntax[0].Name.Name - info := &types.Info{ - Types: make(map[ast.Expr]types.TypeAndValue), - Defs: make(map[*ast.Ident]types.Object), - Uses: make(map[*ast.Ident]types.Object), - Implicits: make(map[ast.Node]types.Object), - Selections: make(map[*ast.SelectorExpr]*types.Selection), - Scopes: make(map[ast.Node]*types.Scope), - Instances: make(map[*ast.Ident]types.Instance), - } + info := newFastPathTypesInfo(pkgPath == l.rootPkgPath) pkg := &packages.Package{ Fset: l.fset, Name: pkgName, @@ -314,7 +350,8 @@ func (l *localFastPathLoader) load(pkgPath string) (*packages.Package, error) { Importer: importerFunc(func(path string) (*types.Package, error) { return l.importPackage(path) }), - Sizes: types.SizesFor("gc", runtime.GOARCH), + IgnoreFuncBodies: l.shouldIgnoreFuncBodies(pkgPath), + Sizes: types.SizesFor("gc", runtime.GOARCH), Error: func(err error) { pkg.Errors = append(pkg.Errors, packages.Error{Msg: err.Error()}) }, @@ -322,6 +359,10 @@ func (l *localFastPathLoader) load(pkgPath string) (*packages.Package, error) { checkedPkg, err := conf.Check(pkgPath, l.fset, syntax, info) if checkedPkg != nil { pkg.Types = checkedPkg + l.imported[pkgPath] = checkedPkg + } + if l.shouldRetryWithoutBodyStripping(pkgPath, pkg.Errors) { + return l.reloadWithoutBodyStripping(pkgPath, files, mode, pkg) } if err != nil && len(pkg.Errors) == 0 { return nil, err @@ -347,7 +388,71 @@ func (l *localFastPathLoader) load(pkgPath string) (*packages.Package, error) { return pkg, nil } +func (l *localFastPathLoader) parseFileForFastPath(name string, mode parser.Mode, pkgPath string) (*ast.File, error) { + file, err := parser.ParseFile(l.fset, name, nil, mode) + if err != nil { + return nil, err + } + if l.shouldStripFunctionBodies(pkgPath) { + stripFunctionBodies(file) + pruneImportsWithoutTopLevelUse(file) + } + return file, nil +} + +func (l *localFastPathLoader) reloadWithoutBodyStripping(pkgPath string, files []string, mode parser.Mode, pkg *packages.Package) (*packages.Package, error) { + syntax := make([]*ast.File, 0, len(files)) + for _, name := range files { + file, err := parser.ParseFile(l.fset, name, nil, mode) + if err != nil { + return nil, err + } + syntax = append(syntax, file) + } + pkg.Syntax = syntax + pkg.Errors = nil + pkg.TypesInfo = newFastPathTypesInfo(pkgPath == l.rootPkgPath) + conf := &types.Config{ + Importer: importerFunc(func(path string) (*types.Package, error) { + return l.importPackage(path) + }), + IgnoreFuncBodies: false, + Sizes: types.SizesFor("gc", runtime.GOARCH), + Error: func(err error) { + pkg.Errors = append(pkg.Errors, packages.Error{Msg: err.Error()}) + }, + } + checkedPkg, err := conf.Check(pkgPath, l.fset, syntax, pkg.TypesInfo) + if checkedPkg != nil { + pkg.Types = checkedPkg + l.imported[pkgPath] = checkedPkg + } + if err != nil && len(pkg.Errors) == 0 { + return nil, err + } + if len(pkg.Errors) > 0 { + return nil, formatLocalTypeCheckError(l.wd, pkgPath, pkg.Errors) + } + return pkg, nil +} + +func (l *localFastPathLoader) shouldRetryWithoutBodyStripping(pkgPath string, errs []packages.Error) bool { + if !l.shouldStripFunctionBodies(pkgPath) || len(errs) == 0 { + return false + } + for _, pkgErr := range errs { + msg := pkgErr.Msg + if strings.Contains(msg, "missing function body") || strings.Contains(msg, "func init must have a body") { + return true + } + } + return false +} + func (l *localFastPathLoader) importPackage(path string) (*types.Package, error) { + if l.shouldImportFromExport(path) { + return l.importExportPackage(path) + } if l.meta[path] != nil { pkg, err := l.load(path) if err != nil { @@ -358,17 +463,132 @@ func (l *localFastPathLoader) importPackage(path string) (*types.Package, error) if l.externalImp == nil { return nil, fmt.Errorf("missing external importer") } - return l.externalImp.Import(path) + return l.importExportPackage(path) } func (l *localFastPathLoader) openExternalExport(path string) (io.ReadCloser, error) { meta, ok := l.externalMeta[path] if !ok || meta.ExportFile == "" { - return nil, fmt.Errorf("missing export data for %s", path) + if l.meta[path] != nil || l.isLikelyLocalImport(path) { + return nil, fmt.Errorf("missing local export data for %s", path) + } + return nil, fmt.Errorf("missing external export data for %s", path) } return os.Open(meta.ExportFile) } +func (l *localFastPathLoader) isLikelyLocalImport(path string) bool { + if l == nil || l.modulePrefix == "" { + return false + } + return path == l.modulePrefix || strings.HasPrefix(path, l.modulePrefix+"/") +} + +func moduleImportPrefix(meta map[string]*packageFingerprint) string { + if len(meta) == 0 { + return "" + } + paths := make([]string, 0, len(meta)) + for path := range meta { + paths = append(paths, path) + } + sort.Strings(paths) + prefix := strings.Split(paths[0], "/") + for _, path := range paths[1:] { + parts := strings.Split(path, "/") + n := len(prefix) + if len(parts) < n { + n = len(parts) + } + i := 0 + for i < n && prefix[i] == parts[i] { + i++ + } + prefix = prefix[:i] + if len(prefix) == 0 { + return "" + } + } + return strings.Join(prefix, "/") +} + +func (l *localFastPathLoader) importExportPackage(path string) (*types.Package, error) { + if l == nil { + return nil, fmt.Errorf("missing local fast path loader") + } + if pkg := l.imported[path]; pkg != nil { + return pkg, nil + } + if l.externalImp == nil { + return nil, fmt.Errorf("missing external importer") + } + pkg, err := l.externalImp.Import(path) + if err != nil { + return nil, err + } + l.imported[path] = pkg + return pkg, nil +} + +func (l *localFastPathLoader) shouldImportFromExport(pkgPath string) bool { + if l == nil { + return false + } + if _, source := l.sourcePkgs[pkgPath]; source { + return false + } + _, ok := l.summaries[pkgPath] + return ok +} + +func (l *localFastPathLoader) markSourceClosure() { + if l == nil { + return + } + reverse := make(map[string][]string) + for pkgPath, fp := range l.meta { + if fp == nil { + continue + } + for _, imp := range fp.LocalImports { + reverse[imp] = append(reverse[imp], pkgPath) + } + } + queue := make([]string, 0, len(l.changedPkgs)+1) + queue = append(queue, l.rootPkgPath) + for pkgPath := range l.changedPkgs { + queue = append(queue, pkgPath) + } + for len(queue) > 0 { + pkgPath := queue[0] + queue = queue[1:] + if _, seen := l.sourcePkgs[pkgPath]; seen { + continue + } + l.sourcePkgs[pkgPath] = struct{}{} + for _, importer := range reverse[pkgPath] { + if _, seen := l.sourcePkgs[importer]; !seen { + queue = append(queue, importer) + } + } + } +} + +func (l *localFastPathLoader) shouldStripFunctionBodies(pkgPath string) bool { + if l == nil { + return false + } + if pkgPath == l.rootPkgPath { + return false + } + _, changed := l.changedPkgs[pkgPath] + return !changed +} + +func (l *localFastPathLoader) shouldIgnoreFuncBodies(pkgPath string) bool { + return l.shouldStripFunctionBodies(pkgPath) +} + type importerFunc func(string) (*types.Package, error) func (fn importerFunc) Import(path string) (*types.Package, error) { @@ -395,7 +615,105 @@ func packageImportPaths(files []*ast.File) []string { return out } -func generateFromTypedPackages(ctx context.Context, root *packages.Package, allPkgs []*packages.Package, opts *GenerateOptions) ([]GenerateResult, []error) { +func newFastPathTypesInfo(full bool) *types.Info { + info := &types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + Defs: make(map[*ast.Ident]types.Object), + Uses: make(map[*ast.Ident]types.Object), + } + if !full { + return info + } + info.Implicits = make(map[ast.Node]types.Object) + info.Selections = make(map[*ast.SelectorExpr]*types.Selection) + info.Scopes = make(map[ast.Node]*types.Scope) + info.Instances = make(map[*ast.Ident]types.Instance) + return info +} + +func pruneImportsWithoutTopLevelUse(file *ast.File) { + if file == nil || len(file.Imports) == 0 { + return + } + used := usedImportNames(file) + filtered := file.Imports[:0] + for _, spec := range file.Imports { + if spec == nil || spec.Path == nil { + continue + } + name := importName(spec) + if name == "_" || name == "." { + filtered = append(filtered, spec) + continue + } + if _, ok := used[name]; ok { + filtered = append(filtered, spec) + } + } + file.Imports = filtered + for _, decl := range file.Decls { + gen, ok := decl.(*ast.GenDecl) + if !ok || gen.Tok != token.IMPORT { + continue + } + specs := gen.Specs[:0] + for _, spec := range gen.Specs { + importSpec, ok := spec.(*ast.ImportSpec) + if !ok || importSpec.Path == nil { + continue + } + name := importName(importSpec) + if name == "_" || name == "." { + specs = append(specs, spec) + continue + } + if _, ok := used[name]; ok { + specs = append(specs, spec) + } + } + gen.Specs = specs + } +} + +func usedImportNames(file *ast.File) map[string]struct{} { + used := make(map[string]struct{}) + ast.Inspect(file, func(node ast.Node) bool { + sel, ok := node.(*ast.SelectorExpr) + if !ok { + return true + } + ident, ok := sel.X.(*ast.Ident) + if !ok || ident.Name == "" { + return true + } + used[ident.Name] = struct{}{} + return true + }) + return used +} + +func importName(spec *ast.ImportSpec) string { + if spec == nil || spec.Path == nil { + return "" + } + if spec.Name != nil && spec.Name.Name != "" { + return spec.Name.Name + } + path := strings.Trim(spec.Path.Value, "\"") + if path == "" { + return "" + } + if slash := strings.LastIndex(path, "/"); slash >= 0 { + path = path[slash+1:] + } + return path +} + +func generateFromTypedPackages(ctx context.Context, loaded *localFastPathLoaded, opts *GenerateOptions) ([]GenerateResult, []error) { + if loaded == nil { + return nil, []error{fmt.Errorf("missing loaded packages")} + } + root := loaded.root if root == nil { return nil, []error{fmt.Errorf("missing root package")} } @@ -412,7 +730,11 @@ func generateFromTypedPackages(ctx context.Context, root *packages.Package, allP } res.OutputPath = filepath.Join(outDir, opts.PrefixOutputFile+"wire_gen.go") - oc := newObjectCache(allPkgs, nil) + var summary *summaryProviderResolver + if loaded.loader != nil { + summary = newSummaryProviderResolver(ctx, loaded.loader.summaries, loaded.loader.importExportPackage) + } + oc := newObjectCacheWithLoader(loaded.allPackages, nil, nil, summary) g := newGen(root) injectorStart := time.Now() injectorFiles, errs := generateInjectors(oc, g, root) @@ -447,9 +769,12 @@ func writeIncrementalFingerprints(snapshot *incrementalFingerprintSnapshot, wd s if snapshot == nil { return } - for _, fp := range snapshotPackageFingerprints(snapshot) { - fp := fp - writeIncrementalFingerprint(incrementalFingerprintKey(wd, tags, fp.PkgPath), &fp) + for _, path := range snapshot.changed { + fp := snapshot.fingerprints[path] + if fp == nil { + continue + } + writeIncrementalFingerprint(incrementalFingerprintKey(wd, tags, fp.PkgPath), fp) } } diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 2f038a9..73f218d 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -471,6 +471,7 @@ type objectCache struct { objects map[objRef]objCacheEntry hasher typeutil.Hasher loader *lazyLoader + summary *summaryProviderResolver } type objRef struct { @@ -484,6 +485,10 @@ type objCacheEntry struct { } func newObjectCache(pkgs []*packages.Package, loader *lazyLoader) *objectCache { + return newObjectCacheWithLoader(pkgs, loader, nil, nil) +} + +func newObjectCacheWithLoader(pkgs []*packages.Package, loader *lazyLoader, _ *localFastPathLoader, summary *summaryProviderResolver) *objectCache { if len(pkgs) == 0 { panic("object cache must have packages to draw from") } @@ -493,6 +498,7 @@ func newObjectCache(pkgs []*packages.Package, loader *lazyLoader) *objectCache { objects: make(map[objRef]objCacheEntry), hasher: typeutil.MakeHasher(), loader: loader, + summary: summary, } if oc.fset == nil && loader != nil { oc.fset = loader.fset @@ -557,9 +563,6 @@ func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { if ent, cached := oc.objects[ref]; cached { return ent.val, append([]error(nil), ent.errs...) } - if _, errs := oc.ensurePackage(ref.importPath); len(errs) > 0 { - return nil, errs - } defer func() { oc.objects[ref] = objCacheEntry{ val: val, @@ -568,6 +571,14 @@ func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { }() switch obj := obj.(type) { case *types.Var: + if isProviderSetType(obj.Type()) && oc.summary != nil { + if pset, ok, summaryErrs := oc.summary.Resolve(obj.Pkg().Path(), obj.Name()); ok { + return pset, summaryErrs + } + } + if _, errs := oc.ensurePackage(ref.importPath); len(errs) > 0 { + return nil, errs + } spec := oc.varDecl(obj) if spec == nil || len(spec.Values) == 0 { return nil, []error{fmt.Errorf("%v is not a provider or a provider set", obj)} @@ -583,6 +594,9 @@ func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { case *types.Func: return processFuncProvider(oc.fset, obj) default: + if _, errs := oc.ensurePackage(ref.importPath); len(errs) > 0 { + return nil, errs + } return nil, []error{fmt.Errorf("%v is not a provider or a provider set", obj)} } } diff --git a/internal/wire/summary_provider_resolver.go b/internal/wire/summary_provider_resolver.go new file mode 100644 index 0000000..c93e0c5 --- /dev/null +++ b/internal/wire/summary_provider_resolver.go @@ -0,0 +1,223 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "context" + "fmt" + "go/token" + "go/types" + "time" + + "golang.org/x/tools/go/types/typeutil" +) + +type summaryProviderResolver struct { + ctx context.Context + fset *token.FileSet + summaries map[string]*packageSummary + importPackage func(string) (*types.Package, error) + cache map[providerSetRefSummary]*ProviderSet + resolving map[providerSetRefSummary]struct{} + supported map[string]bool +} + +func newSummaryProviderResolver(ctx context.Context, summaries map[string]*packageSummary, importPackage func(string) (*types.Package, error)) *summaryProviderResolver { + if len(summaries) == 0 || importPackage == nil { + return nil + } + r := &summaryProviderResolver{ + ctx: ctx, + fset: token.NewFileSet(), + summaries: make(map[string]*packageSummary, len(summaries)), + importPackage: importPackage, + cache: make(map[providerSetRefSummary]*ProviderSet), + resolving: make(map[providerSetRefSummary]struct{}), + supported: make(map[string]bool, len(summaries)), + } + for pkgPath, summary := range summaries { + if summary == nil { + continue + } + r.summaries[pkgPath] = summary + } + for pkgPath := range r.summaries { + r.supported[pkgPath] = r.packageSupported(pkgPath, make(map[string]struct{})) + } + return r +} + +func filterSupportedPackageSummaries(summaries map[string]*packageSummary) map[string]*packageSummary { + if len(summaries) == 0 { + return nil + } + resolver := &summaryProviderResolver{ + summaries: summaries, + supported: make(map[string]bool, len(summaries)), + } + out := make(map[string]*packageSummary) + for pkgPath, summary := range summaries { + if summary == nil { + continue + } + if resolver.packageSupported(pkgPath, make(map[string]struct{})) { + out[pkgPath] = summary + } + } + return out +} + +func (r *summaryProviderResolver) Resolve(pkgPath string, varName string) (*ProviderSet, bool, []error) { + if r == nil || !r.supported[pkgPath] { + return nil, false, nil + } + start := time.Now() + set, err := r.resolve(providerSetRefSummary{PkgPath: pkgPath, VarName: varName}) + logTiming(r.ctx, "incremental.local_fastpath.summary_resolve", start) + if err != nil { + return nil, true, []error{err} + } + return set, true, nil +} + +func (r *summaryProviderResolver) resolve(ref providerSetRefSummary) (*ProviderSet, error) { + if set := r.cache[ref]; set != nil { + return set, nil + } + if _, ok := r.resolving[ref]; ok { + return nil, fmt.Errorf("summary provider set cycle for %s.%s", ref.PkgPath, ref.VarName) + } + summary := r.summaries[ref.PkgPath] + if summary == nil { + return nil, fmt.Errorf("missing package summary for %s", ref.PkgPath) + } + setSummary, ok := r.findProviderSet(summary, ref.VarName) + if !ok { + return nil, fmt.Errorf("missing provider set summary for %s.%s", ref.PkgPath, ref.VarName) + } + r.resolving[ref] = struct{}{} + defer delete(r.resolving, ref) + + pkg, err := r.importPackage(ref.PkgPath) + if err != nil { + return nil, err + } + set := &ProviderSet{ + PkgPath: ref.PkgPath, + VarName: ref.VarName, + } + for _, provider := range setSummary.Providers { + resolved, err := r.resolveProvider(pkg, provider) + if err != nil { + return nil, err + } + set.Providers = append(set.Providers, resolved) + } + for _, imported := range setSummary.Imports { + child, err := r.resolve(imported) + if err != nil { + return nil, err + } + set.Imports = append(set.Imports, child) + } + hasher := typeutil.MakeHasher() + providerMap, srcMap, errs := buildProviderMap(r.fset, hasher, set) + if len(errs) > 0 { + return nil, errs[0] + } + if errs := verifyAcyclic(providerMap, hasher); len(errs) > 0 { + return nil, errs[0] + } + set.providerMap = providerMap + set.srcMap = srcMap + r.cache[ref] = set + return set, nil +} + +func (r *summaryProviderResolver) resolveProvider(pkg *types.Package, summary providerSummary) (*Provider, error) { + if summary.IsStruct || len(summary.Out) == 0 { + return nil, fmt.Errorf("unsupported summary provider %s.%s", summary.PkgPath, summary.Name) + } + if pkg == nil || pkg.Path() != summary.PkgPath { + var err error + pkg, err = r.importPackage(summary.PkgPath) + if err != nil { + return nil, err + } + } + obj := pkg.Scope().Lookup(summary.Name) + fn, ok := obj.(*types.Func) + if !ok { + return nil, fmt.Errorf("summary provider %s.%s missing function", summary.PkgPath, summary.Name) + } + provider, errs := processFuncProvider(r.fset, fn) + if len(errs) > 0 { + return nil, errs[0] + } + return provider, nil +} + +func (r *summaryProviderResolver) findProviderSet(summary *packageSummary, varName string) (providerSetSummary, bool) { + if summary == nil { + return providerSetSummary{}, false + } + for _, set := range summary.ProviderSets { + if set.VarName == varName { + return set, true + } + } + return providerSetSummary{}, false +} + +func (r *summaryProviderResolver) packageSupported(pkgPath string, visiting map[string]struct{}) bool { + if ok, seen := r.supported[pkgPath]; seen { + return ok + } + if _, seen := visiting[pkgPath]; seen { + return false + } + summary := r.summaries[pkgPath] + if summary == nil { + return false + } + visiting[pkgPath] = struct{}{} + defer delete(visiting, pkgPath) + for _, set := range summary.ProviderSets { + if !providerSetSummarySupported(set) { + return false + } + for _, imported := range set.Imports { + if _, ok := r.summaries[imported.PkgPath]; !ok { + return false + } + if !r.packageSupported(imported.PkgPath, visiting) { + return false + } + } + } + return true +} + +func providerSetSummarySupported(summary providerSetSummary) bool { + if len(summary.Bindings) > 0 || len(summary.Values) > 0 || len(summary.Fields) > 0 || len(summary.InputTypes) > 0 { + return false + } + for _, provider := range summary.Providers { + if provider.IsStruct { + return false + } + } + return true +} From 2eb540098f701a041e7eff360bf2f72c321873f7 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Fri, 13 Mar 2026 23:47:20 -0500 Subject: [PATCH 03/79] perf(incremental): trim cold bootstrap work and keep warm shape changes fast --- internal/wire/incremental.go | 20 +++++ internal/wire/incremental_fingerprint.go | 99 ++++++++++++++++++++++- internal/wire/incremental_manifest.go | 14 +++- internal/wire/incremental_summary_test.go | 10 ++- internal/wire/loader_test.go | 37 +++++++++ internal/wire/parse.go | 7 +- internal/wire/wire.go | 37 ++++++++- 7 files changed, 213 insertions(+), 11 deletions(-) diff --git a/internal/wire/incremental.go b/internal/wire/incremental.go index 0bc334c..007027b 100644 --- a/internal/wire/incremental.go +++ b/internal/wire/incremental.go @@ -23,6 +23,7 @@ import ( const IncrementalEnvVar = "WIRE_INCREMENTAL" type incrementalKey struct{} +type incrementalColdBootstrapKey struct{} // WithIncremental overrides incremental-mode resolution for the provided // context. This takes precedence over the environment variable. @@ -33,6 +34,13 @@ func WithIncremental(ctx context.Context, enabled bool) context.Context { return context.WithValue(ctx, incrementalKey{}, enabled) } +func withIncrementalColdBootstrap(ctx context.Context, enabled bool) context.Context { + if ctx == nil { + ctx = context.Background() + } + return context.WithValue(ctx, incrementalColdBootstrapKey{}, enabled) +} + // IncrementalEnabled reports whether incremental mode is enabled for the // current operation. A context override takes precedence over env. func IncrementalEnabled(ctx context.Context, env []string) bool { @@ -54,6 +62,18 @@ func IncrementalEnabled(ctx context.Context, env []string) bool { return enabled } +func incrementalColdBootstrapEnabled(ctx context.Context) bool { + if ctx == nil { + return false + } + if v := ctx.Value(incrementalColdBootstrapKey{}); v != nil { + if enabled, ok := v.(bool); ok { + return enabled + } + } + return false +} + func lookupEnv(env []string, key string) (string, bool) { prefix := key + "=" for i := len(env) - 1; i >= 0; i-- { diff --git a/internal/wire/incremental_fingerprint.go b/internal/wire/incremental_fingerprint.go index 886d07f..46485f7 100644 --- a/internal/wire/incremental_fingerprint.go +++ b/internal/wire/incremental_fingerprint.go @@ -124,6 +124,52 @@ func collectIncrementalFingerprints(wd string, tags string, pkgs []*packages.Pac return snapshot } +func buildIncrementalManifestSnapshotFromPackages(wd string, tags string, pkgs []*packages.Package) *incrementalFingerprintSnapshot { + all := collectAllPackages(pkgs) + moduleRoot := findModuleRoot(wd) + snapshot := &incrementalFingerprintSnapshot{ + fingerprints: make(map[string]*packageFingerprint), + } + for _, pkg := range all { + if classifyPackageLocation(moduleRoot, pkg) != "local" { + continue + } + files := packageFingerprintFiles(pkg) + if len(files) == 0 { + continue + } + sort.Strings(files) + metaFiles, err := buildCacheFiles(files) + if err != nil { + continue + } + shapeHash, err := packageShapeHashFromSyntax(pkg, files) + if err != nil { + continue + } + localImports := make([]string, 0, len(pkg.Imports)) + for _, imp := range pkg.Imports { + if classifyPackageLocation(moduleRoot, imp) == "local" { + localImports = append(localImports, imp.PkgPath) + } + } + sort.Strings(localImports) + snapshot.fingerprints[pkg.PkgPath] = &packageFingerprint{ + Version: incrementalFingerprintVersion, + WD: filepath.Clean(wd), + Tags: tags, + PkgPath: pkg.PkgPath, + Files: metaFiles, + ShapeHash: shapeHash, + LocalImports: localImports, + } + } + if len(snapshot.fingerprints) == 0 { + return nil + } + return snapshot +} + func packageFingerprintFiles(pkg *packages.Package) []string { if pkg == nil { return nil @@ -202,16 +248,63 @@ func packageShapeHash(files []string) (string, error) { if err != nil { return "", err } - stripFunctionBodies(file) - if err := printer.Fprint(&buf, fset, file); err != nil { - return "", err + writeSyntaxShapeHash(&buf, fset, file) + buf.WriteByte(0) + } + sum := sha256.Sum256(buf.Bytes()) + return fmt.Sprintf("%x", sum[:]), nil +} + +func packageShapeHashFromSyntax(pkg *packages.Package, files []string) (string, error) { + if pkg == nil || len(pkg.Syntax) == 0 || pkg.Fset == nil { + return packageShapeHash(files) + } + var buf bytes.Buffer + for _, file := range pkg.Syntax { + if file == nil { + continue } + writeSyntaxShapeHash(&buf, pkg.Fset, file) buf.WriteByte(0) } sum := sha256.Sum256(buf.Bytes()) return fmt.Sprintf("%x", sum[:]), nil } +func writeSyntaxShapeHash(buf *bytes.Buffer, fset *token.FileSet, file *ast.File) { + if file == nil || buf == nil || fset == nil { + return + } + if file.Name != nil { + buf.WriteString("package ") + buf.WriteString(file.Name.Name) + buf.WriteByte('\n') + } + for _, decl := range file.Decls { + switch decl := decl.(type) { + case *ast.FuncDecl: + writeNodeHash(buf, fset, decl.Recv) + buf.WriteByte(' ') + if decl.Name != nil { + buf.WriteString(decl.Name.Name) + } + buf.WriteByte(' ') + writeNodeHash(buf, fset, decl.Type) + buf.WriteByte('\n') + default: + writeNodeHash(buf, fset, decl) + buf.WriteByte('\n') + } + } +} + +func writeNodeHash(buf *bytes.Buffer, fset *token.FileSet, node interface{}) { + if buf == nil || fset == nil || node == nil { + return + } + _ = printer.Fprint(buf, fset, node) +} + func stripFunctionBodies(file *ast.File) { if file == nil { return diff --git a/internal/wire/incremental_manifest.go b/internal/wire/incremental_manifest.go index 11d250f..cd88976 100644 --- a/internal/wire/incremental_manifest.go +++ b/internal/wire/incremental_manifest.go @@ -143,13 +143,21 @@ func readIncrementalManifestResults(ctx context.Context, wd string, env []string } func writeIncrementalManifest(wd string, env []string, patterns []string, opts *GenerateOptions, pkgs []*packages.Package, snapshot *incrementalFingerprintSnapshot, generated []GenerateResult) { + writeIncrementalManifestWithOptions(wd, env, patterns, opts, pkgs, snapshot, generated, true) +} + +func writeIncrementalManifestWithOptions(wd string, env []string, patterns []string, opts *GenerateOptions, pkgs []*packages.Package, snapshot *incrementalFingerprintSnapshot, generated []GenerateResult, includeExternalFiles bool) { if snapshot == nil || len(generated) == 0 { return } externalPkgs := buildExternalPackageExports(wd, pkgs) - externalFiles, err := buildExternalPackageFiles(wd, pkgs) - if err != nil { - return + var externalFiles []cacheFile + if includeExternalFiles { + var err error + externalFiles, err = buildExternalPackageFiles(wd, pkgs) + if err != nil { + return + } } manifest := &incrementalManifest{ Version: incrementalManifestVersion, diff --git a/internal/wire/incremental_summary_test.go b/internal/wire/incremental_summary_test.go index efb4028..ae85651 100644 --- a/internal/wire/incremental_summary_test.go +++ b/internal/wire/incremental_summary_test.go @@ -242,6 +242,14 @@ func TestCollectIncrementalPackageSummariesUsesCacheForUnchanged(t *testing.T) { if len(gens) != 1 || len(gens[0].Errs) > 0 { t.Fatalf("unexpected Generate result: %+v", gens) } + pkgs, loader, errs := load(ctx, root, env, "", []string{"./app"}) + if len(errs) > 0 { + t.Fatalf("load returned errors while seeding summaries: %v", errs) + } + if _, errs := newObjectCache(pkgs, loader).ensurePackage("example.com/app/app"); len(errs) > 0 { + t.Fatalf("ensurePackage returned errors while seeding summaries: %v", errs) + } + writeIncrementalPackageSummaries(loader, pkgs) writeFile(t, depFile, strings.Join([]string{ "package dep", @@ -264,7 +272,7 @@ func TestCollectIncrementalPackageSummariesUsesCacheForUnchanged(t *testing.T) { "", }, "\n")) - pkgs, loader, errs := load(ctx, root, env, "", []string{"./app"}) + pkgs, loader, errs = load(ctx, root, env, "", []string{"./app"}) if len(errs) > 0 { t.Fatalf("load returned errors: %v", errs) } diff --git a/internal/wire/loader_test.go b/internal/wire/loader_test.go index 6a26d8e..6899249 100644 --- a/internal/wire/loader_test.go +++ b/internal/wire/loader_test.go @@ -794,6 +794,43 @@ func TestGenerateIncrementalShapeChangeMatchesNormalGenerate(t *testing.T) { } } +func TestGenerateIncrementalColdBootstrapStillSeedsFastPath(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + writeLargeBenchmarkModule(t, repoRoot, root, 24) + + env := append(os.Environ(), "GOWORK=off") + ctx := WithIncremental(context.Background(), true) + + if _, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { + t.Fatalf("cold bootstrap Generate returned errors: %v", errs) + } + + mutateLargeBenchmarkModule(t, root, 12) + + var labels []string + timedCtx := WithTiming(ctx, func(label string, _ time.Duration) { + labels = append(labels, label) + }) + gens, errs := Generate(timedCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("shape-change Generate returned errors: %v", errs) + } + if len(gens) != 1 || len(gens[0].Errs) > 0 { + t.Fatalf("unexpected Generate results: %+v", gens) + } + if !containsLabel(labels, "incremental.local_fastpath.load") { + t.Fatalf("expected cold bootstrap to seed fast path, labels=%v", labels) + } +} + func TestGenerateIncrementalShapeChangeWithUnchangedDependentPackageMatchesNormalGenerate(t *testing.T) { lockCacheHooks(t) state := saveCacheHooks() diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 73f218d..0cb0551 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -403,8 +403,11 @@ func load(ctx context.Context, wd string, env []string, tags string, patterns [] if len(errs) > 0 { return nil, nil, errs } - fingerprints := analyzeIncrementalFingerprints(ctx, wd, env, tags, pkgs) - analyzeIncrementalGraph(ctx, wd, env, tags, pkgs, fingerprints) + var fingerprints *incrementalFingerprintSnapshot + if !incrementalColdBootstrapEnabled(ctx) { + fingerprints = analyzeIncrementalFingerprints(ctx, wd, env, tags, pkgs) + analyzeIncrementalGraph(ctx, wd, env, tags, pkgs, fingerprints) + } baseFiles := collectPackageFiles(pkgs) loader := &lazyLoader{ diff --git a/internal/wire/wire.go b/internal/wire/wire.go index 64202dc..8b617ea 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -103,9 +103,14 @@ func Generate(ctx context.Context, wd string, env []string, patterns []string, o } var preloadState *incrementalPreloadState bypassIncrementalManifest := false + coldBootstrap := false if IncrementalEnabled(ctx, env) { debugf(ctx, "incremental=enabled") preloadState, _ = prepareIncrementalPreloadState(ctx, wd, env, patterns, opts) + coldBootstrap = preloadState == nil + if coldBootstrap { + ctx = withIncrementalColdBootstrap(ctx, true) + } if cached, ok := readPreloadIncrementalManifestResultsFromState(ctx, wd, env, patterns, opts, preloadState, preloadState != nil); ok { return cached, nil } @@ -139,14 +144,42 @@ func Generate(ctx context.Context, wd string, env []string, patterns []string, o } if allGeneratedOK(generated) { if IncrementalEnabled(ctx, env) { - writeIncrementalPackageSummaries(loader, pkgs) + if coldBootstrap { + snapshot := buildIncrementalManifestSnapshotFromPackages(wd, opts.Tags, incrementalManifestPackages(pkgs, loader)) + writeIncrementalManifestWithOptions(wd, env, patterns, opts, incrementalManifestPackages(pkgs, loader), snapshot, generated, false) + if snapshot != nil { + writeIncrementalGraphFromSnapshot(wd, opts.Tags, manifestOutputPkgPathsFromGenerated(generated), snapshot.fingerprints) + } + } else { + writeIncrementalPackageSummaries(loader, pkgs) + writeIncrementalManifest(wd, env, patterns, opts, incrementalManifestPackages(pkgs, loader), loader.fingerprints, generated) + } } writeManifest(wd, env, patterns, opts, pkgs) - writeIncrementalManifest(wd, env, patterns, opts, incrementalManifestPackages(pkgs, loader), loader.fingerprints, generated) } return generated, nil } +func manifestOutputPkgPathsFromGenerated(generated []GenerateResult) []string { + if len(generated) == 0 { + return nil + } + seen := make(map[string]struct{}, len(generated)) + out := make([]string, 0, len(generated)) + for _, gen := range generated { + if gen.PkgPath == "" { + continue + } + if _, ok := seen[gen.PkgPath]; ok { + continue + } + seen[gen.PkgPath] = struct{}{} + out = append(out, gen.PkgPath) + } + sort.Strings(out) + return out +} + func warmPackageOutputCache(pkgs []*packages.Package, opts *GenerateOptions, generated []GenerateResult) { if len(pkgs) == 0 || len(generated) == 0 { return From ad2d561df609fbda6a5780547d21b3abf55e800c Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sat, 14 Mar 2026 00:11:01 -0500 Subject: [PATCH 04/79] perf(incremental): load deps conditionally --- internal/wire/parse.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 0cb0551..a7a1a02 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -377,7 +377,7 @@ func load(ctx context.Context, wd string, env []string, tags string, patterns [] } baseCfg := &packages.Config{ Context: ctx, - Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps, + Mode: baseLoadMode(ctx), Dir: wd, Env: env, BuildFlags: []string{"-tags=wireinject"}, @@ -423,6 +423,14 @@ func load(ctx context.Context, wd string, env []string, tags string, patterns [] return pkgs, loader, nil } +func baseLoadMode(ctx context.Context) packages.LoadMode { + mode := packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports + if !incrementalColdBootstrapEnabled(ctx) { + mode |= packages.NeedDeps + } + return mode +} + func collectLoadErrors(pkgs []*packages.Package) []error { var errs []error for _, p := range pkgs { From 578b24f7be26cab46528ad6ce71f58266921189a Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sat, 14 Mar 2026 00:25:39 -0500 Subject: [PATCH 05/79] chore(incremental): clear session cache --- cmd/wire/cache_cmd.go | 7 +++++-- internal/wire/cache_coverage_test.go | 30 ++++++++++++++++++++++++++++ internal/wire/cache_store.go | 1 + internal/wire/incremental_session.go | 7 +++++++ 4 files changed, 43 insertions(+), 2 deletions(-) diff --git a/cmd/wire/cache_cmd.go b/cmd/wire/cache_cmd.go index e5ceda4..f34d381 100644 --- a/cmd/wire/cache_cmd.go +++ b/cmd/wire/cache_cmd.go @@ -38,9 +38,9 @@ func (*cacheCmd) Synopsis() string { // Usage returns the help text for the subcommand. func (*cacheCmd) Usage() string { - return `cache [-clear] + return `cache [-clear|clear] - By default, prints the cache directory. With -clear, removes all cache files. + By default, prints the cache directory. With -clear or clear, removes all cache files. ` } @@ -51,6 +51,9 @@ func (cmd *cacheCmd) SetFlags(f *flag.FlagSet) { // Execute runs the subcommand. func (cmd *cacheCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + if f.NArg() > 0 && f.Arg(0) == "clear" { + cmd.clear = true + } if cmd.clear { if err := wire.ClearCache(); err != nil { log.Printf("failed to clear cache: %v\n", err) diff --git a/internal/wire/cache_coverage_test.go b/internal/wire/cache_coverage_test.go index d65de26..316f30e 100644 --- a/internal/wire/cache_coverage_test.go +++ b/internal/wire/cache_coverage_test.go @@ -166,6 +166,36 @@ func TestCacheStoreReadWrite(t *testing.T) { } } +func TestClearCacheClearsIncrementalSessions(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + tempDir := t.TempDir() + osTempDir = func() string { return tempDir } + + sessionA := getIncrementalSession("/tmp/example", []string{"WIRE_INCREMENTAL=1"}, "") + if sessionA == nil { + t.Fatal("expected incremental session") + } + sessionB := getIncrementalSession("/tmp/example", []string{"WIRE_INCREMENTAL=1"}, "") + if sessionA != sessionB { + t.Fatal("expected same incremental session before clear") + } + + if err := ClearCache(); err != nil { + t.Fatalf("ClearCache failed: %v", err) + } + + sessionC := getIncrementalSession("/tmp/example", []string{"WIRE_INCREMENTAL=1"}, "") + if sessionC == nil { + t.Fatal("expected incremental session after clear") + } + if sessionC == sessionA { + t.Fatal("expected ClearCache to drop in-process incremental sessions") + } +} + func TestCacheStoreReadError(t *testing.T) { lockCacheHooks(t) state := saveCacheHooks() diff --git a/internal/wire/cache_store.go b/internal/wire/cache_store.go index dce5565..0c959cf 100644 --- a/internal/wire/cache_store.go +++ b/internal/wire/cache_store.go @@ -32,6 +32,7 @@ func CacheDir() string { // ClearCache removes all cached data. func ClearCache() error { + clearIncrementalSessions() return osRemoveAll(cacheDir()) } diff --git a/internal/wire/incremental_session.go b/internal/wire/incremental_session.go index fda6605..72b051b 100644 --- a/internal/wire/incremental_session.go +++ b/internal/wire/incremental_session.go @@ -37,6 +37,13 @@ type cachedParsedFile struct { var incrementalSessions sync.Map +func clearIncrementalSessions() { + incrementalSessions.Range(func(key, _ any) bool { + incrementalSessions.Delete(key) + return true + }) +} + func sessionKey(wd string, env []string, tags string) string { var b strings.Builder b.WriteString(filepath.Clean(wd)) From 83806b9e0dbb0d9da99bd0749c50801dbf54449a Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sat, 14 Mar 2026 13:33:42 -0500 Subject: [PATCH 06/79] fix(cli): improve wire error coloring and solve error labeling --- cmd/wire/main.go | 72 ++++++- cmd/wire/main_test.go | 114 ++++++++++ internal/wire/incremental_fingerprint.go | 3 + internal/wire/incremental_manifest.go | 9 + internal/wire/loader_test.go | 264 ++++++++++++++++++++++- internal/wire/local_fastpath.go | 20 ++ internal/wire/parser_lazy_loader.go | 23 +- internal/wire/wire.go | 6 + 8 files changed, 495 insertions(+), 16 deletions(-) create mode 100644 cmd/wire/main_test.go diff --git a/cmd/wire/main.go b/cmd/wire/main.go index 3166531..d40e439 100644 --- a/cmd/wire/main.go +++ b/cmd/wire/main.go @@ -21,6 +21,7 @@ import ( "context" "flag" "fmt" + "io" "io/ioutil" "log" "os" @@ -37,7 +38,7 @@ import ( var topLevelIncremental optionalBoolFlag const ( - ansiRed = "\033[31m" + ansiRed = "\033[1;31m" ansiReset = "\033[0m" ) @@ -208,7 +209,7 @@ func newGenerateOptions(headerFile string) (*wire.GenerateOptions, error) { // logErrors logs each error with consistent formatting. func logErrors(errs []error) { for _, err := range errs { - msg := err.Error() + msg := formatLoggedError(err) if strings.Contains(msg, "\n") { logMultilineError("\n " + strings.ReplaceAll(msg, "\n", "\n ")) continue @@ -217,25 +218,78 @@ func logErrors(errs []error) { } } -func logMultilineError(msg string) { - if shouldColorStderr() { - log.Print(ansiRed + msg + ansiReset) - return +func formatLoggedError(err error) string { + if err == nil { + return "" + } + msg := err.Error() + if strings.HasPrefix(msg, "inject ") { + return "solve failed\n" + msg } - log.Print(msg) + if idx := strings.Index(msg, ": inject "); idx >= 0 { + return "solve failed\n" + msg + } + return msg +} + +func logMultilineError(msg string) { + writeErrorLog(os.Stderr, msg) } func shouldColorStderr() bool { - if os.Getenv("NO_COLOR") != "" { + return shouldColorOutput(stderrIsTTY(), os.Getenv("TERM")) +} + +func shouldColorOutput(isTTY bool, term string) bool { + if os.Getenv("NO_COLOR") != "" || os.Getenv("CLICOLOR") == "0" { return false } - term := os.Getenv("TERM") + if forceColorEnabled() { + return true + } if term == "" || term == "dumb" { return false } + return isTTY +} + +func forceColorEnabled() bool { + return os.Getenv("FORCE_COLOR") != "" || os.Getenv("CLICOLOR_FORCE") != "" +} + +func stderrIsTTY() bool { info, err := os.Stderr.Stat() if err != nil { return false } return (info.Mode() & os.ModeCharDevice) != 0 } + +func writeErrorLog(w io.Writer, msg string) { + line := "wire: " + msg + if !strings.HasSuffix(line, "\n") { + line += "\n" + } + if shouldColorStderr() { + _, _ = io.WriteString(w, colorizeLines(line)) + return + } + _, _ = io.WriteString(w, line) +} + +func colorizeLines(s string) string { + if s == "" { + return "" + } + parts := strings.SplitAfter(s, "\n") + var b strings.Builder + for _, part := range parts { + if part == "" { + continue + } + b.WriteString(ansiRed) + b.WriteString(part) + b.WriteString(ansiReset) + } + return b.String() +} diff --git a/cmd/wire/main_test.go b/cmd/wire/main_test.go new file mode 100644 index 0000000..b172f62 --- /dev/null +++ b/cmd/wire/main_test.go @@ -0,0 +1,114 @@ +package main + +import ( + "bytes" + "testing" +) + +func TestFormatLoggedErrorAddsSolveHeader(t *testing.T) { + err := testError("inject InitializeApplication: no provider found for *example.Foo") + got := formatLoggedError(err) + want := "solve failed\ninject InitializeApplication: no provider found for *example.Foo" + if got != want { + t.Fatalf("formatLoggedError() = %q, want %q", got, want) + } +} + +func TestFormatLoggedErrorAddsSolveHeaderWithPositionPrefix(t *testing.T) { + err := testError("/tmp/wire.go:12:1: inject InitializeApplication: no provider found for *example.Foo") + got := formatLoggedError(err) + want := "solve failed\n/tmp/wire.go:12:1: inject InitializeApplication: no provider found for *example.Foo" + if got != want { + t.Fatalf("formatLoggedError() = %q, want %q", got, want) + } +} + +func TestFormatLoggedErrorLeavesNonSolveErrorsUnchanged(t *testing.T) { + err := testError("type-check failed for example.com/app/app") + got := formatLoggedError(err) + if got != err.Error() { + t.Fatalf("formatLoggedError() = %q, want %q", got, err.Error()) + } +} + +func TestShouldColorOutputForceColorOverridesTTYRequirement(t *testing.T) { + t.Setenv("FORCE_COLOR", "1") + t.Setenv("NO_COLOR", "") + t.Setenv("CLICOLOR", "") + t.Setenv("CLICOLOR_FORCE", "") + + if !shouldColorOutput(false, "xterm-256color") { + t.Fatal("shouldColorOutput() = false, want true when FORCE_COLOR is set") + } +} + +func TestShouldColorOutputNoColorWins(t *testing.T) { + t.Setenv("FORCE_COLOR", "1") + t.Setenv("NO_COLOR", "1") + t.Setenv("CLICOLOR", "") + t.Setenv("CLICOLOR_FORCE", "") + + if shouldColorOutput(true, "xterm-256color") { + t.Fatal("shouldColorOutput() = true, want false when NO_COLOR is set") + } +} + +func TestShouldColorOutputTTYFallback(t *testing.T) { + t.Setenv("FORCE_COLOR", "") + t.Setenv("NO_COLOR", "") + t.Setenv("CLICOLOR", "") + t.Setenv("CLICOLOR_FORCE", "") + + if !shouldColorOutput(true, "xterm-256color") { + t.Fatal("shouldColorOutput() = false, want true for tty stderr") + } + if shouldColorOutput(false, "xterm-256color") { + t.Fatal("shouldColorOutput() = true, want false for non-tty stderr without force color") + } +} + +func TestWriteErrorLogFormatsWirePrefix(t *testing.T) { + var buf bytes.Buffer + writeErrorLog(&buf, "type-check failed for example.com/app/app") + got := buf.String() + want := "wire: type-check failed for example.com/app/app\n" + if got != want { + t.Fatalf("writeErrorLog() = %q, want %q", got, want) + } +} + +func TestWriteErrorLogColorsWholeBlockWhenForced(t *testing.T) { + t.Setenv("FORCE_COLOR", "1") + t.Setenv("NO_COLOR", "") + t.Setenv("CLICOLOR", "") + t.Setenv("CLICOLOR_FORCE", "") + + var buf bytes.Buffer + writeErrorLog(&buf, "type-check failed for example.com/app/app") + got := buf.String() + want := ansiRed + "wire: type-check failed for example.com/app/app\n" + ansiReset + if got != want { + t.Fatalf("writeErrorLog() = %q, want %q", got, want) + } +} + +func TestWriteErrorLogColorsEachMultilineLineWhenForced(t *testing.T) { + t.Setenv("FORCE_COLOR", "1") + t.Setenv("NO_COLOR", "") + t.Setenv("CLICOLOR", "") + t.Setenv("CLICOLOR_FORCE", "") + + var buf bytes.Buffer + writeErrorLog(&buf, "\n first line\n second line") + got := buf.String() + want := ansiRed + "wire: \n" + ansiReset + + ansiRed + " first line\n" + ansiReset + + ansiRed + " second line\n" + ansiReset + if got != want { + t.Fatalf("writeErrorLog() = %q, want %q", got, want) + } +} + +type testError string + +func (e testError) Error() string { return string(e) } diff --git a/internal/wire/incremental_fingerprint.go b/internal/wire/incremental_fingerprint.go index 46485f7..42c2317 100644 --- a/internal/wire/incremental_fingerprint.go +++ b/internal/wire/incremental_fingerprint.go @@ -54,6 +54,7 @@ type fingerprintStats struct { type incrementalFingerprintSnapshot struct { stats fingerprintStats changed []string + touched []string fingerprints map[string]*packageFingerprint } @@ -106,6 +107,7 @@ func collectIncrementalFingerprints(wd string, tags string, pkgs []*packages.Pac continue } snapshot.stats.metaMisses++ + snapshot.touched = append(snapshot.touched, pkg.PkgPath) fp, err := buildPackageFingerprint(wd, tags, pkg, metaFiles) if err != nil { continue @@ -121,6 +123,7 @@ func collectIncrementalFingerprints(wd string, tags string, pkgs []*packages.Pac snapshot.changed = append(snapshot.changed, pkg.PkgPath) } sort.Strings(snapshot.changed) + sort.Strings(snapshot.touched) return snapshot } diff --git a/internal/wire/incremental_manifest.go b/internal/wire/incremental_manifest.go index cd88976..4a55c19 100644 --- a/internal/wire/incremental_manifest.go +++ b/internal/wire/incremental_manifest.go @@ -414,6 +414,8 @@ func incrementalManifestCurrentLocalPackages(ctx context.Context, local []packag if firstReason == "" { firstReason = fp.PkgPath + ".shape_mismatch" } + } else if firstReason == "" { + firstReason = fp.PkgPath + ".meta_changed" } } if changed, err := packageDirectoryIntroducedRelevantFiles(fp.Files); err != nil { @@ -574,6 +576,13 @@ func writeIncrementalManifestFile(key string, manifest *incrementalManifest) { } } +func removeIncrementalManifestFile(key string) { + if key == "" { + return + } + _ = osRemove(incrementalManifestPath(key)) +} + func encodeIncrementalManifest(manifest *incrementalManifest) ([]byte, error) { var buf bytes.Buffer if manifest == nil { diff --git a/internal/wire/loader_test.go b/internal/wire/loader_test.go index 6899249..d74c0f8 100644 --- a/internal/wire/loader_test.go +++ b/internal/wire/loader_test.go @@ -220,7 +220,7 @@ func TestLoadAndGenerateModuleIncrementalMatches(t *testing.T) { } } -func TestGenerateIncrementalManifestSkipsLazyLoadOnBodyOnlyChange(t *testing.T) { +func TestGenerateIncrementalBodyOnlyChangeFallsBackToLoadAndReusesOutput(t *testing.T) { lockCacheHooks(t) state := saveCacheHooks() t.Cleanup(func() { restoreCacheHooks(state) }) @@ -340,17 +340,119 @@ func TestGenerateIncrementalManifestSkipsLazyLoadOnBodyOnlyChange(t *testing.T) if len(second) != 1 || len(second[0].Errs) > 0 { t.Fatalf("unexpected second Generate result: %+v", second) } - if containsLabel(secondLabels, "generate.load") { - t.Fatalf("expected second Generate to hit preload incremental manifest before package load, labels=%v", secondLabels) - } - if containsLabel(secondLabels, "load.packages.lazy.load") { - t.Fatalf("expected second Generate to skip lazy load, labels=%v", secondLabels) + if !containsLabel(secondLabels, "generate.load") { + t.Fatalf("expected second Generate to re-load packages after body-only change, labels=%v", secondLabels) } if string(first[0].Content) != string(second[0].Content) { t.Fatal("expected body-only change to reuse identical generated output") } } +func TestGenerateIncrementalBodyOnlyInvalidChangeDoesNotReusePreloadManifest(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + + writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *dep.Foo {", + "\twire.Build(dep.NewSet)", + "\treturn nil", + "}", + "", + }, "\n")) + + depFile := filepath.Join(root, "dep", "dep.go") + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return \"a\" }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, New)", + "", + }, "\n")) + + env := append(os.Environ(), "GOWORK=off") + ctx := WithIncremental(context.Background(), true) + + first, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("first Generate returned errors: %v", errs) + } + if len(first) != 1 || len(first[0].Errs) > 0 { + t.Fatalf("unexpected first Generate result: %+v", first) + } + + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return \"a\" }", + "", + "func New(msg string) *Foo {", + "\treturn missing", + "}", + "", + }, "\n")) + + var secondLabels []string + secondCtx := WithTiming(ctx, func(label string, _ time.Duration) { + secondLabels = append(secondLabels, label) + }) + second, errs := Generate(secondCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(second) != 0 { + t.Fatalf("expected invalid body-only change to stop before generation, got %+v", second) + } + if len(errs) == 0 { + t.Fatal("expected invalid body-only change to return errors") + } + if !containsLabel(secondLabels, "generate.load") { + t.Fatalf("expected invalid body-only change to bypass preload manifest and load packages, labels=%v", secondLabels) + } + if got := errs[0].Error(); !strings.Contains(got, "undefined: missing") { + t.Fatalf("expected load/type-check error from invalid body-only change, got %q", got) + } +} + func TestGenerateIncrementalShapeChangeFallsBackToLazyLoad(t *testing.T) { lockCacheHooks(t) state := saveCacheHooks() @@ -1078,6 +1180,156 @@ func TestGenerateIncrementalInvalidShapeChangeDoesNotReuseManifest(t *testing.T) if got := errs[0].Error(); !strings.Contains(got, "type-check failed for example.com/app/app") { t.Fatalf("expected fast-path type-check error, got %q", got) } + if _, ok := readIncrementalManifest(incrementalManifestSelectorKey(root, env, []string{"./app"}, &GenerateOptions{})); ok { + t.Fatal("expected invalid incremental generate to invalidate selector manifest") + } +} + +func TestGenerateIncrementalRecoversAfterInvalidShapeChange(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + + writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *dep.Foo {", + "\twire.Build(dep.NewSet)", + "\treturn nil", + "}", + "", + }, "\n")) + + depFile := filepath.Join(root, "dep", "dep.go") + wireFile := filepath.Join(root, "dep", "wire.go") + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return \"a\" }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n")) + writeFile(t, wireFile, strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, New)", + "", + }, "\n")) + + env := append(os.Environ(), "GOWORK=off") + ctx := WithIncremental(context.Background(), true) + + first, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("first Generate returned errors: %v", errs) + } + if len(first) != 1 || len(first[0].Errs) > 0 { + t.Fatalf("unexpected first Generate result: %+v", first) + } + + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "import \"example.com/app/extra\"", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return \"a\" }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n")) + + second, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(second) != 0 { + t.Fatalf("expected invalid incremental generate to stop before generation, got %+v", second) + } + if len(errs) == 0 { + t.Fatal("expected invalid incremental generate to return errors") + } + clearIncrementalSessions() + + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string; Count int }", + "", + "func NewMessage() string { return \"a\" }", + "", + "func NewCount() int { return 7 }", + "", + "func New(msg string, count int) *Foo {", + "\treturn &Foo{Message: msg, Count: count}", + "}", + "", + }, "\n")) + writeFile(t, wireFile, strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, NewCount, New)", + "", + }, "\n")) + + var thirdLabels []string + thirdCtx := WithTiming(ctx, func(label string, _ time.Duration) { + thirdLabels = append(thirdLabels, label) + }) + third, errs := Generate(thirdCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("recovery incremental Generate returned errors: %v", errs) + } + if len(third) != 1 || len(third[0].Errs) > 0 { + t.Fatalf("unexpected recovery incremental Generate result: %+v", third) + } + + normal, errs := Generate(context.Background(), root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("normal Generate returned errors: %v", errs) + } + if len(normal) != 1 || len(normal[0].Errs) > 0 { + t.Fatalf("unexpected normal Generate result: %+v", normal) + } + if string(third[0].Content) != string(normal[0].Content) { + t.Fatal("incremental output differs from normal Generate output after recovering from invalid shape change") + } + if !containsLabel(thirdLabels, "generate.load") { + t.Fatalf("expected recovery run to fall back to normal load after invalidating stale manifest, labels=%v", thirdLabels) + } } func TestGenerateIncrementalToggleBackToKnownShapeHitsArchivedPreloadManifest(t *testing.T) { diff --git a/internal/wire/local_fastpath.go b/internal/wire/local_fastpath.go index 4ef1f8f..04d9cb8 100644 --- a/internal/wire/local_fastpath.go +++ b/internal/wire/local_fastpath.go @@ -63,6 +63,7 @@ func tryIncrementalLocalFastPath(ctx context.Context, wd string, env []string, p if err != nil { debugf(ctx, "incremental.local_fastpath miss reason=%v", err) if shouldBypassIncrementalManifestAfterFastPathError(err) { + invalidateIncrementalPreloadState(state) return nil, true, true, []error{err} } return nil, false, false, nil @@ -104,6 +105,18 @@ func tryIncrementalLocalFastPath(ctx context.Context, wd string, env []string, p return generated, true, false, nil } +func validateIncrementalTouchedPackages(ctx context.Context, wd string, opts *GenerateOptions, state *incrementalPreloadState, snapshot *incrementalFingerprintSnapshot) error { + if state == nil || state.manifest == nil || snapshot == nil || len(snapshot.touched) == 0 { + return nil + } + roots := manifestOutputPkgPaths(state.manifest) + if len(roots) != 1 { + return nil + } + _, err := loadLocalPackagesForFastPath(ctx, wd, opts.Tags, roots[0], snapshot.touched, snapshotPackageFingerprints(snapshot), state.manifest.ExternalPkgs) + return err +} + func shouldBypassIncrementalManifestAfterFastPathError(err error) bool { if err == nil { return false @@ -115,6 +128,13 @@ func shouldBypassIncrementalManifestAfterFastPathError(err error) bool { return strings.Contains(msg, "type-check failed for ") } +func invalidateIncrementalPreloadState(state *incrementalPreloadState) { + if state == nil { + return + } + removeIncrementalManifestFile(state.selectorKey) +} + func formatLocalTypeCheckError(wd string, pkgPath string, errs []packages.Error) error { if len(errs) == 0 { return fmt.Errorf("type-check failed for %s", pkgPath) diff --git a/internal/wire/parser_lazy_loader.go b/internal/wire/parser_lazy_loader.go index 223c9ad..f6137bc 100644 --- a/internal/wire/parser_lazy_loader.go +++ b/internal/wire/parser_lazy_loader.go @@ -128,7 +128,8 @@ func (ll *lazyLoader) parseFileFor(pkgPath string, stats *parseFileStats) func(* return func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { start := time.Now() isPrimary := isPrimaryFile(primary, filename) - if !isPrimary && ll.session != nil { + keepBodies := ll.shouldKeepDependencyBodies(filename) + if !isPrimary && !keepBodies && ll.session != nil { if file, ok := ll.session.getParsedDep(filename, src); ok { if stats != nil { stats.record(false, time.Since(start), nil, true) @@ -153,6 +154,9 @@ func (ll *lazyLoader) parseFileFor(pkgPath string, stats *parseFileStats) func(* if isPrimary { return file, nil } + if keepBodies { + return file, nil + } for _, decl := range file.Decls { if fn, ok := decl.(*ast.FuncDecl); ok { fn.Body = nil @@ -165,3 +169,20 @@ func (ll *lazyLoader) parseFileFor(pkgPath string, stats *parseFileStats) func(* return file, nil } } + +func (ll *lazyLoader) shouldKeepDependencyBodies(filename string) bool { + if ll == nil || ll.fingerprints == nil || len(ll.fingerprints.touched) == 0 { + return false + } + clean := filepath.Clean(filename) + for _, pkgPath := range ll.fingerprints.touched { + files := ll.baseFiles[pkgPath] + if len(files) == 0 { + continue + } + if _, ok := files[clean]; ok { + return true + } + } + return false +} diff --git a/internal/wire/wire.go b/internal/wire/wire.go index 8b617ea..1b6140f 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -129,6 +129,12 @@ func Generate(ctx context.Context, wd string, env []string, patterns []string, o if len(errs) > 0 { return nil, errs } + if err := validateIncrementalTouchedPackages(ctx, wd, opts, preloadState, loader.fingerprints); err != nil { + if shouldBypassIncrementalManifestAfterFastPathError(err) { + return nil, []error{err} + } + bypassIncrementalManifest = true + } if !bypassIncrementalManifest { if cached, ok := readIncrementalManifestResults(ctx, wd, env, patterns, opts, pkgs, loader.fingerprints); ok { warmPackageOutputCache(pkgs, opts, cached) From c6e4f4e3babe2f5308e29568981b6c23b98e48d2 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sun, 15 Mar 2026 00:51:55 -0500 Subject: [PATCH 07/79] feat(incremental): harden loader and scenario tooling --- cmd/wire/gen_cmd.go | 4 +- cmd/wire/main.go | 41 +- cmd/wire/main_test.go | 31 +- cmd/wire/watch_cmd.go | 4 +- internal/wire/cache_coverage_test.go | 8 +- internal/wire/cache_key.go | 68 +- internal/wire/cache_manifest.go | 21 +- internal/wire/cache_scope.go | 69 + internal/wire/cache_scope_test.go | 59 + internal/wire/incremental_bench_test.go | 798 +++++++++++- internal/wire/incremental_fingerprint.go | 169 ++- internal/wire/incremental_fingerprint_test.go | 38 + internal/wire/incremental_graph.go | 4 +- internal/wire/incremental_manifest.go | 371 +++++- internal/wire/incremental_session.go | 2 +- internal/wire/incremental_summary.go | 2 +- internal/wire/loader_test.go | 1113 ++++++++++++++++- internal/wire/local_export.go | 97 ++ internal/wire/local_fastpath.go | 156 ++- internal/wire/wire.go | 4 + scripts/incremental-scenarios.sh | 137 ++ 21 files changed, 3076 insertions(+), 120 deletions(-) create mode 100644 internal/wire/cache_scope.go create mode 100644 internal/wire/cache_scope_test.go create mode 100644 internal/wire/local_export.go create mode 100755 scripts/incremental-scenarios.sh diff --git a/cmd/wire/gen_cmd.go b/cmd/wire/gen_cmd.go index 13b88ed..e98556f 100644 --- a/cmd/wire/gen_cmd.go +++ b/cmd/wire/gen_cmd.go @@ -112,9 +112,9 @@ func (cmd *genCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interfa } if wrote, err := out.CommitWithStatus(); err == nil { if wrote { - log.Printf("%s: wrote %s (%s)\n", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) + logSuccessf("%s: wrote %s (%s)", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) } else { - log.Printf("%s: unchanged %s (%s)\n", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) + logSuccessf("%s: unchanged %s (%s)", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) } } else { log.Printf("%s: failed to write %s: %v\n", out.PkgPath, out.OutputPath, err) diff --git a/cmd/wire/main.go b/cmd/wire/main.go index d40e439..efaf767 100644 --- a/cmd/wire/main.go +++ b/cmd/wire/main.go @@ -38,8 +38,12 @@ import ( var topLevelIncremental optionalBoolFlag const ( - ansiRed = "\033[1;31m" - ansiReset = "\033[0m" + ansiRed = "\033[1;31m" + ansiGreen = "\033[1;32m" + ansiReset = "\033[0m" + successSig = "✓ " + errorSig = "x " + maxLoggedErrorLines = 5 ) // main wires up subcommands and executes the selected command. @@ -209,7 +213,7 @@ func newGenerateOptions(headerFile string) (*wire.GenerateOptions, error) { // logErrors logs each error with consistent formatting. func logErrors(errs []error) { for _, err := range errs { - msg := formatLoggedError(err) + msg := truncateLoggedError(formatLoggedError(err)) if strings.Contains(msg, "\n") { logMultilineError("\n " + strings.ReplaceAll(msg, "\n", "\n ")) continue @@ -232,10 +236,27 @@ func formatLoggedError(err error) string { return msg } +func truncateLoggedError(msg string) string { + if msg == "" { + return "" + } + lines := strings.Split(msg, "\n") + if len(lines) <= maxLoggedErrorLines { + return msg + } + omitted := len(lines) - maxLoggedErrorLines + lines = append(lines[:maxLoggedErrorLines], fmt.Sprintf("... (%d additional lines omitted)", omitted)) + return strings.Join(lines, "\n") +} + func logMultilineError(msg string) { writeErrorLog(os.Stderr, msg) } +func logSuccessf(format string, args ...interface{}) { + writeStatusLog(os.Stderr, fmt.Sprintf(format, args...)) +} + func shouldColorStderr() bool { return shouldColorOutput(stderrIsTTY(), os.Getenv("TERM")) } @@ -266,7 +287,7 @@ func stderrIsTTY() bool { } func writeErrorLog(w io.Writer, msg string) { - line := "wire: " + msg + line := errorSig + "wire: " + msg if !strings.HasSuffix(line, "\n") { line += "\n" } @@ -277,6 +298,18 @@ func writeErrorLog(w io.Writer, msg string) { _, _ = io.WriteString(w, line) } +func writeStatusLog(w io.Writer, msg string) { + line := successSig + "wire: " + msg + if !strings.HasSuffix(line, "\n") { + line += "\n" + } + if shouldColorStderr() { + _, _ = io.WriteString(w, ansiGreen+line+ansiReset) + return + } + _, _ = io.WriteString(w, line) +} + func colorizeLines(s string) string { if s == "" { return "" diff --git a/cmd/wire/main_test.go b/cmd/wire/main_test.go index b172f62..7fe4720 100644 --- a/cmd/wire/main_test.go +++ b/cmd/wire/main_test.go @@ -2,6 +2,8 @@ package main import ( "bytes" + "fmt" + "strings" "testing" ) @@ -31,6 +33,19 @@ func TestFormatLoggedErrorLeavesNonSolveErrorsUnchanged(t *testing.T) { } } +func TestTruncateLoggedErrorSummarizesLargeBlocks(t *testing.T) { + lines := make([]string, 0, maxLoggedErrorLines+3) + for i := 0; i < maxLoggedErrorLines+3; i++ { + lines = append(lines, fmt.Sprintf("line %d", i+1)) + } + got := truncateLoggedError(strings.Join(lines, "\n")) + wantLines := append(append([]string(nil), lines[:maxLoggedErrorLines]...), "... (3 additional lines omitted)") + want := strings.Join(wantLines, "\n") + if got != want { + t.Fatalf("truncateLoggedError() = %q, want %q", got, want) + } +} + func TestShouldColorOutputForceColorOverridesTTYRequirement(t *testing.T) { t.Setenv("FORCE_COLOR", "1") t.Setenv("NO_COLOR", "") @@ -71,7 +86,7 @@ func TestWriteErrorLogFormatsWirePrefix(t *testing.T) { var buf bytes.Buffer writeErrorLog(&buf, "type-check failed for example.com/app/app") got := buf.String() - want := "wire: type-check failed for example.com/app/app\n" + want := errorSig + "wire: type-check failed for example.com/app/app\n" if got != want { t.Fatalf("writeErrorLog() = %q, want %q", got, want) } @@ -86,7 +101,7 @@ func TestWriteErrorLogColorsWholeBlockWhenForced(t *testing.T) { var buf bytes.Buffer writeErrorLog(&buf, "type-check failed for example.com/app/app") got := buf.String() - want := ansiRed + "wire: type-check failed for example.com/app/app\n" + ansiReset + want := ansiRed + errorSig + "wire: type-check failed for example.com/app/app\n" + ansiReset if got != want { t.Fatalf("writeErrorLog() = %q, want %q", got, want) } @@ -101,7 +116,7 @@ func TestWriteErrorLogColorsEachMultilineLineWhenForced(t *testing.T) { var buf bytes.Buffer writeErrorLog(&buf, "\n first line\n second line") got := buf.String() - want := ansiRed + "wire: \n" + ansiReset + + want := ansiRed + errorSig + "wire: \n" + ansiReset + ansiRed + " first line\n" + ansiReset + ansiRed + " second line\n" + ansiReset if got != want { @@ -109,6 +124,16 @@ func TestWriteErrorLogColorsEachMultilineLineWhenForced(t *testing.T) { } } +func TestWriteStatusLogFormatsSuccessPrefix(t *testing.T) { + var buf bytes.Buffer + writeStatusLog(&buf, "example.com/app: wrote /tmp/wire_gen.go (12ms)") + got := buf.String() + want := successSig + "wire: example.com/app: wrote /tmp/wire_gen.go (12ms)\n" + if got != want { + t.Fatalf("writeStatusLog() = %q, want %q", got, want) + } +} + type testError string func (e testError) Error() string { return string(e) } diff --git a/cmd/wire/watch_cmd.go b/cmd/wire/watch_cmd.go index 13743cd..cb1b31b 100644 --- a/cmd/wire/watch_cmd.go +++ b/cmd/wire/watch_cmd.go @@ -131,9 +131,9 @@ func (cmd *watchCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...inter } if wrote, err := out.CommitWithStatus(); err == nil { if wrote { - log.Printf("%s: wrote %s (%s)\n", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) + logSuccessf("%s: wrote %s (%s)", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) } else { - log.Printf("%s: unchanged %s (%s)\n", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) + logSuccessf("%s: unchanged %s (%s)", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) } } else { log.Printf("%s: failed to write %s: %v\n", out.PkgPath, out.OutputPath, err) diff --git a/internal/wire/cache_coverage_test.go b/internal/wire/cache_coverage_test.go index 316f30e..faf6e62 100644 --- a/internal/wire/cache_coverage_test.go +++ b/internal/wire/cache_coverage_test.go @@ -605,16 +605,18 @@ func TestManifestKeyHelpers(t *testing.T) { PrefixOutputFile: "prefix", Header: []byte("header"), } + wd := t.TempDir() + patterns := []string{"./a", "./b"} manifest := &cacheManifest{ - WD: t.TempDir(), + WD: runCacheScope(wd, patterns), EnvHash: envHash(env), Tags: opts.Tags, Prefix: opts.PrefixOutputFile, HeaderHash: headerHash(opts.Header), - Patterns: []string{"./a", "./b"}, + Patterns: normalizePatternsForScope(wd, packageCacheScope(wd), patterns), } got := manifestKeyFromManifest(manifest) - want := manifestKey(manifest.WD, env, manifest.Patterns, opts) + want := manifestKey(wd, env, patterns, opts) if got != want { t.Fatalf("manifest key mismatch: got %q, want %q", got, want) } diff --git a/internal/wire/cache_key.go b/internal/wire/cache_key.go index 2aa8881..f22c6c0 100644 --- a/internal/wire/cache_key.go +++ b/internal/wire/cache_key.go @@ -18,7 +18,9 @@ import ( "crypto/sha256" "fmt" "path/filepath" + "runtime" "sort" + "sync" "golang.org/x/tools/go/packages" ) @@ -209,17 +211,69 @@ func cacheMetaMatches(meta *cacheMeta, pkg *packages.Package, opts *GenerateOpti // buildCacheFiles converts file paths into cache metadata entries. func buildCacheFiles(files []string) ([]cacheFile, error) { - out := make([]cacheFile, 0, len(files)) - for _, name := range files { - info, err := osStat(name) + return buildCacheFilesWithStats(files, func(path string) (cacheFile, error) { + info, err := osStat(path) if err != nil { - return nil, err + return cacheFile{}, err } - out = append(out, cacheFile{ - Path: filepath.Clean(name), + return cacheFile{ + Path: filepath.Clean(path), Size: info.Size(), ModTime: info.ModTime().UnixNano(), - }) + }, nil + }) +} + +func buildCacheFilesWithStats[T any](items []T, stat func(T) (cacheFile, error)) ([]cacheFile, error) { + if len(items) == 0 { + return nil, nil + } + if len(items) == 1 { + file, err := stat(items[0]) + if err != nil { + return nil, err + } + return []cacheFile{file}, nil + } + out := make([]cacheFile, len(items)) + workers := runtime.GOMAXPROCS(0) + if workers < 4 { + workers = 4 + } + if workers > len(items) { + workers = len(items) + } + var ( + wg sync.WaitGroup + mu sync.Mutex + firstErr error + indexCh = make(chan int, len(items)) + ) + for i := range items { + indexCh <- i + } + close(indexCh) + wg.Add(workers) + for i := 0; i < workers; i++ { + go func() { + defer wg.Done() + for i := range indexCh { + file, err := stat(items[i]) + if err != nil { + mu.Lock() + if firstErr == nil { + firstErr = err + } + mu.Unlock() + continue + } + out[i] = file + } + }() + } + wg.Wait() + if firstErr != nil { + return nil, firstErr } return out, nil } diff --git a/internal/wire/cache_manifest.go b/internal/wire/cache_manifest.go index 127aa55..57be68b 100644 --- a/internal/wire/cache_manifest.go +++ b/internal/wire/cache_manifest.go @@ -79,14 +79,15 @@ func writeManifest(wd string, env []string, patterns []string, opts *GenerateOpt return } key := manifestKey(wd, env, patterns, opts) + scope := runCacheScope(wd, patterns) manifest := &cacheManifest{ Version: cacheVersion, - WD: wd, + WD: scope, Tags: opts.Tags, Prefix: opts.PrefixOutputFile, HeaderHash: headerHash(opts.Header), EnvHash: envHash(env), - Patterns: sortedStrings(patterns), + Patterns: normalizePatternsForScope(wd, packageCacheScope(wd), patterns), } manifest.ExtraFiles = extraCacheFiles(wd) for _, pkg := range pkgs { @@ -138,7 +139,7 @@ func manifestKey(wd string, env []string, patterns []string, opts *GenerateOptio h := sha256.New() h.Write([]byte(cacheVersion)) h.Write([]byte{0}) - h.Write([]byte(filepath.Clean(wd))) + h.Write([]byte(runCacheScope(wd, patterns))) h.Write([]byte{0}) h.Write([]byte(envHash(env))) h.Write([]byte{0}) @@ -148,7 +149,7 @@ func manifestKey(wd string, env []string, patterns []string, opts *GenerateOptio h.Write([]byte{0}) h.Write([]byte(headerHash(opts.Header))) h.Write([]byte{0}) - for _, p := range sortedStrings(patterns) { + for _, p := range normalizePatternsForScope(wd, packageCacheScope(wd), patterns) { h.Write([]byte(p)) h.Write([]byte{0}) } @@ -293,19 +294,17 @@ func manifestValid(manifest *cacheManifest) bool { // buildCacheFilesFromMeta re-stats files to compare metadata. func buildCacheFilesFromMeta(files []cacheFile) ([]cacheFile, error) { - out := make([]cacheFile, 0, len(files)) - for _, file := range files { + return buildCacheFilesWithStats(files, func(file cacheFile) (cacheFile, error) { info, err := osStat(file.Path) if err != nil { - return nil, err + return cacheFile{}, err } - out = append(out, cacheFile{ + return cacheFile{ Path: filepath.Clean(file.Path), Size: info.Size(), ModTime: info.ModTime().UnixNano(), - }) - } - return out, nil + }, nil + }) } // extraCacheFiles returns Go module/workspace files affecting builds. diff --git a/internal/wire/cache_scope.go b/internal/wire/cache_scope.go new file mode 100644 index 0000000..fe161a7 --- /dev/null +++ b/internal/wire/cache_scope.go @@ -0,0 +1,69 @@ +package wire + +import ( + "path/filepath" + "sort" + "strings" +) + +func packageCacheScope(wd string) string { + if root := findModuleRoot(wd); root != "" { + return filepath.Clean(root) + } + return filepath.Clean(wd) +} + +func runCacheScope(wd string, patterns []string) string { + scopeRoot := packageCacheScope(wd) + normalized := normalizePatternsForScope(wd, scopeRoot, patterns) + if len(normalized) == 0 { + return scopeRoot + } + return scopeRoot + "\n" + strings.Join(normalized, "\n") +} + +func normalizePatternsForScope(wd string, scopeRoot string, patterns []string) []string { + if len(patterns) == 0 { + return nil + } + out := make([]string, 0, len(patterns)) + for _, pattern := range patterns { + out = append(out, normalizePatternForScope(wd, scopeRoot, pattern)) + } + sort.Strings(out) + return out +} + +func normalizePatternForScope(wd string, scopeRoot string, pattern string) string { + if pattern == "" { + return pattern + } + if filepath.IsAbs(pattern) || strings.HasPrefix(pattern, ".") { + abs := pattern + if !filepath.IsAbs(abs) { + abs = filepath.Join(wd, pattern) + } + abs = filepath.Clean(abs) + if scopeRoot != "" { + if rel, ok := pathWithinRoot(scopeRoot, abs); ok { + if rel == "." { + return "." + } + return filepath.ToSlash(rel) + } + } + return filepath.ToSlash(abs) + } + return pattern +} + +func pathWithinRoot(root string, path string) (string, bool) { + rel, err := filepath.Rel(filepath.Clean(root), filepath.Clean(path)) + if err != nil { + return "", false + } + if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) { + return "", false + } + return rel, true +} diff --git a/internal/wire/cache_scope_test.go b/internal/wire/cache_scope_test.go new file mode 100644 index 0000000..9cc518b --- /dev/null +++ b/internal/wire/cache_scope_test.go @@ -0,0 +1,59 @@ +package wire + +import ( + "path/filepath" + "testing" +) + +func TestRunScopedKeysIgnoreEquivalentWorkingDirectories(t *testing.T) { + root := t.TempDir() + writeFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.22\n") + wireDir := filepath.Join(root, "wire") + writeFile(t, filepath.Join(wireDir, "wire.go"), "package wire\n") + + env := []string{"GOOS=darwin"} + opts := &GenerateOptions{Tags: "wireinject", PrefixOutputFile: "gen_"} + + rootKey := manifestKey(root, env, []string{"./wire"}, opts) + subdirKey := manifestKey(wireDir, env, []string{"."}, opts) + if rootKey != subdirKey { + t.Fatalf("manifestKey mismatch: root=%q subdir=%q", rootKey, subdirKey) + } + + rootIncrementalKey := incrementalManifestSelectorKey(root, env, []string{"./wire"}, opts) + subdirIncrementalKey := incrementalManifestSelectorKey(wireDir, env, []string{"."}, opts) + if rootIncrementalKey != subdirIncrementalKey { + t.Fatalf("incrementalManifestSelectorKey mismatch: root=%q subdir=%q", rootIncrementalKey, subdirIncrementalKey) + } +} + +func TestPackageScopedKeysIgnoreEquivalentWorkingDirectories(t *testing.T) { + root := t.TempDir() + writeFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.22\n") + wireDir := filepath.Join(root, "wire") + writeFile(t, filepath.Join(wireDir, "wire.go"), "package wire\n") + + rootFingerprintKey := incrementalFingerprintKey(root, "wireinject", "example.com/app/wire") + subdirFingerprintKey := incrementalFingerprintKey(wireDir, "wireinject", "example.com/app/wire") + if rootFingerprintKey != subdirFingerprintKey { + t.Fatalf("incrementalFingerprintKey mismatch: root=%q subdir=%q", rootFingerprintKey, subdirFingerprintKey) + } + + rootSummaryKey := incrementalSummaryKey(root, "wireinject", "example.com/app/wire") + subdirSummaryKey := incrementalSummaryKey(wireDir, "wireinject", "example.com/app/wire") + if rootSummaryKey != subdirSummaryKey { + t.Fatalf("incrementalSummaryKey mismatch: root=%q subdir=%q", rootSummaryKey, subdirSummaryKey) + } + + rootGraphKey := incrementalGraphKey(root, "wireinject", []string{"example.com/app/wire"}) + subdirGraphKey := incrementalGraphKey(wireDir, "wireinject", []string{"example.com/app/wire"}) + if rootGraphKey != subdirGraphKey { + t.Fatalf("incrementalGraphKey mismatch: root=%q subdir=%q", rootGraphKey, subdirGraphKey) + } + + rootSessionKey := sessionKey(root, []string{"GOOS=darwin"}, "wireinject") + subdirSessionKey := sessionKey(wireDir, []string{"GOOS=darwin"}, "wireinject") + if rootSessionKey != subdirSessionKey { + t.Fatalf("sessionKey mismatch: root=%q subdir=%q", rootSessionKey, subdirSessionKey) + } +} diff --git a/internal/wire/incremental_bench_test.go b/internal/wire/incremental_bench_test.go index 911a8c7..b300c4f 100644 --- a/internal/wire/incremental_bench_test.go +++ b/internal/wire/incremental_bench_test.go @@ -5,10 +5,12 @@ import ( "fmt" "os" "path/filepath" + "sort" "strconv" "strings" "testing" "time" + "unicode/utf8" ) const ( @@ -18,6 +20,38 @@ const ( var largeBenchmarkSizes = []int{10, 100, 1000} +type incrementalScenarioBenchmarkCase struct { + name string + mutate func(tb testing.TB, root string) + measure func(tb testing.TB, root string, env []string, ctx context.Context) incrementalScenarioTrace + wantErr bool +} + +type incrementalScenarioTrace struct { + total time.Duration + labels map[string]time.Duration +} + +type incrementalScenarioBudget struct { + total time.Duration + validateLocal time.Duration + validateExt time.Duration + validateTouch time.Duration + validateTouchHit time.Duration + outputs time.Duration + generateLoad time.Duration + localFastpath time.Duration +} + +type largeRepoPerformanceBudget struct { + shapeTotal time.Duration + localLoad time.Duration + parse time.Duration + typecheck time.Duration + generate time.Duration + knownToggle time.Duration +} + func BenchmarkGenerateIncrementalFirstSeenShapeChange(b *testing.B) { cacheHooksMu.Lock() state := saveCacheHooks() @@ -77,6 +111,129 @@ func BenchmarkGenerateIncrementalFirstSeenShapeChange(b *testing.B) { } } +func BenchmarkGenerateIncrementalScenarioMatrix(b *testing.B) { + cacheHooksMu.Lock() + state := saveCacheHooks() + b.Cleanup(func() { + restoreCacheHooks(state) + cacheHooksMu.Unlock() + }) + + repoRoot := benchmarkRepoRoot(b) + for _, scenario := range incrementalScenarioBenchmarks() { + scenario := scenario + b.Run(scenario.name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + b.StartTimer() + _ = measureIncrementalScenarioOnce(b, repoRoot, scenario) + b.StopTimer() + } + }) + } +} + +func TestPrintIncrementalScenarioBenchmarkTable(t *testing.T) { + if os.Getenv("WIRE_BENCH_SCENARIOS") == "" { + t.Skip("set WIRE_BENCH_SCENARIOS=1 to print the incremental scenario benchmark table") + } + + cacheHooksMu.Lock() + state := saveCacheHooks() + t.Cleanup(func() { + restoreCacheHooks(state) + cacheHooksMu.Unlock() + }) + + repoRoot := benchmarkRepoRoot(t) + rows := [][]string{{ + "scenario", + "total", + "local pkgs", + "external", + "touched", + "touch hit", + "outputs", + "gen load", + "local fastpath", + }} + for _, scenario := range incrementalScenarioBenchmarks() { + trace := measureIncrementalScenarioOnce(t, repoRoot, scenario) + rows = append(rows, []string{ + scenario.name, + formatBenchmarkDuration(trace.total), + formatBenchmarkDuration(trace.label("incremental.preload_manifest.validate_local_packages")), + formatBenchmarkDuration(trace.label("incremental.preload_manifest.validate_external_files")), + formatBenchmarkDuration(trace.label("incremental.preload_manifest.validate_touched")), + formatBenchmarkDuration(trace.label("incremental.preload_manifest.validate_touched_cache_hit")), + formatBenchmarkDuration(trace.label("incremental.preload_manifest.outputs")), + formatBenchmarkDuration(trace.label("generate.load")), + formatBenchmarkDuration(trace.label("incremental.local_fastpath.load")), + }) + } + fmt.Print(renderASCIITable(rows)) +} + +func TestIncrementalScenarioPerformanceBudgets(t *testing.T) { + if os.Getenv("WIRE_PERF_BUDGETS") == "" { + t.Skip("set WIRE_PERF_BUDGETS=1 to enforce incremental scenario performance budgets") + } + + cacheHooksMu.Lock() + state := saveCacheHooks() + t.Cleanup(func() { + restoreCacheHooks(state) + cacheHooksMu.Unlock() + }) + + repoRoot := benchmarkRepoRoot(t) + budgets := incrementalScenarioPerformanceBudgets() + for _, scenario := range incrementalScenarioBenchmarks() { + scenario := scenario + budget, ok := budgets[scenario.name] + if !ok { + t.Fatalf("missing performance budget for scenario %q", scenario.name) + } + t.Run(scenario.name, func(t *testing.T) { + trace := measureIncrementalScenarioMedian(t, repoRoot, scenario, 5) + assertScenarioBudget(t, trace, budget) + }) + } +} + +func TestLargeRepoPerformanceBudgets(t *testing.T) { + if os.Getenv("WIRE_PERF_BUDGETS") == "" { + t.Skip("set WIRE_PERF_BUDGETS=1 to enforce incremental scenario performance budgets") + } + + cacheHooksMu.Lock() + state := saveCacheHooks() + t.Cleanup(func() { + restoreCacheHooks(state) + cacheHooksMu.Unlock() + }) + + repoRoot := benchmarkRepoRoot(t) + budgets := largeRepoPerformanceBudgets() + for _, packageCount := range largeBenchmarkSizes { + packageCount := packageCount + budget, ok := budgets[packageCount] + if !ok { + t.Fatalf("missing large-repo performance budget for size %d", packageCount) + } + t.Run(strconv.Itoa(packageCount), func(t *testing.T) { + trace := measureLargeRepoShapeChangeTraceMedian(t, repoRoot, packageCount, true, 3) + checkBudgetDuration(t, "shape_total", trace.total, budget.shapeTotal) + checkBudgetDuration(t, "local_fastpath_load", trace.label("incremental.local_fastpath.load"), budget.localLoad) + checkBudgetDuration(t, "parse", trace.label("incremental.local_fastpath.parse"), budget.parse) + checkBudgetDuration(t, "typecheck", trace.label("incremental.local_fastpath.typecheck"), budget.typecheck) + checkBudgetDuration(t, "generate", trace.label("incremental.local_fastpath.generate"), budget.generate) + + knownToggle := measureLargeRepoKnownToggleMedian(t, repoRoot, packageCount, 3) + checkBudgetDuration(t, "known_toggle", knownToggle, budget.knownToggle) + }) + } +} + func BenchmarkGenerateLargeRepoNormalShapeChange(b *testing.B) { runLargeRepoShapeChangeBenchmarks(b, false) } @@ -164,8 +321,10 @@ func TestPrintLargeRepoShapeChangeBreakdownTable(t *testing.T) { "old typed load", "new total", "new local load", - "new cached sets", + "new parse", + "new typecheck", "new injector solve", + "new format", "new generate", "speedup", }} @@ -179,8 +338,10 @@ func TestPrintLargeRepoShapeChangeBreakdownTable(t *testing.T) { formatBenchmarkDuration(normal.label("load.packages.lazy.load")), formatBenchmarkDuration(incremental.total), formatBenchmarkDuration(incremental.label("incremental.local_fastpath.load")), - formatBenchmarkDuration(incremental.label("incremental.local_fastpath.summary_resolve")), + formatBenchmarkDuration(incremental.label("incremental.local_fastpath.parse")), + formatBenchmarkDuration(incremental.label("incremental.local_fastpath.typecheck")), formatBenchmarkDuration(incremental.label("generate.package.example.com/app/app.injectors")), + formatBenchmarkDuration(incremental.label("generate.package.example.com/app/app.format")), formatBenchmarkDuration(incremental.label("incremental.local_fastpath.generate")), fmt.Sprintf("%.2fx", speedupRatio(normal.total, incremental.total)), }) @@ -314,6 +475,576 @@ func runLargeRepoShapeChangeBenchmarks(b *testing.B, incremental bool) { } } +func incrementalScenarioBenchmarks() []incrementalScenarioBenchmarkCase { + return []incrementalScenarioBenchmarkCase{ + { + name: "preload_unchanged", + mutate: func(testing.TB, string) {}, + }, + { + name: "preload_whitespace_only_change", + mutate: func(tb testing.TB, root string) { + writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "", + "func New(msg string) *Foo {", + "", + "\treturn &Foo{Message: helper(msg)}", + "", + "}", + "", + }, "\n")) + }, + }, + { + name: "preload_body_only_change", + mutate: func(tb testing.TB, root string) { + writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string {", + "\treturn helper(SQLText)", + "}", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + }, + }, + { + name: "preload_body_only_repeat_change", + measure: func(tb testing.TB, root string, env []string, ctx context.Context) incrementalScenarioTrace { + writeBodyOnlyScenarioVariant(tb, root, "b") + if _, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { + tb.Fatalf("warm changed variant Generate returned errors: %v", errs) + } + writeBodyOnlyScenarioVariant(tb, root, "a") + if _, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { + tb.Fatalf("reset variant Generate returned errors: %v", errs) + } + writeBodyOnlyScenarioVariant(tb, root, "b") + trace := incrementalScenarioTrace{labels: make(map[string]time.Duration)} + timedCtx := WithTiming(ctx, func(label string, dur time.Duration) { + trace.labels[label] += dur + }) + start := time.Now() + gens, errs := Generate(timedCtx, root, env, []string{"./app"}, &GenerateOptions{}) + trace.total = time.Since(start) + if len(errs) > 0 { + tb.Fatalf("%s: Generate returned errors: %v", "preload_body_only_repeat_change", errs) + } + if len(gens) != 1 || len(gens[0].Errs) > 0 { + tb.Fatalf("%s: unexpected Generate results: %+v", "preload_body_only_repeat_change", gens) + } + return trace + }, + }, + { + name: "local_fastpath_method_body_change", + mutate: func(tb testing.TB, root string) { + writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func (f Foo) Summary() string {", + "\treturn helper(f.Message)", + "}", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n")) + }, + }, + { + name: "preload_const_value_change", + mutate: func(tb testing.TB, root string) { + writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "const SQLText = \"blue\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + }, + }, + { + name: "preload_var_initializer_change", + mutate: func(tb testing.TB, root string) { + writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 2", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + }, + }, + { + name: "local_fastpath_add_top_level_helper", + mutate: func(tb testing.TB, root string) { + writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "func NewTag() string { return \"tag\" }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + }, + }, + { + name: "preload_import_only_implementation_change", + mutate: func(tb testing.TB, root string) { + writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return fmt.Sprint(msg) }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + }, + }, + { + name: "local_fastpath_signature_change", + mutate: func(tb testing.TB, root string) { + writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 7", + "", + "type Foo struct {", + "\tMessage string", + "\tCount int", + "}", + "", + "func NewMessage() string { return SQLText }", + "", + "func NewCount() int { return defaultCount }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string, count int) *Foo {", + "\treturn &Foo{Message: helper(msg), Count: count}", + "}", + "", + }, "\n")) + writeBenchmarkFile(tb, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, NewCount, New)", + "", + }, "\n")) + }, + }, + { + name: "local_fastpath_struct_field_addition", + mutate: func(tb testing.TB, root string) { + writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Foo struct {", + "\tMessage string", + "\tCount int", + "}", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg), Count: defaultCount}", + "}", + "", + }, "\n")) + }, + }, + { + name: "local_fastpath_interface_method_addition", + mutate: func(tb testing.TB, root string) { + writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Fooer interface {", + "\tMessage() string", + "\tCount() int", + "}", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + }, + }, + { + name: "fallback_invalid_body_change", + mutate: func(tb testing.TB, root string) { + writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return missing }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + }, + wantErr: true, + }, + } +} + +func incrementalScenarioPerformanceBudgets() map[string]incrementalScenarioBudget { + return map[string]incrementalScenarioBudget{ + "preload_unchanged": { + total: 300 * time.Millisecond, + validateLocal: 25 * time.Millisecond, + validateExt: 25 * time.Millisecond, + validateTouch: 5 * time.Millisecond, + outputs: 5 * time.Millisecond, + }, + "preload_whitespace_only_change": { + total: 300 * time.Millisecond, + validateLocal: 25 * time.Millisecond, + validateExt: 25 * time.Millisecond, + validateTouch: 250 * time.Millisecond, + outputs: 5 * time.Millisecond, + }, + "preload_body_only_change": { + total: 400 * time.Millisecond, + validateLocal: 40 * time.Millisecond, + validateExt: 40 * time.Millisecond, + validateTouch: 250 * time.Millisecond, + outputs: 5 * time.Millisecond, + }, + "preload_body_only_repeat_change": { + total: 150 * time.Millisecond, + validateLocal: 40 * time.Millisecond, + validateExt: 40 * time.Millisecond, + validateTouch: 5 * time.Millisecond, + validateTouchHit: 5 * time.Millisecond, + outputs: 5 * time.Millisecond, + }, + "local_fastpath_method_body_change": { + total: 500 * time.Millisecond, + validateLocal: 60 * time.Millisecond, + validateExt: 60 * time.Millisecond, + localFastpath: 300 * time.Millisecond, + }, + "preload_import_only_implementation_change": { + total: 150 * time.Millisecond, + validateLocal: 40 * time.Millisecond, + validateExt: 40 * time.Millisecond, + validateTouch: 50 * time.Millisecond, + outputs: 5 * time.Millisecond, + }, + "preload_const_value_change": { + total: 400 * time.Millisecond, + validateLocal: 40 * time.Millisecond, + validateExt: 40 * time.Millisecond, + validateTouch: 250 * time.Millisecond, + outputs: 5 * time.Millisecond, + }, + "preload_var_initializer_change": { + total: 400 * time.Millisecond, + validateLocal: 40 * time.Millisecond, + validateExt: 40 * time.Millisecond, + validateTouch: 250 * time.Millisecond, + outputs: 5 * time.Millisecond, + }, + "local_fastpath_add_top_level_helper": { + total: 500 * time.Millisecond, + validateLocal: 60 * time.Millisecond, + validateExt: 60 * time.Millisecond, + localFastpath: 300 * time.Millisecond, + }, + "local_fastpath_signature_change": { + total: 500 * time.Millisecond, + validateLocal: 60 * time.Millisecond, + validateExt: 60 * time.Millisecond, + localFastpath: 300 * time.Millisecond, + }, + "local_fastpath_struct_field_addition": { + total: 500 * time.Millisecond, + validateLocal: 60 * time.Millisecond, + validateExt: 60 * time.Millisecond, + localFastpath: 300 * time.Millisecond, + }, + "local_fastpath_interface_method_addition": { + total: 500 * time.Millisecond, + validateLocal: 60 * time.Millisecond, + validateExt: 60 * time.Millisecond, + localFastpath: 300 * time.Millisecond, + }, + "fallback_invalid_body_change": { + total: 800 * time.Millisecond, + generateLoad: 500 * time.Millisecond, + }, + } +} + +func measureIncrementalScenarioOnce(tb testing.TB, repoRoot string, scenario incrementalScenarioBenchmarkCase) incrementalScenarioTrace { + tb.Helper() + + cacheRoot := tb.TempDir() + osTempDir = func() string { return cacheRoot } + + root := tb.TempDir() + writeIncrementalScenarioBenchmarkModule(tb, repoRoot, root) + + env := append(os.Environ(), "GOWORK=off") + ctx := WithIncremental(context.Background(), true) + + if _, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { + tb.Fatalf("baseline Generate returned errors: %v", errs) + } + + if scenario.measure != nil { + return scenario.measure(tb, root, env, ctx) + } + + scenario.mutate(tb, root) + + trace := incrementalScenarioTrace{labels: make(map[string]time.Duration)} + timedCtx := WithTiming(ctx, func(label string, dur time.Duration) { + trace.labels[label] += dur + }) + start := time.Now() + gens, errs := Generate(timedCtx, root, env, []string{"./app"}, &GenerateOptions{}) + trace.total = time.Since(start) + + if scenario.wantErr { + if len(errs) == 0 { + tb.Fatalf("%s: expected Generate errors", scenario.name) + } + if len(gens) != 0 { + tb.Fatalf("%s: expected no generated results on error, got %+v", scenario.name, gens) + } + return trace + } + + if len(errs) > 0 { + tb.Fatalf("%s: Generate returned errors: %v", scenario.name, errs) + } + if len(gens) != 1 || len(gens[0].Errs) > 0 { + tb.Fatalf("%s: unexpected Generate results: %+v", scenario.name, gens) + } + return trace +} + +func writeIncrementalScenarioBenchmarkModule(tb testing.TB, repoRoot string, root string) { + tb.Helper() + + writeBenchmarkFile(tb, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeBenchmarkFile(tb, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *dep.Foo {", + "\twire.Build(dep.NewSet)", + "\treturn nil", + "}", + "", + }, "\n")) + + writeBodyOnlyScenarioVariant(tb, root, "green") +} + +func writeBodyOnlyScenarioVariant(tb testing.TB, root string, value string) { + tb.Helper() + writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "const SQLText = \"" + value + "\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + + writeBenchmarkFile(tb, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, New)", + "", + }, "\n")) +} + +func measureIncrementalScenarioMedian(tb testing.TB, repoRoot string, scenario incrementalScenarioBenchmarkCase, samples int) incrementalScenarioTrace { + tb.Helper() + if samples <= 0 { + samples = 1 + } + traces := make([]incrementalScenarioTrace, 0, samples) + for i := 0; i < samples; i++ { + traces = append(traces, measureIncrementalScenarioOnce(tb, repoRoot, scenario)) + } + sort.Slice(traces, func(i, j int) bool { return traces[i].total < traces[j].total }) + return traces[len(traces)/2] +} + +func assertScenarioBudget(t *testing.T, trace incrementalScenarioTrace, budget incrementalScenarioBudget) { + t.Helper() + checkBudgetDuration(t, "total", trace.total, budget.total) + checkBudgetDuration(t, "validate_local_packages", trace.label("incremental.preload_manifest.validate_local_packages"), budget.validateLocal) + checkBudgetDuration(t, "validate_external_files", trace.label("incremental.preload_manifest.validate_external_files"), budget.validateExt) + checkBudgetDuration(t, "validate_touched", trace.label("incremental.preload_manifest.validate_touched"), budget.validateTouch) + checkBudgetDuration(t, "validate_touched_cache_hit", trace.label("incremental.preload_manifest.validate_touched_cache_hit"), budget.validateTouchHit) + checkBudgetDuration(t, "outputs", trace.label("incremental.preload_manifest.outputs"), budget.outputs) + checkBudgetDuration(t, "generate_load", trace.label("generate.load"), budget.generateLoad) + checkBudgetDuration(t, "local_fastpath_load", trace.label("incremental.local_fastpath.load"), budget.localFastpath) +} + +func checkBudgetDuration(t *testing.T, name string, got time.Duration, max time.Duration) { + t.Helper() + if max <= 0 { + return + } + if got > max { + t.Fatalf("%s exceeded budget: got=%s max=%s", name, got, max) + } +} + +func (s incrementalScenarioTrace) label(name string) time.Duration { + if s.labels == nil { + return 0 + } + return s.labels[name] +} + type largeRepoBenchmarkRow struct { packageCount int coldNormal time.Duration @@ -328,6 +1059,35 @@ type shapeChangeTrace struct { labels map[string]time.Duration } +func largeRepoPerformanceBudgets() map[int]largeRepoPerformanceBudget { + return map[int]largeRepoPerformanceBudget{ + 10: { + shapeTotal: 45 * time.Millisecond, + localLoad: 3 * time.Millisecond, + parse: 500 * time.Microsecond, + typecheck: 4 * time.Millisecond, + generate: 3 * time.Millisecond, + knownToggle: 3 * time.Millisecond, + }, + 100: { + shapeTotal: 35 * time.Millisecond, + localLoad: 20 * time.Millisecond, + parse: 1500 * time.Microsecond, + typecheck: 12 * time.Millisecond, + generate: 20 * time.Millisecond, + knownToggle: 15 * time.Millisecond, + }, + 1000: { + shapeTotal: 260 * time.Millisecond, + localLoad: 110 * time.Millisecond, + parse: 4 * time.Millisecond, + typecheck: 70 * time.Millisecond, + generate: 180 * time.Millisecond, + knownToggle: 90 * time.Millisecond, + }, + } +} + func measureLargeRepoShapeChangeOnce(tb testing.TB, repoRoot string, packageCount int, incremental bool) time.Duration { tb.Helper() @@ -394,6 +1154,19 @@ func measureLargeRepoShapeChangeTraceOnce(tb testing.TB, repoRoot string, packag return trace } +func measureLargeRepoShapeChangeTraceMedian(tb testing.TB, repoRoot string, packageCount int, incremental bool, samples int) shapeChangeTrace { + tb.Helper() + if samples <= 0 { + samples = 1 + } + traces := make([]shapeChangeTrace, 0, samples) + for i := 0; i < samples; i++ { + traces = append(traces, measureLargeRepoShapeChangeTraceOnce(tb, repoRoot, packageCount, incremental)) + } + sort.Slice(traces, func(i, j int) bool { return traces[i].total < traces[j].total }) + return traces[len(traces)/2] +} + func (s shapeChangeTrace) label(name string) time.Duration { if s.labels == nil { return 0 @@ -466,6 +1239,19 @@ func measureLargeRepoKnownToggleOnce(tb testing.TB, repoRoot string, packageCoun return dur } +func measureLargeRepoKnownToggleMedian(tb testing.TB, repoRoot string, packageCount int, samples int) time.Duration { + tb.Helper() + if samples <= 0 { + samples = 1 + } + values := make([]time.Duration, 0, samples) + for i := 0; i < samples; i++ { + values = append(values, measureLargeRepoKnownToggleOnce(tb, repoRoot, packageCount)) + } + sort.Slice(values, func(i, j int) bool { return values[i] < values[j] }) + return values[len(values)/2] +} + func formatPercentImprovement(normal time.Duration, incremental time.Duration) string { if normal <= 0 { return "0.0%" @@ -488,7 +1274,7 @@ func formatBenchmarkDuration(d time.Duration) string { case d >= time.Millisecond: return fmt.Sprintf("%.2fms", float64(d)/float64(time.Millisecond)) case d >= time.Microsecond: - return fmt.Sprintf("%.2fµs", float64(d)/float64(time.Microsecond)) + return fmt.Sprintf("%.2fus", float64(d)/float64(time.Microsecond)) default: return d.String() } @@ -674,8 +1460,8 @@ func renderASCIITable(rows [][]string) string { widths := make([]int, len(rows[0])) for _, row := range rows { for i, cell := range row { - if len(cell) > widths[i] { - widths[i] = len(cell) + if width := utf8.RuneCountInString(cell); width > widths[i] { + widths[i] = width } } } @@ -693,7 +1479,7 @@ func renderASCIITable(rows [][]string) string { for i, cell := range row { b.WriteByte(' ') b.WriteString(cell) - b.WriteString(strings.Repeat(" ", widths[i]-len(cell)+1)) + b.WriteString(strings.Repeat(" ", widths[i]-utf8.RuneCountInString(cell)+1)) b.WriteByte('|') } b.WriteByte('\n') diff --git a/internal/wire/incremental_fingerprint.go b/internal/wire/incremental_fingerprint.go index 42c2317..be39982 100644 --- a/internal/wire/incremental_fingerprint.go +++ b/internal/wire/incremental_fingerprint.go @@ -31,7 +31,7 @@ import ( "golang.org/x/tools/go/packages" ) -const incrementalFingerprintVersion = "wire-incremental-v1" +const incrementalFingerprintVersion = "wire-incremental-v3" type packageFingerprint struct { Version string @@ -39,6 +39,8 @@ type packageFingerprint struct { Tags string PkgPath string Files []cacheFile + Dirs []cacheFile + ContentHash string ShapeHash string LocalImports []string } @@ -159,10 +161,12 @@ func buildIncrementalManifestSnapshotFromPackages(wd string, tags string, pkgs [ sort.Strings(localImports) snapshot.fingerprints[pkg.PkgPath] = &packageFingerprint{ Version: incrementalFingerprintVersion, - WD: filepath.Clean(wd), + WD: packageCacheScope(wd), Tags: tags, PkgPath: pkg.PkgPath, Files: metaFiles, + Dirs: mustBuildPackageDirCacheFiles(files), + ContentHash: mustHashPackageFiles(files), ShapeHash: shapeHash, LocalImports: localImports, } @@ -183,11 +187,52 @@ func packageFingerprintFiles(pkg *packages.Package) []string { return append([]string(nil), pkg.GoFiles...) } +func packageFingerprintDirs(files []string) []string { + if len(files) == 0 { + return nil + } + dirs := make([]string, 0, len(files)) + seen := make(map[string]struct{}, len(files)) + for _, name := range files { + dir := filepath.Clean(filepath.Dir(name)) + if _, ok := seen[dir]; ok { + continue + } + seen[dir] = struct{}{} + dirs = append(dirs, dir) + } + sort.Strings(dirs) + return dirs +} + +func mustBuildPackageDirCacheFiles(files []string) []cacheFile { + dirs := packageFingerprintDirs(files) + if len(dirs) == 0 { + return nil + } + meta, err := buildCacheFiles(dirs) + if err != nil { + return nil + } + return meta +} + +func mustHashPackageFiles(files []string) string { + if len(files) == 0 { + return "" + } + hash, err := hashFiles(files) + if err != nil { + return "" + } + return hash +} + func incrementalFingerprintEquivalent(a, b *packageFingerprint) bool { if a == nil || b == nil { return false } - if a.ShapeHash != b.ShapeHash || a.PkgPath != b.PkgPath || a.Tags != b.Tags || filepath.Clean(a.WD) != filepath.Clean(b.WD) { + if a.ShapeHash != b.ShapeHash || a.PkgPath != b.PkgPath || a.Tags != b.Tags || a.WD != b.WD { return false } if len(a.LocalImports) != len(b.LocalImports) { @@ -205,7 +250,7 @@ func incrementalFingerprintMetaMatches(prev *packageFingerprint, wd string, tags if prev == nil || prev.Version != incrementalFingerprintVersion { return false } - if filepath.Clean(prev.WD) != filepath.Clean(wd) || prev.Tags != tags || prev.PkgPath != pkgPath { + if prev.WD != packageCacheScope(wd) || prev.Tags != tags || prev.PkgPath != pkgPath { return false } if len(prev.Files) != len(files) { @@ -234,10 +279,12 @@ func buildPackageFingerprint(wd string, tags string, pkg *packages.Package, file sort.Strings(localImports) return &packageFingerprint{ Version: incrementalFingerprintVersion, - WD: filepath.Clean(wd), + WD: packageCacheScope(wd), Tags: tags, PkgPath: pkg.PkgPath, Files: append([]cacheFile(nil), files...), + Dirs: mustBuildPackageDirCacheFiles(packageFingerprintFiles(pkg)), + ContentHash: mustHashPackageFiles(packageFingerprintFiles(pkg)), ShapeHash: shapeHash, LocalImports: localImports, }, nil @@ -278,6 +325,7 @@ func writeSyntaxShapeHash(buf *bytes.Buffer, fset *token.FileSet, file *ast.File if file == nil || buf == nil || fset == nil { return } + usedImports := usedImportNamesInShape(file) if file.Name != nil { buf.WriteString("package ") buf.WriteString(file.Name.Name) @@ -294,6 +342,10 @@ func writeSyntaxShapeHash(buf *bytes.Buffer, fset *token.FileSet, file *ast.File buf.WriteByte(' ') writeNodeHash(buf, fset, decl.Type) buf.WriteByte('\n') + case *ast.GenDecl: + if writeGenDeclShapeHash(buf, fset, decl, usedImports) { + buf.WriteByte('\n') + } default: writeNodeHash(buf, fset, decl) buf.WriteByte('\n') @@ -301,6 +353,111 @@ func writeSyntaxShapeHash(buf *bytes.Buffer, fset *token.FileSet, file *ast.File } } +func writeGenDeclShapeHash(buf *bytes.Buffer, fset *token.FileSet, decl *ast.GenDecl, usedImports map[string]struct{}) bool { + if buf == nil || fset == nil || decl == nil { + return false + } + var specBuf bytes.Buffer + wrote := false + for _, spec := range decl.Specs { + switch spec := spec.(type) { + case *ast.ImportSpec: + name := importName(spec) + if name == "_" || name == "." { + if spec.Name != nil { + specBuf.WriteString(spec.Name.Name) + } + specBuf.WriteByte(' ') + writeNodeHash(&specBuf, fset, spec.Path) + specBuf.WriteByte('\n') + wrote = true + break + } + if _, ok := usedImports[name]; !ok { + continue + } + if spec.Name != nil { + specBuf.WriteString(spec.Name.Name) + } + specBuf.WriteByte(' ') + writeNodeHash(&specBuf, fset, spec.Path) + case *ast.TypeSpec: + if spec.Name != nil { + specBuf.WriteString(spec.Name.Name) + } + specBuf.WriteByte(' ') + writeNodeHash(&specBuf, fset, spec.Type) + case *ast.ValueSpec: + for _, name := range spec.Names { + if name != nil { + specBuf.WriteString(name.Name) + } + specBuf.WriteByte(' ') + } + if spec.Type != nil { + writeNodeHash(&specBuf, fset, spec.Type) + } + default: + writeNodeHash(&specBuf, fset, spec) + } + specBuf.WriteByte('\n') + wrote = true + } + if !wrote { + return false + } + buf.WriteString(decl.Tok.String()) + buf.WriteByte(' ') + buf.Write(specBuf.Bytes()) + return true +} + +func usedImportNamesInShape(file *ast.File) map[string]struct{} { + used := make(map[string]struct{}) + if file == nil { + return used + } + record := func(node ast.Node) { + ast.Inspect(node, func(n ast.Node) bool { + sel, ok := n.(*ast.SelectorExpr) + if !ok { + return true + } + ident, ok := sel.X.(*ast.Ident) + if !ok || ident.Name == "" { + return true + } + used[ident.Name] = struct{}{} + return true + }) + } + for _, decl := range file.Decls { + switch decl := decl.(type) { + case *ast.FuncDecl: + if decl.Recv != nil { + record(decl.Recv) + } + if decl.Type != nil { + record(decl.Type) + } + case *ast.GenDecl: + for _, spec := range decl.Specs { + switch spec := spec.(type) { + case *ast.TypeSpec: + if spec.Type != nil { + record(spec.Type) + } + case *ast.ValueSpec: + if spec.Type != nil { + record(spec.Type) + } + } + } + } + } + return used +} + func writeNodeHash(buf *bytes.Buffer, fset *token.FileSet, node interface{}) { if buf == nil || fset == nil || node == nil { return @@ -324,7 +481,7 @@ func incrementalFingerprintKey(wd string, tags string, pkgPath string) string { h := sha256.New() h.Write([]byte(incrementalFingerprintVersion)) h.Write([]byte{0}) - h.Write([]byte(filepath.Clean(wd))) + h.Write([]byte(packageCacheScope(wd))) h.Write([]byte{0}) h.Write([]byte(tags)) h.Write([]byte{0}) diff --git a/internal/wire/incremental_fingerprint_test.go b/internal/wire/incremental_fingerprint_test.go index afe81de..920d08e 100644 --- a/internal/wire/incremental_fingerprint_test.go +++ b/internal/wire/incremental_fingerprint_test.go @@ -41,6 +41,44 @@ func TestPackageShapeHashIgnoresFunctionBodies(t *testing.T) { } } +func TestPackageShapeHashIgnoresConstValueChanges(t *testing.T) { + dir := t.TempDir() + file := writeTempFile(t, dir, "pkg.go", "package p\n\nconst SQLText = \"a\"\n") + hash1, err := packageShapeHash([]string{file}) + if err != nil { + t.Fatalf("packageShapeHash first failed: %v", err) + } + if err := os.WriteFile(file, []byte("package p\n\nconst SQLText = \"b\"\n"), 0644); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + hash2, err := packageShapeHash([]string{file}) + if err != nil { + t.Fatalf("packageShapeHash second failed: %v", err) + } + if hash1 != hash2 { + t.Fatalf("const-value change should not affect shape hash: %q vs %q", hash1, hash2) + } +} + +func TestPackageShapeHashIgnoresImplementationOnlyImportChanges(t *testing.T) { + dir := t.TempDir() + file := writeTempFile(t, dir, "pkg.go", "package p\n\nfunc Hello() string { return \"a\" }\n") + hash1, err := packageShapeHash([]string{file}) + if err != nil { + t.Fatalf("packageShapeHash first failed: %v", err) + } + if err := os.WriteFile(file, []byte("package p\n\nimport \"fmt\"\n\nfunc Hello() string { return fmt.Sprint(\"a\") }\n"), 0644); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + hash2, err := packageShapeHash([]string{file}) + if err != nil { + t.Fatalf("packageShapeHash second failed: %v", err) + } + if hash1 != hash2 { + t.Fatalf("implementation-only import change should not affect shape hash: %q vs %q", hash1, hash2) + } +} + func TestIncrementalFingerprintRoundTrip(t *testing.T) { fp := &packageFingerprint{ Version: incrementalFingerprintVersion, diff --git a/internal/wire/incremental_graph.go b/internal/wire/incremental_graph.go index 66cf28d..37b3d0f 100644 --- a/internal/wire/incremental_graph.go +++ b/internal/wire/incremental_graph.go @@ -58,7 +58,7 @@ func buildIncrementalGraph(wd string, tags string, pkgs []*packages.Package) *in moduleRoot := findModuleRoot(wd) graph := &incrementalGraph{ Version: incrementalGraphVersion, - WD: filepath.Clean(wd), + WD: packageCacheScope(wd), Tags: tags, Roots: make([]string, 0, len(pkgs)), LocalReverse: make(map[string][]string), @@ -126,7 +126,7 @@ func incrementalGraphKey(wd string, tags string, roots []string) string { h := sha256.New() h.Write([]byte(incrementalGraphVersion)) h.Write([]byte{0}) - h.Write([]byte(filepath.Clean(wd))) + h.Write([]byte(packageCacheScope(wd))) h.Write([]byte{0}) h.Write([]byte(tags)) h.Write([]byte{0}) diff --git a/internal/wire/incremental_manifest.go b/internal/wire/incremental_manifest.go index 4a55c19..8fab10e 100644 --- a/internal/wire/incremental_manifest.go +++ b/internal/wire/incremental_manifest.go @@ -20,6 +20,7 @@ import ( "crypto/sha256" "encoding/binary" "fmt" + "go/token" "os" "path/filepath" "sort" @@ -28,7 +29,7 @@ import ( "golang.org/x/tools/go/packages" ) -const incrementalManifestVersion = "wire-incremental-manifest-v1" +const incrementalManifestVersion = "wire-incremental-manifest-v3" type incrementalManifest struct { Version string @@ -61,9 +62,19 @@ type incrementalPreloadState struct { manifest *incrementalManifest valid bool currentLocal []packageFingerprint + touched []string reason string } +type incrementalPreloadValidation struct { + valid bool + currentLocal []packageFingerprint + touched []string + reason string +} + +const touchedValidationVersion = "wire-touched-validation-v1" + func readPreloadIncrementalManifestResults(ctx context.Context, wd string, env []string, patterns []string, opts *GenerateOptions) ([]GenerateResult, bool) { state, ok := prepareIncrementalPreloadState(ctx, wd, env, patterns, opts) return readPreloadIncrementalManifestResultsFromState(ctx, wd, env, patterns, opts, state, ok) @@ -75,15 +86,36 @@ func readPreloadIncrementalManifestResultsFromState(ctx context.Context, wd stri return nil, false } if state.valid { + validateStart := timeNow() + if len(state.touched) > 0 { + debugf(ctx, "incremental.preload_manifest touched=%s", strings.Join(state.touched, ",")) + } + if err := validateIncrementalPreloadTouchedPackages(ctx, wd, env, opts, state.currentLocal, state.touched); err != nil { + logTiming(ctx, "incremental.preload_manifest.validate_touched", validateStart) + if shouldBypassIncrementalManifestAfterFastPathError(err) { + invalidateIncrementalPreloadState(state) + } + debugf(ctx, "incremental.preload_manifest miss reason=touched_validation") + return nil, false + } + logTiming(ctx, "incremental.preload_manifest.validate_touched", validateStart) + outputsStart := timeNow() results, ok := incrementalManifestOutputs(state.manifest) + logTiming(ctx, "incremental.preload_manifest.outputs", outputsStart) if !ok { debugf(ctx, "incremental.preload_manifest miss reason=outputs") return nil, false } + if manifestNeedsLocalRefresh(state.manifest.LocalPackages, state.currentLocal) { + refreshed := *state.manifest + refreshed.LocalPackages = append([]packageFingerprint(nil), state.currentLocal...) + writeIncrementalManifestFile(state.selectorKey, &refreshed) + writeIncrementalManifestFile(incrementalManifestStateKey(state.selectorKey, refreshed.LocalPackages), &refreshed) + } debugf(ctx, "incremental.preload_manifest hit outputs=%d", len(results)) return results, true } else if archived := readStateIncrementalManifest(state.selectorKey, state.currentLocal); archived != nil { - if ok, _, _ := incrementalManifestPreloadValid(ctx, archived, wd, env, patterns, opts); ok { + if validation := incrementalManifestPreloadValid(ctx, archived, wd, env, patterns, opts); validation.valid { results, ok := incrementalManifestOutputs(archived) if !ok { debugf(ctx, "incremental.preload_manifest miss reason=state_outputs") @@ -107,13 +139,14 @@ func prepareIncrementalPreloadState(ctx context.Context, wd string, env []string if !ok { return nil, false } - valid, currentLocal, reason := incrementalManifestPreloadValid(ctx, manifest, wd, env, patterns, opts) + validation := incrementalManifestPreloadValid(ctx, manifest, wd, env, patterns, opts) return &incrementalPreloadState{ selectorKey: selectorKey, manifest: manifest, - valid: valid, - currentLocal: currentLocal, - reason: reason, + valid: validation.valid, + currentLocal: validation.currentLocal, + touched: validation.touched, + reason: validation.reason, }, true } @@ -150,6 +183,7 @@ func writeIncrementalManifestWithOptions(wd string, env []string, patterns []str if snapshot == nil || len(generated) == 0 { return } + scope := runCacheScope(wd, patterns) externalPkgs := buildExternalPackageExports(wd, pkgs) var externalFiles []cacheFile if includeExternalFiles { @@ -161,12 +195,12 @@ func writeIncrementalManifestWithOptions(wd string, env []string, patterns []str } manifest := &incrementalManifest{ Version: incrementalManifestVersion, - WD: filepath.Clean(wd), + WD: scope, Tags: opts.Tags, Prefix: opts.PrefixOutputFile, HeaderHash: headerHash(opts.Header), EnvHash: envHash(env), - Patterns: sortedStrings(patterns), + Patterns: normalizePatternsForScope(wd, packageCacheScope(wd), patterns), LocalPackages: snapshotPackageFingerprints(snapshot), ExternalPkgs: externalPkgs, ExternalFiles: externalFiles, @@ -197,7 +231,7 @@ func incrementalManifestSelectorKey(wd string, env []string, patterns []string, h := sha256.New() h.Write([]byte(incrementalManifestVersion)) h.Write([]byte{0}) - h.Write([]byte(filepath.Clean(wd))) + h.Write([]byte(runCacheScope(wd, patterns))) h.Write([]byte{0}) h.Write([]byte(envHash(env))) h.Write([]byte{0}) @@ -207,7 +241,7 @@ func incrementalManifestSelectorKey(wd string, env []string, patterns []string, h.Write([]byte{0}) h.Write([]byte(headerHash(opts.Header))) h.Write([]byte{0}) - for _, p := range sortedStrings(patterns) { + for _, p := range normalizePatternsForScope(wd, packageCacheScope(wd), patterns) { h.Write([]byte(p)) h.Write([]byte{0}) } @@ -276,16 +310,17 @@ func incrementalManifestValid(manifest *incrementalManifest, wd string, env []st if manifest == nil || manifest.Version != incrementalManifestVersion { return false } - if filepath.Clean(manifest.WD) != filepath.Clean(wd) || manifest.Tags != opts.Tags || manifest.Prefix != opts.PrefixOutputFile { + if manifest.WD != runCacheScope(wd, patterns) || manifest.Tags != opts.Tags || manifest.Prefix != opts.PrefixOutputFile { return false } if manifest.HeaderHash != headerHash(opts.Header) || manifest.EnvHash != envHash(env) { return false } - if len(manifest.Patterns) != len(patterns) { + normalizedPatterns := normalizePatternsForScope(wd, packageCacheScope(wd), patterns) + if len(manifest.Patterns) != len(normalizedPatterns) { return false } - for i, p := range sortedStrings(patterns) { + for i, p := range normalizedPatterns { if manifest.Patterns[i] != p { return false } @@ -313,58 +348,93 @@ func incrementalManifestValid(manifest *incrementalManifest, wd string, env []st return len(manifest.Outputs) > 0 } -func incrementalManifestPreloadValid(ctx context.Context, manifest *incrementalManifest, wd string, env []string, patterns []string, opts *GenerateOptions) (bool, []packageFingerprint, string) { +func incrementalManifestPreloadValid(ctx context.Context, manifest *incrementalManifest, wd string, env []string, patterns []string, opts *GenerateOptions) incrementalPreloadValidation { if manifest == nil || manifest.Version != incrementalManifestVersion { - return false, nil, "version" + return incrementalPreloadValidation{reason: "version"} } - if filepath.Clean(manifest.WD) != filepath.Clean(wd) || manifest.Tags != opts.Tags || manifest.Prefix != opts.PrefixOutputFile { - return false, nil, "config" + if manifest.WD != runCacheScope(wd, patterns) || manifest.Tags != opts.Tags || manifest.Prefix != opts.PrefixOutputFile { + return incrementalPreloadValidation{reason: "config"} } if manifest.HeaderHash != headerHash(opts.Header) || manifest.EnvHash != envHash(env) { - return false, nil, "env" + return incrementalPreloadValidation{reason: "env"} } - if len(manifest.Patterns) != len(patterns) { - return false, nil, "patterns.length" + normalizedPatterns := normalizePatternsForScope(wd, packageCacheScope(wd), patterns) + if len(manifest.Patterns) != len(normalizedPatterns) { + return incrementalPreloadValidation{reason: "patterns.length"} } - for i, p := range sortedStrings(patterns) { + for i, p := range normalizedPatterns { if manifest.Patterns[i] != p { - return false, nil, "patterns.value" + return incrementalPreloadValidation{reason: "patterns.value"} } } if len(manifest.ExtraFiles) > 0 { + extraStart := timeNow() current, err := buildCacheFilesFromMeta(manifest.ExtraFiles) + logTiming(ctx, "incremental.preload_manifest.validate_extra_files", extraStart) if err != nil || len(current) != len(manifest.ExtraFiles) { - return false, nil, "extra_files" + return incrementalPreloadValidation{reason: "extra_files"} } for i := range current { if current[i] != manifest.ExtraFiles[i] { - return false, nil, "extra_files.diff" + return incrementalPreloadValidation{reason: "extra_files.diff"} } } } - currentLocal, ok, reason := incrementalManifestCurrentLocalPackages(ctx, manifest.LocalPackages) - if !ok { - return false, currentLocal, "local_packages." + reason + localStart := timeNow() + packagesState := incrementalManifestCurrentLocalPackages(ctx, manifest.LocalPackages) + logTiming(ctx, "incremental.preload_manifest.validate_local_packages", localStart) + if !packagesState.valid { + return incrementalPreloadValidation{ + currentLocal: packagesState.currentLocal, + touched: packagesState.touched, + reason: "local_packages." + packagesState.reason, + } } if len(manifest.ExternalFiles) > 0 { + externalStart := timeNow() current, err := buildCacheFilesFromMeta(manifest.ExternalFiles) + logTiming(ctx, "incremental.preload_manifest.validate_external_files", externalStart) if err != nil || len(current) != len(manifest.ExternalFiles) { - return false, currentLocal, "external_files" + return incrementalPreloadValidation{ + currentLocal: packagesState.currentLocal, + touched: packagesState.touched, + reason: "external_files", + } } for i := range current { if current[i] != manifest.ExternalFiles[i] { - return false, currentLocal, "external_files.diff" + return incrementalPreloadValidation{ + currentLocal: packagesState.currentLocal, + touched: packagesState.touched, + reason: "external_files.diff", + } } } } if len(manifest.Outputs) == 0 { - return false, currentLocal, "outputs" + return incrementalPreloadValidation{ + currentLocal: packagesState.currentLocal, + touched: packagesState.touched, + reason: "outputs", + } + } + return incrementalPreloadValidation{ + valid: true, + currentLocal: packagesState.currentLocal, + touched: packagesState.touched, } - return true, currentLocal, "" } -func incrementalManifestCurrentLocalPackages(ctx context.Context, local []packageFingerprint) ([]packageFingerprint, bool, string) { +type incrementalLocalPackagesState struct { + valid bool + currentLocal []packageFingerprint + touched []string + reason string +} + +func incrementalManifestCurrentLocalPackages(ctx context.Context, local []packageFingerprint) incrementalLocalPackagesState { currentState := make([]packageFingerprint, 0, len(local)) + touched := make([]string, 0, len(local)) var firstReason string for _, fp := range local { if len(fp.Files) == 0 { @@ -400,42 +470,158 @@ func incrementalManifestCurrentLocalPackages(ctx context.Context, local []packag } } if !sameMeta { - shapeHash, err := packageShapeHash(storedFiles) + if diffs := describeCacheFileDiffs(fp.Files, currentMeta); len(diffs) > 0 { + debugf(ctx, "incremental.preload_manifest local_pkg=%s meta_diff=%s", fp.PkgPath, strings.Join(diffs, "; ")) + } + contentHash, err := hashFiles(storedFiles) if err != nil { - debugf(ctx, "incremental.preload_manifest local_pkg=%s shape_error=%v", fp.PkgPath, err) + debugf(ctx, "incremental.preload_manifest local_pkg=%s content_error=%v", fp.PkgPath, err) if firstReason == "" { - firstReason = fp.PkgPath + ".shape_error" + firstReason = fp.PkgPath + ".content_error" } continue } - currentFP.ShapeHash = shapeHash - if shapeHash != fp.ShapeHash { - debugf(ctx, "incremental.preload_manifest local_pkg=%s stored_shape=%s current_shape=%s files=%s", fp.PkgPath, fp.ShapeHash, shapeHash, strings.Join(storedFiles, ",")) - if firstReason == "" { - firstReason = fp.PkgPath + ".shape_mismatch" + currentFP.ContentHash = contentHash + if contentHash != fp.ContentHash { + debugf(ctx, "incremental.preload_manifest local_pkg=%s stored_content=%s current_content=%s hash_files=%s", fp.PkgPath, fp.ContentHash, contentHash, strings.Join(storedFiles, ",")) + shapeHash, err := packageShapeHash(storedFiles) + if err != nil { + debugf(ctx, "incremental.preload_manifest local_pkg=%s shape_error=%v", fp.PkgPath, err) + if firstReason == "" { + firstReason = fp.PkgPath + ".shape_error" + } + continue + } + currentFP.ShapeHash = shapeHash + if shapeHash != fp.ShapeHash { + debugf(ctx, "incremental.preload_manifest local_pkg=%s stored_shape=%s current_shape=%s files=%s", fp.PkgPath, fp.ShapeHash, shapeHash, strings.Join(storedFiles, ",")) + if firstReason == "" { + firstReason = fp.PkgPath + ".shape_mismatch" + } + } else { + debugf(ctx, "incremental.preload_manifest local_pkg=%s content_changed_shape_unchanged", fp.PkgPath) + touched = append(touched, fp.PkgPath) } - } else if firstReason == "" { - firstReason = fp.PkgPath + ".meta_changed" } } - if changed, err := packageDirectoryIntroducedRelevantFiles(fp.Files); err != nil { - debugf(ctx, "incremental.preload_manifest local_pkg=%s dir_scan_error=%v", fp.PkgPath, err) + currentDirs, dirsChanged, err := packageDirectoryMetaChanged(fp, storedFiles) + if err != nil { + debugf(ctx, "incremental.preload_manifest local_pkg=%s dir_meta_error=%v", fp.PkgPath, err) if firstReason == "" { - firstReason = fp.PkgPath + ".dir_scan_error" + firstReason = fp.PkgPath + ".dir_meta_error" } continue - } else if changed { - debugf(ctx, "incremental.preload_manifest local_pkg=%s introduced_relevant_files=true", fp.PkgPath) - if firstReason == "" { - firstReason = fp.PkgPath + ".introduced_relevant_files" + } + currentFP.Dirs = currentDirs + if dirsChanged { + if changed, err := packageDirectoryIntroducedRelevantFiles(fp.Files); err != nil { + debugf(ctx, "incremental.preload_manifest local_pkg=%s dir_scan_error=%v", fp.PkgPath, err) + if firstReason == "" { + firstReason = fp.PkgPath + ".dir_scan_error" + } + continue + } else if changed { + debugf(ctx, "incremental.preload_manifest local_pkg=%s introduced_relevant_files=true", fp.PkgPath) + if firstReason == "" { + firstReason = fp.PkgPath + ".introduced_relevant_files" + } } } currentState = append(currentState, currentFP) } if firstReason != "" { - return currentState, false, firstReason + return incrementalLocalPackagesState{ + currentLocal: currentState, + touched: touched, + reason: firstReason, + } + } + sort.Strings(touched) + return incrementalLocalPackagesState{ + valid: true, + currentLocal: currentState, + touched: touched, + } +} + +func validateIncrementalPreloadTouchedPackages(ctx context.Context, wd string, env []string, opts *GenerateOptions, local []packageFingerprint, touched []string) error { + if len(touched) == 0 { + return nil + } + cacheKey := touchedValidationKey(wd, env, opts, local, touched) + if cacheKey != "" { + cacheHitStart := timeNow() + if _, ok := readCache(cacheKey); ok { + logTiming(ctx, "incremental.preload_manifest.validate_touched_cache_hit", cacheHitStart) + return nil + } + } + cfg := &packages.Config{ + Context: ctx, + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedExportsFile | packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesSizes, + Dir: wd, + Env: env, + BuildFlags: []string{"-tags=wireinject"}, + Fset: token.NewFileSet(), + } + if len(opts.Tags) > 0 { + cfg.BuildFlags[0] += " " + opts.Tags + } + loadStart := timeNow() + pkgs, err := packages.Load(cfg, touched...) + logTiming(ctx, "incremental.preload_manifest.validate_touched_load", loadStart) + if err != nil { + return err + } + errorsStart := timeNow() + byPath := make(map[string]*packages.Package, len(pkgs)) + for _, pkg := range pkgs { + if pkg != nil { + byPath[pkg.PkgPath] = pkg + } + } + for _, path := range touched { + if pkg := byPath[path]; pkg != nil && len(pkg.Errors) > 0 { + logTiming(ctx, "incremental.preload_manifest.validate_touched_errors", errorsStart) + return formatLocalTypeCheckError(wd, pkg.PkgPath, pkg.Errors) + } } - return currentState, true, "" + logTiming(ctx, "incremental.preload_manifest.validate_touched_errors", errorsStart) + if cacheKey != "" { + cacheWriteStart := timeNow() + writeCache(cacheKey, []byte("ok")) + logTiming(ctx, "incremental.preload_manifest.validate_touched_cache_write", cacheWriteStart) + } + return nil +} + +func touchedValidationKey(wd string, env []string, opts *GenerateOptions, local []packageFingerprint, touched []string) string { + if len(touched) == 0 { + return "" + } + byPath := fingerprintsFromSlice(local) + h := sha256.New() + h.Write([]byte(touchedValidationVersion)) + h.Write([]byte{0}) + h.Write([]byte(packageCacheScope(wd))) + h.Write([]byte{0}) + h.Write([]byte(envHash(env))) + h.Write([]byte{0}) + if opts != nil { + h.Write([]byte(opts.Tags)) + } + h.Write([]byte{0}) + for _, pkgPath := range touched { + fp := byPath[pkgPath] + if fp == nil || fp.ContentHash == "" { + return "" + } + h.Write([]byte(pkgPath)) + h.Write([]byte{0}) + h.Write([]byte(fp.ContentHash)) + h.Write([]byte{0}) + } + return fmt.Sprintf("%x", h.Sum(nil)) } func incrementalManifestOutputs(manifest *incrementalManifest) ([]GenerateResult, bool) { @@ -500,6 +686,89 @@ func filesFromMeta(files []cacheFile) []string { return out } +func describeCacheFileDiffs(stored []cacheFile, current []cacheFile) []string { + if len(stored) == 0 && len(current) == 0 { + return nil + } + storedByPath := make(map[string]cacheFile, len(stored)) + currentByPath := make(map[string]cacheFile, len(current)) + for _, file := range stored { + storedByPath[filepath.Clean(file.Path)] = file + } + for _, file := range current { + currentByPath[filepath.Clean(file.Path)] = file + } + paths := make([]string, 0, len(storedByPath)+len(currentByPath)) + seen := make(map[string]struct{}, len(storedByPath)+len(currentByPath)) + for path := range storedByPath { + if _, ok := seen[path]; ok { + continue + } + seen[path] = struct{}{} + paths = append(paths, path) + } + for path := range currentByPath { + if _, ok := seen[path]; ok { + continue + } + seen[path] = struct{}{} + paths = append(paths, path) + } + sort.Strings(paths) + diffs := make([]string, 0, len(paths)) + for _, path := range paths { + storedFile, storedOK := storedByPath[path] + currentFile, currentOK := currentByPath[path] + switch { + case !storedOK: + diffs = append(diffs, fmt.Sprintf("%s added size=%d mtime=%d", path, currentFile.Size, currentFile.ModTime)) + case !currentOK: + diffs = append(diffs, fmt.Sprintf("%s removed size=%d mtime=%d", path, storedFile.Size, storedFile.ModTime)) + case storedFile != currentFile: + diffs = append(diffs, fmt.Sprintf("%s size:%d->%d mtime:%d->%d", path, storedFile.Size, currentFile.Size, storedFile.ModTime, currentFile.ModTime)) + } + } + return diffs +} + +func manifestNeedsLocalRefresh(stored []packageFingerprint, current []packageFingerprint) bool { + if len(stored) != len(current) { + return false + } + for i := range stored { + if stored[i].PkgPath != current[i].PkgPath { + return false + } + if stored[i].ContentHash == "" && current[i].ContentHash != "" { + return true + } + if len(stored[i].Dirs) == 0 && len(current[i].Dirs) > 0 { + return true + } + } + return false +} + +func packageDirectoryMetaChanged(fp packageFingerprint, storedFiles []string) ([]cacheFile, bool, error) { + dirs := packageFingerprintDirs(storedFiles) + if len(dirs) == 0 { + return nil, false, nil + } + current, err := buildCacheFiles(dirs) + if err != nil { + return nil, false, err + } + if len(fp.Dirs) != len(current) { + return current, true, nil + } + for i := range current { + if current[i] != fp.Dirs[i] { + return current, true, nil + } + } + return current, false, nil +} + func packageDirectoryIntroducedRelevantFiles(files []cacheFile) (bool, error) { dirs := make(map[string]struct{}) old := make(map[string]struct{}, len(files)) diff --git a/internal/wire/incremental_session.go b/internal/wire/incremental_session.go index 72b051b..2fdaa2b 100644 --- a/internal/wire/incremental_session.go +++ b/internal/wire/incremental_session.go @@ -46,7 +46,7 @@ func clearIncrementalSessions() { func sessionKey(wd string, env []string, tags string) string { var b strings.Builder - b.WriteString(filepath.Clean(wd)) + b.WriteString(packageCacheScope(wd)) b.WriteByte('\n') b.WriteString(tags) b.WriteByte('\n') diff --git a/internal/wire/incremental_summary.go b/internal/wire/incremental_summary.go index 2930b37..934f637 100644 --- a/internal/wire/incremental_summary.go +++ b/internal/wire/incremental_summary.go @@ -99,7 +99,7 @@ func incrementalSummaryKey(wd string, tags string, pkgPath string) string { h := sha256.New() h.Write([]byte(incrementalSummaryVersion)) h.Write([]byte{0}) - h.Write([]byte(filepath.Clean(wd))) + h.Write([]byte(packageCacheScope(wd))) h.Write([]byte{0}) h.Write([]byte(tags)) h.Write([]byte{0}) diff --git a/internal/wire/loader_test.go b/internal/wire/loader_test.go index d74c0f8..37e27d9 100644 --- a/internal/wire/loader_test.go +++ b/internal/wire/loader_test.go @@ -220,7 +220,7 @@ func TestLoadAndGenerateModuleIncrementalMatches(t *testing.T) { } } -func TestGenerateIncrementalBodyOnlyChangeFallsBackToLoadAndReusesOutput(t *testing.T) { +func TestGenerateIncrementalBodyOnlyChangeUsesPreloadManifestAndReusesOutput(t *testing.T) { lockCacheHooks(t) state := saveCacheHooks() t.Cleanup(func() { restoreCacheHooks(state) }) @@ -340,14 +340,253 @@ func TestGenerateIncrementalBodyOnlyChangeFallsBackToLoadAndReusesOutput(t *test if len(second) != 1 || len(second[0].Errs) > 0 { t.Fatalf("unexpected second Generate result: %+v", second) } - if !containsLabel(secondLabels, "generate.load") { - t.Fatalf("expected second Generate to re-load packages after body-only change, labels=%v", secondLabels) + if containsLabel(secondLabels, "generate.load") { + t.Fatalf("expected second Generate to reuse preload manifest after body-only change, labels=%v", secondLabels) } if string(first[0].Content) != string(second[0].Content) { t.Fatal("expected body-only change to reuse identical generated output") } } +func TestGenerateIncrementalTouchedValidationCacheReusesSuccessfulValidation(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + + writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *dep.Foo {", + "\twire.Build(dep.NewSet)", + "\treturn nil", + "}", + "", + }, "\n")) + + depFile := filepath.Join(root, "dep", "dep.go") + writeBodyVariant := func(message string) { + t.Helper() + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return \"" + message + "\" }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n")) + } + writeBodyVariant("a") + + writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, New)", + "", + }, "\n")) + + env := append(os.Environ(), "GOWORK=off") + ctx := WithIncremental(context.Background(), true) + + first, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("first Generate returned errors: %v", errs) + } + if len(first) != 1 || len(first[0].Errs) > 0 { + t.Fatalf("unexpected first Generate result: %+v", first) + } + + writeBodyVariant("b") + + var secondLabels []string + secondCtx := WithTiming(ctx, func(label string, _ time.Duration) { + secondLabels = append(secondLabels, label) + }) + second, errs := Generate(secondCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("second Generate returned errors: %v", errs) + } + if len(second) != 1 || len(second[0].Errs) > 0 { + t.Fatalf("unexpected second Generate result: %+v", second) + } + if containsLabel(secondLabels, "generate.load") { + t.Fatalf("expected first body-only variant change to avoid generate.load, labels=%v", secondLabels) + } + if containsLabel(secondLabels, "incremental.preload_manifest.validate_touched_cache_hit") { + t.Fatalf("did not expect first body-only variant change to hit touched validation cache, labels=%v", secondLabels) + } + + writeBodyVariant("a") + third, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("third Generate returned errors: %v", errs) + } + if len(third) != 1 || len(third[0].Errs) > 0 { + t.Fatalf("unexpected third Generate result: %+v", third) + } + + writeBodyVariant("b") + + var fourthLabels []string + fourthCtx := WithTiming(ctx, func(label string, _ time.Duration) { + fourthLabels = append(fourthLabels, label) + }) + fourth, errs := Generate(fourthCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("fourth Generate returned errors: %v", errs) + } + if len(fourth) != 1 || len(fourth[0].Errs) > 0 { + t.Fatalf("unexpected fourth Generate result: %+v", fourth) + } + if containsLabel(fourthLabels, "generate.load") { + t.Fatalf("expected repeated body-only variant change to avoid generate.load, labels=%v", fourthLabels) + } + if !containsLabel(fourthLabels, "incremental.preload_manifest.validate_touched_cache_hit") { + t.Fatalf("expected repeated body-only variant change to hit touched validation cache, labels=%v", fourthLabels) + } + if string(first[0].Content) != string(fourth[0].Content) { + t.Fatal("expected repeated body-only variant change to reuse identical generated output") + } +} + +func TestGenerateIncrementalConstValueChangeUsesPreloadManifestAndReusesOutput(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + + writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *dep.Foo {", + "\twire.Build(dep.NewSet)", + "\treturn nil", + "}", + "", + }, "\n")) + + depFile := filepath.Join(root, "dep", "dep.go") + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, New)", + "", + }, "\n")) + + env := append(os.Environ(), "GOWORK=off") + ctx := WithIncremental(context.Background(), true) + + first, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("first Generate returned errors: %v", errs) + } + if len(first) != 1 || len(first[0].Errs) > 0 { + t.Fatalf("unexpected first Generate result: %+v", first) + } + + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "const SQLText = \"blue\"", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n")) + + var secondLabels []string + secondCtx := WithTiming(ctx, func(label string, _ time.Duration) { + secondLabels = append(secondLabels, label) + }) + second, errs := Generate(secondCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("second Generate returned errors: %v", errs) + } + if len(second) != 1 || len(second[0].Errs) > 0 { + t.Fatalf("unexpected second Generate result: %+v", second) + } + if containsLabel(secondLabels, "generate.load") { + t.Fatalf("expected const-value change to reuse preload manifest, labels=%v", secondLabels) + } + if string(first[0].Content) != string(second[0].Content) { + t.Fatal("expected const-value change to reuse identical generated output") + } +} + func TestGenerateIncrementalBodyOnlyInvalidChangeDoesNotReusePreloadManifest(t *testing.T) { lockCacheHooks(t) state := saveCacheHooks() @@ -453,6 +692,448 @@ func TestGenerateIncrementalBodyOnlyInvalidChangeDoesNotReusePreloadManifest(t * } } +func TestGenerateIncrementalScenarioMatrix(t *testing.T) { + t.Parallel() + + type scenarioExpectation struct { + mode string + wantErr bool + wantSameOutput bool + } + + scenarios := []struct { + name string + apply func(t *testing.T, fx incrementalScenarioFixture) + want scenarioExpectation + }{ + { + name: "comment_only_change_reuses_preload", + apply: func(t *testing.T, fx incrementalScenarioFixture) { + writeFile(t, fx.depFile, strings.Join([]string{ + "package dep", + "", + "// SQLText controls SQL highlighting in log output.", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + }, + want: scenarioExpectation{mode: "preload", wantSameOutput: true}, + }, + { + name: "whitespace_only_change_reuses_preload", + apply: func(t *testing.T, fx incrementalScenarioFixture) { + writeFile(t, fx.depFile, strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "", + "func New(msg string) *Foo {", + "", + "\treturn &Foo{Message: helper(msg)}", + "", + "}", + "", + }, "\n")) + }, + want: scenarioExpectation{mode: "preload", wantSameOutput: true}, + }, + { + name: "function_body_change_reuses_preload", + apply: func(t *testing.T, fx incrementalScenarioFixture) { + writeFile(t, fx.depFile, strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string {", + "\treturn helper(SQLText)", + "}", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + }, + want: scenarioExpectation{mode: "preload", wantSameOutput: true}, + }, + { + name: "method_body_change_uses_local_fastpath", + apply: func(t *testing.T, fx incrementalScenarioFixture) { + writeFile(t, fx.depFile, strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func (f Foo) Summary() string {", + "\treturn helper(f.Message)", + "}", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n")) + }, + want: scenarioExpectation{mode: "local_fastpath", wantSameOutput: true}, + }, + { + name: "const_value_change_reuses_preload", + apply: func(t *testing.T, fx incrementalScenarioFixture) { + writeFile(t, fx.depFile, strings.Join([]string{ + "package dep", + "", + "const SQLText = \"blue\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + }, + want: scenarioExpectation{mode: "preload", wantSameOutput: true}, + }, + { + name: "var_initializer_change_reuses_preload", + apply: func(t *testing.T, fx incrementalScenarioFixture) { + writeFile(t, fx.depFile, strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 2", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + }, + want: scenarioExpectation{mode: "preload", wantSameOutput: true}, + }, + { + name: "add_top_level_helper_uses_local_fastpath", + apply: func(t *testing.T, fx incrementalScenarioFixture) { + writeFile(t, fx.depFile, strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "func NewTag() string { return \"tag\" }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + }, + want: scenarioExpectation{mode: "local_fastpath", wantSameOutput: true}, + }, + { + name: "import_only_implementation_change_reuses_preload", + apply: func(t *testing.T, fx incrementalScenarioFixture) { + writeFile(t, fx.depFile, strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return fmt.Sprint(msg) }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + }, + want: scenarioExpectation{mode: "preload", wantSameOutput: true}, + }, + { + name: "signature_change_uses_local_fastpath", + apply: func(t *testing.T, fx incrementalScenarioFixture) { + writeFile(t, fx.depFile, strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 7", + "", + "type Foo struct {", + "\tMessage string", + "\tCount int", + "}", + "", + "func NewMessage() string { return SQLText }", + "", + "func NewCount() int { return defaultCount }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string, count int) *Foo {", + "\treturn &Foo{Message: helper(msg), Count: count}", + "}", + "", + }, "\n")) + writeFile(t, fx.wireFile, strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, NewCount, New)", + "", + }, "\n")) + }, + want: scenarioExpectation{mode: "local_fastpath", wantSameOutput: false}, + }, + { + name: "struct_field_addition_uses_local_fastpath", + apply: func(t *testing.T, fx incrementalScenarioFixture) { + writeFile(t, fx.depFile, strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Foo struct {", + "\tMessage string", + "\tCount int", + "}", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg), Count: defaultCount}", + "}", + "", + }, "\n")) + }, + want: scenarioExpectation{mode: "local_fastpath", wantSameOutput: true}, + }, + { + name: "interface_method_addition_uses_local_fastpath", + apply: func(t *testing.T, fx incrementalScenarioFixture) { + writeFile(t, fx.depFile, strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Fooer interface {", + "\tMessage() string", + "\tCount() int", + "}", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + }, + want: scenarioExpectation{mode: "local_fastpath", wantSameOutput: true}, + }, + { + name: "new_source_file_uses_local_fastpath", + apply: func(t *testing.T, fx incrementalScenarioFixture) { + writeFile(t, fx.extraFile, strings.Join([]string{ + "package dep", + "", + "func NewTag() string { return \"tag\" }", + "", + }, "\n")) + }, + want: scenarioExpectation{mode: "fast", wantSameOutput: true}, + }, + { + name: "invalid_body_change_falls_back_and_errors", + apply: func(t *testing.T, fx incrementalScenarioFixture) { + writeFile(t, fx.depFile, strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return missing }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + }, + want: scenarioExpectation{mode: "generate_load", wantErr: true}, + }, + } + + for _, scenario := range scenarios { + scenario := scenario + t.Run(scenario.name, func(t *testing.T) { + fx := newIncrementalScenarioFixture(t) + + first, errs := Generate(fx.ctx, fx.root, fx.env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("baseline Generate returned errors: %v", errs) + } + if len(first) != 1 || len(first[0].Errs) > 0 { + t.Fatalf("unexpected baseline Generate result: %+v", first) + } + + scenario.apply(t, fx) + + var labels []string + timedCtx := WithTiming(fx.ctx, func(label string, _ time.Duration) { + labels = append(labels, label) + }) + second, errs := Generate(timedCtx, fx.root, fx.env, []string{"./app"}, &GenerateOptions{}) + + if scenario.want.wantErr { + if len(errs) == 0 { + t.Fatal("expected Generate to return errors") + } + if len(second) != 0 { + t.Fatalf("expected invalid incremental generate to stop before generation, got %+v", second) + } + } else { + if len(errs) > 0 { + t.Fatalf("incremental Generate returned errors: %v", errs) + } + if len(second) != 1 || len(second[0].Errs) > 0 { + t.Fatalf("unexpected incremental Generate result: %+v", second) + } + } + + switch scenario.want.mode { + case "preload": + if containsLabel(labels, "generate.load") { + t.Fatalf("expected preload reuse without generate.load, labels=%v", labels) + } + case "fast": + if containsLabel(labels, "generate.load") { + t.Fatalf("expected fast incremental path without generate.load, labels=%v", labels) + } + case "local_fastpath": + if containsLabel(labels, "generate.load") { + t.Fatalf("expected local fast path without generate.load, labels=%v", labels) + } + if containsLabel(labels, "load.packages.lazy.load") { + t.Fatalf("expected local fast path to skip lazy load, labels=%v", labels) + } + if !containsLabel(labels, "incremental.local_fastpath.load") { + t.Fatalf("expected local fast path load, labels=%v", labels) + } + case "generate_load": + if !containsLabel(labels, "generate.load") { + t.Fatalf("expected generate.load fallback, labels=%v", labels) + } + default: + t.Fatalf("unknown expected mode %q", scenario.want.mode) + } + + if scenario.want.wantErr { + return + } + + normal, errs := Generate(context.Background(), fx.root, fx.env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("normal Generate returned errors after edit: %v", errs) + } + if len(normal) != 1 || len(normal[0].Errs) > 0 { + t.Fatalf("unexpected normal Generate result after edit: %+v", normal) + } + if second[0].OutputPath != normal[0].OutputPath { + t.Fatalf("output paths differ: incremental=%q normal=%q", second[0].OutputPath, normal[0].OutputPath) + } + if string(second[0].Content) != string(normal[0].Content) { + t.Fatalf("incremental output differs from normal output after %s", scenario.name) + } + if scenario.want.wantSameOutput && string(first[0].Content) != string(second[0].Content) { + t.Fatalf("expected generated output to stay unchanged for %s", scenario.name) + } + if !scenario.want.wantSameOutput && string(first[0].Content) == string(second[0].Content) { + t.Fatalf("expected generated output to change for %s", scenario.name) + } + }) + } +} + func TestGenerateIncrementalShapeChangeFallsBackToLazyLoad(t *testing.T) { lockCacheHooks(t) state := saveCacheHooks() @@ -933,6 +1614,155 @@ func TestGenerateIncrementalColdBootstrapStillSeedsFastPath(t *testing.T) { } } +func TestLoadLocalPackagesForFastPathImportsUnchangedLocalDependencyFromLocalExport(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + writeDepRouterModule(t, root, repoRoot) + + env := append(os.Environ(), "GOWORK=off") + incrementalCtx := WithIncremental(context.Background(), true) + + if _, errs := Generate(incrementalCtx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { + t.Fatalf("baseline incremental Generate returned errors: %v", errs) + } + + depPkgPath := "example.com/app/dep" + depExportPath := mustLocalExportPath(t, root, env, depPkgPath) + if _, err := os.Stat(depExportPath); err != nil { + t.Fatalf("expected local export artifact at %s: %v", depExportPath, err) + } + + mutateRouterModule(t, root) + + preloadState, ok := prepareIncrementalPreloadState(context.Background(), root, env, []string{"./app"}, &GenerateOptions{}) + if !ok || preloadState == nil || preloadState.manifest == nil { + t.Fatal("expected preload state after baseline incremental generate") + } + loaded, err := loadLocalPackagesForFastPath(context.Background(), root, "", "example.com/app/app", []string{"example.com/app/router"}, preloadState.currentLocal, preloadState.manifest.ExternalPkgs) + if err != nil { + t.Fatalf("loadLocalPackagesForFastPath returned error: %v", err) + } + if _, ok := loaded.loader.localExports[depPkgPath]; !ok { + t.Fatalf("expected %s to be a local export candidate", depPkgPath) + } + if _, ok := loaded.loader.sourcePkgs[depPkgPath]; ok { + t.Fatalf("did not expect %s to be source-loaded", depPkgPath) + } + typesPkg, err := loaded.loader.importPackage(depPkgPath) + if err != nil { + t.Fatalf("importPackage(%s) returned error: %v", depPkgPath, err) + } + if typesPkg == nil || !typesPkg.Complete() { + t.Fatalf("expected complete imported package for %s, got %#v", depPkgPath, typesPkg) + } + if loaded.loader.pkgs[depPkgPath] != nil { + t.Fatalf("expected %s to avoid source loading when local export artifact is present", depPkgPath) + } +} + +func TestGenerateIncrementalMissingLocalExportFallsBackSafely(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + writeDepRouterModule(t, root, repoRoot) + + env := append(os.Environ(), "GOWORK=off") + incrementalCtx := WithIncremental(context.Background(), true) + + if _, errs := Generate(incrementalCtx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { + t.Fatalf("baseline incremental Generate returned errors: %v", errs) + } + + depExportPath := mustLocalExportPath(t, root, env, "example.com/app/dep") + if err := os.Remove(depExportPath); err != nil { + t.Fatalf("Remove(%s) failed: %v", depExportPath, err) + } + + mutateRouterModule(t, root) + + var labels []string + timedCtx := WithTiming(incrementalCtx, func(label string, _ time.Duration) { + labels = append(labels, label) + }) + gens, errs := Generate(timedCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("Generate returned errors: %v", errs) + } + if len(gens) != 1 || len(gens[0].Errs) > 0 { + t.Fatalf("unexpected Generate results: %+v", gens) + } + if !containsLabel(labels, "incremental.local_fastpath.load") { + t.Fatalf("expected missing local export to stay on local fast path, labels=%v", labels) + } + refreshedExportPath := mustLocalExportPath(t, root, env, "example.com/app/dep") + if _, err := os.Stat(refreshedExportPath); err != nil { + t.Fatalf("expected local export artifact to be refreshed at %s: %v", refreshedExportPath, err) + } +} + +func TestGenerateIncrementalCorruptedLocalExportFallsBackSafely(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + writeDepRouterModule(t, root, repoRoot) + + env := append(os.Environ(), "GOWORK=off") + incrementalCtx := WithIncremental(context.Background(), true) + + if _, errs := Generate(incrementalCtx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { + t.Fatalf("baseline incremental Generate returned errors: %v", errs) + } + + depExportPath := mustLocalExportPath(t, root, env, "example.com/app/dep") + if err := os.WriteFile(depExportPath, []byte("not-a-valid-export"), 0644); err != nil { + t.Fatalf("WriteFile(%s) failed: %v", depExportPath, err) + } + + mutateRouterModule(t, root) + + var labels []string + timedCtx := WithTiming(incrementalCtx, func(label string, _ time.Duration) { + labels = append(labels, label) + }) + gens, errs := Generate(timedCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("Generate returned errors: %v", errs) + } + if len(gens) != 1 || len(gens[0].Errs) > 0 { + t.Fatalf("unexpected Generate results: %+v", gens) + } + if !containsLabel(labels, "incremental.local_fastpath.load") { + t.Fatalf("expected corrupted local export to stay on local fast path, labels=%v", labels) + } + refreshedExportPath := mustLocalExportPath(t, root, env, "example.com/app/dep") + data, err := os.ReadFile(refreshedExportPath) + if err != nil { + t.Fatalf("ReadFile(%s) failed: %v", refreshedExportPath, err) + } + if string(data) == "not-a-valid-export" { + t.Fatalf("expected corrupted local export artifact to be refreshed at %s", refreshedExportPath) + } +} + func TestGenerateIncrementalShapeChangeWithUnchangedDependentPackageMatchesNormalGenerate(t *testing.T) { lockCacheHooks(t) state := saveCacheHooks() @@ -1174,9 +2004,6 @@ func TestGenerateIncrementalInvalidShapeChangeDoesNotReuseManifest(t *testing.T) if len(errs) == 0 { t.Fatal("expected invalid incremental generate to return errors") } - if containsLabel(secondLabels, "generate.load") { - t.Fatalf("expected invalid incremental generate to stop before slow-path load, labels=%v", secondLabels) - } if got := errs[0].Error(); !strings.Contains(got, "type-check failed for example.com/app/app") { t.Fatalf("expected fast-path type-check error, got %q", got) } @@ -1327,8 +2154,8 @@ func TestGenerateIncrementalRecoversAfterInvalidShapeChange(t *testing.T) { if string(third[0].Content) != string(normal[0].Content) { t.Fatal("incremental output differs from normal Generate output after recovering from invalid shape change") } - if !containsLabel(thirdLabels, "generate.load") { - t.Fatalf("expected recovery run to fall back to normal load after invalidating stale manifest, labels=%v", thirdLabels) + if !containsLabel(thirdLabels, "incremental.local_fastpath.load") && !containsLabel(thirdLabels, "generate.load") { + t.Fatalf("expected recovery run to rebuild through local fast path or normal load, labels=%v", thirdLabels) } } @@ -1466,6 +2293,55 @@ func TestGenerateIncrementalToggleBackToKnownShapeHitsArchivedPreloadManifest(t } } +func TestGenerateIncrementalPreloadHitRefreshesMissingContentHashes(t *testing.T) { + fx := newIncrementalScenarioFixture(t) + + first, errs := Generate(fx.ctx, fx.root, fx.env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("baseline Generate returned errors: %v", errs) + } + if len(first) != 1 || len(first[0].Errs) > 0 { + t.Fatalf("unexpected baseline Generate result: %+v", first) + } + + selectorKey := incrementalManifestSelectorKey(fx.root, fx.env, []string{"./app"}, &GenerateOptions{}) + manifest, ok := readIncrementalManifest(selectorKey) + if !ok { + t.Fatal("expected incremental manifest after baseline generate") + } + if len(manifest.LocalPackages) == 0 { + t.Fatal("expected local packages in incremental manifest") + } + + stale := *manifest + stale.LocalPackages = append([]packageFingerprint(nil), manifest.LocalPackages...) + for i := range stale.LocalPackages { + stale.LocalPackages[i].ContentHash = "" + stale.LocalPackages[i].Dirs = nil + } + writeIncrementalManifestFile(selectorKey, &stale) + writeIncrementalManifestFile(incrementalManifestStateKey(selectorKey, stale.LocalPackages), &stale) + + second, errs := Generate(fx.ctx, fx.root, fx.env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("refresh Generate returned errors: %v", errs) + } + if len(second) != 1 || len(second[0].Errs) > 0 { + t.Fatalf("unexpected refresh Generate result: %+v", second) + } + + preloadState, ok := prepareIncrementalPreloadState(fx.ctx, fx.root, fx.env, []string{"./app"}, &GenerateOptions{}) + if !ok { + t.Fatal("expected preload state after manifest refresh") + } + if !preloadState.valid { + t.Fatalf("expected refreshed preload state to be valid, reason=%s", preloadState.reason) + } + if len(preloadState.touched) != 0 { + t.Fatalf("expected refreshed preload state to have no touched packages, got %v", preloadState.touched) + } +} + func containsLabel(labels []string, want string) bool { for _, label := range labels { if label == want { @@ -1475,6 +2351,96 @@ func containsLabel(labels []string, want string) bool { return false } +type incrementalScenarioFixture struct { + root string + env []string + ctx context.Context + depFile string + wireFile string + extraFile string +} + +func newIncrementalScenarioFixture(t *testing.T) incrementalScenarioFixture { + t.Helper() + + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + + writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *dep.Foo {", + "\twire.Build(dep.NewSet)", + "\treturn nil", + "}", + "", + }, "\n")) + + depFile := filepath.Join(root, "dep", "dep.go") + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + + wireFile := filepath.Join(root, "dep", "wire.go") + writeFile(t, wireFile, strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, New)", + "", + }, "\n")) + + return incrementalScenarioFixture{ + root: root, + env: append(os.Environ(), "GOWORK=off"), + ctx: WithIncremental(context.Background(), true), + depFile: depFile, + wireFile: wireFile, + extraFile: filepath.Join(root, "dep", "extra.go"), + } +} + func mustRepoRoot(t *testing.T) string { t.Helper() wd, err := os.Getwd() @@ -1488,6 +2454,137 @@ func mustRepoRoot(t *testing.T) string { return repoRoot } +func writeDepRouterModule(t *testing.T, root string, repoRoot string) { + t.Helper() + writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"example.com/app/router\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *router.Routes {", + "\twire.Build(dep.Set, router.Set)", + "\treturn nil", + "}", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Controller struct { Message string }", + "", + "func NewMessage() string { return \"ok\" }", + "", + "func NewController(msg string) *Controller {", + "\treturn &Controller{Message: msg}", + "}", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var Set = wire.NewSet(NewMessage, NewController)", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "router", "router.go"), strings.Join([]string{ + "package router", + "", + "import \"example.com/app/dep\"", + "", + "type Routes struct { Controller *dep.Controller }", + "", + "func ProvideRoutes(controller *dep.Controller) *Routes {", + "\treturn &Routes{Controller: controller}", + "}", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "router", "wire.go"), strings.Join([]string{ + "package router", + "", + "import \"github.com/goforj/wire\"", + "", + "var Set = wire.NewSet(ProvideRoutes)", + "", + }, "\n")) +} + +func mutateRouterModule(t *testing.T, root string) { + t.Helper() + writeFile(t, filepath.Join(root, "router", "router.go"), strings.Join([]string{ + "package router", + "", + "import \"example.com/app/dep\"", + "", + "type Routes struct {", + "\tController *dep.Controller", + "\tVersion int", + "}", + "", + "func NewVersion() int {", + "\treturn 2", + "}", + "", + "func ProvideRoutes(controller *dep.Controller, version int) *Routes {", + "\treturn &Routes{Controller: controller, Version: version}", + "}", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "router", "wire.go"), strings.Join([]string{ + "package router", + "", + "import \"github.com/goforj/wire\"", + "", + "var Set = wire.NewSet(NewVersion, ProvideRoutes)", + "", + }, "\n")) +} + +func mustLocalExportPath(t *testing.T, root string, env []string, pkgPath string) string { + t.Helper() + pkgs, loader, errs := load(context.Background(), root, env, "", []string{"./app"}) + if len(errs) > 0 { + t.Fatalf("load returned errors: %v", errs) + } + if loader == nil { + t.Fatal("load returned nil loader") + } + if _, errs := loader.load("example.com/app/app"); len(errs) > 0 { + t.Fatalf("lazy load returned errors: %v", errs) + } + snapshot := buildIncrementalManifestSnapshotFromPackages(root, "", incrementalManifestPackages(pkgs, loader)) + if snapshot == nil || snapshot.fingerprints[pkgPath] == nil { + t.Fatalf("missing fingerprint for %s", pkgPath) + } + path := localExportPathForFingerprint(root, "", snapshot.fingerprints[pkgPath]) + if path == "" { + t.Fatalf("missing local export path for %s", pkgPath) + } + return path +} + func writeFile(t *testing.T, path string, content string) { t.Helper() if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { diff --git a/internal/wire/local_export.go b/internal/wire/local_export.go new file mode 100644 index 0000000..f83ed7b --- /dev/null +++ b/internal/wire/local_export.go @@ -0,0 +1,97 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "crypto/sha256" + "fmt" + "go/token" + "go/types" + "path/filepath" + + "golang.org/x/tools/go/gcexportdata" + "golang.org/x/tools/go/packages" +) + +const localExportVersion = "wire-local-export-v1" + +func localExportKey(wd string, tags string, pkgPath string, shapeHash string) string { + sum := sha256.Sum256([]byte(localExportVersion + "\x00" + packageCacheScope(wd) + "\x00" + tags + "\x00" + pkgPath + "\x00" + shapeHash)) + return fmt.Sprintf("%x", sum[:]) +} + +func localExportPath(key string) string { + return filepath.Join(cacheDir(), key+".iexp") +} + +func localExportPathForFingerprint(wd string, tags string, fp *packageFingerprint) string { + if fp == nil || fp.PkgPath == "" || fp.ShapeHash == "" { + return "" + } + return localExportPath(localExportKey(wd, tags, fp.PkgPath, fp.ShapeHash)) +} + +func localExportExists(wd string, tags string, fp *packageFingerprint) bool { + path := localExportPathForFingerprint(wd, tags, fp) + if path == "" { + return false + } + _, err := osStat(path) + return err == nil +} + +func writeLocalPackageExports(wd string, tags string, pkgs []*packages.Package, fps map[string]*packageFingerprint) { + if len(pkgs) == 0 || len(fps) == 0 { + return + } + moduleRoot := findModuleRoot(wd) + for _, pkg := range pkgs { + if pkg == nil || pkg.Types == nil || pkg.PkgPath == "" { + continue + } + if classifyPackageLocation(moduleRoot, pkg) != "local" { + continue + } + fp := fps[pkg.PkgPath] + path := localExportPathForFingerprint(wd, tags, fp) + if path == "" { + continue + } + writeLocalPackageExportFile(path, pkg.Fset, pkg.Types) + } +} + +func writeLocalPackageExportFile(path string, fset *token.FileSet, pkg *types.Package) { + if path == "" || fset == nil || pkg == nil { + return + } + dir := cacheDir() + if err := osMkdirAll(dir, 0755); err != nil { + return + } + tmp, err := osCreateTemp(dir, filepath.Base(path)+".tmp-") + if err != nil { + return + } + writeErr := gcexportdata.Write(tmp, fset, pkg) + closeErr := tmp.Close() + if writeErr != nil || closeErr != nil { + osRemove(tmp.Name()) + return + } + if err := osRename(tmp.Name(), path); err != nil { + osRemove(tmp.Name()) + } +} diff --git a/internal/wire/local_fastpath.go b/internal/wire/local_fastpath.go index 04d9cb8..89ea402 100644 --- a/internal/wire/local_fastpath.go +++ b/internal/wire/local_fastpath.go @@ -31,6 +31,7 @@ import ( "strings" "time" + "golang.org/x/tools/go/gcexportdata" "golang.org/x/tools/go/packages" ) @@ -96,8 +97,10 @@ func tryIncrementalLocalFastPath(ctx context.Context, wd string, env []string, p for _, path := range snapshot.changed { changedSet[path] = struct{}{} } + currentPackages := loaded.currentPackages() writeIncrementalFingerprints(snapshot, wd, opts.Tags) - writeIncrementalPackageSummariesWithSummary(loader, loaded.allPackages, newSummaryProviderResolver(ctx, loaded.loader.summaries, loaded.loader.importExportPackage), changedSet) + writeLocalPackageExports(wd, opts.Tags, currentPackages, loaded.fingerprints) + writeIncrementalPackageSummariesWithSummary(loader, currentPackages, newSummaryProviderResolver(ctx, loaded.loader.summaries, loaded.loader.importExportPackage), changedSet) writeIncrementalManifestFromState(wd, env, patterns, opts, state, snapshot, generated) writeIncrementalGraphFromSnapshot(wd, opts.Tags, roots, loaded.fingerprints) @@ -235,6 +238,21 @@ type localFastPathLoaded struct { loader *localFastPathLoader } +func (l *localFastPathLoaded) currentPackages() []*packages.Package { + if l == nil { + return nil + } + if l.loader == nil || len(l.loader.pkgs) == 0 { + return l.allPackages + } + all := make([]*packages.Package, 0, len(l.loader.pkgs)) + for _, pkg := range l.loader.pkgs { + all = append(all, pkg) + } + sort.Slice(all, func(i, j int) bool { return all[i].PkgPath < all[j].PkgPath }) + return all +} + type localFastPathLoader struct { ctx context.Context wd string @@ -249,10 +267,25 @@ type localFastPathLoader struct { pkgs map[string]*packages.Package imported map[string]*types.Package externalMeta map[string]externalPackageExport + localExports map[string]string externalImp types.Importer + externalFallback types.Importer } func loadLocalPackagesForFastPath(ctx context.Context, wd string, tags string, rootPkgPath string, changed []string, current []packageFingerprint, external []externalPackageExport) (*localFastPathLoaded, error) { + return loadLocalPackagesForFastPathMode(ctx, wd, tags, rootPkgPath, changed, current, external, false) +} + +func validateTouchedPackagesFastPath(ctx context.Context, wd string, tags string, touched []string, current []packageFingerprint, external []externalPackageExport) error { + if len(touched) == 0 { + return nil + } + rootPkgPath := touched[0] + _, err := loadLocalPackagesForFastPathMode(ctx, wd, tags, rootPkgPath, touched, current, external, true) + return err +} + +func loadLocalPackagesForFastPathMode(ctx context.Context, wd string, tags string, rootPkgPath string, changed []string, current []packageFingerprint, external []externalPackageExport, validationOnly bool) (*localFastPathLoaded, error) { meta := fingerprintsFromSlice(current) if len(meta) == 0 { return nil, fmt.Errorf("no local fingerprints") @@ -265,6 +298,9 @@ func loadLocalPackagesForFastPath(ctx context.Context, wd string, tags string, r if item.PkgPath == "" || item.ExportFile == "" { continue } + if meta[item.PkgPath] != nil { + continue + } externalMeta[item.PkgPath] = item } loader := &localFastPathLoader{ @@ -281,12 +317,18 @@ func loadLocalPackagesForFastPath(ctx context.Context, wd string, tags string, r pkgs: make(map[string]*packages.Package, len(meta)), imported: make(map[string]*types.Package, len(meta)+len(externalMeta)), externalMeta: externalMeta, + localExports: make(map[string]string), } for _, path := range changed { loader.changedPkgs[path] = struct{}{} } - loader.markSourceClosure() - candidates := make(map[string]*packageSummary) + if validationOnly { + for path := range loader.changedPkgs { + loader.sourcePkgs[path] = struct{}{} + } + } else { + loader.markSourceClosure() + } for path, fp := range meta { if path == rootPkgPath { continue @@ -294,7 +336,19 @@ func loadLocalPackagesForFastPath(ctx context.Context, wd string, tags string, r if _, changed := loader.changedPkgs[path]; changed { continue } - if _, ok := externalMeta[path]; !ok { + if _, ok := loader.sourcePkgs[path]; ok { + continue + } + if exportPath := localExportPathForFingerprint(wd, tags, fp); exportPath != "" && localExportExists(wd, tags, fp) { + loader.localExports[path] = exportPath + } + } + candidates := make(map[string]*packageSummary) + for path, fp := range meta { + if path == rootPkgPath { + continue + } + if _, changed := loader.changedPkgs[path]; changed { continue } summary, ok := readIncrementalPackageSummary(incrementalSummaryKey(wd, tags, path)) @@ -305,9 +359,24 @@ func loadLocalPackagesForFastPath(ctx context.Context, wd string, tags string, r } loader.summaries = filterSupportedPackageSummaries(candidates) loader.externalImp = importerpkg.ForCompiler(loader.fset, "gc", loader.openExternalExport) - root, err := loader.load(rootPkgPath) - if err != nil { - return nil, err + loader.externalFallback = importerpkg.ForCompiler(loader.fset, "gc", nil) + var root *packages.Package + if validationOnly { + for _, path := range changed { + pkg, err := loader.load(path) + if err != nil { + return nil, err + } + if root == nil { + root = pkg + } + } + } else { + var err error + root, err = loader.load(rootPkgPath) + if err != nil { + return nil, err + } } all := make([]*packages.Package, 0, len(loader.pkgs)) for _, pkg := range loader.pkgs { @@ -341,6 +410,7 @@ func (l *localFastPathLoader) load(pkgPath string) (*packages.Package, error) { mode |= parser.ParseComments } syntax := make([]*ast.File, 0, len(files)) + parseStart := time.Now() for _, name := range files { file, err := l.parseFileForFastPath(name, mode, pkgPath) if err != nil { @@ -348,6 +418,7 @@ func (l *localFastPathLoader) load(pkgPath string) (*packages.Package, error) { } syntax = append(syntax, file) } + logTiming(l.ctx, "incremental.local_fastpath.parse", parseStart) if len(syntax) == 0 { return nil, fmt.Errorf("package %s parsed no files", pkgPath) } @@ -376,7 +447,9 @@ func (l *localFastPathLoader) load(pkgPath string) (*packages.Package, error) { pkg.Errors = append(pkg.Errors, packages.Error{Msg: err.Error()}) }, } + typecheckStart := time.Now() checkedPkg, err := conf.Check(pkgPath, l.fset, syntax, info) + logTiming(l.ctx, "incremental.local_fastpath.typecheck", typecheckStart) if checkedPkg != nil { pkg.Types = checkedPkg l.imported[pkgPath] = checkedPkg @@ -422,6 +495,7 @@ func (l *localFastPathLoader) parseFileForFastPath(name string, mode parser.Mode func (l *localFastPathLoader) reloadWithoutBodyStripping(pkgPath string, files []string, mode parser.Mode, pkg *packages.Package) (*packages.Package, error) { syntax := make([]*ast.File, 0, len(files)) + parseStart := time.Now() for _, name := range files { file, err := parser.ParseFile(l.fset, name, nil, mode) if err != nil { @@ -429,6 +503,7 @@ func (l *localFastPathLoader) reloadWithoutBodyStripping(pkgPath string, files [ } syntax = append(syntax, file) } + logTiming(l.ctx, "incremental.local_fastpath.parse_retry", parseStart) pkg.Syntax = syntax pkg.Errors = nil pkg.TypesInfo = newFastPathTypesInfo(pkgPath == l.rootPkgPath) @@ -442,7 +517,9 @@ func (l *localFastPathLoader) reloadWithoutBodyStripping(pkgPath string, files [ pkg.Errors = append(pkg.Errors, packages.Error{Msg: err.Error()}) }, } + typecheckStart := time.Now() checkedPkg, err := conf.Check(pkgPath, l.fset, syntax, pkg.TypesInfo) + logTiming(l.ctx, "incremental.local_fastpath.typecheck_retry", typecheckStart) if checkedPkg != nil { pkg.Types = checkedPkg l.imported[pkgPath] = checkedPkg @@ -471,13 +548,29 @@ func (l *localFastPathLoader) shouldRetryWithoutBodyStripping(pkgPath string, er func (l *localFastPathLoader) importPackage(path string) (*types.Package, error) { if l.shouldImportFromExport(path) { - return l.importExportPackage(path) + pkg, err := l.importExportPackage(path) + if err == nil { + return pkg, nil + } + // Cached local export artifacts are an optimization only. If one is + // missing or corrupted, fall back to source loading for correctness. + if _, ok := l.localExports[path]; ok && l.meta[path] != nil { + delete(l.localExports, path) + pkg, loadErr := l.load(path) + if loadErr == nil { + l.refreshLocalExport(path, pkg) + return pkg.Types, nil + } + return nil, loadErr + } + return nil, err } if l.meta[path] != nil { pkg, err := l.load(path) if err != nil { return nil, err } + l.refreshLocalExport(path, pkg) return pkg.Types, nil } if l.externalImp == nil { @@ -536,7 +629,20 @@ func (l *localFastPathLoader) importExportPackage(path string) (*types.Package, if l == nil { return nil, fmt.Errorf("missing local fast path loader") } - if pkg := l.imported[path]; pkg != nil { + if pkg := l.imported[path]; pkg != nil && pkg.Complete() { + return pkg, nil + } + if exportPath := l.localExports[path]; exportPath != "" { + f, err := os.Open(exportPath) + if err != nil { + return nil, err + } + defer f.Close() + pkg, err := gcexportdata.Read(f, l.fset, l.imported, path) + if err != nil { + return nil, err + } + l.imported[path] = pkg return pkg, nil } if l.externalImp == nil { @@ -544,6 +650,13 @@ func (l *localFastPathLoader) importExportPackage(path string) (*types.Package, } pkg, err := l.externalImp.Import(path) if err != nil { + if l.externalFallback != nil && strings.Contains(err.Error(), "missing external export data for ") { + pkg, fallbackErr := l.externalFallback.Import(path) + if fallbackErr == nil { + l.imported[path] = pkg + return pkg, nil + } + } return nil, err } l.imported[path] = pkg @@ -557,10 +670,26 @@ func (l *localFastPathLoader) shouldImportFromExport(pkgPath string) bool { if _, source := l.sourcePkgs[pkgPath]; source { return false } - _, ok := l.summaries[pkgPath] + if _, ok := l.localExports[pkgPath]; ok { + return true + } + _, ok := l.externalMeta[pkgPath] return ok } +func (l *localFastPathLoader) refreshLocalExport(pkgPath string, pkg *packages.Package) { + if l == nil || pkg == nil || pkg.Fset == nil || pkg.Types == nil { + return + } + fp := l.meta[pkgPath] + exportPath := localExportPathForFingerprint(l.wd, l.tags, fp) + if exportPath == "" { + return + } + writeLocalPackageExportFile(exportPath, pkg.Fset, pkg.Types) + l.localExports[pkgPath] = exportPath +} + func (l *localFastPathLoader) markSourceClosure() { if l == nil { return @@ -802,14 +931,15 @@ func writeIncrementalManifestFromState(wd string, env []string, patterns []strin if snapshot == nil || len(generated) == 0 || state == nil || state.manifest == nil { return } + scope := runCacheScope(wd, patterns) manifest := &incrementalManifest{ Version: incrementalManifestVersion, - WD: filepath.Clean(wd), + WD: scope, Tags: opts.Tags, Prefix: opts.PrefixOutputFile, HeaderHash: headerHash(opts.Header), EnvHash: envHash(env), - Patterns: sortedStrings(patterns), + Patterns: normalizePatternsForScope(wd, packageCacheScope(wd), patterns), LocalPackages: snapshotPackageFingerprints(snapshot), ExternalPkgs: append([]externalPackageExport(nil), state.manifest.ExternalPkgs...), ExternalFiles: append([]cacheFile(nil), state.manifest.ExternalFiles...), @@ -841,7 +971,7 @@ func writeIncrementalGraphFromSnapshot(wd string, tags string, roots []string, f } graph := &incrementalGraph{ Version: incrementalGraphVersion, - WD: filepath.Clean(wd), + WD: packageCacheScope(wd), Tags: tags, Roots: append([]string(nil), roots...), LocalReverse: make(map[string][]string), diff --git a/internal/wire/wire.go b/internal/wire/wire.go index 1b6140f..24ca575 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -154,9 +154,13 @@ func Generate(ctx context.Context, wd string, env []string, patterns []string, o snapshot := buildIncrementalManifestSnapshotFromPackages(wd, opts.Tags, incrementalManifestPackages(pkgs, loader)) writeIncrementalManifestWithOptions(wd, env, patterns, opts, incrementalManifestPackages(pkgs, loader), snapshot, generated, false) if snapshot != nil { + writeLocalPackageExports(wd, opts.Tags, incrementalManifestPackages(pkgs, loader), snapshot.fingerprints) writeIncrementalGraphFromSnapshot(wd, opts.Tags, manifestOutputPkgPathsFromGenerated(generated), snapshot.fingerprints) + loader.fingerprints = snapshot } + writeIncrementalPackageSummaries(loader, pkgs) } else { + writeLocalPackageExports(wd, opts.Tags, incrementalManifestPackages(pkgs, loader), loader.fingerprints.fingerprints) writeIncrementalPackageSummaries(loader, pkgs) writeIncrementalManifest(wd, env, patterns, opts, incrementalManifestPackages(pkgs, loader), loader.fingerprints, generated) } diff --git a/scripts/incremental-scenarios.sh b/scripts/incremental-scenarios.sh new file mode 100755 index 0000000..b59e970 --- /dev/null +++ b/scripts/incremental-scenarios.sh @@ -0,0 +1,137 @@ +#!/usr/bin/env bash + +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +export GOCACHE="${GOCACHE:-/tmp/gocache}" +export GOMODCACHE="${GOMODCACHE:-/tmp/gomodcache}" + +usage() { + cat <<'EOF' +Usage: + scripts/incremental-scenarios.sh test + scripts/incremental-scenarios.sh matrix + scripts/incremental-scenarios.sh table + scripts/incremental-scenarios.sh budgets + scripts/incremental-scenarios.sh bench + scripts/incremental-scenarios.sh large-table + scripts/incremental-scenarios.sh large-breakdown + scripts/incremental-scenarios.sh report + scripts/incremental-scenarios.sh all + +Commands: + test Run the full internal/wire test suite. + matrix Run the incremental scenario matrix correctness test. + table Print the incremental scenario timing table. + budgets Enforce the incremental scenario performance budgets. + bench Run the incremental scenario benchmark suite. + large-table Print the large-repo comparison timing table. + large-breakdown Print the large-repo shape-change breakdown table. + report Run the main timing report: scenario table, budgets, and large-repo table. + all Run matrix, table, budgets, and the large-repo table in sequence. +EOF +} + +print_section() { + local title="$1" + printf '\n== %s ==\n' "$title" +} + +print_test_table() { + local output_file="$1" + awk ' + /^\+[-+]+\+$/ { in_table=1 } + in_table && !/^--- PASS:/ && !/^PASS$/ && !/^ok[[:space:]]/ { print } + /^--- PASS:/ && in_table { exit } + ' "$output_file" +} + +run_test_table() { + local env_var="$1" + local test_name="$2" + local output_file + output_file="$(mktemp)" + env "$env_var"=1 go test ./internal/wire -run "$test_name" -count=1 -v >"$output_file" + print_test_table "$output_file" + rm -f "$output_file" +} + +run_test() { + go test ./internal/wire -count=1 +} + +run_matrix() { + go test ./internal/wire -run TestGenerateIncrementalScenarioMatrix -count=1 +} + +run_table() { + run_test_table WIRE_BENCH_SCENARIOS TestPrintIncrementalScenarioBenchmarkTable +} + +run_budgets() { + WIRE_PERF_BUDGETS=1 go test ./internal/wire -run TestIncrementalScenarioPerformanceBudgets -count=1 >/dev/null + echo "PASS" +} + +run_bench() { + go test ./internal/wire -run '^$' -bench BenchmarkGenerateIncrementalScenarioMatrix -benchmem -count=1 +} + +run_large_table() { + run_test_table WIRE_BENCH_TABLE TestPrintLargeRepoBenchmarkComparisonTable +} + +run_large_breakdown() { + run_test_table WIRE_BENCH_BREAKDOWN TestPrintLargeRepoShapeChangeBreakdownTable +} + +run_report() { + print_section "Scenario Timing Table" + run_table + print_section "Scenario Performance Budgets" + run_budgets + print_section "Large Repo Comparison Table" + run_large_table +} + +cmd="${1:-}" +case "$cmd" in + test) + run_test + ;; + matrix) + run_matrix + ;; + table) + run_table + ;; + budgets) + run_budgets + ;; + bench) + run_bench + ;; + large-table) + run_large_table + ;; + large-breakdown) + run_large_breakdown + ;; + report) + run_report + ;; + all) + run_matrix + run_report + ;; + ""|-h|--help|help) + usage + ;; + *) + echo "Unknown command: $cmd" >&2 + usage >&2 + exit 1 + ;; +esac From d3486f551be2d63e51dc9bfc04440e91bb942bdc Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sun, 15 Mar 2026 03:15:56 -0500 Subject: [PATCH 08/79] feat: custom loader initial --- cmd/wire/cache_cmd.go | 67 - cmd/wire/check_cmd.go | 3 - cmd/wire/diff_cmd.go | 3 - cmd/wire/gen_cmd.go | 3 - cmd/wire/incremental_flag.go | 60 - cmd/wire/main.go | 9 +- cmd/wire/show_cmd.go | 3 - cmd/wire/watch_cmd.go | 3 - internal/loader/custom.go | 1014 +++++++ internal/loader/discovery.go | 94 + internal/loader/fallback.go | 224 ++ internal/loader/loader.go | 135 + internal/loader/loader_test.go | 821 ++++++ internal/loader/mode.go | 38 + internal/loader/timing.go | 41 + internal/wire/cache_bypass.go | 17 - internal/wire/cache_coverage_test.go | 1099 ------- internal/wire/cache_generate_test.go | 100 - internal/wire/cache_key.go | 352 --- internal/wire/cache_manifest.go | 393 --- internal/wire/cache_scope.go | 69 - internal/wire/cache_scope_test.go | 59 - internal/wire/cache_store.go | 77 - internal/wire/cache_test.go | 387 --- internal/wire/generate_package.go | 126 - internal/wire/generate_package_test.go | 137 - internal/wire/incremental.go | 85 - internal/wire/incremental_bench_test.go | 1495 ---------- internal/wire/incremental_fingerprint.go | 674 ----- internal/wire/incremental_fingerprint_test.go | 142 - internal/wire/incremental_graph.go | 306 -- internal/wire/incremental_graph_test.go | 97 - internal/wire/incremental_manifest.go | 1158 -------- internal/wire/incremental_session.go | 102 - internal/wire/incremental_summary.go | 656 ----- internal/wire/incremental_summary_test.go | 295 -- internal/wire/incremental_test.go | 65 - internal/wire/load_debug.go | 31 +- internal/wire/loader_test.go | 2596 ----------------- internal/wire/loader_timing_bridge.go | 17 + .../{cache_hooks.go => loader_validation.go} | 34 +- internal/wire/local_export.go | 97 - internal/wire/local_fastpath.go | 1031 ------- internal/wire/parse.go | 146 +- internal/wire/parse_coverage_test.go | 12 +- internal/wire/parser_lazy_loader.go | 188 -- internal/wire/parser_lazy_loader_test.go | 204 -- internal/wire/summary_provider_resolver.go | 223 -- internal/wire/wire.go | 149 +- internal/wire/wire_test.go | 121 +- 50 files changed, 2645 insertions(+), 12613 deletions(-) delete mode 100644 cmd/wire/cache_cmd.go delete mode 100644 cmd/wire/incremental_flag.go create mode 100644 internal/loader/custom.go create mode 100644 internal/loader/discovery.go create mode 100644 internal/loader/fallback.go create mode 100644 internal/loader/loader.go create mode 100644 internal/loader/loader_test.go create mode 100644 internal/loader/mode.go create mode 100644 internal/loader/timing.go delete mode 100644 internal/wire/cache_bypass.go delete mode 100644 internal/wire/cache_coverage_test.go delete mode 100644 internal/wire/cache_generate_test.go delete mode 100644 internal/wire/cache_key.go delete mode 100644 internal/wire/cache_manifest.go delete mode 100644 internal/wire/cache_scope.go delete mode 100644 internal/wire/cache_scope_test.go delete mode 100644 internal/wire/cache_store.go delete mode 100644 internal/wire/cache_test.go delete mode 100644 internal/wire/generate_package.go delete mode 100644 internal/wire/generate_package_test.go delete mode 100644 internal/wire/incremental.go delete mode 100644 internal/wire/incremental_bench_test.go delete mode 100644 internal/wire/incremental_fingerprint.go delete mode 100644 internal/wire/incremental_fingerprint_test.go delete mode 100644 internal/wire/incremental_graph.go delete mode 100644 internal/wire/incremental_graph_test.go delete mode 100644 internal/wire/incremental_manifest.go delete mode 100644 internal/wire/incremental_session.go delete mode 100644 internal/wire/incremental_summary.go delete mode 100644 internal/wire/incremental_summary_test.go delete mode 100644 internal/wire/incremental_test.go delete mode 100644 internal/wire/loader_test.go create mode 100644 internal/wire/loader_timing_bridge.go rename internal/wire/{cache_hooks.go => loader_validation.go} (50%) delete mode 100644 internal/wire/local_export.go delete mode 100644 internal/wire/local_fastpath.go delete mode 100644 internal/wire/parser_lazy_loader.go delete mode 100644 internal/wire/parser_lazy_loader_test.go delete mode 100644 internal/wire/summary_provider_resolver.go diff --git a/cmd/wire/cache_cmd.go b/cmd/wire/cache_cmd.go deleted file mode 100644 index f34d381..0000000 --- a/cmd/wire/cache_cmd.go +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2018 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "context" - "flag" - "fmt" - "log" - - "github.com/goforj/wire/internal/wire" - "github.com/google/subcommands" -) - -type cacheCmd struct { - clear bool -} - -// Name returns the subcommand name. -func (*cacheCmd) Name() string { return "cache" } - -// Synopsis returns a short summary of the subcommand. -func (*cacheCmd) Synopsis() string { - return "inspect or clear the wire cache" -} - -// Usage returns the help text for the subcommand. -func (*cacheCmd) Usage() string { - return `cache [-clear|clear] - - By default, prints the cache directory. With -clear or clear, removes all cache files. -` -} - -// SetFlags registers flags for the subcommand. -func (cmd *cacheCmd) SetFlags(f *flag.FlagSet) { - f.BoolVar(&cmd.clear, "clear", false, "remove all cached data") -} - -// Execute runs the subcommand. -func (cmd *cacheCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { - if f.NArg() > 0 && f.Arg(0) == "clear" { - cmd.clear = true - } - if cmd.clear { - if err := wire.ClearCache(); err != nil { - log.Printf("failed to clear cache: %v\n", err) - return subcommands.ExitFailure - } - log.Printf("cleared cache at %s\n", wire.CacheDir()) - return subcommands.ExitSuccess - } - fmt.Println(wire.CacheDir()) - return subcommands.ExitSuccess -} diff --git a/cmd/wire/check_cmd.go b/cmd/wire/check_cmd.go index 71872d9..897bec2 100644 --- a/cmd/wire/check_cmd.go +++ b/cmd/wire/check_cmd.go @@ -27,7 +27,6 @@ import ( type checkCmd struct { tags string - incremental optionalBoolFlag profile profileFlags } @@ -53,7 +52,6 @@ func (*checkCmd) Usage() string { // SetFlags registers flags for the subcommand. func (cmd *checkCmd) SetFlags(f *flag.FlagSet) { f.StringVar(&cmd.tags, "tags", "", "append build tags to the default wirebuild") - addIncrementalFlag(&cmd.incremental, f) cmd.profile.addFlags(f) } @@ -67,7 +65,6 @@ func (cmd *checkCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...inter defer stop() totalStart := time.Now() ctx = withTiming(ctx, cmd.profile.timings) - ctx = cmd.incremental.apply(ctx) wd, err := os.Getwd() if err != nil { diff --git a/cmd/wire/diff_cmd.go b/cmd/wire/diff_cmd.go index c7facca..5aad2f1 100644 --- a/cmd/wire/diff_cmd.go +++ b/cmd/wire/diff_cmd.go @@ -31,7 +31,6 @@ import ( type diffCmd struct { headerFile string tags string - incremental optionalBoolFlag profile profileFlags } @@ -61,7 +60,6 @@ func (*diffCmd) Usage() string { func (cmd *diffCmd) SetFlags(f *flag.FlagSet) { f.StringVar(&cmd.headerFile, "header_file", "", "path to file to insert as a header in wire_gen.go") f.StringVar(&cmd.tags, "tags", "", "append build tags to the default wirebuild") - addIncrementalFlag(&cmd.incremental, f) cmd.profile.addFlags(f) } @@ -79,7 +77,6 @@ func (cmd *diffCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interf defer stop() totalStart := time.Now() ctx = withTiming(ctx, cmd.profile.timings) - ctx = cmd.incremental.apply(ctx) wd, err := os.Getwd() if err != nil { diff --git a/cmd/wire/gen_cmd.go b/cmd/wire/gen_cmd.go index e98556f..aceefee 100644 --- a/cmd/wire/gen_cmd.go +++ b/cmd/wire/gen_cmd.go @@ -29,7 +29,6 @@ type genCmd struct { headerFile string prefixFileName string tags string - incremental optionalBoolFlag profile profileFlags } @@ -56,7 +55,6 @@ func (cmd *genCmd) SetFlags(f *flag.FlagSet) { f.StringVar(&cmd.headerFile, "header_file", "", "path to file to insert as a header in wire_gen.go") f.StringVar(&cmd.prefixFileName, "output_file_prefix", "", "string to prepend to output file names.") f.StringVar(&cmd.tags, "tags", "", "append build tags to the default wirebuild") - addIncrementalFlag(&cmd.incremental, f) cmd.profile.addFlags(f) } @@ -70,7 +68,6 @@ func (cmd *genCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interfa defer stop() totalStart := time.Now() ctx = withTiming(ctx, cmd.profile.timings) - ctx = cmd.incremental.apply(ctx) wd, err := os.Getwd() if err != nil { diff --git a/cmd/wire/incremental_flag.go b/cmd/wire/incremental_flag.go deleted file mode 100644 index 2962128..0000000 --- a/cmd/wire/incremental_flag.go +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "context" - "flag" - "strconv" - - "github.com/goforj/wire/internal/wire" -) - -type optionalBoolFlag struct { - value bool - set bool -} - -func (f *optionalBoolFlag) String() string { - if f == nil { - return "" - } - return strconv.FormatBool(f.value) -} - -func (f *optionalBoolFlag) Set(s string) error { - v, err := strconv.ParseBool(s) - if err != nil { - return err - } - f.value = v - f.set = true - return nil -} - -func (f *optionalBoolFlag) IsBoolFlag() bool { - return true -} - -func (f *optionalBoolFlag) apply(ctx context.Context) context.Context { - if f == nil || !f.set { - return ctx - } - return wire.WithIncremental(ctx, f.value) -} - -func addIncrementalFlag(f *optionalBoolFlag, fs *flag.FlagSet) { - fs.Var(f, "incremental", "enable the incremental engine (overrides "+wire.IncrementalEnvVar+")") -} diff --git a/cmd/wire/main.go b/cmd/wire/main.go index efaf767..4426ee1 100644 --- a/cmd/wire/main.go +++ b/cmd/wire/main.go @@ -35,8 +35,6 @@ import ( "github.com/google/subcommands" ) -var topLevelIncremental optionalBoolFlag - const ( ansiRed = "\033[1;31m" ansiGreen = "\033[1;32m" @@ -52,12 +50,10 @@ func main() { subcommands.Register(subcommands.FlagsCommand(), "") subcommands.Register(subcommands.HelpCommand(), "") subcommands.Register(&checkCmd{}, "") - subcommands.Register(&cacheCmd{}, "") subcommands.Register(&diffCmd{}, "") subcommands.Register(&genCmd{}, "") subcommands.Register(&watchCmd{}, "") subcommands.Register(&showCmd{}, "") - addIncrementalFlag(&topLevelIncremental, flag.CommandLine) flag.Parse() // Initialize the default logger to log to stderr. @@ -74,7 +70,6 @@ func main() { "help": true, // builtin "flags": true, // builtin "check": true, - "cache": true, "diff": true, "gen": true, "serve": true, @@ -84,9 +79,9 @@ func main() { // Default to running the "gen" command. if args := flag.Args(); len(args) == 0 || !allCmds[args[0]] { genCmd := &genCmd{} - os.Exit(int(genCmd.Execute(topLevelIncremental.apply(context.Background()), flag.CommandLine))) + os.Exit(int(genCmd.Execute(context.Background(), flag.CommandLine))) } - os.Exit(int(subcommands.Execute(topLevelIncremental.apply(context.Background())))) + os.Exit(int(subcommands.Execute(context.Background()))) } // installStackDumper registers signal handlers to dump goroutine stacks. diff --git a/cmd/wire/show_cmd.go b/cmd/wire/show_cmd.go index 1313ade..10c737f 100644 --- a/cmd/wire/show_cmd.go +++ b/cmd/wire/show_cmd.go @@ -35,7 +35,6 @@ import ( type showCmd struct { tags string - incremental optionalBoolFlag profile profileFlags } @@ -63,7 +62,6 @@ func (*showCmd) Usage() string { // SetFlags registers flags for the subcommand. func (cmd *showCmd) SetFlags(f *flag.FlagSet) { f.StringVar(&cmd.tags, "tags", "", "append build tags to the default wirebuild") - addIncrementalFlag(&cmd.incremental, f) cmd.profile.addFlags(f) } @@ -77,7 +75,6 @@ func (cmd *showCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interf defer stop() totalStart := time.Now() ctx = withTiming(ctx, cmd.profile.timings) - ctx = cmd.incremental.apply(ctx) wd, err := os.Getwd() if err != nil { diff --git a/cmd/wire/watch_cmd.go b/cmd/wire/watch_cmd.go index cb1b31b..ebdfa0e 100644 --- a/cmd/wire/watch_cmd.go +++ b/cmd/wire/watch_cmd.go @@ -36,7 +36,6 @@ type watchCmd struct { headerFile string prefixFileName string tags string - incremental optionalBoolFlag profile profileFlags pollInterval time.Duration rescanInterval time.Duration @@ -64,7 +63,6 @@ func (cmd *watchCmd) SetFlags(f *flag.FlagSet) { f.StringVar(&cmd.headerFile, "header_file", "", "path to file to insert as a header in wire_gen.go") f.StringVar(&cmd.prefixFileName, "output_file_prefix", "", "string to prepend to output file names.") f.StringVar(&cmd.tags, "tags", "", "append build tags to the default wirebuild") - addIncrementalFlag(&cmd.incremental, f) f.DurationVar(&cmd.pollInterval, "poll_interval", 250*time.Millisecond, "interval between file stat checks") f.DurationVar(&cmd.rescanInterval, "rescan_interval", 2*time.Second, "interval to rescan for new or removed Go files") cmd.profile.addFlags(f) @@ -79,7 +77,6 @@ func (cmd *watchCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...inter } defer stop() ctx = withTiming(ctx, cmd.profile.timings) - ctx = cmd.incremental.apply(ctx) if cmd.pollInterval <= 0 { log.Println("poll_interval must be greater than zero") diff --git a/internal/loader/custom.go b/internal/loader/custom.go new file mode 100644 index 0000000..ffa2d48 --- /dev/null +++ b/internal/loader/custom.go @@ -0,0 +1,1014 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package loader + +import ( + "context" + "fmt" + "go/ast" + importerpkg "go/importer" + "go/parser" + "go/scanner" + "go/token" + "go/types" + "os" + "path/filepath" + "runtime" + "runtime/pprof" + "sort" + "strings" + "time" + + "golang.org/x/tools/go/gcexportdata" + "golang.org/x/tools/go/packages" +) + +type unsupportedError struct { + reason string +} + +func (e unsupportedError) Error() string { return e.reason } + +type packageMeta struct { + ImportPath string + Name string + Dir string + DepOnly bool + Export string + GoFiles []string + CompiledGoFiles []string + Imports []string + ImportMap map[string]string + Error *goListError +} + +type goListError struct { + Err string +} + +type customValidator struct { + fset *token.FileSet + meta map[string]*packageMeta + touched map[string]struct{} + packages map[string]*types.Package + importer types.Importer + loading map[string]bool +} + +type customTypedGraphLoader struct { + workspace string + ctx context.Context + fset *token.FileSet + meta map[string]*packageMeta + targets map[string]struct{} + parseFile ParseFileFunc + packages map[string]*packages.Package + typesPkgs map[string]*types.Package + importer types.Importer + loading map[string]bool + stats typedLoadStats +} + +type typedLoadStats struct { + read time.Duration + parse time.Duration + typecheck time.Duration + localRead time.Duration + externalRead time.Duration + localParse time.Duration + externalParse time.Duration + localTypecheck time.Duration + externalTypecheck time.Duration + filesRead int + packages int + localPackages int + externalPackages int + localFilesRead int + externalFilesRead int +} + +func validateTouchedPackagesCustom(ctx context.Context, req TouchedValidationRequest) (*TouchedValidationResult, error) { + if len(req.Touched) == 0 { + return &TouchedValidationResult{Backend: ModeCustom}, nil + } + meta, err := discoverTouchedMetadata(ctx, req) + if err != nil { + return nil, err + } + validator := &customValidator{ + fset: token.NewFileSet(), + meta: meta, + touched: make(map[string]struct{}, len(req.Touched)), + packages: make(map[string]*types.Package, len(meta)), + importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), + loading: make(map[string]bool, len(meta)), + } + for _, path := range req.Touched { + if !metadataMatchesFingerprint(path, meta, req.Local) { + return nil, unsupportedError{reason: "metadata fingerprint mismatch"} + } + validator.touched[path] = struct{}{} + } + out := make([]*packages.Package, 0, len(req.Touched)) + for _, path := range req.Touched { + pkg, err := validator.validatePackage(path) + if err != nil { + return nil, err + } + out = append(out, pkg) + } + return &TouchedValidationResult{ + Packages: out, + Backend: ModeCustom, + }, nil +} + +func loadRootGraphCustom(ctx context.Context, req RootLoadRequest) (*RootLoadResult, error) { + discoveryStart := time.Now() + meta, err := runGoList(ctx, goListRequest{ + WD: req.WD, + Env: req.Env, + Tags: req.Tags, + Patterns: req.Patterns, + NeedDeps: req.NeedDeps, + }) + if err != nil { + return nil, err + } + logTiming(ctx, "loader.custom.root.discovery", discoveryStart) + if len(meta) == 0 { + return nil, unsupportedError{reason: "empty go list result"} + } + pkgs := make(map[string]*packages.Package, len(meta)) + for path, m := range meta { + pkgs[path] = &packages.Package{ + ID: m.ImportPath, + Name: m.Name, + PkgPath: m.ImportPath, + GoFiles: append([]string(nil), metaFiles(m)...), + CompiledGoFiles: append([]string(nil), metaFiles(m)...), + Imports: make(map[string]*packages.Package), + } + if m.Error != nil && strings.TrimSpace(m.Error.Err) != "" { + pkgs[path].Errors = append(pkgs[path].Errors, packages.Error{ + Pos: "-", + Msg: m.Error.Err, + Kind: packages.ListError, + }) + } + } + for path, m := range meta { + pkg := pkgs[path] + for _, imp := range m.Imports { + target := imp + if mapped := m.ImportMap[imp]; mapped != "" { + target = mapped + } + if dep := pkgs[target]; dep != nil { + pkg.Imports[imp] = dep + } + } + } + roots := make([]*packages.Package, 0, len(req.Patterns)) + for _, m := range meta { + if m.DepOnly { + continue + } + if pkg := pkgs[m.ImportPath]; pkg != nil { + roots = append(roots, pkg) + } + } + if len(roots) == 0 { + return nil, unsupportedError{reason: "no root packages from metadata"} + } + sort.Slice(roots, func(i, j int) bool { return roots[i].PkgPath < roots[j].PkgPath }) + return &RootLoadResult{ + Packages: roots, + Backend: ModeCustom, + Discovery: discoverySnapshotForMeta(meta, req.NeedDeps), + }, nil +} + +func loadTypedPackageGraphCustom(ctx context.Context, req LazyLoadRequest) (*LazyLoadResult, error) { + stopProfile, profileErr := startLoaderCPUProfile(req.Env) + if profileErr != nil { + return nil, profileErr + } + if stopProfile != nil { + defer stopProfile() + } + var ( + meta map[string]*packageMeta + err error + ) + if req.Discovery != nil && len(req.Discovery.meta) > 0 { + meta = req.Discovery.meta + } else { + meta, err = runGoList(ctx, goListRequest{ + WD: req.WD, + Env: req.Env, + Tags: req.Tags, + Patterns: []string{req.Package}, + NeedDeps: true, + }) + if err != nil { + return nil, err + } + } + if len(meta) == 0 { + return nil, unsupportedError{reason: "empty go list result"} + } + fset := req.Fset + if fset == nil { + fset = token.NewFileSet() + } + l := &customTypedGraphLoader{ + workspace: detectModuleRoot(req.WD), + ctx: ctx, + fset: fset, + meta: meta, + targets: map[string]struct{}{req.Package: {}}, + parseFile: req.ParseFile, + packages: make(map[string]*packages.Package, len(meta)), + typesPkgs: make(map[string]*types.Package, len(meta)), + importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), + loading: make(map[string]bool, len(meta)), + } + root, err := l.loadPackage(req.Package) + if err != nil { + return nil, err + } + logDuration(ctx, "loader.custom.lazy.read_files.cumulative", l.stats.read) + logDuration(ctx, "loader.custom.lazy.parse_files.cumulative", l.stats.parse) + logDuration(ctx, "loader.custom.lazy.typecheck.cumulative", l.stats.typecheck) + logDuration(ctx, "loader.custom.lazy.read_files.local.cumulative", l.stats.localRead) + logDuration(ctx, "loader.custom.lazy.read_files.external.cumulative", l.stats.externalRead) + logDuration(ctx, "loader.custom.lazy.parse_files.local.cumulative", l.stats.localParse) + logDuration(ctx, "loader.custom.lazy.parse_files.external.cumulative", l.stats.externalParse) + logDuration(ctx, "loader.custom.lazy.typecheck.local.cumulative", l.stats.localTypecheck) + logDuration(ctx, "loader.custom.lazy.typecheck.external.cumulative", l.stats.externalTypecheck) + return &LazyLoadResult{ + Packages: []*packages.Package{root}, + Backend: ModeCustom, + }, nil +} + +func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLoadResult, error) { + meta, err := runGoList(ctx, goListRequest{ + WD: req.WD, + Env: req.Env, + Tags: req.Tags, + Patterns: req.Patterns, + NeedDeps: true, + }) + if err != nil { + return nil, err + } + if len(meta) == 0 { + return nil, unsupportedError{reason: "empty go list result"} + } + fset := req.Fset + if fset == nil { + fset = token.NewFileSet() + } + targets := make(map[string]struct{}) + for _, m := range meta { + if m.DepOnly { + continue + } + targets[m.ImportPath] = struct{}{} + } + if len(targets) == 0 { + return nil, unsupportedError{reason: "no root packages from metadata"} + } + l := &customTypedGraphLoader{ + workspace: detectModuleRoot(req.WD), + ctx: ctx, + fset: fset, + meta: meta, + targets: targets, + parseFile: req.ParseFile, + packages: make(map[string]*packages.Package, len(meta)), + typesPkgs: make(map[string]*types.Package, len(meta)), + importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), + loading: make(map[string]bool, len(meta)), + } + roots := make([]*packages.Package, 0, len(targets)) + for _, m := range meta { + if m.DepOnly { + continue + } + root, err := l.loadPackage(m.ImportPath) + if err != nil { + return nil, err + } + roots = append(roots, root) + } + sort.Slice(roots, func(i, j int) bool { return roots[i].PkgPath < roots[j].PkgPath }) + logDuration(ctx, "loader.custom.typed.read_files.cumulative", l.stats.read) + logDuration(ctx, "loader.custom.typed.parse_files.cumulative", l.stats.parse) + logDuration(ctx, "loader.custom.typed.typecheck.cumulative", l.stats.typecheck) + logDuration(ctx, "loader.custom.typed.read_files.local.cumulative", l.stats.localRead) + logDuration(ctx, "loader.custom.typed.read_files.external.cumulative", l.stats.externalRead) + logDuration(ctx, "loader.custom.typed.parse_files.local.cumulative", l.stats.localParse) + logDuration(ctx, "loader.custom.typed.parse_files.external.cumulative", l.stats.externalParse) + logDuration(ctx, "loader.custom.typed.typecheck.local.cumulative", l.stats.localTypecheck) + logDuration(ctx, "loader.custom.typed.typecheck.external.cumulative", l.stats.externalTypecheck) + return &PackageLoadResult{ + Packages: roots, + Backend: ModeCustom, + }, nil +} + +func (v *customValidator) validatePackage(path string) (*packages.Package, error) { + meta := v.meta[path] + if meta == nil { + return nil, unsupportedError{reason: "missing metadata for touched package"} + } + if v.loading[path] { + return nil, unsupportedError{reason: "touched package cycle"} + } + v.loading[path] = true + defer delete(v.loading, path) + pkg := &packages.Package{ + ID: meta.ImportPath, + Name: meta.Name, + PkgPath: meta.ImportPath, + Fset: v.fset, + GoFiles: append([]string(nil), metaFiles(meta)...), + CompiledGoFiles: append([]string(nil), metaFiles(meta)...), + Imports: make(map[string]*packages.Package), + ExportFile: meta.Export, + } + if meta.Error != nil && strings.TrimSpace(meta.Error.Err) != "" { + pkg.Errors = append(pkg.Errors, packages.Error{ + Pos: "-", + Msg: meta.Error.Err, + Kind: packages.ListError, + }) + return pkg, nil + } + files, errs := v.parseFiles(metaFiles(meta)) + pkg.Errors = append(pkg.Errors, errs...) + if len(files) == 0 { + return pkg, nil + } + + tpkg := types.NewPackage(meta.ImportPath, meta.Name) + v.packages[meta.ImportPath] = tpkg + info := &types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + Defs: make(map[*ast.Ident]types.Object), + Uses: make(map[*ast.Ident]types.Object), + Implicits: make(map[ast.Node]types.Object), + Scopes: make(map[ast.Node]*types.Scope), + Selections: make(map[*ast.SelectorExpr]*types.Selection), + } + importer := importerFunc(func(importPath string) (*types.Package, error) { + if importPath == "unsafe" { + return types.Unsafe, nil + } + target := importPath + if mapped := meta.ImportMap[importPath]; mapped != "" { + target = mapped + } + if _, ok := v.touched[target]; ok { + if typed := v.packages[target]; typed != nil && typed.Complete() { + if depMeta := v.meta[target]; depMeta != nil { + pkg.Imports[importPath] = touchedPackageStub(v.fset, depMeta) + } + return typed, nil + } + checked, err := v.validatePackage(target) + if err != nil { + return nil, err + } + pkg.Imports[importPath] = checked + if len(checked.Errors) > 0 { + return nil, fmt.Errorf("touched dependency %s has errors", target) + } + if typed := v.packages[target]; typed != nil { + return typed, nil + } + return nil, unsupportedError{reason: "missing typed touched dependency"} + } + dep, err := v.importFromExport(target) + if err == nil { + if depMeta := v.meta[target]; depMeta != nil { + pkg.Imports[importPath] = touchedPackageStub(v.fset, depMeta) + } else { + pkg.Imports[importPath] = &packages.Package{PkgPath: target, Name: dep.Name()} + } + } + return dep, err + }) + var typeErrors []packages.Error + cfg := &types.Config{ + Importer: importer, + Sizes: types.SizesFor("gc", runtime.GOARCH), + Error: func(err error) { + typeErrors = append(typeErrors, toPackagesError(v.fset, err)) + }, + } + checker := types.NewChecker(cfg, v.fset, tpkg, info) + if err := checker.Files(files); err != nil && len(typeErrors) == 0 { + typeErrors = append(typeErrors, toPackagesError(v.fset, err)) + } + pkg.Syntax = files + pkg.Types = tpkg + pkg.TypesInfo = info + typeErrors = append(typeErrors, v.validateDeclaredImports(meta, files)...) + pkg.Errors = append(pkg.Errors, typeErrors...) + return pkg, nil +} + +func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, error) { + if pkg := l.packages[path]; pkg != nil && (pkg.Types != nil || len(pkg.Errors) > 0) { + return pkg, nil + } + meta := l.meta[path] + if meta == nil { + return nil, unsupportedError{reason: "missing lazy-load metadata"} + } + if l.loading[path] { + if pkg := l.packages[path]; pkg != nil { + return pkg, nil + } + return nil, unsupportedError{reason: "lazy-load cycle"} + } + l.loading[path] = true + defer delete(l.loading, path) + l.stats.packages++ + isLocal := isWorkspacePackage(l.workspace, meta.Dir) + if isLocal { + l.stats.localPackages++ + } else { + l.stats.externalPackages++ + } + + pkg := l.packages[path] + if pkg == nil { + pkg = &packages.Package{ + ID: meta.ImportPath, + Name: meta.Name, + PkgPath: meta.ImportPath, + Fset: l.fset, + GoFiles: append([]string(nil), metaFiles(meta)...), + CompiledGoFiles: append([]string(nil), metaFiles(meta)...), + Imports: make(map[string]*packages.Package), + ExportFile: meta.Export, + } + l.packages[path] = pkg + } + files, parseErrs := l.parseFiles(metaFiles(meta), isLocal) + pkg.Errors = append(pkg.Errors, parseErrs...) + if len(files) == 0 { + if meta.Error != nil && strings.TrimSpace(meta.Error.Err) != "" { + pkg.Errors = append(pkg.Errors, packages.Error{ + Pos: "-", + Msg: meta.Error.Err, + Kind: packages.ListError, + }) + } + return pkg, nil + } + + tpkg := l.typesPkgs[path] + if tpkg == nil { + tpkg = types.NewPackage(meta.ImportPath, meta.Name) + l.typesPkgs[path] = tpkg + } + _, isTarget := l.targets[path] + needFullState := isTarget || isWorkspacePackage(l.workspace, meta.Dir) + var info *types.Info + if needFullState { + info = &types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + Defs: make(map[*ast.Ident]types.Object), + Uses: make(map[*ast.Ident]types.Object), + Implicits: make(map[ast.Node]types.Object), + Scopes: make(map[ast.Node]*types.Scope), + Selections: make(map[*ast.SelectorExpr]*types.Selection), + } + } + var typeErrors []packages.Error + cfg := &types.Config{ + Sizes: types.SizesFor("gc", runtime.GOARCH), + IgnoreFuncBodies: !isWorkspacePackage(l.workspace, meta.Dir), + Importer: importerFunc(func(importPath string) (*types.Package, error) { + if importPath == "unsafe" { + return types.Unsafe, nil + } + target := importPath + if mapped := meta.ImportMap[importPath]; mapped != "" { + target = mapped + } + dep, err := l.loadPackage(target) + if err != nil { + return nil, err + } + pkg.Imports[importPath] = dep + if dep.Types != nil { + return dep.Types, nil + } + if typed := l.typesPkgs[target]; typed != nil { + return typed, nil + } + if len(dep.Errors) > 0 { + return nil, unsupportedError{reason: "lazy-load dependency has errors"} + } + return nil, unsupportedError{reason: "missing typed lazy-load dependency"} + }), + Error: func(err error) { + typeErrors = append(typeErrors, toPackagesError(l.fset, err)) + }, + } + checker := types.NewChecker(cfg, l.fset, tpkg, info) + typecheckStart := time.Now() + if err := checker.Files(files); err != nil && len(typeErrors) == 0 { + typeErrors = append(typeErrors, toPackagesError(l.fset, err)) + } + typecheckDuration := time.Since(typecheckStart) + l.stats.typecheck += typecheckDuration + if isLocal { + l.stats.localTypecheck += typecheckDuration + } else { + l.stats.externalTypecheck += typecheckDuration + } + if needFullState { + pkg.Syntax = files + } else { + pkg.Syntax = nil + } + pkg.Types = tpkg + pkg.TypesInfo = info + pkg.Errors = append(pkg.Errors, typeErrors...) + return pkg, nil +} + +func (v *customValidator) importFromExport(path string) (*types.Package, error) { + if typed := v.packages[path]; typed != nil && typed.Complete() { + return typed, nil + } + if v.importer != nil { + if imported, err := v.importer.Import(path); err == nil { + v.packages[path] = imported + return imported, nil + } + } + meta := v.meta[path] + if meta == nil { + return nil, unsupportedError{reason: "missing dependency metadata"} + } + if meta.Export == "" { + return v.loadDependencyFromSource(path) + } + exportPath := meta.Export + if !filepath.IsAbs(exportPath) { + exportPath = filepath.Join(meta.Dir, exportPath) + } + f, err := os.Open(exportPath) + if err != nil { + return nil, unsupportedError{reason: "open export data"} + } + defer f.Close() + r, err := gcexportdata.NewReader(f) + if err != nil { + return nil, unsupportedError{reason: "read export data"} + } + view := make(map[string]*types.Package, len(v.packages)) + for pkgPath, pkg := range v.packages { + view[pkgPath] = pkg + } + tpkg, err := gcexportdata.Read(r, v.fset, view, path) + if err != nil { + return v.loadDependencyFromSource(path) + } + v.packages[path] = tpkg + return tpkg, nil +} + +func (v *customValidator) loadDependencyFromSource(path string) (*types.Package, error) { + if typed := v.packages[path]; typed != nil && typed.Complete() { + return typed, nil + } + meta := v.meta[path] + if meta == nil { + return nil, unsupportedError{reason: "missing source dependency metadata"} + } + if v.loading[path] { + if typed := v.packages[path]; typed != nil { + return typed, nil + } + return nil, unsupportedError{reason: "dependency cycle"} + } + v.loading[path] = true + defer delete(v.loading, path) + + tpkg := v.packages[path] + if tpkg == nil { + tpkg = types.NewPackage(meta.ImportPath, meta.Name) + v.packages[path] = tpkg + } + files, errs := v.parseFiles(metaFiles(meta)) + if len(errs) > 0 { + return nil, unsupportedError{reason: "dependency parse error"} + } + info := &types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + Defs: make(map[*ast.Ident]types.Object), + Uses: make(map[*ast.Ident]types.Object), + Implicits: make(map[ast.Node]types.Object), + Scopes: make(map[ast.Node]*types.Scope), + Selections: make(map[*ast.SelectorExpr]*types.Selection), + } + cfg := &types.Config{ + Importer: importerFunc(func(importPath string) (*types.Package, error) { + if importPath == "unsafe" { + return types.Unsafe, nil + } + target := importPath + if mapped := meta.ImportMap[importPath]; mapped != "" { + target = mapped + } + if _, ok := v.touched[target]; ok { + checked, err := v.validatePackage(target) + if err != nil { + return nil, err + } + if len(checked.Errors) > 0 { + return nil, unsupportedError{reason: "touched dependency has validation errors"} + } + return v.packages[target], nil + } + return v.importFromExport(target) + }), + Sizes: types.SizesFor("gc", runtime.GOARCH), + IgnoreFuncBodies: true, + } + if err := types.NewChecker(cfg, v.fset, tpkg, info).Files(files); err != nil { + return nil, unsupportedError{reason: "dependency typecheck error"} + } + return tpkg, nil +} + +func (v *customValidator) parseFiles(names []string) ([]*ast.File, []packages.Error) { + files := make([]*ast.File, 0, len(names)) + var errs []packages.Error + for _, name := range names { + src, err := os.ReadFile(name) + if err != nil { + errs = append(errs, packages.Error{Pos: name + ":1", Msg: err.Error(), Kind: packages.ParseError}) + continue + } + f, err := parser.ParseFile(v.fset, name, src, parser.AllErrors|parser.ParseComments) + if err != nil { + switch typed := err.(type) { + case scanner.ErrorList: + for _, parseErr := range typed { + errs = append(errs, packages.Error{ + Pos: parseErr.Pos.String(), + Msg: parseErr.Msg, + Kind: packages.ParseError, + }) + } + default: + errs = append(errs, packages.Error{Pos: name + ":1", Msg: err.Error(), Kind: packages.ParseError}) + } + } + if f != nil { + files = append(files, f) + } + } + return files, errs +} + +func (l *customTypedGraphLoader) parseFiles(names []string, isLocal bool) ([]*ast.File, []packages.Error) { + files := make([]*ast.File, 0, len(names)) + var errs []packages.Error + for _, name := range names { + readStart := time.Now() + src, err := os.ReadFile(name) + readDuration := time.Since(readStart) + l.stats.read += readDuration + l.stats.filesRead++ + if isLocal { + l.stats.localRead += readDuration + l.stats.localFilesRead++ + } else { + l.stats.externalRead += readDuration + l.stats.externalFilesRead++ + } + if err != nil { + errs = append(errs, packages.Error{Pos: name + ":1", Msg: err.Error(), Kind: packages.ParseError}) + continue + } + var f *ast.File + parseStart := time.Now() + if l.parseFile != nil { + f, err = l.parseFile(l.fset, name, src) + } else { + f, err = parser.ParseFile(l.fset, name, src, parser.AllErrors|parser.ParseComments) + } + parseDuration := time.Since(parseStart) + l.stats.parse += parseDuration + if isLocal { + l.stats.localParse += parseDuration + } else { + l.stats.externalParse += parseDuration + } + if err != nil { + switch typed := err.(type) { + case scanner.ErrorList: + for _, parseErr := range typed { + errs = append(errs, packages.Error{ + Pos: parseErr.Pos.String(), + Msg: parseErr.Msg, + Kind: packages.ParseError, + }) + } + default: + errs = append(errs, packages.Error{Pos: name + ":1", Msg: err.Error(), Kind: packages.ParseError}) + } + } + if f != nil { + files = append(files, f) + } + } + return files, errs +} + +func toPackagesError(fset *token.FileSet, err error) packages.Error { + switch typed := err.(type) { + case packages.Error: + return typed + case types.Error: + return packages.Error{ + Pos: typed.Fset.Position(typed.Pos).String(), + Msg: typed.Msg, + Kind: packages.TypeError, + } + default: + pos := "-" + if fset != nil { + if te, ok := err.(interface{ Pos() token.Pos }); ok { + pos = fset.Position(te.Pos()).String() + } + } + return packages.Error{Pos: pos, Msg: err.Error(), Kind: packages.UnknownError} + } +} + +type importerFunc func(path string) (*types.Package, error) + +func (f importerFunc) Import(path string) (*types.Package, error) { return f(path) } + +func (v *customValidator) validateDeclaredImports(meta *packageMeta, files []*ast.File) []packages.Error { + var errs []packages.Error + for _, file := range files { + used := usedImportsInFile(file) + for _, spec := range file.Imports { + if spec == nil || spec.Path == nil { + continue + } + path := strings.Trim(spec.Path.Value, "\"") + if path == "" { + continue + } + target := path + if mapped := meta.ImportMap[path]; mapped != "" { + target = mapped + } + name := importName(spec) + if name != "_" && name != "." { + if _, ok := used[name]; !ok { + errs = append(errs, packages.Error{ + Pos: v.fset.Position(spec.Pos()).String(), + Msg: fmt.Sprintf("%q imported and not used", path), + Kind: packages.TypeError, + }) + continue + } + } + if _, err := v.importFromExport(target); err != nil { + errs = append(errs, packages.Error{ + Pos: v.fset.Position(spec.Pos()).String(), + Msg: fmt.Sprintf("could not import %s", path), + Kind: packages.TypeError, + }) + } + } + } + return errs +} + +func usedImportsInFile(file *ast.File) map[string]struct{} { + used := make(map[string]struct{}) + ast.Inspect(file, func(node ast.Node) bool { + sel, ok := node.(*ast.SelectorExpr) + if !ok { + return true + } + ident, ok := sel.X.(*ast.Ident) + if !ok || ident.Name == "" { + return true + } + used[ident.Name] = struct{}{} + return true + }) + return used +} + +func importName(spec *ast.ImportSpec) string { + if spec == nil || spec.Path == nil { + return "" + } + if spec.Name != nil && spec.Name.Name != "" { + return spec.Name.Name + } + path := strings.Trim(spec.Path.Value, "\"") + if path == "" { + return "" + } + if slash := strings.LastIndex(path, "/"); slash >= 0 { + path = path[slash+1:] + } + return path +} + +func discoverTouchedMetadata(ctx context.Context, req TouchedValidationRequest) (map[string]*packageMeta, error) { + metas, err := runGoList(ctx, goListRequest{ + WD: req.WD, + Env: req.Env, + Tags: req.Tags, + Patterns: req.Touched, + NeedDeps: true, + }) + if err != nil { + return nil, err + } + if len(metas) == 0 { + return nil, unsupportedError{reason: "empty go list result"} + } + for _, touched := range req.Touched { + if _, ok := metas[touched]; !ok { + return nil, unsupportedError{reason: "missing touched package in metadata"} + } + } + return metas, nil +} + +func normalizeImports(imports []string, importMap map[string]string) []string { + if len(imports) == 0 { + return nil + } + out := make([]string, 0, len(imports)) + for _, imp := range imports { + if mapped := importMap[imp]; mapped != "" { + out = append(out, mapped) + continue + } + out = append(out, imp) + } + sort.Strings(out) + return out +} + +func metaFiles(meta *packageMeta) []string { + if meta == nil { + return nil + } + if len(meta.CompiledGoFiles) > 0 { + return meta.CompiledGoFiles + } + return meta.GoFiles +} + +func discoverySnapshotForMeta(meta map[string]*packageMeta, complete bool) *DiscoverySnapshot { + if !complete || len(meta) == 0 { + return nil + } + return &DiscoverySnapshot{meta: meta} +} + +func isWorkspacePackage(workspaceRoot, dir string) bool { + if workspaceRoot == "" || dir == "" { + return false + } + workspaceRoot = canonicalLoaderPath(workspaceRoot) + dir = canonicalLoaderPath(dir) + if dir == workspaceRoot { + return true + } + rel, err := filepath.Rel(workspaceRoot, dir) + if err != nil { + return false + } + return rel != ".." && !strings.HasPrefix(rel, ".."+string(filepath.Separator)) +} + +func detectModuleRoot(start string) string { + start = canonicalLoaderPath(start) + for dir := start; dir != "" && dir != "." && dir != string(filepath.Separator); dir = filepath.Dir(dir) { + if info, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil && !info.IsDir() { + return dir + } + next := filepath.Dir(dir) + if next == dir { + break + } + } + return start +} + +func canonicalLoaderPath(path string) string { + path = filepath.Clean(path) + if resolved, err := filepath.EvalSymlinks(path); err == nil && resolved != "" { + return filepath.Clean(resolved) + } + return path +} + +func startLoaderCPUProfile(env []string) (func(), error) { + path := envValue(env, "WIRE_LOADER_CPU_PROFILE") + if strings.TrimSpace(path) == "" { + return nil, nil + } + f, err := os.Create(path) + if err != nil { + return nil, err + } + if err := pprof.StartCPUProfile(f); err != nil { + _ = f.Close() + return nil, err + } + return func() { + pprof.StopCPUProfile() + _ = f.Close() + }, nil +} + +func envValue(env []string, key string) string { + for i := len(env) - 1; i >= 0; i-- { + name, value, ok := strings.Cut(env[i], "=") + if ok && name == key { + return value + } + } + return "" +} + + +func touchedPackageStub(fset *token.FileSet, meta *packageMeta) *packages.Package { + if meta == nil { + return nil + } + return &packages.Package{ + ID: meta.ImportPath, + Name: meta.Name, + PkgPath: meta.ImportPath, + Fset: fset, + GoFiles: append([]string(nil), metaFiles(meta)...), + CompiledGoFiles: append([]string(nil), metaFiles(meta)...), + Imports: make(map[string]*packages.Package), + ExportFile: meta.Export, + } +} + +func metadataMatchesFingerprint(pkgPath string, meta map[string]*packageMeta, local []LocalPackageFingerprint) bool { + for _, fp := range local { + if fp.PkgPath != pkgPath { + continue + } + pm := meta[pkgPath] + if pm == nil { + return false + } + want := append([]string(nil), fp.Files...) + got := append([]string(nil), metaFiles(pm)...) + sort.Strings(want) + sort.Strings(got) + if len(want) != len(got) { + return false + } + for i := range want { + if filepath.Clean(want[i]) != filepath.Clean(got[i]) { + return false + } + } + return true + } + return true +} diff --git a/internal/loader/discovery.go b/internal/loader/discovery.go new file mode 100644 index 0000000..a6aba46 --- /dev/null +++ b/internal/loader/discovery.go @@ -0,0 +1,94 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package loader + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" +) + +type goListRequest struct { + WD string + Env []string + Tags string + Patterns []string + NeedDeps bool +} + +func runGoList(ctx context.Context, req goListRequest) (map[string]*packageMeta, error) { + args := []string{"list", "-json", "-e", "-compiled", "-export"} + if req.NeedDeps { + args = append(args, "-deps") + } + if req.Tags != "" { + args = append(args, "-tags=wireinject "+req.Tags) + } else { + args = append(args, "-tags=wireinject") + } + args = append(args, "--") + args = append(args, req.Patterns...) + + cmd := exec.CommandContext(ctx, "go", args...) + cmd.Dir = req.WD + if len(req.Env) > 0 { + cmd.Env = req.Env + } else { + cmd.Env = os.Environ() + } + var stdout bytes.Buffer + var stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + if err := cmd.Run(); err != nil { + return nil, fmt.Errorf("go list: %w: %s", err, stderr.String()) + } + dec := json.NewDecoder(&stdout) + out := make(map[string]*packageMeta) + for { + var meta packageMeta + if err := dec.Decode(&meta); err != nil { + if err == io.EOF { + break + } + return nil, err + } + if meta.ImportPath == "" { + continue + } + for i, name := range meta.GoFiles { + if !filepath.IsAbs(name) { + meta.GoFiles[i] = filepath.Join(meta.Dir, name) + } + } + for i, name := range meta.CompiledGoFiles { + if !filepath.IsAbs(name) { + meta.CompiledGoFiles[i] = filepath.Join(meta.Dir, name) + } + } + if meta.Export != "" && !filepath.IsAbs(meta.Export) { + meta.Export = filepath.Join(meta.Dir, meta.Export) + } + meta.Imports = normalizeImports(meta.Imports, meta.ImportMap) + copyMeta := meta + out[meta.ImportPath] = ©Meta + } + return out, nil +} diff --git a/internal/loader/fallback.go b/internal/loader/fallback.go new file mode 100644 index 0000000..513694c --- /dev/null +++ b/internal/loader/fallback.go @@ -0,0 +1,224 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package loader + +import ( + "context" + "errors" + "go/token" + + "golang.org/x/tools/go/packages" +) + +type defaultLoader struct{} + +func (defaultLoader) LoadPackages(ctx context.Context, req PackageLoadRequest) (*PackageLoadResult, error) { + var unsupported unsupportedError + if req.LoaderMode != ModeFallback { + result, err := loadPackagesCustom(ctx, req) + if err == nil { + return result, nil + } + if !errors.As(err, &unsupported) { + return nil, err + } + } + result := &PackageLoadResult{ + Backend: ModeFallback, + } + switch req.LoaderMode { + case ModeFallback: + result.FallbackReason = FallbackReasonForcedFallback + default: + result.FallbackReason = FallbackReasonCustomUnsupported + if unsupported.reason != "" { + result.FallbackDetail = unsupported.reason + } + } + cfg := &packages.Config{ + Context: ctx, + Mode: req.Mode, + Dir: req.WD, + Env: req.Env, + BuildFlags: []string{"-tags=wireinject"}, + Fset: req.Fset, + } + if cfg.Fset == nil { + cfg.Fset = token.NewFileSet() + } + if req.ParseFile != nil { + cfg.ParseFile = req.ParseFile + } + if req.Tags != "" { + cfg.BuildFlags[0] += " " + req.Tags + } + escaped := make([]string, len(req.Patterns)) + for i := range req.Patterns { + escaped[i] = "pattern=" + req.Patterns[i] + } + pkgs, err := packages.Load(cfg, escaped...) + if err != nil { + return nil, err + } + result.Packages = pkgs + return result, nil +} + +func (defaultLoader) LoadRootGraph(ctx context.Context, req RootLoadRequest) (*RootLoadResult, error) { + var unsupported unsupportedError + if req.Mode != ModeFallback { + result, err := loadRootGraphCustom(ctx, req) + if err == nil { + return result, nil + } + if !errors.As(err, &unsupported) { + return nil, err + } + } + result := &RootLoadResult{ + Backend: ModeFallback, + } + switch req.Mode { + case ModeFallback: + result.FallbackReason = FallbackReasonForcedFallback + default: + result.FallbackReason = FallbackReasonCustomUnsupported + if unsupported.reason != "" { + result.FallbackDetail = unsupported.reason + } + } + cfg := &packages.Config{ + Context: ctx, + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports, + Dir: req.WD, + Env: req.Env, + BuildFlags: []string{"-tags=wireinject"}, + Fset: req.Fset, + } + if req.NeedDeps { + cfg.Mode |= packages.NeedDeps + } + if req.Fset == nil { + cfg.Fset = token.NewFileSet() + } + if req.Tags != "" { + cfg.BuildFlags[0] += " " + req.Tags + } + escaped := make([]string, len(req.Patterns)) + for i := range req.Patterns { + escaped[i] = "pattern=" + req.Patterns[i] + } + pkgs, err := packages.Load(cfg, escaped...) + if err != nil { + return nil, err + } + result.Packages = pkgs + return result, nil +} + +func (defaultLoader) LoadTypedPackageGraph(ctx context.Context, req LazyLoadRequest) (*LazyLoadResult, error) { + var unsupported unsupportedError + if req.LoaderMode != ModeFallback { + result, err := loadTypedPackageGraphCustom(ctx, req) + if err == nil { + return result, nil + } + if !errors.As(err, &unsupported) { + return nil, err + } + } + result := &LazyLoadResult{ + Backend: ModeFallback, + } + switch req.LoaderMode { + case ModeFallback: + result.FallbackReason = FallbackReasonForcedFallback + default: + result.FallbackReason = FallbackReasonCustomUnsupported + if unsupported.reason != "" { + result.FallbackDetail = unsupported.reason + } + } + cfg := &packages.Config{ + Context: ctx, + Mode: req.Mode, + Dir: req.WD, + Env: req.Env, + BuildFlags: []string{"-tags=wireinject"}, + Fset: req.Fset, + } + if cfg.Fset == nil { + cfg.Fset = token.NewFileSet() + } + if req.ParseFile != nil { + cfg.ParseFile = req.ParseFile + } + if req.Tags != "" { + cfg.BuildFlags[0] += " " + req.Tags + } + pkgs, err := packages.Load(cfg, "pattern="+req.Package) + if err != nil { + return nil, err + } + result.Packages = pkgs + return result, nil +} + +func (defaultLoader) ValidateTouchedPackages(ctx context.Context, req TouchedValidationRequest) (*TouchedValidationResult, error) { + var unsupported unsupportedError + if req.Mode != ModeFallback { + result, err := validateTouchedPackagesCustom(ctx, req) + if err == nil { + return result, nil + } + if !errors.As(err, &unsupported) { + return nil, err + } + } + return validateTouchedPackagesFallback(ctx, req, unsupported.reason) +} + +func validateTouchedPackagesFallback(ctx context.Context, req TouchedValidationRequest, detail string) (*TouchedValidationResult, error) { + result := &TouchedValidationResult{ + Backend: ModeFallback, + } + switch req.Mode { + case ModeFallback: + result.FallbackReason = FallbackReasonForcedFallback + default: + result.FallbackReason = FallbackReasonCustomUnsupported + result.FallbackDetail = detail + } + if len(req.Touched) == 0 { + return result, nil + } + cfg := &packages.Config{ + Context: ctx, + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedExportsFile | packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesSizes, + Dir: req.WD, + Env: req.Env, + BuildFlags: []string{"-tags=wireinject"}, + Fset: token.NewFileSet(), + } + if req.Tags != "" { + cfg.BuildFlags[0] += " " + req.Tags + } + pkgs, err := packages.Load(cfg, req.Touched...) + if err != nil { + return nil, err + } + result.Packages = pkgs + return result, nil +} diff --git a/internal/loader/loader.go b/internal/loader/loader.go new file mode 100644 index 0000000..e26747b --- /dev/null +++ b/internal/loader/loader.go @@ -0,0 +1,135 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package loader + +import ( + "context" + "go/ast" + "go/token" + + "golang.org/x/tools/go/packages" +) + +type Mode string + +const ( + ModeAuto Mode = "auto" + ModeCustom Mode = "custom" + ModeFallback Mode = "fallback" +) + +type FallbackReason string + +const ( + FallbackReasonNone FallbackReason = "" + FallbackReasonForcedFallback FallbackReason = "forced_fallback" + FallbackReasonCustomNotImplemented FallbackReason = "custom_not_implemented" + FallbackReasonCustomUnsupported FallbackReason = "custom_unsupported" +) + +type LocalPackageFingerprint struct { + PkgPath string + ContentHash string + ShapeHash string + Files []string +} + +type DiscoverySnapshot struct { + meta map[string]*packageMeta +} + +type TouchedValidationRequest struct { + WD string + Env []string + Tags string + Touched []string + Local []LocalPackageFingerprint + Mode Mode +} + +type TouchedValidationResult struct { + Packages []*packages.Package + Backend Mode + FallbackReason FallbackReason + FallbackDetail string +} + +type RootLoadRequest struct { + WD string + Env []string + Tags string + Patterns []string + NeedDeps bool + Mode Mode + Fset *token.FileSet +} + +type RootLoadResult struct { + Packages []*packages.Package + Backend Mode + FallbackReason FallbackReason + FallbackDetail string + Discovery *DiscoverySnapshot +} + +type PackageLoadRequest struct { + WD string + Env []string + Tags string + Patterns []string + Mode packages.LoadMode + LoaderMode Mode + Fset *token.FileSet + ParseFile ParseFileFunc +} + +type PackageLoadResult struct { + Packages []*packages.Package + Backend Mode + FallbackReason FallbackReason + FallbackDetail string +} + +type ParseFileFunc func(*token.FileSet, string, []byte) (*ast.File, error) + +type LazyLoadRequest struct { + WD string + Env []string + Tags string + Package string + Mode packages.LoadMode + LoaderMode Mode + Fset *token.FileSet + ParseFile ParseFileFunc + Discovery *DiscoverySnapshot +} + +type LazyLoadResult struct { + Packages []*packages.Package + Backend Mode + FallbackReason FallbackReason + FallbackDetail string +} + +type Loader interface { + LoadPackages(context.Context, PackageLoadRequest) (*PackageLoadResult, error) + LoadRootGraph(context.Context, RootLoadRequest) (*RootLoadResult, error) + LoadTypedPackageGraph(context.Context, LazyLoadRequest) (*LazyLoadResult, error) + ValidateTouchedPackages(context.Context, TouchedValidationRequest) (*TouchedValidationResult, error) +} + +func New() Loader { + return defaultLoader{} +} diff --git a/internal/loader/loader_test.go b/internal/loader/loader_test.go new file mode 100644 index 0000000..0e38d99 --- /dev/null +++ b/internal/loader/loader_test.go @@ -0,0 +1,821 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package loader + +import ( + "context" + "go/ast" + "go/parser" + "go/token" + "os" + "path/filepath" + "sort" + "strings" + "testing" + + "golang.org/x/tools/go/packages" +) + +func TestModeFromEnvDefaultAuto(t *testing.T) { + if got := ModeFromEnv(nil); got != ModeAuto { + t.Fatalf("ModeFromEnv(nil) = %q, want %q", got, ModeAuto) + } +} + +func TestModeFromEnvUsesLastMatchingValue(t *testing.T) { + env := []string{ + "WIRE_LOADER_MODE=fallback", + "OTHER=value", + "WIRE_LOADER_MODE=custom", + } + if got := ModeFromEnv(env); got != ModeCustom { + t.Fatalf("ModeFromEnv(...) = %q, want %q", got, ModeCustom) + } +} + +func TestModeFromEnvIgnoresInvalidValues(t *testing.T) { + env := []string{ + "WIRE_LOADER_MODE=invalid", + } + if got := ModeFromEnv(env); got != ModeAuto { + t.Fatalf("ModeFromEnv(...) = %q, want %q", got, ModeAuto) + } +} + +func TestFallbackLoaderReasonFromMode(t *testing.T) { + l := New() + + gotAuto, err := l.ValidateTouchedPackages(context.Background(), TouchedValidationRequest{ + WD: ".", + Env: []string{}, + Touched: []string{}, + Mode: ModeAuto, + }) + if err != nil { + t.Fatalf("ValidateTouchedPackages(auto) error = %v", err) + } + if gotAuto.Backend != ModeCustom { + t.Fatalf("auto backend = %q, want %q", gotAuto.Backend, ModeCustom) + } + if gotAuto.FallbackReason != FallbackReasonNone { + t.Fatalf("auto fallback reason = %q, want none", gotAuto.FallbackReason) + } + + gotForced, err := l.ValidateTouchedPackages(context.Background(), TouchedValidationRequest{ + WD: ".", + Env: []string{}, + Touched: []string{}, + Mode: ModeFallback, + }) + if err != nil { + t.Fatalf("ValidateTouchedPackages(fallback) error = %v", err) + } + if gotForced.FallbackReason != FallbackReasonForcedFallback { + t.Fatalf("forced fallback reason = %q, want %q", gotForced.FallbackReason, FallbackReasonForcedFallback) + } +} + +func TestCustomTouchedValidationSuccess(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "a", "a.go"), "package a\n\nimport \"fmt\"\n\nfunc Use() string { return fmt.Sprint(\"ok\") }\n") + + l := New() + got, err := l.ValidateTouchedPackages(context.Background(), TouchedValidationRequest{ + WD: root, + Env: os.Environ(), + Touched: []string{"example.com/app/a"}, + Mode: ModeCustom, + }) + if err != nil { + t.Fatalf("ValidateTouchedPackages(custom) error = %v", err) + } + if got.Backend != ModeCustom { + t.Fatalf("backend = %q, want %q", got.Backend, ModeCustom) + } + if got.FallbackReason != FallbackReasonNone { + t.Fatalf("fallback reason = %q, want none", got.FallbackReason) + } + if len(got.Packages) != 1 { + t.Fatalf("packages len = %d, want 1", len(got.Packages)) + } + if len(got.Packages[0].Errors) != 0 { + t.Fatalf("unexpected package errors: %+v", got.Packages[0].Errors) + } +} + +func TestValidateTouchedPackagesAutoUsesCustomWhenSupported(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "a", "a.go"), "package a\n\nfunc Use() string { return \"ok\" }\n") + + l := New() + got, err := l.ValidateTouchedPackages(context.Background(), TouchedValidationRequest{ + WD: root, + Env: os.Environ(), + Touched: []string{"example.com/app/a"}, + Mode: ModeAuto, + }) + if err != nil { + t.Fatalf("ValidateTouchedPackages(auto) error = %v", err) + } + if got.Backend != ModeCustom { + t.Fatalf("backend = %q, want %q", got.Backend, ModeCustom) + } + if got.FallbackReason != FallbackReasonNone { + t.Fatalf("fallback reason = %q, want none", got.FallbackReason) + } +} + +func TestCustomTouchedValidationTypeError(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "a", "a.go"), "package a\n\nfunc Broken() int { return missing }\n") + + l := New() + got, err := l.ValidateTouchedPackages(context.Background(), TouchedValidationRequest{ + WD: root, + Env: os.Environ(), + Touched: []string{"example.com/app/a"}, + Mode: ModeCustom, + }) + if err != nil { + t.Fatalf("ValidateTouchedPackages(custom) error = %v", err) + } + if got.Backend != ModeCustom { + t.Fatalf("backend = %q, want %q", got.Backend, ModeCustom) + } + if len(got.Packages) != 1 { + t.Fatalf("packages len = %d, want 1", len(got.Packages)) + } + if len(got.Packages[0].Errors) == 0 { + t.Fatal("expected type-check errors") + } +} + +func TestValidateTouchedPackagesCustomMatchesFallback(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), "package dep\n\ntype T struct{}\nfunc New() *T { return &T{} }\n") + writeTestFile(t, filepath.Join(root, "app", "app.go"), "package app\n\nimport \"example.com/app/dep\"\n\nfunc Use() *dep.T { return dep.New() }\n") + + l := New() + custom, err := l.ValidateTouchedPackages(context.Background(), TouchedValidationRequest{ + WD: root, + Env: os.Environ(), + Touched: []string{"example.com/app/app"}, + Mode: ModeCustom, + }) + if err != nil { + t.Fatalf("ValidateTouchedPackages(custom) error = %v", err) + } + fallback, err := l.ValidateTouchedPackages(context.Background(), TouchedValidationRequest{ + WD: root, + Env: os.Environ(), + Touched: []string{"example.com/app/app"}, + Mode: ModeFallback, + }) + if err != nil { + t.Fatalf("ValidateTouchedPackages(fallback) error = %v", err) + } + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, false) +} + +func TestValidateTouchedPackagesCustomMatchesFallbackTypeErrors(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "a", "a.go"), "package a\n\nfunc Broken() int { return missing }\n") + + l := New() + custom, err := l.ValidateTouchedPackages(context.Background(), TouchedValidationRequest{ + WD: root, + Env: os.Environ(), + Touched: []string{"example.com/app/a"}, + Mode: ModeCustom, + }) + if err != nil { + t.Fatalf("ValidateTouchedPackages(custom) error = %v", err) + } + fallback, err := l.ValidateTouchedPackages(context.Background(), TouchedValidationRequest{ + WD: root, + Env: os.Environ(), + Touched: []string{"example.com/app/a"}, + Mode: ModeFallback, + }) + if err != nil { + t.Fatalf("ValidateTouchedPackages(fallback) error = %v", err) + } + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, false) +} + +func TestValidateTouchedPackagesAutoReportsFallbackDetail(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "a", "a.go"), "package a\n\nfunc Use() string { return \"ok\" }\n") + + l := New() + got, err := l.ValidateTouchedPackages(context.Background(), TouchedValidationRequest{ + WD: root, + Env: os.Environ(), + Touched: []string{"example.com/app/a"}, + Local: []LocalPackageFingerprint{ + { + PkgPath: "example.com/app/a", + ContentHash: "wrong", + ShapeHash: "wrong", + Files: []string{filepath.Join(root, "a", "a.go")}, + }, + }, + Mode: ModeAuto, + }) + if err != nil { + t.Fatalf("ValidateTouchedPackages(auto) error = %v", err) + } + if got.Backend != ModeFallback { + t.Fatalf("backend = %q, want %q", got.Backend, ModeFallback) + } + if got.FallbackReason != FallbackReasonCustomUnsupported { + t.Fatalf("fallback reason = %q, want %q", got.FallbackReason, FallbackReasonCustomUnsupported) + } + if got.FallbackDetail != "metadata fingerprint mismatch" { + t.Fatalf("fallback detail = %q, want %q", got.FallbackDetail, "metadata fingerprint mismatch") + } +} + +func TestLoadRootGraphFallback(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "app", "app.go"), "package app\n\nimport _ \"fmt\"\n") + + l := New() + got, err := l.LoadRootGraph(context.Background(), RootLoadRequest{ + WD: root, + Env: os.Environ(), + Patterns: []string{"./app"}, + NeedDeps: true, + Mode: ModeFallback, + }) + if err != nil { + t.Fatalf("LoadRootGraph error = %v", err) + } + if got.Backend != ModeFallback { + t.Fatalf("backend = %q, want %q", got.Backend, ModeFallback) + } + if got.FallbackReason != FallbackReasonForcedFallback { + t.Fatalf("fallback reason = %q, want %q", got.FallbackReason, FallbackReasonForcedFallback) + } + if len(got.Packages) == 0 { + t.Fatal("expected loaded root packages") + } +} + +func TestLoadRootGraphCustom(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "app", "app.go"), "package app\n\nimport _ \"example.com/app/dep\"\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), "package dep\n") + + l := New() + got, err := l.LoadRootGraph(context.Background(), RootLoadRequest{ + WD: root, + Env: os.Environ(), + Patterns: []string{"./app"}, + NeedDeps: true, + Mode: ModeCustom, + }) + if err != nil { + t.Fatalf("LoadRootGraph(custom) error = %v", err) + } + if got.Backend != ModeCustom { + t.Fatalf("backend = %q, want %q", got.Backend, ModeCustom) + } + if got.FallbackReason != FallbackReasonNone { + t.Fatalf("fallback reason = %q, want none", got.FallbackReason) + } + if len(got.Packages) != 1 { + t.Fatalf("packages len = %d, want 1", len(got.Packages)) + } + if got.Packages[0].Imports["example.com/app/dep"] == nil { + t.Fatal("expected custom root graph to wire local import dependency") + } +} + +func TestLoadRootGraphAutoUsesCustomWhenSupported(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "app", "app.go"), "package app\n\nimport _ \"example.com/app/dep\"\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), "package dep\n") + + l := New() + got, err := l.LoadRootGraph(context.Background(), RootLoadRequest{ + WD: root, + Env: os.Environ(), + Patterns: []string{"./app"}, + NeedDeps: true, + Mode: ModeAuto, + Fset: token.NewFileSet(), + }) + if err != nil { + t.Fatalf("LoadRootGraph(auto) error = %v", err) + } + if got.Backend != ModeCustom { + t.Fatalf("backend = %q, want %q", got.Backend, ModeCustom) + } +} + +func TestMetaFilesFallsBackToGoFiles(t *testing.T) { + meta := &packageMeta{ + GoFiles: []string{"a.go", "b.go"}, + } + got := metaFiles(meta) + if len(got) != 2 || got[0] != "a.go" || got[1] != "b.go" { + t.Fatalf("metaFiles(go-only) = %v, want GoFiles fallback", got) + } + + meta.CompiledGoFiles = []string{"c.go"} + got = metaFiles(meta) + if len(got) != 1 || got[0] != "c.go" { + t.Fatalf("metaFiles(compiled) = %v, want CompiledGoFiles", got) + } +} + +func TestLoadTypedPackageGraphFallback(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "app", "app.go"), "package app\n\nfunc Value() int { return 42 }\n") + + var parseCalls int + l := New() + got, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: os.Environ(), + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeFallback, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + parseCalls++ + return parser.ParseFile(fset, filename, src, parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph error = %v", err) + } + if got.Backend != ModeFallback { + t.Fatalf("backend = %q, want %q", got.Backend, ModeFallback) + } + if got.FallbackReason != FallbackReasonForcedFallback { + t.Fatalf("fallback reason = %q, want %q", got.FallbackReason, FallbackReasonForcedFallback) + } + if len(got.Packages) != 1 { + t.Fatalf("packages len = %d, want 1", len(got.Packages)) + } + if parseCalls == 0 { + t.Fatal("expected ParseFile hook to be used") + } +} + +func TestLoadTypedPackageGraphCustom(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), "package dep\n\ntype T struct{}\nfunc New() *T { return &T{} }\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport \"example.com/app/dep\"\n\nfunc Init() *dep.T { return dep.New() }\n") + + var parseCalls int + l := New() + got, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: os.Environ(), + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + parseCalls++ + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) + } + if got.Backend != ModeCustom { + t.Fatalf("backend = %q, want %q", got.Backend, ModeCustom) + } + if got.FallbackReason != FallbackReasonNone { + t.Fatalf("fallback reason = %q, want none", got.FallbackReason) + } + if len(got.Packages) != 1 { + t.Fatalf("packages len = %d, want 1", len(got.Packages)) + } + rootPkg := got.Packages[0] + if rootPkg.Types == nil || rootPkg.TypesInfo == nil || len(rootPkg.Syntax) == 0 { + t.Fatalf("root package missing typed syntax: %+v", rootPkg) + } + depPkg := rootPkg.Imports["example.com/app/dep"] + if depPkg == nil || depPkg.Types == nil || len(depPkg.Syntax) == 0 { + t.Fatalf("dep package missing typed syntax: %+v", depPkg) + } + if parseCalls < 2 { + t.Fatalf("parseCalls = %d, want at least 2", parseCalls) + } +} + +func TestLoadTypedPackageGraphAutoUsesCustomWhenSupported(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), "package dep\n\ntype T struct{}\nfunc New() *T { return &T{} }\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport \"example.com/app/dep\"\n\nfunc Init() *dep.T { return dep.New() }\n") + + l := New() + got, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: os.Environ(), + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeAuto, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(auto) error = %v", err) + } + if got.Backend != ModeCustom { + t.Fatalf("backend = %q, want %q", got.Backend, ModeCustom) + } +} + +func TestLoadTypedPackageGraphCustomKeepsExternalPackagesLight(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport \"fmt\"\n\nfunc Init() string { return fmt.Sprint(\"ok\") }\n") + + l := New() + got, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: os.Environ(), + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) + } + rootPkg := got.Packages[0] + fmtPkg := rootPkg.Imports["fmt"] + if fmtPkg == nil { + t.Fatal("expected fmt import package") + } + if fmtPkg.Types == nil { + t.Fatalf("fmt package missing types: %+v", fmtPkg) + } + if fmtPkg.TypesInfo != nil { + t.Fatalf("fmt package TypesInfo should be nil, got %+v", fmtPkg.TypesInfo) + } + if len(fmtPkg.Syntax) != 0 { + t.Fatalf("fmt package Syntax len = %d, want 0", len(fmtPkg.Syntax)) + } +} + +func TestLoadRootGraphCustomMatchesFallback(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), "package dep\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport _ \"example.com/app/dep\"\n") + + l := New() + custom, err := l.LoadRootGraph(context.Background(), RootLoadRequest{ + WD: root, + Env: os.Environ(), + Patterns: []string{"./app"}, + NeedDeps: true, + Mode: ModeCustom, + Fset: token.NewFileSet(), + }) + if err != nil { + t.Fatalf("LoadRootGraph(custom) error = %v", err) + } + fallback, err := l.LoadRootGraph(context.Background(), RootLoadRequest{ + WD: root, + Env: os.Environ(), + Patterns: []string{"./app"}, + NeedDeps: true, + Mode: ModeFallback, + Fset: token.NewFileSet(), + }) + if err != nil { + t.Fatalf("LoadRootGraph(fallback) error = %v", err) + } + comparePackageGraphs(t, custom.Packages, fallback.Packages, false) +} + +func TestLoadTypedPackageGraphCustomMatchesFallback(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), "package dep\n\ntype T struct{}\nfunc New() *T { return &T{} }\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport \"example.com/app/dep\"\n\nfunc Init() *dep.T { return dep.New() }\n") + + l := New() + custom, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: os.Environ(), + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) + } + fallback, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: os.Environ(), + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeFallback, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(fallback) error = %v", err) + } + comparePackageGraphs(t, custom.Packages, fallback.Packages, true) +} + +func TestLoadTypedPackageGraphCustomMatchesFallbackTypeErrors(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nfunc Broken() int { return missing }\n") + + l := New() + custom, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: os.Environ(), + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) + } + fallback, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: os.Environ(), + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeFallback, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(fallback) error = %v", err) + } + comparePackageGraphs(t, custom.Packages, fallback.Packages, true) +} + +func comparePackageGraphs(t *testing.T, got []*packages.Package, want []*packages.Package, requireTyped bool) { + t.Helper() + gotAll := collectGraph(got) + wantAll := collectGraph(want) + if len(gotAll) != len(wantAll) { + t.Fatalf("package graph size = %d, want %d", len(gotAll), len(wantAll)) + } + for path, wantPkg := range wantAll { + gotPkg := gotAll[path] + if gotPkg == nil { + t.Fatalf("missing package %q in custom graph", path) + } + if gotPkg.Name != wantPkg.Name { + t.Fatalf("package %q name = %q, want %q", path, gotPkg.Name, wantPkg.Name) + } + if !equalStrings(gotPkg.CompiledGoFiles, wantPkg.CompiledGoFiles) { + t.Fatalf("package %q compiled files = %v, want %v", path, gotPkg.CompiledGoFiles, wantPkg.CompiledGoFiles) + } + if !equalImportPaths(gotPkg.Imports, wantPkg.Imports) { + t.Fatalf("package %q imports = %v, want %v", path, sortedImportPaths(gotPkg.Imports), sortedImportPaths(wantPkg.Imports)) + } + gotErrs := comparableErrors(gotPkg.Errors) + wantErrs := comparableErrors(wantPkg.Errors) + if len(gotErrs) != len(wantErrs) { + t.Fatalf("package %q comparable errors len = %d, want %d; got=%v want=%v", path, len(gotErrs), len(wantErrs), gotErrs, wantErrs) + } + for i := range gotErrs { + if gotErrs[i] != wantErrs[i] { + t.Fatalf("package %q comparable error[%d] = %q, want %q", path, i, gotErrs[i], wantErrs[i]) + } + } + if requireTyped { + gotTyped := gotPkg.Types != nil && gotPkg.TypesInfo != nil && len(gotPkg.Syntax) > 0 + wantTyped := wantPkg.Types != nil && wantPkg.TypesInfo != nil && len(wantPkg.Syntax) > 0 + if gotTyped != wantTyped { + t.Fatalf("package %q typed state = %v, want %v", path, gotTyped, wantTyped) + } + } + } +} + +func compareRootPackagesOnly(t *testing.T, got []*packages.Package, want []*packages.Package, requireTyped bool) { + t.Helper() + if len(got) != len(want) { + t.Fatalf("root package count = %d, want %d", len(got), len(want)) + } + gotByPath := make(map[string]*packages.Package, len(got)) + for _, pkg := range got { + gotByPath[pkg.PkgPath] = pkg + } + for _, wantPkg := range want { + gotPkg := gotByPath[wantPkg.PkgPath] + if gotPkg == nil { + t.Fatalf("missing root package %q", wantPkg.PkgPath) + } + if gotPkg.Name != wantPkg.Name { + t.Fatalf("package %q name = %q, want %q", wantPkg.PkgPath, gotPkg.Name, wantPkg.Name) + } + if !equalStrings(gotPkg.CompiledGoFiles, wantPkg.CompiledGoFiles) { + t.Fatalf("package %q compiled files = %v, want %v", wantPkg.PkgPath, gotPkg.CompiledGoFiles, wantPkg.CompiledGoFiles) + } + if !equalImportPaths(gotPkg.Imports, wantPkg.Imports) { + t.Fatalf("package %q imports = %v, want %v", wantPkg.PkgPath, sortedImportPaths(gotPkg.Imports), sortedImportPaths(wantPkg.Imports)) + } + gotErrs := comparableErrors(gotPkg.Errors) + wantErrs := comparableErrors(wantPkg.Errors) + if len(gotErrs) != len(wantErrs) { + t.Fatalf("package %q comparable errors len = %d, want %d; got=%v want=%v", wantPkg.PkgPath, len(gotErrs), len(wantErrs), gotErrs, wantErrs) + } + for i := range gotErrs { + if gotErrs[i] != wantErrs[i] { + t.Fatalf("package %q comparable error[%d] = %q, want %q", wantPkg.PkgPath, i, gotErrs[i], wantErrs[i]) + } + } + if requireTyped { + gotTyped := gotPkg.Types != nil && gotPkg.TypesInfo != nil && len(gotPkg.Syntax) > 0 + wantTyped := wantPkg.Types != nil && wantPkg.TypesInfo != nil && len(wantPkg.Syntax) > 0 + if gotTyped != wantTyped { + t.Fatalf("package %q typed state = %v, want %v", wantPkg.PkgPath, gotTyped, wantTyped) + } + } + } +} + +func collectGraph(roots []*packages.Package) map[string]*packages.Package { + out := make(map[string]*packages.Package) + stack := append([]*packages.Package(nil), roots...) + for len(stack) > 0 { + pkg := stack[len(stack)-1] + stack = stack[:len(stack)-1] + if pkg == nil || out[pkg.PkgPath] != nil { + continue + } + out[pkg.PkgPath] = pkg + for _, imp := range pkg.Imports { + stack = append(stack, imp) + } + } + return out +} + +func equalStrings(a, b []string) bool { + if len(a) != len(b) { + return false + } + aCopy := append([]string(nil), a...) + bCopy := append([]string(nil), b...) + for i := range aCopy { + aCopy[i] = normalizePathForCompare(aCopy[i]) + } + for i := range bCopy { + bCopy[i] = normalizePathForCompare(bCopy[i]) + } + sort.Strings(aCopy) + sort.Strings(bCopy) + for i := range aCopy { + if aCopy[i] != bCopy[i] { + return false + } + } + return true +} + +func equalImportPaths(a, b map[string]*packages.Package) bool { + return equalStrings(sortedImportPaths(a), sortedImportPaths(b)) +} + +func sortedImportPaths(m map[string]*packages.Package) []string { + out := make([]string, 0, len(m)) + for path := range m { + out = append(out, path) + } + sort.Strings(out) + return out +} + +func normalizePathForCompare(path string) string { + if path == "" { + return "" + } + if resolved, err := filepath.EvalSymlinks(path); err == nil && resolved != "" { + return filepath.Clean(resolved) + } + return filepath.Clean(path) +} + +func comparableErrors(errs []packages.Error) []string { + seen := make(map[string]struct{}, len(errs)) + out := make([]string, 0, len(errs)) + add := func(value string) { + if _, ok := seen[value]; ok { + return + } + seen[value] = struct{}{} + out = append(out, value) + } + for _, err := range errs { + if strings.HasPrefix(err.Msg, "# ") { + for _, value := range expandSummaryDiagnostics(err.Msg) { + add(value) + } + continue + } + pos := normalizeErrorPos(err.Pos) + add(pos + "|" + err.Msg) + } + sort.Strings(out) + return out +} + +func normalizeErrorPos(pos string) string { + if pos == "" || pos == "-" { + return pos + } + parts := strings.Split(pos, ":") + if len(parts) < 2 { + return shortenComparablePath(normalizePathForCompare(pos)) + } + path := shortenComparablePath(normalizePathForCompare(parts[0])) + return strings.Join(append([]string{path}, parts[1:]...), ":") +} + +func expandSummaryDiagnostics(msg string) []string { + lines := strings.Split(msg, "\n") + out := make([]string, 0, len(lines)) + for _, line := range lines[1:] { + line = strings.TrimSpace(line) + if line == "" { + continue + } + if parts := strings.SplitN(line, ": ", 2); len(parts) == 2 { + pos := normalizeErrorPos(parts[0]) + out = append(out, pos+"|"+parts[1]) + continue + } + out = append(out, line) + } + return out +} + +func shortenComparablePath(path string) string { + path = filepath.Clean(path) + parts := strings.Split(path, string(filepath.Separator)) + if len(parts) >= 2 { + return filepath.Join(parts[len(parts)-2], parts[len(parts)-1]) + } + return path +} + +func writeTestFile(t *testing.T, path string, content string) { + t.Helper() + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatalf("MkdirAll(%q) error = %v", path, err) + } + if err := os.WriteFile(path, []byte(content), 0o644); err != nil { + t.Fatalf("WriteFile(%q) error = %v", path, err) + } +} diff --git a/internal/loader/mode.go b/internal/loader/mode.go new file mode 100644 index 0000000..b08710b --- /dev/null +++ b/internal/loader/mode.go @@ -0,0 +1,38 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package loader + +import "strings" + +const ModeEnvVar = "WIRE_LOADER_MODE" + +func ModeFromEnv(env []string) Mode { + mode := ModeAuto + for _, entry := range env { + name, value, ok := strings.Cut(entry, "=") + if !ok || name != ModeEnvVar { + continue + } + switch strings.ToLower(strings.TrimSpace(value)) { + case string(ModeCustom): + mode = ModeCustom + case string(ModeFallback): + mode = ModeFallback + case "", string(ModeAuto): + mode = ModeAuto + } + } + return mode +} diff --git a/internal/loader/timing.go b/internal/loader/timing.go new file mode 100644 index 0000000..0211f17 --- /dev/null +++ b/internal/loader/timing.go @@ -0,0 +1,41 @@ +package loader + +import ( + "context" + "time" +) + +type timingLogger func(string, time.Duration) + +type timingKey struct{} + +func WithTiming(ctx context.Context, logf func(string, time.Duration)) context.Context { + if logf == nil { + return ctx + } + return context.WithValue(ctx, timingKey{}, timingLogger(logf)) +} + +func timing(ctx context.Context) timingLogger { + if ctx == nil { + return nil + } + if v := ctx.Value(timingKey{}); v != nil { + if t, ok := v.(timingLogger); ok { + return t + } + } + return nil +} + +func logTiming(ctx context.Context, label string, start time.Time) { + if t := timing(ctx); t != nil { + t(label, time.Since(start)) + } +} + +func logDuration(ctx context.Context, label string, d time.Duration) { + if t := timing(ctx); t != nil { + t(label, d) + } +} diff --git a/internal/wire/cache_bypass.go b/internal/wire/cache_bypass.go deleted file mode 100644 index b195eef..0000000 --- a/internal/wire/cache_bypass.go +++ /dev/null @@ -1,17 +0,0 @@ -package wire - -import "context" - -type bypassPackageCacheKey struct{} - -func withBypassPackageCache(ctx context.Context) context.Context { - return context.WithValue(ctx, bypassPackageCacheKey{}, true) -} - -func bypassPackageCache(ctx context.Context) bool { - if ctx == nil { - return false - } - v, _ := ctx.Value(bypassPackageCacheKey{}).(bool) - return v -} diff --git a/internal/wire/cache_coverage_test.go b/internal/wire/cache_coverage_test.go deleted file mode 100644 index faf6e62..0000000 --- a/internal/wire/cache_coverage_test.go +++ /dev/null @@ -1,1099 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "bytes" - "errors" - "io/fs" - "os" - "path/filepath" - "sort" - "sync" - "testing" - - "golang.org/x/tools/go/packages" -) - -type cacheHookState struct { - osCreateTemp func(string, string) (*os.File, error) - osMkdirAll func(string, os.FileMode) error - osReadFile func(string) ([]byte, error) - osRemove func(string) error - osRemoveAll func(string) error - osRename func(string, string) error - osStat func(string) (os.FileInfo, error) - osTempDir func() string - jsonMarshal func(any) ([]byte, error) - jsonUnmarshal func([]byte, any) error - extraCachePathsFunc func(string) []string - cacheKeyForPackage func(*packages.Package, *GenerateOptions) (string, error) - detectOutputDir func([]string) (string, error) - buildCacheFiles func([]string) ([]cacheFile, error) - buildCacheFilesFrom func([]cacheFile) ([]cacheFile, error) - rootPackageFiles func(*packages.Package) []string - hashFiles func([]string) (string, error) -} - -var cacheHooksMu sync.Mutex - -func lockCacheHooks(t *testing.T) { - t.Helper() - cacheHooksMu.Lock() - t.Cleanup(func() { - cacheHooksMu.Unlock() - }) -} - -func saveCacheHooks() cacheHookState { - return cacheHookState{ - osCreateTemp: osCreateTemp, - osMkdirAll: osMkdirAll, - osReadFile: osReadFile, - osRemove: osRemove, - osRemoveAll: osRemoveAll, - osRename: osRename, - osStat: osStat, - osTempDir: osTempDir, - jsonMarshal: jsonMarshal, - jsonUnmarshal: jsonUnmarshal, - extraCachePathsFunc: extraCachePathsFunc, - cacheKeyForPackage: cacheKeyForPackageFunc, - detectOutputDir: detectOutputDirFunc, - buildCacheFiles: buildCacheFilesFunc, - buildCacheFilesFrom: buildCacheFilesFromMetaFunc, - rootPackageFiles: rootPackageFilesFunc, - hashFiles: hashFilesFunc, - } -} - -func restoreCacheHooks(state cacheHookState) { - osCreateTemp = state.osCreateTemp - osMkdirAll = state.osMkdirAll - osReadFile = state.osReadFile - osRemove = state.osRemove - osRemoveAll = state.osRemoveAll - osRename = state.osRename - osStat = state.osStat - osTempDir = state.osTempDir - jsonMarshal = state.jsonMarshal - jsonUnmarshal = state.jsonUnmarshal - extraCachePathsFunc = state.extraCachePathsFunc - cacheKeyForPackageFunc = state.cacheKeyForPackage - detectOutputDirFunc = state.detectOutputDir - buildCacheFilesFunc = state.buildCacheFiles - buildCacheFilesFromMetaFunc = state.buildCacheFilesFrom - rootPackageFilesFunc = state.rootPackageFiles - hashFilesFunc = state.hashFiles -} - -func writeTempFile(t *testing.T, dir, name, content string) string { - t.Helper() - path := filepath.Join(dir, name) - if err := os.WriteFile(path, []byte(content), 0644); err != nil { - t.Fatalf("WriteFile(%s) failed: %v", path, err) - } - return path -} - -func cloneManifest(src *cacheManifest) *cacheManifest { - if src == nil { - return nil - } - dst := *src - if src.Patterns != nil { - dst.Patterns = append([]string(nil), src.Patterns...) - } - if src.ExtraFiles != nil { - dst.ExtraFiles = append([]cacheFile(nil), src.ExtraFiles...) - } - if src.Packages != nil { - dst.Packages = make([]manifestPackage, len(src.Packages)) - for i, pkg := range src.Packages { - dstPkg := pkg - if pkg.Files != nil { - dstPkg.Files = append([]cacheFile(nil), pkg.Files...) - } - if pkg.RootFiles != nil { - dstPkg.RootFiles = append([]cacheFile(nil), pkg.RootFiles...) - } - dst.Packages[i] = dstPkg - } - } - return &dst -} - -func TestCacheStoreReadWrite(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - osTempDir = func() string { return tempDir } - - if got := CacheDir(); got == "" { - t.Fatal("expected CacheDir to return a value") - } - - key := "cache-store" - want := []byte("content") - writeCache(key, want) - - got, ok := readCache(key) - if !ok { - t.Fatal("expected cache hit") - } - if !bytes.Equal(got, want) { - t.Fatalf("cache content mismatch: got %q, want %q", got, want) - } - if err := ClearCache(); err != nil { - t.Fatalf("ClearCache failed: %v", err) - } - if _, ok := readCache(key); ok { - t.Fatal("expected cache miss after clear") - } -} - -func TestClearCacheClearsIncrementalSessions(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - osTempDir = func() string { return tempDir } - - sessionA := getIncrementalSession("/tmp/example", []string{"WIRE_INCREMENTAL=1"}, "") - if sessionA == nil { - t.Fatal("expected incremental session") - } - sessionB := getIncrementalSession("/tmp/example", []string{"WIRE_INCREMENTAL=1"}, "") - if sessionA != sessionB { - t.Fatal("expected same incremental session before clear") - } - - if err := ClearCache(); err != nil { - t.Fatalf("ClearCache failed: %v", err) - } - - sessionC := getIncrementalSession("/tmp/example", []string{"WIRE_INCREMENTAL=1"}, "") - if sessionC == nil { - t.Fatal("expected incremental session after clear") - } - if sessionC == sessionA { - t.Fatal("expected ClearCache to drop in-process incremental sessions") - } -} - -func TestCacheStoreReadError(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - osReadFile = func(string) ([]byte, error) { - return nil, errors.New("boom") - } - if _, ok := readCache("missing"); ok { - t.Fatal("expected cache miss on read error") - } -} - -func TestCacheStoreWriteErrors(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - osTempDir = func() string { return tempDir } - - t.Run("mkdir", func(t *testing.T) { - osMkdirAll = func(string, os.FileMode) error { return errors.New("mkdir") } - writeCache("mkdir", []byte("data")) - }) - - t.Run("create", func(t *testing.T) { - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osCreateTemp = func(string, string) (*os.File, error) { - return nil, errors.New("create") - } - writeCache("create", []byte("data")) - }) - - t.Run("write", func(t *testing.T) { - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osCreateTemp = func(dir, pattern string) (*os.File, error) { - tmp, err := os.CreateTemp(dir, pattern) - if err != nil { - return nil, err - } - name := tmp.Name() - if err := tmp.Close(); err != nil { - return nil, err - } - return os.Open(name) - } - writeCache("write", []byte("data")) - }) - - t.Run("rename-exist", func(t *testing.T) { - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osRename = func(string, string) error { - return fs.ErrExist - } - writeCache("exist", []byte("data")) - }) - - t.Run("rename", func(t *testing.T) { - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osRename = func(string, string) error { - return errors.New("rename") - } - writeCache("rename", []byte("data")) - }) -} - -func TestCacheDirError(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - osRemoveAll = func(string) error { return errors.New("remove") } - if err := ClearCache(); err == nil { - t.Fatal("expected ClearCache error") - } -} - -func TestPackageFiles(t *testing.T) { - tempDir := t.TempDir() - rootFile := writeTempFile(t, tempDir, "root.go", "package root\n") - childFile := writeTempFile(t, tempDir, "child.go", "package child\n") - - child := &packages.Package{ - PkgPath: "example.com/child", - CompiledGoFiles: []string{childFile}, - } - root := &packages.Package{ - PkgPath: "example.com/root", - GoFiles: []string{rootFile}, - Imports: map[string]*packages.Package{ - "child": child, - "dup": child, - "nil": nil, - }, - } - got := packageFiles(root) - sort.Strings(got) - if len(got) != 2 { - t.Fatalf("expected 2 files, got %d", len(got)) - } - if got[0] != childFile || got[1] != rootFile { - t.Fatalf("unexpected files: %v", got) - } -} - -func TestCacheKeyEmptyPackage(t *testing.T) { - key, err := cacheKeyForPackage(&packages.Package{PkgPath: "example.com/empty"}, &GenerateOptions{}) - if err != nil { - t.Fatalf("cacheKeyForPackage error: %v", err) - } - if key != "" { - t.Fatalf("expected empty cache key, got %q", key) - } -} - -func TestCacheKeyMetaHit(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - osTempDir = func() string { return tempDir } - - file := writeTempFile(t, tempDir, "hit.go", "package hit\n") - pkg := &packages.Package{ - PkgPath: "example.com/hit", - GoFiles: []string{file}, - } - opts := &GenerateOptions{} - files := packageFiles(pkg) - sort.Strings(files) - contentHash, err := contentHashForFiles(pkg, opts, files) - if err != nil { - t.Fatalf("contentHashForFiles error: %v", err) - } - rootFiles := rootPackageFiles(pkg) - sort.Strings(rootFiles) - rootHash, err := hashFiles(rootFiles) - if err != nil { - t.Fatalf("hashFiles error: %v", err) - } - metaFiles, err := buildCacheFiles(files) - if err != nil { - t.Fatalf("buildCacheFiles error: %v", err) - } - meta := &cacheMeta{ - Version: cacheVersion, - PkgPath: pkg.PkgPath, - Tags: opts.Tags, - Prefix: opts.PrefixOutputFile, - HeaderHash: headerHash(opts.Header), - Files: metaFiles, - ContentHash: contentHash, - RootHash: rootHash, - } - metaKey := cacheMetaKey(pkg, opts) - writeCacheMeta(metaKey, meta) - - got, err := cacheKeyForPackage(pkg, opts) - if err != nil { - t.Fatalf("cacheKeyForPackage error: %v", err) - } - if got != contentHash { - t.Fatalf("cache key mismatch: got %q, want %q", got, contentHash) - } -} - -func TestCacheKeyErrorPaths(t *testing.T) { - pkg := &packages.Package{ - PkgPath: "example.com/missing", - GoFiles: []string{filepath.Join(t.TempDir(), "missing.go")}, - } - if _, err := cacheKeyForPackage(pkg, &GenerateOptions{}); err == nil { - t.Fatal("expected cacheKeyForPackage error") - } - if _, err := buildCacheFiles([]string{filepath.Join(t.TempDir(), "missing.go")}); err == nil { - t.Fatal("expected buildCacheFiles error") - } - if _, err := contentHashForPaths("example.com/missing", &GenerateOptions{}, []string{filepath.Join(t.TempDir(), "missing.go")}); err == nil { - t.Fatal("expected contentHashForPaths error") - } - if _, err := hashFiles([]string{filepath.Join(t.TempDir(), "missing.go")}); err == nil { - t.Fatal("expected hashFiles error") - } - if got, err := hashFiles(nil); err != nil || got != "" { - t.Fatalf("expected empty hashFiles result, got %q err=%v", got, err) - } -} - -func TestCacheMetaMatches(t *testing.T) { - tempDir := t.TempDir() - file := writeTempFile(t, tempDir, "meta.go", "package meta\n") - pkg := &packages.Package{ - PkgPath: "example.com/meta", - GoFiles: []string{file}, - } - opts := &GenerateOptions{} - files := packageFiles(pkg) - sort.Strings(files) - metaFiles, err := buildCacheFiles(files) - if err != nil { - t.Fatalf("buildCacheFiles error: %v", err) - } - rootFiles := rootPackageFiles(pkg) - sort.Strings(rootFiles) - rootHash, err := hashFiles(rootFiles) - if err != nil { - t.Fatalf("hashFiles error: %v", err) - } - contentHash, err := contentHashForFiles(pkg, opts, files) - if err != nil { - t.Fatalf("contentHashForFiles error: %v", err) - } - meta := &cacheMeta{ - Version: cacheVersion, - PkgPath: pkg.PkgPath, - Tags: opts.Tags, - Prefix: opts.PrefixOutputFile, - HeaderHash: headerHash(opts.Header), - Files: metaFiles, - ContentHash: contentHash, - RootHash: rootHash, - } - if !cacheMetaMatches(meta, pkg, opts, files) { - t.Fatal("expected cacheMetaMatches to succeed") - } - badVersion := *meta - badVersion.Version = "nope" - if cacheMetaMatches(&badVersion, pkg, opts, files) { - t.Fatal("expected version mismatch") - } - badPkg := *meta - badPkg.PkgPath = "example.com/other" - if cacheMetaMatches(&badPkg, pkg, opts, files) { - t.Fatal("expected pkg mismatch") - } - badHeader := *meta - badHeader.HeaderHash = "bad" - if cacheMetaMatches(&badHeader, pkg, opts, files) { - t.Fatal("expected header mismatch") - } - shortFiles := *meta - shortFiles.Files = nil - if cacheMetaMatches(&shortFiles, pkg, opts, files) { - t.Fatal("expected file count mismatch") - } - fileMismatch := *meta - fileMismatch.Files = append([]cacheFile(nil), meta.Files...) - fileMismatch.Files[0].Size++ - if cacheMetaMatches(&fileMismatch, pkg, opts, files) { - t.Fatal("expected file metadata mismatch") - } - pkgNoRoot := &packages.Package{PkgPath: pkg.PkgPath} - if cacheMetaMatches(meta, pkgNoRoot, opts, files) { - t.Fatal("expected missing root files") - } - noRootHash := *meta - noRootHash.RootHash = "" - if cacheMetaMatches(&noRootHash, pkg, opts, files) { - t.Fatal("expected empty root hash mismatch") - } - missingRootPkg := &packages.Package{ - PkgPath: "example.com/meta", - GoFiles: []string{filepath.Join(tempDir, "missing.go")}, - } - if cacheMetaMatches(meta, missingRootPkg, opts, files) { - t.Fatal("expected root hash error") - } - badRoot := *meta - badRoot.RootHash = "bad" - if cacheMetaMatches(&badRoot, pkg, opts, files) { - t.Fatal("expected root hash mismatch") - } - emptyContent := *meta - emptyContent.ContentHash = "" - if cacheMetaMatches(&emptyContent, pkg, opts, files) { - t.Fatal("expected empty content hash mismatch") - } - - if cacheMetaMatches(meta, pkg, opts, []string{filepath.Join(tempDir, "missing.go")}) { - t.Fatal("expected buildCacheFiles error") - } -} - -func TestCacheMetaReadWriteErrors(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - osTempDir = func() string { return tempDir } - - if _, ok := readCacheMeta("missing"); ok { - t.Fatal("expected cache meta miss") - } - - osReadFile = func(string) ([]byte, error) { - return []byte("{bad json"), nil - } - if _, ok := readCacheMeta("bad-json"); ok { - t.Fatal("expected cache meta miss on invalid json") - } - - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osMkdirAll = func(string, os.FileMode) error { return errors.New("mkdir") } - writeCacheMeta("mkdir", &cacheMeta{}) - - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - jsonMarshal = func(any) ([]byte, error) { return nil, errors.New("marshal") } - writeCacheMeta("marshal", &cacheMeta{}) - - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osCreateTemp = func(string, string) (*os.File, error) { return nil, errors.New("create") } - writeCacheMeta("create", &cacheMeta{}) - - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osCreateTemp = func(dir, pattern string) (*os.File, error) { - tmp, err := os.CreateTemp(dir, pattern) - if err != nil { - return nil, err - } - name := tmp.Name() - if err := tmp.Close(); err != nil { - return nil, err - } - return os.Open(name) - } - writeCacheMeta("write", &cacheMeta{}) - - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osRename = func(string, string) error { return errors.New("rename") } - writeCacheMeta("rename", &cacheMeta{}) -} - -func TestManifestReadWriteErrors(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - osTempDir = func() string { return tempDir } - - if _, ok := readManifest("missing"); ok { - t.Fatal("expected manifest miss") - } - - osReadFile = func(string) ([]byte, error) { - return []byte("{bad json"), nil - } - if _, ok := readManifest("bad-json"); ok { - t.Fatal("expected manifest miss on invalid json") - } - - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osMkdirAll = func(string, os.FileMode) error { return errors.New("mkdir") } - writeManifestFile("mkdir", &cacheManifest{}) - - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - jsonMarshal = func(any) ([]byte, error) { return nil, errors.New("marshal") } - writeManifestFile("marshal", &cacheManifest{}) - - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osCreateTemp = func(string, string) (*os.File, error) { return nil, errors.New("create") } - writeManifestFile("create", &cacheManifest{}) - - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osCreateTemp = func(dir, pattern string) (*os.File, error) { - tmp, err := os.CreateTemp(dir, pattern) - if err != nil { - return nil, err - } - name := tmp.Name() - if err := tmp.Close(); err != nil { - return nil, err - } - return os.Open(name) - } - writeManifestFile("write", &cacheManifest{}) - - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osRename = func(string, string) error { return errors.New("rename") } - writeManifestFile("rename", &cacheManifest{}) -} - -func TestManifestKeyHelpers(t *testing.T) { - if got := manifestKeyFromManifest(nil); got != "" { - t.Fatalf("expected empty manifest key, got %q", got) - } - env := []string{"A=B"} - opts := &GenerateOptions{ - Tags: "tags", - PrefixOutputFile: "prefix", - Header: []byte("header"), - } - wd := t.TempDir() - patterns := []string{"./a", "./b"} - manifest := &cacheManifest{ - WD: runCacheScope(wd, patterns), - EnvHash: envHash(env), - Tags: opts.Tags, - Prefix: opts.PrefixOutputFile, - HeaderHash: headerHash(opts.Header), - Patterns: normalizePatternsForScope(wd, packageCacheScope(wd), patterns), - } - got := manifestKeyFromManifest(manifest) - want := manifestKey(wd, env, patterns, opts) - if got != want { - t.Fatalf("manifest key mismatch: got %q, want %q", got, want) - } -} - -func TestReadManifestResultsPaths(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - osTempDir = func() string { return tempDir } - - wd := t.TempDir() - env := []string{"A=B"} - patterns := []string{"./..."} - opts := &GenerateOptions{} - - if _, ok := readManifestResults(wd, env, patterns, opts); ok { - t.Fatal("expected no manifest") - } - - key := manifestKey(wd, env, patterns, opts) - invalid := &cacheManifest{Version: cacheVersion, WD: wd, EnvHash: "", Packages: nil} - writeManifestFile(key, invalid) - if _, ok := readManifestResults(wd, env, patterns, opts); ok { - t.Fatal("expected invalid manifest miss") - } - - file := writeTempFile(t, wd, "wire.go", "package app\n") - pkg := &packages.Package{ - PkgPath: "example.com/app", - GoFiles: []string{file}, - } - files := packageFiles(pkg) - sort.Strings(files) - contentHash, err := contentHashForFiles(pkg, opts, files) - if err != nil { - t.Fatalf("contentHashForFiles error: %v", err) - } - metaFiles, err := buildCacheFiles(files) - if err != nil { - t.Fatalf("buildCacheFiles error: %v", err) - } - rootFiles := rootPackageFiles(pkg) - sort.Strings(rootFiles) - rootMeta, err := buildCacheFiles(rootFiles) - if err != nil { - t.Fatalf("buildCacheFiles error: %v", err) - } - rootHash, err := hashFiles(rootFiles) - if err != nil { - t.Fatalf("hashFiles error: %v", err) - } - valid := &cacheManifest{ - Version: cacheVersion, - WD: wd, - Tags: opts.Tags, - Prefix: opts.PrefixOutputFile, - HeaderHash: headerHash(opts.Header), - EnvHash: envHash(env), - Patterns: sortedStrings(patterns), - Packages: []manifestPackage{ - { - PkgPath: pkg.PkgPath, - OutputPath: filepath.Join(wd, "wire_gen.go"), - Files: metaFiles, - ContentHash: contentHash, - RootFiles: rootMeta, - RootHash: rootHash, - }, - }, - } - writeManifestFile(key, valid) - if _, ok := readManifestResults(wd, env, patterns, opts); ok { - t.Fatal("expected cache miss without content") - } - writeCache(contentHash, []byte("wire")) - if results, ok := readManifestResults(wd, env, patterns, opts); !ok || len(results) != 1 { - t.Fatalf("expected manifest cache hit, got ok=%v results=%d", ok, len(results)) - } -} - -func TestWriteManifestBranches(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - osTempDir = func() string { return tempDir } - - wd := t.TempDir() - env := []string{"A=B"} - patterns := []string{"./..."} - opts := &GenerateOptions{} - - writeManifest(wd, env, patterns, opts, nil) - - writeManifest(wd, env, patterns, opts, []*packages.Package{nil}) - - writeManifest(wd, env, patterns, opts, []*packages.Package{{PkgPath: "example.com/empty"}}) - - missingFilePkg := &packages.Package{ - PkgPath: "example.com/missing", - GoFiles: []string{filepath.Join(wd, "missing.go")}, - } - writeManifest(wd, env, patterns, opts, []*packages.Package{missingFilePkg}) - - conflictDir := t.TempDir() - fileA := writeTempFile(t, conflictDir, "a.go", "package a\n") - fileB := writeTempFile(t, t.TempDir(), "b.go", "package b\n") - conflictPkg := &packages.Package{ - PkgPath: "example.com/conflict", - GoFiles: []string{fileA, fileB}, - } - writeManifest(wd, env, patterns, opts, []*packages.Package{conflictPkg}) - - okFile := writeTempFile(t, wd, "ok.go", "package ok\n") - okPkg := &packages.Package{ - PkgPath: "example.com/ok", - GoFiles: []string{okFile}, - } - cacheKeyForPackageFunc = func(*packages.Package, *GenerateOptions) (string, error) { - return "", errors.New("cache key") - } - writeManifest(wd, env, patterns, opts, []*packages.Package{okPkg}) - - cacheKeyForPackageFunc = func(*packages.Package, *GenerateOptions) (string, error) { - return "", nil - } - writeManifest(wd, env, patterns, opts, []*packages.Package{okPkg}) - - cacheKeyForPackageFunc = func(*packages.Package, *GenerateOptions) (string, error) { - return "hash", nil - } - detectOutputDirFunc = func([]string) (string, error) { - return "", errors.New("output") - } - writeManifest(wd, env, patterns, opts, []*packages.Package{okPkg}) - - detectOutputDirFunc = state.detectOutputDir - buildCacheFilesFunc = func([]string) ([]cacheFile, error) { - return nil, errors.New("build") - } - writeManifest(wd, env, patterns, opts, []*packages.Package{okPkg}) - - call := 0 - buildCacheFilesFunc = func([]string) ([]cacheFile, error) { - call++ - if call > 1 { - return nil, errors.New("root") - } - return []cacheFile{{Path: okFile}}, nil - } - rootPackageFilesFunc = func(*packages.Package) []string { - return []string{okFile} - } - writeManifest(wd, env, patterns, opts, []*packages.Package{okPkg}) - - buildCacheFilesFunc = state.buildCacheFiles - hashFilesFunc = func([]string) (string, error) { - return "", errors.New("hash") - } - writeManifest(wd, env, patterns, opts, []*packages.Package{okPkg}) - - restoreCacheHooks(state) - statCalls := 0 - osStat = func(name string) (os.FileInfo, error) { - statCalls++ - if statCalls > 3 { - return nil, errors.New("stat") - } - return state.osStat(name) - } - writeManifest(wd, env, patterns, opts, []*packages.Package{okPkg}) - - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - readCalls := 0 - osReadFile = func(name string) ([]byte, error) { - readCalls++ - if readCalls > 2 { - return nil, errors.New("read") - } - return state.osReadFile(name) - } - writeManifest(wd, env, patterns, opts, []*packages.Package{okPkg}) -} - -func TestManifestValidationAndExtras(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - if manifestValid(nil) { - t.Fatal("expected nil manifest invalid") - } - if manifestValid(&cacheManifest{Version: "bad"}) { - t.Fatal("expected version mismatch") - } - if manifestValid(&cacheManifest{Version: cacheVersion}) { - t.Fatal("expected missing env hash") - } - - tempDir := t.TempDir() - file := writeTempFile(t, tempDir, "valid.go", "package valid\n") - files, err := buildCacheFiles([]string{file}) - if err != nil { - t.Fatalf("buildCacheFiles error: %v", err) - } - rootHash, err := hashFiles([]string{file}) - if err != nil { - t.Fatalf("hashFiles error: %v", err) - } - valid := &cacheManifest{ - Version: cacheVersion, - WD: tempDir, - EnvHash: "env", - Packages: []manifestPackage{{PkgPath: "example.com/valid", Files: files, RootFiles: files, ContentHash: "hash", RootHash: rootHash}}, - ExtraFiles: nil, - } - if !manifestValid(valid) { - t.Fatal("expected valid manifest") - } - - invalidExtra := cloneManifest(valid) - invalidExtra.ExtraFiles = []cacheFile{{Path: filepath.Join(tempDir, "missing.go")}} - if manifestValid(invalidExtra) { - t.Fatal("expected invalid extra files") - } - - extraMismatch := cloneManifest(valid) - extraMismatch.ExtraFiles = []cacheFile{files[0]} - extraMismatch.ExtraFiles[0].Size++ - if manifestValid(extraMismatch) { - t.Fatal("expected extra file metadata mismatch") - } - - invalidPkg := cloneManifest(valid) - invalidPkg.Packages[0].ContentHash = "" - if manifestValid(invalidPkg) { - t.Fatal("expected invalid content hash") - } - - invalidRoot := cloneManifest(valid) - invalidRoot.Packages[0].RootHash = "" - if manifestValid(invalidRoot) { - t.Fatal("expected invalid root hash") - } - - invalidFiles := cloneManifest(valid) - invalidFiles.Packages[0].Files = []cacheFile{{Path: filepath.Join(tempDir, "missing.go")}} - if manifestValid(invalidFiles) { - t.Fatal("expected invalid package files") - } - - fileMismatch := cloneManifest(valid) - fileMismatch.Packages[0].Files = []cacheFile{files[0]} - fileMismatch.Packages[0].Files[0].Size++ - if manifestValid(fileMismatch) { - t.Fatal("expected package file mismatch") - } - - invalidRootFiles := cloneManifest(valid) - invalidRootFiles.Packages[0].RootFiles = []cacheFile{{Path: filepath.Join(tempDir, "missing.go")}} - if manifestValid(invalidRootFiles) { - t.Fatal("expected invalid root files") - } - - rootMismatch := cloneManifest(valid) - rootMismatch.Packages[0].RootFiles = []cacheFile{files[0]} - rootMismatch.Packages[0].RootFiles[0].Size++ - if manifestValid(rootMismatch) { - t.Fatal("expected root file mismatch") - } - - emptyRoot := cloneManifest(valid) - emptyRoot.Packages[0].RootFiles = nil - if manifestValid(emptyRoot) { - t.Fatal("expected empty root files") - } - - badHash := cloneManifest(valid) - badHash.Packages[0].RootHash = "bad" - if manifestValid(badHash) { - t.Fatal("expected root hash mismatch") - } - - if _, err := buildCacheFilesFromMeta([]cacheFile{{Path: filepath.Join(tempDir, "missing.go")}}); err == nil { - t.Fatal("expected buildCacheFilesFromMeta error") - } - - extraCachePathsFunc = func(string) []string { - return []string{file, file, filepath.Join(tempDir, "missing.go")} - } - extras := extraCacheFiles(tempDir) - if len(extras) != 1 { - t.Fatalf("expected 1 extra file, got %d", len(extras)) - } - - extraCachePathsFunc = func(string) []string { return nil } - if extras := extraCacheFiles(tempDir); extras != nil { - t.Fatal("expected nil extras") - } - - extraCachePathsFunc = func(string) []string { return []string{file, writeTempFile(t, tempDir, "go.sum", "sum\n")} } - if extras := extraCacheFiles(tempDir); len(extras) < 2 { - t.Fatalf("expected extras to include two files, got %v", extras) - } -} - -func TestExtraCachePaths(t *testing.T) { - tempDir := t.TempDir() - rootMod := writeTempFile(t, tempDir, "go.mod", "module example.com/root\n") - writeTempFile(t, tempDir, "go.sum", "sum\n") - nested := filepath.Join(tempDir, "nested", "dir") - if err := os.MkdirAll(nested, 0755); err != nil { - t.Fatalf("MkdirAll failed: %v", err) - } - paths := extraCachePaths(nested) - if len(paths) < 2 { - t.Fatalf("expected extra cache paths, got %v", paths) - } - found := false - for _, path := range paths { - if path == rootMod { - found = true - break - } - } - if !found { - t.Fatalf("expected %s in paths: %v", rootMod, paths) - } - if got := sortedStrings(nil); got != nil { - t.Fatal("expected nil for empty sortedStrings") - } - if got := envHash(nil); got != "" { - t.Fatal("expected empty env hash") - } -} - -func TestRootPackageFiles(t *testing.T) { - if rootPackageFiles(nil) != nil { - t.Fatal("expected nil root files for nil package") - } - tempDir := t.TempDir() - compiled := writeTempFile(t, tempDir, "compiled.go", "package compiled\n") - pkg := &packages.Package{ - PkgPath: "example.com/compiled", - CompiledGoFiles: []string{compiled}, - } - got := rootPackageFiles(pkg) - if len(got) != 1 || got[0] != compiled { - t.Fatalf("unexpected compiled files: %v", got) - } -} - -func TestAddExtraCachePath(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - file := writeTempFile(t, tempDir, "go.mod", "module example.com\n") - var paths []string - seen := make(map[string]struct{}) - addExtraCachePath(&paths, seen, file) - addExtraCachePath(&paths, seen, file) - if len(paths) != 1 { - t.Fatalf("expected 1 path, got %d", len(paths)) - } - addExtraCachePath(&paths, seen, filepath.Join(tempDir, "missing.go")) - if len(paths) != 1 { - t.Fatalf("unexpected extra path append: %v", paths) - } -} - -func TestManifestValidHookBranches(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - file := writeTempFile(t, tempDir, "hook.go", "package hook\n") - files, err := buildCacheFiles([]string{file}) - if err != nil { - t.Fatalf("buildCacheFiles error: %v", err) - } - rootHash, err := hashFiles([]string{file}) - if err != nil { - t.Fatalf("hashFiles error: %v", err) - } - base := &cacheManifest{ - Version: cacheVersion, - WD: tempDir, - EnvHash: "env", - Packages: []manifestPackage{{PkgPath: "example.com/hook", Files: files, RootFiles: files, ContentHash: "hash", RootHash: rootHash}}, - ExtraFiles: []cacheFile{files[0]}, - } - - buildCacheFilesFromMetaFunc = func(in []cacheFile) ([]cacheFile, error) { - if len(in) == 1 && in[0].Path == files[0].Path { - return []cacheFile{}, nil - } - return buildCacheFilesFromMeta(in) - } - if manifestValid(base) { - t.Fatal("expected extra file length mismatch") - } - - restoreCacheHooks(state) - emptyRoot := cloneManifest(base) - emptyRoot.Packages[0].RootFiles = nil - if manifestValid(emptyRoot) { - t.Fatal("expected empty root files") - } - - restoreCacheHooks(state) - buildCacheFilesFromMetaFunc = func(in []cacheFile) ([]cacheFile, error) { - if len(in) == 1 && in[0].Path == file { - return nil, errors.New("pkg files") - } - return buildCacheFilesFromMeta(in) - } - noExtra := cloneManifest(base) - noExtra.ExtraFiles = nil - if manifestValid(noExtra) { - t.Fatal("expected pkg files error") - } - - restoreCacheHooks(state) - buildCacheFilesFromMetaFunc = func(in []cacheFile) ([]cacheFile, error) { - if len(in) == 1 && in[0].Path == file { - return []cacheFile{}, nil - } - return buildCacheFilesFromMeta(in) - } - if manifestValid(noExtra) { - t.Fatal("expected pkg files length mismatch") - } - - restoreCacheHooks(state) - call := 0 - buildCacheFilesFromMetaFunc = func(in []cacheFile) ([]cacheFile, error) { - call++ - if call == 2 { - return nil, errors.New("root files") - } - return buildCacheFilesFromMeta(in) - } - if manifestValid(noExtra) { - t.Fatal("expected root files error") - } - - restoreCacheHooks(state) - call = 0 - buildCacheFilesFromMetaFunc = func(in []cacheFile) ([]cacheFile, error) { - call++ - if call == 2 { - return []cacheFile{}, nil - } - return buildCacheFilesFromMeta(in) - } - if manifestValid(noExtra) { - t.Fatal("expected root files length mismatch") - } - - restoreCacheHooks(state) - call = 0 - buildCacheFilesFromMetaFunc = func(in []cacheFile) ([]cacheFile, error) { - call++ - if call == 2 { - return []cacheFile{{Path: file, Size: files[0].Size + 1}}, nil - } - return buildCacheFilesFromMeta(in) - } - if manifestValid(noExtra) { - t.Fatal("expected root files mismatch") - } -} diff --git a/internal/wire/cache_generate_test.go b/internal/wire/cache_generate_test.go deleted file mode 100644 index d009f73..0000000 --- a/internal/wire/cache_generate_test.go +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "context" - "os" - "path/filepath" - "sort" - "testing" - - "golang.org/x/tools/go/packages" -) - -func TestGenerateUsesManifestCache(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - osTempDir = func() string { return tempDir } - - wd := t.TempDir() - file := filepath.Join(wd, "provider.go") - if err := os.WriteFile(file, []byte("package p\n"), 0644); err != nil { - t.Fatalf("WriteFile failed: %v", err) - } - - env := []string{"A=B"} - patterns := []string{"./..."} - opts := &GenerateOptions{} - key := manifestKey(wd, env, patterns, opts) - - pkg := &packages.Package{ - PkgPath: "example.com/p", - GoFiles: []string{file}, - } - files := packageFiles(pkg) - sort.Strings(files) - contentHash, err := contentHashForFiles(pkg, opts, files) - if err != nil { - t.Fatalf("contentHashForFiles error: %v", err) - } - metaFiles, err := buildCacheFiles(files) - if err != nil { - t.Fatalf("buildCacheFiles error: %v", err) - } - rootFiles := rootPackageFiles(pkg) - sort.Strings(rootFiles) - rootMeta, err := buildCacheFiles(rootFiles) - if err != nil { - t.Fatalf("buildCacheFiles root error: %v", err) - } - rootHash, err := hashFiles(rootFiles) - if err != nil { - t.Fatalf("hashFiles error: %v", err) - } - - manifest := &cacheManifest{ - Version: cacheVersion, - WD: wd, - Tags: opts.Tags, - Prefix: opts.PrefixOutputFile, - HeaderHash: headerHash(opts.Header), - EnvHash: envHash(env), - Patterns: sortedStrings(patterns), - Packages: []manifestPackage{ - { - PkgPath: pkg.PkgPath, - OutputPath: filepath.Join(wd, "wire_gen.go"), - Files: metaFiles, - ContentHash: contentHash, - RootFiles: rootMeta, - RootHash: rootHash, - }, - }, - } - writeManifestFile(key, manifest) - writeCache(contentHash, []byte("wire")) - - results, errs := Generate(context.Background(), wd, env, patterns, opts) - if len(errs) > 0 { - t.Fatalf("Generate returned errors: %v", errs) - } - if len(results) != 1 || string(results[0].Content) != "wire" { - t.Fatalf("unexpected cached results: %+v", results) - } -} diff --git a/internal/wire/cache_key.go b/internal/wire/cache_key.go deleted file mode 100644 index f22c6c0..0000000 --- a/internal/wire/cache_key.go +++ /dev/null @@ -1,352 +0,0 @@ -// Copyright 2018 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "crypto/sha256" - "fmt" - "path/filepath" - "runtime" - "sort" - "sync" - - "golang.org/x/tools/go/packages" -) - -// cacheVersion is the schema/version identifier for cache entries. -const cacheVersion = "wire-cache-v3" - -// cacheFile captures file metadata used to validate cached content. -type cacheFile struct { - Path string `json:"path"` - Size int64 `json:"size"` - ModTime int64 `json:"mod_time"` -} - -// cacheMeta tracks inputs and outputs for a single package cache entry. -type cacheMeta struct { - Version string `json:"version"` - PkgPath string `json:"pkg_path"` - Tags string `json:"tags"` - Prefix string `json:"prefix"` - HeaderHash string `json:"header_hash"` - Files []cacheFile `json:"files"` - ContentHash string `json:"content_hash"` - RootHash string `json:"root_hash"` -} - -// cacheKeyForPackage returns the content hash for a package, if cacheable. -func cacheKeyForPackage(pkg *packages.Package, opts *GenerateOptions) (string, error) { - files := packageFiles(pkg) - if len(files) == 0 { - return "", nil - } - sort.Strings(files) - metaKey := cacheMetaKey(pkg, opts) - if meta, ok := readCacheMeta(metaKey); ok { - if cacheMetaMatches(meta, pkg, opts, files) { - return meta.ContentHash, nil - } - } - contentHash, err := contentHashForFiles(pkg, opts, files) - if err != nil { - return "", err - } - rootFiles := rootPackageFiles(pkg) - sort.Strings(rootFiles) - rootHash, err := hashFiles(rootFiles) - if err != nil { - return "", err - } - metaFiles, err := buildCacheFiles(files) - if err != nil { - return "", err - } - meta := &cacheMeta{ - Version: cacheVersion, - PkgPath: pkg.PkgPath, - Tags: opts.Tags, - Prefix: opts.PrefixOutputFile, - HeaderHash: headerHash(opts.Header), - Files: metaFiles, - ContentHash: contentHash, - RootHash: rootHash, - } - writeCacheMeta(metaKey, meta) - return contentHash, nil -} - -// packageFiles returns the transitive Go files for a package graph. -func packageFiles(root *packages.Package) []string { - seen := make(map[string]struct{}) - var files []string - stack := []*packages.Package{root} - for len(stack) > 0 { - p := stack[len(stack)-1] - stack = stack[:len(stack)-1] - if p == nil { - continue - } - if _, ok := seen[p.PkgPath]; ok { - continue - } - seen[p.PkgPath] = struct{}{} - if len(p.CompiledGoFiles) > 0 { - files = append(files, p.CompiledGoFiles...) - } else if len(p.GoFiles) > 0 { - files = append(files, p.GoFiles...) - } - for _, imp := range p.Imports { - stack = append(stack, imp) - } - } - return files -} - -// cacheMetaKey builds the key for a package's cache metadata entry. -func cacheMetaKey(pkg *packages.Package, opts *GenerateOptions) string { - h := sha256.New() - h.Write([]byte(cacheVersion)) - h.Write([]byte{0}) - h.Write([]byte(pkg.PkgPath)) - h.Write([]byte{0}) - h.Write([]byte(opts.Tags)) - h.Write([]byte{0}) - h.Write([]byte(opts.PrefixOutputFile)) - h.Write([]byte{0}) - h.Write([]byte(headerHash(opts.Header))) - return fmt.Sprintf("%x", h.Sum(nil)) -} - -// cacheMetaPath returns the on-disk path for a cache metadata key. -func cacheMetaPath(key string) string { - return filepath.Join(cacheDir(), key+".json") -} - -// readCacheMeta loads a cached metadata entry if it exists. -func readCacheMeta(key string) (*cacheMeta, bool) { - data, err := osReadFile(cacheMetaPath(key)) - if err != nil { - return nil, false - } - var meta cacheMeta - if err := jsonUnmarshal(data, &meta); err != nil { - return nil, false - } - return &meta, true -} - -// writeCacheMeta persists cache metadata to disk. -func writeCacheMeta(key string, meta *cacheMeta) { - dir := cacheDir() - if err := osMkdirAll(dir, 0755); err != nil { - return - } - data, err := jsonMarshal(meta) - if err != nil { - return - } - tmp, err := osCreateTemp(dir, key+".meta-") - if err != nil { - return - } - _, writeErr := tmp.Write(data) - closeErr := tmp.Close() - if writeErr != nil || closeErr != nil { - osRemove(tmp.Name()) - return - } - path := cacheMetaPath(key) - if err := osRename(tmp.Name(), path); err != nil { - osRemove(tmp.Name()) - } -} - -// cacheMetaMatches reports whether metadata matches the current package inputs. -func cacheMetaMatches(meta *cacheMeta, pkg *packages.Package, opts *GenerateOptions, files []string) bool { - if meta.Version != cacheVersion { - return false - } - if meta.PkgPath != pkg.PkgPath || meta.Tags != opts.Tags || meta.Prefix != opts.PrefixOutputFile { - return false - } - if meta.HeaderHash != headerHash(opts.Header) { - return false - } - if len(meta.Files) != len(files) { - return false - } - current, err := buildCacheFiles(files) - if err != nil { - return false - } - for i := range meta.Files { - if meta.Files[i] != current[i] { - return false - } - } - rootFiles := rootPackageFiles(pkg) - if len(rootFiles) == 0 || meta.RootHash == "" { - return false - } - sort.Strings(rootFiles) - rootHash, err := hashFiles(rootFiles) - if err != nil || rootHash != meta.RootHash { - return false - } - return meta.ContentHash != "" -} - -// buildCacheFiles converts file paths into cache metadata entries. -func buildCacheFiles(files []string) ([]cacheFile, error) { - return buildCacheFilesWithStats(files, func(path string) (cacheFile, error) { - info, err := osStat(path) - if err != nil { - return cacheFile{}, err - } - return cacheFile{ - Path: filepath.Clean(path), - Size: info.Size(), - ModTime: info.ModTime().UnixNano(), - }, nil - }) -} - -func buildCacheFilesWithStats[T any](items []T, stat func(T) (cacheFile, error)) ([]cacheFile, error) { - if len(items) == 0 { - return nil, nil - } - if len(items) == 1 { - file, err := stat(items[0]) - if err != nil { - return nil, err - } - return []cacheFile{file}, nil - } - out := make([]cacheFile, len(items)) - workers := runtime.GOMAXPROCS(0) - if workers < 4 { - workers = 4 - } - if workers > len(items) { - workers = len(items) - } - var ( - wg sync.WaitGroup - mu sync.Mutex - firstErr error - indexCh = make(chan int, len(items)) - ) - for i := range items { - indexCh <- i - } - close(indexCh) - wg.Add(workers) - for i := 0; i < workers; i++ { - go func() { - defer wg.Done() - for i := range indexCh { - file, err := stat(items[i]) - if err != nil { - mu.Lock() - if firstErr == nil { - firstErr = err - } - mu.Unlock() - continue - } - out[i] = file - } - }() - } - wg.Wait() - if firstErr != nil { - return nil, firstErr - } - return out, nil -} - -// headerHash returns a stable hash of the generated header content. -func headerHash(header []byte) string { - if len(header) == 0 { - return "" - } - sum := sha256.Sum256(header) - return fmt.Sprintf("%x", sum[:]) -} - -// contentHashForFiles hashes the current package inputs using file paths. -func contentHashForFiles(pkg *packages.Package, opts *GenerateOptions, files []string) (string, error) { - return contentHashForPaths(pkg.PkgPath, opts, files) -} - -// contentHashForPaths hashes the provided file contents and options. -func contentHashForPaths(pkgPath string, opts *GenerateOptions, files []string) (string, error) { - h := sha256.New() - h.Write([]byte(cacheVersion)) - h.Write([]byte{0}) - h.Write([]byte(pkgPath)) - h.Write([]byte{0}) - h.Write([]byte(opts.Tags)) - h.Write([]byte{0}) - h.Write([]byte(opts.PrefixOutputFile)) - h.Write([]byte{0}) - h.Write([]byte(headerHash(opts.Header))) - h.Write([]byte{0}) - for _, name := range files { - h.Write([]byte(name)) - h.Write([]byte{0}) - data, err := osReadFile(name) - if err != nil { - return "", err - } - h.Write(data) - h.Write([]byte{0}) - } - return fmt.Sprintf("%x", h.Sum(nil)), nil -} - -// rootPackageFiles returns the direct Go files for the root package. -func rootPackageFiles(pkg *packages.Package) []string { - if pkg == nil { - return nil - } - if len(pkg.CompiledGoFiles) > 0 { - return append([]string(nil), pkg.CompiledGoFiles...) - } - if len(pkg.GoFiles) > 0 { - return append([]string(nil), pkg.GoFiles...) - } - return nil -} - -// hashFiles returns a combined content hash for the provided paths. -func hashFiles(files []string) (string, error) { - if len(files) == 0 { - return "", nil - } - h := sha256.New() - for _, name := range files { - h.Write([]byte(name)) - h.Write([]byte{0}) - data, err := osReadFile(name) - if err != nil { - return "", err - } - h.Write(data) - h.Write([]byte{0}) - } - return fmt.Sprintf("%x", h.Sum(nil)), nil -} diff --git a/internal/wire/cache_manifest.go b/internal/wire/cache_manifest.go deleted file mode 100644 index 57be68b..0000000 --- a/internal/wire/cache_manifest.go +++ /dev/null @@ -1,393 +0,0 @@ -// Copyright 2018 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "crypto/sha256" - "fmt" - "path/filepath" - "sort" - - "golang.org/x/tools/go/packages" -) - -// cacheManifest stores per-run cache metadata for generated packages. -type cacheManifest struct { - Version string `json:"version"` - WD string `json:"wd"` - Tags string `json:"tags"` - Prefix string `json:"prefix"` - HeaderHash string `json:"header_hash"` - EnvHash string `json:"env_hash"` - Patterns []string `json:"patterns"` - Packages []manifestPackage `json:"packages"` - ExtraFiles []cacheFile `json:"extra_files"` -} - -// manifestPackage captures cached output for a single package. -type manifestPackage struct { - PkgPath string `json:"pkg_path"` - OutputPath string `json:"output_path"` - Files []cacheFile `json:"files"` - ContentHash string `json:"content_hash"` - RootFiles []cacheFile `json:"root_files"` - RootHash string `json:"root_hash"` -} - -var extraCachePathsFunc = extraCachePaths - -// readManifestResults loads cached generation results if still valid. -func readManifestResults(wd string, env []string, patterns []string, opts *GenerateOptions) ([]GenerateResult, bool) { - key := manifestKey(wd, env, patterns, opts) - manifest, ok := readManifest(key) - if !ok { - return nil, false - } - if !manifestValid(manifest) { - return nil, false - } - results := make([]GenerateResult, 0, len(manifest.Packages)) - for _, pkg := range manifest.Packages { - content, ok := readCache(pkg.ContentHash) - if !ok { - return nil, false - } - results = append(results, GenerateResult{ - PkgPath: pkg.PkgPath, - OutputPath: pkg.OutputPath, - Content: content, - }) - } - return results, true -} - -// writeManifest persists cache metadata for a successful run. -func writeManifest(wd string, env []string, patterns []string, opts *GenerateOptions, pkgs []*packages.Package) { - if len(pkgs) == 0 { - return - } - key := manifestKey(wd, env, patterns, opts) - scope := runCacheScope(wd, patterns) - manifest := &cacheManifest{ - Version: cacheVersion, - WD: scope, - Tags: opts.Tags, - Prefix: opts.PrefixOutputFile, - HeaderHash: headerHash(opts.Header), - EnvHash: envHash(env), - Patterns: normalizePatternsForScope(wd, packageCacheScope(wd), patterns), - } - manifest.ExtraFiles = extraCacheFiles(wd) - for _, pkg := range pkgs { - if pkg == nil { - continue - } - files := packageFiles(pkg) - if len(files) == 0 { - continue - } - sort.Strings(files) - contentHash, err := cacheKeyForPackageFunc(pkg, opts) - if err != nil || contentHash == "" { - continue - } - outDir, err := detectOutputDirFunc(pkg.GoFiles) - if err != nil { - continue - } - outputPath := filepath.Join(outDir, opts.PrefixOutputFile+"wire_gen.go") - metaFiles, err := buildCacheFilesFunc(files) - if err != nil { - continue - } - rootFiles := rootPackageFilesFunc(pkg) - sort.Strings(rootFiles) - rootMeta, err := buildCacheFilesFunc(rootFiles) - if err != nil { - continue - } - rootHash, err := hashFilesFunc(rootFiles) - if err != nil { - continue - } - manifest.Packages = append(manifest.Packages, manifestPackage{ - PkgPath: pkg.PkgPath, - OutputPath: outputPath, - Files: metaFiles, - ContentHash: contentHash, - RootFiles: rootMeta, - RootHash: rootHash, - }) - } - writeManifestFile(key, manifest) -} - -// manifestKey builds the cache key for a given run configuration. -func manifestKey(wd string, env []string, patterns []string, opts *GenerateOptions) string { - h := sha256.New() - h.Write([]byte(cacheVersion)) - h.Write([]byte{0}) - h.Write([]byte(runCacheScope(wd, patterns))) - h.Write([]byte{0}) - h.Write([]byte(envHash(env))) - h.Write([]byte{0}) - h.Write([]byte(opts.Tags)) - h.Write([]byte{0}) - h.Write([]byte(opts.PrefixOutputFile)) - h.Write([]byte{0}) - h.Write([]byte(headerHash(opts.Header))) - h.Write([]byte{0}) - for _, p := range normalizePatternsForScope(wd, packageCacheScope(wd), patterns) { - h.Write([]byte(p)) - h.Write([]byte{0}) - } - return fmt.Sprintf("%x", h.Sum(nil)) -} - -// manifestKeyFromManifest rebuilds the cache key from stored metadata. -func manifestKeyFromManifest(manifest *cacheManifest) string { - if manifest == nil { - return "" - } - h := sha256.New() - h.Write([]byte(cacheVersion)) - h.Write([]byte{0}) - h.Write([]byte(filepath.Clean(manifest.WD))) - h.Write([]byte{0}) - h.Write([]byte(manifest.EnvHash)) - h.Write([]byte{0}) - h.Write([]byte(manifest.Tags)) - h.Write([]byte{0}) - h.Write([]byte(manifest.Prefix)) - h.Write([]byte{0}) - h.Write([]byte(manifest.HeaderHash)) - h.Write([]byte{0}) - for _, p := range sortedStrings(manifest.Patterns) { - h.Write([]byte(p)) - h.Write([]byte{0}) - } - return fmt.Sprintf("%x", h.Sum(nil)) -} - -// readManifest loads the cached manifest by key. -func readManifest(key string) (*cacheManifest, bool) { - data, err := osReadFile(cacheManifestPath(key)) - if err != nil { - return nil, false - } - var manifest cacheManifest - if err := jsonUnmarshal(data, &manifest); err != nil { - return nil, false - } - return &manifest, true -} - -// writeManifestFile writes the manifest to disk. -func writeManifestFile(key string, manifest *cacheManifest) { - dir := cacheDir() - if err := osMkdirAll(dir, 0755); err != nil { - return - } - data, err := jsonMarshal(manifest) - if err != nil { - return - } - tmp, err := osCreateTemp(dir, key+".manifest-") - if err != nil { - return - } - _, writeErr := tmp.Write(data) - closeErr := tmp.Close() - if writeErr != nil || closeErr != nil { - osRemove(tmp.Name()) - return - } - path := cacheManifestPath(key) - if err := osRename(tmp.Name(), path); err != nil { - osRemove(tmp.Name()) - } -} - -// cacheManifestPath returns the on-disk path for a manifest key. -func cacheManifestPath(key string) string { - return filepath.Join(cacheDir(), key+".manifest.json") -} - -// manifestValid reports whether the manifest still matches current inputs. -func manifestValid(manifest *cacheManifest) bool { - if manifest == nil || manifest.Version != cacheVersion { - return false - } - if manifest.EnvHash == "" || len(manifest.Packages) == 0 { - return false - } - if len(manifest.ExtraFiles) > 0 { - current, err := buildCacheFilesFromMetaFunc(manifest.ExtraFiles) - if err != nil { - return false - } - if len(current) != len(manifest.ExtraFiles) { - return false - } - for i := range manifest.ExtraFiles { - if manifest.ExtraFiles[i] != current[i] { - return false - } - } - } - for i := range manifest.Packages { - pkg := manifest.Packages[i] - if pkg.ContentHash == "" { - return false - } - if len(pkg.RootFiles) == 0 || pkg.RootHash == "" { - return false - } - current, err := buildCacheFilesFromMetaFunc(pkg.Files) - if err != nil { - return false - } - if len(current) != len(pkg.Files) { - return false - } - for j := range pkg.Files { - if pkg.Files[j] != current[j] { - return false - } - } - rootCurrent, err := buildCacheFilesFromMetaFunc(pkg.RootFiles) - if err != nil { - return false - } - if len(rootCurrent) != len(pkg.RootFiles) { - return false - } - for j := range pkg.RootFiles { - if pkg.RootFiles[j] != rootCurrent[j] { - return false - } - } - rootPaths := make([]string, 0, len(pkg.RootFiles)) - for _, file := range pkg.RootFiles { - rootPaths = append(rootPaths, file.Path) - } - sort.Strings(rootPaths) - rootHash, err := hashFiles(rootPaths) - if err != nil || rootHash != pkg.RootHash { - return false - } - } - return true -} - -// buildCacheFilesFromMeta re-stats files to compare metadata. -func buildCacheFilesFromMeta(files []cacheFile) ([]cacheFile, error) { - return buildCacheFilesWithStats(files, func(file cacheFile) (cacheFile, error) { - info, err := osStat(file.Path) - if err != nil { - return cacheFile{}, err - } - return cacheFile{ - Path: filepath.Clean(file.Path), - Size: info.Size(), - ModTime: info.ModTime().UnixNano(), - }, nil - }) -} - -// extraCacheFiles returns Go module/workspace files affecting builds. -func extraCacheFiles(wd string) []cacheFile { - paths := extraCachePathsFunc(wd) - if len(paths) == 0 { - return nil - } - out := make([]cacheFile, 0, len(paths)) - seen := make(map[string]struct{}) - for _, path := range paths { - path = filepath.Clean(path) - if _, ok := seen[path]; ok { - continue - } - info, err := osStat(path) - if err != nil { - continue - } - seen[path] = struct{}{} - out = append(out, cacheFile{ - Path: path, - Size: info.Size(), - ModTime: info.ModTime().UnixNano(), - }) - } - sort.Slice(out, func(i, j int) bool { - return out[i].Path < out[j].Path - }) - return out -} - -// extraCachePaths finds go.mod/go.sum/go.work files for a working dir. -func extraCachePaths(wd string) []string { - var paths []string - dir := filepath.Clean(wd) - seen := make(map[string]struct{}) - for { - for _, name := range []string{"go.work", "go.work.sum", "go.mod", "go.sum"} { - full := filepath.Join(dir, name) - addExtraCachePath(&paths, seen, full) - } - parent := filepath.Dir(dir) - if parent == dir { - break - } - dir = parent - } - return paths -} - -// addExtraCachePath appends an existing file if it has not been seen. -func addExtraCachePath(paths *[]string, seen map[string]struct{}, full string) { - if _, ok := seen[full]; ok { - return - } - if _, err := osStat(full); err != nil { - return - } - *paths = append(*paths, full) - seen[full] = struct{}{} -} - -// sortedStrings returns a sorted copy of the input slice. -func sortedStrings(values []string) []string { - if len(values) == 0 { - return nil - } - out := append([]string(nil), values...) - sort.Strings(out) - return out -} - -// envHash returns a stable hash of environment variables. -func envHash(env []string) string { - if len(env) == 0 { - return "" - } - sorted := sortedStrings(env) - h := sha256.New() - for _, v := range sorted { - h.Write([]byte(v)) - h.Write([]byte{0}) - } - return fmt.Sprintf("%x", h.Sum(nil)) -} diff --git a/internal/wire/cache_scope.go b/internal/wire/cache_scope.go deleted file mode 100644 index fe161a7..0000000 --- a/internal/wire/cache_scope.go +++ /dev/null @@ -1,69 +0,0 @@ -package wire - -import ( - "path/filepath" - "sort" - "strings" -) - -func packageCacheScope(wd string) string { - if root := findModuleRoot(wd); root != "" { - return filepath.Clean(root) - } - return filepath.Clean(wd) -} - -func runCacheScope(wd string, patterns []string) string { - scopeRoot := packageCacheScope(wd) - normalized := normalizePatternsForScope(wd, scopeRoot, patterns) - if len(normalized) == 0 { - return scopeRoot - } - return scopeRoot + "\n" + strings.Join(normalized, "\n") -} - -func normalizePatternsForScope(wd string, scopeRoot string, patterns []string) []string { - if len(patterns) == 0 { - return nil - } - out := make([]string, 0, len(patterns)) - for _, pattern := range patterns { - out = append(out, normalizePatternForScope(wd, scopeRoot, pattern)) - } - sort.Strings(out) - return out -} - -func normalizePatternForScope(wd string, scopeRoot string, pattern string) string { - if pattern == "" { - return pattern - } - if filepath.IsAbs(pattern) || strings.HasPrefix(pattern, ".") { - abs := pattern - if !filepath.IsAbs(abs) { - abs = filepath.Join(wd, pattern) - } - abs = filepath.Clean(abs) - if scopeRoot != "" { - if rel, ok := pathWithinRoot(scopeRoot, abs); ok { - if rel == "." { - return "." - } - return filepath.ToSlash(rel) - } - } - return filepath.ToSlash(abs) - } - return pattern -} - -func pathWithinRoot(root string, path string) (string, bool) { - rel, err := filepath.Rel(filepath.Clean(root), filepath.Clean(path)) - if err != nil { - return "", false - } - if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) { - return "", false - } - return rel, true -} diff --git a/internal/wire/cache_scope_test.go b/internal/wire/cache_scope_test.go deleted file mode 100644 index 9cc518b..0000000 --- a/internal/wire/cache_scope_test.go +++ /dev/null @@ -1,59 +0,0 @@ -package wire - -import ( - "path/filepath" - "testing" -) - -func TestRunScopedKeysIgnoreEquivalentWorkingDirectories(t *testing.T) { - root := t.TempDir() - writeFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.22\n") - wireDir := filepath.Join(root, "wire") - writeFile(t, filepath.Join(wireDir, "wire.go"), "package wire\n") - - env := []string{"GOOS=darwin"} - opts := &GenerateOptions{Tags: "wireinject", PrefixOutputFile: "gen_"} - - rootKey := manifestKey(root, env, []string{"./wire"}, opts) - subdirKey := manifestKey(wireDir, env, []string{"."}, opts) - if rootKey != subdirKey { - t.Fatalf("manifestKey mismatch: root=%q subdir=%q", rootKey, subdirKey) - } - - rootIncrementalKey := incrementalManifestSelectorKey(root, env, []string{"./wire"}, opts) - subdirIncrementalKey := incrementalManifestSelectorKey(wireDir, env, []string{"."}, opts) - if rootIncrementalKey != subdirIncrementalKey { - t.Fatalf("incrementalManifestSelectorKey mismatch: root=%q subdir=%q", rootIncrementalKey, subdirIncrementalKey) - } -} - -func TestPackageScopedKeysIgnoreEquivalentWorkingDirectories(t *testing.T) { - root := t.TempDir() - writeFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.22\n") - wireDir := filepath.Join(root, "wire") - writeFile(t, filepath.Join(wireDir, "wire.go"), "package wire\n") - - rootFingerprintKey := incrementalFingerprintKey(root, "wireinject", "example.com/app/wire") - subdirFingerprintKey := incrementalFingerprintKey(wireDir, "wireinject", "example.com/app/wire") - if rootFingerprintKey != subdirFingerprintKey { - t.Fatalf("incrementalFingerprintKey mismatch: root=%q subdir=%q", rootFingerprintKey, subdirFingerprintKey) - } - - rootSummaryKey := incrementalSummaryKey(root, "wireinject", "example.com/app/wire") - subdirSummaryKey := incrementalSummaryKey(wireDir, "wireinject", "example.com/app/wire") - if rootSummaryKey != subdirSummaryKey { - t.Fatalf("incrementalSummaryKey mismatch: root=%q subdir=%q", rootSummaryKey, subdirSummaryKey) - } - - rootGraphKey := incrementalGraphKey(root, "wireinject", []string{"example.com/app/wire"}) - subdirGraphKey := incrementalGraphKey(wireDir, "wireinject", []string{"example.com/app/wire"}) - if rootGraphKey != subdirGraphKey { - t.Fatalf("incrementalGraphKey mismatch: root=%q subdir=%q", rootGraphKey, subdirGraphKey) - } - - rootSessionKey := sessionKey(root, []string{"GOOS=darwin"}, "wireinject") - subdirSessionKey := sessionKey(wireDir, []string{"GOOS=darwin"}, "wireinject") - if rootSessionKey != subdirSessionKey { - t.Fatalf("sessionKey mismatch: root=%q subdir=%q", rootSessionKey, subdirSessionKey) - } -} diff --git a/internal/wire/cache_store.go b/internal/wire/cache_store.go deleted file mode 100644 index 0c959cf..0000000 --- a/internal/wire/cache_store.go +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright 2018 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "errors" - "io/fs" - "path/filepath" -) - -// cacheDir returns the base directory for Wire cache files. -func cacheDir() string { - return filepath.Join(osTempDir(), "wire-cache") -} - -// CacheDir returns the directory used for Wire's cache. -func CacheDir() string { - return cacheDir() -} - -// ClearCache removes all cached data. -func ClearCache() error { - clearIncrementalSessions() - return osRemoveAll(cacheDir()) -} - -// cachePath builds the on-disk path for a cached content hash. -func cachePath(key string) string { - return filepath.Join(cacheDir(), key+".bin") -} - -// readCache reads a cached content blob by key. -func readCache(key string) ([]byte, bool) { - data, err := osReadFile(cachePath(key)) - if err != nil { - return nil, false - } - return data, true -} - -// writeCache persists a content blob for the provided cache key. -func writeCache(key string, content []byte) { - dir := cacheDir() - if err := osMkdirAll(dir, 0755); err != nil { - return - } - path := cachePath(key) - tmp, err := osCreateTemp(dir, key+".tmp-") - if err != nil { - return - } - _, writeErr := tmp.Write(content) - closeErr := tmp.Close() - if writeErr != nil || closeErr != nil { - osRemove(tmp.Name()) - return - } - if err := osRename(tmp.Name(), path); err != nil { - if errors.Is(err, fs.ErrExist) { - osRemove(tmp.Name()) - return - } - osRemove(tmp.Name()) - } -} diff --git a/internal/wire/cache_test.go b/internal/wire/cache_test.go deleted file mode 100644 index 6ffb20a..0000000 --- a/internal/wire/cache_test.go +++ /dev/null @@ -1,387 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "context" - "os" - "path/filepath" - "strings" - "testing" -) - -func TestCacheInvalidation(t *testing.T) { - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - prevTmp := os.Getenv("TMPDIR") - if err := os.Setenv("TMPDIR", t.TempDir()); err != nil { - t.Fatalf("Setenv TMPDIR failed: %v", err) - } - t.Cleanup(func() { - os.Setenv("TMPDIR", prevTmp) - }) - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() string {", - "\twire.Build(dep.ProvideMessage)", - "\treturn \"\"", - "}", - "", - }, "\n")) - - depPath := filepath.Join(root, "dep", "dep.go") - writeFile(t, depPath, strings.Join([]string{ - "package dep", - "", - "func ProvideMessage() string {", - "\treturn \"hello\"", - "}", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := context.Background() - opts := &GenerateOptions{} - - first, errs := Generate(ctx, root, env, []string{"./app"}, opts) - if len(errs) > 0 { - t.Fatalf("first Generate errors: %v", errs) - } - if len(first) != 1 || len(first[0].Content) == 0 { - t.Fatalf("first Generate returned unexpected result: %+v", first) - } - - pkgs, _, errs := load(ctx, root, env, opts.Tags, []string{"./app"}) - if len(errs) > 0 || len(pkgs) != 1 { - t.Fatalf("load failed: %v", errs) - } - key, err := cacheKeyForPackage(pkgs[0], opts) - if err != nil { - t.Fatalf("cacheKeyForPackage failed: %v", err) - } - if cached, ok := readCache(key); !ok || len(cached) == 0 { - t.Fatal("expected cache entry after first Generate") - } - - writeFile(t, depPath, strings.Join([]string{ - "package dep", - "", - "func ProvideMessage() string {", - "\treturn \"goodbye\"", - "}", - "", - }, "\n")) - - second, errs := Generate(ctx, root, env, []string{"./app"}, opts) - if len(errs) > 0 { - t.Fatalf("second Generate errors: %v", errs) - } - if len(second) != 1 || len(second[0].Content) == 0 { - t.Fatalf("second Generate returned unexpected result: %+v", second) - } - pkgs, _, errs = load(ctx, root, env, opts.Tags, []string{"./app"}) - if len(errs) > 0 || len(pkgs) != 1 { - t.Fatalf("reload failed: %v", errs) - } - key2, err := cacheKeyForPackage(pkgs[0], opts) - if err != nil { - t.Fatalf("cacheKeyForPackage after update failed: %v", err) - } - if key2 == key { - t.Fatal("expected cache key to change after source update") - } - if !IncrementalEnabled(ctx, env) { - if cached, ok := readCache(key2); !ok || len(cached) == 0 { - t.Fatal("expected cache entry after second Generate") - } - } -} - -func TestManifestInvalidation(t *testing.T) { - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - prevTmp := os.Getenv("TMPDIR") - if err := os.Setenv("TMPDIR", t.TempDir()); err != nil { - t.Fatalf("Setenv TMPDIR failed: %v", err) - } - t.Cleanup(func() { - os.Setenv("TMPDIR", prevTmp) - }) - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() string {", - "\twire.Build(dep.ProvideMessage)", - "\treturn \"\"", - "}", - "", - }, "\n")) - - depPath := filepath.Join(root, "dep", "dep.go") - writeFile(t, depPath, strings.Join([]string{ - "package dep", - "", - "func ProvideMessage() string {", - "\treturn \"hello\"", - "}", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := context.Background() - opts := &GenerateOptions{} - - if _, errs := Generate(ctx, root, env, []string{"./app"}, opts); len(errs) > 0 { - t.Fatalf("Generate errors: %v", errs) - } - - key := manifestKey(root, env, []string{"./app"}, opts) - manifest, ok := readManifest(key) - if !ok { - t.Fatal("expected manifest after Generate") - } - if !manifestValid(manifest) { - t.Fatal("expected manifest to be valid") - } - - writeFile(t, depPath, strings.Join([]string{ - "package dep", - "", - "func ProvideMessage() string {", - "\treturn \"goodbye\"", - "}", - "", - }, "\n")) - - if manifestValid(manifest) { - t.Fatal("expected manifest to be invalid after source update") - } -} - -func TestManifestInvalidationGoMod(t *testing.T) { - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - prevTmp := os.Getenv("TMPDIR") - if err := os.Setenv("TMPDIR", t.TempDir()); err != nil { - t.Fatalf("Setenv TMPDIR failed: %v", err) - } - t.Cleanup(func() { - os.Setenv("TMPDIR", prevTmp) - }) - - goModPath := filepath.Join(root, "go.mod") - writeFile(t, goModPath, strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() string {", - "\twire.Build(dep.ProvideMessage)", - "\treturn \"\"", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "func ProvideMessage() string {", - "\treturn \"hello\"", - "}", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := context.Background() - opts := &GenerateOptions{} - - if _, errs := Generate(ctx, root, env, []string{"./app"}, opts); len(errs) > 0 { - t.Fatalf("Generate errors: %v", errs) - } - - key := manifestKey(root, env, []string{"./app"}, opts) - manifest, ok := readManifest(key) - if !ok { - t.Fatal("expected manifest after Generate") - } - if !manifestValid(manifest) { - t.Fatal("expected manifest to be valid") - } - - writeFile(t, goModPath, strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0 // updated", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - if manifestValid(manifest) { - t.Fatal("expected manifest to be invalid after go.mod update") - } -} - -func TestManifestInvalidationSameTimestamp(t *testing.T) { - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - prevTmp := os.Getenv("TMPDIR") - if err := os.Setenv("TMPDIR", t.TempDir()); err != nil { - t.Fatalf("Setenv TMPDIR failed: %v", err) - } - t.Cleanup(func() { - os.Setenv("TMPDIR", prevTmp) - }) - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - wirePath := filepath.Join(root, "app", "wire.go") - writeFile(t, wirePath, strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() string {", - "\twire.Build(dep.ProvideMessage)", - "\treturn \"\"", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "func ProvideMessage() string {", - "\treturn \"hello\"", - "}", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := context.Background() - opts := &GenerateOptions{} - - if _, errs := Generate(ctx, root, env, []string{"./app"}, opts); len(errs) > 0 { - t.Fatalf("Generate errors: %v", errs) - } - - key := manifestKey(root, env, []string{"./app"}, opts) - manifest, ok := readManifest(key) - if !ok { - t.Fatal("expected manifest after Generate") - } - if !manifestValid(manifest) { - t.Fatal("expected manifest to be valid") - } - - info, err := os.Stat(wirePath) - if err != nil { - t.Fatalf("Stat failed: %v", err) - } - originalMod := info.ModTime() - - original, err := os.ReadFile(wirePath) - if err != nil { - t.Fatalf("ReadFile failed: %v", err) - } - updated := strings.Replace(string(original), "ProvideMessage", "ProvideMassage", 1) - if len(updated) != len(original) { - t.Fatalf("expected updated content to keep length; got %d vs %d", len(updated), len(original)) - } - if err := os.WriteFile(wirePath, []byte(updated), 0644); err != nil { - t.Fatalf("WriteFile failed: %v", err) - } - if err := os.Chtimes(wirePath, originalMod, originalMod); err != nil { - t.Fatalf("Chtimes failed: %v", err) - } - - if manifestValid(manifest) { - t.Fatal("expected manifest to be invalid after same-timestamp content update") - } -} diff --git a/internal/wire/generate_package.go b/internal/wire/generate_package.go deleted file mode 100644 index 01d3d20..0000000 --- a/internal/wire/generate_package.go +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright 2018 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "context" - "errors" - "fmt" - "go/format" - "path/filepath" - "time" - - "golang.org/x/tools/go/packages" -) - -// generateForPackage runs Wire code generation for a single package. -func generateForPackage(ctx context.Context, pkg *packages.Package, loader *lazyLoader, opts *GenerateOptions) GenerateResult { - if opts == nil { - opts = &GenerateOptions{} - } - pkgStart := time.Now() - res := GenerateResult{ - PkgPath: pkg.PkgPath, - } - dirStart := time.Now() - outDir, err := detectOutputDir(pkg.GoFiles) - logTiming(ctx, "generate.package."+pkg.PkgPath+".output_dir", dirStart) - if err != nil { - res.Errs = append(res.Errs, err) - return res - } - res.OutputPath = filepath.Join(outDir, opts.PrefixOutputFile+"wire_gen.go") - cacheKey, err := cacheKeyForPackage(pkg, opts) - if err != nil { - res.Errs = append(res.Errs, err) - return res - } - if cacheKey != "" && !bypassPackageCache(ctx) { - cacheHitStart := time.Now() - if cached, ok := readCache(cacheKey); ok { - res.Content = cached - logTiming(ctx, "generate.package."+pkg.PkgPath+".cache_hit", cacheHitStart) - logTiming(ctx, "generate.package."+pkg.PkgPath+".total", pkgStart) - return res - } - } - oc := newObjectCache([]*packages.Package{pkg}, loader) - if loaded, errs := oc.ensurePackage(pkg.PkgPath); len(errs) > 0 { - res.Errs = append(res.Errs, errs...) - return res - } else if loaded != nil { - pkg = loaded - } - g := newGen(pkg) - injectorStart := time.Now() - injectorFiles, errs := generateInjectors(oc, g, pkg) - logTiming(ctx, "generate.package."+pkg.PkgPath+".injectors", injectorStart) - if len(errs) > 0 { - res.Errs = errs - return res - } - copyStart := time.Now() - copyNonInjectorDecls(g, injectorFiles, pkg.TypesInfo) - logTiming(ctx, "generate.package."+pkg.PkgPath+".copy_non_injectors", copyStart) - frameStart := time.Now() - goSrc := g.frame(opts.Tags) - logTiming(ctx, "generate.package."+pkg.PkgPath+".frame", frameStart) - if len(opts.Header) > 0 { - goSrc = append(opts.Header, goSrc...) - } - formatStart := time.Now() - fmtSrc, err := format.Source(goSrc) - logTiming(ctx, "generate.package."+pkg.PkgPath+".format", formatStart) - if err != nil { - // This is likely a bug from a poorly generated source file. - // Add an error but also the unformatted source. - res.Errs = append(res.Errs, err) - } else { - goSrc = fmtSrc - } - res.Content = goSrc - if cacheKey != "" && len(res.Errs) == 0 { - writeCache(cacheKey, res.Content) - } - logTiming(ctx, "generate.package."+pkg.PkgPath+".total", pkgStart) - return res -} - -// allGeneratedOK reports whether every package result succeeded. -func allGeneratedOK(results []GenerateResult) bool { - if len(results) == 0 { - return false - } - for _, res := range results { - if len(res.Errs) > 0 { - return false - } - } - return true -} - -// detectOutputDir returns a shared directory for the provided file paths. -func detectOutputDir(paths []string) (string, error) { - if len(paths) == 0 { - return "", errors.New("no files to derive output directory from") - } - dir := filepath.Dir(paths[0]) - for _, p := range paths[1:] { - if dir2 := filepath.Dir(p); dir2 != dir { - return "", fmt.Errorf("found conflicting directories %q and %q", dir, dir2) - } - } - return dir, nil -} diff --git a/internal/wire/generate_package_test.go b/internal/wire/generate_package_test.go deleted file mode 100644 index 51b15b1..0000000 --- a/internal/wire/generate_package_test.go +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "context" - "os" - "path/filepath" - "strings" - "testing" - - "golang.org/x/tools/go/packages" -) - -func TestGenerateForPackageOptionAndDetectErrors(t *testing.T) { - res := generateForPackage(context.Background(), &packages.Package{PkgPath: "example.com/empty"}, nil, nil) - if len(res.Errs) == 0 { - t.Fatal("expected error for empty package") - } - if _, err := detectOutputDir(nil); err == nil { - t.Fatal("expected detectOutputDir error") - } -} - -func TestGenerateForPackageCacheKeyError(t *testing.T) { - tempDir := t.TempDir() - missing := filepath.Join(tempDir, "missing.go") - pkg := &packages.Package{ - PkgPath: "example.com/missing", - GoFiles: []string{missing}, - } - res := generateForPackage(context.Background(), pkg, nil, &GenerateOptions{}) - if len(res.Errs) == 0 { - t.Fatal("expected cache key error") - } -} - -func TestGenerateForPackageCacheHit(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - osTempDir = func() string { return tempDir } - - file := writeTempFile(t, tempDir, "hit.go", "package hit\n") - pkg := &packages.Package{ - PkgPath: "example.com/hit", - GoFiles: []string{file}, - } - opts := &GenerateOptions{} - key, err := cacheKeyForPackage(pkg, opts) - if err != nil || key == "" { - t.Fatalf("cacheKeyForPackage failed: %v", err) - } - writeCache(key, []byte("cached")) - res := generateForPackage(context.Background(), pkg, nil, opts) - if string(res.Content) != "cached" { - t.Fatalf("expected cached content, got %q", res.Content) - } -} - -func TestGenerateForPackageFormatError(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - osTempDir = func() string { return tempDir } - - repoRoot := mustRepoRoot(t) - writeTempFile(t, tempDir, "go.mod", strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - appDir := filepath.Join(tempDir, "app") - if err := os.MkdirAll(appDir, 0755); err != nil { - t.Fatalf("MkdirAll failed: %v", err) - } - writeTempFile(t, appDir, "wire.go", strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import \"github.com/goforj/wire\"", - "", - "func Init() string {", - "\twire.Build(NewMessage)", - "\treturn \"\"", - "}", - "", - "func NewMessage() string { return \"ok\" }", - "", - }, "\n")) - - ctx := context.Background() - env := append(os.Environ(), "GOWORK=off") - pkgs, loader, errs := load(ctx, tempDir, env, "", []string{"./app"}) - if len(errs) > 0 || len(pkgs) != 1 { - t.Fatalf("load errors: %v", errs) - } - opts := &GenerateOptions{Header: []byte("invalid")} - res := generateForPackage(ctx, pkgs[0], loader, opts) - if len(res.Errs) == 0 { - t.Fatal("expected format.Source error") - } -} - -func TestAllGeneratedOK(t *testing.T) { - if allGeneratedOK(nil) { - t.Fatal("expected empty results to be false") - } - if allGeneratedOK([]GenerateResult{{Errs: []error{context.DeadlineExceeded}}}) { - t.Fatal("expected errors to be false") - } - if !allGeneratedOK([]GenerateResult{{}}) { - t.Fatal("expected success results to be true") - } -} diff --git a/internal/wire/incremental.go b/internal/wire/incremental.go deleted file mode 100644 index 007027b..0000000 --- a/internal/wire/incremental.go +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "context" - "strconv" - "strings" -) - -const IncrementalEnvVar = "WIRE_INCREMENTAL" - -type incrementalKey struct{} -type incrementalColdBootstrapKey struct{} - -// WithIncremental overrides incremental-mode resolution for the provided -// context. This takes precedence over the environment variable. -func WithIncremental(ctx context.Context, enabled bool) context.Context { - if ctx == nil { - ctx = context.Background() - } - return context.WithValue(ctx, incrementalKey{}, enabled) -} - -func withIncrementalColdBootstrap(ctx context.Context, enabled bool) context.Context { - if ctx == nil { - ctx = context.Background() - } - return context.WithValue(ctx, incrementalColdBootstrapKey{}, enabled) -} - -// IncrementalEnabled reports whether incremental mode is enabled for the -// current operation. A context override takes precedence over env. -func IncrementalEnabled(ctx context.Context, env []string) bool { - if ctx != nil { - if v := ctx.Value(incrementalKey{}); v != nil { - if enabled, ok := v.(bool); ok { - return enabled - } - } - } - raw, ok := lookupEnv(env, IncrementalEnvVar) - if !ok { - return false - } - enabled, err := strconv.ParseBool(strings.TrimSpace(raw)) - if err != nil { - return false - } - return enabled -} - -func incrementalColdBootstrapEnabled(ctx context.Context) bool { - if ctx == nil { - return false - } - if v := ctx.Value(incrementalColdBootstrapKey{}); v != nil { - if enabled, ok := v.(bool); ok { - return enabled - } - } - return false -} - -func lookupEnv(env []string, key string) (string, bool) { - prefix := key + "=" - for i := len(env) - 1; i >= 0; i-- { - if strings.HasPrefix(env[i], prefix) { - return strings.TrimPrefix(env[i], prefix), true - } - } - return "", false -} diff --git a/internal/wire/incremental_bench_test.go b/internal/wire/incremental_bench_test.go deleted file mode 100644 index b300c4f..0000000 --- a/internal/wire/incremental_bench_test.go +++ /dev/null @@ -1,1495 +0,0 @@ -package wire - -import ( - "context" - "fmt" - "os" - "path/filepath" - "sort" - "strconv" - "strings" - "testing" - "time" - "unicode/utf8" -) - -const ( - largeBenchmarkTestPackageCount = 24 - largeBenchmarkHelperCount = 12 -) - -var largeBenchmarkSizes = []int{10, 100, 1000} - -type incrementalScenarioBenchmarkCase struct { - name string - mutate func(tb testing.TB, root string) - measure func(tb testing.TB, root string, env []string, ctx context.Context) incrementalScenarioTrace - wantErr bool -} - -type incrementalScenarioTrace struct { - total time.Duration - labels map[string]time.Duration -} - -type incrementalScenarioBudget struct { - total time.Duration - validateLocal time.Duration - validateExt time.Duration - validateTouch time.Duration - validateTouchHit time.Duration - outputs time.Duration - generateLoad time.Duration - localFastpath time.Duration -} - -type largeRepoPerformanceBudget struct { - shapeTotal time.Duration - localLoad time.Duration - parse time.Duration - typecheck time.Duration - generate time.Duration - knownToggle time.Duration -} - -func BenchmarkGenerateIncrementalFirstSeenShapeChange(b *testing.B) { - cacheHooksMu.Lock() - state := saveCacheHooks() - b.Cleanup(func() { - restoreCacheHooks(state) - cacheHooksMu.Unlock() - }) - - repoRoot := benchmarkRepoRoot(b) - - for i := 0; i < b.N; i++ { - cacheRoot := b.TempDir() - osTempDir = func() string { return cacheRoot } - - root := b.TempDir() - writeIncrementalBenchmarkModule(b, repoRoot, root) - - env := append(os.Environ(), "GOWORK=off") - ctx := WithIncremental(context.Background(), true) - - if _, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { - b.Fatalf("baseline Generate returned errors: %v", errs) - } - - writeBenchmarkFile(b, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string; Count int }", - "", - "func NewMessage() string { return \"ok\" }", - "", - "func NewCount() int { return 7 }", - "", - "func New(msg string, count int) *Foo {", - "\treturn &Foo{Message: msg, Count: count}", - "}", - "", - }, "\n")) - writeBenchmarkFile(b, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, NewCount, New)", - "", - }, "\n")) - - b.StartTimer() - gens, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - b.StopTimer() - if len(errs) > 0 { - b.Fatalf("incremental shape-change Generate returned errors: %v", errs) - } - if len(gens) != 1 || len(gens[0].Errs) > 0 { - b.Fatalf("unexpected Generate results: %+v", gens) - } - } -} - -func BenchmarkGenerateIncrementalScenarioMatrix(b *testing.B) { - cacheHooksMu.Lock() - state := saveCacheHooks() - b.Cleanup(func() { - restoreCacheHooks(state) - cacheHooksMu.Unlock() - }) - - repoRoot := benchmarkRepoRoot(b) - for _, scenario := range incrementalScenarioBenchmarks() { - scenario := scenario - b.Run(scenario.name, func(b *testing.B) { - for i := 0; i < b.N; i++ { - b.StartTimer() - _ = measureIncrementalScenarioOnce(b, repoRoot, scenario) - b.StopTimer() - } - }) - } -} - -func TestPrintIncrementalScenarioBenchmarkTable(t *testing.T) { - if os.Getenv("WIRE_BENCH_SCENARIOS") == "" { - t.Skip("set WIRE_BENCH_SCENARIOS=1 to print the incremental scenario benchmark table") - } - - cacheHooksMu.Lock() - state := saveCacheHooks() - t.Cleanup(func() { - restoreCacheHooks(state) - cacheHooksMu.Unlock() - }) - - repoRoot := benchmarkRepoRoot(t) - rows := [][]string{{ - "scenario", - "total", - "local pkgs", - "external", - "touched", - "touch hit", - "outputs", - "gen load", - "local fastpath", - }} - for _, scenario := range incrementalScenarioBenchmarks() { - trace := measureIncrementalScenarioOnce(t, repoRoot, scenario) - rows = append(rows, []string{ - scenario.name, - formatBenchmarkDuration(trace.total), - formatBenchmarkDuration(trace.label("incremental.preload_manifest.validate_local_packages")), - formatBenchmarkDuration(trace.label("incremental.preload_manifest.validate_external_files")), - formatBenchmarkDuration(trace.label("incremental.preload_manifest.validate_touched")), - formatBenchmarkDuration(trace.label("incremental.preload_manifest.validate_touched_cache_hit")), - formatBenchmarkDuration(trace.label("incremental.preload_manifest.outputs")), - formatBenchmarkDuration(trace.label("generate.load")), - formatBenchmarkDuration(trace.label("incremental.local_fastpath.load")), - }) - } - fmt.Print(renderASCIITable(rows)) -} - -func TestIncrementalScenarioPerformanceBudgets(t *testing.T) { - if os.Getenv("WIRE_PERF_BUDGETS") == "" { - t.Skip("set WIRE_PERF_BUDGETS=1 to enforce incremental scenario performance budgets") - } - - cacheHooksMu.Lock() - state := saveCacheHooks() - t.Cleanup(func() { - restoreCacheHooks(state) - cacheHooksMu.Unlock() - }) - - repoRoot := benchmarkRepoRoot(t) - budgets := incrementalScenarioPerformanceBudgets() - for _, scenario := range incrementalScenarioBenchmarks() { - scenario := scenario - budget, ok := budgets[scenario.name] - if !ok { - t.Fatalf("missing performance budget for scenario %q", scenario.name) - } - t.Run(scenario.name, func(t *testing.T) { - trace := measureIncrementalScenarioMedian(t, repoRoot, scenario, 5) - assertScenarioBudget(t, trace, budget) - }) - } -} - -func TestLargeRepoPerformanceBudgets(t *testing.T) { - if os.Getenv("WIRE_PERF_BUDGETS") == "" { - t.Skip("set WIRE_PERF_BUDGETS=1 to enforce incremental scenario performance budgets") - } - - cacheHooksMu.Lock() - state := saveCacheHooks() - t.Cleanup(func() { - restoreCacheHooks(state) - cacheHooksMu.Unlock() - }) - - repoRoot := benchmarkRepoRoot(t) - budgets := largeRepoPerformanceBudgets() - for _, packageCount := range largeBenchmarkSizes { - packageCount := packageCount - budget, ok := budgets[packageCount] - if !ok { - t.Fatalf("missing large-repo performance budget for size %d", packageCount) - } - t.Run(strconv.Itoa(packageCount), func(t *testing.T) { - trace := measureLargeRepoShapeChangeTraceMedian(t, repoRoot, packageCount, true, 3) - checkBudgetDuration(t, "shape_total", trace.total, budget.shapeTotal) - checkBudgetDuration(t, "local_fastpath_load", trace.label("incremental.local_fastpath.load"), budget.localLoad) - checkBudgetDuration(t, "parse", trace.label("incremental.local_fastpath.parse"), budget.parse) - checkBudgetDuration(t, "typecheck", trace.label("incremental.local_fastpath.typecheck"), budget.typecheck) - checkBudgetDuration(t, "generate", trace.label("incremental.local_fastpath.generate"), budget.generate) - - knownToggle := measureLargeRepoKnownToggleMedian(t, repoRoot, packageCount, 3) - checkBudgetDuration(t, "known_toggle", knownToggle, budget.knownToggle) - }) - } -} - -func BenchmarkGenerateLargeRepoNormalShapeChange(b *testing.B) { - runLargeRepoShapeChangeBenchmarks(b, false) -} - -func BenchmarkGenerateLargeRepoIncrementalShapeChange(b *testing.B) { - runLargeRepoShapeChangeBenchmarks(b, true) -} - -func TestPrintLargeRepoBenchmarkComparisonTable(t *testing.T) { - if os.Getenv("WIRE_BENCH_TABLE") == "" { - t.Skip("set WIRE_BENCH_TABLE=1 to print the large-repo benchmark comparison table") - } - - cacheHooksMu.Lock() - state := saveCacheHooks() - t.Cleanup(func() { - restoreCacheHooks(state) - cacheHooksMu.Unlock() - }) - - repoRoot := benchmarkRepoRoot(t) - rows := make([]largeRepoBenchmarkRow, 0, len(largeBenchmarkSizes)) - for _, packageCount := range largeBenchmarkSizes { - coldNormal := measureLargeRepoColdOnce(t, repoRoot, packageCount, false) - coldIncremental := measureLargeRepoColdOnce(t, repoRoot, packageCount, true) - normal := measureLargeRepoShapeChangeOnce(t, repoRoot, packageCount, false) - incremental := measureLargeRepoShapeChangeOnce(t, repoRoot, packageCount, true) - knownToggle := measureLargeRepoKnownToggleOnce(t, repoRoot, packageCount) - rows = append(rows, largeRepoBenchmarkRow{ - packageCount: packageCount, - coldNormal: coldNormal, - coldIncremental: coldIncremental, - normal: normal, - incremental: incremental, - knownToggle: knownToggle, - }) - } - - table := [][]string{{ - "repo size", - "cold old", - "cold new", - "cold delta", - "shape old", - "shape new", - "shape delta", - "known toggle", - "cold speedup", - "shape speedup", - }} - for _, row := range rows { - table = append(table, []string{ - strconv.Itoa(row.packageCount), - formatBenchmarkDuration(row.coldNormal), - formatBenchmarkDuration(row.coldIncremental), - formatPercentImprovement(row.coldNormal, row.coldIncremental), - formatBenchmarkDuration(row.normal), - formatBenchmarkDuration(row.incremental), - formatPercentImprovement(row.normal, row.incremental), - formatBenchmarkDuration(row.knownToggle), - fmt.Sprintf("%.2fx", speedupRatio(row.coldNormal, row.coldIncremental)), - fmt.Sprintf("%.2fx", speedupRatio(row.normal, row.incremental)), - }) - } - fmt.Print(renderASCIITable(table)) -} - -func TestPrintLargeRepoShapeChangeBreakdownTable(t *testing.T) { - if os.Getenv("WIRE_BENCH_BREAKDOWN") == "" { - t.Skip("set WIRE_BENCH_BREAKDOWN=1 to print the large-repo shape-change breakdown table") - } - - cacheHooksMu.Lock() - state := saveCacheHooks() - t.Cleanup(func() { - restoreCacheHooks(state) - cacheHooksMu.Unlock() - }) - - repoRoot := benchmarkRepoRoot(t) - rows := [][]string{{ - "repo size", - "old total", - "old base load", - "old typed load", - "new total", - "new local load", - "new parse", - "new typecheck", - "new injector solve", - "new format", - "new generate", - "speedup", - }} - for _, packageCount := range largeBenchmarkSizes { - normal := measureLargeRepoShapeChangeTraceOnce(t, repoRoot, packageCount, false) - incremental := measureLargeRepoShapeChangeTraceOnce(t, repoRoot, packageCount, true) - rows = append(rows, []string{ - strconv.Itoa(packageCount), - formatBenchmarkDuration(normal.total), - formatBenchmarkDuration(normal.label("load.packages.base.load")), - formatBenchmarkDuration(normal.label("load.packages.lazy.load")), - formatBenchmarkDuration(incremental.total), - formatBenchmarkDuration(incremental.label("incremental.local_fastpath.load")), - formatBenchmarkDuration(incremental.label("incremental.local_fastpath.parse")), - formatBenchmarkDuration(incremental.label("incremental.local_fastpath.typecheck")), - formatBenchmarkDuration(incremental.label("generate.package.example.com/app/app.injectors")), - formatBenchmarkDuration(incremental.label("generate.package.example.com/app/app.format")), - formatBenchmarkDuration(incremental.label("incremental.local_fastpath.generate")), - fmt.Sprintf("%.2fx", speedupRatio(normal.total, incremental.total)), - }) - } - fmt.Print(renderASCIITable(rows)) -} - -func writeIncrementalBenchmarkModule(tb testing.TB, repoRoot string, root string) { - tb.Helper() - - writeBenchmarkFile(tb, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeBenchmarkFile(tb, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.NewSet)", - "\treturn nil", - "}", - "", - }, "\n")) - - writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return \"ok\" }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n")) - - writeBenchmarkFile(tb, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, New)", - "", - }, "\n")) -} - -func TestGenerateIncrementalLargeRepoShapeChangeMatchesNormalGenerate(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := benchmarkRepoRoot(t) - root := t.TempDir() - writeLargeBenchmarkModule(t, repoRoot, root, largeBenchmarkTestPackageCount) - - env := append(os.Environ(), "GOWORK=off") - incrementalCtx := WithIncremental(context.Background(), true) - - if _, errs := Generate(incrementalCtx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { - t.Fatalf("baseline incremental Generate returned errors: %v", errs) - } - - mutateLargeBenchmarkModule(t, root, largeBenchmarkTestPackageCount/2) - - var incrementalLabels []string - incrementalTimedCtx := WithTiming(incrementalCtx, func(label string, _ time.Duration) { - incrementalLabels = append(incrementalLabels, label) - }) - incrementalGens, errs := Generate(incrementalTimedCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("incremental large-repo Generate returned errors: %v", errs) - } - if len(incrementalGens) != 1 || len(incrementalGens[0].Errs) > 0 { - t.Fatalf("unexpected incremental results: %+v", incrementalGens) - } - if !containsLabel(incrementalLabels, "incremental.local_fastpath.load") { - t.Fatalf("expected large-repo shape change to use local fast path, labels=%v", incrementalLabels) - } - - normalGens, errs := Generate(context.Background(), root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("normal large-repo Generate returned errors: %v", errs) - } - if len(normalGens) != 1 || len(normalGens[0].Errs) > 0 { - t.Fatalf("unexpected normal results: %+v", normalGens) - } - if incrementalGens[0].OutputPath != normalGens[0].OutputPath { - t.Fatalf("output paths differ: incremental=%q normal=%q", incrementalGens[0].OutputPath, normalGens[0].OutputPath) - } - if string(incrementalGens[0].Content) != string(normalGens[0].Content) { - t.Fatal("large-repo shape-changing incremental output differs from normal Generate output") - } -} - -func runLargeRepoShapeChangeBenchmarks(b *testing.B, incremental bool) { - cacheHooksMu.Lock() - state := saveCacheHooks() - b.Cleanup(func() { - restoreCacheHooks(state) - cacheHooksMu.Unlock() - }) - - repoRoot := benchmarkRepoRoot(b) - for _, packageCount := range largeBenchmarkSizes { - packageCount := packageCount - b.Run(fmt.Sprintf("size=%d", packageCount), func(b *testing.B) { - for i := 0; i < b.N; i++ { - b.StartTimer() - _ = measureLargeRepoShapeChangeOnce(b, repoRoot, packageCount, incremental) - b.StopTimer() - } - }) - } -} - -func incrementalScenarioBenchmarks() []incrementalScenarioBenchmarkCase { - return []incrementalScenarioBenchmarkCase{ - { - name: "preload_unchanged", - mutate: func(testing.TB, string) {}, - }, - { - name: "preload_whitespace_only_change", - mutate: func(tb testing.TB, root string) { - writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "", - "func New(msg string) *Foo {", - "", - "\treturn &Foo{Message: helper(msg)}", - "", - "}", - "", - }, "\n")) - }, - }, - { - name: "preload_body_only_change", - mutate: func(tb testing.TB, root string) { - writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string {", - "\treturn helper(SQLText)", - "}", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - }, - }, - { - name: "preload_body_only_repeat_change", - measure: func(tb testing.TB, root string, env []string, ctx context.Context) incrementalScenarioTrace { - writeBodyOnlyScenarioVariant(tb, root, "b") - if _, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { - tb.Fatalf("warm changed variant Generate returned errors: %v", errs) - } - writeBodyOnlyScenarioVariant(tb, root, "a") - if _, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { - tb.Fatalf("reset variant Generate returned errors: %v", errs) - } - writeBodyOnlyScenarioVariant(tb, root, "b") - trace := incrementalScenarioTrace{labels: make(map[string]time.Duration)} - timedCtx := WithTiming(ctx, func(label string, dur time.Duration) { - trace.labels[label] += dur - }) - start := time.Now() - gens, errs := Generate(timedCtx, root, env, []string{"./app"}, &GenerateOptions{}) - trace.total = time.Since(start) - if len(errs) > 0 { - tb.Fatalf("%s: Generate returned errors: %v", "preload_body_only_repeat_change", errs) - } - if len(gens) != 1 || len(gens[0].Errs) > 0 { - tb.Fatalf("%s: unexpected Generate results: %+v", "preload_body_only_repeat_change", gens) - } - return trace - }, - }, - { - name: "local_fastpath_method_body_change", - mutate: func(tb testing.TB, root string) { - writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func (f Foo) Summary() string {", - "\treturn helper(f.Message)", - "}", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n")) - }, - }, - { - name: "preload_const_value_change", - mutate: func(tb testing.TB, root string) { - writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "const SQLText = \"blue\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - }, - }, - { - name: "preload_var_initializer_change", - mutate: func(tb testing.TB, root string) { - writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 2", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - }, - }, - { - name: "local_fastpath_add_top_level_helper", - mutate: func(tb testing.TB, root string) { - writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "func NewTag() string { return \"tag\" }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - }, - }, - { - name: "preload_import_only_implementation_change", - mutate: func(tb testing.TB, root string) { - writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "import \"fmt\"", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return fmt.Sprint(msg) }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - }, - }, - { - name: "local_fastpath_signature_change", - mutate: func(tb testing.TB, root string) { - writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 7", - "", - "type Foo struct {", - "\tMessage string", - "\tCount int", - "}", - "", - "func NewMessage() string { return SQLText }", - "", - "func NewCount() int { return defaultCount }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string, count int) *Foo {", - "\treturn &Foo{Message: helper(msg), Count: count}", - "}", - "", - }, "\n")) - writeBenchmarkFile(tb, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, NewCount, New)", - "", - }, "\n")) - }, - }, - { - name: "local_fastpath_struct_field_addition", - mutate: func(tb testing.TB, root string) { - writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Foo struct {", - "\tMessage string", - "\tCount int", - "}", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg), Count: defaultCount}", - "}", - "", - }, "\n")) - }, - }, - { - name: "local_fastpath_interface_method_addition", - mutate: func(tb testing.TB, root string) { - writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Fooer interface {", - "\tMessage() string", - "\tCount() int", - "}", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - }, - }, - { - name: "fallback_invalid_body_change", - mutate: func(tb testing.TB, root string) { - writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return missing }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - }, - wantErr: true, - }, - } -} - -func incrementalScenarioPerformanceBudgets() map[string]incrementalScenarioBudget { - return map[string]incrementalScenarioBudget{ - "preload_unchanged": { - total: 300 * time.Millisecond, - validateLocal: 25 * time.Millisecond, - validateExt: 25 * time.Millisecond, - validateTouch: 5 * time.Millisecond, - outputs: 5 * time.Millisecond, - }, - "preload_whitespace_only_change": { - total: 300 * time.Millisecond, - validateLocal: 25 * time.Millisecond, - validateExt: 25 * time.Millisecond, - validateTouch: 250 * time.Millisecond, - outputs: 5 * time.Millisecond, - }, - "preload_body_only_change": { - total: 400 * time.Millisecond, - validateLocal: 40 * time.Millisecond, - validateExt: 40 * time.Millisecond, - validateTouch: 250 * time.Millisecond, - outputs: 5 * time.Millisecond, - }, - "preload_body_only_repeat_change": { - total: 150 * time.Millisecond, - validateLocal: 40 * time.Millisecond, - validateExt: 40 * time.Millisecond, - validateTouch: 5 * time.Millisecond, - validateTouchHit: 5 * time.Millisecond, - outputs: 5 * time.Millisecond, - }, - "local_fastpath_method_body_change": { - total: 500 * time.Millisecond, - validateLocal: 60 * time.Millisecond, - validateExt: 60 * time.Millisecond, - localFastpath: 300 * time.Millisecond, - }, - "preload_import_only_implementation_change": { - total: 150 * time.Millisecond, - validateLocal: 40 * time.Millisecond, - validateExt: 40 * time.Millisecond, - validateTouch: 50 * time.Millisecond, - outputs: 5 * time.Millisecond, - }, - "preload_const_value_change": { - total: 400 * time.Millisecond, - validateLocal: 40 * time.Millisecond, - validateExt: 40 * time.Millisecond, - validateTouch: 250 * time.Millisecond, - outputs: 5 * time.Millisecond, - }, - "preload_var_initializer_change": { - total: 400 * time.Millisecond, - validateLocal: 40 * time.Millisecond, - validateExt: 40 * time.Millisecond, - validateTouch: 250 * time.Millisecond, - outputs: 5 * time.Millisecond, - }, - "local_fastpath_add_top_level_helper": { - total: 500 * time.Millisecond, - validateLocal: 60 * time.Millisecond, - validateExt: 60 * time.Millisecond, - localFastpath: 300 * time.Millisecond, - }, - "local_fastpath_signature_change": { - total: 500 * time.Millisecond, - validateLocal: 60 * time.Millisecond, - validateExt: 60 * time.Millisecond, - localFastpath: 300 * time.Millisecond, - }, - "local_fastpath_struct_field_addition": { - total: 500 * time.Millisecond, - validateLocal: 60 * time.Millisecond, - validateExt: 60 * time.Millisecond, - localFastpath: 300 * time.Millisecond, - }, - "local_fastpath_interface_method_addition": { - total: 500 * time.Millisecond, - validateLocal: 60 * time.Millisecond, - validateExt: 60 * time.Millisecond, - localFastpath: 300 * time.Millisecond, - }, - "fallback_invalid_body_change": { - total: 800 * time.Millisecond, - generateLoad: 500 * time.Millisecond, - }, - } -} - -func measureIncrementalScenarioOnce(tb testing.TB, repoRoot string, scenario incrementalScenarioBenchmarkCase) incrementalScenarioTrace { - tb.Helper() - - cacheRoot := tb.TempDir() - osTempDir = func() string { return cacheRoot } - - root := tb.TempDir() - writeIncrementalScenarioBenchmarkModule(tb, repoRoot, root) - - env := append(os.Environ(), "GOWORK=off") - ctx := WithIncremental(context.Background(), true) - - if _, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { - tb.Fatalf("baseline Generate returned errors: %v", errs) - } - - if scenario.measure != nil { - return scenario.measure(tb, root, env, ctx) - } - - scenario.mutate(tb, root) - - trace := incrementalScenarioTrace{labels: make(map[string]time.Duration)} - timedCtx := WithTiming(ctx, func(label string, dur time.Duration) { - trace.labels[label] += dur - }) - start := time.Now() - gens, errs := Generate(timedCtx, root, env, []string{"./app"}, &GenerateOptions{}) - trace.total = time.Since(start) - - if scenario.wantErr { - if len(errs) == 0 { - tb.Fatalf("%s: expected Generate errors", scenario.name) - } - if len(gens) != 0 { - tb.Fatalf("%s: expected no generated results on error, got %+v", scenario.name, gens) - } - return trace - } - - if len(errs) > 0 { - tb.Fatalf("%s: Generate returned errors: %v", scenario.name, errs) - } - if len(gens) != 1 || len(gens[0].Errs) > 0 { - tb.Fatalf("%s: unexpected Generate results: %+v", scenario.name, gens) - } - return trace -} - -func writeIncrementalScenarioBenchmarkModule(tb testing.TB, repoRoot string, root string) { - tb.Helper() - - writeBenchmarkFile(tb, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeBenchmarkFile(tb, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.NewSet)", - "\treturn nil", - "}", - "", - }, "\n")) - - writeBodyOnlyScenarioVariant(tb, root, "green") -} - -func writeBodyOnlyScenarioVariant(tb testing.TB, root string, value string) { - tb.Helper() - writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "const SQLText = \"" + value + "\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - - writeBenchmarkFile(tb, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, New)", - "", - }, "\n")) -} - -func measureIncrementalScenarioMedian(tb testing.TB, repoRoot string, scenario incrementalScenarioBenchmarkCase, samples int) incrementalScenarioTrace { - tb.Helper() - if samples <= 0 { - samples = 1 - } - traces := make([]incrementalScenarioTrace, 0, samples) - for i := 0; i < samples; i++ { - traces = append(traces, measureIncrementalScenarioOnce(tb, repoRoot, scenario)) - } - sort.Slice(traces, func(i, j int) bool { return traces[i].total < traces[j].total }) - return traces[len(traces)/2] -} - -func assertScenarioBudget(t *testing.T, trace incrementalScenarioTrace, budget incrementalScenarioBudget) { - t.Helper() - checkBudgetDuration(t, "total", trace.total, budget.total) - checkBudgetDuration(t, "validate_local_packages", trace.label("incremental.preload_manifest.validate_local_packages"), budget.validateLocal) - checkBudgetDuration(t, "validate_external_files", trace.label("incremental.preload_manifest.validate_external_files"), budget.validateExt) - checkBudgetDuration(t, "validate_touched", trace.label("incremental.preload_manifest.validate_touched"), budget.validateTouch) - checkBudgetDuration(t, "validate_touched_cache_hit", trace.label("incremental.preload_manifest.validate_touched_cache_hit"), budget.validateTouchHit) - checkBudgetDuration(t, "outputs", trace.label("incremental.preload_manifest.outputs"), budget.outputs) - checkBudgetDuration(t, "generate_load", trace.label("generate.load"), budget.generateLoad) - checkBudgetDuration(t, "local_fastpath_load", trace.label("incremental.local_fastpath.load"), budget.localFastpath) -} - -func checkBudgetDuration(t *testing.T, name string, got time.Duration, max time.Duration) { - t.Helper() - if max <= 0 { - return - } - if got > max { - t.Fatalf("%s exceeded budget: got=%s max=%s", name, got, max) - } -} - -func (s incrementalScenarioTrace) label(name string) time.Duration { - if s.labels == nil { - return 0 - } - return s.labels[name] -} - -type largeRepoBenchmarkRow struct { - packageCount int - coldNormal time.Duration - coldIncremental time.Duration - normal time.Duration - incremental time.Duration - knownToggle time.Duration -} - -type shapeChangeTrace struct { - total time.Duration - labels map[string]time.Duration -} - -func largeRepoPerformanceBudgets() map[int]largeRepoPerformanceBudget { - return map[int]largeRepoPerformanceBudget{ - 10: { - shapeTotal: 45 * time.Millisecond, - localLoad: 3 * time.Millisecond, - parse: 500 * time.Microsecond, - typecheck: 4 * time.Millisecond, - generate: 3 * time.Millisecond, - knownToggle: 3 * time.Millisecond, - }, - 100: { - shapeTotal: 35 * time.Millisecond, - localLoad: 20 * time.Millisecond, - parse: 1500 * time.Microsecond, - typecheck: 12 * time.Millisecond, - generate: 20 * time.Millisecond, - knownToggle: 15 * time.Millisecond, - }, - 1000: { - shapeTotal: 260 * time.Millisecond, - localLoad: 110 * time.Millisecond, - parse: 4 * time.Millisecond, - typecheck: 70 * time.Millisecond, - generate: 180 * time.Millisecond, - knownToggle: 90 * time.Millisecond, - }, - } -} - -func measureLargeRepoShapeChangeOnce(tb testing.TB, repoRoot string, packageCount int, incremental bool) time.Duration { - tb.Helper() - - cacheRoot := tb.TempDir() - osTempDir = func() string { return cacheRoot } - - root := tb.TempDir() - writeLargeBenchmarkModule(tb, repoRoot, root, packageCount) - env := append(os.Environ(), "GOWORK=off") - ctx := context.Background() - if incremental { - ctx = WithIncremental(ctx, true) - } - if _, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { - tb.Fatalf("baseline Generate returned errors: %v", errs) - } - - mutateLargeBenchmarkModule(tb, root, packageCount/2) - - start := time.Now() - gens, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - dur := time.Since(start) - if len(errs) > 0 { - tb.Fatalf("shape-change Generate returned errors: %v", errs) - } - if len(gens) != 1 || len(gens[0].Errs) > 0 { - tb.Fatalf("unexpected Generate results: %+v", gens) - } - return dur -} - -func measureLargeRepoShapeChangeTraceOnce(tb testing.TB, repoRoot string, packageCount int, incremental bool) shapeChangeTrace { - tb.Helper() - - cacheRoot := tb.TempDir() - osTempDir = func() string { return cacheRoot } - - root := tb.TempDir() - writeLargeBenchmarkModule(tb, repoRoot, root, packageCount) - env := append(os.Environ(), "GOWORK=off") - ctx := context.Background() - if incremental { - ctx = WithIncremental(ctx, true) - } - if _, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { - tb.Fatalf("baseline Generate returned errors: %v", errs) - } - - mutateLargeBenchmarkModule(tb, root, packageCount/2) - - trace := shapeChangeTrace{labels: make(map[string]time.Duration)} - ctx = WithTiming(ctx, func(label string, dur time.Duration) { - trace.labels[label] += dur - }) - start := time.Now() - gens, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - trace.total = time.Since(start) - if len(errs) > 0 { - tb.Fatalf("shape-change Generate returned errors: %v", errs) - } - if len(gens) != 1 || len(gens[0].Errs) > 0 { - tb.Fatalf("unexpected Generate results: %+v", gens) - } - return trace -} - -func measureLargeRepoShapeChangeTraceMedian(tb testing.TB, repoRoot string, packageCount int, incremental bool, samples int) shapeChangeTrace { - tb.Helper() - if samples <= 0 { - samples = 1 - } - traces := make([]shapeChangeTrace, 0, samples) - for i := 0; i < samples; i++ { - traces = append(traces, measureLargeRepoShapeChangeTraceOnce(tb, repoRoot, packageCount, incremental)) - } - sort.Slice(traces, func(i, j int) bool { return traces[i].total < traces[j].total }) - return traces[len(traces)/2] -} - -func (s shapeChangeTrace) label(name string) time.Duration { - if s.labels == nil { - return 0 - } - return s.labels[name] -} - -func measureLargeRepoColdOnce(tb testing.TB, repoRoot string, packageCount int, incremental bool) time.Duration { - tb.Helper() - - cacheRoot := tb.TempDir() - osTempDir = func() string { return cacheRoot } - - root := tb.TempDir() - writeLargeBenchmarkModule(tb, repoRoot, root, packageCount) - env := append(os.Environ(), "GOWORK=off") - ctx := context.Background() - if incremental { - ctx = WithIncremental(ctx, true) - } - - start := time.Now() - gens, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - dur := time.Since(start) - if len(errs) > 0 { - tb.Fatalf("cold Generate returned errors: %v", errs) - } - if len(gens) != 1 || len(gens[0].Errs) > 0 { - tb.Fatalf("unexpected cold Generate results: %+v", gens) - } - return dur -} - -func measureLargeRepoKnownToggleOnce(tb testing.TB, repoRoot string, packageCount int) time.Duration { - tb.Helper() - - cacheRoot := tb.TempDir() - osTempDir = func() string { return cacheRoot } - - root := tb.TempDir() - writeLargeBenchmarkModule(tb, repoRoot, root, packageCount) - env := append(os.Environ(), "GOWORK=off") - ctx := WithIncremental(context.Background(), true) - mutatedIndex := packageCount / 2 - - if _, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { - tb.Fatalf("baseline Generate returned errors: %v", errs) - } - - mutateLargeBenchmarkModule(tb, root, mutatedIndex) - gens, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - tb.Fatalf("mutated Generate returned errors: %v", errs) - } - if len(gens) != 1 || len(gens[0].Errs) > 0 { - tb.Fatalf("unexpected mutated Generate results: %+v", gens) - } - - writeLargeBenchmarkPackage(tb, root, mutatedIndex, false) - - start := time.Now() - gens, errs = Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - dur := time.Since(start) - if len(errs) > 0 { - tb.Fatalf("toggle Generate returned errors: %v", errs) - } - if len(gens) != 1 || len(gens[0].Errs) > 0 { - tb.Fatalf("unexpected toggle Generate results: %+v", gens) - } - return dur -} - -func measureLargeRepoKnownToggleMedian(tb testing.TB, repoRoot string, packageCount int, samples int) time.Duration { - tb.Helper() - if samples <= 0 { - samples = 1 - } - values := make([]time.Duration, 0, samples) - for i := 0; i < samples; i++ { - values = append(values, measureLargeRepoKnownToggleOnce(tb, repoRoot, packageCount)) - } - sort.Slice(values, func(i, j int) bool { return values[i] < values[j] }) - return values[len(values)/2] -} - -func formatPercentImprovement(normal time.Duration, incremental time.Duration) string { - if normal <= 0 { - return "0.0%" - } - improvement := 100 * (float64(normal-incremental) / float64(normal)) - return fmt.Sprintf("%.1f%%", improvement) -} - -func speedupRatio(normal time.Duration, incremental time.Duration) float64 { - if incremental <= 0 { - return 0 - } - return float64(normal) / float64(incremental) -} - -func formatBenchmarkDuration(d time.Duration) string { - switch { - case d >= time.Second: - return fmt.Sprintf("%.2fs", d.Seconds()) - case d >= time.Millisecond: - return fmt.Sprintf("%.2fms", float64(d)/float64(time.Millisecond)) - case d >= time.Microsecond: - return fmt.Sprintf("%.2fus", float64(d)/float64(time.Microsecond)) - default: - return d.String() - } -} - -func writeLargeBenchmarkModule(tb testing.TB, repoRoot string, root string, packageCount int) { - tb.Helper() - - writeBenchmarkFile(tb, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - wireImports := []string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"github.com/goforj/wire\"", - } - appImports := []string{ - "package app", - "", - "import (", - } - buildArgs := []string{"\twire.Build("} - argNames := make([]string, 0, packageCount) - for i := 0; i < packageCount; i++ { - pkgName := fmt.Sprintf("layer%02d", i) - wireImports = append(wireImports, fmt.Sprintf("\t%s %q", pkgName, "example.com/app/"+pkgName)) - appImports = append(appImports, fmt.Sprintf("\t%s %q", pkgName, "example.com/app/"+pkgName)) - buildArgs = append(buildArgs, fmt.Sprintf("\t\t%s.NewSet,", pkgName)) - argNames = append(argNames, fmt.Sprintf("dep%02d *%s.Token", i, pkgName)) - } - wireImports = append(wireImports, ")", "") - appImports = append(appImports, ")", "") - wireFile := append([]string{}, wireImports...) - wireFile = append(wireFile, "func Init() *App {") - wireFile = append(wireFile, buildArgs...) - wireFile = append(wireFile, "\t\tNewApp,", "\t)", "\treturn nil", "}", "") - writeBenchmarkFile(tb, filepath.Join(root, "app", "wire.go"), strings.Join(wireFile, "\n")) - - appGo := append(appImports[:len(appImports)-2], // reuse imports without trailing blank line - ")", - "", - "type App struct {", - "\tCount int", - "}", - "", - fmt.Sprintf("func NewApp(%s) *App {", strings.Join(argNames, ", ")), - fmt.Sprintf("\treturn &App{Count: %d}", packageCount), - "}", - "", - ) - writeBenchmarkFile(tb, filepath.Join(root, "app", "app.go"), strings.Join(appGo, "\n")) - - for i := 0; i < packageCount; i++ { - writeLargeBenchmarkPackage(tb, root, i, false) - } -} - -func mutateLargeBenchmarkModule(tb testing.TB, root string, mutatedIndex int) { - tb.Helper() - writeLargeBenchmarkPackage(tb, root, mutatedIndex, true) -} - -func writeLargeBenchmarkPackage(tb testing.TB, root string, index int, mutated bool) { - tb.Helper() - - pkgName := fmt.Sprintf("layer%02d", index) - pkgDir := filepath.Join(root, pkgName) - - writeBenchmarkFile(tb, filepath.Join(pkgDir, "helpers.go"), renderLargeBenchmarkHelpers(pkgName, index, mutated)) - writeBenchmarkFile(tb, filepath.Join(pkgDir, "wire.go"), renderLargeBenchmarkWire(pkgName, mutated)) -} - -func renderLargeBenchmarkHelpers(pkgName string, index int, mutated bool) string { - lines := []string{ - "package " + pkgName, - "", - "import (", - "\t\"fmt\"", - "\t\"strconv\"", - "\t\"strings\"", - ")", - "", - "type Config struct {", - "\tLabel string", - "}", - "", - "type Weight int", - "", - "type Token struct {", - "\tConfig Config", - "\tWeight Weight", - "}", - "", - fmt.Sprintf("func NewConfig() Config { return Config{Label: %q} }", pkgName), - "", - } - if mutated { - lines = append(lines, - fmt.Sprintf("func NewWeight() Weight { return Weight(%d) }", index+100), - "", - "func New(cfg Config, weight Weight) *Token {", - fmt.Sprintf("\t_ = helper%02d()", largeBenchmarkHelperCount-1), - "\treturn &Token{Config: cfg, Weight: weight}", - "}", - "", - ) - } else { - lines = append(lines, - "func New(cfg Config) *Token {", - fmt.Sprintf("\t_ = helper%02d()", largeBenchmarkHelperCount-1), - "\treturn &Token{Config: cfg}", - "}", - "", - ) - } - for i := 0; i < largeBenchmarkHelperCount; i++ { - lines = append(lines, fmt.Sprintf("func helper%02d() string {", i)) - lines = append(lines, fmt.Sprintf("\treturn strings.ToUpper(fmt.Sprintf(\"%%s-%%d\", %q, %d)) + strconv.Itoa(%d)", pkgName, i, index+i)) - lines = append(lines, "}", "") - } - return strings.Join(lines, "\n") -} - -func renderLargeBenchmarkWire(pkgName string, mutated bool) string { - lines := []string{ - "package " + pkgName, - "", - "import (", - "\t\"github.com/goforj/wire\"", - ")", - "", - } - if mutated { - lines = append(lines, "var NewSet = wire.NewSet(NewConfig, NewWeight, New)", "") - } else { - lines = append(lines, "var NewSet = wire.NewSet(NewConfig, New)", "") - } - return strings.Join(lines, "\n") -} - -func strconvQuote(s string) string { - return fmt.Sprintf("%q", s) -} - -func benchmarkRepoRoot(tb testing.TB) string { - tb.Helper() - wd, err := os.Getwd() - if err != nil { - tb.Fatalf("Getwd failed: %v", err) - } - repoRoot := filepath.Clean(filepath.Join(wd, "..", "..")) - if _, err := os.Stat(filepath.Join(repoRoot, "go.mod")); err != nil { - tb.Fatalf("repo root not found at %s: %v", repoRoot, err) - } - return repoRoot -} - -func writeBenchmarkFile(tb testing.TB, path string, content string) { - tb.Helper() - if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { - tb.Fatalf("MkdirAll failed: %v", err) - } - if err := os.WriteFile(path, []byte(content), 0644); err != nil { - tb.Fatalf("WriteFile failed: %v", err) - } -} - -func renderASCIITable(rows [][]string) string { - if len(rows) == 0 { - return "" - } - widths := make([]int, len(rows[0])) - for _, row := range rows { - for i, cell := range row { - if width := utf8.RuneCountInString(cell); width > widths[i] { - widths[i] = width - } - } - } - var b strings.Builder - border := func() { - b.WriteByte('+') - for _, width := range widths { - b.WriteString(strings.Repeat("-", width+2)) - b.WriteByte('+') - } - b.WriteByte('\n') - } - writeRow := func(row []string) { - b.WriteByte('|') - for i, cell := range row { - b.WriteByte(' ') - b.WriteString(cell) - b.WriteString(strings.Repeat(" ", widths[i]-utf8.RuneCountInString(cell)+1)) - b.WriteByte('|') - } - b.WriteByte('\n') - } - border() - writeRow(rows[0]) - border() - for _, row := range rows[1:] { - writeRow(row) - } - border() - return b.String() -} diff --git a/internal/wire/incremental_fingerprint.go b/internal/wire/incremental_fingerprint.go deleted file mode 100644 index be39982..0000000 --- a/internal/wire/incremental_fingerprint.go +++ /dev/null @@ -1,674 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "bytes" - "context" - "crypto/sha256" - "encoding/binary" - "fmt" - "go/ast" - "go/parser" - "go/printer" - "go/token" - "path/filepath" - "sort" - "strings" - - "golang.org/x/tools/go/packages" -) - -const incrementalFingerprintVersion = "wire-incremental-v3" - -type packageFingerprint struct { - Version string - WD string - Tags string - PkgPath string - Files []cacheFile - Dirs []cacheFile - ContentHash string - ShapeHash string - LocalImports []string -} - -type fingerprintStats struct { - localPackages int - metaHits int - metaMisses int - unchanged int - changed int -} - -type incrementalFingerprintSnapshot struct { - stats fingerprintStats - changed []string - touched []string - fingerprints map[string]*packageFingerprint -} - -func analyzeIncrementalFingerprints(ctx context.Context, wd string, env []string, tags string, pkgs []*packages.Package) *incrementalFingerprintSnapshot { - if !IncrementalEnabled(ctx, env) { - return nil - } - start := timeNow() - snapshot := collectIncrementalFingerprints(wd, tags, pkgs) - debugf(ctx, "incremental.fingerprint local_pkgs=%d meta_hits=%d meta_misses=%d unchanged=%d changed=%d total=%s", - snapshot.stats.localPackages, - snapshot.stats.metaHits, - snapshot.stats.metaMisses, - snapshot.stats.unchanged, - snapshot.stats.changed, - timeSince(start), - ) - if len(snapshot.changed) > 0 { - debugf(ctx, "incremental.fingerprint changed_pkgs=%s", strings.Join(snapshot.changed, ", ")) - } - return snapshot -} - -func collectIncrementalFingerprints(wd string, tags string, pkgs []*packages.Package) *incrementalFingerprintSnapshot { - all := collectAllPackages(pkgs) - moduleRoot := findModuleRoot(wd) - snapshot := &incrementalFingerprintSnapshot{ - fingerprints: make(map[string]*packageFingerprint), - } - for _, pkg := range all { - if classifyPackageLocation(moduleRoot, pkg) != "local" { - continue - } - snapshot.stats.localPackages++ - files := packageFingerprintFiles(pkg) - if len(files) == 0 { - continue - } - sort.Strings(files) - metaFiles, err := buildCacheFiles(files) - if err != nil { - snapshot.stats.metaMisses++ - continue - } - key := incrementalFingerprintKey(wd, tags, pkg.PkgPath) - if prev, ok := readIncrementalFingerprint(key); ok && incrementalFingerprintMetaMatches(prev, wd, tags, pkg.PkgPath, metaFiles) { - snapshot.stats.metaHits++ - snapshot.stats.unchanged++ - snapshot.fingerprints[pkg.PkgPath] = prev - continue - } - snapshot.stats.metaMisses++ - snapshot.touched = append(snapshot.touched, pkg.PkgPath) - fp, err := buildPackageFingerprint(wd, tags, pkg, metaFiles) - if err != nil { - continue - } - prev, hadPrev := readIncrementalFingerprint(key) - writeIncrementalFingerprint(key, fp) - snapshot.fingerprints[pkg.PkgPath] = fp - if hadPrev && incrementalFingerprintEquivalent(prev, fp) { - snapshot.stats.unchanged++ - continue - } - snapshot.stats.changed++ - snapshot.changed = append(snapshot.changed, pkg.PkgPath) - } - sort.Strings(snapshot.changed) - sort.Strings(snapshot.touched) - return snapshot -} - -func buildIncrementalManifestSnapshotFromPackages(wd string, tags string, pkgs []*packages.Package) *incrementalFingerprintSnapshot { - all := collectAllPackages(pkgs) - moduleRoot := findModuleRoot(wd) - snapshot := &incrementalFingerprintSnapshot{ - fingerprints: make(map[string]*packageFingerprint), - } - for _, pkg := range all { - if classifyPackageLocation(moduleRoot, pkg) != "local" { - continue - } - files := packageFingerprintFiles(pkg) - if len(files) == 0 { - continue - } - sort.Strings(files) - metaFiles, err := buildCacheFiles(files) - if err != nil { - continue - } - shapeHash, err := packageShapeHashFromSyntax(pkg, files) - if err != nil { - continue - } - localImports := make([]string, 0, len(pkg.Imports)) - for _, imp := range pkg.Imports { - if classifyPackageLocation(moduleRoot, imp) == "local" { - localImports = append(localImports, imp.PkgPath) - } - } - sort.Strings(localImports) - snapshot.fingerprints[pkg.PkgPath] = &packageFingerprint{ - Version: incrementalFingerprintVersion, - WD: packageCacheScope(wd), - Tags: tags, - PkgPath: pkg.PkgPath, - Files: metaFiles, - Dirs: mustBuildPackageDirCacheFiles(files), - ContentHash: mustHashPackageFiles(files), - ShapeHash: shapeHash, - LocalImports: localImports, - } - } - if len(snapshot.fingerprints) == 0 { - return nil - } - return snapshot -} - -func packageFingerprintFiles(pkg *packages.Package) []string { - if pkg == nil { - return nil - } - if len(pkg.CompiledGoFiles) > 0 { - return append([]string(nil), pkg.CompiledGoFiles...) - } - return append([]string(nil), pkg.GoFiles...) -} - -func packageFingerprintDirs(files []string) []string { - if len(files) == 0 { - return nil - } - dirs := make([]string, 0, len(files)) - seen := make(map[string]struct{}, len(files)) - for _, name := range files { - dir := filepath.Clean(filepath.Dir(name)) - if _, ok := seen[dir]; ok { - continue - } - seen[dir] = struct{}{} - dirs = append(dirs, dir) - } - sort.Strings(dirs) - return dirs -} - -func mustBuildPackageDirCacheFiles(files []string) []cacheFile { - dirs := packageFingerprintDirs(files) - if len(dirs) == 0 { - return nil - } - meta, err := buildCacheFiles(dirs) - if err != nil { - return nil - } - return meta -} - -func mustHashPackageFiles(files []string) string { - if len(files) == 0 { - return "" - } - hash, err := hashFiles(files) - if err != nil { - return "" - } - return hash -} - -func incrementalFingerprintEquivalent(a, b *packageFingerprint) bool { - if a == nil || b == nil { - return false - } - if a.ShapeHash != b.ShapeHash || a.PkgPath != b.PkgPath || a.Tags != b.Tags || a.WD != b.WD { - return false - } - if len(a.LocalImports) != len(b.LocalImports) { - return false - } - for i := range a.LocalImports { - if a.LocalImports[i] != b.LocalImports[i] { - return false - } - } - return true -} - -func incrementalFingerprintMetaMatches(prev *packageFingerprint, wd string, tags string, pkgPath string, files []cacheFile) bool { - if prev == nil || prev.Version != incrementalFingerprintVersion { - return false - } - if prev.WD != packageCacheScope(wd) || prev.Tags != tags || prev.PkgPath != pkgPath { - return false - } - if len(prev.Files) != len(files) { - return false - } - for i := range prev.Files { - if prev.Files[i] != files[i] { - return false - } - } - return true -} - -func buildPackageFingerprint(wd string, tags string, pkg *packages.Package, files []cacheFile) (*packageFingerprint, error) { - shapeHash, err := packageShapeHash(packageFingerprintFiles(pkg)) - if err != nil { - return nil, err - } - localImports := make([]string, 0, len(pkg.Imports)) - moduleRoot := findModuleRoot(wd) - for _, imp := range pkg.Imports { - if classifyPackageLocation(moduleRoot, imp) == "local" { - localImports = append(localImports, imp.PkgPath) - } - } - sort.Strings(localImports) - return &packageFingerprint{ - Version: incrementalFingerprintVersion, - WD: packageCacheScope(wd), - Tags: tags, - PkgPath: pkg.PkgPath, - Files: append([]cacheFile(nil), files...), - Dirs: mustBuildPackageDirCacheFiles(packageFingerprintFiles(pkg)), - ContentHash: mustHashPackageFiles(packageFingerprintFiles(pkg)), - ShapeHash: shapeHash, - LocalImports: localImports, - }, nil -} - -func packageShapeHash(files []string) (string, error) { - fset := token.NewFileSet() - var buf bytes.Buffer - for _, name := range files { - file, err := parser.ParseFile(fset, name, nil, parser.SkipObjectResolution) - if err != nil { - return "", err - } - writeSyntaxShapeHash(&buf, fset, file) - buf.WriteByte(0) - } - sum := sha256.Sum256(buf.Bytes()) - return fmt.Sprintf("%x", sum[:]), nil -} - -func packageShapeHashFromSyntax(pkg *packages.Package, files []string) (string, error) { - if pkg == nil || len(pkg.Syntax) == 0 || pkg.Fset == nil { - return packageShapeHash(files) - } - var buf bytes.Buffer - for _, file := range pkg.Syntax { - if file == nil { - continue - } - writeSyntaxShapeHash(&buf, pkg.Fset, file) - buf.WriteByte(0) - } - sum := sha256.Sum256(buf.Bytes()) - return fmt.Sprintf("%x", sum[:]), nil -} - -func writeSyntaxShapeHash(buf *bytes.Buffer, fset *token.FileSet, file *ast.File) { - if file == nil || buf == nil || fset == nil { - return - } - usedImports := usedImportNamesInShape(file) - if file.Name != nil { - buf.WriteString("package ") - buf.WriteString(file.Name.Name) - buf.WriteByte('\n') - } - for _, decl := range file.Decls { - switch decl := decl.(type) { - case *ast.FuncDecl: - writeNodeHash(buf, fset, decl.Recv) - buf.WriteByte(' ') - if decl.Name != nil { - buf.WriteString(decl.Name.Name) - } - buf.WriteByte(' ') - writeNodeHash(buf, fset, decl.Type) - buf.WriteByte('\n') - case *ast.GenDecl: - if writeGenDeclShapeHash(buf, fset, decl, usedImports) { - buf.WriteByte('\n') - } - default: - writeNodeHash(buf, fset, decl) - buf.WriteByte('\n') - } - } -} - -func writeGenDeclShapeHash(buf *bytes.Buffer, fset *token.FileSet, decl *ast.GenDecl, usedImports map[string]struct{}) bool { - if buf == nil || fset == nil || decl == nil { - return false - } - var specBuf bytes.Buffer - wrote := false - for _, spec := range decl.Specs { - switch spec := spec.(type) { - case *ast.ImportSpec: - name := importName(spec) - if name == "_" || name == "." { - if spec.Name != nil { - specBuf.WriteString(spec.Name.Name) - } - specBuf.WriteByte(' ') - writeNodeHash(&specBuf, fset, spec.Path) - specBuf.WriteByte('\n') - wrote = true - break - } - if _, ok := usedImports[name]; !ok { - continue - } - if spec.Name != nil { - specBuf.WriteString(spec.Name.Name) - } - specBuf.WriteByte(' ') - writeNodeHash(&specBuf, fset, spec.Path) - case *ast.TypeSpec: - if spec.Name != nil { - specBuf.WriteString(spec.Name.Name) - } - specBuf.WriteByte(' ') - writeNodeHash(&specBuf, fset, spec.Type) - case *ast.ValueSpec: - for _, name := range spec.Names { - if name != nil { - specBuf.WriteString(name.Name) - } - specBuf.WriteByte(' ') - } - if spec.Type != nil { - writeNodeHash(&specBuf, fset, spec.Type) - } - default: - writeNodeHash(&specBuf, fset, spec) - } - specBuf.WriteByte('\n') - wrote = true - } - if !wrote { - return false - } - buf.WriteString(decl.Tok.String()) - buf.WriteByte(' ') - buf.Write(specBuf.Bytes()) - return true -} - -func usedImportNamesInShape(file *ast.File) map[string]struct{} { - used := make(map[string]struct{}) - if file == nil { - return used - } - record := func(node ast.Node) { - ast.Inspect(node, func(n ast.Node) bool { - sel, ok := n.(*ast.SelectorExpr) - if !ok { - return true - } - ident, ok := sel.X.(*ast.Ident) - if !ok || ident.Name == "" { - return true - } - used[ident.Name] = struct{}{} - return true - }) - } - for _, decl := range file.Decls { - switch decl := decl.(type) { - case *ast.FuncDecl: - if decl.Recv != nil { - record(decl.Recv) - } - if decl.Type != nil { - record(decl.Type) - } - case *ast.GenDecl: - for _, spec := range decl.Specs { - switch spec := spec.(type) { - case *ast.TypeSpec: - if spec.Type != nil { - record(spec.Type) - } - case *ast.ValueSpec: - if spec.Type != nil { - record(spec.Type) - } - } - } - } - } - return used -} - -func writeNodeHash(buf *bytes.Buffer, fset *token.FileSet, node interface{}) { - if buf == nil || fset == nil || node == nil { - return - } - _ = printer.Fprint(buf, fset, node) -} - -func stripFunctionBodies(file *ast.File) { - if file == nil { - return - } - for _, decl := range file.Decls { - if fn, ok := decl.(*ast.FuncDecl); ok { - fn.Body = nil - fn.Doc = nil - } - } -} - -func incrementalFingerprintKey(wd string, tags string, pkgPath string) string { - h := sha256.New() - h.Write([]byte(incrementalFingerprintVersion)) - h.Write([]byte{0}) - h.Write([]byte(packageCacheScope(wd))) - h.Write([]byte{0}) - h.Write([]byte(tags)) - h.Write([]byte{0}) - h.Write([]byte(pkgPath)) - return fmt.Sprintf("%x", h.Sum(nil)) -} - -func incrementalFingerprintPath(key string) string { - return filepath.Join(cacheDir(), key+".ifp") -} - -func readIncrementalFingerprint(key string) (*packageFingerprint, bool) { - data, err := osReadFile(incrementalFingerprintPath(key)) - if err != nil { - return nil, false - } - fp, err := decodeIncrementalFingerprint(data) - if err != nil { - return nil, false - } - return fp, true -} - -func writeIncrementalFingerprint(key string, fp *packageFingerprint) { - data, err := encodeIncrementalFingerprint(fp) - if err != nil { - return - } - dir := cacheDir() - if err := osMkdirAll(dir, 0755); err != nil { - return - } - tmp, err := osCreateTemp(dir, key+".ifp-") - if err != nil { - return - } - _, writeErr := tmp.Write(data) - closeErr := tmp.Close() - if writeErr != nil || closeErr != nil { - osRemove(tmp.Name()) - return - } - if err := osRename(tmp.Name(), incrementalFingerprintPath(key)); err != nil { - osRemove(tmp.Name()) - } -} - -func encodeIncrementalFingerprint(fp *packageFingerprint) ([]byte, error) { - var buf bytes.Buffer - writeString := func(s string) error { - if err := binary.Write(&buf, binary.LittleEndian, uint32(len(s))); err != nil { - return err - } - _, err := buf.WriteString(s) - return err - } - writeCacheFiles := func(files []cacheFile) error { - if err := binary.Write(&buf, binary.LittleEndian, uint32(len(files))); err != nil { - return err - } - for _, f := range files { - if err := writeString(f.Path); err != nil { - return err - } - if err := binary.Write(&buf, binary.LittleEndian, f.Size); err != nil { - return err - } - if err := binary.Write(&buf, binary.LittleEndian, f.ModTime); err != nil { - return err - } - } - return nil - } - writeStrings := func(items []string) error { - if err := binary.Write(&buf, binary.LittleEndian, uint32(len(items))); err != nil { - return err - } - for _, item := range items { - if err := writeString(item); err != nil { - return err - } - } - return nil - } - if fp == nil { - return nil, fmt.Errorf("nil fingerprint") - } - for _, s := range []string{fp.Version, fp.WD, fp.Tags, fp.PkgPath, fp.ShapeHash} { - if err := writeString(s); err != nil { - return nil, err - } - } - if err := writeCacheFiles(fp.Files); err != nil { - return nil, err - } - if err := writeStrings(fp.LocalImports); err != nil { - return nil, err - } - return buf.Bytes(), nil -} - -func decodeIncrementalFingerprint(data []byte) (*packageFingerprint, error) { - r := bytes.NewReader(data) - readString := func() (string, error) { - var n uint32 - if err := binary.Read(r, binary.LittleEndian, &n); err != nil { - return "", err - } - buf := make([]byte, n) - if _, err := r.Read(buf); err != nil { - return "", err - } - return string(buf), nil - } - readCacheFiles := func() ([]cacheFile, error) { - var n uint32 - if err := binary.Read(r, binary.LittleEndian, &n); err != nil { - return nil, err - } - out := make([]cacheFile, 0, n) - for i := uint32(0); i < n; i++ { - path, err := readString() - if err != nil { - return nil, err - } - var size int64 - if err := binary.Read(r, binary.LittleEndian, &size); err != nil { - return nil, err - } - var modTime int64 - if err := binary.Read(r, binary.LittleEndian, &modTime); err != nil { - return nil, err - } - out = append(out, cacheFile{Path: path, Size: size, ModTime: modTime}) - } - return out, nil - } - readStrings := func() ([]string, error) { - var n uint32 - if err := binary.Read(r, binary.LittleEndian, &n); err != nil { - return nil, err - } - out := make([]string, 0, n) - for i := uint32(0); i < n; i++ { - item, err := readString() - if err != nil { - return nil, err - } - out = append(out, item) - } - return out, nil - } - version, err := readString() - if err != nil { - return nil, err - } - wd, err := readString() - if err != nil { - return nil, err - } - tags, err := readString() - if err != nil { - return nil, err - } - pkgPath, err := readString() - if err != nil { - return nil, err - } - shapeHash, err := readString() - if err != nil { - return nil, err - } - files, err := readCacheFiles() - if err != nil { - return nil, err - } - localImports, err := readStrings() - if err != nil { - return nil, err - } - return &packageFingerprint{ - Version: version, - WD: wd, - Tags: tags, - PkgPath: pkgPath, - ShapeHash: shapeHash, - Files: files, - LocalImports: localImports, - }, nil -} diff --git a/internal/wire/incremental_fingerprint_test.go b/internal/wire/incremental_fingerprint_test.go deleted file mode 100644 index 920d08e..0000000 --- a/internal/wire/incremental_fingerprint_test.go +++ /dev/null @@ -1,142 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "os" - "path/filepath" - "testing" - - "golang.org/x/tools/go/packages" -) - -func TestPackageShapeHashIgnoresFunctionBodies(t *testing.T) { - dir := t.TempDir() - file := writeTempFile(t, dir, "pkg.go", "package p\n\nfunc Hello() string { return \"a\" }\n") - hash1, err := packageShapeHash([]string{file}) - if err != nil { - t.Fatalf("packageShapeHash first failed: %v", err) - } - if err := os.WriteFile(file, []byte("package p\n\nfunc Hello() string { return \"b\" }\n"), 0644); err != nil { - t.Fatalf("WriteFile failed: %v", err) - } - hash2, err := packageShapeHash([]string{file}) - if err != nil { - t.Fatalf("packageShapeHash second failed: %v", err) - } - if hash1 != hash2 { - t.Fatalf("body-only change should not affect shape hash: %q vs %q", hash1, hash2) - } -} - -func TestPackageShapeHashIgnoresConstValueChanges(t *testing.T) { - dir := t.TempDir() - file := writeTempFile(t, dir, "pkg.go", "package p\n\nconst SQLText = \"a\"\n") - hash1, err := packageShapeHash([]string{file}) - if err != nil { - t.Fatalf("packageShapeHash first failed: %v", err) - } - if err := os.WriteFile(file, []byte("package p\n\nconst SQLText = \"b\"\n"), 0644); err != nil { - t.Fatalf("WriteFile failed: %v", err) - } - hash2, err := packageShapeHash([]string{file}) - if err != nil { - t.Fatalf("packageShapeHash second failed: %v", err) - } - if hash1 != hash2 { - t.Fatalf("const-value change should not affect shape hash: %q vs %q", hash1, hash2) - } -} - -func TestPackageShapeHashIgnoresImplementationOnlyImportChanges(t *testing.T) { - dir := t.TempDir() - file := writeTempFile(t, dir, "pkg.go", "package p\n\nfunc Hello() string { return \"a\" }\n") - hash1, err := packageShapeHash([]string{file}) - if err != nil { - t.Fatalf("packageShapeHash first failed: %v", err) - } - if err := os.WriteFile(file, []byte("package p\n\nimport \"fmt\"\n\nfunc Hello() string { return fmt.Sprint(\"a\") }\n"), 0644); err != nil { - t.Fatalf("WriteFile failed: %v", err) - } - hash2, err := packageShapeHash([]string{file}) - if err != nil { - t.Fatalf("packageShapeHash second failed: %v", err) - } - if hash1 != hash2 { - t.Fatalf("implementation-only import change should not affect shape hash: %q vs %q", hash1, hash2) - } -} - -func TestIncrementalFingerprintRoundTrip(t *testing.T) { - fp := &packageFingerprint{ - Version: incrementalFingerprintVersion, - WD: "/tmp/app", - Tags: "dev", - PkgPath: "example.com/app", - ShapeHash: "shape", - Files: []cacheFile{{Path: "/tmp/app/pkg.go", Size: 12, ModTime: 34}}, - LocalImports: []string{"example.com/dep"}, - } - data, err := encodeIncrementalFingerprint(fp) - if err != nil { - t.Fatalf("encodeIncrementalFingerprint failed: %v", err) - } - got, err := decodeIncrementalFingerprint(data) - if err != nil { - t.Fatalf("decodeIncrementalFingerprint failed: %v", err) - } - if !incrementalFingerprintEquivalent(fp, got) { - t.Fatalf("fingerprint mismatch after round-trip: got %+v want %+v", got, fp) - } - if len(got.Files) != 1 || got.Files[0] != fp.Files[0] { - t.Fatalf("file metadata mismatch after round-trip: got %+v want %+v", got.Files, fp.Files) - } -} - -func TestCollectIncrementalFingerprintsTreatsBodyOnlyChangeAsUnchanged(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - root := t.TempDir() - writeFile(t, filepath.Join(root, "go.mod"), "module example.com/test\n\ngo 1.21\n") - file := filepath.Join(root, "app", "app.go") - writeFile(t, file, "package app\n\nfunc Hello() string { return \"a\" }\n") - pkg := &packages.Package{ - PkgPath: "example.com/app", - CompiledGoFiles: []string{file}, - GoFiles: []string{file}, - Imports: map[string]*packages.Package{}, - } - - snapshot := collectIncrementalFingerprints(root, "", []*packages.Package{pkg}) - if snapshot.stats.changed != 1 || len(snapshot.changed) != 1 || snapshot.changed[0] != pkg.PkgPath { - t.Fatalf("first run stats=%+v changed=%v", snapshot.stats, snapshot.changed) - } - - if err := os.WriteFile(file, []byte("package app\n\nfunc Hello() string { return \"b\" }\n"), 0644); err != nil { - t.Fatalf("WriteFile failed: %v", err) - } - snapshot = collectIncrementalFingerprints(root, "", []*packages.Package{pkg}) - if snapshot.stats.unchanged != 1 { - t.Fatalf("body-only change should be unchanged by shape, stats=%+v changed=%v", snapshot.stats, snapshot.changed) - } - if len(snapshot.changed) != 0 { - t.Fatalf("body-only change should not report changed packages, got %v", snapshot.changed) - } -} diff --git a/internal/wire/incremental_graph.go b/internal/wire/incremental_graph.go deleted file mode 100644 index 37b3d0f..0000000 --- a/internal/wire/incremental_graph.go +++ /dev/null @@ -1,306 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "bytes" - "context" - "crypto/sha256" - "encoding/binary" - "fmt" - "path/filepath" - "sort" - "strings" - - "golang.org/x/tools/go/packages" -) - -const incrementalGraphVersion = "wire-incremental-graph-v1" - -type incrementalGraph struct { - Version string - WD string - Tags string - Roots []string - LocalReverse map[string][]string -} - -func analyzeIncrementalGraph(ctx context.Context, wd string, env []string, tags string, pkgs []*packages.Package, snapshot *incrementalFingerprintSnapshot) { - if !IncrementalEnabled(ctx, env) || snapshot == nil { - return - } - graph := buildIncrementalGraph(wd, tags, pkgs) - writeIncrementalGraph(incrementalGraphKey(wd, tags, graph.Roots), graph) - if len(snapshot.changed) == 0 { - return - } - affected := affectedRoots(graph, snapshot.changed) - if len(affected) > 0 { - debugf(ctx, "incremental.graph changed=%s affected_roots=%s", stringsJoin(snapshot.changed), stringsJoin(affected)) - } else { - debugf(ctx, "incremental.graph changed=%s affected_roots=", stringsJoin(snapshot.changed)) - } -} - -func buildIncrementalGraph(wd string, tags string, pkgs []*packages.Package) *incrementalGraph { - moduleRoot := findModuleRoot(wd) - graph := &incrementalGraph{ - Version: incrementalGraphVersion, - WD: packageCacheScope(wd), - Tags: tags, - Roots: make([]string, 0, len(pkgs)), - LocalReverse: make(map[string][]string), - } - for _, pkg := range pkgs { - if pkg == nil { - continue - } - graph.Roots = append(graph.Roots, pkg.PkgPath) - } - sort.Strings(graph.Roots) - for _, pkg := range collectAllPackages(pkgs) { - if classifyPackageLocation(moduleRoot, pkg) != "local" { - continue - } - for _, imp := range pkg.Imports { - if classifyPackageLocation(moduleRoot, imp) != "local" { - continue - } - graph.LocalReverse[imp.PkgPath] = append(graph.LocalReverse[imp.PkgPath], pkg.PkgPath) - } - } - for path := range graph.LocalReverse { - sort.Strings(graph.LocalReverse[path]) - } - return graph -} - -func affectedRoots(graph *incrementalGraph, changed []string) []string { - if graph == nil || len(changed) == 0 { - return nil - } - rootSet := make(map[string]struct{}, len(graph.Roots)) - for _, root := range graph.Roots { - rootSet[root] = struct{}{} - } - seen := make(map[string]struct{}) - queue := append([]string(nil), changed...) - affected := make(map[string]struct{}) - for len(queue) > 0 { - cur := queue[0] - queue = queue[1:] - if _, ok := seen[cur]; ok { - continue - } - seen[cur] = struct{}{} - if _, ok := rootSet[cur]; ok { - affected[cur] = struct{}{} - } - for _, next := range graph.LocalReverse[cur] { - if _, ok := seen[next]; !ok { - queue = append(queue, next) - } - } - } - out := make([]string, 0, len(affected)) - for root := range affected { - out = append(out, root) - } - sort.Strings(out) - return out -} - -func incrementalGraphKey(wd string, tags string, roots []string) string { - h := sha256.New() - h.Write([]byte(incrementalGraphVersion)) - h.Write([]byte{0}) - h.Write([]byte(packageCacheScope(wd))) - h.Write([]byte{0}) - h.Write([]byte(tags)) - h.Write([]byte{0}) - for _, root := range roots { - h.Write([]byte(root)) - h.Write([]byte{0}) - } - return fmt.Sprintf("%x", h.Sum(nil)) -} - -func incrementalGraphPath(key string) string { - return filepath.Join(cacheDir(), key+".igr") -} - -func writeIncrementalGraph(key string, graph *incrementalGraph) { - data, err := encodeIncrementalGraph(graph) - if err != nil { - return - } - dir := cacheDir() - if err := osMkdirAll(dir, 0755); err != nil { - return - } - tmp, err := osCreateTemp(dir, key+".igr-") - if err != nil { - return - } - _, writeErr := tmp.Write(data) - closeErr := tmp.Close() - if writeErr != nil || closeErr != nil { - osRemove(tmp.Name()) - return - } - if err := osRename(tmp.Name(), incrementalGraphPath(key)); err != nil { - osRemove(tmp.Name()) - } -} - -func readIncrementalGraph(key string) (*incrementalGraph, bool) { - data, err := osReadFile(incrementalGraphPath(key)) - if err != nil { - return nil, false - } - graph, err := decodeIncrementalGraph(data) - if err != nil { - return nil, false - } - return graph, true -} - -func encodeIncrementalGraph(graph *incrementalGraph) ([]byte, error) { - if graph == nil { - return nil, fmt.Errorf("nil incremental graph") - } - var buf bytes.Buffer - writeString := func(s string) error { - if err := binary.Write(&buf, binary.LittleEndian, uint32(len(s))); err != nil { - return err - } - _, err := buf.WriteString(s) - return err - } - for _, s := range []string{graph.Version, graph.WD, graph.Tags} { - if err := writeString(s); err != nil { - return nil, err - } - } - if err := binary.Write(&buf, binary.LittleEndian, uint32(len(graph.Roots))); err != nil { - return nil, err - } - for _, root := range graph.Roots { - if err := writeString(root); err != nil { - return nil, err - } - } - keys := make([]string, 0, len(graph.LocalReverse)) - for k := range graph.LocalReverse { - keys = append(keys, k) - } - sort.Strings(keys) - if err := binary.Write(&buf, binary.LittleEndian, uint32(len(keys))); err != nil { - return nil, err - } - for _, k := range keys { - if err := writeString(k); err != nil { - return nil, err - } - children := append([]string(nil), graph.LocalReverse[k]...) - sort.Strings(children) - if err := binary.Write(&buf, binary.LittleEndian, uint32(len(children))); err != nil { - return nil, err - } - for _, child := range children { - if err := writeString(child); err != nil { - return nil, err - } - } - } - return buf.Bytes(), nil -} - -func decodeIncrementalGraph(data []byte) (*incrementalGraph, error) { - r := bytes.NewReader(data) - readString := func() (string, error) { - var n uint32 - if err := binary.Read(r, binary.LittleEndian, &n); err != nil { - return "", err - } - buf := make([]byte, n) - if _, err := r.Read(buf); err != nil { - return "", err - } - return string(buf), nil - } - version, err := readString() - if err != nil { - return nil, err - } - wd, err := readString() - if err != nil { - return nil, err - } - tags, err := readString() - if err != nil { - return nil, err - } - var rootCount uint32 - if err := binary.Read(r, binary.LittleEndian, &rootCount); err != nil { - return nil, err - } - roots := make([]string, 0, rootCount) - for i := uint32(0); i < rootCount; i++ { - root, err := readString() - if err != nil { - return nil, err - } - roots = append(roots, root) - } - var edgeCount uint32 - if err := binary.Read(r, binary.LittleEndian, &edgeCount); err != nil { - return nil, err - } - reverse := make(map[string][]string, edgeCount) - for i := uint32(0); i < edgeCount; i++ { - k, err := readString() - if err != nil { - return nil, err - } - var n uint32 - if err := binary.Read(r, binary.LittleEndian, &n); err != nil { - return nil, err - } - children := make([]string, 0, n) - for j := uint32(0); j < n; j++ { - child, err := readString() - if err != nil { - return nil, err - } - children = append(children, child) - } - reverse[k] = children - } - return &incrementalGraph{ - Version: version, - WD: wd, - Tags: tags, - Roots: roots, - LocalReverse: reverse, - }, nil -} - -func stringsJoin(items []string) string { - if len(items) == 0 { - return "" - } - return strings.Join(items, ",") -} diff --git a/internal/wire/incremental_graph_test.go b/internal/wire/incremental_graph_test.go deleted file mode 100644 index 8a91b54..0000000 --- a/internal/wire/incremental_graph_test.go +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "path/filepath" - "reflect" - "testing" - - "golang.org/x/tools/go/packages" -) - -func TestIncrementalGraphRoundTrip(t *testing.T) { - graph := &incrementalGraph{ - Version: incrementalGraphVersion, - WD: "/tmp/app", - Tags: "dev", - Roots: []string{"example.com/app", "example.com/other"}, - LocalReverse: map[string][]string{ - "example.com/dep": {"example.com/app"}, - "example.com/sub": {"example.com/dep", "example.com/other"}, - }, - } - data, err := encodeIncrementalGraph(graph) - if err != nil { - t.Fatalf("encodeIncrementalGraph failed: %v", err) - } - got, err := decodeIncrementalGraph(data) - if err != nil { - t.Fatalf("decodeIncrementalGraph failed: %v", err) - } - if !reflect.DeepEqual(got, graph) { - t.Fatalf("graph round-trip mismatch:\n got=%+v\nwant=%+v", got, graph) - } -} - -func TestAffectedRoots(t *testing.T) { - graph := &incrementalGraph{ - Roots: []string{"example.com/app", "example.com/other"}, - LocalReverse: map[string][]string{ - "example.com/dep": {"example.com/app"}, - "example.com/sub": {"example.com/dep", "example.com/other"}, - }, - } - got := affectedRoots(graph, []string{"example.com/sub"}) - want := []string{"example.com/app", "example.com/other"} - if !reflect.DeepEqual(got, want) { - t.Fatalf("affectedRoots=%v want %v", got, want) - } -} - -func TestBuildIncrementalGraph(t *testing.T) { - root := t.TempDir() - writeFile(t, filepath.Join(root, "go.mod"), "module example.com/test\n\ngo 1.21\n") - - appFile := filepath.Join(root, "app", "app.go") - depFile := filepath.Join(root, "dep", "dep.go") - writeFile(t, appFile, "package app\n") - writeFile(t, depFile, "package dep\n") - - dep := &packages.Package{ - PkgPath: "example.com/test/dep", - CompiledGoFiles: []string{depFile}, - GoFiles: []string{depFile}, - Imports: map[string]*packages.Package{}, - } - app := &packages.Package{ - PkgPath: "example.com/test/app", - CompiledGoFiles: []string{appFile}, - GoFiles: []string{appFile}, - Imports: map[string]*packages.Package{ - "example.com/test/dep": dep, - }, - } - - graph := buildIncrementalGraph(root, "", []*packages.Package{app}) - if len(graph.Roots) != 1 || graph.Roots[0] != app.PkgPath { - t.Fatalf("unexpected roots: %v", graph.Roots) - } - got := graph.LocalReverse[dep.PkgPath] - want := []string{app.PkgPath} - if !reflect.DeepEqual(got, want) { - t.Fatalf("unexpected reverse edges: got=%v want=%v", got, want) - } -} diff --git a/internal/wire/incremental_manifest.go b/internal/wire/incremental_manifest.go deleted file mode 100644 index 8fab10e..0000000 --- a/internal/wire/incremental_manifest.go +++ /dev/null @@ -1,1158 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "bytes" - "context" - "crypto/sha256" - "encoding/binary" - "fmt" - "go/token" - "os" - "path/filepath" - "sort" - "strings" - - "golang.org/x/tools/go/packages" -) - -const incrementalManifestVersion = "wire-incremental-manifest-v3" - -type incrementalManifest struct { - Version string - WD string - Tags string - Prefix string - HeaderHash string - EnvHash string - Patterns []string - LocalPackages []packageFingerprint - ExternalPkgs []externalPackageExport - ExternalFiles []cacheFile - ExtraFiles []cacheFile - Outputs []incrementalOutput -} - -type externalPackageExport struct { - PkgPath string - ExportFile string -} - -type incrementalOutput struct { - PkgPath string - OutputPath string - ContentKey string -} - -type incrementalPreloadState struct { - selectorKey string - manifest *incrementalManifest - valid bool - currentLocal []packageFingerprint - touched []string - reason string -} - -type incrementalPreloadValidation struct { - valid bool - currentLocal []packageFingerprint - touched []string - reason string -} - -const touchedValidationVersion = "wire-touched-validation-v1" - -func readPreloadIncrementalManifestResults(ctx context.Context, wd string, env []string, patterns []string, opts *GenerateOptions) ([]GenerateResult, bool) { - state, ok := prepareIncrementalPreloadState(ctx, wd, env, patterns, opts) - return readPreloadIncrementalManifestResultsFromState(ctx, wd, env, patterns, opts, state, ok) -} - -func readPreloadIncrementalManifestResultsFromState(ctx context.Context, wd string, env []string, patterns []string, opts *GenerateOptions, state *incrementalPreloadState, ok bool) ([]GenerateResult, bool) { - if !ok { - debugf(ctx, "incremental.preload_manifest miss reason=no_manifest") - return nil, false - } - if state.valid { - validateStart := timeNow() - if len(state.touched) > 0 { - debugf(ctx, "incremental.preload_manifest touched=%s", strings.Join(state.touched, ",")) - } - if err := validateIncrementalPreloadTouchedPackages(ctx, wd, env, opts, state.currentLocal, state.touched); err != nil { - logTiming(ctx, "incremental.preload_manifest.validate_touched", validateStart) - if shouldBypassIncrementalManifestAfterFastPathError(err) { - invalidateIncrementalPreloadState(state) - } - debugf(ctx, "incremental.preload_manifest miss reason=touched_validation") - return nil, false - } - logTiming(ctx, "incremental.preload_manifest.validate_touched", validateStart) - outputsStart := timeNow() - results, ok := incrementalManifestOutputs(state.manifest) - logTiming(ctx, "incremental.preload_manifest.outputs", outputsStart) - if !ok { - debugf(ctx, "incremental.preload_manifest miss reason=outputs") - return nil, false - } - if manifestNeedsLocalRefresh(state.manifest.LocalPackages, state.currentLocal) { - refreshed := *state.manifest - refreshed.LocalPackages = append([]packageFingerprint(nil), state.currentLocal...) - writeIncrementalManifestFile(state.selectorKey, &refreshed) - writeIncrementalManifestFile(incrementalManifestStateKey(state.selectorKey, refreshed.LocalPackages), &refreshed) - } - debugf(ctx, "incremental.preload_manifest hit outputs=%d", len(results)) - return results, true - } else if archived := readStateIncrementalManifest(state.selectorKey, state.currentLocal); archived != nil { - if validation := incrementalManifestPreloadValid(ctx, archived, wd, env, patterns, opts); validation.valid { - results, ok := incrementalManifestOutputs(archived) - if !ok { - debugf(ctx, "incremental.preload_manifest miss reason=state_outputs") - return nil, false - } - writeIncrementalManifestFile(state.selectorKey, archived) - debugf(ctx, "incremental.preload_manifest state_hit outputs=%d", len(results)) - return results, true - } - debugf(ctx, "incremental.preload_manifest miss reason=%s", state.reason) - return nil, false - } else { - debugf(ctx, "incremental.preload_manifest miss reason=%s", state.reason) - return nil, false - } -} - -func prepareIncrementalPreloadState(ctx context.Context, wd string, env []string, patterns []string, opts *GenerateOptions) (*incrementalPreloadState, bool) { - selectorKey := incrementalManifestSelectorKey(wd, env, patterns, opts) - manifest, ok := readIncrementalManifest(selectorKey) - if !ok { - return nil, false - } - validation := incrementalManifestPreloadValid(ctx, manifest, wd, env, patterns, opts) - return &incrementalPreloadState{ - selectorKey: selectorKey, - manifest: manifest, - valid: validation.valid, - currentLocal: validation.currentLocal, - touched: validation.touched, - reason: validation.reason, - }, true -} - -func readIncrementalManifestResults(ctx context.Context, wd string, env []string, patterns []string, opts *GenerateOptions, pkgs []*packages.Package, snapshot *incrementalFingerprintSnapshot) ([]GenerateResult, bool) { - if snapshot == nil || snapshot.stats.changed != 0 { - return nil, false - } - key := incrementalManifestSelectorKey(wd, env, patterns, opts) - manifest, ok := readIncrementalManifest(key) - if !ok || !incrementalManifestValid(manifest, wd, env, patterns, opts, pkgs) { - return nil, false - } - results := make([]GenerateResult, 0, len(manifest.Outputs)) - for _, out := range manifest.Outputs { - content, ok := readCache(out.ContentKey) - if !ok { - return nil, false - } - results = append(results, GenerateResult{ - PkgPath: out.PkgPath, - OutputPath: out.OutputPath, - Content: content, - }) - } - debugf(ctx, "incremental.manifest hit outputs=%d", len(results)) - return results, true -} - -func writeIncrementalManifest(wd string, env []string, patterns []string, opts *GenerateOptions, pkgs []*packages.Package, snapshot *incrementalFingerprintSnapshot, generated []GenerateResult) { - writeIncrementalManifestWithOptions(wd, env, patterns, opts, pkgs, snapshot, generated, true) -} - -func writeIncrementalManifestWithOptions(wd string, env []string, patterns []string, opts *GenerateOptions, pkgs []*packages.Package, snapshot *incrementalFingerprintSnapshot, generated []GenerateResult, includeExternalFiles bool) { - if snapshot == nil || len(generated) == 0 { - return - } - scope := runCacheScope(wd, patterns) - externalPkgs := buildExternalPackageExports(wd, pkgs) - var externalFiles []cacheFile - if includeExternalFiles { - var err error - externalFiles, err = buildExternalPackageFiles(wd, pkgs) - if err != nil { - return - } - } - manifest := &incrementalManifest{ - Version: incrementalManifestVersion, - WD: scope, - Tags: opts.Tags, - Prefix: opts.PrefixOutputFile, - HeaderHash: headerHash(opts.Header), - EnvHash: envHash(env), - Patterns: normalizePatternsForScope(wd, packageCacheScope(wd), patterns), - LocalPackages: snapshotPackageFingerprints(snapshot), - ExternalPkgs: externalPkgs, - ExternalFiles: externalFiles, - ExtraFiles: extraCacheFiles(wd), - } - for _, out := range generated { - if len(out.Content) == 0 || out.OutputPath == "" { - continue - } - contentKey := incrementalContentKey(out.Content) - writeCache(contentKey, out.Content) - manifest.Outputs = append(manifest.Outputs, incrementalOutput{ - PkgPath: out.PkgPath, - OutputPath: out.OutputPath, - ContentKey: contentKey, - }) - } - if len(manifest.Outputs) == 0 { - return - } - selectorKey := incrementalManifestSelectorKey(wd, env, patterns, opts) - stateKey := incrementalManifestStateKey(selectorKey, manifest.LocalPackages) - writeIncrementalManifestFile(selectorKey, manifest) - writeIncrementalManifestFile(stateKey, manifest) -} - -func incrementalManifestSelectorKey(wd string, env []string, patterns []string, opts *GenerateOptions) string { - h := sha256.New() - h.Write([]byte(incrementalManifestVersion)) - h.Write([]byte{0}) - h.Write([]byte(runCacheScope(wd, patterns))) - h.Write([]byte{0}) - h.Write([]byte(envHash(env))) - h.Write([]byte{0}) - h.Write([]byte(opts.Tags)) - h.Write([]byte{0}) - h.Write([]byte(opts.PrefixOutputFile)) - h.Write([]byte{0}) - h.Write([]byte(headerHash(opts.Header))) - h.Write([]byte{0}) - for _, p := range normalizePatternsForScope(wd, packageCacheScope(wd), patterns) { - h.Write([]byte(p)) - h.Write([]byte{0}) - } - return fmt.Sprintf("%x", h.Sum(nil)) -} - -func snapshotPackageFingerprints(snapshot *incrementalFingerprintSnapshot) []packageFingerprint { - if snapshot == nil || len(snapshot.fingerprints) == 0 { - return nil - } - paths := make([]string, 0, len(snapshot.fingerprints)) - for path := range snapshot.fingerprints { - paths = append(paths, path) - } - sort.Strings(paths) - out := make([]packageFingerprint, 0, len(paths)) - for _, path := range paths { - if fp := snapshot.fingerprints[path]; fp != nil { - out = append(out, *fp) - } - } - return out -} - -func buildExternalPackageFiles(wd string, pkgs []*packages.Package) ([]cacheFile, error) { - moduleRoot := findModuleRoot(wd) - seen := make(map[string]struct{}) - var files []string - for _, pkg := range collectAllPackages(pkgs) { - if classifyPackageLocation(moduleRoot, pkg) == "local" { - continue - } - names := pkg.CompiledGoFiles - if len(names) == 0 { - names = pkg.GoFiles - } - for _, name := range names { - clean := filepath.Clean(name) - if _, ok := seen[clean]; ok { - continue - } - seen[clean] = struct{}{} - files = append(files, clean) - } - } - sort.Strings(files) - return buildCacheFiles(files) -} - -func buildExternalPackageExports(wd string, pkgs []*packages.Package) []externalPackageExport { - out := make([]externalPackageExport, 0) - for _, pkg := range collectAllPackages(pkgs) { - if pkg == nil || pkg.PkgPath == "" || pkg.ExportFile == "" { - continue - } - out = append(out, externalPackageExport{ - PkgPath: pkg.PkgPath, - ExportFile: pkg.ExportFile, - }) - } - sort.Slice(out, func(i, j int) bool { return out[i].PkgPath < out[j].PkgPath }) - return out -} - -func incrementalManifestValid(manifest *incrementalManifest, wd string, env []string, patterns []string, opts *GenerateOptions, pkgs []*packages.Package) bool { - if manifest == nil || manifest.Version != incrementalManifestVersion { - return false - } - if manifest.WD != runCacheScope(wd, patterns) || manifest.Tags != opts.Tags || manifest.Prefix != opts.PrefixOutputFile { - return false - } - if manifest.HeaderHash != headerHash(opts.Header) || manifest.EnvHash != envHash(env) { - return false - } - normalizedPatterns := normalizePatternsForScope(wd, packageCacheScope(wd), patterns) - if len(manifest.Patterns) != len(normalizedPatterns) { - return false - } - for i, p := range normalizedPatterns { - if manifest.Patterns[i] != p { - return false - } - } - currentExternal, err := buildExternalPackageFiles(wd, pkgs) - if err != nil || len(currentExternal) != len(manifest.ExternalFiles) { - return false - } - for i := range currentExternal { - if currentExternal[i] != manifest.ExternalFiles[i] { - return false - } - } - if len(manifest.ExtraFiles) > 0 { - current, err := buildCacheFilesFromMeta(manifest.ExtraFiles) - if err != nil || len(current) != len(manifest.ExtraFiles) { - return false - } - for i := range current { - if current[i] != manifest.ExtraFiles[i] { - return false - } - } - } - return len(manifest.Outputs) > 0 -} - -func incrementalManifestPreloadValid(ctx context.Context, manifest *incrementalManifest, wd string, env []string, patterns []string, opts *GenerateOptions) incrementalPreloadValidation { - if manifest == nil || manifest.Version != incrementalManifestVersion { - return incrementalPreloadValidation{reason: "version"} - } - if manifest.WD != runCacheScope(wd, patterns) || manifest.Tags != opts.Tags || manifest.Prefix != opts.PrefixOutputFile { - return incrementalPreloadValidation{reason: "config"} - } - if manifest.HeaderHash != headerHash(opts.Header) || manifest.EnvHash != envHash(env) { - return incrementalPreloadValidation{reason: "env"} - } - normalizedPatterns := normalizePatternsForScope(wd, packageCacheScope(wd), patterns) - if len(manifest.Patterns) != len(normalizedPatterns) { - return incrementalPreloadValidation{reason: "patterns.length"} - } - for i, p := range normalizedPatterns { - if manifest.Patterns[i] != p { - return incrementalPreloadValidation{reason: "patterns.value"} - } - } - if len(manifest.ExtraFiles) > 0 { - extraStart := timeNow() - current, err := buildCacheFilesFromMeta(manifest.ExtraFiles) - logTiming(ctx, "incremental.preload_manifest.validate_extra_files", extraStart) - if err != nil || len(current) != len(manifest.ExtraFiles) { - return incrementalPreloadValidation{reason: "extra_files"} - } - for i := range current { - if current[i] != manifest.ExtraFiles[i] { - return incrementalPreloadValidation{reason: "extra_files.diff"} - } - } - } - localStart := timeNow() - packagesState := incrementalManifestCurrentLocalPackages(ctx, manifest.LocalPackages) - logTiming(ctx, "incremental.preload_manifest.validate_local_packages", localStart) - if !packagesState.valid { - return incrementalPreloadValidation{ - currentLocal: packagesState.currentLocal, - touched: packagesState.touched, - reason: "local_packages." + packagesState.reason, - } - } - if len(manifest.ExternalFiles) > 0 { - externalStart := timeNow() - current, err := buildCacheFilesFromMeta(manifest.ExternalFiles) - logTiming(ctx, "incremental.preload_manifest.validate_external_files", externalStart) - if err != nil || len(current) != len(manifest.ExternalFiles) { - return incrementalPreloadValidation{ - currentLocal: packagesState.currentLocal, - touched: packagesState.touched, - reason: "external_files", - } - } - for i := range current { - if current[i] != manifest.ExternalFiles[i] { - return incrementalPreloadValidation{ - currentLocal: packagesState.currentLocal, - touched: packagesState.touched, - reason: "external_files.diff", - } - } - } - } - if len(manifest.Outputs) == 0 { - return incrementalPreloadValidation{ - currentLocal: packagesState.currentLocal, - touched: packagesState.touched, - reason: "outputs", - } - } - return incrementalPreloadValidation{ - valid: true, - currentLocal: packagesState.currentLocal, - touched: packagesState.touched, - } -} - -type incrementalLocalPackagesState struct { - valid bool - currentLocal []packageFingerprint - touched []string - reason string -} - -func incrementalManifestCurrentLocalPackages(ctx context.Context, local []packageFingerprint) incrementalLocalPackagesState { - currentState := make([]packageFingerprint, 0, len(local)) - touched := make([]string, 0, len(local)) - var firstReason string - for _, fp := range local { - if len(fp.Files) == 0 { - if firstReason == "" { - firstReason = fp.PkgPath + ".files" - } - continue - } - storedFiles := filesFromMeta(fp.Files) - if len(storedFiles) == 0 { - if firstReason == "" { - firstReason = fp.PkgPath + ".stored_files" - } - continue - } - currentMeta, err := buildCacheFiles(storedFiles) - if err != nil { - debugf(ctx, "incremental.preload_manifest local_pkg=%s meta_error=%v", fp.PkgPath, err) - if firstReason == "" { - firstReason = fp.PkgPath + ".meta_error" - } - continue - } - currentFP := fp - currentFP.Files = append([]cacheFile(nil), currentMeta...) - sameMeta := len(currentMeta) == len(fp.Files) - if sameMeta { - for i := range currentMeta { - if currentMeta[i] != fp.Files[i] { - sameMeta = false - break - } - } - } - if !sameMeta { - if diffs := describeCacheFileDiffs(fp.Files, currentMeta); len(diffs) > 0 { - debugf(ctx, "incremental.preload_manifest local_pkg=%s meta_diff=%s", fp.PkgPath, strings.Join(diffs, "; ")) - } - contentHash, err := hashFiles(storedFiles) - if err != nil { - debugf(ctx, "incremental.preload_manifest local_pkg=%s content_error=%v", fp.PkgPath, err) - if firstReason == "" { - firstReason = fp.PkgPath + ".content_error" - } - continue - } - currentFP.ContentHash = contentHash - if contentHash != fp.ContentHash { - debugf(ctx, "incremental.preload_manifest local_pkg=%s stored_content=%s current_content=%s hash_files=%s", fp.PkgPath, fp.ContentHash, contentHash, strings.Join(storedFiles, ",")) - shapeHash, err := packageShapeHash(storedFiles) - if err != nil { - debugf(ctx, "incremental.preload_manifest local_pkg=%s shape_error=%v", fp.PkgPath, err) - if firstReason == "" { - firstReason = fp.PkgPath + ".shape_error" - } - continue - } - currentFP.ShapeHash = shapeHash - if shapeHash != fp.ShapeHash { - debugf(ctx, "incremental.preload_manifest local_pkg=%s stored_shape=%s current_shape=%s files=%s", fp.PkgPath, fp.ShapeHash, shapeHash, strings.Join(storedFiles, ",")) - if firstReason == "" { - firstReason = fp.PkgPath + ".shape_mismatch" - } - } else { - debugf(ctx, "incremental.preload_manifest local_pkg=%s content_changed_shape_unchanged", fp.PkgPath) - touched = append(touched, fp.PkgPath) - } - } - } - currentDirs, dirsChanged, err := packageDirectoryMetaChanged(fp, storedFiles) - if err != nil { - debugf(ctx, "incremental.preload_manifest local_pkg=%s dir_meta_error=%v", fp.PkgPath, err) - if firstReason == "" { - firstReason = fp.PkgPath + ".dir_meta_error" - } - continue - } - currentFP.Dirs = currentDirs - if dirsChanged { - if changed, err := packageDirectoryIntroducedRelevantFiles(fp.Files); err != nil { - debugf(ctx, "incremental.preload_manifest local_pkg=%s dir_scan_error=%v", fp.PkgPath, err) - if firstReason == "" { - firstReason = fp.PkgPath + ".dir_scan_error" - } - continue - } else if changed { - debugf(ctx, "incremental.preload_manifest local_pkg=%s introduced_relevant_files=true", fp.PkgPath) - if firstReason == "" { - firstReason = fp.PkgPath + ".introduced_relevant_files" - } - } - } - currentState = append(currentState, currentFP) - } - if firstReason != "" { - return incrementalLocalPackagesState{ - currentLocal: currentState, - touched: touched, - reason: firstReason, - } - } - sort.Strings(touched) - return incrementalLocalPackagesState{ - valid: true, - currentLocal: currentState, - touched: touched, - } -} - -func validateIncrementalPreloadTouchedPackages(ctx context.Context, wd string, env []string, opts *GenerateOptions, local []packageFingerprint, touched []string) error { - if len(touched) == 0 { - return nil - } - cacheKey := touchedValidationKey(wd, env, opts, local, touched) - if cacheKey != "" { - cacheHitStart := timeNow() - if _, ok := readCache(cacheKey); ok { - logTiming(ctx, "incremental.preload_manifest.validate_touched_cache_hit", cacheHitStart) - return nil - } - } - cfg := &packages.Config{ - Context: ctx, - Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedExportsFile | packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesSizes, - Dir: wd, - Env: env, - BuildFlags: []string{"-tags=wireinject"}, - Fset: token.NewFileSet(), - } - if len(opts.Tags) > 0 { - cfg.BuildFlags[0] += " " + opts.Tags - } - loadStart := timeNow() - pkgs, err := packages.Load(cfg, touched...) - logTiming(ctx, "incremental.preload_manifest.validate_touched_load", loadStart) - if err != nil { - return err - } - errorsStart := timeNow() - byPath := make(map[string]*packages.Package, len(pkgs)) - for _, pkg := range pkgs { - if pkg != nil { - byPath[pkg.PkgPath] = pkg - } - } - for _, path := range touched { - if pkg := byPath[path]; pkg != nil && len(pkg.Errors) > 0 { - logTiming(ctx, "incremental.preload_manifest.validate_touched_errors", errorsStart) - return formatLocalTypeCheckError(wd, pkg.PkgPath, pkg.Errors) - } - } - logTiming(ctx, "incremental.preload_manifest.validate_touched_errors", errorsStart) - if cacheKey != "" { - cacheWriteStart := timeNow() - writeCache(cacheKey, []byte("ok")) - logTiming(ctx, "incremental.preload_manifest.validate_touched_cache_write", cacheWriteStart) - } - return nil -} - -func touchedValidationKey(wd string, env []string, opts *GenerateOptions, local []packageFingerprint, touched []string) string { - if len(touched) == 0 { - return "" - } - byPath := fingerprintsFromSlice(local) - h := sha256.New() - h.Write([]byte(touchedValidationVersion)) - h.Write([]byte{0}) - h.Write([]byte(packageCacheScope(wd))) - h.Write([]byte{0}) - h.Write([]byte(envHash(env))) - h.Write([]byte{0}) - if opts != nil { - h.Write([]byte(opts.Tags)) - } - h.Write([]byte{0}) - for _, pkgPath := range touched { - fp := byPath[pkgPath] - if fp == nil || fp.ContentHash == "" { - return "" - } - h.Write([]byte(pkgPath)) - h.Write([]byte{0}) - h.Write([]byte(fp.ContentHash)) - h.Write([]byte{0}) - } - return fmt.Sprintf("%x", h.Sum(nil)) -} - -func incrementalManifestOutputs(manifest *incrementalManifest) ([]GenerateResult, bool) { - results := make([]GenerateResult, 0, len(manifest.Outputs)) - for _, out := range manifest.Outputs { - content, ok := readCache(out.ContentKey) - if !ok { - return nil, false - } - results = append(results, GenerateResult{ - PkgPath: out.PkgPath, - OutputPath: out.OutputPath, - Content: content, - }) - } - return results, true -} - -func readStateIncrementalManifest(selectorKey string, local []packageFingerprint) *incrementalManifest { - if len(local) == 0 { - return nil - } - stateKey := incrementalManifestStateKey(selectorKey, local) - manifest, ok := readIncrementalManifest(stateKey) - if !ok { - return nil - } - return manifest -} - -func incrementalManifestStateKey(selectorKey string, local []packageFingerprint) string { - h := sha256.New() - h.Write([]byte(selectorKey)) - h.Write([]byte{0}) - for _, fp := range snapshotPackageFingerprints(&incrementalFingerprintSnapshot{fingerprints: fingerprintsFromSlice(local)}) { - h.Write([]byte(fp.PkgPath)) - h.Write([]byte{0}) - h.Write([]byte(fp.ShapeHash)) - h.Write([]byte{0}) - } - return fmt.Sprintf("%x", h.Sum(nil)) -} - -func fingerprintsFromSlice(local []packageFingerprint) map[string]*packageFingerprint { - if len(local) == 0 { - return nil - } - out := make(map[string]*packageFingerprint, len(local)) - for i := range local { - fp := local[i] - out[fp.PkgPath] = &fp - } - return out -} - -func filesFromMeta(files []cacheFile) []string { - out := make([]string, 0, len(files)) - for _, f := range files { - out = append(out, filepath.Clean(f.Path)) - } - sort.Strings(out) - return out -} - -func describeCacheFileDiffs(stored []cacheFile, current []cacheFile) []string { - if len(stored) == 0 && len(current) == 0 { - return nil - } - storedByPath := make(map[string]cacheFile, len(stored)) - currentByPath := make(map[string]cacheFile, len(current)) - for _, file := range stored { - storedByPath[filepath.Clean(file.Path)] = file - } - for _, file := range current { - currentByPath[filepath.Clean(file.Path)] = file - } - paths := make([]string, 0, len(storedByPath)+len(currentByPath)) - seen := make(map[string]struct{}, len(storedByPath)+len(currentByPath)) - for path := range storedByPath { - if _, ok := seen[path]; ok { - continue - } - seen[path] = struct{}{} - paths = append(paths, path) - } - for path := range currentByPath { - if _, ok := seen[path]; ok { - continue - } - seen[path] = struct{}{} - paths = append(paths, path) - } - sort.Strings(paths) - diffs := make([]string, 0, len(paths)) - for _, path := range paths { - storedFile, storedOK := storedByPath[path] - currentFile, currentOK := currentByPath[path] - switch { - case !storedOK: - diffs = append(diffs, fmt.Sprintf("%s added size=%d mtime=%d", path, currentFile.Size, currentFile.ModTime)) - case !currentOK: - diffs = append(diffs, fmt.Sprintf("%s removed size=%d mtime=%d", path, storedFile.Size, storedFile.ModTime)) - case storedFile != currentFile: - diffs = append(diffs, fmt.Sprintf("%s size:%d->%d mtime:%d->%d", path, storedFile.Size, currentFile.Size, storedFile.ModTime, currentFile.ModTime)) - } - } - return diffs -} - -func manifestNeedsLocalRefresh(stored []packageFingerprint, current []packageFingerprint) bool { - if len(stored) != len(current) { - return false - } - for i := range stored { - if stored[i].PkgPath != current[i].PkgPath { - return false - } - if stored[i].ContentHash == "" && current[i].ContentHash != "" { - return true - } - if len(stored[i].Dirs) == 0 && len(current[i].Dirs) > 0 { - return true - } - } - return false -} - -func packageDirectoryMetaChanged(fp packageFingerprint, storedFiles []string) ([]cacheFile, bool, error) { - dirs := packageFingerprintDirs(storedFiles) - if len(dirs) == 0 { - return nil, false, nil - } - current, err := buildCacheFiles(dirs) - if err != nil { - return nil, false, err - } - if len(fp.Dirs) != len(current) { - return current, true, nil - } - for i := range current { - if current[i] != fp.Dirs[i] { - return current, true, nil - } - } - return current, false, nil -} - -func packageDirectoryIntroducedRelevantFiles(files []cacheFile) (bool, error) { - dirs := make(map[string]struct{}) - old := make(map[string]struct{}, len(files)) - for _, f := range files { - path := filepath.Clean(f.Path) - dirs[filepath.Dir(path)] = struct{}{} - old[path] = struct{}{} - } - for dir := range dirs { - entries, err := os.ReadDir(dir) - if err != nil { - return false, err - } - for _, entry := range entries { - if entry.IsDir() { - continue - } - name := entry.Name() - if !strings.HasSuffix(name, ".go") { - continue - } - if strings.HasSuffix(name, "_test.go") { - continue - } - if strings.HasSuffix(name, "wire_gen.go") { - continue - } - path := filepath.Clean(filepath.Join(dir, name)) - if _, ok := old[path]; !ok { - return true, nil - } - } - } - return false, nil -} - -func incrementalManifestPath(key string) string { - return filepath.Join(cacheDir(), key+".iman") -} - -func readIncrementalManifest(key string) (*incrementalManifest, bool) { - data, err := osReadFile(incrementalManifestPath(key)) - if err != nil { - return nil, false - } - manifest, err := decodeIncrementalManifest(data) - if err != nil { - return nil, false - } - return manifest, true -} - -func writeIncrementalManifestFile(key string, manifest *incrementalManifest) { - data, err := encodeIncrementalManifest(manifest) - if err != nil { - return - } - dir := cacheDir() - if err := osMkdirAll(dir, 0755); err != nil { - return - } - tmp, err := osCreateTemp(dir, key+".iman-") - if err != nil { - return - } - _, writeErr := tmp.Write(data) - closeErr := tmp.Close() - if writeErr != nil || closeErr != nil { - osRemove(tmp.Name()) - return - } - if err := osRename(tmp.Name(), incrementalManifestPath(key)); err != nil { - osRemove(tmp.Name()) - } -} - -func removeIncrementalManifestFile(key string) { - if key == "" { - return - } - _ = osRemove(incrementalManifestPath(key)) -} - -func encodeIncrementalManifest(manifest *incrementalManifest) ([]byte, error) { - var buf bytes.Buffer - if manifest == nil { - return nil, fmt.Errorf("nil incremental manifest") - } - writeString := func(s string) error { - if err := binary.Write(&buf, binary.LittleEndian, uint32(len(s))); err != nil { - return err - } - _, err := buf.WriteString(s) - return err - } - writeCacheFiles := func(files []cacheFile) error { - if err := binary.Write(&buf, binary.LittleEndian, uint32(len(files))); err != nil { - return err - } - for _, f := range files { - if err := writeString(f.Path); err != nil { - return err - } - if err := binary.Write(&buf, binary.LittleEndian, f.Size); err != nil { - return err - } - if err := binary.Write(&buf, binary.LittleEndian, f.ModTime); err != nil { - return err - } - } - return nil - } - writeExternalPkgs := func(pkgs []externalPackageExport) error { - if err := binary.Write(&buf, binary.LittleEndian, uint32(len(pkgs))); err != nil { - return err - } - for _, pkg := range pkgs { - if err := writeString(pkg.PkgPath); err != nil { - return err - } - if err := writeString(pkg.ExportFile); err != nil { - return err - } - } - return nil - } - writeFingerprints := func(fps []packageFingerprint) error { - if err := binary.Write(&buf, binary.LittleEndian, uint32(len(fps))); err != nil { - return err - } - for _, fp := range fps { - for _, s := range []string{fp.Version, fp.WD, fp.Tags, fp.PkgPath, fp.ShapeHash} { - if err := writeString(s); err != nil { - return err - } - } - if err := writeCacheFiles(fp.Files); err != nil { - return err - } - if err := binary.Write(&buf, binary.LittleEndian, uint32(len(fp.LocalImports))); err != nil { - return err - } - for _, imp := range fp.LocalImports { - if err := writeString(imp); err != nil { - return err - } - } - } - return nil - } - writeOutputs := func(outputs []incrementalOutput) error { - if err := binary.Write(&buf, binary.LittleEndian, uint32(len(outputs))); err != nil { - return err - } - for _, out := range outputs { - for _, s := range []string{out.PkgPath, out.OutputPath, out.ContentKey} { - if err := writeString(s); err != nil { - return err - } - } - } - return nil - } - for _, s := range []string{manifest.Version, manifest.WD, manifest.Tags, manifest.Prefix, manifest.HeaderHash, manifest.EnvHash} { - if err := writeString(s); err != nil { - return nil, err - } - } - if err := binary.Write(&buf, binary.LittleEndian, uint32(len(manifest.Patterns))); err != nil { - return nil, err - } - for _, p := range manifest.Patterns { - if err := writeString(p); err != nil { - return nil, err - } - } - if err := writeFingerprints(manifest.LocalPackages); err != nil { - return nil, err - } - if err := writeExternalPkgs(manifest.ExternalPkgs); err != nil { - return nil, err - } - if err := writeCacheFiles(manifest.ExternalFiles); err != nil { - return nil, err - } - if err := writeCacheFiles(manifest.ExtraFiles); err != nil { - return nil, err - } - if err := writeOutputs(manifest.Outputs); err != nil { - return nil, err - } - return buf.Bytes(), nil -} - -func decodeIncrementalManifest(data []byte) (*incrementalManifest, error) { - r := bytes.NewReader(data) - readString := func() (string, error) { - var n uint32 - if err := binary.Read(r, binary.LittleEndian, &n); err != nil { - return "", err - } - buf := make([]byte, n) - if _, err := r.Read(buf); err != nil { - return "", err - } - return string(buf), nil - } - readCacheFiles := func() ([]cacheFile, error) { - var n uint32 - if err := binary.Read(r, binary.LittleEndian, &n); err != nil { - return nil, err - } - out := make([]cacheFile, 0, n) - for i := uint32(0); i < n; i++ { - path, err := readString() - if err != nil { - return nil, err - } - var size int64 - if err := binary.Read(r, binary.LittleEndian, &size); err != nil { - return nil, err - } - var modTime int64 - if err := binary.Read(r, binary.LittleEndian, &modTime); err != nil { - return nil, err - } - out = append(out, cacheFile{Path: path, Size: size, ModTime: modTime}) - } - return out, nil - } - readExternalPkgs := func() ([]externalPackageExport, error) { - var n uint32 - if err := binary.Read(r, binary.LittleEndian, &n); err != nil { - return nil, err - } - out := make([]externalPackageExport, 0, n) - for i := uint32(0); i < n; i++ { - pkgPath, err := readString() - if err != nil { - return nil, err - } - exportFile, err := readString() - if err != nil { - return nil, err - } - out = append(out, externalPackageExport{PkgPath: pkgPath, ExportFile: exportFile}) - } - return out, nil - } - readFingerprints := func() ([]packageFingerprint, error) { - var n uint32 - if err := binary.Read(r, binary.LittleEndian, &n); err != nil { - return nil, err - } - out := make([]packageFingerprint, 0, n) - for i := uint32(0); i < n; i++ { - version, err := readString() - if err != nil { - return nil, err - } - wd, err := readString() - if err != nil { - return nil, err - } - tags, err := readString() - if err != nil { - return nil, err - } - pkgPath, err := readString() - if err != nil { - return nil, err - } - shapeHash, err := readString() - if err != nil { - return nil, err - } - files, err := readCacheFiles() - if err != nil { - return nil, err - } - var importCount uint32 - if err := binary.Read(r, binary.LittleEndian, &importCount); err != nil { - return nil, err - } - localImports := make([]string, 0, importCount) - for j := uint32(0); j < importCount; j++ { - imp, err := readString() - if err != nil { - return nil, err - } - localImports = append(localImports, imp) - } - out = append(out, packageFingerprint{ - Version: version, - WD: wd, - Tags: tags, - PkgPath: pkgPath, - ShapeHash: shapeHash, - Files: files, - LocalImports: localImports, - }) - } - return out, nil - } - readOutputs := func() ([]incrementalOutput, error) { - var n uint32 - if err := binary.Read(r, binary.LittleEndian, &n); err != nil { - return nil, err - } - out := make([]incrementalOutput, 0, n) - for i := uint32(0); i < n; i++ { - pkgPath, err := readString() - if err != nil { - return nil, err - } - outputPath, err := readString() - if err != nil { - return nil, err - } - contentKey, err := readString() - if err != nil { - return nil, err - } - out = append(out, incrementalOutput{PkgPath: pkgPath, OutputPath: outputPath, ContentKey: contentKey}) - } - return out, nil - } - fields := make([]string, 6) - for i := range fields { - s, err := readString() - if err != nil { - return nil, err - } - fields[i] = s - } - var patternCount uint32 - if err := binary.Read(r, binary.LittleEndian, &patternCount); err != nil { - return nil, err - } - patterns := make([]string, 0, patternCount) - for i := uint32(0); i < patternCount; i++ { - p, err := readString() - if err != nil { - return nil, err - } - patterns = append(patterns, p) - } - localPackages, err := readFingerprints() - if err != nil { - return nil, err - } - externalPkgs, err := readExternalPkgs() - if err != nil { - return nil, err - } - externalFiles, err := readCacheFiles() - if err != nil { - return nil, err - } - extraFiles, err := readCacheFiles() - if err != nil { - return nil, err - } - outputs, err := readOutputs() - if err != nil { - return nil, err - } - return &incrementalManifest{ - Version: fields[0], - WD: fields[1], - Tags: fields[2], - Prefix: fields[3], - HeaderHash: fields[4], - EnvHash: fields[5], - Patterns: patterns, - LocalPackages: localPackages, - ExternalPkgs: externalPkgs, - ExternalFiles: externalFiles, - ExtraFiles: extraFiles, - Outputs: outputs, - }, nil -} - -func incrementalContentKey(content []byte) string { - sum := sha256.Sum256(content) - return fmt.Sprintf("%x", sum[:]) -} diff --git a/internal/wire/incremental_session.go b/internal/wire/incremental_session.go deleted file mode 100644 index 2fdaa2b..0000000 --- a/internal/wire/incremental_session.go +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "crypto/sha256" - "encoding/hex" - "go/ast" - "go/token" - "path/filepath" - "strings" - "sync" -) - -type incrementalSession struct { - fset *token.FileSet - mu sync.Mutex - parsedDeps map[string]cachedParsedFile -} - -type cachedParsedFile struct { - hash string - file *ast.File -} - -var incrementalSessions sync.Map - -func clearIncrementalSessions() { - incrementalSessions.Range(func(key, _ any) bool { - incrementalSessions.Delete(key) - return true - }) -} - -func sessionKey(wd string, env []string, tags string) string { - var b strings.Builder - b.WriteString(packageCacheScope(wd)) - b.WriteByte('\n') - b.WriteString(tags) - b.WriteByte('\n') - for _, entry := range env { - b.WriteString(entry) - b.WriteByte('\x00') - } - return b.String() -} - -func getIncrementalSession(wd string, env []string, tags string) *incrementalSession { - key := sessionKey(wd, env, tags) - if session, ok := incrementalSessions.Load(key); ok { - return session.(*incrementalSession) - } - session := &incrementalSession{ - fset: token.NewFileSet(), - parsedDeps: make(map[string]cachedParsedFile), - } - actual, _ := incrementalSessions.LoadOrStore(key, session) - return actual.(*incrementalSession) -} - -func (s *incrementalSession) getParsedDep(filename string, src []byte) (*ast.File, bool) { - if s == nil { - return nil, false - } - hash := hashSource(src) - s.mu.Lock() - defer s.mu.Unlock() - entry, ok := s.parsedDeps[filepath.Clean(filename)] - if !ok || entry.hash != hash { - return nil, false - } - return entry.file, true -} - -func (s *incrementalSession) storeParsedDep(filename string, src []byte, file *ast.File) { - if s == nil || file == nil { - return - } - s.mu.Lock() - defer s.mu.Unlock() - s.parsedDeps[filepath.Clean(filename)] = cachedParsedFile{ - hash: hashSource(src), - file: file, - } -} - -func hashSource(src []byte) string { - sum := sha256.Sum256(src) - return hex.EncodeToString(sum[:]) -} diff --git a/internal/wire/incremental_summary.go b/internal/wire/incremental_summary.go deleted file mode 100644 index 934f637..0000000 --- a/internal/wire/incremental_summary.go +++ /dev/null @@ -1,656 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "bytes" - "crypto/sha256" - "encoding/binary" - "fmt" - "go/ast" - "go/types" - "path/filepath" - "sort" - - "golang.org/x/tools/go/packages" -) - -const incrementalSummaryVersion = "wire-incremental-summary-v1" - -type packageSummary struct { - Version string - WD string - Tags string - PkgPath string - ShapeHash string - LocalImports []string - ProviderSets []providerSetSummary - Injectors []injectorSummary -} - -type providerSetSummary struct { - VarName string - Providers []providerSummary - Imports []providerSetRefSummary - Bindings []ifaceBindingSummary - Values []string - Fields []fieldSummary - InputTypes []string -} - -type providerSummary struct { - PkgPath string - Name string - Args []providerInputSummary - Out []string - Varargs bool - IsStruct bool - HasCleanup bool - HasErr bool -} - -type providerInputSummary struct { - Type string - FieldName string -} - -type providerSetRefSummary struct { - PkgPath string - VarName string -} - -type ifaceBindingSummary struct { - Iface string - Provided string -} - -type fieldSummary struct { - PkgPath string - Parent string - Name string - Out []string -} - -type injectorSummary struct { - Name string - Inputs []string - Output string - Build providerSetSummary -} - -type packageSummarySnapshot struct { - Changed map[string]*packageSummary - Unchanged map[string]*packageSummary -} - -func incrementalSummaryKey(wd string, tags string, pkgPath string) string { - h := sha256.New() - h.Write([]byte(incrementalSummaryVersion)) - h.Write([]byte{0}) - h.Write([]byte(packageCacheScope(wd))) - h.Write([]byte{0}) - h.Write([]byte(tags)) - h.Write([]byte{0}) - h.Write([]byte(pkgPath)) - return fmt.Sprintf("%x", h.Sum(nil)) -} - -func incrementalSummaryPath(key string) string { - return filepath.Join(cacheDir(), key+".isum") -} - -func readIncrementalPackageSummary(key string) (*packageSummary, bool) { - data, err := osReadFile(incrementalSummaryPath(key)) - if err != nil { - return nil, false - } - summary, err := decodeIncrementalSummary(data) - if err != nil { - return nil, false - } - return summary, true -} - -func writeIncrementalPackageSummary(key string, summary *packageSummary) { - data, err := encodeIncrementalSummary(summary) - if err != nil { - return - } - dir := cacheDir() - if err := osMkdirAll(dir, 0755); err != nil { - return - } - tmp, err := osCreateTemp(dir, key+".isum-") - if err != nil { - return - } - _, writeErr := tmp.Write(data) - closeErr := tmp.Close() - if writeErr != nil || closeErr != nil { - osRemove(tmp.Name()) - return - } - if err := osRename(tmp.Name(), incrementalSummaryPath(key)); err != nil { - osRemove(tmp.Name()) - } -} - -func writeIncrementalPackageSummaries(loader *lazyLoader, pkgs []*packages.Package) { - writeIncrementalPackageSummariesWithSummary(loader, pkgs, nil, nil) -} - -func writeIncrementalPackageSummariesWithSummary(loader *lazyLoader, pkgs []*packages.Package, summary *summaryProviderResolver, only map[string]struct{}) { - if loader == nil || len(pkgs) == 0 { - return - } - moduleRoot := findModuleRoot(loader.wd) - all := collectAllPackages(pkgs) - for path, pkg := range loader.loaded { - if pkg != nil { - all[path] = pkg - } - } - allPkgs := make([]*packages.Package, 0, len(all)) - for _, pkg := range all { - allPkgs = append(allPkgs, pkg) - } - oc := newObjectCacheWithLoader(allPkgs, loader, nil, summary) - for _, pkg := range all { - if classifyPackageLocation(moduleRoot, pkg) != "local" { - continue - } - if len(only) > 0 { - if _, ok := only[pkg.PkgPath]; !ok { - continue - } - } - if pkg == nil || pkg.TypesInfo == nil || len(pkg.Syntax) == 0 { - continue - } - summary, err := buildPackageSummary(loader, oc, pkg) - if err != nil { - continue - } - writeIncrementalPackageSummary(incrementalSummaryKey(loader.wd, loader.tags, pkg.PkgPath), summary) - } -} - -func collectIncrementalPackageSummaries(loader *lazyLoader, pkgs []*packages.Package) *packageSummarySnapshot { - if loader == nil || loader.fingerprints == nil { - return nil - } - snapshot := &packageSummarySnapshot{ - Changed: make(map[string]*packageSummary), - Unchanged: make(map[string]*packageSummary), - } - changed := make(map[string]struct{}, len(loader.fingerprints.changed)) - for _, path := range loader.fingerprints.changed { - changed[path] = struct{}{} - } - moduleRoot := findModuleRoot(loader.wd) - oc := newObjectCache(pkgs, loader) - for _, pkg := range collectAllPackages(pkgs) { - if pkg == nil { - continue - } - if classifyPackageLocation(moduleRoot, pkg) != "local" { - continue - } - if _, ok := changed[pkg.PkgPath]; ok { - if pkg.TypesInfo == nil || len(pkg.Syntax) == 0 { - loaded, errs := oc.ensurePackage(pkg.PkgPath) - if len(errs) > 0 { - continue - } - pkg = loaded - } - if pkg == nil || pkg.TypesInfo == nil || len(pkg.Syntax) == 0 { - continue - } - summary, err := buildPackageSummary(loader, oc, pkg) - if err != nil { - continue - } - snapshot.Changed[pkg.PkgPath] = summary - continue - } - if summary, ok := readIncrementalPackageSummary(incrementalSummaryKey(loader.wd, loader.tags, pkg.PkgPath)); ok { - snapshot.Unchanged[pkg.PkgPath] = summary - } - } - return snapshot -} - -func buildPackageSummary(loader *lazyLoader, oc *objectCache, pkg *packages.Package) (*packageSummary, error) { - if loader == nil || oc == nil || pkg == nil { - return nil, fmt.Errorf("missing loader, object cache, or package") - } - summary := &packageSummary{ - Version: incrementalSummaryVersion, - WD: filepath.Clean(loader.wd), - Tags: loader.tags, - PkgPath: pkg.PkgPath, - } - if snapshot := loader.fingerprints; snapshot != nil { - if fp := snapshot.fingerprints[pkg.PkgPath]; fp != nil { - summary.ShapeHash = fp.ShapeHash - summary.LocalImports = append(summary.LocalImports, fp.LocalImports...) - } - } - scope := pkg.Types.Scope() - for _, name := range scope.Names() { - obj := scope.Lookup(name) - if !isProviderSetType(obj.Type()) { - continue - } - item, errs := oc.get(obj) - if len(errs) > 0 { - continue - } - pset, ok := item.(*ProviderSet) - if !ok { - continue - } - summary.ProviderSets = append(summary.ProviderSets, summarizeProviderSet(pset)) - } - sort.Slice(summary.ProviderSets, func(i, j int) bool { - return summary.ProviderSets[i].VarName < summary.ProviderSets[j].VarName - }) - for _, file := range pkg.Syntax { - for _, decl := range file.Decls { - fn, ok := decl.(*ast.FuncDecl) - if !ok { - continue - } - buildCall, err := findInjectorBuild(pkg.TypesInfo, fn) - if err != nil || buildCall == nil { - continue - } - sig := pkg.TypesInfo.ObjectOf(fn.Name).Type().(*types.Signature) - ins, out, err := injectorFuncSignature(sig) - if err != nil { - continue - } - injectorArgs := &InjectorArgs{ - Name: fn.Name.Name, - Tuple: ins, - Pos: fn.Pos(), - } - set, errs := oc.processNewSet(pkg.TypesInfo, pkg.PkgPath, buildCall, injectorArgs, "") - if len(errs) > 0 { - continue - } - summary.Injectors = append(summary.Injectors, injectorSummary{ - Name: fn.Name.Name, - Inputs: summarizeTuple(ins), - Output: summaryTypeString(out.out), - Build: summarizeProviderSet(set), - }) - } - } - sort.Slice(summary.Injectors, func(i, j int) bool { - return summary.Injectors[i].Name < summary.Injectors[j].Name - }) - return summary, nil -} - -func summarizeProviderSet(pset *ProviderSet) providerSetSummary { - if pset == nil { - return providerSetSummary{} - } - summary := providerSetSummary{ - VarName: pset.VarName, - } - for _, provider := range pset.Providers { - summary.Providers = append(summary.Providers, summarizeProvider(provider)) - } - for _, imported := range pset.Imports { - summary.Imports = append(summary.Imports, providerSetRefSummary{ - PkgPath: imported.PkgPath, - VarName: imported.VarName, - }) - } - for _, binding := range pset.Bindings { - summary.Bindings = append(summary.Bindings, ifaceBindingSummary{ - Iface: summaryTypeString(binding.Iface), - Provided: summaryTypeString(binding.Provided), - }) - } - for _, value := range pset.Values { - summary.Values = append(summary.Values, summaryTypeString(value.Out)) - } - for _, field := range pset.Fields { - item := fieldSummary{ - Parent: summaryTypeString(field.Parent), - Name: field.Name, - Out: summarizeTypes(field.Out), - } - if field.Pkg != nil { - item.PkgPath = field.Pkg.Path() - } - summary.Fields = append(summary.Fields, item) - } - if pset.InjectorArgs != nil { - summary.InputTypes = summarizeTuple(pset.InjectorArgs.Tuple) - } - sort.Slice(summary.Providers, func(i, j int) bool { - return summary.Providers[i].PkgPath+"."+summary.Providers[i].Name < summary.Providers[j].PkgPath+"."+summary.Providers[j].Name - }) - sort.Slice(summary.Imports, func(i, j int) bool { - return summary.Imports[i].PkgPath+"."+summary.Imports[i].VarName < summary.Imports[j].PkgPath+"."+summary.Imports[j].VarName - }) - sort.Slice(summary.Bindings, func(i, j int) bool { - return summary.Bindings[i].Iface+":"+summary.Bindings[i].Provided < summary.Bindings[j].Iface+":"+summary.Bindings[j].Provided - }) - sort.Strings(summary.Values) - sort.Slice(summary.Fields, func(i, j int) bool { - return summary.Fields[i].Parent+"."+summary.Fields[i].Name < summary.Fields[j].Parent+"."+summary.Fields[j].Name - }) - sort.Strings(summary.InputTypes) - return summary -} - -func summarizeProvider(provider *Provider) providerSummary { - summary := providerSummary{ - Name: provider.Name, - Varargs: provider.Varargs, - IsStruct: provider.IsStruct, - HasCleanup: provider.HasCleanup, - HasErr: provider.HasErr, - Out: summarizeTypes(provider.Out), - } - if provider.Pkg != nil { - summary.PkgPath = provider.Pkg.Path() - } - for _, arg := range provider.Args { - summary.Args = append(summary.Args, providerInputSummary{ - Type: summaryTypeString(arg.Type), - FieldName: arg.FieldName, - }) - } - return summary -} - -func summarizeTuple(tuple *types.Tuple) []string { - if tuple == nil { - return nil - } - out := make([]string, 0, tuple.Len()) - for i := 0; i < tuple.Len(); i++ { - out = append(out, summaryTypeString(tuple.At(i).Type())) - } - return out -} - -func summarizeTypes(typesList []types.Type) []string { - out := make([]string, 0, len(typesList)) - for _, t := range typesList { - out = append(out, summaryTypeString(t)) - } - return out -} - -func summaryTypeString(t types.Type) string { - if t == nil { - return "" - } - return types.TypeString(t, func(pkg *types.Package) string { - if pkg == nil { - return "" - } - return pkg.Path() - }) -} - -func encodeIncrementalSummary(summary *packageSummary) ([]byte, error) { - if summary == nil { - return nil, fmt.Errorf("nil package summary") - } - var buf bytes.Buffer - enc := binarySummaryEncoder{buf: &buf} - enc.string(summary.Version) - enc.string(summary.WD) - enc.string(summary.Tags) - enc.string(summary.PkgPath) - enc.string(summary.ShapeHash) - enc.strings(summary.LocalImports) - enc.providerSets(summary.ProviderSets) - enc.u32(uint32(len(summary.Injectors))) - for _, injector := range summary.Injectors { - enc.string(injector.Name) - enc.strings(injector.Inputs) - enc.string(injector.Output) - enc.providerSet(injector.Build) - } - if enc.err != nil { - return nil, enc.err - } - return buf.Bytes(), nil -} - -func decodeIncrementalSummary(data []byte) (*packageSummary, error) { - dec := binarySummaryDecoder{r: bytes.NewReader(data)} - summary := &packageSummary{ - Version: dec.string(), - WD: dec.string(), - Tags: dec.string(), - PkgPath: dec.string(), - ShapeHash: dec.string(), - } - summary.LocalImports = dec.strings() - summary.ProviderSets = dec.providerSets() - for n := dec.u32(); n > 0; n-- { - summary.Injectors = append(summary.Injectors, injectorSummary{ - Name: dec.string(), - Inputs: dec.strings(), - Output: dec.string(), - Build: dec.providerSet(), - }) - } - if dec.err != nil { - return nil, dec.err - } - return summary, nil -} - -type binarySummaryEncoder struct { - buf *bytes.Buffer - err error -} - -func (e *binarySummaryEncoder) u32(v uint32) { - if e.err != nil { - return - } - e.err = binary.Write(e.buf, binary.LittleEndian, v) -} - -func (e *binarySummaryEncoder) string(s string) { - e.u32(uint32(len(s))) - if e.err != nil { - return - } - _, e.err = e.buf.WriteString(s) -} - -func (e *binarySummaryEncoder) bool(v bool) { - if e.err != nil { - return - } - var b byte - if v { - b = 1 - } - e.err = e.buf.WriteByte(b) -} - -func (e *binarySummaryEncoder) strings(values []string) { - e.u32(uint32(len(values))) - for _, v := range values { - e.string(v) - } -} - -func (e *binarySummaryEncoder) providerSets(values []providerSetSummary) { - e.u32(uint32(len(values))) - for _, value := range values { - e.providerSet(value) - } -} - -func (e *binarySummaryEncoder) providerSet(value providerSetSummary) { - e.string(value.VarName) - e.u32(uint32(len(value.Providers))) - for _, provider := range value.Providers { - e.string(provider.PkgPath) - e.string(provider.Name) - e.u32(uint32(len(provider.Args))) - for _, arg := range provider.Args { - e.string(arg.Type) - e.string(arg.FieldName) - } - e.strings(provider.Out) - e.bool(provider.Varargs) - e.bool(provider.IsStruct) - e.bool(provider.HasCleanup) - e.bool(provider.HasErr) - } - e.u32(uint32(len(value.Imports))) - for _, imported := range value.Imports { - e.string(imported.PkgPath) - e.string(imported.VarName) - } - e.u32(uint32(len(value.Bindings))) - for _, binding := range value.Bindings { - e.string(binding.Iface) - e.string(binding.Provided) - } - e.strings(value.Values) - e.u32(uint32(len(value.Fields))) - for _, field := range value.Fields { - e.string(field.PkgPath) - e.string(field.Parent) - e.string(field.Name) - e.strings(field.Out) - } - e.strings(value.InputTypes) -} - -type binarySummaryDecoder struct { - r *bytes.Reader - err error -} - -func (d *binarySummaryDecoder) u32() uint32 { - if d.err != nil { - return 0 - } - var v uint32 - d.err = binary.Read(d.r, binary.LittleEndian, &v) - return v -} - -func (d *binarySummaryDecoder) string() string { - n := d.u32() - if d.err != nil { - return "" - } - buf := make([]byte, n) - _, d.err = d.r.Read(buf) - return string(buf) -} - -func (d *binarySummaryDecoder) bool() bool { - if d.err != nil { - return false - } - b, err := d.r.ReadByte() - if err != nil { - d.err = err - return false - } - return b != 0 -} - -func (d *binarySummaryDecoder) strings() []string { - n := d.u32() - if d.err != nil { - return nil - } - out := make([]string, 0, n) - for i := uint32(0); i < n; i++ { - out = append(out, d.string()) - } - return out -} - -func (d *binarySummaryDecoder) providerSets() []providerSetSummary { - n := d.u32() - if d.err != nil { - return nil - } - out := make([]providerSetSummary, 0, n) - for i := uint32(0); i < n; i++ { - out = append(out, d.providerSet()) - } - return out -} - -func (d *binarySummaryDecoder) providerSet() providerSetSummary { - value := providerSetSummary{ - VarName: d.string(), - } - for n := d.u32(); n > 0; n-- { - provider := providerSummary{ - PkgPath: d.string(), - Name: d.string(), - } - for m := d.u32(); m > 0; m-- { - provider.Args = append(provider.Args, providerInputSummary{ - Type: d.string(), - FieldName: d.string(), - }) - } - provider.Out = d.strings() - provider.Varargs = d.bool() - provider.IsStruct = d.bool() - provider.HasCleanup = d.bool() - provider.HasErr = d.bool() - value.Providers = append(value.Providers, provider) - } - for n := d.u32(); n > 0; n-- { - value.Imports = append(value.Imports, providerSetRefSummary{ - PkgPath: d.string(), - VarName: d.string(), - }) - } - for n := d.u32(); n > 0; n-- { - value.Bindings = append(value.Bindings, ifaceBindingSummary{ - Iface: d.string(), - Provided: d.string(), - }) - } - value.Values = d.strings() - for n := d.u32(); n > 0; n-- { - value.Fields = append(value.Fields, fieldSummary{ - PkgPath: d.string(), - Parent: d.string(), - Name: d.string(), - Out: d.strings(), - }) - } - value.InputTypes = d.strings() - return value -} diff --git a/internal/wire/incremental_summary_test.go b/internal/wire/incremental_summary_test.go deleted file mode 100644 index ae85651..0000000 --- a/internal/wire/incremental_summary_test.go +++ /dev/null @@ -1,295 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "context" - "os" - "path/filepath" - "strings" - "testing" -) - -func TestIncrementalSummaryEncodeDecodeRoundTrip(t *testing.T) { - summary := &packageSummary{ - Version: incrementalSummaryVersion, - WD: "/tmp/app", - Tags: "dev", - PkgPath: "example.com/app/dep", - ShapeHash: "abc123", - LocalImports: []string{"example.com/app/shared"}, - ProviderSets: []providerSetSummary{{ - VarName: "Set", - Providers: []providerSummary{{ - PkgPath: "example.com/app/dep", - Name: "NewThing", - Args: []providerInputSummary{{Type: "string"}}, - Out: []string{"*example.com/app/dep.Thing"}, - HasCleanup: true, - }}, - Imports: []providerSetRefSummary{{PkgPath: "example.com/app/shared", VarName: "SharedSet"}}, - Bindings: []ifaceBindingSummary{{Iface: "error", Provided: "*example.com/app/dep.Thing"}}, - Values: []string{"string"}, - Fields: []fieldSummary{{PkgPath: "example.com/app/dep", Parent: "example.com/app/dep.Config", Name: "Name", Out: []string{"string"}}}, - InputTypes: []string{"context.Context"}, - }}, - Injectors: []injectorSummary{{ - Name: "Init", - Inputs: []string{"context.Context"}, - Output: "*example.com/app/dep.Thing", - Build: providerSetSummary{ - Imports: []providerSetRefSummary{{PkgPath: "example.com/app/shared", VarName: "SharedSet"}}, - }, - }}, - } - data, err := encodeIncrementalSummary(summary) - if err != nil { - t.Fatalf("encodeIncrementalSummary: %v", err) - } - got, err := decodeIncrementalSummary(data) - if err != nil { - t.Fatalf("decodeIncrementalSummary: %v", err) - } - if got.Version != summary.Version || got.PkgPath != summary.PkgPath || got.ShapeHash != summary.ShapeHash { - t.Fatalf("decoded summary mismatch: %+v", got) - } - if len(got.ProviderSets) != 1 || got.ProviderSets[0].VarName != "Set" { - t.Fatalf("decoded provider sets mismatch: %+v", got.ProviderSets) - } - if len(got.Injectors) != 1 || got.Injectors[0].Name != "Init" { - t.Fatalf("decoded injectors mismatch: %+v", got.Injectors) - } -} - -func TestBuildPackageSummary(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.Set)", - "\treturn nil", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "type Foo struct{ Message string }", - "", - "func NewMessage() string { return \"ok\" }", - "", - "func New(msg string) *Foo { return &Foo{Message: msg} }", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var Set = wire.NewSet(NewMessage, New)", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := WithIncremental(context.Background(), true) - pkgs, loader, errs := load(ctx, root, env, "", []string{"./app"}) - if len(errs) > 0 { - t.Fatalf("load returned errors: %v", errs) - } - oc := newObjectCache(pkgs, loader) - loadedDep, errs := oc.ensurePackage("example.com/app/dep") - if len(errs) > 0 { - t.Fatalf("ensurePackage returned errors: %v", errs) - } - summary, err := buildPackageSummary(loader, oc, loadedDep) - if err != nil { - t.Fatalf("buildPackageSummary: %v", err) - } - if summary.PkgPath != "example.com/app/dep" { - t.Fatalf("summary pkg path = %q", summary.PkgPath) - } - if len(summary.ProviderSets) != 1 || summary.ProviderSets[0].VarName != "Set" { - t.Fatalf("unexpected provider sets: %+v", summary.ProviderSets) - } - if len(summary.ProviderSets[0].Providers) != 2 { - t.Fatalf("unexpected providers: %+v", summary.ProviderSets[0].Providers) - } - loadedApp, errs := oc.ensurePackage("example.com/app/app") - if len(errs) > 0 { - t.Fatalf("ensurePackage app returned errors: %v", errs) - } - appSummary, err := buildPackageSummary(loader, oc, loadedApp) - if err != nil { - t.Fatalf("buildPackageSummary app: %v", err) - } - if len(appSummary.Injectors) != 1 || appSummary.Injectors[0].Name != "Init" { - t.Fatalf("unexpected injectors: %+v", appSummary.Injectors) - } - if len(appSummary.Injectors[0].Build.Imports) != 1 || appSummary.Injectors[0].Build.Imports[0].PkgPath != "example.com/app/dep" { - t.Fatalf("unexpected injector imports: %+v", appSummary.Injectors[0].Build.Imports) - } -} - -func TestCollectIncrementalPackageSummariesUsesCacheForUnchanged(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.Set)", - "\treturn nil", - "}", - "", - }, "\n")) - - depFile := filepath.Join(root, "dep", "dep.go") - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "type Foo struct{ Message string }", - "", - "func NewMessage() string { return \"ok\" }", - "", - "func New(msg string) *Foo { return &Foo{Message: msg} }", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var Set = wire.NewSet(NewMessage, New)", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := WithIncremental(context.Background(), true) - - gens, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("Generate returned errors: %v", errs) - } - if len(gens) != 1 || len(gens[0].Errs) > 0 { - t.Fatalf("unexpected Generate result: %+v", gens) - } - pkgs, loader, errs := load(ctx, root, env, "", []string{"./app"}) - if len(errs) > 0 { - t.Fatalf("load returned errors while seeding summaries: %v", errs) - } - if _, errs := newObjectCache(pkgs, loader).ensurePackage("example.com/app/app"); len(errs) > 0 { - t.Fatalf("ensurePackage returned errors while seeding summaries: %v", errs) - } - writeIncrementalPackageSummaries(loader, pkgs) - - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "type Foo struct{ Message string; Count int }", - "", - "func NewMessage() string { return \"ok\" }", - "", - "func NewCount() int { return 7 }", - "", - "func New(msg string, count int) *Foo { return &Foo{Message: msg, Count: count} }", - "", - }, "\n")) - writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var Set = wire.NewSet(NewMessage, NewCount, New)", - "", - }, "\n")) - - pkgs, loader, errs = load(ctx, root, env, "", []string{"./app"}) - if len(errs) > 0 { - t.Fatalf("load returned errors: %v", errs) - } - snapshot := collectIncrementalPackageSummaries(loader, pkgs) - if snapshot == nil { - t.Fatal("collectIncrementalPackageSummaries returned nil") - } - if _, ok := snapshot.Changed["example.com/app/dep"]; !ok { - t.Fatalf("expected changed dep summary, got %+v", snapshot.Changed) - } - if _, ok := snapshot.Unchanged["example.com/app/app"]; !ok { - t.Fatalf("expected unchanged app summary from cache, got %+v", snapshot.Unchanged) - } - if len(snapshot.Unchanged["example.com/app/app"].Injectors) != 1 { - t.Fatalf("unexpected cached app summary: %+v", snapshot.Unchanged["example.com/app/app"]) - } - if len(snapshot.Changed["example.com/app/dep"].ProviderSets) != 1 { - t.Fatalf("unexpected changed dep summary: %+v", snapshot.Changed["example.com/app/dep"]) - } -} diff --git a/internal/wire/incremental_test.go b/internal/wire/incremental_test.go deleted file mode 100644 index a531123..0000000 --- a/internal/wire/incremental_test.go +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "context" - "testing" -) - -func TestIncrementalEnabledDefaultOff(t *testing.T) { - if IncrementalEnabled(context.Background(), nil) { - t.Fatal("IncrementalEnabled should default to false") - } -} - -func TestIncrementalEnabledFromEnv(t *testing.T) { - env := []string{ - "FOO=bar", - IncrementalEnvVar + "=true", - } - if !IncrementalEnabled(context.Background(), env) { - t.Fatal("IncrementalEnabled should read the environment variable") - } -} - -func TestIncrementalEnabledUsesLastEnvValue(t *testing.T) { - env := []string{ - IncrementalEnvVar + "=false", - IncrementalEnvVar + "=true", - } - if !IncrementalEnabled(context.Background(), env) { - t.Fatal("IncrementalEnabled should use the last matching env value") - } -} - -func TestIncrementalEnabledContextOverridesEnv(t *testing.T) { - env := []string{ - IncrementalEnvVar + "=false", - } - ctx := WithIncremental(context.Background(), true) - if !IncrementalEnabled(ctx, env) { - t.Fatal("context override should take precedence over env") - } -} - -func TestIncrementalEnabledInvalidEnvFallsBackFalse(t *testing.T) { - env := []string{ - IncrementalEnvVar + "=maybe", - } - if IncrementalEnabled(context.Background(), env) { - t.Fatal("invalid env value should not enable incremental mode") - } -} diff --git a/internal/wire/load_debug.go b/internal/wire/load_debug.go index fd8c4d7..d3d5fc1 100644 --- a/internal/wire/load_debug.go +++ b/internal/wire/load_debug.go @@ -124,7 +124,7 @@ func logLoadDebug(ctx context.Context, scope string, mode packages.LoadMode, sub } if parseStats != nil { snap := parseStats.snapshot() - debugf(ctx, "load.debug scope=%s parse.calls=%d parse.primary=%d parse.deps=%d parse.cache_hits=%d parse.cache_misses=%d parse.errors=%d parse.total=%s", + debugf(ctx, "load.debug scope=%s parse.calls=%d parse.primary=%d parse.deps=%d parse.cache_hits=%d parse.cache_misses=%d parse.errors=%d parse.cumulative=%s", scope, snap.calls, snap.primaryCalls, @@ -190,6 +190,23 @@ func summarizeLoadScope(wd string, pkgs []*packages.Package) loadScopeStats { return stats } +func collectAllPackages(pkgs []*packages.Package) map[string]*packages.Package { + all := make(map[string]*packages.Package) + stack := append([]*packages.Package(nil), pkgs...) + for len(stack) > 0 { + p := stack[len(stack)-1] + stack = stack[:len(stack)-1] + if p == nil || all[p.PkgPath] != nil { + continue + } + all[p.PkgPath] = p + for _, imp := range p.Imports { + stack = append(stack, imp) + } + } + return all +} + func classifyPackageLocation(moduleRoot string, pkg *packages.Package) string { if moduleRoot == "" || pkg == nil { return "unknown" @@ -210,8 +227,8 @@ func classifyPackageLocation(moduleRoot string, pkg *packages.Package) string { } func isWithinRoot(root, name string) bool { - cleanRoot := filepath.Clean(root) - cleanName := filepath.Clean(name) + cleanRoot := canonicalPath(root) + cleanName := canonicalPath(name) if cleanName == cleanRoot { return true } @@ -222,6 +239,14 @@ func isWithinRoot(root, name string) bool { return rel != ".." && !strings.HasPrefix(rel, ".."+string(os.PathSeparator)) } +func canonicalPath(path string) string { + clean := filepath.Clean(path) + if resolved, err := filepath.EvalSymlinks(clean); err == nil && resolved != "" { + return filepath.Clean(resolved) + } + return clean +} + func topPackageMetrics(metrics []packageMetric) []string { sort.Slice(metrics, func(i, j int) bool { if metrics[i].count == metrics[j].count { diff --git a/internal/wire/loader_test.go b/internal/wire/loader_test.go deleted file mode 100644 index 37e27d9..0000000 --- a/internal/wire/loader_test.go +++ /dev/null @@ -1,2596 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "context" - "os" - "path/filepath" - "strings" - "testing" - "time" -) - -func TestLoadAndGenerateModule(t *testing.T) { - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "app.go"), strings.Join([]string{ - "package app", - "", - }, "\n")) - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.New)", - "\treturn nil", - "}", - "", - }, "\n")) - writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "type Foo struct{}", - "", - "func New() *Foo {", - "\treturn &Foo{}", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "noop", "noop.go"), strings.Join([]string{ - "package noop", - "", - "type Thing struct{}", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := context.Background() - - info, errs := Load(ctx, root, env, "", []string{"./app"}) - if len(errs) > 0 { - t.Fatalf("Load returned errors: %v", errs) - } - if info == nil { - t.Fatal("Load returned nil info") - } - if len(info.Injectors) != 1 || info.Injectors[0].FuncName != "Init" { - t.Fatalf("Load returned unexpected injectors: %+v", info.Injectors) - } - - gens, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("Generate returned errors: %v", errs) - } - if len(gens) != 1 { - t.Fatalf("Generate returned %d results, want 1", len(gens)) - } - if len(gens[0].Errs) > 0 { - t.Fatalf("Generate result had errors: %v", gens[0].Errs) - } - if len(gens[0].Content) == 0 { - t.Fatal("Generate returned empty output for wire package") - } - if gens[0].OutputPath == "" { - t.Fatal("Generate returned empty output path") - } - - noops, errs := Generate(ctx, root, env, []string{"./noop"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("Generate noop returned errors: %v", errs) - } - if len(noops) != 1 { - t.Fatalf("Generate noop returned %d results, want 1", len(noops)) - } - if len(noops[0].Errs) > 0 { - t.Fatalf("Generate noop result had errors: %v", noops[0].Errs) - } - if noops[0].OutputPath == "" { - t.Fatal("Generate noop returned empty output path") - } - if len(noops[0].Content) != 0 { - t.Fatal("Generate noop returned unexpected output") - } -} - -func TestLoadAndGenerateModuleIncrementalMatches(t *testing.T) { - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.NewSet)", - "\treturn nil", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return \"ok\" }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, New)", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - - info, errs := Load(context.Background(), root, env, "", []string{"./app"}) - if len(errs) > 0 { - t.Fatalf("Load returned errors: %v", errs) - } - if info == nil || len(info.Injectors) != 1 { - t.Fatalf("Load returned unexpected info: %+v errs=%v", info, errs) - } - - incrementalCtx := WithIncremental(context.Background(), true) - incrementalInfo, errs := Load(incrementalCtx, root, env, "", []string{"./app"}) - if len(errs) > 0 { - t.Fatalf("incremental Load returned errors: %v", errs) - } - if incrementalInfo == nil || len(incrementalInfo.Injectors) != 1 { - t.Fatalf("incremental Load returned unexpected info: %+v errs=%v", incrementalInfo, errs) - } - - normalGens, errs := Generate(context.Background(), root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("Generate returned errors: %v", errs) - } - incrementalGens, errs := Generate(incrementalCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("incremental Generate returned errors: %v", errs) - } - if len(normalGens) != 1 || len(incrementalGens) != 1 { - t.Fatalf("unexpected result counts: normal=%d incremental=%d", len(normalGens), len(incrementalGens)) - } - if len(normalGens[0].Errs) > 0 || len(incrementalGens[0].Errs) > 0 { - t.Fatalf("unexpected generate errors: normal=%v incremental=%v", normalGens[0].Errs, incrementalGens[0].Errs) - } - if normalGens[0].OutputPath != incrementalGens[0].OutputPath { - t.Fatalf("output paths differ: normal=%q incremental=%q", normalGens[0].OutputPath, incrementalGens[0].OutputPath) - } - if string(normalGens[0].Content) != string(incrementalGens[0].Content) { - t.Fatalf("generated content differs between normal and incremental modes") - } -} - -func TestGenerateIncrementalBodyOnlyChangeUsesPreloadManifestAndReusesOutput(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.NewSet)", - "\treturn nil", - "}", - "", - }, "\n")) - writeFile(t, filepath.Join(root, "app", "wire_gen.go"), strings.Join([]string{ - "//go:build !wireinject", - "", - "package app", - "", - "func generated() {}", - "", - }, "\n")) - writeFile(t, filepath.Join(root, "app", "app_test.go"), strings.Join([]string{ - "package app", - "", - "func testOnly() {}", - "", - }, "\n")) - - depFile := filepath.Join(root, "dep", "dep.go") - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return \"a\" }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, New)", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := WithIncremental(context.Background(), true) - - var firstLabels []string - firstCtx := WithTiming(ctx, func(label string, _ time.Duration) { - firstLabels = append(firstLabels, label) - }) - first, errs := Generate(firstCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("first Generate returned errors: %v", errs) - } - if len(first) != 1 || len(first[0].Errs) > 0 { - t.Fatalf("unexpected first Generate result: %+v", first) - } - if !containsLabel(firstLabels, "load.packages.lazy.load") { - t.Fatalf("expected first incremental generate to perform lazy load, labels=%v", firstLabels) - } - - if err := os.WriteFile(depFile, []byte(strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return \"b\" }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n")), 0644); err != nil { - t.Fatalf("WriteFile failed: %v", err) - } - - var secondLabels []string - secondCtx := WithTiming(ctx, func(label string, _ time.Duration) { - secondLabels = append(secondLabels, label) - }) - second, errs := Generate(secondCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("second Generate returned errors: %v", errs) - } - if len(second) != 1 || len(second[0].Errs) > 0 { - t.Fatalf("unexpected second Generate result: %+v", second) - } - if containsLabel(secondLabels, "generate.load") { - t.Fatalf("expected second Generate to reuse preload manifest after body-only change, labels=%v", secondLabels) - } - if string(first[0].Content) != string(second[0].Content) { - t.Fatal("expected body-only change to reuse identical generated output") - } -} - -func TestGenerateIncrementalTouchedValidationCacheReusesSuccessfulValidation(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.NewSet)", - "\treturn nil", - "}", - "", - }, "\n")) - - depFile := filepath.Join(root, "dep", "dep.go") - writeBodyVariant := func(message string) { - t.Helper() - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return \"" + message + "\" }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n")) - } - writeBodyVariant("a") - - writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, New)", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := WithIncremental(context.Background(), true) - - first, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("first Generate returned errors: %v", errs) - } - if len(first) != 1 || len(first[0].Errs) > 0 { - t.Fatalf("unexpected first Generate result: %+v", first) - } - - writeBodyVariant("b") - - var secondLabels []string - secondCtx := WithTiming(ctx, func(label string, _ time.Duration) { - secondLabels = append(secondLabels, label) - }) - second, errs := Generate(secondCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("second Generate returned errors: %v", errs) - } - if len(second) != 1 || len(second[0].Errs) > 0 { - t.Fatalf("unexpected second Generate result: %+v", second) - } - if containsLabel(secondLabels, "generate.load") { - t.Fatalf("expected first body-only variant change to avoid generate.load, labels=%v", secondLabels) - } - if containsLabel(secondLabels, "incremental.preload_manifest.validate_touched_cache_hit") { - t.Fatalf("did not expect first body-only variant change to hit touched validation cache, labels=%v", secondLabels) - } - - writeBodyVariant("a") - third, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("third Generate returned errors: %v", errs) - } - if len(third) != 1 || len(third[0].Errs) > 0 { - t.Fatalf("unexpected third Generate result: %+v", third) - } - - writeBodyVariant("b") - - var fourthLabels []string - fourthCtx := WithTiming(ctx, func(label string, _ time.Duration) { - fourthLabels = append(fourthLabels, label) - }) - fourth, errs := Generate(fourthCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("fourth Generate returned errors: %v", errs) - } - if len(fourth) != 1 || len(fourth[0].Errs) > 0 { - t.Fatalf("unexpected fourth Generate result: %+v", fourth) - } - if containsLabel(fourthLabels, "generate.load") { - t.Fatalf("expected repeated body-only variant change to avoid generate.load, labels=%v", fourthLabels) - } - if !containsLabel(fourthLabels, "incremental.preload_manifest.validate_touched_cache_hit") { - t.Fatalf("expected repeated body-only variant change to hit touched validation cache, labels=%v", fourthLabels) - } - if string(first[0].Content) != string(fourth[0].Content) { - t.Fatal("expected repeated body-only variant change to reuse identical generated output") - } -} - -func TestGenerateIncrementalConstValueChangeUsesPreloadManifestAndReusesOutput(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.NewSet)", - "\treturn nil", - "}", - "", - }, "\n")) - - depFile := filepath.Join(root, "dep", "dep.go") - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, New)", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := WithIncremental(context.Background(), true) - - first, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("first Generate returned errors: %v", errs) - } - if len(first) != 1 || len(first[0].Errs) > 0 { - t.Fatalf("unexpected first Generate result: %+v", first) - } - - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "const SQLText = \"blue\"", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n")) - - var secondLabels []string - secondCtx := WithTiming(ctx, func(label string, _ time.Duration) { - secondLabels = append(secondLabels, label) - }) - second, errs := Generate(secondCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("second Generate returned errors: %v", errs) - } - if len(second) != 1 || len(second[0].Errs) > 0 { - t.Fatalf("unexpected second Generate result: %+v", second) - } - if containsLabel(secondLabels, "generate.load") { - t.Fatalf("expected const-value change to reuse preload manifest, labels=%v", secondLabels) - } - if string(first[0].Content) != string(second[0].Content) { - t.Fatal("expected const-value change to reuse identical generated output") - } -} - -func TestGenerateIncrementalBodyOnlyInvalidChangeDoesNotReusePreloadManifest(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.NewSet)", - "\treturn nil", - "}", - "", - }, "\n")) - - depFile := filepath.Join(root, "dep", "dep.go") - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return \"a\" }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, New)", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := WithIncremental(context.Background(), true) - - first, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("first Generate returned errors: %v", errs) - } - if len(first) != 1 || len(first[0].Errs) > 0 { - t.Fatalf("unexpected first Generate result: %+v", first) - } - - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return \"a\" }", - "", - "func New(msg string) *Foo {", - "\treturn missing", - "}", - "", - }, "\n")) - - var secondLabels []string - secondCtx := WithTiming(ctx, func(label string, _ time.Duration) { - secondLabels = append(secondLabels, label) - }) - second, errs := Generate(secondCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(second) != 0 { - t.Fatalf("expected invalid body-only change to stop before generation, got %+v", second) - } - if len(errs) == 0 { - t.Fatal("expected invalid body-only change to return errors") - } - if !containsLabel(secondLabels, "generate.load") { - t.Fatalf("expected invalid body-only change to bypass preload manifest and load packages, labels=%v", secondLabels) - } - if got := errs[0].Error(); !strings.Contains(got, "undefined: missing") { - t.Fatalf("expected load/type-check error from invalid body-only change, got %q", got) - } -} - -func TestGenerateIncrementalScenarioMatrix(t *testing.T) { - t.Parallel() - - type scenarioExpectation struct { - mode string - wantErr bool - wantSameOutput bool - } - - scenarios := []struct { - name string - apply func(t *testing.T, fx incrementalScenarioFixture) - want scenarioExpectation - }{ - { - name: "comment_only_change_reuses_preload", - apply: func(t *testing.T, fx incrementalScenarioFixture) { - writeFile(t, fx.depFile, strings.Join([]string{ - "package dep", - "", - "// SQLText controls SQL highlighting in log output.", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - }, - want: scenarioExpectation{mode: "preload", wantSameOutput: true}, - }, - { - name: "whitespace_only_change_reuses_preload", - apply: func(t *testing.T, fx incrementalScenarioFixture) { - writeFile(t, fx.depFile, strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "", - "func New(msg string) *Foo {", - "", - "\treturn &Foo{Message: helper(msg)}", - "", - "}", - "", - }, "\n")) - }, - want: scenarioExpectation{mode: "preload", wantSameOutput: true}, - }, - { - name: "function_body_change_reuses_preload", - apply: func(t *testing.T, fx incrementalScenarioFixture) { - writeFile(t, fx.depFile, strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string {", - "\treturn helper(SQLText)", - "}", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - }, - want: scenarioExpectation{mode: "preload", wantSameOutput: true}, - }, - { - name: "method_body_change_uses_local_fastpath", - apply: func(t *testing.T, fx incrementalScenarioFixture) { - writeFile(t, fx.depFile, strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func (f Foo) Summary() string {", - "\treturn helper(f.Message)", - "}", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n")) - }, - want: scenarioExpectation{mode: "local_fastpath", wantSameOutput: true}, - }, - { - name: "const_value_change_reuses_preload", - apply: func(t *testing.T, fx incrementalScenarioFixture) { - writeFile(t, fx.depFile, strings.Join([]string{ - "package dep", - "", - "const SQLText = \"blue\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - }, - want: scenarioExpectation{mode: "preload", wantSameOutput: true}, - }, - { - name: "var_initializer_change_reuses_preload", - apply: func(t *testing.T, fx incrementalScenarioFixture) { - writeFile(t, fx.depFile, strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 2", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - }, - want: scenarioExpectation{mode: "preload", wantSameOutput: true}, - }, - { - name: "add_top_level_helper_uses_local_fastpath", - apply: func(t *testing.T, fx incrementalScenarioFixture) { - writeFile(t, fx.depFile, strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "func NewTag() string { return \"tag\" }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - }, - want: scenarioExpectation{mode: "local_fastpath", wantSameOutput: true}, - }, - { - name: "import_only_implementation_change_reuses_preload", - apply: func(t *testing.T, fx incrementalScenarioFixture) { - writeFile(t, fx.depFile, strings.Join([]string{ - "package dep", - "", - "import \"fmt\"", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return fmt.Sprint(msg) }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - }, - want: scenarioExpectation{mode: "preload", wantSameOutput: true}, - }, - { - name: "signature_change_uses_local_fastpath", - apply: func(t *testing.T, fx incrementalScenarioFixture) { - writeFile(t, fx.depFile, strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 7", - "", - "type Foo struct {", - "\tMessage string", - "\tCount int", - "}", - "", - "func NewMessage() string { return SQLText }", - "", - "func NewCount() int { return defaultCount }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string, count int) *Foo {", - "\treturn &Foo{Message: helper(msg), Count: count}", - "}", - "", - }, "\n")) - writeFile(t, fx.wireFile, strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, NewCount, New)", - "", - }, "\n")) - }, - want: scenarioExpectation{mode: "local_fastpath", wantSameOutput: false}, - }, - { - name: "struct_field_addition_uses_local_fastpath", - apply: func(t *testing.T, fx incrementalScenarioFixture) { - writeFile(t, fx.depFile, strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Foo struct {", - "\tMessage string", - "\tCount int", - "}", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg), Count: defaultCount}", - "}", - "", - }, "\n")) - }, - want: scenarioExpectation{mode: "local_fastpath", wantSameOutput: true}, - }, - { - name: "interface_method_addition_uses_local_fastpath", - apply: func(t *testing.T, fx incrementalScenarioFixture) { - writeFile(t, fx.depFile, strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Fooer interface {", - "\tMessage() string", - "\tCount() int", - "}", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - }, - want: scenarioExpectation{mode: "local_fastpath", wantSameOutput: true}, - }, - { - name: "new_source_file_uses_local_fastpath", - apply: func(t *testing.T, fx incrementalScenarioFixture) { - writeFile(t, fx.extraFile, strings.Join([]string{ - "package dep", - "", - "func NewTag() string { return \"tag\" }", - "", - }, "\n")) - }, - want: scenarioExpectation{mode: "fast", wantSameOutput: true}, - }, - { - name: "invalid_body_change_falls_back_and_errors", - apply: func(t *testing.T, fx incrementalScenarioFixture) { - writeFile(t, fx.depFile, strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return missing }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - }, - want: scenarioExpectation{mode: "generate_load", wantErr: true}, - }, - } - - for _, scenario := range scenarios { - scenario := scenario - t.Run(scenario.name, func(t *testing.T) { - fx := newIncrementalScenarioFixture(t) - - first, errs := Generate(fx.ctx, fx.root, fx.env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("baseline Generate returned errors: %v", errs) - } - if len(first) != 1 || len(first[0].Errs) > 0 { - t.Fatalf("unexpected baseline Generate result: %+v", first) - } - - scenario.apply(t, fx) - - var labels []string - timedCtx := WithTiming(fx.ctx, func(label string, _ time.Duration) { - labels = append(labels, label) - }) - second, errs := Generate(timedCtx, fx.root, fx.env, []string{"./app"}, &GenerateOptions{}) - - if scenario.want.wantErr { - if len(errs) == 0 { - t.Fatal("expected Generate to return errors") - } - if len(second) != 0 { - t.Fatalf("expected invalid incremental generate to stop before generation, got %+v", second) - } - } else { - if len(errs) > 0 { - t.Fatalf("incremental Generate returned errors: %v", errs) - } - if len(second) != 1 || len(second[0].Errs) > 0 { - t.Fatalf("unexpected incremental Generate result: %+v", second) - } - } - - switch scenario.want.mode { - case "preload": - if containsLabel(labels, "generate.load") { - t.Fatalf("expected preload reuse without generate.load, labels=%v", labels) - } - case "fast": - if containsLabel(labels, "generate.load") { - t.Fatalf("expected fast incremental path without generate.load, labels=%v", labels) - } - case "local_fastpath": - if containsLabel(labels, "generate.load") { - t.Fatalf("expected local fast path without generate.load, labels=%v", labels) - } - if containsLabel(labels, "load.packages.lazy.load") { - t.Fatalf("expected local fast path to skip lazy load, labels=%v", labels) - } - if !containsLabel(labels, "incremental.local_fastpath.load") { - t.Fatalf("expected local fast path load, labels=%v", labels) - } - case "generate_load": - if !containsLabel(labels, "generate.load") { - t.Fatalf("expected generate.load fallback, labels=%v", labels) - } - default: - t.Fatalf("unknown expected mode %q", scenario.want.mode) - } - - if scenario.want.wantErr { - return - } - - normal, errs := Generate(context.Background(), fx.root, fx.env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("normal Generate returned errors after edit: %v", errs) - } - if len(normal) != 1 || len(normal[0].Errs) > 0 { - t.Fatalf("unexpected normal Generate result after edit: %+v", normal) - } - if second[0].OutputPath != normal[0].OutputPath { - t.Fatalf("output paths differ: incremental=%q normal=%q", second[0].OutputPath, normal[0].OutputPath) - } - if string(second[0].Content) != string(normal[0].Content) { - t.Fatalf("incremental output differs from normal output after %s", scenario.name) - } - if scenario.want.wantSameOutput && string(first[0].Content) != string(second[0].Content) { - t.Fatalf("expected generated output to stay unchanged for %s", scenario.name) - } - if !scenario.want.wantSameOutput && string(first[0].Content) == string(second[0].Content) { - t.Fatalf("expected generated output to change for %s", scenario.name) - } - }) - } -} - -func TestGenerateIncrementalShapeChangeFallsBackToLazyLoad(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.NewSet)", - "\treturn nil", - "}", - "", - }, "\n")) - - depFile := filepath.Join(root, "dep", "dep.go") - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return \"a\" }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n")) - - wireFile := filepath.Join(root, "dep", "wire.go") - writeFile(t, wireFile, strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, New)", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := WithIncremental(context.Background(), true) - - var firstLabels []string - firstCtx := WithTiming(ctx, func(label string, _ time.Duration) { - firstLabels = append(firstLabels, label) - }) - first, errs := Generate(firstCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("first Generate returned errors: %v", errs) - } - if len(first) != 1 || len(first[0].Errs) > 0 { - t.Fatalf("unexpected first Generate result: %+v", first) - } - if !containsLabel(firstLabels, "load.packages.lazy.load") { - t.Fatalf("expected first incremental generate to perform lazy load, labels=%v", firstLabels) - } - - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string; Count int }", - "", - "func NewMessage() string { return \"a\" }", - "", - "func NewCount() int { return 7 }", - "", - "func New(msg string, count int) *Foo {", - "\treturn &Foo{Message: msg, Count: count}", - "}", - "", - }, "\n")) - writeFile(t, wireFile, strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, NewCount, New)", - "", - }, "\n")) - - var secondLabels []string - secondCtx := WithTiming(ctx, func(label string, _ time.Duration) { - secondLabels = append(secondLabels, label) - }) - second, errs := Generate(secondCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("second Generate returned errors: %v", errs) - } - if len(second) != 1 || len(second[0].Errs) > 0 { - t.Fatalf("unexpected second Generate result: %+v", second) - } - if containsLabel(secondLabels, "generate.load") { - t.Fatalf("expected shape-changing incremental run to skip package load via local fast path, labels=%v", secondLabels) - } - if containsLabel(secondLabels, "load.packages.lazy.load") { - t.Fatalf("expected shape-changing incremental run to skip lazy load via local fast path, labels=%v", secondLabels) - } - if !containsLabel(secondLabels, "incremental.local_fastpath.load") { - t.Fatalf("expected shape-changing incremental run to use local fast path, labels=%v", secondLabels) - } - if string(first[0].Content) == string(second[0].Content) { - t.Fatal("expected shape-changing edit to regenerate different output") - } -} - -func TestGenerateIncrementalRepeatedShapeStateHitsPreloadManifest(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.NewSet)", - "\treturn nil", - "}", - "", - }, "\n")) - - depFile := filepath.Join(root, "dep", "dep.go") - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string; Count int }", - "", - "func NewMessage() string { return \"a\" }", - "", - "func NewCount() int { return 7 }", - "", - "func New(msg string, count int) *Foo {", - "\treturn &Foo{Message: msg, Count: count}", - "}", - "", - }, "\n")) - writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, NewCount, New)", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := WithIncremental(context.Background(), true) - - first, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("first Generate returned errors: %v", errs) - } - if len(first) != 1 || len(first[0].Errs) > 0 { - t.Fatalf("unexpected first Generate result: %+v", first) - } - - var secondLabels []string - secondCtx := WithTiming(ctx, func(label string, _ time.Duration) { - secondLabels = append(secondLabels, label) - }) - second, errs := Generate(secondCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("second Generate returned errors: %v", errs) - } - if len(second) != 1 || len(second[0].Errs) > 0 { - t.Fatalf("unexpected second Generate result: %+v", second) - } - if containsLabel(secondLabels, "generate.load") { - t.Fatalf("expected repeated shape state to hit preload manifest before package load, labels=%v", secondLabels) - } - if containsLabel(secondLabels, "load.packages.lazy.load") { - t.Fatalf("expected repeated shape state to skip lazy load, labels=%v", secondLabels) - } - if string(first[0].Content) != string(second[0].Content) { - t.Fatal("expected repeated shape state to reuse identical generated output") - } -} - -func TestGenerateIncrementalShapeChangeThenRepeatHitsPreloadManifest(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.NewSet)", - "\treturn nil", - "}", - "", - }, "\n")) - writeFile(t, filepath.Join(root, "extra", "extra.go"), strings.Join([]string{ - "package extra", - "", - "type Marker struct{}", - "", - }, "\n")) - - depFile := filepath.Join(root, "dep", "dep.go") - wireFile := filepath.Join(root, "dep", "wire.go") - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return \"a\" }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n")) - writeFile(t, wireFile, strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, New)", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := WithIncremental(context.Background(), true) - - first, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("first Generate returned errors: %v", errs) - } - if len(first) != 1 || len(first[0].Errs) > 0 { - t.Fatalf("unexpected first Generate result: %+v", first) - } - - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string; Count int }", - "", - "func NewMessage() string { return \"a\" }", - "", - "func NewCount() int { return 7 }", - "", - "func New(msg string, count int) *Foo {", - "\treturn &Foo{Message: msg, Count: count}", - "}", - "", - }, "\n")) - writeFile(t, wireFile, strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, NewCount, New)", - "", - }, "\n")) - - var secondLabels []string - secondCtx := WithTiming(ctx, func(label string, _ time.Duration) { - secondLabels = append(secondLabels, label) - }) - second, errs := Generate(secondCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("second Generate returned errors: %v", errs) - } - if len(second) != 1 || len(second[0].Errs) > 0 { - t.Fatalf("unexpected second Generate result: %+v", second) - } - if containsLabel(secondLabels, "generate.load") { - t.Fatalf("expected shape-changing Generate to skip package load via local fast path, labels=%v", secondLabels) - } - if containsLabel(secondLabels, "load.packages.lazy.load") { - t.Fatalf("expected shape-changing Generate to skip lazy load via local fast path, labels=%v", secondLabels) - } - if !containsLabel(secondLabels, "incremental.local_fastpath.load") { - t.Fatalf("expected shape-changing Generate to use local fast path, labels=%v", secondLabels) - } - - var thirdLabels []string - thirdCtx := WithTiming(ctx, func(label string, _ time.Duration) { - thirdLabels = append(thirdLabels, label) - }) - third, errs := Generate(thirdCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("third Generate returned errors: %v", errs) - } - if len(third) != 1 || len(third[0].Errs) > 0 { - t.Fatalf("unexpected third Generate result: %+v", third) - } - if containsLabel(thirdLabels, "generate.load") { - t.Fatalf("expected repeated shape-changing state to hit preload manifest before package load, labels=%v", thirdLabels) - } - if containsLabel(thirdLabels, "load.packages.lazy.load") { - t.Fatalf("expected repeated shape-changing state to skip lazy load, labels=%v", thirdLabels) - } - if string(second[0].Content) != string(third[0].Content) { - t.Fatal("expected repeated shape-changing state to reuse identical generated output") - } -} - -func TestGenerateIncrementalShapeChangeMatchesNormalGenerate(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeIncrementalBenchmarkModule(t, repoRoot, root) - - env := append(os.Environ(), "GOWORK=off") - incrementalCtx := WithIncremental(context.Background(), true) - - if _, errs := Generate(incrementalCtx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { - t.Fatalf("baseline incremental Generate returned errors: %v", errs) - } - - writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string; Count int }", - "", - "func NewMessage() string { return \"ok\" }", - "", - "func NewCount() int { return 7 }", - "", - "func New(msg string, count int) *Foo {", - "\treturn &Foo{Message: msg, Count: count}", - "}", - "", - }, "\n")) - writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, NewCount, New)", - "", - }, "\n")) - - var incrementalLabels []string - incrementalTimedCtx := WithTiming(incrementalCtx, func(label string, _ time.Duration) { - incrementalLabels = append(incrementalLabels, label) - }) - incrementalGens, errs := Generate(incrementalTimedCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("incremental shape-change Generate returned errors: %v", errs) - } - if len(incrementalGens) != 1 || len(incrementalGens[0].Errs) > 0 { - t.Fatalf("unexpected incremental Generate results: %+v", incrementalGens) - } - if !containsLabel(incrementalLabels, "incremental.local_fastpath.load") { - t.Fatalf("expected incremental shape-change Generate to use local fast path, labels=%v", incrementalLabels) - } - - normalGens, errs := Generate(context.Background(), root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("normal Generate returned errors: %v", errs) - } - if len(normalGens) != 1 || len(normalGens[0].Errs) > 0 { - t.Fatalf("unexpected normal Generate results: %+v", normalGens) - } - if incrementalGens[0].OutputPath != normalGens[0].OutputPath { - t.Fatalf("output paths differ: incremental=%q normal=%q", incrementalGens[0].OutputPath, normalGens[0].OutputPath) - } - if string(incrementalGens[0].Content) != string(normalGens[0].Content) { - t.Fatal("shape-changing incremental output differs from normal Generate output") - } -} - -func TestGenerateIncrementalColdBootstrapStillSeedsFastPath(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - writeLargeBenchmarkModule(t, repoRoot, root, 24) - - env := append(os.Environ(), "GOWORK=off") - ctx := WithIncremental(context.Background(), true) - - if _, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { - t.Fatalf("cold bootstrap Generate returned errors: %v", errs) - } - - mutateLargeBenchmarkModule(t, root, 12) - - var labels []string - timedCtx := WithTiming(ctx, func(label string, _ time.Duration) { - labels = append(labels, label) - }) - gens, errs := Generate(timedCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("shape-change Generate returned errors: %v", errs) - } - if len(gens) != 1 || len(gens[0].Errs) > 0 { - t.Fatalf("unexpected Generate results: %+v", gens) - } - if !containsLabel(labels, "incremental.local_fastpath.load") { - t.Fatalf("expected cold bootstrap to seed fast path, labels=%v", labels) - } -} - -func TestLoadLocalPackagesForFastPathImportsUnchangedLocalDependencyFromLocalExport(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - writeDepRouterModule(t, root, repoRoot) - - env := append(os.Environ(), "GOWORK=off") - incrementalCtx := WithIncremental(context.Background(), true) - - if _, errs := Generate(incrementalCtx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { - t.Fatalf("baseline incremental Generate returned errors: %v", errs) - } - - depPkgPath := "example.com/app/dep" - depExportPath := mustLocalExportPath(t, root, env, depPkgPath) - if _, err := os.Stat(depExportPath); err != nil { - t.Fatalf("expected local export artifact at %s: %v", depExportPath, err) - } - - mutateRouterModule(t, root) - - preloadState, ok := prepareIncrementalPreloadState(context.Background(), root, env, []string{"./app"}, &GenerateOptions{}) - if !ok || preloadState == nil || preloadState.manifest == nil { - t.Fatal("expected preload state after baseline incremental generate") - } - loaded, err := loadLocalPackagesForFastPath(context.Background(), root, "", "example.com/app/app", []string{"example.com/app/router"}, preloadState.currentLocal, preloadState.manifest.ExternalPkgs) - if err != nil { - t.Fatalf("loadLocalPackagesForFastPath returned error: %v", err) - } - if _, ok := loaded.loader.localExports[depPkgPath]; !ok { - t.Fatalf("expected %s to be a local export candidate", depPkgPath) - } - if _, ok := loaded.loader.sourcePkgs[depPkgPath]; ok { - t.Fatalf("did not expect %s to be source-loaded", depPkgPath) - } - typesPkg, err := loaded.loader.importPackage(depPkgPath) - if err != nil { - t.Fatalf("importPackage(%s) returned error: %v", depPkgPath, err) - } - if typesPkg == nil || !typesPkg.Complete() { - t.Fatalf("expected complete imported package for %s, got %#v", depPkgPath, typesPkg) - } - if loaded.loader.pkgs[depPkgPath] != nil { - t.Fatalf("expected %s to avoid source loading when local export artifact is present", depPkgPath) - } -} - -func TestGenerateIncrementalMissingLocalExportFallsBackSafely(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - writeDepRouterModule(t, root, repoRoot) - - env := append(os.Environ(), "GOWORK=off") - incrementalCtx := WithIncremental(context.Background(), true) - - if _, errs := Generate(incrementalCtx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { - t.Fatalf("baseline incremental Generate returned errors: %v", errs) - } - - depExportPath := mustLocalExportPath(t, root, env, "example.com/app/dep") - if err := os.Remove(depExportPath); err != nil { - t.Fatalf("Remove(%s) failed: %v", depExportPath, err) - } - - mutateRouterModule(t, root) - - var labels []string - timedCtx := WithTiming(incrementalCtx, func(label string, _ time.Duration) { - labels = append(labels, label) - }) - gens, errs := Generate(timedCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("Generate returned errors: %v", errs) - } - if len(gens) != 1 || len(gens[0].Errs) > 0 { - t.Fatalf("unexpected Generate results: %+v", gens) - } - if !containsLabel(labels, "incremental.local_fastpath.load") { - t.Fatalf("expected missing local export to stay on local fast path, labels=%v", labels) - } - refreshedExportPath := mustLocalExportPath(t, root, env, "example.com/app/dep") - if _, err := os.Stat(refreshedExportPath); err != nil { - t.Fatalf("expected local export artifact to be refreshed at %s: %v", refreshedExportPath, err) - } -} - -func TestGenerateIncrementalCorruptedLocalExportFallsBackSafely(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - writeDepRouterModule(t, root, repoRoot) - - env := append(os.Environ(), "GOWORK=off") - incrementalCtx := WithIncremental(context.Background(), true) - - if _, errs := Generate(incrementalCtx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { - t.Fatalf("baseline incremental Generate returned errors: %v", errs) - } - - depExportPath := mustLocalExportPath(t, root, env, "example.com/app/dep") - if err := os.WriteFile(depExportPath, []byte("not-a-valid-export"), 0644); err != nil { - t.Fatalf("WriteFile(%s) failed: %v", depExportPath, err) - } - - mutateRouterModule(t, root) - - var labels []string - timedCtx := WithTiming(incrementalCtx, func(label string, _ time.Duration) { - labels = append(labels, label) - }) - gens, errs := Generate(timedCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("Generate returned errors: %v", errs) - } - if len(gens) != 1 || len(gens[0].Errs) > 0 { - t.Fatalf("unexpected Generate results: %+v", gens) - } - if !containsLabel(labels, "incremental.local_fastpath.load") { - t.Fatalf("expected corrupted local export to stay on local fast path, labels=%v", labels) - } - refreshedExportPath := mustLocalExportPath(t, root, env, "example.com/app/dep") - data, err := os.ReadFile(refreshedExportPath) - if err != nil { - t.Fatalf("ReadFile(%s) failed: %v", refreshedExportPath, err) - } - if string(data) == "not-a-valid-export" { - t.Fatalf("expected corrupted local export artifact to be refreshed at %s", refreshedExportPath) - } -} - -func TestGenerateIncrementalShapeChangeWithUnchangedDependentPackageMatchesNormalGenerate(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"example.com/app/router\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *router.Routes {", - "\twire.Build(dep.Set, router.Set)", - "\treturn nil", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "type Controller struct { Message string }", - "", - "func NewMessage() string { return \"ok\" }", - "", - "func NewController(msg string) *Controller {", - "\treturn &Controller{Message: msg}", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var Set = wire.NewSet(NewMessage, NewController)", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "router", "router.go"), strings.Join([]string{ - "package router", - "", - "import \"example.com/app/dep\"", - "", - "type Routes struct { Controller *dep.Controller }", - "", - "func ProvideRoutes(controller *dep.Controller) *Routes {", - "\treturn &Routes{Controller: controller}", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "router", "wire.go"), strings.Join([]string{ - "package router", - "", - "import \"github.com/goforj/wire\"", - "", - "var Set = wire.NewSet(ProvideRoutes)", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - incrementalCtx := WithIncremental(context.Background(), true) - - if _, errs := Generate(incrementalCtx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { - t.Fatalf("baseline incremental Generate returned errors: %v", errs) - } - - writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "type Controller struct { Message string; Count int }", - "", - "func NewMessage() string { return \"ok\" }", - "", - "func NewCount() int { return 7 }", - "", - "func NewController(msg string, count int) *Controller {", - "\treturn &Controller{Message: msg, Count: count}", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var Set = wire.NewSet(NewMessage, NewCount, NewController)", - "", - }, "\n")) - - var incrementalLabels []string - incrementalTimedCtx := WithTiming(incrementalCtx, func(label string, _ time.Duration) { - incrementalLabels = append(incrementalLabels, label) - }) - incrementalGens, errs := Generate(incrementalTimedCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("incremental Generate returned errors: %v", errs) - } - if len(incrementalGens) != 1 || len(incrementalGens[0].Errs) > 0 { - t.Fatalf("unexpected incremental Generate results: %+v", incrementalGens) - } - if !containsLabel(incrementalLabels, "incremental.local_fastpath.load") { - t.Fatalf("expected incremental Generate to use local fast path, labels=%v", incrementalLabels) - } - - normalGens, errs := Generate(context.Background(), root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("normal Generate returned errors: %v", errs) - } - if len(normalGens) != 1 || len(normalGens[0].Errs) > 0 { - t.Fatalf("unexpected normal Generate results: %+v", normalGens) - } - if string(incrementalGens[0].Content) != string(normalGens[0].Content) { - t.Fatal("incremental output differs from normal Generate output when unchanged package depends on changed package") - } -} - -func TestGenerateIncrementalInvalidShapeChangeDoesNotReuseManifest(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.NewSet)", - "\treturn nil", - "}", - "", - }, "\n")) - - depFile := filepath.Join(root, "dep", "dep.go") - wireFile := filepath.Join(root, "dep", "wire.go") - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return \"a\" }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n")) - writeFile(t, wireFile, strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, New)", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := WithIncremental(context.Background(), true) - - first, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("first Generate returned errors: %v", errs) - } - if len(first) != 1 || len(first[0].Errs) > 0 { - t.Fatalf("unexpected first Generate result: %+v", first) - } - - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "import \"example.com/app/extra\"", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return \"a\" }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n")) - - var secondLabels []string - secondCtx := WithTiming(ctx, func(label string, _ time.Duration) { - secondLabels = append(secondLabels, label) - }) - second, errs := Generate(secondCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(second) != 0 { - t.Fatalf("expected invalid incremental generate to stop before generation, got %+v", second) - } - if len(errs) == 0 { - t.Fatal("expected invalid incremental generate to return errors") - } - if got := errs[0].Error(); !strings.Contains(got, "type-check failed for example.com/app/app") { - t.Fatalf("expected fast-path type-check error, got %q", got) - } - if _, ok := readIncrementalManifest(incrementalManifestSelectorKey(root, env, []string{"./app"}, &GenerateOptions{})); ok { - t.Fatal("expected invalid incremental generate to invalidate selector manifest") - } -} - -func TestGenerateIncrementalRecoversAfterInvalidShapeChange(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.NewSet)", - "\treturn nil", - "}", - "", - }, "\n")) - - depFile := filepath.Join(root, "dep", "dep.go") - wireFile := filepath.Join(root, "dep", "wire.go") - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return \"a\" }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n")) - writeFile(t, wireFile, strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, New)", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := WithIncremental(context.Background(), true) - - first, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("first Generate returned errors: %v", errs) - } - if len(first) != 1 || len(first[0].Errs) > 0 { - t.Fatalf("unexpected first Generate result: %+v", first) - } - - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "import \"example.com/app/extra\"", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return \"a\" }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n")) - - second, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(second) != 0 { - t.Fatalf("expected invalid incremental generate to stop before generation, got %+v", second) - } - if len(errs) == 0 { - t.Fatal("expected invalid incremental generate to return errors") - } - clearIncrementalSessions() - - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string; Count int }", - "", - "func NewMessage() string { return \"a\" }", - "", - "func NewCount() int { return 7 }", - "", - "func New(msg string, count int) *Foo {", - "\treturn &Foo{Message: msg, Count: count}", - "}", - "", - }, "\n")) - writeFile(t, wireFile, strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, NewCount, New)", - "", - }, "\n")) - - var thirdLabels []string - thirdCtx := WithTiming(ctx, func(label string, _ time.Duration) { - thirdLabels = append(thirdLabels, label) - }) - third, errs := Generate(thirdCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("recovery incremental Generate returned errors: %v", errs) - } - if len(third) != 1 || len(third[0].Errs) > 0 { - t.Fatalf("unexpected recovery incremental Generate result: %+v", third) - } - - normal, errs := Generate(context.Background(), root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("normal Generate returned errors: %v", errs) - } - if len(normal) != 1 || len(normal[0].Errs) > 0 { - t.Fatalf("unexpected normal Generate result: %+v", normal) - } - if string(third[0].Content) != string(normal[0].Content) { - t.Fatal("incremental output differs from normal Generate output after recovering from invalid shape change") - } - if !containsLabel(thirdLabels, "incremental.local_fastpath.load") && !containsLabel(thirdLabels, "generate.load") { - t.Fatalf("expected recovery run to rebuild through local fast path or normal load, labels=%v", thirdLabels) - } -} - -func TestGenerateIncrementalToggleBackToKnownShapeHitsArchivedPreloadManifest(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.NewSet)", - "\treturn nil", - "}", - "", - }, "\n")) - - depFile := filepath.Join(root, "dep", "dep.go") - wireFile := filepath.Join(root, "dep", "wire.go") - - oldDep := strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return \"a\" }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n") - newDep := strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string; Count int }", - "", - "func NewMessage() string { return \"a\" }", - "", - "func NewCount() int { return 7 }", - "", - "func New(msg string, count int) *Foo {", - "\treturn &Foo{Message: msg, Count: count}", - "}", - "", - }, "\n") - oldWire := strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, New)", - "", - }, "\n") - newWire := strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, NewCount, New)", - "", - }, "\n") - - writeFile(t, depFile, oldDep) - writeFile(t, wireFile, oldWire) - - env := append(os.Environ(), "GOWORK=off") - ctx := WithIncremental(context.Background(), true) - - first, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("first Generate returned errors: %v", errs) - } - if len(first) != 1 || len(first[0].Errs) > 0 { - t.Fatalf("unexpected first Generate result: %+v", first) - } - - writeFile(t, depFile, newDep) - writeFile(t, wireFile, newWire) - second, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("second Generate returned errors: %v", errs) - } - if len(second) != 1 || len(second[0].Errs) > 0 { - t.Fatalf("unexpected second Generate result: %+v", second) - } - - writeFile(t, depFile, oldDep) - writeFile(t, wireFile, oldWire) - - var thirdLabels []string - thirdCtx := WithTiming(ctx, func(label string, _ time.Duration) { - thirdLabels = append(thirdLabels, label) - }) - third, errs := Generate(thirdCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("third Generate returned errors: %v", errs) - } - if len(third) != 1 || len(third[0].Errs) > 0 { - t.Fatalf("unexpected third Generate result: %+v", third) - } - if containsLabel(thirdLabels, "generate.load") { - t.Fatalf("expected toggled-back shape state to hit archived preload manifest before package load, labels=%v", thirdLabels) - } - if containsLabel(thirdLabels, "load.packages.lazy.load") { - t.Fatalf("expected toggled-back shape state to skip lazy load, labels=%v", thirdLabels) - } - if string(first[0].Content) != string(third[0].Content) { - t.Fatal("expected toggled-back shape state to reuse archived generated output") - } -} - -func TestGenerateIncrementalPreloadHitRefreshesMissingContentHashes(t *testing.T) { - fx := newIncrementalScenarioFixture(t) - - first, errs := Generate(fx.ctx, fx.root, fx.env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("baseline Generate returned errors: %v", errs) - } - if len(first) != 1 || len(first[0].Errs) > 0 { - t.Fatalf("unexpected baseline Generate result: %+v", first) - } - - selectorKey := incrementalManifestSelectorKey(fx.root, fx.env, []string{"./app"}, &GenerateOptions{}) - manifest, ok := readIncrementalManifest(selectorKey) - if !ok { - t.Fatal("expected incremental manifest after baseline generate") - } - if len(manifest.LocalPackages) == 0 { - t.Fatal("expected local packages in incremental manifest") - } - - stale := *manifest - stale.LocalPackages = append([]packageFingerprint(nil), manifest.LocalPackages...) - for i := range stale.LocalPackages { - stale.LocalPackages[i].ContentHash = "" - stale.LocalPackages[i].Dirs = nil - } - writeIncrementalManifestFile(selectorKey, &stale) - writeIncrementalManifestFile(incrementalManifestStateKey(selectorKey, stale.LocalPackages), &stale) - - second, errs := Generate(fx.ctx, fx.root, fx.env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("refresh Generate returned errors: %v", errs) - } - if len(second) != 1 || len(second[0].Errs) > 0 { - t.Fatalf("unexpected refresh Generate result: %+v", second) - } - - preloadState, ok := prepareIncrementalPreloadState(fx.ctx, fx.root, fx.env, []string{"./app"}, &GenerateOptions{}) - if !ok { - t.Fatal("expected preload state after manifest refresh") - } - if !preloadState.valid { - t.Fatalf("expected refreshed preload state to be valid, reason=%s", preloadState.reason) - } - if len(preloadState.touched) != 0 { - t.Fatalf("expected refreshed preload state to have no touched packages, got %v", preloadState.touched) - } -} - -func containsLabel(labels []string, want string) bool { - for _, label := range labels { - if label == want { - return true - } - } - return false -} - -type incrementalScenarioFixture struct { - root string - env []string - ctx context.Context - depFile string - wireFile string - extraFile string -} - -func newIncrementalScenarioFixture(t *testing.T) incrementalScenarioFixture { - t.Helper() - - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.NewSet)", - "\treturn nil", - "}", - "", - }, "\n")) - - depFile := filepath.Join(root, "dep", "dep.go") - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - - wireFile := filepath.Join(root, "dep", "wire.go") - writeFile(t, wireFile, strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, New)", - "", - }, "\n")) - - return incrementalScenarioFixture{ - root: root, - env: append(os.Environ(), "GOWORK=off"), - ctx: WithIncremental(context.Background(), true), - depFile: depFile, - wireFile: wireFile, - extraFile: filepath.Join(root, "dep", "extra.go"), - } -} - -func mustRepoRoot(t *testing.T) string { - t.Helper() - wd, err := os.Getwd() - if err != nil { - t.Fatalf("Getwd failed: %v", err) - } - repoRoot := filepath.Clean(filepath.Join(wd, "..", "..")) - if _, err := os.Stat(filepath.Join(repoRoot, "go.mod")); err != nil { - t.Fatalf("repo root not found at %s: %v", repoRoot, err) - } - return repoRoot -} - -func writeDepRouterModule(t *testing.T, root string, repoRoot string) { - t.Helper() - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"example.com/app/router\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *router.Routes {", - "\twire.Build(dep.Set, router.Set)", - "\treturn nil", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "type Controller struct { Message string }", - "", - "func NewMessage() string { return \"ok\" }", - "", - "func NewController(msg string) *Controller {", - "\treturn &Controller{Message: msg}", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var Set = wire.NewSet(NewMessage, NewController)", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "router", "router.go"), strings.Join([]string{ - "package router", - "", - "import \"example.com/app/dep\"", - "", - "type Routes struct { Controller *dep.Controller }", - "", - "func ProvideRoutes(controller *dep.Controller) *Routes {", - "\treturn &Routes{Controller: controller}", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "router", "wire.go"), strings.Join([]string{ - "package router", - "", - "import \"github.com/goforj/wire\"", - "", - "var Set = wire.NewSet(ProvideRoutes)", - "", - }, "\n")) -} - -func mutateRouterModule(t *testing.T, root string) { - t.Helper() - writeFile(t, filepath.Join(root, "router", "router.go"), strings.Join([]string{ - "package router", - "", - "import \"example.com/app/dep\"", - "", - "type Routes struct {", - "\tController *dep.Controller", - "\tVersion int", - "}", - "", - "func NewVersion() int {", - "\treturn 2", - "}", - "", - "func ProvideRoutes(controller *dep.Controller, version int) *Routes {", - "\treturn &Routes{Controller: controller, Version: version}", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "router", "wire.go"), strings.Join([]string{ - "package router", - "", - "import \"github.com/goforj/wire\"", - "", - "var Set = wire.NewSet(NewVersion, ProvideRoutes)", - "", - }, "\n")) -} - -func mustLocalExportPath(t *testing.T, root string, env []string, pkgPath string) string { - t.Helper() - pkgs, loader, errs := load(context.Background(), root, env, "", []string{"./app"}) - if len(errs) > 0 { - t.Fatalf("load returned errors: %v", errs) - } - if loader == nil { - t.Fatal("load returned nil loader") - } - if _, errs := loader.load("example.com/app/app"); len(errs) > 0 { - t.Fatalf("lazy load returned errors: %v", errs) - } - snapshot := buildIncrementalManifestSnapshotFromPackages(root, "", incrementalManifestPackages(pkgs, loader)) - if snapshot == nil || snapshot.fingerprints[pkgPath] == nil { - t.Fatalf("missing fingerprint for %s", pkgPath) - } - path := localExportPathForFingerprint(root, "", snapshot.fingerprints[pkgPath]) - if path == "" { - t.Fatalf("missing local export path for %s", pkgPath) - } - return path -} - -func writeFile(t *testing.T, path string, content string) { - t.Helper() - if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { - t.Fatalf("MkdirAll failed: %v", err) - } - if err := os.WriteFile(path, []byte(content), 0644); err != nil { - t.Fatalf("WriteFile failed: %v", err) - } -} diff --git a/internal/wire/loader_timing_bridge.go b/internal/wire/loader_timing_bridge.go new file mode 100644 index 0000000..2de0245 --- /dev/null +++ b/internal/wire/loader_timing_bridge.go @@ -0,0 +1,17 @@ +package wire + +import ( + "context" + "time" + + "github.com/goforj/wire/internal/loader" +) + +func withLoaderTiming(ctx context.Context) context.Context { + if t := timing(ctx); t != nil { + return loader.WithTiming(ctx, func(label string, d time.Duration) { + t(label, d) + }) + } + return ctx +} diff --git a/internal/wire/cache_hooks.go b/internal/wire/loader_validation.go similarity index 50% rename from internal/wire/cache_hooks.go rename to internal/wire/loader_validation.go index 9d4be6d..6868b7b 100644 --- a/internal/wire/cache_hooks.go +++ b/internal/wire/loader_validation.go @@ -15,27 +15,19 @@ package wire import ( - "encoding/json" - "os" -) + "context" -var ( - osCreateTemp = os.CreateTemp - osMkdirAll = os.MkdirAll - osReadFile = os.ReadFile - osRemove = os.Remove - osRemoveAll = os.RemoveAll - osRename = os.Rename - osStat = os.Stat - osTempDir = os.TempDir + "github.com/goforj/wire/internal/loader" +) - jsonMarshal = json.Marshal - jsonUnmarshal = json.Unmarshal +func loaderValidationMode(ctx context.Context, wd string, env []string) bool { + return effectiveLoaderMode(ctx, wd, env) != loader.ModeFallback +} - cacheKeyForPackageFunc = cacheKeyForPackage - detectOutputDirFunc = detectOutputDir - buildCacheFilesFunc = buildCacheFiles - buildCacheFilesFromMetaFunc = buildCacheFilesFromMeta - rootPackageFilesFunc = rootPackageFiles - hashFilesFunc = hashFiles -) +func effectiveLoaderMode(ctx context.Context, wd string, env []string) loader.Mode { + mode := loader.ModeFromEnv(env) + if mode != loader.ModeAuto { + return mode + } + return loader.ModeAuto +} diff --git a/internal/wire/local_export.go b/internal/wire/local_export.go deleted file mode 100644 index f83ed7b..0000000 --- a/internal/wire/local_export.go +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "crypto/sha256" - "fmt" - "go/token" - "go/types" - "path/filepath" - - "golang.org/x/tools/go/gcexportdata" - "golang.org/x/tools/go/packages" -) - -const localExportVersion = "wire-local-export-v1" - -func localExportKey(wd string, tags string, pkgPath string, shapeHash string) string { - sum := sha256.Sum256([]byte(localExportVersion + "\x00" + packageCacheScope(wd) + "\x00" + tags + "\x00" + pkgPath + "\x00" + shapeHash)) - return fmt.Sprintf("%x", sum[:]) -} - -func localExportPath(key string) string { - return filepath.Join(cacheDir(), key+".iexp") -} - -func localExportPathForFingerprint(wd string, tags string, fp *packageFingerprint) string { - if fp == nil || fp.PkgPath == "" || fp.ShapeHash == "" { - return "" - } - return localExportPath(localExportKey(wd, tags, fp.PkgPath, fp.ShapeHash)) -} - -func localExportExists(wd string, tags string, fp *packageFingerprint) bool { - path := localExportPathForFingerprint(wd, tags, fp) - if path == "" { - return false - } - _, err := osStat(path) - return err == nil -} - -func writeLocalPackageExports(wd string, tags string, pkgs []*packages.Package, fps map[string]*packageFingerprint) { - if len(pkgs) == 0 || len(fps) == 0 { - return - } - moduleRoot := findModuleRoot(wd) - for _, pkg := range pkgs { - if pkg == nil || pkg.Types == nil || pkg.PkgPath == "" { - continue - } - if classifyPackageLocation(moduleRoot, pkg) != "local" { - continue - } - fp := fps[pkg.PkgPath] - path := localExportPathForFingerprint(wd, tags, fp) - if path == "" { - continue - } - writeLocalPackageExportFile(path, pkg.Fset, pkg.Types) - } -} - -func writeLocalPackageExportFile(path string, fset *token.FileSet, pkg *types.Package) { - if path == "" || fset == nil || pkg == nil { - return - } - dir := cacheDir() - if err := osMkdirAll(dir, 0755); err != nil { - return - } - tmp, err := osCreateTemp(dir, filepath.Base(path)+".tmp-") - if err != nil { - return - } - writeErr := gcexportdata.Write(tmp, fset, pkg) - closeErr := tmp.Close() - if writeErr != nil || closeErr != nil { - osRemove(tmp.Name()) - return - } - if err := osRename(tmp.Name(), path); err != nil { - osRemove(tmp.Name()) - } -} diff --git a/internal/wire/local_fastpath.go b/internal/wire/local_fastpath.go deleted file mode 100644 index 89ea402..0000000 --- a/internal/wire/local_fastpath.go +++ /dev/null @@ -1,1031 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "context" - "fmt" - "go/ast" - "go/format" - importerpkg "go/importer" - "go/parser" - "go/token" - "go/types" - "io" - "os" - "path/filepath" - "runtime" - "sort" - "strings" - "time" - - "golang.org/x/tools/go/gcexportdata" - "golang.org/x/tools/go/packages" -) - -func tryIncrementalLocalFastPath(ctx context.Context, wd string, env []string, patterns []string, opts *GenerateOptions, state *incrementalPreloadState) ([]GenerateResult, bool, bool, []error) { - if state == nil || state.manifest == nil { - return nil, false, false, nil - } - if !strings.HasSuffix(state.reason, ".shape_mismatch") { - return nil, false, false, nil - } - roots := manifestOutputPkgPaths(state.manifest) - if len(roots) != 1 { - return nil, false, false, nil - } - changed := changedPackagePaths(state.manifest.LocalPackages, state.currentLocal) - if len(changed) != 1 { - return nil, false, false, nil - } - graph, ok := readIncrementalGraph(incrementalGraphKey(wd, opts.Tags, roots)) - if !ok { - return nil, false, false, nil - } - affected := affectedRoots(graph, changed) - if len(affected) != 1 || affected[0] != roots[0] { - return nil, false, false, nil - } - - fastPathStart := time.Now() - loaded, err := loadLocalPackagesForFastPath(ctx, wd, opts.Tags, roots[0], changed, state.currentLocal, state.manifest.ExternalPkgs) - if err != nil { - debugf(ctx, "incremental.local_fastpath miss reason=%v", err) - if shouldBypassIncrementalManifestAfterFastPathError(err) { - invalidateIncrementalPreloadState(state) - return nil, true, true, []error{err} - } - return nil, false, false, nil - } - logTiming(ctx, "incremental.local_fastpath.load", fastPathStart) - - generated, errs := generateFromTypedPackages(ctx, loaded, opts) - logTiming(ctx, "incremental.local_fastpath.generate", fastPathStart) - if len(errs) > 0 { - return nil, true, true, errs - } - - snapshot := &incrementalFingerprintSnapshot{ - fingerprints: loaded.fingerprints, - changed: append([]string(nil), changed...), - } - loader := &lazyLoader{ - ctx: ctx, - wd: wd, - env: env, - tags: opts.Tags, - fset: loaded.fset, - fingerprints: snapshot, - loaded: make(map[string]*packages.Package, len(loaded.byPath)), - } - for path, pkg := range loaded.byPath { - loader.loaded[path] = pkg - } - changedSet := make(map[string]struct{}, len(snapshot.changed)) - for _, path := range snapshot.changed { - changedSet[path] = struct{}{} - } - currentPackages := loaded.currentPackages() - writeIncrementalFingerprints(snapshot, wd, opts.Tags) - writeLocalPackageExports(wd, opts.Tags, currentPackages, loaded.fingerprints) - writeIncrementalPackageSummariesWithSummary(loader, currentPackages, newSummaryProviderResolver(ctx, loaded.loader.summaries, loaded.loader.importExportPackage), changedSet) - writeIncrementalManifestFromState(wd, env, patterns, opts, state, snapshot, generated) - writeIncrementalGraphFromSnapshot(wd, opts.Tags, roots, loaded.fingerprints) - - debugf(ctx, "incremental.local_fastpath hit root=%s changed=%s", roots[0], strings.Join(changed, ",")) - return generated, true, false, nil -} - -func validateIncrementalTouchedPackages(ctx context.Context, wd string, opts *GenerateOptions, state *incrementalPreloadState, snapshot *incrementalFingerprintSnapshot) error { - if state == nil || state.manifest == nil || snapshot == nil || len(snapshot.touched) == 0 { - return nil - } - roots := manifestOutputPkgPaths(state.manifest) - if len(roots) != 1 { - return nil - } - _, err := loadLocalPackagesForFastPath(ctx, wd, opts.Tags, roots[0], snapshot.touched, snapshotPackageFingerprints(snapshot), state.manifest.ExternalPkgs) - return err -} - -func shouldBypassIncrementalManifestAfterFastPathError(err error) bool { - if err == nil { - return false - } - msg := err.Error() - if strings.Contains(msg, "missing external export data for ") { - return false - } - return strings.Contains(msg, "type-check failed for ") -} - -func invalidateIncrementalPreloadState(state *incrementalPreloadState) { - if state == nil { - return - } - removeIncrementalManifestFile(state.selectorKey) -} - -func formatLocalTypeCheckError(wd string, pkgPath string, errs []packages.Error) error { - if len(errs) == 0 { - return fmt.Errorf("type-check failed for %s", pkgPath) - } - root := findModuleRoot(wd) - lines := []string{} - for _, pkgErr := range errs { - details := normalizeErrorLines(pkgErr.Msg, root) - if len(details) == 0 { - continue - } - lines = append(lines, fmt.Sprintf("type-check failed for %s: %s", pkgPath, details[0])) - for _, line := range details[1:] { - lines = append(lines, line) - } - } - if len(lines) == 0 { - lines = append(lines, fmt.Sprintf("type-check failed for %s", pkgPath)) - } - return fmt.Errorf("%s", strings.Join(lines, "\n")) -} - -func normalizeErrorLines(msg string, root string) []string { - msg = strings.TrimSpace(msg) - if msg == "" { - return []string{"unknown error"} - } - lines := unfoldTypeCheckChain(msg) - for i := range lines { - lines[i] = relativizeErrorLine(lines[i], root) - } - if len(lines) == 0 { - return []string{"unknown error"} - } - return lines -} - -func relativizeErrorLine(line string, root string) string { - if root == "" { - return line - } - cleanRoot := filepath.Clean(root) - prefix := cleanRoot + string(os.PathSeparator) - return strings.ReplaceAll(line, prefix, "") -} - -func unfoldTypeCheckChain(msg string) []string { - msg = strings.TrimSpace(msg) - if msg == "" { - return nil - } - if inner, outer, ok := splitNestedTypeCheck(msg); ok { - lines := []string{strings.TrimSpace(outer)} - return append(lines, unfoldTypeCheckChain(inner)...) - } - parts := strings.Split(msg, "\n") - lines := make([]string, 0, len(parts)) - for _, part := range parts { - part = strings.TrimSpace(part) - if part == "" { - continue - } - lines = append(lines, part) - } - return lines -} - -func splitNestedTypeCheck(msg string) (inner string, outer string, ok bool) { - msg = strings.TrimSpace(msg) - if len(msg) < 2 || msg[len(msg)-1] != ')' { - return "", "", false - } - depth := 0 - for i := len(msg) - 1; i >= 0; i-- { - switch msg[i] { - case ')': - depth++ - case '(': - depth-- - if depth == 0 { - inner = strings.TrimSpace(msg[i+1 : len(msg)-1]) - if strings.HasPrefix(inner, "type-check failed for ") { - return inner, strings.TrimSpace(msg[:i]), true - } - return "", "", false - } - } - } - return "", "", false -} - -type localFastPathLoaded struct { - fset *token.FileSet - root *packages.Package - allPackages []*packages.Package - byPath map[string]*packages.Package - fingerprints map[string]*packageFingerprint - loader *localFastPathLoader -} - -func (l *localFastPathLoaded) currentPackages() []*packages.Package { - if l == nil { - return nil - } - if l.loader == nil || len(l.loader.pkgs) == 0 { - return l.allPackages - } - all := make([]*packages.Package, 0, len(l.loader.pkgs)) - for _, pkg := range l.loader.pkgs { - all = append(all, pkg) - } - sort.Slice(all, func(i, j int) bool { return all[i].PkgPath < all[j].PkgPath }) - return all -} - -type localFastPathLoader struct { - ctx context.Context - wd string - tags string - fset *token.FileSet - modulePrefix string - rootPkgPath string - changedPkgs map[string]struct{} - sourcePkgs map[string]struct{} - summaries map[string]*packageSummary - meta map[string]*packageFingerprint - pkgs map[string]*packages.Package - imported map[string]*types.Package - externalMeta map[string]externalPackageExport - localExports map[string]string - externalImp types.Importer - externalFallback types.Importer -} - -func loadLocalPackagesForFastPath(ctx context.Context, wd string, tags string, rootPkgPath string, changed []string, current []packageFingerprint, external []externalPackageExport) (*localFastPathLoaded, error) { - return loadLocalPackagesForFastPathMode(ctx, wd, tags, rootPkgPath, changed, current, external, false) -} - -func validateTouchedPackagesFastPath(ctx context.Context, wd string, tags string, touched []string, current []packageFingerprint, external []externalPackageExport) error { - if len(touched) == 0 { - return nil - } - rootPkgPath := touched[0] - _, err := loadLocalPackagesForFastPathMode(ctx, wd, tags, rootPkgPath, touched, current, external, true) - return err -} - -func loadLocalPackagesForFastPathMode(ctx context.Context, wd string, tags string, rootPkgPath string, changed []string, current []packageFingerprint, external []externalPackageExport, validationOnly bool) (*localFastPathLoaded, error) { - meta := fingerprintsFromSlice(current) - if len(meta) == 0 { - return nil, fmt.Errorf("no local fingerprints") - } - if meta[rootPkgPath] == nil { - return nil, fmt.Errorf("missing root package fingerprint") - } - externalMeta := make(map[string]externalPackageExport, len(external)) - for _, item := range external { - if item.PkgPath == "" || item.ExportFile == "" { - continue - } - if meta[item.PkgPath] != nil { - continue - } - externalMeta[item.PkgPath] = item - } - loader := &localFastPathLoader{ - ctx: ctx, - wd: wd, - tags: tags, - fset: token.NewFileSet(), - modulePrefix: moduleImportPrefix(meta), - rootPkgPath: rootPkgPath, - changedPkgs: make(map[string]struct{}, len(changed)), - sourcePkgs: make(map[string]struct{}), - summaries: make(map[string]*packageSummary), - meta: meta, - pkgs: make(map[string]*packages.Package, len(meta)), - imported: make(map[string]*types.Package, len(meta)+len(externalMeta)), - externalMeta: externalMeta, - localExports: make(map[string]string), - } - for _, path := range changed { - loader.changedPkgs[path] = struct{}{} - } - if validationOnly { - for path := range loader.changedPkgs { - loader.sourcePkgs[path] = struct{}{} - } - } else { - loader.markSourceClosure() - } - for path, fp := range meta { - if path == rootPkgPath { - continue - } - if _, changed := loader.changedPkgs[path]; changed { - continue - } - if _, ok := loader.sourcePkgs[path]; ok { - continue - } - if exportPath := localExportPathForFingerprint(wd, tags, fp); exportPath != "" && localExportExists(wd, tags, fp) { - loader.localExports[path] = exportPath - } - } - candidates := make(map[string]*packageSummary) - for path, fp := range meta { - if path == rootPkgPath { - continue - } - if _, changed := loader.changedPkgs[path]; changed { - continue - } - summary, ok := readIncrementalPackageSummary(incrementalSummaryKey(wd, tags, path)) - if !ok || summary == nil || summary.ShapeHash != fp.ShapeHash { - continue - } - candidates[path] = summary - } - loader.summaries = filterSupportedPackageSummaries(candidates) - loader.externalImp = importerpkg.ForCompiler(loader.fset, "gc", loader.openExternalExport) - loader.externalFallback = importerpkg.ForCompiler(loader.fset, "gc", nil) - var root *packages.Package - if validationOnly { - for _, path := range changed { - pkg, err := loader.load(path) - if err != nil { - return nil, err - } - if root == nil { - root = pkg - } - } - } else { - var err error - root, err = loader.load(rootPkgPath) - if err != nil { - return nil, err - } - } - all := make([]*packages.Package, 0, len(loader.pkgs)) - for _, pkg := range loader.pkgs { - all = append(all, pkg) - } - sort.Slice(all, func(i, j int) bool { return all[i].PkgPath < all[j].PkgPath }) - return &localFastPathLoaded{ - fset: loader.fset, - root: root, - allPackages: all, - byPath: loader.pkgs, - fingerprints: loader.meta, - loader: loader, - }, nil -} - -func (l *localFastPathLoader) load(pkgPath string) (*packages.Package, error) { - if pkg := l.pkgs[pkgPath]; pkg != nil { - return pkg, nil - } - fp := l.meta[pkgPath] - if fp == nil { - return nil, fmt.Errorf("package %s not tracked as local", pkgPath) - } - files := filesFromMeta(fp.Files) - if len(files) == 0 { - return nil, fmt.Errorf("package %s has no files", pkgPath) - } - mode := parser.SkipObjectResolution - if pkgPath == l.rootPkgPath { - mode |= parser.ParseComments - } - syntax := make([]*ast.File, 0, len(files)) - parseStart := time.Now() - for _, name := range files { - file, err := l.parseFileForFastPath(name, mode, pkgPath) - if err != nil { - return nil, err - } - syntax = append(syntax, file) - } - logTiming(l.ctx, "incremental.local_fastpath.parse", parseStart) - if len(syntax) == 0 { - return nil, fmt.Errorf("package %s parsed no files", pkgPath) - } - - pkgName := syntax[0].Name.Name - info := newFastPathTypesInfo(pkgPath == l.rootPkgPath) - pkg := &packages.Package{ - Fset: l.fset, - Name: pkgName, - PkgPath: pkgPath, - GoFiles: append([]string(nil), files...), - CompiledGoFiles: append([]string(nil), files...), - Syntax: syntax, - TypesInfo: info, - Imports: make(map[string]*packages.Package), - } - l.pkgs[pkgPath] = pkg - - conf := &types.Config{ - Importer: importerFunc(func(path string) (*types.Package, error) { - return l.importPackage(path) - }), - IgnoreFuncBodies: l.shouldIgnoreFuncBodies(pkgPath), - Sizes: types.SizesFor("gc", runtime.GOARCH), - Error: func(err error) { - pkg.Errors = append(pkg.Errors, packages.Error{Msg: err.Error()}) - }, - } - typecheckStart := time.Now() - checkedPkg, err := conf.Check(pkgPath, l.fset, syntax, info) - logTiming(l.ctx, "incremental.local_fastpath.typecheck", typecheckStart) - if checkedPkg != nil { - pkg.Types = checkedPkg - l.imported[pkgPath] = checkedPkg - } - if l.shouldRetryWithoutBodyStripping(pkgPath, pkg.Errors) { - return l.reloadWithoutBodyStripping(pkgPath, files, mode, pkg) - } - if err != nil && len(pkg.Errors) == 0 { - return nil, err - } - if len(pkg.Errors) > 0 { - return nil, formatLocalTypeCheckError(l.wd, pkgPath, pkg.Errors) - } - - imports := packageImportPaths(syntax) - localImports := make([]string, 0, len(imports)) - for _, path := range imports { - if dep := l.pkgs[path]; dep != nil { - pkg.Imports[path] = dep - localImports = append(localImports, path) - } - } - sort.Strings(localImports) - updated := *fp - updated.LocalImports = localImports - updated.Tags = l.tags - updated.WD = filepath.Clean(l.wd) - l.meta[pkgPath] = &updated - return pkg, nil -} - -func (l *localFastPathLoader) parseFileForFastPath(name string, mode parser.Mode, pkgPath string) (*ast.File, error) { - file, err := parser.ParseFile(l.fset, name, nil, mode) - if err != nil { - return nil, err - } - if l.shouldStripFunctionBodies(pkgPath) { - stripFunctionBodies(file) - pruneImportsWithoutTopLevelUse(file) - } - return file, nil -} - -func (l *localFastPathLoader) reloadWithoutBodyStripping(pkgPath string, files []string, mode parser.Mode, pkg *packages.Package) (*packages.Package, error) { - syntax := make([]*ast.File, 0, len(files)) - parseStart := time.Now() - for _, name := range files { - file, err := parser.ParseFile(l.fset, name, nil, mode) - if err != nil { - return nil, err - } - syntax = append(syntax, file) - } - logTiming(l.ctx, "incremental.local_fastpath.parse_retry", parseStart) - pkg.Syntax = syntax - pkg.Errors = nil - pkg.TypesInfo = newFastPathTypesInfo(pkgPath == l.rootPkgPath) - conf := &types.Config{ - Importer: importerFunc(func(path string) (*types.Package, error) { - return l.importPackage(path) - }), - IgnoreFuncBodies: false, - Sizes: types.SizesFor("gc", runtime.GOARCH), - Error: func(err error) { - pkg.Errors = append(pkg.Errors, packages.Error{Msg: err.Error()}) - }, - } - typecheckStart := time.Now() - checkedPkg, err := conf.Check(pkgPath, l.fset, syntax, pkg.TypesInfo) - logTiming(l.ctx, "incremental.local_fastpath.typecheck_retry", typecheckStart) - if checkedPkg != nil { - pkg.Types = checkedPkg - l.imported[pkgPath] = checkedPkg - } - if err != nil && len(pkg.Errors) == 0 { - return nil, err - } - if len(pkg.Errors) > 0 { - return nil, formatLocalTypeCheckError(l.wd, pkgPath, pkg.Errors) - } - return pkg, nil -} - -func (l *localFastPathLoader) shouldRetryWithoutBodyStripping(pkgPath string, errs []packages.Error) bool { - if !l.shouldStripFunctionBodies(pkgPath) || len(errs) == 0 { - return false - } - for _, pkgErr := range errs { - msg := pkgErr.Msg - if strings.Contains(msg, "missing function body") || strings.Contains(msg, "func init must have a body") { - return true - } - } - return false -} - -func (l *localFastPathLoader) importPackage(path string) (*types.Package, error) { - if l.shouldImportFromExport(path) { - pkg, err := l.importExportPackage(path) - if err == nil { - return pkg, nil - } - // Cached local export artifacts are an optimization only. If one is - // missing or corrupted, fall back to source loading for correctness. - if _, ok := l.localExports[path]; ok && l.meta[path] != nil { - delete(l.localExports, path) - pkg, loadErr := l.load(path) - if loadErr == nil { - l.refreshLocalExport(path, pkg) - return pkg.Types, nil - } - return nil, loadErr - } - return nil, err - } - if l.meta[path] != nil { - pkg, err := l.load(path) - if err != nil { - return nil, err - } - l.refreshLocalExport(path, pkg) - return pkg.Types, nil - } - if l.externalImp == nil { - return nil, fmt.Errorf("missing external importer") - } - return l.importExportPackage(path) -} - -func (l *localFastPathLoader) openExternalExport(path string) (io.ReadCloser, error) { - meta, ok := l.externalMeta[path] - if !ok || meta.ExportFile == "" { - if l.meta[path] != nil || l.isLikelyLocalImport(path) { - return nil, fmt.Errorf("missing local export data for %s", path) - } - return nil, fmt.Errorf("missing external export data for %s", path) - } - return os.Open(meta.ExportFile) -} - -func (l *localFastPathLoader) isLikelyLocalImport(path string) bool { - if l == nil || l.modulePrefix == "" { - return false - } - return path == l.modulePrefix || strings.HasPrefix(path, l.modulePrefix+"/") -} - -func moduleImportPrefix(meta map[string]*packageFingerprint) string { - if len(meta) == 0 { - return "" - } - paths := make([]string, 0, len(meta)) - for path := range meta { - paths = append(paths, path) - } - sort.Strings(paths) - prefix := strings.Split(paths[0], "/") - for _, path := range paths[1:] { - parts := strings.Split(path, "/") - n := len(prefix) - if len(parts) < n { - n = len(parts) - } - i := 0 - for i < n && prefix[i] == parts[i] { - i++ - } - prefix = prefix[:i] - if len(prefix) == 0 { - return "" - } - } - return strings.Join(prefix, "/") -} - -func (l *localFastPathLoader) importExportPackage(path string) (*types.Package, error) { - if l == nil { - return nil, fmt.Errorf("missing local fast path loader") - } - if pkg := l.imported[path]; pkg != nil && pkg.Complete() { - return pkg, nil - } - if exportPath := l.localExports[path]; exportPath != "" { - f, err := os.Open(exportPath) - if err != nil { - return nil, err - } - defer f.Close() - pkg, err := gcexportdata.Read(f, l.fset, l.imported, path) - if err != nil { - return nil, err - } - l.imported[path] = pkg - return pkg, nil - } - if l.externalImp == nil { - return nil, fmt.Errorf("missing external importer") - } - pkg, err := l.externalImp.Import(path) - if err != nil { - if l.externalFallback != nil && strings.Contains(err.Error(), "missing external export data for ") { - pkg, fallbackErr := l.externalFallback.Import(path) - if fallbackErr == nil { - l.imported[path] = pkg - return pkg, nil - } - } - return nil, err - } - l.imported[path] = pkg - return pkg, nil -} - -func (l *localFastPathLoader) shouldImportFromExport(pkgPath string) bool { - if l == nil { - return false - } - if _, source := l.sourcePkgs[pkgPath]; source { - return false - } - if _, ok := l.localExports[pkgPath]; ok { - return true - } - _, ok := l.externalMeta[pkgPath] - return ok -} - -func (l *localFastPathLoader) refreshLocalExport(pkgPath string, pkg *packages.Package) { - if l == nil || pkg == nil || pkg.Fset == nil || pkg.Types == nil { - return - } - fp := l.meta[pkgPath] - exportPath := localExportPathForFingerprint(l.wd, l.tags, fp) - if exportPath == "" { - return - } - writeLocalPackageExportFile(exportPath, pkg.Fset, pkg.Types) - l.localExports[pkgPath] = exportPath -} - -func (l *localFastPathLoader) markSourceClosure() { - if l == nil { - return - } - reverse := make(map[string][]string) - for pkgPath, fp := range l.meta { - if fp == nil { - continue - } - for _, imp := range fp.LocalImports { - reverse[imp] = append(reverse[imp], pkgPath) - } - } - queue := make([]string, 0, len(l.changedPkgs)+1) - queue = append(queue, l.rootPkgPath) - for pkgPath := range l.changedPkgs { - queue = append(queue, pkgPath) - } - for len(queue) > 0 { - pkgPath := queue[0] - queue = queue[1:] - if _, seen := l.sourcePkgs[pkgPath]; seen { - continue - } - l.sourcePkgs[pkgPath] = struct{}{} - for _, importer := range reverse[pkgPath] { - if _, seen := l.sourcePkgs[importer]; !seen { - queue = append(queue, importer) - } - } - } -} - -func (l *localFastPathLoader) shouldStripFunctionBodies(pkgPath string) bool { - if l == nil { - return false - } - if pkgPath == l.rootPkgPath { - return false - } - _, changed := l.changedPkgs[pkgPath] - return !changed -} - -func (l *localFastPathLoader) shouldIgnoreFuncBodies(pkgPath string) bool { - return l.shouldStripFunctionBodies(pkgPath) -} - -type importerFunc func(string) (*types.Package, error) - -func (fn importerFunc) Import(path string) (*types.Package, error) { - return fn(path) -} - -func packageImportPaths(files []*ast.File) []string { - seen := make(map[string]struct{}) - var out []string - for _, file := range files { - for _, spec := range file.Imports { - path := strings.Trim(spec.Path.Value, "\"") - if path == "" { - continue - } - if _, ok := seen[path]; ok { - continue - } - seen[path] = struct{}{} - out = append(out, path) - } - } - sort.Strings(out) - return out -} - -func newFastPathTypesInfo(full bool) *types.Info { - info := &types.Info{ - Types: make(map[ast.Expr]types.TypeAndValue), - Defs: make(map[*ast.Ident]types.Object), - Uses: make(map[*ast.Ident]types.Object), - } - if !full { - return info - } - info.Implicits = make(map[ast.Node]types.Object) - info.Selections = make(map[*ast.SelectorExpr]*types.Selection) - info.Scopes = make(map[ast.Node]*types.Scope) - info.Instances = make(map[*ast.Ident]types.Instance) - return info -} - -func pruneImportsWithoutTopLevelUse(file *ast.File) { - if file == nil || len(file.Imports) == 0 { - return - } - used := usedImportNames(file) - filtered := file.Imports[:0] - for _, spec := range file.Imports { - if spec == nil || spec.Path == nil { - continue - } - name := importName(spec) - if name == "_" || name == "." { - filtered = append(filtered, spec) - continue - } - if _, ok := used[name]; ok { - filtered = append(filtered, spec) - } - } - file.Imports = filtered - for _, decl := range file.Decls { - gen, ok := decl.(*ast.GenDecl) - if !ok || gen.Tok != token.IMPORT { - continue - } - specs := gen.Specs[:0] - for _, spec := range gen.Specs { - importSpec, ok := spec.(*ast.ImportSpec) - if !ok || importSpec.Path == nil { - continue - } - name := importName(importSpec) - if name == "_" || name == "." { - specs = append(specs, spec) - continue - } - if _, ok := used[name]; ok { - specs = append(specs, spec) - } - } - gen.Specs = specs - } -} - -func usedImportNames(file *ast.File) map[string]struct{} { - used := make(map[string]struct{}) - ast.Inspect(file, func(node ast.Node) bool { - sel, ok := node.(*ast.SelectorExpr) - if !ok { - return true - } - ident, ok := sel.X.(*ast.Ident) - if !ok || ident.Name == "" { - return true - } - used[ident.Name] = struct{}{} - return true - }) - return used -} - -func importName(spec *ast.ImportSpec) string { - if spec == nil || spec.Path == nil { - return "" - } - if spec.Name != nil && spec.Name.Name != "" { - return spec.Name.Name - } - path := strings.Trim(spec.Path.Value, "\"") - if path == "" { - return "" - } - if slash := strings.LastIndex(path, "/"); slash >= 0 { - path = path[slash+1:] - } - return path -} - -func generateFromTypedPackages(ctx context.Context, loaded *localFastPathLoaded, opts *GenerateOptions) ([]GenerateResult, []error) { - if loaded == nil { - return nil, []error{fmt.Errorf("missing loaded packages")} - } - root := loaded.root - if root == nil { - return nil, []error{fmt.Errorf("missing root package")} - } - if opts == nil { - opts = &GenerateOptions{} - } - pkgStart := time.Now() - res := GenerateResult{PkgPath: root.PkgPath} - outDir, err := detectOutputDir(root.GoFiles) - logTiming(ctx, "generate.package."+root.PkgPath+".output_dir", pkgStart) - if err != nil { - res.Errs = append(res.Errs, err) - return []GenerateResult{res}, nil - } - res.OutputPath = filepath.Join(outDir, opts.PrefixOutputFile+"wire_gen.go") - - var summary *summaryProviderResolver - if loaded.loader != nil { - summary = newSummaryProviderResolver(ctx, loaded.loader.summaries, loaded.loader.importExportPackage) - } - oc := newObjectCacheWithLoader(loaded.allPackages, nil, nil, summary) - g := newGen(root) - injectorStart := time.Now() - injectorFiles, errs := generateInjectors(oc, g, root) - logTiming(ctx, "generate.package."+root.PkgPath+".injectors", injectorStart) - if len(errs) > 0 { - res.Errs = errs - return []GenerateResult{res}, nil - } - copyStart := time.Now() - copyNonInjectorDecls(g, injectorFiles, root.TypesInfo) - logTiming(ctx, "generate.package."+root.PkgPath+".copy_non_injectors", copyStart) - frameStart := time.Now() - goSrc := g.frame(opts.Tags) - logTiming(ctx, "generate.package."+root.PkgPath+".frame", frameStart) - if len(opts.Header) > 0 { - goSrc = append(opts.Header, goSrc...) - } - formatStart := time.Now() - fmtSrc, err := format.Source(goSrc) - logTiming(ctx, "generate.package."+root.PkgPath+".format", formatStart) - if err != nil { - res.Errs = append(res.Errs, err) - } else { - goSrc = fmtSrc - } - res.Content = goSrc - logTiming(ctx, "generate.package."+root.PkgPath+".total", pkgStart) - return []GenerateResult{res}, nil -} - -func writeIncrementalFingerprints(snapshot *incrementalFingerprintSnapshot, wd string, tags string) { - if snapshot == nil { - return - } - for _, path := range snapshot.changed { - fp := snapshot.fingerprints[path] - if fp == nil { - continue - } - writeIncrementalFingerprint(incrementalFingerprintKey(wd, tags, fp.PkgPath), fp) - } -} - -func writeIncrementalManifestFromState(wd string, env []string, patterns []string, opts *GenerateOptions, state *incrementalPreloadState, snapshot *incrementalFingerprintSnapshot, generated []GenerateResult) { - if snapshot == nil || len(generated) == 0 || state == nil || state.manifest == nil { - return - } - scope := runCacheScope(wd, patterns) - manifest := &incrementalManifest{ - Version: incrementalManifestVersion, - WD: scope, - Tags: opts.Tags, - Prefix: opts.PrefixOutputFile, - HeaderHash: headerHash(opts.Header), - EnvHash: envHash(env), - Patterns: normalizePatternsForScope(wd, packageCacheScope(wd), patterns), - LocalPackages: snapshotPackageFingerprints(snapshot), - ExternalPkgs: append([]externalPackageExport(nil), state.manifest.ExternalPkgs...), - ExternalFiles: append([]cacheFile(nil), state.manifest.ExternalFiles...), - ExtraFiles: extraCacheFiles(wd), - } - for _, out := range generated { - if len(out.Content) == 0 || out.OutputPath == "" { - continue - } - contentKey := incrementalContentKey(out.Content) - writeCache(contentKey, out.Content) - manifest.Outputs = append(manifest.Outputs, incrementalOutput{ - PkgPath: out.PkgPath, - OutputPath: out.OutputPath, - ContentKey: contentKey, - }) - } - if len(manifest.Outputs) == 0 { - return - } - selectorKey := incrementalManifestSelectorKey(wd, env, patterns, opts) - writeIncrementalManifestFile(selectorKey, manifest) - writeIncrementalManifestFile(incrementalManifestStateKey(selectorKey, manifest.LocalPackages), manifest) -} - -func writeIncrementalGraphFromSnapshot(wd string, tags string, roots []string, fps map[string]*packageFingerprint) { - if len(roots) == 0 || len(fps) == 0 { - return - } - graph := &incrementalGraph{ - Version: incrementalGraphVersion, - WD: packageCacheScope(wd), - Tags: tags, - Roots: append([]string(nil), roots...), - LocalReverse: make(map[string][]string), - } - sort.Strings(graph.Roots) - for _, fp := range fps { - if fp == nil { - continue - } - for _, imp := range fp.LocalImports { - graph.LocalReverse[imp] = append(graph.LocalReverse[imp], fp.PkgPath) - } - } - for path := range graph.LocalReverse { - sort.Strings(graph.LocalReverse[path]) - } - writeIncrementalGraph(incrementalGraphKey(wd, tags, graph.Roots), graph) -} - -func manifestOutputPkgPaths(manifest *incrementalManifest) []string { - if manifest == nil || len(manifest.Outputs) == 0 { - return nil - } - seen := make(map[string]struct{}, len(manifest.Outputs)) - paths := make([]string, 0, len(manifest.Outputs)) - for _, out := range manifest.Outputs { - if out.PkgPath == "" { - continue - } - if _, ok := seen[out.PkgPath]; ok { - continue - } - seen[out.PkgPath] = struct{}{} - paths = append(paths, out.PkgPath) - } - sort.Strings(paths) - return paths -} - -func changedPackagePaths(previous []packageFingerprint, current []packageFingerprint) []string { - if len(current) == 0 { - return nil - } - prevByPath := make(map[string]packageFingerprint, len(previous)) - for _, fp := range previous { - prevByPath[fp.PkgPath] = fp - } - changed := make([]string, 0, len(current)) - for _, fp := range current { - prev, ok := prevByPath[fp.PkgPath] - if !ok || !incrementalFingerprintEquivalent(&prev, &fp) { - changed = append(changed, fp.PkgPath) - } - } - sort.Strings(changed) - return changed -} diff --git a/internal/wire/parse.go b/internal/wire/parse.go index a7a1a02..e6f8cb1 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -19,6 +19,7 @@ import ( "errors" "fmt" "go/ast" + "go/parser" "go/token" "go/types" "os" @@ -30,6 +31,8 @@ import ( "golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/go/packages" "golang.org/x/tools/go/types/typeutil" + + "github.com/goforj/wire/internal/loader" ) // A providerSetSrc captures the source for a type provided by a ProviderSet. @@ -250,11 +253,8 @@ type Field struct { // In case of duplicate environment variables, the last one in the list // takes precedence. func Load(ctx context.Context, wd string, env []string, tags string, patterns []string) (*Info, []error) { - if IncrementalEnabled(ctx, env) { - debugf(ctx, "incremental=enabled") - } loadStart := time.Now() - pkgs, loader, errs := load(ctx, wd, env, tags, patterns) + pkgs, errs := load(ctx, wd, env, tags, patterns) logTiming(ctx, "load.packages", loadStart) if len(errs) > 0 { return nil, errs @@ -267,19 +267,13 @@ func Load(ctx context.Context, wd string, env []string, tags string, patterns [] Fset: fset, Sets: make(map[ProviderSetID]*ProviderSet), } - oc := newObjectCache(pkgs, loader) + oc := newObjectCache(pkgs) ec := new(errorCollector) for _, pkg := range pkgs { if isWireImport(pkg.PkgPath) { // The marker function package confuses analysis. continue } - if loaded, errs := oc.ensurePackage(pkg.PkgPath); len(errs) > 0 { - ec.add(errs...) - continue - } else if loaded != nil { - pkg = loaded - } pkgStart := time.Now() scope := pkg.Types.Scope() setStart := time.Now() @@ -367,68 +361,48 @@ func Load(ctx context.Context, wd string, env []string, tags string, patterns [] // env is nil or empty, it is interpreted as an empty set of variables. // In case of duplicate environment variables, the last one in the list // takes precedence. -func load(ctx context.Context, wd string, env []string, tags string, patterns []string) ([]*packages.Package, *lazyLoader, []error) { - var session *incrementalSession +func load(ctx context.Context, wd string, env []string, tags string, patterns []string) ([]*packages.Package, []error) { fset := token.NewFileSet() - if IncrementalEnabled(ctx, env) { - session = getIncrementalSession(wd, env, tags) - fset = session.fset - debugf(ctx, "incremental session=enabled") - } - baseCfg := &packages.Config{ - Context: ctx, - Mode: baseLoadMode(ctx), - Dir: wd, + loaderMode := effectiveLoaderMode(ctx, wd, env) + parseStats := &parseFileStats{} + loadStart := time.Now() + result, err := loader.New().LoadPackages(withLoaderTiming(ctx), loader.PackageLoadRequest{ + WD: wd, Env: env, - BuildFlags: []string{"-tags=wireinject"}, + Tags: tags, + Patterns: append([]string(nil), patterns...), + Mode: packages.LoadAllSyntax, + LoaderMode: loaderMode, Fset: fset, + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + start := time.Now() + file, err := parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + parseStats.record(false, time.Since(start), err, false) + return file, err + }, + }) + logTiming(ctx, "load.packages.load", loadStart) + var typedPkgs []*packages.Package + if result != nil { + typedPkgs = result.Packages + debugf(ctx, "load.packages.backend=%s", result.Backend) + if result.FallbackReason != loader.FallbackReasonNone { + debugf(ctx, "load.packages.fallback_reason=%s", result.FallbackReason) + if result.FallbackDetail != "" { + debugf(ctx, "load.packages.fallback_detail=%s", result.FallbackDetail) + } + } } - if len(tags) > 0 { - baseCfg.BuildFlags[0] += " " + tags - } - escaped := make([]string, len(patterns)) - for i := range patterns { - escaped[i] = "pattern=" + patterns[i] - } - baseLoadStart := time.Now() - pkgs, err := packages.Load(baseCfg, escaped...) - logTiming(ctx, "load.packages.base.load", baseLoadStart) - logLoadDebug(ctx, "base", baseCfg.Mode, strings.Join(patterns, ","), wd, pkgs, nil) + logLoadDebug(ctx, "typed", packages.LoadAllSyntax, strings.Join(patterns, ","), wd, typedPkgs, parseStats) if err != nil { - return nil, nil, []error{err} + return nil, []error{err} } - baseErrsStart := time.Now() - errs := collectLoadErrors(pkgs) - logTiming(ctx, "load.packages.base.collect_errors", baseErrsStart) + errs := collectLoadErrors(typedPkgs) + logTiming(ctx, "load.packages.collect_errors", loadStart) if len(errs) > 0 { - return nil, nil, errs - } - var fingerprints *incrementalFingerprintSnapshot - if !incrementalColdBootstrapEnabled(ctx) { - fingerprints = analyzeIncrementalFingerprints(ctx, wd, env, tags, pkgs) - analyzeIncrementalGraph(ctx, wd, env, tags, pkgs, fingerprints) - } - - baseFiles := collectPackageFiles(pkgs) - loader := &lazyLoader{ - ctx: ctx, - wd: wd, - env: env, - tags: tags, - fset: fset, - baseFiles: baseFiles, - session: session, - fingerprints: fingerprints, - } - return pkgs, loader, nil -} - -func baseLoadMode(ctx context.Context) packages.LoadMode { - mode := packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports - if !incrementalColdBootstrapEnabled(ctx) { - mode |= packages.NeedDeps + return nil, errs } - return mode + return typedPkgs, nil } func collectLoadErrors(pkgs []*packages.Package) []error { @@ -481,8 +455,6 @@ type objectCache struct { packages map[string]*packages.Package objects map[objRef]objCacheEntry hasher typeutil.Hasher - loader *lazyLoader - summary *summaryProviderResolver } type objRef struct { @@ -495,11 +467,7 @@ type objCacheEntry struct { errs []error } -func newObjectCache(pkgs []*packages.Package, loader *lazyLoader) *objectCache { - return newObjectCacheWithLoader(pkgs, loader, nil, nil) -} - -func newObjectCacheWithLoader(pkgs []*packages.Package, loader *lazyLoader, _ *localFastPathLoader, summary *summaryProviderResolver) *objectCache { +func newObjectCache(pkgs []*packages.Package) *objectCache { if len(pkgs) == 0 { panic("object cache must have packages to draw from") } @@ -508,11 +476,6 @@ func newObjectCacheWithLoader(pkgs []*packages.Package, loader *lazyLoader, _ *l packages: make(map[string]*packages.Package), objects: make(map[objRef]objCacheEntry), hasher: typeutil.MakeHasher(), - loader: loader, - summary: summary, - } - if oc.fset == nil && loader != nil { - oc.fset = loader.fset } // Depth-first search of all dependencies to gather import path to // packages.Package mapping. go/packages guarantees that for a single @@ -546,24 +509,6 @@ func (oc *objectCache) registerPackages(pkgs []*packages.Package, replace bool) } } -func (oc *objectCache) ensurePackage(pkgPath string) (*packages.Package, []error) { - if pkg := oc.packages[pkgPath]; pkg != nil && pkg.TypesInfo != nil && len(pkg.Syntax) > 0 { - return pkg, nil - } - if oc.loader == nil { - if pkg := oc.packages[pkgPath]; pkg != nil { - return pkg, nil - } - return nil, []error{fmt.Errorf("package %q is missing type information", pkgPath)} - } - loaded, errs := oc.loader.load(pkgPath) - if len(errs) > 0 { - return nil, errs - } - oc.registerPackages(loaded, true) - return oc.packages[pkgPath], nil -} - // get converts a Go object into a Wire structure. It may return a *Provider, an // *IfaceBinding, a *ProviderSet, a *Value, or a []*Field. func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { @@ -582,14 +527,6 @@ func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { }() switch obj := obj.(type) { case *types.Var: - if isProviderSetType(obj.Type()) && oc.summary != nil { - if pset, ok, summaryErrs := oc.summary.Resolve(obj.Pkg().Path(), obj.Name()); ok { - return pset, summaryErrs - } - } - if _, errs := oc.ensurePackage(ref.importPath); len(errs) > 0 { - return nil, errs - } spec := oc.varDecl(obj) if spec == nil || len(spec.Values) == 0 { return nil, []error{fmt.Errorf("%v is not a provider or a provider set", obj)} @@ -605,9 +542,6 @@ func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { case *types.Func: return processFuncProvider(oc.fset, obj) default: - if _, errs := oc.ensurePackage(ref.importPath); len(errs) > 0 { - return nil, errs - } return nil, []error{fmt.Errorf("%v is not a provider or a provider set", obj)} } } diff --git a/internal/wire/parse_coverage_test.go b/internal/wire/parse_coverage_test.go index 516d1d5..7c7a3b7 100644 --- a/internal/wire/parse_coverage_test.go +++ b/internal/wire/parse_coverage_test.go @@ -333,18 +333,18 @@ func TestAllFields(t *testing.T) { } } -func TestObjectCacheEnsurePackage(t *testing.T) { +func TestNewObjectCacheRegistersPackages(t *testing.T) { t.Parallel() fset := token.NewFileSet() pkg := &packages.Package{PkgPath: "example.com/p", Fset: fset} - oc := newObjectCache([]*packages.Package{pkg}, nil) + oc := newObjectCache([]*packages.Package{pkg}) - if got, errs := oc.ensurePackage(pkg.PkgPath); len(errs) != 0 || got != pkg { - t.Fatalf("expected existing package without errors, got pkg=%v errs=%v", got, errs) + if got := oc.packages[pkg.PkgPath]; got != pkg { + t.Fatalf("expected package to be registered, got %v", got) } - if _, errs := oc.ensurePackage("missing.example.com"); len(errs) == 0 { - t.Fatal("expected missing package error") + if got := oc.packages["missing.example.com"]; got != nil { + t.Fatalf("expected missing package to remain absent, got %v", got) } } diff --git a/internal/wire/parser_lazy_loader.go b/internal/wire/parser_lazy_loader.go deleted file mode 100644 index f6137bc..0000000 --- a/internal/wire/parser_lazy_loader.go +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "context" - "go/ast" - "go/parser" - "go/token" - "path/filepath" - "time" - - "golang.org/x/tools/go/packages" -) - -type lazyLoader struct { - ctx context.Context - wd string - env []string - tags string - fset *token.FileSet - baseFiles map[string]map[string]struct{} - session *incrementalSession - fingerprints *incrementalFingerprintSnapshot - loaded map[string]*packages.Package -} - -func collectPackageFiles(pkgs []*packages.Package) map[string]map[string]struct{} { - all := collectAllPackages(pkgs) - out := make(map[string]map[string]struct{}, len(all)) - for path, pkg := range all { - if pkg == nil { - continue - } - files := make(map[string]struct{}, len(pkg.CompiledGoFiles)) - for _, name := range pkg.CompiledGoFiles { - files[filepath.Clean(name)] = struct{}{} - } - if len(files) > 0 { - out[path] = files - } - } - return out -} - -func collectAllPackages(pkgs []*packages.Package) map[string]*packages.Package { - all := make(map[string]*packages.Package) - stack := append([]*packages.Package(nil), pkgs...) - for len(stack) > 0 { - p := stack[len(stack)-1] - stack = stack[:len(stack)-1] - if p == nil || all[p.PkgPath] != nil { - continue - } - all[p.PkgPath] = p - for _, imp := range p.Imports { - stack = append(stack, imp) - } - } - return all -} - -func (ll *lazyLoader) load(pkgPath string) ([]*packages.Package, []error) { - return ll.loadWithMode(pkgPath, ll.fullMode(), "load.packages.lazy.load") -} - -func (ll *lazyLoader) fullMode() packages.LoadMode { - return packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile -} - -func (ll *lazyLoader) loadWithMode(pkgPath string, mode packages.LoadMode, timingLabel string) ([]*packages.Package, []error) { - parseStats := &parseFileStats{} - cfg := &packages.Config{ - Context: ll.ctx, - Mode: mode, - Dir: ll.wd, - Env: ll.env, - BuildFlags: []string{"-tags=wireinject"}, - Fset: ll.fset, - ParseFile: ll.parseFileFor(pkgPath, parseStats), - } - if len(ll.tags) > 0 { - cfg.BuildFlags[0] += " " + ll.tags - } - loadStart := time.Now() - pkgs, err := packages.Load(cfg, "pattern="+pkgPath) - logTiming(ll.ctx, timingLabel, loadStart) - logLoadDebug(ll.ctx, "lazy", mode, pkgPath, ll.wd, pkgs, parseStats) - if err != nil { - return nil, []error{err} - } - errs := collectLoadErrors(pkgs) - if len(errs) > 0 { - return nil, errs - } - ll.rememberPackages(pkgs) - return pkgs, nil -} - -func (ll *lazyLoader) rememberPackages(pkgs []*packages.Package) { - if ll == nil || len(pkgs) == 0 { - return - } - if ll.loaded == nil { - ll.loaded = make(map[string]*packages.Package) - } - for path, pkg := range collectAllPackages(pkgs) { - if pkg != nil { - ll.loaded[path] = pkg - } - } -} - -func (ll *lazyLoader) parseFileFor(pkgPath string, stats *parseFileStats) func(*token.FileSet, string, []byte) (*ast.File, error) { - primary := primaryFileSet(ll.baseFiles[pkgPath]) - return func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { - start := time.Now() - isPrimary := isPrimaryFile(primary, filename) - keepBodies := ll.shouldKeepDependencyBodies(filename) - if !isPrimary && !keepBodies && ll.session != nil { - if file, ok := ll.session.getParsedDep(filename, src); ok { - if stats != nil { - stats.record(false, time.Since(start), nil, true) - } - return file, nil - } - } - mode := parser.SkipObjectResolution - if isPrimary { - mode = parser.ParseComments | parser.SkipObjectResolution - } - file, err := parser.ParseFile(fset, filename, src, mode) - if stats != nil { - stats.record(isPrimary, time.Since(start), err, false) - } - if err != nil { - return nil, err - } - if primary == nil { - return file, nil - } - if isPrimary { - return file, nil - } - if keepBodies { - return file, nil - } - for _, decl := range file.Decls { - if fn, ok := decl.(*ast.FuncDecl); ok { - fn.Body = nil - fn.Doc = nil - } - } - if ll.session != nil { - ll.session.storeParsedDep(filename, src, file) - } - return file, nil - } -} - -func (ll *lazyLoader) shouldKeepDependencyBodies(filename string) bool { - if ll == nil || ll.fingerprints == nil || len(ll.fingerprints.touched) == 0 { - return false - } - clean := filepath.Clean(filename) - for _, pkgPath := range ll.fingerprints.touched { - files := ll.baseFiles[pkgPath] - if len(files) == 0 { - continue - } - if _, ok := files[clean]; ok { - return true - } - } - return false -} diff --git a/internal/wire/parser_lazy_loader_test.go b/internal/wire/parser_lazy_loader_test.go deleted file mode 100644 index 86b49da..0000000 --- a/internal/wire/parser_lazy_loader_test.go +++ /dev/null @@ -1,204 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "context" - "go/ast" - "go/token" - "os" - "path/filepath" - "strings" - "testing" -) - -func TestLazyLoaderParseFileFor(t *testing.T) { - t.Helper() - fset := token.NewFileSet() - pkgPath := "example.com/pkg" - root := t.TempDir() - primary := filepath.Join(root, "primary.go") - secondary := filepath.Join(root, "secondary.go") - ll := &lazyLoader{ - fset: fset, - baseFiles: map[string]map[string]struct{}{ - pkgPath: {filepath.Clean(primary): {}}, - }, - } - src := strings.Join([]string{ - "package pkg", - "", - "// Doc comment", - "func Foo() {", - "\tprintln(\"hi\")", - "}", - "", - }, "\n") - - parse := ll.parseFileFor(pkgPath, &parseFileStats{}) - file, err := parse(fset, primary, []byte(src)) - if err != nil { - t.Fatalf("parse primary: %v", err) - } - fn := firstFuncDecl(t, file) - if fn.Body == nil { - t.Fatal("expected primary file to keep function body") - } - if fn.Doc == nil { - t.Fatal("expected primary file to keep doc comment") - } - - file, err = parse(fset, secondary, []byte(src)) - if err != nil { - t.Fatalf("parse secondary: %v", err) - } - fn = firstFuncDecl(t, file) - if fn.Body != nil { - t.Fatal("expected secondary file to strip function body") - } - if fn.Doc != nil { - t.Fatal("expected secondary file to strip doc comment") - } -} - -func TestLazyLoaderParseFileForCachesDependencyFiles(t *testing.T) { - t.Helper() - fset := token.NewFileSet() - pkgPath := "example.com/pkg" - root := t.TempDir() - primary := filepath.Join(root, "primary.go") - secondary := filepath.Join(root, "secondary.go") - session := &incrementalSession{ - fset: fset, - parsedDeps: make(map[string]cachedParsedFile), - } - ll := &lazyLoader{ - fset: fset, - baseFiles: map[string]map[string]struct{}{ - pkgPath: {filepath.Clean(primary): {}}, - }, - session: session, - } - src := []byte(strings.Join([]string{ - "package pkg", - "", - "func Foo() {", - "\tprintln(\"hi\")", - "}", - "", - }, "\n")) - - stats1 := &parseFileStats{} - parse1 := ll.parseFileFor(pkgPath, stats1) - file1, err := parse1(fset, secondary, src) - if err != nil { - t.Fatalf("first parse: %v", err) - } - snap1 := stats1.snapshot() - if snap1.cacheHits != 0 || snap1.cacheMisses != 1 { - t.Fatalf("first parse stats = %+v, want 0 hits and 1 miss", snap1) - } - - stats2 := &parseFileStats{} - parse2 := ll.parseFileFor(pkgPath, stats2) - file2, err := parse2(fset, secondary, src) - if err != nil { - t.Fatalf("second parse: %v", err) - } - if file1 != file2 { - t.Fatal("expected cached dependency parse to reuse AST") - } - snap2 := stats2.snapshot() - if snap2.cacheHits != 1 || snap2.cacheMisses != 0 { - t.Fatalf("second parse stats = %+v, want 1 hit and 0 misses", snap2) - } -} - -func TestLoadModuleUsesWireinjectTagsForDeps(t *testing.T) { - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.New)", - "\treturn nil", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "type Foo struct{}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "dep_inject.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package dep", - "", - "func New() *Foo {", - "\treturn &Foo{}", - "}", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := context.Background() - - info, errs := Load(ctx, root, env, "", []string{"./app"}) - if len(errs) > 0 { - t.Fatalf("Load returned errors: %v", errs) - } - if info == nil { - t.Fatal("Load returned nil info") - } - if len(info.Injectors) != 1 || info.Injectors[0].FuncName != "Init" { - t.Fatalf("Load returned unexpected injectors: %+v", info.Injectors) - } -} - -func firstFuncDecl(t *testing.T, file *ast.File) *ast.FuncDecl { - t.Helper() - for _, decl := range file.Decls { - if fn, ok := decl.(*ast.FuncDecl); ok { - return fn - } - } - t.Fatal("expected function declaration in file") - return nil -} diff --git a/internal/wire/summary_provider_resolver.go b/internal/wire/summary_provider_resolver.go deleted file mode 100644 index c93e0c5..0000000 --- a/internal/wire/summary_provider_resolver.go +++ /dev/null @@ -1,223 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "context" - "fmt" - "go/token" - "go/types" - "time" - - "golang.org/x/tools/go/types/typeutil" -) - -type summaryProviderResolver struct { - ctx context.Context - fset *token.FileSet - summaries map[string]*packageSummary - importPackage func(string) (*types.Package, error) - cache map[providerSetRefSummary]*ProviderSet - resolving map[providerSetRefSummary]struct{} - supported map[string]bool -} - -func newSummaryProviderResolver(ctx context.Context, summaries map[string]*packageSummary, importPackage func(string) (*types.Package, error)) *summaryProviderResolver { - if len(summaries) == 0 || importPackage == nil { - return nil - } - r := &summaryProviderResolver{ - ctx: ctx, - fset: token.NewFileSet(), - summaries: make(map[string]*packageSummary, len(summaries)), - importPackage: importPackage, - cache: make(map[providerSetRefSummary]*ProviderSet), - resolving: make(map[providerSetRefSummary]struct{}), - supported: make(map[string]bool, len(summaries)), - } - for pkgPath, summary := range summaries { - if summary == nil { - continue - } - r.summaries[pkgPath] = summary - } - for pkgPath := range r.summaries { - r.supported[pkgPath] = r.packageSupported(pkgPath, make(map[string]struct{})) - } - return r -} - -func filterSupportedPackageSummaries(summaries map[string]*packageSummary) map[string]*packageSummary { - if len(summaries) == 0 { - return nil - } - resolver := &summaryProviderResolver{ - summaries: summaries, - supported: make(map[string]bool, len(summaries)), - } - out := make(map[string]*packageSummary) - for pkgPath, summary := range summaries { - if summary == nil { - continue - } - if resolver.packageSupported(pkgPath, make(map[string]struct{})) { - out[pkgPath] = summary - } - } - return out -} - -func (r *summaryProviderResolver) Resolve(pkgPath string, varName string) (*ProviderSet, bool, []error) { - if r == nil || !r.supported[pkgPath] { - return nil, false, nil - } - start := time.Now() - set, err := r.resolve(providerSetRefSummary{PkgPath: pkgPath, VarName: varName}) - logTiming(r.ctx, "incremental.local_fastpath.summary_resolve", start) - if err != nil { - return nil, true, []error{err} - } - return set, true, nil -} - -func (r *summaryProviderResolver) resolve(ref providerSetRefSummary) (*ProviderSet, error) { - if set := r.cache[ref]; set != nil { - return set, nil - } - if _, ok := r.resolving[ref]; ok { - return nil, fmt.Errorf("summary provider set cycle for %s.%s", ref.PkgPath, ref.VarName) - } - summary := r.summaries[ref.PkgPath] - if summary == nil { - return nil, fmt.Errorf("missing package summary for %s", ref.PkgPath) - } - setSummary, ok := r.findProviderSet(summary, ref.VarName) - if !ok { - return nil, fmt.Errorf("missing provider set summary for %s.%s", ref.PkgPath, ref.VarName) - } - r.resolving[ref] = struct{}{} - defer delete(r.resolving, ref) - - pkg, err := r.importPackage(ref.PkgPath) - if err != nil { - return nil, err - } - set := &ProviderSet{ - PkgPath: ref.PkgPath, - VarName: ref.VarName, - } - for _, provider := range setSummary.Providers { - resolved, err := r.resolveProvider(pkg, provider) - if err != nil { - return nil, err - } - set.Providers = append(set.Providers, resolved) - } - for _, imported := range setSummary.Imports { - child, err := r.resolve(imported) - if err != nil { - return nil, err - } - set.Imports = append(set.Imports, child) - } - hasher := typeutil.MakeHasher() - providerMap, srcMap, errs := buildProviderMap(r.fset, hasher, set) - if len(errs) > 0 { - return nil, errs[0] - } - if errs := verifyAcyclic(providerMap, hasher); len(errs) > 0 { - return nil, errs[0] - } - set.providerMap = providerMap - set.srcMap = srcMap - r.cache[ref] = set - return set, nil -} - -func (r *summaryProviderResolver) resolveProvider(pkg *types.Package, summary providerSummary) (*Provider, error) { - if summary.IsStruct || len(summary.Out) == 0 { - return nil, fmt.Errorf("unsupported summary provider %s.%s", summary.PkgPath, summary.Name) - } - if pkg == nil || pkg.Path() != summary.PkgPath { - var err error - pkg, err = r.importPackage(summary.PkgPath) - if err != nil { - return nil, err - } - } - obj := pkg.Scope().Lookup(summary.Name) - fn, ok := obj.(*types.Func) - if !ok { - return nil, fmt.Errorf("summary provider %s.%s missing function", summary.PkgPath, summary.Name) - } - provider, errs := processFuncProvider(r.fset, fn) - if len(errs) > 0 { - return nil, errs[0] - } - return provider, nil -} - -func (r *summaryProviderResolver) findProviderSet(summary *packageSummary, varName string) (providerSetSummary, bool) { - if summary == nil { - return providerSetSummary{}, false - } - for _, set := range summary.ProviderSets { - if set.VarName == varName { - return set, true - } - } - return providerSetSummary{}, false -} - -func (r *summaryProviderResolver) packageSupported(pkgPath string, visiting map[string]struct{}) bool { - if ok, seen := r.supported[pkgPath]; seen { - return ok - } - if _, seen := visiting[pkgPath]; seen { - return false - } - summary := r.summaries[pkgPath] - if summary == nil { - return false - } - visiting[pkgPath] = struct{}{} - defer delete(visiting, pkgPath) - for _, set := range summary.ProviderSets { - if !providerSetSummarySupported(set) { - return false - } - for _, imported := range set.Imports { - if _, ok := r.summaries[imported.PkgPath]; !ok { - return false - } - if !r.packageSupported(imported.PkgPath, visiting) { - return false - } - } - } - return true -} - -func providerSetSummarySupported(summary providerSetSummary) bool { - if len(summary.Bindings) > 0 || len(summary.Values) > 0 || len(summary.Fields) > 0 || len(summary.InputTypes) > 0 { - return false - } - for _, provider := range summary.Providers { - if provider.IsStruct { - return false - } - } - return true -} diff --git a/internal/wire/wire.go b/internal/wire/wire.go index 24ca575..09bf814 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -21,6 +21,7 @@ import ( "context" "fmt" "go/ast" + "go/format" "go/printer" "go/token" "go/types" @@ -101,75 +102,69 @@ func Generate(ctx context.Context, wd string, env []string, patterns []string, o if opts == nil { opts = &GenerateOptions{} } - var preloadState *incrementalPreloadState - bypassIncrementalManifest := false - coldBootstrap := false - if IncrementalEnabled(ctx, env) { - debugf(ctx, "incremental=enabled") - preloadState, _ = prepareIncrementalPreloadState(ctx, wd, env, patterns, opts) - coldBootstrap = preloadState == nil - if coldBootstrap { - ctx = withIncrementalColdBootstrap(ctx, true) - } - if cached, ok := readPreloadIncrementalManifestResultsFromState(ctx, wd, env, patterns, opts, preloadState, preloadState != nil); ok { - return cached, nil - } - if generated, ok, bypass, errs := tryIncrementalLocalFastPath(ctx, wd, env, patterns, opts, preloadState); ok || len(errs) > 0 { - return generated, errs - } else if bypass { - bypassIncrementalManifest = true - } - } - if cached, ok := readManifestResults(wd, env, patterns, opts); ok { - return cached, nil - } loadStart := time.Now() - pkgs, loader, errs := load(ctx, wd, env, opts.Tags, patterns) + pkgs, errs := load(ctx, wd, env, opts.Tags, patterns) logTiming(ctx, "generate.load", loadStart) if len(errs) > 0 { return nil, errs } - if err := validateIncrementalTouchedPackages(ctx, wd, opts, preloadState, loader.fingerprints); err != nil { - if shouldBypassIncrementalManifestAfterFastPathError(err) { - return nil, []error{err} - } - bypassIncrementalManifest = true - } - if !bypassIncrementalManifest { - if cached, ok := readIncrementalManifestResults(ctx, wd, env, patterns, opts, pkgs, loader.fingerprints); ok { - warmPackageOutputCache(pkgs, opts, cached) - return cached, nil - } - } else { - debugf(ctx, "incremental.manifest bypass reason=fastpath_error") - ctx = withBypassPackageCache(ctx) - } generated := make([]GenerateResult, len(pkgs)) for i, pkg := range pkgs { - generated[i] = generateForPackage(ctx, pkg, loader, opts) - } - if allGeneratedOK(generated) { - if IncrementalEnabled(ctx, env) { - if coldBootstrap { - snapshot := buildIncrementalManifestSnapshotFromPackages(wd, opts.Tags, incrementalManifestPackages(pkgs, loader)) - writeIncrementalManifestWithOptions(wd, env, patterns, opts, incrementalManifestPackages(pkgs, loader), snapshot, generated, false) - if snapshot != nil { - writeLocalPackageExports(wd, opts.Tags, incrementalManifestPackages(pkgs, loader), snapshot.fingerprints) - writeIncrementalGraphFromSnapshot(wd, opts.Tags, manifestOutputPkgPathsFromGenerated(generated), snapshot.fingerprints) - loader.fingerprints = snapshot - } - writeIncrementalPackageSummaries(loader, pkgs) - } else { - writeLocalPackageExports(wd, opts.Tags, incrementalManifestPackages(pkgs, loader), loader.fingerprints.fingerprints) - writeIncrementalPackageSummaries(loader, pkgs) - writeIncrementalManifest(wd, env, patterns, opts, incrementalManifestPackages(pkgs, loader), loader.fingerprints, generated) - } + pkgStart := time.Now() + generated[i].PkgPath = pkg.PkgPath + dirStart := time.Now() + outDir, err := detectOutputDir(pkg.GoFiles) + logTiming(ctx, "generate.package."+pkg.PkgPath+".output_dir", dirStart) + if err != nil { + generated[i].Errs = append(generated[i].Errs, err) + continue + } + generated[i].OutputPath = filepath.Join(outDir, opts.PrefixOutputFile+"wire_gen.go") + g := newGen(pkg) + oc := newObjectCache([]*packages.Package{pkg}) + injectorStart := time.Now() + injectorFiles, genErrs := generateInjectors(oc, g, pkg) + logTiming(ctx, "generate.package."+pkg.PkgPath+".injectors", injectorStart) + if len(genErrs) > 0 { + generated[i].Errs = genErrs + continue } - writeManifest(wd, env, patterns, opts, pkgs) + copyStart := time.Now() + copyNonInjectorDecls(g, injectorFiles, pkg.TypesInfo) + logTiming(ctx, "generate.package."+pkg.PkgPath+".copy_non_injectors", copyStart) + frameStart := time.Now() + goSrc := g.frame(opts.Tags) + logTiming(ctx, "generate.package."+pkg.PkgPath+".frame", frameStart) + if len(opts.Header) > 0 { + goSrc = append(opts.Header, goSrc...) + } + formatStart := time.Now() + fmtSrc, err := format.Source(goSrc) + logTiming(ctx, "generate.package."+pkg.PkgPath+".format", formatStart) + if err != nil { + generated[i].Errs = append(generated[i].Errs, err) + } else { + goSrc = fmtSrc + } + generated[i].Content = goSrc + logTiming(ctx, "generate.package."+pkg.PkgPath+".total", pkgStart) } return generated, nil } +func detectOutputDir(paths []string) (string, error) { + if len(paths) == 0 { + return "", fmt.Errorf("no files to derive output directory from") + } + dir := filepath.Dir(paths[0]) + for _, p := range paths[1:] { + if dir2 := filepath.Dir(p); dir2 != dir { + return "", fmt.Errorf("found conflicting directories %q and %q", dir, dir2) + } + } + return dir, nil +} + func manifestOutputPkgPathsFromGenerated(generated []GenerateResult) []string { if len(generated) == 0 { return nil @@ -190,46 +185,6 @@ func manifestOutputPkgPathsFromGenerated(generated []GenerateResult) []string { return out } -func warmPackageOutputCache(pkgs []*packages.Package, opts *GenerateOptions, generated []GenerateResult) { - if len(pkgs) == 0 || len(generated) == 0 { - return - } - byPkg := make(map[string][]byte, len(generated)) - for _, gen := range generated { - if len(gen.Content) == 0 { - continue - } - byPkg[gen.PkgPath] = gen.Content - } - for _, pkg := range pkgs { - content := byPkg[pkg.PkgPath] - if len(content) == 0 { - continue - } - key, err := cacheKeyForPackage(pkg, opts) - if err != nil || key == "" { - continue - } - writeCache(key, content) - } -} - -func incrementalManifestPackages(pkgs []*packages.Package, loader *lazyLoader) []*packages.Package { - if loader == nil || len(loader.loaded) == 0 { - return pkgs - } - out := make([]*packages.Package, 0, len(loader.loaded)) - for _, pkg := range loader.loaded { - if pkg != nil { - out = append(out, pkg) - } - } - if len(out) == 0 { - return pkgs - } - return out -} - // generateInjectors generates the injectors for a given package. func generateInjectors(oc *objectCache, g *gen, pkg *packages.Package) (injectorFiles []*ast.File, _ []error) { injectorFiles = make([]*ast.File, 0, len(pkg.Syntax)) diff --git a/internal/wire/wire_test.go b/internal/wire/wire_test.go index cb167aa..dc5cfda 100644 --- a/internal/wire/wire_test.go +++ b/internal/wire/wire_test.go @@ -26,6 +26,7 @@ import ( "io/ioutil" "os" "os/exec" + "path" "path/filepath" "strings" "testing" @@ -481,6 +482,7 @@ func isIdent(s string) bool { // "C:\GOPATH" and running on Windows, the string // "C:\GOPATH\src\foo\bar.go:15:4" would be rewritten to "foo/bar.go:x:y". func scrubError(gopath string, s string) string { + s = normalizeHeaderRelativeError(s) sb := new(strings.Builder) query := gopath + string(os.PathSeparator) + "src" + string(os.PathSeparator) for { @@ -517,7 +519,106 @@ func scrubError(gopath string, s string) string { sb.WriteString(linecol) s = s[linecolLen:] } - return sb.String() + return strings.TrimRight(sb.String(), "\n") +} + +func normalizeHeaderRelativeError(s string) string { + const headerPrefix = "-: # " + if !strings.HasPrefix(s, headerPrefix) { + return s + } + pkgAndRest := strings.TrimPrefix(s, headerPrefix) + newline := strings.IndexByte(pkgAndRest, '\n') + if newline == -1 { + return s + } + pkg := strings.TrimSpace(pkgAndRest[:newline]) + rest := strings.TrimLeft(pkgAndRest[newline+1:], "\n") + if pkg == "" || rest == "" { + return s + } + + firstLineEnd := strings.IndexByte(rest, '\n') + if firstLineEnd == -1 { + firstLineEnd = len(rest) + } + firstLine := rest[:firstLineEnd] + rewritten, ok := canonicalizeRelativeErrorPath(pkg, firstLine) + if !ok { + return s + } + return normalizeLegacyUndefinedQualifiedName(rewritten + rest[firstLineEnd:]) +} + +func canonicalizeRelativeErrorPath(pkg, line string) (string, bool) { + goExt := strings.Index(line, ".go") + if goExt == -1 { + return "", false + } + goExt += len(".go") + linecol, n := scrubLineColumn(line[goExt:]) + if n == 0 { + return "", false + } + file := line[:goExt] + suffix := line[goExt+n:] + file = strings.ReplaceAll(file, "\\", "/") + file = strings.TrimPrefix(file, "./") + file = strings.TrimPrefix(file, "/") + baseDir := path.Base(pkg) + if strings.HasPrefix(file, pkg+"/") { + return file + linecol + suffix, true + } + if strings.HasPrefix(file, baseDir+"/") { + file = pkg + "/" + strings.TrimPrefix(file, baseDir+"/") + return file + linecol + suffix, true + } + if !strings.Contains(file, "/") { + return pkg + "/" + file + linecol + suffix, true + } + return "", false +} + +func normalizeLegacyUndefinedQualifiedName(s string) string { + const marker = ": undefined: " + idx := strings.Index(s, marker) + if idx == -1 { + return s + } + qualified := s[idx+len(marker):] + end := len(qualified) + for i, r := range qualified { + if r == '\n' || r == '\r' || r == '\t' || r == ' ' { + end = i + break + } + } + qualified = qualified[:end] + dot := strings.IndexByte(qualified, '.') + if dot == -1 || dot == 0 || dot == len(qualified)-1 { + return s + } + pkgName := qualified[:dot] + name := qualified[dot+1:] + if name == "" || !isLowerIdent(name) { + return s + } + return s[:idx] + ": name " + name + " not exported by package " + pkgName +} + +func isLowerIdent(s string) bool { + if s == "" { + return false + } + for i, r := range s { + if i == 0 && !unicode.IsLower(r) { + return false + } + if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_' { + return false + } + } + return true } func scrubLineColumn(s string) (replacement string, n int) { @@ -571,6 +672,24 @@ func filterLegacyCompilerErrors(errs []string) []string { return filtered } +func TestScrubErrorCanonicalizesHeaderRelativePath(t *testing.T) { + const gopath = "/tmp/wire_test" + got := scrubError(gopath, "-: # example.com/foo\nfoo/wire.go:26:33: not enough arguments in call to wire.InterfaceValue") + want := "example.com/foo/wire.go:x:y: not enough arguments in call to wire.InterfaceValue" + if got != want { + t.Fatalf("scrubError() = %q, want %q", got, want) + } +} + +func TestScrubErrorCanonicalizesHeaderRootRelativePath(t *testing.T) { + const gopath = "/tmp/wire_test" + got := scrubError(gopath, "-: # example.com/foo\n/wire.go:27:17: name foo not exported by package bar") + want := "example.com/foo/wire.go:x:y: name foo not exported by package bar" + if got != want { + t.Fatalf("scrubError() = %q, want %q", got, want) + } +} + type testCase struct { name string pkg string From e7cfc6390e19fee8485fbac62eb4d7eb352134bc Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sun, 15 Mar 2026 05:38:30 -0500 Subject: [PATCH 09/79] feat: external loader caching --- cmd/wire/main.go | 4 + internal/loader/artifact_cache.go | 139 ++++++ internal/loader/custom.go | 268 +++++++--- internal/loader/discovery.go | 1 + internal/loader/loader_test.go | 791 ++++++++++++++++++++++++++++++ internal/loader/timing.go | 15 + internal/wire/wire_test.go | 115 +++++ 7 files changed, 1272 insertions(+), 61 deletions(-) create mode 100644 internal/loader/artifact_cache.go diff --git a/cmd/wire/main.go b/cmd/wire/main.go index 4426ee1..f7fd92f 100644 --- a/cmd/wire/main.go +++ b/cmd/wire/main.go @@ -186,6 +186,10 @@ func withTiming(ctx context.Context, enabled bool) context.Context { return ctx } return wire.WithTiming(ctx, func(label string, dur time.Duration) { + if dur == 0 && strings.Contains(label, "=") { + log.Printf("timing: %s", label) + return + } log.Printf("timing: %s=%s", label, dur) }) } diff --git a/internal/loader/artifact_cache.go b/internal/loader/artifact_cache.go new file mode 100644 index 0000000..d495143 --- /dev/null +++ b/internal/loader/artifact_cache.go @@ -0,0 +1,139 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package loader + +import ( + "bytes" + "crypto/sha256" + "encoding/hex" + "go/token" + "go/types" + "os" + "path/filepath" + "runtime" + "strconv" + + "golang.org/x/tools/go/gcexportdata" +) + +const ( + loaderArtifactEnv = "WIRE_LOADER_ARTIFACTS" + loaderArtifactDirEnv = "WIRE_LOADER_ARTIFACT_DIR" +) + +func loaderArtifactEnabled(env []string) bool { + return envValue(env, loaderArtifactEnv) == "1" +} + +func loaderArtifactDir(env []string) (string, error) { + if dir := envValue(env, loaderArtifactDirEnv); dir != "" { + return dir, nil + } + base, err := os.UserCacheDir() + if err != nil { + return "", err + } + return filepath.Join(base, "wire", "loader-artifacts"), nil +} + +func loaderArtifactPath(env []string, meta *packageMeta, isLocal bool) (string, error) { + dir, err := loaderArtifactDir(env) + if err != nil { + return "", err + } + key, err := loaderArtifactKey(meta, isLocal) + if err != nil { + return "", err + } + return filepath.Join(dir, key+".bin"), nil +} + +func loaderArtifactKey(meta *packageMeta, isLocal bool) (string, error) { + sum := sha256.New() + sum.Write([]byte("wire-loader-artifact-v3\n")) + sum.Write([]byte(runtime.Version())) + sum.Write([]byte{'\n'}) + sum.Write([]byte(meta.ImportPath)) + sum.Write([]byte{'\n'}) + sum.Write([]byte(meta.Name)) + sum.Write([]byte{'\n'}) + if !isLocal { + sum.Write([]byte(meta.Export)) + sum.Write([]byte{'\n'}) + if meta.Error != nil { + sum.Write([]byte(meta.Error.Err)) + sum.Write([]byte{'\n'}) + } + return hex.EncodeToString(sum.Sum(nil)), nil + } + for _, name := range metaFiles(meta) { + info, err := os.Stat(name) + if err != nil { + return "", err + } + sum.Write([]byte(name)) + sum.Write([]byte{'\n'}) + sum.Write([]byte(strconv.FormatInt(info.Size(), 10))) + sum.Write([]byte{'\n'}) + sum.Write([]byte(strconv.FormatInt(info.ModTime().UnixNano(), 10))) + sum.Write([]byte{'\n'}) + } + return hex.EncodeToString(sum.Sum(nil)), nil +} + +func readLoaderArtifact(path string, fset *token.FileSet, imports map[string]*types.Package, pkgPath string) (*types.Package, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + return readLoaderArtifactData(data, fset, imports, pkgPath) +} + +func readLoaderArtifactData(data []byte, fset *token.FileSet, imports map[string]*types.Package, pkgPath string) (*types.Package, error) { + return gcexportdata.Read(bytes.NewReader(data), fset, imports, pkgPath) +} + +func writeLoaderArtifact(path string, fset *token.FileSet, pkg *types.Package) error { + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return err + } + var out bytes.Buffer + if err := gcexportdata.Write(&out, fset, pkg); err != nil { + return err + } + return os.WriteFile(path, out.Bytes(), 0o644) +} + +func artifactUpToDate(env []string, artifactPath string, meta *packageMeta, isLocal bool) bool { + _, err := os.Stat(artifactPath) + return err == nil +} + +func isProviderSetTypeForLoader(t types.Type) bool { + named, ok := t.(*types.Named) + if !ok { + return false + } + obj := named.Obj() + if obj == nil || obj.Pkg() == nil { + return false + } + switch obj.Pkg().Path() { + case "github.com/goforj/wire", "github.com/google/wire": + return obj.Name() == "ProviderSet" + default: + return false + } +} diff --git a/internal/loader/custom.go b/internal/loader/custom.go index ffa2d48..49fc217 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -68,35 +68,45 @@ type customValidator struct { } type customTypedGraphLoader struct { - workspace string - ctx context.Context - fset *token.FileSet - meta map[string]*packageMeta - targets map[string]struct{} - parseFile ParseFileFunc - packages map[string]*packages.Package - typesPkgs map[string]*types.Package - importer types.Importer - loading map[string]bool - stats typedLoadStats + workspace string + ctx context.Context + env []string + fset *token.FileSet + meta map[string]*packageMeta + targets map[string]struct{} + parseFile ParseFileFunc + packages map[string]*packages.Package + typesPkgs map[string]*types.Package + importer types.Importer + loading map[string]bool + isLocalCache map[string]bool + stats typedLoadStats } type typedLoadStats struct { - read time.Duration - parse time.Duration - typecheck time.Duration - localRead time.Duration - externalRead time.Duration - localParse time.Duration - externalParse time.Duration - localTypecheck time.Duration - externalTypecheck time.Duration - filesRead int - packages int - localPackages int - externalPackages int - localFilesRead int - externalFilesRead int + read time.Duration + parse time.Duration + typecheck time.Duration + localRead time.Duration + externalRead time.Duration + localParse time.Duration + externalParse time.Duration + localTypecheck time.Duration + externalTypecheck time.Duration + filesRead int + packages int + localPackages int + externalPackages int + localFilesRead int + externalFilesRead int + artifactRead time.Duration + artifactPath time.Duration + artifactDecode time.Duration + artifactImportLink time.Duration + artifactWrite time.Duration + artifactHits int + artifactMisses int + artifactWrites int } func validateTouchedPackagesCustom(ctx context.Context, req TouchedValidationRequest) (*TouchedValidationResult, error) { @@ -195,8 +205,8 @@ func loadRootGraphCustom(ctx context.Context, req RootLoadRequest) (*RootLoadRes } sort.Slice(roots, func(i, j int) bool { return roots[i].PkgPath < roots[j].PkgPath }) return &RootLoadResult{ - Packages: roots, - Backend: ModeCustom, + Packages: roots, + Backend: ModeCustom, Discovery: discoverySnapshotForMeta(meta, req.NeedDeps), }, nil } @@ -235,16 +245,19 @@ func loadTypedPackageGraphCustom(ctx context.Context, req LazyLoadRequest) (*Laz fset = token.NewFileSet() } l := &customTypedGraphLoader{ - workspace: detectModuleRoot(req.WD), - ctx: ctx, - fset: fset, - meta: meta, - targets: map[string]struct{}{req.Package: {}}, - parseFile: req.ParseFile, - packages: make(map[string]*packages.Package, len(meta)), - typesPkgs: make(map[string]*types.Package, len(meta)), - importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), - loading: make(map[string]bool, len(meta)), + workspace: detectModuleRoot(req.WD), + ctx: ctx, + env: append([]string(nil), req.Env...), + fset: fset, + meta: meta, + targets: map[string]struct{}{req.Package: {}}, + parseFile: req.ParseFile, + packages: make(map[string]*packages.Package, len(meta)), + typesPkgs: make(map[string]*types.Package, len(meta)), + importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), + loading: make(map[string]bool, len(meta)), + isLocalCache: make(map[string]bool, len(meta)), + stats: typedLoadStats{}, } root, err := l.loadPackage(req.Package) if err != nil { @@ -259,6 +272,14 @@ func loadTypedPackageGraphCustom(ctx context.Context, req LazyLoadRequest) (*Laz logDuration(ctx, "loader.custom.lazy.parse_files.external.cumulative", l.stats.externalParse) logDuration(ctx, "loader.custom.lazy.typecheck.local.cumulative", l.stats.localTypecheck) logDuration(ctx, "loader.custom.lazy.typecheck.external.cumulative", l.stats.externalTypecheck) + logDuration(ctx, "loader.custom.lazy.artifact_read", l.stats.artifactRead) + logDuration(ctx, "loader.custom.lazy.artifact_path", l.stats.artifactPath) + logDuration(ctx, "loader.custom.lazy.artifact_decode", l.stats.artifactDecode) + logDuration(ctx, "loader.custom.lazy.artifact_import_link", l.stats.artifactImportLink) + logDuration(ctx, "loader.custom.lazy.artifact_write", l.stats.artifactWrite) + logInt(ctx, "loader.custom.lazy.artifact_hits", l.stats.artifactHits) + logInt(ctx, "loader.custom.lazy.artifact_misses", l.stats.artifactMisses) + logInt(ctx, "loader.custom.lazy.artifact_writes", l.stats.artifactWrites) return &LazyLoadResult{ Packages: []*packages.Package{root}, Backend: ModeCustom, @@ -294,16 +315,19 @@ func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLo return nil, unsupportedError{reason: "no root packages from metadata"} } l := &customTypedGraphLoader{ - workspace: detectModuleRoot(req.WD), - ctx: ctx, - fset: fset, - meta: meta, - targets: targets, - parseFile: req.ParseFile, - packages: make(map[string]*packages.Package, len(meta)), - typesPkgs: make(map[string]*types.Package, len(meta)), - importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), - loading: make(map[string]bool, len(meta)), + workspace: detectModuleRoot(req.WD), + ctx: ctx, + env: append([]string(nil), req.Env...), + fset: fset, + meta: meta, + targets: targets, + parseFile: req.ParseFile, + packages: make(map[string]*packages.Package, len(meta)), + typesPkgs: make(map[string]*types.Package, len(meta)), + importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), + loading: make(map[string]bool, len(meta)), + isLocalCache: make(map[string]bool, len(meta)), + stats: typedLoadStats{}, } roots := make([]*packages.Package, 0, len(targets)) for _, m := range meta { @@ -326,6 +350,14 @@ func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLo logDuration(ctx, "loader.custom.typed.parse_files.external.cumulative", l.stats.externalParse) logDuration(ctx, "loader.custom.typed.typecheck.local.cumulative", l.stats.localTypecheck) logDuration(ctx, "loader.custom.typed.typecheck.external.cumulative", l.stats.externalTypecheck) + logDuration(ctx, "loader.custom.typed.artifact_read", l.stats.artifactRead) + logDuration(ctx, "loader.custom.typed.artifact_path", l.stats.artifactPath) + logDuration(ctx, "loader.custom.typed.artifact_decode", l.stats.artifactDecode) + logDuration(ctx, "loader.custom.typed.artifact_import_link", l.stats.artifactImportLink) + logDuration(ctx, "loader.custom.typed.artifact_write", l.stats.artifactWrite) + logInt(ctx, "loader.custom.typed.artifact_hits", l.stats.artifactHits) + logInt(ctx, "loader.custom.typed.artifact_misses", l.stats.artifactMisses) + logInt(ctx, "loader.custom.typed.artifact_writes", l.stats.artifactWrites) return &PackageLoadResult{ Packages: roots, Backend: ModeCustom, @@ -435,30 +467,51 @@ func (v *customValidator) validatePackage(path string) (*packages.Package, error } func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, error) { - if pkg := l.packages[path]; pkg != nil && (pkg.Types != nil || len(pkg.Errors) > 0) { + if path == "C" { + if pkg := l.packages[path]; pkg != nil { + return pkg, nil + } + tpkg := l.typesPkgs[path] + if tpkg == nil { + tpkg = types.NewPackage("C", "C") + l.typesPkgs[path] = tpkg + } + pkg := &packages.Package{ + ID: "C", + Name: "C", + PkgPath: "C", + Fset: l.fset, + Imports: make(map[string]*packages.Package), + Types: tpkg, + } + l.packages[path] = pkg return pkg, nil } meta := l.meta[path] if meta == nil { - return nil, unsupportedError{reason: "missing lazy-load metadata"} + return nil, unsupportedError{reason: "missing lazy-load metadata for " + path} } + pkg := l.packages[path] if l.loading[path] { - if pkg := l.packages[path]; pkg != nil { + if pkg != nil { return pkg, nil } return nil, unsupportedError{reason: "lazy-load cycle"} } + if pkg != nil && (pkg.Types != nil || len(pkg.Errors) > 0) { + return pkg, nil + } l.loading[path] = true defer delete(l.loading, path) l.stats.packages++ - isLocal := isWorkspacePackage(l.workspace, meta.Dir) + _, isTarget := l.targets[path] + isLocal := l.isLocalPackage(path, meta) if isLocal { l.stats.localPackages++ } else { l.stats.externalPackages++ } - pkg := l.packages[path] if pkg == nil { pkg = &packages.Package{ ID: meta.ImportPath, @@ -472,6 +525,28 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er } l.packages[path] = pkg } + useArtifact := loaderArtifactEnabled(l.env) && !isTarget && !isLocal + if useArtifact { + if typed, ok := l.readArtifact(path, meta, isLocal); ok { + linkStart := time.Now() + for _, imp := range meta.Imports { + target := imp + if mapped := meta.ImportMap[imp]; mapped != "" { + target = mapped + } + dep, err := l.loadPackage(target) + if err != nil { + return nil, err + } + pkg.Imports[imp] = dep + } + l.stats.artifactImportLink += time.Since(linkStart) + pkg.Types = typed + pkg.TypesInfo = nil + pkg.Syntax = nil + return pkg, nil + } + } files, parseErrs := l.parseFiles(metaFiles(meta), isLocal) pkg.Errors = append(pkg.Errors, parseErrs...) if len(files) == 0 { @@ -486,12 +561,11 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er } tpkg := l.typesPkgs[path] - if tpkg == nil { + if tpkg == nil || tpkg.Complete() || (tpkg.Scope() != nil && len(tpkg.Scope().Names()) > 0) { tpkg = types.NewPackage(meta.ImportPath, meta.Name) l.typesPkgs[path] = tpkg } - _, isTarget := l.targets[path] - needFullState := isTarget || isWorkspacePackage(l.workspace, meta.Dir) + needFullState := isTarget || isLocal var info *types.Info if needFullState { info = &types.Info{ @@ -506,7 +580,7 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er var typeErrors []packages.Error cfg := &types.Config{ Sizes: types.SizesFor("gc", runtime.GOARCH), - IgnoreFuncBodies: !isWorkspacePackage(l.workspace, meta.Dir), + IgnoreFuncBodies: !isLocal, Importer: importerFunc(func(importPath string) (*types.Package, error) { if importPath == "unsafe" { return types.Unsafe, nil @@ -537,7 +611,7 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er } checker := types.NewChecker(cfg, l.fset, tpkg, info) typecheckStart := time.Now() - if err := checker.Files(files); err != nil && len(typeErrors) == 0 { + if err := l.checkFiles(path, checker, files); err != nil && len(typeErrors) == 0 { typeErrors = append(typeErrors, toPackagesError(l.fset, err)) } typecheckDuration := time.Since(typecheckStart) @@ -555,9 +629,84 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er pkg.Types = tpkg pkg.TypesInfo = info pkg.Errors = append(pkg.Errors, typeErrors...) + if shouldWriteArtifact(l.env, isTarget, isLocal) && len(pkg.Errors) == 0 { + _ = l.writeArtifact(meta, tpkg, isLocal) + } return pkg, nil } +func (l *customTypedGraphLoader) checkFiles(path string, checker *types.Checker, files []*ast.File) (err error) { + defer func() { + if r := recover(); r != nil { + err = unsupportedError{reason: fmt.Sprintf("typecheck panic in %s: %v", path, r)} + } + }() + return checker.Files(files) +} + +func (l *customTypedGraphLoader) readArtifact(path string, meta *packageMeta, isLocal bool) (*types.Package, bool) { + start := time.Now() + pathStart := time.Now() + artifactPath, err := loaderArtifactPath(l.env, meta, isLocal) + l.stats.artifactPath += time.Since(pathStart) + if err != nil { + l.stats.artifactRead += time.Since(start) + l.stats.artifactMisses++ + return nil, false + } + var tpkg *types.Package + decodeStart := time.Now() + tpkg, err = readLoaderArtifact(artifactPath, l.fset, l.typesPkgs, path) + l.stats.artifactDecode += time.Since(decodeStart) + if err != nil { + l.stats.artifactRead += time.Since(start) + l.stats.artifactMisses++ + return nil, false + } + l.stats.artifactRead += time.Since(start) + l.stats.artifactHits++ + l.typesPkgs[path] = tpkg + return tpkg, true +} + +func (l *customTypedGraphLoader) writeArtifact(meta *packageMeta, pkg *types.Package, isLocal bool) error { + start := time.Now() + artifactPath, err := loaderArtifactPath(l.env, meta, isLocal) + if err != nil { + l.stats.artifactWrite += time.Since(start) + return err + } + if artifactUpToDate(l.env, artifactPath, meta, isLocal) { + l.stats.artifactWrite += time.Since(start) + return nil + } + writeErr := writeLoaderArtifact(artifactPath, l.fset, pkg) + l.stats.artifactWrite += time.Since(start) + if writeErr == nil { + l.stats.artifactWrites++ + } + if writeErr != nil { + return writeErr + } + return nil +} + +func shouldWriteArtifact(env []string, isTarget, isLocal bool) bool { + if !loaderArtifactEnabled(env) || isTarget || isLocal { + return false + } + return true +} + +func (l *customTypedGraphLoader) isLocalPackage(importPath string, meta *packageMeta) bool { + if local, ok := l.isLocalCache[importPath]; ok { + return local + } + local := isWorkspacePackage(l.workspace, meta.Dir) + l.isLocalCache[importPath] = local + return local +} + func (v *customValidator) importFromExport(path string) (*types.Package, error) { if typed := v.packages[path]; typed != nil && typed.Complete() { return typed, nil @@ -907,8 +1056,6 @@ func isWorkspacePackage(workspaceRoot, dir string) bool { if workspaceRoot == "" || dir == "" { return false } - workspaceRoot = canonicalLoaderPath(workspaceRoot) - dir = canonicalLoaderPath(dir) if dir == workspaceRoot { return true } @@ -970,7 +1117,6 @@ func envValue(env []string, key string) string { return "" } - func touchedPackageStub(fset *token.FileSet, meta *packageMeta) *packages.Package { if meta == nil { return nil diff --git a/internal/loader/discovery.go b/internal/loader/discovery.go index a6aba46..9adbf4e 100644 --- a/internal/loader/discovery.go +++ b/internal/loader/discovery.go @@ -73,6 +73,7 @@ func runGoList(ctx context.Context, req goListRequest) (map[string]*packageMeta, if meta.ImportPath == "" { continue } + meta.Dir = canonicalLoaderPath(meta.Dir) for i, name := range meta.GoFiles { if !filepath.IsAbs(name) { meta.GoFiles[i] = filepath.Join(meta.Dir, name) diff --git a/internal/loader/loader_test.go b/internal/loader/loader_test.go index 0e38d99..50d8690 100644 --- a/internal/loader/loader_test.go +++ b/internal/loader/loader_test.go @@ -15,16 +15,22 @@ package loader import ( + "bytes" "context" + "fmt" "go/ast" "go/parser" "go/token" + "go/types" "os" "path/filepath" "sort" + "strconv" "strings" "testing" + "time" + "golang.org/x/tools/go/gcexportdata" "golang.org/x/tools/go/packages" ) @@ -351,6 +357,107 @@ func TestMetaFilesFallsBackToGoFiles(t *testing.T) { } } +func TestExportDataPairings(t *testing.T) { + fset := token.NewFileSet() + file, err := parser.ParseFile(fset, "lib.go", "package lib\n\ntype T int\n", 0) + if err != nil { + t.Fatalf("ParseFile() error = %v", err) + } + pkg, err := new(types.Config).Check("lib", fset, []*ast.File{file}, nil) + if err != nil { + t.Fatalf("types.Check() error = %v", err) + } + + t.Run("gcexportdata write/read direct", func(t *testing.T) { + var out bytes.Buffer + if err := gcexportdata.Write(&out, fset, pkg); err != nil { + t.Fatalf("gcexportdata.Write() error = %v", err) + } + got, err := gcexportdata.Read(bytes.NewReader(out.Bytes()), token.NewFileSet(), make(map[string]*types.Package), pkg.Path()) + if err != nil { + t.Fatalf("gcexportdata.Read() error = %v", err) + } + if got.Scope().Lookup("T") == nil { + t.Fatal("reimported package missing T") + } + }) + + t.Run("gcexportdata write with newreader fails", func(t *testing.T) { + var out bytes.Buffer + if err := gcexportdata.Write(&out, fset, pkg); err != nil { + t.Fatalf("gcexportdata.Write() error = %v", err) + } + if _, err := gcexportdata.NewReader(bytes.NewReader(out.Bytes())); err == nil { + t.Fatal("gcexportdata.NewReader() unexpectedly succeeded on direct gcexportdata.Write output") + } + }) +} + +func TestExportDataRoundTripWithImports(t *testing.T) { + fset := token.NewFileSet() + depPkg, err := new(types.Config).Check("example.com/dep", fset, []*ast.File{ + mustParseFile(t, fset, "dep.go", `package dep + +type T int +`), + }, nil) + if err != nil { + t.Fatalf("types.Check(dep) error = %v", err) + } + pkg, err := (&types.Config{ + Importer: importerFuncForTest(func(path string) (*types.Package, error) { + if path == "example.com/dep" { + return depPkg, nil + } + if path == "unsafe" { + return types.Unsafe, nil + } + return nil, nil + }), + }).Check("example.com/lib", fset, []*ast.File{ + mustParseFile(t, fset, "lib.go", `package lib + +import "example.com/dep" + +type T struct { + S dep.T +} +`), + }, nil) + if err != nil { + t.Fatalf("types.Check() error = %v", err) + } + + var out bytes.Buffer + if err := gcexportdata.Write(&out, fset, pkg); err != nil { + t.Fatalf("gcexportdata.Write() error = %v", err) + } + imports := make(map[string]*types.Package) + got, err := gcexportdata.Read(bytes.NewReader(out.Bytes()), token.NewFileSet(), imports, pkg.Path()) + if err != nil { + t.Fatalf("gcexportdata.Read() error = %v", err) + } + obj := got.Scope().Lookup("T") + if obj == nil { + t.Fatal("reimported package missing T") + } + named, ok := obj.Type().(*types.Named) + if !ok { + t.Fatalf("T type = %T, want *types.Named", obj.Type()) + } + field := named.Underlying().(*types.Struct).Field(0) + if field.Type().String() != "example.com/dep.T" { + t.Fatalf("field type = %q, want %q", field.Type().String(), "example.com/dep.T") + } + depImport := imports["example.com/dep"] + if depImport == nil { + t.Fatal("imports map missing dep") + } + if depImport.Scope().Lookup("T") == nil { + t.Fatal("dep import missing T after import") + } +} + func TestLoadTypedPackageGraphFallback(t *testing.T) { root := t.TempDir() writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") @@ -494,6 +601,644 @@ func TestLoadTypedPackageGraphCustomKeepsExternalPackagesLight(t *testing.T) { } } +func TestLoadTypedPackageGraphCustomExternalArtifactCache(t *testing.T) { + root := t.TempDir() + artifactDir := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport \"fmt\"\n\nfunc Init() string { return fmt.Sprint(\"ok\") }\n") + + env := append(os.Environ(), + loaderArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + + run := func() int { + var parseCalls int + l := New() + got, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: env, + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + parseCalls++ + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) + } + rootPkg := got.Packages[0] + if rootPkg.Imports["fmt"] == nil { + t.Fatal("expected fmt import package") + } + return parseCalls + } + + first := run() + entries, err := os.ReadDir(artifactDir) + if err != nil { + t.Fatalf("ReadDir(%q) error = %v", artifactDir, err) + } + if len(entries) == 0 { + t.Fatal("expected artifact cache files after first run") + } + second := run() + if second >= first { + t.Fatalf("second parseCalls = %d, want less than first run %d", second, first) + } +} + +func TestLoadTypedPackageGraphCustomExternalArtifactCacheReportsHits(t *testing.T) { + root := t.TempDir() + artifactDir := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport \"fmt\"\n\nfunc Init() string { return fmt.Sprint(\"ok\") }\n") + env := append(os.Environ(), + loaderArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + + run := func() []string { + var labels []string + ctx := WithTiming(context.Background(), func(label string, _ time.Duration) { + labels = append(labels, label) + }) + l := New() + _, err := l.LoadTypedPackageGraph(ctx, LazyLoadRequest{ + WD: root, + Env: env, + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) + } + return labels + } + + _ = run() + second := run() + if !hasPrefixLabel(second, "loader.custom.lazy.artifact_hits=") { + t.Fatalf("second run labels missing artifact hit count: %v", second) + } + if !containsPositiveIntLabel(second, "loader.custom.lazy.artifact_hits=") { + t.Fatalf("second run artifact hit count was not positive: %v", second) + } +} + +func TestLoadTypedPackageGraphCustomLeafLocalArtifactCache(t *testing.T) { + root := t.TempDir() + artifactDir := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), "package dep\n\nfunc Provide() string { return \"ok\" }\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport \"example.com/app/dep\"\n\nfunc Init() string { return dep.Provide() }\n") + env := append(os.Environ(), + loaderArtifactEnv+"=1", + loaderLocalArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + + run := func() int { + var parseCalls int + l := New() + _, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: env, + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + parseCalls++ + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) + } + return parseCalls + } + + first := run() + second := run() + if second >= first { + t.Fatalf("second parseCalls = %d, want less than first run %d", second, first) + } +} + +func TestLoadTypedPackageGraphCustomNonLeafLocalArtifactCacheWithoutProviderSets(t *testing.T) { + root := t.TempDir() + artifactDir := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "leaf", "leaf.go"), "package leaf\n\nfunc Provide() string { return \"ok\" }\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), "package dep\n\nimport \"example.com/app/leaf\"\n\nfunc Provide() string { return leaf.Provide() }\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport \"example.com/app/dep\"\n\nfunc Init() string { return dep.Provide() }\n") + env := append(os.Environ(), + loaderArtifactEnv+"=1", + loaderLocalArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + + run := func() int { + var parseCalls int + l := New() + _, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: env, + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + parseCalls++ + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) + } + return parseCalls + } + + first := run() + second := run() + if second >= first { + t.Fatalf("second parseCalls = %d, want less than first run %d", second, first) + } +} + +func TestLoadTypedPackageGraphCustomNonLeafLocalArtifactDisabledForProviderSets(t *testing.T) { + root := t.TempDir() + artifactDir := t.TempDir() + repoRoot, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd() error = %v", err) + } + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n\nrequire github.com/goforj/wire v0.0.0\n\nreplace github.com/goforj/wire => "+repoRoot+"\n") + writeTestFile(t, filepath.Join(root, "leaf", "leaf.go"), "package leaf\n\nfunc Provide() string { return \"ok\" }\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), "package dep\n\nimport (\n\t\"example.com/app/leaf\"\n\t\"github.com/goforj/wire\"\n)\n\nfunc Provide() string { return leaf.Provide() }\n\nvar Set = wire.NewSet(Provide)\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport \"example.com/app/dep\"\n\nfunc Init() string { return dep.Provide() }\n") + env := append(os.Environ(), + loaderArtifactEnv+"=1", + loaderLocalArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + + run := func() int { + var parseCalls int + l := New() + _, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: env, + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + parseCalls++ + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) + } + return parseCalls + } + + first := run() + second := run() + if second < first-1 { + t.Fatalf("second parseCalls = %d, expected provider-set package to stay near first run %d", second, first) + } + meta, err := runGoList(context.Background(), goListRequest{ + WD: root, + Env: env, + Patterns: []string{"./app"}, + NeedDeps: true, + }) + if err != nil { + t.Fatalf("runGoList() error = %v", err) + } + depMeta := meta["example.com/app/dep"] + if depMeta == nil { + t.Fatal("missing metadata for example.com/app/dep") + } + hasProviderSets, ok := readLocalArtifactProviderSetFlag(env, depMeta) + if ok && !hasProviderSets { + t.Fatal("expected provider-set package metadata to record provider sets") + } +} + +func TestLoadTypedPackageGraphCustomLocalArtifactDisabledForProviderSetImporter(t *testing.T) { + root := t.TempDir() + artifactDir := t.TempDir() + repoRoot, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd() error = %v", err) + } + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n\nrequire github.com/goforj/wire v0.0.0\n\nreplace github.com/goforj/wire => "+repoRoot+"\n") + writeTestFile(t, filepath.Join(root, "jobs", "jobs.go"), "package jobs\n\nfunc Provide() string { return \"ok\" }\n") + writeTestFile(t, filepath.Join(root, "cmd", "cmd.go"), "package cmd\n\nimport (\n\t\"example.com/app/jobs\"\n\t\"github.com/goforj/wire\"\n)\n\nvar Set = wire.NewSet(jobs.Provide)\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport _ \"example.com/app/cmd\"\n\nfunc Init() string { return \"ok\" }\n") + env := append(os.Environ(), + loaderArtifactEnv+"=1", + loaderLocalArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + + run := func() (int, []string) { + var ( + parseCalls int + labels []string + ) + ctx := WithTiming(context.Background(), func(label string, _ time.Duration) { + labels = append(labels, label) + }) + l := New() + _, err := l.LoadTypedPackageGraph(ctx, LazyLoadRequest{ + WD: root, + Env: env, + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + parseCalls++ + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) + } + return parseCalls, labels + } + + _, _ = run() + secondCalls, _ := run() + if secondCalls < 2 { + t.Fatalf("second parseCalls = %d, expected source load for cmd and jobs", secondCalls) + } +} + +func TestLoadTypedPackageGraphCustomLocalArtifactDisabledForWireDeclImporter(t *testing.T) { + root := t.TempDir() + artifactDir := t.TempDir() + repoRoot, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd() error = %v", err) + } + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n\nrequire github.com/goforj/wire v0.0.0\n\nreplace github.com/goforj/wire => "+repoRoot+"\n") + writeTestFile(t, filepath.Join(root, "jobs", "jobs.go"), "package jobs\n\nfunc Provide() string { return \"ok\" }\n") + writeTestFile(t, filepath.Join(root, "cfg", "cfg.go"), "package cfg\n\nimport (\n\t\"example.com/app/jobs\"\n\t\"github.com/goforj/wire\"\n)\n\nvar V = wire.Value(jobs.Provide())\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport _ \"example.com/app/cfg\"\n\nfunc Init() string { return \"ok\" }\n") + env := append(os.Environ(), + loaderArtifactEnv+"=1", + loaderLocalArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + + run := func() int { + var parseCalls int + l := New() + _, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: env, + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + parseCalls++ + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) + } + return parseCalls + } + + first := run() + second := run() + if second < first-1 { + t.Fatalf("second parseCalls = %d, expected source load for cfg and jobs to remain near first run %d", second, first) + } + meta, err := runGoList(context.Background(), goListRequest{ + WD: root, + Env: env, + Patterns: []string{"./app"}, + NeedDeps: true, + }) + if err != nil { + t.Fatalf("runGoList() error = %v", err) + } + cfgMeta := meta["example.com/app/cfg"] + if cfgMeta == nil { + t.Fatal("missing metadata for example.com/app/cfg") + } + flags, ok := readLocalArtifactFlags(env, cfgMeta) + if ok && !flags.wireDecls { + t.Fatal("expected wire decl package metadata to record wire declarations") + } +} + +func TestLoadTypedPackageGraphCustomLocalArtifactPreservesImportedPackageName(t *testing.T) { + root := t.TempDir() + artifactDir := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "models", "models.go"), "package models\n\nfunc NewRepo() string { return \"ok\" }\n") + writeTestFile(t, filepath.Join(root, "root", "wire.go"), "package root\n\nimport \"example.com/app/models\"\n\nvar _ = models.NewRepo\n") + env := append(os.Environ(), + loaderArtifactEnv+"=1", + loaderLocalArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + + l := New() + load := func() (*LazyLoadResult, error) { + return l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: env, + Package: "example.com/app/root", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + } + + first, err := load() + if err != nil { + t.Fatalf("first LoadTypedPackageGraph(custom) error = %v", err) + } + second, err := load() + if err != nil { + t.Fatalf("second LoadTypedPackageGraph(custom) error = %v", err) + } + firstRoot := collectGraph(first.Packages)["example.com/app/root"] + secondRoot := collectGraph(second.Packages)["example.com/app/root"] + if firstRoot == nil || secondRoot == nil { + t.Fatal("missing root package") + } + firstModels := firstRoot.Imports["example.com/app/models"] + secondModels := secondRoot.Imports["example.com/app/models"] + if firstModels == nil || secondModels == nil { + t.Fatal("missing imported models package") + } + if firstModels.Types == nil || secondModels.Types == nil { + t.Fatal("expected imported models package to be typed") + } + if firstModels.Types.Name() != "models" { + t.Fatalf("first imported package name = %q, want %q", firstModels.Types.Name(), "models") + } + if secondModels.Types.Name() != "models" { + t.Fatalf("second imported package name = %q, want %q", secondModels.Types.Name(), "models") + } +} + +func TestLoadTypedPackageGraphCustomRealAppDirectImporterBoundarySelectors(t *testing.T) { + root := os.Getenv("WIRE_REAL_APP_ROOT") + if root == "" { + t.Skip("WIRE_REAL_APP_ROOT not set") + } + artifactDir := t.TempDir() + env := append(os.Environ(), + loaderArtifactEnv+"=1", + loaderLocalArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + "WIRE_LOADER_LOCAL_BOUNDARY=direct_importers", + ) + load := func() (*LazyLoadResult, error) { + l := New() + return l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: env, + Package: "test/wire", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + } + if _, err := load(); err != nil { + t.Fatalf("warm LoadTypedPackageGraph(custom) error = %v", err) + } + got, err := load() + if err != nil { + t.Fatalf("second LoadTypedPackageGraph(custom) error = %v", err) + } + graph := collectGraph(got.Packages) + rootPkg := graph["test/wire"] + if rootPkg == nil { + t.Fatal("missing root package test/wire") + } + checkSelector := func(fileSuffix, pkgIdentName string) { + t.Helper() + var targetFile *ast.File + for _, f := range rootPkg.Syntax { + name := rootPkg.Fset.File(f.Pos()).Name() + if strings.HasSuffix(name, fileSuffix) { + targetFile = f + break + } + } + if targetFile == nil { + t.Fatalf("missing syntax file %s", fileSuffix) + } + found := false + ast.Inspect(targetFile, func(node ast.Node) bool { + sel, ok := node.(*ast.SelectorExpr) + if !ok { + return true + } + pkgIdent, ok := sel.X.(*ast.Ident) + if !ok || pkgIdent.Name != pkgIdentName { + return true + } + found = true + pkgObj, ok := rootPkg.TypesInfo.ObjectOf(pkgIdent).(*types.PkgName) + if !ok || pkgObj == nil { + var importBindings []string + for _, spec := range targetFile.Imports { + obj := rootPkg.TypesInfo.Implicits[spec] + path, _ := strconv.Unquote(spec.Path.Value) + name := "" + if spec.Name != nil { + name = spec.Name.Name + } + switch typed := obj.(type) { + case *types.PkgName: + importBindings = append(importBindings, fmt.Sprintf("%s=>%s(%s)", name, typed.Imported().Path(), typed.Imported().Name())) + case nil: + importBindings = append(importBindings, fmt.Sprintf("%s=>nil[%s]", name, path)) + default: + importBindings = append(importBindings, fmt.Sprintf("%s=>%T[%s]", name, obj, path)) + } + } + importPath := "" + for _, spec := range targetFile.Imports { + path, _ := strconv.Unquote(spec.Path.Value) + name := filepath.Base(path) + if spec.Name != nil { + name = spec.Name.Name + } + if name == pkgIdentName { + importPath = path + break + } + } + var depSummary string + if importPath != "" { + if dep := graph[importPath]; dep != nil { + depSummary = fmt.Sprintf("dep=%s name=%q types=%v typeName=%q errors=%v", importPath, dep.Name, dep.Types != nil, func() string { + if dep.Types == nil { + return "" + } + return dep.Types.Name() + }(), dep.Errors) + } else { + depSummary = "dep_missing=" + importPath + } + } + t.Fatalf("%s selector lost package object for %s; imports=%s; importPath=%q; %s; root errors=%v", fileSuffix, pkgIdentName, strings.Join(importBindings, ", "), importPath, depSummary, rootPkg.Errors) + } + if rootPkg.TypesInfo.ObjectOf(sel.Sel) == nil { + t.Fatalf("%s selector lost object for %s.%s", fileSuffix, pkgIdentName, sel.Sel.Name) + } + return false + }) + if !found { + t.Fatalf("did not find selector using %s in %s", pkgIdentName, fileSuffix) + } + } + checkSelector("inject_repositories.go", "models") + checkSelector("inject_http.go", "http") + if len(rootPkg.Errors) > 0 { + var msgs []string + for _, err := range rootPkg.Errors { + msgs = append(msgs, err.Msg) + } + t.Fatalf("root package has errors under direct importer boundary: %s", strings.Join(msgs, "; ")) + } + for _, p := range []string{"test/internal/models", "test/internal/http"} { + dep := graph[p] + if dep == nil { + t.Fatalf("missing dependency package %s", p) + } + if dep.Types == nil { + t.Fatalf("dependency %s missing types", p) + } + if dep.Name == "" || dep.Types.Name() == "" { + t.Fatalf("dependency %s missing package name", p) + } + if dep.Name != dep.Types.Name() { + t.Fatalf("dependency %s package name mismatch: pkg=%q types=%q", p, dep.Name, dep.Types.Name()) + } + } + _ = fmt.Sprintf +} + +func TestLoadTypedPackageGraphCustomExternalArtifactCacheRealAppParity(t *testing.T) { + root := os.Getenv("WIRE_REAL_APP_ROOT") + if root == "" { + t.Skip("WIRE_REAL_APP_ROOT not set") + } + artifactDir := t.TempDir() + load := func(env []string) (map[string]*packages.Package, error) { + l := New() + got, err := l.LoadPackages(context.Background(), PackageLoadRequest{ + WD: root, + Env: env, + Patterns: []string{"."}, + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + return nil, err + } + return collectGraph(got.Packages), nil + } + + base, err := load(os.Environ()) + if err != nil { + t.Fatalf("base load error = %v", err) + } + withArtifactsEnv := append(os.Environ(), + loaderArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + firstArtifact, err := load(withArtifactsEnv) + if err != nil { + t.Fatalf("first artifact load error = %v", err) + } + secondArtifact, err := load(withArtifactsEnv) + if err != nil { + t.Fatalf("second artifact load error = %v", err) + } + if len(base) != len(firstArtifact) { + t.Fatalf("first artifact graph size = %d, want %d", len(firstArtifact), len(base)) + } + if len(base) != len(secondArtifact) { + var missing []string + for path := range base { + if secondArtifact[path] == nil { + missing = append(missing, path) + } + } + sort.Strings(missing) + parents := make(map[string][]string) + for parentPath, pkg := range base { + for impPath := range pkg.Imports { + if secondArtifact[impPath] == nil { + parents[impPath] = append(parents[impPath], parentPath) + } + } + } + parentSummary := make([]string, 0, 5) + for _, path := range missing { + if len(parentSummary) == 5 { + break + } + importers := append([]string(nil), parents[path]...) + sort.Strings(importers) + if len(importers) > 3 { + importers = importers[:3] + } + parentSummary = append(parentSummary, path+" <- "+strings.Join(importers, ",")) + } + if len(missing) > 20 { + missing = missing[:20] + } + secondParent := secondArtifact["github.com/shirou/gopsutil/v4/internal/common"] + secondParentImports := []string(nil) + if secondParent != nil { + secondParentImports = sortedImportPaths(secondParent.Imports) + } + internalCommonParents := append([]string(nil), parents["github.com/shirou/gopsutil/v4/internal/common"]...) + sort.Strings(internalCommonParents) + t.Fatalf("second artifact graph size = %d, want %d; missing sample=%v; parent sample=%v; gopsutil/internal/common parents=%v; gopsutil/internal/common imports on second run=%v", len(secondArtifact), len(base), missing, parentSummary, internalCommonParents, secondParentImports) + } + if compiledFileCount(base) != compiledFileCount(secondArtifact) { + t.Fatalf("second artifact compiled file count = %d, want %d", compiledFileCount(secondArtifact), compiledFileCount(base)) + } +} + func TestLoadRootGraphCustomMatchesFallback(t *testing.T) { root := t.TempDir() writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") @@ -702,6 +1447,14 @@ func collectGraph(roots []*packages.Package) map[string]*packages.Package { return out } +func compiledFileCount(pkgs map[string]*packages.Package) int { + total := 0 + for _, pkg := range pkgs { + total += len(pkg.CompiledGoFiles) + } + return total +} + func equalStrings(a, b []string) bool { if len(a) != len(b) { return false @@ -737,6 +1490,21 @@ func sortedImportPaths(m map[string]*packages.Package) []string { return out } +type importerFuncForTest func(string) (*types.Package, error) + +func (f importerFuncForTest) Import(path string) (*types.Package, error) { + return f(path) +} + +func mustParseFile(t *testing.T, fset *token.FileSet, filename, src string) *ast.File { + t.Helper() + file, err := parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + if err != nil { + t.Fatalf("ParseFile(%q) error = %v", filename, err) + } + return file +} + func normalizePathForCompare(path string) string { if path == "" { return "" @@ -771,6 +1539,29 @@ func comparableErrors(errs []packages.Error) []string { return out } +func hasPrefixLabel(labels []string, prefix string) bool { + for _, label := range labels { + if strings.HasPrefix(label, prefix) { + return true + } + } + return false +} + +func containsPositiveIntLabel(labels []string, prefix string) bool { + for _, label := range labels { + if !strings.HasPrefix(label, prefix) { + continue + } + value := strings.TrimPrefix(label, prefix) + n, err := strconv.Atoi(value) + if err == nil && n > 0 { + return true + } + } + return false +} + func normalizeErrorPos(pos string) string { if pos == "" || pos == "-" { return pos diff --git a/internal/loader/timing.go b/internal/loader/timing.go index 0211f17..4b902db 100644 --- a/internal/loader/timing.go +++ b/internal/loader/timing.go @@ -2,6 +2,8 @@ package loader import ( "context" + "fmt" + "log" "time" ) @@ -39,3 +41,16 @@ func logDuration(ctx context.Context, label string, d time.Duration) { t(label, d) } } + +func logInt(ctx context.Context, label string, v int) { + if t := timing(ctx); t != nil { + t(fmt.Sprintf("%s=%d", label, v), 0) + } +} + +func debugf(ctx context.Context, format string, args ...interface{}) { + if timing(ctx) == nil { + return + } + log.Printf("timing: "+format, args...) +} diff --git a/internal/wire/wire_test.go b/internal/wire/wire_test.go index dc5cfda..7bd6c20 100644 --- a/internal/wire/wire_test.go +++ b/internal/wire/wire_test.go @@ -28,6 +28,7 @@ import ( "os/exec" "path" "path/filepath" + "sort" "strings" "testing" "unicode" @@ -220,6 +221,120 @@ func TestGenerateResultCommitWithStatus(t *testing.T) { } } +func TestGenerateRealAppArtifactParity(t *testing.T) { + root := os.Getenv("WIRE_REAL_APP_ROOT") + if root == "" { + t.Skip("WIRE_REAL_APP_ROOT not set") + } + artifactDir := t.TempDir() + ctx := context.Background() + + run := func(env []string) ([]GenerateResult, []string) { + t.Helper() + gens, errs := Generate(ctx, root, env, []string{"."}, &GenerateOptions{}) + errStrings := make([]string, len(errs)) + for i, err := range errs { + errStrings[i] = err.Error() + } + sort.Strings(errStrings) + return gens, errStrings + } + + baseGens, baseErrs := run(os.Environ()) + artifactEnv := append(os.Environ(), + "WIRE_LOADER_ARTIFACTS=1", + "WIRE_LOADER_ARTIFACT_DIR="+artifactDir, + ) + _, warmErrs := run(artifactEnv) + if diff := cmp.Diff(baseErrs, warmErrs); diff != "" { + t.Fatalf("artifact warm-up errors mismatch (-base +warm):\n%s", diff) + } + artifactGens, artifactErrs := run(artifactEnv) + if diff := cmp.Diff(baseErrs, artifactErrs); diff != "" { + t.Fatalf("artifact errors mismatch (-base +artifact):\n%s", diff) + } + if len(baseGens) != len(artifactGens) { + t.Fatalf("generated file count = %d, want %d", len(artifactGens), len(baseGens)) + } + for i := range baseGens { + if baseGens[i].PkgPath != artifactGens[i].PkgPath { + t.Fatalf("generated package[%d] = %q, want %q", i, artifactGens[i].PkgPath, baseGens[i].PkgPath) + } + if diff := cmp.Diff(string(baseGens[i].Content), string(artifactGens[i].Content)); diff != "" { + t.Fatalf("generated content mismatch for %q (-base +artifact):\n%s", baseGens[i].PkgPath, diff) + } + baseGenErrs := comparableGenerateErrors(baseGens[i].Errs) + artifactGenErrs := comparableGenerateErrors(artifactGens[i].Errs) + if diff := cmp.Diff(baseGenErrs, artifactGenErrs); diff != "" { + t.Fatalf("generate errs mismatch for %q (-base +artifact):\n%s", baseGens[i].PkgPath, diff) + } + } +} + +func TestGenerateRealAppSelfOnlyArtifactParity(t *testing.T) { + root := os.Getenv("WIRE_REAL_APP_ROOT") + if root == "" { + t.Skip("WIRE_REAL_APP_ROOT not set") + } + artifactDir := t.TempDir() + ctx := context.Background() + + run := func(env []string) ([]GenerateResult, []string) { + t.Helper() + gens, errs := Generate(ctx, root, env, []string{"."}, &GenerateOptions{}) + errStrings := make([]string, len(errs)) + for i, err := range errs { + errStrings[i] = err.Error() + } + sort.Strings(errStrings) + return gens, errStrings + } + + artifactEnv := append(os.Environ(), + "WIRE_LOADER_ARTIFACTS=1", + "WIRE_LOADER_LOCAL_ARTIFACTS=1", + "WIRE_LOADER_ARTIFACT_DIR="+artifactDir, + ) + _, warmErrs := run(artifactEnv) + if len(warmErrs) > 0 { + t.Fatalf("artifact warm-up errors: %v", warmErrs) + } + baseGens, baseErrs := run(artifactEnv) + + selfOnlyEnv := append(append([]string(nil), artifactEnv...), + "WIRE_LOADER_LOCAL_BOUNDARY=self_only", + ) + selfOnlyGens, selfOnlyErrs := run(selfOnlyEnv) + if diff := cmp.Diff(baseErrs, selfOnlyErrs); diff != "" { + t.Fatalf("self_only errors mismatch (-base +self_only):\n%s", diff) + } + if len(baseGens) != len(selfOnlyGens) { + t.Fatalf("generated file count = %d, want %d", len(selfOnlyGens), len(baseGens)) + } + for i := range baseGens { + if baseGens[i].PkgPath != selfOnlyGens[i].PkgPath { + t.Fatalf("generated package[%d] = %q, want %q", i, selfOnlyGens[i].PkgPath, baseGens[i].PkgPath) + } + if diff := cmp.Diff(string(baseGens[i].Content), string(selfOnlyGens[i].Content)); diff != "" { + t.Fatalf("generated content mismatch for %q (-base +self_only):\n%s", baseGens[i].PkgPath, diff) + } + baseGenErrs := comparableGenerateErrors(baseGens[i].Errs) + selfOnlyGenErrs := comparableGenerateErrors(selfOnlyGens[i].Errs) + if diff := cmp.Diff(baseGenErrs, selfOnlyGenErrs); diff != "" { + t.Fatalf("generate errs mismatch for %q (-base +self_only):\n%s", baseGens[i].PkgPath, diff) + } + } +} + +func comparableGenerateErrors(errs []error) []string { + out := make([]string, len(errs)) + for i, err := range errs { + out[i] = err.Error() + } + sort.Strings(out) + return out +} + func TestZeroValue(t *testing.T) { t.Parallel() From 93796593f84ca7cd5cc127f9a1fd8bc989df0f35 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sun, 15 Mar 2026 05:44:08 -0500 Subject: [PATCH 10/79] chore: remove local caching strat --- internal/loader/loader_test.go | 457 --------------------------------- internal/wire/wire_test.go | 55 ---- 2 files changed, 512 deletions(-) diff --git a/internal/loader/loader_test.go b/internal/loader/loader_test.go index 50d8690..0c871e0 100644 --- a/internal/loader/loader_test.go +++ b/internal/loader/loader_test.go @@ -17,7 +17,6 @@ package loader import ( "bytes" "context" - "fmt" "go/ast" "go/parser" "go/token" @@ -694,462 +693,6 @@ func TestLoadTypedPackageGraphCustomExternalArtifactCacheReportsHits(t *testing. } } -func TestLoadTypedPackageGraphCustomLeafLocalArtifactCache(t *testing.T) { - root := t.TempDir() - artifactDir := t.TempDir() - writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") - writeTestFile(t, filepath.Join(root, "dep", "dep.go"), "package dep\n\nfunc Provide() string { return \"ok\" }\n") - writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport \"example.com/app/dep\"\n\nfunc Init() string { return dep.Provide() }\n") - env := append(os.Environ(), - loaderArtifactEnv+"=1", - loaderLocalArtifactEnv+"=1", - loaderArtifactDirEnv+"="+artifactDir, - ) - - run := func() int { - var parseCalls int - l := New() - _, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ - WD: root, - Env: env, - Package: "example.com/app/app", - Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, - LoaderMode: ModeCustom, - Fset: token.NewFileSet(), - ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { - parseCalls++ - return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) - }, - }) - if err != nil { - t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) - } - return parseCalls - } - - first := run() - second := run() - if second >= first { - t.Fatalf("second parseCalls = %d, want less than first run %d", second, first) - } -} - -func TestLoadTypedPackageGraphCustomNonLeafLocalArtifactCacheWithoutProviderSets(t *testing.T) { - root := t.TempDir() - artifactDir := t.TempDir() - writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") - writeTestFile(t, filepath.Join(root, "leaf", "leaf.go"), "package leaf\n\nfunc Provide() string { return \"ok\" }\n") - writeTestFile(t, filepath.Join(root, "dep", "dep.go"), "package dep\n\nimport \"example.com/app/leaf\"\n\nfunc Provide() string { return leaf.Provide() }\n") - writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport \"example.com/app/dep\"\n\nfunc Init() string { return dep.Provide() }\n") - env := append(os.Environ(), - loaderArtifactEnv+"=1", - loaderLocalArtifactEnv+"=1", - loaderArtifactDirEnv+"="+artifactDir, - ) - - run := func() int { - var parseCalls int - l := New() - _, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ - WD: root, - Env: env, - Package: "example.com/app/app", - Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, - LoaderMode: ModeCustom, - Fset: token.NewFileSet(), - ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { - parseCalls++ - return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) - }, - }) - if err != nil { - t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) - } - return parseCalls - } - - first := run() - second := run() - if second >= first { - t.Fatalf("second parseCalls = %d, want less than first run %d", second, first) - } -} - -func TestLoadTypedPackageGraphCustomNonLeafLocalArtifactDisabledForProviderSets(t *testing.T) { - root := t.TempDir() - artifactDir := t.TempDir() - repoRoot, err := os.Getwd() - if err != nil { - t.Fatalf("Getwd() error = %v", err) - } - writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n\nrequire github.com/goforj/wire v0.0.0\n\nreplace github.com/goforj/wire => "+repoRoot+"\n") - writeTestFile(t, filepath.Join(root, "leaf", "leaf.go"), "package leaf\n\nfunc Provide() string { return \"ok\" }\n") - writeTestFile(t, filepath.Join(root, "dep", "dep.go"), "package dep\n\nimport (\n\t\"example.com/app/leaf\"\n\t\"github.com/goforj/wire\"\n)\n\nfunc Provide() string { return leaf.Provide() }\n\nvar Set = wire.NewSet(Provide)\n") - writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport \"example.com/app/dep\"\n\nfunc Init() string { return dep.Provide() }\n") - env := append(os.Environ(), - loaderArtifactEnv+"=1", - loaderLocalArtifactEnv+"=1", - loaderArtifactDirEnv+"="+artifactDir, - ) - - run := func() int { - var parseCalls int - l := New() - _, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ - WD: root, - Env: env, - Package: "example.com/app/app", - Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, - LoaderMode: ModeCustom, - Fset: token.NewFileSet(), - ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { - parseCalls++ - return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) - }, - }) - if err != nil { - t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) - } - return parseCalls - } - - first := run() - second := run() - if second < first-1 { - t.Fatalf("second parseCalls = %d, expected provider-set package to stay near first run %d", second, first) - } - meta, err := runGoList(context.Background(), goListRequest{ - WD: root, - Env: env, - Patterns: []string{"./app"}, - NeedDeps: true, - }) - if err != nil { - t.Fatalf("runGoList() error = %v", err) - } - depMeta := meta["example.com/app/dep"] - if depMeta == nil { - t.Fatal("missing metadata for example.com/app/dep") - } - hasProviderSets, ok := readLocalArtifactProviderSetFlag(env, depMeta) - if ok && !hasProviderSets { - t.Fatal("expected provider-set package metadata to record provider sets") - } -} - -func TestLoadTypedPackageGraphCustomLocalArtifactDisabledForProviderSetImporter(t *testing.T) { - root := t.TempDir() - artifactDir := t.TempDir() - repoRoot, err := os.Getwd() - if err != nil { - t.Fatalf("Getwd() error = %v", err) - } - writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n\nrequire github.com/goforj/wire v0.0.0\n\nreplace github.com/goforj/wire => "+repoRoot+"\n") - writeTestFile(t, filepath.Join(root, "jobs", "jobs.go"), "package jobs\n\nfunc Provide() string { return \"ok\" }\n") - writeTestFile(t, filepath.Join(root, "cmd", "cmd.go"), "package cmd\n\nimport (\n\t\"example.com/app/jobs\"\n\t\"github.com/goforj/wire\"\n)\n\nvar Set = wire.NewSet(jobs.Provide)\n") - writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport _ \"example.com/app/cmd\"\n\nfunc Init() string { return \"ok\" }\n") - env := append(os.Environ(), - loaderArtifactEnv+"=1", - loaderLocalArtifactEnv+"=1", - loaderArtifactDirEnv+"="+artifactDir, - ) - - run := func() (int, []string) { - var ( - parseCalls int - labels []string - ) - ctx := WithTiming(context.Background(), func(label string, _ time.Duration) { - labels = append(labels, label) - }) - l := New() - _, err := l.LoadTypedPackageGraph(ctx, LazyLoadRequest{ - WD: root, - Env: env, - Package: "example.com/app/app", - Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, - LoaderMode: ModeCustom, - Fset: token.NewFileSet(), - ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { - parseCalls++ - return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) - }, - }) - if err != nil { - t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) - } - return parseCalls, labels - } - - _, _ = run() - secondCalls, _ := run() - if secondCalls < 2 { - t.Fatalf("second parseCalls = %d, expected source load for cmd and jobs", secondCalls) - } -} - -func TestLoadTypedPackageGraphCustomLocalArtifactDisabledForWireDeclImporter(t *testing.T) { - root := t.TempDir() - artifactDir := t.TempDir() - repoRoot, err := os.Getwd() - if err != nil { - t.Fatalf("Getwd() error = %v", err) - } - writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n\nrequire github.com/goforj/wire v0.0.0\n\nreplace github.com/goforj/wire => "+repoRoot+"\n") - writeTestFile(t, filepath.Join(root, "jobs", "jobs.go"), "package jobs\n\nfunc Provide() string { return \"ok\" }\n") - writeTestFile(t, filepath.Join(root, "cfg", "cfg.go"), "package cfg\n\nimport (\n\t\"example.com/app/jobs\"\n\t\"github.com/goforj/wire\"\n)\n\nvar V = wire.Value(jobs.Provide())\n") - writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport _ \"example.com/app/cfg\"\n\nfunc Init() string { return \"ok\" }\n") - env := append(os.Environ(), - loaderArtifactEnv+"=1", - loaderLocalArtifactEnv+"=1", - loaderArtifactDirEnv+"="+artifactDir, - ) - - run := func() int { - var parseCalls int - l := New() - _, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ - WD: root, - Env: env, - Package: "example.com/app/app", - Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, - LoaderMode: ModeCustom, - Fset: token.NewFileSet(), - ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { - parseCalls++ - return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) - }, - }) - if err != nil { - t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) - } - return parseCalls - } - - first := run() - second := run() - if second < first-1 { - t.Fatalf("second parseCalls = %d, expected source load for cfg and jobs to remain near first run %d", second, first) - } - meta, err := runGoList(context.Background(), goListRequest{ - WD: root, - Env: env, - Patterns: []string{"./app"}, - NeedDeps: true, - }) - if err != nil { - t.Fatalf("runGoList() error = %v", err) - } - cfgMeta := meta["example.com/app/cfg"] - if cfgMeta == nil { - t.Fatal("missing metadata for example.com/app/cfg") - } - flags, ok := readLocalArtifactFlags(env, cfgMeta) - if ok && !flags.wireDecls { - t.Fatal("expected wire decl package metadata to record wire declarations") - } -} - -func TestLoadTypedPackageGraphCustomLocalArtifactPreservesImportedPackageName(t *testing.T) { - root := t.TempDir() - artifactDir := t.TempDir() - writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") - writeTestFile(t, filepath.Join(root, "models", "models.go"), "package models\n\nfunc NewRepo() string { return \"ok\" }\n") - writeTestFile(t, filepath.Join(root, "root", "wire.go"), "package root\n\nimport \"example.com/app/models\"\n\nvar _ = models.NewRepo\n") - env := append(os.Environ(), - loaderArtifactEnv+"=1", - loaderLocalArtifactEnv+"=1", - loaderArtifactDirEnv+"="+artifactDir, - ) - - l := New() - load := func() (*LazyLoadResult, error) { - return l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ - WD: root, - Env: env, - Package: "example.com/app/root", - Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, - LoaderMode: ModeCustom, - Fset: token.NewFileSet(), - ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { - return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) - }, - }) - } - - first, err := load() - if err != nil { - t.Fatalf("first LoadTypedPackageGraph(custom) error = %v", err) - } - second, err := load() - if err != nil { - t.Fatalf("second LoadTypedPackageGraph(custom) error = %v", err) - } - firstRoot := collectGraph(first.Packages)["example.com/app/root"] - secondRoot := collectGraph(second.Packages)["example.com/app/root"] - if firstRoot == nil || secondRoot == nil { - t.Fatal("missing root package") - } - firstModels := firstRoot.Imports["example.com/app/models"] - secondModels := secondRoot.Imports["example.com/app/models"] - if firstModels == nil || secondModels == nil { - t.Fatal("missing imported models package") - } - if firstModels.Types == nil || secondModels.Types == nil { - t.Fatal("expected imported models package to be typed") - } - if firstModels.Types.Name() != "models" { - t.Fatalf("first imported package name = %q, want %q", firstModels.Types.Name(), "models") - } - if secondModels.Types.Name() != "models" { - t.Fatalf("second imported package name = %q, want %q", secondModels.Types.Name(), "models") - } -} - -func TestLoadTypedPackageGraphCustomRealAppDirectImporterBoundarySelectors(t *testing.T) { - root := os.Getenv("WIRE_REAL_APP_ROOT") - if root == "" { - t.Skip("WIRE_REAL_APP_ROOT not set") - } - artifactDir := t.TempDir() - env := append(os.Environ(), - loaderArtifactEnv+"=1", - loaderLocalArtifactEnv+"=1", - loaderArtifactDirEnv+"="+artifactDir, - "WIRE_LOADER_LOCAL_BOUNDARY=direct_importers", - ) - load := func() (*LazyLoadResult, error) { - l := New() - return l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ - WD: root, - Env: env, - Package: "test/wire", - Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, - LoaderMode: ModeCustom, - Fset: token.NewFileSet(), - ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { - return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) - }, - }) - } - if _, err := load(); err != nil { - t.Fatalf("warm LoadTypedPackageGraph(custom) error = %v", err) - } - got, err := load() - if err != nil { - t.Fatalf("second LoadTypedPackageGraph(custom) error = %v", err) - } - graph := collectGraph(got.Packages) - rootPkg := graph["test/wire"] - if rootPkg == nil { - t.Fatal("missing root package test/wire") - } - checkSelector := func(fileSuffix, pkgIdentName string) { - t.Helper() - var targetFile *ast.File - for _, f := range rootPkg.Syntax { - name := rootPkg.Fset.File(f.Pos()).Name() - if strings.HasSuffix(name, fileSuffix) { - targetFile = f - break - } - } - if targetFile == nil { - t.Fatalf("missing syntax file %s", fileSuffix) - } - found := false - ast.Inspect(targetFile, func(node ast.Node) bool { - sel, ok := node.(*ast.SelectorExpr) - if !ok { - return true - } - pkgIdent, ok := sel.X.(*ast.Ident) - if !ok || pkgIdent.Name != pkgIdentName { - return true - } - found = true - pkgObj, ok := rootPkg.TypesInfo.ObjectOf(pkgIdent).(*types.PkgName) - if !ok || pkgObj == nil { - var importBindings []string - for _, spec := range targetFile.Imports { - obj := rootPkg.TypesInfo.Implicits[spec] - path, _ := strconv.Unquote(spec.Path.Value) - name := "" - if spec.Name != nil { - name = spec.Name.Name - } - switch typed := obj.(type) { - case *types.PkgName: - importBindings = append(importBindings, fmt.Sprintf("%s=>%s(%s)", name, typed.Imported().Path(), typed.Imported().Name())) - case nil: - importBindings = append(importBindings, fmt.Sprintf("%s=>nil[%s]", name, path)) - default: - importBindings = append(importBindings, fmt.Sprintf("%s=>%T[%s]", name, obj, path)) - } - } - importPath := "" - for _, spec := range targetFile.Imports { - path, _ := strconv.Unquote(spec.Path.Value) - name := filepath.Base(path) - if spec.Name != nil { - name = spec.Name.Name - } - if name == pkgIdentName { - importPath = path - break - } - } - var depSummary string - if importPath != "" { - if dep := graph[importPath]; dep != nil { - depSummary = fmt.Sprintf("dep=%s name=%q types=%v typeName=%q errors=%v", importPath, dep.Name, dep.Types != nil, func() string { - if dep.Types == nil { - return "" - } - return dep.Types.Name() - }(), dep.Errors) - } else { - depSummary = "dep_missing=" + importPath - } - } - t.Fatalf("%s selector lost package object for %s; imports=%s; importPath=%q; %s; root errors=%v", fileSuffix, pkgIdentName, strings.Join(importBindings, ", "), importPath, depSummary, rootPkg.Errors) - } - if rootPkg.TypesInfo.ObjectOf(sel.Sel) == nil { - t.Fatalf("%s selector lost object for %s.%s", fileSuffix, pkgIdentName, sel.Sel.Name) - } - return false - }) - if !found { - t.Fatalf("did not find selector using %s in %s", pkgIdentName, fileSuffix) - } - } - checkSelector("inject_repositories.go", "models") - checkSelector("inject_http.go", "http") - if len(rootPkg.Errors) > 0 { - var msgs []string - for _, err := range rootPkg.Errors { - msgs = append(msgs, err.Msg) - } - t.Fatalf("root package has errors under direct importer boundary: %s", strings.Join(msgs, "; ")) - } - for _, p := range []string{"test/internal/models", "test/internal/http"} { - dep := graph[p] - if dep == nil { - t.Fatalf("missing dependency package %s", p) - } - if dep.Types == nil { - t.Fatalf("dependency %s missing types", p) - } - if dep.Name == "" || dep.Types.Name() == "" { - t.Fatalf("dependency %s missing package name", p) - } - if dep.Name != dep.Types.Name() { - t.Fatalf("dependency %s package name mismatch: pkg=%q types=%q", p, dep.Name, dep.Types.Name()) - } - } - _ = fmt.Sprintf -} - func TestLoadTypedPackageGraphCustomExternalArtifactCacheRealAppParity(t *testing.T) { root := os.Getenv("WIRE_REAL_APP_ROOT") if root == "" { diff --git a/internal/wire/wire_test.go b/internal/wire/wire_test.go index 7bd6c20..23db303 100644 --- a/internal/wire/wire_test.go +++ b/internal/wire/wire_test.go @@ -271,61 +271,6 @@ func TestGenerateRealAppArtifactParity(t *testing.T) { } } -func TestGenerateRealAppSelfOnlyArtifactParity(t *testing.T) { - root := os.Getenv("WIRE_REAL_APP_ROOT") - if root == "" { - t.Skip("WIRE_REAL_APP_ROOT not set") - } - artifactDir := t.TempDir() - ctx := context.Background() - - run := func(env []string) ([]GenerateResult, []string) { - t.Helper() - gens, errs := Generate(ctx, root, env, []string{"."}, &GenerateOptions{}) - errStrings := make([]string, len(errs)) - for i, err := range errs { - errStrings[i] = err.Error() - } - sort.Strings(errStrings) - return gens, errStrings - } - - artifactEnv := append(os.Environ(), - "WIRE_LOADER_ARTIFACTS=1", - "WIRE_LOADER_LOCAL_ARTIFACTS=1", - "WIRE_LOADER_ARTIFACT_DIR="+artifactDir, - ) - _, warmErrs := run(artifactEnv) - if len(warmErrs) > 0 { - t.Fatalf("artifact warm-up errors: %v", warmErrs) - } - baseGens, baseErrs := run(artifactEnv) - - selfOnlyEnv := append(append([]string(nil), artifactEnv...), - "WIRE_LOADER_LOCAL_BOUNDARY=self_only", - ) - selfOnlyGens, selfOnlyErrs := run(selfOnlyEnv) - if diff := cmp.Diff(baseErrs, selfOnlyErrs); diff != "" { - t.Fatalf("self_only errors mismatch (-base +self_only):\n%s", diff) - } - if len(baseGens) != len(selfOnlyGens) { - t.Fatalf("generated file count = %d, want %d", len(selfOnlyGens), len(baseGens)) - } - for i := range baseGens { - if baseGens[i].PkgPath != selfOnlyGens[i].PkgPath { - t.Fatalf("generated package[%d] = %q, want %q", i, selfOnlyGens[i].PkgPath, baseGens[i].PkgPath) - } - if diff := cmp.Diff(string(baseGens[i].Content), string(selfOnlyGens[i].Content)); diff != "" { - t.Fatalf("generated content mismatch for %q (-base +self_only):\n%s", baseGens[i].PkgPath, diff) - } - baseGenErrs := comparableGenerateErrors(baseGens[i].Errs) - selfOnlyGenErrs := comparableGenerateErrors(selfOnlyGens[i].Errs) - if diff := cmp.Diff(baseGenErrs, selfOnlyGenErrs); diff != "" { - t.Fatalf("generate errs mismatch for %q (-base +self_only):\n%s", baseGens[i].PkgPath, diff) - } - } -} - func comparableGenerateErrors(errs []error) []string { out := make([]string, len(errs)) for i, err := range errs { From 3bc75c443ed89a2855a4c7479d8412da2bccb12f Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sun, 15 Mar 2026 06:06:56 -0500 Subject: [PATCH 11/79] feat: local caching from wire perspective --- internal/loader/custom.go | 48 ++- internal/semanticcache/cache.go | 143 +++++++ internal/wire/parse.go | 563 ++++++++++++++++++++++++++- internal/wire/parse_coverage_test.go | 228 +++++++++++ internal/wire/wire.go | 2 +- 5 files changed, 978 insertions(+), 6 deletions(-) create mode 100644 internal/semanticcache/cache.go diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 49fc217..d4a7f66 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -33,6 +33,8 @@ import ( "golang.org/x/tools/go/gcexportdata" "golang.org/x/tools/go/packages" + + "github.com/goforj/wire/internal/semanticcache" ) type unsupportedError struct { @@ -525,7 +527,7 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er } l.packages[path] = pkg } - useArtifact := loaderArtifactEnabled(l.env) && !isTarget && !isLocal + useArtifact := loaderArtifactEnabled(l.env) && !isTarget && (!isLocal || l.useLocalSemanticArtifact(meta)) if useArtifact { if typed, ok := l.readArtifact(path, meta, isLocal); ok { linkStart := time.Now() @@ -629,12 +631,23 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er pkg.Types = tpkg pkg.TypesInfo = info pkg.Errors = append(pkg.Errors, typeErrors...) - if shouldWriteArtifact(l.env, isTarget, isLocal) && len(pkg.Errors) == 0 { + if shouldWriteArtifact(l.env, isTarget) && len(pkg.Errors) == 0 { _ = l.writeArtifact(meta, tpkg, isLocal) } return pkg, nil } +func (l *customTypedGraphLoader) useLocalSemanticArtifact(meta *packageMeta) bool { + if meta == nil { + return false + } + art, err := semanticcache.Read(l.env, meta.ImportPath, meta.Name, metaFiles(meta)) + if err != nil || art == nil { + return false + } + return art.Supported +} + func (l *customTypedGraphLoader) checkFiles(path string, checker *types.Checker, files []*ast.File) (err error) { defer func() { if r := recover(); r != nil { @@ -650,15 +663,37 @@ func (l *customTypedGraphLoader) readArtifact(path string, meta *packageMeta, is artifactPath, err := loaderArtifactPath(l.env, meta, isLocal) l.stats.artifactPath += time.Since(pathStart) if err != nil { + debugf(l.ctx, "loader.artifact.read_miss pkg=%s local=%t reason=path_error err=%v", path, isLocal, err) l.stats.artifactRead += time.Since(start) l.stats.artifactMisses++ return nil, false } + if isLocal { + preloadStart := time.Now() + for _, imp := range meta.Imports { + target := imp + if mapped := meta.ImportMap[imp]; mapped != "" { + target = mapped + } + dep, err := l.loadPackage(target) + if err != nil { + debugf(l.ctx, "loader.artifact.read_miss pkg=%s local=%t reason=preload_dep_error dep=%s err=%v", path, isLocal, target, err) + l.stats.artifactRead += time.Since(start) + l.stats.artifactMisses++ + return nil, false + } + if dep != nil && dep.Types != nil { + l.typesPkgs[target] = dep.Types + } + } + l.stats.artifactImportLink += time.Since(preloadStart) + } var tpkg *types.Package decodeStart := time.Now() tpkg, err = readLoaderArtifact(artifactPath, l.fset, l.typesPkgs, path) l.stats.artifactDecode += time.Since(decodeStart) if err != nil { + debugf(l.ctx, "loader.artifact.read_miss pkg=%s local=%t reason=decode_error err=%v", path, isLocal, err) l.stats.artifactRead += time.Since(start) l.stats.artifactMisses++ return nil, false @@ -673,10 +708,12 @@ func (l *customTypedGraphLoader) writeArtifact(meta *packageMeta, pkg *types.Pac start := time.Now() artifactPath, err := loaderArtifactPath(l.env, meta, isLocal) if err != nil { + debugf(l.ctx, "loader.artifact.write_skip pkg=%s local=%t reason=path_error err=%v", meta.ImportPath, isLocal, err) l.stats.artifactWrite += time.Since(start) return err } if artifactUpToDate(l.env, artifactPath, meta, isLocal) { + debugf(l.ctx, "loader.artifact.write_skip pkg=%s local=%t reason=up_to_date", meta.ImportPath, isLocal) l.stats.artifactWrite += time.Since(start) return nil } @@ -684,6 +721,9 @@ func (l *customTypedGraphLoader) writeArtifact(meta *packageMeta, pkg *types.Pac l.stats.artifactWrite += time.Since(start) if writeErr == nil { l.stats.artifactWrites++ + debugf(l.ctx, "loader.artifact.write_ok pkg=%s local=%t path=%s", meta.ImportPath, isLocal, artifactPath) + } else { + debugf(l.ctx, "loader.artifact.write_fail pkg=%s local=%t err=%v", meta.ImportPath, isLocal, writeErr) } if writeErr != nil { return writeErr @@ -691,8 +731,8 @@ func (l *customTypedGraphLoader) writeArtifact(meta *packageMeta, pkg *types.Pac return nil } -func shouldWriteArtifact(env []string, isTarget, isLocal bool) bool { - if !loaderArtifactEnabled(env) || isTarget || isLocal { +func shouldWriteArtifact(env []string, isTarget bool) bool { + if !loaderArtifactEnabled(env) || isTarget { return false } return true diff --git a/internal/semanticcache/cache.go b/internal/semanticcache/cache.go new file mode 100644 index 0000000..4442415 --- /dev/null +++ b/internal/semanticcache/cache.go @@ -0,0 +1,143 @@ +package semanticcache + +import ( + "crypto/sha256" + "encoding/gob" + "encoding/hex" + "os" + "path/filepath" + "runtime" + "strconv" +) + +const dirEnv = "WIRE_SEMANTIC_CACHE_DIR" + +type PackageArtifact struct { + Version int + PackagePath string + PackageName string + HasProviderSetVars bool + Supported bool + Vars map[string]ProviderSetArtifact +} + +type ProviderSetArtifact struct { + Items []ProviderSetItemArtifact +} + +type ProviderSetItemArtifact struct { + Kind string + ImportPath string + Name string + Type TypeRef + Type2 TypeRef + FieldNames []string + AllFields bool +} + +type TypeRef struct { + ImportPath string + Name string + Pointer int +} + +func ArtifactPath(env []string, importPath, packageName string, files []string) (string, error) { + dir, err := artifactDir(env) + if err != nil { + return "", err + } + key, err := artifactKey(importPath, packageName, files) + if err != nil { + return "", err + } + return filepath.Join(dir, key+".gob"), nil +} + +func Read(env []string, importPath, packageName string, files []string) (*PackageArtifact, error) { + path, err := ArtifactPath(env, importPath, packageName, files) + if err != nil { + return nil, err + } + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + var art PackageArtifact + if err := gob.NewDecoder(f).Decode(&art); err != nil { + return nil, err + } + return &art, nil +} + +func Write(env []string, importPath, packageName string, files []string, art *PackageArtifact) error { + path, err := ArtifactPath(env, importPath, packageName, files) + if err != nil { + return err + } + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return err + } + f, err := os.Create(path) + if err != nil { + return err + } + defer f.Close() + return gob.NewEncoder(f).Encode(art) +} + +func Exists(env []string, importPath, packageName string, files []string) bool { + path, err := ArtifactPath(env, importPath, packageName, files) + if err != nil { + return false + } + _, err = os.Stat(path) + return err == nil +} + +func artifactDir(env []string) (string, error) { + for i := len(env) - 1; i >= 0; i-- { + key, val, ok := splitEnv(env[i]) + if ok && key == dirEnv && val != "" { + return val, nil + } + } + base, err := os.UserCacheDir() + if err != nil { + return "", err + } + return filepath.Join(base, "wire", "semantic-artifacts"), nil +} + +func artifactKey(importPath, packageName string, files []string) (string, error) { + sum := sha256.New() + sum.Write([]byte("wire-semantic-artifact-v1\n")) + sum.Write([]byte(runtime.Version())) + sum.Write([]byte{'\n'}) + sum.Write([]byte(importPath)) + sum.Write([]byte{'\n'}) + sum.Write([]byte(packageName)) + sum.Write([]byte{'\n'}) + for _, name := range files { + info, err := os.Stat(name) + if err != nil { + return "", err + } + sum.Write([]byte(name)) + sum.Write([]byte{'\n'}) + sum.Write([]byte(strconv.FormatInt(info.Size(), 10))) + sum.Write([]byte{'\n'}) + sum.Write([]byte(strconv.FormatInt(info.ModTime().UnixNano(), 10))) + sum.Write([]byte{'\n'}) + } + return hex.EncodeToString(sum.Sum(nil)), nil +} + +func splitEnv(kv string) (string, string, bool) { + for i := 0; i < len(kv); i++ { + if kv[i] == '=' { + return kv[:i], kv[i+1:], true + } + } + return "", "", false +} diff --git a/internal/wire/parse.go b/internal/wire/parse.go index e6f8cb1..a825b4b 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -33,6 +33,7 @@ import ( "golang.org/x/tools/go/types/typeutil" "github.com/goforj/wire/internal/loader" + "github.com/goforj/wire/internal/semanticcache" ) // A providerSetSrc captures the source for a type provided by a ProviderSet. @@ -267,7 +268,7 @@ func Load(ctx context.Context, wd string, env []string, tags string, patterns [] Fset: fset, Sets: make(map[ProviderSetID]*ProviderSet), } - oc := newObjectCache(pkgs) + oc := newObjectCacheWithEnv(pkgs, env) ec := new(errorCollector) for _, pkg := range pkgs { if isWireImport(pkg.PkgPath) { @@ -452,8 +453,10 @@ func (in *Injector) String() string { // objectCache is a lazily evaluated mapping of objects to Wire structures. type objectCache struct { fset *token.FileSet + env []string packages map[string]*packages.Package objects map[objRef]objCacheEntry + semantic map[string]*semanticcache.PackageArtifact hasher typeutil.Hasher } @@ -468,13 +471,19 @@ type objCacheEntry struct { } func newObjectCache(pkgs []*packages.Package) *objectCache { + return newObjectCacheWithEnv(pkgs, nil) +} + +func newObjectCacheWithEnv(pkgs []*packages.Package, env []string) *objectCache { if len(pkgs) == 0 { panic("object cache must have packages to draw from") } oc := &objectCache{ fset: pkgs[0].Fset, + env: append([]string(nil), env...), packages: make(map[string]*packages.Package), objects: make(map[objRef]objCacheEntry), + semantic: make(map[string]*semanticcache.PackageArtifact), hasher: typeutil.MakeHasher(), } // Depth-first search of all dependencies to gather import path to @@ -482,6 +491,7 @@ func newObjectCache(pkgs []*packages.Package) *objectCache { // call to packages.Load and an import path X, there will exist only // one *packages.Package value with PkgPath X. oc.registerPackages(pkgs, false) + oc.recordSemanticArtifacts() return oc } @@ -528,6 +538,11 @@ func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { switch obj := obj.(type) { case *types.Var: spec := oc.varDecl(obj) + if spec == nil && isProviderSetType(obj.Type()) { + if pset, ok, errs := oc.semanticProviderSet(obj); ok { + return pset, errs + } + } if spec == nil || len(spec.Values) == 0 { return nil, []error{fmt.Errorf("%v is not a provider or a provider set", obj)} } @@ -546,6 +561,552 @@ func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { } } +func (oc *objectCache) semanticProviderSet(obj *types.Var) (*ProviderSet, bool, []error) { + pkg := oc.packages[obj.Pkg().Path()] + if pkg == nil { + return nil, false, nil + } + art := oc.semanticArtifact(pkg) + if art == nil || !art.Supported { + return nil, false, nil + } + setArt, ok := art.Vars[obj.Name()] + if !ok { + return nil, false, nil + } + pset := &ProviderSet{ + Pos: obj.Pos(), + PkgPath: obj.Pkg().Path(), + VarName: obj.Name(), + } + ec := new(errorCollector) + for _, item := range setArt.Items { + switch item.Kind { + case "func": + providerObj, errs := oc.semanticProvider(item.ImportPath, item.Name) + if len(errs) > 0 { + ec.add(errs...) + continue + } + pset.Providers = append(pset.Providers, providerObj) + case "set": + setObj, errs := oc.semanticImportedSet(item.ImportPath, item.Name) + if len(errs) > 0 { + ec.add(errs...) + continue + } + pset.Imports = append(pset.Imports, setObj) + case "bind": + binding, errs := oc.semanticBinding(item) + if len(errs) > 0 { + ec.add(errs...) + continue + } + pset.Bindings = append(pset.Bindings, binding) + case "struct": + providerObj, errs := oc.semanticStructProvider(item) + if len(errs) > 0 { + ec.add(errs...) + continue + } + pset.Providers = append(pset.Providers, providerObj) + case "fields": + fields, errs := oc.semanticFields(item) + if len(errs) > 0 { + ec.add(errs...) + continue + } + pset.Fields = append(pset.Fields, fields...) + default: + ec.add(fmt.Errorf("unsupported semantic cache item kind %q", item.Kind)) + } + } + if len(ec.errors) > 0 { + return nil, true, ec.errors + } + var errs []error + pset.providerMap, pset.srcMap, errs = buildProviderMap(oc.fset, oc.hasher, pset) + if len(errs) > 0 { + return nil, true, errs + } + if errs := verifyAcyclic(pset.providerMap, oc.hasher); len(errs) > 0 { + return nil, true, errs + } + return pset, true, nil +} + +func (oc *objectCache) semanticProvider(importPath, name string) (*Provider, []error) { + pkg := oc.packages[importPath] + if pkg == nil || pkg.Types == nil { + return nil, []error{fmt.Errorf("missing typed package for %s", importPath)} + } + obj := pkg.Types.Scope().Lookup(name) + fn, ok := obj.(*types.Func) + if !ok || fn == nil { + return nil, []error{fmt.Errorf("%s.%s is not a provider function", importPath, name)} + } + return processFuncProvider(oc.fset, fn) +} + +func (oc *objectCache) semanticImportedSet(importPath, name string) (*ProviderSet, []error) { + pkg := oc.packages[importPath] + if pkg == nil || pkg.Types == nil { + return nil, []error{fmt.Errorf("missing typed package for %s", importPath)} + } + obj := pkg.Types.Scope().Lookup(name) + v, ok := obj.(*types.Var) + if !ok || v == nil || !isProviderSetType(v.Type()) { + return nil, []error{fmt.Errorf("%s.%s is not a provider set", importPath, name)} + } + item, errs := oc.get(v) + if len(errs) > 0 { + return nil, errs + } + pset, ok := item.(*ProviderSet) + if !ok || pset == nil { + return nil, []error{fmt.Errorf("%s.%s did not resolve to a provider set", importPath, name)} + } + return pset, nil +} + +func (oc *objectCache) semanticBinding(item semanticcache.ProviderSetItemArtifact) (*IfaceBinding, []error) { + iface, err := oc.semanticType(item.Type) + if err != nil { + return nil, []error{err} + } + provided, err := oc.semanticType(item.Type2) + if err != nil { + return nil, []error{err} + } + return &IfaceBinding{ + Iface: iface, + Provided: provided, + }, nil +} + +func (oc *objectCache) semanticStructProvider(item semanticcache.ProviderSetItemArtifact) (*Provider, []error) { + typeName, err := oc.semanticTypeName(item.Type) + if err != nil { + return nil, []error{err} + } + out := typeName.Type() + st, ok := out.Underlying().(*types.Struct) + if !ok { + return nil, []error{fmt.Errorf("%s.%s does not name a struct", item.Type.ImportPath, item.Type.Name)} + } + provider := &Provider{ + Pkg: typeName.Pkg(), + Name: typeName.Name(), + Pos: typeName.Pos(), + IsStruct: true, + Out: []types.Type{out, types.NewPointer(out)}, + } + if item.AllFields { + for i := 0; i < st.NumFields(); i++ { + if isPrevented(st.Tag(i)) { + continue + } + f := st.Field(i) + provider.Args = append(provider.Args, ProviderInput{ + Type: f.Type(), + FieldName: f.Name(), + }) + } + } else { + for _, fieldName := range item.FieldNames { + f := lookupStructField(st, fieldName) + if f == nil { + return nil, []error{fmt.Errorf("field %q not found in %s.%s", fieldName, item.Type.ImportPath, item.Type.Name)} + } + provider.Args = append(provider.Args, ProviderInput{ + Type: f.Type(), + FieldName: f.Name(), + }) + } + } + return provider, nil +} + +func (oc *objectCache) semanticFields(item semanticcache.ProviderSetItemArtifact) ([]*Field, []error) { + parent, err := oc.semanticType(item.Type) + if err != nil { + return nil, []error{err} + } + structType, ptrToField, err := structFromFieldsParent(parent) + if err != nil { + return nil, []error{err} + } + fields := make([]*Field, 0, len(item.FieldNames)) + for _, fieldName := range item.FieldNames { + v := lookupStructField(structType, fieldName) + if v == nil { + return nil, []error{fmt.Errorf("field %q not found", fieldName)} + } + out := []types.Type{v.Type()} + if ptrToField { + out = append(out, types.NewPointer(v.Type())) + } + fields = append(fields, &Field{ + Parent: parent, + Name: v.Name(), + Pkg: v.Pkg(), + Pos: v.Pos(), + Out: out, + }) + } + return fields, nil +} + +func (oc *objectCache) semanticType(ref semanticcache.TypeRef) (types.Type, error) { + typeName, err := oc.semanticTypeName(ref) + if err != nil { + return nil, err + } + var typ types.Type = typeName.Type() + for i := 0; i < ref.Pointer; i++ { + typ = types.NewPointer(typ) + } + return typ, nil +} + +func (oc *objectCache) semanticTypeName(ref semanticcache.TypeRef) (*types.TypeName, error) { + pkg := oc.packages[ref.ImportPath] + if pkg == nil || pkg.Types == nil { + return nil, fmt.Errorf("missing typed package for %s", ref.ImportPath) + } + obj := pkg.Types.Scope().Lookup(ref.Name) + typeName, ok := obj.(*types.TypeName) + if !ok || typeName == nil { + return nil, fmt.Errorf("%s.%s is not a named type", ref.ImportPath, ref.Name) + } + return typeName, nil +} + +func structFromFieldsParent(parent types.Type) (*types.Struct, bool, error) { + ptr, ok := parent.(*types.Pointer) + if !ok { + return nil, false, fmt.Errorf("parent type %s is not a pointer", types.TypeString(parent, nil)) + } + switch t := ptr.Elem().Underlying().(type) { + case *types.Pointer: + st, ok := t.Elem().Underlying().(*types.Struct) + if !ok { + return nil, false, fmt.Errorf("parent type %s does not point to a struct", types.TypeString(parent, nil)) + } + return st, true, nil + case *types.Struct: + return t, false, nil + default: + return nil, false, fmt.Errorf("parent type %s does not point to a struct", types.TypeString(parent, nil)) + } +} + +func lookupStructField(st *types.Struct, name string) *types.Var { + for i := 0; i < st.NumFields(); i++ { + if st.Field(i).Name() == name { + return st.Field(i) + } + } + return nil +} + +func (oc *objectCache) semanticArtifact(pkg *packages.Package) *semanticcache.PackageArtifact { + if pkg == nil { + return nil + } + if art, ok := oc.semantic[pkg.PkgPath]; ok { + return art + } + if len(oc.env) == 0 || len(pkg.GoFiles) == 0 { + return nil + } + art, err := semanticcache.Read(oc.env, pkg.PkgPath, pkg.Name, pkg.GoFiles) + if err != nil { + return nil + } + oc.semantic[pkg.PkgPath] = art + return art +} + +func (oc *objectCache) recordSemanticArtifacts() { + if len(oc.env) == 0 { + return + } + for _, pkg := range oc.packages { + if pkg == nil || len(pkg.Syntax) == 0 || pkg.Types == nil || pkg.TypesInfo == nil || len(pkg.GoFiles) == 0 { + continue + } + art := buildSemanticArtifact(pkg) + if art == nil { + continue + } + oc.semantic[pkg.PkgPath] = art + _ = semanticcache.Write(oc.env, pkg.PkgPath, pkg.Name, pkg.GoFiles, art) + } +} + +func buildSemanticArtifact(pkg *packages.Package) *semanticcache.PackageArtifact { + if pkg == nil || pkg.Types == nil || pkg.TypesInfo == nil { + return nil + } + art := &semanticcache.PackageArtifact{ + Version: 1, + PackagePath: pkg.PkgPath, + PackageName: pkg.Name, + Supported: true, + Vars: make(map[string]semanticcache.ProviderSetArtifact), + } + scope := pkg.Types.Scope() + for _, name := range scope.Names() { + obj := scope.Lookup(name) + v, ok := obj.(*types.Var) + if !ok || !isProviderSetType(v.Type()) { + continue + } + art.HasProviderSetVars = true + spec := semanticVarDecl(pkg, v) + if spec == nil || len(spec.Values) == 0 { + art.Supported = false + continue + } + var idx int + found := false + for i := range spec.Names { + if spec.Names[i].Name == v.Name() { + idx = i + found = true + break + } + } + if !found || idx >= len(spec.Values) { + art.Supported = false + continue + } + setArt, ok := summarizeSemanticProviderSet(pkg.TypesInfo, spec.Values[idx], pkg.PkgPath) + if !ok { + art.Supported = false + continue + } + art.Vars[v.Name()] = setArt + } + return art +} + +func summarizeSemanticProviderSet(info *types.Info, expr ast.Expr, pkgPath string) (semanticcache.ProviderSetArtifact, bool) { + call, ok := astutil.Unparen(expr).(*ast.CallExpr) + if !ok { + return semanticcache.ProviderSetArtifact{}, false + } + fnObj := qualifiedIdentObject(info, call.Fun) + if fnObj == nil || fnObj.Pkg() == nil || !isWireImport(fnObj.Pkg().Path()) || fnObj.Name() != "NewSet" { + return semanticcache.ProviderSetArtifact{}, false + } + setArt := semanticcache.ProviderSetArtifact{ + Items: make([]semanticcache.ProviderSetItemArtifact, 0, len(call.Args)), + } + for _, arg := range call.Args { + items, ok := summarizeSemanticProviderSetArg(info, astutil.Unparen(arg), pkgPath) + if !ok { + return semanticcache.ProviderSetArtifact{}, false + } + setArt.Items = append(setArt.Items, items...) + } + return setArt, true +} + +func summarizeSemanticProviderSetArg(info *types.Info, expr ast.Expr, pkgPath string) ([]semanticcache.ProviderSetItemArtifact, bool) { + if obj := qualifiedIdentObject(info, expr); obj != nil && obj.Pkg() != nil && obj.Exported() { + item := semanticcache.ProviderSetItemArtifact{ + ImportPath: obj.Pkg().Path(), + Name: obj.Name(), + } + switch typed := obj.(type) { + case *types.Func: + item.Kind = "func" + case *types.Var: + if !isProviderSetType(typed.Type()) { + return nil, false + } + item.Kind = "set" + default: + return nil, false + } + if item.ImportPath == "" { + item.ImportPath = pkgPath + } + return []semanticcache.ProviderSetItemArtifact{item}, true + } + call, ok := expr.(*ast.CallExpr) + if !ok { + return nil, false + } + fnObj := qualifiedIdentObject(info, call.Fun) + if fnObj == nil || fnObj.Pkg() == nil || !isWireImport(fnObj.Pkg().Path()) { + return nil, false + } + switch fnObj.Name() { + case "NewSet": + nested, ok := summarizeSemanticProviderSet(info, call, pkgPath) + if !ok { + return nil, false + } + return nested.Items, true + case "Bind": + item, ok := summarizeSemanticBind(info, call) + if !ok { + return nil, false + } + return []semanticcache.ProviderSetItemArtifact{item}, true + case "Struct": + item, ok := summarizeSemanticStruct(info, call) + if !ok { + return nil, false + } + return []semanticcache.ProviderSetItemArtifact{item}, true + case "FieldsOf": + item, ok := summarizeSemanticFields(info, call) + if !ok { + return nil, false + } + return []semanticcache.ProviderSetItemArtifact{item}, true + default: + return nil, false + } +} + +func summarizeSemanticBind(info *types.Info, call *ast.CallExpr) (semanticcache.ProviderSetItemArtifact, bool) { + if len(call.Args) != 2 { + return semanticcache.ProviderSetItemArtifact{}, false + } + iface, ok := summarizeTypeRef(info.TypeOf(call.Args[0])) + if !ok || iface.Pointer == 0 { + return semanticcache.ProviderSetItemArtifact{}, false + } + iface.Pointer-- + providedType := info.TypeOf(call.Args[1]) + if bindShouldUsePointer(info, call) { + ptr, ok := providedType.(*types.Pointer) + if !ok { + return semanticcache.ProviderSetItemArtifact{}, false + } + providedType = ptr.Elem() + } + provided, ok := summarizeTypeRef(providedType) + if !ok { + return semanticcache.ProviderSetItemArtifact{}, false + } + return semanticcache.ProviderSetItemArtifact{ + Kind: "bind", + Type: iface, + Type2: provided, + }, true +} + +func summarizeSemanticStruct(info *types.Info, call *ast.CallExpr) (semanticcache.ProviderSetItemArtifact, bool) { + if len(call.Args) < 1 { + return semanticcache.ProviderSetItemArtifact{}, false + } + structType := info.TypeOf(call.Args[0]) + ptr, ok := structType.(*types.Pointer) + if !ok { + return semanticcache.ProviderSetItemArtifact{}, false + } + ref, ok := summarizeTypeRef(ptr.Elem()) + if !ok { + return semanticcache.ProviderSetItemArtifact{}, false + } + item := semanticcache.ProviderSetItemArtifact{ + Kind: "struct", + Type: ref, + } + if allFields(call) { + item.AllFields = true + return item, true + } + item.FieldNames = make([]string, 0, len(call.Args)-1) + for i := 1; i < len(call.Args); i++ { + lit, ok := call.Args[i].(*ast.BasicLit) + if !ok { + return semanticcache.ProviderSetItemArtifact{}, false + } + fieldName, err := strconv.Unquote(lit.Value) + if err != nil { + return semanticcache.ProviderSetItemArtifact{}, false + } + item.FieldNames = append(item.FieldNames, fieldName) + } + return item, true +} + +func summarizeSemanticFields(info *types.Info, call *ast.CallExpr) (semanticcache.ProviderSetItemArtifact, bool) { + if len(call.Args) < 2 { + return semanticcache.ProviderSetItemArtifact{}, false + } + parent, ok := summarizeTypeRef(info.TypeOf(call.Args[0])) + if !ok { + return semanticcache.ProviderSetItemArtifact{}, false + } + item := semanticcache.ProviderSetItemArtifact{ + Kind: "fields", + Type: parent, + FieldNames: make([]string, 0, len(call.Args)-1), + } + for i := 1; i < len(call.Args); i++ { + lit, ok := call.Args[i].(*ast.BasicLit) + if !ok { + return semanticcache.ProviderSetItemArtifact{}, false + } + fieldName, err := strconv.Unquote(lit.Value) + if err != nil { + return semanticcache.ProviderSetItemArtifact{}, false + } + item.FieldNames = append(item.FieldNames, fieldName) + } + return item, true +} + +func summarizeTypeRef(typ types.Type) (semanticcache.TypeRef, bool) { + ref := semanticcache.TypeRef{} + for { + ptr, ok := typ.(*types.Pointer) + if !ok { + break + } + ref.Pointer++ + typ = ptr.Elem() + } + named, ok := typ.(*types.Named) + if !ok { + return semanticcache.TypeRef{}, false + } + obj := named.Obj() + if obj == nil || obj.Pkg() == nil { + return semanticcache.TypeRef{}, false + } + ref.ImportPath = obj.Pkg().Path() + ref.Name = obj.Name() + return ref, true +} + +func semanticVarDecl(pkg *packages.Package, obj *types.Var) *ast.ValueSpec { + pos := obj.Pos() + for _, f := range pkg.Syntax { + tokenFile := pkg.Fset.File(f.Pos()) + if tokenFile == nil { + continue + } + if base := tokenFile.Base(); base <= int(pos) && int(pos) < base+tokenFile.Size() { + path, _ := astutil.PathEnclosingInterval(f, pos, pos) + for _, node := range path { + if spec, ok := node.(*ast.ValueSpec); ok { + return spec + } + } + } + } + return nil +} + // varDecl finds the declaration that defines the given variable. func (oc *objectCache) varDecl(obj *types.Var) *ast.ValueSpec { // TODO(light): Walk files to build object -> declaration mapping, if more performant. diff --git a/internal/wire/parse_coverage_test.go b/internal/wire/parse_coverage_test.go index 7c7a3b7..3a23d18 100644 --- a/internal/wire/parse_coverage_test.go +++ b/internal/wire/parse_coverage_test.go @@ -22,6 +22,8 @@ import ( "golang.org/x/tools/go/packages" "golang.org/x/tools/go/types/typeutil" + + "github.com/goforj/wire/internal/semanticcache" ) func TestFindInjectorBuildVariants(t *testing.T) { @@ -220,6 +222,232 @@ func TestProcessStructProviderDuplicateFields(t *testing.T) { } } +func TestSummarizeSemanticProviderSet(t *testing.T) { + t.Parallel() + + info := &types.Info{ + Uses: make(map[*ast.Ident]types.Object), + } + wirePkg := types.NewPackage("github.com/goforj/wire", "wire") + wireIdent := ast.NewIdent("wire") + newSetIdent := ast.NewIdent("NewSet") + info.Uses[wireIdent] = types.NewPkgName(token.NoPos, nil, "wire", wirePkg) + info.Uses[newSetIdent] = types.NewFunc(token.NoPos, wirePkg, "NewSet", nil) + + depPkg := types.NewPackage("example.com/dep", "dep") + fnIdent := ast.NewIdent("NewMessage") + info.Uses[fnIdent] = types.NewFunc(token.NoPos, depPkg, "NewMessage", types.NewSignatureType(nil, nil, nil, nil, types.NewTuple(types.NewVar(token.NoPos, depPkg, "", types.Typ[types.String])), false)) + + call := &ast.CallExpr{ + Fun: &ast.SelectorExpr{X: wireIdent, Sel: newSetIdent}, + Args: []ast.Expr{ + fnIdent, + }, + } + got, ok := summarizeSemanticProviderSet(info, call, "example.com/app") + if !ok { + t.Fatal("summarizeSemanticProviderSet() = unsupported, want supported") + } + if len(got.Items) != 1 { + t.Fatalf("items len = %d, want 1", len(got.Items)) + } + if got.Items[0].Kind != "func" || got.Items[0].ImportPath != "example.com/dep" || got.Items[0].Name != "NewMessage" { + t.Fatalf("unexpected item: %+v", got.Items[0]) + } +} + +func TestSummarizeSemanticProviderSetTypeOnlyForms(t *testing.T) { + t.Parallel() + + info := &types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + Uses: make(map[*ast.Ident]types.Object), + } + wirePkg := types.NewPackage("github.com/goforj/wire", "wire") + wireIdent := ast.NewIdent("wire") + info.Uses[wireIdent] = types.NewPkgName(token.NoPos, nil, "wire", wirePkg) + + appPkg := types.NewPackage("example.com/app", "app") + fooObj := types.NewTypeName(token.NoPos, appPkg, "Foo", nil) + fooNamed := types.NewNamed(fooObj, types.NewStruct([]*types.Var{ + types.NewVar(token.NoPos, appPkg, "Message", types.Typ[types.String]), + }, []string{""}), nil) + fooIfaceObj := types.NewTypeName(token.NoPos, appPkg, "Fooer", nil) + fooIface := types.NewNamed(fooIfaceObj, types.NewInterfaceType(nil, nil).Complete(), nil) + + newSetIdent := ast.NewIdent("NewSet") + bindIdent := ast.NewIdent("Bind") + structIdent := ast.NewIdent("Struct") + fieldsIdent := ast.NewIdent("FieldsOf") + info.Uses[newSetIdent] = types.NewFunc(token.NoPos, wirePkg, "NewSet", nil) + info.Uses[bindIdent] = types.NewFunc(token.NoPos, wirePkg, "Bind", nil) + info.Uses[structIdent] = types.NewFunc(token.NoPos, wirePkg, "Struct", nil) + info.Uses[fieldsIdent] = types.NewFunc(token.NoPos, wirePkg, "FieldsOf", nil) + + newFooCall := &ast.CallExpr{Fun: ast.NewIdent("new"), Args: []ast.Expr{ast.NewIdent("Foo")}} + info.Types[newFooCall] = types.TypeAndValue{Type: types.NewPointer(fooNamed)} + newFooIfaceCall := &ast.CallExpr{Fun: ast.NewIdent("new"), Args: []ast.Expr{ast.NewIdent("Fooer")}} + info.Types[newFooIfaceCall] = types.TypeAndValue{Type: types.NewPointer(fooIface)} + ptrToPtrFooCall := &ast.CallExpr{Fun: ast.NewIdent("new"), Args: []ast.Expr{ast.NewIdent("FooPtr")}} + info.Types[ptrToPtrFooCall] = types.TypeAndValue{Type: types.NewPointer(types.NewPointer(fooNamed))} + + call := &ast.CallExpr{ + Fun: &ast.SelectorExpr{X: wireIdent, Sel: newSetIdent}, + Args: []ast.Expr{ + &ast.CallExpr{Fun: &ast.SelectorExpr{X: wireIdent, Sel: bindIdent}, Args: []ast.Expr{newFooIfaceCall, newFooCall}}, + &ast.CallExpr{Fun: &ast.SelectorExpr{X: wireIdent, Sel: structIdent}, Args: []ast.Expr{newFooCall, &ast.BasicLit{Kind: token.STRING, Value: "\"*\""}}}, + &ast.CallExpr{Fun: &ast.SelectorExpr{X: wireIdent, Sel: fieldsIdent}, Args: []ast.Expr{ptrToPtrFooCall, &ast.BasicLit{Kind: token.STRING, Value: "\"Message\""}}}, + }, + } + got, ok := summarizeSemanticProviderSet(info, call, "example.com/app") + if !ok { + t.Fatal("summarizeSemanticProviderSet() = unsupported, want supported") + } + if len(got.Items) != 3 { + t.Fatalf("items len = %d, want 3", len(got.Items)) + } + if got.Items[0].Kind != "bind" || got.Items[1].Kind != "struct" || got.Items[2].Kind != "fields" { + t.Fatalf("unexpected kinds: %+v", got.Items) + } +} + +func TestObjectCacheSemanticProviderSetFallback(t *testing.T) { + t.Parallel() + + fset := token.NewFileSet() + wirePkg := types.NewPackage("github.com/goforj/wire", "wire") + wireObj := types.NewTypeName(token.NoPos, wirePkg, "ProviderSet", nil) + wireNamed := types.NewNamed(wireObj, types.NewStruct(nil, nil), nil) + + depTypes := types.NewPackage("example.com/dep", "dep") + msgFnSig := types.NewSignatureType(nil, nil, nil, nil, types.NewTuple(types.NewVar(token.NoPos, depTypes, "", types.Typ[types.String])), false) + msgFn := types.NewFunc(token.NoPos, depTypes, "NewMessage", msgFnSig) + setVar := types.NewVar(token.NoPos, depTypes, "Set", wireNamed) + depTypes.Scope().Insert(msgFn) + depTypes.Scope().Insert(setVar) + + depPkg := &packages.Package{ + Name: "dep", + PkgPath: depTypes.Path(), + Types: depTypes, + Fset: fset, + Imports: make(map[string]*packages.Package), + } + oc := &objectCache{ + fset: fset, + packages: map[string]*packages.Package{depPkg.PkgPath: depPkg}, + objects: make(map[objRef]objCacheEntry), + semantic: map[string]*semanticcache.PackageArtifact{ + depPkg.PkgPath: { + Version: 1, + PackagePath: depPkg.PkgPath, + PackageName: depPkg.Name, + Supported: true, + Vars: map[string]semanticcache.ProviderSetArtifact{ + "Set": { + Items: []semanticcache.ProviderSetItemArtifact{ + {Kind: "func", ImportPath: depPkg.PkgPath, Name: "NewMessage"}, + }, + }, + }, + }, + }, + hasher: typeutil.MakeHasher(), + } + item, errs := oc.get(setVar) + if len(errs) > 0 { + t.Fatalf("oc.get(Set) errs = %v", errs) + } + pset, ok := item.(*ProviderSet) + if !ok || pset == nil { + t.Fatalf("oc.get(Set) type = %T, want *ProviderSet", item) + } + if len(pset.Providers) != 1 || pset.Providers[0].Name != "NewMessage" { + t.Fatalf("unexpected providers: %+v", pset.Providers) + } +} + +func TestObjectCacheSemanticProviderSetFallbackTypeOnlyForms(t *testing.T) { + t.Parallel() + + fset := token.NewFileSet() + wirePkg := types.NewPackage("github.com/goforj/wire", "wire") + wireObj := types.NewTypeName(token.NoPos, wirePkg, "ProviderSet", nil) + wireNamed := types.NewNamed(wireObj, types.NewStruct(nil, nil), nil) + + appTypes := types.NewPackage("example.com/app", "app") + fooIfaceObj := types.NewTypeName(token.NoPos, appTypes, "Fooer", nil) + _ = types.NewNamed(fooIfaceObj, types.NewInterfaceType(nil, nil).Complete(), nil) + fooObj := types.NewTypeName(token.NoPos, appTypes, "Foo", nil) + _ = types.NewNamed(fooObj, types.NewStruct([]*types.Var{ + types.NewVar(token.NoPos, appTypes, "Message", types.Typ[types.String]), + }, []string{""}), nil) + setVar := types.NewVar(token.NoPos, appTypes, "Set", wireNamed) + appTypes.Scope().Insert(fooIfaceObj) + appTypes.Scope().Insert(fooObj) + appTypes.Scope().Insert(setVar) + + appPkg := &packages.Package{ + Name: "app", + PkgPath: appTypes.Path(), + Types: appTypes, + Fset: fset, + Imports: make(map[string]*packages.Package), + } + oc := &objectCache{ + fset: fset, + packages: map[string]*packages.Package{appPkg.PkgPath: appPkg}, + objects: make(map[objRef]objCacheEntry), + semantic: map[string]*semanticcache.PackageArtifact{ + appPkg.PkgPath: { + Version: 1, + PackagePath: appPkg.PkgPath, + PackageName: appPkg.Name, + Supported: true, + Vars: map[string]semanticcache.ProviderSetArtifact{ + "Set": { + Items: []semanticcache.ProviderSetItemArtifact{ + { + Kind: "bind", + Type: semanticcache.TypeRef{ImportPath: appPkg.PkgPath, Name: "Fooer"}, + Type2: semanticcache.TypeRef{ImportPath: appPkg.PkgPath, Name: "Foo"}, + }, + { + Kind: "struct", + Type: semanticcache.TypeRef{ImportPath: appPkg.PkgPath, Name: "Foo"}, + AllFields: true, + }, + { + Kind: "fields", + Type: semanticcache.TypeRef{ImportPath: appPkg.PkgPath, Name: "Foo", Pointer: 2}, + FieldNames: []string{"Message"}, + }, + }, + }, + }, + }, + }, + hasher: typeutil.MakeHasher(), + } + item, errs := oc.get(setVar) + if len(errs) > 0 { + t.Fatalf("oc.get(Set) errs = %v", errs) + } + pset, ok := item.(*ProviderSet) + if !ok || pset == nil { + t.Fatalf("oc.get(Set) type = %T, want *ProviderSet", item) + } + if len(pset.Bindings) != 1 { + t.Fatalf("bindings len = %d, want 1", len(pset.Bindings)) + } + if len(pset.Providers) != 1 || !pset.Providers[0].IsStruct { + t.Fatalf("providers = %+v, want one struct provider", pset.Providers) + } + if len(pset.Fields) != 1 || pset.Fields[0].Name != "Message" { + t.Fatalf("fields = %+v, want Message field", pset.Fields) + } +} + func TestProcessFuncProviderErrors(t *testing.T) { t.Parallel() diff --git a/internal/wire/wire.go b/internal/wire/wire.go index 09bf814..99062ac 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -121,7 +121,7 @@ func Generate(ctx context.Context, wd string, env []string, patterns []string, o } generated[i].OutputPath = filepath.Join(outDir, opts.PrefixOutputFile+"wire_gen.go") g := newGen(pkg) - oc := newObjectCache([]*packages.Package{pkg}) + oc := newObjectCacheWithEnv([]*packages.Package{pkg}, env) injectorStart := time.Now() injectorFiles, genErrs := generateInjectors(oc, g, pkg) logTiming(ctx, "generate.package."+pkg.PkgPath+".injectors", injectorStart) From a8c2a0212bd481a18ba25525800407c903af9679 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sun, 15 Mar 2026 06:40:37 -0500 Subject: [PATCH 12/79] feat: go dep cache --- internal/loader/custom.go | 169 ++++++++++-- internal/loader/discovery.go | 4 + internal/loader/discovery_cache.go | 331 ++++++++++++++++++++++++ internal/loader/discovery_cache_test.go | 126 +++++++++ internal/wire/profile_bench_test.go | 32 +++ 5 files changed, 647 insertions(+), 15 deletions(-) create mode 100644 internal/loader/discovery_cache.go create mode 100644 internal/loader/discovery_cache_test.go create mode 100644 internal/wire/profile_bench_test.go diff --git a/internal/loader/custom.go b/internal/loader/custom.go index d4a7f66..c938acc 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -29,6 +29,7 @@ import ( "runtime/pprof" "sort" "strings" + "sync" "time" "golang.org/x/tools/go/gcexportdata" @@ -82,9 +83,18 @@ type customTypedGraphLoader struct { importer types.Importer loading map[string]bool isLocalCache map[string]bool + localSemanticOK map[string]bool + artifactPrefetch map[string]artifactPrefetchEntry stats typedLoadStats } +type artifactPrefetchEntry struct { + path string + data []byte + err error + ok bool +} + type typedLoadStats struct { read time.Duration parse time.Duration @@ -106,6 +116,9 @@ type typedLoadStats struct { artifactDecode time.Duration artifactImportLink time.Duration artifactWrite time.Duration + artifactPrefetch time.Duration + rootLoad time.Duration + discovery time.Duration artifactHits int artifactMisses int artifactWrites int @@ -225,6 +238,7 @@ func loadTypedPackageGraphCustom(ctx context.Context, req LazyLoadRequest) (*Laz meta map[string]*packageMeta err error ) + discoveryStart := time.Now() if req.Discovery != nil && len(req.Discovery.meta) > 0 { meta = req.Discovery.meta } else { @@ -239,6 +253,7 @@ func loadTypedPackageGraphCustom(ctx context.Context, req LazyLoadRequest) (*Laz return nil, err } } + discoveryDuration := time.Since(discoveryStart) if len(meta) == 0 { return nil, unsupportedError{reason: "empty go list result"} } @@ -259,12 +274,19 @@ func loadTypedPackageGraphCustom(ctx context.Context, req LazyLoadRequest) (*Laz importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), loading: make(map[string]bool, len(meta)), isLocalCache: make(map[string]bool, len(meta)), - stats: typedLoadStats{}, - } + localSemanticOK: make(map[string]bool, len(meta)), + artifactPrefetch: make(map[string]artifactPrefetchEntry, len(meta)), + stats: typedLoadStats{discovery: discoveryDuration}, + } + prefetchStart := time.Now() + l.prefetchArtifacts() + l.stats.artifactPrefetch = time.Since(prefetchStart) + rootLoadStart := time.Now() root, err := l.loadPackage(req.Package) if err != nil { return nil, err } + l.stats.rootLoad = time.Since(rootLoadStart) logDuration(ctx, "loader.custom.lazy.read_files.cumulative", l.stats.read) logDuration(ctx, "loader.custom.lazy.parse_files.cumulative", l.stats.parse) logDuration(ctx, "loader.custom.lazy.typecheck.cumulative", l.stats.typecheck) @@ -279,6 +301,9 @@ func loadTypedPackageGraphCustom(ctx context.Context, req LazyLoadRequest) (*Laz logDuration(ctx, "loader.custom.lazy.artifact_decode", l.stats.artifactDecode) logDuration(ctx, "loader.custom.lazy.artifact_import_link", l.stats.artifactImportLink) logDuration(ctx, "loader.custom.lazy.artifact_write", l.stats.artifactWrite) + logDuration(ctx, "loader.custom.lazy.artifact_prefetch.wall", l.stats.artifactPrefetch) + logDuration(ctx, "loader.custom.lazy.root_load.wall", l.stats.rootLoad) + logDuration(ctx, "loader.custom.lazy.discovery.wall", l.stats.discovery) logInt(ctx, "loader.custom.lazy.artifact_hits", l.stats.artifactHits) logInt(ctx, "loader.custom.lazy.artifact_misses", l.stats.artifactMisses) logInt(ctx, "loader.custom.lazy.artifact_writes", l.stats.artifactWrites) @@ -289,6 +314,7 @@ func loadTypedPackageGraphCustom(ctx context.Context, req LazyLoadRequest) (*Laz } func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLoadResult, error) { + discoveryStart := time.Now() meta, err := runGoList(ctx, goListRequest{ WD: req.WD, Env: req.Env, @@ -299,6 +325,7 @@ func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLo if err != nil { return nil, err } + discoveryDuration := time.Since(discoveryStart) if len(meta) == 0 { return nil, unsupportedError{reason: "empty go list result"} } @@ -329,8 +356,14 @@ func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLo importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), loading: make(map[string]bool, len(meta)), isLocalCache: make(map[string]bool, len(meta)), - stats: typedLoadStats{}, - } + localSemanticOK: make(map[string]bool, len(meta)), + artifactPrefetch: make(map[string]artifactPrefetchEntry, len(meta)), + stats: typedLoadStats{discovery: discoveryDuration}, + } + prefetchStart := time.Now() + l.prefetchArtifacts() + l.stats.artifactPrefetch = time.Since(prefetchStart) + rootLoadStart := time.Now() roots := make([]*packages.Package, 0, len(targets)) for _, m := range meta { if m.DepOnly { @@ -342,6 +375,7 @@ func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLo } roots = append(roots, root) } + l.stats.rootLoad = time.Since(rootLoadStart) sort.Slice(roots, func(i, j int) bool { return roots[i].PkgPath < roots[j].PkgPath }) logDuration(ctx, "loader.custom.typed.read_files.cumulative", l.stats.read) logDuration(ctx, "loader.custom.typed.parse_files.cumulative", l.stats.parse) @@ -357,6 +391,9 @@ func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLo logDuration(ctx, "loader.custom.typed.artifact_decode", l.stats.artifactDecode) logDuration(ctx, "loader.custom.typed.artifact_import_link", l.stats.artifactImportLink) logDuration(ctx, "loader.custom.typed.artifact_write", l.stats.artifactWrite) + logDuration(ctx, "loader.custom.typed.artifact_prefetch.wall", l.stats.artifactPrefetch) + logDuration(ctx, "loader.custom.typed.root_load.wall", l.stats.rootLoad) + logDuration(ctx, "loader.custom.typed.discovery.wall", l.stats.discovery) logInt(ctx, "loader.custom.typed.artifact_hits", l.stats.artifactHits) logInt(ctx, "loader.custom.typed.artifact_misses", l.stats.artifactMisses) logInt(ctx, "loader.custom.typed.artifact_writes", l.stats.artifactWrites) @@ -527,7 +564,7 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er } l.packages[path] = pkg } - useArtifact := loaderArtifactEnabled(l.env) && !isTarget && (!isLocal || l.useLocalSemanticArtifact(meta)) + useArtifact := l.shouldUseArtifact(path, meta, isTarget, isLocal) if useArtifact { if typed, ok := l.readArtifact(path, meta, isLocal); ok { linkStart := time.Now() @@ -641,10 +678,15 @@ func (l *customTypedGraphLoader) useLocalSemanticArtifact(meta *packageMeta) boo if meta == nil { return false } + if ok, exists := l.localSemanticOK[meta.ImportPath]; exists { + return ok + } art, err := semanticcache.Read(l.env, meta.ImportPath, meta.Name, metaFiles(meta)) if err != nil || art == nil { + l.localSemanticOK[meta.ImportPath] = false return false } + l.localSemanticOK[meta.ImportPath] = art.Supported return art.Supported } @@ -659,14 +701,26 @@ func (l *customTypedGraphLoader) checkFiles(path string, checker *types.Checker, func (l *customTypedGraphLoader) readArtifact(path string, meta *packageMeta, isLocal bool) (*types.Package, bool) { start := time.Now() - pathStart := time.Now() - artifactPath, err := loaderArtifactPath(l.env, meta, isLocal) - l.stats.artifactPath += time.Since(pathStart) - if err != nil { - debugf(l.ctx, "loader.artifact.read_miss pkg=%s local=%t reason=path_error err=%v", path, isLocal, err) - l.stats.artifactRead += time.Since(start) - l.stats.artifactMisses++ - return nil, false + entry, prefetched := l.artifactPrefetch[path] + artifactPath := "" + if prefetched { + artifactPath = entry.path + if entry.err != nil { + debugf(l.ctx, "loader.artifact.read_miss pkg=%s local=%t reason=prefetch_error err=%v", path, isLocal, entry.err) + l.stats.artifactMisses++ + return nil, false + } + } else { + pathStart := time.Now() + var err error + artifactPath, err = loaderArtifactPath(l.env, meta, isLocal) + l.stats.artifactPath += time.Since(pathStart) + if err != nil { + debugf(l.ctx, "loader.artifact.read_miss pkg=%s local=%t reason=path_error err=%v", path, isLocal, err) + l.stats.artifactRead += time.Since(start) + l.stats.artifactMisses++ + return nil, false + } } if isLocal { preloadStart := time.Now() @@ -690,7 +744,12 @@ func (l *customTypedGraphLoader) readArtifact(path string, meta *packageMeta, is } var tpkg *types.Package decodeStart := time.Now() - tpkg, err = readLoaderArtifact(artifactPath, l.fset, l.typesPkgs, path) + var err error + if prefetched { + tpkg, err = readLoaderArtifactData(entry.data, l.fset, l.typesPkgs, path) + } else { + tpkg, err = readLoaderArtifact(artifactPath, l.fset, l.typesPkgs, path) + } l.stats.artifactDecode += time.Since(decodeStart) if err != nil { debugf(l.ctx, "loader.artifact.read_miss pkg=%s local=%t reason=decode_error err=%v", path, isLocal, err) @@ -698,12 +757,92 @@ func (l *customTypedGraphLoader) readArtifact(path string, meta *packageMeta, is l.stats.artifactMisses++ return nil, false } - l.stats.artifactRead += time.Since(start) + if !prefetched { + l.stats.artifactRead += time.Since(start) + } l.stats.artifactHits++ l.typesPkgs[path] = tpkg return tpkg, true } +func (l *customTypedGraphLoader) shouldUseArtifact(path string, meta *packageMeta, isTarget, isLocal bool) bool { + if !loaderArtifactEnabled(l.env) || isTarget { + return false + } + if !isLocal { + return true + } + return l.useLocalSemanticArtifact(meta) +} + +func (l *customTypedGraphLoader) prefetchArtifacts() { + if !loaderArtifactEnabled(l.env) { + return + } + candidates := make([]string, 0, len(l.meta)) + for path, meta := range l.meta { + _, isTarget := l.targets[path] + isLocal := l.isLocalPackage(path, meta) + if l.shouldUseArtifact(path, meta, isTarget, isLocal) { + candidates = append(candidates, path) + } + } + sort.Strings(candidates) + if len(candidates) == 0 { + return + } + type result struct { + pkg string + entry artifactPrefetchEntry + dur time.Duration + } + jobs := make(chan string, len(candidates)) + results := make(chan result, len(candidates)) + workers := 8 + if len(candidates) < workers { + workers = len(candidates) + } + var wg sync.WaitGroup + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for path := range jobs { + start := time.Now() + meta := l.meta[path] + isLocal := l.isLocalPackage(path, meta) + artifactPath, err := loaderArtifactPath(l.env, meta, isLocal) + entry := artifactPrefetchEntry{path: artifactPath} + if err == nil { + data, readErr := os.ReadFile(artifactPath) + if readErr != nil { + entry.err = readErr + } else { + entry.data = data + entry.ok = true + } + } else { + entry.err = err + } + results <- result{pkg: path, entry: entry, dur: time.Since(start)} + } + }() + } + for _, path := range candidates { + jobs <- path + } + close(jobs) + wg.Wait() + close(results) + for res := range results { + l.artifactPrefetch[res.pkg] = res.entry + l.stats.artifactRead += res.dur + pathStart := time.Now() + _ = res.entry.path + l.stats.artifactPath += time.Since(pathStart) + } +} + func (l *customTypedGraphLoader) writeArtifact(meta *packageMeta, pkg *types.Package, isLocal bool) error { start := time.Now() artifactPath, err := loaderArtifactPath(l.env, meta, isLocal) diff --git a/internal/loader/discovery.go b/internal/loader/discovery.go index 9adbf4e..422a3a4 100644 --- a/internal/loader/discovery.go +++ b/internal/loader/discovery.go @@ -34,6 +34,9 @@ type goListRequest struct { } func runGoList(ctx context.Context, req goListRequest) (map[string]*packageMeta, error) { + if cached, ok := readDiscoveryCache(req); ok { + return cached, nil + } args := []string{"list", "-json", "-e", "-compiled", "-export"} if req.NeedDeps { args = append(args, "-deps") @@ -91,5 +94,6 @@ func runGoList(ctx context.Context, req goListRequest) (map[string]*packageMeta, copyMeta := meta out[meta.ImportPath] = ©Meta } + writeDiscoveryCache(req, out) return out, nil } diff --git a/internal/loader/discovery_cache.go b/internal/loader/discovery_cache.go new file mode 100644 index 0000000..4ec9a12 --- /dev/null +++ b/internal/loader/discovery_cache.go @@ -0,0 +1,331 @@ +package loader + +import ( + "bytes" + "crypto/sha256" + "encoding/gob" + "encoding/hex" + "go/parser" + "go/token" + "os" + "path/filepath" + "runtime" + "sort" +) + +type discoveryCacheEntry struct { + Version int + WD string + Tags string + Patterns []string + NeedDeps bool + Workspace string + Meta map[string]*packageMeta + Global []discoveryFileMeta + LocalPkgs []discoveryLocalPackage +} + +type discoveryLocalPackage struct { + ImportPath string + Dir string + DirMeta discoveryDirMeta + Files []discoveryFileFingerprint +} + +type discoveryFileMeta struct { + Path string + Size int64 + ModTime int64 + IsDir bool +} + +type discoveryDirMeta struct { + Path string + Entries []string +} + +type discoveryFileFingerprint struct { + Path string + Hash string +} + +func readDiscoveryCache(req goListRequest) (map[string]*packageMeta, bool) { + entry, err := loadDiscoveryCacheEntry(req) + if err != nil || entry == nil { + return nil, false + } + if !validateDiscoveryCacheEntry(entry) { + return nil, false + } + return clonePackageMetaMap(entry.Meta), true +} + +func writeDiscoveryCache(req goListRequest, meta map[string]*packageMeta) { + entry, err := buildDiscoveryCacheEntry(req, meta) + if err != nil { + return + } + _ = saveDiscoveryCacheEntry(req, entry) +} + +func buildDiscoveryCacheEntry(req goListRequest, meta map[string]*packageMeta) (*discoveryCacheEntry, error) { + workspace := detectModuleRoot(req.WD) + entry := &discoveryCacheEntry{ + Version: 2, + WD: canonicalLoaderPath(req.WD), + Tags: req.Tags, + Patterns: append([]string(nil), req.Patterns...), + NeedDeps: req.NeedDeps, + Workspace: workspace, + Meta: clonePackageMetaMap(meta), + } + global := []string{ + filepath.Join(workspace, "go.mod"), + filepath.Join(workspace, "go.sum"), + filepath.Join(workspace, "go.work"), + filepath.Join(workspace, "go.work.sum"), + } + for _, name := range global { + if fm, ok := statDiscoveryFile(name); ok { + entry.Global = append(entry.Global, fm) + } + } + locals := make([]discoveryLocalPackage, 0) + for _, pkg := range meta { + if pkg == nil || !isWorkspacePackage(workspace, pkg.Dir) { + continue + } + lp := discoveryLocalPackage{ + ImportPath: pkg.ImportPath, + Dir: pkg.Dir, + } + if fm, ok := statDiscoveryDir(pkg.Dir); ok { + lp.DirMeta = fm + } + for _, name := range metaFiles(pkg) { + if fm, ok := fingerprintDiscoveryFile(name); ok { + lp.Files = append(lp.Files, fm) + } + } + sort.Slice(lp.Files, func(i, j int) bool { return lp.Files[i].Path < lp.Files[j].Path }) + locals = append(locals, lp) + } + sort.Slice(locals, func(i, j int) bool { return locals[i].ImportPath < locals[j].ImportPath }) + entry.LocalPkgs = locals + return entry, nil +} + +func validateDiscoveryCacheEntry(entry *discoveryCacheEntry) bool { + if entry == nil || entry.Version != 2 { + return false + } + for _, fm := range entry.Global { + if !matchesDiscoveryFile(fm) { + return false + } + } + for _, lp := range entry.LocalPkgs { + if !matchesDiscoveryDir(lp.DirMeta) { + return false + } + for _, fm := range lp.Files { + if !matchesDiscoveryFingerprint(fm) { + return false + } + } + } + return true +} + +func discoveryCachePath(req goListRequest) (string, error) { + base, err := os.UserCacheDir() + if err != nil { + return "", err + } + sumReq := struct { + Version int + WD string + Tags string + Patterns []string + NeedDeps bool + Go string + }{ + Version: 2, + WD: canonicalLoaderPath(req.WD), + Tags: req.Tags, + Patterns: append([]string(nil), req.Patterns...), + NeedDeps: req.NeedDeps, + Go: runtime.Version(), + } + key, err := hashGob(sumReq) + if err != nil { + return "", err + } + return filepath.Join(base, "wire", "discovery-cache", key+".gob"), nil +} + +func loadDiscoveryCacheEntry(req goListRequest) (*discoveryCacheEntry, error) { + path, err := discoveryCachePath(req) + if err != nil { + return nil, err + } + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + var entry discoveryCacheEntry + if err := gob.NewDecoder(f).Decode(&entry); err != nil { + return nil, err + } + return &entry, nil +} + +func saveDiscoveryCacheEntry(req goListRequest, entry *discoveryCacheEntry) error { + path, err := discoveryCachePath(req) + if err != nil { + return err + } + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return err + } + f, err := os.Create(path) + if err != nil { + return err + } + defer f.Close() + return gob.NewEncoder(f).Encode(entry) +} + +func statDiscoveryFile(path string) (discoveryFileMeta, bool) { + info, err := os.Stat(path) + if err != nil { + return discoveryFileMeta{}, false + } + return discoveryFileMeta{ + Path: canonicalLoaderPath(path), + Size: info.Size(), + ModTime: info.ModTime().UnixNano(), + IsDir: info.IsDir(), + }, true +} + +func matchesDiscoveryFile(fm discoveryFileMeta) bool { + cur, ok := statDiscoveryFile(fm.Path) + if !ok { + return false + } + return cur.Size == fm.Size && cur.ModTime == fm.ModTime && cur.IsDir == fm.IsDir +} + +func statDiscoveryDir(path string) (discoveryDirMeta, bool) { + entries, err := os.ReadDir(path) + if err != nil { + return discoveryDirMeta{}, false + } + names := make([]string, 0, len(entries)) + for _, entry := range entries { + names = append(names, entry.Name()) + } + sort.Strings(names) + return discoveryDirMeta{ + Path: canonicalLoaderPath(path), + Entries: names, + }, true +} + +func matchesDiscoveryDir(dm discoveryDirMeta) bool { + cur, ok := statDiscoveryDir(dm.Path) + if !ok { + return false + } + if len(cur.Entries) != len(dm.Entries) { + return false + } + for i := range cur.Entries { + if cur.Entries[i] != dm.Entries[i] { + return false + } + } + return true +} + +func fingerprintDiscoveryFile(path string) (discoveryFileFingerprint, bool) { + src, err := os.ReadFile(path) + if err != nil { + return discoveryFileFingerprint{}, false + } + sum := sha256.New() + sum.Write([]byte(filepath.Base(path))) + sum.Write([]byte{0}) + file, err := parser.ParseFile(token.NewFileSet(), path, src, parser.ImportsOnly|parser.ParseComments) + if err != nil { + sum.Write(src) + return discoveryFileFingerprint{ + Path: canonicalLoaderPath(path), + Hash: hex.EncodeToString(sum.Sum(nil)), + }, true + } + if offset := int(file.Package) - 1; offset > 0 && offset <= len(src) { + sum.Write(src[:offset]) + } + sum.Write([]byte(file.Name.Name)) + sum.Write([]byte{0}) + for _, imp := range file.Imports { + if imp.Name != nil { + sum.Write([]byte(imp.Name.Name)) + } + sum.Write([]byte{0}) + sum.Write([]byte(imp.Path.Value)) + sum.Write([]byte{0}) + } + return discoveryFileFingerprint{ + Path: canonicalLoaderPath(path), + Hash: hex.EncodeToString(sum.Sum(nil)), + }, true +} + +func matchesDiscoveryFingerprint(fp discoveryFileFingerprint) bool { + cur, ok := fingerprintDiscoveryFile(fp.Path) + if !ok { + return false + } + return cur.Hash == fp.Hash +} + +func hashGob(v interface{}) (string, error) { + var buf bytes.Buffer + if err := gob.NewEncoder(&buf).Encode(v); err != nil { + return "", err + } + sum := sha256.Sum256(buf.Bytes()) + return hex.EncodeToString(sum[:]), nil +} + +func clonePackageMetaMap(in map[string]*packageMeta) map[string]*packageMeta { + if len(in) == 0 { + return nil + } + out := make(map[string]*packageMeta, len(in)) + for k, v := range in { + if v == nil { + continue + } + cp := *v + cp.GoFiles = append([]string(nil), v.GoFiles...) + cp.CompiledGoFiles = append([]string(nil), v.CompiledGoFiles...) + cp.Imports = append([]string(nil), v.Imports...) + if v.ImportMap != nil { + cp.ImportMap = make(map[string]string, len(v.ImportMap)) + for mk, mv := range v.ImportMap { + cp.ImportMap[mk] = mv + } + } + if v.Error != nil { + errCopy := *v.Error + cp.Error = &errCopy + } + out[k] = &cp + } + return out +} diff --git a/internal/loader/discovery_cache_test.go b/internal/loader/discovery_cache_test.go new file mode 100644 index 0000000..953824f --- /dev/null +++ b/internal/loader/discovery_cache_test.go @@ -0,0 +1,126 @@ +package loader + +import ( + "os" + "path/filepath" + "testing" +) + +func TestDiscoveryFingerprintIgnoresBodyOnlyEdits(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "pkg.go") + before := `package example + +import "fmt" + +func Provide() string { + return fmt.Sprint("before") +} +` + if err := os.WriteFile(path, []byte(before), 0o644); err != nil { + t.Fatal(err) + } + fpBefore, ok := fingerprintDiscoveryFile(path) + if !ok { + t.Fatalf("fingerprintDiscoveryFile(%q) failed", path) + } + after := `package example + +import "fmt" + +func Provide() string { + return fmt.Sprint("after") +} +` + if err := os.WriteFile(path, []byte(after), 0o644); err != nil { + t.Fatal(err) + } + fpAfter, ok := fingerprintDiscoveryFile(path) + if !ok { + t.Fatalf("fingerprintDiscoveryFile(%q) failed after body edit", path) + } + if fpBefore.Hash != fpAfter.Hash { + t.Fatalf("body-only edit changed fingerprint: %s != %s", fpBefore.Hash, fpAfter.Hash) + } +} + +func TestDiscoveryFingerprintDetectsImportChange(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "pkg.go") + before := `package example + +import "fmt" +` + if err := os.WriteFile(path, []byte(before), 0o644); err != nil { + t.Fatal(err) + } + fpBefore, ok := fingerprintDiscoveryFile(path) + if !ok { + t.Fatalf("fingerprintDiscoveryFile(%q) failed", path) + } + after := `package example + +import "strings" +` + if err := os.WriteFile(path, []byte(after), 0o644); err != nil { + t.Fatal(err) + } + fpAfter, ok := fingerprintDiscoveryFile(path) + if !ok { + t.Fatalf("fingerprintDiscoveryFile(%q) failed after import edit", path) + } + if fpBefore.Hash == fpAfter.Hash { + t.Fatalf("import edit did not change fingerprint") + } +} + +func TestDiscoveryFingerprintDetectsHeaderBuildTagChange(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "pkg.go") + before := `//go:build linux + +package example + +import "fmt" +` + if err := os.WriteFile(path, []byte(before), 0o644); err != nil { + t.Fatal(err) + } + fpBefore, ok := fingerprintDiscoveryFile(path) + if !ok { + t.Fatalf("fingerprintDiscoveryFile(%q) failed", path) + } + after := `//go:build darwin + +package example + +import "fmt" +` + if err := os.WriteFile(path, []byte(after), 0o644); err != nil { + t.Fatal(err) + } + fpAfter, ok := fingerprintDiscoveryFile(path) + if !ok { + t.Fatalf("fingerprintDiscoveryFile(%q) failed after header edit", path) + } + if fpBefore.Hash == fpAfter.Hash { + t.Fatalf("build tag edit did not change fingerprint") + } +} + +func TestDiscoveryDirDetectsFileSetChange(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "a.go"), []byte("package example\n"), 0o644); err != nil { + t.Fatal(err) + } + before, ok := statDiscoveryDir(dir) + if !ok { + t.Fatalf("statDiscoveryDir(%q) failed", dir) + } + if err := os.WriteFile(filepath.Join(dir, "b.go"), []byte("package example\n"), 0o644); err != nil { + t.Fatal(err) + } + if matchesDiscoveryDir(before) { + t.Fatalf("directory metadata did not detect added file") + } +} diff --git a/internal/wire/profile_bench_test.go b/internal/wire/profile_bench_test.go new file mode 100644 index 0000000..31dc7b7 --- /dev/null +++ b/internal/wire/profile_bench_test.go @@ -0,0 +1,32 @@ +package wire + +import ( + "context" + "os" + "testing" +) + +func BenchmarkGenerateRealAppWarmArtifacts(b *testing.B) { + root := os.Getenv("WIRE_REAL_APP_ROOT") + if root == "" { + b.Skip("WIRE_REAL_APP_ROOT not set") + } + artifactDir := b.TempDir() + env := append(os.Environ(), + "WIRE_LOADER_ARTIFACTS=1", + "WIRE_LOADER_ARTIFACT_DIR="+artifactDir, + ) + ctx := context.Background() + + // Warm the artifact cache once before measurement. + if _, errs := Generate(ctx, root, env, []string{"."}, &GenerateOptions{}); len(errs) > 0 { + b.Fatalf("warm Generate errors: %v", errs) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, errs := Generate(ctx, root, env, []string{"."}, &GenerateOptions{}); len(errs) > 0 { + b.Fatalf("Generate errors: %v", errs) + } + } +} From 26f591bef82bc19092839dfae14971dc89646da8 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sun, 15 Mar 2026 06:54:13 -0500 Subject: [PATCH 13/79] feat(loader): cache unchanged root output --- internal/loader/custom.go | 1 + internal/wire/output_cache.go | 273 ++++++++++++++++++++++++++++++++++ internal/wire/wire.go | 5 + 3 files changed, 279 insertions(+) create mode 100644 internal/wire/output_cache.go diff --git a/internal/loader/custom.go b/internal/loader/custom.go index c938acc..dd2b9e0 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -184,6 +184,7 @@ func loadRootGraphCustom(ctx context.Context, req RootLoadRequest) (*RootLoadRes PkgPath: m.ImportPath, GoFiles: append([]string(nil), metaFiles(m)...), CompiledGoFiles: append([]string(nil), metaFiles(m)...), + ExportFile: m.Export, Imports: make(map[string]*packages.Package), } if m.Error != nil && strings.TrimSpace(m.Error.Err) != "" { diff --git a/internal/wire/output_cache.go b/internal/wire/output_cache.go new file mode 100644 index 0000000..7d384fb --- /dev/null +++ b/internal/wire/output_cache.go @@ -0,0 +1,273 @@ +package wire + +import ( + "context" + "crypto/sha256" + "encoding/gob" + "encoding/hex" + "os" + "path/filepath" + "runtime" + "sort" + "strconv" + "strings" + + "golang.org/x/tools/go/packages" + + "github.com/goforj/wire/internal/loader" +) + +const outputCacheDirEnv = "WIRE_OUTPUT_CACHE_DIR" + +type outputCacheEntry struct { + Version int + Content []byte +} + +type outputCacheCandidate struct { + path string + outputPath string +} + +func prepareGenerateOutputCache(ctx context.Context, wd string, env []string, patterns []string, opts *GenerateOptions) (map[string]outputCacheCandidate, []GenerateResult, bool) { + if !outputCacheEnabled(ctx, wd, env) { + debugf(ctx, "generate.output_cache=disabled") + return nil, nil, false + } + rootResult, err := loader.New().LoadRootGraph(withLoaderTiming(ctx), loader.RootLoadRequest{ + WD: wd, + Env: env, + Tags: opts.Tags, + Patterns: append([]string(nil), patterns...), + NeedDeps: true, + Mode: effectiveLoaderMode(ctx, wd, env), + }) + if err != nil || rootResult == nil || len(rootResult.Packages) == 0 { + if err != nil { + debugf(ctx, "generate.output_cache=load_root_error") + } else { + debugf(ctx, "generate.output_cache=no_roots") + } + return nil, nil, false + } + candidates := make(map[string]outputCacheCandidate, len(rootResult.Packages)) + results := make([]GenerateResult, 0, len(rootResult.Packages)) + for _, pkg := range rootResult.Packages { + outDir, err := detectOutputDir(pkg.GoFiles) + if err != nil { + debugf(ctx, "generate.output_cache=bad_output_dir") + return candidates, nil, false + } + key, err := outputCacheKey(wd, opts, pkg) + if err != nil { + debugf(ctx, "generate.output_cache=key_error") + return candidates, nil, false + } + path, err := outputCachePath(env, key) + if err != nil { + debugf(ctx, "generate.output_cache=path_error") + return candidates, nil, false + } + candidates[pkg.PkgPath] = outputCacheCandidate{ + path: path, + outputPath: filepath.Join(outDir, opts.PrefixOutputFile+"wire_gen.go"), + } + entry, ok := readOutputCache(path) + if !ok { + debugf(ctx, "generate.output_cache=miss") + return candidates, nil, false + } + results = append(results, GenerateResult{ + PkgPath: pkg.PkgPath, + OutputPath: filepath.Join(outDir, opts.PrefixOutputFile+"wire_gen.go"), + Content: entry.Content, + }) + } + debugf(ctx, "generate.output_cache=hit") + return candidates, results, len(results) == len(rootResult.Packages) +} + +func writeGenerateOutputCache(candidates map[string]outputCacheCandidate, generated []GenerateResult) { + for _, gen := range generated { + candidate, ok := candidates[gen.PkgPath] + if !ok || candidate.path == "" || len(gen.Errs) > 0 || len(gen.Content) == 0 { + continue + } + _ = writeOutputCache(candidate.path, &outputCacheEntry{ + Version: 1, + Content: append([]byte(nil), gen.Content...), + }) + } +} + +func outputCacheEnabled(ctx context.Context, wd string, env []string) bool { + if effectiveLoaderMode(ctx, wd, env) == loader.ModeFallback { + return false + } + return envValue(env, "WIRE_LOADER_ARTIFACTS") == "1" +} + +func outputCachePath(env []string, key string) (string, error) { + dir, err := outputCacheDir(env) + if err != nil { + return "", err + } + return filepath.Join(dir, key+".gob"), nil +} + +func outputCacheDir(env []string) (string, error) { + if dir := envValue(env, outputCacheDirEnv); dir != "" { + return dir, nil + } + base, err := os.UserCacheDir() + if err != nil { + return "", err + } + return filepath.Join(base, "wire", "output-cache"), nil +} + +func readOutputCache(path string) (*outputCacheEntry, bool) { + f, err := os.Open(path) + if err != nil { + return nil, false + } + defer f.Close() + var entry outputCacheEntry + if err := gob.NewDecoder(f).Decode(&entry); err != nil { + return nil, false + } + if entry.Version != 1 { + return nil, false + } + return &entry, true +} + +func writeOutputCache(path string, entry *outputCacheEntry) error { + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return err + } + f, err := os.Create(path) + if err != nil { + return err + } + defer f.Close() + return gob.NewEncoder(f).Encode(entry) +} + +func outputCacheKey(wd string, opts *GenerateOptions, root *packages.Package) (string, error) { + sum := sha256.New() + sum.Write([]byte("wire-output-cache-v1\n")) + sum.Write([]byte(runtime.Version())) + sum.Write([]byte{'\n'}) + sum.Write([]byte(canonicalWirePath(wd))) + sum.Write([]byte{'\n'}) + sum.Write(opts.Header) + sum.Write([]byte{'\n'}) + sum.Write([]byte(opts.Tags)) + sum.Write([]byte{'\n'}) + sum.Write([]byte(root.PkgPath)) + sum.Write([]byte{'\n'}) + workspace := detectWireModuleRoot(wd) + pkgs := reachablePackages(root) + for _, pkg := range pkgs { + sum.Write([]byte(pkg.PkgPath)) + sum.Write([]byte{'\n'}) + if isLocalWirePackage(workspace, pkg) { + files := append([]string(nil), pkg.GoFiles...) + sort.Strings(files) + for _, name := range files { + info, err := os.Stat(name) + if err != nil { + return "", err + } + sum.Write([]byte(name)) + sum.Write([]byte{'\n'}) + sum.Write([]byte(strconv.FormatInt(info.Size(), 10))) + sum.Write([]byte{'\n'}) + sum.Write([]byte(strconv.FormatInt(info.ModTime().UnixNano(), 10))) + sum.Write([]byte{'\n'}) + if pkg.PkgPath == root.PkgPath { + src, err := os.ReadFile(name) + if err != nil { + return "", err + } + sum.Write(src) + sum.Write([]byte{'\n'}) + } + } + continue + } + sum.Write([]byte(pkg.ExportFile)) + sum.Write([]byte{'\n'}) + } + return hex.EncodeToString(sum.Sum(nil)), nil +} + +func reachablePackages(root *packages.Package) []*packages.Package { + seen := map[string]bool{} + var out []*packages.Package + var walk func(*packages.Package) + walk = func(pkg *packages.Package) { + if pkg == nil || seen[pkg.PkgPath] { + return + } + seen[pkg.PkgPath] = true + out = append(out, pkg) + paths := make([]string, 0, len(pkg.Imports)) + for path := range pkg.Imports { + paths = append(paths, path) + } + sort.Strings(paths) + for _, path := range paths { + walk(pkg.Imports[path]) + } + } + walk(root) + sort.Slice(out, func(i, j int) bool { return out[i].PkgPath < out[j].PkgPath }) + return out +} + +func isLocalWirePackage(workspace string, pkg *packages.Package) bool { + if pkg == nil || len(pkg.GoFiles) == 0 { + return false + } + dir := filepath.Dir(pkg.GoFiles[0]) + dir = canonicalWirePath(dir) + workspace = canonicalWirePath(workspace) + if dir == workspace { + return true + } + return len(dir) > len(workspace) && dir[:len(workspace)] == workspace && dir[len(workspace)] == filepath.Separator +} + +func detectWireModuleRoot(start string) string { + start = canonicalWirePath(start) + for dir := start; dir != "" && dir != "." && dir != string(filepath.Separator); dir = filepath.Dir(dir) { + if info, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil && !info.IsDir() { + return dir + } + next := filepath.Dir(dir) + if next == dir { + break + } + } + return start +} + +func canonicalWirePath(path string) string { + path = filepath.Clean(path) + if resolved, err := filepath.EvalSymlinks(path); err == nil && resolved != "" { + return filepath.Clean(resolved) + } + return path +} + +func envValue(env []string, key string) string { + for i := len(env) - 1; i >= 0; i-- { + name, value, ok := strings.Cut(env[i], "=") + if ok && name == key { + return value + } + } + return "" +} diff --git a/internal/wire/wire.go b/internal/wire/wire.go index 99062ac..3d787f3 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -102,6 +102,10 @@ func Generate(ctx context.Context, wd string, env []string, patterns []string, o if opts == nil { opts = &GenerateOptions{} } + cacheCandidates, cached, ok := prepareGenerateOutputCache(ctx, wd, env, patterns, opts) + if ok { + return cached, nil + } loadStart := time.Now() pkgs, errs := load(ctx, wd, env, opts.Tags, patterns) logTiming(ctx, "generate.load", loadStart) @@ -149,6 +153,7 @@ func Generate(ctx context.Context, wd string, env []string, patterns []string, o generated[i].Content = goSrc logTiming(ctx, "generate.package."+pkg.PkgPath+".total", pkgStart) } + writeGenerateOutputCache(cacheCandidates, generated) return generated, nil } From a357a384f501cc87d0687b497e07b67e947744ee Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sun, 15 Mar 2026 08:18:47 -0500 Subject: [PATCH 14/79] chore: bench tweaks --- internal/loader/artifact_cache.go | 2 +- internal/loader/discovery.go | 35 +++- internal/wire/import_bench_test.go | 311 +++++++++++++++++++++++++++++ internal/wire/output_cache.go | 2 +- scripts/import-benchmarks.sh | 33 +++ 5 files changed, 379 insertions(+), 4 deletions(-) create mode 100644 internal/wire/import_bench_test.go create mode 100755 scripts/import-benchmarks.sh diff --git a/internal/loader/artifact_cache.go b/internal/loader/artifact_cache.go index d495143..42293cb 100644 --- a/internal/loader/artifact_cache.go +++ b/internal/loader/artifact_cache.go @@ -34,7 +34,7 @@ const ( ) func loaderArtifactEnabled(env []string) bool { - return envValue(env, loaderArtifactEnv) == "1" + return envValue(env, loaderArtifactEnv) != "0" } func loaderArtifactDir(env []string) (string, error) { diff --git a/internal/loader/discovery.go b/internal/loader/discovery.go index 422a3a4..22b34d3 100644 --- a/internal/loader/discovery.go +++ b/internal/loader/discovery.go @@ -23,6 +23,7 @@ import ( "os" "os/exec" "path/filepath" + "time" ) type goListRequest struct { @@ -34,10 +35,18 @@ type goListRequest struct { } func runGoList(ctx context.Context, req goListRequest) (map[string]*packageMeta, error) { + cacheReadStart := time.Now() if cached, ok := readDiscoveryCache(req); ok { + logDuration(ctx, "loader.discovery.cache_read.wall", time.Since(cacheReadStart)) + logDuration(ctx, "loader.discovery.golist.wall", 0) + logDuration(ctx, "loader.discovery.decode.wall", 0) + logDuration(ctx, "loader.discovery.canonicalize.wall", 0) + logDuration(ctx, "loader.discovery.cache_build.wall", 0) + logDuration(ctx, "loader.discovery.cache_write.wall", 0) return cached, nil } - args := []string{"list", "-json", "-e", "-compiled", "-export"} + logDuration(ctx, "loader.discovery.cache_read.wall", time.Since(cacheReadStart)) + args := []string{"list", "-json", "-e", "-compiled"} if req.NeedDeps { args = append(args, "-deps") } @@ -60,22 +69,30 @@ func runGoList(ctx context.Context, req goListRequest) (map[string]*packageMeta, var stderr bytes.Buffer cmd.Stdout = &stdout cmd.Stderr = &stderr + goListStart := time.Now() if err := cmd.Run(); err != nil { return nil, fmt.Errorf("go list: %w: %s", err, stderr.String()) } + goListDuration := time.Since(goListStart) dec := json.NewDecoder(&stdout) out := make(map[string]*packageMeta) + var decodeDuration time.Duration + var canonicalizeDuration time.Duration for { var meta packageMeta + decodeStart := time.Now() if err := dec.Decode(&meta); err != nil { + decodeDuration += time.Since(decodeStart) if err == io.EOF { break } return nil, err } + decodeDuration += time.Since(decodeStart) if meta.ImportPath == "" { continue } + canonicalizeStart := time.Now() meta.Dir = canonicalLoaderPath(meta.Dir) for i, name := range meta.GoFiles { if !filepath.IsAbs(name) { @@ -91,9 +108,23 @@ func runGoList(ctx context.Context, req goListRequest) (map[string]*packageMeta, meta.Export = filepath.Join(meta.Dir, meta.Export) } meta.Imports = normalizeImports(meta.Imports, meta.ImportMap) + canonicalizeDuration += time.Since(canonicalizeStart) copyMeta := meta out[meta.ImportPath] = ©Meta } - writeDiscoveryCache(req, out) + cacheBuildStart := time.Now() + entry, err := buildDiscoveryCacheEntry(req, out) + cacheBuildDuration := time.Since(cacheBuildStart) + if err == nil && entry != nil { + cacheWriteStart := time.Now() + _ = saveDiscoveryCacheEntry(req, entry) + logDuration(ctx, "loader.discovery.cache_write.wall", time.Since(cacheWriteStart)) + } else { + logDuration(ctx, "loader.discovery.cache_write.wall", 0) + } + logDuration(ctx, "loader.discovery.golist.wall", goListDuration) + logDuration(ctx, "loader.discovery.decode.wall", decodeDuration) + logDuration(ctx, "loader.discovery.canonicalize.wall", canonicalizeDuration) + logDuration(ctx, "loader.discovery.cache_build.wall", cacheBuildDuration) return out, nil } diff --git a/internal/wire/import_bench_test.go b/internal/wire/import_bench_test.go new file mode 100644 index 0000000..e134bc2 --- /dev/null +++ b/internal/wire/import_bench_test.go @@ -0,0 +1,311 @@ +package wire + +import ( + "archive/tar" + "bytes" + "context" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + "time" +) + +const ( + importBenchEnv = "WIRE_IMPORT_BENCH_TABLE" + stockWireCommit = "9c25c9016f6825302537c4efdd5e897976f9c826" + stockWireModulePath = "github.com/google/wire" + currentWireModulePath = "github.com/goforj/wire" +) + +type importBenchRow struct { + imports int + stockCold time.Duration + currentCold time.Duration + currentWarm time.Duration +} + +const importBenchTrials = 3 + +func TestPrintImportScaleBenchmarkTable(t *testing.T) { + if os.Getenv(importBenchEnv) != "1" { + t.Skipf("%s not set", importBenchEnv) + } + repoRoot, err := importBenchRepoRoot() + if err != nil { + t.Fatal(err) + } + currentBin := buildWireBinary(t, repoRoot, "current-wire") + stockDir := extractStockWire(t, repoRoot, stockWireCommit) + stockBin := buildWireBinary(t, stockDir, "stock-wire") + + sizes := []int{10, 100, 1000} + rows := make([]importBenchRow, 0, len(sizes)) + for _, n := range sizes { + stockFixture := createImportBenchFixture(t, n, stockWireModulePath, stockDir) + currentFixture := createImportBenchFixture(t, n, currentWireModulePath, repoRoot) + rows = append(rows, importBenchRow{ + imports: n, + stockCold: medianDuration(runColdTrials(t, stockBin, stockFixture, importBenchTrials)), + currentCold: medianDuration(runColdTrials(t, currentBin, currentFixture, importBenchTrials)), + currentWarm: medianDuration(runWarmTrials(t, currentBin, currentFixture, importBenchTrials)), + }) + } + printImportBenchTable(t, rows) +} + +func buildWireBinary(t *testing.T, dir, name string) string { + t.Helper() + out := filepath.Join(t.TempDir(), name) + cmd := exec.Command("go", "build", "-o", out, "./cmd/wire") + cmd.Dir = dir + cmd.Env = benchEnv(t.TempDir(), filepath.Join(t.TempDir(), "gocache")) + output, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("build wire binary in %s: %v\n%s", dir, err, output) + } + return out +} + +func extractStockWire(t *testing.T, repoRoot, commit string) string { + t.Helper() + tmp := t.TempDir() + cmd := exec.Command("git", "archive", "--format=tar", commit) + cmd.Dir = repoRoot + stdout, err := cmd.StdoutPipe() + if err != nil { + t.Fatal(err) + } + if err := cmd.Start(); err != nil { + t.Fatalf("git archive start: %v", err) + } + tr := tar.NewReader(stdout) + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("read stock tar: %v", err) + } + target := filepath.Join(tmp, hdr.Name) + switch hdr.Typeflag { + case tar.TypeDir: + if err := os.MkdirAll(target, os.FileMode(hdr.Mode)); err != nil { + t.Fatalf("mkdir %s: %v", target, err) + } + case tar.TypeReg: + if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil { + t.Fatalf("mkdir parent %s: %v", target, err) + } + f, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(hdr.Mode)) + if err != nil { + t.Fatalf("create %s: %v", target, err) + } + if _, err := io.Copy(f, tr); err != nil { + _ = f.Close() + t.Fatalf("write %s: %v", target, err) + } + if err := f.Close(); err != nil { + t.Fatalf("close %s: %v", target, err) + } + } + } + if err := cmd.Wait(); err != nil { + t.Fatalf("git archive wait: %v", err) + } + return tmp +} + +func createImportBenchFixture(t *testing.T, imports int, wireModulePath, wireReplaceDir string) string { + t.Helper() + root := t.TempDir() + if err := os.WriteFile(filepath.Join(root, "go.mod"), []byte(importBenchGoMod(wireModulePath, wireReplaceDir)), 0o644); err != nil { + t.Fatal(err) + } + for i := 0; i < imports; i++ { + dir := filepath.Join(root, fmt.Sprintf("dep%04d", i)) + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatal(err) + } + src := fmt.Sprintf("package dep%04d\n\ntype T struct{}\n\nfunc Provide() *T { return &T{} }\n", i) + if err := os.WriteFile(filepath.Join(dir, "dep.go"), []byte(src), 0o644); err != nil { + t.Fatal(err) + } + } + if err := os.MkdirAll(filepath.Join(root, "app"), 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(root, "app", "wire.go"), []byte(importBenchWireFile(imports, wireModulePath)), 0o644); err != nil { + t.Fatal(err) + } + return filepath.Join(root, "app") +} + +func runWireBenchCommand(t *testing.T, bin, pkgDir, home, goCache string) time.Duration { + t.Helper() + cmd := exec.Command(bin, "gen") + cmd.Dir = pkgDir + cmd.Env = append(benchEnv(home, goCache), "WIRE_LOADER_ARTIFACTS=1") + var stderr bytes.Buffer + cmd.Stdout = io.Discard + cmd.Stderr = &stderr + start := time.Now() + if err := cmd.Run(); err != nil { + t.Fatalf("run %s in %s: %v\n%s", bin, pkgDir, err, stderr.String()) + } + return time.Since(start) +} + +func runColdTrials(t *testing.T, bin, pkgDir string, trials int) []time.Duration { + t.Helper() + durations := make([]time.Duration, 0, trials) + for i := 0; i < trials; i++ { + home := t.TempDir() + goCache := filepath.Join(t.TempDir(), "gocache") + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, home, goCache)) + } + return durations +} + +func runWarmTrials(t *testing.T, bin, pkgDir string, trials int) []time.Duration { + t.Helper() + durations := make([]time.Duration, 0, trials) + for i := 0; i < trials; i++ { + home := t.TempDir() + goCache := filepath.Join(t.TempDir(), "gocache") + _ = runWireBenchCommand(t, bin, pkgDir, home, goCache) + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, home, goCache)) + } + return durations +} + +func medianDuration(durations []time.Duration) time.Duration { + if len(durations) == 0 { + return 0 + } + sorted := append([]time.Duration(nil), durations...) + for i := 1; i < len(sorted); i++ { + for j := i; j > 0 && sorted[j] < sorted[j-1]; j-- { + sorted[j], sorted[j-1] = sorted[j-1], sorted[j] + } + } + return sorted[len(sorted)/2] +} + +func benchEnv(home, goCache string) []string { + env := append([]string(nil), os.Environ()...) + env = append(env, + "HOME="+home, + "GOCACHE="+goCache, + "GOMODCACHE=/tmp/gomodcache", + ) + return env +} + +func importBenchGoMod(wireModulePath, wireReplaceDir string) string { + return fmt.Sprintf(`module example.com/importbench + +go 1.26 + +require %s v0.0.0 + +replace %s => %s +`, wireModulePath, wireModulePath, wireReplaceDir) +} + +func importBenchWireFile(imports int, wireModulePath string) string { + var b strings.Builder + b.WriteString("//go:build wireinject\n\n") + b.WriteString("package app\n\n") + b.WriteString("import (\n") + b.WriteString(fmt.Sprintf("\twire %q\n", wireModulePath)) + for i := 0; i < imports; i++ { + b.WriteString(fmt.Sprintf("\t%[1]q\n", fmt.Sprintf("example.com/importbench/dep%04d", i))) + } + b.WriteString(")\n\n") + b.WriteString("type App struct{}\n\n") + b.WriteString("func provideApp(") + for i := 0; i < imports; i++ { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(fmt.Sprintf("d%d *dep%04d.T", i, i)) + } + b.WriteString(") *App {\n\treturn &App{}\n}\n\n") + b.WriteString("func Initialize() *App {\n\twire.Build(wire.NewSet(\n") + for i := 0; i < imports; i++ { + b.WriteString(fmt.Sprintf("\t\tdep%04d.Provide,\n", i)) + } + b.WriteString("\t\tprovideApp,\n\t))\n\treturn nil\n}\n") + return b.String() +} + +func printImportBenchTable(t *testing.T, rows []importBenchRow) { + t.Helper() + fmt.Println("+-----------+-----------+--------------+-------------------+--------------+-------------------+") + fmt.Println("| repo size | stock | current cold | current unchanged | cold speedup | unchanged speedup |") + fmt.Println("+-----------+-----------+--------------+-------------------+--------------+-------------------+") + for _, row := range rows { + fmt.Printf("| %-9d | %-9s | %-12s | %-17s | %-12s | %-17s |\n", + row.imports, + formatMs(row.stockCold), + formatMs(row.currentCold), + formatMs(row.currentWarm), + formatSpeedup(row.stockCold, row.currentCold), + formatSpeedup(row.stockCold, row.currentWarm), + ) + } + fmt.Println("+-----------+-----------+--------------+-------------------+--------------+-------------------+") +} + +func formatMs(d time.Duration) string { + return fmt.Sprintf("%.1fms", float64(d)/float64(time.Millisecond)) +} + +func formatSpeedup(oldDur, newDur time.Duration) string { + if newDur == 0 { + return "inf" + } + return fmt.Sprintf("%.2fx", float64(oldDur)/float64(newDur)) +} + +func TestImportBenchFixtureGenerates(t *testing.T) { + repoRoot, err := importBenchRepoRoot() + if err != nil { + t.Fatal(err) + } + bin := buildWireBinary(t, repoRoot, "fixture-wire") + fixture := createImportBenchFixture(t, 10, currentWireModulePath, repoRoot) + _ = runWireBenchCommand(t, bin, fixture, t.TempDir(), filepath.Join(t.TempDir(), "gocache")) +} + +func TestImportBenchUsesStockArchive(t *testing.T) { + repoRoot, err := importBenchRepoRoot() + if err != nil { + t.Fatal(err) + } + stockDir := extractStockWire(t, repoRoot, stockWireCommit) + if _, err := os.Stat(filepath.Join(stockDir, "cmd", "wire", "main.go")); err != nil { + t.Fatalf("stock archive missing cmd/wire: %v", err) + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + cmd := exec.CommandContext(ctx, "go", "list", "./cmd/wire") + cmd.Dir = stockDir + cmd.Env = benchEnv(t.TempDir(), filepath.Join(t.TempDir(), "gocache")) + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("stock archive not buildable: %v\n%s", err, out) + } +} + +func importBenchRepoRoot() (string, error) { + wd, err := os.Getwd() + if err != nil { + return "", err + } + return filepath.Clean(filepath.Join(wd, "..", "..")), nil +} diff --git a/internal/wire/output_cache.go b/internal/wire/output_cache.go index 7d384fb..35eacfe 100644 --- a/internal/wire/output_cache.go +++ b/internal/wire/output_cache.go @@ -104,7 +104,7 @@ func outputCacheEnabled(ctx context.Context, wd string, env []string) bool { if effectiveLoaderMode(ctx, wd, env) == loader.ModeFallback { return false } - return envValue(env, "WIRE_LOADER_ARTIFACTS") == "1" + return envValue(env, "WIRE_LOADER_ARTIFACTS") != "0" } func outputCachePath(env []string, key string) (string, error) { diff --git a/scripts/import-benchmarks.sh b/scripts/import-benchmarks.sh new file mode 100755 index 0000000..f1f4e58 --- /dev/null +++ b/scripts/import-benchmarks.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash + +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +export GOCACHE="${GOCACHE:-/tmp/gocache}" +export GOMODCACHE="${GOMODCACHE:-/tmp/gomodcache}" + +usage() { + cat <<'EOF' +Usage: + scripts/import-benchmarks.sh table + +Commands: + table Print the 10/100/1000 import stock-vs-current benchmark table. +EOF +} + +case "${1:-}" in + table) + WIRE_IMPORT_BENCH_TABLE=1 go test ./internal/wire -run TestPrintImportScaleBenchmarkTable -count=1 -v + ;; + ""|-h|--help|help) + usage + ;; + *) + echo "Unknown command: ${1}" >&2 + usage >&2 + exit 1 + ;; +esac From d6b36b16bcb8465a12602def5c9691dd1a4eb33c Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sun, 15 Mar 2026 15:52:16 -0500 Subject: [PATCH 15/79] chore: re-implement cache --- cmd/wire/cache_cmd.go | 181 +++++++++++++++++++++++++++++ cmd/wire/cache_cmd_test.go | 96 +++++++++++++++ cmd/wire/main.go | 2 + internal/wire/import_bench_test.go | 64 +++++++++- scripts/import-benchmarks.sh | 7 +- 5 files changed, 347 insertions(+), 3 deletions(-) create mode 100644 cmd/wire/cache_cmd.go create mode 100644 cmd/wire/cache_cmd_test.go diff --git a/cmd/wire/cache_cmd.go b/cmd/wire/cache_cmd.go new file mode 100644 index 0000000..cdbbd40 --- /dev/null +++ b/cmd/wire/cache_cmd.go @@ -0,0 +1,181 @@ +// Copyright 2018 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "flag" + "fmt" + "log" + "os" + "path/filepath" + "sort" + "strings" + + "github.com/google/subcommands" +) + +const ( + loaderArtifactDirEnv = "WIRE_LOADER_ARTIFACT_DIR" + outputCacheDirEnv = "WIRE_OUTPUT_CACHE_DIR" + semanticCacheDirEnv = "WIRE_SEMANTIC_CACHE_DIR" +) + +var osUserCacheDir = os.UserCacheDir + +type cacheCmd struct { + clear bool +} + +type cacheTarget struct { + name string + path string +} + +func (*cacheCmd) Name() string { return "cache" } + +func (*cacheCmd) Synopsis() string { + return "inspect or clear the wire cache" +} + +func (*cacheCmd) Usage() string { + return `cache +cache clear +cache -clear + + By default, prints the cache directory. With -clear or clear, removes all + Wire-managed cache files. +` +} + +func (cmd *cacheCmd) SetFlags(f *flag.FlagSet) { + f.BoolVar(&cmd.clear, "clear", false, "clear Wire caches") +} + +func (cmd *cacheCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + _ = ctx + clearRequested := cmd.clear + switch extra := f.Args(); len(extra) { + case 0: + if !clearRequested { + root, err := wireCacheRoot(os.Environ()) + if err != nil { + log.Println(err) + return subcommands.ExitFailure + } + fmt.Fprintln(os.Stdout, root) + return subcommands.ExitSuccess + } + case 1: + if extra[0] == "clear" { + clearRequested = true + break + } + log.Printf("unknown cache action %q", extra[0]) + log.Println(strings.TrimSpace(cmd.Usage())) + return subcommands.ExitFailure + default: + log.Println(strings.TrimSpace(cmd.Usage())) + return subcommands.ExitFailure + } + if !clearRequested { + log.Println(strings.TrimSpace(cmd.Usage())) + return subcommands.ExitFailure + } + cleared, err := clearWireCaches(os.Environ()) + if err != nil { + log.Printf("failed to clear cache: %v\n", err) + return subcommands.ExitFailure + } + root, err := wireCacheRoot(os.Environ()) + if err != nil { + log.Println(err) + return subcommands.ExitFailure + } + if len(cleared) == 0 { + log.Printf("cleared cache at %s\n", root) + return subcommands.ExitSuccess + } + log.Printf("cleared cache at %s\n", root) + return subcommands.ExitSuccess +} + +func wireCacheRoot(env []string) (string, error) { + base, err := osUserCacheDir() + if err != nil { + return "", fmt.Errorf("resolve user cache dir: %w", err) + } + return filepath.Join(base, "wire"), nil +} + +func clearWireCaches(env []string) ([]string, error) { + base, err := wireCacheRoot(env) + if err != nil { + return nil, err + } + targets := wireCacheTargets(env, filepath.Dir(base)) + cleared := make([]string, 0, len(targets)) + for _, target := range targets { + info, err := os.Stat(target.path) + if os.IsNotExist(err) { + continue + } + if err != nil { + return cleared, fmt.Errorf("stat %s cache: %w", target.name, err) + } + if !info.IsDir() { + if err := os.Remove(target.path); err != nil { + return cleared, fmt.Errorf("remove %s cache: %w", target.name, err) + } + } else if err := os.RemoveAll(target.path); err != nil { + return cleared, fmt.Errorf("remove %s cache: %w", target.name, err) + } + cleared = append(cleared, target.name) + } + return cleared, nil +} + +func wireCacheTargets(env []string, userCacheDir string) []cacheTarget { + baseWire := filepath.Join(userCacheDir, "wire") + targets := []cacheTarget{ + {name: "loader-artifacts", path: envValueDefault(env, loaderArtifactDirEnv, filepath.Join(baseWire, "loader-artifacts"))}, + {name: "discovery-cache", path: filepath.Join(baseWire, "discovery-cache")}, + {name: "semantic-artifacts", path: envValueDefault(env, semanticCacheDirEnv, filepath.Join(baseWire, "semantic-artifacts"))}, + {name: "output-cache", path: envValueDefault(env, outputCacheDirEnv, filepath.Join(baseWire, "output-cache"))}, + } + seen := make(map[string]bool, len(targets)) + deduped := make([]cacheTarget, 0, len(targets)) + for _, target := range targets { + cleaned := filepath.Clean(target.path) + if seen[cleaned] { + continue + } + seen[cleaned] = true + target.path = cleaned + deduped = append(deduped, target) + } + sort.Slice(deduped, func(i, j int) bool { return deduped[i].name < deduped[j].name }) + return deduped +} + +func envValueDefault(env []string, key, fallback string) string { + for i := len(env) - 1; i >= 0; i-- { + parts := strings.SplitN(env[i], "=", 2) + if len(parts) == 2 && parts[0] == key && parts[1] != "" { + return parts[1] + } + } + return fallback +} diff --git a/cmd/wire/cache_cmd_test.go b/cmd/wire/cache_cmd_test.go new file mode 100644 index 0000000..c0c74c1 --- /dev/null +++ b/cmd/wire/cache_cmd_test.go @@ -0,0 +1,96 @@ +package main + +import ( + "os" + "path/filepath" + "testing" +) + +func TestWireCacheTargetsDefault(t *testing.T) { + base := filepath.Join(t.TempDir(), "cache") + got := wireCacheTargets(nil, base) + want := map[string]string{ + "discovery-cache": filepath.Join(base, "wire", "discovery-cache"), + "loader-artifacts": filepath.Join(base, "wire", "loader-artifacts"), + "output-cache": filepath.Join(base, "wire", "output-cache"), + "semantic-artifacts": filepath.Join(base, "wire", "semantic-artifacts"), + } + if len(got) != len(want) { + t.Fatalf("targets len = %d, want %d", len(got), len(want)) + } + for _, target := range got { + if target.path != want[target.name] { + t.Fatalf("%s path = %q, want %q", target.name, target.path, want[target.name]) + } + } +} + +func TestWireCacheRoot(t *testing.T) { + base := filepath.Join(t.TempDir(), "cache") + old := osUserCacheDir + osUserCacheDir = func() (string, error) { return base, nil } + defer func() { osUserCacheDir = old }() + + got, err := wireCacheRoot(nil) + if err != nil { + t.Fatalf("wireCacheRoot() error = %v", err) + } + want := filepath.Join(base, "wire") + if got != want { + t.Fatalf("wireCacheRoot() = %q, want %q", got, want) + } +} + +func TestWireCacheTargetsRespectOverrides(t *testing.T) { + base := filepath.Join(t.TempDir(), "cache") + env := []string{ + loaderArtifactDirEnv + "=" + filepath.Join(base, "loader"), + outputCacheDirEnv + "=" + filepath.Join(base, "output"), + semanticCacheDirEnv + "=" + filepath.Join(base, "semantic"), + } + got := wireCacheTargets(env, base) + want := map[string]string{ + "discovery-cache": filepath.Join(base, "wire", "discovery-cache"), + "loader-artifacts": filepath.Join(base, "loader"), + "output-cache": filepath.Join(base, "output"), + "semantic-artifacts": filepath.Join(base, "semantic"), + } + for _, target := range got { + if target.path != want[target.name] { + t.Fatalf("%s path = %q, want %q", target.name, target.path, want[target.name]) + } + } +} + +func TestClearWireCachesRemovesTargets(t *testing.T) { + base := filepath.Join(t.TempDir(), "cache") + env := []string{ + loaderArtifactDirEnv + "=" + filepath.Join(base, "loader"), + outputCacheDirEnv + "=" + filepath.Join(base, "output"), + semanticCacheDirEnv + "=" + filepath.Join(base, "semantic"), + } + for _, target := range wireCacheTargets(env, base) { + if err := os.MkdirAll(target.path, 0o755); err != nil { + t.Fatalf("MkdirAll(%q): %v", target.path, err) + } + if err := os.WriteFile(filepath.Join(target.path, "marker"), []byte("x"), 0o644); err != nil { + t.Fatalf("WriteFile(%q): %v", target.path, err) + } + } + old := osUserCacheDir + osUserCacheDir = func() (string, error) { return base, nil } + defer func() { osUserCacheDir = old }() + + cleared, err := clearWireCaches(env) + if err != nil { + t.Fatalf("clearWireCaches() error = %v", err) + } + if len(cleared) != 4 { + t.Fatalf("cleared len = %d, want 4", len(cleared)) + } + for _, target := range wireCacheTargets(env, base) { + if _, err := os.Stat(target.path); !os.IsNotExist(err) { + t.Fatalf("%s still exists after clear, stat err = %v", target.path, err) + } + } +} diff --git a/cmd/wire/main.go b/cmd/wire/main.go index f7fd92f..515673c 100644 --- a/cmd/wire/main.go +++ b/cmd/wire/main.go @@ -49,6 +49,7 @@ func main() { subcommands.Register(subcommands.CommandsCommand(), "") subcommands.Register(subcommands.FlagsCommand(), "") subcommands.Register(subcommands.HelpCommand(), "") + subcommands.Register(&cacheCmd{}, "") subcommands.Register(&checkCmd{}, "") subcommands.Register(&diffCmd{}, "") subcommands.Register(&genCmd{}, "") @@ -69,6 +70,7 @@ func main() { "commands": true, // builtin "help": true, // builtin "flags": true, // builtin + "cache": true, "check": true, "diff": true, "gen": true, diff --git a/internal/wire/import_bench_test.go b/internal/wire/import_bench_test.go index e134bc2..770e938 100644 --- a/internal/wire/import_bench_test.go +++ b/internal/wire/import_bench_test.go @@ -16,6 +16,7 @@ import ( const ( importBenchEnv = "WIRE_IMPORT_BENCH_TABLE" + importBenchBreakdown = "WIRE_IMPORT_BENCH_BREAKDOWN" stockWireCommit = "9c25c9016f6825302537c4efdd5e897976f9c826" stockWireModulePath = "github.com/google/wire" currentWireModulePath = "github.com/goforj/wire" @@ -57,6 +58,57 @@ func TestPrintImportScaleBenchmarkTable(t *testing.T) { printImportBenchTable(t, rows) } +func TestPrintImportScaleBenchmarkBreakdown(t *testing.T) { + if os.Getenv(importBenchBreakdown) != "1" { + t.Skipf("%s not set", importBenchBreakdown) + } + repoRoot, err := importBenchRepoRoot() + if err != nil { + t.Fatal(err) + } + currentBin := buildWireBinary(t, repoRoot, "current-wire") + stockDir := extractStockWire(t, repoRoot, stockWireCommit) + stockBin := buildWireBinary(t, stockDir, "stock-wire") + + const imports = 1000 + stockFixture := createImportBenchFixture(t, imports, stockWireModulePath, stockDir) + currentFixture := createImportBenchFixture(t, imports, currentWireModulePath, repoRoot) + + stockCold := medianDuration(runColdTrials(t, stockBin, stockFixture, importBenchTrials)) + currentCold := medianDuration(runColdTrials(t, currentBin, currentFixture, importBenchTrials)) + currentWarm := medianDuration(runWarmTrials(t, currentBin, currentFixture, importBenchTrials)) + + fmt.Printf("repo size: %d\n", imports) + fmt.Printf("stock cold: %s\n", formatMs(stockCold)) + fmt.Printf("current cold: %s\n", formatMs(currentCold)) + fmt.Printf("current unchanged: %s\n", formatMs(currentWarm)) + fmt.Printf("cold speedup: %s\n", formatSpeedup(stockCold, currentCold)) + fmt.Printf("unchanged speedup: %s\n", formatSpeedup(stockCold, currentWarm)) + fmt.Printf("cold gap: %s\n", formatMs(currentCold-stockCold)) + + home := t.TempDir() + goCache := filepath.Join(t.TempDir(), "gocache") + _, output := runWireBenchCommandOutput(t, currentBin, currentFixture, home, goCache, "-timings") + fmt.Println("current cold timings:") + for _, line := range strings.Split(output, "\n") { + if !strings.Contains(line, "wire: timing:") { + continue + } + if strings.Contains(line, "loader.custom.root.discovery=") || + strings.Contains(line, "loader.discovery.") || + strings.Contains(line, "load.packages.load=") || + strings.Contains(line, "loader.custom.typed.artifact_write=") || + strings.Contains(line, "loader.custom.typed.root_load.wall=") || + strings.Contains(line, "loader.custom.typed.discovery.wall=") || + strings.Contains(line, "loader.custom.typed.artifact_writes=") || + strings.Contains(line, "generate.package.") || + strings.Contains(line, "wire.Generate=") || + strings.Contains(line, "total=") { + fmt.Println(line) + } + } +} + func buildWireBinary(t *testing.T, dir, name string) string { t.Helper() out := filepath.Join(t.TempDir(), name) @@ -147,7 +199,15 @@ func createImportBenchFixture(t *testing.T, imports int, wireModulePath, wireRep func runWireBenchCommand(t *testing.T, bin, pkgDir, home, goCache string) time.Duration { t.Helper() - cmd := exec.Command(bin, "gen") + d, _ := runWireBenchCommandOutput(t, bin, pkgDir, home, goCache) + return d +} + +func runWireBenchCommandOutput(t *testing.T, bin, pkgDir, home, goCache string, extraArgs ...string) (time.Duration, string) { + t.Helper() + args := []string{"gen"} + args = append(args, extraArgs...) + cmd := exec.Command(bin, args...) cmd.Dir = pkgDir cmd.Env = append(benchEnv(home, goCache), "WIRE_LOADER_ARTIFACTS=1") var stderr bytes.Buffer @@ -157,7 +217,7 @@ func runWireBenchCommand(t *testing.T, bin, pkgDir, home, goCache string) time.D if err := cmd.Run(); err != nil { t.Fatalf("run %s in %s: %v\n%s", bin, pkgDir, err, stderr.String()) } - return time.Since(start) + return time.Since(start), stderr.String() } func runColdTrials(t *testing.T, bin, pkgDir string, trials int) []time.Duration { diff --git a/scripts/import-benchmarks.sh b/scripts/import-benchmarks.sh index f1f4e58..e1c1f97 100755 --- a/scripts/import-benchmarks.sh +++ b/scripts/import-benchmarks.sh @@ -12,9 +12,11 @@ usage() { cat <<'EOF' Usage: scripts/import-benchmarks.sh table + scripts/import-benchmarks.sh breakdown Commands: - table Print the 10/100/1000 import stock-vs-current benchmark table. + table Print the 10/100/1000 import stock-vs-current benchmark table. + breakdown Print a focused 1000-import cold/unchanged breakdown. EOF } @@ -22,6 +24,9 @@ case "${1:-}" in table) WIRE_IMPORT_BENCH_TABLE=1 go test ./internal/wire -run TestPrintImportScaleBenchmarkTable -count=1 -v ;; + breakdown) + WIRE_IMPORT_BENCH_BREAKDOWN=1 go test ./internal/wire -run TestPrintImportScaleBenchmarkBreakdown -count=1 -v + ;; ""|-h|--help|help) usage ;; From 84087dd9dd619fb4d4f4d0316244a379e2e32822 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sun, 15 Mar 2026 17:59:33 -0500 Subject: [PATCH 16/79] fix: provider discovery --- internal/wire/parse.go | 13 ++++++++++++- internal/wire/parse_coverage_test.go | 21 +++++++-------------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index a825b4b..08a3c8a 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -574,6 +574,11 @@ func (oc *objectCache) semanticProviderSet(obj *types.Var) (*ProviderSet, bool, if !ok { return nil, false, nil } + for _, item := range setArt.Items { + if item.Kind == "bind" { + return nil, false, nil + } + } pset := &ProviderSet{ Pos: obj.Pos(), PkgPath: obj.Pkg().Path(), @@ -1856,5 +1861,11 @@ func bindShouldUsePointer(info *types.Info, call *ast.CallExpr) bool { fun := call.Fun.(*ast.SelectorExpr) // wire.Bind pkgName := fun.X.(*ast.Ident) // wire wireName := info.ObjectOf(pkgName).(*types.PkgName) // wire package - return wireName.Imported().Scope().Lookup("bindToUsePointer") != nil + if imported := wireName.Imported(); imported != nil { + if isWireImport(imported.Path()) { + return true + } + return imported.Scope().Lookup("bindToUsePointer") != nil + } + return false } diff --git a/internal/wire/parse_coverage_test.go b/internal/wire/parse_coverage_test.go index 3a23d18..c3c4d8e 100644 --- a/internal/wire/parse_coverage_test.go +++ b/internal/wire/parse_coverage_test.go @@ -367,7 +367,7 @@ func TestObjectCacheSemanticProviderSetFallback(t *testing.T) { } } -func TestObjectCacheSemanticProviderSetFallbackTypeOnlyForms(t *testing.T) { +func TestObjectCacheSemanticProviderSetSkipsBindArtifacts(t *testing.T) { t.Parallel() fset := token.NewFileSet() @@ -429,22 +429,15 @@ func TestObjectCacheSemanticProviderSetFallbackTypeOnlyForms(t *testing.T) { }, hasher: typeutil.MakeHasher(), } - item, errs := oc.get(setVar) + pset, ok, errs := oc.semanticProviderSet(setVar) if len(errs) > 0 { - t.Fatalf("oc.get(Set) errs = %v", errs) - } - pset, ok := item.(*ProviderSet) - if !ok || pset == nil { - t.Fatalf("oc.get(Set) type = %T, want *ProviderSet", item) - } - if len(pset.Bindings) != 1 { - t.Fatalf("bindings len = %d, want 1", len(pset.Bindings)) + t.Fatalf("semanticProviderSet(Set) errs = %v", errs) } - if len(pset.Providers) != 1 || !pset.Providers[0].IsStruct { - t.Fatalf("providers = %+v, want one struct provider", pset.Providers) + if ok { + t.Fatalf("semanticProviderSet(Set) ok = true, want false") } - if len(pset.Fields) != 1 || pset.Fields[0].Name != "Message" { - t.Fatalf("fields = %+v, want Message field", pset.Fields) + if pset != nil { + t.Fatalf("semanticProviderSet(Set) = %#v, want nil", pset) } } From e968dccc78c647fc54f3b47a3140cb545fbf91d5 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sun, 15 Mar 2026 21:01:43 -0500 Subject: [PATCH 17/79] chore: benchmark update --- internal/wire/import_bench_test.go | 1086 ++++++++++++++++++++++++++-- scripts/import-benchmarks.sh | 5 + 2 files changed, 1050 insertions(+), 41 deletions(-) diff --git a/internal/wire/import_bench_test.go b/internal/wire/import_bench_test.go index 770e938..aff71e2 100644 --- a/internal/wire/import_bench_test.go +++ b/internal/wire/import_bench_test.go @@ -4,6 +4,7 @@ import ( "archive/tar" "bytes" "context" + "encoding/json" "fmt" "io" "os" @@ -17,6 +18,8 @@ import ( const ( importBenchEnv = "WIRE_IMPORT_BENCH_TABLE" importBenchBreakdown = "WIRE_IMPORT_BENCH_BREAKDOWN" + importBenchScenarios = "WIRE_IMPORT_BENCH_SCENARIOS" + importBenchScenarioBD = "WIRE_IMPORT_BENCH_SCENARIO_BREAKDOWN" stockWireCommit = "9c25c9016f6825302537c4efdd5e897976f9c826" stockWireModulePath = "github.com/google/wire" currentWireModulePath = "github.com/goforj/wire" @@ -29,6 +32,27 @@ type importBenchRow struct { currentWarm time.Duration } +type importBenchScenarioRow struct { + profile string + localCount int + stdlibCount int + externalCount int + name string + stock time.Duration + current time.Duration +} + +type benchCaches struct { + home string + goCache string +} + +type benchGraphCounts struct { + local int + stdlib int + external int +} + const importBenchTrials = 3 func TestPrintImportScaleBenchmarkTable(t *testing.T) { @@ -42,6 +66,8 @@ func TestPrintImportScaleBenchmarkTable(t *testing.T) { currentBin := buildWireBinary(t, repoRoot, "current-wire") stockDir := extractStockWire(t, repoRoot, stockWireCommit) stockBin := buildWireBinary(t, stockDir, "stock-wire") + stockCaches := newBenchCaches(t) + currentCaches := newBenchCaches(t) sizes := []int{10, 100, 1000} rows := make([]importBenchRow, 0, len(sizes)) @@ -50,9 +76,9 @@ func TestPrintImportScaleBenchmarkTable(t *testing.T) { currentFixture := createImportBenchFixture(t, n, currentWireModulePath, repoRoot) rows = append(rows, importBenchRow{ imports: n, - stockCold: medianDuration(runColdTrials(t, stockBin, stockFixture, importBenchTrials)), - currentCold: medianDuration(runColdTrials(t, currentBin, currentFixture, importBenchTrials)), - currentWarm: medianDuration(runWarmTrials(t, currentBin, currentFixture, importBenchTrials)), + stockCold: medianDuration(runColdTrials(t, stockBin, stockFixture, stockCaches, importBenchTrials)), + currentCold: medianDuration(runColdTrials(t, currentBin, currentFixture, currentCaches, importBenchTrials)), + currentWarm: medianDuration(runWarmTrials(t, currentBin, currentFixture, currentCaches, importBenchTrials)), }) } printImportBenchTable(t, rows) @@ -69,14 +95,16 @@ func TestPrintImportScaleBenchmarkBreakdown(t *testing.T) { currentBin := buildWireBinary(t, repoRoot, "current-wire") stockDir := extractStockWire(t, repoRoot, stockWireCommit) stockBin := buildWireBinary(t, stockDir, "stock-wire") + stockCaches := newBenchCaches(t) + currentCaches := newBenchCaches(t) const imports = 1000 stockFixture := createImportBenchFixture(t, imports, stockWireModulePath, stockDir) currentFixture := createImportBenchFixture(t, imports, currentWireModulePath, repoRoot) - stockCold := medianDuration(runColdTrials(t, stockBin, stockFixture, importBenchTrials)) - currentCold := medianDuration(runColdTrials(t, currentBin, currentFixture, importBenchTrials)) - currentWarm := medianDuration(runWarmTrials(t, currentBin, currentFixture, importBenchTrials)) + stockCold := medianDuration(runColdTrials(t, stockBin, stockFixture, stockCaches, importBenchTrials)) + currentCold := medianDuration(runColdTrials(t, currentBin, currentFixture, currentCaches, importBenchTrials)) + currentWarm := medianDuration(runWarmTrials(t, currentBin, currentFixture, currentCaches, importBenchTrials)) fmt.Printf("repo size: %d\n", imports) fmt.Printf("stock cold: %s\n", formatMs(stockCold)) @@ -86,27 +114,202 @@ func TestPrintImportScaleBenchmarkBreakdown(t *testing.T) { fmt.Printf("unchanged speedup: %s\n", formatSpeedup(stockCold, currentWarm)) fmt.Printf("cold gap: %s\n", formatMs(currentCold-stockCold)) - home := t.TempDir() - goCache := filepath.Join(t.TempDir(), "gocache") - _, output := runWireBenchCommandOutput(t, currentBin, currentFixture, home, goCache, "-timings") + prewarmGoBenchCache(t, currentFixture, currentCaches) + _, output := runWireBenchCommandOutput(t, currentBin, currentFixture, currentCaches, "-timings") fmt.Println("current cold timings:") - for _, line := range strings.Split(output, "\n") { - if !strings.Contains(line, "wire: timing:") { - continue - } - if strings.Contains(line, "loader.custom.root.discovery=") || - strings.Contains(line, "loader.discovery.") || - strings.Contains(line, "load.packages.load=") || - strings.Contains(line, "loader.custom.typed.artifact_write=") || - strings.Contains(line, "loader.custom.typed.root_load.wall=") || - strings.Contains(line, "loader.custom.typed.discovery.wall=") || - strings.Contains(line, "loader.custom.typed.artifact_writes=") || - strings.Contains(line, "generate.package.") || - strings.Contains(line, "wire.Generate=") || - strings.Contains(line, "total=") { - fmt.Println(line) - } + printScenarioTimingLines(output) +} + +func TestPrintImportScenarioBenchmarkTable(t *testing.T) { + if os.Getenv(importBenchScenarios) != "1" { + t.Skipf("%s not set", importBenchScenarios) } + repoRoot, err := importBenchRepoRoot() + if err != nil { + t.Fatal(err) + } + currentBin := buildWireBinary(t, repoRoot, "current-wire") + stockDir := extractStockWire(t, repoRoot, stockWireCommit) + stockBin := buildWireBinary(t, stockDir, "stock-wire") + + type appBenchProfile struct { + localPkgs int + depPkgs int + external bool + label string + } + profiles := []appBenchProfile{ + {localPkgs: 10, depPkgs: 25, label: "local"}, + {localPkgs: 10, depPkgs: 1000, label: "local-high"}, + {localPkgs: 10, depPkgs: 25, external: true, label: "external"}, + {localPkgs: 10, depPkgs: 100, external: true, label: "external"}, + } + rows := make([]importBenchScenarioRow, 0, len(profiles)*6) + for _, profile := range profiles { + shapeFixture := createAppShapeBenchFixture(t, profile.localPkgs, profile.depPkgs, profile.external, currentWireModulePath, repoRoot) + shapeCounts := goListGraphCounts(t, shapeFixture, "example.com/appbench", newBenchCaches(t)) + rows = append(rows, + importBenchScenarioRow{ + profile: profile.label, + localCount: shapeCounts.local, + stdlibCount: shapeCounts.stdlib, + externalCount: shapeCounts.external, + name: "cold run", + stock: medianDuration(runAppColdTrials(t, stockBin, profile.localPkgs, profile.depPkgs, profile.external, stockWireModulePath, stockDir, importBenchTrials)), + current: medianDuration(runAppColdTrials(t, currentBin, profile.localPkgs, profile.depPkgs, profile.external, currentWireModulePath, repoRoot, importBenchTrials)), + }, + importBenchScenarioRow{ + profile: profile.label, + localCount: shapeCounts.local, + stdlibCount: shapeCounts.stdlib, + externalCount: shapeCounts.external, + name: "unchanged rerun", + stock: medianDuration(runAppWarmTrials(t, stockBin, profile.localPkgs, profile.depPkgs, profile.external, stockWireModulePath, stockDir, importBenchTrials)), + current: medianDuration(runAppWarmTrials(t, currentBin, profile.localPkgs, profile.depPkgs, profile.external, currentWireModulePath, repoRoot, importBenchTrials)), + }, + importBenchScenarioRow{ + profile: profile.label, + localCount: shapeCounts.local, + stdlibCount: shapeCounts.stdlib, + externalCount: shapeCounts.external, + name: "body-only local edit", + stock: medianDuration(runAppScenarioTrials(t, stockBin, profile.localPkgs, profile.depPkgs, profile.external, stockWireModulePath, stockDir, "body", importBenchTrials)), + current: medianDuration(runAppScenarioTrials(t, currentBin, profile.localPkgs, profile.depPkgs, profile.external, currentWireModulePath, repoRoot, "body", importBenchTrials)), + }, + importBenchScenarioRow{ + profile: profile.label, + localCount: shapeCounts.local, + stdlibCount: shapeCounts.stdlib, + externalCount: shapeCounts.external, + name: "shape change", + stock: medianDuration(runAppScenarioTrials(t, stockBin, profile.localPkgs, profile.depPkgs, profile.external, stockWireModulePath, stockDir, "shape", importBenchTrials)), + current: medianDuration(runAppScenarioTrials(t, currentBin, profile.localPkgs, profile.depPkgs, profile.external, currentWireModulePath, repoRoot, "shape", importBenchTrials)), + }, + importBenchScenarioRow{ + profile: profile.label, + localCount: shapeCounts.local, + stdlibCount: shapeCounts.stdlib, + externalCount: shapeCounts.external, + name: "import change", + stock: medianDuration(runAppScenarioTrials(t, stockBin, profile.localPkgs, profile.depPkgs, profile.external, stockWireModulePath, stockDir, "import", importBenchTrials)), + current: medianDuration(runAppScenarioTrials(t, currentBin, profile.localPkgs, profile.depPkgs, profile.external, currentWireModulePath, repoRoot, "import", importBenchTrials)), + }, + importBenchScenarioRow{ + profile: profile.label, + localCount: shapeCounts.local, + stdlibCount: shapeCounts.stdlib, + externalCount: shapeCounts.external, + name: "known import toggle", + stock: medianDuration(runAppKnownToggleTrials(t, stockBin, profile.localPkgs, profile.depPkgs, profile.external, stockWireModulePath, stockDir, importBenchTrials)), + current: medianDuration(runAppKnownToggleTrials(t, currentBin, profile.localPkgs, profile.depPkgs, profile.external, currentWireModulePath, repoRoot, importBenchTrials)), + }, + ) + } + printImportScenarioBenchTable(t, rows) +} + +func TestPrintImportScenarioBenchmarkBreakdown(t *testing.T) { + if os.Getenv(importBenchScenarioBD) != "1" { + t.Skipf("%s not set", importBenchScenarioBD) + } + repoRoot, err := importBenchRepoRoot() + if err != nil { + t.Fatal(err) + } + currentBin := buildWireBinary(t, repoRoot, "current-wire") + stockDir := extractStockWire(t, repoRoot, stockWireCommit) + stockBin := buildWireBinary(t, stockDir, "stock-wire") + + const ( + localPkgs = 10 + depPkgs = 1000 + ) + + stockPkgDir := createAppShapeBenchFixture(t, localPkgs, depPkgs, false, stockWireModulePath, stockDir) + currentPkgDir := createAppShapeBenchFixture(t, localPkgs, depPkgs, false, currentWireModulePath, repoRoot) + stockCaches := newBenchCaches(t) + currentCaches := newBenchCaches(t) + + prewarmGoBenchCache(t, stockPkgDir, stockCaches) + _ = runWireBenchCommand(t, stockBin, stockPkgDir, stockCaches) + writeAppShapeControllerFile(t, filepath.Dir(stockPkgDir), 0, "shape") + _ = runWireBenchCommand(t, stockBin, stockPkgDir, stockCaches) + writeAppShapeControllerFile(t, filepath.Dir(stockPkgDir), 0, "base") + stockDur := runWireBenchCommand(t, stockBin, stockPkgDir, stockCaches) + + prewarmGoBenchCache(t, currentPkgDir, currentCaches) + _ = runWireBenchCommand(t, currentBin, currentPkgDir, currentCaches) + writeAppShapeControllerFile(t, filepath.Dir(currentPkgDir), 0, "shape") + _ = runWireBenchCommand(t, currentBin, currentPkgDir, currentCaches) + writeAppShapeControllerFile(t, filepath.Dir(currentPkgDir), 0, "base") + currentDur, currentOutput := runWireBenchCommandOutput(t, currentBin, currentPkgDir, currentCaches, "-timings") + + fmt.Printf("scenario: local=%d dep=%d known import toggle\n", localPkgs, depPkgs) + fmt.Printf("stock: %s\n", formatMs(stockDur)) + fmt.Printf("current: %s\n", formatMs(currentDur)) + fmt.Printf("speedup: %s\n", formatSpeedup(stockDur, currentDur)) + fmt.Println("current timings:") + printScenarioTimingLines(currentOutput) +} + +func runAppColdTrials(t *testing.T, bin string, features, depPkgs int, external bool, wireModulePath, wireReplaceDir string, trials int) []time.Duration { + t.Helper() + durations := make([]time.Duration, 0, trials) + pkgDir := createAppShapeBenchFixture(t, features, depPkgs, external, wireModulePath, wireReplaceDir) + for i := 0; i < trials; i++ { + caches := newBenchCaches(t) + prewarmGoBenchCache(t, pkgDir, caches) + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, caches)) + } + return durations +} + +func runAppWarmTrials(t *testing.T, bin string, features, depPkgs int, external bool, wireModulePath, wireReplaceDir string, trials int) []time.Duration { + t.Helper() + durations := make([]time.Duration, 0, trials) + pkgDir := createAppShapeBenchFixture(t, features, depPkgs, external, wireModulePath, wireReplaceDir) + caches := newBenchCaches(t) + for i := 0; i < trials; i++ { + resetAppShapeBenchFixture(t, pkgDir, features) + prewarmGoBenchCache(t, pkgDir, caches) + _ = runWireBenchCommand(t, bin, pkgDir, caches) + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, caches)) + } + return durations +} + +func runAppScenarioTrials(t *testing.T, bin string, features, depPkgs int, external bool, wireModulePath, wireReplaceDir, variant string, trials int) []time.Duration { + t.Helper() + durations := make([]time.Duration, 0, trials) + pkgDir := createAppShapeBenchFixture(t, features, depPkgs, external, wireModulePath, wireReplaceDir) + caches := newBenchCaches(t) + root := filepath.Dir(pkgDir) + for i := 0; i < trials; i++ { + resetAppShapeBenchFixture(t, pkgDir, features) + prewarmGoBenchCache(t, pkgDir, caches) + _ = runWireBenchCommand(t, bin, pkgDir, caches) + writeAppShapeControllerFile(t, root, 0, variant) + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, caches)) + } + return durations +} + +func runAppKnownToggleTrials(t *testing.T, bin string, features, depPkgs int, external bool, wireModulePath, wireReplaceDir string, trials int) []time.Duration { + t.Helper() + durations := make([]time.Duration, 0, trials) + pkgDir := createAppShapeBenchFixture(t, features, depPkgs, external, wireModulePath, wireReplaceDir) + caches := newBenchCaches(t) + root := filepath.Dir(pkgDir) + for i := 0; i < trials; i++ { + resetAppShapeBenchFixture(t, pkgDir, features) + prewarmGoBenchCache(t, pkgDir, caches) + _ = runWireBenchCommand(t, bin, pkgDir, caches) + writeAppShapeControllerFile(t, root, 0, "shape") + _ = runWireBenchCommand(t, bin, pkgDir, caches) + writeAppShapeControllerFile(t, root, 0, "base") + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, caches)) + } + return durations } func buildWireBinary(t *testing.T, dir, name string) string { @@ -122,6 +325,14 @@ func buildWireBinary(t *testing.T, dir, name string) string { return out } +func newBenchCaches(t *testing.T) benchCaches { + t.Helper() + return benchCaches{ + home: t.TempDir(), + goCache: filepath.Join(t.TempDir(), "gocache"), + } +} + func extractStockWire(t *testing.T, repoRoot, commit string) string { t.Helper() tmp := t.TempDir() @@ -183,8 +394,7 @@ func createImportBenchFixture(t *testing.T, imports int, wireModulePath, wireRep if err := os.MkdirAll(dir, 0o755); err != nil { t.Fatal(err) } - src := fmt.Sprintf("package dep%04d\n\ntype T struct{}\n\nfunc Provide() *T { return &T{} }\n", i) - if err := os.WriteFile(filepath.Join(dir, "dep.go"), []byte(src), 0o644); err != nil { + if err := os.WriteFile(filepath.Join(dir, "dep.go"), []byte(importBenchDepFile(i, "base")), 0o644); err != nil { t.Fatal(err) } } @@ -197,19 +407,614 @@ func createImportBenchFixture(t *testing.T, imports int, wireModulePath, wireRep return filepath.Join(root, "app") } -func runWireBenchCommand(t *testing.T, bin, pkgDir, home, goCache string) time.Duration { +func createAppShapeBenchFixture(t *testing.T, features, depPkgs int, external bool, wireModulePath, wireReplaceDir string) string { + t.Helper() + root := t.TempDir() + modulePath := "example.com/appbench" + if err := os.WriteFile(filepath.Join(root, "go.mod"), []byte(appShapeGoMod(modulePath, wireModulePath, wireReplaceDir, external)), 0o644); err != nil { + t.Fatal(err) + } + if external { + seedAppShapeExternalGoSum(t, root) + } + for i := 0; i < depPkgs; i++ { + writeAppShapeFile(t, filepath.Join(root, "internal", fmt.Sprintf("dep%04d", i), "dep.go"), appShapeDepFile(i)) + } + writeAppShapeFile(t, filepath.Join(root, "internal", "logger", "logger.go"), appShapeLoggerFile(modulePath)) + writeAppShapeFile(t, filepath.Join(root, "internal", "cache", "cache.go"), appShapeCacheFile(modulePath)) + writeAppShapeFile(t, filepath.Join(root, "internal", "db", "db.go"), appShapeDBFile(modulePath)) + writeAppShapeFile(t, filepath.Join(root, "internal", "config", "config.go"), appShapeConfigFile(modulePath)) + writeAppShapeFile(t, filepath.Join(root, "internal", "metrics", "metrics.go"), appShapeMetricsFile(modulePath)) + writeAppShapeFile(t, filepath.Join(root, "internal", "httpx", "httpx.go"), appShapeHTTPXFile(modulePath)) + if external { + writeAppShapeFile(t, filepath.Join(root, "internal", "extsink", "extsink.go"), appShapeExtSinkFile(modulePath)) + } + writeAppShapeFile(t, filepath.Join(root, "wire", "app.go"), appShapeAppFile(modulePath, features)) + writeAppShapeFile(t, filepath.Join(root, "wire", "wire.go"), appShapeWireFile(modulePath, wireModulePath, features, external)) + for i := 0; i < features; i++ { + writeAppShapeFile(t, filepath.Join(root, "internal", fmt.Sprintf("feature%04d", i), "feature.go"), appShapeFeatureFile(modulePath, wireModulePath, i, depPkgs, external)) + writeAppShapeControllerFile(t, root, i, "base") + } + return filepath.Join(root, "wire") +} + +func writeAppShapeFile(t *testing.T, path, content string) { + t.Helper() + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(path, []byte(content), 0o644); err != nil { + t.Fatal(err) + } +} + +func writeAppShapeControllerFile(t *testing.T, root string, index int, variant string) { + t.Helper() + path := filepath.Join(root, "internal", fmt.Sprintf("feature%04d", index), "controller.go") + if err := os.WriteFile(path, []byte(appShapeControllerFile("example.com/appbench", index, variant)), 0o644); err != nil { + t.Fatal(err) + } +} + +func seedAppShapeExternalGoSum(t *testing.T, root string) { + t.Helper() + const source = "/private/tmp/test/go.sum" + data, err := os.ReadFile(source) + if err != nil { + return + } + if err := os.WriteFile(filepath.Join(root, "go.sum"), data, 0o644); err != nil { + t.Fatalf("write seeded go.sum: %v", err) + } +} + +func resetAppShapeBenchFixture(t *testing.T, pkgDir string, features int) { + t.Helper() + root := filepath.Dir(pkgDir) + for i := 0; i < features; i++ { + writeAppShapeControllerFile(t, root, i, "base") + } +} + +func appShapeGoMod(modulePath, wireModulePath, wireReplaceDir string, external bool) string { + extraRequires := "" + if external { + extraRequires = ` + github.com/alecthomas/kong v1.14.0 + github.com/charmbracelet/lipgloss v1.1.0 + github.com/fsnotify/fsnotify v1.7.0 + github.com/glebarez/sqlite v1.11.0 + github.com/goforj/cache v0.1.5 + github.com/goforj/crypt v1.1.0 + github.com/goforj/env/v2 v2.3.0 + github.com/goforj/httpx v1.1.0 + github.com/goforj/null/v6 v6.0.2 + github.com/goforj/queue v0.1.5 + github.com/goforj/queue/driver/redisqueue v0.1.5 + github.com/goforj/scheduler v1.4.0 + github.com/goforj/storage v0.2.5 + github.com/goforj/storage/driver/localstorage v0.2.5 + github.com/goforj/storage/driver/redisstorage v0.2.5 + github.com/goforj/str v1.3.0 + github.com/google/go-cmp v0.6.0 + github.com/google/subcommands v1.2.0 + github.com/google/uuid v1.6.0 + github.com/gorilla/websocket v1.5.3 + github.com/hibiken/asynq v0.26.0 + github.com/imroc/req/v3 v3.57.0 + github.com/labstack/echo/v4 v4.15.1 + github.com/pmezard/go-difflib v1.0.0 + github.com/redis/go-redis/v9 v9.17.2 + github.com/rs/zerolog v1.34.0 + github.com/shirou/gopsutil/v4 v4.26.2 + golang.org/x/mod v0.33.0 + golang.org/x/net v0.50.0 + golang.org/x/sync v0.19.0 + golang.org/x/sys v0.41.0 + golang.org/x/term v0.40.0 + golang.org/x/tools v0.42.0 + gopkg.in/yaml.v3 v3.0.1 + gorm.io/driver/mysql v1.6.0 + gorm.io/driver/postgres v1.6.0 + gorm.io/gorm v1.31.1` + } + return fmt.Sprintf(`module %s + +go 1.26 + +require ( + %s v0.0.0%s +) + +replace %s => %s +`, modulePath, wireModulePath, extraRequires, wireModulePath, wireReplaceDir) +} + +func appShapeLoggerFile(modulePath string) string { + return `package logger + +import ( + "context" + "encoding/json" + "io" + "os" + "sync" + "time" +) + +type Logger struct { + sink io.Writer + mu sync.Mutex +} + +func NewLogger() *Logger { return &Logger{sink: os.Stdout} } + +func (l *Logger) Log(ctx context.Context, msg string, attrs map[string]string) { + l.mu.Lock() + defer l.mu.Unlock() + _, _ = json.Marshal(map[string]any{ + "ctx": ctx != nil, + "msg": msg, + "attrs": attrs, + "time": time.Now().UTC().Format(time.RFC3339Nano), + }) +} +` +} + +func appShapeCacheFile(modulePath string) string { + return `package cache + +type Manager struct{} + +func NewManager() *Manager { return &Manager{} } +` +} + +func appShapeDBFile(modulePath string) string { + return `package db + +import ( + "context" + "database/sql" + "net/url" + "path/filepath" +) + +type DB struct { + driver string + dsn string +} + +func NewDB() *DB { + _ = filepath.Join("var", "lib", "appbench") + _ = sql.LevelDefault + u := &url.URL{Scheme: "postgres", Host: "localhost", Path: "/appbench"} + return &DB{driver: "postgres", dsn: u.String()} +} + +func (db *DB) PingContext(context.Context) error { return nil } +` +} + +func appShapeDepFile(index int) string { + return fmt.Sprintf(`package dep%04d + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "path/filepath" + "strings" +) + +type Value struct { + Name string +} + +func Provide() Value { + sum := sha256.Sum256([]byte(fmt.Sprintf("dep-%%04d", %d))) + return Value{ + Name: filepath.Join("deps", strings.ToLower(hex.EncodeToString(sum[:])))[:16], + } +} +`, index, index) +} + +func appShapeConfigFile(modulePath string) string { + return `package config + +import ( + "encoding/json" + "os" + "strconv" +) + +type Config struct { + Port int + Service string +} + +func NewConfig() *Config { + cfg := &Config{Port: 8080, Service: "appbench"} + if v := os.Getenv("APPBENCH_PORT"); v != "" { + if port, err := strconv.Atoi(v); err == nil { + cfg.Port = port + } + } + _, _ = json.Marshal(cfg) + return cfg +} +` +} + +func appShapeMetricsFile(modulePath string) string { + return `package metrics + +import ( + "expvar" + "fmt" + "sync/atomic" +) + +type Metrics struct { + requests atomic.Int64 + name string +} + +func NewMetrics() *Metrics { + expvar.NewString("appbench_name").Set("appbench") + return &Metrics{name: fmt.Sprintf("appbench_%s", "requests")} +} +` +} + +func appShapeHTTPXFile(modulePath string) string { + return `package httpx + +import ( + "context" + "net/http" + "net/http/httptest" +) + +type Client struct { + client *http.Client +} + +func NewClient() *Client { + req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) + _ = req.WithContext(context.Background()) + return &Client{client: &http.Client{}} +} +` +} + +func appShapeExtSinkFile(modulePath string) string { + return `package extsink + +import ( + "context" + "fmt" + "os" + + _ "github.com/alecthomas/kong" + _ "github.com/charmbracelet/lipgloss" + _ "github.com/charmbracelet/lipgloss/table" + "github.com/fsnotify/fsnotify" + _ "github.com/glebarez/sqlite" + _ "github.com/goforj/cache" + _ "github.com/goforj/crypt" + _ "github.com/goforj/env/v2" + _ "github.com/goforj/httpx" + _ "github.com/goforj/null/v6" + _ "github.com/goforj/queue" + _ "github.com/goforj/queue/driver/redisqueue" + _ "github.com/goforj/scheduler" + _ "github.com/goforj/storage" + _ "github.com/goforj/storage/driver/localstorage" + _ "github.com/goforj/storage/driver/redisstorage" + _ "github.com/goforj/str" + "github.com/google/go-cmp/cmp" + "github.com/google/subcommands" + _ "github.com/google/uuid" + _ "github.com/gorilla/websocket" + _ "github.com/hibiken/asynq" + _ "github.com/imroc/req/v3" + _ "github.com/labstack/echo/v4" + _ "github.com/labstack/echo/v4/middleware" + "github.com/pmezard/go-difflib/difflib" + _ "github.com/redis/go-redis/v9" + _ "github.com/rs/zerolog" + _ "github.com/shirou/gopsutil/v4/cpu" + _ "github.com/shirou/gopsutil/v4/disk" + _ "github.com/shirou/gopsutil/v4/host" + _ "github.com/shirou/gopsutil/v4/mem" + _ "github.com/shirou/gopsutil/v4/net" + _ "github.com/shirou/gopsutil/v4/process" + "golang.org/x/mod/modfile" + _ "golang.org/x/net/http2" + _ "golang.org/x/net/http2/h2c" + "golang.org/x/sync/errgroup" + "golang.org/x/sys/unix" + _ "golang.org/x/term" + "golang.org/x/tools/go/packages" + _ "gopkg.in/yaml.v3" + _ "gorm.io/driver/mysql" + _ "gorm.io/driver/postgres" + _ "gorm.io/gorm" +) + +type Sink struct { + label string +} + +func NewSink() *Sink { + _ = cmp.Equal("a", "b") + _ = difflib.UnifiedDiff{} + _, _ = modfile.Parse("go.mod", []byte("module example.com/appbench"), nil) + _, _ = packages.Load(&packages.Config{Mode: packages.NeedName}, "fmt") + var g errgroup.Group + g.Go(func() error { return nil }) + _ = unix.Getpid() + _ = fsnotify.Event{Name: os.TempDir()} + _ = subcommands.ExitSuccess + return &Sink{label: fmt.Sprintf("sink:%v", context.Background() != nil)} +} +` +} + +func appShapeFeatureFile(modulePath, wireModulePath string, index, depPkgs int, external bool) string { + pkg := fmt.Sprintf("feature%04d", index) + var depImports strings.Builder + var depUse strings.Builder + for i := 0; i < depPkgs; i++ { + depImports.WriteString(fmt.Sprintf("\tdep%04d %q\n", i, fmt.Sprintf("%s/internal/dep%04d", modulePath, i))) + depUse.WriteString(fmt.Sprintf("\t_ = dep%04d.Provide()\n", i)) + } + externalImport := "" + externalArg := "" + externalField := "" + externalUse := "" + if external { + externalImport = fmt.Sprintf("\t%q\n", modulePath+"/internal/extsink") + externalArg = ", sink *extsink.Sink" + externalField = "\tsink *extsink.Sink\n" + externalUse = "\t_ = sink\n" + } + return fmt.Sprintf(`package %s + +import ( + "context" + "encoding/json" + "fmt" + "net/url" + "strconv" + "time" + wire %q + %q + %q + %q + %q + %q +%s +%s +) + +type Repo struct { + db *db.DB + config *config.Config + metrics *metrics.Metrics +%s} + +type Service struct { + repo *Repo + logger *logger.Logger + client *httpx.Client +} + +func NewRepo(dbConn *db.DB, cfg *config.Config, m *metrics.Metrics, l *logger.Logger%s) *Repo { + _, _ = json.Marshal(map[string]any{"feature": %d, "service": cfg.Service}) + l.Log(context.Background(), "repo.init", map[string]string{"feature": strconv.Itoa(%d)}) +%s return &Repo{db: dbConn, config: cfg, metrics: m} +} + +func NewService(repo *Repo, l *logger.Logger, client *httpx.Client) *Service { + _, _ = url.Parse(fmt.Sprintf("https://example.com/%%04d", %d)) + _ = time.Second + return &Service{repo: repo, logger: l, client: client} +} + +var Set = wire.NewSet(NewRepo, NewService, NewController) +`, pkg, wireModulePath, modulePath+"/internal/config", modulePath+"/internal/db", modulePath+"/internal/httpx", modulePath+"/internal/logger", modulePath+"/internal/metrics", depImports.String(), externalImport, externalField, externalArg, index, index, depUse.String()+externalUse, index) +} + +func appShapeControllerFile(modulePath string, index int, variant string) string { + pkg := fmt.Sprintf("feature%04d", index) + imports := []string{ + `"context"`, + `"fmt"`, + `"net/http"`, + `"strconv"`, + `"` + modulePath + `/internal/logger"`, + } + if variant == "shape" { + imports = append(imports, `"`+modulePath+`/internal/db"`) + } + if variant == "import" { + imports = append(imports, `"strings"`) + } + bodyLine := "" + switch variant { + case "body": + bodyLine = "\t_ = \"body-edit\"\n" + case "import": + bodyLine = "\t_ = strings.TrimSpace(\" import-edit \")\n" + } + extraField := "" + extraArg := "" + extraInit := "" + if variant == "shape" { + extraField = "\tdb *db.DB\n" + extraArg = ", d *db.DB" + extraInit = "\t\tdb: d,\n" + } + return fmt.Sprintf(`package %s + +import ( + %s +) + +type Controller struct { + logger *logger.Logger + service *Service +%s} + +func NewController(l *logger.Logger, s *Service%s) *Controller { +%s l.Log(context.Background(), "controller.init", map[string]string{"feature": strconv.Itoa(%d)}) + _ = http.MethodGet + _ = fmt.Sprintf("feature-%%d", %d) + return &Controller{ + logger: l, + service: s, +%s } +} +`, pkg, strings.Join(imports, "\n\t"), extraField, extraArg, bodyLine, index, index, extraInit) +} + +func appShapeAppFile(modulePath string, features int) string { + var b strings.Builder + b.WriteString("package wire\n\n") + if features > 0 { + b.WriteString("import (\n") + for i := 0; i < features; i++ { + b.WriteString(fmt.Sprintf("\tfeature%04d %q\n", i, fmt.Sprintf("%s/internal/feature%04d", modulePath, i))) + } + b.WriteString(")\n\n") + } + b.WriteString("type App struct{}\n\n") + b.WriteString("func NewApp(") + for i := 0; i < features; i++ { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(fmt.Sprintf("_ *feature%04d.Controller", i)) + } + b.WriteString(") *App {\n\treturn &App{}\n}\n") + return b.String() +} + +func appShapeWireFile(modulePath, wireModulePath string, features int, external bool) string { + var b strings.Builder + b.WriteString("//go:build wireinject\n\n") + b.WriteString("package wire\n\n") + b.WriteString("import (\n") + b.WriteString(fmt.Sprintf("\twire %q\n", wireModulePath)) + b.WriteString(fmt.Sprintf("\t%q\n", modulePath+"/internal/config")) + b.WriteString(fmt.Sprintf("\t%q\n", modulePath+"/internal/db")) + if external { + b.WriteString(fmt.Sprintf("\t%q\n", modulePath+"/internal/extsink")) + } + b.WriteString(fmt.Sprintf("\t%q\n", modulePath+"/internal/httpx")) + b.WriteString(fmt.Sprintf("\t%q\n", modulePath+"/internal/logger")) + b.WriteString(fmt.Sprintf("\t%q\n", modulePath+"/internal/metrics")) + for i := 0; i < features; i++ { + b.WriteString(fmt.Sprintf("\t%q\n", fmt.Sprintf("%s/internal/feature%04d", modulePath, i))) + } + b.WriteString(")\n\n") + b.WriteString("func Initialize() *App {\n\twire.Build(\n") + b.WriteString("\t\tconfig.NewConfig,\n") + b.WriteString("\t\tlogger.NewLogger,\n") + b.WriteString("\t\tdb.NewDB,\n") + if external { + b.WriteString("\t\textsink.NewSink,\n") + } + b.WriteString("\t\thttpx.NewClient,\n") + b.WriteString("\t\tmetrics.NewMetrics,\n") + for i := 0; i < features; i++ { + b.WriteString(fmt.Sprintf("\t\tfeature%04d.Set,\n", i)) + } + b.WriteString("\t\tNewApp,\n\t)\n\treturn nil\n}\n") + return b.String() +} + +func runBodyEditTrials(t *testing.T, bin string, imports int, wireModulePath, wireReplaceDir string, trials int) []time.Duration { + t.Helper() + durations := make([]time.Duration, 0, trials) + caches := newBenchCaches(t) + for i := 0; i < trials; i++ { + pkgDir := createImportBenchFixture(t, imports, wireModulePath, wireReplaceDir) + prewarmGoBenchCache(t, pkgDir, caches) + root := filepath.Dir(pkgDir) + _ = runWireBenchCommand(t, bin, pkgDir, caches) + writeImportBenchDepFile(t, root, 0, "body") + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, caches)) + } + return durations +} + +func runShapeEditTrials(t *testing.T, bin string, imports int, wireModulePath, wireReplaceDir string, trials int) []time.Duration { + t.Helper() + durations := make([]time.Duration, 0, trials) + caches := newBenchCaches(t) + for i := 0; i < trials; i++ { + pkgDir := createImportBenchFixture(t, imports, wireModulePath, wireReplaceDir) + prewarmGoBenchCache(t, pkgDir, caches) + root := filepath.Dir(pkgDir) + _ = runWireBenchCommand(t, bin, pkgDir, caches) + writeImportBenchDepFile(t, root, 0, "shape") + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, caches)) + } + return durations +} + +func runImportChangeTrials(t *testing.T, bin string, imports int, wireModulePath, wireReplaceDir string, trials int) []time.Duration { + t.Helper() + durations := make([]time.Duration, 0, trials) + caches := newBenchCaches(t) + for i := 0; i < trials; i++ { + pkgDir := createImportBenchFixture(t, imports+1, wireModulePath, wireReplaceDir) + prewarmGoBenchCache(t, pkgDir, caches) + root := filepath.Dir(pkgDir) + writeImportBenchWireFile(t, root, imports, wireModulePath) + _ = runWireBenchCommand(t, bin, pkgDir, caches) + writeImportBenchWireFile(t, root, imports+1, wireModulePath) + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, caches)) + } + return durations +} + +func runKnownImportToggleTrials(t *testing.T, bin string, imports int, wireModulePath, wireReplaceDir string, trials int) []time.Duration { + t.Helper() + durations := make([]time.Duration, 0, trials) + caches := newBenchCaches(t) + for i := 0; i < trials; i++ { + pkgDir := createImportBenchFixture(t, imports+1, wireModulePath, wireReplaceDir) + prewarmGoBenchCache(t, pkgDir, caches) + root := filepath.Dir(pkgDir) + writeImportBenchWireFile(t, root, imports, wireModulePath) + _ = runWireBenchCommand(t, bin, pkgDir, caches) + writeImportBenchWireFile(t, root, imports+1, wireModulePath) + _ = runWireBenchCommand(t, bin, pkgDir, caches) + writeImportBenchWireFile(t, root, imports, wireModulePath) + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, caches)) + } + return durations +} + +func runWireBenchCommand(t *testing.T, bin, pkgDir string, caches benchCaches) time.Duration { t.Helper() - d, _ := runWireBenchCommandOutput(t, bin, pkgDir, home, goCache) + d, _ := runWireBenchCommandOutput(t, bin, pkgDir, caches) return d } -func runWireBenchCommandOutput(t *testing.T, bin, pkgDir, home, goCache string, extraArgs ...string) (time.Duration, string) { +func runWireBenchCommandOutput(t *testing.T, bin, pkgDir string, caches benchCaches, extraArgs ...string) (time.Duration, string) { t.Helper() args := []string{"gen"} args = append(args, extraArgs...) cmd := exec.Command(bin, args...) cmd.Dir = pkgDir - cmd.Env = append(benchEnv(home, goCache), "WIRE_LOADER_ARTIFACTS=1") + cmd.Env = append(benchEnv(caches.home, caches.goCache), "WIRE_LOADER_ARTIFACTS=1") var stderr bytes.Buffer cmd.Stdout = io.Discard cmd.Stderr = &stderr @@ -220,25 +1025,96 @@ func runWireBenchCommandOutput(t *testing.T, bin, pkgDir, home, goCache string, return time.Since(start), stderr.String() } -func runColdTrials(t *testing.T, bin, pkgDir string, trials int) []time.Duration { +func prewarmGoBenchCache(t *testing.T, pkgDir string, caches benchCaches) { + t.Helper() + prepareBenchModule(t, pkgDir, caches) + cmd := exec.Command("go", "list", "-deps", "./...") + cmd.Dir = pkgDir + cmd.Env = benchEnv(caches.home, caches.goCache) + output, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("prewarm go cache in %s: %v\n%s", pkgDir, err, output) + } +} + +func goListGraphCounts(t *testing.T, pkgDir, modulePath string, caches benchCaches) benchGraphCounts { + t.Helper() + prepareBenchModule(t, pkgDir, caches) + cmd := exec.Command("go", "list", "-deps", "-json", "./...") + cmd.Dir = pkgDir + cmd.Env = benchEnv(caches.home, caches.goCache) + output, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("go list graph counts in %s: %v\n%s", pkgDir, err, output) + } + dec := json.NewDecoder(bytes.NewReader(output)) + seen := make(map[string]struct{}) + var counts benchGraphCounts + for { + var pkg struct { + ImportPath string + Standard bool + } + if err := dec.Decode(&pkg); err != nil { + if err == io.EOF { + break + } + t.Fatalf("decode graph counts for %s: %v", pkgDir, err) + } + if pkg.ImportPath == "" { + continue + } + if _, ok := seen[pkg.ImportPath]; ok { + continue + } + seen[pkg.ImportPath] = struct{}{} + switch { + case pkg.Standard: + counts.stdlib++ + case pkg.ImportPath == modulePath || strings.HasPrefix(pkg.ImportPath, modulePath+"/"): + counts.local++ + default: + counts.external++ + } + } + return counts +} + +func prepareBenchModule(t *testing.T, pkgDir string, caches benchCaches) { + t.Helper() + marker := filepath.Join(filepath.Dir(pkgDir), ".bench-module-ready") + if _, err := os.Stat(marker); err == nil { + return + } + cmd := exec.Command("go", "mod", "tidy") + cmd.Dir = filepath.Dir(pkgDir) + cmd.Env = benchEnv(caches.home, caches.goCache) + output, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("prepare bench module in %s: %v\n%s", pkgDir, err, output) + } + if err := os.WriteFile(marker, []byte("ok\n"), 0o644); err != nil { + t.Fatalf("write module marker %s: %v", marker, err) + } +} + +func runColdTrials(t *testing.T, bin, pkgDir string, caches benchCaches, trials int) []time.Duration { t.Helper() durations := make([]time.Duration, 0, trials) for i := 0; i < trials; i++ { - home := t.TempDir() - goCache := filepath.Join(t.TempDir(), "gocache") - durations = append(durations, runWireBenchCommand(t, bin, pkgDir, home, goCache)) + prewarmGoBenchCache(t, pkgDir, caches) + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, caches)) } return durations } -func runWarmTrials(t *testing.T, bin, pkgDir string, trials int) []time.Duration { +func runWarmTrials(t *testing.T, bin, pkgDir string, caches benchCaches, trials int) []time.Duration { t.Helper() durations := make([]time.Duration, 0, trials) for i := 0; i < trials; i++ { - home := t.TempDir() - goCache := filepath.Join(t.TempDir(), "gocache") - _ = runWireBenchCommand(t, bin, pkgDir, home, goCache) - durations = append(durations, runWireBenchCommand(t, bin, pkgDir, home, goCache)) + prewarmGoBenchCache(t, pkgDir, caches) + _ = runWireBenchCommand(t, bin, pkgDir, caches) + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, caches)) } return durations } @@ -262,6 +1138,7 @@ func benchEnv(home, goCache string) []string { "HOME="+home, "GOCACHE="+goCache, "GOMODCACHE=/tmp/gomodcache", + "GOSUMDB=off", ) return env } @@ -304,6 +1181,33 @@ func importBenchWireFile(imports int, wireModulePath string) string { return b.String() } +func importBenchDepFile(i int, variant string) string { + switch variant { + case "body": + return fmt.Sprintf("package dep%04d\n\ntype T struct{}\n\nfunc Provide() *T {\n\t_ = \"body-edit\"\n\treturn &T{}\n}\n", i) + case "shape": + return fmt.Sprintf("package dep%04d\n\ntype T struct{ Extra int }\n\nfunc Provide() *T { return &T{} }\n", i) + default: + return fmt.Sprintf("package dep%04d\n\ntype T struct{}\n\nfunc Provide() *T { return &T{} }\n", i) + } +} + +func writeImportBenchWireFile(t *testing.T, root string, imports int, wireModulePath string) { + t.Helper() + path := filepath.Join(root, "app", "wire.go") + if err := os.WriteFile(path, []byte(importBenchWireFile(imports, wireModulePath)), 0o644); err != nil { + t.Fatal(err) + } +} + +func writeImportBenchDepFile(t *testing.T, root string, index int, variant string) { + t.Helper() + path := filepath.Join(root, fmt.Sprintf("dep%04d", index), "dep.go") + if err := os.WriteFile(path, []byte(importBenchDepFile(index, variant)), 0o644); err != nil { + t.Fatal(err) + } +} + func printImportBenchTable(t *testing.T, rows []importBenchRow) { t.Helper() fmt.Println("+-----------+-----------+--------------+-------------------+--------------+-------------------+") @@ -322,6 +1226,104 @@ func printImportBenchTable(t *testing.T, rows []importBenchRow) { fmt.Println("+-----------+-----------+--------------+-------------------+--------------+-------------------+") } +func printImportScenarioBenchTable(t *testing.T, rows []importBenchScenarioRow) { + t.Helper() + profileWidth := len("profile") + localWidth := len("local") + stdlibWidth := len("stdlib") + externalWidth := len("external") + changeTypeWidth := len("change type") + stockWidth := len("stock") + currentWidth := len("current") + speedupWidth := len("speedup") + for _, row := range rows { + profileWidth = maxInt(profileWidth, len(row.profile)) + localWidth = maxInt(localWidth, len(fmt.Sprintf("%d", row.localCount))) + stdlibWidth = maxInt(stdlibWidth, len(fmt.Sprintf("%d", row.stdlibCount))) + externalWidth = maxInt(externalWidth, len(fmt.Sprintf("%d", row.externalCount))) + changeTypeWidth = maxInt(changeTypeWidth, len(row.name)) + stockWidth = maxInt(stockWidth, len(formatMs(row.stock))) + currentWidth = maxInt(currentWidth, len(formatMs(row.current))) + speedupWidth = maxInt(speedupWidth, len(formatSpeedup(row.stock, row.current))) + } + sep := fmt.Sprintf("+-%s-+-%s-+-%s-+-%s-+-%s-+-%s-+-%s-+-%s-+", + strings.Repeat("-", profileWidth), + strings.Repeat("-", localWidth), + strings.Repeat("-", stdlibWidth), + strings.Repeat("-", externalWidth), + strings.Repeat("-", changeTypeWidth), + strings.Repeat("-", stockWidth), + strings.Repeat("-", currentWidth), + strings.Repeat("-", speedupWidth), + ) + fmt.Println(sep) + fmt.Printf("| %*s | %-*s | %-*s | %-*s | %-*s | %-*s | %-*s | %-*s |\n", + profileWidth, "profile", + localWidth, "local", + stdlibWidth, "stdlib", + externalWidth, "external", + changeTypeWidth, "change type", + stockWidth, "stock", + currentWidth, "current", + speedupWidth, "speedup", + ) + fmt.Println(sep) + for _, row := range rows { + fmt.Printf("| %*s | %-*d | %-*d | %-*d | %-*s | %-*s | %-*s | %-*s |\n", + profileWidth, row.profile, + localWidth, row.localCount, + stdlibWidth, row.stdlibCount, + externalWidth, row.externalCount, + changeTypeWidth, row.name, + stockWidth, formatMs(row.stock), + currentWidth, formatMs(row.current), + speedupWidth, formatSpeedup(row.stock, row.current), + ) + } + fmt.Println(sep) + fmt.Println() + fmt.Println("change types:") + fmt.Println(" cold run: first wire gen on a fresh Wire cache for that repo shape") + fmt.Println(" unchanged rerun: run wire gen again without changing any files") + fmt.Println(" body-only local edit: change local function body/content without changing imports, types, or constructor signatures") + fmt.Println(" shape change: change local provider/type shape such as constructor params, fields, or return shape") + fmt.Println(" import change: add or remove a local import, which can change discovered package shape") + fmt.Println(" known import toggle: switch back to a previously seen import/shape state in the same repo") +} + +func maxInt(a, b int) int { + if a > b { + return a + } + return b +} + +func printScenarioTimingLines(output string) { + for _, line := range strings.Split(output, "\n") { + if !strings.Contains(line, "wire: timing:") { + continue + } + if strings.Contains(line, "loader.custom.root.discovery=") || + strings.Contains(line, "loader.discovery.") || + strings.Contains(line, "load.packages.load=") || + strings.Contains(line, "load.debug") || + strings.Contains(line, "loader.custom.typed.artifact_read=") || + strings.Contains(line, "loader.custom.typed.artifact_decode=") || + strings.Contains(line, "loader.custom.typed.artifact_import_link=") || + strings.Contains(line, "loader.custom.typed.artifact_write=") || + strings.Contains(line, "loader.custom.typed.root_load.wall=") || + strings.Contains(line, "loader.custom.typed.discovery.wall=") || + strings.Contains(line, "loader.custom.typed.artifact_hits=") || + strings.Contains(line, "loader.custom.typed.artifact_misses=") || + strings.Contains(line, "loader.custom.typed.artifact_writes=") || + strings.Contains(line, "generate.package.") || + strings.Contains(line, "wire.Generate=") || + strings.Contains(line, "total=") { + fmt.Println(line) + } + } +} + func formatMs(d time.Duration) string { return fmt.Sprintf("%.1fms", float64(d)/float64(time.Millisecond)) } @@ -340,7 +1342,9 @@ func TestImportBenchFixtureGenerates(t *testing.T) { } bin := buildWireBinary(t, repoRoot, "fixture-wire") fixture := createImportBenchFixture(t, 10, currentWireModulePath, repoRoot) - _ = runWireBenchCommand(t, bin, fixture, t.TempDir(), filepath.Join(t.TempDir(), "gocache")) + caches := newBenchCaches(t) + prewarmGoBenchCache(t, fixture, caches) + _ = runWireBenchCommand(t, bin, fixture, caches) } func TestImportBenchUsesStockArchive(t *testing.T) { diff --git a/scripts/import-benchmarks.sh b/scripts/import-benchmarks.sh index e1c1f97..2eb98c2 100755 --- a/scripts/import-benchmarks.sh +++ b/scripts/import-benchmarks.sh @@ -12,10 +12,12 @@ usage() { cat <<'EOF' Usage: scripts/import-benchmarks.sh table + scripts/import-benchmarks.sh scenarios scripts/import-benchmarks.sh breakdown Commands: table Print the 10/100/1000 import stock-vs-current benchmark table. + scenarios Print the stock-vs-current change-type scenario table. breakdown Print a focused 1000-import cold/unchanged breakdown. EOF } @@ -24,6 +26,9 @@ case "${1:-}" in table) WIRE_IMPORT_BENCH_TABLE=1 go test ./internal/wire -run TestPrintImportScaleBenchmarkTable -count=1 -v ;; + scenarios) + WIRE_IMPORT_BENCH_SCENARIOS=1 go test ./internal/wire -run TestPrintImportScenarioBenchmarkTable -count=1 -v + ;; breakdown) WIRE_IMPORT_BENCH_BREAKDOWN=1 go test ./internal/wire -run TestPrintImportScaleBenchmarkBreakdown -count=1 -v ;; From ba95466e9ebc18668d8d729ffdb6c37b2fba9cf3 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sun, 15 Mar 2026 21:17:52 -0500 Subject: [PATCH 18/79] fix: ci --- cmd/wire/cache_cmd_test.go | 12 ++-- cmd/wire/check_cmd.go | 4 +- cmd/wire/diff_cmd.go | 6 +- cmd/wire/main.go | 10 ++-- cmd/wire/show_cmd.go | 4 +- internal/loader/custom.go | 90 +++++++++++++++--------------- internal/loader/discovery_cache.go | 18 +++--- internal/wire/import_bench_test.go | 17 ++++-- 8 files changed, 83 insertions(+), 78 deletions(-) diff --git a/cmd/wire/cache_cmd_test.go b/cmd/wire/cache_cmd_test.go index c0c74c1..83924e2 100644 --- a/cmd/wire/cache_cmd_test.go +++ b/cmd/wire/cache_cmd_test.go @@ -10,9 +10,9 @@ func TestWireCacheTargetsDefault(t *testing.T) { base := filepath.Join(t.TempDir(), "cache") got := wireCacheTargets(nil, base) want := map[string]string{ - "discovery-cache": filepath.Join(base, "wire", "discovery-cache"), - "loader-artifacts": filepath.Join(base, "wire", "loader-artifacts"), - "output-cache": filepath.Join(base, "wire", "output-cache"), + "discovery-cache": filepath.Join(base, "wire", "discovery-cache"), + "loader-artifacts": filepath.Join(base, "wire", "loader-artifacts"), + "output-cache": filepath.Join(base, "wire", "output-cache"), "semantic-artifacts": filepath.Join(base, "wire", "semantic-artifacts"), } if len(got) != len(want) { @@ -50,9 +50,9 @@ func TestWireCacheTargetsRespectOverrides(t *testing.T) { } got := wireCacheTargets(env, base) want := map[string]string{ - "discovery-cache": filepath.Join(base, "wire", "discovery-cache"), - "loader-artifacts": filepath.Join(base, "loader"), - "output-cache": filepath.Join(base, "output"), + "discovery-cache": filepath.Join(base, "wire", "discovery-cache"), + "loader-artifacts": filepath.Join(base, "loader"), + "output-cache": filepath.Join(base, "output"), "semantic-artifacts": filepath.Join(base, "semantic"), } for _, target := range got { diff --git a/cmd/wire/check_cmd.go b/cmd/wire/check_cmd.go index 897bec2..7857437 100644 --- a/cmd/wire/check_cmd.go +++ b/cmd/wire/check_cmd.go @@ -26,8 +26,8 @@ import ( ) type checkCmd struct { - tags string - profile profileFlags + tags string + profile profileFlags } // Name returns the subcommand name. diff --git a/cmd/wire/diff_cmd.go b/cmd/wire/diff_cmd.go index 5aad2f1..592cced 100644 --- a/cmd/wire/diff_cmd.go +++ b/cmd/wire/diff_cmd.go @@ -29,9 +29,9 @@ import ( ) type diffCmd struct { - headerFile string - tags string - profile profileFlags + headerFile string + tags string + profile profileFlags } // Name returns the subcommand name. diff --git a/cmd/wire/main.go b/cmd/wire/main.go index 515673c..c13b850 100644 --- a/cmd/wire/main.go +++ b/cmd/wire/main.go @@ -36,11 +36,11 @@ import ( ) const ( - ansiRed = "\033[1;31m" - ansiGreen = "\033[1;32m" - ansiReset = "\033[0m" - successSig = "✓ " - errorSig = "x " + ansiRed = "\033[1;31m" + ansiGreen = "\033[1;32m" + ansiReset = "\033[0m" + successSig = "✓ " + errorSig = "x " maxLoggedErrorLines = 5 ) diff --git a/cmd/wire/show_cmd.go b/cmd/wire/show_cmd.go index 10c737f..5a81b29 100644 --- a/cmd/wire/show_cmd.go +++ b/cmd/wire/show_cmd.go @@ -34,8 +34,8 @@ import ( ) type showCmd struct { - tags string - profile profileFlags + tags string + profile profileFlags } // Name returns the subcommand name. diff --git a/internal/loader/custom.go b/internal/loader/custom.go index dd2b9e0..52ebf4e 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -71,21 +71,21 @@ type customValidator struct { } type customTypedGraphLoader struct { - workspace string - ctx context.Context - env []string - fset *token.FileSet - meta map[string]*packageMeta - targets map[string]struct{} - parseFile ParseFileFunc - packages map[string]*packages.Package - typesPkgs map[string]*types.Package - importer types.Importer - loading map[string]bool - isLocalCache map[string]bool - localSemanticOK map[string]bool - artifactPrefetch map[string]artifactPrefetchEntry - stats typedLoadStats + workspace string + ctx context.Context + env []string + fset *token.FileSet + meta map[string]*packageMeta + targets map[string]struct{} + parseFile ParseFileFunc + packages map[string]*packages.Package + typesPkgs map[string]*types.Package + importer types.Importer + loading map[string]bool + isLocalCache map[string]bool + localSemanticOK map[string]bool + artifactPrefetch map[string]artifactPrefetchEntry + stats typedLoadStats } type artifactPrefetchEntry struct { @@ -263,21 +263,21 @@ func loadTypedPackageGraphCustom(ctx context.Context, req LazyLoadRequest) (*Laz fset = token.NewFileSet() } l := &customTypedGraphLoader{ - workspace: detectModuleRoot(req.WD), - ctx: ctx, - env: append([]string(nil), req.Env...), - fset: fset, - meta: meta, - targets: map[string]struct{}{req.Package: {}}, - parseFile: req.ParseFile, - packages: make(map[string]*packages.Package, len(meta)), - typesPkgs: make(map[string]*types.Package, len(meta)), - importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), - loading: make(map[string]bool, len(meta)), - isLocalCache: make(map[string]bool, len(meta)), - localSemanticOK: make(map[string]bool, len(meta)), - artifactPrefetch: make(map[string]artifactPrefetchEntry, len(meta)), - stats: typedLoadStats{discovery: discoveryDuration}, + workspace: detectModuleRoot(req.WD), + ctx: ctx, + env: append([]string(nil), req.Env...), + fset: fset, + meta: meta, + targets: map[string]struct{}{req.Package: {}}, + parseFile: req.ParseFile, + packages: make(map[string]*packages.Package, len(meta)), + typesPkgs: make(map[string]*types.Package, len(meta)), + importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), + loading: make(map[string]bool, len(meta)), + isLocalCache: make(map[string]bool, len(meta)), + localSemanticOK: make(map[string]bool, len(meta)), + artifactPrefetch: make(map[string]artifactPrefetchEntry, len(meta)), + stats: typedLoadStats{discovery: discoveryDuration}, } prefetchStart := time.Now() l.prefetchArtifacts() @@ -345,21 +345,21 @@ func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLo return nil, unsupportedError{reason: "no root packages from metadata"} } l := &customTypedGraphLoader{ - workspace: detectModuleRoot(req.WD), - ctx: ctx, - env: append([]string(nil), req.Env...), - fset: fset, - meta: meta, - targets: targets, - parseFile: req.ParseFile, - packages: make(map[string]*packages.Package, len(meta)), - typesPkgs: make(map[string]*types.Package, len(meta)), - importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), - loading: make(map[string]bool, len(meta)), - isLocalCache: make(map[string]bool, len(meta)), - localSemanticOK: make(map[string]bool, len(meta)), - artifactPrefetch: make(map[string]artifactPrefetchEntry, len(meta)), - stats: typedLoadStats{discovery: discoveryDuration}, + workspace: detectModuleRoot(req.WD), + ctx: ctx, + env: append([]string(nil), req.Env...), + fset: fset, + meta: meta, + targets: targets, + parseFile: req.ParseFile, + packages: make(map[string]*packages.Package, len(meta)), + typesPkgs: make(map[string]*types.Package, len(meta)), + importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), + loading: make(map[string]bool, len(meta)), + isLocalCache: make(map[string]bool, len(meta)), + localSemanticOK: make(map[string]bool, len(meta)), + artifactPrefetch: make(map[string]artifactPrefetchEntry, len(meta)), + stats: typedLoadStats{discovery: discoveryDuration}, } prefetchStart := time.Now() l.prefetchArtifacts() diff --git a/internal/loader/discovery_cache.go b/internal/loader/discovery_cache.go index 4ec9a12..3b9fe46 100644 --- a/internal/loader/discovery_cache.go +++ b/internal/loader/discovery_cache.go @@ -14,15 +14,15 @@ import ( ) type discoveryCacheEntry struct { - Version int - WD string - Tags string - Patterns []string - NeedDeps bool - Workspace string - Meta map[string]*packageMeta - Global []discoveryFileMeta - LocalPkgs []discoveryLocalPackage + Version int + WD string + Tags string + Patterns []string + NeedDeps bool + Workspace string + Meta map[string]*packageMeta + Global []discoveryFileMeta + LocalPkgs []discoveryLocalPackage } type discoveryLocalPackage struct { diff --git a/internal/wire/import_bench_test.go b/internal/wire/import_bench_test.go index aff71e2..e2f5900 100644 --- a/internal/wire/import_bench_test.go +++ b/internal/wire/import_bench_test.go @@ -26,10 +26,10 @@ const ( ) type importBenchRow struct { - imports int - stockCold time.Duration - currentCold time.Duration - currentWarm time.Duration + imports int + stockCold time.Duration + currentCold time.Duration + currentWarm time.Duration } type importBenchScenarioRow struct { @@ -520,7 +520,7 @@ func appShapeGoMod(modulePath, wireModulePath, wireReplaceDir string, external b } return fmt.Sprintf(`module %s -go 1.26 +go 1.19 require ( %s v0.0.0%s @@ -1146,7 +1146,7 @@ func benchEnv(home, goCache string) []string { func importBenchGoMod(wireModulePath, wireReplaceDir string) string { return fmt.Sprintf(`module example.com/importbench -go 1.26 +go 1.19 require %s v0.0.0 @@ -1352,6 +1352,11 @@ func TestImportBenchUsesStockArchive(t *testing.T) { if err != nil { t.Fatal(err) } + check := exec.Command("git", "cat-file", "-e", stockWireCommit+"^{commit}") + check.Dir = repoRoot + if err := check.Run(); err != nil { + t.Skipf("stock archive commit %s not available in checkout", stockWireCommit) + } stockDir := extractStockWire(t, repoRoot, stockWireCommit) if _, err := os.Stat(filepath.Join(stockDir, "cmd", "wire", "main.go")); err != nil { t.Fatalf("stock archive missing cmd/wire: %v", err) From 4b312b510154d5b87dac53f846b13e14a1d5d0e9 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sun, 15 Mar 2026 21:20:15 -0500 Subject: [PATCH 19/79] fix: ci --- internal/loader/loader_test.go | 25 +++++++++++++++++-------- internal/loader/timing.go | 6 ++---- internal/wire/timing.go | 7 +++---- 3 files changed, 22 insertions(+), 16 deletions(-) diff --git a/internal/loader/loader_test.go b/internal/loader/loader_test.go index 0c871e0..5f734b7 100644 --- a/internal/loader/loader_test.go +++ b/internal/loader/loader_test.go @@ -248,14 +248,23 @@ func TestValidateTouchedPackagesAutoReportsFallbackDetail(t *testing.T) { if err != nil { t.Fatalf("ValidateTouchedPackages(auto) error = %v", err) } - if got.Backend != ModeFallback { - t.Fatalf("backend = %q, want %q", got.Backend, ModeFallback) - } - if got.FallbackReason != FallbackReasonCustomUnsupported { - t.Fatalf("fallback reason = %q, want %q", got.FallbackReason, FallbackReasonCustomUnsupported) - } - if got.FallbackDetail != "metadata fingerprint mismatch" { - t.Fatalf("fallback detail = %q, want %q", got.FallbackDetail, "metadata fingerprint mismatch") + switch got.Backend { + case ModeCustom: + if got.FallbackReason != FallbackReasonNone { + t.Fatalf("fallback reason = %q, want empty for custom backend", got.FallbackReason) + } + if got.FallbackDetail != "" { + t.Fatalf("fallback detail = %q, want empty for custom backend", got.FallbackDetail) + } + case ModeFallback: + if got.FallbackReason != FallbackReasonCustomUnsupported { + t.Fatalf("fallback reason = %q, want %q", got.FallbackReason, FallbackReasonCustomUnsupported) + } + if got.FallbackDetail != "metadata fingerprint mismatch" { + t.Fatalf("fallback detail = %q, want %q", got.FallbackDetail, "metadata fingerprint mismatch") + } + default: + t.Fatalf("backend = %q, want %q or %q", got.Backend, ModeCustom, ModeFallback) } } diff --git a/internal/loader/timing.go b/internal/loader/timing.go index 4b902db..1ae9ccd 100644 --- a/internal/loader/timing.go +++ b/internal/loader/timing.go @@ -3,7 +3,6 @@ package loader import ( "context" "fmt" - "log" "time" ) @@ -49,8 +48,7 @@ func logInt(ctx context.Context, label string, v int) { } func debugf(ctx context.Context, format string, args ...interface{}) { - if timing(ctx) == nil { - return + if t := timing(ctx); t != nil { + t(fmt.Sprintf(format, args...), 0) } - log.Printf("timing: "+format, args...) } diff --git a/internal/wire/timing.go b/internal/wire/timing.go index d83754b..84c9022 100644 --- a/internal/wire/timing.go +++ b/internal/wire/timing.go @@ -16,7 +16,7 @@ package wire import ( "context" - "log" + "fmt" "time" ) @@ -52,8 +52,7 @@ func logTiming(ctx context.Context, label string, start time.Time) { } func debugf(ctx context.Context, format string, args ...interface{}) { - if timing(ctx) == nil { - return + if t := timing(ctx); t != nil { + t(fmt.Sprintf(format, args...), 0) } - log.Printf("timing: "+format, args...) } From 433ffc775b43d899f6cbae4bba2bc70a212b3c61 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sun, 15 Mar 2026 21:26:29 -0500 Subject: [PATCH 20/79] fix: windows tmpdir issue --- internal/loader/loader_test.go | 12 ++++++++---- internal/wire/import_bench_test.go | 9 ++++++++- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/internal/loader/loader_test.go b/internal/loader/loader_test.go index 5f734b7..27a129a 100644 --- a/internal/loader/loader_test.go +++ b/internal/loader/loader_test.go @@ -1118,12 +1118,16 @@ func normalizeErrorPos(pos string) string { if pos == "" || pos == "-" { return pos } - parts := strings.Split(pos, ":") - if len(parts) < 2 { + last := strings.LastIndex(pos, ":") + if last == -1 { return shortenComparablePath(normalizePathForCompare(pos)) } - path := shortenComparablePath(normalizePathForCompare(parts[0])) - return strings.Join(append([]string{path}, parts[1:]...), ":") + prev := strings.LastIndex(pos[:last], ":") + if prev == -1 { + return shortenComparablePath(normalizePathForCompare(pos)) + } + path := shortenComparablePath(normalizePathForCompare(pos[:prev])) + return path + pos[prev:] } func expandSummaryDiagnostics(msg string) []string { diff --git a/internal/wire/import_bench_test.go b/internal/wire/import_bench_test.go index e2f5900..4f5dd82 100644 --- a/internal/wire/import_bench_test.go +++ b/internal/wire/import_bench_test.go @@ -1137,12 +1137,19 @@ func benchEnv(home, goCache string) []string { env = append(env, "HOME="+home, "GOCACHE="+goCache, - "GOMODCACHE=/tmp/gomodcache", + "GOMODCACHE="+benchModCache(), "GOSUMDB=off", ) return env } +func benchModCache() string { + if path := os.Getenv("GOMODCACHE"); path != "" { + return path + } + return filepath.Join(os.TempDir(), "gomodcache") +} + func importBenchGoMod(wireModulePath, wireReplaceDir string) string { return fmt.Sprintf(`module example.com/importbench From d941179194e76e3891c9f746d8492e54847b9597 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sun, 15 Mar 2026 21:32:44 -0500 Subject: [PATCH 21/79] fix: windows bench executable path --- internal/wire/import_bench_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/internal/wire/import_bench_test.go b/internal/wire/import_bench_test.go index 4f5dd82..cd38190 100644 --- a/internal/wire/import_bench_test.go +++ b/internal/wire/import_bench_test.go @@ -10,6 +10,7 @@ import ( "os" "os/exec" "path/filepath" + "runtime" "strings" "testing" "time" @@ -314,6 +315,9 @@ func runAppKnownToggleTrials(t *testing.T, bin string, features, depPkgs int, ex func buildWireBinary(t *testing.T, dir, name string) string { t.Helper() + if runtime.GOOS == "windows" && filepath.Ext(name) != ".exe" { + name += ".exe" + } out := filepath.Join(t.TempDir(), name) cmd := exec.Command("go", "build", "-o", out, "./cmd/wire") cmd.Dir = dir From 88c86347dd3e4aa6fa656e558df310e9cf037ac6 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 14:49:15 -0500 Subject: [PATCH 22/79] fix(loader): strengthen artifact keys for replaced external modules --- internal/loader/artifact_cache.go | 23 +++++ internal/loader/loader_test.go | 160 ++++++++++++++++++++++++++++++ 2 files changed, 183 insertions(+) diff --git a/internal/loader/artifact_cache.go b/internal/loader/artifact_cache.go index 42293cb..e920d5a 100644 --- a/internal/loader/artifact_cache.go +++ b/internal/loader/artifact_cache.go @@ -72,6 +72,29 @@ func loaderArtifactKey(meta *packageMeta, isLocal bool) (string, error) { if !isLocal { sum.Write([]byte(meta.Export)) sum.Write([]byte{'\n'}) + if meta.Export != "" { + info, err := os.Stat(meta.Export) + if err != nil { + return "", err + } + sum.Write([]byte(strconv.FormatInt(info.Size(), 10))) + sum.Write([]byte{'\n'}) + sum.Write([]byte(strconv.FormatInt(info.ModTime().UnixNano(), 10))) + sum.Write([]byte{'\n'}) + } else { + for _, name := range metaFiles(meta) { + info, err := os.Stat(name) + if err != nil { + return "", err + } + sum.Write([]byte(name)) + sum.Write([]byte{'\n'}) + sum.Write([]byte(strconv.FormatInt(info.Size(), 10))) + sum.Write([]byte{'\n'}) + sum.Write([]byte(strconv.FormatInt(info.ModTime().UnixNano(), 10))) + sum.Write([]byte{'\n'}) + } + } if meta.Error != nil { sum.Write([]byte(meta.Error.Err)) sum.Write([]byte{'\n'}) diff --git a/internal/loader/loader_test.go b/internal/loader/loader_test.go index 27a129a..2010f56 100644 --- a/internal/loader/loader_test.go +++ b/internal/loader/loader_test.go @@ -702,6 +702,166 @@ func TestLoadTypedPackageGraphCustomExternalArtifactCacheReportsHits(t *testing. } } +func TestLoadTypedPackageGraphCustomArtifactCacheReplacedModuleSourceChange(t *testing.T) { + root := t.TempDir() + depRoot := filepath.Join(root, "depmod") + appRoot := filepath.Join(root, "appmod") + artifactDir := t.TempDir() + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), "package dep\n\nfunc New() string { return \"ok\" }\n") + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string {", + "\treturn dep.New()", + "}", + "", + }, "\n")) + + env := append(os.Environ(), + "HOME="+homeDir, + loaderArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + + load := func(mode Mode) (*LazyLoadResult, error) { + l := New() + return l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: appRoot, + Env: env, + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: mode, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + } + + first, err := load(ModeCustom) + if err != nil { + t.Fatalf("first LoadTypedPackageGraph(custom) error = %v", err) + } + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Logger interface { Log(string) }", + "", + "type NoopLogger struct{}", + "", + "func (NoopLogger) Log(string) {}", + "", + "func New(Logger) string { return \"ok\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string {", + "\tvar l dep.Logger = dep.NoopLogger{}", + "\treturn dep.New(l)", + "}", + "", + }, "\n")) + + custom, err := load(ModeCustom) + if err != nil { + t.Fatalf("second LoadTypedPackageGraph(custom) error = %v", err) + } + if len(custom.Packages) != 1 { + t.Fatalf("second custom packages len = %d, want 1", len(custom.Packages)) + } + if got := comparableErrors(custom.Packages[0].Errors); len(got) != 0 { + t.Fatalf("second custom load returned errors: %v", got) + } + + fallback, err := load(ModeFallback) + if err != nil { + t.Fatalf("second LoadTypedPackageGraph(fallback) error = %v", err) + } + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) +} + +func TestLoaderArtifactKeyExternalChangesWhenExportFileChanges(t *testing.T) { + exportPath := filepath.Join(t.TempDir(), "dep.a") + writeTestFile(t, exportPath, "first export payload") + + meta := &packageMeta{ + ImportPath: "example.com/dep", + Name: "dep", + Export: exportPath, + } + + first, err := loaderArtifactKey(meta, false) + if err != nil { + t.Fatalf("loaderArtifactKey(first) error = %v", err) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, exportPath, "second export payload with different contents") + + second, err := loaderArtifactKey(meta, false) + if err != nil { + t.Fatalf("loaderArtifactKey(second) error = %v", err) + } + + if first == second { + t.Fatalf("loaderArtifactKey did not change after export file update: %q", first) + } +} + +func TestLoaderArtifactKeyExternalWithoutExportChangesWhenSourceChanges(t *testing.T) { + sourcePath := filepath.Join(t.TempDir(), "dep.go") + writeTestFile(t, sourcePath, "package dep\n\nconst Name = \"first\"\n") + + meta := &packageMeta{ + ImportPath: "example.com/dep", + Name: "dep", + GoFiles: []string{sourcePath}, + CompiledGoFiles: []string{sourcePath}, + } + + first, err := loaderArtifactKey(meta, false) + if err != nil { + t.Fatalf("loaderArtifactKey(first) error = %v", err) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, sourcePath, "package dep\n\nconst Name = \"second\"\n") + + second, err := loaderArtifactKey(meta, false) + if err != nil { + t.Fatalf("loaderArtifactKey(second) error = %v", err) + } + + if first == second { + t.Fatalf("loaderArtifactKey did not change after external source update without export data: %q", first) + } +} + func TestLoadTypedPackageGraphCustomExternalArtifactCacheRealAppParity(t *testing.T) { root := os.Getenv("WIRE_REAL_APP_ROOT") if root == "" { From 09073ee95aff218e7a4feb21d0ba69936da135bf Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 16:12:02 -0500 Subject: [PATCH 23/79] test(loader): harden cache invalidation and discovery parity coverage --- internal/loader/custom.go | 18 +- internal/loader/discovery.go | 2 +- internal/loader/loader_test.go | 1917 ++++++++++++++++++++++++++++++++ 3 files changed, 1935 insertions(+), 2 deletions(-) diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 52ebf4e..10f8c79 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -641,7 +641,7 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er return typed, nil } if len(dep.Errors) > 0 { - return nil, unsupportedError{reason: "lazy-load dependency has errors"} + return nil, dependencyImportError(dep) } return nil, unsupportedError{reason: "missing typed lazy-load dependency"} }), @@ -1100,6 +1100,22 @@ func toPackagesError(fset *token.FileSet, err error) packages.Error { } } +func dependencyImportError(pkg *packages.Package) error { + if pkg == nil { + return unsupportedError{reason: "lazy-load dependency has errors"} + } + if pkg.Name == "" { + return fmt.Errorf("invalid package name: %q", pkg.Name) + } + for _, err := range pkg.Errors { + if strings.TrimSpace(err.Msg) == "" { + continue + } + return fmt.Errorf("%s", err.Msg) + } + return unsupportedError{reason: "lazy-load dependency has errors"} +} + type importerFunc func(path string) (*types.Package, error) func (f importerFunc) Import(path string) (*types.Package, error) { return f(path) } diff --git a/internal/loader/discovery.go b/internal/loader/discovery.go index 22b34d3..0e7e69c 100644 --- a/internal/loader/discovery.go +++ b/internal/loader/discovery.go @@ -46,7 +46,7 @@ func runGoList(ctx context.Context, req goListRequest) (map[string]*packageMeta, return cached, nil } logDuration(ctx, "loader.discovery.cache_read.wall", time.Since(cacheReadStart)) - args := []string{"list", "-json", "-e", "-compiled"} + args := []string{"list", "-json", "-e", "-compiled", "-export"} if req.NeedDeps { args = append(args, "-deps") } diff --git a/internal/loader/loader_test.go b/internal/loader/loader_test.go index 2010f56..1cb080d 100644 --- a/internal/loader/loader_test.go +++ b/internal/loader/loader_test.go @@ -15,13 +15,16 @@ package loader import ( + "archive/zip" "bytes" "context" + "fmt" "go/ast" "go/parser" "go/token" "go/types" "os" + "os/exec" "path/filepath" "sort" "strconv" @@ -805,6 +808,1522 @@ func TestLoadTypedPackageGraphCustomArtifactCacheReplacedModuleSourceChange(t *t compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) } +func TestDiscoveryCacheInvalidatesOnGoModResolutionChange(t *testing.T) { + root := t.TempDir() + depOneRoot := filepath.Join(root, "dep-one") + depTwoRoot := filepath.Join(root, "dep-two") + appRoot := filepath.Join(root, "appmod") + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depOneRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depOneRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "func New() string { return \"one\" }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(depTwoRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depTwoRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"strings\"", + "", + "func New() string { return strings.ToUpper(\"two\") }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depOneRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + + first := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depTwoRoot, + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/dep", false) +} + +func TestLoadTypedPackageGraphCustomCrossWorkspaceReplaceTargetIsolation(t *testing.T) { + cacheHome := t.TempDir() + artifactDir := t.TempDir() + repoOne := filepath.Join(t.TempDir(), "repo-one") + repoTwo := filepath.Join(t.TempDir(), "repo-two") + + depOneRoot := filepath.Join(repoOne, "depmod") + appOneRoot := filepath.Join(repoOne, "appmod") + depTwoRoot := filepath.Join(repoTwo, "depmod") + appTwoRoot := filepath.Join(repoTwo, "appmod") + + writeTestFile(t, filepath.Join(depOneRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depOneRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "func New() string { return \"one\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appOneRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depOneRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appOneRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.New() }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(depTwoRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depTwoRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"strings\"", + "", + "func New() string { return strings.ToUpper(\"two\") }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appTwoRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depTwoRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appTwoRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), + "HOME="+cacheHome, + loaderArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + + warm := loadTypedPackageGraphForTest(t, appOneRoot, env, "example.com/app/app", ModeCustom) + if len(warm.Packages) != 1 || len(warm.Packages[0].Errors) != 0 { + t.Fatalf("repo one warm custom load returned errors: %+v", warm.Packages) + } + + custom := loadTypedPackageGraphForTest(t, appTwoRoot, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, appTwoRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/dep", false) +} + +func TestLoadTypedPackageGraphCustomTransitiveShapeChangeWarmParity(t *testing.T) { + root := t.TempDir() + + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "b", "b.go"), strings.Join([]string{ + "package b", + "", + "type T struct{}", + "", + "func New() *T { return &T{} }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "a", "a.go"), strings.Join([]string{ + "package a", + "", + "import \"example.com/app/b\"", + "", + "func New() *b.T { return b.New() }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/a\"", + "", + "func Init() any { return a.New() }", + "", + }, "\n")) + + first := loadTypedPackageGraphForTest(t, root, os.Environ(), "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(root, "b", "b.go"), strings.Join([]string{ + "package b", + "", + "type T struct{}", + "", + "type Logger interface { Log(string) }", + "", + "type NoopLogger struct{}", + "", + "func (NoopLogger) Log(string) {}", + "", + "func New(Logger) *T { return &T{} }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "a", "a.go"), strings.Join([]string{ + "package a", + "", + "import \"example.com/app/b\"", + "", + "func New() *b.T {", + "\tvar logger b.Logger = b.NoopLogger{}", + "\treturn b.New(logger)", + "}", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, root, os.Environ(), "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, root, os.Environ(), "example.com/app/app", ModeFallback) + comparePackageGraphs(t, custom.Packages, fallback.Packages, true) +} + +func TestLoadTypedPackageGraphCustomReplacePathSwitchInvalidatesCaches(t *testing.T) { + root := t.TempDir() + depOneRoot := filepath.Join(root, "dep-one") + depTwoRoot := filepath.Join(root, "dep-two") + appRoot := filepath.Join(root, "appmod") + artifactDir := t.TempDir() + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depOneRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depOneRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "func New() string { return \"one\" }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(depTwoRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depTwoRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"strings\"", + "", + "func New() string { return strings.TrimSpace(\" two \") }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depOneRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), + "HOME="+homeDir, + loaderArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + + first := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depTwoRoot, + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/dep", false) +} + +func TestLoadTypedPackageGraphCustomDiscoveryCacheReplacedSiblingOutsideWorkspace(t *testing.T) { + root := t.TempDir() + depRoot := filepath.Join(root, "depmod") + appRoot := filepath.Join(root, "appmod") + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "func New() string { return \"ok\" }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + rootLoad := loadRootGraphForTest(t, appRoot, env, []string{"./app"}, ModeCustom) + if rootLoad.Discovery == nil { + t.Fatal("expected discovery snapshot from custom root load") + } + + first := loadTypedPackageGraphWithDiscoveryForTest(t, appRoot, env, "example.com/app/app", ModeCustom, rootLoad.Discovery) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Logger interface { Log(string) }", + "", + "type NoopLogger struct{}", + "", + "func (NoopLogger) Log(string) {}", + "", + "func New(Logger) string { return \"ok\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string {", + "\tvar logger dep.Logger = dep.NoopLogger{}", + "\treturn dep.New(logger)", + "}", + "", + }, "\n")) + + custom := loadTypedPackageGraphWithDiscoveryForTest(t, appRoot, env, "example.com/app/app", ModeCustom, rootLoad.Discovery) + fallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/dep", false) +} + +func TestDiscoveryCacheInvalidatesOnGeneratedFileSetChange(t *testing.T) { + root := t.TempDir() + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "func New() string { return \"base\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() string { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(root, "dep", "zz_generated.go"), strings.Join([]string{ + "package dep", + "", + "func Generated() string { return \"generated\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() string { return dep.New() + dep.Generated() }", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeFallback) + comparePackageGraphs(t, custom.Packages, fallback.Packages, true) +} + +func TestLoadTypedPackageGraphCustomBodyOnlyEditWarmParity(t *testing.T) { + root := t.TempDir() + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "func Message() string {", + "\treturn fmt.Sprint(\"before\")", + "}", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() string { return dep.Message() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "func Message() string {", + "\treturn fmt.Sprint(\"after\")", + "}", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/app/dep", true) +} + +func TestLoadTypedPackageGraphCustomReplaceNestedModuleParity(t *testing.T) { + root := t.TempDir() + appRoot := filepath.Join(root, "appmod") + depRoot := filepath.Join(appRoot, "third_party", "depmod") + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "func New() string { return \"ok\" }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => ./third_party/depmod", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Logger interface { Log(string) }", + "", + "type NoopLogger struct{}", + "", + "func (NoopLogger) Log(string) {}", + "", + "func New(Logger) string { return \"ok\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string {", + "\tvar logger dep.Logger = dep.NoopLogger{}", + "\treturn dep.New(logger)", + "}", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/dep", false) +} + +func TestLoadTypedPackageGraphCustomReplaceChainParity(t *testing.T) { + root := t.TempDir() + depRoot := filepath.Join(root, "depmod") + midRoot := filepath.Join(root, "midmod") + appRoot := filepath.Join(root, "appmod") + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "func New() string { return \"ok\" }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(midRoot, "go.mod"), "module example.com/mid\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(midRoot, "mid.go"), strings.Join([]string{ + "package mid", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.New() }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require (", + "\texample.com/dep v0.0.0", + "\texample.com/mid v0.0.0", + ")", + "", + "replace example.com/dep => " + depRoot, + "replace example.com/mid => " + midRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/mid\"", + "", + "func Use() string { return mid.Use() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Logger interface { Log(string) }", + "", + "type NoopLogger struct{}", + "", + "func (NoopLogger) Log(string) {}", + "", + "func New(Logger) string { return \"ok\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(midRoot, "mid.go"), strings.Join([]string{ + "package mid", + "", + "import \"example.com/dep\"", + "", + "func Use() string {", + "\tvar logger dep.Logger = dep.NoopLogger{}", + "\treturn dep.New(logger)", + "}", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/mid", false) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/dep", false) +} + +func TestLoadTypedPackageGraphCustomGoWorkWorkspaceParity(t *testing.T) { + root := t.TempDir() + appRoot := filepath.Join(root, "appmod") + depRoot := filepath.Join(root, "depmod") + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(root, "go.work"), strings.Join([]string{ + "go 1.19", + "", + "use (", + "\t./appmod", + "\t./depmod", + ")", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "func New() string { return \"ok\" }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"strings\"", + "", + "type Logger interface { Log(string) }", + "", + "type NoopLogger struct{}", + "", + "func (NoopLogger) Log(string) {}", + "", + "func New(Logger) string { return strings.TrimSpace(\" ok \") }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string {", + "\tvar logger dep.Logger = dep.NoopLogger{}", + "\treturn dep.New(logger)", + "}", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/dep", false) +} + +func TestLoadTypedPackageGraphCustomCrossWorkspaceModuleIsolation(t *testing.T) { + cacheHome := t.TempDir() + repoOne := filepath.Join(t.TempDir(), "repo-one") + repoTwo := filepath.Join(t.TempDir(), "repo-two") + + writeTestFile(t, filepath.Join(repoOne, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(repoOne, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "func Message() string { return \"one\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(repoOne, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() string { return dep.Message() }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(repoTwo, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(repoTwo, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"strings\"", + "", + "func Message() string { return strings.ToUpper(\"two\") }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(repoTwo, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() string { return dep.Message() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+cacheHome) + warm := loadTypedPackageGraphForTest(t, repoOne, env, "example.com/app/app", ModeCustom) + if len(warm.Packages) != 1 || len(warm.Packages[0].Errors) != 0 { + t.Fatalf("repo one warm custom load returned errors: %+v", warm.Packages) + } + + custom := loadTypedPackageGraphForTest(t, repoTwo, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, repoTwo, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/app/dep", true) +} + +func TestDiscoveryCacheInvalidatesOnLocalImportChangeEndToEnd(t *testing.T) { + root := t.TempDir() + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "func Base() string { return \"base\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "extra", "extra.go"), strings.Join([]string{ + "package extra", + "", + "func Value() string { return \"extra\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() string { return dep.Base() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"example.com/app/extra\"", + "", + "func Base() string { return extra.Value() }", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeFallback) + comparePackageGraphs(t, custom.Packages, fallback.Packages, true) +} + +func TestLoadTypedPackageGraphCustomLocalShapeChangeWarmParity(t *testing.T) { + root := t.TempDir() + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "type T struct{}", + "", + "func New() *T { return &T{} }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() *dep.T { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Config struct{}", + "", + "type T struct{}", + "", + "func New(Config) *T { return &T{} }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() *dep.T { return dep.New(dep.Config{}) }", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeFallback) + comparePackageGraphs(t, custom.Packages, fallback.Packages, true) +} + +func TestLoadTypedPackageGraphCustomTransitiveBodyOnlyWarmParity(t *testing.T) { + root := t.TempDir() + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "b", "b.go"), strings.Join([]string{ + "package b", + "", + "import \"fmt\"", + "", + "func Message() string { return fmt.Sprint(\"before\") }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "a", "a.go"), strings.Join([]string{ + "package a", + "", + "import \"example.com/app/b\"", + "", + "func Message() string { return b.Message() }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/a\"", + "", + "func Init() string { return a.Message() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(root, "b", "b.go"), strings.Join([]string{ + "package b", + "", + "import \"fmt\"", + "", + "func Message() string { return fmt.Sprint(\"after\") }", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/app/a", true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/app/b", true) +} + +func TestLoadTypedPackageGraphCustomKnownShapeToggleWarmParity(t *testing.T) { + root := t.TempDir() + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Config struct { Name string }", + "", + "func New(Config) string { return \"config\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() string { return dep.New(dep.Config{Name: \"a\"}) }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Logger interface { Log(string) }", + "", + "type NoopLogger struct{}", + "", + "func (NoopLogger) Log(string) {}", + "", + "func New(Logger) string { return \"logger\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() string {", + "\tvar logger dep.Logger = dep.NoopLogger{}", + "\treturn dep.New(logger)", + "}", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeFallback) + comparePackageGraphs(t, custom.Packages, fallback.Packages, true) +} + +func TestLoadTypedPackageGraphCustomNewShapeWarmParity(t *testing.T) { + root := t.TempDir() + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "func New() string { return \"ok\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() string { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Config struct{}", + "", + "func New() string { return \"ok\" }", + "", + "func NewWithConfig(Config) string { return \"ok\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() string { return dep.NewWithConfig(dep.Config{}) }", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeFallback) + comparePackageGraphs(t, custom.Packages, fallback.Packages, true) +} + +func TestLoadTypedPackageGraphCustomReplaceTargetBodyOnlyWarmParity(t *testing.T) { + root := t.TempDir() + depRoot := filepath.Join(root, "depmod") + appRoot := filepath.Join(root, "appmod") + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "func Message() string { return fmt.Sprint(\"before\") }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.Message() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "func Message() string { return fmt.Sprint(\"after\") }", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/dep", false) +} + +func TestLoadTypedPackageGraphCustomReplaceTargetShapeChangeWarmParity(t *testing.T) { + root := t.TempDir() + depRoot := filepath.Join(root, "depmod") + appRoot := filepath.Join(root, "appmod") + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "func New() string { return \"ok\" }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Logger interface { Log(string) }", + "", + "type NoopLogger struct{}", + "", + "func (NoopLogger) Log(string) {}", + "", + "func New(Logger) string { return \"ok\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string {", + "\tvar logger dep.Logger = dep.NoopLogger{}", + "\treturn dep.New(logger)", + "}", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/dep", false) +} + +func TestLoadTypedPackageGraphCustomFixtureAppWarmMutationParity(t *testing.T) { + root := t.TempDir() + appRoot := filepath.Join(root, "appmod") + depRoot := filepath.Join(root, "depmod") + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "func Message() string { return fmt.Sprint(\"dep\") }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "base", "base.go"), strings.Join([]string{ + "package base", + "", + "import \"fmt\"", + "", + "func Prefix() string { return fmt.Sprint(\"base:\") }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "gen", "zz_generated.go"), strings.Join([]string{ + "package gen", + "", + "func Value() string { return \"generated\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "feature", "feature.go"), strings.Join([]string{ + "package feature", + "", + "import (", + "\t\"example.com/app/base\"", + "\t\"example.com/app/gen\"", + "\t\"example.com/dep\"", + ")", + "", + "func Message() string {", + "\treturn base.Prefix() + dep.Message() + gen.Value()", + "}", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/feature\"", + "", + "func Init() string { return feature.Message() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + coldCustom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + if len(coldCustom.Packages) != 1 || len(coldCustom.Packages[0].Errors) != 0 { + t.Fatalf("cold custom load returned errors: %+v", coldCustom.Packages) + } + coldFallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, coldCustom.Packages, coldFallback.Packages, true) + comparePackageByPath(t, coldCustom.Packages, coldFallback.Packages, "example.com/app/feature", true) + comparePackageByPath(t, coldCustom.Packages, coldFallback.Packages, "example.com/app/gen", true) + comparePackageByPath(t, coldCustom.Packages, coldFallback.Packages, "example.com/dep", false) + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "type Logger interface { Log(string) }", + "", + "type NoopLogger struct{}", + "", + "func (NoopLogger) Log(string) {}", + "", + "func Message(Logger) string { return fmt.Sprint(\"dep2\") }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "gen", "zz_generated.go"), strings.Join([]string{ + "package gen", + "", + "func Value() string { return \"generated2\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "feature", "feature.go"), strings.Join([]string{ + "package feature", + "", + "import (", + "\t\"example.com/app/base\"", + "\t\"example.com/app/gen\"", + "\t\"example.com/dep\"", + ")", + "", + "func Message() string {", + "\tvar logger dep.Logger = dep.NoopLogger{}", + "\treturn base.Prefix() + dep.Message(logger) + gen.Value()", + "}", + "", + }, "\n")) + + warmCustom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + warmFallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, warmCustom.Packages, warmFallback.Packages, true) + comparePackageByPath(t, warmCustom.Packages, warmFallback.Packages, "example.com/app/feature", true) + comparePackageByPath(t, warmCustom.Packages, warmFallback.Packages, "example.com/app/gen", true) + comparePackageByPath(t, warmCustom.Packages, warmFallback.Packages, "example.com/dep", false) +} + +func TestLoadTypedPackageGraphCustomSequentialMutationsParity(t *testing.T) { + root := t.TempDir() + depRoot := filepath.Join(root, "depmod") + appRoot := filepath.Join(root, "appmod") + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "func Message() string { return fmt.Sprint(\"dep\") }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "helper", "helper.go"), strings.Join([]string{ + "package helper", + "", + "func Prefix() string { return \"prefix:\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Init() string { return dep.Message() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + + assertParity := func() { + custom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/dep", false) + } + + initial := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + if len(initial.Packages) != 1 || len(initial.Packages[0].Errors) != 0 { + t.Fatalf("initial custom load returned errors: %+v", initial.Packages) + } + assertParity() + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "func Message() string { return fmt.Sprint(\"dep-body\") }", + "", + }, "\n")) + assertParity() + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(appRoot, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import (", + "\t\"example.com/app/helper\"", + "\t\"example.com/dep\"", + ")", + "", + "func Init() string { return helper.Prefix() + dep.Message() }", + "", + }, "\n")) + assertParity() + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "type Logger interface { Log(string) }", + "", + "type NoopLogger struct{}", + "", + "func (NoopLogger) Log(string) {}", + "", + "func Message(Logger) string { return fmt.Sprint(\"dep-shape\") }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import (", + "\t\"example.com/app/helper\"", + "\t\"example.com/dep\"", + ")", + "", + "func Init() string {", + "\tvar logger dep.Logger = dep.NoopLogger{}", + "\treturn helper.Prefix() + dep.Message(logger)", + "}", + "", + }, "\n")) + assertParity() + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "func Message() string { return fmt.Sprint(\"dep\") }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Init() string { return dep.Message() }", + "", + }, "\n")) + assertParity() +} + +func TestDiscoveryCacheInvalidatesOnGoSumResolutionChange(t *testing.T) { + root := t.TempDir() + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/google/go-cmp v0.6.0", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "go.sum"), strings.Join([]string{ + "github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=", + "github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"github.com/google/go-cmp/cmp\"", + "", + "func Init() string { return cmp.Diff(\"a\", \"b\") }", + "", + }, "\n")) + + env := append(os.Environ(), + "HOME="+homeDir, + "GOPROXY=off", + "GONOSUMDB=*", + "GOCACHE=/tmp/gocache", + "GOMODCACHE=/tmp/gomodcache", + ) + + first := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 { + t.Fatalf("first custom packages len = %d, want 1", len(first.Packages)) + } + if got := comparableErrors(first.Packages[0].Errors); len(got) != 0 { + t.Fatalf("first custom load returned errors: %v", got) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(root, "go.sum"), "") + + custom := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) +} + +func TestLoadTypedPackageGraphCustomExternalVersionChangeBustsCache(t *testing.T) { + root := t.TempDir() + proxyDir := t.TempDir() + artifactDir := t.TempDir() + homeDir := t.TempDir() + + writeModuleProxyVersion(t, proxyDir, "example.com/extdep", "v1.0.0", map[string]string{ + "pkg/pkg.go": "package pkg\n\nfunc Version() string { return \"v1.0.0\" }\n", + }) + writeModuleProxyVersion(t, proxyDir, "example.com/extdep", "v1.1.0", map[string]string{ + "pkg/pkg.go": "package pkg\n\nimport \"strings\"\n\nfunc Version() string { return strings.TrimSpace(\"v1.1.0\") }\n", + }) + + writeTestFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/extdep v1.0.0", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/extdep/pkg\"", + "", + "func Init() string { return pkg.Version() }", + "", + }, "\n")) + + env := append(os.Environ(), + "HOME="+homeDir, + "GOPROXY=file://"+proxyDir, + "GOSUMDB=off", + "GOCACHE=/tmp/gocache", + "GOMODCACHE=/tmp/gomodcache", + loaderArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + runGoModTidyForTest(t, root, env) + + first := loadPackagesForTest(t, root, env, []string{"./app"}, ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + firstDep := collectGraph(first.Packages)["example.com/extdep/pkg"] + if firstDep == nil { + t.Fatal("expected dependency package for example.com/extdep/pkg") + } + if !containsPathSubstring(firstDep.CompiledGoFiles, "example.com/extdep@v1.0.0") { + t.Fatalf("first dependency files = %v, want version v1.0.0", firstDep.CompiledGoFiles) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/extdep v1.1.0", + "", + }, "\n")) + runGoModTidyForTest(t, root, env) + + custom := loadPackagesForTest(t, root, env, []string{"./app"}, ModeCustom) + fallback := loadPackagesForTest(t, root, env, []string{"./app"}, ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/extdep/pkg", false) + + secondDep := collectGraph(custom.Packages)["example.com/extdep/pkg"] + if secondDep == nil { + t.Fatal("expected dependency package for example.com/extdep/pkg after version change") + } + if !containsPathSubstring(secondDep.CompiledGoFiles, "example.com/extdep@v1.1.0") { + t.Fatalf("second dependency files = %v, want version v1.1.0", secondDep.CompiledGoFiles) + } +} + func TestLoaderArtifactKeyExternalChangesWhenExportFileChanges(t *testing.T) { exportPath := filepath.Join(t.TempDir(), "dep.a") writeTestFile(t, exportPath, "first export payload") @@ -862,6 +2381,225 @@ func TestLoaderArtifactKeyExternalWithoutExportChangesWhenSourceChanges(t *testi } } +func TestRunGoListIncludesExportDataForReplacedModule(t *testing.T) { + root := t.TempDir() + depRoot := filepath.Join(root, "depmod") + appRoot := filepath.Join(root, "appmod") + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), "package dep\n\nfunc New() string { return \"ok\" }\n") + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), "package app\n\nimport \"example.com/dep\"\n\nfunc Use() string { return dep.New() }\n") + + meta, err := runGoList(context.Background(), goListRequest{ + WD: appRoot, + Env: append(os.Environ(), "GOCACHE=/tmp/gocache", "GOMODCACHE=/tmp/gomodcache"), + Patterns: []string{"example.com/app/app"}, + NeedDeps: true, + }) + if err != nil { + t.Fatalf("runGoList() error = %v", err) + } + depMeta := meta["example.com/dep"] + if depMeta == nil { + t.Fatal("expected metadata for example.com/dep") + } + if depMeta.Export == "" { + t.Fatalf("expected export data path for replaced module metadata: %+v", depMeta) + } +} + +func TestLoadTypedPackageGraphCustomReplaceTargetWithExportDataWarmParity(t *testing.T) { + root := t.TempDir() + depRoot := filepath.Join(root, "depmod") + appRoot := filepath.Join(root, "appmod") + artifactDir := t.TempDir() + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "func New() string { return \"ok\" }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), + "HOME="+homeDir, + "GOCACHE=/tmp/gocache", + "GOMODCACHE=/tmp/gomodcache", + loaderArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + + meta, err := runGoList(context.Background(), goListRequest{ + WD: appRoot, + Env: env, + Patterns: []string{"example.com/app/app"}, + NeedDeps: true, + }) + if err != nil { + t.Fatalf("runGoList() error = %v", err) + } + depMeta := meta["example.com/dep"] + if depMeta == nil || depMeta.Export == "" { + t.Fatalf("expected export-backed metadata for example.com/dep: %+v", depMeta) + } + + first := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Logger interface { Log(string) }", + "", + "type NoopLogger struct{}", + "", + "func (NoopLogger) Log(string) {}", + "", + "func New(Logger) string { return \"ok\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string {", + "\tvar logger dep.Logger = dep.NoopLogger{}", + "\treturn dep.New(logger)", + "}", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/dep", false) +} + +func TestLoadTypedPackageGraphCustomReplaceTargetWithoutExportDataWarmParity(t *testing.T) { + root := t.TempDir() + depRoot := filepath.Join(root, "depmod") + appRoot := filepath.Join(root, "appmod") + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "//go:build never", + "", + "func New() string { return \"ok\" }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), + "HOME="+homeDir, + "GOCACHE=/tmp/gocache", + "GOMODCACHE=/tmp/gomodcache", + ) + + meta, err := runGoList(context.Background(), goListRequest{ + WD: appRoot, + Env: env, + Patterns: []string{"example.com/app/app"}, + NeedDeps: true, + }) + if err != nil { + t.Fatalf("runGoList(first) error = %v", err) + } + depMeta := meta["example.com/dep"] + if depMeta == nil || depMeta.Export != "" { + t.Fatalf("expected no export data for incomplete replaced module: %+v", depMeta) + } + + first := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 { + t.Fatalf("first custom packages len = %d, want 1", len(first.Packages)) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "var _ missing", + "", + "func New() string { return \"ok\" }", + "", + }, "\n")) + + meta, err = runGoList(context.Background(), goListRequest{ + WD: appRoot, + Env: env, + Patterns: []string{"example.com/app/app"}, + NeedDeps: true, + }) + if err != nil { + t.Fatalf("runGoList(second) error = %v", err) + } + depMeta = meta["example.com/dep"] + if depMeta == nil || depMeta.Export != "" { + t.Fatalf("expected no export data for second incomplete replaced module state: %+v", depMeta) + } + + custom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) +} + func TestLoadTypedPackageGraphCustomExternalArtifactCacheRealAppParity(t *testing.T) { root := os.Getenv("WIRE_REAL_APP_ROOT") if root == "" { @@ -1142,6 +2880,44 @@ func compareRootPackagesOnly(t *testing.T, got []*packages.Package, want []*pack } } +func comparePackageByPath(t *testing.T, got []*packages.Package, want []*packages.Package, pkgPath string, requireTyped bool) { + t.Helper() + gotPkg := collectGraph(got)[pkgPath] + if gotPkg == nil { + t.Fatalf("missing package %q in custom graph", pkgPath) + } + wantPkg := collectGraph(want)[pkgPath] + if wantPkg == nil { + t.Fatalf("missing package %q in fallback graph", pkgPath) + } + if gotPkg.Name != wantPkg.Name { + t.Fatalf("package %q name = %q, want %q", pkgPath, gotPkg.Name, wantPkg.Name) + } + if !equalStrings(gotPkg.CompiledGoFiles, wantPkg.CompiledGoFiles) { + t.Fatalf("package %q compiled files = %v, want %v", pkgPath, gotPkg.CompiledGoFiles, wantPkg.CompiledGoFiles) + } + if !equalImportPaths(gotPkg.Imports, wantPkg.Imports) { + t.Fatalf("package %q imports = %v, want %v", pkgPath, sortedImportPaths(gotPkg.Imports), sortedImportPaths(wantPkg.Imports)) + } + gotErrs := comparableErrors(gotPkg.Errors) + wantErrs := comparableErrors(wantPkg.Errors) + if len(gotErrs) != len(wantErrs) { + t.Fatalf("package %q comparable errors len = %d, want %d; got=%v want=%v", pkgPath, len(gotErrs), len(wantErrs), gotErrs, wantErrs) + } + for i := range gotErrs { + if gotErrs[i] != wantErrs[i] { + t.Fatalf("package %q comparable error[%d] = %q, want %q", pkgPath, i, gotErrs[i], wantErrs[i]) + } + } + if requireTyped { + gotTyped := gotPkg.Types != nil && gotPkg.TypesInfo != nil && len(gotPkg.Syntax) > 0 + wantTyped := wantPkg.Types != nil && wantPkg.TypesInfo != nil && len(wantPkg.Syntax) > 0 + if gotTyped != wantTyped { + t.Fatalf("package %q typed state = %v, want %v", pkgPath, gotTyped, wantTyped) + } + } +} + func collectGraph(roots []*packages.Package) map[string]*packages.Package { out := make(map[string]*packages.Package) stack := append([]*packages.Package(nil), roots...) @@ -1159,6 +2935,68 @@ func collectGraph(roots []*packages.Package) map[string]*packages.Package { return out } +func loadTypedPackageGraphForTest(t *testing.T, wd string, env []string, pkg string, mode Mode) *LazyLoadResult { + return loadTypedPackageGraphWithDiscoveryForTest(t, wd, env, pkg, mode, nil) +} + +func loadPackagesForTest(t *testing.T, wd string, env []string, patterns []string, mode Mode) *PackageLoadResult { + t.Helper() + l := New() + got, err := l.LoadPackages(context.Background(), PackageLoadRequest{ + WD: wd, + Env: env, + Patterns: patterns, + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: mode, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadPackages(%q, %q) error = %v", wd, mode, err) + } + return got +} + +func loadTypedPackageGraphWithDiscoveryForTest(t *testing.T, wd string, env []string, pkg string, mode Mode, discovery *DiscoverySnapshot) *LazyLoadResult { + t.Helper() + l := New() + got, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: wd, + Env: env, + Package: pkg, + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: mode, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + Discovery: discovery, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(%q, %q) error = %v", wd, mode, err) + } + return got +} + +func loadRootGraphForTest(t *testing.T, wd string, env []string, patterns []string, mode Mode) *RootLoadResult { + t.Helper() + l := New() + got, err := l.LoadRootGraph(context.Background(), RootLoadRequest{ + WD: wd, + Env: env, + Patterns: patterns, + NeedDeps: true, + Mode: mode, + Fset: token.NewFileSet(), + }) + if err != nil { + t.Fatalf("LoadRootGraph(%q, %q) error = %v", wd, mode, err) + } + return got +} + func compiledFileCount(pkgs map[string]*packages.Package) int { total := 0 for _, pkg := range pkgs { @@ -1202,6 +3040,85 @@ func sortedImportPaths(m map[string]*packages.Package) []string { return out } +func containsPathSubstring(paths []string, needle string) bool { + for _, path := range paths { + if strings.Contains(normalizePathForCompare(path), needle) { + return true + } + } + return false +} + +func runGoModTidyForTest(t *testing.T, wd string, env []string) { + t.Helper() + cmd := exec.Command("go", "mod", "tidy") + cmd.Dir = wd + cmd.Env = env + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("go mod tidy in %q error = %v: %s", wd, err, out) + } +} + +func writeModuleProxyVersion(t *testing.T, proxyDir string, modulePath string, version string, files map[string]string) { + t.Helper() + base := filepath.Join(proxyDir, filepath.FromSlash(modulePath), "@v") + if err := os.MkdirAll(base, 0o755); err != nil { + t.Fatalf("mkdir proxy dir: %v", err) + } + listPath := filepath.Join(base, "list") + appendLineIfMissing(t, listPath, version) + + modFile := "module " + modulePath + "\n\ngo 1.19\n" + writeTestFile(t, filepath.Join(base, version+".mod"), modFile) + writeTestFile(t, filepath.Join(base, version+".info"), fmt.Sprintf("{\"Version\":%q,\"Time\":\"2024-01-01T00:00:00Z\"}\n", version)) + + zipPath := filepath.Join(base, version+".zip") + zipFile, err := os.Create(zipPath) + if err != nil { + t.Fatalf("create proxy zip: %v", err) + } + defer zipFile.Close() + + zw := zip.NewWriter(zipFile) + moduleRoot := modulePath + "@" + version + writeZipFile := func(name string, contents string) { + w, err := zw.Create(moduleRoot + "/" + filepath.ToSlash(name)) + if err != nil { + t.Fatalf("create zip entry %q: %v", name, err) + } + if _, err := w.Write([]byte(contents)); err != nil { + t.Fatalf("write zip entry %q: %v", name, err) + } + } + writeZipFile("go.mod", modFile) + for name, contents := range files { + writeZipFile(name, contents) + } + if err := zw.Close(); err != nil { + t.Fatalf("close proxy zip: %v", err) + } +} + +func appendLineIfMissing(t *testing.T, path string, line string) { + t.Helper() + existing, err := os.ReadFile(path) + if err != nil && !os.IsNotExist(err) { + t.Fatalf("read %q: %v", path, err) + } + for _, existingLine := range strings.Split(strings.TrimSpace(string(existing)), "\n") { + if existingLine == line { + return + } + } + content := string(existing) + if strings.TrimSpace(content) != "" && !strings.HasSuffix(content, "\n") { + content += "\n" + } + content += line + "\n" + writeTestFile(t, path, content) +} + type importerFuncForTest func(string) (*types.Package, error) func (f importerFuncForTest) Import(path string) (*types.Package, error) { From fea5e0ab08738a121576c549eb2e7916f5d35917 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 16:19:24 -0500 Subject: [PATCH 24/79] fix(loader): treat replaced workspace deps as local and harden runtests --- internal/loader/custom.go | 47 +++++++++++++++++++++++++++++- internal/loader/discovery_cache.go | 22 +++++++++++--- internal/loader/loader_test.go | 21 ++++++------- internal/runtests.sh | 8 ++++- 4 files changed, 82 insertions(+), 16 deletions(-) diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 10f8c79..18c0f7d 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -54,9 +54,19 @@ type packageMeta struct { CompiledGoFiles []string Imports []string ImportMap map[string]string + Module *goListModule Error *goListError } +type goListModule struct { + Path string + Version string + Main bool + Dir string + GoMod string + Replace *goListModule +} + type goListError struct { Err string } @@ -882,7 +892,7 @@ func (l *customTypedGraphLoader) isLocalPackage(importPath string, meta *package if local, ok := l.isLocalCache[importPath]; ok { return local } - local := isWorkspacePackage(l.workspace, meta.Dir) + local := isLocalSourcePackage(l.workspace, meta) l.isLocalCache[importPath] = local return local } @@ -1262,6 +1272,41 @@ func isWorkspacePackage(workspaceRoot, dir string) bool { return rel != ".." && !strings.HasPrefix(rel, ".."+string(filepath.Separator)) } +func isLocalSourcePackage(workspaceRoot string, meta *packageMeta) bool { + if meta == nil { + return false + } + if isWorkspacePackage(workspaceRoot, meta.Dir) { + return true + } + mod := localSourceModule(meta.Module) + if mod == nil { + return false + } + if mod.Main { + return true + } + return canonicalLoaderPath(mod.Dir) == canonicalLoaderPath(meta.Dir) || isWorkspacePackage(canonicalLoaderPath(mod.Dir), meta.Dir) +} + +func localSourceModule(mod *goListModule) *goListModule { + if mod == nil { + return nil + } + if mod.Replace != nil { + if local := localSourceModule(mod.Replace); local != nil { + return local + } + } + if mod.Main && mod.Dir != "" { + return mod + } + if mod.Replace != nil && mod.Replace.Dir != "" { + return mod.Replace + } + return nil +} + func detectModuleRoot(start string) string { start = canonicalLoaderPath(start) for dir := start; dir != "" && dir != "." && dir != string(filepath.Separator); dir = filepath.Dir(dir) { diff --git a/internal/loader/discovery_cache.go b/internal/loader/discovery_cache.go index 3b9fe46..e3db86b 100644 --- a/internal/loader/discovery_cache.go +++ b/internal/loader/discovery_cache.go @@ -71,7 +71,7 @@ func writeDiscoveryCache(req goListRequest, meta map[string]*packageMeta) { func buildDiscoveryCacheEntry(req goListRequest, meta map[string]*packageMeta) (*discoveryCacheEntry, error) { workspace := detectModuleRoot(req.WD) entry := &discoveryCacheEntry{ - Version: 2, + Version: 3, WD: canonicalLoaderPath(req.WD), Tags: req.Tags, Patterns: append([]string(nil), req.Patterns...), @@ -92,7 +92,7 @@ func buildDiscoveryCacheEntry(req goListRequest, meta map[string]*packageMeta) ( } locals := make([]discoveryLocalPackage, 0) for _, pkg := range meta { - if pkg == nil || !isWorkspacePackage(workspace, pkg.Dir) { + if pkg == nil || !isLocalSourcePackage(workspace, pkg) { continue } lp := discoveryLocalPackage{ @@ -116,7 +116,7 @@ func buildDiscoveryCacheEntry(req goListRequest, meta map[string]*packageMeta) ( } func validateDiscoveryCacheEntry(entry *discoveryCacheEntry) bool { - if entry == nil || entry.Version != 2 { + if entry == nil || entry.Version != 3 { return false } for _, fm := range entry.Global { @@ -150,7 +150,7 @@ func discoveryCachePath(req goListRequest) (string, error) { NeedDeps bool Go string }{ - Version: 2, + Version: 3, WD: canonicalLoaderPath(req.WD), Tags: req.Tags, Patterns: append([]string(nil), req.Patterns...), @@ -321,6 +321,9 @@ func clonePackageMetaMap(in map[string]*packageMeta) map[string]*packageMeta { cp.ImportMap[mk] = mv } } + if v.Module != nil { + cp.Module = cloneGoListModule(v.Module) + } if v.Error != nil { errCopy := *v.Error cp.Error = &errCopy @@ -329,3 +332,14 @@ func clonePackageMetaMap(in map[string]*packageMeta) map[string]*packageMeta { } return out } + +func cloneGoListModule(in *goListModule) *goListModule { + if in == nil { + return nil + } + cp := *in + if in.Replace != nil { + cp.Replace = cloneGoListModule(in.Replace) + } + return &cp +} diff --git a/internal/loader/loader_test.go b/internal/loader/loader_test.go index 1cb080d..065b4fb 100644 --- a/internal/loader/loader_test.go +++ b/internal/loader/loader_test.go @@ -2198,37 +2198,38 @@ func TestLoadTypedPackageGraphCustomSequentialMutationsParity(t *testing.T) { func TestDiscoveryCacheInvalidatesOnGoSumResolutionChange(t *testing.T) { root := t.TempDir() + proxyDir := t.TempDir() homeDir := t.TempDir() + writeModuleProxyVersion(t, proxyDir, "example.com/extdep", "v1.0.0", map[string]string{ + "pkg/pkg.go": "package pkg\n\nfunc Version() string { return \"v1.0.0\" }\n", + }) + writeTestFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ "module example.com/app", "", "go 1.19", "", - "require github.com/google/go-cmp v0.6.0", - "", - }, "\n")) - writeTestFile(t, filepath.Join(root, "go.sum"), strings.Join([]string{ - "github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=", - "github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=", + "require example.com/extdep v1.0.0", "", }, "\n")) writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ "package app", "", - "import \"github.com/google/go-cmp/cmp\"", + "import \"example.com/extdep/pkg\"", "", - "func Init() string { return cmp.Diff(\"a\", \"b\") }", + "func Init() string { return pkg.Version() }", "", }, "\n")) env := append(os.Environ(), "HOME="+homeDir, - "GOPROXY=off", - "GONOSUMDB=*", + "GOPROXY=file://"+proxyDir, + "GOSUMDB=off", "GOCACHE=/tmp/gocache", "GOMODCACHE=/tmp/gomodcache", ) + runGoModTidyForTest(t, root, env) first := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) if len(first.Packages) != 1 { diff --git a/internal/runtests.sh b/internal/runtests.sh index 28877c1..7d2ddcb 100755 --- a/internal/runtests.sh +++ b/internal/runtests.sh @@ -16,6 +16,9 @@ # https://coderwall.com/p/fkfaqq/safer-bash-scripts-with-set-euxo-pipefail set -euo pipefail +export GOCACHE="${GOCACHE:-/tmp/gocache}" +export GOMODCACHE="${GOMODCACHE:-/tmp/gomodcache}" + if [[ $# -gt 0 ]]; then echo "usage: runtests.sh" 1>&2 exit 64 @@ -34,7 +37,10 @@ fi echo echo "Ensuring .go files are formatted with gofmt -s..." -mapfile -t go_files < <(find . -name '*.go' -type f | grep -v testdata) +go_files=() +while IFS= read -r file; do + go_files+=("$file") +done < <(find . -name '*.go' -type f | grep -v testdata) DIFF="$(gofmt -s -d "${go_files[@]}")" if [ -n "$DIFF" ]; then echo "FAIL: please run gofmt -s and commit the result" From 568d2e08e3de5035cc6257405399e210f9394a51 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 16:35:40 -0500 Subject: [PATCH 25/79] fix(loader): make cache-hardening tests and runtests portable --- internal/loader/loader_test.go | 51 ++++++++++++++++++++++++++++------ internal/runtests.sh | 9 ++++-- 2 files changed, 49 insertions(+), 11 deletions(-) diff --git a/internal/loader/loader_test.go b/internal/loader/loader_test.go index 065b4fb..4e94afc 100644 --- a/internal/loader/loader_test.go +++ b/internal/loader/loader_test.go @@ -2200,6 +2200,8 @@ func TestDiscoveryCacheInvalidatesOnGoSumResolutionChange(t *testing.T) { root := t.TempDir() proxyDir := t.TempDir() homeDir := t.TempDir() + goCacheDir := tempCacheDirForTest(t, "wire-gocache-") + goModCacheDir := tempCacheDirForTest(t, "wire-gomodcache-") writeModuleProxyVersion(t, proxyDir, "example.com/extdep", "v1.0.0", map[string]string{ "pkg/pkg.go": "package pkg\n\nfunc Version() string { return \"v1.0.0\" }\n", @@ -2226,8 +2228,8 @@ func TestDiscoveryCacheInvalidatesOnGoSumResolutionChange(t *testing.T) { "HOME="+homeDir, "GOPROXY=file://"+proxyDir, "GOSUMDB=off", - "GOCACHE=/tmp/gocache", - "GOMODCACHE=/tmp/gomodcache", + "GOCACHE="+goCacheDir, + "GOMODCACHE="+goModCacheDir, ) runGoModTidyForTest(t, root, env) @@ -2252,6 +2254,8 @@ func TestLoadTypedPackageGraphCustomExternalVersionChangeBustsCache(t *testing.T proxyDir := t.TempDir() artifactDir := t.TempDir() homeDir := t.TempDir() + goCacheDir := tempCacheDirForTest(t, "wire-gocache-") + goModCacheDir := tempCacheDirForTest(t, "wire-gomodcache-") writeModuleProxyVersion(t, proxyDir, "example.com/extdep", "v1.0.0", map[string]string{ "pkg/pkg.go": "package pkg\n\nfunc Version() string { return \"v1.0.0\" }\n", @@ -2281,8 +2285,8 @@ func TestLoadTypedPackageGraphCustomExternalVersionChangeBustsCache(t *testing.T "HOME="+homeDir, "GOPROXY=file://"+proxyDir, "GOSUMDB=off", - "GOCACHE=/tmp/gocache", - "GOMODCACHE=/tmp/gomodcache", + "GOCACHE="+goCacheDir, + "GOMODCACHE="+goModCacheDir, loaderArtifactEnv+"=1", loaderArtifactDirEnv+"="+artifactDir, ) @@ -2386,6 +2390,8 @@ func TestRunGoListIncludesExportDataForReplacedModule(t *testing.T) { root := t.TempDir() depRoot := filepath.Join(root, "depmod") appRoot := filepath.Join(root, "appmod") + goCacheDir := tempCacheDirForTest(t, "wire-gocache-") + goModCacheDir := tempCacheDirForTest(t, "wire-gomodcache-") writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") writeTestFile(t, filepath.Join(depRoot, "dep.go"), "package dep\n\nfunc New() string { return \"ok\" }\n") @@ -2404,7 +2410,7 @@ func TestRunGoListIncludesExportDataForReplacedModule(t *testing.T) { meta, err := runGoList(context.Background(), goListRequest{ WD: appRoot, - Env: append(os.Environ(), "GOCACHE=/tmp/gocache", "GOMODCACHE=/tmp/gomodcache"), + Env: append(os.Environ(), "GOCACHE="+goCacheDir, "GOMODCACHE="+goModCacheDir), Patterns: []string{"example.com/app/app"}, NeedDeps: true, }) @@ -2426,6 +2432,8 @@ func TestLoadTypedPackageGraphCustomReplaceTargetWithExportDataWarmParity(t *tes appRoot := filepath.Join(root, "appmod") artifactDir := t.TempDir() homeDir := t.TempDir() + goCacheDir := tempCacheDirForTest(t, "wire-gocache-") + goModCacheDir := tempCacheDirForTest(t, "wire-gomodcache-") writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ @@ -2456,8 +2464,8 @@ func TestLoadTypedPackageGraphCustomReplaceTargetWithExportDataWarmParity(t *tes env := append(os.Environ(), "HOME="+homeDir, - "GOCACHE=/tmp/gocache", - "GOMODCACHE=/tmp/gomodcache", + "GOCACHE="+goCacheDir, + "GOMODCACHE="+goModCacheDir, loaderArtifactEnv+"=1", loaderArtifactDirEnv+"="+artifactDir, ) @@ -2517,6 +2525,8 @@ func TestLoadTypedPackageGraphCustomReplaceTargetWithoutExportDataWarmParity(t * depRoot := filepath.Join(root, "depmod") appRoot := filepath.Join(root, "appmod") homeDir := t.TempDir() + goCacheDir := tempCacheDirForTest(t, "wire-gocache-") + goModCacheDir := tempCacheDirForTest(t, "wire-gomodcache-") writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ @@ -2549,8 +2559,8 @@ func TestLoadTypedPackageGraphCustomReplaceTargetWithoutExportDataWarmParity(t * env := append(os.Environ(), "HOME="+homeDir, - "GOCACHE=/tmp/gocache", - "GOMODCACHE=/tmp/gomodcache", + "GOCACHE="+goCacheDir, + "GOMODCACHE="+goModCacheDir, ) meta, err := runGoList(context.Background(), goListRequest{ @@ -3120,6 +3130,29 @@ func appendLineIfMissing(t *testing.T, path string, line string) { writeTestFile(t, path, content) } +func tempCacheDirForTest(t *testing.T, pattern string) string { + t.Helper() + dir, err := os.MkdirTemp("", pattern) + if err != nil { + t.Fatalf("MkdirTemp(%q) error = %v", pattern, err) + } + t.Cleanup(func() { + _ = filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return nil + } + if info.IsDir() { + _ = os.Chmod(path, 0o755) + return nil + } + _ = os.Chmod(path, 0o644) + return nil + }) + _ = os.RemoveAll(dir) + }) + return dir +} + type importerFuncForTest func(string) (*types.Package, error) func (f importerFuncForTest) Import(path string) (*types.Package, error) { diff --git a/internal/runtests.sh b/internal/runtests.sh index 7d2ddcb..905e319 100755 --- a/internal/runtests.sh +++ b/internal/runtests.sh @@ -16,8 +16,13 @@ # https://coderwall.com/p/fkfaqq/safer-bash-scripts-with-set-euxo-pipefail set -euo pipefail -export GOCACHE="${GOCACHE:-/tmp/gocache}" -export GOMODCACHE="${GOMODCACHE:-/tmp/gomodcache}" +tmp_root="${TMPDIR:-${RUNNER_TEMP:-}}" +if [[ -z "${tmp_root}" ]]; then + tmp_root="$(mktemp -d)" +fi + +export GOCACHE="${GOCACHE:-${tmp_root}/gocache}" +export GOMODCACHE="${GOMODCACHE:-${tmp_root}/gomodcache}" if [[ $# -gt 0 ]]; then echo "usage: runtests.sh" 1>&2 From 114d1740875b182ff6453ed7232df47d1b76a693 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 16:40:02 -0500 Subject: [PATCH 26/79] fix(loader): use valid file GOPROXY URLs in proxy-based tests --- internal/loader/loader_test.go | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/internal/loader/loader_test.go b/internal/loader/loader_test.go index 4e94afc..d539347 100644 --- a/internal/loader/loader_test.go +++ b/internal/loader/loader_test.go @@ -23,6 +23,7 @@ import ( "go/parser" "go/token" "go/types" + "net/url" "os" "os/exec" "path/filepath" @@ -2226,7 +2227,7 @@ func TestDiscoveryCacheInvalidatesOnGoSumResolutionChange(t *testing.T) { env := append(os.Environ(), "HOME="+homeDir, - "GOPROXY=file://"+proxyDir, + "GOPROXY="+fileURLForTest(t, proxyDir), "GOSUMDB=off", "GOCACHE="+goCacheDir, "GOMODCACHE="+goModCacheDir, @@ -2283,7 +2284,7 @@ func TestLoadTypedPackageGraphCustomExternalVersionChangeBustsCache(t *testing.T env := append(os.Environ(), "HOME="+homeDir, - "GOPROXY=file://"+proxyDir, + "GOPROXY="+fileURLForTest(t, proxyDir), "GOSUMDB=off", "GOCACHE="+goCacheDir, "GOMODCACHE="+goModCacheDir, @@ -3153,6 +3154,15 @@ func tempCacheDirForTest(t *testing.T, pattern string) string { return dir } +func fileURLForTest(t *testing.T, path string) string { + t.Helper() + u := &url.URL{ + Scheme: "file", + Path: filepath.ToSlash(path), + } + return u.String() +} + type importerFuncForTest func(string) (*types.Package, error) func (f importerFuncForTest) Import(path string) (*types.Package, error) { From 53890d7a7037c92179c6a6542789174be0f74c3f Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 17:00:43 -0500 Subject: [PATCH 27/79] fix(loader): format file GOPROXY URLs correctly on windows --- internal/loader/loader_test.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/internal/loader/loader_test.go b/internal/loader/loader_test.go index d539347..a0d96bc 100644 --- a/internal/loader/loader_test.go +++ b/internal/loader/loader_test.go @@ -23,7 +23,6 @@ import ( "go/parser" "go/token" "go/types" - "net/url" "os" "os/exec" "path/filepath" @@ -3156,11 +3155,11 @@ func tempCacheDirForTest(t *testing.T, pattern string) string { func fileURLForTest(t *testing.T, path string) string { t.Helper() - u := &url.URL{ - Scheme: "file", - Path: filepath.ToSlash(path), + slashed := filepath.ToSlash(path) + if !strings.HasPrefix(slashed, "/") { + slashed = "/" + slashed } - return u.String() + return "file://" + slashed } type importerFuncForTest func(string) (*types.Package, error) From efa02144b77bb6e99e10d0f1d412b8b231e393f2 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 17:04:50 -0500 Subject: [PATCH 28/79] fix(loader): normalize test path comparisons across platforms --- internal/loader/loader_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/loader/loader_test.go b/internal/loader/loader_test.go index a0d96bc..05cfaa7 100644 --- a/internal/loader/loader_test.go +++ b/internal/loader/loader_test.go @@ -3182,9 +3182,9 @@ func normalizePathForCompare(path string) string { return "" } if resolved, err := filepath.EvalSymlinks(path); err == nil && resolved != "" { - return filepath.Clean(resolved) + return filepath.ToSlash(filepath.Clean(resolved)) } - return filepath.Clean(path) + return filepath.ToSlash(filepath.Clean(path)) } func comparableErrors(errs []packages.Error) []string { From 541acdfe8fd2dc42d53f1988ed1f4738c0c1af4a Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 18:31:37 -0500 Subject: [PATCH 29/79] refactor: remove unused loader and wire helpers --- internal/loader/discovery_cache.go | 8 -------- internal/wire/loader_validation.go | 6 +----- internal/wire/wire.go | 20 -------------------- 3 files changed, 1 insertion(+), 33 deletions(-) diff --git a/internal/loader/discovery_cache.go b/internal/loader/discovery_cache.go index e3db86b..9d7d932 100644 --- a/internal/loader/discovery_cache.go +++ b/internal/loader/discovery_cache.go @@ -60,14 +60,6 @@ func readDiscoveryCache(req goListRequest) (map[string]*packageMeta, bool) { return clonePackageMetaMap(entry.Meta), true } -func writeDiscoveryCache(req goListRequest, meta map[string]*packageMeta) { - entry, err := buildDiscoveryCacheEntry(req, meta) - if err != nil { - return - } - _ = saveDiscoveryCacheEntry(req, entry) -} - func buildDiscoveryCacheEntry(req goListRequest, meta map[string]*packageMeta) (*discoveryCacheEntry, error) { workspace := detectModuleRoot(req.WD) entry := &discoveryCacheEntry{ diff --git a/internal/wire/loader_validation.go b/internal/wire/loader_validation.go index 6868b7b..cde4d60 100644 --- a/internal/wire/loader_validation.go +++ b/internal/wire/loader_validation.go @@ -20,11 +20,7 @@ import ( "github.com/goforj/wire/internal/loader" ) -func loaderValidationMode(ctx context.Context, wd string, env []string) bool { - return effectiveLoaderMode(ctx, wd, env) != loader.ModeFallback -} - -func effectiveLoaderMode(ctx context.Context, wd string, env []string) loader.Mode { +func effectiveLoaderMode(_ context.Context, _ string, env []string) loader.Mode { mode := loader.ModeFromEnv(env) if mode != loader.ModeAuto { return mode diff --git a/internal/wire/wire.go b/internal/wire/wire.go index 3d787f3..2459723 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -170,26 +170,6 @@ func detectOutputDir(paths []string) (string, error) { return dir, nil } -func manifestOutputPkgPathsFromGenerated(generated []GenerateResult) []string { - if len(generated) == 0 { - return nil - } - seen := make(map[string]struct{}, len(generated)) - out := make([]string, 0, len(generated)) - for _, gen := range generated { - if gen.PkgPath == "" { - continue - } - if _, ok := seen[gen.PkgPath]; ok { - continue - } - seen[gen.PkgPath] = struct{}{} - out = append(out, gen.PkgPath) - } - sort.Strings(out) - return out -} - // generateInjectors generates the injectors for a given package. func generateInjectors(oc *objectCache, g *gen, pkg *packages.Package) (injectorFiles []*ast.File, _ []error) { injectorFiles = make([]*ast.File, 0, len(pkg.Syntax)) From 40ab1445c19c08ecb79cb6a61832413c7feaf807 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 18:31:40 -0500 Subject: [PATCH 30/79] refactor: dedupe command and custom loader helpers --- cmd/wire/gen_cmd.go | 43 +------ cmd/wire/generate_runner.go | 59 +++++++++ cmd/wire/logging.go | 133 ++++++++++++++++++++ cmd/wire/main.go | 125 ------------------- cmd/wire/watch_cmd.go | 48 +------- internal/loader/custom.go | 234 ++++++++++++++---------------------- 6 files changed, 287 insertions(+), 355 deletions(-) create mode 100644 cmd/wire/generate_runner.go create mode 100644 cmd/wire/logging.go diff --git a/cmd/wire/gen_cmd.go b/cmd/wire/gen_cmd.go index aceefee..246caa5 100644 --- a/cmd/wire/gen_cmd.go +++ b/cmd/wire/gen_cmd.go @@ -19,9 +19,7 @@ import ( "flag" "log" "os" - "time" - "github.com/goforj/wire/internal/wire" "github.com/google/subcommands" ) @@ -66,7 +64,6 @@ func (cmd *genCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interfa return subcommands.ExitFailure } defer stop() - totalStart := time.Now() ctx = withTiming(ctx, cmd.profile.timings) wd, err := os.Getwd() @@ -83,46 +80,8 @@ func (cmd *genCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interfa opts.PrefixOutputFile = cmd.prefixFileName opts.Tags = cmd.tags - genStart := time.Now() - outs, errs := wire.Generate(ctx, wd, os.Environ(), packages(f), opts) - logTiming(cmd.profile.timings, "wire.Generate", genStart) - if len(errs) > 0 { - logErrors(errs) - log.Println("generate failed") + if !runGenerateCommand(ctx, wd, os.Environ(), packages(f), opts, cmd.profile.timings) { return subcommands.ExitFailure } - if len(outs) == 0 { - logTiming(cmd.profile.timings, "total", totalStart) - return subcommands.ExitSuccess - } - success := true - writeStart := time.Now() - for _, out := range outs { - if len(out.Errs) > 0 { - logErrors(out.Errs) - log.Printf("%s: generate failed\n", out.PkgPath) - success = false - } - if len(out.Content) == 0 { - // No Wire output. Maybe errors, maybe no Wire directives. - continue - } - if wrote, err := out.CommitWithStatus(); err == nil { - if wrote { - logSuccessf("%s: wrote %s (%s)", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) - } else { - logSuccessf("%s: unchanged %s (%s)", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) - } - } else { - log.Printf("%s: failed to write %s: %v\n", out.PkgPath, out.OutputPath, err) - success = false - } - } - if !success { - log.Println("at least one generate failure") - return subcommands.ExitFailure - } - logTiming(cmd.profile.timings, "writes", writeStart) - logTiming(cmd.profile.timings, "total", totalStart) return subcommands.ExitSuccess } diff --git a/cmd/wire/generate_runner.go b/cmd/wire/generate_runner.go new file mode 100644 index 0000000..4dc7b20 --- /dev/null +++ b/cmd/wire/generate_runner.go @@ -0,0 +1,59 @@ +package main + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/goforj/wire/internal/wire" +) + +func runGenerateCommand(ctx context.Context, wd string, env []string, patterns []string, opts *wire.GenerateOptions, timings bool) bool { + totalStart := time.Now() + genStart := time.Now() + outs, errs := wire.Generate(ctx, wd, env, patterns, opts) + logTiming(timings, "wire.Generate", genStart) + if len(errs) > 0 { + logErrors(errs) + log.Println("generate failed") + return false + } + if len(outs) == 0 { + logTiming(timings, "total", totalStart) + return true + } + success := true + writeStart := time.Now() + for _, out := range outs { + if len(out.Errs) > 0 { + logErrors(out.Errs) + log.Printf("%s: generate failed\n", out.PkgPath) + success = false + } + if len(out.Content) == 0 { + continue + } + if wrote, err := out.CommitWithStatus(); err == nil { + if wrote { + logSuccessf("%s: wrote %s (%s)", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) + } else { + logSuccessf("%s: unchanged %s (%s)", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) + } + } else { + log.Printf("%s: failed to write %s: %v\n", out.PkgPath, out.OutputPath, err) + success = false + } + } + if !success { + log.Println("at least one generate failure") + return false + } + logTiming(timings, "writes", writeStart) + logTiming(timings, "total", totalStart) + return true +} + +func formatDuration(d time.Duration) string { + return fmt.Sprintf("%.2fms", float64(d)/float64(time.Millisecond)) +} diff --git a/cmd/wire/logging.go b/cmd/wire/logging.go new file mode 100644 index 0000000..7479f13 --- /dev/null +++ b/cmd/wire/logging.go @@ -0,0 +1,133 @@ +package main + +import ( + "fmt" + "io" + "os" + "strings" +) + +const ( + ansiRed = "\033[1;31m" + ansiGreen = "\033[1;32m" + ansiReset = "\033[0m" + successSig = "✓ " + errorSig = "x " + maxLoggedErrorLines = 5 +) + +func logErrors(errs []error) { + for _, err := range errs { + msg := truncateLoggedError(formatLoggedError(err)) + if strings.Contains(msg, "\n") { + logMultilineError("\n " + strings.ReplaceAll(msg, "\n", "\n ")) + continue + } + logMultilineError(msg) + } +} + +func formatLoggedError(err error) string { + if err == nil { + return "" + } + msg := err.Error() + if strings.HasPrefix(msg, "inject ") { + return "solve failed\n" + msg + } + if idx := strings.Index(msg, ": inject "); idx >= 0 { + return "solve failed\n" + msg + } + return msg +} + +func truncateLoggedError(msg string) string { + if msg == "" { + return "" + } + lines := strings.Split(msg, "\n") + if len(lines) <= maxLoggedErrorLines { + return msg + } + omitted := len(lines) - maxLoggedErrorLines + lines = append(lines[:maxLoggedErrorLines], fmt.Sprintf("... (%d additional lines omitted)", omitted)) + return strings.Join(lines, "\n") +} + +func logMultilineError(msg string) { + writeErrorLog(os.Stderr, msg) +} + +func logSuccessf(format string, args ...interface{}) { + writeStatusLog(os.Stderr, fmt.Sprintf(format, args...)) +} + +func shouldColorStderr() bool { + return shouldColorOutput(stderrIsTTY(), os.Getenv("TERM")) +} + +func shouldColorOutput(isTTY bool, term string) bool { + if os.Getenv("NO_COLOR") != "" || os.Getenv("CLICOLOR") == "0" { + return false + } + if forceColorEnabled() { + return true + } + if term == "" || term == "dumb" { + return false + } + return isTTY +} + +func forceColorEnabled() bool { + return os.Getenv("FORCE_COLOR") != "" || os.Getenv("CLICOLOR_FORCE") != "" +} + +func stderrIsTTY() bool { + info, err := os.Stderr.Stat() + if err != nil { + return false + } + return (info.Mode() & os.ModeCharDevice) != 0 +} + +func writeErrorLog(w io.Writer, msg string) { + line := errorSig + "wire: " + msg + if !strings.HasSuffix(line, "\n") { + line += "\n" + } + if shouldColorStderr() { + _, _ = io.WriteString(w, colorizeLines(line)) + return + } + _, _ = io.WriteString(w, line) +} + +func writeStatusLog(w io.Writer, msg string) { + line := successSig + "wire: " + msg + if !strings.HasSuffix(line, "\n") { + line += "\n" + } + if shouldColorStderr() { + _, _ = io.WriteString(w, ansiGreen+line+ansiReset) + return + } + _, _ = io.WriteString(w, line) +} + +func colorizeLines(s string) string { + if s == "" { + return "" + } + parts := strings.SplitAfter(s, "\n") + var b strings.Builder + for _, part := range parts { + if part == "" { + continue + } + b.WriteString(ansiRed) + b.WriteString(part) + b.WriteString(ansiReset) + } + return b.String() +} diff --git a/cmd/wire/main.go b/cmd/wire/main.go index c13b850..ada16d2 100644 --- a/cmd/wire/main.go +++ b/cmd/wire/main.go @@ -21,7 +21,6 @@ import ( "context" "flag" "fmt" - "io" "io/ioutil" "log" "os" @@ -35,15 +34,6 @@ import ( "github.com/google/subcommands" ) -const ( - ansiRed = "\033[1;31m" - ansiGreen = "\033[1;32m" - ansiReset = "\033[0m" - successSig = "✓ " - errorSig = "x " - maxLoggedErrorLines = 5 -) - // main wires up subcommands and executes the selected command. func main() { subcommands.Register(subcommands.CommandsCommand(), "") @@ -212,118 +202,3 @@ func newGenerateOptions(headerFile string) (*wire.GenerateOptions, error) { } // logErrors logs each error with consistent formatting. -func logErrors(errs []error) { - for _, err := range errs { - msg := truncateLoggedError(formatLoggedError(err)) - if strings.Contains(msg, "\n") { - logMultilineError("\n " + strings.ReplaceAll(msg, "\n", "\n ")) - continue - } - logMultilineError(msg) - } -} - -func formatLoggedError(err error) string { - if err == nil { - return "" - } - msg := err.Error() - if strings.HasPrefix(msg, "inject ") { - return "solve failed\n" + msg - } - if idx := strings.Index(msg, ": inject "); idx >= 0 { - return "solve failed\n" + msg - } - return msg -} - -func truncateLoggedError(msg string) string { - if msg == "" { - return "" - } - lines := strings.Split(msg, "\n") - if len(lines) <= maxLoggedErrorLines { - return msg - } - omitted := len(lines) - maxLoggedErrorLines - lines = append(lines[:maxLoggedErrorLines], fmt.Sprintf("... (%d additional lines omitted)", omitted)) - return strings.Join(lines, "\n") -} - -func logMultilineError(msg string) { - writeErrorLog(os.Stderr, msg) -} - -func logSuccessf(format string, args ...interface{}) { - writeStatusLog(os.Stderr, fmt.Sprintf(format, args...)) -} - -func shouldColorStderr() bool { - return shouldColorOutput(stderrIsTTY(), os.Getenv("TERM")) -} - -func shouldColorOutput(isTTY bool, term string) bool { - if os.Getenv("NO_COLOR") != "" || os.Getenv("CLICOLOR") == "0" { - return false - } - if forceColorEnabled() { - return true - } - if term == "" || term == "dumb" { - return false - } - return isTTY -} - -func forceColorEnabled() bool { - return os.Getenv("FORCE_COLOR") != "" || os.Getenv("CLICOLOR_FORCE") != "" -} - -func stderrIsTTY() bool { - info, err := os.Stderr.Stat() - if err != nil { - return false - } - return (info.Mode() & os.ModeCharDevice) != 0 -} - -func writeErrorLog(w io.Writer, msg string) { - line := errorSig + "wire: " + msg - if !strings.HasSuffix(line, "\n") { - line += "\n" - } - if shouldColorStderr() { - _, _ = io.WriteString(w, colorizeLines(line)) - return - } - _, _ = io.WriteString(w, line) -} - -func writeStatusLog(w io.Writer, msg string) { - line := successSig + "wire: " + msg - if !strings.HasSuffix(line, "\n") { - line += "\n" - } - if shouldColorStderr() { - _, _ = io.WriteString(w, ansiGreen+line+ansiReset) - return - } - _, _ = io.WriteString(w, line) -} - -func colorizeLines(s string) string { - if s == "" { - return "" - } - parts := strings.SplitAfter(s, "\n") - var b strings.Builder - for _, part := range parts { - if part == "" { - continue - } - b.WriteString(ansiRed) - b.WriteString(part) - b.WriteString(ansiReset) - } - return b.String() -} diff --git a/cmd/wire/watch_cmd.go b/cmd/wire/watch_cmd.go index ebdfa0e..45b6bc4 100644 --- a/cmd/wire/watch_cmd.go +++ b/cmd/wire/watch_cmd.go @@ -27,7 +27,6 @@ import ( "time" "github.com/fsnotify/fsnotify" - "github.com/goforj/wire/internal/wire" "github.com/google/subcommands" ) @@ -102,47 +101,7 @@ func (cmd *watchCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...inter env := os.Environ() runGenerate := func() { - totalStart := time.Now() - genStart := time.Now() - outs, errs := wire.Generate(ctx, wd, env, packages(f), opts) - logTiming(cmd.profile.timings, "wire.Generate", genStart) - if len(errs) > 0 { - logErrors(errs) - log.Println("generate failed") - return - } - if len(outs) == 0 { - logTiming(cmd.profile.timings, "total", totalStart) - return - } - success := true - writeStart := time.Now() - for _, out := range outs { - if len(out.Errs) > 0 { - logErrors(out.Errs) - log.Printf("%s: generate failed\n", out.PkgPath) - success = false - } - if len(out.Content) == 0 { - continue - } - if wrote, err := out.CommitWithStatus(); err == nil { - if wrote { - logSuccessf("%s: wrote %s (%s)", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) - } else { - logSuccessf("%s: unchanged %s (%s)", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) - } - } else { - log.Printf("%s: failed to write %s: %v\n", out.PkgPath, out.OutputPath, err) - success = false - } - } - if !success { - log.Println("at least one generate failure") - return - } - logTiming(cmd.profile.timings, "writes", writeStart) - logTiming(cmd.profile.timings, "total", totalStart) + _ = runGenerateCommand(ctx, wd, env, packages(f), opts, cmd.profile.timings) } root, err := moduleRoot(wd, env) @@ -332,11 +291,6 @@ func moduleRoot(wd string, env []string) (string, error) { return filepath.Dir(path), nil } -// formatDuration renders a short millisecond duration for log output. -func formatDuration(d time.Duration) string { - return fmt.Sprintf("%.2fms", float64(d)/float64(time.Millisecond)) -} - // watchWithFSNotify runs the watcher using native filesystem notifications. func watchWithFSNotify(root string, onChange func()) error { watcher, err := fsnotify.NewWatcher() diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 18c0f7d..afdafc0 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -188,22 +188,8 @@ func loadRootGraphCustom(ctx context.Context, req RootLoadRequest) (*RootLoadRes } pkgs := make(map[string]*packages.Package, len(meta)) for path, m := range meta { - pkgs[path] = &packages.Package{ - ID: m.ImportPath, - Name: m.Name, - PkgPath: m.ImportPath, - GoFiles: append([]string(nil), metaFiles(m)...), - CompiledGoFiles: append([]string(nil), metaFiles(m)...), - ExportFile: m.Export, - Imports: make(map[string]*packages.Package), - } - if m.Error != nil && strings.TrimSpace(m.Error.Err) != "" { - pkgs[path].Errors = append(pkgs[path].Errors, packages.Error{ - Pos: "-", - Msg: m.Error.Err, - Kind: packages.ListError, - }) - } + pkgs[path] = packageStub(nil, m) + appendPackageMetaError(pkgs[path], m) } for path, m := range meta { pkg := pkgs[path] @@ -217,12 +203,10 @@ func loadRootGraphCustom(ctx context.Context, req RootLoadRequest) (*RootLoadRes } } } - roots := make([]*packages.Package, 0, len(req.Patterns)) - for _, m := range meta { - if m.DepOnly { - continue - } - if pkg := pkgs[m.ImportPath]; pkg != nil { + rootPaths := nonDepRootImportPaths(meta) + roots := make([]*packages.Package, 0, len(rootPaths)) + for _, path := range rootPaths { + if pkg := pkgs[path]; pkg != nil { roots = append(roots, pkg) } } @@ -272,23 +256,7 @@ func loadTypedPackageGraphCustom(ctx context.Context, req LazyLoadRequest) (*Laz if fset == nil { fset = token.NewFileSet() } - l := &customTypedGraphLoader{ - workspace: detectModuleRoot(req.WD), - ctx: ctx, - env: append([]string(nil), req.Env...), - fset: fset, - meta: meta, - targets: map[string]struct{}{req.Package: {}}, - parseFile: req.ParseFile, - packages: make(map[string]*packages.Package, len(meta)), - typesPkgs: make(map[string]*types.Package, len(meta)), - importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), - loading: make(map[string]bool, len(meta)), - isLocalCache: make(map[string]bool, len(meta)), - localSemanticOK: make(map[string]bool, len(meta)), - artifactPrefetch: make(map[string]artifactPrefetchEntry, len(meta)), - stats: typedLoadStats{discovery: discoveryDuration}, - } + l := newCustomTypedGraphLoader(ctx, req.WD, req.Env, fset, meta, map[string]struct{}{req.Package: {}}, req.ParseFile, discoveryDuration) prefetchStart := time.Now() l.prefetchArtifacts() l.stats.artifactPrefetch = time.Since(prefetchStart) @@ -298,26 +266,7 @@ func loadTypedPackageGraphCustom(ctx context.Context, req LazyLoadRequest) (*Laz return nil, err } l.stats.rootLoad = time.Since(rootLoadStart) - logDuration(ctx, "loader.custom.lazy.read_files.cumulative", l.stats.read) - logDuration(ctx, "loader.custom.lazy.parse_files.cumulative", l.stats.parse) - logDuration(ctx, "loader.custom.lazy.typecheck.cumulative", l.stats.typecheck) - logDuration(ctx, "loader.custom.lazy.read_files.local.cumulative", l.stats.localRead) - logDuration(ctx, "loader.custom.lazy.read_files.external.cumulative", l.stats.externalRead) - logDuration(ctx, "loader.custom.lazy.parse_files.local.cumulative", l.stats.localParse) - logDuration(ctx, "loader.custom.lazy.parse_files.external.cumulative", l.stats.externalParse) - logDuration(ctx, "loader.custom.lazy.typecheck.local.cumulative", l.stats.localTypecheck) - logDuration(ctx, "loader.custom.lazy.typecheck.external.cumulative", l.stats.externalTypecheck) - logDuration(ctx, "loader.custom.lazy.artifact_read", l.stats.artifactRead) - logDuration(ctx, "loader.custom.lazy.artifact_path", l.stats.artifactPath) - logDuration(ctx, "loader.custom.lazy.artifact_decode", l.stats.artifactDecode) - logDuration(ctx, "loader.custom.lazy.artifact_import_link", l.stats.artifactImportLink) - logDuration(ctx, "loader.custom.lazy.artifact_write", l.stats.artifactWrite) - logDuration(ctx, "loader.custom.lazy.artifact_prefetch.wall", l.stats.artifactPrefetch) - logDuration(ctx, "loader.custom.lazy.root_load.wall", l.stats.rootLoad) - logDuration(ctx, "loader.custom.lazy.discovery.wall", l.stats.discovery) - logInt(ctx, "loader.custom.lazy.artifact_hits", l.stats.artifactHits) - logInt(ctx, "loader.custom.lazy.artifact_misses", l.stats.artifactMisses) - logInt(ctx, "loader.custom.lazy.artifact_writes", l.stats.artifactWrites) + logTypedLoadStats(ctx, "lazy", l.stats) return &LazyLoadResult{ Packages: []*packages.Package{root}, Backend: ModeCustom, @@ -345,42 +294,21 @@ func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLo fset = token.NewFileSet() } targets := make(map[string]struct{}) - for _, m := range meta { - if m.DepOnly { - continue - } - targets[m.ImportPath] = struct{}{} + for _, path := range nonDepRootImportPaths(meta) { + targets[path] = struct{}{} } if len(targets) == 0 { return nil, unsupportedError{reason: "no root packages from metadata"} } - l := &customTypedGraphLoader{ - workspace: detectModuleRoot(req.WD), - ctx: ctx, - env: append([]string(nil), req.Env...), - fset: fset, - meta: meta, - targets: targets, - parseFile: req.ParseFile, - packages: make(map[string]*packages.Package, len(meta)), - typesPkgs: make(map[string]*types.Package, len(meta)), - importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), - loading: make(map[string]bool, len(meta)), - isLocalCache: make(map[string]bool, len(meta)), - localSemanticOK: make(map[string]bool, len(meta)), - artifactPrefetch: make(map[string]artifactPrefetchEntry, len(meta)), - stats: typedLoadStats{discovery: discoveryDuration}, - } + l := newCustomTypedGraphLoader(ctx, req.WD, req.Env, fset, meta, targets, req.ParseFile, discoveryDuration) prefetchStart := time.Now() l.prefetchArtifacts() l.stats.artifactPrefetch = time.Since(prefetchStart) rootLoadStart := time.Now() - roots := make([]*packages.Package, 0, len(targets)) - for _, m := range meta { - if m.DepOnly { - continue - } - root, err := l.loadPackage(m.ImportPath) + rootPaths := nonDepRootImportPaths(meta) + roots := make([]*packages.Package, 0, len(rootPaths)) + for _, path := range rootPaths { + root, err := l.loadPackage(path) if err != nil { return nil, err } @@ -388,26 +316,7 @@ func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLo } l.stats.rootLoad = time.Since(rootLoadStart) sort.Slice(roots, func(i, j int) bool { return roots[i].PkgPath < roots[j].PkgPath }) - logDuration(ctx, "loader.custom.typed.read_files.cumulative", l.stats.read) - logDuration(ctx, "loader.custom.typed.parse_files.cumulative", l.stats.parse) - logDuration(ctx, "loader.custom.typed.typecheck.cumulative", l.stats.typecheck) - logDuration(ctx, "loader.custom.typed.read_files.local.cumulative", l.stats.localRead) - logDuration(ctx, "loader.custom.typed.read_files.external.cumulative", l.stats.externalRead) - logDuration(ctx, "loader.custom.typed.parse_files.local.cumulative", l.stats.localParse) - logDuration(ctx, "loader.custom.typed.parse_files.external.cumulative", l.stats.externalParse) - logDuration(ctx, "loader.custom.typed.typecheck.local.cumulative", l.stats.localTypecheck) - logDuration(ctx, "loader.custom.typed.typecheck.external.cumulative", l.stats.externalTypecheck) - logDuration(ctx, "loader.custom.typed.artifact_read", l.stats.artifactRead) - logDuration(ctx, "loader.custom.typed.artifact_path", l.stats.artifactPath) - logDuration(ctx, "loader.custom.typed.artifact_decode", l.stats.artifactDecode) - logDuration(ctx, "loader.custom.typed.artifact_import_link", l.stats.artifactImportLink) - logDuration(ctx, "loader.custom.typed.artifact_write", l.stats.artifactWrite) - logDuration(ctx, "loader.custom.typed.artifact_prefetch.wall", l.stats.artifactPrefetch) - logDuration(ctx, "loader.custom.typed.root_load.wall", l.stats.rootLoad) - logDuration(ctx, "loader.custom.typed.discovery.wall", l.stats.discovery) - logInt(ctx, "loader.custom.typed.artifact_hits", l.stats.artifactHits) - logInt(ctx, "loader.custom.typed.artifact_misses", l.stats.artifactMisses) - logInt(ctx, "loader.custom.typed.artifact_writes", l.stats.artifactWrites) + logTypedLoadStats(ctx, "typed", l.stats) return &PackageLoadResult{ Packages: roots, Backend: ModeCustom, @@ -424,22 +333,8 @@ func (v *customValidator) validatePackage(path string) (*packages.Package, error } v.loading[path] = true defer delete(v.loading, path) - pkg := &packages.Package{ - ID: meta.ImportPath, - Name: meta.Name, - PkgPath: meta.ImportPath, - Fset: v.fset, - GoFiles: append([]string(nil), metaFiles(meta)...), - CompiledGoFiles: append([]string(nil), metaFiles(meta)...), - Imports: make(map[string]*packages.Package), - ExportFile: meta.Export, - } - if meta.Error != nil && strings.TrimSpace(meta.Error.Err) != "" { - pkg.Errors = append(pkg.Errors, packages.Error{ - Pos: "-", - Msg: meta.Error.Err, - Kind: packages.ListError, - }) + pkg := packageStub(v.fset, meta) + if appendPackageMetaError(pkg, meta) { return pkg, nil } files, errs := v.parseFiles(metaFiles(meta)) @@ -563,16 +458,7 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er } if pkg == nil { - pkg = &packages.Package{ - ID: meta.ImportPath, - Name: meta.Name, - PkgPath: meta.ImportPath, - Fset: l.fset, - GoFiles: append([]string(nil), metaFiles(meta)...), - CompiledGoFiles: append([]string(nil), metaFiles(meta)...), - Imports: make(map[string]*packages.Package), - ExportFile: meta.Export, - } + pkg = packageStub(l.fset, meta) l.packages[path] = pkg } useArtifact := l.shouldUseArtifact(path, meta, isTarget, isLocal) @@ -600,13 +486,7 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er files, parseErrs := l.parseFiles(metaFiles(meta), isLocal) pkg.Errors = append(pkg.Errors, parseErrs...) if len(files) == 0 { - if meta.Error != nil && strings.TrimSpace(meta.Error.Err) != "" { - pkg.Errors = append(pkg.Errors, packages.Error{ - Pos: "-", - Msg: meta.Error.Err, - Kind: packages.ListError, - }) - } + appendPackageMetaError(pkg, meta) return pkg, nil } @@ -1358,7 +1238,27 @@ func envValue(env []string, key string) string { return "" } -func touchedPackageStub(fset *token.FileSet, meta *packageMeta) *packages.Package { +func newCustomTypedGraphLoader(ctx context.Context, wd string, env []string, fset *token.FileSet, meta map[string]*packageMeta, targets map[string]struct{}, parseFile ParseFileFunc, discoveryDuration time.Duration) *customTypedGraphLoader { + return &customTypedGraphLoader{ + workspace: detectModuleRoot(wd), + ctx: ctx, + env: append([]string(nil), env...), + fset: fset, + meta: meta, + targets: targets, + parseFile: parseFile, + packages: make(map[string]*packages.Package, len(meta)), + typesPkgs: make(map[string]*types.Package, len(meta)), + importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), + loading: make(map[string]bool, len(meta)), + isLocalCache: make(map[string]bool, len(meta)), + localSemanticOK: make(map[string]bool, len(meta)), + artifactPrefetch: make(map[string]artifactPrefetchEntry, len(meta)), + stats: typedLoadStats{discovery: discoveryDuration}, + } +} + +func packageStub(fset *token.FileSet, meta *packageMeta) *packages.Package { if meta == nil { return nil } @@ -1374,6 +1274,58 @@ func touchedPackageStub(fset *token.FileSet, meta *packageMeta) *packages.Packag } } +func appendPackageMetaError(pkg *packages.Package, meta *packageMeta) bool { + if pkg == nil || meta == nil || meta.Error == nil || strings.TrimSpace(meta.Error.Err) == "" { + return false + } + pkg.Errors = append(pkg.Errors, packages.Error{ + Pos: "-", + Msg: meta.Error.Err, + Kind: packages.ListError, + }) + return true +} + +func nonDepRootImportPaths(meta map[string]*packageMeta) []string { + paths := make([]string, 0, len(meta)) + for _, m := range meta { + if m == nil || m.DepOnly { + continue + } + paths = append(paths, m.ImportPath) + } + sort.Strings(paths) + return paths +} + +func logTypedLoadStats(ctx context.Context, mode string, stats typedLoadStats) { + prefix := "loader.custom." + mode + logDuration(ctx, prefix+".read_files.cumulative", stats.read) + logDuration(ctx, prefix+".parse_files.cumulative", stats.parse) + logDuration(ctx, prefix+".typecheck.cumulative", stats.typecheck) + logDuration(ctx, prefix+".read_files.local.cumulative", stats.localRead) + logDuration(ctx, prefix+".read_files.external.cumulative", stats.externalRead) + logDuration(ctx, prefix+".parse_files.local.cumulative", stats.localParse) + logDuration(ctx, prefix+".parse_files.external.cumulative", stats.externalParse) + logDuration(ctx, prefix+".typecheck.local.cumulative", stats.localTypecheck) + logDuration(ctx, prefix+".typecheck.external.cumulative", stats.externalTypecheck) + logDuration(ctx, prefix+".artifact_read", stats.artifactRead) + logDuration(ctx, prefix+".artifact_path", stats.artifactPath) + logDuration(ctx, prefix+".artifact_decode", stats.artifactDecode) + logDuration(ctx, prefix+".artifact_import_link", stats.artifactImportLink) + logDuration(ctx, prefix+".artifact_write", stats.artifactWrite) + logDuration(ctx, prefix+".artifact_prefetch.wall", stats.artifactPrefetch) + logDuration(ctx, prefix+".root_load.wall", stats.rootLoad) + logDuration(ctx, prefix+".discovery.wall", stats.discovery) + logInt(ctx, prefix+".artifact_hits", stats.artifactHits) + logInt(ctx, prefix+".artifact_misses", stats.artifactMisses) + logInt(ctx, prefix+".artifact_writes", stats.artifactWrites) +} + +func touchedPackageStub(fset *token.FileSet, meta *packageMeta) *packages.Package { + return packageStub(fset, meta) +} + func metadataMatchesFingerprint(pkgPath string, meta map[string]*packageMeta, local []LocalPackageFingerprint) bool { for _, fp := range local { if fp.PkgPath != pkgPath { From 0921843ca888dff91d2ac9774962e7751a111cdf Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 18:37:26 -0500 Subject: [PATCH 31/79] refactor: make loader artifact policy explicit --- internal/loader/custom.go | 45 ++++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/internal/loader/custom.go b/internal/loader/custom.go index afdafc0..87cab98 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -134,6 +134,11 @@ type typedLoadStats struct { artifactWrites int } +type artifactPolicy struct { + read bool + write bool +} + func validateTouchedPackagesCustom(ctx context.Context, req TouchedValidationRequest) (*TouchedValidationResult, error) { if len(req.Touched) == 0 { return &TouchedValidationResult{Backend: ModeCustom}, nil @@ -461,8 +466,8 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er pkg = packageStub(l.fset, meta) l.packages[path] = pkg } - useArtifact := l.shouldUseArtifact(path, meta, isTarget, isLocal) - if useArtifact { + artifactPolicy := l.artifactPolicy(meta, isTarget, isLocal) + if artifactPolicy.read { if typed, ok := l.readArtifact(path, meta, isLocal); ok { linkStart := time.Now() for _, imp := range meta.Imports { @@ -559,13 +564,13 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er pkg.Types = tpkg pkg.TypesInfo = info pkg.Errors = append(pkg.Errors, typeErrors...) - if shouldWriteArtifact(l.env, isTarget) && len(pkg.Errors) == 0 { + if artifactPolicy.write && len(pkg.Errors) == 0 { _ = l.writeArtifact(meta, tpkg, isLocal) } return pkg, nil } -func (l *customTypedGraphLoader) useLocalSemanticArtifact(meta *packageMeta) bool { +func (l *customTypedGraphLoader) localSemanticArtifactSupported(meta *packageMeta) bool { if meta == nil { return false } @@ -581,6 +586,19 @@ func (l *customTypedGraphLoader) useLocalSemanticArtifact(meta *packageMeta) boo return art.Supported } +func (l *customTypedGraphLoader) artifactPolicy(meta *packageMeta, isTarget, isLocal bool) artifactPolicy { + if !loaderArtifactEnabled(l.env) || isTarget { + return artifactPolicy{} + } + policy := artifactPolicy{write: true} + if !isLocal { + policy.read = true + return policy + } + policy.read = l.localSemanticArtifactSupported(meta) + return policy +} + func (l *customTypedGraphLoader) checkFiles(path string, checker *types.Checker, files []*ast.File) (err error) { defer func() { if r := recover(); r != nil { @@ -656,16 +674,6 @@ func (l *customTypedGraphLoader) readArtifact(path string, meta *packageMeta, is return tpkg, true } -func (l *customTypedGraphLoader) shouldUseArtifact(path string, meta *packageMeta, isTarget, isLocal bool) bool { - if !loaderArtifactEnabled(l.env) || isTarget { - return false - } - if !isLocal { - return true - } - return l.useLocalSemanticArtifact(meta) -} - func (l *customTypedGraphLoader) prefetchArtifacts() { if !loaderArtifactEnabled(l.env) { return @@ -674,7 +682,7 @@ func (l *customTypedGraphLoader) prefetchArtifacts() { for path, meta := range l.meta { _, isTarget := l.targets[path] isLocal := l.isLocalPackage(path, meta) - if l.shouldUseArtifact(path, meta, isTarget, isLocal) { + if l.artifactPolicy(meta, isTarget, isLocal).read { candidates = append(candidates, path) } } @@ -761,13 +769,6 @@ func (l *customTypedGraphLoader) writeArtifact(meta *packageMeta, pkg *types.Pac return nil } -func shouldWriteArtifact(env []string, isTarget bool) bool { - if !loaderArtifactEnabled(env) || isTarget { - return false - } - return true -} - func (l *customTypedGraphLoader) isLocalPackage(importPath string, meta *packageMeta) bool { if local, ok := l.isLocalCache[importPath]; ok { return local From bb7af77028392f6f756b4b8326758367e4a61712 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 18:38:07 -0500 Subject: [PATCH 32/79] refactor: dedupe custom loader import linking --- internal/loader/custom.go | 38 ++++++++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 87cab98..827b003 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -470,16 +470,8 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er if artifactPolicy.read { if typed, ok := l.readArtifact(path, meta, isLocal); ok { linkStart := time.Now() - for _, imp := range meta.Imports { - target := imp - if mapped := meta.ImportMap[imp]; mapped != "" { - target = mapped - } - dep, err := l.loadPackage(target) - if err != nil { - return nil, err - } - pkg.Imports[imp] = dep + if err := l.linkPackageImports(pkg, meta); err != nil { + return nil, err } l.stats.artifactImportLink += time.Since(linkStart) pkg.Types = typed @@ -520,10 +512,7 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er if importPath == "unsafe" { return types.Unsafe, nil } - target := importPath - if mapped := meta.ImportMap[importPath]; mapped != "" { - target = mapped - } + target := resolvedImportTarget(meta, importPath) dep, err := l.loadPackage(target) if err != nil { return nil, err @@ -599,6 +588,17 @@ func (l *customTypedGraphLoader) artifactPolicy(meta *packageMeta, isTarget, isL return policy } +func (l *customTypedGraphLoader) linkPackageImports(pkg *packages.Package, meta *packageMeta) error { + for _, imp := range meta.Imports { + dep, err := l.loadPackage(resolvedImportTarget(meta, imp)) + if err != nil { + return err + } + pkg.Imports[imp] = dep + } + return nil +} + func (l *customTypedGraphLoader) checkFiles(path string, checker *types.Checker, files []*ast.File) (err error) { defer func() { if r := recover(); r != nil { @@ -1287,6 +1287,16 @@ func appendPackageMetaError(pkg *packages.Package, meta *packageMeta) bool { return true } +func resolvedImportTarget(meta *packageMeta, importPath string) string { + if meta == nil { + return importPath + } + if mapped := meta.ImportMap[importPath]; mapped != "" { + return mapped + } + return importPath +} + func nonDepRootImportPaths(meta map[string]*packageMeta) []string { paths := make([]string, 0, len(meta)) for _, m := range meta { From df22b6a67bd167af9fd282bfdff5d1d79086934c Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 18:38:47 -0500 Subject: [PATCH 33/79] refactor: share import target resolution in custom loader --- internal/loader/custom.go | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 827b003..a498499 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -199,10 +199,7 @@ func loadRootGraphCustom(ctx context.Context, req RootLoadRequest) (*RootLoadRes for path, m := range meta { pkg := pkgs[path] for _, imp := range m.Imports { - target := imp - if mapped := m.ImportMap[imp]; mapped != "" { - target = mapped - } + target := resolvedImportTarget(m, imp) if dep := pkgs[target]; dep != nil { pkg.Imports[imp] = dep } @@ -362,10 +359,7 @@ func (v *customValidator) validatePackage(path string) (*packages.Package, error if importPath == "unsafe" { return types.Unsafe, nil } - target := importPath - if mapped := meta.ImportMap[importPath]; mapped != "" { - target = mapped - } + target := resolvedImportTarget(meta, importPath) if _, ok := v.touched[target]; ok { if typed := v.packages[target]; typed != nil && typed.Complete() { if depMeta := v.meta[target]; depMeta != nil { @@ -859,10 +853,7 @@ func (v *customValidator) loadDependencyFromSource(path string) (*types.Package, if importPath == "unsafe" { return types.Unsafe, nil } - target := importPath - if mapped := meta.ImportMap[importPath]; mapped != "" { - target = mapped - } + target := resolvedImportTarget(meta, importPath) if _, ok := v.touched[target]; ok { checked, err := v.validatePackage(target) if err != nil { @@ -1023,10 +1014,7 @@ func (v *customValidator) validateDeclaredImports(meta *packageMeta, files []*as if path == "" { continue } - target := path - if mapped := meta.ImportMap[path]; mapped != "" { - target = mapped - } + target := resolvedImportTarget(meta, path) name := importName(spec) if name != "_" && name != "." { if _, ok := used[name]; !ok { From a857e0b2e77754431d7175e7e0ca62e50230da7f Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 18:39:36 -0500 Subject: [PATCH 34/79] refactor: centralize types info setup in custom loader --- internal/loader/custom.go | 38 ++++++++++++++------------------------ 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/internal/loader/custom.go b/internal/loader/custom.go index a498499..67803fd 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -347,14 +347,7 @@ func (v *customValidator) validatePackage(path string) (*packages.Package, error tpkg := types.NewPackage(meta.ImportPath, meta.Name) v.packages[meta.ImportPath] = tpkg - info := &types.Info{ - Types: make(map[ast.Expr]types.TypeAndValue), - Defs: make(map[*ast.Ident]types.Object), - Uses: make(map[*ast.Ident]types.Object), - Implicits: make(map[ast.Node]types.Object), - Scopes: make(map[ast.Node]*types.Scope), - Selections: make(map[*ast.SelectorExpr]*types.Selection), - } + info := newTypesInfo() importer := importerFunc(func(importPath string) (*types.Package, error) { if importPath == "unsafe" { return types.Unsafe, nil @@ -489,14 +482,7 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er needFullState := isTarget || isLocal var info *types.Info if needFullState { - info = &types.Info{ - Types: make(map[ast.Expr]types.TypeAndValue), - Defs: make(map[*ast.Ident]types.Object), - Uses: make(map[*ast.Ident]types.Object), - Implicits: make(map[ast.Node]types.Object), - Scopes: make(map[ast.Node]*types.Scope), - Selections: make(map[*ast.SelectorExpr]*types.Selection), - } + info = newTypesInfo() } var typeErrors []packages.Error cfg := &types.Config{ @@ -840,14 +826,7 @@ func (v *customValidator) loadDependencyFromSource(path string) (*types.Package, if len(errs) > 0 { return nil, unsupportedError{reason: "dependency parse error"} } - info := &types.Info{ - Types: make(map[ast.Expr]types.TypeAndValue), - Defs: make(map[*ast.Ident]types.Object), - Uses: make(map[*ast.Ident]types.Object), - Implicits: make(map[ast.Node]types.Object), - Scopes: make(map[ast.Node]*types.Scope), - Selections: make(map[*ast.SelectorExpr]*types.Selection), - } + info := newTypesInfo() cfg := &types.Config{ Importer: importerFunc(func(importPath string) (*types.Package, error) { if importPath == "unsafe" { @@ -1285,6 +1264,17 @@ func resolvedImportTarget(meta *packageMeta, importPath string) string { return importPath } +func newTypesInfo() *types.Info { + return &types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + Defs: make(map[*ast.Ident]types.Object), + Uses: make(map[*ast.Ident]types.Object), + Implicits: make(map[ast.Node]types.Object), + Scopes: make(map[ast.Node]*types.Scope), + Selections: make(map[*ast.SelectorExpr]*types.Selection), + } +} + func nonDepRootImportPaths(meta map[string]*packageMeta) []string { paths := make([]string, 0, len(meta)) for _, m := range meta { From 11b5498234fe7fe9afadeea8a9026460219ab73f Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 18:40:25 -0500 Subject: [PATCH 35/79] refactor: share parse error conversion in custom loader --- internal/loader/custom.go | 42 +++++++++++++++++---------------------- 1 file changed, 18 insertions(+), 24 deletions(-) diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 67803fd..4cdc1e8 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -865,18 +865,7 @@ func (v *customValidator) parseFiles(names []string) ([]*ast.File, []packages.Er } f, err := parser.ParseFile(v.fset, name, src, parser.AllErrors|parser.ParseComments) if err != nil { - switch typed := err.(type) { - case scanner.ErrorList: - for _, parseErr := range typed { - errs = append(errs, packages.Error{ - Pos: parseErr.Pos.String(), - Msg: parseErr.Msg, - Kind: packages.ParseError, - }) - } - default: - errs = append(errs, packages.Error{Pos: name + ":1", Msg: err.Error(), Kind: packages.ParseError}) - } + errs = appendParseErrors(errs, name, err) } if f != nil { files = append(files, f) @@ -920,18 +909,7 @@ func (l *customTypedGraphLoader) parseFiles(names []string, isLocal bool) ([]*as l.stats.externalParse += parseDuration } if err != nil { - switch typed := err.(type) { - case scanner.ErrorList: - for _, parseErr := range typed { - errs = append(errs, packages.Error{ - Pos: parseErr.Pos.String(), - Msg: parseErr.Msg, - Kind: packages.ParseError, - }) - } - default: - errs = append(errs, packages.Error{Pos: name + ":1", Msg: err.Error(), Kind: packages.ParseError}) - } + errs = appendParseErrors(errs, name, err) } if f != nil { files = append(files, f) @@ -1275,6 +1253,22 @@ func newTypesInfo() *types.Info { } } +func appendParseErrors(errs []packages.Error, name string, err error) []packages.Error { + switch typed := err.(type) { + case scanner.ErrorList: + for _, parseErr := range typed { + errs = append(errs, packages.Error{ + Pos: parseErr.Pos.String(), + Msg: parseErr.Msg, + Kind: packages.ParseError, + }) + } + default: + errs = append(errs, packages.Error{Pos: name + ":1", Msg: err.Error(), Kind: packages.ParseError}) + } + return errs +} + func nonDepRootImportPaths(meta map[string]*packageMeta) []string { paths := make([]string, 0, len(meta)) for _, m := range meta { From 7bd7c65d6246a514a5fc6f7692fcf79c88c5e489 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 18:41:07 -0500 Subject: [PATCH 36/79] refactor: share source parsing in custom loader --- internal/loader/custom.go | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 4cdc1e8..3529350 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -863,7 +863,7 @@ func (v *customValidator) parseFiles(names []string) ([]*ast.File, []packages.Er errs = append(errs, packages.Error{Pos: name + ":1", Msg: err.Error(), Kind: packages.ParseError}) continue } - f, err := parser.ParseFile(v.fset, name, src, parser.AllErrors|parser.ParseComments) + f, err := parseGoSourceFile(v.fset, nil, name, src) if err != nil { errs = appendParseErrors(errs, name, err) } @@ -896,11 +896,7 @@ func (l *customTypedGraphLoader) parseFiles(names []string, isLocal bool) ([]*as } var f *ast.File parseStart := time.Now() - if l.parseFile != nil { - f, err = l.parseFile(l.fset, name, src) - } else { - f, err = parser.ParseFile(l.fset, name, src, parser.AllErrors|parser.ParseComments) - } + f, err = parseGoSourceFile(l.fset, l.parseFile, name, src) parseDuration := time.Since(parseStart) l.stats.parse += parseDuration if isLocal { @@ -1253,6 +1249,13 @@ func newTypesInfo() *types.Info { } } +func parseGoSourceFile(fset *token.FileSet, parseFile ParseFileFunc, name string, src []byte) (*ast.File, error) { + if parseFile != nil { + return parseFile(fset, name, src) + } + return parser.ParseFile(fset, name, src, parser.AllErrors|parser.ParseComments) +} + func appendParseErrors(errs []packages.Error, name string, err error) []packages.Error { switch typed := err.(type) { case scanner.ErrorList: From 6e3fed58e84c6bce7b18e863e355ff34d26c59ea Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 18:47:19 -0500 Subject: [PATCH 37/79] refactor: centralize semantic artifact cache inputs --- internal/wire/parse.go | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 08a3c8a..34167fc 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -822,10 +822,11 @@ func (oc *objectCache) semanticArtifact(pkg *packages.Package) *semanticcache.Pa if art, ok := oc.semantic[pkg.PkgPath]; ok { return art } - if len(oc.env) == 0 || len(pkg.GoFiles) == 0 { + importPath, packageName, files, ok := semanticArtifactInputs(oc.env, pkg) + if !ok { return nil } - art, err := semanticcache.Read(oc.env, pkg.PkgPath, pkg.Name, pkg.GoFiles) + art, err := semanticcache.Read(oc.env, importPath, packageName, files) if err != nil { return nil } @@ -838,7 +839,8 @@ func (oc *objectCache) recordSemanticArtifacts() { return } for _, pkg := range oc.packages { - if pkg == nil || len(pkg.Syntax) == 0 || pkg.Types == nil || pkg.TypesInfo == nil || len(pkg.GoFiles) == 0 { + importPath, packageName, files, ok := semanticArtifactInputs(oc.env, pkg) + if !ok || len(pkg.Syntax) == 0 || pkg.Types == nil || pkg.TypesInfo == nil { continue } art := buildSemanticArtifact(pkg) @@ -846,8 +848,15 @@ func (oc *objectCache) recordSemanticArtifacts() { continue } oc.semantic[pkg.PkgPath] = art - _ = semanticcache.Write(oc.env, pkg.PkgPath, pkg.Name, pkg.GoFiles, art) + _ = semanticcache.Write(oc.env, importPath, packageName, files, art) + } +} + +func semanticArtifactInputs(env []string, pkg *packages.Package) (importPath, packageName string, files []string, ok bool) { + if len(env) == 0 || pkg == nil || len(pkg.GoFiles) == 0 { + return "", "", nil, false } + return pkg.PkgPath, pkg.Name, pkg.GoFiles, true } func buildSemanticArtifact(pkg *packages.Package) *semanticcache.PackageArtifact { From c2abf4ef0d2caadf6454fb47cc9d62932bd46901 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 18:48:16 -0500 Subject: [PATCH 38/79] refactor: isolate semantic artifact cache io --- internal/wire/parse.go | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 34167fc..dee48b0 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -826,7 +826,7 @@ func (oc *objectCache) semanticArtifact(pkg *packages.Package) *semanticcache.Pa if !ok { return nil } - art, err := semanticcache.Read(oc.env, importPath, packageName, files) + art, err := readSemanticArtifact(oc.env, importPath, packageName, files) if err != nil { return nil } @@ -848,7 +848,7 @@ func (oc *objectCache) recordSemanticArtifacts() { continue } oc.semantic[pkg.PkgPath] = art - _ = semanticcache.Write(oc.env, importPath, packageName, files, art) + _ = writeSemanticArtifact(oc.env, importPath, packageName, files, art) } } @@ -859,6 +859,14 @@ func semanticArtifactInputs(env []string, pkg *packages.Package) (importPath, pa return pkg.PkgPath, pkg.Name, pkg.GoFiles, true } +func readSemanticArtifact(env []string, importPath, packageName string, files []string) (*semanticcache.PackageArtifact, error) { + return semanticcache.Read(env, importPath, packageName, files) +} + +func writeSemanticArtifact(env []string, importPath, packageName string, files []string, art *semanticcache.PackageArtifact) error { + return semanticcache.Write(env, importPath, packageName, files, art) +} + func buildSemanticArtifact(pkg *packages.Package) *semanticcache.PackageArtifact { if pkg == nil || pkg.Types == nil || pkg.TypesInfo == nil { return nil From 40ea02b0303b1b9c3aeb7be849209a1ec08427e3 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 18:49:05 -0500 Subject: [PATCH 39/79] refactor: isolate semantic provider set artifact lookup --- internal/wire/parse.go | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index dee48b0..80411ef 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -562,23 +562,10 @@ func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { } func (oc *objectCache) semanticProviderSet(obj *types.Var) (*ProviderSet, bool, []error) { - pkg := oc.packages[obj.Pkg().Path()] - if pkg == nil { - return nil, false, nil - } - art := oc.semanticArtifact(pkg) - if art == nil || !art.Supported { - return nil, false, nil - } - setArt, ok := art.Vars[obj.Name()] + setArt, ok := oc.semanticProviderSetArtifact(obj) if !ok { return nil, false, nil } - for _, item := range setArt.Items { - if item.Kind == "bind" { - return nil, false, nil - } - } pset := &ProviderSet{ Pos: obj.Pos(), PkgPath: obj.Pkg().Path(), @@ -640,6 +627,27 @@ func (oc *objectCache) semanticProviderSet(obj *types.Var) (*ProviderSet, bool, return pset, true, nil } +func (oc *objectCache) semanticProviderSetArtifact(obj *types.Var) (semanticcache.ProviderSetArtifact, bool) { + pkg := oc.packages[obj.Pkg().Path()] + if pkg == nil { + return semanticcache.ProviderSetArtifact{}, false + } + art := oc.semanticArtifact(pkg) + if art == nil || !art.Supported { + return semanticcache.ProviderSetArtifact{}, false + } + setArt, ok := art.Vars[obj.Name()] + if !ok { + return semanticcache.ProviderSetArtifact{}, false + } + for _, item := range setArt.Items { + if item.Kind == "bind" { + return semanticcache.ProviderSetArtifact{}, false + } + } + return setArt, true +} + func (oc *objectCache) semanticProvider(importPath, name string) (*Provider, []error) { pkg := oc.packages[importPath] if pkg == nil || pkg.Types == nil { From 5f295650675a7ef27313893b599fa091b258de59 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 18:51:21 -0500 Subject: [PATCH 40/79] refactor: extract semantic provider set item application --- internal/wire/parse.go | 82 ++++++++++++++++++++++-------------------- 1 file changed, 44 insertions(+), 38 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 80411ef..b988a2a 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -573,44 +573,8 @@ func (oc *objectCache) semanticProviderSet(obj *types.Var) (*ProviderSet, bool, } ec := new(errorCollector) for _, item := range setArt.Items { - switch item.Kind { - case "func": - providerObj, errs := oc.semanticProvider(item.ImportPath, item.Name) - if len(errs) > 0 { - ec.add(errs...) - continue - } - pset.Providers = append(pset.Providers, providerObj) - case "set": - setObj, errs := oc.semanticImportedSet(item.ImportPath, item.Name) - if len(errs) > 0 { - ec.add(errs...) - continue - } - pset.Imports = append(pset.Imports, setObj) - case "bind": - binding, errs := oc.semanticBinding(item) - if len(errs) > 0 { - ec.add(errs...) - continue - } - pset.Bindings = append(pset.Bindings, binding) - case "struct": - providerObj, errs := oc.semanticStructProvider(item) - if len(errs) > 0 { - ec.add(errs...) - continue - } - pset.Providers = append(pset.Providers, providerObj) - case "fields": - fields, errs := oc.semanticFields(item) - if len(errs) > 0 { - ec.add(errs...) - continue - } - pset.Fields = append(pset.Fields, fields...) - default: - ec.add(fmt.Errorf("unsupported semantic cache item kind %q", item.Kind)) + if errs := oc.applySemanticProviderSetItem(pset, item); len(errs) > 0 { + ec.add(errs...) } } if len(ec.errors) > 0 { @@ -648,6 +612,48 @@ func (oc *objectCache) semanticProviderSetArtifact(obj *types.Var) (semanticcach return setArt, true } +func (oc *objectCache) applySemanticProviderSetItem(pset *ProviderSet, item semanticcache.ProviderSetItemArtifact) []error { + switch item.Kind { + case "func": + providerObj, errs := oc.semanticProvider(item.ImportPath, item.Name) + if len(errs) > 0 { + return errs + } + pset.Providers = append(pset.Providers, providerObj) + return nil + case "set": + setObj, errs := oc.semanticImportedSet(item.ImportPath, item.Name) + if len(errs) > 0 { + return errs + } + pset.Imports = append(pset.Imports, setObj) + return nil + case "bind": + binding, errs := oc.semanticBinding(item) + if len(errs) > 0 { + return errs + } + pset.Bindings = append(pset.Bindings, binding) + return nil + case "struct": + providerObj, errs := oc.semanticStructProvider(item) + if len(errs) > 0 { + return errs + } + pset.Providers = append(pset.Providers, providerObj) + return nil + case "fields": + fields, errs := oc.semanticFields(item) + if len(errs) > 0 { + return errs + } + pset.Fields = append(pset.Fields, fields...) + return nil + default: + return []error{fmt.Errorf("unsupported semantic cache item kind %q", item.Kind)} + } +} + func (oc *objectCache) semanticProvider(importPath, name string) (*Provider, []error) { pkg := oc.packages[importPath] if pkg == nil || pkg.Types == nil { From 91845b1a1462ec646fe56615268abc375414e6e0 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 18:52:32 -0500 Subject: [PATCH 41/79] refactor: share semantic struct field helpers --- internal/wire/parse.go | 69 +++++++++++++++++++++++++++--------------- 1 file changed, 44 insertions(+), 25 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index b988a2a..04a3301 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -720,29 +720,11 @@ func (oc *objectCache) semanticStructProvider(item semanticcache.ProviderSetItem IsStruct: true, Out: []types.Type{out, types.NewPointer(out)}, } - if item.AllFields { - for i := 0; i < st.NumFields(); i++ { - if isPrevented(st.Tag(i)) { - continue - } - f := st.Field(i) - provider.Args = append(provider.Args, ProviderInput{ - Type: f.Type(), - FieldName: f.Name(), - }) - } - } else { - for _, fieldName := range item.FieldNames { - f := lookupStructField(st, fieldName) - if f == nil { - return nil, []error{fmt.Errorf("field %q not found in %s.%s", fieldName, item.Type.ImportPath, item.Type.Name)} - } - provider.Args = append(provider.Args, ProviderInput{ - Type: f.Type(), - FieldName: f.Name(), - }) - } + args, errs := semanticStructProviderInputs(st, item) + if len(errs) > 0 { + return nil, errs } + provider.Args = args return provider, nil } @@ -757,9 +739,9 @@ func (oc *objectCache) semanticFields(item semanticcache.ProviderSetItemArtifact } fields := make([]*Field, 0, len(item.FieldNames)) for _, fieldName := range item.FieldNames { - v := lookupStructField(structType, fieldName) - if v == nil { - return nil, []error{fmt.Errorf("field %q not found", fieldName)} + v, err := requiredStructField(structType, fieldName) + if err != nil { + return nil, []error{err} } out := []types.Type{v.Type()} if ptrToField { @@ -776,6 +758,35 @@ func (oc *objectCache) semanticFields(item semanticcache.ProviderSetItemArtifact return fields, nil } +func semanticStructProviderInputs(st *types.Struct, item semanticcache.ProviderSetItemArtifact) ([]ProviderInput, []error) { + if item.AllFields { + args := make([]ProviderInput, 0, st.NumFields()) + for i := 0; i < st.NumFields(); i++ { + if isPrevented(st.Tag(i)) { + continue + } + f := st.Field(i) + args = append(args, ProviderInput{ + Type: f.Type(), + FieldName: f.Name(), + }) + } + return args, nil + } + args := make([]ProviderInput, 0, len(item.FieldNames)) + for _, fieldName := range item.FieldNames { + f, err := requiredStructField(st, fieldName) + if err != nil { + return nil, []error{fmt.Errorf("field %q not found in %s.%s", fieldName, item.Type.ImportPath, item.Type.Name)} + } + args = append(args, ProviderInput{ + Type: f.Type(), + FieldName: f.Name(), + }) + } + return args, nil +} + func (oc *objectCache) semanticType(ref semanticcache.TypeRef) (types.Type, error) { typeName, err := oc.semanticTypeName(ref) if err != nil { @@ -829,6 +840,14 @@ func lookupStructField(st *types.Struct, name string) *types.Var { return nil } +func requiredStructField(st *types.Struct, name string) (*types.Var, error) { + v := lookupStructField(st, name) + if v == nil { + return nil, fmt.Errorf("field %q not found", name) + } + return v, nil +} + func (oc *objectCache) semanticArtifact(pkg *packages.Package) *semanticcache.PackageArtifact { if pkg == nil { return nil From 9d894b29550e8bde205b1d16bc56ba11197ef6fc Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 18:57:42 -0500 Subject: [PATCH 42/79] refactor: share semantic package object lookup --- internal/wire/parse.go | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 04a3301..36a6feb 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -655,11 +655,10 @@ func (oc *objectCache) applySemanticProviderSetItem(pset *ProviderSet, item sema } func (oc *objectCache) semanticProvider(importPath, name string) (*Provider, []error) { - pkg := oc.packages[importPath] - if pkg == nil || pkg.Types == nil { - return nil, []error{fmt.Errorf("missing typed package for %s", importPath)} + obj, err := oc.lookupPackageObject(importPath, name) + if err != nil { + return nil, []error{err} } - obj := pkg.Types.Scope().Lookup(name) fn, ok := obj.(*types.Func) if !ok || fn == nil { return nil, []error{fmt.Errorf("%s.%s is not a provider function", importPath, name)} @@ -668,11 +667,10 @@ func (oc *objectCache) semanticProvider(importPath, name string) (*Provider, []e } func (oc *objectCache) semanticImportedSet(importPath, name string) (*ProviderSet, []error) { - pkg := oc.packages[importPath] - if pkg == nil || pkg.Types == nil { - return nil, []error{fmt.Errorf("missing typed package for %s", importPath)} + obj, err := oc.lookupPackageObject(importPath, name) + if err != nil { + return nil, []error{err} } - obj := pkg.Types.Scope().Lookup(name) v, ok := obj.(*types.Var) if !ok || v == nil || !isProviderSetType(v.Type()) { return nil, []error{fmt.Errorf("%s.%s is not a provider set", importPath, name)} @@ -800,11 +798,10 @@ func (oc *objectCache) semanticType(ref semanticcache.TypeRef) (types.Type, erro } func (oc *objectCache) semanticTypeName(ref semanticcache.TypeRef) (*types.TypeName, error) { - pkg := oc.packages[ref.ImportPath] - if pkg == nil || pkg.Types == nil { - return nil, fmt.Errorf("missing typed package for %s", ref.ImportPath) + obj, err := oc.lookupPackageObject(ref.ImportPath, ref.Name) + if err != nil { + return nil, err } - obj := pkg.Types.Scope().Lookup(ref.Name) typeName, ok := obj.(*types.TypeName) if !ok || typeName == nil { return nil, fmt.Errorf("%s.%s is not a named type", ref.ImportPath, ref.Name) @@ -812,6 +809,14 @@ func (oc *objectCache) semanticTypeName(ref semanticcache.TypeRef) (*types.TypeN return typeName, nil } +func (oc *objectCache) lookupPackageObject(importPath, name string) (types.Object, error) { + pkg := oc.packages[importPath] + if pkg == nil || pkg.Types == nil { + return nil, fmt.Errorf("missing typed package for %s", importPath) + } + return pkg.Types.Scope().Lookup(name), nil +} + func structFromFieldsParent(parent types.Type) (*types.Struct, bool, error) { ptr, ok := parent.(*types.Pointer) if !ok { From 3775c4cb7ac0e703145edca7762ed8a8ea4c0bbb Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 19:00:27 -0500 Subject: [PATCH 43/79] refactor: share semantic output type assembly --- internal/wire/parse.go | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 36a6feb..0fa211b 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -716,7 +716,7 @@ func (oc *objectCache) semanticStructProvider(item semanticcache.ProviderSetItem Name: typeName.Name(), Pos: typeName.Pos(), IsStruct: true, - Out: []types.Type{out, types.NewPointer(out)}, + Out: typeAndPointer(out), } args, errs := semanticStructProviderInputs(st, item) if len(errs) > 0 { @@ -741,16 +741,12 @@ func (oc *objectCache) semanticFields(item semanticcache.ProviderSetItemArtifact if err != nil { return nil, []error{err} } - out := []types.Type{v.Type()} - if ptrToField { - out = append(out, types.NewPointer(v.Type())) - } fields = append(fields, &Field{ Parent: parent, Name: v.Name(), Pkg: v.Pkg(), Pos: v.Pos(), - Out: out, + Out: fieldOutputTypes(v.Type(), ptrToField), }) } return fields, nil @@ -785,6 +781,18 @@ func semanticStructProviderInputs(st *types.Struct, item semanticcache.ProviderS return args, nil } +func typeAndPointer(typ types.Type) []types.Type { + return []types.Type{typ, types.NewPointer(typ)} +} + +func fieldOutputTypes(typ types.Type, includePointer bool) []types.Type { + out := []types.Type{typ} + if includePointer { + out = append(out, types.NewPointer(typ)) + } + return out +} + func (oc *objectCache) semanticType(ref semanticcache.TypeRef) (types.Type, error) { typeName, err := oc.semanticTypeName(ref) if err != nil { From 2091c4b13f0678bf9e33c7144aebd90e2d80ae57 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 19:04:01 -0500 Subject: [PATCH 44/79] refactor: share struct provider shell assembly --- internal/wire/parse.go | 37 +++++++++++++++---------------------- 1 file changed, 15 insertions(+), 22 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 0fa211b..3a19ba4 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -711,13 +711,7 @@ func (oc *objectCache) semanticStructProvider(item semanticcache.ProviderSetItem if !ok { return nil, []error{fmt.Errorf("%s.%s does not name a struct", item.Type.ImportPath, item.Type.Name)} } - provider := &Provider{ - Pkg: typeName.Pkg(), - Name: typeName.Name(), - Pos: typeName.Pos(), - IsStruct: true, - Out: typeAndPointer(out), - } + provider := newStructProvider(typeName, typeAndPointer(out)) args, errs := semanticStructProviderInputs(st, item) if len(errs) > 0 { return nil, errs @@ -793,6 +787,16 @@ func fieldOutputTypes(typ types.Type, includePointer bool) []types.Type { return out } +func newStructProvider(typeName types.Object, out []types.Type) *Provider { + return &Provider{ + Pkg: typeName.Pkg(), + Name: typeName.Name(), + Pos: typeName.Pos(), + IsStruct: true, + Out: out, + } +} + func (oc *objectCache) semanticType(ref semanticcache.TypeRef) (types.Type, error) { typeName, err := oc.semanticTypeName(ref) if err != nil { @@ -1448,14 +1452,9 @@ func processStructLiteralProvider(fset *token.FileSet, typeName *types.TypeName) notePosition(fset.Position(pos), fmt.Errorf("using struct literal to inject %s is deprecated and will be removed in the next release; use wire.Struct instead", typeName.Type()))) - provider := &Provider{ - Pkg: typeName.Pkg(), - Name: typeName.Name(), - Pos: pos, - Args: make([]ProviderInput, st.NumFields()), - IsStruct: true, - Out: []types.Type{out, types.NewPointer(out)}, - } + provider := newStructProvider(typeName, typeAndPointer(out)) + provider.Pos = pos + provider.Args = make([]ProviderInput, st.NumFields()) for i := 0; i < st.NumFields(); i++ { f := st.Field(i) provider.Args[i] = ProviderInput{ @@ -1496,13 +1495,7 @@ func processStructProvider(fset *token.FileSet, info *types.Info, call *ast.Call stExpr := call.Args[0].(*ast.CallExpr) typeName := qualifiedIdentObject(info, stExpr.Args[0]) // should be either an identifier or selector - provider := &Provider{ - Pkg: typeName.Pkg(), - Name: typeName.Name(), - Pos: typeName.Pos(), - IsStruct: true, - Out: []types.Type{structPtr.Elem(), structPtr}, - } + provider := newStructProvider(typeName, []types.Type{structPtr.Elem(), structPtr}) if allFields(call) { for i := 0; i < st.NumFields(); i++ { if isPrevented(st.Tag(i)) { From a7cd4cde95d2787679cdbcd4e4a8d157eb693eef Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 19:08:45 -0500 Subject: [PATCH 45/79] refactor: share allowed struct field inputs --- internal/wire/parse.go | 52 ++++++++++++++++++------------------------ 1 file changed, 22 insertions(+), 30 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 3a19ba4..88657a7 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -748,18 +748,7 @@ func (oc *objectCache) semanticFields(item semanticcache.ProviderSetItemArtifact func semanticStructProviderInputs(st *types.Struct, item semanticcache.ProviderSetItemArtifact) ([]ProviderInput, []error) { if item.AllFields { - args := make([]ProviderInput, 0, st.NumFields()) - for i := 0; i < st.NumFields(); i++ { - if isPrevented(st.Tag(i)) { - continue - } - f := st.Field(i) - args = append(args, ProviderInput{ - Type: f.Type(), - FieldName: f.Name(), - }) - } - return args, nil + return providerInputsForAllowedStructFields(st), nil } args := make([]ProviderInput, 0, len(item.FieldNames)) for _, fieldName := range item.FieldNames { @@ -767,14 +756,29 @@ func semanticStructProviderInputs(st *types.Struct, item semanticcache.ProviderS if err != nil { return nil, []error{fmt.Errorf("field %q not found in %s.%s", fieldName, item.Type.ImportPath, item.Type.Name)} } - args = append(args, ProviderInput{ - Type: f.Type(), - FieldName: f.Name(), - }) + args = append(args, providerInputForVar(f)) } return args, nil } +func providerInputsForAllowedStructFields(st *types.Struct) []ProviderInput { + args := make([]ProviderInput, 0, st.NumFields()) + for i := 0; i < st.NumFields(); i++ { + if isPrevented(st.Tag(i)) { + continue + } + args = append(args, providerInputForVar(st.Field(i))) + } + return args +} + +func providerInputForVar(v *types.Var) ProviderInput { + return ProviderInput{ + Type: v.Type(), + FieldName: v.Name(), + } +} + func typeAndPointer(typ types.Type) []types.Type { return []types.Type{typ, types.NewPointer(typ)} } @@ -1497,16 +1501,7 @@ func processStructProvider(fset *token.FileSet, info *types.Info, call *ast.Call typeName := qualifiedIdentObject(info, stExpr.Args[0]) // should be either an identifier or selector provider := newStructProvider(typeName, []types.Type{structPtr.Elem(), structPtr}) if allFields(call) { - for i := 0; i < st.NumFields(); i++ { - if isPrevented(st.Tag(i)) { - continue - } - f := st.Field(i) - provider.Args = append(provider.Args, ProviderInput{ - Type: f.Type(), - FieldName: f.Name(), - }) - } + provider.Args = providerInputsForAllowedStructFields(st) } else { provider.Args = make([]ProviderInput, len(call.Args)-1) for i := 1; i < len(call.Args); i++ { @@ -1514,10 +1509,7 @@ func processStructProvider(fset *token.FileSet, info *types.Info, call *ast.Call if err != nil { return nil, notePosition(fset.Position(call.Pos()), err) } - provider.Args[i-1] = ProviderInput{ - Type: v.Type(), - FieldName: v.Name(), - } + provider.Args[i-1] = providerInputForVar(v) } } for i := 0; i < len(provider.Args); i++ { From 051fb75473a893cb77470f64137c7213e1f8e9ae Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 19:09:55 -0500 Subject: [PATCH 46/79] refactor: share selected struct field inputs --- internal/wire/parse.go | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 88657a7..9f9a26d 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -750,24 +750,32 @@ func semanticStructProviderInputs(st *types.Struct, item semanticcache.ProviderS if item.AllFields { return providerInputsForAllowedStructFields(st), nil } - args := make([]ProviderInput, 0, len(item.FieldNames)) + fields := make([]*types.Var, 0, len(item.FieldNames)) for _, fieldName := range item.FieldNames { f, err := requiredStructField(st, fieldName) if err != nil { return nil, []error{fmt.Errorf("field %q not found in %s.%s", fieldName, item.Type.ImportPath, item.Type.Name)} } - args = append(args, providerInputForVar(f)) + fields = append(fields, f) } - return args, nil + return providerInputsForVars(fields), nil } func providerInputsForAllowedStructFields(st *types.Struct) []ProviderInput { - args := make([]ProviderInput, 0, st.NumFields()) + fields := make([]*types.Var, 0, st.NumFields()) for i := 0; i < st.NumFields(); i++ { if isPrevented(st.Tag(i)) { continue } - args = append(args, providerInputForVar(st.Field(i))) + fields = append(fields, st.Field(i)) + } + return providerInputsForVars(fields) +} + +func providerInputsForVars(vars []*types.Var) []ProviderInput { + args := make([]ProviderInput, 0, len(vars)) + for _, v := range vars { + args = append(args, providerInputForVar(v)) } return args } @@ -1503,14 +1511,15 @@ func processStructProvider(fset *token.FileSet, info *types.Info, call *ast.Call if allFields(call) { provider.Args = providerInputsForAllowedStructFields(st) } else { - provider.Args = make([]ProviderInput, len(call.Args)-1) + fields := make([]*types.Var, 0, len(call.Args)-1) for i := 1; i < len(call.Args); i++ { v, err := checkField(call.Args[i], st) if err != nil { return nil, notePosition(fset.Position(call.Pos()), err) } - provider.Args[i-1] = providerInputForVar(v) + fields = append(fields, v) } + provider.Args = providerInputsForVars(fields) } for i := 0; i < len(provider.Args); i++ { for j := 0; j < i; j++ { From 7f6195a67c2f397ae839fc4debbb0d88434ca760 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 19:10:59 -0500 Subject: [PATCH 47/79] refactor: share field output assembly for FieldsOf --- internal/wire/parse.go | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 9f9a26d..3342aad 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -1713,18 +1713,12 @@ func processFieldsOf(fset *token.FileSet, info *types.Info, call *ast.CallExpr) if err != nil { return nil, notePosition(fset.Position(call.Pos()), err) } - out := []types.Type{v.Type()} - if isPtrToStruct { - // If the field is from a pointer to a struct, then - // wire.Fields also provides a pointer to the field. - out = append(out, types.NewPointer(v.Type())) - } fields = append(fields, &Field{ Parent: structPtr.Elem(), Name: v.Name(), Pkg: v.Pkg(), Pos: v.Pos(), - Out: out, + Out: fieldOutputTypes(v.Type(), isPtrToStruct), }) } return fields, nil From 326425316c696b346b1f8fea9c8024a9d433be48 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 19:11:56 -0500 Subject: [PATCH 48/79] refactor: share quoted struct field lookup --- internal/wire/parse.go | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 3342aad..6395d00 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -877,6 +877,15 @@ func requiredStructField(st *types.Struct, name string) (*types.Var, error) { return v, nil } +func lookupQuotedStructField(st *types.Struct, quotedName string) (*types.Var, int) { + for i := 0; i < st.NumFields(); i++ { + if strings.EqualFold(strconv.Quote(st.Field(i).Name()), quotedName) { + return st.Field(i), i + } + } + return nil, -1 +} + func (oc *objectCache) semanticArtifact(pkg *packages.Package) *semanticcache.PackageArtifact { if pkg == nil { return nil @@ -1731,13 +1740,12 @@ func checkField(f ast.Expr, st *types.Struct) (*types.Var, error) { if !ok { return nil, fmt.Errorf("%v must be a string with the field name", f) } - for i := 0; i < st.NumFields(); i++ { - if strings.EqualFold(strconv.Quote(st.Field(i).Name()), b.Value) { - if isPrevented(st.Tag(i)) { - return nil, fmt.Errorf("%s is prevented from injecting by wire", b.Value) - } - return st.Field(i), nil + v, i := lookupQuotedStructField(st, b.Value) + if v != nil { + if isPrevented(st.Tag(i)) { + return nil, fmt.Errorf("%s is prevented from injecting by wire", b.Value) } + return v, nil } return nil, fmt.Errorf("%s is not a field of %s", b.Value, st.String()) } From 41f4d7dfecffec3be083097a737f206934e4a2ff Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 19:14:35 -0500 Subject: [PATCH 49/79] refactor: share semantic pointer expansion --- internal/wire/parse.go | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 6395d00..2c36c76 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -788,13 +788,13 @@ func providerInputForVar(v *types.Var) ProviderInput { } func typeAndPointer(typ types.Type) []types.Type { - return []types.Type{typ, types.NewPointer(typ)} + return []types.Type{typ, applyTypePointers(typ, 1)} } func fieldOutputTypes(typ types.Type, includePointer bool) []types.Type { out := []types.Type{typ} if includePointer { - out = append(out, types.NewPointer(typ)) + out = append(out, applyTypePointers(typ, 1)) } return out } @@ -814,11 +814,7 @@ func (oc *objectCache) semanticType(ref semanticcache.TypeRef) (types.Type, erro if err != nil { return nil, err } - var typ types.Type = typeName.Type() - for i := 0; i < ref.Pointer; i++ { - typ = types.NewPointer(typ) - } - return typ, nil + return applyTypePointers(typeName.Type(), ref.Pointer), nil } func (oc *objectCache) semanticTypeName(ref semanticcache.TypeRef) (*types.TypeName, error) { @@ -841,6 +837,13 @@ func (oc *objectCache) lookupPackageObject(importPath, name string) (types.Objec return pkg.Types.Scope().Lookup(name), nil } +func applyTypePointers(typ types.Type, count int) types.Type { + for i := 0; i < count; i++ { + typ = types.NewPointer(typ) + } + return typ +} + func structFromFieldsParent(parent types.Type) (*types.Struct, bool, error) { ptr, ok := parent.(*types.Pointer) if !ok { From dea5a636ad6252f2862cfbd7167aacfc69fde04c Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 19:15:23 -0500 Subject: [PATCH 50/79] refactor: reuse field parent struct resolution --- internal/wire/parse.go | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 2c36c76..f52200f 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -1697,22 +1697,10 @@ func processFieldsOf(fset *token.FileSet, info *types.Info, call *ast.CallExpr) return nil, notePosition(fset.Position(call.Pos()), fmt.Errorf(firstArgReqFormat, types.TypeString(structType, nil))) } - - var struc *types.Struct - isPtrToStruct := false - switch t := structPtr.Elem().Underlying().(type) { - case *types.Pointer: - struc, ok = t.Elem().Underlying().(*types.Struct) - if !ok { - return nil, notePosition(fset.Position(call.Pos()), - fmt.Errorf(firstArgReqFormat, types.TypeString(struc, nil))) - } - isPtrToStruct = true - case *types.Struct: - struc = t - default: + struc, isPtrToStruct, err := structFromFieldsParent(structPtr) + if err != nil { return nil, notePosition(fset.Position(call.Pos()), - fmt.Errorf(firstArgReqFormat, types.TypeString(t, nil))) + fmt.Errorf(firstArgReqFormat, types.TypeString(structType, nil))) } if struc.NumFields() < len(call.Args)-1 { return nil, notePosition(fset.Position(call.Pos()), From 06060679318057694ce47598f9ec9656aa29360f Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 19:17:47 -0500 Subject: [PATCH 51/79] refactor: share field object assembly --- internal/wire/parse.go | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index f52200f..f1a43b1 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -735,13 +735,7 @@ func (oc *objectCache) semanticFields(item semanticcache.ProviderSetItemArtifact if err != nil { return nil, []error{err} } - fields = append(fields, &Field{ - Parent: parent, - Name: v.Name(), - Pkg: v.Pkg(), - Pos: v.Pos(), - Out: fieldOutputTypes(v.Type(), ptrToField), - }) + fields = append(fields, newField(parent, v, ptrToField)) } return fields, nil } @@ -787,6 +781,16 @@ func providerInputForVar(v *types.Var) ProviderInput { } } +func newField(parent types.Type, v *types.Var, includePointer bool) *Field { + return &Field{ + Parent: parent, + Name: v.Name(), + Pkg: v.Pkg(), + Pos: v.Pos(), + Out: fieldOutputTypes(v.Type(), includePointer), + } +} + func typeAndPointer(typ types.Type) []types.Type { return []types.Type{typ, applyTypePointers(typ, 1)} } @@ -1713,13 +1717,7 @@ func processFieldsOf(fset *token.FileSet, info *types.Info, call *ast.CallExpr) if err != nil { return nil, notePosition(fset.Position(call.Pos()), err) } - fields = append(fields, &Field{ - Parent: structPtr.Elem(), - Name: v.Name(), - Pkg: v.Pkg(), - Pos: v.Pos(), - Out: fieldOutputTypes(v.Type(), isPtrToStruct), - }) + fields = append(fields, newField(structPtr.Elem(), v, isPtrToStruct)) } return fields, nil } From f9d735f21db84886fb1e4e7b1961b9d0c79a87dd Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 19:19:12 -0500 Subject: [PATCH 52/79] refactor: share named struct type resolution --- internal/wire/parse.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index f1a43b1..353d743 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -706,8 +706,7 @@ func (oc *objectCache) semanticStructProvider(item semanticcache.ProviderSetItem if err != nil { return nil, []error{err} } - out := typeName.Type() - st, ok := out.Underlying().(*types.Struct) + out, st, ok := namedStructType(typeName) if !ok { return nil, []error{fmt.Errorf("%s.%s does not name a struct", item.Type.ImportPath, item.Type.Name)} } @@ -848,6 +847,12 @@ func applyTypePointers(typ types.Type, count int) types.Type { return typ } +func namedStructType(typeName types.Object) (types.Type, *types.Struct, bool) { + out := typeName.Type() + st, ok := out.Underlying().(*types.Struct) + return out, st, ok +} + func structFromFieldsParent(parent types.Type) (*types.Struct, bool, error) { ptr, ok := parent.(*types.Pointer) if !ok { @@ -1468,8 +1473,7 @@ func funcOutput(sig *types.Signature) (outputSignature, error) { // It will not support any new feature introduced after v0.2. Please use the new // wire.Struct syntax for those. func processStructLiteralProvider(fset *token.FileSet, typeName *types.TypeName) (*Provider, []error) { - out := typeName.Type() - st, ok := out.Underlying().(*types.Struct) + out, st, ok := namedStructType(typeName) if !ok { return nil, []error{fmt.Errorf("%v does not name a struct", typeName)} } From 4477c75c6ed812d59384b70678f57b9fb1f93932 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 19:20:16 -0500 Subject: [PATCH 53/79] refactor: share semantic type name lookup --- internal/wire/parse.go | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 353d743..4220d32 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -821,15 +821,7 @@ func (oc *objectCache) semanticType(ref semanticcache.TypeRef) (types.Type, erro } func (oc *objectCache) semanticTypeName(ref semanticcache.TypeRef) (*types.TypeName, error) { - obj, err := oc.lookupPackageObject(ref.ImportPath, ref.Name) - if err != nil { - return nil, err - } - typeName, ok := obj.(*types.TypeName) - if !ok || typeName == nil { - return nil, fmt.Errorf("%s.%s is not a named type", ref.ImportPath, ref.Name) - } - return typeName, nil + return oc.lookupPackageTypeName(ref.ImportPath, ref.Name) } func (oc *objectCache) lookupPackageObject(importPath, name string) (types.Object, error) { @@ -840,6 +832,18 @@ func (oc *objectCache) lookupPackageObject(importPath, name string) (types.Objec return pkg.Types.Scope().Lookup(name), nil } +func (oc *objectCache) lookupPackageTypeName(importPath, name string) (*types.TypeName, error) { + obj, err := oc.lookupPackageObject(importPath, name) + if err != nil { + return nil, err + } + typeName, ok := obj.(*types.TypeName) + if !ok || typeName == nil { + return nil, fmt.Errorf("%s.%s is not a named type", importPath, name) + } + return typeName, nil +} + func applyTypePointers(typ types.Type, count int) types.Type { for i := 0; i < count; i++ { typ = types.NewPointer(typ) From 2c0446de652895b389e96a14349b926b223adaed Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 19:21:11 -0500 Subject: [PATCH 54/79] refactor: share semantic package member lookup --- internal/wire/parse.go | 36 ++++++++++++++++++++++++++---------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 4220d32..0677cd0 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -655,26 +655,18 @@ func (oc *objectCache) applySemanticProviderSetItem(pset *ProviderSet, item sema } func (oc *objectCache) semanticProvider(importPath, name string) (*Provider, []error) { - obj, err := oc.lookupPackageObject(importPath, name) + fn, err := oc.lookupPackageFunc(importPath, name) if err != nil { return nil, []error{err} } - fn, ok := obj.(*types.Func) - if !ok || fn == nil { - return nil, []error{fmt.Errorf("%s.%s is not a provider function", importPath, name)} - } return processFuncProvider(oc.fset, fn) } func (oc *objectCache) semanticImportedSet(importPath, name string) (*ProviderSet, []error) { - obj, err := oc.lookupPackageObject(importPath, name) + v, err := oc.lookupProviderSetVar(importPath, name) if err != nil { return nil, []error{err} } - v, ok := obj.(*types.Var) - if !ok || v == nil || !isProviderSetType(v.Type()) { - return nil, []error{fmt.Errorf("%s.%s is not a provider set", importPath, name)} - } item, errs := oc.get(v) if len(errs) > 0 { return nil, errs @@ -844,6 +836,30 @@ func (oc *objectCache) lookupPackageTypeName(importPath, name string) (*types.Ty return typeName, nil } +func (oc *objectCache) lookupPackageFunc(importPath, name string) (*types.Func, error) { + obj, err := oc.lookupPackageObject(importPath, name) + if err != nil { + return nil, err + } + fn, ok := obj.(*types.Func) + if !ok || fn == nil { + return nil, fmt.Errorf("%s.%s is not a provider function", importPath, name) + } + return fn, nil +} + +func (oc *objectCache) lookupProviderSetVar(importPath, name string) (*types.Var, error) { + obj, err := oc.lookupPackageObject(importPath, name) + if err != nil { + return nil, err + } + v, ok := obj.(*types.Var) + if !ok || v == nil || !isProviderSetType(v.Type()) { + return nil, fmt.Errorf("%s.%s is not a provider set", importPath, name) + } + return v, nil +} + func applyTypePointers(typ types.Type, count int) types.Type { for i := 0; i < count; i++ { typ = types.NewPointer(typ) From 848371ff0edfc0a4ba9f06a9423d809e314ad648 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 19:22:09 -0500 Subject: [PATCH 55/79] refactor: share semantic error wrapping --- internal/wire/parse.go | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 0677cd0..430cb4f 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -681,11 +681,11 @@ func (oc *objectCache) semanticImportedSet(importPath, name string) (*ProviderSe func (oc *objectCache) semanticBinding(item semanticcache.ProviderSetItemArtifact) (*IfaceBinding, []error) { iface, err := oc.semanticType(item.Type) if err != nil { - return nil, []error{err} + return nil, semanticErrors(err) } provided, err := oc.semanticType(item.Type2) if err != nil { - return nil, []error{err} + return nil, semanticErrors(err) } return &IfaceBinding{ Iface: iface, @@ -696,11 +696,11 @@ func (oc *objectCache) semanticBinding(item semanticcache.ProviderSetItemArtifac func (oc *objectCache) semanticStructProvider(item semanticcache.ProviderSetItemArtifact) (*Provider, []error) { typeName, err := oc.semanticTypeName(item.Type) if err != nil { - return nil, []error{err} + return nil, semanticErrors(err) } out, st, ok := namedStructType(typeName) if !ok { - return nil, []error{fmt.Errorf("%s.%s does not name a struct", item.Type.ImportPath, item.Type.Name)} + return nil, semanticErrors(fmt.Errorf("%s.%s does not name a struct", item.Type.ImportPath, item.Type.Name)) } provider := newStructProvider(typeName, typeAndPointer(out)) args, errs := semanticStructProviderInputs(st, item) @@ -714,17 +714,17 @@ func (oc *objectCache) semanticStructProvider(item semanticcache.ProviderSetItem func (oc *objectCache) semanticFields(item semanticcache.ProviderSetItemArtifact) ([]*Field, []error) { parent, err := oc.semanticType(item.Type) if err != nil { - return nil, []error{err} + return nil, semanticErrors(err) } structType, ptrToField, err := structFromFieldsParent(parent) if err != nil { - return nil, []error{err} + return nil, semanticErrors(err) } fields := make([]*Field, 0, len(item.FieldNames)) for _, fieldName := range item.FieldNames { v, err := requiredStructField(structType, fieldName) if err != nil { - return nil, []error{err} + return nil, semanticErrors(err) } fields = append(fields, newField(parent, v, ptrToField)) } @@ -765,6 +765,10 @@ func providerInputsForVars(vars []*types.Var) []ProviderInput { return args } +func semanticErrors(err error) []error { + return []error{err} +} + func providerInputForVar(v *types.Var) ProviderInput { return ProviderInput{ Type: v.Type(), From 3944cbb876cca053e5a8b184389a1b71a14a0228 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 19:23:07 -0500 Subject: [PATCH 56/79] refactor: share provider set finalization --- internal/wire/parse.go | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 430cb4f..834f8c2 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -580,12 +580,7 @@ func (oc *objectCache) semanticProviderSet(obj *types.Var) (*ProviderSet, bool, if len(ec.errors) > 0 { return nil, true, ec.errors } - var errs []error - pset.providerMap, pset.srcMap, errs = buildProviderMap(oc.fset, oc.hasher, pset) - if len(errs) > 0 { - return nil, true, errs - } - if errs := verifyAcyclic(pset.providerMap, oc.hasher); len(errs) > 0 { + if errs := oc.finalizeProviderSet(pset); len(errs) > 0 { return nil, true, errs } return pset, true, nil @@ -1361,15 +1356,22 @@ func (oc *objectCache) processNewSet(info *types.Info, pkgPath string, call *ast if len(ec.errors) > 0 { return nil, ec.errors } + if errs := oc.finalizeProviderSet(pset); len(errs) > 0 { + return nil, errs + } + return pset, nil +} + +func (oc *objectCache) finalizeProviderSet(pset *ProviderSet) []error { var errs []error pset.providerMap, pset.srcMap, errs = buildProviderMap(oc.fset, oc.hasher, pset) if len(errs) > 0 { - return nil, errs + return errs } if errs := verifyAcyclic(pset.providerMap, oc.hasher); len(errs) > 0 { - return nil, errs + return errs } - return pset, nil + return nil } // structArgType attempts to interpret an expression as a simple struct type. From 12858a2ff65ee60de69cf55e8c1f531e87e58810 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 19:42:51 -0500 Subject: [PATCH 57/79] refactor: add isolated output cache gate --- internal/wire/output_cache.go | 8 ++++++- internal/wire/output_cache_test.go | 38 ++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) create mode 100644 internal/wire/output_cache_test.go diff --git a/internal/wire/output_cache.go b/internal/wire/output_cache.go index 35eacfe..b95a514 100644 --- a/internal/wire/output_cache.go +++ b/internal/wire/output_cache.go @@ -17,7 +17,10 @@ import ( "github.com/goforj/wire/internal/loader" ) -const outputCacheDirEnv = "WIRE_OUTPUT_CACHE_DIR" +const ( + outputCacheDirEnv = "WIRE_OUTPUT_CACHE_DIR" + outputCacheEnabledEnv = "WIRE_OUTPUT_CACHE" +) type outputCacheEntry struct { Version int @@ -104,6 +107,9 @@ func outputCacheEnabled(ctx context.Context, wd string, env []string) bool { if effectiveLoaderMode(ctx, wd, env) == loader.ModeFallback { return false } + if envValue(env, outputCacheEnabledEnv) == "0" { + return false + } return envValue(env, "WIRE_LOADER_ARTIFACTS") != "0" } diff --git a/internal/wire/output_cache_test.go b/internal/wire/output_cache_test.go new file mode 100644 index 0000000..a74621b --- /dev/null +++ b/internal/wire/output_cache_test.go @@ -0,0 +1,38 @@ +package wire + +import ( + "context" + "testing" +) + +func TestOutputCacheEnabled(t *testing.T) { + ctx := context.Background() + tests := []struct { + name string + env []string + want bool + }{ + { + name: "enabled with artifacts", + env: []string{"WIRE_LOADER_ARTIFACTS=1"}, + want: true, + }, + { + name: "disabled without artifacts", + env: []string{"WIRE_LOADER_ARTIFACTS=0"}, + want: false, + }, + { + name: "disabled by dedicated env", + env: []string{"WIRE_LOADER_ARTIFACTS=1", "WIRE_OUTPUT_CACHE=0"}, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := outputCacheEnabled(ctx, t.TempDir(), tt.env); got != tt.want { + t.Fatalf("outputCacheEnabled(..., %v) = %v, want %v", tt.env, got, tt.want) + } + }) + } +} From 8dbd59013ae0091cc8f1f7d272884beaddc73c1c Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 20:10:55 -0500 Subject: [PATCH 58/79] refactor: make provider set fallback policy explicit --- internal/wire/parse.go | 43 +++++++++++++++++++++++------------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 834f8c2..cd20bd5 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -538,8 +538,8 @@ func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { switch obj := obj.(type) { case *types.Var: spec := oc.varDecl(obj) - if spec == nil && isProviderSetType(obj.Type()) { - if pset, ok, errs := oc.semanticProviderSet(obj); ok { + if isProviderSetType(obj.Type()) { + if pset, ok, errs := oc.providerSetForVar(obj, spec); ok { return pset, errs } } @@ -561,6 +561,13 @@ func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { } } +func (oc *objectCache) providerSetForVar(obj *types.Var, spec *ast.ValueSpec) (*ProviderSet, bool, []error) { + if spec != nil { + return nil, false, nil + } + return oc.semanticProviderSet(obj) +} + func (oc *objectCache) semanticProviderSet(obj *types.Var) (*ProviderSet, bool, []error) { setArt, ok := oc.semanticProviderSetArtifact(obj) if !ok { @@ -1213,22 +1220,10 @@ func summarizeTypeRef(typ types.Type) (semanticcache.TypeRef, bool) { } func semanticVarDecl(pkg *packages.Package, obj *types.Var) *ast.ValueSpec { - pos := obj.Pos() - for _, f := range pkg.Syntax { - tokenFile := pkg.Fset.File(f.Pos()) - if tokenFile == nil { - continue - } - if base := tokenFile.Base(); base <= int(pos) && int(pos) < base+tokenFile.Size() { - path, _ := astutil.PathEnclosingInterval(f, pos, pos) - for _, node := range path { - if spec, ok := node.(*ast.ValueSpec); ok { - return spec - } - } - } + if pkg == nil { + return nil } - return nil + return valueSpecForVar(pkg.Fset, pkg.Syntax, obj) } // varDecl finds the declaration that defines the given variable. @@ -1236,9 +1231,19 @@ func (oc *objectCache) varDecl(obj *types.Var) *ast.ValueSpec { // TODO(light): Walk files to build object -> declaration mapping, if more performant. // Recommended by https://golang.org/s/types-tutorial pkg := oc.packages[obj.Pkg().Path()] + if pkg == nil { + return nil + } + return valueSpecForVar(oc.fset, pkg.Syntax, obj) +} + +func valueSpecForVar(fset *token.FileSet, files []*ast.File, obj *types.Var) *ast.ValueSpec { pos := obj.Pos() - for _, f := range pkg.Syntax { - tokenFile := oc.fset.File(f.Pos()) + for _, f := range files { + tokenFile := fset.File(f.Pos()) + if tokenFile == nil { + continue + } if base := tokenFile.Base(); base <= int(pos) && int(pos) < base+tokenFile.Size() { path, _ := astutil.PathEnclosingInterval(f, pos, pos) for _, node := range path { From 9ef6b11db20058188d10f75fb4cddb214ac6c438 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 20:11:53 -0500 Subject: [PATCH 59/79] refactor: share custom loader root loading path --- internal/loader/custom.go | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 3529350..634545d 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -259,18 +259,13 @@ func loadTypedPackageGraphCustom(ctx context.Context, req LazyLoadRequest) (*Laz fset = token.NewFileSet() } l := newCustomTypedGraphLoader(ctx, req.WD, req.Env, fset, meta, map[string]struct{}{req.Package: {}}, req.ParseFile, discoveryDuration) - prefetchStart := time.Now() - l.prefetchArtifacts() - l.stats.artifactPrefetch = time.Since(prefetchStart) - rootLoadStart := time.Now() - root, err := l.loadPackage(req.Package) + roots, err := loadCustomRootPackages(l, []string{req.Package}) if err != nil { return nil, err } - l.stats.rootLoad = time.Since(rootLoadStart) logTypedLoadStats(ctx, "lazy", l.stats) return &LazyLoadResult{ - Packages: []*packages.Package{root}, + Packages: roots, Backend: ModeCustom, }, nil } @@ -303,13 +298,26 @@ func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLo return nil, unsupportedError{reason: "no root packages from metadata"} } l := newCustomTypedGraphLoader(ctx, req.WD, req.Env, fset, meta, targets, req.ParseFile, discoveryDuration) + rootPaths := nonDepRootImportPaths(meta) + roots, err := loadCustomRootPackages(l, rootPaths) + if err != nil { + return nil, err + } + logTypedLoadStats(ctx, "typed", l.stats) + return &PackageLoadResult{ + Packages: roots, + Backend: ModeCustom, + }, nil +} + +func loadCustomRootPackages(l *customTypedGraphLoader, paths []string) ([]*packages.Package, error) { prefetchStart := time.Now() l.prefetchArtifacts() l.stats.artifactPrefetch = time.Since(prefetchStart) + rootLoadStart := time.Now() - rootPaths := nonDepRootImportPaths(meta) - roots := make([]*packages.Package, 0, len(rootPaths)) - for _, path := range rootPaths { + roots := make([]*packages.Package, 0, len(paths)) + for _, path := range paths { root, err := l.loadPackage(path) if err != nil { return nil, err @@ -318,11 +326,7 @@ func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLo } l.stats.rootLoad = time.Since(rootLoadStart) sort.Slice(roots, func(i, j int) bool { return roots[i].PkgPath < roots[j].PkgPath }) - logTypedLoadStats(ctx, "typed", l.stats) - return &PackageLoadResult{ - Packages: roots, - Backend: ModeCustom, - }, nil + return roots, nil } func (v *customValidator) validatePackage(path string) (*packages.Package, error) { From 4386089cdff454a9d32b4ae131bb6b2f2affc8d9 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 20:12:51 -0500 Subject: [PATCH 60/79] refactor: share custom loader metadata root graph --- internal/loader/custom.go | 67 ++++++++++++++++++++++++--------------- 1 file changed, 41 insertions(+), 26 deletions(-) diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 634545d..669e3f1 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -191,31 +191,11 @@ func loadRootGraphCustom(ctx context.Context, req RootLoadRequest) (*RootLoadRes if len(meta) == 0 { return nil, unsupportedError{reason: "empty go list result"} } - pkgs := make(map[string]*packages.Package, len(meta)) - for path, m := range meta { - pkgs[path] = packageStub(nil, m) - appendPackageMetaError(pkgs[path], m) - } - for path, m := range meta { - pkg := pkgs[path] - for _, imp := range m.Imports { - target := resolvedImportTarget(m, imp) - if dep := pkgs[target]; dep != nil { - pkg.Imports[imp] = dep - } - } - } - rootPaths := nonDepRootImportPaths(meta) - roots := make([]*packages.Package, 0, len(rootPaths)) - for _, path := range rootPaths { - if pkg := pkgs[path]; pkg != nil { - roots = append(roots, pkg) - } - } + pkgs := packageStubGraphFromMeta(nil, meta) + roots := rootPackagesFromMeta(meta, pkgs) if len(roots) == 0 { return nil, unsupportedError{reason: "no root packages from metadata"} } - sort.Slice(roots, func(i, j int) bool { return roots[i].PkgPath < roots[j].PkgPath }) return &RootLoadResult{ Packages: roots, Backend: ModeCustom, @@ -290,10 +270,7 @@ func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLo if fset == nil { fset = token.NewFileSet() } - targets := make(map[string]struct{}) - for _, path := range nonDepRootImportPaths(meta) { - targets[path] = struct{}{} - } + targets := rootTargetSet(meta) if len(targets) == 0 { return nil, unsupportedError{reason: "no root packages from metadata"} } @@ -1220,6 +1197,24 @@ func packageStub(fset *token.FileSet, meta *packageMeta) *packages.Package { } } +func packageStubGraphFromMeta(fset *token.FileSet, meta map[string]*packageMeta) map[string]*packages.Package { + pkgs := make(map[string]*packages.Package, len(meta)) + for path, m := range meta { + pkgs[path] = packageStub(fset, m) + appendPackageMetaError(pkgs[path], m) + } + for path, m := range meta { + pkg := pkgs[path] + for _, imp := range m.Imports { + target := resolvedImportTarget(m, imp) + if dep := pkgs[target]; dep != nil { + pkg.Imports[imp] = dep + } + } + } + return pkgs +} + func appendPackageMetaError(pkg *packages.Package, meta *packageMeta) bool { if pkg == nil || meta == nil || meta.Error == nil || strings.TrimSpace(meta.Error.Err) == "" { return false @@ -1288,6 +1283,26 @@ func nonDepRootImportPaths(meta map[string]*packageMeta) []string { return paths } +func rootTargetSet(meta map[string]*packageMeta) map[string]struct{} { + targets := make(map[string]struct{}) + for _, path := range nonDepRootImportPaths(meta) { + targets[path] = struct{}{} + } + return targets +} + +func rootPackagesFromMeta(meta map[string]*packageMeta, pkgs map[string]*packages.Package) []*packages.Package { + rootPaths := nonDepRootImportPaths(meta) + roots := make([]*packages.Package, 0, len(rootPaths)) + for _, path := range rootPaths { + if pkg := pkgs[path]; pkg != nil { + roots = append(roots, pkg) + } + } + sort.Slice(roots, func(i, j int) bool { return roots[i].PkgPath < roots[j].PkgPath }) + return roots +} + func logTypedLoadStats(ctx context.Context, mode string, stats typedLoadStats) { prefix := "loader.custom." + mode logDuration(ctx, prefix+".read_files.cumulative", stats.read) From b099db8f1a2d65f3a24b3484872c79c9376698fe Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 20:13:22 -0500 Subject: [PATCH 61/79] refactor: isolate semantic provider set support rule --- internal/wire/parse.go | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index cd20bd5..f56d2c1 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -606,12 +606,19 @@ func (oc *objectCache) semanticProviderSetArtifact(obj *types.Var) (semanticcach if !ok { return semanticcache.ProviderSetArtifact{}, false } + if !semanticProviderSetArtifactSupported(setArt) { + return semanticcache.ProviderSetArtifact{}, false + } + return setArt, true +} + +func semanticProviderSetArtifactSupported(setArt semanticcache.ProviderSetArtifact) bool { for _, item := range setArt.Items { if item.Kind == "bind" { - return semanticcache.ProviderSetArtifact{}, false + return false } } - return setArt, true + return true } func (oc *objectCache) applySemanticProviderSetItem(pset *ProviderSet, item semanticcache.ProviderSetItemArtifact) []error { From cf4bb809a4fa813f1108c372ede3cc46685137ce Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 20:26:52 -0500 Subject: [PATCH 62/79] refactor: fold back weak cleanup abstractions --- internal/loader/custom.go | 36 +++++++++++++----------------------- internal/wire/parse.go | 10 ++-------- 2 files changed, 15 insertions(+), 31 deletions(-) diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 669e3f1..007eebd 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -192,10 +192,17 @@ func loadRootGraphCustom(ctx context.Context, req RootLoadRequest) (*RootLoadRes return nil, unsupportedError{reason: "empty go list result"} } pkgs := packageStubGraphFromMeta(nil, meta) - roots := rootPackagesFromMeta(meta, pkgs) + rootPaths := nonDepRootImportPaths(meta) + roots := make([]*packages.Package, 0, len(rootPaths)) + for _, path := range rootPaths { + if pkg := pkgs[path]; pkg != nil { + roots = append(roots, pkg) + } + } if len(roots) == 0 { return nil, unsupportedError{reason: "no root packages from metadata"} } + sort.Slice(roots, func(i, j int) bool { return roots[i].PkgPath < roots[j].PkgPath }) return &RootLoadResult{ Packages: roots, Backend: ModeCustom, @@ -270,12 +277,15 @@ func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLo if fset == nil { fset = token.NewFileSet() } - targets := rootTargetSet(meta) + rootPaths := nonDepRootImportPaths(meta) + targets := make(map[string]struct{}, len(rootPaths)) + for _, path := range rootPaths { + targets[path] = struct{}{} + } if len(targets) == 0 { return nil, unsupportedError{reason: "no root packages from metadata"} } l := newCustomTypedGraphLoader(ctx, req.WD, req.Env, fset, meta, targets, req.ParseFile, discoveryDuration) - rootPaths := nonDepRootImportPaths(meta) roots, err := loadCustomRootPackages(l, rootPaths) if err != nil { return nil, err @@ -1283,26 +1293,6 @@ func nonDepRootImportPaths(meta map[string]*packageMeta) []string { return paths } -func rootTargetSet(meta map[string]*packageMeta) map[string]struct{} { - targets := make(map[string]struct{}) - for _, path := range nonDepRootImportPaths(meta) { - targets[path] = struct{}{} - } - return targets -} - -func rootPackagesFromMeta(meta map[string]*packageMeta, pkgs map[string]*packages.Package) []*packages.Package { - rootPaths := nonDepRootImportPaths(meta) - roots := make([]*packages.Package, 0, len(rootPaths)) - for _, path := range rootPaths { - if pkg := pkgs[path]; pkg != nil { - roots = append(roots, pkg) - } - } - sort.Slice(roots, func(i, j int) bool { return roots[i].PkgPath < roots[j].PkgPath }) - return roots -} - func logTypedLoadStats(ctx context.Context, mode string, stats typedLoadStats) { prefix := "loader.custom." + mode logDuration(ctx, prefix+".read_files.cumulative", stats.read) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index f56d2c1..66d51da 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -539,7 +539,8 @@ func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { case *types.Var: spec := oc.varDecl(obj) if isProviderSetType(obj.Type()) { - if pset, ok, errs := oc.providerSetForVar(obj, spec); ok { + if spec == nil { + pset, _, errs := oc.semanticProviderSet(obj) return pset, errs } } @@ -561,13 +562,6 @@ func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { } } -func (oc *objectCache) providerSetForVar(obj *types.Var, spec *ast.ValueSpec) (*ProviderSet, bool, []error) { - if spec != nil { - return nil, false, nil - } - return oc.semanticProviderSet(obj) -} - func (oc *objectCache) semanticProviderSet(obj *types.Var) (*ProviderSet, bool, []error) { setArt, ok := oc.semanticProviderSetArtifact(obj) if !ok { From c323a0f7fdc379cc1330d8e00ccc4ffc56b09914 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 20:32:44 -0500 Subject: [PATCH 63/79] refactor: narrow loader semantic artifact coupling --- internal/loader/artifact_cache.go | 8 ++++++-- internal/loader/custom.go | 16 ++++++++++++++++ internal/loader/loader_test.go | 28 ++++++++++++++++++++++++++++ 3 files changed, 50 insertions(+), 2 deletions(-) diff --git a/internal/loader/artifact_cache.go b/internal/loader/artifact_cache.go index e920d5a..f57ba76 100644 --- a/internal/loader/artifact_cache.go +++ b/internal/loader/artifact_cache.go @@ -153,9 +153,13 @@ func isProviderSetTypeForLoader(t types.Type) bool { if obj == nil || obj.Pkg() == nil { return false } - switch obj.Pkg().Path() { + return isWireImportPath(obj.Pkg().Path()) && obj.Name() == "ProviderSet" +} + +func isWireImportPath(path string) bool { + switch path { case "github.com/goforj/wire", "github.com/google/wire": - return obj.Name() == "ProviderSet" + return true default: return false } diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 007eebd..0cdd919 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -546,6 +546,18 @@ func (l *customTypedGraphLoader) localSemanticArtifactSupported(meta *packageMet return art.Supported } +func localPackageNeedsSemanticArtifacts(meta *packageMeta) bool { + if meta == nil { + return false + } + for _, path := range meta.Imports { + if isWireImportPath(path) { + return true + } + } + return false +} + func (l *customTypedGraphLoader) artifactPolicy(meta *packageMeta, isTarget, isLocal bool) artifactPolicy { if !loaderArtifactEnabled(l.env) || isTarget { return artifactPolicy{} @@ -555,6 +567,10 @@ func (l *customTypedGraphLoader) artifactPolicy(meta *packageMeta, isTarget, isL policy.read = true return policy } + if !localPackageNeedsSemanticArtifacts(meta) { + policy.read = true + return policy + } policy.read = l.localSemanticArtifactSupported(meta) return policy } diff --git a/internal/loader/loader_test.go b/internal/loader/loader_test.go index 05cfaa7..8e5e585 100644 --- a/internal/loader/loader_test.go +++ b/internal/loader/loader_test.go @@ -2386,6 +2386,34 @@ func TestLoaderArtifactKeyExternalWithoutExportChangesWhenSourceChanges(t *testi } } +func TestArtifactPolicyLocalReadOnlyNeedsSemanticForWirePackages(t *testing.T) { + t.Parallel() + + loader := &customTypedGraphLoader{ + env: []string{"WIRE_LOADER_ARTIFACTS=1"}, + localSemanticOK: map[string]bool{"example.com/app": false}, + } + + nonWireMeta := &packageMeta{ + ImportPath: "example.com/app", + Imports: []string{"fmt", "example.com/dep"}, + } + wireMeta := &packageMeta{ + ImportPath: "example.com/app", + Imports: []string{"github.com/goforj/wire"}, + } + + if got := loader.artifactPolicy(nonWireMeta, false, true); !got.read || !got.write { + t.Fatalf("artifactPolicy(non-wire local) = %+v, want read+write", got) + } + if got := loader.artifactPolicy(wireMeta, false, true); got.read || !got.write { + t.Fatalf("artifactPolicy(wire local without semantic support) = %+v, want write-only", got) + } + if got := loader.artifactPolicy(wireMeta, false, false); !got.read || !got.write { + t.Fatalf("artifactPolicy(wire external) = %+v, want read+write", got) + } +} + func TestRunGoListIncludesExportDataForReplacedModule(t *testing.T) { root := t.TempDir() depRoot := filepath.Join(root, "depmod") From 7bf31e8accdeefa14f1fc1b4269c80472e21bebc Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 20:39:52 -0500 Subject: [PATCH 64/79] refactor: unify semantic provider set support rules --- internal/wire/parse.go | 3 +++ internal/wire/parse_coverage_test.go | 19 +++++++++++++++---- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 66d51da..2c33d99 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -1043,6 +1043,9 @@ func summarizeSemanticProviderSet(info *types.Info, expr ast.Expr, pkgPath strin } setArt.Items = append(setArt.Items, items...) } + if !semanticProviderSetArtifactSupported(setArt) { + return semanticcache.ProviderSetArtifact{}, false + } return setArt, true } diff --git a/internal/wire/parse_coverage_test.go b/internal/wire/parse_coverage_test.go index c3c4d8e..8b7e68a 100644 --- a/internal/wire/parse_coverage_test.go +++ b/internal/wire/parse_coverage_test.go @@ -299,14 +299,25 @@ func TestSummarizeSemanticProviderSetTypeOnlyForms(t *testing.T) { &ast.CallExpr{Fun: &ast.SelectorExpr{X: wireIdent, Sel: fieldsIdent}, Args: []ast.Expr{ptrToPtrFooCall, &ast.BasicLit{Kind: token.STRING, Value: "\"Message\""}}}, }, } + if got, ok := summarizeSemanticProviderSet(info, call, "example.com/app"); ok || len(got.Items) != 0 { + t.Fatalf("summarizeSemanticProviderSet(bind case) = (%+v, %v), want unsupported", got, ok) + } + + call = &ast.CallExpr{ + Fun: &ast.SelectorExpr{X: wireIdent, Sel: newSetIdent}, + Args: []ast.Expr{ + &ast.CallExpr{Fun: &ast.SelectorExpr{X: wireIdent, Sel: structIdent}, Args: []ast.Expr{newFooCall, &ast.BasicLit{Kind: token.STRING, Value: "\"*\""}}}, + &ast.CallExpr{Fun: &ast.SelectorExpr{X: wireIdent, Sel: fieldsIdent}, Args: []ast.Expr{ptrToPtrFooCall, &ast.BasicLit{Kind: token.STRING, Value: "\"Message\""}}}, + }, + } got, ok := summarizeSemanticProviderSet(info, call, "example.com/app") if !ok { - t.Fatal("summarizeSemanticProviderSet() = unsupported, want supported") + t.Fatal("summarizeSemanticProviderSet(non-bind type-only forms) = unsupported, want supported") } - if len(got.Items) != 3 { - t.Fatalf("items len = %d, want 3", len(got.Items)) + if len(got.Items) != 2 { + t.Fatalf("items len = %d, want 2", len(got.Items)) } - if got.Items[0].Kind != "bind" || got.Items[1].Kind != "struct" || got.Items[2].Kind != "fields" { + if got.Items[0].Kind != "struct" || got.Items[1].Kind != "fields" { t.Fatalf("unexpected kinds: %+v", got.Items) } } From a7fc49c9993dc7fd2db747895f7a92c737f89fff Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 20:51:06 -0500 Subject: [PATCH 65/79] fix: restore local loader artifact safety gate --- internal/loader/artifact_cache.go | 8 ++------ internal/loader/custom.go | 16 ---------------- internal/loader/loader_test.go | 28 ---------------------------- 3 files changed, 2 insertions(+), 50 deletions(-) diff --git a/internal/loader/artifact_cache.go b/internal/loader/artifact_cache.go index f57ba76..e920d5a 100644 --- a/internal/loader/artifact_cache.go +++ b/internal/loader/artifact_cache.go @@ -153,13 +153,9 @@ func isProviderSetTypeForLoader(t types.Type) bool { if obj == nil || obj.Pkg() == nil { return false } - return isWireImportPath(obj.Pkg().Path()) && obj.Name() == "ProviderSet" -} - -func isWireImportPath(path string) bool { - switch path { + switch obj.Pkg().Path() { case "github.com/goforj/wire", "github.com/google/wire": - return true + return obj.Name() == "ProviderSet" default: return false } diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 0cdd919..007eebd 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -546,18 +546,6 @@ func (l *customTypedGraphLoader) localSemanticArtifactSupported(meta *packageMet return art.Supported } -func localPackageNeedsSemanticArtifacts(meta *packageMeta) bool { - if meta == nil { - return false - } - for _, path := range meta.Imports { - if isWireImportPath(path) { - return true - } - } - return false -} - func (l *customTypedGraphLoader) artifactPolicy(meta *packageMeta, isTarget, isLocal bool) artifactPolicy { if !loaderArtifactEnabled(l.env) || isTarget { return artifactPolicy{} @@ -567,10 +555,6 @@ func (l *customTypedGraphLoader) artifactPolicy(meta *packageMeta, isTarget, isL policy.read = true return policy } - if !localPackageNeedsSemanticArtifacts(meta) { - policy.read = true - return policy - } policy.read = l.localSemanticArtifactSupported(meta) return policy } diff --git a/internal/loader/loader_test.go b/internal/loader/loader_test.go index 8e5e585..05cfaa7 100644 --- a/internal/loader/loader_test.go +++ b/internal/loader/loader_test.go @@ -2386,34 +2386,6 @@ func TestLoaderArtifactKeyExternalWithoutExportChangesWhenSourceChanges(t *testi } } -func TestArtifactPolicyLocalReadOnlyNeedsSemanticForWirePackages(t *testing.T) { - t.Parallel() - - loader := &customTypedGraphLoader{ - env: []string{"WIRE_LOADER_ARTIFACTS=1"}, - localSemanticOK: map[string]bool{"example.com/app": false}, - } - - nonWireMeta := &packageMeta{ - ImportPath: "example.com/app", - Imports: []string{"fmt", "example.com/dep"}, - } - wireMeta := &packageMeta{ - ImportPath: "example.com/app", - Imports: []string{"github.com/goforj/wire"}, - } - - if got := loader.artifactPolicy(nonWireMeta, false, true); !got.read || !got.write { - t.Fatalf("artifactPolicy(non-wire local) = %+v, want read+write", got) - } - if got := loader.artifactPolicy(wireMeta, false, true); got.read || !got.write { - t.Fatalf("artifactPolicy(wire local without semantic support) = %+v, want write-only", got) - } - if got := loader.artifactPolicy(wireMeta, false, false); !got.read || !got.write { - t.Fatalf("artifactPolicy(wire external) = %+v, want read+write", got) - } -} - func TestRunGoListIncludesExportDataForReplacedModule(t *testing.T) { root := t.TempDir() depRoot := filepath.Join(root, "depmod") From df1f6f62d0a211079b98e357167dcea21ca22fd9 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 21:36:10 -0500 Subject: [PATCH 66/79] refactor: disable semantic reconstruction by default --- internal/wire/parse.go | 24 ++++++++++++-- internal/wire/parse_coverage_test.go | 2 ++ internal/wire/semantic_reconstruction_test.go | 33 +++++++++++++++++++ 3 files changed, 57 insertions(+), 2 deletions(-) create mode 100644 internal/wire/semantic_reconstruction_test.go diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 2c33d99..747cf61 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -36,6 +36,8 @@ import ( "github.com/goforj/wire/internal/semanticcache" ) +const semanticReconstructionEnv = "WIRE_SEMANTIC_RECONSTRUCTION" + // A providerSetSrc captures the source for a type provided by a ProviderSet. // Exactly one of the fields will be set. type providerSetSrc struct { @@ -491,7 +493,9 @@ func newObjectCacheWithEnv(pkgs []*packages.Package, env []string) *objectCache // call to packages.Load and an import path X, there will exist only // one *packages.Package value with PkgPath X. oc.registerPackages(pkgs, false) - oc.recordSemanticArtifacts() + if semanticReconstructionEnabled(env) { + oc.recordSemanticArtifacts() + } return oc } @@ -588,6 +592,9 @@ func (oc *objectCache) semanticProviderSet(obj *types.Var) (*ProviderSet, bool, } func (oc *objectCache) semanticProviderSetArtifact(obj *types.Var) (semanticcache.ProviderSetArtifact, bool) { + if !semanticReconstructionEnabled(oc.env) { + return semanticcache.ProviderSetArtifact{}, false + } pkg := oc.packages[obj.Pkg().Path()] if pkg == nil { return semanticcache.ProviderSetArtifact{}, false @@ -926,6 +933,9 @@ func lookupQuotedStructField(st *types.Struct, quotedName string) (*types.Var, i } func (oc *objectCache) semanticArtifact(pkg *packages.Package) *semanticcache.PackageArtifact { + if !semanticReconstructionEnabled(oc.env) { + return nil + } if pkg == nil { return nil } @@ -945,7 +955,7 @@ func (oc *objectCache) semanticArtifact(pkg *packages.Package) *semanticcache.Pa } func (oc *objectCache) recordSemanticArtifacts() { - if len(oc.env) == 0 { + if len(oc.env) == 0 || !semanticReconstructionEnabled(oc.env) { return } for _, pkg := range oc.packages { @@ -962,6 +972,16 @@ func (oc *objectCache) recordSemanticArtifacts() { } } +func semanticReconstructionEnabled(env []string) bool { + for i := len(env) - 1; i >= 0; i-- { + key, value, ok := strings.Cut(env[i], "=") + if ok && key == semanticReconstructionEnv { + return value == "1" + } + } + return false +} + func semanticArtifactInputs(env []string, pkg *packages.Package) (importPath, packageName string, files []string, ok bool) { if len(env) == 0 || pkg == nil || len(pkg.GoFiles) == 0 { return "", "", nil, false diff --git a/internal/wire/parse_coverage_test.go b/internal/wire/parse_coverage_test.go index 8b7e68a..4ba2a30 100644 --- a/internal/wire/parse_coverage_test.go +++ b/internal/wire/parse_coverage_test.go @@ -345,6 +345,7 @@ func TestObjectCacheSemanticProviderSetFallback(t *testing.T) { Imports: make(map[string]*packages.Package), } oc := &objectCache{ + env: []string{"WIRE_SEMANTIC_RECONSTRUCTION=1"}, fset: fset, packages: map[string]*packages.Package{depPkg.PkgPath: depPkg}, objects: make(map[objRef]objCacheEntry), @@ -406,6 +407,7 @@ func TestObjectCacheSemanticProviderSetSkipsBindArtifacts(t *testing.T) { Imports: make(map[string]*packages.Package), } oc := &objectCache{ + env: []string{"WIRE_SEMANTIC_RECONSTRUCTION=1"}, fset: fset, packages: map[string]*packages.Package{appPkg.PkgPath: appPkg}, objects: make(map[objRef]objCacheEntry), diff --git a/internal/wire/semantic_reconstruction_test.go b/internal/wire/semantic_reconstruction_test.go new file mode 100644 index 0000000..79cbe3c --- /dev/null +++ b/internal/wire/semantic_reconstruction_test.go @@ -0,0 +1,33 @@ +package wire + +import "testing" + +func TestSemanticReconstructionEnabled(t *testing.T) { + tests := []struct { + name string + env []string + want bool + }{ + { + name: "disabled by default", + want: false, + }, + { + name: "enabled by env", + env: []string{"WIRE_SEMANTIC_RECONSTRUCTION=1"}, + want: true, + }, + { + name: "disabled by env", + env: []string{"WIRE_SEMANTIC_RECONSTRUCTION=0"}, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := semanticReconstructionEnabled(tt.env); got != tt.want { + t.Fatalf("semanticReconstructionEnabled(%v) = %v, want %v", tt.env, got, tt.want) + } + }) + } +} From 3927014faae17b14931efa46ed23296f5631a8b4 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 22:20:11 -0500 Subject: [PATCH 67/79] refactor: remove semantic reconstruction path --- internal/wire/parse.go | 570 +----------------- internal/wire/parse_coverage_test.go | 234 ------- internal/wire/semantic_reconstruction_test.go | 33 - internal/wire/wire.go | 2 +- 4 files changed, 2 insertions(+), 837 deletions(-) delete mode 100644 internal/wire/semantic_reconstruction_test.go diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 747cf61..4350baa 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -33,11 +33,8 @@ import ( "golang.org/x/tools/go/types/typeutil" "github.com/goforj/wire/internal/loader" - "github.com/goforj/wire/internal/semanticcache" ) -const semanticReconstructionEnv = "WIRE_SEMANTIC_RECONSTRUCTION" - // A providerSetSrc captures the source for a type provided by a ProviderSet. // Exactly one of the fields will be set. type providerSetSrc struct { @@ -270,7 +267,7 @@ func Load(ctx context.Context, wd string, env []string, tags string, patterns [] Fset: fset, Sets: make(map[ProviderSetID]*ProviderSet), } - oc := newObjectCacheWithEnv(pkgs, env) + oc := newObjectCache(pkgs) ec := new(errorCollector) for _, pkg := range pkgs { if isWireImport(pkg.PkgPath) { @@ -455,10 +452,8 @@ func (in *Injector) String() string { // objectCache is a lazily evaluated mapping of objects to Wire structures. type objectCache struct { fset *token.FileSet - env []string packages map[string]*packages.Package objects map[objRef]objCacheEntry - semantic map[string]*semanticcache.PackageArtifact hasher typeutil.Hasher } @@ -473,19 +468,13 @@ type objCacheEntry struct { } func newObjectCache(pkgs []*packages.Package) *objectCache { - return newObjectCacheWithEnv(pkgs, nil) -} - -func newObjectCacheWithEnv(pkgs []*packages.Package, env []string) *objectCache { if len(pkgs) == 0 { panic("object cache must have packages to draw from") } oc := &objectCache{ fset: pkgs[0].Fset, - env: append([]string(nil), env...), packages: make(map[string]*packages.Package), objects: make(map[objRef]objCacheEntry), - semantic: make(map[string]*semanticcache.PackageArtifact), hasher: typeutil.MakeHasher(), } // Depth-first search of all dependencies to gather import path to @@ -493,9 +482,6 @@ func newObjectCacheWithEnv(pkgs []*packages.Package, env []string) *objectCache // call to packages.Load and an import path X, there will exist only // one *packages.Package value with PkgPath X. oc.registerPackages(pkgs, false) - if semanticReconstructionEnabled(env) { - oc.recordSemanticArtifacts() - } return oc } @@ -542,12 +528,6 @@ func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { switch obj := obj.(type) { case *types.Var: spec := oc.varDecl(obj) - if isProviderSetType(obj.Type()) { - if spec == nil { - pset, _, errs := oc.semanticProviderSet(obj) - return pset, errs - } - } if spec == nil || len(spec.Values) == 0 { return nil, []error{fmt.Errorf("%v is not a provider or a provider set", obj)} } @@ -566,196 +546,6 @@ func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { } } -func (oc *objectCache) semanticProviderSet(obj *types.Var) (*ProviderSet, bool, []error) { - setArt, ok := oc.semanticProviderSetArtifact(obj) - if !ok { - return nil, false, nil - } - pset := &ProviderSet{ - Pos: obj.Pos(), - PkgPath: obj.Pkg().Path(), - VarName: obj.Name(), - } - ec := new(errorCollector) - for _, item := range setArt.Items { - if errs := oc.applySemanticProviderSetItem(pset, item); len(errs) > 0 { - ec.add(errs...) - } - } - if len(ec.errors) > 0 { - return nil, true, ec.errors - } - if errs := oc.finalizeProviderSet(pset); len(errs) > 0 { - return nil, true, errs - } - return pset, true, nil -} - -func (oc *objectCache) semanticProviderSetArtifact(obj *types.Var) (semanticcache.ProviderSetArtifact, bool) { - if !semanticReconstructionEnabled(oc.env) { - return semanticcache.ProviderSetArtifact{}, false - } - pkg := oc.packages[obj.Pkg().Path()] - if pkg == nil { - return semanticcache.ProviderSetArtifact{}, false - } - art := oc.semanticArtifact(pkg) - if art == nil || !art.Supported { - return semanticcache.ProviderSetArtifact{}, false - } - setArt, ok := art.Vars[obj.Name()] - if !ok { - return semanticcache.ProviderSetArtifact{}, false - } - if !semanticProviderSetArtifactSupported(setArt) { - return semanticcache.ProviderSetArtifact{}, false - } - return setArt, true -} - -func semanticProviderSetArtifactSupported(setArt semanticcache.ProviderSetArtifact) bool { - for _, item := range setArt.Items { - if item.Kind == "bind" { - return false - } - } - return true -} - -func (oc *objectCache) applySemanticProviderSetItem(pset *ProviderSet, item semanticcache.ProviderSetItemArtifact) []error { - switch item.Kind { - case "func": - providerObj, errs := oc.semanticProvider(item.ImportPath, item.Name) - if len(errs) > 0 { - return errs - } - pset.Providers = append(pset.Providers, providerObj) - return nil - case "set": - setObj, errs := oc.semanticImportedSet(item.ImportPath, item.Name) - if len(errs) > 0 { - return errs - } - pset.Imports = append(pset.Imports, setObj) - return nil - case "bind": - binding, errs := oc.semanticBinding(item) - if len(errs) > 0 { - return errs - } - pset.Bindings = append(pset.Bindings, binding) - return nil - case "struct": - providerObj, errs := oc.semanticStructProvider(item) - if len(errs) > 0 { - return errs - } - pset.Providers = append(pset.Providers, providerObj) - return nil - case "fields": - fields, errs := oc.semanticFields(item) - if len(errs) > 0 { - return errs - } - pset.Fields = append(pset.Fields, fields...) - return nil - default: - return []error{fmt.Errorf("unsupported semantic cache item kind %q", item.Kind)} - } -} - -func (oc *objectCache) semanticProvider(importPath, name string) (*Provider, []error) { - fn, err := oc.lookupPackageFunc(importPath, name) - if err != nil { - return nil, []error{err} - } - return processFuncProvider(oc.fset, fn) -} - -func (oc *objectCache) semanticImportedSet(importPath, name string) (*ProviderSet, []error) { - v, err := oc.lookupProviderSetVar(importPath, name) - if err != nil { - return nil, []error{err} - } - item, errs := oc.get(v) - if len(errs) > 0 { - return nil, errs - } - pset, ok := item.(*ProviderSet) - if !ok || pset == nil { - return nil, []error{fmt.Errorf("%s.%s did not resolve to a provider set", importPath, name)} - } - return pset, nil -} - -func (oc *objectCache) semanticBinding(item semanticcache.ProviderSetItemArtifact) (*IfaceBinding, []error) { - iface, err := oc.semanticType(item.Type) - if err != nil { - return nil, semanticErrors(err) - } - provided, err := oc.semanticType(item.Type2) - if err != nil { - return nil, semanticErrors(err) - } - return &IfaceBinding{ - Iface: iface, - Provided: provided, - }, nil -} - -func (oc *objectCache) semanticStructProvider(item semanticcache.ProviderSetItemArtifact) (*Provider, []error) { - typeName, err := oc.semanticTypeName(item.Type) - if err != nil { - return nil, semanticErrors(err) - } - out, st, ok := namedStructType(typeName) - if !ok { - return nil, semanticErrors(fmt.Errorf("%s.%s does not name a struct", item.Type.ImportPath, item.Type.Name)) - } - provider := newStructProvider(typeName, typeAndPointer(out)) - args, errs := semanticStructProviderInputs(st, item) - if len(errs) > 0 { - return nil, errs - } - provider.Args = args - return provider, nil -} - -func (oc *objectCache) semanticFields(item semanticcache.ProviderSetItemArtifact) ([]*Field, []error) { - parent, err := oc.semanticType(item.Type) - if err != nil { - return nil, semanticErrors(err) - } - structType, ptrToField, err := structFromFieldsParent(parent) - if err != nil { - return nil, semanticErrors(err) - } - fields := make([]*Field, 0, len(item.FieldNames)) - for _, fieldName := range item.FieldNames { - v, err := requiredStructField(structType, fieldName) - if err != nil { - return nil, semanticErrors(err) - } - fields = append(fields, newField(parent, v, ptrToField)) - } - return fields, nil -} - -func semanticStructProviderInputs(st *types.Struct, item semanticcache.ProviderSetItemArtifact) ([]ProviderInput, []error) { - if item.AllFields { - return providerInputsForAllowedStructFields(st), nil - } - fields := make([]*types.Var, 0, len(item.FieldNames)) - for _, fieldName := range item.FieldNames { - f, err := requiredStructField(st, fieldName) - if err != nil { - return nil, []error{fmt.Errorf("field %q not found in %s.%s", fieldName, item.Type.ImportPath, item.Type.Name)} - } - fields = append(fields, f) - } - return providerInputsForVars(fields), nil -} - func providerInputsForAllowedStructFields(st *types.Struct) []ProviderInput { fields := make([]*types.Var, 0, st.NumFields()) for i := 0; i < st.NumFields(); i++ { @@ -775,10 +565,6 @@ func providerInputsForVars(vars []*types.Var) []ProviderInput { return args } -func semanticErrors(err error) []error { - return []error{err} -} - func providerInputForVar(v *types.Var) ProviderInput { return ProviderInput{ Type: v.Type(), @@ -818,18 +604,6 @@ func newStructProvider(typeName types.Object, out []types.Type) *Provider { } } -func (oc *objectCache) semanticType(ref semanticcache.TypeRef) (types.Type, error) { - typeName, err := oc.semanticTypeName(ref) - if err != nil { - return nil, err - } - return applyTypePointers(typeName.Type(), ref.Pointer), nil -} - -func (oc *objectCache) semanticTypeName(ref semanticcache.TypeRef) (*types.TypeName, error) { - return oc.lookupPackageTypeName(ref.ImportPath, ref.Name) -} - func (oc *objectCache) lookupPackageObject(importPath, name string) (types.Object, error) { pkg := oc.packages[importPath] if pkg == nil || pkg.Types == nil { @@ -838,18 +612,6 @@ func (oc *objectCache) lookupPackageObject(importPath, name string) (types.Objec return pkg.Types.Scope().Lookup(name), nil } -func (oc *objectCache) lookupPackageTypeName(importPath, name string) (*types.TypeName, error) { - obj, err := oc.lookupPackageObject(importPath, name) - if err != nil { - return nil, err - } - typeName, ok := obj.(*types.TypeName) - if !ok || typeName == nil { - return nil, fmt.Errorf("%s.%s is not a named type", importPath, name) - } - return typeName, nil -} - func (oc *objectCache) lookupPackageFunc(importPath, name string) (*types.Func, error) { obj, err := oc.lookupPackageObject(importPath, name) if err != nil { @@ -862,18 +624,6 @@ func (oc *objectCache) lookupPackageFunc(importPath, name string) (*types.Func, return fn, nil } -func (oc *objectCache) lookupProviderSetVar(importPath, name string) (*types.Var, error) { - obj, err := oc.lookupPackageObject(importPath, name) - if err != nil { - return nil, err - } - v, ok := obj.(*types.Var) - if !ok || v == nil || !isProviderSetType(v.Type()) { - return nil, fmt.Errorf("%s.%s is not a provider set", importPath, name) - } - return v, nil -} - func applyTypePointers(typ types.Type, count int) types.Type { for i := 0; i < count; i++ { typ = types.NewPointer(typ) @@ -932,324 +682,6 @@ func lookupQuotedStructField(st *types.Struct, quotedName string) (*types.Var, i return nil, -1 } -func (oc *objectCache) semanticArtifact(pkg *packages.Package) *semanticcache.PackageArtifact { - if !semanticReconstructionEnabled(oc.env) { - return nil - } - if pkg == nil { - return nil - } - if art, ok := oc.semantic[pkg.PkgPath]; ok { - return art - } - importPath, packageName, files, ok := semanticArtifactInputs(oc.env, pkg) - if !ok { - return nil - } - art, err := readSemanticArtifact(oc.env, importPath, packageName, files) - if err != nil { - return nil - } - oc.semantic[pkg.PkgPath] = art - return art -} - -func (oc *objectCache) recordSemanticArtifacts() { - if len(oc.env) == 0 || !semanticReconstructionEnabled(oc.env) { - return - } - for _, pkg := range oc.packages { - importPath, packageName, files, ok := semanticArtifactInputs(oc.env, pkg) - if !ok || len(pkg.Syntax) == 0 || pkg.Types == nil || pkg.TypesInfo == nil { - continue - } - art := buildSemanticArtifact(pkg) - if art == nil { - continue - } - oc.semantic[pkg.PkgPath] = art - _ = writeSemanticArtifact(oc.env, importPath, packageName, files, art) - } -} - -func semanticReconstructionEnabled(env []string) bool { - for i := len(env) - 1; i >= 0; i-- { - key, value, ok := strings.Cut(env[i], "=") - if ok && key == semanticReconstructionEnv { - return value == "1" - } - } - return false -} - -func semanticArtifactInputs(env []string, pkg *packages.Package) (importPath, packageName string, files []string, ok bool) { - if len(env) == 0 || pkg == nil || len(pkg.GoFiles) == 0 { - return "", "", nil, false - } - return pkg.PkgPath, pkg.Name, pkg.GoFiles, true -} - -func readSemanticArtifact(env []string, importPath, packageName string, files []string) (*semanticcache.PackageArtifact, error) { - return semanticcache.Read(env, importPath, packageName, files) -} - -func writeSemanticArtifact(env []string, importPath, packageName string, files []string, art *semanticcache.PackageArtifact) error { - return semanticcache.Write(env, importPath, packageName, files, art) -} - -func buildSemanticArtifact(pkg *packages.Package) *semanticcache.PackageArtifact { - if pkg == nil || pkg.Types == nil || pkg.TypesInfo == nil { - return nil - } - art := &semanticcache.PackageArtifact{ - Version: 1, - PackagePath: pkg.PkgPath, - PackageName: pkg.Name, - Supported: true, - Vars: make(map[string]semanticcache.ProviderSetArtifact), - } - scope := pkg.Types.Scope() - for _, name := range scope.Names() { - obj := scope.Lookup(name) - v, ok := obj.(*types.Var) - if !ok || !isProviderSetType(v.Type()) { - continue - } - art.HasProviderSetVars = true - spec := semanticVarDecl(pkg, v) - if spec == nil || len(spec.Values) == 0 { - art.Supported = false - continue - } - var idx int - found := false - for i := range spec.Names { - if spec.Names[i].Name == v.Name() { - idx = i - found = true - break - } - } - if !found || idx >= len(spec.Values) { - art.Supported = false - continue - } - setArt, ok := summarizeSemanticProviderSet(pkg.TypesInfo, spec.Values[idx], pkg.PkgPath) - if !ok { - art.Supported = false - continue - } - art.Vars[v.Name()] = setArt - } - return art -} - -func summarizeSemanticProviderSet(info *types.Info, expr ast.Expr, pkgPath string) (semanticcache.ProviderSetArtifact, bool) { - call, ok := astutil.Unparen(expr).(*ast.CallExpr) - if !ok { - return semanticcache.ProviderSetArtifact{}, false - } - fnObj := qualifiedIdentObject(info, call.Fun) - if fnObj == nil || fnObj.Pkg() == nil || !isWireImport(fnObj.Pkg().Path()) || fnObj.Name() != "NewSet" { - return semanticcache.ProviderSetArtifact{}, false - } - setArt := semanticcache.ProviderSetArtifact{ - Items: make([]semanticcache.ProviderSetItemArtifact, 0, len(call.Args)), - } - for _, arg := range call.Args { - items, ok := summarizeSemanticProviderSetArg(info, astutil.Unparen(arg), pkgPath) - if !ok { - return semanticcache.ProviderSetArtifact{}, false - } - setArt.Items = append(setArt.Items, items...) - } - if !semanticProviderSetArtifactSupported(setArt) { - return semanticcache.ProviderSetArtifact{}, false - } - return setArt, true -} - -func summarizeSemanticProviderSetArg(info *types.Info, expr ast.Expr, pkgPath string) ([]semanticcache.ProviderSetItemArtifact, bool) { - if obj := qualifiedIdentObject(info, expr); obj != nil && obj.Pkg() != nil && obj.Exported() { - item := semanticcache.ProviderSetItemArtifact{ - ImportPath: obj.Pkg().Path(), - Name: obj.Name(), - } - switch typed := obj.(type) { - case *types.Func: - item.Kind = "func" - case *types.Var: - if !isProviderSetType(typed.Type()) { - return nil, false - } - item.Kind = "set" - default: - return nil, false - } - if item.ImportPath == "" { - item.ImportPath = pkgPath - } - return []semanticcache.ProviderSetItemArtifact{item}, true - } - call, ok := expr.(*ast.CallExpr) - if !ok { - return nil, false - } - fnObj := qualifiedIdentObject(info, call.Fun) - if fnObj == nil || fnObj.Pkg() == nil || !isWireImport(fnObj.Pkg().Path()) { - return nil, false - } - switch fnObj.Name() { - case "NewSet": - nested, ok := summarizeSemanticProviderSet(info, call, pkgPath) - if !ok { - return nil, false - } - return nested.Items, true - case "Bind": - item, ok := summarizeSemanticBind(info, call) - if !ok { - return nil, false - } - return []semanticcache.ProviderSetItemArtifact{item}, true - case "Struct": - item, ok := summarizeSemanticStruct(info, call) - if !ok { - return nil, false - } - return []semanticcache.ProviderSetItemArtifact{item}, true - case "FieldsOf": - item, ok := summarizeSemanticFields(info, call) - if !ok { - return nil, false - } - return []semanticcache.ProviderSetItemArtifact{item}, true - default: - return nil, false - } -} - -func summarizeSemanticBind(info *types.Info, call *ast.CallExpr) (semanticcache.ProviderSetItemArtifact, bool) { - if len(call.Args) != 2 { - return semanticcache.ProviderSetItemArtifact{}, false - } - iface, ok := summarizeTypeRef(info.TypeOf(call.Args[0])) - if !ok || iface.Pointer == 0 { - return semanticcache.ProviderSetItemArtifact{}, false - } - iface.Pointer-- - providedType := info.TypeOf(call.Args[1]) - if bindShouldUsePointer(info, call) { - ptr, ok := providedType.(*types.Pointer) - if !ok { - return semanticcache.ProviderSetItemArtifact{}, false - } - providedType = ptr.Elem() - } - provided, ok := summarizeTypeRef(providedType) - if !ok { - return semanticcache.ProviderSetItemArtifact{}, false - } - return semanticcache.ProviderSetItemArtifact{ - Kind: "bind", - Type: iface, - Type2: provided, - }, true -} - -func summarizeSemanticStruct(info *types.Info, call *ast.CallExpr) (semanticcache.ProviderSetItemArtifact, bool) { - if len(call.Args) < 1 { - return semanticcache.ProviderSetItemArtifact{}, false - } - structType := info.TypeOf(call.Args[0]) - ptr, ok := structType.(*types.Pointer) - if !ok { - return semanticcache.ProviderSetItemArtifact{}, false - } - ref, ok := summarizeTypeRef(ptr.Elem()) - if !ok { - return semanticcache.ProviderSetItemArtifact{}, false - } - item := semanticcache.ProviderSetItemArtifact{ - Kind: "struct", - Type: ref, - } - if allFields(call) { - item.AllFields = true - return item, true - } - item.FieldNames = make([]string, 0, len(call.Args)-1) - for i := 1; i < len(call.Args); i++ { - lit, ok := call.Args[i].(*ast.BasicLit) - if !ok { - return semanticcache.ProviderSetItemArtifact{}, false - } - fieldName, err := strconv.Unquote(lit.Value) - if err != nil { - return semanticcache.ProviderSetItemArtifact{}, false - } - item.FieldNames = append(item.FieldNames, fieldName) - } - return item, true -} - -func summarizeSemanticFields(info *types.Info, call *ast.CallExpr) (semanticcache.ProviderSetItemArtifact, bool) { - if len(call.Args) < 2 { - return semanticcache.ProviderSetItemArtifact{}, false - } - parent, ok := summarizeTypeRef(info.TypeOf(call.Args[0])) - if !ok { - return semanticcache.ProviderSetItemArtifact{}, false - } - item := semanticcache.ProviderSetItemArtifact{ - Kind: "fields", - Type: parent, - FieldNames: make([]string, 0, len(call.Args)-1), - } - for i := 1; i < len(call.Args); i++ { - lit, ok := call.Args[i].(*ast.BasicLit) - if !ok { - return semanticcache.ProviderSetItemArtifact{}, false - } - fieldName, err := strconv.Unquote(lit.Value) - if err != nil { - return semanticcache.ProviderSetItemArtifact{}, false - } - item.FieldNames = append(item.FieldNames, fieldName) - } - return item, true -} - -func summarizeTypeRef(typ types.Type) (semanticcache.TypeRef, bool) { - ref := semanticcache.TypeRef{} - for { - ptr, ok := typ.(*types.Pointer) - if !ok { - break - } - ref.Pointer++ - typ = ptr.Elem() - } - named, ok := typ.(*types.Named) - if !ok { - return semanticcache.TypeRef{}, false - } - obj := named.Obj() - if obj == nil || obj.Pkg() == nil { - return semanticcache.TypeRef{}, false - } - ref.ImportPath = obj.Pkg().Path() - ref.Name = obj.Name() - return ref, true -} - -func semanticVarDecl(pkg *packages.Package, obj *types.Var) *ast.ValueSpec { - if pkg == nil { - return nil - } - return valueSpecForVar(pkg.Fset, pkg.Syntax, obj) -} - // varDecl finds the declaration that defines the given variable. func (oc *objectCache) varDecl(obj *types.Var) *ast.ValueSpec { // TODO(light): Walk files to build object -> declaration mapping, if more performant. diff --git a/internal/wire/parse_coverage_test.go b/internal/wire/parse_coverage_test.go index 4ba2a30..7c7a3b7 100644 --- a/internal/wire/parse_coverage_test.go +++ b/internal/wire/parse_coverage_test.go @@ -22,8 +22,6 @@ import ( "golang.org/x/tools/go/packages" "golang.org/x/tools/go/types/typeutil" - - "github.com/goforj/wire/internal/semanticcache" ) func TestFindInjectorBuildVariants(t *testing.T) { @@ -222,238 +220,6 @@ func TestProcessStructProviderDuplicateFields(t *testing.T) { } } -func TestSummarizeSemanticProviderSet(t *testing.T) { - t.Parallel() - - info := &types.Info{ - Uses: make(map[*ast.Ident]types.Object), - } - wirePkg := types.NewPackage("github.com/goforj/wire", "wire") - wireIdent := ast.NewIdent("wire") - newSetIdent := ast.NewIdent("NewSet") - info.Uses[wireIdent] = types.NewPkgName(token.NoPos, nil, "wire", wirePkg) - info.Uses[newSetIdent] = types.NewFunc(token.NoPos, wirePkg, "NewSet", nil) - - depPkg := types.NewPackage("example.com/dep", "dep") - fnIdent := ast.NewIdent("NewMessage") - info.Uses[fnIdent] = types.NewFunc(token.NoPos, depPkg, "NewMessage", types.NewSignatureType(nil, nil, nil, nil, types.NewTuple(types.NewVar(token.NoPos, depPkg, "", types.Typ[types.String])), false)) - - call := &ast.CallExpr{ - Fun: &ast.SelectorExpr{X: wireIdent, Sel: newSetIdent}, - Args: []ast.Expr{ - fnIdent, - }, - } - got, ok := summarizeSemanticProviderSet(info, call, "example.com/app") - if !ok { - t.Fatal("summarizeSemanticProviderSet() = unsupported, want supported") - } - if len(got.Items) != 1 { - t.Fatalf("items len = %d, want 1", len(got.Items)) - } - if got.Items[0].Kind != "func" || got.Items[0].ImportPath != "example.com/dep" || got.Items[0].Name != "NewMessage" { - t.Fatalf("unexpected item: %+v", got.Items[0]) - } -} - -func TestSummarizeSemanticProviderSetTypeOnlyForms(t *testing.T) { - t.Parallel() - - info := &types.Info{ - Types: make(map[ast.Expr]types.TypeAndValue), - Uses: make(map[*ast.Ident]types.Object), - } - wirePkg := types.NewPackage("github.com/goforj/wire", "wire") - wireIdent := ast.NewIdent("wire") - info.Uses[wireIdent] = types.NewPkgName(token.NoPos, nil, "wire", wirePkg) - - appPkg := types.NewPackage("example.com/app", "app") - fooObj := types.NewTypeName(token.NoPos, appPkg, "Foo", nil) - fooNamed := types.NewNamed(fooObj, types.NewStruct([]*types.Var{ - types.NewVar(token.NoPos, appPkg, "Message", types.Typ[types.String]), - }, []string{""}), nil) - fooIfaceObj := types.NewTypeName(token.NoPos, appPkg, "Fooer", nil) - fooIface := types.NewNamed(fooIfaceObj, types.NewInterfaceType(nil, nil).Complete(), nil) - - newSetIdent := ast.NewIdent("NewSet") - bindIdent := ast.NewIdent("Bind") - structIdent := ast.NewIdent("Struct") - fieldsIdent := ast.NewIdent("FieldsOf") - info.Uses[newSetIdent] = types.NewFunc(token.NoPos, wirePkg, "NewSet", nil) - info.Uses[bindIdent] = types.NewFunc(token.NoPos, wirePkg, "Bind", nil) - info.Uses[structIdent] = types.NewFunc(token.NoPos, wirePkg, "Struct", nil) - info.Uses[fieldsIdent] = types.NewFunc(token.NoPos, wirePkg, "FieldsOf", nil) - - newFooCall := &ast.CallExpr{Fun: ast.NewIdent("new"), Args: []ast.Expr{ast.NewIdent("Foo")}} - info.Types[newFooCall] = types.TypeAndValue{Type: types.NewPointer(fooNamed)} - newFooIfaceCall := &ast.CallExpr{Fun: ast.NewIdent("new"), Args: []ast.Expr{ast.NewIdent("Fooer")}} - info.Types[newFooIfaceCall] = types.TypeAndValue{Type: types.NewPointer(fooIface)} - ptrToPtrFooCall := &ast.CallExpr{Fun: ast.NewIdent("new"), Args: []ast.Expr{ast.NewIdent("FooPtr")}} - info.Types[ptrToPtrFooCall] = types.TypeAndValue{Type: types.NewPointer(types.NewPointer(fooNamed))} - - call := &ast.CallExpr{ - Fun: &ast.SelectorExpr{X: wireIdent, Sel: newSetIdent}, - Args: []ast.Expr{ - &ast.CallExpr{Fun: &ast.SelectorExpr{X: wireIdent, Sel: bindIdent}, Args: []ast.Expr{newFooIfaceCall, newFooCall}}, - &ast.CallExpr{Fun: &ast.SelectorExpr{X: wireIdent, Sel: structIdent}, Args: []ast.Expr{newFooCall, &ast.BasicLit{Kind: token.STRING, Value: "\"*\""}}}, - &ast.CallExpr{Fun: &ast.SelectorExpr{X: wireIdent, Sel: fieldsIdent}, Args: []ast.Expr{ptrToPtrFooCall, &ast.BasicLit{Kind: token.STRING, Value: "\"Message\""}}}, - }, - } - if got, ok := summarizeSemanticProviderSet(info, call, "example.com/app"); ok || len(got.Items) != 0 { - t.Fatalf("summarizeSemanticProviderSet(bind case) = (%+v, %v), want unsupported", got, ok) - } - - call = &ast.CallExpr{ - Fun: &ast.SelectorExpr{X: wireIdent, Sel: newSetIdent}, - Args: []ast.Expr{ - &ast.CallExpr{Fun: &ast.SelectorExpr{X: wireIdent, Sel: structIdent}, Args: []ast.Expr{newFooCall, &ast.BasicLit{Kind: token.STRING, Value: "\"*\""}}}, - &ast.CallExpr{Fun: &ast.SelectorExpr{X: wireIdent, Sel: fieldsIdent}, Args: []ast.Expr{ptrToPtrFooCall, &ast.BasicLit{Kind: token.STRING, Value: "\"Message\""}}}, - }, - } - got, ok := summarizeSemanticProviderSet(info, call, "example.com/app") - if !ok { - t.Fatal("summarizeSemanticProviderSet(non-bind type-only forms) = unsupported, want supported") - } - if len(got.Items) != 2 { - t.Fatalf("items len = %d, want 2", len(got.Items)) - } - if got.Items[0].Kind != "struct" || got.Items[1].Kind != "fields" { - t.Fatalf("unexpected kinds: %+v", got.Items) - } -} - -func TestObjectCacheSemanticProviderSetFallback(t *testing.T) { - t.Parallel() - - fset := token.NewFileSet() - wirePkg := types.NewPackage("github.com/goforj/wire", "wire") - wireObj := types.NewTypeName(token.NoPos, wirePkg, "ProviderSet", nil) - wireNamed := types.NewNamed(wireObj, types.NewStruct(nil, nil), nil) - - depTypes := types.NewPackage("example.com/dep", "dep") - msgFnSig := types.NewSignatureType(nil, nil, nil, nil, types.NewTuple(types.NewVar(token.NoPos, depTypes, "", types.Typ[types.String])), false) - msgFn := types.NewFunc(token.NoPos, depTypes, "NewMessage", msgFnSig) - setVar := types.NewVar(token.NoPos, depTypes, "Set", wireNamed) - depTypes.Scope().Insert(msgFn) - depTypes.Scope().Insert(setVar) - - depPkg := &packages.Package{ - Name: "dep", - PkgPath: depTypes.Path(), - Types: depTypes, - Fset: fset, - Imports: make(map[string]*packages.Package), - } - oc := &objectCache{ - env: []string{"WIRE_SEMANTIC_RECONSTRUCTION=1"}, - fset: fset, - packages: map[string]*packages.Package{depPkg.PkgPath: depPkg}, - objects: make(map[objRef]objCacheEntry), - semantic: map[string]*semanticcache.PackageArtifact{ - depPkg.PkgPath: { - Version: 1, - PackagePath: depPkg.PkgPath, - PackageName: depPkg.Name, - Supported: true, - Vars: map[string]semanticcache.ProviderSetArtifact{ - "Set": { - Items: []semanticcache.ProviderSetItemArtifact{ - {Kind: "func", ImportPath: depPkg.PkgPath, Name: "NewMessage"}, - }, - }, - }, - }, - }, - hasher: typeutil.MakeHasher(), - } - item, errs := oc.get(setVar) - if len(errs) > 0 { - t.Fatalf("oc.get(Set) errs = %v", errs) - } - pset, ok := item.(*ProviderSet) - if !ok || pset == nil { - t.Fatalf("oc.get(Set) type = %T, want *ProviderSet", item) - } - if len(pset.Providers) != 1 || pset.Providers[0].Name != "NewMessage" { - t.Fatalf("unexpected providers: %+v", pset.Providers) - } -} - -func TestObjectCacheSemanticProviderSetSkipsBindArtifacts(t *testing.T) { - t.Parallel() - - fset := token.NewFileSet() - wirePkg := types.NewPackage("github.com/goforj/wire", "wire") - wireObj := types.NewTypeName(token.NoPos, wirePkg, "ProviderSet", nil) - wireNamed := types.NewNamed(wireObj, types.NewStruct(nil, nil), nil) - - appTypes := types.NewPackage("example.com/app", "app") - fooIfaceObj := types.NewTypeName(token.NoPos, appTypes, "Fooer", nil) - _ = types.NewNamed(fooIfaceObj, types.NewInterfaceType(nil, nil).Complete(), nil) - fooObj := types.NewTypeName(token.NoPos, appTypes, "Foo", nil) - _ = types.NewNamed(fooObj, types.NewStruct([]*types.Var{ - types.NewVar(token.NoPos, appTypes, "Message", types.Typ[types.String]), - }, []string{""}), nil) - setVar := types.NewVar(token.NoPos, appTypes, "Set", wireNamed) - appTypes.Scope().Insert(fooIfaceObj) - appTypes.Scope().Insert(fooObj) - appTypes.Scope().Insert(setVar) - - appPkg := &packages.Package{ - Name: "app", - PkgPath: appTypes.Path(), - Types: appTypes, - Fset: fset, - Imports: make(map[string]*packages.Package), - } - oc := &objectCache{ - env: []string{"WIRE_SEMANTIC_RECONSTRUCTION=1"}, - fset: fset, - packages: map[string]*packages.Package{appPkg.PkgPath: appPkg}, - objects: make(map[objRef]objCacheEntry), - semantic: map[string]*semanticcache.PackageArtifact{ - appPkg.PkgPath: { - Version: 1, - PackagePath: appPkg.PkgPath, - PackageName: appPkg.Name, - Supported: true, - Vars: map[string]semanticcache.ProviderSetArtifact{ - "Set": { - Items: []semanticcache.ProviderSetItemArtifact{ - { - Kind: "bind", - Type: semanticcache.TypeRef{ImportPath: appPkg.PkgPath, Name: "Fooer"}, - Type2: semanticcache.TypeRef{ImportPath: appPkg.PkgPath, Name: "Foo"}, - }, - { - Kind: "struct", - Type: semanticcache.TypeRef{ImportPath: appPkg.PkgPath, Name: "Foo"}, - AllFields: true, - }, - { - Kind: "fields", - Type: semanticcache.TypeRef{ImportPath: appPkg.PkgPath, Name: "Foo", Pointer: 2}, - FieldNames: []string{"Message"}, - }, - }, - }, - }, - }, - }, - hasher: typeutil.MakeHasher(), - } - pset, ok, errs := oc.semanticProviderSet(setVar) - if len(errs) > 0 { - t.Fatalf("semanticProviderSet(Set) errs = %v", errs) - } - if ok { - t.Fatalf("semanticProviderSet(Set) ok = true, want false") - } - if pset != nil { - t.Fatalf("semanticProviderSet(Set) = %#v, want nil", pset) - } -} - func TestProcessFuncProviderErrors(t *testing.T) { t.Parallel() diff --git a/internal/wire/semantic_reconstruction_test.go b/internal/wire/semantic_reconstruction_test.go deleted file mode 100644 index 79cbe3c..0000000 --- a/internal/wire/semantic_reconstruction_test.go +++ /dev/null @@ -1,33 +0,0 @@ -package wire - -import "testing" - -func TestSemanticReconstructionEnabled(t *testing.T) { - tests := []struct { - name string - env []string - want bool - }{ - { - name: "disabled by default", - want: false, - }, - { - name: "enabled by env", - env: []string{"WIRE_SEMANTIC_RECONSTRUCTION=1"}, - want: true, - }, - { - name: "disabled by env", - env: []string{"WIRE_SEMANTIC_RECONSTRUCTION=0"}, - want: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := semanticReconstructionEnabled(tt.env); got != tt.want { - t.Fatalf("semanticReconstructionEnabled(%v) = %v, want %v", tt.env, got, tt.want) - } - }) - } -} diff --git a/internal/wire/wire.go b/internal/wire/wire.go index 2459723..1c44eba 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -125,7 +125,7 @@ func Generate(ctx context.Context, wd string, env []string, patterns []string, o } generated[i].OutputPath = filepath.Join(outDir, opts.PrefixOutputFile+"wire_gen.go") g := newGen(pkg) - oc := newObjectCacheWithEnv([]*packages.Package{pkg}, env) + oc := newObjectCache([]*packages.Package{pkg}) injectorStart := time.Now() injectorFiles, genErrs := generateInjectors(oc, g, pkg) logTiming(ctx, "generate.package."+pkg.PkgPath+".injectors", injectorStart) From 4af9edcf37168c3a130fc29e09e06cc62d48a70b Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 23:44:22 -0500 Subject: [PATCH 68/79] refactor: remove semantic cache layer --- cmd/wire/cache_cmd.go | 1 - cmd/wire/cache_cmd_test.go | 20 ++--- internal/loader/custom.go | 22 ----- internal/semanticcache/cache.go | 143 -------------------------------- 4 files changed, 8 insertions(+), 178 deletions(-) delete mode 100644 internal/semanticcache/cache.go diff --git a/cmd/wire/cache_cmd.go b/cmd/wire/cache_cmd.go index cdbbd40..1bc4560 100644 --- a/cmd/wire/cache_cmd.go +++ b/cmd/wire/cache_cmd.go @@ -152,7 +152,6 @@ func wireCacheTargets(env []string, userCacheDir string) []cacheTarget { targets := []cacheTarget{ {name: "loader-artifacts", path: envValueDefault(env, loaderArtifactDirEnv, filepath.Join(baseWire, "loader-artifacts"))}, {name: "discovery-cache", path: filepath.Join(baseWire, "discovery-cache")}, - {name: "semantic-artifacts", path: envValueDefault(env, semanticCacheDirEnv, filepath.Join(baseWire, "semantic-artifacts"))}, {name: "output-cache", path: envValueDefault(env, outputCacheDirEnv, filepath.Join(baseWire, "output-cache"))}, } seen := make(map[string]bool, len(targets)) diff --git a/cmd/wire/cache_cmd_test.go b/cmd/wire/cache_cmd_test.go index 83924e2..578c2aa 100644 --- a/cmd/wire/cache_cmd_test.go +++ b/cmd/wire/cache_cmd_test.go @@ -10,10 +10,9 @@ func TestWireCacheTargetsDefault(t *testing.T) { base := filepath.Join(t.TempDir(), "cache") got := wireCacheTargets(nil, base) want := map[string]string{ - "discovery-cache": filepath.Join(base, "wire", "discovery-cache"), - "loader-artifacts": filepath.Join(base, "wire", "loader-artifacts"), - "output-cache": filepath.Join(base, "wire", "output-cache"), - "semantic-artifacts": filepath.Join(base, "wire", "semantic-artifacts"), + "discovery-cache": filepath.Join(base, "wire", "discovery-cache"), + "loader-artifacts": filepath.Join(base, "wire", "loader-artifacts"), + "output-cache": filepath.Join(base, "wire", "output-cache"), } if len(got) != len(want) { t.Fatalf("targets len = %d, want %d", len(got), len(want)) @@ -46,14 +45,12 @@ func TestWireCacheTargetsRespectOverrides(t *testing.T) { env := []string{ loaderArtifactDirEnv + "=" + filepath.Join(base, "loader"), outputCacheDirEnv + "=" + filepath.Join(base, "output"), - semanticCacheDirEnv + "=" + filepath.Join(base, "semantic"), } got := wireCacheTargets(env, base) want := map[string]string{ - "discovery-cache": filepath.Join(base, "wire", "discovery-cache"), - "loader-artifacts": filepath.Join(base, "loader"), - "output-cache": filepath.Join(base, "output"), - "semantic-artifacts": filepath.Join(base, "semantic"), + "discovery-cache": filepath.Join(base, "wire", "discovery-cache"), + "loader-artifacts": filepath.Join(base, "loader"), + "output-cache": filepath.Join(base, "output"), } for _, target := range got { if target.path != want[target.name] { @@ -67,7 +64,6 @@ func TestClearWireCachesRemovesTargets(t *testing.T) { env := []string{ loaderArtifactDirEnv + "=" + filepath.Join(base, "loader"), outputCacheDirEnv + "=" + filepath.Join(base, "output"), - semanticCacheDirEnv + "=" + filepath.Join(base, "semantic"), } for _, target := range wireCacheTargets(env, base) { if err := os.MkdirAll(target.path, 0o755); err != nil { @@ -85,8 +81,8 @@ func TestClearWireCachesRemovesTargets(t *testing.T) { if err != nil { t.Fatalf("clearWireCaches() error = %v", err) } - if len(cleared) != 4 { - t.Fatalf("cleared len = %d, want 4", len(cleared)) + if len(cleared) != 3 { + t.Fatalf("cleared len = %d, want 3", len(cleared)) } for _, target := range wireCacheTargets(env, base) { if _, err := os.Stat(target.path); !os.IsNotExist(err) { diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 007eebd..6fa586b 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -34,8 +34,6 @@ import ( "golang.org/x/tools/go/gcexportdata" "golang.org/x/tools/go/packages" - - "github.com/goforj/wire/internal/semanticcache" ) type unsupportedError struct { @@ -93,7 +91,6 @@ type customTypedGraphLoader struct { importer types.Importer loading map[string]bool isLocalCache map[string]bool - localSemanticOK map[string]bool artifactPrefetch map[string]artifactPrefetchEntry stats typedLoadStats } @@ -530,22 +527,6 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er return pkg, nil } -func (l *customTypedGraphLoader) localSemanticArtifactSupported(meta *packageMeta) bool { - if meta == nil { - return false - } - if ok, exists := l.localSemanticOK[meta.ImportPath]; exists { - return ok - } - art, err := semanticcache.Read(l.env, meta.ImportPath, meta.Name, metaFiles(meta)) - if err != nil || art == nil { - l.localSemanticOK[meta.ImportPath] = false - return false - } - l.localSemanticOK[meta.ImportPath] = art.Supported - return art.Supported -} - func (l *customTypedGraphLoader) artifactPolicy(meta *packageMeta, isTarget, isLocal bool) artifactPolicy { if !loaderArtifactEnabled(l.env) || isTarget { return artifactPolicy{} @@ -553,9 +534,7 @@ func (l *customTypedGraphLoader) artifactPolicy(meta *packageMeta, isTarget, isL policy := artifactPolicy{write: true} if !isLocal { policy.read = true - return policy } - policy.read = l.localSemanticArtifactSupported(meta) return policy } @@ -1185,7 +1164,6 @@ func newCustomTypedGraphLoader(ctx context.Context, wd string, env []string, fse importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), loading: make(map[string]bool, len(meta)), isLocalCache: make(map[string]bool, len(meta)), - localSemanticOK: make(map[string]bool, len(meta)), artifactPrefetch: make(map[string]artifactPrefetchEntry, len(meta)), stats: typedLoadStats{discovery: discoveryDuration}, } diff --git a/internal/semanticcache/cache.go b/internal/semanticcache/cache.go deleted file mode 100644 index 4442415..0000000 --- a/internal/semanticcache/cache.go +++ /dev/null @@ -1,143 +0,0 @@ -package semanticcache - -import ( - "crypto/sha256" - "encoding/gob" - "encoding/hex" - "os" - "path/filepath" - "runtime" - "strconv" -) - -const dirEnv = "WIRE_SEMANTIC_CACHE_DIR" - -type PackageArtifact struct { - Version int - PackagePath string - PackageName string - HasProviderSetVars bool - Supported bool - Vars map[string]ProviderSetArtifact -} - -type ProviderSetArtifact struct { - Items []ProviderSetItemArtifact -} - -type ProviderSetItemArtifact struct { - Kind string - ImportPath string - Name string - Type TypeRef - Type2 TypeRef - FieldNames []string - AllFields bool -} - -type TypeRef struct { - ImportPath string - Name string - Pointer int -} - -func ArtifactPath(env []string, importPath, packageName string, files []string) (string, error) { - dir, err := artifactDir(env) - if err != nil { - return "", err - } - key, err := artifactKey(importPath, packageName, files) - if err != nil { - return "", err - } - return filepath.Join(dir, key+".gob"), nil -} - -func Read(env []string, importPath, packageName string, files []string) (*PackageArtifact, error) { - path, err := ArtifactPath(env, importPath, packageName, files) - if err != nil { - return nil, err - } - f, err := os.Open(path) - if err != nil { - return nil, err - } - defer f.Close() - var art PackageArtifact - if err := gob.NewDecoder(f).Decode(&art); err != nil { - return nil, err - } - return &art, nil -} - -func Write(env []string, importPath, packageName string, files []string, art *PackageArtifact) error { - path, err := ArtifactPath(env, importPath, packageName, files) - if err != nil { - return err - } - if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { - return err - } - f, err := os.Create(path) - if err != nil { - return err - } - defer f.Close() - return gob.NewEncoder(f).Encode(art) -} - -func Exists(env []string, importPath, packageName string, files []string) bool { - path, err := ArtifactPath(env, importPath, packageName, files) - if err != nil { - return false - } - _, err = os.Stat(path) - return err == nil -} - -func artifactDir(env []string) (string, error) { - for i := len(env) - 1; i >= 0; i-- { - key, val, ok := splitEnv(env[i]) - if ok && key == dirEnv && val != "" { - return val, nil - } - } - base, err := os.UserCacheDir() - if err != nil { - return "", err - } - return filepath.Join(base, "wire", "semantic-artifacts"), nil -} - -func artifactKey(importPath, packageName string, files []string) (string, error) { - sum := sha256.New() - sum.Write([]byte("wire-semantic-artifact-v1\n")) - sum.Write([]byte(runtime.Version())) - sum.Write([]byte{'\n'}) - sum.Write([]byte(importPath)) - sum.Write([]byte{'\n'}) - sum.Write([]byte(packageName)) - sum.Write([]byte{'\n'}) - for _, name := range files { - info, err := os.Stat(name) - if err != nil { - return "", err - } - sum.Write([]byte(name)) - sum.Write([]byte{'\n'}) - sum.Write([]byte(strconv.FormatInt(info.Size(), 10))) - sum.Write([]byte{'\n'}) - sum.Write([]byte(strconv.FormatInt(info.ModTime().UnixNano(), 10))) - sum.Write([]byte{'\n'}) - } - return hex.EncodeToString(sum.Sum(nil)), nil -} - -func splitEnv(kv string) (string, string, bool) { - for i := 0; i < len(kv); i++ { - if kv[i] == '=' { - return kv[:i], kv[i+1:], true - } - } - return "", "", false -} From bf14bb5e5ea652e0d06ee6c632613074806bdf81 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 23:57:34 -0500 Subject: [PATCH 69/79] refactor: add import benchmark profile filter --- internal/wire/import_bench_test.go | 13 +++++++++++++ scripts/import-benchmarks.sh | 9 +++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/internal/wire/import_bench_test.go b/internal/wire/import_bench_test.go index cd38190..ad2af50 100644 --- a/internal/wire/import_bench_test.go +++ b/internal/wire/import_bench_test.go @@ -21,6 +21,7 @@ const ( importBenchBreakdown = "WIRE_IMPORT_BENCH_BREAKDOWN" importBenchScenarios = "WIRE_IMPORT_BENCH_SCENARIOS" importBenchScenarioBD = "WIRE_IMPORT_BENCH_SCENARIO_BREAKDOWN" + importBenchProfile = "WIRE_IMPORT_BENCH_PROFILE" stockWireCommit = "9c25c9016f6825302537c4efdd5e897976f9c826" stockWireModulePath = "github.com/google/wire" currentWireModulePath = "github.com/goforj/wire" @@ -145,6 +146,18 @@ func TestPrintImportScenarioBenchmarkTable(t *testing.T) { {localPkgs: 10, depPkgs: 25, external: true, label: "external"}, {localPkgs: 10, depPkgs: 100, external: true, label: "external"}, } + if filter := os.Getenv(importBenchProfile); filter != "" { + filtered := make([]appBenchProfile, 0, len(profiles)) + for _, profile := range profiles { + if profile.label == filter { + filtered = append(filtered, profile) + } + } + if len(filtered) == 0 { + t.Fatalf("%s=%q did not match any benchmark profile", importBenchProfile, filter) + } + profiles = filtered + } rows := make([]importBenchScenarioRow, 0, len(profiles)*6) for _, profile := range profiles { shapeFixture := createAppShapeBenchFixture(t, profile.localPkgs, profile.depPkgs, profile.external, currentWireModulePath, repoRoot) diff --git a/scripts/import-benchmarks.sh b/scripts/import-benchmarks.sh index 2eb98c2..6c1ca47 100755 --- a/scripts/import-benchmarks.sh +++ b/scripts/import-benchmarks.sh @@ -12,12 +12,13 @@ usage() { cat <<'EOF' Usage: scripts/import-benchmarks.sh table - scripts/import-benchmarks.sh scenarios + scripts/import-benchmarks.sh scenarios [profile] scripts/import-benchmarks.sh breakdown Commands: table Print the 10/100/1000 import stock-vs-current benchmark table. scenarios Print the stock-vs-current change-type scenario table. + Optional profiles: local, local-high, external. breakdown Print a focused 1000-import cold/unchanged breakdown. EOF } @@ -27,7 +28,11 @@ case "${1:-}" in WIRE_IMPORT_BENCH_TABLE=1 go test ./internal/wire -run TestPrintImportScaleBenchmarkTable -count=1 -v ;; scenarios) - WIRE_IMPORT_BENCH_SCENARIOS=1 go test ./internal/wire -run TestPrintImportScenarioBenchmarkTable -count=1 -v + if [[ -n "${2:-}" ]]; then + WIRE_IMPORT_BENCH_SCENARIOS=1 WIRE_IMPORT_BENCH_PROFILE="${2}" go test ./internal/wire -run TestPrintImportScenarioBenchmarkTable -count=1 -v + else + WIRE_IMPORT_BENCH_SCENARIOS=1 go test ./internal/wire -run TestPrintImportScenarioBenchmarkTable -count=1 -v + fi ;; breakdown) WIRE_IMPORT_BENCH_BREAKDOWN=1 go test ./internal/wire -run TestPrintImportScaleBenchmarkBreakdown -count=1 -v From bf4a02deb2bb7c342a7d3ee02bd2759e18cb32c9 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Tue, 17 Mar 2026 00:05:28 -0500 Subject: [PATCH 70/79] refactor: trim redundant discovery cache metadata --- internal/loader/discovery_cache.go | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/internal/loader/discovery_cache.go b/internal/loader/discovery_cache.go index 9d7d932..d891d01 100644 --- a/internal/loader/discovery_cache.go +++ b/internal/loader/discovery_cache.go @@ -15,11 +15,6 @@ import ( type discoveryCacheEntry struct { Version int - WD string - Tags string - Patterns []string - NeedDeps bool - Workspace string Meta map[string]*packageMeta Global []discoveryFileMeta LocalPkgs []discoveryLocalPackage @@ -49,6 +44,8 @@ type discoveryFileFingerprint struct { Hash string } +const discoveryCacheVersion = 3 + func readDiscoveryCache(req goListRequest) (map[string]*packageMeta, bool) { entry, err := loadDiscoveryCacheEntry(req) if err != nil || entry == nil { @@ -63,13 +60,8 @@ func readDiscoveryCache(req goListRequest) (map[string]*packageMeta, bool) { func buildDiscoveryCacheEntry(req goListRequest, meta map[string]*packageMeta) (*discoveryCacheEntry, error) { workspace := detectModuleRoot(req.WD) entry := &discoveryCacheEntry{ - Version: 3, - WD: canonicalLoaderPath(req.WD), - Tags: req.Tags, - Patterns: append([]string(nil), req.Patterns...), - NeedDeps: req.NeedDeps, - Workspace: workspace, - Meta: clonePackageMetaMap(meta), + Version: discoveryCacheVersion, + Meta: clonePackageMetaMap(meta), } global := []string{ filepath.Join(workspace, "go.mod"), @@ -108,7 +100,7 @@ func buildDiscoveryCacheEntry(req goListRequest, meta map[string]*packageMeta) ( } func validateDiscoveryCacheEntry(entry *discoveryCacheEntry) bool { - if entry == nil || entry.Version != 3 { + if entry == nil || entry.Version != discoveryCacheVersion { return false } for _, fm := range entry.Global { @@ -142,7 +134,7 @@ func discoveryCachePath(req goListRequest) (string, error) { NeedDeps bool Go string }{ - Version: 3, + Version: discoveryCacheVersion, WD: canonicalLoaderPath(req.WD), Tags: req.Tags, Patterns: append([]string(nil), req.Patterns...), From 788798d1b127311871a4b467a99b8c4db16e6f90 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Tue, 17 Mar 2026 00:06:38 -0500 Subject: [PATCH 71/79] refactor: remove redundant discovery cache cloning --- internal/loader/discovery_cache.go | 46 ++---------------------------- 1 file changed, 2 insertions(+), 44 deletions(-) diff --git a/internal/loader/discovery_cache.go b/internal/loader/discovery_cache.go index d891d01..52a67c9 100644 --- a/internal/loader/discovery_cache.go +++ b/internal/loader/discovery_cache.go @@ -54,14 +54,14 @@ func readDiscoveryCache(req goListRequest) (map[string]*packageMeta, bool) { if !validateDiscoveryCacheEntry(entry) { return nil, false } - return clonePackageMetaMap(entry.Meta), true + return entry.Meta, true } func buildDiscoveryCacheEntry(req goListRequest, meta map[string]*packageMeta) (*discoveryCacheEntry, error) { workspace := detectModuleRoot(req.WD) entry := &discoveryCacheEntry{ Version: discoveryCacheVersion, - Meta: clonePackageMetaMap(meta), + Meta: meta, } global := []string{ filepath.Join(workspace, "go.mod"), @@ -285,45 +285,3 @@ func hashGob(v interface{}) (string, error) { sum := sha256.Sum256(buf.Bytes()) return hex.EncodeToString(sum[:]), nil } - -func clonePackageMetaMap(in map[string]*packageMeta) map[string]*packageMeta { - if len(in) == 0 { - return nil - } - out := make(map[string]*packageMeta, len(in)) - for k, v := range in { - if v == nil { - continue - } - cp := *v - cp.GoFiles = append([]string(nil), v.GoFiles...) - cp.CompiledGoFiles = append([]string(nil), v.CompiledGoFiles...) - cp.Imports = append([]string(nil), v.Imports...) - if v.ImportMap != nil { - cp.ImportMap = make(map[string]string, len(v.ImportMap)) - for mk, mv := range v.ImportMap { - cp.ImportMap[mk] = mv - } - } - if v.Module != nil { - cp.Module = cloneGoListModule(v.Module) - } - if v.Error != nil { - errCopy := *v.Error - cp.Error = &errCopy - } - out[k] = &cp - } - return out -} - -func cloneGoListModule(in *goListModule) *goListModule { - if in == nil { - return nil - } - cp := *in - if in.Replace != nil { - cp.Replace = cloneGoListModule(in.Replace) - } - return &cp -} From 0921fc2e6eb00b0b4d138d7f7608f3eceec08fbb Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Tue, 17 Mar 2026 00:11:07 -0500 Subject: [PATCH 72/79] refactor: split external benchmark profiles --- internal/wire/import_bench_test.go | 4 ++-- scripts/import-benchmarks.sh | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/wire/import_bench_test.go b/internal/wire/import_bench_test.go index ad2af50..d3f8ed7 100644 --- a/internal/wire/import_bench_test.go +++ b/internal/wire/import_bench_test.go @@ -143,8 +143,8 @@ func TestPrintImportScenarioBenchmarkTable(t *testing.T) { profiles := []appBenchProfile{ {localPkgs: 10, depPkgs: 25, label: "local"}, {localPkgs: 10, depPkgs: 1000, label: "local-high"}, - {localPkgs: 10, depPkgs: 25, external: true, label: "external"}, - {localPkgs: 10, depPkgs: 100, external: true, label: "external"}, + {localPkgs: 10, depPkgs: 25, external: true, label: "external-low"}, + {localPkgs: 10, depPkgs: 100, external: true, label: "external-high"}, } if filter := os.Getenv(importBenchProfile); filter != "" { filtered := make([]appBenchProfile, 0, len(profiles)) diff --git a/scripts/import-benchmarks.sh b/scripts/import-benchmarks.sh index 6c1ca47..232ccd9 100755 --- a/scripts/import-benchmarks.sh +++ b/scripts/import-benchmarks.sh @@ -18,7 +18,7 @@ Usage: Commands: table Print the 10/100/1000 import stock-vs-current benchmark table. scenarios Print the stock-vs-current change-type scenario table. - Optional profiles: local, local-high, external. + Optional profiles: local, local-high, external-low, external-high. breakdown Print a focused 1000-import cold/unchanged breakdown. EOF } From e52201b03a4817e85fbe39cfb7f392f6a5471c4c Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Tue, 17 Mar 2026 00:18:57 -0500 Subject: [PATCH 73/79] refactor: share custom typed load pipeline --- internal/loader/custom.go | 44 +++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 6fa586b..76f86ba 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -235,19 +235,10 @@ func loadTypedPackageGraphCustom(ctx context.Context, req LazyLoadRequest) (*Laz } } discoveryDuration := time.Since(discoveryStart) - if len(meta) == 0 { - return nil, unsupportedError{reason: "empty go list result"} - } - fset := req.Fset - if fset == nil { - fset = token.NewFileSet() - } - l := newCustomTypedGraphLoader(ctx, req.WD, req.Env, fset, meta, map[string]struct{}{req.Package: {}}, req.ParseFile, discoveryDuration) - roots, err := loadCustomRootPackages(l, []string{req.Package}) + roots, err := loadCustomPackagesFromMeta(ctx, req.WD, req.Env, req.Fset, meta, map[string]struct{}{req.Package: {}}, []string{req.Package}, req.ParseFile, discoveryDuration, "lazy") if err != nil { return nil, err } - logTypedLoadStats(ctx, "lazy", l.stats) return &LazyLoadResult{ Packages: roots, Backend: ModeCustom, @@ -267,33 +258,40 @@ func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLo return nil, err } discoveryDuration := time.Since(discoveryStart) - if len(meta) == 0 { - return nil, unsupportedError{reason: "empty go list result"} - } - fset := req.Fset - if fset == nil { - fset = token.NewFileSet() - } rootPaths := nonDepRootImportPaths(meta) targets := make(map[string]struct{}, len(rootPaths)) for _, path := range rootPaths { targets[path] = struct{}{} } - if len(targets) == 0 { - return nil, unsupportedError{reason: "no root packages from metadata"} - } - l := newCustomTypedGraphLoader(ctx, req.WD, req.Env, fset, meta, targets, req.ParseFile, discoveryDuration) - roots, err := loadCustomRootPackages(l, rootPaths) + roots, err := loadCustomPackagesFromMeta(ctx, req.WD, req.Env, req.Fset, meta, targets, rootPaths, req.ParseFile, discoveryDuration, "typed") if err != nil { return nil, err } - logTypedLoadStats(ctx, "typed", l.stats) return &PackageLoadResult{ Packages: roots, Backend: ModeCustom, }, nil } +func loadCustomPackagesFromMeta(ctx context.Context, wd string, env []string, fset *token.FileSet, meta map[string]*packageMeta, targets map[string]struct{}, rootPaths []string, parseFile ParseFileFunc, discoveryDuration time.Duration, mode string) ([]*packages.Package, error) { + if len(meta) == 0 { + return nil, unsupportedError{reason: "empty go list result"} + } + if len(rootPaths) == 0 { + return nil, unsupportedError{reason: "no root packages from metadata"} + } + if fset == nil { + fset = token.NewFileSet() + } + l := newCustomTypedGraphLoader(ctx, wd, env, fset, meta, targets, parseFile, discoveryDuration) + roots, err := loadCustomRootPackages(l, rootPaths) + if err != nil { + return nil, err + } + logTypedLoadStats(ctx, mode, l.stats) + return roots, nil +} + func loadCustomRootPackages(l *customTypedGraphLoader, paths []string) ([]*packages.Package, error) { prefetchStart := time.Now() l.prefetchArtifacts() From 2dc4ac4e5fd4b75b70a70a9e5ade047278440c83 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Tue, 17 Mar 2026 00:21:41 -0500 Subject: [PATCH 74/79] refactor: centralize custom metadata loading --- internal/loader/custom.go | 61 +++++++++++++++++++++------------------ 1 file changed, 33 insertions(+), 28 deletions(-) diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 76f86ba..fce84f4 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -173,8 +173,7 @@ func validateTouchedPackagesCustom(ctx context.Context, req TouchedValidationReq } func loadRootGraphCustom(ctx context.Context, req RootLoadRequest) (*RootLoadResult, error) { - discoveryStart := time.Now() - meta, err := runGoList(ctx, goListRequest{ + meta, discoveryDuration, err := loadCustomMeta(ctx, goListRequest{ WD: req.WD, Env: req.Env, Tags: req.Tags, @@ -184,10 +183,7 @@ func loadRootGraphCustom(ctx context.Context, req RootLoadRequest) (*RootLoadRes if err != nil { return nil, err } - logTiming(ctx, "loader.custom.root.discovery", discoveryStart) - if len(meta) == 0 { - return nil, unsupportedError{reason: "empty go list result"} - } + logDuration(ctx, "loader.custom.root.discovery", discoveryDuration) pkgs := packageStubGraphFromMeta(nil, meta) rootPaths := nonDepRootImportPaths(meta) roots := make([]*packages.Package, 0, len(rootPaths)) @@ -219,22 +215,10 @@ func loadTypedPackageGraphCustom(ctx context.Context, req LazyLoadRequest) (*Laz meta map[string]*packageMeta err error ) - discoveryStart := time.Now() - if req.Discovery != nil && len(req.Discovery.meta) > 0 { - meta = req.Discovery.meta - } else { - meta, err = runGoList(ctx, goListRequest{ - WD: req.WD, - Env: req.Env, - Tags: req.Tags, - Patterns: []string{req.Package}, - NeedDeps: true, - }) - if err != nil { - return nil, err - } + meta, discoveryDuration, err := loadCustomLazyMeta(ctx, req) + if err != nil { + return nil, err } - discoveryDuration := time.Since(discoveryStart) roots, err := loadCustomPackagesFromMeta(ctx, req.WD, req.Env, req.Fset, meta, map[string]struct{}{req.Package: {}}, []string{req.Package}, req.ParseFile, discoveryDuration, "lazy") if err != nil { return nil, err @@ -246,8 +230,7 @@ func loadTypedPackageGraphCustom(ctx context.Context, req LazyLoadRequest) (*Laz } func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLoadResult, error) { - discoveryStart := time.Now() - meta, err := runGoList(ctx, goListRequest{ + meta, discoveryDuration, err := loadCustomMeta(ctx, goListRequest{ WD: req.WD, Env: req.Env, Tags: req.Tags, @@ -257,7 +240,6 @@ func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLo if err != nil { return nil, err } - discoveryDuration := time.Since(discoveryStart) rootPaths := nonDepRootImportPaths(meta) targets := make(map[string]struct{}, len(rootPaths)) for _, path := range rootPaths { @@ -292,6 +274,32 @@ func loadCustomPackagesFromMeta(ctx context.Context, wd string, env []string, fs return roots, nil } +func loadCustomMeta(ctx context.Context, req goListRequest) (map[string]*packageMeta, time.Duration, error) { + start := time.Now() + meta, err := runGoList(ctx, req) + duration := time.Since(start) + if err != nil { + return nil, duration, err + } + if len(meta) == 0 { + return nil, duration, unsupportedError{reason: "empty go list result"} + } + return meta, duration, nil +} + +func loadCustomLazyMeta(ctx context.Context, req LazyLoadRequest) (map[string]*packageMeta, time.Duration, error) { + if req.Discovery != nil && len(req.Discovery.meta) > 0 { + return req.Discovery.meta, 0, nil + } + return loadCustomMeta(ctx, goListRequest{ + WD: req.WD, + Env: req.Env, + Tags: req.Tags, + Patterns: []string{req.Package}, + NeedDeps: true, + }) +} + func loadCustomRootPackages(l *customTypedGraphLoader, paths []string) ([]*packages.Package, error) { prefetchStart := time.Now() l.prefetchArtifacts() @@ -994,7 +1002,7 @@ func importName(spec *ast.ImportSpec) string { } func discoverTouchedMetadata(ctx context.Context, req TouchedValidationRequest) (map[string]*packageMeta, error) { - metas, err := runGoList(ctx, goListRequest{ + metas, _, err := loadCustomMeta(ctx, goListRequest{ WD: req.WD, Env: req.Env, Tags: req.Tags, @@ -1004,9 +1012,6 @@ func discoverTouchedMetadata(ctx context.Context, req TouchedValidationRequest) if err != nil { return nil, err } - if len(metas) == 0 { - return nil, unsupportedError{reason: "empty go list result"} - } for _, touched := range req.Touched { if _, ok := metas[touched]; !ok { return nil, unsupportedError{reason: "missing touched package in metadata"} From 6f0966483067e91a987ba1f5f8b3df786fbdf8c3 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Tue, 17 Mar 2026 00:25:39 -0500 Subject: [PATCH 75/79] refactor: share loader fallback reason policy --- internal/loader/fallback.go | 47 ++++++++++--------------------------- 1 file changed, 13 insertions(+), 34 deletions(-) diff --git a/internal/loader/fallback.go b/internal/loader/fallback.go index 513694c..860bd50 100644 --- a/internal/loader/fallback.go +++ b/internal/loader/fallback.go @@ -24,6 +24,15 @@ import ( type defaultLoader struct{} +func fallbackReasonDetail(mode Mode, detail string) (FallbackReason, string) { + switch mode { + case ModeFallback: + return FallbackReasonForcedFallback, "" + default: + return FallbackReasonCustomUnsupported, detail + } +} + func (defaultLoader) LoadPackages(ctx context.Context, req PackageLoadRequest) (*PackageLoadResult, error) { var unsupported unsupportedError if req.LoaderMode != ModeFallback { @@ -38,15 +47,7 @@ func (defaultLoader) LoadPackages(ctx context.Context, req PackageLoadRequest) ( result := &PackageLoadResult{ Backend: ModeFallback, } - switch req.LoaderMode { - case ModeFallback: - result.FallbackReason = FallbackReasonForcedFallback - default: - result.FallbackReason = FallbackReasonCustomUnsupported - if unsupported.reason != "" { - result.FallbackDetail = unsupported.reason - } - } + result.FallbackReason, result.FallbackDetail = fallbackReasonDetail(req.LoaderMode, unsupported.reason) cfg := &packages.Config{ Context: ctx, Mode: req.Mode, @@ -90,15 +91,7 @@ func (defaultLoader) LoadRootGraph(ctx context.Context, req RootLoadRequest) (*R result := &RootLoadResult{ Backend: ModeFallback, } - switch req.Mode { - case ModeFallback: - result.FallbackReason = FallbackReasonForcedFallback - default: - result.FallbackReason = FallbackReasonCustomUnsupported - if unsupported.reason != "" { - result.FallbackDetail = unsupported.reason - } - } + result.FallbackReason, result.FallbackDetail = fallbackReasonDetail(req.Mode, unsupported.reason) cfg := &packages.Config{ Context: ctx, Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports, @@ -142,15 +135,7 @@ func (defaultLoader) LoadTypedPackageGraph(ctx context.Context, req LazyLoadRequ result := &LazyLoadResult{ Backend: ModeFallback, } - switch req.LoaderMode { - case ModeFallback: - result.FallbackReason = FallbackReasonForcedFallback - default: - result.FallbackReason = FallbackReasonCustomUnsupported - if unsupported.reason != "" { - result.FallbackDetail = unsupported.reason - } - } + result.FallbackReason, result.FallbackDetail = fallbackReasonDetail(req.LoaderMode, unsupported.reason) cfg := &packages.Config{ Context: ctx, Mode: req.Mode, @@ -194,13 +179,7 @@ func validateTouchedPackagesFallback(ctx context.Context, req TouchedValidationR result := &TouchedValidationResult{ Backend: ModeFallback, } - switch req.Mode { - case ModeFallback: - result.FallbackReason = FallbackReasonForcedFallback - default: - result.FallbackReason = FallbackReasonCustomUnsupported - result.FallbackDetail = detail - } + result.FallbackReason, result.FallbackDetail = fallbackReasonDetail(req.Mode, detail) if len(req.Touched) == 0 { return result, nil } From d003c9f88f447f922979a08406d6d67b2a1b31ed Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Tue, 17 Mar 2026 00:45:14 -0500 Subject: [PATCH 76/79] refactor: add targeted local profile benchmark --- internal/wire/import_bench_test.go | 80 ++++++++++++++++++++++++------ 1 file changed, 64 insertions(+), 16 deletions(-) diff --git a/internal/wire/import_bench_test.go b/internal/wire/import_bench_test.go index d3f8ed7..24c58da 100644 --- a/internal/wire/import_bench_test.go +++ b/internal/wire/import_bench_test.go @@ -266,6 +266,25 @@ func TestPrintImportScenarioBenchmarkBreakdown(t *testing.T) { printScenarioTimingLines(currentOutput) } +func BenchmarkCurrentWireLocalProfile(b *testing.B) { + repoRoot, err := importBenchRepoRoot() + if err != nil { + b.Fatal(err) + } + currentBin := buildWireBinary(b, repoRoot, "current-wire") + const ( + features = 10 + depPkgs = 25 + external = false + ) + + for _, variant := range []string{"unchanged", "body", "shape", "import", "known-toggle"} { + b.Run(variant, func(b *testing.B) { + benchmarkCurrentWireAppScenario(b, currentBin, repoRoot, features, depPkgs, external, variant) + }) + } +} + func runAppColdTrials(t *testing.T, bin string, features, depPkgs int, external bool, wireModulePath, wireReplaceDir string, trials int) []time.Duration { t.Helper() durations := make([]time.Duration, 0, trials) @@ -278,6 +297,35 @@ func runAppColdTrials(t *testing.T, bin string, features, depPkgs int, external return durations } +func benchmarkCurrentWireAppScenario(b *testing.B, bin, repoRoot string, features, depPkgs int, external bool, variant string) { + b.Helper() + pkgDir := createAppShapeBenchFixture(b, features, depPkgs, external, currentWireModulePath, repoRoot) + caches := newBenchCaches(b) + root := filepath.Dir(pkgDir) + prewarmGoBenchCache(b, pkgDir, caches) + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StopTimer() + resetAppShapeBenchFixture(b, pkgDir, features) + switch variant { + case "body", "shape", "import": + _ = runWireBenchCommand(b, bin, pkgDir, caches) + writeAppShapeControllerFile(b, root, 0, variant) + case "known-toggle": + _ = runWireBenchCommand(b, bin, pkgDir, caches) + writeAppShapeControllerFile(b, root, 0, "shape") + _ = runWireBenchCommand(b, bin, pkgDir, caches) + writeAppShapeControllerFile(b, root, 0, "base") + case "unchanged": + _ = runWireBenchCommand(b, bin, pkgDir, caches) + default: + b.Fatalf("unknown benchmark variant %q", variant) + } + b.StartTimer() + _ = runWireBenchCommand(b, bin, pkgDir, caches) + } +} + func runAppWarmTrials(t *testing.T, bin string, features, depPkgs int, external bool, wireModulePath, wireReplaceDir string, trials int) []time.Duration { t.Helper() durations := make([]time.Duration, 0, trials) @@ -326,7 +374,7 @@ func runAppKnownToggleTrials(t *testing.T, bin string, features, depPkgs int, ex return durations } -func buildWireBinary(t *testing.T, dir, name string) string { +func buildWireBinary(t testing.TB, dir, name string) string { t.Helper() if runtime.GOOS == "windows" && filepath.Ext(name) != ".exe" { name += ".exe" @@ -342,7 +390,7 @@ func buildWireBinary(t *testing.T, dir, name string) string { return out } -func newBenchCaches(t *testing.T) benchCaches { +func newBenchCaches(t testing.TB) benchCaches { t.Helper() return benchCaches{ home: t.TempDir(), @@ -350,7 +398,7 @@ func newBenchCaches(t *testing.T) benchCaches { } } -func extractStockWire(t *testing.T, repoRoot, commit string) string { +func extractStockWire(t testing.TB, repoRoot, commit string) string { t.Helper() tmp := t.TempDir() cmd := exec.Command("git", "archive", "--format=tar", commit) @@ -400,7 +448,7 @@ func extractStockWire(t *testing.T, repoRoot, commit string) string { return tmp } -func createImportBenchFixture(t *testing.T, imports int, wireModulePath, wireReplaceDir string) string { +func createImportBenchFixture(t testing.TB, imports int, wireModulePath, wireReplaceDir string) string { t.Helper() root := t.TempDir() if err := os.WriteFile(filepath.Join(root, "go.mod"), []byte(importBenchGoMod(wireModulePath, wireReplaceDir)), 0o644); err != nil { @@ -424,7 +472,7 @@ func createImportBenchFixture(t *testing.T, imports int, wireModulePath, wireRep return filepath.Join(root, "app") } -func createAppShapeBenchFixture(t *testing.T, features, depPkgs int, external bool, wireModulePath, wireReplaceDir string) string { +func createAppShapeBenchFixture(t testing.TB, features, depPkgs int, external bool, wireModulePath, wireReplaceDir string) string { t.Helper() root := t.TempDir() modulePath := "example.com/appbench" @@ -455,7 +503,7 @@ func createAppShapeBenchFixture(t *testing.T, features, depPkgs int, external bo return filepath.Join(root, "wire") } -func writeAppShapeFile(t *testing.T, path, content string) { +func writeAppShapeFile(t testing.TB, path, content string) { t.Helper() if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { t.Fatal(err) @@ -465,7 +513,7 @@ func writeAppShapeFile(t *testing.T, path, content string) { } } -func writeAppShapeControllerFile(t *testing.T, root string, index int, variant string) { +func writeAppShapeControllerFile(t testing.TB, root string, index int, variant string) { t.Helper() path := filepath.Join(root, "internal", fmt.Sprintf("feature%04d", index), "controller.go") if err := os.WriteFile(path, []byte(appShapeControllerFile("example.com/appbench", index, variant)), 0o644); err != nil { @@ -473,7 +521,7 @@ func writeAppShapeControllerFile(t *testing.T, root string, index int, variant s } } -func seedAppShapeExternalGoSum(t *testing.T, root string) { +func seedAppShapeExternalGoSum(t testing.TB, root string) { t.Helper() const source = "/private/tmp/test/go.sum" data, err := os.ReadFile(source) @@ -485,7 +533,7 @@ func seedAppShapeExternalGoSum(t *testing.T, root string) { } } -func resetAppShapeBenchFixture(t *testing.T, pkgDir string, features int) { +func resetAppShapeBenchFixture(t testing.TB, pkgDir string, features int) { t.Helper() root := filepath.Dir(pkgDir) for i := 0; i < features; i++ { @@ -1019,13 +1067,13 @@ func runKnownImportToggleTrials(t *testing.T, bin string, imports int, wireModul return durations } -func runWireBenchCommand(t *testing.T, bin, pkgDir string, caches benchCaches) time.Duration { +func runWireBenchCommand(t testing.TB, bin, pkgDir string, caches benchCaches) time.Duration { t.Helper() d, _ := runWireBenchCommandOutput(t, bin, pkgDir, caches) return d } -func runWireBenchCommandOutput(t *testing.T, bin, pkgDir string, caches benchCaches, extraArgs ...string) (time.Duration, string) { +func runWireBenchCommandOutput(t testing.TB, bin, pkgDir string, caches benchCaches, extraArgs ...string) (time.Duration, string) { t.Helper() args := []string{"gen"} args = append(args, extraArgs...) @@ -1042,7 +1090,7 @@ func runWireBenchCommandOutput(t *testing.T, bin, pkgDir string, caches benchCac return time.Since(start), stderr.String() } -func prewarmGoBenchCache(t *testing.T, pkgDir string, caches benchCaches) { +func prewarmGoBenchCache(t testing.TB, pkgDir string, caches benchCaches) { t.Helper() prepareBenchModule(t, pkgDir, caches) cmd := exec.Command("go", "list", "-deps", "./...") @@ -1054,7 +1102,7 @@ func prewarmGoBenchCache(t *testing.T, pkgDir string, caches benchCaches) { } } -func goListGraphCounts(t *testing.T, pkgDir, modulePath string, caches benchCaches) benchGraphCounts { +func goListGraphCounts(t testing.TB, pkgDir, modulePath string, caches benchCaches) benchGraphCounts { t.Helper() prepareBenchModule(t, pkgDir, caches) cmd := exec.Command("go", "list", "-deps", "-json", "./...") @@ -1097,7 +1145,7 @@ func goListGraphCounts(t *testing.T, pkgDir, modulePath string, caches benchCach return counts } -func prepareBenchModule(t *testing.T, pkgDir string, caches benchCaches) { +func prepareBenchModule(t testing.TB, pkgDir string, caches benchCaches) { t.Helper() marker := filepath.Join(filepath.Dir(pkgDir), ".bench-module-ready") if _, err := os.Stat(marker); err == nil { @@ -1216,7 +1264,7 @@ func importBenchDepFile(i int, variant string) string { } } -func writeImportBenchWireFile(t *testing.T, root string, imports int, wireModulePath string) { +func writeImportBenchWireFile(t testing.TB, root string, imports int, wireModulePath string) { t.Helper() path := filepath.Join(root, "app", "wire.go") if err := os.WriteFile(path, []byte(importBenchWireFile(imports, wireModulePath)), 0o644); err != nil { @@ -1224,7 +1272,7 @@ func writeImportBenchWireFile(t *testing.T, root string, imports int, wireModule } } -func writeImportBenchDepFile(t *testing.T, root string, index int, variant string) { +func writeImportBenchDepFile(t testing.TB, root string, index int, variant string) { t.Helper() path := filepath.Join(root, fmt.Sprintf("dep%04d", index), "dep.go") if err := os.WriteFile(path, []byte(importBenchDepFile(index, variant)), 0o644); err != nil { From a1125656097b1724625732ec13a268209acb337d Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Tue, 17 Mar 2026 00:49:59 -0500 Subject: [PATCH 77/79] refactor: add one-shot import profile harness --- internal/wire/import_bench_test.go | 96 ++++++++++++++++++++++++++---- 1 file changed, 84 insertions(+), 12 deletions(-) diff --git a/internal/wire/import_bench_test.go b/internal/wire/import_bench_test.go index 24c58da..39dd862 100644 --- a/internal/wire/import_bench_test.go +++ b/internal/wire/import_bench_test.go @@ -22,6 +22,9 @@ const ( importBenchScenarios = "WIRE_IMPORT_BENCH_SCENARIOS" importBenchScenarioBD = "WIRE_IMPORT_BENCH_SCENARIO_BREAKDOWN" importBenchProfile = "WIRE_IMPORT_BENCH_PROFILE" + importBenchProfileRun = "WIRE_IMPORT_BENCH_PROFILE_RUN" + importBenchVariant = "WIRE_IMPORT_BENCH_VARIANT" + importBenchCPUProfile = "WIRE_IMPORT_BENCH_CPU_PROFILE" stockWireCommit = "9c25c9016f6825302537c4efdd5e897976f9c826" stockWireModulePath = "github.com/google/wire" currentWireModulePath = "github.com/goforj/wire" @@ -134,18 +137,7 @@ func TestPrintImportScenarioBenchmarkTable(t *testing.T) { stockDir := extractStockWire(t, repoRoot, stockWireCommit) stockBin := buildWireBinary(t, stockDir, "stock-wire") - type appBenchProfile struct { - localPkgs int - depPkgs int - external bool - label string - } - profiles := []appBenchProfile{ - {localPkgs: 10, depPkgs: 25, label: "local"}, - {localPkgs: 10, depPkgs: 1000, label: "local-high"}, - {localPkgs: 10, depPkgs: 25, external: true, label: "external-low"}, - {localPkgs: 10, depPkgs: 100, external: true, label: "external-high"}, - } + profiles := importBenchAppProfiles() if filter := os.Getenv(importBenchProfile); filter != "" { filtered := make([]appBenchProfile, 0, len(profiles)) for _, profile := range profiles { @@ -266,6 +258,61 @@ func TestPrintImportScenarioBenchmarkBreakdown(t *testing.T) { printScenarioTimingLines(currentOutput) } +func TestProfileCurrentWireScenarioRun(t *testing.T) { + if os.Getenv(importBenchProfileRun) != "1" { + t.Skipf("%s not set", importBenchProfileRun) + } + profile := os.Getenv(importBenchProfile) + variant := os.Getenv(importBenchVariant) + cpuProfile := os.Getenv(importBenchCPUProfile) + if profile == "" { + t.Fatalf("%s must be set", importBenchProfile) + } + if variant == "" { + t.Fatalf("%s must be set", importBenchVariant) + } + if cpuProfile == "" { + t.Fatalf("%s must be set", importBenchCPUProfile) + } + repoRoot, err := importBenchRepoRoot() + if err != nil { + t.Fatal(err) + } + currentBin := buildWireBinary(t, repoRoot, "current-wire") + profileCfg, err := importBenchAppProfile(profile) + if err != nil { + t.Fatal(err) + } + pkgDir := createAppShapeBenchFixture(t, profileCfg.localPkgs, profileCfg.depPkgs, profileCfg.external, currentWireModulePath, repoRoot) + caches := newBenchCaches(t) + prewarmGoBenchCache(t, pkgDir, caches) + root := filepath.Dir(pkgDir) + + switch variant { + case "unchanged": + _ = runWireBenchCommand(t, currentBin, pkgDir, caches) + case "body", "shape", "import": + resetAppShapeBenchFixture(t, pkgDir, profileCfg.localPkgs) + _ = runWireBenchCommand(t, currentBin, pkgDir, caches) + writeAppShapeControllerFile(t, root, 0, variant) + case "known-toggle": + resetAppShapeBenchFixture(t, pkgDir, profileCfg.localPkgs) + _ = runWireBenchCommand(t, currentBin, pkgDir, caches) + writeAppShapeControllerFile(t, root, 0, "shape") + _ = runWireBenchCommand(t, currentBin, pkgDir, caches) + writeAppShapeControllerFile(t, root, 0, "base") + default: + t.Fatalf("unknown %s %q", importBenchVariant, variant) + } + + dur, output := runWireBenchCommandOutput(t, currentBin, pkgDir, caches, "-cpuprofile="+cpuProfile, "-timings") + fmt.Printf("profile: %s\n", profile) + fmt.Printf("variant: %s\n", variant) + fmt.Printf("duration: %s\n", formatMs(dur)) + fmt.Printf("cpuprofile: %s\n", cpuProfile) + printScenarioTimingLines(output) +} + func BenchmarkCurrentWireLocalProfile(b *testing.B) { repoRoot, err := importBenchRepoRoot() if err != nil { @@ -285,6 +332,31 @@ func BenchmarkCurrentWireLocalProfile(b *testing.B) { } } +type appBenchProfile struct { + localPkgs int + depPkgs int + external bool + label string +} + +func importBenchAppProfiles() []appBenchProfile { + return []appBenchProfile{ + {localPkgs: 10, depPkgs: 25, label: "local"}, + {localPkgs: 10, depPkgs: 1000, label: "local-high"}, + {localPkgs: 10, depPkgs: 25, external: true, label: "external-low"}, + {localPkgs: 10, depPkgs: 100, external: true, label: "external-high"}, + } +} + +func importBenchAppProfile(label string) (appBenchProfile, error) { + for _, profile := range importBenchAppProfiles() { + if profile.label == label { + return profile, nil + } + } + return appBenchProfile{}, fmt.Errorf("%s=%q did not match any benchmark profile", importBenchProfile, label) +} + func runAppColdTrials(t *testing.T, bin string, features, depPkgs int, external bool, wireModulePath, wireReplaceDir string, trials int) []time.Duration { t.Helper() durations := make([]time.Duration, 0, trials) From ee3ffc9890c2db874dce6e0c116d7f328c075247 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Tue, 17 Mar 2026 01:56:00 -0500 Subject: [PATCH 78/79] perf: reuse root discovery for generate loads --- internal/loader/custom.go | 24 +++++++++++++++++------- internal/loader/discovery.go | 6 +++++- internal/loader/discovery_cache.go | 14 ++++++++------ internal/loader/loader.go | 1 + internal/wire/output_cache.go | 16 ++++++++-------- internal/wire/parse.go | 5 +++-- internal/wire/wire.go | 4 ++-- 7 files changed, 44 insertions(+), 26 deletions(-) diff --git a/internal/loader/custom.go b/internal/loader/custom.go index fce84f4..112fe85 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -179,6 +179,7 @@ func loadRootGraphCustom(ctx context.Context, req RootLoadRequest) (*RootLoadRes Tags: req.Tags, Patterns: req.Patterns, NeedDeps: req.NeedDeps, + SkipCompiled: true, }) if err != nil { return nil, err @@ -230,13 +231,22 @@ func loadTypedPackageGraphCustom(ctx context.Context, req LazyLoadRequest) (*Laz } func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLoadResult, error) { - meta, discoveryDuration, err := loadCustomMeta(ctx, goListRequest{ - WD: req.WD, - Env: req.Env, - Tags: req.Tags, - Patterns: req.Patterns, - NeedDeps: true, - }) + var ( + meta map[string]*packageMeta + discoveryDuration time.Duration + err error + ) + if req.Discovery != nil && len(req.Discovery.meta) > 0 { + meta = req.Discovery.meta + } else { + meta, discoveryDuration, err = loadCustomMeta(ctx, goListRequest{ + WD: req.WD, + Env: req.Env, + Tags: req.Tags, + Patterns: req.Patterns, + NeedDeps: true, + }) + } if err != nil { return nil, err } diff --git a/internal/loader/discovery.go b/internal/loader/discovery.go index 0e7e69c..e416e95 100644 --- a/internal/loader/discovery.go +++ b/internal/loader/discovery.go @@ -32,6 +32,7 @@ type goListRequest struct { Tags string Patterns []string NeedDeps bool + SkipCompiled bool } func runGoList(ctx context.Context, req goListRequest) (map[string]*packageMeta, error) { @@ -46,7 +47,10 @@ func runGoList(ctx context.Context, req goListRequest) (map[string]*packageMeta, return cached, nil } logDuration(ctx, "loader.discovery.cache_read.wall", time.Since(cacheReadStart)) - args := []string{"list", "-json", "-e", "-compiled", "-export"} + args := []string{"list", "-json", "-e", "-export"} + if !req.SkipCompiled { + args = append(args, "-compiled") + } if req.NeedDeps { args = append(args, "-deps") } diff --git a/internal/loader/discovery_cache.go b/internal/loader/discovery_cache.go index 52a67c9..1a93bdf 100644 --- a/internal/loader/discovery_cache.go +++ b/internal/loader/discovery_cache.go @@ -132,14 +132,16 @@ func discoveryCachePath(req goListRequest) (string, error) { Tags string Patterns []string NeedDeps bool + SkipCompiled bool Go string }{ - Version: discoveryCacheVersion, - WD: canonicalLoaderPath(req.WD), - Tags: req.Tags, - Patterns: append([]string(nil), req.Patterns...), - NeedDeps: req.NeedDeps, - Go: runtime.Version(), + Version: discoveryCacheVersion, + WD: canonicalLoaderPath(req.WD), + Tags: req.Tags, + Patterns: append([]string(nil), req.Patterns...), + NeedDeps: req.NeedDeps, + SkipCompiled: req.SkipCompiled, + Go: runtime.Version(), } key, err := hashGob(sumReq) if err != nil { diff --git a/internal/loader/loader.go b/internal/loader/loader.go index e26747b..a507758 100644 --- a/internal/loader/loader.go +++ b/internal/loader/loader.go @@ -93,6 +93,7 @@ type PackageLoadRequest struct { LoaderMode Mode Fset *token.FileSet ParseFile ParseFileFunc + Discovery *DiscoverySnapshot } type PackageLoadResult struct { diff --git a/internal/wire/output_cache.go b/internal/wire/output_cache.go index b95a514..bd2bc8b 100644 --- a/internal/wire/output_cache.go +++ b/internal/wire/output_cache.go @@ -32,10 +32,10 @@ type outputCacheCandidate struct { outputPath string } -func prepareGenerateOutputCache(ctx context.Context, wd string, env []string, patterns []string, opts *GenerateOptions) (map[string]outputCacheCandidate, []GenerateResult, bool) { +func prepareGenerateOutputCache(ctx context.Context, wd string, env []string, patterns []string, opts *GenerateOptions) (map[string]outputCacheCandidate, []GenerateResult, *loader.DiscoverySnapshot, bool) { if !outputCacheEnabled(ctx, wd, env) { debugf(ctx, "generate.output_cache=disabled") - return nil, nil, false + return nil, nil, nil, false } rootResult, err := loader.New().LoadRootGraph(withLoaderTiming(ctx), loader.RootLoadRequest{ WD: wd, @@ -51,7 +51,7 @@ func prepareGenerateOutputCache(ctx context.Context, wd string, env []string, pa } else { debugf(ctx, "generate.output_cache=no_roots") } - return nil, nil, false + return nil, nil, nil, false } candidates := make(map[string]outputCacheCandidate, len(rootResult.Packages)) results := make([]GenerateResult, 0, len(rootResult.Packages)) @@ -59,17 +59,17 @@ func prepareGenerateOutputCache(ctx context.Context, wd string, env []string, pa outDir, err := detectOutputDir(pkg.GoFiles) if err != nil { debugf(ctx, "generate.output_cache=bad_output_dir") - return candidates, nil, false + return candidates, nil, rootResult.Discovery, false } key, err := outputCacheKey(wd, opts, pkg) if err != nil { debugf(ctx, "generate.output_cache=key_error") - return candidates, nil, false + return candidates, nil, rootResult.Discovery, false } path, err := outputCachePath(env, key) if err != nil { debugf(ctx, "generate.output_cache=path_error") - return candidates, nil, false + return candidates, nil, rootResult.Discovery, false } candidates[pkg.PkgPath] = outputCacheCandidate{ path: path, @@ -78,7 +78,7 @@ func prepareGenerateOutputCache(ctx context.Context, wd string, env []string, pa entry, ok := readOutputCache(path) if !ok { debugf(ctx, "generate.output_cache=miss") - return candidates, nil, false + return candidates, nil, rootResult.Discovery, false } results = append(results, GenerateResult{ PkgPath: pkg.PkgPath, @@ -87,7 +87,7 @@ func prepareGenerateOutputCache(ctx context.Context, wd string, env []string, pa }) } debugf(ctx, "generate.output_cache=hit") - return candidates, results, len(results) == len(rootResult.Packages) + return candidates, results, rootResult.Discovery, len(results) == len(rootResult.Packages) } func writeGenerateOutputCache(candidates map[string]outputCacheCandidate, generated []GenerateResult) { diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 4350baa..2e9c428 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -254,7 +254,7 @@ type Field struct { // takes precedence. func Load(ctx context.Context, wd string, env []string, tags string, patterns []string) (*Info, []error) { loadStart := time.Now() - pkgs, errs := load(ctx, wd, env, tags, patterns) + pkgs, errs := load(ctx, wd, env, tags, patterns, nil) logTiming(ctx, "load.packages", loadStart) if len(errs) > 0 { return nil, errs @@ -361,7 +361,7 @@ func Load(ctx context.Context, wd string, env []string, tags string, patterns [] // env is nil or empty, it is interpreted as an empty set of variables. // In case of duplicate environment variables, the last one in the list // takes precedence. -func load(ctx context.Context, wd string, env []string, tags string, patterns []string) ([]*packages.Package, []error) { +func load(ctx context.Context, wd string, env []string, tags string, patterns []string, discovery *loader.DiscoverySnapshot) ([]*packages.Package, []error) { fset := token.NewFileSet() loaderMode := effectiveLoaderMode(ctx, wd, env) parseStats := &parseFileStats{} @@ -374,6 +374,7 @@ func load(ctx context.Context, wd string, env []string, tags string, patterns [] Mode: packages.LoadAllSyntax, LoaderMode: loaderMode, Fset: fset, + Discovery: discovery, ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { start := time.Now() file, err := parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) diff --git a/internal/wire/wire.go b/internal/wire/wire.go index 1c44eba..9f5bb9e 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -102,12 +102,12 @@ func Generate(ctx context.Context, wd string, env []string, patterns []string, o if opts == nil { opts = &GenerateOptions{} } - cacheCandidates, cached, ok := prepareGenerateOutputCache(ctx, wd, env, patterns, opts) + cacheCandidates, cached, discovery, ok := prepareGenerateOutputCache(ctx, wd, env, patterns, opts) if ok { return cached, nil } loadStart := time.Now() - pkgs, errs := load(ctx, wd, env, opts.Tags, patterns) + pkgs, errs := load(ctx, wd, env, opts.Tags, patterns, discovery) logTiming(ctx, "generate.load", loadStart) if len(errs) > 0 { return nil, errs From cf528798871c8089e1d4dec084900115bfdcf035 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Tue, 17 Mar 2026 02:50:43 -0500 Subject: [PATCH 79/79] style: format loader discovery changes --- internal/loader/custom.go | 10 +++++----- internal/loader/discovery.go | 10 +++++----- internal/loader/discovery_cache.go | 12 ++++++------ 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 112fe85..b4532af 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -174,11 +174,11 @@ func validateTouchedPackagesCustom(ctx context.Context, req TouchedValidationReq func loadRootGraphCustom(ctx context.Context, req RootLoadRequest) (*RootLoadResult, error) { meta, discoveryDuration, err := loadCustomMeta(ctx, goListRequest{ - WD: req.WD, - Env: req.Env, - Tags: req.Tags, - Patterns: req.Patterns, - NeedDeps: req.NeedDeps, + WD: req.WD, + Env: req.Env, + Tags: req.Tags, + Patterns: req.Patterns, + NeedDeps: req.NeedDeps, SkipCompiled: true, }) if err != nil { diff --git a/internal/loader/discovery.go b/internal/loader/discovery.go index e416e95..bccfd93 100644 --- a/internal/loader/discovery.go +++ b/internal/loader/discovery.go @@ -27,11 +27,11 @@ import ( ) type goListRequest struct { - WD string - Env []string - Tags string - Patterns []string - NeedDeps bool + WD string + Env []string + Tags string + Patterns []string + NeedDeps bool SkipCompiled bool } diff --git a/internal/loader/discovery_cache.go b/internal/loader/discovery_cache.go index 1a93bdf..1151853 100644 --- a/internal/loader/discovery_cache.go +++ b/internal/loader/discovery_cache.go @@ -127,13 +127,13 @@ func discoveryCachePath(req goListRequest) (string, error) { return "", err } sumReq := struct { - Version int - WD string - Tags string - Patterns []string - NeedDeps bool + Version int + WD string + Tags string + Patterns []string + NeedDeps bool SkipCompiled bool - Go string + Go string }{ Version: discoveryCacheVersion, WD: canonicalLoaderPath(req.WD),