Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reproducing Random graph matching Results #22

Open
BarakeelFanseuKamhoua opened this issue Sep 25, 2024 · 0 comments
Open

Reproducing Random graph matching Results #22

BarakeelFanseuKamhoua opened this issue Sep 25, 2024 · 0 comments

Comments

@BarakeelFanseuKamhoua
Copy link

BarakeelFanseuKamhoua commented Sep 25, 2024

Hello Rusty,
Thanks for the amazing paper. I am so sorry that I am not able to reproduce the results on erdos regny random graph matching using the softmax for the different noise levels till 0.5 as defined in figure 2 (a, b, c, and d) of the paper. My results are orders of magnitude lower using the torch model implementation and the random graph generation below:

Generate Erdos-Renyi graphs

def to_onehot(mat):
k = mat.shape[0]
encoded_arr = np.zeros((mat.size,k), dtype=int)
encoded_arr[np.arange(mat.size), mat.astype(int)] = 1
return encoded_arr

def compute_evecs(A, k, type=1):
if type==1:
return np.abs(sp.linalg.svds(A, k, which="LM")[0])
else:
D = [email protected]((A.shape[0], 1))
return to_onehot(D)

def generate_er(n, p, sigma, learned=False, feat_type=1, feat_num=20):
s = 1 - (sigma**2)(1-p)
G = np.random.uniform(size=(n,n)) < p/s
G = np.tril(G, -1) + np.tril(G, -1).T
Z1 = np.random.uniform(size=(n,n)) < s
Z1 = np.tril(Z1, -1) + np.tril(Z1, -1).T
Z2 = np.random.uniform(size=(n,n)) < s
Z2 = np.tril(Z2, -1) + np.tril(Z2, -1).T
A0 = (G * Z1).astype(float)
B0 = (G * Z2).astype(float)
P_rnd = np.eye(n)
idx = np.random.permutation(n)
P_rnd = P_rnd[:, idx]
B0 = P_rnd @ B0 @ P_rnd.T
A = A0 - p
B = B0 - p
A = A/np.sqrt(n
p*(1-p))
B = B/np.sqrt(np(1-p))
P_orig = P_rnd
if learned:
P_rnd = np.array(P_rnd.nonzero()).T
real_symm = False
if np.allclose(A @ A.T, A.T @ A) and np.allclose(B @ B.T, B @ B.T):
real_symm = True
x_A0 = compute_evecs(A0, feat_num, feat_type)
x_B0 = compute_evecs(B0, feat_num, feat_type)
return A, B, A0, B0, x_A0, x_B0, P_rnd, P_orig

Model

class DGMC(torch.nn.Module):
def init(self, psi_1, psi_2, num_steps, k=-1, detach=False):
super(DGMC, self).init()
self.psi_1 = psi_1
self.psi_2 = psi_2
self.num_steps = num_steps
self.k = k
self.detach = detach
self.mlp = nn.Sequential(nn.Linear(psi_2.output_dim, psi_2.output_dim),
nn.ReLU(), nn.Linear(psi_2.output_dim, 1),)

def reset_parameters(self):
      self.psi_1.reset_parameters()
      self.psi_2.reset_parameters()
      for layer in self.mlp:
          if hasattr(layer, 'reset_parameters'):
              layer.reset_parameters()

def forward(self, x_s, adj_s, x_t, adj_t, y=None):
      h_s = self.psi_1(adj_s, x_s)
      h_t = self.psi_1(adj_t, x_t)
      h_s, h_t = (h_s.detach(), h_t.detach()) if self.detach else (h_s, h_t)
      B, N_s, C_out = h_s.size()
      N_t = h_t.size(1)
      R_in, R_out = self.psi_2.input_dim, self.psi_2.output_dim
      S_hat = h_s @ h_t.transpose(-1, -2)
      S_0 = torch.softmax(S_hat, dim=-1)
      for _ in range(self.num_steps):
            S = torch.softmax(S_hat, dim=-1)
            r_s = torch.randn((B, N_s, R_in), dtype=h_s.dtype, device=h_s.device)
            r_t = S.transpose(-1, -2) @ r_s                
            o_s = self.psi_2(adj_s, r_s)
            o_t = self.psi_2(adj_t, r_t)
            D = o_s.view(B, N_s, 1, R_out) - o_t.view(B, 1, N_t, R_out)
            S_hat = S_hat + self.mlp(D).squeeze(-1)
      S_L = torch.softmax(S_hat, dim=-1)
      return S_0, S_L
    
def loss(self, S, y, reduction='mean', EPS=1e-8):
      assert reduction in ['none', 'mean', 'sum']
      B = S.size(0)
      nll = 0
      for i in range(B):
          val = S[i][y[i, :, 0], y[i, :, 1]]
          nll += -torch.log(val + EPS).sum()
      if reduction == 'mean':
          nll = nll / y.size(0)
      elif reduction == 'sum':
          nll = nll
      return nll

@torch.no_grad()
def acc(self, S, y, reduction='mean'):
      assert reduction in ['mean', 'sum']
      B = S.size(0)
      correct = 0
      total = 0
      for i in range(B):
          pred = S[i].argmax(dim=-1)
          correct += (pred == y[i, :, 1]).sum().item()
          total += y[i, :, 1].size(0)
      accuracy = correct / total
      return accuracy if reduction == 'mean' else correct

@torch.no_grad()
def hits_at_k(self, k, S, y, reduction='mean'):
      assert reduction in ['mean', 'sum']
      B = S.size(0)
      correct = 0
      total = 0
      for i in range(B):
          pred = S[i].argsort(dim=-1, descending=True)[:, :k]
          correct += (pred == y[i, :, 1].view(-1, 1)).sum().item()
          total += y[i, :, 1].size(0)
      hits = correct / total
      return hits if reduction == 'mean' else correct
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant