diff --git a/examples/sdxl.py b/examples/sdxl.py index b6ea7439..cc45bd37 100644 --- a/examples/sdxl.py +++ b/examples/sdxl.py @@ -424,5 +424,5 @@ if __name__ == "__main__": and not args.weights: ref_image = Tensor(np.array(Image.open(Path(__file__).parent / "sdxl_seed0.png"))) distance = (((x - ref_image).cast(dtypes.float) / ref_image.max())**2).mean().item() - assert distance < 2e-3, colored(f"validation failed with {distance=}", "red") + assert distance < 4e-3, colored(f"validation failed with {distance=}", "red") print(colored(f"output validated with {distance=}", "green")) diff --git a/test/test_ops.py b/test/test_ops.py index 3e74119a..782f2b09 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -812,8 +812,8 @@ class TestOps(unittest.TestCase): b = Tensor.ones(3,3) a @ b def test_multidot(self): - helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot) - helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot) + helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4) + helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4) def test_sum_simple(self): helper_test_op(None, lambda x: x.sum(), vals=[[1.,1.]]) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 724f11ae..7dfa0629 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -231,8 +231,9 @@ constant_folder = PatternMatcher([ (UPat(UOps.PHI, src=(UPat(UOps.GEP, name="phi_input", src=(UPat(UOps.DEFINE_ACC, src=[UPat(UOps.CONST), UPat(UOps.RANGE, name="loop")]),)), UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse), # deal with UNMUL - (UPat(UOps.ALU, BinaryOps.MUL, [UPat(UOps.CONST, name="c1"), UPat(UOps.UNMUL, src=[UPat(UOps.CONST, name="c2"), UPat(name="v")])]), - lambda c1,c2,v: v if c1.arg == c2.arg else None), + (UOp.cvar('c1') * UOp(UOps.UNMUL, src=(UOp.cvar('c2'), UOp.var('v'))), lambda c1,c2,v: v if c1.arg == c2.arg else None), + (UOp.cvar('c1') * (UOp.var('add') + UOp(UOps.UNMUL, src=(UOp.cvar('c2'), UOp.var('v')))), + lambda c1, add, c2, v: (add*c1+v) if c1.arg == c2.arg else None), (UOp(UOps.UNMUL, src=(UOp.const(None, 0).name('zero'), UOp.var())), lambda zero: zero), (UOp(UOps.UNMUL).name('unmul').cast().name('root'), lambda root,unmul: UOp(UOps.UNMUL, root.dtype, (unmul.src[0].cast(root.dtype), unmul.src[1]))), # indexing (with a multiply offset)! @@ -410,14 +411,14 @@ def do_reduce_with_expand(root): const = UOp.const(root.dtype.scalar(), dtypes.as_const(0, root.dtype.scalar()) if root.arg is ReduceOps.SUM else dtypes.min(root.dtype.scalar())) ret = acc = UOp(UOps.DEFINE_ACC, root.dtype, (const,) + tuple(x for x in root.src[1:] if x.op is not UOps.EXPAND), (acc_number,)) acc_number += 1 + alu_op = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[cast(ReduceOps, root.arg)] if len(expands_reduce): assert root.src[0].op is UOps.EXPAND expand_reduce_args = dedup(flatten([x.arg for x in expands_reduce])) assert prod([y[1] for y in expand_reduce_args]) == len(root.src[0].src) - for xx in root.src[0].src: - ret = UOp.alu({ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[cast(ReduceOps, root.arg)], ret, xx) + ret = functools.reduce(lambda x,y: UOp.alu(alu_op, x, y), root.src[0].src+(ret,)) else: - ret = UOp.alu({ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[cast(ReduceOps, root.arg)], ret, root.src[0]) + ret = UOp.alu(alu_op, ret, root.src[0]) ret = UOp(UOps.PHI, ret.dtype, (acc, ret)) if len(expands_non_reduce): ret = ret * prod([sz for _,sz in flatten([x.arg for x in expands_non_reduce])]) return ret diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 99912bc7..f286471e 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -35,8 +35,8 @@ class UOp: src: Tuple[UOp, ...] = tuple() arg: Any = None def commutative(self) -> bool: - return self.op is UOps.ALU and \ - self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR} + return self.op is UOps.UNMUL or (self.op is UOps.ALU and \ + self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR}) @functools.cached_property def cmp_tuple(self): # NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX