Currently, guest message is PAGE_SIZE bytes and payload is hard-coded to 4000 bytes, assuming snp_guest_msg_hdr structure as 96 bytes. Remove the structure size assumption and hard-coding of payload size and instead use variable length array. While at it, rename the local guest message variables for clarity. Signed-off-by: Nikunj A Dadhania <nikunj@xxxxxxx> Suggested-by: Tom Lendacky <thomas.lendacky@xxxxxxx> --- drivers/virt/coco/sev-guest/sev-guest.h | 5 +- drivers/virt/coco/sev-guest/sev-guest.c | 74 +++++++++++++++---------- 2 files changed, 48 insertions(+), 31 deletions(-) diff --git a/drivers/virt/coco/sev-guest/sev-guest.h b/drivers/virt/coco/sev-guest/sev-guest.h index ceb798a404d6..97796f658fd3 100644 --- a/drivers/virt/coco/sev-guest/sev-guest.h +++ b/drivers/virt/coco/sev-guest/sev-guest.h @@ -60,7 +60,10 @@ struct snp_guest_msg_hdr { struct snp_guest_msg { struct snp_guest_msg_hdr hdr; - u8 payload[4000]; + u8 payload[]; } __packed; +#define SNP_GUEST_MSG_SIZE 4096 +#define SNP_GUEST_MSG_PAYLOAD_SIZE (SNP_GUEST_MSG_SIZE - sizeof(struct snp_guest_msg)) + #endif /* __VIRT_SEVGUEST_H__ */ diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c index 7e1bf2056b47..69bd817239d8 100644 --- a/drivers/virt/coco/sev-guest/sev-guest.c +++ b/drivers/virt/coco/sev-guest/sev-guest.c @@ -48,7 +48,7 @@ struct snp_guest_dev { * Avoid information leakage by double-buffering shared messages * in fields that are in regular encrypted memory. */ - struct snp_guest_msg secret_request, secret_response; + struct snp_guest_msg *secret_request, *secret_response; struct snp_secrets_page *secrets; struct snp_req_data input; @@ -171,40 +171,40 @@ static struct aesgcm_ctx *snp_init_crypto(u8 *key, size_t keylen) static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, void *payload, u32 sz) { - struct snp_guest_msg *resp = &snp_dev->secret_response; - struct snp_guest_msg *req = &snp_dev->secret_request; - struct snp_guest_msg_hdr *req_hdr = &req->hdr; - struct snp_guest_msg_hdr *resp_hdr = &resp->hdr; + struct snp_guest_msg *resp_msg = snp_dev->secret_response; + struct snp_guest_msg *req_msg = snp_dev->secret_request; + struct snp_guest_msg_hdr *req_msg_hdr = &req_msg->hdr; + struct snp_guest_msg_hdr *resp_msg_hdr = &resp_msg->hdr; struct aesgcm_ctx *ctx = snp_dev->ctx; u8 iv[GCM_AES_IV_SIZE] = {}; pr_debug("response [seqno %lld type %d version %d sz %d]\n", - resp_hdr->msg_seqno, resp_hdr->msg_type, resp_hdr->msg_version, - resp_hdr->msg_sz); + resp_msg_hdr->msg_seqno, resp_msg_hdr->msg_type, resp_msg_hdr->msg_version, + resp_msg_hdr->msg_sz); /* Copy response from shared memory to encrypted memory. */ - memcpy(resp, snp_dev->response, sizeof(*resp)); + memcpy(resp_msg, snp_dev->response, SNP_GUEST_MSG_SIZE); /* Verify that the sequence counter is incremented by 1 */ - if (unlikely(resp_hdr->msg_seqno != (req_hdr->msg_seqno + 1))) + if (unlikely(resp_msg_hdr->msg_seqno != (req_msg_hdr->msg_seqno + 1))) return -EBADMSG; /* Verify response message type and version number. */ - if (resp_hdr->msg_type != (req_hdr->msg_type + 1) || - resp_hdr->msg_version != req_hdr->msg_version) + if (resp_msg_hdr->msg_type != (req_msg_hdr->msg_type + 1) || + resp_msg_hdr->msg_version != req_msg_hdr->msg_version) return -EBADMSG; /* * If the message size is greater than our buffer length then return * an error. */ - if (unlikely((resp_hdr->msg_sz + ctx->authsize) > sz)) + if (unlikely((resp_msg_hdr->msg_sz + ctx->authsize) > sz)) return -EBADMSG; /* Decrypt the payload */ - memcpy(iv, &resp_hdr->msg_seqno, min(sizeof(iv), sizeof(resp_hdr->msg_seqno))); - if (!aesgcm_decrypt(ctx, payload, resp->payload, resp_hdr->msg_sz, - &resp_hdr->algo, AAD_LEN, iv, resp_hdr->authtag)) + memcpy(iv, &resp_msg_hdr->msg_seqno, min(sizeof(iv), sizeof(resp_msg_hdr->msg_seqno))); + if (!aesgcm_decrypt(ctx, payload, resp_msg->payload, resp_msg_hdr->msg_sz, + &resp_msg_hdr->algo, AAD_LEN, iv, resp_msg_hdr->authtag)) return -EBADMSG; return 0; @@ -213,12 +213,12 @@ static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, void *payload, static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, int version, u8 type, void *payload, size_t sz) { - struct snp_guest_msg *req = &snp_dev->secret_request; - struct snp_guest_msg_hdr *hdr = &req->hdr; + struct snp_guest_msg *msg = snp_dev->secret_request; + struct snp_guest_msg_hdr *hdr = &msg->hdr; struct aesgcm_ctx *ctx = snp_dev->ctx; u8 iv[GCM_AES_IV_SIZE] = {}; - memset(req, 0, sizeof(*req)); + memset(msg, 0, SNP_GUEST_MSG_SIZE); hdr->algo = SNP_AEAD_AES_256_GCM; hdr->hdr_version = MSG_HDR_VER; @@ -236,11 +236,11 @@ static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, int version, u8 pr_debug("request [seqno %lld type %d version %d sz %d]\n", hdr->msg_seqno, hdr->msg_type, hdr->msg_version, hdr->msg_sz); - if (WARN_ON((sz + ctx->authsize) > sizeof(req->payload))) + if (WARN_ON((sz + ctx->authsize) > SNP_GUEST_MSG_PAYLOAD_SIZE)) return -EBADMSG; memcpy(iv, &hdr->msg_seqno, min(sizeof(iv), sizeof(hdr->msg_seqno))); - aesgcm_encrypt(ctx, req->payload, payload, sz, &hdr->algo, AAD_LEN, + aesgcm_encrypt(ctx, msg->payload, payload, sz, &hdr->algo, AAD_LEN, iv, hdr->authtag); return 0; @@ -346,7 +346,7 @@ static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code, return -EIO; /* Clear shared memory's response for the host to populate. */ - memset(snp_dev->response, 0, sizeof(struct snp_guest_msg)); + memset(snp_dev->response, 0, SNP_GUEST_MSG_SIZE); /* Encrypt the userspace provided payload in snp_dev->secret_request. */ rc = enc_payload(snp_dev, seqno, rio->msg_version, type, req_buf, req_sz); @@ -357,8 +357,7 @@ static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code, * Write the fully encrypted request to the shared unencrypted * request page. */ - memcpy(snp_dev->request, &snp_dev->secret_request, - sizeof(snp_dev->secret_request)); + memcpy(snp_dev->request, &snp_dev->secret_request, SNP_GUEST_MSG_SIZE); rc = __handle_guest_request(snp_dev, exit_code, rio); if (rc) { @@ -842,12 +841,21 @@ static int __init sev_guest_probe(struct platform_device *pdev) snp_dev->dev = dev; snp_dev->secrets = secrets; + /* Allocate secret request and response message for double buffering */ + snp_dev->secret_request = kzalloc(SNP_GUEST_MSG_SIZE, GFP_KERNEL); + if (!snp_dev->secret_request) + goto e_unmap; + + snp_dev->secret_response = kzalloc(SNP_GUEST_MSG_SIZE, GFP_KERNEL); + if (!snp_dev->secret_response) + goto e_free_secret_req; + /* Allocate the shared page used for the request and response message. */ - snp_dev->request = alloc_shared_pages(dev, sizeof(struct snp_guest_msg)); + snp_dev->request = alloc_shared_pages(dev, SNP_GUEST_MSG_SIZE); if (!snp_dev->request) - goto e_unmap; + goto e_free_secret_resp; - snp_dev->response = alloc_shared_pages(dev, sizeof(struct snp_guest_msg)); + snp_dev->response = alloc_shared_pages(dev, SNP_GUEST_MSG_SIZE); if (!snp_dev->response) goto e_free_request; @@ -890,9 +898,13 @@ static int __init sev_guest_probe(struct platform_device *pdev) e_free_cert_data: free_shared_pages(snp_dev->certs_data, SEV_FW_BLOB_MAX_SIZE); e_free_response: - free_shared_pages(snp_dev->response, sizeof(struct snp_guest_msg)); + free_shared_pages(snp_dev->response, SNP_GUEST_MSG_SIZE); e_free_request: - free_shared_pages(snp_dev->request, sizeof(struct snp_guest_msg)); + free_shared_pages(snp_dev->request, SNP_GUEST_MSG_SIZE); +e_free_secret_resp: + kfree(snp_dev->secret_response); +e_free_secret_req: + kfree(snp_dev->secret_request); e_unmap: iounmap(mapping); return ret; @@ -903,8 +915,10 @@ static void __exit sev_guest_remove(struct platform_device *pdev) struct snp_guest_dev *snp_dev = platform_get_drvdata(pdev); free_shared_pages(snp_dev->certs_data, SEV_FW_BLOB_MAX_SIZE); - free_shared_pages(snp_dev->response, sizeof(struct snp_guest_msg)); - free_shared_pages(snp_dev->request, sizeof(struct snp_guest_msg)); + free_shared_pages(snp_dev->response, SNP_GUEST_MSG_SIZE); + free_shared_pages(snp_dev->request, SNP_GUEST_MSG_SIZE); + kfree(snp_dev->secret_response); + kfree(snp_dev->secret_request); kfree(snp_dev->ctx); misc_deregister(&snp_dev->misc); } -- 2.34.1