diff options
Diffstat (limited to 'monzo_ynab')
-rw-r--r-- | monzo_ynab/tokens.go | 72 |
1 files changed, 42 insertions, 30 deletions
diff --git a/monzo_ynab/tokens.go b/monzo_ynab/tokens.go index b417e49faf44..7afd86e4cf72 100644 --- a/monzo_ynab/tokens.go +++ b/monzo_ynab/tokens.go @@ -55,7 +55,8 @@ type readMsg struct { } type writeMsg struct { - state state + state state + sender chan bool } type channels struct { @@ -102,10 +103,10 @@ func scheduleTokenRefresh(expiresIn int, refreshToken string) { log.Printf("Scheduling token refresh for %v\n", timestamp) time.Sleep(duration) log.Println("Refreshing tokens now...") - access, refresh := refreshTokens(refreshToken) + accessToken, refreshToken := refreshTokens(refreshToken) log.Println("Successfully refreshed tokens.") - logTokens(access, refresh) - chans.writes <- writeMsg{state{access, refresh}} + logTokens(accessToken, refreshToken) + setState(accessToken, refreshToken) } // Exchange existing credentials for a new access token and `refreshToken`. Also @@ -177,6 +178,13 @@ func handleInterrupts() { os.Exit(0) } +// Set `accessToken` and `refreshToken` on application state. +func setState(accessToken string, refreshToken string) { + msg := writeMsg{state{accessToken, refreshToken}, make(chan bool)} + chans.writes <- msg + <-msg.sender +} + // Return our application state. func getState() state { msg := readMsg{make(chan state)} @@ -189,24 +197,9 @@ func getState() state { //////////////////////////////////////////////////////////////////////////////// func main() { - // Retrieve cached tokens from store. - accessToken := fmt.Sprintf("%v", kv.Get("monzoAccessToken")) - refreshToken := fmt.Sprintf("%v", kv.Get("monzoRefreshToken")) - - log.Println("Attempting to retrieve cached credentials...") - logTokens(accessToken, refreshToken) - - if accessToken == "" || refreshToken == "" { - log.Println("Cached credentials are absent. Authorizing client...") - authCode := auth.GetAuthCode(monzoClientId) - tokens := auth.GetTokensFromAuthCode(authCode, monzoClientId, monzoClientSecret) - accessToken, refreshToken = tokens.AccessToken, tokens.RefreshToken - go persistTokens(accessToken, refreshToken) - go scheduleTokenRefresh(tokens.ExpiresIn, refreshToken) - } // Manage application state. go func() { - state := &state{accessToken, refreshToken} + state := &state{} for { select { case msg := <-chans.reads: @@ -218,18 +211,39 @@ func main() { log.Printf("Old: %s\n", state) *state = msg.state log.Printf("New: %s\n", state) + // As an attempt to maintain consistency between application + // state and persisted state, everytime we write to the + // application state, we will write to the store. + persistTokens(state.accessToken, state.refreshToken) + msg.sender <- true } } }() - // Gracefully shutdown. - go handleInterrupts() + // Retrieve cached tokens from store. + accessToken := fmt.Sprintf("%v", kv.Get("monzoAccessToken")) + refreshToken := fmt.Sprintf("%v", kv.Get("monzoRefreshToken")) - // If we have tokens, they may be expiring soon. We don't know because - // we aren't storing the expiration timestamp in the state or in the - // store. Until we have that information, and to be safe, let's refresh - // the tokens. - scheduleTokenRefresh(0, refreshToken) + log.Println("Attempting to retrieve cached credentials...") + logTokens(accessToken, refreshToken) + + if accessToken == "" || refreshToken == "" { + log.Println("Cached credentials are absent. Authorizing client...") + authCode := auth.GetAuthCode(monzoClientId) + tokens := auth.GetTokensFromAuthCode(authCode, monzoClientId, monzoClientSecret) + setState(tokens.AccessToken, tokens.RefreshToken) + go scheduleTokenRefresh(tokens.ExpiresIn, tokens.RefreshToken) + } else { + setState(accessToken, refreshToken) + // If we have tokens, they may be expiring soon. We don't know because + // we aren't storing the expiration timestamp in the state or in the + // store. Until we have that information, and to be safe, let's refresh + // the tokens. + go scheduleTokenRefresh(0, refreshToken) + } + + // Gracefully handle shutdowns. + go handleInterrupts() // Listen to inbound requests. fmt.Println("Listening on http://localhost:4242 ...") @@ -239,7 +253,6 @@ func main() { state := getState() go scheduleTokenRefresh(0, state.refreshToken) fmt.Fprintf(w, "Done.") - } else if req.URL.Path == "/set-tokens" && req.Method == "POST" { // Parse payload := &setTokensRequest{} @@ -249,8 +262,7 @@ func main() { } // Update application state - msg := writeMsg{state{payload.AccessToken, payload.RefreshToken}} - chans.writes <- msg + setState(payload.AccessToken, payload.RefreshToken) // Refresh tokens go scheduleTokenRefresh(payload.ExpiresIn, payload.RefreshToken) |