about summary refs log tree commit diff
path: root/monzo_ynab/tokens.go
blob: 4be967ccb803ac5efc6406c0f8b4dd1705b6c65e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
// Creating a Tokens server to manage my access and refresh tokens. Keeping this
// as a separate server allows me to develop and use the access tokens without
// going through client authorization.
package main

////////////////////////////////////////////////////////////////////////////////
// Dependencies
////////////////////////////////////////////////////////////////////////////////

import (
	"auth"
	"encoding/json"
	"fmt"
	"io"
	"kv"
	"log"
	"net/http"
	"net/url"
	"os"
	"os/signal"
	"syscall"
	"time"
	"utils"
)

////////////////////////////////////////////////////////////////////////////////
// Types
////////////////////////////////////////////////////////////////////////////////

// This is the response from Monzo's API after we request an access token
// refresh.
type refreshTokenResponse struct {
	AccessToken  string `json:"access_token"`
	RefreshToken string `json:"refresh_token"`
	ClientId     string `json:"client_id"`
	ExpiresIn    int    `json:"expires_in"`
}

// This is the shape of the request from clients wishing to set state of the
// server.
type setTokensRequest struct {
	AccessToken  string `json:"access_token"`
	RefreshToken string `json:"refresh_token"`
	ExpiresIn    int    `json:"expires_in"`
}

// This is our application state.
type state struct {
	accessToken  string `json:"access_token"`
	refreshToken string `json:"refresh_token"`
}

type readMsg struct {
	sender chan state
}

type writeMsg struct {
	state  state
	sender chan bool
}

type channels struct {
	reads  chan readMsg
	writes chan writeMsg
}

////////////////////////////////////////////////////////////////////////////////
// Top-level Definitions
////////////////////////////////////////////////////////////////////////////////

var chans = &channels{
	reads:  make(chan readMsg),
	writes: make(chan writeMsg),
}

var (
	monzoClientId     = os.Getenv("monzo_client_id")
	monzoClientSecret = os.Getenv("monzo_client_secret")
	storePath         = os.Getenv("store_path")
)

////////////////////////////////////////////////////////////////////////////////
// Utils
////////////////////////////////////////////////////////////////////////////////

// Print the access and refresh tokens for debugging.
func logTokens(access string, refresh string) {
	log.Printf("Access: %s\n", access)
	log.Printf("Refresh: %s\n", refresh)
}

func (state *state) String() string {
	return fmt.Sprintf("state{\n\taccessToken: \"%s\",\n\trefreshToken: \"%s\"\n}\n", state.accessToken, state.refreshToken)
}

// Schedule a token refresh for `expiresIn` seconds using the provided
// `refreshToken`. This will update the application state with the access token
// and schedule an additional token refresh for the newly acquired tokens.
func scheduleTokenRefresh(expiresIn int, refreshToken string) {
	duration := time.Second * time.Duration(expiresIn)
	timestamp := time.Now().Local().Add(duration)
	// TODO(wpcarro): Consider adding a more human readable version that will
	// log the number of hours, minutes, etc. until the next refresh.
	log.Printf("Scheduling token refresh for %v\n", timestamp)
	time.Sleep(duration)
	log.Println("Refreshing tokens now...")
	accessToken, refreshToken := refreshTokens(refreshToken)
	log.Println("Successfully refreshed tokens.")
	logTokens(accessToken, refreshToken)
	setState(accessToken, refreshToken)
}

// Exchange existing credentials for a new access token and `refreshToken`. Also
// schedule the next refresh. This function returns the newly acquired access
// token and refresh token.
func refreshTokens(refreshToken string) (string, string) {
	// TODO(wpcarro): Support retries with exponential backoff.
	res, err := http.PostForm("https://api.monzo.com/oauth2/token", url.Values{
		"grant_type":    {"refresh_token"},
		"client_id":     {monzoClientId},
		"client_secret": {monzoClientSecret},
		"refresh_token": {refreshToken},
	})
	if res.StatusCode != http.StatusOK {
		// TODO(wpcarro): Considering panicking here.
		utils.DebugResponse(res)
	}
	if err != nil {
		utils.DebugResponse(res)
		log.Fatal("The request to Monzo to refresh our access token failed.", err)
	}
	defer res.Body.Close()
	payload := &refreshTokenResponse{}
	err = json.NewDecoder(res.Body).Decode(payload)
	if err != nil {
		log.Fatal("Could not decode the JSON response from Monzo.", err)
	}

	go scheduleTokenRefresh(payload.ExpiresIn, payload.RefreshToken)

	// Interestingly, JSON decoding into the refreshTokenResponse can success
	// even if the decoder doesn't populate any of the fields in the
	// refreshTokenResponse struct. From what I read, it isn't possible to make
	// these fields as required using an annotation, so this guard must suffice
	// for now.
	if payload.AccessToken == "" || payload.RefreshToken == "" {
		log.Fatal("JSON parsed correctly but failed to populate token fields.")
	}

	return payload.AccessToken, payload.RefreshToken
}

func persistTokens(access string, refresh string) {
	log.Println("Persisting tokens...")
	kv.Set(storePath, "monzoAccessToken", access)
	kv.Set(storePath, "monzoRefreshToken", refresh)
	log.Println("Successfully persisted tokens.")
}

// Listen for SIGINT and SIGTERM signals. When received, persist the access and
// refresh tokens and shutdown the server.
func handleInterrupts() {
	// Gracefully handle interruptions.
	sigs := make(chan os.Signal, 1)
	done := make(chan bool)

	signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)

	go func() {
		sig := <-sigs
		log.Printf("Received signal to shutdown. %v\n", sig)
		state := getState()
		persistTokens(state.accessToken, state.refreshToken)
		done <- true
	}()

	<-done
	log.Println("Exiting...")
	os.Exit(0)
}

// Set `accessToken` and `refreshToken` on application state.
func setState(accessToken string, refreshToken string) {
	msg := writeMsg{state{accessToken, refreshToken}, make(chan bool)}
	chans.writes <- msg
	<-msg.sender
}

// Return our application state.
func getState() state {
	msg := readMsg{make(chan state)}
	chans.reads <- msg
	return <-msg.sender
}

////////////////////////////////////////////////////////////////////////////////
// Main
////////////////////////////////////////////////////////////////////////////////

func main() {
	// Manage application state.
	go func() {
		state := &state{}
		for {
			select {
			case msg := <-chans.reads:
				log.Println("Reading from state...")
				log.Println(state)
				msg.sender <- *state
			case msg := <-chans.writes:
				log.Println("Writing to state.")
				log.Printf("Old: %s\n", state)
				*state = msg.state
				log.Printf("New: %s\n", state)
				// As an attempt to maintain consistency between application
				// state and persisted state, everytime we write to the
				// application state, we will write to the store.
				persistTokens(state.accessToken, state.refreshToken)
				msg.sender <- true
			}
		}
	}()

	// Retrieve cached tokens from store.
	accessToken := fmt.Sprintf("%v", kv.Get(storePath, "monzoAccessToken"))
	refreshToken := fmt.Sprintf("%v", kv.Get(storePath, "monzoRefreshToken"))

	log.Println("Attempting to retrieve cached credentials...")
	logTokens(accessToken, refreshToken)

	if accessToken == "" || refreshToken == "" {
		log.Println("Cached credentials are absent. Authorizing client...")
		authCode := auth.GetAuthCode(monzoClientId)
		tokens := auth.GetTokensFromAuthCode(authCode, monzoClientId, monzoClientSecret)
		setState(tokens.AccessToken, tokens.RefreshToken)
		go scheduleTokenRefresh(tokens.ExpiresIn, tokens.RefreshToken)
	} else {
		setState(accessToken, refreshToken)
		// If we have tokens, they may be expiring soon. We don't know because
		// we aren't storing the expiration timestamp in the state or in the
		// store. Until we have that information, and to be safe, let's refresh
		// the tokens.
		go scheduleTokenRefresh(0, refreshToken)
	}

	// Gracefully handle shutdowns.
	go handleInterrupts()

	// Listen to inbound requests.
	fmt.Println("Listening on http://localhost:4242 ...")
	log.Fatal(http.ListenAndServe(":4242",
		http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
			if req.URL.Path == "/refresh-tokens" && req.Method == "POST" {
				state := getState()
				go scheduleTokenRefresh(0, state.refreshToken)
				fmt.Fprintf(w, "Done.")
			} else if req.URL.Path == "/set-tokens" && req.Method == "POST" {
				// Parse
				payload := &setTokensRequest{}
				err := json.NewDecoder(req.Body).Decode(payload)
				if err != nil {
					log.Fatal("Could not decode the user's JSON request.", err)
				}

				// Update application state
				setState(payload.AccessToken, payload.RefreshToken)

				// Refresh tokens
				go scheduleTokenRefresh(payload.ExpiresIn, payload.RefreshToken)

				// Ack
				fmt.Fprintf(w, "Done.")
			} else if req.URL.Path == "/state" && req.Method == "GET" {
				// TODO(wpcarro): Ensure that this returns serialized state.
				w.Header().Set("Content-type", "application/json")
				state := getState()
				payload, _ := json.Marshal(state)
				io.WriteString(w, string(payload))
			} else {
				log.Printf("Unhandled request: %v\n", *req)
			}
		})))
}