Re: [PATCH v8 3/9] x86/virt/tdx: Use auto-generated code to read global metadata

From: Huang, Kai
Date: Fri Dec 13 2024 - 06:18:22 EST


On Thu, 2024-11-14 at 00:57 +1300, Kai Huang wrote:
> The tdx_global_metadata.{hc} can be generated by running below:
>
>  #python tdx_global_metadata.py global_metadata.json \
> tdx_global_metadata.h tdx_global_metadata.c
>
> .. where the 'tdx_global_metadata.py' can be found in [4] and the
> 'global_metadata.json' can be fetched from [3].
>

[...]

> Link: https://cdrdv2.intel.com/v1/dl/getContent/795381 [3]
> Link: https://lore.kernel.org/d5aed06ae4b46df5db97fdbac9c01843920a2f96.camel@xxxxxxxxx/ [4]

Hi Dave,

I'll remove the bugfix (patch 8) and the CMR reading code but only keep what is
required for KVM TDX in the next version.

I updated the script and attached here so that I can have a lore link of the
script which can be used to reproduce the generated code in the next version.

It only has CMR part removed thus is not interesting to read. Thanks.
#! /usr/bin/env python3
import json
import sys

# Note: this script does not run as part of the build process.
# It is used to generate structs from the TDX global_metadata.json
# file, and functions to fill in said structs. Rerun it if
# you need more fields.

TDX_STRUCTS = {
"features": [
"TDX_FEATURES0"
],
"tdmr": [
"MAX_TDMRS",
"MAX_RESERVED_PER_TDMR",
"PAMT_4K_ENTRY_SIZE",
"PAMT_2M_ENTRY_SIZE",
"PAMT_1G_ENTRY_SIZE",
],
}

STRUCT_PREFIX = "tdx_sys_info"
FUNC_PREFIX = "get_tdx_sys_info"
STRVAR_PREFIX = "sysinfo"

def print_class_struct_field(field_name, element_bytes, num_fields, num_elements, file):
element_type = "u%s" % (element_bytes * 8)
element_array = ""
if num_fields > 1:
element_array += "[%d]" % (num_fields)
if num_elements > 1:
element_array += "[%d]" % (num_elements)
print("\t%s %s%s;" % (element_type, field_name, element_array), file=file)

def print_class_struct(class_name, fields, file):
struct_name = "%s_%s" % (STRUCT_PREFIX, class_name)
print("struct %s {" % (struct_name), file=file)
for f in fields:
print_class_struct_field(
f["Field Name"].lower(),
int(f["Element Size (Bytes)"]),
int(f["Num Fields"]),
int(f["Num Elements"]),
file=file)
print("};", file=file)

def print_read_field(field_id, struct_var, struct_member, indent, file):
print(
"%sif (!ret && !(ret = read_sys_metadata_field(%s, &val)))\n%s\t%s->%s = val;"
% (indent, field_id, indent, struct_var, struct_member),
file=file,
)

def print_class_function(class_name, fields, file):
func_name = "%s_%s" % (FUNC_PREFIX, class_name)
struct_name = "%s_%s" % (STRUCT_PREFIX, class_name)
struct_var = "%s_%s" % (STRVAR_PREFIX, class_name)

print("static int %s(struct %s *%s)" % (func_name, struct_name, struct_var), file=file)
print("{", file=file)
print("\tint ret = 0;", file=file)
print("\tu64 val;", file=file)

has_i = 0
has_j = 0
for f in fields:
num_fields = int(f["Num Fields"])
num_elements = int(f["Num Elements"])
if num_fields > 1:
has_i = 1
if num_elements > 1:
has_j = 1

if has_i == 1 and has_j == 1:
print("\tint i, j;", file=file)
elif has_i == 1:
print("\tint i;", file=file)

print(file=file)
for f in fields:
fname = f["Field Name"]
field_id = f["Base FIELD_ID (Hex)"]
num_fields = int(f["Num Fields"])
num_elements = int(f["Num Elements"])
struct_member = fname.lower()
indent = "\t"
if num_fields > 1:
if fname == "CMR_BASE" or fname == "CMR_SIZE":
limit = "%s_%s->num_cmrs" %(STRVAR_PREFIX, "cmr")
elif fname == "CPUID_CONFIG_LEAVES" or fname == "CPUID_CONFIG_VALUES":
limit = "%s_%s->num_cpuid_config" %(STRVAR_PREFIX, "td_conf")
else:
limit = "%d" %(num_fields)
print("%sfor (i = 0; i < %s; i++)" % (indent, limit), file=file)
indent += "\t"
field_id += " + i"
struct_member += "[i]"
if num_elements > 1:
print("%sfor (j = 0; j < %d; j++)" % (indent, num_elements), file=file)
indent += "\t"
field_id += " * %d + j" % (num_elements)
struct_member += "[j]"

print_read_field(
field_id,
struct_var,
struct_member,
indent,
file=file,
)

print(file=file)
print("\treturn ret;", file=file)
print("}", file=file)

def print_main_struct(file):
print("struct tdx_sys_info {", file=file)
for class_name, field_names in TDX_STRUCTS.items():
struct_name = "%s_%s" % (STRUCT_PREFIX, class_name)
struct_var = class_name
print("\tstruct %s %s;" % (struct_name, struct_var), file=file)
print("};", file=file)

def print_main_function(file):
print("static int get_tdx_sys_info(struct tdx_sys_info *sysinfo)", file=file)
print("{", file=file)
print("\tint ret = 0;", file=file)
print(file=file)
for class_name, field_names in TDX_STRUCTS.items():
func_name = "%s_%s" % (FUNC_PREFIX, class_name)
struct_var = class_name
print("\tret = ret ?: %s(&sysinfo->%s);" % (func_name, struct_var), file=file)
print(file=file)
print("\treturn ret;", file=file)
print("}", file=file)

jsonfile = sys.argv[1]
hfile = sys.argv[2]
cfile = sys.argv[3]
hfileifdef = hfile.replace(".", "_")

with open(jsonfile, "r") as f:
json_in = json.load(f)
fields = {x["Field Name"]: x for x in json_in["Fields"]}

with open(hfile, "w") as f:
print("/* SPDX-License-Identifier: GPL-2.0 */", file=f)
print("/* Automatically generated TDX global metadata structures. */", file=f)
print("#ifndef _X86_VIRT_TDX_AUTO_GENERATED_" + hfileifdef.upper(), file=f)
print("#define _X86_VIRT_TDX_AUTO_GENERATED_" + hfileifdef.upper(), file=f)
print(file=f)
print("#include <linux/types.h>", file=f)
print(file=f)
for class_name, field_names in TDX_STRUCTS.items():
print_class_struct(class_name, [fields[x] for x in field_names], file=f)
print(file=f)
print_main_struct(file=f)
print(file=f)
print("#endif", file=f)

with open(cfile, "w") as f:
print("// SPDX-License-Identifier: GPL-2.0", file=f)
print("/*", file=f)
print(" * Automatically generated functions to read TDX global metadata.", file=f)
print(" *", file=f)
print(" * This file doesn't compile on its own as it lacks of inclusion", file=f)
print(" * of SEAMCALL wrapper primitive which reads global metadata.", file=f)
print(" * Include this file to other C file instead.", file=f)
print(" */", file=f)
for class_name, field_names in TDX_STRUCTS.items():
print(file=f)
print_class_function(class_name, [fields[x] for x in field_names], file=f)
print(file=f)
print_main_function(file=f)