Skip to content

Commit

Permalink
add bf16 support (deepjavalibrary#700)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qing Lan authored May 10, 2023
1 parent c39ded1 commit 5f6f86a
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 11 deletions.
11 changes: 11 additions & 0 deletions .github/workflows/llm_inf2_integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,17 @@ jobs:
python3 llm/client.py stable-diffusion stable-diffusion-2.1-base-neuron
docker rm -f $(docker ps -aq)
sudo rm -rf models
- name: Test stable diffusion bf16 with handler
working-directory: tests/integration
run: |
rm -rf models
python3 llm/prepare.py transformers_neuronx stable-diffusion-2.1-base-neuron-bf16
./launch_container.sh deepjavalibrary/djl-serving:$DJLSERVING_DOCKER_TAG $PWD/models pytorch-inf2-2 \
serve
curl http://127.0.0.1:8080/models
python3 llm/client.py stable-diffusion stable-diffusion-2.1-base-neuron
docker rm -f $(docker ps -aq)
sudo rm -rf models
- name: On fail step
if: ${{ failure() }}
working-directory: tests/integration
Expand Down
25 changes: 14 additions & 11 deletions engines/python/setup/djl_python/stable_diffusion_inf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,10 @@ def forward(self,
timestep,
encoder_hidden_states,
cross_attention_kwargs=None):
sample = self.unetwrap(sample,
timestep.float().expand((sample.shape[0], )),
encoder_hidden_states)[0]
sample = self.unetwrap(
sample,
timestep.to(sample.dtype).expand((sample.shape[0], )),
encoder_hidden_states)[0]
return UNet2DConditionOutput(sample=sample)


Expand All @@ -78,10 +79,10 @@ def forward(self, emb, attention_mask=None):
def get_torch_dtype_from_str(dtype: str):
if dtype == "fp32":
return torch.float32
elif dtype == "fp16":
return torch.float16
elif dtype == "bf16":
return torch.bfloat16
raise ValueError(
f"Invalid data type: {dtype}. DeepSpeed currently only supports fp16 for stable diffusion"
f"Invalid data type: {dtype}. NeuronX currently only supports fp32 and bf16 for stable diffusion"
)


Expand Down Expand Up @@ -197,9 +198,11 @@ def runtime_compile(self):

self.pipeline.unet = NeuronUNet(UNetWrap(self.pipeline.unet))

sample_1b = torch.randn([1, 4, 64, 64])
timestep_1b = torch.tensor(999).float().expand((1, ))
encoder_hidden_states_1b = torch.randn([1, 77, 1024])
sample_1b = torch.randn([1, 4, 64, 64]).to(self.data_type)
timestep_1b = torch.tensor(999).float().expand(
(1, )).to(self.data_type)
encoder_hidden_states_1b = torch.randn([1, 77,
1024]).to(self.data_type)
example_inputs = sample_1b, timestep_1b, encoder_hidden_states_1b

logging.info("Compiling UNET...")
Expand All @@ -214,7 +217,7 @@ def runtime_compile(self):

logging.info("Compiling post_quant_conv_in...")
# Compile vae post_quant_conv
post_quant_conv_in = torch.randn([1, 4, 64, 64])
post_quant_conv_in = torch.randn([1, 4, 64, 64]).to(self.data_type)
self.pipeline.vae.post_quant_conv = torch_neuronx.trace(
self.pipeline.vae.post_quant_conv,
post_quant_conv_in,
Expand All @@ -223,7 +226,7 @@ def runtime_compile(self):

logging.info("Compiling VAE Decoder...")
# Compile vae decoder
decoder_in = torch.randn([1, 4, 64, 64])
decoder_in = torch.randn([1, 4, 64, 64]).to(self.data_type)
self.pipeline.vae.decoder = torch_neuronx.trace(
self.pipeline.vae.decoder,
decoder_in,
Expand Down
6 changes: 6 additions & 0 deletions tests/integration/llm/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,12 @@
"option.model_id": "s3://djl-llm/stable-diffusion-2-1-base-compiled/",
"option.tensor_parallel_degree": 2,
"option.use_stable_diffusion": True
},
"stable-diffusion-2.1-base-neuron-bf16": {
"option.model_id": "s3://djl-llm/stable-diffusion-2-1-base-compiled-bf16/",
"option.tensor_parallel_degree": 2,
"option.dtype": "bf16",
"option.use_stable_diffusion": True
}
}

Expand Down

0 comments on commit 5f6f86a

Please sign in to comment.