2021-05-13 14:48:51 +08:00
|
|
|
import sys
|
|
|
|
import random
|
|
|
|
import json
|
|
|
|
import numpy
|
2023-08-31 01:41:08 +08:00
|
|
|
from pathlib import Path
|
2023-02-11 02:09:37 +08:00
|
|
|
from PIL import Image
|
|
|
|
from tinygrad.tensor import Tensor
|
|
|
|
from tinygrad.nn.optim import SGD
|
2023-10-19 00:20:11 +08:00
|
|
|
from tinygrad.nn.state import safe_save, safe_load, get_state_dict, load_state_dict
|
2023-07-02 03:04:28 +08:00
|
|
|
from examples.vgg7_helpers.waifu2x import image_load, image_save, Vgg7
|
2021-05-13 14:48:51 +08:00
|
|
|
|
|
|
|
# amount of context erased by model
|
|
|
|
CONTEXT = 7
|
|
|
|
|
|
|
|
def get_sample_count(samples_dir):
|
|
|
|
try:
|
|
|
|
samples_dir_count_file = open(samples_dir + "/sample_count.txt", "r")
|
|
|
|
v = samples_dir_count_file.readline()
|
|
|
|
samples_dir_count_file.close()
|
|
|
|
return int(v)
|
|
|
|
except:
|
|
|
|
return 0
|
|
|
|
|
|
|
|
def set_sample_count(samples_dir, sc):
|
2023-02-11 02:09:37 +08:00
|
|
|
with open(samples_dir + "/sample_count.txt", "w") as file:
|
|
|
|
file.write(str(sc) + "\n")
|
2021-05-13 14:48:51 +08:00
|
|
|
|
|
|
|
if len(sys.argv) < 2:
|
2023-10-20 13:07:15 +08:00
|
|
|
print("python3 -m examples.vgg7 import MODELJSON MODEL")
|
2021-05-13 14:48:51 +08:00
|
|
|
print(" imports a waifu2x JSON vgg_7 model, i.e. waifu2x/models/vgg_7/art/scale2.0x_model.json")
|
2023-10-19 00:20:11 +08:00
|
|
|
print(" into a safetensors file")
|
2021-05-13 14:48:51 +08:00
|
|
|
print(" weight tensors are ordered in tinygrad/ncnn form, as so: (outC,inC,H,W)")
|
2023-10-19 00:20:11 +08:00
|
|
|
print(" *this format is used by most other commands in this program*")
|
2023-10-20 13:07:15 +08:00
|
|
|
print("python3 -m examples.vgg7 import_kinne MODEL_KINNE MODEL_SAFETENSORS")
|
2023-10-19 00:20:11 +08:00
|
|
|
print(" imports a model in 'KINNE' format (raw floats: used by older versions of this example) into safetensors")
|
2023-10-20 13:07:15 +08:00
|
|
|
print("python3 -m examples.vgg7 execute MODEL IMG_IN IMG_OUT")
|
2021-05-13 14:48:51 +08:00
|
|
|
print(" given an already-nearest-neighbour-scaled image, runs vgg7 on it")
|
|
|
|
print(" output image has 7 pixels removed on all edges")
|
|
|
|
print(" do not run on large images, will have *hilarious* RAM use")
|
2023-10-20 13:07:15 +08:00
|
|
|
print("python3 -m examples.vgg7 execute_full MODEL IMG_IN IMG_OUT")
|
2021-05-13 14:48:51 +08:00
|
|
|
print(" does the 'whole thing' (padding, tiling)")
|
|
|
|
print(" safe for large images, etc.")
|
2023-10-20 13:07:15 +08:00
|
|
|
print("python3 -m examples.vgg7 new MODEL")
|
2021-05-13 14:48:51 +08:00
|
|
|
print(" creates a new model (experimental)")
|
2023-10-20 13:07:15 +08:00
|
|
|
print("python3 -m examples.vgg7 train MODEL SAMPLES_DIR ROUNDS ROUNDS_SAVE")
|
2021-05-13 14:48:51 +08:00
|
|
|
print(" trains a model (experimental)")
|
|
|
|
print(" (how experimental? well, every time I tried it, it flooded w/ NaNs)")
|
|
|
|
print(" note: ROUNDS < 0 means 'forever'. ROUNDS_SAVE <= 0 is not a good idea.")
|
|
|
|
print(" expects roughly execute's input as SAMPLES_DIR/IDXa.png")
|
|
|
|
print(" expects roughly execute's output as SAMPLES_DIR/IDXb.png")
|
|
|
|
print(" (i.e. my_samples/0a.png is the first pre-nearest-scaled image,")
|
|
|
|
print(" my_samples/0b.png is the first original image)")
|
|
|
|
print(" in addition, SAMPLES_DIR/samples_count.txt indicates sample count")
|
|
|
|
print(" won't pad or tile, so keep image sizes sane")
|
|
|
|
print("python3 -m examples.vgg7 samplify IMG_A IMG_B SAMPLES_DIR SIZE")
|
|
|
|
print(" creates overlapping micropatches (SIZExSIZE w/ 7-pixel border) for training")
|
|
|
|
print(" maintains/creates samples_count.txt automatically")
|
|
|
|
print(" unlike training, IMG_A must be exactly half the size of IMG_B")
|
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
cmd = sys.argv[1]
|
2023-02-11 02:09:37 +08:00
|
|
|
vgg7 = Vgg7()
|
2021-05-13 14:48:51 +08:00
|
|
|
|
|
|
|
def nansbane(p):
|
2023-02-19 08:36:12 +08:00
|
|
|
if numpy.isnan(numpy.min(p.numpy())):
|
2021-05-13 14:48:51 +08:00
|
|
|
raise Exception("A NaN in the model has been detected. This model will not be interacted with to prevent further damage.")
|
|
|
|
|
|
|
|
def load_and_save(path, save):
|
|
|
|
if save:
|
|
|
|
for v in vgg7.get_parameters():
|
|
|
|
nansbane(v)
|
2023-10-19 00:20:11 +08:00
|
|
|
st = get_state_dict(vgg7)
|
|
|
|
safe_save(st, path)
|
|
|
|
else:
|
|
|
|
st = safe_load(path)
|
|
|
|
load_state_dict(vgg7, st)
|
2021-05-13 14:48:51 +08:00
|
|
|
for v in vgg7.get_parameters():
|
|
|
|
nansbane(v)
|
|
|
|
|
|
|
|
if cmd == "import":
|
|
|
|
src = sys.argv[2]
|
|
|
|
model = sys.argv[3]
|
|
|
|
|
|
|
|
vgg7.load_waifu2x_json(json.load(open(src, "rb")))
|
|
|
|
|
2023-10-19 00:20:11 +08:00
|
|
|
load_and_save(model, True)
|
|
|
|
elif cmd == "import_kinne":
|
|
|
|
# tinygrad wasn't doing safetensors when this example was written
|
|
|
|
# it's possible someone might have a model around using the resulting interim format
|
|
|
|
src = sys.argv[2]
|
|
|
|
model = sys.argv[3]
|
|
|
|
|
|
|
|
index = 0
|
|
|
|
for t in vgg7.get_parameters():
|
|
|
|
fn = src + "/snoop_bin_" + str(index) + ".bin"
|
|
|
|
t.assign(Tensor(numpy.fromfile(fn, "<f4")).reshape(shape=t.shape))
|
|
|
|
index += 1
|
|
|
|
|
2021-05-13 14:48:51 +08:00
|
|
|
load_and_save(model, True)
|
|
|
|
elif cmd == "execute":
|
|
|
|
model = sys.argv[2]
|
|
|
|
in_file = sys.argv[3]
|
|
|
|
out_file = sys.argv[4]
|
|
|
|
|
|
|
|
load_and_save(model, False)
|
|
|
|
|
2023-02-19 08:36:12 +08:00
|
|
|
image_save(out_file, vgg7.forward(Tensor(image_load(in_file))).numpy())
|
2021-05-13 14:48:51 +08:00
|
|
|
elif cmd == "execute_full":
|
|
|
|
model = sys.argv[2]
|
|
|
|
in_file = sys.argv[3]
|
|
|
|
out_file = sys.argv[4]
|
|
|
|
|
|
|
|
load_and_save(model, False)
|
|
|
|
|
2023-02-11 02:09:37 +08:00
|
|
|
image_save(out_file, vgg7.forward_tiled(image_load(in_file), 156))
|
2021-05-13 14:48:51 +08:00
|
|
|
elif cmd == "new":
|
|
|
|
model = sys.argv[2]
|
|
|
|
|
|
|
|
load_and_save(model, True)
|
|
|
|
elif cmd == "train":
|
|
|
|
model = sys.argv[2]
|
|
|
|
samples_base = sys.argv[3]
|
|
|
|
samples_count = get_sample_count(samples_base)
|
|
|
|
rounds = int(sys.argv[4])
|
|
|
|
rounds_per_save = int(sys.argv[5])
|
|
|
|
|
|
|
|
load_and_save(model, False)
|
|
|
|
|
|
|
|
# Initialize sample probabilities.
|
|
|
|
# This is used to try and get the network to focus on "interesting" samples,
|
|
|
|
# which works nicely with the microsample system.
|
|
|
|
sample_probs = None
|
2023-10-20 13:07:15 +08:00
|
|
|
sample_probs_path = model + "_sample_probs.bin"
|
2021-05-13 14:48:51 +08:00
|
|
|
try:
|
|
|
|
# try to read...
|
|
|
|
sample_probs = numpy.fromfile(sample_probs_path, "<f8")
|
|
|
|
if sample_probs.shape[0] != samples_count:
|
|
|
|
print("sample probs size != sample count - initializing")
|
|
|
|
sample_probs = None
|
|
|
|
except:
|
|
|
|
# it's fine
|
|
|
|
print("sample probs could not be loaded - initializing")
|
|
|
|
|
|
|
|
if sample_probs is None:
|
|
|
|
# This stupidly high amount is used to force an initial pass over all samples
|
|
|
|
sample_probs = numpy.ones(samples_count) * 1000
|
|
|
|
|
|
|
|
print("Training...")
|
|
|
|
# Adam has a tendency to destroy the state of the network when restarted
|
|
|
|
# Plus it's slower
|
|
|
|
optim = SGD(vgg7.get_parameters())
|
|
|
|
|
|
|
|
rnum = 0
|
|
|
|
while True:
|
|
|
|
# The way the -1 option works is that rnum is never -1.
|
|
|
|
if rnum == rounds:
|
|
|
|
break
|
|
|
|
|
|
|
|
sample_idx = 0
|
|
|
|
try:
|
|
|
|
sample_idx = numpy.random.choice(samples_count, p = sample_probs / sample_probs.sum())
|
|
|
|
except:
|
|
|
|
print("exception occurred (PROBABLY value-probabilities-dont-sum-to-1)")
|
|
|
|
sample_idx = random.randint(0, samples_count - 1)
|
|
|
|
|
2023-02-11 02:09:37 +08:00
|
|
|
x_img = image_load(samples_base + "/" + str(sample_idx) + "a.png")
|
|
|
|
y_img = image_load(samples_base + "/" + str(sample_idx) + "b.png")
|
2021-05-13 14:48:51 +08:00
|
|
|
|
|
|
|
sample_x = Tensor(x_img, requires_grad = False)
|
|
|
|
sample_y = Tensor(y_img, requires_grad = False)
|
|
|
|
|
|
|
|
# magic code roughly from readme example
|
2022-04-05 20:22:18 +08:00
|
|
|
# An explanation, in case anyone else has to go down this path:
|
2021-05-13 14:48:51 +08:00
|
|
|
# This runs the actual network normally
|
|
|
|
out = vgg7.forward(sample_x)
|
|
|
|
# Subtraction determines error here (as this is an image, not classification).
|
|
|
|
# *Abs is the important bit* - at least for me, anyway.
|
|
|
|
# The training process seeks to minimize this 'loss' value.
|
|
|
|
# Minimization of loss *tends towards negative infinity*, so without the abs,
|
|
|
|
# or without an implicit abs (the mul in the README),
|
|
|
|
# loss will always go haywire in one direction or another.
|
|
|
|
# Mean determines how errors are treated.
|
|
|
|
# Do not use Sum. I tried that. It worked while I was using 1x1 patches...
|
|
|
|
# Then it went exponential.
|
|
|
|
# Also, Mean goes *after* abs. I realize this should have been obvious to me.
|
|
|
|
loss = sample_y.sub(out).abs().mean()
|
|
|
|
# This is the bit where tinygrad works backward from the loss
|
|
|
|
optim.zero_grad()
|
|
|
|
loss.backward()
|
|
|
|
# And this updates the parameters
|
|
|
|
optim.step()
|
|
|
|
|
|
|
|
# warning: used by sample probability adjuster
|
2023-10-20 13:07:15 +08:00
|
|
|
loss_indicator = loss.max().numpy()
|
2021-05-13 14:48:51 +08:00
|
|
|
print("Round " + str(rnum) + " : " + str(loss_indicator))
|
|
|
|
|
|
|
|
if (rnum % rounds_per_save) == 0:
|
|
|
|
print("Saving")
|
|
|
|
load_and_save(model, True)
|
|
|
|
sample_probs.astype("<f8", "C").tofile(sample_probs_path)
|
|
|
|
|
|
|
|
# Update round state
|
|
|
|
# Number
|
|
|
|
rnum = rnum + 1
|
|
|
|
# Probability management
|
|
|
|
# there must always be a probability, no matter how slim, even if loss goes to 0
|
|
|
|
sample_probs[sample_idx] = max(loss_indicator, 1.e-10)
|
|
|
|
|
|
|
|
# if we were told to save every round, we already saved
|
|
|
|
if rounds_per_save != 1:
|
|
|
|
print("Done with all rounds, saving")
|
|
|
|
load_and_save(model, True)
|
|
|
|
sample_probs.astype("<f8", "C").tofile(sample_probs_path)
|
|
|
|
|
|
|
|
elif cmd == "samplify":
|
|
|
|
a_img = sys.argv[2]
|
|
|
|
b_img = sys.argv[3]
|
|
|
|
samples_base = sys.argv[4]
|
|
|
|
sample_size = int(sys.argv[5])
|
|
|
|
samples_count = get_sample_count(samples_base)
|
|
|
|
|
|
|
|
# This bit is interesting because it actually does some work.
|
|
|
|
# Not much, but some work.
|
2023-02-11 02:09:37 +08:00
|
|
|
a_img = image_load(a_img)
|
|
|
|
b_img = image_load(b_img)
|
2021-05-13 14:48:51 +08:00
|
|
|
|
|
|
|
# as with the main library body,
|
|
|
|
# Y X order is used here
|
|
|
|
|
|
|
|
# assertion before pre-upscaling is performed
|
|
|
|
assert a_img.shape[2] == (b_img.shape[2] // 2)
|
|
|
|
assert a_img.shape[3] == (b_img.shape[3] // 2)
|
|
|
|
|
|
|
|
# pre-upscaling - this matches the sizes (and coordinates)
|
|
|
|
a_img = a_img.repeat(2, 2).repeat(2, 3)
|
|
|
|
|
|
|
|
samples_added = 0
|
|
|
|
|
|
|
|
# actual patch extraction
|
|
|
|
for posy in range(CONTEXT, b_img.shape[2] - (CONTEXT + sample_size - 1), sample_size):
|
|
|
|
for posx in range(CONTEXT, b_img.shape[3] - (CONTEXT + sample_size - 1), sample_size):
|
|
|
|
# this is a viable patch location, add it
|
|
|
|
# note the ranges here:
|
|
|
|
# + there are always CONTEXT pixels *before* the point
|
|
|
|
# + with no subtraction at the end, there'd already be a pixel *at* the point,
|
|
|
|
# as ranges are exclusive
|
|
|
|
# + additionally, there are sample_size - 1 additional sample pixels
|
|
|
|
# + additionally, there are CONTEXT additional pixels
|
|
|
|
# + therefore there are CONTEXT + sample_size pixels *at & after* the point
|
|
|
|
patch_x = a_img[:, :, posy - CONTEXT : posy + CONTEXT + sample_size, posx - CONTEXT : posx + CONTEXT + sample_size]
|
|
|
|
patch_y = b_img[:, :, posy : posy + sample_size, posx : posx + sample_size]
|
|
|
|
|
2023-02-11 02:09:37 +08:00
|
|
|
image_save(f"{samples_base}/{str(samples_count)}a.png", patch_x)
|
|
|
|
image_save(f"{samples_base}/{str(samples_count)}b.png", patch_y)
|
2021-05-13 14:48:51 +08:00
|
|
|
samples_count += 1
|
|
|
|
samples_added += 1
|
|
|
|
|
2023-02-11 02:09:37 +08:00
|
|
|
print(f"Added {str(samples_added)} samples")
|
2021-05-13 14:48:51 +08:00
|
|
|
set_sample_count(samples_base, samples_count)
|
|
|
|
|
|
|
|
else:
|
|
|
|
print("unknown command")
|