tiny import cleanup and fix typo (#6692)

This commit is contained in:
chenyu 2024-09-23 21:48:23 -04:00 committed by GitHub
parent 02c0c09fb9
commit 31b9c74c77
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 6 additions and 8 deletions

View File

@ -606,7 +606,7 @@ def SkipLayerNormalization(x:Tensor, skip:Tensor, gamma, beta:Optional[Tensor]=N
return x.layernorm(eps=epsilon) * gamma + beta, None, None, x
def FastGelu(x:Tensor, bias:Optional[Tensor]=None):
# this is tanh approamixated
# this is tanh approximated
return (x + bias).gelu()
def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Optional[Tensor]=None, word_embedding:Tensor=None, position_embedding:Tensor=None, segment_embedding:Optional[Tensor]=None, gamma=None, beta=None, mask:Optional[Tensor]=None, position_ids:Optional[Tensor]=None, epsilon=None, mask_index_type=None):

View File

@ -3,20 +3,20 @@ import itertools, functools
from dataclasses import dataclass
from collections import defaultdict
from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict
from enum import Enum, auto
from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, UNSAFE_PAD_OPS, KernelInfo, BUFFER_UOPS, UOp, UOps, print_uops, type_verify, \
graph_rewrite, PatternMatcher
from tinygrad.device import Device
from tinygrad.renderer import Renderer, TensorCore, Program
from tinygrad.dtype import ImageDType, PtrDType
from tinygrad.helpers import _CURRENT_KERNEL, all_same, colored, ansilen, dedup, getenv, prod, DEBUG, TC_OPT, USE_TC, AMX, round_up, all_int, \
get_contraction, to_function_name, diskcache_put
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, get_contraction, to_function_name, diskcache_put
from tinygrad.helpers import _CURRENT_KERNEL, DEBUG, TC_OPT, USE_TC, AMX
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.shape.view import strides_for_shape
from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite
from tinygrad.codegen.lowerer import ast_to_uop
from enum import Enum, auto
class OptOps(Enum):
TC = auto(); UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702

View File

@ -2,8 +2,7 @@ import sys, pickle, atexit, importlib, contextlib
from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import Callable, Tuple, List, Dict, Optional, DefaultDict, cast, get_args
from tinygrad.ops import REDUCE_ALU, MetaOps, ReduceOps, UNSAFE_PAD_OPS, UnaryOps, UOp, UOps
from tinygrad.ops import PatternMatcher, UPat, graph_rewrite
from tinygrad.ops import REDUCE_ALU, MetaOps, ReduceOps, UNSAFE_PAD_OPS, UnaryOps, UOp, UOps, PatternMatcher, UPat, graph_rewrite
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, AST_REWRITE, \
GlobalCounters, all_same, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata, unwrap

View File

@ -1,7 +1,6 @@
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from __future__ import annotations
import dataclasses
import time, math, itertools, functools, struct, sys, inspect, pathlib, string
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, dataclasses
from contextlib import ContextDecorator
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Literal
from collections import defaultdict