Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #32

Merged
merged 9 commits into from
Sep 5, 2023
Merged

Dev #32

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
437 changes: 437 additions & 0 deletions examples/2_forward/code.ipynb

Large diffs are not rendered by default.

111 changes: 101 additions & 10 deletions irtk/connectors/psdr_jit_connector.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,20 @@
from ..connector import Connector
from ..scene import *
from ..config import *
from ..io import write_mesh
from collections import OrderedDict

import drjit
import psdr_jit
from drjit.scalar import Array3f
from drjit.cuda import Array3f as Vector3fC, Array3i as Vector3iC
from drjit.cuda.ad import Array3f as Vector3fD, Float32 as FloatD, Matrix4f as Matrix4fD, Matrix3f as Matrix3fD
from drjit.cuda.ad import Array3f as Vector3fD, Array1f as Vector1fD, Float32 as FloatD, Matrix4f as Matrix4fD, Matrix3f as Matrix3fD
from drjit.cuda.ad import Float32 as FloatD
import torch

import os

class PSDRJITConnector(Connector, connector_name='psdr_jit'):

backend = 'torch'
device = 'cuda'
ftype = torch.float32
itype = torch.long

def __init__(self):
super().__init__()

Expand Down Expand Up @@ -73,7 +67,7 @@ def renderC(self, scene, render_options, sensor_ids=[0], integrator_id=0):

images = []
for sensor_id in sensor_ids:
image = torch.zeros((h * w, c)).to(device).to(ftype)
image = to_torch_f(torch.zeros((h * w, c)))
for i in range(npass):
image_pass = integrator.renderC(cache['scene'], sensor_id).torch()
image += image_pass / npass
Expand Down Expand Up @@ -106,11 +100,48 @@ def renderD(self, image_grads, scene, render_options, sensor_ids=[0], integrator
drjit.backward(tmp)

for param_grad, drjit_param in zip(param_grads, drjit_params):
grad = drjit.grad(drjit_param).torch().to(device).to(ftype)
grad = to_torch_f(drjit.grad(drjit_param).torch())
grad = torch.nan_to_num(grad).reshape(param_grad.shape)
param_grad += grad

return param_grads
return param_grads

def forward_ad_mesh_translation(self, mesh_id, scene, render_options, sensor_ids=[0], integrator_id=0):
cache, drjit_params = self.update_scene_objects(scene, render_options)
assert len(drjit_params) == 0

P = FloatD(0.)
drjit.enable_grad(P)
psdr_mesh = cache['scene'].param_map[cache['name_map'][mesh_id]]
psdr_mesh.set_transform(Matrix4fD([[1.,0.,0.,P],[0.,1.,0.,0.],[0.,0.,1.,0.],[0.,0.,0.,1.],]))

cache['scene'].configure(sensor_ids)

npass = render_options['npass']
h, w, c = cache['film']['shape']
if type(integrator_id) == int:
integrator = list(cache['integrators'].values())[integrator_id]
elif type(integrator_id) == str:
integrator = cache['integrators'][integrator_id]
else:
raise RuntimeError('integrator_id is invalid: {integrator_id}')


image = to_torch_f(torch.zeros((h * w, c)))
grad_image = to_torch_f(torch.zeros((h * w, c)))

for j in range(npass):
drjit_image = integrator.renderD(cache['scene'], sensor_ids[0])
image += to_torch_f(drjit_image.torch()) / npass

drjit.set_grad(P, 1.0)
drjit.forward_to(drjit_image)
drjit_grad_image = drjit.grad(drjit_image)
grad_image += to_torch_f(drjit_grad_image.torch()) / npass

image = image.reshape(h, w, c)
grad_image = grad_image.reshape(h, w, c)
return image, grad_image

@PSDRJITConnector.register(Integrator)
def process_integrator(name, scene):
Expand Down Expand Up @@ -417,4 +448,64 @@ def enable_grad(drjit_param):
if param_name == 'radiance':
enable_grad(psdr_emitter.radiance.data)

return drjit_params


# Scene components specfic to psdr-jit
class MicrofacetBRDFPerVertex(ParamGroup):

def __init__(self, d, s, r):
super().__init__()

self.add_param('d', to_torch_f(d), is_tensor=True, is_diff=True, help_msg='diffuse reflectance')
self.add_param('s', to_torch_f(s), is_tensor=True, is_diff=True, help_msg='specular reflectance')
self.add_param('r', to_torch_f(r), is_tensor=True, is_diff=True, help_msg='roughness')

@PSDRJITConnector.register(MicrofacetBRDFPerVertex)
def process_microfacet_brdf_per_vertex(name, scene):
brdf = scene[name]
cache = scene.cached['psdr_jit']
psdr_scene = cache['scene']

# Create the object if it has not been created
if name not in cache['name_map']:
d = Vector3fD(brdf['d'])
s = Vector3fD(brdf['s'])
r = Vector1fD(brdf['r'])

psdr_bsdf = psdr_jit.MicrofacetBSDFPerVertex(s, d, r)
psdr_scene.add_BSDF(psdr_bsdf, name)
cache['name_map'][name] = f"BSDF[id={name}]"

psdr_brdf = psdr_scene.param_map[cache['name_map'][name]]

# Update parameters
updated = brdf.get_updated()
if len(updated) > 0:
for param_name in updated:
if param_name == 'd':
psdr_brdf.diffuseReflectance = Vector3fD(brdf['d'])
elif param_name == 's':
psdr_brdf.specularReflectance = Vector3fD(brdf['s'])
elif param_name == 'r':
psdr_brdf.roughness= Vector1fD(brdf['r'])
brdf.params[param_name]['updated'] = False

# Enable grad for parameters requiring grad
drjit_params = []

def enable_grad(drjit_param):
drjit.enable_grad(drjit_param)
drjit_params.append(drjit_param)

requiring_grad = brdf.get_requiring_grad()
if len(requiring_grad) > 0:
for param_name in requiring_grad:
if param_name == 'd':
enable_grad(psdr_brdf.diffuseReflectance)
elif param_name == 's':
enable_grad(psdr_brdf.specularReflectance)
elif param_name == 'r':
enable_grad(psdr_brdf.roughness)

return drjit_params
Loading