diff --git a/arch/x86/kvm/mmu/tdp_mmu.c b/arch/x86/kvm/mmu/tdp_mmu.c
index 3031b42c27a6..b838cfa984ad 100644
--- a/arch/x86/kvm/mmu/tdp_mmu.c
+++ b/arch/x86/kvm/mmu/tdp_mmu.c
@@ -91,21 +91,66 @@ void kvm_tdp_mmu_put_root(struct kvm *kvm, struct kvm_mmu_page *root,
WARN_ON(!root->tdp_mmu_page);
- spin_lock(&kvm->arch.tdp_mmu_pages_lock);
- list_del_rcu(&root->link);
- spin_unlock(&kvm->arch.tdp_mmu_pages_lock);
+ /*
+ * Ensure root->role.invalid is read after the refcount reaches zero to
+ * avoid zapping the root multiple times, e.g. if a different task
+ * acquires a reference (after the root was marked invalid) and puts
+ * the last reference, all while holding mmu_lock for read. Pairs
+ * with the smp_mb__before_atomic() below.
+ */
+ smp_mb__after_atomic();
+
+ /*
+ * Free the root if it's already invalid. Invalid roots must be zapped
+ * before their last reference is put, i.e. there's no work to be done,
+ * and all roots must be invalidated (see below) before they're freed.
+ * Re-zapping invalid roots would put KVM into an infinite loop (again,
+ * see below).
+ */
+ if (root->role.invalid) {
+ spin_lock(&kvm->arch.tdp_mmu_pages_lock);
+ list_del_rcu(&root->link);
+ spin_unlock(&kvm->arch.tdp_mmu_pages_lock);
+
+ call_rcu(&root->rcu_head, tdp_mmu_free_sp_rcu_callback);
+ return;
+ }
+
+ /*
+ * Invalidate the root to prevent it from being reused by a vCPU, and
+ * so that KVM doesn't re-zap the root when its last reference is put
+ * again (see above).
+ */
+ root->role.invalid = true;
+
+ /*
+ * Ensure role.invalid is visible if a concurrent reader acquires a
+ * reference after the root's refcount is reset. Pairs with the
+ * smp_mb__after_atomic() above.
+ */
+ smp_mb__before_atomic();
+ /*
+ * Note, if mmu_lock is held for read this can race with other readers,
+ * e.g. they may acquire a reference without seeing the root as invalid,
+ * and the refcount may be reset after the root is skipped. Both races
+ * are benign, as flows that must visit all roots, e.g. need to zap
+ * SPTEs for correctness, must take mmu_lock for write to block page
+ * faults, and the only flow that must not consume an invalid root is
+ * allocating a new root for a vCPU, which also takes mmu_lock for write.
+ */
+ refcount_set(&root->tdp_mmu_root_count, 1);
/*
- * A TLB flush is not necessary as KVM performs a local TLB flush when
- * allocating a new root (see kvm_mmu_load()), and when migrating vCPU
- * to a different pCPU. Note, the local TLB flush on reuse also
- * invalidates any paging-structure-cache entries, i.e. TLB entries for
- * intermediate paging structures, that may be zapped, as such entries
- * are associated with the ASID on both VMX and SVM.
+ * Zap the root, then put the refcount "acquired" above. Recursively
+ * call kvm_tdp_mmu_put_root() to test the above logic for avoiding an
+ * infinite loop by freeing invalid roots. By design, the root is
+ * reachable while it's being zapped, thus a different task can put its
+ * last reference, i.e. flowing through kvm_tdp_mmu_put_root() for a
+ * defunct root is unavoidable.
*/
tdp_mmu_zap_root(kvm, root, shared);
-
- call_rcu(&root->rcu_head, tdp_mmu_free_sp_rcu_callback);
+ kvm_tdp_mmu_put_root(kvm, root, shared);
}
enum tdp_mmu_roots_iter_type {