diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 3577cbc3..140b6913 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -117,11 +117,11 @@ jobs: - name: Run 10 CIFAR training steps run: CUDA=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar.txt - name: Run 10 CIFAR training steps w HALF - run: CUDA=1 STEPS=10 HALF=1 python3 examples/hlb_cifar10.py | tee train_cifar_half.txt + run: CUDA=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_half.txt - name: Run 10 CIFAR training steps w BF16 - run: CUDA=1 STEPS=10 BF16=1 python3 examples/hlb_cifar10.py | tee train_cifar_bf16.txt + run: CUDA=1 STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3 examples/hlb_cifar10.py | tee train_cifar_bf16.txt - name: Run full CIFAR training - run: time CUDA=1 HALF=1 LATEWINO=1 STEPS=1000 TARGET_EVAL_ACC_PCT=93.3 python3 examples/hlb_cifar10.py | tee train_cifar_one_gpu.txt + run: time CUDA=1 DEFAULT_FLOAT=HALF LATEWINO=1 STEPS=1000 TARGET_EVAL_ACC_PCT=93.3 python3 examples/hlb_cifar10.py | tee train_cifar_one_gpu.txt - uses: actions/upload-artifact@v4 with: name: Speed (NVIDIA) @@ -236,13 +236,13 @@ jobs: - name: Run 10 CIFAR training steps run: HSA=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar.txt - name: Run 10 CIFAR training steps w HALF - run: HSA=1 STEPS=10 HALF=1 python3 examples/hlb_cifar10.py | tee train_cifar_half.txt + run: HSA=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_half.txt - name: Run 10 CIFAR training steps w BF16 - run: HSA=1 STEPS=10 BF16=1 python3 examples/hlb_cifar10.py | tee train_cifar_bf16.txt + run: HSA=1 STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3 examples/hlb_cifar10.py | tee train_cifar_bf16.txt - name: Run full CIFAR training w 1 GPU - run: time HSA=1 HALF=1 LATEWINO=1 STEPS=1000 TARGET_EVAL_ACC_PCT=93.3 python3 examples/hlb_cifar10.py | tee train_cifar_one_gpu.txt + run: time HSA=1 DEFAULT_FLOAT=HALF LATEWINO=1 STEPS=1000 TARGET_EVAL_ACC_PCT=93.3 python3 examples/hlb_cifar10.py | tee train_cifar_one_gpu.txt - name: Run full CIFAR training steps w 6 GPUS - run: time HSA=1 HALF=1 STEPS=350 BS=1536 GPUS=6 TARGET_EVAL_ACC_PCT=93.3 python3 examples/hlb_cifar10.py | tee train_cifar_six_gpu.txt + run: time HSA=1 DEFAULT_FLOAT=HALF STEPS=350 BS=1536 GPUS=6 TARGET_EVAL_ACC_PCT=93.3 python3 examples/hlb_cifar10.py | tee train_cifar_six_gpu.txt - name: Run MLPerf resnet eval on training data run: time HSA=1 MODEL=resnet python3 examples/mlperf/model_eval.py - name: Run 10 MLPerf ResNet50 training steps (1 gpu) diff --git a/docs/env_vars.md b/docs/env_vars.md index dba133b3..a342b560 100644 --- a/docs/env_vars.md +++ b/docs/env_vars.md @@ -45,6 +45,7 @@ BEAM | [#] | number of beams in kernel beam search GRAPH | [1] | create a graph of all operations (requires graphviz) GRAPHUOPS | [1] | create a graph of uops (requires graphviz and saves at /tmp/uops.{svg,dot}) GRAPHPATH | [/path/to] | where to put the generated graph +DEFAULT_FLOAT | [HALF, ...]| specify the default float dtype (FLOAT32, HALF, BFLOAT16, FLOAT64, ...), default to FLOAT32 IMAGE | [1-2] | enable 2d specific optimizations FLOAT16 | [1] | use float16 for images instead of float32 DISALLOW_ASSIGN | [1] | disallow assignment of tensors diff --git a/examples/hlb_cifar10.py b/examples/hlb_cifar10.py index f72ba7c0..8b024c22 100644 --- a/examples/hlb_cifar10.py +++ b/examples/hlb_cifar10.py @@ -20,13 +20,6 @@ GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 1))] assert BS % len(GPUS) == 0, f"{BS=} is not a multiple of {len(GPUS)=}, uneven multi GPU is slow" assert EVAL_BS % len(GPUS) == 0, f"{EVAL_BS=} is not a multiple of {len(GPUS)=}, uneven multi GPU is slow" -if getenv("HALF"): - dtypes.default_float = dtypes.float16 -elif getenv("BF16"): - dtypes.default_float = dtypes.bfloat16 -else: - dtypes.default_float = dtypes.float32 - class UnsyncedBatchNorm: def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1, num_devices=len(GPUS)): self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum diff --git a/test/test_dtype.py b/test/test_dtype.py index 0a8dc291..53566f6f 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -1,4 +1,4 @@ -import unittest, operator +import unittest, operator, subprocess import numpy as np import torch from typing import Any, List @@ -353,6 +353,22 @@ class TestTypeSpec(unittest.TestCase): dtypes.default_float = default_float assert dtypes.default_float == default_float + def test_env_set_default_float(self): + # check default + subprocess.run(['python3 -c "from tinygrad import dtypes; assert dtypes.default_float == dtypes.float"'], + shell=True, check=True) + # check change + subprocess.run(['DEFAULT_FLOAT=HALF python3 -c "from tinygrad import dtypes; assert dtypes.default_float == dtypes.half"'], + shell=True, check=True) + # check invalid + with self.assertRaises(subprocess.CalledProcessError): + subprocess.run(['DEFAULT_FLOAT=INT32 python3 -c "from tinygrad import dtypes"'], + shell=True, check=True) + + with self.assertRaises(subprocess.CalledProcessError): + subprocess.run(['DEFAULT_FLOAT=TYPO python3 -c "from tinygrad import dtypes"'], + shell=True, check=True) + @given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats)) def test_creation(self, default_int, default_float): dtypes.default_int, dtypes.default_float = default_int, default_float diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index d108e496..ee177e51 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -2,6 +2,7 @@ from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union from dataclasses import dataclass import numpy as np # TODO: remove numpy import functools +from tinygrad.helpers import getenv Scalar = Union[float, int, bool] @@ -83,6 +84,10 @@ class dtypes: default_float: ClassVar[DType] = float32 default_int: ClassVar[DType] = int32 +if (env_default_float := getenv("DEFAULT_FLOAT", "")): + dtypes.default_float = getattr(dtypes, env_default_float.lower()) + assert dtypes.is_float(dtypes.default_float), f"{env_default_float} is not a float dtype" + # https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html # we don't support weak type and complex type promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64],