about summary refs log tree commit diff
path: root/fun
diff options
context:
space:
mode:
Diffstat (limited to 'fun')
-rw-r--r--fun/clbot/clbot.go25
1 files changed, 24 insertions, 1 deletions
diff --git a/fun/clbot/clbot.go b/fun/clbot/clbot.go
index 943df3bd38..7fa12f2c3b 100644
--- a/fun/clbot/clbot.go
+++ b/fun/clbot/clbot.go
@@ -5,6 +5,7 @@ import (
 	"flag"
 	"io/ioutil"
 	"os"
+	"os/signal"
 	"time"
 
 	"code.tvl.fyi/fun/clbot/gerrit"
@@ -55,11 +56,27 @@ func checkRequired(fs *flag.FlagSet) {
 	}
 }
 
+var shutdownFuncs []func()
+
+func callOnShutdown(f func()) {
+	shutdownFuncs = append(shutdownFuncs, f)
+}
+
 func main() {
 	flag.Parse()
 	checkRequired(required)
 
-	ctx := context.Background()
+	shutdownCh := make(chan os.Signal)
+	signal.Notify(shutdownCh, os.Interrupt)
+	go func() {
+		<-shutdownCh
+		for n := len(shutdownFuncs); n >= 0; n-- {
+			shutdownFuncs[n]()
+		}
+	}()
+
+	ctx, cancel := context.WithCancel(context.Background())
+	callOnShutdown(cancel)
 	cfg := &ssh.ClientConfig{
 		User:            *gerritAuthUsername,
 		Auth:            []ssh.AuthMethod{mustPrivateKey(*gerritAuthKeyPath)},
@@ -67,10 +84,16 @@ func main() {
 		Timeout:         *gerritSSHTimeout,
 	}
 	cfg.SetDefaults()
+
 	gw, err := gerrit.New(ctx, "tcp", *gerritAddr, cfg)
 	if err != nil {
 		log.Errorf("gerrit.New(%q): %v", *gerritAddr, err)
 	}
+	callOnShutdown(func() {
+		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+		defer cancel()
+		gw.Close(ctx)
+	})
 
 	for e := range gw.Events() {
 		log.Infof("hello: %v", spew.Sdump(e))