diff --git a/cmd/search.go b/cmd/search.go index 3554baa..d092501 100644 --- a/cmd/search.go +++ b/cmd/search.go @@ -52,6 +52,10 @@ func runSearch(cmd *cobra.Command, args []string) error { after = searchAroundFlag } + if before < 0 || after < 0 { + return fmt.Errorf("--before, --after, and --around must be non-negative") + } + client := daemon.NewClient() if err := client.EnsureDaemon(); err != nil { return fmt.Errorf("daemon: %w", err) diff --git a/internal/daemon/server.go b/internal/daemon/server.go index c6b1ef9..8a8225c 100644 --- a/internal/daemon/server.go +++ b/internal/daemon/server.go @@ -17,6 +17,21 @@ import ( "github.com/schovi/shelli/internal/ansi" ) +type ptyHandle struct { + f *os.File + closeOnce sync.Once +} + +func (p *ptyHandle) Close() { + p.closeOnce.Do(func() { + p.f.Close() + }) +} + +func (p *ptyHandle) File() *os.File { + return p.f +} + type Session struct { Name string `json:"name"` PID int `json:"pid"` @@ -38,7 +53,7 @@ type SessionInfo struct { type Server struct { mu sync.Mutex sessions map[string]*Session - ptys map[string]*os.File + ptys map[string]*ptyHandle cmds map[string]*exec.Cmd doneChans map[string]chan struct{} frameDetectors map[string]*ansi.FrameDetector @@ -87,7 +102,7 @@ func NewServer(opts ...ServerOption) (*Server, error) { s := &Server{ sessions: make(map[string]*Session), - ptys: make(map[string]*os.File), + ptys: make(map[string]*ptyHandle), cmds: make(map[string]*exec.Cmd), doneChans: make(map[string]chan struct{}), frameDetectors: make(map[string]*ansi.FrameDetector), @@ -162,7 +177,10 @@ func (s *Server) Start() error { for { conn, err := listener.Accept() if err != nil { - if s.listener == nil { + s.mu.Lock() + isShutdown := s.listener == nil + s.mu.Unlock() + if isShutdown { return nil } return err @@ -211,8 +229,8 @@ func (s *Server) Shutdown() { if done, ok := s.doneChans[name]; ok { close(done) } - if ptmx, ok := s.ptys[name]; ok { - ptmx.Close() + if handle, ok := s.ptys[name]; ok { + handle.Close() } if cmd, ok := s.cmds[name]; ok { cmd.Process.Kill() @@ -380,8 +398,9 @@ func (s *Server) handleCreate(req Request) Response { CreatedAt: now, } + handle := &ptyHandle{f: ptmx} s.sessions[req.Name] = sess - s.ptys[req.Name] = ptmx + s.ptys[req.Name] = handle s.cmds[req.Name] = cmd s.doneChans[req.Name] = make(chan struct{}) if req.TUIMode { @@ -389,7 +408,7 @@ func (s *Server) handleCreate(req Request) Response { s.responders[req.Name] = ansi.NewTerminalResponder(ptmx) } - go s.captureOutput(req.Name, ptmx, cmd) + go s.captureOutput(req.Name, handle, cmd) return Response{Success: true, Data: map[string]interface{}{ "name": sess.Name, @@ -401,7 +420,7 @@ func (s *Server) handleCreate(req Request) Response { }} } -func (s *Server) captureOutput(name string, ptmx *os.File, cmd *exec.Cmd) { +func (s *Server) captureOutput(name string, handle *ptyHandle, cmd *exec.Cmd) { s.mu.Lock() done := s.doneChans[name] detector := s.frameDetectors[name] @@ -409,6 +428,8 @@ func (s *Server) captureOutput(name string, ptmx *os.File, cmd *exec.Cmd) { storage := s.storage s.mu.Unlock() + f := handle.File() + if detector != nil { defer func() { if pending := detector.Flush(); len(pending) > 0 { @@ -425,8 +446,8 @@ func (s *Server) captureOutput(name string, ptmx *os.File, cmd *exec.Cmd) { default: } - ptmx.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) - n, err := ptmx.Read(buf) + f.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + n, err := f.Read(buf) if n > 0 { data := buf[:n] if responder != nil { @@ -450,7 +471,7 @@ func (s *Server) captureOutput(name string, ptmx *os.File, cmd *exec.Cmd) { } cmd.Wait() - ptmx.Close() + handle.Close() s.mu.Lock() defer s.mu.Unlock() @@ -626,11 +647,12 @@ func (s *Server) handleSnapshot(req Request) Response { s.mu.Unlock() return Response{Success: false, Error: fmt.Sprintf("session %q is not in TUI mode (snapshot requires --tui)", req.Name)} } - ptmx, ok := s.ptys[req.Name] + handle, ok := s.ptys[req.Name] if !ok { s.mu.Unlock() return Response{Success: false, Error: fmt.Sprintf("session %q PTY not available", req.Name)} } + ptmx := handle.File() cmd := s.cmds[req.Name] storage := s.storage s.mu.Unlock() @@ -793,7 +815,7 @@ func (s *Server) handleSend(req Request) Response { s.mu.Unlock() return Response{Success: false, Error: fmt.Sprintf("session %q is stopped", req.Name)} } - ptmx, ok := s.ptys[req.Name] + handle, ok := s.ptys[req.Name] s.mu.Unlock() if !ok { @@ -805,7 +827,7 @@ func (s *Server) handleSend(req Request) Response { data += "\n" } - if _, err := ptmx.WriteString(data); err != nil { + if _, err := handle.File().WriteString(data); err != nil { return Response{Success: false, Error: err.Error()} } @@ -830,8 +852,8 @@ func (s *Server) handleStop(req Request) Response { delete(s.doneChans, req.Name) } - if ptmx, ok := s.ptys[req.Name]; ok { - ptmx.Close() + if handle, ok := s.ptys[req.Name]; ok { + handle.Close() delete(s.ptys, req.Name) } @@ -863,29 +885,25 @@ func (s *Server) handleStop(req Request) Response { func (s *Server) handleKill(req Request) Response { s.mu.Lock() - defer s.mu.Unlock() sess, exists := s.sessions[req.Name] if !exists { + s.mu.Unlock() return Response{Success: false, Error: fmt.Sprintf("session %q not found", req.Name)} } + var proc *os.Process if sess.State == StateRunning { if done, ok := s.doneChans[req.Name]; ok { close(done) delete(s.doneChans, req.Name) } - - if ptmx, ok := s.ptys[req.Name]; ok { - ptmx.Close() + if handle, ok := s.ptys[req.Name]; ok { + handle.Close() delete(s.ptys, req.Name) } - if cmd, ok := s.cmds[req.Name]; ok { - cmd.Process.Signal(syscall.SIGTERM) - time.Sleep(KillGracePeriod) - cmd.Process.Signal(syscall.SIGKILL) - cmd.Wait() + proc = cmd.Process delete(s.cmds, req.Name) } } @@ -894,6 +912,16 @@ func (s *Server) handleKill(req Request) Response { delete(s.sessions, req.Name) delete(s.frameDetectors, req.Name) delete(s.responders, req.Name) + s.mu.Unlock() + + if proc != nil { + go func() { + proc.Signal(syscall.SIGTERM) + time.Sleep(KillGracePeriod) + proc.Signal(syscall.SIGKILL) + proc.Wait() + }() + } return Response{Success: true} } @@ -916,6 +944,10 @@ func (s *Server) handleSize(req Request) Response { } func (s *Server) handleSearch(req Request) Response { + if req.Before < 0 || req.After < 0 { + return Response{Success: false, Error: "before and after must be non-negative"} + } + s.mu.Lock() _, exists := s.sessions[req.Name] if !exists { @@ -1056,7 +1088,7 @@ func (s *Server) handleResize(req Request) Response { return Response{Success: false, Error: fmt.Sprintf("session %q is stopped", req.Name)} } - ptmx, ok := s.ptys[req.Name] + handle, ok := s.ptys[req.Name] if !ok { s.mu.Unlock() return Response{Success: false, Error: fmt.Sprintf("session %q not running", req.Name)} @@ -1078,7 +1110,7 @@ func (s *Server) handleResize(req Request) Response { rows = meta.Rows } - if err := pty.Setsize(ptmx, &pty.Winsize{Cols: clampUint16(cols), Rows: clampUint16(rows)}); err != nil { + if err := pty.Setsize(handle.File(), &pty.Winsize{Cols: clampUint16(cols), Rows: clampUint16(rows)}); err != nil { return Response{Success: false, Error: fmt.Sprintf("resize: %v", err)} } diff --git a/internal/mcp/tools.go b/internal/mcp/tools.go index 57d6920..99925d4 100644 --- a/internal/mcp/tools.go +++ b/internal/mcp/tools.go @@ -831,6 +831,10 @@ func (r *ToolRegistry) callSearch(args json.RawMessage) (*CallToolResult, error) after = a.Around } + if before < 0 || after < 0 { + return nil, fmt.Errorf("before, after, and around must be non-negative") + } + resp, err := r.client.Search(daemon.SearchRequest{ Name: a.Name, Pattern: a.Pattern,