Re: [PATCH v11 10/20] virt: sev-guest: Carve out SNP message context structure
From: Tom Lendacky
Date: Fri Sep 13 2024 - 11:52:46 EST
On 7/31/24 10:08, Nikunj A Dadhania wrote:
> Currently, the sev-guest driver is the only user of SNP guest messaging.
> snp_guest_dev structure holds all the allocated buffers, secrets page and
The snp_guest_dev structure...
> VMPCK details. In preparation of adding messaging allocation and
s/of/for/
> initialization APIs, decouple snp_guest_dev from messaging-related
> information by carving out guest message context structure(snp_msg_desc).
s/out guest/out the guest/
>
> Incorporate this newly added context into snp_send_guest_request() and all
> related functions, replacing the use of the snp_guest_dev.
>
> No functional change.
>
> Signed-off-by: Nikunj A Dadhania <nikunj@xxxxxxx>
Reviewed-by: Tom Lendacky <thomas.lendacky@xxxxxxx>
> ---
> arch/x86/include/asm/sev.h | 21 +++
> drivers/virt/coco/sev-guest/sev-guest.c | 183 ++++++++++++------------
> 2 files changed, 111 insertions(+), 93 deletions(-)
>
> diff --git a/arch/x86/include/asm/sev.h b/arch/x86/include/asm/sev.h
> index 27fa1c9c3465..2e49c4a9e7fe 100644
> --- a/arch/x86/include/asm/sev.h
> +++ b/arch/x86/include/asm/sev.h
> @@ -234,6 +234,27 @@ struct snp_secrets_page {
> u8 rsvd4[3744];
> } __packed;
>
> +struct snp_msg_desc {
> + /* request and response are in unencrypted memory */
> + struct snp_guest_msg *request, *response;
> +
> + /*
> + * Avoid information leakage by double-buffering shared messages
> + * in fields that are in regular encrypted memory.
> + */
> + struct snp_guest_msg secret_request, secret_response;
> +
> + struct snp_secrets_page *secrets;
> + struct snp_req_data input;
> +
> + void *certs_data;
> +
> + struct aesgcm_ctx *ctx;
> +
> + u32 *os_area_msg_seqno;
> + u8 *vmpck;
> +};
> +
> /*
> * The SVSM Calling Area (CA) related structures.
> */
> diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c
> index 42f7126f1718..38ddabcd7ba3 100644
> --- a/drivers/virt/coco/sev-guest/sev-guest.c
> +++ b/drivers/virt/coco/sev-guest/sev-guest.c
> @@ -40,26 +40,13 @@ struct snp_guest_dev {
> struct device *dev;
> struct miscdevice misc;
>
> - void *certs_data;
> - struct aesgcm_ctx *ctx;
> - /* request and response are in unencrypted memory */
> - struct snp_guest_msg *request, *response;
> -
> - /*
> - * Avoid information leakage by double-buffering shared messages
> - * in fields that are in regular encrypted memory.
> - */
> - struct snp_guest_msg secret_request, secret_response;
> + struct snp_msg_desc *msg_desc;
>
> - struct snp_secrets_page *secrets;
> - struct snp_req_data input;
> union {
> struct snp_report_req report;
> struct snp_derived_key_req derived_key;
> struct snp_ext_report_req ext_report;
> } req;
> - u32 *os_area_msg_seqno;
> - u8 *vmpck;
> };
>
> /*
> @@ -76,12 +63,12 @@ MODULE_PARM_DESC(vmpck_id, "The VMPCK ID to use when communicating with the PSP.
> /* Mutex to serialize the shared buffer access and command handling. */
> static DEFINE_MUTEX(snp_cmd_mutex);
>
> -static bool is_vmpck_empty(struct snp_guest_dev *snp_dev)
> +static bool is_vmpck_empty(struct snp_msg_desc *mdesc)
> {
> char zero_key[VMPCK_KEY_LEN] = {0};
>
> - if (snp_dev->vmpck)
> - return !memcmp(snp_dev->vmpck, zero_key, VMPCK_KEY_LEN);
> + if (mdesc->vmpck)
> + return !memcmp(mdesc->vmpck, zero_key, VMPCK_KEY_LEN);
>
> return true;
> }
> @@ -103,30 +90,30 @@ static bool is_vmpck_empty(struct snp_guest_dev *snp_dev)
> * vulnerable. If the sequence number were incremented for a fresh IV the ASP
> * will reject the request.
> */
> -static void snp_disable_vmpck(struct snp_guest_dev *snp_dev)
> +static void snp_disable_vmpck(struct snp_msg_desc *mdesc)
> {
> - dev_alert(snp_dev->dev, "Disabling VMPCK%d communication key to prevent IV reuse.\n",
> + pr_alert("Disabling VMPCK%d communication key to prevent IV reuse.\n",
> vmpck_id);
> - memzero_explicit(snp_dev->vmpck, VMPCK_KEY_LEN);
> - snp_dev->vmpck = NULL;
> + memzero_explicit(mdesc->vmpck, VMPCK_KEY_LEN);
> + mdesc->vmpck = NULL;
> }
>
> -static inline u64 __snp_get_msg_seqno(struct snp_guest_dev *snp_dev)
> +static inline u64 __snp_get_msg_seqno(struct snp_msg_desc *mdesc)
> {
> u64 count;
>
> lockdep_assert_held(&snp_cmd_mutex);
>
> /* Read the current message sequence counter from secrets pages */
> - count = *snp_dev->os_area_msg_seqno;
> + count = *mdesc->os_area_msg_seqno;
>
> return count + 1;
> }
>
> /* Return a non-zero on success */
> -static u64 snp_get_msg_seqno(struct snp_guest_dev *snp_dev)
> +static u64 snp_get_msg_seqno(struct snp_msg_desc *mdesc)
> {
> - u64 count = __snp_get_msg_seqno(snp_dev);
> + u64 count = __snp_get_msg_seqno(mdesc);
>
> /*
> * The message sequence counter for the SNP guest request is a 64-bit
> @@ -137,20 +124,20 @@ static u64 snp_get_msg_seqno(struct snp_guest_dev *snp_dev)
> * invalid number and will fail the message request.
> */
> if (count >= UINT_MAX) {
> - dev_err(snp_dev->dev, "request message sequence counter overflow\n");
> + pr_err("request message sequence counter overflow\n");
> return 0;
> }
>
> return count;
> }
>
> -static void snp_inc_msg_seqno(struct snp_guest_dev *snp_dev)
> +static void snp_inc_msg_seqno(struct snp_msg_desc *mdesc)
> {
> /*
> * The counter is also incremented by the PSP, so increment it by 2
> * and save in secrets page.
> */
> - *snp_dev->os_area_msg_seqno += 2;
> + *mdesc->os_area_msg_seqno += 2;
> }
>
> static inline struct snp_guest_dev *to_snp_dev(struct file *file)
> @@ -177,13 +164,13 @@ static struct aesgcm_ctx *snp_init_crypto(u8 *key, size_t keylen)
> return ctx;
> }
>
> -static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, struct snp_guest_req *req)
> +static int verify_and_dec_payload(struct snp_msg_desc *mdesc, struct snp_guest_req *req)
> {
> - struct snp_guest_msg *resp_msg = &snp_dev->secret_response;
> - struct snp_guest_msg *req_msg = &snp_dev->secret_request;
> + struct snp_guest_msg *resp_msg = &mdesc->secret_response;
> + struct snp_guest_msg *req_msg = &mdesc->secret_request;
> struct snp_guest_msg_hdr *req_msg_hdr = &req_msg->hdr;
> struct snp_guest_msg_hdr *resp_msg_hdr = &resp_msg->hdr;
> - struct aesgcm_ctx *ctx = snp_dev->ctx;
> + struct aesgcm_ctx *ctx = mdesc->ctx;
> u8 iv[GCM_AES_IV_SIZE] = {};
>
> pr_debug("response [seqno %lld type %d version %d sz %d]\n",
> @@ -191,7 +178,7 @@ static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, struct snp_gues
> resp_msg_hdr->msg_sz);
>
> /* Copy response from shared memory to encrypted memory. */
> - memcpy(resp_msg, snp_dev->response, sizeof(*resp_msg));
> + memcpy(resp_msg, mdesc->response, sizeof(*resp_msg));
>
> /* Verify that the sequence counter is incremented by 1 */
> if (unlikely(resp_msg_hdr->msg_seqno != (req_msg_hdr->msg_seqno + 1)))
> @@ -218,11 +205,11 @@ static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, struct snp_gues
> return 0;
> }
>
> -static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, struct snp_guest_req *req)
> +static int enc_payload(struct snp_msg_desc *mdesc, u64 seqno, struct snp_guest_req *req)
> {
> - struct snp_guest_msg *msg = &snp_dev->secret_request;
> + struct snp_guest_msg *msg = &mdesc->secret_request;
> struct snp_guest_msg_hdr *hdr = &msg->hdr;
> - struct aesgcm_ctx *ctx = snp_dev->ctx;
> + struct aesgcm_ctx *ctx = mdesc->ctx;
> u8 iv[GCM_AES_IV_SIZE] = {};
>
> memset(msg, 0, sizeof(*msg));
> @@ -253,7 +240,7 @@ static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, struct snp_gues
> return 0;
> }
>
> -static int __handle_guest_request(struct snp_guest_dev *snp_dev, struct snp_guest_req *req,
> +static int __handle_guest_request(struct snp_msg_desc *mdesc, struct snp_guest_req *req,
> struct snp_guest_request_ioctl *rio)
> {
> unsigned long req_start = jiffies;
> @@ -268,7 +255,7 @@ static int __handle_guest_request(struct snp_guest_dev *snp_dev, struct snp_gues
> * sequence number must be incremented or the VMPCK must be deleted to
> * prevent reuse of the IV.
> */
> - rc = snp_issue_guest_request(req, &snp_dev->input, rio);
> + rc = snp_issue_guest_request(req, &mdesc->input, rio);
> switch (rc) {
> case -ENOSPC:
> /*
> @@ -278,7 +265,7 @@ static int __handle_guest_request(struct snp_guest_dev *snp_dev, struct snp_gues
> * order to increment the sequence number and thus avoid
> * IV reuse.
> */
> - override_npages = snp_dev->input.data_npages;
> + override_npages = mdesc->input.data_npages;
> req->exit_code = SVM_VMGEXIT_GUEST_REQUEST;
>
> /*
> @@ -318,7 +305,7 @@ static int __handle_guest_request(struct snp_guest_dev *snp_dev, struct snp_gues
> * structure and any failure will wipe the VMPCK, preventing further
> * use anyway.
> */
> - snp_inc_msg_seqno(snp_dev);
> + snp_inc_msg_seqno(mdesc);
>
> if (override_err) {
> rio->exitinfo2 = override_err;
> @@ -334,12 +321,12 @@ static int __handle_guest_request(struct snp_guest_dev *snp_dev, struct snp_gues
> }
>
> if (override_npages)
> - snp_dev->input.data_npages = override_npages;
> + mdesc->input.data_npages = override_npages;
>
> return rc;
> }
>
> -static int snp_send_guest_request(struct snp_guest_dev *snp_dev, struct snp_guest_req *req,
> +static int snp_send_guest_request(struct snp_msg_desc *mdesc, struct snp_guest_req *req,
> struct snp_guest_request_ioctl *rio)
> {
> u64 seqno;
> @@ -348,15 +335,15 @@ static int snp_send_guest_request(struct snp_guest_dev *snp_dev, struct snp_gues
> guard(mutex)(&snp_cmd_mutex);
>
> /* Get message sequence and verify that its a non-zero */
> - seqno = snp_get_msg_seqno(snp_dev);
> + seqno = snp_get_msg_seqno(mdesc);
> if (!seqno)
> return -EIO;
>
> /* Clear shared memory's response for the host to populate. */
> - memset(snp_dev->response, 0, sizeof(struct snp_guest_msg));
> + memset(mdesc->response, 0, sizeof(struct snp_guest_msg));
>
> - /* Encrypt the userspace provided payload in snp_dev->secret_request. */
> - rc = enc_payload(snp_dev, seqno, req);
> + /* Encrypt the userspace provided payload in mdesc->secret_request. */
> + rc = enc_payload(mdesc, seqno, req);
> if (rc)
> return rc;
>
> @@ -364,34 +351,33 @@ static int snp_send_guest_request(struct snp_guest_dev *snp_dev, struct snp_gues
> * Write the fully encrypted request to the shared unencrypted
> * request page.
> */
> - memcpy(snp_dev->request, &snp_dev->secret_request,
> - sizeof(snp_dev->secret_request));
> + memcpy(mdesc->request, &mdesc->secret_request,
> + sizeof(mdesc->secret_request));
>
> - rc = __handle_guest_request(snp_dev, req, rio);
> + rc = __handle_guest_request(mdesc, req, rio);
> if (rc) {
> if (rc == -EIO &&
> rio->exitinfo2 == SNP_GUEST_VMM_ERR(SNP_GUEST_VMM_ERR_INVALID_LEN))
> return rc;
>
> - dev_alert(snp_dev->dev,
> - "Detected error from ASP request. rc: %d, exitinfo2: 0x%llx\n",
> - rc, rio->exitinfo2);
> + pr_alert("Detected error from ASP request. rc: %d, exitinfo2: 0x%llx\n",
> + rc, rio->exitinfo2);
>
> - snp_disable_vmpck(snp_dev);
> + snp_disable_vmpck(mdesc);
> return rc;
> }
>
> - rc = verify_and_dec_payload(snp_dev, req);
> + rc = verify_and_dec_payload(mdesc, req);
> if (rc) {
> - dev_alert(snp_dev->dev, "Detected unexpected decode failure from ASP. rc: %d\n", rc);
> - snp_disable_vmpck(snp_dev);
> + pr_alert("Detected unexpected decode failure from ASP. rc: %d\n", rc);
> + snp_disable_vmpck(mdesc);
> return rc;
> }
>
> return 0;
> }
>
> -static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
> +static int handle_guest_request(struct snp_msg_desc *mdesc, u64 exit_code,
> struct snp_guest_request_ioctl *rio, u8 type,
> void *req_buf, size_t req_sz, void *resp_buf,
> u32 resp_sz)
> @@ -407,7 +393,7 @@ static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
> .exit_code = exit_code,
> };
>
> - return snp_send_guest_request(snp_dev, &req, rio);
> + return snp_send_guest_request(mdesc, &req, rio);
> }
>
> struct snp_req_resp {
> @@ -418,6 +404,7 @@ struct snp_req_resp {
> static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg)
> {
> struct snp_report_req *report_req = &snp_dev->req.report;
> + struct snp_msg_desc *mdesc = snp_dev->msg_desc;
> struct snp_report_resp *report_resp;
> int rc, resp_len;
>
> @@ -432,12 +419,12 @@ static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_io
> * response payload. Make sure that it has enough space to cover the
> * authtag.
> */
> - resp_len = sizeof(report_resp->data) + snp_dev->ctx->authsize;
> + resp_len = sizeof(report_resp->data) + mdesc->ctx->authsize;
> report_resp = kzalloc(resp_len, GFP_KERNEL_ACCOUNT);
> if (!report_resp)
> return -ENOMEM;
>
> - rc = handle_guest_request(snp_dev, SVM_VMGEXIT_GUEST_REQUEST, arg, SNP_MSG_REPORT_REQ,
> + rc = handle_guest_request(mdesc, SVM_VMGEXIT_GUEST_REQUEST, arg, SNP_MSG_REPORT_REQ,
> report_req, sizeof(*report_req), report_resp->data, resp_len);
> if (rc)
> goto e_free;
> @@ -454,6 +441,7 @@ static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_reque
> {
> struct snp_derived_key_req *derived_key_req = &snp_dev->req.derived_key;
> struct snp_derived_key_resp derived_key_resp = {0};
> + struct snp_msg_desc *mdesc = snp_dev->msg_desc;
> int rc, resp_len;
> /* Response data is 64 bytes and max authsize for GCM is 16 bytes. */
> u8 buf[64 + 16];
> @@ -466,7 +454,7 @@ static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_reque
> * response payload. Make sure that it has enough space to cover the
> * authtag.
> */
> - resp_len = sizeof(derived_key_resp.data) + snp_dev->ctx->authsize;
> + resp_len = sizeof(derived_key_resp.data) + mdesc->ctx->authsize;
> if (sizeof(buf) < resp_len)
> return -ENOMEM;
>
> @@ -474,7 +462,7 @@ static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_reque
> sizeof(*derived_key_req)))
> return -EFAULT;
>
> - rc = handle_guest_request(snp_dev, SVM_VMGEXIT_GUEST_REQUEST, arg, SNP_MSG_KEY_REQ,
> + rc = handle_guest_request(mdesc, SVM_VMGEXIT_GUEST_REQUEST, arg, SNP_MSG_KEY_REQ,
> derived_key_req, sizeof(*derived_key_req), buf, resp_len);
> if (rc)
> return rc;
> @@ -495,6 +483,7 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
>
> {
> struct snp_ext_report_req *report_req = &snp_dev->req.ext_report;
> + struct snp_msg_desc *mdesc = snp_dev->msg_desc;
> struct snp_report_resp *report_resp;
> int ret, npages = 0, resp_len;
> sockptr_t certs_address;
> @@ -527,7 +516,7 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
> * the host. If host does not supply any certs in it, then copy
> * zeros to indicate that certificate data was not provided.
> */
> - memset(snp_dev->certs_data, 0, report_req->certs_len);
> + memset(mdesc->certs_data, 0, report_req->certs_len);
> npages = report_req->certs_len >> PAGE_SHIFT;
> cmd:
> /*
> @@ -535,19 +524,19 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
> * response payload. Make sure that it has enough space to cover the
> * authtag.
> */
> - resp_len = sizeof(report_resp->data) + snp_dev->ctx->authsize;
> + resp_len = sizeof(report_resp->data) + mdesc->ctx->authsize;
> report_resp = kzalloc(resp_len, GFP_KERNEL_ACCOUNT);
> if (!report_resp)
> return -ENOMEM;
>
> - snp_dev->input.data_npages = npages;
> - ret = handle_guest_request(snp_dev, SVM_VMGEXIT_EXT_GUEST_REQUEST, arg, SNP_MSG_REPORT_REQ,
> + mdesc->input.data_npages = npages;
> + ret = handle_guest_request(mdesc, SVM_VMGEXIT_EXT_GUEST_REQUEST, arg, SNP_MSG_REPORT_REQ,
> &report_req->data, sizeof(report_req->data),
> report_resp->data, resp_len);
>
> /* If certs length is invalid then copy the returned length */
> if (arg->vmm_error == SNP_GUEST_VMM_ERR_INVALID_LEN) {
> - report_req->certs_len = snp_dev->input.data_npages << PAGE_SHIFT;
> + report_req->certs_len = mdesc->input.data_npages << PAGE_SHIFT;
>
> if (copy_to_sockptr(io->req_data, report_req, sizeof(*report_req)))
> ret = -EFAULT;
> @@ -556,7 +545,7 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
> if (ret)
> goto e_free;
>
> - if (npages && copy_to_sockptr(certs_address, snp_dev->certs_data, report_req->certs_len)) {
> + if (npages && copy_to_sockptr(certs_address, mdesc->certs_data, report_req->certs_len)) {
> ret = -EFAULT;
> goto e_free;
> }
> @@ -572,6 +561,7 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
> static long snp_guest_ioctl(struct file *file, unsigned int ioctl, unsigned long arg)
> {
> struct snp_guest_dev *snp_dev = to_snp_dev(file);
> + struct snp_msg_desc *mdesc = snp_dev->msg_desc;
> void __user *argp = (void __user *)arg;
> struct snp_guest_request_ioctl input;
> struct snp_req_resp io;
> @@ -587,7 +577,7 @@ static long snp_guest_ioctl(struct file *file, unsigned int ioctl, unsigned long
> return -EINVAL;
>
> /* Check if the VMPCK is not empty */
> - if (is_vmpck_empty(snp_dev)) {
> + if (is_vmpck_empty(mdesc)) {
> dev_err_ratelimited(snp_dev->dev, "VMPCK is disabled\n");
> return -ENOTTY;
> }
> @@ -862,7 +852,7 @@ static int sev_report_new(struct tsm_report *report, void *data)
> return -ENOMEM;
>
> /* Check if the VMPCK is not empty */
> - if (is_vmpck_empty(snp_dev)) {
> + if (is_vmpck_empty(snp_dev->msg_desc)) {
> dev_err_ratelimited(snp_dev->dev, "VMPCK is disabled\n");
> return -ENOTTY;
> }
> @@ -992,6 +982,7 @@ static int __init sev_guest_probe(struct platform_device *pdev)
> struct snp_secrets_page *secrets;
> struct device *dev = &pdev->dev;
> struct snp_guest_dev *snp_dev;
> + struct snp_msg_desc *mdesc;
> struct miscdevice *misc;
> void __iomem *mapping;
> int ret;
> @@ -1014,46 +1005,50 @@ static int __init sev_guest_probe(struct platform_device *pdev)
> if (!snp_dev)
> goto e_unmap;
>
> + mdesc = devm_kzalloc(&pdev->dev, sizeof(struct snp_msg_desc), GFP_KERNEL);
> + if (!mdesc)
> + goto e_unmap;
> +
> /* Adjust the default VMPCK key based on the executing VMPL level */
> if (vmpck_id == -1)
> vmpck_id = snp_vmpl;
>
> ret = -EINVAL;
> - snp_dev->vmpck = get_vmpck(vmpck_id, secrets, &snp_dev->os_area_msg_seqno);
> - if (!snp_dev->vmpck) {
> + mdesc->vmpck = get_vmpck(vmpck_id, secrets, &mdesc->os_area_msg_seqno);
> + if (!mdesc->vmpck) {
> dev_err(dev, "Invalid VMPCK%d communication key\n", vmpck_id);
> goto e_unmap;
> }
>
> /* Verify that VMPCK is not zero. */
> - if (is_vmpck_empty(snp_dev)) {
> + if (is_vmpck_empty(mdesc)) {
> dev_err(dev, "Empty VMPCK%d communication key\n", vmpck_id);
> goto e_unmap;
> }
>
> platform_set_drvdata(pdev, snp_dev);
> snp_dev->dev = dev;
> - snp_dev->secrets = secrets;
> + mdesc->secrets = secrets;
>
> /* Ensure SNP guest messages do not span more than a page */
> BUILD_BUG_ON(sizeof(struct snp_guest_msg) > PAGE_SIZE);
>
> /* Allocate the shared page used for the request and response message. */
> - snp_dev->request = alloc_shared_pages(dev, sizeof(struct snp_guest_msg));
> - if (!snp_dev->request)
> + mdesc->request = alloc_shared_pages(dev, sizeof(struct snp_guest_msg));
> + if (!mdesc->request)
> goto e_unmap;
>
> - snp_dev->response = alloc_shared_pages(dev, sizeof(struct snp_guest_msg));
> - if (!snp_dev->response)
> + mdesc->response = alloc_shared_pages(dev, sizeof(struct snp_guest_msg));
> + if (!mdesc->response)
> goto e_free_request;
>
> - snp_dev->certs_data = alloc_shared_pages(dev, SEV_FW_BLOB_MAX_SIZE);
> - if (!snp_dev->certs_data)
> + mdesc->certs_data = alloc_shared_pages(dev, SEV_FW_BLOB_MAX_SIZE);
> + if (!mdesc->certs_data)
> goto e_free_response;
>
> ret = -EIO;
> - snp_dev->ctx = snp_init_crypto(snp_dev->vmpck, VMPCK_KEY_LEN);
> - if (!snp_dev->ctx)
> + mdesc->ctx = snp_init_crypto(mdesc->vmpck, VMPCK_KEY_LEN);
> + if (!mdesc->ctx)
> goto e_free_cert_data;
>
> misc = &snp_dev->misc;
> @@ -1062,9 +1057,9 @@ static int __init sev_guest_probe(struct platform_device *pdev)
> misc->fops = &snp_guest_fops;
>
> /* Initialize the input addresses for guest request */
> - snp_dev->input.req_gpa = __pa(snp_dev->request);
> - snp_dev->input.resp_gpa = __pa(snp_dev->response);
> - snp_dev->input.data_gpa = __pa(snp_dev->certs_data);
> + mdesc->input.req_gpa = __pa(mdesc->request);
> + mdesc->input.resp_gpa = __pa(mdesc->response);
> + mdesc->input.data_gpa = __pa(mdesc->certs_data);
>
> /* Set the privlevel_floor attribute based on the vmpck_id */
> sev_tsm_ops.privlevel_floor = vmpck_id;
> @@ -1081,17 +1076,18 @@ static int __init sev_guest_probe(struct platform_device *pdev)
> if (ret)
> goto e_free_ctx;
>
> + snp_dev->msg_desc = mdesc;
> dev_info(dev, "Initialized SEV guest driver (using VMPCK%d communication key)\n", vmpck_id);
> return 0;
>
> e_free_ctx:
> - kfree(snp_dev->ctx);
> + kfree(mdesc->ctx);
> e_free_cert_data:
> - free_shared_pages(snp_dev->certs_data, SEV_FW_BLOB_MAX_SIZE);
> + free_shared_pages(mdesc->certs_data, SEV_FW_BLOB_MAX_SIZE);
> e_free_response:
> - free_shared_pages(snp_dev->response, sizeof(struct snp_guest_msg));
> + free_shared_pages(mdesc->response, sizeof(struct snp_guest_msg));
> e_free_request:
> - free_shared_pages(snp_dev->request, sizeof(struct snp_guest_msg));
> + free_shared_pages(mdesc->request, sizeof(struct snp_guest_msg));
> e_unmap:
> iounmap(mapping);
> return ret;
> @@ -1100,11 +1096,12 @@ static int __init sev_guest_probe(struct platform_device *pdev)
> static void __exit sev_guest_remove(struct platform_device *pdev)
> {
> struct snp_guest_dev *snp_dev = platform_get_drvdata(pdev);
> + struct snp_msg_desc *mdesc = snp_dev->msg_desc;
>
> - free_shared_pages(snp_dev->certs_data, SEV_FW_BLOB_MAX_SIZE);
> - free_shared_pages(snp_dev->response, sizeof(struct snp_guest_msg));
> - free_shared_pages(snp_dev->request, sizeof(struct snp_guest_msg));
> - kfree(snp_dev->ctx);
> + free_shared_pages(mdesc->certs_data, SEV_FW_BLOB_MAX_SIZE);
> + free_shared_pages(mdesc->response, sizeof(struct snp_guest_msg));
> + free_shared_pages(mdesc->request, sizeof(struct snp_guest_msg));
> + kfree(mdesc->ctx);
> misc_deregister(&snp_dev->misc);
> }
>