From ff880f5be4e4b7684cd8bd0e89aeea2b095c30fe Mon Sep 17 00:00:00 2001 From: George Hotz Date: Thu, 26 Sep 2024 11:13:16 +0800 Subject: [PATCH] hotfix: force_transcendental to fix process replay --- tinygrad/codegen/uopgraph.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 3b78aebc..18cffefb 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -275,8 +275,8 @@ transcendental_patterns = [ ] @functools.lru_cache(None) -def get_extra_patterns(ops): - pat = [(p[0], cast(Callable, p[1])) for p in transcendental_patterns if p[0].arg not in ops or TRANSCENDENTAL >= 2] +def get_extra_patterns(ops, force_transcendental=False): + pat = [(p[0], cast(Callable, p[1])) for p in transcendental_patterns if p[0].arg not in ops or force_transcendental] if BinaryOps.SHL in ops and BinaryOps.SHR in ops: shiftable_consts = set([2**i for i in range(64)]) pat += [ @@ -750,7 +750,7 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp: sink = graph_rewrite(sink, constant_folder+just_reduce) sink = graph_rewrite(sink, constant_folder+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)) sink = graph_rewrite(sink, constant_folder+reducer) - sink = graph_rewrite(sink, constant_folder+get_extra_patterns(tuple(opts.code_for_op.keys()) if opts is not None else ())) + sink = graph_rewrite(sink, constant_folder+get_extra_patterns(tuple(opts.code_for_op.keys()) if opts is not None else (), TRANSCENDENTAL>=2)) if opts is not None and opts.extra_matcher is not None: sink = graph_rewrite(sink, opts.extra_matcher) return sink