I looked at moving from kvm_arch_vcpu_pre_fault_memory() to directly calling kvm_tdp_map_page(), so we could potentially put whatever check in kvm_arch_vcpu_pre_fault_memory(). It required a couple exports: diff --git a/arch/x86/kvm/mmu.h b/arch/x86/kvm/mmu.h index 03737f3aaeeb..9004ac597a85 100644 --- a/arch/x86/kvm/mmu.h +++ b/arch/x86/kvm/mmu.h @@ -277,6 +277,7 @@ extern bool tdp_mmu_enabled; #endif int kvm_tdp_mmu_get_walk_mirror_pfn(struct kvm_vcpu *vcpu, u64 gpa, kvm_pfn_t *pfn); +int kvm_tdp_map_page(struct kvm_vcpu *vcpu, gpa_t gpa, u64 error_code, u8 *level); static inline bool kvm_memslots_have_rmaps(struct kvm *kvm) { diff --git a/arch/x86/kvm/mmu/mmu.c b/arch/x86/kvm/mmu/mmu.c index 7bb6b17b455f..4a3e471ec9fe 100644 --- a/arch/x86/kvm/mmu/mmu.c +++ b/arch/x86/kvm/mmu/mmu.c @@ -4721,8 +4721,7 @@ int kvm_tdp_page_fault(struct kvm_vcpu *vcpu, struct kvm_page_fault *fault) return direct_page_fault(vcpu, fault); } -static int kvm_tdp_map_page(struct kvm_vcpu *vcpu, gpa_t gpa, u64 error_code, - u8 *level) +int kvm_tdp_map_page(struct kvm_vcpu *vcpu, gpa_t gpa, u64 error_code, u8 *level) { int r; @@ -4759,6 +4758,7 @@ static int kvm_tdp_map_page(struct kvm_vcpu *vcpu, gpa_t gpa, u64 error_code, return -EIO; } } +EXPORT_SYMBOL_GPL(kvm_tdp_map_page); long kvm_arch_vcpu_pre_fault_memory(struct kvm_vcpu *vcpu, struct kvm_pre_fault_memory *range) @@ -5770,6 +5770,7 @@ int kvm_mmu_load(struct kvm_vcpu *vcpu) out: return r; } +EXPORT_SYMBOL_GPL(kvm_mmu_load); void kvm_mmu_unload(struct kvm_vcpu *vcpu) { diff --git a/arch/x86/kvm/vmx/tdx.c b/arch/x86/kvm/vmx/tdx.c index 9ac0821eb44b..7161ef68f3da 100644 --- a/arch/x86/kvm/vmx/tdx.c +++ b/arch/x86/kvm/vmx/tdx.c @@ -2809,11 +2809,13 @@ struct tdx_gmem_post_populate_arg { static int tdx_gmem_post_populate(struct kvm *kvm, gfn_t gfn, kvm_pfn_t pfn, void __user *src, int order, void *_arg) { + u64 error_code = PFERR_GUEST_FINAL_MASK | PFERR_PRIVATE_ACCESS; struct kvm_tdx *kvm_tdx = to_kvm_tdx(kvm); struct tdx_gmem_post_populate_arg *arg = _arg; struct kvm_vcpu *vcpu = arg->vcpu; struct kvm_memory_slot *slot; gpa_t gpa = gfn_to_gpa(gfn); + u8 level = PG_LEVEL_4K; struct page *page; kvm_pfn_t mmu_pfn; int ret, i; @@ -2832,6 +2834,10 @@ static int tdx_gmem_post_populate(struct kvm *kvm, gfn_t gfn, kvm_pfn_t pfn, goto out_put_page; } + ret = kvm_tdp_map_page(vcpu, gpa, error_code, &level); + if (ret < 0) + goto out_put_page; + read_lock(&kvm->mmu_lock); ret = kvm_tdp_mmu_get_walk_mirror_pfn(vcpu, gpa, &mmu_pfn); @@ -2910,6 +2916,7 @@ static int tdx_vcpu_init_mem_region(struct kvm_vcpu *vcpu, struct kvm_tdx_cmd *c mutex_lock(&kvm->slots_lock); idx = srcu_read_lock(&kvm->srcu); + kvm_mmu_reload(vcpu); ret = 0; while (region.nr_pages) { if (signal_pending(current)) {