mirror of https://github.com/commaai/tinygrad.git
Fix fake dataloader (#5326)
This commit is contained in:
parent
6856f915d6
commit
73bddc44f6
|
@ -30,9 +30,9 @@ if __name__ == "__main__":
|
|||
input_ids = Tensor.empty((BS, 512), dtype=dtypes.float32)
|
||||
segment_ids = Tensor.empty((BS, 512), dtype=dtypes.float32)
|
||||
attention_mask = Tensor.empty((BS, 512), dtype=dtypes.default_float)
|
||||
masked_positions = Tensor.empty((BS, 512), dtype=dtypes.float32)
|
||||
masked_lm_ids = Tensor.empty((BS, 512), dtype=dtypes.float32)
|
||||
masked_lm_weights = Tensor.empty((BS, 512), dtype=dtypes.float32)
|
||||
masked_positions = Tensor.empty((BS, 76), dtype=dtypes.float32)
|
||||
masked_lm_ids = Tensor.empty((BS, 76), dtype=dtypes.float32)
|
||||
masked_lm_weights = Tensor.empty((BS, 76), dtype=dtypes.float32)
|
||||
next_sentence_labels = Tensor.empty((BS, 1), dtype=dtypes.float32)
|
||||
|
||||
# run model twice to get only what changes, these are the kernels of the model
|
||||
|
|
|
@ -230,11 +230,11 @@ def get_data_bert(GPUS:list[str], it):
|
|||
|
||||
def get_fake_data_bert(GPUS:list[str], BS:int):
|
||||
return {
|
||||
"input_ids": Tensor.zeros((BS, 512), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
|
||||
"input_mask": Tensor.zeros((BS, 512), dtype=dtypes.default_float).contiguous().shard_(GPUS, axis=0),
|
||||
"segment_ids": Tensor.zeros((BS, 512), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
|
||||
"masked_lm_positions": Tensor.zeros((BS, 512), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
|
||||
"masked_lm_ids": Tensor.zeros((BS, 512), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
|
||||
"masked_lm_weights": Tensor.zeros((BS, 512), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
|
||||
"next_sentence_labels": Tensor.zeros((BS, 1), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
|
||||
"input_ids": Tensor.empty((BS, 512), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
|
||||
"input_mask": Tensor.empty((BS, 512), dtype=dtypes.default_float).contiguous().shard_(GPUS, axis=0),
|
||||
"segment_ids": Tensor.empty((BS, 512), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
|
||||
"masked_lm_positions": Tensor.empty((BS, 76), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
|
||||
"masked_lm_ids": Tensor.empty((BS, 76), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
|
||||
"masked_lm_weights": Tensor.empty((BS, 76), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
|
||||
"next_sentence_labels": Tensor.empty((BS, 1), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
|
||||
}
|
||||
|
|
|
@ -590,7 +590,7 @@ def train_bert():
|
|||
|
||||
for j in tqdm(range(max_eval_steps), desc="Evaluating", total=max_eval_steps, disable=BENCHMARK):
|
||||
if INITMLPERF:
|
||||
eval_data = get_fake_data_bert(GPUS, BS)
|
||||
eval_data = get_fake_data_bert(GPUS, EVAL_BS)
|
||||
else:
|
||||
eval_data = get_data_bert(GPUS, eval_it)
|
||||
GlobalCounters.reset()
|
||||
|
|
Loading…
Reference in New Issue