diff --git a/jax/BUILD b/jax/BUILD index ef7f07f5e4a6..ca9b71851bed 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -623,9 +623,14 @@ pytype_strict_library( ":pallas_tpu_users", ], deps = [ - ":pallas", # buildcleaner: keep + ":pallas", # build_cleaner: keep ":tpu_custom_call", - "//jax/_src/pallas/mosaic", + "//jax/_src/pallas/mosaic:core", + "//jax/_src/pallas/mosaic:kernel_regeneration_util", + "//jax/_src/pallas/mosaic:lowering", + "//jax/_src/pallas/mosaic:pallas_call_registration", # build_cleaner: keep + "//jax/_src/pallas/mosaic:pipeline", + "//jax/_src/pallas/mosaic:primitives", ], ) @@ -663,8 +668,9 @@ pytype_strict_library( ], deps = [ ":pallas", - "//jax/_src/pallas/mosaic_gpu", - "//jax/_src/pallas/triton", + "//jax/_src/pallas/mosaic_gpu:pallas_call_registration", # build_cleaner: keep + "//jax/_src/pallas/triton:pallas_call_registration", # build_cleaner: keep + "//jax/_src/pallas/triton:primitives", ], ) diff --git a/jax/_src/pallas/BUILD b/jax/_src/pallas/BUILD index 85ea5a13dba6..c0fa02131bc8 100644 --- a/jax/_src/pallas/BUILD +++ b/jax/_src/pallas/BUILD @@ -27,13 +27,13 @@ package( py_library( name = "pallas", - srcs = glob( - include = ["**/*.py"], - exclude = [ - "triton/*.py", - "mosaic/*.py", - ], - ), + srcs = [ + "__init__.py", + "core.py", + "pallas_call.py", + "primitives.py", + "utils.py", + ], deps = [ "//jax", "//jax:ad_util", @@ -46,21 +46,3 @@ py_library( "//jax/_src/lib", ] + py_deps("numpy"), ) - -py_library( - name = "gpu", - visibility = [], - deps = [ - ":pallas", - "//jax/_src/pallas/triton", - ], -) - -py_library( - name = "tpu", - visibility = [], - deps = [ - ":pallas", - "//jax/_src/pallas/mosaic", - ], -) diff --git a/jax/_src/pallas/mosaic/BUILD b/jax/_src/pallas/mosaic/BUILD index b16a819ab914..968bc43e8c9e 100644 --- a/jax/_src/pallas/mosaic/BUILD +++ b/jax/_src/pallas/mosaic/BUILD @@ -15,11 +15,7 @@ # Package for Mosaic-specific Pallas extensions load("@rules_python//python:defs.bzl", "py_library") -load( - "//jaxlib:jax.bzl", - "py_deps", - "py_library_providing_imports_info", -) +load("//jaxlib:jax.bzl", "py_deps") package( default_applicable_licenses = [], @@ -28,20 +24,6 @@ package( ], ) -py_library_providing_imports_info( - name = "mosaic", - srcs = ["__init__.py"], - lib_rule = py_library, - deps = [ - ":core", - ":kernel_regeneration_util", - ":lowering", - ":pallas_call_registration", - ":pipeline", - ":primitives", - ], -) - py_library( name = "core", srcs = ["core.py"], diff --git a/jax/_src/pallas/mosaic/__init__.py b/jax/_src/pallas/mosaic/__init__.py index 2ab64462b102..38d13f42da99 100644 --- a/jax/_src/pallas/mosaic/__init__.py +++ b/jax/_src/pallas/mosaic/__init__.py @@ -11,42 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -"""Module for Mosaic lowering of Pallas call.""" - -from jax._src.pallas.mosaic import core -from jax._src.pallas.mosaic.core import dma_semaphore -from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec -from jax._src.pallas.mosaic.core import semaphore -from jax._src.pallas.mosaic.core import SemaphoreType -from jax._src.pallas.mosaic.core import TPUMemorySpace -from jax._src.pallas.mosaic.kernel_regeneration_util import encode_kernel_regeneration_metadata -from jax._src.pallas.mosaic.kernel_regeneration_util import extract_kernel_regeneration_metadata -from jax._src.pallas.mosaic.lowering import LoweringException -from jax._src.pallas.mosaic.pipeline import BufferedRef -from jax._src.pallas.mosaic.pipeline import emit_pipeline -from jax._src.pallas.mosaic.pipeline import emit_pipeline_with_allocations -from jax._src.pallas.mosaic.pipeline import get_pipeline_schedule -from jax._src.pallas.mosaic.pipeline import make_pipeline_allocations -from jax._src.pallas.mosaic.primitives import async_copy -from jax._src.pallas.mosaic.primitives import async_remote_copy -from jax._src.pallas.mosaic.primitives import bitcast -from jax._src.pallas.mosaic.primitives import delay -from jax._src.pallas.mosaic.primitives import device_id -from jax._src.pallas.mosaic.primitives import DeviceIdType -from jax._src.pallas.mosaic.primitives import get_barrier_semaphore -from jax._src.pallas.mosaic.primitives import make_async_copy -from jax._src.pallas.mosaic.primitives import make_async_remote_copy -from jax._src.pallas.mosaic.primitives import repeat -from jax._src.pallas.mosaic.primitives import roll -from jax._src.pallas.mosaic.primitives import run_scoped -from jax._src.pallas.mosaic.primitives import semaphore_read -from jax._src.pallas.mosaic.primitives import semaphore_signal -from jax._src.pallas.mosaic.primitives import semaphore_wait -from jax._src.pallas.mosaic.primitives import prng_seed -from jax._src.pallas.mosaic.primitives import prng_random_bits - -ANY = TPUMemorySpace.ANY -CMEM = TPUMemorySpace.CMEM -SMEM = TPUMemorySpace.SMEM -VMEM = TPUMemorySpace.VMEM diff --git a/jax/_src/pallas/triton/BUILD b/jax/_src/pallas/triton/BUILD index d8b4c6ed3497..370cbb713ac5 100644 --- a/jax/_src/pallas/triton/BUILD +++ b/jax/_src/pallas/triton/BUILD @@ -17,7 +17,6 @@ load( "//jaxlib:jax.bzl", "py_deps", - "py_library_providing_imports_info", "pytype_strict_library", ) @@ -28,18 +27,6 @@ package( ], ) -py_library_providing_imports_info( - name = "triton", - srcs = ["__init__.py"], - lib_rule = pytype_strict_library, - deps = [ - ":lowering", - ":pallas_call_registration", - ":primitives", - "//jax/_src/lib", - ], -) - pytype_strict_library( name = "primitives", srcs = ["primitives.py"], diff --git a/jax/_src/pallas/triton/__init__.py b/jax/_src/pallas/triton/__init__.py index adade4e8a72c..38d13f42da99 100644 --- a/jax/_src/pallas/triton/__init__.py +++ b/jax/_src/pallas/triton/__init__.py @@ -11,8 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -"""Triton-specific Pallas APIs.""" - -from jax._src.pallas.triton.primitives import approx_tanh -from jax._src.pallas.triton.primitives import elementwise_inline_asm diff --git a/jax/experimental/pallas/gpu.py b/jax/experimental/pallas/gpu.py index 8047aeed245b..adade4e8a72c 100644 --- a/jax/experimental/pallas/gpu.py +++ b/jax/experimental/pallas/gpu.py @@ -14,5 +14,5 @@ """Triton-specific Pallas APIs.""" -from jax._src.pallas.triton import approx_tanh -from jax._src.pallas.triton import elementwise_inline_asm +from jax._src.pallas.triton.primitives import approx_tanh +from jax._src.pallas.triton.primitives import elementwise_inline_asm diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index fcfe560958d1..e1e215a85cc0 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -12,38 +12,42 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Contains Mosaic specific Pallas functions.""" -from jax._src.pallas.mosaic import ANY -from jax._src.pallas.mosaic import CMEM -from jax._src.pallas.mosaic import PrefetchScalarGridSpec -from jax._src.pallas.mosaic import SMEM -from jax._src.pallas.mosaic import SemaphoreType -from jax._src.pallas.mosaic import TPUMemorySpace -from jax._src.pallas.mosaic import VMEM -from jax._src.pallas.mosaic import DeviceIdType -from jax._src.pallas.mosaic import async_copy -from jax._src.pallas.mosaic import async_remote_copy -from jax._src.pallas.mosaic import bitcast -from jax._src.pallas.mosaic import dma_semaphore -from jax._src.pallas.mosaic import delay -from jax._src.pallas.mosaic import device_id -from jax._src.pallas.mosaic import emit_pipeline_with_allocations -from jax._src.pallas.mosaic import emit_pipeline -from jax._src.pallas.mosaic import get_pipeline_schedule -from jax._src.pallas.mosaic import make_pipeline_allocations -from jax._src.pallas.mosaic import BufferedRef -from jax._src.pallas.mosaic import encode_kernel_regeneration_metadata -from jax._src.pallas.mosaic import extract_kernel_regeneration_metadata -from jax._src.pallas.mosaic import get_barrier_semaphore -from jax._src.pallas.mosaic import make_async_copy -from jax._src.pallas.mosaic import make_async_remote_copy -from jax._src.pallas.mosaic import repeat -from jax._src.pallas.mosaic import roll -from jax._src.pallas.mosaic import run_scoped -from jax._src.pallas.mosaic import semaphore -from jax._src.pallas.mosaic import semaphore_read -from jax._src.pallas.mosaic import semaphore_signal -from jax._src.pallas.mosaic import semaphore_wait +"""Mosaic-specific Pallas APIs.""" + +from jax._src.pallas.mosaic import core +from jax._src.pallas.mosaic.core import dma_semaphore +from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec +from jax._src.pallas.mosaic.core import semaphore +from jax._src.pallas.mosaic.core import SemaphoreType +from jax._src.pallas.mosaic.core import TPUMemorySpace +from jax._src.pallas.mosaic.kernel_regeneration_util import encode_kernel_regeneration_metadata +from jax._src.pallas.mosaic.kernel_regeneration_util import extract_kernel_regeneration_metadata +from jax._src.pallas.mosaic.lowering import LoweringException +from jax._src.pallas.mosaic.pipeline import BufferedRef +from jax._src.pallas.mosaic.pipeline import emit_pipeline +from jax._src.pallas.mosaic.pipeline import emit_pipeline_with_allocations +from jax._src.pallas.mosaic.pipeline import get_pipeline_schedule +from jax._src.pallas.mosaic.pipeline import make_pipeline_allocations +from jax._src.pallas.mosaic.primitives import async_copy +from jax._src.pallas.mosaic.primitives import async_remote_copy +from jax._src.pallas.mosaic.primitives import bitcast +from jax._src.pallas.mosaic.primitives import delay +from jax._src.pallas.mosaic.primitives import device_id +from jax._src.pallas.mosaic.primitives import DeviceIdType +from jax._src.pallas.mosaic.primitives import get_barrier_semaphore +from jax._src.pallas.mosaic.primitives import make_async_copy +from jax._src.pallas.mosaic.primitives import make_async_remote_copy +from jax._src.pallas.mosaic.primitives import repeat +from jax._src.pallas.mosaic.primitives import roll +from jax._src.pallas.mosaic.primitives import run_scoped +from jax._src.pallas.mosaic.primitives import semaphore_read +from jax._src.pallas.mosaic.primitives import semaphore_signal +from jax._src.pallas.mosaic.primitives import semaphore_wait +from jax._src.pallas.mosaic.primitives import prng_seed +from jax._src.pallas.mosaic.primitives import prng_random_bits from jax._src.tpu_custom_call import CostEstimate -from jax._src.pallas.mosaic import prng_seed -from jax._src.pallas.mosaic import prng_random_bits + +ANY = TPUMemorySpace.ANY +CMEM = TPUMemorySpace.CMEM +SMEM = TPUMemorySpace.SMEM +VMEM = TPUMemorySpace.VMEM