[PATCH v4 2/4] crypto: support rsa-pss encoding

From: Hongbo Li
Date: Wed Apr 07 2021 - 09:42:38 EST


This patch add the support of rsa-pss encoding which is described in
RFC8017 section 8.1 and section 9.1.

Similar to rsa-pkcs1, we create a pss template. Parse pss related params
mgfhash and saltlen in set_pub_key.

Implement a mgf function according to RFC8017 section B.2.
Implement the verification according to RFC8017 section 8.1.2 and 9.1.2

Signed-off-by: Hongbo Li <herbert.tencent@xxxxxxxxx>
---
crypto/Makefile | 7 +-
crypto/rsa-psspad.c | 398 ++++++++++++++++++++++++++++++++++++++++++
crypto/rsa.c | 14 +-
crypto/rsa_helper.c | 127 ++++++++++++++
crypto/rsapss_params.asn1 | 21 +++
include/crypto/internal/rsa.h | 25 ++-
6 files changed, 583 insertions(+), 9 deletions(-)
create mode 100644 crypto/rsa-psspad.c
create mode 100644 crypto/rsapss_params.asn1

diff --git a/crypto/Makefile b/crypto/Makefile
index 10526d4..2c65744 100644
--- a/crypto/Makefile
+++ b/crypto/Makefile
@@ -33,13 +33,18 @@ obj-$(CONFIG_CRYPTO_DH) += dh_generic.o

$(obj)/rsapubkey.asn1.o: $(obj)/rsapubkey.asn1.c $(obj)/rsapubkey.asn1.h
$(obj)/rsaprivkey.asn1.o: $(obj)/rsaprivkey.asn1.c $(obj)/rsaprivkey.asn1.h
-$(obj)/rsa_helper.o: $(obj)/rsapubkey.asn1.h $(obj)/rsaprivkey.asn1.h
+$(obj)/rsapss_params.asn1.o: $(obj)/rsapss_params.asn1.c \
+ $(obj)/rsapss_params.asn1.h
+$(obj)/rsa_helper.o: $(obj)/rsapubkey.asn1.h $(obj)/rsaprivkey.asn1.h \
+ $(obj)/rsapss_params.asn1.h

rsa_generic-y := rsapubkey.asn1.o
rsa_generic-y += rsaprivkey.asn1.o
+rsa_generic-y += rsapss_params.asn1.o
rsa_generic-y += rsa.o
rsa_generic-y += rsa_helper.o
rsa_generic-y += rsa-pkcs1pad.o
+rsa_generic-y += rsa-psspad.o
obj-$(CONFIG_CRYPTO_RSA) += rsa_generic.o

$(obj)/sm2signature.asn1.o: $(obj)/sm2signature.asn1.c $(obj)/sm2signature.asn1.h
diff --git a/crypto/rsa-psspad.c b/crypto/rsa-psspad.c
new file mode 100644
index 0000000..342c4cc
--- /dev/null
+++ b/crypto/rsa-psspad.c
@@ -0,0 +1,398 @@
+// SPDX-License-Identifier: GPL-2.0+
+/*
+ * RSA PSS padding templates.
+ *
+ * Copyright (c) 2021 Hongbo Li <herberthbli@xxxxxxxxxxx>
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License as published by the Free
+ * Software Foundation; either version 2 of the License, or (at your option)
+ * any later version.
+ */
+
+#include <crypto/hash.h>
+#include <crypto/internal/rsa.h>
+#include <crypto/internal/akcipher.h>
+
+struct psspad_inst_ctx {
+ struct crypto_akcipher_spawn spawn;
+};
+
+struct psspad_request {
+ struct scatterlist out_sg[1];
+ uint8_t *out_buf;
+ struct akcipher_request child_req;
+};
+
+static const u8 *psspad_unpack(void *dst, const void *src, size_t sz)
+{
+ memcpy(dst, src, sz);
+ return src + sz;
+}
+
+static int psspad_set_pub_key(struct crypto_akcipher *tfm, const void *key,
+ unsigned int keylen)
+{
+ struct rsa_pss_ctx *ctx = akcipher_tfm_ctx(tfm);
+ const u8 *ptr;
+ u32 algo, paramlen;
+ int err;
+
+ ctx->key_size = 0;
+
+ err = crypto_akcipher_set_pub_key(ctx->child, key, keylen);
+ if (err)
+ return err;
+
+ /* Find out new modulus size from rsa implementation */
+ err = crypto_akcipher_maxsize(ctx->child);
+ if (err > PAGE_SIZE)
+ return -EOPNOTSUPP;
+
+ ctx->key_size = err;
+
+ ptr = key + keylen;
+ ptr = psspad_unpack(&algo, ptr, sizeof(algo));
+ ptr = psspad_unpack(&paramlen, ptr, sizeof(paramlen));
+ err = rsa_parse_pss_params(ctx, ptr, paramlen);
+ if (err < 0)
+ return err;
+
+ if (!ctx->hash_algo)
+ ctx->hash_algo = "sha1";
+ if (!ctx->mgf_algo)
+ ctx->mgf_algo = "mgf1";
+ if (!ctx->mgf_hash_algo)
+ ctx->mgf_hash_algo = "sha1";
+ if (!ctx->salt_len)
+ ctx->salt_len = RSA_PSS_DEFAULT_SALT_LEN;
+
+ return 0;
+}
+
+static int psspad_mgf1(const char *hash_algo, u8 *seed, u32 seed_len, u8 *mask,
+ u32 masklen)
+{
+ struct crypto_shash *tfm = NULL;
+ u32 hlen, cnt, tlen;
+ u8 c[4], digest[RSA_MAX_DIGEST_SIZE], buf[RSA_MAX_DIGEST_SIZE + 4];
+ int i, err = 0;
+ SHASH_DESC_ON_STACK(desc, tfm);
+
+ tfm = crypto_alloc_shash(hash_algo, 0, 0);
+ if (IS_ERR(tfm)) {
+ err = PTR_ERR(tfm);
+ return err;
+ }
+ desc->tfm = tfm;
+ hlen = crypto_shash_digestsize(tfm);
+ cnt = DIV_ROUND_UP(masklen, hlen);
+ tlen = 0;
+ for (i = 0; i < cnt; i++) {
+ /* C = I2OSP (counter, 4) */
+ c[0] = (i >> 24) & 0xff;
+ c[1] = (i >> 16) & 0xff;
+ c[2] = (i >> 8) & 0xff;
+ c[3] = i & 0xff;
+
+ memcpy(buf, seed, seed_len);
+ memcpy(buf + seed_len, c, 4);
+ err = crypto_shash_digest(desc, buf,
+ seed_len + 4, digest);
+ if (err < 0)
+ goto free;
+
+ /* T = T || Hash(mgfSeed || C) */
+ tlen = i * hlen;
+ if (i == cnt - 1)
+ memcpy(mask + tlen, digest, masklen - tlen);
+ else
+ memcpy(mask + tlen, digest, hlen);
+ }
+free:
+ crypto_free_shash(tfm);
+ return err;
+}
+
+/* EMSA-PSS-VERIFY (M, EM, emBits) */
+static int psspad_verify_complete(struct akcipher_request *req, int err)
+{
+ struct crypto_akcipher *ak_tfm = crypto_akcipher_reqtfm(req);
+ struct rsa_pss_ctx *ctx = akcipher_tfm_ctx(ak_tfm);
+ struct psspad_request *req_ctx = akcipher_request_ctx(req);
+ struct crypto_akcipher *rsa_tfm;
+ struct rsa_mpi_key *mpi_key;
+ struct crypto_shash *tfm = NULL;
+ u32 i, hlen, slen, modbits, embits, emlen, masklen, buflen;
+ u8 *em, *h, *maskeddb, *dbmask, *db, *salt;
+ u8 mhash[RSA_MAX_DIGEST_SIZE], digest[RSA_MAX_DIGEST_SIZE];
+ u8 *buf = NULL;
+ SHASH_DESC_ON_STACK(desc, tfm);
+
+ if (err)
+ goto free;
+
+ tfm = crypto_alloc_shash(ctx->hash_algo, 0, 0);
+ if (IS_ERR(tfm)) {
+ err = PTR_ERR(tfm);
+ tfm = NULL;
+ goto free;
+ }
+ desc->tfm = tfm;
+ hlen = crypto_shash_digestsize(tfm);
+
+ /* mhash */
+ sg_pcopy_to_buffer(req->src,
+ sg_nents_for_len(req->src,
+ req->src_len + req->dst_len),
+ mhash, hlen, req->src_len);
+
+ err = -EINVAL;
+
+ /* section 8.1.2. emLen = \ceil ((modBits - 1)/8) */
+ rsa_tfm = crypto_akcipher_reqtfm(&req_ctx->child_req);
+ mpi_key = akcipher_tfm_ctx(rsa_tfm);
+ modbits = mpi_get_nbits(mpi_key->n);
+ embits = modbits - 1;
+ emlen = DIV_ROUND_UP(embits, 8);
+
+ /* 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop. */
+ slen = ctx->salt_len;
+ if (emlen < hlen + slen + 2)
+ goto free;
+
+ /* 4. If the rightmost octet of EM does not have hexadecimal value
+ * 0xbc, output "inconsistent" and stop.
+ */
+ em = req_ctx->out_buf;
+ if (em[emlen - 1] != 0xbc)
+ goto free;
+
+
+ /* 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM,
+ * and let H be the next hLen octets.
+ */
+ maskeddb = em;
+ masklen = emlen - hlen - 1;
+ h = em + masklen;
+
+ /* 6. If the leftmost 8emLen - emBits bits of the leftmost octet in
+ * maskedDB are not all equal to zero, output "inconsistent" and
+ * stop.
+ */
+ if (maskeddb[0] & ~(0xff >> (8 * emlen - embits)))
+ goto free;
+
+ /* 7. Let dbMask = MGF(H, emLen - hLen - 1). */
+ buflen = max_t(u32, masklen, 8 + hlen + slen);
+ buf = kmalloc(buflen, GFP_KERNEL);
+ if (!buf) {
+ err = -ENOMEM;
+ goto free;
+ }
+ dbmask = buf;
+ err = psspad_mgf1(ctx->mgf_hash_algo, h, hlen, dbmask, masklen);
+ if (err)
+ goto free;
+
+ /* 8. Let DB = maskedDB \xor dbMask. */
+ db = maskeddb;
+ for (i = 0; i < masklen; i++)
+ db[i] = maskeddb[i] ^ dbmask[i];
+
+ /* 9. Set the leftmost 8emLen - emBits bits of the leftmost octet
+ * in DB to zero.
+ */
+ db[0] &= 0xff >> (8 * emlen - embits);
+
+ /* 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not
+ * zero or if the octet at position emLen - hLen - sLen - 1 (the
+ * leftmost position is "position 1") does not have hexadecimal
+ * value 0x01, output "inconsistent" and stop.
+ */
+ for (i = 0; i < emlen - hlen - slen - 2; i++) {
+ if (db[i]) {
+ err = -EINVAL;
+ goto free;
+ }
+ }
+ if (db[i] != 1)
+ goto free;
+
+ /* 11. Let salt be the last sLen octets of DB. */
+ salt = db + masklen - slen;
+
+ /* 12. M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ; */
+ memset(buf, 0, 8);
+ memcpy(buf + 8, mhash, hlen);
+ memcpy(buf + 8 + hlen, salt, slen);
+
+ /* 13. Let H' = Hash(M'), an octet string of length hLen. */
+ err = crypto_shash_digest(desc, buf, 8 + hlen + slen, digest);
+ if (err < 0)
+ goto free;
+
+ /* 14. If H = H', output "consistent". Otherwise, output
+ * "inconsistent".
+ */
+ if (memcmp(h, digest, hlen))
+ err = -EKEYREJECTED;
+
+free:
+ if (tfm)
+ crypto_free_shash(tfm);
+ kfree_sensitive(req_ctx->out_buf);
+ kfree(buf);
+ return err;
+}
+
+static void psspad_verify_complete_cb(
+ struct crypto_async_request *child_async_req, int err)
+{
+ struct akcipher_request *req = child_async_req->data;
+ struct crypto_async_request async_req;
+
+ if (err == -EINPROGRESS)
+ return;
+
+ async_req.data = req->base.data;
+ async_req.tfm = crypto_akcipher_tfm(crypto_akcipher_reqtfm(req));
+ async_req.flags = child_async_req->flags;
+ req->base.complete(&async_req, psspad_verify_complete(req, err));
+}
+
+static int psspad_verify(struct akcipher_request *req)
+{
+ struct crypto_akcipher *tfm = crypto_akcipher_reqtfm(req);
+ struct rsa_pss_ctx *ctx = akcipher_tfm_ctx(tfm);
+ struct psspad_request *req_ctx = akcipher_request_ctx(req);
+ int err;
+
+ if (WARN_ON(req->dst) ||
+ WARN_ON(!req->dst_len) ||
+ !ctx->key_size || req->src_len < ctx->key_size)
+ return -EINVAL;
+
+ req_ctx->out_buf = kmalloc(ctx->key_size + req->dst_len, GFP_KERNEL);
+ if (!req_ctx->out_buf)
+ return -ENOMEM;
+
+ sg_init_table(req_ctx->out_sg, 1);
+ sg_set_buf(req_ctx->out_sg, req_ctx->out_buf, ctx->key_size);
+
+ akcipher_request_set_tfm(&req_ctx->child_req, ctx->child);
+ akcipher_request_set_callback(&req_ctx->child_req, req->base.flags,
+ psspad_verify_complete_cb, req);
+
+ /* Reuse input buffer, output to a new buffer */
+ akcipher_request_set_crypt(&req_ctx->child_req, req->src,
+ req_ctx->out_sg, req->src_len,
+ ctx->key_size);
+
+ err = crypto_akcipher_encrypt(&req_ctx->child_req);
+ if (err != -EINPROGRESS && err != -EBUSY)
+ return psspad_verify_complete(req, err);
+
+ return err;
+}
+
+static unsigned int psspad_get_max_size(struct crypto_akcipher *tfm)
+{
+ struct rsa_pss_ctx *ctx = akcipher_tfm_ctx(tfm);
+
+ return ctx->key_size;
+}
+
+static int psspad_init_tfm(struct crypto_akcipher *tfm)
+{
+ struct akcipher_instance *inst = akcipher_alg_instance(tfm);
+ struct psspad_inst_ctx *ictx = akcipher_instance_ctx(inst);
+ struct rsa_pss_ctx *ctx = akcipher_tfm_ctx(tfm);
+ struct crypto_akcipher *child_tfm;
+
+ child_tfm = crypto_spawn_akcipher(&ictx->spawn);
+ if (IS_ERR(child_tfm))
+ return PTR_ERR(child_tfm);
+
+ ctx->child = child_tfm;
+ return 0;
+}
+
+static void psspad_exit_tfm(struct crypto_akcipher *tfm)
+{
+ struct rsa_pss_ctx *ctx = akcipher_tfm_ctx(tfm);
+
+ crypto_free_akcipher(ctx->child);
+}
+
+static void psspad_free(struct akcipher_instance *inst)
+{
+ struct psspad_inst_ctx *ctx = akcipher_instance_ctx(inst);
+ struct crypto_akcipher_spawn *spawn = &ctx->spawn;
+
+ crypto_drop_akcipher(spawn);
+ kfree(inst);
+}
+
+static int psspad_create(struct crypto_template *tmpl, struct rtattr **tb)
+{
+ u32 mask;
+ struct akcipher_instance *inst;
+ struct psspad_inst_ctx *ctx;
+ struct akcipher_alg *rsa_alg;
+ int err;
+
+ err = crypto_check_attr_type(tb, CRYPTO_ALG_TYPE_AKCIPHER, &mask);
+ if (err)
+ return err;
+
+ inst = kzalloc(sizeof(*inst) + sizeof(*ctx), GFP_KERNEL);
+ if (!inst)
+ return -ENOMEM;
+
+ ctx = akcipher_instance_ctx(inst);
+
+ err = crypto_grab_akcipher(&ctx->spawn, akcipher_crypto_instance(inst),
+ crypto_attr_alg_name(tb[1]), 0, mask);
+ if (err)
+ goto err_free_inst;
+
+ rsa_alg = crypto_spawn_akcipher_alg(&ctx->spawn);
+
+ err = -ENAMETOOLONG;
+ if (snprintf(inst->alg.base.cra_name,
+ CRYPTO_MAX_ALG_NAME, "psspad(%s)",
+ rsa_alg->base.cra_name) >= CRYPTO_MAX_ALG_NAME)
+ goto err_free_inst;
+
+ if (snprintf(inst->alg.base.cra_driver_name,
+ CRYPTO_MAX_ALG_NAME, "psspad(%s)",
+ rsa_alg->base.cra_driver_name) >=
+ CRYPTO_MAX_ALG_NAME)
+ goto err_free_inst;
+
+ inst->alg.base.cra_priority = rsa_alg->base.cra_priority;
+ inst->alg.base.cra_ctxsize = sizeof(struct rsa_pss_ctx);
+
+ inst->alg.init = psspad_init_tfm;
+ inst->alg.exit = psspad_exit_tfm;
+ inst->alg.verify = psspad_verify;
+ inst->alg.set_pub_key = psspad_set_pub_key;
+ inst->alg.max_size = psspad_get_max_size;
+ inst->alg.reqsize = sizeof(struct psspad_request) + rsa_alg->reqsize;
+
+ inst->free = psspad_free;
+
+ err = akcipher_register_instance(tmpl, inst);
+ if (err) {
+err_free_inst:
+ psspad_free(inst);
+ }
+ return err;
+}
+
+struct crypto_template rsa_psspad_tmpl = {
+ .name = "psspad",
+ .create = psspad_create,
+ .module = THIS_MODULE,
+};
diff --git a/crypto/rsa.c b/crypto/rsa.c
index 4cdbec9..adc9b2d2 100644
--- a/crypto/rsa.c
+++ b/crypto/rsa.c
@@ -6,18 +6,11 @@
*/

#include <linux/module.h>
-#include <linux/mpi.h>
#include <crypto/internal/rsa.h>
#include <crypto/internal/akcipher.h>
#include <crypto/akcipher.h>
#include <crypto/algapi.h>

-struct rsa_mpi_key {
- MPI n;
- MPI e;
- MPI d;
-};
-
/*
* RSAEP function [RFC3447 sec 5.1.1]
* c = m^e mod n;
@@ -269,12 +262,19 @@ static int rsa_init(void)
return err;
}

+ err = crypto_register_template(&rsa_psspad_tmpl);
+ if (err) {
+ crypto_unregister_akcipher(&rsa);
+ return err;
+ }
+
return 0;
}

static void rsa_exit(void)
{
crypto_unregister_template(&rsa_pkcs1pad_tmpl);
+ crypto_unregister_template(&rsa_psspad_tmpl);
crypto_unregister_akcipher(&rsa);
}

diff --git a/crypto/rsa_helper.c b/crypto/rsa_helper.c
index 94266f2..912d975 100644
--- a/crypto/rsa_helper.c
+++ b/crypto/rsa_helper.c
@@ -12,6 +12,7 @@
#include <crypto/internal/rsa.h>
#include "rsapubkey.asn1.h"
#include "rsaprivkey.asn1.h"
+#include "rsapss_params.asn1.h"

int rsa_get_n(void *context, size_t hdrlen, unsigned char tag,
const void *value, size_t vlen)
@@ -148,6 +149,115 @@ int rsa_get_qinv(void *context, size_t hdrlen, unsigned char tag,
return 0;
}

+int rsa_get_pss_hash(void *context, size_t hdrlen, unsigned char tag,
+ const void *value, size_t vlen)
+{
+ struct rsa_pss_ctx *ctx = context;
+ enum OID oid;
+
+ if (!value || !vlen)
+ return -EINVAL;
+
+ oid = look_up_OID(value, vlen);
+ switch (oid) {
+ case OID_sha1:
+ ctx->hash_algo = "sha1";
+ break;
+ case OID_sha224:
+ ctx->hash_algo = "sha224";
+ break;
+ case OID_sha256:
+ ctx->hash_algo = "sha256";
+ break;
+ case OID_sha384:
+ ctx->hash_algo = "sha384";
+ break;
+ case OID_sha512:
+ ctx->hash_algo = "sha512";
+ break;
+ default:
+ return -ENOPKG;
+
+ }
+
+ return 0;
+}
+
+int rsa_get_pss_mgf(void *context, size_t hdrlen, unsigned char tag,
+ const void *value, size_t vlen)
+{
+ struct rsa_pss_ctx *ctx = context;
+ enum OID oid;
+
+ if (!value || !vlen)
+ return -EINVAL;
+
+ oid = look_up_OID(value, vlen);
+ if (oid != OID_rsa_mgf1)
+ return -ENOPKG;
+ ctx->mgf_algo = "mgf1";
+
+ return 0;
+}
+
+int rsa_get_pss_mgf_hash(void *context, size_t hdrlen, unsigned char tag,
+ const void *value, size_t vlen)
+{
+ struct rsa_pss_ctx *ctx = context;
+ enum OID oid;
+
+ if (!value || !vlen)
+ return -EINVAL;
+ /* todo, merge with get_pss_hash */
+ oid = look_up_OID(value, vlen);
+ switch (oid) {
+ case OID_sha1:
+ ctx->mgf_hash_algo = "sha1";
+ break;
+ case OID_sha224:
+ ctx->mgf_hash_algo = "sha224";
+ break;
+ case OID_sha256:
+ ctx->mgf_hash_algo = "sha256";
+ break;
+ case OID_sha384:
+ ctx->mgf_hash_algo = "sha384";
+ break;
+ case OID_sha512:
+ ctx->mgf_hash_algo = "sha512";
+ break;
+ default:
+ return -ENOPKG;
+ }
+
+ return 0;
+}
+
+int rsa_get_pss_saltlen(void *context, size_t hdrlen, unsigned char tag,
+ const void *value, size_t vlen)
+{
+ struct rsa_pss_ctx *ctx = context;
+
+ if (!value || vlen < 1 || vlen > 2)
+ return -EINVAL;
+
+ if (vlen == 1)
+ ctx->salt_len = *(u8 *)value;
+ else if (vlen == 2)
+ ctx->salt_len = ntohs(*(u16 *)value);
+
+ return 0;
+}
+
+int rsa_get_pss_trailerfield(void *context, size_t hdrlen, unsigned char tag,
+ const void *value, size_t vlen)
+{
+ if (!value || !vlen || *(u8 *)value != 1)
+ return -EINVAL;
+
+ return 0;
+}
+
/**
* rsa_parse_pub_key() - decodes the BER encoded buffer and stores in the
* provided struct rsa_key, pointers to the raw key as is,
@@ -184,3 +294,20 @@ int rsa_parse_priv_key(struct rsa_key *rsa_key, const void *key,
return asn1_ber_decoder(&rsaprivkey_decoder, rsa_key, key, key_len);
}
EXPORT_SYMBOL_GPL(rsa_parse_priv_key);
+
+/**
+ * rsa_parse_pss_params() - decodes the BER encoded pss padding params
+ *
+ * @ctx: struct rsa_pss_ctx, pss padding context
+ * @params: params in BER format
+ * @params_len: length of params
+ *
+ * Return: 0 on success or error code in case of error
+ */
+int rsa_parse_pss_params(struct rsa_pss_ctx *ctx, const void *params,
+ unsigned int params_len)
+{
+ return asn1_ber_decoder(&rsapss_params_decoder, ctx, params,
+ params_len);
+}
+EXPORT_SYMBOL_GPL(rsa_parse_pss_params);
diff --git a/crypto/rsapss_params.asn1 b/crypto/rsapss_params.asn1
new file mode 100644
index 0000000..4d6b0ba
--- /dev/null
+++ b/crypto/rsapss_params.asn1
@@ -0,0 +1,21 @@
+-- rfc4055 section 3.1.
+
+RSAPSS_Params ::= SEQUENCE {
+ hashAlgorithm [0] HashAlgorithm OPTIONAL,
+ maskGenAlgorithm [1] MaskGenAlgorithm OPTIONAL,
+ saltLen [2] INTEGER OPTIONAL ({ rsa_get_pss_saltlen }),
+ trailerField [3] INTEGER OPTIONAL ({ rsa_get_pss_trailerfield })
+ }
+
+HashAlgorithm ::= SEQUENCE {
+ algorithm OBJECT IDENTIFIER ({ rsa_get_pss_hash })
+ }
+
+MaskGenAlgorithm ::= SEQUENCE {
+ algorithm OBJECT IDENTIFIER ({ rsa_get_pss_mgf }),
+ hashAlgorithm MgfHashAlgorithm
+ }
+
+MgfHashAlgorithm ::= SEQUENCE {
+ algorithm OBJECT IDENTIFIER ({ rsa_get_pss_mgf_hash })
+ }
diff --git a/include/crypto/internal/rsa.h b/include/crypto/internal/rsa.h
index e870133..cfb0801 100644
--- a/include/crypto/internal/rsa.h
+++ b/include/crypto/internal/rsa.h
@@ -8,6 +8,12 @@
#ifndef _RSA_HELPER_
#define _RSA_HELPER_
#include <linux/types.h>
+#include <linux/mpi.h>
+#include <linux/oid_registry.h>
+#include <crypto/sha2.h>
+
+#define RSA_MAX_DIGEST_SIZE SHA512_DIGEST_SIZE
+#define RSA_PSS_DEFAULT_SALT_LEN 20

/**
* rsa_key - RSA key structure
@@ -47,11 +53,28 @@ struct rsa_key {
size_t qinv_sz;
};

+struct rsa_mpi_key {
+ MPI n;
+ MPI e;
+ MPI d;
+};
+
+struct rsa_pss_ctx {
+ struct crypto_akcipher *child;
+ unsigned int key_size;
+ const char *hash_algo;
+ const char *mgf_algo;
+ const char *mgf_hash_algo;
+ u32 salt_len;
+};
+
int rsa_parse_pub_key(struct rsa_key *rsa_key, const void *key,
unsigned int key_len);

int rsa_parse_priv_key(struct rsa_key *rsa_key, const void *key,
unsigned int key_len);
-
+int rsa_parse_pss_params(struct rsa_pss_ctx *ctx, const void *params,
+ unsigned int params_len);
extern struct crypto_template rsa_pkcs1pad_tmpl;
+extern struct crypto_template rsa_psspad_tmpl;
#endif
--
1.8.3.1