diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index da26d231920..ecfe7d59672 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -37,11 +37,6 @@ jobs: - name: Checkout repository uses: actions/checkout@v4 - - name: Initialize Docker Buildx - uses: docker/setup-buildx-action@v3 - with: - install: true - - name: Inject slug/short variables uses: rlespinasse/github-slug-action@v4.4.1 @@ -82,7 +77,12 @@ jobs: uses: huggingface/tailscale-action@main with: authkey: ${{ secrets.TAILSCALE_AUTHKEY }} - + slackChannel: ${{ secrets.SLACK_CIFEEDBACK_CHANNEL }} + slackToken: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} + - name: Initialize Docker Buildx + uses: docker/setup-buildx-action@v2.0.0 + with: + install: true - name: Login to GitHub Container Registry if: github.event_name != 'pull_request' uses: docker/login-action@v3 @@ -148,6 +148,7 @@ jobs: DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }} tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }} labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }} + network: host cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache${{ env.LABEL }},mode=min cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache${{ env.LABEL }},mode=min @@ -209,12 +210,6 @@ jobs: run: | make install-integration-tests - - name: Tailscale - uses: huggingface/tailscale-action@main - if: needs.build-and-push.outputs.runs_on != 'amd-gpu-tgi' - with: - authkey: ${{ secrets.TAILSCALE_AUTHKEY }} - - name: Run tests run: | export DOCKER_DEVICES=${{ needs.build-and-push.outputs.docker_devices }} @@ -229,3 +224,9 @@ jobs: echo $SYSTEM pytest -s -vvvvv integration-tests + + - name: Tailscale Wait + if: ${{ failure() || runner.debug == '1' }} + uses: huggingface/tailscale-action@main + with: + waitForSSH: true diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 74479cc6c0b..d5ad9da3b80 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -33,9 +33,9 @@ jobs: - name: Install Rust uses: actions-rs/toolchain@v1 with: - # Released on: 02 May, 2024 - # https://releases.rs/docs/1.78.0/ - toolchain: 1.78.0 + # Released on: June 13, 2024 + # https://releases.rs/docs/1.79.0/ + toolchain: 1.79.0 override: true components: rustfmt, clippy - name: Install Protoc @@ -72,7 +72,7 @@ jobs: - name: Run server tests run: | pip install pytest - export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} + export HUGGING_FACE_HUB_TOKEN=${{ secrets.HF_TOKEN }} pytest -s -vv server/tests - name: Pre-commit checks run: | diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 00000000000..b23f3150a5a --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,133 @@ + +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, caste, color, religion, or sexual +identity and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the overall + community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or advances of + any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email address, + without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +feedback@huggingface.co. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series of +actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or permanent +ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within the +community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.1, available at +[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. + +Community Impact Guidelines were inspired by +[Mozilla's code of conduct enforcement ladder][Mozilla CoC]. + +For answers to common questions about this code of conduct, see the FAQ at +[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at +[https://www.contributor-covenant.org/translations][translations]. + +[homepage]: https://www.contributor-covenant.org +[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html +[Mozilla CoC]: https://github.com/mozilla/diversity +[FAQ]: https://www.contributor-covenant.org/faq +[translations]: https://www.contributor-covenant.org/translations diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000000..d541e47f3dd --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,120 @@ + + +# Contribute to text-generation-inference + +Everyone is welcome to contribute, and we value everybody's contribution. Code +contributions are not the only way to help the community. Answering questions, helping +others, and improving the documentation are also immensely valuable. + +It also helps us if you spread the word! Reference the library in blog posts +about the awesome projects it made possible, shout out on Twitter every time it has +helped you, or simply ⭐️ the repository to say thank you. + +However you choose to contribute, please be mindful and respect our +[code of conduct](https://github.com/huggingface/text-generation-inference/blob/main/CODE_OF_CONDUCT.md). + +**This guide was heavily inspired by the awesome [scikit-learn guide to contributing](https://github.com/scikit-learn/scikit-learn/blob/main/CONTRIBUTING.md).** + +## Ways to contribute + +There are several ways you can contribute to text-generation-inference. + +* Fix outstanding issues with the existing code. +* Submit issues related to bugs or desired new features. +* Contribute to the examples or to the documentation. + +> All contributions are equally valuable to the community. 🥰 + +## Fixing outstanding issues + +If you notice an issue with the existing code and have a fix in mind, feel free to [start contributing](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request) and open +a Pull Request! + +## Submitting a bug-related issue or feature request + +Do your best to follow these guidelines when submitting a bug-related issue or a feature +request. It will make it easier for us to come back to you quickly and with good +feedback. + +### Did you find a bug? + +The text-generation-inference library is robust and reliable thanks to users who report the problems they encounter. + +Before you report an issue, we would really appreciate it if you could **make sure the bug was not +already reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the +library itself, and not your code. + +Once you've confirmed the bug hasn't already been reported, please include the following information in your issue so +we can quickly resolve it: + +* Your **OS type and version**, as well as your environment versions (versions of rust, python, and dependencies). +* A short, self-contained, code snippet that allows us to reproduce the bug. +* The *full* traceback if an exception is raised. +* Attach any other additional information, like screenshots, you think may help. + +To get the OS and software versions automatically, you can re-run the launcher with the `--env` flag: + +```bash +text-generation-launcher --env +``` + +This will precede the launch of the model with the information relative to your environment. We recommend pasting +that in your issue report. + +### Do you want a new feature? + +If there is a new feature you'd like to see in text-generation-inference, please open an issue and describe: + +1. What is the *motivation* behind this feature? Is it related to a problem or frustration with the library? Is it + a feature related to something you need for a project? Is it something you worked on and think it could benefit + the community? + + Whatever it is, we'd love to hear about it! + +2. Describe your requested feature in as much detail as possible. The more you can tell us about it, the better + we'll be able to help you. +3. Provide a *code snippet* that demonstrates the feature's usage. +4. If the feature is related to a paper, please include a link. + +If your issue is well written we're already 80% of the way there by the time you create it. + +We have added [templates](https://github.com/huggingface/text-generation-inference/tree/main/.github/ISSUE_TEMPLATE) +to help you get started with your issue. + +## Do you want to implement a new model? + +New models are constantly released and if you want to implement a new model, please provide the following information: + +* A short description of the model and a link to the paper. +* Link to the implementation if it is open-sourced. +* Link to the model weights if they are available. + +If you are willing to contribute the model yourself, let us know so we can help you add it to text-generation-inference! + +## Do you want to add documentation? + +We're always looking for improvements to the documentation that make it more clear and accurate. Please let us know +how the documentation can be improved such as typos and any content that is missing, unclear or inaccurate. We'll be +happy to make the changes or help you make a contribution if you're interested! + +## I want to become a maintainer of the project. How do I get there? + +TGI is a project led and managed by Hugging Face as it powers our internal services. However, we are happy to have +motivated individuals from other organizations join us as maintainers with the goal of making TGI the best inference +service. + +If you are such an individual (or organization), please reach out to us and let's collaborate. diff --git a/Dockerfile b/Dockerfile index 1462833934c..c93372a2f1d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # Rust builder -FROM lukemathwalker/cargo-chef:latest-rust-1.78 AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse diff --git a/Dockerfile_amd b/Dockerfile_amd index c79bc03c5b3..55da92046b7 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -1,5 +1,5 @@ # Rust builder -FROM lukemathwalker/cargo-chef:latest-rust-1.78 AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse diff --git a/Dockerfile_intel b/Dockerfile_intel index cb0e84bb23c..35362fc91cf 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -1,4 +1,4 @@ -FROM lukemathwalker/cargo-chef:latest-rust-1.78 AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse diff --git a/benchmark/src/app.rs b/benchmark/src/app.rs index 48ac976a0c2..a0a9313a198 100644 --- a/benchmark/src/app.rs +++ b/benchmark/src/app.rs @@ -497,7 +497,7 @@ fn statis_spans<'a>(data: &[f64], unit: &'static str) -> Vec> { "Lowest: {:.2} {unit}", data.iter() .min_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN) + .unwrap_or(&f64::NAN) ), Style::default().fg(Color::Reset), )]), @@ -506,7 +506,7 @@ fn statis_spans<'a>(data: &[f64], unit: &'static str) -> Vec> { "Highest: {:.2} {unit}", data.iter() .max_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN) + .unwrap_or(&f64::NAN) ), Style::default().fg(Color::Reset), )]), @@ -555,17 +555,17 @@ fn latency_throughput_chart<'a>( let min_latency: f64 = *latency_iter .clone() .min_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN); + .unwrap_or(&f64::NAN); let max_latency: f64 = *latency_iter .max_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN); + .unwrap_or(&f64::NAN); let min_throughput: f64 = *throughput_iter .clone() .min_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN); + .unwrap_or(&f64::NAN); let max_throughput: f64 = *throughput_iter .max_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN); + .unwrap_or(&f64::NAN); // Char min max values let min_x = if zoom { diff --git a/benchmark/src/table.rs b/benchmark/src/table.rs index e18d7310a35..1585a25f4fc 100644 --- a/benchmark/src/table.rs +++ b/benchmark/src/table.rs @@ -156,17 +156,17 @@ fn avg_min_max(data: &[f64]) -> (f64, f64, f64) { let min = data .iter() .min_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN); + .unwrap_or(&f64::NAN); let max = data .iter() .max_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN); + .unwrap_or(&f64::NAN); (average, *min, *max) } fn px(data: &[f64], p: u32) -> f64 { let i = (f64::from(p) / 100.0 * data.len() as f64) as usize; - *data.get(i).unwrap_or(&std::f64::NAN) + *data.get(i).unwrap_or(&f64::NAN) } fn format_value(value: f64, unit: &'static str) -> String { diff --git a/benchmark/src/utils.rs b/benchmark/src/utils.rs index d096d65510f..20469991c39 100644 --- a/benchmark/src/utils.rs +++ b/benchmark/src/utils.rs @@ -37,7 +37,7 @@ pub(crate) fn percentiles(values: &[f64], pecents: &[i32]) -> BTreeMap Result<(), Box> { - println!("cargo:rerun-if-changed=../../proto/**"); + println!("cargo:rerun-if-changed=../../proto/"); fs::create_dir_all("src/v2/pb").unwrap_or(()); let mut config = prost_build::Config::new(); diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 507ee859411..8c77896e9e0 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,5 +1,5 @@ [toolchain] -# Released on: 02 May, 2024 -# https://releases.rs/docs/1.78.0/ -channel = "1.78.0" +# Released on: June 13, 2024 +# https://releases.rs/docs/1.79.0/ +channel = "1.79.0" components = ["rustfmt", "clippy"] diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index b1de58b2396..df5a8ae95f3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -40,31 +40,12 @@ def _load_gqa(config, prefix: str, weights): assert config.hidden_size % config.num_attention_heads == 0 assert config.num_attention_heads % weights.process_group.size() == 0 - weight = weights.get_multi_weights_col( + return TensorParallelColumnLinear.load_multi( + config, prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, dim=0, - ) - - if config.quantize not in ["gptq", "awq", "marlin"]: - weight = weight.to(dtype=weights.dtype).to(device=weights.device) - - head_size = config.hidden_size // config.num_attention_heads - num_heads = config.num_attention_heads // weights.process_group.size() - num_key_value_heads = config.num_key_value_heads // weights.process_group.size() - assert list(weight.shape) == [ - (num_heads + 2 * num_key_value_heads) * head_size, - config.hidden_size, - ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" - - w = [ - weights.get_sharded(f"{p}.bias", dim=0) - for p in [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"] - ] - bias = torch.cat(w, dim=0).to(dtype=weights.dtype).to(device=weights.device) - - return TensorParallelColumnLinear( - get_linear(weight, bias=bias, quantize=config.quantize) + weights=weights, + bias=True, ) diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 569b6925a0e..a0347cd8e73 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -240,7 +240,11 @@ async def serve_inner( interceptors=[ ExceptionInterceptor(), UDSOpenTelemetryAioServerInterceptor(), - ] + ], + options=[ + # Set the maximum possible message length: i32::MAX + ("grpc.max_receive_message_length", (1 << 31) - 1) + ], ) generate_pb2_grpc.add_TextGenerationServiceServicer_to_server( TextGenerationService(model, Cache(), quantize, server_urls), server diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 0ac67bf0cb3..ea96c9ad1db 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -130,29 +130,57 @@ def get_sharded(self, tensor_name: str, dim: int): ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" return self.get_partial_sharded(tensor_name, dim) - def _get_qweight(self, name: str, block_sizes: Union[int, List[int]]): - slice_ = self._get_slice(name) - total_size = slice_.get_shape()[1] + def get_packed_sharded( + self, tensor_name: str, dim: int, block_sizes: Union[int, List[int]] + ) -> torch.Tensor: + """ + Get a shard from a tensor that packs multiple tensors. + + When a tensor packs multiple tensors (such as QKV or an up + projection + gate projection), sharding with `get_sharded` is not + safe since it would not split the packed tensors across shards. + + This method shards a tensor, such that the packed tensors are + split across shards. + + The columns are split in equally sized blocks when blocks is an `int`, or + in blocks proportional given to the sizes. For instance `[2, 1, 1]` will + divide an input with dimensionality `1024` in `[512, 256, 256]`. This is + convenient for e.g. splitting QKV without knowing the storage details of + quantized weights. + """ + slice_ = self._get_slice(tensor_name) + total_size = slice_.get_shape()[dim] block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes) world_size = self.process_group.size() rank = self.process_group.rank() - weights = [] + tensors = [] block_offset = 0 for block_size in block_sizes: assert ( block_size % world_size == 0 - ), f"Prepacked qkv cannot be sharded across {world_size} shards" + ), f"Prepacked tensor cannot be sharded across {world_size} shards" shard_block_size = block_size // world_size start = rank * shard_block_size stop = (rank + 1) * shard_block_size - weights.append(slice_[:, block_offset + start : block_offset + stop]) + if dim == 0: + tensor = slice_[block_offset + start : block_offset + stop] + elif dim == 1: + tensor = slice_[:, block_offset + start : block_offset + stop] + else: + raise NotImplementedError("Currently only dim=0 or dim=1 is supported") + tensors.append(tensor) block_offset += block_size + tensor = torch.cat(tensors, dim=dim) + tensor = tensor.to(device=self.device) - weight = torch.cat(weights, dim=1) - weight = weight.to(device=self.device) - return weight + # Avoid casting quantizer dtypes. + if tensor.dtype not in [torch.int16, torch.int32, torch.int64]: + tensor = tensor.to(dtype=self.dtype) + + return tensor def get_weights_col_packed_qkv( self, @@ -185,7 +213,9 @@ def get_weights_col_packed( from text_generation_server.layers.gptq import GPTQWeight try: - qweight = self._get_qweight(f"{prefix}.qweight", block_sizes) + qweight = self.get_packed_sharded( + f"{prefix}.qweight", dim=1, block_sizes=block_sizes + ) except RuntimeError: raise RuntimeError( f"Cannot load `{quantize}` weight, make sure the model is already quantized." @@ -193,8 +223,12 @@ def get_weights_col_packed( gptq_params = self._get_gptq_params() - qzeros = self._get_qweight(f"{prefix}.qzeros", block_sizes) - scales = self._get_qweight(f"{prefix}.scales", block_sizes) + qzeros = self.get_packed_sharded( + f"{prefix}.qzeros", dim=1, block_sizes=block_sizes + ) + scales = self.get_packed_sharded( + f"{prefix}.scales", dim=1, block_sizes=block_sizes + ) scales = scales.to(dtype=self.dtype) if quantize == "gptq" and gptq_params.quant_method == "gptq": @@ -237,13 +271,17 @@ def get_weights_col_packed( if quant_method == "gptq": gptq_params = self._get_gptq_params() try: - qweight = self._get_qweight(f"{prefix}.qweight", block_sizes) + qweight = self.get_packed_sharded( + f"{prefix}.qweight", dim=1, block_sizes=block_sizes + ) except RuntimeError: raise RuntimeError( f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" ) - scales = self._get_qweight(f"{prefix}.scales", block_sizes) + scales = self.get_packed_sharded( + f"{prefix}.scales", dim=1, block_sizes=block_sizes + ) g_idx = self.get_tensor(f"{prefix}.g_idx") weight = repack_gptq_for_marlin( qweight=qweight, @@ -257,34 +295,17 @@ def get_weights_col_packed( ) else: - B = self._get_qweight(f"{prefix}.B", block_sizes) - s = self._get_qweight(f"{prefix}.s", block_sizes) + B = self.get_packed_sharded( + f"{prefix}.B", dim=1, block_sizes=block_sizes + ) + s = self.get_packed_sharded( + f"{prefix}.s", dim=1, block_sizes=block_sizes + ) weight = MarlinWeight(B=B, s=s) else: - slice_ = self._get_slice(f"{prefix}.weight") - total_size = slice_.get_shape()[0] - block_sizes = _blocks_to_block_sizes( - total_size=total_size, blocks=block_sizes + weight = self.get_packed_sharded( + f"{prefix}.weight", dim=0, block_sizes=block_sizes ) - - world_size = self.process_group.size() - rank = self.process_group.rank() - - tensors = [] - block_offset = 0 - for block_size in block_sizes: - assert ( - block_size % world_size == 0 - ), f"Prepacked weights cannot be sharded across {world_size} shards" - shard_block_size = block_size // world_size - start = rank * shard_block_size - stop = (rank + 1) * shard_block_size - tensor = slice_[block_offset + start : block_offset + stop] - tensors.append(tensor) - block_offset += block_size - weight = torch.cat(tensors, dim=0) - weight = weight.to(device=self.device) - weight = weight.to(dtype=self.dtype) return weight def get_weights_col(self, prefix: str, quantize: str):