diff options
| -rw-r--r-- | crypto/rsa.c | 105 | ||||
| -rw-r--r-- | crypto/rsa_helper.c | 111 | ||||
| -rw-r--r-- | include/crypto/internal/rsa.h | 22 | 
3 files changed, 135 insertions, 103 deletions
| diff --git a/crypto/rsa.c b/crypto/rsa.c index 77d737f52147..dc692d43b666 100644 --- a/crypto/rsa.c +++ b/crypto/rsa.c @@ -10,16 +10,23 @@   */  #include <linux/module.h> +#include <linux/mpi.h>  #include <crypto/internal/rsa.h>  #include <crypto/internal/akcipher.h>  #include <crypto/akcipher.h>  #include <crypto/algapi.h> +struct rsa_mpi_key { +	MPI n; +	MPI e; +	MPI d; +}; +  /*   * RSAEP function [RFC3447 sec 5.1.1]   * c = m^e mod n;   */ -static int _rsa_enc(const struct rsa_key *key, MPI c, MPI m) +static int _rsa_enc(const struct rsa_mpi_key *key, MPI c, MPI m)  {  	/* (1) Validate 0 <= m < n */  	if (mpi_cmp_ui(m, 0) < 0 || mpi_cmp(m, key->n) >= 0) @@ -33,7 +40,7 @@ static int _rsa_enc(const struct rsa_key *key, MPI c, MPI m)   * RSADP function [RFC3447 sec 5.1.2]   * m = c^d mod n;   */ -static int _rsa_dec(const struct rsa_key *key, MPI m, MPI c) +static int _rsa_dec(const struct rsa_mpi_key *key, MPI m, MPI c)  {  	/* (1) Validate 0 <= c < n */  	if (mpi_cmp_ui(c, 0) < 0 || mpi_cmp(c, key->n) >= 0) @@ -47,7 +54,7 @@ static int _rsa_dec(const struct rsa_key *key, MPI m, MPI c)   * RSASP1 function [RFC3447 sec 5.2.1]   * s = m^d mod n   */ -static int _rsa_sign(const struct rsa_key *key, MPI s, MPI m) +static int _rsa_sign(const struct rsa_mpi_key *key, MPI s, MPI m)  {  	/* (1) Validate 0 <= m < n */  	if (mpi_cmp_ui(m, 0) < 0 || mpi_cmp(m, key->n) >= 0) @@ -61,7 +68,7 @@ static int _rsa_sign(const struct rsa_key *key, MPI s, MPI m)   * RSAVP1 function [RFC3447 sec 5.2.2]   * m = s^e mod n;   */ -static int _rsa_verify(const struct rsa_key *key, MPI m, MPI s) +static int _rsa_verify(const struct rsa_mpi_key *key, MPI m, MPI s)  {  	/* (1) Validate 0 <= s < n */  	if (mpi_cmp_ui(s, 0) < 0 || mpi_cmp(s, key->n) >= 0) @@ -71,7 +78,7 @@ static int _rsa_verify(const struct rsa_key *key, MPI m, MPI s)  	return mpi_powm(m, s, key->e, key->n);  } -static inline struct rsa_key *rsa_get_key(struct crypto_akcipher *tfm) +static inline struct rsa_mpi_key *rsa_get_key(struct crypto_akcipher *tfm)  {  	return akcipher_tfm_ctx(tfm);  } @@ -79,7 +86,7 @@ static inline struct rsa_key *rsa_get_key(struct crypto_akcipher *tfm)  static int rsa_enc(struct akcipher_request *req)  {  	struct crypto_akcipher *tfm = crypto_akcipher_reqtfm(req); -	const struct rsa_key *pkey = rsa_get_key(tfm); +	const struct rsa_mpi_key *pkey = rsa_get_key(tfm);  	MPI m, c = mpi_alloc(0);  	int ret = 0;  	int sign; @@ -118,7 +125,7 @@ err_free_c:  static int rsa_dec(struct akcipher_request *req)  {  	struct crypto_akcipher *tfm = crypto_akcipher_reqtfm(req); -	const struct rsa_key *pkey = rsa_get_key(tfm); +	const struct rsa_mpi_key *pkey = rsa_get_key(tfm);  	MPI c, m = mpi_alloc(0);  	int ret = 0;  	int sign; @@ -156,7 +163,7 @@ err_free_m:  static int rsa_sign(struct akcipher_request *req)  {  	struct crypto_akcipher *tfm = crypto_akcipher_reqtfm(req); -	const struct rsa_key *pkey = rsa_get_key(tfm); +	const struct rsa_mpi_key *pkey = rsa_get_key(tfm);  	MPI m, s = mpi_alloc(0);  	int ret = 0;  	int sign; @@ -195,7 +202,7 @@ err_free_s:  static int rsa_verify(struct akcipher_request *req)  {  	struct crypto_akcipher *tfm = crypto_akcipher_reqtfm(req); -	const struct rsa_key *pkey = rsa_get_key(tfm); +	const struct rsa_mpi_key *pkey = rsa_get_key(tfm);  	MPI s, m = mpi_alloc(0);  	int ret = 0;  	int sign; @@ -233,6 +240,16 @@ err_free_m:  	return ret;  } +static void rsa_free_mpi_key(struct rsa_mpi_key *key) +{ +	mpi_free(key->d); +	mpi_free(key->e); +	mpi_free(key->n); +	key->d = NULL; +	key->e = NULL; +	key->n = NULL; +} +  static int rsa_check_key_length(unsigned int len)  {  	switch (len) { @@ -251,49 +268,87 @@ static int rsa_check_key_length(unsigned int len)  static int rsa_set_pub_key(struct crypto_akcipher *tfm, const void *key,  			   unsigned int keylen)  { -	struct rsa_key *pkey = akcipher_tfm_ctx(tfm); +	struct rsa_mpi_key *mpi_key = akcipher_tfm_ctx(tfm); +	struct rsa_key raw_key = {0};  	int ret; -	ret = rsa_parse_pub_key(pkey, key, keylen); +	/* Free the old MPI key if any */ +	rsa_free_mpi_key(mpi_key); + +	ret = rsa_parse_pub_key(&raw_key, key, keylen);  	if (ret)  		return ret; -	if (rsa_check_key_length(mpi_get_size(pkey->n) << 3)) { -		rsa_free_key(pkey); -		ret = -EINVAL; +	mpi_key->e = mpi_read_raw_data(raw_key.e, raw_key.e_sz); +	if (!mpi_key->e) +		goto err; + +	mpi_key->n = mpi_read_raw_data(raw_key.n, raw_key.n_sz); +	if (!mpi_key->n) +		goto err; + +	if (rsa_check_key_length(mpi_get_size(mpi_key->n) << 3)) { +		rsa_free_mpi_key(mpi_key); +		return -EINVAL;  	} -	return ret; + +	return 0; + +err: +	rsa_free_mpi_key(mpi_key); +	return -ENOMEM;  }  static int rsa_set_priv_key(struct crypto_akcipher *tfm, const void *key,  			    unsigned int keylen)  { -	struct rsa_key *pkey = akcipher_tfm_ctx(tfm); +	struct rsa_mpi_key *mpi_key = akcipher_tfm_ctx(tfm); +	struct rsa_key raw_key = {0};  	int ret; -	ret = rsa_parse_priv_key(pkey, key, keylen); +	/* Free the old MPI key if any */ +	rsa_free_mpi_key(mpi_key); + +	ret = rsa_parse_priv_key(&raw_key, key, keylen);  	if (ret)  		return ret; -	if (rsa_check_key_length(mpi_get_size(pkey->n) << 3)) { -		rsa_free_key(pkey); -		ret = -EINVAL; +	mpi_key->d = mpi_read_raw_data(raw_key.d, raw_key.d_sz); +	if (!mpi_key->d) +		goto err; + +	mpi_key->e = mpi_read_raw_data(raw_key.e, raw_key.e_sz); +	if (!mpi_key->e) +		goto err; + +	mpi_key->n = mpi_read_raw_data(raw_key.n, raw_key.n_sz); +	if (!mpi_key->n) +		goto err; + +	if (rsa_check_key_length(mpi_get_size(mpi_key->n) << 3)) { +		rsa_free_mpi_key(mpi_key); +		return -EINVAL;  	} -	return ret; + +	return 0; + +err: +	rsa_free_mpi_key(mpi_key); +	return -ENOMEM;  }  static int rsa_max_size(struct crypto_akcipher *tfm)  { -	struct rsa_key *pkey = akcipher_tfm_ctx(tfm); +	struct rsa_mpi_key *pkey = akcipher_tfm_ctx(tfm);  	return pkey->n ? mpi_get_size(pkey->n) : -EINVAL;  }  static void rsa_exit_tfm(struct crypto_akcipher *tfm)  { -	struct rsa_key *pkey = akcipher_tfm_ctx(tfm); +	struct rsa_mpi_key *pkey = akcipher_tfm_ctx(tfm); -	rsa_free_key(pkey); +	rsa_free_mpi_key(pkey);  }  static struct akcipher_alg rsa = { @@ -310,7 +365,7 @@ static struct akcipher_alg rsa = {  		.cra_driver_name = "rsa-generic",  		.cra_priority = 100,  		.cra_module = THIS_MODULE, -		.cra_ctxsize = sizeof(struct rsa_key), +		.cra_ctxsize = sizeof(struct rsa_mpi_key),  	},  }; diff --git a/crypto/rsa_helper.c b/crypto/rsa_helper.c index d226f48d0907..583656af4fe2 100644 --- a/crypto/rsa_helper.c +++ b/crypto/rsa_helper.c @@ -22,20 +22,29 @@ int rsa_get_n(void *context, size_t hdrlen, unsigned char tag,  	      const void *value, size_t vlen)  {  	struct rsa_key *key = context; +	const u8 *ptr = value; +	size_t n_sz = vlen; -	key->n = mpi_read_raw_data(value, vlen); - -	if (!key->n) -		return -ENOMEM; - -	/* In FIPS mode only allow key size 2K & 3K */ -	if (fips_enabled && (mpi_get_size(key->n) != 256 && -			     mpi_get_size(key->n) != 384)) { -		pr_err("RSA: key size not allowed in FIPS mode\n"); -		mpi_free(key->n); -		key->n = NULL; +	/* invalid key provided */ +	if (!value || !vlen)  		return -EINVAL; + +	if (fips_enabled) { +		while (!*ptr && n_sz) { +			ptr++; +			n_sz--; +		} + +		/* In FIPS mode only allow key size 2K & 3K */ +		if (n_sz != 256 && n_sz != 384) { +			pr_err("RSA: key size not allowed in FIPS mode\n"); +			return -EINVAL; +		}  	} + +	key->n = value; +	key->n_sz = vlen; +  	return 0;  } @@ -44,10 +53,12 @@ int rsa_get_e(void *context, size_t hdrlen, unsigned char tag,  {  	struct rsa_key *key = context; -	key->e = mpi_read_raw_data(value, vlen); +	/* invalid key provided */ +	if (!value || !key->n_sz || !vlen || vlen > key->n_sz) +		return -EINVAL; -	if (!key->e) -		return -ENOMEM; +	key->e = value; +	key->e_sz = vlen;  	return 0;  } @@ -57,46 +68,20 @@ int rsa_get_d(void *context, size_t hdrlen, unsigned char tag,  {  	struct rsa_key *key = context; -	key->d = mpi_read_raw_data(value, vlen); - -	if (!key->d) -		return -ENOMEM; - -	/* In FIPS mode only allow key size 2K & 3K */ -	if (fips_enabled && (mpi_get_size(key->d) != 256 && -			     mpi_get_size(key->d) != 384)) { -		pr_err("RSA: key size not allowed in FIPS mode\n"); -		mpi_free(key->d); -		key->d = NULL; +	/* invalid key provided */ +	if (!value || !key->n_sz || !vlen || vlen > key->n_sz)  		return -EINVAL; -	} -	return 0; -} -static void free_mpis(struct rsa_key *key) -{ -	mpi_free(key->n); -	mpi_free(key->e); -	mpi_free(key->d); -	key->n = NULL; -	key->e = NULL; -	key->d = NULL; -} +	key->d = value; +	key->d_sz = vlen; -/** - * rsa_free_key() - frees rsa key allocated by rsa_parse_key() - * - * @rsa_key:	struct rsa_key key representation - */ -void rsa_free_key(struct rsa_key *key) -{ -	free_mpis(key); +	return 0;  } -EXPORT_SYMBOL_GPL(rsa_free_key);  /** - * rsa_parse_pub_key() - extracts an rsa public key from BER encoded buffer - *			 and stores it in the provided struct rsa_key + * rsa_parse_pub_key() - decodes the BER encoded buffer and stores in the + *                       provided struct rsa_key, pointers to the raw key as is, + *                       so that the caller can copy it or MPI parse it, etc.   *   * @rsa_key:	struct rsa_key key representation   * @key:	key in BER format @@ -107,23 +92,15 @@ EXPORT_SYMBOL_GPL(rsa_free_key);  int rsa_parse_pub_key(struct rsa_key *rsa_key, const void *key,  		      unsigned int key_len)  { -	int ret; - -	free_mpis(rsa_key); -	ret = asn1_ber_decoder(&rsapubkey_decoder, rsa_key, key, key_len); -	if (ret < 0) -		goto error; - -	return 0; -error: -	free_mpis(rsa_key); -	return ret; +	return asn1_ber_decoder(&rsapubkey_decoder, rsa_key, key, key_len);  }  EXPORT_SYMBOL_GPL(rsa_parse_pub_key);  /** - * rsa_parse_pub_key() - extracts an rsa private key from BER encoded buffer - *			 and stores it in the provided struct rsa_key + * rsa_parse_priv_key() - decodes the BER encoded buffer and stores in the + *                        provided struct rsa_key, pointers to the raw key + *                        as is, so that the caller can copy it or MPI parse it, + *                        etc.   *   * @rsa_key:	struct rsa_key key representation   * @key:	key in BER format @@ -134,16 +111,6 @@ EXPORT_SYMBOL_GPL(rsa_parse_pub_key);  int rsa_parse_priv_key(struct rsa_key *rsa_key, const void *key,  		       unsigned int key_len)  { -	int ret; - -	free_mpis(rsa_key); -	ret = asn1_ber_decoder(&rsaprivkey_decoder, rsa_key, key, key_len); -	if (ret < 0) -		goto error; - -	return 0; -error: -	free_mpis(rsa_key); -	return ret; +	return asn1_ber_decoder(&rsaprivkey_decoder, rsa_key, key, key_len);  }  EXPORT_SYMBOL_GPL(rsa_parse_priv_key); diff --git a/include/crypto/internal/rsa.h b/include/crypto/internal/rsa.h index c7585bdecbc2..d6c042a2ee52 100644 --- a/include/crypto/internal/rsa.h +++ b/include/crypto/internal/rsa.h @@ -12,12 +12,24 @@   */  #ifndef _RSA_HELPER_  #define _RSA_HELPER_ -#include <linux/mpi.h> +#include <linux/types.h> +/** + * rsa_key - RSA key structure + * @n           : RSA modulus raw byte stream + * @e           : RSA public exponent raw byte stream + * @d           : RSA private exponent raw byte stream + * @n_sz        : length in bytes of RSA modulus n + * @e_sz        : length in bytes of RSA public exponent + * @d_sz        : length in bytes of RSA private exponent + */  struct rsa_key { -	MPI n; -	MPI e; -	MPI d; +	const u8 *n; +	const u8 *e; +	const u8 *d; +	size_t n_sz; +	size_t e_sz; +	size_t d_sz;  };  int rsa_parse_pub_key(struct rsa_key *rsa_key, const void *key, @@ -26,7 +38,5 @@ int rsa_parse_pub_key(struct rsa_key *rsa_key, const void *key,  int rsa_parse_priv_key(struct rsa_key *rsa_key, const void *key,  		       unsigned int key_len); -void rsa_free_key(struct rsa_key *rsa_key); -  extern struct crypto_template rsa_pkcs1pad_tmpl;  #endif | 
