Simplify get_vmpck and prepare it to be used as an API. Update the
snp_guest_dev structure in snp_assign_vmpck(). Added vmpck_id to the
snp_guest_dev structure which can be used in SNP guest request API and
will remove direct use of vmpck_id command line parameter.
Signed-off-by: Nikunj A Dadhania <nikunj@xxxxxxx>
---
drivers/virt/coco/sev-guest/sev-guest.c | 41 ++++++++-----------------
1 file changed, 12 insertions(+), 29 deletions(-)
diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c
index ec93dee330f2..4901ebc8fa1a 100644
--- a/drivers/virt/coco/sev-guest/sev-guest.c
+++ b/drivers/virt/coco/sev-guest/sev-guest.c
@@ -45,6 +45,7 @@ struct snp_guest_dev {
struct snp_req_data input;
u32 *os_area_msg_seqno;
u8 *vmpck;
+ u8 vmpck_id;
};
static u32 vmpck_id;
@@ -80,7 +81,7 @@ static inline unsigned int get_ctx_authsize(struct snp_guest_dev *snp_dev)
static void snp_disable_vmpck(struct snp_guest_dev *snp_dev)
{
dev_alert(snp_dev->dev, "Disabling vmpck_id %d to prevent IV reuse.\n",
- vmpck_id);
+ snp_dev->vmpck_id);
memzero_explicit(snp_dev->vmpck, VMPCK_KEY_LEN);
snp_dev->vmpck = NULL;
}
@@ -339,7 +340,7 @@ static int snp_send_guest_request(struct snp_guest_dev *snp_dev, struct snp_gues
memset(snp_dev->response, 0, sizeof(struct snp_guest_msg));
/* Encrypt the userspace provided payload */
- rc = enc_payload(snp_dev, seqno, req, vmpck_id);
+ rc = enc_payload(snp_dev, seqno, req, snp_dev->vmpck_id);
if (rc)
return rc;
@@ -364,7 +365,6 @@ static int snp_send_guest_request(struct snp_guest_dev *snp_dev, struct snp_gues
return 0;
}
-
static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code, u8 msg_version,
u8 msg_type, void *req_buf, size_t req_sz, void *resp_buf,
u32 resp_sz, __u64 *fw_err)
@@ -625,32 +625,16 @@ static const struct file_operations snp_guest_fops = {
.unlocked_ioctl = snp_guest_ioctl,
};
-static u8 *get_vmpck(int id, struct snp_secrets_page_layout *layout, u32 **seqno)
+bool snp_assign_vmpck(struct snp_guest_dev *dev, int vmpck_id)
{
- u8 *key = NULL;
+ if (WARN_ON(vmpck_id > 3))
+ return false;
- switch (id) {
- case 0:
- *seqno = &layout->os_area.msg_seqno_0;
- key = layout->vmpck0;
- break;
- case 1:
- *seqno = &layout->os_area.msg_seqno_1;
- key = layout->vmpck1;
- break;
- case 2:
- *seqno = &layout->os_area.msg_seqno_2;
- key = layout->vmpck2;
- break;
- case 3:
- *seqno = &layout->os_area.msg_seqno_3;
- key = layout->vmpck3;
- break;
- default:
- break;
- }
+ dev->vmpck_id = vmpck_id;
+ dev->vmpck = dev->layout->vmpck0 + vmpck_id * VMPCK_KEY_LEN;
+ dev->os_area_msg_seqno = &dev->layout->os_area.msg_seqno_0 + vmpck_id;
- return key;
+ return true;
}
static int __init sev_guest_probe(struct platform_device *pdev)
@@ -682,8 +666,8 @@ static int __init sev_guest_probe(struct platform_device *pdev)
goto e_unmap;
ret = -EINVAL;
- snp_dev->vmpck = get_vmpck(vmpck_id, layout, &snp_dev->os_area_msg_seqno);
- if (!snp_dev->vmpck) {
+ snp_dev->layout = layout;
+ if (!snp_assign_vmpck(snp_dev, vmpck_id)) {
dev_err(dev, "invalid vmpck id %d\n", vmpck_id);
goto e_unmap;
}
@@ -697,7 +681,6 @@ static int __init sev_guest_probe(struct platform_device *pdev)
mutex_init(&snp_dev->cmd_mutex);
platform_set_drvdata(pdev, snp_dev);
snp_dev->dev = dev;
- snp_dev->layout = layout;
/* Allocate the shared page used for the request and response message. */
snp_dev->request = alloc_shared_pages(dev, sizeof(struct snp_guest_msg));