about summary refs log tree commit diff
path: root/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib.rs')
-rw-r--r--src/lib.rs39
1 files changed, 27 insertions, 12 deletions
diff --git a/src/lib.rs b/src/lib.rs
index e62600e26b2f..135b1df0f9f2 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -74,7 +74,7 @@ extern crate openssl;
 extern crate serde;
 extern crate serde_json;
 
-use base64::{decode_config, URL_SAFE};
+use base64::{URL_SAFE_NO_PAD, Config, DecodeError};
 use openssl::bn::BigNum;
 use openssl::error::ErrorStack;
 use openssl::hash::MessageDigest;
@@ -88,6 +88,17 @@ use std::time::{UNIX_EPOCH, Duration, SystemTime};
 #[cfg(test)]
 mod tests;
 
+
+/// URL-safe character set without padding that allows trailing bits,
+/// which appear in some JWT implementations.
+///
+/// Note: The functions on `base64::Config` are not marked `const`,
+/// and the constructors are not exported, which is why this is
+/// implemented as a function.
+fn jwt_forgiving() -> Config {
+    URL_SAFE_NO_PAD.decode_allow_trailing_bits(true)
+}
+
 /// JWT algorithm used. The only supported algorithm is currently
 /// RS256.
 #[derive(Clone, Deserialize, Debug)]
@@ -179,8 +190,11 @@ pub enum Validation {
 /// Possible results of a token validation.
 #[derive(Debug)]
 pub enum ValidationError {
-    /// Token was malformed (various possible reasons!)
-    MalformedJWT,
+    /// Invalid number of token components (not a JWT?)
+    InvalidComponents,
+
+    /// Token segments had invalid base64-encoding.
+    InvalidBase64(DecodeError),
 
     /// Decoding of the provided JWK failed.
     InvalidJWK,
@@ -211,6 +225,10 @@ impl From<serde_json::Error> for ValidationError {
     fn from(err: serde_json::Error) -> Self { ValidationError::JSON(err) }
 }
 
+impl From<DecodeError> for ValidationError {
+    fn from(err: DecodeError) -> Self { ValidationError::InvalidBase64(err) }
+}
+
 /// Attempt to extract the `kid`-claim out of a JWT's header claims.
 ///
 /// This function is normally used when a token provider has multiple
@@ -224,7 +242,7 @@ pub fn token_kid(token: &str) -> JWTResult<Option<String>> {
     // dismissing the rest.
     let parts: Vec<&str> = token.splitn(2, '.').collect();
     if parts.len() != 2 {
-        return Err(ValidationError::MalformedJWT);
+        return Err(ValidationError::InvalidComponents);
     }
 
     // Decode only the first part of the token into a specialised
@@ -262,7 +280,7 @@ pub fn validate(token: &str,
     if parts.len() != 3 {
         // This is unlikely considering that validation has already
         // been performed at this point, but better safe than sorry.
-        return Err(ValidationError::MalformedJWT)
+        return Err(ValidationError::InvalidComponents)
     }
 
     // Perform claim validations before constructing the valid token:
@@ -284,7 +302,7 @@ pub fn validate(token: &str,
 /// Decode a single key fragment (base64-url encoded integer) to an
 /// OpenSSL BigNum.
 fn decode_fragment(fragment: &str) -> JWTResult<BigNum> {
-    let bytes = decode_config(fragment, URL_SAFE)
+    let bytes = base64::decode_config(fragment, jwt_forgiving())
         .map_err(|_| ValidationError::InvalidJWK)?;
 
     BigNum::from_slice(&bytes).map_err(Into::into)
@@ -301,9 +319,7 @@ fn public_key_from_jwk(jwk: &JWK) -> JWTResult<Rsa<Public>> {
 /// Decode a base64-URL encoded string and deserialise the resulting
 /// JSON.
 fn deserialize_part<T: DeserializeOwned>(part: &str) -> JWTResult<T> {
-    let json = base64::decode_config(part, URL_SAFE)
-        .map_err(|_| ValidationError::MalformedJWT)?;
-
+    let json = base64::decode_config(part, jwt_forgiving())?;
     serde_json::from_slice(&json).map_err(Into::into)
 }
 
@@ -321,7 +337,7 @@ fn validate_jwt_signature(jwt: &JWT, key: Rsa<Public>) -> JWTResult<()> {
     // splitting them is unnecessary.
     let token_parts: Vec<&str> = jwt.0.rsplitn(2, '.').collect();
     if token_parts.len() != 2 {
-        return Err(ValidationError::MalformedJWT);
+        return Err(ValidationError::InvalidComponents);
     }
 
     // Second element of the vector will be the signed payload.
@@ -329,8 +345,7 @@ fn validate_jwt_signature(jwt: &JWT, key: Rsa<Public>) -> JWTResult<()> {
 
     // First element of the vector will be the (encoded) signature.
     let sig_b64 = token_parts[0];
-    let sig = base64::decode_config(sig_b64, URL_SAFE)
-        .map_err(|_| ValidationError::MalformedJWT)?;
+    let sig = base64::decode_config(sig_b64, jwt_forgiving())?;
 
     // Verify signature by inserting the payload data and checking it
     // against the decoded signature.