Drop vmpck and os_area_msg_seqno pointers so that secret page layout does not need to be exposed to the sev-guest driver after the rework. Instead, add helper APIs to access vmpck and os_area_msg_seqno when needed. Also, change function is_vmpck_empty() to snp_is_vmpck_empty() in preparation for moving to sev.c. Signed-off-by: Nikunj A Dadhania <nikunj@xxxxxxx> Reviewed-by: Tom Lendacky <thomas.lendacky@xxxxxxx> --- drivers/virt/coco/sev-guest/sev-guest.c | 95 ++++++++++++------------- 1 file changed, 47 insertions(+), 48 deletions(-) diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c index 1579140d43ec..0f2134deca51 100644 --- a/drivers/virt/coco/sev-guest/sev-guest.c +++ b/drivers/virt/coco/sev-guest/sev-guest.c @@ -59,22 +59,29 @@ struct snp_guest_dev { struct snp_derived_key_req derived_key; struct snp_ext_report_req ext_report; } req; - u32 *os_area_msg_seqno; - u8 *vmpck; + unsigned int vmpck_id; }; static u32 vmpck_id; module_param(vmpck_id, uint, 0444); MODULE_PARM_DESC(vmpck_id, "The VMPCK ID to use when communicating with the PSP."); -static bool is_vmpck_empty(struct snp_guest_dev *snp_dev) +static inline u8 *snp_get_vmpck(struct snp_guest_dev *snp_dev) { - char zero_key[VMPCK_KEY_LEN] = {0}; + return snp_dev->layout->vmpck0 + snp_dev->vmpck_id * VMPCK_KEY_LEN; +} - if (snp_dev->vmpck) - return !memcmp(snp_dev->vmpck, zero_key, VMPCK_KEY_LEN); +static inline u32 *snp_get_os_area_msg_seqno(struct snp_guest_dev *snp_dev) +{ + return &snp_dev->layout->os_area.msg_seqno_0 + snp_dev->vmpck_id; +} - return true; +static bool snp_is_vmpck_empty(struct snp_guest_dev *snp_dev) +{ + char zero_key[VMPCK_KEY_LEN] = {0}; + u8 *key = snp_get_vmpck(snp_dev); + + return !memcmp(key, zero_key, VMPCK_KEY_LEN); } /* @@ -96,20 +103,22 @@ static bool is_vmpck_empty(struct snp_guest_dev *snp_dev) */ static void snp_disable_vmpck(struct snp_guest_dev *snp_dev) { - dev_alert(snp_dev->dev, "Disabling vmpck_id %d to prevent IV reuse.\n", - vmpck_id); - memzero_explicit(snp_dev->vmpck, VMPCK_KEY_LEN); - snp_dev->vmpck = NULL; + u8 *key = snp_get_vmpck(snp_dev); + + dev_alert(snp_dev->dev, "Disabling vmpck_id %u to prevent IV reuse.\n", + snp_dev->vmpck_id); + memzero_explicit(key, VMPCK_KEY_LEN); } static inline u64 __snp_get_msg_seqno(struct snp_guest_dev *snp_dev) { + u32 *os_area_msg_seqno = snp_get_os_area_msg_seqno(snp_dev); u64 count; lockdep_assert_held(&snp_dev->cmd_mutex); /* Read the current message sequence counter from secrets pages */ - count = *snp_dev->os_area_msg_seqno; + count = *os_area_msg_seqno; return count + 1; } @@ -137,11 +146,13 @@ static u64 snp_get_msg_seqno(struct snp_guest_dev *snp_dev) static void snp_inc_msg_seqno(struct snp_guest_dev *snp_dev) { + u32 *os_area_msg_seqno = snp_get_os_area_msg_seqno(snp_dev); + /* * The counter is also incremented by the PSP, so increment it by 2 * and save in secrets page. */ - *snp_dev->os_area_msg_seqno += 2; + *os_area_msg_seqno += 2; } static inline struct snp_guest_dev *to_snp_dev(struct file *file) @@ -151,15 +162,22 @@ static inline struct snp_guest_dev *to_snp_dev(struct file *file) return container_of(dev, struct snp_guest_dev, misc); } -static struct aesgcm_ctx *snp_init_crypto(u8 *key, size_t keylen) +static struct aesgcm_ctx *snp_init_crypto(struct snp_guest_dev *snp_dev) { struct aesgcm_ctx *ctx; + u8 *key; + + if (snp_is_vmpck_empty(snp_dev)) { + pr_err("VM communication key VMPCK%u is null\n", vmpck_id); + return NULL; + } ctx = kzalloc(sizeof(*ctx), GFP_KERNEL_ACCOUNT); if (!ctx) return NULL; - if (aesgcm_expandkey(ctx, key, keylen, AUTHTAG_LEN)) { + key = snp_get_vmpck(snp_dev); + if (aesgcm_expandkey(ctx, key, VMPCK_KEY_LEN, AUTHTAG_LEN)) { pr_err("Crypto context initialization failed\n"); kfree(ctx); return NULL; @@ -589,7 +607,7 @@ static long snp_guest_ioctl(struct file *file, unsigned int ioctl, unsigned long mutex_lock(&snp_dev->cmd_mutex); /* Check if the VMPCK is not empty */ - if (is_vmpck_empty(snp_dev)) { + if (snp_is_vmpck_empty(snp_dev)) { dev_err_ratelimited(snp_dev->dev, "VMPCK is disabled\n"); mutex_unlock(&snp_dev->cmd_mutex); return -ENOTTY; @@ -666,32 +684,14 @@ static const struct file_operations snp_guest_fops = { .unlocked_ioctl = snp_guest_ioctl, }; -static u8 *get_vmpck(int id, struct snp_secrets_page_layout *layout, u32 **seqno) +bool snp_assign_vmpck(struct snp_guest_dev *dev, unsigned int vmpck_id) { - u8 *key = NULL; + if (WARN_ON(vmpck_id > 3)) + return false; - switch (id) { - case 0: - *seqno = &layout->os_area.msg_seqno_0; - key = layout->vmpck0; - break; - case 1: - *seqno = &layout->os_area.msg_seqno_1; - key = layout->vmpck1; - break; - case 2: - *seqno = &layout->os_area.msg_seqno_2; - key = layout->vmpck2; - break; - case 3: - *seqno = &layout->os_area.msg_seqno_3; - key = layout->vmpck3; - break; - default: - break; - } + dev->vmpck_id = vmpck_id; - return key; + return true; } struct snp_msg_report_resp_hdr { @@ -727,7 +727,7 @@ static int sev_report_new(struct tsm_report *report, void *data) guard(mutex)(&snp_dev->cmd_mutex); /* Check if the VMPCK is not empty */ - if (is_vmpck_empty(snp_dev)) { + if (snp_is_vmpck_empty(snp_dev)) { dev_err_ratelimited(snp_dev->dev, "VMPCK is disabled\n"); return -ENOTTY; } @@ -847,22 +847,21 @@ static int __init sev_guest_probe(struct platform_device *pdev) goto e_unmap; ret = -EINVAL; - snp_dev->vmpck = get_vmpck(vmpck_id, layout, &snp_dev->os_area_msg_seqno); - if (!snp_dev->vmpck) { - dev_err(dev, "invalid vmpck id %d\n", vmpck_id); + snp_dev->layout = layout; + if (!snp_assign_vmpck(snp_dev, vmpck_id)) { + dev_err(dev, "invalid vmpck id %u\n", vmpck_id); goto e_unmap; } /* Verify that VMPCK is not zero. */ - if (is_vmpck_empty(snp_dev)) { - dev_err(dev, "vmpck id %d is null\n", vmpck_id); + if (snp_is_vmpck_empty(snp_dev)) { + dev_err(dev, "vmpck id %u is null\n", vmpck_id); goto e_unmap; } mutex_init(&snp_dev->cmd_mutex); platform_set_drvdata(pdev, snp_dev); snp_dev->dev = dev; - snp_dev->layout = layout; /* Allocate the shared page used for the request and response message. */ snp_dev->request = alloc_shared_pages(dev, sizeof(struct snp_guest_msg)); @@ -878,7 +877,7 @@ static int __init sev_guest_probe(struct platform_device *pdev) goto e_free_response; ret = -EIO; - snp_dev->ctx = snp_init_crypto(snp_dev->vmpck, VMPCK_KEY_LEN); + snp_dev->ctx = snp_init_crypto(snp_dev); if (!snp_dev->ctx) goto e_free_cert_data; @@ -903,7 +902,7 @@ static int __init sev_guest_probe(struct platform_device *pdev) if (ret) goto e_free_ctx; - dev_info(dev, "Initialized SEV guest driver (using vmpck_id %d)\n", vmpck_id); + dev_info(dev, "Initialized SEV guest driver (using vmpck_id %u)\n", vmpck_id); return 0; e_free_ctx: -- 2.34.1