Initial commit for SeaDiff project code
This commit is contained in:
228
utils/RGBuvHistBlock.py
Normal file
228
utils/RGBuvHistBlock.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
##### Copyright 2021 Mahmoud Afifi.
|
||||
|
||||
If you find this code useful, please cite our paper:
|
||||
|
||||
Mahmoud Afifi, Marcus A. Brubaker, and Michael S. Brown. "HistoGAN:
|
||||
Controlling Colors of GAN-Generated and Real Images via Color Histograms."
|
||||
In CVPR, 2021.
|
||||
|
||||
@inproceedings{afifi2021histogan,
|
||||
title={Histo{GAN}: Controlling Colors of {GAN}-Generated and Real Images via
|
||||
Color Histograms},
|
||||
author={Afifi, Mahmoud and Brubaker, Marcus A. and Brown, Michael S.},
|
||||
booktitle={CVPR},
|
||||
year={2021}
|
||||
}
|
||||
####
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
EPS = 1e-6
|
||||
|
||||
|
||||
class RGBuvHistBlock(nn.Module):
|
||||
def __init__(self, h=64, insz=150, resizing='interpolation',
|
||||
method='inverse-quadratic', sigma=0.02, intensity_scale=True,
|
||||
hist_boundary=None, green_only=False, device='cuda'):
|
||||
""" Computes the RGB-uv histogram feature of a given image.
|
||||
Args:
|
||||
h: histogram dimension size (scalar). The default value is 64.
|
||||
insz: maximum size of the input image; if it is larger than this size, the
|
||||
image will be resized (scalar). Default value is 150 (i.e., 150 x 150
|
||||
pixels).
|
||||
resizing: resizing method if applicable. Options are: 'interpolation' or
|
||||
'sampling'. Default is 'interpolation'.
|
||||
method: the method used to count the number of pixels for each bin in the
|
||||
histogram feature. Options are: 'thresholding', 'RBF' (radial basis
|
||||
function), or 'inverse-quadratic'. Default value is 'inverse-quadratic'.
|
||||
sigma: if the method value is 'RBF' or 'inverse-quadratic', then this is
|
||||
the sigma parameter of the kernel function. The default value is 0.02.
|
||||
intensity_scale: boolean variable to use the intensity scale (I_y in
|
||||
Equation 2). Default value is True.
|
||||
hist_boundary: a list of histogram boundary values. Default is [-3, 3].
|
||||
green_only: boolean variable to use only the log(g/r), log(g/b) channels.
|
||||
Default is False.
|
||||
|
||||
Methods:
|
||||
forward: accepts input image and returns its histogram feature. Note that
|
||||
unless the method is 'thresholding', this is a differentiable function
|
||||
and can be easily integrated with the loss function. As mentioned in the
|
||||
paper, the 'inverse-quadratic' was found more stable than 'RBF' in our
|
||||
training.
|
||||
"""
|
||||
super(RGBuvHistBlock, self).__init__()
|
||||
self.h = h
|
||||
self.insz = insz
|
||||
self.device = device
|
||||
self.resizing = resizing
|
||||
self.method = method
|
||||
self.intensity_scale = intensity_scale
|
||||
self.green_only = green_only
|
||||
if hist_boundary is None:
|
||||
hist_boundary = [-3, 3]
|
||||
hist_boundary.sort()
|
||||
self.hist_boundary = hist_boundary
|
||||
if self.method == 'thresholding':
|
||||
self.eps = (abs(hist_boundary[0]) + abs(hist_boundary[1])) / h
|
||||
else:
|
||||
self.sigma = sigma
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.clamp(x, 0, 1)
|
||||
if x.shape[2] > self.insz or x.shape[3] > self.insz:
|
||||
if self.resizing == 'interpolation':
|
||||
x_sampled = F.interpolate(x, size=(self.insz, self.insz),
|
||||
mode='bilinear', align_corners=False)
|
||||
elif self.resizing == 'sampling':
|
||||
inds_1 = torch.LongTensor(
|
||||
np.linspace(0, x.shape[2], self.h, endpoint=False)).to(
|
||||
device=self.device)
|
||||
inds_2 = torch.LongTensor(
|
||||
np.linspace(0, x.shape[3], self.h, endpoint=False)).to(
|
||||
device=self.device)
|
||||
x_sampled = x.index_select(2, inds_1)
|
||||
x_sampled = x_sampled.index_select(3, inds_2)
|
||||
else:
|
||||
raise Exception(
|
||||
f'Wrong resizing method. It should be: interpolation or sampling. '
|
||||
f'But the given value is {self.resizing}.')
|
||||
else:
|
||||
x_sampled = x
|
||||
|
||||
L = x_sampled.shape[0] # size of mini-batch
|
||||
if x_sampled.shape[1] > 3:
|
||||
x_sampled = x_sampled[:, :3, :, :]
|
||||
X = torch.unbind(x_sampled, dim=0)
|
||||
hists = torch.zeros((x_sampled.shape[0], 1 + int(not self.green_only) * 2,
|
||||
self.h, self.h)).to(device=self.device)
|
||||
for l in range(L):
|
||||
I = torch.t(torch.reshape(X[l], (3, -1)))
|
||||
II = torch.pow(I, 2)
|
||||
if self.intensity_scale:
|
||||
Iy = torch.unsqueeze(torch.sqrt(II[:, 0] + II[:, 1] + II[:, 2] + EPS),
|
||||
dim=1)
|
||||
else:
|
||||
Iy = 1
|
||||
if not self.green_only:
|
||||
Iu0 = torch.unsqueeze(torch.log(I[:, 0] + EPS) - torch.log(I[:, 1] +
|
||||
EPS), dim=1)
|
||||
Iv0 = torch.unsqueeze(torch.log(I[:, 0] + EPS) - torch.log(I[:, 2] +
|
||||
EPS), dim=1)
|
||||
diff_u0 = abs(
|
||||
Iu0 - torch.unsqueeze(torch.tensor(np.linspace(
|
||||
self.hist_boundary[0], self.hist_boundary[1], num=self.h)),
|
||||
dim=0).to(self.device))
|
||||
diff_v0 = abs(
|
||||
Iv0 - torch.unsqueeze(torch.tensor(np.linspace(
|
||||
self.hist_boundary[0], self.hist_boundary[1], num=self.h)),
|
||||
dim=0).to(self.device))
|
||||
if self.method == 'thresholding':
|
||||
diff_u0 = torch.reshape(diff_u0, (-1, self.h)) <= self.eps / 2
|
||||
diff_v0 = torch.reshape(diff_v0, (-1, self.h)) <= self.eps / 2
|
||||
elif self.method == 'RBF':
|
||||
diff_u0 = torch.pow(torch.reshape(diff_u0, (-1, self.h)),
|
||||
2) / self.sigma ** 2
|
||||
diff_v0 = torch.pow(torch.reshape(diff_v0, (-1, self.h)),
|
||||
2) / self.sigma ** 2
|
||||
diff_u0 = torch.exp(-diff_u0) # Radial basis function
|
||||
diff_v0 = torch.exp(-diff_v0)
|
||||
elif self.method == 'inverse-quadratic':
|
||||
diff_u0 = torch.pow(torch.reshape(diff_u0, (-1, self.h)),
|
||||
2) / self.sigma ** 2
|
||||
diff_v0 = torch.pow(torch.reshape(diff_v0, (-1, self.h)),
|
||||
2) / self.sigma ** 2
|
||||
diff_u0 = 1 / (1 + diff_u0) # Inverse quadratic
|
||||
diff_v0 = 1 / (1 + diff_v0)
|
||||
else:
|
||||
raise Exception(
|
||||
f'Wrong kernel method. It should be either thresholding, RBF,'
|
||||
f' inverse-quadratic. But the given value is {self.method}.')
|
||||
diff_u0 = diff_u0.type(torch.float32)
|
||||
diff_v0 = diff_v0.type(torch.float32)
|
||||
a = torch.t(Iy * diff_u0)
|
||||
hists[l, 0, :, :] = torch.mm(a, diff_v0)
|
||||
|
||||
Iu1 = torch.unsqueeze(torch.log(I[:, 1] + EPS) - torch.log(I[:, 0] + EPS),
|
||||
dim=1)
|
||||
Iv1 = torch.unsqueeze(torch.log(I[:, 1] + EPS) - torch.log(I[:, 2] + EPS),
|
||||
dim=1)
|
||||
diff_u1 = abs(
|
||||
Iu1 - torch.unsqueeze(torch.tensor(np.linspace(
|
||||
self.hist_boundary[0], self.hist_boundary[1], num=self.h)),
|
||||
dim=0).to(self.device))
|
||||
diff_v1 = abs(
|
||||
Iv1 - torch.unsqueeze(torch.tensor(np.linspace(
|
||||
self.hist_boundary[0], self.hist_boundary[1], num=self.h)),
|
||||
dim=0).to(self.device))
|
||||
|
||||
if self.method == 'thresholding':
|
||||
diff_u1 = torch.reshape(diff_u1, (-1, self.h)) <= self.eps / 2
|
||||
diff_v1 = torch.reshape(diff_v1, (-1, self.h)) <= self.eps / 2
|
||||
elif self.method == 'RBF':
|
||||
diff_u1 = torch.pow(torch.reshape(diff_u1, (-1, self.h)),
|
||||
2) / self.sigma ** 2
|
||||
diff_v1 = torch.pow(torch.reshape(diff_v1, (-1, self.h)),
|
||||
2) / self.sigma ** 2
|
||||
diff_u1 = torch.exp(-diff_u1) # Gaussian
|
||||
diff_v1 = torch.exp(-diff_v1)
|
||||
elif self.method == 'inverse-quadratic':
|
||||
diff_u1 = torch.pow(torch.reshape(diff_u1, (-1, self.h)),
|
||||
2) / self.sigma ** 2
|
||||
diff_v1 = torch.pow(torch.reshape(diff_v1, (-1, self.h)),
|
||||
2) / self.sigma ** 2
|
||||
diff_u1 = 1 / (1 + diff_u1) # Inverse quadratic
|
||||
diff_v1 = 1 / (1 + diff_v1)
|
||||
|
||||
diff_u1 = diff_u1.type(torch.float32)
|
||||
diff_v1 = diff_v1.type(torch.float32)
|
||||
a = torch.t(Iy * diff_u1)
|
||||
if not self.green_only:
|
||||
hists[l, 1, :, :] = torch.mm(a, diff_v1)
|
||||
else:
|
||||
hists[l, 0, :, :] = torch.mm(a, diff_v1)
|
||||
|
||||
if not self.green_only:
|
||||
Iu2 = torch.unsqueeze(torch.log(I[:, 2] + EPS) - torch.log(I[:, 0] +
|
||||
EPS), dim=1)
|
||||
Iv2 = torch.unsqueeze(torch.log(I[:, 2] + EPS) - torch.log(I[:, 1] +
|
||||
EPS), dim=1)
|
||||
diff_u2 = abs(
|
||||
Iu2 - torch.unsqueeze(torch.tensor(np.linspace(
|
||||
self.hist_boundary[0], self.hist_boundary[1], num=self.h)),
|
||||
dim=0).to(self.device))
|
||||
diff_v2 = abs(
|
||||
Iv2 - torch.unsqueeze(torch.tensor(np.linspace(
|
||||
self.hist_boundary[0], self.hist_boundary[1], num=self.h)),
|
||||
dim=0).to(self.device))
|
||||
if self.method == 'thresholding':
|
||||
diff_u2 = torch.reshape(diff_u2, (-1, self.h)) <= self.eps / 2
|
||||
diff_v2 = torch.reshape(diff_v2, (-1, self.h)) <= self.eps / 2
|
||||
elif self.method == 'RBF':
|
||||
diff_u2 = torch.pow(torch.reshape(diff_u2, (-1, self.h)),
|
||||
2) / self.sigma ** 2
|
||||
diff_v2 = torch.pow(torch.reshape(diff_v2, (-1, self.h)),
|
||||
2) / self.sigma ** 2
|
||||
diff_u2 = torch.exp(-diff_u2) # Gaussian
|
||||
diff_v2 = torch.exp(-diff_v2)
|
||||
elif self.method == 'inverse-quadratic':
|
||||
diff_u2 = torch.pow(torch.reshape(diff_u2, (-1, self.h)),
|
||||
2) / self.sigma ** 2
|
||||
diff_v2 = torch.pow(torch.reshape(diff_v2, (-1, self.h)),
|
||||
2) / self.sigma ** 2
|
||||
diff_u2 = 1 / (1 + diff_u2) # Inverse quadratic
|
||||
diff_v2 = 1 / (1 + diff_v2)
|
||||
diff_u2 = diff_u2.type(torch.float32)
|
||||
diff_v2 = diff_v2.type(torch.float32)
|
||||
a = torch.t(Iy * diff_u2)
|
||||
hists[l, 2, :, :] = torch.mm(a, diff_v2)
|
||||
|
||||
# normalization
|
||||
hists_normalized = hists / (
|
||||
((hists.sum(dim=1)).sum(dim=1)).sum(dim=1).view(-1, 1, 1, 1) + EPS)
|
||||
|
||||
return hists_normalized
|
||||
BIN
utils/__pycache__/RGBuvHistBlock.cpython-39.pyc
Normal file
BIN
utils/__pycache__/RGBuvHistBlock.cpython-39.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/perceptual_loss.cpython-38.pyc
Normal file
BIN
utils/__pycache__/perceptual_loss.cpython-38.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/perceptual_loss.cpython-39.pyc
Normal file
BIN
utils/__pycache__/perceptual_loss.cpython-39.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/ssim_loss.cpython-38.pyc
Normal file
BIN
utils/__pycache__/ssim_loss.cpython-38.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/ssim_loss.cpython-39.pyc
Normal file
BIN
utils/__pycache__/ssim_loss.cpython-39.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/utils.cpython-38.pyc
Normal file
BIN
utils/__pycache__/utils.cpython-38.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/utils.cpython-39.pyc
Normal file
BIN
utils/__pycache__/utils.cpython-39.pyc
Normal file
Binary file not shown.
50
utils/create_hist_sample.py
Normal file
50
utils/create_hist_sample.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""
|
||||
If you find this code useful, please cite our paper:
|
||||
|
||||
Mahmoud Afifi, Marcus A. Brubaker, and Michael S. Brown. "HistoGAN:
|
||||
Controlling Colors of GAN-Generated and Real Images via Color Histograms."
|
||||
In CVPR, 2021.
|
||||
|
||||
@inproceedings{afifi2021histogan,
|
||||
title={Histo{GAN}: Controlling Colors of {GAN}-Generated and Real Images via
|
||||
Color Histograms},
|
||||
author={Afifi, Mahmoud and Brubaker, Marcus A. and Brown, Michael S.},
|
||||
booktitle={CVPR},
|
||||
year={2021}
|
||||
}
|
||||
"""
|
||||
import os
|
||||
from RGBuvHistBlock import RGBuvHistBlock
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from torchvision.utils import save_image
|
||||
import numpy as np
|
||||
from os.path import splitext, join, basename, exists
|
||||
|
||||
|
||||
image_folder = ''
|
||||
output_folder = ''
|
||||
if exists(output_folder) is False:
|
||||
os.mkdir(output_folder)
|
||||
|
||||
torch.cuda.set_device(0)
|
||||
histblock = RGBuvHistBlock(insz=336, h=336,
|
||||
resizing='sampling',
|
||||
method='inverse-quadratic',
|
||||
sigma=0.02,
|
||||
device=torch.cuda.current_device())
|
||||
transform = transforms.Compose([transforms.Resize((336, 336)),
|
||||
transforms.ToTensor()])
|
||||
|
||||
image_names = os.listdir(image_folder)
|
||||
for filename in image_names:
|
||||
print(filename)
|
||||
img_hist = Image.open(os.path.join(image_folder, filename))
|
||||
img_hist = torch.unsqueeze(transform(img_hist), dim=0).to(
|
||||
device=torch.cuda.current_device())
|
||||
histogram = histblock(img_hist)
|
||||
# histogram = histogram.cpu().numpy()
|
||||
save_image(histogram * 255, os.path.join(output_folder, filename))
|
||||
# np.save(join(output_dir, basename(splitext(filename)[0]) + '.npy'), histogram)
|
||||
|
||||
BIN
utils/font/simhei.ttf
Normal file
BIN
utils/font/simhei.ttf
Normal file
Binary file not shown.
BIN
utils/font/times.ttf
Normal file
BIN
utils/font/times.ttf
Normal file
Binary file not shown.
BIN
utils/font/timesbd.ttf
Normal file
BIN
utils/font/timesbd.ttf
Normal file
Binary file not shown.
BIN
utils/font/timesbi.ttf
Normal file
BIN
utils/font/timesbi.ttf
Normal file
Binary file not shown.
BIN
utils/font/timesi.ttf
Normal file
BIN
utils/font/timesi.ttf
Normal file
Binary file not shown.
BIN
utils/font/方正仿宋_GBK.TTF
Normal file
BIN
utils/font/方正仿宋_GBK.TTF
Normal file
Binary file not shown.
BIN
utils/font/楷体_GB2312.ttf
Normal file
BIN
utils/font/楷体_GB2312.ttf
Normal file
Binary file not shown.
BIN
utils/font/青鸟华光简琥珀.ttf
Normal file
BIN
utils/font/青鸟华光简琥珀.ttf
Normal file
Binary file not shown.
40
utils/perceptual_loss.py
Normal file
40
utils/perceptual_loss.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torchvision.models import vgg16
|
||||
|
||||
|
||||
# --- Perceptual loss network --- #
|
||||
class PerceptualLoss(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(PerceptualLoss, self).__init__()
|
||||
vgg_model = vgg16(pretrained=True).features[:16]
|
||||
vgg_model = vgg_model.cuda()
|
||||
# vgg_model = nn.DataParallel(vgg_model, device_ids=device_ids)
|
||||
for param in vgg_model.parameters():
|
||||
param.requires_grad = False
|
||||
self.vgg_layers = vgg_model
|
||||
self.layer_name_mapping = {
|
||||
'3': "relu1_2",
|
||||
'8': "relu2_2",
|
||||
'15': "relu3_3"
|
||||
}
|
||||
|
||||
def output_features(self, x):
|
||||
output = {}
|
||||
for name, module in self.vgg_layers._modules.items():
|
||||
x = module(x)
|
||||
if name in self.layer_name_mapping:
|
||||
output[self.layer_name_mapping[name]] = x
|
||||
return list(output.values())
|
||||
|
||||
def forward(self, pred_im, gt):
|
||||
loss = []
|
||||
pred_im_features = self.output_features(pred_im)
|
||||
gt_features = self.output_features(gt)
|
||||
for pred_im_feature, gt_feature in zip(pred_im_features, gt_features):
|
||||
loss.append(F.mse_loss(pred_im_feature, gt_feature))
|
||||
|
||||
return sum(loss)/len(loss)
|
||||
|
||||
67
utils/ssim_loss.py
Normal file
67
utils/ssim_loss.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
from math import exp
|
||||
|
||||
def gaussian(window_size, sigma):
|
||||
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
|
||||
return gauss/gauss.sum()
|
||||
|
||||
def create_window(window_size, channel):
|
||||
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
||||
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
||||
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
|
||||
return window
|
||||
|
||||
def _ssim(img1, img2, window, window_size, channel, size_average = True):
|
||||
mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
|
||||
mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
|
||||
|
||||
mu1_sq = mu1.pow(2)
|
||||
mu2_sq = mu2.pow(2)
|
||||
mu1_mu2 = mu1*mu2
|
||||
|
||||
sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
|
||||
sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
|
||||
sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
|
||||
|
||||
C1 = 0.01**2
|
||||
C2 = 0.03**2
|
||||
|
||||
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
|
||||
|
||||
if size_average:
|
||||
return ssim_map.mean()
|
||||
else:
|
||||
return ssim_map.mean(1).mean(1).mean(1)
|
||||
|
||||
class SSIMLoss(torch.nn.Module):
|
||||
def __init__(self, window_size=11, size_average=True, loss_weight=1.0):
|
||||
super(SSIMLoss, self).__init__()
|
||||
self.window_size = window_size
|
||||
self.size_average = size_average
|
||||
self.channel = 1
|
||||
self.window = create_window(window_size, self.channel)
|
||||
self.loss_weight = loss_weight
|
||||
|
||||
def forward(self, img1, img2):
|
||||
(_, channel, _, _) = img1.size()
|
||||
|
||||
if channel == self.channel and self.window.data.type() == img1.data.type():
|
||||
window = self.window
|
||||
else:
|
||||
window = create_window(self.window_size, channel)
|
||||
|
||||
if img1.is_cuda:
|
||||
window = window.cuda(img1.get_device())
|
||||
window = window.type_as(img1)
|
||||
|
||||
self.window = window
|
||||
self.channel = channel
|
||||
loss = self.loss_weight * (1 - _ssim(img1,
|
||||
img2,
|
||||
window,
|
||||
self.window_size,
|
||||
channel,
|
||||
self.size_average))
|
||||
return loss
|
||||
89
utils/utils.py
Normal file
89
utils/utils.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image, ImageFilter
|
||||
from torchvision.transforms import ToTensor
|
||||
|
||||
|
||||
def my_save_image(name, image_np, output_path=""):
|
||||
if not os.path.exists(output_path):
|
||||
os.mkdir(output_path)
|
||||
|
||||
p = np_to_pil(image_np)
|
||||
p.save(output_path + "{}".format(name))
|
||||
|
||||
|
||||
def pil_to_np(img_PIL, with_transpose=True):
|
||||
"""
|
||||
Converts image in PIL format to np.array.
|
||||
|
||||
From W x H x C [0...255] to C x W x H [0..1]
|
||||
"""
|
||||
ar = np.array(img_PIL)
|
||||
if len(ar.shape) == 3 and ar.shape[-1] == 4:
|
||||
ar = ar[:, :, :3]
|
||||
# this is alpha channel
|
||||
if with_transpose:
|
||||
if len(ar.shape) == 3:
|
||||
ar = ar.transpose(2, 0, 1)
|
||||
else:
|
||||
ar = ar[None, ...]
|
||||
|
||||
return ar.astype(np.float32) / 255.
|
||||
|
||||
|
||||
def np_to_pil(img_np):
|
||||
"""
|
||||
Converts image in np.array format to PIL image.
|
||||
|
||||
From C x W x H [0..1] to W x H x C [0...255]
|
||||
:param img_np:
|
||||
:return:
|
||||
"""
|
||||
ar = np.clip(img_np * 255, 0, 255).astype(np.uint8)
|
||||
|
||||
if img_np.shape[0] == 1:
|
||||
ar = ar[0]
|
||||
else:
|
||||
assert img_np.shape[0] == 3, img_np.shape
|
||||
ar = ar.transpose(1, 2, 0)
|
||||
|
||||
return Image.fromarray(ar)
|
||||
|
||||
|
||||
def np_to_torch(img_np):
|
||||
"""
|
||||
Converts image in numpy.array to torch.Tensor.
|
||||
|
||||
From C x W x H [0..1] to C x W x H [0..1]
|
||||
|
||||
:param img_np:
|
||||
:return:
|
||||
"""
|
||||
return torch.from_numpy(img_np)[None, :]
|
||||
|
||||
|
||||
def torch_to_np(img_var):
|
||||
"""
|
||||
Converts an image in torch.Tensor format to np.array.
|
||||
|
||||
From 1 x C x W x H [0..1] to C x W x H [0..1]
|
||||
:param img_var:
|
||||
:return:
|
||||
"""
|
||||
return img_var.detach().cpu().numpy()
|
||||
|
||||
|
||||
def quantize(img, rgb_range):
|
||||
pixel_range = 255 / rgb_range
|
||||
return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
|
||||
|
||||
|
||||
def get_A(x):
|
||||
x_np = np.clip(torch_to_np(x), 0, 1)
|
||||
x_pil = np_to_pil(x_np)
|
||||
h, w = x_pil.size
|
||||
windows = (h + w) / 2
|
||||
A = x_pil.filter(ImageFilter.GaussianBlur(windows))
|
||||
A = ToTensor()(A)
|
||||
return A
|
||||
Reference in New Issue
Block a user