-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
poetry.lock | ||
/dist/ |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from pcr import __version__ | ||
from pcr.model.PCRNetwork import PCRNetwork | ||
from pcr.pcn_training_config import Config | ||
from pcr.misc import download_checkpoint | ||
import torch | ||
|
||
|
||
def test_version(): | ||
assert __version__ == '0.1.0' | ||
|
||
id = '2kkeig53' | ||
version = 'v9' | ||
|
||
model = PCRNetwork(Config.Model) | ||
ckpt_path = download_checkpoint(id, version) | ||
model.load_state_dict(torch.load(ckpt_path)) | ||
|
||
partial = torch.rand((10000, 3)) | ||
complete = model(partial) | ||
print(complete) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
__version__ = '0.1.0' |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import requests | ||
# import torch | ||
# import os | ||
# import wandb | ||
# from pathlib import Path | ||
|
||
|
||
def onnx_minimum(x1, x2): | ||
return torch.where(x2 < x1, x2, x1) | ||
|
||
|
||
def fp_sampling(points, num: int): | ||
batch_size = points.shape[0] | ||
# TODO use onnx_cdists just to export to onnx, otherwise use torch.cdist | ||
# D = onnx_cdists(points, points) | ||
D = torch.cdist(points, points) | ||
# By default, takes the first point in the list to be the | ||
# first point in the permutation, but could be random | ||
res = torch.zeros((batch_size, 1), dtype=torch.int32, device=points.device) | ||
ds = D[:, 0, :] | ||
for i in range(1, num): | ||
idx = ds.max(dim=1)[1] | ||
res = torch.cat([res, idx.unsqueeze(1).to(torch.int32)], dim=1) | ||
ds = onnx_minimum(ds, D[torch.arange(batch_size), idx, :]) | ||
|
||
return res | ||
|
||
|
||
# def download_checkpoint(id, version): | ||
# ckpt = f'model-{id}:{version}' | ||
# project = 'pcr-grasping' | ||
# | ||
# ckpt_path = f'artifacts/{ckpt}/model.ckpt' if os.name != 'nt' else\ | ||
# f'artifacts/{ckpt}/model.ckpt'.replace(':', '-') | ||
# | ||
# if not Path(ckpt_path).exists(): | ||
# run = wandb.init(id=id, settings=wandb.Settings(start_method="spawn")) | ||
# run.use_artifact(f'rosasco/{project}/{ckpt}', type='model').download(f'artifacts/{ckpt}/') | ||
# wandb.finish(exit_code=0) | ||
# | ||
# return ckpt_path | ||
# | ||
# | ||
# import requests | ||
|
||
def download_checkpoint(id, destination): | ||
URL = f'https://drive.google.com/u/0/uc?id={id}&export=download' | ||
|
||
session = requests.Session() | ||
|
||
response = session.get(URL, stream = True) | ||
|
||
token = None | ||
for key, value in response.cookies.items(): | ||
if key.startswith('download_warning'): | ||
token = value | ||
break | ||
|
||
if token: | ||
params = { 'id' : id, 'confirm' : token } | ||
response = session.get(URL, params = params, stream = True) | ||
|
||
CHUNK_SIZE = 32768 | ||
with open(destination, "wb") as f: | ||
for chunk in response.iter_content(CHUNK_SIZE): | ||
if chunk: # filter out keep-alive new chunks | ||
f.write(chunk) | ||
|
||
if __name__ == '__main__': | ||
download_checkpoint('1vAxN2MF7sWayeG1uvh2Jc6Gr4WHkCX6X', './checkpoint') |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
from torch import nn | ||
from .MLP import MLP | ||
from .Transformer import PCTransformer | ||
|
||
|
||
class BackBone(nn.Module): | ||
def __init__(self, config): | ||
super().__init__() | ||
self.embed_dim = config.embed_dim | ||
self.knn_layer = config.knn_layer | ||
|
||
self.transformer = PCTransformer(in_chans=config.n_channels, | ||
embed_dim=config.embed_dim, | ||
depth=config.encoder_depth, | ||
mlp_ratio=config.mlp_ratio, | ||
qkv_bias=config.qkv_bias, | ||
knn_layer=config.knn_layer, | ||
num_heads=config.num_heads, | ||
attn_drop_rate=config.attn_drop_rate, | ||
drop_rate=config.drop_rate, | ||
qk_scale=config.qk_scale, | ||
out_size=config.out_size) | ||
|
||
# Select between deep feature extractor and not | ||
if config.use_deep_weights_generator: | ||
generator = MLP | ||
else: | ||
generator = nn.Linear | ||
|
||
# Select the right dimension for linear layers | ||
global_size = config.out_size | ||
if config.use_object_id: | ||
global_size = config.out_size * 2 | ||
|
||
# Generate first weight, bias and scale of the input layer of the implicit function | ||
self.output = nn.ModuleList([nn.ModuleList([ | ||
generator(global_size, config.hidden_dim * 3), | ||
generator(global_size, config.hidden_dim), | ||
generator(global_size, config.hidden_dim)])]) | ||
|
||
# Generate weights, biases and scales of the hidden layers of the implicit function | ||
for _ in range(config.depth): | ||
self.output.append(nn.ModuleList([ | ||
generator(global_size, config.hidden_dim * config.hidden_dim), | ||
generator(global_size, config.hidden_dim), | ||
generator(global_size, config.hidden_dim) | ||
])) | ||
# Generate weights, biases and scales of the output layer of the implicit function | ||
self.output.append(nn.ModuleList([ | ||
generator(global_size, config.hidden_dim), | ||
generator(global_size, 1), | ||
generator(global_size, 1), | ||
])) | ||
pass | ||
|
||
def _init_weights(self, m): | ||
if isinstance(m, nn.LayerNorm) or isinstance(m, nn.GroupNorm) or isinstance(m, nn.BatchNorm1d): | ||
nn.init.constant_(m.bias, 0) | ||
nn.init.constant_(m.weight, 1.0) | ||
elif isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): | ||
nn.init.xavier_normal_(m.weight.data, gain=1) | ||
|
||
if isinstance(m, nn.Linear) and m.bias is not None: | ||
nn.init.constant_(m.bias, 0) | ||
|
||
def forward(self, xyz, object_id=None): | ||
|
||
global_feature = self.transformer(xyz) # B M C and B M 3 | ||
|
||
fast_weights = [] | ||
for layer in self.output: | ||
fast_weights.append([ly(global_feature) for ly in layer]) | ||
|
||
# return fast_weights, global_feature | ||
return fast_weights, global_feature | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import torch | ||
from torch.nn import BCEWithLogitsLoss | ||
from torch.optim import Adam | ||
from pcr.pcn_training_config import Config | ||
|
||
|
||
class Decoder: | ||
def __init__(self, sdf, num_points, thr, itr): | ||
self.num_points = num_points | ||
self.thr = thr | ||
self.itr = itr | ||
|
||
self.sdf = sdf | ||
|
||
def __call__(self, fast_weights): | ||
with torch.enable_grad(): | ||
old = self.sdf.training | ||
self.sdf.eval() | ||
|
||
batch_size = fast_weights[0][0].shape[0] | ||
refined_pred = torch.tensor(torch.randn(batch_size, self.num_points, 3).cpu().detach().numpy() * 1, device=Config.General.device, | ||
requires_grad=True) | ||
|
||
loss_function = BCEWithLogitsLoss(reduction='mean') | ||
optim = Adam([refined_pred], lr=0.1) | ||
|
||
c1, c2, c3, c4 = 1, 0, 0, 0 #1, 0, 0 1, 1e3, 0 # 0, 1e4, 5e2 | ||
new_points = [[] for _ in range(batch_size)] | ||
# refined_pred.detach().clone() | ||
for step in range(self.itr): | ||
results = self.sdf(refined_pred, fast_weights) | ||
|
||
for i in range(batch_size): | ||
points = refined_pred[i].detach().clone()[(torch.sigmoid(results[i]).squeeze() >= self.thr), :] | ||
preds = torch.sigmoid(results[i]).detach().clone()[torch.sigmoid(results[i]).squeeze() >= self.thr] | ||
new_points[i] += [torch.cat([points, preds], dim=1)] | ||
|
||
gt = torch.ones_like(results[..., 0], dtype=torch.float32) | ||
gt[:, :] = 1 | ||
loss1 = c1 * loss_function(results[..., 0], gt) | ||
|
||
loss_value = loss1 | ||
|
||
self.sdf.zero_grad() | ||
optim.zero_grad() | ||
loss_value.backward(inputs=[refined_pred]) | ||
optim.step() | ||
|
||
selected = [torch.cat(points).squeeze() for points in new_points] | ||
res = torch.zeros([batch_size, self.num_points, 4], device=Config.General.device) | ||
for i, s in enumerate(selected): | ||
# torch.sum(torch.sum(torch.cat([s > 0.5, s < -0.5], dim=1), dim=1) != 0) | ||
k = min(s.size(0), self.num_points) | ||
perm = torch.randperm(s.size(0)) | ||
res[i][:k] = s[perm[:k]] | ||
|
||
self.sdf.train(old) | ||
return res[..., :3], res[..., -1] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import torch | ||
from torch import nn | ||
|
||
|
||
class ImplicitFunction(nn.Module): | ||
|
||
def __init__(self, config, params=None): | ||
super().__init__() | ||
self.params = params | ||
self.relu = nn.LeakyReLU(0.2) | ||
# self.dropout = nn.Dropout(0.5) | ||
self.hidden_dim = config.hidden_dim | ||
|
||
def set_params(self, params): | ||
self.params = params | ||
|
||
def forward(self, points, params=None): | ||
if params is not None: | ||
self.params = params | ||
|
||
if self.params is None: | ||
raise ValueError('Can not run forward on uninitialized implicit function') | ||
|
||
x = points | ||
# TODO: I just added unsqueeze(1), reshape(-1) and bmm and everything works (or did I introduce some kind of bug?) | ||
weights, scales, biases = self.params[0] | ||
weights = weights.reshape(-1, 3, self.hidden_dim) | ||
scales = scales.unsqueeze(1) | ||
biases = biases.unsqueeze(1) | ||
|
||
x = torch.bmm(x, weights) * scales + biases | ||
# x = self.dropout(x) | ||
x = self.relu(x) | ||
|
||
for layer in self.params[1:-1]: | ||
weights, scales, biases = layer | ||
|
||
weights = weights.reshape(-1, self.hidden_dim, self.hidden_dim) | ||
scales = scales.unsqueeze(1) | ||
biases = biases.unsqueeze(1) | ||
|
||
x = torch.bmm(x, weights) * scales + biases | ||
# x = self.dropout(x) | ||
x = self.relu(x) | ||
|
||
weights, scales, biases = self.params[-1] | ||
|
||
weights = weights.reshape(-1, self.hidden_dim, 1) | ||
scales = scales.unsqueeze(1) | ||
biases = biases.unsqueeze(1) | ||
|
||
x = torch.bmm(x, weights) * scales + biases | ||
|
||
return x |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import torch | ||
from torch import nn | ||
|
||
|
||
class MLP(nn.Module): | ||
def __init__(self, input_size, output_size, hidden_size=None): | ||
super().__init__() | ||
if hidden_size is None: | ||
hidden_size = (input_size + output_size) // 2 | ||
self.layers = nn.Sequential( | ||
nn.Linear(input_size, hidden_size), | ||
torch.nn.GELU(), | ||
torch.nn.LayerNorm(hidden_size), | ||
nn.Linear(hidden_size, output_size) | ||
) | ||
|
||
def forward(self, x): | ||
return self.layers(x) |