Initial commit for SeaDiff project code

This commit is contained in:
Henry-Bi
2025-06-15 17:28:44 +08:00
parent b31cdfd067
commit a1c05872fe
170 changed files with 12855 additions and 0 deletions

228
utils/RGBuvHistBlock.py Normal file
View File

@@ -0,0 +1,228 @@
"""
##### Copyright 2021 Mahmoud Afifi.
If you find this code useful, please cite our paper:
Mahmoud Afifi, Marcus A. Brubaker, and Michael S. Brown. "HistoGAN:
Controlling Colors of GAN-Generated and Real Images via Color Histograms."
In CVPR, 2021.
@inproceedings{afifi2021histogan,
title={Histo{GAN}: Controlling Colors of {GAN}-Generated and Real Images via
Color Histograms},
author={Afifi, Mahmoud and Brubaker, Marcus A. and Brown, Michael S.},
booktitle={CVPR},
year={2021}
}
####
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
EPS = 1e-6
class RGBuvHistBlock(nn.Module):
def __init__(self, h=64, insz=150, resizing='interpolation',
method='inverse-quadratic', sigma=0.02, intensity_scale=True,
hist_boundary=None, green_only=False, device='cuda'):
""" Computes the RGB-uv histogram feature of a given image.
Args:
h: histogram dimension size (scalar). The default value is 64.
insz: maximum size of the input image; if it is larger than this size, the
image will be resized (scalar). Default value is 150 (i.e., 150 x 150
pixels).
resizing: resizing method if applicable. Options are: 'interpolation' or
'sampling'. Default is 'interpolation'.
method: the method used to count the number of pixels for each bin in the
histogram feature. Options are: 'thresholding', 'RBF' (radial basis
function), or 'inverse-quadratic'. Default value is 'inverse-quadratic'.
sigma: if the method value is 'RBF' or 'inverse-quadratic', then this is
the sigma parameter of the kernel function. The default value is 0.02.
intensity_scale: boolean variable to use the intensity scale (I_y in
Equation 2). Default value is True.
hist_boundary: a list of histogram boundary values. Default is [-3, 3].
green_only: boolean variable to use only the log(g/r), log(g/b) channels.
Default is False.
Methods:
forward: accepts input image and returns its histogram feature. Note that
unless the method is 'thresholding', this is a differentiable function
and can be easily integrated with the loss function. As mentioned in the
paper, the 'inverse-quadratic' was found more stable than 'RBF' in our
training.
"""
super(RGBuvHistBlock, self).__init__()
self.h = h
self.insz = insz
self.device = device
self.resizing = resizing
self.method = method
self.intensity_scale = intensity_scale
self.green_only = green_only
if hist_boundary is None:
hist_boundary = [-3, 3]
hist_boundary.sort()
self.hist_boundary = hist_boundary
if self.method == 'thresholding':
self.eps = (abs(hist_boundary[0]) + abs(hist_boundary[1])) / h
else:
self.sigma = sigma
def forward(self, x):
x = torch.clamp(x, 0, 1)
if x.shape[2] > self.insz or x.shape[3] > self.insz:
if self.resizing == 'interpolation':
x_sampled = F.interpolate(x, size=(self.insz, self.insz),
mode='bilinear', align_corners=False)
elif self.resizing == 'sampling':
inds_1 = torch.LongTensor(
np.linspace(0, x.shape[2], self.h, endpoint=False)).to(
device=self.device)
inds_2 = torch.LongTensor(
np.linspace(0, x.shape[3], self.h, endpoint=False)).to(
device=self.device)
x_sampled = x.index_select(2, inds_1)
x_sampled = x_sampled.index_select(3, inds_2)
else:
raise Exception(
f'Wrong resizing method. It should be: interpolation or sampling. '
f'But the given value is {self.resizing}.')
else:
x_sampled = x
L = x_sampled.shape[0] # size of mini-batch
if x_sampled.shape[1] > 3:
x_sampled = x_sampled[:, :3, :, :]
X = torch.unbind(x_sampled, dim=0)
hists = torch.zeros((x_sampled.shape[0], 1 + int(not self.green_only) * 2,
self.h, self.h)).to(device=self.device)
for l in range(L):
I = torch.t(torch.reshape(X[l], (3, -1)))
II = torch.pow(I, 2)
if self.intensity_scale:
Iy = torch.unsqueeze(torch.sqrt(II[:, 0] + II[:, 1] + II[:, 2] + EPS),
dim=1)
else:
Iy = 1
if not self.green_only:
Iu0 = torch.unsqueeze(torch.log(I[:, 0] + EPS) - torch.log(I[:, 1] +
EPS), dim=1)
Iv0 = torch.unsqueeze(torch.log(I[:, 0] + EPS) - torch.log(I[:, 2] +
EPS), dim=1)
diff_u0 = abs(
Iu0 - torch.unsqueeze(torch.tensor(np.linspace(
self.hist_boundary[0], self.hist_boundary[1], num=self.h)),
dim=0).to(self.device))
diff_v0 = abs(
Iv0 - torch.unsqueeze(torch.tensor(np.linspace(
self.hist_boundary[0], self.hist_boundary[1], num=self.h)),
dim=0).to(self.device))
if self.method == 'thresholding':
diff_u0 = torch.reshape(diff_u0, (-1, self.h)) <= self.eps / 2
diff_v0 = torch.reshape(diff_v0, (-1, self.h)) <= self.eps / 2
elif self.method == 'RBF':
diff_u0 = torch.pow(torch.reshape(diff_u0, (-1, self.h)),
2) / self.sigma ** 2
diff_v0 = torch.pow(torch.reshape(diff_v0, (-1, self.h)),
2) / self.sigma ** 2
diff_u0 = torch.exp(-diff_u0) # Radial basis function
diff_v0 = torch.exp(-diff_v0)
elif self.method == 'inverse-quadratic':
diff_u0 = torch.pow(torch.reshape(diff_u0, (-1, self.h)),
2) / self.sigma ** 2
diff_v0 = torch.pow(torch.reshape(diff_v0, (-1, self.h)),
2) / self.sigma ** 2
diff_u0 = 1 / (1 + diff_u0) # Inverse quadratic
diff_v0 = 1 / (1 + diff_v0)
else:
raise Exception(
f'Wrong kernel method. It should be either thresholding, RBF,'
f' inverse-quadratic. But the given value is {self.method}.')
diff_u0 = diff_u0.type(torch.float32)
diff_v0 = diff_v0.type(torch.float32)
a = torch.t(Iy * diff_u0)
hists[l, 0, :, :] = torch.mm(a, diff_v0)
Iu1 = torch.unsqueeze(torch.log(I[:, 1] + EPS) - torch.log(I[:, 0] + EPS),
dim=1)
Iv1 = torch.unsqueeze(torch.log(I[:, 1] + EPS) - torch.log(I[:, 2] + EPS),
dim=1)
diff_u1 = abs(
Iu1 - torch.unsqueeze(torch.tensor(np.linspace(
self.hist_boundary[0], self.hist_boundary[1], num=self.h)),
dim=0).to(self.device))
diff_v1 = abs(
Iv1 - torch.unsqueeze(torch.tensor(np.linspace(
self.hist_boundary[0], self.hist_boundary[1], num=self.h)),
dim=0).to(self.device))
if self.method == 'thresholding':
diff_u1 = torch.reshape(diff_u1, (-1, self.h)) <= self.eps / 2
diff_v1 = torch.reshape(diff_v1, (-1, self.h)) <= self.eps / 2
elif self.method == 'RBF':
diff_u1 = torch.pow(torch.reshape(diff_u1, (-1, self.h)),
2) / self.sigma ** 2
diff_v1 = torch.pow(torch.reshape(diff_v1, (-1, self.h)),
2) / self.sigma ** 2
diff_u1 = torch.exp(-diff_u1) # Gaussian
diff_v1 = torch.exp(-diff_v1)
elif self.method == 'inverse-quadratic':
diff_u1 = torch.pow(torch.reshape(diff_u1, (-1, self.h)),
2) / self.sigma ** 2
diff_v1 = torch.pow(torch.reshape(diff_v1, (-1, self.h)),
2) / self.sigma ** 2
diff_u1 = 1 / (1 + diff_u1) # Inverse quadratic
diff_v1 = 1 / (1 + diff_v1)
diff_u1 = diff_u1.type(torch.float32)
diff_v1 = diff_v1.type(torch.float32)
a = torch.t(Iy * diff_u1)
if not self.green_only:
hists[l, 1, :, :] = torch.mm(a, diff_v1)
else:
hists[l, 0, :, :] = torch.mm(a, diff_v1)
if not self.green_only:
Iu2 = torch.unsqueeze(torch.log(I[:, 2] + EPS) - torch.log(I[:, 0] +
EPS), dim=1)
Iv2 = torch.unsqueeze(torch.log(I[:, 2] + EPS) - torch.log(I[:, 1] +
EPS), dim=1)
diff_u2 = abs(
Iu2 - torch.unsqueeze(torch.tensor(np.linspace(
self.hist_boundary[0], self.hist_boundary[1], num=self.h)),
dim=0).to(self.device))
diff_v2 = abs(
Iv2 - torch.unsqueeze(torch.tensor(np.linspace(
self.hist_boundary[0], self.hist_boundary[1], num=self.h)),
dim=0).to(self.device))
if self.method == 'thresholding':
diff_u2 = torch.reshape(diff_u2, (-1, self.h)) <= self.eps / 2
diff_v2 = torch.reshape(diff_v2, (-1, self.h)) <= self.eps / 2
elif self.method == 'RBF':
diff_u2 = torch.pow(torch.reshape(diff_u2, (-1, self.h)),
2) / self.sigma ** 2
diff_v2 = torch.pow(torch.reshape(diff_v2, (-1, self.h)),
2) / self.sigma ** 2
diff_u2 = torch.exp(-diff_u2) # Gaussian
diff_v2 = torch.exp(-diff_v2)
elif self.method == 'inverse-quadratic':
diff_u2 = torch.pow(torch.reshape(diff_u2, (-1, self.h)),
2) / self.sigma ** 2
diff_v2 = torch.pow(torch.reshape(diff_v2, (-1, self.h)),
2) / self.sigma ** 2
diff_u2 = 1 / (1 + diff_u2) # Inverse quadratic
diff_v2 = 1 / (1 + diff_v2)
diff_u2 = diff_u2.type(torch.float32)
diff_v2 = diff_v2.type(torch.float32)
a = torch.t(Iy * diff_u2)
hists[l, 2, :, :] = torch.mm(a, diff_v2)
# normalization
hists_normalized = hists / (
((hists.sum(dim=1)).sum(dim=1)).sum(dim=1).view(-1, 1, 1, 1) + EPS)
return hists_normalized

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,50 @@
"""
If you find this code useful, please cite our paper:
Mahmoud Afifi, Marcus A. Brubaker, and Michael S. Brown. "HistoGAN:
Controlling Colors of GAN-Generated and Real Images via Color Histograms."
In CVPR, 2021.
@inproceedings{afifi2021histogan,
title={Histo{GAN}: Controlling Colors of {GAN}-Generated and Real Images via
Color Histograms},
author={Afifi, Mahmoud and Brubaker, Marcus A. and Brown, Michael S.},
booktitle={CVPR},
year={2021}
}
"""
import os
from RGBuvHistBlock import RGBuvHistBlock
import torch
from PIL import Image
from torchvision import transforms
from torchvision.utils import save_image
import numpy as np
from os.path import splitext, join, basename, exists
image_folder = ''
output_folder = ''
if exists(output_folder) is False:
os.mkdir(output_folder)
torch.cuda.set_device(0)
histblock = RGBuvHistBlock(insz=336, h=336,
resizing='sampling',
method='inverse-quadratic',
sigma=0.02,
device=torch.cuda.current_device())
transform = transforms.Compose([transforms.Resize((336, 336)),
transforms.ToTensor()])
image_names = os.listdir(image_folder)
for filename in image_names:
print(filename)
img_hist = Image.open(os.path.join(image_folder, filename))
img_hist = torch.unsqueeze(transform(img_hist), dim=0).to(
device=torch.cuda.current_device())
histogram = histblock(img_hist)
# histogram = histogram.cpu().numpy()
save_image(histogram * 255, os.path.join(output_folder, filename))
# np.save(join(output_dir, basename(splitext(filename)[0]) + '.npy'), histogram)

BIN
utils/font/simhei.ttf Normal file

Binary file not shown.

BIN
utils/font/times.ttf Normal file

Binary file not shown.

BIN
utils/font/timesbd.ttf Normal file

Binary file not shown.

BIN
utils/font/timesbi.ttf Normal file

Binary file not shown.

BIN
utils/font/timesi.ttf Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

40
utils/perceptual_loss.py Normal file
View File

@@ -0,0 +1,40 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import vgg16
# --- Perceptual loss network --- #
class PerceptualLoss(torch.nn.Module):
def __init__(self):
super(PerceptualLoss, self).__init__()
vgg_model = vgg16(pretrained=True).features[:16]
vgg_model = vgg_model.cuda()
# vgg_model = nn.DataParallel(vgg_model, device_ids=device_ids)
for param in vgg_model.parameters():
param.requires_grad = False
self.vgg_layers = vgg_model
self.layer_name_mapping = {
'3': "relu1_2",
'8': "relu2_2",
'15': "relu3_3"
}
def output_features(self, x):
output = {}
for name, module in self.vgg_layers._modules.items():
x = module(x)
if name in self.layer_name_mapping:
output[self.layer_name_mapping[name]] = x
return list(output.values())
def forward(self, pred_im, gt):
loss = []
pred_im_features = self.output_features(pred_im)
gt_features = self.output_features(gt)
for pred_im_feature, gt_feature in zip(pred_im_features, gt_features):
loss.append(F.mse_loss(pred_im_feature, gt_feature))
return sum(loss)/len(loss)

67
utils/ssim_loss.py Normal file
View File

@@ -0,0 +1,67 @@
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from math import exp
def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
return gauss/gauss.sum()
def create_window(window_size, channel):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
return window
def _ssim(img1, img2, window, window_size, channel, size_average = True):
mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1*mu2
sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean(1).mean(1).mean(1)
class SSIMLoss(torch.nn.Module):
def __init__(self, window_size=11, size_average=True, loss_weight=1.0):
super(SSIMLoss, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = 1
self.window = create_window(window_size, self.channel)
self.loss_weight = loss_weight
def forward(self, img1, img2):
(_, channel, _, _) = img1.size()
if channel == self.channel and self.window.data.type() == img1.data.type():
window = self.window
else:
window = create_window(self.window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
self.window = window
self.channel = channel
loss = self.loss_weight * (1 - _ssim(img1,
img2,
window,
self.window_size,
channel,
self.size_average))
return loss

89
utils/utils.py Normal file
View File

@@ -0,0 +1,89 @@
import os
import torch
import numpy as np
from PIL import Image, ImageFilter
from torchvision.transforms import ToTensor
def my_save_image(name, image_np, output_path=""):
if not os.path.exists(output_path):
os.mkdir(output_path)
p = np_to_pil(image_np)
p.save(output_path + "{}".format(name))
def pil_to_np(img_PIL, with_transpose=True):
"""
Converts image in PIL format to np.array.
From W x H x C [0...255] to C x W x H [0..1]
"""
ar = np.array(img_PIL)
if len(ar.shape) == 3 and ar.shape[-1] == 4:
ar = ar[:, :, :3]
# this is alpha channel
if with_transpose:
if len(ar.shape) == 3:
ar = ar.transpose(2, 0, 1)
else:
ar = ar[None, ...]
return ar.astype(np.float32) / 255.
def np_to_pil(img_np):
"""
Converts image in np.array format to PIL image.
From C x W x H [0..1] to W x H x C [0...255]
:param img_np:
:return:
"""
ar = np.clip(img_np * 255, 0, 255).astype(np.uint8)
if img_np.shape[0] == 1:
ar = ar[0]
else:
assert img_np.shape[0] == 3, img_np.shape
ar = ar.transpose(1, 2, 0)
return Image.fromarray(ar)
def np_to_torch(img_np):
"""
Converts image in numpy.array to torch.Tensor.
From C x W x H [0..1] to C x W x H [0..1]
:param img_np:
:return:
"""
return torch.from_numpy(img_np)[None, :]
def torch_to_np(img_var):
"""
Converts an image in torch.Tensor format to np.array.
From 1 x C x W x H [0..1] to C x W x H [0..1]
:param img_var:
:return:
"""
return img_var.detach().cpu().numpy()
def quantize(img, rgb_range):
pixel_range = 255 / rgb_range
return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
def get_A(x):
x_np = np.clip(torch_to_np(x), 0, 1)
x_pil = np_to_pil(x_np)
h, w = x_pil.size
windows = (h + w) / 2
A = x_pil.filter(ImageFilter.GaussianBlur(windows))
A = ToTensor()(A)
return A