diff --git a/.pylintrc b/.pylintrc index fe6b6ac4..00645b29 100644 --- a/.pylintrc +++ b/.pylintrc @@ -248,7 +248,7 @@ indent-after-paren=4 indent-string=' ' # Maximum number of characters on a single line. -max-line-length=100 +max-line-length=150 # Maximum number of lines in a module max-module-lines=1000 diff --git a/ruff.toml b/ruff.toml index b916c118..7ac0a4c5 100644 --- a/ruff.toml +++ b/ruff.toml @@ -12,6 +12,7 @@ select = [ "E272", # "E303", # "E304", + "E501", # "E502", "E702", "E703", @@ -22,6 +23,8 @@ select = [ "UP039", # unnecessary-class-parentheses ] +line-length = 150 + exclude = [ "disassemblers/", "docs/", diff --git a/setup.py b/setup.py index 7164c3f3..05b390ac 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,8 @@ setup(name='tinygrad', license='MIT', long_description=long_description, long_description_content_type='text/markdown', - packages = ['tinygrad', 'tinygrad.codegen', 'tinygrad.nn', 'tinygrad.renderer', 'tinygrad.runtime', 'tinygrad.shape', 'tinygrad.features', 'tinygrad.features.graph'], + packages = ['tinygrad', 'tinygrad.codegen', 'tinygrad.nn', 'tinygrad.renderer', + 'tinygrad.runtime', 'tinygrad.shape', 'tinygrad.features', 'tinygrad.features.graph'], classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License" diff --git a/sz.py b/sz.py index b64de656..c66aadd3 100755 --- a/sz.py +++ b/sz.py @@ -38,7 +38,7 @@ def gen_diff(table_old, table_new): file_stat_old = [stats for stats in table_old if file in stats] file_stat_new = [stats for stats in table_new if file in stats] if file_stat_new[0][1]-file_stat_old[0][1] != 0 or file_stat_new[0][2]-file_stat_old[0][2] != 0: - table.append([file_stat_new[0][0], file_stat_new[0][1], file_stat_new[0][1]-file_stat_old[0][1], file_stat_new[0][2], file_stat_new[0][2]-file_stat_old[0][2]]) + table.append([file_stat_new[0][0], file_stat_new[0][1], file_stat_new[0][1]-file_stat_old[0][1], file_stat_new[0][2], file_stat_new[0][2]-file_stat_old[0][2]]) # noqa: E501 return table def display_diff(diff): return "+"+str(diff) if diff > 0 else str(diff) @@ -58,7 +58,7 @@ if __name__ == "__main__": if len(sys.argv) == 3: print("### Changes") print("```") - print(tabulate([headers] + sorted(table, key=lambda x: -x[1]), headers="firstrow", intfmt=(..., "d", "+d"), floatfmt=(..., ..., ..., ".1f", "+.1f"))+"\n") + print(tabulate([headers] + sorted(table, key=lambda x: -x[1]), headers="firstrow", intfmt=(..., "d", "+d"), floatfmt=(..., ..., ..., ".1f", "+.1f"))+"\n") # noqa: E501 print(f"\ntotal lines changes: {display_diff(sum([x[2] for x in table]))}") print("```") else: diff --git a/test/external/external_llama_eval.py b/test/external/external_llama_eval.py index ce1cbd80..cf08b38d 100644 --- a/test/external/external_llama_eval.py +++ b/test/external/external_llama_eval.py @@ -97,6 +97,7 @@ if __name__ == '__main__': args = parser.parse_args() # run eval and exit - adaptor = LLaMaAdaptor(model_gen=args.gen, model_size=args.size, quantize=args.quantize, checkpoint_path=args.weights, tokenizer_path=args.tokenizer, device="cpu") + adaptor = LLaMaAdaptor(model_gen=args.gen, model_size=args.size, quantize=args.quantize, + checkpoint_path=args.weights, tokenizer_path=args.tokenizer, device="cpu") results = evaluator.evaluate(adaptor, tasks.get_task_dict(args.eval.split(",")), False, 0, args.limit) print(json.dumps(results, indent=2)) diff --git a/test/external/external_model_benchmark.py b/test/external/external_model_benchmark.py index 01742126..bc2de5bd 100644 --- a/test/external/external_model_benchmark.py +++ b/test/external/external_model_benchmark.py @@ -52,7 +52,7 @@ def benchmark_model(m, devices, validate_outs=False): onnx_model = onnx.load(fn) output_names = [out.name for out in onnx_model.graph.output] excluded = {inp.name for inp in onnx_model.graph.initializer} - input_shapes = {inp.name:tuple(x.dim_value if x.dim_value != 0 else 1 for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input if inp.name not in excluded} + input_shapes = {inp.name:tuple(x.dim_value if x.dim_value != 0 else 1 for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input if inp.name not in excluded} # noqa: E501 input_types = {inp.name: tensor_dtype_to_np_dtype(inp.type.tensor_type.elem_type) for inp in onnx_model.graph.input if inp.name not in excluded} #input_types = {k:v if v!=np.float16 else np.float32 for k,v in input_types.items()} # cast np_inputs = {k:torch.randn(shp).numpy().astype(input_types[k]) for k,shp in input_shapes.items()} diff --git a/test/external/external_test_opt.py b/test/external/external_test_opt.py index a80a9e24..4a007c04 100644 --- a/test/external/external_test_opt.py +++ b/test/external/external_test_opt.py @@ -113,7 +113,9 @@ class TestOptBinOp(unittest.TestCase): def test_no_binop_rerun(self): return self._test_no_binop_rerun(lambda a,b: a*b, lambda a,b: (a*b).reshape(16, 16, 1)) def test_no_binop_rerun_alt(self): return self._test_no_binop_rerun(lambda a,b: (a*b).reshape(16, 16, 1), lambda a,b: a*b) - def test_no_binop_rerun_reduce_broadcast(self): return self._test_no_binop_rerun(lambda a,b: a.sum()+b, lambda a,b: a.sum().reshape(1,1)+b, allowed=2) + def test_no_binop_rerun_reduce_broadcast(self): + return self._test_no_binop_rerun(lambda a,b: a.sum()+b, lambda a,b: a.sum().reshape(1,1)+b, allowed=2) + @unittest.skip("this test started failing with the new change, based movementop issue") def test_no_binop_rerun_transposed(self): return self._test_no_binop_rerun(lambda a,b: (a.T*b.T).T, lambda a,b: a*b) def test_no_binop_rerun_mid_reshape(self): return self._test_no_binop_rerun(lambda a,b: (a*b).reshape(256)+a.reshape(256)) diff --git a/test/external/external_test_yolov8.py b/test/external/external_test_yolov8.py index c98a4266..a215aa06 100644 --- a/test/external/external_test_yolov8.py +++ b/test/external/external_test_yolov8.py @@ -35,7 +35,7 @@ class TestYOLOv8(unittest.TestCase): def test_forward_pass_torch_onnx(self): variant = 'n' weights_location = fetch(f'https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{variant}.safetensors') - weights_location_pt = fetch(f'https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8{variant}.pt', name=f"yolov8{variant}.pt") # it needs the pt extension + weights_location_pt = fetch(f'https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8{variant}.pt', name=f"yolov8{variant}.pt") # it needs the pt extension # noqa: E501 weights_location_onnx = weights_location_pt.parent / f"yolov8{variant}.onnx" # the ultralytics export prints a lot of unneccesary things @@ -48,7 +48,7 @@ class TestYOLOv8(unittest.TestCase): state_dict = safe_load(weights_location) load_state_dict(TinyYolov8, state_dict) - image_location = [np.frombuffer(io.BytesIO(fetch('https://raw.githubusercontent.com/ultralytics/yolov5/master/data/images/bus.jpg').read_bytes()).read(), np.uint8)] + image_location = [np.frombuffer(io.BytesIO(fetch('https://raw.githubusercontent.com/ultralytics/yolov5/master/data/images/bus.jpg').read_bytes()).read(), np.uint8)] # noqa: E501 orig_image = [cv2.imdecode(image_location[0], 1)] input_image = preprocess(orig_image) diff --git a/test/extra/test_export_model.py b/test/extra/test_export_model.py index 675e46d0..4d0671c3 100644 --- a/test/extra/test_export_model.py +++ b/test/extra/test_export_model.py @@ -20,14 +20,14 @@ class TextModelExport(unittest.TestCase): prg, inp_sizes, _, _ = export_model(model, "", *inputs) prg = json.loads(prg) - assert len(inputs) == len(prg["inputs"]) == len(inp_sizes), f"Model and exported inputs don't match: mdl={len(inputs)}, prg={len(prg['inputs'])}, inp_sizes={len(inp_sizes)}" + assert len(inputs) == len(prg["inputs"]) == len(inp_sizes), f"Model and exported inputs don't match: mdl={len(inputs)}, prg={len(prg['inputs'])}, inp_sizes={len(inp_sizes)}" # noqa: E501 for i in range(len(inputs)): assert f"input{i}" in inp_sizes, f"input{i} not captured in inp_sizes" assert f"input{i}" in prg["buffers"], f"input{i} not captured in exported buffers" for i, exported_input in enumerate(prg["inputs"]): - assert inputs[i].dtype.name == exported_input["dtype"], f"Model and exported input dtype don't match: mdl={inputs[i].dtype.name}, prg={exported_input['dtype']}" + assert inputs[i].dtype.name == exported_input["dtype"], f"Model and exported input dtype don't match: mdl={inputs[i].dtype.name}, prg={exported_input['dtype']}" # noqa: E501 def test_multi_output_model_export(self): model = MockMultiOutputModel() @@ -36,14 +36,14 @@ class TextModelExport(unittest.TestCase): prg, _, out_sizes, _ = export_model(model, "", input) prg = json.loads(prg) - assert len(outputs) == len(prg["outputs"]) == len(out_sizes), f"Model and exported outputs don't match: mdl={len(outputs)}, prg={len(prg['outputs'])}, inp_sizes={len(out_sizes)}" + assert len(outputs) == len(prg["outputs"]) == len(out_sizes), f"Model and exported outputs don't match: mdl={len(outputs)}, prg={len(prg['outputs'])}, inp_sizes={len(out_sizes)}" # noqa: E501 for i in range(len(outputs)): assert f"output{i}" in out_sizes, f"output{i} not captured in out_sizes" assert f"output{i}" in prg["buffers"], f"output{i} not captured in exported buffers" for i, exported_output in enumerate(prg["outputs"]): - assert outputs[i].dtype.name == exported_output["dtype"], f"Model and exported output dtype don't match: mdl={outputs[i].dtype.name}, prg={exported_output['dtype']}" + assert outputs[i].dtype.name == exported_output["dtype"], f"Model and exported output dtype don't match: mdl={outputs[i].dtype.name}, prg={exported_output['dtype']}" # noqa: E501 if __name__ == '__main__': diff --git a/test/extra/test_lr_scheduler.py b/test/extra/test_lr_scheduler.py index 3bff3660..bc8792d1 100644 --- a/test/extra/test_lr_scheduler.py +++ b/test/extra/test_lr_scheduler.py @@ -100,7 +100,8 @@ class TestLrScheduler(unittest.TestCase): without = lr_scheduler_training() sched_fns = [MultiStepLR, ReduceLROnPlateau, CosineAnnealingLR, OneCycleLR] argss = [{'milestones': [5, 7, 10, 15], 'gamma': 0.5}, {'factor': 0.5, 'patience': 2}, {'T_max': 25, 'eta_min': 0.001}, - {'pct_start': 0.3, 'anneal_strategy': 'linear', 'cycle_momentum': False, 'div_factor': 25.0, 'final_div_factor': 10000.0, 'max_lr':1e-5, 'total_steps': 25}] + {'pct_start': 0.3, 'anneal_strategy': 'linear', 'cycle_momentum': False, 'div_factor': 25.0, 'final_div_factor': 10000.0, + 'max_lr':1e-5, 'total_steps': 25}] for sched_fn, args in zip(sched_fns, argss): with_sched = lr_scheduler_training(sched_fn, args) assert with_sched > without diff --git a/test/imported/test_indexing.py b/test/imported/test_indexing.py index d7d54398..4e391655 100644 --- a/test/imported/test_indexing.py +++ b/test/imported/test_indexing.py @@ -24,7 +24,7 @@ def consec(shape, start=1): def set_(reference: Tensor, shape, strides, offset): if reference.lazydata.base.realized is None: reference.realize() assert reference.lazydata.base.realized, "base has to be realized before setting it to strided's base" - strided = Tensor(LazyBuffer(device=reference.device, st=ShapeTracker((View.create(shape=shape, strides=strides, offset=offset),)), optype=None, op=None, dtype=reference.dtype, src=None, base=reference.lazydata.base)) + strided = Tensor(LazyBuffer(device=reference.device, st=ShapeTracker((View.create(shape=shape, strides=strides, offset=offset),)), optype=None, op=None, dtype=reference.dtype, src=None, base=reference.lazydata.base)) # noqa: E501 assert strided.lazydata.st.real_strides() == strides, "real_strides should equal strides for strided" assert strided.lazydata in reference.lazydata.base.views, "base.views should contain strided.lazydata" return strided diff --git a/test/models/test_onnx.py b/test/models/test_onnx.py index 845dc36f..1782c51c 100644 --- a/test/models/test_onnx.py +++ b/test/models/test_onnx.py @@ -48,7 +48,8 @@ class TestOnnxModel(unittest.TestCase): mt2 = time.monotonic() tinygrad_out = tinygrad_out.numpy() et = time.monotonic() - if not CI: print(f"ran openpilot model in {(et-st)*1000.0:.2f} ms, waited {(mt2-mt)*1000.0:.2f} ms for realize, {(et-mt2)*1000.0:.2f} ms for GPU queue") + if not CI: + print(f"ran openpilot model in {(et-st)*1000.0:.2f} ms, waited {(mt2-mt)*1000.0:.2f} ms for realize, {(et-mt2)*1000.0:.2f} ms for GPU queue") if not CI: import cProfile @@ -100,7 +101,7 @@ class TestOnnxModel(unittest.TestCase): def test_efficientnet(self): input_name, input_new = "images:0", True - self._test_model(fetch("https://github.com/onnx/models/raw/main/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx"), input_name, input_new) + self._test_model(fetch("https://github.com/onnx/models/raw/main/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx"), input_name, input_new) # noqa: E501 def test_shufflenet(self): input_name, input_new = "gpu_0/data_0", False diff --git a/test/models/test_real_world.py b/test/models/test_real_world.py index 003654cc..326f4dbe 100644 --- a/test/models/test_real_world.py +++ b/test/models/test_real_world.py @@ -32,7 +32,7 @@ def helper_test(nm, gen, train, max_memory_allowed, max_kernels_allowed, all_jit assert GlobalCounters.mem_used/1e9 < max_memory_allowed, f"{nm} used more than {max_memory_allowed:.2f} GB" assert not kernels_used or kernels_used <= max_kernels_allowed, f"{nm} used more than {max_kernels_allowed} kernels" if all_jitted: - assert kernels_used > 0 and kernels_used == GlobalCounters.kernel_count or (kernels_used == 1 and getattr(Device[Device.DEFAULT], "graph", None)), f"only {kernels_used} out of {GlobalCounters.kernel_count} were jitted" + assert kernels_used > 0 and kernels_used == GlobalCounters.kernel_count or (kernels_used == 1 and getattr(Device[Device.DEFAULT], "graph", None)), f"only {kernels_used} out of {GlobalCounters.kernel_count} were jitted" # noqa: E501 class TestRealWorld(unittest.TestCase): def setUp(self): diff --git a/test/models/test_whisper.py b/test/models/test_whisper.py index 323e4dc0..ea61c193 100644 --- a/test/models/test_whisper.py +++ b/test/models/test_whisper.py @@ -13,7 +13,7 @@ TEST_FILE_2 = str(pathlib.Path(__file__).parent / "whisper/test2.wav") TRANSCRIPTION_2 = "a slightly longer audio file so that we can test batch transcriptions of varying length." # TODO this file will possibly not survive long. find another 1-2 minute sound file online to transcribe TEST_FILE_3_URL = 'https://homepage.ntu.edu.tw/~karchung/miniconversations/mc45.mp3' -TRANSCRIPTION_3 = "Just lie back and relax. Is the level of pressure about right? Yes, it's fine, and I'd like conditioner please. Sure. I'm going to start the second lathering now. Would you like some Q-tips? How'd you like it cut? I'd like my bangs and the back trimmed, and I'd like the rest thinned out a bit and layered. Where would you like the part? On the left, right about here. Here, have a look. What do you think? It's fine. Here's a thousand anti-dollars. It's 30-ant extra for the rants. Here's your change and receipt. Thank you, and please come again. So how do you like it? It could have been worse, but you'll notice that I didn't ask her for her card. Hmm, yeah. Maybe you can try that place over there next time." +TRANSCRIPTION_3 = "Just lie back and relax. Is the level of pressure about right? Yes, it's fine, and I'd like conditioner please. Sure. I'm going to start the second lathering now. Would you like some Q-tips? How'd you like it cut? I'd like my bangs and the back trimmed, and I'd like the rest thinned out a bit and layered. Where would you like the part? On the left, right about here. Here, have a look. What do you think? It's fine. Here's a thousand anti-dollars. It's 30-ant extra for the rants. Here's your change and receipt. Thank you, and please come again. So how do you like it? It could have been worse, but you'll notice that I didn't ask her for her card. Hmm, yeah. Maybe you can try that place over there next time." # noqa: E501 @unittest.skipIf(CI and Device.DEFAULT in ["LLVM", "CLANG", "CPU", "GPU"], "Not working on LLVM, slow on others. GPU reequires cl_khr_fp16") class TestWhisper(unittest.TestCase): diff --git a/test/test_dtype.py b/test/test_dtype.py index 4214ce22..a1acfca7 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -1,3 +1,4 @@ +# ruff: noqa: E501 import unittest import numpy as np from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType, OSX, temp, least_upper_dtype diff --git a/test/test_dtype_alu.py b/test/test_dtype_alu.py index c424a743..f52477a3 100644 --- a/test/test_dtype_alu.py +++ b/test/test_dtype_alu.py @@ -15,7 +15,8 @@ dtypes_int = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint dtypes_bool = (dtypes.bool,) binary_operations = [operator.add, operator.sub, operator.mul] integer_binary_operations = binary_operations + [(Tensor.xor, np.bitwise_xor)] -unary_operations = [(Tensor.exp, np.exp), (Tensor.log, np.log), operator.neg, (Tensor.sin, np.sin), (Tensor.sqrt, np.sqrt), (Tensor.reciprocal, np.reciprocal)] +unary_operations = [(Tensor.exp, np.exp), (Tensor.log, np.log), operator.neg, (Tensor.sin, np.sin), + (Tensor.sqrt, np.sqrt), (Tensor.reciprocal, np.reciprocal)] # TODO: enable this (this is a dtype issue) #binary_operations.append(operator.truediv) @@ -57,7 +58,7 @@ def universal_test_unary(a, dtype, op): if not isinstance(op, tuple): op = (op, op) tensor_value = op[0](Tensor([a], dtype=dtype)).numpy() numpy_value = op[1](np.array([a]).astype(dtype.np)) - if dtype in dtypes_float: np.testing.assert_allclose(tensor_value, numpy_value, atol=5 if Device.DEFAULT == "METAL" and op[0] == Tensor.sin else 1e-3, rtol=2 if Device.DEFAULT == "METAL" and op[0] == Tensor.sin else 1e-4 if dtype == dtypes.float32 else 1e-2) # exp and log and sin are approximations (in METAL, the default fast-math versions are less precise) + if dtype in dtypes_float: np.testing.assert_allclose(tensor_value, numpy_value, atol=5 if Device.DEFAULT == "METAL" and op[0] == Tensor.sin else 1e-3, rtol=2 if Device.DEFAULT == "METAL" and op[0] == Tensor.sin else 1e-4 if dtype == dtypes.float32 else 1e-2) # exp and log and sin are approximations (in METAL, the default fast-math versions are less precise) # noqa: E501 else: np.testing.assert_equal(tensor_value, numpy_value) def universal_test_cast(a, in_dtype, dtype): @@ -130,7 +131,7 @@ class TestDTypeALU(unittest.TestCase): def test_int32_midcast_float(self, a, b, c, op1, op2): universal_test_midcast(a, b, c, op1, op2, dtypes.int32, dtypes.float32) # Metal and CUDACPU behave differently than numpy in CI for overflows - @given(st.floats(width=32, min_value=0, max_value=10.0) if CI and (Device.DEFAULT == "METAL" or getenv("CUDACPU")) else ht.float32, st.floats(width=32, min_value=0, max_value=10.0) if CI and (Device.DEFAULT == "METAL" or getenv("CUDACPU")) else ht.float32, ht.int32, st.sampled_from(binary_operations), st.sampled_from(integer_binary_operations)) + @given(st.floats(width=32, min_value=0, max_value=10.0) if CI and (Device.DEFAULT == "METAL" or getenv("CUDACPU")) else ht.float32, st.floats(width=32, min_value=0, max_value=10.0) if CI and (Device.DEFAULT == "METAL" or getenv("CUDACPU")) else ht.float32, ht.int32, st.sampled_from(binary_operations), st.sampled_from(integer_binary_operations)) # noqa: E501 def test_float_midcast_int32(self, a, b, c, op1, op2): universal_test_midcast(a, b, c, op1, op2, dtypes.float32, dtypes.int32) @given(ht.float32, st.sampled_from(dtypes_float+dtypes_int+dtypes_bool)) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 1f95f648..6dacb506 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -1,3 +1,4 @@ +# ruff: noqa: E501 import numpy as np import unittest, os diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index 87471048..2aaf55dd 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -1,3 +1,4 @@ +# ruff: noqa: E501 import unittest from tinygrad.codegen.linearizer import Linearizer from tinygrad.features.search import Opt, OptOps diff --git a/test/test_nn.py b/test/test_nn.py index 2ca44742..8a1b8593 100755 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -4,7 +4,7 @@ import numpy as np from tinygrad.helpers import CI from tinygrad.jit import TinyJit from tinygrad.tensor import Tensor, Device -from tinygrad.nn import BatchNorm2d, Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear, GroupNorm, LayerNorm, LayerNorm2d, Embedding, InstanceNorm +from tinygrad.nn import BatchNorm2d, Conv1d,ConvTranspose1d, Conv2d,ConvTranspose2d, Linear, GroupNorm, LayerNorm,LayerNorm2d, Embedding, InstanceNorm import torch import pytest diff --git a/test/test_ops.py b/test/test_ops.py index 8bd36c2a..feca869f 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1,3 +1,4 @@ +# ruff: noqa: E501 import torch import time import math diff --git a/test/test_optim.py b/test/test_optim.py index df1e53f3..5bab58e0 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -63,8 +63,10 @@ class TestOptim(unittest.TestCase): def test_multistep_sgd_nesterov_momentum(self): self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9, 'nesterov': True}, 1e-5, 0) def test_multistep_sgd_high_lr_nesterov_momentum(self): self._test_sgd(10, {'lr': 10, 'momentum': 0.9, 'nesterov': True}, 1e-5, 3e-4) - def test_multistep_sgd_nesterov_momentum_wd(self): self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1}, 1e-5, 0) - def test_multistep_sgd_high_lr_nesterov_momentum_wd(self): self._test_sgd(10, {'lr': 9, 'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1}, 1e-5, 3e-4) + def test_multistep_sgd_nesterov_momentum_wd(self): + self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1}, 1e-5, 0) + def test_multistep_sgd_high_lr_nesterov_momentum_wd(self): + self._test_sgd(10, {'lr': 9, 'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1}, 1e-5, 3e-4) def test_adam(self): self._test_adam(1, {'lr': 0.001}, 1e-5, 0) def test_adam_high_lr(self): self._test_adam(1, {'lr': 10}, 1e-4, 1e-4) diff --git a/test/test_randomness.py b/test/test_randomness.py index d9b5a907..f7954745 100644 --- a/test/test_randomness.py +++ b/test/test_randomness.py @@ -1,3 +1,4 @@ +# ruff: noqa: E501 import math import unittest import numpy as np diff --git a/test/test_speed_v_torch.py b/test/test_speed_v_torch.py index fe6333e0..c9a14230 100644 --- a/test/test_speed_v_torch.py +++ b/test/test_speed_v_torch.py @@ -106,7 +106,7 @@ def helper_test_generic(name, f1, f1_args, f2, f2_args): desc = "faster" if et_torch > et_tinygrad else "slower" flops = save_ops*1e-6 mem = save_mem*1e-6 - print(("\r" if not CI else "")+f"{name:42s} {et_torch:7.2f} ms ({flops/et_torch:8.2f} GFLOPS {mem/et_torch:8.2f} GB/s) in torch, {et_tinygrad:7.2f} ms ({flops/et_tinygrad:8.2f} GFLOPS {mem/et_tinygrad:8.2f} GB/s) in tinygrad, {colorize_float(et_tinygrad/et_torch)} {desc} {flops:10.2f} MOPS {mem:8.2f} MB") + print(("\r" if not CI else "")+f"{name:42s} {et_torch:7.2f} ms ({flops/et_torch:8.2f} GFLOPS {mem/et_torch:8.2f} GB/s) in torch, {et_tinygrad:7.2f} ms ({flops/et_tinygrad:8.2f} GFLOPS {mem/et_tinygrad:8.2f} GB/s) in tinygrad, {colorize_float(et_tinygrad/et_torch)} {desc} {flops:10.2f} MOPS {mem:8.2f} MB") # noqa: E501 np.testing.assert_allclose(val_tinygrad, val_torch, atol=1e-3, rtol=1e-3) def helper_test_conv(bs, in_chans, out_chans, kernel_size, img_size_y, img_size_x): diff --git a/test/test_symbolic_shapetracker.py b/test/test_symbolic_shapetracker.py index 99f87822..02e129c7 100644 --- a/test/test_symbolic_shapetracker.py +++ b/test/test_symbolic_shapetracker.py @@ -136,11 +136,11 @@ class TestSymbolicReshape(unittest.TestCase): def test_symbolic_mask(self): # taken from gpt2 single kvcache # these two caused problems in gpt2 if reshape merged views - view = View(shape=(1, (NumNode(1)+Variable('start_pos', 1, 128).bind(2)), 16, 64), strides=(0, 0, 64, 1), offset=NumNode(1024), mask=((0, 1), (Variable('start_pos', 1, 128).bind(2), (NumNode(1)+Variable('start_pos', 1, 128).bind(2))), (0, 16), (0, 64)), contiguous=False) + view = View(shape=(1, (NumNode(1)+Variable('start_pos', 1, 128).bind(2)), 16, 64), strides=(0, 0, 64, 1), offset=NumNode(1024), mask=((0, 1), (Variable('start_pos', 1, 128).bind(2), (NumNode(1)+Variable('start_pos', 1, 128).bind(2))), (0, 16), (0, 64)), contiguous=False) # noqa: E501 new_shape = (1, 1, (NumNode(1)+Variable('start_pos', 1, 128).bind(2)), 16, 64) assert view.reshape(new_shape) is None - view = View(shape=(2, 1, (NumNode(1)+Variable('start_pos', 1, 128)), 16, 64), strides=(0, 0, 1024, 64, 1), offset=131072, mask=((1, 2), (0, 1), (0, (NumNode(1)+Variable('start_pos', 1, 128))), (0, 16), (0, 64)), contiguous=False) + view = View(shape=(2, 1, (NumNode(1)+Variable('start_pos', 1, 128)), 16, 64), strides=(0, 0, 1024, 64, 1), offset=131072, mask=((1, 2), (0, 1), (0, (NumNode(1)+Variable('start_pos', 1, 128))), (0, 16), (0, 64)), contiguous=False) # noqa: E501 new_shape = (2, (NumNode(1)+Variable('start_pos', 1, 128)), 16, 64) assert view.reshape(new_shape) is None diff --git a/test/test_uops.py b/test/test_uops.py index 06bf73f5..1e0d275f 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -1,3 +1,4 @@ +# ruff: noqa: E501 from typing import Optional, Tuple, Any, List import unittest, math import numpy as np diff --git a/test/unit/test_shapetracker.py b/test/unit/test_shapetracker.py index 7f9739b0..a3fa2300 100644 --- a/test/unit/test_shapetracker.py +++ b/test/unit/test_shapetracker.py @@ -1,4 +1,5 @@ #!/usr/bin/env python +# ruff: noqa: E501 import unittest import numpy as np from tinygrad.helpers import prod, DEBUG diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index ca0248f9..f0b61207 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -11,7 +11,8 @@ from dataclasses import dataclass from enum import Enum, auto class OptOps(Enum): - UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto(); LASTLOCAL = auto(); GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto() # noqa: E702 + UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto(); LASTLOCAL = auto() # noqa: E702 + GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto() # noqa: E702 def __lt__(self, x:OptOps): return self.value < x.value @dataclass(frozen=True, order=True) @@ -29,19 +30,19 @@ class TensorCore: dtype_out: DType threads: List[Tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure upcast_dim: int # which TC dim to upcast - thread_local_aliases: List[List[List[int]]] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim + thread_local_aliases: List[List[List[int]]] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim # noqa: E501 thread_local_sizes: List[int] # in each thread, the number of elements stored in registers for each TC dim arch: Optional[str] = None def __str__(self): return f"tensor_core<{self.device}, {self.dims}, {self.dtype_in}, {self.dtype_out}>" tensor_cores: Dict[str, List[TensorCore]] = { "METAL": [ - TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"), - TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"), + TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"), # noqa: E501 + TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"), # noqa: E501 ], "HIP": [ - TensorCore(device="HIP", dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), - TensorCore(device="HIP", dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.half, upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), + TensorCore(device="HIP", dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501 + TensorCore(device="HIP", dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.half, upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501 ] } @@ -65,7 +66,7 @@ class LinearizerOptions(NamedTuple): class Kernel: def __init__(self, ast:LazyOp, opts:Optional[LinearizerOptions]=None): - self.opts = opts if opts else (cast(Compiled, Device[Device.DEFAULT]).linearizer_opts if isinstance(Device[Device.DEFAULT], Compiled) else LinearizerOptions()) + self.opts = opts if opts else (cast(Compiled, Device[Device.DEFAULT]).linearizer_opts if isinstance(Device[Device.DEFAULT], Compiled) else LinearizerOptions()) # noqa: E501 self.ast = ast assert ast.op == BufferOps.STORE, f"kernels must have a store as the output, got {ast.op}" @@ -133,8 +134,8 @@ class Kernel: def membufs(self) -> List[MemBuffer]: return [x for x in self.bufs if isinstance(x, MemBuffer)] # TODO: these need more tests or it might silently be no-op - def shape_offsets(self, i:int): return itertools.product(*[list(range(cast(int, s))) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()] - def float4_axis(self, i:int): return [x-(self.shape_len-self.upcasted) for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x]%4 == 0] + def shape_offsets(self, i:int): return itertools.product(*[list(range(cast(int, s))) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()] # noqa: E501 + def float4_axis(self, i:int): return [x-(self.shape_len-self.upcasted) for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x]%4 == 0] # noqa: E501 def upcasted_axis(self, i:int): return list(zip(self.sts[i].shape[self.shape_len-self.upcasted:], @@ -149,11 +150,11 @@ class Kernel: return [sum(t) for t in itertools.product(*[[y*acc_strides[i] for y in range(x[0])] for i,x in enumerate(upcasted_i[::-1])])] def get_upcast_dim(self, i:int) -> List[int]: - should_upcast = self.opts.supports_float4 and (self.bufs[i].dtype in [dtypes.float32, dtypes.float16] or isinstance(self.bufs[i].dtype, ImageDType)) + should_upcast = self.opts.supports_float4 and (self.bufs[i].dtype in [dtypes.float32, dtypes.float16] or isinstance(self.bufs[i].dtype, ImageDType)) # noqa: E501 return [x for x in self.sts[i].unit_stride_axes() if should_upcast and x >= self.shape_len-self.upcasted and self.sts[i].shape[x] > 1] @property - def first_reduce(self) -> int: return [x!=y for x,y in zip(self.sts[0].shape[:self.shape_len-self.upcasted]+(0,), self.full_shape[:self.shape_len-self.upcasted]+(1,))].index(True) + def first_reduce(self) -> int:return [x!=y for x,y in zip(self.sts[0].shape[:self.shape_len-self.upcasted]+(0,), self.full_shape[:self.shape_len-self.upcasted]+(1,))].index(True) # noqa: E501 @property def output_shape(self) -> Tuple[sint, ...]: return self.sts[0].shape @@ -168,7 +169,8 @@ class Kernel: def shape_len(self) -> int: return len(self.sts[0].shape) @property - def upcast_in_mid_reduce_axes(self) -> List[int]: return [j for j in range(self.first_reduce, self.first_reduce+len(self.group_for_reduce)) if self.full_shape[j] == self.sts[0].shape[j]] + def upcast_in_mid_reduce_axes(self) -> List[int]: + return [j for j in range(self.first_reduce, self.first_reduce+len(self.group_for_reduce)) if self.full_shape[j] == self.sts[0].shape[j]] @property def global_dims(self) -> int: return self.first_reduce-self.local_dims @@ -189,7 +191,7 @@ class Kernel: # after global are local_dims; warp ones used in tensor cores must be closest to first_reduce (cyan) colors += ["cyan"] * self.local_dims # between first_reduce and first_reduce + group_for_reduce, they are either upcast mid reduce (white), or late upcasted (green) - colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce + len(self.group_for_reduce))] + colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce + len(self.group_for_reduce))] # noqa: E501 # between first_reduce + group_for_reduce and upcasted, they are reduce (red) colors += ["red"] * ((self.shape_len-self.upcasted) - (self.first_reduce + len(self.group_for_reduce))) # upcasted dimensions are reduce (magenta) or normal (yellow) @@ -198,7 +200,7 @@ class Kernel: return colors def colored_shape(self, pad:Optional[int]=None, dense=False) -> str: - ret = ' '.join(colored(s, color) for s,color in zip([f"{s:4d}" if isinstance(s, int) and not dense else s for s in self.full_shape], self.colors())) + ret = ' '.join(colored(s, color) for s,color in zip([f"{s:4d}" if isinstance(s, int) and not dense else s for s in self.full_shape], self.colors())) # noqa: E501 if pad: ret += ' '*(pad-ansilen(ret)) return ret @@ -267,7 +269,7 @@ class Kernel: can_merge = [] for j in range(len(shapes)): # TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case - can_merge.append(strides[j][i] is not None and ((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*cast(int, strides[j][i])) or (strides[j][i] == 0 and rets[j][-1][1] == 0))) + can_merge.append(strides[j][i] is not None and ((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*cast(int, strides[j][i])) or (strides[j][i] == 0 and rets[j][-1][1] == 0))) # noqa: E501 # more can merge than this mergeable = all(can_merge) and i != self.first_reduce for j in range(len(shapes)): @@ -299,8 +301,9 @@ class Kernel: if global_dims > 0: if global_max: tmp = global_max[:global_dims] + (local_max[:self.local_dims] if local_max else []) - if max(global_max) < max(self.full_shape[:global_dims]): self.reshape_and_permute(lambda x: self._limit_size(x, tmp + [math.inf] * (len(self.full_shape)-len(tmp))), None) - assert max(global_max) >= max(self.full_shape[:global_dims]), f"device max allocation {max(self.full_shape[:global_dims])} exceeds global dim maximum {max(global_max)}" + if max(global_max) < max(self.full_shape[:global_dims]): + self.reshape_and_permute(lambda x: self._limit_size(x, tmp + [math.inf] * (len(self.full_shape)-len(tmp))), None) + assert max(global_max) >= max(self.full_shape[:global_dims]), f"device max allocation {max(self.full_shape[:global_dims])} exceeds global dim maximum {max(global_max)}" # noqa: E501 for i in range(global_dims-1): if i < len(global_max) and self.full_shape[i] > global_max[i]: order = list(range(len(self.full_shape))) @@ -333,7 +336,7 @@ class Kernel: if not((tc.arch is None or tc.arch == os.uname().machine) and isinstance(self.reduceop.src[0], LazyOp)): continue has_cast = tc.dtype_in != tc.dtype_out - if has_cast and not(isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == UnaryOps.CAST and self.reduceop.src[0].arg[0] == tc.dtype_out): continue + if has_cast and not(isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == UnaryOps.CAST and self.reduceop.src[0].arg[0] == tc.dtype_out): continue # noqa: E501 mul_op = self.reduceop.src[0].src[0] if has_cast else self.reduceop.src[0] if not(isinstance(mul_op, LazyOp) and mul_op.op == BinaryOps.MUL): continue @@ -341,10 +344,10 @@ class Kernel: if not(isinstance(mul_op.src[1], LazyOp) and mul_op.src[1].op == BufferOps.LOAD and mul_op.src[1].arg.dtype == tc.dtype_in): continue buf0, buf1 = self.bufs.index(cast(MemBuffer, mul_op.src[0].arg)), self.bufs.index(cast(MemBuffer, mul_op.src[1].arg)) buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides() - axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[0] == 0] - axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[1] == 0] + axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[0] == 0] # noqa: E501 + axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[1] == 0] # noqa: E501 - if not(axis_buf0 and axis_buf1 and self.full_shape[self.first_reduce]%tc.dims[2] == 0 and self.full_shape[self.first_reduce] >= tc.dims[2] and (self.shape_len-self.first_reduce) == 1): continue + if not(axis_buf0 and axis_buf1 and self.full_shape[self.first_reduce]%tc.dims[2] == 0 and self.full_shape[self.first_reduce] >= tc.dims[2] and (self.shape_len-self.first_reduce) == 1): continue # noqa: E501 if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc) @@ -388,17 +391,17 @@ class Kernel: break # alias buffer - alias_pattern = [0]*(self.global_dims+(self.local_dims-len(tc.threads))) + [2]*(len(tc.threads)) + [0]*(self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3]*(self.upcasted-2) + alias_pattern = [0]*(self.global_dims+(self.local_dims-len(tc.threads))) + [2]*(len(tc.threads)) + [0]*(self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3]*(self.upcasted-2) # noqa: E501 self.alias_buffer(buf0, alias_pattern) self.alias_buffer(buf1, alias_pattern) return True return False def apply_opt(self, opt:Opt): - assert not self.dont_use_locals or opt.op not in {OptOps.LOCAL, OptOps.LASTLOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals" + assert not self.dont_use_locals or opt.op not in {OptOps.LOCAL, OptOps.LASTLOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals" # noqa: E501 self.applied_opts.append(opt) if opt.axis is not None: - axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op == OptOps.GROUP or opt.op == OptOps.GROUPTOP else 0)) + axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op == OptOps.GROUP or opt.op == OptOps.GROUPTOP else 0)) # noqa: E501 else: axis = -1 if opt.amt is not None: @@ -435,7 +438,7 @@ class Kernel: self.shift_to(axis, amt, insert_before=None) self.upcast() elif opt.op == OptOps.UPCASTMID: # white - assert self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce" + assert self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce" # noqa: E501 axes = self.sts[0].unit_stride_axes() assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}" assert axes[0] == axis, "wrong axis" @@ -465,7 +468,7 @@ class Kernel: if self.bufs[0].dtype.__class__ is ImageDType: unit_stride_axes_mul_4 = [i for i in self.sts[0].unit_stride_axes(ignore_valid=True) if self.sts[0].shape[i]%4 == 0] assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[0]}" - if len(unit_stride_axes_mul_4) and all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: + if len(unit_stride_axes_mul_4) and all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: # noqa: E501 self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4)) def hand_coded_optimizations(self): @@ -482,10 +485,10 @@ class Kernel: buf0_strides = self.sts[buf0].real_strides() buf1_strides = self.sts[buf1].real_strides() def has_expanded_axis(s, st): return any(x > 1 and y == 0 for x,y in zip(s,st)) - if buf0_strides[self.first_reduce] == 1 and not (has_expanded_axis(self.sts[buf0].shape, buf0_strides) and has_expanded_axis(self.sts[buf1].shape, buf1_strides)): + if buf0_strides[self.first_reduce] == 1 and not (has_expanded_axis(self.sts[buf0].shape, buf0_strides) and has_expanded_axis(self.sts[buf1].shape, buf1_strides)): # noqa: E501 for global_idx in range(self.global_dims): if self.full_shape[self.first_reduce]%MV_THREADS_PER_ROW == 0 and self.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0: - if DEBUG >= 3: print(f"MATVEC: full_shape={self.full_shape} first_reduce={self.first_reduce} buf0_strides={buf0_strides} blocksize={MV_BLOCKSIZE} threads_per_row={MV_THREADS_PER_ROW} rows_per_thread={MV_ROWS_PER_THREAD}") + if DEBUG >= 3: print(f"MATVEC: full_shape={self.full_shape} first_reduce={self.first_reduce} buf0_strides={buf0_strides} blocksize={MV_BLOCKSIZE} threads_per_row={MV_THREADS_PER_ROW} rows_per_thread={MV_ROWS_PER_THREAD}") # noqa: E501 if MV_THREADS_PER_ROW > 1: self.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW)) if MV_BLOCKSIZE > 1: @@ -496,7 +499,7 @@ class Kernel: if self.opts.has_local and self.opts.has_shared and all(isinstance(s, int) for s in self.sts[0].shape[:self.first_reduce]): # are we grouping? (requires local shape support) - if not self.float4_axis(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048: + if not self.float4_axis(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048: # noqa: E501 # TODO: use 1024 if it's allowed in a smarter way for sz in (([256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]): if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts): @@ -504,7 +507,7 @@ class Kernel: break # are we upcasting in mid reduce? (only for images) - if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: + if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: # noqa: E501 axes = self.sts[0].unit_stride_axes() assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}" if self.sts[0].shape[axes[0]]%4 == 0: @@ -515,7 +518,7 @@ class Kernel: unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0] if buf.dtype.__class__ is ImageDType: #assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}" - if len(unit_stride_axes_mul_4) and all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: + if len(unit_stride_axes_mul_4) and all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: # noqa: E501 if unit_stride_axes_mul_4[0] < self.first_reduce: self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4)) else: @@ -548,9 +551,9 @@ class Kernel: while prod(self.sts[0].shape[:self.first_reduce]) >= 1024: xb_choices = [] for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce - # if we haven't upcasted it, it's not symbolic, it mods, and some buffer has stride 0 on axis while having no stride 0 in the upcasted axis already - if axis not in upcasted_axis and isinstance(self.full_shape[axis], int) and self.full_shape[axis]%upcast_amount == 0 and any(st.views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index, st in enumerate(self.sts)): - xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in self.sts), sum(st.views[-1].strides[axis] for st in self.sts), axis, upcast_amount)) + # if we haven't upcasted it, it's not symbolic, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already + if axis not in upcasted_axis and isinstance(self.full_shape[axis], int) and self.full_shape[axis]%upcast_amount == 0 and any(st.views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index, st in enumerate(self.sts)): # noqa: E501 + xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in self.sts), sum(st.views[-1].strides[axis] for st in self.sts), axis, upcast_amount)) # noqa: E501 if xb_choices: xb_choices = sorted(xb_choices) if DEBUG >= 4: print(f"float4 merging axis : {xb_choices}") @@ -559,8 +562,8 @@ class Kernel: else: break - # if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS - if self.first_reduce < (self.shape_len-self.upcasted) and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))) and (self.upcasted == 0 or prod(self.full_shape[-self.upcasted:]) < 64): + # if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. + if self.first_reduce < (self.shape_len-self.upcasted) and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))) and (self.upcasted == 0 or prod(self.full_shape[-self.upcasted:]) < 64): # noqa: E501 if (s:=self.full_unupcasted_shape[-1]) <= 32 and isinstance(s, int): # NOTE: cannot loop unroll symbolic axis self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0)) # if it's small, upcast a second reduce dimension too @@ -585,11 +588,11 @@ class Kernel: self.apply_opt(Opt(OptOps.NOLOCALS)) else: # prioritize making expand axes local - local_axis_ranking = [(any(self.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(self.sts))), axis) for axis in range(len(self.full_shape[:self.first_reduce]))] + local_axis_ranking = [(any(self.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(self.sts))), axis) for axis in range(len(self.full_shape[:self.first_reduce]))] # noqa: E501 to_local: List[Tuple[int, int]] = [] for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])): local_size = prod(sz for _, sz in to_local) - local_sz: Optional[int] = next((x for x in ([32] * (axis == 0) + [16, 8, 4, 3, 2]) if self.full_shape[axis] % x == 0 and local_size * x <= 128), None) + local_sz: Optional[int] = next((x for x in ([32] * (axis == 0) + [16, 8, 4, 3, 2]) if self.full_shape[axis] % x == 0 and local_size * x <= 128), None) # noqa: E501 if local_sz is not None: to_local.append((axis, local_sz)) deleted_shape = 0 for axis, local_sz in sorted(to_local[:3]): diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 93821568..5e8acb12 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -25,10 +25,11 @@ class UOp: dtype: Optional[DType] vin: Tuple[UOp, ...] arg: Any - def __repr__(self): return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}" + def __repr__(self): + return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}" def get_grouped_dims(prefix, start_dim, local_dims, maxdim:int=0): - local_idxs = loop_local_idxs = [Variable(f"{prefix}{start_dim+i}", 0, s-1) for i,s in enumerate(local_dims[0:maxdim-1] + (prod(local_dims[maxdim-1:]),) if len(local_dims) > maxdim else local_dims)] + local_idxs = loop_local_idxs = [Variable(f"{prefix}{start_dim+i}", 0, s-1) for i,s in enumerate(local_dims[0:maxdim-1] + (prod(local_dims[maxdim-1:]),) if len(local_dims) > maxdim else local_dims)] # noqa: E501 if maxdim != 0 and len(local_dims) > maxdim: dd = local_idxs[maxdim-1] nli = [] @@ -44,7 +45,9 @@ class Linearizer(Kernel): return self.uop(UOps.ALU, dtype, (a, render_b), op) # NOTE: the consts have to be cached for deduping of downstream uops to work - def const(self, b:Union[int,float], dtype=dtypes.int32, insert_before=None) -> UOp: return self.uop(UOps.CONST, dtype, tuple(), b, insert_before=insert_before) + def const(self, b:Union[int,float], dtype=dtypes.int32, insert_before=None) -> UOp: + return self.uop(UOps.CONST, dtype, tuple(), b, insert_before=insert_before) + def cast(self, val: UOp, dtype) -> UOp: return self.uop(UOps.CAST, dtype, (val,)) if val.dtype != dtype else val def get_reduce_acc(self, op, dtype:DType): @@ -56,8 +59,10 @@ class Linearizer(Kernel): DivNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.DIV), ModNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MOD), LtNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.CMPLT, dtype=dtypes.bool), - SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.ADD), self.nodes[1:], self.nodes[0].render(ops,ctx)), - AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.MUL, dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) } + SumNode: lambda self,ops,ctx: + functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.ADD), self.nodes[1:], self.nodes[0].render(ops,ctx)), + AndNode: lambda self,ops,ctx: + functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.MUL, dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) } def global_load(self, i:int, idxs:Sequence[Node], acc=None, barrier:Optional[UOp]=None) -> List[UOp]: buf = self.bufs[i] @@ -87,7 +92,7 @@ class Linearizer(Kernel): invalid_value = 0 if dtypes.is_int(buf.dtype) else 0.0 for idx, valid, rep_idx in zip(e_idxs, e_valids, Node.iter_idxs(expand_vars)): this_const, idx, valid = (invalid_value, NumNode(0), NumNode(1)) if valid.max == 0 else (const, idx, valid) - key = f"{acc}{localtype}{this_const if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" + key = f"{acc}{localtype}{this_const if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" # noqa: E501 if key not in self.load_cache: if acc is not None: self.load_cache[key] = self.uop(UOps.DEFINE_ACC, localtype, (), this_const, cachable=False) @@ -95,12 +100,12 @@ class Linearizer(Kernel): self.load_cache[key] = self.const(this_const, localtype) if valid.min == 0 and valid.max == 1: valid_rendered = valid.render(self.render_ops, self) - self.load_cache[key] = self.uop(UOps.ALU, localtype, (valid_rendered, self.load_cache[key], self.const(invalid_value, localtype)), TernaryOps.WHERE) + self.load_cache[key] = self.uop(UOps.ALU, localtype, (valid_rendered, self.load_cache[key], self.const(invalid_value, localtype)), TernaryOps.WHERE) # noqa: E501 elif isinstance(buf.dtype, ImageDType): buf_uop = self.buf_uops[i] assert buf_uop is not None, f"buffer {i} wasn't UOped" image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid) - rendered_idx = self.uop(UOps.CAST, dtypes.int.vec(2), (image_idx[0].render(self.render_ops, self), image_idx[1].render(self.render_ops, self))) + rendered_idx = self.uop(UOps.CAST, dtypes.int.vec(2), (image_idx[0].render(self.render_ops, self), image_idx[1].render(self.render_ops, self))) # noqa: E501 valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, dtypes.float32.vec(4))) if valid.min == 0 else tuple() self.load_cache[key] = self.uop(UOps.LOAD, dtypes.float32.vec(4), (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ())) idx_small = idx%4 @@ -179,7 +184,7 @@ class Linearizer(Kernel): # add global buffers for i,buf in enumerate(self.bufs): if isinstance(buf, MemBuffer): - self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, PtrDType(buf.dtype) if not isinstance(buf.dtype, ImageDType) else buf.dtype, (), (f"data{buf.idx}", buf.dtype)) + self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, PtrDType(buf.dtype) if not isinstance(buf.dtype, ImageDType) else buf.dtype, (), (f"data{buf.idx}", buf.dtype)) # noqa: E501 # add var vals for var in vars_from_ast(self.ast): assert var.expr is not None @@ -190,7 +195,7 @@ class Linearizer(Kernel): # add a local buffer for multistage reduce. # TODO: use local alias if self.group_for_reduce: # TODO: the strides of this can be controlled - self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+len(self.group_for_reduce)]) + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) + self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+len(self.group_for_reduce)]) + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) # noqa: E501 self.bufs.append(LocalBuffer("temp", self.sts[-1].size())) self.buf_uops.append(self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), ("temp", self.sts[-1].size()))) @@ -204,7 +209,7 @@ class Linearizer(Kernel): # define indexes global_idxs, loop_global_idxs = get_grouped_dims("gidx", 0, self.full_shape[:self.global_dims], 3 if self.opts.has_local else 0) - local_idxs, loop_local_idxs = get_grouped_dims("lidx", self.global_dims, self.full_shape[self.global_dims:self.first_reduce+len(self.group_for_reduce)], 3 if self.opts.has_local else 0) + local_idxs, loop_local_idxs = get_grouped_dims("lidx", self.global_dims, self.full_shape[self.global_dims:self.first_reduce+len(self.group_for_reduce)], 3 if self.opts.has_local else 0) # noqa: E501 full_upcast_idxs = [Variable(None, 0, s-1) for s in self.full_shape[self.shape_len-self.upcasted:]] upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]] @@ -212,7 +217,7 @@ class Linearizer(Kernel): def render_loop(xx:List[Variable]) -> Tuple[UOp, ...]: new_loops = {x.expr:self.uop(UOps.LOOP, dtypes.int32, ( self.const(x.min) if isinstance(x.min, int) else cast(Node, x.min).render(self.render_ops, self), - self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), cachable=False) for x in xx if not isinstance(x, NumNode) and x.expr is not None} + self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), cachable=False) for x in xx if not isinstance(x, NumNode) and x.expr is not None} # noqa: E501 self.loop_uops.update(new_loops) return tuple(new_loops.values()) @@ -221,11 +226,11 @@ class Linearizer(Kernel): self.local_size: Optional[List[int]] = None if self.dont_use_locals: self.global_size = [x.max+1 for x in loop_global_idxs][::-1] - self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)}) + self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501 elif self.opts.has_local: self.global_size, self.local_size = [x.max+1 for x in loop_global_idxs][::-1], [x.max+1 for x in loop_local_idxs][::-1] - self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)}) - self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_local_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)}) + self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501 + self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_local_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)}) # noqa: E501 else: render_loop(loop_global_idxs+loop_local_idxs) @@ -238,7 +243,7 @@ class Linearizer(Kernel): fake_reduce_idxs: List[Variable] = [] if self.reduceop is not None: # define indexes - reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len-self.upcasted)] + reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len-self.upcasted)] # noqa: E501 fake_reduce_idxs = [x*0 for x in reduce_idxs] # define accumulator @@ -275,7 +280,7 @@ class Linearizer(Kernel): buf_idxs = [idx*0 if s == 0 else idx for idx,s in zip(global_idxs+local_idxs+reduce_idxs+full_upcast_idxs,self.sts[i].real_strides())] if self.tensor_core: min_alias_idx = min(self.local_alias.keys()) - replace_input_idxs = calc_tc_idxs(self.tensor_core.thread_local_sizes[i-min_alias_idx], self.tensor_core.thread_local_aliases[i-min_alias_idx]) + replace_input_idxs = calc_tc_idxs(self.tensor_core.thread_local_sizes[i-min_alias_idx], self.tensor_core.thread_local_aliases[i-min_alias_idx]) # noqa: E501 for n in range(len(self.tensor_core.threads)): buf_idxs[self.first_reduce-len(self.tensor_core.threads)+n] = replace_input_idxs[n] # replace locals for n in range(len(replace_input_idxs)-len(self.tensor_core.threads)): @@ -294,14 +299,14 @@ class Linearizer(Kernel): for y in range(by): for x in range(bx): for j in range(acc_reds): - op1, op2, op3 = locals_to_store[0][2][(x+(j*bx))*wmma_sz[0]:(x+(j*bx)+1)*wmma_sz[0]], locals_to_store[1][2][(y+(j*by))*wmma_sz[1]:(y+(j*by)+1)*wmma_sz[1]], acc[i:i+wmma_sz[2]] + op1, op2, op3 = locals_to_store[0][2][(x+(j*bx))*wmma_sz[0]:(x+(j*bx)+1)*wmma_sz[0]], locals_to_store[1][2][(y+(j*by))*wmma_sz[1]:(y+(j*by)+1)*wmma_sz[1]], acc[i:i+wmma_sz[2]] # noqa: E501 if self.opts.device != "HIP": ops = tuple(op1+op2+op3) else: ops = (self.uop(UOps.CAST, dtypes.half.vec(16), tuple(op1)), self.uop(UOps.CAST, dtypes.half.vec(16), tuple(op2)), self.uop(UOps.CAST, dtypes.float.vec(8), tuple(op3))) - ret = self.uop(UOps.WMMA, dtypes.float.vec(2) if wmma_sz[2] == 2 else dtypes.float.vec(8), ops, (self.opts.device, self.tensor_core.dtype_in, self.tensor_core.dtype_out,)) + ret = self.uop(UOps.WMMA, dtypes.float.vec(2) if wmma_sz[2] == 2 else dtypes.float.vec(8), ops, (self.opts.device, self.tensor_core.dtype_in, self.tensor_core.dtype_out,)) # noqa: E501 for z in range(cast(DType, ret.dtype).sz): acc[i+z] = self.uop(UOps.PHI, dtypes.float, (op3[z], self.uop(UOps.GEP, dtypes.float, (ret,), z)) + loop_ctx) i += wmma_sz[2] @@ -312,7 +317,7 @@ class Linearizer(Kernel): self.uop(UOps.BARRIER, None, (), cachable=False) # load earlybufs - loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs[1:], start=1) if b in self.earlybufs}) + loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs[1:], start=1) if b in self.earlybufs}) # noqa: E501 # run early AST (with reduce) self.ast_parse(self.reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) @@ -332,7 +337,7 @@ class Linearizer(Kernel): barrier = self.uop(UOps.IF, None, (if_cond, barrier), cachable=False) # create new late reduce local loops and replace local_idxs that have been used - end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))] + end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))] # noqa: E501 local_idxs = local_idxs[:self.local_dims] + end_local_idxs[self.global_dims + self.local_dims:] # if any group_for_reduce items aren't reduces, upcast them here @@ -348,7 +353,7 @@ class Linearizer(Kernel): # NOTE: this structure is the same as the reduce op above # define late accumulator - acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop.op, self.bufs[-1].dtype)) + acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop.op, self.bufs[-1].dtype)) # noqa: E501 # late reduce loop loop_ctx = render_loop(end_local_idxs) @@ -357,13 +362,13 @@ class Linearizer(Kernel): loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier) # there's no AST here (and there's no shape for the reduce LazyOp) - self.ast_parse(LazyOp(self.reduceop.op, (LazyOp(BufferOps.LOAD, (), self.bufs[-1]),)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) + self.ast_parse(LazyOp(self.reduceop.op, (LazyOp(BufferOps.LOAD, (), self.bufs[-1]),)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) # noqa: E501 # end the late reduce loop self.load_cache.clear() # load latebufs - loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer}) + loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer}) # noqa: E501 # run late AST (without the store) val = self.ast_parse(cast(LazyOp, self.ast.src[0]), acc, None, loaded_buffers) @@ -440,7 +445,7 @@ class Linearizer(Kernel): for u in self.uops: if u.uop == UOps.LOOP: # add END of loops after the last thing that (recursively) depends on them - self.uop(UOps.END, None, (u,), cachable=False, insert_before=self.uops.index(sorted(list(get_recursive_children(u)), key=self.uops.index)[-1])+1) + self.uop(UOps.END, None, (u,), cachable=False, insert_before=self.uops.index(sorted(list(get_recursive_children(u)), key=self.uops.index)[-1])+1) # noqa: E501 elif u.uop == UOps.IF: # END any if statements at the end of the uops self.uop(UOps.END, None, (u,), cachable=False) @@ -448,7 +453,7 @@ class Linearizer(Kernel): # maybe graph the uops if DEBUG >= 5: for u in self.uops: - print(f"{self.uops.index(u):4d} {str(u.uop):20s}: {str(u.dtype) if u.dtype is not None else '':25s} {str([self.uops.index(x) for x in u.vin]):32s} {u.arg}") + print(f"{self.uops.index(u):4d} {str(u.uop):20s}: {str(u.dtype) if u.dtype is not None else '':25s} {str([self.uops.index(x) for x in u.vin]):32s} {u.arg}") # noqa: E501 if getenv("GRAPHUOPS"): from tinygrad.graph import graph_uops graph_uops(self.uops) @@ -460,7 +465,7 @@ class Linearizer(Kernel): self.applied_opts_cache = self.applied_opts[:] return self - def uop(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None, cachable=True, insert_before=None, simplify=True) -> UOp: + def uop(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None, cachable=True, insert_before=None, simplify=True) -> UOp: # noqa: E501 key = (uop, dtype, vin, arg) if uop == UOps.PHI and vin[1].dtype != dtype: vin = (vin[0], self.cast(vin[1], dtype)) + vin[1:] if uop == UOps.ALU: # upcast vins to the same dtype @@ -471,10 +476,12 @@ class Linearizer(Kernel): if simplify: if uop == UOps.PHI and len(vin) == 2: return vin[1] # a phi without loops is a noop if uop == UOps.GEP and vin[0].uop == UOps.CONST: return self.const(vin[0].arg, dtype, insert_before) - if uop == UOps.CAST and all(x.uop == UOps.CONST for x in vin) and all_same([x.arg for x in vin]): return self.const(vin[0].arg, dtype, insert_before) + if uop == UOps.CAST and all(x.uop == UOps.CONST for x in vin) and all_same([x.arg for x in vin]): + return self.const(vin[0].arg, dtype, insert_before) if uop == UOps.ALU: # rewrites. NOTE: the rewritten NEG op is still around... - if arg == BinaryOps.ADD and vin[1].uop == UOps.ALU and vin[1].arg == UnaryOps.NEG: return self.uop(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable=cachable, insert_before=insert_before) + if arg == BinaryOps.ADD and vin[1].uop == UOps.ALU and vin[1].arg == UnaryOps.NEG: + return self.uop(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable=cachable, insert_before=insert_before) # constant folding if arg == UnaryOps.NEG and vin[0].uop == UOps.CONST: return self.const(-vin[0].arg, dtype, insert_before) if arg == TernaryOps.WHERE and vin[1] == vin[2]: return vin[1] # a conditional with the same results either way is a noop @@ -487,7 +494,8 @@ class Linearizer(Kernel): if arg == BinaryOps.DIV and vin[1].uop == UOps.CONST and vin[1].arg == 1.0: return vin[0] # When insert_before is set, need to check if the cached expr is valid with the given insert place. - if cachable and (expr:=self.saved_exprs.get(key, None)) is not None and (insert_before is None or self.uops.index(expr) <= insert_before): return expr + if cachable and (expr:=self.saved_exprs.get(key, None)) is not None and (insert_before is None or self.uops.index(expr) <= insert_before): + return expr ret = UOp(uop, dtype, vin, arg) if insert_before is not None: self.uops.insert(insert_before, ret) @@ -496,16 +504,16 @@ class Linearizer(Kernel): if cachable: self.saved_exprs[key] = ret return ret - def ast_parse(self, x:LazyOp, acc: List[UOp], offs:Optional[List[int]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], do_reduce=False, loop_ctx=tuple()) -> List[UOp]: + def ast_parse(self, x:LazyOp, acc: List[UOp], offs:Optional[List[int]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], do_reduce=False, loop_ctx=tuple()) -> List[UOp]: # noqa: E501 if x.op in BufferOps: return loaded_buffers[x.arg] - if x.op == UnaryOps.CAST: return [self.uop(UOps.CAST, x.arg[0], (u,), x.arg) if not isinstance(x.arg[0], ImageDType) else u for u in self.ast_parse(cast(LazyOp, x.src[0]), acc, offs, loaded_buffers)] + if x.op == UnaryOps.CAST: return [self.uop(UOps.CAST, x.arg[0], (u,), x.arg) if not isinstance(x.arg[0], ImageDType) else u for u in self.ast_parse(cast(LazyOp, x.src[0]), acc, offs, loaded_buffers)] # noqa: E501 if x.op in ReduceOps and not do_reduce: assert offs is None, "not available if we aren't doing reduce" return acc # MULACC fusion. TODO: this is copied from Interpreted if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == BinaryOps.MUL: x = LazyOp(TernaryOps.MULACC, x.src[0].src, x.arg) - if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == UnaryOps.CAST and x.src[0].src[0].__class__ is LazyOp and x.src[0].src[0].op == BinaryOps.MUL: + if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == UnaryOps.CAST and x.src[0].src[0].__class__ is LazyOp and x.src[0].src[0].op == BinaryOps.MUL: # noqa: E501 x = LazyOp(TernaryOps.MULACC, x.src[0].src[0].src, x.arg) values = [self.ast_parse(cast(LazyOp, v), acc, offs, loaded_buffers, loop_ctx=loop_ctx) for v in x.src] ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, TernaryOps.MULACC:TernaryOps.MULACC} diff --git a/tinygrad/device.py b/tinygrad/device.py index 2f93adfa..215466f5 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -3,7 +3,8 @@ import numpy as np from collections import defaultdict from typing import TYPE_CHECKING, Union, Any, List, Optional, Dict, Callable import importlib, inspect, functools, pathlib, time, re, ctypes -from tinygrad.helpers import ansilen, DEBUG, getenv, GlobalCounters, colored, BEAM, NOOPT, all_int, to_function_name, DType, from_mv, dtypes, flat_mv, ImageDType +from tinygrad.helpers import DType, dtypes, ImageDType +from tinygrad.helpers import ansilen, DEBUG, getenv, GlobalCounters, colored, BEAM, NOOPT, all_int, to_function_name, from_mv, flat_mv from tinygrad.shape.symbolic import Variable, sym_infer, sint from tinygrad.ops import LazyOp, TernaryOps, get_lazyop_info, ReduceOps, BufferOps, BinaryOps, UnaryOps, Op, vars_from_ast @@ -14,13 +15,13 @@ if TYPE_CHECKING: # **************** Device **************** class _Device: - def __init__(self) -> None: self._buffers: List[str] = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")] - def canonicalize(self, device:Optional[str]) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") if device is not None else self.DEFAULT + def __init__(self) -> None: self._buffers: List[str] = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")] # noqa: E501 + def canonicalize(self, device:Optional[str]) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") if device is not None else self.DEFAULT # noqa: E501 def __getitem__(self, ix:str) -> Union[Interpreted, Compiled]: return self.__get_canonicalized_item(self.canonicalize(ix)) @functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none def __get_canonicalized_item(self, ix:str) -> Union[Interpreted, Compiled]: x = ix.split(":")[0].upper() - ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._buffers][0] + ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._buffers][0] # noqa: E501 if isinstance(ret, type): ret = ret(ix) return ret @functools.cached_property @@ -48,7 +49,7 @@ class JITRunner: def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]: raise NotImplementedError("override this") -def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count:int, jit=False, num_kernels=1, lra: Optional[Dict]=None, device:str="", first_run=False): +def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count:int, jit=False, num_kernels=1, lra: Optional[Dict]=None, device:str="", first_run=False): # noqa: E501 if var_vals is None: var_vals = {} op_estimate, mem_estimate = sym_infer(op_estimate, var_vals), sym_infer(mem_estimate, var_vals) GlobalCounters.kernel_count += num_kernels @@ -57,8 +58,8 @@ def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Option if et is not None: GlobalCounters.time_sum_s += et if DEBUG >= 2: ptm = (colored(f"{et*1e3:9.2f}ms", "RED") if et > 0.01 else f"{et*1e6:9.2f}us") if et is not None else "" - print(f"{colored(f'** {device[:7]:7} {GlobalCounters.kernel_count:4d}', ('magenta' if num_kernels == 1 else 'CYAN') if jit else ('green' if first_run else None))} {name+' '*(37-ansilen(name))} arg {buf_count:3d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + - (str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) + print(f"{colored(f'** {device[:7]:7} {GlobalCounters.kernel_count:4d}', ('magenta' if num_kernels == 1 else 'CYAN') if jit else ('green' if first_run else None))} {name+' '*(37-ansilen(name))} arg {buf_count:3d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501 + (str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) # noqa: E501 # **************** Buffer / Allocator **************** @@ -85,13 +86,14 @@ class Buffer: def fromCPU(device:str, x:np.ndarray): return Buffer(device, x.size, dtypes.from_np(x.dtype)).copyin(x.data) def toCPU(self) -> np.ndarray: # zero copy with as_buffer - if hasattr(self.allocator, 'as_buffer'): return np.frombuffer(self.allocator.as_buffer(self._buf), dtype=np.dtype(self.dtype.np, metadata={"backing": self._buf})) # type: ignore + if hasattr(self.allocator, 'as_buffer'): + return np.frombuffer(self.allocator.as_buffer(self._buf), dtype=np.dtype(self.dtype.np, metadata={"backing": self._buf})) # type: ignore ret = np.empty(self.size, self.dtype.np) if self.size > 0: self.allocator.copyout(flat_mv(ret.data), self._buf) return ret def _internal_buffer_copy(dest, src): - if hasattr(dest.allocator, 'transfer') and type(dest.allocator) is type(src.allocator): + if hasattr(dest.allocator, 'transfer') and type(dest.allocator) is type(src.allocator): # noqa: E721 # fast path, used on HIP between GPUs # NOTE: it's important we use the dest device here to ensure the transfer is ready Device[src.device].synchronize() # TODO: async this @@ -124,7 +126,7 @@ class _BufferCopy(JITRunner): if wait or DEBUG >= 2: Device[dest.device].synchronize() et = time.perf_counter() - st - update_stats(colored(f"copy {dest.size:8d}, {dest.device[:7]:7} <- {src.device[:7]:7}", "yellow"), 0, dest.size*dest.dtype.itemsize, {}, et, 2, jit, lra={"global_size": dest.size}, device=dest.device) + update_stats(colored(f"copy {dest.size:8d}, {dest.device[:7]:7} <- {src.device[:7]:7}", "yellow"), 0, dest.size*dest.dtype.itemsize, {}, et, 2, jit, device=dest.device) # noqa: E501 BufferCopy = _BufferCopy() # TODO: size, dest, src are the same type. can we enforce this? @@ -232,7 +234,7 @@ def _get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], ast:LazyOp) -> Interpret # **************** for Compiled Devices **************** class CompiledASTRunner(JITRunner): - def __init__(self, ast:Optional[LazyOp], name:str, prg:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, runtime_args:Optional[dict]=None): + def __init__(self, ast:Optional[LazyOp], name:str, prg:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, runtime_args:Optional[dict]=None): # noqa: E501 super().__init__() if DEBUG >= 4: print(prg) if global_size is not None: global_size = global_size + [1]*(3-len(global_size)) @@ -267,13 +269,13 @@ class CompiledASTRunner(JITRunner): if global_size: lra['global_size'] = global_size if local_size and 'local_size' not in lra: lra['local_size'] = local_size et = self.clprg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.vars), wait=wait or DEBUG>=2) - update_stats(self.display_name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, lra=lra, device=rawbufs[0].device, first_run=self.first_run) + update_stats(self.display_name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, lra=lra, device=rawbufs[0].device, first_run=self.first_run) # noqa: E501 self.first_run = False return et class Compiled: def __init__(self, allocator:Allocator, linearizer_opts:LinearizerOptions, renderer, compiler, runtime, graph=None): - self.allocator, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.graph = allocator, linearizer_opts, renderer, compiler, runtime, graph + self.allocator, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.graph = allocator, linearizer_opts, renderer, compiler, runtime, graph # noqa: E501 def synchronize(self): pass # override this in your device def to_program(self, k:Linearizer) -> CompiledASTRunner: diff --git a/tinygrad/features/graph/cuda.py b/tinygrad/features/graph/cuda.py index 9832a41a..0af6af01 100644 --- a/tinygrad/features/graph/cuda.py +++ b/tinygrad/features/graph/cuda.py @@ -5,7 +5,7 @@ from tinygrad.helpers import init_c_var, encode_args_cuda_style from tinygrad.device import CompiledASTRunner, update_stats, Buffer from tinygrad.runtime.ops_cuda import check, cu_time_execution from tinygrad.shape.symbolic import Variable -from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals, GraphException +from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals, GraphException # noqa: E501 class CUDAGraph: def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]): @@ -27,7 +27,7 @@ class CUDAGraph: prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg) c_deps = (type(graph_node)*1)(*(graph_node,)) if graph_node is not None else None - c_kernel_input_config, c_input_params = encode_args_cuda_style([cast(Buffer, x)._buf for x in ji.rawbufs], [var_vals[x] for x in prg.vars], *self.encode_args_info()) + c_kernel_input_config, c_input_params = encode_args_cuda_style([cast(Buffer, x)._buf for x in ji.rawbufs], [var_vals[x] for x in prg.vars], *self.encode_args_info()) # noqa: E501 c_node_params = self.build_kernel_node_params(prg, *cast(Tuple[List[int], List[int]], prg.launch_dims(var_vals)), c_kernel_input_config) graph_node = self.graph_add_kernel_node(self.graph, c_deps, c_node_params) @@ -55,7 +55,7 @@ class CUDAGraph: self.graph_exec_kernel_node_set_params(self.instance, node, ctypes.byref(c_node_params)) et = self.graph_launch(self.instance, None, wait=wait) - update_stats(f"", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=jit, num_kernels=len(self.jit_cache)) + update_stats(f"", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=jit, num_kernels=len(self.jit_cache)) # noqa: E501 return et def __del__(self): @@ -64,9 +64,13 @@ class CUDAGraph: def encode_args_info(self): return (cuda.CUdeviceptr_v2, (1,2,0)) def graph_create(self): return init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0))) - def graph_instantiate(self, graph): return init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), graph, None, None, 0))) - def graph_add_kernel_node(self, graph, c_deps, c_node_params): return init_c_var(cuda.CUgraphNode(), lambda x: check(cuda.cuGraphAddKernelNode(ctypes.byref(x), graph, c_deps, ctypes.sizeof(c_deps)//8 if c_deps else 0, ctypes.byref(c_node_params)))) + def graph_instantiate(self, graph): + return init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), graph, None, None, 0))) + def graph_add_kernel_node(self, graph, c_deps, c_node_params): + return init_c_var(cuda.CUgraphNode(), lambda x: check(cuda.cuGraphAddKernelNode(ctypes.byref(x), graph, c_deps, ctypes.sizeof(c_deps)//8 if c_deps else 0, ctypes.byref(c_node_params)))) # noqa: E501 def graph_launch(self, *args, wait=False): return cu_time_execution(lambda: check(cuda.cuGraphLaunch(*args)), enable=wait) def graph_exec_kernel_node_set_params(self, *args): return check(cuda.cuGraphExecKernelNodeSetParams(*args)) - def build_kernel_node_params(self, prg, global_size, local_size, c_kernel_config): return cuda.CUDA_KERNEL_NODE_PARAMS(prg.clprg.prg, *global_size, *local_size, 0, None, c_kernel_config) - def set_kernel_node_launch_dims(self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int]): node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_size, *global_size + def build_kernel_node_params(self, prg, global_size, local_size, c_kernel_config): + return cuda.CUDA_KERNEL_NODE_PARAMS(prg.clprg.prg, *global_size, *local_size, 0, None, c_kernel_config) + def set_kernel_node_launch_dims(self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int]): + node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_size, *global_size diff --git a/tinygrad/features/graph/hip.py b/tinygrad/features/graph/hip.py index 3252ca6e..8d8c9c35 100644 --- a/tinygrad/features/graph/hip.py +++ b/tinygrad/features/graph/hip.py @@ -12,9 +12,13 @@ class HIPGraph(CUDAGraph): def encode_args_info(self): return (hip.hipDeviceptr_t, (1,2,3)) def graph_create(self): return init_c_var(hip.hipGraph_t(), lambda x: check(hip.hipGraphCreate(ctypes.byref(x), 0))) - def graph_instantiate(self, graph): return init_c_var(hip.hipGraphExec_t(), lambda x: check(hip.hipGraphInstantiate(ctypes.byref(x), graph, None, None, 0))) - def graph_add_kernel_node(self, graph, c_deps, c_params): return init_c_var(hip.hipGraphNode_t(), lambda x: check(hip.hipGraphAddKernelNode(ctypes.byref(x), graph, c_deps, ctypes.sizeof(c_deps)//8 if c_deps else 0, ctypes.byref(c_params)))) + def graph_instantiate(self, graph): + return init_c_var(hip.hipGraphExec_t(), lambda x: check(hip.hipGraphInstantiate(ctypes.byref(x), graph, None, None, 0))) + def graph_add_kernel_node(self, graph, c_deps, c_params): + return init_c_var(hip.hipGraphNode_t(), lambda x: check(hip.hipGraphAddKernelNode(ctypes.byref(x), graph, c_deps, ctypes.sizeof(c_deps)//8 if c_deps else 0, ctypes.byref(c_params)))) # noqa: E501 def graph_launch(self, *args, wait=False): return hip_time_execution(lambda: check(hip.hipGraphLaunch(*args)), enable=wait) def graph_exec_kernel_node_set_params(self, *args): return check(hip.hipGraphExecKernelNodeSetParams(*args)) - def build_kernel_node_params(self, prg, global_size, local_size, c_config): return hip.hipKernelNodeParams(hip.dim3(*local_size), c_config, ctypes.cast(prg.clprg.prg, ctypes.c_void_p), hip.dim3(*global_size), None, 0) - def set_kernel_node_launch_dims(self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int]): node.blockDim.x, node.blockDim.y, node.blockDim.z, node.gridDim.x, node.gridDim.y, node.gridDim.z = *local_size, *global_size + def build_kernel_node_params(self, prg, global_size, local_size, c_config): + return hip.hipKernelNodeParams(hip.dim3(*local_size), c_config, ctypes.cast(prg.clprg.prg, ctypes.c_void_p), hip.dim3(*global_size), None, 0) + def set_kernel_node_launch_dims(self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int]): + node.blockDim.x, node.blockDim.y, node.blockDim.z, node.gridDim.x, node.gridDim.y, node.gridDim.z = *local_size, *global_size diff --git a/tinygrad/features/graph/metal.py b/tinygrad/features/graph/metal.py index b554c42e..a025af21 100644 --- a/tinygrad/features/graph/metal.py +++ b/tinygrad/features/graph/metal.py @@ -23,7 +23,7 @@ class MetalGraph: icb_descriptor.setInheritBuffers_(False) icb_descriptor.setInheritPipelineState_(False) icb_descriptor.setMaxKernelBufferBindCount_(31) - self.icb = self.device.device.newIndirectCommandBufferWithDescriptor_maxCommandCount_options_(icb_descriptor, len(self.jit_cache), Metal.MTLResourceOptions(0)) + self.icb = self.device.device.newIndirectCommandBufferWithDescriptor_maxCommandCount_options_(icb_descriptor, len(self.jit_cache), Metal.MTLResourceOptions(0)) # noqa: E501 if self.icb is None: raise GraphException("create indirect command buffer failed, does your system support this?") if len(var_vals): self.int_buf = self.device.allocator.alloc(len(var_vals)*dtypes.int32.itemsize) @@ -33,7 +33,7 @@ class MetalGraph: descriptor = Metal.MTLComputePipelineDescriptor.new() descriptor.setComputeFunction_(prg.clprg.fxn) descriptor.setSupportIndirectCommandBuffers_(True) - pipeline_state = unwrap2(self.device.device.newComputePipelineStateWithDescriptor_options_reflection_error_(descriptor, Metal.MTLPipelineOption(0), None, None)) + pipeline_state = unwrap2(self.device.device.newComputePipelineStateWithDescriptor_options_reflection_error_(descriptor, Metal.MTLPipelineOption(0), None, None)) # noqa: E501 icb_command = self.icb.indirectComputeCommandAtIndex_(j) icb_command.setComputePipelineState_(pipeline_state) for i,b in enumerate(ji.rawbufs): @@ -59,7 +59,7 @@ class MetalGraph: self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_idx]._buf, 0, i) for j in self.jc_idx_with_updatable_launch_dims: global_size, local_size = cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals) - self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size)) + self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size)) # noqa: E501 if len(var_vals): self.int_buf_view[:] = list(var_vals.values()) command_buffer = self.device.mtl_queue.commandBuffer() encoder = command_buffer.computeCommandEncoder() @@ -74,5 +74,5 @@ class MetalGraph: else: self.device.mtl_buffers_in_flight.append(command_buffer) et = None - update_stats(f"", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=jit, num_kernels=len(self.jit_cache)) + update_stats(f"", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=jit, num_kernels=len(self.jit_cache)) # noqa: E501 return et \ No newline at end of file diff --git a/tinygrad/features/image.py b/tinygrad/features/image.py index 82d92f44..4d786301 100644 --- a/tinygrad/features/image.py +++ b/tinygrad/features/image.py @@ -7,7 +7,7 @@ def image_dot(self, w): # NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1) n1, n2 = len(self.shape), len(w.shape) assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D" - assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" + assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" # noqa: E501 bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2]) cin, cout = w.shape[-2], w.shape[-1] out_shape_t = self.shape[0:-2] + (cout,-1) diff --git a/tinygrad/features/search.py b/tinygrad/features/search.py index 21da3511..4e973541 100644 --- a/tinygrad/features/search.py +++ b/tinygrad/features/search.py @@ -75,7 +75,7 @@ def compile_linearizer(dev:str, lin:Linearizer, name:Optional[str]=None) -> Tupl src, _ = rdev.renderer(name if name is not None else to_function_name(lin.name), lin.uops) # NOTE: these all have the same name for deduping return rdev.compiler(src), lin.global_size, lin.local_size -def time_program(dev:str, lib:bytes, global_size, local_size, var_vals, rawbufs, early_stop=None, max_global_size=65536, clear_l2=False, cnt=3, name="test"): +def time_program(dev:str, lib:bytes, global_size, local_size, var_vals, rawbufs, early_stop=None, max_global_size=65536, clear_l2=False, cnt=3, name="test"): # noqa: E501 rdev = Device[dev] assert isinstance(rdev, Compiled) clprg = rdev.runtime(name, lib) @@ -124,21 +124,21 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea while not exiting: acted_lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam]) if len(beam) else [lin] timed_lins: List[Tuple[Linearizer, float]] = [] - for i,proc in (pool.imap_unordered(try_compile_linearized_w_idx, enumerate(acted_lins)) if pool is not None else map(try_compile_linearized_w_idx, enumerate(acted_lins))): + for i,proc in (pool.imap_unordered(try_compile_linearized_w_idx, enumerate(acted_lins)) if pool is not None else map(try_compile_linearized_w_idx, enumerate(acted_lins))): # noqa: E501 if proc is None: continue lib, global_size, local_size = proc if lib in seen_libs: continue seen_libs.add(lib) - tms = time_program(Device.DEFAULT, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0) # > 1 second, run one time + tms = time_program(Device.DEFAULT, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0) timed_lins.append((acted_lins[i], min(tms))) - if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") + if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501 # done opts = sorted(timed_lins, key=lambda x: x[1]) exiting = len(opts) == 0 or (len(beam) > 0 and beam[0][1] <= opts[0][1]) if not exiting: beam = opts[:amt] assert len(beam) > 0, "no BEAM items succeeded?!?" - if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape()) + if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape()) # noqa: E501 if pool is not None: pool.close() # the pool is closed except KeyboardInterrupt as e: if pool is not None: pool.terminate() @@ -155,20 +155,20 @@ def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buff local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice def try_exec(local_size): try: - return clprg(*[x._buf for x in test_rawbuffers], global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size=local_size, wait=True) + return clprg(*[x._buf for x in test_rawbuffers], global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size=local_size, wait=True) # noqa: E501 except Exception: return float('inf') ret = min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))]) assert not math.isinf(ret[0]), "all optimize_local_size exec failed" return ret[1] -def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: - key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2, "device": Device.DEFAULT} +def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501 + key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2, "device": Device.DEFAULT} # noqa: E501 if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val) var_vals = {k:(k.max+k.min)//2 for k in vars_from_ast(lin.ast)} lib, global_size, local_size = compile_linearizer(Device.DEFAULT, lin) - tms = time_program(Device.DEFAULT, lib, global_size, local_size, var_vals, rawbufs, max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name)) + tms = time_program(Device.DEFAULT, lib, global_size, local_size, var_vals, rawbufs, max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name)) # noqa: E501 if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms) return min(tms) diff --git a/tinygrad/graph.py b/tinygrad/graph.py index 805c9ec8..956b7b47 100644 --- a/tinygrad/graph.py +++ b/tinygrad/graph.py @@ -14,8 +14,8 @@ cnts: Dict[OpType, int] = defaultdict(int) if DEBUG >= 2: def print_globalcounters(): if GlobalCounters.time_sum_s == 0: return - print(f"avg: {GlobalCounters.global_ops*1e-9/GlobalCounters.time_sum_s:8.2f} GFLOPS {GlobalCounters.global_mem*1e-9/GlobalCounters.time_sum_s:8.2f} GB/s", - f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms") + print(f"avg: {GlobalCounters.global_ops*1e-9/GlobalCounters.time_sum_s:8.2f} GFLOPS {GlobalCounters.global_mem*1e-9/GlobalCounters.time_sum_s:8.2f} GB/s", # noqa: E501 + f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms") # noqa: E501 atexit.register(print_globalcounters) if GRAPH: import networkx as nx @@ -52,7 +52,7 @@ def add_st_node(nmx, nmo, label, st:ShapeTracker): inter_node = node_count node_count += 1 offset = st.expr_node(NumNode(0))[0] - G.add_node(inter_node, style='filled', fillcolor="#80ff8080", color="black", label=f"{st.shape}\n{st.real_strides()}" + (f"\n{offset}" if offset != 0 else "")) + G.add_node(inter_node, style='filled', fillcolor="#80ff8080", color="black", label=f"{st.shape}\n{st.real_strides()}" + (f"\n{offset}" if offset != 0 else "")) # noqa: E501 G.add_edge(nmx, inter_node, color='#00000060') G.add_edge(inter_node, nmo, label=label, color='#00000060') @@ -69,7 +69,8 @@ def log_schedule_item(si: ScheduleItem): cnts[optype] += 1 if GRAPH: assert si.out.base == si.out, "all outputs based" - top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#FFA0A0", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", TernaryOps: "#c0c0c0", BufferOps: '#a0a0ff'} + top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#FFA0A0", BinaryOps: "#c0c0c0", + MovementOps: "#80ff80", TernaryOps: "#c0c0c0", BufferOps: '#a0a0ff'} # get inputs for shapetrackers input_to_st = defaultdict(list) @@ -89,13 +90,14 @@ def log_schedule_item(si: ScheduleItem): if nm(si.out) not in G.nodes: G.add_node(nm(si.out)) - G.nodes[nm(si.out)]['label'] = '"' + (str(set(x.shape for x in si.inputs))+"\n"+str(si.out.shape) if optype == ReduceOps else str(si.out.shape))+str_dtype(si.out.dtype)+(f"\n{si.ast.op}" if si.ast.op in LoadOps or optype is BufferOps else "")+(f"\n{si.out.device}" if si.out.device != Device.DEFAULT else "") + '"' + G.nodes[nm(si.out)]['label'] = '"' + (str(set(x.shape for x in si.inputs))+"\n"+str(si.out.shape) if optype == ReduceOps else str(si.out.shape))+str_dtype(si.out.dtype)+(f"\n{si.ast.op}" if si.ast.op in LoadOps or optype is BufferOps else "")+(f"\n{si.out.device}" if si.out.device != Device.DEFAULT else "") + '"' # noqa: E501 G.nodes[nm(si.out)]['fillcolor'] = top_colors[optype] G.nodes[nm(si.out)]['color'] = 'black' G.nodes[nm(si.out)]['style'] = 'filled' def _tree(lazydata, prefix=""): - if type(lazydata).__name__ == "LazyBuffer": return [f"━━ realized {lazydata.dtype.name} {lazydata.shape}"] if (lazydata.realized) else _tree(lazydata.op, "LB ") + if type(lazydata).__name__ == "LazyBuffer": + return [f"━━ realized {lazydata.dtype.name} {lazydata.shape}"] if (lazydata.realized) else _tree(lazydata.op, "LB ") if len(lazydata.src) == 0: return [f"━━ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"] lines = [f"━┳ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"] childs = [_tree(c) for c in lazydata.src[:]] @@ -112,7 +114,7 @@ def graph_uops(uops:List[UOp]): G = nx.DiGraph() for u in uops: if u.uop == UOps.END: continue - G.add_node(uops.index(u), label=f"{str(u.uop)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.uop, "#ffffff")) + G.add_node(uops.index(u), label=f"{str(u.uop)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.uop, "#ffffff")) # noqa: E501 for v in u.vin: G.add_edge(uops.index(v), uops.index(u)) GRAPHPATH = "/tmp/uops" nx.drawing.nx_pydot.write_dot(G, f'{GRAPHPATH}.dot') diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 349e6456..25241680 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -21,7 +21,7 @@ def argfix(*x): return tuple(x[0]) if x and x[0].__class__ in (tuple, list) else def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python def all_same(items:List[T]): return all(x == items[0] for x in items) def all_int(t: Tuple[Any, ...]) -> TypeGuard[Tuple[int, ...]]: return all(isinstance(s, int) for s in t) -def colored(st, color:Optional[str], background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line +def colored(st, color:Optional[str], background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line # noqa: E501 def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s) def ansilen(s:str): return len(ansistrip(s)) def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x @@ -30,7 +30,7 @@ def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm) def strip_parens(fst:str): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst def round_up(num, amt:int): return (num+amt-1)//amt * amt def merge_dicts(ds:Iterable[Dict[T,U]]) -> Dict[T,U]: - assert len(kvs:=set([(k,v) for d in ds for k,v in d.items()])) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key" + assert len(kvs:=set([(k,v) for d in ds for k,v in d.items()])) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key" # noqa: E501 return {k:v for d in ds for k,v in d.items()} def partition(lst:List[T], fxn:Callable[[T],bool]): a:List[T] = [] @@ -136,10 +136,10 @@ class PtrDType(DType): def __repr__(self): return f"ptr.{super().__repr__()}" class dtypes: - @staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool - def is_int(x: DType)-> bool: return x.scalar() in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64) @staticmethod def is_float(x: DType) -> bool: return x.scalar() in (dtypes.float16, dtypes.float32, dtypes.float64) + @staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool + def is_int(x: DType) -> bool: return x.scalar() in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64) or dtypes.is_unsigned(x) @staticmethod def is_unsigned(x: DType) -> bool: return x.scalar() in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64) @staticmethod @@ -187,11 +187,13 @@ class dtypes: # we don't support weak type and complex type promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64], dtypes.int64: [dtypes.float16, dtypes.bfloat16], - dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32], dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.float16, dtypes.bfloat16], + dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32], + dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.float16, dtypes.bfloat16], dtypes.float16: [dtypes.float32], dtypes.bfloat16: [dtypes.float32], dtypes.float32: [dtypes.float64], } @functools.lru_cache(None) -def _get_recursive_parents(dtype:DType) -> Set[DType]: return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64} +def _get_recursive_parents(dtype:DType) -> Set[DType]: + return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64} @functools.lru_cache(None) def least_upper_dtype(*ds:DType) -> DType: return min(set.intersection(*[_get_recursive_parents(d) for d in ds])) @@ -247,7 +249,7 @@ def diskcache_put(table:str, key:Union[Dict, str, int], val:Any): ltypes = ', '.join(f"{k} {TYPES[type(key[k])]}" for k in key.keys()) cur.execute(f"CREATE TABLE IF NOT EXISTS {table}_{VERSION} ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))") _db_tables.add(table) - cur.execute(f"REPLACE INTO {table}_{VERSION} ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (pickle.dumps(val), )) + cur.execute(f"REPLACE INTO {table}_{VERSION} ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (pickle.dumps(val), )) # noqa: E501 conn.commit() cur.close() return val @@ -264,7 +266,7 @@ def diskcache(func): def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path: if url.startswith("/") or url.startswith("."): return pathlib.Path(url) - fp = pathlib.Path(name) if name is not None and (isinstance(name, pathlib.Path) or '/' in name) else pathlib.Path(_cache_dir) / "tinygrad" / "downloads" / (name if name else hashlib.md5(url.encode('utf-8')).hexdigest()) + fp = pathlib.Path(name) if name is not None and (isinstance(name, pathlib.Path) or '/' in name) else pathlib.Path(_cache_dir) / "tinygrad" / "downloads" / (name if name else hashlib.md5(url.encode('utf-8')).hexdigest()) # noqa: E501 if not fp.is_file() or not allow_caching: with request.urlopen(url, timeout=10) as r: assert r.status == 200 @@ -289,14 +291,14 @@ def cpu_time_execution(cb, enable): # TODO: make this work with read only memoryviews (if possible) def from_mv(mv, to_type=ctypes.c_char): return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type)) -def to_char_p_p(options: List[bytes], to_type=ctypes.c_char): return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options]) +def to_char_p_p(options: List[bytes], to_type=ctypes.c_char): return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options]) # noqa: E501 @functools.lru_cache(maxsize=None) def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]): class CStruct(ctypes.Structure): _pack_, _fields_ = 1, fields return CStruct def init_c_var(ctypes_var, creat_cb): return (creat_cb(ctypes_var), ctypes_var)[1] -def get_bytes(arg, get_sz, get_str, check) -> bytes: return (sz := init_c_var(ctypes.c_size_t(), lambda x: check(get_sz(arg, ctypes.byref(x)))), ctypes.string_at(init_c_var(ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x))), size=sz.value))[1] +def get_bytes(arg, get_sz, get_str, check) -> bytes: return (sz := init_c_var(ctypes.c_size_t(), lambda x: check(get_sz(arg, ctypes.byref(x)))), ctypes.string_at(init_c_var(ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x))), size=sz.value))[1] # noqa: E501 def flat_mv(mv:memoryview): if len(mv) == 0: return mv return mv.cast("B", shape=(mv.nbytes,)) @@ -305,10 +307,10 @@ def flat_mv(mv:memoryview): def pretty_ptx(s): # all expressions match `` and replace it with `color()` - s = re.sub(r'([!@<\[\s,\+\-;\n])((?:[_%$][\w%\$_]+(?:\.[xyz])?\:?)|(?:buf\d+))([<>\]\s,\+\-;\n\)])', lambda m:m[1]+colored(m[2], "blue")+m[3], s, flags=re.M) # identifiers + s = re.sub(r'([!@<\[\s,\+\-;\n])((?:[_%$][\w%\$_]+(?:\.[xyz])?\:?)|(?:buf\d+))([<>\]\s,\+\-;\n\)])', lambda m:m[1]+colored(m[2], "blue")+m[3], s, flags=re.M) # identifiers # noqa: E501 s = re.sub(r'(.)((?:b|s|u|f)(?:8|16|32|64)|pred)([\.\s])', lambda m:m[1]+colored(m[2], "green")+m[3], s, flags=re.M) # types s = re.sub(r'^(\s*)([\w]+)(.*?;$)', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # instructions - s = re.sub(r'([<>\[\]\s,\+\-;])((?:0[fF][0-9a-fA-F]{8})|(?:[0-9]+)|(?:0[xX][0-9a-fA-F]+))([<>\[\]\s,\+\-;])', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # numbers + s = re.sub(r'([<>\[\]\s,\+\-;])((?:0[fF][0-9a-fA-F]{8})|(?:[0-9]+)|(?:0[xX][0-9a-fA-F]+))([<>\[\]\s,\+\-;])', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # numbers # noqa: E501 s = re.sub(r'(\.)(param|reg|global)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # space s = re.sub(r'(\.)(version|target|address_size|visible|entry)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # derivatives return s @@ -321,8 +323,8 @@ def compile_cuda_style(prg, compile_options, prog_t, create_prog, compile_prog, return get_bytes(prog, get_code_size, get_code, check) def encode_args_cuda_style(bufs, vals, device_ptr_t, marks) -> Tuple[ctypes.Array, ctypes.Structure]: - c_args = init_c_struct_t(tuple([(f'f{i}', device_ptr_t) for i in range(len(bufs))] + [(f'f{i}', ctypes.c_int) for i in range(len(bufs), len(bufs)+len(vals))]))(*bufs, *vals) - return (ctypes.c_void_p * 5)(ctypes.c_void_p(marks[0]), ctypes.cast(ctypes.pointer(c_args), ctypes.c_void_p), ctypes.c_void_p(marks[1]), ctypes.cast(ctypes.pointer(ctypes.c_size_t(ctypes.sizeof(c_args))), ctypes.c_void_p), ctypes.c_void_p(marks[2])), c_args + c_args = init_c_struct_t(tuple([(f'f{i}', device_ptr_t) for i in range(len(bufs))] + [(f'f{i}', ctypes.c_int) for i in range(len(bufs), len(bufs)+len(vals))]))(*bufs, *vals) # noqa: E501 + return (ctypes.c_void_p * 5)(ctypes.c_void_p(marks[0]), ctypes.cast(ctypes.pointer(c_args), ctypes.c_void_p), ctypes.c_void_p(marks[1]), ctypes.cast(ctypes.pointer(ctypes.c_size_t(ctypes.sizeof(c_args))), ctypes.c_void_p), ctypes.c_void_p(marks[2])), c_args # noqa: E501 def time_execution_cuda_style(cb, ev_t, evcreate, evrecord, evsync, evdestroy, evtime, enable=False) -> Optional[float]: if not enable: return cb() diff --git a/tinygrad/jit.py b/tinygrad/jit.py index 94a27397..7caf04a0 100644 --- a/tinygrad/jit.py +++ b/tinygrad/jit.py @@ -15,7 +15,7 @@ class JitItem: rawbufs: List[Optional[Buffer]] def get_jit_stats(jit_cache: List[JitItem]) -> Tuple[Node, Node]: - return functools.reduce(operator.__add__, [ji.prg.op_estimate for ji in jit_cache], NumNode(0)), functools.reduce(operator.__add__, [ji.prg.mem_estimate for ji in jit_cache], NumNode(0)) + return functools.reduce(operator.__add__, [ji.prg.op_estimate for ji in jit_cache], NumNode(0)), functools.reduce(operator.__add__, [ji.prg.mem_estimate for ji in jit_cache], NumNode(0)) # noqa: E501 def get_input_replace(jit_cache: List[JitItem], input_rawbuffers:List[Buffer]) -> Dict[Tuple[int, int], int]: input_replace: Dict[Tuple[int, int], int] = {} for j,ji in enumerate(jit_cache): @@ -24,7 +24,7 @@ def get_input_replace(jit_cache: List[JitItem], input_rawbuffers:List[Buffer]) - input_replace[(j,i)] = input_rawbuffers.index(a) return input_replace def get_jc_idxs_with_updatable_launch_dims(jit_cache: List[JitItem]) -> List[int]: - return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ((ji.prg.global_size and not all_int(tuple(ji.prg.global_size))) or (ji.prg.local_size and not all_int(tuple(ji.prg.local_size))))] + return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ((ji.prg.global_size and not all_int(tuple(ji.prg.global_size))) or (ji.prg.local_size and not all_int(tuple(ji.prg.local_size))))] # noqa: E501 def get_jc_idxs_with_updatable_var_vals(jit_cache: List[JitItem]) -> List[int]: return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ji.prg.vars] @@ -49,7 +49,7 @@ class TinyJit(Generic[ReturnType]): def __call__(self, *args, **kwargs) -> ReturnType: # all inputs (except const) are realized - input_tensors: Dict[Union[int, str], Tensor] = {cast(Union[int, str], k):v.realize() for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor} + input_tensors: Dict[Union[int, str], Tensor] = {cast(Union[int, str], k):v.realize() for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor} # noqa: E501 expected_name_sts_dtype = tuple([(k, v.lazydata.st.unbind(), v.dtype) for k,v in input_tensors.items()]) # get rawbuffers @@ -58,13 +58,13 @@ class TinyJit(Generic[ReturnType]): assert len(set(input_rawbuffers)) == len(input_rawbuffers), "duplicate inputs to JIT" # get variables: they can either be in Tensors or passed in as arguments, and all must be bound. these are all global - var_vals: Dict[Variable, int] = merge_dicts([arg.lazydata.st.var_vals for arg in input_tensors.values()] + [dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))]) + var_vals: Dict[Variable, int] = merge_dicts([arg.lazydata.st.var_vals for arg in input_tensors.values()] + [dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))]) # noqa: E501 expected_vals = tuple(var_vals.keys()) if self.cnt >= 2: # jit exec assert self.expected_vals == expected_vals, "mismatch of var_vals" - assert self.expected_name_sts_dtype == expected_name_sts_dtype, f"mismatch of sts, expected {self.expected_name_sts_dtype} got {expected_name_sts_dtype}" + assert self.expected_name_sts_dtype == expected_name_sts_dtype, f"mismatch of sts, expected {self.expected_name_sts_dtype} got {expected_name_sts_dtype}" # noqa: E501 for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx] for ji in self.jit_cache: ji.prg(cast(List[Buffer], ji.rawbufs), var_vals, wait=DEBUG>=2, jit=True) elif self.cnt == 1: @@ -74,20 +74,21 @@ class TinyJit(Generic[ReturnType]): self.ret = self.fxn(*args, **kwargs) self.jit_cache = CacheCollector.finish() assert len(self.jit_cache) != 0, "didn't JIT anything!" - assert len(set(get_input_replace(self.jit_cache, input_rawbuffers).values())) == len(input_rawbuffers), "some input tensors not found" # Do this check on an unmodified jit cache. + assert len(set(get_input_replace(self.jit_cache, input_rawbuffers).values())) == len(input_rawbuffers), "some input tensors not found" if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs") # if your Device supports it, condense the items into a graph executor. if (make_graph := Device[Device.DEFAULT].graph) and getenv("JIT") != 2: - # Split JIT cache into batches for faster graph execution. This allows the accelerator to run some batches while subsequent graphs are still being updated. - # JitItems that cannot be jitted (not CompiledASTRunner) are moved to the final jit cache and do not participate in the process of graph building. + # Split JIT cache into batches for faster graph execution. + # This allows the accelerator to run some batches while subsequent graphs are still being updated. graphed_jit_cache, current_batch = [], [] for i,ji in enumerate(self.jit_cache): # If the jit item can potentially be graphed, put it in a batch. if isinstance(ji.prg, CompiledASTRunner): current_batch.append(ji) - # The flush is done when (1) ji is the last one, (2) the size of batch exceeds the maximum batch size or (3) the current jit item cannot be graphed, so the current batch is flushed before such a jit item is added. - if len(current_batch) > 0 and (i==len(self.jit_cache)-1 or len(current_batch) >= getenv("JIT_BATCH_SIZE", 64) or not isinstance(ji.prg, CompiledASTRunner)): + # The flush is done when (1) ji is the last one, (2) the size of batch exceeds the maximum batch size or + # (3) the current jit item cannot be graphed, so the current batch is flushed before such a jit item is added. + if len(current_batch) > 0 and (i==len(self.jit_cache)-1 or len(current_batch) >= getenv("JIT_BATCH_SIZE", 64) or not isinstance(ji.prg, CompiledASTRunner)): # noqa: E501 try: graphed_jit_cache.append(JitItem(make_graph(current_batch, input_rawbuffers, var_vals), cast(List[Optional[Buffer]], input_rawbuffers))) if DEBUG >= 2: print(f"\tJIT GRAPHing batch with {len(current_batch)} kernels") diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 44a7ea77..213b18eb 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -1,3 +1,5 @@ +# ruff: noqa: E501 +# TODO: replace with new lazy with <= 150 length lines from __future__ import annotations import sys, math from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast, Mapping, Set diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index 2309d6af..ae937da9 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -126,7 +126,7 @@ class Div(Function): def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return grad_output.e(BinaryOps.DIV, self.y) if self.needs_input_grad[0] else None, \ - grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.DIV, self.y.e(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None + grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.DIV, self.y.e(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None # noqa: E501 # ************* ternary ops ************* diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index 79eda968..b5078f15 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -27,7 +27,7 @@ class BatchNorm2d: # NOTE: wow, this is done all throughout training in most PyTorch models if self.track_running_stats: self.running_mean.assign((1 - self.momentum) * self.running_mean + self.momentum * batch_mean.detach()) - self.running_var.assign((1 - self.momentum) * self.running_var + self.momentum * prod(y.shape)/(prod(y.shape) - y.shape[1]) * batch_var.detach() ) + self.running_var.assign((1 - self.momentum) * self.running_var + self.momentum * prod(y.shape)/(prod(y.shape) - y.shape[1]) * batch_var.detach() ) # noqa: E501 self.num_batches_tracked += 1 else: batch_mean = self.running_mean @@ -52,7 +52,8 @@ class Conv2d: def __call__(self, x:Tensor): return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups) - def initialize_weight(self, out_channels, in_channels, groups): return Tensor.kaiming_uniform(out_channels, in_channels//groups, *self.kernel_size, a=math.sqrt(5)) + def initialize_weight(self, out_channels, in_channels, groups): + return Tensor.kaiming_uniform(out_channels, in_channels//groups, *self.kernel_size, a=math.sqrt(5)) def ConvTranspose1d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True): return ConvTranspose2d(in_channels, out_channels, (kernel_size,), stride, padding, output_padding, dilation, groups, bias) @@ -63,9 +64,10 @@ class ConvTranspose2d(Conv2d): self.output_padding = output_padding def __call__(self, x:Tensor): - return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride, dilation=self.dilation, groups=self.groups) + return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride, dilation=self.dilation, groups=self.groups) # noqa: E501 - def initialize_weight(self, out_channels, in_channels, groups): return Tensor.kaiming_uniform(in_channels, out_channels//groups, *self.kernel_size, a=math.sqrt(5)) + def initialize_weight(self, out_channels, in_channels, groups): + return Tensor.kaiming_uniform(in_channels, out_channels//groups, *self.kernel_size, a=math.sqrt(5)) class Linear: def __init__(self, in_features, out_features, bias=True): diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index bc97ea83..018cb8fc 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -5,7 +5,8 @@ from tinygrad.tensor import Tensor from tinygrad.helpers import dtypes, prod, argsort, DEBUG, Timing, GlobalCounters, CI, unwrap from tinygrad.shape.view import strides_for_shape -safe_dtypes = {"F16": dtypes.float16, "F32": dtypes.float32, "U8": dtypes.uint8, "I8": dtypes.int8, "I32": dtypes.int32, "I64": dtypes.int64, "F64": dtypes.double, "B": dtypes.bool, "I16": dtypes.short, "U16": dtypes.ushort, "UI": dtypes.uint, "UL": dtypes.ulong} +safe_dtypes = {"F16": dtypes.float16, "F32": dtypes.float32, "U8": dtypes.uint8, "I8": dtypes.int8, "I32": dtypes.int32, "I64": dtypes.int64, + "F64": dtypes.double, "B": dtypes.bool, "I16": dtypes.short, "U16": dtypes.ushort, "UI": dtypes.uint, "UL": dtypes.ulong} inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()} def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]: @@ -15,7 +16,7 @@ def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]: def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]: t, json_len, metadata = safe_load_metadata(fn) - return {k:t[8+json_len+v['data_offsets'][0]:].cast(safe_dtypes[v['dtype']])[:prod(v['shape'])].reshape(v['shape']) for k,v in metadata.items() if k != "__metadata__"} + return {k:t[8+json_len+v['data_offsets'][0]:].cast(safe_dtypes[v['dtype']])[:prod(v['shape'])].reshape(v['shape']) for k,v in metadata.items() if k != "__metadata__"} # noqa: E501 def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any]]=None): headers, offset = {}, 0 @@ -49,9 +50,10 @@ def get_parameters(obj) -> List[Tensor]: return list(get_state_dict(obj).values( def load_state_dict(model, state_dict, strict=True, verbose=True): start_mem_used = GlobalCounters.mem_used - with Timing("loaded weights in ", lambda et_ns: f", {(GlobalCounters.mem_used-start_mem_used)/1e9:.2f} GB loaded at {(GlobalCounters.mem_used-start_mem_used)/et_ns:.2f} GB/s"): + with Timing("loaded weights in ", lambda et_ns: f", {(GlobalCounters.mem_used-start_mem_used)/1e9:.2f} GB loaded at {(GlobalCounters.mem_used-start_mem_used)/et_ns:.2f} GB/s"): # noqa: E501 model_state_dict = get_state_dict(model) - if DEBUG >= 1 and len(state_dict) > len(model_state_dict): print("WARNING: unused weights in state_dict", sorted(list(state_dict.keys() - model_state_dict.keys()))) + if DEBUG >= 1 and len(state_dict) > len(model_state_dict): + print("WARNING: unused weights in state_dict", sorted(list(state_dict.keys() - model_state_dict.keys()))) for k,v in (t := tqdm(model_state_dict.items(), disable=CI or not verbose)): t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, {k:50s}") if k not in state_dict and not strict: @@ -90,8 +92,8 @@ def torch_load(fn:str) -> Dict[str, Tensor]: def __setstate__(self, state): self.tensor = state[0] deserialized_objects: Dict[str, Any] = {} - intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16, "IntStorage": dtypes.int32, "LongStorage": dtypes.int64, - "_rebuild_tensor_v2": _rebuild_tensor_v2, "FloatTensor": None, "Parameter": Parameter} + intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16, "IntStorage": dtypes.int32, + "LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2, "FloatTensor": None, "Parameter": Parameter} whitelist = {"torch", "collections", "numpy", "_codecs"} # NOTE: this is not for security, only speed class Dummy: pass class TorchPickle(pickle.Unpickler): @@ -123,7 +125,8 @@ def torch_load(fn:str) -> Dict[str, Tensor]: f = unwrap(tar.extractfile('tensors')) for _ in range(TorchPickle(f).load()): # num_tensors (key, storage_id, _), ndim, _ = TorchPickle(f).load(), struct.unpack(' LazyOp: return LazyOp(self.op, tuple([y.map_buffers(real_srcs) if y not in real_srcs else real_srcs[y] for y in self.src]), self.arg) + def map_buffers(self, real_srcs: Mapping[Any, Union[LazyBuffer, LazyOp]]) -> LazyOp: + return LazyOp(self.op, tuple([y.map_buffers(real_srcs) if y not in real_srcs else real_srcs[y] for y in self.src]), self.arg) def get_lazyops(self) -> List[LazyOp]: return [self] + [item for x in self.src for item in x.get_lazyops()] def replace_with_movement_ops(self:LazyOp, ops:List[Tuple[MovementOps, Tuple[Any, ...]]]) -> 'LazyBuffer': @@ -80,7 +81,8 @@ class LazyOp: def shrink(self, _): raise NotImplementedError def stride(self, _): raise NotImplementedError -def vars_from_ast(ast:LazyOp) -> List[Variable]: return sorted(set.union(*[x.arg.st.vars() for x in ast.get_lazyops() if x.op in BufferOps], set()), key=lambda x: str(x.expr)) +def vars_from_ast(ast:LazyOp) -> List[Variable]: + return sorted(set.union(*[x.arg.st.vars() for x in ast.get_lazyops() if x.op in BufferOps], set()), key=lambda x: str(x.expr)) # **************** independent FlopCounter **************** @@ -97,12 +99,14 @@ class FlopCounter: return ret InterpretedFlopCounter: Dict[Op, Callable] = { - BufferOps.LOAD: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {arg.idx: arg.dtype.itemsize*arg.st.size()}), BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {}), - BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, arg.dtype, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize*arg.st.size()}), UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, arg[0], self.consume_flops(), self.mem), # cast uses no flops - **{op:lambda self: FlopCounter(self.shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op != UnaryOps.CAST}, - **{op:lambda self,y: FlopCounter(self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, + BufferOps.LOAD: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {arg.idx: arg.dtype.itemsize*arg.st.size()}), + BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {}), + BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, arg.dtype, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize*arg.st.size()}), # noqa: E501 + UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, arg[0], self.consume_flops(), self.mem), # cast uses no flops + **{op:lambda self: FlopCounter(self.shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op != UnaryOps.CAST}, # noqa: E501 + **{op:lambda self,y: FlopCounter(self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501 **{op:lambda self,new_shape: FlopCounter(new_shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps}, - TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, max(y.dtype, z.dtype), self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})} + TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, max(y.dtype, z.dtype), self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})} # noqa: E501 @functools.lru_cache(None) def get_lazyop_info(ast:LazyOp) -> FlopCounter: diff --git a/tinygrad/realize.py b/tinygrad/realize.py index 3fdb5da1..0eba097e 100644 --- a/tinygrad/realize.py +++ b/tinygrad/realize.py @@ -12,7 +12,7 @@ class CustomOp(JITRunner): def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False): self.fxn(*rawbufs) def lower_schedule_item(si:ScheduleItem) -> Optional[JITRunner]: - assert all(si.out.device == x.device for x in si.inputs) or si.ast.op is LoadOps.COPY, f"all devices must be the same, {si.out.device} != {[x.device for x in si.inputs]} {print_tree(si.ast) or ''}" + assert all(si.out.device == x.device for x in si.inputs) or si.ast.op is LoadOps.COPY, f"all devices must be the same, {si.out.device} != {[x.device for x in si.inputs]} {print_tree(si.ast) or ''}" # noqa: E501 if si.ast.op is LoadOps.EMPTY: return None if si.ast.op is LoadOps.COPY: return BufferCopy if si.ast.op is LoadOps.CUSTOM: return CustomOp(si.ast.arg) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 7f787638..3e7939ef 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -62,7 +62,7 @@ class CStyleLanguage(NamedTuple): if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and output_dtype.scalar() != dtypes.float16: return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx})" if output_dtype.sz > 1: - out_val = f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx}))" + out_val = f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx}))" # noqa: E501 else: out_val = f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]" @@ -81,7 +81,7 @@ class CStyleLanguage(NamedTuple): return f"({cond})?({x}):{y}" def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], local_size:List[int], prekernel:List[str]) -> str: - tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else "" + tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else "" # noqa: E501 buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else self.arg_int_prefix if dtype == dtypes._arg_int32 else ("const " if i > 0 else "")+self.buffer_prefix+dtype.name+"*"+self.buffer_suffix) for i,(name,dtype) in enumerate(bufs)] @@ -99,7 +99,7 @@ class CStyleLanguage(NamedTuple): if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and var_dtype.scalar() != dtypes.float16: return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx});" if var_dtype.sz > 1: - return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.sz}){var_name};" + return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.sz}){var_name};" # noqa: E501 return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};" def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tuple[str, Dict]: @@ -149,7 +149,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu kk(f"{output}.x = c.thread_elements()[0]; {output}.y = c.thread_elements()[1]; }}") elif args[0] == "HIP": assert dtype == dtypes.float.vec(8), "output dtype of HIP TC is _float8" - kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u, 'wmma')} = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});") + kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u, 'wmma')} = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});") # noqa: E501 else: raise NotImplementedError(f"WMMA not implemented for {args}") elif uop == UOps.ALU: @@ -160,7 +160,8 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu else: val = lang.code_for_op[args](*[r[x] for x in vin] + [dtype]) assert child_count[u] != 0, f"childless ALU op found {u}" - if (child_count[u] <= 1 or dtypes.is_int(dtype)) and args != BinaryOps.MAX and not getenv("EXPAND_SSA"): # fix index rendering issue. fix clang nested max macro issue + # TODO: fix index rendering issue. fix clang nested max macro issue + if (child_count[u] <= 1 or dtypes.is_int(dtype)) and args != BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val else: kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'alu')} = {val};") @@ -272,7 +273,8 @@ class HIPLanguage(CStyleLanguage): __device__ float4 log2(float4 x) { return float4(log2(x.x), log2(x.y), log2(x.z), log2(x.w)); } __device__ float4 exp2(float4 x) { return float4(exp2(x.x), exp2(x.y), exp2(x.z), exp2(x.w)); } __device__ float4 sin(float4 x) { return float4(sin(x.x), sin(x.y), sin(x.z), sin(x.w)); } - typedef float float8 __attribute__((ext_vector_type(8))); __device__ float8 make_float8(float x, float y, float z, float w, float a, float b, float c, float d) { return {x, y, z, w, a, b, c, d}; } + typedef float float8 __attribute__((ext_vector_type(8))); + __device__ float8 make_float8(float x, float y, float z, float w, float a, float b, float c, float d) { return {x, y, z, w, a, b, c, d}; } extern "C" __global__ """ launch_bounds = True @@ -284,15 +286,22 @@ class HIPLanguage(CStyleLanguage): uses_ptr_arithmetic=True arg_int_prefix = "const int" half_prekernel = "#include \n" + """ -typedef union { struct { half x, y, z, w; } __attribute__((aligned(8))); half data[4]; } half4; __device__ half4 make_half4(half x, half y, half z, half w) { return {x, y, z, w}; } -typedef union { struct { half x, y, z, w, a, b, c, d; } __attribute__((aligned(16))); half data[8]; } half8; __device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { return {x, y, z, w, a, b, c, d}; } - typedef _Float16 half16 __attribute__((ext_vector_type(16))); __device__ half16 make_half16(half x, half y, half z, half w, half a, half b, half c, half d, half e, half f, half g, half h, half i, half j, half k, half l) { return {x, y, z, w, a, b, c, d, e, f, g, h, i, j, k, l}; } +typedef union { struct { half x, y, z, w; } __attribute__((aligned(8))); half data[4]; } half4; +__device__ half4 make_half4(half x, half y, half z, half w) { return {x, y, z, w}; } +typedef union { struct { half x, y, z, w, a, b, c, d; } __attribute__((aligned(16))); half data[8]; } half8; +__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { return {x, y, z, w, a, b, c, d}; } + typedef _Float16 half16 __attribute__((ext_vector_type(16))); +__device__ half16 make_half16(half x, half y, half z, half w, half a, half b, half c, half d, + half e, half f, half g, half h, half i, half j, half k, half l) { + return {x, y, z, w, a, b, c, d, e, f, g, h, i, j, k, l}; } __device__ float vload_half(size_t offset, const half *p) { return (float)*(p + offset); } __device__ float2 vload_half2(size_t offset, const half *p) { return make_float2((float)*(p + offset*2), (float)*(p + offset*2 + 1)); } -__device__ float4 vload_half4(size_t offset, const half *p) { return make_float4((float)*(p + offset*4), (float)*(p + offset*4 + 1), (float)*(p + offset*4 + 2), (float)*(p + offset*4 + 3)); } +__device__ float4 vload_half4(size_t offset, const half *p) { + return make_float4((float)*(p + offset*4), (float)*(p + offset*4 + 1), (float)*(p + offset*4 + 2), (float)*(p + offset*4 + 3)); } __device__ void vstore_half(float data, size_t offset, half *p) { *(p + offset) = (half)data; } __device__ void vstore_half2(float2 data, size_t offset, half *p) { *(p + offset*2) = (half)data.x; *(p + offset*2 + 1) = (half)data.y; } -__device__ void vstore_half4(float4 data, size_t offset, half *p) { *(p + offset*4) = (half)data.x; *(p + offset*4 + 1) = (half)data.y; *(p + offset*4 + 2) = (half)data.z; *(p + offset*4 + 3) = (half)data.w; } +__device__ void vstore_half4(float4 data, size_t offset, half *p) { + *(p + offset*4) = (half)data.x; *(p + offset*4 + 1) = (half)data.y; *(p + offset*4 + 2) = (half)data.z; *(p + offset*4 + 3) = (half)data.w; } __device__ half exp2(half x) { return hexp2(x); } __device__ half log2(half x) { return hlog2(x); } __device__ half sin(half x) { return hsin(x); } @@ -317,7 +326,8 @@ __device__ bool operator<(const unsigned short &a, const half &b) { return __hlt gid = [f'blockIdx.{chr(120+i)}' for i in range(3)] lid = [f'threadIdx.{chr(120+i)}' for i in range(3)] xid = [f'(blockIdx.{chr(120+i)}*blockDim.{chr(120+i)}+threadIdx.{chr(120+i)})' for i in range(3)] - code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})" if dtype != dtypes.half else f"hmax({a},{b})", TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}!=0?{b}:{c})" if dtype != dtypes.half else f"(half)({a}!=0?{b}:{c})"} + code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})" if dtype != dtypes.half else f"hmax({a},{b})", + TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}!=0?{b}:{c})" if dtype != dtypes.half else f"(half)({a}!=0?{b}:{c})"} HIPRenderer = functools.partial(uops_to_cstyle, HIPLanguage()) # TODO: how much of this can be merged with above? @@ -328,7 +338,8 @@ class WGSLLanguage(CStyleLanguage): barrier="workgroupBarrier();" generic_var_prefix = "var " external_local_bufs = True - code_for_op = { **CStyleLanguage().code_for_op, BinaryOps.CMPLT: lambda x,y,dtype: f"f32({x}<{y})", TernaryOps.MULACC: lambda x,y,z,dtype: f"fma({x},{y},{z})", TernaryOps.WHERE: lambda a,b,c,dtype: f"select({c},{b},{a}!=0.)" } + code_for_op = { **CStyleLanguage().code_for_op, BinaryOps.CMPLT: lambda x,y,dtype: f"f32({x}<{y})", + TernaryOps.MULACC: lambda x,y,z,dtype: f"fma({x},{y},{z})", TernaryOps.WHERE: lambda a,b,c,dtype: f"select({c},{b},{a}!=0.)" } type_map = {dtypes.float: "f32", dtypes.half: "f16", dtypes.int32: "i32", dtypes.uint32: "u32", dtypes.bool: "bool"} def render_local(self, name: str, size: int): @@ -343,8 +354,8 @@ class WGSLLanguage(CStyleLanguage): local_size = local_size[::-1] if local_size else [1] bind_it = iter(range(len(bufs))) prg = "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast(bits); }\n" - prg += "\n".join(prekernel+[f"@group(0) @binding({next(bind_it)}) {'var' if dtype == dtypes._arg_int32 else 'var'} {name}: {'i32' if dtype == dtypes._arg_int32 else f'array<{self.type_map[dtype]}>'};" for name,dtype in bufs]) - prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3, @builtin(local_invocation_id) lindex: vec3) {{\n" + "\n".join(kernel) + "\n}" + prg += "\n".join(prekernel+[f"@group(0) @binding({next(bind_it)}) {'var' if dtype == dtypes._arg_int32 else 'var'} {name}: {'i32' if dtype == dtypes._arg_int32 else f'array<{self.type_map[dtype]}>'};" for name,dtype in bufs]) # noqa: E501 + prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3, @builtin(local_invocation_id) lindex: vec3) {{\n" + "\n".join(kernel) + "\n}" # noqa: E501 return prg def render_for(self, expr:str, _min:Union[int,str], _max:Union[int,str]) -> str: diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 04278d06..1a523db0 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -4,31 +4,37 @@ from tinygrad.codegen.linearizer import UOps, UOp from tinygrad.helpers import DType, dtypes from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps -LLVM_FAST_MATH_FLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf +MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf def is_bool_or_unsigned(dtype: DType): return dtype == dtypes.bool or dtypes.is_unsigned(dtype) code_for_op: Final[Dict[Op, Callable]] = { - UnaryOps.NEG: lambda builder, x, var_dtype: builder.xor(x, ir.Constant(ir.IntType(1), 1)) if var_dtype == dtypes.bool else builder.neg(x) if dtypes.is_int(var_dtype) else builder.fneg(x, flags=LLVM_FAST_MATH_FLAGS), - UnaryOps.EXP2: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [x.type]), [x], fastmath=LLVM_FAST_MATH_FLAGS), - UnaryOps.LOG2: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.log2', [x.type]), [x], fastmath=LLVM_FAST_MATH_FLAGS), - UnaryOps.SIN: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [x.type]), [x], fastmath=LLVM_FAST_MATH_FLAGS), - UnaryOps.SQRT: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=LLVM_FAST_MATH_FLAGS), - BinaryOps.ADD: lambda builder, x, y, var_dtype: builder.or_(x, y) if var_dtype == dtypes.bool else builder.add(x, y) if dtypes.is_int(var_dtype) else builder.fadd(x, y, flags=LLVM_FAST_MATH_FLAGS), - BinaryOps.SUB: lambda builder, x, y, var_dtype: builder.sub(x, y) if dtypes.is_int(var_dtype) else builder.fsub(x, y, flags=LLVM_FAST_MATH_FLAGS), - BinaryOps.MUL: lambda builder, x, y, var_dtype: builder.mul(x, y) if is_bool_or_unsigned(var_dtype) or dtypes.is_int(var_dtype) else builder.fmul(x, y, flags=LLVM_FAST_MATH_FLAGS), # TOOD should we use umul_with_overflow? - BinaryOps.DIV: lambda builder, x, y, var_dtype: builder.udiv(x, y) if is_bool_or_unsigned(var_dtype) else builder.sdiv(x, y) if dtypes.is_int(var_dtype) else builder.fdiv(x, y, flags=LLVM_FAST_MATH_FLAGS), - BinaryOps.CMPLT: lambda builder, x, y, var_dtype: builder.icmp_unsigned("<", x, y) if is_bool_or_unsigned(var_dtype) else builder.icmp_signed("<", x, y) if dtypes.is_int(var_dtype) else builder.fcmp_unordered("<", x, y, flags=LLVM_FAST_MATH_FLAGS), - BinaryOps.MAX: lambda builder, x, y, var_dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(var_dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(var_dtype) else builder.fcmp_unordered(">", x, y, flags=LLVM_FAST_MATH_FLAGS), x, y), - BinaryOps.MOD: lambda builder, x, y, var_dtype: builder.urem(x, y) if is_bool_or_unsigned(var_dtype) else builder.srem(x, y) if dtypes.is_int(var_dtype) else builder.frem(x, y), + UnaryOps.NEG: lambda builder, x, var_dtype: builder.xor(x, ir.Constant(ir.IntType(1), 1)) if var_dtype == dtypes.bool else builder.neg(x) if dtypes.is_int(var_dtype) else builder.fneg(x, flags=MFLAGS), # noqa: E501 + UnaryOps.EXP2: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [x.type]), [x], fastmath=MFLAGS), + UnaryOps.LOG2: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.log2', [x.type]), [x], fastmath=MFLAGS), + UnaryOps.SIN: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [x.type]), [x], fastmath=MFLAGS), + UnaryOps.SQRT: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=MFLAGS), + BinaryOps.ADD: lambda builder, x, y, var_dtype: builder.or_(x, y) if var_dtype == dtypes.bool else builder.add(x, y) if dtypes.is_int(var_dtype) else builder.fadd(x, y, flags=MFLAGS), # noqa: E501 + BinaryOps.SUB: lambda builder, x, y, var_dtype: builder.sub(x, y) if dtypes.is_int(var_dtype) else builder.fsub(x, y, flags=MFLAGS), + BinaryOps.MUL: lambda builder, x, y, var_dtype: # TOOD should we use umul_with_overflow? + builder.mul(x, y) if is_bool_or_unsigned(var_dtype) or dtypes.is_int(var_dtype) else builder.fmul(x, y, flags=MFLAGS), + BinaryOps.DIV: lambda builder, x, y, var_dtype: + builder.udiv(x, y) if is_bool_or_unsigned(var_dtype) else builder.sdiv(x, y) if dtypes.is_int(var_dtype) else builder.fdiv(x, y, flags=MFLAGS), + BinaryOps.CMPLT: lambda builder, x, y, var_dtype: builder.icmp_unsigned("<", x, y) if is_bool_or_unsigned(var_dtype) else builder.icmp_signed("<", x, y) if dtypes.is_int(var_dtype) else builder.fcmp_unordered("<", x, y, flags=MFLAGS), # noqa: E501 + BinaryOps.MAX: lambda builder, x, y, var_dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(var_dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(var_dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501 + BinaryOps.MOD: lambda builder, x, y, var_dtype: + builder.urem(x, y) if is_bool_or_unsigned(var_dtype) else builder.srem(x, y) if dtypes.is_int(var_dtype) else builder.frem(x, y), BinaryOps.XOR: lambda builder, x, y, var_dtype: builder.xor(x, y), - TernaryOps.MULACC: lambda builder, x, y, z, var_dtype: builder.fadd(builder.fmul(x, y, flags=LLVM_FAST_MATH_FLAGS), z, flags=LLVM_FAST_MATH_FLAGS), - TernaryOps.WHERE: lambda builder, x, y, z, var_dtype: builder.select(builder.trunc(x, ir.IntType(1)) if isinstance(x.type, ir.IntType) else builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=LLVM_FAST_MATH_FLAGS), y, z) + TernaryOps.MULACC: lambda builder, x, y, z, var_dtype: builder.fadd(builder.fmul(x, y, flags=MFLAGS), z, flags=MFLAGS), + TernaryOps.WHERE: lambda builder, x, y, z, var_dtype: builder.select(builder.trunc(x, ir.IntType(1)) if isinstance(x.type, ir.IntType) else builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=MFLAGS), y, z) # noqa: E501 } -dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32), dtypes._arg_int32: ir.IntType(32), dtypes.int16:ir.IntType(16), dtypes.uint16:ir.IntType(16), dtypes.uint32:ir.IntType(32), dtypes.uint64:ir.IntType(64)} +dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), + dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), + dtypes.int32: ir.IntType(32), dtypes._arg_int32: ir.IntType(32), dtypes.int16:ir.IntType(16), dtypes.uint16:ir.IntType(16), + dtypes.uint32:ir.IntType(32), dtypes.uint64:ir.IntType(64)} def cast(bb, val, input_type, output_type, bitcast=False): if input_type == output_type: return val @@ -65,7 +71,8 @@ def cast(bb, val, input_type, output_type, bitcast=False): raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented") -def const(args, dtype): return ir.Constant(dtype_to_llvm_dtype[dtype], int(args) if dtypes.is_int(dtype) else bool(args) if dtype == dtypes.bool else args) +def const(args, dtype): + return ir.Constant(dtype_to_llvm_dtype[dtype], int(args) if dtypes.is_int(dtype) else bool(args) if dtype == dtypes.bool else args) def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]: # all llvm stuff goes into a module @@ -77,7 +84,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]: # create llvm function func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values()] - func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if dt!=dtypes._arg_int32 else x for x,dt in func_dtypes]), name=function_name) + func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if dt!=dtypes._arg_int32 else x for x,dt in func_dtypes]), name=function_name) # noqa: E501 for a in func.args: if a.type.is_pointer: a.add_attribute("noalias") diff --git a/tinygrad/runtime/ops_clang.py b/tinygrad/runtime/ops_clang.py index 1540aac6..62d3a797 100644 --- a/tinygrad/runtime/ops_clang.py +++ b/tinygrad/runtime/ops_clang.py @@ -5,13 +5,13 @@ from tinygrad.helpers import diskcache, cpu_time_execution from tinygrad.codegen.kernel import LinearizerOptions from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage -CLANG_PROGRAM_HEADER = '#include \n#define max(x,y) ((x>y)?x:y)\n#define int64 long\n#define half __fp16\n#define uchar unsigned char\n#include \n' +CLANG_PROGRAM_HEADER = '#include \n#define max(x,y) ((x>y)?x:y)\n#define int64 long\n#define half __fp16\n#define uchar unsigned char\n#include \n' # noqa: E501 @diskcache def compile_clang(prg:str, header:str=CLANG_PROGRAM_HEADER) -> bytes: # TODO: remove file write. sadly clang doesn't like the use of /dev/stdout here with tempfile.NamedTemporaryFile(delete=True) as output_file: - subprocess.check_output(args=('clang -shared -march=native -O2 -Wall -Werror -x c -fPIC - -o '+str(output_file.name)).split(), input=(header+prg).encode('utf-8')) + subprocess.check_output(args=('clang -shared -march=native -O2 -Wall -Werror -x c -fPIC - -o '+str(output_file.name)).split(), input=(header+prg).encode('utf-8')) # noqa: E501 return pathlib.Path(output_file.name).read_bytes() class ClangProgram: diff --git a/tinygrad/runtime/ops_cpu.py b/tinygrad/runtime/ops_cpu.py index 63d1d7ba..6e74a127 100644 --- a/tinygrad/runtime/ops_cpu.py +++ b/tinygrad/runtime/ops_cpu.py @@ -27,13 +27,15 @@ def einsum_mulacc(einsum, get_strides, expand): numpy_fxn_for_op: Dict[Op, Callable] = { BufferOps.CONST: lambda val, dtype: np.array(val, dtype=dtype.np), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin, - UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False), UnaryOps.NEG: lambda x: np.logical_not(x) if x.dtype == np.bool_ else np.negative(x), + UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False), + UnaryOps.NEG: lambda x: np.logical_not(x) if x.dtype == np.bool_ else np.negative(x), BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: (x Optional[float]: return time_execution_cuda_style(cb, cuda.CUevent, cuda.cuEventCreate, cuda.cuEventRecord, cuda.cuEventSynchronize, cuda.cuEventDestroy_v2, cuda.cuEventElapsedTime, enable=enable) if not CUDACPU else cpu_time_execution(cb, enable=enable) +def cu_time_execution(cb, enable=False) -> Optional[float]: return time_execution_cuda_style(cb, cuda.CUevent, cuda.cuEventCreate, cuda.cuEventRecord, cuda.cuEventSynchronize, cuda.cuEventDestroy_v2, cuda.cuEventElapsedTime, enable=enable) if not CUDACPU else cpu_time_execution(cb, enable=enable) # noqa: E501 @diskcache -def compile_cuda(prg) -> bytes: return compile_cuda_style(prg, [f'--gpu-architecture={CUDADevice.default_arch_name}', "-I/usr/local/cuda/include", "-I/usr/include"], cuda.nvrtcProgram, cuda.nvrtcCreateProgram, cuda.nvrtcCompileProgram, cuda.nvrtcGetPTX, cuda.nvrtcGetPTXSize, cuda.nvrtcGetProgramLog, cuda.nvrtcGetProgramLogSize, check) +def compile_cuda(prg) -> bytes: return compile_cuda_style(prg, [f'--gpu-architecture={CUDADevice.default_arch_name}', "-I/usr/local/cuda/include", "-I/usr/include"], cuda.nvrtcProgram, cuda.nvrtcCreateProgram, cuda.nvrtcCompileProgram, cuda.nvrtcGetPTX, cuda.nvrtcGetPTXSize, cuda.nvrtcGetProgramLog, cuda.nvrtcGetProgramLogSize, check) # noqa: E501 class CUDAProgram: def __init__(self, device:CUDADevice, name:str, lib:bytes): @@ -46,7 +46,7 @@ class CUDAProgram: def __call__(self, *bufs, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int], vals:Tuple[int, ...]=(), wait=False): if not CUDACPU: check(cuda.cuCtxSetCurrent(self.device.context)) c_kernel_input_config = encode_args_cuda_style(bufs, vals, cuda.CUdeviceptr_v2, (1,2,0))[0] if not CUDACPU else (bufs+tuple(vals)) - return cu_time_execution(lambda: check(cuda.cuLaunchKernel(self.prg, *global_size, *local_size, 0, None, None, c_kernel_input_config)), enable=wait) + return cu_time_execution(lambda: check(cuda.cuLaunchKernel(self.prg, *global_size, *local_size, 0, None, None, c_kernel_input_config)), enable=wait) # noqa: E501 class CUDAAllocator(LRUAllocator): def __init__(self, device:CUDADevice): diff --git a/tinygrad/runtime/ops_disk.py b/tinygrad/runtime/ops_disk.py index 8fbb4a9e..2ab18fac 100644 --- a/tinygrad/runtime/ops_disk.py +++ b/tinygrad/runtime/ops_disk.py @@ -13,7 +13,8 @@ class UnderlyingDiskBuffer: if self.fd: self.fd.close() class DiskBuffer: - def __init__(self, ud:UnderlyingDiskBuffer, size:int, dtype:DType=dtypes.uint8, offset=0): self.ud, self.size, self.dtype, self.offset = ud, size, dtype, offset + def __init__(self, ud:UnderlyingDiskBuffer, size:int, dtype:DType=dtypes.uint8, offset=0): + self.ud, self.size, self.dtype, self.offset = ud, size, dtype, offset def __repr__(self): return f"" def cast(self, arg:Tuple[DType, bool]): return DiskBuffer(self.ud, self.size, arg[0], offset=self.offset) def as_strided(self, arg): diff --git a/tinygrad/runtime/ops_gpu.py b/tinygrad/runtime/ops_gpu.py index 1f94565c..fa72f2cf 100644 --- a/tinygrad/runtime/ops_gpu.py +++ b/tinygrad/runtime/ops_gpu.py @@ -7,7 +7,8 @@ from tinygrad.codegen.kernel import LinearizerOptions from tinygrad.renderer.cstyle import OpenCLRenderer from tinygrad.device import Compiled, LRUAllocator -OSX_TIMING_RATIO = (125/3) if OSX else 1.0 # see test/external/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something +# see test/external/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something +OSX_TIMING_RATIO = (125/3) if OSX else 1.0 def check(status): if status != 0: raise RuntimeError(f"OpenCL Error {status}") @@ -16,22 +17,24 @@ def checked(ret, status): return (check(status.value), ret)[1] @diskcache def compile_cl(prg:str) -> bytes: assert CLDevice.compiler_context is not None, 'OpenCL requires a "compiler_context" to compile, init a device before you call this' - program = checked(cl.clCreateProgramWithSource(CLDevice.compiler_context.context, 1, to_char_p_p([prg_bytes := prg.encode()]), ctypes.byref(ctypes.c_size_t(len(prg_bytes))), ctypes.byref(status := ctypes.c_int32())), status) + program = checked(cl.clCreateProgramWithSource(CLDevice.compiler_context.context, 1, to_char_p_p([prg_bytes := prg.encode()]), + ctypes.byref(ctypes.c_size_t(len(prg_bytes))), ctypes.byref(status := ctypes.c_int32())), status) status = cl.clBuildProgram(program, 1, ctypes.byref(CLDevice.compiler_context.device_id), None, cl.clBuildProgram.argtypes[4](), None) if status != 0: - cl.clGetProgramBuildInfo(program, CLDevice.compiler_context.device_id, cl.CL_PROGRAM_BUILD_LOG, 0, None, ctypes.byref(log_size := ctypes.c_size_t())) - cl.clGetProgramBuildInfo(program, CLDevice.compiler_context.device_id, cl.CL_PROGRAM_BUILD_LOG, log_size.value, mstr := ctypes.create_string_buffer(log_size.value), None) + cl.clGetProgramBuildInfo(program, CLDevice.compiler_context.device_id, cl.CL_PROGRAM_BUILD_LOG, 0, None, ctypes.byref(log_size := ctypes.c_size_t())) # noqa: E501 + cl.clGetProgramBuildInfo(program, CLDevice.compiler_context.device_id, cl.CL_PROGRAM_BUILD_LOG, log_size.value, mstr := ctypes.create_string_buffer(log_size.value), None) # noqa: E501 raise RuntimeError(f"OpenCL Compile Error\n\n{ctypes.string_at(mstr, size=log_size.value).decode()}") - binary_sizes = init_c_var((ctypes.c_size_t * 1)(), lambda x: check(cl.clGetProgramInfo(program, cl.CL_PROGRAM_BINARY_SIZES, ctypes.sizeof(x), ctypes.byref(x), None))) - binary = init_c_var(ctypes.create_string_buffer(binary_sizes[0]), lambda x: check(cl.clGetProgramInfo(program, cl.CL_PROGRAM_BINARIES, ctypes.sizeof(ctypes.c_void_p), ctypes.byref((ctypes.c_void_p * 1)(ctypes.addressof(x))), None))) + binary_sizes = init_c_var((ctypes.c_size_t * 1)(), lambda x: check(cl.clGetProgramInfo(program, cl.CL_PROGRAM_BINARY_SIZES, ctypes.sizeof(x), ctypes.byref(x), None))) # noqa: E501 + binary = init_c_var(ctypes.create_string_buffer(binary_sizes[0]), lambda x: check(cl.clGetProgramInfo(program, cl.CL_PROGRAM_BINARIES, ctypes.sizeof(ctypes.c_void_p), ctypes.byref((ctypes.c_void_p * 1)(ctypes.addressof(x))), None))) # noqa: E501 check(cl.clReleaseProgram(program)) return bytes(binary) class CLProgram: def __init__(self, device:CLDevice, name:str, lib:bytes): self.device, self.name, self.lib = device, name, lib - self.program = checked(cl.clCreateProgramWithBinary(device.context, 1, ctypes.byref(device.device_id), (ctypes.c_size_t * 1)(len(lib)), to_char_p_p([lib], ctypes.c_ubyte), - ctypes.byref(binary_status := ctypes.c_int32()), ctypes.byref(errcode_ret := ctypes.c_int32())), errcode_ret) + self.program = checked(cl.clCreateProgramWithBinary(device.context, 1, ctypes.byref(device.device_id), (ctypes.c_size_t * 1)(len(lib)), + to_char_p_p([lib], ctypes.c_ubyte), ctypes.byref(binary_status := ctypes.c_int32()), + ctypes.byref(errcode_ret := ctypes.c_int32())), errcode_ret) check(binary_status.value) check(cl.clBuildProgram(self.program, 1, ctypes.byref(device.device_id), None, cl.clBuildProgram.argtypes[4](), None)) # NOTE: OSX requires this self.kernel = checked(cl.clCreateKernel(self.program, name.encode(), ctypes.byref(status := ctypes.c_int32())), status) @@ -40,16 +43,16 @@ class CLProgram: check(cl.clReleaseKernel(self.kernel)) check(cl.clReleaseProgram(self.program)) - def __call__(self, *bufs:cl.cl_mem, global_size:Tuple[int,...], local_size:Optional[Tuple[int,...]]=None, vals:Tuple[int, ...]=(), wait=False) -> Optional[float]: + def __call__(self, *bufs:cl.cl_mem, global_size:Tuple[int,...], local_size:Optional[Tuple[int,...]]=None, vals:Tuple[int, ...]=(), wait=False) -> Optional[float]: # noqa: E501 for i,b in enumerate(bufs): cl.clSetKernelArg(self.kernel, i, ctypes.sizeof(b), ctypes.byref(b)) for i,b in enumerate(vals,start=len(bufs)): cl.clSetKernelArg(self.kernel, i, 4, ctypes.byref(ctypes.c_int32(b))) if local_size is not None: global_size = tuple(int(g*l) for g,l in zip(global_size, local_size)) event = cl.cl_event() if wait else None - check(cl.clEnqueueNDRangeKernel(self.device.queue, self.kernel, len(global_size), None, (ctypes.c_size_t * len(global_size))(*global_size), (ctypes.c_size_t * len(local_size))(*local_size) if local_size else None, 0, None, event)) + check(cl.clEnqueueNDRangeKernel(self.device.queue, self.kernel, len(global_size), None, (ctypes.c_size_t * len(global_size))(*global_size), (ctypes.c_size_t * len(local_size))(*local_size) if local_size else None, 0, None, event)) # noqa: E501 if wait: check(cl.clWaitForEvents(1, ctypes.byref(event))) - start = init_c_var(ctypes.c_ulong(), lambda x: check(cl.clGetEventProfilingInfo(event, cl.CL_PROFILING_COMMAND_START, ctypes.sizeof(x), ctypes.byref(x), None))) - end = init_c_var(ctypes.c_ulong(), lambda x: check(cl.clGetEventProfilingInfo(event, cl.CL_PROFILING_COMMAND_END, ctypes.sizeof(x), ctypes.byref(x), None))) + start = init_c_var(ctypes.c_ulong(), lambda x: check(cl.clGetEventProfilingInfo(event, cl.CL_PROFILING_COMMAND_START, ctypes.sizeof(x), ctypes.byref(x), None))) # noqa: E501 + end = init_c_var(ctypes.c_ulong(), lambda x: check(cl.clGetEventProfilingInfo(event, cl.CL_PROFILING_COMMAND_END, ctypes.sizeof(x), ctypes.byref(x), None))) # noqa: E501 return float(end.value-start.value) * OSX_TIMING_RATIO * 1e-9 return None @@ -83,10 +86,10 @@ class CLDevice(Compiled): err = cl.clGetDeviceIDs(platform_ids[0], device_type, 0, None, ctypes.byref(num_devices)) if err == 0 and num_devices.value != 0: break if DEBUG >= 1: print(f"CLDevice: got {num_platforms.value} platforms and {num_devices.value} devices") - CLDevice.device_ids = init_c_var((cl.cl_device_id * num_devices.value)(), lambda x: check(cl.clGetDeviceIDs(platform_ids[0], device_type, num_devices, x, None))) + CLDevice.device_ids = init_c_var((cl.cl_device_id * num_devices.value)(), lambda x: check(cl.clGetDeviceIDs(platform_ids[0], device_type, num_devices, x, None))) # noqa: E501 self.device_id = CLDevice.device_ids[0 if ":" not in device else int(device.split(":")[1])] - self.context = checked(cl.clCreateContext(None, 1, ctypes.byref(self.device_id), cl.clCreateContext.argtypes[3](), None, ctypes.byref(status := ctypes.c_int32())), status) + self.context = checked(cl.clCreateContext(None, 1, ctypes.byref(self.device_id), cl.clCreateContext.argtypes[3](), None, ctypes.byref(status := ctypes.c_int32())), status) # noqa: E501 if CLDevice.compiler_context is None: CLDevice.compiler_context = self self.queue = checked(cl.clCreateCommandQueue(self.context, self.device_id, cl.CL_QUEUE_PROFILING_ENABLE, ctypes.byref(status)), status) self.pending_copyin: List[memoryview] = [] diff --git a/tinygrad/runtime/ops_hip.py b/tinygrad/runtime/ops_hip.py index 54e401c3..2624d3e0 100644 --- a/tinygrad/runtime/ops_hip.py +++ b/tinygrad/runtime/ops_hip.py @@ -12,10 +12,11 @@ MOCKHIP = getenv("MOCKHIP") # for CI. don't run kernels, only check if they comp def check(status): if status != 0: raise RuntimeError(f"HIP Error {status}, {ctypes.string_at(hip.hipGetErrorString(status)).decode()}") -def hip_time_execution(cb, enable=False): return time_execution_cuda_style(cb, hip.hipEvent_t, hip.hipEventCreate, hip.hipEventRecord, hip.hipEventSynchronize, hip.hipEventDestroy, hip.hipEventElapsedTime, enable=enable) +# TODO: remove these helpers, they increase complexity +def hip_time_execution(cb, enable=False): return time_execution_cuda_style(cb, hip.hipEvent_t, hip.hipEventCreate, hip.hipEventRecord, hip.hipEventSynchronize, hip.hipEventDestroy, hip.hipEventElapsedTime, enable=enable) # noqa: E501 @diskcache -def compile_hip(prg) -> bytes: return compile_cuda_style(prg, [f'--offload-arch={HIPDevice.default_arch_name}'], hip.hiprtcProgram, hip.hiprtcCreateProgram, hip.hiprtcCompileProgram, hip.hiprtcGetCode, hip.hiprtcGetCodeSize, hip.hiprtcGetProgramLog, hip.hiprtcGetProgramLogSize, check) +def compile_hip(prg) -> bytes: return compile_cuda_style(prg, [f'--offload-arch={HIPDevice.default_arch_name}'], hip.hiprtcProgram, hip.hiprtcCreateProgram, hip.hiprtcCompileProgram, hip.hiprtcGetCode, hip.hiprtcGetCodeSize, hip.hiprtcGetProgramLog, hip.hiprtcGetProgramLogSize, check) # noqa: E501 class HIPProgram: def __init__(self, device:int, name:str, lib:bytes): @@ -36,7 +37,7 @@ class HIPProgram: def __call__(self, *args, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int], vals:Tuple[int, ...]=(), wait=False): if MOCKHIP: return float("inf") check(hip.hipSetDevice(self.device)) - return hip_time_execution(lambda: check(hip.hipModuleLaunchKernel(self.prg, *global_size, *local_size, 0, None, None, encode_args_cuda_style(args, vals, hip.hipDeviceptr_t, marks=(1,2,3))[0])), enable=wait) + return hip_time_execution(lambda: check(hip.hipModuleLaunchKernel(self.prg, *global_size, *local_size, 0, None, None, encode_args_cuda_style(args, vals, hip.hipDeviceptr_t, marks=(1,2,3))[0])), enable=wait) # noqa: E501 T = TypeVar("T") class HIPAllocator(LRUAllocator): @@ -63,10 +64,11 @@ class HIPDevice(Compiled): default_arch_name = "gfx1100" def __init__(self, device:str=""): self.device = int(device.split(":")[1]) if ":" in device else 0 - if self.device == 0 and not MOCKHIP: HIPDevice.default_arch_name = init_c_var(hip.hipDeviceProp_t(), lambda x: check(hip.hipGetDeviceProperties(x, self.device))).gcnArchName.decode() + if self.device == 0 and not MOCKHIP: HIPDevice.default_arch_name = init_c_var(hip.hipDeviceProp_t(), lambda x: check(hip.hipGetDeviceProperties(x, self.device))).gcnArchName.decode() # noqa: E501 from tinygrad.features.graph.hip import HIPGraph - super().__init__(MallocAllocator if MOCKHIP else HIPAllocator(self.device), LinearizerOptions(device="HIP"), HIPRenderer, compile_hip, functools.partial(HIPProgram, self.device), HIPGraph) + super().__init__(MallocAllocator if MOCKHIP else HIPAllocator(self.device), LinearizerOptions(device="HIP"), HIPRenderer, + compile_hip, functools.partial(HIPProgram, self.device), HIPGraph) def synchronize(self): check(hip.hipSetDevice(self.device)) check(hip.hipDeviceSynchronize()) \ No newline at end of file diff --git a/tinygrad/runtime/ops_llvm.py b/tinygrad/runtime/ops_llvm.py index 9f924026..4ea70905 100644 --- a/tinygrad/runtime/ops_llvm.py +++ b/tinygrad/runtime/ops_llvm.py @@ -63,4 +63,5 @@ class LLVMProgram: self.cfunc = CFUNCTYPE(ctypes.c_int, *([ctypes.c_void_p]*len(bufs)), *([ctypes.c_int32]*len(vals)))(self.fxn) return cpu_time_execution(lambda: self.cfunc(*bufs, *vals), enable=wait) -LLVMDevice = Compiled(MallocAllocator, LinearizerOptions(supports_float4=False, has_local=False, has_shared=False), uops_to_llvm_ir, compile_llvm, LLVMProgram) +LLVMDevice = Compiled(MallocAllocator, LinearizerOptions(supports_float4=False, has_local=False, has_shared=False), + uops_to_llvm_ir, compile_llvm, LLVMProgram) diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index 1f832964..85a78775 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -32,7 +32,7 @@ class MetalProgram: self.pipeline_state = unwrap2(self.device.device.newComputePipelineStateWithFunction_error_(self.fxn, None)) def __call__(self, *bufs, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int], vals:Tuple[int, ...]=(), wait=False): - assert prod(local_size) <= self.pipeline_state.maxTotalThreadsPerThreadgroup(), f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}" + assert prod(local_size) <= self.pipeline_state.maxTotalThreadsPerThreadgroup(),f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}" # noqa: E501 command_buffer = self.device.mtl_queue.commandBuffer() encoder = command_buffer.computeCommandEncoder() encoder.setComputePipelineState_(self.pipeline_state) @@ -81,7 +81,8 @@ class MetalDevice(Compiled): self.mtl_buffers_in_flight: List[Any] = [] self.mv_in_metal: List[memoryview] = [] from tinygrad.features.graph.metal import MetalGraph - super().__init__(MetalAllocator(self), LinearizerOptions(device="METAL"), MetalRenderer, compile_metal, functools.partial(MetalProgram, self), functools.partial(MetalGraph, self)) + super().__init__(MetalAllocator(self), LinearizerOptions(device="METAL"), MetalRenderer, + compile_metal, functools.partial(MetalProgram, self), functools.partial(MetalGraph, self)) def synchronize(self): for cbuf in self.mtl_buffers_in_flight: cbuf.waitUntilCompleted() self.mv_in_metal.clear() diff --git a/tinygrad/runtime/ops_torch.py b/tinygrad/runtime/ops_torch.py index a055dbd2..c8e4925f 100644 --- a/tinygrad/runtime/ops_torch.py +++ b/tinygrad/runtime/ops_torch.py @@ -7,7 +7,9 @@ from tinygrad.helpers import getenv, dtypes from tinygrad.runtime.ops_cpu import einsum_mulacc, shape_to_axis device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu")) -type_map = {torch.float64: dtypes.float64, torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.int8: dtypes.int8, torch.int32: dtypes.int32, torch.int64: dtypes.int64, torch.uint8: dtypes.uint8, torch.bool: dtypes.bool, torch.int16: dtypes.int16} +type_map = {torch.float64: dtypes.float64, torch.float16: dtypes.float16, torch.float32: dtypes.float32, + torch.int8: dtypes.int8, torch.int32: dtypes.int32, torch.int64: dtypes.int64, + torch.uint8: dtypes.uint8, torch.bool: dtypes.bool, torch.int16: dtypes.int16} inverse_type_map = {v:k for k,v in type_map.items()} def output_type(x, y): return x.dtype if type_map[x.dtype].priority > type_map[y.dtype].priority else y.dtype @@ -27,7 +29,8 @@ torch_fxn_for_op: Dict[Op, Callable] = { #BufferOps.CONST: lambda val, dtype: torch.tensor(val, device=device, dtype=inverse_type_map[dtype]), BufferOps.CONST: lambda val, dtype: torch.from_numpy(np.array(val, dtype=dtype.np)).to(device), UnaryOps.SQRT: lambda x: x.sqrt(), UnaryOps.EXP2: lambda x: x.exp2(), UnaryOps.LOG2: lambda x: x.log2(), UnaryOps.SIN: torch.sin, - UnaryOps.CAST: lambda x,y: (x.view if y[1] else x.type)(next(k for k,v in type_map.items() if v==y[0])), UnaryOps.NEG: lambda x: torch.logical_not(x) if x.dtype is torch.bool else torch.neg(x), + UnaryOps.CAST: lambda x,y: (x.view if y[1] else x.type)(next(k for k,v in type_map.items() if v==y[0])), + UnaryOps.NEG: lambda x: torch.logical_not(x) if x.dtype is torch.bool else torch.neg(x), BinaryOps.MAX: torch.maximum, BinaryOps.CMPLT: lambda x,y: (x wgpu.GPUBuffer: return buf class WebGPUProgram: - def __init__(self, name:str, lib:bytes): self.name, self.lib, self.prg = name, lib, wgpu_device.create_shader_module(code=lib) # NOTE: this is the compiler + def __init__(self, name:str, lib:bytes): + self.name, self.lib, self.prg = name, lib, wgpu_device.create_shader_module(code=lib) # NOTE: this is the compiler def __call__(self, *bufs, global_size, local_size, vals=(), wait=False): assert len(bufs) <= 8, "WEBGPU only supports 8 buffers" - binding_layouts = [{"binding": i, "visibility": wgpu.ShaderStage.COMPUTE, "buffer": {"type": wgpu.BufferBindingType.uniform if i >= len(bufs) else wgpu.BufferBindingType.storage }} for i in range(len(bufs)+len(vals))] - bindings = [{"binding": i, "resource": {"buffer": create_uniform(x) if i >= len(bufs) else x, "offset": 0, "size": 4 if i >= len(bufs) else x.size}} for i,x in enumerate(bufs+vals)] + binding_layouts = [{"binding": i, "visibility": wgpu.ShaderStage.COMPUTE, "buffer": {"type": wgpu.BufferBindingType.uniform if i >= len(bufs) else wgpu.BufferBindingType.storage }} for i in range(len(bufs)+len(vals))] # noqa: E501 + bindings = [{"binding": i, "resource": {"buffer": create_uniform(x) if i >= len(bufs) else x, "offset": 0, "size": 4 if i >= len(bufs) else x.size}} for i,x in enumerate(bufs+vals)] # noqa: E501 bind_group_layout = wgpu_device.create_bind_group_layout(entries=binding_layouts) pipeline_layout = wgpu_device.create_pipeline_layout(bind_group_layouts=[bind_group_layout]) bind_group = wgpu_device.create_bind_group(layout=bind_group_layout, entries=bindings) @@ -29,11 +30,12 @@ class WebGPUProgram: wgpu_device.queue.submit([command_encoder.finish()]) class WebGpuAllocator(Allocator): - def _alloc(self, size: int): return wgpu_device.create_buffer(size=size, usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_DST | wgpu.BufferUsage.COPY_SRC) + def _alloc(self, size: int): + return wgpu_device.create_buffer(size=size, usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_DST | wgpu.BufferUsage.COPY_SRC) def copyin(self, dest, src: memoryview): wgpu_device.queue.write_buffer(dest, 0, src) def copyout(self, dest, src: memoryview): dest[:] = wgpu_device.queue.read_buffer(src, 0) # TODO: remove this copy class WebGpuDevice(Compiled): def __init__(self, device:str): - super().__init__(WebGpuAllocator(), LinearizerOptions(device="WEBGPU", supports_float4=False, local_max=[256, 256, 64], global_max=[65535, 65535, 65535]), - WGSLRenderer, lambda x: x, WebGPUProgram) + super().__init__(WebGpuAllocator(), LinearizerOptions(device="WEBGPU", supports_float4=False, local_max=[256, 256, 64], + global_max=[65535, 65535, 65535]), WGSLRenderer, lambda x: x, WebGPUProgram) diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 9f19eaba..0670e165 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -32,7 +32,7 @@ def expr_node(view:View, idx:Optional[Node]=None) -> Node: # generate an expression if you have a variable or expression for each index def expr_idxs(view:View, idxs:Tuple[Node, ...]) -> Node: assert len(idxs) == len(view.shape), f"need an idx for all dimensions {idxs} vs {view.shape}" - return Variable.sum([NumNode(view.offset) if isinstance(view.offset, int) else view.offset] + [idx*st for idx,sh,st in zip(idxs, view.shape, view.strides) if sh != 1 and st != 0]) + return Variable.sum([NumNode(view.offset) if isinstance(view.offset, int) else view.offset] + [idx*st for idx,sh,st in zip(idxs, view.shape, view.strides) if sh != 1 and st != 0]) # noqa: E501 @functools.lru_cache(maxsize=None) def merge_views(vm2:View, vm1:View) -> Optional[View]: @@ -52,7 +52,8 @@ def idxs_to_idx(shape:Tuple[int, ...], idxs:Tuple[Node, ...]) -> Node: @dataclass(frozen=True) class ShapeTracker: views: Tuple[View, ...] - def __post_init__(self): assert isinstance(self.views, tuple) and all(isinstance(v, View) for v in self.views), "ShapeTracker must be created with a tuple of Views" + def __post_init__(self): + assert isinstance(self.views, tuple) and all(isinstance(v, View) for v in self.views), "ShapeTracker must be created with a tuple of Views" @staticmethod def from_shape(shape:Tuple[sint, ...]): return ShapeTracker((View.create(shape),)) diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index 3099339a..1797b929 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -185,7 +185,8 @@ class LtNode(OpNode): if isinstance(self.b, int): return (1, 1) if self.a.max < self.b else (0, 0) if self.a.min >= self.b else (0, 1) return (1, 1) if self.a.max < self.b.min else (0, 0) if self.a.min >= self.b.max else (0, 1) - def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) < (self.b if isinstance(self.b, int) else self.b.substitute(var_vals)) + def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: + return self.a.substitute(var_vals) < (self.b if isinstance(self.b, int) else self.b.substitute(var_vals)) class MulNode(OpNode): def __lt__(self, b: Union[Node, int]): @@ -201,7 +202,8 @@ class MulNode(OpNode): a = (self.a * (self.b%b)) return Node.__mod__(a, b) def get_bounds(self) -> Tuple[int, int]: return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b) - def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) * (self.b if isinstance(self.b, int) else self.b.substitute(var_vals)) + def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: + return self.a.substitute(var_vals) * (self.b if isinstance(self.b, int) else self.b.substitute(var_vals)) class DivNode(OpNode): def __floordiv__(self, b: Union[Node, int], _=False): return self.a//(self.b*b) # two divs is one div @@ -219,7 +221,7 @@ class ModNode(OpNode): return Node.__floordiv__(self, b, factoring_allowed) def get_bounds(self) -> Tuple[int, int]: assert self.a.min >= 0 and isinstance(self.b, int) - return (0, self.b-1) if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b) else (self.a.min%self.b, self.a.max%self.b) + return (0, self.b-1) if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b) else (self.a.min%self.b, self.a.max%self.b) # noqa: E501 def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) % self.b class RedNode(Node): @@ -335,9 +337,9 @@ sint = Union[Node, int] VariableOrNum = Union[Variable, NumNode] render_python: Dict[Type, Callable] = { - Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self._val is not None else ''}]" if ctx == "DEBUG" else (f"Variable('{self.expr}', {self.min}, {self.max})"+(f".bind({self.val})" if self._val is not None else '') if ctx == "REPR" else f"{self.expr}"), + Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self._val is not None else ''}]" if ctx == "DEBUG" else (f"Variable('{self.expr}', {self.min}, {self.max})"+(f".bind({self.val})" if self._val is not None else '') if ctx == "REPR" else f"{self.expr}"), # noqa: E501 NumNode: lambda self,ops,ctx: f"NumNode({self.b})" if ctx == "REPR" else f"{self.b}", - MulNode: lambda self,ops,ctx: f"({sym_render(self.b,ops,ctx)}*{self.a.render(ops,ctx)})" if isinstance(self.a,Variable) and isinstance(self.b,Variable) and self.a.expr and self.b.expr and self.b.expr < self.a.expr else f"({self.a.render(ops,ctx)}*{sym_render(self.b,ops,ctx)})", + MulNode: lambda self,ops,ctx: f"({sym_render(self.b,ops,ctx)}*{self.a.render(ops,ctx)})" if isinstance(self.a,Variable) and isinstance(self.b,Variable) and self.a.expr and self.b.expr and self.b.expr < self.a.expr else f"({self.a.render(ops,ctx)}*{sym_render(self.b,ops,ctx)})", # noqa: E501 DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})", ModNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}%{self.b})", LtNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}<{sym_render(self.b,ops,ctx)})", diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index b7c79702..4f3d55b0 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -16,7 +16,7 @@ def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]: return filter_strides(shape, tuple(reversed(strides))) @functools.lru_cache(maxsize=None) -def _merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tuple[Tuple[int, int], ...]] = None) -> Tuple[Tuple[int, int, int], ...]: +def _merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tuple[Tuple[int, int], ...]] = None) -> Tuple[Tuple[int, int, int], ...]: # noqa: E501 # merge contiguous subparts or zero strided dims. ret = List[(merged_dims, stride, merged dims w/o zero stride), ...] if not shape: return tuple() assert len(shape) == len(strides) @@ -95,7 +95,7 @@ class View: new_shape = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.shape]) new_strides = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.strides]) new_offset = self.offset if isinstance(self.offset, int) else self.offset.substitute(unbound_vars) - new_mask = tuple((a if isinstance(a, int) else a.substitute(unbound_vars), b if isinstance(b, int) else b.substitute(unbound_vars)) for (a, b) in self.mask) if self.mask is not None else None + new_mask = tuple((a if isinstance(a, int) else a.substitute(unbound_vars), b if isinstance(b, int) else b.substitute(unbound_vars)) for (a, b) in self.mask) if self.mask is not None else None # noqa: E501 return View.create(new_shape, new_strides, new_offset, new_mask) # MovementOps live here now @@ -139,7 +139,7 @@ class View: def permute(self, axis: Tuple[int, ...]) -> View: assert all(isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis), f"invalid permute {axis} for {self.shape}" assert len(set(axis)) == len(axis) and len(axis) == len(self.shape), f"can't permute {self.shape} with {axis}" - return View.create(tuple([self.shape[a] for a in axis]), tuple([self.strides[a] for a in axis]), self.offset, tuple([self.mask[a] for a in axis]) if self.mask is not None else None) + return View.create(tuple([self.shape[a] for a in axis]), tuple([self.strides[a] for a in axis]), self.offset, tuple([self.mask[a] for a in axis]) if self.mask is not None else None) # noqa: E501 @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none def stride(self, mul: Tuple[int, ...]) -> View: @@ -148,7 +148,7 @@ class View: strides = tuple([z*m for z,m in zip(self.strides, mul)]) new_shape = tuple([(s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul)]) offset = sum([(s-1)*z for s,z,m in zip(self.shape, self.strides, mul) if m < 0]) - mask = tuple([(((mx if m > 0 else s-my)+(abs(m)-1))//abs(m), ((my if m > 0 else s-mx)+(abs(m)-1))//abs(m)) for (mx,my),s,m in zip(self.mask, self.shape, mul)]) if self.mask is not None else None + mask = tuple([(((mx if m > 0 else s-my)+(abs(m)-1))//abs(m), ((my if m > 0 else s-mx)+(abs(m)-1))//abs(m)) for (mx,my),s,m in zip(self.mask, self.shape, mul)]) if self.mask is not None else None # noqa: E501 return View.create(new_shape, strides, self.offset + offset, mask) @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none @@ -162,7 +162,8 @@ class View: # check for the same size if all_int(self.shape): assert all(isinstance(s, (int, Variable)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim" - if prod(self.shape) != prod([s if isinstance(s, int) else cast(Variable,s).val for s in new_shape]): raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}") + if prod(self.shape) != prod([s if isinstance(s, int) else cast(Variable,s).val for s in new_shape]): + raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}") if new_shape == () and self.mask and any(mx==my for (mx,my) in self.mask): return None diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index a94a67e0..97f3b45d 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -46,7 +46,7 @@ class Tensor: no_grad: ClassVar[bool] = False default_type: ClassVar[DType] = dtypes.float32 - def __init__(self, data:Union[None, int, float, list, LazyBuffer, np.ndarray, bytes], device:Optional[str]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None): + def __init__(self, data:Union[None, int, float, list, LazyBuffer, np.ndarray, bytes], device:Optional[str]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None): # noqa: E501 assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}" device = Device.canonicalize(device) # tensors have gradients, buffers do not @@ -116,7 +116,7 @@ class Tensor: assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}" assert not x.requires_grad # self requires_grad is okay? if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}") - if self.dtype == x.dtype and self.lazydata.realized is not None and not getenv("DISALLOW_ASSIGN"): x.lazydata.output_buffer = self.lazydata.realized + if self.dtype == x.dtype and self.lazydata.realized is not None and not getenv("DISALLOW_ASSIGN"): x.lazydata.output_buffer = self.lazydata.realized # noqa: E501 self.lazydata = x.lazydata return self def detach(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False) @@ -132,7 +132,7 @@ class Tensor: assert all_int(self.shape), f"no numpy if shape is symbolic, {self.shape=}" assert self.dtype.np is not None, f"no numpy dtype for {self.dtype}" if 0 in self.shape: return np.zeros(self.shape, dtype=self.dtype.np) - return self.detach().cast(dtypes.from_np(self.dtype.np)).contiguous().to('CPU').realize().lazydata.realized.toCPU().astype(self.dtype.np, copy=True).reshape(self.shape) + return self.detach().cast(dtypes.from_np(self.dtype.np)).contiguous().to('CPU').realize().lazydata.realized.toCPU().astype(self.dtype.np, copy=True).reshape(self.shape) # noqa: E501 def to(self, device:Optional[str]) -> Tensor: if device is None or device == self.device: return self @@ -151,7 +151,7 @@ class Tensor: @staticmethod def _loadop(op, sz, device:Optional[str]=None, dtype:Optional[DType]=None, arg=None, **kwargs): assert isinstance(sz, int), f"cannot create with symbolic size {sz}" - return Tensor(LazyBuffer.loadop(op, (sz,), Tensor.default_type if dtype is None else dtype, Device.canonicalize(device), arg), dtype=dtype, device=device, **kwargs) + return Tensor(LazyBuffer.loadop(op, (sz,), Tensor.default_type if dtype is None else dtype, Device.canonicalize(device), arg), dtype=dtype, device=device, **kwargs) # noqa: E501 @staticmethod def empty(*shape, **kwargs): @@ -168,7 +168,8 @@ class Tensor: # ***** creation helper functions ***** @staticmethod - def full(shape:Tuple[sint, ...], fill_value, **kwargs): return Tensor(fill_value, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape) + def full(shape:Tuple[sint, ...], fill_value, **kwargs): + return Tensor(fill_value, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape) @staticmethod def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0, **kwargs) @@ -182,9 +183,11 @@ class Tensor: return Tensor.full((math.ceil((stop-start)/step),), step, **kwargs).cumsum() + (start - step) @staticmethod - def eye(dim:int, **kwargs): return Tensor.full((dim,1),1,**kwargs).pad(((0,0),(0,dim))).reshape(dim*(dim+1)).shrink(((0,dim*dim),)).reshape(dim, dim) + def eye(dim:int, **kwargs): + return Tensor.full((dim,1),1,**kwargs).pad(((0,0),(0,dim))).reshape(dim*(dim+1)).shrink(((0,dim*dim),)).reshape(dim, dim) - def full_like(self, fill_value, **kwargs): return Tensor.full(self.shape, fill_value=fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs) + def full_like(self, fill_value, **kwargs): + return Tensor.full(self.shape, fill_value=fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs) def zeros_like(self, **kwargs): return self.full_like(0, **kwargs) def ones_like(self, **kwargs): return self.full_like(1, **kwargs) @@ -271,11 +274,14 @@ class Tensor: def reshape(self, shape, *args) -> Tensor: new_shape = argfix(shape, *args) - return mlops.Reshape.apply(self, shape=tuple([-prod(self.shape) // prod(new_shape) if s == -1 else (s if s is not None else self.shape[i]) for i,s in enumerate(new_shape)])) - def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))])) + return mlops.Reshape.apply(self, shape=tuple([-prod(self.shape) // prod(new_shape) if s == -1 else (s if s is not None else self.shape[i]) for i,s in enumerate(new_shape)])) # noqa: E501 + def expand(self, shape, *args) -> Tensor: + return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))])) def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args)) def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)]) - def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=tuple(x if x is not None else (0,s) for x,s in zip(arg, self.shape))) if any(x is not None and x != (0,s) for x,s in zip(arg, self.shape)) else self + def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor: + if not any(x is not None and x != (0,s) for x,s in zip(arg, self.shape)): return self + return mlops.Shrink.apply(self, arg=tuple(x if x is not None else (0,s) for x,s in zip(arg, self.shape))) def pad(self, arg:Tuple[Optional[Tuple[sint, sint]], ...], value:float=0.0) -> Tensor: if all(x is None or x == (0,0) for x in arg): return self ret = mlops.Pad.apply(self, arg=(narg:=tuple(x if x is not None else (0,0) for x in arg))) @@ -314,8 +320,9 @@ class Tensor: # 1. indices normalization and validation # treat internal tuples and lists as Tensors and standardize indices to list type if isinstance(indices, (tuple, list)): - if isinstance(indices, list) and all(isinstance(i, int) for i in indices): indices = [Tensor(indices, dtype=dtypes.int32, requires_grad=False, device=self.device)] # special case , a lil ugly - else: indices = [Tensor(list(i), dtype=dtypes.int32, requires_grad=False, device=self.device) if isinstance(i, (tuple, list)) else i for i in indices] + # special case , a lil ugly + if isinstance(indices, list) and all(isinstance(i, int) for i in indices): indices = [Tensor(indices, dtype=dtypes.int32, requires_grad=False, device=self.device)] # noqa: E501 + else: indices = [Tensor(list(i), dtype=dtypes.int32, requires_grad=False, device=self.device) if isinstance(i, (tuple, list)) else i for i in indices] # noqa: E501 else: indices = [indices] # filter ellipsis and fill with slice(None) or fill rest of indices with slice(None) @@ -341,10 +348,12 @@ class Tensor: if any(isinstance(i, slice) and i.step == 0 for i in indices): raise ValueError('slice step cannot be 0') if num_slices > len(self.shape): raise IndexError(f"too many indices for tensor of dimension {len(self.shape)}") for dim in type_dim[int]: - if indices_filtered[dim] >= self.shape[dim] or indices_filtered[dim] < -self.shape[dim]: raise IndexError(f"index {indices_filtered[dim]} is out of bounds for dimension {dim} with size {self.shape[dim]}") + if indices_filtered[dim] >= self.shape[dim] or indices_filtered[dim] < -self.shape[dim]: + raise IndexError(f"index {indices_filtered[dim]} is out of bounds for dimension {dim} with size {self.shape[dim]}") # normalize! indices -> start, stop, strides - start, stop, strides = zip(*y) if (y := [i.indices(sh) if isinstance(i, slice) else slice(normalized:= i if i != -1 else sh-1, normalized+1, 1).indices(sh) if isinstance(i, int) else (0, sh, 1) for i, sh in zip(indices_filtered, self.shape)]) else ((), (), ()) # type: ignore[arg-type] + # TODO: this line is completely unreadable + start, stop, strides = zip(*y) if (y := [i.indices(sh) if isinstance(i, slice) else slice(normalized:= i if i != -1 else sh-1, normalized+1, 1).indices(sh) if isinstance(i, int) else (0, sh, 1) for i, sh in zip(indices_filtered, self.shape)]) else ((), (), ()) # type: ignore[arg-type] # noqa: E501 # 2. basic indexing (no copy) # apply slices and flip where strides are negative @@ -380,7 +389,7 @@ class Tensor: # compute sum_dim, arange, and idx max_dim = max(i.ndim for i in idx) sum_dim = [d if n==0 else d+max_dim-n for n,d in enumerate(tdim)] - arange = [Tensor.arange(ret.shape[d], dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(*[1]*sd, ret.shape[d], *[1]*(ret.ndim + max_dim - n - sd - 1)) for n,(sd,d) in enumerate(zip(sum_dim, tdim))] + arange = [Tensor.arange(ret.shape[d], dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(*[1]*sd, ret.shape[d], *[1]*(ret.ndim + max_dim - n - sd - 1)) for n,(sd,d) in enumerate(zip(sum_dim, tdim))] # noqa: E501 first_idx = [idx[0].reshape(*[1]*tdim[0], *[1]*(1 + max_dim - idx[0].ndim), *idx[0].shape, *[1]*(ret.ndim - tdim[0] - 1))] rest_idx = [i.reshape(*[1]*tdim[0], *[1]*(max_dim - i.ndim), *i.shape, *[1]*(ret.ndim - tdim[0] - n)) for n,i in enumerate(idx[1:], 1)] idx = first_idx + rest_idx @@ -410,7 +419,7 @@ class Tensor: idx = idx.transpose(ax1=dim, ax2=0).unsqueeze(-1) permarg = list(range(self.ndim)) permarg = permarg[1:dim] + [permarg[0]] + permarg[dim+1:] + [permarg[dim]] if dim != 0 else permarg[1:] + [permarg[0]] - return ((idx == Tensor.arange(self.shape[dim], dtype=dtypes.int32, requires_grad=False, device=self.device)) * self.permute(*permarg).shrink(tuple([*[(0,sh) for sh in idx.shape[1:-1]], (0,self.shape[dim])])).unsqueeze(0)).sum(-1).transpose(ax1=0, ax2=dim) + return ((idx == Tensor.arange(self.shape[dim], dtype=dtypes.int32, requires_grad=False, device=self.device)) * self.permute(*permarg).shrink(tuple([*[(0,sh) for sh in idx.shape[1:-1]], (0,self.shape[dim])])).unsqueeze(0)).sum(-1).transpose(ax1=0, ax2=dim) # noqa: E501 def cat(self, *args:Tensor, dim:int=0) -> Tensor: dim = (dim + len(self.shape)) if dim < 0 else dim @@ -446,7 +455,7 @@ class Tensor: def squeeze(self, dim:Optional[int]=None) -> Tensor: if dim is None: return self if 1 not in self.shape else self.reshape(*[size for size in self.shape if size != 1]) if dim <= 0 and self.ndim == 0: return self # This is to match PyTorch behavior - if not -self.ndim <= dim < self.ndim: raise IndexError(f"Dimension out of range (expected to be in range of [{-self.ndim if self.ndim > 0 else self.ndim-1}, {self.ndim-1 if self.ndim > 0 else self.ndim}], but got {dim})") + if not -self.ndim <= dim < self.ndim: raise IndexError(f"Dimension out of range (expected to be in range of [{-self.ndim if self.ndim > 0 else self.ndim-1}, {self.ndim-1 if self.ndim > 0 else self.ndim}], but got {dim})") # noqa: E501 if dim < 0: dim += self.ndim return self if self.shape[dim] != 1 else self.reshape(*[size for idx, size in enumerate(self.shape) if idx != dim]) @@ -473,7 +482,8 @@ class Tensor: axis_: List[int] = list(range(len(self.shape))) if axis is None else ([axis] if isinstance(axis, int) else list(axis)) axis_ = [x if x >= 0 else x+len(self.shape) for x in axis_] shape = tuple(s for i,s in enumerate(self.shape) if i not in axis_) - if 0 in self.shape and 0 not in shape: return Tensor.full(tuple(1 if s == 0 else s for s in self.shape) if keepdim else shape, {mlops.Sum: 0.0, mlops.Max: -float("inf")}[fxn]) + if 0 in self.shape and 0 not in shape: + return Tensor.full(tuple(1 if s == 0 else s for s in self.shape) if keepdim else shape, {mlops.Sum: 0.0, mlops.Max: -float("inf")}[fxn]) ret = fxn.apply(self, new_shape=tuple([1 if i in axis_ else s for i,s in enumerate(self.shape)])) return ret if keepdim else ret.reshape(shape=shape) @@ -504,11 +514,11 @@ class Tensor: def argmax(self, axis=None, keepdim=False): if axis is None: - idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape) + idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape) # noqa: E501 return prod(self.shape) - idx.max() - 1 axis = axis + len(self.shape) if axis < 0 else axis m = self == self.max(axis=axis, keepdim=True) - idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1)) + idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1)) # noqa: E501 return self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1 def argmin(self, axis=None, keepdim=False): return (-self).argmax(axis=axis, keepdim=keepdim) @@ -547,7 +557,7 @@ class Tensor: if any(k > s for k,s in zip(k_, s_)) or any(d != 1 for d in d_): o_ = [(i - d * (k-1) - 1)//s + 1 for i,d,k,s in zip(i_, d_, k_, s_)] e_ = [math.ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)] # expands such that we don't need padding - xup = self.reshape(*prefix, *flatten((1,i) for i in i_)).expand(*prefix, *flatten((e,i) for e,i in zip(e_, i_))).reshape(*prefix, *[e*i for e,i in zip(e_, i_)]) + xup = self.reshape(*prefix, *flatten((1,i) for i in i_)).expand(*prefix, *flatten((e,i) for e,i in zip(e_, i_))).reshape(*prefix, *[e*i for e,i in zip(e_, i_)]) # noqa: E501 # slide by dilation xup = xup.slice(slc_prefix + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)]) xup = xup.reshape(*prefix, *flatten((k,i+d) for k,i,d in zip(k_, i_, d_))) @@ -565,8 +575,8 @@ class Tensor: return xup.permute(*range(len(prefix)), *[len(prefix)+i*2 for i in range(len(k_))], *[len(prefix)+i*2+1 for i in range(len(k_))]) # NOTE: these work for more than 2D - def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0))) - def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0))) + def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0))) # noqa: E501 + def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0))) # noqa: E501 def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor: HW, trailing = weight.shape[2:], list(range(3, len(weight.shape)+1)) @@ -577,62 +587,70 @@ class Tensor: x = x.pad(((0,0), (0,0), *flatten(((0,0),(0,s-1)) for s in stride))) x = x.reshape(*x.shape[:2], *[k*s for k,s in zip(x.shape[2::2], stride)]) x = x.shrink(((0,x.shape[0]), (0,x.shape[1]), *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)])) - padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, make_pair(dilation, len(HW)), make_pair(padding, len(HW)), make_pair(output_padding, len(HW))))))) + padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, make_pair(dilation, len(HW)), make_pair(padding, len(HW)), make_pair(output_padding, len(HW))))))) # noqa: E501 return x.conv2d(w.reshape(w.shape[0]*w.shape[1],*w.shape[2:]), groups=groups, bias=bias, dilation=dilation, padding=padding) wino = int(getenv("WINO", "0")) def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor: (bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:] - assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" - if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}" - padding_ = [padding]*2*len(HW) if isinstance(padding, int) else (padding if len(padding) == 2*len(HW) else [p for p in padding for _ in range(2)][::-1]) + assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501 + if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}" # noqa: E501 + padding_ = [padding]*2*len(HW) if isinstance(padding, int) else (padding if len(padding) == 2*len(HW) else [p for p in padding for _ in range(2)][::-1]) # noqa: E501 # conv2d is a pooling op (with padding) x = self.pad2d(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W) rcout, oyx = cout//groups, x.shape[2:-len(HW)] if not all(x == 3 for x in HW) or stride != 1 or dilation != 1 or not Tensor.wino: # normal conv - x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))]) + x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))]) # noqa: E501 # conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW) - ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True).reshape(bs, cout, *oyx) + ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True).reshape(bs, cout, *oyx) # noqa: E501 return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW))) # winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308 - def apply_matrix(mat, t, dim=0): return t if dim == len(HW) else Tensor.stack([apply_matrix(mat, sum(mm*t[j] for j,mm in enumerate(m) if mm), dim=dim+1) for m in mat]) + def apply_matrix(mat, t, dim=0): return t if dim == len(HW) else Tensor.stack([apply_matrix(mat, sum(mm*t[j] for j,mm in enumerate(m) if mm), dim=dim+1) for m in mat]) # noqa: E501 HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]] winograd_G = [[1/4, 0, 0], [-1/6, -1/6, -1/6], [-1/6, 1/6, -1/6], [1/24, 1/12, 1/6], [1/24, -1/12, 1/6], [0, 0, 1]] - winograd_At = [[1, 1, 1, 1, 1, 0], [0, 1, -1, 2, -2, 0], [0, 1, 1, 4, 4, 0], [0, 1, -1, 8, -8, 1]] # applying At in pre-order almost doubles compilation time + winograd_At = [[1, 1, 1, 1, 1, 0], [0, 1, -1, 2, -2, 0], [0, 1, 1, 4, 4, 0], [0, 1, -1, 8, -8, 1]] # applying At in pre-order doubles compile time # todo: stride == dilation # use padding to round up to 4x4 output tiles - d = self.pad2d(sum([[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])], []))._pool(HWI, HWO) # (bs, cin_, tyx, HWI) - d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW))).contiguous_backward() # move HW to the front: # (HWI, bs, cin_, tyx) + # (bs, cin_, tyx, HWI) + d = self.pad2d(sum([[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])], []))._pool(HWI, HWO) # noqa: E501 + # move HW to the front: # (HWI, bs, cin_, tyx) + d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW))).contiguous_backward() tyx = d.shape[-len(HWI):] # dim of tiling g = weight.permute(*range(len(weight.shape)-len(HW),len(weight.shape)), *range(len(weight.shape)-len(HW))) # move HW to the front # compute 6x6 winograd tiles: GgGt, BtdB - gfactors = apply_matrix(winograd_G, g).contiguous().reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx))) # (HWI, groups * rcout, cin) -> (HWI, bs=1, groups, rcout, cin, tyx=(1,1)) - dfactors = apply_matrix(winograd_Bt, d).contiguous().reshape(*HWI, bs, groups, 1, cin, *tyx) # (HWI, bs, cin_, tyx) -> (HWI, bs, groups, 1 ,cin, *tyx) + # (HWI, groups * rcout, cin) -> (HWI, bs=1, groups, rcout, cin, tyx=(1,1)) + gfactors = apply_matrix(winograd_G, g).contiguous().reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx))) + # (HWI, bs, cin_, tyx) -> (HWI, bs, groups, 1 ,cin, *tyx) + dfactors = apply_matrix(winograd_Bt, d).contiguous().reshape(*HWI, bs, groups, 1, cin, *tyx) - ret = apply_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW))) # matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx) + # matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx) + ret = apply_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW))) - ret = ret.permute([*range(len(HW), len(ret.shape)-len(HW)), *[i+o for i in range(len(HW)) for o in [len(ret.shape)-len(HW),0]]]) # interleave tyx and HWO: (bs, groups, rcout, oy, HO, ox, WO) - ret = ret.reshape(bs, cout, *[c * HWO[i] for i, c in enumerate(tyx)]).shrink(tuple((0, s) for s in [bs, cout, *oyx])) # merge groups and rcout, tyx and HWO: (bs, groups, cout, *yx), shrink to final + # interleave tyx and HWO: (bs, groups, rcout, oy, HO, ox, WO) + ret = ret.permute([*range(len(HW), len(ret.shape)-len(HW)), *[i+o for i in range(len(HW)) for o in [len(ret.shape)-len(HW),0]]]) + # merge groups and rcout, tyx and HWO: (bs, groups, cout, *yx), shrink to final + ret = ret.reshape(bs, cout, *[c * HWO[i] for i, c in enumerate(tyx)]).shrink(tuple((0, s) for s in [bs, cout, *oyx])) return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward() def dot(self, w:Tensor) -> Tensor: n1, n2 = len(self.shape), len(w.shape) assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D" - assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" + assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" # noqa: E501 x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1]) w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2)) return (x*w).sum(-1) - def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor: return self.transpose(axis,-1).pad2d((self.shape[axis]-int(not _first_zero),0))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1) + def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor: + return self.transpose(axis,-1).pad2d((self.shape[axis]-int(not _first_zero),0))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1) def cumsum(self, axis:int=0) -> Tensor: # TODO: someday the optimizer will find this on it's own # for now this is a two stage cumsum @@ -646,7 +664,8 @@ class Tensor: return fix(ret) + fix(base_add) @staticmethod - def _tri(r:int, c:int, k:int=0, **kwargs) -> Tensor: return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c) + def _tri(r:int, c:int, k:int=0, **kwargs) -> Tensor: + return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c) def triu(self, k:int=0) -> Tensor: assert all_int(self.shape), f"does not support symbolic shape {self.shape}" return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype, device=self.device).where(self, Tensor.zeros_like(self)) @@ -711,7 +730,7 @@ class Tensor: x: Tensor = self if not isinstance(y, Tensor): if 0 in x.shape: return x, x.full_like(y) - y = Tensor(y, device=self.device, requires_grad=False, dtype=self.dtype if self.dtype != dtypes.bool and self.dtype.__class__ is not ImageDType else dtypes.float32) + y = Tensor(y, device=self.device, requires_grad=False, dtype=self.dtype if self.dtype != dtypes.bool and self.dtype.__class__ is not ImageDType else dtypes.float32) # noqa: E501 if reverse: x, y = y, x if (xshape:=x.shape) == (yshape:=y.shape): return (x, y) @@ -742,7 +761,7 @@ class Tensor: return mlops.Mul.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x != 1.0 else self def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: x = self._to_float(x) - return mlops.Div.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or reverse or not x or not dtypes.is_float(self.dtype) else self.mul(1/x) + return mlops.Div.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or reverse or not x or not dtypes.is_float(self.dtype) else self.mul(1/x) # noqa: E501 def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor: x = self._to_float(x) if x.__class__ is not Tensor and not reverse: @@ -761,7 +780,7 @@ class Tensor: # we need 0 to be positive so we need to correct base_sign when the base is 0 base_sign = base_sign - (1.5 * (1 - (self.sign().abs() if not reverse else x.sign().abs() if isinstance(x, Tensor) else abs(int(bool(x)))))) # inject nan if the base is negative and the power is not an integer - to_nan = (((x - x.trunc()) * 1e10).abs().clip(0, 1) if isinstance(x, Tensor) else int(bool(x - int(x))) if not reverse else ((self - self.trunc()) * 1e10).abs().clip(0, 1)) * base_sign + to_nan = (((x - x.trunc()) * 1e10).abs().clip(0, 1) if isinstance(x, Tensor) else int(bool(x - int(x))) if not reverse else ((self - self.trunc()) * 1e10).abs().clip(0, 1)) * base_sign # noqa: E501 inject_nan = ((((-to_nan) * 2) + 1)).log().add(1) if isinstance(to_nan, Tensor) else 1 if not to_nan else float("nan") return ar.mul(sign * base_sign + (1 - base_sign)).mul(inject_nan) def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x) @@ -833,7 +852,7 @@ class Tensor: mask = (Tensor.rand(*self.shape, requires_grad=False, device=self.device) >= p).cast(dtypes.bool) return self * mask * (1/(1.0 - p)) - def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor: + def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor: # noqa: E501 # NOTE: it works if key, value have symbolic shape assert all_int(self.shape), f"does not support symbolic shape {self.shape}" if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool) @@ -849,7 +868,7 @@ class Tensor: def sparse_categorical_crossentropy(self, Y, ignore_index=-1) -> Tensor: # NOTE: self is a logits input loss_mask = Y != ignore_index - y_counter = Tensor.arange(self.shape[-1], dtype=dtypes.int32, requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1]) + y_counter = Tensor.arange(self.shape[-1], dtype=dtypes.int32, requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1]) # noqa: E501 y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1.0, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1]) return self.log_softmax().mul(y).sum() / loss_mask.sum() @@ -857,7 +876,8 @@ class Tensor: def cast(self, dtype:DType) -> Tensor: # hack for devices that don't support bfloat16 - if self.dtype == dtypes.bfloat16: return self.bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).contiguous().bitcast(dtypes.float32).cast(dtype) + if self.dtype == dtypes.bfloat16: + return self.bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).contiguous().bitcast(dtypes.float32).cast(dtype) return mlops.Cast.apply(self, dtype=dtype) if self.dtype != dtype else self def bitcast(self, dtype:DType) -> Tensor: assert self.dtype.itemsize == dtype.itemsize, "can't bitcast mismatched dtype itemsizes"