logging: change LOGKERN to LOGKERNS to match LOGOPS (#4193)

also add printing of ast and applied_opts during verify_kernel
to more easily debug errors if they come up
This commit is contained in:
Francis Lam 2024-04-16 13:08:32 -07:00 committed by GitHub
parent 7fb220a567
commit e9c1616b27
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 4 deletions

View File

@ -39,7 +39,10 @@ if __name__ == "__main__":
for i, kern_str in enumerate(kern_strs):
print(f"testing kernel {i}")
test_lin = kern_str_to_lin(kern_str)
for op in test_lin.ast: print_tree(op)
for op in test_lin.ast:
print_tree(op)
print(op)
print(test_lin.applied_opts)
print(test_lin.colored_shape())
(msg,rb,vv,gt) = compare_linearizer(test_lin, None, None, None, rtol=args.rtol, atol=args.atol)
if msg != "PASS":

View File

@ -206,7 +206,7 @@ class MultiDeviceJITGraph(Runner):
raise NotImplementedError("override this")
method_cache: Dict[Tuple[str, Tuple[LazyOp, ...], bool], CompiledRunner] = {}
logkern, logkern_level = open(getenv("LOGKERN", ""), "a") if getenv("LOGKERN", "") else None, getenv("LOGKERN_LEVEL", 1)
logkerns, logkerns_level = open(getenv("LOGKERNS", ""), "a") if getenv("LOGKERNS", "") else None, getenv("LOGKERNS_LEVEL", 1)
class Compiled:
def __init__(self, device:str, allocator:Allocator, compiler:Optional[Compiler], runtime, graph=None):
self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler, runtime, graph
@ -248,9 +248,9 @@ class Compiled:
timed = sorted([(nm, tk, time_linearizer(tk, rawbufs, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
k = timed[0][1]
if logkern is not None and logkern_level > 1: logkern.writelines([f"{(lin.ast, lin.applied_opts)}\n" for (_,lin,_) in timed[1:]])
if logkerns is not None and logkerns_level > 1: logkerns.writelines([f"{(lin.ast, lin.applied_opts)}\n" for (_,lin,_) in timed[1:]])
# TODO: check the correctness inline once compare_linearizer is in core
if logkern is not None: logkern.writelines([f"{(k.ast, k.applied_opts)}\n"])
if logkerns is not None: logkerns.writelines([f"{(k.ast, k.applied_opts)}\n"])
if DEBUG >= 4: print((k.ast, k.applied_opts)) # print here to show final applied_opts for all kernels instead of just in beam_search
return k