about summary refs log tree commit diff
path: root/website/sandbox/learnpianochords/src
diff options
context:
space:
mode:
Diffstat (limited to 'website/sandbox/learnpianochords/src')
-rw-r--r--website/sandbox/learnpianochords/src/server/Fixtures.hs13
-rw-r--r--website/sandbox/learnpianochords/src/server/GoogleSignIn.hs53
-rw-r--r--website/sandbox/learnpianochords/src/server/Spec.hs30
-rw-r--r--website/sandbox/learnpianochords/src/server/TestUtils.hs12
4 files changed, 87 insertions, 21 deletions
diff --git a/website/sandbox/learnpianochords/src/server/Fixtures.hs b/website/sandbox/learnpianochords/src/server/Fixtures.hs
index 93599c3e884e..475553643319 100644
--- a/website/sandbox/learnpianochords/src/server/Fixtures.hs
+++ b/website/sandbox/learnpianochords/src/server/Fixtures.hs
@@ -7,25 +7,28 @@ import Web.JWT
 import Utils
 
 import qualified Data.Map as Map
+import qualified GoogleSignIn
+import qualified TestUtils
 --------------------------------------------------------------------------------
 
 -- | These are the JWT fields that I'd like to overwrite in the `googleJWT`
 -- function.
 data JWTFields = JWTFields
   { overwriteSigner :: Signer
-  , overwriteAud :: Maybe StringOrURI
+  , overwriteAuds :: [StringOrURI]
   }
 
 defaultJWTFields :: JWTFields
 defaultJWTFields = JWTFields
   { overwriteSigner = hmacSecret "secret"
-  , overwriteAud = stringOrURI "771151720060-buofllhed98fgt0j22locma05e7rpngl.apps.googleusercontent.com"
+  , overwriteAuds = ["771151720060-buofllhed98fgt0j22locma05e7rpngl.apps.googleusercontent.com"]
+                    |> fmap TestUtils.unsafeStringOrURI
   }
 
-googleJWT :: JWTFields -> Maybe (JWT UnverifiedJWT)
+googleJWT :: JWTFields -> GoogleSignIn.EncodedJWT
 googleJWT JWTFields{..} =
   encodeSigned signer jwtHeader claimSet
-  |> decode
+  |> GoogleSignIn.EncodedJWT
   where
     signer :: Signer
     signer = overwriteSigner
@@ -42,7 +45,7 @@ googleJWT JWTFields{..} =
     claimSet = JWTClaimsSet
       { iss = stringOrURI "accounts.google.com"
       , sub = stringOrURI "114079822315085727057"
-      , aud = overwriteAud |> fmap Left
+      , aud = overwriteAuds |> Right |> Just
       -- TODO: Replace date creation with a human-readable date constructor.
       , Web.JWT.exp = numericDate 1596756453
       , nbf = Nothing
diff --git a/website/sandbox/learnpianochords/src/server/GoogleSignIn.hs b/website/sandbox/learnpianochords/src/server/GoogleSignIn.hs
index 1ea252eea5ae..72fa608c47b4 100644
--- a/website/sandbox/learnpianochords/src/server/GoogleSignIn.hs
+++ b/website/sandbox/learnpianochords/src/server/GoogleSignIn.hs
@@ -1,14 +1,63 @@
+{-# LANGUAGE OverloadedStrings #-}
 --------------------------------------------------------------------------------
 module GoogleSignIn where
 --------------------------------------------------------------------------------
+import Data.String.Conversions (cs)
+import Data.Text (Text)
 import Web.JWT
+import Utils
+
+import qualified Network.HTTP.Simple as HTTP
 --------------------------------------------------------------------------------
 
+newtype EncodedJWT = EncodedJWT Text
+
+-- | Some of the errors that a JWT
+data ValidationResult
+  = Valid
+  | DecodeError
+  | GoogleSaysInvalid Text
+  | NoMatchingClientIDs [StringOrURI]
+  | ClientIDParseFailure Text
+  deriving (Eq, Show)
+
 -- | Returns True when the supplied `jwt` meets the following criteria:
 -- * The token has been signed by Google
 -- * The value of `aud` matches my Google client's ID
 -- * The value of `iss` matches is "accounts.google.com" or
 --   "https://accounts.google.com"
 -- * The `exp` time has not passed
-jwtIsValid :: JWT UnverifiedJWT -> IO Bool
-jwtIsValid jwt = pure False
+--
+-- Set `skipHTTP` to `True` to avoid making the network request for testing.
+jwtIsValid :: Bool
+           -> EncodedJWT
+           -> IO ValidationResult
+jwtIsValid skipHTTP (EncodedJWT encodedJWT) = do
+  case encodedJWT |> decode of
+    Nothing -> pure DecodeError
+    Just jwt -> do
+      if skipHTTP then
+        continue jwt
+      else do
+        let request = "https://oauth2.googleapis.com/tokeninfo"
+                      |> HTTP.setRequestQueryString [ ( "id_token", Just (cs encodedJWT) ) ]
+        res <- HTTP.httpLBS request
+        if HTTP.getResponseStatusCode res /= 200 then
+          pure $ GoogleSaysInvalid (res |> HTTP.getResponseBody |> cs)
+        else
+          continue jwt
+  where
+    continue :: JWT UnverifiedJWT -> IO ValidationResult
+    continue jwt = do
+      let audValues = jwt |> claims |> auds
+          mClientID = stringOrURI "771151720060-buofllhed98fgt0j22locma05e7rpngl.apps.googleusercontent.com"
+      case mClientID of
+        Nothing ->
+          pure $ ClientIDParseFailure "771151720060-buofllhed98fgt0j22locma05e7rpngl.apps.googleusercontent.com"
+        Just clientID ->
+          -- TODO: Prefer reading clientID from a config. I'm thinking of the
+          -- AppContext type having my Configuration
+          if not $ clientID `elem` audValues then
+            pure $ NoMatchingClientIDs audValues
+          else
+            pure Valid
diff --git a/website/sandbox/learnpianochords/src/server/Spec.hs b/website/sandbox/learnpianochords/src/server/Spec.hs
index 1f9b9bb4bf9c..6c683cbbf2a7 100644
--- a/website/sandbox/learnpianochords/src/server/Spec.hs
+++ b/website/sandbox/learnpianochords/src/server/Spec.hs
@@ -3,27 +3,29 @@
 module Spec where
 --------------------------------------------------------------------------------
 import Test.Hspec
-import Web.JWT
 import Utils
+import GoogleSignIn (ValidationResult(..))
 
 import qualified GoogleSignIn
 import qualified Fixtures as F
+import qualified TestUtils
 --------------------------------------------------------------------------------
 
 main :: IO ()
 main = hspec $ do
-  describe "GoogleSignIn" $ do
+  describe "GoogleSignIn" $
     describe "jwtIsValid" $ do
-      it "returns false when the signature is invalid" $ do
-        let mJWT = F.defaultJWTFields { F.overwriteSigner = hmacSecret "wrong" }
-                   |> F.googleJWT
-        case mJWT of
-          Nothing  -> True `shouldBe` False
-          Just jwt -> GoogleSignIn.jwtIsValid jwt `shouldReturn` False
+      let jwtIsValid' = GoogleSignIn.jwtIsValid True
+      it "returns validation error when the aud field doesn't match my client ID" $ do
+        let auds = ["wrong-client-id"]
+                   |> fmap TestUtils.unsafeStringOrURI
+            encodedJWT = F.defaultJWTFields { F.overwriteAuds = auds }
+                         |> F.googleJWT
+        jwtIsValid' encodedJWT `shouldReturn` NoMatchingClientIDs auds
 
-      it "returns false when the aud field doesn't match my client ID" $ do
-        let mJWT = F.defaultJWTFields { F.overwriteAud = stringOrURI "wrong" }
-                  |> F.googleJWT
-        case mJWT of
-          Nothing  -> True `shouldBe` False
-          Just jwt -> GoogleSignIn.jwtIsValid jwt `shouldReturn` False
+      it "returns validation success when one of the aud fields matches my client ID" $ do
+        let auds = ["wrong-client-id", "771151720060-buofllhed98fgt0j22locma05e7rpngl.apps.googleusercontent.com"]
+                   |> fmap TestUtils.unsafeStringOrURI
+            encodedJWT = F.defaultJWTFields { F.overwriteAuds = auds }
+                         |> F.googleJWT
+        jwtIsValid' encodedJWT `shouldReturn` Valid
diff --git a/website/sandbox/learnpianochords/src/server/TestUtils.hs b/website/sandbox/learnpianochords/src/server/TestUtils.hs
new file mode 100644
index 000000000000..c586f7f219ba
--- /dev/null
+++ b/website/sandbox/learnpianochords/src/server/TestUtils.hs
@@ -0,0 +1,12 @@
+--------------------------------------------------------------------------------
+module TestUtils where
+--------------------------------------------------------------------------------
+import Web.JWT
+import Data.String.Conversions (cs)
+--------------------------------------------------------------------------------
+
+unsafeStringOrURI :: String -> StringOrURI
+unsafeStringOrURI x =
+  case stringOrURI (cs x) of
+    Nothing -> error $ "Failed to convert to StringOrURI: " ++ x
+    Just x  -> x