about summary refs log tree commit diff
path: root/fun/clbot/gerrit/watcher.go
diff options
context:
space:
mode:
Diffstat (limited to 'fun/clbot/gerrit/watcher.go')
-rw-r--r--fun/clbot/gerrit/watcher.go277
1 files changed, 277 insertions, 0 deletions
diff --git a/fun/clbot/gerrit/watcher.go b/fun/clbot/gerrit/watcher.go
new file mode 100644
index 000000000000..80a431f92250
--- /dev/null
+++ b/fun/clbot/gerrit/watcher.go
@@ -0,0 +1,277 @@
+// Package gerrit implements a watcher for Gerrit events.
+package gerrit
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"net"
+	"strings"
+	"time"
+
+	"code.tvl.fyi/fun/clbot/gerrit/gerritevents"
+	"github.com/cenkalti/backoff/v4"
+	log "github.com/golang/glog"
+	"golang.org/x/crypto/ssh"
+)
+
+// zeroStartingBackOff is a backoff.BackOff that returns "0" as the first Duration after a reset.
+// This is useful for constructing loops and just enforcing a backoff duration on every loop, rather than incorporating this logic into the loop directly.
+type zeroStartingBackOff struct {
+	bo      backoff.BackOff
+	initial bool
+}
+
+// NextBackOff returns the next back off duration to use.
+// For the first call after a call to Reset(), this is 0. For each subsequent duration, the underlying BackOff is consulted.
+func (bo *zeroStartingBackOff) NextBackOff() time.Duration {
+	if bo.initial == true {
+		bo.initial = false
+		return 0
+	}
+	return bo.bo.NextBackOff()
+}
+
+// Reset resets to the initial state, and also passes a Reset through to the underlying BackOff.
+func (bo *zeroStartingBackOff) Reset() {
+	bo.initial = true
+	bo.bo.Reset()
+}
+
+// closer provides an embeddable implementation of Close which awaits a main loop acknowledging it has stopped.
+type closer struct {
+	stop    chan struct{}
+	stopped chan struct{}
+}
+
+// newCloser returns a closer with the channels initialised.
+func newCloser() closer {
+	return closer{
+		stop:    make(chan struct{}),
+		stopped: make(chan struct{}),
+	}
+}
+
+// Close stops the main loop, waiting for the main loop to stop until it stops or the context is cancelled, whichever happens first.
+func (c *closer) Close(ctx context.Context) error {
+	select {
+	case <-c.stopped:
+		return nil
+	case <-c.stop:
+		return nil
+	case <-ctx.Done():
+		return ctx.Err()
+	default:
+	}
+	close(c.stop)
+	select {
+	case <-c.stopped:
+		return nil
+	case <-ctx.Done():
+		return ctx.Err()
+	}
+}
+
+// lineWriter is an io.Writer which splits on \n and outputs each line (with no trailing newline) to its output channel.
+type lineWriter struct {
+	buf string
+	out chan string
+}
+
+// Write accepts a slice of bytes containing zero or more new lines.
+// If the contained channel is non-buffering or is full, this will block.
+func (w *lineWriter) Write(p []byte) (n int, err error) {
+	w.buf += string(p)
+	pieces := strings.Split(w.buf, "\n")
+	w.buf = pieces[len(pieces)-1]
+	for n := 0; n < len(pieces)-1; n++ {
+		w.out <- pieces[n]
+	}
+	return len(p), nil
+}
+
+// restartingClient is a simple SSH client that repeatedly connects to an SSH server, runs a command, and outputs the lines output by it on stdout onto a channel.
+type restartingClient struct {
+	closer
+
+	network string
+	addr    string
+	cfg     *ssh.ClientConfig
+
+	exec     string
+	output   chan string
+	shutdown func()
+}
+
+var (
+	errStopConnect = errors.New("gerrit: told to stop reconnecting by remote server")
+)
+
+func (c *restartingClient) runOnce() error {
+	netConn, err := net.Dial(c.network, c.addr)
+	if err != nil {
+		return fmt.Errorf("connecting to %v/%v: %w", c.network, c.addr, err)
+	}
+	defer netConn.Close()
+
+	sshConn, newCh, newReq, err := ssh.NewClientConn(netConn, c.addr, c.cfg)
+	if err != nil {
+		return fmt.Errorf("creating SSH connection to %v/%v: %w", c.network, c.addr, err)
+	}
+	defer sshConn.Close()
+
+	goAway := false
+	passedThroughReqs := make(chan *ssh.Request)
+	go func() {
+		defer close(passedThroughReqs)
+		for req := range newReq {
+			if req.Type == "goaway" {
+				goAway = true
+				log.Warningf("remote end %v/%v told me to go away!", c.network, c.addr)
+				sshConn.Close()
+				netConn.Close()
+			}
+			passedThroughReqs <- req
+		}
+	}()
+
+	cl := ssh.NewClient(sshConn, newCh, passedThroughReqs)
+
+	sess, err := cl.NewSession()
+	if err != nil {
+		return fmt.Errorf("NewSession on %v/%v: %w", c.network, c.addr, err)
+	}
+	defer sess.Close()
+
+	sess.Stdout = &lineWriter{out: c.output}
+
+	if err := sess.Start(c.exec); err != nil {
+		return fmt.Errorf("Start(%q) on %v/%v: %w", c.exec, c.network, c.addr, err)
+	}
+
+	log.Infof("connected to %v/%v", c.network, c.addr)
+
+	done := make(chan struct{})
+	go func() {
+		sess.Wait()
+		close(done)
+	}()
+	go func() {
+		select {
+		case <-c.stop:
+			sess.Close()
+		case <-done:
+		}
+		return
+	}()
+	<-done
+
+	if goAway {
+		return errStopConnect
+	}
+	return nil
+}
+
+func (c *restartingClient) run() {
+	defer close(c.stopped)
+	ebo := backoff.NewExponentialBackOff()
+	ebo.MaxElapsedTime = 0
+	bo := &zeroStartingBackOff{bo: ebo, initial: true}
+	for {
+		timer := time.NewTimer(bo.NextBackOff())
+		select {
+		case <-c.stop:
+			timer.Stop()
+			return
+		case <-timer.C:
+			break
+		}
+		if err := c.runOnce(); err == errStopConnect {
+			if c.shutdown != nil {
+				c.shutdown()
+				return
+			}
+		} else if err != nil {
+			log.Errorf("SSH: %v", err)
+		} else {
+			bo.Reset()
+		}
+	}
+}
+
+// Output returns the channel on which each newline-delimited string output by the executed command's stdout can be received.
+func (c *restartingClient) Output() <-chan string {
+	return c.output
+}
+
+// dialRestartingClient creates a new restartingClient.
+func dialRestartingClient(network, addr string, config *ssh.ClientConfig, exec string, shutdown func()) (*restartingClient, error) {
+	c := &restartingClient{
+		closer:   newCloser(),
+		network:  network,
+		addr:     addr,
+		cfg:      config,
+		exec:     exec,
+		output:   make(chan string),
+		shutdown: shutdown,
+	}
+	go c.run()
+	return c, nil
+}
+
+// Watcher watches
+type Watcher struct {
+	closer
+	c *restartingClient
+
+	output chan gerritevents.Event
+}
+
+// Close shuts down the SSH client connection, if any, and closes the output channel.
+// It blocks until shutdown is complete or until the context is cancelled, whichever comes first.
+func (w *Watcher) Close(ctx context.Context) {
+	w.c.Close(ctx)
+	w.closer.Close(ctx)
+}
+
+func (w *Watcher) run() {
+	defer close(w.stopped)
+	defer close(w.output)
+	for {
+		select {
+		case <-w.stop:
+			return
+		case o := <-w.c.Output():
+			ev, err := gerritevents.Parse([]byte(o))
+			if err != nil {
+				log.Errorf("failed to parse event %v: %v", o, err)
+				continue
+			}
+			w.output <- ev
+		}
+	}
+}
+
+// Events returns the channel upon which parsed Gerrit events can be received.
+func (w *Watcher) Events() <-chan gerritevents.Event {
+	return w.output
+}
+
+// New returns a running Watcher from which events can be read.
+// It will begin connecting to the provided address immediately.
+func New(ctx context.Context, network, addr string, cfg *ssh.ClientConfig) (*Watcher, error) {
+	wc := newCloser()
+	rc, err := dialRestartingClient(network, addr, cfg, "gerrit stream-events", func() {
+		wc.Close(context.Background())
+	})
+	if err != nil {
+		return nil, fmt.Errorf("dialRestartingClient: %w", err)
+	}
+	w := &Watcher{
+		closer: wc,
+		c:      rc,
+		output: make(chan gerritevents.Event),
+	}
+	go w.run()
+	return w, nil
+}