diff --git a/.gitignore b/.gitignore index 94f3e27c93..37291bad15 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,4 @@ dist/ # for VS Code debug sessions **/__debug_bin* **/ginkgo.report +.history/ \ No newline at end of file diff --git a/cmd/edit/machinepool/cmd.go b/cmd/edit/machinepool/cmd.go index 56e25a7376..b0b6d19ea8 100644 --- a/cmd/edit/machinepool/cmd.go +++ b/cmd/edit/machinepool/cmd.go @@ -52,7 +52,7 @@ func NewEditMachinePoolCommand() *cobra.Command { Long: long, Aliases: aliases, Example: example, - Args: machinepool.NewMachinepoolArgsFunction(false), + Args: cobra.MaximumNArgs(1), Run: rosa.DefaultRunner(rosa.RuntimeWithOCM(), EditMachinePoolRunner(options)), } diff --git a/cmd/rosa/main.go b/cmd/rosa/main.go index 4160f38eb2..f2e4e3a8a9 100644 --- a/cmd/rosa/main.go +++ b/cmd/rosa/main.go @@ -21,6 +21,7 @@ import ( "os" "strings" + cobra_mcp "github.com/paulczar/cobra-mcp/pkg" "github.com/spf13/cobra" "github.com/openshift/rosa/pkg/arguments" @@ -49,6 +50,55 @@ func init() { // Register the subcommands: commands.RegisterCommands(root) + + // Add MCP support with sub-process execution mode + serverConfig := &cobra_mcp.ServerConfig{ + ToolPrefix: "rosa", + ExecutionMode: "sub-process", + EnableResources: true, + } + mcpCmd := cobra_mcp.NewMCPCommand(root, serverConfig) + mcpCmd.Args = cobra.NoArgs + // Set Args for mcp subcommands + for _, subCmd := range mcpCmd.Commands() { + subCmd.Args = cobra.NoArgs + // Convert RunE to Run for test compatibility + if subCmd.RunE != nil { + runE := subCmd.RunE + subCmd.RunE = nil + subCmd.Run = func(cmd *cobra.Command, args []string) { + if err := runE(cmd, args); err != nil { + cmd.PrintErrln(err) + os.Exit(1) + } + } + } + } + root.AddCommand(mcpCmd) + + // Add Chat support + chatConfig := &cobra_mcp.ChatConfig{ + Model: "gpt-5-mini", + Debug: false, + } + chatCmd := cobra_mcp.NewChatCommand(root, chatConfig, serverConfig) + chatCmd.Args = cobra.NoArgs + // Set Args and Run for system-message subcommand + if systemMsgCmd := chatCmd.Commands()[0]; systemMsgCmd != nil && systemMsgCmd.Name() == "system-message" { + systemMsgCmd.Args = cobra.NoArgs + // Convert RunE to Run for test compatibility + if systemMsgCmd.RunE != nil { + runE := systemMsgCmd.RunE + systemMsgCmd.RunE = nil + systemMsgCmd.Run = func(cmd *cobra.Command, args []string) { + if err := runE(cmd, args); err != nil { + cmd.PrintErrln(err) + os.Exit(1) + } + } + } + } + root.AddCommand(chatCmd) } func main() { diff --git a/cmd/rosa/structure_test/command_args/rosa/chat/command_args.yml b/cmd/rosa/structure_test/command_args/rosa/chat/command_args.yml new file mode 100644 index 0000000000..37fa4a9071 --- /dev/null +++ b/cmd/rosa/structure_test/command_args/rosa/chat/command_args.yml @@ -0,0 +1,7 @@ +- name: api-key +- name: api-url +- name: debug +- name: message +- name: model +- name: stdin +- name: system-message-file diff --git a/cmd/rosa/structure_test/command_args/rosa/chat/system-message/command_args.yml b/cmd/rosa/structure_test/command_args/rosa/chat/system-message/command_args.yml new file mode 100644 index 0000000000..c78c243d6d --- /dev/null +++ b/cmd/rosa/structure_test/command_args/rosa/chat/system-message/command_args.yml @@ -0,0 +1 @@ +- name: system-message-file diff --git a/cmd/rosa/structure_test/command_args/rosa/mcp/start/command_args.yml b/cmd/rosa/structure_test/command_args/rosa/mcp/start/command_args.yml new file mode 100644 index 0000000000..e69de29bb2 diff --git a/cmd/rosa/structure_test/command_args/rosa/mcp/stream/command_args.yml b/cmd/rosa/structure_test/command_args/rosa/mcp/stream/command_args.yml new file mode 100644 index 0000000000..7105c23ed8 --- /dev/null +++ b/cmd/rosa/structure_test/command_args/rosa/mcp/stream/command_args.yml @@ -0,0 +1 @@ +- name: port diff --git a/cmd/rosa/structure_test/command_args/rosa/mcp/tools/command_args.yml b/cmd/rosa/structure_test/command_args/rosa/mcp/tools/command_args.yml new file mode 100644 index 0000000000..e69de29bb2 diff --git a/cmd/rosa/structure_test/command_structure.yml b/cmd/rosa/structure_test/command_structure.yml index 76cf276be8..34ccee098d 100755 --- a/cmd/rosa/structure_test/command_structure.yml +++ b/cmd/rosa/structure_test/command_structure.yml @@ -6,6 +6,9 @@ # name: rosa children: +- name: chat + children: + - name: system-message - name: completion - name: config children: @@ -139,6 +142,11 @@ children: - name: versions - name: login - name: logout +- name: mcp + children: + - name: start + - name: stream + - name: tools - name: logs children: - name: install diff --git a/go.mod b/go.mod index c8e5f25fd0..3aaeecf36a 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/openshift/rosa -go 1.24.0 +go 1.25.3 require ( github.com/AlecAivazis/survey/v2 v2.2.15 @@ -32,11 +32,12 @@ require ( github.com/openshift-online/ocm-api-model/clientapi v0.0.437 github.com/openshift-online/ocm-common v0.0.31 github.com/openshift-online/ocm-sdk-go v0.1.482 + github.com/paulczar/cobra-mcp v1.3.0 github.com/pkg/errors v0.9.1 github.com/robfig/cron/v3 v3.0.1 github.com/sirupsen/logrus v1.9.3 - github.com/spf13/cobra v1.8.0 - github.com/spf13/pflag v1.0.5 + github.com/spf13/cobra v1.10.1 + github.com/spf13/pflag v1.0.9 github.com/zgalor/weberr v0.6.0 gitlab.com/c0b/go-ordered-json v0.0.0-20201030195603-febf46534d5a go.uber.org/mock v0.5.2 @@ -50,8 +51,16 @@ require ( github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2 v1.34.0 // indirect github.com/aws/aws-sdk-go-v2/service/ram v1.26.1 // indirect github.com/go-jose/go-jose/v4 v4.0.2 // indirect + github.com/google/jsonschema-go v0.3.0 // indirect + github.com/modelcontextprotocol/go-sdk v1.1.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/openai/openai-go v1.12.0 // indirect github.com/openshift-online/ocm-api-model/model v0.0.437 // indirect + github.com/tidwall/gjson v1.14.4 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect go.yaml.in/yaml/v2 v2.4.3 // indirect ) @@ -80,7 +89,7 @@ require ( github.com/beorn7/perks v1.0.1 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect - github.com/cpuguy83/go-md2man/v2 v2.0.3 // indirect + github.com/cpuguy83/go-md2man/v2 v2.0.6 // indirect github.com/danieljoos/wincred v1.2.0 // indirect github.com/dvsekhvalnov/jose2go v1.6.0 // indirect github.com/evanphx/json-patch/v5 v5.6.0 // indirect diff --git a/go.sum b/go.sum index 9ab2cb877d..e5518e2fb0 100644 --- a/go.sum +++ b/go.sum @@ -95,8 +95,8 @@ github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= -github.com/cpuguy83/go-md2man/v2 v2.0.3 h1:qMCsGGgs+MAzDFyp9LpAe1Lqy/fY/qCovCm0qnXZOBM= -github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/cpuguy83/go-md2man/v2 v2.0.6 h1:XJtiaUW6dEEqVuZiMTn1ldk455QWwEIsMIJlo5vtkx0= +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.7 h1:6pwm8kMQKCmgUg0ZHTm5+/YvRK0s3THD/28+T6/kk4A= github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= github.com/danieljoos/wincred v1.2.0 h1:ozqKHaLK0W/ii4KVbbvluM91W2H3Sh0BncbUNPS7jLE= @@ -141,6 +141,8 @@ github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q= +github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 h1:K6RDEckDVWvDI9JAJYCmNdQXq6neHJOYx3V6jnqNEec= github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= @@ -258,6 +260,8 @@ github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d h1:5PJl274Y63IEHC+7izoQ github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwXFM08ygZfk= github.com/microcosm-cc/bluemonday v1.0.27/go.mod h1:jFi9vgW+H7c3V0lb6nR74Ib/DIB5OBs92Dimizgw2cA= +github.com/modelcontextprotocol/go-sdk v1.1.0 h1:Qjayg53dnKC4UZ+792W21e4BpwEZBzwgRW6LrjLWSwA= +github.com/modelcontextprotocol/go-sdk v1.1.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -274,6 +278,8 @@ github.com/onsi/ginkgo/v2 v2.17.1 h1:V++EzdbhI4ZV4ev0UTIj0PzhzOcReJFyJaLjtSF55M8 github.com/onsi/ginkgo/v2 v2.17.1/go.mod h1:llBI3WDLL9Z6taip6f33H76YcWtJv+7R3HigUjbIBOs= github.com/onsi/gomega v1.30.0 h1:hvMK7xYz4D3HapigLTeGdId/NcfQx1VHMJc60ew99+8= github.com/onsi/gomega v1.30.0/go.mod h1:9sxs+SwGrKI0+PWe4Fxa9tFQQBG5xSsSbMXOI8PPpoQ= +github.com/openai/openai-go v1.12.0 h1:NBQCnXzqOTv5wsgNC36PrFEiskGfO5wccfCWDo9S1U0= +github.com/openai/openai-go v1.12.0/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= github.com/openshift-online/ocm-api-model/clientapi v0.0.437 h1:dbBQu3oDKP/EbevS5a9qWX1WsIZRyrjPm73f6jQprqc= github.com/openshift-online/ocm-api-model/clientapi v0.0.437/go.mod h1:fZwy5HY2URG9nrExvQeXrDU/08TGqZ16f8oymVEN5lo= github.com/openshift-online/ocm-api-model/model v0.0.437 h1:FxXKbDGODZ5jFpTK1Czpz1uRXm945fA/E1cLxeVsOmk= @@ -282,6 +288,8 @@ github.com/openshift-online/ocm-common v0.0.31 h1:csxB4UQAUhwhDOVBmOzUKgtemuwV9r github.com/openshift-online/ocm-common v0.0.31/go.mod h1:VEkuZp9aqbXtetZ5ycND6QpvhykvTuBF3oPsVM1X3vI= github.com/openshift-online/ocm-sdk-go v0.1.482 h1:vLj56Y7/nC3FBTn+uTNqYJRrOCKp9E4gTtJJSJF2tSc= github.com/openshift-online/ocm-sdk-go v0.1.482/go.mod h1:VEvdienFpwhq9rfrj92s3eMcyaT1A0MS9LwGEARG6Zw= +github.com/paulczar/cobra-mcp v1.3.0 h1:nUAIdE3yP8lpFi2J/E3IHIKjRQuaQ4B73+Pl0WYg350= +github.com/paulczar/cobra-mcp v1.3.0/go.mod h1:XPGRXfPrRi6trhZtiPCUJnh+nzuhUV8tYSEF0nTFHiQ= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -315,10 +323,10 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 h1:JIAuq3EEf9cgbU6AtGPK4CTG3Zf6CKMNqf0MHTggAUA= github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966/go.mod h1:sUM3LWHvSMaG192sy56D9F7CNvL7jUJVXoqM1QKLnog= -github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= -github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= -github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= -github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s= +github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4XaB0= +github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY= +github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= @@ -333,6 +341,18 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= +github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= diff --git a/vendor/github.com/cpuguy83/go-md2man/v2/md2man/debug.go b/vendor/github.com/cpuguy83/go-md2man/v2/md2man/debug.go new file mode 100644 index 0000000000..0ec4b12c75 --- /dev/null +++ b/vendor/github.com/cpuguy83/go-md2man/v2/md2man/debug.go @@ -0,0 +1,62 @@ +package md2man + +import ( + "fmt" + "io" + "os" + "strings" + + "github.com/russross/blackfriday/v2" +) + +func fmtListFlags(flags blackfriday.ListType) string { + knownFlags := []struct { + name string + flag blackfriday.ListType + }{ + {"ListTypeOrdered", blackfriday.ListTypeOrdered}, + {"ListTypeDefinition", blackfriday.ListTypeDefinition}, + {"ListTypeTerm", blackfriday.ListTypeTerm}, + {"ListItemContainsBlock", blackfriday.ListItemContainsBlock}, + {"ListItemBeginningOfList", blackfriday.ListItemBeginningOfList}, + {"ListItemEndOfList", blackfriday.ListItemEndOfList}, + } + + var f []string + for _, kf := range knownFlags { + if flags&kf.flag != 0 { + f = append(f, kf.name) + flags &^= kf.flag + } + } + if flags != 0 { + f = append(f, fmt.Sprintf("Unknown(%#x)", flags)) + } + return strings.Join(f, "|") +} + +type debugDecorator struct { + blackfriday.Renderer +} + +func depth(node *blackfriday.Node) int { + d := 0 + for n := node.Parent; n != nil; n = n.Parent { + d++ + } + return d +} + +func (d *debugDecorator) RenderNode(w io.Writer, node *blackfriday.Node, entering bool) blackfriday.WalkStatus { + fmt.Fprintf(os.Stderr, "%s%s %v %v\n", + strings.Repeat(" ", depth(node)), + map[bool]string{true: "+", false: "-"}[entering], + node, + fmtListFlags(node.ListFlags)) + var b strings.Builder + status := d.Renderer.RenderNode(io.MultiWriter(&b, w), node, entering) + if b.Len() > 0 { + fmt.Fprintf(os.Stderr, ">> %q\n", b.String()) + } + return status +} diff --git a/vendor/github.com/cpuguy83/go-md2man/v2/md2man/md2man.go b/vendor/github.com/cpuguy83/go-md2man/v2/md2man/md2man.go index 42bf32aab0..62d91b77d5 100644 --- a/vendor/github.com/cpuguy83/go-md2man/v2/md2man/md2man.go +++ b/vendor/github.com/cpuguy83/go-md2man/v2/md2man/md2man.go @@ -1,16 +1,23 @@ package md2man import ( + "os" + "strconv" + "github.com/russross/blackfriday/v2" ) // Render converts a markdown document into a roff formatted document. func Render(doc []byte) []byte { renderer := NewRoffRenderer() + var r blackfriday.Renderer = renderer + if v, _ := strconv.ParseBool(os.Getenv("MD2MAN_DEBUG")); v { + r = &debugDecorator{Renderer: r} + } return blackfriday.Run(doc, []blackfriday.Option{ - blackfriday.WithRenderer(renderer), + blackfriday.WithRenderer(r), blackfriday.WithExtensions(renderer.GetExtensions()), }...) } diff --git a/vendor/github.com/cpuguy83/go-md2man/v2/md2man/roff.go b/vendor/github.com/cpuguy83/go-md2man/v2/md2man/roff.go index 4b19188d90..96a80c99b8 100644 --- a/vendor/github.com/cpuguy83/go-md2man/v2/md2man/roff.go +++ b/vendor/github.com/cpuguy83/go-md2man/v2/md2man/roff.go @@ -1,6 +1,7 @@ package md2man import ( + "bufio" "bytes" "fmt" "io" @@ -13,68 +14,72 @@ import ( // roffRenderer implements the blackfriday.Renderer interface for creating // roff format (manpages) from markdown text type roffRenderer struct { - extensions blackfriday.Extensions listCounters []int firstHeader bool - firstDD bool listDepth int } const ( - titleHeader = ".TH " - topLevelHeader = "\n\n.SH " - secondLevelHdr = "\n.SH " - otherHeader = "\n.SS " - crTag = "\n" - emphTag = "\\fI" - emphCloseTag = "\\fP" - strongTag = "\\fB" - strongCloseTag = "\\fP" - breakTag = "\n.br\n" - paraTag = "\n.PP\n" - hruleTag = "\n.ti 0\n\\l'\\n(.lu'\n" - linkTag = "\n\\[la]" - linkCloseTag = "\\[ra]" - codespanTag = "\\fB" - codespanCloseTag = "\\fR" - codeTag = "\n.EX\n" - codeCloseTag = "\n.EE\n" - quoteTag = "\n.PP\n.RS\n" - quoteCloseTag = "\n.RE\n" - listTag = "\n.RS\n" - listCloseTag = "\n.RE\n" - dtTag = "\n.TP\n" - dd2Tag = "\n" - tableStart = "\n.TS\nallbox;\n" - tableEnd = ".TE\n" - tableCellStart = "T{\n" - tableCellEnd = "\nT}\n" + titleHeader = ".TH " + topLevelHeader = "\n\n.SH " + secondLevelHdr = "\n.SH " + otherHeader = "\n.SS " + crTag = "\n" + emphTag = "\\fI" + emphCloseTag = "\\fP" + strongTag = "\\fB" + strongCloseTag = "\\fP" + breakTag = "\n.br\n" + paraTag = "\n.PP\n" + hruleTag = "\n.ti 0\n\\l'\\n(.lu'\n" + linkTag = "\n\\[la]" + linkCloseTag = "\\[ra]" + codespanTag = "\\fB" + codespanCloseTag = "\\fR" + codeTag = "\n.EX\n" + codeCloseTag = ".EE\n" // Do not prepend a newline character since code blocks, by definition, include a newline already (or at least as how blackfriday gives us on). + quoteTag = "\n.PP\n.RS\n" + quoteCloseTag = "\n.RE\n" + listTag = "\n.RS\n" + listCloseTag = ".RE\n" + dtTag = "\n.TP\n" + dd2Tag = "\n" + tableStart = "\n.TS\nallbox;\n" + tableEnd = ".TE\n" + tableCellStart = "T{\n" + tableCellEnd = "\nT}\n" + tablePreprocessor = `'\" t` ) // NewRoffRenderer creates a new blackfriday Renderer for generating roff documents // from markdown func NewRoffRenderer() *roffRenderer { // nolint: golint - var extensions blackfriday.Extensions - - extensions |= blackfriday.NoIntraEmphasis - extensions |= blackfriday.Tables - extensions |= blackfriday.FencedCode - extensions |= blackfriday.SpaceHeadings - extensions |= blackfriday.Footnotes - extensions |= blackfriday.Titleblock - extensions |= blackfriday.DefinitionLists - return &roffRenderer{ - extensions: extensions, - } + return &roffRenderer{} } // GetExtensions returns the list of extensions used by this renderer implementation -func (r *roffRenderer) GetExtensions() blackfriday.Extensions { - return r.extensions +func (*roffRenderer) GetExtensions() blackfriday.Extensions { + return blackfriday.NoIntraEmphasis | + blackfriday.Tables | + blackfriday.FencedCode | + blackfriday.SpaceHeadings | + blackfriday.Footnotes | + blackfriday.Titleblock | + blackfriday.DefinitionLists } // RenderHeader handles outputting the header at document start func (r *roffRenderer) RenderHeader(w io.Writer, ast *blackfriday.Node) { + // We need to walk the tree to check if there are any tables. + // If there are, we need to enable the roff table preprocessor. + ast.Walk(func(node *blackfriday.Node, entering bool) blackfriday.WalkStatus { + if node.Type == blackfriday.Table { + out(w, tablePreprocessor+"\n") + return blackfriday.Terminate + } + return blackfriday.GoToNext + }) + // disable hyphenation out(w, ".nh\n") } @@ -91,7 +96,23 @@ func (r *roffRenderer) RenderNode(w io.Writer, node *blackfriday.Node, entering switch node.Type { case blackfriday.Text: - escapeSpecialChars(w, node.Literal) + // Special case: format the NAME section as required for proper whatis parsing. + // Refer to the lexgrog(1) and groff_man(7) manual pages for details. + if node.Parent != nil && + node.Parent.Type == blackfriday.Paragraph && + node.Parent.Prev != nil && + node.Parent.Prev.Type == blackfriday.Heading && + node.Parent.Prev.FirstChild != nil && + bytes.EqualFold(node.Parent.Prev.FirstChild.Literal, []byte("NAME")) { + before, after, found := bytesCut(node.Literal, []byte(" - ")) + escapeSpecialChars(w, before) + if found { + out(w, ` \- `) + escapeSpecialChars(w, after) + } + } else { + escapeSpecialChars(w, node.Literal) + } case blackfriday.Softbreak: out(w, crTag) case blackfriday.Hardbreak: @@ -129,14 +150,25 @@ func (r *roffRenderer) RenderNode(w io.Writer, node *blackfriday.Node, entering case blackfriday.Document: break case blackfriday.Paragraph: - // roff .PP markers break lists - if r.listDepth > 0 { - return blackfriday.GoToNext - } if entering { - out(w, paraTag) + if r.listDepth > 0 { + // roff .PP markers break lists + if node.Prev != nil { // continued paragraph + if node.Prev.Type == blackfriday.List && node.Prev.ListFlags&blackfriday.ListTypeDefinition == 0 { + out(w, ".IP\n") + } else { + out(w, crTag) + } + } + } else if node.Prev != nil && node.Prev.Type == blackfriday.Heading { + out(w, crTag) + } else { + out(w, paraTag) + } } else { - out(w, crTag) + if node.Next == nil || node.Next.Type != blackfriday.List { + out(w, crTag) + } } case blackfriday.BlockQuote: if entering { @@ -199,6 +231,10 @@ func (r *roffRenderer) handleHeading(w io.Writer, node *blackfriday.Node, enteri func (r *roffRenderer) handleList(w io.Writer, node *blackfriday.Node, entering bool) { openTag := listTag closeTag := listCloseTag + if (entering && r.listDepth == 0) || (!entering && r.listDepth == 1) { + openTag = crTag + closeTag = "" + } if node.ListFlags&blackfriday.ListTypeDefinition != 0 { // tags for definition lists handled within Item node openTag = "" @@ -227,23 +263,25 @@ func (r *roffRenderer) handleItem(w io.Writer, node *blackfriday.Node, entering } else if node.ListFlags&blackfriday.ListTypeTerm != 0 { // DT (definition term): line just before DD (see below). out(w, dtTag) - r.firstDD = true } else if node.ListFlags&blackfriday.ListTypeDefinition != 0 { // DD (definition description): line that starts with ": ". // // We have to distinguish between the first DD and the // subsequent ones, as there should be no vertical // whitespace between the DT and the first DD. - if r.firstDD { - r.firstDD = false - } else { - out(w, dd2Tag) + if node.Prev != nil && node.Prev.ListFlags&(blackfriday.ListTypeTerm|blackfriday.ListTypeDefinition) == blackfriday.ListTypeDefinition { + if node.Prev.Type == blackfriday.Item && + node.Prev.LastChild != nil && + node.Prev.LastChild.Type == blackfriday.List && + node.Prev.LastChild.ListFlags&blackfriday.ListTypeDefinition == 0 { + out(w, ".IP\n") + } else { + out(w, dd2Tag) + } } } else { out(w, ".IP \\(bu 2\n") } - } else { - out(w, "\n") } } @@ -322,6 +360,28 @@ func out(w io.Writer, output string) { } func escapeSpecialChars(w io.Writer, text []byte) { + scanner := bufio.NewScanner(bytes.NewReader(text)) + + // count the number of lines in the text + // we need to know this to avoid adding a newline after the last line + n := bytes.Count(text, []byte{'\n'}) + idx := 0 + + for scanner.Scan() { + dt := scanner.Bytes() + if idx < n { + idx++ + dt = append(dt, '\n') + } + escapeSpecialCharsLine(w, dt) + } + + if err := scanner.Err(); err != nil { + panic(err) + } +} + +func escapeSpecialCharsLine(w io.Writer, text []byte) { for i := 0; i < len(text); i++ { // escape initial apostrophe or period if len(text) >= 1 && (text[0] == '\'' || text[0] == '.') { @@ -346,3 +406,12 @@ func escapeSpecialChars(w io.Writer, text []byte) { w.Write([]byte{'\\', text[i]}) // nolint: errcheck } } + +// bytesCut is a copy of [bytes.Cut] to provide compatibility with go1.17 +// and older. We can remove this once we drop support for go1.17 and older. +func bytesCut(s, sep []byte) (before, after []byte, found bool) { + if i := bytes.Index(s, sep); i >= 0 { + return s[:i], s[i+len(sep):], true + } + return s, nil, false +} diff --git a/vendor/github.com/google/jsonschema-go/LICENSE b/vendor/github.com/google/jsonschema-go/LICENSE new file mode 100644 index 0000000000..1cb53e9df9 --- /dev/null +++ b/vendor/github.com/google/jsonschema-go/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 JSON Schema Go Project Authors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/google/jsonschema-go/jsonschema/annotations.go b/vendor/github.com/google/jsonschema-go/jsonschema/annotations.go new file mode 100644 index 0000000000..d4dd6436b8 --- /dev/null +++ b/vendor/github.com/google/jsonschema-go/jsonschema/annotations.go @@ -0,0 +1,76 @@ +// Copyright 2025 The JSON Schema Go Project Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package jsonschema + +import "maps" + +// An annotations tracks certain properties computed by keywords that are used by validation. +// ("Annotation" is the spec's term.) +// In particular, the unevaluatedItems and unevaluatedProperties keywords need to know which +// items and properties were evaluated (validated successfully). +type annotations struct { + allItems bool // all items were evaluated + endIndex int // 1+largest index evaluated by prefixItems + evaluatedIndexes map[int]bool // set of indexes evaluated by contains + allProperties bool // all properties were evaluated + evaluatedProperties map[string]bool // set of properties evaluated by various keywords +} + +// noteIndex marks i as evaluated. +func (a *annotations) noteIndex(i int) { + if a.evaluatedIndexes == nil { + a.evaluatedIndexes = map[int]bool{} + } + a.evaluatedIndexes[i] = true +} + +// noteEndIndex marks items with index less than end as evaluated. +func (a *annotations) noteEndIndex(end int) { + if end > a.endIndex { + a.endIndex = end + } +} + +// noteProperty marks prop as evaluated. +func (a *annotations) noteProperty(prop string) { + if a.evaluatedProperties == nil { + a.evaluatedProperties = map[string]bool{} + } + a.evaluatedProperties[prop] = true +} + +// noteProperties marks all the properties in props as evaluated. +func (a *annotations) noteProperties(props map[string]bool) { + a.evaluatedProperties = merge(a.evaluatedProperties, props) +} + +// merge adds b's annotations to a. +// a must not be nil. +func (a *annotations) merge(b *annotations) { + if b == nil { + return + } + if b.allItems { + a.allItems = true + } + if b.endIndex > a.endIndex { + a.endIndex = b.endIndex + } + a.evaluatedIndexes = merge(a.evaluatedIndexes, b.evaluatedIndexes) + if b.allProperties { + a.allProperties = true + } + a.evaluatedProperties = merge(a.evaluatedProperties, b.evaluatedProperties) +} + +// merge adds t's keys to s and returns s. +// If s is nil, it returns a copy of t. +func merge[K comparable](s, t map[K]bool) map[K]bool { + if s == nil { + return maps.Clone(t) + } + maps.Copy(s, t) + return s +} diff --git a/vendor/github.com/google/jsonschema-go/jsonschema/doc.go b/vendor/github.com/google/jsonschema-go/jsonschema/doc.go new file mode 100644 index 0000000000..a34bab725f --- /dev/null +++ b/vendor/github.com/google/jsonschema-go/jsonschema/doc.go @@ -0,0 +1,101 @@ +// Copyright 2025 The JSON Schema Go Project Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +/* +Package jsonschema is an implementation of the [JSON Schema specification], +a JSON-based format for describing the structure of JSON data. +The package can be used to read schemas for code generation, and to validate +data using the draft 2020-12 specification. Validation with other drafts +or custom meta-schemas is not supported. + +Construct a [Schema] as you would any Go struct (for example, by writing +a struct literal), or unmarshal a JSON schema into a [Schema] in the usual +way (with [encoding/json], for instance). It can then be used for code +generation or other purposes without further processing. +You can also infer a schema from a Go struct. + +# Resolution + +A Schema can refer to other schemas, both inside and outside itself. These +references must be resolved before a schema can be used for validation. +Call [Schema.Resolve] to obtain a resolved schema (called a [Resolved]). +If the schema has external references, pass a [ResolveOptions] with a [Loader] +to load them. To validate default values in a schema, set +[ResolveOptions.ValidateDefaults] to true. + +# Validation + +Call [Resolved.Validate] to validate a JSON value. The value must be a +Go value that looks like the result of unmarshaling a JSON value into an +[any] or a struct. For example, the JSON value + + {"name": "Al", "scores": [90, 80, 100]} + +could be represented as the Go value + + map[string]any{ + "name": "Al", + "scores": []any{90, 80, 100}, + } + +or as a value of this type: + + type Player struct { + Name string `json:"name"` + Scores []int `json:"scores"` + } + +# Inference + +The [For] function returns a [Schema] describing the given Go type. +Each field in the struct becomes a property of the schema. +The values of "json" tags are respected: the field's property name is taken +from the tag, and fields omitted from the JSON are omitted from the schema as +well. +For example, `jsonschema.For[Player]()` returns this schema: + + { + "properties": { + "name": { + "type": "string" + }, + "scores": { + "type": "array", + "items": {"type": "integer"} + } + "required": ["name", "scores"], + "additionalProperties": {"not": {}} + } + } + +Use the "jsonschema" struct tag to provide a description for the property: + + type Player struct { + Name string `json:"name" jsonschema:"player name"` + Scores []int `json:"scores" jsonschema:"scores of player's games"` + } + +# Deviations from the specification + +Regular expressions are processed with Go's regexp package, which differs +from ECMA 262, most significantly in not supporting back-references. +See [this table of differences] for more. + +The "format" keyword described in [section 7 of the validation spec] is recorded +in the Schema, but is ignored during validation. +It does not even produce [annotations]. +Use the "pattern" keyword instead: it will work more reliably across JSON Schema +implementations. See [learnjsonschema.com] for more recommendations about "format". + +The content keywords described in [section 8 of the validation spec] +are recorded in the schema, but ignored during validation. + +[JSON Schema specification]: https://json-schema.org +[section 7 of the validation spec]: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-00#rfc.section.7 +[section 8 of the validation spec]: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-00#rfc.section.8 +[learnjsonschema.com]: https://www.learnjsonschema.com/2020-12/format-annotation/format/ +[this table of differences]: https://github.com/dlclark/regexp2?tab=readme-ov-file#compare-regexp-and-regexp2 +[annotations]: https://json-schema.org/draft/2020-12/json-schema-core#name-annotations +*/ +package jsonschema diff --git a/vendor/github.com/google/jsonschema-go/jsonschema/infer.go b/vendor/github.com/google/jsonschema-go/jsonschema/infer.go new file mode 100644 index 0000000000..ae624ad094 --- /dev/null +++ b/vendor/github.com/google/jsonschema-go/jsonschema/infer.go @@ -0,0 +1,248 @@ +// Copyright 2025 The JSON Schema Go Project Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file contains functions that infer a schema from a Go type. + +package jsonschema + +import ( + "fmt" + "log/slog" + "maps" + "math/big" + "reflect" + "regexp" + "time" +) + +// ForOptions are options for the [For] and [ForType] functions. +type ForOptions struct { + // If IgnoreInvalidTypes is true, fields that can't be represented as a JSON + // Schema are ignored instead of causing an error. + // This allows callers to adjust the resulting schema using custom knowledge. + // For example, an interface type where all the possible implementations are + // known can be described with "oneof". + IgnoreInvalidTypes bool + + // TypeSchemas maps types to their schemas. + // If [For] encounters a type that is a key in this map, the + // corresponding value is used as the resulting schema (after cloning to + // ensure uniqueness). + // Types in this map override the default translations, as described + // in [For]'s documentation. + TypeSchemas map[reflect.Type]*Schema +} + +// For constructs a JSON schema object for the given type argument. +// If non-nil, the provided options configure certain aspects of this contruction, +// described below. + +// It translates Go types into compatible JSON schema types, as follows. +// These defaults can be overridden by [ForOptions.TypeSchemas]. +// +// - Strings have schema type "string". +// - Bools have schema type "boolean". +// - Signed and unsigned integer types have schema type "integer". +// - Floating point types have schema type "number". +// - Slices and arrays have schema type "array", and a corresponding schema +// for items. +// - Maps with string key have schema type "object", and corresponding +// schema for additionalProperties. +// - Structs have schema type "object", and disallow additionalProperties. +// Their properties are derived from exported struct fields, using the +// struct field JSON name. Fields that are marked "omitempty" are +// considered optional; all other fields become required properties. +// - Some types in the standard library that implement json.Marshaler +// translate to schemas that match the values to which they marshal. +// For example, [time.Time] translates to the schema for strings. +// +// For will return an error if there is a cycle in the types. +// +// By default, For returns an error if t contains (possibly recursively) any of the +// following Go types, as they are incompatible with the JSON schema spec. +// If [ForOptions.IgnoreInvalidTypes] is true, then these types are ignored instead. +// - maps with key other than 'string' +// - function types +// - channel types +// - complex numbers +// - unsafe pointers +// +// This function recognizes struct field tags named "jsonschema". +// A jsonschema tag on a field is used as the description for the corresponding property. +// For future compatibility, descriptions must not start with "WORD=", where WORD is a +// sequence of non-whitespace characters. +func For[T any](opts *ForOptions) (*Schema, error) { + if opts == nil { + opts = &ForOptions{} + } + schemas := maps.Clone(initialSchemaMap) + // Add types from the options. They override the default ones. + maps.Copy(schemas, opts.TypeSchemas) + s, err := forType(reflect.TypeFor[T](), map[reflect.Type]bool{}, opts.IgnoreInvalidTypes, schemas) + if err != nil { + var z T + return nil, fmt.Errorf("For[%T](): %w", z, err) + } + return s, nil +} + +// ForType is like [For], but takes a [reflect.Type] +func ForType(t reflect.Type, opts *ForOptions) (*Schema, error) { + schemas := maps.Clone(initialSchemaMap) + // Add types from the options. They override the default ones. + maps.Copy(schemas, opts.TypeSchemas) + s, err := forType(t, map[reflect.Type]bool{}, opts.IgnoreInvalidTypes, schemas) + if err != nil { + return nil, fmt.Errorf("ForType(%s): %w", t, err) + } + return s, nil +} + +func forType(t reflect.Type, seen map[reflect.Type]bool, ignore bool, schemas map[reflect.Type]*Schema) (*Schema, error) { + // Follow pointers: the schema for *T is almost the same as for T, except that + // an explicit JSON "null" is allowed for the pointer. + allowNull := false + for t.Kind() == reflect.Pointer { + allowNull = true + t = t.Elem() + } + + // Check for cycles + // User defined types have a name, so we can skip those that are natively defined + if t.Name() != "" { + if seen[t] { + return nil, fmt.Errorf("cycle detected for type %v", t) + } + seen[t] = true + defer delete(seen, t) + } + + if s := schemas[t]; s != nil { + return s.CloneSchemas(), nil + } + + var ( + s = new(Schema) + err error + ) + + switch t.Kind() { + case reflect.Bool: + s.Type = "boolean" + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Uintptr: + s.Type = "integer" + + case reflect.Float32, reflect.Float64: + s.Type = "number" + + case reflect.Interface: + // Unrestricted + + case reflect.Map: + if t.Key().Kind() != reflect.String { + if ignore { + return nil, nil // ignore + } + return nil, fmt.Errorf("unsupported map key type %v", t.Key().Kind()) + } + if t.Key().Kind() != reflect.String { + } + s.Type = "object" + s.AdditionalProperties, err = forType(t.Elem(), seen, ignore, schemas) + if err != nil { + return nil, fmt.Errorf("computing map value schema: %v", err) + } + if ignore && s.AdditionalProperties == nil { + // Ignore if the element type is invalid. + return nil, nil + } + + case reflect.Slice, reflect.Array: + s.Type = "array" + s.Items, err = forType(t.Elem(), seen, ignore, schemas) + if err != nil { + return nil, fmt.Errorf("computing element schema: %v", err) + } + if ignore && s.Items == nil { + // Ignore if the element type is invalid. + return nil, nil + } + if t.Kind() == reflect.Array { + s.MinItems = Ptr(t.Len()) + s.MaxItems = Ptr(t.Len()) + } + + case reflect.String: + s.Type = "string" + + case reflect.Struct: + s.Type = "object" + // no additional properties are allowed + s.AdditionalProperties = falseSchema() + for _, field := range reflect.VisibleFields(t) { + if field.Anonymous { + continue + } + + info := fieldJSONInfo(field) + if info.omit { + continue + } + if s.Properties == nil { + s.Properties = make(map[string]*Schema) + } + fs, err := forType(field.Type, seen, ignore, schemas) + if err != nil { + return nil, err + } + if ignore && fs == nil { + // Skip fields of invalid type. + continue + } + if tag, ok := field.Tag.Lookup("jsonschema"); ok { + if tag == "" { + return nil, fmt.Errorf("empty jsonschema tag on struct field %s.%s", t, field.Name) + } + if disallowedPrefixRegexp.MatchString(tag) { + return nil, fmt.Errorf("tag must not begin with 'WORD=': %q", tag) + } + fs.Description = tag + } + s.Properties[info.name] = fs + if !info.settings["omitempty"] && !info.settings["omitzero"] { + s.Required = append(s.Required, info.name) + } + } + + default: + if ignore { + // Ignore. + return nil, nil + } + return nil, fmt.Errorf("type %v is unsupported by jsonschema", t) + } + if allowNull && s.Type != "" { + s.Types = []string{"null", s.Type} + s.Type = "" + } + return s, nil +} + +// initialSchemaMap holds types from the standard library that have MarshalJSON methods. +var initialSchemaMap = make(map[reflect.Type]*Schema) + +func init() { + ss := &Schema{Type: "string"} + initialSchemaMap[reflect.TypeFor[time.Time]()] = ss + initialSchemaMap[reflect.TypeFor[slog.Level]()] = ss + initialSchemaMap[reflect.TypeFor[big.Int]()] = &Schema{Types: []string{"null", "string"}} + initialSchemaMap[reflect.TypeFor[big.Rat]()] = ss + initialSchemaMap[reflect.TypeFor[big.Float]()] = ss +} + +// Disallow jsonschema tag values beginning "WORD=", for future expansion. +var disallowedPrefixRegexp = regexp.MustCompile("^[^ \t\n]*=") diff --git a/vendor/github.com/google/jsonschema-go/jsonschema/json_pointer.go b/vendor/github.com/google/jsonschema-go/jsonschema/json_pointer.go new file mode 100644 index 0000000000..ed1b16991c --- /dev/null +++ b/vendor/github.com/google/jsonschema-go/jsonschema/json_pointer.go @@ -0,0 +1,146 @@ +// Copyright 2025 The JSON Schema Go Project Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file implements JSON Pointers. +// A JSON Pointer is a path that refers to one JSON value within another. +// If the path is empty, it refers to the root value. +// Otherwise, it is a sequence of slash-prefixed strings, like "/points/1/x", +// selecting successive properties (for JSON objects) or items (for JSON arrays). +// For example, when applied to this JSON value: +// { +// "points": [ +// {"x": 1, "y": 2}, +// {"x": 3, "y": 4} +// ] +// } +// +// the JSON Pointer "/points/1/x" refers to the number 3. +// See the spec at https://datatracker.ietf.org/doc/html/rfc6901. + +package jsonschema + +import ( + "errors" + "fmt" + "reflect" + "strconv" + "strings" +) + +var ( + jsonPointerEscaper = strings.NewReplacer("~", "~0", "/", "~1") + jsonPointerUnescaper = strings.NewReplacer("~0", "~", "~1", "/") +) + +func escapeJSONPointerSegment(s string) string { + return jsonPointerEscaper.Replace(s) +} + +func unescapeJSONPointerSegment(s string) string { + return jsonPointerUnescaper.Replace(s) +} + +// parseJSONPointer splits a JSON Pointer into a sequence of segments. It doesn't +// convert strings to numbers, because that depends on the traversal: a segment +// is treated as a number when applied to an array, but a string when applied to +// an object. See section 4 of the spec. +func parseJSONPointer(ptr string) (segments []string, err error) { + if ptr == "" { + return nil, nil + } + if ptr[0] != '/' { + return nil, fmt.Errorf("JSON Pointer %q does not begin with '/'", ptr) + } + // Unlike file paths, consecutive slashes are not coalesced. + // Split is nicer than Cut here, because it gets a final "/" right. + segments = strings.Split(ptr[1:], "/") + if strings.Contains(ptr, "~") { + // Undo the simple escaping rules that allow one to include a slash in a segment. + for i := range segments { + segments[i] = unescapeJSONPointerSegment(segments[i]) + } + } + return segments, nil +} + +// dereferenceJSONPointer returns the Schema that sptr points to within s, +// or an error if none. +// This implementation suffices for JSON Schema: pointers are applied only to Schemas, +// and refer only to Schemas. +func dereferenceJSONPointer(s *Schema, sptr string) (_ *Schema, err error) { + defer wrapf(&err, "JSON Pointer %q", sptr) + + segments, err := parseJSONPointer(sptr) + if err != nil { + return nil, err + } + v := reflect.ValueOf(s) + for _, seg := range segments { + switch v.Kind() { + case reflect.Pointer: + v = v.Elem() + if !v.IsValid() { + return nil, errors.New("navigated to nil reference") + } + fallthrough // if valid, can only be a pointer to a Schema + + case reflect.Struct: + // The segment must refer to a field in a Schema. + if v.Type() != reflect.TypeFor[Schema]() { + return nil, fmt.Errorf("navigated to non-Schema %s", v.Type()) + } + v = lookupSchemaField(v, seg) + if !v.IsValid() { + return nil, fmt.Errorf("no schema field %q", seg) + } + case reflect.Slice, reflect.Array: + // The segment must be an integer without leading zeroes that refers to an item in the + // slice or array. + if seg == "-" { + return nil, errors.New("the JSON Pointer array segment '-' is not supported") + } + if len(seg) > 1 && seg[0] == '0' { + return nil, fmt.Errorf("segment %q has leading zeroes", seg) + } + n, err := strconv.Atoi(seg) + if err != nil { + return nil, fmt.Errorf("invalid int: %q", seg) + } + if n < 0 || n >= v.Len() { + return nil, fmt.Errorf("index %d is out of bounds for array of length %d", n, v.Len()) + } + v = v.Index(n) + // Cannot be invalid. + case reflect.Map: + // The segment must be a key in the map. + v = v.MapIndex(reflect.ValueOf(seg)) + if !v.IsValid() { + return nil, fmt.Errorf("no key %q in map", seg) + } + default: + return nil, fmt.Errorf("value %s (%s) is not a schema, slice or map", v, v.Type()) + } + } + if s, ok := v.Interface().(*Schema); ok { + return s, nil + } + return nil, fmt.Errorf("does not refer to a schema, but to a %s", v.Type()) +} + +// lookupSchemaField returns the value of the field with the given name in v, +// or the zero value if there is no such field or it is not of type Schema or *Schema. +func lookupSchemaField(v reflect.Value, name string) reflect.Value { + if name == "type" { + // The "type" keyword may refer to Type or Types. + // At most one will be non-zero. + if t := v.FieldByName("Type"); !t.IsZero() { + return t + } + return v.FieldByName("Types") + } + if sf, ok := schemaFieldMap[name]; ok { + return v.FieldByIndex(sf.Index) + } + return reflect.Value{} +} diff --git a/vendor/github.com/google/jsonschema-go/jsonschema/resolve.go b/vendor/github.com/google/jsonschema-go/jsonschema/resolve.go new file mode 100644 index 0000000000..ece9be8807 --- /dev/null +++ b/vendor/github.com/google/jsonschema-go/jsonschema/resolve.go @@ -0,0 +1,548 @@ +// Copyright 2025 The JSON Schema Go Project Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file deals with preparing a schema for validation, including various checks, +// optimizations, and the resolution of cross-schema references. + +package jsonschema + +import ( + "errors" + "fmt" + "net/url" + "reflect" + "regexp" + "strings" +) + +// A Resolved consists of a [Schema] along with associated information needed to +// validate documents against it. +// A Resolved has been validated against its meta-schema, and all its references +// (the $ref and $dynamicRef keywords) have been resolved to their referenced Schemas. +// Call [Schema.Resolve] to obtain a Resolved from a Schema. +type Resolved struct { + root *Schema + // map from $ids to their schemas + resolvedURIs map[string]*Schema + // map from schemas to additional info computed during resolution + resolvedInfos map[*Schema]*resolvedInfo +} + +func newResolved(s *Schema) *Resolved { + return &Resolved{ + root: s, + resolvedURIs: map[string]*Schema{}, + resolvedInfos: map[*Schema]*resolvedInfo{}, + } +} + +// resolvedInfo holds information specific to a schema that is computed by [Schema.Resolve]. +type resolvedInfo struct { + s *Schema + // The JSON Pointer path from the root schema to here. + // Used in errors. + path string + // The schema's base schema. + // If the schema is the root or has an ID, its base is itself. + // Otherwise, its base is the innermost enclosing schema whose base + // is itself. + // Intuitively, a base schema is one that can be referred to with a + // fragmentless URI. + base *Schema + // The URI for the schema, if it is the root or has an ID. + // Otherwise nil. + // Invariants: + // s.base.uri != nil. + // s.base == s <=> s.uri != nil + uri *url.URL + // The schema to which Ref refers. + resolvedRef *Schema + + // If the schema has a dynamic ref, exactly one of the next two fields + // will be non-zero after successful resolution. + // The schema to which the dynamic ref refers when it acts lexically. + resolvedDynamicRef *Schema + // The anchor to look up on the stack when the dynamic ref acts dynamically. + dynamicRefAnchor string + + // The following fields are independent of arguments to Schema.Resolved, + // so they could live on the Schema. We put them here for simplicity. + + // The set of required properties. + isRequired map[string]bool + + // Compiled regexps. + pattern *regexp.Regexp + patternProperties map[*regexp.Regexp]*Schema + + // Map from anchors to subschemas. + anchors map[string]anchorInfo +} + +// Schema returns the schema that was resolved. +// It must not be modified. +func (r *Resolved) Schema() *Schema { return r.root } + +// schemaString returns a short string describing the schema. +func (r *Resolved) schemaString(s *Schema) string { + if s.ID != "" { + return s.ID + } + info := r.resolvedInfos[s] + if info.path != "" { + return info.path + } + return "" +} + +// A Loader reads and unmarshals the schema at uri, if any. +type Loader func(uri *url.URL) (*Schema, error) + +// ResolveOptions are options for [Schema.Resolve]. +type ResolveOptions struct { + // BaseURI is the URI relative to which the root schema should be resolved. + // If non-empty, must be an absolute URI (one that starts with a scheme). + // It is resolved (in the URI sense; see [url.ResolveReference]) with root's + // $id property. + // If the resulting URI is not absolute, then the schema cannot contain + // relative URI references. + BaseURI string + // Loader loads schemas that are referred to by a $ref but are not under the + // root schema (remote references). + // If nil, resolving a remote reference will return an error. + Loader Loader + // ValidateDefaults determines whether to validate values of "default" keywords + // against their schemas. + // The [JSON Schema specification] does not require this, but it is recommended + // if defaults will be used. + // + // [JSON Schema specification]: https://json-schema.org/understanding-json-schema/reference/annotations + ValidateDefaults bool +} + +// Resolve resolves all references within the schema and performs other tasks that +// prepare the schema for validation. +// If opts is nil, the default values are used. +// The schema must not be changed after Resolve is called. +// The same schema may be resolved multiple times. +func (root *Schema) Resolve(opts *ResolveOptions) (*Resolved, error) { + // There are up to five steps required to prepare a schema to validate. + // 1. Load: read the schema from somewhere and unmarshal it. + // This schema (root) may have been loaded or created in memory, but other schemas that + // come into the picture in step 4 will be loaded by the given loader. + // 2. Check: validate the schema against a meta-schema, and perform other well-formedness checks. + // Precompute some values along the way. + // 3. Resolve URIs: determine the base URI of the root and all its subschemas, and + // resolve (in the URI sense) all identifiers and anchors with their bases. This step results + // in a map from URIs to schemas within root. + // 4. Resolve references: all refs in the schemas are replaced with the schema they refer to. + // 5. (Optional.) If opts.ValidateDefaults is true, validate the defaults. + r := &resolver{loaded: map[string]*Resolved{}} + if opts != nil { + r.opts = *opts + } + var base *url.URL + if r.opts.BaseURI == "" { + base = &url.URL{} // so we can call ResolveReference on it + } else { + var err error + base, err = url.Parse(r.opts.BaseURI) + if err != nil { + return nil, fmt.Errorf("parsing base URI: %w", err) + } + } + + if r.opts.Loader == nil { + r.opts.Loader = func(uri *url.URL) (*Schema, error) { + return nil, errors.New("cannot resolve remote schemas: no loader passed to Schema.Resolve") + } + } + + resolved, err := r.resolve(root, base) + if err != nil { + return nil, err + } + if r.opts.ValidateDefaults { + if err := resolved.validateDefaults(); err != nil { + return nil, err + } + } + // TODO: before we return, throw away anything we don't need for validation. + return resolved, nil +} + +// A resolver holds the state for resolution. +type resolver struct { + opts ResolveOptions + // A cache of loaded and partly resolved schemas. (They may not have had their + // refs resolved.) The cache ensures that the loader will never be called more + // than once with the same URI, and that reference cycles are handled properly. + loaded map[string]*Resolved +} + +func (r *resolver) resolve(s *Schema, baseURI *url.URL) (*Resolved, error) { + if baseURI.Fragment != "" { + return nil, fmt.Errorf("base URI %s must not have a fragment", baseURI) + } + rs := newResolved(s) + + if err := s.check(rs.resolvedInfos); err != nil { + return nil, err + } + + if err := resolveURIs(rs, baseURI); err != nil { + return nil, err + } + + // Remember the schema by both the URI we loaded it from and its canonical name, + // which may differ if the schema has an $id. + // We must set the map before calling resolveRefs, or ref cycles will cause unbounded recursion. + r.loaded[baseURI.String()] = rs + r.loaded[rs.resolvedInfos[s].uri.String()] = rs + + if err := r.resolveRefs(rs); err != nil { + return nil, err + } + return rs, nil +} + +func (root *Schema) check(infos map[*Schema]*resolvedInfo) error { + // Check for structural validity. Do this first and fail fast: + // bad structure will cause other code to panic. + if err := root.checkStructure(infos); err != nil { + return err + } + + var errs []error + report := func(err error) { errs = append(errs, err) } + + for ss := range root.all() { + ss.checkLocal(report, infos) + } + return errors.Join(errs...) +} + +// checkStructure verifies that root and its subschemas form a tree. +// It also assigns each schema a unique path, to improve error messages. +func (root *Schema) checkStructure(infos map[*Schema]*resolvedInfo) error { + assert(len(infos) == 0, "non-empty infos") + + var check func(reflect.Value, []byte) error + check = func(v reflect.Value, path []byte) error { + // For the purpose of error messages, the root schema has path "root" + // and other schemas' paths are their JSON Pointer from the root. + p := "root" + if len(path) > 0 { + p = string(path) + } + s := v.Interface().(*Schema) + if s == nil { + return fmt.Errorf("jsonschema: schema at %s is nil", p) + } + if info, ok := infos[s]; ok { + // We've seen s before. + // The schema graph at root is not a tree, but it needs to + // be because a schema's base must be unique. + // A cycle would also put Schema.all into an infinite recursion. + return fmt.Errorf("jsonschema: schemas at %s do not form a tree; %s appears more than once (also at %s)", + root, info.path, p) + } + infos[s] = &resolvedInfo{s: s, path: p} + + for _, info := range schemaFieldInfos { + fv := v.Elem().FieldByIndex(info.sf.Index) + switch info.sf.Type { + case schemaType: + // A field that contains an individual schema. + // A nil is valid: it just means the field isn't present. + if !fv.IsNil() { + if err := check(fv, fmt.Appendf(path, "/%s", info.jsonName)); err != nil { + return err + } + } + + case schemaSliceType: + for i := range fv.Len() { + if err := check(fv.Index(i), fmt.Appendf(path, "/%s/%d", info.jsonName, i)); err != nil { + return err + } + } + + case schemaMapType: + iter := fv.MapRange() + for iter.Next() { + key := escapeJSONPointerSegment(iter.Key().String()) + if err := check(iter.Value(), fmt.Appendf(path, "/%s/%s", info.jsonName, key)); err != nil { + return err + } + } + } + + } + return nil + } + + return check(reflect.ValueOf(root), make([]byte, 0, 256)) +} + +// checkLocal checks s for validity, independently of other schemas it may refer to. +// Since checking a regexp involves compiling it, checkLocal saves those compiled regexps +// in the schema for later use. +// It appends the errors it finds to errs. +func (s *Schema) checkLocal(report func(error), infos map[*Schema]*resolvedInfo) { + addf := func(format string, args ...any) { + msg := fmt.Sprintf(format, args...) + report(fmt.Errorf("jsonschema.Schema: %s: %s", s, msg)) + } + + if s == nil { + addf("nil subschema") + return + } + if err := s.basicChecks(); err != nil { + report(err) + return + } + + // TODO: validate the schema's properties, + // ideally by jsonschema-validating it against the meta-schema. + + // Some properties are present so that Schemas can round-trip, but we do not + // validate them. + // Currently, it's just the $vocabulary property. + // As a special case, we can validate the 2020-12 meta-schema. + if s.Vocabulary != nil && s.Schema != draft202012 { + addf("cannot validate a schema with $vocabulary") + } + + info := infos[s] + + // Check and compile regexps. + if s.Pattern != "" { + re, err := regexp.Compile(s.Pattern) + if err != nil { + addf("pattern: %v", err) + } else { + info.pattern = re + } + } + if len(s.PatternProperties) > 0 { + info.patternProperties = map[*regexp.Regexp]*Schema{} + for reString, subschema := range s.PatternProperties { + re, err := regexp.Compile(reString) + if err != nil { + addf("patternProperties[%q]: %v", reString, err) + continue + } + info.patternProperties[re] = subschema + } + } + + // Build a set of required properties, to avoid quadratic behavior when validating + // a struct. + if len(s.Required) > 0 { + info.isRequired = map[string]bool{} + for _, r := range s.Required { + info.isRequired[r] = true + } + } +} + +// resolveURIs resolves the ids and anchors in all the schemas of root, relative +// to baseURI. +// See https://json-schema.org/draft/2020-12/json-schema-core#section-8.2, section +// 8.2.1. +// +// Every schema has a base URI and a parent base URI. +// +// The parent base URI is the base URI of the lexically enclosing schema, or for +// a root schema, the URI it was loaded from or the one supplied to [Schema.Resolve]. +// +// If the schema has no $id property, the base URI of a schema is that of its parent. +// If the schema does have an $id, it must be a URI, possibly relative. The schema's +// base URI is the $id resolved (in the sense of [url.URL.ResolveReference]) against +// the parent base. +// +// As an example, consider this schema loaded from http://a.com/root.json (quotes omitted): +// +// { +// allOf: [ +// {$id: "sub1.json", minLength: 5}, +// {$id: "http://b.com", minimum: 10}, +// {not: {maximum: 20}} +// ] +// } +// +// The base URIs are as follows. Schema locations are expressed in the JSON Pointer notation. +// +// schema base URI +// root http://a.com/root.json +// allOf/0 http://a.com/sub1.json +// allOf/1 http://b.com (absolute $id; doesn't matter that it's not under the loaded URI) +// allOf/2 http://a.com/root.json (inherited from parent) +// allOf/2/not http://a.com/root.json (inherited from parent) +func resolveURIs(rs *Resolved, baseURI *url.URL) error { + var resolve func(s, base *Schema) error + resolve = func(s, base *Schema) error { + info := rs.resolvedInfos[s] + baseInfo := rs.resolvedInfos[base] + + // ids are scoped to the root. + if s.ID != "" { + // A non-empty ID establishes a new base. + idURI, err := url.Parse(s.ID) + if err != nil { + return err + } + if idURI.Fragment != "" { + return fmt.Errorf("$id %s must not have a fragment", s.ID) + } + // The base URI for this schema is its $id resolved against the parent base. + info.uri = baseInfo.uri.ResolveReference(idURI) + if !info.uri.IsAbs() { + return fmt.Errorf("$id %s does not resolve to an absolute URI (base is %q)", s.ID, baseInfo.uri) + } + rs.resolvedURIs[info.uri.String()] = s + base = s // needed for anchors + baseInfo = rs.resolvedInfos[base] + } + info.base = base + + // Anchors and dynamic anchors are URI fragments that are scoped to their base. + // We treat them as keys in a map stored within the schema. + setAnchor := func(anchor string, dynamic bool) error { + if anchor != "" { + if _, ok := baseInfo.anchors[anchor]; ok { + return fmt.Errorf("duplicate anchor %q in %s", anchor, baseInfo.uri) + } + if baseInfo.anchors == nil { + baseInfo.anchors = map[string]anchorInfo{} + } + baseInfo.anchors[anchor] = anchorInfo{s, dynamic} + } + return nil + } + + setAnchor(s.Anchor, false) + setAnchor(s.DynamicAnchor, true) + + for c := range s.children() { + if err := resolve(c, base); err != nil { + return err + } + } + return nil + } + + // Set the root URI to the base for now. If the root has an $id, this will change. + rs.resolvedInfos[rs.root].uri = baseURI + // The original base, even if changed, is still a valid way to refer to the root. + rs.resolvedURIs[baseURI.String()] = rs.root + + return resolve(rs.root, rs.root) +} + +// resolveRefs replaces every ref in the schemas with the schema it refers to. +// A reference that doesn't resolve within the schema may refer to some other schema +// that needs to be loaded. +func (r *resolver) resolveRefs(rs *Resolved) error { + for s := range rs.root.all() { + info := rs.resolvedInfos[s] + if s.Ref != "" { + refSchema, _, err := r.resolveRef(rs, s, s.Ref) + if err != nil { + return err + } + // Whether or not the anchor referred to by $ref fragment is dynamic, + // the ref still treats it lexically. + info.resolvedRef = refSchema + } + if s.DynamicRef != "" { + refSchema, frag, err := r.resolveRef(rs, s, s.DynamicRef) + if err != nil { + return err + } + if frag != "" { + // The dynamic ref's fragment points to a dynamic anchor. + // We must resolve the fragment at validation time. + info.dynamicRefAnchor = frag + } else { + // There is no dynamic anchor in the lexically referenced schema, + // so the dynamic ref behaves like a lexical ref. + info.resolvedDynamicRef = refSchema + } + } + } + return nil +} + +// resolveRef resolves the reference ref, which is either s.Ref or s.DynamicRef. +func (r *resolver) resolveRef(rs *Resolved, s *Schema, ref string) (_ *Schema, dynamicFragment string, err error) { + refURI, err := url.Parse(ref) + if err != nil { + return nil, "", err + } + // URI-resolve the ref against the current base URI to get a complete URI. + base := rs.resolvedInfos[s].base + refURI = rs.resolvedInfos[base].uri.ResolveReference(refURI) + // The non-fragment part of a ref URI refers to the base URI of some schema. + // This part is the same for dynamic refs too: their non-fragment part resolves + // lexically. + u := *refURI + u.Fragment = "" + fraglessRefURI := &u + // Look it up locally. + referencedSchema := rs.resolvedURIs[fraglessRefURI.String()] + if referencedSchema == nil { + // The schema is remote. Maybe we've already loaded it. + // We assume that the non-fragment part of refURI refers to a top-level schema + // document. That is, we don't support the case exemplified by + // http://foo.com/bar.json/baz, where the document is in bar.json and + // the reference points to a subschema within it. + // TODO: support that case. + if lrs := r.loaded[fraglessRefURI.String()]; lrs != nil { + referencedSchema = lrs.root + } else { + // Try to load the schema. + ls, err := r.opts.Loader(fraglessRefURI) + if err != nil { + return nil, "", fmt.Errorf("loading %s: %w", fraglessRefURI, err) + } + lrs, err := r.resolve(ls, fraglessRefURI) + if err != nil { + return nil, "", err + } + referencedSchema = lrs.root + assert(referencedSchema != nil, "nil referenced schema") + // Copy the resolvedInfos from lrs into rs, without overwriting + // (hence we can't use maps.Insert). + for s, i := range lrs.resolvedInfos { + if rs.resolvedInfos[s] == nil { + rs.resolvedInfos[s] = i + } + } + } + } + + frag := refURI.Fragment + // Look up frag in refSchema. + // frag is either a JSON Pointer or the name of an anchor. + // A JSON Pointer is either the empty string or begins with a '/', + // whereas anchors are always non-empty strings that don't contain slashes. + if frag != "" && !strings.HasPrefix(frag, "/") { + resInfo := rs.resolvedInfos[referencedSchema] + info, found := resInfo.anchors[frag] + + if !found { + return nil, "", fmt.Errorf("no anchor %q in %s", frag, s) + } + if info.dynamic { + dynamicFragment = frag + } + return info.schema, dynamicFragment, nil + } + // frag is a JSON Pointer. + s, err = dereferenceJSONPointer(referencedSchema, frag) + return s, "", err +} diff --git a/vendor/github.com/google/jsonschema-go/jsonschema/schema.go b/vendor/github.com/google/jsonschema-go/jsonschema/schema.go new file mode 100644 index 0000000000..3b4db9a6e3 --- /dev/null +++ b/vendor/github.com/google/jsonschema-go/jsonschema/schema.go @@ -0,0 +1,436 @@ +// Copyright 2025 The JSON Schema Go Project Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package jsonschema + +import ( + "bytes" + "cmp" + "encoding/json" + "errors" + "fmt" + "iter" + "maps" + "math" + "reflect" + "slices" +) + +// A Schema is a JSON schema object. +// It corresponds to the 2020-12 draft, as described in https://json-schema.org/draft/2020-12, +// specifically: +// - https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-01 +// - https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01 +// +// A Schema value may have non-zero values for more than one field: +// all relevant non-zero fields are used for validation. +// There is one exception to provide more Go type-safety: the Type and Types fields +// are mutually exclusive. +// +// Since this struct is a Go representation of a JSON value, it inherits JSON's +// distinction between nil and empty. Nil slices and maps are considered absent, +// but empty ones are present and affect validation. For example, +// +// Schema{Enum: nil} +// +// is equivalent to an empty schema, so it validates every instance. But +// +// Schema{Enum: []any{}} +// +// requires equality to some slice element, so it vacuously rejects every instance. +type Schema struct { + // core + ID string `json:"$id,omitempty"` + Schema string `json:"$schema,omitempty"` + Ref string `json:"$ref,omitempty"` + Comment string `json:"$comment,omitempty"` + Defs map[string]*Schema `json:"$defs,omitempty"` + // definitions is deprecated but still allowed. It is a synonym for $defs. + Definitions map[string]*Schema `json:"definitions,omitempty"` + + Anchor string `json:"$anchor,omitempty"` + DynamicAnchor string `json:"$dynamicAnchor,omitempty"` + DynamicRef string `json:"$dynamicRef,omitempty"` + Vocabulary map[string]bool `json:"$vocabulary,omitempty"` + + // metadata + Title string `json:"title,omitempty"` + Description string `json:"description,omitempty"` + Default json.RawMessage `json:"default,omitempty"` + Deprecated bool `json:"deprecated,omitempty"` + ReadOnly bool `json:"readOnly,omitempty"` + WriteOnly bool `json:"writeOnly,omitempty"` + Examples []any `json:"examples,omitempty"` + + // validation + // Use Type for a single type, or Types for multiple types; never both. + Type string `json:"-"` + Types []string `json:"-"` + Enum []any `json:"enum,omitempty"` + // Const is *any because a JSON null (Go nil) is a valid value. + Const *any `json:"const,omitempty"` + MultipleOf *float64 `json:"multipleOf,omitempty"` + Minimum *float64 `json:"minimum,omitempty"` + Maximum *float64 `json:"maximum,omitempty"` + ExclusiveMinimum *float64 `json:"exclusiveMinimum,omitempty"` + ExclusiveMaximum *float64 `json:"exclusiveMaximum,omitempty"` + MinLength *int `json:"minLength,omitempty"` + MaxLength *int `json:"maxLength,omitempty"` + Pattern string `json:"pattern,omitempty"` + + // arrays + PrefixItems []*Schema `json:"prefixItems,omitempty"` + Items *Schema `json:"items,omitempty"` + MinItems *int `json:"minItems,omitempty"` + MaxItems *int `json:"maxItems,omitempty"` + AdditionalItems *Schema `json:"additionalItems,omitempty"` + UniqueItems bool `json:"uniqueItems,omitempty"` + Contains *Schema `json:"contains,omitempty"` + MinContains *int `json:"minContains,omitempty"` // *int, not int: default is 1, not 0 + MaxContains *int `json:"maxContains,omitempty"` + UnevaluatedItems *Schema `json:"unevaluatedItems,omitempty"` + + // objects + MinProperties *int `json:"minProperties,omitempty"` + MaxProperties *int `json:"maxProperties,omitempty"` + Required []string `json:"required,omitempty"` + DependentRequired map[string][]string `json:"dependentRequired,omitempty"` + Properties map[string]*Schema `json:"properties,omitempty"` + PatternProperties map[string]*Schema `json:"patternProperties,omitempty"` + AdditionalProperties *Schema `json:"additionalProperties,omitempty"` + PropertyNames *Schema `json:"propertyNames,omitempty"` + UnevaluatedProperties *Schema `json:"unevaluatedProperties,omitempty"` + + // logic + AllOf []*Schema `json:"allOf,omitempty"` + AnyOf []*Schema `json:"anyOf,omitempty"` + OneOf []*Schema `json:"oneOf,omitempty"` + Not *Schema `json:"not,omitempty"` + + // conditional + If *Schema `json:"if,omitempty"` + Then *Schema `json:"then,omitempty"` + Else *Schema `json:"else,omitempty"` + DependentSchemas map[string]*Schema `json:"dependentSchemas,omitempty"` + + // other + // https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-00#rfc.section.8 + ContentEncoding string `json:"contentEncoding,omitempty"` + ContentMediaType string `json:"contentMediaType,omitempty"` + ContentSchema *Schema `json:"contentSchema,omitempty"` + + // https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-00#rfc.section.7 + Format string `json:"format,omitempty"` + + // Extra allows for additional keywords beyond those specified. + Extra map[string]any `json:"-"` +} + +// falseSchema returns a new Schema tree that fails to validate any value. +func falseSchema() *Schema { + return &Schema{Not: &Schema{}} +} + +// anchorInfo records the subschema to which an anchor refers, and whether +// the anchor keyword is $anchor or $dynamicAnchor. +type anchorInfo struct { + schema *Schema + dynamic bool +} + +// String returns a short description of the schema. +func (s *Schema) String() string { + if s.ID != "" { + return s.ID + } + if a := cmp.Or(s.Anchor, s.DynamicAnchor); a != "" { + return fmt.Sprintf("anchor %s", a) + } + return "" +} + +// CloneSchemas returns a copy of s. +// The copy is shallow except for sub-schemas, which are themelves copied with CloneSchemas. +// This allows both s and s.CloneSchemas() to appear as sub-schemas of the same parent. +func (s *Schema) CloneSchemas() *Schema { + if s == nil { + return nil + } + s2 := *s + v := reflect.ValueOf(&s2) + for _, info := range schemaFieldInfos { + fv := v.Elem().FieldByIndex(info.sf.Index) + switch info.sf.Type { + case schemaType: + sscss := fv.Interface().(*Schema) + fv.Set(reflect.ValueOf(sscss.CloneSchemas())) + + case schemaSliceType: + slice := fv.Interface().([]*Schema) + slice = slices.Clone(slice) + for i, ss := range slice { + slice[i] = ss.CloneSchemas() + } + fv.Set(reflect.ValueOf(slice)) + + case schemaMapType: + m := fv.Interface().(map[string]*Schema) + m = maps.Clone(m) + for k, ss := range m { + m[k] = ss.CloneSchemas() + } + fv.Set(reflect.ValueOf(m)) + } + } + return &s2 +} + +func (s *Schema) basicChecks() error { + if s.Type != "" && s.Types != nil { + return errors.New("both Type and Types are set; at most one should be") + } + if s.Defs != nil && s.Definitions != nil { + return errors.New("both Defs and Definitions are set; at most one should be") + } + return nil +} + +type schemaWithoutMethods Schema // doesn't implement json.{Unm,M}arshaler + +func (s *Schema) MarshalJSON() ([]byte, error) { + if err := s.basicChecks(); err != nil { + return nil, err + } + + // Marshal either Type or Types as "type". + var typ any + switch { + case s.Type != "": + typ = s.Type + case s.Types != nil: + typ = s.Types + } + ms := struct { + Type any `json:"type,omitempty"` + *schemaWithoutMethods + }{ + Type: typ, + schemaWithoutMethods: (*schemaWithoutMethods)(s), + } + bs, err := marshalStructWithMap(&ms, "Extra") + if err != nil { + return nil, err + } + // Marshal {} as true and {"not": {}} as false. + // It is wasteful to do this here instead of earlier, but much easier. + switch { + case bytes.Equal(bs, []byte(`{}`)): + bs = []byte("true") + case bytes.Equal(bs, []byte(`{"not":true}`)): + bs = []byte("false") + } + return bs, nil +} + +func (s *Schema) UnmarshalJSON(data []byte) error { + // A JSON boolean is a valid schema. + var b bool + if err := json.Unmarshal(data, &b); err == nil { + if b { + // true is the empty schema, which validates everything. + *s = Schema{} + } else { + // false is the schema that validates nothing. + *s = *falseSchema() + } + return nil + } + + ms := struct { + Type json.RawMessage `json:"type,omitempty"` + Const json.RawMessage `json:"const,omitempty"` + MinLength *integer `json:"minLength,omitempty"` + MaxLength *integer `json:"maxLength,omitempty"` + MinItems *integer `json:"minItems,omitempty"` + MaxItems *integer `json:"maxItems,omitempty"` + MinProperties *integer `json:"minProperties,omitempty"` + MaxProperties *integer `json:"maxProperties,omitempty"` + MinContains *integer `json:"minContains,omitempty"` + MaxContains *integer `json:"maxContains,omitempty"` + + *schemaWithoutMethods + }{ + schemaWithoutMethods: (*schemaWithoutMethods)(s), + } + if err := unmarshalStructWithMap(data, &ms, "Extra"); err != nil { + return err + } + // Unmarshal "type" as either Type or Types. + var err error + if len(ms.Type) > 0 { + switch ms.Type[0] { + case '"': + err = json.Unmarshal(ms.Type, &s.Type) + case '[': + err = json.Unmarshal(ms.Type, &s.Types) + default: + err = fmt.Errorf(`invalid value for "type": %q`, ms.Type) + } + } + if err != nil { + return err + } + + unmarshalAnyPtr := func(p **any, raw json.RawMessage) error { + if len(raw) == 0 { + return nil + } + if bytes.Equal(raw, []byte("null")) { + *p = new(any) + return nil + } + return json.Unmarshal(raw, p) + } + + // Setting Const to a pointer to null will marshal properly, but won't + // unmarshal: the *any is set to nil, not a pointer to nil. + if err := unmarshalAnyPtr(&s.Const, ms.Const); err != nil { + return err + } + + set := func(dst **int, src *integer) { + if src != nil { + *dst = Ptr(int(*src)) + } + } + + set(&s.MinLength, ms.MinLength) + set(&s.MaxLength, ms.MaxLength) + set(&s.MinItems, ms.MinItems) + set(&s.MaxItems, ms.MaxItems) + set(&s.MinProperties, ms.MinProperties) + set(&s.MaxProperties, ms.MaxProperties) + set(&s.MinContains, ms.MinContains) + set(&s.MaxContains, ms.MaxContains) + + return nil +} + +type integer int32 // for the integer-valued fields of Schema + +func (ip *integer) UnmarshalJSON(data []byte) error { + if len(data) == 0 { + // nothing to do + return nil + } + // If there is a decimal point, src is a floating-point number. + var i int64 + if bytes.ContainsRune(data, '.') { + var f float64 + if err := json.Unmarshal(data, &f); err != nil { + return errors.New("not a number") + } + i = int64(f) + if float64(i) != f { + return errors.New("not an integer value") + } + } else { + if err := json.Unmarshal(data, &i); err != nil { + return errors.New("cannot be unmarshaled into an int") + } + } + // Ensure behavior is the same on both 32-bit and 64-bit systems. + if i < math.MinInt32 || i > math.MaxInt32 { + return errors.New("integer is out of range") + } + *ip = integer(i) + return nil +} + +// Ptr returns a pointer to a new variable whose value is x. +func Ptr[T any](x T) *T { return &x } + +// every applies f preorder to every schema under s including s. +// The second argument to f is the path to the schema appended to the argument path. +// It stops when f returns false. +func (s *Schema) every(f func(*Schema) bool) bool { + return f(s) && s.everyChild(func(s *Schema) bool { return s.every(f) }) +} + +// everyChild reports whether f is true for every immediate child schema of s. +func (s *Schema) everyChild(f func(*Schema) bool) bool { + v := reflect.ValueOf(s) + for _, info := range schemaFieldInfos { + fv := v.Elem().FieldByIndex(info.sf.Index) + switch info.sf.Type { + case schemaType: + // A field that contains an individual schema. A nil is valid: it just means the field isn't present. + c := fv.Interface().(*Schema) + if c != nil && !f(c) { + return false + } + + case schemaSliceType: + slice := fv.Interface().([]*Schema) + for _, c := range slice { + if !f(c) { + return false + } + } + + case schemaMapType: + // Sort keys for determinism. + m := fv.Interface().(map[string]*Schema) + for _, k := range slices.Sorted(maps.Keys(m)) { + if !f(m[k]) { + return false + } + } + } + } + return true +} + +// all wraps every in an iterator. +func (s *Schema) all() iter.Seq[*Schema] { + return func(yield func(*Schema) bool) { s.every(yield) } +} + +// children wraps everyChild in an iterator. +func (s *Schema) children() iter.Seq[*Schema] { + return func(yield func(*Schema) bool) { s.everyChild(yield) } +} + +var ( + schemaType = reflect.TypeFor[*Schema]() + schemaSliceType = reflect.TypeFor[[]*Schema]() + schemaMapType = reflect.TypeFor[map[string]*Schema]() +) + +type structFieldInfo struct { + sf reflect.StructField + jsonName string +} + +var ( + // the visible fields of Schema that have a JSON name, sorted by that name + schemaFieldInfos []structFieldInfo + // map from JSON name to field + schemaFieldMap = map[string]reflect.StructField{} +) + +func init() { + for _, sf := range reflect.VisibleFields(reflect.TypeFor[Schema]()) { + info := fieldJSONInfo(sf) + if !info.omit { + schemaFieldInfos = append(schemaFieldInfos, structFieldInfo{sf, info.name}) + } + } + slices.SortFunc(schemaFieldInfos, func(i1, i2 structFieldInfo) int { + return cmp.Compare(i1.jsonName, i2.jsonName) + }) + for _, info := range schemaFieldInfos { + schemaFieldMap[info.jsonName] = info.sf + } +} diff --git a/vendor/github.com/google/jsonschema-go/jsonschema/util.go b/vendor/github.com/google/jsonschema-go/jsonschema/util.go new file mode 100644 index 0000000000..5cfa27dc64 --- /dev/null +++ b/vendor/github.com/google/jsonschema-go/jsonschema/util.go @@ -0,0 +1,463 @@ +// Copyright 2025 The JSON Schema Go Project Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package jsonschema + +import ( + "bytes" + "cmp" + "encoding/binary" + "encoding/json" + "fmt" + "hash/maphash" + "math" + "math/big" + "reflect" + "slices" + "strings" + "sync" +) + +// Equal reports whether two Go values representing JSON values are equal according +// to the JSON Schema spec. +// The values must not contain cycles. +// See https://json-schema.org/draft/2020-12/json-schema-core#section-4.2.2. +// It behaves like reflect.DeepEqual, except that numbers are compared according +// to mathematical equality. +func Equal(x, y any) bool { + return equalValue(reflect.ValueOf(x), reflect.ValueOf(y)) +} + +func equalValue(x, y reflect.Value) bool { + // Copied from src/reflect/deepequal.go, omitting the visited check (because JSON + // values are trees). + if !x.IsValid() || !y.IsValid() { + return x.IsValid() == y.IsValid() + } + + // Treat numbers specially. + rx, ok1 := jsonNumber(x) + ry, ok2 := jsonNumber(y) + if ok1 && ok2 { + return rx.Cmp(ry) == 0 + } + if x.Kind() != y.Kind() { + return false + } + switch x.Kind() { + case reflect.Array: + if x.Len() != y.Len() { + return false + } + for i := range x.Len() { + if !equalValue(x.Index(i), y.Index(i)) { + return false + } + } + return true + case reflect.Slice: + if x.IsNil() != y.IsNil() { + return false + } + if x.Len() != y.Len() { + return false + } + if x.UnsafePointer() == y.UnsafePointer() { + return true + } + // Special case for []byte, which is common. + if x.Type().Elem().Kind() == reflect.Uint8 && x.Type() == y.Type() { + return bytes.Equal(x.Bytes(), y.Bytes()) + } + for i := range x.Len() { + if !equalValue(x.Index(i), y.Index(i)) { + return false + } + } + return true + case reflect.Interface: + if x.IsNil() || y.IsNil() { + return x.IsNil() == y.IsNil() + } + return equalValue(x.Elem(), y.Elem()) + case reflect.Pointer: + if x.UnsafePointer() == y.UnsafePointer() { + return true + } + return equalValue(x.Elem(), y.Elem()) + case reflect.Struct: + t := x.Type() + if t != y.Type() { + return false + } + for i := range t.NumField() { + sf := t.Field(i) + if !sf.IsExported() { + continue + } + if !equalValue(x.FieldByIndex(sf.Index), y.FieldByIndex(sf.Index)) { + return false + } + } + return true + case reflect.Map: + if x.IsNil() != y.IsNil() { + return false + } + if x.Len() != y.Len() { + return false + } + if x.UnsafePointer() == y.UnsafePointer() { + return true + } + iter := x.MapRange() + for iter.Next() { + vx := iter.Value() + vy := y.MapIndex(iter.Key()) + if !vy.IsValid() || !equalValue(vx, vy) { + return false + } + } + return true + case reflect.Func: + if x.Type() != y.Type() { + return false + } + if x.IsNil() && y.IsNil() { + return true + } + panic("cannot compare functions") + case reflect.String: + return x.String() == y.String() + case reflect.Bool: + return x.Bool() == y.Bool() + // Ints, uints and floats handled in jsonNumber, at top of function. + default: + panic(fmt.Sprintf("unsupported kind: %s", x.Kind())) + } +} + +// hashValue adds v to the data hashed by h. v must not have cycles. +// hashValue panics if the value contains functions or channels, or maps whose +// key type is not string. +// It ignores unexported fields of structs. +// Calls to hashValue with the equal values (in the sense +// of [Equal]) result in the same sequence of values written to the hash. +func hashValue(h *maphash.Hash, v reflect.Value) { + // TODO: replace writes of basic types with WriteComparable in 1.24. + + writeUint := func(u uint64) { + var buf [8]byte + binary.BigEndian.PutUint64(buf[:], u) + h.Write(buf[:]) + } + + var write func(reflect.Value) + write = func(v reflect.Value) { + if r, ok := jsonNumber(v); ok { + // We want 1.0 and 1 to hash the same. + // big.Rats are always normalized, so they will be. + // We could do this more efficiently by handling the int and float cases + // separately, but that's premature. + writeUint(uint64(r.Sign() + 1)) + h.Write(r.Num().Bytes()) + h.Write(r.Denom().Bytes()) + return + } + switch v.Kind() { + case reflect.Invalid: + h.WriteByte(0) + case reflect.String: + h.WriteString(v.String()) + case reflect.Bool: + if v.Bool() { + h.WriteByte(1) + } else { + h.WriteByte(0) + } + case reflect.Complex64, reflect.Complex128: + c := v.Complex() + writeUint(math.Float64bits(real(c))) + writeUint(math.Float64bits(imag(c))) + case reflect.Array, reflect.Slice: + // Although we could treat []byte more efficiently, + // JSON values are unlikely to contain them. + writeUint(uint64(v.Len())) + for i := range v.Len() { + write(v.Index(i)) + } + case reflect.Interface, reflect.Pointer: + write(v.Elem()) + case reflect.Struct: + t := v.Type() + for i := range t.NumField() { + if sf := t.Field(i); sf.IsExported() { + write(v.FieldByIndex(sf.Index)) + } + } + case reflect.Map: + if v.Type().Key().Kind() != reflect.String { + panic("map with non-string key") + } + // Sort the keys so the hash is deterministic. + keys := v.MapKeys() + // Write the length. That distinguishes between, say, two consecutive + // maps with disjoint keys from one map that has the items of both. + writeUint(uint64(len(keys))) + slices.SortFunc(keys, func(x, y reflect.Value) int { return cmp.Compare(x.String(), y.String()) }) + for _, k := range keys { + write(k) + write(v.MapIndex(k)) + } + // Ints, uints and floats handled in jsonNumber, at top of function. + default: + panic(fmt.Sprintf("unsupported kind: %s", v.Kind())) + } + } + + write(v) +} + +// jsonNumber converts a numeric value or a json.Number to a [big.Rat]. +// If v is not a number, it returns nil, false. +func jsonNumber(v reflect.Value) (*big.Rat, bool) { + r := new(big.Rat) + switch { + case !v.IsValid(): + return nil, false + case v.CanInt(): + r.SetInt64(v.Int()) + case v.CanUint(): + r.SetUint64(v.Uint()) + case v.CanFloat(): + r.SetFloat64(v.Float()) + default: + jn, ok := v.Interface().(json.Number) + if !ok { + return nil, false + } + if _, ok := r.SetString(jn.String()); !ok { + // This can fail in rare cases; for example, "1e9999999". + // That is a valid JSON number, since the spec puts no limit on the size + // of the exponent. + return nil, false + } + } + return r, true +} + +// jsonType returns a string describing the type of the JSON value, +// as described in the JSON Schema specification: +// https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.1.1. +// It returns "", false if the value is not valid JSON. +func jsonType(v reflect.Value) (string, bool) { + if !v.IsValid() { + // Not v.IsNil(): a nil []any is still a JSON array. + return "null", true + } + if v.CanInt() || v.CanUint() { + return "integer", true + } + if v.CanFloat() { + if _, f := math.Modf(v.Float()); f == 0 { + return "integer", true + } + return "number", true + } + switch v.Kind() { + case reflect.Bool: + return "boolean", true + case reflect.String: + return "string", true + case reflect.Slice, reflect.Array: + return "array", true + case reflect.Map, reflect.Struct: + return "object", true + default: + return "", false + } +} + +func assert(cond bool, msg string) { + if !cond { + panic("assertion failed: " + msg) + } +} + +// marshalStructWithMap marshals its first argument to JSON, treating the field named +// mapField as an embedded map. The first argument must be a pointer to +// a struct. The underlying type of mapField must be a map[string]any, and it must have +// a "-" json tag, meaning it will not be marshaled. +// +// For example, given this struct: +// +// type S struct { +// A int +// Extra map[string] any `json:"-"` +// } +// +// and this value: +// +// s := S{A: 1, Extra: map[string]any{"B": 2}} +// +// the call marshalJSONWithMap(s, "Extra") would return +// +// {"A": 1, "B": 2} +// +// It is an error if the map contains the same key as another struct field's +// JSON name. +// +// marshalStructWithMap calls json.Marshal on a value of type T, so T must not +// have a MarshalJSON method that calls this function, on pain of infinite regress. +// +// Note that there is a similar function in mcp/util.go, but they are not the same. +// Here the function requires `-` json tag, does not clear the mapField map, +// and handles embedded struct due to the implementation of jsonNames in this package. +// +// TODO: avoid this restriction on T by forcing it to marshal in a default way. +// See https://go.dev/play/p/EgXKJHxEx_R. +func marshalStructWithMap[T any](s *T, mapField string) ([]byte, error) { + // Marshal the struct and the map separately, and concatenate the bytes. + // This strategy is dramatically less complicated than + // constructing a synthetic struct or map with the combined keys. + if s == nil { + return []byte("null"), nil + } + s2 := *s + vMapField := reflect.ValueOf(&s2).Elem().FieldByName(mapField) + mapVal := vMapField.Interface().(map[string]any) + + // Check for duplicates. + names := jsonNames(reflect.TypeFor[T]()) + for key := range mapVal { + if names[key] { + return nil, fmt.Errorf("map key %q duplicates struct field", key) + } + } + + structBytes, err := json.Marshal(s2) + if err != nil { + return nil, fmt.Errorf("marshalStructWithMap(%+v): %w", s, err) + } + if len(mapVal) == 0 { + return structBytes, nil + } + mapBytes, err := json.Marshal(mapVal) + if err != nil { + return nil, err + } + if len(structBytes) == 2 { // must be "{}" + return mapBytes, nil + } + // "{X}" + "{Y}" => "{X,Y}" + res := append(structBytes[:len(structBytes)-1], ',') + res = append(res, mapBytes[1:]...) + return res, nil +} + +// unmarshalStructWithMap is the inverse of marshalStructWithMap. +// T has the same restrictions as in that function. +// +// Note that there is a similar function in mcp/util.go, but they are not the same. +// Here jsonNames also returns fields from embedded structs, hence this function +// handles embedded structs as well. +func unmarshalStructWithMap[T any](data []byte, v *T, mapField string) error { + // Unmarshal into the struct, ignoring unknown fields. + if err := json.Unmarshal(data, v); err != nil { + return err + } + // Unmarshal into the map. + m := map[string]any{} + if err := json.Unmarshal(data, &m); err != nil { + return err + } + // Delete from the map the fields of the struct. + for n := range jsonNames(reflect.TypeFor[T]()) { + delete(m, n) + } + if len(m) != 0 { + reflect.ValueOf(v).Elem().FieldByName(mapField).Set(reflect.ValueOf(m)) + } + return nil +} + +var jsonNamesMap sync.Map // from reflect.Type to map[string]bool + +// jsonNames returns the set of JSON object keys that t will marshal into, +// including fields from embedded structs in t. +// t must be a struct type. +// +// Note that there is a similar function in mcp/util.go, but they are not the same +// Here the function recurses over embedded structs and includes fields from them. +func jsonNames(t reflect.Type) map[string]bool { + // Lock not necessary: at worst we'll duplicate work. + if val, ok := jsonNamesMap.Load(t); ok { + return val.(map[string]bool) + } + m := map[string]bool{} + for i := range t.NumField() { + field := t.Field(i) + // handle embedded structs + if field.Anonymous { + fieldType := field.Type + if fieldType.Kind() == reflect.Ptr { + fieldType = fieldType.Elem() + } + for n := range jsonNames(fieldType) { + m[n] = true + } + continue + } + info := fieldJSONInfo(field) + if !info.omit { + m[info.name] = true + } + } + jsonNamesMap.Store(t, m) + return m +} + +type jsonInfo struct { + omit bool // unexported or first tag element is "-" + name string // Go field name or first tag element. Empty if omit is true. + settings map[string]bool // "omitempty", "omitzero", etc. +} + +// fieldJSONInfo reports information about how encoding/json +// handles the given struct field. +// If the field is unexported, jsonInfo.omit is true and no other jsonInfo field +// is populated. +// If the field is exported and has no tag, then name is the field's name and all +// other fields are false. +// Otherwise, the information is obtained from the tag. +func fieldJSONInfo(f reflect.StructField) jsonInfo { + if !f.IsExported() { + return jsonInfo{omit: true} + } + info := jsonInfo{name: f.Name} + if tag, ok := f.Tag.Lookup("json"); ok { + name, rest, found := strings.Cut(tag, ",") + // "-" means omit, but "-," means the name is "-" + if name == "-" && !found { + return jsonInfo{omit: true} + } + if name != "" { + info.name = name + } + if len(rest) > 0 { + info.settings = map[string]bool{} + for _, s := range strings.Split(rest, ",") { + info.settings[s] = true + } + } + } + return info +} + +// wrapf wraps *errp with the given formatted message if *errp is not nil. +func wrapf(errp *error, format string, args ...any) { + if *errp != nil { + *errp = fmt.Errorf("%s: %w", fmt.Sprintf(format, args...), *errp) + } +} diff --git a/vendor/github.com/google/jsonschema-go/jsonschema/validate.go b/vendor/github.com/google/jsonschema-go/jsonschema/validate.go new file mode 100644 index 0000000000..b895bbd41b --- /dev/null +++ b/vendor/github.com/google/jsonschema-go/jsonschema/validate.go @@ -0,0 +1,789 @@ +// Copyright 2025 The JSON Schema Go Project Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package jsonschema + +import ( + "encoding/json" + "errors" + "fmt" + "hash/maphash" + "iter" + "math" + "math/big" + "reflect" + "slices" + "strings" + "sync" + "unicode/utf8" +) + +// The value of the "$schema" keyword for the version that we can validate. +const draft202012 = "https://json-schema.org/draft/2020-12/schema" + +// Validate validates the instance, which must be a JSON value, against the schema. +// It returns nil if validation is successful or an error if it is not. +// If the schema type is "object", instance can be a map[string]any or a struct. +func (rs *Resolved) Validate(instance any) error { + if s := rs.root.Schema; s != "" && s != draft202012 { + return fmt.Errorf("cannot validate version %s, only %s", s, draft202012) + } + st := &state{rs: rs} + return st.validate(reflect.ValueOf(instance), st.rs.root, nil) +} + +// validateDefaults walks the schema tree. If it finds a default, it validates it +// against the schema containing it. +// +// TODO(jba): account for dynamic refs. This algorithm simple-mindedly +// treats each schema with a default as its own root. +func (rs *Resolved) validateDefaults() error { + if s := rs.root.Schema; s != "" && s != draft202012 { + return fmt.Errorf("cannot validate version %s, only %s", s, draft202012) + } + st := &state{rs: rs} + for s := range rs.root.all() { + // We checked for nil schemas in [Schema.Resolve]. + assert(s != nil, "nil schema") + if s.DynamicRef != "" { + return fmt.Errorf("jsonschema: %s: validateDefaults does not support dynamic refs", rs.schemaString(s)) + } + if s.Default != nil { + var d any + if err := json.Unmarshal(s.Default, &d); err != nil { + return fmt.Errorf("unmarshaling default value of schema %s: %w", rs.schemaString(s), err) + } + if err := st.validate(reflect.ValueOf(d), s, nil); err != nil { + return err + } + } + } + return nil +} + +// state is the state of single call to ResolvedSchema.Validate. +type state struct { + rs *Resolved + // stack holds the schemas from recursive calls to validate. + // These are the "dynamic scopes" used to resolve dynamic references. + // https://json-schema.org/draft/2020-12/json-schema-core#scopes + stack []*Schema +} + +// validate validates the reflected value of the instance. +func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *annotations) (err error) { + defer wrapf(&err, "validating %s", st.rs.schemaString(schema)) + + // Maintain a stack for dynamic schema resolution. + st.stack = append(st.stack, schema) // push + defer func() { + st.stack = st.stack[:len(st.stack)-1] // pop + }() + + // We checked for nil schemas in [Schema.Resolve]. + assert(schema != nil, "nil schema") + + // Step through interfaces and pointers. + for instance.Kind() == reflect.Pointer || instance.Kind() == reflect.Interface { + instance = instance.Elem() + } + + schemaInfo := st.rs.resolvedInfos[schema] + + // type: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.1.1 + if schema.Type != "" || schema.Types != nil { + gotType, ok := jsonType(instance) + if !ok { + return fmt.Errorf("type: %v of type %[1]T is not a valid JSON value", instance) + } + if schema.Type != "" { + // "number" subsumes integers + if !(gotType == schema.Type || + gotType == "integer" && schema.Type == "number") { + return fmt.Errorf("type: %v has type %q, want %q", instance, gotType, schema.Type) + } + } else { + if !(slices.Contains(schema.Types, gotType) || (gotType == "integer" && slices.Contains(schema.Types, "number"))) { + return fmt.Errorf("type: %v has type %q, want one of %q", + instance, gotType, strings.Join(schema.Types, ", ")) + } + } + } + // enum: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.1.2 + if schema.Enum != nil { + ok := false + for _, e := range schema.Enum { + if equalValue(reflect.ValueOf(e), instance) { + ok = true + break + } + } + if !ok { + return fmt.Errorf("enum: %v does not equal any of: %v", instance, schema.Enum) + } + } + + // const: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.1.3 + if schema.Const != nil { + if !equalValue(reflect.ValueOf(*schema.Const), instance) { + return fmt.Errorf("const: %v does not equal %v", instance, *schema.Const) + } + } + + // numbers: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.2 + if schema.MultipleOf != nil || schema.Minimum != nil || schema.Maximum != nil || schema.ExclusiveMinimum != nil || schema.ExclusiveMaximum != nil { + n, ok := jsonNumber(instance) + if ok { // these keywords don't apply to non-numbers + if schema.MultipleOf != nil { + // TODO: validate MultipleOf as non-zero. + // The test suite assumes floats. + nf, _ := n.Float64() // don't care if it's exact or not + if _, f := math.Modf(nf / *schema.MultipleOf); f != 0 { + return fmt.Errorf("multipleOf: %s is not a multiple of %f", n, *schema.MultipleOf) + } + } + + m := new(big.Rat) // reuse for all of the following + cmp := func(f float64) int { return n.Cmp(m.SetFloat64(f)) } + + if schema.Minimum != nil && cmp(*schema.Minimum) < 0 { + return fmt.Errorf("minimum: %s is less than %f", n, *schema.Minimum) + } + if schema.Maximum != nil && cmp(*schema.Maximum) > 0 { + return fmt.Errorf("maximum: %s is greater than %f", n, *schema.Maximum) + } + if schema.ExclusiveMinimum != nil && cmp(*schema.ExclusiveMinimum) <= 0 { + return fmt.Errorf("exclusiveMinimum: %s is less than or equal to %f", n, *schema.ExclusiveMinimum) + } + if schema.ExclusiveMaximum != nil && cmp(*schema.ExclusiveMaximum) >= 0 { + return fmt.Errorf("exclusiveMaximum: %s is greater than or equal to %f", n, *schema.ExclusiveMaximum) + } + } + } + + // strings: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.3 + if instance.Kind() == reflect.String && (schema.MinLength != nil || schema.MaxLength != nil || schema.Pattern != "") { + str := instance.String() + n := utf8.RuneCountInString(str) + if schema.MinLength != nil { + if m := *schema.MinLength; n < m { + return fmt.Errorf("minLength: %q contains %d Unicode code points, fewer than %d", str, n, m) + } + } + if schema.MaxLength != nil { + if m := *schema.MaxLength; n > m { + return fmt.Errorf("maxLength: %q contains %d Unicode code points, more than %d", str, n, m) + } + } + + if schema.Pattern != "" && !schemaInfo.pattern.MatchString(str) { + return fmt.Errorf("pattern: %q does not match regular expression %q", str, schema.Pattern) + } + } + + var anns annotations // all the annotations for this call and child calls + + // $ref: https://json-schema.org/draft/2020-12/json-schema-core#section-8.2.3.1 + if schema.Ref != "" { + if err := st.validate(instance, schemaInfo.resolvedRef, &anns); err != nil { + return err + } + } + + // $dynamicRef: https://json-schema.org/draft/2020-12/json-schema-core#section-8.2.3.2 + if schema.DynamicRef != "" { + // The ref behaves lexically or dynamically, but not both. + assert((schemaInfo.resolvedDynamicRef == nil) != (schemaInfo.dynamicRefAnchor == ""), + "DynamicRef not resolved properly") + if schemaInfo.resolvedDynamicRef != nil { + // Same as $ref. + if err := st.validate(instance, schemaInfo.resolvedDynamicRef, &anns); err != nil { + return err + } + } else { + // Dynamic behavior. + // Look for the base of the outermost schema on the stack with this dynamic + // anchor. (Yes, outermost: the one farthest from here. This the opposite + // of how ordinary dynamic variables behave.) + // Why the base of the schema being validated and not the schema itself? + // Because the base is the scope for anchors. In fact it's possible to + // refer to a schema that is not on the stack, but a child of some base + // on the stack. + // For an example, search for "detached" in testdata/draft2020-12/dynamicRef.json. + var dynamicSchema *Schema + for _, s := range st.stack { + base := st.rs.resolvedInfos[s].base + info, ok := st.rs.resolvedInfos[base].anchors[schemaInfo.dynamicRefAnchor] + if ok && info.dynamic { + dynamicSchema = info.schema + break + } + } + if dynamicSchema == nil { + return fmt.Errorf("missing dynamic anchor %q", schemaInfo.dynamicRefAnchor) + } + if err := st.validate(instance, dynamicSchema, &anns); err != nil { + return err + } + } + } + + // logic + // https://json-schema.org/draft/2020-12/json-schema-core#section-10.2 + // These must happen before arrays and objects because if they evaluate an item or property, + // then the unevaluatedItems/Properties schemas don't apply to it. + // See https://json-schema.org/draft/2020-12/json-schema-core#section-11.2, paragraph 4. + // + // If any of these fail, then validation fails, even if there is an unevaluatedXXX + // keyword in the schema. The spec is unclear about this, but that is the intention. + + valid := func(s *Schema, anns *annotations) bool { return st.validate(instance, s, anns) == nil } + + if schema.AllOf != nil { + for _, ss := range schema.AllOf { + if err := st.validate(instance, ss, &anns); err != nil { + return err + } + } + } + if schema.AnyOf != nil { + // We must visit them all, to collect annotations. + ok := false + for _, ss := range schema.AnyOf { + if valid(ss, &anns) { + ok = true + } + } + if !ok { + return fmt.Errorf("anyOf: did not validate against any of %v", schema.AnyOf) + } + } + if schema.OneOf != nil { + // Exactly one. + var okSchema *Schema + for _, ss := range schema.OneOf { + if valid(ss, &anns) { + if okSchema != nil { + return fmt.Errorf("oneOf: validated against both %v and %v", okSchema, ss) + } + okSchema = ss + } + } + if okSchema == nil { + return fmt.Errorf("oneOf: did not validate against any of %v", schema.OneOf) + } + } + if schema.Not != nil { + // Ignore annotations from "not". + if valid(schema.Not, nil) { + return fmt.Errorf("not: validated against %v", schema.Not) + } + } + if schema.If != nil { + var ss *Schema + if valid(schema.If, &anns) { + ss = schema.Then + } else { + ss = schema.Else + } + if ss != nil { + if err := st.validate(instance, ss, &anns); err != nil { + return err + } + } + } + + // arrays + // TODO(jba): consider arrays of structs. + if instance.Kind() == reflect.Array || instance.Kind() == reflect.Slice { + // https://json-schema.org/draft/2020-12/json-schema-core#section-10.3.1 + // This validate call doesn't collect annotations for the items of the instance; they are separate + // instances in their own right. + // TODO(jba): if the test suite doesn't cover this case, add a test. For example, nested arrays. + for i, ischema := range schema.PrefixItems { + if i >= instance.Len() { + break // shorter is OK + } + if err := st.validate(instance.Index(i), ischema, nil); err != nil { + return err + } + } + anns.noteEndIndex(min(len(schema.PrefixItems), instance.Len())) + + if schema.Items != nil { + for i := len(schema.PrefixItems); i < instance.Len(); i++ { + if err := st.validate(instance.Index(i), schema.Items, nil); err != nil { + return err + } + } + // Note that all the items in this array have been validated. + anns.allItems = true + } + + nContains := 0 + if schema.Contains != nil { + for i := range instance.Len() { + if err := st.validate(instance.Index(i), schema.Contains, nil); err == nil { + nContains++ + anns.noteIndex(i) + } + } + if nContains == 0 && (schema.MinContains == nil || *schema.MinContains > 0) { + return fmt.Errorf("contains: %s does not have an item matching %s", instance, schema.Contains) + } + } + + // https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.4 + // TODO(jba): check that these next four keywords' values are integers. + if schema.MinContains != nil && schema.Contains != nil { + if m := *schema.MinContains; nContains < m { + return fmt.Errorf("minContains: contains validated %d items, less than %d", nContains, m) + } + } + if schema.MaxContains != nil && schema.Contains != nil { + if m := *schema.MaxContains; nContains > m { + return fmt.Errorf("maxContains: contains validated %d items, greater than %d", nContains, m) + } + } + if schema.MinItems != nil { + if m := *schema.MinItems; instance.Len() < m { + return fmt.Errorf("minItems: array length %d is less than %d", instance.Len(), m) + } + } + if schema.MaxItems != nil { + if m := *schema.MaxItems; instance.Len() > m { + return fmt.Errorf("maxItems: array length %d is greater than %d", instance.Len(), m) + } + } + if schema.UniqueItems { + if instance.Len() > 1 { + // Hash each item and compare the hashes. + // If two hashes differ, the items differ. + // If two hashes are the same, compare the collisions for equality. + // (The same logic as hash table lookup.) + // TODO(jba): Use container/hash.Map when it becomes available (https://go.dev/issue/69559), + hashes := map[uint64][]int{} // from hash to indices + seed := maphash.MakeSeed() + for i := range instance.Len() { + item := instance.Index(i) + var h maphash.Hash + h.SetSeed(seed) + hashValue(&h, item) + hv := h.Sum64() + if sames := hashes[hv]; len(sames) > 0 { + for _, j := range sames { + if equalValue(item, instance.Index(j)) { + return fmt.Errorf("uniqueItems: array items %d and %d are equal", i, j) + } + } + } + hashes[hv] = append(hashes[hv], i) + } + } + } + + // https://json-schema.org/draft/2020-12/json-schema-core#section-11.2 + if schema.UnevaluatedItems != nil && !anns.allItems { + // Apply this subschema to all items in the array that haven't been successfully validated. + // That includes validations by subschemas on the same instance, like allOf. + for i := anns.endIndex; i < instance.Len(); i++ { + if !anns.evaluatedIndexes[i] { + if err := st.validate(instance.Index(i), schema.UnevaluatedItems, nil); err != nil { + return err + } + } + } + anns.allItems = true + } + } + + // objects + // https://json-schema.org/draft/2020-12/json-schema-core#section-10.3.2 + // Validating structs is problematic. See https://github.com/google/jsonschema-go/issues/23. + if instance.Kind() == reflect.Struct { + return errors.New("cannot validate against a struct; see https://github.com/google/jsonschema-go/issues/23 for details") + } + if instance.Kind() == reflect.Map { + if kt := instance.Type().Key(); kt.Kind() != reflect.String { + return fmt.Errorf("map key type %s is not a string", kt) + } + // Track the evaluated properties for just this schema, to support additionalProperties. + // If we used anns here, then we'd be including properties evaluated in subschemas + // from allOf, etc., which additionalProperties shouldn't observe. + evalProps := map[string]bool{} + for prop, subschema := range schema.Properties { + val := property(instance, prop) + if !val.IsValid() { + // It's OK if the instance doesn't have the property. + continue + } + // If the instance is a struct and an optional property has the zero + // value, then we could interpret it as present or missing. Be generous: + // assume it's missing, and thus always validates successfully. + if instance.Kind() == reflect.Struct && val.IsZero() && !schemaInfo.isRequired[prop] { + continue + } + if err := st.validate(val, subschema, nil); err != nil { + return err + } + evalProps[prop] = true + } + if len(schema.PatternProperties) > 0 { + for prop, val := range properties(instance) { + // Check every matching pattern. + for re, schema := range schemaInfo.patternProperties { + if re.MatchString(prop) { + if err := st.validate(val, schema, nil); err != nil { + return err + } + evalProps[prop] = true + } + } + } + } + if schema.AdditionalProperties != nil { + // Special case for a better error message when additional properties is + // 'falsy' + // + // If additionalProperties is {"not":{}} (which is how we + // unmarshal "false"), we can produce a better error message that + // summarizes all the extra properties. Otherwise, we fall back to the + // default validation. + // + // Note: this is much faster than comparing with falseSchema using Equal. + isFalsy := schema.AdditionalProperties.Not != nil && reflect.ValueOf(*schema.AdditionalProperties.Not).IsZero() + if isFalsy { + var disallowed []string + for prop := range properties(instance) { + if !evalProps[prop] { + disallowed = append(disallowed, prop) + } + } + if len(disallowed) > 0 { + return fmt.Errorf("unexpected additional properties %q", disallowed) + } + } else { + // Apply to all properties not handled above. + for prop, val := range properties(instance) { + if !evalProps[prop] { + if err := st.validate(val, schema.AdditionalProperties, nil); err != nil { + return err + } + evalProps[prop] = true + } + } + } + } + anns.noteProperties(evalProps) + if schema.PropertyNames != nil { + // Note: properties unnecessarily fetches each value. We could define a propertyNames function + // if performance ever matters. + for prop := range properties(instance) { + if err := st.validate(reflect.ValueOf(prop), schema.PropertyNames, nil); err != nil { + return err + } + } + } + + // https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.5 + var min, max int + if schema.MinProperties != nil || schema.MaxProperties != nil { + min, max = numPropertiesBounds(instance, schemaInfo.isRequired) + } + if schema.MinProperties != nil { + if n, m := max, *schema.MinProperties; n < m { + return fmt.Errorf("minProperties: object has %d properties, less than %d", n, m) + } + } + if schema.MaxProperties != nil { + if n, m := min, *schema.MaxProperties; n > m { + return fmt.Errorf("maxProperties: object has %d properties, greater than %d", n, m) + } + } + + hasProperty := func(prop string) bool { + return property(instance, prop).IsValid() + } + + missingProperties := func(props []string) []string { + var missing []string + for _, p := range props { + if !hasProperty(p) { + missing = append(missing, p) + } + } + return missing + } + + if schema.Required != nil { + if m := missingProperties(schema.Required); len(m) > 0 { + return fmt.Errorf("required: missing properties: %q", m) + } + } + if schema.DependentRequired != nil { + // "Validation succeeds if, for each name that appears in both the instance + // and as a name within this keyword's value, every item in the corresponding + // array is also the name of a property in the instance." §6.5.4 + for dprop, reqs := range schema.DependentRequired { + if hasProperty(dprop) { + if m := missingProperties(reqs); len(m) > 0 { + return fmt.Errorf("dependentRequired[%q]: missing properties %q", dprop, m) + } + } + } + } + + // https://json-schema.org/draft/2020-12/json-schema-core#section-10.2.2.4 + if schema.DependentSchemas != nil { + // This does not collect annotations, although it seems like it should. + for dprop, ss := range schema.DependentSchemas { + if hasProperty(dprop) { + // TODO: include dependentSchemas[dprop] in the errors. + err := st.validate(instance, ss, &anns) + if err != nil { + return err + } + } + } + } + if schema.UnevaluatedProperties != nil && !anns.allProperties { + // This looks a lot like AdditionalProperties, but depends on in-place keywords like allOf + // in addition to sibling keywords. + for prop, val := range properties(instance) { + if !anns.evaluatedProperties[prop] { + if err := st.validate(val, schema.UnevaluatedProperties, nil); err != nil { + return err + } + } + } + // The spec says the annotation should be the set of evaluated properties, but we can optimize + // by setting a single boolean, since after this succeeds all properties will be validated. + // See https://json-schema.slack.com/archives/CT7FF623C/p1745592564381459. + anns.allProperties = true + } + } + + if callerAnns != nil { + // Our caller wants to know what we've validated. + callerAnns.merge(&anns) + } + return nil +} + +// resolveDynamicRef returns the schema referred to by the argument schema's +// $dynamicRef value. +// It returns an error if the dynamic reference has no referent. +// If there is no $dynamicRef, resolveDynamicRef returns nil, nil. +// See https://json-schema.org/draft/2020-12/json-schema-core#section-8.2.3.2. +func (st *state) resolveDynamicRef(schema *Schema) (*Schema, error) { + if schema.DynamicRef == "" { + return nil, nil + } + info := st.rs.resolvedInfos[schema] + // The ref behaves lexically or dynamically, but not both. + assert((info.resolvedDynamicRef == nil) != (info.dynamicRefAnchor == ""), + "DynamicRef not statically resolved properly") + if r := info.resolvedDynamicRef; r != nil { + // Same as $ref. + return r, nil + } + // Dynamic behavior. + // Look for the base of the outermost schema on the stack with this dynamic + // anchor. (Yes, outermost: the one farthest from here. This the opposite + // of how ordinary dynamic variables behave.) + // Why the base of the schema being validated and not the schema itself? + // Because the base is the scope for anchors. In fact it's possible to + // refer to a schema that is not on the stack, but a child of some base + // on the stack. + // For an example, search for "detached" in testdata/draft2020-12/dynamicRef.json. + for _, s := range st.stack { + base := st.rs.resolvedInfos[s].base + info, ok := st.rs.resolvedInfos[base].anchors[info.dynamicRefAnchor] + if ok && info.dynamic { + return info.schema, nil + } + } + return nil, fmt.Errorf("missing dynamic anchor %q", info.dynamicRefAnchor) +} + +// ApplyDefaults modifies an instance by applying the schema's defaults to it. If +// a schema or sub-schema has a default, then a corresponding zero instance value +// is set to the default. +// +// The JSON Schema specification does not describe how defaults should be interpreted. +// This method honors defaults only on properties, and only those that are not required. +// If the instance is a map and the property is missing, the property is added to +// the map with the default. +// If the instance is a struct, the field corresponding to the property exists, and +// its value is zero, the field is set to the default. +// ApplyDefaults can panic if a default cannot be assigned to a field. +// +// The argument must be a pointer to the instance. +// (In case we decide that top-level defaults are meaningful.) +// +// It is recommended to first call Resolve with a ValidateDefaults option of true, +// then call this method, and lastly call Validate. +func (rs *Resolved) ApplyDefaults(instancep any) error { + // TODO(jba): consider what defaults on top-level or array instances might mean. + // TODO(jba): follow $ref and $dynamicRef + // TODO(jba): apply defaults on sub-schemas to corresponding sub-instances. + st := &state{rs: rs} + return st.applyDefaults(reflect.ValueOf(instancep), rs.root) +} + +// Leave this as a potentially recursive helper function, because we'll surely want +// to apply defaults on sub-schemas someday. +func (st *state) applyDefaults(instancep reflect.Value, schema *Schema) (err error) { + defer wrapf(&err, "applyDefaults: schema %s, instance %v", st.rs.schemaString(schema), instancep) + + schemaInfo := st.rs.resolvedInfos[schema] + instance := instancep.Elem() + if instance.Kind() == reflect.Interface && instance.IsValid() { + // If we unmarshalled into 'any', the default object unmarshalling will be map[string]any. + instance = instance.Elem() + } + if instance.Kind() == reflect.Map || instance.Kind() == reflect.Struct { + if instance.Kind() == reflect.Map { + if kt := instance.Type().Key(); kt.Kind() != reflect.String { + return fmt.Errorf("map key type %s is not a string", kt) + } + } + for prop, subschema := range schema.Properties { + // Ignore defaults on required properties. (A required property shouldn't have a default.) + if schemaInfo.isRequired[prop] { + continue + } + val := property(instance, prop) + switch instance.Kind() { + case reflect.Map: + // If there is a default for this property, and the map key is missing, + // set the map value to the default. + if subschema.Default != nil && !val.IsValid() { + // Create an lvalue, since map values aren't addressable. + lvalue := reflect.New(instance.Type().Elem()) + if err := json.Unmarshal(subschema.Default, lvalue.Interface()); err != nil { + return err + } + instance.SetMapIndex(reflect.ValueOf(prop), lvalue.Elem()) + } + case reflect.Struct: + // If there is a default for this property, and the field exists but is zero, + // set the field to the default. + if subschema.Default != nil && val.IsValid() && val.IsZero() { + if err := json.Unmarshal(subschema.Default, val.Addr().Interface()); err != nil { + return err + } + } + default: + panic(fmt.Sprintf("applyDefaults: property %s: bad value %s of kind %s", + prop, instance, instance.Kind())) + } + } + } + return nil +} + +// property returns the value of the property of v with the given name, or the invalid +// reflect.Value if there is none. +// If v is a map, the property is the value of the map whose key is name. +// If v is a struct, the property is the value of the field with the given name according +// to the encoding/json package (see [jsonName]). +// If v is anything else, property panics. +func property(v reflect.Value, name string) reflect.Value { + switch v.Kind() { + case reflect.Map: + return v.MapIndex(reflect.ValueOf(name)) + case reflect.Struct: + props := structPropertiesOf(v.Type()) + // Ignore nonexistent properties. + if sf, ok := props[name]; ok { + return v.FieldByIndex(sf.Index) + } + return reflect.Value{} + default: + panic(fmt.Sprintf("property(%q): bad value %s of kind %s", name, v, v.Kind())) + } +} + +// properties returns an iterator over the names and values of all properties +// in v, which must be a map or a struct. +// If a struct, zero-valued properties that are marked omitempty or omitzero +// are excluded. +func properties(v reflect.Value) iter.Seq2[string, reflect.Value] { + return func(yield func(string, reflect.Value) bool) { + switch v.Kind() { + case reflect.Map: + for k, e := range v.Seq2() { + if !yield(k.String(), e) { + return + } + } + case reflect.Struct: + for name, sf := range structPropertiesOf(v.Type()) { + val := v.FieldByIndex(sf.Index) + if val.IsZero() { + info := fieldJSONInfo(sf) + if info.settings["omitempty"] || info.settings["omitzero"] { + continue + } + } + if !yield(name, val) { + return + } + } + default: + panic(fmt.Sprintf("bad value %s of kind %s", v, v.Kind())) + } + } +} + +// numPropertiesBounds returns bounds on the number of v's properties. +// v must be a map or a struct. +// If v is a map, both bounds are the map's size. +// If v is a struct, the max is the number of struct properties. +// But since we don't know whether a zero value indicates a missing optional property +// or not, be generous and use the number of non-zero properties as the min. +func numPropertiesBounds(v reflect.Value, isRequired map[string]bool) (int, int) { + switch v.Kind() { + case reflect.Map: + return v.Len(), v.Len() + case reflect.Struct: + sp := structPropertiesOf(v.Type()) + min := 0 + for prop, sf := range sp { + if !v.FieldByIndex(sf.Index).IsZero() || isRequired[prop] { + min++ + } + } + return min, len(sp) + default: + panic(fmt.Sprintf("properties: bad value: %s of kind %s", v, v.Kind())) + } +} + +// A propertyMap is a map from property name to struct field index. +type propertyMap = map[string]reflect.StructField + +var structProperties sync.Map // from reflect.Type to propertyMap + +// structPropertiesOf returns the JSON Schema properties for the struct type t. +// The caller must not mutate the result. +func structPropertiesOf(t reflect.Type) propertyMap { + // Mutex not necessary: at worst we'll recompute the same value. + if props, ok := structProperties.Load(t); ok { + return props.(propertyMap) + } + props := map[string]reflect.StructField{} + for _, sf := range reflect.VisibleFields(t) { + if sf.Anonymous { + continue + } + info := fieldJSONInfo(sf) + if !info.omit { + props[info.name] = sf + } + } + structProperties.Store(t, props) + return props +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/LICENSE b/vendor/github.com/modelcontextprotocol/go-sdk/LICENSE new file mode 100644 index 0000000000..508be92666 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Go MCP SDK Authors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/auth/auth.go b/vendor/github.com/modelcontextprotocol/go-sdk/auth/auth.go new file mode 100644 index 0000000000..0eea1d873c --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/auth/auth.go @@ -0,0 +1,120 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package auth + +import ( + "context" + "errors" + "net/http" + "slices" + "strings" + "time" +) + +// TokenInfo holds information from a bearer token. +type TokenInfo struct { + Scopes []string + Expiration time.Time + // TODO: add standard JWT fields + Extra map[string]any +} + +// The error that a TokenVerifier should return if the token cannot be verified. +var ErrInvalidToken = errors.New("invalid token") + +// The error that a TokenVerifier should return for OAuth-specific protocol errors. +var ErrOAuth = errors.New("oauth error") + +// A TokenVerifier checks the validity of a bearer token, and extracts information +// from it. If verification fails, it should return an error that unwraps to ErrInvalidToken. +// The HTTP request is provided in case verifying the token involves checking it. +type TokenVerifier func(ctx context.Context, token string, req *http.Request) (*TokenInfo, error) + +// RequireBearerTokenOptions are options for [RequireBearerToken]. +type RequireBearerTokenOptions struct { + // The URL for the resource server metadata OAuth flow, to be returned as part + // of the WWW-Authenticate header. + ResourceMetadataURL string + // The required scopes. + Scopes []string +} + +type tokenInfoKey struct{} + +// TokenInfoFromContext returns the [TokenInfo] stored in ctx, or nil if none. +func TokenInfoFromContext(ctx context.Context) *TokenInfo { + ti := ctx.Value(tokenInfoKey{}) + if ti == nil { + return nil + } + return ti.(*TokenInfo) +} + +// RequireBearerToken returns a piece of middleware that verifies a bearer token using the verifier. +// If verification succeeds, the [TokenInfo] is added to the request's context and the request proceeds. +// If verification fails, the request fails with a 401 Unauthenticated, and the WWW-Authenticate header +// is populated to enable [protected resource metadata]. +// +// [protected resource metadata]: https://datatracker.ietf.org/doc/rfc9728 +func RequireBearerToken(verifier TokenVerifier, opts *RequireBearerTokenOptions) func(http.Handler) http.Handler { + // Based on typescript-sdk/src/server/auth/middleware/bearerAuth.ts. + + return func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tokenInfo, errmsg, code := verify(r, verifier, opts) + if code != 0 { + if code == http.StatusUnauthorized || code == http.StatusForbidden { + if opts != nil && opts.ResourceMetadataURL != "" { + w.Header().Add("WWW-Authenticate", "Bearer resource_metadata="+opts.ResourceMetadataURL) + } + } + http.Error(w, errmsg, code) + return + } + r = r.WithContext(context.WithValue(r.Context(), tokenInfoKey{}, tokenInfo)) + handler.ServeHTTP(w, r) + }) + } +} + +func verify(req *http.Request, verifier TokenVerifier, opts *RequireBearerTokenOptions) (_ *TokenInfo, errmsg string, code int) { + // Extract bearer token. + authHeader := req.Header.Get("Authorization") + fields := strings.Fields(authHeader) + if len(fields) != 2 || strings.ToLower(fields[0]) != "bearer" { + return nil, "no bearer token", http.StatusUnauthorized + } + + // Verify the token and get information from it. + tokenInfo, err := verifier(req.Context(), fields[1], req) + if err != nil { + if errors.Is(err, ErrInvalidToken) { + return nil, err.Error(), http.StatusUnauthorized + } + if errors.Is(err, ErrOAuth) { + return nil, err.Error(), http.StatusBadRequest + } + return nil, err.Error(), http.StatusInternalServerError + } + + // Check scopes. All must be present. + if opts != nil { + // Note: quadratic, but N is small. + for _, s := range opts.Scopes { + if !slices.Contains(tokenInfo.Scopes, s) { + return nil, "insufficient scope", http.StatusForbidden + } + } + } + + // Check expiration. + if tokenInfo.Expiration.IsZero() { + return nil, "token missing expiration", http.StatusUnauthorized + } + if tokenInfo.Expiration.Before(time.Now()) { + return nil, "token expired", http.StatusUnauthorized + } + return tokenInfo, "", 0 +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/auth/client.go b/vendor/github.com/modelcontextprotocol/go-sdk/auth/client.go new file mode 100644 index 0000000000..acadc51be3 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/auth/client.go @@ -0,0 +1,123 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package auth + +import ( + "bytes" + "errors" + "io" + "net/http" + "sync" + + "golang.org/x/oauth2" +) + +// An OAuthHandler conducts an OAuth flow and returns a [oauth2.TokenSource] if the authorization +// is approved, or an error if not. +// The handler receives the HTTP request and response that triggered the authentication flow. +// To obtain the protected resource metadata, call [oauthex.GetProtectedResourceMetadataFromHeader]. +type OAuthHandler func(req *http.Request, res *http.Response) (oauth2.TokenSource, error) + +// HTTPTransport is an [http.RoundTripper] that follows the MCP +// OAuth protocol when it encounters a 401 Unauthorized response. +type HTTPTransport struct { + handler OAuthHandler + mu sync.Mutex // protects opts.Base + opts HTTPTransportOptions +} + +// NewHTTPTransport returns a new [*HTTPTransport]. +// The handler is invoked when an HTTP request results in a 401 Unauthorized status. +// It is called only once per transport. Once a TokenSource is obtained, it is used +// for the lifetime of the transport; subsequent 401s are not processed. +func NewHTTPTransport(handler OAuthHandler, opts *HTTPTransportOptions) (*HTTPTransport, error) { + if handler == nil { + return nil, errors.New("handler cannot be nil") + } + t := &HTTPTransport{ + handler: handler, + } + if opts != nil { + t.opts = *opts + } + if t.opts.Base == nil { + t.opts.Base = http.DefaultTransport + } + return t, nil +} + +// HTTPTransportOptions are options to [NewHTTPTransport]. +type HTTPTransportOptions struct { + // Base is the [http.RoundTripper] to use. + // If nil, [http.DefaultTransport] is used. + Base http.RoundTripper +} + +func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { + t.mu.Lock() + base := t.opts.Base + t.mu.Unlock() + + var ( + // If haveBody is set, the request has a nontrivial body, and we need avoid + // reading (or closing) it multiple times. In that case, bodyBytes is its + // content. + haveBody bool + bodyBytes []byte + ) + if req.Body != nil && req.Body != http.NoBody { + // if we're setting Body, we must mutate first. + req = req.Clone(req.Context()) + haveBody = true + var err error + bodyBytes, err = io.ReadAll(req.Body) + if err != nil { + return nil, err + } + // Now that we've read the request body, http.RoundTripper requires that we + // close it. + req.Body.Close() // ignore error + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + } + + resp, err := base.RoundTrip(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusUnauthorized { + return resp, nil + } + if _, ok := base.(*oauth2.Transport); ok { + // We failed to authorize even with a token source; give up. + return resp, nil + } + + resp.Body.Close() + // Try to authorize. + t.mu.Lock() + defer t.mu.Unlock() + // If we don't have a token source, get one by following the OAuth flow. + // (We may have obtained one while t.mu was not held above.) + // TODO: We hold the lock for the entire OAuth flow. This could be a long + // time. Is there a better way? + if _, ok := t.opts.Base.(*oauth2.Transport); !ok { + ts, err := t.handler(req, resp) + if err != nil { + return nil, err + } + t.opts.Base = &oauth2.Transport{Base: t.opts.Base, Source: ts} + } + + // If we don't have a body, the request is reusable, though it will be cloned + // by the base. However, if we've had to read the body, we must clone. + if haveBody { + req = req.Clone(req.Context()) + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + } + + return t.opts.Base.RoundTrip(req) +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/conn.go b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/conn.go new file mode 100644 index 0000000000..5549ee1c9e --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/conn.go @@ -0,0 +1,825 @@ +// Copyright 2018 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package jsonrpc2 + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "sync" + "sync/atomic" + "time" +) + +// Binder builds a connection configuration. +// This may be used in servers to generate a new configuration per connection. +// ConnectionOptions itself implements Binder returning itself unmodified, to +// allow for the simple cases where no per connection information is needed. +type Binder interface { + // Bind returns the ConnectionOptions to use when establishing the passed-in + // Connection. + // + // The connection is not ready to use when Bind is called, + // but Bind may close it without reading or writing to it. + Bind(context.Context, *Connection) ConnectionOptions +} + +// A BinderFunc implements the Binder interface for a standalone Bind function. +type BinderFunc func(context.Context, *Connection) ConnectionOptions + +func (f BinderFunc) Bind(ctx context.Context, c *Connection) ConnectionOptions { + return f(ctx, c) +} + +var _ Binder = BinderFunc(nil) + +// ConnectionOptions holds the options for new connections. +type ConnectionOptions struct { + // Framer allows control over the message framing and encoding. + // If nil, HeaderFramer will be used. + Framer Framer + // Preempter allows registration of a pre-queue message handler. + // If nil, no messages will be preempted. + Preempter Preempter + // Handler is used as the queued message handler for inbound messages. + // If nil, all responses will be ErrNotHandled. + Handler Handler + // OnInternalError, if non-nil, is called with any internal errors that occur + // while serving the connection, such as protocol errors or invariant + // violations. (If nil, internal errors result in panics.) + OnInternalError func(error) +} + +// Connection manages the jsonrpc2 protocol, connecting responses back to their +// calls. +// Connection is bidirectional; it does not have a designated server or client +// end. +type Connection struct { + seq int64 // must only be accessed using atomic operations + + stateMu sync.Mutex + state inFlightState // accessed only in updateInFlight + done chan struct{} // closed (under stateMu) when state.closed is true and all goroutines have completed + + writer Writer + handler Handler + + onInternalError func(error) + onDone func() +} + +// inFlightState records the state of the incoming and outgoing calls on a +// Connection. +type inFlightState struct { + connClosing bool // true when the Connection's Close method has been called + reading bool // true while the readIncoming goroutine is running + readErr error // non-nil when the readIncoming goroutine exits (typically io.EOF) + writeErr error // non-nil if a call to the Writer has failed with a non-canceled Context + + // closer shuts down and cleans up the Reader and Writer state, ideally + // interrupting any Read or Write call that is currently blocked. It is closed + // when the state is idle and one of: connClosing is true, readErr is non-nil, + // or writeErr is non-nil. + // + // After the closer has been invoked, the closer field is set to nil + // and the closeErr field is simultaneously set to its result. + closer io.Closer + closeErr error // error returned from closer.Close + + outgoingCalls map[ID]*AsyncCall // calls only + outgoingNotifications int // # of notifications awaiting "write" + + // incoming stores the total number of incoming calls and notifications + // that have not yet written or processed a result. + incoming int + + incomingByID map[ID]*incomingRequest // calls only + + // handlerQueue stores the backlog of calls and notifications that were not + // already handled by a preempter. + // The queue does not include the request currently being handled (if any). + handlerQueue []*incomingRequest + handlerRunning bool +} + +// updateInFlight locks the state of the connection's in-flight requests, allows +// f to mutate that state, and closes the connection if it is idle and either +// is closing or has a read or write error. +func (c *Connection) updateInFlight(f func(*inFlightState)) { + c.stateMu.Lock() + defer c.stateMu.Unlock() + + s := &c.state + + f(s) + + select { + case <-c.done: + // The connection was already completely done at the start of this call to + // updateInFlight, so it must remain so. (The call to f should have noticed + // that and avoided making any updates that would cause the state to be + // non-idle.) + if !s.idle() { + panic("jsonrpc2: updateInFlight transitioned to non-idle when already done") + } + return + default: + } + + if s.idle() && s.shuttingDown(ErrUnknown) != nil { + if s.closer != nil { + s.closeErr = s.closer.Close() + s.closer = nil // prevent duplicate Close calls + } + if s.reading { + // The readIncoming goroutine is still running. Our call to Close should + // cause it to exit soon, at which point it will make another call to + // updateInFlight, set s.reading to false, and mark the Connection done. + } else { + // The readIncoming goroutine has exited, or never started to begin with. + // Since everything else is idle, we're completely done. + if c.onDone != nil { + c.onDone() + } + close(c.done) + } + } +} + +// idle reports whether the connection is in a state with no pending calls or +// notifications. +// +// If idle returns true, the readIncoming goroutine may still be running, +// but no other goroutines are doing work on behalf of the connection. +func (s *inFlightState) idle() bool { + return len(s.outgoingCalls) == 0 && s.outgoingNotifications == 0 && s.incoming == 0 && !s.handlerRunning +} + +// shuttingDown reports whether the connection is in a state that should +// disallow new (incoming and outgoing) calls. It returns either nil or +// an error that is or wraps the provided errClosing. +func (s *inFlightState) shuttingDown(errClosing error) error { + if s.connClosing { + // If Close has been called explicitly, it doesn't matter what state the + // Reader and Writer are in: we shouldn't be starting new work because the + // caller told us not to start new work. + return errClosing + } + if s.readErr != nil { + // If the read side of the connection is broken, we cannot read new call + // requests, and cannot read responses to our outgoing calls. + return fmt.Errorf("%w: %v", errClosing, s.readErr) + } + if s.writeErr != nil { + // If the write side of the connection is broken, we cannot write responses + // for incoming calls, and cannot write requests for outgoing calls. + return fmt.Errorf("%w: %v", errClosing, s.writeErr) + } + return nil +} + +// incomingRequest is used to track an incoming request as it is being handled +type incomingRequest struct { + *Request // the request being processed + ctx context.Context + cancel context.CancelFunc +} + +// Bind returns the options unmodified. +func (o ConnectionOptions) Bind(context.Context, *Connection) ConnectionOptions { + return o +} + +// A ConnectionConfig configures a bidirectional jsonrpc2 connection. +type ConnectionConfig struct { + Reader Reader // required + Writer Writer // required + Closer io.Closer // required + Preempter Preempter // optional + Bind func(*Connection) Handler // required + OnDone func() // optional + OnInternalError func(error) // optional +} + +// NewConnection creates a new [Connection] object and starts processing +// incoming messages. +func NewConnection(ctx context.Context, cfg ConnectionConfig) *Connection { + ctx = notDone{ctx} + + c := &Connection{ + state: inFlightState{closer: cfg.Closer}, + done: make(chan struct{}), + writer: cfg.Writer, + onDone: cfg.OnDone, + onInternalError: cfg.OnInternalError, + } + c.handler = cfg.Bind(c) + c.start(ctx, cfg.Reader, cfg.Preempter) + return c +} + +// bindConnection creates a new connection and runs it. +// +// This is used by the Dial and Serve functions to build the actual connection. +// +// The connection is closed automatically (and its resources cleaned up) when +// the last request has completed after the underlying ReadWriteCloser breaks, +// but it may be stopped earlier by calling Close (for a clean shutdown). +func bindConnection(bindCtx context.Context, rwc io.ReadWriteCloser, binder Binder, onDone func()) *Connection { + // TODO: Should we create a new event span here? + // This will propagate cancellation from ctx; should it? + ctx := notDone{bindCtx} + + c := &Connection{ + state: inFlightState{closer: rwc}, + done: make(chan struct{}), + onDone: onDone, + } + // It's tempting to set a finalizer on c to verify that the state has gone + // idle when the connection becomes unreachable. Unfortunately, the Binder + // interface makes that unsafe: it allows the Handler to close over the + // Connection, which could create a reference cycle that would cause the + // Connection to become uncollectable. + + options := binder.Bind(bindCtx, c) + framer := options.Framer + if framer == nil { + framer = HeaderFramer() + } + c.handler = options.Handler + if c.handler == nil { + c.handler = defaultHandler{} + } + c.onInternalError = options.OnInternalError + + c.writer = framer.Writer(rwc) + reader := framer.Reader(rwc) + c.start(ctx, reader, options.Preempter) + return c +} + +func (c *Connection) start(ctx context.Context, reader Reader, preempter Preempter) { + c.updateInFlight(func(s *inFlightState) { + select { + case <-c.done: + // Bind already closed the connection; don't start a goroutine to read it. + return + default: + } + + // The goroutine started here will continue until the underlying stream is closed. + // + // (If the Binder closed the Connection already, this should error out and + // return almost immediately.) + s.reading = true + go c.readIncoming(ctx, reader, preempter) + }) +} + +// Notify invokes the target method but does not wait for a response. +// The params will be marshaled to JSON before sending over the wire, and will +// be handed to the method invoked. +func (c *Connection) Notify(ctx context.Context, method string, params any) (err error) { + attempted := false + + defer func() { + if attempted { + c.updateInFlight(func(s *inFlightState) { + s.outgoingNotifications-- + }) + } + }() + + c.updateInFlight(func(s *inFlightState) { + // If the connection is shutting down, allow outgoing notifications only if + // there is at least one call still in flight. The number of calls in flight + // cannot increase once shutdown begins, and allowing outgoing notifications + // may permit notifications that will cancel in-flight calls. + if len(s.outgoingCalls) == 0 && len(s.incomingByID) == 0 { + err = s.shuttingDown(ErrClientClosing) + if err != nil { + return + } + } + s.outgoingNotifications++ + attempted = true + }) + if err != nil { + return err + } + + notify, err := NewNotification(method, params) + if err != nil { + return fmt.Errorf("marshaling notify parameters: %v", err) + } + + return c.write(ctx, notify) +} + +// Call invokes the target method and returns an object that can be used to await the response. +// The params will be marshaled to JSON before sending over the wire, and will +// be handed to the method invoked. +// You do not have to wait for the response, it can just be ignored if not needed. +// If sending the call failed, the response will be ready and have the error in it. +func (c *Connection) Call(ctx context.Context, method string, params any) *AsyncCall { + // Generate a new request identifier. + id := Int64ID(atomic.AddInt64(&c.seq, 1)) + + ac := &AsyncCall{ + id: id, + ready: make(chan struct{}), + } + // When this method returns, either ac is retired, or the request has been + // written successfully and the call is awaiting a response (to be provided by + // the readIncoming goroutine). + + call, err := NewCall(ac.id, method, params) + if err != nil { + ac.retire(&Response{ID: id, Error: fmt.Errorf("marshaling call parameters: %w", err)}) + return ac + } + + c.updateInFlight(func(s *inFlightState) { + err = s.shuttingDown(ErrClientClosing) + if err != nil { + return + } + if s.outgoingCalls == nil { + s.outgoingCalls = make(map[ID]*AsyncCall) + } + s.outgoingCalls[ac.id] = ac + }) + if err != nil { + ac.retire(&Response{ID: id, Error: err}) + return ac + } + + if err := c.write(ctx, call); err != nil { + // Sending failed. We will never get a response, so deliver a fake one if it + // wasn't already retired by the connection breaking. + c.updateInFlight(func(s *inFlightState) { + if s.outgoingCalls[ac.id] == ac { + delete(s.outgoingCalls, ac.id) + ac.retire(&Response{ID: id, Error: err}) + } else { + // ac was already retired by the readIncoming goroutine: + // perhaps our write raced with the Read side of the connection breaking. + } + }) + } + return ac +} + +// Async, signals that the current jsonrpc2 request may be handled +// asynchronously to subsequent requests, when ctx is the request context. +// +// Async must be called at most once on each request's context (and its +// descendants). +func Async(ctx context.Context) { + if r, ok := ctx.Value(asyncKey).(*releaser); ok { + r.release(false) + } +} + +type asyncKeyType struct{} + +var asyncKey = asyncKeyType{} + +// A releaser implements concurrency safe 'releasing' of async requests. (A +// request is released when it is allowed to run concurrent with other +// requests, via a call to [Async].) +type releaser struct { + mu sync.Mutex + ch chan struct{} + released bool +} + +// release closes the associated channel. If soft is set, multiple calls to +// release are allowed. +func (r *releaser) release(soft bool) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.released { + if !soft { + panic("jsonrpc2.Async called multiple times") + } + } else { + close(r.ch) + r.released = true + } +} + +type AsyncCall struct { + id ID + ready chan struct{} // closed after response has been set + response *Response +} + +// ID used for this call. +// This can be used to cancel the call if needed. +func (ac *AsyncCall) ID() ID { return ac.id } + +// IsReady can be used to check if the result is already prepared. +// This is guaranteed to return true on a result for which Await has already +// returned, or a call that failed to send in the first place. +func (ac *AsyncCall) IsReady() bool { + select { + case <-ac.ready: + return true + default: + return false + } +} + +// retire processes the response to the call. +func (ac *AsyncCall) retire(response *Response) { + select { + case <-ac.ready: + panic(fmt.Sprintf("jsonrpc2: retire called twice for ID %v", ac.id)) + default: + } + + ac.response = response + close(ac.ready) +} + +// Await waits for (and decodes) the results of a Call. +// The response will be unmarshaled from JSON into the result. +func (ac *AsyncCall) Await(ctx context.Context, result any) error { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ac.ready: + } + if ac.response.Error != nil { + return ac.response.Error + } + if result == nil { + return nil + } + return json.Unmarshal(ac.response.Result, result) +} + +// Cancel cancels the Context passed to the Handle call for the inbound message +// with the given ID. +// +// Cancel will not complain if the ID is not a currently active message, and it +// will not cause any messages that have not arrived yet with that ID to be +// cancelled. +func (c *Connection) Cancel(id ID) { + var req *incomingRequest + c.updateInFlight(func(s *inFlightState) { + req = s.incomingByID[id] + }) + if req != nil { + req.cancel() + } +} + +// Wait blocks until the connection is fully closed, but does not close it. +func (c *Connection) Wait() error { + return c.wait(true) +} + +// wait for the connection to close, and aggregates the most cause of its +// termination, if abnormal. +// +// The fromWait argument allows this logic to be shared with Close, where we +// only want to expose the closeErr. +// +// (Previously, Wait also only returned the closeErr, which was misleading if +// the connection was broken for another reason). +func (c *Connection) wait(fromWait bool) error { + var err error + <-c.done + c.updateInFlight(func(s *inFlightState) { + if fromWait { + if !errors.Is(s.readErr, io.EOF) { + err = s.readErr + } + if err == nil && !errors.Is(s.writeErr, io.EOF) { + err = s.writeErr + } + } + if err == nil { + err = s.closeErr + } + }) + return err +} + +// Close stops accepting new requests, waits for in-flight requests and enqueued +// Handle calls to complete, and then closes the underlying stream. +// +// After the start of a Close, notification requests (that lack IDs and do not +// receive responses) will continue to be passed to the Preempter, but calls +// with IDs will receive immediate responses with ErrServerClosing, and no new +// requests (not even notifications!) will be enqueued to the Handler. +func (c *Connection) Close() error { + // Stop handling new requests, and interrupt the reader (by closing the + // connection) as soon as the active requests finish. + c.updateInFlight(func(s *inFlightState) { s.connClosing = true }) + return c.wait(false) +} + +// readIncoming collects inbound messages from the reader and delivers them, either responding +// to outgoing calls or feeding requests to the queue. +func (c *Connection) readIncoming(ctx context.Context, reader Reader, preempter Preempter) { + var err error + for { + var msg Message + msg, err = reader.Read(ctx) + if err != nil { + break + } + + switch msg := msg.(type) { + case *Request: + c.acceptRequest(ctx, msg, preempter) + + case *Response: + c.updateInFlight(func(s *inFlightState) { + if ac, ok := s.outgoingCalls[msg.ID]; ok { + delete(s.outgoingCalls, msg.ID) + ac.retire(msg) + } else { + // TODO: How should we report unexpected responses? + } + }) + + default: + c.internalErrorf("Read returned an unexpected message of type %T", msg) + } + } + + c.updateInFlight(func(s *inFlightState) { + s.reading = false + s.readErr = err + + // Retire any outgoing requests that were still in flight: with the Reader no + // longer being processed, they necessarily cannot receive a response. + for id, ac := range s.outgoingCalls { + ac.retire(&Response{ID: id, Error: err}) + } + s.outgoingCalls = nil + }) +} + +// acceptRequest either handles msg synchronously or enqueues it to be handled +// asynchronously. +func (c *Connection) acceptRequest(ctx context.Context, msg *Request, preempter Preempter) { + // In theory notifications cannot be cancelled, but we build them a cancel + // context anyway. + reqCtx, cancel := context.WithCancel(ctx) + req := &incomingRequest{ + Request: msg, + ctx: reqCtx, + cancel: cancel, + } + + // If the request is a call, add it to the incoming map so it can be + // cancelled (or responded) by ID. + var err error + c.updateInFlight(func(s *inFlightState) { + s.incoming++ + + if req.IsCall() { + if s.incomingByID[req.ID] != nil { + err = fmt.Errorf("%w: request ID %v already in use", ErrInvalidRequest, req.ID) + req.ID = ID{} // Don't misattribute this error to the existing request. + return + } + + if s.incomingByID == nil { + s.incomingByID = make(map[ID]*incomingRequest) + } + s.incomingByID[req.ID] = req + + // When shutting down, reject all new Call requests, even if they could + // theoretically be handled by the preempter. The preempter could return + // ErrAsyncResponse, which would increase the amount of work in flight + // when we're trying to ensure that it strictly decreases. + err = s.shuttingDown(ErrServerClosing) + } + }) + if err != nil { + c.processResult("acceptRequest", req, nil, err) + return + } + + if preempter != nil { + result, err := preempter.Preempt(req.ctx, req.Request) + + if !errors.Is(err, ErrNotHandled) { + c.processResult("Preempt", req, result, err) + return + } + } + + c.updateInFlight(func(s *inFlightState) { + // If the connection is shutting down, don't enqueue anything to the + // handler — not even notifications. That ensures that if the handler + // continues to make progress, it will eventually become idle and + // close the connection. + err = s.shuttingDown(ErrServerClosing) + if err != nil { + return + } + + // We enqueue requests that have not been preempted to an unbounded slice. + // Unfortunately, we cannot in general limit the size of the handler + // queue: we have to read every response that comes in on the wire + // (because it may be responding to a request issued by, say, an + // asynchronous handler), and in order to get to that response we have + // to read all of the requests that came in ahead of it. + s.handlerQueue = append(s.handlerQueue, req) + if !s.handlerRunning { + // We start the handleAsync goroutine when it has work to do, and let it + // exit when the queue empties. + // + // Otherwise, in order to synchronize the handler we would need some other + // goroutine (probably readIncoming?) to explicitly wait for handleAsync + // to finish, and that would complicate error reporting: either the error + // report from the goroutine would be blocked on the handler emptying its + // queue (which was tried, and introduced a deadlock detected by + // TestCloseCallRace), or the error would need to be reported separately + // from synchronizing completion. Allowing the handler goroutine to exit + // when idle seems simpler than trying to implement either of those + // alternatives correctly. + s.handlerRunning = true + go c.handleAsync() + } + }) + if err != nil { + c.processResult("acceptRequest", req, nil, err) + } +} + +// handleAsync invokes the handler on the requests in the handler queue +// sequentially until the queue is empty. +func (c *Connection) handleAsync() { + for { + var req *incomingRequest + c.updateInFlight(func(s *inFlightState) { + if len(s.handlerQueue) > 0 { + req, s.handlerQueue = s.handlerQueue[0], s.handlerQueue[1:] + } else { + s.handlerRunning = false + } + }) + if req == nil { + return + } + + // Only deliver to the Handler if not already canceled. + if err := req.ctx.Err(); err != nil { + c.updateInFlight(func(s *inFlightState) { + if s.writeErr != nil { + // Assume that req.ctx was canceled due to s.writeErr. + // TODO(#51365): use a Context API to plumb this through req.ctx. + err = fmt.Errorf("%w: %v", ErrServerClosing, s.writeErr) + } + }) + c.processResult("handleAsync", req, nil, err) + continue + } + + releaser := &releaser{ch: make(chan struct{})} + ctx := context.WithValue(req.ctx, asyncKey, releaser) + go func() { + defer releaser.release(true) + result, err := c.handler.Handle(ctx, req.Request) + c.processResult(c.handler, req, result, err) + }() + <-releaser.ch + } +} + +// processResult processes the result of a request and, if appropriate, sends a response. +func (c *Connection) processResult(from any, req *incomingRequest, result any, err error) error { + switch err { + case ErrNotHandled, ErrMethodNotFound: + // Add detail describing the unhandled method. + err = fmt.Errorf("%w: %q", ErrMethodNotFound, req.Method) + } + + if result != nil && err != nil { + c.internalErrorf("%#v returned a non-nil result with a non-nil error for %s:\n%v\n%#v", from, req.Method, err, result) + result = nil // Discard the spurious result and respond with err. + } + + if req.IsCall() { + if result == nil && err == nil { + err = c.internalErrorf("%#v returned a nil result and nil error for a %q Request that requires a Response", from, req.Method) + } + + response, respErr := NewResponse(req.ID, result, err) + + // The caller could theoretically reuse the request's ID as soon as we've + // sent the response, so ensure that it is removed from the incoming map + // before sending. + c.updateInFlight(func(s *inFlightState) { + delete(s.incomingByID, req.ID) + }) + if respErr == nil { + writeErr := c.write(notDone{req.ctx}, response) + if err == nil { + err = writeErr + } + } else { + err = c.internalErrorf("%#v returned a malformed result for %q: %w", from, req.Method, respErr) + } + } else { // req is a notification + if result != nil { + err = c.internalErrorf("%#v returned a non-nil result for a %q Request without an ID", from, req.Method) + } else if err != nil { + err = fmt.Errorf("%w: %q notification failed: %v", ErrInternal, req.Method, err) + } + } + if err != nil { + // TODO: can/should we do anything with this error beyond writing it to the event log? + // (Is this the right label to attach to the log?) + } + + // Cancel the request to free any associated resources. + req.cancel() + c.updateInFlight(func(s *inFlightState) { + if s.incoming == 0 { + panic("jsonrpc2: processResult called when incoming count is already zero") + } + s.incoming-- + }) + return nil +} + +// write is used by all things that write outgoing messages, including replies. +// it makes sure that writes are atomic +func (c *Connection) write(ctx context.Context, msg Message) error { + var err error + // Fail writes immediately if the connection is shutting down. + // + // TODO(rfindley): should we allow cancellation notifications through? It + // could be the case that writes can still succeed. + c.updateInFlight(func(s *inFlightState) { + err = s.shuttingDown(ErrServerClosing) + }) + if err == nil { + err = c.writer.Write(ctx, msg) + } + + // For rejected requests, we don't set the writeErr (which would break the + // connection). They can just be returned to the caller. + if errors.Is(err, ErrRejected) { + return err + } + + if err != nil && ctx.Err() == nil { + // The call to Write failed, and since ctx.Err() is nil we can't attribute + // the failure (even indirectly) to Context cancellation. The writer appears + // to be broken, and future writes are likely to also fail. + // + // If the read side of the connection is also broken, we might not even be + // able to receive cancellation notifications. Since we can't reliably write + // the results of incoming calls and can't receive explicit cancellations, + // cancel the calls now. + c.updateInFlight(func(s *inFlightState) { + if s.writeErr == nil { + s.writeErr = err + for _, r := range s.incomingByID { + r.cancel() + } + } + }) + } + + return err +} + +// internalErrorf reports an internal error. By default it panics, but if +// c.onInternalError is non-nil it instead calls that and returns an error +// wrapping ErrInternal. +func (c *Connection) internalErrorf(format string, args ...any) error { + err := fmt.Errorf(format, args...) + if c.onInternalError == nil { + panic("jsonrpc2: " + err.Error()) + } + c.onInternalError(err) + + return fmt.Errorf("%w: %v", ErrInternal, err) +} + +// notDone is a context.Context wrapper that returns a nil Done channel. +type notDone struct{ ctx context.Context } + +func (ic notDone) Value(key any) any { + return ic.ctx.Value(key) +} + +func (notDone) Done() <-chan struct{} { return nil } +func (notDone) Err() error { return nil } +func (notDone) Deadline() (time.Time, bool) { return time.Time{}, false } diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/frame.go b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/frame.go new file mode 100644 index 0000000000..46fcc9db99 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/frame.go @@ -0,0 +1,208 @@ +// Copyright 2018 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package jsonrpc2 + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "strconv" + "strings" + "sync" +) + +// Reader abstracts the transport mechanics from the JSON RPC protocol. +// A Conn reads messages from the reader it was provided on construction, +// and assumes that each call to Read fully transfers a single message, +// or returns an error. +// +// A reader is not safe for concurrent use, it is expected it will be used by +// a single Conn in a safe manner. +type Reader interface { + // Read gets the next message from the stream. + Read(context.Context) (Message, error) +} + +// Writer abstracts the transport mechanics from the JSON RPC protocol. +// A Conn writes messages using the writer it was provided on construction, +// and assumes that each call to Write fully transfers a single message, +// or returns an error. +// +// A writer must be safe for concurrent use, as writes may occur concurrently +// in practice: libraries may make calls or respond to requests asynchronously. +type Writer interface { + // Write sends a message to the stream. + Write(context.Context, Message) error +} + +// Framer wraps low level byte readers and writers into jsonrpc2 message +// readers and writers. +// It is responsible for the framing and encoding of messages into wire form. +// +// TODO(rfindley): rethink the framer interface, as with JSONRPC2 batching +// there is a need for Reader and Writer to be correlated, and while the +// implementation of framing here allows that, it is not made explicit by the +// interface. +// +// Perhaps a better interface would be +// +// Frame(io.ReadWriteCloser) (Reader, Writer). +type Framer interface { + // Reader wraps a byte reader into a message reader. + Reader(io.Reader) Reader + // Writer wraps a byte writer into a message writer. + Writer(io.Writer) Writer +} + +// RawFramer returns a new Framer. +// The messages are sent with no wrapping, and rely on json decode consistency +// to determine message boundaries. +func RawFramer() Framer { return rawFramer{} } + +type rawFramer struct{} +type rawReader struct{ in *json.Decoder } +type rawWriter struct { + mu sync.Mutex + out io.Writer +} + +func (rawFramer) Reader(rw io.Reader) Reader { + return &rawReader{in: json.NewDecoder(rw)} +} + +func (rawFramer) Writer(rw io.Writer) Writer { + return &rawWriter{out: rw} +} + +func (r *rawReader) Read(ctx context.Context) (Message, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + var raw json.RawMessage + if err := r.in.Decode(&raw); err != nil { + return nil, err + } + msg, err := DecodeMessage(raw) + return msg, err +} + +func (w *rawWriter) Write(ctx context.Context, msg Message) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + data, err := EncodeMessage(msg) + if err != nil { + return fmt.Errorf("marshaling message: %v", err) + } + + w.mu.Lock() + defer w.mu.Unlock() + _, err = w.out.Write(data) + return err +} + +// HeaderFramer returns a new Framer. +// The messages are sent with HTTP content length and MIME type headers. +// This is the format used by LSP and others. +func HeaderFramer() Framer { return headerFramer{} } + +type headerFramer struct{} +type headerReader struct{ in *bufio.Reader } +type headerWriter struct { + mu sync.Mutex + out io.Writer +} + +func (headerFramer) Reader(rw io.Reader) Reader { + return &headerReader{in: bufio.NewReader(rw)} +} + +func (headerFramer) Writer(rw io.Writer) Writer { + return &headerWriter{out: rw} +} + +func (r *headerReader) Read(ctx context.Context) (Message, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + firstRead := true // to detect a clean EOF below + var contentLength int64 + // read the header, stop on the first empty line + for { + line, err := r.in.ReadString('\n') + if err != nil { + if err == io.EOF { + if firstRead && line == "" { + return nil, io.EOF // clean EOF + } + err = io.ErrUnexpectedEOF + } + return nil, fmt.Errorf("failed reading header line: %w", err) + } + firstRead = false + + line = strings.TrimSpace(line) + // check we have a header line + if line == "" { + break + } + colon := strings.IndexRune(line, ':') + if colon < 0 { + return nil, fmt.Errorf("invalid header line %q", line) + } + name, value := line[:colon], strings.TrimSpace(line[colon+1:]) + switch name { + case "Content-Length": + if contentLength, err = strconv.ParseInt(value, 10, 32); err != nil { + return nil, fmt.Errorf("failed parsing Content-Length: %v", value) + } + if contentLength <= 0 { + return nil, fmt.Errorf("invalid Content-Length: %v", contentLength) + } + default: + // ignoring unknown headers + } + } + if contentLength == 0 { + return nil, fmt.Errorf("missing Content-Length header") + } + data := make([]byte, contentLength) + _, err := io.ReadFull(r.in, data) + if err != nil { + return nil, err + } + msg, err := DecodeMessage(data) + return msg, err +} + +func (w *headerWriter) Write(ctx context.Context, msg Message) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + w.mu.Lock() + defer w.mu.Unlock() + + data, err := EncodeMessage(msg) + if err != nil { + return fmt.Errorf("marshaling message: %v", err) + } + _, err = fmt.Fprintf(w.out, "Content-Length: %v\r\n\r\n", len(data)) + if err == nil { + _, err = w.out.Write(data) + } + return err +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/jsonrpc2.go b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/jsonrpc2.go new file mode 100644 index 0000000000..234e6ee3a1 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/jsonrpc2.go @@ -0,0 +1,121 @@ +// Copyright 2018 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// Package jsonrpc2 is a minimal implementation of the JSON RPC 2 spec. +// https://www.jsonrpc.org/specification +// It is intended to be compatible with other implementations at the wire level. +package jsonrpc2 + +import ( + "context" + "errors" +) + +var ( + // ErrIdleTimeout is returned when serving timed out waiting for new connections. + ErrIdleTimeout = errors.New("timed out waiting for new connections") + + // ErrNotHandled is returned from a Handler or Preempter to indicate it did + // not handle the request. + // + // If a Handler returns ErrNotHandled, the server replies with + // ErrMethodNotFound. + ErrNotHandled = errors.New("JSON RPC not handled") +) + +// Preempter handles messages on a connection before they are queued to the main +// handler. +// Primarily this is used for cancel handlers or notifications for which out of +// order processing is not an issue. +type Preempter interface { + // Preempt is invoked for each incoming request before it is queued for handling. + // + // If Preempt returns ErrNotHandled, the request will be queued, + // and eventually passed to a Handle call. + // + // Otherwise, the result and error are processed as if returned by Handle. + // + // Preempt must not block. (The Context passed to it is for Values only.) + Preempt(ctx context.Context, req *Request) (result any, err error) +} + +// A PreempterFunc implements the Preempter interface for a standalone Preempt function. +type PreempterFunc func(ctx context.Context, req *Request) (any, error) + +func (f PreempterFunc) Preempt(ctx context.Context, req *Request) (any, error) { + return f(ctx, req) +} + +var _ Preempter = PreempterFunc(nil) + +// Handler handles messages on a connection. +type Handler interface { + // Handle is invoked sequentially for each incoming request that has not + // already been handled by a Preempter. + // + // If the Request has a nil ID, Handle must return a nil result, + // and any error may be logged but will not be reported to the caller. + // + // If the Request has a non-nil ID, Handle must return either a + // non-nil, JSON-marshalable result, or a non-nil error. + // + // The Context passed to Handle will be canceled if the + // connection is broken or the request is canceled or completed. + // (If Handle returns ErrAsyncResponse, ctx will remain uncanceled + // until either Cancel or Respond is called for the request's ID.) + Handle(ctx context.Context, req *Request) (result any, err error) +} + +type defaultHandler struct{} + +func (defaultHandler) Preempt(context.Context, *Request) (any, error) { + return nil, ErrNotHandled +} + +func (defaultHandler) Handle(context.Context, *Request) (any, error) { + return nil, ErrNotHandled +} + +// A HandlerFunc implements the Handler interface for a standalone Handle function. +type HandlerFunc func(ctx context.Context, req *Request) (any, error) + +func (f HandlerFunc) Handle(ctx context.Context, req *Request) (any, error) { + return f(ctx, req) +} + +var _ Handler = HandlerFunc(nil) + +// async is a small helper for operations with an asynchronous result that you +// can wait for. +type async struct { + ready chan struct{} // closed when done + firstErr chan error // 1-buffered; contains either nil or the first non-nil error +} + +func newAsync() *async { + var a async + a.ready = make(chan struct{}) + a.firstErr = make(chan error, 1) + a.firstErr <- nil + return &a +} + +func (a *async) done() { + close(a.ready) +} + +func (a *async) wait() error { + <-a.ready + err := <-a.firstErr + a.firstErr <- err + return err +} + +func (a *async) setError(err error) { + storedErr := <-a.firstErr + if storedErr == nil { + storedErr = err + } + a.firstErr <- storedErr +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/messages.go b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/messages.go new file mode 100644 index 0000000000..791e698d96 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/messages.go @@ -0,0 +1,212 @@ +// Copyright 2018 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package jsonrpc2 + +import ( + "encoding/json" + "errors" + "fmt" +) + +// ID is a Request identifier, which is defined by the spec to be a string, integer, or null. +// https://www.jsonrpc.org/specification#request_object +type ID struct { + value any +} + +// MakeID coerces the given Go value to an ID. The value should be the +// default JSON marshaling of a Request identifier: nil, float64, or string. +// +// Returns an error if the value type was not a valid Request ID type. +// +// TODO: ID can't be a json.Marshaler/Unmarshaler, because we want to omitzero. +// Simplify this package by making ID json serializable once we can rely on +// omitzero. +func MakeID(v any) (ID, error) { + switch v := v.(type) { + case nil: + return ID{}, nil + case float64: + return Int64ID(int64(v)), nil + case string: + return StringID(v), nil + } + return ID{}, fmt.Errorf("%w: invalid ID type %T", ErrParse, v) +} + +// Message is the interface to all jsonrpc2 message types. +// They share no common functionality, but are a closed set of concrete types +// that are allowed to implement this interface. The message types are *Request +// and *Response. +type Message interface { + // marshal builds the wire form from the API form. + // It is private, which makes the set of Message implementations closed. + marshal(to *wireCombined) +} + +// Request is a Message sent to a peer to request behavior. +// If it has an ID it is a call, otherwise it is a notification. +type Request struct { + // ID of this request, used to tie the Response back to the request. + // This will be nil for notifications. + ID ID + // Method is a string containing the method name to invoke. + Method string + // Params is either a struct or an array with the parameters of the method. + Params json.RawMessage + // Extra is additional information that does not appear on the wire. It can be + // used to pass information from the application to the underlying transport. + Extra any +} + +// Response is a Message used as a reply to a call Request. +// It will have the same ID as the call it is a response to. +type Response struct { + // result is the content of the response. + Result json.RawMessage + // err is set only if the call failed. + Error error + // id of the request this is a response to. + ID ID + // Extra is additional information that does not appear on the wire. It can be + // used to pass information from the underlying transport to the application. + Extra any +} + +// StringID creates a new string request identifier. +func StringID(s string) ID { return ID{value: s} } + +// Int64ID creates a new integer request identifier. +func Int64ID(i int64) ID { return ID{value: i} } + +// IsValid returns true if the ID is a valid identifier. +// The default value for ID will return false. +func (id ID) IsValid() bool { return id.value != nil } + +// Raw returns the underlying value of the ID. +func (id ID) Raw() any { return id.value } + +// NewNotification constructs a new Notification message for the supplied +// method and parameters. +func NewNotification(method string, params any) (*Request, error) { + p, merr := marshalToRaw(params) + return &Request{Method: method, Params: p}, merr +} + +// NewCall constructs a new Call message for the supplied ID, method and +// parameters. +func NewCall(id ID, method string, params any) (*Request, error) { + p, merr := marshalToRaw(params) + return &Request{ID: id, Method: method, Params: p}, merr +} + +func (msg *Request) IsCall() bool { return msg.ID.IsValid() } + +func (msg *Request) marshal(to *wireCombined) { + to.ID = msg.ID.value + to.Method = msg.Method + to.Params = msg.Params +} + +// NewResponse constructs a new Response message that is a reply to the +// supplied. If err is set result may be ignored. +func NewResponse(id ID, result any, rerr error) (*Response, error) { + r, merr := marshalToRaw(result) + return &Response{ID: id, Result: r, Error: rerr}, merr +} + +func (msg *Response) marshal(to *wireCombined) { + to.ID = msg.ID.value + to.Error = toWireError(msg.Error) + to.Result = msg.Result +} + +func toWireError(err error) *WireError { + if err == nil { + // no error, the response is complete + return nil + } + if err, ok := err.(*WireError); ok { + // already a wire error, just use it + return err + } + result := &WireError{Message: err.Error()} + var wrapped *WireError + if errors.As(err, &wrapped) { + // if we wrapped a wire error, keep the code from the wrapped error + // but the message from the outer error + result.Code = wrapped.Code + } + return result +} + +func EncodeMessage(msg Message) ([]byte, error) { + wire := wireCombined{VersionTag: wireVersion} + msg.marshal(&wire) + data, err := json.Marshal(&wire) + if err != nil { + return data, fmt.Errorf("marshaling jsonrpc message: %w", err) + } + return data, nil +} + +// EncodeIndent is like EncodeMessage, but honors indents. +// TODO(rfindley): refactor so that this concern is handled independently. +// Perhaps we should pass in a json.Encoder? +func EncodeIndent(msg Message, prefix, indent string) ([]byte, error) { + wire := wireCombined{VersionTag: wireVersion} + msg.marshal(&wire) + data, err := json.MarshalIndent(&wire, prefix, indent) + if err != nil { + return data, fmt.Errorf("marshaling jsonrpc message: %w", err) + } + return data, nil +} + +func DecodeMessage(data []byte) (Message, error) { + msg := wireCombined{} + if err := json.Unmarshal(data, &msg); err != nil { + return nil, fmt.Errorf("unmarshaling jsonrpc message: %w", err) + } + if msg.VersionTag != wireVersion { + return nil, fmt.Errorf("invalid message version tag %q; expected %q", msg.VersionTag, wireVersion) + } + id, err := MakeID(msg.ID) + if err != nil { + return nil, err + } + if msg.Method != "" { + // has a method, must be a call + return &Request{ + Method: msg.Method, + ID: id, + Params: msg.Params, + }, nil + } + // no method, should be a response + if !id.IsValid() { + return nil, ErrInvalidRequest + } + resp := &Response{ + ID: id, + Result: msg.Result, + } + // we have to check if msg.Error is nil to avoid a typed error + if msg.Error != nil { + resp.Error = msg.Error + } + return resp, nil +} + +func marshalToRaw(obj any) (json.RawMessage, error) { + if obj == nil { + return nil, nil + } + data, err := json.Marshal(obj) + if err != nil { + return nil, err + } + return json.RawMessage(data), nil +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/net.go b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/net.go new file mode 100644 index 0000000000..05db062618 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/net.go @@ -0,0 +1,138 @@ +// Copyright 2018 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package jsonrpc2 + +import ( + "context" + "io" + "net" + "os" +) + +// This file contains implementations of the transport primitives that use the standard network +// package. + +// NetListenOptions is the optional arguments to the NetListen function. +type NetListenOptions struct { + NetListenConfig net.ListenConfig + NetDialer net.Dialer +} + +// NetListener returns a new Listener that listens on a socket using the net package. +func NetListener(ctx context.Context, network, address string, options NetListenOptions) (Listener, error) { + ln, err := options.NetListenConfig.Listen(ctx, network, address) + if err != nil { + return nil, err + } + return &netListener{net: ln}, nil +} + +// netListener is the implementation of Listener for connections made using the net package. +type netListener struct { + net net.Listener +} + +// Accept blocks waiting for an incoming connection to the listener. +func (l *netListener) Accept(context.Context) (io.ReadWriteCloser, error) { + return l.net.Accept() +} + +// Close will cause the listener to stop listening. It will not close any connections that have +// already been accepted. +func (l *netListener) Close() error { + addr := l.net.Addr() + err := l.net.Close() + if addr.Network() == "unix" { + rerr := os.Remove(addr.String()) + if rerr != nil && err == nil { + err = rerr + } + } + return err +} + +// Dialer returns a dialer that can be used to connect to the listener. +func (l *netListener) Dialer() Dialer { + return NetDialer(l.net.Addr().Network(), l.net.Addr().String(), net.Dialer{}) +} + +// NetDialer returns a Dialer using the supplied standard network dialer. +func NetDialer(network, address string, nd net.Dialer) Dialer { + return &netDialer{ + network: network, + address: address, + dialer: nd, + } +} + +type netDialer struct { + network string + address string + dialer net.Dialer +} + +func (n *netDialer) Dial(ctx context.Context) (io.ReadWriteCloser, error) { + return n.dialer.DialContext(ctx, n.network, n.address) +} + +// NetPipeListener returns a new Listener that listens using net.Pipe. +// It is only possibly to connect to it using the Dialer returned by the +// Dialer method, each call to that method will generate a new pipe the other +// side of which will be returned from the Accept call. +func NetPipeListener(ctx context.Context) (Listener, error) { + return &netPiper{ + done: make(chan struct{}), + dialed: make(chan io.ReadWriteCloser), + }, nil +} + +// netPiper is the implementation of Listener build on top of net.Pipes. +type netPiper struct { + done chan struct{} + dialed chan io.ReadWriteCloser +} + +// Accept blocks waiting for an incoming connection to the listener. +func (l *netPiper) Accept(context.Context) (io.ReadWriteCloser, error) { + // Block until the pipe is dialed or the listener is closed, + // preferring the latter if already closed at the start of Accept. + select { + case <-l.done: + return nil, net.ErrClosed + default: + } + select { + case rwc := <-l.dialed: + return rwc, nil + case <-l.done: + return nil, net.ErrClosed + } +} + +// Close will cause the listener to stop listening. It will not close any connections that have +// already been accepted. +func (l *netPiper) Close() error { + // unblock any accept calls that are pending + close(l.done) + return nil +} + +func (l *netPiper) Dialer() Dialer { + return l +} + +func (l *netPiper) Dial(ctx context.Context) (io.ReadWriteCloser, error) { + client, server := net.Pipe() + + select { + case l.dialed <- server: + return client, nil + + case <-l.done: + client.Close() + server.Close() + return nil, net.ErrClosed + } +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/serve.go b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/serve.go new file mode 100644 index 0000000000..424163aaf6 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/serve.go @@ -0,0 +1,330 @@ +// Copyright 2020 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package jsonrpc2 + +import ( + "context" + "fmt" + "io" + "runtime" + "sync" + "sync/atomic" + "time" +) + +// Listener is implemented by protocols to accept new inbound connections. +type Listener interface { + // Accept accepts an inbound connection to a server. + // It blocks until either an inbound connection is made, or the listener is closed. + Accept(context.Context) (io.ReadWriteCloser, error) + + // Close closes the listener. + // Any blocked Accept or Dial operations will unblock and return errors. + Close() error + + // Dialer returns a dialer that can be used to connect to this listener + // locally. + // If a listener does not implement this it will return nil. + Dialer() Dialer +} + +// Dialer is used by clients to dial a server. +type Dialer interface { + // Dial returns a new communication byte stream to a listening server. + Dial(ctx context.Context) (io.ReadWriteCloser, error) +} + +// Server is a running server that is accepting incoming connections. +type Server struct { + listener Listener + binder Binder + async *async + + shutdownOnce sync.Once + closing int32 // atomic: set to nonzero when Shutdown is called +} + +// Dial uses the dialer to make a new connection, wraps the returned +// reader and writer using the framer to make a stream, and then builds +// a connection on top of that stream using the binder. +// +// The returned Connection will operate independently using the Preempter and/or +// Handler provided by the Binder, and will release its own resources when the +// connection is broken, but the caller may Close it earlier to stop accepting +// (or sending) new requests. +// +// If non-nil, the onDone function is called when the connection is closed. +func Dial(ctx context.Context, dialer Dialer, binder Binder, onDone func()) (*Connection, error) { + // dial a server + rwc, err := dialer.Dial(ctx) + if err != nil { + return nil, err + } + return bindConnection(ctx, rwc, binder, onDone), nil +} + +// NewServer starts a new server listening for incoming connections and returns +// it. +// This returns a fully running and connected server, it does not block on +// the listener. +// You can call Wait to block on the server, or Shutdown to get the sever to +// terminate gracefully. +// To notice incoming connections, use an intercepting Binder. +func NewServer(ctx context.Context, listener Listener, binder Binder) *Server { + server := &Server{ + listener: listener, + binder: binder, + async: newAsync(), + } + go server.run(ctx) + return server +} + +// Wait returns only when the server has shut down. +func (s *Server) Wait() error { + return s.async.wait() +} + +// Shutdown informs the server to stop accepting new connections. +func (s *Server) Shutdown() { + s.shutdownOnce.Do(func() { + atomic.StoreInt32(&s.closing, 1) + s.listener.Close() + }) +} + +// run accepts incoming connections from the listener, +// If IdleTimeout is non-zero, run exits after there are no clients for this +// duration, otherwise it exits only on error. +func (s *Server) run(ctx context.Context) { + defer s.async.done() + + var activeConns sync.WaitGroup + for { + rwc, err := s.listener.Accept(ctx) + if err != nil { + // Only Shutdown closes the listener. If we get an error after Shutdown is + // called, assume that was the cause and don't report the error; + // otherwise, report the error in case it is unexpected. + if atomic.LoadInt32(&s.closing) == 0 { + s.async.setError(err) + } + // We are done generating new connections for good. + break + } + + // A new inbound connection. + activeConns.Add(1) + _ = bindConnection(ctx, rwc, s.binder, activeConns.Done) // unregisters itself when done + } + activeConns.Wait() +} + +// NewIdleListener wraps a listener with an idle timeout. +// +// When there are no active connections for at least the timeout duration, +// calls to Accept will fail with ErrIdleTimeout. +// +// A connection is considered inactive as soon as its Close method is called. +func NewIdleListener(timeout time.Duration, wrap Listener) Listener { + l := &idleListener{ + wrapped: wrap, + timeout: timeout, + active: make(chan int, 1), + timedOut: make(chan struct{}), + idleTimer: make(chan *time.Timer, 1), + } + l.idleTimer <- time.AfterFunc(l.timeout, l.timerExpired) + return l +} + +type idleListener struct { + wrapped Listener + timeout time.Duration + + // Only one of these channels is receivable at any given time. + active chan int // count of active connections; closed when Close is called if not timed out + timedOut chan struct{} // closed when the idle timer expires + idleTimer chan *time.Timer // holds the timer only when idle +} + +// Accept accepts an incoming connection. +// +// If an incoming connection is accepted concurrent to the listener being closed +// due to idleness, the new connection is immediately closed. +func (l *idleListener) Accept(ctx context.Context) (io.ReadWriteCloser, error) { + rwc, err := l.wrapped.Accept(ctx) + + select { + case n, ok := <-l.active: + if err != nil { + if ok { + l.active <- n + } + return nil, err + } + if ok { + l.active <- n + 1 + } else { + // l.wrapped.Close Close has been called, but Accept returned a + // connection. This race can occur with concurrent Accept and Close calls + // with any net.Listener, and it is benign: since the listener was closed + // explicitly, it can't have also timed out. + } + return l.newConn(rwc), nil + + case <-l.timedOut: + if err == nil { + // Keeping the connection open would leave the listener simultaneously + // active and closed due to idleness, which would be contradictory and + // confusing. Close the connection and pretend that it never happened. + rwc.Close() + } else { + // In theory the timeout could have raced with an unrelated error return + // from Accept. However, ErrIdleTimeout is arguably still valid (since we + // would have closed due to the timeout independent of the error), and the + // harm from returning a spurious ErrIdleTimeout is negligible anyway. + } + return nil, ErrIdleTimeout + + case timer := <-l.idleTimer: + if err != nil { + // The idle timer doesn't run until it receives itself from the idleTimer + // channel, so it can't have called l.wrapped.Close yet and thus err can't + // be ErrIdleTimeout. Leave the idle timer as it was and return whatever + // error we got. + l.idleTimer <- timer + return nil, err + } + + if !timer.Stop() { + // Failed to stop the timer — the timer goroutine is in the process of + // firing. Send the timer back to the timer goroutine so that it can + // safely close the timedOut channel, and then wait for the listener to + // actually be closed before we return ErrIdleTimeout. + l.idleTimer <- timer + rwc.Close() + <-l.timedOut + return nil, ErrIdleTimeout + } + + l.active <- 1 + return l.newConn(rwc), nil + } +} + +func (l *idleListener) Close() error { + select { + case _, ok := <-l.active: + if ok { + close(l.active) + } + + case <-l.timedOut: + // Already closed by the timer; take care not to double-close if the caller + // only explicitly invokes this Close method once, since the io.Closer + // interface explicitly leaves doubled Close calls undefined. + return ErrIdleTimeout + + case timer := <-l.idleTimer: + if !timer.Stop() { + // Couldn't stop the timer. It shouldn't take long to run, so just wait + // (so that the Listener is guaranteed to be closed before we return) + // and pretend that this call happened afterward. + // That way we won't leak any timers or goroutines when Close returns. + l.idleTimer <- timer + <-l.timedOut + return ErrIdleTimeout + } + close(l.active) + } + + return l.wrapped.Close() +} + +func (l *idleListener) Dialer() Dialer { + return l.wrapped.Dialer() +} + +func (l *idleListener) timerExpired() { + select { + case n, ok := <-l.active: + if ok { + panic(fmt.Sprintf("jsonrpc2: idleListener idle timer fired with %d connections still active", n)) + } else { + panic("jsonrpc2: Close finished with idle timer still running") + } + + case <-l.timedOut: + panic("jsonrpc2: idleListener idle timer fired more than once") + + case <-l.idleTimer: + // The timer for this very call! + } + + // Close the Listener with all channels still blocked to ensure that this call + // to l.wrapped.Close doesn't race with the one in l.Close. + defer close(l.timedOut) + l.wrapped.Close() +} + +func (l *idleListener) connClosed() { + select { + case n, ok := <-l.active: + if !ok { + // l is already closed, so it can't close due to idleness, + // and we don't need to track the number of active connections any more. + return + } + n-- + if n == 0 { + l.idleTimer <- time.AfterFunc(l.timeout, l.timerExpired) + } else { + l.active <- n + } + + case <-l.timedOut: + panic("jsonrpc2: idleListener idle timer fired before last active connection was closed") + + case <-l.idleTimer: + panic("jsonrpc2: idleListener idle timer active before last active connection was closed") + } +} + +type idleListenerConn struct { + wrapped io.ReadWriteCloser + l *idleListener + closeOnce sync.Once +} + +func (l *idleListener) newConn(rwc io.ReadWriteCloser) *idleListenerConn { + c := &idleListenerConn{ + wrapped: rwc, + l: l, + } + + // A caller that forgets to call Close may disrupt the idleListener's + // accounting, even though the file descriptor for the underlying connection + // may eventually be garbage-collected anyway. + // + // Set a (best-effort) finalizer to verify that a Close call always occurs. + // (We will clear the finalizer explicitly in Close.) + runtime.SetFinalizer(c, func(c *idleListenerConn) { + panic("jsonrpc2: IdleListener connection became unreachable without a call to Close") + }) + + return c +} + +func (c *idleListenerConn) Read(p []byte) (int, error) { return c.wrapped.Read(p) } +func (c *idleListenerConn) Write(p []byte) (int, error) { return c.wrapped.Write(p) } + +func (c *idleListenerConn) Close() error { + defer c.closeOnce.Do(func() { + c.l.connClosed() + runtime.SetFinalizer(c, nil) + }) + return c.wrapped.Close() +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/wire.go b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/wire.go new file mode 100644 index 0000000000..8be2872e43 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/wire.go @@ -0,0 +1,97 @@ +// Copyright 2018 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package jsonrpc2 + +import ( + "encoding/json" +) + +// This file contains the go forms of the wire specification. +// see http://www.jsonrpc.org/specification for details + +var ( + // ErrParse is used when invalid JSON was received by the server. + ErrParse = NewError(-32700, "parse error") + // ErrInvalidRequest is used when the JSON sent is not a valid Request object. + ErrInvalidRequest = NewError(-32600, "invalid request") + // ErrMethodNotFound should be returned by the handler when the method does + // not exist / is not available. + ErrMethodNotFound = NewError(-32601, "method not found") + // ErrInvalidParams should be returned by the handler when method + // parameter(s) were invalid. + ErrInvalidParams = NewError(-32602, "invalid params") + // ErrInternal indicates a failure to process a call correctly + ErrInternal = NewError(-32603, "internal error") + + // The following errors are not part of the json specification, but + // compliant extensions specific to this implementation. + + // ErrServerOverloaded is returned when a message was refused due to a + // server being temporarily unable to accept any new messages. + ErrServerOverloaded = NewError(-32000, "overloaded") + // ErrUnknown should be used for all non coded errors. + ErrUnknown = NewError(-32001, "unknown error") + // ErrServerClosing is returned for calls that arrive while the server is closing. + ErrServerClosing = NewError(-32004, "server is closing") + // ErrClientClosing is a dummy error returned for calls initiated while the client is closing. + ErrClientClosing = NewError(-32003, "client is closing") + + // The following errors have special semantics for MCP transports + + // ErrRejected may be wrapped to return errors from calls to Writer.Write + // that signal that the request was rejected by the transport layer as + // invalid. + // + // Such failures do not indicate that the connection is broken, but rather + // should be returned to the caller to indicate that the specific request is + // invalid in the current context. + ErrRejected = NewError(-32004, "rejected by transport") +) + +const wireVersion = "2.0" + +// wireCombined has all the fields of both Request and Response. +// We can decode this and then work out which it is. +type wireCombined struct { + VersionTag string `json:"jsonrpc"` + ID any `json:"id,omitempty"` + Method string `json:"method,omitempty"` + Params json.RawMessage `json:"params,omitempty"` + Result json.RawMessage `json:"result,omitempty"` + Error *WireError `json:"error,omitempty"` +} + +// WireError represents a structured error in a Response. +type WireError struct { + // Code is an error code indicating the type of failure. + Code int64 `json:"code"` + // Message is a short description of the error. + Message string `json:"message"` + // Data is optional structured data containing additional information about the error. + Data json.RawMessage `json:"data,omitempty"` +} + +// NewError returns an error that will encode on the wire correctly. +// The standard codes are made available from this package, this function should +// only be used to build errors for application specific codes as allowed by the +// specification. +func NewError(code int64, message string) error { + return &WireError{ + Code: code, + Message: message, + } +} + +func (err *WireError) Error() string { + return err.Message +} + +func (err *WireError) Is(other error) bool { + w, ok := other.(*WireError) + if !ok { + return false + } + return err.Code == w.Code +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/internal/util/util.go b/vendor/github.com/modelcontextprotocol/go-sdk/internal/util/util.go new file mode 100644 index 0000000000..4b5c325fa9 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/internal/util/util.go @@ -0,0 +1,44 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package util + +import ( + "cmp" + "fmt" + "iter" + "slices" +) + +// Helpers below are copied from gopls' moremaps package. + +// Sorted returns an iterator over the entries of m in key order. +func Sorted[M ~map[K]V, K cmp.Ordered, V any](m M) iter.Seq2[K, V] { + // TODO(adonovan): use maps.Sorted if proposal #68598 is accepted. + return func(yield func(K, V) bool) { + keys := KeySlice(m) + slices.Sort(keys) + for _, k := range keys { + if !yield(k, m[k]) { + break + } + } + } +} + +// KeySlice returns the keys of the map M, like slices.Collect(maps.Keys(m)). +func KeySlice[M ~map[K]V, K comparable, V any](m M) []K { + r := make([]K, 0, len(m)) + for k := range m { + r = append(r, k) + } + return r +} + +// Wrapf wraps *errp with the given formatted message if *errp is not nil. +func Wrapf(errp *error, format string, args ...any) { + if *errp != nil { + *errp = fmt.Errorf("%s: %w", fmt.Sprintf(format, args...), *errp) + } +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/internal/xcontext/xcontext.go b/vendor/github.com/modelcontextprotocol/go-sdk/internal/xcontext/xcontext.go new file mode 100644 index 0000000000..849060d57e --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/internal/xcontext/xcontext.go @@ -0,0 +1,23 @@ +// Copyright 2019 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// Package xcontext is a package to offer the extra functionality we need +// from contexts that is not available from the standard context package. +package xcontext + +import ( + "context" + "time" +) + +// Detach returns a context that keeps all the values of its parent context +// but detaches from the cancellation and error handling. +func Detach(ctx context.Context) context.Context { return detachedContext{ctx} } + +type detachedContext struct{ parent context.Context } + +func (v detachedContext) Deadline() (time.Time, bool) { return time.Time{}, false } +func (v detachedContext) Done() <-chan struct{} { return nil } +func (v detachedContext) Err() error { return nil } +func (v detachedContext) Value(key any) any { return v.parent.Value(key) } diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/jsonrpc/jsonrpc.go b/vendor/github.com/modelcontextprotocol/go-sdk/jsonrpc/jsonrpc.go new file mode 100644 index 0000000000..1633d4e3c7 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/jsonrpc/jsonrpc.go @@ -0,0 +1,39 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// Package jsonrpc exposes part of a JSON-RPC v2 implementation +// for use by mcp transport authors. +package jsonrpc + +import "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + +type ( + // ID is a JSON-RPC request ID. + ID = jsonrpc2.ID + // Message is a JSON-RPC message. + Message = jsonrpc2.Message + // Request is a JSON-RPC request. + Request = jsonrpc2.Request + // Response is a JSON-RPC response. + Response = jsonrpc2.Response +) + +// MakeID coerces the given Go value to an ID. The value should be the +// default JSON marshaling of a Request identifier: nil, float64, or string. +// +// Returns an error if the value type was not a valid Request ID type. +func MakeID(v any) (ID, error) { + return jsonrpc2.MakeID(v) +} + +// EncodeMessage serializes a JSON-RPC message to its wire format. +func EncodeMessage(msg Message) ([]byte, error) { + return jsonrpc2.EncodeMessage(msg) +} + +// DecodeMessage deserializes JSON-RPC wire format data into a Message. +// It returns either a Request or Response based on the message content. +func DecodeMessage(data []byte) (Message, error) { + return jsonrpc2.DecodeMessage(data) +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/client.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/client.go new file mode 100644 index 0000000000..d7e3ae5a68 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/client.go @@ -0,0 +1,762 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "iter" + "slices" + "sync" + "sync/atomic" + "time" + + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" +) + +// A Client is an MCP client, which may be connected to an MCP server +// using the [Client.Connect] method. +type Client struct { + impl *Implementation + opts ClientOptions + mu sync.Mutex + roots *featureSet[*Root] + sessions []*ClientSession + sendingMethodHandler_ MethodHandler + receivingMethodHandler_ MethodHandler +} + +// NewClient creates a new [Client]. +// +// Use [Client.Connect] to connect it to an MCP server. +// +// The first argument must not be nil. +// +// If non-nil, the provided options configure the Client. +func NewClient(impl *Implementation, opts *ClientOptions) *Client { + if impl == nil { + panic("nil Implementation") + } + c := &Client{ + impl: impl, + roots: newFeatureSet(func(r *Root) string { return r.URI }), + sendingMethodHandler_: defaultSendingMethodHandler[*ClientSession], + receivingMethodHandler_: defaultReceivingMethodHandler[*ClientSession], + } + if opts != nil { + c.opts = *opts + } + return c +} + +// ClientOptions configures the behavior of the client. +type ClientOptions struct { + // CreateMessageHandler handles incoming requests for sampling/createMessage. + // + // Setting CreateMessageHandler to a non-nil value causes the client to + // advertise the sampling capability. + CreateMessageHandler func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) + // ElicitationHandler handles incoming requests for elicitation/create. + // + // Setting ElicitationHandler to a non-nil value causes the client to + // advertise the elicitation capability. + ElicitationHandler func(context.Context, *ElicitRequest) (*ElicitResult, error) + // Handlers for notifications from the server. + ToolListChangedHandler func(context.Context, *ToolListChangedRequest) + PromptListChangedHandler func(context.Context, *PromptListChangedRequest) + ResourceListChangedHandler func(context.Context, *ResourceListChangedRequest) + ResourceUpdatedHandler func(context.Context, *ResourceUpdatedNotificationRequest) + LoggingMessageHandler func(context.Context, *LoggingMessageRequest) + ProgressNotificationHandler func(context.Context, *ProgressNotificationClientRequest) + // If non-zero, defines an interval for regular "ping" requests. + // If the peer fails to respond to pings originating from the keepalive check, + // the session is automatically closed. + KeepAlive time.Duration +} + +// bind implements the binder[*ClientSession] interface, so that Clients can +// be connected using [connect]. +func (c *Client) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *clientSessionState, onClose func()) *ClientSession { + assert(mcpConn != nil && conn != nil, "nil connection") + cs := &ClientSession{conn: conn, mcpConn: mcpConn, client: c, onClose: onClose} + if state != nil { + cs.state = *state + } + c.mu.Lock() + defer c.mu.Unlock() + c.sessions = append(c.sessions, cs) + return cs +} + +// disconnect implements the binder[*Client] interface, so that +// Clients can be connected using [connect]. +func (c *Client) disconnect(cs *ClientSession) { + c.mu.Lock() + defer c.mu.Unlock() + c.sessions = slices.DeleteFunc(c.sessions, func(cs2 *ClientSession) bool { + return cs2 == cs + }) +} + +// TODO: Consider exporting this type and its field. +type unsupportedProtocolVersionError struct { + version string +} + +func (e unsupportedProtocolVersionError) Error() string { + return fmt.Sprintf("unsupported protocol version: %q", e.version) +} + +// ClientSessionOptions is reserved for future use. +type ClientSessionOptions struct{} + +func (c *Client) capabilities() *ClientCapabilities { + caps := &ClientCapabilities{} + caps.Roots.ListChanged = true + if c.opts.CreateMessageHandler != nil { + caps.Sampling = &SamplingCapabilities{} + } + if c.opts.ElicitationHandler != nil { + caps.Elicitation = &ElicitationCapabilities{} + } + return caps +} + +// Connect begins an MCP session by connecting to a server over the given +// transport. The resulting session is initialized, and ready to use. +// +// Typically, it is the responsibility of the client to close the connection +// when it is no longer needed. However, if the connection is closed by the +// server, calls or notifications will return an error wrapping +// [ErrConnectionClosed]. +func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptions) (cs *ClientSession, err error) { + cs, err = connect(ctx, t, c, (*clientSessionState)(nil), nil) + if err != nil { + return nil, err + } + + params := &InitializeParams{ + ProtocolVersion: latestProtocolVersion, + ClientInfo: c.impl, + Capabilities: c.capabilities(), + } + req := &InitializeRequest{Session: cs, Params: params} + res, err := handleSend[*InitializeResult](ctx, methodInitialize, req) + if err != nil { + _ = cs.Close() + return nil, err + } + if !slices.Contains(supportedProtocolVersions, res.ProtocolVersion) { + return nil, unsupportedProtocolVersionError{res.ProtocolVersion} + } + cs.state.InitializeResult = res + if hc, ok := cs.mcpConn.(clientConnection); ok { + hc.sessionUpdated(cs.state) + } + req2 := &initializedClientRequest{Session: cs, Params: &InitializedParams{}} + if err := handleNotify(ctx, notificationInitialized, req2); err != nil { + _ = cs.Close() + return nil, err + } + + if c.opts.KeepAlive > 0 { + cs.startKeepalive(c.opts.KeepAlive) + } + + return cs, nil +} + +// A ClientSession is a logical connection with an MCP server. Its +// methods can be used to send requests or notifications to the server. Create +// a session by calling [Client.Connect]. +// +// Call [ClientSession.Close] to close the connection, or await server +// termination with [ClientSession.Wait]. +type ClientSession struct { + // Ensure that onClose is called at most once. + // We defensively use an atomic CompareAndSwap rather than a sync.Once, in case the + // onClose callback triggers a re-entrant call to Close. + calledOnClose atomic.Bool + onClose func() + + conn *jsonrpc2.Connection + client *Client + keepaliveCancel context.CancelFunc + mcpConn Connection + + // No mutex is (currently) required to guard the session state, because it is + // only set synchronously during Client.Connect. + state clientSessionState +} + +type clientSessionState struct { + InitializeResult *InitializeResult +} + +func (cs *ClientSession) InitializeResult() *InitializeResult { return cs.state.InitializeResult } + +func (cs *ClientSession) ID() string { + if c, ok := cs.mcpConn.(hasSessionID); ok { + return c.SessionID() + } + return "" +} + +// Close performs a graceful close of the connection, preventing new requests +// from being handled, and waiting for ongoing requests to return. Close then +// terminates the connection. +// +// Close is idempotent and concurrency safe. +func (cs *ClientSession) Close() error { + // Note: keepaliveCancel access is safe without a mutex because: + // 1. keepaliveCancel is only written once during startKeepalive (happens-before all Close calls) + // 2. context.CancelFunc is safe to call multiple times and from multiple goroutines + // 3. The keepalive goroutine calls Close on ping failure, but this is safe since + // Close is idempotent and conn.Close() handles concurrent calls correctly + if cs.keepaliveCancel != nil { + cs.keepaliveCancel() + } + err := cs.conn.Close() + + if cs.onClose != nil && cs.calledOnClose.CompareAndSwap(false, true) { + cs.onClose() + } + + return err +} + +// Wait waits for the connection to be closed by the server. +// Generally, clients should be responsible for closing the connection. +func (cs *ClientSession) Wait() error { + return cs.conn.Wait() +} + +// startKeepalive starts the keepalive mechanism for this client session. +func (cs *ClientSession) startKeepalive(interval time.Duration) { + startKeepalive(cs, interval, &cs.keepaliveCancel) +} + +// AddRoots adds the given roots to the client, +// replacing any with the same URIs, +// and notifies any connected servers. +func (c *Client) AddRoots(roots ...*Root) { + // Only notify if something could change. + if len(roots) == 0 { + return + } + changeAndNotify(c, notificationRootsListChanged, &RootsListChangedParams{}, + func() bool { c.roots.add(roots...); return true }) +} + +// RemoveRoots removes the roots with the given URIs, +// and notifies any connected servers if the list has changed. +// It is not an error to remove a nonexistent root. +func (c *Client) RemoveRoots(uris ...string) { + changeAndNotify(c, notificationRootsListChanged, &RootsListChangedParams{}, + func() bool { return c.roots.remove(uris...) }) +} + +// changeAndNotify is called when a feature is added or removed. +// It calls change, which should do the work and report whether a change actually occurred. +// If there was a change, it notifies a snapshot of the sessions. +func changeAndNotify[P Params](c *Client, notification string, params P, change func() bool) { + var sessions []*ClientSession + // Lock for the change, but not for the notification. + c.mu.Lock() + if change() { + sessions = slices.Clone(c.sessions) + } + c.mu.Unlock() + notifySessions(sessions, notification, params) +} + +func (c *Client) listRoots(_ context.Context, req *ListRootsRequest) (*ListRootsResult, error) { + c.mu.Lock() + defer c.mu.Unlock() + roots := slices.Collect(c.roots.all()) + if roots == nil { + roots = []*Root{} // avoid JSON null + } + return &ListRootsResult{ + Roots: roots, + }, nil +} + +func (c *Client) createMessage(ctx context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) { + if c.opts.CreateMessageHandler == nil { + // TODO: wrap or annotate this error? Pick a standard code? + return nil, jsonrpc2.NewError(codeUnsupportedMethod, "client does not support CreateMessage") + } + return c.opts.CreateMessageHandler(ctx, req) +} + +func (c *Client) elicit(ctx context.Context, req *ElicitRequest) (*ElicitResult, error) { + if c.opts.ElicitationHandler == nil { + // TODO: wrap or annotate this error? Pick a standard code? + return nil, jsonrpc2.NewError(codeUnsupportedMethod, "client does not support elicitation") + } + + // Validate that the requested schema only contains top-level properties without nesting + schema, err := validateElicitSchema(req.Params.RequestedSchema) + if err != nil { + return nil, jsonrpc2.NewError(codeInvalidParams, err.Error()) + } + + res, err := c.opts.ElicitationHandler(ctx, req) + if err != nil { + return nil, err + } + + // Validate elicitation result content against requested schema + if schema != nil && res.Content != nil { + // TODO: is this the correct behavior if validation fails? + // It isn't the *server's* params that are invalid, so why would we return + // this code to the server? + resolved, err := schema.Resolve(nil) + if err != nil { + return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("failed to resolve requested schema: %v", err)) + } + + if err := resolved.Validate(res.Content); err != nil { + return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("elicitation result content does not match requested schema: %v", err)) + } + } + + return res, nil +} + +// validateElicitSchema validates that the schema conforms to MCP elicitation schema requirements. +// Per the MCP specification, elicitation schemas are limited to flat objects with primitive properties only. +func validateElicitSchema(wireSchema any) (*jsonschema.Schema, error) { + if wireSchema == nil { + return nil, nil // nil schema is allowed + } + + var schema *jsonschema.Schema + if err := remarshal(wireSchema, &schema); err != nil { + return nil, err + } + + // The root schema must be of type "object" if specified + if schema.Type != "" && schema.Type != "object" { + return nil, fmt.Errorf("elicit schema must be of type 'object', got %q", schema.Type) + } + + // Check if the schema has properties + if schema.Properties != nil { + for propName, propSchema := range schema.Properties { + if propSchema == nil { + continue + } + + if err := validateElicitProperty(propName, propSchema); err != nil { + return nil, err + } + } + } + + return schema, nil +} + +// validateElicitProperty validates a single property in an elicitation schema. +func validateElicitProperty(propName string, propSchema *jsonschema.Schema) error { + // Check if this property has nested properties (not allowed) + if len(propSchema.Properties) > 0 { + return fmt.Errorf("elicit schema property %q contains nested properties, only primitive properties are allowed", propName) + } + + // Validate based on the property type - only primitives are supported + switch propSchema.Type { + case "string": + return validateElicitStringProperty(propName, propSchema) + case "number", "integer": + return validateElicitNumberProperty(propName, propSchema) + case "boolean": + return validateElicitBooleanProperty(propName, propSchema) + default: + return fmt.Errorf("elicit schema property %q has unsupported type %q, only string, number, integer, and boolean are allowed", propName, propSchema.Type) + } +} + +// validateElicitStringProperty validates string-type properties, including enums. +func validateElicitStringProperty(propName string, propSchema *jsonschema.Schema) error { + // Handle enum validation (enums are a special case of strings) + if len(propSchema.Enum) > 0 { + // Enums must be string type (or untyped which defaults to string) + if propSchema.Type != "" && propSchema.Type != "string" { + return fmt.Errorf("elicit schema property %q has enum values but type is %q, enums are only supported for string type", propName, propSchema.Type) + } + // Enum values themselves are validated by the JSON schema library + // Validate enumNames if present - must match enum length + if propSchema.Extra != nil { + if enumNamesRaw, exists := propSchema.Extra["enumNames"]; exists { + // Type check enumNames - should be a slice + if enumNamesSlice, ok := enumNamesRaw.([]any); ok { + if len(enumNamesSlice) != len(propSchema.Enum) { + return fmt.Errorf("elicit schema property %q has %d enum values but %d enumNames, they must match", propName, len(propSchema.Enum), len(enumNamesSlice)) + } + } else { + return fmt.Errorf("elicit schema property %q has invalid enumNames type, must be an array", propName) + } + } + } + return nil + } + + // Validate format if specified - only specific formats are allowed + if propSchema.Format != "" { + allowedFormats := map[string]bool{ + "email": true, + "uri": true, + "date": true, + "date-time": true, + } + if !allowedFormats[propSchema.Format] { + return fmt.Errorf("elicit schema property %q has unsupported format %q, only email, uri, date, and date-time are allowed", propName, propSchema.Format) + } + } + + // Validate minLength constraint if specified + if propSchema.MinLength != nil { + if *propSchema.MinLength < 0 { + return fmt.Errorf("elicit schema property %q has invalid minLength %d, must be non-negative", propName, *propSchema.MinLength) + } + } + + // Validate maxLength constraint if specified + if propSchema.MaxLength != nil { + if *propSchema.MaxLength < 0 { + return fmt.Errorf("elicit schema property %q has invalid maxLength %d, must be non-negative", propName, *propSchema.MaxLength) + } + // Check that maxLength >= minLength if both are specified + if propSchema.MinLength != nil && *propSchema.MaxLength < *propSchema.MinLength { + return fmt.Errorf("elicit schema property %q has maxLength %d less than minLength %d", propName, *propSchema.MaxLength, *propSchema.MinLength) + } + } + + return nil +} + +// validateElicitNumberProperty validates number and integer-type properties. +func validateElicitNumberProperty(propName string, propSchema *jsonschema.Schema) error { + if propSchema.Minimum != nil && propSchema.Maximum != nil { + if *propSchema.Maximum < *propSchema.Minimum { + return fmt.Errorf("elicit schema property %q has maximum %g less than minimum %g", propName, *propSchema.Maximum, *propSchema.Minimum) + } + } + + return nil +} + +// validateElicitBooleanProperty validates boolean-type properties. +func validateElicitBooleanProperty(propName string, propSchema *jsonschema.Schema) error { + // Validate default value if specified - must be a valid boolean + if propSchema.Default != nil { + var defaultValue bool + if err := json.Unmarshal(propSchema.Default, &defaultValue); err != nil { + return fmt.Errorf("elicit schema property %q has invalid default value, must be a boolean: %v", propName, err) + } + } + + return nil +} + +// AddSendingMiddleware wraps the current sending method handler using the provided +// middleware. Middleware is applied from right to left, so that the first one is +// executed first. +// +// For example, AddSendingMiddleware(m1, m2, m3) augments the method handler as +// m1(m2(m3(handler))). +// +// Sending middleware is called when a request is sent. It is useful for tasks +// such as tracing, metrics, and adding progress tokens. +func (c *Client) AddSendingMiddleware(middleware ...Middleware) { + c.mu.Lock() + defer c.mu.Unlock() + addMiddleware(&c.sendingMethodHandler_, middleware) +} + +// AddReceivingMiddleware wraps the current receiving method handler using +// the provided middleware. Middleware is applied from right to left, so that the +// first one is executed first. +// +// For example, AddReceivingMiddleware(m1, m2, m3) augments the method handler as +// m1(m2(m3(handler))). +// +// Receiving middleware is called when a request is received. It is useful for tasks +// such as authentication, request logging and metrics. +func (c *Client) AddReceivingMiddleware(middleware ...Middleware) { + c.mu.Lock() + defer c.mu.Unlock() + addMiddleware(&c.receivingMethodHandler_, middleware) +} + +// clientMethodInfos maps from the RPC method name to serverMethodInfos. +// +// The 'allowMissingParams' values are extracted from the protocol schema. +// TODO(rfindley): actually load and validate the protocol schema, rather than +// curating these method flags. +var clientMethodInfos = map[string]methodInfo{ + methodComplete: newClientMethodInfo(clientSessionMethod((*ClientSession).Complete), 0), + methodPing: newClientMethodInfo(clientSessionMethod((*ClientSession).ping), missingParamsOK), + methodListRoots: newClientMethodInfo(clientMethod((*Client).listRoots), missingParamsOK), + methodCreateMessage: newClientMethodInfo(clientMethod((*Client).createMessage), 0), + methodElicit: newClientMethodInfo(clientMethod((*Client).elicit), missingParamsOK), + notificationCancelled: newClientMethodInfo(clientSessionMethod((*ClientSession).cancel), notification|missingParamsOK), + notificationToolListChanged: newClientMethodInfo(clientMethod((*Client).callToolChangedHandler), notification|missingParamsOK), + notificationPromptListChanged: newClientMethodInfo(clientMethod((*Client).callPromptChangedHandler), notification|missingParamsOK), + notificationResourceListChanged: newClientMethodInfo(clientMethod((*Client).callResourceChangedHandler), notification|missingParamsOK), + notificationResourceUpdated: newClientMethodInfo(clientMethod((*Client).callResourceUpdatedHandler), notification|missingParamsOK), + notificationLoggingMessage: newClientMethodInfo(clientMethod((*Client).callLoggingHandler), notification), + notificationProgress: newClientMethodInfo(clientSessionMethod((*ClientSession).callProgressNotificationHandler), notification), +} + +func (cs *ClientSession) sendingMethodInfos() map[string]methodInfo { + return serverMethodInfos +} + +func (cs *ClientSession) receivingMethodInfos() map[string]methodInfo { + return clientMethodInfos +} + +func (cs *ClientSession) handle(ctx context.Context, req *jsonrpc.Request) (any, error) { + if req.IsCall() { + jsonrpc2.Async(ctx) + } + return handleReceive(ctx, cs, req) +} + +func (cs *ClientSession) sendingMethodHandler() MethodHandler { + cs.client.mu.Lock() + defer cs.client.mu.Unlock() + return cs.client.sendingMethodHandler_ +} + +func (cs *ClientSession) receivingMethodHandler() MethodHandler { + cs.client.mu.Lock() + defer cs.client.mu.Unlock() + return cs.client.receivingMethodHandler_ +} + +// getConn implements [Session.getConn]. +func (cs *ClientSession) getConn() *jsonrpc2.Connection { return cs.conn } + +func (*ClientSession) ping(context.Context, *PingParams) (*emptyResult, error) { + return &emptyResult{}, nil +} + +// cancel is a placeholder: cancellation is handled the jsonrpc2 package. +// +// It should never be invoked in practice because cancellation is preempted, +// but having its signature here facilitates the construction of methodInfo +// that can be used to validate incoming cancellation notifications. +func (*ClientSession) cancel(context.Context, *CancelledParams) (Result, error) { + return nil, nil +} + +func newClientRequest[P Params](cs *ClientSession, params P) *ClientRequest[P] { + return &ClientRequest[P]{Session: cs, Params: params} +} + +// Ping makes an MCP "ping" request to the server. +func (cs *ClientSession) Ping(ctx context.Context, params *PingParams) error { + _, err := handleSend[*emptyResult](ctx, methodPing, newClientRequest(cs, orZero[Params](params))) + return err +} + +// ListPrompts lists prompts that are currently available on the server. +func (cs *ClientSession) ListPrompts(ctx context.Context, params *ListPromptsParams) (*ListPromptsResult, error) { + return handleSend[*ListPromptsResult](ctx, methodListPrompts, newClientRequest(cs, orZero[Params](params))) +} + +// GetPrompt gets a prompt from the server. +func (cs *ClientSession) GetPrompt(ctx context.Context, params *GetPromptParams) (*GetPromptResult, error) { + return handleSend[*GetPromptResult](ctx, methodGetPrompt, newClientRequest(cs, orZero[Params](params))) +} + +// ListTools lists tools that are currently available on the server. +func (cs *ClientSession) ListTools(ctx context.Context, params *ListToolsParams) (*ListToolsResult, error) { + return handleSend[*ListToolsResult](ctx, methodListTools, newClientRequest(cs, orZero[Params](params))) +} + +// CallTool calls the tool with the given parameters. +// +// The params.Arguments can be any value that marshals into a JSON object. +func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams) (*CallToolResult, error) { + if params == nil { + params = new(CallToolParams) + } + if params.Arguments == nil { + // Avoid sending nil over the wire. + params.Arguments = map[string]any{} + } + return handleSend[*CallToolResult](ctx, methodCallTool, newClientRequest(cs, orZero[Params](params))) +} + +func (cs *ClientSession) SetLoggingLevel(ctx context.Context, params *SetLoggingLevelParams) error { + _, err := handleSend[*emptyResult](ctx, methodSetLevel, newClientRequest(cs, orZero[Params](params))) + return err +} + +// ListResources lists the resources that are currently available on the server. +func (cs *ClientSession) ListResources(ctx context.Context, params *ListResourcesParams) (*ListResourcesResult, error) { + return handleSend[*ListResourcesResult](ctx, methodListResources, newClientRequest(cs, orZero[Params](params))) +} + +// ListResourceTemplates lists the resource templates that are currently available on the server. +func (cs *ClientSession) ListResourceTemplates(ctx context.Context, params *ListResourceTemplatesParams) (*ListResourceTemplatesResult, error) { + return handleSend[*ListResourceTemplatesResult](ctx, methodListResourceTemplates, newClientRequest(cs, orZero[Params](params))) +} + +// ReadResource asks the server to read a resource and return its contents. +func (cs *ClientSession) ReadResource(ctx context.Context, params *ReadResourceParams) (*ReadResourceResult, error) { + return handleSend[*ReadResourceResult](ctx, methodReadResource, newClientRequest(cs, orZero[Params](params))) +} + +func (cs *ClientSession) Complete(ctx context.Context, params *CompleteParams) (*CompleteResult, error) { + return handleSend[*CompleteResult](ctx, methodComplete, newClientRequest(cs, orZero[Params](params))) +} + +// Subscribe sends a "resources/subscribe" request to the server, asking for +// notifications when the specified resource changes. +func (cs *ClientSession) Subscribe(ctx context.Context, params *SubscribeParams) error { + _, err := handleSend[*emptyResult](ctx, methodSubscribe, newClientRequest(cs, orZero[Params](params))) + return err +} + +// Unsubscribe sends a "resources/unsubscribe" request to the server, cancelling +// a previous subscription. +func (cs *ClientSession) Unsubscribe(ctx context.Context, params *UnsubscribeParams) error { + _, err := handleSend[*emptyResult](ctx, methodUnsubscribe, newClientRequest(cs, orZero[Params](params))) + return err +} + +func (c *Client) callToolChangedHandler(ctx context.Context, req *ToolListChangedRequest) (Result, error) { + if h := c.opts.ToolListChangedHandler; h != nil { + h(ctx, req) + } + return nil, nil +} + +func (c *Client) callPromptChangedHandler(ctx context.Context, req *PromptListChangedRequest) (Result, error) { + if h := c.opts.PromptListChangedHandler; h != nil { + h(ctx, req) + } + return nil, nil +} + +func (c *Client) callResourceChangedHandler(ctx context.Context, req *ResourceListChangedRequest) (Result, error) { + if h := c.opts.ResourceListChangedHandler; h != nil { + h(ctx, req) + } + return nil, nil +} + +func (c *Client) callResourceUpdatedHandler(ctx context.Context, req *ResourceUpdatedNotificationRequest) (Result, error) { + if h := c.opts.ResourceUpdatedHandler; h != nil { + h(ctx, req) + } + return nil, nil +} + +func (c *Client) callLoggingHandler(ctx context.Context, req *LoggingMessageRequest) (Result, error) { + if h := c.opts.LoggingMessageHandler; h != nil { + h(ctx, req) + } + return nil, nil +} + +func (cs *ClientSession) callProgressNotificationHandler(ctx context.Context, params *ProgressNotificationParams) (Result, error) { + if h := cs.client.opts.ProgressNotificationHandler; h != nil { + h(ctx, clientRequestFor(cs, params)) + } + return nil, nil +} + +// NotifyProgress sends a progress notification from the client to the server +// associated with this session. +// This can be used if the client is performing a long-running task that was +// initiated by the server. +func (cs *ClientSession) NotifyProgress(ctx context.Context, params *ProgressNotificationParams) error { + return handleNotify(ctx, notificationProgress, newClientRequest(cs, orZero[Params](params))) +} + +// Tools provides an iterator for all tools available on the server, +// automatically fetching pages and managing cursors. +// The params argument can set the initial cursor. +// Iteration stops at the first encountered error, which will be yielded. +func (cs *ClientSession) Tools(ctx context.Context, params *ListToolsParams) iter.Seq2[*Tool, error] { + if params == nil { + params = &ListToolsParams{} + } + return paginate(ctx, params, cs.ListTools, func(res *ListToolsResult) []*Tool { + return res.Tools + }) +} + +// Resources provides an iterator for all resources available on the server, +// automatically fetching pages and managing cursors. +// The params argument can set the initial cursor. +// Iteration stops at the first encountered error, which will be yielded. +func (cs *ClientSession) Resources(ctx context.Context, params *ListResourcesParams) iter.Seq2[*Resource, error] { + if params == nil { + params = &ListResourcesParams{} + } + return paginate(ctx, params, cs.ListResources, func(res *ListResourcesResult) []*Resource { + return res.Resources + }) +} + +// ResourceTemplates provides an iterator for all resource templates available on the server, +// automatically fetching pages and managing cursors. +// The params argument can set the initial cursor. +// Iteration stops at the first encountered error, which will be yielded. +func (cs *ClientSession) ResourceTemplates(ctx context.Context, params *ListResourceTemplatesParams) iter.Seq2[*ResourceTemplate, error] { + if params == nil { + params = &ListResourceTemplatesParams{} + } + return paginate(ctx, params, cs.ListResourceTemplates, func(res *ListResourceTemplatesResult) []*ResourceTemplate { + return res.ResourceTemplates + }) +} + +// Prompts provides an iterator for all prompts available on the server, +// automatically fetching pages and managing cursors. +// The params argument can set the initial cursor. +// Iteration stops at the first encountered error, which will be yielded. +func (cs *ClientSession) Prompts(ctx context.Context, params *ListPromptsParams) iter.Seq2[*Prompt, error] { + if params == nil { + params = &ListPromptsParams{} + } + return paginate(ctx, params, cs.ListPrompts, func(res *ListPromptsResult) []*Prompt { + return res.Prompts + }) +} + +// paginate is a generic helper function to provide a paginated iterator. +func paginate[P listParams, R listResult[T], T any](ctx context.Context, params P, listFunc func(context.Context, P) (R, error), items func(R) []*T) iter.Seq2[*T, error] { + return func(yield func(*T, error) bool) { + for { + res, err := listFunc(ctx, params) + if err != nil { + yield(nil, err) + return + } + for _, r := range items(res) { + if !yield(r, nil) { + return + } + } + nextCursorVal := res.nextCursorPtr() + if nextCursorVal == nil || *nextCursorVal == "" { + return + } + *params.cursorPtr() = *nextCursorVal + } + } +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/cmd.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/cmd.go new file mode 100644 index 0000000000..b531eaf132 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/cmd.go @@ -0,0 +1,108 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "fmt" + "io" + "os/exec" + "syscall" + "time" +) + +var defaultTerminateDuration = 5 * time.Second // mutable for testing + +// A CommandTransport is a [Transport] that runs a command and communicates +// with it over stdin/stdout, using newline-delimited JSON. +type CommandTransport struct { + Command *exec.Cmd + // TerminateDuration controls how long Close waits after closing stdin + // for the process to exit before sending SIGTERM. + // If zero or negative, the default of 5s is used. + TerminateDuration time.Duration +} + +// Connect starts the command, and connects to it over stdin/stdout. +func (t *CommandTransport) Connect(ctx context.Context) (Connection, error) { + stdout, err := t.Command.StdoutPipe() + if err != nil { + return nil, err + } + stdout = io.NopCloser(stdout) // close the connection by closing stdin, not stdout + stdin, err := t.Command.StdinPipe() + if err != nil { + return nil, err + } + if err := t.Command.Start(); err != nil { + return nil, err + } + td := t.TerminateDuration + if td <= 0 { + td = defaultTerminateDuration + } + return newIOConn(&pipeRWC{t.Command, stdout, stdin, td}), nil +} + +// A pipeRWC is an io.ReadWriteCloser that communicates with a subprocess over +// stdin/stdout pipes. +type pipeRWC struct { + cmd *exec.Cmd + stdout io.ReadCloser + stdin io.WriteCloser + terminateDuration time.Duration +} + +func (s *pipeRWC) Read(p []byte) (n int, err error) { + return s.stdout.Read(p) +} + +func (s *pipeRWC) Write(p []byte) (n int, err error) { + return s.stdin.Write(p) +} + +// Close closes the input stream to the child process, and awaits normal +// termination of the command. If the command does not exit, it is signalled to +// terminate, and then eventually killed. +func (s *pipeRWC) Close() error { + // Spec: + // "For the stdio transport, the client SHOULD initiate shutdown by:... + + // "...First, closing the input stream to the child process (the server)" + if err := s.stdin.Close(); err != nil { + return fmt.Errorf("closing stdin: %v", err) + } + resChan := make(chan error, 1) + go func() { + resChan <- s.cmd.Wait() + }() + // "...Waiting for the server to exit, or sending SIGTERM if the server does not exit within a reasonable time" + wait := func() (error, bool) { + select { + case err := <-resChan: + return err, true + case <-time.After(s.terminateDuration): + } + return nil, false + } + if err, ok := wait(); ok { + return err + } + // Note the condition here: if sending SIGTERM fails, don't wait and just + // move on to SIGKILL. + if err := s.cmd.Process.Signal(syscall.SIGTERM); err == nil { + if err, ok := wait(); ok { + return err + } + } + // "...Sending SIGKILL if the server does not exit within a reasonable time after SIGTERM" + if err := s.cmd.Process.Kill(); err != nil { + return err + } + if err, ok := wait(); ok { + return err + } + return fmt.Errorf("unresponsive subprocess") +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/content.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/content.go new file mode 100644 index 0000000000..e53cad14bc --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/content.go @@ -0,0 +1,284 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// TODO(findleyr): update JSON marshalling of all content types to preserve required fields. +// (See [TextContent.MarshalJSON], which handles this for text content). + +package mcp + +import ( + "encoding/json" + "errors" + "fmt" +) + +// A Content is a [TextContent], [ImageContent], [AudioContent], +// [ResourceLink], or [EmbeddedResource]. +type Content interface { + MarshalJSON() ([]byte, error) + fromWire(*wireContent) +} + +// TextContent is a textual content. +type TextContent struct { + Text string + Meta Meta + Annotations *Annotations +} + +func (c *TextContent) MarshalJSON() ([]byte, error) { + // Custom wire format to ensure the required "text" field is always included, even when empty. + wire := struct { + Type string `json:"type"` + Text string `json:"text"` + Meta Meta `json:"_meta,omitempty"` + Annotations *Annotations `json:"annotations,omitempty"` + }{ + Type: "text", + Text: c.Text, + Meta: c.Meta, + Annotations: c.Annotations, + } + return json.Marshal(wire) +} + +func (c *TextContent) fromWire(wire *wireContent) { + c.Text = wire.Text + c.Meta = wire.Meta + c.Annotations = wire.Annotations +} + +// ImageContent contains base64-encoded image data. +type ImageContent struct { + Meta Meta + Annotations *Annotations + Data []byte // base64-encoded + MIMEType string +} + +func (c *ImageContent) MarshalJSON() ([]byte, error) { + // Custom wire format to ensure required fields are always included, even when empty. + data := c.Data + if data == nil { + data = []byte{} + } + wire := imageAudioWire{ + Type: "image", + MIMEType: c.MIMEType, + Data: data, + Meta: c.Meta, + Annotations: c.Annotations, + } + return json.Marshal(wire) +} + +func (c *ImageContent) fromWire(wire *wireContent) { + c.MIMEType = wire.MIMEType + c.Data = wire.Data + c.Meta = wire.Meta + c.Annotations = wire.Annotations +} + +// AudioContent contains base64-encoded audio data. +type AudioContent struct { + Data []byte + MIMEType string + Meta Meta + Annotations *Annotations +} + +func (c AudioContent) MarshalJSON() ([]byte, error) { + // Custom wire format to ensure required fields are always included, even when empty. + data := c.Data + if data == nil { + data = []byte{} + } + wire := imageAudioWire{ + Type: "audio", + MIMEType: c.MIMEType, + Data: data, + Meta: c.Meta, + Annotations: c.Annotations, + } + return json.Marshal(wire) +} + +func (c *AudioContent) fromWire(wire *wireContent) { + c.MIMEType = wire.MIMEType + c.Data = wire.Data + c.Meta = wire.Meta + c.Annotations = wire.Annotations +} + +// Custom wire format to ensure required fields are always included, even when empty. +type imageAudioWire struct { + Type string `json:"type"` + MIMEType string `json:"mimeType"` + Data []byte `json:"data"` + Meta Meta `json:"_meta,omitempty"` + Annotations *Annotations `json:"annotations,omitempty"` +} + +// ResourceLink is a link to a resource +type ResourceLink struct { + URI string + Name string + Title string + Description string + MIMEType string + Size *int64 + Meta Meta + Annotations *Annotations +} + +func (c *ResourceLink) MarshalJSON() ([]byte, error) { + return json.Marshal(&wireContent{ + Type: "resource_link", + URI: c.URI, + Name: c.Name, + Title: c.Title, + Description: c.Description, + MIMEType: c.MIMEType, + Size: c.Size, + Meta: c.Meta, + Annotations: c.Annotations, + }) +} + +func (c *ResourceLink) fromWire(wire *wireContent) { + c.URI = wire.URI + c.Name = wire.Name + c.Title = wire.Title + c.Description = wire.Description + c.MIMEType = wire.MIMEType + c.Size = wire.Size + c.Meta = wire.Meta + c.Annotations = wire.Annotations +} + +// EmbeddedResource contains embedded resources. +type EmbeddedResource struct { + Resource *ResourceContents + Meta Meta + Annotations *Annotations +} + +func (c *EmbeddedResource) MarshalJSON() ([]byte, error) { + return json.Marshal(&wireContent{ + Type: "resource", + Resource: c.Resource, + Meta: c.Meta, + Annotations: c.Annotations, + }) +} + +func (c *EmbeddedResource) fromWire(wire *wireContent) { + c.Resource = wire.Resource + c.Meta = wire.Meta + c.Annotations = wire.Annotations +} + +// ResourceContents contains the contents of a specific resource or +// sub-resource. +type ResourceContents struct { + URI string `json:"uri"` + MIMEType string `json:"mimeType,omitempty"` + Text string `json:"text,omitempty"` + Blob []byte `json:"blob,omitempty"` + Meta Meta `json:"_meta,omitempty"` +} + +func (r *ResourceContents) MarshalJSON() ([]byte, error) { + // If we could assume Go 1.24, we could use omitzero for Blob and avoid this method. + if r.URI == "" { + return nil, errors.New("ResourceContents missing URI") + } + if r.Blob == nil { + // Text. Marshal normally. + type wireResourceContents ResourceContents // (lacks MarshalJSON method) + return json.Marshal((wireResourceContents)(*r)) + } + // Blob. + if r.Text != "" { + return nil, errors.New("ResourceContents has non-zero Text and Blob fields") + } + // r.Blob may be the empty slice, so marshal with an alternative definition. + br := struct { + URI string `json:"uri,omitempty"` + MIMEType string `json:"mimeType,omitempty"` + Blob []byte `json:"blob"` + Meta Meta `json:"_meta,omitempty"` + }{ + URI: r.URI, + MIMEType: r.MIMEType, + Blob: r.Blob, + Meta: r.Meta, + } + return json.Marshal(br) +} + +// wireContent is the wire format for content. +// It represents the protocol types TextContent, ImageContent, AudioContent, +// ResourceLink, and EmbeddedResource. +// The Type field distinguishes them. In the protocol, each type has a constant +// value for the field. +// At most one of Text, Data, Resource, and URI is non-zero. +type wireContent struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + MIMEType string `json:"mimeType,omitempty"` + Data []byte `json:"data,omitempty"` + Resource *ResourceContents `json:"resource,omitempty"` + URI string `json:"uri,omitempty"` + Name string `json:"name,omitempty"` + Title string `json:"title,omitempty"` + Description string `json:"description,omitempty"` + Size *int64 `json:"size,omitempty"` + Meta Meta `json:"_meta,omitempty"` + Annotations *Annotations `json:"annotations,omitempty"` +} + +func contentsFromWire(wires []*wireContent, allow map[string]bool) ([]Content, error) { + var blocks []Content + for _, wire := range wires { + block, err := contentFromWire(wire, allow) + if err != nil { + return nil, err + } + blocks = append(blocks, block) + } + return blocks, nil +} + +func contentFromWire(wire *wireContent, allow map[string]bool) (Content, error) { + if wire == nil { + return nil, fmt.Errorf("nil content") + } + if allow != nil && !allow[wire.Type] { + return nil, fmt.Errorf("invalid content type %q", wire.Type) + } + switch wire.Type { + case "text": + v := new(TextContent) + v.fromWire(wire) + return v, nil + case "image": + v := new(ImageContent) + v.fromWire(wire) + return v, nil + case "audio": + v := new(AudioContent) + v.fromWire(wire) + return v, nil + case "resource_link": + v := new(ResourceLink) + v.fromWire(wire) + return v, nil + case "resource": + v := new(EmbeddedResource) + v.fromWire(wire) + return v, nil + } + return nil, fmt.Errorf("internal error: unrecognized content type %s", wire.Type) +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/event.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/event.go new file mode 100644 index 0000000000..281f5925ae --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/event.go @@ -0,0 +1,426 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file is for SSE events. +// See https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events. + +package mcp + +import ( + "bufio" + "bytes" + "context" + "errors" + "fmt" + "io" + "iter" + "maps" + "net/http" + "slices" + "strings" + "sync" +) + +// If true, MemoryEventStore will do frequent validation to check invariants, slowing it down. +// Enable for debugging. +const validateMemoryEventStore = false + +// An Event is a server-sent event. +// See https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#fields. +type Event struct { + Name string // the "event" field + ID string // the "id" field + Data []byte // the "data" field +} + +// Empty reports whether the Event is empty. +func (e Event) Empty() bool { + return e.Name == "" && e.ID == "" && len(e.Data) == 0 +} + +// writeEvent writes the event to w, and flushes. +func writeEvent(w io.Writer, evt Event) (int, error) { + var b bytes.Buffer + if evt.Name != "" { + fmt.Fprintf(&b, "event: %s\n", evt.Name) + } + if evt.ID != "" { + fmt.Fprintf(&b, "id: %s\n", evt.ID) + } + fmt.Fprintf(&b, "data: %s\n\n", string(evt.Data)) + n, err := w.Write(b.Bytes()) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + return n, err +} + +// scanEvents iterates SSE events in the given scanner. The iterated error is +// terminal: if encountered, the stream is corrupt or broken and should no +// longer be used. +// +// TODO(rfindley): consider a different API here that makes failure modes more +// apparent. +func scanEvents(r io.Reader) iter.Seq2[Event, error] { + scanner := bufio.NewScanner(r) + const maxTokenSize = 1 * 1024 * 1024 // 1 MiB max line size + scanner.Buffer(nil, maxTokenSize) + + // TODO: investigate proper behavior when events are out of order, or have + // non-standard names. + var ( + eventKey = []byte("event") + idKey = []byte("id") + dataKey = []byte("data") + ) + + return func(yield func(Event, error) bool) { + // iterate event from the wire. + // https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#examples + // + // - `key: value` line records. + // - Consecutive `data: ...` fields are joined with newlines. + // - Unrecognized fields are ignored. Since we only care about 'event', 'id', and + // 'data', these are the only three we consider. + // - Lines starting with ":" are ignored. + // - Records are terminated with two consecutive newlines. + var ( + evt Event + dataBuf *bytes.Buffer // if non-nil, preceding field was also data + ) + flushData := func() { + if dataBuf != nil { + evt.Data = dataBuf.Bytes() + dataBuf = nil + } + } + for scanner.Scan() { + line := scanner.Bytes() + if len(line) == 0 { + flushData() + // \n\n is the record delimiter + if !evt.Empty() && !yield(evt, nil) { + return + } + evt = Event{} + continue + } + before, after, found := bytes.Cut(line, []byte{':'}) + if !found { + yield(Event{}, fmt.Errorf("malformed line in SSE stream: %q", string(line))) + return + } + if !bytes.Equal(before, dataKey) { + flushData() + } + switch { + case bytes.Equal(before, eventKey): + evt.Name = strings.TrimSpace(string(after)) + case bytes.Equal(before, idKey): + evt.ID = strings.TrimSpace(string(after)) + case bytes.Equal(before, dataKey): + data := bytes.TrimSpace(after) + if dataBuf != nil { + dataBuf.WriteByte('\n') + dataBuf.Write(data) + } else { + dataBuf = new(bytes.Buffer) + dataBuf.Write(data) + } + } + } + if err := scanner.Err(); err != nil { + if errors.Is(err, bufio.ErrTooLong) { + err = fmt.Errorf("event exceeded max line length of %d", maxTokenSize) + } + if !yield(Event{}, err) { + return + } + } + flushData() + if !evt.Empty() { + yield(evt, nil) + } + } +} + +// An EventStore tracks data for SSE streams. +// A single EventStore suffices for all sessions, since session IDs are +// globally unique. So one EventStore can be created per process, for +// all Servers in the process. +// Such a store is able to bound resource usage for the entire process. +// +// All of an EventStore's methods must be safe for use by multiple goroutines. +type EventStore interface { + // Open is called when a new stream is created. It may be used to ensure that + // the underlying data structure for the stream is initialized, making it + // ready to store and replay event streams. + Open(_ context.Context, sessionID, streamID string) error + + // Append appends data for an outgoing event to given stream, which is part of the + // given session. + Append(_ context.Context, sessionID, streamID string, data []byte) error + + // After returns an iterator over the data for the given session and stream, beginning + // just after the given index. + // + // Once the iterator yields a non-nil error, it will stop. + // After's iterator must return an error immediately if any data after index was + // dropped; it must not return partial results. + // The stream must have been opened previously (see [EventStore.Open]). + After(_ context.Context, sessionID, streamID string, index int) iter.Seq2[[]byte, error] + + // SessionClosed informs the store that the given session is finished, along + // with all of its streams. + // + // A store cannot rely on this method being called for cleanup. It should institute + // additional mechanisms, such as timeouts, to reclaim storage. + SessionClosed(_ context.Context, sessionID string) error + + // There is no StreamClosed method. A server doesn't know when a stream is finished, because + // the client can always send a GET with a Last-Event-ID referring to the stream. +} + +// A dataList is a list of []byte. +// The zero dataList is ready to use. +type dataList struct { + size int // total size of data bytes + first int // the stream index of the first element in data + data [][]byte +} + +func (dl *dataList) appendData(d []byte) { + // If we allowed empty data, we would consume memory without incrementing the size. + // We could of course account for that, but we keep it simple and assume there is no + // empty data. + if len(d) == 0 { + panic("empty data item") + } + dl.data = append(dl.data, d) + dl.size += len(d) +} + +// removeFirst removes the first data item in dl, returning the size of the item. +// It panics if dl is empty. +func (dl *dataList) removeFirst() int { + if len(dl.data) == 0 { + panic("empty dataList") + } + r := len(dl.data[0]) + dl.size -= r + dl.data[0] = nil // help GC + dl.data = dl.data[1:] + dl.first++ + return r +} + +// A MemoryEventStore is an [EventStore] backed by memory. +type MemoryEventStore struct { + mu sync.Mutex + maxBytes int // max total size of all data + nBytes int // current total size of all data + store map[string]map[string]*dataList // session ID -> stream ID -> *dataList +} + +// MemoryEventStoreOptions are options for a [MemoryEventStore]. +type MemoryEventStoreOptions struct{} + +// MaxBytes returns the maximum number of bytes that the store will retain before +// purging data. +func (s *MemoryEventStore) MaxBytes() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.maxBytes +} + +// SetMaxBytes sets the maximum number of bytes the store will retain before purging +// data. The argument must not be negative. If it is zero, a suitable default will be used. +// SetMaxBytes can be called at any time. The size of the store will be adjusted +// immediately. +func (s *MemoryEventStore) SetMaxBytes(n int) { + s.mu.Lock() + defer s.mu.Unlock() + switch { + case n < 0: + panic("negative argument") + case n == 0: + s.maxBytes = defaultMaxBytes + default: + s.maxBytes = n + } + s.purge() +} + +const defaultMaxBytes = 10 << 20 // 10 MiB + +// NewMemoryEventStore creates a [MemoryEventStore] with the default value +// for MaxBytes. +func NewMemoryEventStore(opts *MemoryEventStoreOptions) *MemoryEventStore { + return &MemoryEventStore{ + maxBytes: defaultMaxBytes, + store: make(map[string]map[string]*dataList), + } +} + +// Open implements [EventStore.Open]. It ensures that the underlying data +// structures for the given session are initialized and ready for use. +func (s *MemoryEventStore) Open(_ context.Context, sessionID, streamID string) error { + s.mu.Lock() + defer s.mu.Unlock() + s.init(sessionID, streamID) + return nil +} + +// init is an internal helper function that ensures the nested map structure for a +// given sessionID and streamID exists, creating it if necessary. It returns the +// dataList associated with the specified IDs. +// Requires s.mu. +func (s *MemoryEventStore) init(sessionID, streamID string) *dataList { + streamMap, ok := s.store[sessionID] + if !ok { + streamMap = make(map[string]*dataList) + s.store[sessionID] = streamMap + } + dl, ok := streamMap[streamID] + if !ok { + dl = &dataList{} + streamMap[streamID] = dl + } + return dl +} + +// Append implements [EventStore.Append] by recording data in memory. +func (s *MemoryEventStore) Append(_ context.Context, sessionID, streamID string, data []byte) error { + s.mu.Lock() + defer s.mu.Unlock() + dl := s.init(sessionID, streamID) + // Purge before adding, so at least the current data item will be present. + // (That could result in nBytes > maxBytes, but we'll live with that.) + s.purge() + dl.appendData(data) + s.nBytes += len(data) + return nil +} + +// ErrEventsPurged is the error that [EventStore.After] should return if the event just after the +// index is no longer available. +var ErrEventsPurged = errors.New("data purged") + +// After implements [EventStore.After]. +func (s *MemoryEventStore) After(_ context.Context, sessionID, streamID string, index int) iter.Seq2[[]byte, error] { + // Return the data items to yield. + // We must copy, because dataList.removeFirst nils out slice elements. + copyData := func() ([][]byte, error) { + s.mu.Lock() + defer s.mu.Unlock() + streamMap, ok := s.store[sessionID] + if !ok { + return nil, fmt.Errorf("MemoryEventStore.After: unknown session ID %q", sessionID) + } + dl, ok := streamMap[streamID] + if !ok { + return nil, fmt.Errorf("MemoryEventStore.After: unknown stream ID %v in session %q", streamID, sessionID) + } + start := index + 1 + if dl.first > start { + return nil, fmt.Errorf("MemoryEventStore.After: index %d, stream ID %v, session %q: %w", + index, streamID, sessionID, ErrEventsPurged) + } + return slices.Clone(dl.data[start-dl.first:]), nil + } + + return func(yield func([]byte, error) bool) { + ds, err := copyData() + if err != nil { + yield(nil, err) + return + } + for _, d := range ds { + if !yield(d, nil) { + return + } + } + } +} + +// SessionClosed implements [EventStore.SessionClosed]. +func (s *MemoryEventStore) SessionClosed(_ context.Context, sessionID string) error { + s.mu.Lock() + defer s.mu.Unlock() + for _, dl := range s.store[sessionID] { + s.nBytes -= dl.size + } + delete(s.store, sessionID) + s.validate() + return nil +} + +// purge removes data until no more than s.maxBytes bytes are in use. +// It must be called with s.mu held. +func (s *MemoryEventStore) purge() { + // Remove the first element of every dataList until below the max. + for s.nBytes > s.maxBytes { + changed := false + for _, sm := range s.store { + for _, dl := range sm { + if dl.size > 0 { + r := dl.removeFirst() + if r > 0 { + changed = true + s.nBytes -= r + } + } + } + } + if !changed { + panic("no progress during purge") + } + } + s.validate() +} + +// validate checks that the store's data structures are valid. +// It must be called with s.mu held. +func (s *MemoryEventStore) validate() { + if !validateMemoryEventStore { + return + } + // Check that we're accounting for the size correctly. + n := 0 + for _, sm := range s.store { + for _, dl := range sm { + for _, d := range dl.data { + n += len(d) + } + } + } + if n != s.nBytes { + panic("sizes don't add up") + } +} + +// debugString returns a string containing the state of s. +// Used in tests. +func (s *MemoryEventStore) debugString() string { + s.mu.Lock() + defer s.mu.Unlock() + var b strings.Builder + for i, sess := range slices.Sorted(maps.Keys(s.store)) { + if i > 0 { + fmt.Fprintf(&b, "; ") + } + sm := s.store[sess] + for i, sid := range slices.Sorted(maps.Keys(sm)) { + if i > 0 { + fmt.Fprintf(&b, "; ") + } + dl := sm[sid] + fmt.Fprintf(&b, "%s %s first=%d", sess, sid, dl.first) + for _, d := range dl.data { + fmt.Fprintf(&b, " %s", d) + } + } + } + return b.String() +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/features.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/features.go new file mode 100644 index 0000000000..438370fe58 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/features.go @@ -0,0 +1,114 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "iter" + "maps" + "slices" +) + +// This file contains implementations that are common to all features. +// A feature is an item provided to a peer. In the 2025-03-26 spec, +// the features are prompt, tool, resource and root. + +// A featureSet is a collection of features of type T. +// Every feature has a unique ID, and the spec never mentions +// an ordering for the List calls, so what it calls a "list" is actually a set. +// +// An alternative implementation would use an ordered map, but that's probably +// not necessary as adds and removes are rare, and usually batched. +type featureSet[T any] struct { + uniqueID func(T) string + features map[string]T + sortedKeys []string // lazily computed; nil after add or remove +} + +// newFeatureSet creates a new featureSet for features of type T. +// The argument function should return the unique ID for a single feature. +func newFeatureSet[T any](uniqueIDFunc func(T) string) *featureSet[T] { + return &featureSet[T]{ + uniqueID: uniqueIDFunc, + features: make(map[string]T), + } +} + +// add adds each feature to the set if it is not present, +// or replaces an existing feature. +func (s *featureSet[T]) add(fs ...T) { + for _, f := range fs { + s.features[s.uniqueID(f)] = f + } + s.sortedKeys = nil +} + +// remove removes all features with the given uids from the set if present, +// and returns whether any were removed. +// It is not an error to remove a nonexistent feature. +func (s *featureSet[T]) remove(uids ...string) bool { + changed := false + for _, uid := range uids { + if _, ok := s.features[uid]; ok { + changed = true + delete(s.features, uid) + } + } + if changed { + s.sortedKeys = nil + } + return changed +} + +// get returns the feature with the given uid. +// If there is none, it returns zero, false. +func (s *featureSet[T]) get(uid string) (T, bool) { + t, ok := s.features[uid] + return t, ok +} + +// len returns the number of features in the set. +func (s *featureSet[T]) len() int { return len(s.features) } + +// all returns an iterator over of all the features in the set +// sorted by unique ID. +func (s *featureSet[T]) all() iter.Seq[T] { + s.sortKeys() + return func(yield func(T) bool) { + s.yieldFrom(0, yield) + } +} + +// above returns an iterator over features in the set whose unique IDs are +// greater than `uid`, in ascending ID order. +func (s *featureSet[T]) above(uid string) iter.Seq[T] { + s.sortKeys() + index, found := slices.BinarySearch(s.sortedKeys, uid) + if found { + index++ + } + return func(yield func(T) bool) { + s.yieldFrom(index, yield) + } +} + +// sortKeys is a helper that maintains a sorted list of feature IDs. It +// computes this list lazily upon its first call after a modification, or +// if it's nil. +func (s *featureSet[T]) sortKeys() { + if s.sortedKeys != nil { + return + } + s.sortedKeys = slices.Sorted(maps.Keys(s.features)) +} + +// yieldFrom is a helper that iterates over the features in the set, +// starting at the given index, and calls the yield function for each one. +func (s *featureSet[T]) yieldFrom(index int, yield func(T) bool) { + for i := index; i < len(s.sortedKeys); i++ { + if !yield(s.features[s.sortedKeys[i]]) { + return + } + } +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/logging.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/logging.go new file mode 100644 index 0000000000..208427e226 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/logging.go @@ -0,0 +1,207 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "bytes" + "cmp" + "context" + "encoding/json" + "log/slog" + "sync" + "time" +) + +// Logging levels. +const ( + LevelDebug = slog.LevelDebug + LevelInfo = slog.LevelInfo + LevelNotice = (slog.LevelInfo + slog.LevelWarn) / 2 + LevelWarning = slog.LevelWarn + LevelError = slog.LevelError + LevelCritical = slog.LevelError + 4 + LevelAlert = slog.LevelError + 8 + LevelEmergency = slog.LevelError + 12 +) + +var slogToMCP = map[slog.Level]LoggingLevel{ + LevelDebug: "debug", + LevelInfo: "info", + LevelNotice: "notice", + LevelWarning: "warning", + LevelError: "error", + LevelCritical: "critical", + LevelAlert: "alert", + LevelEmergency: "emergency", +} + +var mcpToSlog = make(map[LoggingLevel]slog.Level) + +func init() { + for sl, ml := range slogToMCP { + mcpToSlog[ml] = sl + } +} + +func slogLevelToMCP(sl slog.Level) LoggingLevel { + if ml, ok := slogToMCP[sl]; ok { + return ml + } + return "debug" // for lack of a better idea +} + +func mcpLevelToSlog(ll LoggingLevel) slog.Level { + if sl, ok := mcpToSlog[ll]; ok { + return sl + } + // TODO: is there a better default? + return LevelDebug +} + +// compareLevels behaves like [cmp.Compare] for [LoggingLevel]s. +func compareLevels(l1, l2 LoggingLevel) int { + return cmp.Compare(mcpLevelToSlog(l1), mcpLevelToSlog(l2)) +} + +// LoggingHandlerOptions are options for a LoggingHandler. +type LoggingHandlerOptions struct { + // The value for the "logger" field of logging notifications. + LoggerName string + // Limits the rate at which log messages are sent. + // Excess messages are dropped. + // If zero, there is no rate limiting. + MinInterval time.Duration +} + +// A LoggingHandler is a [slog.Handler] for MCP. +type LoggingHandler struct { + opts LoggingHandlerOptions + ss *ServerSession + // Ensures that the buffer reset is atomic with the write (see Handle). + // A pointer so that clones share the mutex. See + // https://github.com/golang/example/blob/master/slog-handler-guide/README.md#getting-the-mutex-right. + mu *sync.Mutex + lastMessageSent time.Time // for rate-limiting + buf *bytes.Buffer + handler slog.Handler +} + +// discardHandler is a slog.Handler that drops all logs. +// TODO: use slog.DiscardHandler when we require Go 1.24+. +type discardHandler struct{} + +func (discardHandler) Enabled(context.Context, slog.Level) bool { return false } +func (discardHandler) Handle(context.Context, slog.Record) error { return nil } +func (discardHandler) WithAttrs([]slog.Attr) slog.Handler { return discardHandler{} } +func (discardHandler) WithGroup(string) slog.Handler { return discardHandler{} } + +// ensureLogger returns l if non-nil, otherwise a discard logger. +func ensureLogger(l *slog.Logger) *slog.Logger { + if l != nil { + return l + } + return slog.New(discardHandler{}) +} + +// NewLoggingHandler creates a [LoggingHandler] that logs to the given [ServerSession] using a +// [slog.JSONHandler]. +func NewLoggingHandler(ss *ServerSession, opts *LoggingHandlerOptions) *LoggingHandler { + var buf bytes.Buffer + jsonHandler := slog.NewJSONHandler(&buf, &slog.HandlerOptions{ + ReplaceAttr: func(_ []string, a slog.Attr) slog.Attr { + // Remove level: it appears in LoggingMessageParams. + if a.Key == slog.LevelKey { + return slog.Attr{} + } + return a + }, + }) + lh := &LoggingHandler{ + ss: ss, + mu: new(sync.Mutex), + buf: &buf, + handler: jsonHandler, + } + if opts != nil { + lh.opts = *opts + } + return lh +} + +// Enabled implements [slog.Handler.Enabled] by comparing level to the [ServerSession]'s level. +func (h *LoggingHandler) Enabled(ctx context.Context, level slog.Level) bool { + // This is also checked in ServerSession.LoggingMessage, so checking it here + // is just an optimization that skips building the JSON. + h.ss.mu.Lock() + mcpLevel := h.ss.state.LogLevel + h.ss.mu.Unlock() + return level >= mcpLevelToSlog(mcpLevel) +} + +// WithAttrs implements [slog.Handler.WithAttrs]. +func (h *LoggingHandler) WithAttrs(as []slog.Attr) slog.Handler { + h2 := *h + h2.handler = h.handler.WithAttrs(as) + return &h2 +} + +// WithGroup implements [slog.Handler.WithGroup]. +func (h *LoggingHandler) WithGroup(name string) slog.Handler { + h2 := *h + h2.handler = h.handler.WithGroup(name) + return &h2 +} + +// Handle implements [slog.Handler.Handle] by writing the Record to a JSONHandler, +// then calling [ServerSession.LoggingMessage] with the result. +func (h *LoggingHandler) Handle(ctx context.Context, r slog.Record) error { + err := h.handle(ctx, r) + // TODO(jba): find a way to surface the error. + // The return value will probably be ignored. + return err +} + +func (h *LoggingHandler) handle(ctx context.Context, r slog.Record) error { + // Observe the rate limit. + // TODO(jba): use golang.org/x/time/rate. (We can't here because it would require adding + // golang.org/x/time to the go.mod file.) + h.mu.Lock() + skip := time.Since(h.lastMessageSent) < h.opts.MinInterval + h.mu.Unlock() + if skip { + return nil + } + + var err error + // Make the buffer reset atomic with the record write. + // We are careful here in the unlikely event that the handler panics. + // We don't want to hold the lock for the entire function, because Notify is + // an I/O operation. + // This can result in out-of-order delivery. + func() { + h.mu.Lock() + defer h.mu.Unlock() + h.buf.Reset() + err = h.handler.Handle(ctx, r) + }() + if err != nil { + return err + } + + h.mu.Lock() + h.lastMessageSent = time.Now() + h.mu.Unlock() + + params := &LoggingMessageParams{ + Logger: h.opts.LoggerName, + Level: slogLevelToMCP(r.Level), + Data: json.RawMessage(h.buf.Bytes()), + } + // We pass the argument context to Notify, even though slog.Handler.Handle's + // documentation says not to. + // In this case logging is a service to clients, not a means for debugging the + // server, so we want to cancel the log message. + return h.ss.Log(ctx, params) +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/mcp.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/mcp.go new file mode 100644 index 0000000000..56e950b869 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/mcp.go @@ -0,0 +1,88 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// The mcp package provides an SDK for writing model context protocol clients +// and servers. +// +// To get started, create either a [Client] or [Server], add features to it +// using `AddXXX` functions, and connect it to a peer using a [Transport]. +// +// For example, to run a simple server on the [StdioTransport]: +// +// server := mcp.NewServer(&mcp.Implementation{Name: "greeter"}, nil) +// +// // Using the generic AddTool automatically populates the the input and output +// // schema of the tool. +// type args struct { +// Name string `json:"name" jsonschema:"the person to greet"` +// } +// mcp.AddTool(server, &mcp.Tool{ +// Name: "greet", +// Description: "say hi", +// }, func(ctx context.Context, req *mcp.CallToolRequest, args args) (*mcp.CallToolResult, any, error) { +// return &mcp.CallToolResult{ +// Content: []mcp.Content{ +// &mcp.TextContent{Text: "Hi " + args.Name}, +// }, +// }, nil, nil +// }) +// +// // Run the server on the stdio transport. +// if err := server.Run(context.Background(), &mcp.StdioTransport{}); err != nil { +// log.Printf("Server failed: %v", err) +// } +// +// To connect to this server, use the [CommandTransport]: +// +// client := mcp.NewClient(&mcp.Implementation{Name: "mcp-client", Version: "v1.0.0"}, nil) +// transport := &mcp.CommandTransport{Command: exec.Command("myserver")} +// session, err := client.Connect(ctx, transport, nil) +// if err != nil { +// log.Fatal(err) +// } +// defer session.Close() +// +// params := &mcp.CallToolParams{ +// Name: "greet", +// Arguments: map[string]any{"name": "you"}, +// } +// res, err := session.CallTool(ctx, params) +// if err != nil { +// log.Fatalf("CallTool failed: %v", err) +// } +// +// # Clients, servers, and sessions +// +// In this SDK, both a [Client] and [Server] may handle many concurrent +// connections. Each time a client or server is connected to a peer using a +// [Transport], it creates a new session (either a [ClientSession] or +// [ServerSession]): +// +// Client Server +// ⇅ (jsonrpc2) ⇅ +// ClientSession ⇄ Client Transport ⇄ Server Transport ⇄ ServerSession +// +// The session types expose an API to interact with its peer. For example, +// [ClientSession.CallTool] or [ServerSession.ListRoots]. +// +// # Adding features +// +// Add MCP servers to your Client or Server using AddXXX methods (for example +// [Client.AddRoot] or [Server.AddPrompt]). If any peers are connected when +// AddXXX is called, they will receive a corresponding change notification +// (for example notifications/roots/list_changed). +// +// Adding tools is special: tools may be bound to ordinary Go functions by +// using the top-level generic [AddTool] function, which allows specifying an +// input and output type. When AddTool is used, the tool's input schema and +// output schema are automatically populated, and inputs are automatically +// validated. As a special case, if the output type is 'any', no output schema +// is generated. +// +// func double(_ context.Context, _ *mcp.CallToolRequest, in In) (*mcp.CallToolResult, Out, error) { +// return nil, Out{Answer: 2*in.Number}, nil +// } +// ... +// mcp.AddTool(server, &mcp.Tool{Name: "double"}, double) +package mcp diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/prompt.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/prompt.go new file mode 100644 index 0000000000..62f38a36af --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/prompt.go @@ -0,0 +1,17 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" +) + +// A PromptHandler handles a call to prompts/get. +type PromptHandler func(context.Context, *GetPromptRequest) (*GetPromptResult, error) + +type serverPrompt struct { + prompt *Prompt + handler PromptHandler +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/protocol.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/protocol.go new file mode 100644 index 0000000000..1312dfbdc8 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/protocol.go @@ -0,0 +1,1165 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +// Protocol types for version 2025-06-18. +// To see the schema changes from the previous version, run: +// +// prefix=https://raw.githubusercontent.com/modelcontextprotocol/modelcontextprotocol/refs/heads/main/schema +// sdiff -l <(curl $prefix/2025-03-26/schema.ts) <(curl $prefix/2025/06-18/schema.ts) + +import ( + "encoding/json" + "fmt" +) + +// Optional annotations for the client. The client can use annotations to inform +// how objects are used or displayed. +type Annotations struct { + // Describes who the intended customer of this object or data is. + // + // It can include multiple entries to indicate content useful for multiple + // audiences (e.g., []Role{"user", "assistant"}). + Audience []Role `json:"audience,omitempty"` + // The moment the resource was last modified, as an ISO 8601 formatted string. + // + // Should be an ISO 8601 formatted string (e.g., "2025-01-12T15:00:58Z"). + // + // Examples: last activity timestamp in an open file, timestamp when the + // resource was attached, etc. + LastModified string `json:"lastModified,omitempty"` + // Describes how important this data is for operating the server. + // + // A value of 1 means "most important," and indicates that the data is + // effectively required, while 0 means "least important," and indicates that the + // data is entirely optional. + Priority float64 `json:"priority,omitempty"` +} + +// CallToolParams is used by clients to call a tool. +type CallToolParams struct { + // Meta is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // Name is the name of the tool to call. + Name string `json:"name"` + // Arguments holds the tool arguments. It can hold any value that can be + // marshaled to JSON. + Arguments any `json:"arguments,omitempty"` +} + +// CallToolParamsRaw is passed to tool handlers on the server. Its arguments +// are not yet unmarshaled (hence "raw"), so that the handlers can perform +// unmarshaling themselves. +type CallToolParamsRaw struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // Name is the name of the tool being called. + Name string `json:"name"` + // Arguments is the raw arguments received over the wire from the client. It + // is the responsibility of the tool handler to unmarshal and validate the + // Arguments (see [AddTool]). + Arguments json.RawMessage `json:"arguments,omitempty"` +} + +// A CallToolResult is the server's response to a tool call. +// +// The [ToolHandler] and [ToolHandlerFor] handler functions return this result, +// though [ToolHandlerFor] populates much of it automatically as documented at +// each field. +type CallToolResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + + // A list of content objects that represent the unstructured result of the tool + // call. + // + // When using a [ToolHandlerFor] with structured output, if Content is unset + // it will be populated with JSON text content corresponding to the + // structured output value. + Content []Content `json:"content"` + + // StructuredContent is an optional value that represents the structured + // result of the tool call. It must marshal to a JSON object. + // + // When using a [ToolHandlerFor] with structured output, you should not + // populate this field. It will be automatically populated with the typed Out + // value. + StructuredContent any `json:"structuredContent,omitempty"` + + // IsError reports whether the tool call ended in an error. + // + // If not set, this is assumed to be false (the call was successful). + // + // Any errors that originate from the tool should be reported inside the + // Content field, with IsError set to true, not as an MCP protocol-level + // error response. Otherwise, the LLM would not be able to see that an error + // occurred and self-correct. + // + // However, any errors in finding the tool, an error indicating that the + // server does not support tool calls, or any other exceptional conditions, + // should be reported as an MCP error response. + // + // When using a [ToolHandlerFor], this field is automatically set when the + // tool handler returns an error, and the error string is included as text in + // the Content field. + IsError bool `json:"isError,omitempty"` + + // The error passed to setError, if any. + // It is not marshaled, and therefore it is only visible on the server. + // Its only use is in server sending middleware, where it can be accessed + // with getError. + err error +} + +// TODO(#64): consider exposing setError (and getError), by adding an error +// field on CallToolResult. +func (r *CallToolResult) setError(err error) { + r.Content = []Content{&TextContent{Text: err.Error()}} + r.IsError = true + r.err = err +} + +// getError returns the error set with setError, or nil if none. +// This function always returns nil on clients. +func (r *CallToolResult) getError() error { + return r.err +} + +func (*CallToolResult) isResult() {} + +// UnmarshalJSON handles the unmarshalling of content into the Content +// interface. +func (x *CallToolResult) UnmarshalJSON(data []byte) error { + type res CallToolResult // avoid recursion + var wire struct { + res + Content []*wireContent `json:"content"` + } + if err := json.Unmarshal(data, &wire); err != nil { + return err + } + var err error + if wire.res.Content, err = contentsFromWire(wire.Content, nil); err != nil { + return err + } + *x = CallToolResult(wire.res) + return nil +} + +func (x *CallToolParams) isParams() {} +func (x *CallToolParams) GetProgressToken() any { return getProgressToken(x) } +func (x *CallToolParams) SetProgressToken(t any) { setProgressToken(x, t) } + +func (x *CallToolParamsRaw) isParams() {} +func (x *CallToolParamsRaw) GetProgressToken() any { return getProgressToken(x) } +func (x *CallToolParamsRaw) SetProgressToken(t any) { setProgressToken(x, t) } + +type CancelledParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // An optional string describing the reason for the cancellation. This may be + // logged or presented to the user. + Reason string `json:"reason,omitempty"` + // The ID of the request to cancel. + // + // This must correspond to the ID of a request previously issued in the same + // direction. + RequestID any `json:"requestId"` +} + +func (x *CancelledParams) isParams() {} +func (x *CancelledParams) GetProgressToken() any { return getProgressToken(x) } +func (x *CancelledParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// Capabilities a client may support. Known capabilities are defined here, in +// this schema, but this is not a closed set: any client can define its own, +// additional capabilities. +type ClientCapabilities struct { + // Experimental, non-standard capabilities that the client supports. + Experimental map[string]any `json:"experimental,omitempty"` + // Present if the client supports listing roots. + Roots struct { + // Whether the client supports notifications for changes to the roots list. + ListChanged bool `json:"listChanged,omitempty"` + } `json:"roots,omitempty"` + // Present if the client supports sampling from an LLM. + Sampling *SamplingCapabilities `json:"sampling,omitempty"` + // Present if the client supports elicitation from the server. + Elicitation *ElicitationCapabilities `json:"elicitation,omitempty"` +} + +type CompleteParamsArgument struct { + // The name of the argument + Name string `json:"name"` + // The value of the argument to use for completion matching. + Value string `json:"value"` +} + +// CompleteContext represents additional, optional context for completions. +type CompleteContext struct { + // Previously-resolved variables in a URI template or prompt. + Arguments map[string]string `json:"arguments,omitempty"` +} + +// CompleteReference represents a completion reference type (ref/prompt ref/resource). +// The Type field determines which other fields are relevant. +type CompleteReference struct { + Type string `json:"type"` + // Name is relevant when Type is "ref/prompt". + Name string `json:"name,omitempty"` + // URI is relevant when Type is "ref/resource". + URI string `json:"uri,omitempty"` +} + +func (r *CompleteReference) UnmarshalJSON(data []byte) error { + type wireCompleteReference CompleteReference // for naive unmarshaling + var r2 wireCompleteReference + if err := json.Unmarshal(data, &r2); err != nil { + return err + } + switch r2.Type { + case "ref/prompt", "ref/resource": + if r2.Type == "ref/prompt" && r2.URI != "" { + return fmt.Errorf("reference of type %q must not have a URI set", r2.Type) + } + if r2.Type == "ref/resource" && r2.Name != "" { + return fmt.Errorf("reference of type %q must not have a Name set", r2.Type) + } + default: + return fmt.Errorf("unrecognized content type %q", r2.Type) + } + *r = CompleteReference(r2) + return nil +} + +func (r *CompleteReference) MarshalJSON() ([]byte, error) { + // Validation for marshalling: ensure consistency before converting to JSON. + switch r.Type { + case "ref/prompt": + if r.URI != "" { + return nil, fmt.Errorf("reference of type %q must not have a URI set for marshalling", r.Type) + } + case "ref/resource": + if r.Name != "" { + return nil, fmt.Errorf("reference of type %q must not have a Name set for marshalling", r.Type) + } + default: + return nil, fmt.Errorf("unrecognized reference type %q for marshalling", r.Type) + } + + type wireReference CompleteReference + return json.Marshal(wireReference(*r)) +} + +type CompleteParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The argument's information + Argument CompleteParamsArgument `json:"argument"` + Context *CompleteContext `json:"context,omitempty"` + Ref *CompleteReference `json:"ref"` +} + +func (*CompleteParams) isParams() {} + +type CompletionResultDetails struct { + HasMore bool `json:"hasMore,omitempty"` + Total int `json:"total,omitempty"` + Values []string `json:"values"` +} + +// The server's response to a completion/complete request +type CompleteResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + Completion CompletionResultDetails `json:"completion"` +} + +func (*CompleteResult) isResult() {} + +type CreateMessageParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // A request to include context from one or more MCP servers (including the + // caller), to be attached to the prompt. The client may ignore this request. + IncludeContext string `json:"includeContext,omitempty"` + // The maximum number of tokens to sample, as requested by the server. The + // client may choose to sample fewer tokens than requested. + MaxTokens int64 `json:"maxTokens"` + Messages []*SamplingMessage `json:"messages"` + // Optional metadata to pass through to the LLM provider. The format of this + // metadata is provider-specific. + Metadata any `json:"metadata,omitempty"` + // The server's preferences for which model to select. The client may ignore + // these preferences. + ModelPreferences *ModelPreferences `json:"modelPreferences,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` + // An optional system prompt the server wants to use for sampling. The client + // may modify or omit this prompt. + SystemPrompt string `json:"systemPrompt,omitempty"` + Temperature float64 `json:"temperature,omitempty"` +} + +func (x *CreateMessageParams) isParams() {} +func (x *CreateMessageParams) GetProgressToken() any { return getProgressToken(x) } +func (x *CreateMessageParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// The client's response to a sampling/create_message request from the server. +// The client should inform the user before returning the sampled message, to +// allow them to inspect the response (human in the loop) and decide whether to +// allow the server to see it. +type CreateMessageResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + Content Content `json:"content"` + // The name of the model that generated the message. + Model string `json:"model"` + Role Role `json:"role"` + // The reason why sampling stopped, if known. + StopReason string `json:"stopReason,omitempty"` +} + +func (*CreateMessageResult) isResult() {} +func (r *CreateMessageResult) UnmarshalJSON(data []byte) error { + type result CreateMessageResult // avoid recursion + var wire struct { + result + Content *wireContent `json:"content"` + } + if err := json.Unmarshal(data, &wire); err != nil { + return err + } + var err error + if wire.result.Content, err = contentFromWire(wire.Content, map[string]bool{"text": true, "image": true, "audio": true}); err != nil { + return err + } + *r = CreateMessageResult(wire.result) + return nil +} + +type GetPromptParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // Arguments to use for templating the prompt. + Arguments map[string]string `json:"arguments,omitempty"` + // The name of the prompt or prompt template. + Name string `json:"name"` +} + +func (x *GetPromptParams) isParams() {} +func (x *GetPromptParams) GetProgressToken() any { return getProgressToken(x) } +func (x *GetPromptParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// The server's response to a prompts/get request from the client. +type GetPromptResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // An optional description for the prompt. + Description string `json:"description,omitempty"` + Messages []*PromptMessage `json:"messages"` +} + +func (*GetPromptResult) isResult() {} + +type InitializeParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + Capabilities *ClientCapabilities `json:"capabilities"` + ClientInfo *Implementation `json:"clientInfo"` + // The latest version of the Model Context Protocol that the client supports. + // The client may decide to support older versions as well. + ProtocolVersion string `json:"protocolVersion"` +} + +func (x *InitializeParams) isParams() {} +func (x *InitializeParams) GetProgressToken() any { return getProgressToken(x) } +func (x *InitializeParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// After receiving an initialize request from the client, the server sends this +// response. +type InitializeResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + Capabilities *ServerCapabilities `json:"capabilities"` + // Instructions describing how to use the server and its features. + // + // This can be used by clients to improve the LLM's understanding of available + // tools, resources, etc. It can be thought of like a "hint" to the model. For + // example, this information may be added to the system prompt. + Instructions string `json:"instructions,omitempty"` + // The version of the Model Context Protocol that the server wants to use. This + // may not match the version that the client requested. If the client cannot + // support this version, it must disconnect. + ProtocolVersion string `json:"protocolVersion"` + ServerInfo *Implementation `json:"serverInfo"` +} + +func (*InitializeResult) isResult() {} + +type InitializedParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` +} + +func (x *InitializedParams) isParams() {} +func (x *InitializedParams) GetProgressToken() any { return getProgressToken(x) } +func (x *InitializedParams) SetProgressToken(t any) { setProgressToken(x, t) } + +type ListPromptsParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // An opaque token representing the current pagination position. If provided, + // the server should return results starting after this cursor. + Cursor string `json:"cursor,omitempty"` +} + +func (x *ListPromptsParams) isParams() {} +func (x *ListPromptsParams) GetProgressToken() any { return getProgressToken(x) } +func (x *ListPromptsParams) SetProgressToken(t any) { setProgressToken(x, t) } +func (x *ListPromptsParams) cursorPtr() *string { return &x.Cursor } + +// The server's response to a prompts/list request from the client. +type ListPromptsResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // An opaque token representing the pagination position after the last returned + // result. If present, there may be more results available. + NextCursor string `json:"nextCursor,omitempty"` + Prompts []*Prompt `json:"prompts"` +} + +func (x *ListPromptsResult) isResult() {} +func (x *ListPromptsResult) nextCursorPtr() *string { return &x.NextCursor } + +type ListResourceTemplatesParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // An opaque token representing the current pagination position. If provided, + // the server should return results starting after this cursor. + Cursor string `json:"cursor,omitempty"` +} + +func (x *ListResourceTemplatesParams) isParams() {} +func (x *ListResourceTemplatesParams) GetProgressToken() any { return getProgressToken(x) } +func (x *ListResourceTemplatesParams) SetProgressToken(t any) { setProgressToken(x, t) } +func (x *ListResourceTemplatesParams) cursorPtr() *string { return &x.Cursor } + +// The server's response to a resources/templates/list request from the client. +type ListResourceTemplatesResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // An opaque token representing the pagination position after the last returned + // result. If present, there may be more results available. + NextCursor string `json:"nextCursor,omitempty"` + ResourceTemplates []*ResourceTemplate `json:"resourceTemplates"` +} + +func (x *ListResourceTemplatesResult) isResult() {} +func (x *ListResourceTemplatesResult) nextCursorPtr() *string { return &x.NextCursor } + +type ListResourcesParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // An opaque token representing the current pagination position. If provided, + // the server should return results starting after this cursor. + Cursor string `json:"cursor,omitempty"` +} + +func (x *ListResourcesParams) isParams() {} +func (x *ListResourcesParams) GetProgressToken() any { return getProgressToken(x) } +func (x *ListResourcesParams) SetProgressToken(t any) { setProgressToken(x, t) } +func (x *ListResourcesParams) cursorPtr() *string { return &x.Cursor } + +// The server's response to a resources/list request from the client. +type ListResourcesResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // An opaque token representing the pagination position after the last returned + // result. If present, there may be more results available. + NextCursor string `json:"nextCursor,omitempty"` + Resources []*Resource `json:"resources"` +} + +func (x *ListResourcesResult) isResult() {} +func (x *ListResourcesResult) nextCursorPtr() *string { return &x.NextCursor } + +type ListRootsParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` +} + +func (x *ListRootsParams) isParams() {} +func (x *ListRootsParams) GetProgressToken() any { return getProgressToken(x) } +func (x *ListRootsParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// The client's response to a roots/list request from the server. This result +// contains an array of Root objects, each representing a root directory or file +// that the server can operate on. +type ListRootsResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + Roots []*Root `json:"roots"` +} + +func (*ListRootsResult) isResult() {} + +type ListToolsParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // An opaque token representing the current pagination position. If provided, + // the server should return results starting after this cursor. + Cursor string `json:"cursor,omitempty"` +} + +func (x *ListToolsParams) isParams() {} +func (x *ListToolsParams) GetProgressToken() any { return getProgressToken(x) } +func (x *ListToolsParams) SetProgressToken(t any) { setProgressToken(x, t) } +func (x *ListToolsParams) cursorPtr() *string { return &x.Cursor } + +// The server's response to a tools/list request from the client. +type ListToolsResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // An opaque token representing the pagination position after the last returned + // result. If present, there may be more results available. + NextCursor string `json:"nextCursor,omitempty"` + Tools []*Tool `json:"tools"` +} + +func (x *ListToolsResult) isResult() {} +func (x *ListToolsResult) nextCursorPtr() *string { return &x.NextCursor } + +// The severity of a log message. +// +// These map to syslog message severities, as specified in RFC-5424: +// https://datatracker.ietf.org/doc/html/rfc5424#section-6.2.1 +type LoggingLevel string + +type LoggingMessageParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The data to be logged, such as a string message or an object. Any JSON + // serializable type is allowed here. + Data any `json:"data"` + // The severity of this log message. + Level LoggingLevel `json:"level"` + // An optional name of the logger issuing this message. + Logger string `json:"logger,omitempty"` +} + +func (x *LoggingMessageParams) isParams() {} +func (x *LoggingMessageParams) GetProgressToken() any { return getProgressToken(x) } +func (x *LoggingMessageParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// Hints to use for model selection. +// +// Keys not declared here are currently left unspecified by the spec and are up +// to the client to interpret. +type ModelHint struct { + // A hint for a model name. + // + // The client should treat this as a substring of a model name; for example: - + // `claude-3-5-sonnet` should match `claude-3-5-sonnet-20241022` - `sonnet` + // should match `claude-3-5-sonnet-20241022`, `claude-3-sonnet-20240229`, etc. - + // `claude` should match any Claude model + // + // The client may also map the string to a different provider's model name or a + // different model family, as long as it fills a similar niche; for example: - + // `gemini-1.5-flash` could match `claude-3-haiku-20240307` + Name string `json:"name,omitempty"` +} + +// The server's preferences for model selection, requested of the client during +// sampling. +// +// Because LLMs can vary along multiple dimensions, choosing the "best" model is +// rarely straightforward. Different models excel in different areas—some are +// faster but less capable, others are more capable but more expensive, and so +// on. This interface allows servers to express their priorities across multiple +// dimensions to help clients make an appropriate selection for their use case. +// +// These preferences are always advisory. The client may ignore them. It is also +// up to the client to decide how to interpret these preferences and how to +// balance them against other considerations. +type ModelPreferences struct { + // How much to prioritize cost when selecting a model. A value of 0 means cost + // is not important, while a value of 1 means cost is the most important factor. + CostPriority float64 `json:"costPriority,omitempty"` + // Optional hints to use for model selection. + // + // If multiple hints are specified, the client must evaluate them in order (such + // that the first match is taken). + // + // The client should prioritize these hints over the numeric priorities, but may + // still use the priorities to select from ambiguous matches. + Hints []*ModelHint `json:"hints,omitempty"` + // How much to prioritize intelligence and capabilities when selecting a model. + // A value of 0 means intelligence is not important, while a value of 1 means + // intelligence is the most important factor. + IntelligencePriority float64 `json:"intelligencePriority,omitempty"` + // How much to prioritize sampling speed (latency) when selecting a model. A + // value of 0 means speed is not important, while a value of 1 means speed is + // the most important factor. + SpeedPriority float64 `json:"speedPriority,omitempty"` +} + +type PingParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` +} + +func (x *PingParams) isParams() {} +func (x *PingParams) GetProgressToken() any { return getProgressToken(x) } +func (x *PingParams) SetProgressToken(t any) { setProgressToken(x, t) } + +type ProgressNotificationParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The progress token which was given in the initial request, used to associate + // this notification with the request that is proceeding. + ProgressToken any `json:"progressToken"` + // An optional message describing the current progress. + Message string `json:"message,omitempty"` + // The progress thus far. This should increase every time progress is made, even + // if the total is unknown. + Progress float64 `json:"progress"` + // Total number of items to process (or total progress required), if known. + // Zero means unknown. + Total float64 `json:"total,omitempty"` +} + +func (*ProgressNotificationParams) isParams() {} + +// A prompt or prompt template that the server offers. +type Prompt struct { + // See [specification/2025-06-18/basic/index#general-fields] for notes on _meta + // usage. + Meta `json:"_meta,omitempty"` + // A list of arguments to use for templating the prompt. + Arguments []*PromptArgument `json:"arguments,omitempty"` + // An optional description of what this prompt provides + Description string `json:"description,omitempty"` + // Intended for programmatic or logical use, but used as a display name in past + // specs or fallback (if title isn't present). + Name string `json:"name"` + // Intended for UI and end-user contexts — optimized to be human-readable and + // easily understood, even by those unfamiliar with domain-specific terminology. + Title string `json:"title,omitempty"` +} + +// Describes an argument that a prompt can accept. +type PromptArgument struct { + // Intended for programmatic or logical use, but used as a display name in past + // specs or fallback (if title isn't present). + Name string `json:"name"` + // Intended for UI and end-user contexts — optimized to be human-readable and + // easily understood, even by those unfamiliar with domain-specific terminology. + Title string `json:"title,omitempty"` + // A human-readable description of the argument. + Description string `json:"description,omitempty"` + // Whether this argument must be provided. + Required bool `json:"required,omitempty"` +} + +type PromptListChangedParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` +} + +func (x *PromptListChangedParams) isParams() {} +func (x *PromptListChangedParams) GetProgressToken() any { return getProgressToken(x) } +func (x *PromptListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// Describes a message returned as part of a prompt. +// +// This is similar to SamplingMessage, but also supports the embedding of +// resources from the MCP server. +type PromptMessage struct { + Content Content `json:"content"` + Role Role `json:"role"` +} + +// UnmarshalJSON handles the unmarshalling of content into the Content +// interface. +func (m *PromptMessage) UnmarshalJSON(data []byte) error { + type msg PromptMessage // avoid recursion + var wire struct { + msg + Content *wireContent `json:"content"` + } + if err := json.Unmarshal(data, &wire); err != nil { + return err + } + var err error + if wire.msg.Content, err = contentFromWire(wire.Content, nil); err != nil { + return err + } + *m = PromptMessage(wire.msg) + return nil +} + +type ReadResourceParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The URI of the resource to read. The URI can use any protocol; it is up to + // the server how to interpret it. + URI string `json:"uri"` +} + +func (x *ReadResourceParams) isParams() {} +func (x *ReadResourceParams) GetProgressToken() any { return getProgressToken(x) } +func (x *ReadResourceParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// The server's response to a resources/read request from the client. +type ReadResourceResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + Contents []*ResourceContents `json:"contents"` +} + +func (*ReadResourceResult) isResult() {} + +// A known resource that the server is capable of reading. +type Resource struct { + // See [specification/2025-06-18/basic/index#general-fields] for notes on _meta + // usage. + Meta `json:"_meta,omitempty"` + // Optional annotations for the client. + Annotations *Annotations `json:"annotations,omitempty"` + // A description of what this resource represents. + // + // This can be used by clients to improve the LLM's understanding of available + // resources. It can be thought of like a "hint" to the model. + Description string `json:"description,omitempty"` + // The MIME type of this resource, if known. + MIMEType string `json:"mimeType,omitempty"` + // Intended for programmatic or logical use, but used as a display name in past + // specs or fallback (if title isn't present). + Name string `json:"name"` + // The size of the raw resource content, in bytes (i.e., before base64 encoding + // or any tokenization), if known. + // + // This can be used by Hosts to display file sizes and estimate context window + // usage. + Size int64 `json:"size,omitempty"` + // Intended for UI and end-user contexts — optimized to be human-readable and + // easily understood, even by those unfamiliar with domain-specific terminology. + // + // If not provided, the name should be used for display (except for Tool, where + // Annotations.Title should be given precedence over using name, if + // present). + Title string `json:"title,omitempty"` + // The URI of this resource. + URI string `json:"uri"` +} + +type ResourceListChangedParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` +} + +func (x *ResourceListChangedParams) isParams() {} +func (x *ResourceListChangedParams) GetProgressToken() any { return getProgressToken(x) } +func (x *ResourceListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// A template description for resources available on the server. +type ResourceTemplate struct { + // See [specification/2025-06-18/basic/index#general-fields] for notes on _meta + // usage. + Meta `json:"_meta,omitempty"` + // Optional annotations for the client. + Annotations *Annotations `json:"annotations,omitempty"` + // A description of what this template is for. + // + // This can be used by clients to improve the LLM's understanding of available + // resources. It can be thought of like a "hint" to the model. + Description string `json:"description,omitempty"` + // The MIME type for all resources that match this template. This should only be + // included if all resources matching this template have the same type. + MIMEType string `json:"mimeType,omitempty"` + // Intended for programmatic or logical use, but used as a display name in past + // specs or fallback (if title isn't present). + Name string `json:"name"` + // Intended for UI and end-user contexts — optimized to be human-readable and + // easily understood, even by those unfamiliar with domain-specific terminology. + // + // If not provided, the name should be used for display (except for Tool, where + // Annotations.Title should be given precedence over using name, if + // present). + Title string `json:"title,omitempty"` + // A URI template (according to RFC 6570) that can be used to construct resource + // URIs. + URITemplate string `json:"uriTemplate"` +} + +// The sender or recipient of messages and data in a conversation. +type Role string + +// Represents a root directory or file that the server can operate on. +type Root struct { + // See [specification/2025-06-18/basic/index#general-fields] for notes on _meta + // usage. + Meta `json:"_meta,omitempty"` + // An optional name for the root. This can be used to provide a human-readable + // identifier for the root, which may be useful for display purposes or for + // referencing the root in other parts of the application. + Name string `json:"name,omitempty"` + // The URI identifying the root. This *must* start with file:// for now. This + // restriction may be relaxed in future versions of the protocol to allow other + // URI schemes. + URI string `json:"uri"` +} + +type RootsListChangedParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` +} + +func (x *RootsListChangedParams) isParams() {} +func (x *RootsListChangedParams) GetProgressToken() any { return getProgressToken(x) } +func (x *RootsListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// SamplingCapabilities describes the capabilities for sampling. +type SamplingCapabilities struct{} + +// ElicitationCapabilities describes the capabilities for elicitation. +type ElicitationCapabilities struct{} + +// Describes a message issued to or received from an LLM API. +type SamplingMessage struct { + Content Content `json:"content"` + Role Role `json:"role"` +} + +// UnmarshalJSON handles the unmarshalling of content into the Content +// interface. +func (m *SamplingMessage) UnmarshalJSON(data []byte) error { + type msg SamplingMessage // avoid recursion + var wire struct { + msg + Content *wireContent `json:"content"` + } + if err := json.Unmarshal(data, &wire); err != nil { + return err + } + var err error + if wire.msg.Content, err = contentFromWire(wire.Content, map[string]bool{"text": true, "image": true, "audio": true}); err != nil { + return err + } + *m = SamplingMessage(wire.msg) + return nil +} + +type SetLoggingLevelParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The level of logging that the client wants to receive from the server. The + // server should send all logs at this level and higher (i.e., more severe) to + // the client as notifications/message. + Level LoggingLevel `json:"level"` +} + +func (x *SetLoggingLevelParams) isParams() {} +func (x *SetLoggingLevelParams) GetProgressToken() any { return getProgressToken(x) } +func (x *SetLoggingLevelParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// Definition for a tool the client can call. +type Tool struct { + // See [specification/2025-06-18/basic/index#general-fields] for notes on _meta + // usage. + Meta `json:"_meta,omitempty"` + // Optional additional tool information. + // + // Display name precedence order is: title, annotations.title, then name. + Annotations *ToolAnnotations `json:"annotations,omitempty"` + // A human-readable description of the tool. + // + // This can be used by clients to improve the LLM's understanding of available + // tools. It can be thought of like a "hint" to the model. + Description string `json:"description,omitempty"` + // InputSchema holds a JSON Schema object defining the expected parameters + // for the tool. + // + // From the server, this field may be set to any value that JSON-marshals to + // valid JSON schema (including json.RawMessage). However, for tools added + // using [AddTool], which automatically validates inputs and outputs, the + // schema must be in a draft the SDK understands. Currently, the SDK uses + // github.com/google/jsonschema-go for inference and validation, which only + // supports the 2020-12 draft of JSON schema. To do your own validation, use + // [Server.AddTool]. + // + // From the client, this field will hold the default JSON marshaling of the + // server's input schema (a map[string]any). + InputSchema any `json:"inputSchema"` + // Intended for programmatic or logical use, but used as a display name in past + // specs or fallback (if title isn't present). + Name string `json:"name"` + // OutputSchema holds an optional JSON Schema object defining the structure + // of the tool's output returned in the StructuredContent field of a + // CallToolResult. + // + // From the server, this field may be set to any value that JSON-marshals to + // valid JSON schema (including json.RawMessage). However, for tools added + // using [AddTool], which automatically validates inputs and outputs, the + // schema must be in a draft the SDK understands. Currently, the SDK uses + // github.com/google/jsonschema-go for inference and validation, which only + // supports the 2020-12 draft of JSON schema. To do your own validation, use + // [Server.AddTool]. + // + // From the client, this field will hold the default JSON marshaling of the + // server's output schema (a map[string]any). + OutputSchema any `json:"outputSchema,omitempty"` + // Intended for UI and end-user contexts — optimized to be human-readable and + // easily understood, even by those unfamiliar with domain-specific terminology. + // If not provided, Annotations.Title should be used for display if present, + // otherwise Name. + Title string `json:"title,omitempty"` +} + +// Additional properties describing a Tool to clients. +// +// NOTE: all properties in ToolAnnotations are hints. They are not +// guaranteed to provide a faithful description of tool behavior (including +// descriptive properties like title). +// +// Clients should never make tool use decisions based on ToolAnnotations +// received from untrusted servers. +type ToolAnnotations struct { + // If true, the tool may perform destructive updates to its environment. If + // false, the tool performs only additive updates. + // + // (This property is meaningful only when ReadOnlyHint == false.) + // + // Default: true + DestructiveHint *bool `json:"destructiveHint,omitempty"` + // If true, calling the tool repeatedly with the same arguments will have no + // additional effect on the its environment. + // + // (This property is meaningful only when ReadOnlyHint == false.) + // + // Default: false + IdempotentHint bool `json:"idempotentHint,omitempty"` + // If true, this tool may interact with an "open world" of external entities. If + // false, the tool's domain of interaction is closed. For example, the world of + // a web search tool is open, whereas that of a memory tool is not. + // + // Default: true + OpenWorldHint *bool `json:"openWorldHint,omitempty"` + // If true, the tool does not modify its environment. + // + // Default: false + ReadOnlyHint bool `json:"readOnlyHint,omitempty"` + // A human-readable title for the tool. + Title string `json:"title,omitempty"` +} + +type ToolListChangedParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` +} + +func (x *ToolListChangedParams) isParams() {} +func (x *ToolListChangedParams) GetProgressToken() any { return getProgressToken(x) } +func (x *ToolListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// Sent from the client to request resources/updated notifications from the +// server whenever a particular resource changes. +type SubscribeParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The URI of the resource to subscribe to. + URI string `json:"uri"` +} + +func (*SubscribeParams) isParams() {} + +// Sent from the client to request cancellation of resources/updated +// notifications from the server. This should follow a previous +// resources/subscribe request. +type UnsubscribeParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The URI of the resource to unsubscribe from. + URI string `json:"uri"` +} + +func (*UnsubscribeParams) isParams() {} + +// A notification from the server to the client, informing it that a resource +// has changed and may need to be read again. This should only be sent if the +// client previously sent a resources/subscribe request. +type ResourceUpdatedNotificationParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The URI of the resource that has been updated. This might be a sub-resource of the one that the client actually subscribed to. + URI string `json:"uri"` +} + +func (*ResourceUpdatedNotificationParams) isParams() {} + +// TODO(jba): add CompleteRequest and related types. + +// A request from the server to elicit additional information from the user via the client. +type ElicitParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The message to present to the user. + Message string `json:"message"` + // A JSON schema object defining the requested elicitation schema. + // + // From the server, this field may be set to any value that can JSON-marshal + // to valid JSON schema (including json.RawMessage for raw schema values). + // Internally, the SDK uses github.com/google/jsonschema-go for validation, + // which only supports the 2020-12 draft of the JSON schema spec. + // + // From the client, this field will use the default JSON marshaling (a + // map[string]any). + // + // Only top-level properties are allowed, without nesting. + RequestedSchema any `json:"requestedSchema"` +} + +func (x *ElicitParams) isParams() {} + +func (x *ElicitParams) GetProgressToken() any { return getProgressToken(x) } +func (x *ElicitParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// The client's response to an elicitation/create request from the server. +type ElicitResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The user action in response to the elicitation. + // - "accept": User submitted the form/confirmed the action + // - "decline": User explicitly declined the action + // - "cancel": User dismissed without making an explicit choice + Action string `json:"action"` + // The submitted form data, only present when action is "accept". + // Contains values matching the requested schema. + Content map[string]any `json:"content,omitempty"` +} + +func (*ElicitResult) isResult() {} + +// An Implementation describes the name and version of an MCP implementation, with an optional +// title for UI representation. +type Implementation struct { + // Intended for programmatic or logical use, but used as a display name in past + // specs or fallback (if title isn't present). + Name string `json:"name"` + // Intended for UI and end-user contexts — optimized to be human-readable and + // easily understood, even by those unfamiliar with domain-specific terminology. + Title string `json:"title,omitempty"` + Version string `json:"version"` +} + +// Present if the server supports argument autocompletion suggestions. +type CompletionCapabilities struct{} + +// Present if the server supports sending log messages to the client. +type LoggingCapabilities struct{} + +// Present if the server offers any prompt templates. +type PromptCapabilities struct { + // Whether this server supports notifications for changes to the prompt list. + ListChanged bool `json:"listChanged,omitempty"` +} + +// Present if the server offers any resources to read. +type ResourceCapabilities struct { + // Whether this server supports notifications for changes to the resource list. + ListChanged bool `json:"listChanged,omitempty"` + // Whether this server supports subscribing to resource updates. + Subscribe bool `json:"subscribe,omitempty"` +} + +// Capabilities that a server may support. Known capabilities are defined here, +// in this schema, but this is not a closed set: any server can define its own, +// additional capabilities. +type ServerCapabilities struct { + // Present if the server supports argument autocompletion suggestions. + Completions *CompletionCapabilities `json:"completions,omitempty"` + // Experimental, non-standard capabilities that the server supports. + Experimental map[string]any `json:"experimental,omitempty"` + // Present if the server supports sending log messages to the client. + Logging *LoggingCapabilities `json:"logging,omitempty"` + // Present if the server offers any prompt templates. + Prompts *PromptCapabilities `json:"prompts,omitempty"` + // Present if the server offers any resources to read. + Resources *ResourceCapabilities `json:"resources,omitempty"` + // Present if the server offers any tools to call. + Tools *ToolCapabilities `json:"tools,omitempty"` +} + +// Present if the server offers any tools to call. +type ToolCapabilities struct { + // Whether this server supports notifications for changes to the tool list. + ListChanged bool `json:"listChanged,omitempty"` +} + +const ( + methodCallTool = "tools/call" + notificationCancelled = "notifications/cancelled" + methodComplete = "completion/complete" + methodCreateMessage = "sampling/createMessage" + methodElicit = "elicitation/create" + methodGetPrompt = "prompts/get" + methodInitialize = "initialize" + notificationInitialized = "notifications/initialized" + methodListPrompts = "prompts/list" + methodListResourceTemplates = "resources/templates/list" + methodListResources = "resources/list" + methodListRoots = "roots/list" + methodListTools = "tools/list" + notificationLoggingMessage = "notifications/message" + methodPing = "ping" + notificationProgress = "notifications/progress" + notificationPromptListChanged = "notifications/prompts/list_changed" + methodReadResource = "resources/read" + notificationResourceListChanged = "notifications/resources/list_changed" + notificationResourceUpdated = "notifications/resources/updated" + notificationRootsListChanged = "notifications/roots/list_changed" + methodSetLevel = "logging/setLevel" + methodSubscribe = "resources/subscribe" + notificationToolListChanged = "notifications/tools/list_changed" + methodUnsubscribe = "resources/unsubscribe" +) diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/requests.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/requests.go new file mode 100644 index 0000000000..82b700f564 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/requests.go @@ -0,0 +1,37 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file holds the request types. + +package mcp + +type ( + CallToolRequest = ServerRequest[*CallToolParamsRaw] + CompleteRequest = ServerRequest[*CompleteParams] + GetPromptRequest = ServerRequest[*GetPromptParams] + InitializedRequest = ServerRequest[*InitializedParams] + ListPromptsRequest = ServerRequest[*ListPromptsParams] + ListResourcesRequest = ServerRequest[*ListResourcesParams] + ListResourceTemplatesRequest = ServerRequest[*ListResourceTemplatesParams] + ListToolsRequest = ServerRequest[*ListToolsParams] + ProgressNotificationServerRequest = ServerRequest[*ProgressNotificationParams] + ReadResourceRequest = ServerRequest[*ReadResourceParams] + RootsListChangedRequest = ServerRequest[*RootsListChangedParams] + SubscribeRequest = ServerRequest[*SubscribeParams] + UnsubscribeRequest = ServerRequest[*UnsubscribeParams] +) + +type ( + CreateMessageRequest = ClientRequest[*CreateMessageParams] + ElicitRequest = ClientRequest[*ElicitParams] + initializedClientRequest = ClientRequest[*InitializedParams] + InitializeRequest = ClientRequest[*InitializeParams] + ListRootsRequest = ClientRequest[*ListRootsParams] + LoggingMessageRequest = ClientRequest[*LoggingMessageParams] + ProgressNotificationClientRequest = ClientRequest[*ProgressNotificationParams] + PromptListChangedRequest = ClientRequest[*PromptListChangedParams] + ResourceListChangedRequest = ClientRequest[*ResourceListChangedParams] + ResourceUpdatedNotificationRequest = ClientRequest[*ResourceUpdatedNotificationParams] + ToolListChangedRequest = ClientRequest[*ToolListChangedParams] +) diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource.go new file mode 100644 index 0000000000..8746edaedf --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource.go @@ -0,0 +1,164 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/url" + "os" + "path/filepath" + "strings" + + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/internal/util" + "github.com/yosida95/uritemplate/v3" +) + +// A serverResource associates a Resource with its handler. +type serverResource struct { + resource *Resource + handler ResourceHandler +} + +// A serverResourceTemplate associates a ResourceTemplate with its handler. +type serverResourceTemplate struct { + resourceTemplate *ResourceTemplate + handler ResourceHandler +} + +// A ResourceHandler is a function that reads a resource. +// It will be called when the client calls [ClientSession.ReadResource]. +// If it cannot find the resource, it should return the result of calling [ResourceNotFoundError]. +type ResourceHandler func(context.Context, *ReadResourceRequest) (*ReadResourceResult, error) + +// ResourceNotFoundError returns an error indicating that a resource being read could +// not be found. +func ResourceNotFoundError(uri string) error { + return &jsonrpc2.WireError{ + Code: codeResourceNotFound, + Message: "Resource not found", + Data: json.RawMessage(fmt.Sprintf(`{"uri":%q}`, uri)), + } +} + +// readFileResource reads from the filesystem at a URI relative to dirFilepath, respecting +// the roots. +// dirFilepath and rootFilepaths are absolute filesystem paths. +func readFileResource(rawURI, dirFilepath string, rootFilepaths []string) ([]byte, error) { + uriFilepath, err := computeURIFilepath(rawURI, dirFilepath, rootFilepaths) + if err != nil { + return nil, err + } + + var data []byte + err = withFile(dirFilepath, uriFilepath, func(f *os.File) error { + var err error + data, err = io.ReadAll(f) + return err + }) + if os.IsNotExist(err) { + err = ResourceNotFoundError(rawURI) + } + return data, err +} + +// computeURIFilepath returns a path relative to dirFilepath. +// The dirFilepath and rootFilepaths are absolute file paths. +func computeURIFilepath(rawURI, dirFilepath string, rootFilepaths []string) (string, error) { + // We use "file path" to mean a filesystem path. + uri, err := url.Parse(rawURI) + if err != nil { + return "", err + } + if uri.Scheme != "file" { + return "", fmt.Errorf("URI is not a file: %s", uri) + } + if uri.Path == "" { + // A more specific error than the one below, to catch the + // common mistake "file://foo". + return "", errors.New("empty path") + } + // The URI's path is interpreted relative to dirFilepath, and in the local filesystem. + // It must not try to escape its directory. + uriFilepathRel, err := filepath.Localize(strings.TrimPrefix(uri.Path, "/")) + if err != nil { + return "", fmt.Errorf("%q cannot be localized: %w", uriFilepathRel, err) + } + + // Check roots, if there are any. + if len(rootFilepaths) > 0 { + // To check against the roots, we need an absolute file path, not relative to the directory. + // uriFilepath is local, so the joined path is under dirFilepath. + uriFilepathAbs := filepath.Join(dirFilepath, uriFilepathRel) + rootOK := false + // Check that the requested file path is under some root. + // Since both paths are absolute, that's equivalent to filepath.Rel constructing + // a local path. + for _, rootFilepathAbs := range rootFilepaths { + if rel, err := filepath.Rel(rootFilepathAbs, uriFilepathAbs); err == nil && filepath.IsLocal(rel) { + rootOK = true + break + } + } + if !rootOK { + return "", fmt.Errorf("URI path %q is not under any root", uriFilepathAbs) + } + } + return uriFilepathRel, nil +} + +// fileRoots transforms the Roots obtained from the client into absolute paths on +// the local filesystem. +// TODO(jba): expose this functionality to user ResourceHandlers, +// so they don't have to repeat it. +func fileRoots(rawRoots []*Root) ([]string, error) { + var fileRoots []string + for _, r := range rawRoots { + fr, err := fileRoot(r) + if err != nil { + return nil, err + } + fileRoots = append(fileRoots, fr) + } + return fileRoots, nil +} + +// fileRoot returns the absolute path for Root. +func fileRoot(root *Root) (_ string, err error) { + defer util.Wrapf(&err, "root %q", root.URI) + + // Convert to absolute file path. + rurl, err := url.Parse(root.URI) + if err != nil { + return "", err + } + if rurl.Scheme != "file" { + return "", errors.New("not a file URI") + } + if rurl.Path == "" { + // A more specific error than the one below, to catch the + // common mistake "file://foo". + return "", errors.New("empty path") + } + // We don't want Localize here: we want an absolute path, which is not local. + fileRoot := filepath.Clean(filepath.FromSlash(rurl.Path)) + if !filepath.IsAbs(fileRoot) { + return "", errors.New("not an absolute path") + } + return fileRoot, nil +} + +// Matches reports whether the receiver's uri template matches the uri. +func (sr *serverResourceTemplate) Matches(uri string) bool { + tmpl, err := uritemplate.New(sr.resourceTemplate.URITemplate) + if err != nil { + return false + } + return tmpl.Regexp().MatchString(uri) +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource_go124.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource_go124.go new file mode 100644 index 0000000000..4a35603c66 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource_go124.go @@ -0,0 +1,29 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build go1.24 + +package mcp + +import ( + "errors" + "os" +) + +// withFile calls f on the file at join(dir, rel), +// protecting against path traversal attacks. +func withFile(dir, rel string, f func(*os.File) error) (err error) { + r, err := os.OpenRoot(dir) + if err != nil { + return err + } + defer r.Close() + file, err := r.Open(rel) + if err != nil { + return err + } + // Record error, in case f writes. + defer func() { err = errors.Join(err, file.Close()) }() + return f(file) +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource_pre_go124.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource_pre_go124.go new file mode 100644 index 0000000000..d1f72eedc4 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource_pre_go124.go @@ -0,0 +1,25 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build !go1.24 + +package mcp + +import ( + "errors" + "os" + "path/filepath" +) + +// withFile calls f on the file at join(dir, rel). +// It does not protect against path traversal attacks. +func withFile(dir, rel string, f func(*os.File) error) (err error) { + file, err := os.Open(filepath.Join(dir, rel)) + if err != nil { + return err + } + // Record error, in case f writes. + defer func() { err = errors.Join(err, file.Close()) }() + return f(file) +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/server.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/server.go new file mode 100644 index 0000000000..4a7bc89a89 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/server.go @@ -0,0 +1,1303 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/gob" + "encoding/json" + "fmt" + "iter" + "log/slog" + "maps" + "net/url" + "path/filepath" + "reflect" + "slices" + "sync" + "sync/atomic" + "time" + + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/internal/util" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" + "github.com/yosida95/uritemplate/v3" +) + +// DefaultPageSize is the default for [ServerOptions.PageSize]. +const DefaultPageSize = 1000 + +// A Server is an instance of an MCP server. +// +// Servers expose server-side MCP features, which can serve one or more MCP +// sessions by using [Server.Run]. +type Server struct { + // fixed at creation + impl *Implementation + opts ServerOptions + + mu sync.Mutex + prompts *featureSet[*serverPrompt] + tools *featureSet[*serverTool] + resources *featureSet[*serverResource] + resourceTemplates *featureSet[*serverResourceTemplate] + sessions []*ServerSession + sendingMethodHandler_ MethodHandler + receivingMethodHandler_ MethodHandler + resourceSubscriptions map[string]map[*ServerSession]bool // uri -> session -> bool +} + +// ServerOptions is used to configure behavior of the server. +type ServerOptions struct { + // Optional instructions for connected clients. + Instructions string + // If non-nil, log server activity. + Logger *slog.Logger + // If non-nil, called when "notifications/initialized" is received. + InitializedHandler func(context.Context, *InitializedRequest) + // PageSize is the maximum number of items to return in a single page for + // list methods (e.g. ListTools). + // + // If zero, defaults to [DefaultPageSize]. + PageSize int + // If non-nil, called when "notifications/roots/list_changed" is received. + RootsListChangedHandler func(context.Context, *RootsListChangedRequest) + // If non-nil, called when "notifications/progress" is received. + ProgressNotificationHandler func(context.Context, *ProgressNotificationServerRequest) + // If non-nil, called when "completion/complete" is received. + CompletionHandler func(context.Context, *CompleteRequest) (*CompleteResult, error) + // If non-zero, defines an interval for regular "ping" requests. + // If the peer fails to respond to pings originating from the keepalive check, + // the session is automatically closed. + KeepAlive time.Duration + // Function called when a client session subscribes to a resource. + SubscribeHandler func(context.Context, *SubscribeRequest) error + // Function called when a client session unsubscribes from a resource. + UnsubscribeHandler func(context.Context, *UnsubscribeRequest) error + // If true, advertises the prompts capability during initialization, + // even if no prompts have been registered. + HasPrompts bool + // If true, advertises the resources capability during initialization, + // even if no resources have been registered. + HasResources bool + // If true, advertises the tools capability during initialization, + // even if no tools have been registered. + HasTools bool + + // GetSessionID provides the next session ID to use for an incoming request. + // If nil, a default randomly generated ID will be used. + // + // Session IDs should be globally unique across the scope of the server, + // which may span multiple processes in the case of distributed servers. + // + // As a special case, if GetSessionID returns the empty string, the + // Mcp-Session-Id header will not be set. + GetSessionID func() string +} + +// NewServer creates a new MCP server. The resulting server has no features: +// add features using the various Server.AddXXX methods, and the [AddTool] function. +// +// The server can be connected to one or more MCP clients using [Server.Run]. +// +// The first argument must not be nil. +// +// If non-nil, the provided options are used to configure the server. +func NewServer(impl *Implementation, options *ServerOptions) *Server { + if impl == nil { + panic("nil Implementation") + } + var opts ServerOptions + if options != nil { + opts = *options + } + options = nil // prevent reuse + if opts.PageSize < 0 { + panic(fmt.Errorf("invalid page size %d", opts.PageSize)) + } + if opts.PageSize == 0 { + opts.PageSize = DefaultPageSize + } + if opts.SubscribeHandler != nil && opts.UnsubscribeHandler == nil { + panic("SubscribeHandler requires UnsubscribeHandler") + } + if opts.UnsubscribeHandler != nil && opts.SubscribeHandler == nil { + panic("UnsubscribeHandler requires SubscribeHandler") + } + + if opts.GetSessionID == nil { + opts.GetSessionID = randText + } + + if opts.Logger == nil { // ensure we have a logger + opts.Logger = ensureLogger(nil) + } + + return &Server{ + impl: impl, + opts: opts, + prompts: newFeatureSet(func(p *serverPrompt) string { return p.prompt.Name }), + tools: newFeatureSet(func(t *serverTool) string { return t.tool.Name }), + resources: newFeatureSet(func(r *serverResource) string { return r.resource.URI }), + resourceTemplates: newFeatureSet(func(t *serverResourceTemplate) string { return t.resourceTemplate.URITemplate }), + sendingMethodHandler_: defaultSendingMethodHandler[*ServerSession], + receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession], + resourceSubscriptions: make(map[string]map[*ServerSession]bool), + } +} + +// AddPrompt adds a [Prompt] to the server, or replaces one with the same name. +func (s *Server) AddPrompt(p *Prompt, h PromptHandler) { + // Assume there was a change, since add replaces existing items. + // (It's possible an item was replaced with an identical one, but not worth checking.) + s.changeAndNotify( + notificationPromptListChanged, + &PromptListChangedParams{}, + func() bool { s.prompts.add(&serverPrompt{p, h}); return true }) +} + +// RemovePrompts removes the prompts with the given names. +// It is not an error to remove a nonexistent prompt. +func (s *Server) RemovePrompts(names ...string) { + s.changeAndNotify(notificationPromptListChanged, &PromptListChangedParams{}, + func() bool { return s.prompts.remove(names...) }) +} + +// AddTool adds a [Tool] to the server, or replaces one with the same name. +// The Tool argument must not be modified after this call. +// +// The tool's input schema must be non-nil and have the type "object". For a tool +// that takes no input, or one where any input is valid, set [Tool.InputSchema] to +// `{"type": "object"}`, using your preferred library or `json.RawMessage`. +// +// If present, [Tool.OutputSchema] must also have type "object". +// +// When the handler is invoked as part of a CallTool request, req.Params.Arguments +// will be a json.RawMessage. +// +// Unmarshaling the arguments and validating them against the input schema are the +// caller's responsibility. +// +// Validating the result against the output schema, if any, is the caller's responsibility. +// +// Setting the result's Content, StructuredContent and IsError fields are the caller's +// responsibility. +// +// Most users should use the top-level function [AddTool], which handles all these +// responsibilities. +func (s *Server) AddTool(t *Tool, h ToolHandler) { + if t.InputSchema == nil { + // This prevents the tool author from forgetting to write a schema where + // one should be provided. If we papered over this by supplying the empty + // schema, then every input would be validated and the problem wouldn't be + // discovered until runtime, when the LLM sent bad data. + panic(fmt.Errorf("AddTool %q: missing input schema", t.Name)) + } + if s, ok := t.InputSchema.(*jsonschema.Schema); ok { + if s.Type != "object" { + panic(fmt.Errorf(`AddTool %q: input schema must have type "object"`, t.Name)) + } + } else { + var m map[string]any + if err := remarshal(t.InputSchema, &m); err != nil { + panic(fmt.Errorf("AddTool %q: can't marshal input schema to a JSON object: %v", t.Name, err)) + } + if typ := m["type"]; typ != "object" { + panic(fmt.Errorf(`AddTool %q: input schema must have type "object" (got %v)`, t.Name, typ)) + } + } + if t.OutputSchema != nil { + if s, ok := t.OutputSchema.(*jsonschema.Schema); ok { + if s.Type != "object" { + panic(fmt.Errorf(`AddTool %q: output schema must have type "object"`, t.Name)) + } + } else { + var m map[string]any + if err := remarshal(t.OutputSchema, &m); err != nil { + panic(fmt.Errorf("AddTool %q: can't marshal output schema to a JSON object: %v", t.Name, err)) + } + if typ := m["type"]; typ != "object" { + panic(fmt.Errorf(`AddTool %q: output schema must have type "object" (got %v)`, t.Name, typ)) + } + } + } + st := &serverTool{tool: t, handler: h} + // Assume there was a change, since add replaces existing tools. + // (It's possible a tool was replaced with an identical one, but not worth checking.) + // TODO: Batch these changes by size and time? The typescript SDK doesn't. + // TODO: Surface notify error here? best not, in case we need to batch. + s.changeAndNotify(notificationToolListChanged, &ToolListChangedParams{}, + func() bool { s.tools.add(st); return true }) +} + +func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandler, error) { + tt := *t + + // Special handling for an "any" input: treat as an empty object. + if reflect.TypeFor[In]() == reflect.TypeFor[any]() && t.InputSchema == nil { + tt.InputSchema = &jsonschema.Schema{Type: "object"} + } + + var inputResolved *jsonschema.Resolved + if _, err := setSchema[In](&tt.InputSchema, &inputResolved); err != nil { + return nil, nil, fmt.Errorf("input schema: %w", err) + } + + // Handling for zero values: + // + // If Out is a pointer type and we've derived the output schema from its + // element type, use the zero value of its element type in place of a typed + // nil. + var ( + elemZero any // only non-nil if Out is a pointer type + outputResolved *jsonschema.Resolved + ) + if t.OutputSchema != nil || reflect.TypeFor[Out]() != reflect.TypeFor[any]() { + var err error + elemZero, err = setSchema[Out](&tt.OutputSchema, &outputResolved) + if err != nil { + return nil, nil, fmt.Errorf("output schema: %v", err) + } + } + + th := func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + var input json.RawMessage + if req.Params.Arguments != nil { + input = req.Params.Arguments + } + // Validate input and apply defaults. + var err error + input, err = applySchema(input, inputResolved) + if err != nil { + // TODO(#450): should this be considered a tool error? (and similar below) + return nil, fmt.Errorf("%w: validating \"arguments\": %v", jsonrpc2.ErrInvalidParams, err) + } + + // Unmarshal and validate args. + var in In + if input != nil { + if err := json.Unmarshal(input, &in); err != nil { + return nil, fmt.Errorf("%w: %v", jsonrpc2.ErrInvalidParams, err) + } + } + + // Call typed handler. + res, out, err := h(ctx, req, in) + // Handle server errors appropriately: + // - If the handler returns a structured error (like jsonrpc2.WireError), return it directly + // - If the handler returns a regular error, wrap it in a CallToolResult with IsError=true + // - This allows tools to distinguish between protocol errors and tool execution errors + if err != nil { + // Check if this is already a structured JSON-RPC error + if wireErr, ok := err.(*jsonrpc2.WireError); ok { + return nil, wireErr + } + // For regular errors, embed them in the tool result as per MCP spec + var errRes CallToolResult + errRes.setError(err) + return &errRes, nil + } + + if res == nil { + res = &CallToolResult{} + } + + // Marshal the output and put the RawMessage in the StructuredContent field. + var outval any = out + if elemZero != nil { + // Avoid typed nil, which will serialize as JSON null. + // Instead, use the zero value of the unpointered type. + var z Out + if any(out) == any(z) { // zero is only non-nil if Out is a pointer type + outval = elemZero + } + } + if outval != nil { + outbytes, err := json.Marshal(outval) + if err != nil { + return nil, fmt.Errorf("marshaling output: %w", err) + } + outJSON := json.RawMessage(outbytes) + // Validate the output JSON, and apply defaults. + // + // We validate against the JSON, rather than the output value, as + // some types may have custom JSON marshalling (issue #447). + outJSON, err = applySchema(outJSON, outputResolved) + if err != nil { + return nil, fmt.Errorf("validating tool output: %w", err) + } + res.StructuredContent = outJSON // avoid a second marshal over the wire + + // If the Content field isn't being used, return the serialized JSON in a + // TextContent block, as the spec suggests: + // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#structured-content. + if res.Content == nil { + res.Content = []Content{&TextContent{ + Text: string(outJSON), + }} + } + } + return res, nil + } // end of handler + + return &tt, th, nil +} + +// setSchema sets the schema and resolved schema corresponding to the type T. +// +// If sfield is nil, the schema is derived from T. +// +// Pointers are treated equivalently to non-pointers when deriving the schema. +// If an indirection occurred to derive the schema, a non-nil zero value is +// returned to be used in place of the typed nil zero value. +// +// Note that if sfield already holds a schema, zero will be nil even if T is a +// pointer: if the user provided the schema, they may have intentionally +// derived it from the pointer type, and handling of zero values is up to them. +// +// TODO(rfindley): we really shouldn't ever return 'null' results. Maybe we +// should have a jsonschema.Zero(schema) helper? +func setSchema[T any](sfield *any, rfield **jsonschema.Resolved) (zero any, err error) { + var internalSchema *jsonschema.Schema + if *sfield == nil { + rt := reflect.TypeFor[T]() + if rt.Kind() == reflect.Pointer { + rt = rt.Elem() + zero = reflect.Zero(rt).Interface() + } + // TODO: we should be able to pass nil opts here. + internalSchema, err = jsonschema.ForType(rt, &jsonschema.ForOptions{}) + if err == nil { + *sfield = internalSchema + } + } else if err := remarshal(*sfield, &internalSchema); err != nil { + return zero, err + } + if err != nil { + return zero, err + } + *rfield, err = internalSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) + return zero, err +} + +// AddTool adds a tool and typed tool handler to the server. +// +// If the tool's input schema is nil, it is set to the schema inferred from the +// In type parameter. Types are inferred from Go types, and property +// descriptions are read from the 'jsonschema' struct tag. Internally, the SDK +// uses the github.com/google/jsonschema-go package for inference and +// validation. The In type argument must be a map or a struct, so that its +// inferred JSON Schema has type "object", as required by the spec. As a +// special case, if the In type is 'any', the tool's input schema is set to an +// empty object schema value. +// +// If the tool's output schema is nil, and the Out type is not 'any', the +// output schema is set to the schema inferred from the Out type argument, +// which must also be a map or struct. If the Out type is 'any', the output +// schema is omitted. +// +// Unlike [Server.AddTool], AddTool does a lot automatically, and forces +// tools to conform to the MCP spec. See [ToolHandlerFor] for a detailed +// description of this automatic behavior. +func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) { + tt, hh, err := toolForErr(t, h) + if err != nil { + panic(fmt.Sprintf("AddTool: tool %q: %v", t.Name, err)) + } + s.AddTool(tt, hh) +} + +// RemoveTools removes the tools with the given names. +// It is not an error to remove a nonexistent tool. +func (s *Server) RemoveTools(names ...string) { + s.changeAndNotify(notificationToolListChanged, &ToolListChangedParams{}, + func() bool { return s.tools.remove(names...) }) +} + +// AddResource adds a [Resource] to the server, or replaces one with the same URI. +// AddResource panics if the resource URI is invalid or not absolute (has an empty scheme). +func (s *Server) AddResource(r *Resource, h ResourceHandler) { + s.changeAndNotify(notificationResourceListChanged, &ResourceListChangedParams{}, + func() bool { + if _, err := url.Parse(r.URI); err != nil { + panic(err) // url.Parse includes the URI in the error + } + s.resources.add(&serverResource{r, h}) + return true + }) +} + +// RemoveResources removes the resources with the given URIs. +// It is not an error to remove a nonexistent resource. +func (s *Server) RemoveResources(uris ...string) { + s.changeAndNotify(notificationResourceListChanged, &ResourceListChangedParams{}, + func() bool { return s.resources.remove(uris...) }) +} + +// AddResourceTemplate adds a [ResourceTemplate] to the server, or replaces one with the same URI. +// AddResourceTemplate panics if a URI template is invalid or not absolute (has an empty scheme). +func (s *Server) AddResourceTemplate(t *ResourceTemplate, h ResourceHandler) { + s.changeAndNotify(notificationResourceListChanged, &ResourceListChangedParams{}, + func() bool { + // Validate the URI template syntax + _, err := uritemplate.New(t.URITemplate) + if err != nil { + panic(fmt.Errorf("URI template %q is invalid: %w", t.URITemplate, err)) + } + s.resourceTemplates.add(&serverResourceTemplate{t, h}) + return true + }) +} + +// RemoveResourceTemplates removes the resource templates with the given URI templates. +// It is not an error to remove a nonexistent resource. +func (s *Server) RemoveResourceTemplates(uriTemplates ...string) { + s.changeAndNotify(notificationResourceListChanged, &ResourceListChangedParams{}, + func() bool { return s.resourceTemplates.remove(uriTemplates...) }) +} + +func (s *Server) capabilities() *ServerCapabilities { + s.mu.Lock() + defer s.mu.Unlock() + + caps := &ServerCapabilities{ + Logging: &LoggingCapabilities{}, + } + if s.opts.HasTools || s.tools.len() > 0 { + caps.Tools = &ToolCapabilities{ListChanged: true} + } + if s.opts.HasPrompts || s.prompts.len() > 0 { + caps.Prompts = &PromptCapabilities{ListChanged: true} + } + if s.opts.HasResources || s.resources.len() > 0 || s.resourceTemplates.len() > 0 { + caps.Resources = &ResourceCapabilities{ListChanged: true} + if s.opts.SubscribeHandler != nil { + caps.Resources.Subscribe = true + } + } + if s.opts.CompletionHandler != nil { + caps.Completions = &CompletionCapabilities{} + } + return caps +} + +func (s *Server) complete(ctx context.Context, req *CompleteRequest) (*CompleteResult, error) { + if s.opts.CompletionHandler == nil { + return nil, jsonrpc2.ErrMethodNotFound + } + return s.opts.CompletionHandler(ctx, req) +} + +// changeAndNotify is called when a feature is added or removed. +// It calls change, which should do the work and report whether a change actually occurred. +// If there was a change, it notifies a snapshot of the sessions. +func (s *Server) changeAndNotify(notification string, params Params, change func() bool) { + var sessions []*ServerSession + // Lock for the change, but not for the notification. + s.mu.Lock() + if change() { + sessions = slices.Clone(s.sessions) + } + s.mu.Unlock() + notifySessions(sessions, notification, params) +} + +// Sessions returns an iterator that yields the current set of server sessions. +// +// There is no guarantee that the iterator observes sessions that are added or +// removed during iteration. +func (s *Server) Sessions() iter.Seq[*ServerSession] { + s.mu.Lock() + clients := slices.Clone(s.sessions) + s.mu.Unlock() + return slices.Values(clients) +} + +func (s *Server) listPrompts(_ context.Context, req *ListPromptsRequest) (*ListPromptsResult, error) { + s.mu.Lock() + defer s.mu.Unlock() + if req.Params == nil { + req.Params = &ListPromptsParams{} + } + return paginateList(s.prompts, s.opts.PageSize, req.Params, &ListPromptsResult{}, func(res *ListPromptsResult, prompts []*serverPrompt) { + res.Prompts = []*Prompt{} // avoid JSON null + for _, p := range prompts { + res.Prompts = append(res.Prompts, p.prompt) + } + }) +} + +func (s *Server) getPrompt(ctx context.Context, req *GetPromptRequest) (*GetPromptResult, error) { + s.mu.Lock() + prompt, ok := s.prompts.get(req.Params.Name) + s.mu.Unlock() + if !ok { + // Return a proper JSON-RPC error with the correct error code + return nil, &jsonrpc2.WireError{ + Code: codeInvalidParams, + Message: fmt.Sprintf("unknown prompt %q", req.Params.Name), + } + } + return prompt.handler(ctx, req) +} + +func (s *Server) listTools(_ context.Context, req *ListToolsRequest) (*ListToolsResult, error) { + s.mu.Lock() + defer s.mu.Unlock() + if req.Params == nil { + req.Params = &ListToolsParams{} + } + return paginateList(s.tools, s.opts.PageSize, req.Params, &ListToolsResult{}, func(res *ListToolsResult, tools []*serverTool) { + res.Tools = []*Tool{} // avoid JSON null + for _, t := range tools { + res.Tools = append(res.Tools, t.tool) + } + }) +} + +func (s *Server) callTool(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + s.mu.Lock() + st, ok := s.tools.get(req.Params.Name) + s.mu.Unlock() + if !ok { + return nil, &jsonrpc2.WireError{ + Code: codeInvalidParams, + Message: fmt.Sprintf("unknown tool %q", req.Params.Name), + } + } + res, err := st.handler(ctx, req) + if err == nil && res != nil && res.Content == nil { + res2 := *res + res2.Content = []Content{} // avoid "null" + res = &res2 + } + return res, err +} + +func (s *Server) listResources(_ context.Context, req *ListResourcesRequest) (*ListResourcesResult, error) { + s.mu.Lock() + defer s.mu.Unlock() + if req.Params == nil { + req.Params = &ListResourcesParams{} + } + return paginateList(s.resources, s.opts.PageSize, req.Params, &ListResourcesResult{}, func(res *ListResourcesResult, resources []*serverResource) { + res.Resources = []*Resource{} // avoid JSON null + for _, r := range resources { + res.Resources = append(res.Resources, r.resource) + } + }) +} + +func (s *Server) listResourceTemplates(_ context.Context, req *ListResourceTemplatesRequest) (*ListResourceTemplatesResult, error) { + s.mu.Lock() + defer s.mu.Unlock() + if req.Params == nil { + req.Params = &ListResourceTemplatesParams{} + } + return paginateList(s.resourceTemplates, s.opts.PageSize, req.Params, &ListResourceTemplatesResult{}, + func(res *ListResourceTemplatesResult, rts []*serverResourceTemplate) { + res.ResourceTemplates = []*ResourceTemplate{} // avoid JSON null + for _, rt := range rts { + res.ResourceTemplates = append(res.ResourceTemplates, rt.resourceTemplate) + } + }) +} + +func (s *Server) readResource(ctx context.Context, req *ReadResourceRequest) (*ReadResourceResult, error) { + uri := req.Params.URI + // Look up the resource URI in the lists of resources and resource templates. + // This is a security check as well as an information lookup. + handler, mimeType, ok := s.lookupResourceHandler(uri) + if !ok { + // Don't expose the server configuration to the client. + // Treat an unregistered resource the same as a registered one that couldn't be found. + return nil, ResourceNotFoundError(uri) + } + res, err := handler(ctx, req) + if err != nil { + return nil, err + } + if res == nil || res.Contents == nil { + return nil, fmt.Errorf("reading resource %s: read handler returned nil information", uri) + } + // As a convenience, populate some fields. + for _, c := range res.Contents { + if c.URI == "" { + c.URI = uri + } + if c.MIMEType == "" { + c.MIMEType = mimeType + } + } + return res, nil +} + +// lookupResourceHandler returns the resource handler and MIME type for the resource or +// resource template matching uri. If none, the last return value is false. +func (s *Server) lookupResourceHandler(uri string) (ResourceHandler, string, bool) { + s.mu.Lock() + defer s.mu.Unlock() + // Try resources first. + if r, ok := s.resources.get(uri); ok { + return r.handler, r.resource.MIMEType, true + } + // Look for matching template. + for rt := range s.resourceTemplates.all() { + if rt.Matches(uri) { + return rt.handler, rt.resourceTemplate.MIMEType, true + } + } + return nil, "", false +} + +// fileResourceHandler returns a ReadResourceHandler that reads paths using dir as +// a base directory. +// It honors client roots and protects against path traversal attacks. +// +// The dir argument should be a filesystem path. It need not be absolute, but +// that is recommended to avoid a dependency on the current working directory (the +// check against client roots is done with an absolute path). If dir is not absolute +// and the current working directory is unavailable, fileResourceHandler panics. +// +// Lexical path traversal attacks, where the path has ".." elements that escape dir, +// are always caught. Go 1.24 and above also protects against symlink-based attacks, +// where symlinks under dir lead out of the tree. +func fileResourceHandler(dir string) ResourceHandler { + // Convert dir to an absolute path. + dirFilepath, err := filepath.Abs(dir) + if err != nil { + panic(err) + } + return func(ctx context.Context, req *ReadResourceRequest) (_ *ReadResourceResult, err error) { + defer util.Wrapf(&err, "reading resource %s", req.Params.URI) + + // TODO(#25): use a memoizing API here. + rootRes, err := req.Session.ListRoots(ctx, nil) + if err != nil { + return nil, fmt.Errorf("listing roots: %w", err) + } + roots, err := fileRoots(rootRes.Roots) + if err != nil { + return nil, err + } + data, err := readFileResource(req.Params.URI, dirFilepath, roots) + if err != nil { + return nil, err + } + // TODO(jba): figure out mime type. Omit for now: Server.readResource will fill it in. + return &ReadResourceResult{Contents: []*ResourceContents{ + {URI: req.Params.URI, Blob: data}, + }}, nil + } +} + +// ResourceUpdated sends a notification to all clients that have subscribed to the +// resource specified in params. This method is the primary way for a +// server author to signal that a resource has changed. +func (s *Server) ResourceUpdated(ctx context.Context, params *ResourceUpdatedNotificationParams) error { + s.mu.Lock() + subscribedSessions := s.resourceSubscriptions[params.URI] + sessions := slices.Collect(maps.Keys(subscribedSessions)) + s.mu.Unlock() + notifySessions(sessions, notificationResourceUpdated, params) + s.opts.Logger.Info("resource updated notification sent", "uri", params.URI, "subscriber_count", len(sessions)) + return nil +} + +func (s *Server) subscribe(ctx context.Context, req *SubscribeRequest) (*emptyResult, error) { + if s.opts.SubscribeHandler == nil { + return nil, fmt.Errorf("%w: server does not support resource subscriptions", jsonrpc2.ErrMethodNotFound) + } + if err := s.opts.SubscribeHandler(ctx, req); err != nil { + return nil, err + } + + s.mu.Lock() + defer s.mu.Unlock() + if s.resourceSubscriptions[req.Params.URI] == nil { + s.resourceSubscriptions[req.Params.URI] = make(map[*ServerSession]bool) + } + s.resourceSubscriptions[req.Params.URI][req.Session] = true + s.opts.Logger.Info("resource subscribed", "uri", req.Params.URI, "session_id", req.Session.ID()) + + return &emptyResult{}, nil +} + +func (s *Server) unsubscribe(ctx context.Context, req *UnsubscribeRequest) (*emptyResult, error) { + if s.opts.UnsubscribeHandler == nil { + return nil, jsonrpc2.ErrMethodNotFound + } + + if err := s.opts.UnsubscribeHandler(ctx, req); err != nil { + return nil, err + } + + s.mu.Lock() + defer s.mu.Unlock() + if subscribedSessions, ok := s.resourceSubscriptions[req.Params.URI]; ok { + delete(subscribedSessions, req.Session) + if len(subscribedSessions) == 0 { + delete(s.resourceSubscriptions, req.Params.URI) + } + } + s.opts.Logger.Info("resource unsubscribed", "uri", req.Params.URI, "session_id", req.Session.ID()) + + return &emptyResult{}, nil +} + +// Run runs the server over the given transport, which must be persistent. +// +// Run blocks until the client terminates the connection or the provided +// context is cancelled. If the context is cancelled, Run closes the connection. +// +// If tools have been added to the server before this call, then the server will +// advertise the capability for tools, including the ability to send list-changed notifications. +// If no tools have been added, the server will not have the tool capability. +// The same goes for other features like prompts and resources. +// +// Run is a convenience for servers that handle a single session (or one session at a time). +// It need not be called on servers that are used for multiple concurrent connections, +// as with [StreamableHTTPHandler]. +func (s *Server) Run(ctx context.Context, t Transport) error { + s.opts.Logger.Info("server run start") + ss, err := s.Connect(ctx, t, nil) + if err != nil { + s.opts.Logger.Error("server connect failed", "error", err) + return err + } + + ssClosed := make(chan error) + go func() { + ssClosed <- ss.Wait() + }() + + select { + case <-ctx.Done(): + ss.Close() + <-ssClosed // wait until waiting go routine above actually completes + s.opts.Logger.Error("server run cancelled", "error", ctx.Err()) + return ctx.Err() + case err := <-ssClosed: + if err != nil { + s.opts.Logger.Error("server session ended with error", "error", err) + } else { + s.opts.Logger.Info("server session ended") + } + return err + } +} + +// bind implements the binder[*ServerSession] interface, so that Servers can +// be connected using [connect]. +func (s *Server) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *ServerSessionState, onClose func()) *ServerSession { + assert(mcpConn != nil && conn != nil, "nil connection") + ss := &ServerSession{conn: conn, mcpConn: mcpConn, server: s, onClose: onClose} + if state != nil { + ss.state = *state + } + s.mu.Lock() + s.sessions = append(s.sessions, ss) + s.mu.Unlock() + s.opts.Logger.Info("server session connected", "session_id", ss.ID()) + return ss +} + +// disconnect implements the binder[*ServerSession] interface, so that +// Servers can be connected using [connect]. +func (s *Server) disconnect(cc *ServerSession) { + s.mu.Lock() + defer s.mu.Unlock() + s.sessions = slices.DeleteFunc(s.sessions, func(cc2 *ServerSession) bool { + return cc2 == cc + }) + + for _, subscribedSessions := range s.resourceSubscriptions { + delete(subscribedSessions, cc) + } + s.opts.Logger.Info("server session disconnected", "session_id", cc.ID()) +} + +// ServerSessionOptions configures the server session. +type ServerSessionOptions struct { + State *ServerSessionState + + onClose func() // used to clean up associated resources +} + +// Connect connects the MCP server over the given transport and starts handling +// messages. +// +// It returns a connection object that may be used to terminate the connection +// (with [Connection.Close]), or await client termination (with +// [Connection.Wait]). +// +// If opts.State is non-nil, it is the initial state for the server. +func (s *Server) Connect(ctx context.Context, t Transport, opts *ServerSessionOptions) (*ServerSession, error) { + var state *ServerSessionState + var onClose func() + if opts != nil { + state = opts.State + onClose = opts.onClose + } + + s.opts.Logger.Info("server connecting") + ss, err := connect(ctx, t, s, state, onClose) + if err != nil { + s.opts.Logger.Error("server connect error", "error", err) + return nil, err + } + return ss, nil +} + +// TODO: (nit) move all ServerSession methods below the ServerSession declaration. +func (ss *ServerSession) initialized(ctx context.Context, params *InitializedParams) (Result, error) { + if params == nil { + // Since we use nilness to signal 'initialized' state, we must ensure that + // params are non-nil. + params = new(InitializedParams) + } + var wasInit, wasInitd bool + ss.updateState(func(state *ServerSessionState) { + wasInit = state.InitializeParams != nil + wasInitd = state.InitializedParams != nil + if wasInit && !wasInitd { + state.InitializedParams = params + } + }) + + if !wasInit { + ss.server.opts.Logger.Error("initialized before initialize") + return nil, fmt.Errorf("%q before %q", notificationInitialized, methodInitialize) + } + if wasInitd { + ss.server.opts.Logger.Error("duplicate initialized notification") + return nil, fmt.Errorf("duplicate %q received", notificationInitialized) + } + if ss.server.opts.KeepAlive > 0 { + ss.startKeepalive(ss.server.opts.KeepAlive) + } + if h := ss.server.opts.InitializedHandler; h != nil { + h(ctx, serverRequestFor(ss, params)) + } + ss.server.opts.Logger.Info("session initialized") + return nil, nil +} + +func (s *Server) callRootsListChangedHandler(ctx context.Context, req *RootsListChangedRequest) (Result, error) { + if h := s.opts.RootsListChangedHandler; h != nil { + h(ctx, req) + } + return nil, nil +} + +func (ss *ServerSession) callProgressNotificationHandler(ctx context.Context, p *ProgressNotificationParams) (Result, error) { + if h := ss.server.opts.ProgressNotificationHandler; h != nil { + h(ctx, serverRequestFor(ss, p)) + } + return nil, nil +} + +// NotifyProgress sends a progress notification from the server to the client +// associated with this session. +// This is typically used to report on the status of a long-running request +// that was initiated by the client. +func (ss *ServerSession) NotifyProgress(ctx context.Context, params *ProgressNotificationParams) error { + return handleNotify(ctx, notificationProgress, newServerRequest(ss, orZero[Params](params))) +} + +func newServerRequest[P Params](ss *ServerSession, params P) *ServerRequest[P] { + return &ServerRequest[P]{Session: ss, Params: params} +} + +// A ServerSession is a logical connection from a single MCP client. Its +// methods can be used to send requests or notifications to the client. Create +// a session by calling [Server.Connect]. +// +// Call [ServerSession.Close] to close the connection, or await client +// termination with [ServerSession.Wait]. +type ServerSession struct { + // Ensure that onClose is called at most once. + // We defensively use an atomic CompareAndSwap rather than a sync.Once, in case the + // onClose callback triggers a re-entrant call to Close. + calledOnClose atomic.Bool + onClose func() + + server *Server + conn *jsonrpc2.Connection + mcpConn Connection + keepaliveCancel context.CancelFunc // TODO: theory around why keepaliveCancel need not be guarded + + mu sync.Mutex + state ServerSessionState +} + +func (ss *ServerSession) updateState(mut func(*ServerSessionState)) { + ss.mu.Lock() + mut(&ss.state) + copy := ss.state + ss.mu.Unlock() + if c, ok := ss.mcpConn.(serverConnection); ok { + c.sessionUpdated(copy) + } +} + +// hasInitialized reports whether the server has received the initialized +// notification. +// +// TODO(findleyr): use this to prevent change notifications. +func (ss *ServerSession) hasInitialized() bool { + ss.mu.Lock() + defer ss.mu.Unlock() + return ss.state.InitializedParams != nil +} + +// checkInitialized returns a formatted error if the server has not yet +// received the initialized notification. +func (ss *ServerSession) checkInitialized(method string) error { + if !ss.hasInitialized() { + // TODO(rfindley): enable this check. + // Right now is is flaky, because server tests don't await the initialized notification. + // Perhaps requests should simply block until they have received the initialized notification + + // if strings.HasPrefix(method, "notifications/") { + // return fmt.Errorf("must not send %q before %q is received", method, notificationInitialized) + // } else { + // return fmt.Errorf("cannot call %q before %q is received", method, notificationInitialized) + // } + } + return nil +} + +func (ss *ServerSession) ID() string { + if c, ok := ss.mcpConn.(hasSessionID); ok { + return c.SessionID() + } + return "" +} + +// Ping pings the client. +func (ss *ServerSession) Ping(ctx context.Context, params *PingParams) error { + _, err := handleSend[*emptyResult](ctx, methodPing, newServerRequest(ss, orZero[Params](params))) + return err +} + +// ListRoots lists the client roots. +func (ss *ServerSession) ListRoots(ctx context.Context, params *ListRootsParams) (*ListRootsResult, error) { + if err := ss.checkInitialized(methodListRoots); err != nil { + return nil, err + } + return handleSend[*ListRootsResult](ctx, methodListRoots, newServerRequest(ss, orZero[Params](params))) +} + +// CreateMessage sends a sampling request to the client. +func (ss *ServerSession) CreateMessage(ctx context.Context, params *CreateMessageParams) (*CreateMessageResult, error) { + if err := ss.checkInitialized(methodCreateMessage); err != nil { + return nil, err + } + if params == nil { + params = &CreateMessageParams{Messages: []*SamplingMessage{}} + } + if params.Messages == nil { + p2 := *params + p2.Messages = []*SamplingMessage{} // avoid JSON "null" + params = &p2 + } + return handleSend[*CreateMessageResult](ctx, methodCreateMessage, newServerRequest(ss, orZero[Params](params))) +} + +// Elicit sends an elicitation request to the client asking for user input. +func (ss *ServerSession) Elicit(ctx context.Context, params *ElicitParams) (*ElicitResult, error) { + if err := ss.checkInitialized(methodElicit); err != nil { + return nil, err + } + return handleSend[*ElicitResult](ctx, methodElicit, newServerRequest(ss, orZero[Params](params))) +} + +// Log sends a log message to the client. +// The message is not sent if the client has not called SetLevel, or if its level +// is below that of the last SetLevel. +func (ss *ServerSession) Log(ctx context.Context, params *LoggingMessageParams) error { + ss.mu.Lock() + logLevel := ss.state.LogLevel + ss.mu.Unlock() + if logLevel == "" { + // The spec is unclear, but seems to imply that no log messages are sent until the client + // sets the level. + // TODO(jba): read other SDKs, possibly file an issue. + return nil + } + if compareLevels(params.Level, logLevel) < 0 { + return nil + } + return handleNotify(ctx, notificationLoggingMessage, newServerRequest(ss, orZero[Params](params))) +} + +// AddSendingMiddleware wraps the current sending method handler using the provided +// middleware. Middleware is applied from right to left, so that the first one is +// executed first. +// +// For example, AddSendingMiddleware(m1, m2, m3) augments the method handler as +// m1(m2(m3(handler))). +// +// Sending middleware is called when a request is sent. It is useful for tasks +// such as tracing, metrics, and adding progress tokens. +func (s *Server) AddSendingMiddleware(middleware ...Middleware) { + s.mu.Lock() + defer s.mu.Unlock() + addMiddleware(&s.sendingMethodHandler_, middleware) +} + +// AddReceivingMiddleware wraps the current receiving method handler using +// the provided middleware. Middleware is applied from right to left, so that the +// first one is executed first. +// +// For example, AddReceivingMiddleware(m1, m2, m3) augments the method handler as +// m1(m2(m3(handler))). +// +// Receiving middleware is called when a request is received. It is useful for tasks +// such as authentication, request logging and metrics. +func (s *Server) AddReceivingMiddleware(middleware ...Middleware) { + s.mu.Lock() + defer s.mu.Unlock() + addMiddleware(&s.receivingMethodHandler_, middleware) +} + +// serverMethodInfos maps from the RPC method name to serverMethodInfos. +// +// The 'allowMissingParams' values are extracted from the protocol schema. +// TODO(rfindley): actually load and validate the protocol schema, rather than +// curating these method flags. +var serverMethodInfos = map[string]methodInfo{ + methodComplete: newServerMethodInfo(serverMethod((*Server).complete), 0), + methodInitialize: newServerMethodInfo(serverSessionMethod((*ServerSession).initialize), 0), + methodPing: newServerMethodInfo(serverSessionMethod((*ServerSession).ping), missingParamsOK), + methodListPrompts: newServerMethodInfo(serverMethod((*Server).listPrompts), missingParamsOK), + methodGetPrompt: newServerMethodInfo(serverMethod((*Server).getPrompt), 0), + methodListTools: newServerMethodInfo(serverMethod((*Server).listTools), missingParamsOK), + methodCallTool: newServerMethodInfo(serverMethod((*Server).callTool), 0), + methodListResources: newServerMethodInfo(serverMethod((*Server).listResources), missingParamsOK), + methodListResourceTemplates: newServerMethodInfo(serverMethod((*Server).listResourceTemplates), missingParamsOK), + methodReadResource: newServerMethodInfo(serverMethod((*Server).readResource), 0), + methodSetLevel: newServerMethodInfo(serverSessionMethod((*ServerSession).setLevel), 0), + methodSubscribe: newServerMethodInfo(serverMethod((*Server).subscribe), 0), + methodUnsubscribe: newServerMethodInfo(serverMethod((*Server).unsubscribe), 0), + notificationCancelled: newServerMethodInfo(serverSessionMethod((*ServerSession).cancel), notification|missingParamsOK), + notificationInitialized: newServerMethodInfo(serverSessionMethod((*ServerSession).initialized), notification|missingParamsOK), + notificationRootsListChanged: newServerMethodInfo(serverMethod((*Server).callRootsListChangedHandler), notification|missingParamsOK), + notificationProgress: newServerMethodInfo(serverSessionMethod((*ServerSession).callProgressNotificationHandler), notification), +} + +func (ss *ServerSession) sendingMethodInfos() map[string]methodInfo { return clientMethodInfos } + +func (ss *ServerSession) receivingMethodInfos() map[string]methodInfo { return serverMethodInfos } + +func (ss *ServerSession) sendingMethodHandler() MethodHandler { + s := ss.server + s.mu.Lock() + defer s.mu.Unlock() + return s.sendingMethodHandler_ +} + +func (ss *ServerSession) receivingMethodHandler() MethodHandler { + s := ss.server + s.mu.Lock() + defer s.mu.Unlock() + return s.receivingMethodHandler_ +} + +// getConn implements [session.getConn]. +func (ss *ServerSession) getConn() *jsonrpc2.Connection { return ss.conn } + +// handle invokes the method described by the given JSON RPC request. +func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, error) { + ss.mu.Lock() + initialized := ss.state.InitializeParams != nil + ss.mu.Unlock() + + // From the spec: + // "The client SHOULD NOT send requests other than pings before the server + // has responded to the initialize request." + switch req.Method { + case methodInitialize, methodPing, notificationInitialized: + default: + if !initialized { + ss.server.opts.Logger.Error("method invalid during initialization", "method", req.Method) + return nil, fmt.Errorf("method %q is invalid during session initialization", req.Method) + } + } + + // modelcontextprotocol/go-sdk#26: handle calls asynchronously, and + // notifications synchronously, except for 'initialize' which shouldn't be + // asynchronous to other + if req.IsCall() && req.Method != methodInitialize { + jsonrpc2.Async(ctx) + } + + // For the streamable transport, we need the request ID to correlate + // server->client calls and notifications to the incoming request from which + // they originated. See [idContextKey] for details. + ctx = context.WithValue(ctx, idContextKey{}, req.ID) + return handleReceive(ctx, ss, req) +} + +// InitializeParams returns the InitializeParams provided during the client's +// initial connection. +func (ss *ServerSession) InitializeParams() *InitializeParams { + ss.mu.Lock() + defer ss.mu.Unlock() + return ss.state.InitializeParams +} + +func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParams) (*InitializeResult, error) { + if params == nil { + return nil, fmt.Errorf("%w: \"params\" must be be provided", jsonrpc2.ErrInvalidParams) + } + ss.updateState(func(state *ServerSessionState) { + state.InitializeParams = params + }) + + s := ss.server + return &InitializeResult{ + // TODO(rfindley): alter behavior when falling back to an older version: + // reject unsupported features. + ProtocolVersion: negotiatedVersion(params.ProtocolVersion), + Capabilities: s.capabilities(), + Instructions: s.opts.Instructions, + ServerInfo: s.impl, + }, nil +} + +func (ss *ServerSession) ping(context.Context, *PingParams) (*emptyResult, error) { + return &emptyResult{}, nil +} + +// cancel is a placeholder: cancellation is handled the jsonrpc2 package. +// +// It should never be invoked in practice because cancellation is preempted, +// but having its signature here facilitates the construction of methodInfo +// that can be used to validate incoming cancellation notifications. +func (ss *ServerSession) cancel(context.Context, *CancelledParams) (Result, error) { + return nil, nil +} + +func (ss *ServerSession) setLevel(_ context.Context, params *SetLoggingLevelParams) (*emptyResult, error) { + ss.updateState(func(state *ServerSessionState) { + state.LogLevel = params.Level + }) + ss.server.opts.Logger.Info("client log level set", "level", params.Level) + return &emptyResult{}, nil +} + +// Close performs a graceful shutdown of the connection, preventing new +// requests from being handled, and waiting for ongoing requests to return. +// Close then terminates the connection. +// +// Close is idempotent and concurrency safe. +func (ss *ServerSession) Close() error { + if ss.keepaliveCancel != nil { + // Note: keepaliveCancel access is safe without a mutex because: + // 1. keepaliveCancel is only written once during startKeepalive (happens-before all Close calls) + // 2. context.CancelFunc is safe to call multiple times and from multiple goroutines + // 3. The keepalive goroutine calls Close on ping failure, but this is safe since + // Close is idempotent and conn.Close() handles concurrent calls correctly + ss.keepaliveCancel() + } + err := ss.conn.Close() + + if ss.onClose != nil && ss.calledOnClose.CompareAndSwap(false, true) { + ss.onClose() + } + + return err +} + +// Wait waits for the connection to be closed by the client. +func (ss *ServerSession) Wait() error { + return ss.conn.Wait() +} + +// startKeepalive starts the keepalive mechanism for this server session. +func (ss *ServerSession) startKeepalive(interval time.Duration) { + startKeepalive(ss, interval, &ss.keepaliveCancel) +} + +// pageToken is the internal structure for the opaque pagination cursor. +// It will be Gob-encoded and then Base64-encoded for use as a string token. +type pageToken struct { + LastUID string // The unique ID of the last resource seen. +} + +// encodeCursor encodes a unique identifier (UID) into a opaque pagination cursor +// by serializing a pageToken struct. +func encodeCursor(uid string) (string, error) { + var buf bytes.Buffer + token := pageToken{LastUID: uid} + encoder := gob.NewEncoder(&buf) + if err := encoder.Encode(token); err != nil { + return "", fmt.Errorf("failed to encode page token: %w", err) + } + return base64.URLEncoding.EncodeToString(buf.Bytes()), nil +} + +// decodeCursor decodes an opaque pagination cursor into the original pageToken struct. +func decodeCursor(cursor string) (*pageToken, error) { + decodedBytes, err := base64.URLEncoding.DecodeString(cursor) + if err != nil { + return nil, fmt.Errorf("failed to decode cursor: %w", err) + } + + var token pageToken + buf := bytes.NewBuffer(decodedBytes) + decoder := gob.NewDecoder(buf) + if err := decoder.Decode(&token); err != nil { + return nil, fmt.Errorf("failed to decode page token: %w, cursor: %v", err, cursor) + } + return &token, nil +} + +// paginateList is a generic helper that returns a paginated slice of items +// from a featureSet. It populates the provided result res with the items +// and sets its next cursor for subsequent pages. +// If there are no more pages, the next cursor within the result will be an empty string. +func paginateList[P listParams, R listResult[T], T any](fs *featureSet[T], pageSize int, params P, res R, setFunc func(R, []T)) (R, error) { + var seq iter.Seq[T] + if params.cursorPtr() == nil || *params.cursorPtr() == "" { + seq = fs.all() + } else { + pageToken, err := decodeCursor(*params.cursorPtr()) + // According to the spec, invalid cursors should return Invalid params. + if err != nil { + var zero R + return zero, jsonrpc2.ErrInvalidParams + } + seq = fs.above(pageToken.LastUID) + } + var count int + var features []T + for f := range seq { + count++ + // If we've seen pageSize + 1 elements, we've gathered enough info to determine + // if there's a next page. Stop processing the sequence. + if count == pageSize+1 { + break + } + features = append(features, f) + } + setFunc(res, features) + // No remaining pages. + if count < pageSize+1 { + return res, nil + } + nextCursor, err := encodeCursor(fs.uniqueID(features[len(features)-1])) + if err != nil { + var zero R + return zero, err + } + *res.nextCursorPtr() = nextCursor + return res, nil +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/session.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/session.go new file mode 100644 index 0000000000..dcf9888cc4 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/session.go @@ -0,0 +1,29 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +// hasSessionID is the interface which, if implemented by connections, informs +// the session about their session ID. +// +// TODO(rfindley): remove SessionID methods from connections, when it doesn't +// make sense. Or remove it from the Sessions entirely: why does it even need +// to be exposed? +type hasSessionID interface { + SessionID() string +} + +// ServerSessionState is the state of a session. +type ServerSessionState struct { + // InitializeParams are the parameters from 'initialize'. + InitializeParams *InitializeParams `json:"initializeParams"` + + // InitializedParams are the parameters from 'notifications/initialized'. + InitializedParams *InitializedParams `json:"initializedParams"` + + // LogLevel is the logging level for the session. + LogLevel LoggingLevel `json:"logLevel"` + + // TODO: resource subscriptions +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/shared.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/shared.go new file mode 100644 index 0000000000..e90bcbd8d5 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/shared.go @@ -0,0 +1,541 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file contains code shared between client and server, including +// method handler and middleware definitions. +// +// Much of this is here so that we can factor out commonalities using +// generics. If this becomes unwieldy, it can perhaps be simplified with +// reflection. + +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "log" + "net/http" + "reflect" + "slices" + "strings" + "time" + + "github.com/modelcontextprotocol/go-sdk/auth" + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" +) + +const ( + // latestProtocolVersion is the latest protocol version that this version of + // the SDK supports. + // + // It is the version that the client sends in the initialization request, and + // the default version used by the server. + latestProtocolVersion = protocolVersion20250618 + protocolVersion20250618 = "2025-06-18" + protocolVersion20250326 = "2025-03-26" + protocolVersion20241105 = "2024-11-05" +) + +var supportedProtocolVersions = []string{ + protocolVersion20250618, + protocolVersion20250326, + protocolVersion20241105, +} + +// negotiatedVersion returns the effective protocol version to use, given a +// client version. +func negotiatedVersion(clientVersion string) string { + // In general, prefer to use the clientVersion, but if we don't support the + // client's version, use the latest version. + // + // This handles the case where a new spec version is released, and the SDK + // does not support it yet. + if !slices.Contains(supportedProtocolVersions, clientVersion) { + return latestProtocolVersion + } + return clientVersion +} + +// A MethodHandler handles MCP messages. +// For methods, exactly one of the return values must be nil. +// For notifications, both must be nil. +type MethodHandler func(ctx context.Context, method string, req Request) (result Result, err error) + +// A Session is either a [ClientSession] or a [ServerSession]. +type Session interface { + // ID returns the session ID, or the empty string if there is none. + ID() string + + sendingMethodInfos() map[string]methodInfo + receivingMethodInfos() map[string]methodInfo + sendingMethodHandler() MethodHandler + receivingMethodHandler() MethodHandler + getConn() *jsonrpc2.Connection +} + +// Middleware is a function from [MethodHandler] to [MethodHandler]. +type Middleware func(MethodHandler) MethodHandler + +// addMiddleware wraps the handler in the middleware functions. +func addMiddleware(handlerp *MethodHandler, middleware []Middleware) { + for _, m := range slices.Backward(middleware) { + *handlerp = m(*handlerp) + } +} + +func defaultSendingMethodHandler[S Session](ctx context.Context, method string, req Request) (Result, error) { + info, ok := req.GetSession().sendingMethodInfos()[method] + if !ok { + // This can be called from user code, with an arbitrary value for method. + return nil, jsonrpc2.ErrNotHandled + } + // Notifications don't have results. + if strings.HasPrefix(method, "notifications/") { + return nil, req.GetSession().getConn().Notify(ctx, method, req.GetParams()) + } + // Create the result to unmarshal into. + // The concrete type of the result is the return type of the receiving function. + res := info.newResult() + if err := call(ctx, req.GetSession().getConn(), method, req.GetParams(), res); err != nil { + return nil, err + } + return res, nil +} + +// Helper method to avoid typed nil. +func orZero[T any, P *U, U any](p P) T { + if p == nil { + var zero T + return zero + } + return any(p).(T) +} + +func handleNotify(ctx context.Context, method string, req Request) error { + mh := req.GetSession().sendingMethodHandler() + _, err := mh(ctx, method, req) + return err +} + +func handleSend[R Result](ctx context.Context, method string, req Request) (R, error) { + mh := req.GetSession().sendingMethodHandler() + // mh might be user code, so ensure that it returns the right values for the jsonrpc2 protocol. + res, err := mh(ctx, method, req) + if err != nil { + var z R + return z, err + } + return res.(R), nil +} + +// defaultReceivingMethodHandler is the initial MethodHandler for servers and clients, before being wrapped by middleware. +func defaultReceivingMethodHandler[S Session](ctx context.Context, method string, req Request) (Result, error) { + info, ok := req.GetSession().receivingMethodInfos()[method] + if !ok { + // This can be called from user code, with an arbitrary value for method. + return nil, jsonrpc2.ErrNotHandled + } + return info.handleMethod(ctx, method, req) +} + +func handleReceive[S Session](ctx context.Context, session S, jreq *jsonrpc.Request) (Result, error) { + info, err := checkRequest(jreq, session.receivingMethodInfos()) + if err != nil { + return nil, err + } + params, err := info.unmarshalParams(jreq.Params) + if err != nil { + return nil, fmt.Errorf("handling '%s': %w", jreq.Method, err) + } + + mh := session.receivingMethodHandler() + re, _ := jreq.Extra.(*RequestExtra) + req := info.newRequest(session, params, re) + // mh might be user code, so ensure that it returns the right values for the jsonrpc2 protocol. + res, err := mh(ctx, jreq.Method, req) + if err != nil { + return nil, err + } + return res, nil +} + +// checkRequest checks the given request against the provided method info, to +// ensure it is a valid MCP request. +// +// If valid, the relevant method info is returned. Otherwise, a non-nil error +// is returned describing why the request is invalid. +// +// This is extracted from request handling so that it can be called in the +// transport layer to preemptively reject bad requests. +func checkRequest(req *jsonrpc.Request, infos map[string]methodInfo) (methodInfo, error) { + info, ok := infos[req.Method] + if !ok { + return methodInfo{}, fmt.Errorf("%w: %q unsupported", jsonrpc2.ErrNotHandled, req.Method) + } + if info.flags¬ification != 0 && req.IsCall() { + return methodInfo{}, fmt.Errorf("%w: unexpected id for %q", jsonrpc2.ErrInvalidRequest, req.Method) + } + if info.flags¬ification == 0 && !req.IsCall() { + return methodInfo{}, fmt.Errorf("%w: missing id for %q", jsonrpc2.ErrInvalidRequest, req.Method) + } + // missingParamsOK is checked here to catch the common case where "params" is + // missing entirely. + // + // However, it's checked again after unmarshalling to catch the rare but + // possible case where "params" is JSON null (see https://go.dev/issue/33835). + if info.flags&missingParamsOK == 0 && len(req.Params) == 0 { + return methodInfo{}, fmt.Errorf("%w: missing required \"params\"", jsonrpc2.ErrInvalidRequest) + } + return info, nil +} + +// methodInfo is information about sending and receiving a method. +type methodInfo struct { + // flags is a collection of flags controlling how the JSONRPC method is + // handled. See individual flag values for documentation. + flags methodFlags + // Unmarshal params from the wire into a Params struct. + // Used on the receive side. + unmarshalParams func(json.RawMessage) (Params, error) + newRequest func(Session, Params, *RequestExtra) Request + // Run the code when a call to the method is received. + // Used on the receive side. + handleMethod MethodHandler + // Create a pointer to a Result struct. + // Used on the send side. + newResult func() Result +} + +// The following definitions support converting from typed to untyped method handlers. +// Type parameter meanings: +// - S: sessions +// - P: params +// - R: results + +// A typedMethodHandler is like a MethodHandler, but with type information. +type ( + typedClientMethodHandler[P Params, R Result] func(context.Context, *ClientRequest[P]) (R, error) + typedServerMethodHandler[P Params, R Result] func(context.Context, *ServerRequest[P]) (R, error) +) + +type paramsPtr[T any] interface { + *T + Params +} + +type methodFlags int + +const ( + notification methodFlags = 1 << iota // method is a notification, not request + missingParamsOK // params may be missing or null +) + +func newClientMethodInfo[P paramsPtr[T], R Result, T any](d typedClientMethodHandler[P, R], flags methodFlags) methodInfo { + mi := newMethodInfo[P, R](flags) + mi.newRequest = func(s Session, p Params, _ *RequestExtra) Request { + r := &ClientRequest[P]{Session: s.(*ClientSession)} + if p != nil { + r.Params = p.(P) + } + return r + } + mi.handleMethod = MethodHandler(func(ctx context.Context, _ string, req Request) (Result, error) { + return d(ctx, req.(*ClientRequest[P])) + }) + return mi +} + +func newServerMethodInfo[P paramsPtr[T], R Result, T any](d typedServerMethodHandler[P, R], flags methodFlags) methodInfo { + mi := newMethodInfo[P, R](flags) + mi.newRequest = func(s Session, p Params, re *RequestExtra) Request { + r := &ServerRequest[P]{Session: s.(*ServerSession), Extra: re} + if p != nil { + r.Params = p.(P) + } + return r + } + mi.handleMethod = MethodHandler(func(ctx context.Context, _ string, req Request) (Result, error) { + return d(ctx, req.(*ServerRequest[P])) + }) + return mi +} + +// newMethodInfo creates a methodInfo from a typedMethodHandler. +// +// If isRequest is set, the method is treated as a request rather than a +// notification. +func newMethodInfo[P paramsPtr[T], R Result, T any](flags methodFlags) methodInfo { + return methodInfo{ + flags: flags, + unmarshalParams: func(m json.RawMessage) (Params, error) { + var p P + if m != nil { + if err := json.Unmarshal(m, &p); err != nil { + return nil, fmt.Errorf("unmarshaling %q into a %T: %w", m, p, err) + } + } + // We must check missingParamsOK here, in addition to checkRequest, to + // catch the edge cases where "params" is set to JSON null. + // See also https://go.dev/issue/33835. + // + // We need to ensure that p is non-null to guard against crashes, as our + // internal code or externally provided handlers may assume that params + // is non-null. + if flags&missingParamsOK == 0 && p == nil { + return nil, fmt.Errorf("%w: missing required \"params\"", jsonrpc2.ErrInvalidRequest) + } + return orZero[Params](p), nil + }, + // newResult is used on the send side, to construct the value to unmarshal the result into. + // R is a pointer to a result struct. There is no way to "unpointer" it without reflection. + // TODO(jba): explore generic approaches to this, perhaps by treating R in + // the signature as the unpointered type. + newResult: func() Result { return reflect.New(reflect.TypeFor[R]().Elem()).Interface().(R) }, + } +} + +// serverMethod is glue for creating a typedMethodHandler from a method on Server. +func serverMethod[P Params, R Result]( + f func(*Server, context.Context, *ServerRequest[P]) (R, error), +) typedServerMethodHandler[P, R] { + return func(ctx context.Context, req *ServerRequest[P]) (R, error) { + return f(req.Session.server, ctx, req) + } +} + +// clientMethod is glue for creating a typedMethodHandler from a method on Client. +func clientMethod[P Params, R Result]( + f func(*Client, context.Context, *ClientRequest[P]) (R, error), +) typedClientMethodHandler[P, R] { + return func(ctx context.Context, req *ClientRequest[P]) (R, error) { + return f(req.Session.client, ctx, req) + } +} + +// serverSessionMethod is glue for creating a typedServerMethodHandler from a method on ServerSession. +func serverSessionMethod[P Params, R Result](f func(*ServerSession, context.Context, P) (R, error)) typedServerMethodHandler[P, R] { + return func(ctx context.Context, req *ServerRequest[P]) (R, error) { + return f(req.GetSession().(*ServerSession), ctx, req.Params) + } +} + +// clientSessionMethod is glue for creating a typedMethodHandler from a method on ServerSession. +func clientSessionMethod[P Params, R Result](f func(*ClientSession, context.Context, P) (R, error)) typedClientMethodHandler[P, R] { + return func(ctx context.Context, req *ClientRequest[P]) (R, error) { + return f(req.GetSession().(*ClientSession), ctx, req.Params) + } +} + +// Error codes +const ( + codeResourceNotFound = -32002 + // The error code if the method exists and was called properly, but the peer does not support it. + codeUnsupportedMethod = -31001 + // The error code for invalid parameters + codeInvalidParams = -32602 +) + +// notifySessions calls Notify on all the sessions. +// Should be called on a copy of the peer sessions. +func notifySessions[S Session, P Params](sessions []S, method string, params P) { + if sessions == nil { + return + } + // TODO: make this timeout configurable, or call handleNotify asynchronously. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // TODO: there's a potential spec violation here, when the feature list + // changes before the session (client or server) is initialized. + for _, s := range sessions { + req := newRequest(s, params) + if err := handleNotify(ctx, method, req); err != nil { + // TODO(jba): surface this error better + log.Printf("calling %s: %v", method, err) + } + } +} + +func newRequest[S Session, P Params](s S, p P) Request { + switch s := any(s).(type) { + case *ClientSession: + return &ClientRequest[P]{Session: s, Params: p} + case *ServerSession: + return &ServerRequest[P]{Session: s, Params: p} + default: + panic("bad session") + } +} + +// Meta is additional metadata for requests, responses and other types. +type Meta map[string]any + +// GetMeta returns metadata from a value. +func (m Meta) GetMeta() map[string]any { return m } + +// SetMeta sets the metadata on a value. +func (m *Meta) SetMeta(x map[string]any) { *m = x } + +const progressTokenKey = "progressToken" + +func getProgressToken(p Params) any { + return p.GetMeta()[progressTokenKey] +} + +func setProgressToken(p Params, pt any) { + switch pt.(type) { + // Support int32 and int64 for atomic.IntNN. + case int, int32, int64, string: + default: + panic(fmt.Sprintf("progress token %v is of type %[1]T, not int or string", pt)) + } + m := p.GetMeta() + if m == nil { + m = map[string]any{} + } + m[progressTokenKey] = pt +} + +// A Request is a method request with parameters and additional information, such as the session. +// Request is implemented by [*ClientRequest] and [*ServerRequest]. +type Request interface { + isRequest() + GetSession() Session + GetParams() Params + // GetExtra returns the Extra field for ServerRequests, and nil for ClientRequests. + GetExtra() *RequestExtra +} + +// A ClientRequest is a request to a client. +type ClientRequest[P Params] struct { + Session *ClientSession + Params P +} + +// A ServerRequest is a request to a server. +type ServerRequest[P Params] struct { + Session *ServerSession + Params P + Extra *RequestExtra +} + +// RequestExtra is extra information included in requests, typically from +// the transport layer. +type RequestExtra struct { + TokenInfo *auth.TokenInfo // bearer token info (e.g. from OAuth) if any + Header http.Header // header from HTTP request, if any +} + +func (*ClientRequest[P]) isRequest() {} +func (*ServerRequest[P]) isRequest() {} + +func (r *ClientRequest[P]) GetSession() Session { return r.Session } +func (r *ServerRequest[P]) GetSession() Session { return r.Session } + +func (r *ClientRequest[P]) GetParams() Params { return r.Params } +func (r *ServerRequest[P]) GetParams() Params { return r.Params } + +func (r *ClientRequest[P]) GetExtra() *RequestExtra { return nil } +func (r *ServerRequest[P]) GetExtra() *RequestExtra { return r.Extra } + +func serverRequestFor[P Params](s *ServerSession, p P) *ServerRequest[P] { + return &ServerRequest[P]{Session: s, Params: p} +} + +func clientRequestFor[P Params](s *ClientSession, p P) *ClientRequest[P] { + return &ClientRequest[P]{Session: s, Params: p} +} + +// Params is a parameter (input) type for an MCP call or notification. +type Params interface { + // GetMeta returns metadata from a value. + GetMeta() map[string]any + // SetMeta sets the metadata on a value. + SetMeta(map[string]any) + + // isParams discourages implementation of Params outside of this package. + isParams() +} + +// RequestParams is a parameter (input) type for an MCP request. +type RequestParams interface { + Params + + // GetProgressToken returns the progress token from the params' Meta field, or nil + // if there is none. + GetProgressToken() any + + // SetProgressToken sets the given progress token into the params' Meta field. + // It panics if its argument is not an int or a string. + SetProgressToken(any) +} + +// Result is a result of an MCP call. +type Result interface { + // isResult discourages implementation of Result outside of this package. + isResult() + + // GetMeta returns metadata from a value. + GetMeta() map[string]any + // SetMeta sets the metadata on a value. + SetMeta(map[string]any) +} + +// emptyResult is returned by methods that have no result, like ping. +// Those methods cannot return nil, because jsonrpc2 cannot handle nils. +type emptyResult struct{} + +func (*emptyResult) isResult() {} +func (*emptyResult) GetMeta() map[string]any { panic("should never be called") } +func (*emptyResult) SetMeta(map[string]any) { panic("should never be called") } + +type listParams interface { + // Returns a pointer to the param's Cursor field. + cursorPtr() *string +} + +type listResult[T any] interface { + // Returns a pointer to the param's NextCursor field. + nextCursorPtr() *string +} + +// keepaliveSession represents a session that supports keepalive functionality. +type keepaliveSession interface { + Ping(ctx context.Context, params *PingParams) error + Close() error +} + +// startKeepalive starts the keepalive mechanism for a session. +// It assigns the cancel function to the provided cancelPtr and starts a goroutine +// that sends ping messages at the specified interval. +func startKeepalive(session keepaliveSession, interval time.Duration, cancelPtr *context.CancelFunc) { + ctx, cancel := context.WithCancel(context.Background()) + // Assign cancel function before starting goroutine to avoid race condition. + // We cannot return it because the caller may need to cancel during the + // window between goroutine scheduling and function return. + *cancelPtr = cancel + + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + pingCtx, pingCancel := context.WithTimeout(context.Background(), interval/2) + err := session.Ping(pingCtx, nil) + pingCancel() + if err != nil { + // Ping failed, close the session + _ = session.Close() + return + } + } + } + }() +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/sse.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/sse.go new file mode 100644 index 0000000000..7f644918bb --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/sse.go @@ -0,0 +1,479 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "net/url" + "sync" + + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" +) + +// This file implements support for SSE (HTTP with server-sent events) +// transport server and client. +// https://modelcontextprotocol.io/specification/2024-11-05/basic/transports +// +// The transport is simple, at least relative to the new streamable transport +// introduced in the 2025-03-26 version of the spec. In short: +// +// 1. Sessions are initiated via a hanging GET request, which streams +// server->client messages as SSE 'message' events. +// 2. The first event in the SSE stream must be an 'endpoint' event that +// informs the client of the session endpoint. +// 3. The client POSTs client->server messages to the session endpoint. +// +// Therefore, the each new GET request hands off its responsewriter to an +// [SSEServerTransport] type that abstracts the transport as follows: +// - Write writes a new event to the responseWriter, or fails if the GET has +// exited. +// - Read reads off a message queue that is pushed to via POST requests. +// - Close causes the hanging GET to exit. + +// SSEHandler is an http.Handler that serves SSE-based MCP sessions as defined by +// the [2024-11-05 version] of the MCP spec. +// +// [2024-11-05 version]: https://modelcontextprotocol.io/specification/2024-11-05/basic/transports +type SSEHandler struct { + getServer func(request *http.Request) *Server + opts SSEOptions + onConnection func(*ServerSession) // for testing; must not block + + mu sync.Mutex + sessions map[string]*SSEServerTransport +} + +// SSEOptions specifies options for an [SSEHandler]. +// for now, it is empty, but may be extended in future. +// https://github.com/modelcontextprotocol/go-sdk/issues/507 +type SSEOptions struct{} + +// NewSSEHandler returns a new [SSEHandler] that creates and manages MCP +// sessions created via incoming HTTP requests. +// +// Sessions are created when the client issues a GET request to the server, +// which must accept text/event-stream responses (server-sent events). +// For each such request, a new [SSEServerTransport] is created with a distinct +// messages endpoint, and connected to the server returned by getServer. +// The SSEHandler also handles requests to the message endpoints, by +// delegating them to the relevant server transport. +// +// The getServer function may return a distinct [Server] for each new +// request, or reuse an existing server. If it returns nil, the handler +// will return a 400 Bad Request. +func NewSSEHandler(getServer func(request *http.Request) *Server, opts *SSEOptions) *SSEHandler { + s := &SSEHandler{ + getServer: getServer, + sessions: make(map[string]*SSEServerTransport), + } + + if opts != nil { + s.opts = *opts + } + + return s +} + +// A SSEServerTransport is a logical SSE session created through a hanging GET +// request. +// +// Use [SSEServerTransport.Connect] to initiate the flow of messages. +// +// When connected, it returns the following [Connection] implementation: +// - Writes are SSE 'message' events to the GET response. +// - Reads are received from POSTs to the session endpoint, via +// [SSEServerTransport.ServeHTTP]. +// - Close terminates the hanging GET. +// +// The transport is itself an [http.Handler]. It is the caller's responsibility +// to ensure that the resulting transport serves HTTP requests on the given +// session endpoint. +// +// Each SSEServerTransport may be connected (via [Server.Connect]) at most +// once, since [SSEServerTransport.ServeHTTP] serves messages to the connected +// session. +// +// Most callers should instead use an [SSEHandler], which transparently handles +// the delegation to SSEServerTransports. +type SSEServerTransport struct { + // Endpoint is the endpoint for this session, where the client can POST + // messages. + Endpoint string + + // Response is the hanging response body to the incoming GET request. + Response http.ResponseWriter + + // incoming is the queue of incoming messages. + // It is never closed, and by convention, incoming is non-nil if and only if + // the transport is connected. + incoming chan jsonrpc.Message + + // We must guard both pushes to the incoming queue and writes to the response + // writer, because incoming POST requests are arbitrarily concurrent and we + // need to ensure we don't write push to the queue, or write to the + // ResponseWriter, after the session GET request exits. + mu sync.Mutex // also guards writes to Response + closed bool // set when the stream is closed + done chan struct{} // closed when the connection is closed +} + +// ServeHTTP handles POST requests to the transport endpoint. +func (t *SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if t.incoming == nil { + http.Error(w, "session not connected", http.StatusInternalServerError) + return + } + + // Read and parse the message. + data, err := io.ReadAll(req.Body) + if err != nil { + http.Error(w, "failed to read body", http.StatusBadRequest) + return + } + // Optionally, we could just push the data onto a channel, and let the + // message fail to parse when it is read. This failure seems a bit more + // useful + msg, err := jsonrpc2.DecodeMessage(data) + if err != nil { + http.Error(w, "failed to parse body", http.StatusBadRequest) + return + } + if req, ok := msg.(*jsonrpc.Request); ok { + if _, err := checkRequest(req, serverMethodInfos); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + } + select { + case t.incoming <- msg: + w.WriteHeader(http.StatusAccepted) + case <-t.done: + http.Error(w, "session closed", http.StatusBadRequest) + } +} + +// Connect sends the 'endpoint' event to the client. +// See [SSEServerTransport] for more details on the [Connection] implementation. +func (t *SSEServerTransport) Connect(context.Context) (Connection, error) { + if t.incoming != nil { + return nil, fmt.Errorf("already connected") + } + t.incoming = make(chan jsonrpc.Message, 100) + t.done = make(chan struct{}) + _, err := writeEvent(t.Response, Event{ + Name: "endpoint", + Data: []byte(t.Endpoint), + }) + if err != nil { + return nil, err + } + return &sseServerConn{t: t}, nil +} + +func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + sessionID := req.URL.Query().Get("sessionid") + + // TODO: consider checking Content-Type here. For now, we are lax. + + // For POST requests, the message body is a message to send to a session. + if req.Method == http.MethodPost { + // Look up the session. + if sessionID == "" { + http.Error(w, "sessionid must be provided", http.StatusBadRequest) + return + } + h.mu.Lock() + session := h.sessions[sessionID] + h.mu.Unlock() + if session == nil { + http.Error(w, "session not found", http.StatusNotFound) + return + } + + session.ServeHTTP(w, req) + return + } + + if req.Method != http.MethodGet { + http.Error(w, "invalid method", http.StatusMethodNotAllowed) + return + } + + // GET requests create a new session, and serve messages over SSE. + + // TODO: it's not entirely documented whether we should check Accept here. + // Let's again be lax and assume the client will accept SSE. + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + + sessionID = randText() + endpoint, err := req.URL.Parse("?sessionid=" + sessionID) + if err != nil { + http.Error(w, "internal error: failed to create endpoint", http.StatusInternalServerError) + return + } + + transport := &SSEServerTransport{Endpoint: endpoint.RequestURI(), Response: w} + + // The session is terminated when the request exits. + h.mu.Lock() + h.sessions[sessionID] = transport + h.mu.Unlock() + defer func() { + h.mu.Lock() + delete(h.sessions, sessionID) + h.mu.Unlock() + }() + + server := h.getServer(req) + if server == nil { + // The getServer argument to NewSSEHandler returned nil. + http.Error(w, "no server available", http.StatusBadRequest) + return + } + ss, err := server.Connect(req.Context(), transport, nil) + if err != nil { + http.Error(w, "connection failed", http.StatusInternalServerError) + return + } + if h.onConnection != nil { + h.onConnection(ss) + } + defer ss.Close() // close the transport when the GET exits + + select { + case <-req.Context().Done(): + case <-transport.done: + } +} + +// sseServerConn implements the [Connection] interface for a single [SSEServerTransport]. +// It hides the Connection interface from the SSEServerTransport API. +type sseServerConn struct { + t *SSEServerTransport +} + +// TODO(jba): get the session ID. (Not urgent because SSE transports have been removed from the spec.) +func (s *sseServerConn) SessionID() string { return "" } + +// Read implements jsonrpc2.Reader. +func (s *sseServerConn) Read(ctx context.Context) (jsonrpc.Message, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case msg := <-s.t.incoming: + return msg, nil + case <-s.t.done: + return nil, io.EOF + } +} + +// Write implements jsonrpc2.Writer. +func (s *sseServerConn) Write(ctx context.Context, msg jsonrpc.Message) error { + if ctx.Err() != nil { + return ctx.Err() + } + + data, err := jsonrpc2.EncodeMessage(msg) + if err != nil { + return err + } + + s.t.mu.Lock() + defer s.t.mu.Unlock() + + // Note that it is invalid to write to a ResponseWriter after ServeHTTP has + // exited, and so we must lock around this write and check isDone, which is + // set before the hanging GET exits. + if s.t.closed { + return io.EOF + } + + _, err = writeEvent(s.t.Response, Event{Name: "message", Data: data}) + return err +} + +// Close implements io.Closer, and closes the session. +// +// It must be safe to call Close more than once, as the close may +// asynchronously be initiated by either the server closing its connection, or +// by the hanging GET exiting. +func (s *sseServerConn) Close() error { + s.t.mu.Lock() + defer s.t.mu.Unlock() + if !s.t.closed { + s.t.closed = true + close(s.t.done) + } + return nil +} + +// An SSEClientTransport is a [Transport] that can communicate with an MCP +// endpoint serving the SSE transport defined by the 2024-11-05 version of the +// spec. +// +// https://modelcontextprotocol.io/specification/2024-11-05/basic/transports +type SSEClientTransport struct { + // Endpoint is the SSE endpoint to connect to. + Endpoint string + + // HTTPClient is the client to use for making HTTP requests. If nil, + // http.DefaultClient is used. + HTTPClient *http.Client +} + +// Connect connects through the client endpoint. +func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { + parsedURL, err := url.Parse(c.Endpoint) + if err != nil { + return nil, fmt.Errorf("invalid endpoint: %v", err) + } + req, err := http.NewRequestWithContext(ctx, "GET", c.Endpoint, nil) + if err != nil { + return nil, err + } + httpClient := c.HTTPClient + if httpClient == nil { + httpClient = http.DefaultClient + } + req.Header.Set("Accept", "text/event-stream") + resp, err := httpClient.Do(req) + if err != nil { + return nil, err + } + + msgEndpoint, err := func() (*url.URL, error) { + var evt Event + for evt, err = range scanEvents(resp.Body) { + break + } + if err != nil { + return nil, err + } + if evt.Name != "endpoint" { + return nil, fmt.Errorf("first event is %q, want %q", evt.Name, "endpoint") + } + raw := string(evt.Data) + return parsedURL.Parse(raw) + }() + if err != nil { + resp.Body.Close() + return nil, fmt.Errorf("missing endpoint: %v", err) + } + + // From here on, the stream takes ownership of resp.Body. + s := &sseClientConn{ + client: httpClient, + msgEndpoint: msgEndpoint, + incoming: make(chan []byte, 100), + body: resp.Body, + done: make(chan struct{}), + } + + go func() { + defer s.Close() // close the transport when the GET exits + + for evt, err := range scanEvents(resp.Body) { + if err != nil { + return + } + select { + case s.incoming <- evt.Data: + case <-s.done: + return + } + } + }() + + return s, nil +} + +// An sseClientConn is a logical jsonrpc2 connection that implements the client +// half of the SSE protocol: +// - Writes are POSTS to the session endpoint. +// - Reads are SSE 'message' events, and pushes them onto a buffered channel. +// - Close terminates the GET request. +type sseClientConn struct { + client *http.Client // HTTP client to use for requests + msgEndpoint *url.URL // session endpoint for POSTs + incoming chan []byte // queue of incoming messages + + mu sync.Mutex + body io.ReadCloser // body of the hanging GET + closed bool // set when the stream is closed + done chan struct{} // closed when the stream is closed +} + +// TODO(jba): get the session ID. (Not urgent because SSE transports have been removed from the spec.) +func (c *sseClientConn) SessionID() string { return "" } + +func (c *sseClientConn) isDone() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.closed +} + +func (c *sseClientConn) Read(ctx context.Context) (jsonrpc.Message, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + + case <-c.done: + return nil, io.EOF + + case data := <-c.incoming: + // TODO(rfindley): do we really need to check this? We receive from c.done above. + if c.isDone() { + return nil, io.EOF + } + msg, err := jsonrpc2.DecodeMessage(data) + if err != nil { + return nil, err + } + return msg, nil + } +} + +func (c *sseClientConn) Write(ctx context.Context, msg jsonrpc.Message) error { + data, err := jsonrpc2.EncodeMessage(msg) + if err != nil { + return err + } + if c.isDone() { + return io.EOF + } + req, err := http.NewRequestWithContext(ctx, "POST", c.msgEndpoint.String(), bytes.NewReader(data)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + resp, err := c.client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("failed to write: %s", resp.Status) + } + return nil +} + +func (c *sseClientConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + if !c.closed { + c.closed = true + _ = c.body.Close() + close(c.done) + } + return nil +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable.go new file mode 100644 index 0000000000..178b24662a --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable.go @@ -0,0 +1,1777 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "maps" + "math" + "math/rand/v2" + "net/http" + "slices" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/modelcontextprotocol/go-sdk/auth" + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" +) + +const ( + protocolVersionHeader = "Mcp-Protocol-Version" + sessionIDHeader = "Mcp-Session-Id" +) + +// A StreamableHTTPHandler is an http.Handler that serves streamable MCP +// sessions, as defined by the [MCP spec]. +// +// [MCP spec]: https://modelcontextprotocol.io/2025/03/26/streamable-http-transport.html +type StreamableHTTPHandler struct { + getServer func(*http.Request) *Server + opts StreamableHTTPOptions + + onTransportDeletion func(sessionID string) // for testing + + mu sync.Mutex + sessions map[string]*sessionInfo // keyed by session ID +} + +type sessionInfo struct { + session *ServerSession + transport *StreamableServerTransport + + // If timeout is set, automatically close the session after an idle period. + timeout time.Duration + timerMu sync.Mutex + refs int // reference count + timer *time.Timer +} + +// startPOST signals that a POST request for this session is starting (which +// carries a client->server message), pausing the session timeout if it was +// running. +// +// TODO: we may want to also pause the timer when resuming non-standalone SSE +// streams, but that is tricy to implement. Clients should generally make +// keepalive pings if they want to keep the session live. +func (i *sessionInfo) startPOST() { + if i.timeout <= 0 { + return + } + + i.timerMu.Lock() + defer i.timerMu.Unlock() + + if i.timer == nil { + return // timer stopped permanently + } + if i.refs == 0 { + i.timer.Stop() + } + i.refs++ +} + +// endPOST sigals that a request for this session is ending, starting the +// timeout if there are no other requests running. +func (i *sessionInfo) endPOST() { + if i.timeout <= 0 { + return + } + + i.timerMu.Lock() + defer i.timerMu.Unlock() + + if i.timer == nil { + return // timer stopped permanently + } + + i.refs-- + assert(i.refs >= 0, "negative ref count") + if i.refs == 0 { + i.timer.Reset(i.timeout) + } +} + +// stopTimer stops the inactivity timer permanently. +func (i *sessionInfo) stopTimer() { + i.timerMu.Lock() + defer i.timerMu.Unlock() + if i.timer != nil { + i.timer.Stop() + i.timer = nil + } +} + +// StreamableHTTPOptions configures the StreamableHTTPHandler. +type StreamableHTTPOptions struct { + // Stateless controls whether the session is 'stateless'. + // + // A stateless server does not validate the Mcp-Session-Id header, and uses a + // temporary session with default initialization parameters. Any + // server->client request is rejected immediately as there's no way for the + // client to respond. Server->Client notifications may reach the client if + // they are made in the context of an incoming request, as described in the + // documentation for [StreamableServerTransport]. + Stateless bool + + // TODO(#148): support session retention (?) + + // JSONResponse causes streamable responses to return application/json rather + // than text/event-stream ([§2.1.5] of the spec). + // + // [§2.1.5]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server + JSONResponse bool + + // Logger specifies the logger to use. + // If nil, do not log. + Logger *slog.Logger + + // EventStore enables stream resumption. + // + // If set, EventStore will be used to persist stream events and replay them + // upon stream resumption. + EventStore EventStore + + // SessionTimeout configures a timeout for idle sessions. + // + // When sessions receive no new HTTP requests from the client for this + // duration, they are automatically closed. + // + // If SessionTimeout is the zero value, idle sessions are never closed. + SessionTimeout time.Duration +} + +// NewStreamableHTTPHandler returns a new [StreamableHTTPHandler]. +// +// The getServer function is used to create or look up servers for new +// sessions. It is OK for getServer to return the same server multiple times. +// If getServer returns nil, a 400 Bad Request will be served. +func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *StreamableHTTPOptions) *StreamableHTTPHandler { + h := &StreamableHTTPHandler{ + getServer: getServer, + sessions: make(map[string]*sessionInfo), + } + if opts != nil { + h.opts = *opts + } + + if h.opts.Logger == nil { // ensure we have a logger + h.opts.Logger = ensureLogger(nil) + } + + return h +} + +// closeAll closes all ongoing sessions, for tests. +// +// TODO(rfindley): investigate the best API for callers to configure their +// session lifecycle. (?) +// +// Should we allow passing in a session store? That would allow the handler to +// be stateless. +func (h *StreamableHTTPHandler) closeAll() { + // TODO: if we ever expose this outside of tests, we'll need to do better + // than simply collecting sessions while holding the lock: we need to prevent + // new sessions from being added. + // + // Currently, sessions remove themselves from h.sessions when closed, so we + // can't call Close while holding the lock. + h.mu.Lock() + sessionInfos := slices.Collect(maps.Values(h.sessions)) + h.sessions = nil + h.mu.Unlock() + for _, s := range sessionInfos { + s.session.Close() + } +} + +func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // Allow multiple 'Accept' headers. + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Accept#syntax + accept := strings.Split(strings.Join(req.Header.Values("Accept"), ","), ",") + var jsonOK, streamOK bool + for _, c := range accept { + switch strings.TrimSpace(c) { + case "application/json", "application/*": + jsonOK = true + case "text/event-stream", "text/*": + streamOK = true + case "*/*": + jsonOK = true + streamOK = true + } + } + + if req.Method == http.MethodGet { + if !streamOK { + http.Error(w, "Accept must contain 'text/event-stream' for GET requests", http.StatusBadRequest) + return + } + } else if (!jsonOK || !streamOK) && req.Method != http.MethodDelete { // TODO: consolidate with handling of http method below. + http.Error(w, "Accept must contain both 'application/json' and 'text/event-stream'", http.StatusBadRequest) + return + } + + sessionID := req.Header.Get(sessionIDHeader) + var sessInfo *sessionInfo + if sessionID != "" { + h.mu.Lock() + sessInfo = h.sessions[sessionID] + h.mu.Unlock() + if sessInfo == nil && !h.opts.Stateless { + // Unless we're in 'stateless' mode, which doesn't perform any Session-ID + // validation, we require that the session ID matches a known session. + // + // In stateless mode, a temporary transport is be created below. + http.Error(w, "session not found", http.StatusNotFound) + return + } + } + + if req.Method == http.MethodDelete { + if sessionID == "" { + http.Error(w, "Bad Request: DELETE requires an Mcp-Session-Id header", http.StatusBadRequest) + return + } + if sessInfo != nil { // sessInfo may be nil in stateless mode + // Closing the session also removes it from h.sessions, due to the + // onClose callback. + sessInfo.session.Close() + } + w.WriteHeader(http.StatusNoContent) + return + } + + switch req.Method { + case http.MethodPost, http.MethodGet: + if req.Method == http.MethodGet && (h.opts.Stateless || sessionID == "") { + http.Error(w, "GET requires an active session", http.StatusMethodNotAllowed) + return + } + default: + w.Header().Set("Allow", "GET, POST, DELETE") + http.Error(w, "Method Not Allowed: streamable MCP servers support GET, POST, and DELETE requests", http.StatusMethodNotAllowed) + return + } + + // [§2.7] of the spec (2025-06-18) states: + // + // "If using HTTP, the client MUST include the MCP-Protocol-Version: + // HTTP header on all subsequent requests to the MCP + // server, allowing the MCP server to respond based on the MCP protocol + // version. + // + // For example: MCP-Protocol-Version: 2025-06-18 + // The protocol version sent by the client SHOULD be the one negotiated during + // initialization. + // + // For backwards compatibility, if the server does not receive an + // MCP-Protocol-Version header, and has no other way to identify the version - + // for example, by relying on the protocol version negotiated during + // initialization - the server SHOULD assume protocol version 2025-03-26. + // + // If the server receives a request with an invalid or unsupported + // MCP-Protocol-Version, it MUST respond with 400 Bad Request." + // + // Since this wasn't present in the 2025-03-26 version of the spec, this + // effectively means: + // 1. IF the client provides a version header, it must be a supported + // version. + // 2. In stateless mode, where we've lost the state of the initialize + // request, we assume that whatever the client tells us is the truth (or + // assume 2025-03-26 if the client doesn't say anything). + // + // This logic matches the typescript SDK. + // + // [§2.7]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#protocol-version-header + protocolVersion := req.Header.Get(protocolVersionHeader) + if protocolVersion == "" { + protocolVersion = protocolVersion20250326 + } + if !slices.Contains(supportedProtocolVersions, protocolVersion) { + http.Error(w, fmt.Sprintf("Bad Request: Unsupported protocol version (supported versions: %s)", strings.Join(supportedProtocolVersions, ",")), http.StatusBadRequest) + return + } + + if sessInfo == nil { + server := h.getServer(req) + if server == nil { + // The getServer argument to NewStreamableHTTPHandler returned nil. + http.Error(w, "no server available", http.StatusBadRequest) + return + } + if sessionID == "" { + // In stateless mode, sessionID may be nonempty even if there's no + // existing transport. + sessionID = server.opts.GetSessionID() + } + transport := &StreamableServerTransport{ + SessionID: sessionID, + Stateless: h.opts.Stateless, + EventStore: h.opts.EventStore, + jsonResponse: h.opts.JSONResponse, + logger: h.opts.Logger, + } + + // Sessions without a session ID are also stateless: there's no way to + // address them. + stateless := h.opts.Stateless || sessionID == "" + // To support stateless mode, we initialize the session with a default + // state, so that it doesn't reject subsequent requests. + var connectOpts *ServerSessionOptions + if stateless { + // Peek at the body to see if it is initialize or initialized. + // We want those to be handled as usual. + var hasInitialize, hasInitialized bool + { + // TODO: verify that this allows protocol version negotiation for + // stateless servers. + body, err := io.ReadAll(req.Body) + if err != nil { + http.Error(w, "failed to read body", http.StatusInternalServerError) + return + } + req.Body.Close() + + // Reset the body so that it can be read later. + req.Body = io.NopCloser(bytes.NewBuffer(body)) + + msgs, _, err := readBatch(body) + if err == nil { + for _, msg := range msgs { + if req, ok := msg.(*jsonrpc.Request); ok { + switch req.Method { + case methodInitialize: + hasInitialize = true + case notificationInitialized: + hasInitialized = true + } + } + } + } + } + + // If we don't have InitializeParams or InitializedParams in the request, + // set the initial state to a default value. + state := new(ServerSessionState) + if !hasInitialize { + state.InitializeParams = &InitializeParams{ + ProtocolVersion: protocolVersion, + } + } + if !hasInitialized { + state.InitializedParams = new(InitializedParams) + } + state.LogLevel = "info" + connectOpts = &ServerSessionOptions{ + State: state, + } + } else { + // Cleanup is only required in stateful mode, as transportation is + // not stored in the map otherwise. + connectOpts = &ServerSessionOptions{ + onClose: func() { + h.mu.Lock() + defer h.mu.Unlock() + if info, ok := h.sessions[transport.SessionID]; ok { + info.stopTimer() + delete(h.sessions, transport.SessionID) + if h.onTransportDeletion != nil { + h.onTransportDeletion(transport.SessionID) + } + } + }, + } + } + + // Pass req.Context() here, to allow middleware to add context values. + // The context is detached in the jsonrpc2 library when handling the + // long-running stream. + session, err := server.Connect(req.Context(), transport, connectOpts) + if err != nil { + http.Error(w, "failed connection", http.StatusInternalServerError) + return + } + sessInfo = &sessionInfo{ + session: session, + transport: transport, + } + + if stateless { + // Stateless mode: close the session when the request exits. + defer session.Close() // close the fake session after handling the request + } else { + // Otherwise, save the transport so that it can be reused + + // Clean up the session when it times out. + // + // Note that the timer here may fire multiple times, but + // sessInfo.session.Close is idempotent. + if h.opts.SessionTimeout > 0 { + sessInfo.timeout = h.opts.SessionTimeout + sessInfo.timer = time.AfterFunc(sessInfo.timeout, func() { + sessInfo.session.Close() + }) + } + h.mu.Lock() + h.sessions[transport.SessionID] = sessInfo + h.mu.Unlock() + defer func() { + // If initialization failed, clean up the session (#578). + if session.InitializeParams() == nil { + // Initialization failed. + session.Close() + } + }() + } + } + + if req.Method == http.MethodPost { + sessInfo.startPOST() + defer sessInfo.endPOST() + } + + sessInfo.transport.ServeHTTP(w, req) +} + +// A StreamableServerTransport implements the server side of the MCP streamable +// transport. +// +// Each StreamableServerTransport must be connected (via [Server.Connect]) at +// most once, since [StreamableServerTransport.ServeHTTP] serves messages to +// the connected session. +// +// Reads from the streamable server connection receive messages from http POST +// requests from the client. Writes to the streamable server connection are +// sent either to the related stream, or to the standalone SSE stream, +// according to the following rules: +// - JSON-RPC responses to incoming requests are always routed to the +// appropriate HTTP response. +// - Requests or notifications made with a context.Context value derived from +// an incoming request handler, are routed to the HTTP response +// corresponding to that request, unless it has already terminated, in +// which case they are routed to the standalone SSE stream. +// - Requests or notifications made with a detached context.Context value are +// routed to the standalone SSE stream. +type StreamableServerTransport struct { + // SessionID is the ID of this session. + // + // If SessionID is the empty string, this is a 'stateless' session, which has + // limited ability to communicate with the client. Otherwise, the session ID + // must be globally unique, that is, different from any other session ID + // anywhere, past and future. (We recommend using a crypto random number + // generator to produce one, as with [crypto/rand.Text].) + SessionID string + + // Stateless controls whether the eventstore is 'Stateless'. Server sessions + // connected to a stateless transport are disallowed from making outgoing + // requests. + // + // See also [StreamableHTTPOptions.Stateless]. + Stateless bool + + // EventStore enables stream resumption. + // + // If set, EventStore will be used to persist stream events and replay them + // upon stream resumption. + EventStore EventStore + + // jsonResponse, if set, tells the server to prefer to respond to requests + // using application/json responses rather than text/event-stream. + // + // Specifically, responses will be application/json whenever incoming POST + // request contain only a single message. In this case, notifications or + // requests made within the context of a server request will be sent to the + // standalone SSE stream, if any. + // + // TODO(rfindley): jsonResponse should be exported, since + // StreamableHTTPOptions.JSONResponse is exported, and we want to allow users + // to write their own streamable HTTP handler. + jsonResponse bool + + // optional logger provided through the [StreamableHTTPOptions.Logger]. + // + // TODO(rfindley): logger should be exported, since we want to allow users + // to write their own streamable HTTP handler. + logger *slog.Logger + + // connection is non-nil if and only if the transport has been connected. + connection *streamableServerConn +} + +// Connect implements the [Transport] interface. +func (t *StreamableServerTransport) Connect(ctx context.Context) (Connection, error) { + if t.connection != nil { + return nil, fmt.Errorf("transport already connected") + } + t.connection = &streamableServerConn{ + sessionID: t.SessionID, + stateless: t.Stateless, + eventStore: t.EventStore, + jsonResponse: t.jsonResponse, + logger: ensureLogger(t.logger), // see #556: must be non-nil + incoming: make(chan jsonrpc.Message, 10), + done: make(chan struct{}), + streams: make(map[string]*stream), + requestStreams: make(map[jsonrpc.ID]string), + } + // Stream 0 corresponds to the standalone SSE stream. + // + // It is always text/event-stream, since it must carry arbitrarily many + // messages. + var err error + t.connection.streams[""], err = t.connection.newStream(ctx, nil, "") + if err != nil { + return nil, err + } + return t.connection, nil +} + +type streamableServerConn struct { + sessionID string + stateless bool + jsonResponse bool + eventStore EventStore + + logger *slog.Logger + + incoming chan jsonrpc.Message // messages from the client to the server + + mu sync.Mutex // guards all fields below + + // Sessions are closed exactly once. + isDone bool + done chan struct{} + + // Sessions can have multiple logical connections (which we call streams), + // corresponding to HTTP requests. Additionally, streams may be resumed by + // subsequent HTTP requests, when the HTTP connection is terminated + // unexpectedly. + // + // Therefore, we use a logical stream ID to key the stream state, and + // perform the accounting described below when incoming HTTP requests are + // handled. + + // streams holds the logical streams for this session, keyed by their ID. + // + // Lifecycle: streams persist until all of their responses are received from + // the server. + streams map[string]*stream + + // requestStreams maps incoming requests to their logical stream ID. + // + // Lifecycle: requestStreams persist until their response is received. + requestStreams map[jsonrpc.ID]string +} + +func (c *streamableServerConn) SessionID() string { + return c.sessionID +} + +// A stream is a single logical stream of SSE events within a server session. +// A stream begins with a client request, or with a client GET that has +// no Last-Event-ID header. +// +// A stream ends only when its session ends; we cannot determine its end otherwise, +// since a client may send a GET with a Last-Event-ID that references the stream +// at any time. +type stream struct { + // id is the logical ID for the stream, unique within a session. + // + // The standalone SSE stream has id "". + id string + + // mu guards the fields below, as well as storage of new messages in the + // connection's event store (if any). + mu sync.Mutex + + // If non-nil, deliver writes data directly to the HTTP response. + // + // Only one HTTP response may receive messages at a given time. An active + // HTTP connection acquires ownership of the stream by setting this field. + deliver func(data []byte, final bool) error + + // streamRequests is the set of unanswered incoming requests for the stream. + // + // Requests are removed when their response has been received. + requests map[jsonrpc.ID]struct{} +} + +// doneLocked reports whether the stream is logically complete. +// +// s.mu must be held while calling this function. +func (s *stream) doneLocked() bool { + return len(s.requests) == 0 && s.id != "" +} + +func (c *streamableServerConn) newStream(ctx context.Context, requests map[jsonrpc.ID]struct{}, id string) (*stream, error) { + if c.eventStore != nil { + if err := c.eventStore.Open(ctx, c.sessionID, id); err != nil { + return nil, err + } + } + return &stream{ + id: id, + requests: requests, + }, nil +} + +// We track the incoming request ID inside the handler context using +// idContextValue, so that notifications and server->client calls that occur in +// the course of handling incoming requests are correlated with the incoming +// request that caused them, and can be dispatched as server-sent events to the +// correct HTTP request. +// +// Currently, this is implemented in [ServerSession.handle]. This is not ideal, +// because it means that a user of the MCP package couldn't implement the +// streamable transport, as they'd lack this privileged access. +// +// If we ever wanted to expose this mechanism, we have a few options: +// 1. Make ServerSession an interface, and provide an implementation of +// ServerSession to handlers that closes over the incoming request ID. +// 2. Expose a 'HandlerTransport' interface that allows transports to provide +// a handler middleware, so that we don't hard-code this behavior in +// ServerSession.handle. +// 3. Add a `func ForRequest(context.Context) jsonrpc.ID` accessor that lets +// any transport access the incoming request ID. +// +// For now, by giving only the StreamableServerTransport access to the request +// ID, we avoid having to make this API decision. +type idContextKey struct{} + +// ServeHTTP handles a single HTTP request for the session. +func (t *StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if t.connection == nil { + http.Error(w, "transport not connected", http.StatusInternalServerError) + return + } + switch req.Method { + case http.MethodGet: + t.connection.serveGET(w, req) + case http.MethodPost: + t.connection.servePOST(w, req) + default: + // Should not be reached, as this is checked in StreamableHTTPHandler.ServeHTTP. + w.Header().Set("Allow", "GET, POST") + http.Error(w, "unsupported method", http.StatusMethodNotAllowed) + return + } +} + +// serveGET streams messages to a hanging http GET, with stream ID and last +// message parsed from the Last-Event-ID header. +// +// It returns an HTTP status code and error message. +func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request) { + // streamID "" corresponds to the default GET request. + streamID := "" + // By default, we haven't seen a last index. Since indices start at 0, we represent + // that by -1. This is incremented just before each event is written. + lastIdx := -1 + if len(req.Header.Values("Last-Event-ID")) > 0 { + eid := req.Header.Get("Last-Event-ID") + var ok bool + streamID, lastIdx, ok = parseEventID(eid) + if !ok { + http.Error(w, fmt.Sprintf("malformed Last-Event-ID %q", eid), http.StatusBadRequest) + return + } + if c.eventStore == nil { + http.Error(w, "stream replay unsupported", http.StatusBadRequest) + return + } + } + + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + + stream, done := c.acquireStream(ctx, w, streamID, &lastIdx) + if stream == nil { + return + } + // Release the stream when we're done. + defer func() { + stream.mu.Lock() + stream.deliver = nil + stream.mu.Unlock() + }() + + select { + case <-ctx.Done(): + // request cancelled + case <-done: + // request complete + case <-c.done: + // session closed + } +} + +// writeEvent writes an SSE event to w corresponding to the given stream, data, and index. +// lastIdx is incremented before writing, so that it continues to point to the index of the +// last event written to the stream. +func (c *streamableServerConn) writeEvent(w http.ResponseWriter, stream *stream, data []byte, lastIdx *int) error { + *lastIdx++ + e := Event{ + Name: "message", + Data: data, + } + if c.eventStore != nil { + e.ID = formatEventID(stream.id, *lastIdx) + } + if _, err := writeEvent(w, e); err != nil { + return err + } + return nil +} + +// acquireStream acquires the stream and replays all events since lastIdx, if +// any, updating lastIdx accordingly. If non-nil, the resulting stream will be +// registered for receiving new messages, and the resulting done channel will +// be closed when all related messages have been delivered. +// +// If any errors occur, they will be written to w and the resulting stream will +// be nil. The resulting stream may also be nil if the stream is complete. +// +// Importantly, this function must hold the stream mutex until done replaying +// all messages, so that no delivery or storage of new messages occurs while +// the stream is still replaying. +func (c *streamableServerConn) acquireStream(ctx context.Context, w http.ResponseWriter, streamID string, lastIdx *int) (*stream, chan struct{}) { + // if tempStream is set, the stream is done and we're just replaying messages. + // + // We record a temporary stream to claim exclusive replay rights. + tempStream := false + c.mu.Lock() + s, ok := c.streams[streamID] + if !ok { + // The stream is logically done, but claim exclusive rights to replay it by + // adding a temporary entry in the streams map. + // + // We create this entry with a non-nil deliver function, to ensure it isn't + // claimed by another request before we lock it below. + tempStream = true + s = &stream{ + id: streamID, + deliver: func([]byte, bool) error { return nil }, + } + c.streams[streamID] = s + + // Since this stream is transient, we must clean up after replaying. + defer func() { + c.mu.Lock() + delete(c.streams, streamID) + c.mu.Unlock() + }() + } + c.mu.Unlock() + + s.mu.Lock() + defer s.mu.Unlock() + + // Check that this stream wasn't claimed by another request. + if !tempStream && s.deliver != nil { + http.Error(w, "stream ID conflicts with ongoing stream", http.StatusConflict) + return nil, nil + } + + // Collect events to replay. Collect them all before writing, so that we + // have an opportunity to set the HTTP status code on an error. + // + // As indicated above, we must do that while holding stream.mu, so that no + // new messages are added to the eventstore until we've replayed all previous + // messages, and registered our delivery function. + var toReplay [][]byte + if c.eventStore != nil { + for data, err := range c.eventStore.After(ctx, c.SessionID(), s.id, *lastIdx) { + if err != nil { + // We can't replay events, perhaps because the underlying event store + // has garbage collected its storage. + // + // We must be careful here: any 404 will signal to the client that the + // *session* is not found, rather than the stream. + // + // 400 is not really accurate, but should at least have no side effects. + // Other SDKs (typescript) do not have a mechanism for events to be purged. + http.Error(w, "failed to replay events", http.StatusBadRequest) + return nil, nil + } + toReplay = append(toReplay, data) + } + } + + w.Header().Set("Cache-Control", "no-cache, no-transform") + w.Header().Set("Content-Type", "text/event-stream") // Accept checked in [StreamableHTTPHandler] + w.Header().Set("Connection", "keep-alive") + + if s.id == "" { + // Issue #410: the standalone SSE stream is likely not to receive messages + // for a long time. Ensure that headers are flushed. + w.WriteHeader(http.StatusOK) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + } + + for _, data := range toReplay { + if err := c.writeEvent(w, s, data, lastIdx); err != nil { + return nil, nil + } + } + + if tempStream || s.doneLocked() { + // Nothing more to do. + return nil, nil + } + + // The stream is not done: register a delivery function before the stream is + // unlocked, allowing the connection to write new events. + done := make(chan struct{}) + s.deliver = func(data []byte, final bool) error { + if err := ctx.Err(); err != nil { + return err + } + err := c.writeEvent(w, s, data, lastIdx) + if final { + close(done) + } + return err + } + return s, done +} + +// servePOST handles an incoming message, and replies with either an outgoing +// message stream or single response object, depending on whether the +// jsonResponse option is set. +// +// It returns an HTTP status code and error message. +func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Request) { + if len(req.Header.Values("Last-Event-ID")) > 0 { + http.Error(w, "can't send Last-Event-ID for POST request", http.StatusBadRequest) + return + } + + // Read incoming messages. + body, err := io.ReadAll(req.Body) + if err != nil { + http.Error(w, "failed to read body", http.StatusBadRequest) + return + } + if len(body) == 0 { + http.Error(w, "POST requires a non-empty body", http.StatusBadRequest) + return + } + incoming, isBatch, err := readBatch(body) + if err != nil { + http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest) + return + } + + protocolVersion := req.Header.Get(protocolVersionHeader) + if protocolVersion == "" { + protocolVersion = protocolVersion20250326 + } + + if isBatch && protocolVersion >= protocolVersion20250618 { + http.Error(w, fmt.Sprintf("JSON-RPC batching is not supported in %s and later (request version: %s)", protocolVersion20250618, protocolVersion), http.StatusBadRequest) + return + } + + // TODO(rfindley): no tests fail if we reject batch JSON requests entirely. + // We need to test this with older protocol versions. + // if isBatch && c.jsonResponse { + // http.Error(w, "server does not support batch requests", http.StatusBadRequest) + // return + // } + + calls := make(map[jsonrpc.ID]struct{}) + tokenInfo := auth.TokenInfoFromContext(req.Context()) + isInitialize := false + for _, msg := range incoming { + if jreq, ok := msg.(*jsonrpc.Request); ok { + // Preemptively check that this is a valid request, so that we can fail + // the HTTP request. If we didn't do this, a request with a bad method or + // missing ID could be silently swallowed. + if _, err := checkRequest(jreq, serverMethodInfos); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if jreq.Method == methodInitialize { + isInitialize = true + } + jreq.Extra = &RequestExtra{ + TokenInfo: tokenInfo, + Header: req.Header, + } + if jreq.IsCall() { + calls[jreq.ID] = struct{}{} + } + } + } + + // If we don't have any calls, we can just publish the incoming messages and return. + // No need to track a logical stream. + if len(calls) == 0 { + for _, msg := range incoming { + select { + case c.incoming <- msg: + case <-c.done: + // The session is closing. Since we haven't yet written any data to the + // response, we can signal to the client that the session is gone. + http.Error(w, "session is closing", http.StatusNotFound) + return + } + } + w.WriteHeader(http.StatusAccepted) + return + } + + // Invariant: we have at least one call. + // + // Create a logical stream to track its responses. + // Important: don't publish the incoming messages until the stream is + // registered, as the server may attempt to respond to imcoming messages as + // soon as they're published. + stream, err := c.newStream(req.Context(), calls, randText()) + if err != nil { + http.Error(w, fmt.Sprintf("storing stream: %v", err), http.StatusInternalServerError) + return + } + + // Set response headers. Accept was checked in [StreamableHTTPHandler]. + w.Header().Set("Cache-Control", "no-cache, no-transform") + if c.jsonResponse { + w.Header().Set("Content-Type", "application/json") + } else { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Connection", "keep-alive") + } + if c.sessionID != "" && isInitialize { + w.Header().Set(sessionIDHeader, c.sessionID) + } + + // Message delivery has two paths, depending on whether we're responding with JSON or + // event stream. + done := make(chan struct{}) // closed after the final response is written + if c.jsonResponse { + var msgs []json.RawMessage + stream.deliver = func(data []byte, final bool) error { + // Collect messages until we've received the final response. + // + // In recent protocol versions, there should only be one message as + // batching is disabled, as checked above. + msgs = append(msgs, data) + if !final { + return nil + } + defer close(done) // final response + + // Write either the JSON object corresponding to the one response, or a + // JSON array corresponding to the batch response. + var toWrite []byte + if len(msgs) == 1 { + toWrite = []byte(msgs[0]) + } else { + var err error + toWrite, err = json.Marshal(msgs) + if err != nil { + return err + } + } + _, err = w.Write(toWrite) + return err + } + } else { + // Write events in the order we receive them. + lastIndex := -1 + stream.deliver = func(data []byte, final bool) error { + if final { + defer close(done) + } + return c.writeEvent(w, stream, data, &lastIndex) + } + } + + // Release ownership of the stream by unsetting deliver. + defer func() { + stream.mu.Lock() + // TODO(rfindley): if we have no event store, we should really cancel all + // remaining requests here, since the client will never get the results. + stream.deliver = nil + stream.mu.Unlock() + }() + + // The stream is now set up to deliver messages. + // + // Register it before publishing incoming messages. + c.mu.Lock() + c.streams[stream.id] = stream + for reqID := range calls { + c.requestStreams[reqID] = stream.id + } + c.mu.Unlock() + + // Publish incoming messages. + for _, msg := range incoming { + select { + case c.incoming <- msg: + // Note: don't select on req.Context().Done() here, since we've already + // received the requests and may have already published a response message + // or notification. The client could resume the stream. + case <-c.done: + // Session closed: we don't know if any data has been written, so it's + // too late to write a status code here. + return + } + } + + select { + case <-req.Context().Done(): + // request cancelled + case <-done: + // request complete + case <-c.done: + // session is closed + } +} + +// Event IDs: encode both the logical connection ID and the index, as +// _, to be consistent with the typescript implementation. + +// formatEventID returns the event ID to use for the logical connection ID +// streamID and message index idx. +// +// See also [parseEventID]. +func formatEventID(sid string, idx int) string { + return fmt.Sprintf("%s_%d", sid, idx) +} + +// parseEventID parses a Last-Event-ID value into a logical stream id and +// index. +// +// See also [formatEventID]. +func parseEventID(eventID string) (streamID string, idx int, ok bool) { + parts := strings.Split(eventID, "_") + if len(parts) != 2 { + return "", 0, false + } + streamID = parts[0] + idx, err := strconv.Atoi(parts[1]) + if err != nil || idx < 0 { + return "", 0, false + } + return streamID, idx, true +} + +// Read implements the [Connection] interface. +func (c *streamableServerConn) Read(ctx context.Context) (jsonrpc.Message, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case msg, ok := <-c.incoming: + if !ok { + return nil, io.EOF + } + return msg, nil + case <-c.done: + return nil, io.EOF + } +} + +// Write implements the [Connection] interface. +func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) error { + // Throughout this function, note that any error that wraps ErrRejected + // indicates a does not cause the connection to break. + // + // Most errors don't break the connection: unlike a true bidirectional + // stream, a failure to deliver to a stream is not an indication that the + // logical session is broken. + data, err := jsonrpc2.EncodeMessage(msg) + if err != nil { + return err + } + + if req, ok := msg.(*jsonrpc.Request); ok && req.ID.IsValid() && (c.stateless || c.sessionID == "") { + // Requests aren't possible with stateless servers, or when there's no session ID. + return fmt.Errorf("%w: stateless servers cannot make requests", jsonrpc2.ErrRejected) + } + + // Find the incoming request that this write relates to, if any. + var ( + relatedRequest jsonrpc.ID + responseTo jsonrpc.ID // if valid, the message is a response to this request + ) + if resp, ok := msg.(*jsonrpc.Response); ok { + // If the message is a response, it relates to its request (of course). + relatedRequest = resp.ID + responseTo = resp.ID + } else { + // Otherwise, we check to see if it request was made in the context of an + // ongoing request. This may not be the case if the request was made with + // an unrelated context. + if v := ctx.Value(idContextKey{}); v != nil { + relatedRequest = v.(jsonrpc.ID) + } + } + + // If the stream is application/json, but the message is not a response, we + // must send it out of band to the standalone SSE stream. + if c.jsonResponse && !responseTo.IsValid() { + relatedRequest = jsonrpc.ID{} + } + + // Write the message to the stream. + var s *stream + c.mu.Lock() + if relatedRequest.IsValid() { + if streamID, ok := c.requestStreams[relatedRequest]; ok { + s = c.streams[streamID] + } + } else { + s = c.streams[""] // standalone SSE stream + } + if responseTo.IsValid() { + // Once we've responded to a request, disallow related messages by removing + // the stream association. This also releases memory. + delete(c.requestStreams, responseTo) + } + sessionClosed := c.isDone + c.mu.Unlock() + + if s == nil { + // The request was made in the context of an ongoing request, but that + // request is complete. + // + // In the future, we could be less strict and allow the request to land on + // the standalone SSE stream. + return fmt.Errorf("%w: write to closed stream", jsonrpc2.ErrRejected) + } + if sessionClosed { + return errors.New("session is closed") + } + + s.mu.Lock() + defer s.mu.Unlock() + if s.doneLocked() { + // It's possible that the stream was completed in between getting s above, + // and acquiring the stream lock. In order to avoid acquiring s.mu while + // holding c.mu, we check the terminal condition again. + return fmt.Errorf("%w: write to closed stream", jsonrpc2.ErrRejected) + } + // Perform accounting on responses. + if responseTo.IsValid() { + if _, ok := s.requests[responseTo]; !ok { + panic(fmt.Sprintf("internal error: stream %v: response to untracked request %v", s.id, responseTo)) + } + if s.id == "" { + // This should be guaranteed not to happen by the stream resolution logic + // above, but be defensive: we don't ever want to delete the standalone + // stream. + panic("internal error: response on standalone stream") + } + delete(s.requests, responseTo) + if len(s.requests) == 0 { + c.mu.Lock() + delete(c.streams, s.id) + c.mu.Unlock() + } + } + + delivered := false + if c.eventStore != nil { + if err := c.eventStore.Append(ctx, c.sessionID, s.id, data); err != nil { + // TODO: report a side-channel error. + } else { + delivered = true + } + } + if s.deliver != nil { + if err := s.deliver(data, s.doneLocked()); err != nil { + // TODO: report a side-channel error. + } else { + delivered = true + } + } + if !delivered { + return fmt.Errorf("%w: undelivered message", jsonrpc2.ErrRejected) + } + return nil +} + +// Close implements the [Connection] interface. +func (c *streamableServerConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + if !c.isDone { + c.isDone = true + close(c.done) + if c.eventStore != nil { + // TODO: find a way to plumb a context here, or an event store with a long-running + // close operation can take arbitrary time. Alternative: impose a fixed timeout here. + return c.eventStore.SessionClosed(context.TODO(), c.sessionID) + } + } + return nil +} + +// A StreamableClientTransport is a [Transport] that can communicate with an MCP +// endpoint serving the streamable HTTP transport defined by the 2025-03-26 +// version of the spec. +type StreamableClientTransport struct { + Endpoint string + HTTPClient *http.Client + // MaxRetries is the maximum number of times to attempt a reconnect before giving up. + // It defaults to 5. To disable retries, use a negative number. + MaxRetries int + + // TODO(rfindley): propose exporting these. + // If strict is set, the transport is in 'strict mode', where any violation + // of the MCP spec causes a failure. + strict bool + // If logger is set, it is used to log aspects of the transport, such as spec + // violations that were ignored. + logger *slog.Logger +} + +// These settings are not (yet) exposed to the user in +// StreamableClientTransport. +const ( + // reconnectGrowFactor is the multiplicative factor by which the delay increases after each attempt. + // A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time. + // It must be 1.0 or greater if MaxRetries is greater than 0. + reconnectGrowFactor = 1.5 + // reconnectInitialDelay is the base delay for the first reconnect attempt. + reconnectInitialDelay = 1 * time.Second + // reconnectMaxDelay caps the backoff delay, preventing it from growing indefinitely. + reconnectMaxDelay = 30 * time.Second +) + +// Connect implements the [Transport] interface. +// +// The resulting [Connection] writes messages via POST requests to the +// transport URL with the Mcp-Session-Id header set, and reads messages from +// hanging requests. +// +// When closed, the connection issues a DELETE request to terminate the logical +// session. +func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, error) { + client := t.HTTPClient + if client == nil { + client = http.DefaultClient + } + maxRetries := t.MaxRetries + if maxRetries == 0 { + maxRetries = 5 + } else if maxRetries < 0 { + maxRetries = 0 + } + // Create a new cancellable context that will manage the connection's lifecycle. + // This is crucial for cleanly shutting down the background SSE listener by + // cancelling its blocking network operations, which prevents hangs on exit. + connCtx, cancel := context.WithCancel(ctx) + conn := &streamableClientConn{ + url: t.Endpoint, + client: client, + incoming: make(chan jsonrpc.Message, 10), + done: make(chan struct{}), + maxRetries: maxRetries, + strict: t.strict, + logger: t.logger, + ctx: connCtx, + cancel: cancel, + failed: make(chan struct{}), + } + return conn, nil +} + +type streamableClientConn struct { + url string + client *http.Client + ctx context.Context + cancel context.CancelFunc + incoming chan jsonrpc.Message + maxRetries int + strict bool // from [StreamableClientTransport.strict] + logger *slog.Logger // from [StreamableClientTransport.logger] + + // Guard calls to Close, as it may be called multiple times. + closeOnce sync.Once + closeErr error + done chan struct{} // signal graceful termination + + // Logical reads are distributed across multiple http requests. Whenever any + // of them fails to process their response, we must break the connection, by + // failing the pending Read. + // + // Achieve this by storing the failure message, and signalling when reads are + // broken. See also [streamableClientConn.fail] and + // [streamableClientConn.failure]. + failOnce sync.Once + _failure error + failed chan struct{} // signal failure + + // Guard the initialization state. + mu sync.Mutex + initializedResult *InitializeResult + sessionID string +} + +// errSessionMissing distinguishes if the session is known to not be present on +// the server (see [streamableClientConn.fail]). +// +// TODO(rfindley): should we expose this error value (and its corresponding +// API) to the user? +// +// The spec says that if the server returns 404, clients should reestablish +// a session. For now, we delegate that to the user, but do they need a way to +// differentiate a 'NotFound' error from other errors? +var errSessionMissing = errors.New("session not found") + +var _ clientConnection = (*streamableClientConn)(nil) + +func (c *streamableClientConn) sessionUpdated(state clientSessionState) { + c.mu.Lock() + c.initializedResult = state.InitializeResult + c.mu.Unlock() + + // Start the standalone SSE stream as soon as we have the initialized + // result. + // + // § 2.2: The client MAY issue an HTTP GET to the MCP endpoint. This can be + // used to open an SSE stream, allowing the server to communicate to the + // client, without the client first sending data via HTTP POST. + // + // We have to wait for initialized, because until we've received + // initialized, we don't know whether the server requires a sessionID. + // + // § 2.5: A server using the Streamable HTTP transport MAY assign a session + // ID at initialization time, by including it in an Mcp-Session-Id header + // on the HTTP response containing the InitializeResult. + c.connectStandaloneSSE() +} + +func (c *streamableClientConn) connectStandaloneSSE() { + resp, err := c.connectSSE("") + if err != nil { + c.fail(fmt.Errorf("standalone SSE request failed (session ID: %v): %v", c.sessionID, err)) + return + } + + // [§2.2.3]: "The server MUST either return Content-Type: + // text/event-stream in response to this HTTP GET, or else return HTTP + // 405 Method Not Allowed, indicating that the server does not offer an + // SSE stream at this endpoint." + // + // [§2.2.3]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#listening-for-messages-from-the-server + if resp.StatusCode == http.StatusMethodNotAllowed { + // The server doesn't support the standalone SSE stream. + resp.Body.Close() + return + } + if resp.StatusCode == http.StatusNotFound && !c.strict { + // modelcontextprotocol/gosdk#393: some servers return NotFound instead + // of MethodNotAllowed for the standalone SSE stream. + // + // Treat this like MethodNotAllowed in non-strict mode. + if c.logger != nil { + c.logger.Warn("got 404 instead of 405 for standalone SSE stream") + } + resp.Body.Close() + return + } + summary := "standalone SSE stream" + if err := c.checkResponse(summary, resp); err != nil { + c.fail(err) + return + } + go c.handleSSE(summary, resp, true, nil) +} + +// fail handles an asynchronous error while reading. +// +// If err is non-nil, it is terminal, and subsequent (or pending) Reads will +// fail. +// +// If err wraps errSessionMissing, the failure indicates that the session is no +// longer present on the server, and no final DELETE will be performed when +// closing the connection. +func (c *streamableClientConn) fail(err error) { + if err != nil { + c.failOnce.Do(func() { + c._failure = err + close(c.failed) + }) + } +} + +func (c *streamableClientConn) failure() error { + select { + case <-c.failed: + return c._failure + default: + return nil + } +} + +func (c *streamableClientConn) SessionID() string { + c.mu.Lock() + defer c.mu.Unlock() + return c.sessionID +} + +// Read implements the [Connection] interface. +func (c *streamableClientConn) Read(ctx context.Context) (jsonrpc.Message, error) { + if err := c.failure(); err != nil { + return nil, err + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-c.failed: + return nil, c.failure() + case <-c.done: + return nil, io.EOF + case msg := <-c.incoming: + return msg, nil + } +} + +// Write implements the [Connection] interface. +func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) error { + if err := c.failure(); err != nil { + return err + } + + var requestSummary string + var isCall bool + switch msg := msg.(type) { + case *jsonrpc.Request: + requestSummary = fmt.Sprintf("sending %q", msg.Method) + isCall = msg.IsCall() + case *jsonrpc.Response: + requestSummary = fmt.Sprintf("sending jsonrpc response #%d", msg.ID) + default: + panic("unreachable") + } + + data, err := jsonrpc.EncodeMessage(msg) + if err != nil { + return fmt.Errorf("%s: %v", requestSummary, err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(data)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + c.setMCPHeaders(req) + + resp, err := c.client.Do(req) + if err != nil { + return fmt.Errorf("%s: %v", requestSummary, err) + } + + if err := c.checkResponse(requestSummary, resp); err != nil { + c.fail(err) + return err + } + + if sessionID := resp.Header.Get(sessionIDHeader); sessionID != "" { + c.mu.Lock() + hadSessionID := c.sessionID + if hadSessionID == "" { + c.sessionID = sessionID + } + c.mu.Unlock() + if hadSessionID != "" && hadSessionID != sessionID { + resp.Body.Close() + return fmt.Errorf("mismatching session IDs %q and %q", hadSessionID, sessionID) + } + } + // TODO(rfindley): this logic isn't quite right. + // We should keep going even if the server returns 202, if we have a call. + if resp.StatusCode == http.StatusNoContent || resp.StatusCode == http.StatusAccepted { + // [§2.1.4]: "If the input is a JSON-RPC response or notification: + // If the server accepts the input, the server MUST return HTTP status code 202 Accepted with no body." + // + // [§2.1.4]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#listening-for-messages-from-the-server + resp.Body.Close() + return nil + } else if !isCall && !c.strict { + // Some servers return 200, even with an empty json body. + // Ignore this response in non-strict mode. + if c.logger != nil { + c.logger.Warn(fmt.Sprintf("unexpected status code %d from non-call", resp.StatusCode)) + } + resp.Body.Close() + return nil + } + + contentType := strings.TrimSpace(strings.SplitN(resp.Header.Get("Content-Type"), ";", 2)[0]) + switch contentType { + case "application/json": + go c.handleJSON(requestSummary, resp) + + case "text/event-stream": + var forCall *jsonrpc.Request + if jsonReq, ok := msg.(*jsonrpc.Request); ok && jsonReq.IsCall() { + forCall = jsonReq + } + // TODO: should we cancel this logical SSE request if/when jsonReq is canceled? + go c.handleSSE(requestSummary, resp, false, forCall) + + default: + resp.Body.Close() + return fmt.Errorf("%s: unsupported content type %q", requestSummary, contentType) + } + return nil +} + +// testAuth controls whether a fake Authorization header is added to outgoing requests. +// TODO: replace with a better mechanism when client-side auth is in place. +var testAuth atomic.Bool + +func (c *streamableClientConn) setMCPHeaders(req *http.Request) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.initializedResult != nil { + req.Header.Set(protocolVersionHeader, c.initializedResult.ProtocolVersion) + } + if c.sessionID != "" { + req.Header.Set(sessionIDHeader, c.sessionID) + } + if testAuth.Load() { + req.Header.Set("Authorization", "Bearer foo") + } +} + +func (c *streamableClientConn) handleJSON(requestSummary string, resp *http.Response) { + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + c.fail(fmt.Errorf("%s: failed to read body: %v", requestSummary, err)) + return + } + msg, err := jsonrpc.DecodeMessage(body) + if err != nil { + c.fail(fmt.Errorf("%s: failed to decode response: %v", requestSummary, err)) + return + } + select { + case c.incoming <- msg: + case <-c.done: + // The connection was closed by the client; exit gracefully. + } +} + +// handleSSE manages the lifecycle of an SSE connection. It can be either +// persistent (for the main GET listener) or temporary (for a POST response). +// +// If forCall is set, it is the call that initiated the stream, and the +// stream is complete when we receive its response. +func (c *streamableClientConn) handleSSE(requestSummary string, resp *http.Response, persistent bool, forCall *jsonrpc2.Request) { + for { + // Connection was successful. Continue the loop with the new response. + // TODO: we should set a reasonable limit on the number of times we'll try + // getting a response for a given request. + // + // Eventually, if we don't get the response, we should stop trying and + // fail the request. + lastEventID, clientClosed := c.processStream(requestSummary, resp, forCall) + + // If the connection was closed by the client, we're done. + if clientClosed { + return + } + // If the stream has ended, then do not reconnect if the stream is + // temporary (POST initiated SSE). + if lastEventID == "" && !persistent { + return + } + + // The stream was interrupted or ended by the server. Attempt to reconnect. + newResp, err := c.connectSSE(lastEventID) + if err != nil { + // All reconnection attempts failed: fail the connection. + c.fail(fmt.Errorf("%s: failed to reconnect (session ID: %v): %v", requestSummary, c.sessionID, err)) + return + } + resp = newResp + if err := c.checkResponse(requestSummary, resp); err != nil { + c.fail(err) + return + } + } +} + +// checkResponse checks the status code of the provided response, and +// translates it into an error if the request was unsuccessful. +// +// The response body is close if a non-nil error is returned. +func (c *streamableClientConn) checkResponse(requestSummary string, resp *http.Response) (err error) { + defer func() { + if err != nil { + resp.Body.Close() + } + }() + // §2.5.3: "The server MAY terminate the session at any time, after + // which it MUST respond to requests containing that session ID with HTTP + // 404 Not Found." + if resp.StatusCode == http.StatusNotFound { + // Return an errSessionMissing to avoid sending a redundant DELETE when the + // session is already gone. + return fmt.Errorf("%s: failed to connect (session ID: %v): %w", requestSummary, c.sessionID, errSessionMissing) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("%s: failed to connect: %v", requestSummary, http.StatusText(resp.StatusCode)) + } + return nil +} + +// processStream reads from a single response body, sending events to the +// incoming channel. It returns the ID of the last processed event and a flag +// indicating if the connection was closed by the client. If resp is nil, it +// returns "", false. +func (c *streamableClientConn) processStream(requestSummary string, resp *http.Response, forCall *jsonrpc.Request) (lastEventID string, clientClosed bool) { + defer resp.Body.Close() + for evt, err := range scanEvents(resp.Body) { + if err != nil { + // TODO: we should differentiate EOF from other errors here. + break + } + + if evt.ID != "" { + lastEventID = evt.ID + } + + msg, err := jsonrpc.DecodeMessage(evt.Data) + if err != nil { + c.fail(fmt.Errorf("%s: failed to decode event: %v", requestSummary, err)) + return "", true + } + + select { + case c.incoming <- msg: + if jsonResp, ok := msg.(*jsonrpc.Response); ok && forCall != nil { + // TODO: we should never get a response when forReq is nil (the standalone SSE request). + // We should detect this case. + if jsonResp.ID == forCall.ID { + return "", true + } + } + case <-c.done: + // The connection was closed by the client; exit gracefully. + return "", true + } + } + // The loop finished without an error, indicating the server closed the stream. + // + // If the lastEventID is "", the stream is not retryable and we should + // report a synthetic error for the call. + if lastEventID == "" && forCall != nil { + errmsg := &jsonrpc2.Response{ + ID: forCall.ID, + Error: fmt.Errorf("request terminated without response"), + } + select { + case c.incoming <- errmsg: + case <-c.done: + } + } + return lastEventID, false +} + +// connectSSE handles the logic of connecting a text/event-stream connection. +// +// If lastEventID is set, it is the last-event ID of a stream being resumed. +// +// If connection fails, connectSSE retries with an exponential backoff +// strategy. It returns a new, valid HTTP response if successful, or an error +// if all retries are exhausted. +func (c *streamableClientConn) connectSSE(lastEventID string) (*http.Response, error) { + var finalErr error + // If lastEventID is set, we've already connected successfully once, so + // consider that to be the first attempt. + attempt := 0 + if lastEventID != "" { + attempt = 1 + } + for ; attempt <= c.maxRetries; attempt++ { + select { + case <-c.done: + return nil, fmt.Errorf("connection closed by client during reconnect") + case <-time.After(calculateReconnectDelay(attempt)): + req, err := http.NewRequestWithContext(c.ctx, http.MethodGet, c.url, nil) + if err != nil { + return nil, err + } + c.setMCPHeaders(req) + if lastEventID != "" { + req.Header.Set("Last-Event-ID", lastEventID) + } + req.Header.Set("Accept", "text/event-stream") + resp, err := c.client.Do(req) + if err != nil { + finalErr = err // Store the error and try again. + continue + } + return resp, nil + } + } + // If the loop completes, all retries have failed, or the client is closing. + if finalErr != nil { + return nil, fmt.Errorf("connection failed after %d attempts: %w", c.maxRetries, finalErr) + } + return nil, fmt.Errorf("connection aborted after %d attempts", c.maxRetries) +} + +// Close implements the [Connection] interface. +func (c *streamableClientConn) Close() error { + c.closeOnce.Do(func() { + if errors.Is(c.failure(), errSessionMissing) { + // If the session is missing, no need to delete it. + } else { + req, err := http.NewRequestWithContext(c.ctx, http.MethodDelete, c.url, nil) + if err != nil { + c.closeErr = err + } else { + c.setMCPHeaders(req) + if _, err := c.client.Do(req); err != nil { + c.closeErr = err + } + } + } + + // Cancel any hanging network requests after cleanup. + c.cancel() + close(c.done) + }) + return c.closeErr +} + +// calculateReconnectDelay calculates a delay using exponential backoff with full jitter. +func calculateReconnectDelay(attempt int) time.Duration { + if attempt == 0 { + return 0 + } + // Calculate the exponential backoff using the grow factor. + backoffDuration := time.Duration(float64(reconnectInitialDelay) * math.Pow(reconnectGrowFactor, float64(attempt-1))) + // Cap the backoffDuration at maxDelay. + backoffDuration = min(backoffDuration, reconnectMaxDelay) + + // Use a full jitter using backoffDuration + jitter := rand.N(backoffDuration) + + return backoffDuration + jitter +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/tool.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/tool.go new file mode 100644 index 0000000000..12b02b7bb0 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/tool.go @@ -0,0 +1,103 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/google/jsonschema-go/jsonschema" +) + +// A ToolHandler handles a call to tools/call. +// +// This is a low-level API, for use with [Server.AddTool]. It does not do any +// pre- or post-processing of the request or result: the params contain raw +// arguments, no input validation is performed, and the result is returned to +// the user as-is, without any validation of the output. +// +// Most users will write a [ToolHandlerFor] and install it with the generic +// [AddTool] function. +// +// If ToolHandler returns an error, it is treated as a protocol error. By +// contrast, [ToolHandlerFor] automatically populates [CallToolResult.IsError] +// and [CallToolResult.Content] accordingly. +type ToolHandler func(context.Context, *CallToolRequest) (*CallToolResult, error) + +// A ToolHandlerFor handles a call to tools/call with typed arguments and results. +// +// Use [AddTool] to add a ToolHandlerFor to a server. +// +// Unlike [ToolHandler], [ToolHandlerFor] provides significant functionality +// out of the box, and enforces that the tool conforms to the MCP spec: +// - The In type provides a default input schema for the tool, though it may +// be overridden in [AddTool]. +// - The input value is automatically unmarshaled from req.Params.Arguments. +// - The input value is automatically validated against its input schema. +// Invalid input is rejected before getting to the handler. +// - If the Out type is not the empty interface [any], it provides the +// default output schema for the tool (which again may be overridden in +// [AddTool]). +// - The Out value is used to populate result.StructuredOutput. +// - If [CallToolResult.Content] is unset, it is populated with the JSON +// content of the output. +// - An error result is treated as a tool error, rather than a protocol +// error, and is therefore packed into CallToolResult.Content, with +// [IsError] set. +// +// For these reasons, most users can ignore the [CallToolRequest] argument and +// [CallToolResult] return values entirely. In fact, it is permissible to +// return a nil CallToolResult, if you only care about returning a output value +// or error. The effective result will be populated as described above. +type ToolHandlerFor[In, Out any] func(_ context.Context, request *CallToolRequest, input In) (result *CallToolResult, output Out, _ error) + +// A serverTool is a tool definition that is bound to a tool handler. +type serverTool struct { + tool *Tool + handler ToolHandler +} + +// applySchema validates whether data is valid JSON according to the provided +// schema, after applying schema defaults. +// +// Returns the JSON value augmented with defaults. +func applySchema(data json.RawMessage, resolved *jsonschema.Resolved) (json.RawMessage, error) { + // TODO: use reflection to create the struct type to unmarshal into. + // Separate validation from assignment. + + // Use default JSON marshalling for validation. + // + // This avoids inconsistent representation due to custom marshallers, such as + // time.Time (issue #449). + // + // Additionally, unmarshalling into a map ensures that the resulting JSON is + // at least {}, even if data is empty. For example, arguments is technically + // an optional property of callToolParams, and we still want to apply the + // defaults in this case. + // + // TODO(rfindley): in which cases can resolved be nil? + if resolved != nil { + v := make(map[string]any) + if len(data) > 0 { + if err := json.Unmarshal(data, &v); err != nil { + return nil, fmt.Errorf("unmarshaling arguments: %w", err) + } + } + if err := resolved.ApplyDefaults(&v); err != nil { + return nil, fmt.Errorf("applying schema defaults:\n%w", err) + } + if err := resolved.Validate(&v); err != nil { + return nil, err + } + // We must re-marshal with the default values applied. + var err error + data, err = json.Marshal(v) + if err != nil { + return nil, fmt.Errorf("marshalling with defaults: %v", err) + } + } + return data, nil +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/transport.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/transport.go new file mode 100644 index 0000000000..cacd65fd53 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/transport.go @@ -0,0 +1,643 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net" + "os" + "sync" + + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/internal/xcontext" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" +) + +// ErrConnectionClosed is returned when sending a message to a connection that +// is closed or in the process of closing. +var ErrConnectionClosed = errors.New("connection closed") + +// A Transport is used to create a bidirectional connection between MCP client +// and server. +// +// Transports should be used for at most one call to [Server.Connect] or +// [Client.Connect]. +type Transport interface { + // Connect returns the logical JSON-RPC connection.. + // + // It is called exactly once by [Server.Connect] or [Client.Connect]. + Connect(ctx context.Context) (Connection, error) +} + +// A Connection is a logical bidirectional JSON-RPC connection. +type Connection interface { + // Read reads the next message to process off the connection. + // + // Connections must allow Read to be called concurrently with Close. In + // particular, calling Close should unblock a Read waiting for input. + Read(context.Context) (jsonrpc.Message, error) + + // Write writes a new message to the connection. + // + // Write may be called concurrently, as calls or responses may occur + // concurrently in user code. + Write(context.Context, jsonrpc.Message) error + + // Close closes the connection. It is implicitly called whenever a Read or + // Write fails. + // + // Close may be called multiple times, potentially concurrently. + Close() error + + // TODO(#148): remove SessionID from this interface. + SessionID() string +} + +// A ClientConnection is a [Connection] that is specific to the MCP client. +// +// If client connections implement this interface, they may receive information +// about changes to the client session. +// +// TODO: should this interface be exported? +type clientConnection interface { + Connection + + // SessionUpdated is called whenever the client session state changes. + sessionUpdated(clientSessionState) +} + +// A serverConnection is a Connection that is specific to the MCP server. +// +// If server connections implement this interface, they receive information +// about changes to the server session. +// +// TODO: should this interface be exported? +type serverConnection interface { + Connection + sessionUpdated(ServerSessionState) +} + +// A StdioTransport is a [Transport] that communicates over stdin/stdout using +// newline-delimited JSON. +type StdioTransport struct{} + +// Connect implements the [Transport] interface. +func (*StdioTransport) Connect(context.Context) (Connection, error) { + return newIOConn(rwc{os.Stdin, nopCloserWriter{os.Stdout}}), nil +} + +// nopCloserWriter is an io.WriteCloser with a trivial Close method. +type nopCloserWriter struct { + io.Writer +} + +func (nopCloserWriter) Close() error { return nil } + +// An IOTransport is a [Transport] that communicates over separate +// io.ReadCloser and io.WriteCloser using newline-delimited JSON. +type IOTransport struct { + Reader io.ReadCloser + Writer io.WriteCloser +} + +// Connect implements the [Transport] interface. +func (t *IOTransport) Connect(context.Context) (Connection, error) { + return newIOConn(rwc{t.Reader, t.Writer}), nil +} + +// An InMemoryTransport is a [Transport] that communicates over an in-memory +// network connection, using newline-delimited JSON. +// +// InMemoryTransports should be constructed using [NewInMemoryTransports], +// which returns two transports connected to each other. +type InMemoryTransport struct { + rwc io.ReadWriteCloser +} + +// Connect implements the [Transport] interface. +func (t *InMemoryTransport) Connect(context.Context) (Connection, error) { + return newIOConn(t.rwc), nil +} + +// NewInMemoryTransports returns two [InMemoryTransport] objects that connect +// to each other. +// +// The resulting transports are symmetrical: use either to connect to a server, +// and then the other to connect to a client. Servers must be connected before +// clients, as the client initializes the MCP session during connection. +func NewInMemoryTransports() (*InMemoryTransport, *InMemoryTransport) { + c1, c2 := net.Pipe() + return &InMemoryTransport{c1}, &InMemoryTransport{c2} +} + +type binder[T handler, State any] interface { + // TODO(rfindley): the bind API has gotten too complicated. Simplify. + bind(Connection, *jsonrpc2.Connection, State, func()) T + disconnect(T) +} + +type handler interface { + handle(ctx context.Context, req *jsonrpc.Request) (any, error) +} + +func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, State], s State, onClose func()) (H, error) { + var zero H + mcpConn, err := t.Connect(ctx) + if err != nil { + return zero, err + } + // If logging is configured, write message logs. + reader, writer := jsonrpc2.Reader(mcpConn), jsonrpc2.Writer(mcpConn) + var ( + h H + preempter canceller + ) + bind := func(conn *jsonrpc2.Connection) jsonrpc2.Handler { + h = b.bind(mcpConn, conn, s, onClose) + preempter.conn = conn + return jsonrpc2.HandlerFunc(h.handle) + } + _ = jsonrpc2.NewConnection(ctx, jsonrpc2.ConnectionConfig{ + Reader: reader, + Writer: writer, + Closer: mcpConn, + Bind: bind, + Preempter: &preempter, + OnDone: func() { + b.disconnect(h) + }, + OnInternalError: func(err error) { log.Printf("jsonrpc2 error: %v", err) }, + }) + assert(preempter.conn != nil, "unbound preempter") + return h, nil +} + +// A canceller is a jsonrpc2.Preempter that cancels in-flight requests on MCP +// cancelled notifications. +type canceller struct { + conn *jsonrpc2.Connection +} + +// Preempt implements [jsonrpc2.Preempter]. +func (c *canceller) Preempt(ctx context.Context, req *jsonrpc.Request) (result any, err error) { + if req.Method == notificationCancelled { + var params CancelledParams + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + return nil, err + } + id, err := jsonrpc2.MakeID(params.RequestID) + if err != nil { + return nil, err + } + go c.conn.Cancel(id) + } + return nil, jsonrpc2.ErrNotHandled +} + +// call executes and awaits a jsonrpc2 call on the given connection, +// translating errors into the mcp domain. +func call(ctx context.Context, conn *jsonrpc2.Connection, method string, params Params, result Result) error { + // TODO: the "%w"s in this function effectively make jsonrpc2.WireError part of the API. + // Consider alternatives. + call := conn.Call(ctx, method, params) + err := call.Await(ctx, result) + switch { + case errors.Is(err, jsonrpc2.ErrClientClosing), errors.Is(err, jsonrpc2.ErrServerClosing): + return fmt.Errorf("%w: calling %q: %v", ErrConnectionClosed, method, err) + case ctx.Err() != nil: + // Notify the peer of cancellation. + err := conn.Notify(xcontext.Detach(ctx), notificationCancelled, &CancelledParams{ + Reason: ctx.Err().Error(), + RequestID: call.ID().Raw(), + }) + return errors.Join(ctx.Err(), err) + case err != nil: + return fmt.Errorf("calling %q: %w", method, err) + } + return nil +} + +// A LoggingTransport is a [Transport] that delegates to another transport, +// writing RPC logs to an io.Writer. +type LoggingTransport struct { + Transport Transport + Writer io.Writer +} + +// Connect connects the underlying transport, returning a [Connection] that writes +// logs to the configured destination. +func (t *LoggingTransport) Connect(ctx context.Context) (Connection, error) { + delegate, err := t.Transport.Connect(ctx) + if err != nil { + return nil, err + } + return &loggingConn{delegate: delegate, w: t.Writer}, nil +} + +type loggingConn struct { + delegate Connection + + mu sync.Mutex + w io.Writer +} + +func (c *loggingConn) SessionID() string { return c.delegate.SessionID() } + +// Read is a stream middleware that logs incoming messages. +func (s *loggingConn) Read(ctx context.Context) (jsonrpc.Message, error) { + msg, err := s.delegate.Read(ctx) + + if err != nil { + s.mu.Lock() + fmt.Fprintf(s.w, "read error: %v\n", err) + s.mu.Unlock() + } else { + data, err := jsonrpc2.EncodeMessage(msg) + s.mu.Lock() + if err != nil { + fmt.Fprintf(s.w, "LoggingTransport: failed to marshal: %v", err) + } + fmt.Fprintf(s.w, "read: %s\n", string(data)) + s.mu.Unlock() + } + + return msg, err +} + +// Write is a stream middleware that logs outgoing messages. +func (s *loggingConn) Write(ctx context.Context, msg jsonrpc.Message) error { + err := s.delegate.Write(ctx, msg) + if err != nil { + s.mu.Lock() + fmt.Fprintf(s.w, "write error: %v\n", err) + s.mu.Unlock() + } else { + data, err := jsonrpc2.EncodeMessage(msg) + s.mu.Lock() + if err != nil { + fmt.Fprintf(s.w, "LoggingTransport: failed to marshal: %v", err) + } + fmt.Fprintf(s.w, "write: %s\n", string(data)) + s.mu.Unlock() + } + return err +} + +func (s *loggingConn) Close() error { + return s.delegate.Close() +} + +// A rwc binds an io.ReadCloser and io.WriteCloser together to create an +// io.ReadWriteCloser. +type rwc struct { + rc io.ReadCloser + wc io.WriteCloser +} + +func (r rwc) Read(p []byte) (n int, err error) { + return r.rc.Read(p) +} + +func (r rwc) Write(p []byte) (n int, err error) { + return r.wc.Write(p) +} + +func (r rwc) Close() error { + rcErr := r.rc.Close() + + var wcErr error + if r.wc != nil { // we only allow a nil writer in unit tests + wcErr = r.wc.Close() + } + + return errors.Join(rcErr, wcErr) +} + +// An ioConn is a transport that delimits messages with newlines across +// a bidirectional stream, and supports jsonrpc.2 message batching. +// +// See https://github.com/ndjson/ndjson-spec for discussion of newline +// delimited JSON. +// +// See [msgBatch] for more discussion of message batching. +type ioConn struct { + protocolVersion string // negotiated version, set during session initialization. + + writeMu sync.Mutex // guards Write, which must be concurrency safe. + rwc io.ReadWriteCloser // the underlying stream + + // incoming receives messages from the read loop started in [newIOConn]. + incoming <-chan msgOrErr + + // If outgoiBatch has a positive capacity, it will be used to batch requests + // and notifications before sending. + outgoingBatch []jsonrpc.Message + + // Unread messages in the last batch. Since reads are serialized, there is no + // need to guard here. + queue []jsonrpc.Message + + // batches correlate incoming requests to the batch in which they arrived. + // Since writes may be concurrent to reads, we need to guard this with a mutex. + batchMu sync.Mutex + batches map[jsonrpc2.ID]*msgBatch // lazily allocated + + closeOnce sync.Once + closed chan struct{} + closeErr error +} + +type msgOrErr struct { + msg json.RawMessage + err error +} + +func newIOConn(rwc io.ReadWriteCloser) *ioConn { + var ( + incoming = make(chan msgOrErr) + closed = make(chan struct{}) + ) + // Start a goroutine for reads, so that we can select on the incoming channel + // in [ioConn.Read] and unblock the read as soon as Close is called (see #224). + // + // This leaks a goroutine if rwc.Read does not unblock after it is closed, + // but that is unavoidable since AFAIK there is no (easy and portable) way to + // guarantee that reads of stdin are unblocked when closed. + go func() { + dec := json.NewDecoder(rwc) + for { + var raw json.RawMessage + err := dec.Decode(&raw) + // If decoding was successful, check for trailing data at the end of the stream. + if err == nil { + // Read the next byte to check if there is trailing data. + var tr [1]byte + if n, readErr := dec.Buffered().Read(tr[:]); n > 0 { + // If read byte is not a newline, it is an error. + if tr[0] != '\n' { + err = fmt.Errorf("invalid trailing data at the end of stream") + } + } else if readErr != nil && readErr != io.EOF { + err = readErr + } + } + select { + case incoming <- msgOrErr{msg: raw, err: err}: + case <-closed: + return + } + if err != nil { + return + } + } + }() + return &ioConn{ + rwc: rwc, + incoming: incoming, + closed: closed, + } +} + +func (c *ioConn) SessionID() string { return "" } + +func (c *ioConn) sessionUpdated(state ServerSessionState) { + protocolVersion := "" + if state.InitializeParams != nil { + protocolVersion = state.InitializeParams.ProtocolVersion + } + if protocolVersion == "" { + protocolVersion = protocolVersion20250326 + } + c.protocolVersion = negotiatedVersion(protocolVersion) +} + +// addBatch records a msgBatch for an incoming batch payload. +// It returns an error if batch is malformed, containing previously seen IDs. +// +// See [msgBatch] for more. +func (t *ioConn) addBatch(batch *msgBatch) error { + t.batchMu.Lock() + defer t.batchMu.Unlock() + for id := range batch.unresolved { + if _, ok := t.batches[id]; ok { + return fmt.Errorf("%w: batch contains previously seen request %v", jsonrpc2.ErrInvalidRequest, id.Raw()) + } + } + for id := range batch.unresolved { + if t.batches == nil { + t.batches = make(map[jsonrpc2.ID]*msgBatch) + } + t.batches[id] = batch + } + return nil +} + +// updateBatch records a response in the message batch tracking the +// corresponding incoming call, if any. +// +// The second result reports whether resp was part of a batch. If this is true, +// the first result is nil if the batch is still incomplete, or the full set of +// batch responses if resp completed the batch. +func (t *ioConn) updateBatch(resp *jsonrpc.Response) ([]*jsonrpc.Response, bool) { + t.batchMu.Lock() + defer t.batchMu.Unlock() + + if batch, ok := t.batches[resp.ID]; ok { + idx, ok := batch.unresolved[resp.ID] + if !ok { + panic("internal error: inconsistent batches") + } + batch.responses[idx] = resp + delete(batch.unresolved, resp.ID) + delete(t.batches, resp.ID) + if len(batch.unresolved) == 0 { + return batch.responses, true + } + return nil, true + } + return nil, false +} + +// A msgBatch records information about an incoming batch of jsonrpc.2 calls. +// +// The jsonrpc.2 spec (https://www.jsonrpc.org/specification#batch) says: +// +// "The Server should respond with an Array containing the corresponding +// Response objects, after all of the batch Request objects have been +// processed. A Response object SHOULD exist for each Request object, except +// that there SHOULD NOT be any Response objects for notifications. The Server +// MAY process a batch rpc call as a set of concurrent tasks, processing them +// in any order and with any width of parallelism." +// +// Therefore, a msgBatch keeps track of outstanding calls and their responses. +// When there are no unresolved calls, the response payload is sent. +type msgBatch struct { + unresolved map[jsonrpc2.ID]int + responses []*jsonrpc.Response +} + +func (t *ioConn) Read(ctx context.Context) (jsonrpc.Message, error) { + // As a matter of principle, enforce that reads on a closed context return an + // error. + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + if len(t.queue) > 0 { + next := t.queue[0] + t.queue = t.queue[1:] + return next, nil + } + + var raw json.RawMessage + select { + case <-ctx.Done(): + return nil, ctx.Err() + + case v := <-t.incoming: + if v.err != nil { + return nil, v.err + } + raw = v.msg + + case <-t.closed: + return nil, io.EOF + } + + msgs, batch, err := readBatch(raw) + if err != nil { + return nil, err + } + if batch && t.protocolVersion >= protocolVersion20250618 { + return nil, fmt.Errorf("JSON-RPC batching is not supported in %s and later (request version: %s)", protocolVersion20250618, t.protocolVersion) + } + + t.queue = msgs[1:] + + if batch { + var respBatch *msgBatch // track incoming requests in the batch + for _, msg := range msgs { + if req, ok := msg.(*jsonrpc.Request); ok { + if respBatch == nil { + respBatch = &msgBatch{ + unresolved: make(map[jsonrpc2.ID]int), + } + } + if _, ok := respBatch.unresolved[req.ID]; ok { + return nil, fmt.Errorf("duplicate message ID %q", req.ID) + } + respBatch.unresolved[req.ID] = len(respBatch.responses) + respBatch.responses = append(respBatch.responses, nil) + } + } + if respBatch != nil { + // The batch contains one or more incoming requests to track. + if err := t.addBatch(respBatch); err != nil { + return nil, err + } + } + } + return msgs[0], err +} + +// readBatch reads batch data, which may be either a single JSON-RPC message, +// or an array of JSON-RPC messages. +func readBatch(data []byte) (msgs []jsonrpc.Message, isBatch bool, _ error) { + // Try to read an array of messages first. + var rawBatch []json.RawMessage + if err := json.Unmarshal(data, &rawBatch); err == nil { + if len(rawBatch) == 0 { + return nil, true, fmt.Errorf("empty batch") + } + for _, raw := range rawBatch { + msg, err := jsonrpc2.DecodeMessage(raw) + if err != nil { + return nil, true, err + } + msgs = append(msgs, msg) + } + return msgs, true, nil + } + // Try again with a single message. + msg, err := jsonrpc2.DecodeMessage(data) + return []jsonrpc.Message{msg}, false, err +} + +func (t *ioConn) Write(ctx context.Context, msg jsonrpc.Message) error { + // As in [ioConn.Read], enforce that Writes on a closed context are an error. + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + t.writeMu.Lock() + defer t.writeMu.Unlock() + + // Batching support: if msg is a Response, it may have completed a batch, so + // check that first. Otherwise, it is a request or notification, and we may + // want to collect it into a batch before sending, if we're configured to use + // outgoing batches. + if resp, ok := msg.(*jsonrpc.Response); ok { + if batch, ok := t.updateBatch(resp); ok { + if len(batch) > 0 { + data, err := marshalMessages(batch) + if err != nil { + return err + } + data = append(data, '\n') + _, err = t.rwc.Write(data) + return err + } + return nil + } + } else if len(t.outgoingBatch) < cap(t.outgoingBatch) { + t.outgoingBatch = append(t.outgoingBatch, msg) + if len(t.outgoingBatch) == cap(t.outgoingBatch) { + data, err := marshalMessages(t.outgoingBatch) + t.outgoingBatch = t.outgoingBatch[:0] + if err != nil { + return err + } + data = append(data, '\n') + _, err = t.rwc.Write(data) + return err + } + return nil + } + data, err := jsonrpc2.EncodeMessage(msg) + if err != nil { + return fmt.Errorf("marshaling message: %v", err) + } + data = append(data, '\n') // newline delimited + _, err = t.rwc.Write(data) + return err +} + +func (t *ioConn) Close() error { + t.closeOnce.Do(func() { + t.closeErr = t.rwc.Close() + close(t.closed) + }) + return t.closeErr +} + +func marshalMessages[T jsonrpc.Message](msgs []T) ([]byte, error) { + var rawMsgs []json.RawMessage + for _, msg := range msgs { + raw, err := jsonrpc2.EncodeMessage(msg) + if err != nil { + return nil, fmt.Errorf("encoding batch message: %w", err) + } + rawMsgs = append(rawMsgs, raw) + } + return json.Marshal(rawMsgs) +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/util.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/util.go new file mode 100644 index 0000000000..5ada466e50 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/util.go @@ -0,0 +1,43 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "crypto/rand" + "encoding/json" +) + +func assert(cond bool, msg string) { + if !cond { + panic(msg) + } +} + +// Copied from crypto/rand. +// TODO: once 1.24 is assured, just use crypto/rand. +const base32alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567" + +func randText() string { + // ⌈log₃₂ 2¹²⁸⌉ = 26 chars + src := make([]byte, 26) + rand.Read(src) + for i := range src { + src[i] = base32alphabet[src[i]%32] + } + return string(src) +} + +// remarshal marshals from to JSON, and then unmarshals into to, which must be +// a pointer type. +func remarshal(from, to any) error { + data, err := json.Marshal(from) + if err != nil { + return err + } + if err := json.Unmarshal(data, to); err != nil { + return err + } + return nil +} diff --git a/vendor/github.com/openai/openai-go/.gitignore b/vendor/github.com/openai/openai-go/.gitignore new file mode 100644 index 0000000000..c6d0501519 --- /dev/null +++ b/vendor/github.com/openai/openai-go/.gitignore @@ -0,0 +1,4 @@ +.prism.log +codegen.log +Brewfile.lock.json +.idea/ diff --git a/vendor/github.com/openai/openai-go/.release-please-manifest.json b/vendor/github.com/openai/openai-go/.release-please-manifest.json new file mode 100644 index 0000000000..de0960aba8 --- /dev/null +++ b/vendor/github.com/openai/openai-go/.release-please-manifest.json @@ -0,0 +1,3 @@ +{ + ".": "1.12.0" +} \ No newline at end of file diff --git a/vendor/github.com/openai/openai-go/.stats.yml b/vendor/github.com/openai/openai-go/.stats.yml new file mode 100644 index 0000000000..2f2ae96cde --- /dev/null +++ b/vendor/github.com/openai/openai-go/.stats.yml @@ -0,0 +1,4 @@ +configured_endpoints: 97 +openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/openai%2Fopenai-721e6ccaa72205ee14c71f8163129920464fb814b95d3df9567a9476bbd9b7fb.yml +openapi_spec_hash: 2115413a21df8b5bf9e4552a74df4312 +config_hash: 9606bb315a193bfd8da0459040143242 diff --git a/vendor/github.com/openai/openai-go/Brewfile b/vendor/github.com/openai/openai-go/Brewfile new file mode 100644 index 0000000000..577e34a4b3 --- /dev/null +++ b/vendor/github.com/openai/openai-go/Brewfile @@ -0,0 +1 @@ +brew "go" diff --git a/vendor/github.com/openai/openai-go/CHANGELOG.md b/vendor/github.com/openai/openai-go/CHANGELOG.md new file mode 100644 index 0000000000..16a13929b2 --- /dev/null +++ b/vendor/github.com/openai/openai-go/CHANGELOG.md @@ -0,0 +1,473 @@ +# Changelog + +## 1.12.0 (2025-07-30) + +Full Changelog: [v1.11.1...v1.12.0](https://github.com/openai/openai-go/compare/v1.11.1...v1.12.0) + +### Features + +* **api:** manual updates ([16312ea](https://github.com/openai/openai-go/commit/16312ea2fea76c7cd2db4f38dfa10e0839f52d3e)) + + +### Chores + +* **client:** refactor streaming slightly to better future proof it ([0b9cb85](https://github.com/openai/openai-go/commit/0b9cb85a6bf0f2386e5db13aed34fbfad645efbe)) + +## 1.11.1 (2025-07-22) + +Full Changelog: [v1.11.0...v1.11.1](https://github.com/openai/openai-go/compare/v1.11.0...v1.11.1) + +### Bug Fixes + +* **client:** process custom base url ahead of time ([cc1c23e](https://github.com/openai/openai-go/commit/cc1c23e3b1f4645004cb07b75816e3df445e73df)) + + +### Chores + +* **api:** event shapes more accurate ([2acd10d](https://github.com/openai/openai-go/commit/2acd10df4df52d1954d9ee3a98e5a4e56531533b)) + +## 1.11.0 (2025-07-16) + +Full Changelog: [v1.10.3...v1.11.0](https://github.com/openai/openai-go/compare/v1.10.3...v1.11.0) + +### Features + +* **api:** manual updates ([97ed7fd](https://github.com/openai/openai-go/commit/97ed7fd1d432ad0144ec76bcebb61c9aaa1148de)) + +## 1.10.3 (2025-07-15) + +Full Changelog: [v1.10.2...v1.10.3](https://github.com/openai/openai-go/compare/v1.10.2...v1.10.3) + +## 1.10.2 (2025-07-15) + +Full Changelog: [v1.10.1...v1.10.2](https://github.com/openai/openai-go/compare/v1.10.1...v1.10.2) + +### Chores + +* **api:** update realtime specs, build config ([3d2afda](https://github.com/openai/openai-go/commit/3d2afda006bd1f9e7ebde27b2873efa67e5e480d)) + +## 1.10.1 (2025-07-11) + +Full Changelog: [v1.10.0...v1.10.1](https://github.com/openai/openai-go/compare/v1.10.0...v1.10.1) + +### Chores + +* **api:** specification cleanup ([5dbf6d2](https://github.com/openai/openai-go/commit/5dbf6d2cebe770d980db7888d705d1642ccd9cbc)) +* lint tests in subpackages ([02f440d](https://github.com/openai/openai-go/commit/02f440dc6d899d7816b9fec9c47c09b393a7dd6c)) + +## 1.10.0 (2025-07-10) + +Full Changelog: [v1.9.0...v1.10.0](https://github.com/openai/openai-go/compare/v1.9.0...v1.10.0) + +### Features + +* **api:** add file_url, fix event ID ([cb33971](https://github.com/openai/openai-go/commit/cb339714b65249844a87009192b2cf1508329673)) + +## 1.9.0 (2025-07-10) + +Full Changelog: [v1.8.3...v1.9.0](https://github.com/openai/openai-go/compare/v1.8.3...v1.9.0) + +### Features + +* **client:** expand max streaming buffer size ([44390c8](https://github.com/openai/openai-go/commit/44390c81fdf33144f088b3ee8fef02269634dbe9)) + +## 1.8.3 (2025-07-08) + +Full Changelog: [v1.8.2...v1.8.3](https://github.com/openai/openai-go/compare/v1.8.2...v1.8.3) + +### Chores + +* **ci:** only run for pushes and fork pull requests ([d6aab99](https://github.com/openai/openai-go/commit/d6aab99dadf267201add9812ba34ab2d5c70e0f4)) +* **internal:** fix lint script for tests ([9c0a745](https://github.com/openai/openai-go/commit/9c0a74553c57ea5c29fb55f5ca2e122ca96031a4)) +* lint tests ([2bd38d2](https://github.com/openai/openai-go/commit/2bd38d248cf2097254d1821a44c87827805732d1)) + +## 1.8.2 (2025-06-27) + +Full Changelog: [v1.8.1...v1.8.2](https://github.com/openai/openai-go/compare/v1.8.1...v1.8.2) + +### Bug Fixes + +* don't try to deserialize as json when ResponseBodyInto is []byte ([74ad0f8](https://github.com/openai/openai-go/commit/74ad0f8fab0f956234503a9ba26fbd395944dcf8)) +* **pagination:** check if page data is empty in GetNextPage ([c9becdc](https://github.com/openai/openai-go/commit/c9becdc9908f2a1961160837c6ab8cd9064e7854)) + +## 1.8.1 (2025-06-26) + +Full Changelog: [v1.8.0...v1.8.1](https://github.com/openai/openai-go/compare/v1.8.0...v1.8.1) + +### Chores + +* **api:** remove unsupported property ([e22316a](https://github.com/openai/openai-go/commit/e22316adcd8f2c5aa672b12453cbd287de0e1878)) +* **docs:** update README to include links to docs on Webhooks ([7bb8f85](https://github.com/openai/openai-go/commit/7bb8f8549fdd98997b1d145cbae98ff0146b4e43)) + +## 1.8.0 (2025-06-26) + +Full Changelog: [v1.7.0...v1.8.0](https://github.com/openai/openai-go/compare/v1.7.0...v1.8.0) + +### Features + +* **api:** webhook and deep research support ([f6a7e7d](https://github.com/openai/openai-go/commit/f6a7e7dcd8801facc4f8d981f1ca43786c10de1e)) + + +### Chores + +* **internal:** add tests for breaking change detection ([339522d](https://github.com/openai/openai-go/commit/339522d38cd31b0753a8df37b8924f7e7dfb0b1d)) + +## 1.7.0 (2025-06-23) + +Full Changelog: [v1.6.0...v1.7.0](https://github.com/openai/openai-go/compare/v1.6.0...v1.7.0) + +### Features + +* **api:** make model and inputs not required to create response ([19f0b76](https://github.com/openai/openai-go/commit/19f0b76378d35b3d81c60c85bf2e64d6bf85b9c2)) +* **api:** update api shapes for usage and code interpreter ([d24d42c](https://github.com/openai/openai-go/commit/d24d42cba60e565627e8ffb1cac63a5085ddb6da)) +* **client:** add escape hatch for null slice & maps ([9c633d6](https://github.com/openai/openai-go/commit/9c633d6f1dbcc0b153f42f831ee7e13d6fe62296)) + + +### Chores + +* fix documentation of null map ([8f3a134](https://github.com/openai/openai-go/commit/8f3a134e500b1b7791ab855adaef2d7b10d2d1c3)) + +## 1.6.0 (2025-06-17) + +Full Changelog: [v1.5.0...v1.6.0](https://github.com/openai/openai-go/compare/v1.5.0...v1.6.0) + +### Features + +* **api:** add reusable prompt IDs ([280c698](https://github.com/openai/openai-go/commit/280c698015eba5f6bd47e2fce038eb401f6ef0f2)) +* **api:** manual updates ([740f840](https://github.com/openai/openai-go/commit/740f84006ac283a25f5ad96aaf845a3c8a51c6ac)) +* **client:** add debug log helper ([5715c49](https://github.com/openai/openai-go/commit/5715c491c483f8dab4ea2a900c400384f6810024)) + + +### Chores + +* **ci:** enable for pull requests ([9ed793a](https://github.com/openai/openai-go/commit/9ed793a51010423db464a7b7bd263d2fd275967f)) + +## 1.5.0 (2025-06-10) + +Full Changelog: [v1.4.0...v1.5.0](https://github.com/openai/openai-go/compare/v1.4.0...v1.5.0) + +### Features + +* **api:** Add o3-pro model IDs ([3bbd0b8](https://github.com/openai/openai-go/commit/3bbd0b8f09030a6c571900d444742c4fc2a3c211)) + +## 1.4.0 (2025-06-09) + +Full Changelog: [v1.3.0...v1.4.0](https://github.com/openai/openai-go/compare/v1.3.0...v1.4.0) + +### Features + +* **client:** allow overriding unions ([27c6299](https://github.com/openai/openai-go/commit/27c6299cb4ac275c6542b5691d81b795e65eeff6)) + + +### Bug Fixes + +* **client:** cast to raw message when converting to params ([a3282b0](https://github.com/openai/openai-go/commit/a3282b01a8d9a2c0cd04f24b298bf2ffcd160ebd)) + +## 1.3.0 (2025-06-03) + +Full Changelog: [v1.2.1...v1.3.0](https://github.com/openai/openai-go/compare/v1.2.1...v1.3.0) + +### Features + +* **api:** add new realtime and audio models, realtime session options ([8b8f62b](https://github.com/openai/openai-go/commit/8b8f62b8e185f3fe4aaa99e892df5d35638931a1)) + +## 1.2.1 (2025-06-02) + +Full Changelog: [v1.2.0...v1.2.1](https://github.com/openai/openai-go/compare/v1.2.0...v1.2.1) + +### Bug Fixes + +* **api:** Fix evals and code interpreter interfaces ([7e244c7](https://github.com/openai/openai-go/commit/7e244c73caad6b4768cced9a798452f03b1165c8)) +* fix error ([a200fca](https://github.com/openai/openai-go/commit/a200fca92c3fa413cf724f424077d1537fa2ca3e)) + + +### Chores + +* make go mod tidy continue on error ([48f41c2](https://github.com/openai/openai-go/commit/48f41c2993bf6181018da859ae759951261f9ee2)) + +## 1.2.0 (2025-05-29) + +Full Changelog: [v1.1.0...v1.2.0](https://github.com/openai/openai-go/compare/v1.1.0...v1.2.0) + +### Features + +* **api:** Config update for pakrym-stream-param ([84d59d5](https://github.com/openai/openai-go/commit/84d59d5cbc7521ddcc04435317903fd4ec3d17f6)) + + +### Bug Fixes + +* **client:** return binary content from `get /containers/{container_id}/files/{file_id}/content` ([f8c8de1](https://github.com/openai/openai-go/commit/f8c8de18b720b224267d54da53d7d919ed0fdff3)) + + +### Chores + +* deprecate Assistants API ([027470e](https://github.com/openai/openai-go/commit/027470e066ea6bbca1aeeb4fb9a8a3430babb84c)) +* **internal:** fix release workflows ([fd46533](https://github.com/openai/openai-go/commit/fd4653316312755ccab7435fca9fb0a2d8bf8fbb)) + +## 1.1.0 (2025-05-22) + +Full Changelog: [v1.0.0...v1.1.0](https://github.com/openai/openai-go/compare/v1.0.0...v1.1.0) + +### Features + +* **api:** add container endpoint ([2bd777d](https://github.com/openai/openai-go/commit/2bd777d6813b5dfcd3a2d339047a944c478dcd64)) +* **api:** new API tools ([e7e2123](https://github.com/openai/openai-go/commit/e7e2123de7cafef515e07adde6edd45a7035b610)) +* **api:** new streaming helpers for background responses ([422a0db](https://github.com/openai/openai-go/commit/422a0db3c674135e23dd200f5d8d785bd0be33e6)) + + +### Chores + +* **docs:** grammar improvements ([f4b23dd](https://github.com/openai/openai-go/commit/f4b23dd31facfc8839310854521b48060ef76be2)) +* improve devcontainer setup ([dfdaeec](https://github.com/openai/openai-go/commit/dfdaeec2d6dd5cd679514d60c49b68c5df9e1b1e)) + +## 1.0.0 (2025-05-19) + +Full Changelog: [v0.1.0-beta.11...v1.0.0](https://github.com/openai/openai-go/compare/v0.1.0-beta.11...v1.0.0) + +### ⚠ BREAKING CHANGES + +* **client:** rename file array param variant +* **api:** improve naming and remove assistants +* **accumulator:** update casing ([#401](https://github.com/openai/openai-go/issues/401)) + +### Features + +* **api:** improve naming and remove assistants ([4c623b8](https://github.com/openai/openai-go/commit/4c623b88a9025db1961cc57985eb7374342f43e7)) + + +### Bug Fixes + +* **accumulator:** update casing ([#401](https://github.com/openai/openai-go/issues/401)) ([d59453c](https://github.com/openai/openai-go/commit/d59453c95b89fdd0b51305778dec0a39ce3a9d2a)) +* **client:** correctly set stream key for multipart ([0ec68f0](https://github.com/openai/openai-go/commit/0ec68f0d779e7726931b1115eca9ae81eab59ba8)) +* **client:** don't panic on marshal with extra null field ([9c15332](https://github.com/openai/openai-go/commit/9c153320272d212beaa516d4c70d54ae8053a958)) +* **client:** increase max stream buffer size ([9456455](https://github.com/openai/openai-go/commit/945645559c5d68d9e28cf445d9c3b83e5fc6bd35)) +* **client:** rename file array param variant ([4cfcf86](https://github.com/openai/openai-go/commit/4cfcf869280e7531fbbc8c00db0dd9271d07c423)) +* **client:** use scanner for streaming ([aa58806](https://github.com/openai/openai-go/commit/aa58806bffc3aed68425c480414ddbb4dac3fa78)) + + +### Chores + +* **docs:** typo fix ([#400](https://github.com/openai/openai-go/issues/400)) ([bececf2](https://github.com/openai/openai-go/commit/bececf24cd0324b7c991b7d7f1d3eff6bf71f996)) +* **examples:** migrate enum ([#447](https://github.com/openai/openai-go/issues/447)) ([814dd8b](https://github.com/openai/openai-go/commit/814dd8b6cfe4eeb535dc8ecd161a409ea2eb6698)) +* **examples:** migrate to latest version ([#444](https://github.com/openai/openai-go/issues/444)) ([1c8754f](https://github.com/openai/openai-go/commit/1c8754ff905ed023f6381c8493910d63039407de)) +* **examples:** remove beta assisstants examples ([#445](https://github.com/openai/openai-go/issues/445)) ([5891583](https://github.com/openai/openai-go/commit/589158372be9c0517b5508f9ccd872fdb1fe480b)) +* **example:** update fine-tuning ([#450](https://github.com/openai/openai-go/issues/450)) ([421e3c5](https://github.com/openai/openai-go/commit/421e3c5065ace2d5ddd3d13a036477fff9123e5f)) + +## 0.1.0-beta.11 (2025-05-16) + +Full Changelog: [v0.1.0-beta.10...v0.1.0-beta.11](https://github.com/openai/openai-go/compare/v0.1.0-beta.10...v0.1.0-beta.11) + +### ⚠ BREAKING CHANGES + +* **client:** clearer array variant names +* **client:** rename resp package +* **client:** improve core function names +* **client:** improve union variant names +* **client:** improve param subunions & deduplicate types + +### Features + +* **api:** add image sizes, reasoning encryption ([0852fb3](https://github.com/openai/openai-go/commit/0852fb3101dc940761f9e4f32875bfcf3669eada)) +* **api:** add o3 and o4-mini model IDs ([3fabca6](https://github.com/openai/openai-go/commit/3fabca6b5c610edfb7bcd0cab5334a06444df0b0)) +* **api:** Add reinforcement fine-tuning api support ([831a124](https://github.com/openai/openai-go/commit/831a12451cfce907b5ae4d294b9c2ac95f40d97a)) +* **api:** adding gpt-4.1 family of model IDs ([1ef19d4](https://github.com/openai/openai-go/commit/1ef19d4cc94992dc435d7d5f28b30c9b1d255cd4)) +* **api:** adding new image model support ([bf17880](https://github.com/openai/openai-go/commit/bf17880e182549c5c0fc34ec05df3184f223bc00)) +* **api:** manual updates ([11f5716](https://github.com/openai/openai-go/commit/11f5716afa86aa100f80f3fa127e1d49203e5e21)) +* **api:** responses x eval api ([183aaf7](https://github.com/openai/openai-go/commit/183aaf700f1d7ffad4ac847627d9ace65379c459)) +* **api:** Updating Assistants and Evals API schemas ([47ca619](https://github.com/openai/openai-go/commit/47ca619fa1b439cf3a68c98e48e9bf1942f0568b)) +* **client:** add dynamic streaming buffer to handle large lines ([8e6aad6](https://github.com/openai/openai-go/commit/8e6aad6d54fc73f1fcc174e1f06c9b3cf00c2689)) +* **client:** add helper method to generate constant structs ([ff82809](https://github.com/openai/openai-go/commit/ff828094b561fc11184fed83f04424b6f68f7781)) +* **client:** add support for endpoint-specific base URLs in python ([072dce4](https://github.com/openai/openai-go/commit/072dce46486d373fa0f0de5415f5270b01c2d972)) +* **client:** add support for reading base URL from environment variable ([0d37268](https://github.com/openai/openai-go/commit/0d372687d673990290bad583f1906a2b121960b2)) +* **client:** clearer array variant names ([a5d8b5d](https://github.com/openai/openai-go/commit/a5d8b5d6b161e3083184586840b2cbe0606d8de1)) +* **client:** experimental support for unmarshalling into param structs ([5234875](https://github.com/openai/openai-go/commit/523487582e15a47e2f409f183568551258f4b8fe)) +* **client:** improve param subunions & deduplicate types ([8a78f37](https://github.com/openai/openai-go/commit/8a78f37c25abf10498d16d210de3078f491ff23e)) +* **client:** rename resp package ([4433516](https://github.com/openai/openai-go/commit/443351625ee290937a25425719b099ce785bd21b)) +* **client:** support more time formats ([ec171b2](https://github.com/openai/openai-go/commit/ec171b2405c46f9cf04560760da001f7133d2fec)) +* fix lint ([9c50a1e](https://github.com/openai/openai-go/commit/9c50a1eb9f93b578cb78085616f6bfab69f21dbc)) + + +### Bug Fixes + +* **client:** clean up reader resources ([710b92e](https://github.com/openai/openai-go/commit/710b92eaa7e94c03aeeca7479668677b32acb154)) +* **client:** correctly update body in WithJSONSet ([f2d7118](https://github.com/openai/openai-go/commit/f2d7118295dd3073aa449426801d02e6f60bdaa3)) +* **client:** improve core function names ([9f312a9](https://github.com/openai/openai-go/commit/9f312a9b14f5424d44d5834f1b82f3d3fcd57db2)) +* **client:** improve union variant names ([a2c3de9](https://github.com/openai/openai-go/commit/a2c3de9e6c9f6e406b953f6de2eb78d1e72ec1b5)) +* **client:** include path for type names in example code ([69561c5](https://github.com/openai/openai-go/commit/69561c549e18bd16a3641d62769479b125a4e955)) +* **client:** resolve issue with optional multipart files ([910d173](https://github.com/openai/openai-go/commit/910d1730e97a03898e5dee7c889844a2ccec3e56)) +* **client:** time format encoding fix ([ca17553](https://github.com/openai/openai-go/commit/ca175533ac8a17d36be1f531bbaa89c770da3f58)) +* **client:** unmarshal responses properly ([fc9fec3](https://github.com/openai/openai-go/commit/fc9fec3c466ba9f633c3f7a4eebb5ebd3b85e8ac)) +* handle empty bodies in WithJSONSet ([8372464](https://github.com/openai/openai-go/commit/83724640c6c00dcef1547dcabace309f17d14afc)) +* **pagination:** handle errors when applying options ([eebf84b](https://github.com/openai/openai-go/commit/eebf84bf19f0eb6d9fa21e64bb83b0258e8cb42c)) + + +### Chores + +* **ci:** add timeout thresholds for CI jobs ([26b0dd7](https://github.com/openai/openai-go/commit/26b0dd760c142ca3aa287e8441bbe44cc8b3be0b)) +* **ci:** only use depot for staging repos ([7682154](https://github.com/openai/openai-go/commit/7682154fdbcbe2a2ffdb2df590647a1712d52275)) +* **ci:** run on more branches and use depot runners ([d7badbc](https://github.com/openai/openai-go/commit/d7badbc0d17bcf3cffec332f65cb68e531cb3176)) +* **docs:** document pre-request options ([4befa5a](https://github.com/openai/openai-go/commit/4befa5a48ca61372715f36c45e72eb159d95bf2d)) +* **docs:** update respjson package name ([9a00229](https://github.com/openai/openai-go/commit/9a002299a91e1145f053c51b1a4de10298fd2f43)) +* **readme:** improve formatting ([a847e8d](https://github.com/openai/openai-go/commit/a847e8df45f725f9652fcea53ce57d3b9046efc7)) +* **utils:** add internal resp to param utility ([239c4e2](https://github.com/openai/openai-go/commit/239c4e2cb32c7af71ab14668ccc2f52ea59653f9)) + + +### Documentation + +* update documentation links to be more uniform ([f5f0bb0](https://github.com/openai/openai-go/commit/f5f0bb05ee705d84119806f8e703bf2e0becb1fa)) + +## 0.1.0-beta.10 (2025-04-14) + +Full Changelog: [v0.1.0-beta.9...v0.1.0-beta.10](https://github.com/openai/openai-go/compare/v0.1.0-beta.9...v0.1.0-beta.10) + +### Chores + +* **internal:** expand CI branch coverage ([#369](https://github.com/openai/openai-go/issues/369)) ([258dda8](https://github.com/openai/openai-go/commit/258dda8007a69b9c2720b225ee6d27474d676a93)) +* **internal:** reduce CI branch coverage ([a2f7c03](https://github.com/openai/openai-go/commit/a2f7c03eb984d98f29f908df103ea1743f2e3d9a)) + +## 0.1.0-beta.9 (2025-04-09) + +Full Changelog: [v0.1.0-beta.8...v0.1.0-beta.9](https://github.com/openai/openai-go/compare/v0.1.0-beta.8...v0.1.0-beta.9) + +### Chores + +* workaround build errors ([#366](https://github.com/openai/openai-go/issues/366)) ([adeb003](https://github.com/openai/openai-go/commit/adeb003cab8efbfbf4424e03e96a0f5e728551cb)) + +## 0.1.0-beta.8 (2025-04-09) + +Full Changelog: [v0.1.0-beta.7...v0.1.0-beta.8](https://github.com/openai/openai-go/compare/v0.1.0-beta.7...v0.1.0-beta.8) + +### Features + +* **api:** Add evalapi to sdk ([#360](https://github.com/openai/openai-go/issues/360)) ([88977d1](https://github.com/openai/openai-go/commit/88977d1868dbbe0060c56ba5dac8eb19773e4938)) +* **api:** manual updates ([#363](https://github.com/openai/openai-go/issues/363)) ([5d068e0](https://github.com/openai/openai-go/commit/5d068e0053172db7f5b75038aa215eee074eeeed)) +* **client:** add escape hatch to omit required param fields ([#354](https://github.com/openai/openai-go/issues/354)) ([9690d6b](https://github.com/openai/openai-go/commit/9690d6b49f8b00329afc038ec15116750853e620)) +* **client:** support custom http clients ([#357](https://github.com/openai/openai-go/issues/357)) ([b5a624f](https://github.com/openai/openai-go/commit/b5a624f658cad774094427b36b05e446b41e8c52)) + + +### Chores + +* **docs:** readme improvements ([#356](https://github.com/openai/openai-go/issues/356)) ([b2f8539](https://github.com/openai/openai-go/commit/b2f8539d6316e3443aa733be2c95926696119c13)) +* **internal:** fix examples ([#361](https://github.com/openai/openai-go/issues/361)) ([de398b4](https://github.com/openai/openai-go/commit/de398b453d398299eb80c15f8fdb2bcbef5eeed6)) +* **internal:** skip broken test ([#362](https://github.com/openai/openai-go/issues/362)) ([cccead9](https://github.com/openai/openai-go/commit/cccead9ba916142ac8fbe6e8926d706511e32ae3)) +* **tests:** improve enum examples ([#359](https://github.com/openai/openai-go/issues/359)) ([e0b9739](https://github.com/openai/openai-go/commit/e0b9739920114d6e991d3947b67fdf62cfaa09c7)) + +## 0.1.0-beta.7 (2025-04-07) + +Full Changelog: [v0.1.0-beta.6...v0.1.0-beta.7](https://github.com/openai/openai-go/compare/v0.1.0-beta.6...v0.1.0-beta.7) + +### Features + +* **client:** make response union's AsAny method type safe ([#352](https://github.com/openai/openai-go/issues/352)) ([1252f56](https://github.com/openai/openai-go/commit/1252f56c917e57d6d2b031501b2ff5f89f87cf87)) + + +### Chores + +* **docs:** doc improvements ([#350](https://github.com/openai/openai-go/issues/350)) ([80debc8](https://github.com/openai/openai-go/commit/80debc824eaacb4b07c8f3e8b1d0488d860d5be5)) + +## 0.1.0-beta.6 (2025-04-04) + +Full Changelog: [v0.1.0-beta.5...v0.1.0-beta.6](https://github.com/openai/openai-go/compare/v0.1.0-beta.5...v0.1.0-beta.6) + +### Features + +* **api:** manual updates ([4e39609](https://github.com/openai/openai-go/commit/4e39609d499b88039f1c90cc4b56e26f28fd58ea)) +* **client:** support unions in query and forms ([#347](https://github.com/openai/openai-go/issues/347)) ([cf8af37](https://github.com/openai/openai-go/commit/cf8af373ab7c019c75e886855009ffaca320d0e3)) + +## 0.1.0-beta.5 (2025-04-03) + +Full Changelog: [v0.1.0-beta.4...v0.1.0-beta.5](https://github.com/openai/openai-go/compare/v0.1.0-beta.4...v0.1.0-beta.5) + +### Features + +* **api:** manual updates ([563cc50](https://github.com/openai/openai-go/commit/563cc505f2ab17749bb77e937342a6614243b975)) +* **client:** omitzero on required id parameter ([#339](https://github.com/openai/openai-go/issues/339)) ([c0b4842](https://github.com/openai/openai-go/commit/c0b484266ccd9faee66873916d8c0c92ea9f1014)) + + +### Bug Fixes + +* **client:** return error on bad custom url instead of panic ([#341](https://github.com/openai/openai-go/issues/341)) ([a06c5e6](https://github.com/openai/openai-go/commit/a06c5e632242e53d3fdcc8964931acb533a30b7e)) +* **client:** support multipart encoding array formats ([#342](https://github.com/openai/openai-go/issues/342)) ([5993b28](https://github.com/openai/openai-go/commit/5993b28309d02c2d748b54d98934ef401dcd193a)) +* **client:** unmarshal stream events into fresh memory ([#340](https://github.com/openai/openai-go/issues/340)) ([52c3e08](https://github.com/openai/openai-go/commit/52c3e08f51d471d728e5acd16b3c304b51be2d03)) + +## 0.1.0-beta.4 (2025-04-02) + +Full Changelog: [v0.1.0-beta.3...v0.1.0-beta.4](https://github.com/openai/openai-go/compare/v0.1.0-beta.3...v0.1.0-beta.4) + +### Features + +* **api:** manual updates ([bc4fe73](https://github.com/openai/openai-go/commit/bc4fe73eec9c4d39229e4beae8eaafb55b1d3364)) +* **api:** manual updates ([aa7ff10](https://github.com/openai/openai-go/commit/aa7ff10b0616a6b2ece45cb10e9c83f25e35aded)) + + +### Chores + +* **docs:** update file uploads in README ([#333](https://github.com/openai/openai-go/issues/333)) ([471c452](https://github.com/openai/openai-go/commit/471c4525c94e83cf4b78cb6c9b2f65a8a27bf3ce)) +* **internal:** codegen related update ([#335](https://github.com/openai/openai-go/issues/335)) ([48422dc](https://github.com/openai/openai-go/commit/48422dcca333ab808ccb02506c033f1c69d2aa19)) +* Remove deprecated/unused remote spec feature ([c5077a1](https://github.com/openai/openai-go/commit/c5077a154a6db79b73cf4978bdc08212c6da6423)) + +## 0.1.0-beta.3 (2025-03-28) + +Full Changelog: [v0.1.0-beta.2...v0.1.0-beta.3](https://github.com/openai/openai-go/compare/v0.1.0-beta.2...v0.1.0-beta.3) + +### ⚠ BREAKING CHANGES + +* **client:** add enums ([#327](https://github.com/openai/openai-go/issues/327)) + +### Features + +* **api:** add `get /chat/completions` endpoint ([e8ed116](https://github.com/openai/openai-go/commit/e8ed1168576c885cb26fbf819b9c8d24975749bd)) +* **api:** add `get /responses/{response_id}/input_items` endpoint ([8870c26](https://github.com/openai/openai-go/commit/8870c26f010a596adcf37ac10dba096bdd4394e3)) + + +### Bug Fixes + +* **client:** add enums ([#327](https://github.com/openai/openai-go/issues/327)) ([b0e3afb](https://github.com/openai/openai-go/commit/b0e3afbd6f18fd9fc2a5ea9174bd7ec0ac0614db)) + + +### Chores + +* add hash of OpenAPI spec/config inputs to .stats.yml ([104b786](https://github.com/openai/openai-go/commit/104b7861bb025514999b143f7d1de45d2dab659f)) +* add request options to client tests ([#321](https://github.com/openai/openai-go/issues/321)) ([f5239ce](https://github.com/openai/openai-go/commit/f5239ceecf36835341eac5121ed1770020c4806a)) +* **api:** updates to supported Voice IDs ([#325](https://github.com/openai/openai-go/issues/325)) ([477727a](https://github.com/openai/openai-go/commit/477727a44b0fb72493c4749cc60171e0d30f98ec)) +* **docs:** improve security documentation ([#319](https://github.com/openai/openai-go/issues/319)) ([0271053](https://github.com/openai/openai-go/commit/027105363ab30ac3e189234908169faf94e0ca49)) +* fix typos ([#324](https://github.com/openai/openai-go/issues/324)) ([dba15f7](https://github.com/openai/openai-go/commit/dba15f74d63814ce16f778e1017a209a42f46179)) + +## 0.1.0-beta.2 (2025-03-22) + +Full Changelog: [v0.1.0-beta.1...v0.1.0-beta.2](https://github.com/openai/openai-go/compare/v0.1.0-beta.1...v0.1.0-beta.2) + +### Bug Fixes + +* **client:** elide fields in ToAssistantParam ([#309](https://github.com/openai/openai-go/issues/309)) ([1fcd837](https://github.com/openai/openai-go/commit/1fcd83753ea806745d278a5b94797bbee0f018ed)) + +## 0.1.0-beta.1 (2025-03-22) + +Full Changelog: [v0.1.0-alpha.67...v0.1.0-beta.1](https://github.com/openai/openai-go/compare/v0.1.0-alpha.67...v0.1.0-beta.1) + +### Chores + +* **docs:** clarify breaking changes ([#306](https://github.com/openai/openai-go/issues/306)) ([db4bd1f](https://github.com/openai/openai-go/commit/db4bd1f5304aa523a6b62da6e2571487d4248518)) + +## 0.1.0-alpha.67 (2025-03-21) + +Full Changelog: [v0.1.0-alpha.66...v0.1.0-alpha.67](https://github.com/openai/openai-go/compare/v0.1.0-alpha.66...v0.1.0-alpha.67) + +### ⚠ BREAKING CHANGES + +* **api:** migrate to v2 + +### Features + +* **api:** migrate to v2 ([9377508](https://github.com/openai/openai-go/commit/9377508e45ae485d11c3199d6d3d91d345f1b76e)) +* **api:** new models for TTS, STT, + new audio features for Realtime ([#298](https://github.com/openai/openai-go/issues/298)) ([48fa064](https://github.com/openai/openai-go/commit/48fa064202a6e4a3e850d435b29f6fe9a1fe53f4)) + + +### Chores + +* **internal:** bugfix ([0d8c1f4](https://github.com/openai/openai-go/commit/0d8c1f4e801785728b6ad3342146fe38874d6c04)) + + +### Documentation + +* add migration guide ([#302](https://github.com/openai/openai-go/issues/302)) ([19e32fa](https://github.com/openai/openai-go/commit/19e32fa595e65048bb129e813c697991117abca2)) diff --git a/vendor/github.com/openai/openai-go/CONTRIBUTING.md b/vendor/github.com/openai/openai-go/CONTRIBUTING.md new file mode 100644 index 0000000000..95426be2d2 --- /dev/null +++ b/vendor/github.com/openai/openai-go/CONTRIBUTING.md @@ -0,0 +1,66 @@ +## Setting up the environment + +To set up the repository, run: + +```sh +$ ./scripts/bootstrap +$ ./scripts/lint +``` + +This will install all the required dependencies and build the SDK. + +You can also [install go 1.18+ manually](https://go.dev/doc/install). + +## Modifying/Adding code + +Most of the SDK is generated code. Modifications to code will be persisted between generations, but may +result in merge conflicts between manual patches and changes from the generator. The generator will never +modify the contents of the `lib/` and `examples/` directories. + +## Adding and running examples + +All files in the `examples/` directory are not modified by the generator and can be freely edited or added to. + +```go +# add an example to examples//main.go + +package main + +func main() { + // ... +} +``` + +```sh +$ go run ./examples/ +``` + +## Using the repository from source + +To use a local version of this library from source in another project, edit the `go.mod` with a replace +directive. This can be done through the CLI with the following: + +```sh +$ go mod edit -replace github.com/openai/openai-go=/path/to/openai-go +``` + +## Running tests + +Most tests require you to [set up a mock server](https://github.com/stoplightio/prism) against the OpenAPI spec to run the tests. + +```sh +# you will need npm installed +$ npx prism mock path/to/your/openapi.yml +``` + +```sh +$ ./scripts/test +``` + +## Formatting + +This library uses the standard gofmt code formatter: + +```sh +$ ./scripts/format +``` diff --git a/vendor/github.com/openai/openai-go/LICENSE b/vendor/github.com/openai/openai-go/LICENSE new file mode 100644 index 0000000000..f011417af6 --- /dev/null +++ b/vendor/github.com/openai/openai-go/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2025 OpenAI + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/github.com/openai/openai-go/MIGRATION.md b/vendor/github.com/openai/openai-go/MIGRATION.md new file mode 100644 index 0000000000..54990b8aca --- /dev/null +++ b/vendor/github.com/openai/openai-go/MIGRATION.md @@ -0,0 +1,284 @@ +# OpenAI Go Migration Guide + +Go Reference + +This SDK includes breaking changes to improve the ergonomics of constructing parameters and accessing responses. + +To reduce verbosity, the `openai.F(...)` and `param.Field[T]` have been removed. +All calls to `openai.F(...)` can be deleted. + +The SDK now uses the \`json:"...,omitzero"\` struct tag to omit fields. Nested structs, arrays and maps +can be declared like normal. + +The old SDK used interfaces for unions in requests, which required +a type assertion to access variants and fields. The new design uses +structs with a field for each variant, wherein only one field can be set. +These struct unions also expose 'Get' methods to access and mutate subfields +which may be shared by multiple variants. + +# Request parameters + +## Required primitives parameters serialize their zero values (`string`, `int64`, etc.) + +> [!CAUTION] +> +> **This change can cause new behavior in existing code, without compiler warnings.** + +While migrating, ensure that all required fields are explicitly set. A required primitive +field `Age` will use the \`json:"age,required"\` struct tag without `omitzero`. + +If a required primitive field is not set, the zero value will be serialized. +This was not the case in with `param.Field[T]`. + +```diff +type FooParams struct { +- Age param.Field[int64] `json:"age,required"` +- Name param.Field[string] `json:"name"` ++ Age int64 `json:"age,required"` // <== Notice no omitzero ++ Name param.Opt[string] `json:"name,omitzero"` +} +``` + + + + + + + + + + +
PreviousNew
+ +```go +_ = FooParams{ + Name: openai.String("Jerry") +} +`{"name": "Jerry"}` // (after serialization) +``` + + + +```go +_ = FooParams{ + Name: openai.String("Jerry") +} +`{"name": "Jerry", "age": 0}` // <== Notice the age field +``` + +
+ +The required field `"age"` is now present as `0`. Fields without the \`json:"...,omitzero"\` struct tag +are always serialized, including their zero values. + +## Transition from `param.Field[T]` to `omitzero` + +The `openai.F(...)` function and `param.Field[T]` type are no longer present in the new SDK. + +To represent omitted fields, the SDK uses \`json:"...,omitzero"\` semantics from Go 1.24+ for JSON encoding[^1]. `omitzero` always omits fields +with zero values. + +In all cases other than optional primitives, `openai.F()` can simply be removed. +For optional primitive types, such as `param.Opt[string]`, you can use `openai.String(string)` to construct the value. +Similar functions exist for other primitive types like `openai.Int(int)`, `openai.Bool(bool)`, etc. + +`omitzero` is used for fields whose type is either a struct, slice, map, string enum, +or wrapped optional primitive (e.g. `param.Opt[T]`). Required primitive fields don't use `omitzero`. + +**Example User Code: Constructing a request** + +```diff +foo = FooParams{ +- RequiredString: openai.String("hello"), ++ RequiredString: "hello", + +- OptionalString: openai.String("hi"), ++ OptionalString: openai.String("hi"), + +- Array: openai.F([]BarParam{ +- BarParam{Prop: ... } +- }), ++ Array: []BarParam{ ++ BarParam{Prop: ... } ++ }, + +- RequiredObject: openai.F(BarParam{ ... }), ++ RequiredObject: BarParam{ ... }, + +- OptionalObject: openai.F(BarParam{ ... }), ++ OptionalObject: BarParam{ ... }, + +- StringEnum: openai.F[BazEnum]("baz-ok"), ++ StringEnum: "baz-ok", +} +``` + +**Internal SDK Code: Fields of a request struct:** + +```diff +type FooParams struct { +- RequiredString param.Field[string] `json:"required_string,required"` ++ RequiredString string `json:"required_string,required"` + +- OptionalString param.Field[string] `json:"optional_string"` ++ OptionalString param.Opt[string] `json:"optional_string,omitzero"` + +- Array param.Field[[]BarParam] `json"array"` ++ Array []BarParam `json"array,omitzero"` + +- Map param.Field[map[string]BarParam] `json"map"` ++ Map map[string]BarParam `json"map,omitzero"` + +- RequiredObject param.Field[BarParam] `json:"required_object,required"` ++ RequiredObject BarParam `json:"required_object,omitzero,required"` + +- OptionalObject param.Field[BarParam] `json:"optional_object"` ++ OptionalObject BarParam `json:"optional_object,omitzero"` + +- StringEnum param.Field[BazEnum] `json:"string_enum"` ++ StringEnum BazEnum `json:"string_enum,omitzero"` +} +``` + +## Request Unions: Removing interfaces and moving to structs + +For a type `AnimalUnionParam` which could be either a `CatParam | DogParam`. + + + + + + + + + + + + + + + + + +
Previous New
+ +```go +type AnimalParam interface { + ImplAnimalParam() +} + +func (Dog) ImplAnimalParam() {} +func (Cat) ImplAnimalParam() {} +``` + + + +```go +type AnimalUnionParam struct { + OfCat *Cat `json:",omitzero,inline` + OfDog *Dog `json:",omitzero,inline` +} +``` + +
+ +```go +var dog AnimalParam = DogParam{ + Name: "spot", ... +} +var cat AnimalParam = CatParam{ + Name: "whiskers", ... +} +``` + + + +```go +dog := AnimalUnionParam{ + OfDog: &DogParam{Name: "spot", ... }, +} +cat := AnimalUnionParam{ + OfCat: &CatParam{Name: "whiskers", ... }, +} +``` + +
+ +```go +var name string +switch v := animal.(type) { +case Dog: + name = v.Name +case Cat: + name = v.Name +} +``` + + + +```go +// Accessing fields +var name *string = animal.GetName() +``` + +
+ +## Sending explicit `null` values + +The old SDK had a function `param.Null[T]()` which could set `param.Field[T]` to `null`. + +The new SDK uses `param.Null[T]()` for to set a `param.Opt[T]` to `null`, +but `param.NullStruct[T]()` to set a param struct `T` to `null`. + +```diff +- var nullPrimitive param.Field[int64] = param.Null[int64]() ++ var nullPrimitive param.Opt[int64] = param.Null[int64]() + +- var nullStruct param.Field[BarParam] = param.Null[BarParam]() ++ var nullStruct BarParam = param.NullStruct[BarParam]() +``` + +## Sending custom values + +The `openai.Raw[T](any)` function has been removed. All request structs now support a +`.WithExtraField(map[string]any)` method to customize the fields. + +```diff +foo := FooParams{ + A: param.String("hello"), +- B: param.Raw[string](12) // sending `12` instead of a string +} ++ foo.SetExtraFields(map[string]any{ ++ "B": 12, ++ }) +``` + +# Response Properties + +## Checking for presence of optional fields + +The `.IsNull()` method has been changed to `.Valid()` to better reflect its behavior. + +```diff +- if !resp.Foo.JSON.Bar.IsNull() { ++ if resp.Foo.JSON.Bar.Valid() { + println("bar is present:", resp.Foo.Bar) +} +``` + +| Previous | New | Returns true for values | +| -------------- | ------------------------ | ----------------------- | +| `.IsNull()` | `!.Valid()` | `null` or Omitted | +| `.IsMissing()` | `.Raw() == resp.Omitted` | Omitted | +| | `.Raw() == resp.Null` | + +## Checking Raw JSON of a response + +The `.RawJSON()` method has moved to the parent of the `.JSON` property. + +```diff +- resp.Foo.JSON.RawJSON() ++ resp.Foo.RawJSON() +``` + +[^1]: The SDK doesn't require Go 1.24, despite supporting the `omitzero` feature diff --git a/vendor/github.com/openai/openai-go/README.md b/vendor/github.com/openai/openai-go/README.md new file mode 100644 index 0000000000..153fc9367d --- /dev/null +++ b/vendor/github.com/openai/openai-go/README.md @@ -0,0 +1,948 @@ +# OpenAI Go API Library + +Go Reference + +The OpenAI Go library provides convenient access to the [OpenAI REST API](https://platform.openai.com/docs) +from applications written in Go. + +> [!WARNING] +> The latest version of this package uses a new design with significant breaking changes. +> Please refer to the [migration guide](./MIGRATION.md) for more information on how to update your code. + +## Installation + + + +```go +import ( + "github.com/openai/openai-go" // imported as openai +) +``` + + + +Or to pin the version: + + + +```sh +go get -u 'github.com/openai/openai-go@v1.12.0' +``` + + + +## Requirements + +This library requires Go 1.18+. + +## Usage + +The full API of this library can be found in [api.md](api.md). + +```go +package main + +import ( + "context" + "fmt" + + "github.com/openai/openai-go" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/shared" +) + +func main() { + client := openai.NewClient( + option.WithAPIKey("My API Key"), // defaults to os.LookupEnv("OPENAI_API_KEY") + ) + chatCompletion, err := client.Chat.Completions.New(context.TODO(), openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("Say this is a test"), + }, + Model: openai.ChatModelGPT4o, + }) + if err != nil { + panic(err.Error()) + } + println(chatCompletion.Choices[0].Message.Content) +} + +``` + +
+Conversations + +```go +param := openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("What kind of houseplant is easy to take care of?"), + }, + Seed: openai.Int(1), + Model: openai.ChatModelGPT4o, +} + +completion, err := client.Chat.Completions.New(ctx, param) + +param.Messages = append(param.Messages, completion.Choices[0].Message.ToParam()) +param.Messages = append(param.Messages, openai.UserMessage("How big are those?")) + +// continue the conversation +completion, err = client.Chat.Completions.New(ctx, param) +``` + +
+ +
+Streaming responses + +```go +question := "Write an epic" + +stream := client.Chat.Completions.NewStreaming(ctx, openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(question), + }, + Seed: openai.Int(0), + Model: openai.ChatModelGPT4o, +}) + +// optionally, an accumulator helper can be used +acc := openai.ChatCompletionAccumulator{} + +for stream.Next() { + chunk := stream.Current() + acc.AddChunk(chunk) + + if content, ok := acc.JustFinishedContent(); ok { + println("Content stream finished:", content) + } + + // if using tool calls + if tool, ok := acc.JustFinishedToolCall(); ok { + println("Tool call stream finished:", tool.Index, tool.Name, tool.Arguments) + } + + if refusal, ok := acc.JustFinishedRefusal(); ok { + println("Refusal stream finished:", refusal) + } + + // it's best to use chunks after handling JustFinished events + if len(chunk.Choices) > 0 { + println(chunk.Choices[0].Delta.Content) + } +} + +if stream.Err() != nil { + panic(stream.Err()) +} + +// After the stream is finished, acc can be used like a ChatCompletion +_ = acc.Choices[0].Message.Content +``` + +> See the [full streaming and accumulation example](./examples/chat-completion-accumulating/main.go) + +
+ +
+Tool calling + +```go +import ( + "encoding/json" + // ... +) + +// ... + +question := "What is the weather in New York City?" + +params := openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(question), + }, + Tools: []openai.ChatCompletionToolParam{ + { + Function: openai.FunctionDefinitionParam{ + Name: "get_weather", + Description: openai.String("Get weather at the given location"), + Parameters: openai.FunctionParameters{ + "type": "object", + "properties": map[string]interface{}{ + "location": map[string]string{ + "type": "string", + }, + }, + "required": []string{"location"}, + }, + }, + }, + }, + Model: openai.ChatModelGPT4o, +} + +// If there is a was a function call, continue the conversation +params.Messages = append(params.Messages, completion.Choices[0].Message.ToParam()) +for _, toolCall := range toolCalls { + if toolCall.Function.Name == "get_weather" { + // Extract the location from the function call arguments + var args map[string]interface{} + err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args) + if err != nil { + panic(err) + } + location := args["location"].(string) + + // Simulate getting weather data + weatherData := getWeather(location) + + // Print the weather data + fmt.Printf("Weather in %s: %s\n", location, weatherData) + + params.Messages = append(params.Messages, openai.ToolMessage(weatherData, toolCall.ID)) + } +} + +// ... continue the conversation with the information provided by the tool +``` + +> See the [full tool calling example](./examples/chat-completion-tool-calling/main.go) + +
+ +
+Structured outputs + +```go +import ( + "encoding/json" + "github.com/invopop/jsonschema" + // ... +) + +// A struct that will be converted to a Structured Outputs response schema +type HistoricalComputer struct { + Origin Origin `json:"origin" jsonschema_description:"The origin of the computer"` + Name string `json:"full_name" jsonschema_description:"The name of the device model"` + Legacy string `json:"legacy" jsonschema:"enum=positive,enum=neutral,enum=negative" jsonschema_description:"Its influence on the field of computing"` + NotableFacts []string `json:"notable_facts" jsonschema_description:"A few key facts about the computer"` +} + +type Origin struct { + YearBuilt int64 `json:"year_of_construction" jsonschema_description:"The year it was made"` + Organization string `json:"organization" jsonschema_description:"The organization that was in charge of its development"` +} + +func GenerateSchema[T any]() interface{} { + // Structured Outputs uses a subset of JSON schema + // These flags are necessary to comply with the subset + reflector := jsonschema.Reflector{ + AllowAdditionalProperties: false, + DoNotReference: true, + } + var v T + schema := reflector.Reflect(v) + return schema +} + +// Generate the JSON schema at initialization time +var HistoricalComputerResponseSchema = GenerateSchema[HistoricalComputer]() + +func main() { + + // ... + + question := "What computer ran the first neural network?" + + schemaParam := openai.ResponseFormatJSONSchemaJSONSchemaParam{ + Name: "historical_computer", + Description: openai.String("Notable information about a computer"), + Schema: HistoricalComputerResponseSchema, + Strict: openai.Bool(true), + } + + chat, _ := client.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ + // ... + ResponseFormat: openai.ChatCompletionNewParamsResponseFormatUnion{ + OfJSONSchema: &openai.ResponseFormatJSONSchemaParam{ + JSONSchema: schemaParam, + }, + }, + // only certain models can perform structured outputs + Model: openai.ChatModelGPT4o2024_08_06, + }) + + // extract into a well-typed struct + var historicalComputer HistoricalComputer + _ = json.Unmarshal([]byte(chat.Choices[0].Message.Content), &historicalComputer) + + historicalComputer.Name + historicalComputer.Origin.YearBuilt + historicalComputer.Origin.Organization + for i, fact := range historicalComputer.NotableFacts { + // ... + } +} +``` + +> See the [full structured outputs example](./examples/structured-outputs/main.go) + +
+ + +### Request fields + +The openai library uses the [`omitzero`](https://tip.golang.org/doc/go1.24#encodingjsonpkgencodingjson) +semantics from the Go 1.24+ `encoding/json` release for request fields. + +Required primitive fields (`int64`, `string`, etc.) feature the tag \`json:"...,required"\`. These +fields are always serialized, even their zero values. + +Optional primitive types are wrapped in a `param.Opt[T]`. These fields can be set with the provided constructors, `openai.String(string)`, `openai.Int(int64)`, etc. + +Any `param.Opt[T]`, map, slice, struct or string enum uses the +tag \`json:"...,omitzero"\`. Its zero value is considered omitted. + +The `param.IsOmitted(any)` function can confirm the presence of any `omitzero` field. + +```go +p := openai.ExampleParams{ + ID: "id_xxx", // required property + Name: openai.String("..."), // optional property + + Point: openai.Point{ + X: 0, // required field will serialize as 0 + Y: openai.Int(1), // optional field will serialize as 1 + // ... omitted non-required fields will not be serialized + }, + + Origin: openai.Origin{}, // the zero value of [Origin] is considered omitted +} +``` + +To send `null` instead of a `param.Opt[T]`, use `param.Null[T]()`. +To send `null` instead of a struct `T`, use `param.NullStruct[T]()`. + +```go +p.Name = param.Null[string]() // 'null' instead of string +p.Point = param.NullStruct[Point]() // 'null' instead of struct + +param.IsNull(p.Name) // true +param.IsNull(p.Point) // true +``` + +Request structs contain a `.SetExtraFields(map[string]any)` method which can send non-conforming +fields in the request body. Extra fields overwrite any struct fields with a matching +key. For security reasons, only use `SetExtraFields` with trusted data. + +To send a custom value instead of a struct, use `param.Override[T](value)`. + +```go +// In cases where the API specifies a given type, +// but you want to send something else, use [SetExtraFields]: +p.SetExtraFields(map[string]any{ + "x": 0.01, // send "x" as a float instead of int +}) + +// Send a number instead of an object +custom := param.Override[openai.FooParams](12) +``` + +### Request unions + +Unions are represented as a struct with fields prefixed by "Of" for each of it's variants, +only one field can be non-zero. The non-zero field will be serialized. + +Sub-properties of the union can be accessed via methods on the union struct. +These methods return a mutable pointer to the underlying data, if present. + +```go +// Only one field can be non-zero, use param.IsOmitted() to check if a field is set +type AnimalUnionParam struct { + OfCat *Cat `json:",omitzero,inline` + OfDog *Dog `json:",omitzero,inline` +} + +animal := AnimalUnionParam{ + OfCat: &Cat{ + Name: "Whiskers", + Owner: PersonParam{ + Address: AddressParam{Street: "3333 Coyote Hill Rd", Zip: 0}, + }, + }, +} + +// Mutating a field +if address := animal.GetOwner().GetAddress(); address != nil { + address.ZipCode = 94304 +} +``` + +### Response objects + +All fields in response structs are ordinary value types (not pointers or wrappers). +Response structs also include a special `JSON` field containing metadata about +each property. + +```go +type Animal struct { + Name string `json:"name,nullable"` + Owners int `json:"owners"` + Age int `json:"age"` + JSON struct { + Name respjson.Field + Owner respjson.Field + Age respjson.Field + ExtraFields map[string]respjson.Field + } `json:"-"` +} +``` + +To handle optional data, use the `.Valid()` method on the JSON field. +`.Valid()` returns true if a field is not `null`, not present, or couldn't be marshaled. + +If `.Valid()` is false, the corresponding field will simply be its zero value. + +```go +raw := `{"owners": 1, "name": null}` + +var res Animal +json.Unmarshal([]byte(raw), &res) + +// Accessing regular fields + +res.Owners // 1 +res.Name // "" +res.Age // 0 + +// Optional field checks + +res.JSON.Owners.Valid() // true +res.JSON.Name.Valid() // false +res.JSON.Age.Valid() // false + +// Raw JSON values + +res.JSON.Owners.Raw() // "1" +res.JSON.Name.Raw() == "null" // true +res.JSON.Name.Raw() == respjson.Null // true +res.JSON.Age.Raw() == "" // true +res.JSON.Age.Raw() == respjson.Omitted // true +``` + +These `.JSON` structs also include an `ExtraFields` map containing +any properties in the json response that were not specified +in the struct. This can be useful for API features not yet +present in the SDK. + +```go +body := res.JSON.ExtraFields["my_unexpected_field"].Raw() +``` + +### Response Unions + +In responses, unions are represented by a flattened struct containing all possible fields from each of the +object variants. +To convert it to a variant use the `.AsFooVariant()` method or the `.AsAny()` method if present. + +If a response value union contains primitive values, primitive fields will be alongside +the properties but prefixed with `Of` and feature the tag `json:"...,inline"`. + +```go +type AnimalUnion struct { + // From variants [Dog], [Cat] + Owner Person `json:"owner"` + // From variant [Dog] + DogBreed string `json:"dog_breed"` + // From variant [Cat] + CatBreed string `json:"cat_breed"` + // ... + + JSON struct { + Owner respjson.Field + // ... + } `json:"-"` +} + +// If animal variant +if animal.Owner.Address.ZipCode == "" { + panic("missing zip code") +} + +// Switch on the variant +switch variant := animal.AsAny().(type) { +case Dog: +case Cat: +default: + panic("unexpected type") +} +``` + +### RequestOptions + +This library uses the functional options pattern. Functions defined in the +`option` package return a `RequestOption`, which is a closure that mutates a +`RequestConfig`. These options can be supplied to the client or at individual +requests. For example: + +```go +client := openai.NewClient( + // Adds a header to every request made by the client + option.WithHeader("X-Some-Header", "custom_header_info"), +) + +client.Chat.Completions.New(context.TODO(), ..., + // Override the header + option.WithHeader("X-Some-Header", "some_other_custom_header_info"), + // Add an undocumented field to the request body, using sjson syntax + option.WithJSONSet("some.json.path", map[string]string{"my": "object"}), +) +``` + +The request option `option.WithDebugLog(nil)` may be helpful while debugging. + +See the [full list of request options](https://pkg.go.dev/github.com/openai/openai-go/option). + +### Pagination + +This library provides some conveniences for working with paginated list endpoints. + +You can use `.ListAutoPaging()` methods to iterate through items across all pages: + +```go +iter := client.FineTuning.Jobs.ListAutoPaging(context.TODO(), openai.FineTuningJobListParams{ + Limit: openai.Int(20), +}) +// Automatically fetches more pages as needed. +for iter.Next() { + fineTuningJob := iter.Current() + fmt.Printf("%+v\n", fineTuningJob) +} +if err := iter.Err(); err != nil { + panic(err.Error()) +} +``` + +Or you can use simple `.List()` methods to fetch a single page and receive a standard response object +with additional helper methods like `.GetNextPage()`, e.g.: + +```go +page, err := client.FineTuning.Jobs.List(context.TODO(), openai.FineTuningJobListParams{ + Limit: openai.Int(20), +}) +for page != nil { + for _, job := range page.Data { + fmt.Printf("%+v\n", job) + } + page, err = page.GetNextPage() +} +if err != nil { + panic(err.Error()) +} +``` + +### Errors + +When the API returns a non-success status code, we return an error with type +`*openai.Error`. This contains the `StatusCode`, `*http.Request`, and +`*http.Response` values of the request, as well as the JSON of the error body +(much like other response objects in the SDK). + +To handle errors, we recommend that you use the `errors.As` pattern: + +```go +_, err := client.FineTuning.Jobs.New(context.TODO(), openai.FineTuningJobNewParams{ + Model: openai.FineTuningJobNewParamsModelBabbage002, + TrainingFile: "file-abc123", +}) +if err != nil { + var apierr *openai.Error + if errors.As(err, &apierr) { + println(string(apierr.DumpRequest(true))) // Prints the serialized HTTP request + println(string(apierr.DumpResponse(true))) // Prints the serialized HTTP response + } + panic(err.Error()) // GET "/fine_tuning/jobs": 400 Bad Request { ... } +} +``` + +When other errors occur, they are returned unwrapped; for example, +if HTTP transport fails, you might receive `*url.Error` wrapping `*net.OpError`. + +### Timeouts + +Requests do not time out by default; use context to configure a timeout for a request lifecycle. + +Note that if a request is [retried](#retries), the context timeout does not start over. +To set a per-retry timeout, use `option.WithRequestTimeout()`. + +```go +// This sets the timeout for the request, including all the retries. +ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) +defer cancel() +client.Chat.Completions.New( + ctx, + openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{{ + OfUser: &openai.ChatCompletionUserMessageParam{ + Content: openai.ChatCompletionUserMessageParamContentUnion{ + OfString: openai.String("How can I list all files in a directory using Python?"), + }, + }, + }}, + Model: shared.ChatModelGPT4_1, + }, + // This sets the per-retry timeout + option.WithRequestTimeout(20*time.Second), +) +``` + +### File uploads + +Request parameters that correspond to file uploads in multipart requests are typed as +`io.Reader`. The contents of the `io.Reader` will by default be sent as a multipart form +part with the file name of "anonymous_file" and content-type of "application/octet-stream". + +The file name and content-type can be customized by implementing `Name() string` or `ContentType() +string` on the run-time type of `io.Reader`. Note that `os.File` implements `Name() string`, so a +file returned by `os.Open` will be sent with the file name on disk. + +We also provide a helper `openai.File(reader io.Reader, filename string, contentType string)` +which can be used to wrap any `io.Reader` with the appropriate file name and content type. + +```go +// A file from the file system +file, err := os.Open("input.jsonl") +openai.FileNewParams{ + File: file, + Purpose: openai.FilePurposeFineTune, +} + +// A file from a string +openai.FileNewParams{ + File: strings.NewReader("my file contents"), + Purpose: openai.FilePurposeFineTune, +} + +// With a custom filename and contentType +openai.FileNewParams{ + File: openai.File(strings.NewReader(`{"hello": "foo"}`), "file.go", "application/json"), + Purpose: openai.FilePurposeFineTune, +} +``` + +## Webhook Verification + +Verifying webhook signatures is _optional but encouraged_. + +For more information about webhooks, see [the API docs](https://platform.openai.com/docs/guides/webhooks). + +### Parsing webhook payloads + +For most use cases, you will likely want to verify the webhook and parse the payload at the same time. To achieve this, we provide the method `client.Webhooks.Unwrap()`, which parses a webhook request and verifies that it was sent by OpenAI. This method will return an error if the signature is invalid. + +Note that the `body` parameter should be the raw JSON bytes sent from the server (do not parse it first). The `Unwrap()` method will parse this JSON for you into an event object after verifying the webhook was sent from OpenAI. + +```go +package main + +import ( + "io" + "log" + "net/http" + "os" + + "github.com/gin-gonic/gin" + "github.com/openai/openai-go" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/webhooks" +) + +func main() { + client := openai.NewClient( + option.WithWebhookSecret(os.Getenv("OPENAI_WEBHOOK_SECRET")), // env var used by default; explicit here. + ) + + r := gin.Default() + + r.POST("/webhook", func(c *gin.Context) { + body, err := io.ReadAll(c.Request.Body) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Error reading request body"}) + return + } + defer c.Request.Body.Close() + + webhookEvent, err := client.Webhooks.Unwrap(body, c.Request.Header) + if err != nil { + log.Printf("Invalid webhook signature: %v", err) + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid signature"}) + return + } + + switch event := webhookEvent.AsAny().(type) { + case webhooks.ResponseCompletedWebhookEvent: + log.Printf("Response completed: %+v", event.Data) + case webhooks.ResponseFailedWebhookEvent: + log.Printf("Response failed: %+v", event.Data) + default: + log.Printf("Unhandled event type: %T", event) + } + + c.JSON(http.StatusOK, gin.H{"message": "ok"}) + }) + + r.Run(":8000") +} +``` + +### Verifying webhook payloads directly + +In some cases, you may want to verify the webhook separately from parsing the payload. If you prefer to handle these steps separately, we provide the method `client.Webhooks.VerifySignature()` to _only verify_ the signature of a webhook request. Like `Unwrap()`, this method will return an error if the signature is invalid. + +Note that the `body` parameter should be the raw JSON bytes sent from the server (do not parse it first). You will then need to parse the body after verifying the signature. + +```go +package main + +import ( + "encoding/json" + "io" + "log" + "net/http" + "os" + + "github.com/gin-gonic/gin" + "github.com/openai/openai-go" + "github.com/openai/openai-go/option" +) + +func main() { + client := openai.NewClient( + option.WithWebhookSecret(os.Getenv("OPENAI_WEBHOOK_SECRET")), // env var used by default; explicit here. + ) + + r := gin.Default() + + r.POST("/webhook", func(c *gin.Context) { + body, err := io.ReadAll(c.Request.Body) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Error reading request body"}) + return + } + defer c.Request.Body.Close() + + err = client.Webhooks.VerifySignature(body, c.Request.Header) + if err != nil { + log.Printf("Invalid webhook signature: %v", err) + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid signature"}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "ok"}) + }) + + r.Run(":8000") +} +``` + +### Retries + +Certain errors will be automatically retried 2 times by default, with a short exponential backoff. +We retry by default all connection errors, 408 Request Timeout, 409 Conflict, 429 Rate Limit, +and >=500 Internal errors. + +You can use the `WithMaxRetries` option to configure or disable this: + +```go +// Configure the default for all requests: +client := openai.NewClient( + option.WithMaxRetries(0), // default is 2 +) + +// Override per-request: +client.Chat.Completions.New( + context.TODO(), + openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{{ + OfUser: &openai.ChatCompletionUserMessageParam{ + Content: openai.ChatCompletionUserMessageParamContentUnion{ + OfString: openai.String("How can I get the name of the current day in JavaScript?"), + }, + }, + }}, + Model: shared.ChatModelGPT4_1, + }, + option.WithMaxRetries(5), +) +``` + +### Accessing raw response data (e.g. response headers) + +You can access the raw HTTP response data by using the `option.WithResponseInto()` request option. This is useful when +you need to examine response headers, status codes, or other details. + +```go +// Create a variable to store the HTTP response +var response *http.Response +chatCompletion, err := client.Chat.Completions.New( + context.TODO(), + openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{{ + OfUser: &openai.ChatCompletionUserMessageParam{ + Content: openai.ChatCompletionUserMessageParamContentUnion{ + OfString: openai.String("Say this is a test"), + }, + }, + }}, + Model: shared.ChatModelGPT4_1, + }, + option.WithResponseInto(&response), +) +if err != nil { + // handle error +} +fmt.Printf("%+v\n", chatCompletion) + +fmt.Printf("Status Code: %d\n", response.StatusCode) +fmt.Printf("Headers: %+#v\n", response.Header) +``` + +### Making custom/undocumented requests + +This library is typed for convenient access to the documented API. If you need to access undocumented +endpoints, params, or response properties, the library can still be used. + +#### Undocumented endpoints + +To make requests to undocumented endpoints, you can use `client.Get`, `client.Post`, and other HTTP verbs. +`RequestOptions` on the client, such as retries, will be respected when making these requests. + +```go +var ( + // params can be an io.Reader, a []byte, an encoding/json serializable object, + // or a "…Params" struct defined in this library. + params map[string]any + + // result can be an []byte, *http.Response, a encoding/json deserializable object, + // or a model defined in this library. + result *http.Response +) +err := client.Post(context.Background(), "/unspecified", params, &result) +if err != nil { + … +} +``` + +#### Undocumented request params + +To make requests using undocumented parameters, you may use either the `option.WithQuerySet()` +or the `option.WithJSONSet()` methods. + +```go +params := FooNewParams{ + ID: "id_xxxx", + Data: FooNewParamsData{ + FirstName: openai.String("John"), + }, +} +client.Foo.New(context.Background(), params, option.WithJSONSet("data.last_name", "Doe")) +``` + +#### Undocumented response properties + +To access undocumented response properties, you may either access the raw JSON of the response as a string +with `result.JSON.RawJSON()`, or get the raw JSON of a particular field on the result with +`result.JSON.Foo.Raw()`. + +Any fields that are not present on the response struct will be saved and can be accessed by `result.JSON.ExtraFields()` which returns the extra fields as a `map[string]Field`. + +### Middleware + +We provide `option.WithMiddleware` which applies the given +middleware to requests. + +```go +func Logger(req *http.Request, next option.MiddlewareNext) (res *http.Response, err error) { + // Before the request + start := time.Now() + LogReq(req) + + // Forward the request to the next handler + res, err = next(req) + + // Handle stuff after the request + end := time.Now() + LogRes(res, err, start - end) + + return res, err +} + +client := openai.NewClient( + option.WithMiddleware(Logger), +) +``` + +When multiple middlewares are provided as variadic arguments, the middlewares +are applied left to right. If `option.WithMiddleware` is given +multiple times, for example first in the client then the method, the +middleware in the client will run first and the middleware given in the method +will run next. + +You may also replace the default `http.Client` with +`option.WithHTTPClient(client)`. Only one http client is +accepted (this overwrites any previous client) and receives requests after any +middleware has been applied. + +## Microsoft Azure OpenAI + +To use this library with [Azure OpenAI]https://learn.microsoft.com/azure/ai-services/openai/overview), +use the option.RequestOption functions in the `azure` package. + +```go +package main + +import ( + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/openai/openai-go" + "github.com/openai/openai-go/azure" +) + +func main() { + const azureOpenAIEndpoint = "https://.openai.azure.com" + + // The latest API versions, including previews, can be found here: + // ttps://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versionng + const azureOpenAIAPIVersion = "2024-06-01" + + tokenCredential, err := azidentity.NewDefaultAzureCredential(nil) + + if err != nil { + fmt.Printf("Failed to create the DefaultAzureCredential: %s", err) + os.Exit(1) + } + + client := openai.NewClient( + azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion), + + // Choose between authenticating using a TokenCredential or an API Key + azure.WithTokenCredential(tokenCredential), + // or azure.WithAPIKey(azureOpenAIAPIKey), + ) +} +``` + + +## Semantic versioning + +This package generally follows [SemVer](https://semver.org/spec/v2.0.0.html) conventions, though certain backwards-incompatible changes may be released as minor versions: + +1. Changes to library internals which are technically public but not intended or documented for external use. _(Please open a GitHub issue to let us know if you are relying on such internals.)_ +2. Changes that we do not expect to impact the vast majority of users in practice. + +We take backwards-compatibility seriously and work hard to ensure you can rely on a smooth upgrade experience. + +We are keen for your feedback; please open an [issue](https://www.github.com/openai/openai-go/issues) with questions, bugs, or suggestions. + +## Contributing + +See [the contributing documentation](./CONTRIBUTING.md). diff --git a/vendor/github.com/openai/openai-go/SECURITY.md b/vendor/github.com/openai/openai-go/SECURITY.md new file mode 100644 index 0000000000..4adb0c54f1 --- /dev/null +++ b/vendor/github.com/openai/openai-go/SECURITY.md @@ -0,0 +1,29 @@ +# Security Policy + +## Reporting Security Issues + +This SDK is generated by [Stainless Software Inc](http://stainless.com). Stainless takes security seriously, and encourages you to report any security vulnerability promptly so that appropriate action can be taken. + +To report a security issue, please contact the Stainless team at security@stainless.com. + +## Responsible Disclosure + +We appreciate the efforts of security researchers and individuals who help us maintain the security of +SDKs we generate. If you believe you have found a security vulnerability, please adhere to responsible +disclosure practices by allowing us a reasonable amount of time to investigate and address the issue +before making any information public. + +## Reporting Non-SDK Related Security Issues + +If you encounter security issues that are not directly related to SDKs but pertain to the services +or products provided by OpenAI, please follow the respective company's security reporting guidelines. + +### OpenAI Terms and Policies + +Our Security Policy can be found at [Security Policy URL](https://openai.com/policies/coordinated-vulnerability-disclosure-policy). + +Please contact disclosure@openai.com for any questions or concerns regarding the security of our services. + +--- + +Thank you for helping us keep the SDKs and systems they interact with secure. diff --git a/vendor/github.com/openai/openai-go/aliases.go b/vendor/github.com/openai/openai-go/aliases.go new file mode 100644 index 0000000000..73b6c0cc1a --- /dev/null +++ b/vendor/github.com/openai/openai-go/aliases.go @@ -0,0 +1,440 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package openai + +import ( + "github.com/openai/openai-go/internal/apierror" + "github.com/openai/openai-go/packages/param" + "github.com/openai/openai-go/shared" +) + +// aliased to make [param.APIUnion] private when embedding +type paramUnion = param.APIUnion + +// aliased to make [param.APIObject] private when embedding +type paramObj = param.APIObject + +type Error = apierror.Error + +// This is an alias to an internal type. +type ChatModel = shared.ChatModel + +// Equals "gpt-4.1" +const ChatModelGPT4_1 = shared.ChatModelGPT4_1 + +// Equals "gpt-4.1-mini" +const ChatModelGPT4_1Mini = shared.ChatModelGPT4_1Mini + +// Equals "gpt-4.1-nano" +const ChatModelGPT4_1Nano = shared.ChatModelGPT4_1Nano + +// Equals "gpt-4.1-2025-04-14" +const ChatModelGPT4_1_2025_04_14 = shared.ChatModelGPT4_1_2025_04_14 + +// Equals "gpt-4.1-mini-2025-04-14" +const ChatModelGPT4_1Mini2025_04_14 = shared.ChatModelGPT4_1Mini2025_04_14 + +// Equals "gpt-4.1-nano-2025-04-14" +const ChatModelGPT4_1Nano2025_04_14 = shared.ChatModelGPT4_1Nano2025_04_14 + +// Equals "o4-mini" +const ChatModelO4Mini = shared.ChatModelO4Mini + +// Equals "o4-mini-2025-04-16" +const ChatModelO4Mini2025_04_16 = shared.ChatModelO4Mini2025_04_16 + +// Equals "o3" +const ChatModelO3 = shared.ChatModelO3 + +// Equals "o3-2025-04-16" +const ChatModelO3_2025_04_16 = shared.ChatModelO3_2025_04_16 + +// Equals "o3-mini" +const ChatModelO3Mini = shared.ChatModelO3Mini + +// Equals "o3-mini-2025-01-31" +const ChatModelO3Mini2025_01_31 = shared.ChatModelO3Mini2025_01_31 + +// Equals "o1" +const ChatModelO1 = shared.ChatModelO1 + +// Equals "o1-2024-12-17" +const ChatModelO1_2024_12_17 = shared.ChatModelO1_2024_12_17 + +// Equals "o1-preview" +const ChatModelO1Preview = shared.ChatModelO1Preview + +// Equals "o1-preview-2024-09-12" +const ChatModelO1Preview2024_09_12 = shared.ChatModelO1Preview2024_09_12 + +// Equals "o1-mini" +const ChatModelO1Mini = shared.ChatModelO1Mini + +// Equals "o1-mini-2024-09-12" +const ChatModelO1Mini2024_09_12 = shared.ChatModelO1Mini2024_09_12 + +// Equals "gpt-4o" +const ChatModelGPT4o = shared.ChatModelGPT4o + +// Equals "gpt-4o-2024-11-20" +const ChatModelGPT4o2024_11_20 = shared.ChatModelGPT4o2024_11_20 + +// Equals "gpt-4o-2024-08-06" +const ChatModelGPT4o2024_08_06 = shared.ChatModelGPT4o2024_08_06 + +// Equals "gpt-4o-2024-05-13" +const ChatModelGPT4o2024_05_13 = shared.ChatModelGPT4o2024_05_13 + +// Equals "gpt-4o-audio-preview" +const ChatModelGPT4oAudioPreview = shared.ChatModelGPT4oAudioPreview + +// Equals "gpt-4o-audio-preview-2024-10-01" +const ChatModelGPT4oAudioPreview2024_10_01 = shared.ChatModelGPT4oAudioPreview2024_10_01 + +// Equals "gpt-4o-audio-preview-2024-12-17" +const ChatModelGPT4oAudioPreview2024_12_17 = shared.ChatModelGPT4oAudioPreview2024_12_17 + +// Equals "gpt-4o-audio-preview-2025-06-03" +const ChatModelGPT4oAudioPreview2025_06_03 = shared.ChatModelGPT4oAudioPreview2025_06_03 + +// Equals "gpt-4o-mini-audio-preview" +const ChatModelGPT4oMiniAudioPreview = shared.ChatModelGPT4oMiniAudioPreview + +// Equals "gpt-4o-mini-audio-preview-2024-12-17" +const ChatModelGPT4oMiniAudioPreview2024_12_17 = shared.ChatModelGPT4oMiniAudioPreview2024_12_17 + +// Equals "gpt-4o-search-preview" +const ChatModelGPT4oSearchPreview = shared.ChatModelGPT4oSearchPreview + +// Equals "gpt-4o-mini-search-preview" +const ChatModelGPT4oMiniSearchPreview = shared.ChatModelGPT4oMiniSearchPreview + +// Equals "gpt-4o-search-preview-2025-03-11" +const ChatModelGPT4oSearchPreview2025_03_11 = shared.ChatModelGPT4oSearchPreview2025_03_11 + +// Equals "gpt-4o-mini-search-preview-2025-03-11" +const ChatModelGPT4oMiniSearchPreview2025_03_11 = shared.ChatModelGPT4oMiniSearchPreview2025_03_11 + +// Equals "chatgpt-4o-latest" +const ChatModelChatgpt4oLatest = shared.ChatModelChatgpt4oLatest + +// Equals "codex-mini-latest" +const ChatModelCodexMiniLatest = shared.ChatModelCodexMiniLatest + +// Equals "gpt-4o-mini" +const ChatModelGPT4oMini = shared.ChatModelGPT4oMini + +// Equals "gpt-4o-mini-2024-07-18" +const ChatModelGPT4oMini2024_07_18 = shared.ChatModelGPT4oMini2024_07_18 + +// Equals "gpt-4-turbo" +const ChatModelGPT4Turbo = shared.ChatModelGPT4Turbo + +// Equals "gpt-4-turbo-2024-04-09" +const ChatModelGPT4Turbo2024_04_09 = shared.ChatModelGPT4Turbo2024_04_09 + +// Equals "gpt-4-0125-preview" +const ChatModelGPT4_0125Preview = shared.ChatModelGPT4_0125Preview + +// Equals "gpt-4-turbo-preview" +const ChatModelGPT4TurboPreview = shared.ChatModelGPT4TurboPreview + +// Equals "gpt-4-1106-preview" +const ChatModelGPT4_1106Preview = shared.ChatModelGPT4_1106Preview + +// Equals "gpt-4-vision-preview" +const ChatModelGPT4VisionPreview = shared.ChatModelGPT4VisionPreview + +// Equals "gpt-4" +const ChatModelGPT4 = shared.ChatModelGPT4 + +// Equals "gpt-4-0314" +const ChatModelGPT4_0314 = shared.ChatModelGPT4_0314 + +// Equals "gpt-4-0613" +const ChatModelGPT4_0613 = shared.ChatModelGPT4_0613 + +// Equals "gpt-4-32k" +const ChatModelGPT4_32k = shared.ChatModelGPT4_32k + +// Equals "gpt-4-32k-0314" +const ChatModelGPT4_32k0314 = shared.ChatModelGPT4_32k0314 + +// Equals "gpt-4-32k-0613" +const ChatModelGPT4_32k0613 = shared.ChatModelGPT4_32k0613 + +// Equals "gpt-3.5-turbo" +const ChatModelGPT3_5Turbo = shared.ChatModelGPT3_5Turbo + +// Equals "gpt-3.5-turbo-16k" +const ChatModelGPT3_5Turbo16k = shared.ChatModelGPT3_5Turbo16k + +// Equals "gpt-3.5-turbo-0301" +const ChatModelGPT3_5Turbo0301 = shared.ChatModelGPT3_5Turbo0301 + +// Equals "gpt-3.5-turbo-0613" +const ChatModelGPT3_5Turbo0613 = shared.ChatModelGPT3_5Turbo0613 + +// Equals "gpt-3.5-turbo-1106" +const ChatModelGPT3_5Turbo1106 = shared.ChatModelGPT3_5Turbo1106 + +// Equals "gpt-3.5-turbo-0125" +const ChatModelGPT3_5Turbo0125 = shared.ChatModelGPT3_5Turbo0125 + +// Equals "gpt-3.5-turbo-16k-0613" +const ChatModelGPT3_5Turbo16k0613 = shared.ChatModelGPT3_5Turbo16k0613 + +// A filter used to compare a specified attribute key to a given value using a +// defined comparison operation. +// +// This is an alias to an internal type. +type ComparisonFilter = shared.ComparisonFilter + +// Specifies the comparison operator: `eq`, `ne`, `gt`, `gte`, `lt`, `lte`. +// +// - `eq`: equals +// - `ne`: not equal +// - `gt`: greater than +// - `gte`: greater than or equal +// - `lt`: less than +// - `lte`: less than or equal +// +// This is an alias to an internal type. +type ComparisonFilterType = shared.ComparisonFilterType + +// Equals "eq" +const ComparisonFilterTypeEq = shared.ComparisonFilterTypeEq + +// Equals "ne" +const ComparisonFilterTypeNe = shared.ComparisonFilterTypeNe + +// Equals "gt" +const ComparisonFilterTypeGt = shared.ComparisonFilterTypeGt + +// Equals "gte" +const ComparisonFilterTypeGte = shared.ComparisonFilterTypeGte + +// Equals "lt" +const ComparisonFilterTypeLt = shared.ComparisonFilterTypeLt + +// Equals "lte" +const ComparisonFilterTypeLte = shared.ComparisonFilterTypeLte + +// The value to compare against the attribute key; supports string, number, or +// boolean types. +// +// This is an alias to an internal type. +type ComparisonFilterValueUnion = shared.ComparisonFilterValueUnion + +// A filter used to compare a specified attribute key to a given value using a +// defined comparison operation. +// +// This is an alias to an internal type. +type ComparisonFilterParam = shared.ComparisonFilterParam + +// The value to compare against the attribute key; supports string, number, or +// boolean types. +// +// This is an alias to an internal type. +type ComparisonFilterValueUnionParam = shared.ComparisonFilterValueUnionParam + +// Combine multiple filters using `and` or `or`. +// +// This is an alias to an internal type. +type CompoundFilter = shared.CompoundFilter + +// Type of operation: `and` or `or`. +// +// This is an alias to an internal type. +type CompoundFilterType = shared.CompoundFilterType + +// Equals "and" +const CompoundFilterTypeAnd = shared.CompoundFilterTypeAnd + +// Equals "or" +const CompoundFilterTypeOr = shared.CompoundFilterTypeOr + +// Combine multiple filters using `and` or `or`. +// +// This is an alias to an internal type. +type CompoundFilterParam = shared.CompoundFilterParam + +// This is an alias to an internal type. +type ErrorObject = shared.ErrorObject + +// This is an alias to an internal type. +type FunctionDefinition = shared.FunctionDefinition + +// This is an alias to an internal type. +type FunctionDefinitionParam = shared.FunctionDefinitionParam + +// The parameters the functions accepts, described as a JSON Schema object. See the +// [guide](https://platform.openai.com/docs/guides/function-calling) for examples, +// and the +// [JSON Schema reference](https://json-schema.org/understanding-json-schema/) for +// documentation about the format. +// +// Omitting `parameters` defines a function with an empty parameter list. +// +// This is an alias to an internal type. +type FunctionParameters = shared.FunctionParameters + +// Set of 16 key-value pairs that can be attached to an object. This can be useful +// for storing additional information about the object in a structured format, and +// querying for objects via API or the dashboard. +// +// Keys are strings with a maximum length of 64 characters. Values are strings with +// a maximum length of 512 characters. +// +// This is an alias to an internal type. +type Metadata = shared.Metadata + +// **o-series models only** +// +// Configuration options for +// [reasoning models](https://platform.openai.com/docs/guides/reasoning). +// +// This is an alias to an internal type. +type Reasoning = shared.Reasoning + +// **Deprecated:** use `summary` instead. +// +// A summary of the reasoning performed by the model. This can be useful for +// debugging and understanding the model's reasoning process. One of `auto`, +// `concise`, or `detailed`. +// +// This is an alias to an internal type. +type ReasoningGenerateSummary = shared.ReasoningGenerateSummary + +// Equals "auto" +const ReasoningGenerateSummaryAuto = shared.ReasoningGenerateSummaryAuto + +// Equals "concise" +const ReasoningGenerateSummaryConcise = shared.ReasoningGenerateSummaryConcise + +// Equals "detailed" +const ReasoningGenerateSummaryDetailed = shared.ReasoningGenerateSummaryDetailed + +// A summary of the reasoning performed by the model. This can be useful for +// debugging and understanding the model's reasoning process. One of `auto`, +// `concise`, or `detailed`. +// +// This is an alias to an internal type. +type ReasoningSummary = shared.ReasoningSummary + +// Equals "auto" +const ReasoningSummaryAuto = shared.ReasoningSummaryAuto + +// Equals "concise" +const ReasoningSummaryConcise = shared.ReasoningSummaryConcise + +// Equals "detailed" +const ReasoningSummaryDetailed = shared.ReasoningSummaryDetailed + +// **o-series models only** +// +// Configuration options for +// [reasoning models](https://platform.openai.com/docs/guides/reasoning). +// +// This is an alias to an internal type. +type ReasoningParam = shared.ReasoningParam + +// **o-series models only** +// +// Constrains effort on reasoning for +// [reasoning models](https://platform.openai.com/docs/guides/reasoning). Currently +// supported values are `low`, `medium`, and `high`. Reducing reasoning effort can +// result in faster responses and fewer tokens used on reasoning in a response. +// +// This is an alias to an internal type. +type ReasoningEffort = shared.ReasoningEffort + +// Equals "low" +const ReasoningEffortLow = shared.ReasoningEffortLow + +// Equals "medium" +const ReasoningEffortMedium = shared.ReasoningEffortMedium + +// Equals "high" +const ReasoningEffortHigh = shared.ReasoningEffortHigh + +// JSON object response format. An older method of generating JSON responses. Using +// `json_schema` is recommended for models that support it. Note that the model +// will not generate JSON without a system or user message instructing it to do so. +// +// This is an alias to an internal type. +type ResponseFormatJSONObject = shared.ResponseFormatJSONObject + +// JSON object response format. An older method of generating JSON responses. Using +// `json_schema` is recommended for models that support it. Note that the model +// will not generate JSON without a system or user message instructing it to do so. +// +// This is an alias to an internal type. +type ResponseFormatJSONObjectParam = shared.ResponseFormatJSONObjectParam + +// JSON Schema response format. Used to generate structured JSON responses. Learn +// more about +// [Structured Outputs](https://platform.openai.com/docs/guides/structured-outputs). +// +// This is an alias to an internal type. +type ResponseFormatJSONSchema = shared.ResponseFormatJSONSchema + +// Structured Outputs configuration options, including a JSON Schema. +// +// This is an alias to an internal type. +type ResponseFormatJSONSchemaJSONSchema = shared.ResponseFormatJSONSchemaJSONSchema + +// JSON Schema response format. Used to generate structured JSON responses. Learn +// more about +// [Structured Outputs](https://platform.openai.com/docs/guides/structured-outputs). +// +// This is an alias to an internal type. +type ResponseFormatJSONSchemaParam = shared.ResponseFormatJSONSchemaParam + +// Structured Outputs configuration options, including a JSON Schema. +// +// This is an alias to an internal type. +type ResponseFormatJSONSchemaJSONSchemaParam = shared.ResponseFormatJSONSchemaJSONSchemaParam + +// Default response format. Used to generate text responses. +// +// This is an alias to an internal type. +type ResponseFormatText = shared.ResponseFormatText + +// Default response format. Used to generate text responses. +// +// This is an alias to an internal type. +type ResponseFormatTextParam = shared.ResponseFormatTextParam + +// This is an alias to an internal type. +type ResponsesModel = shared.ResponsesModel + +// Equals "o1-pro" +const ResponsesModelO1Pro = shared.ResponsesModelO1Pro + +// Equals "o1-pro-2025-03-19" +const ResponsesModelO1Pro2025_03_19 = shared.ResponsesModelO1Pro2025_03_19 + +// Equals "o3-pro" +const ResponsesModelO3Pro = shared.ResponsesModelO3Pro + +// Equals "o3-pro-2025-06-10" +const ResponsesModelO3Pro2025_06_10 = shared.ResponsesModelO3Pro2025_06_10 + +// Equals "o3-deep-research" +const ResponsesModelO3DeepResearch = shared.ResponsesModelO3DeepResearch + +// Equals "o3-deep-research-2025-06-26" +const ResponsesModelO3DeepResearch2025_06_26 = shared.ResponsesModelO3DeepResearch2025_06_26 + +// Equals "o4-mini-deep-research" +const ResponsesModelO4MiniDeepResearch = shared.ResponsesModelO4MiniDeepResearch + +// Equals "o4-mini-deep-research-2025-06-26" +const ResponsesModelO4MiniDeepResearch2025_06_26 = shared.ResponsesModelO4MiniDeepResearch2025_06_26 + +// Equals "computer-use-preview" +const ResponsesModelComputerUsePreview = shared.ResponsesModelComputerUsePreview + +// Equals "computer-use-preview-2025-03-11" +const ResponsesModelComputerUsePreview2025_03_11 = shared.ResponsesModelComputerUsePreview2025_03_11 diff --git a/vendor/github.com/openai/openai-go/api.md b/vendor/github.com/openai/openai-go/api.md new file mode 100644 index 0000000000..841a08913b --- /dev/null +++ b/vendor/github.com/openai/openai-go/api.md @@ -0,0 +1,791 @@ +# Shared Params Types + +- shared.ChatModel +- shared.ComparisonFilterParam +- shared.CompoundFilterParam +- shared.FunctionDefinitionParam +- shared.FunctionParameters +- shared.Metadata +- shared.ReasoningParam +- shared.ReasoningEffort +- shared.ResponseFormatJSONObjectParam +- shared.ResponseFormatJSONSchemaParam +- shared.ResponseFormatTextParam +- shared.ResponsesModel + +# Shared Response Types + +- shared.ChatModel +- shared.ComparisonFilter +- shared.CompoundFilter +- shared.ErrorObject +- shared.FunctionDefinition +- shared.FunctionParameters +- shared.Metadata +- shared.Reasoning +- shared.ReasoningEffort +- shared.ResponseFormatJSONObject +- shared.ResponseFormatJSONSchema +- shared.ResponseFormatText +- shared.ResponsesModel + +# Completions + +Response Types: + +- openai.Completion +- openai.CompletionChoice +- openai.CompletionUsage + +Methods: + +- client.Completions.New(ctx context.Context, body openai.CompletionNewParams) (openai.Completion, error) + +# Chat + +## Completions + +Params Types: + +- openai.ChatCompletionAssistantMessageParam +- openai.ChatCompletionAudioParam +- openai.ChatCompletionContentPartUnionParam +- openai.ChatCompletionContentPartImageParam +- openai.ChatCompletionContentPartInputAudioParam +- openai.ChatCompletionContentPartRefusalParam +- openai.ChatCompletionContentPartTextParam +- openai.ChatCompletionDeveloperMessageParam +- openai.ChatCompletionFunctionCallOptionParam +- openai.ChatCompletionFunctionMessageParam +- openai.ChatCompletionMessageParamUnion +- openai.ChatCompletionMessageToolCallParam +- openai.ChatCompletionNamedToolChoiceParam +- openai.ChatCompletionPredictionContentParam +- openai.ChatCompletionStreamOptionsParam +- openai.ChatCompletionSystemMessageParam +- openai.ChatCompletionToolParam +- openai.ChatCompletionToolChoiceOptionUnionParam +- openai.ChatCompletionToolMessageParam +- openai.ChatCompletionUserMessageParam + +Response Types: + +- openai.ChatCompletion +- openai.ChatCompletionAudio +- openai.ChatCompletionChunk +- openai.ChatCompletionContentPartImage +- openai.ChatCompletionContentPartText +- openai.ChatCompletionDeleted +- openai.ChatCompletionMessage +- openai.ChatCompletionMessageToolCall +- openai.ChatCompletionStoreMessage +- openai.ChatCompletionTokenLogprob + +Methods: + +- client.Chat.Completions.New(ctx context.Context, body openai.ChatCompletionNewParams) (openai.ChatCompletion, error) +- client.Chat.Completions.Get(ctx context.Context, completionID string) (openai.ChatCompletion, error) +- client.Chat.Completions.Update(ctx context.Context, completionID string, body openai.ChatCompletionUpdateParams) (openai.ChatCompletion, error) +- client.Chat.Completions.List(ctx context.Context, query openai.ChatCompletionListParams) (pagination.CursorPage[openai.ChatCompletion], error) +- client.Chat.Completions.Delete(ctx context.Context, completionID string) (openai.ChatCompletionDeleted, error) + +### Messages + +Methods: + +- client.Chat.Completions.Messages.List(ctx context.Context, completionID string, query openai.ChatCompletionMessageListParams) (pagination.CursorPage[openai.ChatCompletionStoreMessage], error) + +# Embeddings + +Params Types: + +- openai.EmbeddingModel + +Response Types: + +- openai.CreateEmbeddingResponse +- openai.Embedding + +Methods: + +- client.Embeddings.New(ctx context.Context, body openai.EmbeddingNewParams) (openai.CreateEmbeddingResponse, error) + +# Files + +Params Types: + +- openai.FilePurpose + +Response Types: + +- openai.FileDeleted +- openai.FileObject + +Methods: + +- client.Files.New(ctx context.Context, body openai.FileNewParams) (openai.FileObject, error) +- client.Files.Get(ctx context.Context, fileID string) (openai.FileObject, error) +- client.Files.List(ctx context.Context, query openai.FileListParams) (pagination.CursorPage[openai.FileObject], error) +- client.Files.Delete(ctx context.Context, fileID string) (openai.FileDeleted, error) +- client.Files.Content(ctx context.Context, fileID string) (http.Response, error) + +# Images + +Params Types: + +- openai.ImageModel + +Response Types: + +- openai.Image +- openai.ImageEditCompletedEvent +- openai.ImageEditPartialImageEvent +- openai.ImageEditStreamEventUnion +- openai.ImageGenCompletedEvent +- openai.ImageGenPartialImageEvent +- openai.ImageGenStreamEventUnion +- openai.ImagesResponse + +Methods: + +- client.Images.NewVariation(ctx context.Context, body openai.ImageNewVariationParams) (openai.ImagesResponse, error) +- client.Images.Edit(ctx context.Context, body openai.ImageEditParams) (openai.ImagesResponse, error) +- client.Images.Generate(ctx context.Context, body openai.ImageGenerateParams) (openai.ImagesResponse, error) + +# Audio + +Params Types: + +- openai.AudioModel +- openai.AudioResponseFormat + +## Transcriptions + +Params Types: + +- openai.TranscriptionInclude + +Response Types: + +- openai.Transcription +- openai.TranscriptionStreamEventUnion +- openai.TranscriptionTextDeltaEvent +- openai.TranscriptionTextDoneEvent + +Methods: + +- client.Audio.Transcriptions.New(ctx context.Context, body openai.AudioTranscriptionNewParams) (Transcription, error) + +## Translations + +Response Types: + +- openai.Translation + +Methods: + +- client.Audio.Translations.New(ctx context.Context, body openai.AudioTranslationNewParams) (Translation, error) + +## Speech + +Params Types: + +- openai.SpeechModel + +Methods: + +- client.Audio.Speech.New(ctx context.Context, body openai.AudioSpeechNewParams) (http.Response, error) + +# Moderations + +Params Types: + +- openai.ModerationImageURLInputParam +- openai.ModerationModel +- openai.ModerationMultiModalInputUnionParam +- openai.ModerationTextInputParam + +Response Types: + +- openai.Moderation +- openai.ModerationNewResponse + +Methods: + +- client.Moderations.New(ctx context.Context, body openai.ModerationNewParams) (openai.ModerationNewResponse, error) + +# Models + +Response Types: + +- openai.Model +- openai.ModelDeleted + +Methods: + +- client.Models.Get(ctx context.Context, model string) (openai.Model, error) +- client.Models.List(ctx context.Context) (pagination.Page[openai.Model], error) +- client.Models.Delete(ctx context.Context, model string) (openai.ModelDeleted, error) + +# FineTuning + +## Methods + +Params Types: + +- openai.DpoHyperparameters +- openai.DpoMethodParam +- openai.ReinforcementHyperparameters +- openai.ReinforcementMethodParam +- openai.SupervisedHyperparameters +- openai.SupervisedMethodParam + +Response Types: + +- openai.DpoHyperparametersResp +- openai.DpoMethod +- openai.ReinforcementHyperparametersResp +- openai.ReinforcementMethod +- openai.SupervisedHyperparametersResp +- openai.SupervisedMethod + +## Jobs + +Response Types: + +- openai.FineTuningJob +- openai.FineTuningJobEvent +- openai.FineTuningJobWandbIntegration +- openai.FineTuningJobWandbIntegrationObject + +Methods: + +- client.FineTuning.Jobs.New(ctx context.Context, body openai.FineTuningJobNewParams) (openai.FineTuningJob, error) +- client.FineTuning.Jobs.Get(ctx context.Context, fineTuningJobID string) (openai.FineTuningJob, error) +- client.FineTuning.Jobs.List(ctx context.Context, query openai.FineTuningJobListParams) (pagination.CursorPage[openai.FineTuningJob], error) +- client.FineTuning.Jobs.Cancel(ctx context.Context, fineTuningJobID string) (openai.FineTuningJob, error) +- client.FineTuning.Jobs.ListEvents(ctx context.Context, fineTuningJobID string, query openai.FineTuningJobListEventsParams) (pagination.CursorPage[openai.FineTuningJobEvent], error) +- client.FineTuning.Jobs.Pause(ctx context.Context, fineTuningJobID string) (openai.FineTuningJob, error) +- client.FineTuning.Jobs.Resume(ctx context.Context, fineTuningJobID string) (openai.FineTuningJob, error) + +### Checkpoints + +Response Types: + +- openai.FineTuningJobCheckpoint + +Methods: + +- client.FineTuning.Jobs.Checkpoints.List(ctx context.Context, fineTuningJobID string, query openai.FineTuningJobCheckpointListParams) (pagination.CursorPage[openai.FineTuningJobCheckpoint], error) + +## Checkpoints + +### Permissions + +Response Types: + +- openai.FineTuningCheckpointPermissionNewResponse +- openai.FineTuningCheckpointPermissionGetResponse +- openai.FineTuningCheckpointPermissionDeleteResponse + +Methods: + +- client.FineTuning.Checkpoints.Permissions.New(ctx context.Context, fineTunedModelCheckpoint string, body openai.FineTuningCheckpointPermissionNewParams) (pagination.Page[openai.FineTuningCheckpointPermissionNewResponse], error) +- client.FineTuning.Checkpoints.Permissions.Get(ctx context.Context, fineTunedModelCheckpoint string, query openai.FineTuningCheckpointPermissionGetParams) (openai.FineTuningCheckpointPermissionGetResponse, error) +- client.FineTuning.Checkpoints.Permissions.Delete(ctx context.Context, fineTunedModelCheckpoint string, permissionID string) (openai.FineTuningCheckpointPermissionDeleteResponse, error) + +## Alpha + +### Graders + +Response Types: + +- openai.FineTuningAlphaGraderRunResponse +- openai.FineTuningAlphaGraderValidateResponse + +Methods: + +- client.FineTuning.Alpha.Graders.Run(ctx context.Context, body openai.FineTuningAlphaGraderRunParams) (openai.FineTuningAlphaGraderRunResponse, error) +- client.FineTuning.Alpha.Graders.Validate(ctx context.Context, body openai.FineTuningAlphaGraderValidateParams) (openai.FineTuningAlphaGraderValidateResponse, error) + +# Graders + +## GraderModels + +Params Types: + +- openai.LabelModelGraderParam +- openai.MultiGraderParam +- openai.PythonGraderParam +- openai.ScoreModelGraderParam +- openai.StringCheckGraderParam +- openai.TextSimilarityGraderParam + +Response Types: + +- openai.LabelModelGrader +- openai.MultiGrader +- openai.PythonGrader +- openai.ScoreModelGrader +- openai.StringCheckGrader +- openai.TextSimilarityGrader + +# VectorStores + +Params Types: + +- openai.AutoFileChunkingStrategyParam +- openai.FileChunkingStrategyParamUnion +- openai.StaticFileChunkingStrategyParam +- openai.StaticFileChunkingStrategyObjectParam + +Response Types: + +- openai.FileChunkingStrategyUnion +- openai.OtherFileChunkingStrategyObject +- openai.StaticFileChunkingStrategy +- openai.StaticFileChunkingStrategyObject +- openai.VectorStore +- openai.VectorStoreDeleted +- openai.VectorStoreSearchResponse + +Methods: + +- client.VectorStores.New(ctx context.Context, body openai.VectorStoreNewParams) (openai.VectorStore, error) +- client.VectorStores.Get(ctx context.Context, vectorStoreID string) (openai.VectorStore, error) +- client.VectorStores.Update(ctx context.Context, vectorStoreID string, body openai.VectorStoreUpdateParams) (openai.VectorStore, error) +- client.VectorStores.List(ctx context.Context, query openai.VectorStoreListParams) (pagination.CursorPage[openai.VectorStore], error) +- client.VectorStores.Delete(ctx context.Context, vectorStoreID string) (openai.VectorStoreDeleted, error) +- client.VectorStores.Search(ctx context.Context, vectorStoreID string, body openai.VectorStoreSearchParams) (pagination.Page[openai.VectorStoreSearchResponse], error) + +## Files + +Response Types: + +- openai.VectorStoreFile +- openai.VectorStoreFileDeleted +- openai.VectorStoreFileContentResponse + +Methods: + +- client.VectorStores.Files.New(ctx context.Context, vectorStoreID string, body openai.VectorStoreFileNewParams) (openai.VectorStoreFile, error) +- client.VectorStores.Files.Get(ctx context.Context, vectorStoreID string, fileID string) (openai.VectorStoreFile, error) +- client.VectorStores.Files.Update(ctx context.Context, vectorStoreID string, fileID string, body openai.VectorStoreFileUpdateParams) (openai.VectorStoreFile, error) +- client.VectorStores.Files.List(ctx context.Context, vectorStoreID string, query openai.VectorStoreFileListParams) (pagination.CursorPage[openai.VectorStoreFile], error) +- client.VectorStores.Files.Delete(ctx context.Context, vectorStoreID string, fileID string) (openai.VectorStoreFileDeleted, error) +- client.VectorStores.Files.Content(ctx context.Context, vectorStoreID string, fileID string) (pagination.Page[openai.VectorStoreFileContentResponse], error) + +## FileBatches + +Response Types: + +- openai.VectorStoreFileBatch + +Methods: + +- client.VectorStores.FileBatches.New(ctx context.Context, vectorStoreID string, body openai.VectorStoreFileBatchNewParams) (openai.VectorStoreFileBatch, error) +- client.VectorStores.FileBatches.Get(ctx context.Context, vectorStoreID string, batchID string) (openai.VectorStoreFileBatch, error) +- client.VectorStores.FileBatches.Cancel(ctx context.Context, vectorStoreID string, batchID string) (openai.VectorStoreFileBatch, error) +- client.VectorStores.FileBatches.ListFiles(ctx context.Context, vectorStoreID string, batchID string, query openai.VectorStoreFileBatchListFilesParams) (pagination.CursorPage[openai.VectorStoreFile], error) + +# Webhooks + +Response Types: + +- webhooks.BatchCancelledWebhookEvent +- webhooks.BatchCompletedWebhookEvent +- webhooks.BatchExpiredWebhookEvent +- webhooks.BatchFailedWebhookEvent +- webhooks.EvalRunCanceledWebhookEvent +- webhooks.EvalRunFailedWebhookEvent +- webhooks.EvalRunSucceededWebhookEvent +- webhooks.FineTuningJobCancelledWebhookEvent +- webhooks.FineTuningJobFailedWebhookEvent +- webhooks.FineTuningJobSucceededWebhookEvent +- webhooks.ResponseCancelledWebhookEvent +- webhooks.ResponseCompletedWebhookEvent +- webhooks.ResponseFailedWebhookEvent +- webhooks.ResponseIncompleteWebhookEvent +- webhooks.UnwrapWebhookEventUnion + +Methods: + +- client.Webhooks.Unwrap(body []byte, headers http.Header, opts ...option.RequestOption) (*webhooks.UnwrapWebhookEventUnion, error) +- client.Webhooks.UnwrapWithTolerance(body []byte, headers http.Header, tolerance time.Duration, opts ...option.RequestOption) (*webhooks.UnwrapWebhookEventUnion, error) +- client.Webhooks.UnwrapWithToleranceAndTime(body []byte, headers http.Header, tolerance time.Duration, now time.Time, opts ...option.RequestOption) (*webhooks.UnwrapWebhookEventUnion, error) +- client.Webhooks.VerifySignature(body []byte, headers http.Header, opts ...option.RequestOption) error +- client.Webhooks.VerifySignatureWithTolerance(body []byte, headers http.Header, tolerance time.Duration, opts ...option.RequestOption) error +- client.Webhooks.VerifySignatureWithToleranceAndTime(body []byte, headers http.Header, tolerance time.Duration, now time.Time, opts ...option.RequestOption) error + +# Beta + +## Assistants + +Params Types: + +- openai.AssistantToolUnionParam +- openai.CodeInterpreterToolParam +- openai.FileSearchToolParam +- openai.FunctionToolParam + +Response Types: + +- openai.Assistant +- openai.AssistantDeleted +- openai.AssistantStreamEventUnion +- openai.AssistantToolUnion +- openai.CodeInterpreterTool +- openai.FileSearchTool +- openai.FunctionTool + +Methods: + +- client.Beta.Assistants.New(ctx context.Context, body openai.BetaAssistantNewParams) (openai.Assistant, error) +- client.Beta.Assistants.Get(ctx context.Context, assistantID string) (openai.Assistant, error) +- client.Beta.Assistants.Update(ctx context.Context, assistantID string, body openai.BetaAssistantUpdateParams) (openai.Assistant, error) +- client.Beta.Assistants.List(ctx context.Context, query openai.BetaAssistantListParams) (pagination.CursorPage[openai.Assistant], error) +- client.Beta.Assistants.Delete(ctx context.Context, assistantID string) (openai.AssistantDeleted, error) + +## Threads + +Params Types: + +- openai.AssistantResponseFormatOptionUnionParam +- openai.AssistantToolChoiceParam +- openai.AssistantToolChoiceFunctionParam +- openai.AssistantToolChoiceOptionUnionParam + +Response Types: + +- openai.AssistantResponseFormatOptionUnion +- openai.AssistantToolChoice +- openai.AssistantToolChoiceFunction +- openai.AssistantToolChoiceOptionUnion +- openai.Thread +- openai.ThreadDeleted + +Methods: + +- client.Beta.Threads.New(ctx context.Context, body openai.BetaThreadNewParams) (openai.Thread, error) +- client.Beta.Threads.Get(ctx context.Context, threadID string) (openai.Thread, error) +- client.Beta.Threads.Update(ctx context.Context, threadID string, body openai.BetaThreadUpdateParams) (openai.Thread, error) +- client.Beta.Threads.Delete(ctx context.Context, threadID string) (openai.ThreadDeleted, error) +- client.Beta.Threads.NewAndRun(ctx context.Context, body openai.BetaThreadNewAndRunParams) (openai.Run, error) + +### Runs + +Response Types: + +- openai.RequiredActionFunctionToolCall +- openai.Run +- openai.RunStatus + +Methods: + +- client.Beta.Threads.Runs.New(ctx context.Context, threadID string, params openai.BetaThreadRunNewParams) (openai.Run, error) +- client.Beta.Threads.Runs.Get(ctx context.Context, threadID string, runID string) (openai.Run, error) +- client.Beta.Threads.Runs.Update(ctx context.Context, threadID string, runID string, body openai.BetaThreadRunUpdateParams) (openai.Run, error) +- client.Beta.Threads.Runs.List(ctx context.Context, threadID string, query openai.BetaThreadRunListParams) (pagination.CursorPage[openai.Run], error) +- client.Beta.Threads.Runs.Cancel(ctx context.Context, threadID string, runID string) (openai.Run, error) +- client.Beta.Threads.Runs.SubmitToolOutputs(ctx context.Context, threadID string, runID string, body openai.BetaThreadRunSubmitToolOutputsParams) (openai.Run, error) + +#### Steps + +Params Types: + +- openai.RunStepInclude + +Response Types: + +- openai.CodeInterpreterLogs +- openai.CodeInterpreterOutputImage +- openai.CodeInterpreterToolCall +- openai.CodeInterpreterToolCallDelta +- openai.FileSearchToolCall +- openai.FileSearchToolCallDelta +- openai.FunctionToolCall +- openai.FunctionToolCallDelta +- openai.MessageCreationStepDetails +- openai.RunStep +- openai.RunStepDelta +- openai.RunStepDeltaEvent +- openai.RunStepDeltaMessageDelta +- openai.ToolCallUnion +- openai.ToolCallDeltaUnion +- openai.ToolCallDeltaObject +- openai.ToolCallsStepDetails + +Methods: + +- client.Beta.Threads.Runs.Steps.Get(ctx context.Context, threadID string, runID string, stepID string, query openai.BetaThreadRunStepGetParams) (openai.RunStep, error) +- client.Beta.Threads.Runs.Steps.List(ctx context.Context, threadID string, runID string, query openai.BetaThreadRunStepListParams) (pagination.CursorPage[openai.RunStep], error) + +### Messages + +Params Types: + +- openai.ImageFileParam +- openai.ImageFileContentBlockParam +- openai.ImageURLParam +- openai.ImageURLContentBlockParam +- openai.MessageContentPartParamUnion +- openai.TextContentBlockParam + +Response Types: + +- openai.AnnotationUnion +- openai.AnnotationDeltaUnion +- openai.FileCitationAnnotation +- openai.FileCitationDeltaAnnotation +- openai.FilePathAnnotation +- openai.FilePathDeltaAnnotation +- openai.ImageFile +- openai.ImageFileContentBlock +- openai.ImageFileDelta +- openai.ImageFileDeltaBlock +- openai.ImageURL +- openai.ImageURLContentBlock +- openai.ImageURLDelta +- openai.ImageURLDeltaBlock +- openai.Message +- openai.MessageContentUnion +- openai.MessageContentDeltaUnion +- openai.MessageDeleted +- openai.MessageDelta +- openai.MessageDeltaEvent +- openai.RefusalContentBlock +- openai.RefusalDeltaBlock +- openai.Text +- openai.TextContentBlock +- openai.TextDelta +- openai.TextDeltaBlock + +Methods: + +- client.Beta.Threads.Messages.New(ctx context.Context, threadID string, body openai.BetaThreadMessageNewParams) (openai.Message, error) +- client.Beta.Threads.Messages.Get(ctx context.Context, threadID string, messageID string) (openai.Message, error) +- client.Beta.Threads.Messages.Update(ctx context.Context, threadID string, messageID string, body openai.BetaThreadMessageUpdateParams) (openai.Message, error) +- client.Beta.Threads.Messages.List(ctx context.Context, threadID string, query openai.BetaThreadMessageListParams) (pagination.CursorPage[openai.Message], error) +- client.Beta.Threads.Messages.Delete(ctx context.Context, threadID string, messageID string) (openai.MessageDeleted, error) + +# Batches + +Response Types: + +- openai.Batch +- openai.BatchError +- openai.BatchRequestCounts + +Methods: + +- client.Batches.New(ctx context.Context, body openai.BatchNewParams) (openai.Batch, error) +- client.Batches.Get(ctx context.Context, batchID string) (openai.Batch, error) +- client.Batches.List(ctx context.Context, query openai.BatchListParams) (pagination.CursorPage[openai.Batch], error) +- client.Batches.Cancel(ctx context.Context, batchID string) (openai.Batch, error) + +# Uploads + +Response Types: + +- openai.Upload + +Methods: + +- client.Uploads.New(ctx context.Context, body openai.UploadNewParams) (openai.Upload, error) +- client.Uploads.Cancel(ctx context.Context, uploadID string) (openai.Upload, error) +- client.Uploads.Complete(ctx context.Context, uploadID string, body openai.UploadCompleteParams) (openai.Upload, error) + +## Parts + +Response Types: + +- openai.UploadPart + +Methods: + +- client.Uploads.Parts.New(ctx context.Context, uploadID string, body openai.UploadPartNewParams) (openai.UploadPart, error) + +# Responses + +Params Types: + +- responses.ComputerToolParam +- responses.EasyInputMessageParam +- responses.FileSearchToolParam +- responses.FunctionToolParam +- responses.ResponseCodeInterpreterToolCallParam +- responses.ResponseComputerToolCallParam +- responses.ResponseComputerToolCallOutputScreenshotParam +- responses.ResponseFileSearchToolCallParam +- responses.ResponseFormatTextConfigUnionParam +- responses.ResponseFormatTextJSONSchemaConfigParam +- responses.ResponseFunctionToolCallParam +- responses.ResponseFunctionWebSearchParam +- responses.ResponseIncludable +- responses.ResponseInputParam +- responses.ResponseInputContentUnionParam +- responses.ResponseInputFileParam +- responses.ResponseInputImageParam +- responses.ResponseInputItemUnionParam +- responses.ResponseInputMessageContentListParam +- responses.ResponseInputTextParam +- responses.ResponseOutputMessageParam +- responses.ResponseOutputRefusalParam +- responses.ResponseOutputTextParam +- responses.ResponsePromptParam +- responses.ResponseReasoningItemParam +- responses.ResponseTextConfigParam +- responses.ToolUnionParam +- responses.ToolChoiceFunctionParam +- responses.ToolChoiceMcpParam +- responses.ToolChoiceOptions +- responses.ToolChoiceTypesParam +- responses.WebSearchToolParam + +Response Types: + +- responses.ComputerTool +- responses.EasyInputMessage +- responses.FileSearchTool +- responses.FunctionTool +- responses.Response +- responses.ResponseAudioDeltaEvent +- responses.ResponseAudioDoneEvent +- responses.ResponseAudioTranscriptDeltaEvent +- responses.ResponseAudioTranscriptDoneEvent +- responses.ResponseCodeInterpreterCallCodeDeltaEvent +- responses.ResponseCodeInterpreterCallCodeDoneEvent +- responses.ResponseCodeInterpreterCallCompletedEvent +- responses.ResponseCodeInterpreterCallInProgressEvent +- responses.ResponseCodeInterpreterCallInterpretingEvent +- responses.ResponseCodeInterpreterToolCall +- responses.ResponseCompletedEvent +- responses.ResponseComputerToolCall +- responses.ResponseComputerToolCallOutputItem +- responses.ResponseComputerToolCallOutputScreenshot +- responses.ResponseContentPartAddedEvent +- responses.ResponseContentPartDoneEvent +- responses.ResponseCreatedEvent +- responses.ResponseError +- responses.ResponseErrorEvent +- responses.ResponseFailedEvent +- responses.ResponseFileSearchCallCompletedEvent +- responses.ResponseFileSearchCallInProgressEvent +- responses.ResponseFileSearchCallSearchingEvent +- responses.ResponseFileSearchToolCall +- responses.ResponseFormatTextConfigUnion +- responses.ResponseFormatTextJSONSchemaConfig +- responses.ResponseFunctionCallArgumentsDeltaEvent +- responses.ResponseFunctionCallArgumentsDoneEvent +- responses.ResponseFunctionToolCall +- responses.ResponseFunctionToolCallItem +- responses.ResponseFunctionToolCallOutputItem +- responses.ResponseFunctionWebSearch +- responses.ResponseImageGenCallCompletedEvent +- responses.ResponseImageGenCallGeneratingEvent +- responses.ResponseImageGenCallInProgressEvent +- responses.ResponseImageGenCallPartialImageEvent +- responses.ResponseInProgressEvent +- responses.ResponseIncompleteEvent +- responses.ResponseInputContentUnion +- responses.ResponseInputFile +- responses.ResponseInputImage +- responses.ResponseInputItemUnion +- responses.ResponseInputMessageContentList +- responses.ResponseInputMessageItem +- responses.ResponseInputText +- responses.ResponseItemUnion +- responses.ResponseMcpCallArgumentsDeltaEvent +- responses.ResponseMcpCallArgumentsDoneEvent +- responses.ResponseMcpCallCompletedEvent +- responses.ResponseMcpCallFailedEvent +- responses.ResponseMcpCallInProgressEvent +- responses.ResponseMcpListToolsCompletedEvent +- responses.ResponseMcpListToolsFailedEvent +- responses.ResponseMcpListToolsInProgressEvent +- responses.ResponseOutputItemUnion +- responses.ResponseOutputItemAddedEvent +- responses.ResponseOutputItemDoneEvent +- responses.ResponseOutputMessage +- responses.ResponseOutputRefusal +- responses.ResponseOutputText +- responses.ResponseOutputTextAnnotationAddedEvent +- responses.ResponsePrompt +- responses.ResponseQueuedEvent +- responses.ResponseReasoningItem +- responses.ResponseReasoningSummaryDeltaEvent +- responses.ResponseReasoningSummaryDoneEvent +- responses.ResponseReasoningSummaryPartAddedEvent +- responses.ResponseReasoningSummaryPartDoneEvent +- responses.ResponseReasoningSummaryTextDeltaEvent +- responses.ResponseReasoningSummaryTextDoneEvent +- responses.ResponseRefusalDeltaEvent +- responses.ResponseRefusalDoneEvent +- responses.ResponseStatus +- responses.ResponseStreamEventUnion +- responses.ResponseTextConfig +- responses.ResponseTextDeltaEvent +- responses.ResponseTextDoneEvent +- responses.ResponseUsage +- responses.ResponseWebSearchCallCompletedEvent +- responses.ResponseWebSearchCallInProgressEvent +- responses.ResponseWebSearchCallSearchingEvent +- responses.ToolUnion +- responses.ToolChoiceFunction +- responses.ToolChoiceMcp +- responses.ToolChoiceOptions +- responses.ToolChoiceTypes +- responses.WebSearchTool + +Methods: + +- client.Responses.New(ctx context.Context, body responses.ResponseNewParams) (responses.Response, error) +- client.Responses.Get(ctx context.Context, responseID string, query responses.ResponseGetParams) (responses.Response, error) +- client.Responses.Delete(ctx context.Context, responseID string) error +- client.Responses.Cancel(ctx context.Context, responseID string) (responses.Response, error) + +## InputItems + +Response Types: + +- responses.ResponseItemList + +Methods: + +- client.Responses.InputItems.List(ctx context.Context, responseID string, query responses.InputItemListParams) (pagination.CursorPage[responses.ResponseItemUnion], error) + +# Containers + +Response Types: + +- openai.ContainerNewResponse +- openai.ContainerGetResponse +- openai.ContainerListResponse + +Methods: + +- client.Containers.New(ctx context.Context, body openai.ContainerNewParams) (openai.ContainerNewResponse, error) +- client.Containers.Get(ctx context.Context, containerID string) (openai.ContainerGetResponse, error) +- client.Containers.List(ctx context.Context, query openai.ContainerListParams) (pagination.CursorPage[openai.ContainerListResponse], error) +- client.Containers.Delete(ctx context.Context, containerID string) error + +## Files + +Response Types: + +- openai.ContainerFileNewResponse +- openai.ContainerFileGetResponse +- openai.ContainerFileListResponse + +Methods: + +- client.Containers.Files.New(ctx context.Context, containerID string, body openai.ContainerFileNewParams) (openai.ContainerFileNewResponse, error) +- client.Containers.Files.Get(ctx context.Context, containerID string, fileID string) (openai.ContainerFileGetResponse, error) +- client.Containers.Files.List(ctx context.Context, containerID string, query openai.ContainerFileListParams) (pagination.CursorPage[openai.ContainerFileListResponse], error) +- client.Containers.Files.Delete(ctx context.Context, containerID string, fileID string) error + +### Content + +Methods: + +- client.Containers.Files.Content.Get(ctx context.Context, containerID string, fileID string) (http.Response, error) diff --git a/vendor/github.com/openai/openai-go/audio.go b/vendor/github.com/openai/openai-go/audio.go new file mode 100644 index 0000000000..9cd3e19dcf --- /dev/null +++ b/vendor/github.com/openai/openai-go/audio.go @@ -0,0 +1,53 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package openai + +import ( + "github.com/openai/openai-go/option" +) + +// AudioService contains methods and other services that help with interacting with +// the openai API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewAudioService] method instead. +type AudioService struct { + Options []option.RequestOption + Transcriptions AudioTranscriptionService + Translations AudioTranslationService + Speech AudioSpeechService +} + +// NewAudioService generates a new service that applies the given options to each +// request. These options are applied after the parent client's options (if there +// is one), and before any request-specific options. +func NewAudioService(opts ...option.RequestOption) (r AudioService) { + r = AudioService{} + r.Options = opts + r.Transcriptions = NewAudioTranscriptionService(opts...) + r.Translations = NewAudioTranslationService(opts...) + r.Speech = NewAudioSpeechService(opts...) + return +} + +type AudioModel = string + +const ( + AudioModelWhisper1 AudioModel = "whisper-1" + AudioModelGPT4oTranscribe AudioModel = "gpt-4o-transcribe" + AudioModelGPT4oMiniTranscribe AudioModel = "gpt-4o-mini-transcribe" +) + +// The format of the output, in one of these options: `json`, `text`, `srt`, +// `verbose_json`, or `vtt`. For `gpt-4o-transcribe` and `gpt-4o-mini-transcribe`, +// the only supported format is `json`. +type AudioResponseFormat string + +const ( + AudioResponseFormatJSON AudioResponseFormat = "json" + AudioResponseFormatText AudioResponseFormat = "text" + AudioResponseFormatSRT AudioResponseFormat = "srt" + AudioResponseFormatVerboseJSON AudioResponseFormat = "verbose_json" + AudioResponseFormatVTT AudioResponseFormat = "vtt" +) diff --git a/vendor/github.com/openai/openai-go/audiospeech.go b/vendor/github.com/openai/openai-go/audiospeech.go new file mode 100644 index 0000000000..8adc81a68e --- /dev/null +++ b/vendor/github.com/openai/openai-go/audiospeech.go @@ -0,0 +1,126 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package openai + +import ( + "context" + "net/http" + + "github.com/openai/openai-go/internal/apijson" + "github.com/openai/openai-go/internal/requestconfig" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/packages/param" +) + +// AudioSpeechService contains methods and other services that help with +// interacting with the openai API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewAudioSpeechService] method instead. +type AudioSpeechService struct { + Options []option.RequestOption +} + +// NewAudioSpeechService generates a new service that applies the given options to +// each request. These options are applied after the parent client's options (if +// there is one), and before any request-specific options. +func NewAudioSpeechService(opts ...option.RequestOption) (r AudioSpeechService) { + r = AudioSpeechService{} + r.Options = opts + return +} + +// Generates audio from the input text. +func (r *AudioSpeechService) New(ctx context.Context, body AudioSpeechNewParams, opts ...option.RequestOption) (res *http.Response, err error) { + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("Accept", "application/octet-stream")}, opts...) + path := "audio/speech" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) + return +} + +type SpeechModel = string + +const ( + SpeechModelTTS1 SpeechModel = "tts-1" + SpeechModelTTS1HD SpeechModel = "tts-1-hd" + SpeechModelGPT4oMiniTTS SpeechModel = "gpt-4o-mini-tts" +) + +type AudioSpeechNewParams struct { + // The text to generate audio for. The maximum length is 4096 characters. + Input string `json:"input,required"` + // One of the available [TTS models](https://platform.openai.com/docs/models#tts): + // `tts-1`, `tts-1-hd` or `gpt-4o-mini-tts`. + Model SpeechModel `json:"model,omitzero,required"` + // The voice to use when generating the audio. Supported voices are `alloy`, `ash`, + // `ballad`, `coral`, `echo`, `fable`, `onyx`, `nova`, `sage`, `shimmer`, and + // `verse`. Previews of the voices are available in the + // [Text to speech guide](https://platform.openai.com/docs/guides/text-to-speech#voice-options). + Voice AudioSpeechNewParamsVoice `json:"voice,omitzero,required"` + // Control the voice of your generated audio with additional instructions. Does not + // work with `tts-1` or `tts-1-hd`. + Instructions param.Opt[string] `json:"instructions,omitzero"` + // The speed of the generated audio. Select a value from `0.25` to `4.0`. `1.0` is + // the default. + Speed param.Opt[float64] `json:"speed,omitzero"` + // The format to audio in. Supported formats are `mp3`, `opus`, `aac`, `flac`, + // `wav`, and `pcm`. + // + // Any of "mp3", "opus", "aac", "flac", "wav", "pcm". + ResponseFormat AudioSpeechNewParamsResponseFormat `json:"response_format,omitzero"` + // The format to stream the audio in. Supported formats are `sse` and `audio`. + // `sse` is not supported for `tts-1` or `tts-1-hd`. + // + // Any of "sse", "audio". + StreamFormat AudioSpeechNewParamsStreamFormat `json:"stream_format,omitzero"` + paramObj +} + +func (r AudioSpeechNewParams) MarshalJSON() (data []byte, err error) { + type shadow AudioSpeechNewParams + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *AudioSpeechNewParams) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The voice to use when generating the audio. Supported voices are `alloy`, `ash`, +// `ballad`, `coral`, `echo`, `fable`, `onyx`, `nova`, `sage`, `shimmer`, and +// `verse`. Previews of the voices are available in the +// [Text to speech guide](https://platform.openai.com/docs/guides/text-to-speech#voice-options). +type AudioSpeechNewParamsVoice string + +const ( + AudioSpeechNewParamsVoiceAlloy AudioSpeechNewParamsVoice = "alloy" + AudioSpeechNewParamsVoiceAsh AudioSpeechNewParamsVoice = "ash" + AudioSpeechNewParamsVoiceBallad AudioSpeechNewParamsVoice = "ballad" + AudioSpeechNewParamsVoiceCoral AudioSpeechNewParamsVoice = "coral" + AudioSpeechNewParamsVoiceEcho AudioSpeechNewParamsVoice = "echo" + AudioSpeechNewParamsVoiceSage AudioSpeechNewParamsVoice = "sage" + AudioSpeechNewParamsVoiceShimmer AudioSpeechNewParamsVoice = "shimmer" + AudioSpeechNewParamsVoiceVerse AudioSpeechNewParamsVoice = "verse" +) + +// The format to audio in. Supported formats are `mp3`, `opus`, `aac`, `flac`, +// `wav`, and `pcm`. +type AudioSpeechNewParamsResponseFormat string + +const ( + AudioSpeechNewParamsResponseFormatMP3 AudioSpeechNewParamsResponseFormat = "mp3" + AudioSpeechNewParamsResponseFormatOpus AudioSpeechNewParamsResponseFormat = "opus" + AudioSpeechNewParamsResponseFormatAAC AudioSpeechNewParamsResponseFormat = "aac" + AudioSpeechNewParamsResponseFormatFLAC AudioSpeechNewParamsResponseFormat = "flac" + AudioSpeechNewParamsResponseFormatWAV AudioSpeechNewParamsResponseFormat = "wav" + AudioSpeechNewParamsResponseFormatPCM AudioSpeechNewParamsResponseFormat = "pcm" +) + +// The format to stream the audio in. Supported formats are `sse` and `audio`. +// `sse` is not supported for `tts-1` or `tts-1-hd`. +type AudioSpeechNewParamsStreamFormat string + +const ( + AudioSpeechNewParamsStreamFormatSSE AudioSpeechNewParamsStreamFormat = "sse" + AudioSpeechNewParamsStreamFormatAudio AudioSpeechNewParamsStreamFormat = "audio" +) diff --git a/vendor/github.com/openai/openai-go/audiotranscription.go b/vendor/github.com/openai/openai-go/audiotranscription.go new file mode 100644 index 0000000000..a7a7913899 --- /dev/null +++ b/vendor/github.com/openai/openai-go/audiotranscription.go @@ -0,0 +1,654 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package openai + +import ( + "bytes" + "context" + "encoding/json" + "io" + "mime/multipart" + "net/http" + + "github.com/openai/openai-go/internal/apiform" + "github.com/openai/openai-go/internal/apijson" + "github.com/openai/openai-go/internal/requestconfig" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/packages/param" + "github.com/openai/openai-go/packages/respjson" + "github.com/openai/openai-go/packages/ssestream" + "github.com/openai/openai-go/shared/constant" +) + +// AudioTranscriptionService contains methods and other services that help with +// interacting with the openai API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewAudioTranscriptionService] method instead. +type AudioTranscriptionService struct { + Options []option.RequestOption +} + +// NewAudioTranscriptionService generates a new service that applies the given +// options to each request. These options are applied after the parent client's +// options (if there is one), and before any request-specific options. +func NewAudioTranscriptionService(opts ...option.RequestOption) (r AudioTranscriptionService) { + r = AudioTranscriptionService{} + r.Options = opts + return +} + +// Transcribes audio into the input language. +func (r *AudioTranscriptionService) New(ctx context.Context, body AudioTranscriptionNewParams, opts ...option.RequestOption) (res *Transcription, err error) { + opts = append(r.Options[:], opts...) + path := "audio/transcriptions" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) + return +} + +// Transcribes audio into the input language. +func (r *AudioTranscriptionService) NewStreaming(ctx context.Context, body AudioTranscriptionNewParams, opts ...option.RequestOption) (stream *ssestream.Stream[TranscriptionStreamEventUnion]) { + var ( + raw *http.Response + err error + ) + opts = append(r.Options[:], opts...) + body.SetExtraFields(map[string]any{ + "stream": "true", + }) + path := "audio/transcriptions" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &raw, opts...) + return ssestream.NewStream[TranscriptionStreamEventUnion](ssestream.NewDecoder(raw), err) +} + +// Represents a transcription response returned by model, based on the provided +// input. +type Transcription struct { + // The transcribed text. + Text string `json:"text,required"` + // The log probabilities of the tokens in the transcription. Only returned with the + // models `gpt-4o-transcribe` and `gpt-4o-mini-transcribe` if `logprobs` is added + // to the `include` array. + Logprobs []TranscriptionLogprob `json:"logprobs"` + // Token usage statistics for the request. + Usage TranscriptionUsageUnion `json:"usage"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Text respjson.Field + Logprobs respjson.Field + Usage respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r Transcription) RawJSON() string { return r.JSON.raw } +func (r *Transcription) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type TranscriptionLogprob struct { + // The token in the transcription. + Token string `json:"token"` + // The bytes of the token. + Bytes []float64 `json:"bytes"` + // The log probability of the token. + Logprob float64 `json:"logprob"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Token respjson.Field + Bytes respjson.Field + Logprob respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r TranscriptionLogprob) RawJSON() string { return r.JSON.raw } +func (r *TranscriptionLogprob) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// TranscriptionUsageUnion contains all possible properties and values from +// [TranscriptionUsageTokens], [TranscriptionUsageDuration]. +// +// Use the [TranscriptionUsageUnion.AsAny] method to switch on the variant. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +type TranscriptionUsageUnion struct { + // This field is from variant [TranscriptionUsageTokens]. + InputTokens int64 `json:"input_tokens"` + // This field is from variant [TranscriptionUsageTokens]. + OutputTokens int64 `json:"output_tokens"` + // This field is from variant [TranscriptionUsageTokens]. + TotalTokens int64 `json:"total_tokens"` + // Any of "tokens", "duration". + Type string `json:"type"` + // This field is from variant [TranscriptionUsageTokens]. + InputTokenDetails TranscriptionUsageTokensInputTokenDetails `json:"input_token_details"` + // This field is from variant [TranscriptionUsageDuration]. + Seconds float64 `json:"seconds"` + JSON struct { + InputTokens respjson.Field + OutputTokens respjson.Field + TotalTokens respjson.Field + Type respjson.Field + InputTokenDetails respjson.Field + Seconds respjson.Field + raw string + } `json:"-"` +} + +// anyTranscriptionUsage is implemented by each variant of +// [TranscriptionUsageUnion] to add type safety for the return type of +// [TranscriptionUsageUnion.AsAny] +type anyTranscriptionUsage interface { + implTranscriptionUsageUnion() +} + +func (TranscriptionUsageTokens) implTranscriptionUsageUnion() {} +func (TranscriptionUsageDuration) implTranscriptionUsageUnion() {} + +// Use the following switch statement to find the correct variant +// +// switch variant := TranscriptionUsageUnion.AsAny().(type) { +// case openai.TranscriptionUsageTokens: +// case openai.TranscriptionUsageDuration: +// default: +// fmt.Errorf("no variant present") +// } +func (u TranscriptionUsageUnion) AsAny() anyTranscriptionUsage { + switch u.Type { + case "tokens": + return u.AsTokens() + case "duration": + return u.AsDuration() + } + return nil +} + +func (u TranscriptionUsageUnion) AsTokens() (v TranscriptionUsageTokens) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u TranscriptionUsageUnion) AsDuration() (v TranscriptionUsageDuration) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u TranscriptionUsageUnion) RawJSON() string { return u.JSON.raw } + +func (r *TranscriptionUsageUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Usage statistics for models billed by token usage. +type TranscriptionUsageTokens struct { + // Number of input tokens billed for this request. + InputTokens int64 `json:"input_tokens,required"` + // Number of output tokens generated. + OutputTokens int64 `json:"output_tokens,required"` + // Total number of tokens used (input + output). + TotalTokens int64 `json:"total_tokens,required"` + // The type of the usage object. Always `tokens` for this variant. + Type constant.Tokens `json:"type,required"` + // Details about the input tokens billed for this request. + InputTokenDetails TranscriptionUsageTokensInputTokenDetails `json:"input_token_details"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + InputTokens respjson.Field + OutputTokens respjson.Field + TotalTokens respjson.Field + Type respjson.Field + InputTokenDetails respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r TranscriptionUsageTokens) RawJSON() string { return r.JSON.raw } +func (r *TranscriptionUsageTokens) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Details about the input tokens billed for this request. +type TranscriptionUsageTokensInputTokenDetails struct { + // Number of audio tokens billed for this request. + AudioTokens int64 `json:"audio_tokens"` + // Number of text tokens billed for this request. + TextTokens int64 `json:"text_tokens"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + AudioTokens respjson.Field + TextTokens respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r TranscriptionUsageTokensInputTokenDetails) RawJSON() string { return r.JSON.raw } +func (r *TranscriptionUsageTokensInputTokenDetails) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Usage statistics for models billed by audio input duration. +type TranscriptionUsageDuration struct { + // Duration of the input audio in seconds. + Seconds float64 `json:"seconds,required"` + // The type of the usage object. Always `duration` for this variant. + Type constant.Duration `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Seconds respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r TranscriptionUsageDuration) RawJSON() string { return r.JSON.raw } +func (r *TranscriptionUsageDuration) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type TranscriptionInclude string + +const ( + TranscriptionIncludeLogprobs TranscriptionInclude = "logprobs" +) + +// TranscriptionStreamEventUnion contains all possible properties and values from +// [TranscriptionTextDeltaEvent], [TranscriptionTextDoneEvent]. +// +// Use the [TranscriptionStreamEventUnion.AsAny] method to switch on the variant. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +type TranscriptionStreamEventUnion struct { + // This field is from variant [TranscriptionTextDeltaEvent]. + Delta string `json:"delta"` + // Any of "transcript.text.delta", "transcript.text.done". + Type string `json:"type"` + // This field is a union of [[]TranscriptionTextDeltaEventLogprob], + // [[]TranscriptionTextDoneEventLogprob] + Logprobs TranscriptionStreamEventUnionLogprobs `json:"logprobs"` + // This field is from variant [TranscriptionTextDoneEvent]. + Text string `json:"text"` + // This field is from variant [TranscriptionTextDoneEvent]. + Usage TranscriptionTextDoneEventUsage `json:"usage"` + JSON struct { + Delta respjson.Field + Type respjson.Field + Logprobs respjson.Field + Text respjson.Field + Usage respjson.Field + raw string + } `json:"-"` +} + +// anyTranscriptionStreamEvent is implemented by each variant of +// [TranscriptionStreamEventUnion] to add type safety for the return type of +// [TranscriptionStreamEventUnion.AsAny] +type anyTranscriptionStreamEvent interface { + implTranscriptionStreamEventUnion() +} + +func (TranscriptionTextDeltaEvent) implTranscriptionStreamEventUnion() {} +func (TranscriptionTextDoneEvent) implTranscriptionStreamEventUnion() {} + +// Use the following switch statement to find the correct variant +// +// switch variant := TranscriptionStreamEventUnion.AsAny().(type) { +// case openai.TranscriptionTextDeltaEvent: +// case openai.TranscriptionTextDoneEvent: +// default: +// fmt.Errorf("no variant present") +// } +func (u TranscriptionStreamEventUnion) AsAny() anyTranscriptionStreamEvent { + switch u.Type { + case "transcript.text.delta": + return u.AsTranscriptTextDelta() + case "transcript.text.done": + return u.AsTranscriptTextDone() + } + return nil +} + +func (u TranscriptionStreamEventUnion) AsTranscriptTextDelta() (v TranscriptionTextDeltaEvent) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u TranscriptionStreamEventUnion) AsTranscriptTextDone() (v TranscriptionTextDoneEvent) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u TranscriptionStreamEventUnion) RawJSON() string { return u.JSON.raw } + +func (r *TranscriptionStreamEventUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// TranscriptionStreamEventUnionLogprobs is an implicit subunion of +// [TranscriptionStreamEventUnion]. TranscriptionStreamEventUnionLogprobs provides +// convenient access to the sub-properties of the union. +// +// For type safety it is recommended to directly use a variant of the +// [TranscriptionStreamEventUnion]. +// +// If the underlying value is not a json object, one of the following properties +// will be valid: OfTranscriptionTextDeltaEventLogprobs +// OfTranscriptionTextDoneEventLogprobs] +type TranscriptionStreamEventUnionLogprobs struct { + // This field will be present if the value is a + // [[]TranscriptionTextDeltaEventLogprob] instead of an object. + OfTranscriptionTextDeltaEventLogprobs []TranscriptionTextDeltaEventLogprob `json:",inline"` + // This field will be present if the value is a + // [[]TranscriptionTextDoneEventLogprob] instead of an object. + OfTranscriptionTextDoneEventLogprobs []TranscriptionTextDoneEventLogprob `json:",inline"` + JSON struct { + OfTranscriptionTextDeltaEventLogprobs respjson.Field + OfTranscriptionTextDoneEventLogprobs respjson.Field + raw string + } `json:"-"` +} + +func (r *TranscriptionStreamEventUnionLogprobs) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Emitted when there is an additional text delta. This is also the first event +// emitted when the transcription starts. Only emitted when you +// [create a transcription](https://platform.openai.com/docs/api-reference/audio/create-transcription) +// with the `Stream` parameter set to `true`. +type TranscriptionTextDeltaEvent struct { + // The text delta that was additionally transcribed. + Delta string `json:"delta,required"` + // The type of the event. Always `transcript.text.delta`. + Type constant.TranscriptTextDelta `json:"type,required"` + // The log probabilities of the delta. Only included if you + // [create a transcription](https://platform.openai.com/docs/api-reference/audio/create-transcription) + // with the `include[]` parameter set to `logprobs`. + Logprobs []TranscriptionTextDeltaEventLogprob `json:"logprobs"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Delta respjson.Field + Type respjson.Field + Logprobs respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r TranscriptionTextDeltaEvent) RawJSON() string { return r.JSON.raw } +func (r *TranscriptionTextDeltaEvent) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type TranscriptionTextDeltaEventLogprob struct { + // The token that was used to generate the log probability. + Token string `json:"token"` + // The bytes that were used to generate the log probability. + Bytes []int64 `json:"bytes"` + // The log probability of the token. + Logprob float64 `json:"logprob"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Token respjson.Field + Bytes respjson.Field + Logprob respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r TranscriptionTextDeltaEventLogprob) RawJSON() string { return r.JSON.raw } +func (r *TranscriptionTextDeltaEventLogprob) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Emitted when the transcription is complete. Contains the complete transcription +// text. Only emitted when you +// [create a transcription](https://platform.openai.com/docs/api-reference/audio/create-transcription) +// with the `Stream` parameter set to `true`. +type TranscriptionTextDoneEvent struct { + // The text that was transcribed. + Text string `json:"text,required"` + // The type of the event. Always `transcript.text.done`. + Type constant.TranscriptTextDone `json:"type,required"` + // The log probabilities of the individual tokens in the transcription. Only + // included if you + // [create a transcription](https://platform.openai.com/docs/api-reference/audio/create-transcription) + // with the `include[]` parameter set to `logprobs`. + Logprobs []TranscriptionTextDoneEventLogprob `json:"logprobs"` + // Usage statistics for models billed by token usage. + Usage TranscriptionTextDoneEventUsage `json:"usage"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Text respjson.Field + Type respjson.Field + Logprobs respjson.Field + Usage respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r TranscriptionTextDoneEvent) RawJSON() string { return r.JSON.raw } +func (r *TranscriptionTextDoneEvent) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type TranscriptionTextDoneEventLogprob struct { + // The token that was used to generate the log probability. + Token string `json:"token"` + // The bytes that were used to generate the log probability. + Bytes []int64 `json:"bytes"` + // The log probability of the token. + Logprob float64 `json:"logprob"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Token respjson.Field + Bytes respjson.Field + Logprob respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r TranscriptionTextDoneEventLogprob) RawJSON() string { return r.JSON.raw } +func (r *TranscriptionTextDoneEventLogprob) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Usage statistics for models billed by token usage. +type TranscriptionTextDoneEventUsage struct { + // Number of input tokens billed for this request. + InputTokens int64 `json:"input_tokens,required"` + // Number of output tokens generated. + OutputTokens int64 `json:"output_tokens,required"` + // Total number of tokens used (input + output). + TotalTokens int64 `json:"total_tokens,required"` + // The type of the usage object. Always `tokens` for this variant. + Type constant.Tokens `json:"type,required"` + // Details about the input tokens billed for this request. + InputTokenDetails TranscriptionTextDoneEventUsageInputTokenDetails `json:"input_token_details"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + InputTokens respjson.Field + OutputTokens respjson.Field + TotalTokens respjson.Field + Type respjson.Field + InputTokenDetails respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r TranscriptionTextDoneEventUsage) RawJSON() string { return r.JSON.raw } +func (r *TranscriptionTextDoneEventUsage) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Details about the input tokens billed for this request. +type TranscriptionTextDoneEventUsageInputTokenDetails struct { + // Number of audio tokens billed for this request. + AudioTokens int64 `json:"audio_tokens"` + // Number of text tokens billed for this request. + TextTokens int64 `json:"text_tokens"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + AudioTokens respjson.Field + TextTokens respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r TranscriptionTextDoneEventUsageInputTokenDetails) RawJSON() string { return r.JSON.raw } +func (r *TranscriptionTextDoneEventUsageInputTokenDetails) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type AudioTranscriptionNewParams struct { + // The audio file object (not file name) to transcribe, in one of these formats: + // flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. + File io.Reader `json:"file,omitzero,required" format:"binary"` + // ID of the model to use. The options are `gpt-4o-transcribe`, + // `gpt-4o-mini-transcribe`, and `whisper-1` (which is powered by our open source + // Whisper V2 model). + Model AudioModel `json:"model,omitzero,required"` + // The language of the input audio. Supplying the input language in + // [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) (e.g. `en`) + // format will improve accuracy and latency. + Language param.Opt[string] `json:"language,omitzero"` + // An optional text to guide the model's style or continue a previous audio + // segment. The + // [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting) + // should match the audio language. + Prompt param.Opt[string] `json:"prompt,omitzero"` + // The sampling temperature, between 0 and 1. Higher values like 0.8 will make the + // output more random, while lower values like 0.2 will make it more focused and + // deterministic. If set to 0, the model will use + // [log probability](https://en.wikipedia.org/wiki/Log_probability) to + // automatically increase the temperature until certain thresholds are hit. + Temperature param.Opt[float64] `json:"temperature,omitzero"` + // Controls how the audio is cut into chunks. When set to `"auto"`, the server + // first normalizes loudness and then uses voice activity detection (VAD) to choose + // boundaries. `server_vad` object can be provided to tweak VAD detection + // parameters manually. If unset, the audio is transcribed as a single block. + ChunkingStrategy AudioTranscriptionNewParamsChunkingStrategyUnion `json:"chunking_strategy,omitzero"` + // Additional information to include in the transcription response. `logprobs` will + // return the log probabilities of the tokens in the response to understand the + // model's confidence in the transcription. `logprobs` only works with + // response_format set to `json` and only with the models `gpt-4o-transcribe` and + // `gpt-4o-mini-transcribe`. + Include []TranscriptionInclude `json:"include,omitzero"` + // The format of the output, in one of these options: `json`, `text`, `srt`, + // `verbose_json`, or `vtt`. For `gpt-4o-transcribe` and `gpt-4o-mini-transcribe`, + // the only supported format is `json`. + // + // Any of "json", "text", "srt", "verbose_json", "vtt". + ResponseFormat AudioResponseFormat `json:"response_format,omitzero"` + // The timestamp granularities to populate for this transcription. + // `response_format` must be set `verbose_json` to use timestamp granularities. + // Either or both of these options are supported: `word`, or `segment`. Note: There + // is no additional latency for segment timestamps, but generating word timestamps + // incurs additional latency. + // + // Any of "word", "segment". + TimestampGranularities []string `json:"timestamp_granularities,omitzero"` + paramObj +} + +func (r AudioTranscriptionNewParams) MarshalMultipart() (data []byte, contentType string, err error) { + buf := bytes.NewBuffer(nil) + writer := multipart.NewWriter(buf) + err = apiform.MarshalRoot(r, writer) + if err == nil { + err = apiform.WriteExtras(writer, r.ExtraFields()) + } + if err != nil { + writer.Close() + return nil, "", err + } + err = writer.Close() + if err != nil { + return nil, "", err + } + return buf.Bytes(), writer.FormDataContentType(), nil +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type AudioTranscriptionNewParamsChunkingStrategyUnion struct { + // Construct this variant with constant.ValueOf[constant.Auto]() + OfAuto constant.Auto `json:",omitzero,inline"` + OfAudioTranscriptionNewsChunkingStrategyVadConfig *AudioTranscriptionNewParamsChunkingStrategyVadConfig `json:",omitzero,inline"` + paramUnion +} + +func (u AudioTranscriptionNewParamsChunkingStrategyUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfAuto, u.OfAudioTranscriptionNewsChunkingStrategyVadConfig) +} +func (u *AudioTranscriptionNewParamsChunkingStrategyUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *AudioTranscriptionNewParamsChunkingStrategyUnion) asAny() any { + if !param.IsOmitted(u.OfAuto) { + return &u.OfAuto + } else if !param.IsOmitted(u.OfAudioTranscriptionNewsChunkingStrategyVadConfig) { + return u.OfAudioTranscriptionNewsChunkingStrategyVadConfig + } + return nil +} + +// The property Type is required. +type AudioTranscriptionNewParamsChunkingStrategyVadConfig struct { + // Must be set to `server_vad` to enable manual chunking using server side VAD. + // + // Any of "server_vad". + Type string `json:"type,omitzero,required"` + // Amount of audio to include before the VAD detected speech (in milliseconds). + PrefixPaddingMs param.Opt[int64] `json:"prefix_padding_ms,omitzero"` + // Duration of silence to detect speech stop (in milliseconds). With shorter values + // the model will respond more quickly, but may jump in on short pauses from the + // user. + SilenceDurationMs param.Opt[int64] `json:"silence_duration_ms,omitzero"` + // Sensitivity threshold (0.0 to 1.0) for voice activity detection. A higher + // threshold will require louder audio to activate the model, and thus might + // perform better in noisy environments. + Threshold param.Opt[float64] `json:"threshold,omitzero"` + paramObj +} + +func (r AudioTranscriptionNewParamsChunkingStrategyVadConfig) MarshalJSON() (data []byte, err error) { + type shadow AudioTranscriptionNewParamsChunkingStrategyVadConfig + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *AudioTranscriptionNewParamsChunkingStrategyVadConfig) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +func init() { + apijson.RegisterFieldValidator[AudioTranscriptionNewParamsChunkingStrategyVadConfig]( + "type", "server_vad", + ) +} diff --git a/vendor/github.com/openai/openai-go/audiotranslation.go b/vendor/github.com/openai/openai-go/audiotranslation.go new file mode 100644 index 0000000000..aa754e9469 --- /dev/null +++ b/vendor/github.com/openai/openai-go/audiotranslation.go @@ -0,0 +1,117 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package openai + +import ( + "bytes" + "context" + "io" + "mime/multipart" + "net/http" + + "github.com/openai/openai-go/internal/apiform" + "github.com/openai/openai-go/internal/apijson" + "github.com/openai/openai-go/internal/requestconfig" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/packages/param" + "github.com/openai/openai-go/packages/respjson" +) + +// AudioTranslationService contains methods and other services that help with +// interacting with the openai API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewAudioTranslationService] method instead. +type AudioTranslationService struct { + Options []option.RequestOption +} + +// NewAudioTranslationService generates a new service that applies the given +// options to each request. These options are applied after the parent client's +// options (if there is one), and before any request-specific options. +func NewAudioTranslationService(opts ...option.RequestOption) (r AudioTranslationService) { + r = AudioTranslationService{} + r.Options = opts + return +} + +// Translates audio into English. +func (r *AudioTranslationService) New(ctx context.Context, body AudioTranslationNewParams, opts ...option.RequestOption) (res *Translation, err error) { + opts = append(r.Options[:], opts...) + path := "audio/translations" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) + return +} + +type Translation struct { + Text string `json:"text,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Text respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r Translation) RawJSON() string { return r.JSON.raw } +func (r *Translation) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type AudioTranslationNewParams struct { + // The audio file object (not file name) translate, in one of these formats: flac, + // mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. + File io.Reader `json:"file,omitzero,required" format:"binary"` + // ID of the model to use. Only `whisper-1` (which is powered by our open source + // Whisper V2 model) is currently available. + Model AudioModel `json:"model,omitzero,required"` + // An optional text to guide the model's style or continue a previous audio + // segment. The + // [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting) + // should be in English. + Prompt param.Opt[string] `json:"prompt,omitzero"` + // The sampling temperature, between 0 and 1. Higher values like 0.8 will make the + // output more random, while lower values like 0.2 will make it more focused and + // deterministic. If set to 0, the model will use + // [log probability](https://en.wikipedia.org/wiki/Log_probability) to + // automatically increase the temperature until certain thresholds are hit. + Temperature param.Opt[float64] `json:"temperature,omitzero"` + // The format of the output, in one of these options: `json`, `text`, `srt`, + // `verbose_json`, or `vtt`. + // + // Any of "json", "text", "srt", "verbose_json", "vtt". + ResponseFormat AudioTranslationNewParamsResponseFormat `json:"response_format,omitzero"` + paramObj +} + +func (r AudioTranslationNewParams) MarshalMultipart() (data []byte, contentType string, err error) { + buf := bytes.NewBuffer(nil) + writer := multipart.NewWriter(buf) + err = apiform.MarshalRoot(r, writer) + if err == nil { + err = apiform.WriteExtras(writer, r.ExtraFields()) + } + if err != nil { + writer.Close() + return nil, "", err + } + err = writer.Close() + if err != nil { + return nil, "", err + } + return buf.Bytes(), writer.FormDataContentType(), nil +} + +// The format of the output, in one of these options: `json`, `text`, `srt`, +// `verbose_json`, or `vtt`. +type AudioTranslationNewParamsResponseFormat string + +const ( + AudioTranslationNewParamsResponseFormatJSON AudioTranslationNewParamsResponseFormat = "json" + AudioTranslationNewParamsResponseFormatText AudioTranslationNewParamsResponseFormat = "text" + AudioTranslationNewParamsResponseFormatSRT AudioTranslationNewParamsResponseFormat = "srt" + AudioTranslationNewParamsResponseFormatVerboseJSON AudioTranslationNewParamsResponseFormat = "verbose_json" + AudioTranslationNewParamsResponseFormatVTT AudioTranslationNewParamsResponseFormat = "vtt" +) diff --git a/vendor/github.com/openai/openai-go/batch.go b/vendor/github.com/openai/openai-go/batch.go new file mode 100644 index 0000000000..36e02ed372 --- /dev/null +++ b/vendor/github.com/openai/openai-go/batch.go @@ -0,0 +1,343 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package openai + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + + "github.com/openai/openai-go/internal/apijson" + "github.com/openai/openai-go/internal/apiquery" + "github.com/openai/openai-go/internal/requestconfig" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/packages/pagination" + "github.com/openai/openai-go/packages/param" + "github.com/openai/openai-go/packages/respjson" + "github.com/openai/openai-go/shared" + "github.com/openai/openai-go/shared/constant" +) + +// BatchService contains methods and other services that help with interacting with +// the openai API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewBatchService] method instead. +type BatchService struct { + Options []option.RequestOption +} + +// NewBatchService generates a new service that applies the given options to each +// request. These options are applied after the parent client's options (if there +// is one), and before any request-specific options. +func NewBatchService(opts ...option.RequestOption) (r BatchService) { + r = BatchService{} + r.Options = opts + return +} + +// Creates and executes a batch from an uploaded file of requests +func (r *BatchService) New(ctx context.Context, body BatchNewParams, opts ...option.RequestOption) (res *Batch, err error) { + opts = append(r.Options[:], opts...) + path := "batches" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) + return +} + +// Retrieves a batch. +func (r *BatchService) Get(ctx context.Context, batchID string, opts ...option.RequestOption) (res *Batch, err error) { + opts = append(r.Options[:], opts...) + if batchID == "" { + err = errors.New("missing required batch_id parameter") + return + } + path := fmt.Sprintf("batches/%s", batchID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, nil, &res, opts...) + return +} + +// List your organization's batches. +func (r *BatchService) List(ctx context.Context, query BatchListParams, opts ...option.RequestOption) (res *pagination.CursorPage[Batch], err error) { + var raw *http.Response + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithResponseInto(&raw)}, opts...) + path := "batches" + cfg, err := requestconfig.NewRequestConfig(ctx, http.MethodGet, path, query, &res, opts...) + if err != nil { + return nil, err + } + err = cfg.Execute() + if err != nil { + return nil, err + } + res.SetPageConfig(cfg, raw) + return res, nil +} + +// List your organization's batches. +func (r *BatchService) ListAutoPaging(ctx context.Context, query BatchListParams, opts ...option.RequestOption) *pagination.CursorPageAutoPager[Batch] { + return pagination.NewCursorPageAutoPager(r.List(ctx, query, opts...)) +} + +// Cancels an in-progress batch. The batch will be in status `cancelling` for up to +// 10 minutes, before changing to `cancelled`, where it will have partial results +// (if any) available in the output file. +func (r *BatchService) Cancel(ctx context.Context, batchID string, opts ...option.RequestOption) (res *Batch, err error) { + opts = append(r.Options[:], opts...) + if batchID == "" { + err = errors.New("missing required batch_id parameter") + return + } + path := fmt.Sprintf("batches/%s/cancel", batchID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, nil, &res, opts...) + return +} + +type Batch struct { + ID string `json:"id,required"` + // The time frame within which the batch should be processed. + CompletionWindow string `json:"completion_window,required"` + // The Unix timestamp (in seconds) for when the batch was created. + CreatedAt int64 `json:"created_at,required"` + // The OpenAI API endpoint used by the batch. + Endpoint string `json:"endpoint,required"` + // The ID of the input file for the batch. + InputFileID string `json:"input_file_id,required"` + // The object type, which is always `batch`. + Object constant.Batch `json:"object,required"` + // The current status of the batch. + // + // Any of "validating", "failed", "in_progress", "finalizing", "completed", + // "expired", "cancelling", "cancelled". + Status BatchStatus `json:"status,required"` + // The Unix timestamp (in seconds) for when the batch was cancelled. + CancelledAt int64 `json:"cancelled_at"` + // The Unix timestamp (in seconds) for when the batch started cancelling. + CancellingAt int64 `json:"cancelling_at"` + // The Unix timestamp (in seconds) for when the batch was completed. + CompletedAt int64 `json:"completed_at"` + // The ID of the file containing the outputs of requests with errors. + ErrorFileID string `json:"error_file_id"` + Errors BatchErrors `json:"errors"` + // The Unix timestamp (in seconds) for when the batch expired. + ExpiredAt int64 `json:"expired_at"` + // The Unix timestamp (in seconds) for when the batch will expire. + ExpiresAt int64 `json:"expires_at"` + // The Unix timestamp (in seconds) for when the batch failed. + FailedAt int64 `json:"failed_at"` + // The Unix timestamp (in seconds) for when the batch started finalizing. + FinalizingAt int64 `json:"finalizing_at"` + // The Unix timestamp (in seconds) for when the batch started processing. + InProgressAt int64 `json:"in_progress_at"` + // Set of 16 key-value pairs that can be attached to an object. This can be useful + // for storing additional information about the object in a structured format, and + // querying for objects via API or the dashboard. + // + // Keys are strings with a maximum length of 64 characters. Values are strings with + // a maximum length of 512 characters. + Metadata shared.Metadata `json:"metadata,nullable"` + // The ID of the file containing the outputs of successfully executed requests. + OutputFileID string `json:"output_file_id"` + // The request counts for different statuses within the batch. + RequestCounts BatchRequestCounts `json:"request_counts"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + CompletionWindow respjson.Field + CreatedAt respjson.Field + Endpoint respjson.Field + InputFileID respjson.Field + Object respjson.Field + Status respjson.Field + CancelledAt respjson.Field + CancellingAt respjson.Field + CompletedAt respjson.Field + ErrorFileID respjson.Field + Errors respjson.Field + ExpiredAt respjson.Field + ExpiresAt respjson.Field + FailedAt respjson.Field + FinalizingAt respjson.Field + InProgressAt respjson.Field + Metadata respjson.Field + OutputFileID respjson.Field + RequestCounts respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r Batch) RawJSON() string { return r.JSON.raw } +func (r *Batch) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The current status of the batch. +type BatchStatus string + +const ( + BatchStatusValidating BatchStatus = "validating" + BatchStatusFailed BatchStatus = "failed" + BatchStatusInProgress BatchStatus = "in_progress" + BatchStatusFinalizing BatchStatus = "finalizing" + BatchStatusCompleted BatchStatus = "completed" + BatchStatusExpired BatchStatus = "expired" + BatchStatusCancelling BatchStatus = "cancelling" + BatchStatusCancelled BatchStatus = "cancelled" +) + +type BatchErrors struct { + Data []BatchError `json:"data"` + // The object type, which is always `list`. + Object string `json:"object"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Data respjson.Field + Object respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r BatchErrors) RawJSON() string { return r.JSON.raw } +func (r *BatchErrors) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type BatchError struct { + // An error code identifying the error type. + Code string `json:"code"` + // The line number of the input file where the error occurred, if applicable. + Line int64 `json:"line,nullable"` + // A human-readable message providing more details about the error. + Message string `json:"message"` + // The name of the parameter that caused the error, if applicable. + Param string `json:"param,nullable"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Code respjson.Field + Line respjson.Field + Message respjson.Field + Param respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r BatchError) RawJSON() string { return r.JSON.raw } +func (r *BatchError) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The request counts for different statuses within the batch. +type BatchRequestCounts struct { + // Number of requests that have been completed successfully. + Completed int64 `json:"completed,required"` + // Number of requests that have failed. + Failed int64 `json:"failed,required"` + // Total number of requests in the batch. + Total int64 `json:"total,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Completed respjson.Field + Failed respjson.Field + Total respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r BatchRequestCounts) RawJSON() string { return r.JSON.raw } +func (r *BatchRequestCounts) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type BatchNewParams struct { + // The time frame within which the batch should be processed. Currently only `24h` + // is supported. + // + // Any of "24h". + CompletionWindow BatchNewParamsCompletionWindow `json:"completion_window,omitzero,required"` + // The endpoint to be used for all requests in the batch. Currently + // `/v1/responses`, `/v1/chat/completions`, `/v1/embeddings`, and `/v1/completions` + // are supported. Note that `/v1/embeddings` batches are also restricted to a + // maximum of 50,000 embedding inputs across all requests in the batch. + // + // Any of "/v1/responses", "/v1/chat/completions", "/v1/embeddings", + // "/v1/completions". + Endpoint BatchNewParamsEndpoint `json:"endpoint,omitzero,required"` + // The ID of an uploaded file that contains requests for the new batch. + // + // See [upload file](https://platform.openai.com/docs/api-reference/files/create) + // for how to upload a file. + // + // Your input file must be formatted as a + // [JSONL file](https://platform.openai.com/docs/api-reference/batch/request-input), + // and must be uploaded with the purpose `batch`. The file can contain up to 50,000 + // requests, and can be up to 200 MB in size. + InputFileID string `json:"input_file_id,required"` + // Set of 16 key-value pairs that can be attached to an object. This can be useful + // for storing additional information about the object in a structured format, and + // querying for objects via API or the dashboard. + // + // Keys are strings with a maximum length of 64 characters. Values are strings with + // a maximum length of 512 characters. + Metadata shared.Metadata `json:"metadata,omitzero"` + paramObj +} + +func (r BatchNewParams) MarshalJSON() (data []byte, err error) { + type shadow BatchNewParams + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BatchNewParams) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The time frame within which the batch should be processed. Currently only `24h` +// is supported. +type BatchNewParamsCompletionWindow string + +const ( + BatchNewParamsCompletionWindow24h BatchNewParamsCompletionWindow = "24h" +) + +// The endpoint to be used for all requests in the batch. Currently +// `/v1/responses`, `/v1/chat/completions`, `/v1/embeddings`, and `/v1/completions` +// are supported. Note that `/v1/embeddings` batches are also restricted to a +// maximum of 50,000 embedding inputs across all requests in the batch. +type BatchNewParamsEndpoint string + +const ( + BatchNewParamsEndpointV1Responses BatchNewParamsEndpoint = "/v1/responses" + BatchNewParamsEndpointV1ChatCompletions BatchNewParamsEndpoint = "/v1/chat/completions" + BatchNewParamsEndpointV1Embeddings BatchNewParamsEndpoint = "/v1/embeddings" + BatchNewParamsEndpointV1Completions BatchNewParamsEndpoint = "/v1/completions" +) + +type BatchListParams struct { + // A cursor for use in pagination. `after` is an object ID that defines your place + // in the list. For instance, if you make a list request and receive 100 objects, + // ending with obj_foo, your subsequent call can include after=obj_foo in order to + // fetch the next page of the list. + After param.Opt[string] `query:"after,omitzero" json:"-"` + // A limit on the number of objects to be returned. Limit can range between 1 and + // 100, and the default is 20. + Limit param.Opt[int64] `query:"limit,omitzero" json:"-"` + paramObj +} + +// URLQuery serializes [BatchListParams]'s query parameters as `url.Values`. +func (r BatchListParams) URLQuery() (v url.Values, err error) { + return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{ + ArrayFormat: apiquery.ArrayQueryFormatBrackets, + NestedFormat: apiquery.NestedQueryFormatBrackets, + }) +} diff --git a/vendor/github.com/openai/openai-go/beta.go b/vendor/github.com/openai/openai-go/beta.go new file mode 100644 index 0000000000..79fb960a92 --- /dev/null +++ b/vendor/github.com/openai/openai-go/beta.go @@ -0,0 +1,31 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package openai + +import ( + "github.com/openai/openai-go/option" +) + +// BetaService contains methods and other services that help with interacting with +// the openai API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewBetaService] method instead. +type BetaService struct { + Options []option.RequestOption + Assistants BetaAssistantService + // Deprecated: The Assistants API is deprecated in favor of the Responses API + Threads BetaThreadService +} + +// NewBetaService generates a new service that applies the given options to each +// request. These options are applied after the parent client's options (if there +// is one), and before any request-specific options. +func NewBetaService(opts ...option.RequestOption) (r BetaService) { + r = BetaService{} + r.Options = opts + r.Assistants = NewBetaAssistantService(opts...) + r.Threads = NewBetaThreadService(opts...) + return +} diff --git a/vendor/github.com/openai/openai-go/betaassistant.go b/vendor/github.com/openai/openai-go/betaassistant.go new file mode 100644 index 0000000000..69cb0fa8d5 --- /dev/null +++ b/vendor/github.com/openai/openai-go/betaassistant.go @@ -0,0 +1,2246 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package openai + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + + "github.com/openai/openai-go/internal/apijson" + "github.com/openai/openai-go/internal/apiquery" + "github.com/openai/openai-go/internal/requestconfig" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/packages/pagination" + "github.com/openai/openai-go/packages/param" + "github.com/openai/openai-go/packages/respjson" + "github.com/openai/openai-go/shared" + "github.com/openai/openai-go/shared/constant" +) + +// BetaAssistantService contains methods and other services that help with +// interacting with the openai API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewBetaAssistantService] method instead. +type BetaAssistantService struct { + Options []option.RequestOption +} + +// NewBetaAssistantService generates a new service that applies the given options +// to each request. These options are applied after the parent client's options (if +// there is one), and before any request-specific options. +func NewBetaAssistantService(opts ...option.RequestOption) (r BetaAssistantService) { + r = BetaAssistantService{} + r.Options = opts + return +} + +// Create an assistant with a model and instructions. +func (r *BetaAssistantService) New(ctx context.Context, body BetaAssistantNewParams, opts ...option.RequestOption) (res *Assistant, err error) { + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...) + path := "assistants" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) + return +} + +// Retrieves an assistant. +func (r *BetaAssistantService) Get(ctx context.Context, assistantID string, opts ...option.RequestOption) (res *Assistant, err error) { + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...) + if assistantID == "" { + err = errors.New("missing required assistant_id parameter") + return + } + path := fmt.Sprintf("assistants/%s", assistantID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, nil, &res, opts...) + return +} + +// Modifies an assistant. +func (r *BetaAssistantService) Update(ctx context.Context, assistantID string, body BetaAssistantUpdateParams, opts ...option.RequestOption) (res *Assistant, err error) { + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...) + if assistantID == "" { + err = errors.New("missing required assistant_id parameter") + return + } + path := fmt.Sprintf("assistants/%s", assistantID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) + return +} + +// Returns a list of assistants. +func (r *BetaAssistantService) List(ctx context.Context, query BetaAssistantListParams, opts ...option.RequestOption) (res *pagination.CursorPage[Assistant], err error) { + var raw *http.Response + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2"), option.WithResponseInto(&raw)}, opts...) + path := "assistants" + cfg, err := requestconfig.NewRequestConfig(ctx, http.MethodGet, path, query, &res, opts...) + if err != nil { + return nil, err + } + err = cfg.Execute() + if err != nil { + return nil, err + } + res.SetPageConfig(cfg, raw) + return res, nil +} + +// Returns a list of assistants. +func (r *BetaAssistantService) ListAutoPaging(ctx context.Context, query BetaAssistantListParams, opts ...option.RequestOption) *pagination.CursorPageAutoPager[Assistant] { + return pagination.NewCursorPageAutoPager(r.List(ctx, query, opts...)) +} + +// Delete an assistant. +func (r *BetaAssistantService) Delete(ctx context.Context, assistantID string, opts ...option.RequestOption) (res *AssistantDeleted, err error) { + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...) + if assistantID == "" { + err = errors.New("missing required assistant_id parameter") + return + } + path := fmt.Sprintf("assistants/%s", assistantID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodDelete, path, nil, &res, opts...) + return +} + +// Represents an `assistant` that can call the model and use tools. +type Assistant struct { + // The identifier, which can be referenced in API endpoints. + ID string `json:"id,required"` + // The Unix timestamp (in seconds) for when the assistant was created. + CreatedAt int64 `json:"created_at,required"` + // The description of the assistant. The maximum length is 512 characters. + Description string `json:"description,required"` + // The system instructions that the assistant uses. The maximum length is 256,000 + // characters. + Instructions string `json:"instructions,required"` + // Set of 16 key-value pairs that can be attached to an object. This can be useful + // for storing additional information about the object in a structured format, and + // querying for objects via API or the dashboard. + // + // Keys are strings with a maximum length of 64 characters. Values are strings with + // a maximum length of 512 characters. + Metadata shared.Metadata `json:"metadata,required"` + // ID of the model to use. You can use the + // [List models](https://platform.openai.com/docs/api-reference/models/list) API to + // see all of your available models, or see our + // [Model overview](https://platform.openai.com/docs/models) for descriptions of + // them. + Model string `json:"model,required"` + // The name of the assistant. The maximum length is 256 characters. + Name string `json:"name,required"` + // The object type, which is always `assistant`. + Object constant.Assistant `json:"object,required"` + // A list of tool enabled on the assistant. There can be a maximum of 128 tools per + // assistant. Tools can be of types `code_interpreter`, `file_search`, or + // `function`. + Tools []AssistantToolUnion `json:"tools,required"` + // Specifies the format that the model must output. Compatible with + // [GPT-4o](https://platform.openai.com/docs/models#gpt-4o), + // [GPT-4 Turbo](https://platform.openai.com/docs/models#gpt-4-turbo-and-gpt-4), + // and all GPT-3.5 Turbo models since `gpt-3.5-turbo-1106`. + // + // Setting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured + // Outputs which ensures the model will match your supplied JSON schema. Learn more + // in the + // [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs). + // + // Setting to `{ "type": "json_object" }` enables JSON mode, which ensures the + // message the model generates is valid JSON. + // + // **Important:** when using JSON mode, you **must** also instruct the model to + // produce JSON yourself via a system or user message. Without this, the model may + // generate an unending stream of whitespace until the generation reaches the token + // limit, resulting in a long-running and seemingly "stuck" request. Also note that + // the message content may be partially cut off if `finish_reason="length"`, which + // indicates the generation exceeded `max_tokens` or the conversation exceeded the + // max context length. + ResponseFormat AssistantResponseFormatOptionUnion `json:"response_format,nullable"` + // What sampling temperature to use, between 0 and 2. Higher values like 0.8 will + // make the output more random, while lower values like 0.2 will make it more + // focused and deterministic. + Temperature float64 `json:"temperature,nullable"` + // A set of resources that are used by the assistant's tools. The resources are + // specific to the type of tool. For example, the `code_interpreter` tool requires + // a list of file IDs, while the `file_search` tool requires a list of vector store + // IDs. + ToolResources AssistantToolResources `json:"tool_resources,nullable"` + // An alternative to sampling with temperature, called nucleus sampling, where the + // model considers the results of the tokens with top_p probability mass. So 0.1 + // means only the tokens comprising the top 10% probability mass are considered. + // + // We generally recommend altering this or temperature but not both. + TopP float64 `json:"top_p,nullable"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + CreatedAt respjson.Field + Description respjson.Field + Instructions respjson.Field + Metadata respjson.Field + Model respjson.Field + Name respjson.Field + Object respjson.Field + Tools respjson.Field + ResponseFormat respjson.Field + Temperature respjson.Field + ToolResources respjson.Field + TopP respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r Assistant) RawJSON() string { return r.JSON.raw } +func (r *Assistant) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// A set of resources that are used by the assistant's tools. The resources are +// specific to the type of tool. For example, the `code_interpreter` tool requires +// a list of file IDs, while the `file_search` tool requires a list of vector store +// IDs. +type AssistantToolResources struct { + CodeInterpreter AssistantToolResourcesCodeInterpreter `json:"code_interpreter"` + FileSearch AssistantToolResourcesFileSearch `json:"file_search"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + CodeInterpreter respjson.Field + FileSearch respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r AssistantToolResources) RawJSON() string { return r.JSON.raw } +func (r *AssistantToolResources) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type AssistantToolResourcesCodeInterpreter struct { + // A list of [file](https://platform.openai.com/docs/api-reference/files) IDs made + // available to the `code_interpreter“ tool. There can be a maximum of 20 files + // associated with the tool. + FileIDs []string `json:"file_ids"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + FileIDs respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r AssistantToolResourcesCodeInterpreter) RawJSON() string { return r.JSON.raw } +func (r *AssistantToolResourcesCodeInterpreter) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type AssistantToolResourcesFileSearch struct { + // The ID of the + // [vector store](https://platform.openai.com/docs/api-reference/vector-stores/object) + // attached to this assistant. There can be a maximum of 1 vector store attached to + // the assistant. + VectorStoreIDs []string `json:"vector_store_ids"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + VectorStoreIDs respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r AssistantToolResourcesFileSearch) RawJSON() string { return r.JSON.raw } +func (r *AssistantToolResourcesFileSearch) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type AssistantDeleted struct { + ID string `json:"id,required"` + Deleted bool `json:"deleted,required"` + Object constant.AssistantDeleted `json:"object,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + Deleted respjson.Field + Object respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r AssistantDeleted) RawJSON() string { return r.JSON.raw } +func (r *AssistantDeleted) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// AssistantStreamEventUnion contains all possible properties and values from +// [AssistantStreamEventThreadCreated], [AssistantStreamEventThreadRunCreated], +// [AssistantStreamEventThreadRunQueued], +// [AssistantStreamEventThreadRunInProgress], +// [AssistantStreamEventThreadRunRequiresAction], +// [AssistantStreamEventThreadRunCompleted], +// [AssistantStreamEventThreadRunIncomplete], +// [AssistantStreamEventThreadRunFailed], +// [AssistantStreamEventThreadRunCancelling], +// [AssistantStreamEventThreadRunCancelled], +// [AssistantStreamEventThreadRunExpired], +// [AssistantStreamEventThreadRunStepCreated], +// [AssistantStreamEventThreadRunStepInProgress], +// [AssistantStreamEventThreadRunStepDelta], +// [AssistantStreamEventThreadRunStepCompleted], +// [AssistantStreamEventThreadRunStepFailed], +// [AssistantStreamEventThreadRunStepCancelled], +// [AssistantStreamEventThreadRunStepExpired], +// [AssistantStreamEventThreadMessageCreated], +// [AssistantStreamEventThreadMessageInProgress], +// [AssistantStreamEventThreadMessageDelta], +// [AssistantStreamEventThreadMessageCompleted], +// [AssistantStreamEventThreadMessageIncomplete], [AssistantStreamEventErrorEvent]. +// +// Use the [AssistantStreamEventUnion.AsAny] method to switch on the variant. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +type AssistantStreamEventUnion struct { + // This field is a union of [Thread], [Run], [RunStep], [RunStepDeltaEvent], + // [Message], [MessageDeltaEvent], [shared.ErrorObject] + Data AssistantStreamEventUnionData `json:"data"` + // Any of "thread.created", "thread.run.created", "thread.run.queued", + // "thread.run.in_progress", "thread.run.requires_action", "thread.run.completed", + // "thread.run.incomplete", "thread.run.failed", "thread.run.cancelling", + // "thread.run.cancelled", "thread.run.expired", "thread.run.step.created", + // "thread.run.step.in_progress", "thread.run.step.delta", + // "thread.run.step.completed", "thread.run.step.failed", + // "thread.run.step.cancelled", "thread.run.step.expired", + // "thread.message.created", "thread.message.in_progress", "thread.message.delta", + // "thread.message.completed", "thread.message.incomplete", "error". + Event string `json:"event"` + // This field is from variant [AssistantStreamEventThreadCreated]. + Enabled bool `json:"enabled"` + JSON struct { + Data respjson.Field + Event respjson.Field + Enabled respjson.Field + raw string + } `json:"-"` +} + +// anyAssistantStreamEvent is implemented by each variant of +// [AssistantStreamEventUnion] to add type safety for the return type of +// [AssistantStreamEventUnion.AsAny] +type anyAssistantStreamEvent interface { + implAssistantStreamEventUnion() +} + +func (AssistantStreamEventThreadCreated) implAssistantStreamEventUnion() {} +func (AssistantStreamEventThreadRunCreated) implAssistantStreamEventUnion() {} +func (AssistantStreamEventThreadRunQueued) implAssistantStreamEventUnion() {} +func (AssistantStreamEventThreadRunInProgress) implAssistantStreamEventUnion() {} +func (AssistantStreamEventThreadRunRequiresAction) implAssistantStreamEventUnion() {} +func (AssistantStreamEventThreadRunCompleted) implAssistantStreamEventUnion() {} +func (AssistantStreamEventThreadRunIncomplete) implAssistantStreamEventUnion() {} +func (AssistantStreamEventThreadRunFailed) implAssistantStreamEventUnion() {} +func (AssistantStreamEventThreadRunCancelling) implAssistantStreamEventUnion() {} +func (AssistantStreamEventThreadRunCancelled) implAssistantStreamEventUnion() {} +func (AssistantStreamEventThreadRunExpired) implAssistantStreamEventUnion() {} +func (AssistantStreamEventThreadRunStepCreated) implAssistantStreamEventUnion() {} +func (AssistantStreamEventThreadRunStepInProgress) implAssistantStreamEventUnion() {} +func (AssistantStreamEventThreadRunStepDelta) implAssistantStreamEventUnion() {} +func (AssistantStreamEventThreadRunStepCompleted) implAssistantStreamEventUnion() {} +func (AssistantStreamEventThreadRunStepFailed) implAssistantStreamEventUnion() {} +func (AssistantStreamEventThreadRunStepCancelled) implAssistantStreamEventUnion() {} +func (AssistantStreamEventThreadRunStepExpired) implAssistantStreamEventUnion() {} +func (AssistantStreamEventThreadMessageCreated) implAssistantStreamEventUnion() {} +func (AssistantStreamEventThreadMessageInProgress) implAssistantStreamEventUnion() {} +func (AssistantStreamEventThreadMessageDelta) implAssistantStreamEventUnion() {} +func (AssistantStreamEventThreadMessageCompleted) implAssistantStreamEventUnion() {} +func (AssistantStreamEventThreadMessageIncomplete) implAssistantStreamEventUnion() {} +func (AssistantStreamEventErrorEvent) implAssistantStreamEventUnion() {} + +// Use the following switch statement to find the correct variant +// +// switch variant := AssistantStreamEventUnion.AsAny().(type) { +// case openai.AssistantStreamEventThreadCreated: +// case openai.AssistantStreamEventThreadRunCreated: +// case openai.AssistantStreamEventThreadRunQueued: +// case openai.AssistantStreamEventThreadRunInProgress: +// case openai.AssistantStreamEventThreadRunRequiresAction: +// case openai.AssistantStreamEventThreadRunCompleted: +// case openai.AssistantStreamEventThreadRunIncomplete: +// case openai.AssistantStreamEventThreadRunFailed: +// case openai.AssistantStreamEventThreadRunCancelling: +// case openai.AssistantStreamEventThreadRunCancelled: +// case openai.AssistantStreamEventThreadRunExpired: +// case openai.AssistantStreamEventThreadRunStepCreated: +// case openai.AssistantStreamEventThreadRunStepInProgress: +// case openai.AssistantStreamEventThreadRunStepDelta: +// case openai.AssistantStreamEventThreadRunStepCompleted: +// case openai.AssistantStreamEventThreadRunStepFailed: +// case openai.AssistantStreamEventThreadRunStepCancelled: +// case openai.AssistantStreamEventThreadRunStepExpired: +// case openai.AssistantStreamEventThreadMessageCreated: +// case openai.AssistantStreamEventThreadMessageInProgress: +// case openai.AssistantStreamEventThreadMessageDelta: +// case openai.AssistantStreamEventThreadMessageCompleted: +// case openai.AssistantStreamEventThreadMessageIncomplete: +// case openai.AssistantStreamEventErrorEvent: +// default: +// fmt.Errorf("no variant present") +// } +func (u AssistantStreamEventUnion) AsAny() anyAssistantStreamEvent { + switch u.Event { + case "thread.created": + return u.AsThreadCreated() + case "thread.run.created": + return u.AsThreadRunCreated() + case "thread.run.queued": + return u.AsThreadRunQueued() + case "thread.run.in_progress": + return u.AsThreadRunInProgress() + case "thread.run.requires_action": + return u.AsThreadRunRequiresAction() + case "thread.run.completed": + return u.AsThreadRunCompleted() + case "thread.run.incomplete": + return u.AsThreadRunIncomplete() + case "thread.run.failed": + return u.AsThreadRunFailed() + case "thread.run.cancelling": + return u.AsThreadRunCancelling() + case "thread.run.cancelled": + return u.AsThreadRunCancelled() + case "thread.run.expired": + return u.AsThreadRunExpired() + case "thread.run.step.created": + return u.AsThreadRunStepCreated() + case "thread.run.step.in_progress": + return u.AsThreadRunStepInProgress() + case "thread.run.step.delta": + return u.AsThreadRunStepDelta() + case "thread.run.step.completed": + return u.AsThreadRunStepCompleted() + case "thread.run.step.failed": + return u.AsThreadRunStepFailed() + case "thread.run.step.cancelled": + return u.AsThreadRunStepCancelled() + case "thread.run.step.expired": + return u.AsThreadRunStepExpired() + case "thread.message.created": + return u.AsThreadMessageCreated() + case "thread.message.in_progress": + return u.AsThreadMessageInProgress() + case "thread.message.delta": + return u.AsThreadMessageDelta() + case "thread.message.completed": + return u.AsThreadMessageCompleted() + case "thread.message.incomplete": + return u.AsThreadMessageIncomplete() + case "error": + return u.AsErrorEvent() + } + return nil +} + +func (u AssistantStreamEventUnion) AsThreadCreated() (v AssistantStreamEventThreadCreated) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u AssistantStreamEventUnion) AsThreadRunCreated() (v AssistantStreamEventThreadRunCreated) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u AssistantStreamEventUnion) AsThreadRunQueued() (v AssistantStreamEventThreadRunQueued) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u AssistantStreamEventUnion) AsThreadRunInProgress() (v AssistantStreamEventThreadRunInProgress) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u AssistantStreamEventUnion) AsThreadRunRequiresAction() (v AssistantStreamEventThreadRunRequiresAction) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u AssistantStreamEventUnion) AsThreadRunCompleted() (v AssistantStreamEventThreadRunCompleted) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u AssistantStreamEventUnion) AsThreadRunIncomplete() (v AssistantStreamEventThreadRunIncomplete) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u AssistantStreamEventUnion) AsThreadRunFailed() (v AssistantStreamEventThreadRunFailed) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u AssistantStreamEventUnion) AsThreadRunCancelling() (v AssistantStreamEventThreadRunCancelling) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u AssistantStreamEventUnion) AsThreadRunCancelled() (v AssistantStreamEventThreadRunCancelled) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u AssistantStreamEventUnion) AsThreadRunExpired() (v AssistantStreamEventThreadRunExpired) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u AssistantStreamEventUnion) AsThreadRunStepCreated() (v AssistantStreamEventThreadRunStepCreated) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u AssistantStreamEventUnion) AsThreadRunStepInProgress() (v AssistantStreamEventThreadRunStepInProgress) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u AssistantStreamEventUnion) AsThreadRunStepDelta() (v AssistantStreamEventThreadRunStepDelta) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u AssistantStreamEventUnion) AsThreadRunStepCompleted() (v AssistantStreamEventThreadRunStepCompleted) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u AssistantStreamEventUnion) AsThreadRunStepFailed() (v AssistantStreamEventThreadRunStepFailed) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u AssistantStreamEventUnion) AsThreadRunStepCancelled() (v AssistantStreamEventThreadRunStepCancelled) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u AssistantStreamEventUnion) AsThreadRunStepExpired() (v AssistantStreamEventThreadRunStepExpired) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u AssistantStreamEventUnion) AsThreadMessageCreated() (v AssistantStreamEventThreadMessageCreated) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u AssistantStreamEventUnion) AsThreadMessageInProgress() (v AssistantStreamEventThreadMessageInProgress) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u AssistantStreamEventUnion) AsThreadMessageDelta() (v AssistantStreamEventThreadMessageDelta) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u AssistantStreamEventUnion) AsThreadMessageCompleted() (v AssistantStreamEventThreadMessageCompleted) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u AssistantStreamEventUnion) AsThreadMessageIncomplete() (v AssistantStreamEventThreadMessageIncomplete) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u AssistantStreamEventUnion) AsErrorEvent() (v AssistantStreamEventErrorEvent) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u AssistantStreamEventUnion) RawJSON() string { return u.JSON.raw } + +func (r *AssistantStreamEventUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// AssistantStreamEventUnionData is an implicit subunion of +// [AssistantStreamEventUnion]. AssistantStreamEventUnionData provides convenient +// access to the sub-properties of the union. +// +// For type safety it is recommended to directly use a variant of the +// [AssistantStreamEventUnion]. +type AssistantStreamEventUnionData struct { + ID string `json:"id"` + CreatedAt int64 `json:"created_at"` + // This field is from variant [Thread]. + Metadata shared.Metadata `json:"metadata"` + Object string `json:"object"` + // This field is from variant [Thread]. + ToolResources ThreadToolResources `json:"tool_resources"` + AssistantID string `json:"assistant_id"` + CancelledAt int64 `json:"cancelled_at"` + CompletedAt int64 `json:"completed_at"` + // This field is from variant [Run]. + ExpiresAt int64 `json:"expires_at"` + FailedAt int64 `json:"failed_at"` + // This field is a union of [RunIncompleteDetails], [MessageIncompleteDetails] + IncompleteDetails AssistantStreamEventUnionDataIncompleteDetails `json:"incomplete_details"` + // This field is from variant [Run]. + Instructions string `json:"instructions"` + // This field is a union of [RunLastError], [RunStepLastError] + LastError AssistantStreamEventUnionDataLastError `json:"last_error"` + // This field is from variant [Run]. + MaxCompletionTokens int64 `json:"max_completion_tokens"` + // This field is from variant [Run]. + MaxPromptTokens int64 `json:"max_prompt_tokens"` + // This field is from variant [Run]. + Model string `json:"model"` + // This field is from variant [Run]. + ParallelToolCalls bool `json:"parallel_tool_calls"` + // This field is from variant [Run]. + RequiredAction RunRequiredAction `json:"required_action"` + // This field is from variant [Run]. + ResponseFormat AssistantResponseFormatOptionUnion `json:"response_format"` + // This field is from variant [Run]. + StartedAt int64 `json:"started_at"` + Status string `json:"status"` + ThreadID string `json:"thread_id"` + // This field is from variant [Run]. + ToolChoice AssistantToolChoiceOptionUnion `json:"tool_choice"` + // This field is from variant [Run]. + Tools []AssistantToolUnion `json:"tools"` + // This field is from variant [Run]. + TruncationStrategy RunTruncationStrategy `json:"truncation_strategy"` + // This field is a union of [RunUsage], [RunStepUsage] + Usage AssistantStreamEventUnionDataUsage `json:"usage"` + // This field is from variant [Run]. + Temperature float64 `json:"temperature"` + // This field is from variant [Run]. + TopP float64 `json:"top_p"` + // This field is from variant [RunStep]. + ExpiredAt int64 `json:"expired_at"` + RunID string `json:"run_id"` + // This field is from variant [RunStep]. + StepDetails RunStepStepDetailsUnion `json:"step_details"` + Type string `json:"type"` + // This field is a union of [RunStepDelta], [MessageDelta] + Delta AssistantStreamEventUnionDataDelta `json:"delta"` + // This field is from variant [Message]. + Attachments []MessageAttachment `json:"attachments"` + // This field is from variant [Message]. + Content []MessageContentUnion `json:"content"` + // This field is from variant [Message]. + IncompleteAt int64 `json:"incomplete_at"` + // This field is from variant [Message]. + Role MessageRole `json:"role"` + // This field is from variant [shared.ErrorObject]. + Code string `json:"code"` + // This field is from variant [shared.ErrorObject]. + Message string `json:"message"` + // This field is from variant [shared.ErrorObject]. + Param string `json:"param"` + JSON struct { + ID respjson.Field + CreatedAt respjson.Field + Metadata respjson.Field + Object respjson.Field + ToolResources respjson.Field + AssistantID respjson.Field + CancelledAt respjson.Field + CompletedAt respjson.Field + ExpiresAt respjson.Field + FailedAt respjson.Field + IncompleteDetails respjson.Field + Instructions respjson.Field + LastError respjson.Field + MaxCompletionTokens respjson.Field + MaxPromptTokens respjson.Field + Model respjson.Field + ParallelToolCalls respjson.Field + RequiredAction respjson.Field + ResponseFormat respjson.Field + StartedAt respjson.Field + Status respjson.Field + ThreadID respjson.Field + ToolChoice respjson.Field + Tools respjson.Field + TruncationStrategy respjson.Field + Usage respjson.Field + Temperature respjson.Field + TopP respjson.Field + ExpiredAt respjson.Field + RunID respjson.Field + StepDetails respjson.Field + Type respjson.Field + Delta respjson.Field + Attachments respjson.Field + Content respjson.Field + IncompleteAt respjson.Field + Role respjson.Field + Code respjson.Field + Message respjson.Field + Param respjson.Field + raw string + } `json:"-"` +} + +func (r *AssistantStreamEventUnionData) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// AssistantStreamEventUnionDataIncompleteDetails is an implicit subunion of +// [AssistantStreamEventUnion]. AssistantStreamEventUnionDataIncompleteDetails +// provides convenient access to the sub-properties of the union. +// +// For type safety it is recommended to directly use a variant of the +// [AssistantStreamEventUnion]. +type AssistantStreamEventUnionDataIncompleteDetails struct { + Reason string `json:"reason"` + JSON struct { + Reason respjson.Field + raw string + } `json:"-"` +} + +func (r *AssistantStreamEventUnionDataIncompleteDetails) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// AssistantStreamEventUnionDataLastError is an implicit subunion of +// [AssistantStreamEventUnion]. AssistantStreamEventUnionDataLastError provides +// convenient access to the sub-properties of the union. +// +// For type safety it is recommended to directly use a variant of the +// [AssistantStreamEventUnion]. +type AssistantStreamEventUnionDataLastError struct { + Code string `json:"code"` + Message string `json:"message"` + JSON struct { + Code respjson.Field + Message respjson.Field + raw string + } `json:"-"` +} + +func (r *AssistantStreamEventUnionDataLastError) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// AssistantStreamEventUnionDataUsage is an implicit subunion of +// [AssistantStreamEventUnion]. AssistantStreamEventUnionDataUsage provides +// convenient access to the sub-properties of the union. +// +// For type safety it is recommended to directly use a variant of the +// [AssistantStreamEventUnion]. +type AssistantStreamEventUnionDataUsage struct { + CompletionTokens int64 `json:"completion_tokens"` + PromptTokens int64 `json:"prompt_tokens"` + TotalTokens int64 `json:"total_tokens"` + JSON struct { + CompletionTokens respjson.Field + PromptTokens respjson.Field + TotalTokens respjson.Field + raw string + } `json:"-"` +} + +func (r *AssistantStreamEventUnionDataUsage) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// AssistantStreamEventUnionDataDelta is an implicit subunion of +// [AssistantStreamEventUnion]. AssistantStreamEventUnionDataDelta provides +// convenient access to the sub-properties of the union. +// +// For type safety it is recommended to directly use a variant of the +// [AssistantStreamEventUnion]. +type AssistantStreamEventUnionDataDelta struct { + // This field is from variant [RunStepDelta]. + StepDetails RunStepDeltaStepDetailsUnion `json:"step_details"` + // This field is from variant [MessageDelta]. + Content []MessageContentDeltaUnion `json:"content"` + // This field is from variant [MessageDelta]. + Role MessageDeltaRole `json:"role"` + JSON struct { + StepDetails respjson.Field + Content respjson.Field + Role respjson.Field + raw string + } `json:"-"` +} + +func (r *AssistantStreamEventUnionDataDelta) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Occurs when a new +// [thread](https://platform.openai.com/docs/api-reference/threads/object) is +// created. +type AssistantStreamEventThreadCreated struct { + // Represents a thread that contains + // [messages](https://platform.openai.com/docs/api-reference/messages). + Data Thread `json:"data,required"` + Event constant.ThreadCreated `json:"event,required"` + // Whether to enable input audio transcription. + Enabled bool `json:"enabled"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Data respjson.Field + Event respjson.Field + Enabled respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r AssistantStreamEventThreadCreated) RawJSON() string { return r.JSON.raw } +func (r *AssistantStreamEventThreadCreated) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Occurs when a new +// [run](https://platform.openai.com/docs/api-reference/runs/object) is created. +type AssistantStreamEventThreadRunCreated struct { + // Represents an execution run on a + // [thread](https://platform.openai.com/docs/api-reference/threads). + Data Run `json:"data,required"` + Event constant.ThreadRunCreated `json:"event,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Data respjson.Field + Event respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r AssistantStreamEventThreadRunCreated) RawJSON() string { return r.JSON.raw } +func (r *AssistantStreamEventThreadRunCreated) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Occurs when a [run](https://platform.openai.com/docs/api-reference/runs/object) +// moves to a `queued` status. +type AssistantStreamEventThreadRunQueued struct { + // Represents an execution run on a + // [thread](https://platform.openai.com/docs/api-reference/threads). + Data Run `json:"data,required"` + Event constant.ThreadRunQueued `json:"event,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Data respjson.Field + Event respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r AssistantStreamEventThreadRunQueued) RawJSON() string { return r.JSON.raw } +func (r *AssistantStreamEventThreadRunQueued) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Occurs when a [run](https://platform.openai.com/docs/api-reference/runs/object) +// moves to an `in_progress` status. +type AssistantStreamEventThreadRunInProgress struct { + // Represents an execution run on a + // [thread](https://platform.openai.com/docs/api-reference/threads). + Data Run `json:"data,required"` + Event constant.ThreadRunInProgress `json:"event,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Data respjson.Field + Event respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r AssistantStreamEventThreadRunInProgress) RawJSON() string { return r.JSON.raw } +func (r *AssistantStreamEventThreadRunInProgress) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Occurs when a [run](https://platform.openai.com/docs/api-reference/runs/object) +// moves to a `requires_action` status. +type AssistantStreamEventThreadRunRequiresAction struct { + // Represents an execution run on a + // [thread](https://platform.openai.com/docs/api-reference/threads). + Data Run `json:"data,required"` + Event constant.ThreadRunRequiresAction `json:"event,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Data respjson.Field + Event respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r AssistantStreamEventThreadRunRequiresAction) RawJSON() string { return r.JSON.raw } +func (r *AssistantStreamEventThreadRunRequiresAction) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Occurs when a [run](https://platform.openai.com/docs/api-reference/runs/object) +// is completed. +type AssistantStreamEventThreadRunCompleted struct { + // Represents an execution run on a + // [thread](https://platform.openai.com/docs/api-reference/threads). + Data Run `json:"data,required"` + Event constant.ThreadRunCompleted `json:"event,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Data respjson.Field + Event respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r AssistantStreamEventThreadRunCompleted) RawJSON() string { return r.JSON.raw } +func (r *AssistantStreamEventThreadRunCompleted) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Occurs when a [run](https://platform.openai.com/docs/api-reference/runs/object) +// ends with status `incomplete`. +type AssistantStreamEventThreadRunIncomplete struct { + // Represents an execution run on a + // [thread](https://platform.openai.com/docs/api-reference/threads). + Data Run `json:"data,required"` + Event constant.ThreadRunIncomplete `json:"event,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Data respjson.Field + Event respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r AssistantStreamEventThreadRunIncomplete) RawJSON() string { return r.JSON.raw } +func (r *AssistantStreamEventThreadRunIncomplete) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Occurs when a [run](https://platform.openai.com/docs/api-reference/runs/object) +// fails. +type AssistantStreamEventThreadRunFailed struct { + // Represents an execution run on a + // [thread](https://platform.openai.com/docs/api-reference/threads). + Data Run `json:"data,required"` + Event constant.ThreadRunFailed `json:"event,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Data respjson.Field + Event respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r AssistantStreamEventThreadRunFailed) RawJSON() string { return r.JSON.raw } +func (r *AssistantStreamEventThreadRunFailed) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Occurs when a [run](https://platform.openai.com/docs/api-reference/runs/object) +// moves to a `cancelling` status. +type AssistantStreamEventThreadRunCancelling struct { + // Represents an execution run on a + // [thread](https://platform.openai.com/docs/api-reference/threads). + Data Run `json:"data,required"` + Event constant.ThreadRunCancelling `json:"event,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Data respjson.Field + Event respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r AssistantStreamEventThreadRunCancelling) RawJSON() string { return r.JSON.raw } +func (r *AssistantStreamEventThreadRunCancelling) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Occurs when a [run](https://platform.openai.com/docs/api-reference/runs/object) +// is cancelled. +type AssistantStreamEventThreadRunCancelled struct { + // Represents an execution run on a + // [thread](https://platform.openai.com/docs/api-reference/threads). + Data Run `json:"data,required"` + Event constant.ThreadRunCancelled `json:"event,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Data respjson.Field + Event respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r AssistantStreamEventThreadRunCancelled) RawJSON() string { return r.JSON.raw } +func (r *AssistantStreamEventThreadRunCancelled) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Occurs when a [run](https://platform.openai.com/docs/api-reference/runs/object) +// expires. +type AssistantStreamEventThreadRunExpired struct { + // Represents an execution run on a + // [thread](https://platform.openai.com/docs/api-reference/threads). + Data Run `json:"data,required"` + Event constant.ThreadRunExpired `json:"event,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Data respjson.Field + Event respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r AssistantStreamEventThreadRunExpired) RawJSON() string { return r.JSON.raw } +func (r *AssistantStreamEventThreadRunExpired) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Occurs when a +// [run step](https://platform.openai.com/docs/api-reference/run-steps/step-object) +// is created. +type AssistantStreamEventThreadRunStepCreated struct { + // Represents a step in execution of a run. + Data RunStep `json:"data,required"` + Event constant.ThreadRunStepCreated `json:"event,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Data respjson.Field + Event respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r AssistantStreamEventThreadRunStepCreated) RawJSON() string { return r.JSON.raw } +func (r *AssistantStreamEventThreadRunStepCreated) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Occurs when a +// [run step](https://platform.openai.com/docs/api-reference/run-steps/step-object) +// moves to an `in_progress` state. +type AssistantStreamEventThreadRunStepInProgress struct { + // Represents a step in execution of a run. + Data RunStep `json:"data,required"` + Event constant.ThreadRunStepInProgress `json:"event,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Data respjson.Field + Event respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r AssistantStreamEventThreadRunStepInProgress) RawJSON() string { return r.JSON.raw } +func (r *AssistantStreamEventThreadRunStepInProgress) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Occurs when parts of a +// [run step](https://platform.openai.com/docs/api-reference/run-steps/step-object) +// are being streamed. +type AssistantStreamEventThreadRunStepDelta struct { + // Represents a run step delta i.e. any changed fields on a run step during + // streaming. + Data RunStepDeltaEvent `json:"data,required"` + Event constant.ThreadRunStepDelta `json:"event,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Data respjson.Field + Event respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r AssistantStreamEventThreadRunStepDelta) RawJSON() string { return r.JSON.raw } +func (r *AssistantStreamEventThreadRunStepDelta) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Occurs when a +// [run step](https://platform.openai.com/docs/api-reference/run-steps/step-object) +// is completed. +type AssistantStreamEventThreadRunStepCompleted struct { + // Represents a step in execution of a run. + Data RunStep `json:"data,required"` + Event constant.ThreadRunStepCompleted `json:"event,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Data respjson.Field + Event respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r AssistantStreamEventThreadRunStepCompleted) RawJSON() string { return r.JSON.raw } +func (r *AssistantStreamEventThreadRunStepCompleted) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Occurs when a +// [run step](https://platform.openai.com/docs/api-reference/run-steps/step-object) +// fails. +type AssistantStreamEventThreadRunStepFailed struct { + // Represents a step in execution of a run. + Data RunStep `json:"data,required"` + Event constant.ThreadRunStepFailed `json:"event,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Data respjson.Field + Event respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r AssistantStreamEventThreadRunStepFailed) RawJSON() string { return r.JSON.raw } +func (r *AssistantStreamEventThreadRunStepFailed) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Occurs when a +// [run step](https://platform.openai.com/docs/api-reference/run-steps/step-object) +// is cancelled. +type AssistantStreamEventThreadRunStepCancelled struct { + // Represents a step in execution of a run. + Data RunStep `json:"data,required"` + Event constant.ThreadRunStepCancelled `json:"event,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Data respjson.Field + Event respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r AssistantStreamEventThreadRunStepCancelled) RawJSON() string { return r.JSON.raw } +func (r *AssistantStreamEventThreadRunStepCancelled) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Occurs when a +// [run step](https://platform.openai.com/docs/api-reference/run-steps/step-object) +// expires. +type AssistantStreamEventThreadRunStepExpired struct { + // Represents a step in execution of a run. + Data RunStep `json:"data,required"` + Event constant.ThreadRunStepExpired `json:"event,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Data respjson.Field + Event respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r AssistantStreamEventThreadRunStepExpired) RawJSON() string { return r.JSON.raw } +func (r *AssistantStreamEventThreadRunStepExpired) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Occurs when a +// [message](https://platform.openai.com/docs/api-reference/messages/object) is +// created. +type AssistantStreamEventThreadMessageCreated struct { + // Represents a message within a + // [thread](https://platform.openai.com/docs/api-reference/threads). + Data Message `json:"data,required"` + Event constant.ThreadMessageCreated `json:"event,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Data respjson.Field + Event respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r AssistantStreamEventThreadMessageCreated) RawJSON() string { return r.JSON.raw } +func (r *AssistantStreamEventThreadMessageCreated) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Occurs when a +// [message](https://platform.openai.com/docs/api-reference/messages/object) moves +// to an `in_progress` state. +type AssistantStreamEventThreadMessageInProgress struct { + // Represents a message within a + // [thread](https://platform.openai.com/docs/api-reference/threads). + Data Message `json:"data,required"` + Event constant.ThreadMessageInProgress `json:"event,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Data respjson.Field + Event respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r AssistantStreamEventThreadMessageInProgress) RawJSON() string { return r.JSON.raw } +func (r *AssistantStreamEventThreadMessageInProgress) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Occurs when parts of a +// [Message](https://platform.openai.com/docs/api-reference/messages/object) are +// being streamed. +type AssistantStreamEventThreadMessageDelta struct { + // Represents a message delta i.e. any changed fields on a message during + // streaming. + Data MessageDeltaEvent `json:"data,required"` + Event constant.ThreadMessageDelta `json:"event,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Data respjson.Field + Event respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r AssistantStreamEventThreadMessageDelta) RawJSON() string { return r.JSON.raw } +func (r *AssistantStreamEventThreadMessageDelta) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Occurs when a +// [message](https://platform.openai.com/docs/api-reference/messages/object) is +// completed. +type AssistantStreamEventThreadMessageCompleted struct { + // Represents a message within a + // [thread](https://platform.openai.com/docs/api-reference/threads). + Data Message `json:"data,required"` + Event constant.ThreadMessageCompleted `json:"event,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Data respjson.Field + Event respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r AssistantStreamEventThreadMessageCompleted) RawJSON() string { return r.JSON.raw } +func (r *AssistantStreamEventThreadMessageCompleted) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Occurs when a +// [message](https://platform.openai.com/docs/api-reference/messages/object) ends +// before it is completed. +type AssistantStreamEventThreadMessageIncomplete struct { + // Represents a message within a + // [thread](https://platform.openai.com/docs/api-reference/threads). + Data Message `json:"data,required"` + Event constant.ThreadMessageIncomplete `json:"event,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Data respjson.Field + Event respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r AssistantStreamEventThreadMessageIncomplete) RawJSON() string { return r.JSON.raw } +func (r *AssistantStreamEventThreadMessageIncomplete) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Occurs when an +// [error](https://platform.openai.com/docs/guides/error-codes#api-errors) occurs. +// This can happen due to an internal server error or a timeout. +type AssistantStreamEventErrorEvent struct { + Data shared.ErrorObject `json:"data,required"` + Event constant.Error `json:"event,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Data respjson.Field + Event respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r AssistantStreamEventErrorEvent) RawJSON() string { return r.JSON.raw } +func (r *AssistantStreamEventErrorEvent) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// AssistantToolUnion contains all possible properties and values from +// [CodeInterpreterTool], [FileSearchTool], [FunctionTool]. +// +// Use the [AssistantToolUnion.AsAny] method to switch on the variant. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +type AssistantToolUnion struct { + // Any of "code_interpreter", "file_search", "function". + Type string `json:"type"` + // This field is from variant [FileSearchTool]. + FileSearch FileSearchToolFileSearch `json:"file_search"` + // This field is from variant [FunctionTool]. + Function shared.FunctionDefinition `json:"function"` + JSON struct { + Type respjson.Field + FileSearch respjson.Field + Function respjson.Field + raw string + } `json:"-"` +} + +// anyAssistantTool is implemented by each variant of [AssistantToolUnion] to add +// type safety for the return type of [AssistantToolUnion.AsAny] +type anyAssistantTool interface { + implAssistantToolUnion() +} + +func (CodeInterpreterTool) implAssistantToolUnion() {} +func (FileSearchTool) implAssistantToolUnion() {} +func (FunctionTool) implAssistantToolUnion() {} + +// Use the following switch statement to find the correct variant +// +// switch variant := AssistantToolUnion.AsAny().(type) { +// case openai.CodeInterpreterTool: +// case openai.FileSearchTool: +// case openai.FunctionTool: +// default: +// fmt.Errorf("no variant present") +// } +func (u AssistantToolUnion) AsAny() anyAssistantTool { + switch u.Type { + case "code_interpreter": + return u.AsCodeInterpreter() + case "file_search": + return u.AsFileSearch() + case "function": + return u.AsFunction() + } + return nil +} + +func (u AssistantToolUnion) AsCodeInterpreter() (v CodeInterpreterTool) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u AssistantToolUnion) AsFileSearch() (v FileSearchTool) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u AssistantToolUnion) AsFunction() (v FunctionTool) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u AssistantToolUnion) RawJSON() string { return u.JSON.raw } + +func (r *AssistantToolUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ToParam converts this AssistantToolUnion to a AssistantToolUnionParam. +// +// Warning: the fields of the param type will not be present. ToParam should only +// be used at the last possible moment before sending a request. Test for this with +// AssistantToolUnionParam.Overrides() +func (r AssistantToolUnion) ToParam() AssistantToolUnionParam { + return param.Override[AssistantToolUnionParam](json.RawMessage(r.RawJSON())) +} + +func AssistantToolParamOfFunction(function shared.FunctionDefinitionParam) AssistantToolUnionParam { + var variant FunctionToolParam + variant.Function = function + return AssistantToolUnionParam{OfFunction: &variant} +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type AssistantToolUnionParam struct { + OfCodeInterpreter *CodeInterpreterToolParam `json:",omitzero,inline"` + OfFileSearch *FileSearchToolParam `json:",omitzero,inline"` + OfFunction *FunctionToolParam `json:",omitzero,inline"` + paramUnion +} + +func (u AssistantToolUnionParam) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfCodeInterpreter, u.OfFileSearch, u.OfFunction) +} +func (u *AssistantToolUnionParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *AssistantToolUnionParam) asAny() any { + if !param.IsOmitted(u.OfCodeInterpreter) { + return u.OfCodeInterpreter + } else if !param.IsOmitted(u.OfFileSearch) { + return u.OfFileSearch + } else if !param.IsOmitted(u.OfFunction) { + return u.OfFunction + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u AssistantToolUnionParam) GetFileSearch() *FileSearchToolFileSearchParam { + if vt := u.OfFileSearch; vt != nil { + return &vt.FileSearch + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u AssistantToolUnionParam) GetFunction() *shared.FunctionDefinitionParam { + if vt := u.OfFunction; vt != nil { + return &vt.Function + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u AssistantToolUnionParam) GetType() *string { + if vt := u.OfCodeInterpreter; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfFileSearch; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfFunction; vt != nil { + return (*string)(&vt.Type) + } + return nil +} + +func init() { + apijson.RegisterUnion[AssistantToolUnionParam]( + "type", + apijson.Discriminator[CodeInterpreterToolParam]("code_interpreter"), + apijson.Discriminator[FileSearchToolParam]("file_search"), + apijson.Discriminator[FunctionToolParam]("function"), + ) +} + +type CodeInterpreterTool struct { + // The type of tool being defined: `code_interpreter` + Type constant.CodeInterpreter `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r CodeInterpreterTool) RawJSON() string { return r.JSON.raw } +func (r *CodeInterpreterTool) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ToParam converts this CodeInterpreterTool to a CodeInterpreterToolParam. +// +// Warning: the fields of the param type will not be present. ToParam should only +// be used at the last possible moment before sending a request. Test for this with +// CodeInterpreterToolParam.Overrides() +func (r CodeInterpreterTool) ToParam() CodeInterpreterToolParam { + return param.Override[CodeInterpreterToolParam](json.RawMessage(r.RawJSON())) +} + +func NewCodeInterpreterToolParam() CodeInterpreterToolParam { + return CodeInterpreterToolParam{ + Type: "code_interpreter", + } +} + +// This struct has a constant value, construct it with +// [NewCodeInterpreterToolParam]. +type CodeInterpreterToolParam struct { + // The type of tool being defined: `code_interpreter` + Type constant.CodeInterpreter `json:"type,required"` + paramObj +} + +func (r CodeInterpreterToolParam) MarshalJSON() (data []byte, err error) { + type shadow CodeInterpreterToolParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *CodeInterpreterToolParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FileSearchTool struct { + // The type of tool being defined: `file_search` + Type constant.FileSearch `json:"type,required"` + // Overrides for the file search tool. + FileSearch FileSearchToolFileSearch `json:"file_search"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Type respjson.Field + FileSearch respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FileSearchTool) RawJSON() string { return r.JSON.raw } +func (r *FileSearchTool) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ToParam converts this FileSearchTool to a FileSearchToolParam. +// +// Warning: the fields of the param type will not be present. ToParam should only +// be used at the last possible moment before sending a request. Test for this with +// FileSearchToolParam.Overrides() +func (r FileSearchTool) ToParam() FileSearchToolParam { + return param.Override[FileSearchToolParam](json.RawMessage(r.RawJSON())) +} + +// Overrides for the file search tool. +type FileSearchToolFileSearch struct { + // The maximum number of results the file search tool should output. The default is + // 20 for `gpt-4*` models and 5 for `gpt-3.5-turbo`. This number should be between + // 1 and 50 inclusive. + // + // Note that the file search tool may output fewer than `max_num_results` results. + // See the + // [file search tool documentation](https://platform.openai.com/docs/assistants/tools/file-search#customizing-file-search-settings) + // for more information. + MaxNumResults int64 `json:"max_num_results"` + // The ranking options for the file search. If not specified, the file search tool + // will use the `auto` ranker and a score_threshold of 0. + // + // See the + // [file search tool documentation](https://platform.openai.com/docs/assistants/tools/file-search#customizing-file-search-settings) + // for more information. + RankingOptions FileSearchToolFileSearchRankingOptions `json:"ranking_options"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + MaxNumResults respjson.Field + RankingOptions respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FileSearchToolFileSearch) RawJSON() string { return r.JSON.raw } +func (r *FileSearchToolFileSearch) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The ranking options for the file search. If not specified, the file search tool +// will use the `auto` ranker and a score_threshold of 0. +// +// See the +// [file search tool documentation](https://platform.openai.com/docs/assistants/tools/file-search#customizing-file-search-settings) +// for more information. +type FileSearchToolFileSearchRankingOptions struct { + // The score threshold for the file search. All values must be a floating point + // number between 0 and 1. + ScoreThreshold float64 `json:"score_threshold,required"` + // The ranker to use for the file search. If not specified will use the `auto` + // ranker. + // + // Any of "auto", "default_2024_08_21". + Ranker string `json:"ranker"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ScoreThreshold respjson.Field + Ranker respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FileSearchToolFileSearchRankingOptions) RawJSON() string { return r.JSON.raw } +func (r *FileSearchToolFileSearchRankingOptions) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The property Type is required. +type FileSearchToolParam struct { + // Overrides for the file search tool. + FileSearch FileSearchToolFileSearchParam `json:"file_search,omitzero"` + // The type of tool being defined: `file_search` + // + // This field can be elided, and will marshal its zero value as "file_search". + Type constant.FileSearch `json:"type,required"` + paramObj +} + +func (r FileSearchToolParam) MarshalJSON() (data []byte, err error) { + type shadow FileSearchToolParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *FileSearchToolParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Overrides for the file search tool. +type FileSearchToolFileSearchParam struct { + // The maximum number of results the file search tool should output. The default is + // 20 for `gpt-4*` models and 5 for `gpt-3.5-turbo`. This number should be between + // 1 and 50 inclusive. + // + // Note that the file search tool may output fewer than `max_num_results` results. + // See the + // [file search tool documentation](https://platform.openai.com/docs/assistants/tools/file-search#customizing-file-search-settings) + // for more information. + MaxNumResults param.Opt[int64] `json:"max_num_results,omitzero"` + // The ranking options for the file search. If not specified, the file search tool + // will use the `auto` ranker and a score_threshold of 0. + // + // See the + // [file search tool documentation](https://platform.openai.com/docs/assistants/tools/file-search#customizing-file-search-settings) + // for more information. + RankingOptions FileSearchToolFileSearchRankingOptionsParam `json:"ranking_options,omitzero"` + paramObj +} + +func (r FileSearchToolFileSearchParam) MarshalJSON() (data []byte, err error) { + type shadow FileSearchToolFileSearchParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *FileSearchToolFileSearchParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The ranking options for the file search. If not specified, the file search tool +// will use the `auto` ranker and a score_threshold of 0. +// +// See the +// [file search tool documentation](https://platform.openai.com/docs/assistants/tools/file-search#customizing-file-search-settings) +// for more information. +// +// The property ScoreThreshold is required. +type FileSearchToolFileSearchRankingOptionsParam struct { + // The score threshold for the file search. All values must be a floating point + // number between 0 and 1. + ScoreThreshold float64 `json:"score_threshold,required"` + // The ranker to use for the file search. If not specified will use the `auto` + // ranker. + // + // Any of "auto", "default_2024_08_21". + Ranker string `json:"ranker,omitzero"` + paramObj +} + +func (r FileSearchToolFileSearchRankingOptionsParam) MarshalJSON() (data []byte, err error) { + type shadow FileSearchToolFileSearchRankingOptionsParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *FileSearchToolFileSearchRankingOptionsParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +func init() { + apijson.RegisterFieldValidator[FileSearchToolFileSearchRankingOptionsParam]( + "ranker", "auto", "default_2024_08_21", + ) +} + +type FunctionTool struct { + Function shared.FunctionDefinition `json:"function,required"` + // The type of tool being defined: `function` + Type constant.Function `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Function respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FunctionTool) RawJSON() string { return r.JSON.raw } +func (r *FunctionTool) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ToParam converts this FunctionTool to a FunctionToolParam. +// +// Warning: the fields of the param type will not be present. ToParam should only +// be used at the last possible moment before sending a request. Test for this with +// FunctionToolParam.Overrides() +func (r FunctionTool) ToParam() FunctionToolParam { + return param.Override[FunctionToolParam](json.RawMessage(r.RawJSON())) +} + +// The properties Function, Type are required. +type FunctionToolParam struct { + Function shared.FunctionDefinitionParam `json:"function,omitzero,required"` + // The type of tool being defined: `function` + // + // This field can be elided, and will marshal its zero value as "function". + Type constant.Function `json:"type,required"` + paramObj +} + +func (r FunctionToolParam) MarshalJSON() (data []byte, err error) { + type shadow FunctionToolParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *FunctionToolParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type BetaAssistantNewParams struct { + // ID of the model to use. You can use the + // [List models](https://platform.openai.com/docs/api-reference/models/list) API to + // see all of your available models, or see our + // [Model overview](https://platform.openai.com/docs/models) for descriptions of + // them. + Model shared.ChatModel `json:"model,omitzero,required"` + // The description of the assistant. The maximum length is 512 characters. + Description param.Opt[string] `json:"description,omitzero"` + // The system instructions that the assistant uses. The maximum length is 256,000 + // characters. + Instructions param.Opt[string] `json:"instructions,omitzero"` + // The name of the assistant. The maximum length is 256 characters. + Name param.Opt[string] `json:"name,omitzero"` + // What sampling temperature to use, between 0 and 2. Higher values like 0.8 will + // make the output more random, while lower values like 0.2 will make it more + // focused and deterministic. + Temperature param.Opt[float64] `json:"temperature,omitzero"` + // An alternative to sampling with temperature, called nucleus sampling, where the + // model considers the results of the tokens with top_p probability mass. So 0.1 + // means only the tokens comprising the top 10% probability mass are considered. + // + // We generally recommend altering this or temperature but not both. + TopP param.Opt[float64] `json:"top_p,omitzero"` + // Set of 16 key-value pairs that can be attached to an object. This can be useful + // for storing additional information about the object in a structured format, and + // querying for objects via API or the dashboard. + // + // Keys are strings with a maximum length of 64 characters. Values are strings with + // a maximum length of 512 characters. + Metadata shared.Metadata `json:"metadata,omitzero"` + // **o-series models only** + // + // Constrains effort on reasoning for + // [reasoning models](https://platform.openai.com/docs/guides/reasoning). Currently + // supported values are `low`, `medium`, and `high`. Reducing reasoning effort can + // result in faster responses and fewer tokens used on reasoning in a response. + // + // Any of "low", "medium", "high". + ReasoningEffort shared.ReasoningEffort `json:"reasoning_effort,omitzero"` + // A set of resources that are used by the assistant's tools. The resources are + // specific to the type of tool. For example, the `code_interpreter` tool requires + // a list of file IDs, while the `file_search` tool requires a list of vector store + // IDs. + ToolResources BetaAssistantNewParamsToolResources `json:"tool_resources,omitzero"` + // Specifies the format that the model must output. Compatible with + // [GPT-4o](https://platform.openai.com/docs/models#gpt-4o), + // [GPT-4 Turbo](https://platform.openai.com/docs/models#gpt-4-turbo-and-gpt-4), + // and all GPT-3.5 Turbo models since `gpt-3.5-turbo-1106`. + // + // Setting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured + // Outputs which ensures the model will match your supplied JSON schema. Learn more + // in the + // [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs). + // + // Setting to `{ "type": "json_object" }` enables JSON mode, which ensures the + // message the model generates is valid JSON. + // + // **Important:** when using JSON mode, you **must** also instruct the model to + // produce JSON yourself via a system or user message. Without this, the model may + // generate an unending stream of whitespace until the generation reaches the token + // limit, resulting in a long-running and seemingly "stuck" request. Also note that + // the message content may be partially cut off if `finish_reason="length"`, which + // indicates the generation exceeded `max_tokens` or the conversation exceeded the + // max context length. + ResponseFormat AssistantResponseFormatOptionUnionParam `json:"response_format,omitzero"` + // A list of tool enabled on the assistant. There can be a maximum of 128 tools per + // assistant. Tools can be of types `code_interpreter`, `file_search`, or + // `function`. + Tools []AssistantToolUnionParam `json:"tools,omitzero"` + paramObj +} + +func (r BetaAssistantNewParams) MarshalJSON() (data []byte, err error) { + type shadow BetaAssistantNewParams + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaAssistantNewParams) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// A set of resources that are used by the assistant's tools. The resources are +// specific to the type of tool. For example, the `code_interpreter` tool requires +// a list of file IDs, while the `file_search` tool requires a list of vector store +// IDs. +type BetaAssistantNewParamsToolResources struct { + CodeInterpreter BetaAssistantNewParamsToolResourcesCodeInterpreter `json:"code_interpreter,omitzero"` + FileSearch BetaAssistantNewParamsToolResourcesFileSearch `json:"file_search,omitzero"` + paramObj +} + +func (r BetaAssistantNewParamsToolResources) MarshalJSON() (data []byte, err error) { + type shadow BetaAssistantNewParamsToolResources + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaAssistantNewParamsToolResources) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type BetaAssistantNewParamsToolResourcesCodeInterpreter struct { + // A list of [file](https://platform.openai.com/docs/api-reference/files) IDs made + // available to the `code_interpreter` tool. There can be a maximum of 20 files + // associated with the tool. + FileIDs []string `json:"file_ids,omitzero"` + paramObj +} + +func (r BetaAssistantNewParamsToolResourcesCodeInterpreter) MarshalJSON() (data []byte, err error) { + type shadow BetaAssistantNewParamsToolResourcesCodeInterpreter + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaAssistantNewParamsToolResourcesCodeInterpreter) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type BetaAssistantNewParamsToolResourcesFileSearch struct { + // The + // [vector store](https://platform.openai.com/docs/api-reference/vector-stores/object) + // attached to this assistant. There can be a maximum of 1 vector store attached to + // the assistant. + VectorStoreIDs []string `json:"vector_store_ids,omitzero"` + // A helper to create a + // [vector store](https://platform.openai.com/docs/api-reference/vector-stores/object) + // with file_ids and attach it to this assistant. There can be a maximum of 1 + // vector store attached to the assistant. + VectorStores []BetaAssistantNewParamsToolResourcesFileSearchVectorStore `json:"vector_stores,omitzero"` + paramObj +} + +func (r BetaAssistantNewParamsToolResourcesFileSearch) MarshalJSON() (data []byte, err error) { + type shadow BetaAssistantNewParamsToolResourcesFileSearch + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaAssistantNewParamsToolResourcesFileSearch) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type BetaAssistantNewParamsToolResourcesFileSearchVectorStore struct { + // Set of 16 key-value pairs that can be attached to an object. This can be useful + // for storing additional information about the object in a structured format, and + // querying for objects via API or the dashboard. + // + // Keys are strings with a maximum length of 64 characters. Values are strings with + // a maximum length of 512 characters. + Metadata shared.Metadata `json:"metadata,omitzero"` + // The chunking strategy used to chunk the file(s). If not set, will use the `auto` + // strategy. + ChunkingStrategy BetaAssistantNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyUnion `json:"chunking_strategy,omitzero"` + // A list of [file](https://platform.openai.com/docs/api-reference/files) IDs to + // add to the vector store. There can be a maximum of 10000 files in a vector + // store. + FileIDs []string `json:"file_ids,omitzero"` + paramObj +} + +func (r BetaAssistantNewParamsToolResourcesFileSearchVectorStore) MarshalJSON() (data []byte, err error) { + type shadow BetaAssistantNewParamsToolResourcesFileSearchVectorStore + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaAssistantNewParamsToolResourcesFileSearchVectorStore) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type BetaAssistantNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyUnion struct { + OfAuto *BetaAssistantNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyAuto `json:",omitzero,inline"` + OfStatic *BetaAssistantNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStatic `json:",omitzero,inline"` + paramUnion +} + +func (u BetaAssistantNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfAuto, u.OfStatic) +} +func (u *BetaAssistantNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *BetaAssistantNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyUnion) asAny() any { + if !param.IsOmitted(u.OfAuto) { + return u.OfAuto + } else if !param.IsOmitted(u.OfStatic) { + return u.OfStatic + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u BetaAssistantNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyUnion) GetStatic() *BetaAssistantNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStaticStatic { + if vt := u.OfStatic; vt != nil { + return &vt.Static + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u BetaAssistantNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyUnion) GetType() *string { + if vt := u.OfAuto; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfStatic; vt != nil { + return (*string)(&vt.Type) + } + return nil +} + +func init() { + apijson.RegisterUnion[BetaAssistantNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyUnion]( + "type", + apijson.Discriminator[BetaAssistantNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyAuto]("auto"), + apijson.Discriminator[BetaAssistantNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStatic]("static"), + ) +} + +func NewBetaAssistantNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyAuto() BetaAssistantNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyAuto { + return BetaAssistantNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyAuto{ + Type: "auto", + } +} + +// The default strategy. This strategy currently uses a `max_chunk_size_tokens` of +// `800` and `chunk_overlap_tokens` of `400`. +// +// This struct has a constant value, construct it with +// [NewBetaAssistantNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyAuto]. +type BetaAssistantNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyAuto struct { + // Always `auto`. + Type constant.Auto `json:"type,required"` + paramObj +} + +func (r BetaAssistantNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyAuto) MarshalJSON() (data []byte, err error) { + type shadow BetaAssistantNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyAuto + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaAssistantNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyAuto) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The properties Static, Type are required. +type BetaAssistantNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStatic struct { + Static BetaAssistantNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStaticStatic `json:"static,omitzero,required"` + // Always `static`. + // + // This field can be elided, and will marshal its zero value as "static". + Type constant.Static `json:"type,required"` + paramObj +} + +func (r BetaAssistantNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStatic) MarshalJSON() (data []byte, err error) { + type shadow BetaAssistantNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStatic + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaAssistantNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStatic) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The properties ChunkOverlapTokens, MaxChunkSizeTokens are required. +type BetaAssistantNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStaticStatic struct { + // The number of tokens that overlap between chunks. The default value is `400`. + // + // Note that the overlap must not exceed half of `max_chunk_size_tokens`. + ChunkOverlapTokens int64 `json:"chunk_overlap_tokens,required"` + // The maximum number of tokens in each chunk. The default value is `800`. The + // minimum value is `100` and the maximum value is `4096`. + MaxChunkSizeTokens int64 `json:"max_chunk_size_tokens,required"` + paramObj +} + +func (r BetaAssistantNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStaticStatic) MarshalJSON() (data []byte, err error) { + type shadow BetaAssistantNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStaticStatic + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaAssistantNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStaticStatic) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type BetaAssistantUpdateParams struct { + // The description of the assistant. The maximum length is 512 characters. + Description param.Opt[string] `json:"description,omitzero"` + // The system instructions that the assistant uses. The maximum length is 256,000 + // characters. + Instructions param.Opt[string] `json:"instructions,omitzero"` + // The name of the assistant. The maximum length is 256 characters. + Name param.Opt[string] `json:"name,omitzero"` + // What sampling temperature to use, between 0 and 2. Higher values like 0.8 will + // make the output more random, while lower values like 0.2 will make it more + // focused and deterministic. + Temperature param.Opt[float64] `json:"temperature,omitzero"` + // An alternative to sampling with temperature, called nucleus sampling, where the + // model considers the results of the tokens with top_p probability mass. So 0.1 + // means only the tokens comprising the top 10% probability mass are considered. + // + // We generally recommend altering this or temperature but not both. + TopP param.Opt[float64] `json:"top_p,omitzero"` + // Set of 16 key-value pairs that can be attached to an object. This can be useful + // for storing additional information about the object in a structured format, and + // querying for objects via API or the dashboard. + // + // Keys are strings with a maximum length of 64 characters. Values are strings with + // a maximum length of 512 characters. + Metadata shared.Metadata `json:"metadata,omitzero"` + // **o-series models only** + // + // Constrains effort on reasoning for + // [reasoning models](https://platform.openai.com/docs/guides/reasoning). Currently + // supported values are `low`, `medium`, and `high`. Reducing reasoning effort can + // result in faster responses and fewer tokens used on reasoning in a response. + // + // Any of "low", "medium", "high". + ReasoningEffort shared.ReasoningEffort `json:"reasoning_effort,omitzero"` + // A set of resources that are used by the assistant's tools. The resources are + // specific to the type of tool. For example, the `code_interpreter` tool requires + // a list of file IDs, while the `file_search` tool requires a list of vector store + // IDs. + ToolResources BetaAssistantUpdateParamsToolResources `json:"tool_resources,omitzero"` + // ID of the model to use. You can use the + // [List models](https://platform.openai.com/docs/api-reference/models/list) API to + // see all of your available models, or see our + // [Model overview](https://platform.openai.com/docs/models) for descriptions of + // them. + Model BetaAssistantUpdateParamsModel `json:"model,omitzero"` + // Specifies the format that the model must output. Compatible with + // [GPT-4o](https://platform.openai.com/docs/models#gpt-4o), + // [GPT-4 Turbo](https://platform.openai.com/docs/models#gpt-4-turbo-and-gpt-4), + // and all GPT-3.5 Turbo models since `gpt-3.5-turbo-1106`. + // + // Setting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured + // Outputs which ensures the model will match your supplied JSON schema. Learn more + // in the + // [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs). + // + // Setting to `{ "type": "json_object" }` enables JSON mode, which ensures the + // message the model generates is valid JSON. + // + // **Important:** when using JSON mode, you **must** also instruct the model to + // produce JSON yourself via a system or user message. Without this, the model may + // generate an unending stream of whitespace until the generation reaches the token + // limit, resulting in a long-running and seemingly "stuck" request. Also note that + // the message content may be partially cut off if `finish_reason="length"`, which + // indicates the generation exceeded `max_tokens` or the conversation exceeded the + // max context length. + ResponseFormat AssistantResponseFormatOptionUnionParam `json:"response_format,omitzero"` + // A list of tool enabled on the assistant. There can be a maximum of 128 tools per + // assistant. Tools can be of types `code_interpreter`, `file_search`, or + // `function`. + Tools []AssistantToolUnionParam `json:"tools,omitzero"` + paramObj +} + +func (r BetaAssistantUpdateParams) MarshalJSON() (data []byte, err error) { + type shadow BetaAssistantUpdateParams + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaAssistantUpdateParams) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ID of the model to use. You can use the +// [List models](https://platform.openai.com/docs/api-reference/models/list) API to +// see all of your available models, or see our +// [Model overview](https://platform.openai.com/docs/models) for descriptions of +// them. +type BetaAssistantUpdateParamsModel string + +const ( + BetaAssistantUpdateParamsModelGPT4_1 BetaAssistantUpdateParamsModel = "gpt-4.1" + BetaAssistantUpdateParamsModelGPT4_1Mini BetaAssistantUpdateParamsModel = "gpt-4.1-mini" + BetaAssistantUpdateParamsModelGPT4_1Nano BetaAssistantUpdateParamsModel = "gpt-4.1-nano" + BetaAssistantUpdateParamsModelGPT4_1_2025_04_14 BetaAssistantUpdateParamsModel = "gpt-4.1-2025-04-14" + BetaAssistantUpdateParamsModelGPT4_1Mini2025_04_14 BetaAssistantUpdateParamsModel = "gpt-4.1-mini-2025-04-14" + BetaAssistantUpdateParamsModelGPT4_1Nano2025_04_14 BetaAssistantUpdateParamsModel = "gpt-4.1-nano-2025-04-14" + BetaAssistantUpdateParamsModelO3Mini BetaAssistantUpdateParamsModel = "o3-mini" + BetaAssistantUpdateParamsModelO3Mini2025_01_31 BetaAssistantUpdateParamsModel = "o3-mini-2025-01-31" + BetaAssistantUpdateParamsModelO1 BetaAssistantUpdateParamsModel = "o1" + BetaAssistantUpdateParamsModelO1_2024_12_17 BetaAssistantUpdateParamsModel = "o1-2024-12-17" + BetaAssistantUpdateParamsModelGPT4o BetaAssistantUpdateParamsModel = "gpt-4o" + BetaAssistantUpdateParamsModelGPT4o2024_11_20 BetaAssistantUpdateParamsModel = "gpt-4o-2024-11-20" + BetaAssistantUpdateParamsModelGPT4o2024_08_06 BetaAssistantUpdateParamsModel = "gpt-4o-2024-08-06" + BetaAssistantUpdateParamsModelGPT4o2024_05_13 BetaAssistantUpdateParamsModel = "gpt-4o-2024-05-13" + BetaAssistantUpdateParamsModelGPT4oMini BetaAssistantUpdateParamsModel = "gpt-4o-mini" + BetaAssistantUpdateParamsModelGPT4oMini2024_07_18 BetaAssistantUpdateParamsModel = "gpt-4o-mini-2024-07-18" + BetaAssistantUpdateParamsModelGPT4_5Preview BetaAssistantUpdateParamsModel = "gpt-4.5-preview" + BetaAssistantUpdateParamsModelGPT4_5Preview2025_02_27 BetaAssistantUpdateParamsModel = "gpt-4.5-preview-2025-02-27" + BetaAssistantUpdateParamsModelGPT4Turbo BetaAssistantUpdateParamsModel = "gpt-4-turbo" + BetaAssistantUpdateParamsModelGPT4Turbo2024_04_09 BetaAssistantUpdateParamsModel = "gpt-4-turbo-2024-04-09" + BetaAssistantUpdateParamsModelGPT4_0125Preview BetaAssistantUpdateParamsModel = "gpt-4-0125-preview" + BetaAssistantUpdateParamsModelGPT4TurboPreview BetaAssistantUpdateParamsModel = "gpt-4-turbo-preview" + BetaAssistantUpdateParamsModelGPT4_1106Preview BetaAssistantUpdateParamsModel = "gpt-4-1106-preview" + BetaAssistantUpdateParamsModelGPT4VisionPreview BetaAssistantUpdateParamsModel = "gpt-4-vision-preview" + BetaAssistantUpdateParamsModelGPT4 BetaAssistantUpdateParamsModel = "gpt-4" + BetaAssistantUpdateParamsModelGPT4_0314 BetaAssistantUpdateParamsModel = "gpt-4-0314" + BetaAssistantUpdateParamsModelGPT4_0613 BetaAssistantUpdateParamsModel = "gpt-4-0613" + BetaAssistantUpdateParamsModelGPT4_32k BetaAssistantUpdateParamsModel = "gpt-4-32k" + BetaAssistantUpdateParamsModelGPT4_32k0314 BetaAssistantUpdateParamsModel = "gpt-4-32k-0314" + BetaAssistantUpdateParamsModelGPT4_32k0613 BetaAssistantUpdateParamsModel = "gpt-4-32k-0613" + BetaAssistantUpdateParamsModelGPT3_5Turbo BetaAssistantUpdateParamsModel = "gpt-3.5-turbo" + BetaAssistantUpdateParamsModelGPT3_5Turbo16k BetaAssistantUpdateParamsModel = "gpt-3.5-turbo-16k" + BetaAssistantUpdateParamsModelGPT3_5Turbo0613 BetaAssistantUpdateParamsModel = "gpt-3.5-turbo-0613" + BetaAssistantUpdateParamsModelGPT3_5Turbo1106 BetaAssistantUpdateParamsModel = "gpt-3.5-turbo-1106" + BetaAssistantUpdateParamsModelGPT3_5Turbo0125 BetaAssistantUpdateParamsModel = "gpt-3.5-turbo-0125" + BetaAssistantUpdateParamsModelGPT3_5Turbo16k0613 BetaAssistantUpdateParamsModel = "gpt-3.5-turbo-16k-0613" +) + +// A set of resources that are used by the assistant's tools. The resources are +// specific to the type of tool. For example, the `code_interpreter` tool requires +// a list of file IDs, while the `file_search` tool requires a list of vector store +// IDs. +type BetaAssistantUpdateParamsToolResources struct { + CodeInterpreter BetaAssistantUpdateParamsToolResourcesCodeInterpreter `json:"code_interpreter,omitzero"` + FileSearch BetaAssistantUpdateParamsToolResourcesFileSearch `json:"file_search,omitzero"` + paramObj +} + +func (r BetaAssistantUpdateParamsToolResources) MarshalJSON() (data []byte, err error) { + type shadow BetaAssistantUpdateParamsToolResources + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaAssistantUpdateParamsToolResources) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type BetaAssistantUpdateParamsToolResourcesCodeInterpreter struct { + // Overrides the list of + // [file](https://platform.openai.com/docs/api-reference/files) IDs made available + // to the `code_interpreter` tool. There can be a maximum of 20 files associated + // with the tool. + FileIDs []string `json:"file_ids,omitzero"` + paramObj +} + +func (r BetaAssistantUpdateParamsToolResourcesCodeInterpreter) MarshalJSON() (data []byte, err error) { + type shadow BetaAssistantUpdateParamsToolResourcesCodeInterpreter + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaAssistantUpdateParamsToolResourcesCodeInterpreter) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type BetaAssistantUpdateParamsToolResourcesFileSearch struct { + // Overrides the + // [vector store](https://platform.openai.com/docs/api-reference/vector-stores/object) + // attached to this assistant. There can be a maximum of 1 vector store attached to + // the assistant. + VectorStoreIDs []string `json:"vector_store_ids,omitzero"` + paramObj +} + +func (r BetaAssistantUpdateParamsToolResourcesFileSearch) MarshalJSON() (data []byte, err error) { + type shadow BetaAssistantUpdateParamsToolResourcesFileSearch + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaAssistantUpdateParamsToolResourcesFileSearch) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type BetaAssistantListParams struct { + // A cursor for use in pagination. `after` is an object ID that defines your place + // in the list. For instance, if you make a list request and receive 100 objects, + // ending with obj_foo, your subsequent call can include after=obj_foo in order to + // fetch the next page of the list. + After param.Opt[string] `query:"after,omitzero" json:"-"` + // A cursor for use in pagination. `before` is an object ID that defines your place + // in the list. For instance, if you make a list request and receive 100 objects, + // starting with obj_foo, your subsequent call can include before=obj_foo in order + // to fetch the previous page of the list. + Before param.Opt[string] `query:"before,omitzero" json:"-"` + // A limit on the number of objects to be returned. Limit can range between 1 and + // 100, and the default is 20. + Limit param.Opt[int64] `query:"limit,omitzero" json:"-"` + // Sort order by the `created_at` timestamp of the objects. `asc` for ascending + // order and `desc` for descending order. + // + // Any of "asc", "desc". + Order BetaAssistantListParamsOrder `query:"order,omitzero" json:"-"` + paramObj +} + +// URLQuery serializes [BetaAssistantListParams]'s query parameters as +// `url.Values`. +func (r BetaAssistantListParams) URLQuery() (v url.Values, err error) { + return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{ + ArrayFormat: apiquery.ArrayQueryFormatBrackets, + NestedFormat: apiquery.NestedQueryFormatBrackets, + }) +} + +// Sort order by the `created_at` timestamp of the objects. `asc` for ascending +// order and `desc` for descending order. +type BetaAssistantListParamsOrder string + +const ( + BetaAssistantListParamsOrderAsc BetaAssistantListParamsOrder = "asc" + BetaAssistantListParamsOrderDesc BetaAssistantListParamsOrder = "desc" +) diff --git a/vendor/github.com/openai/openai-go/betathread.go b/vendor/github.com/openai/openai-go/betathread.go new file mode 100644 index 0000000000..7e351bf39a --- /dev/null +++ b/vendor/github.com/openai/openai-go/betathread.go @@ -0,0 +1,1564 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package openai + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + + "github.com/openai/openai-go/internal/apijson" + "github.com/openai/openai-go/internal/requestconfig" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/packages/param" + "github.com/openai/openai-go/packages/respjson" + "github.com/openai/openai-go/packages/ssestream" + "github.com/openai/openai-go/shared" + "github.com/openai/openai-go/shared/constant" +) + +// BetaThreadService contains methods and other services that help with interacting +// with the openai API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewBetaThreadService] method instead. +// +// Deprecated: The Assistants API is deprecated in favor of the Responses API +type BetaThreadService struct { + Options []option.RequestOption + // Deprecated: The Assistants API is deprecated in favor of the Responses API + Runs BetaThreadRunService + // Deprecated: The Assistants API is deprecated in favor of the Responses API + Messages BetaThreadMessageService +} + +// NewBetaThreadService generates a new service that applies the given options to +// each request. These options are applied after the parent client's options (if +// there is one), and before any request-specific options. +func NewBetaThreadService(opts ...option.RequestOption) (r BetaThreadService) { + r = BetaThreadService{} + r.Options = opts + r.Runs = NewBetaThreadRunService(opts...) + r.Messages = NewBetaThreadMessageService(opts...) + return +} + +// Create a thread. +// +// Deprecated: The Assistants API is deprecated in favor of the Responses API +func (r *BetaThreadService) New(ctx context.Context, body BetaThreadNewParams, opts ...option.RequestOption) (res *Thread, err error) { + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...) + path := "threads" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) + return +} + +// Retrieves a thread. +// +// Deprecated: The Assistants API is deprecated in favor of the Responses API +func (r *BetaThreadService) Get(ctx context.Context, threadID string, opts ...option.RequestOption) (res *Thread, err error) { + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...) + if threadID == "" { + err = errors.New("missing required thread_id parameter") + return + } + path := fmt.Sprintf("threads/%s", threadID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, nil, &res, opts...) + return +} + +// Modifies a thread. +// +// Deprecated: The Assistants API is deprecated in favor of the Responses API +func (r *BetaThreadService) Update(ctx context.Context, threadID string, body BetaThreadUpdateParams, opts ...option.RequestOption) (res *Thread, err error) { + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...) + if threadID == "" { + err = errors.New("missing required thread_id parameter") + return + } + path := fmt.Sprintf("threads/%s", threadID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) + return +} + +// Delete a thread. +// +// Deprecated: The Assistants API is deprecated in favor of the Responses API +func (r *BetaThreadService) Delete(ctx context.Context, threadID string, opts ...option.RequestOption) (res *ThreadDeleted, err error) { + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...) + if threadID == "" { + err = errors.New("missing required thread_id parameter") + return + } + path := fmt.Sprintf("threads/%s", threadID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodDelete, path, nil, &res, opts...) + return +} + +// Create a thread and run it in one request. +// +// Deprecated: The Assistants API is deprecated in favor of the Responses API +func (r *BetaThreadService) NewAndRun(ctx context.Context, body BetaThreadNewAndRunParams, opts ...option.RequestOption) (res *Run, err error) { + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...) + path := "threads/runs" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) + return +} + +// Create a thread and run it in one request. +// +// Deprecated: The Assistants API is deprecated in favor of the Responses API +func (r *BetaThreadService) NewAndRunStreaming(ctx context.Context, body BetaThreadNewAndRunParams, opts ...option.RequestOption) (stream *ssestream.Stream[AssistantStreamEventUnion]) { + var ( + raw *http.Response + err error + ) + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2"), option.WithJSONSet("stream", true)}, opts...) + path := "threads/runs" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &raw, opts...) + return ssestream.NewStream[AssistantStreamEventUnion](ssestream.NewDecoder(raw), err) +} + +// AssistantResponseFormatOptionUnion contains all possible properties and values +// from [constant.Auto], [shared.ResponseFormatText], +// [shared.ResponseFormatJSONObject], [shared.ResponseFormatJSONSchema]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +// +// If the underlying value is not a json object, one of the following properties +// will be valid: OfAuto] +type AssistantResponseFormatOptionUnion struct { + // This field will be present if the value is a [constant.Auto] instead of an + // object. + OfAuto constant.Auto `json:",inline"` + Type string `json:"type"` + // This field is from variant [shared.ResponseFormatJSONSchema]. + JSONSchema shared.ResponseFormatJSONSchemaJSONSchema `json:"json_schema"` + JSON struct { + OfAuto respjson.Field + Type respjson.Field + JSONSchema respjson.Field + raw string + } `json:"-"` +} + +func (u AssistantResponseFormatOptionUnion) AsAuto() (v constant.Auto) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u AssistantResponseFormatOptionUnion) AsText() (v shared.ResponseFormatText) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u AssistantResponseFormatOptionUnion) AsJSONObject() (v shared.ResponseFormatJSONObject) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u AssistantResponseFormatOptionUnion) AsJSONSchema() (v shared.ResponseFormatJSONSchema) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u AssistantResponseFormatOptionUnion) RawJSON() string { return u.JSON.raw } + +func (r *AssistantResponseFormatOptionUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ToParam converts this AssistantResponseFormatOptionUnion to a +// AssistantResponseFormatOptionUnionParam. +// +// Warning: the fields of the param type will not be present. ToParam should only +// be used at the last possible moment before sending a request. Test for this with +// AssistantResponseFormatOptionUnionParam.Overrides() +func (r AssistantResponseFormatOptionUnion) ToParam() AssistantResponseFormatOptionUnionParam { + return param.Override[AssistantResponseFormatOptionUnionParam](json.RawMessage(r.RawJSON())) +} + +func AssistantResponseFormatOptionParamOfAuto() AssistantResponseFormatOptionUnionParam { + return AssistantResponseFormatOptionUnionParam{OfAuto: constant.ValueOf[constant.Auto]()} +} + +func AssistantResponseFormatOptionParamOfJSONSchema(jsonSchema shared.ResponseFormatJSONSchemaJSONSchemaParam) AssistantResponseFormatOptionUnionParam { + var variant shared.ResponseFormatJSONSchemaParam + variant.JSONSchema = jsonSchema + return AssistantResponseFormatOptionUnionParam{OfJSONSchema: &variant} +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type AssistantResponseFormatOptionUnionParam struct { + // Construct this variant with constant.ValueOf[constant.Auto]() + OfAuto constant.Auto `json:",omitzero,inline"` + OfText *shared.ResponseFormatTextParam `json:",omitzero,inline"` + OfJSONObject *shared.ResponseFormatJSONObjectParam `json:",omitzero,inline"` + OfJSONSchema *shared.ResponseFormatJSONSchemaParam `json:",omitzero,inline"` + paramUnion +} + +func (u AssistantResponseFormatOptionUnionParam) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfAuto, u.OfText, u.OfJSONObject, u.OfJSONSchema) +} +func (u *AssistantResponseFormatOptionUnionParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *AssistantResponseFormatOptionUnionParam) asAny() any { + if !param.IsOmitted(u.OfAuto) { + return &u.OfAuto + } else if !param.IsOmitted(u.OfText) { + return u.OfText + } else if !param.IsOmitted(u.OfJSONObject) { + return u.OfJSONObject + } else if !param.IsOmitted(u.OfJSONSchema) { + return u.OfJSONSchema + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u AssistantResponseFormatOptionUnionParam) GetJSONSchema() *shared.ResponseFormatJSONSchemaJSONSchemaParam { + if vt := u.OfJSONSchema; vt != nil { + return &vt.JSONSchema + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u AssistantResponseFormatOptionUnionParam) GetType() *string { + if vt := u.OfText; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfJSONObject; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfJSONSchema; vt != nil { + return (*string)(&vt.Type) + } + return nil +} + +// Specifies a tool the model should use. Use to force the model to call a specific +// tool. +type AssistantToolChoice struct { + // The type of the tool. If type is `function`, the function name must be set + // + // Any of "function", "code_interpreter", "file_search". + Type AssistantToolChoiceType `json:"type,required"` + Function AssistantToolChoiceFunction `json:"function"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Type respjson.Field + Function respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r AssistantToolChoice) RawJSON() string { return r.JSON.raw } +func (r *AssistantToolChoice) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ToParam converts this AssistantToolChoice to a AssistantToolChoiceParam. +// +// Warning: the fields of the param type will not be present. ToParam should only +// be used at the last possible moment before sending a request. Test for this with +// AssistantToolChoiceParam.Overrides() +func (r AssistantToolChoice) ToParam() AssistantToolChoiceParam { + return param.Override[AssistantToolChoiceParam](json.RawMessage(r.RawJSON())) +} + +// The type of the tool. If type is `function`, the function name must be set +type AssistantToolChoiceType string + +const ( + AssistantToolChoiceTypeFunction AssistantToolChoiceType = "function" + AssistantToolChoiceTypeCodeInterpreter AssistantToolChoiceType = "code_interpreter" + AssistantToolChoiceTypeFileSearch AssistantToolChoiceType = "file_search" +) + +// Specifies a tool the model should use. Use to force the model to call a specific +// tool. +// +// The property Type is required. +type AssistantToolChoiceParam struct { + // The type of the tool. If type is `function`, the function name must be set + // + // Any of "function", "code_interpreter", "file_search". + Type AssistantToolChoiceType `json:"type,omitzero,required"` + Function AssistantToolChoiceFunctionParam `json:"function,omitzero"` + paramObj +} + +func (r AssistantToolChoiceParam) MarshalJSON() (data []byte, err error) { + type shadow AssistantToolChoiceParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *AssistantToolChoiceParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type AssistantToolChoiceFunction struct { + // The name of the function to call. + Name string `json:"name,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Name respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r AssistantToolChoiceFunction) RawJSON() string { return r.JSON.raw } +func (r *AssistantToolChoiceFunction) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ToParam converts this AssistantToolChoiceFunction to a +// AssistantToolChoiceFunctionParam. +// +// Warning: the fields of the param type will not be present. ToParam should only +// be used at the last possible moment before sending a request. Test for this with +// AssistantToolChoiceFunctionParam.Overrides() +func (r AssistantToolChoiceFunction) ToParam() AssistantToolChoiceFunctionParam { + return param.Override[AssistantToolChoiceFunctionParam](json.RawMessage(r.RawJSON())) +} + +// The property Name is required. +type AssistantToolChoiceFunctionParam struct { + // The name of the function to call. + Name string `json:"name,required"` + paramObj +} + +func (r AssistantToolChoiceFunctionParam) MarshalJSON() (data []byte, err error) { + type shadow AssistantToolChoiceFunctionParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *AssistantToolChoiceFunctionParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// AssistantToolChoiceOptionUnion contains all possible properties and values from +// [string], [AssistantToolChoice]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +// +// If the underlying value is not a json object, one of the following properties +// will be valid: OfAuto] +type AssistantToolChoiceOptionUnion struct { + // This field will be present if the value is a [string] instead of an object. + OfAuto string `json:",inline"` + // This field is from variant [AssistantToolChoice]. + Type AssistantToolChoiceType `json:"type"` + // This field is from variant [AssistantToolChoice]. + Function AssistantToolChoiceFunction `json:"function"` + JSON struct { + OfAuto respjson.Field + Type respjson.Field + Function respjson.Field + raw string + } `json:"-"` +} + +func (u AssistantToolChoiceOptionUnion) AsAuto() (v string) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u AssistantToolChoiceOptionUnion) AsAssistantToolChoice() (v AssistantToolChoice) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u AssistantToolChoiceOptionUnion) RawJSON() string { return u.JSON.raw } + +func (r *AssistantToolChoiceOptionUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ToParam converts this AssistantToolChoiceOptionUnion to a +// AssistantToolChoiceOptionUnionParam. +// +// Warning: the fields of the param type will not be present. ToParam should only +// be used at the last possible moment before sending a request. Test for this with +// AssistantToolChoiceOptionUnionParam.Overrides() +func (r AssistantToolChoiceOptionUnion) ToParam() AssistantToolChoiceOptionUnionParam { + return param.Override[AssistantToolChoiceOptionUnionParam](json.RawMessage(r.RawJSON())) +} + +// `none` means the model will not call any tools and instead generates a message. +// `auto` means the model can pick between generating a message or calling one or +// more tools. `required` means the model must call one or more tools before +// responding to the user. +type AssistantToolChoiceOptionAuto string + +const ( + AssistantToolChoiceOptionAutoNone AssistantToolChoiceOptionAuto = "none" + AssistantToolChoiceOptionAutoAuto AssistantToolChoiceOptionAuto = "auto" + AssistantToolChoiceOptionAutoRequired AssistantToolChoiceOptionAuto = "required" +) + +func AssistantToolChoiceOptionParamOfAssistantToolChoice(type_ AssistantToolChoiceType) AssistantToolChoiceOptionUnionParam { + var variant AssistantToolChoiceParam + variant.Type = type_ + return AssistantToolChoiceOptionUnionParam{OfAssistantToolChoice: &variant} +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type AssistantToolChoiceOptionUnionParam struct { + // Check if union is this variant with !param.IsOmitted(union.OfAuto) + OfAuto param.Opt[string] `json:",omitzero,inline"` + OfAssistantToolChoice *AssistantToolChoiceParam `json:",omitzero,inline"` + paramUnion +} + +func (u AssistantToolChoiceOptionUnionParam) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfAuto, u.OfAssistantToolChoice) +} +func (u *AssistantToolChoiceOptionUnionParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *AssistantToolChoiceOptionUnionParam) asAny() any { + if !param.IsOmitted(u.OfAuto) { + return &u.OfAuto + } else if !param.IsOmitted(u.OfAssistantToolChoice) { + return u.OfAssistantToolChoice + } + return nil +} + +// Represents a thread that contains +// [messages](https://platform.openai.com/docs/api-reference/messages). +type Thread struct { + // The identifier, which can be referenced in API endpoints. + ID string `json:"id,required"` + // The Unix timestamp (in seconds) for when the thread was created. + CreatedAt int64 `json:"created_at,required"` + // Set of 16 key-value pairs that can be attached to an object. This can be useful + // for storing additional information about the object in a structured format, and + // querying for objects via API or the dashboard. + // + // Keys are strings with a maximum length of 64 characters. Values are strings with + // a maximum length of 512 characters. + Metadata shared.Metadata `json:"metadata,required"` + // The object type, which is always `thread`. + Object constant.Thread `json:"object,required"` + // A set of resources that are made available to the assistant's tools in this + // thread. The resources are specific to the type of tool. For example, the + // `code_interpreter` tool requires a list of file IDs, while the `file_search` + // tool requires a list of vector store IDs. + ToolResources ThreadToolResources `json:"tool_resources,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + CreatedAt respjson.Field + Metadata respjson.Field + Object respjson.Field + ToolResources respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r Thread) RawJSON() string { return r.JSON.raw } +func (r *Thread) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// A set of resources that are made available to the assistant's tools in this +// thread. The resources are specific to the type of tool. For example, the +// `code_interpreter` tool requires a list of file IDs, while the `file_search` +// tool requires a list of vector store IDs. +type ThreadToolResources struct { + CodeInterpreter ThreadToolResourcesCodeInterpreter `json:"code_interpreter"` + FileSearch ThreadToolResourcesFileSearch `json:"file_search"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + CodeInterpreter respjson.Field + FileSearch respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ThreadToolResources) RawJSON() string { return r.JSON.raw } +func (r *ThreadToolResources) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type ThreadToolResourcesCodeInterpreter struct { + // A list of [file](https://platform.openai.com/docs/api-reference/files) IDs made + // available to the `code_interpreter` tool. There can be a maximum of 20 files + // associated with the tool. + FileIDs []string `json:"file_ids"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + FileIDs respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ThreadToolResourcesCodeInterpreter) RawJSON() string { return r.JSON.raw } +func (r *ThreadToolResourcesCodeInterpreter) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type ThreadToolResourcesFileSearch struct { + // The + // [vector store](https://platform.openai.com/docs/api-reference/vector-stores/object) + // attached to this thread. There can be a maximum of 1 vector store attached to + // the thread. + VectorStoreIDs []string `json:"vector_store_ids"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + VectorStoreIDs respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ThreadToolResourcesFileSearch) RawJSON() string { return r.JSON.raw } +func (r *ThreadToolResourcesFileSearch) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type ThreadDeleted struct { + ID string `json:"id,required"` + Deleted bool `json:"deleted,required"` + Object constant.ThreadDeleted `json:"object,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + Deleted respjson.Field + Object respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ThreadDeleted) RawJSON() string { return r.JSON.raw } +func (r *ThreadDeleted) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type BetaThreadNewParams struct { + // Set of 16 key-value pairs that can be attached to an object. This can be useful + // for storing additional information about the object in a structured format, and + // querying for objects via API or the dashboard. + // + // Keys are strings with a maximum length of 64 characters. Values are strings with + // a maximum length of 512 characters. + Metadata shared.Metadata `json:"metadata,omitzero"` + // A set of resources that are made available to the assistant's tools in this + // thread. The resources are specific to the type of tool. For example, the + // `code_interpreter` tool requires a list of file IDs, while the `file_search` + // tool requires a list of vector store IDs. + ToolResources BetaThreadNewParamsToolResources `json:"tool_resources,omitzero"` + // A list of [messages](https://platform.openai.com/docs/api-reference/messages) to + // start the thread with. + Messages []BetaThreadNewParamsMessage `json:"messages,omitzero"` + paramObj +} + +func (r BetaThreadNewParams) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadNewParams + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadNewParams) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The properties Content, Role are required. +type BetaThreadNewParamsMessage struct { + // The text contents of the message. + Content BetaThreadNewParamsMessageContentUnion `json:"content,omitzero,required"` + // The role of the entity that is creating the message. Allowed values include: + // + // - `user`: Indicates the message is sent by an actual user and should be used in + // most cases to represent user-generated messages. + // - `assistant`: Indicates the message is generated by the assistant. Use this + // value to insert messages from the assistant into the conversation. + // + // Any of "user", "assistant". + Role string `json:"role,omitzero,required"` + // A list of files attached to the message, and the tools they should be added to. + Attachments []BetaThreadNewParamsMessageAttachment `json:"attachments,omitzero"` + // Set of 16 key-value pairs that can be attached to an object. This can be useful + // for storing additional information about the object in a structured format, and + // querying for objects via API or the dashboard. + // + // Keys are strings with a maximum length of 64 characters. Values are strings with + // a maximum length of 512 characters. + Metadata shared.Metadata `json:"metadata,omitzero"` + paramObj +} + +func (r BetaThreadNewParamsMessage) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadNewParamsMessage + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadNewParamsMessage) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +func init() { + apijson.RegisterFieldValidator[BetaThreadNewParamsMessage]( + "role", "user", "assistant", + ) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type BetaThreadNewParamsMessageContentUnion struct { + OfString param.Opt[string] `json:",omitzero,inline"` + OfArrayOfContentParts []MessageContentPartParamUnion `json:",omitzero,inline"` + paramUnion +} + +func (u BetaThreadNewParamsMessageContentUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfString, u.OfArrayOfContentParts) +} +func (u *BetaThreadNewParamsMessageContentUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *BetaThreadNewParamsMessageContentUnion) asAny() any { + if !param.IsOmitted(u.OfString) { + return &u.OfString.Value + } else if !param.IsOmitted(u.OfArrayOfContentParts) { + return &u.OfArrayOfContentParts + } + return nil +} + +type BetaThreadNewParamsMessageAttachment struct { + // The ID of the file to attach to the message. + FileID param.Opt[string] `json:"file_id,omitzero"` + // The tools to add this file to. + Tools []BetaThreadNewParamsMessageAttachmentToolUnion `json:"tools,omitzero"` + paramObj +} + +func (r BetaThreadNewParamsMessageAttachment) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadNewParamsMessageAttachment + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadNewParamsMessageAttachment) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type BetaThreadNewParamsMessageAttachmentToolUnion struct { + OfCodeInterpreter *CodeInterpreterToolParam `json:",omitzero,inline"` + OfFileSearch *BetaThreadNewParamsMessageAttachmentToolFileSearch `json:",omitzero,inline"` + paramUnion +} + +func (u BetaThreadNewParamsMessageAttachmentToolUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfCodeInterpreter, u.OfFileSearch) +} +func (u *BetaThreadNewParamsMessageAttachmentToolUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *BetaThreadNewParamsMessageAttachmentToolUnion) asAny() any { + if !param.IsOmitted(u.OfCodeInterpreter) { + return u.OfCodeInterpreter + } else if !param.IsOmitted(u.OfFileSearch) { + return u.OfFileSearch + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u BetaThreadNewParamsMessageAttachmentToolUnion) GetType() *string { + if vt := u.OfCodeInterpreter; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfFileSearch; vt != nil { + return (*string)(&vt.Type) + } + return nil +} + +func init() { + apijson.RegisterUnion[BetaThreadNewParamsMessageAttachmentToolUnion]( + "type", + apijson.Discriminator[CodeInterpreterToolParam]("code_interpreter"), + apijson.Discriminator[BetaThreadNewParamsMessageAttachmentToolFileSearch]("file_search"), + ) +} + +func NewBetaThreadNewParamsMessageAttachmentToolFileSearch() BetaThreadNewParamsMessageAttachmentToolFileSearch { + return BetaThreadNewParamsMessageAttachmentToolFileSearch{ + Type: "file_search", + } +} + +// This struct has a constant value, construct it with +// [NewBetaThreadNewParamsMessageAttachmentToolFileSearch]. +type BetaThreadNewParamsMessageAttachmentToolFileSearch struct { + // The type of tool being defined: `file_search` + Type constant.FileSearch `json:"type,required"` + paramObj +} + +func (r BetaThreadNewParamsMessageAttachmentToolFileSearch) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadNewParamsMessageAttachmentToolFileSearch + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadNewParamsMessageAttachmentToolFileSearch) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// A set of resources that are made available to the assistant's tools in this +// thread. The resources are specific to the type of tool. For example, the +// `code_interpreter` tool requires a list of file IDs, while the `file_search` +// tool requires a list of vector store IDs. +type BetaThreadNewParamsToolResources struct { + CodeInterpreter BetaThreadNewParamsToolResourcesCodeInterpreter `json:"code_interpreter,omitzero"` + FileSearch BetaThreadNewParamsToolResourcesFileSearch `json:"file_search,omitzero"` + paramObj +} + +func (r BetaThreadNewParamsToolResources) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadNewParamsToolResources + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadNewParamsToolResources) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type BetaThreadNewParamsToolResourcesCodeInterpreter struct { + // A list of [file](https://platform.openai.com/docs/api-reference/files) IDs made + // available to the `code_interpreter` tool. There can be a maximum of 20 files + // associated with the tool. + FileIDs []string `json:"file_ids,omitzero"` + paramObj +} + +func (r BetaThreadNewParamsToolResourcesCodeInterpreter) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadNewParamsToolResourcesCodeInterpreter + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadNewParamsToolResourcesCodeInterpreter) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type BetaThreadNewParamsToolResourcesFileSearch struct { + // The + // [vector store](https://platform.openai.com/docs/api-reference/vector-stores/object) + // attached to this thread. There can be a maximum of 1 vector store attached to + // the thread. + VectorStoreIDs []string `json:"vector_store_ids,omitzero"` + // A helper to create a + // [vector store](https://platform.openai.com/docs/api-reference/vector-stores/object) + // with file_ids and attach it to this thread. There can be a maximum of 1 vector + // store attached to the thread. + VectorStores []BetaThreadNewParamsToolResourcesFileSearchVectorStore `json:"vector_stores,omitzero"` + paramObj +} + +func (r BetaThreadNewParamsToolResourcesFileSearch) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadNewParamsToolResourcesFileSearch + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadNewParamsToolResourcesFileSearch) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type BetaThreadNewParamsToolResourcesFileSearchVectorStore struct { + // Set of 16 key-value pairs that can be attached to an object. This can be useful + // for storing additional information about the object in a structured format, and + // querying for objects via API or the dashboard. + // + // Keys are strings with a maximum length of 64 characters. Values are strings with + // a maximum length of 512 characters. + Metadata shared.Metadata `json:"metadata,omitzero"` + // The chunking strategy used to chunk the file(s). If not set, will use the `auto` + // strategy. + ChunkingStrategy BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyUnion `json:"chunking_strategy,omitzero"` + // A list of [file](https://platform.openai.com/docs/api-reference/files) IDs to + // add to the vector store. There can be a maximum of 10000 files in a vector + // store. + FileIDs []string `json:"file_ids,omitzero"` + paramObj +} + +func (r BetaThreadNewParamsToolResourcesFileSearchVectorStore) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadNewParamsToolResourcesFileSearchVectorStore + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadNewParamsToolResourcesFileSearchVectorStore) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyUnion struct { + OfAuto *BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyAuto `json:",omitzero,inline"` + OfStatic *BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStatic `json:",omitzero,inline"` + paramUnion +} + +func (u BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfAuto, u.OfStatic) +} +func (u *BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyUnion) asAny() any { + if !param.IsOmitted(u.OfAuto) { + return u.OfAuto + } else if !param.IsOmitted(u.OfStatic) { + return u.OfStatic + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyUnion) GetStatic() *BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStaticStatic { + if vt := u.OfStatic; vt != nil { + return &vt.Static + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyUnion) GetType() *string { + if vt := u.OfAuto; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfStatic; vt != nil { + return (*string)(&vt.Type) + } + return nil +} + +func init() { + apijson.RegisterUnion[BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyUnion]( + "type", + apijson.Discriminator[BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyAuto]("auto"), + apijson.Discriminator[BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStatic]("static"), + ) +} + +func NewBetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyAuto() BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyAuto { + return BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyAuto{ + Type: "auto", + } +} + +// The default strategy. This strategy currently uses a `max_chunk_size_tokens` of +// `800` and `chunk_overlap_tokens` of `400`. +// +// This struct has a constant value, construct it with +// [NewBetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyAuto]. +type BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyAuto struct { + // Always `auto`. + Type constant.Auto `json:"type,required"` + paramObj +} + +func (r BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyAuto) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyAuto + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyAuto) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The properties Static, Type are required. +type BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStatic struct { + Static BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStaticStatic `json:"static,omitzero,required"` + // Always `static`. + // + // This field can be elided, and will marshal its zero value as "static". + Type constant.Static `json:"type,required"` + paramObj +} + +func (r BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStatic) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStatic + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStatic) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The properties ChunkOverlapTokens, MaxChunkSizeTokens are required. +type BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStaticStatic struct { + // The number of tokens that overlap between chunks. The default value is `400`. + // + // Note that the overlap must not exceed half of `max_chunk_size_tokens`. + ChunkOverlapTokens int64 `json:"chunk_overlap_tokens,required"` + // The maximum number of tokens in each chunk. The default value is `800`. The + // minimum value is `100` and the maximum value is `4096`. + MaxChunkSizeTokens int64 `json:"max_chunk_size_tokens,required"` + paramObj +} + +func (r BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStaticStatic) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStaticStatic + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStaticStatic) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type BetaThreadUpdateParams struct { + // Set of 16 key-value pairs that can be attached to an object. This can be useful + // for storing additional information about the object in a structured format, and + // querying for objects via API or the dashboard. + // + // Keys are strings with a maximum length of 64 characters. Values are strings with + // a maximum length of 512 characters. + Metadata shared.Metadata `json:"metadata,omitzero"` + // A set of resources that are made available to the assistant's tools in this + // thread. The resources are specific to the type of tool. For example, the + // `code_interpreter` tool requires a list of file IDs, while the `file_search` + // tool requires a list of vector store IDs. + ToolResources BetaThreadUpdateParamsToolResources `json:"tool_resources,omitzero"` + paramObj +} + +func (r BetaThreadUpdateParams) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadUpdateParams + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadUpdateParams) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// A set of resources that are made available to the assistant's tools in this +// thread. The resources are specific to the type of tool. For example, the +// `code_interpreter` tool requires a list of file IDs, while the `file_search` +// tool requires a list of vector store IDs. +type BetaThreadUpdateParamsToolResources struct { + CodeInterpreter BetaThreadUpdateParamsToolResourcesCodeInterpreter `json:"code_interpreter,omitzero"` + FileSearch BetaThreadUpdateParamsToolResourcesFileSearch `json:"file_search,omitzero"` + paramObj +} + +func (r BetaThreadUpdateParamsToolResources) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadUpdateParamsToolResources + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadUpdateParamsToolResources) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type BetaThreadUpdateParamsToolResourcesCodeInterpreter struct { + // A list of [file](https://platform.openai.com/docs/api-reference/files) IDs made + // available to the `code_interpreter` tool. There can be a maximum of 20 files + // associated with the tool. + FileIDs []string `json:"file_ids,omitzero"` + paramObj +} + +func (r BetaThreadUpdateParamsToolResourcesCodeInterpreter) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadUpdateParamsToolResourcesCodeInterpreter + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadUpdateParamsToolResourcesCodeInterpreter) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type BetaThreadUpdateParamsToolResourcesFileSearch struct { + // The + // [vector store](https://platform.openai.com/docs/api-reference/vector-stores/object) + // attached to this thread. There can be a maximum of 1 vector store attached to + // the thread. + VectorStoreIDs []string `json:"vector_store_ids,omitzero"` + paramObj +} + +func (r BetaThreadUpdateParamsToolResourcesFileSearch) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadUpdateParamsToolResourcesFileSearch + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadUpdateParamsToolResourcesFileSearch) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type BetaThreadNewAndRunParams struct { + // The ID of the + // [assistant](https://platform.openai.com/docs/api-reference/assistants) to use to + // execute this run. + AssistantID string `json:"assistant_id,required"` + // Override the default system message of the assistant. This is useful for + // modifying the behavior on a per-run basis. + Instructions param.Opt[string] `json:"instructions,omitzero"` + // The maximum number of completion tokens that may be used over the course of the + // run. The run will make a best effort to use only the number of completion tokens + // specified, across multiple turns of the run. If the run exceeds the number of + // completion tokens specified, the run will end with status `incomplete`. See + // `incomplete_details` for more info. + MaxCompletionTokens param.Opt[int64] `json:"max_completion_tokens,omitzero"` + // The maximum number of prompt tokens that may be used over the course of the run. + // The run will make a best effort to use only the number of prompt tokens + // specified, across multiple turns of the run. If the run exceeds the number of + // prompt tokens specified, the run will end with status `incomplete`. See + // `incomplete_details` for more info. + MaxPromptTokens param.Opt[int64] `json:"max_prompt_tokens,omitzero"` + // What sampling temperature to use, between 0 and 2. Higher values like 0.8 will + // make the output more random, while lower values like 0.2 will make it more + // focused and deterministic. + Temperature param.Opt[float64] `json:"temperature,omitzero"` + // An alternative to sampling with temperature, called nucleus sampling, where the + // model considers the results of the tokens with top_p probability mass. So 0.1 + // means only the tokens comprising the top 10% probability mass are considered. + // + // We generally recommend altering this or temperature but not both. + TopP param.Opt[float64] `json:"top_p,omitzero"` + // Whether to enable + // [parallel function calling](https://platform.openai.com/docs/guides/function-calling#configuring-parallel-function-calling) + // during tool use. + ParallelToolCalls param.Opt[bool] `json:"parallel_tool_calls,omitzero"` + // Set of 16 key-value pairs that can be attached to an object. This can be useful + // for storing additional information about the object in a structured format, and + // querying for objects via API or the dashboard. + // + // Keys are strings with a maximum length of 64 characters. Values are strings with + // a maximum length of 512 characters. + Metadata shared.Metadata `json:"metadata,omitzero"` + // The ID of the [Model](https://platform.openai.com/docs/api-reference/models) to + // be used to execute this run. If a value is provided here, it will override the + // model associated with the assistant. If not, the model associated with the + // assistant will be used. + Model shared.ChatModel `json:"model,omitzero"` + // A set of resources that are used by the assistant's tools. The resources are + // specific to the type of tool. For example, the `code_interpreter` tool requires + // a list of file IDs, while the `file_search` tool requires a list of vector store + // IDs. + ToolResources BetaThreadNewAndRunParamsToolResources `json:"tool_resources,omitzero"` + // Override the tools the assistant can use for this run. This is useful for + // modifying the behavior on a per-run basis. + Tools []AssistantToolUnionParam `json:"tools,omitzero"` + // Controls for how a thread will be truncated prior to the run. Use this to + // control the intial context window of the run. + TruncationStrategy BetaThreadNewAndRunParamsTruncationStrategy `json:"truncation_strategy,omitzero"` + // Specifies the format that the model must output. Compatible with + // [GPT-4o](https://platform.openai.com/docs/models#gpt-4o), + // [GPT-4 Turbo](https://platform.openai.com/docs/models#gpt-4-turbo-and-gpt-4), + // and all GPT-3.5 Turbo models since `gpt-3.5-turbo-1106`. + // + // Setting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured + // Outputs which ensures the model will match your supplied JSON schema. Learn more + // in the + // [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs). + // + // Setting to `{ "type": "json_object" }` enables JSON mode, which ensures the + // message the model generates is valid JSON. + // + // **Important:** when using JSON mode, you **must** also instruct the model to + // produce JSON yourself via a system or user message. Without this, the model may + // generate an unending stream of whitespace until the generation reaches the token + // limit, resulting in a long-running and seemingly "stuck" request. Also note that + // the message content may be partially cut off if `finish_reason="length"`, which + // indicates the generation exceeded `max_tokens` or the conversation exceeded the + // max context length. + ResponseFormat AssistantResponseFormatOptionUnionParam `json:"response_format,omitzero"` + // Options to create a new thread. If no thread is provided when running a request, + // an empty thread will be created. + Thread BetaThreadNewAndRunParamsThread `json:"thread,omitzero"` + // Controls which (if any) tool is called by the model. `none` means the model will + // not call any tools and instead generates a message. `auto` is the default value + // and means the model can pick between generating a message or calling one or more + // tools. `required` means the model must call one or more tools before responding + // to the user. Specifying a particular tool like `{"type": "file_search"}` or + // `{"type": "function", "function": {"name": "my_function"}}` forces the model to + // call that tool. + ToolChoice AssistantToolChoiceOptionUnionParam `json:"tool_choice,omitzero"` + paramObj +} + +func (r BetaThreadNewAndRunParams) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadNewAndRunParams + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadNewAndRunParams) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Options to create a new thread. If no thread is provided when running a request, +// an empty thread will be created. +type BetaThreadNewAndRunParamsThread struct { + // Set of 16 key-value pairs that can be attached to an object. This can be useful + // for storing additional information about the object in a structured format, and + // querying for objects via API or the dashboard. + // + // Keys are strings with a maximum length of 64 characters. Values are strings with + // a maximum length of 512 characters. + Metadata shared.Metadata `json:"metadata,omitzero"` + // A set of resources that are made available to the assistant's tools in this + // thread. The resources are specific to the type of tool. For example, the + // `code_interpreter` tool requires a list of file IDs, while the `file_search` + // tool requires a list of vector store IDs. + ToolResources BetaThreadNewAndRunParamsThreadToolResources `json:"tool_resources,omitzero"` + // A list of [messages](https://platform.openai.com/docs/api-reference/messages) to + // start the thread with. + Messages []BetaThreadNewAndRunParamsThreadMessage `json:"messages,omitzero"` + paramObj +} + +func (r BetaThreadNewAndRunParamsThread) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadNewAndRunParamsThread + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadNewAndRunParamsThread) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The properties Content, Role are required. +type BetaThreadNewAndRunParamsThreadMessage struct { + // The text contents of the message. + Content BetaThreadNewAndRunParamsThreadMessageContentUnion `json:"content,omitzero,required"` + // The role of the entity that is creating the message. Allowed values include: + // + // - `user`: Indicates the message is sent by an actual user and should be used in + // most cases to represent user-generated messages. + // - `assistant`: Indicates the message is generated by the assistant. Use this + // value to insert messages from the assistant into the conversation. + // + // Any of "user", "assistant". + Role string `json:"role,omitzero,required"` + // A list of files attached to the message, and the tools they should be added to. + Attachments []BetaThreadNewAndRunParamsThreadMessageAttachment `json:"attachments,omitzero"` + // Set of 16 key-value pairs that can be attached to an object. This can be useful + // for storing additional information about the object in a structured format, and + // querying for objects via API or the dashboard. + // + // Keys are strings with a maximum length of 64 characters. Values are strings with + // a maximum length of 512 characters. + Metadata shared.Metadata `json:"metadata,omitzero"` + paramObj +} + +func (r BetaThreadNewAndRunParamsThreadMessage) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadNewAndRunParamsThreadMessage + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadNewAndRunParamsThreadMessage) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +func init() { + apijson.RegisterFieldValidator[BetaThreadNewAndRunParamsThreadMessage]( + "role", "user", "assistant", + ) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type BetaThreadNewAndRunParamsThreadMessageContentUnion struct { + OfString param.Opt[string] `json:",omitzero,inline"` + OfArrayOfContentParts []MessageContentPartParamUnion `json:",omitzero,inline"` + paramUnion +} + +func (u BetaThreadNewAndRunParamsThreadMessageContentUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfString, u.OfArrayOfContentParts) +} +func (u *BetaThreadNewAndRunParamsThreadMessageContentUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *BetaThreadNewAndRunParamsThreadMessageContentUnion) asAny() any { + if !param.IsOmitted(u.OfString) { + return &u.OfString.Value + } else if !param.IsOmitted(u.OfArrayOfContentParts) { + return &u.OfArrayOfContentParts + } + return nil +} + +type BetaThreadNewAndRunParamsThreadMessageAttachment struct { + // The ID of the file to attach to the message. + FileID param.Opt[string] `json:"file_id,omitzero"` + // The tools to add this file to. + Tools []BetaThreadNewAndRunParamsThreadMessageAttachmentToolUnion `json:"tools,omitzero"` + paramObj +} + +func (r BetaThreadNewAndRunParamsThreadMessageAttachment) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadNewAndRunParamsThreadMessageAttachment + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadNewAndRunParamsThreadMessageAttachment) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type BetaThreadNewAndRunParamsThreadMessageAttachmentToolUnion struct { + OfCodeInterpreter *CodeInterpreterToolParam `json:",omitzero,inline"` + OfFileSearch *BetaThreadNewAndRunParamsThreadMessageAttachmentToolFileSearch `json:",omitzero,inline"` + paramUnion +} + +func (u BetaThreadNewAndRunParamsThreadMessageAttachmentToolUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfCodeInterpreter, u.OfFileSearch) +} +func (u *BetaThreadNewAndRunParamsThreadMessageAttachmentToolUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *BetaThreadNewAndRunParamsThreadMessageAttachmentToolUnion) asAny() any { + if !param.IsOmitted(u.OfCodeInterpreter) { + return u.OfCodeInterpreter + } else if !param.IsOmitted(u.OfFileSearch) { + return u.OfFileSearch + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u BetaThreadNewAndRunParamsThreadMessageAttachmentToolUnion) GetType() *string { + if vt := u.OfCodeInterpreter; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfFileSearch; vt != nil { + return (*string)(&vt.Type) + } + return nil +} + +func init() { + apijson.RegisterUnion[BetaThreadNewAndRunParamsThreadMessageAttachmentToolUnion]( + "type", + apijson.Discriminator[CodeInterpreterToolParam]("code_interpreter"), + apijson.Discriminator[BetaThreadNewAndRunParamsThreadMessageAttachmentToolFileSearch]("file_search"), + ) +} + +func NewBetaThreadNewAndRunParamsThreadMessageAttachmentToolFileSearch() BetaThreadNewAndRunParamsThreadMessageAttachmentToolFileSearch { + return BetaThreadNewAndRunParamsThreadMessageAttachmentToolFileSearch{ + Type: "file_search", + } +} + +// This struct has a constant value, construct it with +// [NewBetaThreadNewAndRunParamsThreadMessageAttachmentToolFileSearch]. +type BetaThreadNewAndRunParamsThreadMessageAttachmentToolFileSearch struct { + // The type of tool being defined: `file_search` + Type constant.FileSearch `json:"type,required"` + paramObj +} + +func (r BetaThreadNewAndRunParamsThreadMessageAttachmentToolFileSearch) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadNewAndRunParamsThreadMessageAttachmentToolFileSearch + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadNewAndRunParamsThreadMessageAttachmentToolFileSearch) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// A set of resources that are made available to the assistant's tools in this +// thread. The resources are specific to the type of tool. For example, the +// `code_interpreter` tool requires a list of file IDs, while the `file_search` +// tool requires a list of vector store IDs. +type BetaThreadNewAndRunParamsThreadToolResources struct { + CodeInterpreter BetaThreadNewAndRunParamsThreadToolResourcesCodeInterpreter `json:"code_interpreter,omitzero"` + FileSearch BetaThreadNewAndRunParamsThreadToolResourcesFileSearch `json:"file_search,omitzero"` + paramObj +} + +func (r BetaThreadNewAndRunParamsThreadToolResources) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadNewAndRunParamsThreadToolResources + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadNewAndRunParamsThreadToolResources) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type BetaThreadNewAndRunParamsThreadToolResourcesCodeInterpreter struct { + // A list of [file](https://platform.openai.com/docs/api-reference/files) IDs made + // available to the `code_interpreter` tool. There can be a maximum of 20 files + // associated with the tool. + FileIDs []string `json:"file_ids,omitzero"` + paramObj +} + +func (r BetaThreadNewAndRunParamsThreadToolResourcesCodeInterpreter) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadNewAndRunParamsThreadToolResourcesCodeInterpreter + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadNewAndRunParamsThreadToolResourcesCodeInterpreter) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type BetaThreadNewAndRunParamsThreadToolResourcesFileSearch struct { + // The + // [vector store](https://platform.openai.com/docs/api-reference/vector-stores/object) + // attached to this thread. There can be a maximum of 1 vector store attached to + // the thread. + VectorStoreIDs []string `json:"vector_store_ids,omitzero"` + // A helper to create a + // [vector store](https://platform.openai.com/docs/api-reference/vector-stores/object) + // with file_ids and attach it to this thread. There can be a maximum of 1 vector + // store attached to the thread. + VectorStores []BetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStore `json:"vector_stores,omitzero"` + paramObj +} + +func (r BetaThreadNewAndRunParamsThreadToolResourcesFileSearch) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadNewAndRunParamsThreadToolResourcesFileSearch + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadNewAndRunParamsThreadToolResourcesFileSearch) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type BetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStore struct { + // Set of 16 key-value pairs that can be attached to an object. This can be useful + // for storing additional information about the object in a structured format, and + // querying for objects via API or the dashboard. + // + // Keys are strings with a maximum length of 64 characters. Values are strings with + // a maximum length of 512 characters. + Metadata shared.Metadata `json:"metadata,omitzero"` + // The chunking strategy used to chunk the file(s). If not set, will use the `auto` + // strategy. + ChunkingStrategy BetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStoreChunkingStrategyUnion `json:"chunking_strategy,omitzero"` + // A list of [file](https://platform.openai.com/docs/api-reference/files) IDs to + // add to the vector store. There can be a maximum of 10000 files in a vector + // store. + FileIDs []string `json:"file_ids,omitzero"` + paramObj +} + +func (r BetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStore) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStore + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStore) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type BetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStoreChunkingStrategyUnion struct { + OfAuto *BetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStoreChunkingStrategyAuto `json:",omitzero,inline"` + OfStatic *BetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStoreChunkingStrategyStatic `json:",omitzero,inline"` + paramUnion +} + +func (u BetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStoreChunkingStrategyUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfAuto, u.OfStatic) +} +func (u *BetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStoreChunkingStrategyUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *BetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStoreChunkingStrategyUnion) asAny() any { + if !param.IsOmitted(u.OfAuto) { + return u.OfAuto + } else if !param.IsOmitted(u.OfStatic) { + return u.OfStatic + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u BetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStoreChunkingStrategyUnion) GetStatic() *BetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStoreChunkingStrategyStaticStatic { + if vt := u.OfStatic; vt != nil { + return &vt.Static + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u BetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStoreChunkingStrategyUnion) GetType() *string { + if vt := u.OfAuto; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfStatic; vt != nil { + return (*string)(&vt.Type) + } + return nil +} + +func init() { + apijson.RegisterUnion[BetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStoreChunkingStrategyUnion]( + "type", + apijson.Discriminator[BetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStoreChunkingStrategyAuto]("auto"), + apijson.Discriminator[BetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStoreChunkingStrategyStatic]("static"), + ) +} + +func NewBetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStoreChunkingStrategyAuto() BetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStoreChunkingStrategyAuto { + return BetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStoreChunkingStrategyAuto{ + Type: "auto", + } +} + +// The default strategy. This strategy currently uses a `max_chunk_size_tokens` of +// `800` and `chunk_overlap_tokens` of `400`. +// +// This struct has a constant value, construct it with +// [NewBetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStoreChunkingStrategyAuto]. +type BetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStoreChunkingStrategyAuto struct { + // Always `auto`. + Type constant.Auto `json:"type,required"` + paramObj +} + +func (r BetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStoreChunkingStrategyAuto) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStoreChunkingStrategyAuto + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStoreChunkingStrategyAuto) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The properties Static, Type are required. +type BetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStoreChunkingStrategyStatic struct { + Static BetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStoreChunkingStrategyStaticStatic `json:"static,omitzero,required"` + // Always `static`. + // + // This field can be elided, and will marshal its zero value as "static". + Type constant.Static `json:"type,required"` + paramObj +} + +func (r BetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStoreChunkingStrategyStatic) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStoreChunkingStrategyStatic + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStoreChunkingStrategyStatic) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The properties ChunkOverlapTokens, MaxChunkSizeTokens are required. +type BetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStoreChunkingStrategyStaticStatic struct { + // The number of tokens that overlap between chunks. The default value is `400`. + // + // Note that the overlap must not exceed half of `max_chunk_size_tokens`. + ChunkOverlapTokens int64 `json:"chunk_overlap_tokens,required"` + // The maximum number of tokens in each chunk. The default value is `800`. The + // minimum value is `100` and the maximum value is `4096`. + MaxChunkSizeTokens int64 `json:"max_chunk_size_tokens,required"` + paramObj +} + +func (r BetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStoreChunkingStrategyStaticStatic) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStoreChunkingStrategyStaticStatic + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadNewAndRunParamsThreadToolResourcesFileSearchVectorStoreChunkingStrategyStaticStatic) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// A set of resources that are used by the assistant's tools. The resources are +// specific to the type of tool. For example, the `code_interpreter` tool requires +// a list of file IDs, while the `file_search` tool requires a list of vector store +// IDs. +type BetaThreadNewAndRunParamsToolResources struct { + CodeInterpreter BetaThreadNewAndRunParamsToolResourcesCodeInterpreter `json:"code_interpreter,omitzero"` + FileSearch BetaThreadNewAndRunParamsToolResourcesFileSearch `json:"file_search,omitzero"` + paramObj +} + +func (r BetaThreadNewAndRunParamsToolResources) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadNewAndRunParamsToolResources + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadNewAndRunParamsToolResources) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type BetaThreadNewAndRunParamsToolResourcesCodeInterpreter struct { + // A list of [file](https://platform.openai.com/docs/api-reference/files) IDs made + // available to the `code_interpreter` tool. There can be a maximum of 20 files + // associated with the tool. + FileIDs []string `json:"file_ids,omitzero"` + paramObj +} + +func (r BetaThreadNewAndRunParamsToolResourcesCodeInterpreter) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadNewAndRunParamsToolResourcesCodeInterpreter + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadNewAndRunParamsToolResourcesCodeInterpreter) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type BetaThreadNewAndRunParamsToolResourcesFileSearch struct { + // The ID of the + // [vector store](https://platform.openai.com/docs/api-reference/vector-stores/object) + // attached to this assistant. There can be a maximum of 1 vector store attached to + // the assistant. + VectorStoreIDs []string `json:"vector_store_ids,omitzero"` + paramObj +} + +func (r BetaThreadNewAndRunParamsToolResourcesFileSearch) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadNewAndRunParamsToolResourcesFileSearch + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadNewAndRunParamsToolResourcesFileSearch) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Controls for how a thread will be truncated prior to the run. Use this to +// control the intial context window of the run. +// +// The property Type is required. +type BetaThreadNewAndRunParamsTruncationStrategy struct { + // The truncation strategy to use for the thread. The default is `auto`. If set to + // `last_messages`, the thread will be truncated to the n most recent messages in + // the thread. When set to `auto`, messages in the middle of the thread will be + // dropped to fit the context length of the model, `max_prompt_tokens`. + // + // Any of "auto", "last_messages". + Type string `json:"type,omitzero,required"` + // The number of most recent messages from the thread when constructing the context + // for the run. + LastMessages param.Opt[int64] `json:"last_messages,omitzero"` + paramObj +} + +func (r BetaThreadNewAndRunParamsTruncationStrategy) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadNewAndRunParamsTruncationStrategy + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadNewAndRunParamsTruncationStrategy) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +func init() { + apijson.RegisterFieldValidator[BetaThreadNewAndRunParamsTruncationStrategy]( + "type", "auto", "last_messages", + ) +} diff --git a/vendor/github.com/openai/openai-go/betathreadmessage.go b/vendor/github.com/openai/openai-go/betathreadmessage.go new file mode 100644 index 0000000000..3078e0a68e --- /dev/null +++ b/vendor/github.com/openai/openai-go/betathreadmessage.go @@ -0,0 +1,1712 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package openai + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + + "github.com/openai/openai-go/internal/apijson" + "github.com/openai/openai-go/internal/apiquery" + "github.com/openai/openai-go/internal/requestconfig" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/packages/pagination" + "github.com/openai/openai-go/packages/param" + "github.com/openai/openai-go/packages/respjson" + "github.com/openai/openai-go/shared" + "github.com/openai/openai-go/shared/constant" +) + +// BetaThreadMessageService contains methods and other services that help with +// interacting with the openai API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewBetaThreadMessageService] method instead. +// +// Deprecated: The Assistants API is deprecated in favor of the Responses API +type BetaThreadMessageService struct { + Options []option.RequestOption +} + +// NewBetaThreadMessageService generates a new service that applies the given +// options to each request. These options are applied after the parent client's +// options (if there is one), and before any request-specific options. +func NewBetaThreadMessageService(opts ...option.RequestOption) (r BetaThreadMessageService) { + r = BetaThreadMessageService{} + r.Options = opts + return +} + +// Create a message. +// +// Deprecated: The Assistants API is deprecated in favor of the Responses API +func (r *BetaThreadMessageService) New(ctx context.Context, threadID string, body BetaThreadMessageNewParams, opts ...option.RequestOption) (res *Message, err error) { + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...) + if threadID == "" { + err = errors.New("missing required thread_id parameter") + return + } + path := fmt.Sprintf("threads/%s/messages", threadID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) + return +} + +// Retrieve a message. +// +// Deprecated: The Assistants API is deprecated in favor of the Responses API +func (r *BetaThreadMessageService) Get(ctx context.Context, threadID string, messageID string, opts ...option.RequestOption) (res *Message, err error) { + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...) + if threadID == "" { + err = errors.New("missing required thread_id parameter") + return + } + if messageID == "" { + err = errors.New("missing required message_id parameter") + return + } + path := fmt.Sprintf("threads/%s/messages/%s", threadID, messageID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, nil, &res, opts...) + return +} + +// Modifies a message. +// +// Deprecated: The Assistants API is deprecated in favor of the Responses API +func (r *BetaThreadMessageService) Update(ctx context.Context, threadID string, messageID string, body BetaThreadMessageUpdateParams, opts ...option.RequestOption) (res *Message, err error) { + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...) + if threadID == "" { + err = errors.New("missing required thread_id parameter") + return + } + if messageID == "" { + err = errors.New("missing required message_id parameter") + return + } + path := fmt.Sprintf("threads/%s/messages/%s", threadID, messageID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) + return +} + +// Returns a list of messages for a given thread. +// +// Deprecated: The Assistants API is deprecated in favor of the Responses API +func (r *BetaThreadMessageService) List(ctx context.Context, threadID string, query BetaThreadMessageListParams, opts ...option.RequestOption) (res *pagination.CursorPage[Message], err error) { + var raw *http.Response + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2"), option.WithResponseInto(&raw)}, opts...) + if threadID == "" { + err = errors.New("missing required thread_id parameter") + return + } + path := fmt.Sprintf("threads/%s/messages", threadID) + cfg, err := requestconfig.NewRequestConfig(ctx, http.MethodGet, path, query, &res, opts...) + if err != nil { + return nil, err + } + err = cfg.Execute() + if err != nil { + return nil, err + } + res.SetPageConfig(cfg, raw) + return res, nil +} + +// Returns a list of messages for a given thread. +// +// Deprecated: The Assistants API is deprecated in favor of the Responses API +func (r *BetaThreadMessageService) ListAutoPaging(ctx context.Context, threadID string, query BetaThreadMessageListParams, opts ...option.RequestOption) *pagination.CursorPageAutoPager[Message] { + return pagination.NewCursorPageAutoPager(r.List(ctx, threadID, query, opts...)) +} + +// Deletes a message. +// +// Deprecated: The Assistants API is deprecated in favor of the Responses API +func (r *BetaThreadMessageService) Delete(ctx context.Context, threadID string, messageID string, opts ...option.RequestOption) (res *MessageDeleted, err error) { + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...) + if threadID == "" { + err = errors.New("missing required thread_id parameter") + return + } + if messageID == "" { + err = errors.New("missing required message_id parameter") + return + } + path := fmt.Sprintf("threads/%s/messages/%s", threadID, messageID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodDelete, path, nil, &res, opts...) + return +} + +// AnnotationUnion contains all possible properties and values from +// [FileCitationAnnotation], [FilePathAnnotation]. +// +// Use the [AnnotationUnion.AsAny] method to switch on the variant. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +type AnnotationUnion struct { + EndIndex int64 `json:"end_index"` + // This field is from variant [FileCitationAnnotation]. + FileCitation FileCitationAnnotationFileCitation `json:"file_citation"` + StartIndex int64 `json:"start_index"` + Text string `json:"text"` + // Any of "file_citation", "file_path". + Type string `json:"type"` + // This field is from variant [FilePathAnnotation]. + FilePath FilePathAnnotationFilePath `json:"file_path"` + JSON struct { + EndIndex respjson.Field + FileCitation respjson.Field + StartIndex respjson.Field + Text respjson.Field + Type respjson.Field + FilePath respjson.Field + raw string + } `json:"-"` +} + +// anyAnnotation is implemented by each variant of [AnnotationUnion] to add type +// safety for the return type of [AnnotationUnion.AsAny] +type anyAnnotation interface { + implAnnotationUnion() +} + +func (FileCitationAnnotation) implAnnotationUnion() {} +func (FilePathAnnotation) implAnnotationUnion() {} + +// Use the following switch statement to find the correct variant +// +// switch variant := AnnotationUnion.AsAny().(type) { +// case openai.FileCitationAnnotation: +// case openai.FilePathAnnotation: +// default: +// fmt.Errorf("no variant present") +// } +func (u AnnotationUnion) AsAny() anyAnnotation { + switch u.Type { + case "file_citation": + return u.AsFileCitation() + case "file_path": + return u.AsFilePath() + } + return nil +} + +func (u AnnotationUnion) AsFileCitation() (v FileCitationAnnotation) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u AnnotationUnion) AsFilePath() (v FilePathAnnotation) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u AnnotationUnion) RawJSON() string { return u.JSON.raw } + +func (r *AnnotationUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// AnnotationDeltaUnion contains all possible properties and values from +// [FileCitationDeltaAnnotation], [FilePathDeltaAnnotation]. +// +// Use the [AnnotationDeltaUnion.AsAny] method to switch on the variant. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +type AnnotationDeltaUnion struct { + Index int64 `json:"index"` + // Any of "file_citation", "file_path". + Type string `json:"type"` + EndIndex int64 `json:"end_index"` + // This field is from variant [FileCitationDeltaAnnotation]. + FileCitation FileCitationDeltaAnnotationFileCitation `json:"file_citation"` + StartIndex int64 `json:"start_index"` + Text string `json:"text"` + // This field is from variant [FilePathDeltaAnnotation]. + FilePath FilePathDeltaAnnotationFilePath `json:"file_path"` + JSON struct { + Index respjson.Field + Type respjson.Field + EndIndex respjson.Field + FileCitation respjson.Field + StartIndex respjson.Field + Text respjson.Field + FilePath respjson.Field + raw string + } `json:"-"` +} + +// anyAnnotationDelta is implemented by each variant of [AnnotationDeltaUnion] to +// add type safety for the return type of [AnnotationDeltaUnion.AsAny] +type anyAnnotationDelta interface { + implAnnotationDeltaUnion() +} + +func (FileCitationDeltaAnnotation) implAnnotationDeltaUnion() {} +func (FilePathDeltaAnnotation) implAnnotationDeltaUnion() {} + +// Use the following switch statement to find the correct variant +// +// switch variant := AnnotationDeltaUnion.AsAny().(type) { +// case openai.FileCitationDeltaAnnotation: +// case openai.FilePathDeltaAnnotation: +// default: +// fmt.Errorf("no variant present") +// } +func (u AnnotationDeltaUnion) AsAny() anyAnnotationDelta { + switch u.Type { + case "file_citation": + return u.AsFileCitation() + case "file_path": + return u.AsFilePath() + } + return nil +} + +func (u AnnotationDeltaUnion) AsFileCitation() (v FileCitationDeltaAnnotation) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u AnnotationDeltaUnion) AsFilePath() (v FilePathDeltaAnnotation) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u AnnotationDeltaUnion) RawJSON() string { return u.JSON.raw } + +func (r *AnnotationDeltaUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// A citation within the message that points to a specific quote from a specific +// File associated with the assistant or the message. Generated when the assistant +// uses the "file_search" tool to search files. +type FileCitationAnnotation struct { + EndIndex int64 `json:"end_index,required"` + FileCitation FileCitationAnnotationFileCitation `json:"file_citation,required"` + StartIndex int64 `json:"start_index,required"` + // The text in the message content that needs to be replaced. + Text string `json:"text,required"` + // Always `file_citation`. + Type constant.FileCitation `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + EndIndex respjson.Field + FileCitation respjson.Field + StartIndex respjson.Field + Text respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FileCitationAnnotation) RawJSON() string { return r.JSON.raw } +func (r *FileCitationAnnotation) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FileCitationAnnotationFileCitation struct { + // The ID of the specific File the citation is from. + FileID string `json:"file_id,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + FileID respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FileCitationAnnotationFileCitation) RawJSON() string { return r.JSON.raw } +func (r *FileCitationAnnotationFileCitation) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// A citation within the message that points to a specific quote from a specific +// File associated with the assistant or the message. Generated when the assistant +// uses the "file_search" tool to search files. +type FileCitationDeltaAnnotation struct { + // The index of the annotation in the text content part. + Index int64 `json:"index,required"` + // Always `file_citation`. + Type constant.FileCitation `json:"type,required"` + EndIndex int64 `json:"end_index"` + FileCitation FileCitationDeltaAnnotationFileCitation `json:"file_citation"` + StartIndex int64 `json:"start_index"` + // The text in the message content that needs to be replaced. + Text string `json:"text"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Index respjson.Field + Type respjson.Field + EndIndex respjson.Field + FileCitation respjson.Field + StartIndex respjson.Field + Text respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FileCitationDeltaAnnotation) RawJSON() string { return r.JSON.raw } +func (r *FileCitationDeltaAnnotation) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FileCitationDeltaAnnotationFileCitation struct { + // The ID of the specific File the citation is from. + FileID string `json:"file_id"` + // The specific quote in the file. + Quote string `json:"quote"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + FileID respjson.Field + Quote respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FileCitationDeltaAnnotationFileCitation) RawJSON() string { return r.JSON.raw } +func (r *FileCitationDeltaAnnotationFileCitation) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// A URL for the file that's generated when the assistant used the +// `code_interpreter` tool to generate a file. +type FilePathAnnotation struct { + EndIndex int64 `json:"end_index,required"` + FilePath FilePathAnnotationFilePath `json:"file_path,required"` + StartIndex int64 `json:"start_index,required"` + // The text in the message content that needs to be replaced. + Text string `json:"text,required"` + // Always `file_path`. + Type constant.FilePath `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + EndIndex respjson.Field + FilePath respjson.Field + StartIndex respjson.Field + Text respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FilePathAnnotation) RawJSON() string { return r.JSON.raw } +func (r *FilePathAnnotation) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FilePathAnnotationFilePath struct { + // The ID of the file that was generated. + FileID string `json:"file_id,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + FileID respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FilePathAnnotationFilePath) RawJSON() string { return r.JSON.raw } +func (r *FilePathAnnotationFilePath) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// A URL for the file that's generated when the assistant used the +// `code_interpreter` tool to generate a file. +type FilePathDeltaAnnotation struct { + // The index of the annotation in the text content part. + Index int64 `json:"index,required"` + // Always `file_path`. + Type constant.FilePath `json:"type,required"` + EndIndex int64 `json:"end_index"` + FilePath FilePathDeltaAnnotationFilePath `json:"file_path"` + StartIndex int64 `json:"start_index"` + // The text in the message content that needs to be replaced. + Text string `json:"text"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Index respjson.Field + Type respjson.Field + EndIndex respjson.Field + FilePath respjson.Field + StartIndex respjson.Field + Text respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FilePathDeltaAnnotation) RawJSON() string { return r.JSON.raw } +func (r *FilePathDeltaAnnotation) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FilePathDeltaAnnotationFilePath struct { + // The ID of the file that was generated. + FileID string `json:"file_id"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + FileID respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FilePathDeltaAnnotationFilePath) RawJSON() string { return r.JSON.raw } +func (r *FilePathDeltaAnnotationFilePath) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type ImageFile struct { + // The [File](https://platform.openai.com/docs/api-reference/files) ID of the image + // in the message content. Set `purpose="vision"` when uploading the File if you + // need to later display the file content. + FileID string `json:"file_id,required"` + // Specifies the detail level of the image if specified by the user. `low` uses + // fewer tokens, you can opt in to high resolution using `high`. + // + // Any of "auto", "low", "high". + Detail ImageFileDetail `json:"detail"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + FileID respjson.Field + Detail respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ImageFile) RawJSON() string { return r.JSON.raw } +func (r *ImageFile) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ToParam converts this ImageFile to a ImageFileParam. +// +// Warning: the fields of the param type will not be present. ToParam should only +// be used at the last possible moment before sending a request. Test for this with +// ImageFileParam.Overrides() +func (r ImageFile) ToParam() ImageFileParam { + return param.Override[ImageFileParam](json.RawMessage(r.RawJSON())) +} + +// Specifies the detail level of the image if specified by the user. `low` uses +// fewer tokens, you can opt in to high resolution using `high`. +type ImageFileDetail string + +const ( + ImageFileDetailAuto ImageFileDetail = "auto" + ImageFileDetailLow ImageFileDetail = "low" + ImageFileDetailHigh ImageFileDetail = "high" +) + +// The property FileID is required. +type ImageFileParam struct { + // The [File](https://platform.openai.com/docs/api-reference/files) ID of the image + // in the message content. Set `purpose="vision"` when uploading the File if you + // need to later display the file content. + FileID string `json:"file_id,required"` + // Specifies the detail level of the image if specified by the user. `low` uses + // fewer tokens, you can opt in to high resolution using `high`. + // + // Any of "auto", "low", "high". + Detail ImageFileDetail `json:"detail,omitzero"` + paramObj +} + +func (r ImageFileParam) MarshalJSON() (data []byte, err error) { + type shadow ImageFileParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ImageFileParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// References an image [File](https://platform.openai.com/docs/api-reference/files) +// in the content of a message. +type ImageFileContentBlock struct { + ImageFile ImageFile `json:"image_file,required"` + // Always `image_file`. + Type constant.ImageFile `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ImageFile respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ImageFileContentBlock) RawJSON() string { return r.JSON.raw } +func (r *ImageFileContentBlock) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ToParam converts this ImageFileContentBlock to a ImageFileContentBlockParam. +// +// Warning: the fields of the param type will not be present. ToParam should only +// be used at the last possible moment before sending a request. Test for this with +// ImageFileContentBlockParam.Overrides() +func (r ImageFileContentBlock) ToParam() ImageFileContentBlockParam { + return param.Override[ImageFileContentBlockParam](json.RawMessage(r.RawJSON())) +} + +// References an image [File](https://platform.openai.com/docs/api-reference/files) +// in the content of a message. +// +// The properties ImageFile, Type are required. +type ImageFileContentBlockParam struct { + ImageFile ImageFileParam `json:"image_file,omitzero,required"` + // Always `image_file`. + // + // This field can be elided, and will marshal its zero value as "image_file". + Type constant.ImageFile `json:"type,required"` + paramObj +} + +func (r ImageFileContentBlockParam) MarshalJSON() (data []byte, err error) { + type shadow ImageFileContentBlockParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ImageFileContentBlockParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type ImageFileDelta struct { + // Specifies the detail level of the image if specified by the user. `low` uses + // fewer tokens, you can opt in to high resolution using `high`. + // + // Any of "auto", "low", "high". + Detail ImageFileDeltaDetail `json:"detail"` + // The [File](https://platform.openai.com/docs/api-reference/files) ID of the image + // in the message content. Set `purpose="vision"` when uploading the File if you + // need to later display the file content. + FileID string `json:"file_id"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Detail respjson.Field + FileID respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ImageFileDelta) RawJSON() string { return r.JSON.raw } +func (r *ImageFileDelta) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Specifies the detail level of the image if specified by the user. `low` uses +// fewer tokens, you can opt in to high resolution using `high`. +type ImageFileDeltaDetail string + +const ( + ImageFileDeltaDetailAuto ImageFileDeltaDetail = "auto" + ImageFileDeltaDetailLow ImageFileDeltaDetail = "low" + ImageFileDeltaDetailHigh ImageFileDeltaDetail = "high" +) + +// References an image [File](https://platform.openai.com/docs/api-reference/files) +// in the content of a message. +type ImageFileDeltaBlock struct { + // The index of the content part in the message. + Index int64 `json:"index,required"` + // Always `image_file`. + Type constant.ImageFile `json:"type,required"` + ImageFile ImageFileDelta `json:"image_file"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Index respjson.Field + Type respjson.Field + ImageFile respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ImageFileDeltaBlock) RawJSON() string { return r.JSON.raw } +func (r *ImageFileDeltaBlock) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type ImageURL struct { + // The external URL of the image, must be a supported image types: jpeg, jpg, png, + // gif, webp. + URL string `json:"url,required" format:"uri"` + // Specifies the detail level of the image. `low` uses fewer tokens, you can opt in + // to high resolution using `high`. Default value is `auto` + // + // Any of "auto", "low", "high". + Detail ImageURLDetail `json:"detail"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + URL respjson.Field + Detail respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ImageURL) RawJSON() string { return r.JSON.raw } +func (r *ImageURL) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ToParam converts this ImageURL to a ImageURLParam. +// +// Warning: the fields of the param type will not be present. ToParam should only +// be used at the last possible moment before sending a request. Test for this with +// ImageURLParam.Overrides() +func (r ImageURL) ToParam() ImageURLParam { + return param.Override[ImageURLParam](json.RawMessage(r.RawJSON())) +} + +// Specifies the detail level of the image. `low` uses fewer tokens, you can opt in +// to high resolution using `high`. Default value is `auto` +type ImageURLDetail string + +const ( + ImageURLDetailAuto ImageURLDetail = "auto" + ImageURLDetailLow ImageURLDetail = "low" + ImageURLDetailHigh ImageURLDetail = "high" +) + +// The property URL is required. +type ImageURLParam struct { + // The external URL of the image, must be a supported image types: jpeg, jpg, png, + // gif, webp. + URL string `json:"url,required" format:"uri"` + // Specifies the detail level of the image. `low` uses fewer tokens, you can opt in + // to high resolution using `high`. Default value is `auto` + // + // Any of "auto", "low", "high". + Detail ImageURLDetail `json:"detail,omitzero"` + paramObj +} + +func (r ImageURLParam) MarshalJSON() (data []byte, err error) { + type shadow ImageURLParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ImageURLParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// References an image URL in the content of a message. +type ImageURLContentBlock struct { + ImageURL ImageURL `json:"image_url,required"` + // The type of the content part. + Type constant.ImageURL `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ImageURL respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ImageURLContentBlock) RawJSON() string { return r.JSON.raw } +func (r *ImageURLContentBlock) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ToParam converts this ImageURLContentBlock to a ImageURLContentBlockParam. +// +// Warning: the fields of the param type will not be present. ToParam should only +// be used at the last possible moment before sending a request. Test for this with +// ImageURLContentBlockParam.Overrides() +func (r ImageURLContentBlock) ToParam() ImageURLContentBlockParam { + return param.Override[ImageURLContentBlockParam](json.RawMessage(r.RawJSON())) +} + +// References an image URL in the content of a message. +// +// The properties ImageURL, Type are required. +type ImageURLContentBlockParam struct { + ImageURL ImageURLParam `json:"image_url,omitzero,required"` + // The type of the content part. + // + // This field can be elided, and will marshal its zero value as "image_url". + Type constant.ImageURL `json:"type,required"` + paramObj +} + +func (r ImageURLContentBlockParam) MarshalJSON() (data []byte, err error) { + type shadow ImageURLContentBlockParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ImageURLContentBlockParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type ImageURLDelta struct { + // Specifies the detail level of the image. `low` uses fewer tokens, you can opt in + // to high resolution using `high`. + // + // Any of "auto", "low", "high". + Detail ImageURLDeltaDetail `json:"detail"` + // The URL of the image, must be a supported image types: jpeg, jpg, png, gif, + // webp. + URL string `json:"url"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Detail respjson.Field + URL respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ImageURLDelta) RawJSON() string { return r.JSON.raw } +func (r *ImageURLDelta) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Specifies the detail level of the image. `low` uses fewer tokens, you can opt in +// to high resolution using `high`. +type ImageURLDeltaDetail string + +const ( + ImageURLDeltaDetailAuto ImageURLDeltaDetail = "auto" + ImageURLDeltaDetailLow ImageURLDeltaDetail = "low" + ImageURLDeltaDetailHigh ImageURLDeltaDetail = "high" +) + +// References an image URL in the content of a message. +type ImageURLDeltaBlock struct { + // The index of the content part in the message. + Index int64 `json:"index,required"` + // Always `image_url`. + Type constant.ImageURL `json:"type,required"` + ImageURL ImageURLDelta `json:"image_url"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Index respjson.Field + Type respjson.Field + ImageURL respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ImageURLDeltaBlock) RawJSON() string { return r.JSON.raw } +func (r *ImageURLDeltaBlock) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Represents a message within a +// [thread](https://platform.openai.com/docs/api-reference/threads). +type Message struct { + // The identifier, which can be referenced in API endpoints. + ID string `json:"id,required"` + // If applicable, the ID of the + // [assistant](https://platform.openai.com/docs/api-reference/assistants) that + // authored this message. + AssistantID string `json:"assistant_id,required"` + // A list of files attached to the message, and the tools they were added to. + Attachments []MessageAttachment `json:"attachments,required"` + // The Unix timestamp (in seconds) for when the message was completed. + CompletedAt int64 `json:"completed_at,required"` + // The content of the message in array of text and/or images. + Content []MessageContentUnion `json:"content,required"` + // The Unix timestamp (in seconds) for when the message was created. + CreatedAt int64 `json:"created_at,required"` + // The Unix timestamp (in seconds) for when the message was marked as incomplete. + IncompleteAt int64 `json:"incomplete_at,required"` + // On an incomplete message, details about why the message is incomplete. + IncompleteDetails MessageIncompleteDetails `json:"incomplete_details,required"` + // Set of 16 key-value pairs that can be attached to an object. This can be useful + // for storing additional information about the object in a structured format, and + // querying for objects via API or the dashboard. + // + // Keys are strings with a maximum length of 64 characters. Values are strings with + // a maximum length of 512 characters. + Metadata shared.Metadata `json:"metadata,required"` + // The object type, which is always `thread.message`. + Object constant.ThreadMessage `json:"object,required"` + // The entity that produced the message. One of `user` or `assistant`. + // + // Any of "user", "assistant". + Role MessageRole `json:"role,required"` + // The ID of the [run](https://platform.openai.com/docs/api-reference/runs) + // associated with the creation of this message. Value is `null` when messages are + // created manually using the create message or create thread endpoints. + RunID string `json:"run_id,required"` + // The status of the message, which can be either `in_progress`, `incomplete`, or + // `completed`. + // + // Any of "in_progress", "incomplete", "completed". + Status MessageStatus `json:"status,required"` + // The [thread](https://platform.openai.com/docs/api-reference/threads) ID that + // this message belongs to. + ThreadID string `json:"thread_id,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + AssistantID respjson.Field + Attachments respjson.Field + CompletedAt respjson.Field + Content respjson.Field + CreatedAt respjson.Field + IncompleteAt respjson.Field + IncompleteDetails respjson.Field + Metadata respjson.Field + Object respjson.Field + Role respjson.Field + RunID respjson.Field + Status respjson.Field + ThreadID respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r Message) RawJSON() string { return r.JSON.raw } +func (r *Message) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type MessageAttachment struct { + // The ID of the file to attach to the message. + FileID string `json:"file_id"` + // The tools to add this file to. + Tools []MessageAttachmentToolUnion `json:"tools"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + FileID respjson.Field + Tools respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r MessageAttachment) RawJSON() string { return r.JSON.raw } +func (r *MessageAttachment) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// MessageAttachmentToolUnion contains all possible properties and values from +// [CodeInterpreterTool], [MessageAttachmentToolFileSearchTool]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +type MessageAttachmentToolUnion struct { + Type string `json:"type"` + JSON struct { + Type respjson.Field + raw string + } `json:"-"` +} + +func (u MessageAttachmentToolUnion) AsCodeInterpreterTool() (v CodeInterpreterTool) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u MessageAttachmentToolUnion) AsFileSearchTool() (v MessageAttachmentToolFileSearchTool) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u MessageAttachmentToolUnion) RawJSON() string { return u.JSON.raw } + +func (r *MessageAttachmentToolUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type MessageAttachmentToolFileSearchTool struct { + // The type of tool being defined: `file_search` + Type constant.FileSearch `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r MessageAttachmentToolFileSearchTool) RawJSON() string { return r.JSON.raw } +func (r *MessageAttachmentToolFileSearchTool) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// On an incomplete message, details about why the message is incomplete. +type MessageIncompleteDetails struct { + // The reason the message is incomplete. + // + // Any of "content_filter", "max_tokens", "run_cancelled", "run_expired", + // "run_failed". + Reason string `json:"reason,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Reason respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r MessageIncompleteDetails) RawJSON() string { return r.JSON.raw } +func (r *MessageIncompleteDetails) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The entity that produced the message. One of `user` or `assistant`. +type MessageRole string + +const ( + MessageRoleUser MessageRole = "user" + MessageRoleAssistant MessageRole = "assistant" +) + +// The status of the message, which can be either `in_progress`, `incomplete`, or +// `completed`. +type MessageStatus string + +const ( + MessageStatusInProgress MessageStatus = "in_progress" + MessageStatusIncomplete MessageStatus = "incomplete" + MessageStatusCompleted MessageStatus = "completed" +) + +// MessageContentUnion contains all possible properties and values from +// [ImageFileContentBlock], [ImageURLContentBlock], [TextContentBlock], +// [RefusalContentBlock]. +// +// Use the [MessageContentUnion.AsAny] method to switch on the variant. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +type MessageContentUnion struct { + // This field is from variant [ImageFileContentBlock]. + ImageFile ImageFile `json:"image_file"` + // Any of "image_file", "image_url", "text", "refusal". + Type string `json:"type"` + // This field is from variant [ImageURLContentBlock]. + ImageURL ImageURL `json:"image_url"` + // This field is from variant [TextContentBlock]. + Text Text `json:"text"` + // This field is from variant [RefusalContentBlock]. + Refusal string `json:"refusal"` + JSON struct { + ImageFile respjson.Field + Type respjson.Field + ImageURL respjson.Field + Text respjson.Field + Refusal respjson.Field + raw string + } `json:"-"` +} + +// anyMessageContent is implemented by each variant of [MessageContentUnion] to add +// type safety for the return type of [MessageContentUnion.AsAny] +type anyMessageContent interface { + implMessageContentUnion() +} + +func (ImageFileContentBlock) implMessageContentUnion() {} +func (ImageURLContentBlock) implMessageContentUnion() {} +func (TextContentBlock) implMessageContentUnion() {} +func (RefusalContentBlock) implMessageContentUnion() {} + +// Use the following switch statement to find the correct variant +// +// switch variant := MessageContentUnion.AsAny().(type) { +// case openai.ImageFileContentBlock: +// case openai.ImageURLContentBlock: +// case openai.TextContentBlock: +// case openai.RefusalContentBlock: +// default: +// fmt.Errorf("no variant present") +// } +func (u MessageContentUnion) AsAny() anyMessageContent { + switch u.Type { + case "image_file": + return u.AsImageFile() + case "image_url": + return u.AsImageURL() + case "text": + return u.AsText() + case "refusal": + return u.AsRefusal() + } + return nil +} + +func (u MessageContentUnion) AsImageFile() (v ImageFileContentBlock) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u MessageContentUnion) AsImageURL() (v ImageURLContentBlock) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u MessageContentUnion) AsText() (v TextContentBlock) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u MessageContentUnion) AsRefusal() (v RefusalContentBlock) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u MessageContentUnion) RawJSON() string { return u.JSON.raw } + +func (r *MessageContentUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// MessageContentDeltaUnion contains all possible properties and values from +// [ImageFileDeltaBlock], [TextDeltaBlock], [RefusalDeltaBlock], +// [ImageURLDeltaBlock]. +// +// Use the [MessageContentDeltaUnion.AsAny] method to switch on the variant. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +type MessageContentDeltaUnion struct { + Index int64 `json:"index"` + // Any of "image_file", "text", "refusal", "image_url". + Type string `json:"type"` + // This field is from variant [ImageFileDeltaBlock]. + ImageFile ImageFileDelta `json:"image_file"` + // This field is from variant [TextDeltaBlock]. + Text TextDelta `json:"text"` + // This field is from variant [RefusalDeltaBlock]. + Refusal string `json:"refusal"` + // This field is from variant [ImageURLDeltaBlock]. + ImageURL ImageURLDelta `json:"image_url"` + JSON struct { + Index respjson.Field + Type respjson.Field + ImageFile respjson.Field + Text respjson.Field + Refusal respjson.Field + ImageURL respjson.Field + raw string + } `json:"-"` +} + +// anyMessageContentDelta is implemented by each variant of +// [MessageContentDeltaUnion] to add type safety for the return type of +// [MessageContentDeltaUnion.AsAny] +type anyMessageContentDelta interface { + implMessageContentDeltaUnion() +} + +func (ImageFileDeltaBlock) implMessageContentDeltaUnion() {} +func (TextDeltaBlock) implMessageContentDeltaUnion() {} +func (RefusalDeltaBlock) implMessageContentDeltaUnion() {} +func (ImageURLDeltaBlock) implMessageContentDeltaUnion() {} + +// Use the following switch statement to find the correct variant +// +// switch variant := MessageContentDeltaUnion.AsAny().(type) { +// case openai.ImageFileDeltaBlock: +// case openai.TextDeltaBlock: +// case openai.RefusalDeltaBlock: +// case openai.ImageURLDeltaBlock: +// default: +// fmt.Errorf("no variant present") +// } +func (u MessageContentDeltaUnion) AsAny() anyMessageContentDelta { + switch u.Type { + case "image_file": + return u.AsImageFile() + case "text": + return u.AsText() + case "refusal": + return u.AsRefusal() + case "image_url": + return u.AsImageURL() + } + return nil +} + +func (u MessageContentDeltaUnion) AsImageFile() (v ImageFileDeltaBlock) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u MessageContentDeltaUnion) AsText() (v TextDeltaBlock) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u MessageContentDeltaUnion) AsRefusal() (v RefusalDeltaBlock) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u MessageContentDeltaUnion) AsImageURL() (v ImageURLDeltaBlock) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u MessageContentDeltaUnion) RawJSON() string { return u.JSON.raw } + +func (r *MessageContentDeltaUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +func MessageContentPartParamOfImageFile(imageFile ImageFileParam) MessageContentPartParamUnion { + var variant ImageFileContentBlockParam + variant.ImageFile = imageFile + return MessageContentPartParamUnion{OfImageFile: &variant} +} + +func MessageContentPartParamOfImageURL(imageURL ImageURLParam) MessageContentPartParamUnion { + var variant ImageURLContentBlockParam + variant.ImageURL = imageURL + return MessageContentPartParamUnion{OfImageURL: &variant} +} + +func MessageContentPartParamOfText(text string) MessageContentPartParamUnion { + var variant TextContentBlockParam + variant.Text = text + return MessageContentPartParamUnion{OfText: &variant} +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type MessageContentPartParamUnion struct { + OfImageFile *ImageFileContentBlockParam `json:",omitzero,inline"` + OfImageURL *ImageURLContentBlockParam `json:",omitzero,inline"` + OfText *TextContentBlockParam `json:",omitzero,inline"` + paramUnion +} + +func (u MessageContentPartParamUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfImageFile, u.OfImageURL, u.OfText) +} +func (u *MessageContentPartParamUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *MessageContentPartParamUnion) asAny() any { + if !param.IsOmitted(u.OfImageFile) { + return u.OfImageFile + } else if !param.IsOmitted(u.OfImageURL) { + return u.OfImageURL + } else if !param.IsOmitted(u.OfText) { + return u.OfText + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u MessageContentPartParamUnion) GetImageFile() *ImageFileParam { + if vt := u.OfImageFile; vt != nil { + return &vt.ImageFile + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u MessageContentPartParamUnion) GetImageURL() *ImageURLParam { + if vt := u.OfImageURL; vt != nil { + return &vt.ImageURL + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u MessageContentPartParamUnion) GetText() *string { + if vt := u.OfText; vt != nil { + return &vt.Text + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u MessageContentPartParamUnion) GetType() *string { + if vt := u.OfImageFile; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfImageURL; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfText; vt != nil { + return (*string)(&vt.Type) + } + return nil +} + +func init() { + apijson.RegisterUnion[MessageContentPartParamUnion]( + "type", + apijson.Discriminator[ImageFileContentBlockParam]("image_file"), + apijson.Discriminator[ImageURLContentBlockParam]("image_url"), + apijson.Discriminator[TextContentBlockParam]("text"), + ) +} + +type MessageDeleted struct { + ID string `json:"id,required"` + Deleted bool `json:"deleted,required"` + Object constant.ThreadMessageDeleted `json:"object,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + Deleted respjson.Field + Object respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r MessageDeleted) RawJSON() string { return r.JSON.raw } +func (r *MessageDeleted) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The delta containing the fields that have changed on the Message. +type MessageDelta struct { + // The content of the message in array of text and/or images. + Content []MessageContentDeltaUnion `json:"content"` + // The entity that produced the message. One of `user` or `assistant`. + // + // Any of "user", "assistant". + Role MessageDeltaRole `json:"role"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Content respjson.Field + Role respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r MessageDelta) RawJSON() string { return r.JSON.raw } +func (r *MessageDelta) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The entity that produced the message. One of `user` or `assistant`. +type MessageDeltaRole string + +const ( + MessageDeltaRoleUser MessageDeltaRole = "user" + MessageDeltaRoleAssistant MessageDeltaRole = "assistant" +) + +// Represents a message delta i.e. any changed fields on a message during +// streaming. +type MessageDeltaEvent struct { + // The identifier of the message, which can be referenced in API endpoints. + ID string `json:"id,required"` + // The delta containing the fields that have changed on the Message. + Delta MessageDelta `json:"delta,required"` + // The object type, which is always `thread.message.delta`. + Object constant.ThreadMessageDelta `json:"object,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + Delta respjson.Field + Object respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r MessageDeltaEvent) RawJSON() string { return r.JSON.raw } +func (r *MessageDeltaEvent) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The refusal content generated by the assistant. +type RefusalContentBlock struct { + Refusal string `json:"refusal,required"` + // Always `refusal`. + Type constant.Refusal `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Refusal respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r RefusalContentBlock) RawJSON() string { return r.JSON.raw } +func (r *RefusalContentBlock) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The refusal content that is part of a message. +type RefusalDeltaBlock struct { + // The index of the refusal part in the message. + Index int64 `json:"index,required"` + // Always `refusal`. + Type constant.Refusal `json:"type,required"` + Refusal string `json:"refusal"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Index respjson.Field + Type respjson.Field + Refusal respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r RefusalDeltaBlock) RawJSON() string { return r.JSON.raw } +func (r *RefusalDeltaBlock) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type Text struct { + Annotations []AnnotationUnion `json:"annotations,required"` + // The data that makes up the text. + Value string `json:"value,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Annotations respjson.Field + Value respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r Text) RawJSON() string { return r.JSON.raw } +func (r *Text) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The text content that is part of a message. +type TextContentBlock struct { + Text Text `json:"text,required"` + // Always `text`. + Type constant.Text `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Text respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r TextContentBlock) RawJSON() string { return r.JSON.raw } +func (r *TextContentBlock) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The text content that is part of a message. +// +// The properties Text, Type are required. +type TextContentBlockParam struct { + // Text content to be sent to the model + Text string `json:"text,required"` + // Always `text`. + // + // This field can be elided, and will marshal its zero value as "text". + Type constant.Text `json:"type,required"` + paramObj +} + +func (r TextContentBlockParam) MarshalJSON() (data []byte, err error) { + type shadow TextContentBlockParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *TextContentBlockParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type TextDelta struct { + Annotations []AnnotationDeltaUnion `json:"annotations"` + // The data that makes up the text. + Value string `json:"value"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Annotations respjson.Field + Value respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r TextDelta) RawJSON() string { return r.JSON.raw } +func (r *TextDelta) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The text content that is part of a message. +type TextDeltaBlock struct { + // The index of the content part in the message. + Index int64 `json:"index,required"` + // Always `text`. + Type constant.Text `json:"type,required"` + Text TextDelta `json:"text"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Index respjson.Field + Type respjson.Field + Text respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r TextDeltaBlock) RawJSON() string { return r.JSON.raw } +func (r *TextDeltaBlock) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type BetaThreadMessageNewParams struct { + // The text contents of the message. + Content BetaThreadMessageNewParamsContentUnion `json:"content,omitzero,required"` + // The role of the entity that is creating the message. Allowed values include: + // + // - `user`: Indicates the message is sent by an actual user and should be used in + // most cases to represent user-generated messages. + // - `assistant`: Indicates the message is generated by the assistant. Use this + // value to insert messages from the assistant into the conversation. + // + // Any of "user", "assistant". + Role BetaThreadMessageNewParamsRole `json:"role,omitzero,required"` + // A list of files attached to the message, and the tools they should be added to. + Attachments []BetaThreadMessageNewParamsAttachment `json:"attachments,omitzero"` + // Set of 16 key-value pairs that can be attached to an object. This can be useful + // for storing additional information about the object in a structured format, and + // querying for objects via API or the dashboard. + // + // Keys are strings with a maximum length of 64 characters. Values are strings with + // a maximum length of 512 characters. + Metadata shared.Metadata `json:"metadata,omitzero"` + paramObj +} + +func (r BetaThreadMessageNewParams) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadMessageNewParams + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadMessageNewParams) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type BetaThreadMessageNewParamsContentUnion struct { + OfString param.Opt[string] `json:",omitzero,inline"` + OfArrayOfContentParts []MessageContentPartParamUnion `json:",omitzero,inline"` + paramUnion +} + +func (u BetaThreadMessageNewParamsContentUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfString, u.OfArrayOfContentParts) +} +func (u *BetaThreadMessageNewParamsContentUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *BetaThreadMessageNewParamsContentUnion) asAny() any { + if !param.IsOmitted(u.OfString) { + return &u.OfString.Value + } else if !param.IsOmitted(u.OfArrayOfContentParts) { + return &u.OfArrayOfContentParts + } + return nil +} + +// The role of the entity that is creating the message. Allowed values include: +// +// - `user`: Indicates the message is sent by an actual user and should be used in +// most cases to represent user-generated messages. +// - `assistant`: Indicates the message is generated by the assistant. Use this +// value to insert messages from the assistant into the conversation. +type BetaThreadMessageNewParamsRole string + +const ( + BetaThreadMessageNewParamsRoleUser BetaThreadMessageNewParamsRole = "user" + BetaThreadMessageNewParamsRoleAssistant BetaThreadMessageNewParamsRole = "assistant" +) + +type BetaThreadMessageNewParamsAttachment struct { + // The ID of the file to attach to the message. + FileID param.Opt[string] `json:"file_id,omitzero"` + // The tools to add this file to. + Tools []BetaThreadMessageNewParamsAttachmentToolUnion `json:"tools,omitzero"` + paramObj +} + +func (r BetaThreadMessageNewParamsAttachment) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadMessageNewParamsAttachment + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadMessageNewParamsAttachment) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type BetaThreadMessageNewParamsAttachmentToolUnion struct { + OfCodeInterpreter *CodeInterpreterToolParam `json:",omitzero,inline"` + OfFileSearch *BetaThreadMessageNewParamsAttachmentToolFileSearch `json:",omitzero,inline"` + paramUnion +} + +func (u BetaThreadMessageNewParamsAttachmentToolUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfCodeInterpreter, u.OfFileSearch) +} +func (u *BetaThreadMessageNewParamsAttachmentToolUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *BetaThreadMessageNewParamsAttachmentToolUnion) asAny() any { + if !param.IsOmitted(u.OfCodeInterpreter) { + return u.OfCodeInterpreter + } else if !param.IsOmitted(u.OfFileSearch) { + return u.OfFileSearch + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u BetaThreadMessageNewParamsAttachmentToolUnion) GetType() *string { + if vt := u.OfCodeInterpreter; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfFileSearch; vt != nil { + return (*string)(&vt.Type) + } + return nil +} + +func init() { + apijson.RegisterUnion[BetaThreadMessageNewParamsAttachmentToolUnion]( + "type", + apijson.Discriminator[CodeInterpreterToolParam]("code_interpreter"), + apijson.Discriminator[BetaThreadMessageNewParamsAttachmentToolFileSearch]("file_search"), + ) +} + +func NewBetaThreadMessageNewParamsAttachmentToolFileSearch() BetaThreadMessageNewParamsAttachmentToolFileSearch { + return BetaThreadMessageNewParamsAttachmentToolFileSearch{ + Type: "file_search", + } +} + +// This struct has a constant value, construct it with +// [NewBetaThreadMessageNewParamsAttachmentToolFileSearch]. +type BetaThreadMessageNewParamsAttachmentToolFileSearch struct { + // The type of tool being defined: `file_search` + Type constant.FileSearch `json:"type,required"` + paramObj +} + +func (r BetaThreadMessageNewParamsAttachmentToolFileSearch) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadMessageNewParamsAttachmentToolFileSearch + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadMessageNewParamsAttachmentToolFileSearch) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type BetaThreadMessageUpdateParams struct { + // Set of 16 key-value pairs that can be attached to an object. This can be useful + // for storing additional information about the object in a structured format, and + // querying for objects via API or the dashboard. + // + // Keys are strings with a maximum length of 64 characters. Values are strings with + // a maximum length of 512 characters. + Metadata shared.Metadata `json:"metadata,omitzero"` + paramObj +} + +func (r BetaThreadMessageUpdateParams) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadMessageUpdateParams + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadMessageUpdateParams) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type BetaThreadMessageListParams struct { + // A cursor for use in pagination. `after` is an object ID that defines your place + // in the list. For instance, if you make a list request and receive 100 objects, + // ending with obj_foo, your subsequent call can include after=obj_foo in order to + // fetch the next page of the list. + After param.Opt[string] `query:"after,omitzero" json:"-"` + // A cursor for use in pagination. `before` is an object ID that defines your place + // in the list. For instance, if you make a list request and receive 100 objects, + // starting with obj_foo, your subsequent call can include before=obj_foo in order + // to fetch the previous page of the list. + Before param.Opt[string] `query:"before,omitzero" json:"-"` + // A limit on the number of objects to be returned. Limit can range between 1 and + // 100, and the default is 20. + Limit param.Opt[int64] `query:"limit,omitzero" json:"-"` + // Filter messages by the run ID that generated them. + RunID param.Opt[string] `query:"run_id,omitzero" json:"-"` + // Sort order by the `created_at` timestamp of the objects. `asc` for ascending + // order and `desc` for descending order. + // + // Any of "asc", "desc". + Order BetaThreadMessageListParamsOrder `query:"order,omitzero" json:"-"` + paramObj +} + +// URLQuery serializes [BetaThreadMessageListParams]'s query parameters as +// `url.Values`. +func (r BetaThreadMessageListParams) URLQuery() (v url.Values, err error) { + return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{ + ArrayFormat: apiquery.ArrayQueryFormatBrackets, + NestedFormat: apiquery.NestedQueryFormatBrackets, + }) +} + +// Sort order by the `created_at` timestamp of the objects. `asc` for ascending +// order and `desc` for descending order. +type BetaThreadMessageListParamsOrder string + +const ( + BetaThreadMessageListParamsOrderAsc BetaThreadMessageListParamsOrder = "asc" + BetaThreadMessageListParamsOrderDesc BetaThreadMessageListParamsOrder = "desc" +) diff --git a/vendor/github.com/openai/openai-go/betathreadrun.go b/vendor/github.com/openai/openai-go/betathreadrun.go new file mode 100644 index 0000000000..7d7e11662b --- /dev/null +++ b/vendor/github.com/openai/openai-go/betathreadrun.go @@ -0,0 +1,960 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package openai + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + + "github.com/openai/openai-go/internal/apijson" + "github.com/openai/openai-go/internal/apiquery" + "github.com/openai/openai-go/internal/requestconfig" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/packages/pagination" + "github.com/openai/openai-go/packages/param" + "github.com/openai/openai-go/packages/respjson" + "github.com/openai/openai-go/packages/ssestream" + "github.com/openai/openai-go/shared" + "github.com/openai/openai-go/shared/constant" +) + +// BetaThreadRunService contains methods and other services that help with +// interacting with the openai API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewBetaThreadRunService] method instead. +// +// Deprecated: The Assistants API is deprecated in favor of the Responses API +type BetaThreadRunService struct { + Options []option.RequestOption + // Deprecated: The Assistants API is deprecated in favor of the Responses API + Steps BetaThreadRunStepService +} + +// NewBetaThreadRunService generates a new service that applies the given options +// to each request. These options are applied after the parent client's options (if +// there is one), and before any request-specific options. +func NewBetaThreadRunService(opts ...option.RequestOption) (r BetaThreadRunService) { + r = BetaThreadRunService{} + r.Options = opts + r.Steps = NewBetaThreadRunStepService(opts...) + return +} + +// Create a run. +// +// Deprecated: The Assistants API is deprecated in favor of the Responses API +func (r *BetaThreadRunService) New(ctx context.Context, threadID string, params BetaThreadRunNewParams, opts ...option.RequestOption) (res *Run, err error) { + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...) + if threadID == "" { + err = errors.New("missing required thread_id parameter") + return + } + path := fmt.Sprintf("threads/%s/runs", threadID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, params, &res, opts...) + return +} + +// Create a run. +// +// Deprecated: The Assistants API is deprecated in favor of the Responses API +func (r *BetaThreadRunService) NewStreaming(ctx context.Context, threadID string, params BetaThreadRunNewParams, opts ...option.RequestOption) (stream *ssestream.Stream[AssistantStreamEventUnion]) { + var ( + raw *http.Response + err error + ) + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2"), option.WithJSONSet("stream", true)}, opts...) + if threadID == "" { + err = errors.New("missing required thread_id parameter") + return + } + path := fmt.Sprintf("threads/%s/runs", threadID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, params, &raw, opts...) + return ssestream.NewStream[AssistantStreamEventUnion](ssestream.NewDecoder(raw), err) +} + +// Retrieves a run. +// +// Deprecated: The Assistants API is deprecated in favor of the Responses API +func (r *BetaThreadRunService) Get(ctx context.Context, threadID string, runID string, opts ...option.RequestOption) (res *Run, err error) { + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...) + if threadID == "" { + err = errors.New("missing required thread_id parameter") + return + } + if runID == "" { + err = errors.New("missing required run_id parameter") + return + } + path := fmt.Sprintf("threads/%s/runs/%s", threadID, runID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, nil, &res, opts...) + return +} + +// Modifies a run. +// +// Deprecated: The Assistants API is deprecated in favor of the Responses API +func (r *BetaThreadRunService) Update(ctx context.Context, threadID string, runID string, body BetaThreadRunUpdateParams, opts ...option.RequestOption) (res *Run, err error) { + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...) + if threadID == "" { + err = errors.New("missing required thread_id parameter") + return + } + if runID == "" { + err = errors.New("missing required run_id parameter") + return + } + path := fmt.Sprintf("threads/%s/runs/%s", threadID, runID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) + return +} + +// Returns a list of runs belonging to a thread. +// +// Deprecated: The Assistants API is deprecated in favor of the Responses API +func (r *BetaThreadRunService) List(ctx context.Context, threadID string, query BetaThreadRunListParams, opts ...option.RequestOption) (res *pagination.CursorPage[Run], err error) { + var raw *http.Response + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2"), option.WithResponseInto(&raw)}, opts...) + if threadID == "" { + err = errors.New("missing required thread_id parameter") + return + } + path := fmt.Sprintf("threads/%s/runs", threadID) + cfg, err := requestconfig.NewRequestConfig(ctx, http.MethodGet, path, query, &res, opts...) + if err != nil { + return nil, err + } + err = cfg.Execute() + if err != nil { + return nil, err + } + res.SetPageConfig(cfg, raw) + return res, nil +} + +// Returns a list of runs belonging to a thread. +// +// Deprecated: The Assistants API is deprecated in favor of the Responses API +func (r *BetaThreadRunService) ListAutoPaging(ctx context.Context, threadID string, query BetaThreadRunListParams, opts ...option.RequestOption) *pagination.CursorPageAutoPager[Run] { + return pagination.NewCursorPageAutoPager(r.List(ctx, threadID, query, opts...)) +} + +// Cancels a run that is `in_progress`. +// +// Deprecated: The Assistants API is deprecated in favor of the Responses API +func (r *BetaThreadRunService) Cancel(ctx context.Context, threadID string, runID string, opts ...option.RequestOption) (res *Run, err error) { + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...) + if threadID == "" { + err = errors.New("missing required thread_id parameter") + return + } + if runID == "" { + err = errors.New("missing required run_id parameter") + return + } + path := fmt.Sprintf("threads/%s/runs/%s/cancel", threadID, runID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, nil, &res, opts...) + return +} + +// When a run has the `status: "requires_action"` and `required_action.type` is +// `submit_tool_outputs`, this endpoint can be used to submit the outputs from the +// tool calls once they're all completed. All outputs must be submitted in a single +// request. +// +// Deprecated: The Assistants API is deprecated in favor of the Responses API +func (r *BetaThreadRunService) SubmitToolOutputs(ctx context.Context, threadID string, runID string, body BetaThreadRunSubmitToolOutputsParams, opts ...option.RequestOption) (res *Run, err error) { + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...) + if threadID == "" { + err = errors.New("missing required thread_id parameter") + return + } + if runID == "" { + err = errors.New("missing required run_id parameter") + return + } + path := fmt.Sprintf("threads/%s/runs/%s/submit_tool_outputs", threadID, runID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) + return +} + +// When a run has the `status: "requires_action"` and `required_action.type` is +// `submit_tool_outputs`, this endpoint can be used to submit the outputs from the +// tool calls once they're all completed. All outputs must be submitted in a single +// request. +// +// Deprecated: The Assistants API is deprecated in favor of the Responses API +func (r *BetaThreadRunService) SubmitToolOutputsStreaming(ctx context.Context, threadID string, runID string, body BetaThreadRunSubmitToolOutputsParams, opts ...option.RequestOption) (stream *ssestream.Stream[AssistantStreamEventUnion]) { + var ( + raw *http.Response + err error + ) + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2"), option.WithJSONSet("stream", true)}, opts...) + if threadID == "" { + err = errors.New("missing required thread_id parameter") + return + } + if runID == "" { + err = errors.New("missing required run_id parameter") + return + } + path := fmt.Sprintf("threads/%s/runs/%s/submit_tool_outputs", threadID, runID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &raw, opts...) + return ssestream.NewStream[AssistantStreamEventUnion](ssestream.NewDecoder(raw), err) +} + +// Tool call objects +type RequiredActionFunctionToolCall struct { + // The ID of the tool call. This ID must be referenced when you submit the tool + // outputs in using the + // [Submit tool outputs to run](https://platform.openai.com/docs/api-reference/runs/submitToolOutputs) + // endpoint. + ID string `json:"id,required"` + // The function definition. + Function RequiredActionFunctionToolCallFunction `json:"function,required"` + // The type of tool call the output is required for. For now, this is always + // `function`. + Type constant.Function `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + Function respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r RequiredActionFunctionToolCall) RawJSON() string { return r.JSON.raw } +func (r *RequiredActionFunctionToolCall) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The function definition. +type RequiredActionFunctionToolCallFunction struct { + // The arguments that the model expects you to pass to the function. + Arguments string `json:"arguments,required"` + // The name of the function. + Name string `json:"name,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Arguments respjson.Field + Name respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r RequiredActionFunctionToolCallFunction) RawJSON() string { return r.JSON.raw } +func (r *RequiredActionFunctionToolCallFunction) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Represents an execution run on a +// [thread](https://platform.openai.com/docs/api-reference/threads). +type Run struct { + // The identifier, which can be referenced in API endpoints. + ID string `json:"id,required"` + // The ID of the + // [assistant](https://platform.openai.com/docs/api-reference/assistants) used for + // execution of this run. + AssistantID string `json:"assistant_id,required"` + // The Unix timestamp (in seconds) for when the run was cancelled. + CancelledAt int64 `json:"cancelled_at,required"` + // The Unix timestamp (in seconds) for when the run was completed. + CompletedAt int64 `json:"completed_at,required"` + // The Unix timestamp (in seconds) for when the run was created. + CreatedAt int64 `json:"created_at,required"` + // The Unix timestamp (in seconds) for when the run will expire. + ExpiresAt int64 `json:"expires_at,required"` + // The Unix timestamp (in seconds) for when the run failed. + FailedAt int64 `json:"failed_at,required"` + // Details on why the run is incomplete. Will be `null` if the run is not + // incomplete. + IncompleteDetails RunIncompleteDetails `json:"incomplete_details,required"` + // The instructions that the + // [assistant](https://platform.openai.com/docs/api-reference/assistants) used for + // this run. + Instructions string `json:"instructions,required"` + // The last error associated with this run. Will be `null` if there are no errors. + LastError RunLastError `json:"last_error,required"` + // The maximum number of completion tokens specified to have been used over the + // course of the run. + MaxCompletionTokens int64 `json:"max_completion_tokens,required"` + // The maximum number of prompt tokens specified to have been used over the course + // of the run. + MaxPromptTokens int64 `json:"max_prompt_tokens,required"` + // Set of 16 key-value pairs that can be attached to an object. This can be useful + // for storing additional information about the object in a structured format, and + // querying for objects via API or the dashboard. + // + // Keys are strings with a maximum length of 64 characters. Values are strings with + // a maximum length of 512 characters. + Metadata shared.Metadata `json:"metadata,required"` + // The model that the + // [assistant](https://platform.openai.com/docs/api-reference/assistants) used for + // this run. + Model string `json:"model,required"` + // The object type, which is always `thread.run`. + Object constant.ThreadRun `json:"object,required"` + // Whether to enable + // [parallel function calling](https://platform.openai.com/docs/guides/function-calling#configuring-parallel-function-calling) + // during tool use. + ParallelToolCalls bool `json:"parallel_tool_calls,required"` + // Details on the action required to continue the run. Will be `null` if no action + // is required. + RequiredAction RunRequiredAction `json:"required_action,required"` + // Specifies the format that the model must output. Compatible with + // [GPT-4o](https://platform.openai.com/docs/models#gpt-4o), + // [GPT-4 Turbo](https://platform.openai.com/docs/models#gpt-4-turbo-and-gpt-4), + // and all GPT-3.5 Turbo models since `gpt-3.5-turbo-1106`. + // + // Setting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured + // Outputs which ensures the model will match your supplied JSON schema. Learn more + // in the + // [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs). + // + // Setting to `{ "type": "json_object" }` enables JSON mode, which ensures the + // message the model generates is valid JSON. + // + // **Important:** when using JSON mode, you **must** also instruct the model to + // produce JSON yourself via a system or user message. Without this, the model may + // generate an unending stream of whitespace until the generation reaches the token + // limit, resulting in a long-running and seemingly "stuck" request. Also note that + // the message content may be partially cut off if `finish_reason="length"`, which + // indicates the generation exceeded `max_tokens` or the conversation exceeded the + // max context length. + ResponseFormat AssistantResponseFormatOptionUnion `json:"response_format,required"` + // The Unix timestamp (in seconds) for when the run was started. + StartedAt int64 `json:"started_at,required"` + // The status of the run, which can be either `queued`, `in_progress`, + // `requires_action`, `cancelling`, `cancelled`, `failed`, `completed`, + // `incomplete`, or `expired`. + // + // Any of "queued", "in_progress", "requires_action", "cancelling", "cancelled", + // "failed", "completed", "incomplete", "expired". + Status RunStatus `json:"status,required"` + // The ID of the [thread](https://platform.openai.com/docs/api-reference/threads) + // that was executed on as a part of this run. + ThreadID string `json:"thread_id,required"` + // Controls which (if any) tool is called by the model. `none` means the model will + // not call any tools and instead generates a message. `auto` is the default value + // and means the model can pick between generating a message or calling one or more + // tools. `required` means the model must call one or more tools before responding + // to the user. Specifying a particular tool like `{"type": "file_search"}` or + // `{"type": "function", "function": {"name": "my_function"}}` forces the model to + // call that tool. + ToolChoice AssistantToolChoiceOptionUnion `json:"tool_choice,required"` + // The list of tools that the + // [assistant](https://platform.openai.com/docs/api-reference/assistants) used for + // this run. + Tools []AssistantToolUnion `json:"tools,required"` + // Controls for how a thread will be truncated prior to the run. Use this to + // control the intial context window of the run. + TruncationStrategy RunTruncationStrategy `json:"truncation_strategy,required"` + // Usage statistics related to the run. This value will be `null` if the run is not + // in a terminal state (i.e. `in_progress`, `queued`, etc.). + Usage RunUsage `json:"usage,required"` + // The sampling temperature used for this run. If not set, defaults to 1. + Temperature float64 `json:"temperature,nullable"` + // The nucleus sampling value used for this run. If not set, defaults to 1. + TopP float64 `json:"top_p,nullable"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + AssistantID respjson.Field + CancelledAt respjson.Field + CompletedAt respjson.Field + CreatedAt respjson.Field + ExpiresAt respjson.Field + FailedAt respjson.Field + IncompleteDetails respjson.Field + Instructions respjson.Field + LastError respjson.Field + MaxCompletionTokens respjson.Field + MaxPromptTokens respjson.Field + Metadata respjson.Field + Model respjson.Field + Object respjson.Field + ParallelToolCalls respjson.Field + RequiredAction respjson.Field + ResponseFormat respjson.Field + StartedAt respjson.Field + Status respjson.Field + ThreadID respjson.Field + ToolChoice respjson.Field + Tools respjson.Field + TruncationStrategy respjson.Field + Usage respjson.Field + Temperature respjson.Field + TopP respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r Run) RawJSON() string { return r.JSON.raw } +func (r *Run) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Details on why the run is incomplete. Will be `null` if the run is not +// incomplete. +type RunIncompleteDetails struct { + // The reason why the run is incomplete. This will point to which specific token + // limit was reached over the course of the run. + // + // Any of "max_completion_tokens", "max_prompt_tokens". + Reason string `json:"reason"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Reason respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r RunIncompleteDetails) RawJSON() string { return r.JSON.raw } +func (r *RunIncompleteDetails) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The last error associated with this run. Will be `null` if there are no errors. +type RunLastError struct { + // One of `server_error`, `rate_limit_exceeded`, or `invalid_prompt`. + // + // Any of "server_error", "rate_limit_exceeded", "invalid_prompt". + Code string `json:"code,required"` + // A human-readable description of the error. + Message string `json:"message,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Code respjson.Field + Message respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r RunLastError) RawJSON() string { return r.JSON.raw } +func (r *RunLastError) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Details on the action required to continue the run. Will be `null` if no action +// is required. +type RunRequiredAction struct { + // Details on the tool outputs needed for this run to continue. + SubmitToolOutputs RunRequiredActionSubmitToolOutputs `json:"submit_tool_outputs,required"` + // For now, this is always `submit_tool_outputs`. + Type constant.SubmitToolOutputs `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + SubmitToolOutputs respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r RunRequiredAction) RawJSON() string { return r.JSON.raw } +func (r *RunRequiredAction) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Details on the tool outputs needed for this run to continue. +type RunRequiredActionSubmitToolOutputs struct { + // A list of the relevant tool calls. + ToolCalls []RequiredActionFunctionToolCall `json:"tool_calls,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ToolCalls respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r RunRequiredActionSubmitToolOutputs) RawJSON() string { return r.JSON.raw } +func (r *RunRequiredActionSubmitToolOutputs) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Controls for how a thread will be truncated prior to the run. Use this to +// control the intial context window of the run. +type RunTruncationStrategy struct { + // The truncation strategy to use for the thread. The default is `auto`. If set to + // `last_messages`, the thread will be truncated to the n most recent messages in + // the thread. When set to `auto`, messages in the middle of the thread will be + // dropped to fit the context length of the model, `max_prompt_tokens`. + // + // Any of "auto", "last_messages". + Type string `json:"type,required"` + // The number of most recent messages from the thread when constructing the context + // for the run. + LastMessages int64 `json:"last_messages,nullable"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Type respjson.Field + LastMessages respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r RunTruncationStrategy) RawJSON() string { return r.JSON.raw } +func (r *RunTruncationStrategy) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Usage statistics related to the run. This value will be `null` if the run is not +// in a terminal state (i.e. `in_progress`, `queued`, etc.). +type RunUsage struct { + // Number of completion tokens used over the course of the run. + CompletionTokens int64 `json:"completion_tokens,required"` + // Number of prompt tokens used over the course of the run. + PromptTokens int64 `json:"prompt_tokens,required"` + // Total number of tokens used (prompt + completion). + TotalTokens int64 `json:"total_tokens,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + CompletionTokens respjson.Field + PromptTokens respjson.Field + TotalTokens respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r RunUsage) RawJSON() string { return r.JSON.raw } +func (r *RunUsage) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The status of the run, which can be either `queued`, `in_progress`, +// `requires_action`, `cancelling`, `cancelled`, `failed`, `completed`, +// `incomplete`, or `expired`. +type RunStatus string + +const ( + RunStatusQueued RunStatus = "queued" + RunStatusInProgress RunStatus = "in_progress" + RunStatusRequiresAction RunStatus = "requires_action" + RunStatusCancelling RunStatus = "cancelling" + RunStatusCancelled RunStatus = "cancelled" + RunStatusFailed RunStatus = "failed" + RunStatusCompleted RunStatus = "completed" + RunStatusIncomplete RunStatus = "incomplete" + RunStatusExpired RunStatus = "expired" +) + +type BetaThreadRunNewParams struct { + // The ID of the + // [assistant](https://platform.openai.com/docs/api-reference/assistants) to use to + // execute this run. + AssistantID string `json:"assistant_id,required"` + // Appends additional instructions at the end of the instructions for the run. This + // is useful for modifying the behavior on a per-run basis without overriding other + // instructions. + AdditionalInstructions param.Opt[string] `json:"additional_instructions,omitzero"` + // Overrides the + // [instructions](https://platform.openai.com/docs/api-reference/assistants/createAssistant) + // of the assistant. This is useful for modifying the behavior on a per-run basis. + Instructions param.Opt[string] `json:"instructions,omitzero"` + // The maximum number of completion tokens that may be used over the course of the + // run. The run will make a best effort to use only the number of completion tokens + // specified, across multiple turns of the run. If the run exceeds the number of + // completion tokens specified, the run will end with status `incomplete`. See + // `incomplete_details` for more info. + MaxCompletionTokens param.Opt[int64] `json:"max_completion_tokens,omitzero"` + // The maximum number of prompt tokens that may be used over the course of the run. + // The run will make a best effort to use only the number of prompt tokens + // specified, across multiple turns of the run. If the run exceeds the number of + // prompt tokens specified, the run will end with status `incomplete`. See + // `incomplete_details` for more info. + MaxPromptTokens param.Opt[int64] `json:"max_prompt_tokens,omitzero"` + // What sampling temperature to use, between 0 and 2. Higher values like 0.8 will + // make the output more random, while lower values like 0.2 will make it more + // focused and deterministic. + Temperature param.Opt[float64] `json:"temperature,omitzero"` + // An alternative to sampling with temperature, called nucleus sampling, where the + // model considers the results of the tokens with top_p probability mass. So 0.1 + // means only the tokens comprising the top 10% probability mass are considered. + // + // We generally recommend altering this or temperature but not both. + TopP param.Opt[float64] `json:"top_p,omitzero"` + // Whether to enable + // [parallel function calling](https://platform.openai.com/docs/guides/function-calling#configuring-parallel-function-calling) + // during tool use. + ParallelToolCalls param.Opt[bool] `json:"parallel_tool_calls,omitzero"` + // Adds additional messages to the thread before creating the run. + AdditionalMessages []BetaThreadRunNewParamsAdditionalMessage `json:"additional_messages,omitzero"` + // Set of 16 key-value pairs that can be attached to an object. This can be useful + // for storing additional information about the object in a structured format, and + // querying for objects via API or the dashboard. + // + // Keys are strings with a maximum length of 64 characters. Values are strings with + // a maximum length of 512 characters. + Metadata shared.Metadata `json:"metadata,omitzero"` + // The ID of the [Model](https://platform.openai.com/docs/api-reference/models) to + // be used to execute this run. If a value is provided here, it will override the + // model associated with the assistant. If not, the model associated with the + // assistant will be used. + Model shared.ChatModel `json:"model,omitzero"` + // **o-series models only** + // + // Constrains effort on reasoning for + // [reasoning models](https://platform.openai.com/docs/guides/reasoning). Currently + // supported values are `low`, `medium`, and `high`. Reducing reasoning effort can + // result in faster responses and fewer tokens used on reasoning in a response. + // + // Any of "low", "medium", "high". + ReasoningEffort shared.ReasoningEffort `json:"reasoning_effort,omitzero"` + // Override the tools the assistant can use for this run. This is useful for + // modifying the behavior on a per-run basis. + Tools []AssistantToolUnionParam `json:"tools,omitzero"` + // Controls for how a thread will be truncated prior to the run. Use this to + // control the intial context window of the run. + TruncationStrategy BetaThreadRunNewParamsTruncationStrategy `json:"truncation_strategy,omitzero"` + // A list of additional fields to include in the response. Currently the only + // supported value is `step_details.tool_calls[*].file_search.results[*].content` + // to fetch the file search result content. + // + // See the + // [file search tool documentation](https://platform.openai.com/docs/assistants/tools/file-search#customizing-file-search-settings) + // for more information. + Include []RunStepInclude `query:"include,omitzero" json:"-"` + // Specifies the format that the model must output. Compatible with + // [GPT-4o](https://platform.openai.com/docs/models#gpt-4o), + // [GPT-4 Turbo](https://platform.openai.com/docs/models#gpt-4-turbo-and-gpt-4), + // and all GPT-3.5 Turbo models since `gpt-3.5-turbo-1106`. + // + // Setting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured + // Outputs which ensures the model will match your supplied JSON schema. Learn more + // in the + // [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs). + // + // Setting to `{ "type": "json_object" }` enables JSON mode, which ensures the + // message the model generates is valid JSON. + // + // **Important:** when using JSON mode, you **must** also instruct the model to + // produce JSON yourself via a system or user message. Without this, the model may + // generate an unending stream of whitespace until the generation reaches the token + // limit, resulting in a long-running and seemingly "stuck" request. Also note that + // the message content may be partially cut off if `finish_reason="length"`, which + // indicates the generation exceeded `max_tokens` or the conversation exceeded the + // max context length. + ResponseFormat AssistantResponseFormatOptionUnionParam `json:"response_format,omitzero"` + // Controls which (if any) tool is called by the model. `none` means the model will + // not call any tools and instead generates a message. `auto` is the default value + // and means the model can pick between generating a message or calling one or more + // tools. `required` means the model must call one or more tools before responding + // to the user. Specifying a particular tool like `{"type": "file_search"}` or + // `{"type": "function", "function": {"name": "my_function"}}` forces the model to + // call that tool. + ToolChoice AssistantToolChoiceOptionUnionParam `json:"tool_choice,omitzero"` + paramObj +} + +func (r BetaThreadRunNewParams) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadRunNewParams + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadRunNewParams) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// URLQuery serializes [BetaThreadRunNewParams]'s query parameters as `url.Values`. +func (r BetaThreadRunNewParams) URLQuery() (v url.Values, err error) { + return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{ + ArrayFormat: apiquery.ArrayQueryFormatBrackets, + NestedFormat: apiquery.NestedQueryFormatBrackets, + }) +} + +// The properties Content, Role are required. +type BetaThreadRunNewParamsAdditionalMessage struct { + // The text contents of the message. + Content BetaThreadRunNewParamsAdditionalMessageContentUnion `json:"content,omitzero,required"` + // The role of the entity that is creating the message. Allowed values include: + // + // - `user`: Indicates the message is sent by an actual user and should be used in + // most cases to represent user-generated messages. + // - `assistant`: Indicates the message is generated by the assistant. Use this + // value to insert messages from the assistant into the conversation. + // + // Any of "user", "assistant". + Role string `json:"role,omitzero,required"` + // A list of files attached to the message, and the tools they should be added to. + Attachments []BetaThreadRunNewParamsAdditionalMessageAttachment `json:"attachments,omitzero"` + // Set of 16 key-value pairs that can be attached to an object. This can be useful + // for storing additional information about the object in a structured format, and + // querying for objects via API or the dashboard. + // + // Keys are strings with a maximum length of 64 characters. Values are strings with + // a maximum length of 512 characters. + Metadata shared.Metadata `json:"metadata,omitzero"` + paramObj +} + +func (r BetaThreadRunNewParamsAdditionalMessage) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadRunNewParamsAdditionalMessage + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadRunNewParamsAdditionalMessage) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +func init() { + apijson.RegisterFieldValidator[BetaThreadRunNewParamsAdditionalMessage]( + "role", "user", "assistant", + ) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type BetaThreadRunNewParamsAdditionalMessageContentUnion struct { + OfString param.Opt[string] `json:",omitzero,inline"` + OfArrayOfContentParts []MessageContentPartParamUnion `json:",omitzero,inline"` + paramUnion +} + +func (u BetaThreadRunNewParamsAdditionalMessageContentUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfString, u.OfArrayOfContentParts) +} +func (u *BetaThreadRunNewParamsAdditionalMessageContentUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *BetaThreadRunNewParamsAdditionalMessageContentUnion) asAny() any { + if !param.IsOmitted(u.OfString) { + return &u.OfString.Value + } else if !param.IsOmitted(u.OfArrayOfContentParts) { + return &u.OfArrayOfContentParts + } + return nil +} + +type BetaThreadRunNewParamsAdditionalMessageAttachment struct { + // The ID of the file to attach to the message. + FileID param.Opt[string] `json:"file_id,omitzero"` + // The tools to add this file to. + Tools []BetaThreadRunNewParamsAdditionalMessageAttachmentToolUnion `json:"tools,omitzero"` + paramObj +} + +func (r BetaThreadRunNewParamsAdditionalMessageAttachment) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadRunNewParamsAdditionalMessageAttachment + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadRunNewParamsAdditionalMessageAttachment) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type BetaThreadRunNewParamsAdditionalMessageAttachmentToolUnion struct { + OfCodeInterpreter *CodeInterpreterToolParam `json:",omitzero,inline"` + OfFileSearch *BetaThreadRunNewParamsAdditionalMessageAttachmentToolFileSearch `json:",omitzero,inline"` + paramUnion +} + +func (u BetaThreadRunNewParamsAdditionalMessageAttachmentToolUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfCodeInterpreter, u.OfFileSearch) +} +func (u *BetaThreadRunNewParamsAdditionalMessageAttachmentToolUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *BetaThreadRunNewParamsAdditionalMessageAttachmentToolUnion) asAny() any { + if !param.IsOmitted(u.OfCodeInterpreter) { + return u.OfCodeInterpreter + } else if !param.IsOmitted(u.OfFileSearch) { + return u.OfFileSearch + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u BetaThreadRunNewParamsAdditionalMessageAttachmentToolUnion) GetType() *string { + if vt := u.OfCodeInterpreter; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfFileSearch; vt != nil { + return (*string)(&vt.Type) + } + return nil +} + +func init() { + apijson.RegisterUnion[BetaThreadRunNewParamsAdditionalMessageAttachmentToolUnion]( + "type", + apijson.Discriminator[CodeInterpreterToolParam]("code_interpreter"), + apijson.Discriminator[BetaThreadRunNewParamsAdditionalMessageAttachmentToolFileSearch]("file_search"), + ) +} + +func NewBetaThreadRunNewParamsAdditionalMessageAttachmentToolFileSearch() BetaThreadRunNewParamsAdditionalMessageAttachmentToolFileSearch { + return BetaThreadRunNewParamsAdditionalMessageAttachmentToolFileSearch{ + Type: "file_search", + } +} + +// This struct has a constant value, construct it with +// [NewBetaThreadRunNewParamsAdditionalMessageAttachmentToolFileSearch]. +type BetaThreadRunNewParamsAdditionalMessageAttachmentToolFileSearch struct { + // The type of tool being defined: `file_search` + Type constant.FileSearch `json:"type,required"` + paramObj +} + +func (r BetaThreadRunNewParamsAdditionalMessageAttachmentToolFileSearch) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadRunNewParamsAdditionalMessageAttachmentToolFileSearch + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadRunNewParamsAdditionalMessageAttachmentToolFileSearch) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Controls for how a thread will be truncated prior to the run. Use this to +// control the intial context window of the run. +// +// The property Type is required. +type BetaThreadRunNewParamsTruncationStrategy struct { + // The truncation strategy to use for the thread. The default is `auto`. If set to + // `last_messages`, the thread will be truncated to the n most recent messages in + // the thread. When set to `auto`, messages in the middle of the thread will be + // dropped to fit the context length of the model, `max_prompt_tokens`. + // + // Any of "auto", "last_messages". + Type string `json:"type,omitzero,required"` + // The number of most recent messages from the thread when constructing the context + // for the run. + LastMessages param.Opt[int64] `json:"last_messages,omitzero"` + paramObj +} + +func (r BetaThreadRunNewParamsTruncationStrategy) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadRunNewParamsTruncationStrategy + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadRunNewParamsTruncationStrategy) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +func init() { + apijson.RegisterFieldValidator[BetaThreadRunNewParamsTruncationStrategy]( + "type", "auto", "last_messages", + ) +} + +type BetaThreadRunUpdateParams struct { + // Set of 16 key-value pairs that can be attached to an object. This can be useful + // for storing additional information about the object in a structured format, and + // querying for objects via API or the dashboard. + // + // Keys are strings with a maximum length of 64 characters. Values are strings with + // a maximum length of 512 characters. + Metadata shared.Metadata `json:"metadata,omitzero"` + paramObj +} + +func (r BetaThreadRunUpdateParams) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadRunUpdateParams + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadRunUpdateParams) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type BetaThreadRunListParams struct { + // A cursor for use in pagination. `after` is an object ID that defines your place + // in the list. For instance, if you make a list request and receive 100 objects, + // ending with obj_foo, your subsequent call can include after=obj_foo in order to + // fetch the next page of the list. + After param.Opt[string] `query:"after,omitzero" json:"-"` + // A cursor for use in pagination. `before` is an object ID that defines your place + // in the list. For instance, if you make a list request and receive 100 objects, + // starting with obj_foo, your subsequent call can include before=obj_foo in order + // to fetch the previous page of the list. + Before param.Opt[string] `query:"before,omitzero" json:"-"` + // A limit on the number of objects to be returned. Limit can range between 1 and + // 100, and the default is 20. + Limit param.Opt[int64] `query:"limit,omitzero" json:"-"` + // Sort order by the `created_at` timestamp of the objects. `asc` for ascending + // order and `desc` for descending order. + // + // Any of "asc", "desc". + Order BetaThreadRunListParamsOrder `query:"order,omitzero" json:"-"` + paramObj +} + +// URLQuery serializes [BetaThreadRunListParams]'s query parameters as +// `url.Values`. +func (r BetaThreadRunListParams) URLQuery() (v url.Values, err error) { + return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{ + ArrayFormat: apiquery.ArrayQueryFormatBrackets, + NestedFormat: apiquery.NestedQueryFormatBrackets, + }) +} + +// Sort order by the `created_at` timestamp of the objects. `asc` for ascending +// order and `desc` for descending order. +type BetaThreadRunListParamsOrder string + +const ( + BetaThreadRunListParamsOrderAsc BetaThreadRunListParamsOrder = "asc" + BetaThreadRunListParamsOrderDesc BetaThreadRunListParamsOrder = "desc" +) + +type BetaThreadRunSubmitToolOutputsParams struct { + // A list of tools for which the outputs are being submitted. + ToolOutputs []BetaThreadRunSubmitToolOutputsParamsToolOutput `json:"tool_outputs,omitzero,required"` + paramObj +} + +func (r BetaThreadRunSubmitToolOutputsParams) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadRunSubmitToolOutputsParams + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadRunSubmitToolOutputsParams) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type BetaThreadRunSubmitToolOutputsParamsToolOutput struct { + // The output of the tool call to be submitted to continue the run. + Output param.Opt[string] `json:"output,omitzero"` + // The ID of the tool call in the `required_action` object within the run object + // the output is being submitted for. + ToolCallID param.Opt[string] `json:"tool_call_id,omitzero"` + paramObj +} + +func (r BetaThreadRunSubmitToolOutputsParamsToolOutput) MarshalJSON() (data []byte, err error) { + type shadow BetaThreadRunSubmitToolOutputsParamsToolOutput + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *BetaThreadRunSubmitToolOutputsParamsToolOutput) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} diff --git a/vendor/github.com/openai/openai-go/betathreadrunstep.go b/vendor/github.com/openai/openai-go/betathreadrunstep.go new file mode 100644 index 0000000000..1ae783e574 --- /dev/null +++ b/vendor/github.com/openai/openai-go/betathreadrunstep.go @@ -0,0 +1,1393 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package openai + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + + "github.com/openai/openai-go/internal/apijson" + "github.com/openai/openai-go/internal/apiquery" + "github.com/openai/openai-go/internal/requestconfig" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/packages/pagination" + "github.com/openai/openai-go/packages/param" + "github.com/openai/openai-go/packages/respjson" + "github.com/openai/openai-go/shared" + "github.com/openai/openai-go/shared/constant" +) + +// BetaThreadRunStepService contains methods and other services that help with +// interacting with the openai API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewBetaThreadRunStepService] method instead. +// +// Deprecated: The Assistants API is deprecated in favor of the Responses API +type BetaThreadRunStepService struct { + Options []option.RequestOption +} + +// NewBetaThreadRunStepService generates a new service that applies the given +// options to each request. These options are applied after the parent client's +// options (if there is one), and before any request-specific options. +func NewBetaThreadRunStepService(opts ...option.RequestOption) (r BetaThreadRunStepService) { + r = BetaThreadRunStepService{} + r.Options = opts + return +} + +// Retrieves a run step. +// +// Deprecated: The Assistants API is deprecated in favor of the Responses API +func (r *BetaThreadRunStepService) Get(ctx context.Context, threadID string, runID string, stepID string, query BetaThreadRunStepGetParams, opts ...option.RequestOption) (res *RunStep, err error) { + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...) + if threadID == "" { + err = errors.New("missing required thread_id parameter") + return + } + if runID == "" { + err = errors.New("missing required run_id parameter") + return + } + if stepID == "" { + err = errors.New("missing required step_id parameter") + return + } + path := fmt.Sprintf("threads/%s/runs/%s/steps/%s", threadID, runID, stepID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, query, &res, opts...) + return +} + +// Returns a list of run steps belonging to a run. +// +// Deprecated: The Assistants API is deprecated in favor of the Responses API +func (r *BetaThreadRunStepService) List(ctx context.Context, threadID string, runID string, query BetaThreadRunStepListParams, opts ...option.RequestOption) (res *pagination.CursorPage[RunStep], err error) { + var raw *http.Response + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2"), option.WithResponseInto(&raw)}, opts...) + if threadID == "" { + err = errors.New("missing required thread_id parameter") + return + } + if runID == "" { + err = errors.New("missing required run_id parameter") + return + } + path := fmt.Sprintf("threads/%s/runs/%s/steps", threadID, runID) + cfg, err := requestconfig.NewRequestConfig(ctx, http.MethodGet, path, query, &res, opts...) + if err != nil { + return nil, err + } + err = cfg.Execute() + if err != nil { + return nil, err + } + res.SetPageConfig(cfg, raw) + return res, nil +} + +// Returns a list of run steps belonging to a run. +// +// Deprecated: The Assistants API is deprecated in favor of the Responses API +func (r *BetaThreadRunStepService) ListAutoPaging(ctx context.Context, threadID string, runID string, query BetaThreadRunStepListParams, opts ...option.RequestOption) *pagination.CursorPageAutoPager[RunStep] { + return pagination.NewCursorPageAutoPager(r.List(ctx, threadID, runID, query, opts...)) +} + +// Text output from the Code Interpreter tool call as part of a run step. +type CodeInterpreterLogs struct { + // The index of the output in the outputs array. + Index int64 `json:"index,required"` + // Always `logs`. + Type constant.Logs `json:"type,required"` + // The text output from the Code Interpreter tool call. + Logs string `json:"logs"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Index respjson.Field + Type respjson.Field + Logs respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r CodeInterpreterLogs) RawJSON() string { return r.JSON.raw } +func (r *CodeInterpreterLogs) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type CodeInterpreterOutputImage struct { + // The index of the output in the outputs array. + Index int64 `json:"index,required"` + // Always `image`. + Type constant.Image `json:"type,required"` + Image CodeInterpreterOutputImageImage `json:"image"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Index respjson.Field + Type respjson.Field + Image respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r CodeInterpreterOutputImage) RawJSON() string { return r.JSON.raw } +func (r *CodeInterpreterOutputImage) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type CodeInterpreterOutputImageImage struct { + // The [file](https://platform.openai.com/docs/api-reference/files) ID of the + // image. + FileID string `json:"file_id"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + FileID respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r CodeInterpreterOutputImageImage) RawJSON() string { return r.JSON.raw } +func (r *CodeInterpreterOutputImageImage) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Details of the Code Interpreter tool call the run step was involved in. +type CodeInterpreterToolCall struct { + // The ID of the tool call. + ID string `json:"id,required"` + // The Code Interpreter tool call definition. + CodeInterpreter CodeInterpreterToolCallCodeInterpreter `json:"code_interpreter,required"` + // The type of tool call. This is always going to be `code_interpreter` for this + // type of tool call. + Type constant.CodeInterpreter `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + CodeInterpreter respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r CodeInterpreterToolCall) RawJSON() string { return r.JSON.raw } +func (r *CodeInterpreterToolCall) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The Code Interpreter tool call definition. +type CodeInterpreterToolCallCodeInterpreter struct { + // The input to the Code Interpreter tool call. + Input string `json:"input,required"` + // The outputs from the Code Interpreter tool call. Code Interpreter can output one + // or more items, including text (`logs`) or images (`image`). Each of these are + // represented by a different object type. + Outputs []CodeInterpreterToolCallCodeInterpreterOutputUnion `json:"outputs,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Input respjson.Field + Outputs respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r CodeInterpreterToolCallCodeInterpreter) RawJSON() string { return r.JSON.raw } +func (r *CodeInterpreterToolCallCodeInterpreter) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// CodeInterpreterToolCallCodeInterpreterOutputUnion contains all possible +// properties and values from [CodeInterpreterToolCallCodeInterpreterOutputLogs], +// [CodeInterpreterToolCallCodeInterpreterOutputImage]. +// +// Use the [CodeInterpreterToolCallCodeInterpreterOutputUnion.AsAny] method to +// switch on the variant. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +type CodeInterpreterToolCallCodeInterpreterOutputUnion struct { + // This field is from variant [CodeInterpreterToolCallCodeInterpreterOutputLogs]. + Logs string `json:"logs"` + // Any of "logs", "image". + Type string `json:"type"` + // This field is from variant [CodeInterpreterToolCallCodeInterpreterOutputImage]. + Image CodeInterpreterToolCallCodeInterpreterOutputImageImage `json:"image"` + JSON struct { + Logs respjson.Field + Type respjson.Field + Image respjson.Field + raw string + } `json:"-"` +} + +// anyCodeInterpreterToolCallCodeInterpreterOutput is implemented by each variant +// of [CodeInterpreterToolCallCodeInterpreterOutputUnion] to add type safety for +// the return type of [CodeInterpreterToolCallCodeInterpreterOutputUnion.AsAny] +type anyCodeInterpreterToolCallCodeInterpreterOutput interface { + implCodeInterpreterToolCallCodeInterpreterOutputUnion() +} + +func (CodeInterpreterToolCallCodeInterpreterOutputLogs) implCodeInterpreterToolCallCodeInterpreterOutputUnion() { +} +func (CodeInterpreterToolCallCodeInterpreterOutputImage) implCodeInterpreterToolCallCodeInterpreterOutputUnion() { +} + +// Use the following switch statement to find the correct variant +// +// switch variant := CodeInterpreterToolCallCodeInterpreterOutputUnion.AsAny().(type) { +// case openai.CodeInterpreterToolCallCodeInterpreterOutputLogs: +// case openai.CodeInterpreterToolCallCodeInterpreterOutputImage: +// default: +// fmt.Errorf("no variant present") +// } +func (u CodeInterpreterToolCallCodeInterpreterOutputUnion) AsAny() anyCodeInterpreterToolCallCodeInterpreterOutput { + switch u.Type { + case "logs": + return u.AsLogs() + case "image": + return u.AsImage() + } + return nil +} + +func (u CodeInterpreterToolCallCodeInterpreterOutputUnion) AsLogs() (v CodeInterpreterToolCallCodeInterpreterOutputLogs) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u CodeInterpreterToolCallCodeInterpreterOutputUnion) AsImage() (v CodeInterpreterToolCallCodeInterpreterOutputImage) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u CodeInterpreterToolCallCodeInterpreterOutputUnion) RawJSON() string { return u.JSON.raw } + +func (r *CodeInterpreterToolCallCodeInterpreterOutputUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Text output from the Code Interpreter tool call as part of a run step. +type CodeInterpreterToolCallCodeInterpreterOutputLogs struct { + // The text output from the Code Interpreter tool call. + Logs string `json:"logs,required"` + // Always `logs`. + Type constant.Logs `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Logs respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r CodeInterpreterToolCallCodeInterpreterOutputLogs) RawJSON() string { return r.JSON.raw } +func (r *CodeInterpreterToolCallCodeInterpreterOutputLogs) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type CodeInterpreterToolCallCodeInterpreterOutputImage struct { + Image CodeInterpreterToolCallCodeInterpreterOutputImageImage `json:"image,required"` + // Always `image`. + Type constant.Image `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Image respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r CodeInterpreterToolCallCodeInterpreterOutputImage) RawJSON() string { return r.JSON.raw } +func (r *CodeInterpreterToolCallCodeInterpreterOutputImage) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type CodeInterpreterToolCallCodeInterpreterOutputImageImage struct { + // The [file](https://platform.openai.com/docs/api-reference/files) ID of the + // image. + FileID string `json:"file_id,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + FileID respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r CodeInterpreterToolCallCodeInterpreterOutputImageImage) RawJSON() string { return r.JSON.raw } +func (r *CodeInterpreterToolCallCodeInterpreterOutputImageImage) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Details of the Code Interpreter tool call the run step was involved in. +type CodeInterpreterToolCallDelta struct { + // The index of the tool call in the tool calls array. + Index int64 `json:"index,required"` + // The type of tool call. This is always going to be `code_interpreter` for this + // type of tool call. + Type constant.CodeInterpreter `json:"type,required"` + // The ID of the tool call. + ID string `json:"id"` + // The Code Interpreter tool call definition. + CodeInterpreter CodeInterpreterToolCallDeltaCodeInterpreter `json:"code_interpreter"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Index respjson.Field + Type respjson.Field + ID respjson.Field + CodeInterpreter respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r CodeInterpreterToolCallDelta) RawJSON() string { return r.JSON.raw } +func (r *CodeInterpreterToolCallDelta) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The Code Interpreter tool call definition. +type CodeInterpreterToolCallDeltaCodeInterpreter struct { + // The input to the Code Interpreter tool call. + Input string `json:"input"` + // The outputs from the Code Interpreter tool call. Code Interpreter can output one + // or more items, including text (`logs`) or images (`image`). Each of these are + // represented by a different object type. + Outputs []CodeInterpreterToolCallDeltaCodeInterpreterOutputUnion `json:"outputs"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Input respjson.Field + Outputs respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r CodeInterpreterToolCallDeltaCodeInterpreter) RawJSON() string { return r.JSON.raw } +func (r *CodeInterpreterToolCallDeltaCodeInterpreter) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// CodeInterpreterToolCallDeltaCodeInterpreterOutputUnion contains all possible +// properties and values from [CodeInterpreterLogs], [CodeInterpreterOutputImage]. +// +// Use the [CodeInterpreterToolCallDeltaCodeInterpreterOutputUnion.AsAny] method to +// switch on the variant. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +type CodeInterpreterToolCallDeltaCodeInterpreterOutputUnion struct { + Index int64 `json:"index"` + // Any of "logs", "image". + Type string `json:"type"` + // This field is from variant [CodeInterpreterLogs]. + Logs string `json:"logs"` + // This field is from variant [CodeInterpreterOutputImage]. + Image CodeInterpreterOutputImageImage `json:"image"` + JSON struct { + Index respjson.Field + Type respjson.Field + Logs respjson.Field + Image respjson.Field + raw string + } `json:"-"` +} + +// anyCodeInterpreterToolCallDeltaCodeInterpreterOutput is implemented by each +// variant of [CodeInterpreterToolCallDeltaCodeInterpreterOutputUnion] to add type +// safety for the return type of +// [CodeInterpreterToolCallDeltaCodeInterpreterOutputUnion.AsAny] +type anyCodeInterpreterToolCallDeltaCodeInterpreterOutput interface { + implCodeInterpreterToolCallDeltaCodeInterpreterOutputUnion() +} + +func (CodeInterpreterLogs) implCodeInterpreterToolCallDeltaCodeInterpreterOutputUnion() {} +func (CodeInterpreterOutputImage) implCodeInterpreterToolCallDeltaCodeInterpreterOutputUnion() {} + +// Use the following switch statement to find the correct variant +// +// switch variant := CodeInterpreterToolCallDeltaCodeInterpreterOutputUnion.AsAny().(type) { +// case openai.CodeInterpreterLogs: +// case openai.CodeInterpreterOutputImage: +// default: +// fmt.Errorf("no variant present") +// } +func (u CodeInterpreterToolCallDeltaCodeInterpreterOutputUnion) AsAny() anyCodeInterpreterToolCallDeltaCodeInterpreterOutput { + switch u.Type { + case "logs": + return u.AsLogs() + case "image": + return u.AsImage() + } + return nil +} + +func (u CodeInterpreterToolCallDeltaCodeInterpreterOutputUnion) AsLogs() (v CodeInterpreterLogs) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u CodeInterpreterToolCallDeltaCodeInterpreterOutputUnion) AsImage() (v CodeInterpreterOutputImage) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u CodeInterpreterToolCallDeltaCodeInterpreterOutputUnion) RawJSON() string { return u.JSON.raw } + +func (r *CodeInterpreterToolCallDeltaCodeInterpreterOutputUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FileSearchToolCall struct { + // The ID of the tool call object. + ID string `json:"id,required"` + // For now, this is always going to be an empty object. + FileSearch FileSearchToolCallFileSearch `json:"file_search,required"` + // The type of tool call. This is always going to be `file_search` for this type of + // tool call. + Type constant.FileSearch `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + FileSearch respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FileSearchToolCall) RawJSON() string { return r.JSON.raw } +func (r *FileSearchToolCall) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// For now, this is always going to be an empty object. +type FileSearchToolCallFileSearch struct { + // The ranking options for the file search. + RankingOptions FileSearchToolCallFileSearchRankingOptions `json:"ranking_options"` + // The results of the file search. + Results []FileSearchToolCallFileSearchResult `json:"results"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + RankingOptions respjson.Field + Results respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FileSearchToolCallFileSearch) RawJSON() string { return r.JSON.raw } +func (r *FileSearchToolCallFileSearch) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The ranking options for the file search. +type FileSearchToolCallFileSearchRankingOptions struct { + // The ranker to use for the file search. If not specified will use the `auto` + // ranker. + // + // Any of "auto", "default_2024_08_21". + Ranker string `json:"ranker,required"` + // The score threshold for the file search. All values must be a floating point + // number between 0 and 1. + ScoreThreshold float64 `json:"score_threshold,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Ranker respjson.Field + ScoreThreshold respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FileSearchToolCallFileSearchRankingOptions) RawJSON() string { return r.JSON.raw } +func (r *FileSearchToolCallFileSearchRankingOptions) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// A result instance of the file search. +type FileSearchToolCallFileSearchResult struct { + // The ID of the file that result was found in. + FileID string `json:"file_id,required"` + // The name of the file that result was found in. + FileName string `json:"file_name,required"` + // The score of the result. All values must be a floating point number between 0 + // and 1. + Score float64 `json:"score,required"` + // The content of the result that was found. The content is only included if + // requested via the include query parameter. + Content []FileSearchToolCallFileSearchResultContent `json:"content"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + FileID respjson.Field + FileName respjson.Field + Score respjson.Field + Content respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FileSearchToolCallFileSearchResult) RawJSON() string { return r.JSON.raw } +func (r *FileSearchToolCallFileSearchResult) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FileSearchToolCallFileSearchResultContent struct { + // The text content of the file. + Text string `json:"text"` + // The type of the content. + // + // Any of "text". + Type string `json:"type"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Text respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FileSearchToolCallFileSearchResultContent) RawJSON() string { return r.JSON.raw } +func (r *FileSearchToolCallFileSearchResultContent) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FileSearchToolCallDelta struct { + // For now, this is always going to be an empty object. + FileSearch any `json:"file_search,required"` + // The index of the tool call in the tool calls array. + Index int64 `json:"index,required"` + // The type of tool call. This is always going to be `file_search` for this type of + // tool call. + Type constant.FileSearch `json:"type,required"` + // The ID of the tool call object. + ID string `json:"id"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + FileSearch respjson.Field + Index respjson.Field + Type respjson.Field + ID respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FileSearchToolCallDelta) RawJSON() string { return r.JSON.raw } +func (r *FileSearchToolCallDelta) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FunctionToolCall struct { + // The ID of the tool call object. + ID string `json:"id,required"` + // The definition of the function that was called. + Function FunctionToolCallFunction `json:"function,required"` + // The type of tool call. This is always going to be `function` for this type of + // tool call. + Type constant.Function `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + Function respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FunctionToolCall) RawJSON() string { return r.JSON.raw } +func (r *FunctionToolCall) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The definition of the function that was called. +type FunctionToolCallFunction struct { + // The arguments passed to the function. + Arguments string `json:"arguments,required"` + // The name of the function. + Name string `json:"name,required"` + // The output of the function. This will be `null` if the outputs have not been + // [submitted](https://platform.openai.com/docs/api-reference/runs/submitToolOutputs) + // yet. + Output string `json:"output,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Arguments respjson.Field + Name respjson.Field + Output respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FunctionToolCallFunction) RawJSON() string { return r.JSON.raw } +func (r *FunctionToolCallFunction) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FunctionToolCallDelta struct { + // The index of the tool call in the tool calls array. + Index int64 `json:"index,required"` + // The type of tool call. This is always going to be `function` for this type of + // tool call. + Type constant.Function `json:"type,required"` + // The ID of the tool call object. + ID string `json:"id"` + // The definition of the function that was called. + Function FunctionToolCallDeltaFunction `json:"function"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Index respjson.Field + Type respjson.Field + ID respjson.Field + Function respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FunctionToolCallDelta) RawJSON() string { return r.JSON.raw } +func (r *FunctionToolCallDelta) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The definition of the function that was called. +type FunctionToolCallDeltaFunction struct { + // The arguments passed to the function. + Arguments string `json:"arguments"` + // The name of the function. + Name string `json:"name"` + // The output of the function. This will be `null` if the outputs have not been + // [submitted](https://platform.openai.com/docs/api-reference/runs/submitToolOutputs) + // yet. + Output string `json:"output,nullable"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Arguments respjson.Field + Name respjson.Field + Output respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FunctionToolCallDeltaFunction) RawJSON() string { return r.JSON.raw } +func (r *FunctionToolCallDeltaFunction) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Details of the message creation by the run step. +type MessageCreationStepDetails struct { + MessageCreation MessageCreationStepDetailsMessageCreation `json:"message_creation,required"` + // Always `message_creation`. + Type constant.MessageCreation `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + MessageCreation respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r MessageCreationStepDetails) RawJSON() string { return r.JSON.raw } +func (r *MessageCreationStepDetails) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type MessageCreationStepDetailsMessageCreation struct { + // The ID of the message that was created by this run step. + MessageID string `json:"message_id,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + MessageID respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r MessageCreationStepDetailsMessageCreation) RawJSON() string { return r.JSON.raw } +func (r *MessageCreationStepDetailsMessageCreation) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Represents a step in execution of a run. +type RunStep struct { + // The identifier of the run step, which can be referenced in API endpoints. + ID string `json:"id,required"` + // The ID of the + // [assistant](https://platform.openai.com/docs/api-reference/assistants) + // associated with the run step. + AssistantID string `json:"assistant_id,required"` + // The Unix timestamp (in seconds) for when the run step was cancelled. + CancelledAt int64 `json:"cancelled_at,required"` + // The Unix timestamp (in seconds) for when the run step completed. + CompletedAt int64 `json:"completed_at,required"` + // The Unix timestamp (in seconds) for when the run step was created. + CreatedAt int64 `json:"created_at,required"` + // The Unix timestamp (in seconds) for when the run step expired. A step is + // considered expired if the parent run is expired. + ExpiredAt int64 `json:"expired_at,required"` + // The Unix timestamp (in seconds) for when the run step failed. + FailedAt int64 `json:"failed_at,required"` + // The last error associated with this run step. Will be `null` if there are no + // errors. + LastError RunStepLastError `json:"last_error,required"` + // Set of 16 key-value pairs that can be attached to an object. This can be useful + // for storing additional information about the object in a structured format, and + // querying for objects via API or the dashboard. + // + // Keys are strings with a maximum length of 64 characters. Values are strings with + // a maximum length of 512 characters. + Metadata shared.Metadata `json:"metadata,required"` + // The object type, which is always `thread.run.step`. + Object constant.ThreadRunStep `json:"object,required"` + // The ID of the [run](https://platform.openai.com/docs/api-reference/runs) that + // this run step is a part of. + RunID string `json:"run_id,required"` + // The status of the run step, which can be either `in_progress`, `cancelled`, + // `failed`, `completed`, or `expired`. + // + // Any of "in_progress", "cancelled", "failed", "completed", "expired". + Status RunStepStatus `json:"status,required"` + // The details of the run step. + StepDetails RunStepStepDetailsUnion `json:"step_details,required"` + // The ID of the [thread](https://platform.openai.com/docs/api-reference/threads) + // that was run. + ThreadID string `json:"thread_id,required"` + // The type of run step, which can be either `message_creation` or `tool_calls`. + // + // Any of "message_creation", "tool_calls". + Type RunStepType `json:"type,required"` + // Usage statistics related to the run step. This value will be `null` while the + // run step's status is `in_progress`. + Usage RunStepUsage `json:"usage,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + AssistantID respjson.Field + CancelledAt respjson.Field + CompletedAt respjson.Field + CreatedAt respjson.Field + ExpiredAt respjson.Field + FailedAt respjson.Field + LastError respjson.Field + Metadata respjson.Field + Object respjson.Field + RunID respjson.Field + Status respjson.Field + StepDetails respjson.Field + ThreadID respjson.Field + Type respjson.Field + Usage respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r RunStep) RawJSON() string { return r.JSON.raw } +func (r *RunStep) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The last error associated with this run step. Will be `null` if there are no +// errors. +type RunStepLastError struct { + // One of `server_error` or `rate_limit_exceeded`. + // + // Any of "server_error", "rate_limit_exceeded". + Code string `json:"code,required"` + // A human-readable description of the error. + Message string `json:"message,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Code respjson.Field + Message respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r RunStepLastError) RawJSON() string { return r.JSON.raw } +func (r *RunStepLastError) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The status of the run step, which can be either `in_progress`, `cancelled`, +// `failed`, `completed`, or `expired`. +type RunStepStatus string + +const ( + RunStepStatusInProgress RunStepStatus = "in_progress" + RunStepStatusCancelled RunStepStatus = "cancelled" + RunStepStatusFailed RunStepStatus = "failed" + RunStepStatusCompleted RunStepStatus = "completed" + RunStepStatusExpired RunStepStatus = "expired" +) + +// RunStepStepDetailsUnion contains all possible properties and values from +// [MessageCreationStepDetails], [ToolCallsStepDetails]. +// +// Use the [RunStepStepDetailsUnion.AsAny] method to switch on the variant. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +type RunStepStepDetailsUnion struct { + // This field is from variant [MessageCreationStepDetails]. + MessageCreation MessageCreationStepDetailsMessageCreation `json:"message_creation"` + // Any of "message_creation", "tool_calls". + Type string `json:"type"` + // This field is from variant [ToolCallsStepDetails]. + ToolCalls []ToolCallUnion `json:"tool_calls"` + JSON struct { + MessageCreation respjson.Field + Type respjson.Field + ToolCalls respjson.Field + raw string + } `json:"-"` +} + +// anyRunStepStepDetails is implemented by each variant of +// [RunStepStepDetailsUnion] to add type safety for the return type of +// [RunStepStepDetailsUnion.AsAny] +type anyRunStepStepDetails interface { + implRunStepStepDetailsUnion() +} + +func (MessageCreationStepDetails) implRunStepStepDetailsUnion() {} +func (ToolCallsStepDetails) implRunStepStepDetailsUnion() {} + +// Use the following switch statement to find the correct variant +// +// switch variant := RunStepStepDetailsUnion.AsAny().(type) { +// case openai.MessageCreationStepDetails: +// case openai.ToolCallsStepDetails: +// default: +// fmt.Errorf("no variant present") +// } +func (u RunStepStepDetailsUnion) AsAny() anyRunStepStepDetails { + switch u.Type { + case "message_creation": + return u.AsMessageCreation() + case "tool_calls": + return u.AsToolCalls() + } + return nil +} + +func (u RunStepStepDetailsUnion) AsMessageCreation() (v MessageCreationStepDetails) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u RunStepStepDetailsUnion) AsToolCalls() (v ToolCallsStepDetails) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u RunStepStepDetailsUnion) RawJSON() string { return u.JSON.raw } + +func (r *RunStepStepDetailsUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The type of run step, which can be either `message_creation` or `tool_calls`. +type RunStepType string + +const ( + RunStepTypeMessageCreation RunStepType = "message_creation" + RunStepTypeToolCalls RunStepType = "tool_calls" +) + +// Usage statistics related to the run step. This value will be `null` while the +// run step's status is `in_progress`. +type RunStepUsage struct { + // Number of completion tokens used over the course of the run step. + CompletionTokens int64 `json:"completion_tokens,required"` + // Number of prompt tokens used over the course of the run step. + PromptTokens int64 `json:"prompt_tokens,required"` + // Total number of tokens used (prompt + completion). + TotalTokens int64 `json:"total_tokens,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + CompletionTokens respjson.Field + PromptTokens respjson.Field + TotalTokens respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r RunStepUsage) RawJSON() string { return r.JSON.raw } +func (r *RunStepUsage) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The delta containing the fields that have changed on the run step. +type RunStepDelta struct { + // The details of the run step. + StepDetails RunStepDeltaStepDetailsUnion `json:"step_details"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + StepDetails respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r RunStepDelta) RawJSON() string { return r.JSON.raw } +func (r *RunStepDelta) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// RunStepDeltaStepDetailsUnion contains all possible properties and values from +// [RunStepDeltaMessageDelta], [ToolCallDeltaObject]. +// +// Use the [RunStepDeltaStepDetailsUnion.AsAny] method to switch on the variant. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +type RunStepDeltaStepDetailsUnion struct { + // Any of "message_creation", "tool_calls". + Type string `json:"type"` + // This field is from variant [RunStepDeltaMessageDelta]. + MessageCreation RunStepDeltaMessageDeltaMessageCreation `json:"message_creation"` + // This field is from variant [ToolCallDeltaObject]. + ToolCalls []ToolCallDeltaUnion `json:"tool_calls"` + JSON struct { + Type respjson.Field + MessageCreation respjson.Field + ToolCalls respjson.Field + raw string + } `json:"-"` +} + +// anyRunStepDeltaStepDetails is implemented by each variant of +// [RunStepDeltaStepDetailsUnion] to add type safety for the return type of +// [RunStepDeltaStepDetailsUnion.AsAny] +type anyRunStepDeltaStepDetails interface { + implRunStepDeltaStepDetailsUnion() +} + +func (RunStepDeltaMessageDelta) implRunStepDeltaStepDetailsUnion() {} +func (ToolCallDeltaObject) implRunStepDeltaStepDetailsUnion() {} + +// Use the following switch statement to find the correct variant +// +// switch variant := RunStepDeltaStepDetailsUnion.AsAny().(type) { +// case openai.RunStepDeltaMessageDelta: +// case openai.ToolCallDeltaObject: +// default: +// fmt.Errorf("no variant present") +// } +func (u RunStepDeltaStepDetailsUnion) AsAny() anyRunStepDeltaStepDetails { + switch u.Type { + case "message_creation": + return u.AsMessageCreation() + case "tool_calls": + return u.AsToolCalls() + } + return nil +} + +func (u RunStepDeltaStepDetailsUnion) AsMessageCreation() (v RunStepDeltaMessageDelta) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u RunStepDeltaStepDetailsUnion) AsToolCalls() (v ToolCallDeltaObject) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u RunStepDeltaStepDetailsUnion) RawJSON() string { return u.JSON.raw } + +func (r *RunStepDeltaStepDetailsUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Represents a run step delta i.e. any changed fields on a run step during +// streaming. +type RunStepDeltaEvent struct { + // The identifier of the run step, which can be referenced in API endpoints. + ID string `json:"id,required"` + // The delta containing the fields that have changed on the run step. + Delta RunStepDelta `json:"delta,required"` + // The object type, which is always `thread.run.step.delta`. + Object constant.ThreadRunStepDelta `json:"object,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + Delta respjson.Field + Object respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r RunStepDeltaEvent) RawJSON() string { return r.JSON.raw } +func (r *RunStepDeltaEvent) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Details of the message creation by the run step. +type RunStepDeltaMessageDelta struct { + // Always `message_creation`. + Type constant.MessageCreation `json:"type,required"` + MessageCreation RunStepDeltaMessageDeltaMessageCreation `json:"message_creation"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Type respjson.Field + MessageCreation respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r RunStepDeltaMessageDelta) RawJSON() string { return r.JSON.raw } +func (r *RunStepDeltaMessageDelta) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type RunStepDeltaMessageDeltaMessageCreation struct { + // The ID of the message that was created by this run step. + MessageID string `json:"message_id"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + MessageID respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r RunStepDeltaMessageDeltaMessageCreation) RawJSON() string { return r.JSON.raw } +func (r *RunStepDeltaMessageDeltaMessageCreation) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type RunStepInclude string + +const ( + RunStepIncludeStepDetailsToolCallsFileSearchResultsContent RunStepInclude = "step_details.tool_calls[*].file_search.results[*].content" +) + +// ToolCallUnion contains all possible properties and values from +// [CodeInterpreterToolCall], [FileSearchToolCall], [FunctionToolCall]. +// +// Use the [ToolCallUnion.AsAny] method to switch on the variant. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +type ToolCallUnion struct { + ID string `json:"id"` + // This field is from variant [CodeInterpreterToolCall]. + CodeInterpreter CodeInterpreterToolCallCodeInterpreter `json:"code_interpreter"` + // Any of "code_interpreter", "file_search", "function". + Type string `json:"type"` + // This field is from variant [FileSearchToolCall]. + FileSearch FileSearchToolCallFileSearch `json:"file_search"` + // This field is from variant [FunctionToolCall]. + Function FunctionToolCallFunction `json:"function"` + JSON struct { + ID respjson.Field + CodeInterpreter respjson.Field + Type respjson.Field + FileSearch respjson.Field + Function respjson.Field + raw string + } `json:"-"` +} + +// anyToolCall is implemented by each variant of [ToolCallUnion] to add type safety +// for the return type of [ToolCallUnion.AsAny] +type anyToolCall interface { + implToolCallUnion() +} + +func (CodeInterpreterToolCall) implToolCallUnion() {} +func (FileSearchToolCall) implToolCallUnion() {} +func (FunctionToolCall) implToolCallUnion() {} + +// Use the following switch statement to find the correct variant +// +// switch variant := ToolCallUnion.AsAny().(type) { +// case openai.CodeInterpreterToolCall: +// case openai.FileSearchToolCall: +// case openai.FunctionToolCall: +// default: +// fmt.Errorf("no variant present") +// } +func (u ToolCallUnion) AsAny() anyToolCall { + switch u.Type { + case "code_interpreter": + return u.AsCodeInterpreter() + case "file_search": + return u.AsFileSearch() + case "function": + return u.AsFunction() + } + return nil +} + +func (u ToolCallUnion) AsCodeInterpreter() (v CodeInterpreterToolCall) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u ToolCallUnion) AsFileSearch() (v FileSearchToolCall) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u ToolCallUnion) AsFunction() (v FunctionToolCall) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u ToolCallUnion) RawJSON() string { return u.JSON.raw } + +func (r *ToolCallUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ToolCallDeltaUnion contains all possible properties and values from +// [CodeInterpreterToolCallDelta], [FileSearchToolCallDelta], +// [FunctionToolCallDelta]. +// +// Use the [ToolCallDeltaUnion.AsAny] method to switch on the variant. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +type ToolCallDeltaUnion struct { + Index int64 `json:"index"` + // Any of "code_interpreter", "file_search", "function". + Type string `json:"type"` + ID string `json:"id"` + // This field is from variant [CodeInterpreterToolCallDelta]. + CodeInterpreter CodeInterpreterToolCallDeltaCodeInterpreter `json:"code_interpreter"` + // This field is from variant [FileSearchToolCallDelta]. + FileSearch any `json:"file_search"` + // This field is from variant [FunctionToolCallDelta]. + Function FunctionToolCallDeltaFunction `json:"function"` + JSON struct { + Index respjson.Field + Type respjson.Field + ID respjson.Field + CodeInterpreter respjson.Field + FileSearch respjson.Field + Function respjson.Field + raw string + } `json:"-"` +} + +// anyToolCallDelta is implemented by each variant of [ToolCallDeltaUnion] to add +// type safety for the return type of [ToolCallDeltaUnion.AsAny] +type anyToolCallDelta interface { + implToolCallDeltaUnion() +} + +func (CodeInterpreterToolCallDelta) implToolCallDeltaUnion() {} +func (FileSearchToolCallDelta) implToolCallDeltaUnion() {} +func (FunctionToolCallDelta) implToolCallDeltaUnion() {} + +// Use the following switch statement to find the correct variant +// +// switch variant := ToolCallDeltaUnion.AsAny().(type) { +// case openai.CodeInterpreterToolCallDelta: +// case openai.FileSearchToolCallDelta: +// case openai.FunctionToolCallDelta: +// default: +// fmt.Errorf("no variant present") +// } +func (u ToolCallDeltaUnion) AsAny() anyToolCallDelta { + switch u.Type { + case "code_interpreter": + return u.AsCodeInterpreter() + case "file_search": + return u.AsFileSearch() + case "function": + return u.AsFunction() + } + return nil +} + +func (u ToolCallDeltaUnion) AsCodeInterpreter() (v CodeInterpreterToolCallDelta) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u ToolCallDeltaUnion) AsFileSearch() (v FileSearchToolCallDelta) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u ToolCallDeltaUnion) AsFunction() (v FunctionToolCallDelta) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u ToolCallDeltaUnion) RawJSON() string { return u.JSON.raw } + +func (r *ToolCallDeltaUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Details of the tool call. +type ToolCallDeltaObject struct { + // Always `tool_calls`. + Type constant.ToolCalls `json:"type,required"` + // An array of tool calls the run step was involved in. These can be associated + // with one of three types of tools: `code_interpreter`, `file_search`, or + // `function`. + ToolCalls []ToolCallDeltaUnion `json:"tool_calls"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Type respjson.Field + ToolCalls respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ToolCallDeltaObject) RawJSON() string { return r.JSON.raw } +func (r *ToolCallDeltaObject) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Details of the tool call. +type ToolCallsStepDetails struct { + // An array of tool calls the run step was involved in. These can be associated + // with one of three types of tools: `code_interpreter`, `file_search`, or + // `function`. + ToolCalls []ToolCallUnion `json:"tool_calls,required"` + // Always `tool_calls`. + Type constant.ToolCalls `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ToolCalls respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ToolCallsStepDetails) RawJSON() string { return r.JSON.raw } +func (r *ToolCallsStepDetails) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type BetaThreadRunStepGetParams struct { + // A list of additional fields to include in the response. Currently the only + // supported value is `step_details.tool_calls[*].file_search.results[*].content` + // to fetch the file search result content. + // + // See the + // [file search tool documentation](https://platform.openai.com/docs/assistants/tools/file-search#customizing-file-search-settings) + // for more information. + Include []RunStepInclude `query:"include,omitzero" json:"-"` + paramObj +} + +// URLQuery serializes [BetaThreadRunStepGetParams]'s query parameters as +// `url.Values`. +func (r BetaThreadRunStepGetParams) URLQuery() (v url.Values, err error) { + return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{ + ArrayFormat: apiquery.ArrayQueryFormatBrackets, + NestedFormat: apiquery.NestedQueryFormatBrackets, + }) +} + +type BetaThreadRunStepListParams struct { + // A cursor for use in pagination. `after` is an object ID that defines your place + // in the list. For instance, if you make a list request and receive 100 objects, + // ending with obj_foo, your subsequent call can include after=obj_foo in order to + // fetch the next page of the list. + After param.Opt[string] `query:"after,omitzero" json:"-"` + // A cursor for use in pagination. `before` is an object ID that defines your place + // in the list. For instance, if you make a list request and receive 100 objects, + // starting with obj_foo, your subsequent call can include before=obj_foo in order + // to fetch the previous page of the list. + Before param.Opt[string] `query:"before,omitzero" json:"-"` + // A limit on the number of objects to be returned. Limit can range between 1 and + // 100, and the default is 20. + Limit param.Opt[int64] `query:"limit,omitzero" json:"-"` + // A list of additional fields to include in the response. Currently the only + // supported value is `step_details.tool_calls[*].file_search.results[*].content` + // to fetch the file search result content. + // + // See the + // [file search tool documentation](https://platform.openai.com/docs/assistants/tools/file-search#customizing-file-search-settings) + // for more information. + Include []RunStepInclude `query:"include,omitzero" json:"-"` + // Sort order by the `created_at` timestamp of the objects. `asc` for ascending + // order and `desc` for descending order. + // + // Any of "asc", "desc". + Order BetaThreadRunStepListParamsOrder `query:"order,omitzero" json:"-"` + paramObj +} + +// URLQuery serializes [BetaThreadRunStepListParams]'s query parameters as +// `url.Values`. +func (r BetaThreadRunStepListParams) URLQuery() (v url.Values, err error) { + return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{ + ArrayFormat: apiquery.ArrayQueryFormatBrackets, + NestedFormat: apiquery.NestedQueryFormatBrackets, + }) +} + +// Sort order by the `created_at` timestamp of the objects. `asc` for ascending +// order and `desc` for descending order. +type BetaThreadRunStepListParamsOrder string + +const ( + BetaThreadRunStepListParamsOrderAsc BetaThreadRunStepListParamsOrder = "asc" + BetaThreadRunStepListParamsOrderDesc BetaThreadRunStepListParamsOrder = "desc" +) diff --git a/vendor/github.com/openai/openai-go/chat.go b/vendor/github.com/openai/openai-go/chat.go new file mode 100644 index 0000000000..f579dc3cd7 --- /dev/null +++ b/vendor/github.com/openai/openai-go/chat.go @@ -0,0 +1,28 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package openai + +import ( + "github.com/openai/openai-go/option" +) + +// ChatService contains methods and other services that help with interacting with +// the openai API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewChatService] method instead. +type ChatService struct { + Options []option.RequestOption + Completions ChatCompletionService +} + +// NewChatService generates a new service that applies the given options to each +// request. These options are applied after the parent client's options (if there +// is one), and before any request-specific options. +func NewChatService(opts ...option.RequestOption) (r ChatService) { + r = ChatService{} + r.Options = opts + r.Completions = NewChatCompletionService(opts...) + return +} diff --git a/vendor/github.com/openai/openai-go/chatcompletion.go b/vendor/github.com/openai/openai-go/chatcompletion.go new file mode 100644 index 0000000000..4ebf54b5dc --- /dev/null +++ b/vendor/github.com/openai/openai-go/chatcompletion.go @@ -0,0 +1,2738 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package openai + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + + "github.com/openai/openai-go/internal/apijson" + "github.com/openai/openai-go/internal/apiquery" + "github.com/openai/openai-go/internal/requestconfig" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/packages/pagination" + "github.com/openai/openai-go/packages/param" + "github.com/openai/openai-go/packages/respjson" + "github.com/openai/openai-go/packages/ssestream" + "github.com/openai/openai-go/shared" + "github.com/openai/openai-go/shared/constant" +) + +// ChatCompletionService contains methods and other services that help with +// interacting with the openai API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewChatCompletionService] method instead. +type ChatCompletionService struct { + Options []option.RequestOption + Messages ChatCompletionMessageService +} + +// NewChatCompletionService generates a new service that applies the given options +// to each request. These options are applied after the parent client's options (if +// there is one), and before any request-specific options. +func NewChatCompletionService(opts ...option.RequestOption) (r ChatCompletionService) { + r = ChatCompletionService{} + r.Options = opts + r.Messages = NewChatCompletionMessageService(opts...) + return +} + +// **Starting a new project?** We recommend trying +// [Responses](https://platform.openai.com/docs/api-reference/responses) to take +// advantage of the latest OpenAI platform features. Compare +// [Chat Completions with Responses](https://platform.openai.com/docs/guides/responses-vs-chat-completions?api-mode=responses). +// +// --- +// +// Creates a model response for the given chat conversation. Learn more in the +// [text generation](https://platform.openai.com/docs/guides/text-generation), +// [vision](https://platform.openai.com/docs/guides/vision), and +// [audio](https://platform.openai.com/docs/guides/audio) guides. +// +// Parameter support can differ depending on the model used to generate the +// response, particularly for newer reasoning models. Parameters that are only +// supported for reasoning models are noted below. For the current state of +// unsupported parameters in reasoning models, +// [refer to the reasoning guide](https://platform.openai.com/docs/guides/reasoning). +func (r *ChatCompletionService) New(ctx context.Context, body ChatCompletionNewParams, opts ...option.RequestOption) (res *ChatCompletion, err error) { + opts = append(r.Options[:], opts...) + path := "chat/completions" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) + return +} + +// **Starting a new project?** We recommend trying +// [Responses](https://platform.openai.com/docs/api-reference/responses) to take +// advantage of the latest OpenAI platform features. Compare +// [Chat Completions with Responses](https://platform.openai.com/docs/guides/responses-vs-chat-completions?api-mode=responses). +// +// --- +// +// Creates a model response for the given chat conversation. Learn more in the +// [text generation](https://platform.openai.com/docs/guides/text-generation), +// [vision](https://platform.openai.com/docs/guides/vision), and +// [audio](https://platform.openai.com/docs/guides/audio) guides. +// +// Parameter support can differ depending on the model used to generate the +// response, particularly for newer reasoning models. Parameters that are only +// supported for reasoning models are noted below. For the current state of +// unsupported parameters in reasoning models, +// [refer to the reasoning guide](https://platform.openai.com/docs/guides/reasoning). +func (r *ChatCompletionService) NewStreaming(ctx context.Context, body ChatCompletionNewParams, opts ...option.RequestOption) (stream *ssestream.Stream[ChatCompletionChunk]) { + var ( + raw *http.Response + err error + ) + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithJSONSet("stream", true)}, opts...) + path := "chat/completions" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &raw, opts...) + return ssestream.NewStream[ChatCompletionChunk](ssestream.NewDecoder(raw), err) +} + +// Get a stored chat completion. Only Chat Completions that have been created with +// the `store` parameter set to `true` will be returned. +func (r *ChatCompletionService) Get(ctx context.Context, completionID string, opts ...option.RequestOption) (res *ChatCompletion, err error) { + opts = append(r.Options[:], opts...) + if completionID == "" { + err = errors.New("missing required completion_id parameter") + return + } + path := fmt.Sprintf("chat/completions/%s", completionID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, nil, &res, opts...) + return +} + +// Modify a stored chat completion. Only Chat Completions that have been created +// with the `store` parameter set to `true` can be modified. Currently, the only +// supported modification is to update the `metadata` field. +func (r *ChatCompletionService) Update(ctx context.Context, completionID string, body ChatCompletionUpdateParams, opts ...option.RequestOption) (res *ChatCompletion, err error) { + opts = append(r.Options[:], opts...) + if completionID == "" { + err = errors.New("missing required completion_id parameter") + return + } + path := fmt.Sprintf("chat/completions/%s", completionID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) + return +} + +// List stored Chat Completions. Only Chat Completions that have been stored with +// the `store` parameter set to `true` will be returned. +func (r *ChatCompletionService) List(ctx context.Context, query ChatCompletionListParams, opts ...option.RequestOption) (res *pagination.CursorPage[ChatCompletion], err error) { + var raw *http.Response + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithResponseInto(&raw)}, opts...) + path := "chat/completions" + cfg, err := requestconfig.NewRequestConfig(ctx, http.MethodGet, path, query, &res, opts...) + if err != nil { + return nil, err + } + err = cfg.Execute() + if err != nil { + return nil, err + } + res.SetPageConfig(cfg, raw) + return res, nil +} + +// List stored Chat Completions. Only Chat Completions that have been stored with +// the `store` parameter set to `true` will be returned. +func (r *ChatCompletionService) ListAutoPaging(ctx context.Context, query ChatCompletionListParams, opts ...option.RequestOption) *pagination.CursorPageAutoPager[ChatCompletion] { + return pagination.NewCursorPageAutoPager(r.List(ctx, query, opts...)) +} + +// Delete a stored chat completion. Only Chat Completions that have been created +// with the `store` parameter set to `true` can be deleted. +func (r *ChatCompletionService) Delete(ctx context.Context, completionID string, opts ...option.RequestOption) (res *ChatCompletionDeleted, err error) { + opts = append(r.Options[:], opts...) + if completionID == "" { + err = errors.New("missing required completion_id parameter") + return + } + path := fmt.Sprintf("chat/completions/%s", completionID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodDelete, path, nil, &res, opts...) + return +} + +// Represents a chat completion response returned by model, based on the provided +// input. +type ChatCompletion struct { + // A unique identifier for the chat completion. + ID string `json:"id,required"` + // A list of chat completion choices. Can be more than one if `n` is greater + // than 1. + Choices []ChatCompletionChoice `json:"choices,required"` + // The Unix timestamp (in seconds) of when the chat completion was created. + Created int64 `json:"created,required"` + // The model used for the chat completion. + Model string `json:"model,required"` + // The object type, which is always `chat.completion`. + Object constant.ChatCompletion `json:"object,required"` + // Specifies the processing type used for serving the request. + // + // - If set to 'auto', then the request will be processed with the service tier + // configured in the Project settings. Unless otherwise configured, the Project + // will use 'default'. + // - If set to 'default', then the request will be processed with the standard + // pricing and performance for the selected model. + // - If set to '[flex](https://platform.openai.com/docs/guides/flex-processing)' or + // 'priority', then the request will be processed with the corresponding service + // tier. [Contact sales](https://openai.com/contact-sales) to learn more about + // Priority processing. + // - When not set, the default behavior is 'auto'. + // + // When the `service_tier` parameter is set, the response body will include the + // `service_tier` value based on the processing mode actually used to serve the + // request. This response value may be different from the value set in the + // parameter. + // + // Any of "auto", "default", "flex", "scale", "priority". + ServiceTier ChatCompletionServiceTier `json:"service_tier,nullable"` + // This fingerprint represents the backend configuration that the model runs with. + // + // Can be used in conjunction with the `seed` request parameter to understand when + // backend changes have been made that might impact determinism. + SystemFingerprint string `json:"system_fingerprint"` + // Usage statistics for the completion request. + Usage CompletionUsage `json:"usage"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + Choices respjson.Field + Created respjson.Field + Model respjson.Field + Object respjson.Field + ServiceTier respjson.Field + SystemFingerprint respjson.Field + Usage respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ChatCompletion) RawJSON() string { return r.JSON.raw } +func (r *ChatCompletion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type ChatCompletionChoice struct { + // The reason the model stopped generating tokens. This will be `stop` if the model + // hit a natural stop point or a provided stop sequence, `length` if the maximum + // number of tokens specified in the request was reached, `content_filter` if + // content was omitted due to a flag from our content filters, `tool_calls` if the + // model called a tool, or `function_call` (deprecated) if the model called a + // function. + // + // Any of "stop", "length", "tool_calls", "content_filter", "function_call". + FinishReason string `json:"finish_reason,required"` + // The index of the choice in the list of choices. + Index int64 `json:"index,required"` + // Log probability information for the choice. + Logprobs ChatCompletionChoiceLogprobs `json:"logprobs,required"` + // A chat completion message generated by the model. + Message ChatCompletionMessage `json:"message,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + FinishReason respjson.Field + Index respjson.Field + Logprobs respjson.Field + Message respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ChatCompletionChoice) RawJSON() string { return r.JSON.raw } +func (r *ChatCompletionChoice) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Log probability information for the choice. +type ChatCompletionChoiceLogprobs struct { + // A list of message content tokens with log probability information. + Content []ChatCompletionTokenLogprob `json:"content,required"` + // A list of message refusal tokens with log probability information. + Refusal []ChatCompletionTokenLogprob `json:"refusal,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Content respjson.Field + Refusal respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ChatCompletionChoiceLogprobs) RawJSON() string { return r.JSON.raw } +func (r *ChatCompletionChoiceLogprobs) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Specifies the processing type used for serving the request. +// +// - If set to 'auto', then the request will be processed with the service tier +// configured in the Project settings. Unless otherwise configured, the Project +// will use 'default'. +// - If set to 'default', then the request will be processed with the standard +// pricing and performance for the selected model. +// - If set to '[flex](https://platform.openai.com/docs/guides/flex-processing)' or +// 'priority', then the request will be processed with the corresponding service +// tier. [Contact sales](https://openai.com/contact-sales) to learn more about +// Priority processing. +// - When not set, the default behavior is 'auto'. +// +// When the `service_tier` parameter is set, the response body will include the +// `service_tier` value based on the processing mode actually used to serve the +// request. This response value may be different from the value set in the +// parameter. +type ChatCompletionServiceTier string + +const ( + ChatCompletionServiceTierAuto ChatCompletionServiceTier = "auto" + ChatCompletionServiceTierDefault ChatCompletionServiceTier = "default" + ChatCompletionServiceTierFlex ChatCompletionServiceTier = "flex" + ChatCompletionServiceTierScale ChatCompletionServiceTier = "scale" + ChatCompletionServiceTierPriority ChatCompletionServiceTier = "priority" +) + +// Messages sent by the model in response to user messages. +// +// The property Role is required. +type ChatCompletionAssistantMessageParam struct { + // The refusal message by the assistant. + Refusal param.Opt[string] `json:"refusal,omitzero"` + // An optional name for the participant. Provides the model information to + // differentiate between participants of the same role. + Name param.Opt[string] `json:"name,omitzero"` + // Data about a previous audio response from the model. + // [Learn more](https://platform.openai.com/docs/guides/audio). + Audio ChatCompletionAssistantMessageParamAudio `json:"audio,omitzero"` + // The contents of the assistant message. Required unless `tool_calls` or + // `function_call` is specified. + Content ChatCompletionAssistantMessageParamContentUnion `json:"content,omitzero"` + // Deprecated and replaced by `tool_calls`. The name and arguments of a function + // that should be called, as generated by the model. + // + // Deprecated: deprecated + FunctionCall ChatCompletionAssistantMessageParamFunctionCall `json:"function_call,omitzero"` + // The tool calls generated by the model, such as function calls. + ToolCalls []ChatCompletionMessageToolCallParam `json:"tool_calls,omitzero"` + // The role of the messages author, in this case `assistant`. + // + // This field can be elided, and will marshal its zero value as "assistant". + Role constant.Assistant `json:"role,required"` + paramObj +} + +func (r ChatCompletionAssistantMessageParam) MarshalJSON() (data []byte, err error) { + type shadow ChatCompletionAssistantMessageParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ChatCompletionAssistantMessageParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Data about a previous audio response from the model. +// [Learn more](https://platform.openai.com/docs/guides/audio). +// +// The property ID is required. +type ChatCompletionAssistantMessageParamAudio struct { + // Unique identifier for a previous audio response from the model. + ID string `json:"id,required"` + paramObj +} + +func (r ChatCompletionAssistantMessageParamAudio) MarshalJSON() (data []byte, err error) { + type shadow ChatCompletionAssistantMessageParamAudio + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ChatCompletionAssistantMessageParamAudio) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type ChatCompletionAssistantMessageParamContentUnion struct { + OfString param.Opt[string] `json:",omitzero,inline"` + OfArrayOfContentParts []ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion `json:",omitzero,inline"` + paramUnion +} + +func (u ChatCompletionAssistantMessageParamContentUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfString, u.OfArrayOfContentParts) +} +func (u *ChatCompletionAssistantMessageParamContentUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *ChatCompletionAssistantMessageParamContentUnion) asAny() any { + if !param.IsOmitted(u.OfString) { + return &u.OfString.Value + } else if !param.IsOmitted(u.OfArrayOfContentParts) { + return &u.OfArrayOfContentParts + } + return nil +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion struct { + OfText *ChatCompletionContentPartTextParam `json:",omitzero,inline"` + OfRefusal *ChatCompletionContentPartRefusalParam `json:",omitzero,inline"` + paramUnion +} + +func (u ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfText, u.OfRefusal) +} +func (u *ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion) asAny() any { + if !param.IsOmitted(u.OfText) { + return u.OfText + } else if !param.IsOmitted(u.OfRefusal) { + return u.OfRefusal + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion) GetText() *string { + if vt := u.OfText; vt != nil { + return &vt.Text + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion) GetRefusal() *string { + if vt := u.OfRefusal; vt != nil { + return &vt.Refusal + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion) GetType() *string { + if vt := u.OfText; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfRefusal; vt != nil { + return (*string)(&vt.Type) + } + return nil +} + +func init() { + apijson.RegisterUnion[ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion]( + "type", + apijson.Discriminator[ChatCompletionContentPartTextParam]("text"), + apijson.Discriminator[ChatCompletionContentPartRefusalParam]("refusal"), + ) +} + +// Deprecated and replaced by `tool_calls`. The name and arguments of a function +// that should be called, as generated by the model. +// +// Deprecated: deprecated +// +// The properties Arguments, Name are required. +type ChatCompletionAssistantMessageParamFunctionCall struct { + // The arguments to call the function with, as generated by the model in JSON + // format. Note that the model does not always generate valid JSON, and may + // hallucinate parameters not defined by your function schema. Validate the + // arguments in your code before calling your function. + Arguments string `json:"arguments,required"` + // The name of the function to call. + Name string `json:"name,required"` + paramObj +} + +func (r ChatCompletionAssistantMessageParamFunctionCall) MarshalJSON() (data []byte, err error) { + type shadow ChatCompletionAssistantMessageParamFunctionCall + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ChatCompletionAssistantMessageParamFunctionCall) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// If the audio output modality is requested, this object contains data about the +// audio response from the model. +// [Learn more](https://platform.openai.com/docs/guides/audio). +type ChatCompletionAudio struct { + // Unique identifier for this audio response. + ID string `json:"id,required"` + // Base64 encoded audio bytes generated by the model, in the format specified in + // the request. + Data string `json:"data,required"` + // The Unix timestamp (in seconds) for when this audio response will no longer be + // accessible on the server for use in multi-turn conversations. + ExpiresAt int64 `json:"expires_at,required"` + // Transcript of the audio generated by the model. + Transcript string `json:"transcript,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + Data respjson.Field + ExpiresAt respjson.Field + Transcript respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ChatCompletionAudio) RawJSON() string { return r.JSON.raw } +func (r *ChatCompletionAudio) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Parameters for audio output. Required when audio output is requested with +// `modalities: ["audio"]`. +// [Learn more](https://platform.openai.com/docs/guides/audio). +// +// The properties Format, Voice are required. +type ChatCompletionAudioParam struct { + // Specifies the output audio format. Must be one of `wav`, `mp3`, `flac`, `opus`, + // or `pcm16`. + // + // Any of "wav", "aac", "mp3", "flac", "opus", "pcm16". + Format ChatCompletionAudioParamFormat `json:"format,omitzero,required"` + // The voice the model uses to respond. Supported voices are `alloy`, `ash`, + // `ballad`, `coral`, `echo`, `fable`, `nova`, `onyx`, `sage`, and `shimmer`. + Voice ChatCompletionAudioParamVoice `json:"voice,omitzero,required"` + paramObj +} + +func (r ChatCompletionAudioParam) MarshalJSON() (data []byte, err error) { + type shadow ChatCompletionAudioParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ChatCompletionAudioParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Specifies the output audio format. Must be one of `wav`, `mp3`, `flac`, `opus`, +// or `pcm16`. +type ChatCompletionAudioParamFormat string + +const ( + ChatCompletionAudioParamFormatWAV ChatCompletionAudioParamFormat = "wav" + ChatCompletionAudioParamFormatAAC ChatCompletionAudioParamFormat = "aac" + ChatCompletionAudioParamFormatMP3 ChatCompletionAudioParamFormat = "mp3" + ChatCompletionAudioParamFormatFLAC ChatCompletionAudioParamFormat = "flac" + ChatCompletionAudioParamFormatOpus ChatCompletionAudioParamFormat = "opus" + ChatCompletionAudioParamFormatPcm16 ChatCompletionAudioParamFormat = "pcm16" +) + +// The voice the model uses to respond. Supported voices are `alloy`, `ash`, +// `ballad`, `coral`, `echo`, `fable`, `nova`, `onyx`, `sage`, and `shimmer`. +type ChatCompletionAudioParamVoice string + +const ( + ChatCompletionAudioParamVoiceAlloy ChatCompletionAudioParamVoice = "alloy" + ChatCompletionAudioParamVoiceAsh ChatCompletionAudioParamVoice = "ash" + ChatCompletionAudioParamVoiceBallad ChatCompletionAudioParamVoice = "ballad" + ChatCompletionAudioParamVoiceCoral ChatCompletionAudioParamVoice = "coral" + ChatCompletionAudioParamVoiceEcho ChatCompletionAudioParamVoice = "echo" + ChatCompletionAudioParamVoiceSage ChatCompletionAudioParamVoice = "sage" + ChatCompletionAudioParamVoiceShimmer ChatCompletionAudioParamVoice = "shimmer" + ChatCompletionAudioParamVoiceVerse ChatCompletionAudioParamVoice = "verse" +) + +// Represents a streamed chunk of a chat completion response returned by the model, +// based on the provided input. +// [Learn more](https://platform.openai.com/docs/guides/streaming-responses). +type ChatCompletionChunk struct { + // A unique identifier for the chat completion. Each chunk has the same ID. + ID string `json:"id,required"` + // A list of chat completion choices. Can contain more than one elements if `n` is + // greater than 1. Can also be empty for the last chunk if you set + // `stream_options: {"include_usage": true}`. + Choices []ChatCompletionChunkChoice `json:"choices,required"` + // The Unix timestamp (in seconds) of when the chat completion was created. Each + // chunk has the same timestamp. + Created int64 `json:"created,required"` + // The model to generate the completion. + Model string `json:"model,required"` + // The object type, which is always `chat.completion.chunk`. + Object constant.ChatCompletionChunk `json:"object,required"` + // Specifies the processing type used for serving the request. + // + // - If set to 'auto', then the request will be processed with the service tier + // configured in the Project settings. Unless otherwise configured, the Project + // will use 'default'. + // - If set to 'default', then the request will be processed with the standard + // pricing and performance for the selected model. + // - If set to '[flex](https://platform.openai.com/docs/guides/flex-processing)' or + // 'priority', then the request will be processed with the corresponding service + // tier. [Contact sales](https://openai.com/contact-sales) to learn more about + // Priority processing. + // - When not set, the default behavior is 'auto'. + // + // When the `service_tier` parameter is set, the response body will include the + // `service_tier` value based on the processing mode actually used to serve the + // request. This response value may be different from the value set in the + // parameter. + // + // Any of "auto", "default", "flex", "scale", "priority". + ServiceTier ChatCompletionChunkServiceTier `json:"service_tier,nullable"` + // This fingerprint represents the backend configuration that the model runs with. + // Can be used in conjunction with the `seed` request parameter to understand when + // backend changes have been made that might impact determinism. + SystemFingerprint string `json:"system_fingerprint"` + // An optional field that will only be present when you set + // `stream_options: {"include_usage": true}` in your request. When present, it + // contains a null value **except for the last chunk** which contains the token + // usage statistics for the entire request. + // + // **NOTE:** If the stream is interrupted or cancelled, you may not receive the + // final usage chunk which contains the total token usage for the request. + Usage CompletionUsage `json:"usage,nullable"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + Choices respjson.Field + Created respjson.Field + Model respjson.Field + Object respjson.Field + ServiceTier respjson.Field + SystemFingerprint respjson.Field + Usage respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ChatCompletionChunk) RawJSON() string { return r.JSON.raw } +func (r *ChatCompletionChunk) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type ChatCompletionChunkChoice struct { + // A chat completion delta generated by streamed model responses. + Delta ChatCompletionChunkChoiceDelta `json:"delta,required"` + // The reason the model stopped generating tokens. This will be `stop` if the model + // hit a natural stop point or a provided stop sequence, `length` if the maximum + // number of tokens specified in the request was reached, `content_filter` if + // content was omitted due to a flag from our content filters, `tool_calls` if the + // model called a tool, or `function_call` (deprecated) if the model called a + // function. + // + // Any of "stop", "length", "tool_calls", "content_filter", "function_call". + FinishReason string `json:"finish_reason,required"` + // The index of the choice in the list of choices. + Index int64 `json:"index,required"` + // Log probability information for the choice. + Logprobs ChatCompletionChunkChoiceLogprobs `json:"logprobs,nullable"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Delta respjson.Field + FinishReason respjson.Field + Index respjson.Field + Logprobs respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ChatCompletionChunkChoice) RawJSON() string { return r.JSON.raw } +func (r *ChatCompletionChunkChoice) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// A chat completion delta generated by streamed model responses. +type ChatCompletionChunkChoiceDelta struct { + // The contents of the chunk message. + Content string `json:"content,nullable"` + // Deprecated and replaced by `tool_calls`. The name and arguments of a function + // that should be called, as generated by the model. + // + // Deprecated: deprecated + FunctionCall ChatCompletionChunkChoiceDeltaFunctionCall `json:"function_call"` + // The refusal message generated by the model. + Refusal string `json:"refusal,nullable"` + // The role of the author of this message. + // + // Any of "developer", "system", "user", "assistant", "tool". + Role string `json:"role"` + ToolCalls []ChatCompletionChunkChoiceDeltaToolCall `json:"tool_calls"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Content respjson.Field + FunctionCall respjson.Field + Refusal respjson.Field + Role respjson.Field + ToolCalls respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ChatCompletionChunkChoiceDelta) RawJSON() string { return r.JSON.raw } +func (r *ChatCompletionChunkChoiceDelta) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Deprecated and replaced by `tool_calls`. The name and arguments of a function +// that should be called, as generated by the model. +// +// Deprecated: deprecated +type ChatCompletionChunkChoiceDeltaFunctionCall struct { + // The arguments to call the function with, as generated by the model in JSON + // format. Note that the model does not always generate valid JSON, and may + // hallucinate parameters not defined by your function schema. Validate the + // arguments in your code before calling your function. + Arguments string `json:"arguments"` + // The name of the function to call. + Name string `json:"name"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Arguments respjson.Field + Name respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ChatCompletionChunkChoiceDeltaFunctionCall) RawJSON() string { return r.JSON.raw } +func (r *ChatCompletionChunkChoiceDeltaFunctionCall) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type ChatCompletionChunkChoiceDeltaToolCall struct { + Index int64 `json:"index,required"` + // The ID of the tool call. + ID string `json:"id"` + Function ChatCompletionChunkChoiceDeltaToolCallFunction `json:"function"` + // The type of the tool. Currently, only `function` is supported. + // + // Any of "function". + Type string `json:"type"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Index respjson.Field + ID respjson.Field + Function respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ChatCompletionChunkChoiceDeltaToolCall) RawJSON() string { return r.JSON.raw } +func (r *ChatCompletionChunkChoiceDeltaToolCall) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type ChatCompletionChunkChoiceDeltaToolCallFunction struct { + // The arguments to call the function with, as generated by the model in JSON + // format. Note that the model does not always generate valid JSON, and may + // hallucinate parameters not defined by your function schema. Validate the + // arguments in your code before calling your function. + Arguments string `json:"arguments"` + // The name of the function to call. + Name string `json:"name"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Arguments respjson.Field + Name respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ChatCompletionChunkChoiceDeltaToolCallFunction) RawJSON() string { return r.JSON.raw } +func (r *ChatCompletionChunkChoiceDeltaToolCallFunction) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Log probability information for the choice. +type ChatCompletionChunkChoiceLogprobs struct { + // A list of message content tokens with log probability information. + Content []ChatCompletionTokenLogprob `json:"content,required"` + // A list of message refusal tokens with log probability information. + Refusal []ChatCompletionTokenLogprob `json:"refusal,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Content respjson.Field + Refusal respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ChatCompletionChunkChoiceLogprobs) RawJSON() string { return r.JSON.raw } +func (r *ChatCompletionChunkChoiceLogprobs) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Specifies the processing type used for serving the request. +// +// - If set to 'auto', then the request will be processed with the service tier +// configured in the Project settings. Unless otherwise configured, the Project +// will use 'default'. +// - If set to 'default', then the request will be processed with the standard +// pricing and performance for the selected model. +// - If set to '[flex](https://platform.openai.com/docs/guides/flex-processing)' or +// 'priority', then the request will be processed with the corresponding service +// tier. [Contact sales](https://openai.com/contact-sales) to learn more about +// Priority processing. +// - When not set, the default behavior is 'auto'. +// +// When the `service_tier` parameter is set, the response body will include the +// `service_tier` value based on the processing mode actually used to serve the +// request. This response value may be different from the value set in the +// parameter. +type ChatCompletionChunkServiceTier string + +const ( + ChatCompletionChunkServiceTierAuto ChatCompletionChunkServiceTier = "auto" + ChatCompletionChunkServiceTierDefault ChatCompletionChunkServiceTier = "default" + ChatCompletionChunkServiceTierFlex ChatCompletionChunkServiceTier = "flex" + ChatCompletionChunkServiceTierScale ChatCompletionChunkServiceTier = "scale" + ChatCompletionChunkServiceTierPriority ChatCompletionChunkServiceTier = "priority" +) + +func TextContentPart(text string) ChatCompletionContentPartUnionParam { + var variant ChatCompletionContentPartTextParam + variant.Text = text + return ChatCompletionContentPartUnionParam{OfText: &variant} +} + +func ImageContentPart(imageURL ChatCompletionContentPartImageImageURLParam) ChatCompletionContentPartUnionParam { + var variant ChatCompletionContentPartImageParam + variant.ImageURL = imageURL + return ChatCompletionContentPartUnionParam{OfImageURL: &variant} +} + +func InputAudioContentPart(inputAudio ChatCompletionContentPartInputAudioInputAudioParam) ChatCompletionContentPartUnionParam { + var variant ChatCompletionContentPartInputAudioParam + variant.InputAudio = inputAudio + return ChatCompletionContentPartUnionParam{OfInputAudio: &variant} +} + +func FileContentPart(file ChatCompletionContentPartFileFileParam) ChatCompletionContentPartUnionParam { + var variant ChatCompletionContentPartFileParam + variant.File = file + return ChatCompletionContentPartUnionParam{OfFile: &variant} +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type ChatCompletionContentPartUnionParam struct { + OfText *ChatCompletionContentPartTextParam `json:",omitzero,inline"` + OfImageURL *ChatCompletionContentPartImageParam `json:",omitzero,inline"` + OfInputAudio *ChatCompletionContentPartInputAudioParam `json:",omitzero,inline"` + OfFile *ChatCompletionContentPartFileParam `json:",omitzero,inline"` + paramUnion +} + +func (u ChatCompletionContentPartUnionParam) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfText, u.OfImageURL, u.OfInputAudio, u.OfFile) +} +func (u *ChatCompletionContentPartUnionParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *ChatCompletionContentPartUnionParam) asAny() any { + if !param.IsOmitted(u.OfText) { + return u.OfText + } else if !param.IsOmitted(u.OfImageURL) { + return u.OfImageURL + } else if !param.IsOmitted(u.OfInputAudio) { + return u.OfInputAudio + } else if !param.IsOmitted(u.OfFile) { + return u.OfFile + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u ChatCompletionContentPartUnionParam) GetText() *string { + if vt := u.OfText; vt != nil { + return &vt.Text + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u ChatCompletionContentPartUnionParam) GetImageURL() *ChatCompletionContentPartImageImageURLParam { + if vt := u.OfImageURL; vt != nil { + return &vt.ImageURL + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u ChatCompletionContentPartUnionParam) GetInputAudio() *ChatCompletionContentPartInputAudioInputAudioParam { + if vt := u.OfInputAudio; vt != nil { + return &vt.InputAudio + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u ChatCompletionContentPartUnionParam) GetFile() *ChatCompletionContentPartFileFileParam { + if vt := u.OfFile; vt != nil { + return &vt.File + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u ChatCompletionContentPartUnionParam) GetType() *string { + if vt := u.OfText; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfImageURL; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfInputAudio; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfFile; vt != nil { + return (*string)(&vt.Type) + } + return nil +} + +func init() { + apijson.RegisterUnion[ChatCompletionContentPartUnionParam]( + "type", + apijson.Discriminator[ChatCompletionContentPartTextParam]("text"), + apijson.Discriminator[ChatCompletionContentPartImageParam]("image_url"), + apijson.Discriminator[ChatCompletionContentPartInputAudioParam]("input_audio"), + apijson.Discriminator[ChatCompletionContentPartFileParam]("file"), + ) +} + +// Learn about [file inputs](https://platform.openai.com/docs/guides/text) for text +// generation. +// +// The properties File, Type are required. +type ChatCompletionContentPartFileParam struct { + File ChatCompletionContentPartFileFileParam `json:"file,omitzero,required"` + // The type of the content part. Always `file`. + // + // This field can be elided, and will marshal its zero value as "file". + Type constant.File `json:"type,required"` + paramObj +} + +func (r ChatCompletionContentPartFileParam) MarshalJSON() (data []byte, err error) { + type shadow ChatCompletionContentPartFileParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ChatCompletionContentPartFileParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type ChatCompletionContentPartFileFileParam struct { + // The base64 encoded file data, used when passing the file to the model as a + // string. + FileData param.Opt[string] `json:"file_data,omitzero"` + // The ID of an uploaded file to use as input. + FileID param.Opt[string] `json:"file_id,omitzero"` + // The name of the file, used when passing the file to the model as a string. + Filename param.Opt[string] `json:"filename,omitzero"` + paramObj +} + +func (r ChatCompletionContentPartFileFileParam) MarshalJSON() (data []byte, err error) { + type shadow ChatCompletionContentPartFileFileParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ChatCompletionContentPartFileFileParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Learn about [image inputs](https://platform.openai.com/docs/guides/vision). +type ChatCompletionContentPartImage struct { + ImageURL ChatCompletionContentPartImageImageURL `json:"image_url,required"` + // The type of the content part. + Type constant.ImageURL `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ImageURL respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ChatCompletionContentPartImage) RawJSON() string { return r.JSON.raw } +func (r *ChatCompletionContentPartImage) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ToParam converts this ChatCompletionContentPartImage to a +// ChatCompletionContentPartImageParam. +// +// Warning: the fields of the param type will not be present. ToParam should only +// be used at the last possible moment before sending a request. Test for this with +// ChatCompletionContentPartImageParam.Overrides() +func (r ChatCompletionContentPartImage) ToParam() ChatCompletionContentPartImageParam { + return param.Override[ChatCompletionContentPartImageParam](json.RawMessage(r.RawJSON())) +} + +type ChatCompletionContentPartImageImageURL struct { + // Either a URL of the image or the base64 encoded image data. + URL string `json:"url,required" format:"uri"` + // Specifies the detail level of the image. Learn more in the + // [Vision guide](https://platform.openai.com/docs/guides/vision#low-or-high-fidelity-image-understanding). + // + // Any of "auto", "low", "high". + Detail string `json:"detail"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + URL respjson.Field + Detail respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ChatCompletionContentPartImageImageURL) RawJSON() string { return r.JSON.raw } +func (r *ChatCompletionContentPartImageImageURL) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Learn about [image inputs](https://platform.openai.com/docs/guides/vision). +// +// The properties ImageURL, Type are required. +type ChatCompletionContentPartImageParam struct { + ImageURL ChatCompletionContentPartImageImageURLParam `json:"image_url,omitzero,required"` + // The type of the content part. + // + // This field can be elided, and will marshal its zero value as "image_url". + Type constant.ImageURL `json:"type,required"` + paramObj +} + +func (r ChatCompletionContentPartImageParam) MarshalJSON() (data []byte, err error) { + type shadow ChatCompletionContentPartImageParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ChatCompletionContentPartImageParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The property URL is required. +type ChatCompletionContentPartImageImageURLParam struct { + // Either a URL of the image or the base64 encoded image data. + URL string `json:"url,required" format:"uri"` + // Specifies the detail level of the image. Learn more in the + // [Vision guide](https://platform.openai.com/docs/guides/vision#low-or-high-fidelity-image-understanding). + // + // Any of "auto", "low", "high". + Detail string `json:"detail,omitzero"` + paramObj +} + +func (r ChatCompletionContentPartImageImageURLParam) MarshalJSON() (data []byte, err error) { + type shadow ChatCompletionContentPartImageImageURLParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ChatCompletionContentPartImageImageURLParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +func init() { + apijson.RegisterFieldValidator[ChatCompletionContentPartImageImageURLParam]( + "detail", "auto", "low", "high", + ) +} + +// Learn about [audio inputs](https://platform.openai.com/docs/guides/audio). +// +// The properties InputAudio, Type are required. +type ChatCompletionContentPartInputAudioParam struct { + InputAudio ChatCompletionContentPartInputAudioInputAudioParam `json:"input_audio,omitzero,required"` + // The type of the content part. Always `input_audio`. + // + // This field can be elided, and will marshal its zero value as "input_audio". + Type constant.InputAudio `json:"type,required"` + paramObj +} + +func (r ChatCompletionContentPartInputAudioParam) MarshalJSON() (data []byte, err error) { + type shadow ChatCompletionContentPartInputAudioParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ChatCompletionContentPartInputAudioParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The properties Data, Format are required. +type ChatCompletionContentPartInputAudioInputAudioParam struct { + // Base64 encoded audio data. + Data string `json:"data,required"` + // The format of the encoded audio data. Currently supports "wav" and "mp3". + // + // Any of "wav", "mp3". + Format string `json:"format,omitzero,required"` + paramObj +} + +func (r ChatCompletionContentPartInputAudioInputAudioParam) MarshalJSON() (data []byte, err error) { + type shadow ChatCompletionContentPartInputAudioInputAudioParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ChatCompletionContentPartInputAudioInputAudioParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +func init() { + apijson.RegisterFieldValidator[ChatCompletionContentPartInputAudioInputAudioParam]( + "format", "wav", "mp3", + ) +} + +// The properties Refusal, Type are required. +type ChatCompletionContentPartRefusalParam struct { + // The refusal message generated by the model. + Refusal string `json:"refusal,required"` + // The type of the content part. + // + // This field can be elided, and will marshal its zero value as "refusal". + Type constant.Refusal `json:"type,required"` + paramObj +} + +func (r ChatCompletionContentPartRefusalParam) MarshalJSON() (data []byte, err error) { + type shadow ChatCompletionContentPartRefusalParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ChatCompletionContentPartRefusalParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Learn about +// [text inputs](https://platform.openai.com/docs/guides/text-generation). +type ChatCompletionContentPartText struct { + // The text content. + Text string `json:"text,required"` + // The type of the content part. + Type constant.Text `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Text respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ChatCompletionContentPartText) RawJSON() string { return r.JSON.raw } +func (r *ChatCompletionContentPartText) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ToParam converts this ChatCompletionContentPartText to a +// ChatCompletionContentPartTextParam. +// +// Warning: the fields of the param type will not be present. ToParam should only +// be used at the last possible moment before sending a request. Test for this with +// ChatCompletionContentPartTextParam.Overrides() +func (r ChatCompletionContentPartText) ToParam() ChatCompletionContentPartTextParam { + return param.Override[ChatCompletionContentPartTextParam](json.RawMessage(r.RawJSON())) +} + +// Learn about +// [text inputs](https://platform.openai.com/docs/guides/text-generation). +// +// The properties Text, Type are required. +type ChatCompletionContentPartTextParam struct { + // The text content. + Text string `json:"text,required"` + // The type of the content part. + // + // This field can be elided, and will marshal its zero value as "text". + Type constant.Text `json:"type,required"` + paramObj +} + +func (r ChatCompletionContentPartTextParam) MarshalJSON() (data []byte, err error) { + type shadow ChatCompletionContentPartTextParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ChatCompletionContentPartTextParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type ChatCompletionDeleted struct { + // The ID of the chat completion that was deleted. + ID string `json:"id,required"` + // Whether the chat completion was deleted. + Deleted bool `json:"deleted,required"` + // The type of object being deleted. + Object constant.ChatCompletionDeleted `json:"object,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + Deleted respjson.Field + Object respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ChatCompletionDeleted) RawJSON() string { return r.JSON.raw } +func (r *ChatCompletionDeleted) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Developer-provided instructions that the model should follow, regardless of +// messages sent by the user. With o1 models and newer, `developer` messages +// replace the previous `system` messages. +// +// The properties Content, Role are required. +type ChatCompletionDeveloperMessageParam struct { + // The contents of the developer message. + Content ChatCompletionDeveloperMessageParamContentUnion `json:"content,omitzero,required"` + // An optional name for the participant. Provides the model information to + // differentiate between participants of the same role. + Name param.Opt[string] `json:"name,omitzero"` + // The role of the messages author, in this case `developer`. + // + // This field can be elided, and will marshal its zero value as "developer". + Role constant.Developer `json:"role,required"` + paramObj +} + +func (r ChatCompletionDeveloperMessageParam) MarshalJSON() (data []byte, err error) { + type shadow ChatCompletionDeveloperMessageParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ChatCompletionDeveloperMessageParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type ChatCompletionDeveloperMessageParamContentUnion struct { + OfString param.Opt[string] `json:",omitzero,inline"` + OfArrayOfContentParts []ChatCompletionContentPartTextParam `json:",omitzero,inline"` + paramUnion +} + +func (u ChatCompletionDeveloperMessageParamContentUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfString, u.OfArrayOfContentParts) +} +func (u *ChatCompletionDeveloperMessageParamContentUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *ChatCompletionDeveloperMessageParamContentUnion) asAny() any { + if !param.IsOmitted(u.OfString) { + return &u.OfString.Value + } else if !param.IsOmitted(u.OfArrayOfContentParts) { + return &u.OfArrayOfContentParts + } + return nil +} + +// Specifying a particular function via `{"name": "my_function"}` forces the model +// to call that function. +// +// The property Name is required. +type ChatCompletionFunctionCallOptionParam struct { + // The name of the function to call. + Name string `json:"name,required"` + paramObj +} + +func (r ChatCompletionFunctionCallOptionParam) MarshalJSON() (data []byte, err error) { + type shadow ChatCompletionFunctionCallOptionParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ChatCompletionFunctionCallOptionParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Deprecated: deprecated +// +// The properties Content, Name, Role are required. +type ChatCompletionFunctionMessageParam struct { + // The contents of the function message. + Content param.Opt[string] `json:"content,omitzero,required"` + // The name of the function to call. + Name string `json:"name,required"` + // The role of the messages author, in this case `function`. + // + // This field can be elided, and will marshal its zero value as "function". + Role constant.Function `json:"role,required"` + paramObj +} + +func (r ChatCompletionFunctionMessageParam) MarshalJSON() (data []byte, err error) { + type shadow ChatCompletionFunctionMessageParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ChatCompletionFunctionMessageParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// A chat completion message generated by the model. +type ChatCompletionMessage struct { + // The contents of the message. + Content string `json:"content,required"` + // The refusal message generated by the model. + Refusal string `json:"refusal,required"` + // The role of the author of this message. + Role constant.Assistant `json:"role,required"` + // Annotations for the message, when applicable, as when using the + // [web search tool](https://platform.openai.com/docs/guides/tools-web-search?api-mode=chat). + Annotations []ChatCompletionMessageAnnotation `json:"annotations"` + // If the audio output modality is requested, this object contains data about the + // audio response from the model. + // [Learn more](https://platform.openai.com/docs/guides/audio). + Audio ChatCompletionAudio `json:"audio,nullable"` + // Deprecated and replaced by `tool_calls`. The name and arguments of a function + // that should be called, as generated by the model. + // + // Deprecated: deprecated + FunctionCall ChatCompletionMessageFunctionCall `json:"function_call"` + // The tool calls generated by the model, such as function calls. + ToolCalls []ChatCompletionMessageToolCall `json:"tool_calls"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Content respjson.Field + Refusal respjson.Field + Role respjson.Field + Annotations respjson.Field + Audio respjson.Field + FunctionCall respjson.Field + ToolCalls respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ChatCompletionMessage) RawJSON() string { return r.JSON.raw } +func (r *ChatCompletionMessage) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +func (r ChatCompletionMessage) ToParam() ChatCompletionMessageParamUnion { + asst := r.ToAssistantMessageParam() + return ChatCompletionMessageParamUnion{OfAssistant: &asst} +} + +func (r ChatCompletionMessage) ToAssistantMessageParam() ChatCompletionAssistantMessageParam { + var p ChatCompletionAssistantMessageParam + + // It is important to not rely on the JSON metadata property + // here, it may be unset if the receiver was generated via a + // [ChatCompletionAccumulator]. + // + // Explicit null is intentionally elided from the response. + if r.Content != "" { + p.Content.OfString = String(r.Content) + } + if r.Refusal != "" { + p.Refusal = String(r.Refusal) + } + + p.Audio.ID = r.Audio.ID + p.Role = r.Role + p.FunctionCall.Arguments = r.FunctionCall.Arguments + p.FunctionCall.Name = r.FunctionCall.Name + + if len(r.ToolCalls) > 0 { + p.ToolCalls = make([]ChatCompletionMessageToolCallParam, len(r.ToolCalls)) + for i, v := range r.ToolCalls { + p.ToolCalls[i].ID = v.ID + p.ToolCalls[i].Function.Arguments = v.Function.Arguments + p.ToolCalls[i].Function.Name = v.Function.Name + } + } + return p +} + +// A URL citation when using web search. +type ChatCompletionMessageAnnotation struct { + // The type of the URL citation. Always `url_citation`. + Type constant.URLCitation `json:"type,required"` + // A URL citation when using web search. + URLCitation ChatCompletionMessageAnnotationURLCitation `json:"url_citation,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Type respjson.Field + URLCitation respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ChatCompletionMessageAnnotation) RawJSON() string { return r.JSON.raw } +func (r *ChatCompletionMessageAnnotation) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// A URL citation when using web search. +type ChatCompletionMessageAnnotationURLCitation struct { + // The index of the last character of the URL citation in the message. + EndIndex int64 `json:"end_index,required"` + // The index of the first character of the URL citation in the message. + StartIndex int64 `json:"start_index,required"` + // The title of the web resource. + Title string `json:"title,required"` + // The URL of the web resource. + URL string `json:"url,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + EndIndex respjson.Field + StartIndex respjson.Field + Title respjson.Field + URL respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ChatCompletionMessageAnnotationURLCitation) RawJSON() string { return r.JSON.raw } +func (r *ChatCompletionMessageAnnotationURLCitation) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Deprecated and replaced by `tool_calls`. The name and arguments of a function +// that should be called, as generated by the model. +// +// Deprecated: deprecated +type ChatCompletionMessageFunctionCall struct { + // The arguments to call the function with, as generated by the model in JSON + // format. Note that the model does not always generate valid JSON, and may + // hallucinate parameters not defined by your function schema. Validate the + // arguments in your code before calling your function. + Arguments string `json:"arguments,required"` + // The name of the function to call. + Name string `json:"name,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Arguments respjson.Field + Name respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ChatCompletionMessageFunctionCall) RawJSON() string { return r.JSON.raw } +func (r *ChatCompletionMessageFunctionCall) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +func AssistantMessage[T string | []ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion](content T) ChatCompletionMessageParamUnion { + var assistant ChatCompletionAssistantMessageParam + switch v := any(content).(type) { + case string: + assistant.Content.OfString = param.NewOpt(v) + case []ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion: + assistant.Content.OfArrayOfContentParts = v + } + return ChatCompletionMessageParamUnion{OfAssistant: &assistant} +} + +func DeveloperMessage[T string | []ChatCompletionContentPartTextParam](content T) ChatCompletionMessageParamUnion { + var developer ChatCompletionDeveloperMessageParam + switch v := any(content).(type) { + case string: + developer.Content.OfString = param.NewOpt(v) + case []ChatCompletionContentPartTextParam: + developer.Content.OfArrayOfContentParts = v + } + return ChatCompletionMessageParamUnion{OfDeveloper: &developer} +} + +func SystemMessage[T string | []ChatCompletionContentPartTextParam](content T) ChatCompletionMessageParamUnion { + var system ChatCompletionSystemMessageParam + switch v := any(content).(type) { + case string: + system.Content.OfString = param.NewOpt(v) + case []ChatCompletionContentPartTextParam: + system.Content.OfArrayOfContentParts = v + } + return ChatCompletionMessageParamUnion{OfSystem: &system} +} + +func UserMessage[T string | []ChatCompletionContentPartUnionParam](content T) ChatCompletionMessageParamUnion { + var user ChatCompletionUserMessageParam + switch v := any(content).(type) { + case string: + user.Content.OfString = param.NewOpt(v) + case []ChatCompletionContentPartUnionParam: + user.Content.OfArrayOfContentParts = v + } + return ChatCompletionMessageParamUnion{OfUser: &user} +} + +func ChatCompletionMessageParamOfAssistant[ + T string | []ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion, +](content T) ChatCompletionMessageParamUnion { + var assistant ChatCompletionAssistantMessageParam + switch v := any(content).(type) { + case string: + assistant.Content.OfString = param.NewOpt(v) + case []ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion: + assistant.Content.OfArrayOfContentParts = v + } + return ChatCompletionMessageParamUnion{OfAssistant: &assistant} +} + +func ToolMessage[T string | []ChatCompletionContentPartTextParam](content T, toolCallID string) ChatCompletionMessageParamUnion { + var tool ChatCompletionToolMessageParam + switch v := any(content).(type) { + case string: + tool.Content.OfString = param.NewOpt(v) + case []ChatCompletionContentPartTextParam: + tool.Content.OfArrayOfContentParts = v + } + tool.ToolCallID = toolCallID + return ChatCompletionMessageParamUnion{OfTool: &tool} +} + +func ChatCompletionMessageParamOfFunction(content string, name string) ChatCompletionMessageParamUnion { + var function ChatCompletionFunctionMessageParam + function.Content = param.NewOpt(content) + function.Name = name + return ChatCompletionMessageParamUnion{OfFunction: &function} +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type ChatCompletionMessageParamUnion struct { + OfDeveloper *ChatCompletionDeveloperMessageParam `json:",omitzero,inline"` + OfSystem *ChatCompletionSystemMessageParam `json:",omitzero,inline"` + OfUser *ChatCompletionUserMessageParam `json:",omitzero,inline"` + OfAssistant *ChatCompletionAssistantMessageParam `json:",omitzero,inline"` + OfTool *ChatCompletionToolMessageParam `json:",omitzero,inline"` + OfFunction *ChatCompletionFunctionMessageParam `json:",omitzero,inline"` + paramUnion +} + +func (u ChatCompletionMessageParamUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfDeveloper, + u.OfSystem, + u.OfUser, + u.OfAssistant, + u.OfTool, + u.OfFunction) +} +func (u *ChatCompletionMessageParamUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *ChatCompletionMessageParamUnion) asAny() any { + if !param.IsOmitted(u.OfDeveloper) { + return u.OfDeveloper + } else if !param.IsOmitted(u.OfSystem) { + return u.OfSystem + } else if !param.IsOmitted(u.OfUser) { + return u.OfUser + } else if !param.IsOmitted(u.OfAssistant) { + return u.OfAssistant + } else if !param.IsOmitted(u.OfTool) { + return u.OfTool + } else if !param.IsOmitted(u.OfFunction) { + return u.OfFunction + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u ChatCompletionMessageParamUnion) GetAudio() *ChatCompletionAssistantMessageParamAudio { + if vt := u.OfAssistant; vt != nil { + return &vt.Audio + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u ChatCompletionMessageParamUnion) GetFunctionCall() *ChatCompletionAssistantMessageParamFunctionCall { + if vt := u.OfAssistant; vt != nil { + return &vt.FunctionCall + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u ChatCompletionMessageParamUnion) GetRefusal() *string { + if vt := u.OfAssistant; vt != nil && vt.Refusal.Valid() { + return &vt.Refusal.Value + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u ChatCompletionMessageParamUnion) GetToolCalls() []ChatCompletionMessageToolCallParam { + if vt := u.OfAssistant; vt != nil { + return vt.ToolCalls + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u ChatCompletionMessageParamUnion) GetToolCallID() *string { + if vt := u.OfTool; vt != nil { + return &vt.ToolCallID + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u ChatCompletionMessageParamUnion) GetRole() *string { + if vt := u.OfDeveloper; vt != nil { + return (*string)(&vt.Role) + } else if vt := u.OfSystem; vt != nil { + return (*string)(&vt.Role) + } else if vt := u.OfUser; vt != nil { + return (*string)(&vt.Role) + } else if vt := u.OfAssistant; vt != nil { + return (*string)(&vt.Role) + } else if vt := u.OfTool; vt != nil { + return (*string)(&vt.Role) + } else if vt := u.OfFunction; vt != nil { + return (*string)(&vt.Role) + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u ChatCompletionMessageParamUnion) GetName() *string { + if vt := u.OfDeveloper; vt != nil && vt.Name.Valid() { + return &vt.Name.Value + } else if vt := u.OfSystem; vt != nil && vt.Name.Valid() { + return &vt.Name.Value + } else if vt := u.OfUser; vt != nil && vt.Name.Valid() { + return &vt.Name.Value + } else if vt := u.OfAssistant; vt != nil && vt.Name.Valid() { + return &vt.Name.Value + } else if vt := u.OfFunction; vt != nil { + return (*string)(&vt.Name) + } + return nil +} + +// Returns a subunion which exports methods to access subproperties +// +// Or use AsAny() to get the underlying value +func (u ChatCompletionMessageParamUnion) GetContent() (res chatCompletionMessageParamUnionContent) { + if vt := u.OfDeveloper; vt != nil { + res.any = vt.Content.asAny() + } else if vt := u.OfSystem; vt != nil { + res.any = vt.Content.asAny() + } else if vt := u.OfUser; vt != nil { + res.any = vt.Content.asAny() + } else if vt := u.OfAssistant; vt != nil { + res.any = vt.Content.asAny() + } else if vt := u.OfTool; vt != nil { + res.any = vt.Content.asAny() + } else if vt := u.OfFunction; vt != nil && vt.Content.Valid() { + res.any = &vt.Content.Value + } + return +} + +// Can have the runtime types [*string], [_[]ChatCompletionContentPartTextParam], +// [_[]ChatCompletionContentPartUnionParam], +// [\*[]ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion] +type chatCompletionMessageParamUnionContent struct{ any } + +// Use the following switch statement to get the type of the union: +// +// switch u.AsAny().(type) { +// case *string: +// case *[]openai.ChatCompletionContentPartTextParam: +// case *[]openai.ChatCompletionContentPartUnionParam: +// case *[]openai.ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion: +// default: +// fmt.Errorf("not present") +// } +func (u chatCompletionMessageParamUnionContent) AsAny() any { return u.any } + +func init() { + apijson.RegisterUnion[ChatCompletionMessageParamUnion]( + "role", + apijson.Discriminator[ChatCompletionDeveloperMessageParam]("developer"), + apijson.Discriminator[ChatCompletionSystemMessageParam]("system"), + apijson.Discriminator[ChatCompletionUserMessageParam]("user"), + apijson.Discriminator[ChatCompletionAssistantMessageParam]("assistant"), + apijson.Discriminator[ChatCompletionToolMessageParam]("tool"), + apijson.Discriminator[ChatCompletionFunctionMessageParam]("function"), + ) +} + +type ChatCompletionMessageToolCall struct { + // The ID of the tool call. + ID string `json:"id,required"` + // The function that the model called. + Function ChatCompletionMessageToolCallFunction `json:"function,required"` + // The type of the tool. Currently, only `function` is supported. + Type constant.Function `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + Function respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ChatCompletionMessageToolCall) RawJSON() string { return r.JSON.raw } +func (r *ChatCompletionMessageToolCall) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ToParam converts this ChatCompletionMessageToolCall to a +// ChatCompletionMessageToolCallParam. +// +// Warning: the fields of the param type will not be present. ToParam should only +// be used at the last possible moment before sending a request. Test for this with +// ChatCompletionMessageToolCallParam.Overrides() +func (r ChatCompletionMessageToolCall) ToParam() ChatCompletionMessageToolCallParam { + return param.Override[ChatCompletionMessageToolCallParam](json.RawMessage(r.RawJSON())) +} + +// The function that the model called. +type ChatCompletionMessageToolCallFunction struct { + // The arguments to call the function with, as generated by the model in JSON + // format. Note that the model does not always generate valid JSON, and may + // hallucinate parameters not defined by your function schema. Validate the + // arguments in your code before calling your function. + Arguments string `json:"arguments,required"` + // The name of the function to call. + Name string `json:"name,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Arguments respjson.Field + Name respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ChatCompletionMessageToolCallFunction) RawJSON() string { return r.JSON.raw } +func (r *ChatCompletionMessageToolCallFunction) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The properties ID, Function, Type are required. +type ChatCompletionMessageToolCallParam struct { + // The ID of the tool call. + ID string `json:"id,required"` + // The function that the model called. + Function ChatCompletionMessageToolCallFunctionParam `json:"function,omitzero,required"` + // The type of the tool. Currently, only `function` is supported. + // + // This field can be elided, and will marshal its zero value as "function". + Type constant.Function `json:"type,required"` + paramObj +} + +func (r ChatCompletionMessageToolCallParam) MarshalJSON() (data []byte, err error) { + type shadow ChatCompletionMessageToolCallParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ChatCompletionMessageToolCallParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The function that the model called. +// +// The properties Arguments, Name are required. +type ChatCompletionMessageToolCallFunctionParam struct { + // The arguments to call the function with, as generated by the model in JSON + // format. Note that the model does not always generate valid JSON, and may + // hallucinate parameters not defined by your function schema. Validate the + // arguments in your code before calling your function. + Arguments string `json:"arguments,required"` + // The name of the function to call. + Name string `json:"name,required"` + paramObj +} + +func (r ChatCompletionMessageToolCallFunctionParam) MarshalJSON() (data []byte, err error) { + type shadow ChatCompletionMessageToolCallFunctionParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ChatCompletionMessageToolCallFunctionParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Specifies a tool the model should use. Use to force the model to call a specific +// function. +// +// The properties Function, Type are required. +type ChatCompletionNamedToolChoiceParam struct { + Function ChatCompletionNamedToolChoiceFunctionParam `json:"function,omitzero,required"` + // The type of the tool. Currently, only `function` is supported. + // + // This field can be elided, and will marshal its zero value as "function". + Type constant.Function `json:"type,required"` + paramObj +} + +func (r ChatCompletionNamedToolChoiceParam) MarshalJSON() (data []byte, err error) { + type shadow ChatCompletionNamedToolChoiceParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ChatCompletionNamedToolChoiceParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The property Name is required. +type ChatCompletionNamedToolChoiceFunctionParam struct { + // The name of the function to call. + Name string `json:"name,required"` + paramObj +} + +func (r ChatCompletionNamedToolChoiceFunctionParam) MarshalJSON() (data []byte, err error) { + type shadow ChatCompletionNamedToolChoiceFunctionParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ChatCompletionNamedToolChoiceFunctionParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Static predicted output content, such as the content of a text file that is +// being regenerated. +// +// The properties Content, Type are required. +type ChatCompletionPredictionContentParam struct { + // The content that should be matched when generating a model response. If + // generated tokens would match this content, the entire model response can be + // returned much more quickly. + Content ChatCompletionPredictionContentContentUnionParam `json:"content,omitzero,required"` + // The type of the predicted content you want to provide. This type is currently + // always `content`. + // + // This field can be elided, and will marshal its zero value as "content". + Type constant.Content `json:"type,required"` + paramObj +} + +func (r ChatCompletionPredictionContentParam) MarshalJSON() (data []byte, err error) { + type shadow ChatCompletionPredictionContentParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ChatCompletionPredictionContentParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type ChatCompletionPredictionContentContentUnionParam struct { + OfString param.Opt[string] `json:",omitzero,inline"` + OfArrayOfContentParts []ChatCompletionContentPartTextParam `json:",omitzero,inline"` + paramUnion +} + +func (u ChatCompletionPredictionContentContentUnionParam) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfString, u.OfArrayOfContentParts) +} +func (u *ChatCompletionPredictionContentContentUnionParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *ChatCompletionPredictionContentContentUnionParam) asAny() any { + if !param.IsOmitted(u.OfString) { + return &u.OfString.Value + } else if !param.IsOmitted(u.OfArrayOfContentParts) { + return &u.OfArrayOfContentParts + } + return nil +} + +// A chat completion message generated by the model. +type ChatCompletionStoreMessage struct { + // The identifier of the chat message. + ID string `json:"id,required"` + // If a content parts array was provided, this is an array of `text` and + // `image_url` parts. Otherwise, null. + ContentParts []ChatCompletionStoreMessageContentPartUnion `json:"content_parts,nullable"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + ContentParts respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` + ChatCompletionMessage +} + +// Returns the unmodified JSON received from the API +func (r ChatCompletionStoreMessage) RawJSON() string { return r.JSON.raw } +func (r *ChatCompletionStoreMessage) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ChatCompletionStoreMessageContentPartUnion contains all possible properties and +// values from [ChatCompletionContentPartText], [ChatCompletionContentPartImage]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +type ChatCompletionStoreMessageContentPartUnion struct { + // This field is from variant [ChatCompletionContentPartText]. + Text string `json:"text"` + Type string `json:"type"` + // This field is from variant [ChatCompletionContentPartImage]. + ImageURL ChatCompletionContentPartImageImageURL `json:"image_url"` + JSON struct { + Text respjson.Field + Type respjson.Field + ImageURL respjson.Field + raw string + } `json:"-"` +} + +func (u ChatCompletionStoreMessageContentPartUnion) AsTextContentPart() (v ChatCompletionContentPartText) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u ChatCompletionStoreMessageContentPartUnion) AsImageContentPart() (v ChatCompletionContentPartImage) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u ChatCompletionStoreMessageContentPartUnion) RawJSON() string { return u.JSON.raw } + +func (r *ChatCompletionStoreMessageContentPartUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Options for streaming response. Only set this when you set `stream: true`. +type ChatCompletionStreamOptionsParam struct { + // If set, an additional chunk will be streamed before the `data: [DONE]` message. + // The `usage` field on this chunk shows the token usage statistics for the entire + // request, and the `choices` field will always be an empty array. + // + // All other chunks will also include a `usage` field, but with a null value. + // **NOTE:** If the stream is interrupted, you may not receive the final usage + // chunk which contains the total token usage for the request. + IncludeUsage param.Opt[bool] `json:"include_usage,omitzero"` + paramObj +} + +func (r ChatCompletionStreamOptionsParam) MarshalJSON() (data []byte, err error) { + type shadow ChatCompletionStreamOptionsParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ChatCompletionStreamOptionsParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Developer-provided instructions that the model should follow, regardless of +// messages sent by the user. With o1 models and newer, use `developer` messages +// for this purpose instead. +// +// The properties Content, Role are required. +type ChatCompletionSystemMessageParam struct { + // The contents of the system message. + Content ChatCompletionSystemMessageParamContentUnion `json:"content,omitzero,required"` + // An optional name for the participant. Provides the model information to + // differentiate between participants of the same role. + Name param.Opt[string] `json:"name,omitzero"` + // The role of the messages author, in this case `system`. + // + // This field can be elided, and will marshal its zero value as "system". + Role constant.System `json:"role,required"` + paramObj +} + +func (r ChatCompletionSystemMessageParam) MarshalJSON() (data []byte, err error) { + type shadow ChatCompletionSystemMessageParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ChatCompletionSystemMessageParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type ChatCompletionSystemMessageParamContentUnion struct { + OfString param.Opt[string] `json:",omitzero,inline"` + OfArrayOfContentParts []ChatCompletionContentPartTextParam `json:",omitzero,inline"` + paramUnion +} + +func (u ChatCompletionSystemMessageParamContentUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfString, u.OfArrayOfContentParts) +} +func (u *ChatCompletionSystemMessageParamContentUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *ChatCompletionSystemMessageParamContentUnion) asAny() any { + if !param.IsOmitted(u.OfString) { + return &u.OfString.Value + } else if !param.IsOmitted(u.OfArrayOfContentParts) { + return &u.OfArrayOfContentParts + } + return nil +} + +type ChatCompletionTokenLogprob struct { + // The token. + Token string `json:"token,required"` + // A list of integers representing the UTF-8 bytes representation of the token. + // Useful in instances where characters are represented by multiple tokens and + // their byte representations must be combined to generate the correct text + // representation. Can be `null` if there is no bytes representation for the token. + Bytes []int64 `json:"bytes,required"` + // The log probability of this token, if it is within the top 20 most likely + // tokens. Otherwise, the value `-9999.0` is used to signify that the token is very + // unlikely. + Logprob float64 `json:"logprob,required"` + // List of the most likely tokens and their log probability, at this token + // position. In rare cases, there may be fewer than the number of requested + // `top_logprobs` returned. + TopLogprobs []ChatCompletionTokenLogprobTopLogprob `json:"top_logprobs,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Token respjson.Field + Bytes respjson.Field + Logprob respjson.Field + TopLogprobs respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ChatCompletionTokenLogprob) RawJSON() string { return r.JSON.raw } +func (r *ChatCompletionTokenLogprob) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type ChatCompletionTokenLogprobTopLogprob struct { + // The token. + Token string `json:"token,required"` + // A list of integers representing the UTF-8 bytes representation of the token. + // Useful in instances where characters are represented by multiple tokens and + // their byte representations must be combined to generate the correct text + // representation. Can be `null` if there is no bytes representation for the token. + Bytes []int64 `json:"bytes,required"` + // The log probability of this token, if it is within the top 20 most likely + // tokens. Otherwise, the value `-9999.0` is used to signify that the token is very + // unlikely. + Logprob float64 `json:"logprob,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Token respjson.Field + Bytes respjson.Field + Logprob respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ChatCompletionTokenLogprobTopLogprob) RawJSON() string { return r.JSON.raw } +func (r *ChatCompletionTokenLogprobTopLogprob) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The properties Function, Type are required. +type ChatCompletionToolParam struct { + Function shared.FunctionDefinitionParam `json:"function,omitzero,required"` + // The type of the tool. Currently, only `function` is supported. + // + // This field can be elided, and will marshal its zero value as "function". + Type constant.Function `json:"type,required"` + paramObj +} + +func (r ChatCompletionToolParam) MarshalJSON() (data []byte, err error) { + type shadow ChatCompletionToolParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ChatCompletionToolParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +func ChatCompletionToolChoiceOptionParamOfChatCompletionNamedToolChoice(function ChatCompletionNamedToolChoiceFunctionParam) ChatCompletionToolChoiceOptionUnionParam { + var variant ChatCompletionNamedToolChoiceParam + variant.Function = function + return ChatCompletionToolChoiceOptionUnionParam{OfChatCompletionNamedToolChoice: &variant} +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type ChatCompletionToolChoiceOptionUnionParam struct { + // Check if union is this variant with !param.IsOmitted(union.OfAuto) + OfAuto param.Opt[string] `json:",omitzero,inline"` + OfChatCompletionNamedToolChoice *ChatCompletionNamedToolChoiceParam `json:",omitzero,inline"` + paramUnion +} + +func (u ChatCompletionToolChoiceOptionUnionParam) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfAuto, u.OfChatCompletionNamedToolChoice) +} +func (u *ChatCompletionToolChoiceOptionUnionParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *ChatCompletionToolChoiceOptionUnionParam) asAny() any { + if !param.IsOmitted(u.OfAuto) { + return &u.OfAuto + } else if !param.IsOmitted(u.OfChatCompletionNamedToolChoice) { + return u.OfChatCompletionNamedToolChoice + } + return nil +} + +// `none` means the model will not call any tool and instead generates a message. +// `auto` means the model can pick between generating a message or calling one or +// more tools. `required` means the model must call one or more tools. +type ChatCompletionToolChoiceOptionAuto string + +const ( + ChatCompletionToolChoiceOptionAutoNone ChatCompletionToolChoiceOptionAuto = "none" + ChatCompletionToolChoiceOptionAutoAuto ChatCompletionToolChoiceOptionAuto = "auto" + ChatCompletionToolChoiceOptionAutoRequired ChatCompletionToolChoiceOptionAuto = "required" +) + +// The properties Content, Role, ToolCallID are required. +type ChatCompletionToolMessageParam struct { + // The contents of the tool message. + Content ChatCompletionToolMessageParamContentUnion `json:"content,omitzero,required"` + // Tool call that this message is responding to. + ToolCallID string `json:"tool_call_id,required"` + // The role of the messages author, in this case `tool`. + // + // This field can be elided, and will marshal its zero value as "tool". + Role constant.Tool `json:"role,required"` + paramObj +} + +func (r ChatCompletionToolMessageParam) MarshalJSON() (data []byte, err error) { + type shadow ChatCompletionToolMessageParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ChatCompletionToolMessageParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type ChatCompletionToolMessageParamContentUnion struct { + OfString param.Opt[string] `json:",omitzero,inline"` + OfArrayOfContentParts []ChatCompletionContentPartTextParam `json:",omitzero,inline"` + paramUnion +} + +func (u ChatCompletionToolMessageParamContentUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfString, u.OfArrayOfContentParts) +} +func (u *ChatCompletionToolMessageParamContentUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *ChatCompletionToolMessageParamContentUnion) asAny() any { + if !param.IsOmitted(u.OfString) { + return &u.OfString.Value + } else if !param.IsOmitted(u.OfArrayOfContentParts) { + return &u.OfArrayOfContentParts + } + return nil +} + +// Messages sent by an end user, containing prompts or additional context +// information. +// +// The properties Content, Role are required. +type ChatCompletionUserMessageParam struct { + // The contents of the user message. + Content ChatCompletionUserMessageParamContentUnion `json:"content,omitzero,required"` + // An optional name for the participant. Provides the model information to + // differentiate between participants of the same role. + Name param.Opt[string] `json:"name,omitzero"` + // The role of the messages author, in this case `user`. + // + // This field can be elided, and will marshal its zero value as "user". + Role constant.User `json:"role,required"` + paramObj +} + +func (r ChatCompletionUserMessageParam) MarshalJSON() (data []byte, err error) { + type shadow ChatCompletionUserMessageParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ChatCompletionUserMessageParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type ChatCompletionUserMessageParamContentUnion struct { + OfString param.Opt[string] `json:",omitzero,inline"` + OfArrayOfContentParts []ChatCompletionContentPartUnionParam `json:",omitzero,inline"` + paramUnion +} + +func (u ChatCompletionUserMessageParamContentUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfString, u.OfArrayOfContentParts) +} +func (u *ChatCompletionUserMessageParamContentUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *ChatCompletionUserMessageParamContentUnion) asAny() any { + if !param.IsOmitted(u.OfString) { + return &u.OfString.Value + } else if !param.IsOmitted(u.OfArrayOfContentParts) { + return &u.OfArrayOfContentParts + } + return nil +} + +type ChatCompletionNewParams struct { + // A list of messages comprising the conversation so far. Depending on the + // [model](https://platform.openai.com/docs/models) you use, different message + // types (modalities) are supported, like + // [text](https://platform.openai.com/docs/guides/text-generation), + // [images](https://platform.openai.com/docs/guides/vision), and + // [audio](https://platform.openai.com/docs/guides/audio). + Messages []ChatCompletionMessageParamUnion `json:"messages,omitzero,required"` + // Model ID used to generate the response, like `gpt-4o` or `o3`. OpenAI offers a + // wide range of models with different capabilities, performance characteristics, + // and price points. Refer to the + // [model guide](https://platform.openai.com/docs/models) to browse and compare + // available models. + Model shared.ChatModel `json:"model,omitzero,required"` + // Number between -2.0 and 2.0. Positive values penalize new tokens based on their + // existing frequency in the text so far, decreasing the model's likelihood to + // repeat the same line verbatim. + FrequencyPenalty param.Opt[float64] `json:"frequency_penalty,omitzero"` + // Whether to return log probabilities of the output tokens or not. If true, + // returns the log probabilities of each output token returned in the `content` of + // `message`. + Logprobs param.Opt[bool] `json:"logprobs,omitzero"` + // An upper bound for the number of tokens that can be generated for a completion, + // including visible output tokens and + // [reasoning tokens](https://platform.openai.com/docs/guides/reasoning). + MaxCompletionTokens param.Opt[int64] `json:"max_completion_tokens,omitzero"` + // The maximum number of [tokens](/tokenizer) that can be generated in the chat + // completion. This value can be used to control + // [costs](https://openai.com/api/pricing/) for text generated via API. + // + // This value is now deprecated in favor of `max_completion_tokens`, and is not + // compatible with + // [o-series models](https://platform.openai.com/docs/guides/reasoning). + MaxTokens param.Opt[int64] `json:"max_tokens,omitzero"` + // How many chat completion choices to generate for each input message. Note that + // you will be charged based on the number of generated tokens across all of the + // choices. Keep `n` as `1` to minimize costs. + N param.Opt[int64] `json:"n,omitzero"` + // Number between -2.0 and 2.0. Positive values penalize new tokens based on + // whether they appear in the text so far, increasing the model's likelihood to + // talk about new topics. + PresencePenalty param.Opt[float64] `json:"presence_penalty,omitzero"` + // This feature is in Beta. If specified, our system will make a best effort to + // sample deterministically, such that repeated requests with the same `seed` and + // parameters should return the same result. Determinism is not guaranteed, and you + // should refer to the `system_fingerprint` response parameter to monitor changes + // in the backend. + Seed param.Opt[int64] `json:"seed,omitzero"` + // Whether or not to store the output of this chat completion request for use in + // our [model distillation](https://platform.openai.com/docs/guides/distillation) + // or [evals](https://platform.openai.com/docs/guides/evals) products. + // + // Supports text and image inputs. Note: image inputs over 10MB will be dropped. + Store param.Opt[bool] `json:"store,omitzero"` + // What sampling temperature to use, between 0 and 2. Higher values like 0.8 will + // make the output more random, while lower values like 0.2 will make it more + // focused and deterministic. We generally recommend altering this or `top_p` but + // not both. + Temperature param.Opt[float64] `json:"temperature,omitzero"` + // An integer between 0 and 20 specifying the number of most likely tokens to + // return at each token position, each with an associated log probability. + // `logprobs` must be set to `true` if this parameter is used. + TopLogprobs param.Opt[int64] `json:"top_logprobs,omitzero"` + // An alternative to sampling with temperature, called nucleus sampling, where the + // model considers the results of the tokens with top_p probability mass. So 0.1 + // means only the tokens comprising the top 10% probability mass are considered. + // + // We generally recommend altering this or `temperature` but not both. + TopP param.Opt[float64] `json:"top_p,omitzero"` + // Whether to enable + // [parallel function calling](https://platform.openai.com/docs/guides/function-calling#configuring-parallel-function-calling) + // during tool use. + ParallelToolCalls param.Opt[bool] `json:"parallel_tool_calls,omitzero"` + // Used by OpenAI to cache responses for similar requests to optimize your cache + // hit rates. Replaces the `user` field. + // [Learn more](https://platform.openai.com/docs/guides/prompt-caching). + PromptCacheKey param.Opt[string] `json:"prompt_cache_key,omitzero"` + // A stable identifier used to help detect users of your application that may be + // violating OpenAI's usage policies. The IDs should be a string that uniquely + // identifies each user. We recommend hashing their username or email address, in + // order to avoid sending us any identifying information. + // [Learn more](https://platform.openai.com/docs/guides/safety-best-practices#safety-identifiers). + SafetyIdentifier param.Opt[string] `json:"safety_identifier,omitzero"` + // This field is being replaced by `safety_identifier` and `prompt_cache_key`. Use + // `prompt_cache_key` instead to maintain caching optimizations. A stable + // identifier for your end-users. Used to boost cache hit rates by better bucketing + // similar requests and to help OpenAI detect and prevent abuse. + // [Learn more](https://platform.openai.com/docs/guides/safety-best-practices#safety-identifiers). + User param.Opt[string] `json:"user,omitzero"` + // Parameters for audio output. Required when audio output is requested with + // `modalities: ["audio"]`. + // [Learn more](https://platform.openai.com/docs/guides/audio). + Audio ChatCompletionAudioParam `json:"audio,omitzero"` + // Modify the likelihood of specified tokens appearing in the completion. + // + // Accepts a JSON object that maps tokens (specified by their token ID in the + // tokenizer) to an associated bias value from -100 to 100. Mathematically, the + // bias is added to the logits generated by the model prior to sampling. The exact + // effect will vary per model, but values between -1 and 1 should decrease or + // increase likelihood of selection; values like -100 or 100 should result in a ban + // or exclusive selection of the relevant token. + LogitBias map[string]int64 `json:"logit_bias,omitzero"` + // Set of 16 key-value pairs that can be attached to an object. This can be useful + // for storing additional information about the object in a structured format, and + // querying for objects via API or the dashboard. + // + // Keys are strings with a maximum length of 64 characters. Values are strings with + // a maximum length of 512 characters. + Metadata shared.Metadata `json:"metadata,omitzero"` + // Output types that you would like the model to generate. Most models are capable + // of generating text, which is the default: + // + // `["text"]` + // + // The `gpt-4o-audio-preview` model can also be used to + // [generate audio](https://platform.openai.com/docs/guides/audio). To request that + // this model generate both text and audio responses, you can use: + // + // `["text", "audio"]` + // + // Any of "text", "audio". + Modalities []string `json:"modalities,omitzero"` + // **o-series models only** + // + // Constrains effort on reasoning for + // [reasoning models](https://platform.openai.com/docs/guides/reasoning). Currently + // supported values are `low`, `medium`, and `high`. Reducing reasoning effort can + // result in faster responses and fewer tokens used on reasoning in a response. + // + // Any of "low", "medium", "high". + ReasoningEffort shared.ReasoningEffort `json:"reasoning_effort,omitzero"` + // Specifies the processing type used for serving the request. + // + // - If set to 'auto', then the request will be processed with the service tier + // configured in the Project settings. Unless otherwise configured, the Project + // will use 'default'. + // - If set to 'default', then the request will be processed with the standard + // pricing and performance for the selected model. + // - If set to '[flex](https://platform.openai.com/docs/guides/flex-processing)' or + // 'priority', then the request will be processed with the corresponding service + // tier. [Contact sales](https://openai.com/contact-sales) to learn more about + // Priority processing. + // - When not set, the default behavior is 'auto'. + // + // When the `service_tier` parameter is set, the response body will include the + // `service_tier` value based on the processing mode actually used to serve the + // request. This response value may be different from the value set in the + // parameter. + // + // Any of "auto", "default", "flex", "scale", "priority". + ServiceTier ChatCompletionNewParamsServiceTier `json:"service_tier,omitzero"` + // Not supported with latest reasoning models `o3` and `o4-mini`. + // + // Up to 4 sequences where the API will stop generating further tokens. The + // returned text will not contain the stop sequence. + Stop ChatCompletionNewParamsStopUnion `json:"stop,omitzero"` + // Options for streaming response. Only set this when you set `stream: true`. + StreamOptions ChatCompletionStreamOptionsParam `json:"stream_options,omitzero"` + // Deprecated in favor of `tool_choice`. + // + // Controls which (if any) function is called by the model. + // + // `none` means the model will not call a function and instead generates a message. + // + // `auto` means the model can pick between generating a message or calling a + // function. + // + // Specifying a particular function via `{"name": "my_function"}` forces the model + // to call that function. + // + // `none` is the default when no functions are present. `auto` is the default if + // functions are present. + FunctionCall ChatCompletionNewParamsFunctionCallUnion `json:"function_call,omitzero"` + // Deprecated in favor of `tools`. + // + // A list of functions the model may generate JSON inputs for. + Functions []ChatCompletionNewParamsFunction `json:"functions,omitzero"` + // Static predicted output content, such as the content of a text file that is + // being regenerated. + Prediction ChatCompletionPredictionContentParam `json:"prediction,omitzero"` + // An object specifying the format that the model must output. + // + // Setting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured + // Outputs which ensures the model will match your supplied JSON schema. Learn more + // in the + // [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs). + // + // Setting to `{ "type": "json_object" }` enables the older JSON mode, which + // ensures the message the model generates is valid JSON. Using `json_schema` is + // preferred for models that support it. + ResponseFormat ChatCompletionNewParamsResponseFormatUnion `json:"response_format,omitzero"` + // Controls which (if any) tool is called by the model. `none` means the model will + // not call any tool and instead generates a message. `auto` means the model can + // pick between generating a message or calling one or more tools. `required` means + // the model must call one or more tools. Specifying a particular tool via + // `{"type": "function", "function": {"name": "my_function"}}` forces the model to + // call that tool. + // + // `none` is the default when no tools are present. `auto` is the default if tools + // are present. + ToolChoice ChatCompletionToolChoiceOptionUnionParam `json:"tool_choice,omitzero"` + // A list of tools the model may call. Currently, only functions are supported as a + // tool. Use this to provide a list of functions the model may generate JSON inputs + // for. A max of 128 functions are supported. + Tools []ChatCompletionToolParam `json:"tools,omitzero"` + // This tool searches the web for relevant results to use in a response. Learn more + // about the + // [web search tool](https://platform.openai.com/docs/guides/tools-web-search?api-mode=chat). + WebSearchOptions ChatCompletionNewParamsWebSearchOptions `json:"web_search_options,omitzero"` + paramObj +} + +func (r ChatCompletionNewParams) MarshalJSON() (data []byte, err error) { + type shadow ChatCompletionNewParams + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ChatCompletionNewParams) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type ChatCompletionNewParamsFunctionCallUnion struct { + // Check if union is this variant with !param.IsOmitted(union.OfFunctionCallMode) + OfFunctionCallMode param.Opt[string] `json:",omitzero,inline"` + OfFunctionCallOption *ChatCompletionFunctionCallOptionParam `json:",omitzero,inline"` + paramUnion +} + +func (u ChatCompletionNewParamsFunctionCallUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfFunctionCallMode, u.OfFunctionCallOption) +} +func (u *ChatCompletionNewParamsFunctionCallUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *ChatCompletionNewParamsFunctionCallUnion) asAny() any { + if !param.IsOmitted(u.OfFunctionCallMode) { + return &u.OfFunctionCallMode + } else if !param.IsOmitted(u.OfFunctionCallOption) { + return u.OfFunctionCallOption + } + return nil +} + +// `none` means the model will not call a function and instead generates a message. +// `auto` means the model can pick between generating a message or calling a +// function. +type ChatCompletionNewParamsFunctionCallFunctionCallMode string + +const ( + ChatCompletionNewParamsFunctionCallFunctionCallModeNone ChatCompletionNewParamsFunctionCallFunctionCallMode = "none" + ChatCompletionNewParamsFunctionCallFunctionCallModeAuto ChatCompletionNewParamsFunctionCallFunctionCallMode = "auto" +) + +// Deprecated: deprecated +// +// The property Name is required. +type ChatCompletionNewParamsFunction struct { + // The name of the function to be called. Must be a-z, A-Z, 0-9, or contain + // underscores and dashes, with a maximum length of 64. + Name string `json:"name,required"` + // A description of what the function does, used by the model to choose when and + // how to call the function. + Description param.Opt[string] `json:"description,omitzero"` + // The parameters the functions accepts, described as a JSON Schema object. See the + // [guide](https://platform.openai.com/docs/guides/function-calling) for examples, + // and the + // [JSON Schema reference](https://json-schema.org/understanding-json-schema/) for + // documentation about the format. + // + // Omitting `parameters` defines a function with an empty parameter list. + Parameters shared.FunctionParameters `json:"parameters,omitzero"` + paramObj +} + +func (r ChatCompletionNewParamsFunction) MarshalJSON() (data []byte, err error) { + type shadow ChatCompletionNewParamsFunction + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ChatCompletionNewParamsFunction) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type ChatCompletionNewParamsResponseFormatUnion struct { + OfText *shared.ResponseFormatTextParam `json:",omitzero,inline"` + OfJSONSchema *shared.ResponseFormatJSONSchemaParam `json:",omitzero,inline"` + OfJSONObject *shared.ResponseFormatJSONObjectParam `json:",omitzero,inline"` + paramUnion +} + +func (u ChatCompletionNewParamsResponseFormatUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfText, u.OfJSONSchema, u.OfJSONObject) +} +func (u *ChatCompletionNewParamsResponseFormatUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *ChatCompletionNewParamsResponseFormatUnion) asAny() any { + if !param.IsOmitted(u.OfText) { + return u.OfText + } else if !param.IsOmitted(u.OfJSONSchema) { + return u.OfJSONSchema + } else if !param.IsOmitted(u.OfJSONObject) { + return u.OfJSONObject + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u ChatCompletionNewParamsResponseFormatUnion) GetJSONSchema() *shared.ResponseFormatJSONSchemaJSONSchemaParam { + if vt := u.OfJSONSchema; vt != nil { + return &vt.JSONSchema + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u ChatCompletionNewParamsResponseFormatUnion) GetType() *string { + if vt := u.OfText; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfJSONSchema; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfJSONObject; vt != nil { + return (*string)(&vt.Type) + } + return nil +} + +// Specifies the processing type used for serving the request. +// +// - If set to 'auto', then the request will be processed with the service tier +// configured in the Project settings. Unless otherwise configured, the Project +// will use 'default'. +// - If set to 'default', then the request will be processed with the standard +// pricing and performance for the selected model. +// - If set to '[flex](https://platform.openai.com/docs/guides/flex-processing)' or +// 'priority', then the request will be processed with the corresponding service +// tier. [Contact sales](https://openai.com/contact-sales) to learn more about +// Priority processing. +// - When not set, the default behavior is 'auto'. +// +// When the `service_tier` parameter is set, the response body will include the +// `service_tier` value based on the processing mode actually used to serve the +// request. This response value may be different from the value set in the +// parameter. +type ChatCompletionNewParamsServiceTier string + +const ( + ChatCompletionNewParamsServiceTierAuto ChatCompletionNewParamsServiceTier = "auto" + ChatCompletionNewParamsServiceTierDefault ChatCompletionNewParamsServiceTier = "default" + ChatCompletionNewParamsServiceTierFlex ChatCompletionNewParamsServiceTier = "flex" + ChatCompletionNewParamsServiceTierScale ChatCompletionNewParamsServiceTier = "scale" + ChatCompletionNewParamsServiceTierPriority ChatCompletionNewParamsServiceTier = "priority" +) + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type ChatCompletionNewParamsStopUnion struct { + OfString param.Opt[string] `json:",omitzero,inline"` + OfStringArray []string `json:",omitzero,inline"` + paramUnion +} + +func (u ChatCompletionNewParamsStopUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfString, u.OfStringArray) +} +func (u *ChatCompletionNewParamsStopUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *ChatCompletionNewParamsStopUnion) asAny() any { + if !param.IsOmitted(u.OfString) { + return &u.OfString.Value + } else if !param.IsOmitted(u.OfStringArray) { + return &u.OfStringArray + } + return nil +} + +// This tool searches the web for relevant results to use in a response. Learn more +// about the +// [web search tool](https://platform.openai.com/docs/guides/tools-web-search?api-mode=chat). +type ChatCompletionNewParamsWebSearchOptions struct { + // Approximate location parameters for the search. + UserLocation ChatCompletionNewParamsWebSearchOptionsUserLocation `json:"user_location,omitzero"` + // High level guidance for the amount of context window space to use for the + // search. One of `low`, `medium`, or `high`. `medium` is the default. + // + // Any of "low", "medium", "high". + SearchContextSize string `json:"search_context_size,omitzero"` + paramObj +} + +func (r ChatCompletionNewParamsWebSearchOptions) MarshalJSON() (data []byte, err error) { + type shadow ChatCompletionNewParamsWebSearchOptions + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ChatCompletionNewParamsWebSearchOptions) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +func init() { + apijson.RegisterFieldValidator[ChatCompletionNewParamsWebSearchOptions]( + "search_context_size", "low", "medium", "high", + ) +} + +// Approximate location parameters for the search. +// +// The properties Approximate, Type are required. +type ChatCompletionNewParamsWebSearchOptionsUserLocation struct { + // Approximate location parameters for the search. + Approximate ChatCompletionNewParamsWebSearchOptionsUserLocationApproximate `json:"approximate,omitzero,required"` + // The type of location approximation. Always `approximate`. + // + // This field can be elided, and will marshal its zero value as "approximate". + Type constant.Approximate `json:"type,required"` + paramObj +} + +func (r ChatCompletionNewParamsWebSearchOptionsUserLocation) MarshalJSON() (data []byte, err error) { + type shadow ChatCompletionNewParamsWebSearchOptionsUserLocation + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ChatCompletionNewParamsWebSearchOptionsUserLocation) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Approximate location parameters for the search. +type ChatCompletionNewParamsWebSearchOptionsUserLocationApproximate struct { + // Free text input for the city of the user, e.g. `San Francisco`. + City param.Opt[string] `json:"city,omitzero"` + // The two-letter [ISO country code](https://en.wikipedia.org/wiki/ISO_3166-1) of + // the user, e.g. `US`. + Country param.Opt[string] `json:"country,omitzero"` + // Free text input for the region of the user, e.g. `California`. + Region param.Opt[string] `json:"region,omitzero"` + // The [IANA timezone](https://timeapi.io/documentation/iana-timezones) of the + // user, e.g. `America/Los_Angeles`. + Timezone param.Opt[string] `json:"timezone,omitzero"` + paramObj +} + +func (r ChatCompletionNewParamsWebSearchOptionsUserLocationApproximate) MarshalJSON() (data []byte, err error) { + type shadow ChatCompletionNewParamsWebSearchOptionsUserLocationApproximate + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ChatCompletionNewParamsWebSearchOptionsUserLocationApproximate) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type ChatCompletionUpdateParams struct { + // Set of 16 key-value pairs that can be attached to an object. This can be useful + // for storing additional information about the object in a structured format, and + // querying for objects via API or the dashboard. + // + // Keys are strings with a maximum length of 64 characters. Values are strings with + // a maximum length of 512 characters. + Metadata shared.Metadata `json:"metadata,omitzero,required"` + paramObj +} + +func (r ChatCompletionUpdateParams) MarshalJSON() (data []byte, err error) { + type shadow ChatCompletionUpdateParams + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ChatCompletionUpdateParams) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type ChatCompletionListParams struct { + // Identifier for the last chat completion from the previous pagination request. + After param.Opt[string] `query:"after,omitzero" json:"-"` + // Number of Chat Completions to retrieve. + Limit param.Opt[int64] `query:"limit,omitzero" json:"-"` + // The model used to generate the Chat Completions. + Model param.Opt[string] `query:"model,omitzero" json:"-"` + // A list of metadata keys to filter the Chat Completions by. Example: + // + // `metadata[key1]=value1&metadata[key2]=value2` + Metadata shared.Metadata `query:"metadata,omitzero" json:"-"` + // Sort order for Chat Completions by timestamp. Use `asc` for ascending order or + // `desc` for descending order. Defaults to `asc`. + // + // Any of "asc", "desc". + Order ChatCompletionListParamsOrder `query:"order,omitzero" json:"-"` + paramObj +} + +// URLQuery serializes [ChatCompletionListParams]'s query parameters as +// `url.Values`. +func (r ChatCompletionListParams) URLQuery() (v url.Values, err error) { + return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{ + ArrayFormat: apiquery.ArrayQueryFormatBrackets, + NestedFormat: apiquery.NestedQueryFormatBrackets, + }) +} + +// Sort order for Chat Completions by timestamp. Use `asc` for ascending order or +// `desc` for descending order. Defaults to `asc`. +type ChatCompletionListParamsOrder string + +const ( + ChatCompletionListParamsOrderAsc ChatCompletionListParamsOrder = "asc" + ChatCompletionListParamsOrderDesc ChatCompletionListParamsOrder = "desc" +) diff --git a/vendor/github.com/openai/openai-go/chatcompletionmessage.go b/vendor/github.com/openai/openai-go/chatcompletionmessage.go new file mode 100644 index 0000000000..4b44d41684 --- /dev/null +++ b/vendor/github.com/openai/openai-go/chatcompletionmessage.go @@ -0,0 +1,96 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package openai + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + + "github.com/openai/openai-go/internal/apiquery" + "github.com/openai/openai-go/internal/requestconfig" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/packages/pagination" + "github.com/openai/openai-go/packages/param" +) + +// ChatCompletionMessageService contains methods and other services that help with +// interacting with the openai API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewChatCompletionMessageService] method instead. +type ChatCompletionMessageService struct { + Options []option.RequestOption +} + +// NewChatCompletionMessageService generates a new service that applies the given +// options to each request. These options are applied after the parent client's +// options (if there is one), and before any request-specific options. +func NewChatCompletionMessageService(opts ...option.RequestOption) (r ChatCompletionMessageService) { + r = ChatCompletionMessageService{} + r.Options = opts + return +} + +// Get the messages in a stored chat completion. Only Chat Completions that have +// been created with the `store` parameter set to `true` will be returned. +func (r *ChatCompletionMessageService) List(ctx context.Context, completionID string, query ChatCompletionMessageListParams, opts ...option.RequestOption) (res *pagination.CursorPage[ChatCompletionStoreMessage], err error) { + var raw *http.Response + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithResponseInto(&raw)}, opts...) + if completionID == "" { + err = errors.New("missing required completion_id parameter") + return + } + path := fmt.Sprintf("chat/completions/%s/messages", completionID) + cfg, err := requestconfig.NewRequestConfig(ctx, http.MethodGet, path, query, &res, opts...) + if err != nil { + return nil, err + } + err = cfg.Execute() + if err != nil { + return nil, err + } + res.SetPageConfig(cfg, raw) + return res, nil +} + +// Get the messages in a stored chat completion. Only Chat Completions that have +// been created with the `store` parameter set to `true` will be returned. +func (r *ChatCompletionMessageService) ListAutoPaging(ctx context.Context, completionID string, query ChatCompletionMessageListParams, opts ...option.RequestOption) *pagination.CursorPageAutoPager[ChatCompletionStoreMessage] { + return pagination.NewCursorPageAutoPager(r.List(ctx, completionID, query, opts...)) +} + +type ChatCompletionMessageListParams struct { + // Identifier for the last message from the previous pagination request. + After param.Opt[string] `query:"after,omitzero" json:"-"` + // Number of messages to retrieve. + Limit param.Opt[int64] `query:"limit,omitzero" json:"-"` + // Sort order for messages by timestamp. Use `asc` for ascending order or `desc` + // for descending order. Defaults to `asc`. + // + // Any of "asc", "desc". + Order ChatCompletionMessageListParamsOrder `query:"order,omitzero" json:"-"` + paramObj +} + +// URLQuery serializes [ChatCompletionMessageListParams]'s query parameters as +// `url.Values`. +func (r ChatCompletionMessageListParams) URLQuery() (v url.Values, err error) { + return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{ + ArrayFormat: apiquery.ArrayQueryFormatBrackets, + NestedFormat: apiquery.NestedQueryFormatBrackets, + }) +} + +// Sort order for messages by timestamp. Use `asc` for ascending order or `desc` +// for descending order. Defaults to `asc`. +type ChatCompletionMessageListParamsOrder string + +const ( + ChatCompletionMessageListParamsOrderAsc ChatCompletionMessageListParamsOrder = "asc" + ChatCompletionMessageListParamsOrderDesc ChatCompletionMessageListParamsOrder = "desc" +) diff --git a/vendor/github.com/openai/openai-go/client.go b/vendor/github.com/openai/openai-go/client.go new file mode 100644 index 0000000000..3d78d86b4e --- /dev/null +++ b/vendor/github.com/openai/openai-go/client.go @@ -0,0 +1,161 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package openai + +import ( + "context" + "net/http" + "os" + + "github.com/openai/openai-go/internal/requestconfig" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/responses" + "github.com/openai/openai-go/webhooks" +) + +// Client creates a struct with services and top level methods that help with +// interacting with the openai API. You should not instantiate this client +// directly, and instead use the [NewClient] method instead. +type Client struct { + Options []option.RequestOption + Completions CompletionService + Chat ChatService + Embeddings EmbeddingService + Files FileService + Images ImageService + Audio AudioService + Moderations ModerationService + Models ModelService + FineTuning FineTuningService + Graders GraderService + VectorStores VectorStoreService + Webhooks webhooks.WebhookService + Beta BetaService + Batches BatchService + Uploads UploadService + Responses responses.ResponseService + Containers ContainerService +} + +// DefaultClientOptions read from the environment (OPENAI_API_KEY, OPENAI_ORG_ID, +// OPENAI_PROJECT_ID, OPENAI_WEBHOOK_SECRET, OPENAI_BASE_URL). This should be used +// to initialize new clients. +func DefaultClientOptions() []option.RequestOption { + defaults := []option.RequestOption{option.WithEnvironmentProduction()} + if o, ok := os.LookupEnv("OPENAI_BASE_URL"); ok { + defaults = append(defaults, option.WithBaseURL(o)) + } + if o, ok := os.LookupEnv("OPENAI_API_KEY"); ok { + defaults = append(defaults, option.WithAPIKey(o)) + } + if o, ok := os.LookupEnv("OPENAI_ORG_ID"); ok { + defaults = append(defaults, option.WithOrganization(o)) + } + if o, ok := os.LookupEnv("OPENAI_PROJECT_ID"); ok { + defaults = append(defaults, option.WithProject(o)) + } + if o, ok := os.LookupEnv("OPENAI_WEBHOOK_SECRET"); ok { + defaults = append(defaults, option.WithWebhookSecret(o)) + } + return defaults +} + +// NewClient generates a new client with the default option read from the +// environment (OPENAI_API_KEY, OPENAI_ORG_ID, OPENAI_PROJECT_ID, +// OPENAI_WEBHOOK_SECRET, OPENAI_BASE_URL). The option passed in as arguments are +// applied after these default arguments, and all option will be passed down to the +// services and requests that this client makes. +func NewClient(opts ...option.RequestOption) (r Client) { + opts = append(DefaultClientOptions(), opts...) + + r = Client{Options: opts} + + r.Completions = NewCompletionService(opts...) + r.Chat = NewChatService(opts...) + r.Embeddings = NewEmbeddingService(opts...) + r.Files = NewFileService(opts...) + r.Images = NewImageService(opts...) + r.Audio = NewAudioService(opts...) + r.Moderations = NewModerationService(opts...) + r.Models = NewModelService(opts...) + r.FineTuning = NewFineTuningService(opts...) + r.Graders = NewGraderService(opts...) + r.VectorStores = NewVectorStoreService(opts...) + r.Webhooks = webhooks.NewWebhookService(opts...) + r.Beta = NewBetaService(opts...) + r.Batches = NewBatchService(opts...) + r.Uploads = NewUploadService(opts...) + r.Responses = responses.NewResponseService(opts...) + r.Containers = NewContainerService(opts...) + + return +} + +// Execute makes a request with the given context, method, URL, request params, +// response, and request options. This is useful for hitting undocumented endpoints +// while retaining the base URL, auth, retries, and other options from the client. +// +// If a byte slice or an [io.Reader] is supplied to params, it will be used as-is +// for the request body. +// +// The params is by default serialized into the body using [encoding/json]. If your +// type implements a MarshalJSON function, it will be used instead to serialize the +// request. If a URLQuery method is implemented, the returned [url.Values] will be +// used as query strings to the url. +// +// If your params struct uses [param.Field], you must provide either [MarshalJSON], +// [URLQuery], and/or [MarshalForm] functions. It is undefined behavior to use a +// struct uses [param.Field] without specifying how it is serialized. +// +// Any "…Params" object defined in this library can be used as the request +// argument. Note that 'path' arguments will not be forwarded into the url. +// +// The response body will be deserialized into the res variable, depending on its +// type: +// +// - A pointer to a [*http.Response] is populated by the raw response. +// - A pointer to a byte array will be populated with the contents of the request +// body. +// - A pointer to any other type uses this library's default JSON decoding, which +// respects UnmarshalJSON if it is defined on the type. +// - A nil value will not read the response body. +// +// For even greater flexibility, see [option.WithResponseInto] and +// [option.WithResponseBodyInto]. +func (r *Client) Execute(ctx context.Context, method string, path string, params any, res any, opts ...option.RequestOption) error { + opts = append(r.Options, opts...) + return requestconfig.ExecuteNewRequest(ctx, method, path, params, res, opts...) +} + +// Get makes a GET request with the given URL, params, and optionally deserializes +// to a response. See [Execute] documentation on the params and response. +func (r *Client) Get(ctx context.Context, path string, params any, res any, opts ...option.RequestOption) error { + return r.Execute(ctx, http.MethodGet, path, params, res, opts...) +} + +// Post makes a POST request with the given URL, params, and optionally +// deserializes to a response. See [Execute] documentation on the params and +// response. +func (r *Client) Post(ctx context.Context, path string, params any, res any, opts ...option.RequestOption) error { + return r.Execute(ctx, http.MethodPost, path, params, res, opts...) +} + +// Put makes a PUT request with the given URL, params, and optionally deserializes +// to a response. See [Execute] documentation on the params and response. +func (r *Client) Put(ctx context.Context, path string, params any, res any, opts ...option.RequestOption) error { + return r.Execute(ctx, http.MethodPut, path, params, res, opts...) +} + +// Patch makes a PATCH request with the given URL, params, and optionally +// deserializes to a response. See [Execute] documentation on the params and +// response. +func (r *Client) Patch(ctx context.Context, path string, params any, res any, opts ...option.RequestOption) error { + return r.Execute(ctx, http.MethodPatch, path, params, res, opts...) +} + +// Delete makes a DELETE request with the given URL, params, and optionally +// deserializes to a response. See [Execute] documentation on the params and +// response. +func (r *Client) Delete(ctx context.Context, path string, params any, res any, opts ...option.RequestOption) error { + return r.Execute(ctx, http.MethodDelete, path, params, res, opts...) +} diff --git a/vendor/github.com/openai/openai-go/completion.go b/vendor/github.com/openai/openai-go/completion.go new file mode 100644 index 0000000000..72b2c5102c --- /dev/null +++ b/vendor/github.com/openai/openai-go/completion.go @@ -0,0 +1,426 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package openai + +import ( + "context" + "net/http" + + "github.com/openai/openai-go/internal/apijson" + "github.com/openai/openai-go/internal/requestconfig" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/packages/param" + "github.com/openai/openai-go/packages/respjson" + "github.com/openai/openai-go/packages/ssestream" + "github.com/openai/openai-go/shared/constant" +) + +// CompletionService contains methods and other services that help with interacting +// with the openai API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewCompletionService] method instead. +type CompletionService struct { + Options []option.RequestOption +} + +// NewCompletionService generates a new service that applies the given options to +// each request. These options are applied after the parent client's options (if +// there is one), and before any request-specific options. +func NewCompletionService(opts ...option.RequestOption) (r CompletionService) { + r = CompletionService{} + r.Options = opts + return +} + +// Creates a completion for the provided prompt and parameters. +func (r *CompletionService) New(ctx context.Context, body CompletionNewParams, opts ...option.RequestOption) (res *Completion, err error) { + opts = append(r.Options[:], opts...) + path := "completions" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) + return +} + +// Creates a completion for the provided prompt and parameters. +func (r *CompletionService) NewStreaming(ctx context.Context, body CompletionNewParams, opts ...option.RequestOption) (stream *ssestream.Stream[Completion]) { + var ( + raw *http.Response + err error + ) + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithJSONSet("stream", true)}, opts...) + path := "completions" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &raw, opts...) + return ssestream.NewStream[Completion](ssestream.NewDecoder(raw), err) +} + +// Represents a completion response from the API. Note: both the streamed and +// non-streamed response objects share the same shape (unlike the chat endpoint). +type Completion struct { + // A unique identifier for the completion. + ID string `json:"id,required"` + // The list of completion choices the model generated for the input prompt. + Choices []CompletionChoice `json:"choices,required"` + // The Unix timestamp (in seconds) of when the completion was created. + Created int64 `json:"created,required"` + // The model used for completion. + Model string `json:"model,required"` + // The object type, which is always "text_completion" + Object constant.TextCompletion `json:"object,required"` + // This fingerprint represents the backend configuration that the model runs with. + // + // Can be used in conjunction with the `seed` request parameter to understand when + // backend changes have been made that might impact determinism. + SystemFingerprint string `json:"system_fingerprint"` + // Usage statistics for the completion request. + Usage CompletionUsage `json:"usage"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + Choices respjson.Field + Created respjson.Field + Model respjson.Field + Object respjson.Field + SystemFingerprint respjson.Field + Usage respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r Completion) RawJSON() string { return r.JSON.raw } +func (r *Completion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type CompletionChoice struct { + // The reason the model stopped generating tokens. This will be `stop` if the model + // hit a natural stop point or a provided stop sequence, `length` if the maximum + // number of tokens specified in the request was reached, or `content_filter` if + // content was omitted due to a flag from our content filters. + // + // Any of "stop", "length", "content_filter". + FinishReason CompletionChoiceFinishReason `json:"finish_reason,required"` + Index int64 `json:"index,required"` + Logprobs CompletionChoiceLogprobs `json:"logprobs,required"` + Text string `json:"text,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + FinishReason respjson.Field + Index respjson.Field + Logprobs respjson.Field + Text respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r CompletionChoice) RawJSON() string { return r.JSON.raw } +func (r *CompletionChoice) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The reason the model stopped generating tokens. This will be `stop` if the model +// hit a natural stop point or a provided stop sequence, `length` if the maximum +// number of tokens specified in the request was reached, or `content_filter` if +// content was omitted due to a flag from our content filters. +type CompletionChoiceFinishReason string + +const ( + CompletionChoiceFinishReasonStop CompletionChoiceFinishReason = "stop" + CompletionChoiceFinishReasonLength CompletionChoiceFinishReason = "length" + CompletionChoiceFinishReasonContentFilter CompletionChoiceFinishReason = "content_filter" +) + +type CompletionChoiceLogprobs struct { + TextOffset []int64 `json:"text_offset"` + TokenLogprobs []float64 `json:"token_logprobs"` + Tokens []string `json:"tokens"` + TopLogprobs []map[string]float64 `json:"top_logprobs"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + TextOffset respjson.Field + TokenLogprobs respjson.Field + Tokens respjson.Field + TopLogprobs respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r CompletionChoiceLogprobs) RawJSON() string { return r.JSON.raw } +func (r *CompletionChoiceLogprobs) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Usage statistics for the completion request. +type CompletionUsage struct { + // Number of tokens in the generated completion. + CompletionTokens int64 `json:"completion_tokens,required"` + // Number of tokens in the prompt. + PromptTokens int64 `json:"prompt_tokens,required"` + // Total number of tokens used in the request (prompt + completion). + TotalTokens int64 `json:"total_tokens,required"` + // Breakdown of tokens used in a completion. + CompletionTokensDetails CompletionUsageCompletionTokensDetails `json:"completion_tokens_details"` + // Breakdown of tokens used in the prompt. + PromptTokensDetails CompletionUsagePromptTokensDetails `json:"prompt_tokens_details"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + CompletionTokens respjson.Field + PromptTokens respjson.Field + TotalTokens respjson.Field + CompletionTokensDetails respjson.Field + PromptTokensDetails respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r CompletionUsage) RawJSON() string { return r.JSON.raw } +func (r *CompletionUsage) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Breakdown of tokens used in a completion. +type CompletionUsageCompletionTokensDetails struct { + // When using Predicted Outputs, the number of tokens in the prediction that + // appeared in the completion. + AcceptedPredictionTokens int64 `json:"accepted_prediction_tokens"` + // Audio input tokens generated by the model. + AudioTokens int64 `json:"audio_tokens"` + // Tokens generated by the model for reasoning. + ReasoningTokens int64 `json:"reasoning_tokens"` + // When using Predicted Outputs, the number of tokens in the prediction that did + // not appear in the completion. However, like reasoning tokens, these tokens are + // still counted in the total completion tokens for purposes of billing, output, + // and context window limits. + RejectedPredictionTokens int64 `json:"rejected_prediction_tokens"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + AcceptedPredictionTokens respjson.Field + AudioTokens respjson.Field + ReasoningTokens respjson.Field + RejectedPredictionTokens respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r CompletionUsageCompletionTokensDetails) RawJSON() string { return r.JSON.raw } +func (r *CompletionUsageCompletionTokensDetails) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Breakdown of tokens used in the prompt. +type CompletionUsagePromptTokensDetails struct { + // Audio input tokens present in the prompt. + AudioTokens int64 `json:"audio_tokens"` + // Cached tokens present in the prompt. + CachedTokens int64 `json:"cached_tokens"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + AudioTokens respjson.Field + CachedTokens respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r CompletionUsagePromptTokensDetails) RawJSON() string { return r.JSON.raw } +func (r *CompletionUsagePromptTokensDetails) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type CompletionNewParams struct { + // The prompt(s) to generate completions for, encoded as a string, array of + // strings, array of tokens, or array of token arrays. + // + // Note that <|endoftext|> is the document separator that the model sees during + // training, so if a prompt is not specified the model will generate as if from the + // beginning of a new document. + Prompt CompletionNewParamsPromptUnion `json:"prompt,omitzero,required"` + // ID of the model to use. You can use the + // [List models](https://platform.openai.com/docs/api-reference/models/list) API to + // see all of your available models, or see our + // [Model overview](https://platform.openai.com/docs/models) for descriptions of + // them. + Model CompletionNewParamsModel `json:"model,omitzero,required"` + // Generates `best_of` completions server-side and returns the "best" (the one with + // the highest log probability per token). Results cannot be streamed. + // + // When used with `n`, `best_of` controls the number of candidate completions and + // `n` specifies how many to return – `best_of` must be greater than `n`. + // + // **Note:** Because this parameter generates many completions, it can quickly + // consume your token quota. Use carefully and ensure that you have reasonable + // settings for `max_tokens` and `stop`. + BestOf param.Opt[int64] `json:"best_of,omitzero"` + // Echo back the prompt in addition to the completion + Echo param.Opt[bool] `json:"echo,omitzero"` + // Number between -2.0 and 2.0. Positive values penalize new tokens based on their + // existing frequency in the text so far, decreasing the model's likelihood to + // repeat the same line verbatim. + // + // [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/text-generation) + FrequencyPenalty param.Opt[float64] `json:"frequency_penalty,omitzero"` + // Include the log probabilities on the `logprobs` most likely output tokens, as + // well the chosen tokens. For example, if `logprobs` is 5, the API will return a + // list of the 5 most likely tokens. The API will always return the `logprob` of + // the sampled token, so there may be up to `logprobs+1` elements in the response. + // + // The maximum value for `logprobs` is 5. + Logprobs param.Opt[int64] `json:"logprobs,omitzero"` + // The maximum number of [tokens](/tokenizer) that can be generated in the + // completion. + // + // The token count of your prompt plus `max_tokens` cannot exceed the model's + // context length. + // [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) + // for counting tokens. + MaxTokens param.Opt[int64] `json:"max_tokens,omitzero"` + // How many completions to generate for each prompt. + // + // **Note:** Because this parameter generates many completions, it can quickly + // consume your token quota. Use carefully and ensure that you have reasonable + // settings for `max_tokens` and `stop`. + N param.Opt[int64] `json:"n,omitzero"` + // Number between -2.0 and 2.0. Positive values penalize new tokens based on + // whether they appear in the text so far, increasing the model's likelihood to + // talk about new topics. + // + // [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/text-generation) + PresencePenalty param.Opt[float64] `json:"presence_penalty,omitzero"` + // If specified, our system will make a best effort to sample deterministically, + // such that repeated requests with the same `seed` and parameters should return + // the same result. + // + // Determinism is not guaranteed, and you should refer to the `system_fingerprint` + // response parameter to monitor changes in the backend. + Seed param.Opt[int64] `json:"seed,omitzero"` + // The suffix that comes after a completion of inserted text. + // + // This parameter is only supported for `gpt-3.5-turbo-instruct`. + Suffix param.Opt[string] `json:"suffix,omitzero"` + // What sampling temperature to use, between 0 and 2. Higher values like 0.8 will + // make the output more random, while lower values like 0.2 will make it more + // focused and deterministic. + // + // We generally recommend altering this or `top_p` but not both. + Temperature param.Opt[float64] `json:"temperature,omitzero"` + // An alternative to sampling with temperature, called nucleus sampling, where the + // model considers the results of the tokens with top_p probability mass. So 0.1 + // means only the tokens comprising the top 10% probability mass are considered. + // + // We generally recommend altering this or `temperature` but not both. + TopP param.Opt[float64] `json:"top_p,omitzero"` + // A unique identifier representing your end-user, which can help OpenAI to monitor + // and detect abuse. + // [Learn more](https://platform.openai.com/docs/guides/safety-best-practices#end-user-ids). + User param.Opt[string] `json:"user,omitzero"` + // Modify the likelihood of specified tokens appearing in the completion. + // + // Accepts a JSON object that maps tokens (specified by their token ID in the GPT + // tokenizer) to an associated bias value from -100 to 100. You can use this + // [tokenizer tool](/tokenizer?view=bpe) to convert text to token IDs. + // Mathematically, the bias is added to the logits generated by the model prior to + // sampling. The exact effect will vary per model, but values between -1 and 1 + // should decrease or increase likelihood of selection; values like -100 or 100 + // should result in a ban or exclusive selection of the relevant token. + // + // As an example, you can pass `{"50256": -100}` to prevent the <|endoftext|> token + // from being generated. + LogitBias map[string]int64 `json:"logit_bias,omitzero"` + // Not supported with latest reasoning models `o3` and `o4-mini`. + // + // Up to 4 sequences where the API will stop generating further tokens. The + // returned text will not contain the stop sequence. + Stop CompletionNewParamsStopUnion `json:"stop,omitzero"` + // Options for streaming response. Only set this when you set `stream: true`. + StreamOptions ChatCompletionStreamOptionsParam `json:"stream_options,omitzero"` + paramObj +} + +func (r CompletionNewParams) MarshalJSON() (data []byte, err error) { + type shadow CompletionNewParams + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *CompletionNewParams) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ID of the model to use. You can use the +// [List models](https://platform.openai.com/docs/api-reference/models/list) API to +// see all of your available models, or see our +// [Model overview](https://platform.openai.com/docs/models) for descriptions of +// them. +type CompletionNewParamsModel string + +const ( + CompletionNewParamsModelGPT3_5TurboInstruct CompletionNewParamsModel = "gpt-3.5-turbo-instruct" + CompletionNewParamsModelDavinci002 CompletionNewParamsModel = "davinci-002" + CompletionNewParamsModelBabbage002 CompletionNewParamsModel = "babbage-002" +) + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type CompletionNewParamsPromptUnion struct { + OfString param.Opt[string] `json:",omitzero,inline"` + OfArrayOfStrings []string `json:",omitzero,inline"` + OfArrayOfTokens []int64 `json:",omitzero,inline"` + OfArrayOfTokenArrays [][]int64 `json:",omitzero,inline"` + paramUnion +} + +func (u CompletionNewParamsPromptUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfString, u.OfArrayOfStrings, u.OfArrayOfTokens, u.OfArrayOfTokenArrays) +} +func (u *CompletionNewParamsPromptUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *CompletionNewParamsPromptUnion) asAny() any { + if !param.IsOmitted(u.OfString) { + return &u.OfString.Value + } else if !param.IsOmitted(u.OfArrayOfStrings) { + return &u.OfArrayOfStrings + } else if !param.IsOmitted(u.OfArrayOfTokens) { + return &u.OfArrayOfTokens + } else if !param.IsOmitted(u.OfArrayOfTokenArrays) { + return &u.OfArrayOfTokenArrays + } + return nil +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type CompletionNewParamsStopUnion struct { + OfString param.Opt[string] `json:",omitzero,inline"` + OfStringArray []string `json:",omitzero,inline"` + paramUnion +} + +func (u CompletionNewParamsStopUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfString, u.OfStringArray) +} +func (u *CompletionNewParamsStopUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *CompletionNewParamsStopUnion) asAny() any { + if !param.IsOmitted(u.OfString) { + return &u.OfString.Value + } else if !param.IsOmitted(u.OfStringArray) { + return &u.OfStringArray + } + return nil +} diff --git a/vendor/github.com/openai/openai-go/container.go b/vendor/github.com/openai/openai-go/container.go new file mode 100644 index 0000000000..357bb988fd --- /dev/null +++ b/vendor/github.com/openai/openai-go/container.go @@ -0,0 +1,352 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package openai + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + + "github.com/openai/openai-go/internal/apijson" + "github.com/openai/openai-go/internal/apiquery" + "github.com/openai/openai-go/internal/requestconfig" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/packages/pagination" + "github.com/openai/openai-go/packages/param" + "github.com/openai/openai-go/packages/respjson" +) + +// ContainerService contains methods and other services that help with interacting +// with the openai API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewContainerService] method instead. +type ContainerService struct { + Options []option.RequestOption + Files ContainerFileService +} + +// NewContainerService generates a new service that applies the given options to +// each request. These options are applied after the parent client's options (if +// there is one), and before any request-specific options. +func NewContainerService(opts ...option.RequestOption) (r ContainerService) { + r = ContainerService{} + r.Options = opts + r.Files = NewContainerFileService(opts...) + return +} + +// Create Container +func (r *ContainerService) New(ctx context.Context, body ContainerNewParams, opts ...option.RequestOption) (res *ContainerNewResponse, err error) { + opts = append(r.Options[:], opts...) + path := "containers" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) + return +} + +// Retrieve Container +func (r *ContainerService) Get(ctx context.Context, containerID string, opts ...option.RequestOption) (res *ContainerGetResponse, err error) { + opts = append(r.Options[:], opts...) + if containerID == "" { + err = errors.New("missing required container_id parameter") + return + } + path := fmt.Sprintf("containers/%s", containerID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, nil, &res, opts...) + return +} + +// List Containers +func (r *ContainerService) List(ctx context.Context, query ContainerListParams, opts ...option.RequestOption) (res *pagination.CursorPage[ContainerListResponse], err error) { + var raw *http.Response + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithResponseInto(&raw)}, opts...) + path := "containers" + cfg, err := requestconfig.NewRequestConfig(ctx, http.MethodGet, path, query, &res, opts...) + if err != nil { + return nil, err + } + err = cfg.Execute() + if err != nil { + return nil, err + } + res.SetPageConfig(cfg, raw) + return res, nil +} + +// List Containers +func (r *ContainerService) ListAutoPaging(ctx context.Context, query ContainerListParams, opts ...option.RequestOption) *pagination.CursorPageAutoPager[ContainerListResponse] { + return pagination.NewCursorPageAutoPager(r.List(ctx, query, opts...)) +} + +// Delete Container +func (r *ContainerService) Delete(ctx context.Context, containerID string, opts ...option.RequestOption) (err error) { + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("Accept", "")}, opts...) + if containerID == "" { + err = errors.New("missing required container_id parameter") + return + } + path := fmt.Sprintf("containers/%s", containerID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodDelete, path, nil, nil, opts...) + return +} + +type ContainerNewResponse struct { + // Unique identifier for the container. + ID string `json:"id,required"` + // Unix timestamp (in seconds) when the container was created. + CreatedAt int64 `json:"created_at,required"` + // Name of the container. + Name string `json:"name,required"` + // The type of this object. + Object string `json:"object,required"` + // Status of the container (e.g., active, deleted). + Status string `json:"status,required"` + // The container will expire after this time period. The anchor is the reference + // point for the expiration. The minutes is the number of minutes after the anchor + // before the container expires. + ExpiresAfter ContainerNewResponseExpiresAfter `json:"expires_after"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + CreatedAt respjson.Field + Name respjson.Field + Object respjson.Field + Status respjson.Field + ExpiresAfter respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ContainerNewResponse) RawJSON() string { return r.JSON.raw } +func (r *ContainerNewResponse) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The container will expire after this time period. The anchor is the reference +// point for the expiration. The minutes is the number of minutes after the anchor +// before the container expires. +type ContainerNewResponseExpiresAfter struct { + // The reference point for the expiration. + // + // Any of "last_active_at". + Anchor string `json:"anchor"` + // The number of minutes after the anchor before the container expires. + Minutes int64 `json:"minutes"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Anchor respjson.Field + Minutes respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ContainerNewResponseExpiresAfter) RawJSON() string { return r.JSON.raw } +func (r *ContainerNewResponseExpiresAfter) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type ContainerGetResponse struct { + // Unique identifier for the container. + ID string `json:"id,required"` + // Unix timestamp (in seconds) when the container was created. + CreatedAt int64 `json:"created_at,required"` + // Name of the container. + Name string `json:"name,required"` + // The type of this object. + Object string `json:"object,required"` + // Status of the container (e.g., active, deleted). + Status string `json:"status,required"` + // The container will expire after this time period. The anchor is the reference + // point for the expiration. The minutes is the number of minutes after the anchor + // before the container expires. + ExpiresAfter ContainerGetResponseExpiresAfter `json:"expires_after"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + CreatedAt respjson.Field + Name respjson.Field + Object respjson.Field + Status respjson.Field + ExpiresAfter respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ContainerGetResponse) RawJSON() string { return r.JSON.raw } +func (r *ContainerGetResponse) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The container will expire after this time period. The anchor is the reference +// point for the expiration. The minutes is the number of minutes after the anchor +// before the container expires. +type ContainerGetResponseExpiresAfter struct { + // The reference point for the expiration. + // + // Any of "last_active_at". + Anchor string `json:"anchor"` + // The number of minutes after the anchor before the container expires. + Minutes int64 `json:"minutes"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Anchor respjson.Field + Minutes respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ContainerGetResponseExpiresAfter) RawJSON() string { return r.JSON.raw } +func (r *ContainerGetResponseExpiresAfter) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type ContainerListResponse struct { + // Unique identifier for the container. + ID string `json:"id,required"` + // Unix timestamp (in seconds) when the container was created. + CreatedAt int64 `json:"created_at,required"` + // Name of the container. + Name string `json:"name,required"` + // The type of this object. + Object string `json:"object,required"` + // Status of the container (e.g., active, deleted). + Status string `json:"status,required"` + // The container will expire after this time period. The anchor is the reference + // point for the expiration. The minutes is the number of minutes after the anchor + // before the container expires. + ExpiresAfter ContainerListResponseExpiresAfter `json:"expires_after"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + CreatedAt respjson.Field + Name respjson.Field + Object respjson.Field + Status respjson.Field + ExpiresAfter respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ContainerListResponse) RawJSON() string { return r.JSON.raw } +func (r *ContainerListResponse) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The container will expire after this time period. The anchor is the reference +// point for the expiration. The minutes is the number of minutes after the anchor +// before the container expires. +type ContainerListResponseExpiresAfter struct { + // The reference point for the expiration. + // + // Any of "last_active_at". + Anchor string `json:"anchor"` + // The number of minutes after the anchor before the container expires. + Minutes int64 `json:"minutes"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Anchor respjson.Field + Minutes respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ContainerListResponseExpiresAfter) RawJSON() string { return r.JSON.raw } +func (r *ContainerListResponseExpiresAfter) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type ContainerNewParams struct { + // Name of the container to create. + Name string `json:"name,required"` + // Container expiration time in seconds relative to the 'anchor' time. + ExpiresAfter ContainerNewParamsExpiresAfter `json:"expires_after,omitzero"` + // IDs of files to copy to the container. + FileIDs []string `json:"file_ids,omitzero"` + paramObj +} + +func (r ContainerNewParams) MarshalJSON() (data []byte, err error) { + type shadow ContainerNewParams + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ContainerNewParams) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Container expiration time in seconds relative to the 'anchor' time. +// +// The properties Anchor, Minutes are required. +type ContainerNewParamsExpiresAfter struct { + // Time anchor for the expiration time. Currently only 'last_active_at' is + // supported. + // + // Any of "last_active_at". + Anchor string `json:"anchor,omitzero,required"` + Minutes int64 `json:"minutes,required"` + paramObj +} + +func (r ContainerNewParamsExpiresAfter) MarshalJSON() (data []byte, err error) { + type shadow ContainerNewParamsExpiresAfter + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ContainerNewParamsExpiresAfter) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +func init() { + apijson.RegisterFieldValidator[ContainerNewParamsExpiresAfter]( + "anchor", "last_active_at", + ) +} + +type ContainerListParams struct { + // A cursor for use in pagination. `after` is an object ID that defines your place + // in the list. For instance, if you make a list request and receive 100 objects, + // ending with obj_foo, your subsequent call can include after=obj_foo in order to + // fetch the next page of the list. + After param.Opt[string] `query:"after,omitzero" json:"-"` + // A limit on the number of objects to be returned. Limit can range between 1 and + // 100, and the default is 20. + Limit param.Opt[int64] `query:"limit,omitzero" json:"-"` + // Sort order by the `created_at` timestamp of the objects. `asc` for ascending + // order and `desc` for descending order. + // + // Any of "asc", "desc". + Order ContainerListParamsOrder `query:"order,omitzero" json:"-"` + paramObj +} + +// URLQuery serializes [ContainerListParams]'s query parameters as `url.Values`. +func (r ContainerListParams) URLQuery() (v url.Values, err error) { + return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{ + ArrayFormat: apiquery.ArrayQueryFormatBrackets, + NestedFormat: apiquery.NestedQueryFormatBrackets, + }) +} + +// Sort order by the `created_at` timestamp of the objects. `asc` for ascending +// order and `desc` for descending order. +type ContainerListParamsOrder string + +const ( + ContainerListParamsOrderAsc ContainerListParamsOrder = "asc" + ContainerListParamsOrderDesc ContainerListParamsOrder = "desc" +) diff --git a/vendor/github.com/openai/openai-go/containerfile.go b/vendor/github.com/openai/openai-go/containerfile.go new file mode 100644 index 0000000000..bd50b5297a --- /dev/null +++ b/vendor/github.com/openai/openai-go/containerfile.go @@ -0,0 +1,286 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package openai + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "mime/multipart" + "net/http" + "net/url" + + "github.com/openai/openai-go/internal/apiform" + "github.com/openai/openai-go/internal/apijson" + "github.com/openai/openai-go/internal/apiquery" + "github.com/openai/openai-go/internal/requestconfig" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/packages/pagination" + "github.com/openai/openai-go/packages/param" + "github.com/openai/openai-go/packages/respjson" + "github.com/openai/openai-go/shared/constant" +) + +// ContainerFileService contains methods and other services that help with +// interacting with the openai API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewContainerFileService] method instead. +type ContainerFileService struct { + Options []option.RequestOption + Content ContainerFileContentService +} + +// NewContainerFileService generates a new service that applies the given options +// to each request. These options are applied after the parent client's options (if +// there is one), and before any request-specific options. +func NewContainerFileService(opts ...option.RequestOption) (r ContainerFileService) { + r = ContainerFileService{} + r.Options = opts + r.Content = NewContainerFileContentService(opts...) + return +} + +// Create a Container File +// +// You can send either a multipart/form-data request with the raw file content, or +// a JSON request with a file ID. +func (r *ContainerFileService) New(ctx context.Context, containerID string, body ContainerFileNewParams, opts ...option.RequestOption) (res *ContainerFileNewResponse, err error) { + opts = append(r.Options[:], opts...) + if containerID == "" { + err = errors.New("missing required container_id parameter") + return + } + path := fmt.Sprintf("containers/%s/files", containerID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) + return +} + +// Retrieve Container File +func (r *ContainerFileService) Get(ctx context.Context, containerID string, fileID string, opts ...option.RequestOption) (res *ContainerFileGetResponse, err error) { + opts = append(r.Options[:], opts...) + if containerID == "" { + err = errors.New("missing required container_id parameter") + return + } + if fileID == "" { + err = errors.New("missing required file_id parameter") + return + } + path := fmt.Sprintf("containers/%s/files/%s", containerID, fileID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, nil, &res, opts...) + return +} + +// List Container files +func (r *ContainerFileService) List(ctx context.Context, containerID string, query ContainerFileListParams, opts ...option.RequestOption) (res *pagination.CursorPage[ContainerFileListResponse], err error) { + var raw *http.Response + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithResponseInto(&raw)}, opts...) + if containerID == "" { + err = errors.New("missing required container_id parameter") + return + } + path := fmt.Sprintf("containers/%s/files", containerID) + cfg, err := requestconfig.NewRequestConfig(ctx, http.MethodGet, path, query, &res, opts...) + if err != nil { + return nil, err + } + err = cfg.Execute() + if err != nil { + return nil, err + } + res.SetPageConfig(cfg, raw) + return res, nil +} + +// List Container files +func (r *ContainerFileService) ListAutoPaging(ctx context.Context, containerID string, query ContainerFileListParams, opts ...option.RequestOption) *pagination.CursorPageAutoPager[ContainerFileListResponse] { + return pagination.NewCursorPageAutoPager(r.List(ctx, containerID, query, opts...)) +} + +// Delete Container File +func (r *ContainerFileService) Delete(ctx context.Context, containerID string, fileID string, opts ...option.RequestOption) (err error) { + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("Accept", "")}, opts...) + if containerID == "" { + err = errors.New("missing required container_id parameter") + return + } + if fileID == "" { + err = errors.New("missing required file_id parameter") + return + } + path := fmt.Sprintf("containers/%s/files/%s", containerID, fileID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodDelete, path, nil, nil, opts...) + return +} + +type ContainerFileNewResponse struct { + // Unique identifier for the file. + ID string `json:"id,required"` + // Size of the file in bytes. + Bytes int64 `json:"bytes,required"` + // The container this file belongs to. + ContainerID string `json:"container_id,required"` + // Unix timestamp (in seconds) when the file was created. + CreatedAt int64 `json:"created_at,required"` + // The type of this object (`container.file`). + Object constant.ContainerFile `json:"object,required"` + // Path of the file in the container. + Path string `json:"path,required"` + // Source of the file (e.g., `user`, `assistant`). + Source string `json:"source,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + Bytes respjson.Field + ContainerID respjson.Field + CreatedAt respjson.Field + Object respjson.Field + Path respjson.Field + Source respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ContainerFileNewResponse) RawJSON() string { return r.JSON.raw } +func (r *ContainerFileNewResponse) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type ContainerFileGetResponse struct { + // Unique identifier for the file. + ID string `json:"id,required"` + // Size of the file in bytes. + Bytes int64 `json:"bytes,required"` + // The container this file belongs to. + ContainerID string `json:"container_id,required"` + // Unix timestamp (in seconds) when the file was created. + CreatedAt int64 `json:"created_at,required"` + // The type of this object (`container.file`). + Object constant.ContainerFile `json:"object,required"` + // Path of the file in the container. + Path string `json:"path,required"` + // Source of the file (e.g., `user`, `assistant`). + Source string `json:"source,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + Bytes respjson.Field + ContainerID respjson.Field + CreatedAt respjson.Field + Object respjson.Field + Path respjson.Field + Source respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ContainerFileGetResponse) RawJSON() string { return r.JSON.raw } +func (r *ContainerFileGetResponse) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type ContainerFileListResponse struct { + // Unique identifier for the file. + ID string `json:"id,required"` + // Size of the file in bytes. + Bytes int64 `json:"bytes,required"` + // The container this file belongs to. + ContainerID string `json:"container_id,required"` + // Unix timestamp (in seconds) when the file was created. + CreatedAt int64 `json:"created_at,required"` + // The type of this object (`container.file`). + Object constant.ContainerFile `json:"object,required"` + // Path of the file in the container. + Path string `json:"path,required"` + // Source of the file (e.g., `user`, `assistant`). + Source string `json:"source,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + Bytes respjson.Field + ContainerID respjson.Field + CreatedAt respjson.Field + Object respjson.Field + Path respjson.Field + Source respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ContainerFileListResponse) RawJSON() string { return r.JSON.raw } +func (r *ContainerFileListResponse) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type ContainerFileNewParams struct { + // Name of the file to create. + FileID param.Opt[string] `json:"file_id,omitzero"` + // The File object (not file name) to be uploaded. + File io.Reader `json:"file,omitzero" format:"binary"` + paramObj +} + +func (r ContainerFileNewParams) MarshalMultipart() (data []byte, contentType string, err error) { + buf := bytes.NewBuffer(nil) + writer := multipart.NewWriter(buf) + err = apiform.MarshalRoot(r, writer) + if err == nil { + err = apiform.WriteExtras(writer, r.ExtraFields()) + } + if err != nil { + writer.Close() + return nil, "", err + } + err = writer.Close() + if err != nil { + return nil, "", err + } + return buf.Bytes(), writer.FormDataContentType(), nil +} + +type ContainerFileListParams struct { + // A cursor for use in pagination. `after` is an object ID that defines your place + // in the list. For instance, if you make a list request and receive 100 objects, + // ending with obj_foo, your subsequent call can include after=obj_foo in order to + // fetch the next page of the list. + After param.Opt[string] `query:"after,omitzero" json:"-"` + // A limit on the number of objects to be returned. Limit can range between 1 and + // 100, and the default is 20. + Limit param.Opt[int64] `query:"limit,omitzero" json:"-"` + // Sort order by the `created_at` timestamp of the objects. `asc` for ascending + // order and `desc` for descending order. + // + // Any of "asc", "desc". + Order ContainerFileListParamsOrder `query:"order,omitzero" json:"-"` + paramObj +} + +// URLQuery serializes [ContainerFileListParams]'s query parameters as +// `url.Values`. +func (r ContainerFileListParams) URLQuery() (v url.Values, err error) { + return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{ + ArrayFormat: apiquery.ArrayQueryFormatBrackets, + NestedFormat: apiquery.NestedQueryFormatBrackets, + }) +} + +// Sort order by the `created_at` timestamp of the objects. `asc` for ascending +// order and `desc` for descending order. +type ContainerFileListParamsOrder string + +const ( + ContainerFileListParamsOrderAsc ContainerFileListParamsOrder = "asc" + ContainerFileListParamsOrderDesc ContainerFileListParamsOrder = "desc" +) diff --git a/vendor/github.com/openai/openai-go/containerfilecontent.go b/vendor/github.com/openai/openai-go/containerfilecontent.go new file mode 100644 index 0000000000..0fb0fa9f1f --- /dev/null +++ b/vendor/github.com/openai/openai-go/containerfilecontent.go @@ -0,0 +1,49 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package openai + +import ( + "context" + "errors" + "fmt" + "net/http" + + "github.com/openai/openai-go/internal/requestconfig" + "github.com/openai/openai-go/option" +) + +// ContainerFileContentService contains methods and other services that help with +// interacting with the openai API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewContainerFileContentService] method instead. +type ContainerFileContentService struct { + Options []option.RequestOption +} + +// NewContainerFileContentService generates a new service that applies the given +// options to each request. These options are applied after the parent client's +// options (if there is one), and before any request-specific options. +func NewContainerFileContentService(opts ...option.RequestOption) (r ContainerFileContentService) { + r = ContainerFileContentService{} + r.Options = opts + return +} + +// Retrieve Container File Content +func (r *ContainerFileContentService) Get(ctx context.Context, containerID string, fileID string, opts ...option.RequestOption) (res *http.Response, err error) { + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("Accept", "application/binary")}, opts...) + if containerID == "" { + err = errors.New("missing required container_id parameter") + return + } + if fileID == "" { + err = errors.New("missing required file_id parameter") + return + } + path := fmt.Sprintf("containers/%s/files/%s/content", containerID, fileID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, nil, &res, opts...) + return +} diff --git a/vendor/github.com/openai/openai-go/embedding.go b/vendor/github.com/openai/openai-go/embedding.go new file mode 100644 index 0000000000..f7650f5545 --- /dev/null +++ b/vendor/github.com/openai/openai-go/embedding.go @@ -0,0 +1,203 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package openai + +import ( + "context" + "net/http" + + "github.com/openai/openai-go/internal/apijson" + "github.com/openai/openai-go/internal/requestconfig" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/packages/param" + "github.com/openai/openai-go/packages/respjson" + "github.com/openai/openai-go/shared/constant" +) + +// EmbeddingService contains methods and other services that help with interacting +// with the openai API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewEmbeddingService] method instead. +type EmbeddingService struct { + Options []option.RequestOption +} + +// NewEmbeddingService generates a new service that applies the given options to +// each request. These options are applied after the parent client's options (if +// there is one), and before any request-specific options. +func NewEmbeddingService(opts ...option.RequestOption) (r EmbeddingService) { + r = EmbeddingService{} + r.Options = opts + return +} + +// Creates an embedding vector representing the input text. +func (r *EmbeddingService) New(ctx context.Context, body EmbeddingNewParams, opts ...option.RequestOption) (res *CreateEmbeddingResponse, err error) { + opts = append(r.Options[:], opts...) + path := "embeddings" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) + return +} + +type CreateEmbeddingResponse struct { + // The list of embeddings generated by the model. + Data []Embedding `json:"data,required"` + // The name of the model used to generate the embedding. + Model string `json:"model,required"` + // The object type, which is always "list". + Object constant.List `json:"object,required"` + // The usage information for the request. + Usage CreateEmbeddingResponseUsage `json:"usage,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Data respjson.Field + Model respjson.Field + Object respjson.Field + Usage respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r CreateEmbeddingResponse) RawJSON() string { return r.JSON.raw } +func (r *CreateEmbeddingResponse) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The usage information for the request. +type CreateEmbeddingResponseUsage struct { + // The number of tokens used by the prompt. + PromptTokens int64 `json:"prompt_tokens,required"` + // The total number of tokens used by the request. + TotalTokens int64 `json:"total_tokens,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + PromptTokens respjson.Field + TotalTokens respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r CreateEmbeddingResponseUsage) RawJSON() string { return r.JSON.raw } +func (r *CreateEmbeddingResponseUsage) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Represents an embedding vector returned by embedding endpoint. +type Embedding struct { + // The embedding vector, which is a list of floats. The length of vector depends on + // the model as listed in the + // [embedding guide](https://platform.openai.com/docs/guides/embeddings). + Embedding []float64 `json:"embedding,required"` + // The index of the embedding in the list of embeddings. + Index int64 `json:"index,required"` + // The object type, which is always "embedding". + Object constant.Embedding `json:"object,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Embedding respjson.Field + Index respjson.Field + Object respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r Embedding) RawJSON() string { return r.JSON.raw } +func (r *Embedding) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type EmbeddingModel = string + +const ( + EmbeddingModelTextEmbeddingAda002 EmbeddingModel = "text-embedding-ada-002" + EmbeddingModelTextEmbedding3Small EmbeddingModel = "text-embedding-3-small" + EmbeddingModelTextEmbedding3Large EmbeddingModel = "text-embedding-3-large" +) + +type EmbeddingNewParams struct { + // Input text to embed, encoded as a string or array of tokens. To embed multiple + // inputs in a single request, pass an array of strings or array of token arrays. + // The input must not exceed the max input tokens for the model (8192 tokens for + // all embedding models), cannot be an empty string, and any array must be 2048 + // dimensions or less. + // [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) + // for counting tokens. In addition to the per-input token limit, all embedding + // models enforce a maximum of 300,000 tokens summed across all inputs in a single + // request. + Input EmbeddingNewParamsInputUnion `json:"input,omitzero,required"` + // ID of the model to use. You can use the + // [List models](https://platform.openai.com/docs/api-reference/models/list) API to + // see all of your available models, or see our + // [Model overview](https://platform.openai.com/docs/models) for descriptions of + // them. + Model EmbeddingModel `json:"model,omitzero,required"` + // The number of dimensions the resulting output embeddings should have. Only + // supported in `text-embedding-3` and later models. + Dimensions param.Opt[int64] `json:"dimensions,omitzero"` + // A unique identifier representing your end-user, which can help OpenAI to monitor + // and detect abuse. + // [Learn more](https://platform.openai.com/docs/guides/safety-best-practices#end-user-ids). + User param.Opt[string] `json:"user,omitzero"` + // The format to return the embeddings in. Can be either `float` or + // [`base64`](https://pypi.org/project/pybase64/). + // + // Any of "float", "base64". + EncodingFormat EmbeddingNewParamsEncodingFormat `json:"encoding_format,omitzero"` + paramObj +} + +func (r EmbeddingNewParams) MarshalJSON() (data []byte, err error) { + type shadow EmbeddingNewParams + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *EmbeddingNewParams) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type EmbeddingNewParamsInputUnion struct { + OfString param.Opt[string] `json:",omitzero,inline"` + OfArrayOfStrings []string `json:",omitzero,inline"` + OfArrayOfTokens []int64 `json:",omitzero,inline"` + OfArrayOfTokenArrays [][]int64 `json:",omitzero,inline"` + paramUnion +} + +func (u EmbeddingNewParamsInputUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfString, u.OfArrayOfStrings, u.OfArrayOfTokens, u.OfArrayOfTokenArrays) +} +func (u *EmbeddingNewParamsInputUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *EmbeddingNewParamsInputUnion) asAny() any { + if !param.IsOmitted(u.OfString) { + return &u.OfString.Value + } else if !param.IsOmitted(u.OfArrayOfStrings) { + return &u.OfArrayOfStrings + } else if !param.IsOmitted(u.OfArrayOfTokens) { + return &u.OfArrayOfTokens + } else if !param.IsOmitted(u.OfArrayOfTokenArrays) { + return &u.OfArrayOfTokenArrays + } + return nil +} + +// The format to return the embeddings in. Can be either `float` or +// [`base64`](https://pypi.org/project/pybase64/). +type EmbeddingNewParamsEncodingFormat string + +const ( + EmbeddingNewParamsEncodingFormatFloat EmbeddingNewParamsEncodingFormat = "float" + EmbeddingNewParamsEncodingFormatBase64 EmbeddingNewParamsEncodingFormat = "base64" +) diff --git a/vendor/github.com/openai/openai-go/field.go b/vendor/github.com/openai/openai-go/field.go new file mode 100644 index 0000000000..affd8998fa --- /dev/null +++ b/vendor/github.com/openai/openai-go/field.go @@ -0,0 +1,45 @@ +package openai + +import ( + "github.com/openai/openai-go/packages/param" + "io" + "time" +) + +func String(s string) param.Opt[string] { return param.NewOpt(s) } +func Int(i int64) param.Opt[int64] { return param.NewOpt(i) } +func Bool(b bool) param.Opt[bool] { return param.NewOpt(b) } +func Float(f float64) param.Opt[float64] { return param.NewOpt(f) } +func Time(t time.Time) param.Opt[time.Time] { return param.NewOpt(t) } + +func Opt[T comparable](v T) param.Opt[T] { return param.NewOpt(v) } +func Ptr[T any](v T) *T { return &v } + +func IntPtr(v int64) *int64 { return &v } +func BoolPtr(v bool) *bool { return &v } +func FloatPtr(v float64) *float64 { return &v } +func StringPtr(v string) *string { return &v } +func TimePtr(v time.Time) *time.Time { return &v } + +func File(rdr io.Reader, filename string, contentType string) file { + return file{rdr, filename, contentType} +} + +type file struct { + io.Reader + name string + contentType string +} + +func (f file) Filename() string { + if f.name != "" { + return f.name + } else if named, ok := f.Reader.(interface{ Name() string }); ok { + return named.Name() + } + return "" +} + +func (f file) ContentType() string { + return f.contentType +} diff --git a/vendor/github.com/openai/openai-go/file.go b/vendor/github.com/openai/openai-go/file.go new file mode 100644 index 0000000000..7aa70565e5 --- /dev/null +++ b/vendor/github.com/openai/openai-go/file.go @@ -0,0 +1,314 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package openai + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "mime/multipart" + "net/http" + "net/url" + + "github.com/openai/openai-go/internal/apiform" + "github.com/openai/openai-go/internal/apijson" + "github.com/openai/openai-go/internal/apiquery" + "github.com/openai/openai-go/internal/requestconfig" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/packages/pagination" + "github.com/openai/openai-go/packages/param" + "github.com/openai/openai-go/packages/respjson" + "github.com/openai/openai-go/shared/constant" +) + +// FileService contains methods and other services that help with interacting with +// the openai API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewFileService] method instead. +type FileService struct { + Options []option.RequestOption +} + +// NewFileService generates a new service that applies the given options to each +// request. These options are applied after the parent client's options (if there +// is one), and before any request-specific options. +func NewFileService(opts ...option.RequestOption) (r FileService) { + r = FileService{} + r.Options = opts + return +} + +// Upload a file that can be used across various endpoints. Individual files can be +// up to 512 MB, and the size of all files uploaded by one organization can be up +// to 100 GB. +// +// The Assistants API supports files up to 2 million tokens and of specific file +// types. See the +// [Assistants Tools guide](https://platform.openai.com/docs/assistants/tools) for +// details. +// +// The Fine-tuning API only supports `.jsonl` files. The input also has certain +// required formats for fine-tuning +// [chat](https://platform.openai.com/docs/api-reference/fine-tuning/chat-input) or +// [completions](https://platform.openai.com/docs/api-reference/fine-tuning/completions-input) +// models. +// +// The Batch API only supports `.jsonl` files up to 200 MB in size. The input also +// has a specific required +// [format](https://platform.openai.com/docs/api-reference/batch/request-input). +// +// Please [contact us](https://help.openai.com/) if you need to increase these +// storage limits. +func (r *FileService) New(ctx context.Context, body FileNewParams, opts ...option.RequestOption) (res *FileObject, err error) { + opts = append(r.Options[:], opts...) + path := "files" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) + return +} + +// Returns information about a specific file. +func (r *FileService) Get(ctx context.Context, fileID string, opts ...option.RequestOption) (res *FileObject, err error) { + opts = append(r.Options[:], opts...) + if fileID == "" { + err = errors.New("missing required file_id parameter") + return + } + path := fmt.Sprintf("files/%s", fileID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, nil, &res, opts...) + return +} + +// Returns a list of files. +func (r *FileService) List(ctx context.Context, query FileListParams, opts ...option.RequestOption) (res *pagination.CursorPage[FileObject], err error) { + var raw *http.Response + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithResponseInto(&raw)}, opts...) + path := "files" + cfg, err := requestconfig.NewRequestConfig(ctx, http.MethodGet, path, query, &res, opts...) + if err != nil { + return nil, err + } + err = cfg.Execute() + if err != nil { + return nil, err + } + res.SetPageConfig(cfg, raw) + return res, nil +} + +// Returns a list of files. +func (r *FileService) ListAutoPaging(ctx context.Context, query FileListParams, opts ...option.RequestOption) *pagination.CursorPageAutoPager[FileObject] { + return pagination.NewCursorPageAutoPager(r.List(ctx, query, opts...)) +} + +// Delete a file. +func (r *FileService) Delete(ctx context.Context, fileID string, opts ...option.RequestOption) (res *FileDeleted, err error) { + opts = append(r.Options[:], opts...) + if fileID == "" { + err = errors.New("missing required file_id parameter") + return + } + path := fmt.Sprintf("files/%s", fileID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodDelete, path, nil, &res, opts...) + return +} + +// Returns the contents of the specified file. +func (r *FileService) Content(ctx context.Context, fileID string, opts ...option.RequestOption) (res *http.Response, err error) { + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("Accept", "application/binary")}, opts...) + if fileID == "" { + err = errors.New("missing required file_id parameter") + return + } + path := fmt.Sprintf("files/%s/content", fileID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, nil, &res, opts...) + return +} + +type FileDeleted struct { + ID string `json:"id,required"` + Deleted bool `json:"deleted,required"` + Object constant.File `json:"object,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + Deleted respjson.Field + Object respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FileDeleted) RawJSON() string { return r.JSON.raw } +func (r *FileDeleted) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The `File` object represents a document that has been uploaded to OpenAI. +type FileObject struct { + // The file identifier, which can be referenced in the API endpoints. + ID string `json:"id,required"` + // The size of the file, in bytes. + Bytes int64 `json:"bytes,required"` + // The Unix timestamp (in seconds) for when the file was created. + CreatedAt int64 `json:"created_at,required"` + // The name of the file. + Filename string `json:"filename,required"` + // The object type, which is always `file`. + Object constant.File `json:"object,required"` + // The intended purpose of the file. Supported values are `assistants`, + // `assistants_output`, `batch`, `batch_output`, `fine-tune`, `fine-tune-results`, + // `vision`, and `user_data`. + // + // Any of "assistants", "assistants_output", "batch", "batch_output", "fine-tune", + // "fine-tune-results", "vision", "user_data". + Purpose FileObjectPurpose `json:"purpose,required"` + // Deprecated. The current status of the file, which can be either `uploaded`, + // `processed`, or `error`. + // + // Any of "uploaded", "processed", "error". + // + // Deprecated: deprecated + Status FileObjectStatus `json:"status,required"` + // The Unix timestamp (in seconds) for when the file will expire. + ExpiresAt int64 `json:"expires_at"` + // Deprecated. For details on why a fine-tuning training file failed validation, + // see the `error` field on `fine_tuning.job`. + // + // Deprecated: deprecated + StatusDetails string `json:"status_details"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + Bytes respjson.Field + CreatedAt respjson.Field + Filename respjson.Field + Object respjson.Field + Purpose respjson.Field + Status respjson.Field + ExpiresAt respjson.Field + StatusDetails respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FileObject) RawJSON() string { return r.JSON.raw } +func (r *FileObject) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The intended purpose of the file. Supported values are `assistants`, +// `assistants_output`, `batch`, `batch_output`, `fine-tune`, `fine-tune-results`, +// `vision`, and `user_data`. +type FileObjectPurpose string + +const ( + FileObjectPurposeAssistants FileObjectPurpose = "assistants" + FileObjectPurposeAssistantsOutput FileObjectPurpose = "assistants_output" + FileObjectPurposeBatch FileObjectPurpose = "batch" + FileObjectPurposeBatchOutput FileObjectPurpose = "batch_output" + FileObjectPurposeFineTune FileObjectPurpose = "fine-tune" + FileObjectPurposeFineTuneResults FileObjectPurpose = "fine-tune-results" + FileObjectPurposeVision FileObjectPurpose = "vision" + FileObjectPurposeUserData FileObjectPurpose = "user_data" +) + +// Deprecated. The current status of the file, which can be either `uploaded`, +// `processed`, or `error`. +type FileObjectStatus string + +const ( + FileObjectStatusUploaded FileObjectStatus = "uploaded" + FileObjectStatusProcessed FileObjectStatus = "processed" + FileObjectStatusError FileObjectStatus = "error" +) + +// The intended purpose of the uploaded file. One of: - `assistants`: Used in the +// Assistants API - `batch`: Used in the Batch API - `fine-tune`: Used for +// fine-tuning - `vision`: Images used for vision fine-tuning - `user_data`: +// Flexible file type for any purpose - `evals`: Used for eval data sets +type FilePurpose string + +const ( + FilePurposeAssistants FilePurpose = "assistants" + FilePurposeBatch FilePurpose = "batch" + FilePurposeFineTune FilePurpose = "fine-tune" + FilePurposeVision FilePurpose = "vision" + FilePurposeUserData FilePurpose = "user_data" + FilePurposeEvals FilePurpose = "evals" +) + +type FileNewParams struct { + // The File object (not file name) to be uploaded. + File io.Reader `json:"file,omitzero,required" format:"binary"` + // The intended purpose of the uploaded file. One of: - `assistants`: Used in the + // Assistants API - `batch`: Used in the Batch API - `fine-tune`: Used for + // fine-tuning - `vision`: Images used for vision fine-tuning - `user_data`: + // Flexible file type for any purpose - `evals`: Used for eval data sets + // + // Any of "assistants", "batch", "fine-tune", "vision", "user_data", "evals". + Purpose FilePurpose `json:"purpose,omitzero,required"` + paramObj +} + +func (r FileNewParams) MarshalMultipart() (data []byte, contentType string, err error) { + buf := bytes.NewBuffer(nil) + writer := multipart.NewWriter(buf) + err = apiform.MarshalRoot(r, writer) + if err == nil { + err = apiform.WriteExtras(writer, r.ExtraFields()) + } + if err != nil { + writer.Close() + return nil, "", err + } + err = writer.Close() + if err != nil { + return nil, "", err + } + return buf.Bytes(), writer.FormDataContentType(), nil +} + +type FileListParams struct { + // A cursor for use in pagination. `after` is an object ID that defines your place + // in the list. For instance, if you make a list request and receive 100 objects, + // ending with obj_foo, your subsequent call can include after=obj_foo in order to + // fetch the next page of the list. + After param.Opt[string] `query:"after,omitzero" json:"-"` + // A limit on the number of objects to be returned. Limit can range between 1 and + // 10,000, and the default is 10,000. + Limit param.Opt[int64] `query:"limit,omitzero" json:"-"` + // Only return files with the given purpose. + Purpose param.Opt[string] `query:"purpose,omitzero" json:"-"` + // Sort order by the `created_at` timestamp of the objects. `asc` for ascending + // order and `desc` for descending order. + // + // Any of "asc", "desc". + Order FileListParamsOrder `query:"order,omitzero" json:"-"` + paramObj +} + +// URLQuery serializes [FileListParams]'s query parameters as `url.Values`. +func (r FileListParams) URLQuery() (v url.Values, err error) { + return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{ + ArrayFormat: apiquery.ArrayQueryFormatBrackets, + NestedFormat: apiquery.NestedQueryFormatBrackets, + }) +} + +// Sort order by the `created_at` timestamp of the objects. `asc` for ascending +// order and `desc` for descending order. +type FileListParamsOrder string + +const ( + FileListParamsOrderAsc FileListParamsOrder = "asc" + FileListParamsOrderDesc FileListParamsOrder = "desc" +) diff --git a/vendor/github.com/openai/openai-go/finetuning.go b/vendor/github.com/openai/openai-go/finetuning.go new file mode 100644 index 0000000000..3de51c1364 --- /dev/null +++ b/vendor/github.com/openai/openai-go/finetuning.go @@ -0,0 +1,34 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package openai + +import ( + "github.com/openai/openai-go/option" +) + +// FineTuningService contains methods and other services that help with interacting +// with the openai API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewFineTuningService] method instead. +type FineTuningService struct { + Options []option.RequestOption + Methods FineTuningMethodService + Jobs FineTuningJobService + Checkpoints FineTuningCheckpointService + Alpha FineTuningAlphaService +} + +// NewFineTuningService generates a new service that applies the given options to +// each request. These options are applied after the parent client's options (if +// there is one), and before any request-specific options. +func NewFineTuningService(opts ...option.RequestOption) (r FineTuningService) { + r = FineTuningService{} + r.Options = opts + r.Methods = NewFineTuningMethodService(opts...) + r.Jobs = NewFineTuningJobService(opts...) + r.Checkpoints = NewFineTuningCheckpointService(opts...) + r.Alpha = NewFineTuningAlphaService(opts...) + return +} diff --git a/vendor/github.com/openai/openai-go/finetuningalpha.go b/vendor/github.com/openai/openai-go/finetuningalpha.go new file mode 100644 index 0000000000..986a178ffc --- /dev/null +++ b/vendor/github.com/openai/openai-go/finetuningalpha.go @@ -0,0 +1,28 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package openai + +import ( + "github.com/openai/openai-go/option" +) + +// FineTuningAlphaService contains methods and other services that help with +// interacting with the openai API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewFineTuningAlphaService] method instead. +type FineTuningAlphaService struct { + Options []option.RequestOption + Graders FineTuningAlphaGraderService +} + +// NewFineTuningAlphaService generates a new service that applies the given options +// to each request. These options are applied after the parent client's options (if +// there is one), and before any request-specific options. +func NewFineTuningAlphaService(opts ...option.RequestOption) (r FineTuningAlphaService) { + r = FineTuningAlphaService{} + r.Options = opts + r.Graders = NewFineTuningAlphaGraderService(opts...) + return +} diff --git a/vendor/github.com/openai/openai-go/finetuningalphagrader.go b/vendor/github.com/openai/openai-go/finetuningalphagrader.go new file mode 100644 index 0000000000..4920583278 --- /dev/null +++ b/vendor/github.com/openai/openai-go/finetuningalphagrader.go @@ -0,0 +1,672 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package openai + +import ( + "context" + "encoding/json" + "net/http" + + "github.com/openai/openai-go/internal/apijson" + "github.com/openai/openai-go/internal/requestconfig" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/packages/param" + "github.com/openai/openai-go/packages/respjson" +) + +// FineTuningAlphaGraderService contains methods and other services that help with +// interacting with the openai API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewFineTuningAlphaGraderService] method instead. +type FineTuningAlphaGraderService struct { + Options []option.RequestOption +} + +// NewFineTuningAlphaGraderService generates a new service that applies the given +// options to each request. These options are applied after the parent client's +// options (if there is one), and before any request-specific options. +func NewFineTuningAlphaGraderService(opts ...option.RequestOption) (r FineTuningAlphaGraderService) { + r = FineTuningAlphaGraderService{} + r.Options = opts + return +} + +// Run a grader. +func (r *FineTuningAlphaGraderService) Run(ctx context.Context, body FineTuningAlphaGraderRunParams, opts ...option.RequestOption) (res *FineTuningAlphaGraderRunResponse, err error) { + opts = append(r.Options[:], opts...) + path := "fine_tuning/alpha/graders/run" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) + return +} + +// Validate a grader. +func (r *FineTuningAlphaGraderService) Validate(ctx context.Context, body FineTuningAlphaGraderValidateParams, opts ...option.RequestOption) (res *FineTuningAlphaGraderValidateResponse, err error) { + opts = append(r.Options[:], opts...) + path := "fine_tuning/alpha/graders/validate" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) + return +} + +type FineTuningAlphaGraderRunResponse struct { + Metadata FineTuningAlphaGraderRunResponseMetadata `json:"metadata,required"` + ModelGraderTokenUsagePerModel map[string]any `json:"model_grader_token_usage_per_model,required"` + Reward float64 `json:"reward,required"` + SubRewards map[string]any `json:"sub_rewards,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Metadata respjson.Field + ModelGraderTokenUsagePerModel respjson.Field + Reward respjson.Field + SubRewards respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningAlphaGraderRunResponse) RawJSON() string { return r.JSON.raw } +func (r *FineTuningAlphaGraderRunResponse) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningAlphaGraderRunResponseMetadata struct { + Errors FineTuningAlphaGraderRunResponseMetadataErrors `json:"errors,required"` + ExecutionTime float64 `json:"execution_time,required"` + Name string `json:"name,required"` + SampledModelName string `json:"sampled_model_name,required"` + Scores map[string]any `json:"scores,required"` + TokenUsage int64 `json:"token_usage,required"` + Type string `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Errors respjson.Field + ExecutionTime respjson.Field + Name respjson.Field + SampledModelName respjson.Field + Scores respjson.Field + TokenUsage respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningAlphaGraderRunResponseMetadata) RawJSON() string { return r.JSON.raw } +func (r *FineTuningAlphaGraderRunResponseMetadata) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningAlphaGraderRunResponseMetadataErrors struct { + FormulaParseError bool `json:"formula_parse_error,required"` + InvalidVariableError bool `json:"invalid_variable_error,required"` + ModelGraderParseError bool `json:"model_grader_parse_error,required"` + ModelGraderRefusalError bool `json:"model_grader_refusal_error,required"` + ModelGraderServerError bool `json:"model_grader_server_error,required"` + ModelGraderServerErrorDetails string `json:"model_grader_server_error_details,required"` + OtherError bool `json:"other_error,required"` + PythonGraderRuntimeError bool `json:"python_grader_runtime_error,required"` + PythonGraderRuntimeErrorDetails string `json:"python_grader_runtime_error_details,required"` + PythonGraderServerError bool `json:"python_grader_server_error,required"` + PythonGraderServerErrorType string `json:"python_grader_server_error_type,required"` + SampleParseError bool `json:"sample_parse_error,required"` + TruncatedObservationError bool `json:"truncated_observation_error,required"` + UnresponsiveRewardError bool `json:"unresponsive_reward_error,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + FormulaParseError respjson.Field + InvalidVariableError respjson.Field + ModelGraderParseError respjson.Field + ModelGraderRefusalError respjson.Field + ModelGraderServerError respjson.Field + ModelGraderServerErrorDetails respjson.Field + OtherError respjson.Field + PythonGraderRuntimeError respjson.Field + PythonGraderRuntimeErrorDetails respjson.Field + PythonGraderServerError respjson.Field + PythonGraderServerErrorType respjson.Field + SampleParseError respjson.Field + TruncatedObservationError respjson.Field + UnresponsiveRewardError respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningAlphaGraderRunResponseMetadataErrors) RawJSON() string { return r.JSON.raw } +func (r *FineTuningAlphaGraderRunResponseMetadataErrors) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningAlphaGraderValidateResponse struct { + // The grader used for the fine-tuning job. + Grader FineTuningAlphaGraderValidateResponseGraderUnion `json:"grader"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Grader respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningAlphaGraderValidateResponse) RawJSON() string { return r.JSON.raw } +func (r *FineTuningAlphaGraderValidateResponse) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// FineTuningAlphaGraderValidateResponseGraderUnion contains all possible +// properties and values from [StringCheckGrader], [TextSimilarityGrader], +// [PythonGrader], [ScoreModelGrader], [MultiGrader]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +type FineTuningAlphaGraderValidateResponseGraderUnion struct { + // This field is a union of [string], [string], [[]ScoreModelGraderInput] + Input FineTuningAlphaGraderValidateResponseGraderUnionInput `json:"input"` + Name string `json:"name"` + // This field is from variant [StringCheckGrader]. + Operation StringCheckGraderOperation `json:"operation"` + Reference string `json:"reference"` + Type string `json:"type"` + // This field is from variant [TextSimilarityGrader]. + EvaluationMetric TextSimilarityGraderEvaluationMetric `json:"evaluation_metric"` + // This field is from variant [PythonGrader]. + Source string `json:"source"` + // This field is from variant [PythonGrader]. + ImageTag string `json:"image_tag"` + // This field is from variant [ScoreModelGrader]. + Model string `json:"model"` + // This field is from variant [ScoreModelGrader]. + Range []float64 `json:"range"` + // This field is from variant [ScoreModelGrader]. + SamplingParams any `json:"sampling_params"` + // This field is from variant [MultiGrader]. + CalculateOutput string `json:"calculate_output"` + // This field is from variant [MultiGrader]. + Graders MultiGraderGradersUnion `json:"graders"` + JSON struct { + Input respjson.Field + Name respjson.Field + Operation respjson.Field + Reference respjson.Field + Type respjson.Field + EvaluationMetric respjson.Field + Source respjson.Field + ImageTag respjson.Field + Model respjson.Field + Range respjson.Field + SamplingParams respjson.Field + CalculateOutput respjson.Field + Graders respjson.Field + raw string + } `json:"-"` +} + +func (u FineTuningAlphaGraderValidateResponseGraderUnion) AsStringCheckGrader() (v StringCheckGrader) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u FineTuningAlphaGraderValidateResponseGraderUnion) AsTextSimilarityGrader() (v TextSimilarityGrader) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u FineTuningAlphaGraderValidateResponseGraderUnion) AsPythonGrader() (v PythonGrader) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u FineTuningAlphaGraderValidateResponseGraderUnion) AsScoreModelGrader() (v ScoreModelGrader) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u FineTuningAlphaGraderValidateResponseGraderUnion) AsMultiGrader() (v MultiGrader) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u FineTuningAlphaGraderValidateResponseGraderUnion) RawJSON() string { return u.JSON.raw } + +func (r *FineTuningAlphaGraderValidateResponseGraderUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// FineTuningAlphaGraderValidateResponseGraderUnionInput is an implicit subunion of +// [FineTuningAlphaGraderValidateResponseGraderUnion]. +// FineTuningAlphaGraderValidateResponseGraderUnionInput provides convenient access +// to the sub-properties of the union. +// +// For type safety it is recommended to directly use a variant of the +// [FineTuningAlphaGraderValidateResponseGraderUnion]. +// +// If the underlying value is not a json object, one of the following properties +// will be valid: OfString OfScoreModelGraderInputArray] +type FineTuningAlphaGraderValidateResponseGraderUnionInput struct { + // This field will be present if the value is a [string] instead of an object. + OfString string `json:",inline"` + // This field will be present if the value is a [[]ScoreModelGraderInput] instead + // of an object. + OfScoreModelGraderInputArray []ScoreModelGraderInput `json:",inline"` + JSON struct { + OfString respjson.Field + OfScoreModelGraderInputArray respjson.Field + raw string + } `json:"-"` +} + +func (r *FineTuningAlphaGraderValidateResponseGraderUnionInput) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningAlphaGraderRunParams struct { + // The grader used for the fine-tuning job. + Grader FineTuningAlphaGraderRunParamsGraderUnion `json:"grader,omitzero,required"` + // The model sample to be evaluated. This value will be used to populate the + // `sample` namespace. See + // [the guide](https://platform.openai.com/docs/guides/graders) for more details. + // The `output_json` variable will be populated if the model sample is a valid JSON + // string. + ModelSample string `json:"model_sample,required"` + // The dataset item provided to the grader. This will be used to populate the + // `item` namespace. See + // [the guide](https://platform.openai.com/docs/guides/graders) for more details. + Item any `json:"item,omitzero"` + paramObj +} + +func (r FineTuningAlphaGraderRunParams) MarshalJSON() (data []byte, err error) { + type shadow FineTuningAlphaGraderRunParams + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *FineTuningAlphaGraderRunParams) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type FineTuningAlphaGraderRunParamsGraderUnion struct { + OfStringCheck *StringCheckGraderParam `json:",omitzero,inline"` + OfTextSimilarity *TextSimilarityGraderParam `json:",omitzero,inline"` + OfPython *PythonGraderParam `json:",omitzero,inline"` + OfScoreModel *ScoreModelGraderParam `json:",omitzero,inline"` + OfMulti *MultiGraderParam `json:",omitzero,inline"` + paramUnion +} + +func (u FineTuningAlphaGraderRunParamsGraderUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfStringCheck, + u.OfTextSimilarity, + u.OfPython, + u.OfScoreModel, + u.OfMulti) +} +func (u *FineTuningAlphaGraderRunParamsGraderUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *FineTuningAlphaGraderRunParamsGraderUnion) asAny() any { + if !param.IsOmitted(u.OfStringCheck) { + return u.OfStringCheck + } else if !param.IsOmitted(u.OfTextSimilarity) { + return u.OfTextSimilarity + } else if !param.IsOmitted(u.OfPython) { + return u.OfPython + } else if !param.IsOmitted(u.OfScoreModel) { + return u.OfScoreModel + } else if !param.IsOmitted(u.OfMulti) { + return u.OfMulti + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningAlphaGraderRunParamsGraderUnion) GetOperation() *string { + if vt := u.OfStringCheck; vt != nil { + return (*string)(&vt.Operation) + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningAlphaGraderRunParamsGraderUnion) GetEvaluationMetric() *string { + if vt := u.OfTextSimilarity; vt != nil { + return (*string)(&vt.EvaluationMetric) + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningAlphaGraderRunParamsGraderUnion) GetSource() *string { + if vt := u.OfPython; vt != nil { + return &vt.Source + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningAlphaGraderRunParamsGraderUnion) GetImageTag() *string { + if vt := u.OfPython; vt != nil && vt.ImageTag.Valid() { + return &vt.ImageTag.Value + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningAlphaGraderRunParamsGraderUnion) GetModel() *string { + if vt := u.OfScoreModel; vt != nil { + return &vt.Model + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningAlphaGraderRunParamsGraderUnion) GetRange() []float64 { + if vt := u.OfScoreModel; vt != nil { + return vt.Range + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningAlphaGraderRunParamsGraderUnion) GetSamplingParams() *any { + if vt := u.OfScoreModel; vt != nil { + return &vt.SamplingParams + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningAlphaGraderRunParamsGraderUnion) GetCalculateOutput() *string { + if vt := u.OfMulti; vt != nil { + return &vt.CalculateOutput + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningAlphaGraderRunParamsGraderUnion) GetGraders() *MultiGraderGradersUnionParam { + if vt := u.OfMulti; vt != nil { + return &vt.Graders + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningAlphaGraderRunParamsGraderUnion) GetName() *string { + if vt := u.OfStringCheck; vt != nil { + return (*string)(&vt.Name) + } else if vt := u.OfTextSimilarity; vt != nil { + return (*string)(&vt.Name) + } else if vt := u.OfPython; vt != nil { + return (*string)(&vt.Name) + } else if vt := u.OfScoreModel; vt != nil { + return (*string)(&vt.Name) + } else if vt := u.OfMulti; vt != nil { + return (*string)(&vt.Name) + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningAlphaGraderRunParamsGraderUnion) GetReference() *string { + if vt := u.OfStringCheck; vt != nil { + return (*string)(&vt.Reference) + } else if vt := u.OfTextSimilarity; vt != nil { + return (*string)(&vt.Reference) + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningAlphaGraderRunParamsGraderUnion) GetType() *string { + if vt := u.OfStringCheck; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfTextSimilarity; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfPython; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfScoreModel; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfMulti; vt != nil { + return (*string)(&vt.Type) + } + return nil +} + +// Returns a subunion which exports methods to access subproperties +// +// Or use AsAny() to get the underlying value +func (u FineTuningAlphaGraderRunParamsGraderUnion) GetInput() (res fineTuningAlphaGraderRunParamsGraderUnionInput) { + if vt := u.OfStringCheck; vt != nil { + res.any = &vt.Input + } else if vt := u.OfTextSimilarity; vt != nil { + res.any = &vt.Input + } else if vt := u.OfScoreModel; vt != nil { + res.any = &vt.Input + } + return +} + +// Can have the runtime types [*string], [\*[]ScoreModelGraderInputParam] +type fineTuningAlphaGraderRunParamsGraderUnionInput struct{ any } + +// Use the following switch statement to get the type of the union: +// +// switch u.AsAny().(type) { +// case *string: +// case *[]openai.ScoreModelGraderInputParam: +// default: +// fmt.Errorf("not present") +// } +func (u fineTuningAlphaGraderRunParamsGraderUnionInput) AsAny() any { return u.any } + +func init() { + apijson.RegisterUnion[FineTuningAlphaGraderRunParamsGraderUnion]( + "type", + apijson.Discriminator[StringCheckGraderParam]("string_check"), + apijson.Discriminator[TextSimilarityGraderParam]("text_similarity"), + apijson.Discriminator[PythonGraderParam]("python"), + apijson.Discriminator[ScoreModelGraderParam]("score_model"), + apijson.Discriminator[MultiGraderParam]("multi"), + ) +} + +type FineTuningAlphaGraderValidateParams struct { + // The grader used for the fine-tuning job. + Grader FineTuningAlphaGraderValidateParamsGraderUnion `json:"grader,omitzero,required"` + paramObj +} + +func (r FineTuningAlphaGraderValidateParams) MarshalJSON() (data []byte, err error) { + type shadow FineTuningAlphaGraderValidateParams + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *FineTuningAlphaGraderValidateParams) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type FineTuningAlphaGraderValidateParamsGraderUnion struct { + OfStringCheckGrader *StringCheckGraderParam `json:",omitzero,inline"` + OfTextSimilarityGrader *TextSimilarityGraderParam `json:",omitzero,inline"` + OfPythonGrader *PythonGraderParam `json:",omitzero,inline"` + OfScoreModelGrader *ScoreModelGraderParam `json:",omitzero,inline"` + OfMultiGrader *MultiGraderParam `json:",omitzero,inline"` + paramUnion +} + +func (u FineTuningAlphaGraderValidateParamsGraderUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfStringCheckGrader, + u.OfTextSimilarityGrader, + u.OfPythonGrader, + u.OfScoreModelGrader, + u.OfMultiGrader) +} +func (u *FineTuningAlphaGraderValidateParamsGraderUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *FineTuningAlphaGraderValidateParamsGraderUnion) asAny() any { + if !param.IsOmitted(u.OfStringCheckGrader) { + return u.OfStringCheckGrader + } else if !param.IsOmitted(u.OfTextSimilarityGrader) { + return u.OfTextSimilarityGrader + } else if !param.IsOmitted(u.OfPythonGrader) { + return u.OfPythonGrader + } else if !param.IsOmitted(u.OfScoreModelGrader) { + return u.OfScoreModelGrader + } else if !param.IsOmitted(u.OfMultiGrader) { + return u.OfMultiGrader + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningAlphaGraderValidateParamsGraderUnion) GetOperation() *string { + if vt := u.OfStringCheckGrader; vt != nil { + return (*string)(&vt.Operation) + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningAlphaGraderValidateParamsGraderUnion) GetEvaluationMetric() *string { + if vt := u.OfTextSimilarityGrader; vt != nil { + return (*string)(&vt.EvaluationMetric) + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningAlphaGraderValidateParamsGraderUnion) GetSource() *string { + if vt := u.OfPythonGrader; vt != nil { + return &vt.Source + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningAlphaGraderValidateParamsGraderUnion) GetImageTag() *string { + if vt := u.OfPythonGrader; vt != nil && vt.ImageTag.Valid() { + return &vt.ImageTag.Value + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningAlphaGraderValidateParamsGraderUnion) GetModel() *string { + if vt := u.OfScoreModelGrader; vt != nil { + return &vt.Model + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningAlphaGraderValidateParamsGraderUnion) GetRange() []float64 { + if vt := u.OfScoreModelGrader; vt != nil { + return vt.Range + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningAlphaGraderValidateParamsGraderUnion) GetSamplingParams() *any { + if vt := u.OfScoreModelGrader; vt != nil { + return &vt.SamplingParams + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningAlphaGraderValidateParamsGraderUnion) GetCalculateOutput() *string { + if vt := u.OfMultiGrader; vt != nil { + return &vt.CalculateOutput + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningAlphaGraderValidateParamsGraderUnion) GetGraders() *MultiGraderGradersUnionParam { + if vt := u.OfMultiGrader; vt != nil { + return &vt.Graders + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningAlphaGraderValidateParamsGraderUnion) GetName() *string { + if vt := u.OfStringCheckGrader; vt != nil { + return (*string)(&vt.Name) + } else if vt := u.OfTextSimilarityGrader; vt != nil { + return (*string)(&vt.Name) + } else if vt := u.OfPythonGrader; vt != nil { + return (*string)(&vt.Name) + } else if vt := u.OfScoreModelGrader; vt != nil { + return (*string)(&vt.Name) + } else if vt := u.OfMultiGrader; vt != nil { + return (*string)(&vt.Name) + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningAlphaGraderValidateParamsGraderUnion) GetReference() *string { + if vt := u.OfStringCheckGrader; vt != nil { + return (*string)(&vt.Reference) + } else if vt := u.OfTextSimilarityGrader; vt != nil { + return (*string)(&vt.Reference) + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningAlphaGraderValidateParamsGraderUnion) GetType() *string { + if vt := u.OfStringCheckGrader; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfTextSimilarityGrader; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfPythonGrader; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfScoreModelGrader; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfMultiGrader; vt != nil { + return (*string)(&vt.Type) + } + return nil +} + +// Returns a subunion which exports methods to access subproperties +// +// Or use AsAny() to get the underlying value +func (u FineTuningAlphaGraderValidateParamsGraderUnion) GetInput() (res fineTuningAlphaGraderValidateParamsGraderUnionInput) { + if vt := u.OfStringCheckGrader; vt != nil { + res.any = &vt.Input + } else if vt := u.OfTextSimilarityGrader; vt != nil { + res.any = &vt.Input + } else if vt := u.OfScoreModelGrader; vt != nil { + res.any = &vt.Input + } + return +} + +// Can have the runtime types [*string], [\*[]ScoreModelGraderInputParam] +type fineTuningAlphaGraderValidateParamsGraderUnionInput struct{ any } + +// Use the following switch statement to get the type of the union: +// +// switch u.AsAny().(type) { +// case *string: +// case *[]openai.ScoreModelGraderInputParam: +// default: +// fmt.Errorf("not present") +// } +func (u fineTuningAlphaGraderValidateParamsGraderUnionInput) AsAny() any { return u.any } diff --git a/vendor/github.com/openai/openai-go/finetuningcheckpoint.go b/vendor/github.com/openai/openai-go/finetuningcheckpoint.go new file mode 100644 index 0000000000..11a485e7c8 --- /dev/null +++ b/vendor/github.com/openai/openai-go/finetuningcheckpoint.go @@ -0,0 +1,28 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package openai + +import ( + "github.com/openai/openai-go/option" +) + +// FineTuningCheckpointService contains methods and other services that help with +// interacting with the openai API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewFineTuningCheckpointService] method instead. +type FineTuningCheckpointService struct { + Options []option.RequestOption + Permissions FineTuningCheckpointPermissionService +} + +// NewFineTuningCheckpointService generates a new service that applies the given +// options to each request. These options are applied after the parent client's +// options (if there is one), and before any request-specific options. +func NewFineTuningCheckpointService(opts ...option.RequestOption) (r FineTuningCheckpointService) { + r = FineTuningCheckpointService{} + r.Options = opts + r.Permissions = NewFineTuningCheckpointPermissionService(opts...) + return +} diff --git a/vendor/github.com/openai/openai-go/finetuningcheckpointpermission.go b/vendor/github.com/openai/openai-go/finetuningcheckpointpermission.go new file mode 100644 index 0000000000..2799291914 --- /dev/null +++ b/vendor/github.com/openai/openai-go/finetuningcheckpointpermission.go @@ -0,0 +1,254 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package openai + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + + "github.com/openai/openai-go/internal/apijson" + "github.com/openai/openai-go/internal/apiquery" + "github.com/openai/openai-go/internal/requestconfig" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/packages/pagination" + "github.com/openai/openai-go/packages/param" + "github.com/openai/openai-go/packages/respjson" + "github.com/openai/openai-go/shared/constant" +) + +// FineTuningCheckpointPermissionService contains methods and other services that +// help with interacting with the openai API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewFineTuningCheckpointPermissionService] method instead. +type FineTuningCheckpointPermissionService struct { + Options []option.RequestOption +} + +// NewFineTuningCheckpointPermissionService generates a new service that applies +// the given options to each request. These options are applied after the parent +// client's options (if there is one), and before any request-specific options. +func NewFineTuningCheckpointPermissionService(opts ...option.RequestOption) (r FineTuningCheckpointPermissionService) { + r = FineTuningCheckpointPermissionService{} + r.Options = opts + return +} + +// **NOTE:** Calling this endpoint requires an [admin API key](../admin-api-keys). +// +// This enables organization owners to share fine-tuned models with other projects +// in their organization. +func (r *FineTuningCheckpointPermissionService) New(ctx context.Context, fineTunedModelCheckpoint string, body FineTuningCheckpointPermissionNewParams, opts ...option.RequestOption) (res *pagination.Page[FineTuningCheckpointPermissionNewResponse], err error) { + var raw *http.Response + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithResponseInto(&raw)}, opts...) + if fineTunedModelCheckpoint == "" { + err = errors.New("missing required fine_tuned_model_checkpoint parameter") + return + } + path := fmt.Sprintf("fine_tuning/checkpoints/%s/permissions", fineTunedModelCheckpoint) + cfg, err := requestconfig.NewRequestConfig(ctx, http.MethodPost, path, body, &res, opts...) + if err != nil { + return nil, err + } + err = cfg.Execute() + if err != nil { + return nil, err + } + res.SetPageConfig(cfg, raw) + return res, nil +} + +// **NOTE:** Calling this endpoint requires an [admin API key](../admin-api-keys). +// +// This enables organization owners to share fine-tuned models with other projects +// in their organization. +func (r *FineTuningCheckpointPermissionService) NewAutoPaging(ctx context.Context, fineTunedModelCheckpoint string, body FineTuningCheckpointPermissionNewParams, opts ...option.RequestOption) *pagination.PageAutoPager[FineTuningCheckpointPermissionNewResponse] { + return pagination.NewPageAutoPager(r.New(ctx, fineTunedModelCheckpoint, body, opts...)) +} + +// **NOTE:** This endpoint requires an [admin API key](../admin-api-keys). +// +// Organization owners can use this endpoint to view all permissions for a +// fine-tuned model checkpoint. +func (r *FineTuningCheckpointPermissionService) Get(ctx context.Context, fineTunedModelCheckpoint string, query FineTuningCheckpointPermissionGetParams, opts ...option.RequestOption) (res *FineTuningCheckpointPermissionGetResponse, err error) { + opts = append(r.Options[:], opts...) + if fineTunedModelCheckpoint == "" { + err = errors.New("missing required fine_tuned_model_checkpoint parameter") + return + } + path := fmt.Sprintf("fine_tuning/checkpoints/%s/permissions", fineTunedModelCheckpoint) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, query, &res, opts...) + return +} + +// **NOTE:** This endpoint requires an [admin API key](../admin-api-keys). +// +// Organization owners can use this endpoint to delete a permission for a +// fine-tuned model checkpoint. +func (r *FineTuningCheckpointPermissionService) Delete(ctx context.Context, fineTunedModelCheckpoint string, permissionID string, opts ...option.RequestOption) (res *FineTuningCheckpointPermissionDeleteResponse, err error) { + opts = append(r.Options[:], opts...) + if fineTunedModelCheckpoint == "" { + err = errors.New("missing required fine_tuned_model_checkpoint parameter") + return + } + if permissionID == "" { + err = errors.New("missing required permission_id parameter") + return + } + path := fmt.Sprintf("fine_tuning/checkpoints/%s/permissions/%s", fineTunedModelCheckpoint, permissionID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodDelete, path, nil, &res, opts...) + return +} + +// The `checkpoint.permission` object represents a permission for a fine-tuned +// model checkpoint. +type FineTuningCheckpointPermissionNewResponse struct { + // The permission identifier, which can be referenced in the API endpoints. + ID string `json:"id,required"` + // The Unix timestamp (in seconds) for when the permission was created. + CreatedAt int64 `json:"created_at,required"` + // The object type, which is always "checkpoint.permission". + Object constant.CheckpointPermission `json:"object,required"` + // The project identifier that the permission is for. + ProjectID string `json:"project_id,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + CreatedAt respjson.Field + Object respjson.Field + ProjectID respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningCheckpointPermissionNewResponse) RawJSON() string { return r.JSON.raw } +func (r *FineTuningCheckpointPermissionNewResponse) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningCheckpointPermissionGetResponse struct { + Data []FineTuningCheckpointPermissionGetResponseData `json:"data,required"` + HasMore bool `json:"has_more,required"` + Object constant.List `json:"object,required"` + FirstID string `json:"first_id,nullable"` + LastID string `json:"last_id,nullable"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Data respjson.Field + HasMore respjson.Field + Object respjson.Field + FirstID respjson.Field + LastID respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningCheckpointPermissionGetResponse) RawJSON() string { return r.JSON.raw } +func (r *FineTuningCheckpointPermissionGetResponse) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The `checkpoint.permission` object represents a permission for a fine-tuned +// model checkpoint. +type FineTuningCheckpointPermissionGetResponseData struct { + // The permission identifier, which can be referenced in the API endpoints. + ID string `json:"id,required"` + // The Unix timestamp (in seconds) for when the permission was created. + CreatedAt int64 `json:"created_at,required"` + // The object type, which is always "checkpoint.permission". + Object constant.CheckpointPermission `json:"object,required"` + // The project identifier that the permission is for. + ProjectID string `json:"project_id,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + CreatedAt respjson.Field + Object respjson.Field + ProjectID respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningCheckpointPermissionGetResponseData) RawJSON() string { return r.JSON.raw } +func (r *FineTuningCheckpointPermissionGetResponseData) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningCheckpointPermissionDeleteResponse struct { + // The ID of the fine-tuned model checkpoint permission that was deleted. + ID string `json:"id,required"` + // Whether the fine-tuned model checkpoint permission was successfully deleted. + Deleted bool `json:"deleted,required"` + // The object type, which is always "checkpoint.permission". + Object constant.CheckpointPermission `json:"object,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + Deleted respjson.Field + Object respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningCheckpointPermissionDeleteResponse) RawJSON() string { return r.JSON.raw } +func (r *FineTuningCheckpointPermissionDeleteResponse) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningCheckpointPermissionNewParams struct { + // The project identifiers to grant access to. + ProjectIDs []string `json:"project_ids,omitzero,required"` + paramObj +} + +func (r FineTuningCheckpointPermissionNewParams) MarshalJSON() (data []byte, err error) { + type shadow FineTuningCheckpointPermissionNewParams + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *FineTuningCheckpointPermissionNewParams) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningCheckpointPermissionGetParams struct { + // Identifier for the last permission ID from the previous pagination request. + After param.Opt[string] `query:"after,omitzero" json:"-"` + // Number of permissions to retrieve. + Limit param.Opt[int64] `query:"limit,omitzero" json:"-"` + // The ID of the project to get permissions for. + ProjectID param.Opt[string] `query:"project_id,omitzero" json:"-"` + // The order in which to retrieve permissions. + // + // Any of "ascending", "descending". + Order FineTuningCheckpointPermissionGetParamsOrder `query:"order,omitzero" json:"-"` + paramObj +} + +// URLQuery serializes [FineTuningCheckpointPermissionGetParams]'s query parameters +// as `url.Values`. +func (r FineTuningCheckpointPermissionGetParams) URLQuery() (v url.Values, err error) { + return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{ + ArrayFormat: apiquery.ArrayQueryFormatBrackets, + NestedFormat: apiquery.NestedQueryFormatBrackets, + }) +} + +// The order in which to retrieve permissions. +type FineTuningCheckpointPermissionGetParamsOrder string + +const ( + FineTuningCheckpointPermissionGetParamsOrderAscending FineTuningCheckpointPermissionGetParamsOrder = "ascending" + FineTuningCheckpointPermissionGetParamsOrderDescending FineTuningCheckpointPermissionGetParamsOrder = "descending" +) diff --git a/vendor/github.com/openai/openai-go/finetuningjob.go b/vendor/github.com/openai/openai-go/finetuningjob.go new file mode 100644 index 0000000000..5776f03c38 --- /dev/null +++ b/vendor/github.com/openai/openai-go/finetuningjob.go @@ -0,0 +1,880 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package openai + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + + "github.com/openai/openai-go/internal/apijson" + "github.com/openai/openai-go/internal/apiquery" + "github.com/openai/openai-go/internal/requestconfig" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/packages/pagination" + "github.com/openai/openai-go/packages/param" + "github.com/openai/openai-go/packages/respjson" + "github.com/openai/openai-go/shared" + "github.com/openai/openai-go/shared/constant" +) + +// FineTuningJobService contains methods and other services that help with +// interacting with the openai API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewFineTuningJobService] method instead. +type FineTuningJobService struct { + Options []option.RequestOption + Checkpoints FineTuningJobCheckpointService +} + +// NewFineTuningJobService generates a new service that applies the given options +// to each request. These options are applied after the parent client's options (if +// there is one), and before any request-specific options. +func NewFineTuningJobService(opts ...option.RequestOption) (r FineTuningJobService) { + r = FineTuningJobService{} + r.Options = opts + r.Checkpoints = NewFineTuningJobCheckpointService(opts...) + return +} + +// Creates a fine-tuning job which begins the process of creating a new model from +// a given dataset. +// +// Response includes details of the enqueued job including job status and the name +// of the fine-tuned models once complete. +// +// [Learn more about fine-tuning](https://platform.openai.com/docs/guides/model-optimization) +func (r *FineTuningJobService) New(ctx context.Context, body FineTuningJobNewParams, opts ...option.RequestOption) (res *FineTuningJob, err error) { + opts = append(r.Options[:], opts...) + path := "fine_tuning/jobs" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) + return +} + +// Get info about a fine-tuning job. +// +// [Learn more about fine-tuning](https://platform.openai.com/docs/guides/model-optimization) +func (r *FineTuningJobService) Get(ctx context.Context, fineTuningJobID string, opts ...option.RequestOption) (res *FineTuningJob, err error) { + opts = append(r.Options[:], opts...) + if fineTuningJobID == "" { + err = errors.New("missing required fine_tuning_job_id parameter") + return + } + path := fmt.Sprintf("fine_tuning/jobs/%s", fineTuningJobID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, nil, &res, opts...) + return +} + +// List your organization's fine-tuning jobs +func (r *FineTuningJobService) List(ctx context.Context, query FineTuningJobListParams, opts ...option.RequestOption) (res *pagination.CursorPage[FineTuningJob], err error) { + var raw *http.Response + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithResponseInto(&raw)}, opts...) + path := "fine_tuning/jobs" + cfg, err := requestconfig.NewRequestConfig(ctx, http.MethodGet, path, query, &res, opts...) + if err != nil { + return nil, err + } + err = cfg.Execute() + if err != nil { + return nil, err + } + res.SetPageConfig(cfg, raw) + return res, nil +} + +// List your organization's fine-tuning jobs +func (r *FineTuningJobService) ListAutoPaging(ctx context.Context, query FineTuningJobListParams, opts ...option.RequestOption) *pagination.CursorPageAutoPager[FineTuningJob] { + return pagination.NewCursorPageAutoPager(r.List(ctx, query, opts...)) +} + +// Immediately cancel a fine-tune job. +func (r *FineTuningJobService) Cancel(ctx context.Context, fineTuningJobID string, opts ...option.RequestOption) (res *FineTuningJob, err error) { + opts = append(r.Options[:], opts...) + if fineTuningJobID == "" { + err = errors.New("missing required fine_tuning_job_id parameter") + return + } + path := fmt.Sprintf("fine_tuning/jobs/%s/cancel", fineTuningJobID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, nil, &res, opts...) + return +} + +// Get status updates for a fine-tuning job. +func (r *FineTuningJobService) ListEvents(ctx context.Context, fineTuningJobID string, query FineTuningJobListEventsParams, opts ...option.RequestOption) (res *pagination.CursorPage[FineTuningJobEvent], err error) { + var raw *http.Response + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithResponseInto(&raw)}, opts...) + if fineTuningJobID == "" { + err = errors.New("missing required fine_tuning_job_id parameter") + return + } + path := fmt.Sprintf("fine_tuning/jobs/%s/events", fineTuningJobID) + cfg, err := requestconfig.NewRequestConfig(ctx, http.MethodGet, path, query, &res, opts...) + if err != nil { + return nil, err + } + err = cfg.Execute() + if err != nil { + return nil, err + } + res.SetPageConfig(cfg, raw) + return res, nil +} + +// Get status updates for a fine-tuning job. +func (r *FineTuningJobService) ListEventsAutoPaging(ctx context.Context, fineTuningJobID string, query FineTuningJobListEventsParams, opts ...option.RequestOption) *pagination.CursorPageAutoPager[FineTuningJobEvent] { + return pagination.NewCursorPageAutoPager(r.ListEvents(ctx, fineTuningJobID, query, opts...)) +} + +// Pause a fine-tune job. +func (r *FineTuningJobService) Pause(ctx context.Context, fineTuningJobID string, opts ...option.RequestOption) (res *FineTuningJob, err error) { + opts = append(r.Options[:], opts...) + if fineTuningJobID == "" { + err = errors.New("missing required fine_tuning_job_id parameter") + return + } + path := fmt.Sprintf("fine_tuning/jobs/%s/pause", fineTuningJobID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, nil, &res, opts...) + return +} + +// Resume a fine-tune job. +func (r *FineTuningJobService) Resume(ctx context.Context, fineTuningJobID string, opts ...option.RequestOption) (res *FineTuningJob, err error) { + opts = append(r.Options[:], opts...) + if fineTuningJobID == "" { + err = errors.New("missing required fine_tuning_job_id parameter") + return + } + path := fmt.Sprintf("fine_tuning/jobs/%s/resume", fineTuningJobID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, nil, &res, opts...) + return +} + +// The `fine_tuning.job` object represents a fine-tuning job that has been created +// through the API. +type FineTuningJob struct { + // The object identifier, which can be referenced in the API endpoints. + ID string `json:"id,required"` + // The Unix timestamp (in seconds) for when the fine-tuning job was created. + CreatedAt int64 `json:"created_at,required"` + // For fine-tuning jobs that have `failed`, this will contain more information on + // the cause of the failure. + Error FineTuningJobError `json:"error,required"` + // The name of the fine-tuned model that is being created. The value will be null + // if the fine-tuning job is still running. + FineTunedModel string `json:"fine_tuned_model,required"` + // The Unix timestamp (in seconds) for when the fine-tuning job was finished. The + // value will be null if the fine-tuning job is still running. + FinishedAt int64 `json:"finished_at,required"` + // The hyperparameters used for the fine-tuning job. This value will only be + // returned when running `supervised` jobs. + Hyperparameters FineTuningJobHyperparameters `json:"hyperparameters,required"` + // The base model that is being fine-tuned. + Model string `json:"model,required"` + // The object type, which is always "fine_tuning.job". + Object constant.FineTuningJob `json:"object,required"` + // The organization that owns the fine-tuning job. + OrganizationID string `json:"organization_id,required"` + // The compiled results file ID(s) for the fine-tuning job. You can retrieve the + // results with the + // [Files API](https://platform.openai.com/docs/api-reference/files/retrieve-contents). + ResultFiles []string `json:"result_files,required"` + // The seed used for the fine-tuning job. + Seed int64 `json:"seed,required"` + // The current status of the fine-tuning job, which can be either + // `validating_files`, `queued`, `running`, `succeeded`, `failed`, or `cancelled`. + // + // Any of "validating_files", "queued", "running", "succeeded", "failed", + // "cancelled". + Status FineTuningJobStatus `json:"status,required"` + // The total number of billable tokens processed by this fine-tuning job. The value + // will be null if the fine-tuning job is still running. + TrainedTokens int64 `json:"trained_tokens,required"` + // The file ID used for training. You can retrieve the training data with the + // [Files API](https://platform.openai.com/docs/api-reference/files/retrieve-contents). + TrainingFile string `json:"training_file,required"` + // The file ID used for validation. You can retrieve the validation results with + // the + // [Files API](https://platform.openai.com/docs/api-reference/files/retrieve-contents). + ValidationFile string `json:"validation_file,required"` + // The Unix timestamp (in seconds) for when the fine-tuning job is estimated to + // finish. The value will be null if the fine-tuning job is not running. + EstimatedFinish int64 `json:"estimated_finish,nullable"` + // A list of integrations to enable for this fine-tuning job. + Integrations []FineTuningJobWandbIntegrationObject `json:"integrations,nullable"` + // Set of 16 key-value pairs that can be attached to an object. This can be useful + // for storing additional information about the object in a structured format, and + // querying for objects via API or the dashboard. + // + // Keys are strings with a maximum length of 64 characters. Values are strings with + // a maximum length of 512 characters. + Metadata shared.Metadata `json:"metadata,nullable"` + // The method used for fine-tuning. + Method FineTuningJobMethod `json:"method"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + CreatedAt respjson.Field + Error respjson.Field + FineTunedModel respjson.Field + FinishedAt respjson.Field + Hyperparameters respjson.Field + Model respjson.Field + Object respjson.Field + OrganizationID respjson.Field + ResultFiles respjson.Field + Seed respjson.Field + Status respjson.Field + TrainedTokens respjson.Field + TrainingFile respjson.Field + ValidationFile respjson.Field + EstimatedFinish respjson.Field + Integrations respjson.Field + Metadata respjson.Field + Method respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningJob) RawJSON() string { return r.JSON.raw } +func (r *FineTuningJob) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// For fine-tuning jobs that have `failed`, this will contain more information on +// the cause of the failure. +type FineTuningJobError struct { + // A machine-readable error code. + Code string `json:"code,required"` + // A human-readable error message. + Message string `json:"message,required"` + // The parameter that was invalid, usually `training_file` or `validation_file`. + // This field will be null if the failure was not parameter-specific. + Param string `json:"param,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Code respjson.Field + Message respjson.Field + Param respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningJobError) RawJSON() string { return r.JSON.raw } +func (r *FineTuningJobError) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The hyperparameters used for the fine-tuning job. This value will only be +// returned when running `supervised` jobs. +type FineTuningJobHyperparameters struct { + // Number of examples in each batch. A larger batch size means that model + // parameters are updated less frequently, but with lower variance. + BatchSize FineTuningJobHyperparametersBatchSizeUnion `json:"batch_size,nullable"` + // Scaling factor for the learning rate. A smaller learning rate may be useful to + // avoid overfitting. + LearningRateMultiplier FineTuningJobHyperparametersLearningRateMultiplierUnion `json:"learning_rate_multiplier"` + // The number of epochs to train the model for. An epoch refers to one full cycle + // through the training dataset. + NEpochs FineTuningJobHyperparametersNEpochsUnion `json:"n_epochs"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + BatchSize respjson.Field + LearningRateMultiplier respjson.Field + NEpochs respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningJobHyperparameters) RawJSON() string { return r.JSON.raw } +func (r *FineTuningJobHyperparameters) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// FineTuningJobHyperparametersBatchSizeUnion contains all possible properties and +// values from [constant.Auto], [int64]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +// +// If the underlying value is not a json object, one of the following properties +// will be valid: OfAuto OfInt] +type FineTuningJobHyperparametersBatchSizeUnion struct { + // This field will be present if the value is a [constant.Auto] instead of an + // object. + OfAuto constant.Auto `json:",inline"` + // This field will be present if the value is a [int64] instead of an object. + OfInt int64 `json:",inline"` + JSON struct { + OfAuto respjson.Field + OfInt respjson.Field + raw string + } `json:"-"` +} + +func (u FineTuningJobHyperparametersBatchSizeUnion) AsAuto() (v constant.Auto) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u FineTuningJobHyperparametersBatchSizeUnion) AsInt() (v int64) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u FineTuningJobHyperparametersBatchSizeUnion) RawJSON() string { return u.JSON.raw } + +func (r *FineTuningJobHyperparametersBatchSizeUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// FineTuningJobHyperparametersLearningRateMultiplierUnion contains all possible +// properties and values from [constant.Auto], [float64]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +// +// If the underlying value is not a json object, one of the following properties +// will be valid: OfAuto OfFloat] +type FineTuningJobHyperparametersLearningRateMultiplierUnion struct { + // This field will be present if the value is a [constant.Auto] instead of an + // object. + OfAuto constant.Auto `json:",inline"` + // This field will be present if the value is a [float64] instead of an object. + OfFloat float64 `json:",inline"` + JSON struct { + OfAuto respjson.Field + OfFloat respjson.Field + raw string + } `json:"-"` +} + +func (u FineTuningJobHyperparametersLearningRateMultiplierUnion) AsAuto() (v constant.Auto) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u FineTuningJobHyperparametersLearningRateMultiplierUnion) AsFloat() (v float64) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u FineTuningJobHyperparametersLearningRateMultiplierUnion) RawJSON() string { return u.JSON.raw } + +func (r *FineTuningJobHyperparametersLearningRateMultiplierUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// FineTuningJobHyperparametersNEpochsUnion contains all possible properties and +// values from [constant.Auto], [int64]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +// +// If the underlying value is not a json object, one of the following properties +// will be valid: OfAuto OfInt] +type FineTuningJobHyperparametersNEpochsUnion struct { + // This field will be present if the value is a [constant.Auto] instead of an + // object. + OfAuto constant.Auto `json:",inline"` + // This field will be present if the value is a [int64] instead of an object. + OfInt int64 `json:",inline"` + JSON struct { + OfAuto respjson.Field + OfInt respjson.Field + raw string + } `json:"-"` +} + +func (u FineTuningJobHyperparametersNEpochsUnion) AsAuto() (v constant.Auto) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u FineTuningJobHyperparametersNEpochsUnion) AsInt() (v int64) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u FineTuningJobHyperparametersNEpochsUnion) RawJSON() string { return u.JSON.raw } + +func (r *FineTuningJobHyperparametersNEpochsUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The current status of the fine-tuning job, which can be either +// `validating_files`, `queued`, `running`, `succeeded`, `failed`, or `cancelled`. +type FineTuningJobStatus string + +const ( + FineTuningJobStatusValidatingFiles FineTuningJobStatus = "validating_files" + FineTuningJobStatusQueued FineTuningJobStatus = "queued" + FineTuningJobStatusRunning FineTuningJobStatus = "running" + FineTuningJobStatusSucceeded FineTuningJobStatus = "succeeded" + FineTuningJobStatusFailed FineTuningJobStatus = "failed" + FineTuningJobStatusCancelled FineTuningJobStatus = "cancelled" +) + +// The method used for fine-tuning. +type FineTuningJobMethod struct { + // The type of method. Is either `supervised`, `dpo`, or `reinforcement`. + // + // Any of "supervised", "dpo", "reinforcement". + Type string `json:"type,required"` + // Configuration for the DPO fine-tuning method. + Dpo DpoMethod `json:"dpo"` + // Configuration for the reinforcement fine-tuning method. + Reinforcement ReinforcementMethod `json:"reinforcement"` + // Configuration for the supervised fine-tuning method. + Supervised SupervisedMethod `json:"supervised"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Type respjson.Field + Dpo respjson.Field + Reinforcement respjson.Field + Supervised respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningJobMethod) RawJSON() string { return r.JSON.raw } +func (r *FineTuningJobMethod) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Fine-tuning job event object +type FineTuningJobEvent struct { + // The object identifier. + ID string `json:"id,required"` + // The Unix timestamp (in seconds) for when the fine-tuning job was created. + CreatedAt int64 `json:"created_at,required"` + // The log level of the event. + // + // Any of "info", "warn", "error". + Level FineTuningJobEventLevel `json:"level,required"` + // The message of the event. + Message string `json:"message,required"` + // The object type, which is always "fine_tuning.job.event". + Object constant.FineTuningJobEvent `json:"object,required"` + // The data associated with the event. + Data any `json:"data"` + // The type of event. + // + // Any of "message", "metrics". + Type FineTuningJobEventType `json:"type"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + CreatedAt respjson.Field + Level respjson.Field + Message respjson.Field + Object respjson.Field + Data respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningJobEvent) RawJSON() string { return r.JSON.raw } +func (r *FineTuningJobEvent) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The log level of the event. +type FineTuningJobEventLevel string + +const ( + FineTuningJobEventLevelInfo FineTuningJobEventLevel = "info" + FineTuningJobEventLevelWarn FineTuningJobEventLevel = "warn" + FineTuningJobEventLevelError FineTuningJobEventLevel = "error" +) + +// The type of event. +type FineTuningJobEventType string + +const ( + FineTuningJobEventTypeMessage FineTuningJobEventType = "message" + FineTuningJobEventTypeMetrics FineTuningJobEventType = "metrics" +) + +// The settings for your integration with Weights and Biases. This payload +// specifies the project that metrics will be sent to. Optionally, you can set an +// explicit display name for your run, add tags to your run, and set a default +// entity (team, username, etc) to be associated with your run. +type FineTuningJobWandbIntegration struct { + // The name of the project that the new run will be created under. + Project string `json:"project,required"` + // The entity to use for the run. This allows you to set the team or username of + // the WandB user that you would like associated with the run. If not set, the + // default entity for the registered WandB API key is used. + Entity string `json:"entity,nullable"` + // A display name to set for the run. If not set, we will use the Job ID as the + // name. + Name string `json:"name,nullable"` + // A list of tags to be attached to the newly created run. These tags are passed + // through directly to WandB. Some default tags are generated by OpenAI: + // "openai/finetune", "openai/{base-model}", "openai/{ftjob-abcdef}". + Tags []string `json:"tags"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Project respjson.Field + Entity respjson.Field + Name respjson.Field + Tags respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningJobWandbIntegration) RawJSON() string { return r.JSON.raw } +func (r *FineTuningJobWandbIntegration) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningJobWandbIntegrationObject struct { + // The type of the integration being enabled for the fine-tuning job + Type constant.Wandb `json:"type,required"` + // The settings for your integration with Weights and Biases. This payload + // specifies the project that metrics will be sent to. Optionally, you can set an + // explicit display name for your run, add tags to your run, and set a default + // entity (team, username, etc) to be associated with your run. + Wandb FineTuningJobWandbIntegration `json:"wandb,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Type respjson.Field + Wandb respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningJobWandbIntegrationObject) RawJSON() string { return r.JSON.raw } +func (r *FineTuningJobWandbIntegrationObject) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningJobNewParams struct { + // The name of the model to fine-tune. You can select one of the + // [supported models](https://platform.openai.com/docs/guides/fine-tuning#which-models-can-be-fine-tuned). + Model FineTuningJobNewParamsModel `json:"model,omitzero,required"` + // The ID of an uploaded file that contains training data. + // + // See [upload file](https://platform.openai.com/docs/api-reference/files/create) + // for how to upload a file. + // + // Your dataset must be formatted as a JSONL file. Additionally, you must upload + // your file with the purpose `fine-tune`. + // + // The contents of the file should differ depending on if the model uses the + // [chat](https://platform.openai.com/docs/api-reference/fine-tuning/chat-input), + // [completions](https://platform.openai.com/docs/api-reference/fine-tuning/completions-input) + // format, or if the fine-tuning method uses the + // [preference](https://platform.openai.com/docs/api-reference/fine-tuning/preference-input) + // format. + // + // See the + // [fine-tuning guide](https://platform.openai.com/docs/guides/model-optimization) + // for more details. + TrainingFile string `json:"training_file,required"` + // The seed controls the reproducibility of the job. Passing in the same seed and + // job parameters should produce the same results, but may differ in rare cases. If + // a seed is not specified, one will be generated for you. + Seed param.Opt[int64] `json:"seed,omitzero"` + // A string of up to 64 characters that will be added to your fine-tuned model + // name. + // + // For example, a `suffix` of "custom-model-name" would produce a model name like + // `ft:gpt-4o-mini:openai:custom-model-name:7p4lURel`. + Suffix param.Opt[string] `json:"suffix,omitzero"` + // The ID of an uploaded file that contains validation data. + // + // If you provide this file, the data is used to generate validation metrics + // periodically during fine-tuning. These metrics can be viewed in the fine-tuning + // results file. The same data should not be present in both train and validation + // files. + // + // Your dataset must be formatted as a JSONL file. You must upload your file with + // the purpose `fine-tune`. + // + // See the + // [fine-tuning guide](https://platform.openai.com/docs/guides/model-optimization) + // for more details. + ValidationFile param.Opt[string] `json:"validation_file,omitzero"` + // A list of integrations to enable for your fine-tuning job. + Integrations []FineTuningJobNewParamsIntegration `json:"integrations,omitzero"` + // Set of 16 key-value pairs that can be attached to an object. This can be useful + // for storing additional information about the object in a structured format, and + // querying for objects via API or the dashboard. + // + // Keys are strings with a maximum length of 64 characters. Values are strings with + // a maximum length of 512 characters. + Metadata shared.Metadata `json:"metadata,omitzero"` + // The hyperparameters used for the fine-tuning job. This value is now deprecated + // in favor of `method`, and should be passed in under the `method` parameter. + Hyperparameters FineTuningJobNewParamsHyperparameters `json:"hyperparameters,omitzero"` + // The method used for fine-tuning. + Method FineTuningJobNewParamsMethod `json:"method,omitzero"` + paramObj +} + +func (r FineTuningJobNewParams) MarshalJSON() (data []byte, err error) { + type shadow FineTuningJobNewParams + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *FineTuningJobNewParams) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The name of the model to fine-tune. You can select one of the +// [supported models](https://platform.openai.com/docs/guides/fine-tuning#which-models-can-be-fine-tuned). +type FineTuningJobNewParamsModel string + +const ( + FineTuningJobNewParamsModelBabbage002 FineTuningJobNewParamsModel = "babbage-002" + FineTuningJobNewParamsModelDavinci002 FineTuningJobNewParamsModel = "davinci-002" + FineTuningJobNewParamsModelGPT3_5Turbo FineTuningJobNewParamsModel = "gpt-3.5-turbo" + FineTuningJobNewParamsModelGPT4oMini FineTuningJobNewParamsModel = "gpt-4o-mini" +) + +// The hyperparameters used for the fine-tuning job. This value is now deprecated +// in favor of `method`, and should be passed in under the `method` parameter. +// +// Deprecated: deprecated +type FineTuningJobNewParamsHyperparameters struct { + // Number of examples in each batch. A larger batch size means that model + // parameters are updated less frequently, but with lower variance. + BatchSize FineTuningJobNewParamsHyperparametersBatchSizeUnion `json:"batch_size,omitzero"` + // Scaling factor for the learning rate. A smaller learning rate may be useful to + // avoid overfitting. + LearningRateMultiplier FineTuningJobNewParamsHyperparametersLearningRateMultiplierUnion `json:"learning_rate_multiplier,omitzero"` + // The number of epochs to train the model for. An epoch refers to one full cycle + // through the training dataset. + NEpochs FineTuningJobNewParamsHyperparametersNEpochsUnion `json:"n_epochs,omitzero"` + paramObj +} + +func (r FineTuningJobNewParamsHyperparameters) MarshalJSON() (data []byte, err error) { + type shadow FineTuningJobNewParamsHyperparameters + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *FineTuningJobNewParamsHyperparameters) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type FineTuningJobNewParamsHyperparametersBatchSizeUnion struct { + // Construct this variant with constant.ValueOf[constant.Auto]() + OfAuto constant.Auto `json:",omitzero,inline"` + OfInt param.Opt[int64] `json:",omitzero,inline"` + paramUnion +} + +func (u FineTuningJobNewParamsHyperparametersBatchSizeUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfAuto, u.OfInt) +} +func (u *FineTuningJobNewParamsHyperparametersBatchSizeUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *FineTuningJobNewParamsHyperparametersBatchSizeUnion) asAny() any { + if !param.IsOmitted(u.OfAuto) { + return &u.OfAuto + } else if !param.IsOmitted(u.OfInt) { + return &u.OfInt.Value + } + return nil +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type FineTuningJobNewParamsHyperparametersLearningRateMultiplierUnion struct { + // Construct this variant with constant.ValueOf[constant.Auto]() + OfAuto constant.Auto `json:",omitzero,inline"` + OfFloat param.Opt[float64] `json:",omitzero,inline"` + paramUnion +} + +func (u FineTuningJobNewParamsHyperparametersLearningRateMultiplierUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfAuto, u.OfFloat) +} +func (u *FineTuningJobNewParamsHyperparametersLearningRateMultiplierUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *FineTuningJobNewParamsHyperparametersLearningRateMultiplierUnion) asAny() any { + if !param.IsOmitted(u.OfAuto) { + return &u.OfAuto + } else if !param.IsOmitted(u.OfFloat) { + return &u.OfFloat.Value + } + return nil +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type FineTuningJobNewParamsHyperparametersNEpochsUnion struct { + // Construct this variant with constant.ValueOf[constant.Auto]() + OfAuto constant.Auto `json:",omitzero,inline"` + OfInt param.Opt[int64] `json:",omitzero,inline"` + paramUnion +} + +func (u FineTuningJobNewParamsHyperparametersNEpochsUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfAuto, u.OfInt) +} +func (u *FineTuningJobNewParamsHyperparametersNEpochsUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *FineTuningJobNewParamsHyperparametersNEpochsUnion) asAny() any { + if !param.IsOmitted(u.OfAuto) { + return &u.OfAuto + } else if !param.IsOmitted(u.OfInt) { + return &u.OfInt.Value + } + return nil +} + +// The properties Type, Wandb are required. +type FineTuningJobNewParamsIntegration struct { + // The settings for your integration with Weights and Biases. This payload + // specifies the project that metrics will be sent to. Optionally, you can set an + // explicit display name for your run, add tags to your run, and set a default + // entity (team, username, etc) to be associated with your run. + Wandb FineTuningJobNewParamsIntegrationWandb `json:"wandb,omitzero,required"` + // The type of integration to enable. Currently, only "wandb" (Weights and Biases) + // is supported. + // + // This field can be elided, and will marshal its zero value as "wandb". + Type constant.Wandb `json:"type,required"` + paramObj +} + +func (r FineTuningJobNewParamsIntegration) MarshalJSON() (data []byte, err error) { + type shadow FineTuningJobNewParamsIntegration + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *FineTuningJobNewParamsIntegration) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The settings for your integration with Weights and Biases. This payload +// specifies the project that metrics will be sent to. Optionally, you can set an +// explicit display name for your run, add tags to your run, and set a default +// entity (team, username, etc) to be associated with your run. +// +// The property Project is required. +type FineTuningJobNewParamsIntegrationWandb struct { + // The name of the project that the new run will be created under. + Project string `json:"project,required"` + // The entity to use for the run. This allows you to set the team or username of + // the WandB user that you would like associated with the run. If not set, the + // default entity for the registered WandB API key is used. + Entity param.Opt[string] `json:"entity,omitzero"` + // A display name to set for the run. If not set, we will use the Job ID as the + // name. + Name param.Opt[string] `json:"name,omitzero"` + // A list of tags to be attached to the newly created run. These tags are passed + // through directly to WandB. Some default tags are generated by OpenAI: + // "openai/finetune", "openai/{base-model}", "openai/{ftjob-abcdef}". + Tags []string `json:"tags,omitzero"` + paramObj +} + +func (r FineTuningJobNewParamsIntegrationWandb) MarshalJSON() (data []byte, err error) { + type shadow FineTuningJobNewParamsIntegrationWandb + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *FineTuningJobNewParamsIntegrationWandb) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The method used for fine-tuning. +// +// The property Type is required. +type FineTuningJobNewParamsMethod struct { + // The type of method. Is either `supervised`, `dpo`, or `reinforcement`. + // + // Any of "supervised", "dpo", "reinforcement". + Type string `json:"type,omitzero,required"` + // Configuration for the DPO fine-tuning method. + Dpo DpoMethodParam `json:"dpo,omitzero"` + // Configuration for the reinforcement fine-tuning method. + Reinforcement ReinforcementMethodParam `json:"reinforcement,omitzero"` + // Configuration for the supervised fine-tuning method. + Supervised SupervisedMethodParam `json:"supervised,omitzero"` + paramObj +} + +func (r FineTuningJobNewParamsMethod) MarshalJSON() (data []byte, err error) { + type shadow FineTuningJobNewParamsMethod + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *FineTuningJobNewParamsMethod) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +func init() { + apijson.RegisterFieldValidator[FineTuningJobNewParamsMethod]( + "type", "supervised", "dpo", "reinforcement", + ) +} + +type FineTuningJobListParams struct { + // Identifier for the last job from the previous pagination request. + After param.Opt[string] `query:"after,omitzero" json:"-"` + // Number of fine-tuning jobs to retrieve. + Limit param.Opt[int64] `query:"limit,omitzero" json:"-"` + // Optional metadata filter. To filter, use the syntax `metadata[k]=v`. + // Alternatively, set `metadata=null` to indicate no metadata. + Metadata map[string]string `query:"metadata,omitzero" json:"-"` + paramObj +} + +// URLQuery serializes [FineTuningJobListParams]'s query parameters as +// `url.Values`. +func (r FineTuningJobListParams) URLQuery() (v url.Values, err error) { + return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{ + ArrayFormat: apiquery.ArrayQueryFormatBrackets, + NestedFormat: apiquery.NestedQueryFormatBrackets, + }) +} + +type FineTuningJobListEventsParams struct { + // Identifier for the last event from the previous pagination request. + After param.Opt[string] `query:"after,omitzero" json:"-"` + // Number of events to retrieve. + Limit param.Opt[int64] `query:"limit,omitzero" json:"-"` + paramObj +} + +// URLQuery serializes [FineTuningJobListEventsParams]'s query parameters as +// `url.Values`. +func (r FineTuningJobListEventsParams) URLQuery() (v url.Values, err error) { + return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{ + ArrayFormat: apiquery.ArrayQueryFormatBrackets, + NestedFormat: apiquery.NestedQueryFormatBrackets, + }) +} diff --git a/vendor/github.com/openai/openai-go/finetuningjobcheckpoint.go b/vendor/github.com/openai/openai-go/finetuningjobcheckpoint.go new file mode 100644 index 0000000000..69ef75da23 --- /dev/null +++ b/vendor/github.com/openai/openai-go/finetuningjobcheckpoint.go @@ -0,0 +1,149 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package openai + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + + "github.com/openai/openai-go/internal/apijson" + "github.com/openai/openai-go/internal/apiquery" + "github.com/openai/openai-go/internal/requestconfig" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/packages/pagination" + "github.com/openai/openai-go/packages/param" + "github.com/openai/openai-go/packages/respjson" + "github.com/openai/openai-go/shared/constant" +) + +// FineTuningJobCheckpointService contains methods and other services that help +// with interacting with the openai API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewFineTuningJobCheckpointService] method instead. +type FineTuningJobCheckpointService struct { + Options []option.RequestOption +} + +// NewFineTuningJobCheckpointService generates a new service that applies the given +// options to each request. These options are applied after the parent client's +// options (if there is one), and before any request-specific options. +func NewFineTuningJobCheckpointService(opts ...option.RequestOption) (r FineTuningJobCheckpointService) { + r = FineTuningJobCheckpointService{} + r.Options = opts + return +} + +// List checkpoints for a fine-tuning job. +func (r *FineTuningJobCheckpointService) List(ctx context.Context, fineTuningJobID string, query FineTuningJobCheckpointListParams, opts ...option.RequestOption) (res *pagination.CursorPage[FineTuningJobCheckpoint], err error) { + var raw *http.Response + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithResponseInto(&raw)}, opts...) + if fineTuningJobID == "" { + err = errors.New("missing required fine_tuning_job_id parameter") + return + } + path := fmt.Sprintf("fine_tuning/jobs/%s/checkpoints", fineTuningJobID) + cfg, err := requestconfig.NewRequestConfig(ctx, http.MethodGet, path, query, &res, opts...) + if err != nil { + return nil, err + } + err = cfg.Execute() + if err != nil { + return nil, err + } + res.SetPageConfig(cfg, raw) + return res, nil +} + +// List checkpoints for a fine-tuning job. +func (r *FineTuningJobCheckpointService) ListAutoPaging(ctx context.Context, fineTuningJobID string, query FineTuningJobCheckpointListParams, opts ...option.RequestOption) *pagination.CursorPageAutoPager[FineTuningJobCheckpoint] { + return pagination.NewCursorPageAutoPager(r.List(ctx, fineTuningJobID, query, opts...)) +} + +// The `fine_tuning.job.checkpoint` object represents a model checkpoint for a +// fine-tuning job that is ready to use. +type FineTuningJobCheckpoint struct { + // The checkpoint identifier, which can be referenced in the API endpoints. + ID string `json:"id,required"` + // The Unix timestamp (in seconds) for when the checkpoint was created. + CreatedAt int64 `json:"created_at,required"` + // The name of the fine-tuned checkpoint model that is created. + FineTunedModelCheckpoint string `json:"fine_tuned_model_checkpoint,required"` + // The name of the fine-tuning job that this checkpoint was created from. + FineTuningJobID string `json:"fine_tuning_job_id,required"` + // Metrics at the step number during the fine-tuning job. + Metrics FineTuningJobCheckpointMetrics `json:"metrics,required"` + // The object type, which is always "fine_tuning.job.checkpoint". + Object constant.FineTuningJobCheckpoint `json:"object,required"` + // The step number that the checkpoint was created at. + StepNumber int64 `json:"step_number,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + CreatedAt respjson.Field + FineTunedModelCheckpoint respjson.Field + FineTuningJobID respjson.Field + Metrics respjson.Field + Object respjson.Field + StepNumber respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningJobCheckpoint) RawJSON() string { return r.JSON.raw } +func (r *FineTuningJobCheckpoint) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Metrics at the step number during the fine-tuning job. +type FineTuningJobCheckpointMetrics struct { + FullValidLoss float64 `json:"full_valid_loss"` + FullValidMeanTokenAccuracy float64 `json:"full_valid_mean_token_accuracy"` + Step float64 `json:"step"` + TrainLoss float64 `json:"train_loss"` + TrainMeanTokenAccuracy float64 `json:"train_mean_token_accuracy"` + ValidLoss float64 `json:"valid_loss"` + ValidMeanTokenAccuracy float64 `json:"valid_mean_token_accuracy"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + FullValidLoss respjson.Field + FullValidMeanTokenAccuracy respjson.Field + Step respjson.Field + TrainLoss respjson.Field + TrainMeanTokenAccuracy respjson.Field + ValidLoss respjson.Field + ValidMeanTokenAccuracy respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningJobCheckpointMetrics) RawJSON() string { return r.JSON.raw } +func (r *FineTuningJobCheckpointMetrics) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningJobCheckpointListParams struct { + // Identifier for the last checkpoint ID from the previous pagination request. + After param.Opt[string] `query:"after,omitzero" json:"-"` + // Number of checkpoints to retrieve. + Limit param.Opt[int64] `query:"limit,omitzero" json:"-"` + paramObj +} + +// URLQuery serializes [FineTuningJobCheckpointListParams]'s query parameters as +// `url.Values`. +func (r FineTuningJobCheckpointListParams) URLQuery() (v url.Values, err error) { + return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{ + ArrayFormat: apiquery.ArrayQueryFormatBrackets, + NestedFormat: apiquery.NestedQueryFormatBrackets, + }) +} diff --git a/vendor/github.com/openai/openai-go/finetuningmethod.go b/vendor/github.com/openai/openai-go/finetuningmethod.go new file mode 100644 index 0000000000..b315a9dc34 --- /dev/null +++ b/vendor/github.com/openai/openai-go/finetuningmethod.go @@ -0,0 +1,1487 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package openai + +import ( + "encoding/json" + + "github.com/openai/openai-go/internal/apijson" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/packages/param" + "github.com/openai/openai-go/packages/respjson" + "github.com/openai/openai-go/shared/constant" +) + +// FineTuningMethodService contains methods and other services that help with +// interacting with the openai API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewFineTuningMethodService] method instead. +type FineTuningMethodService struct { + Options []option.RequestOption +} + +// NewFineTuningMethodService generates a new service that applies the given +// options to each request. These options are applied after the parent client's +// options (if there is one), and before any request-specific options. +func NewFineTuningMethodService(opts ...option.RequestOption) (r FineTuningMethodService) { + r = FineTuningMethodService{} + r.Options = opts + return +} + +// The hyperparameters used for the DPO fine-tuning job. +type DpoHyperparametersResp struct { + // Number of examples in each batch. A larger batch size means that model + // parameters are updated less frequently, but with lower variance. + BatchSize DpoHyperparametersBatchSizeUnionResp `json:"batch_size"` + // The beta value for the DPO method. A higher beta value will increase the weight + // of the penalty between the policy and reference model. + Beta DpoHyperparametersBetaUnionResp `json:"beta"` + // Scaling factor for the learning rate. A smaller learning rate may be useful to + // avoid overfitting. + LearningRateMultiplier DpoHyperparametersLearningRateMultiplierUnionResp `json:"learning_rate_multiplier"` + // The number of epochs to train the model for. An epoch refers to one full cycle + // through the training dataset. + NEpochs DpoHyperparametersNEpochsUnionResp `json:"n_epochs"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + BatchSize respjson.Field + Beta respjson.Field + LearningRateMultiplier respjson.Field + NEpochs respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r DpoHyperparametersResp) RawJSON() string { return r.JSON.raw } +func (r *DpoHyperparametersResp) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ToParam converts this DpoHyperparametersResp to a DpoHyperparameters. +// +// Warning: the fields of the param type will not be present. ToParam should only +// be used at the last possible moment before sending a request. Test for this with +// DpoHyperparameters.Overrides() +func (r DpoHyperparametersResp) ToParam() DpoHyperparameters { + return param.Override[DpoHyperparameters](json.RawMessage(r.RawJSON())) +} + +// DpoHyperparametersBatchSizeUnionResp contains all possible properties and values +// from [constant.Auto], [int64]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +// +// If the underlying value is not a json object, one of the following properties +// will be valid: OfAuto OfInt] +type DpoHyperparametersBatchSizeUnionResp struct { + // This field will be present if the value is a [constant.Auto] instead of an + // object. + OfAuto constant.Auto `json:",inline"` + // This field will be present if the value is a [int64] instead of an object. + OfInt int64 `json:",inline"` + JSON struct { + OfAuto respjson.Field + OfInt respjson.Field + raw string + } `json:"-"` +} + +func (u DpoHyperparametersBatchSizeUnionResp) AsAuto() (v constant.Auto) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u DpoHyperparametersBatchSizeUnionResp) AsInt() (v int64) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u DpoHyperparametersBatchSizeUnionResp) RawJSON() string { return u.JSON.raw } + +func (r *DpoHyperparametersBatchSizeUnionResp) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// DpoHyperparametersBetaUnionResp contains all possible properties and values from +// [constant.Auto], [float64]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +// +// If the underlying value is not a json object, one of the following properties +// will be valid: OfAuto OfFloat] +type DpoHyperparametersBetaUnionResp struct { + // This field will be present if the value is a [constant.Auto] instead of an + // object. + OfAuto constant.Auto `json:",inline"` + // This field will be present if the value is a [float64] instead of an object. + OfFloat float64 `json:",inline"` + JSON struct { + OfAuto respjson.Field + OfFloat respjson.Field + raw string + } `json:"-"` +} + +func (u DpoHyperparametersBetaUnionResp) AsAuto() (v constant.Auto) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u DpoHyperparametersBetaUnionResp) AsFloat() (v float64) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u DpoHyperparametersBetaUnionResp) RawJSON() string { return u.JSON.raw } + +func (r *DpoHyperparametersBetaUnionResp) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// DpoHyperparametersLearningRateMultiplierUnionResp contains all possible +// properties and values from [constant.Auto], [float64]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +// +// If the underlying value is not a json object, one of the following properties +// will be valid: OfAuto OfFloat] +type DpoHyperparametersLearningRateMultiplierUnionResp struct { + // This field will be present if the value is a [constant.Auto] instead of an + // object. + OfAuto constant.Auto `json:",inline"` + // This field will be present if the value is a [float64] instead of an object. + OfFloat float64 `json:",inline"` + JSON struct { + OfAuto respjson.Field + OfFloat respjson.Field + raw string + } `json:"-"` +} + +func (u DpoHyperparametersLearningRateMultiplierUnionResp) AsAuto() (v constant.Auto) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u DpoHyperparametersLearningRateMultiplierUnionResp) AsFloat() (v float64) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u DpoHyperparametersLearningRateMultiplierUnionResp) RawJSON() string { return u.JSON.raw } + +func (r *DpoHyperparametersLearningRateMultiplierUnionResp) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// DpoHyperparametersNEpochsUnionResp contains all possible properties and values +// from [constant.Auto], [int64]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +// +// If the underlying value is not a json object, one of the following properties +// will be valid: OfAuto OfInt] +type DpoHyperparametersNEpochsUnionResp struct { + // This field will be present if the value is a [constant.Auto] instead of an + // object. + OfAuto constant.Auto `json:",inline"` + // This field will be present if the value is a [int64] instead of an object. + OfInt int64 `json:",inline"` + JSON struct { + OfAuto respjson.Field + OfInt respjson.Field + raw string + } `json:"-"` +} + +func (u DpoHyperparametersNEpochsUnionResp) AsAuto() (v constant.Auto) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u DpoHyperparametersNEpochsUnionResp) AsInt() (v int64) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u DpoHyperparametersNEpochsUnionResp) RawJSON() string { return u.JSON.raw } + +func (r *DpoHyperparametersNEpochsUnionResp) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The hyperparameters used for the DPO fine-tuning job. +type DpoHyperparameters struct { + // Number of examples in each batch. A larger batch size means that model + // parameters are updated less frequently, but with lower variance. + BatchSize DpoHyperparametersBatchSizeUnion `json:"batch_size,omitzero"` + // The beta value for the DPO method. A higher beta value will increase the weight + // of the penalty between the policy and reference model. + Beta DpoHyperparametersBetaUnion `json:"beta,omitzero"` + // Scaling factor for the learning rate. A smaller learning rate may be useful to + // avoid overfitting. + LearningRateMultiplier DpoHyperparametersLearningRateMultiplierUnion `json:"learning_rate_multiplier,omitzero"` + // The number of epochs to train the model for. An epoch refers to one full cycle + // through the training dataset. + NEpochs DpoHyperparametersNEpochsUnion `json:"n_epochs,omitzero"` + paramObj +} + +func (r DpoHyperparameters) MarshalJSON() (data []byte, err error) { + type shadow DpoHyperparameters + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *DpoHyperparameters) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type DpoHyperparametersBatchSizeUnion struct { + // Construct this variant with constant.ValueOf[constant.Auto]() + OfAuto constant.Auto `json:",omitzero,inline"` + OfInt param.Opt[int64] `json:",omitzero,inline"` + paramUnion +} + +func (u DpoHyperparametersBatchSizeUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfAuto, u.OfInt) +} +func (u *DpoHyperparametersBatchSizeUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *DpoHyperparametersBatchSizeUnion) asAny() any { + if !param.IsOmitted(u.OfAuto) { + return &u.OfAuto + } else if !param.IsOmitted(u.OfInt) { + return &u.OfInt.Value + } + return nil +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type DpoHyperparametersBetaUnion struct { + // Construct this variant with constant.ValueOf[constant.Auto]() + OfAuto constant.Auto `json:",omitzero,inline"` + OfFloat param.Opt[float64] `json:",omitzero,inline"` + paramUnion +} + +func (u DpoHyperparametersBetaUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfAuto, u.OfFloat) +} +func (u *DpoHyperparametersBetaUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *DpoHyperparametersBetaUnion) asAny() any { + if !param.IsOmitted(u.OfAuto) { + return &u.OfAuto + } else if !param.IsOmitted(u.OfFloat) { + return &u.OfFloat.Value + } + return nil +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type DpoHyperparametersLearningRateMultiplierUnion struct { + // Construct this variant with constant.ValueOf[constant.Auto]() + OfAuto constant.Auto `json:",omitzero,inline"` + OfFloat param.Opt[float64] `json:",omitzero,inline"` + paramUnion +} + +func (u DpoHyperparametersLearningRateMultiplierUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfAuto, u.OfFloat) +} +func (u *DpoHyperparametersLearningRateMultiplierUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *DpoHyperparametersLearningRateMultiplierUnion) asAny() any { + if !param.IsOmitted(u.OfAuto) { + return &u.OfAuto + } else if !param.IsOmitted(u.OfFloat) { + return &u.OfFloat.Value + } + return nil +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type DpoHyperparametersNEpochsUnion struct { + // Construct this variant with constant.ValueOf[constant.Auto]() + OfAuto constant.Auto `json:",omitzero,inline"` + OfInt param.Opt[int64] `json:",omitzero,inline"` + paramUnion +} + +func (u DpoHyperparametersNEpochsUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfAuto, u.OfInt) +} +func (u *DpoHyperparametersNEpochsUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *DpoHyperparametersNEpochsUnion) asAny() any { + if !param.IsOmitted(u.OfAuto) { + return &u.OfAuto + } else if !param.IsOmitted(u.OfInt) { + return &u.OfInt.Value + } + return nil +} + +// Configuration for the DPO fine-tuning method. +type DpoMethod struct { + // The hyperparameters used for the DPO fine-tuning job. + Hyperparameters DpoHyperparametersResp `json:"hyperparameters"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Hyperparameters respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r DpoMethod) RawJSON() string { return r.JSON.raw } +func (r *DpoMethod) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ToParam converts this DpoMethod to a DpoMethodParam. +// +// Warning: the fields of the param type will not be present. ToParam should only +// be used at the last possible moment before sending a request. Test for this with +// DpoMethodParam.Overrides() +func (r DpoMethod) ToParam() DpoMethodParam { + return param.Override[DpoMethodParam](json.RawMessage(r.RawJSON())) +} + +// Configuration for the DPO fine-tuning method. +type DpoMethodParam struct { + // The hyperparameters used for the DPO fine-tuning job. + Hyperparameters DpoHyperparameters `json:"hyperparameters,omitzero"` + paramObj +} + +func (r DpoMethodParam) MarshalJSON() (data []byte, err error) { + type shadow DpoMethodParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *DpoMethodParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The hyperparameters used for the reinforcement fine-tuning job. +type ReinforcementHyperparametersResp struct { + // Number of examples in each batch. A larger batch size means that model + // parameters are updated less frequently, but with lower variance. + BatchSize ReinforcementHyperparametersBatchSizeUnionResp `json:"batch_size"` + // Multiplier on amount of compute used for exploring search space during training. + ComputeMultiplier ReinforcementHyperparametersComputeMultiplierUnionResp `json:"compute_multiplier"` + // The number of training steps between evaluation runs. + EvalInterval ReinforcementHyperparametersEvalIntervalUnionResp `json:"eval_interval"` + // Number of evaluation samples to generate per training step. + EvalSamples ReinforcementHyperparametersEvalSamplesUnionResp `json:"eval_samples"` + // Scaling factor for the learning rate. A smaller learning rate may be useful to + // avoid overfitting. + LearningRateMultiplier ReinforcementHyperparametersLearningRateMultiplierUnionResp `json:"learning_rate_multiplier"` + // The number of epochs to train the model for. An epoch refers to one full cycle + // through the training dataset. + NEpochs ReinforcementHyperparametersNEpochsUnionResp `json:"n_epochs"` + // Level of reasoning effort. + // + // Any of "default", "low", "medium", "high". + ReasoningEffort ReinforcementHyperparametersReasoningEffort `json:"reasoning_effort"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + BatchSize respjson.Field + ComputeMultiplier respjson.Field + EvalInterval respjson.Field + EvalSamples respjson.Field + LearningRateMultiplier respjson.Field + NEpochs respjson.Field + ReasoningEffort respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ReinforcementHyperparametersResp) RawJSON() string { return r.JSON.raw } +func (r *ReinforcementHyperparametersResp) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ToParam converts this ReinforcementHyperparametersResp to a +// ReinforcementHyperparameters. +// +// Warning: the fields of the param type will not be present. ToParam should only +// be used at the last possible moment before sending a request. Test for this with +// ReinforcementHyperparameters.Overrides() +func (r ReinforcementHyperparametersResp) ToParam() ReinforcementHyperparameters { + return param.Override[ReinforcementHyperparameters](json.RawMessage(r.RawJSON())) +} + +// ReinforcementHyperparametersBatchSizeUnionResp contains all possible properties +// and values from [constant.Auto], [int64]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +// +// If the underlying value is not a json object, one of the following properties +// will be valid: OfAuto OfInt] +type ReinforcementHyperparametersBatchSizeUnionResp struct { + // This field will be present if the value is a [constant.Auto] instead of an + // object. + OfAuto constant.Auto `json:",inline"` + // This field will be present if the value is a [int64] instead of an object. + OfInt int64 `json:",inline"` + JSON struct { + OfAuto respjson.Field + OfInt respjson.Field + raw string + } `json:"-"` +} + +func (u ReinforcementHyperparametersBatchSizeUnionResp) AsAuto() (v constant.Auto) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u ReinforcementHyperparametersBatchSizeUnionResp) AsInt() (v int64) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u ReinforcementHyperparametersBatchSizeUnionResp) RawJSON() string { return u.JSON.raw } + +func (r *ReinforcementHyperparametersBatchSizeUnionResp) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ReinforcementHyperparametersComputeMultiplierUnionResp contains all possible +// properties and values from [constant.Auto], [float64]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +// +// If the underlying value is not a json object, one of the following properties +// will be valid: OfAuto OfFloat] +type ReinforcementHyperparametersComputeMultiplierUnionResp struct { + // This field will be present if the value is a [constant.Auto] instead of an + // object. + OfAuto constant.Auto `json:",inline"` + // This field will be present if the value is a [float64] instead of an object. + OfFloat float64 `json:",inline"` + JSON struct { + OfAuto respjson.Field + OfFloat respjson.Field + raw string + } `json:"-"` +} + +func (u ReinforcementHyperparametersComputeMultiplierUnionResp) AsAuto() (v constant.Auto) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u ReinforcementHyperparametersComputeMultiplierUnionResp) AsFloat() (v float64) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u ReinforcementHyperparametersComputeMultiplierUnionResp) RawJSON() string { return u.JSON.raw } + +func (r *ReinforcementHyperparametersComputeMultiplierUnionResp) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ReinforcementHyperparametersEvalIntervalUnionResp contains all possible +// properties and values from [constant.Auto], [int64]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +// +// If the underlying value is not a json object, one of the following properties +// will be valid: OfAuto OfInt] +type ReinforcementHyperparametersEvalIntervalUnionResp struct { + // This field will be present if the value is a [constant.Auto] instead of an + // object. + OfAuto constant.Auto `json:",inline"` + // This field will be present if the value is a [int64] instead of an object. + OfInt int64 `json:",inline"` + JSON struct { + OfAuto respjson.Field + OfInt respjson.Field + raw string + } `json:"-"` +} + +func (u ReinforcementHyperparametersEvalIntervalUnionResp) AsAuto() (v constant.Auto) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u ReinforcementHyperparametersEvalIntervalUnionResp) AsInt() (v int64) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u ReinforcementHyperparametersEvalIntervalUnionResp) RawJSON() string { return u.JSON.raw } + +func (r *ReinforcementHyperparametersEvalIntervalUnionResp) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ReinforcementHyperparametersEvalSamplesUnionResp contains all possible +// properties and values from [constant.Auto], [int64]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +// +// If the underlying value is not a json object, one of the following properties +// will be valid: OfAuto OfInt] +type ReinforcementHyperparametersEvalSamplesUnionResp struct { + // This field will be present if the value is a [constant.Auto] instead of an + // object. + OfAuto constant.Auto `json:",inline"` + // This field will be present if the value is a [int64] instead of an object. + OfInt int64 `json:",inline"` + JSON struct { + OfAuto respjson.Field + OfInt respjson.Field + raw string + } `json:"-"` +} + +func (u ReinforcementHyperparametersEvalSamplesUnionResp) AsAuto() (v constant.Auto) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u ReinforcementHyperparametersEvalSamplesUnionResp) AsInt() (v int64) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u ReinforcementHyperparametersEvalSamplesUnionResp) RawJSON() string { return u.JSON.raw } + +func (r *ReinforcementHyperparametersEvalSamplesUnionResp) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ReinforcementHyperparametersLearningRateMultiplierUnionResp contains all +// possible properties and values from [constant.Auto], [float64]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +// +// If the underlying value is not a json object, one of the following properties +// will be valid: OfAuto OfFloat] +type ReinforcementHyperparametersLearningRateMultiplierUnionResp struct { + // This field will be present if the value is a [constant.Auto] instead of an + // object. + OfAuto constant.Auto `json:",inline"` + // This field will be present if the value is a [float64] instead of an object. + OfFloat float64 `json:",inline"` + JSON struct { + OfAuto respjson.Field + OfFloat respjson.Field + raw string + } `json:"-"` +} + +func (u ReinforcementHyperparametersLearningRateMultiplierUnionResp) AsAuto() (v constant.Auto) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u ReinforcementHyperparametersLearningRateMultiplierUnionResp) AsFloat() (v float64) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u ReinforcementHyperparametersLearningRateMultiplierUnionResp) RawJSON() string { + return u.JSON.raw +} + +func (r *ReinforcementHyperparametersLearningRateMultiplierUnionResp) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ReinforcementHyperparametersNEpochsUnionResp contains all possible properties +// and values from [constant.Auto], [int64]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +// +// If the underlying value is not a json object, one of the following properties +// will be valid: OfAuto OfInt] +type ReinforcementHyperparametersNEpochsUnionResp struct { + // This field will be present if the value is a [constant.Auto] instead of an + // object. + OfAuto constant.Auto `json:",inline"` + // This field will be present if the value is a [int64] instead of an object. + OfInt int64 `json:",inline"` + JSON struct { + OfAuto respjson.Field + OfInt respjson.Field + raw string + } `json:"-"` +} + +func (u ReinforcementHyperparametersNEpochsUnionResp) AsAuto() (v constant.Auto) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u ReinforcementHyperparametersNEpochsUnionResp) AsInt() (v int64) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u ReinforcementHyperparametersNEpochsUnionResp) RawJSON() string { return u.JSON.raw } + +func (r *ReinforcementHyperparametersNEpochsUnionResp) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Level of reasoning effort. +type ReinforcementHyperparametersReasoningEffort string + +const ( + ReinforcementHyperparametersReasoningEffortDefault ReinforcementHyperparametersReasoningEffort = "default" + ReinforcementHyperparametersReasoningEffortLow ReinforcementHyperparametersReasoningEffort = "low" + ReinforcementHyperparametersReasoningEffortMedium ReinforcementHyperparametersReasoningEffort = "medium" + ReinforcementHyperparametersReasoningEffortHigh ReinforcementHyperparametersReasoningEffort = "high" +) + +// The hyperparameters used for the reinforcement fine-tuning job. +type ReinforcementHyperparameters struct { + // Number of examples in each batch. A larger batch size means that model + // parameters are updated less frequently, but with lower variance. + BatchSize ReinforcementHyperparametersBatchSizeUnion `json:"batch_size,omitzero"` + // Multiplier on amount of compute used for exploring search space during training. + ComputeMultiplier ReinforcementHyperparametersComputeMultiplierUnion `json:"compute_multiplier,omitzero"` + // The number of training steps between evaluation runs. + EvalInterval ReinforcementHyperparametersEvalIntervalUnion `json:"eval_interval,omitzero"` + // Number of evaluation samples to generate per training step. + EvalSamples ReinforcementHyperparametersEvalSamplesUnion `json:"eval_samples,omitzero"` + // Scaling factor for the learning rate. A smaller learning rate may be useful to + // avoid overfitting. + LearningRateMultiplier ReinforcementHyperparametersLearningRateMultiplierUnion `json:"learning_rate_multiplier,omitzero"` + // The number of epochs to train the model for. An epoch refers to one full cycle + // through the training dataset. + NEpochs ReinforcementHyperparametersNEpochsUnion `json:"n_epochs,omitzero"` + // Level of reasoning effort. + // + // Any of "default", "low", "medium", "high". + ReasoningEffort ReinforcementHyperparametersReasoningEffort `json:"reasoning_effort,omitzero"` + paramObj +} + +func (r ReinforcementHyperparameters) MarshalJSON() (data []byte, err error) { + type shadow ReinforcementHyperparameters + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ReinforcementHyperparameters) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type ReinforcementHyperparametersBatchSizeUnion struct { + // Construct this variant with constant.ValueOf[constant.Auto]() + OfAuto constant.Auto `json:",omitzero,inline"` + OfInt param.Opt[int64] `json:",omitzero,inline"` + paramUnion +} + +func (u ReinforcementHyperparametersBatchSizeUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfAuto, u.OfInt) +} +func (u *ReinforcementHyperparametersBatchSizeUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *ReinforcementHyperparametersBatchSizeUnion) asAny() any { + if !param.IsOmitted(u.OfAuto) { + return &u.OfAuto + } else if !param.IsOmitted(u.OfInt) { + return &u.OfInt.Value + } + return nil +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type ReinforcementHyperparametersComputeMultiplierUnion struct { + // Construct this variant with constant.ValueOf[constant.Auto]() + OfAuto constant.Auto `json:",omitzero,inline"` + OfFloat param.Opt[float64] `json:",omitzero,inline"` + paramUnion +} + +func (u ReinforcementHyperparametersComputeMultiplierUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfAuto, u.OfFloat) +} +func (u *ReinforcementHyperparametersComputeMultiplierUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *ReinforcementHyperparametersComputeMultiplierUnion) asAny() any { + if !param.IsOmitted(u.OfAuto) { + return &u.OfAuto + } else if !param.IsOmitted(u.OfFloat) { + return &u.OfFloat.Value + } + return nil +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type ReinforcementHyperparametersEvalIntervalUnion struct { + // Construct this variant with constant.ValueOf[constant.Auto]() + OfAuto constant.Auto `json:",omitzero,inline"` + OfInt param.Opt[int64] `json:",omitzero,inline"` + paramUnion +} + +func (u ReinforcementHyperparametersEvalIntervalUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfAuto, u.OfInt) +} +func (u *ReinforcementHyperparametersEvalIntervalUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *ReinforcementHyperparametersEvalIntervalUnion) asAny() any { + if !param.IsOmitted(u.OfAuto) { + return &u.OfAuto + } else if !param.IsOmitted(u.OfInt) { + return &u.OfInt.Value + } + return nil +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type ReinforcementHyperparametersEvalSamplesUnion struct { + // Construct this variant with constant.ValueOf[constant.Auto]() + OfAuto constant.Auto `json:",omitzero,inline"` + OfInt param.Opt[int64] `json:",omitzero,inline"` + paramUnion +} + +func (u ReinforcementHyperparametersEvalSamplesUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfAuto, u.OfInt) +} +func (u *ReinforcementHyperparametersEvalSamplesUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *ReinforcementHyperparametersEvalSamplesUnion) asAny() any { + if !param.IsOmitted(u.OfAuto) { + return &u.OfAuto + } else if !param.IsOmitted(u.OfInt) { + return &u.OfInt.Value + } + return nil +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type ReinforcementHyperparametersLearningRateMultiplierUnion struct { + // Construct this variant with constant.ValueOf[constant.Auto]() + OfAuto constant.Auto `json:",omitzero,inline"` + OfFloat param.Opt[float64] `json:",omitzero,inline"` + paramUnion +} + +func (u ReinforcementHyperparametersLearningRateMultiplierUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfAuto, u.OfFloat) +} +func (u *ReinforcementHyperparametersLearningRateMultiplierUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *ReinforcementHyperparametersLearningRateMultiplierUnion) asAny() any { + if !param.IsOmitted(u.OfAuto) { + return &u.OfAuto + } else if !param.IsOmitted(u.OfFloat) { + return &u.OfFloat.Value + } + return nil +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type ReinforcementHyperparametersNEpochsUnion struct { + // Construct this variant with constant.ValueOf[constant.Auto]() + OfAuto constant.Auto `json:",omitzero,inline"` + OfInt param.Opt[int64] `json:",omitzero,inline"` + paramUnion +} + +func (u ReinforcementHyperparametersNEpochsUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfAuto, u.OfInt) +} +func (u *ReinforcementHyperparametersNEpochsUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *ReinforcementHyperparametersNEpochsUnion) asAny() any { + if !param.IsOmitted(u.OfAuto) { + return &u.OfAuto + } else if !param.IsOmitted(u.OfInt) { + return &u.OfInt.Value + } + return nil +} + +// Configuration for the reinforcement fine-tuning method. +type ReinforcementMethod struct { + // The grader used for the fine-tuning job. + Grader ReinforcementMethodGraderUnion `json:"grader,required"` + // The hyperparameters used for the reinforcement fine-tuning job. + Hyperparameters ReinforcementHyperparametersResp `json:"hyperparameters"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Grader respjson.Field + Hyperparameters respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ReinforcementMethod) RawJSON() string { return r.JSON.raw } +func (r *ReinforcementMethod) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ToParam converts this ReinforcementMethod to a ReinforcementMethodParam. +// +// Warning: the fields of the param type will not be present. ToParam should only +// be used at the last possible moment before sending a request. Test for this with +// ReinforcementMethodParam.Overrides() +func (r ReinforcementMethod) ToParam() ReinforcementMethodParam { + return param.Override[ReinforcementMethodParam](json.RawMessage(r.RawJSON())) +} + +// ReinforcementMethodGraderUnion contains all possible properties and values from +// [StringCheckGrader], [TextSimilarityGrader], [PythonGrader], [ScoreModelGrader], +// [MultiGrader]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +type ReinforcementMethodGraderUnion struct { + // This field is a union of [string], [string], [[]ScoreModelGraderInput] + Input ReinforcementMethodGraderUnionInput `json:"input"` + Name string `json:"name"` + // This field is from variant [StringCheckGrader]. + Operation StringCheckGraderOperation `json:"operation"` + Reference string `json:"reference"` + Type string `json:"type"` + // This field is from variant [TextSimilarityGrader]. + EvaluationMetric TextSimilarityGraderEvaluationMetric `json:"evaluation_metric"` + // This field is from variant [PythonGrader]. + Source string `json:"source"` + // This field is from variant [PythonGrader]. + ImageTag string `json:"image_tag"` + // This field is from variant [ScoreModelGrader]. + Model string `json:"model"` + // This field is from variant [ScoreModelGrader]. + Range []float64 `json:"range"` + // This field is from variant [ScoreModelGrader]. + SamplingParams any `json:"sampling_params"` + // This field is from variant [MultiGrader]. + CalculateOutput string `json:"calculate_output"` + // This field is from variant [MultiGrader]. + Graders MultiGraderGradersUnion `json:"graders"` + JSON struct { + Input respjson.Field + Name respjson.Field + Operation respjson.Field + Reference respjson.Field + Type respjson.Field + EvaluationMetric respjson.Field + Source respjson.Field + ImageTag respjson.Field + Model respjson.Field + Range respjson.Field + SamplingParams respjson.Field + CalculateOutput respjson.Field + Graders respjson.Field + raw string + } `json:"-"` +} + +func (u ReinforcementMethodGraderUnion) AsStringCheckGrader() (v StringCheckGrader) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u ReinforcementMethodGraderUnion) AsTextSimilarityGrader() (v TextSimilarityGrader) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u ReinforcementMethodGraderUnion) AsPythonGrader() (v PythonGrader) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u ReinforcementMethodGraderUnion) AsScoreModelGrader() (v ScoreModelGrader) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u ReinforcementMethodGraderUnion) AsMultiGrader() (v MultiGrader) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u ReinforcementMethodGraderUnion) RawJSON() string { return u.JSON.raw } + +func (r *ReinforcementMethodGraderUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ReinforcementMethodGraderUnionInput is an implicit subunion of +// [ReinforcementMethodGraderUnion]. ReinforcementMethodGraderUnionInput provides +// convenient access to the sub-properties of the union. +// +// For type safety it is recommended to directly use a variant of the +// [ReinforcementMethodGraderUnion]. +// +// If the underlying value is not a json object, one of the following properties +// will be valid: OfString OfScoreModelGraderInputArray] +type ReinforcementMethodGraderUnionInput struct { + // This field will be present if the value is a [string] instead of an object. + OfString string `json:",inline"` + // This field will be present if the value is a [[]ScoreModelGraderInput] instead + // of an object. + OfScoreModelGraderInputArray []ScoreModelGraderInput `json:",inline"` + JSON struct { + OfString respjson.Field + OfScoreModelGraderInputArray respjson.Field + raw string + } `json:"-"` +} + +func (r *ReinforcementMethodGraderUnionInput) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Configuration for the reinforcement fine-tuning method. +// +// The property Grader is required. +type ReinforcementMethodParam struct { + // The grader used for the fine-tuning job. + Grader ReinforcementMethodGraderUnionParam `json:"grader,omitzero,required"` + // The hyperparameters used for the reinforcement fine-tuning job. + Hyperparameters ReinforcementHyperparameters `json:"hyperparameters,omitzero"` + paramObj +} + +func (r ReinforcementMethodParam) MarshalJSON() (data []byte, err error) { + type shadow ReinforcementMethodParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ReinforcementMethodParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type ReinforcementMethodGraderUnionParam struct { + OfStringCheckGrader *StringCheckGraderParam `json:",omitzero,inline"` + OfTextSimilarityGrader *TextSimilarityGraderParam `json:",omitzero,inline"` + OfPythonGrader *PythonGraderParam `json:",omitzero,inline"` + OfScoreModelGrader *ScoreModelGraderParam `json:",omitzero,inline"` + OfMultiGrader *MultiGraderParam `json:",omitzero,inline"` + paramUnion +} + +func (u ReinforcementMethodGraderUnionParam) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfStringCheckGrader, + u.OfTextSimilarityGrader, + u.OfPythonGrader, + u.OfScoreModelGrader, + u.OfMultiGrader) +} +func (u *ReinforcementMethodGraderUnionParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *ReinforcementMethodGraderUnionParam) asAny() any { + if !param.IsOmitted(u.OfStringCheckGrader) { + return u.OfStringCheckGrader + } else if !param.IsOmitted(u.OfTextSimilarityGrader) { + return u.OfTextSimilarityGrader + } else if !param.IsOmitted(u.OfPythonGrader) { + return u.OfPythonGrader + } else if !param.IsOmitted(u.OfScoreModelGrader) { + return u.OfScoreModelGrader + } else if !param.IsOmitted(u.OfMultiGrader) { + return u.OfMultiGrader + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u ReinforcementMethodGraderUnionParam) GetOperation() *string { + if vt := u.OfStringCheckGrader; vt != nil { + return (*string)(&vt.Operation) + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u ReinforcementMethodGraderUnionParam) GetEvaluationMetric() *string { + if vt := u.OfTextSimilarityGrader; vt != nil { + return (*string)(&vt.EvaluationMetric) + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u ReinforcementMethodGraderUnionParam) GetSource() *string { + if vt := u.OfPythonGrader; vt != nil { + return &vt.Source + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u ReinforcementMethodGraderUnionParam) GetImageTag() *string { + if vt := u.OfPythonGrader; vt != nil && vt.ImageTag.Valid() { + return &vt.ImageTag.Value + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u ReinforcementMethodGraderUnionParam) GetModel() *string { + if vt := u.OfScoreModelGrader; vt != nil { + return &vt.Model + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u ReinforcementMethodGraderUnionParam) GetRange() []float64 { + if vt := u.OfScoreModelGrader; vt != nil { + return vt.Range + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u ReinforcementMethodGraderUnionParam) GetSamplingParams() *any { + if vt := u.OfScoreModelGrader; vt != nil { + return &vt.SamplingParams + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u ReinforcementMethodGraderUnionParam) GetCalculateOutput() *string { + if vt := u.OfMultiGrader; vt != nil { + return &vt.CalculateOutput + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u ReinforcementMethodGraderUnionParam) GetGraders() *MultiGraderGradersUnionParam { + if vt := u.OfMultiGrader; vt != nil { + return &vt.Graders + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u ReinforcementMethodGraderUnionParam) GetName() *string { + if vt := u.OfStringCheckGrader; vt != nil { + return (*string)(&vt.Name) + } else if vt := u.OfTextSimilarityGrader; vt != nil { + return (*string)(&vt.Name) + } else if vt := u.OfPythonGrader; vt != nil { + return (*string)(&vt.Name) + } else if vt := u.OfScoreModelGrader; vt != nil { + return (*string)(&vt.Name) + } else if vt := u.OfMultiGrader; vt != nil { + return (*string)(&vt.Name) + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u ReinforcementMethodGraderUnionParam) GetReference() *string { + if vt := u.OfStringCheckGrader; vt != nil { + return (*string)(&vt.Reference) + } else if vt := u.OfTextSimilarityGrader; vt != nil { + return (*string)(&vt.Reference) + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u ReinforcementMethodGraderUnionParam) GetType() *string { + if vt := u.OfStringCheckGrader; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfTextSimilarityGrader; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfPythonGrader; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfScoreModelGrader; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfMultiGrader; vt != nil { + return (*string)(&vt.Type) + } + return nil +} + +// Returns a subunion which exports methods to access subproperties +// +// Or use AsAny() to get the underlying value +func (u ReinforcementMethodGraderUnionParam) GetInput() (res reinforcementMethodGraderUnionParamInput) { + if vt := u.OfStringCheckGrader; vt != nil { + res.any = &vt.Input + } else if vt := u.OfTextSimilarityGrader; vt != nil { + res.any = &vt.Input + } else if vt := u.OfScoreModelGrader; vt != nil { + res.any = &vt.Input + } + return +} + +// Can have the runtime types [*string], [\*[]ScoreModelGraderInputParam] +type reinforcementMethodGraderUnionParamInput struct{ any } + +// Use the following switch statement to get the type of the union: +// +// switch u.AsAny().(type) { +// case *string: +// case *[]openai.ScoreModelGraderInputParam: +// default: +// fmt.Errorf("not present") +// } +func (u reinforcementMethodGraderUnionParamInput) AsAny() any { return u.any } + +// The hyperparameters used for the fine-tuning job. +type SupervisedHyperparametersResp struct { + // Number of examples in each batch. A larger batch size means that model + // parameters are updated less frequently, but with lower variance. + BatchSize SupervisedHyperparametersBatchSizeUnionResp `json:"batch_size"` + // Scaling factor for the learning rate. A smaller learning rate may be useful to + // avoid overfitting. + LearningRateMultiplier SupervisedHyperparametersLearningRateMultiplierUnionResp `json:"learning_rate_multiplier"` + // The number of epochs to train the model for. An epoch refers to one full cycle + // through the training dataset. + NEpochs SupervisedHyperparametersNEpochsUnionResp `json:"n_epochs"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + BatchSize respjson.Field + LearningRateMultiplier respjson.Field + NEpochs respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r SupervisedHyperparametersResp) RawJSON() string { return r.JSON.raw } +func (r *SupervisedHyperparametersResp) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ToParam converts this SupervisedHyperparametersResp to a +// SupervisedHyperparameters. +// +// Warning: the fields of the param type will not be present. ToParam should only +// be used at the last possible moment before sending a request. Test for this with +// SupervisedHyperparameters.Overrides() +func (r SupervisedHyperparametersResp) ToParam() SupervisedHyperparameters { + return param.Override[SupervisedHyperparameters](json.RawMessage(r.RawJSON())) +} + +// SupervisedHyperparametersBatchSizeUnionResp contains all possible properties and +// values from [constant.Auto], [int64]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +// +// If the underlying value is not a json object, one of the following properties +// will be valid: OfAuto OfInt] +type SupervisedHyperparametersBatchSizeUnionResp struct { + // This field will be present if the value is a [constant.Auto] instead of an + // object. + OfAuto constant.Auto `json:",inline"` + // This field will be present if the value is a [int64] instead of an object. + OfInt int64 `json:",inline"` + JSON struct { + OfAuto respjson.Field + OfInt respjson.Field + raw string + } `json:"-"` +} + +func (u SupervisedHyperparametersBatchSizeUnionResp) AsAuto() (v constant.Auto) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u SupervisedHyperparametersBatchSizeUnionResp) AsInt() (v int64) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u SupervisedHyperparametersBatchSizeUnionResp) RawJSON() string { return u.JSON.raw } + +func (r *SupervisedHyperparametersBatchSizeUnionResp) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// SupervisedHyperparametersLearningRateMultiplierUnionResp contains all possible +// properties and values from [constant.Auto], [float64]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +// +// If the underlying value is not a json object, one of the following properties +// will be valid: OfAuto OfFloat] +type SupervisedHyperparametersLearningRateMultiplierUnionResp struct { + // This field will be present if the value is a [constant.Auto] instead of an + // object. + OfAuto constant.Auto `json:",inline"` + // This field will be present if the value is a [float64] instead of an object. + OfFloat float64 `json:",inline"` + JSON struct { + OfAuto respjson.Field + OfFloat respjson.Field + raw string + } `json:"-"` +} + +func (u SupervisedHyperparametersLearningRateMultiplierUnionResp) AsAuto() (v constant.Auto) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u SupervisedHyperparametersLearningRateMultiplierUnionResp) AsFloat() (v float64) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u SupervisedHyperparametersLearningRateMultiplierUnionResp) RawJSON() string { return u.JSON.raw } + +func (r *SupervisedHyperparametersLearningRateMultiplierUnionResp) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// SupervisedHyperparametersNEpochsUnionResp contains all possible properties and +// values from [constant.Auto], [int64]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +// +// If the underlying value is not a json object, one of the following properties +// will be valid: OfAuto OfInt] +type SupervisedHyperparametersNEpochsUnionResp struct { + // This field will be present if the value is a [constant.Auto] instead of an + // object. + OfAuto constant.Auto `json:",inline"` + // This field will be present if the value is a [int64] instead of an object. + OfInt int64 `json:",inline"` + JSON struct { + OfAuto respjson.Field + OfInt respjson.Field + raw string + } `json:"-"` +} + +func (u SupervisedHyperparametersNEpochsUnionResp) AsAuto() (v constant.Auto) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u SupervisedHyperparametersNEpochsUnionResp) AsInt() (v int64) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u SupervisedHyperparametersNEpochsUnionResp) RawJSON() string { return u.JSON.raw } + +func (r *SupervisedHyperparametersNEpochsUnionResp) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The hyperparameters used for the fine-tuning job. +type SupervisedHyperparameters struct { + // Number of examples in each batch. A larger batch size means that model + // parameters are updated less frequently, but with lower variance. + BatchSize SupervisedHyperparametersBatchSizeUnion `json:"batch_size,omitzero"` + // Scaling factor for the learning rate. A smaller learning rate may be useful to + // avoid overfitting. + LearningRateMultiplier SupervisedHyperparametersLearningRateMultiplierUnion `json:"learning_rate_multiplier,omitzero"` + // The number of epochs to train the model for. An epoch refers to one full cycle + // through the training dataset. + NEpochs SupervisedHyperparametersNEpochsUnion `json:"n_epochs,omitzero"` + paramObj +} + +func (r SupervisedHyperparameters) MarshalJSON() (data []byte, err error) { + type shadow SupervisedHyperparameters + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *SupervisedHyperparameters) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type SupervisedHyperparametersBatchSizeUnion struct { + // Construct this variant with constant.ValueOf[constant.Auto]() + OfAuto constant.Auto `json:",omitzero,inline"` + OfInt param.Opt[int64] `json:",omitzero,inline"` + paramUnion +} + +func (u SupervisedHyperparametersBatchSizeUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfAuto, u.OfInt) +} +func (u *SupervisedHyperparametersBatchSizeUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *SupervisedHyperparametersBatchSizeUnion) asAny() any { + if !param.IsOmitted(u.OfAuto) { + return &u.OfAuto + } else if !param.IsOmitted(u.OfInt) { + return &u.OfInt.Value + } + return nil +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type SupervisedHyperparametersLearningRateMultiplierUnion struct { + // Construct this variant with constant.ValueOf[constant.Auto]() + OfAuto constant.Auto `json:",omitzero,inline"` + OfFloat param.Opt[float64] `json:",omitzero,inline"` + paramUnion +} + +func (u SupervisedHyperparametersLearningRateMultiplierUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfAuto, u.OfFloat) +} +func (u *SupervisedHyperparametersLearningRateMultiplierUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *SupervisedHyperparametersLearningRateMultiplierUnion) asAny() any { + if !param.IsOmitted(u.OfAuto) { + return &u.OfAuto + } else if !param.IsOmitted(u.OfFloat) { + return &u.OfFloat.Value + } + return nil +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type SupervisedHyperparametersNEpochsUnion struct { + // Construct this variant with constant.ValueOf[constant.Auto]() + OfAuto constant.Auto `json:",omitzero,inline"` + OfInt param.Opt[int64] `json:",omitzero,inline"` + paramUnion +} + +func (u SupervisedHyperparametersNEpochsUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfAuto, u.OfInt) +} +func (u *SupervisedHyperparametersNEpochsUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *SupervisedHyperparametersNEpochsUnion) asAny() any { + if !param.IsOmitted(u.OfAuto) { + return &u.OfAuto + } else if !param.IsOmitted(u.OfInt) { + return &u.OfInt.Value + } + return nil +} + +// Configuration for the supervised fine-tuning method. +type SupervisedMethod struct { + // The hyperparameters used for the fine-tuning job. + Hyperparameters SupervisedHyperparametersResp `json:"hyperparameters"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Hyperparameters respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r SupervisedMethod) RawJSON() string { return r.JSON.raw } +func (r *SupervisedMethod) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ToParam converts this SupervisedMethod to a SupervisedMethodParam. +// +// Warning: the fields of the param type will not be present. ToParam should only +// be used at the last possible moment before sending a request. Test for this with +// SupervisedMethodParam.Overrides() +func (r SupervisedMethod) ToParam() SupervisedMethodParam { + return param.Override[SupervisedMethodParam](json.RawMessage(r.RawJSON())) +} + +// Configuration for the supervised fine-tuning method. +type SupervisedMethodParam struct { + // The hyperparameters used for the fine-tuning job. + Hyperparameters SupervisedHyperparameters `json:"hyperparameters,omitzero"` + paramObj +} + +func (r SupervisedMethodParam) MarshalJSON() (data []byte, err error) { + type shadow SupervisedMethodParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *SupervisedMethodParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} diff --git a/vendor/github.com/openai/openai-go/grader.go b/vendor/github.com/openai/openai-go/grader.go new file mode 100644 index 0000000000..0a12b450a9 --- /dev/null +++ b/vendor/github.com/openai/openai-go/grader.go @@ -0,0 +1,28 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package openai + +import ( + "github.com/openai/openai-go/option" +) + +// GraderService contains methods and other services that help with interacting +// with the openai API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewGraderService] method instead. +type GraderService struct { + Options []option.RequestOption + GraderModels GraderGraderModelService +} + +// NewGraderService generates a new service that applies the given options to each +// request. These options are applied after the parent client's options (if there +// is one), and before any request-specific options. +func NewGraderService(opts ...option.RequestOption) (r GraderService) { + r = GraderService{} + r.Options = opts + r.GraderModels = NewGraderGraderModelService(opts...) + return +} diff --git a/vendor/github.com/openai/openai-go/gradergradermodel.go b/vendor/github.com/openai/openai-go/gradergradermodel.go new file mode 100644 index 0000000000..27236a69e0 --- /dev/null +++ b/vendor/github.com/openai/openai-go/gradergradermodel.go @@ -0,0 +1,1373 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package openai + +import ( + "encoding/json" + + "github.com/openai/openai-go/internal/apijson" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/packages/param" + "github.com/openai/openai-go/packages/respjson" + "github.com/openai/openai-go/responses" + "github.com/openai/openai-go/shared/constant" +) + +// GraderGraderModelService contains methods and other services that help with +// interacting with the openai API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewGraderGraderModelService] method instead. +type GraderGraderModelService struct { + Options []option.RequestOption +} + +// NewGraderGraderModelService generates a new service that applies the given +// options to each request. These options are applied after the parent client's +// options (if there is one), and before any request-specific options. +func NewGraderGraderModelService(opts ...option.RequestOption) (r GraderGraderModelService) { + r = GraderGraderModelService{} + r.Options = opts + return +} + +// A LabelModelGrader object which uses a model to assign labels to each item in +// the evaluation. +type LabelModelGrader struct { + Input []LabelModelGraderInput `json:"input,required"` + // The labels to assign to each item in the evaluation. + Labels []string `json:"labels,required"` + // The model to use for the evaluation. Must support structured outputs. + Model string `json:"model,required"` + // The name of the grader. + Name string `json:"name,required"` + // The labels that indicate a passing result. Must be a subset of labels. + PassingLabels []string `json:"passing_labels,required"` + // The object type, which is always `label_model`. + Type constant.LabelModel `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Input respjson.Field + Labels respjson.Field + Model respjson.Field + Name respjson.Field + PassingLabels respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r LabelModelGrader) RawJSON() string { return r.JSON.raw } +func (r *LabelModelGrader) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ToParam converts this LabelModelGrader to a LabelModelGraderParam. +// +// Warning: the fields of the param type will not be present. ToParam should only +// be used at the last possible moment before sending a request. Test for this with +// LabelModelGraderParam.Overrides() +func (r LabelModelGrader) ToParam() LabelModelGraderParam { + return param.Override[LabelModelGraderParam](json.RawMessage(r.RawJSON())) +} + +// A message input to the model with a role indicating instruction following +// hierarchy. Instructions given with the `developer` or `system` role take +// precedence over instructions given with the `user` role. Messages with the +// `assistant` role are presumed to have been generated by the model in previous +// interactions. +type LabelModelGraderInput struct { + // Inputs to the model - can contain template strings. + Content LabelModelGraderInputContentUnion `json:"content,required"` + // The role of the message input. One of `user`, `assistant`, `system`, or + // `developer`. + // + // Any of "user", "assistant", "system", "developer". + Role string `json:"role,required"` + // The type of the message input. Always `message`. + // + // Any of "message". + Type string `json:"type"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Content respjson.Field + Role respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r LabelModelGraderInput) RawJSON() string { return r.JSON.raw } +func (r *LabelModelGraderInput) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// LabelModelGraderInputContentUnion contains all possible properties and values +// from [string], [responses.ResponseInputText], +// [LabelModelGraderInputContentOutputText], +// [LabelModelGraderInputContentInputImage], [[]any]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +// +// If the underlying value is not a json object, one of the following properties +// will be valid: OfString OfAnArrayOfInputTextAndInputImage] +type LabelModelGraderInputContentUnion struct { + // This field will be present if the value is a [string] instead of an object. + OfString string `json:",inline"` + // This field will be present if the value is a [[]any] instead of an object. + OfAnArrayOfInputTextAndInputImage []any `json:",inline"` + Text string `json:"text"` + Type string `json:"type"` + // This field is from variant [LabelModelGraderInputContentInputImage]. + ImageURL string `json:"image_url"` + // This field is from variant [LabelModelGraderInputContentInputImage]. + Detail string `json:"detail"` + JSON struct { + OfString respjson.Field + OfAnArrayOfInputTextAndInputImage respjson.Field + Text respjson.Field + Type respjson.Field + ImageURL respjson.Field + Detail respjson.Field + raw string + } `json:"-"` +} + +func (u LabelModelGraderInputContentUnion) AsString() (v string) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u LabelModelGraderInputContentUnion) AsInputText() (v responses.ResponseInputText) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u LabelModelGraderInputContentUnion) AsOutputText() (v LabelModelGraderInputContentOutputText) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u LabelModelGraderInputContentUnion) AsInputImage() (v LabelModelGraderInputContentInputImage) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u LabelModelGraderInputContentUnion) AsAnArrayOfInputTextAndInputImage() (v []any) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u LabelModelGraderInputContentUnion) RawJSON() string { return u.JSON.raw } + +func (r *LabelModelGraderInputContentUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// A text output from the model. +type LabelModelGraderInputContentOutputText struct { + // The text output from the model. + Text string `json:"text,required"` + // The type of the output text. Always `output_text`. + Type constant.OutputText `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Text respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r LabelModelGraderInputContentOutputText) RawJSON() string { return r.JSON.raw } +func (r *LabelModelGraderInputContentOutputText) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// An image input to the model. +type LabelModelGraderInputContentInputImage struct { + // The URL of the image input. + ImageURL string `json:"image_url,required"` + // The type of the image input. Always `input_image`. + Type constant.InputImage `json:"type,required"` + // The detail level of the image to be sent to the model. One of `high`, `low`, or + // `auto`. Defaults to `auto`. + Detail string `json:"detail"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ImageURL respjson.Field + Type respjson.Field + Detail respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r LabelModelGraderInputContentInputImage) RawJSON() string { return r.JSON.raw } +func (r *LabelModelGraderInputContentInputImage) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// A LabelModelGrader object which uses a model to assign labels to each item in +// the evaluation. +// +// The properties Input, Labels, Model, Name, PassingLabels, Type are required. +type LabelModelGraderParam struct { + Input []LabelModelGraderInputParam `json:"input,omitzero,required"` + // The labels to assign to each item in the evaluation. + Labels []string `json:"labels,omitzero,required"` + // The model to use for the evaluation. Must support structured outputs. + Model string `json:"model,required"` + // The name of the grader. + Name string `json:"name,required"` + // The labels that indicate a passing result. Must be a subset of labels. + PassingLabels []string `json:"passing_labels,omitzero,required"` + // The object type, which is always `label_model`. + // + // This field can be elided, and will marshal its zero value as "label_model". + Type constant.LabelModel `json:"type,required"` + paramObj +} + +func (r LabelModelGraderParam) MarshalJSON() (data []byte, err error) { + type shadow LabelModelGraderParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *LabelModelGraderParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// A message input to the model with a role indicating instruction following +// hierarchy. Instructions given with the `developer` or `system` role take +// precedence over instructions given with the `user` role. Messages with the +// `assistant` role are presumed to have been generated by the model in previous +// interactions. +// +// The properties Content, Role are required. +type LabelModelGraderInputParam struct { + // Inputs to the model - can contain template strings. + Content LabelModelGraderInputContentUnionParam `json:"content,omitzero,required"` + // The role of the message input. One of `user`, `assistant`, `system`, or + // `developer`. + // + // Any of "user", "assistant", "system", "developer". + Role string `json:"role,omitzero,required"` + // The type of the message input. Always `message`. + // + // Any of "message". + Type string `json:"type,omitzero"` + paramObj +} + +func (r LabelModelGraderInputParam) MarshalJSON() (data []byte, err error) { + type shadow LabelModelGraderInputParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *LabelModelGraderInputParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +func init() { + apijson.RegisterFieldValidator[LabelModelGraderInputParam]( + "role", "user", "assistant", "system", "developer", + ) + apijson.RegisterFieldValidator[LabelModelGraderInputParam]( + "type", "message", + ) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type LabelModelGraderInputContentUnionParam struct { + OfString param.Opt[string] `json:",omitzero,inline"` + OfInputText *responses.ResponseInputTextParam `json:",omitzero,inline"` + OfOutputText *LabelModelGraderInputContentOutputTextParam `json:",omitzero,inline"` + OfInputImage *LabelModelGraderInputContentInputImageParam `json:",omitzero,inline"` + OfAnArrayOfInputTextAndInputImage []any `json:",omitzero,inline"` + paramUnion +} + +func (u LabelModelGraderInputContentUnionParam) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfString, + u.OfInputText, + u.OfOutputText, + u.OfInputImage, + u.OfAnArrayOfInputTextAndInputImage) +} +func (u *LabelModelGraderInputContentUnionParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *LabelModelGraderInputContentUnionParam) asAny() any { + if !param.IsOmitted(u.OfString) { + return &u.OfString.Value + } else if !param.IsOmitted(u.OfInputText) { + return u.OfInputText + } else if !param.IsOmitted(u.OfOutputText) { + return u.OfOutputText + } else if !param.IsOmitted(u.OfInputImage) { + return u.OfInputImage + } else if !param.IsOmitted(u.OfAnArrayOfInputTextAndInputImage) { + return &u.OfAnArrayOfInputTextAndInputImage + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u LabelModelGraderInputContentUnionParam) GetImageURL() *string { + if vt := u.OfInputImage; vt != nil { + return &vt.ImageURL + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u LabelModelGraderInputContentUnionParam) GetDetail() *string { + if vt := u.OfInputImage; vt != nil && vt.Detail.Valid() { + return &vt.Detail.Value + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u LabelModelGraderInputContentUnionParam) GetText() *string { + if vt := u.OfInputText; vt != nil { + return (*string)(&vt.Text) + } else if vt := u.OfOutputText; vt != nil { + return (*string)(&vt.Text) + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u LabelModelGraderInputContentUnionParam) GetType() *string { + if vt := u.OfInputText; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfOutputText; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfInputImage; vt != nil { + return (*string)(&vt.Type) + } + return nil +} + +// A text output from the model. +// +// The properties Text, Type are required. +type LabelModelGraderInputContentOutputTextParam struct { + // The text output from the model. + Text string `json:"text,required"` + // The type of the output text. Always `output_text`. + // + // This field can be elided, and will marshal its zero value as "output_text". + Type constant.OutputText `json:"type,required"` + paramObj +} + +func (r LabelModelGraderInputContentOutputTextParam) MarshalJSON() (data []byte, err error) { + type shadow LabelModelGraderInputContentOutputTextParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *LabelModelGraderInputContentOutputTextParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// An image input to the model. +// +// The properties ImageURL, Type are required. +type LabelModelGraderInputContentInputImageParam struct { + // The URL of the image input. + ImageURL string `json:"image_url,required"` + // The detail level of the image to be sent to the model. One of `high`, `low`, or + // `auto`. Defaults to `auto`. + Detail param.Opt[string] `json:"detail,omitzero"` + // The type of the image input. Always `input_image`. + // + // This field can be elided, and will marshal its zero value as "input_image". + Type constant.InputImage `json:"type,required"` + paramObj +} + +func (r LabelModelGraderInputContentInputImageParam) MarshalJSON() (data []byte, err error) { + type shadow LabelModelGraderInputContentInputImageParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *LabelModelGraderInputContentInputImageParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// A MultiGrader object combines the output of multiple graders to produce a single +// score. +type MultiGrader struct { + // A formula to calculate the output based on grader results. + CalculateOutput string `json:"calculate_output,required"` + // A StringCheckGrader object that performs a string comparison between input and + // reference using a specified operation. + Graders MultiGraderGradersUnion `json:"graders,required"` + // The name of the grader. + Name string `json:"name,required"` + // The object type, which is always `multi`. + Type constant.Multi `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + CalculateOutput respjson.Field + Graders respjson.Field + Name respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r MultiGrader) RawJSON() string { return r.JSON.raw } +func (r *MultiGrader) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ToParam converts this MultiGrader to a MultiGraderParam. +// +// Warning: the fields of the param type will not be present. ToParam should only +// be used at the last possible moment before sending a request. Test for this with +// MultiGraderParam.Overrides() +func (r MultiGrader) ToParam() MultiGraderParam { + return param.Override[MultiGraderParam](json.RawMessage(r.RawJSON())) +} + +// MultiGraderGradersUnion contains all possible properties and values from +// [StringCheckGrader], [TextSimilarityGrader], [PythonGrader], [ScoreModelGrader], +// [LabelModelGrader]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +type MultiGraderGradersUnion struct { + // This field is a union of [string], [string], [[]ScoreModelGraderInput], + // [[]LabelModelGraderInput] + Input MultiGraderGradersUnionInput `json:"input"` + Name string `json:"name"` + // This field is from variant [StringCheckGrader]. + Operation StringCheckGraderOperation `json:"operation"` + Reference string `json:"reference"` + Type string `json:"type"` + // This field is from variant [TextSimilarityGrader]. + EvaluationMetric TextSimilarityGraderEvaluationMetric `json:"evaluation_metric"` + // This field is from variant [PythonGrader]. + Source string `json:"source"` + // This field is from variant [PythonGrader]. + ImageTag string `json:"image_tag"` + Model string `json:"model"` + // This field is from variant [ScoreModelGrader]. + Range []float64 `json:"range"` + // This field is from variant [ScoreModelGrader]. + SamplingParams any `json:"sampling_params"` + // This field is from variant [LabelModelGrader]. + Labels []string `json:"labels"` + // This field is from variant [LabelModelGrader]. + PassingLabels []string `json:"passing_labels"` + JSON struct { + Input respjson.Field + Name respjson.Field + Operation respjson.Field + Reference respjson.Field + Type respjson.Field + EvaluationMetric respjson.Field + Source respjson.Field + ImageTag respjson.Field + Model respjson.Field + Range respjson.Field + SamplingParams respjson.Field + Labels respjson.Field + PassingLabels respjson.Field + raw string + } `json:"-"` +} + +func (u MultiGraderGradersUnion) AsStringCheckGrader() (v StringCheckGrader) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u MultiGraderGradersUnion) AsTextSimilarityGrader() (v TextSimilarityGrader) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u MultiGraderGradersUnion) AsPythonGrader() (v PythonGrader) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u MultiGraderGradersUnion) AsScoreModelGrader() (v ScoreModelGrader) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u MultiGraderGradersUnion) AsLabelModelGrader() (v LabelModelGrader) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u MultiGraderGradersUnion) RawJSON() string { return u.JSON.raw } + +func (r *MultiGraderGradersUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// MultiGraderGradersUnionInput is an implicit subunion of +// [MultiGraderGradersUnion]. MultiGraderGradersUnionInput provides convenient +// access to the sub-properties of the union. +// +// For type safety it is recommended to directly use a variant of the +// [MultiGraderGradersUnion]. +// +// If the underlying value is not a json object, one of the following properties +// will be valid: OfString OfScoreModelGraderInputArray +// OfLabelModelGraderInputArray] +type MultiGraderGradersUnionInput struct { + // This field will be present if the value is a [string] instead of an object. + OfString string `json:",inline"` + // This field will be present if the value is a [[]ScoreModelGraderInput] instead + // of an object. + OfScoreModelGraderInputArray []ScoreModelGraderInput `json:",inline"` + // This field will be present if the value is a [[]LabelModelGraderInput] instead + // of an object. + OfLabelModelGraderInputArray []LabelModelGraderInput `json:",inline"` + JSON struct { + OfString respjson.Field + OfScoreModelGraderInputArray respjson.Field + OfLabelModelGraderInputArray respjson.Field + raw string + } `json:"-"` +} + +func (r *MultiGraderGradersUnionInput) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// A MultiGrader object combines the output of multiple graders to produce a single +// score. +// +// The properties CalculateOutput, Graders, Name, Type are required. +type MultiGraderParam struct { + // A formula to calculate the output based on grader results. + CalculateOutput string `json:"calculate_output,required"` + // A StringCheckGrader object that performs a string comparison between input and + // reference using a specified operation. + Graders MultiGraderGradersUnionParam `json:"graders,omitzero,required"` + // The name of the grader. + Name string `json:"name,required"` + // The object type, which is always `multi`. + // + // This field can be elided, and will marshal its zero value as "multi". + Type constant.Multi `json:"type,required"` + paramObj +} + +func (r MultiGraderParam) MarshalJSON() (data []byte, err error) { + type shadow MultiGraderParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *MultiGraderParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type MultiGraderGradersUnionParam struct { + OfStringCheckGrader *StringCheckGraderParam `json:",omitzero,inline"` + OfTextSimilarityGrader *TextSimilarityGraderParam `json:",omitzero,inline"` + OfPythonGrader *PythonGraderParam `json:",omitzero,inline"` + OfScoreModelGrader *ScoreModelGraderParam `json:",omitzero,inline"` + OfLabelModelGrader *LabelModelGraderParam `json:",omitzero,inline"` + paramUnion +} + +func (u MultiGraderGradersUnionParam) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfStringCheckGrader, + u.OfTextSimilarityGrader, + u.OfPythonGrader, + u.OfScoreModelGrader, + u.OfLabelModelGrader) +} +func (u *MultiGraderGradersUnionParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *MultiGraderGradersUnionParam) asAny() any { + if !param.IsOmitted(u.OfStringCheckGrader) { + return u.OfStringCheckGrader + } else if !param.IsOmitted(u.OfTextSimilarityGrader) { + return u.OfTextSimilarityGrader + } else if !param.IsOmitted(u.OfPythonGrader) { + return u.OfPythonGrader + } else if !param.IsOmitted(u.OfScoreModelGrader) { + return u.OfScoreModelGrader + } else if !param.IsOmitted(u.OfLabelModelGrader) { + return u.OfLabelModelGrader + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u MultiGraderGradersUnionParam) GetOperation() *string { + if vt := u.OfStringCheckGrader; vt != nil { + return (*string)(&vt.Operation) + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u MultiGraderGradersUnionParam) GetEvaluationMetric() *string { + if vt := u.OfTextSimilarityGrader; vt != nil { + return (*string)(&vt.EvaluationMetric) + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u MultiGraderGradersUnionParam) GetSource() *string { + if vt := u.OfPythonGrader; vt != nil { + return &vt.Source + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u MultiGraderGradersUnionParam) GetImageTag() *string { + if vt := u.OfPythonGrader; vt != nil && vt.ImageTag.Valid() { + return &vt.ImageTag.Value + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u MultiGraderGradersUnionParam) GetRange() []float64 { + if vt := u.OfScoreModelGrader; vt != nil { + return vt.Range + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u MultiGraderGradersUnionParam) GetSamplingParams() *any { + if vt := u.OfScoreModelGrader; vt != nil { + return &vt.SamplingParams + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u MultiGraderGradersUnionParam) GetLabels() []string { + if vt := u.OfLabelModelGrader; vt != nil { + return vt.Labels + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u MultiGraderGradersUnionParam) GetPassingLabels() []string { + if vt := u.OfLabelModelGrader; vt != nil { + return vt.PassingLabels + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u MultiGraderGradersUnionParam) GetName() *string { + if vt := u.OfStringCheckGrader; vt != nil { + return (*string)(&vt.Name) + } else if vt := u.OfTextSimilarityGrader; vt != nil { + return (*string)(&vt.Name) + } else if vt := u.OfPythonGrader; vt != nil { + return (*string)(&vt.Name) + } else if vt := u.OfScoreModelGrader; vt != nil { + return (*string)(&vt.Name) + } else if vt := u.OfLabelModelGrader; vt != nil { + return (*string)(&vt.Name) + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u MultiGraderGradersUnionParam) GetReference() *string { + if vt := u.OfStringCheckGrader; vt != nil { + return (*string)(&vt.Reference) + } else if vt := u.OfTextSimilarityGrader; vt != nil { + return (*string)(&vt.Reference) + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u MultiGraderGradersUnionParam) GetType() *string { + if vt := u.OfStringCheckGrader; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfTextSimilarityGrader; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfPythonGrader; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfScoreModelGrader; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfLabelModelGrader; vt != nil { + return (*string)(&vt.Type) + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u MultiGraderGradersUnionParam) GetModel() *string { + if vt := u.OfScoreModelGrader; vt != nil { + return (*string)(&vt.Model) + } else if vt := u.OfLabelModelGrader; vt != nil { + return (*string)(&vt.Model) + } + return nil +} + +// Returns a subunion which exports methods to access subproperties +// +// Or use AsAny() to get the underlying value +func (u MultiGraderGradersUnionParam) GetInput() (res multiGraderGradersUnionParamInput) { + if vt := u.OfStringCheckGrader; vt != nil { + res.any = &vt.Input + } else if vt := u.OfTextSimilarityGrader; vt != nil { + res.any = &vt.Input + } else if vt := u.OfScoreModelGrader; vt != nil { + res.any = &vt.Input + } else if vt := u.OfLabelModelGrader; vt != nil { + res.any = &vt.Input + } + return +} + +// Can have the runtime types [*string], [_[]ScoreModelGraderInputParam], +// [_[]LabelModelGraderInputParam] +type multiGraderGradersUnionParamInput struct{ any } + +// Use the following switch statement to get the type of the union: +// +// switch u.AsAny().(type) { +// case *string: +// case *[]openai.ScoreModelGraderInputParam: +// case *[]openai.LabelModelGraderInputParam: +// default: +// fmt.Errorf("not present") +// } +func (u multiGraderGradersUnionParamInput) AsAny() any { return u.any } + +// A PythonGrader object that runs a python script on the input. +type PythonGrader struct { + // The name of the grader. + Name string `json:"name,required"` + // The source code of the python script. + Source string `json:"source,required"` + // The object type, which is always `python`. + Type constant.Python `json:"type,required"` + // The image tag to use for the python script. + ImageTag string `json:"image_tag"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Name respjson.Field + Source respjson.Field + Type respjson.Field + ImageTag respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r PythonGrader) RawJSON() string { return r.JSON.raw } +func (r *PythonGrader) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ToParam converts this PythonGrader to a PythonGraderParam. +// +// Warning: the fields of the param type will not be present. ToParam should only +// be used at the last possible moment before sending a request. Test for this with +// PythonGraderParam.Overrides() +func (r PythonGrader) ToParam() PythonGraderParam { + return param.Override[PythonGraderParam](json.RawMessage(r.RawJSON())) +} + +// A PythonGrader object that runs a python script on the input. +// +// The properties Name, Source, Type are required. +type PythonGraderParam struct { + // The name of the grader. + Name string `json:"name,required"` + // The source code of the python script. + Source string `json:"source,required"` + // The image tag to use for the python script. + ImageTag param.Opt[string] `json:"image_tag,omitzero"` + // The object type, which is always `python`. + // + // This field can be elided, and will marshal its zero value as "python". + Type constant.Python `json:"type,required"` + paramObj +} + +func (r PythonGraderParam) MarshalJSON() (data []byte, err error) { + type shadow PythonGraderParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *PythonGraderParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// A ScoreModelGrader object that uses a model to assign a score to the input. +type ScoreModelGrader struct { + // The input text. This may include template strings. + Input []ScoreModelGraderInput `json:"input,required"` + // The model to use for the evaluation. + Model string `json:"model,required"` + // The name of the grader. + Name string `json:"name,required"` + // The object type, which is always `score_model`. + Type constant.ScoreModel `json:"type,required"` + // The range of the score. Defaults to `[0, 1]`. + Range []float64 `json:"range"` + // The sampling parameters for the model. + SamplingParams any `json:"sampling_params"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Input respjson.Field + Model respjson.Field + Name respjson.Field + Type respjson.Field + Range respjson.Field + SamplingParams respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ScoreModelGrader) RawJSON() string { return r.JSON.raw } +func (r *ScoreModelGrader) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ToParam converts this ScoreModelGrader to a ScoreModelGraderParam. +// +// Warning: the fields of the param type will not be present. ToParam should only +// be used at the last possible moment before sending a request. Test for this with +// ScoreModelGraderParam.Overrides() +func (r ScoreModelGrader) ToParam() ScoreModelGraderParam { + return param.Override[ScoreModelGraderParam](json.RawMessage(r.RawJSON())) +} + +// A message input to the model with a role indicating instruction following +// hierarchy. Instructions given with the `developer` or `system` role take +// precedence over instructions given with the `user` role. Messages with the +// `assistant` role are presumed to have been generated by the model in previous +// interactions. +type ScoreModelGraderInput struct { + // Inputs to the model - can contain template strings. + Content ScoreModelGraderInputContentUnion `json:"content,required"` + // The role of the message input. One of `user`, `assistant`, `system`, or + // `developer`. + // + // Any of "user", "assistant", "system", "developer". + Role string `json:"role,required"` + // The type of the message input. Always `message`. + // + // Any of "message". + Type string `json:"type"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Content respjson.Field + Role respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ScoreModelGraderInput) RawJSON() string { return r.JSON.raw } +func (r *ScoreModelGraderInput) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ScoreModelGraderInputContentUnion contains all possible properties and values +// from [string], [responses.ResponseInputText], +// [ScoreModelGraderInputContentOutputText], +// [ScoreModelGraderInputContentInputImage], [[]any]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +// +// If the underlying value is not a json object, one of the following properties +// will be valid: OfString OfAnArrayOfInputTextAndInputImage] +type ScoreModelGraderInputContentUnion struct { + // This field will be present if the value is a [string] instead of an object. + OfString string `json:",inline"` + // This field will be present if the value is a [[]any] instead of an object. + OfAnArrayOfInputTextAndInputImage []any `json:",inline"` + Text string `json:"text"` + Type string `json:"type"` + // This field is from variant [ScoreModelGraderInputContentInputImage]. + ImageURL string `json:"image_url"` + // This field is from variant [ScoreModelGraderInputContentInputImage]. + Detail string `json:"detail"` + JSON struct { + OfString respjson.Field + OfAnArrayOfInputTextAndInputImage respjson.Field + Text respjson.Field + Type respjson.Field + ImageURL respjson.Field + Detail respjson.Field + raw string + } `json:"-"` +} + +func (u ScoreModelGraderInputContentUnion) AsString() (v string) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u ScoreModelGraderInputContentUnion) AsInputText() (v responses.ResponseInputText) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u ScoreModelGraderInputContentUnion) AsOutputText() (v ScoreModelGraderInputContentOutputText) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u ScoreModelGraderInputContentUnion) AsInputImage() (v ScoreModelGraderInputContentInputImage) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u ScoreModelGraderInputContentUnion) AsAnArrayOfInputTextAndInputImage() (v []any) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u ScoreModelGraderInputContentUnion) RawJSON() string { return u.JSON.raw } + +func (r *ScoreModelGraderInputContentUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// A text output from the model. +type ScoreModelGraderInputContentOutputText struct { + // The text output from the model. + Text string `json:"text,required"` + // The type of the output text. Always `output_text`. + Type constant.OutputText `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Text respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ScoreModelGraderInputContentOutputText) RawJSON() string { return r.JSON.raw } +func (r *ScoreModelGraderInputContentOutputText) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// An image input to the model. +type ScoreModelGraderInputContentInputImage struct { + // The URL of the image input. + ImageURL string `json:"image_url,required"` + // The type of the image input. Always `input_image`. + Type constant.InputImage `json:"type,required"` + // The detail level of the image to be sent to the model. One of `high`, `low`, or + // `auto`. Defaults to `auto`. + Detail string `json:"detail"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ImageURL respjson.Field + Type respjson.Field + Detail respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ScoreModelGraderInputContentInputImage) RawJSON() string { return r.JSON.raw } +func (r *ScoreModelGraderInputContentInputImage) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// A ScoreModelGrader object that uses a model to assign a score to the input. +// +// The properties Input, Model, Name, Type are required. +type ScoreModelGraderParam struct { + // The input text. This may include template strings. + Input []ScoreModelGraderInputParam `json:"input,omitzero,required"` + // The model to use for the evaluation. + Model string `json:"model,required"` + // The name of the grader. + Name string `json:"name,required"` + // The range of the score. Defaults to `[0, 1]`. + Range []float64 `json:"range,omitzero"` + // The sampling parameters for the model. + SamplingParams any `json:"sampling_params,omitzero"` + // The object type, which is always `score_model`. + // + // This field can be elided, and will marshal its zero value as "score_model". + Type constant.ScoreModel `json:"type,required"` + paramObj +} + +func (r ScoreModelGraderParam) MarshalJSON() (data []byte, err error) { + type shadow ScoreModelGraderParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ScoreModelGraderParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// A message input to the model with a role indicating instruction following +// hierarchy. Instructions given with the `developer` or `system` role take +// precedence over instructions given with the `user` role. Messages with the +// `assistant` role are presumed to have been generated by the model in previous +// interactions. +// +// The properties Content, Role are required. +type ScoreModelGraderInputParam struct { + // Inputs to the model - can contain template strings. + Content ScoreModelGraderInputContentUnionParam `json:"content,omitzero,required"` + // The role of the message input. One of `user`, `assistant`, `system`, or + // `developer`. + // + // Any of "user", "assistant", "system", "developer". + Role string `json:"role,omitzero,required"` + // The type of the message input. Always `message`. + // + // Any of "message". + Type string `json:"type,omitzero"` + paramObj +} + +func (r ScoreModelGraderInputParam) MarshalJSON() (data []byte, err error) { + type shadow ScoreModelGraderInputParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ScoreModelGraderInputParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +func init() { + apijson.RegisterFieldValidator[ScoreModelGraderInputParam]( + "role", "user", "assistant", "system", "developer", + ) + apijson.RegisterFieldValidator[ScoreModelGraderInputParam]( + "type", "message", + ) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type ScoreModelGraderInputContentUnionParam struct { + OfString param.Opt[string] `json:",omitzero,inline"` + OfInputText *responses.ResponseInputTextParam `json:",omitzero,inline"` + OfOutputText *ScoreModelGraderInputContentOutputTextParam `json:",omitzero,inline"` + OfInputImage *ScoreModelGraderInputContentInputImageParam `json:",omitzero,inline"` + OfAnArrayOfInputTextAndInputImage []any `json:",omitzero,inline"` + paramUnion +} + +func (u ScoreModelGraderInputContentUnionParam) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfString, + u.OfInputText, + u.OfOutputText, + u.OfInputImage, + u.OfAnArrayOfInputTextAndInputImage) +} +func (u *ScoreModelGraderInputContentUnionParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *ScoreModelGraderInputContentUnionParam) asAny() any { + if !param.IsOmitted(u.OfString) { + return &u.OfString.Value + } else if !param.IsOmitted(u.OfInputText) { + return u.OfInputText + } else if !param.IsOmitted(u.OfOutputText) { + return u.OfOutputText + } else if !param.IsOmitted(u.OfInputImage) { + return u.OfInputImage + } else if !param.IsOmitted(u.OfAnArrayOfInputTextAndInputImage) { + return &u.OfAnArrayOfInputTextAndInputImage + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u ScoreModelGraderInputContentUnionParam) GetImageURL() *string { + if vt := u.OfInputImage; vt != nil { + return &vt.ImageURL + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u ScoreModelGraderInputContentUnionParam) GetDetail() *string { + if vt := u.OfInputImage; vt != nil && vt.Detail.Valid() { + return &vt.Detail.Value + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u ScoreModelGraderInputContentUnionParam) GetText() *string { + if vt := u.OfInputText; vt != nil { + return (*string)(&vt.Text) + } else if vt := u.OfOutputText; vt != nil { + return (*string)(&vt.Text) + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u ScoreModelGraderInputContentUnionParam) GetType() *string { + if vt := u.OfInputText; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfOutputText; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfInputImage; vt != nil { + return (*string)(&vt.Type) + } + return nil +} + +// A text output from the model. +// +// The properties Text, Type are required. +type ScoreModelGraderInputContentOutputTextParam struct { + // The text output from the model. + Text string `json:"text,required"` + // The type of the output text. Always `output_text`. + // + // This field can be elided, and will marshal its zero value as "output_text". + Type constant.OutputText `json:"type,required"` + paramObj +} + +func (r ScoreModelGraderInputContentOutputTextParam) MarshalJSON() (data []byte, err error) { + type shadow ScoreModelGraderInputContentOutputTextParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ScoreModelGraderInputContentOutputTextParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// An image input to the model. +// +// The properties ImageURL, Type are required. +type ScoreModelGraderInputContentInputImageParam struct { + // The URL of the image input. + ImageURL string `json:"image_url,required"` + // The detail level of the image to be sent to the model. One of `high`, `low`, or + // `auto`. Defaults to `auto`. + Detail param.Opt[string] `json:"detail,omitzero"` + // The type of the image input. Always `input_image`. + // + // This field can be elided, and will marshal its zero value as "input_image". + Type constant.InputImage `json:"type,required"` + paramObj +} + +func (r ScoreModelGraderInputContentInputImageParam) MarshalJSON() (data []byte, err error) { + type shadow ScoreModelGraderInputContentInputImageParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ScoreModelGraderInputContentInputImageParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// A StringCheckGrader object that performs a string comparison between input and +// reference using a specified operation. +type StringCheckGrader struct { + // The input text. This may include template strings. + Input string `json:"input,required"` + // The name of the grader. + Name string `json:"name,required"` + // The string check operation to perform. One of `eq`, `ne`, `like`, or `ilike`. + // + // Any of "eq", "ne", "like", "ilike". + Operation StringCheckGraderOperation `json:"operation,required"` + // The reference text. This may include template strings. + Reference string `json:"reference,required"` + // The object type, which is always `string_check`. + Type constant.StringCheck `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Input respjson.Field + Name respjson.Field + Operation respjson.Field + Reference respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r StringCheckGrader) RawJSON() string { return r.JSON.raw } +func (r *StringCheckGrader) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ToParam converts this StringCheckGrader to a StringCheckGraderParam. +// +// Warning: the fields of the param type will not be present. ToParam should only +// be used at the last possible moment before sending a request. Test for this with +// StringCheckGraderParam.Overrides() +func (r StringCheckGrader) ToParam() StringCheckGraderParam { + return param.Override[StringCheckGraderParam](json.RawMessage(r.RawJSON())) +} + +// The string check operation to perform. One of `eq`, `ne`, `like`, or `ilike`. +type StringCheckGraderOperation string + +const ( + StringCheckGraderOperationEq StringCheckGraderOperation = "eq" + StringCheckGraderOperationNe StringCheckGraderOperation = "ne" + StringCheckGraderOperationLike StringCheckGraderOperation = "like" + StringCheckGraderOperationIlike StringCheckGraderOperation = "ilike" +) + +// A StringCheckGrader object that performs a string comparison between input and +// reference using a specified operation. +// +// The properties Input, Name, Operation, Reference, Type are required. +type StringCheckGraderParam struct { + // The input text. This may include template strings. + Input string `json:"input,required"` + // The name of the grader. + Name string `json:"name,required"` + // The string check operation to perform. One of `eq`, `ne`, `like`, or `ilike`. + // + // Any of "eq", "ne", "like", "ilike". + Operation StringCheckGraderOperation `json:"operation,omitzero,required"` + // The reference text. This may include template strings. + Reference string `json:"reference,required"` + // The object type, which is always `string_check`. + // + // This field can be elided, and will marshal its zero value as "string_check". + Type constant.StringCheck `json:"type,required"` + paramObj +} + +func (r StringCheckGraderParam) MarshalJSON() (data []byte, err error) { + type shadow StringCheckGraderParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *StringCheckGraderParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// A TextSimilarityGrader object which grades text based on similarity metrics. +type TextSimilarityGrader struct { + // The evaluation metric to use. One of `fuzzy_match`, `bleu`, `gleu`, `meteor`, + // `rouge_1`, `rouge_2`, `rouge_3`, `rouge_4`, `rouge_5`, or `rouge_l`. + // + // Any of "fuzzy_match", "bleu", "gleu", "meteor", "rouge_1", "rouge_2", "rouge_3", + // "rouge_4", "rouge_5", "rouge_l". + EvaluationMetric TextSimilarityGraderEvaluationMetric `json:"evaluation_metric,required"` + // The text being graded. + Input string `json:"input,required"` + // The name of the grader. + Name string `json:"name,required"` + // The text being graded against. + Reference string `json:"reference,required"` + // The type of grader. + Type constant.TextSimilarity `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + EvaluationMetric respjson.Field + Input respjson.Field + Name respjson.Field + Reference respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r TextSimilarityGrader) RawJSON() string { return r.JSON.raw } +func (r *TextSimilarityGrader) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// ToParam converts this TextSimilarityGrader to a TextSimilarityGraderParam. +// +// Warning: the fields of the param type will not be present. ToParam should only +// be used at the last possible moment before sending a request. Test for this with +// TextSimilarityGraderParam.Overrides() +func (r TextSimilarityGrader) ToParam() TextSimilarityGraderParam { + return param.Override[TextSimilarityGraderParam](json.RawMessage(r.RawJSON())) +} + +// The evaluation metric to use. One of `fuzzy_match`, `bleu`, `gleu`, `meteor`, +// `rouge_1`, `rouge_2`, `rouge_3`, `rouge_4`, `rouge_5`, or `rouge_l`. +type TextSimilarityGraderEvaluationMetric string + +const ( + TextSimilarityGraderEvaluationMetricFuzzyMatch TextSimilarityGraderEvaluationMetric = "fuzzy_match" + TextSimilarityGraderEvaluationMetricBleu TextSimilarityGraderEvaluationMetric = "bleu" + TextSimilarityGraderEvaluationMetricGleu TextSimilarityGraderEvaluationMetric = "gleu" + TextSimilarityGraderEvaluationMetricMeteor TextSimilarityGraderEvaluationMetric = "meteor" + TextSimilarityGraderEvaluationMetricRouge1 TextSimilarityGraderEvaluationMetric = "rouge_1" + TextSimilarityGraderEvaluationMetricRouge2 TextSimilarityGraderEvaluationMetric = "rouge_2" + TextSimilarityGraderEvaluationMetricRouge3 TextSimilarityGraderEvaluationMetric = "rouge_3" + TextSimilarityGraderEvaluationMetricRouge4 TextSimilarityGraderEvaluationMetric = "rouge_4" + TextSimilarityGraderEvaluationMetricRouge5 TextSimilarityGraderEvaluationMetric = "rouge_5" + TextSimilarityGraderEvaluationMetricRougeL TextSimilarityGraderEvaluationMetric = "rouge_l" +) + +// A TextSimilarityGrader object which grades text based on similarity metrics. +// +// The properties EvaluationMetric, Input, Name, Reference, Type are required. +type TextSimilarityGraderParam struct { + // The evaluation metric to use. One of `fuzzy_match`, `bleu`, `gleu`, `meteor`, + // `rouge_1`, `rouge_2`, `rouge_3`, `rouge_4`, `rouge_5`, or `rouge_l`. + // + // Any of "fuzzy_match", "bleu", "gleu", "meteor", "rouge_1", "rouge_2", "rouge_3", + // "rouge_4", "rouge_5", "rouge_l". + EvaluationMetric TextSimilarityGraderEvaluationMetric `json:"evaluation_metric,omitzero,required"` + // The text being graded. + Input string `json:"input,required"` + // The name of the grader. + Name string `json:"name,required"` + // The text being graded against. + Reference string `json:"reference,required"` + // The type of grader. + // + // This field can be elided, and will marshal its zero value as "text_similarity". + Type constant.TextSimilarity `json:"type,required"` + paramObj +} + +func (r TextSimilarityGraderParam) MarshalJSON() (data []byte, err error) { + type shadow TextSimilarityGraderParam + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *TextSimilarityGraderParam) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} diff --git a/vendor/github.com/openai/openai-go/image.go b/vendor/github.com/openai/openai-go/image.go new file mode 100644 index 0000000000..63f7b2eec1 --- /dev/null +++ b/vendor/github.com/openai/openai-go/image.go @@ -0,0 +1,1300 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package openai + +import ( + "bytes" + "context" + "encoding/json" + "io" + "mime/multipart" + "net/http" + + "github.com/openai/openai-go/internal/apiform" + "github.com/openai/openai-go/internal/apijson" + "github.com/openai/openai-go/internal/requestconfig" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/packages/param" + "github.com/openai/openai-go/packages/respjson" + "github.com/openai/openai-go/packages/ssestream" + "github.com/openai/openai-go/shared/constant" +) + +// ImageService contains methods and other services that help with interacting with +// the openai API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewImageService] method instead. +type ImageService struct { + Options []option.RequestOption +} + +// NewImageService generates a new service that applies the given options to each +// request. These options are applied after the parent client's options (if there +// is one), and before any request-specific options. +func NewImageService(opts ...option.RequestOption) (r ImageService) { + r = ImageService{} + r.Options = opts + return +} + +// Creates a variation of a given image. This endpoint only supports `dall-e-2`. +func (r *ImageService) NewVariation(ctx context.Context, body ImageNewVariationParams, opts ...option.RequestOption) (res *ImagesResponse, err error) { + opts = append(r.Options[:], opts...) + path := "images/variations" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) + return +} + +// Creates an edited or extended image given one or more source images and a +// prompt. This endpoint only supports `gpt-image-1` and `dall-e-2`. +func (r *ImageService) Edit(ctx context.Context, body ImageEditParams, opts ...option.RequestOption) (res *ImagesResponse, err error) { + opts = append(r.Options[:], opts...) + path := "images/edits" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) + return +} + +// Creates an edited or extended image given one or more source images and a +// prompt. This endpoint only supports `gpt-image-1` and `dall-e-2`. +func (r *ImageService) EditStreaming(ctx context.Context, body ImageEditParams, opts ...option.RequestOption) (stream *ssestream.Stream[ImageEditStreamEventUnion]) { + var ( + raw *http.Response + err error + ) + opts = append(r.Options[:], opts...) + body.SetExtraFields(map[string]any{ + "stream": "true", + }) + path := "images/edits" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &raw, opts...) + return ssestream.NewStream[ImageEditStreamEventUnion](ssestream.NewDecoder(raw), err) +} + +// Creates an image given a prompt. +// [Learn more](https://platform.openai.com/docs/guides/images). +func (r *ImageService) Generate(ctx context.Context, body ImageGenerateParams, opts ...option.RequestOption) (res *ImagesResponse, err error) { + opts = append(r.Options[:], opts...) + path := "images/generations" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) + return +} + +// Creates an image given a prompt. +// [Learn more](https://platform.openai.com/docs/guides/images). +func (r *ImageService) GenerateStreaming(ctx context.Context, body ImageGenerateParams, opts ...option.RequestOption) (stream *ssestream.Stream[ImageGenStreamEventUnion]) { + var ( + raw *http.Response + err error + ) + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithJSONSet("stream", true)}, opts...) + path := "images/generations" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &raw, opts...) + return ssestream.NewStream[ImageGenStreamEventUnion](ssestream.NewDecoder(raw), err) +} + +// Represents the content or the URL of an image generated by the OpenAI API. +type Image struct { + // The base64-encoded JSON of the generated image. Default value for `gpt-image-1`, + // and only present if `response_format` is set to `b64_json` for `dall-e-2` and + // `dall-e-3`. + B64JSON string `json:"b64_json"` + // For `dall-e-3` only, the revised prompt that was used to generate the image. + RevisedPrompt string `json:"revised_prompt"` + // When using `dall-e-2` or `dall-e-3`, the URL of the generated image if + // `response_format` is set to `url` (default value). Unsupported for + // `gpt-image-1`. + URL string `json:"url"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + B64JSON respjson.Field + RevisedPrompt respjson.Field + URL respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r Image) RawJSON() string { return r.JSON.raw } +func (r *Image) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Emitted when image editing has completed and the final image is available. +type ImageEditCompletedEvent struct { + // Base64-encoded final edited image data, suitable for rendering as an image. + B64JSON string `json:"b64_json,required"` + // The background setting for the edited image. + // + // Any of "transparent", "opaque", "auto". + Background ImageEditCompletedEventBackground `json:"background,required"` + // The Unix timestamp when the event was created. + CreatedAt int64 `json:"created_at,required"` + // The output format for the edited image. + // + // Any of "png", "webp", "jpeg". + OutputFormat ImageEditCompletedEventOutputFormat `json:"output_format,required"` + // The quality setting for the edited image. + // + // Any of "low", "medium", "high", "auto". + Quality ImageEditCompletedEventQuality `json:"quality,required"` + // The size of the edited image. + // + // Any of "1024x1024", "1024x1536", "1536x1024", "auto". + Size ImageEditCompletedEventSize `json:"size,required"` + // The type of the event. Always `image_edit.completed`. + Type constant.ImageEditCompleted `json:"type,required"` + // For `gpt-image-1` only, the token usage information for the image generation. + Usage ImageEditCompletedEventUsage `json:"usage,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + B64JSON respjson.Field + Background respjson.Field + CreatedAt respjson.Field + OutputFormat respjson.Field + Quality respjson.Field + Size respjson.Field + Type respjson.Field + Usage respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ImageEditCompletedEvent) RawJSON() string { return r.JSON.raw } +func (r *ImageEditCompletedEvent) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The background setting for the edited image. +type ImageEditCompletedEventBackground string + +const ( + ImageEditCompletedEventBackgroundTransparent ImageEditCompletedEventBackground = "transparent" + ImageEditCompletedEventBackgroundOpaque ImageEditCompletedEventBackground = "opaque" + ImageEditCompletedEventBackgroundAuto ImageEditCompletedEventBackground = "auto" +) + +// The output format for the edited image. +type ImageEditCompletedEventOutputFormat string + +const ( + ImageEditCompletedEventOutputFormatPNG ImageEditCompletedEventOutputFormat = "png" + ImageEditCompletedEventOutputFormatWebP ImageEditCompletedEventOutputFormat = "webp" + ImageEditCompletedEventOutputFormatJPEG ImageEditCompletedEventOutputFormat = "jpeg" +) + +// The quality setting for the edited image. +type ImageEditCompletedEventQuality string + +const ( + ImageEditCompletedEventQualityLow ImageEditCompletedEventQuality = "low" + ImageEditCompletedEventQualityMedium ImageEditCompletedEventQuality = "medium" + ImageEditCompletedEventQualityHigh ImageEditCompletedEventQuality = "high" + ImageEditCompletedEventQualityAuto ImageEditCompletedEventQuality = "auto" +) + +// The size of the edited image. +type ImageEditCompletedEventSize string + +const ( + ImageEditCompletedEventSize1024x1024 ImageEditCompletedEventSize = "1024x1024" + ImageEditCompletedEventSize1024x1536 ImageEditCompletedEventSize = "1024x1536" + ImageEditCompletedEventSize1536x1024 ImageEditCompletedEventSize = "1536x1024" + ImageEditCompletedEventSizeAuto ImageEditCompletedEventSize = "auto" +) + +// For `gpt-image-1` only, the token usage information for the image generation. +type ImageEditCompletedEventUsage struct { + // The number of tokens (images and text) in the input prompt. + InputTokens int64 `json:"input_tokens,required"` + // The input tokens detailed information for the image generation. + InputTokensDetails ImageEditCompletedEventUsageInputTokensDetails `json:"input_tokens_details,required"` + // The number of image tokens in the output image. + OutputTokens int64 `json:"output_tokens,required"` + // The total number of tokens (images and text) used for the image generation. + TotalTokens int64 `json:"total_tokens,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + InputTokens respjson.Field + InputTokensDetails respjson.Field + OutputTokens respjson.Field + TotalTokens respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ImageEditCompletedEventUsage) RawJSON() string { return r.JSON.raw } +func (r *ImageEditCompletedEventUsage) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The input tokens detailed information for the image generation. +type ImageEditCompletedEventUsageInputTokensDetails struct { + // The number of image tokens in the input prompt. + ImageTokens int64 `json:"image_tokens,required"` + // The number of text tokens in the input prompt. + TextTokens int64 `json:"text_tokens,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ImageTokens respjson.Field + TextTokens respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ImageEditCompletedEventUsageInputTokensDetails) RawJSON() string { return r.JSON.raw } +func (r *ImageEditCompletedEventUsageInputTokensDetails) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Emitted when a partial image is available during image editing streaming. +type ImageEditPartialImageEvent struct { + // Base64-encoded partial image data, suitable for rendering as an image. + B64JSON string `json:"b64_json,required"` + // The background setting for the requested edited image. + // + // Any of "transparent", "opaque", "auto". + Background ImageEditPartialImageEventBackground `json:"background,required"` + // The Unix timestamp when the event was created. + CreatedAt int64 `json:"created_at,required"` + // The output format for the requested edited image. + // + // Any of "png", "webp", "jpeg". + OutputFormat ImageEditPartialImageEventOutputFormat `json:"output_format,required"` + // 0-based index for the partial image (streaming). + PartialImageIndex int64 `json:"partial_image_index,required"` + // The quality setting for the requested edited image. + // + // Any of "low", "medium", "high", "auto". + Quality ImageEditPartialImageEventQuality `json:"quality,required"` + // The size of the requested edited image. + // + // Any of "1024x1024", "1024x1536", "1536x1024", "auto". + Size ImageEditPartialImageEventSize `json:"size,required"` + // The type of the event. Always `image_edit.partial_image`. + Type constant.ImageEditPartialImage `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + B64JSON respjson.Field + Background respjson.Field + CreatedAt respjson.Field + OutputFormat respjson.Field + PartialImageIndex respjson.Field + Quality respjson.Field + Size respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ImageEditPartialImageEvent) RawJSON() string { return r.JSON.raw } +func (r *ImageEditPartialImageEvent) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The background setting for the requested edited image. +type ImageEditPartialImageEventBackground string + +const ( + ImageEditPartialImageEventBackgroundTransparent ImageEditPartialImageEventBackground = "transparent" + ImageEditPartialImageEventBackgroundOpaque ImageEditPartialImageEventBackground = "opaque" + ImageEditPartialImageEventBackgroundAuto ImageEditPartialImageEventBackground = "auto" +) + +// The output format for the requested edited image. +type ImageEditPartialImageEventOutputFormat string + +const ( + ImageEditPartialImageEventOutputFormatPNG ImageEditPartialImageEventOutputFormat = "png" + ImageEditPartialImageEventOutputFormatWebP ImageEditPartialImageEventOutputFormat = "webp" + ImageEditPartialImageEventOutputFormatJPEG ImageEditPartialImageEventOutputFormat = "jpeg" +) + +// The quality setting for the requested edited image. +type ImageEditPartialImageEventQuality string + +const ( + ImageEditPartialImageEventQualityLow ImageEditPartialImageEventQuality = "low" + ImageEditPartialImageEventQualityMedium ImageEditPartialImageEventQuality = "medium" + ImageEditPartialImageEventQualityHigh ImageEditPartialImageEventQuality = "high" + ImageEditPartialImageEventQualityAuto ImageEditPartialImageEventQuality = "auto" +) + +// The size of the requested edited image. +type ImageEditPartialImageEventSize string + +const ( + ImageEditPartialImageEventSize1024x1024 ImageEditPartialImageEventSize = "1024x1024" + ImageEditPartialImageEventSize1024x1536 ImageEditPartialImageEventSize = "1024x1536" + ImageEditPartialImageEventSize1536x1024 ImageEditPartialImageEventSize = "1536x1024" + ImageEditPartialImageEventSizeAuto ImageEditPartialImageEventSize = "auto" +) + +// ImageEditStreamEventUnion contains all possible properties and values from +// [ImageEditPartialImageEvent], [ImageEditCompletedEvent]. +// +// Use the [ImageEditStreamEventUnion.AsAny] method to switch on the variant. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +type ImageEditStreamEventUnion struct { + B64JSON string `json:"b64_json"` + Background string `json:"background"` + CreatedAt int64 `json:"created_at"` + OutputFormat string `json:"output_format"` + // This field is from variant [ImageEditPartialImageEvent]. + PartialImageIndex int64 `json:"partial_image_index"` + Quality string `json:"quality"` + Size string `json:"size"` + // Any of "image_edit.partial_image", "image_edit.completed". + Type string `json:"type"` + // This field is from variant [ImageEditCompletedEvent]. + Usage ImageEditCompletedEventUsage `json:"usage"` + JSON struct { + B64JSON respjson.Field + Background respjson.Field + CreatedAt respjson.Field + OutputFormat respjson.Field + PartialImageIndex respjson.Field + Quality respjson.Field + Size respjson.Field + Type respjson.Field + Usage respjson.Field + raw string + } `json:"-"` +} + +// anyImageEditStreamEvent is implemented by each variant of +// [ImageEditStreamEventUnion] to add type safety for the return type of +// [ImageEditStreamEventUnion.AsAny] +type anyImageEditStreamEvent interface { + implImageEditStreamEventUnion() +} + +func (ImageEditPartialImageEvent) implImageEditStreamEventUnion() {} +func (ImageEditCompletedEvent) implImageEditStreamEventUnion() {} + +// Use the following switch statement to find the correct variant +// +// switch variant := ImageEditStreamEventUnion.AsAny().(type) { +// case openai.ImageEditPartialImageEvent: +// case openai.ImageEditCompletedEvent: +// default: +// fmt.Errorf("no variant present") +// } +func (u ImageEditStreamEventUnion) AsAny() anyImageEditStreamEvent { + switch u.Type { + case "image_edit.partial_image": + return u.AsImageEditPartialImage() + case "image_edit.completed": + return u.AsImageEditCompleted() + } + return nil +} + +func (u ImageEditStreamEventUnion) AsImageEditPartialImage() (v ImageEditPartialImageEvent) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u ImageEditStreamEventUnion) AsImageEditCompleted() (v ImageEditCompletedEvent) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u ImageEditStreamEventUnion) RawJSON() string { return u.JSON.raw } + +func (r *ImageEditStreamEventUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Emitted when image generation has completed and the final image is available. +type ImageGenCompletedEvent struct { + // Base64-encoded image data, suitable for rendering as an image. + B64JSON string `json:"b64_json,required"` + // The background setting for the generated image. + // + // Any of "transparent", "opaque", "auto". + Background ImageGenCompletedEventBackground `json:"background,required"` + // The Unix timestamp when the event was created. + CreatedAt int64 `json:"created_at,required"` + // The output format for the generated image. + // + // Any of "png", "webp", "jpeg". + OutputFormat ImageGenCompletedEventOutputFormat `json:"output_format,required"` + // The quality setting for the generated image. + // + // Any of "low", "medium", "high", "auto". + Quality ImageGenCompletedEventQuality `json:"quality,required"` + // The size of the generated image. + // + // Any of "1024x1024", "1024x1536", "1536x1024", "auto". + Size ImageGenCompletedEventSize `json:"size,required"` + // The type of the event. Always `image_generation.completed`. + Type constant.ImageGenerationCompleted `json:"type,required"` + // For `gpt-image-1` only, the token usage information for the image generation. + Usage ImageGenCompletedEventUsage `json:"usage,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + B64JSON respjson.Field + Background respjson.Field + CreatedAt respjson.Field + OutputFormat respjson.Field + Quality respjson.Field + Size respjson.Field + Type respjson.Field + Usage respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ImageGenCompletedEvent) RawJSON() string { return r.JSON.raw } +func (r *ImageGenCompletedEvent) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The background setting for the generated image. +type ImageGenCompletedEventBackground string + +const ( + ImageGenCompletedEventBackgroundTransparent ImageGenCompletedEventBackground = "transparent" + ImageGenCompletedEventBackgroundOpaque ImageGenCompletedEventBackground = "opaque" + ImageGenCompletedEventBackgroundAuto ImageGenCompletedEventBackground = "auto" +) + +// The output format for the generated image. +type ImageGenCompletedEventOutputFormat string + +const ( + ImageGenCompletedEventOutputFormatPNG ImageGenCompletedEventOutputFormat = "png" + ImageGenCompletedEventOutputFormatWebP ImageGenCompletedEventOutputFormat = "webp" + ImageGenCompletedEventOutputFormatJPEG ImageGenCompletedEventOutputFormat = "jpeg" +) + +// The quality setting for the generated image. +type ImageGenCompletedEventQuality string + +const ( + ImageGenCompletedEventQualityLow ImageGenCompletedEventQuality = "low" + ImageGenCompletedEventQualityMedium ImageGenCompletedEventQuality = "medium" + ImageGenCompletedEventQualityHigh ImageGenCompletedEventQuality = "high" + ImageGenCompletedEventQualityAuto ImageGenCompletedEventQuality = "auto" +) + +// The size of the generated image. +type ImageGenCompletedEventSize string + +const ( + ImageGenCompletedEventSize1024x1024 ImageGenCompletedEventSize = "1024x1024" + ImageGenCompletedEventSize1024x1536 ImageGenCompletedEventSize = "1024x1536" + ImageGenCompletedEventSize1536x1024 ImageGenCompletedEventSize = "1536x1024" + ImageGenCompletedEventSizeAuto ImageGenCompletedEventSize = "auto" +) + +// For `gpt-image-1` only, the token usage information for the image generation. +type ImageGenCompletedEventUsage struct { + // The number of tokens (images and text) in the input prompt. + InputTokens int64 `json:"input_tokens,required"` + // The input tokens detailed information for the image generation. + InputTokensDetails ImageGenCompletedEventUsageInputTokensDetails `json:"input_tokens_details,required"` + // The number of image tokens in the output image. + OutputTokens int64 `json:"output_tokens,required"` + // The total number of tokens (images and text) used for the image generation. + TotalTokens int64 `json:"total_tokens,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + InputTokens respjson.Field + InputTokensDetails respjson.Field + OutputTokens respjson.Field + TotalTokens respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ImageGenCompletedEventUsage) RawJSON() string { return r.JSON.raw } +func (r *ImageGenCompletedEventUsage) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The input tokens detailed information for the image generation. +type ImageGenCompletedEventUsageInputTokensDetails struct { + // The number of image tokens in the input prompt. + ImageTokens int64 `json:"image_tokens,required"` + // The number of text tokens in the input prompt. + TextTokens int64 `json:"text_tokens,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ImageTokens respjson.Field + TextTokens respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ImageGenCompletedEventUsageInputTokensDetails) RawJSON() string { return r.JSON.raw } +func (r *ImageGenCompletedEventUsageInputTokensDetails) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Emitted when a partial image is available during image generation streaming. +type ImageGenPartialImageEvent struct { + // Base64-encoded partial image data, suitable for rendering as an image. + B64JSON string `json:"b64_json,required"` + // The background setting for the requested image. + // + // Any of "transparent", "opaque", "auto". + Background ImageGenPartialImageEventBackground `json:"background,required"` + // The Unix timestamp when the event was created. + CreatedAt int64 `json:"created_at,required"` + // The output format for the requested image. + // + // Any of "png", "webp", "jpeg". + OutputFormat ImageGenPartialImageEventOutputFormat `json:"output_format,required"` + // 0-based index for the partial image (streaming). + PartialImageIndex int64 `json:"partial_image_index,required"` + // The quality setting for the requested image. + // + // Any of "low", "medium", "high", "auto". + Quality ImageGenPartialImageEventQuality `json:"quality,required"` + // The size of the requested image. + // + // Any of "1024x1024", "1024x1536", "1536x1024", "auto". + Size ImageGenPartialImageEventSize `json:"size,required"` + // The type of the event. Always `image_generation.partial_image`. + Type constant.ImageGenerationPartialImage `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + B64JSON respjson.Field + Background respjson.Field + CreatedAt respjson.Field + OutputFormat respjson.Field + PartialImageIndex respjson.Field + Quality respjson.Field + Size respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ImageGenPartialImageEvent) RawJSON() string { return r.JSON.raw } +func (r *ImageGenPartialImageEvent) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The background setting for the requested image. +type ImageGenPartialImageEventBackground string + +const ( + ImageGenPartialImageEventBackgroundTransparent ImageGenPartialImageEventBackground = "transparent" + ImageGenPartialImageEventBackgroundOpaque ImageGenPartialImageEventBackground = "opaque" + ImageGenPartialImageEventBackgroundAuto ImageGenPartialImageEventBackground = "auto" +) + +// The output format for the requested image. +type ImageGenPartialImageEventOutputFormat string + +const ( + ImageGenPartialImageEventOutputFormatPNG ImageGenPartialImageEventOutputFormat = "png" + ImageGenPartialImageEventOutputFormatWebP ImageGenPartialImageEventOutputFormat = "webp" + ImageGenPartialImageEventOutputFormatJPEG ImageGenPartialImageEventOutputFormat = "jpeg" +) + +// The quality setting for the requested image. +type ImageGenPartialImageEventQuality string + +const ( + ImageGenPartialImageEventQualityLow ImageGenPartialImageEventQuality = "low" + ImageGenPartialImageEventQualityMedium ImageGenPartialImageEventQuality = "medium" + ImageGenPartialImageEventQualityHigh ImageGenPartialImageEventQuality = "high" + ImageGenPartialImageEventQualityAuto ImageGenPartialImageEventQuality = "auto" +) + +// The size of the requested image. +type ImageGenPartialImageEventSize string + +const ( + ImageGenPartialImageEventSize1024x1024 ImageGenPartialImageEventSize = "1024x1024" + ImageGenPartialImageEventSize1024x1536 ImageGenPartialImageEventSize = "1024x1536" + ImageGenPartialImageEventSize1536x1024 ImageGenPartialImageEventSize = "1536x1024" + ImageGenPartialImageEventSizeAuto ImageGenPartialImageEventSize = "auto" +) + +// ImageGenStreamEventUnion contains all possible properties and values from +// [ImageGenPartialImageEvent], [ImageGenCompletedEvent]. +// +// Use the [ImageGenStreamEventUnion.AsAny] method to switch on the variant. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +type ImageGenStreamEventUnion struct { + B64JSON string `json:"b64_json"` + Background string `json:"background"` + CreatedAt int64 `json:"created_at"` + OutputFormat string `json:"output_format"` + // This field is from variant [ImageGenPartialImageEvent]. + PartialImageIndex int64 `json:"partial_image_index"` + Quality string `json:"quality"` + Size string `json:"size"` + // Any of "image_generation.partial_image", "image_generation.completed". + Type string `json:"type"` + // This field is from variant [ImageGenCompletedEvent]. + Usage ImageGenCompletedEventUsage `json:"usage"` + JSON struct { + B64JSON respjson.Field + Background respjson.Field + CreatedAt respjson.Field + OutputFormat respjson.Field + PartialImageIndex respjson.Field + Quality respjson.Field + Size respjson.Field + Type respjson.Field + Usage respjson.Field + raw string + } `json:"-"` +} + +// anyImageGenStreamEvent is implemented by each variant of +// [ImageGenStreamEventUnion] to add type safety for the return type of +// [ImageGenStreamEventUnion.AsAny] +type anyImageGenStreamEvent interface { + implImageGenStreamEventUnion() +} + +func (ImageGenPartialImageEvent) implImageGenStreamEventUnion() {} +func (ImageGenCompletedEvent) implImageGenStreamEventUnion() {} + +// Use the following switch statement to find the correct variant +// +// switch variant := ImageGenStreamEventUnion.AsAny().(type) { +// case openai.ImageGenPartialImageEvent: +// case openai.ImageGenCompletedEvent: +// default: +// fmt.Errorf("no variant present") +// } +func (u ImageGenStreamEventUnion) AsAny() anyImageGenStreamEvent { + switch u.Type { + case "image_generation.partial_image": + return u.AsImageGenerationPartialImage() + case "image_generation.completed": + return u.AsImageGenerationCompleted() + } + return nil +} + +func (u ImageGenStreamEventUnion) AsImageGenerationPartialImage() (v ImageGenPartialImageEvent) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u ImageGenStreamEventUnion) AsImageGenerationCompleted() (v ImageGenCompletedEvent) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u ImageGenStreamEventUnion) RawJSON() string { return u.JSON.raw } + +func (r *ImageGenStreamEventUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type ImageModel = string + +const ( + ImageModelDallE2 ImageModel = "dall-e-2" + ImageModelDallE3 ImageModel = "dall-e-3" + ImageModelGPTImage1 ImageModel = "gpt-image-1" +) + +// The response from the image generation endpoint. +type ImagesResponse struct { + // The Unix timestamp (in seconds) of when the image was created. + Created int64 `json:"created,required"` + // The background parameter used for the image generation. Either `transparent` or + // `opaque`. + // + // Any of "transparent", "opaque". + Background ImagesResponseBackground `json:"background"` + // The list of generated images. + Data []Image `json:"data"` + // The output format of the image generation. Either `png`, `webp`, or `jpeg`. + // + // Any of "png", "webp", "jpeg". + OutputFormat ImagesResponseOutputFormat `json:"output_format"` + // The quality of the image generated. Either `low`, `medium`, or `high`. + // + // Any of "low", "medium", "high". + Quality ImagesResponseQuality `json:"quality"` + // The size of the image generated. Either `1024x1024`, `1024x1536`, or + // `1536x1024`. + // + // Any of "1024x1024", "1024x1536", "1536x1024". + Size ImagesResponseSize `json:"size"` + // For `gpt-image-1` only, the token usage information for the image generation. + Usage ImagesResponseUsage `json:"usage"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Created respjson.Field + Background respjson.Field + Data respjson.Field + OutputFormat respjson.Field + Quality respjson.Field + Size respjson.Field + Usage respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ImagesResponse) RawJSON() string { return r.JSON.raw } +func (r *ImagesResponse) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The background parameter used for the image generation. Either `transparent` or +// `opaque`. +type ImagesResponseBackground string + +const ( + ImagesResponseBackgroundTransparent ImagesResponseBackground = "transparent" + ImagesResponseBackgroundOpaque ImagesResponseBackground = "opaque" +) + +// The output format of the image generation. Either `png`, `webp`, or `jpeg`. +type ImagesResponseOutputFormat string + +const ( + ImagesResponseOutputFormatPNG ImagesResponseOutputFormat = "png" + ImagesResponseOutputFormatWebP ImagesResponseOutputFormat = "webp" + ImagesResponseOutputFormatJPEG ImagesResponseOutputFormat = "jpeg" +) + +// The quality of the image generated. Either `low`, `medium`, or `high`. +type ImagesResponseQuality string + +const ( + ImagesResponseQualityLow ImagesResponseQuality = "low" + ImagesResponseQualityMedium ImagesResponseQuality = "medium" + ImagesResponseQualityHigh ImagesResponseQuality = "high" +) + +// The size of the image generated. Either `1024x1024`, `1024x1536`, or +// `1536x1024`. +type ImagesResponseSize string + +const ( + ImagesResponseSize1024x1024 ImagesResponseSize = "1024x1024" + ImagesResponseSize1024x1536 ImagesResponseSize = "1024x1536" + ImagesResponseSize1536x1024 ImagesResponseSize = "1536x1024" +) + +// For `gpt-image-1` only, the token usage information for the image generation. +type ImagesResponseUsage struct { + // The number of tokens (images and text) in the input prompt. + InputTokens int64 `json:"input_tokens,required"` + // The input tokens detailed information for the image generation. + InputTokensDetails ImagesResponseUsageInputTokensDetails `json:"input_tokens_details,required"` + // The number of output tokens generated by the model. + OutputTokens int64 `json:"output_tokens,required"` + // The total number of tokens (images and text) used for the image generation. + TotalTokens int64 `json:"total_tokens,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + InputTokens respjson.Field + InputTokensDetails respjson.Field + OutputTokens respjson.Field + TotalTokens respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ImagesResponseUsage) RawJSON() string { return r.JSON.raw } +func (r *ImagesResponseUsage) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The input tokens detailed information for the image generation. +type ImagesResponseUsageInputTokensDetails struct { + // The number of image tokens in the input prompt. + ImageTokens int64 `json:"image_tokens,required"` + // The number of text tokens in the input prompt. + TextTokens int64 `json:"text_tokens,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ImageTokens respjson.Field + TextTokens respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r ImagesResponseUsageInputTokensDetails) RawJSON() string { return r.JSON.raw } +func (r *ImagesResponseUsageInputTokensDetails) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type ImageNewVariationParams struct { + // The image to use as the basis for the variation(s). Must be a valid PNG file, + // less than 4MB, and square. + Image io.Reader `json:"image,omitzero,required" format:"binary"` + // The number of images to generate. Must be between 1 and 10. + N param.Opt[int64] `json:"n,omitzero"` + // A unique identifier representing your end-user, which can help OpenAI to monitor + // and detect abuse. + // [Learn more](https://platform.openai.com/docs/guides/safety-best-practices#end-user-ids). + User param.Opt[string] `json:"user,omitzero"` + // The model to use for image generation. Only `dall-e-2` is supported at this + // time. + Model ImageModel `json:"model,omitzero"` + // The format in which the generated images are returned. Must be one of `url` or + // `b64_json`. URLs are only valid for 60 minutes after the image has been + // generated. + // + // Any of "url", "b64_json". + ResponseFormat ImageNewVariationParamsResponseFormat `json:"response_format,omitzero"` + // The size of the generated images. Must be one of `256x256`, `512x512`, or + // `1024x1024`. + // + // Any of "256x256", "512x512", "1024x1024". + Size ImageNewVariationParamsSize `json:"size,omitzero"` + paramObj +} + +func (r ImageNewVariationParams) MarshalMultipart() (data []byte, contentType string, err error) { + buf := bytes.NewBuffer(nil) + writer := multipart.NewWriter(buf) + err = apiform.MarshalRoot(r, writer) + if err == nil { + err = apiform.WriteExtras(writer, r.ExtraFields()) + } + if err != nil { + writer.Close() + return nil, "", err + } + err = writer.Close() + if err != nil { + return nil, "", err + } + return buf.Bytes(), writer.FormDataContentType(), nil +} + +// The format in which the generated images are returned. Must be one of `url` or +// `b64_json`. URLs are only valid for 60 minutes after the image has been +// generated. +type ImageNewVariationParamsResponseFormat string + +const ( + ImageNewVariationParamsResponseFormatURL ImageNewVariationParamsResponseFormat = "url" + ImageNewVariationParamsResponseFormatB64JSON ImageNewVariationParamsResponseFormat = "b64_json" +) + +// The size of the generated images. Must be one of `256x256`, `512x512`, or +// `1024x1024`. +type ImageNewVariationParamsSize string + +const ( + ImageNewVariationParamsSize256x256 ImageNewVariationParamsSize = "256x256" + ImageNewVariationParamsSize512x512 ImageNewVariationParamsSize = "512x512" + ImageNewVariationParamsSize1024x1024 ImageNewVariationParamsSize = "1024x1024" +) + +type ImageEditParams struct { + // The image(s) to edit. Must be a supported image file or an array of images. + // + // For `gpt-image-1`, each image should be a `png`, `webp`, or `jpg` file less than + // 50MB. You can provide up to 16 images. + // + // For `dall-e-2`, you can only provide one image, and it should be a square `png` + // file less than 4MB. + Image ImageEditParamsImageUnion `json:"image,omitzero,required" format:"binary"` + // A text description of the desired image(s). The maximum length is 1000 + // characters for `dall-e-2`, and 32000 characters for `gpt-image-1`. + Prompt string `json:"prompt,required"` + // The number of images to generate. Must be between 1 and 10. + N param.Opt[int64] `json:"n,omitzero"` + // The compression level (0-100%) for the generated images. This parameter is only + // supported for `gpt-image-1` with the `webp` or `jpeg` output formats, and + // defaults to 100. + OutputCompression param.Opt[int64] `json:"output_compression,omitzero"` + // The number of partial images to generate. This parameter is used for streaming + // responses that return partial images. Value must be between 0 and 3. When set to + // 0, the response will be a single image sent in one streaming event. + // + // Note that the final image may be sent before the full number of partial images + // are generated if the full image is generated more quickly. + PartialImages param.Opt[int64] `json:"partial_images,omitzero"` + // A unique identifier representing your end-user, which can help OpenAI to monitor + // and detect abuse. + // [Learn more](https://platform.openai.com/docs/guides/safety-best-practices#end-user-ids). + User param.Opt[string] `json:"user,omitzero"` + // Allows to set transparency for the background of the generated image(s). This + // parameter is only supported for `gpt-image-1`. Must be one of `transparent`, + // `opaque` or `auto` (default value). When `auto` is used, the model will + // automatically determine the best background for the image. + // + // If `transparent`, the output format needs to support transparency, so it should + // be set to either `png` (default value) or `webp`. + // + // Any of "transparent", "opaque", "auto". + Background ImageEditParamsBackground `json:"background,omitzero"` + // Control how much effort the model will exert to match the style and features, + // especially facial features, of input images. This parameter is only supported + // for `gpt-image-1`. Supports `high` and `low`. Defaults to `low`. + // + // Any of "high", "low". + InputFidelity ImageEditParamsInputFidelity `json:"input_fidelity,omitzero"` + // The model to use for image generation. Only `dall-e-2` and `gpt-image-1` are + // supported. Defaults to `dall-e-2` unless a parameter specific to `gpt-image-1` + // is used. + Model ImageModel `json:"model,omitzero"` + // The format in which the generated images are returned. This parameter is only + // supported for `gpt-image-1`. Must be one of `png`, `jpeg`, or `webp`. The + // default value is `png`. + // + // Any of "png", "jpeg", "webp". + OutputFormat ImageEditParamsOutputFormat `json:"output_format,omitzero"` + // The quality of the image that will be generated. `high`, `medium` and `low` are + // only supported for `gpt-image-1`. `dall-e-2` only supports `standard` quality. + // Defaults to `auto`. + // + // Any of "standard", "low", "medium", "high", "auto". + Quality ImageEditParamsQuality `json:"quality,omitzero"` + // The format in which the generated images are returned. Must be one of `url` or + // `b64_json`. URLs are only valid for 60 minutes after the image has been + // generated. This parameter is only supported for `dall-e-2`, as `gpt-image-1` + // will always return base64-encoded images. + // + // Any of "url", "b64_json". + ResponseFormat ImageEditParamsResponseFormat `json:"response_format,omitzero"` + // The size of the generated images. Must be one of `1024x1024`, `1536x1024` + // (landscape), `1024x1536` (portrait), or `auto` (default value) for + // `gpt-image-1`, and one of `256x256`, `512x512`, or `1024x1024` for `dall-e-2`. + // + // Any of "256x256", "512x512", "1024x1024", "1536x1024", "1024x1536", "auto". + Size ImageEditParamsSize `json:"size,omitzero"` + // An additional image whose fully transparent areas (e.g. where alpha is zero) + // indicate where `image` should be edited. If there are multiple images provided, + // the mask will be applied on the first image. Must be a valid PNG file, less than + // 4MB, and have the same dimensions as `image`. + Mask io.Reader `json:"mask,omitzero" format:"binary"` + paramObj +} + +func (r ImageEditParams) MarshalMultipart() (data []byte, contentType string, err error) { + buf := bytes.NewBuffer(nil) + writer := multipart.NewWriter(buf) + err = apiform.MarshalRoot(r, writer) + if err == nil { + err = apiform.WriteExtras(writer, r.ExtraFields()) + } + if err != nil { + writer.Close() + return nil, "", err + } + err = writer.Close() + if err != nil { + return nil, "", err + } + return buf.Bytes(), writer.FormDataContentType(), nil +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type ImageEditParamsImageUnion struct { + OfFile io.Reader `json:",omitzero,inline"` + OfFileArray []io.Reader `json:",omitzero,inline"` + paramUnion +} + +func (u ImageEditParamsImageUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfFile, u.OfFileArray) +} +func (u *ImageEditParamsImageUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *ImageEditParamsImageUnion) asAny() any { + if !param.IsOmitted(u.OfFile) { + return &u.OfFile + } else if !param.IsOmitted(u.OfFileArray) { + return &u.OfFileArray + } + return nil +} + +// Allows to set transparency for the background of the generated image(s). This +// parameter is only supported for `gpt-image-1`. Must be one of `transparent`, +// `opaque` or `auto` (default value). When `auto` is used, the model will +// automatically determine the best background for the image. +// +// If `transparent`, the output format needs to support transparency, so it should +// be set to either `png` (default value) or `webp`. +type ImageEditParamsBackground string + +const ( + ImageEditParamsBackgroundTransparent ImageEditParamsBackground = "transparent" + ImageEditParamsBackgroundOpaque ImageEditParamsBackground = "opaque" + ImageEditParamsBackgroundAuto ImageEditParamsBackground = "auto" +) + +// Control how much effort the model will exert to match the style and features, +// especially facial features, of input images. This parameter is only supported +// for `gpt-image-1`. Supports `high` and `low`. Defaults to `low`. +type ImageEditParamsInputFidelity string + +const ( + ImageEditParamsInputFidelityHigh ImageEditParamsInputFidelity = "high" + ImageEditParamsInputFidelityLow ImageEditParamsInputFidelity = "low" +) + +// The format in which the generated images are returned. This parameter is only +// supported for `gpt-image-1`. Must be one of `png`, `jpeg`, or `webp`. The +// default value is `png`. +type ImageEditParamsOutputFormat string + +const ( + ImageEditParamsOutputFormatPNG ImageEditParamsOutputFormat = "png" + ImageEditParamsOutputFormatJPEG ImageEditParamsOutputFormat = "jpeg" + ImageEditParamsOutputFormatWebP ImageEditParamsOutputFormat = "webp" +) + +// The quality of the image that will be generated. `high`, `medium` and `low` are +// only supported for `gpt-image-1`. `dall-e-2` only supports `standard` quality. +// Defaults to `auto`. +type ImageEditParamsQuality string + +const ( + ImageEditParamsQualityStandard ImageEditParamsQuality = "standard" + ImageEditParamsQualityLow ImageEditParamsQuality = "low" + ImageEditParamsQualityMedium ImageEditParamsQuality = "medium" + ImageEditParamsQualityHigh ImageEditParamsQuality = "high" + ImageEditParamsQualityAuto ImageEditParamsQuality = "auto" +) + +// The format in which the generated images are returned. Must be one of `url` or +// `b64_json`. URLs are only valid for 60 minutes after the image has been +// generated. This parameter is only supported for `dall-e-2`, as `gpt-image-1` +// will always return base64-encoded images. +type ImageEditParamsResponseFormat string + +const ( + ImageEditParamsResponseFormatURL ImageEditParamsResponseFormat = "url" + ImageEditParamsResponseFormatB64JSON ImageEditParamsResponseFormat = "b64_json" +) + +// The size of the generated images. Must be one of `1024x1024`, `1536x1024` +// (landscape), `1024x1536` (portrait), or `auto` (default value) for +// `gpt-image-1`, and one of `256x256`, `512x512`, or `1024x1024` for `dall-e-2`. +type ImageEditParamsSize string + +const ( + ImageEditParamsSize256x256 ImageEditParamsSize = "256x256" + ImageEditParamsSize512x512 ImageEditParamsSize = "512x512" + ImageEditParamsSize1024x1024 ImageEditParamsSize = "1024x1024" + ImageEditParamsSize1536x1024 ImageEditParamsSize = "1536x1024" + ImageEditParamsSize1024x1536 ImageEditParamsSize = "1024x1536" + ImageEditParamsSizeAuto ImageEditParamsSize = "auto" +) + +type ImageGenerateParams struct { + // A text description of the desired image(s). The maximum length is 32000 + // characters for `gpt-image-1`, 1000 characters for `dall-e-2` and 4000 characters + // for `dall-e-3`. + Prompt string `json:"prompt,required"` + // The number of images to generate. Must be between 1 and 10. For `dall-e-3`, only + // `n=1` is supported. + N param.Opt[int64] `json:"n,omitzero"` + // The compression level (0-100%) for the generated images. This parameter is only + // supported for `gpt-image-1` with the `webp` or `jpeg` output formats, and + // defaults to 100. + OutputCompression param.Opt[int64] `json:"output_compression,omitzero"` + // The number of partial images to generate. This parameter is used for streaming + // responses that return partial images. Value must be between 0 and 3. When set to + // 0, the response will be a single image sent in one streaming event. + // + // Note that the final image may be sent before the full number of partial images + // are generated if the full image is generated more quickly. + PartialImages param.Opt[int64] `json:"partial_images,omitzero"` + // A unique identifier representing your end-user, which can help OpenAI to monitor + // and detect abuse. + // [Learn more](https://platform.openai.com/docs/guides/safety-best-practices#end-user-ids). + User param.Opt[string] `json:"user,omitzero"` + // Allows to set transparency for the background of the generated image(s). This + // parameter is only supported for `gpt-image-1`. Must be one of `transparent`, + // `opaque` or `auto` (default value). When `auto` is used, the model will + // automatically determine the best background for the image. + // + // If `transparent`, the output format needs to support transparency, so it should + // be set to either `png` (default value) or `webp`. + // + // Any of "transparent", "opaque", "auto". + Background ImageGenerateParamsBackground `json:"background,omitzero"` + // The model to use for image generation. One of `dall-e-2`, `dall-e-3`, or + // `gpt-image-1`. Defaults to `dall-e-2` unless a parameter specific to + // `gpt-image-1` is used. + Model ImageModel `json:"model,omitzero"` + // Control the content-moderation level for images generated by `gpt-image-1`. Must + // be either `low` for less restrictive filtering or `auto` (default value). + // + // Any of "low", "auto". + Moderation ImageGenerateParamsModeration `json:"moderation,omitzero"` + // The format in which the generated images are returned. This parameter is only + // supported for `gpt-image-1`. Must be one of `png`, `jpeg`, or `webp`. + // + // Any of "png", "jpeg", "webp". + OutputFormat ImageGenerateParamsOutputFormat `json:"output_format,omitzero"` + // The quality of the image that will be generated. + // + // - `auto` (default value) will automatically select the best quality for the + // given model. + // - `high`, `medium` and `low` are supported for `gpt-image-1`. + // - `hd` and `standard` are supported for `dall-e-3`. + // - `standard` is the only option for `dall-e-2`. + // + // Any of "standard", "hd", "low", "medium", "high", "auto". + Quality ImageGenerateParamsQuality `json:"quality,omitzero"` + // The format in which generated images with `dall-e-2` and `dall-e-3` are + // returned. Must be one of `url` or `b64_json`. URLs are only valid for 60 minutes + // after the image has been generated. This parameter isn't supported for + // `gpt-image-1` which will always return base64-encoded images. + // + // Any of "url", "b64_json". + ResponseFormat ImageGenerateParamsResponseFormat `json:"response_format,omitzero"` + // The size of the generated images. Must be one of `1024x1024`, `1536x1024` + // (landscape), `1024x1536` (portrait), or `auto` (default value) for + // `gpt-image-1`, one of `256x256`, `512x512`, or `1024x1024` for `dall-e-2`, and + // one of `1024x1024`, `1792x1024`, or `1024x1792` for `dall-e-3`. + // + // Any of "auto", "1024x1024", "1536x1024", "1024x1536", "256x256", "512x512", + // "1792x1024", "1024x1792". + Size ImageGenerateParamsSize `json:"size,omitzero"` + // The style of the generated images. This parameter is only supported for + // `dall-e-3`. Must be one of `vivid` or `natural`. Vivid causes the model to lean + // towards generating hyper-real and dramatic images. Natural causes the model to + // produce more natural, less hyper-real looking images. + // + // Any of "vivid", "natural". + Style ImageGenerateParamsStyle `json:"style,omitzero"` + paramObj +} + +func (r ImageGenerateParams) MarshalJSON() (data []byte, err error) { + type shadow ImageGenerateParams + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *ImageGenerateParams) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Allows to set transparency for the background of the generated image(s). This +// parameter is only supported for `gpt-image-1`. Must be one of `transparent`, +// `opaque` or `auto` (default value). When `auto` is used, the model will +// automatically determine the best background for the image. +// +// If `transparent`, the output format needs to support transparency, so it should +// be set to either `png` (default value) or `webp`. +type ImageGenerateParamsBackground string + +const ( + ImageGenerateParamsBackgroundTransparent ImageGenerateParamsBackground = "transparent" + ImageGenerateParamsBackgroundOpaque ImageGenerateParamsBackground = "opaque" + ImageGenerateParamsBackgroundAuto ImageGenerateParamsBackground = "auto" +) + +// Control the content-moderation level for images generated by `gpt-image-1`. Must +// be either `low` for less restrictive filtering or `auto` (default value). +type ImageGenerateParamsModeration string + +const ( + ImageGenerateParamsModerationLow ImageGenerateParamsModeration = "low" + ImageGenerateParamsModerationAuto ImageGenerateParamsModeration = "auto" +) + +// The format in which the generated images are returned. This parameter is only +// supported for `gpt-image-1`. Must be one of `png`, `jpeg`, or `webp`. +type ImageGenerateParamsOutputFormat string + +const ( + ImageGenerateParamsOutputFormatPNG ImageGenerateParamsOutputFormat = "png" + ImageGenerateParamsOutputFormatJPEG ImageGenerateParamsOutputFormat = "jpeg" + ImageGenerateParamsOutputFormatWebP ImageGenerateParamsOutputFormat = "webp" +) + +// The quality of the image that will be generated. +// +// - `auto` (default value) will automatically select the best quality for the +// given model. +// - `high`, `medium` and `low` are supported for `gpt-image-1`. +// - `hd` and `standard` are supported for `dall-e-3`. +// - `standard` is the only option for `dall-e-2`. +type ImageGenerateParamsQuality string + +const ( + ImageGenerateParamsQualityStandard ImageGenerateParamsQuality = "standard" + ImageGenerateParamsQualityHD ImageGenerateParamsQuality = "hd" + ImageGenerateParamsQualityLow ImageGenerateParamsQuality = "low" + ImageGenerateParamsQualityMedium ImageGenerateParamsQuality = "medium" + ImageGenerateParamsQualityHigh ImageGenerateParamsQuality = "high" + ImageGenerateParamsQualityAuto ImageGenerateParamsQuality = "auto" +) + +// The format in which generated images with `dall-e-2` and `dall-e-3` are +// returned. Must be one of `url` or `b64_json`. URLs are only valid for 60 minutes +// after the image has been generated. This parameter isn't supported for +// `gpt-image-1` which will always return base64-encoded images. +type ImageGenerateParamsResponseFormat string + +const ( + ImageGenerateParamsResponseFormatURL ImageGenerateParamsResponseFormat = "url" + ImageGenerateParamsResponseFormatB64JSON ImageGenerateParamsResponseFormat = "b64_json" +) + +// The size of the generated images. Must be one of `1024x1024`, `1536x1024` +// (landscape), `1024x1536` (portrait), or `auto` (default value) for +// `gpt-image-1`, one of `256x256`, `512x512`, or `1024x1024` for `dall-e-2`, and +// one of `1024x1024`, `1792x1024`, or `1024x1792` for `dall-e-3`. +type ImageGenerateParamsSize string + +const ( + ImageGenerateParamsSizeAuto ImageGenerateParamsSize = "auto" + ImageGenerateParamsSize1024x1024 ImageGenerateParamsSize = "1024x1024" + ImageGenerateParamsSize1536x1024 ImageGenerateParamsSize = "1536x1024" + ImageGenerateParamsSize1024x1536 ImageGenerateParamsSize = "1024x1536" + ImageGenerateParamsSize256x256 ImageGenerateParamsSize = "256x256" + ImageGenerateParamsSize512x512 ImageGenerateParamsSize = "512x512" + ImageGenerateParamsSize1792x1024 ImageGenerateParamsSize = "1792x1024" + ImageGenerateParamsSize1024x1792 ImageGenerateParamsSize = "1024x1792" +) + +// The style of the generated images. This parameter is only supported for +// `dall-e-3`. Must be one of `vivid` or `natural`. Vivid causes the model to lean +// towards generating hyper-real and dramatic images. Natural causes the model to +// produce more natural, less hyper-real looking images. +type ImageGenerateParamsStyle string + +const ( + ImageGenerateParamsStyleVivid ImageGenerateParamsStyle = "vivid" + ImageGenerateParamsStyleNatural ImageGenerateParamsStyle = "natural" +) diff --git a/vendor/github.com/openai/openai-go/internal/apierror/apierror.go b/vendor/github.com/openai/openai-go/internal/apierror/apierror.go new file mode 100644 index 0000000000..1b3b9e0319 --- /dev/null +++ b/vendor/github.com/openai/openai-go/internal/apierror/apierror.go @@ -0,0 +1,58 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package apierror + +import ( + "fmt" + "net/http" + "net/http/httputil" + + "github.com/openai/openai-go/internal/apijson" + "github.com/openai/openai-go/packages/respjson" +) + +// Error represents an error that originates from the API, i.e. when a request is +// made and the API returns a response with a HTTP status code. Other errors are +// not wrapped by this SDK. +type Error struct { + Code string `json:"code,required"` + Message string `json:"message,required"` + Param string `json:"param,required"` + Type string `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Code respjson.Field + Message respjson.Field + Param respjson.Field + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` + StatusCode int + Request *http.Request + Response *http.Response +} + +// Returns the unmodified JSON received from the API +func (r Error) RawJSON() string { return r.JSON.raw } +func (r *Error) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +func (r *Error) Error() string { + // Attempt to re-populate the response body + return fmt.Sprintf("%s %q: %d %s %s", r.Request.Method, r.Request.URL, r.Response.StatusCode, http.StatusText(r.Response.StatusCode), r.JSON.raw) +} + +func (r *Error) DumpRequest(body bool) []byte { + if r.Request.GetBody != nil { + r.Request.Body, _ = r.Request.GetBody() + } + out, _ := httputil.DumpRequestOut(r.Request, body) + return out +} + +func (r *Error) DumpResponse(body bool) []byte { + out, _ := httputil.DumpResponse(r.Response, body) + return out +} diff --git a/vendor/github.com/openai/openai-go/internal/apiform/encoder.go b/vendor/github.com/openai/openai-go/internal/apiform/encoder.go new file mode 100644 index 0000000000..f1bd16497c --- /dev/null +++ b/vendor/github.com/openai/openai-go/internal/apiform/encoder.go @@ -0,0 +1,465 @@ +package apiform + +import ( + "fmt" + "io" + "mime/multipart" + "net/textproto" + "path" + "reflect" + "sort" + "strconv" + "strings" + "sync" + "time" + + "github.com/openai/openai-go/packages/param" +) + +var encoders sync.Map // map[encoderEntry]encoderFunc + +func Marshal(value any, writer *multipart.Writer) error { + e := &encoder{ + dateFormat: time.RFC3339, + arrayFmt: "brackets", + } + return e.marshal(value, writer) +} + +func MarshalRoot(value any, writer *multipart.Writer) error { + e := &encoder{ + root: true, + dateFormat: time.RFC3339, + arrayFmt: "brackets", + } + return e.marshal(value, writer) +} + +func MarshalWithSettings(value any, writer *multipart.Writer, arrayFormat string) error { + e := &encoder{ + arrayFmt: arrayFormat, + dateFormat: time.RFC3339, + } + return e.marshal(value, writer) +} + +type encoder struct { + arrayFmt string + dateFormat string + root bool +} + +type encoderFunc func(key string, value reflect.Value, writer *multipart.Writer) error + +type encoderField struct { + tag parsedStructTag + fn encoderFunc + idx []int +} + +type encoderEntry struct { + reflect.Type + dateFormat string + root bool +} + +func (e *encoder) marshal(value any, writer *multipart.Writer) error { + val := reflect.ValueOf(value) + if !val.IsValid() { + return nil + } + typ := val.Type() + enc := e.typeEncoder(typ) + return enc("", val, writer) +} + +func (e *encoder) typeEncoder(t reflect.Type) encoderFunc { + entry := encoderEntry{ + Type: t, + dateFormat: e.dateFormat, + root: e.root, + } + + if fi, ok := encoders.Load(entry); ok { + return fi.(encoderFunc) + } + + // To deal with recursive types, populate the map with an + // indirect func before we build it. This type waits on the + // real func (f) to be ready and then calls it. This indirect + // func is only used for recursive types. + var ( + wg sync.WaitGroup + f encoderFunc + ) + wg.Add(1) + fi, loaded := encoders.LoadOrStore(entry, encoderFunc(func(key string, v reflect.Value, writer *multipart.Writer) error { + wg.Wait() + return f(key, v, writer) + })) + if loaded { + return fi.(encoderFunc) + } + + // Compute the real encoder and replace the indirect func with it. + f = e.newTypeEncoder(t) + wg.Done() + encoders.Store(entry, f) + return f +} + +func (e *encoder) newTypeEncoder(t reflect.Type) encoderFunc { + if t.ConvertibleTo(reflect.TypeOf(time.Time{})) { + return e.newTimeTypeEncoder() + } + if t.Implements(reflect.TypeOf((*io.Reader)(nil)).Elem()) { + return e.newReaderTypeEncoder() + } + e.root = false + switch t.Kind() { + case reflect.Pointer: + inner := t.Elem() + + innerEncoder := e.typeEncoder(inner) + return func(key string, v reflect.Value, writer *multipart.Writer) error { + if !v.IsValid() || v.IsNil() { + return nil + } + return innerEncoder(key, v.Elem(), writer) + } + case reflect.Struct: + return e.newStructTypeEncoder(t) + case reflect.Slice, reflect.Array: + return e.newArrayTypeEncoder(t) + case reflect.Map: + return e.newMapEncoder(t) + case reflect.Interface: + return e.newInterfaceEncoder() + default: + return e.newPrimitiveTypeEncoder(t) + } +} + +func (e *encoder) newPrimitiveTypeEncoder(t reflect.Type) encoderFunc { + switch t.Kind() { + // Note that we could use `gjson` to encode these types but it would complicate our + // code more and this current code shouldn't cause any issues + case reflect.String: + return func(key string, v reflect.Value, writer *multipart.Writer) error { + return writer.WriteField(key, v.String()) + } + case reflect.Bool: + return func(key string, v reflect.Value, writer *multipart.Writer) error { + if v.Bool() { + return writer.WriteField(key, "true") + } + return writer.WriteField(key, "false") + } + case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: + return func(key string, v reflect.Value, writer *multipart.Writer) error { + return writer.WriteField(key, strconv.FormatInt(v.Int(), 10)) + } + case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return func(key string, v reflect.Value, writer *multipart.Writer) error { + return writer.WriteField(key, strconv.FormatUint(v.Uint(), 10)) + } + case reflect.Float32: + return func(key string, v reflect.Value, writer *multipart.Writer) error { + return writer.WriteField(key, strconv.FormatFloat(v.Float(), 'f', -1, 32)) + } + case reflect.Float64: + return func(key string, v reflect.Value, writer *multipart.Writer) error { + return writer.WriteField(key, strconv.FormatFloat(v.Float(), 'f', -1, 64)) + } + default: + return func(key string, v reflect.Value, writer *multipart.Writer) error { + return fmt.Errorf("unknown type received at primitive encoder: %s", t.String()) + } + } +} + +func arrayKeyEncoder(arrayFmt string) func(string, int) string { + var keyFn func(string, int) string + switch arrayFmt { + case "comma", "repeat": + keyFn = func(k string, _ int) string { return k } + case "brackets": + keyFn = func(key string, _ int) string { return key + "[]" } + case "indices:dots": + keyFn = func(k string, i int) string { + if k == "" { + return strconv.Itoa(i) + } + return k + "." + strconv.Itoa(i) + } + case "indices:brackets": + keyFn = func(k string, i int) string { + if k == "" { + return strconv.Itoa(i) + } + return k + "[" + strconv.Itoa(i) + "]" + } + } + return keyFn +} + +func (e *encoder) newArrayTypeEncoder(t reflect.Type) encoderFunc { + itemEncoder := e.typeEncoder(t.Elem()) + keyFn := arrayKeyEncoder(e.arrayFmt) + return func(key string, v reflect.Value, writer *multipart.Writer) error { + if keyFn == nil { + return fmt.Errorf("apiform: unsupported array format") + } + for i := 0; i < v.Len(); i++ { + err := itemEncoder(keyFn(key, i), v.Index(i), writer) + if err != nil { + return err + } + } + return nil + } +} + +func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc { + if t.Implements(reflect.TypeOf((*param.Optional)(nil)).Elem()) { + return e.newRichFieldTypeEncoder(t) + } + + for i := 0; i < t.NumField(); i++ { + if t.Field(i).Type == paramUnionType && t.Field(i).Anonymous { + return e.newStructUnionTypeEncoder(t) + } + } + + encoderFields := []encoderField{} + extraEncoder := (*encoderField)(nil) + + // This helper allows us to recursively collect field encoders into a flat + // array. The parameter `index` keeps track of the access patterns necessary + // to get to some field. + var collectEncoderFields func(r reflect.Type, index []int) + collectEncoderFields = func(r reflect.Type, index []int) { + for i := 0; i < r.NumField(); i++ { + idx := append(index, i) + field := t.FieldByIndex(idx) + if !field.IsExported() { + continue + } + // If this is an embedded struct, traverse one level deeper to extract + // the field and get their encoders as well. + if field.Anonymous { + collectEncoderFields(field.Type, idx) + continue + } + // If json tag is not present, then we skip, which is intentionally + // different behavior from the stdlib. + ptag, ok := parseFormStructTag(field) + if !ok { + continue + } + // We only want to support unexported field if they're tagged with + // `extras` because that field shouldn't be part of the public API. We + // also want to only keep the top level extras + if ptag.extras && len(index) == 0 { + extraEncoder = &encoderField{ptag, e.typeEncoder(field.Type.Elem()), idx} + continue + } + if ptag.name == "-" || ptag.name == "" { + continue + } + + dateFormat, ok := parseFormatStructTag(field) + oldFormat := e.dateFormat + if ok { + switch dateFormat { + case "date-time": + e.dateFormat = time.RFC3339 + case "date": + e.dateFormat = "2006-01-02" + } + } + + var encoderFn encoderFunc + if ptag.omitzero { + typeEncoderFn := e.typeEncoder(field.Type) + encoderFn = func(key string, value reflect.Value, writer *multipart.Writer) error { + if value.IsZero() { + return nil + } + return typeEncoderFn(key, value, writer) + } + } else { + encoderFn = e.typeEncoder(field.Type) + } + encoderFields = append(encoderFields, encoderField{ptag, encoderFn, idx}) + e.dateFormat = oldFormat + } + } + collectEncoderFields(t, []int{}) + + // Ensure deterministic output by sorting by lexicographic order + sort.Slice(encoderFields, func(i, j int) bool { + return encoderFields[i].tag.name < encoderFields[j].tag.name + }) + + return func(key string, value reflect.Value, writer *multipart.Writer) error { + if key != "" { + key = key + "." + } + + for _, ef := range encoderFields { + field := value.FieldByIndex(ef.idx) + err := ef.fn(key+ef.tag.name, field, writer) + if err != nil { + return err + } + } + + if extraEncoder != nil { + err := e.encodeMapEntries(key, value.FieldByIndex(extraEncoder.idx), writer) + if err != nil { + return err + } + } + + return nil + } +} + +var paramUnionType = reflect.TypeOf((*param.APIUnion)(nil)).Elem() + +func (e *encoder) newStructUnionTypeEncoder(t reflect.Type) encoderFunc { + var fieldEncoders []encoderFunc + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if field.Type == paramUnionType && field.Anonymous { + fieldEncoders = append(fieldEncoders, nil) + continue + } + fieldEncoders = append(fieldEncoders, e.typeEncoder(field.Type)) + } + + return func(key string, value reflect.Value, writer *multipart.Writer) error { + for i := 0; i < t.NumField(); i++ { + if value.Field(i).Type() == paramUnionType { + continue + } + if !value.Field(i).IsZero() { + return fieldEncoders[i](key, value.Field(i), writer) + } + } + return fmt.Errorf("apiform: union %s has no field set", t.String()) + } +} + +func (e *encoder) newTimeTypeEncoder() encoderFunc { + format := e.dateFormat + return func(key string, value reflect.Value, writer *multipart.Writer) error { + return writer.WriteField(key, value.Convert(reflect.TypeOf(time.Time{})).Interface().(time.Time).Format(format)) + } +} + +func (e encoder) newInterfaceEncoder() encoderFunc { + return func(key string, value reflect.Value, writer *multipart.Writer) error { + value = value.Elem() + if !value.IsValid() { + return nil + } + return e.typeEncoder(value.Type())(key, value, writer) + } +} + +var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"") + +func escapeQuotes(s string) string { + return quoteEscaper.Replace(s) +} + +func (e *encoder) newReaderTypeEncoder() encoderFunc { + return func(key string, value reflect.Value, writer *multipart.Writer) error { + reader, ok := value.Convert(reflect.TypeOf((*io.Reader)(nil)).Elem()).Interface().(io.Reader) + if !ok { + return nil + } + filename := "anonymous_file" + contentType := "application/octet-stream" + if named, ok := reader.(interface{ Filename() string }); ok { + filename = named.Filename() + } else if named, ok := reader.(interface{ Name() string }); ok { + filename = path.Base(named.Name()) + } + if typed, ok := reader.(interface{ ContentType() string }); ok { + contentType = typed.ContentType() + } + + // Below is taken almost 1-for-1 from [multipart.CreateFormFile] + h := make(textproto.MIMEHeader) + h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, escapeQuotes(key), escapeQuotes(filename))) + h.Set("Content-Type", contentType) + filewriter, err := writer.CreatePart(h) + if err != nil { + return err + } + _, err = io.Copy(filewriter, reader) + return err + } +} + +// Given a []byte of json (may either be an empty object or an object that already contains entries) +// encode all of the entries in the map to the json byte array. +func (e *encoder) encodeMapEntries(key string, v reflect.Value, writer *multipart.Writer) error { + type mapPair struct { + key string + value reflect.Value + } + + if key != "" { + key = key + "." + } + + pairs := []mapPair{} + + iter := v.MapRange() + for iter.Next() { + if iter.Key().Type().Kind() == reflect.String { + pairs = append(pairs, mapPair{key: iter.Key().String(), value: iter.Value()}) + } else { + return fmt.Errorf("cannot encode a map with a non string key") + } + } + + // Ensure deterministic output + sort.Slice(pairs, func(i, j int) bool { + return pairs[i].key < pairs[j].key + }) + + elementEncoder := e.typeEncoder(v.Type().Elem()) + for _, p := range pairs { + err := elementEncoder(key+string(p.key), p.value, writer) + if err != nil { + return err + } + } + + return nil +} + +func (e *encoder) newMapEncoder(_ reflect.Type) encoderFunc { + return func(key string, value reflect.Value, writer *multipart.Writer) error { + return e.encodeMapEntries(key, value, writer) + } +} + +func WriteExtras(writer *multipart.Writer, extras map[string]any) (err error) { + for k, v := range extras { + str, ok := v.(string) + if !ok { + break + } + err = writer.WriteField(k, str) + if err != nil { + break + } + } + return +} diff --git a/vendor/github.com/openai/openai-go/internal/apiform/form.go b/vendor/github.com/openai/openai-go/internal/apiform/form.go new file mode 100644 index 0000000000..5445116e99 --- /dev/null +++ b/vendor/github.com/openai/openai-go/internal/apiform/form.go @@ -0,0 +1,5 @@ +package apiform + +type Marshaler interface { + MarshalMultipart() ([]byte, string, error) +} diff --git a/vendor/github.com/openai/openai-go/internal/apiform/richparam.go b/vendor/github.com/openai/openai-go/internal/apiform/richparam.go new file mode 100644 index 0000000000..690a87b7ef --- /dev/null +++ b/vendor/github.com/openai/openai-go/internal/apiform/richparam.go @@ -0,0 +1,20 @@ +package apiform + +import ( + "github.com/openai/openai-go/packages/param" + "mime/multipart" + "reflect" +) + +func (e *encoder) newRichFieldTypeEncoder(t reflect.Type) encoderFunc { + f, _ := t.FieldByName("Value") + enc := e.newPrimitiveTypeEncoder(f.Type) + return func(key string, value reflect.Value, writer *multipart.Writer) error { + if opt, ok := value.Interface().(param.Optional); ok && opt.Valid() { + return enc(key, value.FieldByIndex(f.Index), writer) + } else if ok && param.IsNull(opt) { + return writer.WriteField(key, "null") + } + return nil + } +} diff --git a/vendor/github.com/openai/openai-go/internal/apiform/tag.go b/vendor/github.com/openai/openai-go/internal/apiform/tag.go new file mode 100644 index 0000000000..736fc1ea65 --- /dev/null +++ b/vendor/github.com/openai/openai-go/internal/apiform/tag.go @@ -0,0 +1,51 @@ +package apiform + +import ( + "reflect" + "strings" +) + +const jsonStructTag = "json" +const formStructTag = "form" +const formatStructTag = "format" + +type parsedStructTag struct { + name string + required bool + extras bool + metadata bool + omitzero bool +} + +func parseFormStructTag(field reflect.StructField) (tag parsedStructTag, ok bool) { + raw, ok := field.Tag.Lookup(formStructTag) + if !ok { + raw, ok = field.Tag.Lookup(jsonStructTag) + } + if !ok { + return + } + parts := strings.Split(raw, ",") + if len(parts) == 0 { + return tag, false + } + tag.name = parts[0] + for _, part := range parts[1:] { + switch part { + case "required": + tag.required = true + case "extras": + tag.extras = true + case "metadata": + tag.metadata = true + case "omitzero": + tag.omitzero = true + } + } + return +} + +func parseFormatStructTag(field reflect.StructField) (format string, ok bool) { + format, ok = field.Tag.Lookup(formatStructTag) + return +} diff --git a/vendor/github.com/openai/openai-go/internal/apijson/decoder.go b/vendor/github.com/openai/openai-go/internal/apijson/decoder.go new file mode 100644 index 0000000000..b3f1bf7a67 --- /dev/null +++ b/vendor/github.com/openai/openai-go/internal/apijson/decoder.go @@ -0,0 +1,691 @@ +// The deserialization algorithm from apijson may be subject to improvements +// between minor versions, particularly with respect to calling [json.Unmarshal] +// into param unions. + +package apijson + +import ( + "encoding/json" + "fmt" + "github.com/openai/openai-go/packages/param" + "reflect" + "strconv" + "sync" + "time" + "unsafe" + + "github.com/tidwall/gjson" +) + +// decoders is a synchronized map with roughly the following type: +// map[reflect.Type]decoderFunc +var decoders sync.Map + +// Unmarshal is similar to [encoding/json.Unmarshal] and parses the JSON-encoded +// data and stores it in the given pointer. +func Unmarshal(raw []byte, to any) error { + d := &decoderBuilder{dateFormat: time.RFC3339} + return d.unmarshal(raw, to) +} + +// UnmarshalRoot is like Unmarshal, but doesn't try to call MarshalJSON on the +// root element. Useful if a struct's UnmarshalJSON is overrode to use the +// behavior of this encoder versus the standard library. +func UnmarshalRoot(raw []byte, to any) error { + d := &decoderBuilder{dateFormat: time.RFC3339, root: true} + return d.unmarshal(raw, to) +} + +// decoderBuilder contains the 'compile-time' state of the decoder. +type decoderBuilder struct { + // Whether or not this is the first element and called by [UnmarshalRoot], see + // the documentation there to see why this is necessary. + root bool + // The dateFormat (a format string for [time.Format]) which is chosen by the + // last struct tag that was seen. + dateFormat string +} + +// decoderState contains the 'run-time' state of the decoder. +type decoderState struct { + strict bool + exactness exactness + validator *validationEntry +} + +// Exactness refers to how close to the type the result was if deserialization +// was successful. This is useful in deserializing unions, where you want to try +// each entry, first with strict, then with looser validation, without actually +// having to do a lot of redundant work by marshalling twice (or maybe even more +// times). +type exactness int8 + +const ( + // Some values had to fudged a bit, for example by converting a string to an + // int, or an enum with extra values. + loose exactness = iota + // There are some extra arguments, but other wise it matches the union. + extras + // Exactly right. + exact +) + +type decoderFunc func(node gjson.Result, value reflect.Value, state *decoderState) error + +type decoderField struct { + tag parsedStructTag + fn decoderFunc + idx []int + goname string +} + +type decoderEntry struct { + reflect.Type + dateFormat string + root bool +} + +func (d *decoderBuilder) unmarshal(raw []byte, to any) error { + value := reflect.ValueOf(to).Elem() + result := gjson.ParseBytes(raw) + if !value.IsValid() { + return fmt.Errorf("apijson: cannot marshal into invalid value") + } + return d.typeDecoder(value.Type())(result, value, &decoderState{strict: false, exactness: exact}) +} + +// unmarshalWithExactness is used for internal testing purposes. +func (d *decoderBuilder) unmarshalWithExactness(raw []byte, to any) (exactness, error) { + value := reflect.ValueOf(to).Elem() + result := gjson.ParseBytes(raw) + if !value.IsValid() { + return 0, fmt.Errorf("apijson: cannot marshal into invalid value") + } + state := decoderState{strict: false, exactness: exact} + err := d.typeDecoder(value.Type())(result, value, &state) + return state.exactness, err +} + +func (d *decoderBuilder) typeDecoder(t reflect.Type) decoderFunc { + entry := decoderEntry{ + Type: t, + dateFormat: d.dateFormat, + root: d.root, + } + + if fi, ok := decoders.Load(entry); ok { + return fi.(decoderFunc) + } + + // To deal with recursive types, populate the map with an + // indirect func before we build it. This type waits on the + // real func (f) to be ready and then calls it. This indirect + // func is only used for recursive types. + var ( + wg sync.WaitGroup + f decoderFunc + ) + wg.Add(1) + fi, loaded := decoders.LoadOrStore(entry, decoderFunc(func(node gjson.Result, v reflect.Value, state *decoderState) error { + wg.Wait() + return f(node, v, state) + })) + if loaded { + return fi.(decoderFunc) + } + + // Compute the real decoder and replace the indirect func with it. + f = d.newTypeDecoder(t) + wg.Done() + decoders.Store(entry, f) + return f +} + +// validatedTypeDecoder wraps the type decoder with a validator. This is helpful +// for ensuring that enum fields are correct. +func (d *decoderBuilder) validatedTypeDecoder(t reflect.Type, entry *validationEntry) decoderFunc { + dec := d.typeDecoder(t) + if entry == nil { + return dec + } + + // Thread the current validation entry through the decoder, + // but clean up in time for the next field. + return func(node gjson.Result, v reflect.Value, state *decoderState) error { + state.validator = entry + err := dec(node, v, state) + state.validator = nil + return err + } +} + +func indirectUnmarshalerDecoder(n gjson.Result, v reflect.Value, state *decoderState) error { + return v.Addr().Interface().(json.Unmarshaler).UnmarshalJSON([]byte(n.Raw)) +} + +func unmarshalerDecoder(n gjson.Result, v reflect.Value, state *decoderState) error { + if v.Kind() == reflect.Pointer && v.CanSet() { + v.Set(reflect.New(v.Type().Elem())) + } + return v.Interface().(json.Unmarshaler).UnmarshalJSON([]byte(n.Raw)) +} + +func (d *decoderBuilder) newTypeDecoder(t reflect.Type) decoderFunc { + if t.ConvertibleTo(reflect.TypeOf(time.Time{})) { + return d.newTimeTypeDecoder(t) + } + + if t.Implements(reflect.TypeOf((*param.Optional)(nil)).Elem()) { + return d.newOptTypeDecoder(t) + } + + if !d.root && t.Implements(reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()) { + return unmarshalerDecoder + } + if !d.root && reflect.PointerTo(t).Implements(reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()) { + if _, ok := unionVariants[t]; !ok { + return indirectUnmarshalerDecoder + } + } + d.root = false + + if _, ok := unionRegistry[t]; ok { + if isStructUnion(t) { + return d.newStructUnionDecoder(t) + } + return d.newUnionDecoder(t) + } + + switch t.Kind() { + case reflect.Pointer: + inner := t.Elem() + innerDecoder := d.typeDecoder(inner) + + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + if !v.IsValid() { + return fmt.Errorf("apijson: unexpected invalid reflection value %+#v", v) + } + + newValue := reflect.New(inner).Elem() + err := innerDecoder(n, newValue, state) + if err != nil { + return err + } + + v.Set(newValue.Addr()) + return nil + } + case reflect.Struct: + if isStructUnion(t) { + return d.newStructUnionDecoder(t) + } + return d.newStructTypeDecoder(t) + case reflect.Array: + fallthrough + case reflect.Slice: + return d.newArrayTypeDecoder(t) + case reflect.Map: + return d.newMapDecoder(t) + case reflect.Interface: + return func(node gjson.Result, value reflect.Value, state *decoderState) error { + if !value.IsValid() { + return fmt.Errorf("apijson: unexpected invalid value %+#v", value) + } + if node.Value() != nil && value.CanSet() { + value.Set(reflect.ValueOf(node.Value())) + } + return nil + } + default: + return d.newPrimitiveTypeDecoder(t) + } +} + +func (d *decoderBuilder) newMapDecoder(t reflect.Type) decoderFunc { + keyType := t.Key() + itemType := t.Elem() + itemDecoder := d.typeDecoder(itemType) + + return func(node gjson.Result, value reflect.Value, state *decoderState) (err error) { + mapValue := reflect.MakeMapWithSize(t, len(node.Map())) + + node.ForEach(func(key, value gjson.Result) bool { + // It's fine for us to just use `ValueOf` here because the key types will + // always be primitive types so we don't need to decode it using the standard pattern + keyValue := reflect.ValueOf(key.Value()) + if !keyValue.IsValid() { + if err == nil { + err = fmt.Errorf("apijson: received invalid key type %v", keyValue.String()) + } + return false + } + if keyValue.Type() != keyType { + if err == nil { + err = fmt.Errorf("apijson: expected key type %v but got %v", keyType, keyValue.Type()) + } + return false + } + + itemValue := reflect.New(itemType).Elem() + itemerr := itemDecoder(value, itemValue, state) + if itemerr != nil { + if err == nil { + err = itemerr + } + return false + } + + mapValue.SetMapIndex(keyValue, itemValue) + return true + }) + + if err != nil { + return err + } + value.Set(mapValue) + return nil + } +} + +func (d *decoderBuilder) newArrayTypeDecoder(t reflect.Type) decoderFunc { + itemDecoder := d.typeDecoder(t.Elem()) + + return func(node gjson.Result, value reflect.Value, state *decoderState) (err error) { + if !node.IsArray() { + return fmt.Errorf("apijson: could not deserialize to an array") + } + + arrayNode := node.Array() + + arrayValue := reflect.MakeSlice(reflect.SliceOf(t.Elem()), len(arrayNode), len(arrayNode)) + for i, itemNode := range arrayNode { + err = itemDecoder(itemNode, arrayValue.Index(i), state) + if err != nil { + return err + } + } + + value.Set(arrayValue) + return nil + } +} + +func (d *decoderBuilder) newStructTypeDecoder(t reflect.Type) decoderFunc { + // map of json field name to struct field decoders + decoderFields := map[string]decoderField{} + anonymousDecoders := []decoderField{} + extraDecoder := (*decoderField)(nil) + var inlineDecoders []decoderField + + validationEntries := validationRegistry[t] + + for i := 0; i < t.NumField(); i++ { + idx := []int{i} + field := t.FieldByIndex(idx) + if !field.IsExported() { + continue + } + + var validator *validationEntry + for _, entry := range validationEntries { + if entry.field.Offset == field.Offset { + validator = &entry + break + } + } + + // If this is an embedded struct, traverse one level deeper to extract + // the fields and get their encoders as well. + if field.Anonymous { + anonymousDecoders = append(anonymousDecoders, decoderField{ + fn: d.typeDecoder(field.Type), + idx: idx[:], + }) + continue + } + // If json tag is not present, then we skip, which is intentionally + // different behavior from the stdlib. + ptag, ok := parseJSONStructTag(field) + if !ok { + continue + } + // We only want to support unexported fields if they're tagged with + // `extras` because that field shouldn't be part of the public API. + if ptag.extras { + extraDecoder = &decoderField{ptag, d.typeDecoder(field.Type.Elem()), idx, field.Name} + continue + } + if ptag.inline { + df := decoderField{ptag, d.typeDecoder(field.Type), idx, field.Name} + inlineDecoders = append(inlineDecoders, df) + continue + } + if ptag.metadata { + continue + } + + oldFormat := d.dateFormat + dateFormat, ok := parseFormatStructTag(field) + if ok { + switch dateFormat { + case "date-time": + d.dateFormat = time.RFC3339 + case "date": + d.dateFormat = "2006-01-02" + } + } + + decoderFields[ptag.name] = decoderField{ + ptag, + d.validatedTypeDecoder(field.Type, validator), + idx, field.Name, + } + + d.dateFormat = oldFormat + } + + return func(node gjson.Result, value reflect.Value, state *decoderState) (err error) { + if field := value.FieldByName("JSON"); field.IsValid() { + if raw := field.FieldByName("raw"); raw.IsValid() { + setUnexportedField(raw, node.Raw) + } + } + + for _, decoder := range anonymousDecoders { + // ignore errors + decoder.fn(node, value.FieldByIndex(decoder.idx), state) + } + + for _, inlineDecoder := range inlineDecoders { + var meta Field + dest := value.FieldByIndex(inlineDecoder.idx) + isValid := false + if dest.IsValid() && node.Type != gjson.Null { + inlineState := decoderState{exactness: state.exactness, strict: true} + err = inlineDecoder.fn(node, dest, &inlineState) + if err == nil { + isValid = true + } + } + + if node.Type == gjson.Null { + meta = Field{ + raw: node.Raw, + status: null, + } + } else if !isValid { + // If an inline decoder fails, unset the field and move on. + if dest.IsValid() { + dest.SetZero() + } + continue + } else if isValid { + meta = Field{ + raw: node.Raw, + status: valid, + } + } + setMetadataSubField(value, inlineDecoder.idx, inlineDecoder.goname, meta) + } + + typedExtraType := reflect.Type(nil) + typedExtraFields := reflect.Value{} + if extraDecoder != nil { + typedExtraType = value.FieldByIndex(extraDecoder.idx).Type() + typedExtraFields = reflect.MakeMap(typedExtraType) + } + untypedExtraFields := map[string]Field{} + + for fieldName, itemNode := range node.Map() { + df, explicit := decoderFields[fieldName] + var ( + dest reflect.Value + fn decoderFunc + meta Field + ) + if explicit { + fn = df.fn + dest = value.FieldByIndex(df.idx) + } + if !explicit && extraDecoder != nil { + dest = reflect.New(typedExtraType.Elem()).Elem() + fn = extraDecoder.fn + } + + isValid := false + if dest.IsValid() && itemNode.Type != gjson.Null { + err = fn(itemNode, dest, state) + if err == nil { + isValid = true + } + } + + // Handle null [param.Opt] + if itemNode.Type == gjson.Null && dest.IsValid() && dest.Type().Implements(reflect.TypeOf((*param.Optional)(nil)).Elem()) { + dest.Addr().Interface().(json.Unmarshaler).UnmarshalJSON([]byte(itemNode.Raw)) + continue + } + + if itemNode.Type == gjson.Null { + meta = Field{ + raw: itemNode.Raw, + status: null, + } + } else if !isValid { + meta = Field{ + raw: itemNode.Raw, + status: invalid, + } + } else if isValid { + meta = Field{ + raw: itemNode.Raw, + status: valid, + } + } + + if explicit { + setMetadataSubField(value, df.idx, df.goname, meta) + } + if !explicit { + untypedExtraFields[fieldName] = meta + } + if !explicit && extraDecoder != nil { + typedExtraFields.SetMapIndex(reflect.ValueOf(fieldName), dest) + } + } + + if extraDecoder != nil && typedExtraFields.Len() > 0 { + value.FieldByIndex(extraDecoder.idx).Set(typedExtraFields) + } + + // Set exactness to 'extras' if there are untyped, extra fields. + if len(untypedExtraFields) > 0 && state.exactness > extras { + state.exactness = extras + } + + if len(untypedExtraFields) > 0 { + setMetadataExtraFields(value, []int{-1}, "ExtraFields", untypedExtraFields) + } + return nil + } +} + +func (d *decoderBuilder) newPrimitiveTypeDecoder(t reflect.Type) decoderFunc { + switch t.Kind() { + case reflect.String: + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + v.SetString(n.String()) + if guardStrict(state, n.Type != gjson.String) { + return fmt.Errorf("apijson: failed to parse string strictly") + } + // Everything that is not an object can be loosely stringified. + if n.Type == gjson.JSON { + return fmt.Errorf("apijson: failed to parse string") + } + + state.validateString(v) + + if guardUnknown(state, v) { + return fmt.Errorf("apijson: failed string enum validation") + } + return nil + } + case reflect.Bool: + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + v.SetBool(n.Bool()) + if guardStrict(state, n.Type != gjson.True && n.Type != gjson.False) { + return fmt.Errorf("apijson: failed to parse bool strictly") + } + // Numbers and strings that are either 'true' or 'false' can be loosely + // deserialized as bool. + if n.Type == gjson.String && (n.Raw != "true" && n.Raw != "false") || n.Type == gjson.JSON { + return fmt.Errorf("apijson: failed to parse bool") + } + + state.validateBool(v) + + if guardUnknown(state, v) { + return fmt.Errorf("apijson: failed bool enum validation") + } + return nil + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + v.SetInt(n.Int()) + if guardStrict(state, n.Type != gjson.Number || n.Num != float64(int(n.Num))) { + return fmt.Errorf("apijson: failed to parse int strictly") + } + // Numbers, booleans, and strings that maybe look like numbers can be + // loosely deserialized as numbers. + if n.Type == gjson.JSON || (n.Type == gjson.String && !canParseAsNumber(n.Str)) { + return fmt.Errorf("apijson: failed to parse int") + } + + state.validateInt(v) + + if guardUnknown(state, v) { + return fmt.Errorf("apijson: failed int enum validation") + } + return nil + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + v.SetUint(n.Uint()) + if guardStrict(state, n.Type != gjson.Number || n.Num != float64(int(n.Num)) || n.Num < 0) { + return fmt.Errorf("apijson: failed to parse uint strictly") + } + // Numbers, booleans, and strings that maybe look like numbers can be + // loosely deserialized as uint. + if n.Type == gjson.JSON || (n.Type == gjson.String && !canParseAsNumber(n.Str)) { + return fmt.Errorf("apijson: failed to parse uint") + } + if guardUnknown(state, v) { + return fmt.Errorf("apijson: failed uint enum validation") + } + return nil + } + case reflect.Float32, reflect.Float64: + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + v.SetFloat(n.Float()) + if guardStrict(state, n.Type != gjson.Number) { + return fmt.Errorf("apijson: failed to parse float strictly") + } + // Numbers, booleans, and strings that maybe look like numbers can be + // loosely deserialized as floats. + if n.Type == gjson.JSON || (n.Type == gjson.String && !canParseAsNumber(n.Str)) { + return fmt.Errorf("apijson: failed to parse float") + } + if guardUnknown(state, v) { + return fmt.Errorf("apijson: failed float enum validation") + } + return nil + } + default: + return func(node gjson.Result, v reflect.Value, state *decoderState) error { + return fmt.Errorf("unknown type received at primitive decoder: %s", t.String()) + } + } +} + +func (d *decoderBuilder) newOptTypeDecoder(t reflect.Type) decoderFunc { + for t.Kind() == reflect.Pointer { + t = t.Elem() + } + valueField, _ := t.FieldByName("Value") + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + state.validateOptKind(n, valueField.Type) + return v.Addr().Interface().(json.Unmarshaler).UnmarshalJSON([]byte(n.Raw)) + } +} + +func (d *decoderBuilder) newTimeTypeDecoder(t reflect.Type) decoderFunc { + format := d.dateFormat + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + parsed, err := time.Parse(format, n.Str) + if err == nil { + v.Set(reflect.ValueOf(parsed).Convert(t)) + return nil + } + + if guardStrict(state, true) { + return err + } + + layouts := []string{ + "2006-01-02", + "2006-01-02T15:04:05Z07:00", + "2006-01-02T15:04:05Z0700", + "2006-01-02T15:04:05", + "2006-01-02 15:04:05Z07:00", + "2006-01-02 15:04:05Z0700", + "2006-01-02 15:04:05", + } + + for _, layout := range layouts { + parsed, err := time.Parse(layout, n.Str) + if err == nil { + v.Set(reflect.ValueOf(parsed).Convert(t)) + return nil + } + } + + return fmt.Errorf("unable to leniently parse date-time string: %s", n.Str) + } +} + +func setUnexportedField(field reflect.Value, value any) { + reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Set(reflect.ValueOf(value)) +} + +func guardStrict(state *decoderState, cond bool) bool { + if !cond { + return false + } + + if state.strict { + return true + } + + state.exactness = loose + return false +} + +func canParseAsNumber(str string) bool { + _, err := strconv.ParseFloat(str, 64) + return err == nil +} + +var stringType = reflect.TypeOf(string("")) + +func guardUnknown(state *decoderState, v reflect.Value) bool { + if have, ok := v.Interface().(interface{ IsKnown() bool }); guardStrict(state, ok && !have.IsKnown()) { + return true + } + + constantString, ok := v.Interface().(interface{ Default() string }) + named := v.Type() != stringType + if guardStrict(state, ok && named && v.Equal(reflect.ValueOf(constantString.Default()))) { + return true + } + return false +} diff --git a/vendor/github.com/openai/openai-go/internal/apijson/encoder.go b/vendor/github.com/openai/openai-go/internal/apijson/encoder.go new file mode 100644 index 0000000000..8358a2f0a5 --- /dev/null +++ b/vendor/github.com/openai/openai-go/internal/apijson/encoder.go @@ -0,0 +1,392 @@ +package apijson + +import ( + "bytes" + "encoding/json" + "fmt" + "reflect" + "sort" + "strconv" + "strings" + "sync" + "time" + + "github.com/tidwall/sjson" +) + +var encoders sync.Map // map[encoderEntry]encoderFunc + +func Marshal(value any) ([]byte, error) { + e := &encoder{dateFormat: time.RFC3339} + return e.marshal(value) +} + +func MarshalRoot(value any) ([]byte, error) { + e := &encoder{root: true, dateFormat: time.RFC3339} + return e.marshal(value) +} + +type encoder struct { + dateFormat string + root bool +} + +type encoderFunc func(value reflect.Value) ([]byte, error) + +type encoderField struct { + tag parsedStructTag + fn encoderFunc + idx []int +} + +type encoderEntry struct { + reflect.Type + dateFormat string + root bool +} + +func (e *encoder) marshal(value any) ([]byte, error) { + val := reflect.ValueOf(value) + if !val.IsValid() { + return nil, nil + } + typ := val.Type() + enc := e.typeEncoder(typ) + return enc(val) +} + +func (e *encoder) typeEncoder(t reflect.Type) encoderFunc { + entry := encoderEntry{ + Type: t, + dateFormat: e.dateFormat, + root: e.root, + } + + if fi, ok := encoders.Load(entry); ok { + return fi.(encoderFunc) + } + + // To deal with recursive types, populate the map with an + // indirect func before we build it. This type waits on the + // real func (f) to be ready and then calls it. This indirect + // func is only used for recursive types. + var ( + wg sync.WaitGroup + f encoderFunc + ) + wg.Add(1) + fi, loaded := encoders.LoadOrStore(entry, encoderFunc(func(v reflect.Value) ([]byte, error) { + wg.Wait() + return f(v) + })) + if loaded { + return fi.(encoderFunc) + } + + // Compute the real encoder and replace the indirect func with it. + f = e.newTypeEncoder(t) + wg.Done() + encoders.Store(entry, f) + return f +} + +func marshalerEncoder(v reflect.Value) ([]byte, error) { + return v.Interface().(json.Marshaler).MarshalJSON() +} + +func indirectMarshalerEncoder(v reflect.Value) ([]byte, error) { + return v.Addr().Interface().(json.Marshaler).MarshalJSON() +} + +func (e *encoder) newTypeEncoder(t reflect.Type) encoderFunc { + if t.ConvertibleTo(reflect.TypeOf(time.Time{})) { + return e.newTimeTypeEncoder() + } + if !e.root && t.Implements(reflect.TypeOf((*json.Marshaler)(nil)).Elem()) { + return marshalerEncoder + } + if !e.root && reflect.PointerTo(t).Implements(reflect.TypeOf((*json.Marshaler)(nil)).Elem()) { + return indirectMarshalerEncoder + } + e.root = false + switch t.Kind() { + case reflect.Pointer: + inner := t.Elem() + + innerEncoder := e.typeEncoder(inner) + return func(v reflect.Value) ([]byte, error) { + if !v.IsValid() || v.IsNil() { + return nil, nil + } + return innerEncoder(v.Elem()) + } + case reflect.Struct: + return e.newStructTypeEncoder(t) + case reflect.Array: + fallthrough + case reflect.Slice: + return e.newArrayTypeEncoder(t) + case reflect.Map: + return e.newMapEncoder(t) + case reflect.Interface: + return e.newInterfaceEncoder() + default: + return e.newPrimitiveTypeEncoder(t) + } +} + +func (e *encoder) newPrimitiveTypeEncoder(t reflect.Type) encoderFunc { + switch t.Kind() { + // Note that we could use `gjson` to encode these types but it would complicate our + // code more and this current code shouldn't cause any issues + case reflect.String: + return func(v reflect.Value) ([]byte, error) { + return json.Marshal(v.Interface()) + } + case reflect.Bool: + return func(v reflect.Value) ([]byte, error) { + if v.Bool() { + return []byte("true"), nil + } + return []byte("false"), nil + } + case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: + return func(v reflect.Value) ([]byte, error) { + return []byte(strconv.FormatInt(v.Int(), 10)), nil + } + case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return func(v reflect.Value) ([]byte, error) { + return []byte(strconv.FormatUint(v.Uint(), 10)), nil + } + case reflect.Float32: + return func(v reflect.Value) ([]byte, error) { + return []byte(strconv.FormatFloat(v.Float(), 'f', -1, 32)), nil + } + case reflect.Float64: + return func(v reflect.Value) ([]byte, error) { + return []byte(strconv.FormatFloat(v.Float(), 'f', -1, 64)), nil + } + default: + return func(v reflect.Value) ([]byte, error) { + return nil, fmt.Errorf("unknown type received at primitive encoder: %s", t.String()) + } + } +} + +func (e *encoder) newArrayTypeEncoder(t reflect.Type) encoderFunc { + itemEncoder := e.typeEncoder(t.Elem()) + + return func(value reflect.Value) ([]byte, error) { + json := []byte("[]") + for i := 0; i < value.Len(); i++ { + var value, err = itemEncoder(value.Index(i)) + if err != nil { + return nil, err + } + if value == nil { + // Assume that empty items should be inserted as `null` so that the output array + // will be the same length as the input array + value = []byte("null") + } + + json, err = sjson.SetRawBytes(json, "-1", value) + if err != nil { + return nil, err + } + } + + return json, nil + } +} + +func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc { + encoderFields := []encoderField{} + extraEncoder := (*encoderField)(nil) + + // This helper allows us to recursively collect field encoders into a flat + // array. The parameter `index` keeps track of the access patterns necessary + // to get to some field. + var collectEncoderFields func(r reflect.Type, index []int) + collectEncoderFields = func(r reflect.Type, index []int) { + for i := 0; i < r.NumField(); i++ { + idx := append(index, i) + field := t.FieldByIndex(idx) + if !field.IsExported() { + continue + } + // If this is an embedded struct, traverse one level deeper to extract + // the field and get their encoders as well. + if field.Anonymous { + collectEncoderFields(field.Type, idx) + continue + } + // If json tag is not present, then we skip, which is intentionally + // different behavior from the stdlib. + ptag, ok := parseJSONStructTag(field) + if !ok { + continue + } + // We only want to support unexported field if they're tagged with + // `extras` because that field shouldn't be part of the public API. We + // also want to only keep the top level extras + if ptag.extras && len(index) == 0 { + extraEncoder = &encoderField{ptag, e.typeEncoder(field.Type.Elem()), idx} + continue + } + if ptag.name == "-" { + continue + } + + dateFormat, ok := parseFormatStructTag(field) + oldFormat := e.dateFormat + if ok { + switch dateFormat { + case "date-time": + e.dateFormat = time.RFC3339 + case "date": + e.dateFormat = "2006-01-02" + } + } + encoderFields = append(encoderFields, encoderField{ptag, e.typeEncoder(field.Type), idx}) + e.dateFormat = oldFormat + } + } + collectEncoderFields(t, []int{}) + + // Ensure deterministic output by sorting by lexicographic order + sort.Slice(encoderFields, func(i, j int) bool { + return encoderFields[i].tag.name < encoderFields[j].tag.name + }) + + return func(value reflect.Value) (json []byte, err error) { + json = []byte("{}") + + for _, ef := range encoderFields { + field := value.FieldByIndex(ef.idx) + encoded, err := ef.fn(field) + if err != nil { + return nil, err + } + if encoded == nil { + continue + } + json, err = sjson.SetRawBytes(json, ef.tag.name, encoded) + if err != nil { + return nil, err + } + } + + if extraEncoder != nil { + json, err = e.encodeMapEntries(json, value.FieldByIndex(extraEncoder.idx)) + if err != nil { + return nil, err + } + } + return + } +} + +func (e *encoder) newFieldTypeEncoder(t reflect.Type) encoderFunc { + f, _ := t.FieldByName("Value") + enc := e.typeEncoder(f.Type) + + return func(value reflect.Value) (json []byte, err error) { + present := value.FieldByName("Present") + if !present.Bool() { + return nil, nil + } + null := value.FieldByName("Null") + if null.Bool() { + return []byte("null"), nil + } + raw := value.FieldByName("Raw") + if !raw.IsNil() { + return e.typeEncoder(raw.Type())(raw) + } + return enc(value.FieldByName("Value")) + } +} + +func (e *encoder) newTimeTypeEncoder() encoderFunc { + format := e.dateFormat + return func(value reflect.Value) (json []byte, err error) { + return []byte(`"` + value.Convert(reflect.TypeOf(time.Time{})).Interface().(time.Time).Format(format) + `"`), nil + } +} + +func (e encoder) newInterfaceEncoder() encoderFunc { + return func(value reflect.Value) ([]byte, error) { + value = value.Elem() + if !value.IsValid() { + return nil, nil + } + return e.typeEncoder(value.Type())(value) + } +} + +// Given a []byte of json (may either be an empty object or an object that already contains entries) +// encode all of the entries in the map to the json byte array. +func (e *encoder) encodeMapEntries(json []byte, v reflect.Value) ([]byte, error) { + type mapPair struct { + key []byte + value reflect.Value + } + + pairs := []mapPair{} + keyEncoder := e.typeEncoder(v.Type().Key()) + + iter := v.MapRange() + for iter.Next() { + var encodedKeyString string + if iter.Key().Type().Kind() == reflect.String { + encodedKeyString = iter.Key().String() + } else { + var err error + encodedKeyBytes, err := keyEncoder(iter.Key()) + if err != nil { + return nil, err + } + encodedKeyString = string(encodedKeyBytes) + } + encodedKey := []byte(sjsonReplacer.Replace(encodedKeyString)) + pairs = append(pairs, mapPair{key: encodedKey, value: iter.Value()}) + } + + // Ensure deterministic output + sort.Slice(pairs, func(i, j int) bool { + return bytes.Compare(pairs[i].key, pairs[j].key) < 0 + }) + + elementEncoder := e.typeEncoder(v.Type().Elem()) + for _, p := range pairs { + encodedValue, err := elementEncoder(p.value) + if err != nil { + return nil, err + } + if len(encodedValue) == 0 { + continue + } + json, err = sjson.SetRawBytes(json, string(p.key), encodedValue) + if err != nil { + return nil, err + } + } + + return json, nil +} + +func (e *encoder) newMapEncoder(_ reflect.Type) encoderFunc { + return func(value reflect.Value) ([]byte, error) { + json := []byte("{}") + var err error + json, err = e.encodeMapEntries(json, value) + if err != nil { + return nil, err + } + return json, nil + } +} + +// If we want to set a literal key value into JSON using sjson, we need to make sure it doesn't have +// special characters that sjson interprets as a path. +var sjsonReplacer *strings.Replacer = strings.NewReplacer(".", "\\.", ":", "\\:", "*", "\\*") diff --git a/vendor/github.com/openai/openai-go/internal/apijson/enum.go b/vendor/github.com/openai/openai-go/internal/apijson/enum.go new file mode 100644 index 0000000000..18b218a8e7 --- /dev/null +++ b/vendor/github.com/openai/openai-go/internal/apijson/enum.go @@ -0,0 +1,145 @@ +package apijson + +import ( + "fmt" + "reflect" + "slices" + "sync" + + "github.com/tidwall/gjson" +) + +/********************/ +/* Validating Enums */ +/********************/ + +type validationEntry struct { + field reflect.StructField + required bool + legalValues struct { + strings []string + // 1 represents true, 0 represents false, -1 represents either + bools int + ints []int64 + } +} + +type validatorFunc func(reflect.Value) exactness + +var validators sync.Map +var validationRegistry = map[reflect.Type][]validationEntry{} + +func RegisterFieldValidator[T any, V string | bool | int](fieldName string, values ...V) { + var t T + parentType := reflect.TypeOf(t) + + if _, ok := validationRegistry[parentType]; !ok { + validationRegistry[parentType] = []validationEntry{} + } + + // The following checks run at initialization time, + // it is impossible for them to panic if any tests pass. + if parentType.Kind() != reflect.Struct { + panic(fmt.Sprintf("apijson: cannot initialize validator for non-struct %s", parentType.String())) + } + + var field reflect.StructField + found := false + for i := 0; i < parentType.NumField(); i++ { + ptag, ok := parseJSONStructTag(parentType.Field(i)) + if ok && ptag.name == fieldName { + field = parentType.Field(i) + found = true + break + } + } + + if !found { + panic(fmt.Sprintf("apijson: cannot find field %s in struct %s", fieldName, parentType.String())) + } + + newEntry := validationEntry{field: field} + newEntry.legalValues.bools = -1 // default to either + + switch values := any(values).(type) { + case []string: + newEntry.legalValues.strings = values + case []int: + newEntry.legalValues.ints = make([]int64, len(values)) + for i, value := range values { + newEntry.legalValues.ints[i] = int64(value) + } + case []bool: + for i, value := range values { + var next int + if value { + next = 1 + } + if i > 0 && newEntry.legalValues.bools != next { + newEntry.legalValues.bools = -1 // accept either + break + } + newEntry.legalValues.bools = next + } + } + + // Store the information necessary to create a validator, so that we can use it + // lazily create the validator function when did. + validationRegistry[parentType] = append(validationRegistry[parentType], newEntry) +} + +func (state *decoderState) validateString(v reflect.Value) { + if state.validator == nil { + return + } + if !slices.Contains(state.validator.legalValues.strings, v.String()) { + state.exactness = loose + } +} + +func (state *decoderState) validateInt(v reflect.Value) { + if state.validator == nil { + return + } + if !slices.Contains(state.validator.legalValues.ints, v.Int()) { + state.exactness = loose + } +} + +func (state *decoderState) validateBool(v reflect.Value) { + if state.validator == nil { + return + } + b := v.Bool() + if state.validator.legalValues.bools == 1 && b == false { + state.exactness = loose + } else if state.validator.legalValues.bools == 0 && b == true { + state.exactness = loose + } +} + +func (state *decoderState) validateOptKind(node gjson.Result, t reflect.Type) { + switch node.Type { + case gjson.JSON: + state.exactness = loose + case gjson.Null: + return + case gjson.False, gjson.True: + if t.Kind() != reflect.Bool { + state.exactness = loose + } + case gjson.Number: + switch t.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + return + default: + state.exactness = loose + } + case gjson.String: + if t.Kind() != reflect.String { + state.exactness = loose + } + } +} diff --git a/vendor/github.com/openai/openai-go/internal/apijson/field.go b/vendor/github.com/openai/openai-go/internal/apijson/field.go new file mode 100644 index 0000000000..854d6dd78d --- /dev/null +++ b/vendor/github.com/openai/openai-go/internal/apijson/field.go @@ -0,0 +1,23 @@ +package apijson + +type status uint8 + +const ( + missing status = iota + null + invalid + valid +) + +type Field struct { + raw string + status status +} + +// Returns true if the field is explicitly `null` _or_ if it is not present at all (ie, missing). +// To check if the field's key is present in the JSON with an explicit null value, +// you must check `f.IsNull() && !f.IsMissing()`. +func (j Field) IsNull() bool { return j.status <= null } +func (j Field) IsMissing() bool { return j.status == missing } +func (j Field) IsInvalid() bool { return j.status == invalid } +func (j Field) Raw() string { return j.raw } diff --git a/vendor/github.com/openai/openai-go/internal/apijson/port.go b/vendor/github.com/openai/openai-go/internal/apijson/port.go new file mode 100644 index 0000000000..b40013c13a --- /dev/null +++ b/vendor/github.com/openai/openai-go/internal/apijson/port.go @@ -0,0 +1,120 @@ +package apijson + +import ( + "fmt" + "reflect" +) + +// Port copies over values from one struct to another struct. +func Port(from any, to any) error { + toVal := reflect.ValueOf(to) + fromVal := reflect.ValueOf(from) + + if toVal.Kind() != reflect.Ptr || toVal.IsNil() { + return fmt.Errorf("destination must be a non-nil pointer") + } + + for toVal.Kind() == reflect.Ptr { + toVal = toVal.Elem() + } + toType := toVal.Type() + + for fromVal.Kind() == reflect.Ptr { + fromVal = fromVal.Elem() + } + fromType := fromVal.Type() + + if toType.Kind() != reflect.Struct { + return fmt.Errorf("destination must be a non-nil pointer to a struct (%v %v)", toType, toType.Kind()) + } + + values := map[string]reflect.Value{} + fields := map[string]reflect.Value{} + + fromJSON := fromVal.FieldByName("JSON") + toJSON := toVal.FieldByName("JSON") + + // Iterate through the fields of v and load all the "normal" fields in the struct to the map of + // string to reflect.Value, as well as their raw .JSON.Foo counterpart indicated by j. + var getFields func(t reflect.Type, v reflect.Value) + getFields = func(t reflect.Type, v reflect.Value) { + j := v.FieldByName("JSON") + + // Recurse into anonymous fields first, since the fields on the object should win over the fields in the + // embedded object. + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if field.Anonymous { + getFields(field.Type, v.Field(i)) + continue + } + } + + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + ptag, ok := parseJSONStructTag(field) + if !ok || ptag.name == "-" || ptag.name == "" { + continue + } + values[ptag.name] = v.Field(i) + if j.IsValid() { + fields[ptag.name] = j.FieldByName(field.Name) + } + } + } + getFields(fromType, fromVal) + + // Use the values from the previous step to populate the 'to' struct. + for i := 0; i < toType.NumField(); i++ { + field := toType.Field(i) + ptag, ok := parseJSONStructTag(field) + if !ok { + continue + } + if ptag.name == "-" { + continue + } + if value, ok := values[ptag.name]; ok { + delete(values, ptag.name) + if field.Type.Kind() == reflect.Interface { + toVal.Field(i).Set(value) + } else { + switch value.Kind() { + case reflect.String: + toVal.Field(i).SetString(value.String()) + case reflect.Bool: + toVal.Field(i).SetBool(value.Bool()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + toVal.Field(i).SetInt(value.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + toVal.Field(i).SetUint(value.Uint()) + case reflect.Float32, reflect.Float64: + toVal.Field(i).SetFloat(value.Float()) + default: + toVal.Field(i).Set(value) + } + } + } + + if fromJSONField, ok := fields[ptag.name]; ok { + if toJSONField := toJSON.FieldByName(field.Name); toJSONField.IsValid() { + toJSONField.Set(fromJSONField) + } + } + } + + // Finally, copy over the .JSON.raw and .JSON.ExtraFields + if toJSON.IsValid() { + if raw := toJSON.FieldByName("raw"); raw.IsValid() { + setUnexportedField(raw, fromJSON.Interface().(interface{ RawJSON() string }).RawJSON()) + } + + if toExtraFields := toJSON.FieldByName("ExtraFields"); toExtraFields.IsValid() { + if fromExtraFields := fromJSON.FieldByName("ExtraFields"); fromExtraFields.IsValid() { + setUnexportedField(toExtraFields, fromExtraFields.Interface()) + } + } + } + + return nil +} diff --git a/vendor/github.com/openai/openai-go/internal/apijson/registry.go b/vendor/github.com/openai/openai-go/internal/apijson/registry.go new file mode 100644 index 0000000000..2a24982700 --- /dev/null +++ b/vendor/github.com/openai/openai-go/internal/apijson/registry.go @@ -0,0 +1,51 @@ +package apijson + +import ( + "reflect" + + "github.com/tidwall/gjson" +) + +type UnionVariant struct { + TypeFilter gjson.Type + DiscriminatorValue any + Type reflect.Type +} + +var unionRegistry = map[reflect.Type]unionEntry{} +var unionVariants = map[reflect.Type]any{} + +type unionEntry struct { + discriminatorKey string + variants []UnionVariant +} + +func Discriminator[T any](value any) UnionVariant { + var zero T + return UnionVariant{ + TypeFilter: gjson.JSON, + DiscriminatorValue: value, + Type: reflect.TypeOf(zero), + } +} + +func RegisterUnion[T any](discriminator string, variants ...UnionVariant) { + typ := reflect.TypeOf((*T)(nil)).Elem() + unionRegistry[typ] = unionEntry{ + discriminatorKey: discriminator, + variants: variants, + } + for _, variant := range variants { + unionVariants[variant.Type] = typ + } +} + +// Useful to wrap a union type to force it to use [apijson.UnmarshalJSON] since you cannot define an +// UnmarshalJSON function on the interface itself. +type UnionUnmarshaler[T any] struct { + Value T +} + +func (c *UnionUnmarshaler[T]) UnmarshalJSON(buf []byte) error { + return UnmarshalRoot(buf, &c.Value) +} diff --git a/vendor/github.com/openai/openai-go/internal/apijson/subfield.go b/vendor/github.com/openai/openai-go/internal/apijson/subfield.go new file mode 100644 index 0000000000..782d3a7800 --- /dev/null +++ b/vendor/github.com/openai/openai-go/internal/apijson/subfield.go @@ -0,0 +1,67 @@ +package apijson + +import ( + "github.com/openai/openai-go/packages/respjson" + "reflect" +) + +func getSubField(root reflect.Value, index []int, name string) reflect.Value { + strct := root.FieldByIndex(index[:len(index)-1]) + if !strct.IsValid() { + panic("couldn't find encapsulating struct for field " + name) + } + meta := strct.FieldByName("JSON") + if !meta.IsValid() { + return reflect.Value{} + } + field := meta.FieldByName(name) + if !field.IsValid() { + return reflect.Value{} + } + return field +} + +func setMetadataSubField(root reflect.Value, index []int, name string, meta Field) { + target := getSubField(root, index, name) + if !target.IsValid() { + return + } + + if target.Type() == reflect.TypeOf(meta) { + target.Set(reflect.ValueOf(meta)) + } else if respMeta := meta.toRespField(); target.Type() == reflect.TypeOf(respMeta) { + target.Set(reflect.ValueOf(respMeta)) + } +} + +func setMetadataExtraFields(root reflect.Value, index []int, name string, metaExtras map[string]Field) { + target := getSubField(root, index, name) + if !target.IsValid() { + return + } + + if target.Type() == reflect.TypeOf(metaExtras) { + target.Set(reflect.ValueOf(metaExtras)) + return + } + + newMap := make(map[string]respjson.Field, len(metaExtras)) + if target.Type() == reflect.TypeOf(newMap) { + for k, v := range metaExtras { + newMap[k] = v.toRespField() + } + target.Set(reflect.ValueOf(newMap)) + } +} + +func (f Field) toRespField() respjson.Field { + if f.IsMissing() { + return respjson.Field{} + } else if f.IsNull() { + return respjson.NewField("null") + } else if f.IsInvalid() { + return respjson.NewInvalidField(f.raw) + } else { + return respjson.NewField(f.raw) + } +} diff --git a/vendor/github.com/openai/openai-go/internal/apijson/tag.go b/vendor/github.com/openai/openai-go/internal/apijson/tag.go new file mode 100644 index 0000000000..812fb3caf4 --- /dev/null +++ b/vendor/github.com/openai/openai-go/internal/apijson/tag.go @@ -0,0 +1,47 @@ +package apijson + +import ( + "reflect" + "strings" +) + +const jsonStructTag = "json" +const formatStructTag = "format" + +type parsedStructTag struct { + name string + required bool + extras bool + metadata bool + inline bool +} + +func parseJSONStructTag(field reflect.StructField) (tag parsedStructTag, ok bool) { + raw, ok := field.Tag.Lookup(jsonStructTag) + if !ok { + return + } + parts := strings.Split(raw, ",") + if len(parts) == 0 { + return tag, false + } + tag.name = parts[0] + for _, part := range parts[1:] { + switch part { + case "required": + tag.required = true + case "extras": + tag.extras = true + case "metadata": + tag.metadata = true + case "inline": + tag.inline = true + } + } + return +} + +func parseFormatStructTag(field reflect.StructField) (format string, ok bool) { + format, ok = field.Tag.Lookup(formatStructTag) + return +} diff --git a/vendor/github.com/openai/openai-go/internal/apijson/union.go b/vendor/github.com/openai/openai-go/internal/apijson/union.go new file mode 100644 index 0000000000..2f23de43a7 --- /dev/null +++ b/vendor/github.com/openai/openai-go/internal/apijson/union.go @@ -0,0 +1,202 @@ +package apijson + +import ( + "errors" + "github.com/openai/openai-go/packages/param" + "reflect" + + "github.com/tidwall/gjson" +) + +var apiUnionType = reflect.TypeOf(param.APIUnion{}) + +func isStructUnion(t reflect.Type) bool { + if t.Kind() != reflect.Struct { + return false + } + for i := 0; i < t.NumField(); i++ { + if t.Field(i).Type == apiUnionType && t.Field(i).Anonymous { + return true + } + } + return false +} + +func RegisterDiscriminatedUnion[T any](key string, mappings map[string]reflect.Type) { + var t T + entry := unionEntry{ + discriminatorKey: key, + variants: []UnionVariant{}, + } + for k, typ := range mappings { + entry.variants = append(entry.variants, UnionVariant{ + DiscriminatorValue: k, + Type: typ, + }) + } + unionRegistry[reflect.TypeOf(t)] = entry +} + +func (d *decoderBuilder) newStructUnionDecoder(t reflect.Type) decoderFunc { + type variantDecoder struct { + decoder decoderFunc + field reflect.StructField + discriminatorValue any + } + + variants := []variantDecoder{} + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + + if field.Anonymous && field.Type == apiUnionType { + continue + } + + decoder := d.typeDecoder(field.Type) + variants = append(variants, variantDecoder{ + decoder: decoder, + field: field, + }) + } + + unionEntry, discriminated := unionRegistry[t] + for _, unionVariant := range unionEntry.variants { + for i := 0; i < len(variants); i++ { + variant := &variants[i] + if variant.field.Type.Elem() == unionVariant.Type { + variant.discriminatorValue = unionVariant.DiscriminatorValue + break + } + } + } + + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + if discriminated && n.Type == gjson.JSON && len(unionEntry.discriminatorKey) != 0 { + discriminator := n.Get(unionEntry.discriminatorKey).Value() + for _, variant := range variants { + if discriminator == variant.discriminatorValue { + inner := v.FieldByIndex(variant.field.Index) + return variant.decoder(n, inner, state) + } + } + return errors.New("apijson: was not able to find discriminated union variant") + } + + // Set bestExactness to worse than loose + bestExactness := loose - 1 + bestVariant := -1 + for i, variant := range variants { + // Pointers are used to discern JSON object variants from value variants + if n.Type != gjson.JSON && variant.field.Type.Kind() == reflect.Ptr { + continue + } + + sub := decoderState{strict: state.strict, exactness: exact} + inner := v.FieldByIndex(variant.field.Index) + err := variant.decoder(n, inner, &sub) + if err != nil { + continue + } + if sub.exactness == exact { + bestExactness = exact + bestVariant = i + break + } + if sub.exactness > bestExactness { + bestExactness = sub.exactness + bestVariant = i + } + } + + if bestExactness < loose { + return errors.New("apijson: was not able to coerce type as union") + } + + if guardStrict(state, bestExactness != exact) { + return errors.New("apijson: was not able to coerce type as union strictly") + } + + for i := 0; i < len(variants); i++ { + if i == bestVariant { + continue + } + v.FieldByIndex(variants[i].field.Index).SetZero() + } + + return nil + } +} + +// newUnionDecoder returns a decoderFunc that deserializes into a union using an +// algorithm roughly similar to Pydantic's [smart algorithm]. +// +// Conceptually this is equivalent to choosing the best schema based on how 'exact' +// the deserialization is for each of the schemas. +// +// If there is a tie in the level of exactness, then the tie is broken +// left-to-right. +// +// [smart algorithm]: https://docs.pydantic.dev/latest/concepts/unions/#smart-mode +func (d *decoderBuilder) newUnionDecoder(t reflect.Type) decoderFunc { + unionEntry, ok := unionRegistry[t] + if !ok { + panic("apijson: couldn't find union of type " + t.String() + " in union registry") + } + decoders := []decoderFunc{} + for _, variant := range unionEntry.variants { + decoder := d.typeDecoder(variant.Type) + decoders = append(decoders, decoder) + } + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + // If there is a discriminator match, circumvent the exactness logic entirely + for idx, variant := range unionEntry.variants { + decoder := decoders[idx] + if variant.TypeFilter != n.Type { + continue + } + + if len(unionEntry.discriminatorKey) != 0 { + discriminatorValue := n.Get(unionEntry.discriminatorKey).Value() + if discriminatorValue == variant.DiscriminatorValue { + inner := reflect.New(variant.Type).Elem() + err := decoder(n, inner, state) + v.Set(inner) + return err + } + } + } + + // Set bestExactness to worse than loose + bestExactness := loose - 1 + for idx, variant := range unionEntry.variants { + decoder := decoders[idx] + if variant.TypeFilter != n.Type { + continue + } + sub := decoderState{strict: state.strict, exactness: exact} + inner := reflect.New(variant.Type).Elem() + err := decoder(n, inner, &sub) + if err != nil { + continue + } + if sub.exactness == exact { + v.Set(inner) + return nil + } + if sub.exactness > bestExactness { + v.Set(inner) + bestExactness = sub.exactness + } + } + + if bestExactness < loose { + return errors.New("apijson: was not able to coerce type as union") + } + + if guardStrict(state, bestExactness != exact) { + return errors.New("apijson: was not able to coerce type as union strictly") + } + + return nil + } +} diff --git a/vendor/github.com/openai/openai-go/internal/apiquery/encoder.go b/vendor/github.com/openai/openai-go/internal/apiquery/encoder.go new file mode 100644 index 0000000000..94bc40c32e --- /dev/null +++ b/vendor/github.com/openai/openai-go/internal/apiquery/encoder.go @@ -0,0 +1,415 @@ +package apiquery + +import ( + "encoding/json" + "fmt" + "reflect" + "strconv" + "strings" + "sync" + "time" + + "github.com/openai/openai-go/packages/param" +) + +var encoders sync.Map // map[reflect.Type]encoderFunc + +type encoder struct { + dateFormat string + root bool + settings QuerySettings +} + +type encoderFunc func(key string, value reflect.Value) ([]Pair, error) + +type encoderField struct { + tag parsedStructTag + fn encoderFunc + idx []int +} + +type encoderEntry struct { + reflect.Type + dateFormat string + root bool + settings QuerySettings +} + +type Pair struct { + key string + value string +} + +func (e *encoder) typeEncoder(t reflect.Type) encoderFunc { + entry := encoderEntry{ + Type: t, + dateFormat: e.dateFormat, + root: e.root, + settings: e.settings, + } + + if fi, ok := encoders.Load(entry); ok { + return fi.(encoderFunc) + } + + // To deal with recursive types, populate the map with an + // indirect func before we build it. This type waits on the + // real func (f) to be ready and then calls it. This indirect + // func is only used for recursive types. + var ( + wg sync.WaitGroup + f encoderFunc + ) + wg.Add(1) + fi, loaded := encoders.LoadOrStore(entry, encoderFunc(func(key string, v reflect.Value) ([]Pair, error) { + wg.Wait() + return f(key, v) + })) + if loaded { + return fi.(encoderFunc) + } + + // Compute the real encoder and replace the indirect func with it. + f = e.newTypeEncoder(t) + wg.Done() + encoders.Store(entry, f) + return f +} + +func marshalerEncoder(key string, value reflect.Value) ([]Pair, error) { + s, err := value.Interface().(json.Marshaler).MarshalJSON() + if err != nil { + return nil, fmt.Errorf("apiquery: json fallback marshal error %s", err) + } + return []Pair{{key, string(s)}}, nil +} + +func (e *encoder) newTypeEncoder(t reflect.Type) encoderFunc { + if t.ConvertibleTo(reflect.TypeOf(time.Time{})) { + return e.newTimeTypeEncoder(t) + } + + if t.Implements(reflect.TypeOf((*param.Optional)(nil)).Elem()) { + return e.newRichFieldTypeEncoder(t) + } + + if !e.root && t.Implements(reflect.TypeOf((*json.Marshaler)(nil)).Elem()) { + return marshalerEncoder + } + + e.root = false + switch t.Kind() { + case reflect.Pointer: + encoder := e.typeEncoder(t.Elem()) + return func(key string, value reflect.Value) (pairs []Pair, err error) { + if !value.IsValid() || value.IsNil() { + return + } + return encoder(key, value.Elem()) + } + case reflect.Struct: + return e.newStructTypeEncoder(t) + case reflect.Array: + fallthrough + case reflect.Slice: + return e.newArrayTypeEncoder(t) + case reflect.Map: + return e.newMapEncoder(t) + case reflect.Interface: + return e.newInterfaceEncoder() + default: + return e.newPrimitiveTypeEncoder(t) + } +} + +func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc { + if t.Implements(reflect.TypeOf((*param.Optional)(nil)).Elem()) { + return e.newRichFieldTypeEncoder(t) + } + + for i := 0; i < t.NumField(); i++ { + if t.Field(i).Type == paramUnionType && t.Field(i).Anonymous { + return e.newStructUnionTypeEncoder(t) + } + } + + encoderFields := []encoderField{} + + // This helper allows us to recursively collect field encoders into a flat + // array. The parameter `index` keeps track of the access patterns necessary + // to get to some field. + var collectEncoderFields func(r reflect.Type, index []int) + collectEncoderFields = func(r reflect.Type, index []int) { + for i := 0; i < r.NumField(); i++ { + idx := append(index, i) + field := t.FieldByIndex(idx) + if !field.IsExported() { + continue + } + // If this is an embedded struct, traverse one level deeper to extract + // the field and get their encoders as well. + if field.Anonymous { + collectEncoderFields(field.Type, idx) + continue + } + // If query tag is not present, then we skip, which is intentionally + // different behavior from the stdlib. + ptag, ok := parseQueryStructTag(field) + if !ok { + continue + } + + if (ptag.name == "-" || ptag.name == "") && !ptag.inline { + continue + } + + dateFormat, ok := parseFormatStructTag(field) + oldFormat := e.dateFormat + if ok { + switch dateFormat { + case "date-time": + e.dateFormat = time.RFC3339 + case "date": + e.dateFormat = "2006-01-02" + } + } + var encoderFn encoderFunc + if ptag.omitzero { + typeEncoderFn := e.typeEncoder(field.Type) + encoderFn = func(key string, value reflect.Value) ([]Pair, error) { + if value.IsZero() { + return nil, nil + } + return typeEncoderFn(key, value) + } + } else { + encoderFn = e.typeEncoder(field.Type) + } + encoderFields = append(encoderFields, encoderField{ptag, encoderFn, idx}) + e.dateFormat = oldFormat + } + } + collectEncoderFields(t, []int{}) + + return func(key string, value reflect.Value) (pairs []Pair, err error) { + for _, ef := range encoderFields { + var subkey string = e.renderKeyPath(key, ef.tag.name) + if ef.tag.inline { + subkey = key + } + + field := value.FieldByIndex(ef.idx) + subpairs, suberr := ef.fn(subkey, field) + if suberr != nil { + err = suberr + } + pairs = append(pairs, subpairs...) + } + return + } +} + +var paramUnionType = reflect.TypeOf((*param.APIUnion)(nil)).Elem() + +func (e *encoder) newStructUnionTypeEncoder(t reflect.Type) encoderFunc { + var fieldEncoders []encoderFunc + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if field.Type == paramUnionType && field.Anonymous { + fieldEncoders = append(fieldEncoders, nil) + continue + } + fieldEncoders = append(fieldEncoders, e.typeEncoder(field.Type)) + } + + return func(key string, value reflect.Value) (pairs []Pair, err error) { + for i := 0; i < t.NumField(); i++ { + if value.Field(i).Type() == paramUnionType { + continue + } + if !value.Field(i).IsZero() { + return fieldEncoders[i](key, value.Field(i)) + } + } + return nil, fmt.Errorf("apiquery: union %s has no field set", t.String()) + } +} + +func (e *encoder) newMapEncoder(t reflect.Type) encoderFunc { + keyEncoder := e.typeEncoder(t.Key()) + elementEncoder := e.typeEncoder(t.Elem()) + return func(key string, value reflect.Value) (pairs []Pair, err error) { + iter := value.MapRange() + for iter.Next() { + encodedKey, err := keyEncoder("", iter.Key()) + if err != nil { + return nil, err + } + if len(encodedKey) != 1 { + return nil, fmt.Errorf("apiquery: unexpected number of parts for encoded map key, map may contain non-primitive") + } + subkey := encodedKey[0].value + keyPath := e.renderKeyPath(key, subkey) + subpairs, suberr := elementEncoder(keyPath, iter.Value()) + if suberr != nil { + err = suberr + } + pairs = append(pairs, subpairs...) + } + return + } +} + +func (e *encoder) renderKeyPath(key string, subkey string) string { + if len(key) == 0 { + return subkey + } + if e.settings.NestedFormat == NestedQueryFormatDots { + return fmt.Sprintf("%s.%s", key, subkey) + } + return fmt.Sprintf("%s[%s]", key, subkey) +} + +func (e *encoder) newArrayTypeEncoder(t reflect.Type) encoderFunc { + switch e.settings.ArrayFormat { + case ArrayQueryFormatComma: + innerEncoder := e.typeEncoder(t.Elem()) + return func(key string, v reflect.Value) ([]Pair, error) { + elements := []string{} + for i := 0; i < v.Len(); i++ { + innerPairs, err := innerEncoder("", v.Index(i)) + if err != nil { + return nil, err + } + for _, pair := range innerPairs { + elements = append(elements, pair.value) + } + } + if len(elements) == 0 { + return []Pair{}, nil + } + return []Pair{{key, strings.Join(elements, ",")}}, nil + } + case ArrayQueryFormatRepeat: + innerEncoder := e.typeEncoder(t.Elem()) + return func(key string, value reflect.Value) (pairs []Pair, err error) { + for i := 0; i < value.Len(); i++ { + subpairs, suberr := innerEncoder(key, value.Index(i)) + if suberr != nil { + err = suberr + } + pairs = append(pairs, subpairs...) + } + return + } + case ArrayQueryFormatIndices: + panic("The array indices format is not supported yet") + case ArrayQueryFormatBrackets: + innerEncoder := e.typeEncoder(t.Elem()) + return func(key string, value reflect.Value) (pairs []Pair, err error) { + pairs = []Pair{} + for i := 0; i < value.Len(); i++ { + subpairs, suberr := innerEncoder(key+"[]", value.Index(i)) + if suberr != nil { + err = suberr + } + pairs = append(pairs, subpairs...) + } + return + } + default: + panic(fmt.Sprintf("Unknown ArrayFormat value: %d", e.settings.ArrayFormat)) + } +} + +func (e *encoder) newPrimitiveTypeEncoder(t reflect.Type) encoderFunc { + switch t.Kind() { + case reflect.Pointer: + inner := t.Elem() + + innerEncoder := e.newPrimitiveTypeEncoder(inner) + return func(key string, v reflect.Value) ([]Pair, error) { + if !v.IsValid() || v.IsNil() { + return nil, nil + } + return innerEncoder(key, v.Elem()) + } + case reflect.String: + return func(key string, v reflect.Value) ([]Pair, error) { + return []Pair{{key, v.String()}}, nil + } + case reflect.Bool: + return func(key string, v reflect.Value) ([]Pair, error) { + if v.Bool() { + return []Pair{{key, "true"}}, nil + } + return []Pair{{key, "false"}}, nil + } + case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: + return func(key string, v reflect.Value) ([]Pair, error) { + return []Pair{{key, strconv.FormatInt(v.Int(), 10)}}, nil + } + case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return func(key string, v reflect.Value) ([]Pair, error) { + return []Pair{{key, strconv.FormatUint(v.Uint(), 10)}}, nil + } + case reflect.Float32, reflect.Float64: + return func(key string, v reflect.Value) ([]Pair, error) { + return []Pair{{key, strconv.FormatFloat(v.Float(), 'f', -1, 64)}}, nil + } + case reflect.Complex64, reflect.Complex128: + bitSize := 64 + if t.Kind() == reflect.Complex128 { + bitSize = 128 + } + return func(key string, v reflect.Value) ([]Pair, error) { + return []Pair{{key, strconv.FormatComplex(v.Complex(), 'f', -1, bitSize)}}, nil + } + default: + return func(key string, v reflect.Value) ([]Pair, error) { + return nil, nil + } + } +} + +func (e *encoder) newFieldTypeEncoder(t reflect.Type) encoderFunc { + f, _ := t.FieldByName("Value") + enc := e.typeEncoder(f.Type) + + return func(key string, value reflect.Value) ([]Pair, error) { + present := value.FieldByName("Present") + if !present.Bool() { + return nil, nil + } + null := value.FieldByName("Null") + if null.Bool() { + return nil, fmt.Errorf("apiquery: field cannot be null") + } + raw := value.FieldByName("Raw") + if !raw.IsNil() { + return e.typeEncoder(raw.Type())(key, raw) + } + return enc(key, value.FieldByName("Value")) + } +} + +func (e *encoder) newTimeTypeEncoder(_ reflect.Type) encoderFunc { + format := e.dateFormat + return func(key string, value reflect.Value) ([]Pair, error) { + return []Pair{{ + key, + value.Convert(reflect.TypeOf(time.Time{})).Interface().(time.Time).Format(format), + }}, nil + } +} + +func (e encoder) newInterfaceEncoder() encoderFunc { + return func(key string, value reflect.Value) ([]Pair, error) { + value = value.Elem() + if !value.IsValid() { + return nil, nil + } + return e.typeEncoder(value.Type())(key, value) + } + +} diff --git a/vendor/github.com/openai/openai-go/internal/apiquery/query.go b/vendor/github.com/openai/openai-go/internal/apiquery/query.go new file mode 100644 index 0000000000..0f379fa33a --- /dev/null +++ b/vendor/github.com/openai/openai-go/internal/apiquery/query.go @@ -0,0 +1,55 @@ +package apiquery + +import ( + "net/url" + "reflect" + "time" +) + +func MarshalWithSettings(value any, settings QuerySettings) (url.Values, error) { + e := encoder{time.RFC3339, true, settings} + kv := url.Values{} + val := reflect.ValueOf(value) + if !val.IsValid() { + return nil, nil + } + typ := val.Type() + + pairs, err := e.typeEncoder(typ)("", val) + if err != nil { + return nil, err + } + for _, pair := range pairs { + kv.Add(pair.key, pair.value) + } + return kv, nil +} + +func Marshal(value any) (url.Values, error) { + return MarshalWithSettings(value, QuerySettings{}) +} + +type Queryer interface { + URLQuery() (url.Values, error) +} + +type QuerySettings struct { + NestedFormat NestedQueryFormat + ArrayFormat ArrayQueryFormat +} + +type NestedQueryFormat int + +const ( + NestedQueryFormatBrackets NestedQueryFormat = iota + NestedQueryFormatDots +) + +type ArrayQueryFormat int + +const ( + ArrayQueryFormatComma ArrayQueryFormat = iota + ArrayQueryFormatRepeat + ArrayQueryFormatIndices + ArrayQueryFormatBrackets +) diff --git a/vendor/github.com/openai/openai-go/internal/apiquery/richparam.go b/vendor/github.com/openai/openai-go/internal/apiquery/richparam.go new file mode 100644 index 0000000000..b1636e9ae1 --- /dev/null +++ b/vendor/github.com/openai/openai-go/internal/apiquery/richparam.go @@ -0,0 +1,20 @@ +package apiquery + +import ( + "reflect" + + "github.com/openai/openai-go/packages/param" +) + +func (e *encoder) newRichFieldTypeEncoder(t reflect.Type) encoderFunc { + f, _ := t.FieldByName("Value") + enc := e.typeEncoder(f.Type) + return func(key string, value reflect.Value) ([]Pair, error) { + if opt, ok := value.Interface().(param.Optional); ok && opt.Valid() { + return enc(key, value.FieldByIndex(f.Index)) + } else if ok && param.IsNull(opt) { + return []Pair{{key, "null"}}, nil + } + return nil, nil + } +} diff --git a/vendor/github.com/openai/openai-go/internal/apiquery/tag.go b/vendor/github.com/openai/openai-go/internal/apiquery/tag.go new file mode 100644 index 0000000000..772c40e1a9 --- /dev/null +++ b/vendor/github.com/openai/openai-go/internal/apiquery/tag.go @@ -0,0 +1,44 @@ +package apiquery + +import ( + "reflect" + "strings" +) + +const queryStructTag = "query" +const formatStructTag = "format" + +type parsedStructTag struct { + name string + omitempty bool + omitzero bool + inline bool +} + +func parseQueryStructTag(field reflect.StructField) (tag parsedStructTag, ok bool) { + raw, ok := field.Tag.Lookup(queryStructTag) + if !ok { + return + } + parts := strings.Split(raw, ",") + if len(parts) == 0 { + return tag, false + } + tag.name = parts[0] + for _, part := range parts[1:] { + switch part { + case "omitzero": + tag.omitzero = true + case "omitempty": + tag.omitempty = true + case "inline": + tag.inline = true + } + } + return +} + +func parseFormatStructTag(field reflect.StructField) (format string, ok bool) { + format, ok = field.Tag.Lookup(formatStructTag) + return +} diff --git a/vendor/github.com/openai/openai-go/internal/encoding/json/decode.go b/vendor/github.com/openai/openai-go/internal/encoding/json/decode.go new file mode 100644 index 0000000000..93214331b5 --- /dev/null +++ b/vendor/github.com/openai/openai-go/internal/encoding/json/decode.go @@ -0,0 +1,1324 @@ +// Vendored from Go 1.24.0-pre-release +// To find alterations, check package shims, and comments beginning in SHIM(). +// +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Represents JSON data structure using native Go types: booleans, floats, +// strings, arrays, and maps. + +package json + +import ( + "encoding" + "encoding/base64" + "fmt" + "github.com/openai/openai-go/internal/encoding/json/shims" + "reflect" + "strconv" + "strings" + "unicode" + "unicode/utf16" + "unicode/utf8" + _ "unsafe" // for linkname +) + +// Unmarshal parses the JSON-encoded data and stores the result +// in the value pointed to by v. If v is nil or not a pointer, +// Unmarshal returns an [InvalidUnmarshalError]. +// +// Unmarshal uses the inverse of the encodings that +// [Marshal] uses, allocating maps, slices, and pointers as necessary, +// with the following additional rules: +// +// To unmarshal JSON into a pointer, Unmarshal first handles the case of +// the JSON being the JSON literal null. In that case, Unmarshal sets +// the pointer to nil. Otherwise, Unmarshal unmarshals the JSON into +// the value pointed at by the pointer. If the pointer is nil, Unmarshal +// allocates a new value for it to point to. +// +// To unmarshal JSON into a value implementing [Unmarshaler], +// Unmarshal calls that value's [Unmarshaler.UnmarshalJSON] method, including +// when the input is a JSON null. +// Otherwise, if the value implements [encoding.TextUnmarshaler] +// and the input is a JSON quoted string, Unmarshal calls +// [encoding.TextUnmarshaler.UnmarshalText] with the unquoted form of the string. +// +// To unmarshal JSON into a struct, Unmarshal matches incoming object +// keys to the keys used by [Marshal] (either the struct field name or its tag), +// preferring an exact match but also accepting a case-insensitive match. By +// default, object keys which don't have a corresponding struct field are +// ignored (see [Decoder.DisallowUnknownFields] for an alternative). +// +// To unmarshal JSON into an interface value, +// Unmarshal stores one of these in the interface value: +// +// - bool, for JSON booleans +// - float64, for JSON numbers +// - string, for JSON strings +// - []any, for JSON arrays +// - map[string]any, for JSON objects +// - nil for JSON null +// +// To unmarshal a JSON array into a slice, Unmarshal resets the slice length +// to zero and then appends each element to the slice. +// As a special case, to unmarshal an empty JSON array into a slice, +// Unmarshal replaces the slice with a new empty slice. +// +// To unmarshal a JSON array into a Go array, Unmarshal decodes +// JSON array elements into corresponding Go array elements. +// If the Go array is smaller than the JSON array, +// the additional JSON array elements are discarded. +// If the JSON array is smaller than the Go array, +// the additional Go array elements are set to zero values. +// +// To unmarshal a JSON object into a map, Unmarshal first establishes a map to +// use. If the map is nil, Unmarshal allocates a new map. Otherwise Unmarshal +// reuses the existing map, keeping existing entries. Unmarshal then stores +// key-value pairs from the JSON object into the map. The map's key type must +// either be any string type, an integer, or implement [encoding.TextUnmarshaler]. +// +// If the JSON-encoded data contain a syntax error, Unmarshal returns a [SyntaxError]. +// +// If a JSON value is not appropriate for a given target type, +// or if a JSON number overflows the target type, Unmarshal +// skips that field and completes the unmarshaling as best it can. +// If no more serious errors are encountered, Unmarshal returns +// an [UnmarshalTypeError] describing the earliest such error. In any +// case, it's not guaranteed that all the remaining fields following +// the problematic one will be unmarshaled into the target object. +// +// The JSON null value unmarshals into an interface, map, pointer, or slice +// by setting that Go value to nil. Because null is often used in JSON to mean +// “not present,” unmarshaling a JSON null into any other Go type has no effect +// on the value and produces no error. +// +// When unmarshaling quoted strings, invalid UTF-8 or +// invalid UTF-16 surrogate pairs are not treated as an error. +// Instead, they are replaced by the Unicode replacement +// character U+FFFD. +func Unmarshal(data []byte, v any) error { + // Check for well-formedness. + // Avoids filling out half a data structure + // before discovering a JSON syntax error. + var d decodeState + err := checkValid(data, &d.scan) + if err != nil { + return err + } + + d.init(data) + return d.unmarshal(v) +} + +// Unmarshaler is the interface implemented by types +// that can unmarshal a JSON description of themselves. +// The input can be assumed to be a valid encoding of +// a JSON value. UnmarshalJSON must copy the JSON data +// if it wishes to retain the data after returning. +// +// By convention, to approximate the behavior of [Unmarshal] itself, +// Unmarshalers implement UnmarshalJSON([]byte("null")) as a no-op. +type Unmarshaler interface { + UnmarshalJSON([]byte) error +} + +// An UnmarshalTypeError describes a JSON value that was +// not appropriate for a value of a specific Go type. +type UnmarshalTypeError struct { + Value string // description of JSON value - "bool", "array", "number -5" + Type reflect.Type // type of Go value it could not be assigned to + Offset int64 // error occurred after reading Offset bytes + Struct string // name of the struct type containing the field + Field string // the full path from root node to the field, include embedded struct +} + +func (e *UnmarshalTypeError) Error() string { + if e.Struct != "" || e.Field != "" { + return "json: cannot unmarshal " + e.Value + " into Go struct field " + e.Struct + "." + e.Field + " of type " + e.Type.String() + } + return "json: cannot unmarshal " + e.Value + " into Go value of type " + e.Type.String() +} + +// An UnmarshalFieldError describes a JSON object key that +// led to an unexported (and therefore unwritable) struct field. +// +// Deprecated: No longer used; kept for compatibility. +type UnmarshalFieldError struct { + Key string + Type reflect.Type + Field reflect.StructField +} + +func (e *UnmarshalFieldError) Error() string { + return "json: cannot unmarshal object key " + strconv.Quote(e.Key) + " into unexported field " + e.Field.Name + " of type " + e.Type.String() +} + +// An InvalidUnmarshalError describes an invalid argument passed to [Unmarshal]. +// (The argument to [Unmarshal] must be a non-nil pointer.) +type InvalidUnmarshalError struct { + Type reflect.Type +} + +func (e *InvalidUnmarshalError) Error() string { + if e.Type == nil { + return "json: Unmarshal(nil)" + } + + if e.Type.Kind() != reflect.Pointer { + return "json: Unmarshal(non-pointer " + e.Type.String() + ")" + } + return "json: Unmarshal(nil " + e.Type.String() + ")" +} + +func (d *decodeState) unmarshal(v any) error { + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Pointer || rv.IsNil() { + return &InvalidUnmarshalError{reflect.TypeOf(v)} + } + + d.scan.reset() + d.scanWhile(scanSkipSpace) + // We decode rv not rv.Elem because the Unmarshaler interface + // test must be applied at the top level of the value. + err := d.value(rv) + if err != nil { + return d.addErrorContext(err) + } + return d.savedError +} + +// A Number represents a JSON number literal. +type Number string + +// String returns the literal text of the number. +func (n Number) String() string { return string(n) } + +// Float64 returns the number as a float64. +func (n Number) Float64() (float64, error) { + return strconv.ParseFloat(string(n), 64) +} + +// Int64 returns the number as an int64. +func (n Number) Int64() (int64, error) { + return strconv.ParseInt(string(n), 10, 64) +} + +// An errorContext provides context for type errors during decoding. +type errorContext struct { + Struct reflect.Type + FieldStack []string +} + +// decodeState represents the state while decoding a JSON value. +type decodeState struct { + data []byte + off int // next read offset in data + opcode int // last read result + scan scanner + errorContext *errorContext + savedError error + useNumber bool + disallowUnknownFields bool +} + +// readIndex returns the position of the last byte read. +func (d *decodeState) readIndex() int { + return d.off - 1 +} + +// phasePanicMsg is used as a panic message when we end up with something that +// shouldn't happen. It can indicate a bug in the JSON decoder, or that +// something is editing the data slice while the decoder executes. +const phasePanicMsg = "JSON decoder out of sync - data changing underfoot?" + +func (d *decodeState) init(data []byte) *decodeState { + d.data = data + d.off = 0 + d.savedError = nil + if d.errorContext != nil { + d.errorContext.Struct = nil + // Reuse the allocated space for the FieldStack slice. + d.errorContext.FieldStack = d.errorContext.FieldStack[:0] + } + return d +} + +// saveError saves the first err it is called with, +// for reporting at the end of the unmarshal. +func (d *decodeState) saveError(err error) { + if d.savedError == nil { + d.savedError = d.addErrorContext(err) + } +} + +// addErrorContext returns a new error enhanced with information from d.errorContext +func (d *decodeState) addErrorContext(err error) error { + if d.errorContext != nil && (d.errorContext.Struct != nil || len(d.errorContext.FieldStack) > 0) { + switch err := err.(type) { + case *UnmarshalTypeError: + err.Struct = d.errorContext.Struct.Name() + fieldStack := d.errorContext.FieldStack + if err.Field != "" { + fieldStack = append(fieldStack, err.Field) + } + err.Field = strings.Join(fieldStack, ".") + } + } + return err +} + +// skip scans to the end of what was started. +func (d *decodeState) skip() { + s, data, i := &d.scan, d.data, d.off + depth := len(s.parseState) + for { + op := s.step(s, data[i]) + i++ + if len(s.parseState) < depth { + d.off = i + d.opcode = op + return + } + } +} + +// scanNext processes the byte at d.data[d.off]. +func (d *decodeState) scanNext() { + if d.off < len(d.data) { + d.opcode = d.scan.step(&d.scan, d.data[d.off]) + d.off++ + } else { + d.opcode = d.scan.eof() + d.off = len(d.data) + 1 // mark processed EOF with len+1 + } +} + +// scanWhile processes bytes in d.data[d.off:] until it +// receives a scan code not equal to op. +func (d *decodeState) scanWhile(op int) { + s, data, i := &d.scan, d.data, d.off + for i < len(data) { + newOp := s.step(s, data[i]) + i++ + if newOp != op { + d.opcode = newOp + d.off = i + return + } + } + + d.off = len(data) + 1 // mark processed EOF with len+1 + d.opcode = d.scan.eof() +} + +// rescanLiteral is similar to scanWhile(scanContinue), but it specialises the +// common case where we're decoding a literal. The decoder scans the input +// twice, once for syntax errors and to check the length of the value, and the +// second to perform the decoding. +// +// Only in the second step do we use decodeState to tokenize literals, so we +// know there aren't any syntax errors. We can take advantage of that knowledge, +// and scan a literal's bytes much more quickly. +func (d *decodeState) rescanLiteral() { + data, i := d.data, d.off +Switch: + switch data[i-1] { + case '"': // string + for ; i < len(data); i++ { + switch data[i] { + case '\\': + i++ // escaped char + case '"': + i++ // tokenize the closing quote too + break Switch + } + } + case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-': // number + for ; i < len(data); i++ { + switch data[i] { + case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', + '.', 'e', 'E', '+', '-': + default: + break Switch + } + } + case 't': // true + i += len("rue") + case 'f': // false + i += len("alse") + case 'n': // null + i += len("ull") + } + if i < len(data) { + d.opcode = stateEndValue(&d.scan, data[i]) + } else { + d.opcode = scanEnd + } + d.off = i + 1 +} + +// value consumes a JSON value from d.data[d.off-1:], decoding into v, and +// reads the following byte ahead. If v is invalid, the value is discarded. +// The first byte of the value has been read already. +func (d *decodeState) value(v reflect.Value) error { + switch d.opcode { + default: + panic(phasePanicMsg) + + case scanBeginArray: + if v.IsValid() { + if err := d.array(v); err != nil { + return err + } + } else { + d.skip() + } + d.scanNext() + + case scanBeginObject: + if v.IsValid() { + if err := d.object(v); err != nil { + return err + } + } else { + d.skip() + } + d.scanNext() + + case scanBeginLiteral: + // All bytes inside literal return scanContinue op code. + start := d.readIndex() + d.rescanLiteral() + + if v.IsValid() { + if err := d.literalStore(d.data[start:d.readIndex()], v, false); err != nil { + return err + } + } + } + return nil +} + +type unquotedValue struct{} + +// valueQuoted is like value but decodes a +// quoted string literal or literal null into an interface value. +// If it finds anything other than a quoted string literal or null, +// valueQuoted returns unquotedValue{}. +func (d *decodeState) valueQuoted() any { + switch d.opcode { + default: + panic(phasePanicMsg) + + case scanBeginArray, scanBeginObject: + d.skip() + d.scanNext() + + case scanBeginLiteral: + v := d.literalInterface() + switch v.(type) { + case nil, string: + return v + } + } + return unquotedValue{} +} + +// indirect walks down v allocating pointers as needed, +// until it gets to a non-pointer. +// If it encounters an Unmarshaler, indirect stops and returns that. +// If decodingNull is true, indirect stops at the first settable pointer so it +// can be set to nil. +func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnmarshaler, reflect.Value) { + // Issue #24153 indicates that it is generally not a guaranteed property + // that you may round-trip a reflect.Value by calling Value.Addr().Elem() + // and expect the value to still be settable for values derived from + // unexported embedded struct fields. + // + // The logic below effectively does this when it first addresses the value + // (to satisfy possible pointer methods) and continues to dereference + // subsequent pointers as necessary. + // + // After the first round-trip, we set v back to the original value to + // preserve the original RW flags contained in reflect.Value. + v0 := v + haveAddr := false + + // If v is a named type and is addressable, + // start with its address, so that if the type has pointer methods, + // we find them. + if v.Kind() != reflect.Pointer && v.Type().Name() != "" && v.CanAddr() { + haveAddr = true + v = v.Addr() + } + for { + // Load value from interface, but only if the result will be + // usefully addressable. + if v.Kind() == reflect.Interface && !v.IsNil() { + e := v.Elem() + if e.Kind() == reflect.Pointer && !e.IsNil() && (!decodingNull || e.Elem().Kind() == reflect.Pointer) { + haveAddr = false + v = e + continue + } + } + + if v.Kind() != reflect.Pointer { + break + } + + if decodingNull && v.CanSet() { + break + } + + // Prevent infinite loop if v is an interface pointing to its own address: + // var v any + // v = &v + if v.Elem().Kind() == reflect.Interface && v.Elem().Elem().Equal(v) { + v = v.Elem() + break + } + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + if v.Type().NumMethod() > 0 && v.CanInterface() { + if u, ok := v.Interface().(Unmarshaler); ok { + return u, nil, reflect.Value{} + } + if !decodingNull { + if u, ok := v.Interface().(encoding.TextUnmarshaler); ok { + return nil, u, reflect.Value{} + } + } + } + + if haveAddr { + v = v0 // restore original value after round-trip Value.Addr().Elem() + haveAddr = false + } else { + v = v.Elem() + } + } + return nil, nil, v +} + +// array consumes an array from d.data[d.off-1:], decoding into v. +// The first byte of the array ('[') has been read already. +func (d *decodeState) array(v reflect.Value) error { + // Check for unmarshaler. + u, ut, pv := indirect(v, false) + if u != nil { + start := d.readIndex() + d.skip() + return u.UnmarshalJSON(d.data[start:d.off]) + } + if ut != nil { + d.saveError(&UnmarshalTypeError{Value: "array", Type: v.Type(), Offset: int64(d.off)}) + d.skip() + return nil + } + v = pv + + // Check type of target. + switch v.Kind() { + case reflect.Interface: + if v.NumMethod() == 0 { + // Decoding into nil interface? Switch to non-reflect code. + ai := d.arrayInterface() + v.Set(reflect.ValueOf(ai)) + return nil + } + // Otherwise it's invalid. + fallthrough + default: + d.saveError(&UnmarshalTypeError{Value: "array", Type: v.Type(), Offset: int64(d.off)}) + d.skip() + return nil + case reflect.Array, reflect.Slice: + break + } + + i := 0 + for { + // Look ahead for ] - can only happen on first iteration. + d.scanWhile(scanSkipSpace) + if d.opcode == scanEndArray { + break + } + + // Expand slice length, growing the slice if necessary. + if v.Kind() == reflect.Slice { + if i >= v.Cap() { + v.Grow(1) + } + if i >= v.Len() { + v.SetLen(i + 1) + } + } + + if i < v.Len() { + // Decode into element. + if err := d.value(v.Index(i)); err != nil { + return err + } + } else { + // Ran out of fixed array: skip. + if err := d.value(reflect.Value{}); err != nil { + return err + } + } + i++ + + // Next token must be , or ]. + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.opcode == scanEndArray { + break + } + if d.opcode != scanArrayValue { + panic(phasePanicMsg) + } + } + + if i < v.Len() { + if v.Kind() == reflect.Array { + for ; i < v.Len(); i++ { + v.Index(i).SetZero() // zero remainder of array + } + } else { + v.SetLen(i) // truncate the slice + } + } + if i == 0 && v.Kind() == reflect.Slice { + v.Set(reflect.MakeSlice(v.Type(), 0, 0)) + } + return nil +} + +var nullLiteral = []byte("null") + +// SHIM(reflect): reflect.TypeFor[T]() reflect.T +var textUnmarshalerType = shims.TypeFor[encoding.TextUnmarshaler]() + +// object consumes an object from d.data[d.off-1:], decoding into v. +// The first byte ('{') of the object has been read already. +func (d *decodeState) object(v reflect.Value) error { + // Check for unmarshaler. + u, ut, pv := indirect(v, false) + if u != nil { + start := d.readIndex() + d.skip() + return u.UnmarshalJSON(d.data[start:d.off]) + } + if ut != nil { + d.saveError(&UnmarshalTypeError{Value: "object", Type: v.Type(), Offset: int64(d.off)}) + d.skip() + return nil + } + v = pv + t := v.Type() + + // Decoding into nil interface? Switch to non-reflect code. + if v.Kind() == reflect.Interface && v.NumMethod() == 0 { + oi := d.objectInterface() + v.Set(reflect.ValueOf(oi)) + return nil + } + + var fields structFields + + // Check type of target: + // struct or + // map[T1]T2 where T1 is string, an integer type, + // or an encoding.TextUnmarshaler + switch v.Kind() { + case reflect.Map: + // Map key must either have string kind, have an integer kind, + // or be an encoding.TextUnmarshaler. + switch t.Key().Kind() { + case reflect.String, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + default: + if !reflect.PointerTo(t.Key()).Implements(textUnmarshalerType) { + d.saveError(&UnmarshalTypeError{Value: "object", Type: t, Offset: int64(d.off)}) + d.skip() + return nil + } + } + if v.IsNil() { + v.Set(reflect.MakeMap(t)) + } + case reflect.Struct: + fields = cachedTypeFields(t) + // ok + default: + d.saveError(&UnmarshalTypeError{Value: "object", Type: t, Offset: int64(d.off)}) + d.skip() + return nil + } + + var mapElem reflect.Value + var origErrorContext errorContext + if d.errorContext != nil { + origErrorContext = *d.errorContext + } + + for { + // Read opening " of string key or closing }. + d.scanWhile(scanSkipSpace) + if d.opcode == scanEndObject { + // closing } - can only happen on first iteration. + break + } + if d.opcode != scanBeginLiteral { + panic(phasePanicMsg) + } + + // Read key. + start := d.readIndex() + d.rescanLiteral() + item := d.data[start:d.readIndex()] + key, ok := unquoteBytes(item) + if !ok { + panic(phasePanicMsg) + } + + // Figure out field corresponding to key. + var subv reflect.Value + destring := false // whether the value is wrapped in a string to be decoded first + + if v.Kind() == reflect.Map { + elemType := t.Elem() + if !mapElem.IsValid() { + mapElem = reflect.New(elemType).Elem() + } else { + mapElem.SetZero() + } + subv = mapElem + } else { + f := fields.byExactName[string(key)] + if f == nil { + f = fields.byFoldedName[string(foldName(key))] + } + if f != nil { + subv = v + destring = f.quoted + if d.errorContext == nil { + d.errorContext = new(errorContext) + } + for i, ind := range f.index { + if subv.Kind() == reflect.Pointer { + if subv.IsNil() { + // If a struct embeds a pointer to an unexported type, + // it is not possible to set a newly allocated value + // since the field is unexported. + // + // See https://golang.org/issue/21357 + if !subv.CanSet() { + d.saveError(fmt.Errorf("json: cannot set embedded pointer to unexported struct: %v", subv.Type().Elem())) + // Invalidate subv to ensure d.value(subv) skips over + // the JSON value without assigning it to subv. + subv = reflect.Value{} + destring = false + break + } + subv.Set(reflect.New(subv.Type().Elem())) + } + subv = subv.Elem() + } + if i < len(f.index)-1 { + d.errorContext.FieldStack = append( + d.errorContext.FieldStack, + subv.Type().Field(ind).Name, + ) + } + subv = subv.Field(ind) + } + d.errorContext.Struct = t + d.errorContext.FieldStack = append(d.errorContext.FieldStack, f.name) + } else if d.disallowUnknownFields { + d.saveError(fmt.Errorf("json: unknown field %q", key)) + } + } + + // Read : before value. + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.opcode != scanObjectKey { + panic(phasePanicMsg) + } + d.scanWhile(scanSkipSpace) + + if destring { + switch qv := d.valueQuoted().(type) { + case nil: + if err := d.literalStore(nullLiteral, subv, false); err != nil { + return err + } + case string: + if err := d.literalStore([]byte(qv), subv, true); err != nil { + return err + } + default: + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal unquoted value into %v", subv.Type())) + } + } else { + if err := d.value(subv); err != nil { + return err + } + } + + // Write value back to map; + // if using struct, subv points into struct already. + if v.Kind() == reflect.Map { + kt := t.Key() + var kv reflect.Value + if reflect.PointerTo(kt).Implements(textUnmarshalerType) { + kv = reflect.New(kt) + if err := d.literalStore(item, kv, true); err != nil { + return err + } + kv = kv.Elem() + } else { + switch kt.Kind() { + case reflect.String: + kv = reflect.New(kt).Elem() + kv.SetString(string(key)) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + s := string(key) + n, err := strconv.ParseInt(s, 10, 64) + // SHIM(reflect): reflect.Type.OverflowInt(int64) bool + okt := shims.OverflowableType{Type: kt} + if err != nil || okt.OverflowInt(n) { + d.saveError(&UnmarshalTypeError{Value: "number " + s, Type: kt, Offset: int64(start + 1)}) + break + } + kv = reflect.New(kt).Elem() + kv.SetInt(n) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + s := string(key) + n, err := strconv.ParseUint(s, 10, 64) + // SHIM(reflect): reflect.Type.OverflowUint(uint64) bool + okt := shims.OverflowableType{Type: kt} + if err != nil || okt.OverflowUint(n) { + d.saveError(&UnmarshalTypeError{Value: "number " + s, Type: kt, Offset: int64(start + 1)}) + break + } + kv = reflect.New(kt).Elem() + kv.SetUint(n) + default: + panic("json: Unexpected key type") // should never occur + } + } + if kv.IsValid() { + v.SetMapIndex(kv, subv) + } + } + + // Next token must be , or }. + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.errorContext != nil { + // Reset errorContext to its original state. + // Keep the same underlying array for FieldStack, to reuse the + // space and avoid unnecessary allocs. + d.errorContext.FieldStack = d.errorContext.FieldStack[:len(origErrorContext.FieldStack)] + d.errorContext.Struct = origErrorContext.Struct + } + if d.opcode == scanEndObject { + break + } + if d.opcode != scanObjectValue { + panic(phasePanicMsg) + } + } + return nil +} + +// convertNumber converts the number literal s to a float64 or a Number +// depending on the setting of d.useNumber. +func (d *decodeState) convertNumber(s string) (any, error) { + if d.useNumber { + return Number(s), nil + } + f, err := strconv.ParseFloat(s, 64) + if err != nil { + // SHIM(reflect): reflect.TypeFor[T]() reflect.Type + return nil, &UnmarshalTypeError{Value: "number " + s, Type: shims.TypeFor[float64](), Offset: int64(d.off)} + } + return f, nil +} + +// SHIM(reflect): TypeFor[T]() reflect.Type +var numberType = shims.TypeFor[Number]() + +// literalStore decodes a literal stored in item into v. +// +// fromQuoted indicates whether this literal came from unwrapping a +// string from the ",string" struct tag option. this is used only to +// produce more helpful error messages. +func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool) error { + // Check for unmarshaler. + if len(item) == 0 { + // Empty string given. + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + return nil + } + isNull := item[0] == 'n' // null + u, ut, pv := indirect(v, isNull) + if u != nil { + return u.UnmarshalJSON(item) + } + if ut != nil { + if item[0] != '"' { + if fromQuoted { + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + return nil + } + val := "number" + switch item[0] { + case 'n': + val = "null" + case 't', 'f': + val = "bool" + } + d.saveError(&UnmarshalTypeError{Value: val, Type: v.Type(), Offset: int64(d.readIndex())}) + return nil + } + s, ok := unquoteBytes(item) + if !ok { + if fromQuoted { + return fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()) + } + panic(phasePanicMsg) + } + return ut.UnmarshalText(s) + } + + v = pv + + switch c := item[0]; c { + case 'n': // null + // The main parser checks that only true and false can reach here, + // but if this was a quoted string input, it could be anything. + if fromQuoted && string(item) != "null" { + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + break + } + switch v.Kind() { + case reflect.Interface, reflect.Pointer, reflect.Map, reflect.Slice: + v.SetZero() + // otherwise, ignore null for primitives/string + } + case 't', 'f': // true, false + value := item[0] == 't' + // The main parser checks that only true and false can reach here, + // but if this was a quoted string input, it could be anything. + if fromQuoted && string(item) != "true" && string(item) != "false" { + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + break + } + switch v.Kind() { + default: + if fromQuoted { + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + } else { + d.saveError(&UnmarshalTypeError{Value: "bool", Type: v.Type(), Offset: int64(d.readIndex())}) + } + case reflect.Bool: + v.SetBool(value) + case reflect.Interface: + if v.NumMethod() == 0 { + v.Set(reflect.ValueOf(value)) + } else { + d.saveError(&UnmarshalTypeError{Value: "bool", Type: v.Type(), Offset: int64(d.readIndex())}) + } + } + + case '"': // string + s, ok := unquoteBytes(item) + if !ok { + if fromQuoted { + return fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()) + } + panic(phasePanicMsg) + } + switch v.Kind() { + default: + d.saveError(&UnmarshalTypeError{Value: "string", Type: v.Type(), Offset: int64(d.readIndex())}) + case reflect.Slice: + if v.Type().Elem().Kind() != reflect.Uint8 { + d.saveError(&UnmarshalTypeError{Value: "string", Type: v.Type(), Offset: int64(d.readIndex())}) + break + } + b := make([]byte, base64.StdEncoding.DecodedLen(len(s))) + n, err := base64.StdEncoding.Decode(b, s) + if err != nil { + d.saveError(err) + break + } + v.SetBytes(b[:n]) + case reflect.String: + t := string(s) + if v.Type() == numberType && !isValidNumber(t) { + return fmt.Errorf("json: invalid number literal, trying to unmarshal %q into Number", item) + } + v.SetString(t) + case reflect.Interface: + if v.NumMethod() == 0 { + v.Set(reflect.ValueOf(string(s))) + } else { + d.saveError(&UnmarshalTypeError{Value: "string", Type: v.Type(), Offset: int64(d.readIndex())}) + } + } + + default: // number + if c != '-' && (c < '0' || c > '9') { + if fromQuoted { + return fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()) + } + panic(phasePanicMsg) + } + switch v.Kind() { + default: + if v.Kind() == reflect.String && v.Type() == numberType { + // s must be a valid number, because it's + // already been tokenized. + v.SetString(string(item)) + break + } + if fromQuoted { + return fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()) + } + d.saveError(&UnmarshalTypeError{Value: "number", Type: v.Type(), Offset: int64(d.readIndex())}) + case reflect.Interface: + n, err := d.convertNumber(string(item)) + if err != nil { + d.saveError(err) + break + } + if v.NumMethod() != 0 { + d.saveError(&UnmarshalTypeError{Value: "number", Type: v.Type(), Offset: int64(d.readIndex())}) + break + } + v.Set(reflect.ValueOf(n)) + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + n, err := strconv.ParseInt(string(item), 10, 64) + if err != nil || v.OverflowInt(n) { + d.saveError(&UnmarshalTypeError{Value: "number " + string(item), Type: v.Type(), Offset: int64(d.readIndex())}) + break + } + v.SetInt(n) + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + n, err := strconv.ParseUint(string(item), 10, 64) + if err != nil || v.OverflowUint(n) { + d.saveError(&UnmarshalTypeError{Value: "number " + string(item), Type: v.Type(), Offset: int64(d.readIndex())}) + break + } + v.SetUint(n) + + case reflect.Float32, reflect.Float64: + n, err := strconv.ParseFloat(string(item), v.Type().Bits()) + if err != nil || v.OverflowFloat(n) { + d.saveError(&UnmarshalTypeError{Value: "number " + string(item), Type: v.Type(), Offset: int64(d.readIndex())}) + break + } + v.SetFloat(n) + } + } + return nil +} + +// The xxxInterface routines build up a value to be stored +// in an empty interface. They are not strictly necessary, +// but they avoid the weight of reflection in this common case. + +// valueInterface is like value but returns any. +func (d *decodeState) valueInterface() (val any) { + switch d.opcode { + default: + panic(phasePanicMsg) + case scanBeginArray: + val = d.arrayInterface() + d.scanNext() + case scanBeginObject: + val = d.objectInterface() + d.scanNext() + case scanBeginLiteral: + val = d.literalInterface() + } + return +} + +// arrayInterface is like array but returns []any. +func (d *decodeState) arrayInterface() []any { + var v = make([]any, 0) + for { + // Look ahead for ] - can only happen on first iteration. + d.scanWhile(scanSkipSpace) + if d.opcode == scanEndArray { + break + } + + v = append(v, d.valueInterface()) + + // Next token must be , or ]. + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.opcode == scanEndArray { + break + } + if d.opcode != scanArrayValue { + panic(phasePanicMsg) + } + } + return v +} + +// objectInterface is like object but returns map[string]any. +func (d *decodeState) objectInterface() map[string]any { + m := make(map[string]any) + for { + // Read opening " of string key or closing }. + d.scanWhile(scanSkipSpace) + if d.opcode == scanEndObject { + // closing } - can only happen on first iteration. + break + } + if d.opcode != scanBeginLiteral { + panic(phasePanicMsg) + } + + // Read string key. + start := d.readIndex() + d.rescanLiteral() + item := d.data[start:d.readIndex()] + key, ok := unquote(item) + if !ok { + panic(phasePanicMsg) + } + + // Read : before value. + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.opcode != scanObjectKey { + panic(phasePanicMsg) + } + d.scanWhile(scanSkipSpace) + + // Read value. + m[key] = d.valueInterface() + + // Next token must be , or }. + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.opcode == scanEndObject { + break + } + if d.opcode != scanObjectValue { + panic(phasePanicMsg) + } + } + return m +} + +// literalInterface consumes and returns a literal from d.data[d.off-1:] and +// it reads the following byte ahead. The first byte of the literal has been +// read already (that's how the caller knows it's a literal). +func (d *decodeState) literalInterface() any { + // All bytes inside literal return scanContinue op code. + start := d.readIndex() + d.rescanLiteral() + + item := d.data[start:d.readIndex()] + + switch c := item[0]; c { + case 'n': // null + return nil + + case 't', 'f': // true, false + return c == 't' + + case '"': // string + s, ok := unquote(item) + if !ok { + panic(phasePanicMsg) + } + return s + + default: // number + if c != '-' && (c < '0' || c > '9') { + panic(phasePanicMsg) + } + n, err := d.convertNumber(string(item)) + if err != nil { + d.saveError(err) + } + return n + } +} + +// getu4 decodes \uXXXX from the beginning of s, returning the hex value, +// or it returns -1. +func getu4(s []byte) rune { + if len(s) < 6 || s[0] != '\\' || s[1] != 'u' { + return -1 + } + var r rune + for _, c := range s[2:6] { + switch { + case '0' <= c && c <= '9': + c = c - '0' + case 'a' <= c && c <= 'f': + c = c - 'a' + 10 + case 'A' <= c && c <= 'F': + c = c - 'A' + 10 + default: + return -1 + } + r = r*16 + rune(c) + } + return r +} + +// unquote converts a quoted JSON string literal s into an actual string t. +// The rules are different than for Go, so cannot use strconv.Unquote. +func unquote(s []byte) (t string, ok bool) { + s, ok = unquoteBytes(s) + t = string(s) + return +} + +// unquoteBytes should be an internal detail, +// but widely used packages access it using linkname. +// Notable members of the hall of shame include: +// - github.com/bytedance/sonic +// +// Do not remove or change the type signature. +// See go.dev/issue/67401. +// +//go:linkname unquoteBytes +func unquoteBytes(s []byte) (t []byte, ok bool) { + if len(s) < 2 || s[0] != '"' || s[len(s)-1] != '"' { + return + } + s = s[1 : len(s)-1] + + // Check for unusual characters. If there are none, + // then no unquoting is needed, so return a slice of the + // original bytes. + r := 0 + for r < len(s) { + c := s[r] + if c == '\\' || c == '"' || c < ' ' { + break + } + if c < utf8.RuneSelf { + r++ + continue + } + rr, size := utf8.DecodeRune(s[r:]) + if rr == utf8.RuneError && size == 1 { + break + } + r += size + } + if r == len(s) { + return s, true + } + + b := make([]byte, len(s)+2*utf8.UTFMax) + w := copy(b, s[0:r]) + for r < len(s) { + // Out of room? Can only happen if s is full of + // malformed UTF-8 and we're replacing each + // byte with RuneError. + if w >= len(b)-2*utf8.UTFMax { + nb := make([]byte, (len(b)+utf8.UTFMax)*2) + copy(nb, b[0:w]) + b = nb + } + switch c := s[r]; { + case c == '\\': + r++ + if r >= len(s) { + return + } + switch s[r] { + default: + return + case '"', '\\', '/', '\'': + b[w] = s[r] + r++ + w++ + case 'b': + b[w] = '\b' + r++ + w++ + case 'f': + b[w] = '\f' + r++ + w++ + case 'n': + b[w] = '\n' + r++ + w++ + case 'r': + b[w] = '\r' + r++ + w++ + case 't': + b[w] = '\t' + r++ + w++ + case 'u': + r-- + rr := getu4(s[r:]) + if rr < 0 { + return + } + r += 6 + if utf16.IsSurrogate(rr) { + rr1 := getu4(s[r:]) + if dec := utf16.DecodeRune(rr, rr1); dec != unicode.ReplacementChar { + // A valid pair; consume. + r += 6 + w += utf8.EncodeRune(b[w:], dec) + break + } + // Invalid surrogate; fall back to replacement rune. + rr = unicode.ReplacementChar + } + w += utf8.EncodeRune(b[w:], rr) + } + + // Quote, control characters are invalid. + case c == '"', c < ' ': + return + + // ASCII + case c < utf8.RuneSelf: + b[w] = c + r++ + w++ + + // Coerce to well-formed UTF-8. + default: + rr, size := utf8.DecodeRune(s[r:]) + r += size + w += utf8.EncodeRune(b[w:], rr) + } + } + return b[0:w], true +} diff --git a/vendor/github.com/openai/openai-go/internal/encoding/json/encode.go b/vendor/github.com/openai/openai-go/internal/encoding/json/encode.go new file mode 100644 index 0000000000..d547132584 --- /dev/null +++ b/vendor/github.com/openai/openai-go/internal/encoding/json/encode.go @@ -0,0 +1,1391 @@ +// Vendored from Go 1.24.0-pre-release +// To find alterations, check package shims, and comments beginning in SHIM(). +// +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package json implements encoding and decoding of JSON as defined in +// RFC 7159. The mapping between JSON and Go values is described +// in the documentation for the Marshal and Unmarshal functions. +// +// See "JSON and Go" for an introduction to this package: +// https://golang.org/doc/articles/json_and_go.html +package json + +import ( + "bytes" + "cmp" + "encoding" + "encoding/base64" + "fmt" + "github.com/openai/openai-go/internal/encoding/json/sentinel" + "github.com/openai/openai-go/internal/encoding/json/shims" + "math" + "reflect" + "slices" + "strconv" + "strings" + "sync" + "unicode" + "unicode/utf8" + _ "unsafe" // for linkname +) + +// Marshal returns the JSON encoding of v. +// +// Marshal traverses the value v recursively. +// If an encountered value implements [Marshaler] +// and is not a nil pointer, Marshal calls [Marshaler.MarshalJSON] +// to produce JSON. If no [Marshaler.MarshalJSON] method is present but the +// value implements [encoding.TextMarshaler] instead, Marshal calls +// [encoding.TextMarshaler.MarshalText] and encodes the result as a JSON string. +// The nil pointer exception is not strictly necessary +// but mimics a similar, necessary exception in the behavior of +// [Unmarshaler.UnmarshalJSON]. +// +// Otherwise, Marshal uses the following type-dependent default encodings: +// +// Boolean values encode as JSON booleans. +// +// Floating point, integer, and [Number] values encode as JSON numbers. +// NaN and +/-Inf values will return an [UnsupportedValueError]. +// +// String values encode as JSON strings coerced to valid UTF-8, +// replacing invalid bytes with the Unicode replacement rune. +// So that the JSON will be safe to embed inside HTML