mirror of https://github.com/commaai/tinygrad.git
Do less realizes (#4141)
* less realize * corealize jit inputs * prints * print before we run
This commit is contained in:
parent
06bcae13b4
commit
2e6c39b0b2
|
@ -101,7 +101,7 @@ class TransformerBlock:
|
|||
|
||||
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
|
||||
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
||||
return (h + self.feed_forward(self.ffn_norm(h).half())).realize()
|
||||
return h + self.feed_forward(self.ffn_norm(h).half())
|
||||
|
||||
class Transformer:
|
||||
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward):
|
||||
|
|
|
@ -94,11 +94,14 @@ class TinyJit(Generic[ReturnType]):
|
|||
|
||||
def __call__(self, *args, **kwargs) -> ReturnType:
|
||||
# all inputs (except const) are realized
|
||||
input_tensors: Dict[Union[int, str], Union[LazyBuffer, MultiLazyBuffer]] = { cast(Union[int, str], k):v.realize().lazydata for k,v in itertools.chain(enumerate(args), sorted(kwargs.items())) if v.__class__ is Tensor } # noqa: E501
|
||||
expected_name_sts_dtype_device = tuple([(k, v.st.unbind()[0] if isinstance(v, LazyBuffer) else ShapeTracker.from_shape(v.shape), v.dtype, v.device) for k,v in input_tensors.items()]) #noqa: E501
|
||||
input_tensors: Dict[Union[int, str], Tensor] = { cast(Union[int, str], k):v for k,v in itertools.chain(enumerate(args), sorted(kwargs.items())) if v.__class__ is Tensor } # noqa: E501
|
||||
Tensor.corealize(input_tensors.values())
|
||||
input_lbs: Dict[Union[int, str], Union[LazyBuffer, MultiLazyBuffer]] = {k:v.lazydata for k,v in input_tensors.items()}
|
||||
expected_name_sts_dtype_device = tuple([(k, v.st.unbind()[0] if isinstance(v, LazyBuffer) else ShapeTracker.from_shape(v.shape), v.dtype, v.device) for k,v in input_lbs.items()]) #noqa: E501
|
||||
|
||||
# get rawbuffers
|
||||
lbs: List[LazyBuffer] = [v for v in input_tensors.values() if isinstance(v, LazyBuffer)] + flatten([mlb.lbs for mlb in input_tensors.values() if isinstance(mlb, MultiLazyBuffer)]) #noqa: E501
|
||||
lbs: List[LazyBuffer] = [v for v in input_lbs.values() if isinstance(v, LazyBuffer)] + \
|
||||
flatten([mlb.lbs for mlb in input_lbs.values() if isinstance(mlb, MultiLazyBuffer)])
|
||||
input_rawbuffers: List[Buffer] = [v.base.realized for v in lbs if v.base.realized is not None]
|
||||
assert len(set(input_rawbuffers)) == len(input_rawbuffers), "duplicate inputs to JIT"
|
||||
|
||||
|
@ -113,6 +116,7 @@ class TinyJit(Generic[ReturnType]):
|
|||
for x,y in zip(self.expected_name_sts_dtype_device, expected_name_sts_dtype_device)), \
|
||||
f"mismatch of input tensors, expected {self.expected_name_sts_dtype_device} got {expected_name_sts_dtype_device}"
|
||||
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx]
|
||||
if DEBUG >= 1: print(f"jit execs {len(self.jit_cache)} kernels")
|
||||
for ji in self.jit_cache: ji.prg(cast(List[Buffer], ji.rawbufs), var_vals, wait=DEBUG>=2, jit=True)
|
||||
elif self.cnt == 1:
|
||||
# jit capture
|
||||
|
|
|
@ -238,6 +238,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe
|
|||
# confirm everything was scheduled correctly
|
||||
if not all(degree == 0 for degree in in_degree.values()) or len(prescheduled) != len(schedule):
|
||||
raise RuntimeError(f"cycle detected in graph, prescheduled {len(prescheduled)} but only scheduled {len(schedule)}")
|
||||
if DEBUG >= 1 and len(schedule) > 0: print(f"scheduled {len(schedule)} kernels")
|
||||
return schedule, var_vals
|
||||
|
||||
def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
|
||||
|
|
Loading…
Reference in New Issue