mirror of https://github.com/commaai/tinygrad.git
log resnet TRAIN_BEAM / EVAL_BEAM (#4181)
also run eval in benchmark mode if either one is positive
This commit is contained in:
parent
9d2273235c
commit
d5b67c1ca3
|
@ -28,6 +28,9 @@ def train_resnet():
|
|||
print(f"Training on {GPUS}")
|
||||
for x in GPUS: Device[x]
|
||||
|
||||
TRAIN_BEAM = getenv("TRAIN_BEAM", BEAM.value)
|
||||
EVAL_BEAM = getenv("EVAL_BEAM", BEAM.value)
|
||||
|
||||
# ** model definition and initializers **
|
||||
num_classes = 1000
|
||||
resnet.Conv2d = Conv2dHeNormal
|
||||
|
@ -61,9 +64,11 @@ def train_resnet():
|
|||
steps_in_val_epoch = config["steps_in_val_epoch"] = (len(get_val_files()) // EVAL_BS)
|
||||
|
||||
config["DEFAULT_FLOAT"] = dtypes.default_float.name
|
||||
config["BEAM"] = BEAM.value
|
||||
config["WINO"] = WINO.value
|
||||
config["SYNCBN"] = getenv("SYNCBN")
|
||||
config["BEAM"] = BEAM.value
|
||||
config["TRAIN_BEAM"] = TRAIN_BEAM
|
||||
config["EVAL_BEAM"] = EVAL_BEAM
|
||||
config["WINO"] = WINO.value
|
||||
config["SYNCBN"] = getenv("SYNCBN")
|
||||
|
||||
# ** Optimizer **
|
||||
skip_list = [v for k, v in get_state_dict(model).items() if "bn" in k or "bias" in k or "downsample.1" in k]
|
||||
|
@ -105,7 +110,7 @@ def train_resnet():
|
|||
def normalize(x): return (x.permute([0, 3, 1, 2]) - input_mean).cast(dtypes.default_float)
|
||||
@TinyJit
|
||||
def train_step(X, Y):
|
||||
with Context(BEAM=getenv("TRAIN_BEAM", BEAM.value)):
|
||||
with Context(BEAM=TRAIN_BEAM):
|
||||
optimizer_group.zero_grad()
|
||||
X = normalize(X)
|
||||
out = model.forward(X)
|
||||
|
@ -119,7 +124,7 @@ def train_resnet():
|
|||
|
||||
@TinyJit
|
||||
def eval_step(X, Y):
|
||||
with Context(BEAM=getenv("EVAL_BEAM", BEAM.value)):
|
||||
with Context(BEAM=EVAL_BEAM):
|
||||
X = normalize(X)
|
||||
out = model.forward(X)
|
||||
loss = out.cast(dtypes.float32).sparse_categorical_crossentropy(Y, label_smoothing=0.1)
|
||||
|
@ -177,7 +182,7 @@ def train_resnet():
|
|||
estimated_total_minutes = int(median_step_time * steps_in_train_epoch * epochs / 60)
|
||||
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
|
||||
# if we are doing beam search, run the first eval too
|
||||
if BEAM.value and e == start_epoch: break
|
||||
if (TRAIN_BEAM or EVAL_BEAM) and e == start_epoch: break
|
||||
return
|
||||
|
||||
# ** eval loop **
|
||||
|
|
Loading…
Reference in New Issue