mirror of https://github.com/commaai/tinygrad.git
vgg7 (image upscaling) implementation - not the best, but it works (#255)
* vgg7 implementation - not the best, but it works * VGG7 implementation: Spread nansbane to deter NaNs, maybe improved training experience * VGG7 implementation: Fix training, for real this time Results actually attempt to approximate the input * VGG7 implementation: Sample probability management
This commit is contained in:
parent
81bf933a91
commit
2653d33292
|
@ -0,0 +1,251 @@
|
|||
from PIL import Image
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.optim import SGD
|
||||
import extra.waifu2x
|
||||
from extra.kinne import KinneDir
|
||||
import sys
|
||||
import os
|
||||
import random
|
||||
import json
|
||||
import numpy
|
||||
|
||||
# amount of context erased by model
|
||||
CONTEXT = 7
|
||||
|
||||
def get_sample_count(samples_dir):
|
||||
try:
|
||||
samples_dir_count_file = open(samples_dir + "/sample_count.txt", "r")
|
||||
v = samples_dir_count_file.readline()
|
||||
samples_dir_count_file.close()
|
||||
return int(v)
|
||||
except:
|
||||
return 0
|
||||
|
||||
def set_sample_count(samples_dir, sc):
|
||||
samples_dir_count_file = open(samples_dir + "/sample_count.txt", "w")
|
||||
samples_dir_count_file.write(str(sc) + "\n")
|
||||
samples_dir_count_file.close()
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
print("python3 -m examples.vgg7 import MODELJSON MODELDIR")
|
||||
print(" imports a waifu2x JSON vgg_7 model, i.e. waifu2x/models/vgg_7/art/scale2.0x_model.json")
|
||||
print(" into a directory of float binaries along with a meta.txt file containing tensor sizes")
|
||||
print(" weight tensors are ordered in tinygrad/ncnn form, as so: (outC,inC,H,W)")
|
||||
print(" *this format is used by all other commands in this program*")
|
||||
print("python3 -m examples.vgg7 execute MODELDIR IMG_IN IMG_OUT")
|
||||
print(" given an already-nearest-neighbour-scaled image, runs vgg7 on it")
|
||||
print(" output image has 7 pixels removed on all edges")
|
||||
print(" do not run on large images, will have *hilarious* RAM use")
|
||||
print("python3 -m examples.vgg7 execute_full MODELDIR IMG_IN IMG_OUT")
|
||||
print(" does the 'whole thing' (padding, tiling)")
|
||||
print(" safe for large images, etc.")
|
||||
print("python3 -m examples.vgg7 new MODELDIR")
|
||||
print(" creates a new model (experimental)")
|
||||
print("python3 -m examples.vgg7 train MODELDIR SAMPLES_DIR ROUNDS ROUNDS_SAVE")
|
||||
print(" trains a model (experimental)")
|
||||
print(" (how experimental? well, every time I tried it, it flooded w/ NaNs)")
|
||||
print(" note: ROUNDS < 0 means 'forever'. ROUNDS_SAVE <= 0 is not a good idea.")
|
||||
print(" expects roughly execute's input as SAMPLES_DIR/IDXa.png")
|
||||
print(" expects roughly execute's output as SAMPLES_DIR/IDXb.png")
|
||||
print(" (i.e. my_samples/0a.png is the first pre-nearest-scaled image,")
|
||||
print(" my_samples/0b.png is the first original image)")
|
||||
print(" in addition, SAMPLES_DIR/samples_count.txt indicates sample count")
|
||||
print(" won't pad or tile, so keep image sizes sane")
|
||||
print("python3 -m examples.vgg7 samplify IMG_A IMG_B SAMPLES_DIR SIZE")
|
||||
print(" creates overlapping micropatches (SIZExSIZE w/ 7-pixel border) for training")
|
||||
print(" maintains/creates samples_count.txt automatically")
|
||||
print(" unlike training, IMG_A must be exactly half the size of IMG_B")
|
||||
sys.exit(1)
|
||||
|
||||
cmd = sys.argv[1]
|
||||
vgg7 = extra.waifu2x.Vgg7()
|
||||
|
||||
def nansbane(p):
|
||||
if numpy.isnan(numpy.min(p.data)):
|
||||
raise Exception("A NaN in the model has been detected. This model will not be interacted with to prevent further damage.")
|
||||
|
||||
def load_and_save(path, save):
|
||||
if save:
|
||||
for v in vgg7.get_parameters():
|
||||
nansbane(v)
|
||||
kn = KinneDir(model, save)
|
||||
kn.parameters(vgg7.get_parameters())
|
||||
kn.close()
|
||||
if not save:
|
||||
for v in vgg7.get_parameters():
|
||||
nansbane(v)
|
||||
|
||||
if cmd == "import":
|
||||
src = sys.argv[2]
|
||||
model = sys.argv[3]
|
||||
|
||||
vgg7.load_waifu2x_json(json.load(open(src, "rb")))
|
||||
|
||||
os.mkdir(model)
|
||||
load_and_save(model, True)
|
||||
elif cmd == "execute":
|
||||
model = sys.argv[2]
|
||||
in_file = sys.argv[3]
|
||||
out_file = sys.argv[4]
|
||||
|
||||
load_and_save(model, False)
|
||||
|
||||
extra.waifu2x.image_save(out_file, vgg7.forward(Tensor(extra.waifu2x.image_load(in_file))).data)
|
||||
elif cmd == "execute_full":
|
||||
model = sys.argv[2]
|
||||
in_file = sys.argv[3]
|
||||
out_file = sys.argv[4]
|
||||
|
||||
load_and_save(model, False)
|
||||
|
||||
extra.waifu2x.image_save(out_file, vgg7.forward_tiled(extra.waifu2x.image_load(in_file), 156))
|
||||
elif cmd == "new":
|
||||
model = sys.argv[2]
|
||||
|
||||
os.mkdir(model)
|
||||
load_and_save(model, True)
|
||||
elif cmd == "train":
|
||||
model = sys.argv[2]
|
||||
samples_base = sys.argv[3]
|
||||
samples_count = get_sample_count(samples_base)
|
||||
rounds = int(sys.argv[4])
|
||||
rounds_per_save = int(sys.argv[5])
|
||||
|
||||
load_and_save(model, False)
|
||||
|
||||
# Initialize sample probabilities.
|
||||
# This is used to try and get the network to focus on "interesting" samples,
|
||||
# which works nicely with the microsample system.
|
||||
sample_probs = None
|
||||
sample_probs_path = model + "/sample_probs.bin"
|
||||
try:
|
||||
# try to read...
|
||||
sample_probs = numpy.fromfile(sample_probs_path, "<f8")
|
||||
if sample_probs.shape[0] != samples_count:
|
||||
print("sample probs size != sample count - initializing")
|
||||
sample_probs = None
|
||||
except:
|
||||
# it's fine
|
||||
print("sample probs could not be loaded - initializing")
|
||||
pass
|
||||
|
||||
if sample_probs is None:
|
||||
# This stupidly high amount is used to force an initial pass over all samples
|
||||
sample_probs = numpy.ones(samples_count) * 1000
|
||||
|
||||
print("Training...")
|
||||
# Adam has a tendency to destroy the state of the network when restarted
|
||||
# Plus it's slower
|
||||
optim = SGD(vgg7.get_parameters())
|
||||
|
||||
rnum = 0
|
||||
while True:
|
||||
# The way the -1 option works is that rnum is never -1.
|
||||
if rnum == rounds:
|
||||
break
|
||||
|
||||
sample_idx = 0
|
||||
try:
|
||||
sample_idx = numpy.random.choice(samples_count, p = sample_probs / sample_probs.sum())
|
||||
except:
|
||||
print("exception occurred (PROBABLY value-probabilities-dont-sum-to-1)")
|
||||
sample_idx = random.randint(0, samples_count - 1)
|
||||
|
||||
x_img = extra.waifu2x.image_load(samples_base + "/" + str(sample_idx) + "a.png")
|
||||
y_img = extra.waifu2x.image_load(samples_base + "/" + str(sample_idx) + "b.png")
|
||||
|
||||
sample_x = Tensor(x_img, requires_grad = False)
|
||||
sample_y = Tensor(y_img, requires_grad = False)
|
||||
|
||||
# magic code roughly from readme example
|
||||
# An explaination, in case anyone else has to go down this path:
|
||||
# This runs the actual network normally
|
||||
out = vgg7.forward(sample_x)
|
||||
# Subtraction determines error here (as this is an image, not classification).
|
||||
# *Abs is the important bit* - at least for me, anyway.
|
||||
# The training process seeks to minimize this 'loss' value.
|
||||
# Minimization of loss *tends towards negative infinity*, so without the abs,
|
||||
# or without an implicit abs (the mul in the README),
|
||||
# loss will always go haywire in one direction or another.
|
||||
# Mean determines how errors are treated.
|
||||
# Do not use Sum. I tried that. It worked while I was using 1x1 patches...
|
||||
# Then it went exponential.
|
||||
# Also, Mean goes *after* abs. I realize this should have been obvious to me.
|
||||
loss = sample_y.sub(out).abs().mean()
|
||||
# This is the bit where tinygrad works backward from the loss
|
||||
optim.zero_grad()
|
||||
loss.backward()
|
||||
# And this updates the parameters
|
||||
optim.step()
|
||||
|
||||
# warning: used by sample probability adjuster
|
||||
loss_indicator = loss.max().data[0]
|
||||
print("Round " + str(rnum) + " : " + str(loss_indicator))
|
||||
|
||||
if (rnum % rounds_per_save) == 0:
|
||||
print("Saving")
|
||||
load_and_save(model, True)
|
||||
sample_probs.astype("<f8", "C").tofile(sample_probs_path)
|
||||
|
||||
# Update round state
|
||||
# Number
|
||||
rnum = rnum + 1
|
||||
# Probability management
|
||||
# there must always be a probability, no matter how slim, even if loss goes to 0
|
||||
sample_probs[sample_idx] = max(loss_indicator, 1.e-10)
|
||||
|
||||
# if we were told to save every round, we already saved
|
||||
if rounds_per_save != 1:
|
||||
print("Done with all rounds, saving")
|
||||
load_and_save(model, True)
|
||||
sample_probs.astype("<f8", "C").tofile(sample_probs_path)
|
||||
|
||||
elif cmd == "samplify":
|
||||
a_img = sys.argv[2]
|
||||
b_img = sys.argv[3]
|
||||
samples_base = sys.argv[4]
|
||||
sample_size = int(sys.argv[5])
|
||||
samples_count = get_sample_count(samples_base)
|
||||
|
||||
# This bit is interesting because it actually does some work.
|
||||
# Not much, but some work.
|
||||
a_img = extra.waifu2x.image_load(a_img)
|
||||
b_img = extra.waifu2x.image_load(b_img)
|
||||
|
||||
# as with the main library body,
|
||||
# Y X order is used here
|
||||
|
||||
# assertion before pre-upscaling is performed
|
||||
assert a_img.shape[2] == (b_img.shape[2] // 2)
|
||||
assert a_img.shape[3] == (b_img.shape[3] // 2)
|
||||
|
||||
# pre-upscaling - this matches the sizes (and coordinates)
|
||||
a_img = a_img.repeat(2, 2).repeat(2, 3)
|
||||
|
||||
samples_added = 0
|
||||
|
||||
# actual patch extraction
|
||||
for posy in range(CONTEXT, b_img.shape[2] - (CONTEXT + sample_size - 1), sample_size):
|
||||
for posx in range(CONTEXT, b_img.shape[3] - (CONTEXT + sample_size - 1), sample_size):
|
||||
# this is a viable patch location, add it
|
||||
# note the ranges here:
|
||||
# + there are always CONTEXT pixels *before* the point
|
||||
# + with no subtraction at the end, there'd already be a pixel *at* the point,
|
||||
# as ranges are exclusive
|
||||
# + additionally, there are sample_size - 1 additional sample pixels
|
||||
# + additionally, there are CONTEXT additional pixels
|
||||
# + therefore there are CONTEXT + sample_size pixels *at & after* the point
|
||||
patch_x = a_img[:, :, posy - CONTEXT : posy + CONTEXT + sample_size, posx - CONTEXT : posx + CONTEXT + sample_size]
|
||||
patch_y = b_img[:, :, posy : posy + sample_size, posx : posx + sample_size]
|
||||
|
||||
extra.waifu2x.image_save(samples_base + "/" + str(samples_count) + "a.png", patch_x)
|
||||
extra.waifu2x.image_save(samples_base + "/" + str(samples_count) + "b.png", patch_y)
|
||||
samples_count += 1
|
||||
samples_added += 1
|
||||
|
||||
print("Added " + str(samples_added) + " samples")
|
||||
set_sample_count(samples_base, samples_count)
|
||||
|
||||
else:
|
||||
print("unknown command")
|
||||
|
|
@ -0,0 +1,76 @@
|
|||
from tinygrad.tensor import Tensor
|
||||
import numpy
|
||||
import os
|
||||
|
||||
# Format Details:
|
||||
# A KINNE parameter set is stored as a set of files named "snoop_bin_*.bin",
|
||||
# where the * is a number starting at 0.
|
||||
# Each file is simply raw little-endian floats,
|
||||
# as readable by: numpy.fromfile(path, "<f4")
|
||||
# and as writable by: t.data.astype("<f4", "C").tofile(path)
|
||||
# This format is intended to be extremely simple to get into literally anything.
|
||||
# It is not intended to be structural or efficient - reloading a network when
|
||||
# unnecessary is inefficient anyway.
|
||||
# Ultimately, the idea behind this is as a format that, while it will always
|
||||
# require code to implement, requires as little code as possible, and therefore
|
||||
# works as a suitable interchange for any situation.
|
||||
# To add to the usability of the format, some informal metadata is provided,
|
||||
# in "meta.txt", which provides human-readable shape information.
|
||||
# This is intended to help with debugging other implementations of the network,
|
||||
# by providing concrete human-readable information on tensor shapes.
|
||||
# It is NOT meant to be read by machines.
|
||||
|
||||
class KinneDir:
|
||||
"""
|
||||
A KinneDir is an intermediate object used to save or load a model.
|
||||
"""
|
||||
|
||||
def __init__(self, base: str, save: bool):
|
||||
"""
|
||||
Opens a new KINNE directory with the given base path.
|
||||
If save is true, the directory is created if possible.
|
||||
(This does not create parents.)
|
||||
Save being true or false determines if tensors are loaded or saved.
|
||||
The base path is of the form "models/abc" - no trailing slash.
|
||||
It is important that if you wish to save in the current directory,
|
||||
you use ".", not the empty string.
|
||||
"""
|
||||
if save:
|
||||
try:
|
||||
os.mkdir(base)
|
||||
except:
|
||||
# Silence the exception - the directory may (and if reading, does) already exist.
|
||||
pass
|
||||
self.base = base + "/snoop_bin_"
|
||||
self.next_part_index = 0
|
||||
self.save = save
|
||||
if save:
|
||||
self.metadata = open(base + "/meta.txt", "w")
|
||||
|
||||
def parameter(self, t: Tensor):
|
||||
"""
|
||||
parameter loads or saves a parameter, given as a tensor.
|
||||
"""
|
||||
path = self.base + str(self.next_part_index) + ".bin"
|
||||
if self.save:
|
||||
t.data.astype("<f4", "C").tofile(path)
|
||||
self.metadata.write(str(self.next_part_index) + ": " + str(t.shape) + "\n")
|
||||
else:
|
||||
t.assign(Tensor(numpy.fromfile(path, "<f4")).reshape(shape=t.shape))
|
||||
self.next_part_index += 1
|
||||
|
||||
def parameters(self, params):
|
||||
"""
|
||||
parameters loads or saves a sequence of parameters.
|
||||
It's intended for easily attaching to an existing model,
|
||||
assuming that your parameters list orders are consistent.
|
||||
(In other words, usage with tinygrad.utils.get_parameters isn't advised -
|
||||
it's too 'implicit'.)
|
||||
"""
|
||||
for t in params:
|
||||
self.parameter(t)
|
||||
|
||||
def close(self):
|
||||
if self.save:
|
||||
self.metadata.close()
|
||||
|
|
@ -0,0 +1,178 @@
|
|||
# Implementation of waifu2x vgg7 in tinygrad.
|
||||
# Obviously, not developed, supported, etc. by the original waifu2x author(s).
|
||||
|
||||
import numpy
|
||||
from tinygrad.tensor import Tensor
|
||||
from PIL import Image
|
||||
|
||||
# File Formats
|
||||
|
||||
# tinygrad convolution tensor input layout is (1,c,y,x) - and therefore the form for all images used in the project
|
||||
# tinygrad convolution tensor weight layout is (outC,inC,H,W) - this matches NCNN (and therefore KINNE), but not waifu2x json
|
||||
|
||||
def image_load(path) -> numpy.ndarray:
|
||||
"""
|
||||
Loads an image in the shape expected by other functions in this module.
|
||||
Doesn't Tensor it, in case you need to do further work with it.
|
||||
"""
|
||||
# file
|
||||
na = numpy.array(Image.open(path))
|
||||
# fix shape
|
||||
na = numpy.moveaxis(na, [2,0,1], [0,1,2])
|
||||
# shape is now (3,h,w), add 1
|
||||
na = na.reshape(1,3,na.shape[1],na.shape[2])
|
||||
# change type
|
||||
na = na.astype("float32") / 255.0
|
||||
return na
|
||||
|
||||
def image_save(path, na: numpy.ndarray):
|
||||
"""
|
||||
Saves an image of the shape expected by other functions in this module.
|
||||
However, note this expects a numpy array.
|
||||
"""
|
||||
# change type
|
||||
na = numpy.fmax(numpy.fmin(na * 255.0, 255), 0).astype("uint8")
|
||||
# shape is now (1,3,h,w), remove 1
|
||||
na = na.reshape(3,na.shape[2],na.shape[3])
|
||||
# fix shape
|
||||
na = numpy.moveaxis(na, [0,1,2], [2,0,1])
|
||||
# shape is now (h,w,3)
|
||||
# file
|
||||
Image.fromarray(na).save(path)
|
||||
|
||||
# The Model
|
||||
|
||||
class Conv3x3Biased:
|
||||
"""
|
||||
A 3x3 convolution layer with some utility functions.
|
||||
"""
|
||||
def __init__(self, inC, outC, last = False):
|
||||
# Massively overstate the weights to get them to be focused on,
|
||||
# since otherwise the biases overrule everything
|
||||
self.weight = Tensor.uniform(outC, inC, 3, 3) * 16.0
|
||||
# Layout-wise, blatant cheat, but serious_mnist does it. I'd guess channels either have to have a size of 1 or whatever the target is?
|
||||
# Values-wise, entirely different blatant cheat.
|
||||
# In most cases, use uniform bias, but tiny.
|
||||
# For the last layer, use just 0.5, constant.
|
||||
if last:
|
||||
self.bias = Tensor.zeros(1, outC, 1, 1) + 0.5
|
||||
else:
|
||||
self.bias = Tensor.uniform(1, outC, 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
# You might be thinking, "but what about padding?"
|
||||
# Answer: Tiling is used to stitch everything back together, though you could pad the image before providing it.
|
||||
return x.conv2d(self.weight).add(self.bias)
|
||||
|
||||
def get_parameters(self) -> list:
|
||||
return [self.weight, self.bias]
|
||||
|
||||
def load_waifu2x_json(self, layer: dict):
|
||||
# Weights in this file are outChannel,inChannel,X,Y.
|
||||
# Not outChannel,inChannel,Y,X.
|
||||
# Therefore, transpose it before assignment.
|
||||
# I have long since forgotten how I worked this out.
|
||||
self.weight.assign(Tensor(layer["weight"]).reshape(shape=self.weight.shape).transpose(order=(0, 1, 3, 2)))
|
||||
self.bias.assign(Tensor(layer["bias"]).reshape(shape=self.bias.shape))
|
||||
|
||||
class Vgg7:
|
||||
"""
|
||||
The 'vgg7' waifu2x network.
|
||||
Lower quality and slower than even upconv7 (nevermind cunet), but is very easy to implement and test.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.conv1 = Conv3x3Biased(3, 32)
|
||||
self.conv2 = Conv3x3Biased(32, 32)
|
||||
self.conv3 = Conv3x3Biased(32, 64)
|
||||
self.conv4 = Conv3x3Biased(64, 64)
|
||||
self.conv5 = Conv3x3Biased(64, 128)
|
||||
self.conv6 = Conv3x3Biased(128, 128)
|
||||
self.conv7 = Conv3x3Biased(128, 3, True)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass: Actually runs the network.
|
||||
Input format: (1, 3, Y, X)
|
||||
Output format: (1, 3, Y - 14, X - 14)
|
||||
(the - 14 represents the 7-pixel context border that is lost)
|
||||
"""
|
||||
x = self.conv1.forward(x).leakyrelu(0.1)
|
||||
x = self.conv2.forward(x).leakyrelu(0.1)
|
||||
x = self.conv3.forward(x).leakyrelu(0.1)
|
||||
x = self.conv4.forward(x).leakyrelu(0.1)
|
||||
x = self.conv5.forward(x).leakyrelu(0.1)
|
||||
x = self.conv6.forward(x).leakyrelu(0.1)
|
||||
x = self.conv7.forward(x)
|
||||
return x
|
||||
|
||||
def get_parameters(self) -> list:
|
||||
return self.conv1.get_parameters() + self.conv2.get_parameters() + self.conv3.get_parameters() + self.conv4.get_parameters() + self.conv5.get_parameters() + self.conv6.get_parameters() + self.conv7.get_parameters()
|
||||
|
||||
def load_waifu2x_json(self, data: list):
|
||||
"""
|
||||
Loads weights from one of the waifu2x JSON files, i.e. waifu2x/models/vgg_7/art/noise0_model.json
|
||||
data (passed in) is assumed to be the output of json.load or some similar on such a file
|
||||
"""
|
||||
self.conv1.load_waifu2x_json(data[0])
|
||||
self.conv2.load_waifu2x_json(data[1])
|
||||
self.conv3.load_waifu2x_json(data[2])
|
||||
self.conv4.load_waifu2x_json(data[3])
|
||||
self.conv5.load_waifu2x_json(data[4])
|
||||
self.conv6.load_waifu2x_json(data[5])
|
||||
self.conv7.load_waifu2x_json(data[6])
|
||||
|
||||
|
||||
def forward_tiled(self, image: numpy.ndarray, tile_size: int) -> numpy.ndarray:
|
||||
"""
|
||||
Given an ndarray image as loaded by image_load (NOT a tensor), scales it, pads it, splits it up, forwards the pieces, and reconstitutes it.
|
||||
Note that you really shouldn't try to run anything not (1, 3, *, *) through this.
|
||||
"""
|
||||
# Constant that only really gets repeated a ton here.
|
||||
context = 7
|
||||
context2 = context + context
|
||||
|
||||
# Notably, numpy is used here because it makes this fine manipulation a lot simpler.
|
||||
# Scaling first - repeat on axis 2 and axis 3 (Y & X)
|
||||
image = image.repeat(2, 2).repeat(2, 3)
|
||||
|
||||
# Resulting image buffer. This is made before the input is padded,
|
||||
# since the input has the padded shape right now.
|
||||
image_out = numpy.zeros(image.shape)
|
||||
|
||||
# Padding next. Note that this padding is done on the whole image.
|
||||
# Padding the tiles would lose critical context, cause seams, etc.
|
||||
image = numpy.pad(image, [[0, 0], [0, 0], [context, context], [context, context]], mode = "edge")
|
||||
|
||||
# Now for tiling.
|
||||
# The output tile size is the usable output from an input tile (tile_size).
|
||||
# As such, the tiles overlap.
|
||||
out_tile_size = tile_size - context2
|
||||
for out_y in range(0, image_out.shape[2], out_tile_size):
|
||||
for out_x in range(0, image_out.shape[3], out_tile_size):
|
||||
# Input is sourced from the same coordinates, but some stuff ought to be
|
||||
# noted here for future reference:
|
||||
# + out_x/y's equivalent position w/ the padding is out_x + context.
|
||||
# + The output, however, is without context. Input needs context.
|
||||
# + Therefore, the input rectangle is expanded on all sides by context.
|
||||
# + Therefore, the input position has the context subtracted again.
|
||||
# + Therefore:
|
||||
in_y = out_y
|
||||
in_x = out_x
|
||||
# not shown: in_w/in_h = tile_size (as opposed to out_tile_size)
|
||||
# Extract tile.
|
||||
# Note that numpy will auto-crop this at the bottom-right.
|
||||
# This will never be a problem, as tiles are specifically chosen within the padded section.
|
||||
tile = image[:, :, in_y:in_y + tile_size, in_x:in_x + tile_size]
|
||||
# Extracted tile dimensions -> output dimensions
|
||||
# This is important because of said cropping, otherwise it'd be interior tile size.
|
||||
out_h = tile.shape[2] - context2
|
||||
out_w = tile.shape[3] - context2
|
||||
# Process tile.
|
||||
tile_t = Tensor(tile)
|
||||
tile_fwd_t = self.forward(tile_t)
|
||||
# Replace tile.
|
||||
image_out[:, :, out_y:out_y + out_h, out_x:out_x + out_w] = tile_fwd_t.data
|
||||
|
||||
return image_out
|
||||
|
Loading…
Reference in New Issue