-
Notifications
You must be signed in to change notification settings - Fork 64
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
30c5a0e
commit f7b3507
Showing
1 changed file
with
383 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,383 @@ | ||
import torch.nn as nn | ||
import torch | ||
import numpy as np | ||
import os | ||
from torch.utils.data import Dataset, DataLoader | ||
from tqdm import tqdm | ||
from utils import TrainClock, cycle, ensure_dirs, ensure_dir | ||
import argparse | ||
import h5py | ||
import shutil | ||
import json | ||
import random | ||
from plyfile import PlyData, PlyElement | ||
import sys | ||
sys.path.append("..") | ||
from agent import BaseAgent | ||
from pointnet2_ops.pointnet2_modules import PointnetFPModule, PointnetSAModule | ||
from plyfile import PlyData, PlyElement | ||
|
||
|
||
def write_ply(points, filename, text=False): | ||
""" input: Nx3, write points to filename as PLY format. """ | ||
points = [(points[i,0], points[i,1], points[i,2]) for i in range(points.shape[0])] | ||
vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4')]) | ||
el = PlyElement.describe(vertex, 'vertex', comments=['vertices']) | ||
with open(filename, mode='wb') as f: | ||
PlyData([el], text=text).write(f) | ||
|
||
|
||
class Config(object): | ||
n_points = 2048 | ||
batch_size = 128 | ||
num_workers = 8 | ||
nr_epochs = 200 | ||
lr = 1e-4 | ||
lr_step_size = 50 | ||
# beta1 = 0.5 | ||
grad_clip = None | ||
noise = 0.02 | ||
|
||
save_frequency = 100 | ||
val_frequency = 10 | ||
|
||
def __init__(self, args): | ||
self.data_root = os.path.join(args.proj_dir, args.exp_name, "results/all_zs_ckpt{}.h5".format(args.ae_ckpt)) | ||
self.exp_dir = os.path.join(args.proj_dir, args.exp_name, "pc2cad_tune_noise{}_{}_new".format(self.n_points, self.noise)) | ||
print(self.exp_dir) | ||
self.log_dir = os.path.join(self.exp_dir, 'log') | ||
self.model_dir = os.path.join(self.exp_dir, 'model') | ||
self.gpu_ids = args.gpu_ids | ||
|
||
if (not args.test) and args.cont is not True and os.path.exists(self.exp_dir): | ||
response = input('Experiment log/model already exists, overwrite? (y/n) ') | ||
if response != 'y': | ||
exit() | ||
shutil.rmtree(self.exp_dir) | ||
ensure_dirs([self.log_dir, self.model_dir]) | ||
if not args.test: | ||
os.system("cp pc2cad.py {}".format(self.exp_dir)) | ||
with open('{}/config.txt'.format(self.exp_dir), 'w') as f: | ||
json.dump(self.__dict__, f, indent=2) | ||
|
||
|
||
class PointNet2(nn.Module): | ||
def __init__(self): | ||
super(PointNet2, self).__init__() | ||
|
||
self.use_xyz = True | ||
|
||
self._build_model() | ||
|
||
def _build_model(self): | ||
self.SA_modules = nn.ModuleList() | ||
self.SA_modules.append( | ||
PointnetSAModule( | ||
npoint=512, | ||
radius=0.1, | ||
nsample=64, | ||
mlp=[0, 32, 32, 64], | ||
# bn=False, | ||
use_xyz=self.use_xyz, | ||
) | ||
) | ||
|
||
self.SA_modules.append( | ||
PointnetSAModule( | ||
npoint=256, | ||
radius=0.2, | ||
nsample=64, | ||
mlp=[64, 64, 64, 128], | ||
# bn=False, | ||
use_xyz=self.use_xyz, | ||
) | ||
) | ||
|
||
self.SA_modules.append( | ||
PointnetSAModule( | ||
npoint=128, | ||
radius=0.4, | ||
nsample=64, | ||
mlp=[128, 128, 128, 256], | ||
# bn=False, | ||
use_xyz=self.use_xyz, | ||
) | ||
) | ||
|
||
self.SA_modules.append( | ||
PointnetSAModule( | ||
mlp=[256, 256, 512, 1024], | ||
# bn=False, | ||
use_xyz=self.use_xyz | ||
) | ||
) | ||
|
||
self.fc_layer = nn.Sequential( | ||
nn.Linear(1024, 512), | ||
nn.LeakyReLU(True), | ||
nn.Linear(512, 256), | ||
nn.LeakyReLU(True), | ||
nn.Linear(256, 256), | ||
nn.Tanh() | ||
) | ||
|
||
def _break_up_pc(self, pc): | ||
xyz = pc[..., 0:3].contiguous() | ||
features = pc[..., 3:].transpose(1, 2).contiguous() if pc.size(-1) > 3 else None | ||
|
||
return xyz, features | ||
|
||
def forward(self, pointcloud): | ||
r""" | ||
Forward pass of the network | ||
Parameters | ||
---------- | ||
pointcloud: Variable(torch.cuda.FloatTensor) | ||
(B, N, 3 + input_channels) tensor | ||
Point cloud to run predicts on | ||
Each point in the point-cloud MUST | ||
be formated as (x, y, z, features...) | ||
""" | ||
xyz, features = self._break_up_pc(pointcloud) | ||
|
||
for module in self.SA_modules: | ||
xyz, features = module(xyz, features) | ||
|
||
return self.fc_layer(features.squeeze(-1)) | ||
|
||
|
||
class EncoderPointNet(nn.Module): | ||
def __init__(self, n_filters=(128, 256, 512, 1024), bn=True): | ||
super(EncoderPointNet, self).__init__() | ||
self.n_filters = list(n_filters) # + [latent_dim] | ||
# self.latent_dim = latent_dim | ||
|
||
model = [] | ||
prev_nf = 3 | ||
for idx, nf in enumerate(self.n_filters): | ||
conv_layer = nn.Conv1d(prev_nf, nf, kernel_size=1, stride=1) | ||
model.append(conv_layer) | ||
|
||
if bn: | ||
bn_layer = nn.BatchNorm1d(nf) | ||
model.append(bn_layer) | ||
|
||
act_layer = nn.LeakyReLU(inplace=True) | ||
model.append(act_layer) | ||
prev_nf = nf | ||
|
||
self.model = nn.Sequential(*model) | ||
|
||
self.fc_layer = nn.Sequential( | ||
nn.Linear(1024, 512), | ||
nn.LeakyReLU(True), | ||
nn.Linear(512, 256), | ||
nn.Tanh() | ||
) | ||
|
||
def forward(self, x): | ||
x = x.permute(0, 2, 1) | ||
x = self.model(x) | ||
x = torch.mean(x, dim=2) | ||
x = self.fc_layer(x) | ||
return x | ||
|
||
|
||
class TrainAgent(BaseAgent): | ||
def build_net(self, config): | ||
net = PointNet2() | ||
if len(config.gpu_ids) > 1: | ||
net = nn.DataParallel(net) | ||
# net = EncoderPointNet() | ||
return net | ||
|
||
def set_loss_function(self): | ||
self.criterion = nn.MSELoss().cuda() | ||
|
||
def set_optimizer(self, config): | ||
"""set optimizer and lr scheduler used in training""" | ||
self.optimizer = torch.optim.Adam(self.net.parameters(), config.lr) # , betas=(config.beta1, 0.9)) | ||
self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, config.lr_step_size) | ||
|
||
def forward(self, data): | ||
points = data["points"].cuda() | ||
code = data["code"].cuda() | ||
|
||
pred_code = self.net(points) | ||
|
||
loss = self.criterion(pred_code, code) | ||
return pred_code, {"mse": loss} | ||
|
||
|
||
def read_ply(path, with_normal=False): | ||
with open(path, 'rb') as f: | ||
plydata = PlyData.read(f) | ||
x = np.array(plydata['vertex']['x']) | ||
y = np.array(plydata['vertex']['y']) | ||
z = np.array(plydata['vertex']['z']) | ||
vertex = np.stack([x, y, z], axis=1) | ||
if with_normal: | ||
nx = np.array(plydata['vertex']['nx']) | ||
ny = np.array(plydata['vertex']['ny']) | ||
nz = np.array(plydata['vertex']['nz']) | ||
normals = np.stack([nx, ny, nz], axis=1) | ||
if with_normal: | ||
return np.concatenate([vertex, normals], axis=1) | ||
else: | ||
return vertex | ||
|
||
|
||
class ShapeCodesDataset(Dataset): | ||
def __init__(self, phase, config): | ||
super(ShapeCodesDataset, self).__init__() | ||
self.n_points = config.n_points | ||
self.data_root = config.data_root | ||
# self.abc_root = "/mnt/disk6/wurundi/abc" | ||
self.abc_root = "/home/rundi/data/abc" | ||
self.pc_root = self.abc_root + "/pc_v5a_processed_merge" | ||
self.path = os.path.join(self.abc_root, "cad_e10_l6_c15_len60_min0_t100.json") | ||
with open(self.path, "r") as fp: | ||
self.all_data = json.load(fp)[phase] | ||
|
||
with h5py.File(self.data_root, 'r') as fp: | ||
self.zs = fp["{}_zs".format(phase)][:] | ||
|
||
self.noise = config.noise | ||
|
||
def __getitem__(self, index): | ||
data_id = self.all_data[index] | ||
pc_path = os.path.join(self.pc_root, data_id + '.ply') | ||
if not os.path.exists(pc_path): | ||
return self.__getitem__(index + 1) | ||
pc_n = read_ply(pc_path, with_normal=True) | ||
pc = pc_n[:, :3] | ||
normal = pc_n[:, -3:] | ||
sample_idx = random.sample(list(range(pc.shape[0])), self.n_points) | ||
pc = pc[sample_idx] | ||
normal = normal[sample_idx] | ||
normal = normal / (np.linalg.norm(normal, axis=1, keepdims=True) + 1e-6) | ||
pc = pc + np.random.uniform(-self.noise, self.noise, (pc.shape[0], 1)) * normal | ||
pc = torch.tensor(pc, dtype=torch.float32) | ||
shape_code = torch.tensor(self.zs[index], dtype=torch.float32) | ||
return {"points": pc, "code": shape_code, "id": data_id} | ||
|
||
def __len__(self): | ||
return len(self.zs) | ||
|
||
|
||
def get_dataloader(phase, config, shuffle=None): | ||
is_shuffle = phase == 'train' if shuffle is None else shuffle | ||
|
||
dataset = ShapeCodesDataset(phase, config) | ||
dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=is_shuffle, num_workers=config.num_workers) | ||
return dataloader | ||
|
||
|
||
parser = argparse.ArgumentParser() | ||
# parser.add_argument('--proj_dir', type=str, default="/mnt/disk6/wurundi/cad_gen", | ||
# help="path to project folder where models and logs will be saved") | ||
parser.add_argument('--proj_dir', type=str, default="/home/rundi/project_log/cad_gen", | ||
help="path to project folder where models and logs will be saved") | ||
parser.add_argument('--exp_name', type=str, required=True, help="name of this experiment") | ||
parser.add_argument('--ae_ckpt', type=str, required=True, help="desired checkpoint to restore") | ||
parser.add_argument('--continue', dest='cont', action='store_true', help="continue training from checkpoint") | ||
parser.add_argument('--ckpt', type=str, default='latest', required=False, help="desired checkpoint to restore") | ||
parser.add_argument('--test',action='store_true', help="test mode") | ||
parser.add_argument('--n_samples', type=int, default=100, help="number of samples to generate when testing") | ||
parser.add_argument('-g', '--gpu_ids', type=str, default="0", | ||
help="gpu to use, e.g. 0 0,1,2. CPU not supported.") | ||
args = parser.parse_args() | ||
|
||
if args.gpu_ids is not None: | ||
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_ids) | ||
|
||
cfg = Config(args) | ||
print("data path:", cfg.data_root) | ||
agent = TrainAgent(cfg) | ||
|
||
if not args.test: | ||
# load from checkpoint if provided | ||
if args.cont: | ||
agent.load_ckpt(args.ckpt) | ||
# for g in agent.optimizer.param_groups: | ||
# g['lr'] = 1e-5 | ||
|
||
# create dataloader | ||
train_loader = get_dataloader('train', cfg) | ||
val_loader = get_dataloader('validation', cfg) | ||
val_loader = cycle(val_loader) | ||
|
||
# start training | ||
clock = agent.clock | ||
|
||
for e in range(clock.epoch, cfg.nr_epochs): | ||
# begin iteration | ||
pbar = tqdm(train_loader) | ||
for b, data in enumerate(pbar): | ||
# train step | ||
outputs, losses = agent.train_func(data) | ||
|
||
pbar.set_description("EPOCH[{}][{}]".format(e, b)) | ||
pbar.set_postfix({k: v.item() for k, v in losses.items()}) | ||
|
||
# validation step | ||
if clock.step % cfg.val_frequency == 0: | ||
data = next(val_loader) | ||
outputs, losses = agent.val_func(data) | ||
|
||
clock.tick() | ||
|
||
clock.tock() | ||
|
||
if clock.epoch % cfg.save_frequency == 0: | ||
agent.save_ckpt() | ||
|
||
# if clock.epoch % 10 == 0: | ||
agent.save_ckpt('latest') | ||
else: | ||
# load trained weights | ||
agent.load_ckpt(args.ckpt) | ||
|
||
test_loader = get_dataloader('test', cfg) | ||
|
||
# save_dir = os.path.join(cfg.exp_dir, "results/fake_z_ckpt{}_num{}_pc".format(args.ckpt, args.n_samples)) | ||
save_dir = os.path.join(cfg.exp_dir, "results/pc2cad_ckpt{}_num{}".format(args.ckpt, args.n_samples)) | ||
if not os.path.exists(save_dir): | ||
os.makedirs(save_dir) | ||
|
||
all_zs = [] | ||
all_ids = [] | ||
pbar = tqdm(test_loader) | ||
cnt = 0 | ||
for i, data in enumerate(pbar): | ||
with torch.no_grad(): | ||
pred_z, _ = agent.forward(data) | ||
pred_z = pred_z.detach().cpu().numpy() | ||
# print(pred_z.shape) | ||
all_zs.append(pred_z) | ||
|
||
all_ids.extend(data['id']) | ||
pts = data['points'].detach().cpu().numpy() | ||
# for j in range(pred_z.shape[0]): | ||
# save_path = os.path.join(save_dir, "{}.ply".format(data['id'][j])) | ||
# write_ply(pts[j], save_path) | ||
# for j in range(pred_z.shape[0]): | ||
# save_path = os.path.join(save_dir, "{}.h5".format(data['id'][j])) | ||
# with h5py.File(save_path, 'w') as fp: | ||
# fp.create_dataset("zs", data=pred_z[j]) | ||
|
||
cnt += pred_z.shape[0] | ||
if cnt > args.n_samples: | ||
break | ||
|
||
all_zs = np.concatenate(all_zs, axis=0) | ||
# save generated z | ||
save_path = os.path.join(cfg.exp_dir, "results/pc2cad_z_ckpt{}_num{}.h5".format(args.ckpt, args.n_samples)) | ||
ensure_dir(os.path.dirname(save_path)) | ||
with h5py.File(save_path, 'w') as fp: | ||
fp.create_dataset("zs", shape=all_zs.shape, data=all_zs) | ||
|
||
save_path = os.path.join(cfg.exp_dir, "results/pc2cad_z_ckpt{}_num{}_ids.json".format(args.ckpt, args.n_samples)) | ||
with open(save_path, 'w') as fp: | ||
json.dump(all_ids, fp) |