Skip to content

Commit

Permalink
Merge pull request #2 from andompesta/add-trt-support-cli-conflict
Browse files Browse the repository at this point in the history
Add trt support cli conflict
  • Loading branch information
andompesta authored Nov 26, 2024
2 parents c7fdb64 + a5986b5 commit 74c4c7a
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 7 deletions.
4 changes: 2 additions & 2 deletions src/flux/trt/engine/vae_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def __init__(

def __call__(
self,
x: torch.Tensor,
latent: torch.Tensor,
) -> torch.Tensor:
return self.decode(x)
return self.decode(latent)

def decode(self, z: torch.Tensor) -> torch.Tensor:
z = z.to(dtype=self.tensors["latent"].dtype)
Expand Down
6 changes: 1 addition & 5 deletions src/flux/trt/exporter/vae_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(
compression_factor=compression_factor,
scale_factor=model.params.scale_factor,
shift_factor=model.params.shift_factor,
model=model,
model=model.decoder, # we need to trace only the decoder
fp16=fp16,
tf32=tf32,
bf16=bf16,
Expand All @@ -55,10 +55,6 @@ def __init__(
# set proper dtype
self.prepare_model()

def get_model(self) -> torch.nn.Module:
self.model.forward = self.model.decode
return self.model

def get_input_names(self):
return ["latent"]

Expand Down

0 comments on commit 74c4c7a

Please sign in to comment.