diff options
Diffstat (limited to 'tools/testing/selftests/kvm/lib')
| -rw-r--r-- | tools/testing/selftests/kvm/lib/kvm_util.c | 24 | ||||
| -rw-r--r-- | tools/testing/selftests/kvm/lib/loongarch/exception.S | 59 | ||||
| -rw-r--r-- | tools/testing/selftests/kvm/lib/loongarch/processor.c | 346 | ||||
| -rw-r--r-- | tools/testing/selftests/kvm/lib/loongarch/ucall.c | 38 | ||||
| -rw-r--r-- | tools/testing/selftests/kvm/lib/lru_gen_util.c | 387 | ||||
| -rw-r--r-- | tools/testing/selftests/kvm/lib/riscv/handlers.S | 139 | ||||
| -rw-r--r-- | tools/testing/selftests/kvm/lib/riscv/processor.c | 2 | ||||
| -rw-r--r-- | tools/testing/selftests/kvm/lib/test_util.c | 42 | ||||
| -rw-r--r-- | tools/testing/selftests/kvm/lib/x86/processor.c | 4 | ||||
| -rw-r--r-- | tools/testing/selftests/kvm/lib/x86/sev.c | 76 | 
10 files changed, 1024 insertions, 93 deletions
| diff --git a/tools/testing/selftests/kvm/lib/kvm_util.c b/tools/testing/selftests/kvm/lib/kvm_util.c index 815bc45dd8dc..a055343a7bf7 100644 --- a/tools/testing/selftests/kvm/lib/kvm_util.c +++ b/tools/testing/selftests/kvm/lib/kvm_util.c @@ -222,6 +222,7 @@ const char *vm_guest_mode_string(uint32_t i)  		[VM_MODE_P36V48_4K]	= "PA-bits:36,  VA-bits:48,  4K pages",  		[VM_MODE_P36V48_16K]	= "PA-bits:36,  VA-bits:48, 16K pages",  		[VM_MODE_P36V48_64K]	= "PA-bits:36,  VA-bits:48, 64K pages", +		[VM_MODE_P47V47_16K]	= "PA-bits:47,  VA-bits:47, 16K pages",  		[VM_MODE_P36V47_16K]	= "PA-bits:36,  VA-bits:47, 16K pages",  	};  	_Static_assert(sizeof(strings)/sizeof(char *) == NUM_VM_MODES, @@ -248,6 +249,7 @@ const struct vm_guest_mode_params vm_guest_mode_params[] = {  	[VM_MODE_P36V48_4K]	= { 36, 48,  0x1000, 12 },  	[VM_MODE_P36V48_16K]	= { 36, 48,  0x4000, 14 },  	[VM_MODE_P36V48_64K]	= { 36, 48, 0x10000, 16 }, +	[VM_MODE_P47V47_16K]	= { 47, 47,  0x4000, 14 },  	[VM_MODE_P36V47_16K]	= { 36, 47,  0x4000, 14 },  };  _Static_assert(sizeof(vm_guest_mode_params)/sizeof(struct vm_guest_mode_params) == NUM_VM_MODES, @@ -319,6 +321,7 @@ struct kvm_vm *____vm_create(struct vm_shape shape)  	case VM_MODE_P36V48_16K:  		vm->pgtable_levels = 4;  		break; +	case VM_MODE_P47V47_16K:  	case VM_MODE_P36V47_16K:  		vm->pgtable_levels = 3;  		break; @@ -444,6 +447,15 @@ void kvm_set_files_rlimit(uint32_t nr_vcpus)  } +static bool is_guest_memfd_required(struct vm_shape shape) +{ +#ifdef __x86_64__ +	return shape.type == KVM_X86_SNP_VM; +#else +	return false; +#endif +} +  struct kvm_vm *__vm_create(struct vm_shape shape, uint32_t nr_runnable_vcpus,  			   uint64_t nr_extra_pages)  { @@ -451,7 +463,7 @@ struct kvm_vm *__vm_create(struct vm_shape shape, uint32_t nr_runnable_vcpus,  						 nr_extra_pages);  	struct userspace_mem_region *slot0;  	struct kvm_vm *vm; -	int i; +	int i, flags;  	kvm_set_files_rlimit(nr_runnable_vcpus); @@ -460,7 +472,15 @@ struct kvm_vm *__vm_create(struct vm_shape shape, uint32_t nr_runnable_vcpus,  	vm = ____vm_create(shape); -	vm_userspace_mem_region_add(vm, VM_MEM_SRC_ANONYMOUS, 0, 0, nr_pages, 0); +	/* +	 * Force GUEST_MEMFD for the primary memory region if necessary, e.g. +	 * for CoCo VMs that require GUEST_MEMFD backed private memory. +	 */ +	flags = 0; +	if (is_guest_memfd_required(shape)) +		flags |= KVM_MEM_GUEST_MEMFD; + +	vm_userspace_mem_region_add(vm, VM_MEM_SRC_ANONYMOUS, 0, 0, nr_pages, flags);  	for (i = 0; i < NR_MEM_REGIONS; i++)  		vm->memslots[i] = 0; diff --git a/tools/testing/selftests/kvm/lib/loongarch/exception.S b/tools/testing/selftests/kvm/lib/loongarch/exception.S new file mode 100644 index 000000000000..88bfa505c6f5 --- /dev/null +++ b/tools/testing/selftests/kvm/lib/loongarch/exception.S @@ -0,0 +1,59 @@ +/* SPDX-License-Identifier: GPL-2.0 */ + +#include "processor.h" + +/* address of refill exception should be 4K aligned */ +.balign	4096 +.global handle_tlb_refill +handle_tlb_refill: +	csrwr	t0, LOONGARCH_CSR_TLBRSAVE +	csrrd	t0, LOONGARCH_CSR_PGD +	lddir	t0, t0, 3 +	lddir	t0, t0, 1 +	ldpte	t0, 0 +	ldpte	t0, 1 +	tlbfill +	csrrd	t0, LOONGARCH_CSR_TLBRSAVE +	ertn + +	/* +	 * save and restore all gprs except base register, +	 * and default value of base register is sp ($r3). +	 */ +.macro save_gprs base +	.irp n,1,2,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31 +	st.d    $r\n, \base, 8 * \n +	.endr +.endm + +.macro restore_gprs base +	.irp n,1,2,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31 +	ld.d    $r\n, \base, 8 * \n +	.endr +.endm + +/* address of general exception should be 4K aligned */ +.balign	4096 +.global handle_exception +handle_exception: +	csrwr  sp, LOONGARCH_CSR_KS0 +	csrrd  sp, LOONGARCH_CSR_KS1 +	addi.d sp, sp, -EXREGS_SIZE + +	save_gprs sp +	/* save sp register to stack */ +	csrrd  t0, LOONGARCH_CSR_KS0 +	st.d   t0, sp, 3 * 8 + +	csrrd  t0, LOONGARCH_CSR_ERA +	st.d   t0, sp, PC_OFFSET_EXREGS +	csrrd  t0, LOONGARCH_CSR_ESTAT +	st.d   t0, sp, ESTAT_OFFSET_EXREGS +	csrrd  t0, LOONGARCH_CSR_BADV +	st.d   t0, sp, BADV_OFFSET_EXREGS + +	or     a0, sp, zero +	bl route_exception +	restore_gprs sp +	csrrd  sp, LOONGARCH_CSR_KS0 +	ertn diff --git a/tools/testing/selftests/kvm/lib/loongarch/processor.c b/tools/testing/selftests/kvm/lib/loongarch/processor.c new file mode 100644 index 000000000000..0ac1abcb71cb --- /dev/null +++ b/tools/testing/selftests/kvm/lib/loongarch/processor.c @@ -0,0 +1,346 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <assert.h> +#include <linux/compiler.h> + +#include "kvm_util.h" +#include "processor.h" +#include "ucall_common.h" + +#define LOONGARCH_PAGE_TABLE_PHYS_MIN		0x200000 +#define LOONGARCH_GUEST_STACK_VADDR_MIN		0x200000 + +static vm_paddr_t invalid_pgtable[4]; + +static uint64_t virt_pte_index(struct kvm_vm *vm, vm_vaddr_t gva, int level) +{ +	unsigned int shift; +	uint64_t mask; + +	shift = level * (vm->page_shift - 3) + vm->page_shift; +	mask = (1UL << (vm->page_shift - 3)) - 1; +	return (gva >> shift) & mask; +} + +static uint64_t pte_addr(struct kvm_vm *vm, uint64_t entry) +{ +	return entry &  ~((0x1UL << vm->page_shift) - 1); +} + +static uint64_t ptrs_per_pte(struct kvm_vm *vm) +{ +	return 1 << (vm->page_shift - 3); +} + +static void virt_set_pgtable(struct kvm_vm *vm, vm_paddr_t table, vm_paddr_t child) +{ +	uint64_t *ptep; +	int i, ptrs_per_pte; + +	ptep = addr_gpa2hva(vm, table); +	ptrs_per_pte = 1 << (vm->page_shift - 3); +	for (i = 0; i < ptrs_per_pte; i++) +		WRITE_ONCE(*(ptep + i), child); +} + +void virt_arch_pgd_alloc(struct kvm_vm *vm) +{ +	int i; +	vm_paddr_t child, table; + +	if (vm->pgd_created) +		return; + +	child = table = 0; +	for (i = 0; i < vm->pgtable_levels; i++) { +		invalid_pgtable[i] = child; +		table = vm_phy_page_alloc(vm, LOONGARCH_PAGE_TABLE_PHYS_MIN, +				vm->memslots[MEM_REGION_PT]); +		TEST_ASSERT(table, "Fail to allocate page tale at level %d\n", i); +		virt_set_pgtable(vm, table, child); +		child = table; +	} +	vm->pgd = table; +	vm->pgd_created = true; +} + +static int virt_pte_none(uint64_t *ptep, int level) +{ +	return *ptep == invalid_pgtable[level]; +} + +static uint64_t *virt_populate_pte(struct kvm_vm *vm, vm_vaddr_t gva, int alloc) +{ +	int level; +	uint64_t *ptep; +	vm_paddr_t child; + +	if (!vm->pgd_created) +		goto unmapped_gva; + +	child = vm->pgd; +	level = vm->pgtable_levels - 1; +	while (level > 0) { +		ptep = addr_gpa2hva(vm, child) + virt_pte_index(vm, gva, level) * 8; +		if (virt_pte_none(ptep, level)) { +			if (alloc) { +				child = vm_alloc_page_table(vm); +				virt_set_pgtable(vm, child, invalid_pgtable[level - 1]); +				WRITE_ONCE(*ptep, child); +			} else +				goto unmapped_gva; + +		} else +			child = pte_addr(vm, *ptep); +		level--; +	} + +	ptep = addr_gpa2hva(vm, child) + virt_pte_index(vm, gva, level) * 8; +	return ptep; + +unmapped_gva: +	TEST_FAIL("No mapping for vm virtual address, gva: 0x%lx", gva); +	exit(EXIT_FAILURE); +} + +vm_paddr_t addr_arch_gva2gpa(struct kvm_vm *vm, vm_vaddr_t gva) +{ +	uint64_t *ptep; + +	ptep = virt_populate_pte(vm, gva, 0); +	TEST_ASSERT(*ptep != 0, "Virtual address vaddr: 0x%lx not mapped\n", gva); + +	return pte_addr(vm, *ptep) + (gva & (vm->page_size - 1)); +} + +void virt_arch_pg_map(struct kvm_vm *vm, uint64_t vaddr, uint64_t paddr) +{ +	uint32_t prot_bits; +	uint64_t *ptep; + +	TEST_ASSERT((vaddr % vm->page_size) == 0, +			"Virtual address not on page boundary,\n" +			"vaddr: 0x%lx vm->page_size: 0x%x", vaddr, vm->page_size); +	TEST_ASSERT(sparsebit_is_set(vm->vpages_valid, +			(vaddr >> vm->page_shift)), +			"Invalid virtual address, vaddr: 0x%lx", vaddr); +	TEST_ASSERT((paddr % vm->page_size) == 0, +			"Physical address not on page boundary,\n" +			"paddr: 0x%lx vm->page_size: 0x%x", paddr, vm->page_size); +	TEST_ASSERT((paddr >> vm->page_shift) <= vm->max_gfn, +			"Physical address beyond maximum supported,\n" +			"paddr: 0x%lx vm->max_gfn: 0x%lx vm->page_size: 0x%x", +			paddr, vm->max_gfn, vm->page_size); + +	ptep = virt_populate_pte(vm, vaddr, 1); +	prot_bits = _PAGE_PRESENT | __READABLE | __WRITEABLE | _CACHE_CC | _PAGE_USER; +	WRITE_ONCE(*ptep, paddr | prot_bits); +} + +static void pte_dump(FILE *stream, struct kvm_vm *vm, uint8_t indent, uint64_t page, int level) +{ +	uint64_t pte, *ptep; +	static const char * const type[] = { "pte", "pmd", "pud", "pgd"}; + +	if (level < 0) +		return; + +	for (pte = page; pte < page + ptrs_per_pte(vm) * 8; pte += 8) { +		ptep = addr_gpa2hva(vm, pte); +		if (virt_pte_none(ptep, level)) +			continue; +		fprintf(stream, "%*s%s: %lx: %lx at %p\n", +				indent, "", type[level], pte, *ptep, ptep); +		pte_dump(stream, vm, indent + 1, pte_addr(vm, *ptep), level--); +	} +} + +void virt_arch_dump(FILE *stream, struct kvm_vm *vm, uint8_t indent) +{ +	int level; + +	if (!vm->pgd_created) +		return; + +	level = vm->pgtable_levels - 1; +	pte_dump(stream, vm, indent, vm->pgd, level); +} + +void vcpu_arch_dump(FILE *stream, struct kvm_vcpu *vcpu, uint8_t indent) +{ +} + +void assert_on_unhandled_exception(struct kvm_vcpu *vcpu) +{ +	struct ucall uc; + +	if (get_ucall(vcpu, &uc) != UCALL_UNHANDLED) +		return; + +	TEST_FAIL("Unexpected exception (pc:0x%lx, estat:0x%lx, badv:0x%lx)", +			uc.args[0], uc.args[1], uc.args[2]); +} + +void route_exception(struct ex_regs *regs) +{ +	unsigned long pc, estat, badv; + +	pc = regs->pc; +	badv  = regs->badv; +	estat = regs->estat; +	ucall(UCALL_UNHANDLED, 3, pc, estat, badv); +	while (1) ; +} + +void vcpu_args_set(struct kvm_vcpu *vcpu, unsigned int num, ...) +{ +	int i; +	va_list ap; +	struct kvm_regs regs; + +	TEST_ASSERT(num >= 1 && num <= 8, "Unsupported number of args,\n" +		    "num: %u\n", num); + +	vcpu_regs_get(vcpu, ®s); + +	va_start(ap, num); +	for (i = 0; i < num; i++) +		regs.gpr[i + 4] = va_arg(ap, uint64_t); +	va_end(ap); + +	vcpu_regs_set(vcpu, ®s); +} + +static void loongarch_get_csr(struct kvm_vcpu *vcpu, uint64_t id, void *addr) +{ +	uint64_t csrid; + +	csrid = KVM_REG_LOONGARCH_CSR | KVM_REG_SIZE_U64 | 8 * id; +	__vcpu_get_reg(vcpu, csrid, addr); +} + +static void loongarch_set_csr(struct kvm_vcpu *vcpu, uint64_t id, uint64_t val) +{ +	uint64_t csrid; + +	csrid = KVM_REG_LOONGARCH_CSR | KVM_REG_SIZE_U64 | 8 * id; +	__vcpu_set_reg(vcpu, csrid, val); +} + +static void loongarch_vcpu_setup(struct kvm_vcpu *vcpu) +{ +	int width; +	unsigned long val; +	struct kvm_vm *vm = vcpu->vm; + +	switch (vm->mode) { +	case VM_MODE_P36V47_16K: +	case VM_MODE_P47V47_16K: +		break; + +	default: +		TEST_FAIL("Unknown guest mode, mode: 0x%x", vm->mode); +	} + +	/* user mode and page enable mode */ +	val = PLV_USER | CSR_CRMD_PG; +	loongarch_set_csr(vcpu, LOONGARCH_CSR_CRMD, val); +	loongarch_set_csr(vcpu, LOONGARCH_CSR_PRMD, val); +	loongarch_set_csr(vcpu, LOONGARCH_CSR_EUEN, 1); +	loongarch_set_csr(vcpu, LOONGARCH_CSR_ECFG, 0); +	loongarch_set_csr(vcpu, LOONGARCH_CSR_TCFG, 0); +	loongarch_set_csr(vcpu, LOONGARCH_CSR_ASID, 1); + +	val = 0; +	width = vm->page_shift - 3; + +	switch (vm->pgtable_levels) { +	case 4: +		/* pud page shift and width */ +		val = (vm->page_shift + width * 2) << 20 | (width << 25); +		/* fall throuth */ +	case 3: +		/* pmd page shift and width */ +		val |= (vm->page_shift + width) << 10 | (width << 15); +		/* pte page shift and width */ +		val |= vm->page_shift | width << 5; +		break; +	default: +		TEST_FAIL("Got %u page table levels, expected 3 or 4", vm->pgtable_levels); +	} + +	loongarch_set_csr(vcpu, LOONGARCH_CSR_PWCTL0, val); + +	/* PGD page shift and width */ +	val = (vm->page_shift + width * (vm->pgtable_levels - 1)) | width << 6; +	loongarch_set_csr(vcpu, LOONGARCH_CSR_PWCTL1, val); +	loongarch_set_csr(vcpu, LOONGARCH_CSR_PGDL, vm->pgd); + +	/* +	 * Refill exception runs on real mode +	 * Entry address should be physical address +	 */ +	val = addr_gva2gpa(vm, (unsigned long)handle_tlb_refill); +	loongarch_set_csr(vcpu, LOONGARCH_CSR_TLBRENTRY, val); + +	/* +	 * General exception runs on page-enabled mode +	 * Entry address should be virtual address +	 */ +	val = (unsigned long)handle_exception; +	loongarch_set_csr(vcpu, LOONGARCH_CSR_EENTRY, val); + +	loongarch_get_csr(vcpu, LOONGARCH_CSR_TLBIDX, &val); +	val &= ~CSR_TLBIDX_SIZEM; +	val |= PS_DEFAULT_SIZE << CSR_TLBIDX_SIZE; +	loongarch_set_csr(vcpu, LOONGARCH_CSR_TLBIDX, val); + +	loongarch_set_csr(vcpu, LOONGARCH_CSR_STLBPGSIZE, PS_DEFAULT_SIZE); + +	/* LOONGARCH_CSR_KS1 is used for exception stack */ +	val = __vm_vaddr_alloc(vm, vm->page_size, +			LOONGARCH_GUEST_STACK_VADDR_MIN, MEM_REGION_DATA); +	TEST_ASSERT(val != 0,  "No memory for exception stack"); +	val = val + vm->page_size; +	loongarch_set_csr(vcpu, LOONGARCH_CSR_KS1, val); + +	loongarch_get_csr(vcpu, LOONGARCH_CSR_TLBREHI, &val); +	val &= ~CSR_TLBREHI_PS; +	val |= PS_DEFAULT_SIZE << CSR_TLBREHI_PS_SHIFT; +	loongarch_set_csr(vcpu, LOONGARCH_CSR_TLBREHI, val); + +	loongarch_set_csr(vcpu, LOONGARCH_CSR_CPUID, vcpu->id); +	loongarch_set_csr(vcpu, LOONGARCH_CSR_TMID,  vcpu->id); +} + +struct kvm_vcpu *vm_arch_vcpu_add(struct kvm_vm *vm, uint32_t vcpu_id) +{ +	size_t stack_size; +	uint64_t stack_vaddr; +	struct kvm_regs regs; +	struct kvm_vcpu *vcpu; + +	vcpu = __vm_vcpu_add(vm, vcpu_id); +	stack_size = vm->page_size; +	stack_vaddr = __vm_vaddr_alloc(vm, stack_size, +			LOONGARCH_GUEST_STACK_VADDR_MIN, MEM_REGION_DATA); +	TEST_ASSERT(stack_vaddr != 0,  "No memory for vm stack"); + +	loongarch_vcpu_setup(vcpu); +	/* Setup guest general purpose registers */ +	vcpu_regs_get(vcpu, ®s); +	regs.gpr[3] = stack_vaddr + stack_size; +	vcpu_regs_set(vcpu, ®s); + +	return vcpu; +} + +void vcpu_arch_set_entry_point(struct kvm_vcpu *vcpu, void *guest_code) +{ +	struct kvm_regs regs; + +	/* Setup guest PC register */ +	vcpu_regs_get(vcpu, ®s); +	regs.pc = (uint64_t)guest_code; +	vcpu_regs_set(vcpu, ®s); +} diff --git a/tools/testing/selftests/kvm/lib/loongarch/ucall.c b/tools/testing/selftests/kvm/lib/loongarch/ucall.c new file mode 100644 index 000000000000..fc6cbb50573f --- /dev/null +++ b/tools/testing/selftests/kvm/lib/loongarch/ucall.c @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * ucall support. A ucall is a "hypercall to userspace". + * + */ +#include "kvm_util.h" + +/* + * ucall_exit_mmio_addr holds per-VM values (global data is duplicated by each + * VM), it must not be accessed from host code. + */ +vm_vaddr_t *ucall_exit_mmio_addr; + +void ucall_arch_init(struct kvm_vm *vm, vm_paddr_t mmio_gpa) +{ +	vm_vaddr_t mmio_gva = vm_vaddr_unused_gap(vm, vm->page_size, KVM_UTIL_MIN_VADDR); + +	virt_map(vm, mmio_gva, mmio_gpa, 1); + +	vm->ucall_mmio_addr = mmio_gpa; + +	write_guest_global(vm, ucall_exit_mmio_addr, (vm_vaddr_t *)mmio_gva); +} + +void *ucall_arch_get_ucall(struct kvm_vcpu *vcpu) +{ +	struct kvm_run *run = vcpu->run; + +	if (run->exit_reason == KVM_EXIT_MMIO && +	    run->mmio.phys_addr == vcpu->vm->ucall_mmio_addr) { +		TEST_ASSERT(run->mmio.is_write && run->mmio.len == sizeof(uint64_t), +			    "Unexpected ucall exit mmio address access"); + +		return (void *)(*((uint64_t *)run->mmio.data)); +	} + +	return NULL; +} diff --git a/tools/testing/selftests/kvm/lib/lru_gen_util.c b/tools/testing/selftests/kvm/lib/lru_gen_util.c new file mode 100644 index 000000000000..46a14fd63d9e --- /dev/null +++ b/tools/testing/selftests/kvm/lib/lru_gen_util.c @@ -0,0 +1,387 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C) 2025, Google LLC. + */ + +#include <time.h> + +#include "lru_gen_util.h" + +/* + * Tracks state while we parse memcg lru_gen stats. The file we're parsing is + * structured like this (some extra whitespace elided): + * + * memcg (id) (path) + * node (id) + * (gen_nr) (age_in_ms) (nr_anon_pages) (nr_file_pages) + */ +struct memcg_stats_parse_context { +	bool consumed; /* Whether or not this line was consumed */ +	/* Next parse handler to invoke */ +	void (*next_handler)(struct memcg_stats *stats, +			     struct memcg_stats_parse_context *ctx, +			     char *line); +	int current_node_idx; /* Current index in nodes array */ +	const char *name; /* The name of the memcg we're looking for */ +}; + +static void memcg_stats_handle_searching(struct memcg_stats *stats, +					 struct memcg_stats_parse_context *ctx, +					 char *line); +static void memcg_stats_handle_in_memcg(struct memcg_stats *stats, +					struct memcg_stats_parse_context *ctx, +					char *line); +static void memcg_stats_handle_in_node(struct memcg_stats *stats, +				       struct memcg_stats_parse_context *ctx, +				       char *line); + +struct split_iterator { +	char *str; +	char *save; +}; + +static char *split_next(struct split_iterator *it) +{ +	char *ret = strtok_r(it->str, " \t\n\r", &it->save); + +	it->str = NULL; +	return ret; +} + +static void memcg_stats_handle_searching(struct memcg_stats *stats, +					 struct memcg_stats_parse_context *ctx, +					 char *line) +{ +	struct split_iterator it = { .str = line }; +	char *prefix = split_next(&it); +	char *memcg_id = split_next(&it); +	char *memcg_name = split_next(&it); +	char *end; + +	ctx->consumed = true; + +	if (!prefix || strcmp("memcg", prefix)) +		return; /* Not a memcg line (maybe empty), skip */ + +	TEST_ASSERT(memcg_id && memcg_name, +		    "malformed memcg line; no memcg id or memcg_name"); + +	if (strcmp(memcg_name + 1, ctx->name)) +		return; /* Wrong memcg, skip */ + +	/* Found it! */ + +	stats->memcg_id = strtoul(memcg_id, &end, 10); +	TEST_ASSERT(*end == '\0', "malformed memcg id '%s'", memcg_id); +	if (!stats->memcg_id) +		return; /* Removed memcg? */ + +	ctx->next_handler = memcg_stats_handle_in_memcg; +} + +static void memcg_stats_handle_in_memcg(struct memcg_stats *stats, +					struct memcg_stats_parse_context *ctx, +					char *line) +{ +	struct split_iterator it = { .str = line }; +	char *prefix = split_next(&it); +	char *id = split_next(&it); +	long found_node_id; +	char *end; + +	ctx->consumed = true; +	ctx->current_node_idx = -1; + +	if (!prefix) +		return; /* Skip empty lines */ + +	if (!strcmp("memcg", prefix)) { +		/* Memcg done, found next one; stop. */ +		ctx->next_handler = NULL; +		return; +	} else if (strcmp("node", prefix)) +		TEST_ASSERT(false, "found malformed line after 'memcg ...'," +				   "token: '%s'", prefix); + +	/* At this point we know we have a node line. Parse the ID. */ + +	TEST_ASSERT(id, "malformed node line; no node id"); + +	found_node_id = strtol(id, &end, 10); +	TEST_ASSERT(*end == '\0', "malformed node id '%s'", id); + +	ctx->current_node_idx = stats->nr_nodes++; +	TEST_ASSERT(ctx->current_node_idx < MAX_NR_NODES, +		    "memcg has stats for too many nodes, max is %d", +		    MAX_NR_NODES); +	stats->nodes[ctx->current_node_idx].node = found_node_id; + +	ctx->next_handler = memcg_stats_handle_in_node; +} + +static void memcg_stats_handle_in_node(struct memcg_stats *stats, +				       struct memcg_stats_parse_context *ctx, +				       char *line) +{ +	char *my_line = strdup(line); +	struct split_iterator it = { .str = my_line }; +	char *gen, *age, *nr_anon, *nr_file; +	struct node_stats *node_stats; +	struct generation_stats *gen_stats; +	char *end; + +	TEST_ASSERT(it.str, "failed to copy input line"); + +	gen = split_next(&it); + +	if (!gen) +		goto out_consume; /* Skip empty lines */ + +	if (!strcmp("memcg", gen) || !strcmp("node", gen)) { +		/* +		 * Reached next memcg or node section. Don't consume, let the +		 * other handler deal with this. +		 */ +		ctx->next_handler = memcg_stats_handle_in_memcg; +		goto out; +	} + +	node_stats = &stats->nodes[ctx->current_node_idx]; +	TEST_ASSERT(node_stats->nr_gens < MAX_NR_GENS, +		    "found too many generation lines; max is %d", +		    MAX_NR_GENS); +	gen_stats = &node_stats->gens[node_stats->nr_gens++]; + +	age = split_next(&it); +	nr_anon = split_next(&it); +	nr_file = split_next(&it); + +	TEST_ASSERT(age && nr_anon && nr_file, +		    "malformed generation line; not enough tokens"); + +	gen_stats->gen = (int)strtol(gen, &end, 10); +	TEST_ASSERT(*end == '\0', "malformed generation number '%s'", gen); + +	gen_stats->age_ms = strtol(age, &end, 10); +	TEST_ASSERT(*end == '\0', "malformed generation age '%s'", age); + +	gen_stats->nr_anon = strtol(nr_anon, &end, 10); +	TEST_ASSERT(*end == '\0', "malformed anonymous page count '%s'", +		    nr_anon); + +	gen_stats->nr_file = strtol(nr_file, &end, 10); +	TEST_ASSERT(*end == '\0', "malformed file page count '%s'", nr_file); + +out_consume: +	ctx->consumed = true; +out: +	free(my_line); +} + +static void print_memcg_stats(const struct memcg_stats *stats, const char *name) +{ +	int node, gen; + +	pr_debug("stats for memcg %s (id %lu):\n", name, stats->memcg_id); +	for (node = 0; node < stats->nr_nodes; ++node) { +		pr_debug("\tnode %d\n", stats->nodes[node].node); +		for (gen = 0; gen < stats->nodes[node].nr_gens; ++gen) { +			const struct generation_stats *gstats = +				&stats->nodes[node].gens[gen]; + +			pr_debug("\t\tgen %d\tage_ms %ld" +				 "\tnr_anon %ld\tnr_file %ld\n", +				 gstats->gen, gstats->age_ms, gstats->nr_anon, +				 gstats->nr_file); +		} +	} +} + +/* Re-read lru_gen debugfs information for @memcg into @stats. */ +void lru_gen_read_memcg_stats(struct memcg_stats *stats, const char *memcg) +{ +	FILE *f; +	ssize_t read = 0; +	char *line = NULL; +	size_t bufsz; +	struct memcg_stats_parse_context ctx = { +		.next_handler = memcg_stats_handle_searching, +		.name = memcg, +	}; + +	memset(stats, 0, sizeof(struct memcg_stats)); + +	f = fopen(LRU_GEN_DEBUGFS, "r"); +	TEST_ASSERT(f, "fopen(%s) failed", LRU_GEN_DEBUGFS); + +	while (ctx.next_handler && (read = getline(&line, &bufsz, f)) > 0) { +		ctx.consumed = false; + +		do { +			ctx.next_handler(stats, &ctx, line); +			if (!ctx.next_handler) +				break; +		} while (!ctx.consumed); +	} + +	if (read < 0 && !feof(f)) +		TEST_ASSERT(false, "getline(%s) failed", LRU_GEN_DEBUGFS); + +	TEST_ASSERT(stats->memcg_id > 0, "Couldn't find memcg: %s\n" +		    "Did the memcg get created in the proper mount?", +		    memcg); +	if (line) +		free(line); +	TEST_ASSERT(!fclose(f), "fclose(%s) failed", LRU_GEN_DEBUGFS); + +	print_memcg_stats(stats, memcg); +} + +/* + * Find all pages tracked by lru_gen for this memcg in generation @target_gen. + * + * If @target_gen is negative, look for all generations. + */ +long lru_gen_sum_memcg_stats_for_gen(int target_gen, +				     const struct memcg_stats *stats) +{ +	int node, gen; +	long total_nr = 0; + +	for (node = 0; node < stats->nr_nodes; ++node) { +		const struct node_stats *node_stats = &stats->nodes[node]; + +		for (gen = 0; gen < node_stats->nr_gens; ++gen) { +			const struct generation_stats *gen_stats = +				&node_stats->gens[gen]; + +			if (target_gen >= 0 && gen_stats->gen != target_gen) +				continue; + +			total_nr += gen_stats->nr_anon + gen_stats->nr_file; +		} +	} + +	return total_nr; +} + +/* Find all pages tracked by lru_gen for this memcg. */ +long lru_gen_sum_memcg_stats(const struct memcg_stats *stats) +{ +	return lru_gen_sum_memcg_stats_for_gen(-1, stats); +} + +/* + * If lru_gen aging should force page table scanning. + * + * If you want to set this to false, you will need to do eviction + * before doing extra aging passes. + */ +static const bool force_scan = true; + +static void run_aging_impl(unsigned long memcg_id, int node_id, int max_gen) +{ +	FILE *f = fopen(LRU_GEN_DEBUGFS, "w"); +	char *command; +	size_t sz; + +	TEST_ASSERT(f, "fopen(%s) failed", LRU_GEN_DEBUGFS); +	sz = asprintf(&command, "+ %lu %d %d 1 %d\n", +		      memcg_id, node_id, max_gen, force_scan); +	TEST_ASSERT(sz > 0, "creating aging command failed"); + +	pr_debug("Running aging command: %s", command); +	if (fwrite(command, sizeof(char), sz, f) < sz) { +		TEST_ASSERT(false, "writing aging command %s to %s failed", +			    command, LRU_GEN_DEBUGFS); +	} + +	TEST_ASSERT(!fclose(f), "fclose(%s) failed", LRU_GEN_DEBUGFS); +} + +void lru_gen_do_aging(struct memcg_stats *stats, const char *memcg) +{ +	int node, gen; + +	pr_debug("lru_gen: invoking aging...\n"); + +	/* Must read memcg stats to construct the proper aging command. */ +	lru_gen_read_memcg_stats(stats, memcg); + +	for (node = 0; node < stats->nr_nodes; ++node) { +		int max_gen = 0; + +		for (gen = 0; gen < stats->nodes[node].nr_gens; ++gen) { +			int this_gen = stats->nodes[node].gens[gen].gen; + +			max_gen = max_gen > this_gen ? max_gen : this_gen; +		} + +		run_aging_impl(stats->memcg_id, stats->nodes[node].node, +			       max_gen); +	} + +	/* Re-read so callers get updated information */ +	lru_gen_read_memcg_stats(stats, memcg); +} + +/* + * Find which generation contains at least @pages pages, assuming that + * such a generation exists. + */ +int lru_gen_find_generation(const struct memcg_stats *stats, +			    unsigned long pages) +{ +	int node, gen, gen_idx, min_gen = INT_MAX, max_gen = -1; + +	for (node = 0; node < stats->nr_nodes; ++node) +		for (gen_idx = 0; gen_idx < stats->nodes[node].nr_gens; +		     ++gen_idx) { +			gen = stats->nodes[node].gens[gen_idx].gen; +			max_gen = gen > max_gen ? gen : max_gen; +			min_gen = gen < min_gen ? gen : min_gen; +		} + +	for (gen = min_gen; gen <= max_gen; ++gen) +		/* See if this generation has enough pages. */ +		if (lru_gen_sum_memcg_stats_for_gen(gen, stats) > pages) +			return gen; + +	return -1; +} + +bool lru_gen_usable(void) +{ +	long required_features = LRU_GEN_ENABLED | LRU_GEN_MM_WALK; +	int lru_gen_fd, lru_gen_debug_fd; +	char mglru_feature_str[8] = {}; +	long mglru_features; + +	lru_gen_fd = open(LRU_GEN_ENABLED_PATH, O_RDONLY); +	if (lru_gen_fd < 0) { +		puts("lru_gen: Could not open " LRU_GEN_ENABLED_PATH); +		return false; +	} +	if (read(lru_gen_fd, &mglru_feature_str, 7) < 7) { +		puts("lru_gen: Could not read from " LRU_GEN_ENABLED_PATH); +		close(lru_gen_fd); +		return false; +	} +	close(lru_gen_fd); + +	mglru_features = strtol(mglru_feature_str, NULL, 16); +	if ((mglru_features & required_features) != required_features) { +		printf("lru_gen: missing features, got: 0x%lx, expected: 0x%lx\n", +		       mglru_features, required_features); +		printf("lru_gen: Try 'echo 0x%lx > /sys/kernel/mm/lru_gen/enabled'\n", +		       required_features); +		return false; +	} + +	lru_gen_debug_fd = open(LRU_GEN_DEBUGFS, O_RDWR); +	__TEST_REQUIRE(lru_gen_debug_fd >= 0, +		       "lru_gen: Could not open " LRU_GEN_DEBUGFS ", " +		       "but lru_gen is enabled, so cannot use page_idle."); +	close(lru_gen_debug_fd); +	return true; +} diff --git a/tools/testing/selftests/kvm/lib/riscv/handlers.S b/tools/testing/selftests/kvm/lib/riscv/handlers.S index aa0abd3f35bb..b787b982e922 100644 --- a/tools/testing/selftests/kvm/lib/riscv/handlers.S +++ b/tools/testing/selftests/kvm/lib/riscv/handlers.S @@ -10,85 +10,88 @@  #include <asm/csr.h>  .macro save_context -	addi  sp, sp, (-8*34) -	sd    x1, 0(sp) -	sd    x2, 8(sp) -	sd    x3, 16(sp) -	sd    x4, 24(sp) -	sd    x5, 32(sp) -	sd    x6, 40(sp) -	sd    x7, 48(sp) -	sd    x8, 56(sp) -	sd    x9, 64(sp) -	sd    x10, 72(sp) -	sd    x11, 80(sp) -	sd    x12, 88(sp) -	sd    x13, 96(sp) -	sd    x14, 104(sp) -	sd    x15, 112(sp) -	sd    x16, 120(sp) -	sd    x17, 128(sp) -	sd    x18, 136(sp) -	sd    x19, 144(sp) -	sd    x20, 152(sp) -	sd    x21, 160(sp) -	sd    x22, 168(sp) -	sd    x23, 176(sp) -	sd    x24, 184(sp) -	sd    x25, 192(sp) -	sd    x26, 200(sp) -	sd    x27, 208(sp) -	sd    x28, 216(sp) -	sd    x29, 224(sp) -	sd    x30, 232(sp) -	sd    x31, 240(sp) +	addi  sp, sp, (-8*36) +	sd    x1, 8(sp) +	sd    x2, 16(sp) +	sd    x3, 24(sp) +	sd    x4, 32(sp) +	sd    x5, 40(sp) +	sd    x6, 48(sp) +	sd    x7, 56(sp) +	sd    x8, 64(sp) +	sd    x9, 72(sp) +	sd    x10, 80(sp) +	sd    x11, 88(sp) +	sd    x12, 96(sp) +	sd    x13, 104(sp) +	sd    x14, 112(sp) +	sd    x15, 120(sp) +	sd    x16, 128(sp) +	sd    x17, 136(sp) +	sd    x18, 144(sp) +	sd    x19, 152(sp) +	sd    x20, 160(sp) +	sd    x21, 168(sp) +	sd    x22, 176(sp) +	sd    x23, 184(sp) +	sd    x24, 192(sp) +	sd    x25, 200(sp) +	sd    x26, 208(sp) +	sd    x27, 216(sp) +	sd    x28, 224(sp) +	sd    x29, 232(sp) +	sd    x30, 240(sp) +	sd    x31, 248(sp)  	csrr  s0, CSR_SEPC  	csrr  s1, CSR_SSTATUS -	csrr  s2, CSR_SCAUSE -	sd    s0, 248(sp) +	csrr  s2, CSR_STVAL +	csrr  s3, CSR_SCAUSE +	sd    s0, 0(sp)  	sd    s1, 256(sp)  	sd    s2, 264(sp) +	sd    s3, 272(sp)  .endm  .macro restore_context +	ld    s3, 272(sp)  	ld    s2, 264(sp)  	ld    s1, 256(sp) -	ld    s0, 248(sp) -	csrw  CSR_SCAUSE, s2 +	ld    s0, 0(sp) +	csrw  CSR_SCAUSE, s3  	csrw  CSR_SSTATUS, s1  	csrw  CSR_SEPC, s0 -	ld    x31, 240(sp) -	ld    x30, 232(sp) -	ld    x29, 224(sp) -	ld    x28, 216(sp) -	ld    x27, 208(sp) -	ld    x26, 200(sp) -	ld    x25, 192(sp) -	ld    x24, 184(sp) -	ld    x23, 176(sp) -	ld    x22, 168(sp) -	ld    x21, 160(sp) -	ld    x20, 152(sp) -	ld    x19, 144(sp) -	ld    x18, 136(sp) -	ld    x17, 128(sp) -	ld    x16, 120(sp) -	ld    x15, 112(sp) -	ld    x14, 104(sp) -	ld    x13, 96(sp) -	ld    x12, 88(sp) -	ld    x11, 80(sp) -	ld    x10, 72(sp) -	ld    x9, 64(sp) -	ld    x8, 56(sp) -	ld    x7, 48(sp) -	ld    x6, 40(sp) -	ld    x5, 32(sp) -	ld    x4, 24(sp) -	ld    x3, 16(sp) -	ld    x2, 8(sp) -	ld    x1, 0(sp) -	addi  sp, sp, (8*34) +	ld    x31, 248(sp) +	ld    x30, 240(sp) +	ld    x29, 232(sp) +	ld    x28, 224(sp) +	ld    x27, 216(sp) +	ld    x26, 208(sp) +	ld    x25, 200(sp) +	ld    x24, 192(sp) +	ld    x23, 184(sp) +	ld    x22, 176(sp) +	ld    x21, 168(sp) +	ld    x20, 160(sp) +	ld    x19, 152(sp) +	ld    x18, 144(sp) +	ld    x17, 136(sp) +	ld    x16, 128(sp) +	ld    x15, 120(sp) +	ld    x14, 112(sp) +	ld    x13, 104(sp) +	ld    x12, 96(sp) +	ld    x11, 88(sp) +	ld    x10, 80(sp) +	ld    x9, 72(sp) +	ld    x8, 64(sp) +	ld    x7, 56(sp) +	ld    x6, 48(sp) +	ld    x5, 40(sp) +	ld    x4, 32(sp) +	ld    x3, 24(sp) +	ld    x2, 16(sp) +	ld    x1, 8(sp) +	addi  sp, sp, (8*36)  .endm  .balign 4 diff --git a/tools/testing/selftests/kvm/lib/riscv/processor.c b/tools/testing/selftests/kvm/lib/riscv/processor.c index dd663bcf0cc0..2eac7d4b59e9 100644 --- a/tools/testing/selftests/kvm/lib/riscv/processor.c +++ b/tools/testing/selftests/kvm/lib/riscv/processor.c @@ -402,7 +402,7 @@ struct handlers {  	exception_handler_fn exception_handlers[NR_VECTORS][NR_EXCEPTIONS];  }; -void route_exception(struct ex_regs *regs) +void route_exception(struct pt_regs *regs)  {  	struct handlers *handlers = (struct handlers *)exception_handlers;  	int vector = 0, ec; diff --git a/tools/testing/selftests/kvm/lib/test_util.c b/tools/testing/selftests/kvm/lib/test_util.c index 8ed0b74ae837..03eb99af9b8d 100644 --- a/tools/testing/selftests/kvm/lib/test_util.c +++ b/tools/testing/selftests/kvm/lib/test_util.c @@ -132,37 +132,57 @@ void print_skip(const char *fmt, ...)  	puts(", skipping test");  } -bool thp_configured(void) +static bool test_sysfs_path(const char *path)  { -	int ret;  	struct stat statbuf; +	int ret; -	ret = stat("/sys/kernel/mm/transparent_hugepage", &statbuf); +	ret = stat(path, &statbuf);  	TEST_ASSERT(ret == 0 || (ret == -1 && errno == ENOENT), -		    "Error in stating /sys/kernel/mm/transparent_hugepage"); +		    "Error in stat()ing '%s'", path);  	return ret == 0;  } -size_t get_trans_hugepagesz(void) +bool thp_configured(void) +{ +	return test_sysfs_path("/sys/kernel/mm/transparent_hugepage"); +} + +static size_t get_sysfs_val(const char *path)  {  	size_t size;  	FILE *f;  	int ret; -	TEST_ASSERT(thp_configured(), "THP is not configured in host kernel"); - -	f = fopen("/sys/kernel/mm/transparent_hugepage/hpage_pmd_size", "r"); -	TEST_ASSERT(f != NULL, "Error in opening transparent_hugepage/hpage_pmd_size"); +	f = fopen(path, "r"); +	TEST_ASSERT(f, "Error opening '%s'", path);  	ret = fscanf(f, "%ld", &size); +	TEST_ASSERT(ret > 0, "Error reading '%s'", path); + +	/* Re-scan the input stream to verify the entire file was read. */  	ret = fscanf(f, "%ld", &size); -	TEST_ASSERT(ret < 1, "Error reading transparent_hugepage/hpage_pmd_size"); -	fclose(f); +	TEST_ASSERT(ret < 1, "Error reading '%s'", path); +	fclose(f);  	return size;  } +size_t get_trans_hugepagesz(void) +{ +	TEST_ASSERT(thp_configured(), "THP is not configured in host kernel"); + +	return get_sysfs_val("/sys/kernel/mm/transparent_hugepage/hpage_pmd_size"); +} + +bool is_numa_balancing_enabled(void) +{ +	if (!test_sysfs_path("/proc/sys/kernel/numa_balancing")) +		return false; +	return get_sysfs_val("/proc/sys/kernel/numa_balancing") == 1; +} +  size_t get_def_hugetlb_pagesz(void)  {  	char buf[64]; diff --git a/tools/testing/selftests/kvm/lib/x86/processor.c b/tools/testing/selftests/kvm/lib/x86/processor.c index bd5a802fa7a5..a92dc1dad085 100644 --- a/tools/testing/selftests/kvm/lib/x86/processor.c +++ b/tools/testing/selftests/kvm/lib/x86/processor.c @@ -639,7 +639,7 @@ void kvm_arch_vm_post_create(struct kvm_vm *vm)  	sync_global_to_guest(vm, host_cpu_is_amd);  	sync_global_to_guest(vm, is_forced_emulation_enabled); -	if (vm->type == KVM_X86_SEV_VM || vm->type == KVM_X86_SEV_ES_VM) { +	if (is_sev_vm(vm)) {  		struct kvm_sev_init init = { 0 };  		vm_sev_ioctl(vm, KVM_SEV_INIT2, &init); @@ -1156,7 +1156,7 @@ void kvm_get_cpu_address_width(unsigned int *pa_bits, unsigned int *va_bits)  void kvm_init_vm_address_properties(struct kvm_vm *vm)  { -	if (vm->type == KVM_X86_SEV_VM || vm->type == KVM_X86_SEV_ES_VM) { +	if (is_sev_vm(vm)) {  		vm->arch.sev_fd = open_sev_dev_path_or_exit();  		vm->arch.c_bit = BIT_ULL(this_cpu_property(X86_PROPERTY_SEV_C_BIT));  		vm->gpa_tag_mask = vm->arch.c_bit; diff --git a/tools/testing/selftests/kvm/lib/x86/sev.c b/tools/testing/selftests/kvm/lib/x86/sev.c index e9535ee20b7f..c3a9838f4806 100644 --- a/tools/testing/selftests/kvm/lib/x86/sev.c +++ b/tools/testing/selftests/kvm/lib/x86/sev.c @@ -14,7 +14,8 @@   * and find the first range, but that's correct because the condition   * expression would cause us to quit the loop.   */ -static void encrypt_region(struct kvm_vm *vm, struct userspace_mem_region *region) +static void encrypt_region(struct kvm_vm *vm, struct userspace_mem_region *region, +			   uint8_t page_type, bool private)  {  	const struct sparsebit *protected_phy_pages = region->protected_phy_pages;  	const vm_paddr_t gpa_base = region->region.guest_phys_addr; @@ -24,25 +25,35 @@ static void encrypt_region(struct kvm_vm *vm, struct userspace_mem_region *regio  	if (!sparsebit_any_set(protected_phy_pages))  		return; -	sev_register_encrypted_memory(vm, region); +	if (!is_sev_snp_vm(vm)) +		sev_register_encrypted_memory(vm, region);  	sparsebit_for_each_set_range(protected_phy_pages, i, j) {  		const uint64_t size = (j - i + 1) * vm->page_size;  		const uint64_t offset = (i - lowest_page_in_region) * vm->page_size; -		sev_launch_update_data(vm, gpa_base + offset, size); +		if (private) +			vm_mem_set_private(vm, gpa_base + offset, size); + +		if (is_sev_snp_vm(vm)) +			snp_launch_update_data(vm, gpa_base + offset, +					       (uint64_t)addr_gpa2hva(vm, gpa_base + offset), +					       size, page_type); +		else +			sev_launch_update_data(vm, gpa_base + offset, size); +  	}  }  void sev_vm_init(struct kvm_vm *vm)  {  	if (vm->type == KVM_X86_DEFAULT_VM) { -		assert(vm->arch.sev_fd == -1); +		TEST_ASSERT_EQ(vm->arch.sev_fd, -1);  		vm->arch.sev_fd = open_sev_dev_path_or_exit();  		vm_sev_ioctl(vm, KVM_SEV_INIT, NULL);  	} else {  		struct kvm_sev_init init = { 0 }; -		assert(vm->type == KVM_X86_SEV_VM); +		TEST_ASSERT_EQ(vm->type, KVM_X86_SEV_VM);  		vm_sev_ioctl(vm, KVM_SEV_INIT2, &init);  	}  } @@ -50,16 +61,24 @@ void sev_vm_init(struct kvm_vm *vm)  void sev_es_vm_init(struct kvm_vm *vm)  {  	if (vm->type == KVM_X86_DEFAULT_VM) { -		assert(vm->arch.sev_fd == -1); +		TEST_ASSERT_EQ(vm->arch.sev_fd, -1);  		vm->arch.sev_fd = open_sev_dev_path_or_exit();  		vm_sev_ioctl(vm, KVM_SEV_ES_INIT, NULL);  	} else {  		struct kvm_sev_init init = { 0 }; -		assert(vm->type == KVM_X86_SEV_ES_VM); +		TEST_ASSERT_EQ(vm->type, KVM_X86_SEV_ES_VM);  		vm_sev_ioctl(vm, KVM_SEV_INIT2, &init);  	}  } +void snp_vm_init(struct kvm_vm *vm) +{ +	struct kvm_sev_init init = { 0 }; + +	TEST_ASSERT_EQ(vm->type, KVM_X86_SNP_VM); +	vm_sev_ioctl(vm, KVM_SEV_INIT2, &init); +} +  void sev_vm_launch(struct kvm_vm *vm, uint32_t policy)  {  	struct kvm_sev_launch_start launch_start = { @@ -76,7 +95,7 @@ void sev_vm_launch(struct kvm_vm *vm, uint32_t policy)  	TEST_ASSERT_EQ(status.state, SEV_GUEST_STATE_LAUNCH_UPDATE);  	hash_for_each(vm->regions.slot_hash, ctr, region, slot_node) -		encrypt_region(vm, region); +		encrypt_region(vm, region, KVM_SEV_PAGE_TYPE_INVALID, false);  	if (policy & SEV_POLICY_ES)  		vm_sev_ioctl(vm, KVM_SEV_LAUNCH_UPDATE_VMSA, NULL); @@ -112,6 +131,33 @@ void sev_vm_launch_finish(struct kvm_vm *vm)  	TEST_ASSERT_EQ(status.state, SEV_GUEST_STATE_RUNNING);  } +void snp_vm_launch_start(struct kvm_vm *vm, uint64_t policy) +{ +	struct kvm_sev_snp_launch_start launch_start = { +		.policy = policy, +	}; + +	vm_sev_ioctl(vm, KVM_SEV_SNP_LAUNCH_START, &launch_start); +} + +void snp_vm_launch_update(struct kvm_vm *vm) +{ +	struct userspace_mem_region *region; +	int ctr; + +	hash_for_each(vm->regions.slot_hash, ctr, region, slot_node) +		encrypt_region(vm, region, KVM_SEV_SNP_PAGE_TYPE_NORMAL, true); + +	vm->arch.is_pt_protected = true; +} + +void snp_vm_launch_finish(struct kvm_vm *vm) +{ +	struct kvm_sev_snp_launch_finish launch_finish = { 0 }; + +	vm_sev_ioctl(vm, KVM_SEV_SNP_LAUNCH_FINISH, &launch_finish); +} +  struct kvm_vm *vm_sev_create_with_one_vcpu(uint32_t type, void *guest_code,  					   struct kvm_vcpu **cpu)  { @@ -128,8 +174,20 @@ struct kvm_vm *vm_sev_create_with_one_vcpu(uint32_t type, void *guest_code,  	return vm;  } -void vm_sev_launch(struct kvm_vm *vm, uint32_t policy, uint8_t *measurement) +void vm_sev_launch(struct kvm_vm *vm, uint64_t policy, uint8_t *measurement)  { +	if (is_sev_snp_vm(vm)) { +		vm_enable_cap(vm, KVM_CAP_EXIT_HYPERCALL, BIT(KVM_HC_MAP_GPA_RANGE)); + +		snp_vm_launch_start(vm, policy); + +		snp_vm_launch_update(vm); + +		snp_vm_launch_finish(vm); + +		return; +	} +  	sev_vm_launch(vm, policy);  	if (!measurement) | 
