diff options
-rw-r--r-- | monzo_ynab/tokens.go | 160 |
1 files changed, 102 insertions, 58 deletions
diff --git a/monzo_ynab/tokens.go b/monzo_ynab/tokens.go index d969ce6e43b4..7bebeef11729 100644 --- a/monzo_ynab/tokens.go +++ b/monzo_ynab/tokens.go @@ -8,17 +8,19 @@ package main //////////////////////////////////////////////////////////////////////////////// import ( - "bytes" + "auth" "encoding/json" "fmt" + "io" + "kv" "log" "net/http" "net/url" "os" - "time" - "kv" "os/signal" "syscall" + "time" + "utils" ) //////////////////////////////////////////////////////////////////////////////// @@ -71,25 +73,38 @@ var chans = &channels{ } var ( - monzoClientId = os.Getenv("monzo_client_id") - monzoClientSecret = os.Getenv("monzo_client_secret") + monzoClientId = os.Getenv("monzo_client_id") + monzoClientSecret = os.Getenv("monzo_client_secret") ) //////////////////////////////////////////////////////////////////////////////// // Utils //////////////////////////////////////////////////////////////////////////////// +// Print the access and refresh tokens for debugging. +func logTokens(access string, refresh string) { + log.Printf("Access: %s\n", access) + log.Printf("Refresh: %s\n", refresh) +} + +func (state *state) String() string { + return fmt.Sprintf("state{\n\taccessToken: \"%s\",\n\trefreshToken: \"%s\"\n}\n", state.accessToken, state.refreshToken) +} + // Schedule a token refresh for `expiresIn` seconds using the provided // `refreshToken`. This will update the application state with the access token // and schedule an additional token refresh for the newly acquired tokens. func scheduleTokenRefresh(expiresIn int, refreshToken string) { duration := time.Second * time.Duration(expiresIn) timestamp := time.Now().Local().Add(duration) + // TODO(wpcarro): Consider adding a more human readable version that will + // log the number of hours, minutes, etc. until the next refresh. log.Printf("Scheduling token refresh for %v\n", timestamp) time.Sleep(duration) log.Println("Refreshing tokens now...") access, refresh := refreshTokens(refreshToken) log.Println("Successfully refreshed tokens.") + logTokens(access, refresh) chans.writes <- writeMsg{state{access, refresh}} } @@ -104,22 +119,42 @@ func refreshTokens(refreshToken string) (string, string) { "client_secret": {monzoClientSecret}, "refresh_token": {refreshToken}, }) + if res.StatusCode != http.StatusOK { + // TODO(wpcarro): Considering panicking here. + utils.DebugResponse(res) + } if err != nil { - log.Println(res) + utils.DebugResponse(res) log.Fatal("The request to Monzo to refresh our access token failed.", err) } defer res.Body.Close() payload := &refreshTokenResponse{} err = json.NewDecoder(res.Body).Decode(payload) if err != nil { - log.Println(res) log.Fatal("Could not decode the JSON response from Monzo.", err) } + go scheduleTokenRefresh(payload.ExpiresIn, payload.RefreshToken) + // Interestingly, JSON decoding into the refreshTokenResponse can success + // even if the decoder doesn't populate any of the fields in the + // refreshTokenResponse struct. From what I read, it isn't possible to make + // these fields as required using an annotation, so this guard must suffice + // for now. + if payload.AccessToken == "" || payload.RefreshToken == "" { + log.Fatal("JSON parsed correctly but failed to populate token fields.") + } + return payload.AccessToken, payload.RefreshToken } +func persistTokens(access string, refresh string) { + log.Println("Persisting tokens...") + kv.Set("monzoAccessToken", access) + kv.Set("monzoRefreshToken", refresh) + log.Println("Successfully persisted tokens.") +} + // Listen for SIGINT and SIGTERM signals. When received, persist the access and // refresh tokens and shutdown the server. func handleInterrupts() { @@ -132,14 +167,8 @@ func handleInterrupts() { go func() { sig := <-sigs log.Printf("Received signal to shutdown. %v\n", sig) - // Persist existing tokens - log.Println("Persisting existing credentials...") - msg := readMsg{make(chan state)} - chans.reads <- msg - state := <-msg.sender - kv.Set("monzoAccessToken", state.accessToken) - kv.Set("monzoRefreshToken", state.refreshToken) - log.Println("Credentials persisted.") + state := getState() + persistTokens(state.accessToken, state.refreshToken) done <- true }() @@ -148,6 +177,13 @@ func handleInterrupts() { os.Exit(0) } +// Return our application state. +func getState() state { + msg := readMsg{make(chan state)} + chans.reads <- msg + return <-msg.sender +} + //////////////////////////////////////////////////////////////////////////////// // Main //////////////////////////////////////////////////////////////////////////////// @@ -158,11 +194,21 @@ func main() { refreshToken := fmt.Sprintf("%v", kv.Get("monzoRefreshToken")) log.Println("Attempting to retrieve cached credentials...") - log.Printf("Access token: %s\n", accessToken) - log.Printf("Refresh token: %s\n", refreshToken) + logTokens(accessToken, refreshToken) if accessToken == "" || refreshToken == "" { - log.Fatal("Cannot start server without access or refresh tokens.") + log.Println("Cached credentials are absent. Authorizing client...") + authCode := auth.GetAuthCode(monzoClientId) + tokens := auth.GetTokensFromAuthCode(authCode, monzoClientId, monzoClientSecret) + accessToken, refreshToken = tokens.AccessToken, tokens.RefreshToken + go persistTokens(accessToken, refreshToken) + go scheduleTokenRefresh(tokens.ExpiresIn, refreshToken) + } else { + // If we have tokens, they may be expiring soon. We don't know because + // we aren't storing the expiration timestamp in the state or in the + // store. Until we have that information, and to be safe, let's refresh + // the tokens. + scheduleTokenRefresh(0, refreshToken) } // Gracefully shutdown. @@ -174,54 +220,52 @@ func main() { for { select { case msg := <-chans.reads: - log.Printf("Reading from state.") - log.Printf("Access Token: %s\n", state.accessToken) - log.Printf("Refresh Token: %s\n", state.refreshToken) + log.Println("Reading from state...") + log.Println(state) msg.sender <- *state case msg := <-chans.writes: - fmt.Printf("Writing new state: %v\n", msg.state) + log.Println("Writing to state.") + log.Printf("Old: %s\n", state) *state = msg.state + log.Printf("New: %s\n", state) } } }() // Listen to inbound requests. fmt.Println("Listening on http://localhost:4242 ...") - log.Fatal(http.ListenAndServe(":4242", http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - if req.URL.Path == "/refresh-tokens" && req.Method == "POST" { - msg := readMsg{make(chan state)} - chans.reads <- msg - state := <-msg.sender - go scheduleTokenRefresh(0, state.refreshToken) - fmt.Fprintf(w, "Done.") - - } else if req.URL.Path == "/set-tokens" && req.Method == "POST" { - // Parse - payload := &setTokensRequest{} - err := json.NewDecoder(req.Body).Decode(payload) - if err != nil { - log.Fatal("Could not decode the user's JSON request.", err) - } + log.Fatal(http.ListenAndServe(":4242", + http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if req.URL.Path == "/refresh-tokens" && req.Method == "POST" { + state := getState() + go scheduleTokenRefresh(0, state.refreshToken) + fmt.Fprintf(w, "Done.") - // Update application state - msg := writeMsg{state{payload.AccessToken, payload.RefreshToken}} - chans.writes <- msg - - // Refresh tokens - go scheduleTokenRefresh(payload.ExpiresIn, payload.RefreshToken) - - // Ack - fmt.Fprintf(w, "Done.") - } else if req.URL.Path == "/state" && req.Method == "GET" { - // TODO(wpcarro): Ensure that this returns serialized state. - w.Header().Set("Content-type", "application/json") - msg := readMsg{make(chan state)} - chans.reads <- msg - state := <-msg.sender - payload, _ := json.Marshal(state) - fmt.Fprintf(w, "Application state: %s\n", bytes.NewBuffer(payload)) - } else { - log.Printf("Unhandled request: %v\n", *req) - } - }))) + } else if req.URL.Path == "/set-tokens" && req.Method == "POST" { + // Parse + payload := &setTokensRequest{} + err := json.NewDecoder(req.Body).Decode(payload) + if err != nil { + log.Fatal("Could not decode the user's JSON request.", err) + } + + // Update application state + msg := writeMsg{state{payload.AccessToken, payload.RefreshToken}} + chans.writes <- msg + + // Refresh tokens + go scheduleTokenRefresh(payload.ExpiresIn, payload.RefreshToken) + + // Ack + fmt.Fprintf(w, "Done.") + } else if req.URL.Path == "/state" && req.Method == "GET" { + // TODO(wpcarro): Ensure that this returns serialized state. + w.Header().Set("Content-type", "application/json") + state := getState() + payload, _ := json.Marshal(state) + io.WriteString(w, string(payload)) + } else { + log.Printf("Unhandled request: %v\n", *req) + } + }))) } |