mirror of https://github.com/commaai/tinygrad.git
This reverts commit ec52a09393
.
This commit is contained in:
parent
89c4cffd86
commit
296368f0dd
|
@ -314,7 +314,7 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \
|
||||||
if not st.contiguous or tr_next.op in ReduceOps: break
|
if not st.contiguous or tr_next.op in ReduceOps: break
|
||||||
tr = tr_next
|
tr = tr_next
|
||||||
# don't cast to higher size before store (tr cannot be realized if forced_realize)
|
# don't cast to higher size before store (tr cannot be realized if forced_realize)
|
||||||
if tr.op is UnaryOps.CAST and tr.dtype.itemsize > tr.srcs[0].dtype.itemsize:
|
if tr.op is UnaryOps.CAST and tr.arg.itemsize > tr.srcs[0].dtype.itemsize:
|
||||||
tr = tr.srcs[0].base
|
tr = tr.srcs[0].base
|
||||||
reduce_for_op[tr] = r
|
reduce_for_op[tr] = r
|
||||||
realizes[tr] = None
|
realizes[tr] = None
|
||||||
|
|
|
@ -108,7 +108,7 @@ class LazyBuffer:
|
||||||
# TODO: applying this makes gpt2 slower
|
# TODO: applying this makes gpt2 slower
|
||||||
return self.base.cast(dtype, bitcast)._view(self.st)
|
return self.base.cast(dtype, bitcast)._view(self.st)
|
||||||
cast_op: Union[MetaOps, UnaryOps] = (MetaOps.VIEW if self.can_view() and allow_buffer_view else UnaryOps.BITCAST) if bitcast else UnaryOps.CAST
|
cast_op: Union[MetaOps, UnaryOps] = (MetaOps.VIEW if self.can_view() and allow_buffer_view else UnaryOps.BITCAST) if bitcast else UnaryOps.CAST
|
||||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, None, (self,))
|
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,))
|
||||||
|
|
||||||
def is_unrealized_const(self): return self.base.realized is None and self.base.op is MetaOps.CONST and not isinstance(self.base.arg, Variable)
|
def is_unrealized_const(self): return self.base.realized is None and self.base.op is MetaOps.CONST and not isinstance(self.base.arg, Variable)
|
||||||
def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)
|
def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)
|
||||||
|
|
Loading…
Reference in New Issue