---
arch/x86/include/asm/sev.h | 6 +--
drivers/virt/coco/sev-guest/sev-guest.c | 63 +++++++++++++++----------
2 files changed, 42 insertions(+), 27 deletions(-)
diff --git a/arch/x86/include/asm/sev.h b/arch/x86/include/asm/sev.h
index 91f08af31078..82d9250aac34 100644
--- a/arch/x86/include/asm/sev.h
+++ b/arch/x86/include/asm/sev.h
@@ -185,6 +185,9 @@ struct snp_guest_req {
unsigned int vmpck_id;
u8 msg_version;
u8 msg_type;
+
+ struct snp_req_data input;
+ void *certs_data;
};
/*
@@ -245,9 +248,6 @@ struct snp_msg_desc {
struct snp_guest_msg secret_request, secret_response;
struct snp_secrets_page *secrets;
- struct snp_req_data input;
-
- void *certs_data;
struct aesgcm_ctx *ctx;
diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c
index af64e6191f74..480159606434 100644
--- a/drivers/virt/coco/sev-guest/sev-guest.c
+++ b/drivers/virt/coco/sev-guest/sev-guest.c
@@ -249,7 +249,7 @@ static int __handle_guest_request(struct snp_msg_desc *mdesc, struct snp_guest_r
* sequence number must be incremented or the VMPCK must be deleted to
* prevent reuse of the IV.
*/
- rc = snp_issue_guest_request(req, &mdesc->input, rio);
+ rc = snp_issue_guest_request(req, &req->input, rio);
switch (rc) {
case -ENOSPC:
/*
@@ -259,7 +259,7 @@ static int __handle_guest_request(struct snp_msg_desc *mdesc, struct snp_guest_r
* order to increment the sequence number and thus avoid
* IV reuse.
*/
- override_npages = mdesc->input.data_npages;
+ override_npages = req->input.data_npages;
req->exit_code = SVM_VMGEXIT_GUEST_REQUEST;
/*
@@ -315,7 +315,7 @@ static int __handle_guest_request(struct snp_msg_desc *mdesc, struct snp_guest_r
}
if (override_npages)
- mdesc->input.data_npages = override_npages;
+ req->input.data_npages = override_npages;
return rc;
}
@@ -354,6 +354,11 @@ static int snp_send_guest_request(struct snp_msg_desc *mdesc, struct snp_guest_r
memcpy(mdesc->request, &mdesc->secret_request,
sizeof(mdesc->secret_request));
+ /* initial the input address for guest request */
+ req->input.req_gpa = __pa(mdesc->request);
+ req->input.resp_gpa = __pa(mdesc->response);
+ req->input.data_gpa = req->certs_data ? __pa(req->certs_data) : 0;
+
rc = __handle_guest_request(mdesc, req, rio);
if (rc) {
if (rc == -EIO &&
@@ -495,6 +500,7 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
struct snp_guest_req req = {};
int ret, npages = 0, resp_len;
sockptr_t certs_address;
+ struct page *page;
if (sockptr_is_null(io->req_data) || sockptr_is_null(io->resp_data))
return -EINVAL;
@@ -528,8 +534,20 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
* the host. If host does not supply any certs in it, then copy
* zeros to indicate that certificate data was not provided.
*/
- memset(mdesc->certs_data, 0, report_req->certs_len);
npages = report_req->certs_len >> PAGE_SHIFT;
+ page = alloc_pages(GFP_KERNEL_ACCOUNT | __GFP_ZERO,
+ get_order(report_req->certs_len));
+ if (!page)
+ return -ENOMEM;
+
+ req.certs_data = page_address(page);
+ ret = set_memory_decrypted((unsigned long)req.certs_data, npages);
+ if (ret) {
+ pr_err("failed to mark page shared, ret=%d\n", ret);
+ __free_pages(page, get_order(report_req->certs_len));
+ return -EFAULT;
+ }
+
cmd:
/*
* The intermediate response buffer is used while decrypting the
@@ -538,10 +556,12 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
*/
resp_len = sizeof(report_resp->data) + mdesc->ctx->authsize;
report_resp = kzalloc(resp_len, GFP_KERNEL_ACCOUNT);
- if (!report_resp)
- return -ENOMEM;
+ if (!report_resp) {
+ ret = -ENOMEM;
+ goto e_free_data;
+ }
- mdesc->input.data_npages = npages;
+ req.input.data_npages = npages;
req.msg_version = arg->msg_version;
req.msg_type = SNP_MSG_REPORT_REQ;
@@ -556,7 +576,7 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
/* If certs length is invalid then copy the returned length */
if (arg->vmm_error == SNP_GUEST_VMM_ERR_INVALID_LEN) {
- report_req->certs_len = mdesc->input.data_npages << PAGE_SHIFT;
+ report_req->certs_len = req.input.data_npages << PAGE_SHIFT;
if (copy_to_sockptr(io->req_data, report_req, sizeof(*report_req)))
ret = -EFAULT;
@@ -565,7 +585,7 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
if (ret)
goto e_free;
- if (npages && copy_to_sockptr(certs_address, mdesc->certs_data, report_req->certs_len)) {
+ if (npages && copy_to_sockptr(certs_address, req.certs_data, report_req->certs_len)) {
ret = -EFAULT;
goto e_free;
}
@@ -575,6 +595,13 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
e_free:
kfree(report_resp);
+e_free_data:
+ if (npages) {
+ if (set_memory_encrypted((unsigned long)req.certs_data, npages))
+ WARN_ONCE(ret, "failed to restore encryption mask (leak it)\n");
+ else
+ __free_pages(page, get_order(report_req->certs_len));
+ }
return ret;
}
@@ -1048,35 +1075,26 @@ static int __init sev_guest_probe(struct platform_device *pdev)
if (!mdesc->response)
goto e_free_request;
- mdesc->certs_data = alloc_shared_pages(dev, SEV_FW_BLOB_MAX_SIZE);
- if (!mdesc->certs_data)
- goto e_free_response;
-
ret = -EIO;
mdesc->ctx = snp_init_crypto(mdesc->vmpck, VMPCK_KEY_LEN);
if (!mdesc->ctx)
- goto e_free_cert_data;
+ goto e_free_response;
misc = &snp_dev->misc;
misc->minor = MISC_DYNAMIC_MINOR;
misc->name = DEVICE_NAME;
misc->fops = &snp_guest_fops;
- /* Initialize the input addresses for guest request */
- mdesc->input.req_gpa = __pa(mdesc->request);
- mdesc->input.resp_gpa = __pa(mdesc->response);
- mdesc->input.data_gpa = __pa(mdesc->certs_data);
-
/* Set the privlevel_floor attribute based on the vmpck_id */
sev_tsm_ops.privlevel_floor = vmpck_id;
ret = tsm_register(&sev_tsm_ops, snp_dev);
if (ret)
- goto e_free_cert_data;
+ goto e_free_response;
ret = devm_add_action_or_reset(&pdev->dev, unregister_sev_tsm, NULL);
if (ret)
- goto e_free_cert_data;
+ goto e_free_response;
ret = misc_register(misc);
if (ret)
@@ -1088,8 +1106,6 @@ static int __init sev_guest_probe(struct platform_device *pdev)
e_free_ctx:
kfree(mdesc->ctx);
-e_free_cert_data:
- free_shared_pages(mdesc->certs_data, SEV_FW_BLOB_MAX_SIZE);
e_free_response:
free_shared_pages(mdesc->response, sizeof(struct snp_guest_msg));
e_free_request:
@@ -1104,7 +1120,6 @@ static void __exit sev_guest_remove(struct platform_device *pdev)
struct snp_guest_dev *snp_dev = platform_get_drvdata(pdev);
struct snp_msg_desc *mdesc = snp_dev->msg_desc;
- free_shared_pages(mdesc->certs_data, SEV_FW_BLOB_MAX_SIZE);
free_shared_pages(mdesc->response, sizeof(struct snp_guest_msg));
free_shared_pages(mdesc->request, sizeof(struct snp_guest_msg));
kfree(mdesc->ctx);