env var to change default float (#3902)

* env var to change default float to fp16 or bf16

looking for standard names for these. we have FLOAT16 that does something to IMAGE and HALF to convert weights.

working on default bf16 too.
```
RuntimeError: compile failed: <null>(6): error: identifier "__bf16" is undefined
    __bf16 cast0 = (nv_bfloat16)(val0);
```

remove that in cifar

* DEFAULT_FLOAT

* default of default

* unit test

* don't check default

* tests work on linux
This commit is contained in:
chenyu 2024-03-24 20:33:57 -04:00 committed by GitHub
parent 03899a74bb
commit 83f39a8ceb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 30 additions and 15 deletions

View File

@ -117,11 +117,11 @@ jobs:
- name: Run 10 CIFAR training steps
run: CUDA=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar.txt
- name: Run 10 CIFAR training steps w HALF
run: CUDA=1 STEPS=10 HALF=1 python3 examples/hlb_cifar10.py | tee train_cifar_half.txt
run: CUDA=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_half.txt
- name: Run 10 CIFAR training steps w BF16
run: CUDA=1 STEPS=10 BF16=1 python3 examples/hlb_cifar10.py | tee train_cifar_bf16.txt
run: CUDA=1 STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3 examples/hlb_cifar10.py | tee train_cifar_bf16.txt
- name: Run full CIFAR training
run: time CUDA=1 HALF=1 LATEWINO=1 STEPS=1000 TARGET_EVAL_ACC_PCT=93.3 python3 examples/hlb_cifar10.py | tee train_cifar_one_gpu.txt
run: time CUDA=1 DEFAULT_FLOAT=HALF LATEWINO=1 STEPS=1000 TARGET_EVAL_ACC_PCT=93.3 python3 examples/hlb_cifar10.py | tee train_cifar_one_gpu.txt
- uses: actions/upload-artifact@v4
with:
name: Speed (NVIDIA)
@ -236,13 +236,13 @@ jobs:
- name: Run 10 CIFAR training steps
run: HSA=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar.txt
- name: Run 10 CIFAR training steps w HALF
run: HSA=1 STEPS=10 HALF=1 python3 examples/hlb_cifar10.py | tee train_cifar_half.txt
run: HSA=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_half.txt
- name: Run 10 CIFAR training steps w BF16
run: HSA=1 STEPS=10 BF16=1 python3 examples/hlb_cifar10.py | tee train_cifar_bf16.txt
run: HSA=1 STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3 examples/hlb_cifar10.py | tee train_cifar_bf16.txt
- name: Run full CIFAR training w 1 GPU
run: time HSA=1 HALF=1 LATEWINO=1 STEPS=1000 TARGET_EVAL_ACC_PCT=93.3 python3 examples/hlb_cifar10.py | tee train_cifar_one_gpu.txt
run: time HSA=1 DEFAULT_FLOAT=HALF LATEWINO=1 STEPS=1000 TARGET_EVAL_ACC_PCT=93.3 python3 examples/hlb_cifar10.py | tee train_cifar_one_gpu.txt
- name: Run full CIFAR training steps w 6 GPUS
run: time HSA=1 HALF=1 STEPS=350 BS=1536 GPUS=6 TARGET_EVAL_ACC_PCT=93.3 python3 examples/hlb_cifar10.py | tee train_cifar_six_gpu.txt
run: time HSA=1 DEFAULT_FLOAT=HALF STEPS=350 BS=1536 GPUS=6 TARGET_EVAL_ACC_PCT=93.3 python3 examples/hlb_cifar10.py | tee train_cifar_six_gpu.txt
- name: Run MLPerf resnet eval on training data
run: time HSA=1 MODEL=resnet python3 examples/mlperf/model_eval.py
- name: Run 10 MLPerf ResNet50 training steps (1 gpu)

View File

@ -45,6 +45,7 @@ BEAM | [#] | number of beams in kernel beam search
GRAPH | [1] | create a graph of all operations (requires graphviz)
GRAPHUOPS | [1] | create a graph of uops (requires graphviz and saves at /tmp/uops.{svg,dot})
GRAPHPATH | [/path/to] | where to put the generated graph
DEFAULT_FLOAT | [HALF, ...]| specify the default float dtype (FLOAT32, HALF, BFLOAT16, FLOAT64, ...), default to FLOAT32
IMAGE | [1-2] | enable 2d specific optimizations
FLOAT16 | [1] | use float16 for images instead of float32
DISALLOW_ASSIGN | [1] | disallow assignment of tensors

View File

@ -20,13 +20,6 @@ GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 1))]
assert BS % len(GPUS) == 0, f"{BS=} is not a multiple of {len(GPUS)=}, uneven multi GPU is slow"
assert EVAL_BS % len(GPUS) == 0, f"{EVAL_BS=} is not a multiple of {len(GPUS)=}, uneven multi GPU is slow"
if getenv("HALF"):
dtypes.default_float = dtypes.float16
elif getenv("BF16"):
dtypes.default_float = dtypes.bfloat16
else:
dtypes.default_float = dtypes.float32
class UnsyncedBatchNorm:
def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1, num_devices=len(GPUS)):
self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum

View File

@ -1,4 +1,4 @@
import unittest, operator
import unittest, operator, subprocess
import numpy as np
import torch
from typing import Any, List
@ -353,6 +353,22 @@ class TestTypeSpec(unittest.TestCase):
dtypes.default_float = default_float
assert dtypes.default_float == default_float
def test_env_set_default_float(self):
# check default
subprocess.run(['python3 -c "from tinygrad import dtypes; assert dtypes.default_float == dtypes.float"'],
shell=True, check=True)
# check change
subprocess.run(['DEFAULT_FLOAT=HALF python3 -c "from tinygrad import dtypes; assert dtypes.default_float == dtypes.half"'],
shell=True, check=True)
# check invalid
with self.assertRaises(subprocess.CalledProcessError):
subprocess.run(['DEFAULT_FLOAT=INT32 python3 -c "from tinygrad import dtypes"'],
shell=True, check=True)
with self.assertRaises(subprocess.CalledProcessError):
subprocess.run(['DEFAULT_FLOAT=TYPO python3 -c "from tinygrad import dtypes"'],
shell=True, check=True)
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
def test_creation(self, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float

View File

@ -2,6 +2,7 @@ from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union
from dataclasses import dataclass
import numpy as np # TODO: remove numpy
import functools
from tinygrad.helpers import getenv
Scalar = Union[float, int, bool]
@ -83,6 +84,10 @@ class dtypes:
default_float: ClassVar[DType] = float32
default_int: ClassVar[DType] = int32
if (env_default_float := getenv("DEFAULT_FLOAT", "")):
dtypes.default_float = getattr(dtypes, env_default_float.lower())
assert dtypes.is_float(dtypes.default_float), f"{env_default_float} is not a float dtype"
# https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html
# we don't support weak type and complex type
promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64],