mirror of https://github.com/commaai/tinygrad.git
hotfix: more lazyop rename to uop [run_process_replay] (#6157)
This commit is contained in:
parent
17a043edad
commit
be6dda4093
|
@ -9,13 +9,13 @@ from tinygrad.shape.shapetracker import ShapeTracker
|
|||
from tinygrad import dtypes
|
||||
from tinygrad.shape.view import View
|
||||
|
||||
class InvalidLazyOpException(Exception): pass
|
||||
class InvalidASTException(Exception): pass
|
||||
def helper_test_verify_ast(*stores:UOp) -> Kernel:
|
||||
sink = UOp(UOps.SINK, None, stores)
|
||||
if DEBUG >= 3:
|
||||
for op in stores: print(op)
|
||||
try: verify_ast(sink)
|
||||
except AssertionError as e: raise InvalidLazyOpException(e.args)
|
||||
except AssertionError as e: raise InvalidASTException(e.args)
|
||||
k = Kernel(sink)
|
||||
k.linearize()
|
||||
if DEBUG >= 6: print_uops(k.uops)
|
||||
|
@ -42,14 +42,14 @@ class TestVerifyAST(unittest.TestCase):
|
|||
a = UOp(UOps.LOAD, dtype, (bufs[4], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
b = UOp(UOps.LOAD, dtype, (bufs[5], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
st1 = UOp.store(bufs[1], ShapeTracker.from_shape((32, 32)).to_uop(), a+b)
|
||||
with self.assertRaises(InvalidLazyOpException): helper_test_verify_ast(st0, st1)
|
||||
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st0, st1)
|
||||
|
||||
def test_no_implicit_broadcasting(self):
|
||||
bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), i) for i in range(2)]
|
||||
a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((4, 32)).to_uop()))
|
||||
b = a + UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.MAX, (1,)))
|
||||
st = UOp(UOps.STORE, None, (bufs[0], ShapeTracker.from_shape((4, 32)).to_uop(), b))
|
||||
with self.assertRaises(InvalidLazyOpException): helper_test_verify_ast(st)
|
||||
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st)
|
||||
|
||||
def test_shrink_ok(self):
|
||||
bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), i) for i in range(2)]
|
||||
|
@ -63,14 +63,14 @@ class TestVerifyAST(unittest.TestCase):
|
|||
a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop()))
|
||||
r = UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,)))
|
||||
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r)
|
||||
with self.assertRaisesRegex(InvalidLazyOpException, "implicit expand"): helper_test_verify_ast(st)
|
||||
with self.assertRaisesRegex(InvalidASTException, "implicit expand"): helper_test_verify_ast(st)
|
||||
|
||||
def test_reduce_add_store(self):
|
||||
bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), i) for i in range(2)]
|
||||
a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop()))
|
||||
r = UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,)))
|
||||
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r+a)
|
||||
with self.assertRaisesRegex(InvalidLazyOpException, "implicit expand"): helper_test_verify_ast(st)
|
||||
with self.assertRaisesRegex(InvalidASTException, "implicit expand"): helper_test_verify_ast(st)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -52,7 +52,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
|
|||
realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer],
|
||||
reduce_info:Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]],
|
||||
cache:Dict[Tuple[LazyBuffer, ShapeTracker], UOp]) -> UOp:
|
||||
"""recursively create a lazyop"""
|
||||
"""recursively create a UOp"""
|
||||
if buf is not buf.base: st, buf = buf.st+st, buf.base
|
||||
if (buf, st) in cache: return cache[(buf, st)]
|
||||
assert buf.op is not None, "base must be a base itself"
|
||||
|
@ -90,13 +90,13 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
|
|||
return cache.setdefault((buf, st), UOp(UOps.REDUCE_AXIS, dtype, (rsrc,), (buf.op, rinfo[1])))
|
||||
|
||||
# elementwise ops pass shapetracker
|
||||
in_ops = tuple(_recursive_uop(x, st, outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache) for x in buf.srcs)
|
||||
in_uops = tuple(_recursive_uop(x, st, outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache) for x in buf.srcs)
|
||||
if buf.op in {MetaOps.CONTIGUOUS, MetaOps.ASSIGN}:
|
||||
assert buf in outputs, f"{buf.op} must be writable"
|
||||
return in_ops[0]
|
||||
if buf.op is UnaryOps.CAST: return cache.setdefault((buf, st), UOp(UOps.CAST, dtype, in_ops))
|
||||
if buf.op is UnaryOps.BITCAST: return cache.setdefault((buf, st), UOp(UOps.BITCAST, dtype, in_ops))
|
||||
return cache.setdefault((buf, st), UOp(UOps.ALU, dtype, in_ops, buf.op))
|
||||
return in_uops[0]
|
||||
if buf.op is UnaryOps.CAST: return cache.setdefault((buf, st), UOp(UOps.CAST, dtype, in_uops))
|
||||
if buf.op is UnaryOps.BITCAST: return cache.setdefault((buf, st), UOp(UOps.BITCAST, dtype, in_uops))
|
||||
return cache.setdefault((buf, st), UOp(UOps.ALU, dtype, in_uops, buf.op))
|
||||
|
||||
def _permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]:
|
||||
permute_axis = tuple(i for i in range(len(input_st.shape)) if i not in axis) + axis
|
||||
|
@ -180,8 +180,8 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) ->
|
|||
if vv: var_vals.update(vv)
|
||||
ubuf = UOp(UOps.DEFINE_GLOBAL, out.dtype if isinstance(out.dtype, ImageDType) else PtrDType(out.dtype), (), i)
|
||||
ast.append(UOp(UOps.STORE, None, (ubuf, output_st.to_uop(), src)))
|
||||
return LBScheduleItem(UOp(UOps.SINK, None, tuple(ast)), outs, list(inputs), var_vals,
|
||||
dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))
|
||||
sink = UOp(UOps.SINK, None, tuple(ast))
|
||||
return LBScheduleItem(sink, outs, list(inputs), var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))
|
||||
|
||||
# *** DAG creation: decide which LazyBuffers should realize ***
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ class LazyBuffer:
|
|||
self._base: Optional[LazyBuffer] = None
|
||||
if base is None:
|
||||
# properties on base
|
||||
self.op, self.arg, self.srcs = op, arg, srcs # this is a LazyOp, except the src is LazyBuffers and not LazyOps
|
||||
self.op, self.arg, self.srcs = op, arg, srcs # this is a UOp, except the src is LazyBuffers and not UOps
|
||||
assert self.op is not MetaOps.ASSIGN or srcs[1].base.realized is not None, "assign target must be realized"
|
||||
|
||||
if self.op is MetaOps.VIEW:
|
||||
|
|
Loading…
Reference in New Issue