move globalcounters to ops (#2960)

* move globalcounters to ops

* missed a few

* sick of that failing
This commit is contained in:
George Hotz 2024-01-01 14:21:02 -08:00 committed by GitHub
parent 8291986959
commit c81ce9643d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 52 additions and 282 deletions

View File

@ -5,8 +5,7 @@ from tqdm import trange
from extra.models.efficientnet import EfficientNet
from tinygrad.nn.state import get_parameters
from tinygrad.nn import optim
from tinygrad.tensor import Tensor
from tinygrad.helpers import GlobalCounters
from tinygrad import Tensor, GlobalCounters
from tinygrad.helpers import getenv
from tinygrad.jit import CacheCollector

View File

@ -2,7 +2,7 @@
import argparse
from tqdm import trange
import numpy as np
from tinygrad import Device
from tinygrad import Device, GlobalCounters
from typing import Optional, Union
from tinygrad.tensor import Tensor
from tinygrad.nn import Embedding, Linear, LayerNorm
@ -10,7 +10,7 @@ from tinygrad.shape.symbolic import Variable
from tinygrad.jit import TinyJit
import tiktoken
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
from tinygrad.helpers import GlobalCounters, Timing, DEBUG, getenv, fetch, colored, dtypes
from tinygrad.helpers import Timing, DEBUG, getenv, fetch, colored, dtypes
MAX_CONTEXT = getenv("MAX_CONTEXT", 128)
HALF = getenv("HALF")

View File

@ -17,9 +17,8 @@ from extra.datasets import fetch_cifar, cifar_mean, cifar_std
from tinygrad import nn
from tinygrad.nn.state import get_state_dict
from tinygrad.nn import optim
from tinygrad import Device
from tinygrad import Device, GlobalCounters
from tinygrad.tensor import Tensor
from tinygrad.helpers import GlobalCounters
from tinygrad.shape.symbolic import Node
from extra.lr_scheduler import OneCycleLR
from tinygrad.jit import TinyJit

View File

@ -8,10 +8,9 @@ import sys, argparse, json
import numpy as np
np.set_printoptions(linewidth=200)
from tinygrad.helpers import Timing, Profiling, getenv, DEBUG, dtypes, colored
from tinygrad import Device
from tinygrad import Device, GlobalCounters
from tinygrad.tensor import Tensor
from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
from tinygrad.helpers import GlobalCounters
from extra.models.llama import Transformer, convert_from_huggingface
from sentencepiece import SentencePieceProcessor

View File

@ -1,9 +1,9 @@
import time
from pathlib import Path
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad import Tensor, GlobalCounters
from tinygrad.jit import TinyJit
from tinygrad.helpers import getenv, dtypes, GlobalCounters
from tinygrad.helpers import getenv, dtypes
from examples.mlperf import helpers
def eval_resnet():

View File

@ -1,6 +1,6 @@
# load each model here, quick benchmark
from tinygrad.tensor import Tensor
from tinygrad.helpers import GlobalCounters, getenv
from tinygrad import Tensor, GlobalCounters
from tinygrad.helpers import getenv
import numpy as np
def test_model(model, *inputs):

View File

@ -10,8 +10,8 @@ from PIL import Image
import numpy as np
from tqdm import tqdm
from tinygrad.tensor import Tensor
from tinygrad import Device
from tinygrad.helpers import dtypes, GlobalCounters, Timing, Context, getenv, fetch, colored
from tinygrad import Device, GlobalCounters
from tinygrad.helpers import dtypes, Timing, Context, getenv, fetch, colored
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
from tinygrad.jit import TinyJit

View File

@ -1,173 +0,0 @@
//#pragma OPENCL EXTENSION cl_khr_fp16 : enable
//#define prec4 half4
//#define read_imagep read_imageh
#define prec4 float4
#define read_imagep read_imagef
/*float4 read_imagep(image2d_t data, sampler_t smp, int2 idx) {
return read_imagef(data, smp, idx);
}*/
__kernel void r_32_16_16_64_4_4_4(write_only image2d_t data0, read_only image2d_t data1, read_only image2d_t data2) {
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int idx0 = get_global_id(2); /* 32 */
int idx1 = get_global_id(1); /* 16 */
int idx2 = get_global_id(0); /* 16 */
float4 acc0 = 0.0f;
float4 acc1 = 0.0f;
float4 acc2 = 0.0f;
float4 acc3 = 0.0f;
// idx0 is a global
// (32, 4096, 4)
int2 imi0 = (int2)(((idx1*256)),idx0);
int2 imi1 = (int2)(((idx1*256)+64),idx0);
int2 imi2 = (int2)(((idx1*256)+128),idx0);
int2 imi3 = (int2)(((idx1*256)+192),idx0);
// idx2 is a local
// (16, 256, 4)
int2 imi4 = (int2)(0,idx2);
int2 imi5 = (int2)(1,idx2);
int2 imi6 = (int2)(2,idx2);
int2 imi7 = (int2)(3,idx2);
// 11% faster
//for(;imi4.x < 256;) {
//#pragma unroll(2)
for (int ridx0 = 0; ridx0 < 64; ++ridx0) {
prec4 val0 = read_imagep(data1, smp, imi0);
prec4 val1 = read_imagep(data1, smp, imi1);
prec4 val2 = read_imagep(data1, smp, imi2);
prec4 val3 = read_imagep(data1, smp, imi3);
prec4 val4 = read_imagep(data2, smp, imi4);
prec4 val5 = read_imagep(data2, smp, imi5);
prec4 val6 = read_imagep(data2, smp, imi6);
prec4 val7 = read_imagep(data2, smp, imi7);
imi0.x += 1;
imi1.x += 1;
imi2.x += 1;
imi3.x += 1;
imi4.x += 4;
imi5.x += 4;
imi6.x += 4;
imi7.x += 4;
/*(acc0).x = (((val0).x*(val4).x)+(acc0).x);
(acc0).y = (((val0).x*(val4).y)+(acc0).y);
(acc0).z = (((val0).x*(val4).z)+(acc0).z);
(acc0).w = (((val0).x*(val4).w)+(acc0).w);
(acc1).x = (((val1).x*(val4).x)+(acc1).x);
(acc1).y = (((val1).x*(val4).y)+(acc1).y);
(acc1).z = (((val1).x*(val4).z)+(acc1).z);
(acc1).w = (((val1).x*(val4).w)+(acc1).w);
(acc2).x = (((val2).x*(val4).x)+(acc2).x);
(acc2).y = (((val2).x*(val4).y)+(acc2).y);
(acc2).z = (((val2).x*(val4).z)+(acc2).z);
(acc2).w = (((val2).x*(val4).w)+(acc2).w);
(acc3).x = (((val3).x*(val4).x)+(acc3).x);
(acc3).y = (((val3).x*(val4).y)+(acc3).y);
(acc3).z = (((val3).x*(val4).z)+(acc3).z);
(acc3).w = (((val3).x*(val4).w)+(acc3).w);*/
//read_mem_fence(CLK_LOCAL_MEM_FENCE);
/*(acc0).x = (((val0).x*(val4).x)+(acc0).x);
(acc0).y = (((val0).x*(val4).y)+(acc0).y);
(acc0).z = (((val0).x*(val4).z)+(acc0).z);
(acc0).w = (((val0).x*(val4).w)+(acc0).w);*/
acc0 = mad(val0.x, val4, acc0);
acc0 = mad(val0.y, val5, acc0);
acc0 = mad(val0.z, val6, acc0);
acc0 = mad(val0.w, val7, acc0);
acc1 = mad(val1.x, val4, acc1);
acc1 = mad(val1.y, val5, acc1);
acc1 = mad(val1.z, val6, acc1);
acc1 = mad(val1.w, val7, acc1);
acc2 = mad(val2.x, val4, acc2);
acc2 = mad(val2.y, val5, acc2);
acc2 = mad(val2.z, val6, acc2);
acc2 = mad(val2.w, val7, acc2);
acc3 = mad(val3.x, val4, acc3);
acc3 = mad(val3.y, val5, acc3);
acc3 = mad(val3.z, val6, acc3);
acc3 = mad(val3.w, val7, acc3);
/*acc0 = val0.x * val4 + acc0;
acc1 = val1.x * val4 + acc1;
acc2 = val2.x * val4 + acc2;
acc3 = val3.x * val4 + acc3;*/
/*acc0 = val0.y * val5 + acc0;
acc1 = val1.y * val5 + acc1;
acc2 = val2.y * val5 + acc2;
acc3 = val3.y * val5 + acc3;
acc0 = val0.z * val6 + acc0;
acc1 = val1.z * val6 + acc1;
acc2 = val2.z * val6 + acc2;
acc3 = val3.z * val6 + acc3;
acc0 = val0.w * val7 + acc0;
acc1 = val1.w * val7 + acc1;
acc2 = val2.w * val7 + acc2;
acc3 = val3.w * val7 + acc3;*/
/*(acc0).x = (((val0).y*(val5).x)+(acc0).x);
(acc0).y = (((val0).y*(val5).y)+(acc0).y);
(acc0).z = (((val0).y*(val5).z)+(acc0).z);
(acc0).w = (((val0).y*(val5).w)+(acc0).w);
(acc1).x = (((val1).y*(val5).x)+(acc1).x);
(acc1).y = (((val1).y*(val5).y)+(acc1).y);
(acc1).z = (((val1).y*(val5).z)+(acc1).z);
(acc1).w = (((val1).y*(val5).w)+(acc1).w);
(acc2).x = (((val2).y*(val5).x)+(acc2).x);
(acc2).y = (((val2).y*(val5).y)+(acc2).y);
(acc2).z = (((val2).y*(val5).z)+(acc2).z);
(acc2).w = (((val2).y*(val5).w)+(acc2).w);
(acc3).x = (((val3).y*(val5).x)+(acc3).x);
(acc3).y = (((val3).y*(val5).y)+(acc3).y);
(acc3).z = (((val3).y*(val5).z)+(acc3).z);
(acc3).w = (((val3).y*(val5).w)+(acc3).w);
(acc0).x = (((val0).z*(val6).x)+(acc0).x);
(acc0).y = (((val0).z*(val6).y)+(acc0).y);
(acc0).z = (((val0).z*(val6).z)+(acc0).z);
(acc0).w = (((val0).z*(val6).w)+(acc0).w);
(acc1).x = (((val1).z*(val6).x)+(acc1).x);
(acc1).y = (((val1).z*(val6).y)+(acc1).y);
(acc1).z = (((val1).z*(val6).z)+(acc1).z);
(acc1).w = (((val1).z*(val6).w)+(acc1).w);
(acc2).x = (((val2).z*(val6).x)+(acc2).x);
(acc2).y = (((val2).z*(val6).y)+(acc2).y);
(acc2).z = (((val2).z*(val6).z)+(acc2).z);
(acc2).w = (((val2).z*(val6).w)+(acc2).w);
(acc3).x = (((val3).z*(val6).x)+(acc3).x);
(acc3).y = (((val3).z*(val6).y)+(acc3).y);
(acc3).z = (((val3).z*(val6).z)+(acc3).z);
(acc3).w = (((val3).z*(val6).w)+(acc3).w);
(acc0).x = (((val0).w*(val7).x)+(acc0).x);
(acc0).y = (((val0).w*(val7).y)+(acc0).y);
(acc0).z = (((val0).w*(val7).z)+(acc0).z);
(acc0).w = (((val0).w*(val7).w)+(acc0).w);
(acc1).x = (((val1).w*(val7).x)+(acc1).x);
(acc1).y = (((val1).w*(val7).y)+(acc1).y);
(acc1).z = (((val1).w*(val7).z)+(acc1).z);
(acc1).w = (((val1).w*(val7).w)+(acc1).w);
(acc2).x = (((val2).w*(val7).x)+(acc2).x);
(acc2).y = (((val2).w*(val7).y)+(acc2).y);
(acc2).z = (((val2).w*(val7).z)+(acc2).z);
(acc2).w = (((val2).w*(val7).w)+(acc2).w);
(acc3).x = (((val3).w*(val7).x)+(acc3).x);
(acc3).y = (((val3).w*(val7).y)+(acc3).y);
(acc3).z = (((val3).w*(val7).z)+(acc3).z);
(acc3).w = (((val3).w*(val7).w)+(acc3).w);*/
}
write_imagef(data0, (int2)(((idx1*64)+idx2),idx0), acc0); //(float4)((acc0).x,(acc0).y,(acc0).z,(acc0).w));
write_imagef(data0, (int2)(((idx1*64)+idx2+16),idx0), acc1); //(float4)((acc1).x,(acc1).y,(acc1).z,(acc1).w));
write_imagef(data0, (int2)(((idx1*64)+idx2+32),idx0), acc2); //(float4)((acc2).x,(acc2).y,(acc2).z,(acc2).w));
write_imagef(data0, (int2)(((idx1*64)+idx2+48),idx0), acc3); //(float4)((acc3).x,(acc3).y,(acc3).z,(acc3).w));
}

View File

@ -1,53 +0,0 @@
"""
61: op Conv shape [(1, 256, 32, 64), (64, 256, 1, 1), (64,)] opt {'dilations': (1, 1), 'group': 1, 'kernel_shape': (1, 1), 'pads': (0, 0, 0, 0), 'strides': (1, 1)}
62: op Mul shape [(1, 64, 32, 64), (64, 1, 1)] opt {}
63: op Add shape [(1, 64, 32, 64), (1, 64, 32, 64)] opt {}
64: op Conv shape [(1, 64, 32, 64), (64, 1, 3, 3), (64,)] opt {'dilations': (1, 1), 'group': 64, 'kernel_shape': (3, 3), 'pads': (1, 1, 1, 1), 'strides': (1, 1)}
65: op Conv shape [(1, 64, 32, 64), (64, 1, 7, 7), (64,)] opt {'dilations': (1, 1), 'group': 64, 'kernel_shape': (7, 7), 'pads': (3, 3, 3, 3), 'strides': (1, 1)}
66: op Conv shape [(1, 64, 32, 64), (256, 64, 1, 1), (256,)] opt {'dilations': (1, 1), 'group': 1, 'kernel_shape': (1, 1), 'pads': (0, 0, 0, 0), 'strides': (1, 1)}
"""
import pathlib
from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d
from tinygrad.realize import run_schedule
from tinygrad.helpers import partition, GlobalCounters, Context, getenv, prod, dtypes
from tinygrad.runtime.ops_gpu import CLBuffer, CLProgram
from tinygrad.ops import LoadOps, ReduceOps
def single_kernel():
# single kernel
sz1, sz2, sz3 = (32, 1024, 4), (32, 4096, 4), (16, 256, 4)
out = CLBuffer(prod(sz1), dtypes.imageh(sz1))
x = CLBuffer(prod(sz2), dtypes.imageh(sz2))
w = CLBuffer(prod(sz3), dtypes.imageh(sz3))
old = CLProgram("r_32_16_16_64_4_4_4", open(pathlib.Path(__file__).parent / "conv1_reorder.cl").read())
old_tms = [old([1,1,32], [16,16,1], out, x, w, wait=True)*1e6 for _ in range(5)]
print(old_tms, 67.107/min(old_tms)*1e3)
exit(0)
# CONV=0 PYTHONPATH="." LATEDEBUG=5 OPT=99 IMAGE=2 FLOAT16=1 NOLOCALS=1 python3 extra/fastvits/fastvits_speed.py
if __name__ == "__main__":
#single_kernel()
# this is stage 1 in fastvits
c1 = Conv2d(256, 64, (1,1), bias=False)
c2 = Conv2d(64, 64, (3,3), groups=64, padding=1, bias=False)
c3 = Conv2d(64, 64, (7,7), groups=64, padding=3, bias=False)
c4 = Conv2d(64, 256, (1,1), bias=False)
c5 = Conv2d(256, 64, (1,1), bias=False)
# TODO: the elementwise ops shouldn't rerun with normal realize
x = Tensor.randn(1, 256, 32, 64)
out = x.sequential([c1,c2,c3,c4,c5])
schedule = out.lazydata.schedule()
schedule, schedule_input = partition(schedule, lambda x: x.ast.op not in LoadOps and any(y.op in ReduceOps for y in x.ast.lazyops))
run_schedule(schedule_input)
run_schedule(schedule[:getenv("CONV")])
print("*** init done ***")
GlobalCounters.reset()
with Context(DEBUG=getenv("LATEDEBUG", 2), BEAM=getenv("LATEBEAM")):
run_schedule(schedule[getenv("CONV"):getenv("CONV")+1])

View File

@ -3,10 +3,9 @@ import os
import numpy as np
import time, torch, torch.mps
from tinygrad.helpers import GlobalCounters
from tinygrad.tensor import Tensor
from tinygrad.jit import TinyJit
from tinygrad import Device
from tinygrad import Device, GlobalCounters
from tinygrad.helpers import colored, getenv, CI, flat_mv
import os

View File

@ -4,7 +4,7 @@ from tinygrad.helpers import prod
from tinygrad.tensor import Tensor
from tinygrad.lazy import LazyBuffer
from tinygrad.device import Buffer
from tinygrad.helpers import GlobalCounters
from tinygrad import GlobalCounters
def print_objects():
#gc.collect()

View File

@ -13,8 +13,8 @@ import onnx
from tqdm import tqdm
from typing import Tuple, List, Optional, Dict
from extra.onnx import get_run_onnx
from tinygrad import Tensor, Device
from tinygrad.helpers import dtypes, partition, GlobalCounters, Context, fetch, getenv, ImageDType, GRAPH, DEBUG
from tinygrad import Tensor, Device, GlobalCounters
from tinygrad.helpers import dtypes, partition, Context, fetch, getenv, ImageDType, GRAPH, DEBUG
from tinygrad.realize import run_schedule, lower_schedule_item
from tinygrad.ops import LoadOps, ScheduleItem
Device.DEFAULT = "GPU"

View File

@ -12,10 +12,9 @@ import numpy as np
import unittest
from tinygrad.tensor import Tensor, Device
from tinygrad import nn
from tinygrad import nn, GlobalCounters
from tinygrad.helpers import getenv
from tinygrad.nn import optim
from tinygrad.helpers import GlobalCounters
#from tinygrad.lazy import PUSH_PERMUTES
PUSH_PERMUTES = False
from tinygrad.jit import CacheCollector

View File

@ -9,8 +9,7 @@ torch.set_num_threads(1)
import time
import numpy as np
np.set_printoptions(linewidth=160)
from tinygrad import Device
from tinygrad.helpers import GlobalCounters
from tinygrad import Device, GlobalCounters
from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d
from tinygrad.helpers import colored, getenv, CI

View File

@ -155,6 +155,7 @@ class TestFetch(unittest.TestCase):
def test_fetch_small(self):
assert(len(fetch('https://google.com', allow_caching=False).read_bytes())>0)
@unittest.skip("test is flaky")
def test_fetch_img(self):
img = fetch("https://media.istockphoto.com/photos/hen-picture-id831791190", allow_caching=False)
with Image.open(img) as pimg:

View File

@ -2,7 +2,5 @@ from tinygrad.tensor import Tensor # noqa: F401
from tinygrad.jit import TinyJit # noqa: F401
from tinygrad.shape.symbolic import Variable # noqa: F401
from tinygrad.helpers import dtypes # noqa: F401
# NOTE: these should not be relied on to be stable
from tinygrad.device import Device # noqa: F401
from tinygrad.helpers import GlobalCounters # noqa: F401
from tinygrad.ops import GlobalCounters # noqa: F401
from tinygrad.device import Device # noqa: F401

View File

@ -4,9 +4,9 @@ from collections import defaultdict
from typing import TYPE_CHECKING, Union, Any, List, Optional, Dict, Callable
import importlib, inspect, functools, pathlib, time, re, ctypes
from tinygrad.helpers import DType, dtypes, ImageDType, diskcache_get, diskcache_put
from tinygrad.helpers import ansilen, DEBUG, getenv, GlobalCounters, colored, BEAM, NOOPT, all_int, to_function_name, from_mv, flat_mv
from tinygrad.helpers import ansilen, DEBUG, getenv, colored, BEAM, NOOPT, all_int, to_function_name, from_mv, flat_mv
from tinygrad.shape.symbolic import Variable, sym_infer, sint
from tinygrad.ops import LazyOp, TernaryOps, get_lazyop_info, ReduceOps, BufferOps, BinaryOps, UnaryOps, Op, vars_from_ast
from tinygrad.ops import LazyOp, TernaryOps, get_lazyop_info, ReduceOps, BufferOps, BinaryOps, UnaryOps, Op, vars_from_ast, GlobalCounters
if TYPE_CHECKING:
from tinygrad.codegen.linearizer import Linearizer

View File

@ -1,8 +1,8 @@
import os, atexit
from typing import List, Any
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, LazyOp
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, LazyOp, GlobalCounters
from tinygrad.device import Device
from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, GlobalCounters, getenv
from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, getenv
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.shape.symbolic import NumNode

View File

@ -199,15 +199,6 @@ DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if (
not k.startswith('__') and not k.startswith('default') and not callable(v) and v.__class__ is not staticmethod)}
INVERSE_DTYPES_DICT = {v:k for k,v in DTYPES_DICT.items()}
class GlobalCounters:
global_ops: ClassVar[int] = 0
global_mem: ClassVar[int] = 0
time_sum_s: ClassVar[float] = 0.0
kernel_count: ClassVar[int] = 0
mem_used: ClassVar[int] = 0 # NOTE: this is not reset
@staticmethod
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count = 0,0,0.0,0
# *** universal database cache ***
_cache_dir: str = getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache"))
@ -295,16 +286,6 @@ def flat_mv(mv:memoryview):
# *** Helpers for CUDA-like APIs.
def pretty_ptx(s):
# all expressions match `<valid_before><expr><valid_after>` and replace it with `<valid_before>color(<expr>)<valid_after>`
s = re.sub(r'([!@<\[\s,\+\-;\n])((?:[_%$][\w%\$_]+(?:\.[xyz])?\:?)|(?:buf\d+))([<>\]\s,\+\-;\n\)])', lambda m:m[1]+colored(m[2], "blue")+m[3], s, flags=re.M) # identifiers # noqa: E501
s = re.sub(r'(.)((?:b|s|u|f)(?:8|16|32|64)|pred)([\.\s])', lambda m:m[1]+colored(m[2], "green")+m[3], s, flags=re.M) # types
s = re.sub(r'^(\s*)([\w]+)(.*?;$)', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # instructions
s = re.sub(r'([<>\[\]\s,\+\-;])((?:0[fF][0-9a-fA-F]{8})|(?:[0-9]+)|(?:0[xX][0-9a-fA-F]+))([<>\[\]\s,\+\-;])', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # numbers # noqa: E501
s = re.sub(r'(\.)(param|reg|global)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # space
s = re.sub(r'(\.)(version|target|address_size|visible|entry)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # derivatives
return s
def compile_cuda_style(prg, compile_options, prog_t, create_prog, compile_prog, get_code, get_code_size, get_log, get_log_size, check) -> bytes:
check(create_prog(ctypes.byref(prog := prog_t()), prg.encode(), "<null>".encode(), 0, None, None))
status = compile_prog(prog, len(compile_options), to_char_p_p([o.encode() for o in compile_options]))

View File

@ -2,7 +2,8 @@ import os, json, pathlib, zipfile, pickle, tarfile, struct
from tqdm import tqdm
from typing import Dict, Union, List, Optional, Any, Tuple
from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes, prod, argsort, DEBUG, Timing, GlobalCounters, CI, unwrap
from tinygrad.ops import GlobalCounters
from tinygrad.helpers import dtypes, prod, argsort, DEBUG, Timing, CI, unwrap
from tinygrad.shape.view import strides_for_shape
safe_dtypes = {"F16": dtypes.float16, "F32": dtypes.float32, "U8": dtypes.uint8, "I8": dtypes.int8, "I32": dtypes.int32, "I64": dtypes.int64,

View File

@ -1,5 +1,5 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Dict, Callable
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Dict, Callable, ClassVar
import functools
from enum import Enum, auto
from tinygrad.helpers import dtypes, prod, DType, dedup
@ -97,3 +97,14 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter:
@functools.lru_cache(None) # NOTE: this cache needs to be recreated for new ASTs
def run_ast(ast): return InterpretedFlopCounter[ast.op](*([run_ast(x) for x in ast.src]+([ast.arg] if ast.arg is not None else [])))
return run_ast(ast)
# **************** global state Counters ****************
class GlobalCounters:
global_ops: ClassVar[int] = 0
global_mem: ClassVar[int] = 0
time_sum_s: ClassVar[float] = 0.0
kernel_count: ClassVar[int] = 0
mem_used: ClassVar[int] = 0 # NOTE: this is not reset
@staticmethod
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count = 0,0,0.0,0

View File

@ -1,8 +1,8 @@
from typing import List, Dict, Optional, cast
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, GlobalCounters
from tinygrad.device import Device, Buffer, BufferCopy, JITRunner, update_stats
from tinygrad.graph import print_tree, realized_lazybuffer
from tinygrad.helpers import prod, GlobalCounters, colored
from tinygrad.helpers import prod, colored
from tinygrad.shape.symbolic import Variable
# *** schedule running ***

View File

@ -1,13 +1,23 @@
from __future__ import annotations
import subprocess, hashlib, tempfile, ctypes, ctypes.util, functools
import subprocess, hashlib, tempfile, ctypes, ctypes.util, functools, re
from pathlib import Path
from typing import Tuple, Optional
import gpuctypes.cuda as cuda
from tinygrad.helpers import DEBUG, getenv, from_mv, init_c_var, pretty_ptx, cpu_time_execution, compile_cuda_style, encode_args_cuda_style, time_execution_cuda_style # noqa: E501
from tinygrad.helpers import DEBUG, getenv, from_mv, init_c_var, colored, cpu_time_execution, compile_cuda_style, encode_args_cuda_style, time_execution_cuda_style # noqa: E501
from tinygrad.device import Compiled, LRUAllocator, MallocAllocator
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.renderer.cstyle import CUDARenderer
def pretty_ptx(s):
# all expressions match `<valid_before><expr><valid_after>` and replace it with `<valid_before>color(<expr>)<valid_after>`
s = re.sub(r'([!@<\[\s,\+\-;\n])((?:[_%$][\w%\$_]+(?:\.[xyz])?\:?)|(?:buf\d+))([<>\]\s,\+\-;\n\)])', lambda m:m[1]+colored(m[2], "blue")+m[3], s, flags=re.M) # identifiers # noqa: E501
s = re.sub(r'(.)((?:b|s|u|f)(?:8|16|32|64)|pred)([\.\s])', lambda m:m[1]+colored(m[2], "green")+m[3], s, flags=re.M) # types
s = re.sub(r'^(\s*)([\w]+)(.*?;$)', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # instructions
s = re.sub(r'([<>\[\]\s,\+\-;])((?:0[fF][0-9a-fA-F]{8})|(?:[0-9]+)|(?:0[xX][0-9a-fA-F]+))([<>\[\]\s,\+\-;])', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # numbers # noqa: E501
s = re.sub(r'(\.)(param|reg|global)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # space
s = re.sub(r'(\.)(version|target|address_size|visible|entry)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # derivatives
return s
CUDACPU = getenv("CUDACPU") == 1
if CUDACPU:
gpuocelot_lib = ctypes.CDLL(ctypes.util.find_library("gpuocelot"))