diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index d0cdf8ce..b9b74dcd 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -314,7 +314,7 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \ if not st.contiguous or tr_next.op in ReduceOps: break tr = tr_next # don't cast to higher size before store (tr cannot be realized if forced_realize) - if tr.op is UnaryOps.CAST and tr.dtype.itemsize > tr.srcs[0].dtype.itemsize: + if tr.op is UnaryOps.CAST and tr.arg.itemsize > tr.srcs[0].dtype.itemsize: tr = tr.srcs[0].base reduce_for_op[tr] = r realizes[tr] = None diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 2bca279c..d8d44dfa 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -108,7 +108,7 @@ class LazyBuffer: # TODO: applying this makes gpt2 slower return self.base.cast(dtype, bitcast)._view(self.st) cast_op: Union[MetaOps, UnaryOps] = (MetaOps.VIEW if self.can_view() and allow_buffer_view else UnaryOps.BITCAST) if bitcast else UnaryOps.CAST - return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, None, (self,)) + return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,)) def is_unrealized_const(self): return self.base.realized is None and self.base.op is MetaOps.CONST and not isinstance(self.base.arg, Variable) def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)