mirror of https://github.com/commaai/tinygrad.git
move acc to end (#5568)
* move acc to end * confirmed pictures are the same * relax that * Update test_ops.py
This commit is contained in:
parent
2de82b8a5d
commit
0ad87021e2
|
@ -424,5 +424,5 @@ if __name__ == "__main__":
|
|||
and not args.weights:
|
||||
ref_image = Tensor(np.array(Image.open(Path(__file__).parent / "sdxl_seed0.png")))
|
||||
distance = (((x - ref_image).cast(dtypes.float) / ref_image.max())**2).mean().item()
|
||||
assert distance < 2e-3, colored(f"validation failed with {distance=}", "red")
|
||||
assert distance < 4e-3, colored(f"validation failed with {distance=}", "red")
|
||||
print(colored(f"output validated with {distance=}", "green"))
|
||||
|
|
|
@ -812,8 +812,8 @@ class TestOps(unittest.TestCase):
|
|||
b = Tensor.ones(3,3)
|
||||
a @ b
|
||||
def test_multidot(self):
|
||||
helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot)
|
||||
helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot)
|
||||
helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
|
||||
helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
|
||||
|
||||
def test_sum_simple(self):
|
||||
helper_test_op(None, lambda x: x.sum(), vals=[[1.,1.]])
|
||||
|
|
|
@ -231,8 +231,9 @@ constant_folder = PatternMatcher([
|
|||
(UPat(UOps.PHI, src=(UPat(UOps.GEP, name="phi_input", src=(UPat(UOps.DEFINE_ACC, src=[UPat(UOps.CONST), UPat(UOps.RANGE, name="loop")]),)),
|
||||
UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
|
||||
# deal with UNMUL
|
||||
(UPat(UOps.ALU, BinaryOps.MUL, [UPat(UOps.CONST, name="c1"), UPat(UOps.UNMUL, src=[UPat(UOps.CONST, name="c2"), UPat(name="v")])]),
|
||||
lambda c1,c2,v: v if c1.arg == c2.arg else None),
|
||||
(UOp.cvar('c1') * UOp(UOps.UNMUL, src=(UOp.cvar('c2'), UOp.var('v'))), lambda c1,c2,v: v if c1.arg == c2.arg else None),
|
||||
(UOp.cvar('c1') * (UOp.var('add') + UOp(UOps.UNMUL, src=(UOp.cvar('c2'), UOp.var('v')))),
|
||||
lambda c1, add, c2, v: (add*c1+v) if c1.arg == c2.arg else None),
|
||||
(UOp(UOps.UNMUL, src=(UOp.const(None, 0).name('zero'), UOp.var())), lambda zero: zero),
|
||||
(UOp(UOps.UNMUL).name('unmul').cast().name('root'), lambda root,unmul: UOp(UOps.UNMUL, root.dtype, (unmul.src[0].cast(root.dtype), unmul.src[1]))),
|
||||
# indexing (with a multiply offset)!
|
||||
|
@ -410,14 +411,14 @@ def do_reduce_with_expand(root):
|
|||
const = UOp.const(root.dtype.scalar(), dtypes.as_const(0, root.dtype.scalar()) if root.arg is ReduceOps.SUM else dtypes.min(root.dtype.scalar()))
|
||||
ret = acc = UOp(UOps.DEFINE_ACC, root.dtype, (const,) + tuple(x for x in root.src[1:] if x.op is not UOps.EXPAND), (acc_number,))
|
||||
acc_number += 1
|
||||
alu_op = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[cast(ReduceOps, root.arg)]
|
||||
if len(expands_reduce):
|
||||
assert root.src[0].op is UOps.EXPAND
|
||||
expand_reduce_args = dedup(flatten([x.arg for x in expands_reduce]))
|
||||
assert prod([y[1] for y in expand_reduce_args]) == len(root.src[0].src)
|
||||
for xx in root.src[0].src:
|
||||
ret = UOp.alu({ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[cast(ReduceOps, root.arg)], ret, xx)
|
||||
ret = functools.reduce(lambda x,y: UOp.alu(alu_op, x, y), root.src[0].src+(ret,))
|
||||
else:
|
||||
ret = UOp.alu({ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[cast(ReduceOps, root.arg)], ret, root.src[0])
|
||||
ret = UOp.alu(alu_op, ret, root.src[0])
|
||||
ret = UOp(UOps.PHI, ret.dtype, (acc, ret))
|
||||
if len(expands_non_reduce): ret = ret * prod([sz for _,sz in flatten([x.arg for x in expands_non_reduce])])
|
||||
return ret
|
||||
|
|
|
@ -35,8 +35,8 @@ class UOp:
|
|||
src: Tuple[UOp, ...] = tuple()
|
||||
arg: Any = None
|
||||
def commutative(self) -> bool:
|
||||
return self.op is UOps.ALU and \
|
||||
self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR}
|
||||
return self.op is UOps.UNMUL or (self.op is UOps.ALU and \
|
||||
self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR})
|
||||
@functools.cached_property
|
||||
def cmp_tuple(self):
|
||||
# NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
|
||||
|
|
Loading…
Reference in New Issue