about summary refs log tree commit diff
path: root/monzo_ynab/tokens.go
blob: d969ce6e43b4e5162a10ad118c0de773699b803c (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
// Creating a Tokens server to manage my access and refresh tokens. Keeping this
// as a separate server allows me to develop and use the access tokens without
// going through client authorization.
package main

////////////////////////////////////////////////////////////////////////////////
// Dependencies
////////////////////////////////////////////////////////////////////////////////

import (
	"bytes"
	"encoding/json"
	"fmt"
	"log"
	"net/http"
	"net/url"
	"os"
	"time"
	"kv"
	"os/signal"
	"syscall"
)

////////////////////////////////////////////////////////////////////////////////
// Types
////////////////////////////////////////////////////////////////////////////////

// This is the response from Monzo's API after we request an access token
// refresh.
type refreshTokenResponse struct {
	AccessToken  string `json:"access_token"`
	RefreshToken string `json:"refresh_token"`
	ClientId     string `json:"client_id"`
	ExpiresIn    int    `json:"expires_in"`
}

// This is the shape of the request from clients wishing to set state of the
// server.
type setTokensRequest struct {
	AccessToken  string `json:"access_token"`
	RefreshToken string `json:"refresh_token"`
	ExpiresIn    int    `json:"expires_in"`
}

// This is our application state.
type state struct {
	accessToken  string `json:"access_token"`
	refreshToken string `json:"refresh_token"`
}

type readMsg struct {
	sender chan state
}

type writeMsg struct {
	state state
}

type channels struct {
	reads  chan readMsg
	writes chan writeMsg
}

////////////////////////////////////////////////////////////////////////////////
// Top-level Definitions
////////////////////////////////////////////////////////////////////////////////

var chans = &channels{
	reads:  make(chan readMsg),
	writes: make(chan writeMsg),
}

var (
	monzoClientId      = os.Getenv("monzo_client_id")
	monzoClientSecret  = os.Getenv("monzo_client_secret")
)

////////////////////////////////////////////////////////////////////////////////
// Utils
////////////////////////////////////////////////////////////////////////////////

// Schedule a token refresh for `expiresIn` seconds using the provided
// `refreshToken`. This will update the application state with the access token
// and schedule an additional token refresh for the newly acquired tokens.
func scheduleTokenRefresh(expiresIn int, refreshToken string) {
	duration := time.Second * time.Duration(expiresIn)
	timestamp := time.Now().Local().Add(duration)
	log.Printf("Scheduling token refresh for %v\n", timestamp)
	time.Sleep(duration)
	log.Println("Refreshing tokens now...")
	access, refresh := refreshTokens(refreshToken)
	log.Println("Successfully refreshed tokens.")
	chans.writes <- writeMsg{state{access, refresh}}
}

// Exchange existing credentials for a new access token and `refreshToken`. Also
// schedule the next refresh. This function returns the newly acquired access
// token and refresh token.
func refreshTokens(refreshToken string) (string, string) {
	// TODO(wpcarro): Support retries with exponential backoff.
	res, err := http.PostForm("https://api.monzo.com/oauth2/token", url.Values{
		"grant_type":    {"refresh_token"},
		"client_id":     {monzoClientId},
		"client_secret": {monzoClientSecret},
		"refresh_token": {refreshToken},
	})
	if err != nil {
		log.Println(res)
		log.Fatal("The request to Monzo to refresh our access token failed.", err)
	}
	defer res.Body.Close()
	payload := &refreshTokenResponse{}
	err = json.NewDecoder(res.Body).Decode(payload)
	if err != nil {
		log.Println(res)
		log.Fatal("Could not decode the JSON response from Monzo.", err)
	}
	go scheduleTokenRefresh(payload.ExpiresIn, payload.RefreshToken)

	return payload.AccessToken, payload.RefreshToken
}

// Listen for SIGINT and SIGTERM signals. When received, persist the access and
// refresh tokens and shutdown the server.
func handleInterrupts() {
	// Gracefully handle interruptions.
	sigs := make(chan os.Signal)
	done := make(chan bool)

	signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)

	go func() {
		sig := <-sigs
		log.Printf("Received signal to shutdown. %v\n", sig)
		// Persist existing tokens
		log.Println("Persisting existing credentials...")
		msg := readMsg{make(chan state)}
		chans.reads <- msg
		state := <-msg.sender
		kv.Set("monzoAccessToken", state.accessToken)
		kv.Set("monzoRefreshToken", state.refreshToken)
		log.Println("Credentials persisted.")
		done <- true
	}()

	<-done
	log.Println("Received signal to shutdown. Exiting...")
	os.Exit(0)
}

////////////////////////////////////////////////////////////////////////////////
// Main
////////////////////////////////////////////////////////////////////////////////

func main() {
	// Retrieve cached tokens from store.
	accessToken := fmt.Sprintf("%v", kv.Get("monzoAccessToken"))
	refreshToken := fmt.Sprintf("%v", kv.Get("monzoRefreshToken"))

	log.Println("Attempting to retrieve cached credentials...")
	log.Printf("Access token: %s\n", accessToken)
	log.Printf("Refresh token: %s\n", refreshToken)

	if accessToken == "" || refreshToken == "" {
		log.Fatal("Cannot start server without access or refresh tokens.")
	}

	// Gracefully shutdown.
	go handleInterrupts()

	// Manage application state.
	go func() {
		state := &state{accessToken, refreshToken}
		for {
			select {
			case msg := <-chans.reads:
				log.Printf("Reading from state.")
				log.Printf("Access Token: %s\n", state.accessToken)
				log.Printf("Refresh Token: %s\n", state.refreshToken)
				msg.sender <- *state
			case msg := <-chans.writes:
				fmt.Printf("Writing new state: %v\n", msg.state)
				*state = msg.state
			}
		}
	}()

	// Listen to inbound requests.
	fmt.Println("Listening on http://localhost:4242 ...")
	log.Fatal(http.ListenAndServe(":4242", http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
		if req.URL.Path == "/refresh-tokens" && req.Method == "POST" {
			msg := readMsg{make(chan state)}
			chans.reads <- msg
			state := <-msg.sender
			go scheduleTokenRefresh(0, state.refreshToken)
			fmt.Fprintf(w, "Done.")

		} else if req.URL.Path == "/set-tokens" && req.Method == "POST" {
			// Parse
			payload := &setTokensRequest{}
			err := json.NewDecoder(req.Body).Decode(payload)
			if err != nil {
				log.Fatal("Could not decode the user's JSON request.", err)
			}

			// Update application state
			msg := writeMsg{state{payload.AccessToken, payload.RefreshToken}}
			chans.writes <- msg

			// Refresh tokens
			go scheduleTokenRefresh(payload.ExpiresIn, payload.RefreshToken)

			// Ack
			fmt.Fprintf(w, "Done.")
		} else if req.URL.Path == "/state" && req.Method == "GET" {
			// TODO(wpcarro): Ensure that this returns serialized state.
			w.Header().Set("Content-type", "application/json")
			msg := readMsg{make(chan state)}
			chans.reads <- msg
			state := <-msg.sender
			payload, _ := json.Marshal(state)
			fmt.Fprintf(w, "Application state: %s\n", bytes.NewBuffer(payload))
		} else {
			log.Printf("Unhandled request: %v\n", *req)
		}
	})))
}