mirror of https://github.com/commaai/tinygrad.git
ruff checks the max line length is 150 (#2734)
* ruff checks the max line length is 150 * fix tensor.py * a lot more * done
This commit is contained in:
parent
3635540ddb
commit
6d6eb9302d
|
@ -248,7 +248,7 @@ indent-after-paren=4
|
|||
indent-string=' '
|
||||
|
||||
# Maximum number of characters on a single line.
|
||||
max-line-length=100
|
||||
max-line-length=150
|
||||
|
||||
# Maximum number of lines in a module
|
||||
max-module-lines=1000
|
||||
|
|
|
@ -12,6 +12,7 @@ select = [
|
|||
"E272",
|
||||
# "E303",
|
||||
# "E304",
|
||||
"E501",
|
||||
# "E502",
|
||||
"E702",
|
||||
"E703",
|
||||
|
@ -22,6 +23,8 @@ select = [
|
|||
"UP039", # unnecessary-class-parentheses
|
||||
]
|
||||
|
||||
line-length = 150
|
||||
|
||||
exclude = [
|
||||
"disassemblers/",
|
||||
"docs/",
|
||||
|
|
3
setup.py
3
setup.py
|
@ -14,7 +14,8 @@ setup(name='tinygrad',
|
|||
license='MIT',
|
||||
long_description=long_description,
|
||||
long_description_content_type='text/markdown',
|
||||
packages = ['tinygrad', 'tinygrad.codegen', 'tinygrad.nn', 'tinygrad.renderer', 'tinygrad.runtime', 'tinygrad.shape', 'tinygrad.features', 'tinygrad.features.graph'],
|
||||
packages = ['tinygrad', 'tinygrad.codegen', 'tinygrad.nn', 'tinygrad.renderer',
|
||||
'tinygrad.runtime', 'tinygrad.shape', 'tinygrad.features', 'tinygrad.features.graph'],
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License"
|
||||
|
|
4
sz.py
4
sz.py
|
@ -38,7 +38,7 @@ def gen_diff(table_old, table_new):
|
|||
file_stat_old = [stats for stats in table_old if file in stats]
|
||||
file_stat_new = [stats for stats in table_new if file in stats]
|
||||
if file_stat_new[0][1]-file_stat_old[0][1] != 0 or file_stat_new[0][2]-file_stat_old[0][2] != 0:
|
||||
table.append([file_stat_new[0][0], file_stat_new[0][1], file_stat_new[0][1]-file_stat_old[0][1], file_stat_new[0][2], file_stat_new[0][2]-file_stat_old[0][2]])
|
||||
table.append([file_stat_new[0][0], file_stat_new[0][1], file_stat_new[0][1]-file_stat_old[0][1], file_stat_new[0][2], file_stat_new[0][2]-file_stat_old[0][2]]) # noqa: E501
|
||||
return table
|
||||
|
||||
def display_diff(diff): return "+"+str(diff) if diff > 0 else str(diff)
|
||||
|
@ -58,7 +58,7 @@ if __name__ == "__main__":
|
|||
if len(sys.argv) == 3:
|
||||
print("### Changes")
|
||||
print("```")
|
||||
print(tabulate([headers] + sorted(table, key=lambda x: -x[1]), headers="firstrow", intfmt=(..., "d", "+d"), floatfmt=(..., ..., ..., ".1f", "+.1f"))+"\n")
|
||||
print(tabulate([headers] + sorted(table, key=lambda x: -x[1]), headers="firstrow", intfmt=(..., "d", "+d"), floatfmt=(..., ..., ..., ".1f", "+.1f"))+"\n") # noqa: E501
|
||||
print(f"\ntotal lines changes: {display_diff(sum([x[2] for x in table]))}")
|
||||
print("```")
|
||||
else:
|
||||
|
|
|
@ -97,6 +97,7 @@ if __name__ == '__main__':
|
|||
args = parser.parse_args()
|
||||
|
||||
# run eval and exit
|
||||
adaptor = LLaMaAdaptor(model_gen=args.gen, model_size=args.size, quantize=args.quantize, checkpoint_path=args.weights, tokenizer_path=args.tokenizer, device="cpu")
|
||||
adaptor = LLaMaAdaptor(model_gen=args.gen, model_size=args.size, quantize=args.quantize,
|
||||
checkpoint_path=args.weights, tokenizer_path=args.tokenizer, device="cpu")
|
||||
results = evaluator.evaluate(adaptor, tasks.get_task_dict(args.eval.split(",")), False, 0, args.limit)
|
||||
print(json.dumps(results, indent=2))
|
||||
|
|
|
@ -52,7 +52,7 @@ def benchmark_model(m, devices, validate_outs=False):
|
|||
onnx_model = onnx.load(fn)
|
||||
output_names = [out.name for out in onnx_model.graph.output]
|
||||
excluded = {inp.name for inp in onnx_model.graph.initializer}
|
||||
input_shapes = {inp.name:tuple(x.dim_value if x.dim_value != 0 else 1 for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input if inp.name not in excluded}
|
||||
input_shapes = {inp.name:tuple(x.dim_value if x.dim_value != 0 else 1 for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input if inp.name not in excluded} # noqa: E501
|
||||
input_types = {inp.name: tensor_dtype_to_np_dtype(inp.type.tensor_type.elem_type) for inp in onnx_model.graph.input if inp.name not in excluded}
|
||||
#input_types = {k:v if v!=np.float16 else np.float32 for k,v in input_types.items()} # cast
|
||||
np_inputs = {k:torch.randn(shp).numpy().astype(input_types[k]) for k,shp in input_shapes.items()}
|
||||
|
|
|
@ -113,7 +113,9 @@ class TestOptBinOp(unittest.TestCase):
|
|||
|
||||
def test_no_binop_rerun(self): return self._test_no_binop_rerun(lambda a,b: a*b, lambda a,b: (a*b).reshape(16, 16, 1))
|
||||
def test_no_binop_rerun_alt(self): return self._test_no_binop_rerun(lambda a,b: (a*b).reshape(16, 16, 1), lambda a,b: a*b)
|
||||
def test_no_binop_rerun_reduce_broadcast(self): return self._test_no_binop_rerun(lambda a,b: a.sum()+b, lambda a,b: a.sum().reshape(1,1)+b, allowed=2)
|
||||
def test_no_binop_rerun_reduce_broadcast(self):
|
||||
return self._test_no_binop_rerun(lambda a,b: a.sum()+b, lambda a,b: a.sum().reshape(1,1)+b, allowed=2)
|
||||
|
||||
@unittest.skip("this test started failing with the new change, based movementop issue")
|
||||
def test_no_binop_rerun_transposed(self): return self._test_no_binop_rerun(lambda a,b: (a.T*b.T).T, lambda a,b: a*b)
|
||||
def test_no_binop_rerun_mid_reshape(self): return self._test_no_binop_rerun(lambda a,b: (a*b).reshape(256)+a.reshape(256))
|
||||
|
|
|
@ -35,7 +35,7 @@ class TestYOLOv8(unittest.TestCase):
|
|||
def test_forward_pass_torch_onnx(self):
|
||||
variant = 'n'
|
||||
weights_location = fetch(f'https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{variant}.safetensors')
|
||||
weights_location_pt = fetch(f'https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8{variant}.pt', name=f"yolov8{variant}.pt") # it needs the pt extension
|
||||
weights_location_pt = fetch(f'https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8{variant}.pt', name=f"yolov8{variant}.pt") # it needs the pt extension # noqa: E501
|
||||
weights_location_onnx = weights_location_pt.parent / f"yolov8{variant}.onnx"
|
||||
|
||||
# the ultralytics export prints a lot of unneccesary things
|
||||
|
@ -48,7 +48,7 @@ class TestYOLOv8(unittest.TestCase):
|
|||
state_dict = safe_load(weights_location)
|
||||
load_state_dict(TinyYolov8, state_dict)
|
||||
|
||||
image_location = [np.frombuffer(io.BytesIO(fetch('https://raw.githubusercontent.com/ultralytics/yolov5/master/data/images/bus.jpg').read_bytes()).read(), np.uint8)]
|
||||
image_location = [np.frombuffer(io.BytesIO(fetch('https://raw.githubusercontent.com/ultralytics/yolov5/master/data/images/bus.jpg').read_bytes()).read(), np.uint8)] # noqa: E501
|
||||
orig_image = [cv2.imdecode(image_location[0], 1)]
|
||||
|
||||
input_image = preprocess(orig_image)
|
||||
|
|
|
@ -20,14 +20,14 @@ class TextModelExport(unittest.TestCase):
|
|||
prg, inp_sizes, _, _ = export_model(model, "", *inputs)
|
||||
prg = json.loads(prg)
|
||||
|
||||
assert len(inputs) == len(prg["inputs"]) == len(inp_sizes), f"Model and exported inputs don't match: mdl={len(inputs)}, prg={len(prg['inputs'])}, inp_sizes={len(inp_sizes)}"
|
||||
assert len(inputs) == len(prg["inputs"]) == len(inp_sizes), f"Model and exported inputs don't match: mdl={len(inputs)}, prg={len(prg['inputs'])}, inp_sizes={len(inp_sizes)}" # noqa: E501
|
||||
|
||||
for i in range(len(inputs)):
|
||||
assert f"input{i}" in inp_sizes, f"input{i} not captured in inp_sizes"
|
||||
assert f"input{i}" in prg["buffers"], f"input{i} not captured in exported buffers"
|
||||
|
||||
for i, exported_input in enumerate(prg["inputs"]):
|
||||
assert inputs[i].dtype.name == exported_input["dtype"], f"Model and exported input dtype don't match: mdl={inputs[i].dtype.name}, prg={exported_input['dtype']}"
|
||||
assert inputs[i].dtype.name == exported_input["dtype"], f"Model and exported input dtype don't match: mdl={inputs[i].dtype.name}, prg={exported_input['dtype']}" # noqa: E501
|
||||
|
||||
def test_multi_output_model_export(self):
|
||||
model = MockMultiOutputModel()
|
||||
|
@ -36,14 +36,14 @@ class TextModelExport(unittest.TestCase):
|
|||
prg, _, out_sizes, _ = export_model(model, "", input)
|
||||
prg = json.loads(prg)
|
||||
|
||||
assert len(outputs) == len(prg["outputs"]) == len(out_sizes), f"Model and exported outputs don't match: mdl={len(outputs)}, prg={len(prg['outputs'])}, inp_sizes={len(out_sizes)}"
|
||||
assert len(outputs) == len(prg["outputs"]) == len(out_sizes), f"Model and exported outputs don't match: mdl={len(outputs)}, prg={len(prg['outputs'])}, inp_sizes={len(out_sizes)}" # noqa: E501
|
||||
|
||||
for i in range(len(outputs)):
|
||||
assert f"output{i}" in out_sizes, f"output{i} not captured in out_sizes"
|
||||
assert f"output{i}" in prg["buffers"], f"output{i} not captured in exported buffers"
|
||||
|
||||
for i, exported_output in enumerate(prg["outputs"]):
|
||||
assert outputs[i].dtype.name == exported_output["dtype"], f"Model and exported output dtype don't match: mdl={outputs[i].dtype.name}, prg={exported_output['dtype']}"
|
||||
assert outputs[i].dtype.name == exported_output["dtype"], f"Model and exported output dtype don't match: mdl={outputs[i].dtype.name}, prg={exported_output['dtype']}" # noqa: E501
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -100,7 +100,8 @@ class TestLrScheduler(unittest.TestCase):
|
|||
without = lr_scheduler_training()
|
||||
sched_fns = [MultiStepLR, ReduceLROnPlateau, CosineAnnealingLR, OneCycleLR]
|
||||
argss = [{'milestones': [5, 7, 10, 15], 'gamma': 0.5}, {'factor': 0.5, 'patience': 2}, {'T_max': 25, 'eta_min': 0.001},
|
||||
{'pct_start': 0.3, 'anneal_strategy': 'linear', 'cycle_momentum': False, 'div_factor': 25.0, 'final_div_factor': 10000.0, 'max_lr':1e-5, 'total_steps': 25}]
|
||||
{'pct_start': 0.3, 'anneal_strategy': 'linear', 'cycle_momentum': False, 'div_factor': 25.0, 'final_div_factor': 10000.0,
|
||||
'max_lr':1e-5, 'total_steps': 25}]
|
||||
for sched_fn, args in zip(sched_fns, argss):
|
||||
with_sched = lr_scheduler_training(sched_fn, args)
|
||||
assert with_sched > without
|
||||
|
|
|
@ -24,7 +24,7 @@ def consec(shape, start=1):
|
|||
def set_(reference: Tensor, shape, strides, offset):
|
||||
if reference.lazydata.base.realized is None: reference.realize()
|
||||
assert reference.lazydata.base.realized, "base has to be realized before setting it to strided's base"
|
||||
strided = Tensor(LazyBuffer(device=reference.device, st=ShapeTracker((View.create(shape=shape, strides=strides, offset=offset),)), optype=None, op=None, dtype=reference.dtype, src=None, base=reference.lazydata.base))
|
||||
strided = Tensor(LazyBuffer(device=reference.device, st=ShapeTracker((View.create(shape=shape, strides=strides, offset=offset),)), optype=None, op=None, dtype=reference.dtype, src=None, base=reference.lazydata.base)) # noqa: E501
|
||||
assert strided.lazydata.st.real_strides() == strides, "real_strides should equal strides for strided"
|
||||
assert strided.lazydata in reference.lazydata.base.views, "base.views should contain strided.lazydata"
|
||||
return strided
|
||||
|
|
|
@ -48,7 +48,8 @@ class TestOnnxModel(unittest.TestCase):
|
|||
mt2 = time.monotonic()
|
||||
tinygrad_out = tinygrad_out.numpy()
|
||||
et = time.monotonic()
|
||||
if not CI: print(f"ran openpilot model in {(et-st)*1000.0:.2f} ms, waited {(mt2-mt)*1000.0:.2f} ms for realize, {(et-mt2)*1000.0:.2f} ms for GPU queue")
|
||||
if not CI:
|
||||
print(f"ran openpilot model in {(et-st)*1000.0:.2f} ms, waited {(mt2-mt)*1000.0:.2f} ms for realize, {(et-mt2)*1000.0:.2f} ms for GPU queue")
|
||||
|
||||
if not CI:
|
||||
import cProfile
|
||||
|
@ -100,7 +101,7 @@ class TestOnnxModel(unittest.TestCase):
|
|||
|
||||
def test_efficientnet(self):
|
||||
input_name, input_new = "images:0", True
|
||||
self._test_model(fetch("https://github.com/onnx/models/raw/main/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx"), input_name, input_new)
|
||||
self._test_model(fetch("https://github.com/onnx/models/raw/main/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx"), input_name, input_new) # noqa: E501
|
||||
|
||||
def test_shufflenet(self):
|
||||
input_name, input_new = "gpu_0/data_0", False
|
||||
|
|
|
@ -32,7 +32,7 @@ def helper_test(nm, gen, train, max_memory_allowed, max_kernels_allowed, all_jit
|
|||
assert GlobalCounters.mem_used/1e9 < max_memory_allowed, f"{nm} used more than {max_memory_allowed:.2f} GB"
|
||||
assert not kernels_used or kernels_used <= max_kernels_allowed, f"{nm} used more than {max_kernels_allowed} kernels"
|
||||
if all_jitted:
|
||||
assert kernels_used > 0 and kernels_used == GlobalCounters.kernel_count or (kernels_used == 1 and getattr(Device[Device.DEFAULT], "graph", None)), f"only {kernels_used} out of {GlobalCounters.kernel_count} were jitted"
|
||||
assert kernels_used > 0 and kernels_used == GlobalCounters.kernel_count or (kernels_used == 1 and getattr(Device[Device.DEFAULT], "graph", None)), f"only {kernels_used} out of {GlobalCounters.kernel_count} were jitted" # noqa: E501
|
||||
|
||||
class TestRealWorld(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
|
|
@ -13,7 +13,7 @@ TEST_FILE_2 = str(pathlib.Path(__file__).parent / "whisper/test2.wav")
|
|||
TRANSCRIPTION_2 = "a slightly longer audio file so that we can test batch transcriptions of varying length."
|
||||
# TODO this file will possibly not survive long. find another 1-2 minute sound file online to transcribe
|
||||
TEST_FILE_3_URL = 'https://homepage.ntu.edu.tw/~karchung/miniconversations/mc45.mp3'
|
||||
TRANSCRIPTION_3 = "Just lie back and relax. Is the level of pressure about right? Yes, it's fine, and I'd like conditioner please. Sure. I'm going to start the second lathering now. Would you like some Q-tips? How'd you like it cut? I'd like my bangs and the back trimmed, and I'd like the rest thinned out a bit and layered. Where would you like the part? On the left, right about here. Here, have a look. What do you think? It's fine. Here's a thousand anti-dollars. It's 30-ant extra for the rants. Here's your change and receipt. Thank you, and please come again. So how do you like it? It could have been worse, but you'll notice that I didn't ask her for her card. Hmm, yeah. Maybe you can try that place over there next time."
|
||||
TRANSCRIPTION_3 = "Just lie back and relax. Is the level of pressure about right? Yes, it's fine, and I'd like conditioner please. Sure. I'm going to start the second lathering now. Would you like some Q-tips? How'd you like it cut? I'd like my bangs and the back trimmed, and I'd like the rest thinned out a bit and layered. Where would you like the part? On the left, right about here. Here, have a look. What do you think? It's fine. Here's a thousand anti-dollars. It's 30-ant extra for the rants. Here's your change and receipt. Thank you, and please come again. So how do you like it? It could have been worse, but you'll notice that I didn't ask her for her card. Hmm, yeah. Maybe you can try that place over there next time." # noqa: E501
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ["LLVM", "CLANG", "CPU", "GPU"], "Not working on LLVM, slow on others. GPU reequires cl_khr_fp16")
|
||||
class TestWhisper(unittest.TestCase):
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
# ruff: noqa: E501
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType, OSX, temp, least_upper_dtype
|
||||
|
|
|
@ -15,7 +15,8 @@ dtypes_int = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint
|
|||
dtypes_bool = (dtypes.bool,)
|
||||
binary_operations = [operator.add, operator.sub, operator.mul]
|
||||
integer_binary_operations = binary_operations + [(Tensor.xor, np.bitwise_xor)]
|
||||
unary_operations = [(Tensor.exp, np.exp), (Tensor.log, np.log), operator.neg, (Tensor.sin, np.sin), (Tensor.sqrt, np.sqrt), (Tensor.reciprocal, np.reciprocal)]
|
||||
unary_operations = [(Tensor.exp, np.exp), (Tensor.log, np.log), operator.neg, (Tensor.sin, np.sin),
|
||||
(Tensor.sqrt, np.sqrt), (Tensor.reciprocal, np.reciprocal)]
|
||||
|
||||
# TODO: enable this (this is a dtype issue)
|
||||
#binary_operations.append(operator.truediv)
|
||||
|
@ -57,7 +58,7 @@ def universal_test_unary(a, dtype, op):
|
|||
if not isinstance(op, tuple): op = (op, op)
|
||||
tensor_value = op[0](Tensor([a], dtype=dtype)).numpy()
|
||||
numpy_value = op[1](np.array([a]).astype(dtype.np))
|
||||
if dtype in dtypes_float: np.testing.assert_allclose(tensor_value, numpy_value, atol=5 if Device.DEFAULT == "METAL" and op[0] == Tensor.sin else 1e-3, rtol=2 if Device.DEFAULT == "METAL" and op[0] == Tensor.sin else 1e-4 if dtype == dtypes.float32 else 1e-2) # exp and log and sin are approximations (in METAL, the default fast-math versions are less precise)
|
||||
if dtype in dtypes_float: np.testing.assert_allclose(tensor_value, numpy_value, atol=5 if Device.DEFAULT == "METAL" and op[0] == Tensor.sin else 1e-3, rtol=2 if Device.DEFAULT == "METAL" and op[0] == Tensor.sin else 1e-4 if dtype == dtypes.float32 else 1e-2) # exp and log and sin are approximations (in METAL, the default fast-math versions are less precise) # noqa: E501
|
||||
else: np.testing.assert_equal(tensor_value, numpy_value)
|
||||
|
||||
def universal_test_cast(a, in_dtype, dtype):
|
||||
|
@ -130,7 +131,7 @@ class TestDTypeALU(unittest.TestCase):
|
|||
def test_int32_midcast_float(self, a, b, c, op1, op2): universal_test_midcast(a, b, c, op1, op2, dtypes.int32, dtypes.float32)
|
||||
|
||||
# Metal and CUDACPU behave differently than numpy in CI for overflows
|
||||
@given(st.floats(width=32, min_value=0, max_value=10.0) if CI and (Device.DEFAULT == "METAL" or getenv("CUDACPU")) else ht.float32, st.floats(width=32, min_value=0, max_value=10.0) if CI and (Device.DEFAULT == "METAL" or getenv("CUDACPU")) else ht.float32, ht.int32, st.sampled_from(binary_operations), st.sampled_from(integer_binary_operations))
|
||||
@given(st.floats(width=32, min_value=0, max_value=10.0) if CI and (Device.DEFAULT == "METAL" or getenv("CUDACPU")) else ht.float32, st.floats(width=32, min_value=0, max_value=10.0) if CI and (Device.DEFAULT == "METAL" or getenv("CUDACPU")) else ht.float32, ht.int32, st.sampled_from(binary_operations), st.sampled_from(integer_binary_operations)) # noqa: E501
|
||||
def test_float_midcast_int32(self, a, b, c, op1, op2): universal_test_midcast(a, b, c, op1, op2, dtypes.float32, dtypes.int32)
|
||||
|
||||
@given(ht.float32, st.sampled_from(dtypes_float+dtypes_int+dtypes_bool))
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
# ruff: noqa: E501
|
||||
import numpy as np
|
||||
import unittest, os
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
# ruff: noqa: E501
|
||||
import unittest
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.features.search import Opt, OptOps
|
||||
|
|
|
@ -4,7 +4,7 @@ import numpy as np
|
|||
from tinygrad.helpers import CI
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad.nn import BatchNorm2d, Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear, GroupNorm, LayerNorm, LayerNorm2d, Embedding, InstanceNorm
|
||||
from tinygrad.nn import BatchNorm2d, Conv1d,ConvTranspose1d, Conv2d,ConvTranspose2d, Linear, GroupNorm, LayerNorm,LayerNorm2d, Embedding, InstanceNorm
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
# ruff: noqa: E501
|
||||
import torch
|
||||
import time
|
||||
import math
|
||||
|
|
|
@ -63,8 +63,10 @@ class TestOptim(unittest.TestCase):
|
|||
|
||||
def test_multistep_sgd_nesterov_momentum(self): self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9, 'nesterov': True}, 1e-5, 0)
|
||||
def test_multistep_sgd_high_lr_nesterov_momentum(self): self._test_sgd(10, {'lr': 10, 'momentum': 0.9, 'nesterov': True}, 1e-5, 3e-4)
|
||||
def test_multistep_sgd_nesterov_momentum_wd(self): self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1}, 1e-5, 0)
|
||||
def test_multistep_sgd_high_lr_nesterov_momentum_wd(self): self._test_sgd(10, {'lr': 9, 'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1}, 1e-5, 3e-4)
|
||||
def test_multistep_sgd_nesterov_momentum_wd(self):
|
||||
self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1}, 1e-5, 0)
|
||||
def test_multistep_sgd_high_lr_nesterov_momentum_wd(self):
|
||||
self._test_sgd(10, {'lr': 9, 'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1}, 1e-5, 3e-4)
|
||||
|
||||
def test_adam(self): self._test_adam(1, {'lr': 0.001}, 1e-5, 0)
|
||||
def test_adam_high_lr(self): self._test_adam(1, {'lr': 10}, 1e-4, 1e-4)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
# ruff: noqa: E501
|
||||
import math
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
|
|
@ -106,7 +106,7 @@ def helper_test_generic(name, f1, f1_args, f2, f2_args):
|
|||
desc = "faster" if et_torch > et_tinygrad else "slower"
|
||||
flops = save_ops*1e-6
|
||||
mem = save_mem*1e-6
|
||||
print(("\r" if not CI else "")+f"{name:42s} {et_torch:7.2f} ms ({flops/et_torch:8.2f} GFLOPS {mem/et_torch:8.2f} GB/s) in torch, {et_tinygrad:7.2f} ms ({flops/et_tinygrad:8.2f} GFLOPS {mem/et_tinygrad:8.2f} GB/s) in tinygrad, {colorize_float(et_tinygrad/et_torch)} {desc} {flops:10.2f} MOPS {mem:8.2f} MB")
|
||||
print(("\r" if not CI else "")+f"{name:42s} {et_torch:7.2f} ms ({flops/et_torch:8.2f} GFLOPS {mem/et_torch:8.2f} GB/s) in torch, {et_tinygrad:7.2f} ms ({flops/et_tinygrad:8.2f} GFLOPS {mem/et_tinygrad:8.2f} GB/s) in tinygrad, {colorize_float(et_tinygrad/et_torch)} {desc} {flops:10.2f} MOPS {mem:8.2f} MB") # noqa: E501
|
||||
np.testing.assert_allclose(val_tinygrad, val_torch, atol=1e-3, rtol=1e-3)
|
||||
|
||||
def helper_test_conv(bs, in_chans, out_chans, kernel_size, img_size_y, img_size_x):
|
||||
|
|
|
@ -136,11 +136,11 @@ class TestSymbolicReshape(unittest.TestCase):
|
|||
def test_symbolic_mask(self):
|
||||
# taken from gpt2 single kvcache
|
||||
# these two caused problems in gpt2 if reshape merged views
|
||||
view = View(shape=(1, (NumNode(1)+Variable('start_pos', 1, 128).bind(2)), 16, 64), strides=(0, 0, 64, 1), offset=NumNode(1024), mask=((0, 1), (Variable('start_pos', 1, 128).bind(2), (NumNode(1)+Variable('start_pos', 1, 128).bind(2))), (0, 16), (0, 64)), contiguous=False)
|
||||
view = View(shape=(1, (NumNode(1)+Variable('start_pos', 1, 128).bind(2)), 16, 64), strides=(0, 0, 64, 1), offset=NumNode(1024), mask=((0, 1), (Variable('start_pos', 1, 128).bind(2), (NumNode(1)+Variable('start_pos', 1, 128).bind(2))), (0, 16), (0, 64)), contiguous=False) # noqa: E501
|
||||
new_shape = (1, 1, (NumNode(1)+Variable('start_pos', 1, 128).bind(2)), 16, 64)
|
||||
assert view.reshape(new_shape) is None
|
||||
|
||||
view = View(shape=(2, 1, (NumNode(1)+Variable('start_pos', 1, 128)), 16, 64), strides=(0, 0, 1024, 64, 1), offset=131072, mask=((1, 2), (0, 1), (0, (NumNode(1)+Variable('start_pos', 1, 128))), (0, 16), (0, 64)), contiguous=False)
|
||||
view = View(shape=(2, 1, (NumNode(1)+Variable('start_pos', 1, 128)), 16, 64), strides=(0, 0, 1024, 64, 1), offset=131072, mask=((1, 2), (0, 1), (0, (NumNode(1)+Variable('start_pos', 1, 128))), (0, 16), (0, 64)), contiguous=False) # noqa: E501
|
||||
new_shape = (2, (NumNode(1)+Variable('start_pos', 1, 128)), 16, 64)
|
||||
assert view.reshape(new_shape) is None
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
# ruff: noqa: E501
|
||||
from typing import Optional, Tuple, Any, List
|
||||
import unittest, math
|
||||
import numpy as np
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
#!/usr/bin/env python
|
||||
# ruff: noqa: E501
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.helpers import prod, DEBUG
|
||||
|
|
|
@ -11,7 +11,8 @@ from dataclasses import dataclass
|
|||
from enum import Enum, auto
|
||||
|
||||
class OptOps(Enum):
|
||||
UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto(); LASTLOCAL = auto(); GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto() # noqa: E702
|
||||
UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto(); LASTLOCAL = auto() # noqa: E702
|
||||
GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto() # noqa: E702
|
||||
def __lt__(self, x:OptOps): return self.value < x.value
|
||||
|
||||
@dataclass(frozen=True, order=True)
|
||||
|
@ -29,19 +30,19 @@ class TensorCore:
|
|||
dtype_out: DType
|
||||
threads: List[Tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure
|
||||
upcast_dim: int # which TC dim to upcast
|
||||
thread_local_aliases: List[List[List[int]]] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim
|
||||
thread_local_aliases: List[List[List[int]]] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim # noqa: E501
|
||||
thread_local_sizes: List[int] # in each thread, the number of elements stored in registers for each TC dim
|
||||
arch: Optional[str] = None
|
||||
def __str__(self): return f"tensor_core<{self.device}, {self.dims}, {self.dtype_in}, {self.dtype_out}>"
|
||||
|
||||
tensor_cores: Dict[str, List[TensorCore]] = {
|
||||
"METAL": [
|
||||
TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"),
|
||||
TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"),
|
||||
TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"), # noqa: E501
|
||||
TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"), # noqa: E501
|
||||
],
|
||||
"HIP": [
|
||||
TensorCore(device="HIP", dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]),
|
||||
TensorCore(device="HIP", dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.half, upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]),
|
||||
TensorCore(device="HIP", dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501
|
||||
TensorCore(device="HIP", dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.half, upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501
|
||||
]
|
||||
}
|
||||
|
||||
|
@ -65,7 +66,7 @@ class LinearizerOptions(NamedTuple):
|
|||
|
||||
class Kernel:
|
||||
def __init__(self, ast:LazyOp, opts:Optional[LinearizerOptions]=None):
|
||||
self.opts = opts if opts else (cast(Compiled, Device[Device.DEFAULT]).linearizer_opts if isinstance(Device[Device.DEFAULT], Compiled) else LinearizerOptions())
|
||||
self.opts = opts if opts else (cast(Compiled, Device[Device.DEFAULT]).linearizer_opts if isinstance(Device[Device.DEFAULT], Compiled) else LinearizerOptions()) # noqa: E501
|
||||
self.ast = ast
|
||||
assert ast.op == BufferOps.STORE, f"kernels must have a store as the output, got {ast.op}"
|
||||
|
||||
|
@ -133,8 +134,8 @@ class Kernel:
|
|||
def membufs(self) -> List[MemBuffer]: return [x for x in self.bufs if isinstance(x, MemBuffer)]
|
||||
|
||||
# TODO: these need more tests or it might silently be no-op
|
||||
def shape_offsets(self, i:int): return itertools.product(*[list(range(cast(int, s))) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()]
|
||||
def float4_axis(self, i:int): return [x-(self.shape_len-self.upcasted) for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x]%4 == 0]
|
||||
def shape_offsets(self, i:int): return itertools.product(*[list(range(cast(int, s))) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()] # noqa: E501
|
||||
def float4_axis(self, i:int): return [x-(self.shape_len-self.upcasted) for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x]%4 == 0] # noqa: E501
|
||||
|
||||
def upcasted_axis(self, i:int):
|
||||
return list(zip(self.sts[i].shape[self.shape_len-self.upcasted:],
|
||||
|
@ -149,11 +150,11 @@ class Kernel:
|
|||
return [sum(t) for t in itertools.product(*[[y*acc_strides[i] for y in range(x[0])] for i,x in enumerate(upcasted_i[::-1])])]
|
||||
|
||||
def get_upcast_dim(self, i:int) -> List[int]:
|
||||
should_upcast = self.opts.supports_float4 and (self.bufs[i].dtype in [dtypes.float32, dtypes.float16] or isinstance(self.bufs[i].dtype, ImageDType))
|
||||
should_upcast = self.opts.supports_float4 and (self.bufs[i].dtype in [dtypes.float32, dtypes.float16] or isinstance(self.bufs[i].dtype, ImageDType)) # noqa: E501
|
||||
return [x for x in self.sts[i].unit_stride_axes() if should_upcast and x >= self.shape_len-self.upcasted and self.sts[i].shape[x] > 1]
|
||||
|
||||
@property
|
||||
def first_reduce(self) -> int: return [x!=y for x,y in zip(self.sts[0].shape[:self.shape_len-self.upcasted]+(0,), self.full_shape[:self.shape_len-self.upcasted]+(1,))].index(True)
|
||||
def first_reduce(self) -> int:return [x!=y for x,y in zip(self.sts[0].shape[:self.shape_len-self.upcasted]+(0,), self.full_shape[:self.shape_len-self.upcasted]+(1,))].index(True) # noqa: E501
|
||||
|
||||
@property
|
||||
def output_shape(self) -> Tuple[sint, ...]: return self.sts[0].shape
|
||||
|
@ -168,7 +169,8 @@ class Kernel:
|
|||
def shape_len(self) -> int: return len(self.sts[0].shape)
|
||||
|
||||
@property
|
||||
def upcast_in_mid_reduce_axes(self) -> List[int]: return [j for j in range(self.first_reduce, self.first_reduce+len(self.group_for_reduce)) if self.full_shape[j] == self.sts[0].shape[j]]
|
||||
def upcast_in_mid_reduce_axes(self) -> List[int]:
|
||||
return [j for j in range(self.first_reduce, self.first_reduce+len(self.group_for_reduce)) if self.full_shape[j] == self.sts[0].shape[j]]
|
||||
|
||||
@property
|
||||
def global_dims(self) -> int: return self.first_reduce-self.local_dims
|
||||
|
@ -189,7 +191,7 @@ class Kernel:
|
|||
# after global are local_dims; warp ones used in tensor cores must be closest to first_reduce (cyan)
|
||||
colors += ["cyan"] * self.local_dims
|
||||
# between first_reduce and first_reduce + group_for_reduce, they are either upcast mid reduce (white), or late upcasted (green)
|
||||
colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce + len(self.group_for_reduce))]
|
||||
colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce + len(self.group_for_reduce))] # noqa: E501
|
||||
# between first_reduce + group_for_reduce and upcasted, they are reduce (red)
|
||||
colors += ["red"] * ((self.shape_len-self.upcasted) - (self.first_reduce + len(self.group_for_reduce)))
|
||||
# upcasted dimensions are reduce (magenta) or normal (yellow)
|
||||
|
@ -198,7 +200,7 @@ class Kernel:
|
|||
return colors
|
||||
|
||||
def colored_shape(self, pad:Optional[int]=None, dense=False) -> str:
|
||||
ret = ' '.join(colored(s, color) for s,color in zip([f"{s:4d}" if isinstance(s, int) and not dense else s for s in self.full_shape], self.colors()))
|
||||
ret = ' '.join(colored(s, color) for s,color in zip([f"{s:4d}" if isinstance(s, int) and not dense else s for s in self.full_shape], self.colors())) # noqa: E501
|
||||
if pad: ret += ' '*(pad-ansilen(ret))
|
||||
return ret
|
||||
|
||||
|
@ -267,7 +269,7 @@ class Kernel:
|
|||
can_merge = []
|
||||
for j in range(len(shapes)):
|
||||
# TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case
|
||||
can_merge.append(strides[j][i] is not None and ((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*cast(int, strides[j][i])) or (strides[j][i] == 0 and rets[j][-1][1] == 0)))
|
||||
can_merge.append(strides[j][i] is not None and ((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*cast(int, strides[j][i])) or (strides[j][i] == 0 and rets[j][-1][1] == 0))) # noqa: E501
|
||||
# more can merge than this
|
||||
mergeable = all(can_merge) and i != self.first_reduce
|
||||
for j in range(len(shapes)):
|
||||
|
@ -299,8 +301,9 @@ class Kernel:
|
|||
if global_dims > 0:
|
||||
if global_max:
|
||||
tmp = global_max[:global_dims] + (local_max[:self.local_dims] if local_max else [])
|
||||
if max(global_max) < max(self.full_shape[:global_dims]): self.reshape_and_permute(lambda x: self._limit_size(x, tmp + [math.inf] * (len(self.full_shape)-len(tmp))), None)
|
||||
assert max(global_max) >= max(self.full_shape[:global_dims]), f"device max allocation {max(self.full_shape[:global_dims])} exceeds global dim maximum {max(global_max)}"
|
||||
if max(global_max) < max(self.full_shape[:global_dims]):
|
||||
self.reshape_and_permute(lambda x: self._limit_size(x, tmp + [math.inf] * (len(self.full_shape)-len(tmp))), None)
|
||||
assert max(global_max) >= max(self.full_shape[:global_dims]), f"device max allocation {max(self.full_shape[:global_dims])} exceeds global dim maximum {max(global_max)}" # noqa: E501
|
||||
for i in range(global_dims-1):
|
||||
if i < len(global_max) and self.full_shape[i] > global_max[i]:
|
||||
order = list(range(len(self.full_shape)))
|
||||
|
@ -333,7 +336,7 @@ class Kernel:
|
|||
if not((tc.arch is None or tc.arch == os.uname().machine) and isinstance(self.reduceop.src[0], LazyOp)): continue
|
||||
has_cast = tc.dtype_in != tc.dtype_out
|
||||
|
||||
if has_cast and not(isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == UnaryOps.CAST and self.reduceop.src[0].arg[0] == tc.dtype_out): continue
|
||||
if has_cast and not(isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == UnaryOps.CAST and self.reduceop.src[0].arg[0] == tc.dtype_out): continue # noqa: E501
|
||||
mul_op = self.reduceop.src[0].src[0] if has_cast else self.reduceop.src[0]
|
||||
|
||||
if not(isinstance(mul_op, LazyOp) and mul_op.op == BinaryOps.MUL): continue
|
||||
|
@ -341,10 +344,10 @@ class Kernel:
|
|||
if not(isinstance(mul_op.src[1], LazyOp) and mul_op.src[1].op == BufferOps.LOAD and mul_op.src[1].arg.dtype == tc.dtype_in): continue
|
||||
buf0, buf1 = self.bufs.index(cast(MemBuffer, mul_op.src[0].arg)), self.bufs.index(cast(MemBuffer, mul_op.src[1].arg))
|
||||
buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
|
||||
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[0] == 0]
|
||||
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[1] == 0]
|
||||
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[0] == 0] # noqa: E501
|
||||
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[1] == 0] # noqa: E501
|
||||
|
||||
if not(axis_buf0 and axis_buf1 and self.full_shape[self.first_reduce]%tc.dims[2] == 0 and self.full_shape[self.first_reduce] >= tc.dims[2] and (self.shape_len-self.first_reduce) == 1): continue
|
||||
if not(axis_buf0 and axis_buf1 and self.full_shape[self.first_reduce]%tc.dims[2] == 0 and self.full_shape[self.first_reduce] >= tc.dims[2] and (self.shape_len-self.first_reduce) == 1): continue # noqa: E501
|
||||
|
||||
if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
|
||||
|
||||
|
@ -388,17 +391,17 @@ class Kernel:
|
|||
break
|
||||
|
||||
# alias buffer
|
||||
alias_pattern = [0]*(self.global_dims+(self.local_dims-len(tc.threads))) + [2]*(len(tc.threads)) + [0]*(self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3]*(self.upcasted-2)
|
||||
alias_pattern = [0]*(self.global_dims+(self.local_dims-len(tc.threads))) + [2]*(len(tc.threads)) + [0]*(self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3]*(self.upcasted-2) # noqa: E501
|
||||
self.alias_buffer(buf0, alias_pattern)
|
||||
self.alias_buffer(buf1, alias_pattern)
|
||||
return True
|
||||
return False
|
||||
|
||||
def apply_opt(self, opt:Opt):
|
||||
assert not self.dont_use_locals or opt.op not in {OptOps.LOCAL, OptOps.LASTLOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals"
|
||||
assert not self.dont_use_locals or opt.op not in {OptOps.LOCAL, OptOps.LASTLOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals" # noqa: E501
|
||||
self.applied_opts.append(opt)
|
||||
if opt.axis is not None:
|
||||
axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op == OptOps.GROUP or opt.op == OptOps.GROUPTOP else 0))
|
||||
axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op == OptOps.GROUP or opt.op == OptOps.GROUPTOP else 0)) # noqa: E501
|
||||
else:
|
||||
axis = -1
|
||||
if opt.amt is not None:
|
||||
|
@ -435,7 +438,7 @@ class Kernel:
|
|||
self.shift_to(axis, amt, insert_before=None)
|
||||
self.upcast()
|
||||
elif opt.op == OptOps.UPCASTMID: # white
|
||||
assert self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce"
|
||||
assert self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce" # noqa: E501
|
||||
axes = self.sts[0].unit_stride_axes()
|
||||
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
|
||||
assert axes[0] == axis, "wrong axis"
|
||||
|
@ -465,7 +468,7 @@ class Kernel:
|
|||
if self.bufs[0].dtype.__class__ is ImageDType:
|
||||
unit_stride_axes_mul_4 = [i for i in self.sts[0].unit_stride_axes(ignore_valid=True) if self.sts[0].shape[i]%4 == 0]
|
||||
assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[0]}"
|
||||
if len(unit_stride_axes_mul_4) and all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes:
|
||||
if len(unit_stride_axes_mul_4) and all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: # noqa: E501
|
||||
self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
|
||||
|
||||
def hand_coded_optimizations(self):
|
||||
|
@ -482,10 +485,10 @@ class Kernel:
|
|||
buf0_strides = self.sts[buf0].real_strides()
|
||||
buf1_strides = self.sts[buf1].real_strides()
|
||||
def has_expanded_axis(s, st): return any(x > 1 and y == 0 for x,y in zip(s,st))
|
||||
if buf0_strides[self.first_reduce] == 1 and not (has_expanded_axis(self.sts[buf0].shape, buf0_strides) and has_expanded_axis(self.sts[buf1].shape, buf1_strides)):
|
||||
if buf0_strides[self.first_reduce] == 1 and not (has_expanded_axis(self.sts[buf0].shape, buf0_strides) and has_expanded_axis(self.sts[buf1].shape, buf1_strides)): # noqa: E501
|
||||
for global_idx in range(self.global_dims):
|
||||
if self.full_shape[self.first_reduce]%MV_THREADS_PER_ROW == 0 and self.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
|
||||
if DEBUG >= 3: print(f"MATVEC: full_shape={self.full_shape} first_reduce={self.first_reduce} buf0_strides={buf0_strides} blocksize={MV_BLOCKSIZE} threads_per_row={MV_THREADS_PER_ROW} rows_per_thread={MV_ROWS_PER_THREAD}")
|
||||
if DEBUG >= 3: print(f"MATVEC: full_shape={self.full_shape} first_reduce={self.first_reduce} buf0_strides={buf0_strides} blocksize={MV_BLOCKSIZE} threads_per_row={MV_THREADS_PER_ROW} rows_per_thread={MV_ROWS_PER_THREAD}") # noqa: E501
|
||||
if MV_THREADS_PER_ROW > 1:
|
||||
self.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
|
||||
if MV_BLOCKSIZE > 1:
|
||||
|
@ -496,7 +499,7 @@ class Kernel:
|
|||
|
||||
if self.opts.has_local and self.opts.has_shared and all(isinstance(s, int) for s in self.sts[0].shape[:self.first_reduce]):
|
||||
# are we grouping? (requires local shape support)
|
||||
if not self.float4_axis(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048:
|
||||
if not self.float4_axis(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048: # noqa: E501
|
||||
# TODO: use 1024 if it's allowed in a smarter way
|
||||
for sz in (([256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
|
||||
if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts):
|
||||
|
@ -504,7 +507,7 @@ class Kernel:
|
|||
break
|
||||
|
||||
# are we upcasting in mid reduce? (only for images)
|
||||
if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1:
|
||||
if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: # noqa: E501
|
||||
axes = self.sts[0].unit_stride_axes()
|
||||
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
|
||||
if self.sts[0].shape[axes[0]]%4 == 0:
|
||||
|
@ -515,7 +518,7 @@ class Kernel:
|
|||
unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0]
|
||||
if buf.dtype.__class__ is ImageDType:
|
||||
#assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
|
||||
if len(unit_stride_axes_mul_4) and all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes:
|
||||
if len(unit_stride_axes_mul_4) and all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: # noqa: E501
|
||||
if unit_stride_axes_mul_4[0] < self.first_reduce:
|
||||
self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
|
||||
else:
|
||||
|
@ -548,9 +551,9 @@ class Kernel:
|
|||
while prod(self.sts[0].shape[:self.first_reduce]) >= 1024:
|
||||
xb_choices = []
|
||||
for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
|
||||
# if we haven't upcasted it, it's not symbolic, it mods, and some buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
|
||||
if axis not in upcasted_axis and isinstance(self.full_shape[axis], int) and self.full_shape[axis]%upcast_amount == 0 and any(st.views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index, st in enumerate(self.sts)):
|
||||
xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in self.sts), sum(st.views[-1].strides[axis] for st in self.sts), axis, upcast_amount))
|
||||
# if we haven't upcasted it, it's not symbolic, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
|
||||
if axis not in upcasted_axis and isinstance(self.full_shape[axis], int) and self.full_shape[axis]%upcast_amount == 0 and any(st.views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index, st in enumerate(self.sts)): # noqa: E501
|
||||
xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in self.sts), sum(st.views[-1].strides[axis] for st in self.sts), axis, upcast_amount)) # noqa: E501
|
||||
if xb_choices:
|
||||
xb_choices = sorted(xb_choices)
|
||||
if DEBUG >= 4: print(f"float4 merging axis : {xb_choices}")
|
||||
|
@ -559,8 +562,8 @@ class Kernel:
|
|||
else:
|
||||
break
|
||||
|
||||
# if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS
|
||||
if self.first_reduce < (self.shape_len-self.upcasted) and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))) and (self.upcasted == 0 or prod(self.full_shape[-self.upcasted:]) < 64):
|
||||
# if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast.
|
||||
if self.first_reduce < (self.shape_len-self.upcasted) and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))) and (self.upcasted == 0 or prod(self.full_shape[-self.upcasted:]) < 64): # noqa: E501
|
||||
if (s:=self.full_unupcasted_shape[-1]) <= 32 and isinstance(s, int): # NOTE: cannot loop unroll symbolic axis
|
||||
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
|
||||
# if it's small, upcast a second reduce dimension too
|
||||
|
@ -585,11 +588,11 @@ class Kernel:
|
|||
self.apply_opt(Opt(OptOps.NOLOCALS))
|
||||
else:
|
||||
# prioritize making expand axes local
|
||||
local_axis_ranking = [(any(self.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(self.sts))), axis) for axis in range(len(self.full_shape[:self.first_reduce]))]
|
||||
local_axis_ranking = [(any(self.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(self.sts))), axis) for axis in range(len(self.full_shape[:self.first_reduce]))] # noqa: E501
|
||||
to_local: List[Tuple[int, int]] = []
|
||||
for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])):
|
||||
local_size = prod(sz for _, sz in to_local)
|
||||
local_sz: Optional[int] = next((x for x in ([32] * (axis == 0) + [16, 8, 4, 3, 2]) if self.full_shape[axis] % x == 0 and local_size * x <= 128), None)
|
||||
local_sz: Optional[int] = next((x for x in ([32] * (axis == 0) + [16, 8, 4, 3, 2]) if self.full_shape[axis] % x == 0 and local_size * x <= 128), None) # noqa: E501
|
||||
if local_sz is not None: to_local.append((axis, local_sz))
|
||||
deleted_shape = 0
|
||||
for axis, local_sz in sorted(to_local[:3]):
|
||||
|
|
|
@ -25,10 +25,11 @@ class UOp:
|
|||
dtype: Optional[DType]
|
||||
vin: Tuple[UOp, ...]
|
||||
arg: Any
|
||||
def __repr__(self): return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}"
|
||||
def __repr__(self):
|
||||
return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}"
|
||||
|
||||
def get_grouped_dims(prefix, start_dim, local_dims, maxdim:int=0):
|
||||
local_idxs = loop_local_idxs = [Variable(f"{prefix}{start_dim+i}", 0, s-1) for i,s in enumerate(local_dims[0:maxdim-1] + (prod(local_dims[maxdim-1:]),) if len(local_dims) > maxdim else local_dims)]
|
||||
local_idxs = loop_local_idxs = [Variable(f"{prefix}{start_dim+i}", 0, s-1) for i,s in enumerate(local_dims[0:maxdim-1] + (prod(local_dims[maxdim-1:]),) if len(local_dims) > maxdim else local_dims)] # noqa: E501
|
||||
if maxdim != 0 and len(local_dims) > maxdim:
|
||||
dd = local_idxs[maxdim-1]
|
||||
nli = []
|
||||
|
@ -44,7 +45,9 @@ class Linearizer(Kernel):
|
|||
return self.uop(UOps.ALU, dtype, (a, render_b), op)
|
||||
|
||||
# NOTE: the consts have to be cached for deduping of downstream uops to work
|
||||
def const(self, b:Union[int,float], dtype=dtypes.int32, insert_before=None) -> UOp: return self.uop(UOps.CONST, dtype, tuple(), b, insert_before=insert_before)
|
||||
def const(self, b:Union[int,float], dtype=dtypes.int32, insert_before=None) -> UOp:
|
||||
return self.uop(UOps.CONST, dtype, tuple(), b, insert_before=insert_before)
|
||||
|
||||
def cast(self, val: UOp, dtype) -> UOp: return self.uop(UOps.CAST, dtype, (val,)) if val.dtype != dtype else val
|
||||
|
||||
def get_reduce_acc(self, op, dtype:DType):
|
||||
|
@ -56,8 +59,10 @@ class Linearizer(Kernel):
|
|||
DivNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.DIV),
|
||||
ModNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MOD),
|
||||
LtNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.CMPLT, dtype=dtypes.bool),
|
||||
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.ADD), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
||||
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.MUL, dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
||||
SumNode: lambda self,ops,ctx:
|
||||
functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.ADD), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
||||
AndNode: lambda self,ops,ctx:
|
||||
functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.MUL, dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
||||
|
||||
def global_load(self, i:int, idxs:Sequence[Node], acc=None, barrier:Optional[UOp]=None) -> List[UOp]:
|
||||
buf = self.bufs[i]
|
||||
|
@ -87,7 +92,7 @@ class Linearizer(Kernel):
|
|||
invalid_value = 0 if dtypes.is_int(buf.dtype) else 0.0
|
||||
for idx, valid, rep_idx in zip(e_idxs, e_valids, Node.iter_idxs(expand_vars)):
|
||||
this_const, idx, valid = (invalid_value, NumNode(0), NumNode(1)) if valid.max == 0 else (const, idx, valid)
|
||||
key = f"{acc}{localtype}{this_const if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}"
|
||||
key = f"{acc}{localtype}{this_const if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" # noqa: E501
|
||||
if key not in self.load_cache:
|
||||
if acc is not None:
|
||||
self.load_cache[key] = self.uop(UOps.DEFINE_ACC, localtype, (), this_const, cachable=False)
|
||||
|
@ -95,12 +100,12 @@ class Linearizer(Kernel):
|
|||
self.load_cache[key] = self.const(this_const, localtype)
|
||||
if valid.min == 0 and valid.max == 1:
|
||||
valid_rendered = valid.render(self.render_ops, self)
|
||||
self.load_cache[key] = self.uop(UOps.ALU, localtype, (valid_rendered, self.load_cache[key], self.const(invalid_value, localtype)), TernaryOps.WHERE)
|
||||
self.load_cache[key] = self.uop(UOps.ALU, localtype, (valid_rendered, self.load_cache[key], self.const(invalid_value, localtype)), TernaryOps.WHERE) # noqa: E501
|
||||
elif isinstance(buf.dtype, ImageDType):
|
||||
buf_uop = self.buf_uops[i]
|
||||
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
||||
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
|
||||
rendered_idx = self.uop(UOps.CAST, dtypes.int.vec(2), (image_idx[0].render(self.render_ops, self), image_idx[1].render(self.render_ops, self)))
|
||||
rendered_idx = self.uop(UOps.CAST, dtypes.int.vec(2), (image_idx[0].render(self.render_ops, self), image_idx[1].render(self.render_ops, self))) # noqa: E501
|
||||
valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, dtypes.float32.vec(4))) if valid.min == 0 else tuple()
|
||||
self.load_cache[key] = self.uop(UOps.LOAD, dtypes.float32.vec(4), (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
|
||||
idx_small = idx%4
|
||||
|
@ -179,7 +184,7 @@ class Linearizer(Kernel):
|
|||
# add global buffers
|
||||
for i,buf in enumerate(self.bufs):
|
||||
if isinstance(buf, MemBuffer):
|
||||
self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, PtrDType(buf.dtype) if not isinstance(buf.dtype, ImageDType) else buf.dtype, (), (f"data{buf.idx}", buf.dtype))
|
||||
self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, PtrDType(buf.dtype) if not isinstance(buf.dtype, ImageDType) else buf.dtype, (), (f"data{buf.idx}", buf.dtype)) # noqa: E501
|
||||
# add var vals
|
||||
for var in vars_from_ast(self.ast):
|
||||
assert var.expr is not None
|
||||
|
@ -190,7 +195,7 @@ class Linearizer(Kernel):
|
|||
# add a local buffer for multistage reduce. # TODO: use local alias
|
||||
if self.group_for_reduce:
|
||||
# TODO: the strides of this can be controlled
|
||||
self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+len(self.group_for_reduce)]) + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)])))
|
||||
self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+len(self.group_for_reduce)]) + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) # noqa: E501
|
||||
self.bufs.append(LocalBuffer("temp", self.sts[-1].size()))
|
||||
self.buf_uops.append(self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), ("temp", self.sts[-1].size())))
|
||||
|
||||
|
@ -204,7 +209,7 @@ class Linearizer(Kernel):
|
|||
|
||||
# define indexes
|
||||
global_idxs, loop_global_idxs = get_grouped_dims("gidx", 0, self.full_shape[:self.global_dims], 3 if self.opts.has_local else 0)
|
||||
local_idxs, loop_local_idxs = get_grouped_dims("lidx", self.global_dims, self.full_shape[self.global_dims:self.first_reduce+len(self.group_for_reduce)], 3 if self.opts.has_local else 0)
|
||||
local_idxs, loop_local_idxs = get_grouped_dims("lidx", self.global_dims, self.full_shape[self.global_dims:self.first_reduce+len(self.group_for_reduce)], 3 if self.opts.has_local else 0) # noqa: E501
|
||||
full_upcast_idxs = [Variable(None, 0, s-1) for s in self.full_shape[self.shape_len-self.upcasted:]]
|
||||
upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]]
|
||||
|
||||
|
@ -212,7 +217,7 @@ class Linearizer(Kernel):
|
|||
def render_loop(xx:List[Variable]) -> Tuple[UOp, ...]:
|
||||
new_loops = {x.expr:self.uop(UOps.LOOP, dtypes.int32, (
|
||||
self.const(x.min) if isinstance(x.min, int) else cast(Node, x.min).render(self.render_ops, self),
|
||||
self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), cachable=False) for x in xx if not isinstance(x, NumNode) and x.expr is not None}
|
||||
self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), cachable=False) for x in xx if not isinstance(x, NumNode) and x.expr is not None} # noqa: E501
|
||||
self.loop_uops.update(new_loops)
|
||||
return tuple(new_loops.values())
|
||||
|
||||
|
@ -221,11 +226,11 @@ class Linearizer(Kernel):
|
|||
self.local_size: Optional[List[int]] = None
|
||||
if self.dont_use_locals:
|
||||
self.global_size = [x.max+1 for x in loop_global_idxs][::-1]
|
||||
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)})
|
||||
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501
|
||||
elif self.opts.has_local:
|
||||
self.global_size, self.local_size = [x.max+1 for x in loop_global_idxs][::-1], [x.max+1 for x in loop_local_idxs][::-1]
|
||||
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)})
|
||||
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_local_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)})
|
||||
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501
|
||||
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_local_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)}) # noqa: E501
|
||||
else:
|
||||
render_loop(loop_global_idxs+loop_local_idxs)
|
||||
|
||||
|
@ -238,7 +243,7 @@ class Linearizer(Kernel):
|
|||
fake_reduce_idxs: List[Variable] = []
|
||||
if self.reduceop is not None:
|
||||
# define indexes
|
||||
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len-self.upcasted)]
|
||||
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len-self.upcasted)] # noqa: E501
|
||||
fake_reduce_idxs = [x*0 for x in reduce_idxs]
|
||||
|
||||
# define accumulator
|
||||
|
@ -275,7 +280,7 @@ class Linearizer(Kernel):
|
|||
buf_idxs = [idx*0 if s == 0 else idx for idx,s in zip(global_idxs+local_idxs+reduce_idxs+full_upcast_idxs,self.sts[i].real_strides())]
|
||||
if self.tensor_core:
|
||||
min_alias_idx = min(self.local_alias.keys())
|
||||
replace_input_idxs = calc_tc_idxs(self.tensor_core.thread_local_sizes[i-min_alias_idx], self.tensor_core.thread_local_aliases[i-min_alias_idx])
|
||||
replace_input_idxs = calc_tc_idxs(self.tensor_core.thread_local_sizes[i-min_alias_idx], self.tensor_core.thread_local_aliases[i-min_alias_idx]) # noqa: E501
|
||||
for n in range(len(self.tensor_core.threads)):
|
||||
buf_idxs[self.first_reduce-len(self.tensor_core.threads)+n] = replace_input_idxs[n] # replace locals
|
||||
for n in range(len(replace_input_idxs)-len(self.tensor_core.threads)):
|
||||
|
@ -294,14 +299,14 @@ class Linearizer(Kernel):
|
|||
for y in range(by):
|
||||
for x in range(bx):
|
||||
for j in range(acc_reds):
|
||||
op1, op2, op3 = locals_to_store[0][2][(x+(j*bx))*wmma_sz[0]:(x+(j*bx)+1)*wmma_sz[0]], locals_to_store[1][2][(y+(j*by))*wmma_sz[1]:(y+(j*by)+1)*wmma_sz[1]], acc[i:i+wmma_sz[2]]
|
||||
op1, op2, op3 = locals_to_store[0][2][(x+(j*bx))*wmma_sz[0]:(x+(j*bx)+1)*wmma_sz[0]], locals_to_store[1][2][(y+(j*by))*wmma_sz[1]:(y+(j*by)+1)*wmma_sz[1]], acc[i:i+wmma_sz[2]] # noqa: E501
|
||||
if self.opts.device != "HIP":
|
||||
ops = tuple(op1+op2+op3)
|
||||
else:
|
||||
ops = (self.uop(UOps.CAST, dtypes.half.vec(16), tuple(op1)),
|
||||
self.uop(UOps.CAST, dtypes.half.vec(16), tuple(op2)),
|
||||
self.uop(UOps.CAST, dtypes.float.vec(8), tuple(op3)))
|
||||
ret = self.uop(UOps.WMMA, dtypes.float.vec(2) if wmma_sz[2] == 2 else dtypes.float.vec(8), ops, (self.opts.device, self.tensor_core.dtype_in, self.tensor_core.dtype_out,))
|
||||
ret = self.uop(UOps.WMMA, dtypes.float.vec(2) if wmma_sz[2] == 2 else dtypes.float.vec(8), ops, (self.opts.device, self.tensor_core.dtype_in, self.tensor_core.dtype_out,)) # noqa: E501
|
||||
for z in range(cast(DType, ret.dtype).sz):
|
||||
acc[i+z] = self.uop(UOps.PHI, dtypes.float, (op3[z], self.uop(UOps.GEP, dtypes.float, (ret,), z)) + loop_ctx)
|
||||
i += wmma_sz[2]
|
||||
|
@ -312,7 +317,7 @@ class Linearizer(Kernel):
|
|||
self.uop(UOps.BARRIER, None, (), cachable=False)
|
||||
|
||||
# load earlybufs
|
||||
loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs[1:], start=1) if b in self.earlybufs})
|
||||
loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs[1:], start=1) if b in self.earlybufs}) # noqa: E501
|
||||
|
||||
# run early AST (with reduce)
|
||||
self.ast_parse(self.reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx)
|
||||
|
@ -332,7 +337,7 @@ class Linearizer(Kernel):
|
|||
barrier = self.uop(UOps.IF, None, (if_cond, barrier), cachable=False)
|
||||
|
||||
# create new late reduce local loops and replace local_idxs that have been used
|
||||
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))]
|
||||
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))] # noqa: E501
|
||||
local_idxs = local_idxs[:self.local_dims] + end_local_idxs[self.global_dims + self.local_dims:]
|
||||
|
||||
# if any group_for_reduce items aren't reduces, upcast them here
|
||||
|
@ -348,7 +353,7 @@ class Linearizer(Kernel):
|
|||
# NOTE: this structure is the same as the reduce op above
|
||||
|
||||
# define late accumulator
|
||||
acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop.op, self.bufs[-1].dtype))
|
||||
acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop.op, self.bufs[-1].dtype)) # noqa: E501
|
||||
|
||||
# late reduce loop
|
||||
loop_ctx = render_loop(end_local_idxs)
|
||||
|
@ -357,13 +362,13 @@ class Linearizer(Kernel):
|
|||
loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier)
|
||||
|
||||
# there's no AST here (and there's no shape for the reduce LazyOp)
|
||||
self.ast_parse(LazyOp(self.reduceop.op, (LazyOp(BufferOps.LOAD, (), self.bufs[-1]),)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx)
|
||||
self.ast_parse(LazyOp(self.reduceop.op, (LazyOp(BufferOps.LOAD, (), self.bufs[-1]),)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) # noqa: E501
|
||||
|
||||
# end the late reduce loop
|
||||
self.load_cache.clear()
|
||||
|
||||
# load latebufs
|
||||
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer})
|
||||
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer}) # noqa: E501
|
||||
|
||||
# run late AST (without the store)
|
||||
val = self.ast_parse(cast(LazyOp, self.ast.src[0]), acc, None, loaded_buffers)
|
||||
|
@ -440,7 +445,7 @@ class Linearizer(Kernel):
|
|||
for u in self.uops:
|
||||
if u.uop == UOps.LOOP:
|
||||
# add END of loops after the last thing that (recursively) depends on them
|
||||
self.uop(UOps.END, None, (u,), cachable=False, insert_before=self.uops.index(sorted(list(get_recursive_children(u)), key=self.uops.index)[-1])+1)
|
||||
self.uop(UOps.END, None, (u,), cachable=False, insert_before=self.uops.index(sorted(list(get_recursive_children(u)), key=self.uops.index)[-1])+1) # noqa: E501
|
||||
elif u.uop == UOps.IF:
|
||||
# END any if statements at the end of the uops
|
||||
self.uop(UOps.END, None, (u,), cachable=False)
|
||||
|
@ -448,7 +453,7 @@ class Linearizer(Kernel):
|
|||
# maybe graph the uops
|
||||
if DEBUG >= 5:
|
||||
for u in self.uops:
|
||||
print(f"{self.uops.index(u):4d} {str(u.uop):20s}: {str(u.dtype) if u.dtype is not None else '':25s} {str([self.uops.index(x) for x in u.vin]):32s} {u.arg}")
|
||||
print(f"{self.uops.index(u):4d} {str(u.uop):20s}: {str(u.dtype) if u.dtype is not None else '':25s} {str([self.uops.index(x) for x in u.vin]):32s} {u.arg}") # noqa: E501
|
||||
if getenv("GRAPHUOPS"):
|
||||
from tinygrad.graph import graph_uops
|
||||
graph_uops(self.uops)
|
||||
|
@ -460,7 +465,7 @@ class Linearizer(Kernel):
|
|||
self.applied_opts_cache = self.applied_opts[:]
|
||||
return self
|
||||
|
||||
def uop(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None, cachable=True, insert_before=None, simplify=True) -> UOp:
|
||||
def uop(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None, cachable=True, insert_before=None, simplify=True) -> UOp: # noqa: E501
|
||||
key = (uop, dtype, vin, arg)
|
||||
if uop == UOps.PHI and vin[1].dtype != dtype: vin = (vin[0], self.cast(vin[1], dtype)) + vin[1:]
|
||||
if uop == UOps.ALU: # upcast vins to the same dtype
|
||||
|
@ -471,10 +476,12 @@ class Linearizer(Kernel):
|
|||
if simplify:
|
||||
if uop == UOps.PHI and len(vin) == 2: return vin[1] # a phi without loops is a noop
|
||||
if uop == UOps.GEP and vin[0].uop == UOps.CONST: return self.const(vin[0].arg, dtype, insert_before)
|
||||
if uop == UOps.CAST and all(x.uop == UOps.CONST for x in vin) and all_same([x.arg for x in vin]): return self.const(vin[0].arg, dtype, insert_before)
|
||||
if uop == UOps.CAST and all(x.uop == UOps.CONST for x in vin) and all_same([x.arg for x in vin]):
|
||||
return self.const(vin[0].arg, dtype, insert_before)
|
||||
if uop == UOps.ALU:
|
||||
# rewrites. NOTE: the rewritten NEG op is still around...
|
||||
if arg == BinaryOps.ADD and vin[1].uop == UOps.ALU and vin[1].arg == UnaryOps.NEG: return self.uop(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable=cachable, insert_before=insert_before)
|
||||
if arg == BinaryOps.ADD and vin[1].uop == UOps.ALU and vin[1].arg == UnaryOps.NEG:
|
||||
return self.uop(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable=cachable, insert_before=insert_before)
|
||||
# constant folding
|
||||
if arg == UnaryOps.NEG and vin[0].uop == UOps.CONST: return self.const(-vin[0].arg, dtype, insert_before)
|
||||
if arg == TernaryOps.WHERE and vin[1] == vin[2]: return vin[1] # a conditional with the same results either way is a noop
|
||||
|
@ -487,7 +494,8 @@ class Linearizer(Kernel):
|
|||
if arg == BinaryOps.DIV and vin[1].uop == UOps.CONST and vin[1].arg == 1.0: return vin[0]
|
||||
|
||||
# When insert_before is set, need to check if the cached expr is valid with the given insert place.
|
||||
if cachable and (expr:=self.saved_exprs.get(key, None)) is not None and (insert_before is None or self.uops.index(expr) <= insert_before): return expr
|
||||
if cachable and (expr:=self.saved_exprs.get(key, None)) is not None and (insert_before is None or self.uops.index(expr) <= insert_before):
|
||||
return expr
|
||||
ret = UOp(uop, dtype, vin, arg)
|
||||
if insert_before is not None:
|
||||
self.uops.insert(insert_before, ret)
|
||||
|
@ -496,16 +504,16 @@ class Linearizer(Kernel):
|
|||
if cachable: self.saved_exprs[key] = ret
|
||||
return ret
|
||||
|
||||
def ast_parse(self, x:LazyOp, acc: List[UOp], offs:Optional[List[int]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], do_reduce=False, loop_ctx=tuple()) -> List[UOp]:
|
||||
def ast_parse(self, x:LazyOp, acc: List[UOp], offs:Optional[List[int]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], do_reduce=False, loop_ctx=tuple()) -> List[UOp]: # noqa: E501
|
||||
if x.op in BufferOps: return loaded_buffers[x.arg]
|
||||
if x.op == UnaryOps.CAST: return [self.uop(UOps.CAST, x.arg[0], (u,), x.arg) if not isinstance(x.arg[0], ImageDType) else u for u in self.ast_parse(cast(LazyOp, x.src[0]), acc, offs, loaded_buffers)]
|
||||
if x.op == UnaryOps.CAST: return [self.uop(UOps.CAST, x.arg[0], (u,), x.arg) if not isinstance(x.arg[0], ImageDType) else u for u in self.ast_parse(cast(LazyOp, x.src[0]), acc, offs, loaded_buffers)] # noqa: E501
|
||||
if x.op in ReduceOps and not do_reduce:
|
||||
assert offs is None, "not available if we aren't doing reduce"
|
||||
return acc
|
||||
# MULACC fusion. TODO: this is copied from Interpreted
|
||||
if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == BinaryOps.MUL:
|
||||
x = LazyOp(TernaryOps.MULACC, x.src[0].src, x.arg)
|
||||
if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == UnaryOps.CAST and x.src[0].src[0].__class__ is LazyOp and x.src[0].src[0].op == BinaryOps.MUL:
|
||||
if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == UnaryOps.CAST and x.src[0].src[0].__class__ is LazyOp and x.src[0].src[0].op == BinaryOps.MUL: # noqa: E501
|
||||
x = LazyOp(TernaryOps.MULACC, x.src[0].src[0].src, x.arg)
|
||||
values = [self.ast_parse(cast(LazyOp, v), acc, offs, loaded_buffers, loop_ctx=loop_ctx) for v in x.src]
|
||||
ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, TernaryOps.MULACC:TernaryOps.MULACC}
|
||||
|
|
|
@ -3,7 +3,8 @@ import numpy as np
|
|||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Union, Any, List, Optional, Dict, Callable
|
||||
import importlib, inspect, functools, pathlib, time, re, ctypes
|
||||
from tinygrad.helpers import ansilen, DEBUG, getenv, GlobalCounters, colored, BEAM, NOOPT, all_int, to_function_name, DType, from_mv, dtypes, flat_mv, ImageDType
|
||||
from tinygrad.helpers import DType, dtypes, ImageDType
|
||||
from tinygrad.helpers import ansilen, DEBUG, getenv, GlobalCounters, colored, BEAM, NOOPT, all_int, to_function_name, from_mv, flat_mv
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer, sint
|
||||
from tinygrad.ops import LazyOp, TernaryOps, get_lazyop_info, ReduceOps, BufferOps, BinaryOps, UnaryOps, Op, vars_from_ast
|
||||
|
||||
|
@ -14,13 +15,13 @@ if TYPE_CHECKING:
|
|||
# **************** Device ****************
|
||||
|
||||
class _Device:
|
||||
def __init__(self) -> None: self._buffers: List[str] = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
|
||||
def canonicalize(self, device:Optional[str]) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") if device is not None else self.DEFAULT
|
||||
def __init__(self) -> None: self._buffers: List[str] = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")] # noqa: E501
|
||||
def canonicalize(self, device:Optional[str]) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") if device is not None else self.DEFAULT # noqa: E501
|
||||
def __getitem__(self, ix:str) -> Union[Interpreted, Compiled]: return self.__get_canonicalized_item(self.canonicalize(ix))
|
||||
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
|
||||
def __get_canonicalized_item(self, ix:str) -> Union[Interpreted, Compiled]:
|
||||
x = ix.split(":")[0].upper()
|
||||
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._buffers][0]
|
||||
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._buffers][0] # noqa: E501
|
||||
if isinstance(ret, type): ret = ret(ix)
|
||||
return ret
|
||||
@functools.cached_property
|
||||
|
@ -48,7 +49,7 @@ class JITRunner:
|
|||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
|
||||
raise NotImplementedError("override this")
|
||||
|
||||
def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count:int, jit=False, num_kernels=1, lra: Optional[Dict]=None, device:str="", first_run=False):
|
||||
def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count:int, jit=False, num_kernels=1, lra: Optional[Dict]=None, device:str="", first_run=False): # noqa: E501
|
||||
if var_vals is None: var_vals = {}
|
||||
op_estimate, mem_estimate = sym_infer(op_estimate, var_vals), sym_infer(mem_estimate, var_vals)
|
||||
GlobalCounters.kernel_count += num_kernels
|
||||
|
@ -57,8 +58,8 @@ def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Option
|
|||
if et is not None: GlobalCounters.time_sum_s += et
|
||||
if DEBUG >= 2:
|
||||
ptm = (colored(f"{et*1e3:9.2f}ms", "RED") if et > 0.01 else f"{et*1e6:9.2f}us") if et is not None else ""
|
||||
print(f"{colored(f'** {device[:7]:7} {GlobalCounters.kernel_count:4d}', ('magenta' if num_kernels == 1 else 'CYAN') if jit else ('green' if first_run else None))} {name+' '*(37-ansilen(name))} arg {buf_count:3d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
||||
(str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
|
||||
print(f"{colored(f'** {device[:7]:7} {GlobalCounters.kernel_count:4d}', ('magenta' if num_kernels == 1 else 'CYAN') if jit else ('green' if first_run else None))} {name+' '*(37-ansilen(name))} arg {buf_count:3d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
|
||||
(str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) # noqa: E501
|
||||
|
||||
# **************** Buffer / Allocator ****************
|
||||
|
||||
|
@ -85,13 +86,14 @@ class Buffer:
|
|||
def fromCPU(device:str, x:np.ndarray): return Buffer(device, x.size, dtypes.from_np(x.dtype)).copyin(x.data)
|
||||
def toCPU(self) -> np.ndarray:
|
||||
# zero copy with as_buffer
|
||||
if hasattr(self.allocator, 'as_buffer'): return np.frombuffer(self.allocator.as_buffer(self._buf), dtype=np.dtype(self.dtype.np, metadata={"backing": self._buf})) # type: ignore
|
||||
if hasattr(self.allocator, 'as_buffer'):
|
||||
return np.frombuffer(self.allocator.as_buffer(self._buf), dtype=np.dtype(self.dtype.np, metadata={"backing": self._buf})) # type: ignore
|
||||
ret = np.empty(self.size, self.dtype.np)
|
||||
if self.size > 0: self.allocator.copyout(flat_mv(ret.data), self._buf)
|
||||
return ret
|
||||
|
||||
def _internal_buffer_copy(dest, src):
|
||||
if hasattr(dest.allocator, 'transfer') and type(dest.allocator) is type(src.allocator):
|
||||
if hasattr(dest.allocator, 'transfer') and type(dest.allocator) is type(src.allocator): # noqa: E721
|
||||
# fast path, used on HIP between GPUs
|
||||
# NOTE: it's important we use the dest device here to ensure the transfer is ready
|
||||
Device[src.device].synchronize() # TODO: async this
|
||||
|
@ -124,7 +126,7 @@ class _BufferCopy(JITRunner):
|
|||
if wait or DEBUG >= 2:
|
||||
Device[dest.device].synchronize()
|
||||
et = time.perf_counter() - st
|
||||
update_stats(colored(f"copy {dest.size:8d}, {dest.device[:7]:7} <- {src.device[:7]:7}", "yellow"), 0, dest.size*dest.dtype.itemsize, {}, et, 2, jit, lra={"global_size": dest.size}, device=dest.device)
|
||||
update_stats(colored(f"copy {dest.size:8d}, {dest.device[:7]:7} <- {src.device[:7]:7}", "yellow"), 0, dest.size*dest.dtype.itemsize, {}, et, 2, jit, device=dest.device) # noqa: E501
|
||||
BufferCopy = _BufferCopy()
|
||||
|
||||
# TODO: size, dest, src are the same type. can we enforce this?
|
||||
|
@ -232,7 +234,7 @@ def _get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], ast:LazyOp) -> Interpret
|
|||
# **************** for Compiled Devices ****************
|
||||
|
||||
class CompiledASTRunner(JITRunner):
|
||||
def __init__(self, ast:Optional[LazyOp], name:str, prg:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, runtime_args:Optional[dict]=None):
|
||||
def __init__(self, ast:Optional[LazyOp], name:str, prg:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, runtime_args:Optional[dict]=None): # noqa: E501
|
||||
super().__init__()
|
||||
if DEBUG >= 4: print(prg)
|
||||
if global_size is not None: global_size = global_size + [1]*(3-len(global_size))
|
||||
|
@ -267,13 +269,13 @@ class CompiledASTRunner(JITRunner):
|
|||
if global_size: lra['global_size'] = global_size
|
||||
if local_size and 'local_size' not in lra: lra['local_size'] = local_size
|
||||
et = self.clprg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.vars), wait=wait or DEBUG>=2)
|
||||
update_stats(self.display_name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, lra=lra, device=rawbufs[0].device, first_run=self.first_run)
|
||||
update_stats(self.display_name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, lra=lra, device=rawbufs[0].device, first_run=self.first_run) # noqa: E501
|
||||
self.first_run = False
|
||||
return et
|
||||
|
||||
class Compiled:
|
||||
def __init__(self, allocator:Allocator, linearizer_opts:LinearizerOptions, renderer, compiler, runtime, graph=None):
|
||||
self.allocator, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.graph = allocator, linearizer_opts, renderer, compiler, runtime, graph
|
||||
self.allocator, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.graph = allocator, linearizer_opts, renderer, compiler, runtime, graph # noqa: E501
|
||||
def synchronize(self): pass # override this in your device
|
||||
|
||||
def to_program(self, k:Linearizer) -> CompiledASTRunner:
|
||||
|
|
|
@ -5,7 +5,7 @@ from tinygrad.helpers import init_c_var, encode_args_cuda_style
|
|||
from tinygrad.device import CompiledASTRunner, update_stats, Buffer
|
||||
from tinygrad.runtime.ops_cuda import check, cu_time_execution
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals, GraphException
|
||||
from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals, GraphException # noqa: E501
|
||||
|
||||
class CUDAGraph:
|
||||
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
|
@ -27,7 +27,7 @@ class CUDAGraph:
|
|||
prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg)
|
||||
|
||||
c_deps = (type(graph_node)*1)(*(graph_node,)) if graph_node is not None else None
|
||||
c_kernel_input_config, c_input_params = encode_args_cuda_style([cast(Buffer, x)._buf for x in ji.rawbufs], [var_vals[x] for x in prg.vars], *self.encode_args_info())
|
||||
c_kernel_input_config, c_input_params = encode_args_cuda_style([cast(Buffer, x)._buf for x in ji.rawbufs], [var_vals[x] for x in prg.vars], *self.encode_args_info()) # noqa: E501
|
||||
c_node_params = self.build_kernel_node_params(prg, *cast(Tuple[List[int], List[int]], prg.launch_dims(var_vals)), c_kernel_input_config)
|
||||
graph_node = self.graph_add_kernel_node(self.graph, c_deps, c_node_params)
|
||||
|
||||
|
@ -55,7 +55,7 @@ class CUDAGraph:
|
|||
self.graph_exec_kernel_node_set_params(self.instance, node, ctypes.byref(c_node_params))
|
||||
|
||||
et = self.graph_launch(self.instance, None, wait=wait)
|
||||
update_stats(f"<batched {len(self.jit_cache)}>", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=jit, num_kernels=len(self.jit_cache))
|
||||
update_stats(f"<batched {len(self.jit_cache)}>", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=jit, num_kernels=len(self.jit_cache)) # noqa: E501
|
||||
return et
|
||||
|
||||
def __del__(self):
|
||||
|
@ -64,9 +64,13 @@ class CUDAGraph:
|
|||
|
||||
def encode_args_info(self): return (cuda.CUdeviceptr_v2, (1,2,0))
|
||||
def graph_create(self): return init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
|
||||
def graph_instantiate(self, graph): return init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), graph, None, None, 0)))
|
||||
def graph_add_kernel_node(self, graph, c_deps, c_node_params): return init_c_var(cuda.CUgraphNode(), lambda x: check(cuda.cuGraphAddKernelNode(ctypes.byref(x), graph, c_deps, ctypes.sizeof(c_deps)//8 if c_deps else 0, ctypes.byref(c_node_params))))
|
||||
def graph_instantiate(self, graph):
|
||||
return init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), graph, None, None, 0)))
|
||||
def graph_add_kernel_node(self, graph, c_deps, c_node_params):
|
||||
return init_c_var(cuda.CUgraphNode(), lambda x: check(cuda.cuGraphAddKernelNode(ctypes.byref(x), graph, c_deps, ctypes.sizeof(c_deps)//8 if c_deps else 0, ctypes.byref(c_node_params)))) # noqa: E501
|
||||
def graph_launch(self, *args, wait=False): return cu_time_execution(lambda: check(cuda.cuGraphLaunch(*args)), enable=wait)
|
||||
def graph_exec_kernel_node_set_params(self, *args): return check(cuda.cuGraphExecKernelNodeSetParams(*args))
|
||||
def build_kernel_node_params(self, prg, global_size, local_size, c_kernel_config): return cuda.CUDA_KERNEL_NODE_PARAMS(prg.clprg.prg, *global_size, *local_size, 0, None, c_kernel_config)
|
||||
def set_kernel_node_launch_dims(self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int]): node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_size, *global_size
|
||||
def build_kernel_node_params(self, prg, global_size, local_size, c_kernel_config):
|
||||
return cuda.CUDA_KERNEL_NODE_PARAMS(prg.clprg.prg, *global_size, *local_size, 0, None, c_kernel_config)
|
||||
def set_kernel_node_launch_dims(self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int]):
|
||||
node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_size, *global_size
|
||||
|
|
|
@ -12,9 +12,13 @@ class HIPGraph(CUDAGraph):
|
|||
|
||||
def encode_args_info(self): return (hip.hipDeviceptr_t, (1,2,3))
|
||||
def graph_create(self): return init_c_var(hip.hipGraph_t(), lambda x: check(hip.hipGraphCreate(ctypes.byref(x), 0)))
|
||||
def graph_instantiate(self, graph): return init_c_var(hip.hipGraphExec_t(), lambda x: check(hip.hipGraphInstantiate(ctypes.byref(x), graph, None, None, 0)))
|
||||
def graph_add_kernel_node(self, graph, c_deps, c_params): return init_c_var(hip.hipGraphNode_t(), lambda x: check(hip.hipGraphAddKernelNode(ctypes.byref(x), graph, c_deps, ctypes.sizeof(c_deps)//8 if c_deps else 0, ctypes.byref(c_params))))
|
||||
def graph_instantiate(self, graph):
|
||||
return init_c_var(hip.hipGraphExec_t(), lambda x: check(hip.hipGraphInstantiate(ctypes.byref(x), graph, None, None, 0)))
|
||||
def graph_add_kernel_node(self, graph, c_deps, c_params):
|
||||
return init_c_var(hip.hipGraphNode_t(), lambda x: check(hip.hipGraphAddKernelNode(ctypes.byref(x), graph, c_deps, ctypes.sizeof(c_deps)//8 if c_deps else 0, ctypes.byref(c_params)))) # noqa: E501
|
||||
def graph_launch(self, *args, wait=False): return hip_time_execution(lambda: check(hip.hipGraphLaunch(*args)), enable=wait)
|
||||
def graph_exec_kernel_node_set_params(self, *args): return check(hip.hipGraphExecKernelNodeSetParams(*args))
|
||||
def build_kernel_node_params(self, prg, global_size, local_size, c_config): return hip.hipKernelNodeParams(hip.dim3(*local_size), c_config, ctypes.cast(prg.clprg.prg, ctypes.c_void_p), hip.dim3(*global_size), None, 0)
|
||||
def set_kernel_node_launch_dims(self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int]): node.blockDim.x, node.blockDim.y, node.blockDim.z, node.gridDim.x, node.gridDim.y, node.gridDim.z = *local_size, *global_size
|
||||
def build_kernel_node_params(self, prg, global_size, local_size, c_config):
|
||||
return hip.hipKernelNodeParams(hip.dim3(*local_size), c_config, ctypes.cast(prg.clprg.prg, ctypes.c_void_p), hip.dim3(*global_size), None, 0)
|
||||
def set_kernel_node_launch_dims(self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int]):
|
||||
node.blockDim.x, node.blockDim.y, node.blockDim.z, node.gridDim.x, node.gridDim.y, node.gridDim.z = *local_size, *global_size
|
||||
|
|
|
@ -23,7 +23,7 @@ class MetalGraph:
|
|||
icb_descriptor.setInheritBuffers_(False)
|
||||
icb_descriptor.setInheritPipelineState_(False)
|
||||
icb_descriptor.setMaxKernelBufferBindCount_(31)
|
||||
self.icb = self.device.device.newIndirectCommandBufferWithDescriptor_maxCommandCount_options_(icb_descriptor, len(self.jit_cache), Metal.MTLResourceOptions(0))
|
||||
self.icb = self.device.device.newIndirectCommandBufferWithDescriptor_maxCommandCount_options_(icb_descriptor, len(self.jit_cache), Metal.MTLResourceOptions(0)) # noqa: E501
|
||||
if self.icb is None: raise GraphException("create indirect command buffer failed, does your system support this?")
|
||||
|
||||
if len(var_vals): self.int_buf = self.device.allocator.alloc(len(var_vals)*dtypes.int32.itemsize)
|
||||
|
@ -33,7 +33,7 @@ class MetalGraph:
|
|||
descriptor = Metal.MTLComputePipelineDescriptor.new()
|
||||
descriptor.setComputeFunction_(prg.clprg.fxn)
|
||||
descriptor.setSupportIndirectCommandBuffers_(True)
|
||||
pipeline_state = unwrap2(self.device.device.newComputePipelineStateWithDescriptor_options_reflection_error_(descriptor, Metal.MTLPipelineOption(0), None, None))
|
||||
pipeline_state = unwrap2(self.device.device.newComputePipelineStateWithDescriptor_options_reflection_error_(descriptor, Metal.MTLPipelineOption(0), None, None)) # noqa: E501
|
||||
icb_command = self.icb.indirectComputeCommandAtIndex_(j)
|
||||
icb_command.setComputePipelineState_(pipeline_state)
|
||||
for i,b in enumerate(ji.rawbufs):
|
||||
|
@ -59,7 +59,7 @@ class MetalGraph:
|
|||
self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_idx]._buf, 0, i)
|
||||
for j in self.jc_idx_with_updatable_launch_dims:
|
||||
global_size, local_size = cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals)
|
||||
self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
|
||||
self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size)) # noqa: E501
|
||||
if len(var_vals): self.int_buf_view[:] = list(var_vals.values())
|
||||
command_buffer = self.device.mtl_queue.commandBuffer()
|
||||
encoder = command_buffer.computeCommandEncoder()
|
||||
|
@ -74,5 +74,5 @@ class MetalGraph:
|
|||
else:
|
||||
self.device.mtl_buffers_in_flight.append(command_buffer)
|
||||
et = None
|
||||
update_stats(f"<batched {len(self.jit_cache)}>", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=jit, num_kernels=len(self.jit_cache))
|
||||
update_stats(f"<batched {len(self.jit_cache)}>", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=jit, num_kernels=len(self.jit_cache)) # noqa: E501
|
||||
return et
|
|
@ -7,7 +7,7 @@ def image_dot(self, w):
|
|||
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
|
||||
n1, n2 = len(self.shape), len(w.shape)
|
||||
assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
|
||||
assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})"
|
||||
assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" # noqa: E501
|
||||
bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2])
|
||||
cin, cout = w.shape[-2], w.shape[-1]
|
||||
out_shape_t = self.shape[0:-2] + (cout,-1)
|
||||
|
|
|
@ -75,7 +75,7 @@ def compile_linearizer(dev:str, lin:Linearizer, name:Optional[str]=None) -> Tupl
|
|||
src, _ = rdev.renderer(name if name is not None else to_function_name(lin.name), lin.uops) # NOTE: these all have the same name for deduping
|
||||
return rdev.compiler(src), lin.global_size, lin.local_size
|
||||
|
||||
def time_program(dev:str, lib:bytes, global_size, local_size, var_vals, rawbufs, early_stop=None, max_global_size=65536, clear_l2=False, cnt=3, name="test"):
|
||||
def time_program(dev:str, lib:bytes, global_size, local_size, var_vals, rawbufs, early_stop=None, max_global_size=65536, clear_l2=False, cnt=3, name="test"): # noqa: E501
|
||||
rdev = Device[dev]
|
||||
assert isinstance(rdev, Compiled)
|
||||
clprg = rdev.runtime(name, lib)
|
||||
|
@ -124,21 +124,21 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea
|
|||
while not exiting:
|
||||
acted_lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam]) if len(beam) else [lin]
|
||||
timed_lins: List[Tuple[Linearizer, float]] = []
|
||||
for i,proc in (pool.imap_unordered(try_compile_linearized_w_idx, enumerate(acted_lins)) if pool is not None else map(try_compile_linearized_w_idx, enumerate(acted_lins))):
|
||||
for i,proc in (pool.imap_unordered(try_compile_linearized_w_idx, enumerate(acted_lins)) if pool is not None else map(try_compile_linearized_w_idx, enumerate(acted_lins))): # noqa: E501
|
||||
if proc is None: continue
|
||||
lib, global_size, local_size = proc
|
||||
if lib in seen_libs: continue
|
||||
seen_libs.add(lib)
|
||||
tms = time_program(Device.DEFAULT, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0) # > 1 second, run one time
|
||||
tms = time_program(Device.DEFAULT, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0)
|
||||
timed_lins.append((acted_lins[i], min(tms)))
|
||||
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="")
|
||||
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
|
||||
|
||||
# done
|
||||
opts = sorted(timed_lins, key=lambda x: x[1])
|
||||
exiting = len(opts) == 0 or (len(beam) > 0 and beam[0][1] <= opts[0][1])
|
||||
if not exiting: beam = opts[:amt]
|
||||
assert len(beam) > 0, "no BEAM items succeeded?!?"
|
||||
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape())
|
||||
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape()) # noqa: E501
|
||||
if pool is not None: pool.close() # the pool is closed
|
||||
except KeyboardInterrupt as e:
|
||||
if pool is not None: pool.terminate()
|
||||
|
@ -155,20 +155,20 @@ def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buff
|
|||
local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
|
||||
def try_exec(local_size):
|
||||
try:
|
||||
return clprg(*[x._buf for x in test_rawbuffers], global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size=local_size, wait=True)
|
||||
return clprg(*[x._buf for x in test_rawbuffers], global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size=local_size, wait=True) # noqa: E501
|
||||
except Exception:
|
||||
return float('inf')
|
||||
ret = min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])
|
||||
assert not math.isinf(ret[0]), "all optimize_local_size exec failed"
|
||||
return ret[1]
|
||||
|
||||
def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float:
|
||||
key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2, "device": Device.DEFAULT}
|
||||
def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
|
||||
key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2, "device": Device.DEFAULT} # noqa: E501
|
||||
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
|
||||
|
||||
var_vals = {k:(k.max+k.min)//2 for k in vars_from_ast(lin.ast)}
|
||||
lib, global_size, local_size = compile_linearizer(Device.DEFAULT, lin)
|
||||
tms = time_program(Device.DEFAULT, lib, global_size, local_size, var_vals, rawbufs, max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))
|
||||
tms = time_program(Device.DEFAULT, lib, global_size, local_size, var_vals, rawbufs, max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name)) # noqa: E501
|
||||
|
||||
if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms)
|
||||
return min(tms)
|
||||
|
|
|
@ -14,8 +14,8 @@ cnts: Dict[OpType, int] = defaultdict(int)
|
|||
if DEBUG >= 2:
|
||||
def print_globalcounters():
|
||||
if GlobalCounters.time_sum_s == 0: return
|
||||
print(f"avg: {GlobalCounters.global_ops*1e-9/GlobalCounters.time_sum_s:8.2f} GFLOPS {GlobalCounters.global_mem*1e-9/GlobalCounters.time_sum_s:8.2f} GB/s",
|
||||
f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms")
|
||||
print(f"avg: {GlobalCounters.global_ops*1e-9/GlobalCounters.time_sum_s:8.2f} GFLOPS {GlobalCounters.global_mem*1e-9/GlobalCounters.time_sum_s:8.2f} GB/s", # noqa: E501
|
||||
f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms") # noqa: E501
|
||||
atexit.register(print_globalcounters)
|
||||
if GRAPH:
|
||||
import networkx as nx
|
||||
|
@ -52,7 +52,7 @@ def add_st_node(nmx, nmo, label, st:ShapeTracker):
|
|||
inter_node = node_count
|
||||
node_count += 1
|
||||
offset = st.expr_node(NumNode(0))[0]
|
||||
G.add_node(inter_node, style='filled', fillcolor="#80ff8080", color="black", label=f"{st.shape}\n{st.real_strides()}" + (f"\n{offset}" if offset != 0 else ""))
|
||||
G.add_node(inter_node, style='filled', fillcolor="#80ff8080", color="black", label=f"{st.shape}\n{st.real_strides()}" + (f"\n{offset}" if offset != 0 else "")) # noqa: E501
|
||||
G.add_edge(nmx, inter_node, color='#00000060')
|
||||
G.add_edge(inter_node, nmo, label=label, color='#00000060')
|
||||
|
||||
|
@ -69,7 +69,8 @@ def log_schedule_item(si: ScheduleItem):
|
|||
cnts[optype] += 1
|
||||
if GRAPH:
|
||||
assert si.out.base == si.out, "all outputs based"
|
||||
top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#FFA0A0", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", TernaryOps: "#c0c0c0", BufferOps: '#a0a0ff'}
|
||||
top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#FFA0A0", BinaryOps: "#c0c0c0",
|
||||
MovementOps: "#80ff80", TernaryOps: "#c0c0c0", BufferOps: '#a0a0ff'}
|
||||
|
||||
# get inputs for shapetrackers
|
||||
input_to_st = defaultdict(list)
|
||||
|
@ -89,13 +90,14 @@ def log_schedule_item(si: ScheduleItem):
|
|||
|
||||
if nm(si.out) not in G.nodes: G.add_node(nm(si.out))
|
||||
|
||||
G.nodes[nm(si.out)]['label'] = '"' + (str(set(x.shape for x in si.inputs))+"\n"+str(si.out.shape) if optype == ReduceOps else str(si.out.shape))+str_dtype(si.out.dtype)+(f"\n{si.ast.op}" if si.ast.op in LoadOps or optype is BufferOps else "")+(f"\n{si.out.device}" if si.out.device != Device.DEFAULT else "") + '"'
|
||||
G.nodes[nm(si.out)]['label'] = '"' + (str(set(x.shape for x in si.inputs))+"\n"+str(si.out.shape) if optype == ReduceOps else str(si.out.shape))+str_dtype(si.out.dtype)+(f"\n{si.ast.op}" if si.ast.op in LoadOps or optype is BufferOps else "")+(f"\n{si.out.device}" if si.out.device != Device.DEFAULT else "") + '"' # noqa: E501
|
||||
G.nodes[nm(si.out)]['fillcolor'] = top_colors[optype]
|
||||
G.nodes[nm(si.out)]['color'] = 'black'
|
||||
G.nodes[nm(si.out)]['style'] = 'filled'
|
||||
|
||||
def _tree(lazydata, prefix=""):
|
||||
if type(lazydata).__name__ == "LazyBuffer": return [f"━━ realized {lazydata.dtype.name} {lazydata.shape}"] if (lazydata.realized) else _tree(lazydata.op, "LB ")
|
||||
if type(lazydata).__name__ == "LazyBuffer":
|
||||
return [f"━━ realized {lazydata.dtype.name} {lazydata.shape}"] if (lazydata.realized) else _tree(lazydata.op, "LB ")
|
||||
if len(lazydata.src) == 0: return [f"━━ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"]
|
||||
lines = [f"━┳ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"]
|
||||
childs = [_tree(c) for c in lazydata.src[:]]
|
||||
|
@ -112,7 +114,7 @@ def graph_uops(uops:List[UOp]):
|
|||
G = nx.DiGraph()
|
||||
for u in uops:
|
||||
if u.uop == UOps.END: continue
|
||||
G.add_node(uops.index(u), label=f"{str(u.uop)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.uop, "#ffffff"))
|
||||
G.add_node(uops.index(u), label=f"{str(u.uop)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.uop, "#ffffff")) # noqa: E501
|
||||
for v in u.vin: G.add_edge(uops.index(v), uops.index(u))
|
||||
GRAPHPATH = "/tmp/uops"
|
||||
nx.drawing.nx_pydot.write_dot(G, f'{GRAPHPATH}.dot')
|
||||
|
|
|
@ -21,7 +21,7 @@ def argfix(*x): return tuple(x[0]) if x and x[0].__class__ in (tuple, list) else
|
|||
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
|
||||
def all_same(items:List[T]): return all(x == items[0] for x in items)
|
||||
def all_int(t: Tuple[Any, ...]) -> TypeGuard[Tuple[int, ...]]: return all(isinstance(s, int) for s in t)
|
||||
def colored(st, color:Optional[str], background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line
|
||||
def colored(st, color:Optional[str], background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line # noqa: E501
|
||||
def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s)
|
||||
def ansilen(s:str): return len(ansistrip(s))
|
||||
def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x
|
||||
|
@ -30,7 +30,7 @@ def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
|
|||
def strip_parens(fst:str): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst
|
||||
def round_up(num, amt:int): return (num+amt-1)//amt * amt
|
||||
def merge_dicts(ds:Iterable[Dict[T,U]]) -> Dict[T,U]:
|
||||
assert len(kvs:=set([(k,v) for d in ds for k,v in d.items()])) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
|
||||
assert len(kvs:=set([(k,v) for d in ds for k,v in d.items()])) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key" # noqa: E501
|
||||
return {k:v for d in ds for k,v in d.items()}
|
||||
def partition(lst:List[T], fxn:Callable[[T],bool]):
|
||||
a:List[T] = []
|
||||
|
@ -136,10 +136,10 @@ class PtrDType(DType):
|
|||
def __repr__(self): return f"ptr.{super().__repr__()}"
|
||||
|
||||
class dtypes:
|
||||
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
|
||||
def is_int(x: DType)-> bool: return x.scalar() in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
|
||||
@staticmethod
|
||||
def is_float(x: DType) -> bool: return x.scalar() in (dtypes.float16, dtypes.float32, dtypes.float64)
|
||||
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
|
||||
def is_int(x: DType) -> bool: return x.scalar() in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64) or dtypes.is_unsigned(x)
|
||||
@staticmethod
|
||||
def is_unsigned(x: DType) -> bool: return x.scalar() in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
|
||||
@staticmethod
|
||||
|
@ -187,11 +187,13 @@ class dtypes:
|
|||
# we don't support weak type and complex type
|
||||
promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8],
|
||||
dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64], dtypes.int64: [dtypes.float16, dtypes.bfloat16],
|
||||
dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32], dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.float16, dtypes.bfloat16],
|
||||
dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32],
|
||||
dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.float16, dtypes.bfloat16],
|
||||
dtypes.float16: [dtypes.float32], dtypes.bfloat16: [dtypes.float32], dtypes.float32: [dtypes.float64], }
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _get_recursive_parents(dtype:DType) -> Set[DType]: return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64}
|
||||
def _get_recursive_parents(dtype:DType) -> Set[DType]:
|
||||
return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64}
|
||||
@functools.lru_cache(None)
|
||||
def least_upper_dtype(*ds:DType) -> DType: return min(set.intersection(*[_get_recursive_parents(d) for d in ds]))
|
||||
|
||||
|
@ -247,7 +249,7 @@ def diskcache_put(table:str, key:Union[Dict, str, int], val:Any):
|
|||
ltypes = ', '.join(f"{k} {TYPES[type(key[k])]}" for k in key.keys())
|
||||
cur.execute(f"CREATE TABLE IF NOT EXISTS {table}_{VERSION} ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
|
||||
_db_tables.add(table)
|
||||
cur.execute(f"REPLACE INTO {table}_{VERSION} ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (pickle.dumps(val), ))
|
||||
cur.execute(f"REPLACE INTO {table}_{VERSION} ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (pickle.dumps(val), )) # noqa: E501
|
||||
conn.commit()
|
||||
cur.close()
|
||||
return val
|
||||
|
@ -264,7 +266,7 @@ def diskcache(func):
|
|||
|
||||
def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path:
|
||||
if url.startswith("/") or url.startswith("."): return pathlib.Path(url)
|
||||
fp = pathlib.Path(name) if name is not None and (isinstance(name, pathlib.Path) or '/' in name) else pathlib.Path(_cache_dir) / "tinygrad" / "downloads" / (name if name else hashlib.md5(url.encode('utf-8')).hexdigest())
|
||||
fp = pathlib.Path(name) if name is not None and (isinstance(name, pathlib.Path) or '/' in name) else pathlib.Path(_cache_dir) / "tinygrad" / "downloads" / (name if name else hashlib.md5(url.encode('utf-8')).hexdigest()) # noqa: E501
|
||||
if not fp.is_file() or not allow_caching:
|
||||
with request.urlopen(url, timeout=10) as r:
|
||||
assert r.status == 200
|
||||
|
@ -289,14 +291,14 @@ def cpu_time_execution(cb, enable):
|
|||
|
||||
# TODO: make this work with read only memoryviews (if possible)
|
||||
def from_mv(mv, to_type=ctypes.c_char): return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type))
|
||||
def to_char_p_p(options: List[bytes], to_type=ctypes.c_char): return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options])
|
||||
def to_char_p_p(options: List[bytes], to_type=ctypes.c_char): return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options]) # noqa: E501
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]):
|
||||
class CStruct(ctypes.Structure):
|
||||
_pack_, _fields_ = 1, fields
|
||||
return CStruct
|
||||
def init_c_var(ctypes_var, creat_cb): return (creat_cb(ctypes_var), ctypes_var)[1]
|
||||
def get_bytes(arg, get_sz, get_str, check) -> bytes: return (sz := init_c_var(ctypes.c_size_t(), lambda x: check(get_sz(arg, ctypes.byref(x)))), ctypes.string_at(init_c_var(ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x))), size=sz.value))[1]
|
||||
def get_bytes(arg, get_sz, get_str, check) -> bytes: return (sz := init_c_var(ctypes.c_size_t(), lambda x: check(get_sz(arg, ctypes.byref(x)))), ctypes.string_at(init_c_var(ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x))), size=sz.value))[1] # noqa: E501
|
||||
def flat_mv(mv:memoryview):
|
||||
if len(mv) == 0: return mv
|
||||
return mv.cast("B", shape=(mv.nbytes,))
|
||||
|
@ -305,10 +307,10 @@ def flat_mv(mv:memoryview):
|
|||
|
||||
def pretty_ptx(s):
|
||||
# all expressions match `<valid_before><expr><valid_after>` and replace it with `<valid_before>color(<expr>)<valid_after>`
|
||||
s = re.sub(r'([!@<\[\s,\+\-;\n])((?:[_%$][\w%\$_]+(?:\.[xyz])?\:?)|(?:buf\d+))([<>\]\s,\+\-;\n\)])', lambda m:m[1]+colored(m[2], "blue")+m[3], s, flags=re.M) # identifiers
|
||||
s = re.sub(r'([!@<\[\s,\+\-;\n])((?:[_%$][\w%\$_]+(?:\.[xyz])?\:?)|(?:buf\d+))([<>\]\s,\+\-;\n\)])', lambda m:m[1]+colored(m[2], "blue")+m[3], s, flags=re.M) # identifiers # noqa: E501
|
||||
s = re.sub(r'(.)((?:b|s|u|f)(?:8|16|32|64)|pred)([\.\s])', lambda m:m[1]+colored(m[2], "green")+m[3], s, flags=re.M) # types
|
||||
s = re.sub(r'^(\s*)([\w]+)(.*?;$)', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # instructions
|
||||
s = re.sub(r'([<>\[\]\s,\+\-;])((?:0[fF][0-9a-fA-F]{8})|(?:[0-9]+)|(?:0[xX][0-9a-fA-F]+))([<>\[\]\s,\+\-;])', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # numbers
|
||||
s = re.sub(r'([<>\[\]\s,\+\-;])((?:0[fF][0-9a-fA-F]{8})|(?:[0-9]+)|(?:0[xX][0-9a-fA-F]+))([<>\[\]\s,\+\-;])', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # numbers # noqa: E501
|
||||
s = re.sub(r'(\.)(param|reg|global)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # space
|
||||
s = re.sub(r'(\.)(version|target|address_size|visible|entry)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # derivatives
|
||||
return s
|
||||
|
@ -321,8 +323,8 @@ def compile_cuda_style(prg, compile_options, prog_t, create_prog, compile_prog,
|
|||
return get_bytes(prog, get_code_size, get_code, check)
|
||||
|
||||
def encode_args_cuda_style(bufs, vals, device_ptr_t, marks) -> Tuple[ctypes.Array, ctypes.Structure]:
|
||||
c_args = init_c_struct_t(tuple([(f'f{i}', device_ptr_t) for i in range(len(bufs))] + [(f'f{i}', ctypes.c_int) for i in range(len(bufs), len(bufs)+len(vals))]))(*bufs, *vals)
|
||||
return (ctypes.c_void_p * 5)(ctypes.c_void_p(marks[0]), ctypes.cast(ctypes.pointer(c_args), ctypes.c_void_p), ctypes.c_void_p(marks[1]), ctypes.cast(ctypes.pointer(ctypes.c_size_t(ctypes.sizeof(c_args))), ctypes.c_void_p), ctypes.c_void_p(marks[2])), c_args
|
||||
c_args = init_c_struct_t(tuple([(f'f{i}', device_ptr_t) for i in range(len(bufs))] + [(f'f{i}', ctypes.c_int) for i in range(len(bufs), len(bufs)+len(vals))]))(*bufs, *vals) # noqa: E501
|
||||
return (ctypes.c_void_p * 5)(ctypes.c_void_p(marks[0]), ctypes.cast(ctypes.pointer(c_args), ctypes.c_void_p), ctypes.c_void_p(marks[1]), ctypes.cast(ctypes.pointer(ctypes.c_size_t(ctypes.sizeof(c_args))), ctypes.c_void_p), ctypes.c_void_p(marks[2])), c_args # noqa: E501
|
||||
|
||||
def time_execution_cuda_style(cb, ev_t, evcreate, evrecord, evsync, evdestroy, evtime, enable=False) -> Optional[float]:
|
||||
if not enable: return cb()
|
||||
|
|
|
@ -15,7 +15,7 @@ class JitItem:
|
|||
rawbufs: List[Optional[Buffer]]
|
||||
|
||||
def get_jit_stats(jit_cache: List[JitItem]) -> Tuple[Node, Node]:
|
||||
return functools.reduce(operator.__add__, [ji.prg.op_estimate for ji in jit_cache], NumNode(0)), functools.reduce(operator.__add__, [ji.prg.mem_estimate for ji in jit_cache], NumNode(0))
|
||||
return functools.reduce(operator.__add__, [ji.prg.op_estimate for ji in jit_cache], NumNode(0)), functools.reduce(operator.__add__, [ji.prg.mem_estimate for ji in jit_cache], NumNode(0)) # noqa: E501
|
||||
def get_input_replace(jit_cache: List[JitItem], input_rawbuffers:List[Buffer]) -> Dict[Tuple[int, int], int]:
|
||||
input_replace: Dict[Tuple[int, int], int] = {}
|
||||
for j,ji in enumerate(jit_cache):
|
||||
|
@ -24,7 +24,7 @@ def get_input_replace(jit_cache: List[JitItem], input_rawbuffers:List[Buffer]) -
|
|||
input_replace[(j,i)] = input_rawbuffers.index(a)
|
||||
return input_replace
|
||||
def get_jc_idxs_with_updatable_launch_dims(jit_cache: List[JitItem]) -> List[int]:
|
||||
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ((ji.prg.global_size and not all_int(tuple(ji.prg.global_size))) or (ji.prg.local_size and not all_int(tuple(ji.prg.local_size))))]
|
||||
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ((ji.prg.global_size and not all_int(tuple(ji.prg.global_size))) or (ji.prg.local_size and not all_int(tuple(ji.prg.local_size))))] # noqa: E501
|
||||
def get_jc_idxs_with_updatable_var_vals(jit_cache: List[JitItem]) -> List[int]:
|
||||
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ji.prg.vars]
|
||||
|
||||
|
@ -49,7 +49,7 @@ class TinyJit(Generic[ReturnType]):
|
|||
|
||||
def __call__(self, *args, **kwargs) -> ReturnType:
|
||||
# all inputs (except const) are realized
|
||||
input_tensors: Dict[Union[int, str], Tensor] = {cast(Union[int, str], k):v.realize() for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor}
|
||||
input_tensors: Dict[Union[int, str], Tensor] = {cast(Union[int, str], k):v.realize() for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor} # noqa: E501
|
||||
expected_name_sts_dtype = tuple([(k, v.lazydata.st.unbind(), v.dtype) for k,v in input_tensors.items()])
|
||||
|
||||
# get rawbuffers
|
||||
|
@ -58,13 +58,13 @@ class TinyJit(Generic[ReturnType]):
|
|||
assert len(set(input_rawbuffers)) == len(input_rawbuffers), "duplicate inputs to JIT"
|
||||
|
||||
# get variables: they can either be in Tensors or passed in as arguments, and all must be bound. these are all global
|
||||
var_vals: Dict[Variable, int] = merge_dicts([arg.lazydata.st.var_vals for arg in input_tensors.values()] + [dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))])
|
||||
var_vals: Dict[Variable, int] = merge_dicts([arg.lazydata.st.var_vals for arg in input_tensors.values()] + [dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))]) # noqa: E501
|
||||
expected_vals = tuple(var_vals.keys())
|
||||
|
||||
if self.cnt >= 2:
|
||||
# jit exec
|
||||
assert self.expected_vals == expected_vals, "mismatch of var_vals"
|
||||
assert self.expected_name_sts_dtype == expected_name_sts_dtype, f"mismatch of sts, expected {self.expected_name_sts_dtype} got {expected_name_sts_dtype}"
|
||||
assert self.expected_name_sts_dtype == expected_name_sts_dtype, f"mismatch of sts, expected {self.expected_name_sts_dtype} got {expected_name_sts_dtype}" # noqa: E501
|
||||
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx]
|
||||
for ji in self.jit_cache: ji.prg(cast(List[Buffer], ji.rawbufs), var_vals, wait=DEBUG>=2, jit=True)
|
||||
elif self.cnt == 1:
|
||||
|
@ -74,20 +74,21 @@ class TinyJit(Generic[ReturnType]):
|
|||
self.ret = self.fxn(*args, **kwargs)
|
||||
self.jit_cache = CacheCollector.finish()
|
||||
assert len(self.jit_cache) != 0, "didn't JIT anything!"
|
||||
assert len(set(get_input_replace(self.jit_cache, input_rawbuffers).values())) == len(input_rawbuffers), "some input tensors not found" # Do this check on an unmodified jit cache.
|
||||
assert len(set(get_input_replace(self.jit_cache, input_rawbuffers).values())) == len(input_rawbuffers), "some input tensors not found"
|
||||
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
|
||||
|
||||
# if your Device supports it, condense the items into a graph executor.
|
||||
if (make_graph := Device[Device.DEFAULT].graph) and getenv("JIT") != 2:
|
||||
# Split JIT cache into batches for faster graph execution. This allows the accelerator to run some batches while subsequent graphs are still being updated.
|
||||
# JitItems that cannot be jitted (not CompiledASTRunner) are moved to the final jit cache and do not participate in the process of graph building.
|
||||
# Split JIT cache into batches for faster graph execution.
|
||||
# This allows the accelerator to run some batches while subsequent graphs are still being updated.
|
||||
graphed_jit_cache, current_batch = [], []
|
||||
for i,ji in enumerate(self.jit_cache):
|
||||
# If the jit item can potentially be graphed, put it in a batch.
|
||||
if isinstance(ji.prg, CompiledASTRunner): current_batch.append(ji)
|
||||
|
||||
# The flush is done when (1) ji is the last one, (2) the size of batch exceeds the maximum batch size or (3) the current jit item cannot be graphed, so the current batch is flushed before such a jit item is added.
|
||||
if len(current_batch) > 0 and (i==len(self.jit_cache)-1 or len(current_batch) >= getenv("JIT_BATCH_SIZE", 64) or not isinstance(ji.prg, CompiledASTRunner)):
|
||||
# The flush is done when (1) ji is the last one, (2) the size of batch exceeds the maximum batch size or
|
||||
# (3) the current jit item cannot be graphed, so the current batch is flushed before such a jit item is added.
|
||||
if len(current_batch) > 0 and (i==len(self.jit_cache)-1 or len(current_batch) >= getenv("JIT_BATCH_SIZE", 64) or not isinstance(ji.prg, CompiledASTRunner)): # noqa: E501
|
||||
try:
|
||||
graphed_jit_cache.append(JitItem(make_graph(current_batch, input_rawbuffers, var_vals), cast(List[Optional[Buffer]], input_rawbuffers)))
|
||||
if DEBUG >= 2: print(f"\tJIT GRAPHing batch with {len(current_batch)} kernels")
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
# ruff: noqa: E501
|
||||
# TODO: replace with new lazy with <= 150 length lines
|
||||
from __future__ import annotations
|
||||
import sys, math
|
||||
from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast, Mapping, Set
|
||||
|
|
|
@ -126,7 +126,7 @@ class Div(Function):
|
|||
|
||||
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
||||
return grad_output.e(BinaryOps.DIV, self.y) if self.needs_input_grad[0] else None, \
|
||||
grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.DIV, self.y.e(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None
|
||||
grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.DIV, self.y.e(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None # noqa: E501
|
||||
|
||||
# ************* ternary ops *************
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ class BatchNorm2d:
|
|||
# NOTE: wow, this is done all throughout training in most PyTorch models
|
||||
if self.track_running_stats:
|
||||
self.running_mean.assign((1 - self.momentum) * self.running_mean + self.momentum * batch_mean.detach())
|
||||
self.running_var.assign((1 - self.momentum) * self.running_var + self.momentum * prod(y.shape)/(prod(y.shape) - y.shape[1]) * batch_var.detach() )
|
||||
self.running_var.assign((1 - self.momentum) * self.running_var + self.momentum * prod(y.shape)/(prod(y.shape) - y.shape[1]) * batch_var.detach() ) # noqa: E501
|
||||
self.num_batches_tracked += 1
|
||||
else:
|
||||
batch_mean = self.running_mean
|
||||
|
@ -52,7 +52,8 @@ class Conv2d:
|
|||
def __call__(self, x:Tensor):
|
||||
return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
|
||||
|
||||
def initialize_weight(self, out_channels, in_channels, groups): return Tensor.kaiming_uniform(out_channels, in_channels//groups, *self.kernel_size, a=math.sqrt(5))
|
||||
def initialize_weight(self, out_channels, in_channels, groups):
|
||||
return Tensor.kaiming_uniform(out_channels, in_channels//groups, *self.kernel_size, a=math.sqrt(5))
|
||||
|
||||
def ConvTranspose1d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
|
||||
return ConvTranspose2d(in_channels, out_channels, (kernel_size,), stride, padding, output_padding, dilation, groups, bias)
|
||||
|
@ -63,9 +64,10 @@ class ConvTranspose2d(Conv2d):
|
|||
self.output_padding = output_padding
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
|
||||
return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride, dilation=self.dilation, groups=self.groups) # noqa: E501
|
||||
|
||||
def initialize_weight(self, out_channels, in_channels, groups): return Tensor.kaiming_uniform(in_channels, out_channels//groups, *self.kernel_size, a=math.sqrt(5))
|
||||
def initialize_weight(self, out_channels, in_channels, groups):
|
||||
return Tensor.kaiming_uniform(in_channels, out_channels//groups, *self.kernel_size, a=math.sqrt(5))
|
||||
|
||||
class Linear:
|
||||
def __init__(self, in_features, out_features, bias=True):
|
||||
|
|
|
@ -5,7 +5,8 @@ from tinygrad.tensor import Tensor
|
|||
from tinygrad.helpers import dtypes, prod, argsort, DEBUG, Timing, GlobalCounters, CI, unwrap
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
|
||||
safe_dtypes = {"F16": dtypes.float16, "F32": dtypes.float32, "U8": dtypes.uint8, "I8": dtypes.int8, "I32": dtypes.int32, "I64": dtypes.int64, "F64": dtypes.double, "B": dtypes.bool, "I16": dtypes.short, "U16": dtypes.ushort, "UI": dtypes.uint, "UL": dtypes.ulong}
|
||||
safe_dtypes = {"F16": dtypes.float16, "F32": dtypes.float32, "U8": dtypes.uint8, "I8": dtypes.int8, "I32": dtypes.int32, "I64": dtypes.int64,
|
||||
"F64": dtypes.double, "B": dtypes.bool, "I16": dtypes.short, "U16": dtypes.ushort, "UI": dtypes.uint, "UL": dtypes.ulong}
|
||||
inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
|
||||
|
||||
def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]:
|
||||
|
@ -15,7 +16,7 @@ def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]:
|
|||
|
||||
def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
|
||||
t, json_len, metadata = safe_load_metadata(fn)
|
||||
return {k:t[8+json_len+v['data_offsets'][0]:].cast(safe_dtypes[v['dtype']])[:prod(v['shape'])].reshape(v['shape']) for k,v in metadata.items() if k != "__metadata__"}
|
||||
return {k:t[8+json_len+v['data_offsets'][0]:].cast(safe_dtypes[v['dtype']])[:prod(v['shape'])].reshape(v['shape']) for k,v in metadata.items() if k != "__metadata__"} # noqa: E501
|
||||
|
||||
def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any]]=None):
|
||||
headers, offset = {}, 0
|
||||
|
@ -49,9 +50,10 @@ def get_parameters(obj) -> List[Tensor]: return list(get_state_dict(obj).values(
|
|||
|
||||
def load_state_dict(model, state_dict, strict=True, verbose=True):
|
||||
start_mem_used = GlobalCounters.mem_used
|
||||
with Timing("loaded weights in ", lambda et_ns: f", {(GlobalCounters.mem_used-start_mem_used)/1e9:.2f} GB loaded at {(GlobalCounters.mem_used-start_mem_used)/et_ns:.2f} GB/s"):
|
||||
with Timing("loaded weights in ", lambda et_ns: f", {(GlobalCounters.mem_used-start_mem_used)/1e9:.2f} GB loaded at {(GlobalCounters.mem_used-start_mem_used)/et_ns:.2f} GB/s"): # noqa: E501
|
||||
model_state_dict = get_state_dict(model)
|
||||
if DEBUG >= 1 and len(state_dict) > len(model_state_dict): print("WARNING: unused weights in state_dict", sorted(list(state_dict.keys() - model_state_dict.keys())))
|
||||
if DEBUG >= 1 and len(state_dict) > len(model_state_dict):
|
||||
print("WARNING: unused weights in state_dict", sorted(list(state_dict.keys() - model_state_dict.keys())))
|
||||
for k,v in (t := tqdm(model_state_dict.items(), disable=CI or not verbose)):
|
||||
t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, {k:50s}")
|
||||
if k not in state_dict and not strict:
|
||||
|
@ -90,8 +92,8 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
|
|||
def __setstate__(self, state): self.tensor = state[0]
|
||||
|
||||
deserialized_objects: Dict[str, Any] = {}
|
||||
intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16, "IntStorage": dtypes.int32, "LongStorage": dtypes.int64,
|
||||
"_rebuild_tensor_v2": _rebuild_tensor_v2, "FloatTensor": None, "Parameter": Parameter}
|
||||
intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16, "IntStorage": dtypes.int32,
|
||||
"LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2, "FloatTensor": None, "Parameter": Parameter}
|
||||
whitelist = {"torch", "collections", "numpy", "_codecs"} # NOTE: this is not for security, only speed
|
||||
class Dummy: pass
|
||||
class TorchPickle(pickle.Unpickler):
|
||||
|
@ -123,7 +125,8 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
|
|||
f = unwrap(tar.extractfile('tensors'))
|
||||
for _ in range(TorchPickle(f).load()): # num_tensors
|
||||
(key, storage_id, _), ndim, _ = TorchPickle(f).load(), struct.unpack('<i', f.read(4))[0], f.read(4)
|
||||
size, stride, storage_offset = struct.unpack(f'<{ndim}q', f.read(8 * ndim)), struct.unpack(f'<{ndim}q', f.read(8 * ndim)), struct.unpack('<q', f.read(8))[0]
|
||||
size, stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim)), struct.unpack(f'<{ndim}q', f.read(8 * ndim))
|
||||
storage_offset = struct.unpack('<q', f.read(8))[0]
|
||||
deserialized_objects[str(key)] = _rebuild_tensor_v2((None, storage_type, storage_id, None, -1), storage_offset, size, stride)
|
||||
return {k:v.tensor if isinstance(v, Parameter) else v for k,v in TorchPickle(unwrap(tar.extractfile('pickle'))).load().items()}
|
||||
else:
|
||||
|
|
|
@ -57,7 +57,8 @@ class LazyOp:
|
|||
def hash(self): return hash((self.op,self.src, self.arg))
|
||||
def __hash__(self): return self.hash
|
||||
|
||||
def map_buffers(self, real_srcs: Mapping[Any, Union[LazyBuffer, LazyOp]]) -> LazyOp: return LazyOp(self.op, tuple([y.map_buffers(real_srcs) if y not in real_srcs else real_srcs[y] for y in self.src]), self.arg)
|
||||
def map_buffers(self, real_srcs: Mapping[Any, Union[LazyBuffer, LazyOp]]) -> LazyOp:
|
||||
return LazyOp(self.op, tuple([y.map_buffers(real_srcs) if y not in real_srcs else real_srcs[y] for y in self.src]), self.arg)
|
||||
def get_lazyops(self) -> List[LazyOp]: return [self] + [item for x in self.src for item in x.get_lazyops()]
|
||||
|
||||
def replace_with_movement_ops(self:LazyOp, ops:List[Tuple[MovementOps, Tuple[Any, ...]]]) -> 'LazyBuffer':
|
||||
|
@ -80,7 +81,8 @@ class LazyOp:
|
|||
def shrink(self, _): raise NotImplementedError
|
||||
def stride(self, _): raise NotImplementedError
|
||||
|
||||
def vars_from_ast(ast:LazyOp) -> List[Variable]: return sorted(set.union(*[x.arg.st.vars() for x in ast.get_lazyops() if x.op in BufferOps], set()), key=lambda x: str(x.expr))
|
||||
def vars_from_ast(ast:LazyOp) -> List[Variable]:
|
||||
return sorted(set.union(*[x.arg.st.vars() for x in ast.get_lazyops() if x.op in BufferOps], set()), key=lambda x: str(x.expr))
|
||||
|
||||
# **************** independent FlopCounter ****************
|
||||
|
||||
|
@ -97,12 +99,14 @@ class FlopCounter:
|
|||
return ret
|
||||
|
||||
InterpretedFlopCounter: Dict[Op, Callable] = {
|
||||
BufferOps.LOAD: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {arg.idx: arg.dtype.itemsize*arg.st.size()}), BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {}),
|
||||
BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, arg.dtype, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize*arg.st.size()}), UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, arg[0], self.consume_flops(), self.mem), # cast uses no flops
|
||||
**{op:lambda self: FlopCounter(self.shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op != UnaryOps.CAST},
|
||||
**{op:lambda self,y: FlopCounter(self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps},
|
||||
BufferOps.LOAD: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {arg.idx: arg.dtype.itemsize*arg.st.size()}),
|
||||
BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {}),
|
||||
BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, arg.dtype, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize*arg.st.size()}), # noqa: E501
|
||||
UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, arg[0], self.consume_flops(), self.mem), # cast uses no flops
|
||||
**{op:lambda self: FlopCounter(self.shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op != UnaryOps.CAST}, # noqa: E501
|
||||
**{op:lambda self,y: FlopCounter(self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501
|
||||
**{op:lambda self,new_shape: FlopCounter(new_shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps},
|
||||
TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, max(y.dtype, z.dtype), self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})}
|
||||
TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, max(y.dtype, z.dtype), self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})} # noqa: E501
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_lazyop_info(ast:LazyOp) -> FlopCounter:
|
||||
|
|
|
@ -12,7 +12,7 @@ class CustomOp(JITRunner):
|
|||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False): self.fxn(*rawbufs)
|
||||
|
||||
def lower_schedule_item(si:ScheduleItem) -> Optional[JITRunner]:
|
||||
assert all(si.out.device == x.device for x in si.inputs) or si.ast.op is LoadOps.COPY, f"all devices must be the same, {si.out.device} != {[x.device for x in si.inputs]} {print_tree(si.ast) or ''}"
|
||||
assert all(si.out.device == x.device for x in si.inputs) or si.ast.op is LoadOps.COPY, f"all devices must be the same, {si.out.device} != {[x.device for x in si.inputs]} {print_tree(si.ast) or ''}" # noqa: E501
|
||||
if si.ast.op is LoadOps.EMPTY: return None
|
||||
if si.ast.op is LoadOps.COPY: return BufferCopy
|
||||
if si.ast.op is LoadOps.CUSTOM: return CustomOp(si.ast.arg)
|
||||
|
|
|
@ -62,7 +62,7 @@ class CStyleLanguage(NamedTuple):
|
|||
if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and output_dtype.scalar() != dtypes.float16:
|
||||
return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx})"
|
||||
if output_dtype.sz > 1:
|
||||
out_val = f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx}))"
|
||||
out_val = f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx}))" # noqa: E501
|
||||
else:
|
||||
out_val = f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]"
|
||||
|
||||
|
@ -81,7 +81,7 @@ class CStyleLanguage(NamedTuple):
|
|||
return f"({cond})?({x}):{y}"
|
||||
|
||||
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], local_size:List[int], prekernel:List[str]) -> str:
|
||||
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else ""
|
||||
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else "" # noqa: E501
|
||||
buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else
|
||||
self.arg_int_prefix if dtype == dtypes._arg_int32 else
|
||||
("const " if i > 0 else "")+self.buffer_prefix+dtype.name+"*"+self.buffer_suffix) for i,(name,dtype) in enumerate(bufs)]
|
||||
|
@ -99,7 +99,7 @@ class CStyleLanguage(NamedTuple):
|
|||
if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and var_dtype.scalar() != dtypes.float16:
|
||||
return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx});"
|
||||
if var_dtype.sz > 1:
|
||||
return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.sz}){var_name};"
|
||||
return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.sz}){var_name};" # noqa: E501
|
||||
return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};"
|
||||
|
||||
def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
|
||||
|
@ -149,7 +149,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
|
|||
kk(f"{output}.x = c.thread_elements()[0]; {output}.y = c.thread_elements()[1]; }}")
|
||||
elif args[0] == "HIP":
|
||||
assert dtype == dtypes.float.vec(8), "output dtype of HIP TC is _float8"
|
||||
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u, 'wmma')} = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
|
||||
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u, 'wmma')} = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});") # noqa: E501
|
||||
else:
|
||||
raise NotImplementedError(f"WMMA not implemented for {args}")
|
||||
elif uop == UOps.ALU:
|
||||
|
@ -160,7 +160,8 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
|
|||
else:
|
||||
val = lang.code_for_op[args](*[r[x] for x in vin] + [dtype])
|
||||
assert child_count[u] != 0, f"childless ALU op found {u}"
|
||||
if (child_count[u] <= 1 or dtypes.is_int(dtype)) and args != BinaryOps.MAX and not getenv("EXPAND_SSA"): # fix index rendering issue. fix clang nested max macro issue
|
||||
# TODO: fix index rendering issue. fix clang nested max macro issue
|
||||
if (child_count[u] <= 1 or dtypes.is_int(dtype)) and args != BinaryOps.MAX and not getenv("EXPAND_SSA"):
|
||||
r[u] = val
|
||||
else:
|
||||
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'alu')} = {val};")
|
||||
|
@ -272,7 +273,8 @@ class HIPLanguage(CStyleLanguage):
|
|||
__device__ float4 log2(float4 x) { return float4(log2(x.x), log2(x.y), log2(x.z), log2(x.w)); }
|
||||
__device__ float4 exp2(float4 x) { return float4(exp2(x.x), exp2(x.y), exp2(x.z), exp2(x.w)); }
|
||||
__device__ float4 sin(float4 x) { return float4(sin(x.x), sin(x.y), sin(x.z), sin(x.w)); }
|
||||
typedef float float8 __attribute__((ext_vector_type(8))); __device__ float8 make_float8(float x, float y, float z, float w, float a, float b, float c, float d) { return {x, y, z, w, a, b, c, d}; }
|
||||
typedef float float8 __attribute__((ext_vector_type(8)));
|
||||
__device__ float8 make_float8(float x, float y, float z, float w, float a, float b, float c, float d) { return {x, y, z, w, a, b, c, d}; }
|
||||
extern "C" __global__
|
||||
"""
|
||||
launch_bounds = True
|
||||
|
@ -284,15 +286,22 @@ class HIPLanguage(CStyleLanguage):
|
|||
uses_ptr_arithmetic=True
|
||||
arg_int_prefix = "const int"
|
||||
half_prekernel = "#include <hip/hip_fp16.h>\n" + """
|
||||
typedef union { struct { half x, y, z, w; } __attribute__((aligned(8))); half data[4]; } half4; __device__ half4 make_half4(half x, half y, half z, half w) { return {x, y, z, w}; }
|
||||
typedef union { struct { half x, y, z, w, a, b, c, d; } __attribute__((aligned(16))); half data[8]; } half8; __device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { return {x, y, z, w, a, b, c, d}; }
|
||||
typedef _Float16 half16 __attribute__((ext_vector_type(16))); __device__ half16 make_half16(half x, half y, half z, half w, half a, half b, half c, half d, half e, half f, half g, half h, half i, half j, half k, half l) { return {x, y, z, w, a, b, c, d, e, f, g, h, i, j, k, l}; }
|
||||
typedef union { struct { half x, y, z, w; } __attribute__((aligned(8))); half data[4]; } half4;
|
||||
__device__ half4 make_half4(half x, half y, half z, half w) { return {x, y, z, w}; }
|
||||
typedef union { struct { half x, y, z, w, a, b, c, d; } __attribute__((aligned(16))); half data[8]; } half8;
|
||||
__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { return {x, y, z, w, a, b, c, d}; }
|
||||
typedef _Float16 half16 __attribute__((ext_vector_type(16)));
|
||||
__device__ half16 make_half16(half x, half y, half z, half w, half a, half b, half c, half d,
|
||||
half e, half f, half g, half h, half i, half j, half k, half l) {
|
||||
return {x, y, z, w, a, b, c, d, e, f, g, h, i, j, k, l}; }
|
||||
__device__ float vload_half(size_t offset, const half *p) { return (float)*(p + offset); }
|
||||
__device__ float2 vload_half2(size_t offset, const half *p) { return make_float2((float)*(p + offset*2), (float)*(p + offset*2 + 1)); }
|
||||
__device__ float4 vload_half4(size_t offset, const half *p) { return make_float4((float)*(p + offset*4), (float)*(p + offset*4 + 1), (float)*(p + offset*4 + 2), (float)*(p + offset*4 + 3)); }
|
||||
__device__ float4 vload_half4(size_t offset, const half *p) {
|
||||
return make_float4((float)*(p + offset*4), (float)*(p + offset*4 + 1), (float)*(p + offset*4 + 2), (float)*(p + offset*4 + 3)); }
|
||||
__device__ void vstore_half(float data, size_t offset, half *p) { *(p + offset) = (half)data; }
|
||||
__device__ void vstore_half2(float2 data, size_t offset, half *p) { *(p + offset*2) = (half)data.x; *(p + offset*2 + 1) = (half)data.y; }
|
||||
__device__ void vstore_half4(float4 data, size_t offset, half *p) { *(p + offset*4) = (half)data.x; *(p + offset*4 + 1) = (half)data.y; *(p + offset*4 + 2) = (half)data.z; *(p + offset*4 + 3) = (half)data.w; }
|
||||
__device__ void vstore_half4(float4 data, size_t offset, half *p) {
|
||||
*(p + offset*4) = (half)data.x; *(p + offset*4 + 1) = (half)data.y; *(p + offset*4 + 2) = (half)data.z; *(p + offset*4 + 3) = (half)data.w; }
|
||||
__device__ half exp2(half x) { return hexp2(x); }
|
||||
__device__ half log2(half x) { return hlog2(x); }
|
||||
__device__ half sin(half x) { return hsin(x); }
|
||||
|
@ -317,7 +326,8 @@ __device__ bool operator<(const unsigned short &a, const half &b) { return __hlt
|
|||
gid = [f'blockIdx.{chr(120+i)}' for i in range(3)]
|
||||
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)]
|
||||
xid = [f'(blockIdx.{chr(120+i)}*blockDim.{chr(120+i)}+threadIdx.{chr(120+i)})' for i in range(3)]
|
||||
code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})" if dtype != dtypes.half else f"hmax({a},{b})", TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}!=0?{b}:{c})" if dtype != dtypes.half else f"(half)({a}!=0?{b}:{c})"}
|
||||
code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})" if dtype != dtypes.half else f"hmax({a},{b})",
|
||||
TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}!=0?{b}:{c})" if dtype != dtypes.half else f"(half)({a}!=0?{b}:{c})"}
|
||||
HIPRenderer = functools.partial(uops_to_cstyle, HIPLanguage())
|
||||
|
||||
# TODO: how much of this can be merged with above?
|
||||
|
@ -328,7 +338,8 @@ class WGSLLanguage(CStyleLanguage):
|
|||
barrier="workgroupBarrier();"
|
||||
generic_var_prefix = "var "
|
||||
external_local_bufs = True
|
||||
code_for_op = { **CStyleLanguage().code_for_op, BinaryOps.CMPLT: lambda x,y,dtype: f"f32({x}<{y})", TernaryOps.MULACC: lambda x,y,z,dtype: f"fma({x},{y},{z})", TernaryOps.WHERE: lambda a,b,c,dtype: f"select({c},{b},{a}!=0.)" }
|
||||
code_for_op = { **CStyleLanguage().code_for_op, BinaryOps.CMPLT: lambda x,y,dtype: f"f32({x}<{y})",
|
||||
TernaryOps.MULACC: lambda x,y,z,dtype: f"fma({x},{y},{z})", TernaryOps.WHERE: lambda a,b,c,dtype: f"select({c},{b},{a}!=0.)" }
|
||||
type_map = {dtypes.float: "f32", dtypes.half: "f16", dtypes.int32: "i32", dtypes.uint32: "u32", dtypes.bool: "bool"}
|
||||
|
||||
def render_local(self, name: str, size: int):
|
||||
|
@ -343,8 +354,8 @@ class WGSLLanguage(CStyleLanguage):
|
|||
local_size = local_size[::-1] if local_size else [1]
|
||||
bind_it = iter(range(len(bufs)))
|
||||
prg = "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }\n"
|
||||
prg += "\n".join(prekernel+[f"@group(0) @binding({next(bind_it)}) {'var<uniform>' if dtype == dtypes._arg_int32 else 'var<storage,read_write>'} {name}: {'i32' if dtype == dtypes._arg_int32 else f'array<{self.type_map[dtype]}>'};" for name,dtype in bufs])
|
||||
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>, @builtin(local_invocation_id) lindex: vec3<u32>) {{\n" + "\n".join(kernel) + "\n}"
|
||||
prg += "\n".join(prekernel+[f"@group(0) @binding({next(bind_it)}) {'var<uniform>' if dtype == dtypes._arg_int32 else 'var<storage,read_write>'} {name}: {'i32' if dtype == dtypes._arg_int32 else f'array<{self.type_map[dtype]}>'};" for name,dtype in bufs]) # noqa: E501
|
||||
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>, @builtin(local_invocation_id) lindex: vec3<u32>) {{\n" + "\n".join(kernel) + "\n}" # noqa: E501
|
||||
return prg
|
||||
|
||||
def render_for(self, expr:str, _min:Union[int,str], _max:Union[int,str]) -> str:
|
||||
|
|
|
@ -4,31 +4,37 @@ from tinygrad.codegen.linearizer import UOps, UOp
|
|||
from tinygrad.helpers import DType, dtypes
|
||||
from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
|
||||
|
||||
LLVM_FAST_MATH_FLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf
|
||||
MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf
|
||||
|
||||
def is_bool_or_unsigned(dtype: DType):
|
||||
return dtype == dtypes.bool or dtypes.is_unsigned(dtype)
|
||||
|
||||
code_for_op: Final[Dict[Op, Callable]] = {
|
||||
UnaryOps.NEG: lambda builder, x, var_dtype: builder.xor(x, ir.Constant(ir.IntType(1), 1)) if var_dtype == dtypes.bool else builder.neg(x) if dtypes.is_int(var_dtype) else builder.fneg(x, flags=LLVM_FAST_MATH_FLAGS),
|
||||
UnaryOps.EXP2: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [x.type]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
|
||||
UnaryOps.LOG2: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.log2', [x.type]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
|
||||
UnaryOps.SIN: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [x.type]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
|
||||
UnaryOps.SQRT: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
|
||||
BinaryOps.ADD: lambda builder, x, y, var_dtype: builder.or_(x, y) if var_dtype == dtypes.bool else builder.add(x, y) if dtypes.is_int(var_dtype) else builder.fadd(x, y, flags=LLVM_FAST_MATH_FLAGS),
|
||||
BinaryOps.SUB: lambda builder, x, y, var_dtype: builder.sub(x, y) if dtypes.is_int(var_dtype) else builder.fsub(x, y, flags=LLVM_FAST_MATH_FLAGS),
|
||||
BinaryOps.MUL: lambda builder, x, y, var_dtype: builder.mul(x, y) if is_bool_or_unsigned(var_dtype) or dtypes.is_int(var_dtype) else builder.fmul(x, y, flags=LLVM_FAST_MATH_FLAGS), # TOOD should we use umul_with_overflow?
|
||||
BinaryOps.DIV: lambda builder, x, y, var_dtype: builder.udiv(x, y) if is_bool_or_unsigned(var_dtype) else builder.sdiv(x, y) if dtypes.is_int(var_dtype) else builder.fdiv(x, y, flags=LLVM_FAST_MATH_FLAGS),
|
||||
BinaryOps.CMPLT: lambda builder, x, y, var_dtype: builder.icmp_unsigned("<", x, y) if is_bool_or_unsigned(var_dtype) else builder.icmp_signed("<", x, y) if dtypes.is_int(var_dtype) else builder.fcmp_unordered("<", x, y, flags=LLVM_FAST_MATH_FLAGS),
|
||||
BinaryOps.MAX: lambda builder, x, y, var_dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(var_dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(var_dtype) else builder.fcmp_unordered(">", x, y, flags=LLVM_FAST_MATH_FLAGS), x, y),
|
||||
BinaryOps.MOD: lambda builder, x, y, var_dtype: builder.urem(x, y) if is_bool_or_unsigned(var_dtype) else builder.srem(x, y) if dtypes.is_int(var_dtype) else builder.frem(x, y),
|
||||
UnaryOps.NEG: lambda builder, x, var_dtype: builder.xor(x, ir.Constant(ir.IntType(1), 1)) if var_dtype == dtypes.bool else builder.neg(x) if dtypes.is_int(var_dtype) else builder.fneg(x, flags=MFLAGS), # noqa: E501
|
||||
UnaryOps.EXP2: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [x.type]), [x], fastmath=MFLAGS),
|
||||
UnaryOps.LOG2: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.log2', [x.type]), [x], fastmath=MFLAGS),
|
||||
UnaryOps.SIN: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [x.type]), [x], fastmath=MFLAGS),
|
||||
UnaryOps.SQRT: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=MFLAGS),
|
||||
BinaryOps.ADD: lambda builder, x, y, var_dtype: builder.or_(x, y) if var_dtype == dtypes.bool else builder.add(x, y) if dtypes.is_int(var_dtype) else builder.fadd(x, y, flags=MFLAGS), # noqa: E501
|
||||
BinaryOps.SUB: lambda builder, x, y, var_dtype: builder.sub(x, y) if dtypes.is_int(var_dtype) else builder.fsub(x, y, flags=MFLAGS),
|
||||
BinaryOps.MUL: lambda builder, x, y, var_dtype: # TOOD should we use umul_with_overflow?
|
||||
builder.mul(x, y) if is_bool_or_unsigned(var_dtype) or dtypes.is_int(var_dtype) else builder.fmul(x, y, flags=MFLAGS),
|
||||
BinaryOps.DIV: lambda builder, x, y, var_dtype:
|
||||
builder.udiv(x, y) if is_bool_or_unsigned(var_dtype) else builder.sdiv(x, y) if dtypes.is_int(var_dtype) else builder.fdiv(x, y, flags=MFLAGS),
|
||||
BinaryOps.CMPLT: lambda builder, x, y, var_dtype: builder.icmp_unsigned("<", x, y) if is_bool_or_unsigned(var_dtype) else builder.icmp_signed("<", x, y) if dtypes.is_int(var_dtype) else builder.fcmp_unordered("<", x, y, flags=MFLAGS), # noqa: E501
|
||||
BinaryOps.MAX: lambda builder, x, y, var_dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(var_dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(var_dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501
|
||||
BinaryOps.MOD: lambda builder, x, y, var_dtype:
|
||||
builder.urem(x, y) if is_bool_or_unsigned(var_dtype) else builder.srem(x, y) if dtypes.is_int(var_dtype) else builder.frem(x, y),
|
||||
BinaryOps.XOR: lambda builder, x, y, var_dtype: builder.xor(x, y),
|
||||
TernaryOps.MULACC: lambda builder, x, y, z, var_dtype: builder.fadd(builder.fmul(x, y, flags=LLVM_FAST_MATH_FLAGS), z, flags=LLVM_FAST_MATH_FLAGS),
|
||||
TernaryOps.WHERE: lambda builder, x, y, z, var_dtype: builder.select(builder.trunc(x, ir.IntType(1)) if isinstance(x.type, ir.IntType) else builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=LLVM_FAST_MATH_FLAGS), y, z)
|
||||
TernaryOps.MULACC: lambda builder, x, y, z, var_dtype: builder.fadd(builder.fmul(x, y, flags=MFLAGS), z, flags=MFLAGS),
|
||||
TernaryOps.WHERE: lambda builder, x, y, z, var_dtype: builder.select(builder.trunc(x, ir.IntType(1)) if isinstance(x.type, ir.IntType) else builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=MFLAGS), y, z) # noqa: E501
|
||||
}
|
||||
|
||||
|
||||
dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32), dtypes._arg_int32: ir.IntType(32), dtypes.int16:ir.IntType(16), dtypes.uint16:ir.IntType(16), dtypes.uint32:ir.IntType(32), dtypes.uint64:ir.IntType(64)}
|
||||
dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(),
|
||||
dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64),
|
||||
dtypes.int32: ir.IntType(32), dtypes._arg_int32: ir.IntType(32), dtypes.int16:ir.IntType(16), dtypes.uint16:ir.IntType(16),
|
||||
dtypes.uint32:ir.IntType(32), dtypes.uint64:ir.IntType(64)}
|
||||
|
||||
def cast(bb, val, input_type, output_type, bitcast=False):
|
||||
if input_type == output_type: return val
|
||||
|
@ -65,7 +71,8 @@ def cast(bb, val, input_type, output_type, bitcast=False):
|
|||
|
||||
raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented")
|
||||
|
||||
def const(args, dtype): return ir.Constant(dtype_to_llvm_dtype[dtype], int(args) if dtypes.is_int(dtype) else bool(args) if dtype == dtypes.bool else args)
|
||||
def const(args, dtype):
|
||||
return ir.Constant(dtype_to_llvm_dtype[dtype], int(args) if dtypes.is_int(dtype) else bool(args) if dtype == dtypes.bool else args)
|
||||
|
||||
def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
|
||||
# all llvm stuff goes into a module
|
||||
|
@ -77,7 +84,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
|
|||
|
||||
# create llvm function
|
||||
func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values()]
|
||||
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if dt!=dtypes._arg_int32 else x for x,dt in func_dtypes]), name=function_name)
|
||||
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if dt!=dtypes._arg_int32 else x for x,dt in func_dtypes]), name=function_name) # noqa: E501
|
||||
for a in func.args:
|
||||
if a.type.is_pointer: a.add_attribute("noalias")
|
||||
|
||||
|
|
|
@ -5,13 +5,13 @@ from tinygrad.helpers import diskcache, cpu_time_execution
|
|||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
|
||||
|
||||
CLANG_PROGRAM_HEADER = '#include <math.h>\n#define max(x,y) ((x>y)?x:y)\n#define int64 long\n#define half __fp16\n#define uchar unsigned char\n#include <stdbool.h>\n'
|
||||
CLANG_PROGRAM_HEADER = '#include <math.h>\n#define max(x,y) ((x>y)?x:y)\n#define int64 long\n#define half __fp16\n#define uchar unsigned char\n#include <stdbool.h>\n' # noqa: E501
|
||||
|
||||
@diskcache
|
||||
def compile_clang(prg:str, header:str=CLANG_PROGRAM_HEADER) -> bytes:
|
||||
# TODO: remove file write. sadly clang doesn't like the use of /dev/stdout here
|
||||
with tempfile.NamedTemporaryFile(delete=True) as output_file:
|
||||
subprocess.check_output(args=('clang -shared -march=native -O2 -Wall -Werror -x c -fPIC - -o '+str(output_file.name)).split(), input=(header+prg).encode('utf-8'))
|
||||
subprocess.check_output(args=('clang -shared -march=native -O2 -Wall -Werror -x c -fPIC - -o '+str(output_file.name)).split(), input=(header+prg).encode('utf-8')) # noqa: E501
|
||||
return pathlib.Path(output_file.name).read_bytes()
|
||||
|
||||
class ClangProgram:
|
||||
|
|
|
@ -27,13 +27,15 @@ def einsum_mulacc(einsum, get_strides, expand):
|
|||
numpy_fxn_for_op: Dict[Op, Callable] = {
|
||||
BufferOps.CONST: lambda val, dtype: np.array(val, dtype=dtype.np),
|
||||
UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin,
|
||||
UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False), UnaryOps.NEG: lambda x: np.logical_not(x) if x.dtype == np.bool_ else np.negative(x),
|
||||
UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False),
|
||||
UnaryOps.NEG: lambda x: np.logical_not(x) if x.dtype == np.bool_ else np.negative(x),
|
||||
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).astype(output_type(x,y)), BinaryOps.ADD: lambda x, y: np.add(*match_types(x, y)),
|
||||
BinaryOps.SUB: lambda x, y: np.subtract(*match_types(x, y)), BinaryOps.MUL: lambda x, y: np.multiply(*match_types(x, y)),
|
||||
BinaryOps.DIV: lambda x, y: np.divide(*match_types(x, y)).astype(output_type(x, y), copy=False), BinaryOps.XOR: lambda x, y: np.bitwise_xor(*match_types(x, y)), UnaryOps.SQRT: np.sqrt,
|
||||
BinaryOps.DIV: lambda x, y: np.divide(*match_types(x, y)).astype(output_type(x, y), copy=False),
|
||||
BinaryOps.XOR: lambda x, y: np.bitwise_xor(*match_types(x, y)), UnaryOps.SQRT: np.sqrt,
|
||||
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), dtype=x.dtype, keepdims=True) if x.shape != new_shape else x,
|
||||
ReduceOps.MAX: lambda x, new_shape: x.max(shape_to_axis(x.shape, new_shape), keepdims=True) if x.shape != new_shape else x,
|
||||
MovementOps.AS_STRIDED: lambda x, arg: np.ndarray(arg[0], buffer=np.require(x, requirements='C'), dtype=x.dtype, offset=arg[2]*x.dtype.itemsize, strides=tuple(y*x.dtype.itemsize for y in arg[1])),
|
||||
MovementOps.AS_STRIDED: lambda x, arg: np.ndarray(arg[0], buffer=np.require(x, requirements='C'), dtype=x.dtype, offset=arg[2]*x.dtype.itemsize, strides=tuple(y*x.dtype.itemsize for y in arg[1])), # noqa: E501
|
||||
MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to,
|
||||
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, *match_types(a.copy(), b.copy()), optimize=True), lambda x: x.strides, np.broadcast_to),
|
||||
TernaryOps.WHERE: np.where,
|
||||
|
|
|
@ -3,7 +3,7 @@ import subprocess, hashlib, tempfile, ctypes, ctypes.util, functools
|
|||
from pathlib import Path
|
||||
from typing import Tuple, Optional
|
||||
import gpuctypes.cuda as cuda
|
||||
from tinygrad.helpers import DEBUG, getenv, diskcache, from_mv, init_c_var, pretty_ptx, cpu_time_execution, compile_cuda_style, encode_args_cuda_style, time_execution_cuda_style
|
||||
from tinygrad.helpers import DEBUG, getenv, diskcache, from_mv, init_c_var, pretty_ptx, cpu_time_execution, compile_cuda_style, encode_args_cuda_style, time_execution_cuda_style # noqa: E501
|
||||
from tinygrad.device import Compiled, LRUAllocator, MallocAllocator
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
from tinygrad.renderer.cstyle import CUDARenderer
|
||||
|
@ -11,16 +11,16 @@ from tinygrad.renderer.cstyle import CUDARenderer
|
|||
CUDACPU = getenv("CUDACPU") == 1
|
||||
if CUDACPU:
|
||||
gpuocelot_lib = ctypes.CDLL(ctypes.util.find_library("gpuocelot"))
|
||||
gpuocelot_lib.ptx_run.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int]
|
||||
cuda.cuLaunchKernel = lambda src, gx, gy, gz, lx, ly, lz, shared, stream, unused_extra, args: gpuocelot_lib.ptx_run(src, len(args), (ctypes.c_void_p * len(args))(*[ctypes.cast(x, ctypes.c_void_p) for x in args]), lx, ly, lz, gx, gy, gz, shared)
|
||||
gpuocelot_lib.ptx_run.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int] # noqa: E501
|
||||
cuda.cuLaunchKernel = lambda src, gx, gy, gz, lx, ly, lz, shared, stream, unused_extra, args: gpuocelot_lib.ptx_run(src, len(args), (ctypes.c_void_p * len(args))(*[ctypes.cast(x, ctypes.c_void_p) for x in args]), lx, ly, lz, gx, gy, gz, shared) # noqa: E501
|
||||
|
||||
def check(status):
|
||||
if status != 0: raise RuntimeError(f"CUDA Error {status}, {ctypes.string_at(init_c_var(ctypes.POINTER(ctypes.c_char)(), lambda x: cuda.cuGetErrorString(status, ctypes.byref(x)))).decode()}")
|
||||
if status != 0: raise RuntimeError(f"CUDA Error {status}, {ctypes.string_at(init_c_var(ctypes.POINTER(ctypes.c_char)(), lambda x: cuda.cuGetErrorString(status, ctypes.byref(x)))).decode()}") # noqa: E501
|
||||
|
||||
def cu_time_execution(cb, enable=False) -> Optional[float]: return time_execution_cuda_style(cb, cuda.CUevent, cuda.cuEventCreate, cuda.cuEventRecord, cuda.cuEventSynchronize, cuda.cuEventDestroy_v2, cuda.cuEventElapsedTime, enable=enable) if not CUDACPU else cpu_time_execution(cb, enable=enable)
|
||||
def cu_time_execution(cb, enable=False) -> Optional[float]: return time_execution_cuda_style(cb, cuda.CUevent, cuda.cuEventCreate, cuda.cuEventRecord, cuda.cuEventSynchronize, cuda.cuEventDestroy_v2, cuda.cuEventElapsedTime, enable=enable) if not CUDACPU else cpu_time_execution(cb, enable=enable) # noqa: E501
|
||||
|
||||
@diskcache
|
||||
def compile_cuda(prg) -> bytes: return compile_cuda_style(prg, [f'--gpu-architecture={CUDADevice.default_arch_name}', "-I/usr/local/cuda/include", "-I/usr/include"], cuda.nvrtcProgram, cuda.nvrtcCreateProgram, cuda.nvrtcCompileProgram, cuda.nvrtcGetPTX, cuda.nvrtcGetPTXSize, cuda.nvrtcGetProgramLog, cuda.nvrtcGetProgramLogSize, check)
|
||||
def compile_cuda(prg) -> bytes: return compile_cuda_style(prg, [f'--gpu-architecture={CUDADevice.default_arch_name}', "-I/usr/local/cuda/include", "-I/usr/include"], cuda.nvrtcProgram, cuda.nvrtcCreateProgram, cuda.nvrtcCompileProgram, cuda.nvrtcGetPTX, cuda.nvrtcGetPTXSize, cuda.nvrtcGetProgramLog, cuda.nvrtcGetProgramLogSize, check) # noqa: E501
|
||||
|
||||
class CUDAProgram:
|
||||
def __init__(self, device:CUDADevice, name:str, lib:bytes):
|
||||
|
@ -46,7 +46,7 @@ class CUDAProgram:
|
|||
def __call__(self, *bufs, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int], vals:Tuple[int, ...]=(), wait=False):
|
||||
if not CUDACPU: check(cuda.cuCtxSetCurrent(self.device.context))
|
||||
c_kernel_input_config = encode_args_cuda_style(bufs, vals, cuda.CUdeviceptr_v2, (1,2,0))[0] if not CUDACPU else (bufs+tuple(vals))
|
||||
return cu_time_execution(lambda: check(cuda.cuLaunchKernel(self.prg, *global_size, *local_size, 0, None, None, c_kernel_input_config)), enable=wait)
|
||||
return cu_time_execution(lambda: check(cuda.cuLaunchKernel(self.prg, *global_size, *local_size, 0, None, None, c_kernel_input_config)), enable=wait) # noqa: E501
|
||||
|
||||
class CUDAAllocator(LRUAllocator):
|
||||
def __init__(self, device:CUDADevice):
|
||||
|
|
|
@ -13,7 +13,8 @@ class UnderlyingDiskBuffer:
|
|||
if self.fd: self.fd.close()
|
||||
|
||||
class DiskBuffer:
|
||||
def __init__(self, ud:UnderlyingDiskBuffer, size:int, dtype:DType=dtypes.uint8, offset=0): self.ud, self.size, self.dtype, self.offset = ud, size, dtype, offset
|
||||
def __init__(self, ud:UnderlyingDiskBuffer, size:int, dtype:DType=dtypes.uint8, offset=0):
|
||||
self.ud, self.size, self.dtype, self.offset = ud, size, dtype, offset
|
||||
def __repr__(self): return f"<DiskBuffer size={self.size} dtype={self.dtype} offset={self.offset}>"
|
||||
def cast(self, arg:Tuple[DType, bool]): return DiskBuffer(self.ud, self.size, arg[0], offset=self.offset)
|
||||
def as_strided(self, arg):
|
||||
|
|
|
@ -7,7 +7,8 @@ from tinygrad.codegen.kernel import LinearizerOptions
|
|||
from tinygrad.renderer.cstyle import OpenCLRenderer
|
||||
from tinygrad.device import Compiled, LRUAllocator
|
||||
|
||||
OSX_TIMING_RATIO = (125/3) if OSX else 1.0 # see test/external/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something
|
||||
# see test/external/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something
|
||||
OSX_TIMING_RATIO = (125/3) if OSX else 1.0
|
||||
|
||||
def check(status):
|
||||
if status != 0: raise RuntimeError(f"OpenCL Error {status}")
|
||||
|
@ -16,22 +17,24 @@ def checked(ret, status): return (check(status.value), ret)[1]
|
|||
@diskcache
|
||||
def compile_cl(prg:str) -> bytes:
|
||||
assert CLDevice.compiler_context is not None, 'OpenCL requires a "compiler_context" to compile, init a device before you call this'
|
||||
program = checked(cl.clCreateProgramWithSource(CLDevice.compiler_context.context, 1, to_char_p_p([prg_bytes := prg.encode()]), ctypes.byref(ctypes.c_size_t(len(prg_bytes))), ctypes.byref(status := ctypes.c_int32())), status)
|
||||
program = checked(cl.clCreateProgramWithSource(CLDevice.compiler_context.context, 1, to_char_p_p([prg_bytes := prg.encode()]),
|
||||
ctypes.byref(ctypes.c_size_t(len(prg_bytes))), ctypes.byref(status := ctypes.c_int32())), status)
|
||||
status = cl.clBuildProgram(program, 1, ctypes.byref(CLDevice.compiler_context.device_id), None, cl.clBuildProgram.argtypes[4](), None)
|
||||
if status != 0:
|
||||
cl.clGetProgramBuildInfo(program, CLDevice.compiler_context.device_id, cl.CL_PROGRAM_BUILD_LOG, 0, None, ctypes.byref(log_size := ctypes.c_size_t()))
|
||||
cl.clGetProgramBuildInfo(program, CLDevice.compiler_context.device_id, cl.CL_PROGRAM_BUILD_LOG, log_size.value, mstr := ctypes.create_string_buffer(log_size.value), None)
|
||||
cl.clGetProgramBuildInfo(program, CLDevice.compiler_context.device_id, cl.CL_PROGRAM_BUILD_LOG, 0, None, ctypes.byref(log_size := ctypes.c_size_t())) # noqa: E501
|
||||
cl.clGetProgramBuildInfo(program, CLDevice.compiler_context.device_id, cl.CL_PROGRAM_BUILD_LOG, log_size.value, mstr := ctypes.create_string_buffer(log_size.value), None) # noqa: E501
|
||||
raise RuntimeError(f"OpenCL Compile Error\n\n{ctypes.string_at(mstr, size=log_size.value).decode()}")
|
||||
binary_sizes = init_c_var((ctypes.c_size_t * 1)(), lambda x: check(cl.clGetProgramInfo(program, cl.CL_PROGRAM_BINARY_SIZES, ctypes.sizeof(x), ctypes.byref(x), None)))
|
||||
binary = init_c_var(ctypes.create_string_buffer(binary_sizes[0]), lambda x: check(cl.clGetProgramInfo(program, cl.CL_PROGRAM_BINARIES, ctypes.sizeof(ctypes.c_void_p), ctypes.byref((ctypes.c_void_p * 1)(ctypes.addressof(x))), None)))
|
||||
binary_sizes = init_c_var((ctypes.c_size_t * 1)(), lambda x: check(cl.clGetProgramInfo(program, cl.CL_PROGRAM_BINARY_SIZES, ctypes.sizeof(x), ctypes.byref(x), None))) # noqa: E501
|
||||
binary = init_c_var(ctypes.create_string_buffer(binary_sizes[0]), lambda x: check(cl.clGetProgramInfo(program, cl.CL_PROGRAM_BINARIES, ctypes.sizeof(ctypes.c_void_p), ctypes.byref((ctypes.c_void_p * 1)(ctypes.addressof(x))), None))) # noqa: E501
|
||||
check(cl.clReleaseProgram(program))
|
||||
return bytes(binary)
|
||||
|
||||
class CLProgram:
|
||||
def __init__(self, device:CLDevice, name:str, lib:bytes):
|
||||
self.device, self.name, self.lib = device, name, lib
|
||||
self.program = checked(cl.clCreateProgramWithBinary(device.context, 1, ctypes.byref(device.device_id), (ctypes.c_size_t * 1)(len(lib)), to_char_p_p([lib], ctypes.c_ubyte),
|
||||
ctypes.byref(binary_status := ctypes.c_int32()), ctypes.byref(errcode_ret := ctypes.c_int32())), errcode_ret)
|
||||
self.program = checked(cl.clCreateProgramWithBinary(device.context, 1, ctypes.byref(device.device_id), (ctypes.c_size_t * 1)(len(lib)),
|
||||
to_char_p_p([lib], ctypes.c_ubyte), ctypes.byref(binary_status := ctypes.c_int32()),
|
||||
ctypes.byref(errcode_ret := ctypes.c_int32())), errcode_ret)
|
||||
check(binary_status.value)
|
||||
check(cl.clBuildProgram(self.program, 1, ctypes.byref(device.device_id), None, cl.clBuildProgram.argtypes[4](), None)) # NOTE: OSX requires this
|
||||
self.kernel = checked(cl.clCreateKernel(self.program, name.encode(), ctypes.byref(status := ctypes.c_int32())), status)
|
||||
|
@ -40,16 +43,16 @@ class CLProgram:
|
|||
check(cl.clReleaseKernel(self.kernel))
|
||||
check(cl.clReleaseProgram(self.program))
|
||||
|
||||
def __call__(self, *bufs:cl.cl_mem, global_size:Tuple[int,...], local_size:Optional[Tuple[int,...]]=None, vals:Tuple[int, ...]=(), wait=False) -> Optional[float]:
|
||||
def __call__(self, *bufs:cl.cl_mem, global_size:Tuple[int,...], local_size:Optional[Tuple[int,...]]=None, vals:Tuple[int, ...]=(), wait=False) -> Optional[float]: # noqa: E501
|
||||
for i,b in enumerate(bufs): cl.clSetKernelArg(self.kernel, i, ctypes.sizeof(b), ctypes.byref(b))
|
||||
for i,b in enumerate(vals,start=len(bufs)): cl.clSetKernelArg(self.kernel, i, 4, ctypes.byref(ctypes.c_int32(b)))
|
||||
if local_size is not None: global_size = tuple(int(g*l) for g,l in zip(global_size, local_size))
|
||||
event = cl.cl_event() if wait else None
|
||||
check(cl.clEnqueueNDRangeKernel(self.device.queue, self.kernel, len(global_size), None, (ctypes.c_size_t * len(global_size))(*global_size), (ctypes.c_size_t * len(local_size))(*local_size) if local_size else None, 0, None, event))
|
||||
check(cl.clEnqueueNDRangeKernel(self.device.queue, self.kernel, len(global_size), None, (ctypes.c_size_t * len(global_size))(*global_size), (ctypes.c_size_t * len(local_size))(*local_size) if local_size else None, 0, None, event)) # noqa: E501
|
||||
if wait:
|
||||
check(cl.clWaitForEvents(1, ctypes.byref(event)))
|
||||
start = init_c_var(ctypes.c_ulong(), lambda x: check(cl.clGetEventProfilingInfo(event, cl.CL_PROFILING_COMMAND_START, ctypes.sizeof(x), ctypes.byref(x), None)))
|
||||
end = init_c_var(ctypes.c_ulong(), lambda x: check(cl.clGetEventProfilingInfo(event, cl.CL_PROFILING_COMMAND_END, ctypes.sizeof(x), ctypes.byref(x), None)))
|
||||
start = init_c_var(ctypes.c_ulong(), lambda x: check(cl.clGetEventProfilingInfo(event, cl.CL_PROFILING_COMMAND_START, ctypes.sizeof(x), ctypes.byref(x), None))) # noqa: E501
|
||||
end = init_c_var(ctypes.c_ulong(), lambda x: check(cl.clGetEventProfilingInfo(event, cl.CL_PROFILING_COMMAND_END, ctypes.sizeof(x), ctypes.byref(x), None))) # noqa: E501
|
||||
return float(end.value-start.value) * OSX_TIMING_RATIO * 1e-9
|
||||
return None
|
||||
|
||||
|
@ -83,10 +86,10 @@ class CLDevice(Compiled):
|
|||
err = cl.clGetDeviceIDs(platform_ids[0], device_type, 0, None, ctypes.byref(num_devices))
|
||||
if err == 0 and num_devices.value != 0: break
|
||||
if DEBUG >= 1: print(f"CLDevice: got {num_platforms.value} platforms and {num_devices.value} devices")
|
||||
CLDevice.device_ids = init_c_var((cl.cl_device_id * num_devices.value)(), lambda x: check(cl.clGetDeviceIDs(platform_ids[0], device_type, num_devices, x, None)))
|
||||
CLDevice.device_ids = init_c_var((cl.cl_device_id * num_devices.value)(), lambda x: check(cl.clGetDeviceIDs(platform_ids[0], device_type, num_devices, x, None))) # noqa: E501
|
||||
|
||||
self.device_id = CLDevice.device_ids[0 if ":" not in device else int(device.split(":")[1])]
|
||||
self.context = checked(cl.clCreateContext(None, 1, ctypes.byref(self.device_id), cl.clCreateContext.argtypes[3](), None, ctypes.byref(status := ctypes.c_int32())), status)
|
||||
self.context = checked(cl.clCreateContext(None, 1, ctypes.byref(self.device_id), cl.clCreateContext.argtypes[3](), None, ctypes.byref(status := ctypes.c_int32())), status) # noqa: E501
|
||||
if CLDevice.compiler_context is None: CLDevice.compiler_context = self
|
||||
self.queue = checked(cl.clCreateCommandQueue(self.context, self.device_id, cl.CL_QUEUE_PROFILING_ENABLE, ctypes.byref(status)), status)
|
||||
self.pending_copyin: List[memoryview] = []
|
||||
|
|
|
@ -12,10 +12,11 @@ MOCKHIP = getenv("MOCKHIP") # for CI. don't run kernels, only check if they comp
|
|||
def check(status):
|
||||
if status != 0: raise RuntimeError(f"HIP Error {status}, {ctypes.string_at(hip.hipGetErrorString(status)).decode()}")
|
||||
|
||||
def hip_time_execution(cb, enable=False): return time_execution_cuda_style(cb, hip.hipEvent_t, hip.hipEventCreate, hip.hipEventRecord, hip.hipEventSynchronize, hip.hipEventDestroy, hip.hipEventElapsedTime, enable=enable)
|
||||
# TODO: remove these helpers, they increase complexity
|
||||
def hip_time_execution(cb, enable=False): return time_execution_cuda_style(cb, hip.hipEvent_t, hip.hipEventCreate, hip.hipEventRecord, hip.hipEventSynchronize, hip.hipEventDestroy, hip.hipEventElapsedTime, enable=enable) # noqa: E501
|
||||
|
||||
@diskcache
|
||||
def compile_hip(prg) -> bytes: return compile_cuda_style(prg, [f'--offload-arch={HIPDevice.default_arch_name}'], hip.hiprtcProgram, hip.hiprtcCreateProgram, hip.hiprtcCompileProgram, hip.hiprtcGetCode, hip.hiprtcGetCodeSize, hip.hiprtcGetProgramLog, hip.hiprtcGetProgramLogSize, check)
|
||||
def compile_hip(prg) -> bytes: return compile_cuda_style(prg, [f'--offload-arch={HIPDevice.default_arch_name}'], hip.hiprtcProgram, hip.hiprtcCreateProgram, hip.hiprtcCompileProgram, hip.hiprtcGetCode, hip.hiprtcGetCodeSize, hip.hiprtcGetProgramLog, hip.hiprtcGetProgramLogSize, check) # noqa: E501
|
||||
|
||||
class HIPProgram:
|
||||
def __init__(self, device:int, name:str, lib:bytes):
|
||||
|
@ -36,7 +37,7 @@ class HIPProgram:
|
|||
def __call__(self, *args, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int], vals:Tuple[int, ...]=(), wait=False):
|
||||
if MOCKHIP: return float("inf")
|
||||
check(hip.hipSetDevice(self.device))
|
||||
return hip_time_execution(lambda: check(hip.hipModuleLaunchKernel(self.prg, *global_size, *local_size, 0, None, None, encode_args_cuda_style(args, vals, hip.hipDeviceptr_t, marks=(1,2,3))[0])), enable=wait)
|
||||
return hip_time_execution(lambda: check(hip.hipModuleLaunchKernel(self.prg, *global_size, *local_size, 0, None, None, encode_args_cuda_style(args, vals, hip.hipDeviceptr_t, marks=(1,2,3))[0])), enable=wait) # noqa: E501
|
||||
|
||||
T = TypeVar("T")
|
||||
class HIPAllocator(LRUAllocator):
|
||||
|
@ -63,10 +64,11 @@ class HIPDevice(Compiled):
|
|||
default_arch_name = "gfx1100"
|
||||
def __init__(self, device:str=""):
|
||||
self.device = int(device.split(":")[1]) if ":" in device else 0
|
||||
if self.device == 0 and not MOCKHIP: HIPDevice.default_arch_name = init_c_var(hip.hipDeviceProp_t(), lambda x: check(hip.hipGetDeviceProperties(x, self.device))).gcnArchName.decode()
|
||||
if self.device == 0 and not MOCKHIP: HIPDevice.default_arch_name = init_c_var(hip.hipDeviceProp_t(), lambda x: check(hip.hipGetDeviceProperties(x, self.device))).gcnArchName.decode() # noqa: E501
|
||||
|
||||
from tinygrad.features.graph.hip import HIPGraph
|
||||
super().__init__(MallocAllocator if MOCKHIP else HIPAllocator(self.device), LinearizerOptions(device="HIP"), HIPRenderer, compile_hip, functools.partial(HIPProgram, self.device), HIPGraph)
|
||||
super().__init__(MallocAllocator if MOCKHIP else HIPAllocator(self.device), LinearizerOptions(device="HIP"), HIPRenderer,
|
||||
compile_hip, functools.partial(HIPProgram, self.device), HIPGraph)
|
||||
def synchronize(self):
|
||||
check(hip.hipSetDevice(self.device))
|
||||
check(hip.hipDeviceSynchronize())
|
|
@ -63,4 +63,5 @@ class LLVMProgram:
|
|||
self.cfunc = CFUNCTYPE(ctypes.c_int, *([ctypes.c_void_p]*len(bufs)), *([ctypes.c_int32]*len(vals)))(self.fxn)
|
||||
return cpu_time_execution(lambda: self.cfunc(*bufs, *vals), enable=wait)
|
||||
|
||||
LLVMDevice = Compiled(MallocAllocator, LinearizerOptions(supports_float4=False, has_local=False, has_shared=False), uops_to_llvm_ir, compile_llvm, LLVMProgram)
|
||||
LLVMDevice = Compiled(MallocAllocator, LinearizerOptions(supports_float4=False, has_local=False, has_shared=False),
|
||||
uops_to_llvm_ir, compile_llvm, LLVMProgram)
|
||||
|
|
|
@ -32,7 +32,7 @@ class MetalProgram:
|
|||
self.pipeline_state = unwrap2(self.device.device.newComputePipelineStateWithFunction_error_(self.fxn, None))
|
||||
|
||||
def __call__(self, *bufs, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int], vals:Tuple[int, ...]=(), wait=False):
|
||||
assert prod(local_size) <= self.pipeline_state.maxTotalThreadsPerThreadgroup(), f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}"
|
||||
assert prod(local_size) <= self.pipeline_state.maxTotalThreadsPerThreadgroup(),f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}" # noqa: E501
|
||||
command_buffer = self.device.mtl_queue.commandBuffer()
|
||||
encoder = command_buffer.computeCommandEncoder()
|
||||
encoder.setComputePipelineState_(self.pipeline_state)
|
||||
|
@ -81,7 +81,8 @@ class MetalDevice(Compiled):
|
|||
self.mtl_buffers_in_flight: List[Any] = []
|
||||
self.mv_in_metal: List[memoryview] = []
|
||||
from tinygrad.features.graph.metal import MetalGraph
|
||||
super().__init__(MetalAllocator(self), LinearizerOptions(device="METAL"), MetalRenderer, compile_metal, functools.partial(MetalProgram, self), functools.partial(MetalGraph, self))
|
||||
super().__init__(MetalAllocator(self), LinearizerOptions(device="METAL"), MetalRenderer,
|
||||
compile_metal, functools.partial(MetalProgram, self), functools.partial(MetalGraph, self))
|
||||
def synchronize(self):
|
||||
for cbuf in self.mtl_buffers_in_flight: cbuf.waitUntilCompleted()
|
||||
self.mv_in_metal.clear()
|
||||
|
|
|
@ -7,7 +7,9 @@ from tinygrad.helpers import getenv, dtypes
|
|||
from tinygrad.runtime.ops_cpu import einsum_mulacc, shape_to_axis
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
|
||||
type_map = {torch.float64: dtypes.float64, torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.int8: dtypes.int8, torch.int32: dtypes.int32, torch.int64: dtypes.int64, torch.uint8: dtypes.uint8, torch.bool: dtypes.bool, torch.int16: dtypes.int16}
|
||||
type_map = {torch.float64: dtypes.float64, torch.float16: dtypes.float16, torch.float32: dtypes.float32,
|
||||
torch.int8: dtypes.int8, torch.int32: dtypes.int32, torch.int64: dtypes.int64,
|
||||
torch.uint8: dtypes.uint8, torch.bool: dtypes.bool, torch.int16: dtypes.int16}
|
||||
inverse_type_map = {v:k for k,v in type_map.items()}
|
||||
|
||||
def output_type(x, y): return x.dtype if type_map[x.dtype].priority > type_map[y.dtype].priority else y.dtype
|
||||
|
@ -27,7 +29,8 @@ torch_fxn_for_op: Dict[Op, Callable] = {
|
|||
#BufferOps.CONST: lambda val, dtype: torch.tensor(val, device=device, dtype=inverse_type_map[dtype]),
|
||||
BufferOps.CONST: lambda val, dtype: torch.from_numpy(np.array(val, dtype=dtype.np)).to(device),
|
||||
UnaryOps.SQRT: lambda x: x.sqrt(), UnaryOps.EXP2: lambda x: x.exp2(), UnaryOps.LOG2: lambda x: x.log2(), UnaryOps.SIN: torch.sin,
|
||||
UnaryOps.CAST: lambda x,y: (x.view if y[1] else x.type)(next(k for k,v in type_map.items() if v==y[0])), UnaryOps.NEG: lambda x: torch.logical_not(x) if x.dtype is torch.bool else torch.neg(x),
|
||||
UnaryOps.CAST: lambda x,y: (x.view if y[1] else x.type)(next(k for k,v in type_map.items() if v==y[0])),
|
||||
UnaryOps.NEG: lambda x: torch.logical_not(x) if x.dtype is torch.bool else torch.neg(x),
|
||||
BinaryOps.MAX: torch.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).type(torch.promote_types(x.dtype, y.dtype)),
|
||||
BinaryOps.ADD: lambda x,y: torch.add(*match_types(x, y)).type(output_type(x,y)),
|
||||
BinaryOps.SUB: lambda x,y: torch.sub(*match_types(x, y, disallow_bool=True)).type(output_type(x,y)),
|
||||
|
@ -38,7 +41,8 @@ torch_fxn_for_op: Dict[Op, Callable] = {
|
|||
ReduceOps.MAX: lambda x, new_shape: x.amax(shape_to_axis(x.shape, new_shape), keepdims=True) if x.shape != new_shape else x,
|
||||
MovementOps.AS_STRIDED: as_strided, MovementOps.EXPAND: lambda x, arg: x.expand(arg),
|
||||
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]), # pylint: disable=E1102
|
||||
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(output_type(a,b)), lambda x: x.stride(), lambda x,s: x.expand(s)),
|
||||
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(output_type(a,b)),
|
||||
lambda x: x.stride(), lambda x,s: x.expand(s)),
|
||||
TernaryOps.WHERE: lambda x, y, z: torch.where(x != 0, y, z),
|
||||
}
|
||||
|
||||
|
|
|
@ -11,11 +11,12 @@ def create_uniform(val: int) -> wgpu.GPUBuffer:
|
|||
return buf
|
||||
|
||||
class WebGPUProgram:
|
||||
def __init__(self, name:str, lib:bytes): self.name, self.lib, self.prg = name, lib, wgpu_device.create_shader_module(code=lib) # NOTE: this is the compiler
|
||||
def __init__(self, name:str, lib:bytes):
|
||||
self.name, self.lib, self.prg = name, lib, wgpu_device.create_shader_module(code=lib) # NOTE: this is the compiler
|
||||
def __call__(self, *bufs, global_size, local_size, vals=(), wait=False):
|
||||
assert len(bufs) <= 8, "WEBGPU only supports 8 buffers"
|
||||
binding_layouts = [{"binding": i, "visibility": wgpu.ShaderStage.COMPUTE, "buffer": {"type": wgpu.BufferBindingType.uniform if i >= len(bufs) else wgpu.BufferBindingType.storage }} for i in range(len(bufs)+len(vals))]
|
||||
bindings = [{"binding": i, "resource": {"buffer": create_uniform(x) if i >= len(bufs) else x, "offset": 0, "size": 4 if i >= len(bufs) else x.size}} for i,x in enumerate(bufs+vals)]
|
||||
binding_layouts = [{"binding": i, "visibility": wgpu.ShaderStage.COMPUTE, "buffer": {"type": wgpu.BufferBindingType.uniform if i >= len(bufs) else wgpu.BufferBindingType.storage }} for i in range(len(bufs)+len(vals))] # noqa: E501
|
||||
bindings = [{"binding": i, "resource": {"buffer": create_uniform(x) if i >= len(bufs) else x, "offset": 0, "size": 4 if i >= len(bufs) else x.size}} for i,x in enumerate(bufs+vals)] # noqa: E501
|
||||
bind_group_layout = wgpu_device.create_bind_group_layout(entries=binding_layouts)
|
||||
pipeline_layout = wgpu_device.create_pipeline_layout(bind_group_layouts=[bind_group_layout])
|
||||
bind_group = wgpu_device.create_bind_group(layout=bind_group_layout, entries=bindings)
|
||||
|
@ -29,11 +30,12 @@ class WebGPUProgram:
|
|||
wgpu_device.queue.submit([command_encoder.finish()])
|
||||
|
||||
class WebGpuAllocator(Allocator):
|
||||
def _alloc(self, size: int): return wgpu_device.create_buffer(size=size, usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_DST | wgpu.BufferUsage.COPY_SRC)
|
||||
def _alloc(self, size: int):
|
||||
return wgpu_device.create_buffer(size=size, usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_DST | wgpu.BufferUsage.COPY_SRC)
|
||||
def copyin(self, dest, src: memoryview): wgpu_device.queue.write_buffer(dest, 0, src)
|
||||
def copyout(self, dest, src: memoryview): dest[:] = wgpu_device.queue.read_buffer(src, 0) # TODO: remove this copy
|
||||
|
||||
class WebGpuDevice(Compiled):
|
||||
def __init__(self, device:str):
|
||||
super().__init__(WebGpuAllocator(), LinearizerOptions(device="WEBGPU", supports_float4=False, local_max=[256, 256, 64], global_max=[65535, 65535, 65535]),
|
||||
WGSLRenderer, lambda x: x, WebGPUProgram)
|
||||
super().__init__(WebGpuAllocator(), LinearizerOptions(device="WEBGPU", supports_float4=False, local_max=[256, 256, 64],
|
||||
global_max=[65535, 65535, 65535]), WGSLRenderer, lambda x: x, WebGPUProgram)
|
||||
|
|
|
@ -32,7 +32,7 @@ def expr_node(view:View, idx:Optional[Node]=None) -> Node:
|
|||
# generate an expression if you have a variable or expression for each index
|
||||
def expr_idxs(view:View, idxs:Tuple[Node, ...]) -> Node:
|
||||
assert len(idxs) == len(view.shape), f"need an idx for all dimensions {idxs} vs {view.shape}"
|
||||
return Variable.sum([NumNode(view.offset) if isinstance(view.offset, int) else view.offset] + [idx*st for idx,sh,st in zip(idxs, view.shape, view.strides) if sh != 1 and st != 0])
|
||||
return Variable.sum([NumNode(view.offset) if isinstance(view.offset, int) else view.offset] + [idx*st for idx,sh,st in zip(idxs, view.shape, view.strides) if sh != 1 and st != 0]) # noqa: E501
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def merge_views(vm2:View, vm1:View) -> Optional[View]:
|
||||
|
@ -52,7 +52,8 @@ def idxs_to_idx(shape:Tuple[int, ...], idxs:Tuple[Node, ...]) -> Node:
|
|||
@dataclass(frozen=True)
|
||||
class ShapeTracker:
|
||||
views: Tuple[View, ...]
|
||||
def __post_init__(self): assert isinstance(self.views, tuple) and all(isinstance(v, View) for v in self.views), "ShapeTracker must be created with a tuple of Views"
|
||||
def __post_init__(self):
|
||||
assert isinstance(self.views, tuple) and all(isinstance(v, View) for v in self.views), "ShapeTracker must be created with a tuple of Views"
|
||||
|
||||
@staticmethod
|
||||
def from_shape(shape:Tuple[sint, ...]): return ShapeTracker((View.create(shape),))
|
||||
|
|
|
@ -185,7 +185,8 @@ class LtNode(OpNode):
|
|||
if isinstance(self.b, int):
|
||||
return (1, 1) if self.a.max < self.b else (0, 0) if self.a.min >= self.b else (0, 1)
|
||||
return (1, 1) if self.a.max < self.b.min else (0, 0) if self.a.min >= self.b.max else (0, 1)
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) < (self.b if isinstance(self.b, int) else self.b.substitute(var_vals))
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node:
|
||||
return self.a.substitute(var_vals) < (self.b if isinstance(self.b, int) else self.b.substitute(var_vals))
|
||||
|
||||
class MulNode(OpNode):
|
||||
def __lt__(self, b: Union[Node, int]):
|
||||
|
@ -201,7 +202,8 @@ class MulNode(OpNode):
|
|||
a = (self.a * (self.b%b))
|
||||
return Node.__mod__(a, b)
|
||||
def get_bounds(self) -> Tuple[int, int]: return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) * (self.b if isinstance(self.b, int) else self.b.substitute(var_vals))
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node:
|
||||
return self.a.substitute(var_vals) * (self.b if isinstance(self.b, int) else self.b.substitute(var_vals))
|
||||
|
||||
class DivNode(OpNode):
|
||||
def __floordiv__(self, b: Union[Node, int], _=False): return self.a//(self.b*b) # two divs is one div
|
||||
|
@ -219,7 +221,7 @@ class ModNode(OpNode):
|
|||
return Node.__floordiv__(self, b, factoring_allowed)
|
||||
def get_bounds(self) -> Tuple[int, int]:
|
||||
assert self.a.min >= 0 and isinstance(self.b, int)
|
||||
return (0, self.b-1) if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b) else (self.a.min%self.b, self.a.max%self.b)
|
||||
return (0, self.b-1) if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b) else (self.a.min%self.b, self.a.max%self.b) # noqa: E501
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) % self.b
|
||||
|
||||
class RedNode(Node):
|
||||
|
@ -335,9 +337,9 @@ sint = Union[Node, int]
|
|||
VariableOrNum = Union[Variable, NumNode]
|
||||
|
||||
render_python: Dict[Type, Callable] = {
|
||||
Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self._val is not None else ''}]" if ctx == "DEBUG" else (f"Variable('{self.expr}', {self.min}, {self.max})"+(f".bind({self.val})" if self._val is not None else '') if ctx == "REPR" else f"{self.expr}"),
|
||||
Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self._val is not None else ''}]" if ctx == "DEBUG" else (f"Variable('{self.expr}', {self.min}, {self.max})"+(f".bind({self.val})" if self._val is not None else '') if ctx == "REPR" else f"{self.expr}"), # noqa: E501
|
||||
NumNode: lambda self,ops,ctx: f"NumNode({self.b})" if ctx == "REPR" else f"{self.b}",
|
||||
MulNode: lambda self,ops,ctx: f"({sym_render(self.b,ops,ctx)}*{self.a.render(ops,ctx)})" if isinstance(self.a,Variable) and isinstance(self.b,Variable) and self.a.expr and self.b.expr and self.b.expr < self.a.expr else f"({self.a.render(ops,ctx)}*{sym_render(self.b,ops,ctx)})",
|
||||
MulNode: lambda self,ops,ctx: f"({sym_render(self.b,ops,ctx)}*{self.a.render(ops,ctx)})" if isinstance(self.a,Variable) and isinstance(self.b,Variable) and self.a.expr and self.b.expr and self.b.expr < self.a.expr else f"({self.a.render(ops,ctx)}*{sym_render(self.b,ops,ctx)})", # noqa: E501
|
||||
DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})",
|
||||
ModNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}%{self.b})",
|
||||
LtNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}<{sym_render(self.b,ops,ctx)})",
|
||||
|
|
|
@ -16,7 +16,7 @@ def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]:
|
|||
return filter_strides(shape, tuple(reversed(strides)))
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tuple[Tuple[int, int], ...]] = None) -> Tuple[Tuple[int, int, int], ...]:
|
||||
def _merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tuple[Tuple[int, int], ...]] = None) -> Tuple[Tuple[int, int, int], ...]: # noqa: E501
|
||||
# merge contiguous subparts or zero strided dims. ret = List[(merged_dims, stride, merged dims w/o zero stride), ...]
|
||||
if not shape: return tuple()
|
||||
assert len(shape) == len(strides)
|
||||
|
@ -95,7 +95,7 @@ class View:
|
|||
new_shape = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.shape])
|
||||
new_strides = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.strides])
|
||||
new_offset = self.offset if isinstance(self.offset, int) else self.offset.substitute(unbound_vars)
|
||||
new_mask = tuple((a if isinstance(a, int) else a.substitute(unbound_vars), b if isinstance(b, int) else b.substitute(unbound_vars)) for (a, b) in self.mask) if self.mask is not None else None
|
||||
new_mask = tuple((a if isinstance(a, int) else a.substitute(unbound_vars), b if isinstance(b, int) else b.substitute(unbound_vars)) for (a, b) in self.mask) if self.mask is not None else None # noqa: E501
|
||||
return View.create(new_shape, new_strides, new_offset, new_mask)
|
||||
|
||||
# MovementOps live here now
|
||||
|
@ -139,7 +139,7 @@ class View:
|
|||
def permute(self, axis: Tuple[int, ...]) -> View:
|
||||
assert all(isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis), f"invalid permute {axis} for {self.shape}"
|
||||
assert len(set(axis)) == len(axis) and len(axis) == len(self.shape), f"can't permute {self.shape} with {axis}"
|
||||
return View.create(tuple([self.shape[a] for a in axis]), tuple([self.strides[a] for a in axis]), self.offset, tuple([self.mask[a] for a in axis]) if self.mask is not None else None)
|
||||
return View.create(tuple([self.shape[a] for a in axis]), tuple([self.strides[a] for a in axis]), self.offset, tuple([self.mask[a] for a in axis]) if self.mask is not None else None) # noqa: E501
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def stride(self, mul: Tuple[int, ...]) -> View:
|
||||
|
@ -148,7 +148,7 @@ class View:
|
|||
strides = tuple([z*m for z,m in zip(self.strides, mul)])
|
||||
new_shape = tuple([(s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul)])
|
||||
offset = sum([(s-1)*z for s,z,m in zip(self.shape, self.strides, mul) if m < 0])
|
||||
mask = tuple([(((mx if m > 0 else s-my)+(abs(m)-1))//abs(m), ((my if m > 0 else s-mx)+(abs(m)-1))//abs(m)) for (mx,my),s,m in zip(self.mask, self.shape, mul)]) if self.mask is not None else None
|
||||
mask = tuple([(((mx if m > 0 else s-my)+(abs(m)-1))//abs(m), ((my if m > 0 else s-mx)+(abs(m)-1))//abs(m)) for (mx,my),s,m in zip(self.mask, self.shape, mul)]) if self.mask is not None else None # noqa: E501
|
||||
return View.create(new_shape, strides, self.offset + offset, mask)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
|
@ -162,7 +162,8 @@ class View:
|
|||
# check for the same size
|
||||
if all_int(self.shape):
|
||||
assert all(isinstance(s, (int, Variable)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
|
||||
if prod(self.shape) != prod([s if isinstance(s, int) else cast(Variable,s).val for s in new_shape]): raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
|
||||
if prod(self.shape) != prod([s if isinstance(s, int) else cast(Variable,s).val for s in new_shape]):
|
||||
raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
|
||||
|
||||
if new_shape == () and self.mask and any(mx==my for (mx,my) in self.mask): return None
|
||||
|
||||
|
|
|
@ -46,7 +46,7 @@ class Tensor:
|
|||
|
||||
no_grad: ClassVar[bool] = False
|
||||
default_type: ClassVar[DType] = dtypes.float32
|
||||
def __init__(self, data:Union[None, int, float, list, LazyBuffer, np.ndarray, bytes], device:Optional[str]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
|
||||
def __init__(self, data:Union[None, int, float, list, LazyBuffer, np.ndarray, bytes], device:Optional[str]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None): # noqa: E501
|
||||
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
|
||||
device = Device.canonicalize(device)
|
||||
# tensors have gradients, buffers do not
|
||||
|
@ -116,7 +116,7 @@ class Tensor:
|
|||
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
|
||||
assert not x.requires_grad # self requires_grad is okay?
|
||||
if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
|
||||
if self.dtype == x.dtype and self.lazydata.realized is not None and not getenv("DISALLOW_ASSIGN"): x.lazydata.output_buffer = self.lazydata.realized
|
||||
if self.dtype == x.dtype and self.lazydata.realized is not None and not getenv("DISALLOW_ASSIGN"): x.lazydata.output_buffer = self.lazydata.realized # noqa: E501
|
||||
self.lazydata = x.lazydata
|
||||
return self
|
||||
def detach(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False)
|
||||
|
@ -132,7 +132,7 @@ class Tensor:
|
|||
assert all_int(self.shape), f"no numpy if shape is symbolic, {self.shape=}"
|
||||
assert self.dtype.np is not None, f"no numpy dtype for {self.dtype}"
|
||||
if 0 in self.shape: return np.zeros(self.shape, dtype=self.dtype.np)
|
||||
return self.detach().cast(dtypes.from_np(self.dtype.np)).contiguous().to('CPU').realize().lazydata.realized.toCPU().astype(self.dtype.np, copy=True).reshape(self.shape)
|
||||
return self.detach().cast(dtypes.from_np(self.dtype.np)).contiguous().to('CPU').realize().lazydata.realized.toCPU().astype(self.dtype.np, copy=True).reshape(self.shape) # noqa: E501
|
||||
|
||||
def to(self, device:Optional[str]) -> Tensor:
|
||||
if device is None or device == self.device: return self
|
||||
|
@ -151,7 +151,7 @@ class Tensor:
|
|||
@staticmethod
|
||||
def _loadop(op, sz, device:Optional[str]=None, dtype:Optional[DType]=None, arg=None, **kwargs):
|
||||
assert isinstance(sz, int), f"cannot create with symbolic size {sz}"
|
||||
return Tensor(LazyBuffer.loadop(op, (sz,), Tensor.default_type if dtype is None else dtype, Device.canonicalize(device), arg), dtype=dtype, device=device, **kwargs)
|
||||
return Tensor(LazyBuffer.loadop(op, (sz,), Tensor.default_type if dtype is None else dtype, Device.canonicalize(device), arg), dtype=dtype, device=device, **kwargs) # noqa: E501
|
||||
|
||||
@staticmethod
|
||||
def empty(*shape, **kwargs):
|
||||
|
@ -168,7 +168,8 @@ class Tensor:
|
|||
# ***** creation helper functions *****
|
||||
|
||||
@staticmethod
|
||||
def full(shape:Tuple[sint, ...], fill_value, **kwargs): return Tensor(fill_value, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape)
|
||||
def full(shape:Tuple[sint, ...], fill_value, **kwargs):
|
||||
return Tensor(fill_value, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape)
|
||||
|
||||
@staticmethod
|
||||
def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0, **kwargs)
|
||||
|
@ -182,9 +183,11 @@ class Tensor:
|
|||
return Tensor.full((math.ceil((stop-start)/step),), step, **kwargs).cumsum() + (start - step)
|
||||
|
||||
@staticmethod
|
||||
def eye(dim:int, **kwargs): return Tensor.full((dim,1),1,**kwargs).pad(((0,0),(0,dim))).reshape(dim*(dim+1)).shrink(((0,dim*dim),)).reshape(dim, dim)
|
||||
def eye(dim:int, **kwargs):
|
||||
return Tensor.full((dim,1),1,**kwargs).pad(((0,0),(0,dim))).reshape(dim*(dim+1)).shrink(((0,dim*dim),)).reshape(dim, dim)
|
||||
|
||||
def full_like(self, fill_value, **kwargs): return Tensor.full(self.shape, fill_value=fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)
|
||||
def full_like(self, fill_value, **kwargs):
|
||||
return Tensor.full(self.shape, fill_value=fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)
|
||||
def zeros_like(self, **kwargs): return self.full_like(0, **kwargs)
|
||||
def ones_like(self, **kwargs): return self.full_like(1, **kwargs)
|
||||
|
||||
|
@ -271,11 +274,14 @@ class Tensor:
|
|||
|
||||
def reshape(self, shape, *args) -> Tensor:
|
||||
new_shape = argfix(shape, *args)
|
||||
return mlops.Reshape.apply(self, shape=tuple([-prod(self.shape) // prod(new_shape) if s == -1 else (s if s is not None else self.shape[i]) for i,s in enumerate(new_shape)]))
|
||||
def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))]))
|
||||
return mlops.Reshape.apply(self, shape=tuple([-prod(self.shape) // prod(new_shape) if s == -1 else (s if s is not None else self.shape[i]) for i,s in enumerate(new_shape)])) # noqa: E501
|
||||
def expand(self, shape, *args) -> Tensor:
|
||||
return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))]))
|
||||
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
|
||||
def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
|
||||
def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=tuple(x if x is not None else (0,s) for x,s in zip(arg, self.shape))) if any(x is not None and x != (0,s) for x,s in zip(arg, self.shape)) else self
|
||||
def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor:
|
||||
if not any(x is not None and x != (0,s) for x,s in zip(arg, self.shape)): return self
|
||||
return mlops.Shrink.apply(self, arg=tuple(x if x is not None else (0,s) for x,s in zip(arg, self.shape)))
|
||||
def pad(self, arg:Tuple[Optional[Tuple[sint, sint]], ...], value:float=0.0) -> Tensor:
|
||||
if all(x is None or x == (0,0) for x in arg): return self
|
||||
ret = mlops.Pad.apply(self, arg=(narg:=tuple(x if x is not None else (0,0) for x in arg)))
|
||||
|
@ -314,8 +320,9 @@ class Tensor:
|
|||
# 1. indices normalization and validation
|
||||
# treat internal tuples and lists as Tensors and standardize indices to list type
|
||||
if isinstance(indices, (tuple, list)):
|
||||
if isinstance(indices, list) and all(isinstance(i, int) for i in indices): indices = [Tensor(indices, dtype=dtypes.int32, requires_grad=False, device=self.device)] # special case <indices: List[int]>, a lil ugly
|
||||
else: indices = [Tensor(list(i), dtype=dtypes.int32, requires_grad=False, device=self.device) if isinstance(i, (tuple, list)) else i for i in indices]
|
||||
# special case <indices: List[int]>, a lil ugly
|
||||
if isinstance(indices, list) and all(isinstance(i, int) for i in indices): indices = [Tensor(indices, dtype=dtypes.int32, requires_grad=False, device=self.device)] # noqa: E501
|
||||
else: indices = [Tensor(list(i), dtype=dtypes.int32, requires_grad=False, device=self.device) if isinstance(i, (tuple, list)) else i for i in indices] # noqa: E501
|
||||
else: indices = [indices]
|
||||
|
||||
# filter ellipsis and fill with slice(None) or fill rest of indices with slice(None)
|
||||
|
@ -341,10 +348,12 @@ class Tensor:
|
|||
if any(isinstance(i, slice) and i.step == 0 for i in indices): raise ValueError('slice step cannot be 0')
|
||||
if num_slices > len(self.shape): raise IndexError(f"too many indices for tensor of dimension {len(self.shape)}")
|
||||
for dim in type_dim[int]:
|
||||
if indices_filtered[dim] >= self.shape[dim] or indices_filtered[dim] < -self.shape[dim]: raise IndexError(f"index {indices_filtered[dim]} is out of bounds for dimension {dim} with size {self.shape[dim]}")
|
||||
if indices_filtered[dim] >= self.shape[dim] or indices_filtered[dim] < -self.shape[dim]:
|
||||
raise IndexError(f"index {indices_filtered[dim]} is out of bounds for dimension {dim} with size {self.shape[dim]}")
|
||||
|
||||
# normalize! indices -> start, stop, strides
|
||||
start, stop, strides = zip(*y) if (y := [i.indices(sh) if isinstance(i, slice) else slice(normalized:= i if i != -1 else sh-1, normalized+1, 1).indices(sh) if isinstance(i, int) else (0, sh, 1) for i, sh in zip(indices_filtered, self.shape)]) else ((), (), ()) # type: ignore[arg-type]
|
||||
# TODO: this line is completely unreadable
|
||||
start, stop, strides = zip(*y) if (y := [i.indices(sh) if isinstance(i, slice) else slice(normalized:= i if i != -1 else sh-1, normalized+1, 1).indices(sh) if isinstance(i, int) else (0, sh, 1) for i, sh in zip(indices_filtered, self.shape)]) else ((), (), ()) # type: ignore[arg-type] # noqa: E501
|
||||
|
||||
# 2. basic indexing (no copy)
|
||||
# apply slices and flip where strides are negative
|
||||
|
@ -380,7 +389,7 @@ class Tensor:
|
|||
# compute sum_dim, arange, and idx
|
||||
max_dim = max(i.ndim for i in idx)
|
||||
sum_dim = [d if n==0 else d+max_dim-n for n,d in enumerate(tdim)]
|
||||
arange = [Tensor.arange(ret.shape[d], dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(*[1]*sd, ret.shape[d], *[1]*(ret.ndim + max_dim - n - sd - 1)) for n,(sd,d) in enumerate(zip(sum_dim, tdim))]
|
||||
arange = [Tensor.arange(ret.shape[d], dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(*[1]*sd, ret.shape[d], *[1]*(ret.ndim + max_dim - n - sd - 1)) for n,(sd,d) in enumerate(zip(sum_dim, tdim))] # noqa: E501
|
||||
first_idx = [idx[0].reshape(*[1]*tdim[0], *[1]*(1 + max_dim - idx[0].ndim), *idx[0].shape, *[1]*(ret.ndim - tdim[0] - 1))]
|
||||
rest_idx = [i.reshape(*[1]*tdim[0], *[1]*(max_dim - i.ndim), *i.shape, *[1]*(ret.ndim - tdim[0] - n)) for n,i in enumerate(idx[1:], 1)]
|
||||
idx = first_idx + rest_idx
|
||||
|
@ -410,7 +419,7 @@ class Tensor:
|
|||
idx = idx.transpose(ax1=dim, ax2=0).unsqueeze(-1)
|
||||
permarg = list(range(self.ndim))
|
||||
permarg = permarg[1:dim] + [permarg[0]] + permarg[dim+1:] + [permarg[dim]] if dim != 0 else permarg[1:] + [permarg[0]]
|
||||
return ((idx == Tensor.arange(self.shape[dim], dtype=dtypes.int32, requires_grad=False, device=self.device)) * self.permute(*permarg).shrink(tuple([*[(0,sh) for sh in idx.shape[1:-1]], (0,self.shape[dim])])).unsqueeze(0)).sum(-1).transpose(ax1=0, ax2=dim)
|
||||
return ((idx == Tensor.arange(self.shape[dim], dtype=dtypes.int32, requires_grad=False, device=self.device)) * self.permute(*permarg).shrink(tuple([*[(0,sh) for sh in idx.shape[1:-1]], (0,self.shape[dim])])).unsqueeze(0)).sum(-1).transpose(ax1=0, ax2=dim) # noqa: E501
|
||||
|
||||
def cat(self, *args:Tensor, dim:int=0) -> Tensor:
|
||||
dim = (dim + len(self.shape)) if dim < 0 else dim
|
||||
|
@ -446,7 +455,7 @@ class Tensor:
|
|||
def squeeze(self, dim:Optional[int]=None) -> Tensor:
|
||||
if dim is None: return self if 1 not in self.shape else self.reshape(*[size for size in self.shape if size != 1])
|
||||
if dim <= 0 and self.ndim == 0: return self # This is to match PyTorch behavior
|
||||
if not -self.ndim <= dim < self.ndim: raise IndexError(f"Dimension out of range (expected to be in range of [{-self.ndim if self.ndim > 0 else self.ndim-1}, {self.ndim-1 if self.ndim > 0 else self.ndim}], but got {dim})")
|
||||
if not -self.ndim <= dim < self.ndim: raise IndexError(f"Dimension out of range (expected to be in range of [{-self.ndim if self.ndim > 0 else self.ndim-1}, {self.ndim-1 if self.ndim > 0 else self.ndim}], but got {dim})") # noqa: E501
|
||||
if dim < 0: dim += self.ndim
|
||||
return self if self.shape[dim] != 1 else self.reshape(*[size for idx, size in enumerate(self.shape) if idx != dim])
|
||||
|
||||
|
@ -473,7 +482,8 @@ class Tensor:
|
|||
axis_: List[int] = list(range(len(self.shape))) if axis is None else ([axis] if isinstance(axis, int) else list(axis))
|
||||
axis_ = [x if x >= 0 else x+len(self.shape) for x in axis_]
|
||||
shape = tuple(s for i,s in enumerate(self.shape) if i not in axis_)
|
||||
if 0 in self.shape and 0 not in shape: return Tensor.full(tuple(1 if s == 0 else s for s in self.shape) if keepdim else shape, {mlops.Sum: 0.0, mlops.Max: -float("inf")}[fxn])
|
||||
if 0 in self.shape and 0 not in shape:
|
||||
return Tensor.full(tuple(1 if s == 0 else s for s in self.shape) if keepdim else shape, {mlops.Sum: 0.0, mlops.Max: -float("inf")}[fxn])
|
||||
ret = fxn.apply(self, new_shape=tuple([1 if i in axis_ else s for i,s in enumerate(self.shape)]))
|
||||
return ret if keepdim else ret.reshape(shape=shape)
|
||||
|
||||
|
@ -504,11 +514,11 @@ class Tensor:
|
|||
|
||||
def argmax(self, axis=None, keepdim=False):
|
||||
if axis is None:
|
||||
idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape)
|
||||
idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape) # noqa: E501
|
||||
return prod(self.shape) - idx.max() - 1
|
||||
axis = axis + len(self.shape) if axis < 0 else axis
|
||||
m = self == self.max(axis=axis, keepdim=True)
|
||||
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
|
||||
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1)) # noqa: E501
|
||||
return self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1
|
||||
def argmin(self, axis=None, keepdim=False): return (-self).argmax(axis=axis, keepdim=keepdim)
|
||||
|
||||
|
@ -547,7 +557,7 @@ class Tensor:
|
|||
if any(k > s for k,s in zip(k_, s_)) or any(d != 1 for d in d_):
|
||||
o_ = [(i - d * (k-1) - 1)//s + 1 for i,d,k,s in zip(i_, d_, k_, s_)]
|
||||
e_ = [math.ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)] # expands such that we don't need padding
|
||||
xup = self.reshape(*prefix, *flatten((1,i) for i in i_)).expand(*prefix, *flatten((e,i) for e,i in zip(e_, i_))).reshape(*prefix, *[e*i for e,i in zip(e_, i_)])
|
||||
xup = self.reshape(*prefix, *flatten((1,i) for i in i_)).expand(*prefix, *flatten((e,i) for e,i in zip(e_, i_))).reshape(*prefix, *[e*i for e,i in zip(e_, i_)]) # noqa: E501
|
||||
# slide by dilation
|
||||
xup = xup.slice(slc_prefix + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)])
|
||||
xup = xup.reshape(*prefix, *flatten((k,i+d) for k,i,d in zip(k_, i_, d_)))
|
||||
|
@ -565,8 +575,8 @@ class Tensor:
|
|||
return xup.permute(*range(len(prefix)), *[len(prefix)+i*2 for i in range(len(k_))], *[len(prefix)+i*2+1 for i in range(len(k_))])
|
||||
|
||||
# NOTE: these work for more than 2D
|
||||
def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
|
||||
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
|
||||
def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0))) # noqa: E501
|
||||
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0))) # noqa: E501
|
||||
|
||||
def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor:
|
||||
HW, trailing = weight.shape[2:], list(range(3, len(weight.shape)+1))
|
||||
|
@ -577,62 +587,70 @@ class Tensor:
|
|||
x = x.pad(((0,0), (0,0), *flatten(((0,0),(0,s-1)) for s in stride)))
|
||||
x = x.reshape(*x.shape[:2], *[k*s for k,s in zip(x.shape[2::2], stride)])
|
||||
x = x.shrink(((0,x.shape[0]), (0,x.shape[1]), *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)]))
|
||||
padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, make_pair(dilation, len(HW)), make_pair(padding, len(HW)), make_pair(output_padding, len(HW)))))))
|
||||
padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, make_pair(dilation, len(HW)), make_pair(padding, len(HW)), make_pair(output_padding, len(HW))))))) # noqa: E501
|
||||
return x.conv2d(w.reshape(w.shape[0]*w.shape[1],*w.shape[2:]), groups=groups, bias=bias, dilation=dilation, padding=padding)
|
||||
|
||||
wino = int(getenv("WINO", "0"))
|
||||
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor:
|
||||
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
|
||||
assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})"
|
||||
if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}"
|
||||
padding_ = [padding]*2*len(HW) if isinstance(padding, int) else (padding if len(padding) == 2*len(HW) else [p for p in padding for _ in range(2)][::-1])
|
||||
assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501
|
||||
if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}" # noqa: E501
|
||||
padding_ = [padding]*2*len(HW) if isinstance(padding, int) else (padding if len(padding) == 2*len(HW) else [p for p in padding for _ in range(2)][::-1]) # noqa: E501
|
||||
|
||||
# conv2d is a pooling op (with padding)
|
||||
x = self.pad2d(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W)
|
||||
rcout, oyx = cout//groups, x.shape[2:-len(HW)]
|
||||
if not all(x == 3 for x in HW) or stride != 1 or dilation != 1 or not Tensor.wino:
|
||||
# normal conv
|
||||
x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))])
|
||||
x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))]) # noqa: E501
|
||||
|
||||
# conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
|
||||
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True).reshape(bs, cout, *oyx)
|
||||
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True).reshape(bs, cout, *oyx) # noqa: E501
|
||||
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
|
||||
|
||||
# winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
|
||||
def apply_matrix(mat, t, dim=0): return t if dim == len(HW) else Tensor.stack([apply_matrix(mat, sum(mm*t[j] for j,mm in enumerate(m) if mm), dim=dim+1) for m in mat])
|
||||
def apply_matrix(mat, t, dim=0): return t if dim == len(HW) else Tensor.stack([apply_matrix(mat, sum(mm*t[j] for j,mm in enumerate(m) if mm), dim=dim+1) for m in mat]) # noqa: E501
|
||||
HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles
|
||||
winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]]
|
||||
winograd_G = [[1/4, 0, 0], [-1/6, -1/6, -1/6], [-1/6, 1/6, -1/6], [1/24, 1/12, 1/6], [1/24, -1/12, 1/6], [0, 0, 1]]
|
||||
winograd_At = [[1, 1, 1, 1, 1, 0], [0, 1, -1, 2, -2, 0], [0, 1, 1, 4, 4, 0], [0, 1, -1, 8, -8, 1]] # applying At in pre-order almost doubles compilation time
|
||||
winograd_At = [[1, 1, 1, 1, 1, 0], [0, 1, -1, 2, -2, 0], [0, 1, 1, 4, 4, 0], [0, 1, -1, 8, -8, 1]] # applying At in pre-order doubles compile time
|
||||
|
||||
# todo: stride == dilation
|
||||
# use padding to round up to 4x4 output tiles
|
||||
d = self.pad2d(sum([[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])], []))._pool(HWI, HWO) # (bs, cin_, tyx, HWI)
|
||||
d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW))).contiguous_backward() # move HW to the front: # (HWI, bs, cin_, tyx)
|
||||
# (bs, cin_, tyx, HWI)
|
||||
d = self.pad2d(sum([[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])], []))._pool(HWI, HWO) # noqa: E501
|
||||
# move HW to the front: # (HWI, bs, cin_, tyx)
|
||||
d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW))).contiguous_backward()
|
||||
tyx = d.shape[-len(HWI):] # dim of tiling
|
||||
|
||||
g = weight.permute(*range(len(weight.shape)-len(HW),len(weight.shape)), *range(len(weight.shape)-len(HW))) # move HW to the front
|
||||
|
||||
# compute 6x6 winograd tiles: GgGt, BtdB
|
||||
gfactors = apply_matrix(winograd_G, g).contiguous().reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx))) # (HWI, groups * rcout, cin) -> (HWI, bs=1, groups, rcout, cin, tyx=(1,1))
|
||||
dfactors = apply_matrix(winograd_Bt, d).contiguous().reshape(*HWI, bs, groups, 1, cin, *tyx) # (HWI, bs, cin_, tyx) -> (HWI, bs, groups, 1 ,cin, *tyx)
|
||||
# (HWI, groups * rcout, cin) -> (HWI, bs=1, groups, rcout, cin, tyx=(1,1))
|
||||
gfactors = apply_matrix(winograd_G, g).contiguous().reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx)))
|
||||
# (HWI, bs, cin_, tyx) -> (HWI, bs, groups, 1 ,cin, *tyx)
|
||||
dfactors = apply_matrix(winograd_Bt, d).contiguous().reshape(*HWI, bs, groups, 1, cin, *tyx)
|
||||
|
||||
ret = apply_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW))) # matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx)
|
||||
# matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx)
|
||||
ret = apply_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW)))
|
||||
|
||||
ret = ret.permute([*range(len(HW), len(ret.shape)-len(HW)), *[i+o for i in range(len(HW)) for o in [len(ret.shape)-len(HW),0]]]) # interleave tyx and HWO: (bs, groups, rcout, oy, HO, ox, WO)
|
||||
ret = ret.reshape(bs, cout, *[c * HWO[i] for i, c in enumerate(tyx)]).shrink(tuple((0, s) for s in [bs, cout, *oyx])) # merge groups and rcout, tyx and HWO: (bs, groups, cout, *yx), shrink to final
|
||||
# interleave tyx and HWO: (bs, groups, rcout, oy, HO, ox, WO)
|
||||
ret = ret.permute([*range(len(HW), len(ret.shape)-len(HW)), *[i+o for i in range(len(HW)) for o in [len(ret.shape)-len(HW),0]]])
|
||||
# merge groups and rcout, tyx and HWO: (bs, groups, cout, *yx), shrink to final
|
||||
ret = ret.reshape(bs, cout, *[c * HWO[i] for i, c in enumerate(tyx)]).shrink(tuple((0, s) for s in [bs, cout, *oyx]))
|
||||
|
||||
return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward()
|
||||
|
||||
def dot(self, w:Tensor) -> Tensor:
|
||||
n1, n2 = len(self.shape), len(w.shape)
|
||||
assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
|
||||
assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})"
|
||||
assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" # noqa: E501
|
||||
x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1])
|
||||
w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
|
||||
return (x*w).sum(-1)
|
||||
|
||||
def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor: return self.transpose(axis,-1).pad2d((self.shape[axis]-int(not _first_zero),0))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1)
|
||||
def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor:
|
||||
return self.transpose(axis,-1).pad2d((self.shape[axis]-int(not _first_zero),0))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1)
|
||||
def cumsum(self, axis:int=0) -> Tensor:
|
||||
# TODO: someday the optimizer will find this on it's own
|
||||
# for now this is a two stage cumsum
|
||||
|
@ -646,7 +664,8 @@ class Tensor:
|
|||
return fix(ret) + fix(base_add)
|
||||
|
||||
@staticmethod
|
||||
def _tri(r:int, c:int, k:int=0, **kwargs) -> Tensor: return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c)
|
||||
def _tri(r:int, c:int, k:int=0, **kwargs) -> Tensor:
|
||||
return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c)
|
||||
def triu(self, k:int=0) -> Tensor:
|
||||
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
||||
return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype, device=self.device).where(self, Tensor.zeros_like(self))
|
||||
|
@ -711,7 +730,7 @@ class Tensor:
|
|||
x: Tensor = self
|
||||
if not isinstance(y, Tensor):
|
||||
if 0 in x.shape: return x, x.full_like(y)
|
||||
y = Tensor(y, device=self.device, requires_grad=False, dtype=self.dtype if self.dtype != dtypes.bool and self.dtype.__class__ is not ImageDType else dtypes.float32)
|
||||
y = Tensor(y, device=self.device, requires_grad=False, dtype=self.dtype if self.dtype != dtypes.bool and self.dtype.__class__ is not ImageDType else dtypes.float32) # noqa: E501
|
||||
if reverse: x, y = y, x
|
||||
if (xshape:=x.shape) == (yshape:=y.shape): return (x, y)
|
||||
|
||||
|
@ -742,7 +761,7 @@ class Tensor:
|
|||
return mlops.Mul.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x != 1.0 else self
|
||||
def div(self, x:Union[Tensor, float], reverse=False) -> Tensor:
|
||||
x = self._to_float(x)
|
||||
return mlops.Div.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or reverse or not x or not dtypes.is_float(self.dtype) else self.mul(1/x)
|
||||
return mlops.Div.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or reverse or not x or not dtypes.is_float(self.dtype) else self.mul(1/x) # noqa: E501
|
||||
def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor:
|
||||
x = self._to_float(x)
|
||||
if x.__class__ is not Tensor and not reverse:
|
||||
|
@ -761,7 +780,7 @@ class Tensor:
|
|||
# we need 0 to be positive so we need to correct base_sign when the base is 0
|
||||
base_sign = base_sign - (1.5 * (1 - (self.sign().abs() if not reverse else x.sign().abs() if isinstance(x, Tensor) else abs(int(bool(x))))))
|
||||
# inject nan if the base is negative and the power is not an integer
|
||||
to_nan = (((x - x.trunc()) * 1e10).abs().clip(0, 1) if isinstance(x, Tensor) else int(bool(x - int(x))) if not reverse else ((self - self.trunc()) * 1e10).abs().clip(0, 1)) * base_sign
|
||||
to_nan = (((x - x.trunc()) * 1e10).abs().clip(0, 1) if isinstance(x, Tensor) else int(bool(x - int(x))) if not reverse else ((self - self.trunc()) * 1e10).abs().clip(0, 1)) * base_sign # noqa: E501
|
||||
inject_nan = ((((-to_nan) * 2) + 1)).log().add(1) if isinstance(to_nan, Tensor) else 1 if not to_nan else float("nan")
|
||||
return ar.mul(sign * base_sign + (1 - base_sign)).mul(inject_nan)
|
||||
def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x)
|
||||
|
@ -833,7 +852,7 @@ class Tensor:
|
|||
mask = (Tensor.rand(*self.shape, requires_grad=False, device=self.device) >= p).cast(dtypes.bool)
|
||||
return self * mask * (1/(1.0 - p))
|
||||
|
||||
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
|
||||
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor: # noqa: E501
|
||||
# NOTE: it works if key, value have symbolic shape
|
||||
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
||||
if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool)
|
||||
|
@ -849,7 +868,7 @@ class Tensor:
|
|||
def sparse_categorical_crossentropy(self, Y, ignore_index=-1) -> Tensor:
|
||||
# NOTE: self is a logits input
|
||||
loss_mask = Y != ignore_index
|
||||
y_counter = Tensor.arange(self.shape[-1], dtype=dtypes.int32, requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
|
||||
y_counter = Tensor.arange(self.shape[-1], dtype=dtypes.int32, requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1]) # noqa: E501
|
||||
y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1.0, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
|
||||
return self.log_softmax().mul(y).sum() / loss_mask.sum()
|
||||
|
||||
|
@ -857,7 +876,8 @@ class Tensor:
|
|||
|
||||
def cast(self, dtype:DType) -> Tensor:
|
||||
# hack for devices that don't support bfloat16
|
||||
if self.dtype == dtypes.bfloat16: return self.bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).contiguous().bitcast(dtypes.float32).cast(dtype)
|
||||
if self.dtype == dtypes.bfloat16:
|
||||
return self.bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).contiguous().bitcast(dtypes.float32).cast(dtype)
|
||||
return mlops.Cast.apply(self, dtype=dtype) if self.dtype != dtype else self
|
||||
def bitcast(self, dtype:DType) -> Tensor:
|
||||
assert self.dtype.itemsize == dtype.itemsize, "can't bitcast mismatched dtype itemsizes"
|
||||
|
|
Loading…
Reference in New Issue