about summary refs log tree commit diff
path: root/website
diff options
context:
space:
mode:
authorWilliam Carroll <wpcarro@gmail.com>2020-08-08T12·44+0100
committerWilliam Carroll <wpcarro@gmail.com>2020-08-08T12·44+0100
commitd34b146702476f46bcca7d362e56f46227863f1b (patch)
tree6ad489c4509172780f578df9d66602a1c6a6272f /website
parent926d8e643e9ffb7d5f5608793d35381742675073 (diff)
Tests valid and invalid JWTs for the "aud" field
Test that when the JWT contains the client ID for my Google app, the JWT is
valid, and when it doesn't, it's invalid.
Diffstat (limited to 'website')
-rw-r--r--website/sandbox/learnpianochords/src/server/Fixtures.hs13
-rw-r--r--website/sandbox/learnpianochords/src/server/GoogleSignIn.hs53
-rw-r--r--website/sandbox/learnpianochords/src/server/Spec.hs30
-rw-r--r--website/sandbox/learnpianochords/src/server/TestUtils.hs12
4 files changed, 87 insertions, 21 deletions
diff --git a/website/sandbox/learnpianochords/src/server/Fixtures.hs b/website/sandbox/learnpianochords/src/server/Fixtures.hs
index 93599c3e884e..475553643319 100644
--- a/website/sandbox/learnpianochords/src/server/Fixtures.hs
+++ b/website/sandbox/learnpianochords/src/server/Fixtures.hs
@@ -7,25 +7,28 @@ import Web.JWT
 import Utils
 
 import qualified Data.Map as Map
+import qualified GoogleSignIn
+import qualified TestUtils
 --------------------------------------------------------------------------------
 
 -- | These are the JWT fields that I'd like to overwrite in the `googleJWT`
 -- function.
 data JWTFields = JWTFields
   { overwriteSigner :: Signer
-  , overwriteAud :: Maybe StringOrURI
+  , overwriteAuds :: [StringOrURI]
   }
 
 defaultJWTFields :: JWTFields
 defaultJWTFields = JWTFields
   { overwriteSigner = hmacSecret "secret"
-  , overwriteAud = stringOrURI "771151720060-buofllhed98fgt0j22locma05e7rpngl.apps.googleusercontent.com"
+  , overwriteAuds = ["771151720060-buofllhed98fgt0j22locma05e7rpngl.apps.googleusercontent.com"]
+                    |> fmap TestUtils.unsafeStringOrURI
   }
 
-googleJWT :: JWTFields -> Maybe (JWT UnverifiedJWT)
+googleJWT :: JWTFields -> GoogleSignIn.EncodedJWT
 googleJWT JWTFields{..} =
   encodeSigned signer jwtHeader claimSet
-  |> decode
+  |> GoogleSignIn.EncodedJWT
   where
     signer :: Signer
     signer = overwriteSigner
@@ -42,7 +45,7 @@ googleJWT JWTFields{..} =
     claimSet = JWTClaimsSet
       { iss = stringOrURI "accounts.google.com"
       , sub = stringOrURI "114079822315085727057"
-      , aud = overwriteAud |> fmap Left
+      , aud = overwriteAuds |> Right |> Just
       -- TODO: Replace date creation with a human-readable date constructor.
       , Web.JWT.exp = numericDate 1596756453
       , nbf = Nothing
diff --git a/website/sandbox/learnpianochords/src/server/GoogleSignIn.hs b/website/sandbox/learnpianochords/src/server/GoogleSignIn.hs
index 1ea252eea5ae..72fa608c47b4 100644
--- a/website/sandbox/learnpianochords/src/server/GoogleSignIn.hs
+++ b/website/sandbox/learnpianochords/src/server/GoogleSignIn.hs
@@ -1,14 +1,63 @@
+{-# LANGUAGE OverloadedStrings #-}
 --------------------------------------------------------------------------------
 module GoogleSignIn where
 --------------------------------------------------------------------------------
+import Data.String.Conversions (cs)
+import Data.Text (Text)
 import Web.JWT
+import Utils
+
+import qualified Network.HTTP.Simple as HTTP
 --------------------------------------------------------------------------------
 
+newtype EncodedJWT = EncodedJWT Text
+
+-- | Some of the errors that a JWT
+data ValidationResult
+  = Valid
+  | DecodeError
+  | GoogleSaysInvalid Text
+  | NoMatchingClientIDs [StringOrURI]
+  | ClientIDParseFailure Text
+  deriving (Eq, Show)
+
 -- | Returns True when the supplied `jwt` meets the following criteria:
 -- * The token has been signed by Google
 -- * The value of `aud` matches my Google client's ID
 -- * The value of `iss` matches is "accounts.google.com" or
 --   "https://accounts.google.com"
 -- * The `exp` time has not passed
-jwtIsValid :: JWT UnverifiedJWT -> IO Bool
-jwtIsValid jwt = pure False
+--
+-- Set `skipHTTP` to `True` to avoid making the network request for testing.
+jwtIsValid :: Bool
+           -> EncodedJWT
+           -> IO ValidationResult
+jwtIsValid skipHTTP (EncodedJWT encodedJWT) = do
+  case encodedJWT |> decode of
+    Nothing -> pure DecodeError
+    Just jwt -> do
+      if skipHTTP then
+        continue jwt
+      else do
+        let request = "https://oauth2.googleapis.com/tokeninfo"
+                      |> HTTP.setRequestQueryString [ ( "id_token", Just (cs encodedJWT) ) ]
+        res <- HTTP.httpLBS request
+        if HTTP.getResponseStatusCode res /= 200 then
+          pure $ GoogleSaysInvalid (res |> HTTP.getResponseBody |> cs)
+        else
+          continue jwt
+  where
+    continue :: JWT UnverifiedJWT -> IO ValidationResult
+    continue jwt = do
+      let audValues = jwt |> claims |> auds
+          mClientID = stringOrURI "771151720060-buofllhed98fgt0j22locma05e7rpngl.apps.googleusercontent.com"
+      case mClientID of
+        Nothing ->
+          pure $ ClientIDParseFailure "771151720060-buofllhed98fgt0j22locma05e7rpngl.apps.googleusercontent.com"
+        Just clientID ->
+          -- TODO: Prefer reading clientID from a config. I'm thinking of the
+          -- AppContext type having my Configuration
+          if not $ clientID `elem` audValues then
+            pure $ NoMatchingClientIDs audValues
+          else
+            pure Valid
diff --git a/website/sandbox/learnpianochords/src/server/Spec.hs b/website/sandbox/learnpianochords/src/server/Spec.hs
index 1f9b9bb4bf9c..6c683cbbf2a7 100644
--- a/website/sandbox/learnpianochords/src/server/Spec.hs
+++ b/website/sandbox/learnpianochords/src/server/Spec.hs
@@ -3,27 +3,29 @@
 module Spec where
 --------------------------------------------------------------------------------
 import Test.Hspec
-import Web.JWT
 import Utils
+import GoogleSignIn (ValidationResult(..))
 
 import qualified GoogleSignIn
 import qualified Fixtures as F
+import qualified TestUtils
 --------------------------------------------------------------------------------
 
 main :: IO ()
 main = hspec $ do
-  describe "GoogleSignIn" $ do
+  describe "GoogleSignIn" $
     describe "jwtIsValid" $ do
-      it "returns false when the signature is invalid" $ do
-        let mJWT = F.defaultJWTFields { F.overwriteSigner = hmacSecret "wrong" }
-                   |> F.googleJWT
-        case mJWT of
-          Nothing  -> True `shouldBe` False
-          Just jwt -> GoogleSignIn.jwtIsValid jwt `shouldReturn` False
+      let jwtIsValid' = GoogleSignIn.jwtIsValid True
+      it "returns validation error when the aud field doesn't match my client ID" $ do
+        let auds = ["wrong-client-id"]
+                   |> fmap TestUtils.unsafeStringOrURI
+            encodedJWT = F.defaultJWTFields { F.overwriteAuds = auds }
+                         |> F.googleJWT
+        jwtIsValid' encodedJWT `shouldReturn` NoMatchingClientIDs auds
 
-      it "returns false when the aud field doesn't match my client ID" $ do
-        let mJWT = F.defaultJWTFields { F.overwriteAud = stringOrURI "wrong" }
-                  |> F.googleJWT
-        case mJWT of
-          Nothing  -> True `shouldBe` False
-          Just jwt -> GoogleSignIn.jwtIsValid jwt `shouldReturn` False
+      it "returns validation success when one of the aud fields matches my client ID" $ do
+        let auds = ["wrong-client-id", "771151720060-buofllhed98fgt0j22locma05e7rpngl.apps.googleusercontent.com"]
+                   |> fmap TestUtils.unsafeStringOrURI
+            encodedJWT = F.defaultJWTFields { F.overwriteAuds = auds }
+                         |> F.googleJWT
+        jwtIsValid' encodedJWT `shouldReturn` Valid
diff --git a/website/sandbox/learnpianochords/src/server/TestUtils.hs b/website/sandbox/learnpianochords/src/server/TestUtils.hs
new file mode 100644
index 000000000000..c586f7f219ba
--- /dev/null
+++ b/website/sandbox/learnpianochords/src/server/TestUtils.hs
@@ -0,0 +1,12 @@
+--------------------------------------------------------------------------------
+module TestUtils where
+--------------------------------------------------------------------------------
+import Web.JWT
+import Data.String.Conversions (cs)
+--------------------------------------------------------------------------------
+
+unsafeStringOrURI :: String -> StringOrURI
+unsafeStringOrURI x =
+  case stringOrURI (cs x) of
+    Nothing -> error $ "Failed to convert to StringOrURI: " ++ x
+    Just x  -> x