Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The Training Time Problem(Solved) #2

Open
bohanfeng opened this issue Apr 1, 2021 · 2 comments
Open

The Training Time Problem(Solved) #2

bohanfeng opened this issue Apr 1, 2021 · 2 comments

Comments

@bohanfeng
Copy link

bohanfeng commented Apr 1, 2021

No description provided.

@bohanfeng bohanfeng changed the title The Training Time Problem The Training Time Problem(Solved) Apr 11, 2021
@QingbiaoLi
Copy link
Collaborator

Hello Bohan,

Thanks for your email.

According to the training time, you can try the following method.

  1. Use SSD to store the data,
  2. Use convert the N in to batch size B to speed up a bit of the training progress.
    def forward(self, inputTensor):
        B = inputTensor.shape[0] # batch size
        # N = inputTensor.shape[1]
        # C =
        (B,N,C,W,H) = inputTensor.shape
        # print(inputTensor.shape)
        # print(B,N,C,W,H)
        # B x G x N
        input_currentAgent = inputTensor.reshape(B*N,C,W,H).to(self.config.device)
        featureMap = self.ConvLayers(input_currentAgent).to(self.config.device)
        featureMapFlatten = featureMap.view(featureMap.size(0), -1).to(self.config.device)
        compressfeature = self.compressMLP(featureMapFlatten).to(self.config.device)
        extractFeatureMap_old = compressfeature.reshape(B,N,self.numFeatures2Share).to(self.config.device)
        extractFeatureMap = extractFeatureMap_old.permute([0,2,1]).to(self.config.device)
        # DCP
        for l in range(self.L):
            # \\ Graph filtering stage:
            # There is a 3*l below here, because we have three elements per
            # layer: graph filter, nonlinearity and pooling, so after each layer
            # we're actually adding elements to the (sequential) list.
            self.GFL[2 * l].addGSO(self.S) # add GSO for GraphFilter
        # B x F x N - > B x G x N,
        sharedFeature = self.GFL(extractFeatureMap)
        (_, num_G, _) = sharedFeature.shape
        sharedFeature_permute =sharedFeature.permute([0,2,1]).to(self.config.device)
        sharedFeature_stack = sharedFeature_permute.reshape(B*N,num_G)
        action_predict = self.actionsMLP(sharedFeature_stack)
        return action_predict 

@QingbiaoLi QingbiaoLi reopened this Apr 11, 2021
@bohanfeng
Copy link
Author

Thanks very much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants