diff --git a/extra/nv_gpu_driver/nv_ioctl.py b/extra/nv_gpu_driver/nv_ioctl.py index 659d8b7c..d8d28006 100644 --- a/extra/nv_gpu_driver/nv_ioctl.py +++ b/extra/nv_gpu_driver/nv_ioctl.py @@ -1,6 +1,6 @@ # type: ignore import ctypes, ctypes.util, struct, platform, pathlib, re, time, os, signal -from tinygrad.helpers import from_mv, to_mv, getenv +from tinygrad.helpers import from_mv, to_mv, getenv, init_c_struct_t from hexdump import hexdump start = time.perf_counter() @@ -14,6 +14,7 @@ def get_struct(argp, stype): return ctypes.cast(ctypes.c_void_p(argp), ctypes.POINTER(stype)).contents def dump_struct(st): + if getenv("IOCTL", 0) == 0: return print("\t", st.__class__.__name__, end=" { ") for v in type(st)._fields_: print(f"{v[0]}={getattr(st, v[0])}", end=" ") print("}") @@ -87,31 +88,32 @@ def ioctl(fd, request, argp): fn = os.readlink(f"/proc/self/fd/{fd}") #print(f"ioctl {request:8x} {fn:20s}") idir, size, itype, nr = (request>>30), (request>>16)&0x3FFF, (request>>8)&0xFF, request&0xFF - print(f"#{global_ioctl_id}: ", end="") + if getenv("IOCTL", 0) >= 1: print(f"#{global_ioctl_id}: ", end="") if itype == ord(nv_gpu.NV_IOCTL_MAGIC): if nr == nv_gpu.NV_ESC_RM_CONTROL: s = get_struct(argp, nv_gpu.NVOS54_PARAMETERS) if s.cmd in nvcmds: name, struc = nvcmds[s.cmd] - print(f"NV_ESC_RM_CONTROL cmd={name:30s} hClient={s.hClient}, hObject={s.hObject}, flags={s.flags}, params={s.params}, paramsSize={s.paramsSize}, status={s.status}") + if getenv("IOCTL", 0) >= 1: + print(f"NV_ESC_RM_CONTROL cmd={name:30s} hClient={s.hClient}, hObject={s.hObject}, flags={s.flags}, params={s.params}, paramsSize={s.paramsSize}, status={s.status}") + if struc is not None: dump_struct(get_struct(s.params, struc)) elif hasattr(nv_gpu, name+"_PARAMS"): dump_struct(get_struct(argp, getattr(nv_gpu, name+"_PARAMS"))) elif name == "NVA06C_CTRL_CMD_GPFIFO_SCHEDULE": dump_struct(get_struct(argp, nv_gpu.NVA06C_CTRL_GPFIFO_SCHEDULE_PARAMS)) elif name == "NV83DE_CTRL_CMD_GET_MAPPINGS": dump_struct(get_struct(s.params, nv_gpu.NV83DE_CTRL_DEBUG_GET_MAPPINGS_PARAMETERS)) else: - print("unhandled cmd", hex(s.cmd)) + if getenv("IOCTL", 0) >= 1: print("unhandled cmd", hex(s.cmd)) # format_struct(s) # print(f"{(st-start)*1000:7.2f} ms +{et*1000.:7.2f} ms : {ret:2d} = {name:40s}", ' '.join(format_struct(s))) elif nr == nv_gpu.NV_ESC_RM_ALLOC: s = get_struct(argp, nv_gpu.NVOS21_PARAMETERS) - print(f"NV_ESC_RM_ALLOC hClass={nvclasses.get(s.hClass, f'unk=0x{s.hClass:X}'):30s}, hRoot={s.hRoot}, hObjectParent={s.hObjectParent}, pAllocParms={s.pAllocParms}, hObjectNew={s.hObjectNew} status={s.status}") + if getenv("IOCTL", 0) >= 1: print(f"NV_ESC_RM_ALLOC hClass={nvclasses.get(s.hClass, f'unk=0x{s.hClass:X}'):30s}, hRoot={s.hRoot}, hObjectParent={s.hObjectParent}, pAllocParms={s.pAllocParms}, hObjectNew={s.hObjectNew} status={s.status}") if s.pAllocParms is not None: if s.hClass == nv_gpu.NV01_DEVICE_0: dump_struct(get_struct(s.pAllocParms, nv_gpu.NV0080_ALLOC_PARAMETERS)) if s.hClass == nv_gpu.FERMI_VASPACE_A: dump_struct(get_struct(s.pAllocParms, nv_gpu.NV_VASPACE_ALLOCATION_PARAMETERS)) if s.hClass == nv_gpu.NV50_MEMORY_VIRTUAL: dump_struct(get_struct(s.pAllocParms, nv_gpu.NV_MEMORY_ALLOCATION_PARAMS)) if s.hClass == nv_gpu.NV1_MEMORY_USER: dump_struct(get_struct(s.pAllocParms, nv_gpu.NV_MEMORY_ALLOCATION_PARAMS)) if s.hClass == nv_gpu.NV1_MEMORY_SYSTEM: dump_struct(get_struct(s.pAllocParms, nv_gpu.NV_MEMORY_ALLOCATION_PARAMS)) - # if s.hClass == nv_gpu.NV1_EVENT_OS_EVENT: dump_struct(get_struct(s.pAllocParms, nv_gpu.NV0005_ALLOC_PARAMETERS)) if s.hClass == nv_gpu.GT200_DEBUGGER: dump_struct(get_struct(s.pAllocParms, nv_gpu.NV83DE_ALLOC_PARAMETERS)) if s.hClass == nv_gpu.AMPERE_CHANNEL_GPFIFO_A: sx = get_struct(s.pAllocParms, nv_gpu.NV_CHANNELGPFIFO_ALLOCATION_PARAMETERS) @@ -121,32 +123,38 @@ def ioctl(fd, request, argp): if s.hClass == nv_gpu.TURING_USERMODE_A: gpus_user_modes.append(s.hObjectNew) elif nr == nv_gpu.NV_ESC_RM_MAP_MEMORY: # nv_ioctl_nvos33_parameters_with_fd - s = get_struct(argp, nv_gpu.NVOS33_PARAMETERS) - print(f"NV_ESC_RM_MAP_MEMORY hClient={s.hClient}, hDevice={s.hDevice}, hMemory={s.hMemory}, length={s.length} flags={s.flags} pLinearAddress={s.pLinearAddress}") + if getenv("IOCTL", 0) >= 1: + s = get_struct(argp, nv_gpu.NVOS33_PARAMETERS) + print(f"NV_ESC_RM_MAP_MEMORY hClient={s.hClient}, hDevice={s.hDevice}, hMemory={s.hMemory}, length={s.length} flags={s.flags} pLinearAddress={s.pLinearAddress}") elif nr == nv_gpu.NV_ESC_RM_UPDATE_DEVICE_MAPPING_INFO: - s = get_struct(argp, nv_gpu.NVOS56_PARAMETERS) - print(f"NV_ESC_RM_UPDATE_DEVICE_MAPPING_INFO hClient={s.hClient}, hDevice={s.hDevice}, hMemory={s.hMemory}, pOldCpuAddress={s.pOldCpuAddress} pNewCpuAddress={s.pNewCpuAddress} status={s.status}") + if getenv("IOCTL", 0) >= 1: + s = get_struct(argp, nv_gpu.NVOS56_PARAMETERS) + print(f"NV_ESC_RM_UPDATE_DEVICE_MAPPING_INFO hClient={s.hClient}, hDevice={s.hDevice}, hMemory={s.hMemory}, pOldCpuAddress={s.pOldCpuAddress} pNewCpuAddress={s.pNewCpuAddress} status={s.status}") elif nr == nv_gpu.NV_ESC_RM_ALLOC_MEMORY: - s = get_struct(argp, nv_gpu.nv_ioctl_nvos02_parameters_with_fd) - print(f"NV_ESC_RM_ALLOC_MEMORY fd={s.fd}, hRoot={s.params.hRoot}, hObjectParent={s.params.hObjectParent}, hObjectNew={s.params.hObjectNew}, hClass={s.params.hClass}, flags={s.params.flags}, pMemory={s.params.pMemory}, limit={s.params.limit}, status={s.params.status}") + if getenv("IOCTL", 0) >= 1: + s = get_struct(argp, nv_gpu.nv_ioctl_nvos02_parameters_with_fd) + print(f"NV_ESC_RM_ALLOC_MEMORY fd={s.fd}, hRoot={s.params.hRoot}, hObjectParent={s.params.hObjectParent}, hObjectNew={s.params.hObjectNew}, hClass={s.params.hClass}, flags={s.params.flags}, pMemory={s.params.pMemory}, limit={s.params.limit}, status={s.params.status}") elif nr == nv_gpu.NV_ESC_ALLOC_OS_EVENT: - s = get_struct(argp, nv_gpu.nv_ioctl_nvos02_parameters_with_fd) + if getenv("IOCTL", 0) >= 1: + s = get_struct(argp, nv_gpu.nv_ioctl_alloc_os_event_t) + print(f"NV_ESC_ALLOC_OS_EVENT hClient={s.hClient} hDevice={s.hDevice} fd={s.fd} Status={s.Status}") elif nr == nv_gpu.NV_ESC_REGISTER_FD: - s = get_struct(argp, nv_gpu.nv_ioctl_register_fd_t) - print(f"NV_ESC_REGISTER_FD fd={s.ctl_fd}") + if getenv("IOCTL", 0) >= 1: + s = get_struct(argp, nv_gpu.nv_ioctl_register_fd_t) + print(f"NV_ESC_REGISTER_FD fd={s.ctl_fd}") elif nr in nvescs: - print(nvescs[nr]) + if getenv("IOCTL", 0) >= 1: print(nvescs[nr]) else: - print("unhandled NR", nr) + if getenv("IOCTL", 0) >= 1: print("unhandled NR", nr) elif fn.endswith("nvidia-uvm"): - print(f"{nvuvms.get(request, f'UVM UNKNOWN {request=}')}") - if nvuvms.get(request) is not None: dump_struct(get_struct(argp, getattr(nv_gpu, nvuvms.get(request)+"_PARAMS"))) - if nvuvms.get(request) == "UVM_MAP_EXTERNAL_ALLOCATION": - st = get_struct(argp, getattr(nv_gpu, nvuvms.get(request)+"_PARAMS")) - for i in range(st.gpuAttributesCount): - print("perGpuAttributes[{i}] = ", end="") - dump_struct(st.perGpuAttributes[i]) - print("ok") + if getenv("IOCTL", 0) >= 1: + print(f"{nvuvms.get(request, f'UVM UNKNOWN {request=}')}") + if nvuvms.get(request) is not None: dump_struct(get_struct(argp, getattr(nv_gpu, nvuvms.get(request)+"_PARAMS"))) + if nvuvms.get(request) == "UVM_MAP_EXTERNAL_ALLOCATION": + st = get_struct(argp, getattr(nv_gpu, nvuvms.get(request)+"_PARAMS")) + for i in range(st.gpuAttributesCount): + print("perGpuAttributes[{i}] = ", end="") + dump_struct(st.perGpuAttributes[i]) if getenv("IOCTL") >= 2: print("ioctl", f"{idir=} {size=} {itype=} {nr=} {fd=} {ret=}", fn) return ret @@ -166,23 +174,40 @@ if getenv("IOCTL") >= 3: orig_mmap_mv = install_hook(libc.mmap, _mmap) import collections old_gpputs = collections.defaultdict(int) def _dump_gpfifo(mark): - print("_dump_gpfifo:", mark) + launches = [] + + # print("_dump_gpfifo:", mark) for start,size in gpus_fifo: gpfifo_controls = nv_gpu.AmpereAControlGPFifo.from_address(start+size*8) - gpfifo = to_mv(start, gpfifo_controls.GPPut * 8).cast("Q") - if old_gpputs[start] == gpfifo_controls.GPPut: continue + gpfifo = to_mv(start, size * 8).cast("Q") + while old_gpputs[start] != gpfifo_controls.GPPut: + addr = ((gpfifo[old_gpputs[start]] & ((1 << 40)-1)) >> 2) << 2 + pckt_cnt = (gpfifo[old_gpputs[start]]>>42)&((1 << 20)-1) - print(f"gpfifo {start}: {gpfifo_controls.GPPut=}") - for i in range(old_gpputs[start], gpfifo_controls.GPPut): - addr = ((gpfifo[i % size] & ((1 << 40)-1)) >> 2) << 2 - pckt_cnt = (gpfifo[i % size]>>42)&((1 << 20)-1) - - print(f"\t{i}: 0x{gpfifo[i % size]:x}: addr:0x{addr:x} packets:{pckt_cnt} sync:{(gpfifo[i % size] >> 63) & 0x1} fetch:{gpfifo[i % size] & 0x1}") - old_gpputs[start] = gpfifo_controls.GPPut - _dump_qmd(addr, pckt_cnt) + # print(f"\t{i}: 0x{gpfifo[i % size]:x}: addr:0x{addr:x} packets:{pckt_cnt} sync:{(gpfifo[i % size] >> 63) & 0x1} fetch:{gpfifo[i % size] & 0x1}") + x = _dump_qmd(addr, pckt_cnt) + if isinstance(x, list): launches += x + old_gpputs[start] += 1 + old_gpputs[start] %= size + return launches import types +def make_qmd_struct_type(): + fields: List[Tuple[str, Union[Type[ctypes.c_uint64], Type[ctypes.c_uint32]], Any]] = [] + bits = [(name,dt) for name,dt in nv_gpu.__dict__.items() if name.startswith("NVC6C0_QMDV03_00") and isinstance(dt, tuple)] + bits += [(name+f"_{i}",dt(i)) for name,dt in nv_gpu.__dict__.items() for i in range(8) if name.startswith("NVC6C0_QMDV03_00") and callable(dt)] + bits = sorted(bits, key=lambda x: x[1][1]) + for i,(name, data) in enumerate(bits): + if i > 0 and (gap:=(data[1] - bits[i-1][1][0] - 1)) != 0: fields.append((f"_reserved{i}", ctypes.c_uint32, gap)) + fields.append((name.replace("NVC6C0_QMDV03_00_", "").lower(), ctypes.c_uint32, data[0]-data[1]+1)) + if len(fields) >= 2 and fields[-2][0].endswith('_lower') and fields[-1][0].endswith('_upper') and fields[-1][0][:-6] == fields[-2][0][:-6]: + fields = fields[:-2] + [(fields[-1][0][:-6], ctypes.c_uint64, fields[-1][2] + fields[-2][2])] + return init_c_struct_t(tuple(fields)) +qmd_struct_t = make_qmd_struct_type() +assert ctypes.sizeof(qmd_struct_t) == 0x40 * 4 + def _dump_qmd(address, packets): + qmds = [] gpfifo = to_mv(address, packets * 4).cast("I") i = 0 @@ -194,27 +219,49 @@ def _dump_qmd(address, packets): subc = (dat>>13) & 7 mthd = (dat<<2) & 0x7FFF method_name = nvqcmds.get(mthd, f"unknown method #{mthd}") - print(f"\t\t{method_name}, {typ=} {size=} {subc=} {mthd=}") - for j in range(size): print(f"\t\t\t{j}: {gpfifo[i+j+1]} | 0x{gpfifo[i+j+1]:x}") + if getenv("IOCTL", 0) >= 1: + print(f"\t\t{method_name}, {typ=} {size=} {subc=} {mthd=}") + for j in range(size): print(f"\t\t\t{j}: {gpfifo[i+j+1]} | 0x{gpfifo[i+j+1]:x}") if mthd == 792: - for x in dir(nv_gpu): - if x.startswith("NVC6C0_QMDV03_00_"): - vv = getattr(nv_gpu, x) - bits = None - if isinstance(vv, tuple) and len(vv) == 2: - bits = vv - if isinstance(vv, types.FunctionType): - bits = vv(0) - - if bits is not None: - res = 0 - for bt in range(bits[1], bits[0]+1): res |= ((gpfifo[i + 3 + bt // 32] >> (bt % 32)) & 0x1) << (bt - bits[1]) - if res != 0: print(f"{x}, {hex(res)} | {bin(res)}") - - const_addr = gpfifo[i+35] + ((gpfifo[i+36] & 0xffff) << 32) - const_len = ((gpfifo[i+36] >> 19)) - # hexdump(to_mv(const_addr, const_len)) + qmds.append(qmd_struct_t.from_address(address + 12 + i * 4)) + elif mthd == nv_gpu.NVC6C0_SEND_PCAS_A: + qmds.append(qmd_struct_t.from_address(gpfifo[i+1] << 8)) i += size + 1 + return qmds + +# This is to be used in fuzzer, check cuda/nv side by side. +# Return a state which should be compare and compare function. +def before_launch(): _dump_gpfifo("before launch") +def collect_last_launch_state(): return _dump_gpfifo("after launch") + +def compare_launch_state(states1, states2): + states1 = states1 or list() + states2 = states2 or list() + if len(states1) != 1 or len(states2) != 1: + return False, f"Some states not captured. {len(states1)}!=1 || {len(states2)}!=1" + + for i in range(len(states1)): + state1, state2 = states1[i], states2[i] + + for n in ['qmd_major_version', 'invalidate_shader_data_cache', 'invalidate_shader_data_cache', + 'sm_global_caching_enable', 'invalidate_texture_header_cache', 'invalidate_texture_sampler_cache', + 'barrier_count', 'sampler_index', 'api_visible_call_limit', 'cwd_membar_type', 'sass_version', + 'min_sm_config_shared_mem_size', 'max_sm_config_shared_mem_size', 'register_count_v', + 'target_sm_config_shared_mem_size', 'shared_memory_size']: + if getattr(state1, n) != getattr(state1, n): + return False, f"Field {n} mismatch: {getattr(state1, n)} vs {getattr(state2, n)}" + + # Allow NV to allocate more, at least this is not exact problem, so ignore it here. + # Hmm, CUDA minimum is 0x640, is this hw-required minimum (will check)? + if state1.shader_local_memory_high_size < state2.shader_local_memory_high_size and state2.shader_local_memory_high_size > 0x640: + return False, f"Field shader_local_memory_high_size mismatch: {state1.shader_local_memory_high_size} vs {state2.shader_local_memory_high_size}" + + for i in range(8): + n = f"constant_buffer_valid_{i}" + if getattr(state1, n) != getattr(state1, n): + return False, f"Field {n} mismatch: {getattr(state1, n)} vs {getattr(state2, n)}" + + return True, "PASS" # IOCTL=1 PTX=1 CUDA=1 python3 test/test_ops.py TestOps.test_tiny_add \ No newline at end of file diff --git a/test/external/fuzz_linearizer.py b/test/external/fuzz_linearizer.py index 9ce2d858..fab8ec6f 100644 --- a/test/external/fuzz_linearizer.py +++ b/test/external/fuzz_linearizer.py @@ -1,5 +1,5 @@ import random, traceback, ctypes, argparse -from typing import List, Tuple, DefaultDict +from typing import List, Tuple, DefaultDict, Any import numpy as np from collections import defaultdict from extra.optimization.helpers import load_worlds, ast_str_to_lin, kern_str_to_lin @@ -14,6 +14,21 @@ from tinygrad.helpers import getenv, from_mv, prod, colored, Context, DEBUG, Tim from tinygrad.ops import UnaryOps, UOp, UOps from test.helpers import is_dtype_supported +def on_linearizer_will_run(): pass +def on_linearizer_did_run(): pass +def compare_states(x, y): return True + +if getenv("VALIDATE_HCQ"): + if Device.DEFAULT == "NV": + print("VALIDATE_HCQ: Comparing NV to CUDA") + import extra.nv_gpu_driver.nv_ioctl + validate_device = Device["CUDA"] + on_linearizer_will_run = extra.nv_gpu_driver.nv_ioctl.before_launch + on_linearizer_did_run = extra.nv_gpu_driver.nv_ioctl.collect_last_launch_state + compare_states = extra.nv_gpu_driver.nv_ioctl.compare_launch_state + else: + print(colored("VALIDATE_HCQ options is ignored", 'red')) + def tuplize_uops(uops:List[UOp]) -> Tuple: return tuple([(x.op, x.dtype, tuple(uops.index(x) for x in x.src), x.arg) for x in uops]) @@ -39,36 +54,44 @@ def get_fuzz_rawbufs(lin): data = np.random.uniform(-1, 1, size=rawbuf.size).astype(dtype=_to_np_dtype(rawbuf.dtype)) else: data = np.random.uniform(-10, 10, size=rawbuf.size).astype(dtype=_to_np_dtype(rawbuf.dtype)) - rawbuf.copyin(Tensor(data).realize().lazydata.realized.as_buffer()) + rawbuf.copyin(Tensor(data, device=lin.opts.device).realize().lazydata.realized.as_buffer()) return rawbufs -def get_fuzz_rawbuf_like(rawbuf, zero=False, size=None): - rawbuf = type(rawbuf)(Device.DEFAULT, rawbuf.size if size is None else size, rawbuf.dtype).allocate() - if zero: +def get_fuzz_rawbuf_like(old_rawbuf, zero=False, copy=False, size=None, force_device=None): + rawbuf = type(old_rawbuf)(force_device or old_rawbuf.device, old_rawbuf.size if size is None else size, old_rawbuf.dtype).allocate() + if copy: + with Context(DEBUG=0): rawbuf.copyin(old_rawbuf.as_buffer()) + elif zero: with Context(DEBUG=0): mv = memoryview(bytearray(rawbuf.size * rawbuf.dtype.itemsize)) ctypes.memset(from_mv(mv), 0, len(mv)) rawbuf.copyin(mv) return rawbuf -def run_linearizer(lin: Kernel, rawbufs=None, var_vals=None): +def run_linearizer(lin: Kernel, rawbufs=None, var_vals=None) -> Tuple[str, Any]: # (error msg, run state) if rawbufs is None: rawbufs = bufs_from_lin(lin) if var_vals is None: var_vals = {v: v.min for v in lin.ast[0].vars()} # TODO: images needs required_optimization try: prg = CompiledRunner(lin.to_program()) + except KeyboardInterrupt: raise except Exception: traceback.print_exc() - return "COMPILE_ERROR" + return "COMPILE_ERROR", None + if getenv("VALIDATE_HCQ"): on_linearizer_will_run() try: prg(rawbufs, var_vals, wait=True) + except KeyboardInterrupt: raise except Exception: traceback.print_exc() - return "EXEC_ERROR" + return "EXEC_ERROR", None - return "PASS" + if getenv("VALIDATE_HCQ"): run_state = on_linearizer_did_run() + else: run_state = None + + return "PASS", run_state def compare_linearizer(lin: Kernel, rawbufs=None, var_vals=None, ground_truth=None, rtol=1e-2, atol=1e-2): # TODO: for bfloat16 it compiles linearizer, but it does not run because numpy cannot generate bf16 buffer. @@ -80,8 +103,9 @@ def compare_linearizer(lin: Kernel, rawbufs=None, var_vals=None, ground_truth=No rawbufs = get_fuzz_rawbufs(lin) else: rawbufs[0] = get_fuzz_rawbuf_like(rawbufs[0], zero=True) # get a new output buffer + except KeyboardInterrupt: raise except BaseException: - return ("RAWBUFS_ERROR", rawbufs, var_vals, ground_truth,) + return ("RAWBUFS_ERROR", rawbufs, var_vals, ground_truth, None) if var_vals is None: # TODO: handle symbolic max case @@ -90,18 +114,19 @@ def compare_linearizer(lin: Kernel, rawbufs=None, var_vals=None, ground_truth=No if ground_truth is None and not has_bf16: unoptimized = Kernel(lin.ast) unoptimized.required_optimizations() - if run_linearizer(unoptimized, rawbufs, var_vals) != "PASS": - return ("BASELINE_ERROR", rawbufs, var_vals, ground_truth,) + if run_linearizer(unoptimized, rawbufs, var_vals)[0] != "PASS": + return ("BASELINE_ERROR", rawbufs, var_vals, ground_truth, None) ground_truth = np.frombuffer(rawbufs[0].as_buffer(), _to_np_dtype(rawbufs[0].dtype)).copy() rawbufs[0] = get_fuzz_rawbuf_like(rawbufs[0], zero=True) # get a new output buffer - if (run_msg := run_linearizer(lin, rawbufs, var_vals)) != "PASS": - return (run_msg, rawbufs, var_vals, ground_truth,) + run_msg, run_state = run_linearizer(lin, rawbufs, var_vals) + if run_msg != "PASS": return (run_msg, rawbufs, var_vals, ground_truth, run_state) try: if not has_bf16: result = np.frombuffer(rawbufs[0].as_buffer(), _to_np_dtype(rawbufs[0].dtype)) np.testing.assert_allclose(result, ground_truth, rtol=rtol, atol=atol) + except KeyboardInterrupt: raise except AssertionError as e: if DEBUG >= 2: print(f"COMPARE_ERROR details: {e}") @@ -111,9 +136,9 @@ def compare_linearizer(lin: Kernel, rawbufs=None, var_vals=None, ground_truth=No mismatched_ground_truth = ground_truth[mismatch_indices] for i, idx in enumerate(mismatch_indices[0]): print(f"mismatch at {idx=}: result={mismatched_result[i]} <> ground_truth={mismatched_ground_truth[i]}") - return ("COMPARE_ERROR", rawbufs, var_vals, ground_truth,) + return ("COMPARE_ERROR", rawbufs, var_vals, ground_truth, run_state) - return ("PASS", rawbufs, var_vals, ground_truth,) + return ("PASS", rawbufs, var_vals, ground_truth, run_state) def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2): SEED = getenv("SEED", 42) @@ -124,7 +149,7 @@ def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2): seen_uops = {} last_lins = [lin] failures:DefaultDict[str, List[Tuple[Tuple[UOp,...],List[Opt]]]] = defaultdict(list) - rawbufs, var_vals, ground_truth = None, None, None + rawbufs, var_vals, ground_truth, validate_rawbufs = None, None, None, None FUZZ_ALL_ACTIONS = getenv("FUZZ_ALL_ACTIONS", 0) FUZZ_MAX_SIZE = getenv("FUZZ_MAX_SIZE", 0) @@ -156,6 +181,7 @@ def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2): # stop if kernel uops repeat try: tuops = tuplize_uops(test_lin.linearize().uops) + except KeyboardInterrupt: raise except BaseException as e: print(test_lin.ast) print(test_lin.applied_opts) @@ -168,7 +194,19 @@ def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2): if not FUZZ_ALL_ACTIONS: print(test_lin.colored_shape()) - (msg, rawbufs, var_vals, ground_truth) = compare_linearizer(test_lin, rawbufs, var_vals, ground_truth, rtol=rtol, atol=atol) + (msg, rawbufs, var_vals, ground_truth, state1) = compare_linearizer(test_lin, rawbufs, var_vals, ground_truth, rtol=rtol, atol=atol) + if state1 is not None and validate_device is not None: + validate_lin = test_lin.copy() + validate_lin.opts = validate_device.renderer + if validate_rawbufs is None: + validate_rawbufs = [get_fuzz_rawbuf_like(x, copy=True, force_device=validate_device.dname) for x in rawbufs] + (_msg, _, _, _, state2) = compare_linearizer(validate_lin, validate_rawbufs, var_vals, ground_truth, rtol=rtol, atol=atol) + + if _msg != "PASS": failures[f"VALIDATE_DEV_{_msg}"].append((validate_lin.ast, validate_lin.applied_opts)) + + ok, err_msg = compare_states(state1, state2) + if not ok: failures["HCQ_COMPARE_FAILURE"].append((err_msg, test_lin.ast, test_lin.applied_opts, state1, state2)) + if msg != "PASS": print(test_lin.ast) print(test_lin.applied_opts) @@ -220,28 +258,31 @@ if __name__ == "__main__": failed_ids = [] failures = defaultdict(list) seen_ast_strs = set() - for i, ast in enumerate(ast_strs[:getenv("FUZZ_N", len(ast_strs))]): - if (nth := getenv("FUZZ_NTH", -1)) != -1 and i != nth: continue - if "dtypes.image" in ast and Device.DEFAULT != "GPU": continue # IMAGE is only for GPU - if ast in seen_ast_strs: continue - seen_ast_strs.add(ast) - lin = ast_str_to_lin(ast) - if not all(is_dtype_supported(buf.dtype) for buf in lin.bufs): - print("skipping kernel due to not supported dtype") - continue + try: + for i, ast in enumerate(ast_strs[:getenv("FUZZ_N", len(ast_strs))]): + if (nth := getenv("FUZZ_NTH", -1)) != -1 and i != nth: continue + if "dtypes.image" in ast and Device.DEFAULT != "GPU": continue # IMAGE is only for GPU + if ast in seen_ast_strs: continue + seen_ast_strs.add(ast) - with Timing(f"tested ast {i}: "): - tested += 1 - fuzz_failures = fuzz_linearizer(lin, rtol=args.rtol, atol=args.atol) - if fuzz_failures: failed_ids.append(i) - for k, v in fuzz_failures.items(): - for f in v: - failures[k].append(f) + lin = ast_str_to_lin(ast) + if not all(is_dtype_supported(buf.dtype) for buf in lin.bufs): + print("skipping kernel due to not supported dtype") + continue + + with Timing(f"tested ast {i}: "): + tested += 1 + fuzz_failures = fuzz_linearizer(lin, rtol=args.rtol, atol=args.atol) + if fuzz_failures: failed_ids.append(i) + for k, v in fuzz_failures.items(): + for f in v: + failures[k].append(f) + except KeyboardInterrupt: print(colored("STOPPING...", 'red')) for msg, errors in failures.items(): - for i, (ast, opts) in enumerate(errors): - print(f"{msg} {i} kernel: {(ast,opts)}") # easier to use with output with verify_kernel.py + for i, payload in enumerate(errors): + print(f"{msg} {i} kernel: {payload}") # easier to use with output with verify_kernel.py print(f"{tested=}") if failures: