move print_diff to test/helpers (#7071)

This commit is contained in:
qazal 2024-10-15 22:00:39 +03:00 committed by GitHub
parent 1a45e94f5d
commit 09de958855
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 18 additions and 19 deletions

View File

@ -1,17 +1,7 @@
from dataclasses import dataclass from dataclasses import dataclass
import difflib, logging, traceback, subprocess import traceback, subprocess
from typing import Dict, Optional from typing import Dict, Optional
from tinygrad.helpers import ContextVar, colored, getenv from tinygrad.helpers import ContextVar, getenv
def print_diff(s0, s1, unified=getenv("UNIFIED_DIFF",1)):
if not logging.getLogger().hasHandlers(): logging.basicConfig(level=logging.INFO, format="%(message)s")
if unified:
lines = list(difflib.unified_diff(str(s0).splitlines(), str(s1).splitlines()))
diff = "\n".join(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None) for line in lines)
else:
import ocdiff
diff = ocdiff.console_diff(str(s0), str(s1))
logging.info(diff)
@dataclass(frozen=True) @dataclass(frozen=True)
class ProcessReplayContext: class ProcessReplayContext:

View File

@ -5,9 +5,10 @@ from typing import Callable, List, Tuple, Union, cast
from tinygrad.engine.schedule import full_ast_rewrite from tinygrad.engine.schedule import full_ast_rewrite
from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm
from tinygrad.codegen.kernel import Kernel, Opt from tinygrad.codegen.kernel import Kernel, Opt
from test.external.process_replay.helpers import ProcessReplayContext, print_diff
from tinygrad.ops import UOp from tinygrad.ops import UOp
from tinygrad.renderer import Renderer from tinygrad.renderer import Renderer
from test.helpers import print_diff
from test.external.process_replay.helpers import ProcessReplayContext
# *** process replay settings # *** process replay settings

View File

@ -1,16 +1,14 @@
import sys, time import sys, time, logging, difflib
from typing import Callable, Optional, Tuple, TypeVar from typing import Callable, Optional, Tuple, TypeVar
import numpy as np import numpy as np
from test.external.process_replay.helpers import print_diff
from tinygrad import Tensor, Device, dtypes from tinygrad import Tensor, Device, dtypes
from tinygrad.ops import UOp, UOps from tinygrad.ops import UOp, UOps, sint
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.tensor import _to_np_dtype from tinygrad.tensor import _to_np_dtype
from tinygrad.engine.realize import Runner from tinygrad.engine.realize import Runner
from tinygrad.dtype import ConstType, DType from tinygrad.dtype import ConstType, DType
from tinygrad.nn.state import get_parameters from tinygrad.nn.state import get_parameters
from tinygrad.helpers import Context, CI, OSX, getenv from tinygrad.helpers import Context, CI, OSX, getenv, colored
from tinygrad.ops import sint
def derandomize_model(model): def derandomize_model(model):
with Context(GRAPH=0): with Context(GRAPH=0):
@ -56,8 +54,18 @@ def rand_for_dtype(dt:DType, size:int):
return np.random.choice([True, False], size=size) return np.random.choice([True, False], size=size)
return np.random.uniform(-10, 10, size=size).astype(_to_np_dtype(dt)) return np.random.uniform(-10, 10, size=size).astype(_to_np_dtype(dt))
def print_diff(s0, s1, unified=getenv("UNIFIED_DIFF",1)):
if not logging.getLogger().hasHandlers(): logging.basicConfig(level=logging.INFO, format="%(message)s")
if unified:
lines = list(difflib.unified_diff(str(s0).splitlines(), str(s1).splitlines()))
diff = "\n".join(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None) for line in lines)
else:
import ocdiff
diff = ocdiff.console_diff(str(s0), str(s1))
logging.info(diff)
def assert_equiv_uops(u1:UOp, u2:UOp) -> None: def assert_equiv_uops(u1:UOp, u2:UOp) -> None:
if u1.key != u2.key: if u1 is not u2:
print_diff(u1, u2) print_diff(u1, u2)
raise AssertionError("uops aren't equal.") raise AssertionError("uops aren't equal.")