Do less realizes (#4141)

* less realize

* corealize jit inputs

* prints

* print before we run
This commit is contained in:
George Hotz 2024-04-10 19:50:50 -07:00 committed by GitHub
parent 06bcae13b4
commit 2e6c39b0b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 9 additions and 4 deletions

View File

@ -101,7 +101,7 @@ class TransformerBlock:
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
return (h + self.feed_forward(self.ffn_norm(h).half())).realize()
return h + self.feed_forward(self.ffn_norm(h).half())
class Transformer:
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward):

View File

@ -94,11 +94,14 @@ class TinyJit(Generic[ReturnType]):
def __call__(self, *args, **kwargs) -> ReturnType:
# all inputs (except const) are realized
input_tensors: Dict[Union[int, str], Union[LazyBuffer, MultiLazyBuffer]] = { cast(Union[int, str], k):v.realize().lazydata for k,v in itertools.chain(enumerate(args), sorted(kwargs.items())) if v.__class__ is Tensor } # noqa: E501
expected_name_sts_dtype_device = tuple([(k, v.st.unbind()[0] if isinstance(v, LazyBuffer) else ShapeTracker.from_shape(v.shape), v.dtype, v.device) for k,v in input_tensors.items()]) #noqa: E501
input_tensors: Dict[Union[int, str], Tensor] = { cast(Union[int, str], k):v for k,v in itertools.chain(enumerate(args), sorted(kwargs.items())) if v.__class__ is Tensor } # noqa: E501
Tensor.corealize(input_tensors.values())
input_lbs: Dict[Union[int, str], Union[LazyBuffer, MultiLazyBuffer]] = {k:v.lazydata for k,v in input_tensors.items()}
expected_name_sts_dtype_device = tuple([(k, v.st.unbind()[0] if isinstance(v, LazyBuffer) else ShapeTracker.from_shape(v.shape), v.dtype, v.device) for k,v in input_lbs.items()]) #noqa: E501
# get rawbuffers
lbs: List[LazyBuffer] = [v for v in input_tensors.values() if isinstance(v, LazyBuffer)] + flatten([mlb.lbs for mlb in input_tensors.values() if isinstance(mlb, MultiLazyBuffer)]) #noqa: E501
lbs: List[LazyBuffer] = [v for v in input_lbs.values() if isinstance(v, LazyBuffer)] + \
flatten([mlb.lbs for mlb in input_lbs.values() if isinstance(mlb, MultiLazyBuffer)])
input_rawbuffers: List[Buffer] = [v.base.realized for v in lbs if v.base.realized is not None]
assert len(set(input_rawbuffers)) == len(input_rawbuffers), "duplicate inputs to JIT"
@ -113,6 +116,7 @@ class TinyJit(Generic[ReturnType]):
for x,y in zip(self.expected_name_sts_dtype_device, expected_name_sts_dtype_device)), \
f"mismatch of input tensors, expected {self.expected_name_sts_dtype_device} got {expected_name_sts_dtype_device}"
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx]
if DEBUG >= 1: print(f"jit execs {len(self.jit_cache)} kernels")
for ji in self.jit_cache: ji.prg(cast(List[Buffer], ji.rawbufs), var_vals, wait=DEBUG>=2, jit=True)
elif self.cnt == 1:
# jit capture

View File

@ -238,6 +238,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe
# confirm everything was scheduled correctly
if not all(degree == 0 for degree in in_degree.values()) or len(prescheduled) != len(schedule):
raise RuntimeError(f"cycle detected in graph, prescheduled {len(prescheduled)} but only scheduled {len(schedule)}")
if DEBUG >= 1 and len(schedule) > 0: print(f"scheduled {len(schedule)} kernels")
return schedule, var_vals
def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]: