about summary refs log tree commit diff
path: root/fun
diff options
context:
space:
mode:
authorLuke Granger-Brown <hg@lukegb.com>2020-06-14T19·17+0100
committerlukegb <lukegb@tvl.fyi>2020-06-15T16·48+0000
commit5acd03817a9ed5f3b597960ec2f192740a586631 (patch)
tree7e7074292881bf353a3ed2b8f51931772c3e1b95 /fun
parenta577fd83d63306381f37085a1999c16ca3d19a7b (diff)
chore(clbot): Add signal handler to make clbot shutdown cleanly on SIGINT. r/962
Change-Id: I3c6eeeb99f9d81cdbcb10880c9075ac94c4f5d19
Reviewed-on: https://cl.tvl.fyi/c/depot/+/341
Reviewed-by: tazjin <mail@tazj.in>
Diffstat (limited to 'fun')
-rw-r--r--fun/clbot/clbot.go25
1 files changed, 24 insertions, 1 deletions
diff --git a/fun/clbot/clbot.go b/fun/clbot/clbot.go
index 943df3bd387f..7fa12f2c3bc3 100644
--- a/fun/clbot/clbot.go
+++ b/fun/clbot/clbot.go
@@ -5,6 +5,7 @@ import (
 	"flag"
 	"io/ioutil"
 	"os"
+	"os/signal"
 	"time"
 
 	"code.tvl.fyi/fun/clbot/gerrit"
@@ -55,11 +56,27 @@ func checkRequired(fs *flag.FlagSet) {
 	}
 }
 
+var shutdownFuncs []func()
+
+func callOnShutdown(f func()) {
+	shutdownFuncs = append(shutdownFuncs, f)
+}
+
 func main() {
 	flag.Parse()
 	checkRequired(required)
 
-	ctx := context.Background()
+	shutdownCh := make(chan os.Signal)
+	signal.Notify(shutdownCh, os.Interrupt)
+	go func() {
+		<-shutdownCh
+		for n := len(shutdownFuncs); n >= 0; n-- {
+			shutdownFuncs[n]()
+		}
+	}()
+
+	ctx, cancel := context.WithCancel(context.Background())
+	callOnShutdown(cancel)
 	cfg := &ssh.ClientConfig{
 		User:            *gerritAuthUsername,
 		Auth:            []ssh.AuthMethod{mustPrivateKey(*gerritAuthKeyPath)},
@@ -67,10 +84,16 @@ func main() {
 		Timeout:         *gerritSSHTimeout,
 	}
 	cfg.SetDefaults()
+
 	gw, err := gerrit.New(ctx, "tcp", *gerritAddr, cfg)
 	if err != nil {
 		log.Errorf("gerrit.New(%q): %v", *gerritAddr, err)
 	}
+	callOnShutdown(func() {
+		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+		defer cancel()
+		gw.Close(ctx)
+	})
 
 	for e := range gw.Events() {
 		log.Infof("hello: %v", spew.Sdump(e))