From 48dcc5f92f09665fe57473dc1fa45a73bdaadaec Mon Sep 17 00:00:00 2001 From: arda-argmax Date: Tue, 27 Aug 2024 00:09:50 -0700 Subject: [PATCH 1/4] fix layer_norm for x.shape[0]>1 --- python/src/diffusionkit/mlx/mmdit.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/src/diffusionkit/mlx/mmdit.py b/python/src/diffusionkit/mlx/mmdit.py index df65b22..65390e1 100644 --- a/python/src/diffusionkit/mlx/mmdit.py +++ b/python/src/diffusionkit/mlx/mmdit.py @@ -940,11 +940,12 @@ def affine_transform( norm_module: nn.Module = None, ) -> mx.array: """Affine transformation (Used for Adaptive LayerNorm Modulation)""" - if norm_module is not None: - return mx.fast.layer_norm( - x, 1.0 + residual_scale.squeeze(), shift.squeeze(), norm_module.eps - ) - return x * (1.0 + residual_scale) + shift + if x.shape[0] == 1: + if norm_module is not None: + return mx.fast.layer_norm( + x, 1.0 + residual_scale.squeeze(), shift.squeeze(), norm_module.eps + ) + return norm_module(x) * (1.0 + residual_scale) + shift def unpatchify( From 294ea0ffe58a6e4ace0d20361d8096332fcbd194 Mon Sep 17 00:00:00 2001 From: arda-argmax Date: Tue, 27 Aug 2024 00:56:14 -0700 Subject: [PATCH 2/4] edge case fix --- python/src/diffusionkit/mlx/mmdit.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/src/diffusionkit/mlx/mmdit.py b/python/src/diffusionkit/mlx/mmdit.py index 65390e1..d735e72 100644 --- a/python/src/diffusionkit/mlx/mmdit.py +++ b/python/src/diffusionkit/mlx/mmdit.py @@ -945,7 +945,10 @@ def affine_transform( return mx.fast.layer_norm( x, 1.0 + residual_scale.squeeze(), shift.squeeze(), norm_module.eps ) - return norm_module(x) * (1.0 + residual_scale) + shift + if norm_module is not None: + return norm_module(x) * (1.0 + residual_scale) + shift + else: + return x * (1.0 + residual_scale) + shift def unpatchify( From d6d85bfc8905ce05b86af2db8ef1cabb38caf770 Mon Sep 17 00:00:00 2001 From: arda-argmax Date: Tue, 27 Aug 2024 01:06:06 -0700 Subject: [PATCH 3/4] fixed nested if --- python/src/diffusionkit/mlx/mmdit.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/python/src/diffusionkit/mlx/mmdit.py b/python/src/diffusionkit/mlx/mmdit.py index d735e72..0bb4dac 100644 --- a/python/src/diffusionkit/mlx/mmdit.py +++ b/python/src/diffusionkit/mlx/mmdit.py @@ -940,12 +940,11 @@ def affine_transform( norm_module: nn.Module = None, ) -> mx.array: """Affine transformation (Used for Adaptive LayerNorm Modulation)""" - if x.shape[0] == 1: - if norm_module is not None: - return mx.fast.layer_norm( - x, 1.0 + residual_scale.squeeze(), shift.squeeze(), norm_module.eps - ) - if norm_module is not None: + if x.shape[0] == 1 and norm_module is not None: + return mx.fast.layer_norm( + x, 1.0 + residual_scale.squeeze(), shift.squeeze(), norm_module.eps + ) + elif norm_module is not None: return norm_module(x) * (1.0 + residual_scale) + shift else: return x * (1.0 + residual_scale) + shift From ad0ad240895d000c7131d753ea7ac60dd424aa12 Mon Sep 17 00:00:00 2001 From: arda-argmax Date: Tue, 27 Aug 2024 01:14:39 -0700 Subject: [PATCH 4/4] update version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 2df073f..addace9 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import find_packages, setup from setuptools.command.install import install -VERSION = "0.3.2" +VERSION = "0.3.3" class VersionInstallCommand(install):