Skip to content

Commit

Permalink
Removed the double re-exporting of Pallas GPU/TPU APIs
Browse files Browse the repository at this point in the history
jax.experimental.pallas.{gpu,tpu} now import directly from the relevant
jax._src.pallas.{triton,mosaic} submodules.

PiperOrigin-RevId: 641875127
  • Loading branch information
superbobry authored and jax authors committed Jun 10, 2024
1 parent 3b4039c commit 5e7ad60
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 141 deletions.
14 changes: 10 additions & 4 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -623,9 +623,14 @@ pytype_strict_library(
":pallas_tpu_users",
],
deps = [
":pallas", # buildcleaner: keep
":pallas", # build_cleaner: keep
":tpu_custom_call",
"//jax/_src/pallas/mosaic",
"//jax/_src/pallas/mosaic:core",
"//jax/_src/pallas/mosaic:kernel_regeneration_util",
"//jax/_src/pallas/mosaic:lowering",
"//jax/_src/pallas/mosaic:pallas_call_registration", # build_cleaner: keep
"//jax/_src/pallas/mosaic:pipeline",
"//jax/_src/pallas/mosaic:primitives",
],
)

Expand Down Expand Up @@ -663,8 +668,9 @@ pytype_strict_library(
],
deps = [
":pallas",
"//jax/_src/pallas/mosaic_gpu",
"//jax/_src/pallas/triton",
"//jax/_src/pallas/mosaic_gpu:pallas_call_registration", # build_cleaner: keep
"//jax/_src/pallas/triton:pallas_call_registration", # build_cleaner: keep
"//jax/_src/pallas/triton:primitives",
],
)

Expand Down
32 changes: 7 additions & 25 deletions jax/_src/pallas/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ package(

py_library(
name = "pallas",
srcs = glob(
include = ["**/*.py"],
exclude = [
"triton/*.py",
"mosaic/*.py",
],
),
srcs = [
"__init__.py",
"core.py",
"pallas_call.py",
"primitives.py",
"utils.py",
],
deps = [
"//jax",
"//jax:ad_util",
Expand All @@ -46,21 +46,3 @@ py_library(
"//jax/_src/lib",
] + py_deps("numpy"),
)

py_library(
name = "gpu",
visibility = [],
deps = [
":pallas",
"//jax/_src/pallas/triton",
],
)

py_library(
name = "tpu",
visibility = [],
deps = [
":pallas",
"//jax/_src/pallas/mosaic",
],
)
20 changes: 1 addition & 19 deletions jax/_src/pallas/mosaic/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,7 @@
# Package for Mosaic-specific Pallas extensions

load("@rules_python//python:defs.bzl", "py_library")
load(
"//jaxlib:jax.bzl",
"py_deps",
"py_library_providing_imports_info",
)
load("//jaxlib:jax.bzl", "py_deps")

package(
default_applicable_licenses = [],
Expand All @@ -28,20 +24,6 @@ package(
],
)

py_library_providing_imports_info(
name = "mosaic",
srcs = ["__init__.py"],
lib_rule = py_library,
deps = [
":core",
":kernel_regeneration_util",
":lowering",
":pallas_call_registration",
":pipeline",
":primitives",
],
)

py_library(
name = "core",
srcs = ["core.py"],
Expand Down
39 changes: 0 additions & 39 deletions jax/_src/pallas/mosaic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,42 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module for Mosaic lowering of Pallas call."""

from jax._src.pallas.mosaic import core
from jax._src.pallas.mosaic.core import dma_semaphore
from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec
from jax._src.pallas.mosaic.core import semaphore
from jax._src.pallas.mosaic.core import SemaphoreType
from jax._src.pallas.mosaic.core import TPUMemorySpace
from jax._src.pallas.mosaic.kernel_regeneration_util import encode_kernel_regeneration_metadata
from jax._src.pallas.mosaic.kernel_regeneration_util import extract_kernel_regeneration_metadata
from jax._src.pallas.mosaic.lowering import LoweringException
from jax._src.pallas.mosaic.pipeline import BufferedRef
from jax._src.pallas.mosaic.pipeline import emit_pipeline
from jax._src.pallas.mosaic.pipeline import emit_pipeline_with_allocations
from jax._src.pallas.mosaic.pipeline import get_pipeline_schedule
from jax._src.pallas.mosaic.pipeline import make_pipeline_allocations
from jax._src.pallas.mosaic.primitives import async_copy
from jax._src.pallas.mosaic.primitives import async_remote_copy
from jax._src.pallas.mosaic.primitives import bitcast
from jax._src.pallas.mosaic.primitives import delay
from jax._src.pallas.mosaic.primitives import device_id
from jax._src.pallas.mosaic.primitives import DeviceIdType
from jax._src.pallas.mosaic.primitives import get_barrier_semaphore
from jax._src.pallas.mosaic.primitives import make_async_copy
from jax._src.pallas.mosaic.primitives import make_async_remote_copy
from jax._src.pallas.mosaic.primitives import repeat
from jax._src.pallas.mosaic.primitives import roll
from jax._src.pallas.mosaic.primitives import run_scoped
from jax._src.pallas.mosaic.primitives import semaphore_read
from jax._src.pallas.mosaic.primitives import semaphore_signal
from jax._src.pallas.mosaic.primitives import semaphore_wait
from jax._src.pallas.mosaic.primitives import prng_seed
from jax._src.pallas.mosaic.primitives import prng_random_bits

ANY = TPUMemorySpace.ANY
CMEM = TPUMemorySpace.CMEM
SMEM = TPUMemorySpace.SMEM
VMEM = TPUMemorySpace.VMEM
13 changes: 0 additions & 13 deletions jax/_src/pallas/triton/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
load(
"//jaxlib:jax.bzl",
"py_deps",
"py_library_providing_imports_info",
"pytype_strict_library",
)

Expand All @@ -28,18 +27,6 @@ package(
],
)

py_library_providing_imports_info(
name = "triton",
srcs = ["__init__.py"],
lib_rule = pytype_strict_library,
deps = [
":lowering",
":pallas_call_registration",
":primitives",
"//jax/_src/lib",
],
)

pytype_strict_library(
name = "primitives",
srcs = ["primitives.py"],
Expand Down
5 changes: 0 additions & 5 deletions jax/_src/pallas/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Triton-specific Pallas APIs."""

from jax._src.pallas.triton.primitives import approx_tanh
from jax._src.pallas.triton.primitives import elementwise_inline_asm
4 changes: 2 additions & 2 deletions jax/experimental/pallas/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@

"""Triton-specific Pallas APIs."""

from jax._src.pallas.triton import approx_tanh
from jax._src.pallas.triton import elementwise_inline_asm
from jax._src.pallas.triton.primitives import approx_tanh
from jax._src.pallas.triton.primitives import elementwise_inline_asm
72 changes: 38 additions & 34 deletions jax/experimental/pallas/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,38 +12,42 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Contains Mosaic specific Pallas functions."""
from jax._src.pallas.mosaic import ANY
from jax._src.pallas.mosaic import CMEM
from jax._src.pallas.mosaic import PrefetchScalarGridSpec
from jax._src.pallas.mosaic import SMEM
from jax._src.pallas.mosaic import SemaphoreType
from jax._src.pallas.mosaic import TPUMemorySpace
from jax._src.pallas.mosaic import VMEM
from jax._src.pallas.mosaic import DeviceIdType
from jax._src.pallas.mosaic import async_copy
from jax._src.pallas.mosaic import async_remote_copy
from jax._src.pallas.mosaic import bitcast
from jax._src.pallas.mosaic import dma_semaphore
from jax._src.pallas.mosaic import delay
from jax._src.pallas.mosaic import device_id
from jax._src.pallas.mosaic import emit_pipeline_with_allocations
from jax._src.pallas.mosaic import emit_pipeline
from jax._src.pallas.mosaic import get_pipeline_schedule
from jax._src.pallas.mosaic import make_pipeline_allocations
from jax._src.pallas.mosaic import BufferedRef
from jax._src.pallas.mosaic import encode_kernel_regeneration_metadata
from jax._src.pallas.mosaic import extract_kernel_regeneration_metadata
from jax._src.pallas.mosaic import get_barrier_semaphore
from jax._src.pallas.mosaic import make_async_copy
from jax._src.pallas.mosaic import make_async_remote_copy
from jax._src.pallas.mosaic import repeat
from jax._src.pallas.mosaic import roll
from jax._src.pallas.mosaic import run_scoped
from jax._src.pallas.mosaic import semaphore
from jax._src.pallas.mosaic import semaphore_read
from jax._src.pallas.mosaic import semaphore_signal
from jax._src.pallas.mosaic import semaphore_wait
"""Mosaic-specific Pallas APIs."""

from jax._src.pallas.mosaic import core
from jax._src.pallas.mosaic.core import dma_semaphore
from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec
from jax._src.pallas.mosaic.core import semaphore
from jax._src.pallas.mosaic.core import SemaphoreType
from jax._src.pallas.mosaic.core import TPUMemorySpace
from jax._src.pallas.mosaic.kernel_regeneration_util import encode_kernel_regeneration_metadata
from jax._src.pallas.mosaic.kernel_regeneration_util import extract_kernel_regeneration_metadata
from jax._src.pallas.mosaic.lowering import LoweringException
from jax._src.pallas.mosaic.pipeline import BufferedRef
from jax._src.pallas.mosaic.pipeline import emit_pipeline
from jax._src.pallas.mosaic.pipeline import emit_pipeline_with_allocations
from jax._src.pallas.mosaic.pipeline import get_pipeline_schedule
from jax._src.pallas.mosaic.pipeline import make_pipeline_allocations
from jax._src.pallas.mosaic.primitives import async_copy
from jax._src.pallas.mosaic.primitives import async_remote_copy
from jax._src.pallas.mosaic.primitives import bitcast
from jax._src.pallas.mosaic.primitives import delay
from jax._src.pallas.mosaic.primitives import device_id
from jax._src.pallas.mosaic.primitives import DeviceIdType
from jax._src.pallas.mosaic.primitives import get_barrier_semaphore
from jax._src.pallas.mosaic.primitives import make_async_copy
from jax._src.pallas.mosaic.primitives import make_async_remote_copy
from jax._src.pallas.mosaic.primitives import repeat
from jax._src.pallas.mosaic.primitives import roll
from jax._src.pallas.mosaic.primitives import run_scoped
from jax._src.pallas.mosaic.primitives import semaphore_read
from jax._src.pallas.mosaic.primitives import semaphore_signal
from jax._src.pallas.mosaic.primitives import semaphore_wait
from jax._src.pallas.mosaic.primitives import prng_seed
from jax._src.pallas.mosaic.primitives import prng_random_bits
from jax._src.tpu_custom_call import CostEstimate
from jax._src.pallas.mosaic import prng_seed
from jax._src.pallas.mosaic import prng_random_bits

ANY = TPUMemorySpace.ANY
CMEM = TPUMemorySpace.CMEM
SMEM = TPUMemorySpace.SMEM
VMEM = TPUMemorySpace.VMEM

0 comments on commit 5e7ad60

Please sign in to comment.