diff --git a/FourierGrid/run_gtk_analysis.py b/FourierGrid/run_gtk_analysis.py index d1d8ef9..c87f237 100644 --- a/FourierGrid/run_gtk_analysis.py +++ b/FourierGrid/run_gtk_analysis.py @@ -7,15 +7,17 @@ import time import torch # import random -from jax import random +# from jax import random import torch.nn as nn import numpy as np +from scipy.special import jv +from scipy.ndimage import gaussian_filter1d -class VG(nn.Module): - # the VG operator +class VoxelGrid(nn.Module): + # the V o x e l G ri d operator def __init__(self, grid_len=1000): - super(VG, self).__init__() + super(VoxelGrid, self).__init__() self.grid_len = grid_len self.interval_num = grid_len - 1 axis_coord = np.array([0 + i * 1 / grid_len for i in range(grid_len)]) @@ -39,7 +41,7 @@ def forward(self,): # calculate GTK real_x = idx / data_point_num left_grid = int(real_x // (1 / self.grid_len)) right_grid = left_grid + 1 - if left_grid > 0: + if left_grid >= 0: jacobian_y_w[idx][left_grid] = abs(real_x - right_grid * 1 / self.grid_len) * self.grid_len if right_grid < self.grid_len: jacobian_y_w[idx][right_grid] = abs(real_x - left_grid * 1 / self.grid_len) * self.grid_len @@ -77,10 +79,10 @@ def one_d_regress(self, x_train, x_test, y_train, y_test_gt): return train_loss, test_loss, y_test -class FG(nn.Module): - # the FG operator +class FourierGrid(nn.Module): + # the FourierGrid operator def __init__(self, grid_len=1000, band_num=10): - super(FG, self).__init__() + super(FourierGrid, self).__init__() self.grid_len = grid_len self.interval_num = self.grid_len - 1 self.band_num = band_num @@ -160,8 +162,34 @@ def one_d_regress(self, x_train, x_test, y_train, y_test_gt): test_loss = np.mean(test_loss) return train_loss, test_loss, y_test + +# build models and train them +def train_model(one_model): + # training and testing VoxelGrid + optimizer = torch.optim.Adam(one_model.parameters(), lr=lr) + iterations = 150 + epoch_iter = tqdm(range(iterations)) + for epoch in epoch_iter: + optimizer.zero_grad() # to make the gradients zero + train_loss, test_loss, test_y = one_model.one_d_regress(x_train, x_test, y_train, y_test_gt) + train_loss.backward() # This is for computing gradients using backward propagation + optimizer.step() # This is equivalent to: theta_new = theta_old - alpha * derivative of J w.r.t theta + epoch_iter.set_description(f"Training loss: {train_loss.item()}; Testing Loss: {test_loss}") + return train_loss, test_loss, test_y + + +def get_fg_gtk_spectrum_by_band_num(band_num): + test_fg = FourierGrid(grid_len=grid_len, band_num=band_num * 2) + fg_gtk = test_fg() + # fg_gtk = (fg_gtk - fg_gtk.min()) / (fg_gtk.max() - fg_gtk.min()) + fg_gtk_spectrum = 10**fplot(fg_gtk) + fg_plot = gaussian_filter1d(fg_gtk_spectrum[0], sigma=2) + return fg_plot + + # hyperparameters -title_offset = -0.4 +title_offset = -0.29 +bbox_offset = 1.44 data_point_num = 100 grid_len = 10 freq_num = 10 @@ -171,35 +199,62 @@ def one_d_regress(self, x_train, x_test, y_train, y_test_gt): [0.0943, 0.5937, 0.8793], [0.3936, 0.2946, 0.6330], [0.7123, 0.2705, 0.3795]]) -linewidth = 3 +linewidth = 1.0 line_alpha = .8 +title_font_size = 7.4 +legend_font_size = 6 +label_size = 7 +# matplotlib.rcParams["font.family"] = 'Arial' +matplotlib.rcParams['xtick.labelsize'] = label_size +matplotlib.rcParams['ytick.labelsize'] = label_size # begin plot -fig3 = plt.figure(constrained_layout=True, figsize=(4, 2)) -gs = fig3.add_gridspec(1, 2, width_ratios=[1, 1]) +fig3 = plt.figure(constrained_layout=True, figsize=(4, 4)) +gs = fig3.add_gridspec(2, 2, width_ratios=[1, 1], height_ratios=[1, 1]) # 100 * 100 datapoints, 10*10 params (grid_len=10) -test_vg = VG(grid_len=grid_len * freq_num) +test_vg = VoxelGrid(grid_len=grid_len * freq_num) vg_gtk = test_vg() -vg_gtk = (vg_gtk - vg_gtk.min()) / vg_gtk.max() +vg_gtk_normalized = (vg_gtk - vg_gtk.min()) / (vg_gtk.max() - vg_gtk.min()) ax = fig3.add_subplot(gs[0, 0]) -ax.imshow(vg_gtk) +ax.imshow(vg_gtk_normalized) ax.set_xticks([*range(0, 100, 20)] + [100]) ax.set_yticks([*range(0, 100, 20)] + [100]) -ax.set_title('(a) VG GTK', y=title_offset) +ax.grid(linestyle = '--', linewidth = 0.3) +ax.set_title('(a) VoxelGrid GTK', y=title_offset, fontsize=title_font_size) ax = fig3.add_subplot(gs[0, 1]) -test_fg = FG(grid_len=grid_len, band_num=freq_num) +test_fg = FourierGrid(grid_len=grid_len, band_num=freq_num) fg_gtk = test_fg() -fg_gtk = (fg_gtk - fg_gtk.min()) / fg_gtk.max() +fg_gtk = (fg_gtk - fg_gtk.min()) / (fg_gtk.max() - fg_gtk.min()) ax.imshow(fg_gtk) ax.set_xticks([*range(0, 100, 20)] + [100]) ax.set_yticks([*range(0, 100, 20)] + [100]) -ax.set_title('(b) FG GTK', y=title_offset) - -# generate figures -plt.savefig("figures/vg_fg_gtk.jpg", dpi=800) -plt.savefig("figures/vg_fg_gtk.pdf", format="pdf") -pdb.set_trace() +ax.grid(linestyle = '--', linewidth = 0.3) +ax.set_title('(b) FourierGrid GTK', y=title_offset, fontsize=title_font_size) + +ax = fig3.add_subplot(gs[1, 0]) +w_vg, v_vg = np.linalg.eig(vg_gtk) +w_fg, v_fg = np.linalg.eig(fg_gtk) +fplot = lambda x : np.fft.fftshift(np.log10(np.abs(np.fft.fft(x)))) +vg_gtk_spectrum = 10**fplot(vg_gtk) +vg_plot = gaussian_filter1d(vg_gtk_spectrum[0], sigma=2) + +fg_gtk_plot_1 = get_fg_gtk_spectrum_by_band_num(band_num=1) +fg_gtk_plot_5 = get_fg_gtk_spectrum_by_band_num(band_num=5) +fg_gtk_plot_10 = get_fg_gtk_spectrum_by_band_num(band_num=10) +plt.autoscale(enable=True, axis='x', tight=True) +# plt.plot(vg_plot, label='VoxelGrid', color=colors_k[0], alpha=line_alpha, linewidth=linewidth) +plt.semilogy(np.append(vg_plot, vg_plot[0]), label='VoxelGrid', color=colors_k[0], alpha=line_alpha, linewidth=linewidth) +# plt.semilogy(fg_gtk_plot_1, label='FourierGrid (l=1)', color=colors_k[2], alpha=line_alpha, linewidth=linewidth) +plt.semilogy(np.append(fg_gtk_plot_1, fg_gtk_plot_1[0]), label='FourierGrid (l=1)', color=colors_k[2], alpha=line_alpha, linewidth=linewidth) +# plt.semilogy(fg_gtk_plot_5, label='FourierGrid (l=5)', color=colors_k[3], alpha=line_alpha, linewidth=linewidth) +plt.semilogy(np.append(fg_gtk_plot_5, fg_gtk_plot_5[0]), label='FourierGrid (l=5)', color=colors_k[3], alpha=line_alpha, linewidth=linewidth) +# plt.semilogy(fg_gtk_plot_10, label='FourierGrid (l=10)', color=colors_k[4], alpha=line_alpha, linewidth=linewidth) +plt.semilogy(np.append(fg_gtk_plot_10, fg_gtk_plot_10[0]), label='FourierGrid (l=10)', color=colors_k[4], alpha=line_alpha, linewidth=linewidth) +plt.xticks([0,25,50,75,100], ['$-\pi$','$-\pi/2$','$0$','$\pi/2$','$\pi$']) +ax.set_yticks([0.1, 1, 10, 100]) +ax.legend(loc='upper left', bbox_to_anchor=(-0.01, bbox_offset), handlelength=1, fontsize=legend_font_size, fancybox=False, ncol=1) +ax.set_title('(c) GTK Fourier Spectrum', y=title_offset, fontsize=title_font_size) def sample_random_signal(key, decay_vec): @@ -219,6 +274,14 @@ def sample_random_powerlaw(key, N, power): return sample_random_signal(key, decay_vec) +def get_sine_signal(): + return np.array([np.sin(x / (train_num*sample_interval) * 2 * np.pi) for x in range(train_num*sample_interval)]) + + +def get_bessel_signal(): + # return np.array([np.exp(x / train_num*sample_interval) for x in range(train_num*sample_interval)]) + return np.array([jv(1, x / 4) for x in range(train_num*sample_interval)]) + ## Fitting experiments # hyperparameters rand_key = np.array([0, 0], dtype=np.uint32) @@ -230,62 +293,40 @@ def sample_random_powerlaw(key, N, power): # setup data x_test = np.float32(np.linspace(0, 1., train_num * sample_interval, endpoint=False)) -x_train = x_test[::sample_interval] -# s = sample_random_powerlaw(rand_key, train_num * sample_interval, data_power) -signal = np.array([np.sin(x / (train_num*sample_interval) * 2 * np.pi) for x in range(train_num*sample_interval)]) +x_train = x_test[0:len(x_test):sample_interval] + +# signal = get_sine_signal() +signal = get_bessel_signal() signal = (signal-signal.min()) / (signal.max()-signal.min()) -y_train = signal[::sample_interval] +y_train = signal[0:len(x_test):sample_interval] y_test_gt = signal -# build models and train them -def train_model(one_model): - # training and testing VG - optimizer = torch.optim.Adam(one_model.parameters(), lr=lr) - iterations = 150 - epoch_iter = tqdm(range(iterations)) - for epoch in epoch_iter: - optimizer.zero_grad() # to make the gradients zero - train_loss, test_loss, test_y = one_model.one_d_regress(x_train, x_test, y_train, y_test_gt) - train_loss.backward() # This is for computing gradients using backward propagation - optimizer.step() # This is equivalent to: theta_new = theta_old - alpha * derivative of J w.r.t theta - epoch_iter.set_description(f"Training loss: {train_loss.item()}; Testing Loss: {test_loss}") - return train_loss, test_loss, test_y - freq_num = 3 -test_vg_small = VG(grid_len=10 * freq_num) -test_vg_large = VG(grid_len=100 * freq_num) -test_fg_small = FG(grid_len=10, band_num=freq_num) -test_fg_large = FG(grid_len=100, band_num=freq_num) +test_vg_small = VoxelGrid(grid_len=10 * freq_num) +test_vg_large = VoxelGrid(grid_len=100 * freq_num) +test_fg_small = FourierGrid(grid_len=10, band_num=freq_num) +test_fg_large = FourierGrid(grid_len=100, band_num=freq_num) train_loss, test_loss, test_y_vg_small = train_model(test_vg_small) train_loss, test_loss, test_y_fg_small = train_model(test_fg_small) - -ax = fig3.add_subplot(gs[0, 2]) +ax = fig3.add_subplot(gs[1, 1]) ax.plot(x_test, signal, label='Target signal', color='k', linewidth=1, alpha=line_alpha, zorder=1) -ax.plot(x_test, test_y_vg_small, label='Learned by VG', color=colors_k[1], linewidth=1, alpha=line_alpha, zorder=1) -ax.plot(x_test, test_y_fg_small, label='Learned by FG', color=colors_k[2], linewidth=1, alpha=line_alpha, zorder=1) +ax.plot(x_test, test_y_vg_small, label='Learned by VoxelGrid', color=colors_k[0], linewidth=1, alpha=line_alpha, zorder=1) +ax.plot(x_test, test_y_fg_small, label='Learned by FourierGrid', color=colors_k[3], linewidth=1, alpha=line_alpha, zorder=1) ax.scatter(x_train, y_train, color='w', edgecolors='k', linewidths=1, s=20, linewidth=1, label='Training points', zorder=2) -ax.set_title('(c) 1D Regression', y=title_offset) +ax.set_title('(d) 1D Regression', y=title_offset, fontsize=title_font_size) ax.set_xticks(np.linspace(0.0, 1.0, num=5, endpoint=True)) +ax.legend(loc='upper left', bbox_to_anchor=(-0.01, bbox_offset), handlelength=1, fontsize=legend_font_size, fancybox=False, ncol=1) -# ax.set_xticks([]) -# ax.set_yticks([]) -# ax.legend(loc='upper right', ncol=2) -ax.legend(loc='upper left', bbox_to_anchor=(1.03, 0.78), handlelength=1) - -# plt.xlabel('x') -# plt.ylabel('y') -# plt.grid(True, which='both', alpha=.3) - - -# generate figures -plt.savefig("figures/final_vg_fg.jpg", dpi=800) -plt.savefig("figures/final_vg_fg.pdf", format="pdf") +print("Plotting figures!") +plt.savefig("figures/vg_fg_gtk.jpg", dpi=300) # for example +plt.savefig("figures/vg_fg_gtk.pdf", format="pdf") pdb.set_trace() + # # unused codes # fplot = lambda x : np.fft.fftshift(np.log10(np.abs(np.fft.fft(x)))) @@ -294,8 +335,8 @@ def train_model(one_model): # fg_spec = 10**fplot(fg_gtk) # w_vg, v_vg = np.linalg.eig(vg_gtk) # w_fg, v_fg = np.linalg.eig(fg_gtk) -# plt.semilogy(vg_spec[0], label="VG", color=colors_k[0], alpha=line_alpha, linewidth=linewidth) -# plt.semilogy(fg_spec[0], label="FG", color=colors_k[1], alpha=line_alpha, linewidth=linewidth) +# plt.semilogy(vg_spec[0], label="VoxelGrid", color=colors_k[0], alpha=line_alpha, linewidth=linewidth) +# plt.semilogy(fg_spec[0], label="FourierGrid", color=colors_k[1], alpha=line_alpha, linewidth=linewidth) # # ax.plot(np.linspace(-.5, .5, 10100, endpoint=True), np.append(vg_spec, vg_spec[0]), label="vg", color=colors_k[0], alpha=line_alpha, linewidth=linewidth) # ax.set_title('(c) GTK Fourier spectrum', y=title_offset) diff --git a/block_nerf/block_nerf_lightning.py b/block_nerf/block_nerf_lightning.py new file mode 100644 index 0000000..2c83cdf --- /dev/null +++ b/block_nerf/block_nerf_lightning.py @@ -0,0 +1,152 @@ +from pytorch_lightning import LightningModule, Trainer +import torch +import os +from collections import defaultdict +from torch.utils.data import DataLoader +from block_nerf.waymo_dataset import * +from block_nerf.block_nerf_model import * +from block_nerf.rendering import * +from block_nerf.metrics import * +from block_nerf.block_visualize import * +from block_nerf.learning_utils import * + +class Block_NeRF_System(LightningModule): + def __init__(self, hparams): + super(Block_NeRF_System, self).__init__() + self.hyper_params = hparams + self.save_hyperparameters(hparams) + self.loss = BlockNeRFLoss(1e-2) #hparams['Visi_loss'] + + self.xyz_IPE = InterPosEmbedding(hparams['N_IPE_xyz']) # xyz的L=10 + self.dir_exposure_PE = PosEmbedding( + hparams['N_PE_dir_exposure']) # dir的L=4 + self.embedding_appearance = torch.nn.Embedding( + hparams['N_vocab'], hparams['N_appearance']) + + self.Embedding = {'IPE': self.xyz_IPE, + 'PE': self.dir_exposure_PE, + 'appearance': self.embedding_appearance} + + self.Block_NeRF = Block_NeRF(in_channel_xyz=6 * hparams['N_IPE_xyz'], + in_channel_dir=6 * + hparams['N_PE_dir_exposure'], + in_channel_exposure=2 * + hparams['N_PE_dir_exposure'], + in_channel_appearance=hparams['N_appearance']) + + self.Visibility = Visibility(in_channel_xyz=6 * hparams['N_IPE_xyz'], + in_channel_dir=6 * hparams['N_PE_dir_exposure']) + + self.models_to_train = [] + self.models_to_train += [self.embedding_appearance] + self.models_to_train += [self.Block_NeRF] + self.models_to_train += [self.Visibility] + + def forward(self, rays, ts): + B = rays.shape[0] + model = { + "block_model": self.Block_NeRF, + "visibility_model": self.Visibility + } + + results = defaultdict(list) + for i in range(0, B, self.hparams['chunk']): + rendered_ray_chunks = render_rays(model, self.Embedding, + rays[i:i + self.hparams['chunk']], + ts[i:i + self.hparams['chunk']], + N_samples=self.hparams['N_samples'], + N_importance=self.hparams['N_importance'], + chunk=self.hparams['chunk'], + type="train", + use_disp=self.hparams['use_disp'] + ) + for k, v in rendered_ray_chunks.items(): + results[k] += [v] + + for k, v in results.items(): + results[k] = torch.cat(v, 0) + + return results + + def setup(self, stage): + self.train_dataset = WaymoDataset(root_dir=self.hparams['root_dir'], + split='train', + block=self.hparams['block_index'], + img_downscale=self.hparams['img_downscale'], + near=self.hparams['near'], + far=self.hparams['far']) + self.val_dataset = WaymoDataset(root_dir=self.hparams['root_dir'], + split='val', + block=self.hparams['block_index'], + img_downscale=self.hparams['img_downscale'], + near=self.hparams['near'], + far=self.hparams['far']) + + def configure_optimizers(self): + self.optimizer = get_optimizer(self.hparams, self.models_to_train) + scheduler = get_scheduler(self.hparams, self.optimizer) + return [self.optimizer], [scheduler] + + def train_dataloader(self): + return DataLoader(self.train_dataset, + shuffle=True, + num_workers=8, + batch_size=self.hparams['batch_size'], + pin_memory=True) + + def val_dataloader(self): + return DataLoader(self.val_dataset, + shuffle=False, + num_workers=8, + batch_size=1, + pin_memory=True) + + def training_step(self, batch, batch_nb): + rays, rgbs, ts = batch['rays'], batch['rgbs'], batch['ts'] + results = self(rays, ts) + loss_d = self.loss(results, rgbs) + loss = sum(l for l in loss_d.values()) + + with torch.no_grad(): + psnr_ = psnr(results['rgb_fine'], rgbs) + + self.log('lr', get_learning_rate(self.optimizer)) + self.log('train/loss', loss) + for k, v in loss_d.items(): + self.log(f'train/{k}', v, prog_bar=True) + self.log('train/psnr', psnr_, prog_bar=True) + + return loss + + def validation_step(self, batch, batch_nb): # validate at each epoch + rays, rgbs, ts = batch['rays'].squeeze(), batch['rgbs'].squeeze(), batch['ts'].squeeze() + W,H=batch['w_h'] + results = self(rays, ts) + loss_d = self.loss(results, rgbs) + loss = sum(l for l in loss_d.values()) + + if batch_nb == 0: + img = results[f'rgb_fine'].view(H, W, 3).permute(2, 0, 1).cpu() # (3, H, W) + img_gt = rgbs.view(H, W, 3).permute(2, 0, 1).cpu() # (3, H, W) + depth = visualize_depth(results[f'depth_fine'].view(H, W)) # (3, H, W) + stack = torch.stack([img_gt, img, depth]) # (3, 3, H, W) + #stack = torch.stack([img_gt, img]) # (3, 3, H, W) + # todo: recheck this, * 255? + self.logger.experiment.add_images('val/GT_pred_depth', + stack, self.global_step) + + psnr_ = psnr(results['rgb_fine'], rgbs) + + log = {'val_loss': loss} + for k, v in loss_d.items(): + log[f'val_{k}']= v + log['val_psnr']= psnr_ + + return log + + def validation_epoch_end(self, outputs): + mean_loss = torch.stack([x['val_loss'] for x in outputs]).mean() + mean_psnr = torch.stack([x['val_psnr'] for x in outputs]).mean() + + self.log('val/loss', mean_loss) + self.log('val/psnr', mean_psnr, prog_bar=True) \ No newline at end of file diff --git a/block_nerf/block_nerf_model.py b/block_nerf/block_nerf_model.py index 0f582e4..973d82d 100644 --- a/block_nerf/block_nerf_model.py +++ b/block_nerf/block_nerf_model.py @@ -1,157 +1,176 @@ -from pytorch_lightning import LightningModule, Trainer import torch -import os -from collections import defaultdict -from torch.utils.data import DataLoader -from block_nerf.waymo_dataset import * -from block_nerf.block_nerf_model import * -from block_nerf.rendering import * -from block_nerf.metrics import * -from block_nerf.block_visualize import * -from block_nerf.learning_utils import * - - -class Block_NeRF_System(LightningModule): - - def __init__(self, hparams): - super(Block_NeRF_System, self).__init__() - self.hyper_params = hparams - self.save_hyperparameters(hparams) - self.loss = BlockNeRFLoss(1e-2) # hparams['Visi_loss'] - - self.xyz_IPE = InterPosEmbedding(hparams['N_IPE_xyz']) - self.dir_exposure_PE = PosEmbedding(hparams['N_PE_dir_exposure' - ]) - self.embedding_appearance = torch.nn.Embedding(hparams['N_vocab' - ], hparams['N_appearance']) - - self.Embedding = {'IPE': self.xyz_IPE, - 'PE': self.dir_exposure_PE, - 'appearance': self.embedding_appearance} - - self.Block_NeRF = Block_NeRF(in_channel_xyz=6 - * hparams['N_IPE_xyz'], in_channel_dir=6 - * hparams['N_PE_dir_exposure'], in_channel_exposure=2 - * hparams['N_PE_dir_exposure'], - in_channel_appearance=hparams['N_appearance']) - - self.Visibility = Visibility(in_channel_xyz=6 - * hparams['N_IPE_xyz'], in_channel_dir=6 - * hparams['N_PE_dir_exposure']) - - self.models_to_train = [] - self.models_to_train += [self.embedding_appearance] - self.models_to_train += [self.Block_NeRF] - self.models_to_train += [self.Visibility] - - def forward(self, rays, ts): - B = rays.shape[0] - model = {'block_model': self.Block_NeRF, - 'visibility_model': self.Visibility} - - results = defaultdict(list) - for i in range(0, B, self.hparams['chunk']): - rendered_ray_chunks = render_rays( - model, - self.Embedding, - rays[i:i + self.hparams['chunk']], - ts[i:i + self.hparams['chunk']], - N_samples=self.hparams['N_samples'], - N_importance=self.hparams['N_importance'], - chunk=self.hparams['chunk'], - type='train', - use_disp=self.hparams['use_disp'], - ) - for (k, v) in rendered_ray_chunks.items(): - results[k] += [v] - - for (k, v) in results.items(): - results[k] = torch.cat(v, 0) - - return results - - def setup(self, stage): - self.train_dataset = WaymoDataset( - root_dir=self.hparams['root_dir'], - split='train', - block=self.hparams['block_index'], - img_downscale=self.hparams['img_downscale'], - near=self.hparams['near'], - far=self.hparams['far'], - ) - self.val_dataset = WaymoDataset( - root_dir=self.hparams['root_dir'], - split='val', - block=self.hparams['block_index'], - img_downscale=self.hparams['img_downscale'], - near=self.hparams['near'], - far=self.hparams['far'], - ) - - def configure_optimizers(self): - self.optimizer = get_optimizer(self.hparams, - self.models_to_train) - scheduler = get_scheduler(self.hparams, self.optimizer) - return ([self.optimizer], [scheduler]) - - def train_dataloader(self): - return DataLoader(self.train_dataset, shuffle=True, - num_workers=8, - batch_size=self.hparams['batch_size'], - pin_memory=True) - - def val_dataloader(self): - return DataLoader(self.val_dataset, shuffle=False, - num_workers=8, batch_size=1, pin_memory=True) - - def training_step(self, batch, batch_nb): - (rays, rgbs, ts) = (batch['rays'], batch['rgbs'], batch['ts']) - results = self(rays, ts) - loss_d = self.loss(results, rgbs) - loss = sum(l for l in loss_d.values()) - - with torch.no_grad(): - psnr_ = psnr(results['rgb_fine'], rgbs) - - self.log('lr', get_learning_rate(self.optimizer)) - self.log('train/loss', loss) - - # for k, v in loss_d.items(): - # self.log(f'train/{k}', v, prog_bar=True) - - self.log('train/psnr', psnr_, prog_bar=True) +from torch import nn + + +class BlockNeRFLoss(nn.Module): + def __init__(self, lambda_mu=0.01, Visi_loss=1e-2): + super(BlockNeRFLoss, self).__init__() + self.lambda_mu = lambda_mu + self.Visi_loss = Visi_loss + + def forward(self, inputs, targets): + loss = {} + # RGB + loss['rgb_coarse'] = self.lambda_mu * ((inputs['rgb_coarse'] - targets[..., :3]) ** 2).mean() + loss['rgb_fine'] = ((inputs['rgb_fine'] - targets[..., :3]) ** 2).mean() + # visibility + loss["transmittance_coarse"] = self.lambda_mu * self.Visi_loss * ((inputs['transmittance_coarse_real'].detach() - + inputs['transmittance_coarse_vis'].squeeze()) ** 2).mean() + loss["transmittance_fine"] = self.Visi_loss * ((inputs['transmittance_fine_real'].detach() - inputs[ + 'transmittance_fine_vis'].squeeze()) ** 2).mean() return loss - def validation_step(self, batch, batch_nb): - (rays, rgbs, ts) = (batch['rays'].squeeze(), batch['rgbs' - ].squeeze(), batch['ts'].squeeze()) - (W, H) = batch['w_h'] - results = self(rays, ts) - loss_d = self.loss(results, rgbs) - loss = sum(l for l in loss_d.values()) - - if batch_nb == 0: - img = results['rgb_fine'].view(H, W, 3).permute(2, 0, - 1).cpu() # (3, H, W) - img_gt = rgbs.view(H, W, 3).permute(2, 0, 1).cpu() # (3, H, W) - depth = visualize_depth(results['depth_fine'].view(H, W)) # (3, H, W) - stack = torch.stack([img_gt, img, depth]) # (3, 3, H, W) - self.logger.experiment.add_images('val/GT_pred_depth', - stack, self.global_step) - - psnr_ = psnr(results['rgb_fine'], rgbs) - - log = {'val_loss': loss} - for (k, v) in loss_d.items(): - log['val_' + str(k)] = v - log['val_psnr'] = psnr_ - - return log - - def validation_epoch_end(self, outputs): - mean_loss = torch.stack([x['val_loss'] for x in outputs]).mean() - mean_psnr = torch.stack([x['val_psnr'] for x in outputs]).mean() - - self.log('val/loss', mean_loss) - self.log('val/psnr', mean_psnr, prog_bar=True) + +class InterPosEmbedding(nn.Module): + def __init__(self, N_freqs=10): + super(InterPosEmbedding, self).__init__() + self.N_freqs = N_freqs + self.funcs = [torch.sin, torch.cos] + + # [2^0,2^1,...,2^(n-1)]: for sin + self.freq_band_1 = 2 ** torch.linspace(0, N_freqs - 1, N_freqs) + # [4^0,4^1,...,4^(n-1)]: for diag(∑) + self.freq_band_2 = self.freq_band_1 ** 2 + + def forward(self, mu, diagE): + sin_out = [] + sin_cos = [] + for freq in self.freq_band_1: + for func in self.funcs: + sin_cos.append(func(freq * mu)) + sin_out.append(sin_cos) + sin_cos = [] + # sin_out:list:[sin(mu),cos(mu)] + diag_out = [] + for freq in self.freq_band_2: + diag_out.append(freq * diagE) + # diag_out:list:[4^(L-1)*diag(∑)] + out = [] + for sc_γ, diag_Eγ in zip(sin_out, diag_out): + # torch.exp(-0.5 * x_var) * torch.sin(x) + for sin_cos in sc_γ: # [sin,cos] + out.append(sin_cos * torch.exp(-0.5 * diag_Eγ)) + return torch.cat(out, -1) + + +class PosEmbedding(nn.Module): + def __init__(self, N_freqs): + """ + Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...) + in_channels: number of input channels (3 for both xyz and direction) + """ + super().__init__() + self.N_freqs = N_freqs + self.funcs = [torch.sin, torch.cos] + # [2^0,2^1,...,2^(n-1)] + self.freq_bands = 2 ** torch.linspace(0, N_freqs - 1, N_freqs) + + def forward(self, x): + out = [] + for freq in self.freq_bands: # [2^0,2^1,...,2^(n-1)] + for func in self.funcs: + out += [func(freq * x)] + # xyz——>63,dir——>27 + return torch.cat(out, -1) + +class Block_NeRF(nn.Module): + def __init__(self, D=8, W=256, skips=[4], + in_channel_xyz=60, in_channel_dir=24, + in_channel_exposure=8, # exposure is in 1d and dirs are in 3d + in_channel_appearance=32, + add_apperance=True, + add_exposure=True): + # input:[xyz60,dir24,exposure24,appearance24] + super(Block_NeRF, self).__init__() + self.D = D + self.W = W + self.skips = skips + self.in_channel_xyz = in_channel_xyz + self.in_channel_dir = in_channel_dir + self.in_channel_exposure = in_channel_exposure + self.in_channel_appearance = in_channel_appearance + self.add_appearance = add_apperance + self.add_exposure = add_exposure + + for i in range(D): + if i == 0: + layer = nn.Linear(in_channel_xyz, W) + elif i in skips: + layer = nn.Linear(W + in_channel_xyz, W) + else: + layer = nn.Linear(W, W) + layer = nn.Sequential(layer, nn.ReLU(True)) + setattr(self, f'xyz_encoding_{i + 1}', layer) + self.xyz_encoding_final = nn.Linear(W, W) + + input_channel = W + in_channel_dir + if add_apperance: + input_channel += in_channel_appearance + if add_exposure: + input_channel += in_channel_exposure + # 3层128 + self.dir_encoding = nn.Sequential( # RGB由dir,Exposure,Appearance决定 + nn.Linear( + input_channel, + W // 2 + ), nn.ReLU(True), + nn.Linear(W // 2, W // 2), nn.ReLU(True), + nn.Linear(W // 2, W // 2), nn.ReLU(True) + ) + + self.static_sigma = nn.Sequential(nn.Linear(W, 1), nn.Softplus()) + self.static_rgb = nn.Sequential(nn.Linear(W // 2, 3), nn.Sigmoid()) + + def forward(self, x, sigma_only=False): + if sigma_only: + input_xyz = x + else: + input_xyz, input_dir, input_exp, input_appear = torch.split(x, [self.in_channel_xyz, self.in_channel_dir, + self.in_channel_exposure, + self.in_channel_appearance], dim=-1) + xyz = input_xyz + for i in range(self.D): + if i in self.skips: + xyz = torch.cat([xyz, input_xyz], dim=-1) + xyz = getattr(self, f'xyz_encoding_{i + 1}')(xyz) + + static_sigma = self.static_sigma(xyz) + if sigma_only: + return static_sigma + + xyz_feature = self.xyz_encoding_final(xyz) + input_xyz_feature = torch.cat([xyz_feature, input_dir], dim=-1) + if self.add_exposure: + input_xyz_feature = torch.cat([input_xyz_feature, input_exp], dim=-1) + if self.add_appearance: + input_xyz_feature = torch.cat([input_xyz_feature, input_appear], dim=-1) + + dir_encoding = self.dir_encoding(input_xyz_feature) + + static_rgb = self.static_rgb(dir_encoding) + static_rgb_sigma = torch.cat([static_rgb, static_sigma], dim=-1) + + return static_rgb_sigma + + +class Visibility(nn.Module): + def __init__(self, + in_channel_xyz=60, in_channel_dir=24, + W=128): + super(Visibility, self).__init__() + self.in_channel_xyz = in_channel_xyz + self.in_channel_dir = in_channel_dir + + self.vis_encoding = nn.Sequential( + nn.Linear(in_channel_xyz + in_channel_dir, W), nn.ReLU(True), + nn.Linear(W, W), nn.ReLU(True), + nn.Linear(W, W), nn.ReLU(True), + nn.Linear(W, W), nn.ReLU(True), + ) + self.visibility = nn.Sequential(nn.Linear(W, 1), nn.Softplus()) + + def forward(self, x): + vis_encode = self.vis_encoding(x) + visibility = self.visibility(vis_encode) + return visibility diff --git a/eval_block_nerf.py b/eval_block_nerf.py index d45f2f5..27d578e 100644 --- a/eval_block_nerf.py +++ b/eval_block_nerf.py @@ -134,6 +134,7 @@ def Inverse_Interpolation(model_result, W_H): if __name__ == '__main__': + print("Warning, this old implementation of BlockNeRF will be deprecated in the next version!") torch.cuda.empty_cache() hparams = get_hparams() os.makedirs(hparams['save_path'], exist_ok=True) diff --git a/figures/vg_fg_gtk.jpg b/figures/vg_fg_gtk.jpg index 1b2849e..c77e3a1 100644 Binary files a/figures/vg_fg_gtk.jpg and b/figures/vg_fg_gtk.jpg differ diff --git a/figures/vg_fg_gtk.pdf b/figures/vg_fg_gtk.pdf index bd67a44..872a48c 100644 Binary files a/figures/vg_fg_gtk.pdf and b/figures/vg_fg_gtk.pdf differ diff --git a/train_block_nerf.py b/train_block_nerf.py index 14c49be..4b18482 100644 --- a/train_block_nerf.py +++ b/train_block_nerf.py @@ -100,6 +100,7 @@ def get_opts(): def main(hparams): + print("Warning, this old implementation of BlockNeRF will be deprecated in the next version!") hparams['block_index'] = 'block_' + str(hparams['block_index']) system = Block_NeRF_System(hparams) checkpoint_callback = \