about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--monzo_ynab/tokens.go160
1 files changed, 102 insertions, 58 deletions
diff --git a/monzo_ynab/tokens.go b/monzo_ynab/tokens.go
index d969ce6e43b4..7bebeef11729 100644
--- a/monzo_ynab/tokens.go
+++ b/monzo_ynab/tokens.go
@@ -8,17 +8,19 @@ package main
 ////////////////////////////////////////////////////////////////////////////////
 
 import (
-	"bytes"
+	"auth"
 	"encoding/json"
 	"fmt"
+	"io"
+	"kv"
 	"log"
 	"net/http"
 	"net/url"
 	"os"
-	"time"
-	"kv"
 	"os/signal"
 	"syscall"
+	"time"
+	"utils"
 )
 
 ////////////////////////////////////////////////////////////////////////////////
@@ -71,25 +73,38 @@ var chans = &channels{
 }
 
 var (
-	monzoClientId      = os.Getenv("monzo_client_id")
-	monzoClientSecret  = os.Getenv("monzo_client_secret")
+	monzoClientId     = os.Getenv("monzo_client_id")
+	monzoClientSecret = os.Getenv("monzo_client_secret")
 )
 
 ////////////////////////////////////////////////////////////////////////////////
 // Utils
 ////////////////////////////////////////////////////////////////////////////////
 
+// Print the access and refresh tokens for debugging.
+func logTokens(access string, refresh string) {
+	log.Printf("Access: %s\n", access)
+	log.Printf("Refresh: %s\n", refresh)
+}
+
+func (state *state) String() string {
+	return fmt.Sprintf("state{\n\taccessToken: \"%s\",\n\trefreshToken: \"%s\"\n}\n", state.accessToken, state.refreshToken)
+}
+
 // Schedule a token refresh for `expiresIn` seconds using the provided
 // `refreshToken`. This will update the application state with the access token
 // and schedule an additional token refresh for the newly acquired tokens.
 func scheduleTokenRefresh(expiresIn int, refreshToken string) {
 	duration := time.Second * time.Duration(expiresIn)
 	timestamp := time.Now().Local().Add(duration)
+	// TODO(wpcarro): Consider adding a more human readable version that will
+	// log the number of hours, minutes, etc. until the next refresh.
 	log.Printf("Scheduling token refresh for %v\n", timestamp)
 	time.Sleep(duration)
 	log.Println("Refreshing tokens now...")
 	access, refresh := refreshTokens(refreshToken)
 	log.Println("Successfully refreshed tokens.")
+	logTokens(access, refresh)
 	chans.writes <- writeMsg{state{access, refresh}}
 }
 
@@ -104,22 +119,42 @@ func refreshTokens(refreshToken string) (string, string) {
 		"client_secret": {monzoClientSecret},
 		"refresh_token": {refreshToken},
 	})
+	if res.StatusCode != http.StatusOK {
+		// TODO(wpcarro): Considering panicking here.
+		utils.DebugResponse(res)
+	}
 	if err != nil {
-		log.Println(res)
+		utils.DebugResponse(res)
 		log.Fatal("The request to Monzo to refresh our access token failed.", err)
 	}
 	defer res.Body.Close()
 	payload := &refreshTokenResponse{}
 	err = json.NewDecoder(res.Body).Decode(payload)
 	if err != nil {
-		log.Println(res)
 		log.Fatal("Could not decode the JSON response from Monzo.", err)
 	}
+
 	go scheduleTokenRefresh(payload.ExpiresIn, payload.RefreshToken)
 
+	// Interestingly, JSON decoding into the refreshTokenResponse can success
+	// even if the decoder doesn't populate any of the fields in the
+	// refreshTokenResponse struct. From what I read, it isn't possible to make
+	// these fields as required using an annotation, so this guard must suffice
+	// for now.
+	if payload.AccessToken == "" || payload.RefreshToken == "" {
+		log.Fatal("JSON parsed correctly but failed to populate token fields.")
+	}
+
 	return payload.AccessToken, payload.RefreshToken
 }
 
+func persistTokens(access string, refresh string) {
+	log.Println("Persisting tokens...")
+	kv.Set("monzoAccessToken", access)
+	kv.Set("monzoRefreshToken", refresh)
+	log.Println("Successfully persisted tokens.")
+}
+
 // Listen for SIGINT and SIGTERM signals. When received, persist the access and
 // refresh tokens and shutdown the server.
 func handleInterrupts() {
@@ -132,14 +167,8 @@ func handleInterrupts() {
 	go func() {
 		sig := <-sigs
 		log.Printf("Received signal to shutdown. %v\n", sig)
-		// Persist existing tokens
-		log.Println("Persisting existing credentials...")
-		msg := readMsg{make(chan state)}
-		chans.reads <- msg
-		state := <-msg.sender
-		kv.Set("monzoAccessToken", state.accessToken)
-		kv.Set("monzoRefreshToken", state.refreshToken)
-		log.Println("Credentials persisted.")
+		state := getState()
+		persistTokens(state.accessToken, state.refreshToken)
 		done <- true
 	}()
 
@@ -148,6 +177,13 @@ func handleInterrupts() {
 	os.Exit(0)
 }
 
+// Return our application state.
+func getState() state {
+	msg := readMsg{make(chan state)}
+	chans.reads <- msg
+	return <-msg.sender
+}
+
 ////////////////////////////////////////////////////////////////////////////////
 // Main
 ////////////////////////////////////////////////////////////////////////////////
@@ -158,11 +194,21 @@ func main() {
 	refreshToken := fmt.Sprintf("%v", kv.Get("monzoRefreshToken"))
 
 	log.Println("Attempting to retrieve cached credentials...")
-	log.Printf("Access token: %s\n", accessToken)
-	log.Printf("Refresh token: %s\n", refreshToken)
+	logTokens(accessToken, refreshToken)
 
 	if accessToken == "" || refreshToken == "" {
-		log.Fatal("Cannot start server without access or refresh tokens.")
+		log.Println("Cached credentials are absent. Authorizing client...")
+		authCode := auth.GetAuthCode(monzoClientId)
+		tokens := auth.GetTokensFromAuthCode(authCode, monzoClientId, monzoClientSecret)
+		accessToken, refreshToken = tokens.AccessToken, tokens.RefreshToken
+		go persistTokens(accessToken, refreshToken)
+		go scheduleTokenRefresh(tokens.ExpiresIn, refreshToken)
+	} else {
+		// If we have tokens, they may be expiring soon. We don't know because
+		// we aren't storing the expiration timestamp in the state or in the
+		// store. Until we have that information, and to be safe, let's refresh
+		// the tokens.
+		scheduleTokenRefresh(0, refreshToken)
 	}
 
 	// Gracefully shutdown.
@@ -174,54 +220,52 @@ func main() {
 		for {
 			select {
 			case msg := <-chans.reads:
-				log.Printf("Reading from state.")
-				log.Printf("Access Token: %s\n", state.accessToken)
-				log.Printf("Refresh Token: %s\n", state.refreshToken)
+				log.Println("Reading from state...")
+				log.Println(state)
 				msg.sender <- *state
 			case msg := <-chans.writes:
-				fmt.Printf("Writing new state: %v\n", msg.state)
+				log.Println("Writing to state.")
+				log.Printf("Old: %s\n", state)
 				*state = msg.state
+				log.Printf("New: %s\n", state)
 			}
 		}
 	}()
 
 	// Listen to inbound requests.
 	fmt.Println("Listening on http://localhost:4242 ...")
-	log.Fatal(http.ListenAndServe(":4242", http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
-		if req.URL.Path == "/refresh-tokens" && req.Method == "POST" {
-			msg := readMsg{make(chan state)}
-			chans.reads <- msg
-			state := <-msg.sender
-			go scheduleTokenRefresh(0, state.refreshToken)
-			fmt.Fprintf(w, "Done.")
-
-		} else if req.URL.Path == "/set-tokens" && req.Method == "POST" {
-			// Parse
-			payload := &setTokensRequest{}
-			err := json.NewDecoder(req.Body).Decode(payload)
-			if err != nil {
-				log.Fatal("Could not decode the user's JSON request.", err)
-			}
+	log.Fatal(http.ListenAndServe(":4242",
+		http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+			if req.URL.Path == "/refresh-tokens" && req.Method == "POST" {
+				state := getState()
+				go scheduleTokenRefresh(0, state.refreshToken)
+				fmt.Fprintf(w, "Done.")
 
-			// Update application state
-			msg := writeMsg{state{payload.AccessToken, payload.RefreshToken}}
-			chans.writes <- msg
-
-			// Refresh tokens
-			go scheduleTokenRefresh(payload.ExpiresIn, payload.RefreshToken)
-
-			// Ack
-			fmt.Fprintf(w, "Done.")
-		} else if req.URL.Path == "/state" && req.Method == "GET" {
-			// TODO(wpcarro): Ensure that this returns serialized state.
-			w.Header().Set("Content-type", "application/json")
-			msg := readMsg{make(chan state)}
-			chans.reads <- msg
-			state := <-msg.sender
-			payload, _ := json.Marshal(state)
-			fmt.Fprintf(w, "Application state: %s\n", bytes.NewBuffer(payload))
-		} else {
-			log.Printf("Unhandled request: %v\n", *req)
-		}
-	})))
+			} else if req.URL.Path == "/set-tokens" && req.Method == "POST" {
+				// Parse
+				payload := &setTokensRequest{}
+				err := json.NewDecoder(req.Body).Decode(payload)
+				if err != nil {
+					log.Fatal("Could not decode the user's JSON request.", err)
+				}
+
+				// Update application state
+				msg := writeMsg{state{payload.AccessToken, payload.RefreshToken}}
+				chans.writes <- msg
+
+				// Refresh tokens
+				go scheduleTokenRefresh(payload.ExpiresIn, payload.RefreshToken)
+
+				// Ack
+				fmt.Fprintf(w, "Done.")
+			} else if req.URL.Path == "/state" && req.Method == "GET" {
+				// TODO(wpcarro): Ensure that this returns serialized state.
+				w.Header().Set("Content-type", "application/json")
+				state := getState()
+				payload, _ := json.Marshal(state)
+				io.WriteString(w, string(payload))
+			} else {
+				log.Printf("Unhandled request: %v\n", *req)
+			}
+		})))
 }