Currently, SEV guest driver retrieves the pointers to VMPCK and os_area_msg_seqno from the secrets page. In order to get rid of this dependency, use vmpck_id to index the appropriate key and the corresponding message sequence number. Signed-off-by: Nikunj A Dadhania <nikunj@xxxxxxx> --- drivers/virt/coco/sev-guest/sev-guest.c | 74 ++++++++++++------------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c index a5602c84769f..fcd61df08702 100644 --- a/drivers/virt/coco/sev-guest/sev-guest.c +++ b/drivers/virt/coco/sev-guest/sev-guest.c @@ -58,8 +58,7 @@ struct snp_guest_dev { struct snp_derived_key_req derived_key; struct snp_ext_report_req ext_report; } req; - u32 *os_area_msg_seqno; - u8 *vmpck; + unsigned int vmpck_id; }; /* @@ -69,21 +68,24 @@ struct snp_guest_dev { * Should the default key be wiped (see snp_disable_vmpck()), this parameter * allows for using one of the remaining VMPCKs. */ -static int vmpck_id = -1; -module_param(vmpck_id, int, 0444); +static u32 vmpck_id = VMPCK_MAX_NUM; +module_param(vmpck_id, uint, 0444); MODULE_PARM_DESC(vmpck_id, "The VMPCK ID to use when communicating with the PSP."); /* Mutex to serialize the shared buffer access and command handling. */ static DEFINE_MUTEX(snp_cmd_mutex); +static inline u8 *get_vmpck(struct snp_guest_dev *snp_dev) +{ + return snp_dev->secrets->vmpck[snp_dev->vmpck_id]; +} + static bool is_vmpck_empty(struct snp_guest_dev *snp_dev) { char zero_key[VMPCK_KEY_LEN] = {0}; + u8 *key = get_vmpck(snp_dev); - if (snp_dev->vmpck) - return !memcmp(snp_dev->vmpck, zero_key, VMPCK_KEY_LEN); - - return true; + return !memcmp(key, zero_key, VMPCK_KEY_LEN); } /* @@ -105,28 +107,24 @@ static bool is_vmpck_empty(struct snp_guest_dev *snp_dev) */ static void snp_disable_vmpck(struct snp_guest_dev *snp_dev) { - dev_alert(snp_dev->dev, "Disabling VMPCK%d communication key to prevent IV reuse.\n", - vmpck_id); - memzero_explicit(snp_dev->vmpck, VMPCK_KEY_LEN); - snp_dev->vmpck = NULL; -} - -static inline u64 __snp_get_msg_seqno(struct snp_guest_dev *snp_dev) -{ - u64 count; - - lockdep_assert_held(&snp_cmd_mutex); + u8 *key = get_vmpck(snp_dev); - /* Read the current message sequence counter from secrets pages */ - count = *snp_dev->os_area_msg_seqno; + if (is_vmpck_empty(snp_dev)) + return; - return count + 1; + dev_alert(snp_dev->dev, "Disabling VMPCK%u communication key to prevent IV reuse.\n", + snp_dev->vmpck_id); + memzero_explicit(key, VMPCK_KEY_LEN); } /* Return a non-zero on success */ static u64 snp_get_msg_seqno(struct snp_guest_dev *snp_dev) { - u64 count = __snp_get_msg_seqno(snp_dev); + u64 count; + + lockdep_assert_held(&snp_cmd_mutex); + + count = snp_dev->secrets->os_area.msg_seqno[snp_dev->vmpck_id] + 1; /* * The message sequence counter for the SNP guest request is a 64-bit @@ -150,7 +148,7 @@ static void snp_inc_msg_seqno(struct snp_guest_dev *snp_dev) * The counter is also incremented by the PSP, so increment it by 2 * and save in secrets page. */ - *snp_dev->os_area_msg_seqno += 2; + snp_dev->secrets->os_area.msg_seqno[snp_dev->vmpck_id] += 2; } static inline struct snp_guest_dev *to_snp_dev(struct file *file) @@ -160,15 +158,17 @@ static inline struct snp_guest_dev *to_snp_dev(struct file *file) return container_of(dev, struct snp_guest_dev, misc); } -static struct aesgcm_ctx *snp_init_crypto(u8 *key, size_t keylen) +static struct aesgcm_ctx *snp_init_crypto(struct snp_guest_dev *snp_dev) { struct aesgcm_ctx *ctx; + u8 *key; ctx = kzalloc(sizeof(*ctx), GFP_KERNEL_ACCOUNT); if (!ctx) return NULL; - if (aesgcm_expandkey(ctx, key, keylen, AUTHTAG_LEN)) { + key = get_vmpck(snp_dev); + if (aesgcm_expandkey(ctx, key, VMPCK_KEY_LEN, AUTHTAG_LEN)) { pr_err("Crypto context initialization failed\n"); kfree(ctx); return NULL; @@ -676,13 +676,14 @@ static const struct file_operations snp_guest_fops = { .unlocked_ioctl = snp_guest_ioctl, }; -static u8 *get_vmpck(int id, struct snp_secrets_page *secrets, u32 **seqno) +static bool assign_vmpck(struct snp_guest_dev *dev, unsigned int vmpck_id) { - if (!(id < VMPCK_MAX_NUM)) - return NULL; + if (!(vmpck_id < VMPCK_MAX_NUM)) + return false; + + dev->vmpck_id = vmpck_id; - *seqno = &secrets->os_area.msg_seqno[id]; - return secrets->vmpck[id]; + return true; } struct snp_msg_report_resp_hdr { @@ -1015,25 +1016,24 @@ static int __init sev_guest_probe(struct platform_device *pdev) goto e_unmap; /* Adjust the default VMPCK key based on the executing VMPL level */ - if (vmpck_id == -1) + if (vmpck_id == VMPCK_MAX_NUM) vmpck_id = snp_vmpl; ret = -EINVAL; - snp_dev->vmpck = get_vmpck(vmpck_id, secrets, &snp_dev->os_area_msg_seqno); - if (!snp_dev->vmpck) { + snp_dev->secrets = secrets; + if (!assign_vmpck(snp_dev, vmpck_id)) { dev_err(dev, "Invalid VMPCK%d communication key\n", vmpck_id); goto e_unmap; } /* Verify that VMPCK is not zero. */ if (is_vmpck_empty(snp_dev)) { - dev_err(dev, "Empty VMPCK%d communication key\n", vmpck_id); + dev_err(dev, "Empty VMPCK%d communication key\n", snp_dev->vmpck_id); goto e_unmap; } platform_set_drvdata(pdev, snp_dev); snp_dev->dev = dev; - snp_dev->secrets = secrets; /* Allocate secret request and response message for double buffering */ snp_dev->secret_request = kzalloc(SNP_GUEST_MSG_SIZE, GFP_KERNEL); @@ -1058,7 +1058,7 @@ static int __init sev_guest_probe(struct platform_device *pdev) goto e_free_response; ret = -EIO; - snp_dev->ctx = snp_init_crypto(snp_dev->vmpck, VMPCK_KEY_LEN); + snp_dev->ctx = snp_init_crypto(snp_dev); if (!snp_dev->ctx) goto e_free_cert_data; -- 2.34.1