move acc to end (#5568)

* move acc to end

* confirmed pictures are the same

* relax that

* Update test_ops.py
This commit is contained in:
George Hotz 2024-07-19 03:06:52 -07:00 committed by GitHub
parent 2de82b8a5d
commit 0ad87021e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 11 additions and 10 deletions

View File

@ -424,5 +424,5 @@ if __name__ == "__main__":
and not args.weights:
ref_image = Tensor(np.array(Image.open(Path(__file__).parent / "sdxl_seed0.png")))
distance = (((x - ref_image).cast(dtypes.float) / ref_image.max())**2).mean().item()
assert distance < 2e-3, colored(f"validation failed with {distance=}", "red")
assert distance < 4e-3, colored(f"validation failed with {distance=}", "red")
print(colored(f"output validated with {distance=}", "green"))

View File

@ -812,8 +812,8 @@ class TestOps(unittest.TestCase):
b = Tensor.ones(3,3)
a @ b
def test_multidot(self):
helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot)
helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot)
helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
def test_sum_simple(self):
helper_test_op(None, lambda x: x.sum(), vals=[[1.,1.]])

View File

@ -231,8 +231,9 @@ constant_folder = PatternMatcher([
(UPat(UOps.PHI, src=(UPat(UOps.GEP, name="phi_input", src=(UPat(UOps.DEFINE_ACC, src=[UPat(UOps.CONST), UPat(UOps.RANGE, name="loop")]),)),
UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
# deal with UNMUL
(UPat(UOps.ALU, BinaryOps.MUL, [UPat(UOps.CONST, name="c1"), UPat(UOps.UNMUL, src=[UPat(UOps.CONST, name="c2"), UPat(name="v")])]),
lambda c1,c2,v: v if c1.arg == c2.arg else None),
(UOp.cvar('c1') * UOp(UOps.UNMUL, src=(UOp.cvar('c2'), UOp.var('v'))), lambda c1,c2,v: v if c1.arg == c2.arg else None),
(UOp.cvar('c1') * (UOp.var('add') + UOp(UOps.UNMUL, src=(UOp.cvar('c2'), UOp.var('v')))),
lambda c1, add, c2, v: (add*c1+v) if c1.arg == c2.arg else None),
(UOp(UOps.UNMUL, src=(UOp.const(None, 0).name('zero'), UOp.var())), lambda zero: zero),
(UOp(UOps.UNMUL).name('unmul').cast().name('root'), lambda root,unmul: UOp(UOps.UNMUL, root.dtype, (unmul.src[0].cast(root.dtype), unmul.src[1]))),
# indexing (with a multiply offset)!
@ -410,14 +411,14 @@ def do_reduce_with_expand(root):
const = UOp.const(root.dtype.scalar(), dtypes.as_const(0, root.dtype.scalar()) if root.arg is ReduceOps.SUM else dtypes.min(root.dtype.scalar()))
ret = acc = UOp(UOps.DEFINE_ACC, root.dtype, (const,) + tuple(x for x in root.src[1:] if x.op is not UOps.EXPAND), (acc_number,))
acc_number += 1
alu_op = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[cast(ReduceOps, root.arg)]
if len(expands_reduce):
assert root.src[0].op is UOps.EXPAND
expand_reduce_args = dedup(flatten([x.arg for x in expands_reduce]))
assert prod([y[1] for y in expand_reduce_args]) == len(root.src[0].src)
for xx in root.src[0].src:
ret = UOp.alu({ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[cast(ReduceOps, root.arg)], ret, xx)
ret = functools.reduce(lambda x,y: UOp.alu(alu_op, x, y), root.src[0].src+(ret,))
else:
ret = UOp.alu({ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[cast(ReduceOps, root.arg)], ret, root.src[0])
ret = UOp.alu(alu_op, ret, root.src[0])
ret = UOp(UOps.PHI, ret.dtype, (acc, ret))
if len(expands_non_reduce): ret = ret * prod([sz for _,sz in flatten([x.arg for x in expands_non_reduce])])
return ret

View File

@ -35,8 +35,8 @@ class UOp:
src: Tuple[UOp, ...] = tuple()
arg: Any = None
def commutative(self) -> bool:
return self.op is UOps.ALU and \
self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR}
return self.op is UOps.UNMUL or (self.op is UOps.ALU and \
self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR})
@functools.cached_property
def cmp_tuple(self):
# NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX