about summary refs log tree commit diff
path: root/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib.rs')
-rw-r--r--src/lib.rs33
1 files changed, 23 insertions, 10 deletions
diff --git a/src/lib.rs b/src/lib.rs
index ba856c176615..523a9898d7a6 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -55,11 +55,13 @@ extern crate serde_json;
 
 use base64::{decode_config, URL_SAFE};
 use openssl::bn::BigNum;
+use openssl::error::ErrorStack;
+use openssl::hash::MessageDigest;
 use openssl::pkey::{Public, PKey};
 use openssl::rsa::Rsa;
 use openssl::sign::Verifier;
-use openssl::hash::MessageDigest;
-use openssl::error::ErrorStack;
+use serde::de::DeserializeOwned;
+use serde_json::Value;
 
 #[cfg(test)]
 mod tests;
@@ -135,6 +137,9 @@ pub enum ValidationError {
     /// a more specific error variant could not be constructed.
     OpenSSL(ErrorStack),
 
+    /// JSON decoding into a provided type failed.
+    JSON(serde_json::Error),
+
     /// One or more claim validations failed.
     // TODO: Provide reasons?
     InvalidClaims,
@@ -146,6 +151,10 @@ impl From<ErrorStack> for ValidationError {
     fn from(err: ErrorStack) -> Self { ValidationError::OpenSSL(err) }
 }
 
+impl From<serde_json::Error> for ValidationError {
+    fn from(err: serde_json::Error) -> Self { ValidationError::JSON(err) }
+}
+
 /// Attempt to extract the `kid`-claim out of a JWT's header claims.
 ///
 /// This function is normally used when a token provider has multiple
@@ -162,19 +171,14 @@ pub fn token_kid(jwt: &JWT) -> JWTResult<Option<String>> {
         return Err(ValidationError::MalformedJWT);
     }
 
-    // The token components are individually base64 decoded, decode
-    // just the first part and deserialise it into the expected
-    // representation.
-    let headers_json = base64::decode_config(parts[0], URL_SAFE)
-        .map_err(|_| ValidationError::MalformedJWT)?;
-
+    // Decode only the first part of the token into a specialised
+    // representation:
     #[derive(Deserialize)]
     struct KidOnly {
         kid: Option<String>,
     }
 
-    let kid_only: KidOnly = serde_json::from_slice(&headers_json)
-        .map_err(|_| ValidationError::MalformedJWT)?;
+    let kid_only: KidOnly = deserialize_part(parts[0])?;
 
     Ok(kid_only.kid)
 }
@@ -212,6 +216,15 @@ fn public_key_from_jwk(jwk: &JWK) -> JWTResult<Rsa<Public>> {
     Rsa::from_public_components(jwk_n, jwk_e).map_err(Into::into)
 }
 
+/// Decode a base64-URL encoded string and deserialise the resulting
+/// JSON.
+fn deserialize_part<T: DeserializeOwned>(part: &str) -> JWTResult<T> {
+    let json = base64::decode_config(part, URL_SAFE)
+        .map_err(|_| ValidationError::MalformedJWT)?;
+
+    serde_json::from_slice(&json).map_err(Into::into)
+}
+
 /// Validate the signature on a JWT using a provided public key.
 ///
 /// A JWT is made up of three components (headers, claims, signature)