mirror of https://github.com/commaai/tinygrad.git
no load in INITMLPERF (#5957)
This commit is contained in:
parent
183c4c91a3
commit
c9b4602854
|
@ -1,5 +1,6 @@
|
|||
from collections import OrderedDict
|
||||
import unicodedata
|
||||
from typing import Optional
|
||||
import numpy as np
|
||||
from tinygrad.nn import state
|
||||
from tinygrad.tensor import Tensor, dtypes
|
||||
|
@ -207,7 +208,7 @@ def get_mlperf_bert_config():
|
|||
"vocab_size": 30522
|
||||
}
|
||||
|
||||
def get_mlperf_bert_model(checkpoint_path:str=""):
|
||||
def get_mlperf_bert_model(checkpoint_path:Optional[str]=None):
|
||||
from extra.models import bert
|
||||
from examples.mlperf.initializers import LinearBert, EmbeddingBert, LayerNormBert
|
||||
|
||||
|
@ -220,8 +221,7 @@ def get_mlperf_bert_model(checkpoint_path:str=""):
|
|||
if getenv("DISABLE_DROPOUT", 0):
|
||||
config["hidden_dropout_prob"] = config["attention_probs_dropout_prob"] = 0.0
|
||||
model = BertForPretraining(**config)
|
||||
if checkpoint_path: model.load_from_pretrained(checkpoint_path)
|
||||
return model
|
||||
return model.load_from_pretrained(checkpoint_path) if checkpoint_path else model
|
||||
|
||||
def get_data_bert(GPUS:list[str], it):
|
||||
data: dict[str, Tensor] = next(it)
|
||||
|
|
|
@ -455,7 +455,7 @@ def train_bert():
|
|||
|
||||
Tensor.manual_seed(seed) # seed for weight initialization
|
||||
|
||||
model = get_mlperf_bert_model(init_ckpt)
|
||||
model = get_mlperf_bert_model(init_ckpt if not INITMLPERF else None)
|
||||
|
||||
for _, x in get_state_dict(model).items():
|
||||
x.realize().to_(GPUS)
|
||||
|
|
Loading…
Reference in New Issue