move dtypes to dtype.py (#2964)

* move dtypes to dtype.py

* fix urllib
This commit is contained in:
George Hotz 2024-01-01 14:58:48 -08:00 committed by GitHub
parent fadaa2ec28
commit a280cfe169
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
74 changed files with 219 additions and 289 deletions

View File

@ -67,7 +67,7 @@ class Function:
# %%
# == LazyBuffer (in tinygrad/lazy.py, code 5/10) ==
from tinygrad.helpers import DType
from tinygrad.dtype import DType
# this is where the properties live that you thought were a part of Tensor
# LazyBuffer is like a Tensor without derivatives, at the mlop layer

View File

@ -37,7 +37,7 @@ print("******** second, the Device ***********")
DEVICE = "CLANG" # NOTE: you can change this!
import struct
from tinygrad.helpers import dtypes
from tinygrad.dtype import dtypes
from tinygrad.device import Buffer, Device
from tinygrad.ops import LazyOp, BufferOps, MemBuffer, BinaryOps
from tinygrad.shape.shapetracker import ShapeTracker

View File

@ -55,7 +55,7 @@ There are even more of these factory methods, you can find them in the [tensor.p
All the tensors creation methods can take a `dtype` argument to specify the data type of the tensor.
```python
from tinygrad.helpers import dtypes
from tinygrad.dtype import dtypes
t3 = Tensor([1, 2, 3, 4, 5], dtype=dtypes.int32)
```

View File

@ -1,7 +1,6 @@
from typing import Tuple
import time
from tinygrad import Tensor, TinyJit, nn, Variable
from tinygrad.helpers import dtypes # TODO: wouldn't need this if argmax returned the right dtype
from tinygrad import Tensor, TinyJit, nn, Variable, dtypes
import gymnasium as gym
from tqdm import trange
import numpy as np # TODO: remove numpy import

View File

@ -16,8 +16,8 @@ from vits import Y_LENGTH_ESTIMATE_SCALARS, HParams, Synthesizer, TextMapper, ge
from whisper import init_whisper, transcribe_waveform
from sentencepiece import SentencePieceProcessor
from tinygrad.helpers import Timing, dtypes, fetch
from tinygrad.tensor import Tensor
from tinygrad.helpers import Timing, fetch
from tinygrad import Tensor, dtypes
# Whisper constants
RATE = 16000

View File

@ -1,7 +1,5 @@
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes
from tinygrad import Device
from tinygrad import Device, dtypes, Tensor
# TODO: will be better when tinygrad does math in the target dtype, can remove the floor and use a mul
def bit_extract(x, s, e) -> Tensor:

View File

@ -4,13 +4,13 @@ from tqdm import trange
import numpy as np
from tinygrad import Device, GlobalCounters
from typing import Optional, Union
from tinygrad.tensor import Tensor
from tinygrad import Tensor, dtypes
from tinygrad.nn import Embedding, Linear, LayerNorm
from tinygrad.shape.symbolic import Variable
from tinygrad.jit import TinyJit
import tiktoken
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
from tinygrad.helpers import Timing, DEBUG, getenv, fetch, colored, dtypes
from tinygrad.helpers import Timing, DEBUG, getenv, fetch, colored
MAX_CONTEXT = getenv("MAX_CONTEXT", 128)
HALF = getenv("HALF")

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python3
# setup for distributed
from extra import dist
from tinygrad.helpers import getenv, dtypes
from tinygrad.helpers import getenv
if __name__ == "__main__":
if getenv("DIST"):
dist.preinit()
@ -14,7 +14,7 @@ import random, time
import numpy as np
from typing import Any, Dict, Optional, SupportsIndex
from extra.datasets import fetch_cifar, cifar_mean, cifar_std
from tinygrad import nn
from tinygrad import nn, dtypes
from tinygrad.nn.state import get_state_dict
from tinygrad.nn import optim
from tinygrad import Device, GlobalCounters

View File

@ -7,8 +7,8 @@ from pathlib import Path
import sys, argparse, json
import numpy as np
np.set_printoptions(linewidth=200)
from tinygrad.helpers import Timing, Profiling, getenv, DEBUG, dtypes, colored
from tinygrad import Device, GlobalCounters
from tinygrad.helpers import Timing, Profiling, getenv, DEBUG, colored
from tinygrad import Device, GlobalCounters, dtypes
from tinygrad.tensor import Tensor
from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
from extra.models.llama import Transformer, convert_from_huggingface

View File

@ -1,9 +1,9 @@
import time
from pathlib import Path
import numpy as np
from tinygrad import Tensor, GlobalCounters
from tinygrad import Tensor, GlobalCounters, dtypes
from tinygrad.jit import TinyJit
from tinygrad.helpers import getenv, dtypes
from tinygrad.helpers import getenv
from examples.mlperf import helpers
def eval_resnet():

View File

@ -4,9 +4,8 @@ import sys, logging, time, io, math, argparse, operator, numpy as np
from functools import partial, reduce
from pathlib import Path
from typing import Tuple, Optional, Type
from tinygrad import nn
from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes, getenv
from tinygrad import nn, dtypes, Tensor
from tinygrad.helpers import getenv
from tinygrad.nn.state import torch_load
from examples.vits import ResidualCouplingBlock, PosteriorEncoder, Encoder, ResBlock1, ResBlock2, LRELU_SLOPE, sequence_mask, split, download_if_not_present, get_hparams_from_file, load_checkpoint, weight_norm, HParams
from examples.sovits_helpers import preprocess

View File

@ -1,7 +1,6 @@
import math
from typing import Optional, Tuple
from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes
from tinygrad import Tensor, dtypes
import librosa
import soundfile
import numpy as np

View File

@ -9,9 +9,8 @@ from collections import namedtuple
from PIL import Image
import numpy as np
from tqdm import tqdm
from tinygrad.tensor import Tensor
from tinygrad import Device, GlobalCounters
from tinygrad.helpers import dtypes, Timing, Context, getenv, fetch, colored
from tinygrad import Device, GlobalCounters, dtypes, Tensor
from tinygrad.helpers import Timing, Context, getenv, fetch, colored
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
from tinygrad.jit import TinyJit

View File

@ -5,8 +5,8 @@ from phonemizer.punctuation import Punctuation
from functools import reduce
from pathlib import Path
from typing import List
from tinygrad import nn
from tinygrad.helpers import dtypes, fetch
from tinygrad import nn, dtypes
from tinygrad.helpers import fetch
from tinygrad.nn.state import torch_load
from tinygrad.tensor import Tensor
from tinygrad.jit import TinyJit

View File

@ -2,7 +2,7 @@
import numpy as np
import pickle
from tinygrad.runtime.ops_gpu import CLProgram, CLBuffer
from tinygrad.helpers import dtypes
from tinygrad import dtypes
from tqdm import trange, tqdm
from matplotlib import pyplot as plt

View File

@ -1,7 +1,8 @@
from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict, cast
from tinygrad.codegen.linearizer import UOps, MemOp, UOp
from tinygrad.ops import BinaryOps, UnaryOps
from tinygrad.helpers import DType, dtypes, DEBUG
from tinygrad.dtype import DType, dtypes
from tinygrad.helpers import DEBUG
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
import functools
import math

View File

@ -1,9 +1,10 @@
import struct
from platform import system
from typing import Tuple, Dict, List, Optional
from tinygrad import dtypes
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.helpers import dtypes, CI
from tinygrad.helpers import CI
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])

View File

@ -2,7 +2,7 @@ from typing import List
import struct
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.helpers import dtypes
from tinygrad import dtypes
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
from tinygrad.runtime.ops_cuda import arch

View File

@ -1,6 +1,6 @@
import yaml
from typing import Tuple, Set, Dict
from tinygrad.helpers import dtypes
from tinygrad import dtypes
from tinygrad.codegen.assembly import AssemblyCodegen, Register
from tinygrad.codegen.linearizer import UOps
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps

View File

@ -1,7 +1,7 @@
import os, gzip, tarfile, pickle
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes, fetch
from tinygrad import Tensor, dtypes
from tinygrad.helpers import fetch
def fetch_mnist(tensors=False):
parse = lambda file: np.frombuffer(gzip.open(file).read(), dtype=np.uint8).copy()

View File

@ -1,5 +1,4 @@
from tinygrad.helpers import dtypes
from tinygrad.tensor import Tensor
from tinygrad import Tensor, dtypes
from extra.datasets.imagenet import iterate, get_val_files
if __name__ == "__main__":

View File

@ -1,5 +1,5 @@
from typing import Tuple, Dict, List
from tinygrad.helpers import DType
from tinygrad.dtype import DType
from tinygrad.tensor import Device, Tensor
from tinygrad.jit import TinyJit
from tinygrad.nn.state import get_state_dict

View File

@ -1,82 +0,0 @@
old = """__kernel void re_S256_16_8( write_only image2d_t data0, read_only image2d_t data1, read_only image2d_t data2, __global float* data3 ) {
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int idx2 = get_global_id(0); /* 4 */
int idx1 = get_global_id(1); /* 16 */
int idx0 = get_global_id(2); /* 256 */
float acc0 = 0.0f;
for (int idx3 = 0; idx3 < 8; idx3++) {
float4 val1_0 = read_imagef(data1, smp, (int2)(((idx1*8)+idx3), 0)) /* (1, 128, 4) */;
float4 val2_0 = read_imagef(data2, smp, (int2)(((idx1*32)+(idx3*4)+idx2), idx0)) /* (256, 512, 4) */;
acc0+=(val1_0.x*val2_0.x);
acc0+=(val1_0.y*val2_0.y);
acc0+=(val1_0.z*val2_0.z);
acc0+=(val1_0.w*val2_0.w);
}
__local float temp[64];
temp[((idx1*4)+idx2)] = acc0;
barrier(CLK_LOCAL_MEM_FENCE);
if (((idx1*4)+idx2) == 0) {
float4 output0 = (float4)(0.0f,0.0f,0.0f,0.0f);
for (int mid = 0; mid < 16; mid++) {
float4 val5_0 = ((__local float4*)temp)[mid];
output0.x+=val5_0.x;
output0.y+=val5_0.y;
output0.z+=val5_0.z;
output0.w+=val5_0.w;
}
float4 val3_0 = ((__global float4*)data3)[idx0];
write_imagef(data0, (int2)(idx0, 0), (float4)(max((output0.x+val3_0.x),(0.0f)),max((output0.y+val3_0.y),(0.0f)),max((output0.z+val3_0.z),(0.0f)),max((output0.w+val3_0.w),(0.0f)))); /* (1, 256, 4) */
}
}"""
new = """__kernel void r_256_16_4_8_4(write_only image2d_t data0, read_only image2d_t data1, read_only image2d_t data2, const __global float* data3) {
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
__attribute__ ((aligned (16))) __local float temp[64];
int gidx0 = get_group_id(0); /* 256 */
int lidx1 = get_local_id(1); /* 16 */
int lidx2 = get_local_id(0); /* 4 */
float acc0 = 0.0f;
for (int ridx0 = 0; ridx0 < 8; ++ridx0) {
float4 val0 = read_imagef(data1, smp, (int2)(((lidx1*8)+ridx0),0));
float4 val1 = read_imagef(data2, smp, (int2)(((lidx1*32)+lidx2+(ridx0*4)),gidx0));
acc0 = (((val0).x*(val1).x)+acc0);
acc0 = (((val0).y*(val1).y)+acc0);
acc0 = (((val0).z*(val1).z)+acc0);
acc0 = (((val0).w*(val1).w)+acc0);
}
temp[(lidx1*4)+lidx2] = acc0;
barrier(CLK_LOCAL_MEM_FENCE);
float4 acc1 = (float4)(0.0f,0.0f,0.0f,0.0f);
for (int ridx1 = 0; ridx1 < 16; ++ridx1) {
float4 val2 = (float4)(*((__local float4*)(temp+ridx1*4)));
(acc1).x = ((val2).x+(acc1).x);
(acc1).y = ((val2).y+(acc1).y);
(acc1).z = ((val2).z+(acc1).z);
(acc1).w = ((val2).w+(acc1).w);
}
float4 val3 = (float4)(*((__global float4*)(data3+gidx0*4)));
write_imagef(data0, (int2)(gidx0,0), (float4)(max(((acc1).x+(val3).x),0.0f),max(((acc1).y+(val3).y),0.0f),max(((acc1).z+(val3).z),0.0f),max(((acc1).w+(val3).w),0.0f)));
}"""
from tinygrad.runtime.ops_gpu import CLBuffer, CLProgram
from tinygrad.helpers import dtypes, prod
if __name__ == "__main__":
out = CLBuffer(prod((1, 128, 4)), dtypes.imageh((1,128,4)))
x = CLBuffer(prod((1, 128, 4)), dtypes.imageh((1,128,4)))
w = CLBuffer(prod((256, 512, 4)), dtypes.imageh((256, 512, 4)))
b = CLBuffer(1024, dtypes.float)
old = CLProgram("re_S256_16_8", old)
new = CLProgram("r_256_16_4_8_4", new)
old_tms = []
new_tms = []
for i in range(5):
old_tms.append(old([1,1,256], [4,16,1], out, x, w, b, wait=True))
new_tms.append(new([256,1,1], [4,16,1], out, x, w, b, wait=True))
print(f"old: {min(old_tms)*1e6:.2f} us new: {min(new_tms)*1e6:.2f} us")

View File

@ -1,6 +1,7 @@
import time
import numpy as np
from tinygrad.helpers import dtypes, getenv, prod, flat_mv
from tinygrad import dtypes
from tinygrad.helpers import getenv, prod, flat_mv
from tinygrad.runtime.ops_hip import HIPAllocator, HIPProgram, compile_hip
# AMD_LOG_LEVEL=3 ./MIOpenDriver gemm --iter 1000 --time 1 --a_w 2048 --a_h 2048 --b_w 2048

View File

@ -2,8 +2,8 @@ import os
os.environ["METAL"] = "1"
import time
import numpy as np
from tinygrad.helpers import dtypes, getenv, flat_mv
from tinygrad import Device
from tinygrad import Device, dtypes
from tinygrad.helpers import getenv, flat_mv
from tinygrad.runtime.ops_metal import MetalAllocator, MetalDevice, MetalProgram, compile_metal
N = getenv("N", 2048)

View File

@ -5,14 +5,14 @@ import time, torch, torch.mps
from tinygrad.tensor import Tensor
from tinygrad.jit import TinyJit
from tinygrad import Device, GlobalCounters
from tinygrad import Device, GlobalCounters, dtypes
from tinygrad.helpers import colored, getenv, CI, flat_mv
import os
os.environ["METAL"] = "1"
import time
import numpy as np
from tinygrad.helpers import dtypes, getenv
from tinygrad.helpers import getenv
from tinygrad.runtime.ops_metal import MetalAllocator, MetalDevice, MetalProgram, compile_metal
N = 16384

View File

@ -1,7 +1,6 @@
import numpy as np
from tinygrad.helpers import getenv
from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes
from tinygrad import dtypes, Tensor
dtype_in = dtypes.half if getenv("HALF") else dtypes.float
N = getenv("N", 4096)
CNT = getenv("CNT", 10)

View File

@ -3,9 +3,8 @@ import math
import os
import numpy as np
from pathlib import Path
from tinygrad import nn
from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes, get_child, fetch
from tinygrad import nn, Tensor, dtypes
from tinygrad.helpers import get_child, fetch
from tinygrad.nn.state import torch_load
from extra.models.resnet import ResNet
from extra.models.retinanet import nms as _box_nms

View File

@ -2,8 +2,8 @@ from __future__ import annotations
from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
import importlib
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv, DEBUG, dtypes
from tinygrad import Tensor, dtypes
from tinygrad.helpers import getenv, DEBUG
from typing import List, Dict
from onnx import AttributeProto, ModelProto, TensorProto, TypeProto # onnx 1.50 uses serialized file (see onnx/onnx-ml.proto) as descriptors
try:

View File

@ -1,7 +1,8 @@
import functools, io, math
from typing import Union, Tuple, Optional, List, Any
from tinygrad.tensor import Tensor
from tinygrad.helpers import prod, dtypes, ImageDType, flatten
from tinygrad import Tensor, dtypes
from tinygrad.dtype import ImageDType
from tinygrad.helpers import prod, flatten
from extra.onnx import safe_numpy
from onnx.helper import tensor_dtype_to_np_dtype
from onnx import TensorProto

View File

@ -12,7 +12,7 @@ from tinygrad.helpers import getenv
# stuff needed to unpack a kernel
from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer
from tinygrad.helpers import dtypes
from tinygrad.dtype import dtypes
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.shape.symbolic import Variable

View File

@ -5,7 +5,7 @@ import numpy as np
# stuff needed to unpack a kernel
from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer
from tinygrad.helpers import dtypes
from tinygrad.dtype import dtypes
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.shape.symbolic import Variable

View File

@ -1,6 +1,6 @@
# stuff needed to unpack a kernel
from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer
from tinygrad.helpers import dtypes
from tinygrad.dtype import dtypes
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.shape.symbolic import Variable

View File

@ -9,7 +9,7 @@ from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_lo
# stuff needed to unpack a kernel
from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer
from tinygrad.helpers import dtypes
from tinygrad.dtype import dtypes
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.shape.symbolic import Variable

View File

@ -13,8 +13,9 @@ import onnx
from tqdm import tqdm
from typing import Tuple, List, Optional, Dict
from extra.onnx import get_run_onnx
from tinygrad import Tensor, Device, GlobalCounters
from tinygrad.helpers import dtypes, partition, Context, fetch, getenv, ImageDType, GRAPH, DEBUG
from tinygrad import Tensor, Device, GlobalCounters, dtypes
from tinygrad.dtype import ImageDType
from tinygrad.helpers import partition, Context, fetch, getenv, GRAPH, DEBUG
from tinygrad.realize import run_schedule, lower_schedule_item
from tinygrad.ops import LoadOps, ScheduleItem
Device.DEFAULT = "GPU"

View File

@ -1,5 +1,5 @@
from tinygrad.runtime.ops_gpu import CLProgram, CL, CLBuffer
from tinygrad.helpers import dtypes
from tinygrad import dtypes
import time
N = 1000000

View File

@ -1,9 +1,9 @@
#!/usr/bin/env python
import unittest
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad import Tensor, dtypes
from tinygrad.jit import TinyJit
from tinygrad.helpers import dtypes, CI
from tinygrad.helpers import CI
from test.helpers import derandomize_model
from examples.llama import Transformer

View File

@ -4,8 +4,8 @@ from tinygrad.tensor import Tensor
from tinygrad.nn import optim
from tinygrad.nn.state import get_parameters
from tinygrad.jit import TinyJit
from tinygrad import Device, GlobalCounters
from tinygrad.helpers import CI, dtypes
from tinygrad import Device, GlobalCounters, dtypes
from tinygrad.helpers import CI
from tinygrad.shape.symbolic import Variable
from test.helpers import derandomize_model

View File

@ -2,8 +2,7 @@
import unittest
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad import Device
from tinygrad.helpers import dtypes
from tinygrad import Device, dtypes
N = 200 # has to be bigger than the cache to fail

View File

@ -4,7 +4,8 @@
import unittest
import numpy as np
from typing import Optional, Tuple
from tinygrad.helpers import prod, dtypes
from tinygrad.helpers import prod
from tinygrad.dtype import dtypes
# *** first, we implement the atan2 op at the lowest level ***
# `atan2_gpu` for GPUBuffers and `atan2_cpu` for CPUBuffers

View File

@ -2,7 +2,8 @@ import unittest
import numpy as np
import torch
import operator
from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType, OSX, least_upper_float, temp, least_upper_dtype
from tinygrad.helpers import CI, getenv, DEBUG, OSX, temp
from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype
from tinygrad import Device
from tinygrad.tensor import Tensor, dtypes
from typing import Any, List

View File

@ -4,7 +4,8 @@ from tinygrad import Tensor, dtypes, Device
import operator
import numpy as np
from hypothesis import given, strategies as st, settings
from tinygrad.helpers import CI, getenv, DType, OSX
from tinygrad.dtype import DType
from tinygrad.helpers import CI, getenv, OSX
from tinygrad.ops import UnaryOps, get_lazyop_info
settings.register_profile("my_profile", max_examples=200, deadline=None)

View File

@ -1,7 +1,7 @@
import unittest
import numpy as np
from tinygrad import Device, dtypes, Tensor, Variable
from tinygrad.helpers import ImageDType
from tinygrad.dtype import ImageDType
from tinygrad.features.image import to_image_idx
@unittest.skipIf(Device.DEFAULT != "GPU", "only images on GPU")

View File

@ -5,7 +5,7 @@ from tinygrad.tensor import Tensor
# ruff: noqa: F401
from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer
from tinygrad.lazy import LazyBuffer
from tinygrad.helpers import dtypes
from tinygrad import dtypes
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.shape.symbolic import Variable

View File

@ -12,7 +12,8 @@ from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, Node, c
from tinygrad.tensor import Tensor
from tinygrad.jit import CacheCollector
from tinygrad.realize import run_schedule
from tinygrad.helpers import dtypes, prod
from tinygrad.helpers import prod
from tinygrad.dtype import dtypes
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled), "linearizer is only for compiled backends")
class TestLinearizer(unittest.TestCase):

View File

@ -2,13 +2,12 @@
import unittest
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.features.search import Opt, OptOps
from tinygrad import Device
from tinygrad import Device, dtypes
from tinygrad.helpers import OSX, CI
from test.external.fuzz_linearizer import run_linearizer
# stuff needed to unpack a kernel
from tinygrad.ops import LazyOp, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer, get_lazyop_info
from tinygrad.helpers import dtypes
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
inf, nan = float('inf'), float('nan')

View File

@ -5,8 +5,8 @@ import math
import numpy as np
import unittest
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, dtypes
from tinygrad import Device
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI
from tinygrad import Device, dtypes
if CI:
import warnings

View File

@ -3,9 +3,7 @@ import math
import unittest
import numpy as np
import torch
from tinygrad.tensor import Tensor
import tinygrad.nn as nn
from tinygrad.helpers import dtypes
from tinygrad import nn, dtypes, Tensor
from functools import partial
# https://gist.github.com/devries/11405101

View File

@ -7,10 +7,10 @@ from typing import List, Optional
from tinygrad.tensor import Tensor
from tinygrad.ops import LoadOps
from tinygrad.device import Device, Compiled
from tinygrad.helpers import DEBUG, dtypes
from tinygrad.helpers import DEBUG
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.graph import print_tree, realized_lazybuffer
from tinygrad import nn
from tinygrad import nn, dtypes
def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_loadops=True):
seen = set()

View File

@ -1,7 +1,7 @@
import unittest
from tinygrad.tensor import Tensor
from tinygrad.helpers import CI, dtypes
from tinygrad import Device
from tinygrad.helpers import CI
from tinygrad import Device, dtypes
# similar to test/external/external_test_gpu_ast.py, but universal
@unittest.skipIf(Device.DEFAULT == "CUDA" and CI, "slow on CUDA CI")

View File

@ -2,8 +2,8 @@ import numpy as np
import torch
import unittest, copy
import mmap
from tinygrad.tensor import Tensor, Device
from tinygrad.helpers import dtypes, temp
from tinygrad import Tensor, Device, dtypes
from tinygrad.helpers import temp
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
x_init = np.random.randn(1,3).astype(np.float32)

View File

@ -2,7 +2,8 @@
from typing import Optional, Tuple, Any, List
import unittest, math
import numpy as np
from tinygrad.helpers import dtypes, getenv, DType, PtrDType
from tinygrad.dtype import dtypes, DType, PtrDType
from tinygrad.helpers import getenv
from tinygrad.device import Buffer, Device
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.device import CompiledASTRunner, Compiled

View File

@ -1,8 +1,8 @@
#!/usr/bin/env python
import unittest
from tinygrad import dtypes
from tinygrad.ops import LazyOp, BinaryOps, ReduceOps, get_lazyop_info, BufferOps, MemBuffer
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.helpers import dtypes
class TestFlopCounter(unittest.TestCase):
def setUp(self):

View File

@ -1,7 +1,8 @@
import unittest
import numpy as np
from PIL import Image
from tinygrad.helpers import Context, ContextVar, DType, dtypes, merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten
from tinygrad.dtype import DType, dtypes
from tinygrad.helpers import Context, ContextVar, merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten
from tinygrad.shape.symbolic import Variable, NumNode
VARIABLE = ContextVar("VARIABLE", 0)

View File

@ -1,6 +1,6 @@
from tinygrad.tensor import Tensor # noqa: F401
from tinygrad.jit import TinyJit # noqa: F401
from tinygrad.shape.symbolic import Variable # noqa: F401
from tinygrad.helpers import dtypes # noqa: F401
from tinygrad.dtype import dtypes # noqa: F401
from tinygrad.ops import GlobalCounters # noqa: F401
from tinygrad.device import Device # noqa: F401

View File

@ -3,7 +3,8 @@ import os, math, itertools
from typing import NamedTuple, Optional, List, Tuple, cast, Dict, Union
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, vars_from_ast
from tinygrad.device import Device, Compiled
from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, ansilen, getenv, prod, DEBUG, round_up
from tinygrad.dtype import dtypes, ImageDType, DType
from tinygrad.helpers import dedup, colored, ansilen, getenv, prod, DEBUG, round_up
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
from tinygrad.shape.symbolic import sint
from tinygrad.shape.view import View, strides_for_shape

View File

@ -5,7 +5,8 @@ from collections import defaultdict
from enum import Enum, auto
from dataclasses import dataclass
from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, getenv, all_same, to_function_name, flatten
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
from tinygrad.helpers import colored, DEBUG, prod, getenv, all_same, to_function_name, flatten
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, vars_from_ast, get_lazyop_info
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, NumNode, VariableOrNum, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode

View File

@ -3,8 +3,8 @@ import numpy as np
from collections import defaultdict
from typing import TYPE_CHECKING, Union, Any, List, Optional, Dict, Callable
import importlib, inspect, functools, pathlib, time, re, ctypes
from tinygrad.helpers import DType, dtypes, ImageDType, diskcache_get, diskcache_put
from tinygrad.helpers import ansilen, DEBUG, getenv, colored, BEAM, NOOPT, all_int, to_function_name, from_mv, flat_mv
from tinygrad.dtype import DType, dtypes, ImageDType
from tinygrad.helpers import ansilen, DEBUG, getenv, colored, BEAM, NOOPT, all_int, to_function_name, from_mv, flat_mv, diskcache_get, diskcache_put
from tinygrad.shape.symbolic import Variable, sym_infer, sint
from tinygrad.ops import LazyOp, TernaryOps, get_lazyop_info, ReduceOps, BufferOps, BinaryOps, UnaryOps, Op, vars_from_ast, GlobalCounters

99
tinygrad/dtype.py Normal file
View File

@ -0,0 +1,99 @@
from typing import NamedTuple, Final, Optional, ClassVar, Set, Tuple, Dict
import numpy as np # TODO: remove numpy
import functools
# TODO: migrate this from NamedTuple -> dataclass
class DType(NamedTuple):
priority: int # this determines when things get upcasted
itemsize: int
name: str
np: Optional[type] # TODO: someday this will be removed with the "remove numpy" project
sz: int = 1
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self]}" if self.sz == 1 else f"dtypes._{INVERSE_DTYPES_DICT[self.scalar()]}{self.sz}"
def vec(self, sz:int):
assert sz > 1 and self.sz == 1, f"can't vectorize {self} with size {sz}"
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self]}{str(sz)}", None, sz)
def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.sz))]] if self.sz > 1 else self
# dependent typing?
class ImageDType(DType):
def __new__(cls, priority, itemsize, name, np, shape, base):
return super().__new__(cls, priority, itemsize, name, np)
def __init__(self, priority, itemsize, name, np, shape, base):
self.shape: Tuple[int, ...] = shape # arbitrary arg for the dtype, used in image for the shape
self.base: DType = base
super().__init__()
def scalar(self): return self.base
def vec(self, sz:int): return self.base.vec(sz)
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
# TODO: fix this to not need these
def __hash__(self): return hash((super().__hash__(), self.shape))
def __eq__(self, x): return super().__eq__(x) and self.shape == x.shape
def __ne__(self, x): return super().__ne__(x) or self.shape != x.shape
class PtrDType(DType):
def __new__(cls, dt:DType): return super().__new__(cls, dt.priority, dt.itemsize, dt.name, dt.np, dt.sz)
def __repr__(self): return f"ptr.{super().__repr__()}"
class dtypes:
@staticmethod
def is_float(x: DType) -> bool: return x.scalar() in (dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64)
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
def is_int(x: DType) -> bool: return x.scalar() in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64) or dtypes.is_unsigned(x)
@staticmethod
def is_unsigned(x: DType) -> bool: return x.scalar() in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
@staticmethod
def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name]
@staticmethod # NOTE: isinstance(True, int) is True in python
def from_py(x) -> DType: return dtypes.default_float if isinstance(x, float) else dtypes.bool if isinstance(x, bool) else dtypes.default_int
@staticmethod
def fields() -> Dict[str, DType]: return DTYPES_DICT
bool: Final[DType] = DType(0, 1, "bool", np.bool_)
int8: Final[DType] = DType(1, 1, "char", np.int8)
uint8: Final[DType] = DType(2, 1, "unsigned char", np.uint8)
int16: Final[DType] = DType(3, 2, "short", np.int16)
uint16: Final[DType] = DType(4, 2, "unsigned short", np.uint16)
int32: Final[DType] = DType(5, 4, "int", np.int32)
uint32: Final[DType] = DType(6, 4, "unsigned int", np.uint32)
int64: Final[DType] = DType(7, 8, "long", np.int64)
uint64: Final[DType] = DType(8, 8, "unsigned long", np.uint64)
float16: Final[DType] = DType(9, 2, "half", np.float16)
# bfloat16 has higher priority than float16, so least_upper_dtype(dtypes.int64, dtypes.uint64) = dtypes.float16
bfloat16: Final[DType] = DType(10, 2, "__bf16", None)
float32: Final[DType] = DType(11, 4, "float", np.float32)
float64: Final[DType] = DType(12, 8, "double", np.float64)
# dtype aliases
half = float16; float = float32; double = float64 # noqa: E702
uchar = uint8; ushort = uint16; uint = uint32; ulong = uint64 # noqa: E702
char = int8; short = int16; int = int32; long = int64 # noqa: E702
# NOTE: these are image dtypes
@staticmethod
def imageh(shp): return ImageDType(100, 2, "imageh", np.float16, shp, dtypes.float32)
@staticmethod
def imagef(shp): return ImageDType(100, 4, "imagef", np.float32, shp, dtypes.float32)
default_float: ClassVar[DType] = float32
default_int: ClassVar[DType] = int32
# https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html
# we don't support weak type and complex type
promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8],
dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64], dtypes.int64: [dtypes.float16, dtypes.bfloat16],
dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32],
dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.float16, dtypes.bfloat16],
dtypes.float16: [dtypes.float32], dtypes.bfloat16: [dtypes.float32], dtypes.float32: [dtypes.float64], }
@functools.lru_cache(None)
def _get_recursive_parents(dtype:DType) -> Set[DType]:
return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64}
@functools.lru_cache(None)
def least_upper_dtype(*ds:DType) -> DType:
return min(set.intersection(*[_get_recursive_parents(d) for d in ds])) if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0]
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.float32)
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if (
not k.startswith('__') and not k.startswith('default') and not callable(v) and v.__class__ is not staticmethod)}
INVERSE_DTYPES_DICT = {v:k for k,v in DTYPES_DICT.items()}

View File

@ -1,5 +1,6 @@
from typing import Tuple
from tinygrad.helpers import prod, IMAGE, getenv, dtypes, DEBUG
from tinygrad.helpers import prod, IMAGE, getenv, DEBUG
from tinygrad.dtype import dtypes
# *** image Tensor function replacements ***

View File

@ -2,7 +2,8 @@ from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
import itertools, random, math, time, multiprocessing, traceback, signal
from tinygrad.device import Device, Compiled, Buffer
from tinygrad.ops import MemBuffer, vars_from_ast
from tinygrad.helpers import prod, ImageDType, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
from tinygrad.dtype import ImageDType
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.shape.symbolic import sym_infer
from collections import defaultdict

View File

@ -1,9 +1,8 @@
from __future__ import annotations
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats, tempfile, pathlib, string, ctypes
import numpy as np
from urllib import request
from urllib import request # NOTE: this has to be imported specifically
from tqdm import tqdm
from typing import Dict, Tuple, Union, List, NamedTuple, Final, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Set, Sequence
from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
from typing_extensions import TypeGuard
@ -101,104 +100,6 @@ class Profiling(contextlib.ContextDecorator):
self.pr.disable()
pstats.Stats(self.pr).strip_dirs().sort_stats(self.sort).print_stats(self.frac)
# **** tinygrad now supports dtypes! *****
# TODO: migrate this from NamedTuple -> dataclass
class DType(NamedTuple):
priority: int # this determines when things get upcasted
itemsize: int
name: str
np: Optional[type] # TODO: someday this will be removed with the "remove numpy" project
sz: int = 1
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self]}" if self.sz == 1 else f"dtypes._{INVERSE_DTYPES_DICT[self.scalar()]}{self.sz}"
def vec(self, sz:int):
assert sz > 1 and self.sz == 1, f"can't vectorize {self} with size {sz}"
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self]}{str(sz)}", None, sz)
def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.sz))]] if self.sz > 1 else self
# dependent typing?
class ImageDType(DType):
def __new__(cls, priority, itemsize, name, np, shape, base):
return super().__new__(cls, priority, itemsize, name, np)
def __init__(self, priority, itemsize, name, np, shape, base):
self.shape: Tuple[int, ...] = shape # arbitrary arg for the dtype, used in image for the shape
self.base: DType = base
super().__init__()
def scalar(self): return self.base
def vec(self, sz:int): return self.base.vec(sz)
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
# TODO: fix this to not need these
def __hash__(self): return hash((super().__hash__(), self.shape))
def __eq__(self, x): return super().__eq__(x) and self.shape == x.shape
def __ne__(self, x): return super().__ne__(x) or self.shape != x.shape
class PtrDType(DType):
def __new__(cls, dt:DType): return super().__new__(cls, dt.priority, dt.itemsize, dt.name, dt.np, dt.sz)
def __repr__(self): return f"ptr.{super().__repr__()}"
class dtypes:
@staticmethod
def is_float(x: DType) -> bool: return x.scalar() in (dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64)
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
def is_int(x: DType) -> bool: return x.scalar() in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64) or dtypes.is_unsigned(x)
@staticmethod
def is_unsigned(x: DType) -> bool: return x.scalar() in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
@staticmethod
def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name]
@staticmethod # NOTE: isinstance(True, int) is True in python
def from_py(x) -> DType: return dtypes.default_float if isinstance(x, float) else dtypes.bool if isinstance(x, bool) else dtypes.default_int
@staticmethod
def fields() -> Dict[str, DType]: return DTYPES_DICT
bool: Final[DType] = DType(0, 1, "bool", np.bool_)
int8: Final[DType] = DType(1, 1, "char", np.int8)
uint8: Final[DType] = DType(2, 1, "unsigned char", np.uint8)
int16: Final[DType] = DType(3, 2, "short", np.int16)
uint16: Final[DType] = DType(4, 2, "unsigned short", np.uint16)
int32: Final[DType] = DType(5, 4, "int", np.int32)
uint32: Final[DType] = DType(6, 4, "unsigned int", np.uint32)
int64: Final[DType] = DType(7, 8, "long", np.int64)
uint64: Final[DType] = DType(8, 8, "unsigned long", np.uint64)
float16: Final[DType] = DType(9, 2, "half", np.float16)
# bfloat16 has higher priority than float16, so least_upper_dtype(dtypes.int64, dtypes.uint64) = dtypes.float16
bfloat16: Final[DType] = DType(10, 2, "__bf16", None)
float32: Final[DType] = DType(11, 4, "float", np.float32)
float64: Final[DType] = DType(12, 8, "double", np.float64)
# dtype aliases
half = float16; float = float32; double = float64 # noqa: E702
uchar = uint8; ushort = uint16; uint = uint32; ulong = uint64 # noqa: E702
char = int8; short = int16; int = int32; long = int64 # noqa: E702
# NOTE: these are image dtypes
@staticmethod
def imageh(shp): return ImageDType(100, 2, "imageh", np.float16, shp, dtypes.float32)
@staticmethod
def imagef(shp): return ImageDType(100, 4, "imagef", np.float32, shp, dtypes.float32)
default_float: ClassVar[DType] = float32
default_int: ClassVar[DType] = int32
# https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html
# we don't support weak type and complex type
promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8],
dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64], dtypes.int64: [dtypes.float16, dtypes.bfloat16],
dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32],
dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.float16, dtypes.bfloat16],
dtypes.float16: [dtypes.float32], dtypes.bfloat16: [dtypes.float32], dtypes.float32: [dtypes.float64], }
@functools.lru_cache(None)
def _get_recursive_parents(dtype:DType) -> Set[DType]:
return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64}
@functools.lru_cache(None)
def least_upper_dtype(*ds:DType) -> DType:
return min(set.intersection(*[_get_recursive_parents(d) for d in ds])) if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0]
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.float32)
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if (
not k.startswith('__') and not k.startswith('default') and not callable(v) and v.__class__ is not staticmethod)}
INVERSE_DTYPES_DICT = {v:k for k,v in DTYPES_DICT.items()}
# *** universal database cache ***
_cache_dir: str = getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache"))

View File

@ -1,7 +1,8 @@
from __future__ import annotations
from typing import Callable, List, Tuple, Dict, cast, Union, Optional, TypeVar, Generic
import functools, itertools, operator
from tinygrad.helpers import DEBUG, DType, merge_dicts, getenv, all_int, Context, GRAPH
from tinygrad.dtype import DType
from tinygrad.helpers import DEBUG, merge_dicts, getenv, all_int, Context, GRAPH
from tinygrad.device import Device, JITRunner, CompiledASTRunner, Buffer
from tinygrad.tensor import Tensor
from tinygrad.shape.shapetracker import ShapeTracker

View File

@ -2,7 +2,8 @@ from __future__ import annotations
import sys, math
import numpy as np
from typing import Union, Optional, Any, Tuple, List, Set, Dict
from tinygrad.helpers import prod, dtypes, DType, merge_dicts, flatten, getenv, dedup, ImageDType, DEBUG, all_int, all_same
from tinygrad.dtype import dtypes, DType, ImageDType
from tinygrad.helpers import prod, merge_dicts, flatten, getenv, dedup, DEBUG, all_int, all_same
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, BufferOps
from tinygrad.ops import Op, LazyOp, ConstBuffer, MemBuffer, ScheduleItem, vars_from_ast
from tinygrad.shape.symbolic import sint, Variable

View File

@ -1,6 +1,7 @@
import math
from typing import Tuple, Optional
from tinygrad.helpers import argsort, DType
from tinygrad.helpers import argsort
from tinygrad.dtype import DType
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
from tinygrad.tensor import Function
from tinygrad.lazy import LazyBuffer

View File

@ -3,7 +3,8 @@ from tqdm import tqdm
from typing import Dict, Union, List, Optional, Any, Tuple
from tinygrad.tensor import Tensor
from tinygrad.ops import GlobalCounters
from tinygrad.helpers import dtypes, prod, argsort, DEBUG, Timing, CI, unwrap
from tinygrad.dtype import dtypes
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap
from tinygrad.shape.view import strides_for_shape
safe_dtypes = {"F16": dtypes.float16, "F32": dtypes.float32, "U8": dtypes.uint8, "I8": dtypes.int8, "I32": dtypes.int32, "I64": dtypes.int64,

View File

@ -2,7 +2,8 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Dict, Callable, ClassVar
import functools
from enum import Enum, auto
from tinygrad.helpers import dtypes, prod, DType, dedup
from tinygrad.helpers import prod, dedup
from tinygrad.dtype import dtypes, DType
from tinygrad.shape.symbolic import Variable
from dataclasses import dataclass

View File

@ -3,7 +3,8 @@ import math, functools
from collections import defaultdict, Counter
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.helpers import ImageDType, dtypes, prod, DType, PtrDType, strip_parens, getenv
from tinygrad.helpers import prod, strip_parens, getenv
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
class CStyleLanguage(NamedTuple):
size_prefix: str = "int"

View File

@ -1,7 +1,7 @@
from typing import Final, Dict, Callable, Any, List, Optional, Tuple
from llvmlite import ir
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.helpers import DType, PtrDType, dtypes
from tinygrad.dtype import DType, PtrDType, dtypes
from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf

View File

@ -1,7 +1,8 @@
from typing import List, Any, Dict, cast, Optional
import numpy as np
import Metal
from tinygrad.helpers import dtypes, dedup, unwrap2
from tinygrad.dtype import dtypes
from tinygrad.helpers import dedup, unwrap2
from tinygrad.device import Buffer, CompiledASTRunner, update_stats
from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, GraphException
from tinygrad.shape.symbolic import Variable

View File

@ -1,6 +1,7 @@
import os, mmap, _posixshmem
from typing import Callable, Dict, Tuple
from tinygrad.helpers import prod, DType, OSX, dtypes
from tinygrad.dtype import DType, dtypes
from tinygrad.helpers import prod, OSX
from tinygrad.device import Interpreted, Allocator
from tinygrad.ops import Op, MovementOps, UnaryOps
from tinygrad.shape.view import strides_for_shape

View File

@ -2,7 +2,8 @@ from __future__ import annotations
from typing import Tuple, Optional, List
import ctypes, functools
import gpuctypes.opencl as cl
from tinygrad.helpers import init_c_var, to_char_p_p, from_mv, OSX, ImageDType, DEBUG
from tinygrad.helpers import init_c_var, to_char_p_p, from_mv, OSX, DEBUG
from tinygrad.dtype import ImageDType
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.renderer.cstyle import OpenCLRenderer
from tinygrad.device import Compiled, LRUAllocator

View File

@ -3,7 +3,8 @@ import numpy as np
from typing import Dict, Callable
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, Op
from tinygrad.device import Interpreted, Allocator
from tinygrad.helpers import getenv, dtypes, flatten
from tinygrad.dtype import dtypes
from tinygrad.helpers import getenv, flatten
from tinygrad.runtime.ops_cpu import einsum_mulacc, reduce_axis
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))

View File

@ -7,7 +7,7 @@ from functools import partialmethod, reduce
from itertools import accumulate
import numpy as np
from tinygrad.helpers import DType, dtypes, ImageDType, least_upper_float, least_upper_dtype
from tinygrad.dtype import DType, dtypes, ImageDType, least_upper_float, least_upper_dtype
from tinygrad.helpers import argfix, make_pair, getenv, IMAGE, DEBUG, flatten, prod, all_int, round_up, merge_dicts, fully_flatten
from tinygrad.lazy import LazyBuffer, create_schedule
from tinygrad.ops import LoadOps