mirror of https://github.com/commaai/tinygrad.git
move print_diff to test/helpers (#7071)
This commit is contained in:
parent
1a45e94f5d
commit
09de958855
|
@ -1,17 +1,7 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import difflib, logging, traceback, subprocess
|
import traceback, subprocess
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
from tinygrad.helpers import ContextVar, colored, getenv
|
from tinygrad.helpers import ContextVar, getenv
|
||||||
|
|
||||||
def print_diff(s0, s1, unified=getenv("UNIFIED_DIFF",1)):
|
|
||||||
if not logging.getLogger().hasHandlers(): logging.basicConfig(level=logging.INFO, format="%(message)s")
|
|
||||||
if unified:
|
|
||||||
lines = list(difflib.unified_diff(str(s0).splitlines(), str(s1).splitlines()))
|
|
||||||
diff = "\n".join(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None) for line in lines)
|
|
||||||
else:
|
|
||||||
import ocdiff
|
|
||||||
diff = ocdiff.console_diff(str(s0), str(s1))
|
|
||||||
logging.info(diff)
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ProcessReplayContext:
|
class ProcessReplayContext:
|
||||||
|
|
|
@ -5,9 +5,10 @@ from typing import Callable, List, Tuple, Union, cast
|
||||||
from tinygrad.engine.schedule import full_ast_rewrite
|
from tinygrad.engine.schedule import full_ast_rewrite
|
||||||
from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm
|
from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm
|
||||||
from tinygrad.codegen.kernel import Kernel, Opt
|
from tinygrad.codegen.kernel import Kernel, Opt
|
||||||
from test.external.process_replay.helpers import ProcessReplayContext, print_diff
|
|
||||||
from tinygrad.ops import UOp
|
from tinygrad.ops import UOp
|
||||||
from tinygrad.renderer import Renderer
|
from tinygrad.renderer import Renderer
|
||||||
|
from test.helpers import print_diff
|
||||||
|
from test.external.process_replay.helpers import ProcessReplayContext
|
||||||
|
|
||||||
# *** process replay settings
|
# *** process replay settings
|
||||||
|
|
||||||
|
|
|
@ -1,16 +1,14 @@
|
||||||
import sys, time
|
import sys, time, logging, difflib
|
||||||
from typing import Callable, Optional, Tuple, TypeVar
|
from typing import Callable, Optional, Tuple, TypeVar
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from test.external.process_replay.helpers import print_diff
|
|
||||||
from tinygrad import Tensor, Device, dtypes
|
from tinygrad import Tensor, Device, dtypes
|
||||||
from tinygrad.ops import UOp, UOps
|
from tinygrad.ops import UOp, UOps, sint
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker
|
from tinygrad.shape.shapetracker import ShapeTracker
|
||||||
from tinygrad.tensor import _to_np_dtype
|
from tinygrad.tensor import _to_np_dtype
|
||||||
from tinygrad.engine.realize import Runner
|
from tinygrad.engine.realize import Runner
|
||||||
from tinygrad.dtype import ConstType, DType
|
from tinygrad.dtype import ConstType, DType
|
||||||
from tinygrad.nn.state import get_parameters
|
from tinygrad.nn.state import get_parameters
|
||||||
from tinygrad.helpers import Context, CI, OSX, getenv
|
from tinygrad.helpers import Context, CI, OSX, getenv, colored
|
||||||
from tinygrad.ops import sint
|
|
||||||
|
|
||||||
def derandomize_model(model):
|
def derandomize_model(model):
|
||||||
with Context(GRAPH=0):
|
with Context(GRAPH=0):
|
||||||
|
@ -56,8 +54,18 @@ def rand_for_dtype(dt:DType, size:int):
|
||||||
return np.random.choice([True, False], size=size)
|
return np.random.choice([True, False], size=size)
|
||||||
return np.random.uniform(-10, 10, size=size).astype(_to_np_dtype(dt))
|
return np.random.uniform(-10, 10, size=size).astype(_to_np_dtype(dt))
|
||||||
|
|
||||||
|
def print_diff(s0, s1, unified=getenv("UNIFIED_DIFF",1)):
|
||||||
|
if not logging.getLogger().hasHandlers(): logging.basicConfig(level=logging.INFO, format="%(message)s")
|
||||||
|
if unified:
|
||||||
|
lines = list(difflib.unified_diff(str(s0).splitlines(), str(s1).splitlines()))
|
||||||
|
diff = "\n".join(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None) for line in lines)
|
||||||
|
else:
|
||||||
|
import ocdiff
|
||||||
|
diff = ocdiff.console_diff(str(s0), str(s1))
|
||||||
|
logging.info(diff)
|
||||||
|
|
||||||
def assert_equiv_uops(u1:UOp, u2:UOp) -> None:
|
def assert_equiv_uops(u1:UOp, u2:UOp) -> None:
|
||||||
if u1.key != u2.key:
|
if u1 is not u2:
|
||||||
print_diff(u1, u2)
|
print_diff(u1, u2)
|
||||||
raise AssertionError("uops aren't equal.")
|
raise AssertionError("uops aren't equal.")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue