Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sharded Llama export test also compile to IREE module and verify numerics #237

Merged
268 changes: 234 additions & 34 deletions sharktank/tests/models/llama/sharded_llama_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,20 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import unittest
from typing import Any, Dict, List, Tuple
from typing import Any, List, Tuple, Union
import typing
sogartar marked this conversation as resolved.
Show resolved Hide resolved
import collections.abc
from collections import OrderedDict
from sharktank.models.llama.llama import LlamaModelConfig, PagedLlamaModelV1
import sharktank.ops as ops
from sharktank.types import Dataset
from sharktank.types import (
unbox_tensor,
ShardedTensor,
InferenceTensor,
DefaultPrimitiveTensor,
Dataset,
AnyTensor,
)
from sharktank.models.llama.testing import make_random_llama_theta
from sharktank.models.llama.sharding import shard_theta
from sharktank.layers.configs import LlamaHParams
Expand All @@ -18,6 +28,93 @@
import torch
from copy import deepcopy
from shark_turbine.aot import FxProgramsBuilder, export
import iree.runtime
from pathlib import Path


def get_iree_devices(driver: str, device_count: int) -> List[iree.runtime.HalDevice]:
hal_driver = iree.runtime.get_driver(driver)
available_devices = hal_driver.query_available_devices()
# Use the same actual device for all devices.
return [hal_driver.create_device(available_devices[0]) for _ in range(device_count)]


def load_iree_module(
module_path: str,
parameters_path: str,
devices: List[iree.runtime.HalDevice],
) -> Tuple[iree.runtime.VmModule, iree.runtime.VmContext, iree.runtime.VmInstance]:
params_path = Path(parameters_path)
# TODO: make IREE able to load the parameters from the top parameter file
# without having to specify the parameter file for each shard separately.
parameter_index = iree.runtime.ParameterIndex()
for i in range(len(devices)):
parameter_index.load(
file_path=str(
Path(params_path).with_suffix(f".rank{i}{params_path.suffix}")
)
)
parameter_provider = parameter_index.create_provider(scope="model")
vm_instance = iree.runtime.VmInstance()
parameters_module = iree.runtime.create_io_parameters_module(
vm_instance, parameter_provider
)
vm_module = iree.runtime.VmModule.mmap(vm_instance, str(module_path))
hal_module = iree.runtime.create_hal_module(instance=vm_instance, devices=devices)
vm_context = iree.runtime.VmContext(
instance=vm_instance, modules=(hal_module, parameters_module, vm_module)
)
return vm_module, vm_context, vm_instance


def run_iree_module_function(
module: iree.runtime.VmModule,
vm_context: iree.runtime.VmContext,
function_name: str,
args: List[iree.runtime.DeviceArray],
driver: str,
) -> List[iree.runtime.DeviceArray]:
vm_function = module.lookup_function(function_name)
invoker = iree.runtime.FunctionInvoker(
vm_context=vm_context,
# TODO: rework iree.runtime.FunctionInvoker interface for multiple devices.
# This works, but does not look right.
device=iree.runtime.get_device(driver, cache=False),
vm_function=vm_function,
)
res = invoker(*args)
if isinstance(res, iree.runtime.DeviceArray):
res = (res,)
return res


def prepare_iree_module_function_args(
args: List[Union[AnyTensor, List[AnyTensor]]], devices: List[iree.runtime.HalDevice]
) -> List[iree.runtime.DeviceArray]:
res = []
for arg in args:
if isinstance(arg, ShardedTensor):
assert len(devices) == len(arg.shards)
res.extend(
[
prepare_iree_module_function_args([shard], [device])[0]
for shard, device in zip(arg.shards, devices)
]
)
elif isinstance(arg, (DefaultPrimitiveTensor, torch.Tensor)):
res.append(
iree.runtime.asdevicearray(
devices[0], unbox_tensor(arg).to("cpu").numpy()
sogartar marked this conversation as resolved.
Show resolved Hide resolved
)
)
else:
assert isinstance(arg, collections.abc.Sequence)
res.extend(prepare_iree_module_function_args(arg, devices))
return res


def iree_to_torch(*tensors: iree.runtime.DeviceArray) -> List[torch.Tensor]:
return [torch.tensor(tensor.to_host()) for tensor in tensors]


class ShardedLlamaTest(unittest.TestCase):
Expand Down Expand Up @@ -53,6 +150,8 @@ def setUp(self):
activation_dtype=self.dtype,
attention_dtype=self.dtype,
)
self.sharded_config = deepcopy(self.config)
self.sharded_config.tensor_parallelism_size = 2
self.theta = make_random_llama_theta(
config=self.config,
vocab_size=self.vocabulary_size,
Expand All @@ -61,7 +160,9 @@ def setUp(self):
[14, 9, self.block_seq_stride - 1], dtype=torch.int32
)

def make_prefill_args(self, model: PagedLlamaModelV1) -> Dict[str, Any]:
def make_prefill_args(
self, model: PagedLlamaModelV1
) -> typing.OrderedDict[str, Any]:
batch_seq_len = round_up_to_multiple_of(
int(torch.max(self.prefill_seq_lens)), model.cache.pad_sequence_stride
)
Expand All @@ -79,16 +180,18 @@ def make_prefill_args(self, model: PagedLlamaModelV1) -> Dict[str, Any]:
).view(self.batch_size, -1)
cache_state = model.cache.paged.allocate(page_count=self.cache_page_count)
cache_state = [torch.rand_like(cache_state[0])]
return {
"tokens": token_ids,
"attention_mask": attention_mask,
"seq_block_ids": seq_block_ids,
"cache_state": cache_state,
}
return OrderedDict(
[
("tokens", token_ids),
("attention_mask", attention_mask),
("seq_block_ids", seq_block_ids),
("cache_state", cache_state),
]
)

def make_equal_unsharded_and_sharded_prefill_args(
self, model: PagedLlamaModelV1, sharded_model: PagedLlamaModelV1
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
) -> Tuple[typing.OrderedDict[str, Any], typing.OrderedDict[str, Any]]:
prefill_args = self.make_prefill_args(model)
sharded_cache_state = sharded_model.cache.paged.allocate(
page_count=self.cache_page_count
Expand All @@ -103,7 +206,9 @@ def make_equal_unsharded_and_sharded_prefill_args(
sharded_prefill_args["cache_state"] = sharded_cache_state
return prefill_args, sharded_prefill_args

def make_decode_args(self, model: PagedLlamaModelV1) -> Dict[str, Any]:
def make_decode_args(
self, model: PagedLlamaModelV1
) -> typing.OrderedDict[str, Any]:
start_positions = self.prefill_seq_lens.clone()
seq_lens = self.prefill_seq_lens + 1
batch_seq_len = round_up_to_multiple_of(
Expand All @@ -123,17 +228,19 @@ def make_decode_args(self, model: PagedLlamaModelV1) -> Dict[str, Any]:
).view(self.batch_size, -1)
cache_state = model.cache.paged.allocate(page_count=self.cache_page_count)
cache_state = [torch.rand_like(cache_state[0])]
return {
"tokens": decode_token_ids,
"attention_mask": attention_mask,
"start_positions": start_positions,
"seq_block_ids": seq_block_ids,
"cache_state": cache_state,
}
return OrderedDict(
[
("tokens", decode_token_ids),
("attention_mask", attention_mask),
("start_positions", start_positions),
("seq_block_ids", seq_block_ids),
("cache_state", cache_state),
]
)

def make_equal_unsharded_and_sharded_decode_args(
self, model: PagedLlamaModelV1, sharded_model: PagedLlamaModelV1
):
) -> Tuple[typing.OrderedDict[str, Any], typing.OrderedDict[str, Any]]:
sogartar marked this conversation as resolved.
Show resolved Hide resolved
decode_args = self.make_decode_args(model)
sharded_decode_args = deepcopy(decode_args)
sharded_decode_args["cache_state"] = sharded_model.cache.paged.shard_state(
Expand All @@ -145,10 +252,8 @@ def testCompareToySizedModelToUnsharded(self):
"""Run a sharded variant of a toy model size and compare it against the
unsharded variant."""
model = PagedLlamaModelV1(self.theta, self.config)
sharded_config = deepcopy(self.config)
sharded_config.tensor_parallelism_size = 2
sharded_theta = shard_theta(self.theta, sharded_config)
sharded_model = PagedLlamaModelV1(sharded_theta, sharded_config)
sharded_theta = shard_theta(self.theta, self.sharded_config)
sharded_model = PagedLlamaModelV1(sharded_theta, self.sharded_config)

# Verify prefill step.
(
Expand Down Expand Up @@ -194,20 +299,28 @@ def testCompareToySizedModelToUnsharded(self):
actual_decode_cache_state, expected_decode_cache_state, atol=1e-4, rtol=1e-4
)

def testExportToySizedModelToMlir(self):
@unittest.skip(
(
"Before this does not crash at all we need "
"https://github.com/iree-org/iree/pull/18663 merged."
)
)
def testExportAndRunToySizedModelWithIree(self):
"""Test exporting to MLIR and compiling with IREE the sharded Llama model.
Test numerical accuracy of the IREE module against PyTorch."""

with tempfile.TemporaryDirectory() as temp_dir:
sharded_config = deepcopy(self.config)
sharded_config.tensor_parallelism_size = 2
sharded_theta = shard_theta(self.theta, sharded_config)
sharded_theta = shard_theta(self.theta, self.sharded_config)
sharded_theta.rename_tensors_to_paths()
sharded_dataset = Dataset({}, sharded_theta)
parameters_path = f"{temp_dir}/parameters.irpa"
sharded_dataset.save(f"{temp_dir}/parameters.irpa")
sharded_dataset = Dataset.load(parameters_path, mmap=False)
sharded_parameters_path = f"{temp_dir}/parameters.irpa"
sharded_dataset.save(sharded_parameters_path)
sharded_dataset = Dataset.load(sharded_parameters_path, mmap=False)
iree_driver = "local-task"

model = PagedLlamaModelV1(self.theta, self.config)
sharded_model = PagedLlamaModelV1(
sharded_dataset.root_theta, sharded_config
sharded_dataset.root_theta, self.sharded_config
)
sharded_fxb = FxProgramsBuilder(sharded_model)

Expand All @@ -222,9 +335,10 @@ def testExportToySizedModelToMlir(self):
def _(model, *args, **kwargs) -> torch.Tensor:
return model.prefill(*args, **kwargs)

_, sharded_decode_args = self.make_equal_unsharded_and_sharded_decode_args(
model, sharded_model
)
(
_,
sharded_decode_args,
) = self.make_equal_unsharded_and_sharded_decode_args(model, sharded_model)
# TODO: remove strict=False when
# https://github.com/pytorch/pytorch/issues/136757
# is resolved.
Expand All @@ -237,5 +351,91 @@ def _(model, *args, **kwargs) -> torch.Tensor:
def _(model, *args, **kwargs) -> torch.Tensor:
return model.decode(*args, **kwargs)

# Compile the IREE module.
output = export(sharded_fxb)
output.save_mlir(f"{temp_dir}/program.mlir")
output.session.set_flags(
*[
f"--iree-hal-target-device=llvm-cpu[{i}]"
for i in range(self.sharded_config.tensor_parallelism_size)
]
)
iree_module_path = f"{temp_dir}/program.vmfb"
output.compile(
save_to=iree_module_path,
target_backends=None,
)

iree_devices = get_iree_devices(
driver=iree_driver,
device_count=self.sharded_config.tensor_parallelism_size,
)
iree_module, vm_context, vm_instance = load_iree_module(
module_path=iree_module_path,
devices=iree_devices,
parameters_path=sharded_parameters_path,
)

# Check IREE's prefill step is close to torch.
prefill_iree_args = prepare_iree_module_function_args(
args=deepcopy(sharded_prefill_args).values(), devices=iree_devices
)
prefill_iree_result = run_iree_module_function(
args=prefill_iree_args,
function_name="prefill",
module=iree_module,
vm_context=vm_context,
driver=iree_driver,
)
prefill_iree_result = iree_to_torch(*prefill_iree_result)
assert len(prefill_iree_result) == 1
expected_prefill_result = sharded_model.prefill(**sharded_prefill_args)
# TODO: Although, not entirely wrong, investigate why this accuracy is that
# low for fp32 (atol=0.0011, rtol=0.013).
torch.testing.assert_close(
prefill_iree_result[0],
expected_prefill_result,
)
prefill_iree_cache_state_shards = prefill_iree_args[
-self.config.tensor_parallelism_size - 1 :
]
prefill_iree_cache_state_shards = iree_to_torch(
*prefill_iree_cache_state_shards
)
for actual_cache_state_shard, expected_cache_state_shard in zip(
prefill_iree_cache_state_shards,
sharded_prefill_args["cache_state"][0].shards,
):
# TODO: debug inaccuracy.
torch.testing.assert_close(
actual_cache_state_shard, unbox_tensor(expected_cache_state_shard)
)

# Check IREE's decode step is close to torch.
decode_iree_args = prepare_iree_module_function_args(
args=deepcopy(sharded_decode_args).values(), devices=iree_devices
)
decode_iree_result = run_iree_module_function(
args=decode_iree_args,
function_name="decode",
module=iree_module,
vm_context=vm_context,
)
decode_iree_result = iree_to_torch(*decode_iree_result)
expected_decode_result = sharded_model.decode(**sharded_decode_args)
# TODO: debug inaccuracy.
torch.testing.assert_close(decode_iree_result[0], expected_decode_result)
decode_iree_cache_state_shards = decode_iree_args[
-self.config.tensor_parallelism_size - 1 :
]
decode_iree_cache_state_shards = iree_to_torch(
*decode_iree_cache_state_shards
)
for actual_cache_state_shard, expected_cache_state_shard in zip(
decode_iree_cache_state_shards,
sharded_decode_args["cache_state"][0].shards,
):
# TODO: debug inaccuracy.
torch.testing.assert_close(
actual_cache_state_shard, unbox_tensor(expected_cache_state_shard)
)
Loading