From c9b46028546638a7eb3dd56326478b04be9b341f Mon Sep 17 00:00:00 2001 From: Elias Wahl <82230675+Eliulm@users.noreply.github.com> Date: Thu, 8 Aug 2024 17:28:24 +0200 Subject: [PATCH] no load in INITMLPERF (#5957) --- examples/mlperf/helpers.py | 6 +++--- examples/mlperf/model_train.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/mlperf/helpers.py b/examples/mlperf/helpers.py index dab413e5..d232edfd 100644 --- a/examples/mlperf/helpers.py +++ b/examples/mlperf/helpers.py @@ -1,5 +1,6 @@ from collections import OrderedDict import unicodedata +from typing import Optional import numpy as np from tinygrad.nn import state from tinygrad.tensor import Tensor, dtypes @@ -207,7 +208,7 @@ def get_mlperf_bert_config(): "vocab_size": 30522 } -def get_mlperf_bert_model(checkpoint_path:str=""): +def get_mlperf_bert_model(checkpoint_path:Optional[str]=None): from extra.models import bert from examples.mlperf.initializers import LinearBert, EmbeddingBert, LayerNormBert @@ -220,8 +221,7 @@ def get_mlperf_bert_model(checkpoint_path:str=""): if getenv("DISABLE_DROPOUT", 0): config["hidden_dropout_prob"] = config["attention_probs_dropout_prob"] = 0.0 model = BertForPretraining(**config) - if checkpoint_path: model.load_from_pretrained(checkpoint_path) - return model + return model.load_from_pretrained(checkpoint_path) if checkpoint_path else model def get_data_bert(GPUS:list[str], it): data: dict[str, Tensor] = next(it) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 83020968..4ef4d1f2 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -455,7 +455,7 @@ def train_bert(): Tensor.manual_seed(seed) # seed for weight initialization - model = get_mlperf_bert_model(init_ckpt) + model = get_mlperf_bert_model(init_ckpt if not INITMLPERF else None) for _, x in get_state_dict(model).items(): x.realize().to_(GPUS)