fuzz nv vs cuda (#7066)

* fuzz nv vs cuda

* fixes

* smth

* um

* cmp the same

* dnrt

* correct gpfifo scan

* fix
This commit is contained in:
nimlgen 2024-10-15 22:22:40 +03:00 committed by GitHub
parent 8ff6514ba3
commit b025495e5c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 179 additions and 91 deletions

View File

@ -1,6 +1,6 @@
# type: ignore
import ctypes, ctypes.util, struct, platform, pathlib, re, time, os, signal
from tinygrad.helpers import from_mv, to_mv, getenv
from tinygrad.helpers import from_mv, to_mv, getenv, init_c_struct_t
from hexdump import hexdump
start = time.perf_counter()
@ -14,6 +14,7 @@ def get_struct(argp, stype):
return ctypes.cast(ctypes.c_void_p(argp), ctypes.POINTER(stype)).contents
def dump_struct(st):
if getenv("IOCTL", 0) == 0: return
print("\t", st.__class__.__name__, end=" { ")
for v in type(st)._fields_: print(f"{v[0]}={getattr(st, v[0])}", end=" ")
print("}")
@ -87,31 +88,32 @@ def ioctl(fd, request, argp):
fn = os.readlink(f"/proc/self/fd/{fd}")
#print(f"ioctl {request:8x} {fn:20s}")
idir, size, itype, nr = (request>>30), (request>>16)&0x3FFF, (request>>8)&0xFF, request&0xFF
print(f"#{global_ioctl_id}: ", end="")
if getenv("IOCTL", 0) >= 1: print(f"#{global_ioctl_id}: ", end="")
if itype == ord(nv_gpu.NV_IOCTL_MAGIC):
if nr == nv_gpu.NV_ESC_RM_CONTROL:
s = get_struct(argp, nv_gpu.NVOS54_PARAMETERS)
if s.cmd in nvcmds:
name, struc = nvcmds[s.cmd]
print(f"NV_ESC_RM_CONTROL cmd={name:30s} hClient={s.hClient}, hObject={s.hObject}, flags={s.flags}, params={s.params}, paramsSize={s.paramsSize}, status={s.status}")
if getenv("IOCTL", 0) >= 1:
print(f"NV_ESC_RM_CONTROL cmd={name:30s} hClient={s.hClient}, hObject={s.hObject}, flags={s.flags}, params={s.params}, paramsSize={s.paramsSize}, status={s.status}")
if struc is not None: dump_struct(get_struct(s.params, struc))
elif hasattr(nv_gpu, name+"_PARAMS"): dump_struct(get_struct(argp, getattr(nv_gpu, name+"_PARAMS")))
elif name == "NVA06C_CTRL_CMD_GPFIFO_SCHEDULE": dump_struct(get_struct(argp, nv_gpu.NVA06C_CTRL_GPFIFO_SCHEDULE_PARAMS))
elif name == "NV83DE_CTRL_CMD_GET_MAPPINGS": dump_struct(get_struct(s.params, nv_gpu.NV83DE_CTRL_DEBUG_GET_MAPPINGS_PARAMETERS))
else:
print("unhandled cmd", hex(s.cmd))
if getenv("IOCTL", 0) >= 1: print("unhandled cmd", hex(s.cmd))
# format_struct(s)
# print(f"{(st-start)*1000:7.2f} ms +{et*1000.:7.2f} ms : {ret:2d} = {name:40s}", ' '.join(format_struct(s)))
elif nr == nv_gpu.NV_ESC_RM_ALLOC:
s = get_struct(argp, nv_gpu.NVOS21_PARAMETERS)
print(f"NV_ESC_RM_ALLOC hClass={nvclasses.get(s.hClass, f'unk=0x{s.hClass:X}'):30s}, hRoot={s.hRoot}, hObjectParent={s.hObjectParent}, pAllocParms={s.pAllocParms}, hObjectNew={s.hObjectNew} status={s.status}")
if getenv("IOCTL", 0) >= 1: print(f"NV_ESC_RM_ALLOC hClass={nvclasses.get(s.hClass, f'unk=0x{s.hClass:X}'):30s}, hRoot={s.hRoot}, hObjectParent={s.hObjectParent}, pAllocParms={s.pAllocParms}, hObjectNew={s.hObjectNew} status={s.status}")
if s.pAllocParms is not None:
if s.hClass == nv_gpu.NV01_DEVICE_0: dump_struct(get_struct(s.pAllocParms, nv_gpu.NV0080_ALLOC_PARAMETERS))
if s.hClass == nv_gpu.FERMI_VASPACE_A: dump_struct(get_struct(s.pAllocParms, nv_gpu.NV_VASPACE_ALLOCATION_PARAMETERS))
if s.hClass == nv_gpu.NV50_MEMORY_VIRTUAL: dump_struct(get_struct(s.pAllocParms, nv_gpu.NV_MEMORY_ALLOCATION_PARAMS))
if s.hClass == nv_gpu.NV1_MEMORY_USER: dump_struct(get_struct(s.pAllocParms, nv_gpu.NV_MEMORY_ALLOCATION_PARAMS))
if s.hClass == nv_gpu.NV1_MEMORY_SYSTEM: dump_struct(get_struct(s.pAllocParms, nv_gpu.NV_MEMORY_ALLOCATION_PARAMS))
# if s.hClass == nv_gpu.NV1_EVENT_OS_EVENT: dump_struct(get_struct(s.pAllocParms, nv_gpu.NV0005_ALLOC_PARAMETERS))
if s.hClass == nv_gpu.GT200_DEBUGGER: dump_struct(get_struct(s.pAllocParms, nv_gpu.NV83DE_ALLOC_PARAMETERS))
if s.hClass == nv_gpu.AMPERE_CHANNEL_GPFIFO_A:
sx = get_struct(s.pAllocParms, nv_gpu.NV_CHANNELGPFIFO_ALLOCATION_PARAMETERS)
@ -121,32 +123,38 @@ def ioctl(fd, request, argp):
if s.hClass == nv_gpu.TURING_USERMODE_A: gpus_user_modes.append(s.hObjectNew)
elif nr == nv_gpu.NV_ESC_RM_MAP_MEMORY:
# nv_ioctl_nvos33_parameters_with_fd
s = get_struct(argp, nv_gpu.NVOS33_PARAMETERS)
print(f"NV_ESC_RM_MAP_MEMORY hClient={s.hClient}, hDevice={s.hDevice}, hMemory={s.hMemory}, length={s.length} flags={s.flags} pLinearAddress={s.pLinearAddress}")
if getenv("IOCTL", 0) >= 1:
s = get_struct(argp, nv_gpu.NVOS33_PARAMETERS)
print(f"NV_ESC_RM_MAP_MEMORY hClient={s.hClient}, hDevice={s.hDevice}, hMemory={s.hMemory}, length={s.length} flags={s.flags} pLinearAddress={s.pLinearAddress}")
elif nr == nv_gpu.NV_ESC_RM_UPDATE_DEVICE_MAPPING_INFO:
s = get_struct(argp, nv_gpu.NVOS56_PARAMETERS)
print(f"NV_ESC_RM_UPDATE_DEVICE_MAPPING_INFO hClient={s.hClient}, hDevice={s.hDevice}, hMemory={s.hMemory}, pOldCpuAddress={s.pOldCpuAddress} pNewCpuAddress={s.pNewCpuAddress} status={s.status}")
if getenv("IOCTL", 0) >= 1:
s = get_struct(argp, nv_gpu.NVOS56_PARAMETERS)
print(f"NV_ESC_RM_UPDATE_DEVICE_MAPPING_INFO hClient={s.hClient}, hDevice={s.hDevice}, hMemory={s.hMemory}, pOldCpuAddress={s.pOldCpuAddress} pNewCpuAddress={s.pNewCpuAddress} status={s.status}")
elif nr == nv_gpu.NV_ESC_RM_ALLOC_MEMORY:
s = get_struct(argp, nv_gpu.nv_ioctl_nvos02_parameters_with_fd)
print(f"NV_ESC_RM_ALLOC_MEMORY fd={s.fd}, hRoot={s.params.hRoot}, hObjectParent={s.params.hObjectParent}, hObjectNew={s.params.hObjectNew}, hClass={s.params.hClass}, flags={s.params.flags}, pMemory={s.params.pMemory}, limit={s.params.limit}, status={s.params.status}")
if getenv("IOCTL", 0) >= 1:
s = get_struct(argp, nv_gpu.nv_ioctl_nvos02_parameters_with_fd)
print(f"NV_ESC_RM_ALLOC_MEMORY fd={s.fd}, hRoot={s.params.hRoot}, hObjectParent={s.params.hObjectParent}, hObjectNew={s.params.hObjectNew}, hClass={s.params.hClass}, flags={s.params.flags}, pMemory={s.params.pMemory}, limit={s.params.limit}, status={s.params.status}")
elif nr == nv_gpu.NV_ESC_ALLOC_OS_EVENT:
s = get_struct(argp, nv_gpu.nv_ioctl_nvos02_parameters_with_fd)
if getenv("IOCTL", 0) >= 1:
s = get_struct(argp, nv_gpu.nv_ioctl_alloc_os_event_t)
print(f"NV_ESC_ALLOC_OS_EVENT hClient={s.hClient} hDevice={s.hDevice} fd={s.fd} Status={s.Status}")
elif nr == nv_gpu.NV_ESC_REGISTER_FD:
s = get_struct(argp, nv_gpu.nv_ioctl_register_fd_t)
print(f"NV_ESC_REGISTER_FD fd={s.ctl_fd}")
if getenv("IOCTL", 0) >= 1:
s = get_struct(argp, nv_gpu.nv_ioctl_register_fd_t)
print(f"NV_ESC_REGISTER_FD fd={s.ctl_fd}")
elif nr in nvescs:
print(nvescs[nr])
if getenv("IOCTL", 0) >= 1: print(nvescs[nr])
else:
print("unhandled NR", nr)
if getenv("IOCTL", 0) >= 1: print("unhandled NR", nr)
elif fn.endswith("nvidia-uvm"):
print(f"{nvuvms.get(request, f'UVM UNKNOWN {request=}')}")
if nvuvms.get(request) is not None: dump_struct(get_struct(argp, getattr(nv_gpu, nvuvms.get(request)+"_PARAMS")))
if nvuvms.get(request) == "UVM_MAP_EXTERNAL_ALLOCATION":
st = get_struct(argp, getattr(nv_gpu, nvuvms.get(request)+"_PARAMS"))
for i in range(st.gpuAttributesCount):
print("perGpuAttributes[{i}] = ", end="")
dump_struct(st.perGpuAttributes[i])
print("ok")
if getenv("IOCTL", 0) >= 1:
print(f"{nvuvms.get(request, f'UVM UNKNOWN {request=}')}")
if nvuvms.get(request) is not None: dump_struct(get_struct(argp, getattr(nv_gpu, nvuvms.get(request)+"_PARAMS")))
if nvuvms.get(request) == "UVM_MAP_EXTERNAL_ALLOCATION":
st = get_struct(argp, getattr(nv_gpu, nvuvms.get(request)+"_PARAMS"))
for i in range(st.gpuAttributesCount):
print("perGpuAttributes[{i}] = ", end="")
dump_struct(st.perGpuAttributes[i])
if getenv("IOCTL") >= 2: print("ioctl", f"{idir=} {size=} {itype=} {nr=} {fd=} {ret=}", fn)
return ret
@ -166,23 +174,40 @@ if getenv("IOCTL") >= 3: orig_mmap_mv = install_hook(libc.mmap, _mmap)
import collections
old_gpputs = collections.defaultdict(int)
def _dump_gpfifo(mark):
print("_dump_gpfifo:", mark)
launches = []
# print("_dump_gpfifo:", mark)
for start,size in gpus_fifo:
gpfifo_controls = nv_gpu.AmpereAControlGPFifo.from_address(start+size*8)
gpfifo = to_mv(start, gpfifo_controls.GPPut * 8).cast("Q")
if old_gpputs[start] == gpfifo_controls.GPPut: continue
gpfifo = to_mv(start, size * 8).cast("Q")
while old_gpputs[start] != gpfifo_controls.GPPut:
addr = ((gpfifo[old_gpputs[start]] & ((1 << 40)-1)) >> 2) << 2
pckt_cnt = (gpfifo[old_gpputs[start]]>>42)&((1 << 20)-1)
print(f"gpfifo {start}: {gpfifo_controls.GPPut=}")
for i in range(old_gpputs[start], gpfifo_controls.GPPut):
addr = ((gpfifo[i % size] & ((1 << 40)-1)) >> 2) << 2
pckt_cnt = (gpfifo[i % size]>>42)&((1 << 20)-1)
print(f"\t{i}: 0x{gpfifo[i % size]:x}: addr:0x{addr:x} packets:{pckt_cnt} sync:{(gpfifo[i % size] >> 63) & 0x1} fetch:{gpfifo[i % size] & 0x1}")
old_gpputs[start] = gpfifo_controls.GPPut
_dump_qmd(addr, pckt_cnt)
# print(f"\t{i}: 0x{gpfifo[i % size]:x}: addr:0x{addr:x} packets:{pckt_cnt} sync:{(gpfifo[i % size] >> 63) & 0x1} fetch:{gpfifo[i % size] & 0x1}")
x = _dump_qmd(addr, pckt_cnt)
if isinstance(x, list): launches += x
old_gpputs[start] += 1
old_gpputs[start] %= size
return launches
import types
def make_qmd_struct_type():
fields: List[Tuple[str, Union[Type[ctypes.c_uint64], Type[ctypes.c_uint32]], Any]] = []
bits = [(name,dt) for name,dt in nv_gpu.__dict__.items() if name.startswith("NVC6C0_QMDV03_00") and isinstance(dt, tuple)]
bits += [(name+f"_{i}",dt(i)) for name,dt in nv_gpu.__dict__.items() for i in range(8) if name.startswith("NVC6C0_QMDV03_00") and callable(dt)]
bits = sorted(bits, key=lambda x: x[1][1])
for i,(name, data) in enumerate(bits):
if i > 0 and (gap:=(data[1] - bits[i-1][1][0] - 1)) != 0: fields.append((f"_reserved{i}", ctypes.c_uint32, gap))
fields.append((name.replace("NVC6C0_QMDV03_00_", "").lower(), ctypes.c_uint32, data[0]-data[1]+1))
if len(fields) >= 2 and fields[-2][0].endswith('_lower') and fields[-1][0].endswith('_upper') and fields[-1][0][:-6] == fields[-2][0][:-6]:
fields = fields[:-2] + [(fields[-1][0][:-6], ctypes.c_uint64, fields[-1][2] + fields[-2][2])]
return init_c_struct_t(tuple(fields))
qmd_struct_t = make_qmd_struct_type()
assert ctypes.sizeof(qmd_struct_t) == 0x40 * 4
def _dump_qmd(address, packets):
qmds = []
gpfifo = to_mv(address, packets * 4).cast("I")
i = 0
@ -194,27 +219,49 @@ def _dump_qmd(address, packets):
subc = (dat>>13) & 7
mthd = (dat<<2) & 0x7FFF
method_name = nvqcmds.get(mthd, f"unknown method #{mthd}")
print(f"\t\t{method_name}, {typ=} {size=} {subc=} {mthd=}")
for j in range(size): print(f"\t\t\t{j}: {gpfifo[i+j+1]} | 0x{gpfifo[i+j+1]:x}")
if getenv("IOCTL", 0) >= 1:
print(f"\t\t{method_name}, {typ=} {size=} {subc=} {mthd=}")
for j in range(size): print(f"\t\t\t{j}: {gpfifo[i+j+1]} | 0x{gpfifo[i+j+1]:x}")
if mthd == 792:
for x in dir(nv_gpu):
if x.startswith("NVC6C0_QMDV03_00_"):
vv = getattr(nv_gpu, x)
bits = None
if isinstance(vv, tuple) and len(vv) == 2:
bits = vv
if isinstance(vv, types.FunctionType):
bits = vv(0)
if bits is not None:
res = 0
for bt in range(bits[1], bits[0]+1): res |= ((gpfifo[i + 3 + bt // 32] >> (bt % 32)) & 0x1) << (bt - bits[1])
if res != 0: print(f"{x}, {hex(res)} | {bin(res)}")
const_addr = gpfifo[i+35] + ((gpfifo[i+36] & 0xffff) << 32)
const_len = ((gpfifo[i+36] >> 19))
# hexdump(to_mv(const_addr, const_len))
qmds.append(qmd_struct_t.from_address(address + 12 + i * 4))
elif mthd == nv_gpu.NVC6C0_SEND_PCAS_A:
qmds.append(qmd_struct_t.from_address(gpfifo[i+1] << 8))
i += size + 1
return qmds
# This is to be used in fuzzer, check cuda/nv side by side.
# Return a state which should be compare and compare function.
def before_launch(): _dump_gpfifo("before launch")
def collect_last_launch_state(): return _dump_gpfifo("after launch")
def compare_launch_state(states1, states2):
states1 = states1 or list()
states2 = states2 or list()
if len(states1) != 1 or len(states2) != 1:
return False, f"Some states not captured. {len(states1)}!=1 || {len(states2)}!=1"
for i in range(len(states1)):
state1, state2 = states1[i], states2[i]
for n in ['qmd_major_version', 'invalidate_shader_data_cache', 'invalidate_shader_data_cache',
'sm_global_caching_enable', 'invalidate_texture_header_cache', 'invalidate_texture_sampler_cache',
'barrier_count', 'sampler_index', 'api_visible_call_limit', 'cwd_membar_type', 'sass_version',
'min_sm_config_shared_mem_size', 'max_sm_config_shared_mem_size', 'register_count_v',
'target_sm_config_shared_mem_size', 'shared_memory_size']:
if getattr(state1, n) != getattr(state1, n):
return False, f"Field {n} mismatch: {getattr(state1, n)} vs {getattr(state2, n)}"
# Allow NV to allocate more, at least this is not exact problem, so ignore it here.
# Hmm, CUDA minimum is 0x640, is this hw-required minimum (will check)?
if state1.shader_local_memory_high_size < state2.shader_local_memory_high_size and state2.shader_local_memory_high_size > 0x640:
return False, f"Field shader_local_memory_high_size mismatch: {state1.shader_local_memory_high_size} vs {state2.shader_local_memory_high_size}"
for i in range(8):
n = f"constant_buffer_valid_{i}"
if getattr(state1, n) != getattr(state1, n):
return False, f"Field {n} mismatch: {getattr(state1, n)} vs {getattr(state2, n)}"
return True, "PASS"
# IOCTL=1 PTX=1 CUDA=1 python3 test/test_ops.py TestOps.test_tiny_add

View File

@ -1,5 +1,5 @@
import random, traceback, ctypes, argparse
from typing import List, Tuple, DefaultDict
from typing import List, Tuple, DefaultDict, Any
import numpy as np
from collections import defaultdict
from extra.optimization.helpers import load_worlds, ast_str_to_lin, kern_str_to_lin
@ -14,6 +14,21 @@ from tinygrad.helpers import getenv, from_mv, prod, colored, Context, DEBUG, Tim
from tinygrad.ops import UnaryOps, UOp, UOps
from test.helpers import is_dtype_supported
def on_linearizer_will_run(): pass
def on_linearizer_did_run(): pass
def compare_states(x, y): return True
if getenv("VALIDATE_HCQ"):
if Device.DEFAULT == "NV":
print("VALIDATE_HCQ: Comparing NV to CUDA")
import extra.nv_gpu_driver.nv_ioctl
validate_device = Device["CUDA"]
on_linearizer_will_run = extra.nv_gpu_driver.nv_ioctl.before_launch
on_linearizer_did_run = extra.nv_gpu_driver.nv_ioctl.collect_last_launch_state
compare_states = extra.nv_gpu_driver.nv_ioctl.compare_launch_state
else:
print(colored("VALIDATE_HCQ options is ignored", 'red'))
def tuplize_uops(uops:List[UOp]) -> Tuple:
return tuple([(x.op, x.dtype, tuple(uops.index(x) for x in x.src), x.arg) for x in uops])
@ -39,36 +54,44 @@ def get_fuzz_rawbufs(lin):
data = np.random.uniform(-1, 1, size=rawbuf.size).astype(dtype=_to_np_dtype(rawbuf.dtype))
else:
data = np.random.uniform(-10, 10, size=rawbuf.size).astype(dtype=_to_np_dtype(rawbuf.dtype))
rawbuf.copyin(Tensor(data).realize().lazydata.realized.as_buffer())
rawbuf.copyin(Tensor(data, device=lin.opts.device).realize().lazydata.realized.as_buffer())
return rawbufs
def get_fuzz_rawbuf_like(rawbuf, zero=False, size=None):
rawbuf = type(rawbuf)(Device.DEFAULT, rawbuf.size if size is None else size, rawbuf.dtype).allocate()
if zero:
def get_fuzz_rawbuf_like(old_rawbuf, zero=False, copy=False, size=None, force_device=None):
rawbuf = type(old_rawbuf)(force_device or old_rawbuf.device, old_rawbuf.size if size is None else size, old_rawbuf.dtype).allocate()
if copy:
with Context(DEBUG=0): rawbuf.copyin(old_rawbuf.as_buffer())
elif zero:
with Context(DEBUG=0):
mv = memoryview(bytearray(rawbuf.size * rawbuf.dtype.itemsize))
ctypes.memset(from_mv(mv), 0, len(mv))
rawbuf.copyin(mv)
return rawbuf
def run_linearizer(lin: Kernel, rawbufs=None, var_vals=None):
def run_linearizer(lin: Kernel, rawbufs=None, var_vals=None) -> Tuple[str, Any]: # (error msg, run state)
if rawbufs is None: rawbufs = bufs_from_lin(lin)
if var_vals is None: var_vals = {v: v.min for v in lin.ast[0].vars()}
# TODO: images needs required_optimization
try:
prg = CompiledRunner(lin.to_program())
except KeyboardInterrupt: raise
except Exception:
traceback.print_exc()
return "COMPILE_ERROR"
return "COMPILE_ERROR", None
if getenv("VALIDATE_HCQ"): on_linearizer_will_run()
try:
prg(rawbufs, var_vals, wait=True)
except KeyboardInterrupt: raise
except Exception:
traceback.print_exc()
return "EXEC_ERROR"
return "EXEC_ERROR", None
return "PASS"
if getenv("VALIDATE_HCQ"): run_state = on_linearizer_did_run()
else: run_state = None
return "PASS", run_state
def compare_linearizer(lin: Kernel, rawbufs=None, var_vals=None, ground_truth=None, rtol=1e-2, atol=1e-2):
# TODO: for bfloat16 it compiles linearizer, but it does not run because numpy cannot generate bf16 buffer.
@ -80,8 +103,9 @@ def compare_linearizer(lin: Kernel, rawbufs=None, var_vals=None, ground_truth=No
rawbufs = get_fuzz_rawbufs(lin)
else:
rawbufs[0] = get_fuzz_rawbuf_like(rawbufs[0], zero=True) # get a new output buffer
except KeyboardInterrupt: raise
except BaseException:
return ("RAWBUFS_ERROR", rawbufs, var_vals, ground_truth,)
return ("RAWBUFS_ERROR", rawbufs, var_vals, ground_truth, None)
if var_vals is None:
# TODO: handle symbolic max case
@ -90,18 +114,19 @@ def compare_linearizer(lin: Kernel, rawbufs=None, var_vals=None, ground_truth=No
if ground_truth is None and not has_bf16:
unoptimized = Kernel(lin.ast)
unoptimized.required_optimizations()
if run_linearizer(unoptimized, rawbufs, var_vals) != "PASS":
return ("BASELINE_ERROR", rawbufs, var_vals, ground_truth,)
if run_linearizer(unoptimized, rawbufs, var_vals)[0] != "PASS":
return ("BASELINE_ERROR", rawbufs, var_vals, ground_truth, None)
ground_truth = np.frombuffer(rawbufs[0].as_buffer(), _to_np_dtype(rawbufs[0].dtype)).copy()
rawbufs[0] = get_fuzz_rawbuf_like(rawbufs[0], zero=True) # get a new output buffer
if (run_msg := run_linearizer(lin, rawbufs, var_vals)) != "PASS":
return (run_msg, rawbufs, var_vals, ground_truth,)
run_msg, run_state = run_linearizer(lin, rawbufs, var_vals)
if run_msg != "PASS": return (run_msg, rawbufs, var_vals, ground_truth, run_state)
try:
if not has_bf16:
result = np.frombuffer(rawbufs[0].as_buffer(), _to_np_dtype(rawbufs[0].dtype))
np.testing.assert_allclose(result, ground_truth, rtol=rtol, atol=atol)
except KeyboardInterrupt: raise
except AssertionError as e:
if DEBUG >= 2:
print(f"COMPARE_ERROR details: {e}")
@ -111,9 +136,9 @@ def compare_linearizer(lin: Kernel, rawbufs=None, var_vals=None, ground_truth=No
mismatched_ground_truth = ground_truth[mismatch_indices]
for i, idx in enumerate(mismatch_indices[0]):
print(f"mismatch at {idx=}: result={mismatched_result[i]} <> ground_truth={mismatched_ground_truth[i]}")
return ("COMPARE_ERROR", rawbufs, var_vals, ground_truth,)
return ("COMPARE_ERROR", rawbufs, var_vals, ground_truth, run_state)
return ("PASS", rawbufs, var_vals, ground_truth,)
return ("PASS", rawbufs, var_vals, ground_truth, run_state)
def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2):
SEED = getenv("SEED", 42)
@ -124,7 +149,7 @@ def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2):
seen_uops = {}
last_lins = [lin]
failures:DefaultDict[str, List[Tuple[Tuple[UOp,...],List[Opt]]]] = defaultdict(list)
rawbufs, var_vals, ground_truth = None, None, None
rawbufs, var_vals, ground_truth, validate_rawbufs = None, None, None, None
FUZZ_ALL_ACTIONS = getenv("FUZZ_ALL_ACTIONS", 0)
FUZZ_MAX_SIZE = getenv("FUZZ_MAX_SIZE", 0)
@ -156,6 +181,7 @@ def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2):
# stop if kernel uops repeat
try: tuops = tuplize_uops(test_lin.linearize().uops)
except KeyboardInterrupt: raise
except BaseException as e:
print(test_lin.ast)
print(test_lin.applied_opts)
@ -168,7 +194,19 @@ def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2):
if not FUZZ_ALL_ACTIONS: print(test_lin.colored_shape())
(msg, rawbufs, var_vals, ground_truth) = compare_linearizer(test_lin, rawbufs, var_vals, ground_truth, rtol=rtol, atol=atol)
(msg, rawbufs, var_vals, ground_truth, state1) = compare_linearizer(test_lin, rawbufs, var_vals, ground_truth, rtol=rtol, atol=atol)
if state1 is not None and validate_device is not None:
validate_lin = test_lin.copy()
validate_lin.opts = validate_device.renderer
if validate_rawbufs is None:
validate_rawbufs = [get_fuzz_rawbuf_like(x, copy=True, force_device=validate_device.dname) for x in rawbufs]
(_msg, _, _, _, state2) = compare_linearizer(validate_lin, validate_rawbufs, var_vals, ground_truth, rtol=rtol, atol=atol)
if _msg != "PASS": failures[f"VALIDATE_DEV_{_msg}"].append((validate_lin.ast, validate_lin.applied_opts))
ok, err_msg = compare_states(state1, state2)
if not ok: failures["HCQ_COMPARE_FAILURE"].append((err_msg, test_lin.ast, test_lin.applied_opts, state1, state2))
if msg != "PASS":
print(test_lin.ast)
print(test_lin.applied_opts)
@ -220,28 +258,31 @@ if __name__ == "__main__":
failed_ids = []
failures = defaultdict(list)
seen_ast_strs = set()
for i, ast in enumerate(ast_strs[:getenv("FUZZ_N", len(ast_strs))]):
if (nth := getenv("FUZZ_NTH", -1)) != -1 and i != nth: continue
if "dtypes.image" in ast and Device.DEFAULT != "GPU": continue # IMAGE is only for GPU
if ast in seen_ast_strs: continue
seen_ast_strs.add(ast)
lin = ast_str_to_lin(ast)
if not all(is_dtype_supported(buf.dtype) for buf in lin.bufs):
print("skipping kernel due to not supported dtype")
continue
try:
for i, ast in enumerate(ast_strs[:getenv("FUZZ_N", len(ast_strs))]):
if (nth := getenv("FUZZ_NTH", -1)) != -1 and i != nth: continue
if "dtypes.image" in ast and Device.DEFAULT != "GPU": continue # IMAGE is only for GPU
if ast in seen_ast_strs: continue
seen_ast_strs.add(ast)
with Timing(f"tested ast {i}: "):
tested += 1
fuzz_failures = fuzz_linearizer(lin, rtol=args.rtol, atol=args.atol)
if fuzz_failures: failed_ids.append(i)
for k, v in fuzz_failures.items():
for f in v:
failures[k].append(f)
lin = ast_str_to_lin(ast)
if not all(is_dtype_supported(buf.dtype) for buf in lin.bufs):
print("skipping kernel due to not supported dtype")
continue
with Timing(f"tested ast {i}: "):
tested += 1
fuzz_failures = fuzz_linearizer(lin, rtol=args.rtol, atol=args.atol)
if fuzz_failures: failed_ids.append(i)
for k, v in fuzz_failures.items():
for f in v:
failures[k].append(f)
except KeyboardInterrupt: print(colored("STOPPING...", 'red'))
for msg, errors in failures.items():
for i, (ast, opts) in enumerate(errors):
print(f"{msg} {i} kernel: {(ast,opts)}") # easier to use with output with verify_kernel.py
for i, payload in enumerate(errors):
print(f"{msg} {i} kernel: {payload}") # easier to use with output with verify_kernel.py
print(f"{tested=}")
if failures: