fix IMAGE=2 failed with NOOPT=1 (#2209)

* IMAGE=2 failed with NOOPT=1

* fix it
This commit is contained in:
chenyu 2023-11-05 16:16:37 -05:00 committed by GitHub
parent 680cbfdba4
commit 719a97b337
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 12 additions and 5 deletions

View File

@ -4,7 +4,7 @@ import math
import numpy as np
import unittest
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, dtypes
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, dtypes, Context, NOOPT
from tinygrad.ops import Device
if CI:
@ -754,6 +754,11 @@ class TestOps(unittest.TestCase):
lambda x,w: torch.nn.functional.conv2d(x,w).relu(),
lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5)
def test_simple_conv2d_noopt(self):
# useful with IMAGE enabled
with Context(NOOPT=1):
self.test_simple_conv2d()
@unittest.skipIf(IMAGE>0, "no conv3d on images")
def test_simple_conv3d(self):
helper_test_op([(1,4,9,9,9), (4,4,3,3,3)],

View File

@ -60,7 +60,7 @@ class ContextVar:
def __gt__(self, x): return self.value > x
def __lt__(self, x): return self.value < x
DEBUG, IMAGE, BEAM = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0)
DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0)
GRAPH, GRAPHPATH = getenv("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net")
class Timing(contextlib.ContextDecorator):

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import importlib, inspect, functools, pathlib, itertools, random, math, collections
from enum import Enum, auto
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, BEAM
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, BEAM, NOOPT
from tinygrad.runtime.lib import RawBuffer
from tinygrad.shape.symbolic import Variable, sym_infer
from dataclasses import dataclass
@ -312,7 +312,7 @@ class Compiled:
from tinygrad.codegen.linearizer import Linearizer
k = Linearizer(ast, self.linearizer_opts)
assert k.info.dtype == output.dtype, f"linearizer must match dtype. linearizer wants {k.info.dtype} but buffer is {output.dtype}"
if not getenv("NOOPT"):
if not NOOPT:
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
if BEAM >= 1 and not vars_from_ast(ast):
lins = [(("tc" if used_tensor_cores else "hc"), k)]
@ -328,6 +328,8 @@ class Compiled:
timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, disable_cache=True, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
k = timed[0][1]
else:
k.required_optimizations()
return self.to_program(k)
if getenv("ENABLE_METHOD_CACHE", 1):

View File

@ -61,7 +61,7 @@ class CStyleLanguage(NamedTuple):
# returns a str expression of the loaded value with the output type
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
if isinstance(buf_dtype, ImageDType):
assert output_dtype == dtypes._float4, "images must be float4"
assert output_dtype == dtypes._float4, f"images must be float4, getting {output_dtype}"
return f"read_imagef({buf_name}, smp, {idx})"
if self.uses_vload and buf_dtype == dtypes.float16:
return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx})"