From 09de958855e7b08647d2883934d5bdc910e89ace Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 15 Oct 2024 22:00:39 +0300 Subject: [PATCH] move print_diff to test/helpers (#7071) --- test/external/process_replay/helpers.py | 14 ++----------- .../external/process_replay/process_replay.py | 3 ++- test/helpers.py | 20 +++++++++++++------ 3 files changed, 18 insertions(+), 19 deletions(-) diff --git a/test/external/process_replay/helpers.py b/test/external/process_replay/helpers.py index b8b9880c..ca7a2d59 100644 --- a/test/external/process_replay/helpers.py +++ b/test/external/process_replay/helpers.py @@ -1,17 +1,7 @@ from dataclasses import dataclass -import difflib, logging, traceback, subprocess +import traceback, subprocess from typing import Dict, Optional -from tinygrad.helpers import ContextVar, colored, getenv - -def print_diff(s0, s1, unified=getenv("UNIFIED_DIFF",1)): - if not logging.getLogger().hasHandlers(): logging.basicConfig(level=logging.INFO, format="%(message)s") - if unified: - lines = list(difflib.unified_diff(str(s0).splitlines(), str(s1).splitlines())) - diff = "\n".join(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None) for line in lines) - else: - import ocdiff - diff = ocdiff.console_diff(str(s0), str(s1)) - logging.info(diff) +from tinygrad.helpers import ContextVar, getenv @dataclass(frozen=True) class ProcessReplayContext: diff --git a/test/external/process_replay/process_replay.py b/test/external/process_replay/process_replay.py index 85c753db..2f59c7fc 100755 --- a/test/external/process_replay/process_replay.py +++ b/test/external/process_replay/process_replay.py @@ -5,9 +5,10 @@ from typing import Callable, List, Tuple, Union, cast from tinygrad.engine.schedule import full_ast_rewrite from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm from tinygrad.codegen.kernel import Kernel, Opt -from test.external.process_replay.helpers import ProcessReplayContext, print_diff from tinygrad.ops import UOp from tinygrad.renderer import Renderer +from test.helpers import print_diff +from test.external.process_replay.helpers import ProcessReplayContext # *** process replay settings diff --git a/test/helpers.py b/test/helpers.py index ac566493..2e0c92fb 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -1,16 +1,14 @@ -import sys, time +import sys, time, logging, difflib from typing import Callable, Optional, Tuple, TypeVar import numpy as np -from test.external.process_replay.helpers import print_diff from tinygrad import Tensor, Device, dtypes -from tinygrad.ops import UOp, UOps +from tinygrad.ops import UOp, UOps, sint from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.tensor import _to_np_dtype from tinygrad.engine.realize import Runner from tinygrad.dtype import ConstType, DType from tinygrad.nn.state import get_parameters -from tinygrad.helpers import Context, CI, OSX, getenv -from tinygrad.ops import sint +from tinygrad.helpers import Context, CI, OSX, getenv, colored def derandomize_model(model): with Context(GRAPH=0): @@ -56,8 +54,18 @@ def rand_for_dtype(dt:DType, size:int): return np.random.choice([True, False], size=size) return np.random.uniform(-10, 10, size=size).astype(_to_np_dtype(dt)) +def print_diff(s0, s1, unified=getenv("UNIFIED_DIFF",1)): + if not logging.getLogger().hasHandlers(): logging.basicConfig(level=logging.INFO, format="%(message)s") + if unified: + lines = list(difflib.unified_diff(str(s0).splitlines(), str(s1).splitlines())) + diff = "\n".join(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None) for line in lines) + else: + import ocdiff + diff = ocdiff.console_diff(str(s0), str(s1)) + logging.info(diff) + def assert_equiv_uops(u1:UOp, u2:UOp) -> None: - if u1.key != u2.key: + if u1 is not u2: print_diff(u1, u2) raise AssertionError("uops aren't equal.")