mirror of https://github.com/commaai/tinygrad.git
tiny import cleanup and fix typo (#6692)
This commit is contained in:
parent
02c0c09fb9
commit
31b9c74c77
|
@ -606,7 +606,7 @@ def SkipLayerNormalization(x:Tensor, skip:Tensor, gamma, beta:Optional[Tensor]=N
|
|||
return x.layernorm(eps=epsilon) * gamma + beta, None, None, x
|
||||
|
||||
def FastGelu(x:Tensor, bias:Optional[Tensor]=None):
|
||||
# this is tanh approamixated
|
||||
# this is tanh approximated
|
||||
return (x + bias).gelu()
|
||||
|
||||
def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Optional[Tensor]=None, word_embedding:Tensor=None, position_embedding:Tensor=None, segment_embedding:Optional[Tensor]=None, gamma=None, beta=None, mask:Optional[Tensor]=None, position_ids:Optional[Tensor]=None, epsilon=None, mask_index_type=None):
|
||||
|
|
|
@ -3,20 +3,20 @@ import itertools, functools
|
|||
from dataclasses import dataclass
|
||||
from collections import defaultdict
|
||||
from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict
|
||||
from enum import Enum, auto
|
||||
|
||||
from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, UNSAFE_PAD_OPS, KernelInfo, BUFFER_UOPS, UOp, UOps, print_uops, type_verify, \
|
||||
graph_rewrite, PatternMatcher
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.renderer import Renderer, TensorCore, Program
|
||||
from tinygrad.dtype import ImageDType, PtrDType
|
||||
from tinygrad.helpers import _CURRENT_KERNEL, all_same, colored, ansilen, dedup, getenv, prod, DEBUG, TC_OPT, USE_TC, AMX, round_up, all_int, \
|
||||
get_contraction, to_function_name, diskcache_put
|
||||
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, get_contraction, to_function_name, diskcache_put
|
||||
from tinygrad.helpers import _CURRENT_KERNEL, DEBUG, TC_OPT, USE_TC, AMX
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite
|
||||
from tinygrad.codegen.lowerer import ast_to_uop
|
||||
from enum import Enum, auto
|
||||
|
||||
class OptOps(Enum):
|
||||
TC = auto(); UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
|
||||
|
|
|
@ -2,8 +2,7 @@ import sys, pickle, atexit, importlib, contextlib
|
|||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Tuple, List, Dict, Optional, DefaultDict, cast, get_args
|
||||
from tinygrad.ops import REDUCE_ALU, MetaOps, ReduceOps, UNSAFE_PAD_OPS, UnaryOps, UOp, UOps
|
||||
from tinygrad.ops import PatternMatcher, UPat, graph_rewrite
|
||||
from tinygrad.ops import REDUCE_ALU, MetaOps, ReduceOps, UNSAFE_PAD_OPS, UnaryOps, UOp, UOps, PatternMatcher, UPat, graph_rewrite
|
||||
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
|
||||
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, AST_REWRITE, \
|
||||
GlobalCounters, all_same, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata, unwrap
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
||||
from __future__ import annotations
|
||||
import dataclasses
|
||||
import time, math, itertools, functools, struct, sys, inspect, pathlib, string
|
||||
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, dataclasses
|
||||
from contextlib import ContextDecorator
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Literal
|
||||
from collections import defaultdict
|
||||
|
|
Loading…
Reference in New Issue