[PATCH v9 07/24] virt: sev-guest: Store VMPCK index to SNP guest device structure

From: Nikunj A Dadhania
Date: Fri May 31 2024 - 00:33:32 EST


Currently, SEV guest driver retrieves the pointers to VMPCK and
os_area_msg_seqno from the secrets page. In order to get rid of this
dependency, use vmpck_id to index the appropriate key and the corresponding
message sequence number.

Signed-off-by: Nikunj A Dadhania <nikunj@xxxxxxx>
---
drivers/virt/coco/sev-guest/sev-guest.c | 67 ++++++++++++-------------
1 file changed, 33 insertions(+), 34 deletions(-)

diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c
index a3c0b22d2e14..0729d0b73495 100644
--- a/drivers/virt/coco/sev-guest/sev-guest.c
+++ b/drivers/virt/coco/sev-guest/sev-guest.c
@@ -55,8 +55,7 @@ struct snp_guest_dev {
struct snp_derived_key_req derived_key;
struct snp_ext_report_req ext_report;
} req;
- u32 *os_area_msg_seqno;
- u8 *vmpck;
+ unsigned int vmpck_id;
};

static u32 vmpck_id;
@@ -66,14 +65,17 @@ MODULE_PARM_DESC(vmpck_id, "The VMPCK ID to use when communicating with the PSP.
/* Mutex to serialize the shared buffer access and command handling. */
static DEFINE_MUTEX(snp_cmd_mutex);

+static inline u8 *get_vmpck(struct snp_guest_dev *snp_dev)
+{
+ return snp_dev->secrets->vmpck[snp_dev->vmpck_id];
+}
+
static bool is_vmpck_empty(struct snp_guest_dev *snp_dev)
{
char zero_key[VMPCK_KEY_LEN] = {0};
+ u8 *key = get_vmpck(snp_dev);

- if (snp_dev->vmpck)
- return !memcmp(snp_dev->vmpck, zero_key, VMPCK_KEY_LEN);
-
- return true;
+ return !memcmp(key, zero_key, VMPCK_KEY_LEN);
}

/*
@@ -95,28 +97,23 @@ static bool is_vmpck_empty(struct snp_guest_dev *snp_dev)
*/
static void snp_disable_vmpck(struct snp_guest_dev *snp_dev)
{
- dev_alert(snp_dev->dev, "Disabling VMPCK%d to prevent IV reuse.\n",
- vmpck_id);
- memzero_explicit(snp_dev->vmpck, VMPCK_KEY_LEN);
- snp_dev->vmpck = NULL;
-}
-
-static inline u64 __snp_get_msg_seqno(struct snp_guest_dev *snp_dev)
-{
- u64 count;
-
- lockdep_assert_held(&snp_cmd_mutex);
+ u8 *key = get_vmpck(snp_dev);

- /* Read the current message sequence counter from secrets pages */
- count = *snp_dev->os_area_msg_seqno;
+ if (is_vmpck_empty(snp_dev))
+ return;

- return count + 1;
+ dev_alert(snp_dev->dev, "Disabling VMPCK%u to prevent IV reuse.\n", snp_dev->vmpck_id);
+ memzero_explicit(key, VMPCK_KEY_LEN);
}

/* Return a non-zero on success */
static u64 snp_get_msg_seqno(struct snp_guest_dev *snp_dev)
{
- u64 count = __snp_get_msg_seqno(snp_dev);
+ u64 count;
+
+ lockdep_assert_held(&snp_cmd_mutex);
+
+ count = snp_dev->secrets->os_area.msg_seqno[snp_dev->vmpck_id] + 1;

/*
* The message sequence counter for the SNP guest request is a 64-bit
@@ -140,7 +137,7 @@ static void snp_inc_msg_seqno(struct snp_guest_dev *snp_dev)
* The counter is also incremented by the PSP, so increment it by 2
* and save in secrets page.
*/
- *snp_dev->os_area_msg_seqno += 2;
+ snp_dev->secrets->os_area.msg_seqno[snp_dev->vmpck_id] += 2;
}

static inline struct snp_guest_dev *to_snp_dev(struct file *file)
@@ -150,15 +147,17 @@ static inline struct snp_guest_dev *to_snp_dev(struct file *file)
return container_of(dev, struct snp_guest_dev, misc);
}

-static struct aesgcm_ctx *snp_init_crypto(u8 *key, size_t keylen)
+static struct aesgcm_ctx *snp_init_crypto(struct snp_guest_dev *snp_dev)
{
struct aesgcm_ctx *ctx;
+ u8 *key;

ctx = kzalloc(sizeof(*ctx), GFP_KERNEL_ACCOUNT);
if (!ctx)
return NULL;

- if (aesgcm_expandkey(ctx, key, keylen, AUTHTAG_LEN)) {
+ key = get_vmpck(snp_dev);
+ if (aesgcm_expandkey(ctx, key, VMPCK_KEY_LEN, AUTHTAG_LEN)) {
pr_err("Crypto context initialization failed\n");
kfree(ctx);
return NULL;
@@ -666,13 +665,14 @@ static const struct file_operations snp_guest_fops = {
.unlocked_ioctl = snp_guest_ioctl,
};

-static u8 *get_vmpck(int id, struct snp_secrets_page *secrets, u32 **seqno)
+static bool assign_vmpck(struct snp_guest_dev *dev, unsigned int vmpck_id)
{
- if ((id + 1) > VMPCK_MAX_NUM)
- return NULL;
+ if ((vmpck_id + 1) > VMPCK_MAX_NUM)
+ return false;
+
+ dev->vmpck_id = vmpck_id;

- *seqno = &secrets->os_area.msg_seqno[id];
- return secrets->vmpck[id];
+ return true;
}

struct snp_msg_report_resp_hdr {
@@ -828,21 +828,20 @@ static int __init sev_guest_probe(struct platform_device *pdev)
goto e_unmap;

ret = -EINVAL;
- snp_dev->vmpck = get_vmpck(vmpck_id, secrets, &snp_dev->os_area_msg_seqno);
- if (!snp_dev->vmpck) {
+ snp_dev->secrets = secrets;
+ if (!assign_vmpck(snp_dev, vmpck_id)) {
dev_err(dev, "Invalid VMPCK%d communication key\n", vmpck_id);
goto e_unmap;
}

/* Verify that VMPCK is not zero. */
if (is_vmpck_empty(snp_dev)) {
- dev_err(dev, "Empty VMPCK%d communication key\n", vmpck_id);
+ dev_err(dev, "Empty VMPCK%d communication key\n", snp_dev->vmpck_id);
goto e_unmap;
}

platform_set_drvdata(pdev, snp_dev);
snp_dev->dev = dev;
- snp_dev->secrets = secrets;

/* Allocate secret request and response message for double buffering */
snp_dev->secret_request = kzalloc(SNP_GUEST_MSG_SIZE, GFP_KERNEL);
@@ -867,7 +866,7 @@ static int __init sev_guest_probe(struct platform_device *pdev)
goto e_free_response;

ret = -EIO;
- snp_dev->ctx = snp_init_crypto(snp_dev->vmpck, VMPCK_KEY_LEN);
+ snp_dev->ctx = snp_init_crypto(snp_dev);
if (!snp_dev->ctx)
goto e_free_cert_data;

--
2.34.1