Revert "delete arg from cast [run_process_replay] (#6202)" (#6214)

This reverts commit ec52a09393.
This commit is contained in:
George Hotz 2024-08-20 16:45:30 -07:00 committed by GitHub
parent 89c4cffd86
commit 296368f0dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 2 deletions

View File

@ -314,7 +314,7 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \
if not st.contiguous or tr_next.op in ReduceOps: break if not st.contiguous or tr_next.op in ReduceOps: break
tr = tr_next tr = tr_next
# don't cast to higher size before store (tr cannot be realized if forced_realize) # don't cast to higher size before store (tr cannot be realized if forced_realize)
if tr.op is UnaryOps.CAST and tr.dtype.itemsize > tr.srcs[0].dtype.itemsize: if tr.op is UnaryOps.CAST and tr.arg.itemsize > tr.srcs[0].dtype.itemsize:
tr = tr.srcs[0].base tr = tr.srcs[0].base
reduce_for_op[tr] = r reduce_for_op[tr] = r
realizes[tr] = None realizes[tr] = None

View File

@ -108,7 +108,7 @@ class LazyBuffer:
# TODO: applying this makes gpt2 slower # TODO: applying this makes gpt2 slower
return self.base.cast(dtype, bitcast)._view(self.st) return self.base.cast(dtype, bitcast)._view(self.st)
cast_op: Union[MetaOps, UnaryOps] = (MetaOps.VIEW if self.can_view() and allow_buffer_view else UnaryOps.BITCAST) if bitcast else UnaryOps.CAST cast_op: Union[MetaOps, UnaryOps] = (MetaOps.VIEW if self.can_view() and allow_buffer_view else UnaryOps.BITCAST) if bitcast else UnaryOps.CAST
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, None, (self,)) return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,))
def is_unrealized_const(self): return self.base.realized is None and self.base.op is MetaOps.CONST and not isinstance(self.base.arg, Variable) def is_unrealized_const(self): return self.base.realized is None and self.base.op is MetaOps.CONST and not isinstance(self.base.arg, Variable)
def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views) def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)