no load in INITMLPERF (#5957)

This commit is contained in:
Elias Wahl 2024-08-08 17:28:24 +02:00 committed by GitHub
parent 183c4c91a3
commit c9b4602854
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 4 deletions

View File

@ -1,5 +1,6 @@
from collections import OrderedDict
import unicodedata
from typing import Optional
import numpy as np
from tinygrad.nn import state
from tinygrad.tensor import Tensor, dtypes
@ -207,7 +208,7 @@ def get_mlperf_bert_config():
"vocab_size": 30522
}
def get_mlperf_bert_model(checkpoint_path:str=""):
def get_mlperf_bert_model(checkpoint_path:Optional[str]=None):
from extra.models import bert
from examples.mlperf.initializers import LinearBert, EmbeddingBert, LayerNormBert
@ -220,8 +221,7 @@ def get_mlperf_bert_model(checkpoint_path:str=""):
if getenv("DISABLE_DROPOUT", 0):
config["hidden_dropout_prob"] = config["attention_probs_dropout_prob"] = 0.0
model = BertForPretraining(**config)
if checkpoint_path: model.load_from_pretrained(checkpoint_path)
return model
return model.load_from_pretrained(checkpoint_path) if checkpoint_path else model
def get_data_bert(GPUS:list[str], it):
data: dict[str, Tensor] = next(it)

View File

@ -455,7 +455,7 @@ def train_bert():
Tensor.manual_seed(seed) # seed for weight initialization
model = get_mlperf_bert_model(init_ckpt)
model = get_mlperf_bert_model(init_ckpt if not INITMLPERF else None)
for _, x in get_state_dict(model).items():
x.realize().to_(GPUS)