Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmd/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ func runSearch(cmd *cobra.Command, args []string) error {
after = searchAroundFlag
}

if before < 0 || after < 0 {
return fmt.Errorf("--before, --after, and --around must be non-negative")
}

client := daemon.NewClient()
if err := client.EnsureDaemon(); err != nil {
return fmt.Errorf("daemon: %w", err)
Expand Down
86 changes: 59 additions & 27 deletions internal/daemon/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,21 @@ import (
"github.com/schovi/shelli/internal/ansi"
)

type ptyHandle struct {
f *os.File
closeOnce sync.Once
}

func (p *ptyHandle) Close() {
p.closeOnce.Do(func() {
p.f.Close()
})
}

func (p *ptyHandle) File() *os.File {
return p.f
}

type Session struct {
Name string `json:"name"`
PID int `json:"pid"`
Expand All @@ -38,7 +53,7 @@ type SessionInfo struct {
type Server struct {
mu sync.Mutex
sessions map[string]*Session
ptys map[string]*os.File
ptys map[string]*ptyHandle
cmds map[string]*exec.Cmd
doneChans map[string]chan struct{}
frameDetectors map[string]*ansi.FrameDetector
Expand Down Expand Up @@ -87,7 +102,7 @@ func NewServer(opts ...ServerOption) (*Server, error) {

s := &Server{
sessions: make(map[string]*Session),
ptys: make(map[string]*os.File),
ptys: make(map[string]*ptyHandle),
cmds: make(map[string]*exec.Cmd),
doneChans: make(map[string]chan struct{}),
frameDetectors: make(map[string]*ansi.FrameDetector),
Expand Down Expand Up @@ -162,7 +177,10 @@ func (s *Server) Start() error {
for {
conn, err := listener.Accept()
if err != nil {
if s.listener == nil {
s.mu.Lock()
isShutdown := s.listener == nil
s.mu.Unlock()
if isShutdown {
return nil
}
return err
Expand Down Expand Up @@ -211,8 +229,8 @@ func (s *Server) Shutdown() {
if done, ok := s.doneChans[name]; ok {
close(done)
}
if ptmx, ok := s.ptys[name]; ok {
ptmx.Close()
if handle, ok := s.ptys[name]; ok {
handle.Close()
}
if cmd, ok := s.cmds[name]; ok {
cmd.Process.Kill()
Expand Down Expand Up @@ -380,16 +398,17 @@ func (s *Server) handleCreate(req Request) Response {
CreatedAt: now,
}

handle := &ptyHandle{f: ptmx}
s.sessions[req.Name] = sess
s.ptys[req.Name] = ptmx
s.ptys[req.Name] = handle
s.cmds[req.Name] = cmd
s.doneChans[req.Name] = make(chan struct{})
if req.TUIMode {
s.frameDetectors[req.Name] = ansi.NewFrameDetector(ansi.DefaultTUIStrategy())
s.responders[req.Name] = ansi.NewTerminalResponder(ptmx)
}

go s.captureOutput(req.Name, ptmx, cmd)
go s.captureOutput(req.Name, handle, cmd)

return Response{Success: true, Data: map[string]interface{}{
"name": sess.Name,
Expand All @@ -401,14 +420,16 @@ func (s *Server) handleCreate(req Request) Response {
}}
}

func (s *Server) captureOutput(name string, ptmx *os.File, cmd *exec.Cmd) {
func (s *Server) captureOutput(name string, handle *ptyHandle, cmd *exec.Cmd) {
s.mu.Lock()
done := s.doneChans[name]
detector := s.frameDetectors[name]
responder := s.responders[name]
storage := s.storage
s.mu.Unlock()

f := handle.File()

if detector != nil {
defer func() {
if pending := detector.Flush(); len(pending) > 0 {
Expand All @@ -425,8 +446,8 @@ func (s *Server) captureOutput(name string, ptmx *os.File, cmd *exec.Cmd) {
default:
}

ptmx.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
n, err := ptmx.Read(buf)
f.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
n, err := f.Read(buf)
if n > 0 {
data := buf[:n]
if responder != nil {
Expand All @@ -450,7 +471,7 @@ func (s *Server) captureOutput(name string, ptmx *os.File, cmd *exec.Cmd) {
}

cmd.Wait()
ptmx.Close()
handle.Close()

s.mu.Lock()
defer s.mu.Unlock()
Expand Down Expand Up @@ -626,11 +647,12 @@ func (s *Server) handleSnapshot(req Request) Response {
s.mu.Unlock()
return Response{Success: false, Error: fmt.Sprintf("session %q is not in TUI mode (snapshot requires --tui)", req.Name)}
}
ptmx, ok := s.ptys[req.Name]
handle, ok := s.ptys[req.Name]
if !ok {
s.mu.Unlock()
return Response{Success: false, Error: fmt.Sprintf("session %q PTY not available", req.Name)}
}
ptmx := handle.File()
cmd := s.cmds[req.Name]
storage := s.storage
s.mu.Unlock()
Expand Down Expand Up @@ -793,7 +815,7 @@ func (s *Server) handleSend(req Request) Response {
s.mu.Unlock()
return Response{Success: false, Error: fmt.Sprintf("session %q is stopped", req.Name)}
}
ptmx, ok := s.ptys[req.Name]
handle, ok := s.ptys[req.Name]
s.mu.Unlock()

if !ok {
Expand All @@ -805,7 +827,7 @@ func (s *Server) handleSend(req Request) Response {
data += "\n"
}

if _, err := ptmx.WriteString(data); err != nil {
if _, err := handle.File().WriteString(data); err != nil {
return Response{Success: false, Error: err.Error()}
}

Expand All @@ -830,8 +852,8 @@ func (s *Server) handleStop(req Request) Response {
delete(s.doneChans, req.Name)
}

if ptmx, ok := s.ptys[req.Name]; ok {
ptmx.Close()
if handle, ok := s.ptys[req.Name]; ok {
handle.Close()
delete(s.ptys, req.Name)
}

Expand Down Expand Up @@ -863,29 +885,25 @@ func (s *Server) handleStop(req Request) Response {

func (s *Server) handleKill(req Request) Response {
s.mu.Lock()
defer s.mu.Unlock()

sess, exists := s.sessions[req.Name]
if !exists {
s.mu.Unlock()
return Response{Success: false, Error: fmt.Sprintf("session %q not found", req.Name)}
}

var proc *os.Process
if sess.State == StateRunning {
if done, ok := s.doneChans[req.Name]; ok {
close(done)
delete(s.doneChans, req.Name)
}

if ptmx, ok := s.ptys[req.Name]; ok {
ptmx.Close()
if handle, ok := s.ptys[req.Name]; ok {
handle.Close()
delete(s.ptys, req.Name)
}

if cmd, ok := s.cmds[req.Name]; ok {
cmd.Process.Signal(syscall.SIGTERM)
time.Sleep(KillGracePeriod)
cmd.Process.Signal(syscall.SIGKILL)
cmd.Wait()
proc = cmd.Process
delete(s.cmds, req.Name)
}
}
Expand All @@ -894,6 +912,16 @@ func (s *Server) handleKill(req Request) Response {
delete(s.sessions, req.Name)
delete(s.frameDetectors, req.Name)
delete(s.responders, req.Name)
s.mu.Unlock()

if proc != nil {
go func() {
proc.Signal(syscall.SIGTERM)
time.Sleep(KillGracePeriod)
proc.Signal(syscall.SIGKILL)
proc.Wait()
}()
}

return Response{Success: true}
}
Expand All @@ -916,6 +944,10 @@ func (s *Server) handleSize(req Request) Response {
}

func (s *Server) handleSearch(req Request) Response {
if req.Before < 0 || req.After < 0 {
return Response{Success: false, Error: "before and after must be non-negative"}
}

s.mu.Lock()
_, exists := s.sessions[req.Name]
if !exists {
Expand Down Expand Up @@ -1056,7 +1088,7 @@ func (s *Server) handleResize(req Request) Response {
return Response{Success: false, Error: fmt.Sprintf("session %q is stopped", req.Name)}
}

ptmx, ok := s.ptys[req.Name]
handle, ok := s.ptys[req.Name]
if !ok {
s.mu.Unlock()
return Response{Success: false, Error: fmt.Sprintf("session %q not running", req.Name)}
Expand All @@ -1078,7 +1110,7 @@ func (s *Server) handleResize(req Request) Response {
rows = meta.Rows
}

if err := pty.Setsize(ptmx, &pty.Winsize{Cols: clampUint16(cols), Rows: clampUint16(rows)}); err != nil {
if err := pty.Setsize(handle.File(), &pty.Winsize{Cols: clampUint16(cols), Rows: clampUint16(rows)}); err != nil {
return Response{Success: false, Error: fmt.Sprintf("resize: %v", err)}
}

Expand Down
4 changes: 4 additions & 0 deletions internal/mcp/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,10 @@ func (r *ToolRegistry) callSearch(args json.RawMessage) (*CallToolResult, error)
after = a.Around
}

if before < 0 || after < 0 {
return nil, fmt.Errorf("before, after, and around must be non-negative")
}

resp, err := r.client.Search(daemon.SearchRequest{
Name: a.Name,
Pattern: a.Pattern,
Expand Down