Fix fake dataloader (#5326)

This commit is contained in:
Elias Wahl 2024-07-08 15:07:44 +02:00 committed by GitHub
parent 6856f915d6
commit 73bddc44f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 11 additions and 11 deletions

View File

@ -30,9 +30,9 @@ if __name__ == "__main__":
input_ids = Tensor.empty((BS, 512), dtype=dtypes.float32)
segment_ids = Tensor.empty((BS, 512), dtype=dtypes.float32)
attention_mask = Tensor.empty((BS, 512), dtype=dtypes.default_float)
masked_positions = Tensor.empty((BS, 512), dtype=dtypes.float32)
masked_lm_ids = Tensor.empty((BS, 512), dtype=dtypes.float32)
masked_lm_weights = Tensor.empty((BS, 512), dtype=dtypes.float32)
masked_positions = Tensor.empty((BS, 76), dtype=dtypes.float32)
masked_lm_ids = Tensor.empty((BS, 76), dtype=dtypes.float32)
masked_lm_weights = Tensor.empty((BS, 76), dtype=dtypes.float32)
next_sentence_labels = Tensor.empty((BS, 1), dtype=dtypes.float32)
# run model twice to get only what changes, these are the kernels of the model

View File

@ -230,11 +230,11 @@ def get_data_bert(GPUS:list[str], it):
def get_fake_data_bert(GPUS:list[str], BS:int):
return {
"input_ids": Tensor.zeros((BS, 512), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
"input_mask": Tensor.zeros((BS, 512), dtype=dtypes.default_float).contiguous().shard_(GPUS, axis=0),
"segment_ids": Tensor.zeros((BS, 512), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
"masked_lm_positions": Tensor.zeros((BS, 512), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
"masked_lm_ids": Tensor.zeros((BS, 512), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
"masked_lm_weights": Tensor.zeros((BS, 512), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
"next_sentence_labels": Tensor.zeros((BS, 1), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
"input_ids": Tensor.empty((BS, 512), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
"input_mask": Tensor.empty((BS, 512), dtype=dtypes.default_float).contiguous().shard_(GPUS, axis=0),
"segment_ids": Tensor.empty((BS, 512), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
"masked_lm_positions": Tensor.empty((BS, 76), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
"masked_lm_ids": Tensor.empty((BS, 76), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
"masked_lm_weights": Tensor.empty((BS, 76), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
"next_sentence_labels": Tensor.empty((BS, 1), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
}

View File

@ -590,7 +590,7 @@ def train_bert():
for j in tqdm(range(max_eval_steps), desc="Evaluating", total=max_eval_steps, disable=BENCHMARK):
if INITMLPERF:
eval_data = get_fake_data_bert(GPUS, BS)
eval_data = get_fake_data_bert(GPUS, EVAL_BS)
else:
eval_data = get_data_bert(GPUS, eval_it)
GlobalCounters.reset()