diff options
Diffstat (limited to 'fun/clbot/gerrit/watcher.go')
-rw-r--r-- | fun/clbot/gerrit/watcher.go | 277 |
1 files changed, 277 insertions, 0 deletions
diff --git a/fun/clbot/gerrit/watcher.go b/fun/clbot/gerrit/watcher.go new file mode 100644 index 000000000000..80a431f92250 --- /dev/null +++ b/fun/clbot/gerrit/watcher.go @@ -0,0 +1,277 @@ +// Package gerrit implements a watcher for Gerrit events. +package gerrit + +import ( + "context" + "errors" + "fmt" + "net" + "strings" + "time" + + "code.tvl.fyi/fun/clbot/gerrit/gerritevents" + "github.com/cenkalti/backoff/v4" + log "github.com/golang/glog" + "golang.org/x/crypto/ssh" +) + +// zeroStartingBackOff is a backoff.BackOff that returns "0" as the first Duration after a reset. +// This is useful for constructing loops and just enforcing a backoff duration on every loop, rather than incorporating this logic into the loop directly. +type zeroStartingBackOff struct { + bo backoff.BackOff + initial bool +} + +// NextBackOff returns the next back off duration to use. +// For the first call after a call to Reset(), this is 0. For each subsequent duration, the underlying BackOff is consulted. +func (bo *zeroStartingBackOff) NextBackOff() time.Duration { + if bo.initial == true { + bo.initial = false + return 0 + } + return bo.bo.NextBackOff() +} + +// Reset resets to the initial state, and also passes a Reset through to the underlying BackOff. +func (bo *zeroStartingBackOff) Reset() { + bo.initial = true + bo.bo.Reset() +} + +// closer provides an embeddable implementation of Close which awaits a main loop acknowledging it has stopped. +type closer struct { + stop chan struct{} + stopped chan struct{} +} + +// newCloser returns a closer with the channels initialised. +func newCloser() closer { + return closer{ + stop: make(chan struct{}), + stopped: make(chan struct{}), + } +} + +// Close stops the main loop, waiting for the main loop to stop until it stops or the context is cancelled, whichever happens first. +func (c *closer) Close(ctx context.Context) error { + select { + case <-c.stopped: + return nil + case <-c.stop: + return nil + case <-ctx.Done(): + return ctx.Err() + default: + } + close(c.stop) + select { + case <-c.stopped: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// lineWriter is an io.Writer which splits on \n and outputs each line (with no trailing newline) to its output channel. +type lineWriter struct { + buf string + out chan string +} + +// Write accepts a slice of bytes containing zero or more new lines. +// If the contained channel is non-buffering or is full, this will block. +func (w *lineWriter) Write(p []byte) (n int, err error) { + w.buf += string(p) + pieces := strings.Split(w.buf, "\n") + w.buf = pieces[len(pieces)-1] + for n := 0; n < len(pieces)-1; n++ { + w.out <- pieces[n] + } + return len(p), nil +} + +// restartingClient is a simple SSH client that repeatedly connects to an SSH server, runs a command, and outputs the lines output by it on stdout onto a channel. +type restartingClient struct { + closer + + network string + addr string + cfg *ssh.ClientConfig + + exec string + output chan string + shutdown func() +} + +var ( + errStopConnect = errors.New("gerrit: told to stop reconnecting by remote server") +) + +func (c *restartingClient) runOnce() error { + netConn, err := net.Dial(c.network, c.addr) + if err != nil { + return fmt.Errorf("connecting to %v/%v: %w", c.network, c.addr, err) + } + defer netConn.Close() + + sshConn, newCh, newReq, err := ssh.NewClientConn(netConn, c.addr, c.cfg) + if err != nil { + return fmt.Errorf("creating SSH connection to %v/%v: %w", c.network, c.addr, err) + } + defer sshConn.Close() + + goAway := false + passedThroughReqs := make(chan *ssh.Request) + go func() { + defer close(passedThroughReqs) + for req := range newReq { + if req.Type == "goaway" { + goAway = true + log.Warningf("remote end %v/%v told me to go away!", c.network, c.addr) + sshConn.Close() + netConn.Close() + } + passedThroughReqs <- req + } + }() + + cl := ssh.NewClient(sshConn, newCh, passedThroughReqs) + + sess, err := cl.NewSession() + if err != nil { + return fmt.Errorf("NewSession on %v/%v: %w", c.network, c.addr, err) + } + defer sess.Close() + + sess.Stdout = &lineWriter{out: c.output} + + if err := sess.Start(c.exec); err != nil { + return fmt.Errorf("Start(%q) on %v/%v: %w", c.exec, c.network, c.addr, err) + } + + log.Infof("connected to %v/%v", c.network, c.addr) + + done := make(chan struct{}) + go func() { + sess.Wait() + close(done) + }() + go func() { + select { + case <-c.stop: + sess.Close() + case <-done: + } + return + }() + <-done + + if goAway { + return errStopConnect + } + return nil +} + +func (c *restartingClient) run() { + defer close(c.stopped) + ebo := backoff.NewExponentialBackOff() + ebo.MaxElapsedTime = 0 + bo := &zeroStartingBackOff{bo: ebo, initial: true} + for { + timer := time.NewTimer(bo.NextBackOff()) + select { + case <-c.stop: + timer.Stop() + return + case <-timer.C: + break + } + if err := c.runOnce(); err == errStopConnect { + if c.shutdown != nil { + c.shutdown() + return + } + } else if err != nil { + log.Errorf("SSH: %v", err) + } else { + bo.Reset() + } + } +} + +// Output returns the channel on which each newline-delimited string output by the executed command's stdout can be received. +func (c *restartingClient) Output() <-chan string { + return c.output +} + +// dialRestartingClient creates a new restartingClient. +func dialRestartingClient(network, addr string, config *ssh.ClientConfig, exec string, shutdown func()) (*restartingClient, error) { + c := &restartingClient{ + closer: newCloser(), + network: network, + addr: addr, + cfg: config, + exec: exec, + output: make(chan string), + shutdown: shutdown, + } + go c.run() + return c, nil +} + +// Watcher watches +type Watcher struct { + closer + c *restartingClient + + output chan gerritevents.Event +} + +// Close shuts down the SSH client connection, if any, and closes the output channel. +// It blocks until shutdown is complete or until the context is cancelled, whichever comes first. +func (w *Watcher) Close(ctx context.Context) { + w.c.Close(ctx) + w.closer.Close(ctx) +} + +func (w *Watcher) run() { + defer close(w.stopped) + defer close(w.output) + for { + select { + case <-w.stop: + return + case o := <-w.c.Output(): + ev, err := gerritevents.Parse([]byte(o)) + if err != nil { + log.Errorf("failed to parse event %v: %v", o, err) + continue + } + w.output <- ev + } + } +} + +// Events returns the channel upon which parsed Gerrit events can be received. +func (w *Watcher) Events() <-chan gerritevents.Event { + return w.output +} + +// New returns a running Watcher from which events can be read. +// It will begin connecting to the provided address immediately. +func New(ctx context.Context, network, addr string, cfg *ssh.ClientConfig) (*Watcher, error) { + wc := newCloser() + rc, err := dialRestartingClient(network, addr, cfg, "gerrit stream-events", func() { + wc.Close(context.Background()) + }) + if err != nil { + return nil, fmt.Errorf("dialRestartingClient: %w", err) + } + w := &Watcher{ + closer: wc, + c: rc, + output: make(chan gerritevents.Event), + } + go w.run() + return w, nil +} |