// Package gerrit implements a watcher for Gerrit events. package gerrit import ( "context" "errors" "fmt" "net" "strings" "time" "code.tvl.fyi/fun/clbot/backoffutil" "code.tvl.fyi/fun/clbot/gerrit/gerritevents" log "github.com/golang/glog" "golang.org/x/crypto/ssh" ) // closer provides an embeddable implementation of Close which awaits a main loop acknowledging it has stopped. type closer struct { stop chan struct{} stopped chan struct{} } // newCloser returns a closer with the channels initialised. func newCloser() closer { return closer{ stop: make(chan struct{}), stopped: make(chan struct{}), } } // Close stops the main loop, waiting for the main loop to stop until it stops or the context is cancelled, whichever happens first. func (c *closer) Close(ctx context.Context) error { select { case <-c.stopped: return nil case <-c.stop: return nil case <-ctx.Done(): return ctx.Err() default: } close(c.stop) select { case <-c.stopped: return nil case <-ctx.Done(): return ctx.Err() } } // lineWriter is an io.Writer which splits on \n and outputs each line (with no trailing newline) to its output channel. type lineWriter struct { buf string out chan string } // Write accepts a slice of bytes containing zero or more new lines. // If the contained channel is non-buffering or is full, this will block. func (w *lineWriter) Write(p []byte) (n int, err error) { w.buf += string(p) pieces := strings.Split(w.buf, "\n") w.buf = pieces[len(pieces)-1] for n := 0; n < len(pieces)-1; n++ { w.out <- pieces[n] } return len(p), nil } // restartingClient is a simple SSH client that repeatedly connects to an SSH server, runs a command, and outputs the lines output by it on stdout onto a channel. type restartingClient struct { closer network string addr string cfg *ssh.ClientConfig exec string output chan string shutdown func() } var ( errStopConnect = errors.New("gerrit: told to stop reconnecting by remote server") ) func (c *restartingClient) runOnce() error { netConn, err := net.Dial(c.network, c.addr) if err != nil { return fmt.Errorf("connecting to %v/%v: %w", c.network, c.addr, err) } defer netConn.Close() sshConn, newCh, newReq, err := ssh.NewClientConn(netConn, c.addr, c.cfg) if err != nil { return fmt.Errorf("creating SSH connection to %v/%v: %w", c.network, c.addr, err) } defer sshConn.Close() goAway := false passedThroughReqs := make(chan *ssh.Request) go func() { defer close(passedThroughReqs) for req := range newReq { if req.Type == "goaway" { goAway = true log.Warningf("remote end %v/%v told me to go away!", c.network, c.addr) sshConn.Close() netConn.Close() } passedThroughReqs <- req } }() cl := ssh.NewClient(sshConn, newCh, passedThroughReqs) sess, err := cl.NewSession() if err != nil { return fmt.Errorf("NewSession on %v/%v: %w", c.network, c.addr, err) } defer sess.Close() sess.Stdout = &lineWriter{out: c.output} if err := sess.Start(c.exec); err != nil { return fmt.Errorf("Start(%q) on %v/%v: %w", c.exec, c.network, c.addr, err) } log.Infof("connected to %v/%v", c.network, c.addr) done := make(chan struct{}) go func() { sess.Wait() close(done) }() go func() { select { case <-c.stop: sess.Close() case <-done: } return }() <-done if goAway { return errStopConnect } return nil } func (c *restartingClient) run() { defer close(c.stopped) bo := backoffutil.NewDefaultBackOff() for { timer := time.NewTimer(bo.NextBackOff()) select { case <-c.stop: timer.Stop() return case <-timer.C: break } if err := c.runOnce(); err == errStopConnect { if c.shutdown != nil { c.shutdown() return } } else if err != nil { log.Errorf("SSH: %v", err) } else { bo.Reset() } } } // Output returns the channel on which each newline-delimited string output by the executed command's stdout can be received. func (c *restartingClient) Output() <-chan string { return c.output } // dialRestartingClient creates a new restartingClient. func dialRestartingClient(network, addr string, config *ssh.ClientConfig, exec string, shutdown func()) (*restartingClient, error) { c := &restartingClient{ closer: newCloser(), network: network, addr: addr, cfg: config, exec: exec, output: make(chan string), shutdown: shutdown, } go c.run() return c, nil } // Watcher watches type Watcher struct { closer c *restartingClient output chan gerritevents.Event } // Close shuts down the SSH client connection, if any, and closes the output channel. // It blocks until shutdown is complete or until the context is cancelled, whichever comes first. func (w *Watcher) Close(ctx context.Context) { w.c.Close(ctx) w.closer.Close(ctx) } func (w *Watcher) run() { defer close(w.stopped) defer close(w.output) for { select { case <-w.stop: return case o := <-w.c.Output(): ev, err := gerritevents.Parse([]byte(o)) if err != nil { log.Errorf("failed to parse event %v: %v", o, err) continue } w.output <- ev } } } // Events returns the channel upon which parsed Gerrit events can be received. func (w *Watcher) Events() <-chan gerritevents.Event { return w.output } // New returns a running Watcher from which events can be read. // It will begin connecting to the provided address immediately. func New(ctx context.Context, network, addr string, cfg *ssh.ClientConfig) (*Watcher, error) { wc := newCloser() rc, err := dialRestartingClient(network, addr, cfg, "gerrit stream-events", func() { wc.Close(context.Background()) }) if err != nil { return nil, fmt.Errorf("dialRestartingClient: %w", err) } w := &Watcher{ closer: wc, c: rc, output: make(chan gerritevents.Event), } go w.run() return w, nil }