Skip to content

Commit

Permalink
Added V5E tpu and slices to accelerators (#2838)
Browse files Browse the repository at this point in the history
* Added V5E tpu and slices to accelerators

Signed-off-by: pryce-turner <[email protected]>

* Added v5p and v6e

Signed-off-by: pryce-turner <[email protected]>

---------

Signed-off-by: pryce-turner <[email protected]>
  • Loading branch information
pryce-turner authored Oct 24, 2024
1 parent 1a1ee53 commit 5add665
Showing 1 changed file with 194 additions and 0 deletions.
194 changes: 194 additions & 0 deletions flytekit/extras/accelerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,197 @@ class _A100_80GB(_A100_80GB_Base):
#: .. autoclass:: _A100_80GB
#: :members:
A100_80GB = _A100_80GB()


class _V5E_Base(MultiInstanceGPUAccelerator):
device = "tpu-v5-lite-podslice"


class _V5E(_V5E_Base):
"""
Slices of a `Google Cloud TPU v5e <https://cloud.google.com/tpu/docs/v5e>`_.
"""

slice_1x1 = _V5E_Base.partitioned("1x1")
"""
1x1 topology representing 1 TPU chip or 1/8 of a host.
"""
slice_2x2 = _V5E_Base.partitioned("2x2")
"""
2x2 topology representing 4 TPU chip or 1/2 of a host.
"""
slice_2x4 = _V5E_Base.partitioned("2x4")
"""
2x4 topology representing 8 TPU chip or 1 host.
"""
slice_4x4 = _V5E_Base.partitioned("4x4")
"""
4x4 topology representing 16 TPU chip or 2 hosts.
"""
slice_4x8 = _V5E_Base.partitioned("4x8")
"""
4x8 topology representing 32 TPU chip or 4 hosts.
"""
slice_8x8 = _V5E_Base.partitioned("8x8")
"""
8x8 topology representing 64 TPU chip or 8 hosts.
"""
slice_8x16 = _V5E_Base.partitioned("8x16")
"""
8x16 topology representing 128 TPU chip or 16 hosts.
"""
slice_16x16 = _V5E_Base.partitioned("16x16")
"""
16x16 topology representing 256 TPU chip or 32 hosts.
"""


#: use this constant to specify that the task should run on V5E TPU.
#: `Google V5E Cloud TPU <https://cloud.google.com/tpu/docs/v5e>`_.
#:
#: Use pre-defined slices (as instance attributes). For example, to specify a 2x4 slice, use
#: ``V5E.slice_2x4``.
#: All available partitions are listed below:
#:
#: .. autoclass:: _V5E
#: :members:
V5E = _V5E()


class _V5P_Base(MultiInstanceGPUAccelerator):
device = "tpu-v5p-slice"


class _V5P(_V5P_Base):
"""
Slices of a `Google Cloud TPU v5p <https://cloud.google.com/tpu/docs/v5p>`_.
"""

slice_2x2x1 = _V5P_Base.partitioned("2x2x1")
"""
2x2x1 topology representing 8 TPU cores, 4 chips, 1 host.
"""

slice_2x2x2 = _V5P_Base.partitioned("2x2x2")
"""
2x2x2 topology representing 16 TPU cores, 8 chips, 2 machines.
"""

slice_2x4x4 = _V5P_Base.partitioned("2x4x4")
"""
2x4x4 topology representing 64 TPU cores, 32 chips, 8 machines.
"""

slice_4x4x4 = _V5P_Base.partitioned("4x4x4")
"""
4x4x4 topology representing 128 TPU cores, 64 chips, 16 machines.
"""

slice_4x4x8 = _V5P_Base.partitioned("4x4x8")
"""
4x4x8 topology representing 256 TPU cores, 128 chips, 32 machines. Supports Twisted Topology.
"""

slice_4x8x8 = _V5P_Base.partitioned("4x8x8")
"""
4x8x8 topology representing 512 TPU cores, 256 chips, 64 machines. Supports Twisted Topology.
"""

slice_8x8x8 = _V5P_Base.partitioned("8x8x8")
"""
8x8x8 topology representing 1024 TPU cores, 512 chips, 128 machines.
"""

slice_8x8x16 = _V5P_Base.partitioned("8x8x16")
"""
8x8x16 topology representing 2048 TPU cores, 1024 chips, 256 machines. Supports Twisted Topology.
"""

slice_8x16x16 = _V5P_Base.partitioned("8x16x16")
"""
8x16x16 topology representing 4096 TPU cores, 2048 chips, 512 machines. Supports Twisted Topology.
"""

slice_16x16x16 = _V5P_Base.partitioned("16x16x16")
"""
16x16x16 topology representing 8192 TPU cores, 4096 chips, 1024 machines.
"""

slice_16x16x24 = _V5P_Base.partitioned("16x16x24")
"""
16x16x24 topology representing 12288 TPU cores, 6144 chips, 1536 machines.
"""


#: Use this constant to specify that the task should run on V5P TPU.
#: `Google V5P Cloud TPU <https://cloud.google.com/tpu/docs/v5p>`_.
#:
#: Use pre-defined slices (as instance attributes). For example, to specify a 2x4x4 slice, use
#: ``V5P.slice_2x4x4``.
#: All available partitions are listed below:
#:
#: .. autoclass:: _V5P
#: :members:
V5P = _V5P()


class _V6E_Base(MultiInstanceGPUAccelerator):
device = "tpu-v6e-slice"


class _V6E(_V6E_Base):
"""
Slices of a `Google Cloud TPU v6e <https://cloud.google.com/tpu/docs/v6e>`_.
"""

slice_1x1 = _V6E_Base.partitioned("1x1")
"""
1x1 topology representing 1 TPU core or 1/8 of a host.
"""

slice_2x2 = _V6E_Base.partitioned("2x2")
"""
2x2 topology representing 4 TPU cores or 1/2 of a host.
"""

slice_2x4 = _V6E_Base.partitioned("2x4")
"""
2x4 topology representing 8 TPU cores or 1 host.
"""

slice_4x4 = _V6E_Base.partitioned("4x4")
"""
4x4 topology representing 16 TPU cores or 2 hosts.
"""

slice_4x8 = _V6E_Base.partitioned("4x8")
"""
4x8 topology representing 32 TPU cores or 4 hosts.
"""

slice_8x8 = _V6E_Base.partitioned("8x8")
"""
8x8 topology representing 64 TPU cores or 8 hosts.
"""

slice_8x16 = _V6E_Base.partitioned("8x16")
"""
8x16 topology representing 128 TPU cores or 16 hosts.
"""

slice_16x16 = _V6E_Base.partitioned("16x16")
"""
16x16 topology representing 256 TPU cores or 32 hosts.
"""


#: Use this constant to specify that the task should run on V6E TPU.
#: `Google V6E Cloud TPU <https://cloud.google.com/tpu/docs/v6e>`_.
#:
#: Use pre-defined slices (as instance attributes). For example, to specify a 2x4 slice, use
#: ``V6E.slice_2x4``.
#: All available partitions are listed below:
#:
#: .. autoclass:: _V6E
#: :members:
V6E = _V6E()

0 comments on commit 5add665

Please sign in to comment.