hotfix: force_transcendental to fix process replay

This commit is contained in:
George Hotz 2024-09-26 11:13:16 +08:00
parent a6a70aa4bd
commit ff880f5be4
1 changed files with 3 additions and 3 deletions

View File

@ -275,8 +275,8 @@ transcendental_patterns = [
]
@functools.lru_cache(None)
def get_extra_patterns(ops):
pat = [(p[0], cast(Callable, p[1])) for p in transcendental_patterns if p[0].arg not in ops or TRANSCENDENTAL >= 2]
def get_extra_patterns(ops, force_transcendental=False):
pat = [(p[0], cast(Callable, p[1])) for p in transcendental_patterns if p[0].arg not in ops or force_transcendental]
if BinaryOps.SHL in ops and BinaryOps.SHR in ops:
shiftable_consts = set([2**i for i in range(64)])
pat += [
@ -750,7 +750,7 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
sink = graph_rewrite(sink, constant_folder+just_reduce)
sink = graph_rewrite(sink, constant_folder+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize))
sink = graph_rewrite(sink, constant_folder+reducer)
sink = graph_rewrite(sink, constant_folder+get_extra_patterns(tuple(opts.code_for_op.keys()) if opts is not None else ()))
sink = graph_rewrite(sink, constant_folder+get_extra_patterns(tuple(opts.code_for_op.keys()) if opts is not None else (), TRANSCENDENTAL>=2))
if opts is not None and opts.extra_matcher is not None: sink = graph_rewrite(sink, opts.extra_matcher)
return sink