about summary refs log tree commit diff
path: root/website/sandbox/learnpianochords/src
diff options
context:
space:
mode:
Diffstat (limited to 'website/sandbox/learnpianochords/src')
-rw-r--r--website/sandbox/learnpianochords/src/server/Fixtures.hs4
-rw-r--r--website/sandbox/learnpianochords/src/server/GoogleSignIn.hs35
-rw-r--r--website/sandbox/learnpianochords/src/server/Spec.hs12
3 files changed, 42 insertions, 9 deletions
diff --git a/website/sandbox/learnpianochords/src/server/Fixtures.hs b/website/sandbox/learnpianochords/src/server/Fixtures.hs
index 475553643319..ea7e0301ec4f 100644
--- a/website/sandbox/learnpianochords/src/server/Fixtures.hs
+++ b/website/sandbox/learnpianochords/src/server/Fixtures.hs
@@ -16,6 +16,7 @@ import qualified TestUtils
 data JWTFields = JWTFields
   { overwriteSigner :: Signer
   , overwriteAuds :: [StringOrURI]
+  , overwriteIss :: StringOrURI
   }
 
 defaultJWTFields :: JWTFields
@@ -23,6 +24,7 @@ defaultJWTFields = JWTFields
   { overwriteSigner = hmacSecret "secret"
   , overwriteAuds = ["771151720060-buofllhed98fgt0j22locma05e7rpngl.apps.googleusercontent.com"]
                     |> fmap TestUtils.unsafeStringOrURI
+  , overwriteIss = TestUtils.unsafeStringOrURI "accounts.google.com"
   }
 
 googleJWT :: JWTFields -> GoogleSignIn.EncodedJWT
@@ -43,7 +45,7 @@ googleJWT JWTFields{..} =
 
     claimSet :: JWTClaimsSet
     claimSet = JWTClaimsSet
-      { iss = stringOrURI "accounts.google.com"
+      { iss = Just overwriteIss
       , sub = stringOrURI "114079822315085727057"
       , aud = overwriteAuds |> Right |> Just
       -- TODO: Replace date creation with a human-readable date constructor.
diff --git a/website/sandbox/learnpianochords/src/server/GoogleSignIn.hs b/website/sandbox/learnpianochords/src/server/GoogleSignIn.hs
index 72fa608c47b4..f138f2b615c8 100644
--- a/website/sandbox/learnpianochords/src/server/GoogleSignIn.hs
+++ b/website/sandbox/learnpianochords/src/server/GoogleSignIn.hs
@@ -8,6 +8,7 @@ import Web.JWT
 import Utils
 
 import qualified Network.HTTP.Simple as HTTP
+import qualified Data.Text as Text
 --------------------------------------------------------------------------------
 
 newtype EncodedJWT = EncodedJWT Text
@@ -18,7 +19,9 @@ data ValidationResult
   | DecodeError
   | GoogleSaysInvalid Text
   | NoMatchingClientIDs [StringOrURI]
-  | ClientIDParseFailure Text
+  | WrongIssuer StringOrURI
+  | StringOrURIParseFailure Text
+  | MissingIssuer
   deriving (Eq, Show)
 
 -- | Returns True when the supplied `jwt` meets the following criteria:
@@ -49,15 +52,31 @@ jwtIsValid skipHTTP (EncodedJWT encodedJWT) = do
   where
     continue :: JWT UnverifiedJWT -> IO ValidationResult
     continue jwt = do
-      let audValues = jwt |> claims |> auds
-          mClientID = stringOrURI "771151720060-buofllhed98fgt0j22locma05e7rpngl.apps.googleusercontent.com"
-      case mClientID of
-        Nothing ->
-          pure $ ClientIDParseFailure "771151720060-buofllhed98fgt0j22locma05e7rpngl.apps.googleusercontent.com"
-        Just clientID ->
+      let audValues :: [StringOrURI]
+          audValues = jwt |> claims |> auds
+          expectedClientID :: Text
+          expectedClientID = "771151720060-buofllhed98fgt0j22locma05e7rpngl.apps.googleusercontent.com"
+          expectedIssuers :: [Text]
+          expectedIssuers = [ "accounts.google.com"
+                            , "https://accounts.google.com"
+                            ]
+          mExpectedClientID :: Maybe StringOrURI
+          mExpectedClientID = stringOrURI expectedClientID
+          mExpectedIssuers :: Maybe [StringOrURI]
+          mExpectedIssuers = expectedIssuers |> traverse stringOrURI
+      case (mExpectedClientID, mExpectedIssuers) of
+        (Nothing, _) -> pure $ StringOrURIParseFailure expectedClientID
+        (_, Nothing) -> pure $ StringOrURIParseFailure (Text.unwords expectedIssuers)
+        (Just clientID, Just parsedIssuers) ->
           -- TODO: Prefer reading clientID from a config. I'm thinking of the
           -- AppContext type having my Configuration
           if not $ clientID `elem` audValues then
             pure $ NoMatchingClientIDs audValues
           else
-            pure Valid
+            case jwt |> claims |> iss of
+              Nothing -> pure MissingIssuer
+              Just jwtIssuer ->
+                if not $ jwtIssuer `elem` parsedIssuers then
+                  pure $ WrongIssuer jwtIssuer
+                else
+                  pure Valid
diff --git a/website/sandbox/learnpianochords/src/server/Spec.hs b/website/sandbox/learnpianochords/src/server/Spec.hs
index 20c7b96b952f..96f10a9c4332 100644
--- a/website/sandbox/learnpianochords/src/server/Spec.hs
+++ b/website/sandbox/learnpianochords/src/server/Spec.hs
@@ -32,3 +32,15 @@ main = hspec $ do
             encodedJWT = F.defaultJWTFields { F.overwriteAuds = auds }
                          |> F.googleJWT
         jwtIsValid' encodedJWT `shouldReturn` Valid
+
+      it "returns validation error when one of the iss field doesn't match accounts.google.com or https://accounts.google.com" $ do
+        let erroneousIssuer = TestUtils.unsafeStringOrURI "not-accounts.google.com"
+            encodedJWT = F.defaultJWTFields { F.overwriteIss = erroneousIssuer }
+                         |> F.googleJWT
+        jwtIsValid' encodedJWT `shouldReturn` WrongIssuer erroneousIssuer
+
+      it "returns validation success when the iss field matches accounts.google.com or https://accounts.google.com" $ do
+        let erroneousIssuer = TestUtils.unsafeStringOrURI "https://accounts.google.com"
+            encodedJWT = F.defaultJWTFields { F.overwriteIss = erroneousIssuer }
+                         |> F.googleJWT
+        jwtIsValid' encodedJWT `shouldReturn` Valid