diff --git a/include/tvm/relax/attrs/ccl.h b/include/tvm/relax/attrs/ccl.h new file mode 100644 index 0000000000000..45de0e949c4c7 --- /dev/null +++ b/include/tvm/relax/attrs/ccl.h @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/attrs/ccl.h + * \brief Attributes for ccl operators. + */ +#ifndef TVM_RELAX_ATTRS_CCL_H_ +#define TVM_RELAX_ATTRS_CCL_H_ + +#include + +namespace tvm { +namespace relax { + +/*! \brief Attributes used in allreduce operators */ +struct AllReduceAttrs : public tvm::AttrsNode { + String op_type; + + TVM_DECLARE_ATTRS(AllReduceAttrs, "relax.attrs.AllReduceAttrs") { + TVM_ATTR_FIELD(op_type).describe( + "The type of reduction operation to be applied to the input data. Now only sum is " + "supported."); + } +}; // struct AllReduceAttrs + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ATTRS_CCL_H_ diff --git a/include/tvm/runtime/disco/session.h b/include/tvm/runtime/disco/session.h new file mode 100644 index 0000000000000..962bf5424596c --- /dev/null +++ b/include/tvm/runtime/disco/session.h @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file session.h + * \brief This file serves as the entry point of Disco and defines key data structures and + * interfaces. + * + * Disco is a distributed runtime that consists of a controler and a cluster of workers. The + * controler is responsible for managing the workers by broadcasting commands to all the workers + * together, and the workers are responsible for executing the commands and. The controler and + * workers communicate with each other through a bi-directional channel. + * + * Different from a generic system, Disco is designed to as "single-program-multiple-data" (SPMD) + * runtime, which means that all the workers execute the same instruction at the same time, but the + * data they are working on may be different. For example, in data parallelism, each worker may + * work on a different batches of the data, but they all execute the same set of instructions. + * Therefore, imagine there is a virtual machine that executes the program, the structures of + * workers' register files could be considered as "identical" (single program) although the values + * may differ (multiple data). + * + * **DRef.** Following the design above, consider the program in SPMD in a virtual ISA, then each + * worker is a virtual machine instance to execute the ISA maintaining its own register file. + * The controler denotes each of their register files with a unique integer "register id", + * and the workers use this id to refer to the register file that resides on itself. + * DRef is a control-side object backed by such a register id. The data it contains is not assumed + * to be directly accessible by the controler, with an exception for worker-0, which is a special + * worker that is always co-located with the controler. + * + * **Worker-0.** Worker-0 is a special worker that is always co-located with the controler. + * It is assumed that the controler can synchronize with and access the registers of worker-0. + * The Disco session provides multiple APIs to interact specifically with the worker-0. + * To shared data with other workers, a common paradigm in Disco is to copy data from the + * controler-side NDArray to the worker-0, and then copy it to other workers using primitives on + * the data plane, for example, `broadcast` and `send`. + * + * **Control plane.** The controler broadcasts commands to all the workers as control signals. + * For example, the control may ask all workers to load a library or call a function respectively. + * Common control signals include: shutdown, retrievel a global PackedFunc, call packed function, + * etc. The controler is assumed to keep a message channel to each worker to implement the broadcast + * behavior, and the message channel may vary depends on usecases. + * + * **Data plane.** The data channel is usually used to exchange data between workers, especially for + * tensor data which is usually large. For example, performing an allreduce operator for sharded + * matrix multiplication, or broadcasting for an input tensor. For efficiency, the data channel is + * usually backed by NCCL on NVIDIA GPUs, RCCL on AMD GPUs, or MPI on CPUs. + * + * **Session.** A Disco session is a primary interface to interact with the Disco runtime, serving + * as a global context that manages the control and workers. It could be implemented as a + * multi-threaded with a pool of workers for single-node multi-gpu scenarios, or TCP sockets for + * workloads that span over a cluster of nodes. + * + * **Channel.** Disco channel is a bi-directional communication channel between the controler and + * workers for exchanging control signals. It is no different from a generic RPC channel, but + * adopts TVM's PackedFunc calling convention to support polymorphic and variadic arguments. + */ +#ifndef TVM_RUNTIME_DISCO_SESSION_H_ +#define TVM_RUNTIME_DISCO_SESSION_H_ + +#include +#include + +#include +#include + +namespace tvm { +namespace runtime { + +/*! + * \brief All possible kinds of Disco commands. + */ +enum class DiscoAction : int32_t { + kShutDown = 0, + kKillReg = 1, + kGetGlobalFunc = 2, + kCallPacked = 3, + kSyncWorker = 4, + kCopyFromWorker0 = 5, + kCopyToWorker0 = 6, +}; + +/*! \brief Converts the enum class `DiscoAction` to string */ +inline std::string DiscoAction2String(DiscoAction action) { + switch (action) { + case DiscoAction::kShutDown: + return "kShutDown"; + case DiscoAction::kKillReg: + return "kKillReg"; + case DiscoAction::kGetGlobalFunc: + return "kGetGlobalFunc"; + case DiscoAction::kCallPacked: + return "kCallPacked"; + case DiscoAction::kSyncWorker: + return "kSyncWorker"; + case DiscoAction::kCopyFromWorker0: + return "kCopyFromWorker0"; + case DiscoAction::kCopyToWorker0: + return "kCopyToWorker0"; + } + LOG(FATAL) << "ValueError: Unknown DiscoAction: " << static_cast(action); +} + +/*! + * \brief An object that exists on all workers. + * + * The controler assigns a unique "register id" to each object, and the worker uses this id to + * refer to the object residing on itself. + */ +class DRefObj : public Object { + public: + /*!\ brief Send dellocation command for `reg_id` */ + inline ~DRefObj(); + /*! + * \brief Get the value of a DRef from a remote worker. + * \param worker_id The id of the worker to be fetched from. + * \return The value of the register. + */ + inline TVMRetValue DebugGetFromRemote(int worker_id); + + static constexpr const char* _type_key = "runtime.disco.DRef"; + static constexpr const uint32_t _type_index = TypeIndex::kRuntimeDiscoDRef; + TVM_DECLARE_FINAL_OBJECT_INFO(DRefObj, Object); + + /*! \brief The id of the register */ + int64_t reg_id; + /*! \brief Back-pointer to the host controler session */ + ObjectRef session{nullptr}; +}; + +/*! + * \brief Managed reference to DRefObj. + * \sa DRefObj + * \note No public constructor is provided as it is not supposed to be directly created by users. + */ +class DRef : public ObjectRef { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(DRef, ObjectRef, DRefObj); +}; + +/*! + * \brief A Disco interactive session. It allows users to interact with the Disco command queue with + * various PackedFunc calling convention. + */ +class SessionObj : public Object { + public: + virtual ~SessionObj() = default; + /*! + * \brief Call a PackedFunc on workers providing variadic arguments. + * \tparam Args In the variadic arguments, the supported types include: + * - integers and floating point numbers; + * - DataType; + * - Device; + * - std::string; + * - DRef. + * Examples of unsupported types: + * - NDArray, DLTensor; + * - TVM Objects, including PackedFunc, Module and String; + * \param func The function to be called. + * \param args The variadic arguments. + * \return The return value of function call + */ + template + DRef TVM_ALWAYS_INLINE CallPacked(const DRef& func, Args&&... args); + /*! \brief Get a global functions on workers. */ + virtual DRef GetGlobalFunc(const std::string& name) = 0; + /*! + * \brief Copy the controler-side NDArray to worker-0 + * \param host_array The array to be copied to worker-0 + * \param remote_array The NDArray on worker-0 + */ + virtual void CopyFromWorker0(const NDArray& host_array, const DRef& remote_array) = 0; + /*! + * \brief Copy an NDArray from worker-0 to the controler-side NDArray + * \param host_array The array to be copied to worker-0 + * \param remote_array The NDArray on worker-0 + */ + virtual void CopyToWorker0(const NDArray& host_array, const DRef& remote_array) = 0; + /*! + * \brief Synchrnoize the controler with a worker, and it will wait until worker finishes + * executing this instruction. + * \param worker_id The id of the worker to be synced with. + * \note This function is usually used for worker-0, because it is the only worker that is + * assumed to collocate with the controler. Syncing with other workers may not be supported. + */ + virtual void SyncWorker(int worker_id) = 0; + /*! \brief Signal all the workers to shutdown */ + virtual void Shutdown() = 0; + /*! + * \brief Get the value of a register from a remote worker. + * \param reg_id The id of the register to be fetched. + * \param worker_id The id of the worker to be fetched from. + * \return The value of the register. + */ + virtual TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) = 0; + + static constexpr const char* _type_key = "runtime.disco.Session"; + TVM_DECLARE_BASE_OBJECT_INFO(SessionObj, Object); + + struct FFI; + friend struct SessionObj::FFI; + friend class DRefObj; + + protected: + /*! \brief Deallocate a register id, kill it on all workers, and append it to `free_regs_`. */ + virtual void DeallocReg(int reg_id) = 0; + /*! \brief Call packed function on each worker using a packed sequence */ + virtual DRef CallWithPacked(const TVMArgs& args) = 0; +}; + +/*! + * \brief Managed reference to SessionObj + * \sa SessionObj + */ +class Session : public ObjectRef { + public: + /*! \brief Create a session backed by a thread pool of workers */ + static Session ThreadedSession(int num_workers); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Session, ObjectRef, SessionObj); +}; + +/*! + * \brief A bi-directional channel for controler-worker communication. + * This channel is primarily used to transfer control messages but not data. + */ +class DiscoChannel { + public: + /*! \brief Send a packed sequence to the receiver */ + virtual void Send(const TVMArgs& args) = 0; + /*! \brief Receive a packed sequence from worker */ + virtual TVMArgs Recv() = 0; + /*! \brief Reply a packed sequence to the sender */ + virtual void Reply(const TVMArgs& args) = 0; + /*! \brief Receive a reply from the worker */ + virtual TVMArgs RecvReply() = 0; +}; + +// Implementation details + +DRefObj::~DRefObj() { + if (this->session.defined()) { + Downcast(this->session)->DeallocReg(reg_id); + } +} + +TVMRetValue DRefObj::DebugGetFromRemote(int worker_id) { + return Downcast(this->session)->DebugGetFromRemote(this->reg_id, worker_id); +} + +template +DRef SessionObj::CallPacked(const DRef& func, Args&&... args) { + constexpr int offset = 3; + constexpr int kNumArgs = offset + sizeof...(Args); + TVMValue values[kNumArgs]; + int type_codes[kNumArgs]; + PackArgs(values, type_codes, + /*.0=*/static_cast(DiscoAction::kCallPacked), // action + /*.1=*/0, // reg_id, which will be updated by this->CallWithPacked + /*.2=*/func, // the function to be called + std::forward(args)...); + return this->CallWithPacked(TVMArgs(values, type_codes, kNumArgs)); +} + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_DISCO_SESSION_H_ diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 4fdca46e8d5b3..1fa875c22ee76 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -38,6 +38,7 @@ from . import image from . import memory from . import nn +from . import ccl # Register operator gradient functions from . import _op_gradient diff --git a/python/tvm/relax/op/ccl/__init__.py b/python/tvm/relax/op/ccl/__init__.py new file mode 100644 index 0000000000000..20746eb053136 --- /dev/null +++ b/python/tvm/relax/op/ccl/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import +"""CCL related operators.""" +from .ccl import * diff --git a/python/tvm/relax/op/ccl/_ffi_api.py b/python/tvm/relax/op/ccl/_ffi_api.py new file mode 100644 index 0000000000000..cdf4687810613 --- /dev/null +++ b/python/tvm/relax/op/ccl/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Operators serving for Collective Communications Library (CCL) operators""" +import tvm._ffi + +tvm._ffi._init_api("relax.op.ccl", __name__) diff --git a/python/tvm/relax/op/ccl/ccl.py b/python/tvm/relax/op/ccl/ccl.py new file mode 100644 index 0000000000000..7fede70543def --- /dev/null +++ b/python/tvm/relax/op/ccl/ccl.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relax Collective Communications Library (CCL) operators""" +from . import _ffi_api + + +def allreduce(x, op_type: str = "sum"): # pylint: disable=invalid-name + """Allreduce operator + + Parameters + ---------- + x : relax.Expr + The input tensor. + op_type: str + The type of reduction operation to be applied to the input data. + Now "sum", "prod", "min", "max" and "avg" are supported. + + Returns + ------- + result : relax.Expr + The result of allreduce. + """ + supported_op_types = ["sum", "prod", "min", "max", "avg"] + assert op_type in supported_op_types, ( + "Allreduce only supports limited reduction operations, " + f"including {supported_op_types}, but got {op_type}." + ) + return _ffi_api.allreduce(x, op_type) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/relax/transform/legalize_ops/__init__.py b/python/tvm/relax/transform/legalize_ops/__init__.py index 613bd8970f9e4..c851851ea90ec 100644 --- a/python/tvm/relax/transform/legalize_ops/__init__.py +++ b/python/tvm/relax/transform/legalize_ops/__init__.py @@ -16,6 +16,7 @@ # under the License. """Legalize high-level operator calls in Relax functions to call_tir.""" from . import binary +from . import ccl from . import create from . import datatype from . import grad diff --git a/python/tvm/relax/transform/legalize_ops/ccl.py b/python/tvm/relax/transform/legalize_ops/ccl.py new file mode 100644 index 0000000000000..b1df10451800a --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/ccl.py @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Default legalization function for ccl operators.""" +from ...block_builder import BlockBuilder +from ...expr import Call, Expr, ShapeExpr +from ...op import call_pure_packed +from .common import register_legalize + + +@register_legalize("relax.ccl.allreduce") +def _allreduce(_bb: BlockBuilder, call: Call) -> Expr: + op_type_str = call.attrs.op_type + op_type_map = { + "sum": 0, + "prod": 1, + "min": 2, + "max": 3, + "avg": 4, + } + if op_type_str not in op_type_map: + raise ValueError( + f"Unsupported reduction operation: {op_type_str}. " + f"Supported operations are {op_type_map.keys()}." + ) + return call_pure_packed( + "runtime.disco.allreduce", + call.args[0], + ShapeExpr([op_type_map[op_type_str]]), + sinfo_args=call.args[0].struct_info, + ) diff --git a/python/tvm/runtime/disco/__init__.py b/python/tvm/runtime/disco/__init__.py new file mode 100644 index 0000000000000..57c0548e2ed96 --- /dev/null +++ b/python/tvm/runtime/disco/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TVM distributed runtime API.""" +from .session import DModule, DPackedFunc, DRef, Session, ThreadedSession diff --git a/python/tvm/runtime/disco/_ffi_api.py b/python/tvm/runtime/disco/_ffi_api.py new file mode 100644 index 0000000000000..340be86708db3 --- /dev/null +++ b/python/tvm/runtime/disco/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs from C++""" +from ..._ffi import _init_api + +_init_api("runtime.disco", __name__) diff --git a/python/tvm/runtime/disco/session.py b/python/tvm/runtime/disco/session.py new file mode 100644 index 0000000000000..6fafb9dbbc1b2 --- /dev/null +++ b/python/tvm/runtime/disco/session.py @@ -0,0 +1,301 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module defines a Session in Disco. Session is the primary interface that users interact +with the distributed runtime. +""" +from typing import Any, Callable, Optional, Sequence + +from ..._ffi import register_object +from ..._ffi.runtime_ctypes import Device +from ..ndarray import NDArray +from ..object import Object +from . import _ffi_api + + +@register_object("runtime.disco.DRef") +class DRef(Object): + """An object that exists on all workers. The controller process assigns a unique "register id" + to each object, and the worker process uses this id to refer to the object residing on itself. + """ + + @property + def session(self) -> "Session": + """Get the session that this DRef belongs to.""" + return _ffi_api.DRefSession(self) # type: ignore # pylint: disable=no-member + + def debug_get_from_remote(self, worker_id: int) -> Any: + """Get the value of a DRef from a remote worker. It is only used for debugging purposes. + + Parameters + ---------- + worker_id : int + The id of the worker to be fetched from. + + Returns + ------- + value : object + The value of the register. + """ + return _ffi_api.DRefDebugGetFromRemote(self, worker_id) # type: ignore # pylint: disable=no-member + + +class DPackedFunc(DRef): + """A PackedFunc in a Disco session.""" + + def __init__(self, dref: DRef) -> None: + self.handle = dref.handle + dref.handle = None + + def __call__(self, *args) -> DRef: + return self.session.call_packed(self, *args) + + +class DModule(DRef): + """A Module in a Disco session.""" + + def __init__(self, dref: DRef) -> None: + self.handle = dref.handle + del dref.handle + + def __getitem__(self, name: str) -> DPackedFunc: + func = self.session._get_cached_method("runtime.ModuleGetFunction") + return DPackedFunc(func(self, name, False)) + + +@register_object("runtime.disco.Session") +class Session(Object): + """A Disco interactive session. It allows users to interact with the Disco command queue with + various PackedFunc calling convention.""" + + def _get_cached_method(self, name: str) -> Callable: + if not hasattr(self, "_cache"): + cache = self._cache = {} # pylint: disable=attribute-defined-outside-init + else: + cache = self._cache + if name not in cache: + func = cache[name] = self.get_global_func(name) + else: + func = cache[name] + return func + + def empty( + self, + shape: Sequence[int], + dtype: str, + device: Optional[Device] = None, + ) -> DRef: + """Create an empty NDArray on all workers and attach them to a DRef. + + Parameters + ---------- + shape : tuple of int + The shape of the NDArray. + dtype : str + The data type of the NDArray. + device : Optional[Device] = None + The device of the NDArray. + + Returns + ------- + array : DRef + The created NDArray. + """ + if device is None: + device = Device(device_type=0, device_id=0) + func = self._get_cached_method("runtime.disco.empty") + return func(*shape, dtype, device) + + def get_global_func(self, name: str) -> DRef: + """Get a global function on workers. + + Parameters + ---------- + name : str + The name of the global function. + + Returns + ------- + func : DRef + The global packed function + """ + return DPackedFunc(_ffi_api.SessionGetGlobalFunc(self, name)) # type: ignore # pylint: disable=no-member + + def call_packed(self, func: DRef, *args) -> DRef: + """Call a PackedFunc on workers providing variadic arguments. + + Parameters + ---------- + func : PackedFunc + The function to be called. + *args : various types + In the variadic arguments, the supported types include: + - integers and floating point numbers; + - DLDataType; + - DLDevice; + - str (std::string in C++); + - DRef. + + Returns + ------- + return_value : various types + The return value of the function call. + + Notes + ----- + Examples of unsupported types: + - NDArray, DLTensor,; + - TVM Objects, including PackedFunc, Module and String. + """ + return _ffi_api.SessionCallPacked(self, 0, 0, func, *args) # type: ignore # pylint: disable=no-member + + def sync_worker(self, worker_id: int = 0) -> None: + """Synchronize the controller with a worker, and it will wait until the worker finishes + executing this instruction. + + Parameters + ---------- + worker_id : int + The id of the worker to be synced with. + + Notes + ----- + This function is usually used for worker-0, because it is the only worker that is + assumed to collocate with the controller. Syncing with other workers may not be supported + and should only be used for debugging purposes. + """ + return _ffi_api.SessionSyncWorker(self, worker_id) # type: ignore # pylint: disable=no-member + + def copy_from_worker_0(self, host_array: NDArray, remote_array: DRef) -> None: + """Copy the controller-side NDArray to worker-0. + + Parameters + ---------- + host_array : numpy.ndarray + The array to be copied to worker-0. + remote_array : NDArray + The NDArray on worker-0. + """ + return _ffi_api.SessionCopyFromWorker0(self, host_array, remote_array) # type: ignore # pylint: disable=no-member + + def copy_to_worker_0(self, host_array: NDArray, remote_array: DRef) -> None: + """Copy an NDArray from worker-0 to the controller-side NDArray. + + Parameters + ---------- + host_array : numpy.ndarray + The array to be copied from worker-0. + remote_array : NDArray + The NDArray on worker-0. + """ + return _ffi_api.SessionCopyToWorker0(self, host_array, remote_array) # type: ignore # pylint: disable=no-member + + def load_vm_module( + self, + path: str, + device: Optional[Device] = None, + ) -> DModule: + """Load a VM module from a file. + + Parameters + ---------- + path : str + The path to the VM module file. + device : Optional[Device] = None + The device to load the VM module to. Default to the default device of each worker. + + Returns + ------- + module : DModule + The loaded VM module. + """ + if device is None: + device = Device(device_type=0, device_id=0) + func = self._get_cached_method("runtime.disco.load_vm_module") + return DModule(func(path, device)) + + def init_ccl(self, api: str, *args): + """Initialize the underlying communication collective library. + + Parameters + ---------- + api : str + The name of the communication collective library. Currently supported libraries are: + - nccl + - rccl + - mpi + *args : various types + The arguments to be passed to the initialization function of the communication + """ + assert api in ("nccl", "rccl"), f"Unsupported CCL backend: {api}" + func = self.get_global_func(f"runtime.disco.{api}.init") + func(*args) + + def broadcast_from_worker0(self, array: DRef) -> None: + """Broadcast an array from worker-0 to all other workers. + + Parameters + ---------- + array : DRef + The array to be broadcasted in-place + """ + func = self._get_cached_method("runtime.disco.broadcast_from_worker0") + return func(array) + + def allreduce( + self, + array: DRef, + op: str = "sum", # pylint: disable=invalid-name + ) -> DRef: + """Perform an allreduce operation on an array. + + Parameters + ---------- + array : DRef + The array to be reduced. + op : str = "sum" + The reduce operation to be performed. Available options are: + - "sum" + - "prod" + - "min" + - "max" + - "avg" + """ + func = self._get_cached_method("runtime.disco.allreduce") + if op not in REDUCE_OPS: + raise ValueError(f"Unsupported reduce op: {op}. Available ops are: {REDUCE_OPS.keys()}") + return func(array, REDUCE_OPS[op]) + + +@register_object("runtime.disco.ThreadedSession") +class ThreadedSession(Session): + """A Disco session backed by multi-threading.""" + + def __init__(self, num_workers: int) -> None: + """Create a disco session backed by multiple threads in the same process.""" + self.__init_handle_by_constructor__( + _ffi_api.SessionThreaded, # type: ignore # pylint: disable=no-member + num_workers, + ) + + +REDUCE_OPS = { + "sum": 0, + "prod": 1, + "min": 2, + "max": 3, + "avg": 4, +} diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 8a538c1868faa..8dcdaf2737e7c 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -150,6 +150,7 @@ zeros, zeros_like, nn, + ccl, ) from tvm.relax.op.builtin import stop_lift_params from tvm.relax.struct_info import StructInfo @@ -706,5 +707,6 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "zeros", "zeros_like", "nn", + "ccl", "erf", ] diff --git a/src/relax/op/ccl/ccl.cc b/src/relax/op/ccl/ccl.cc new file mode 100644 index 0000000000000..40c897532ddbc --- /dev/null +++ b/src/relax/op/ccl/ccl.cc @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "ccl.h" + +#include + +namespace tvm { +namespace relax { + +/* relax.ccl.allreduce */ +TVM_REGISTER_NODE_TYPE(AllReduceAttrs); + +Expr allreduce(Expr x, String op_type) { + ObjectPtr attrs = make_object(); + attrs->op_type = std::move(op_type); + + static const Op& op = Op::Get("relax.ccl.allreduce"); + return Call(op, {std::move(x)}, Attrs{attrs}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.ccl.allreduce").set_body_typed(allreduce); + +StructInfo InferStructInfoAllReduce(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + return input_sinfo; +} + +TVM_REGISTER_OP("relax.ccl.allreduce") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("x", "Tensor", "Input to which allreduce will be applied.") + .set_attr("FInferStructInfo", InferStructInfoAllReduce) + .set_attr("FRelaxInferLayout", InferLayoutUnaryEwise) + .set_attr("FPurity", Bool(true)); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/ccl/ccl.h b/src/relax/op/ccl/ccl.h new file mode 100644 index 0000000000000..f87512c138c3b --- /dev/null +++ b/src/relax/op/ccl/ccl.h @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ccl.h + * \brief The functions to make Relax ccl operator calls. + */ + +#ifndef TVM_RELAX_OP_CCL_CCL_H_ +#define TVM_RELAX_OP_CCL_CCL_H_ + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! \brief AllReduce. */ +Expr allreduce(Expr data, String op_type); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_CCL_CCL_H_ diff --git a/src/runtime/disco/bcast_session.cc b/src/runtime/disco/bcast_session.cc new file mode 100644 index 0000000000000..57b57b09f62ba --- /dev/null +++ b/src/runtime/disco/bcast_session.cc @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./bcast_session.h" + +#include +#include +#include + +namespace tvm { +namespace runtime { + +struct BcastSessionObj::Internal { + template + static void TVM_ALWAYS_INLINE BroadcastUnpacked(BcastSessionObj* self, DiscoAction action, + int64_t reg_id, Args&&... args) { + constexpr int kNumArgs = 2 + sizeof...(Args); + TVMValue values[kNumArgs]; + int type_codes[kNumArgs]; + PackArgs(values, type_codes, static_cast(action), reg_id, std::forward(args)...); + self->BroadcastPacked(TVMArgs(values, type_codes, kNumArgs)); + } + + static DRef MakeDRef(int reg_id, Session session) { + ObjectPtr p = make_object(); + p->reg_id = reg_id; + p->session = session; + return DRef(std::move(p)); + } +}; + +DRef BcastSessionObj::GetGlobalFunc(const std::string& name) { + int reg_id = AllocateReg(); + BcastSessionObj::Internal::BroadcastUnpacked(this, DiscoAction::kGetGlobalFunc, reg_id, name); + return BcastSessionObj::Internal::MakeDRef(reg_id, GetRef(this)); +} + +void BcastSessionObj::CopyFromWorker0(const NDArray& host_array, const DRef& remote_array) { + this->AppendHostNDArray(host_array); + BcastSessionObj::Internal::BroadcastUnpacked(this, DiscoAction::kCopyFromWorker0, + remote_array->reg_id); +} + +void BcastSessionObj::CopyToWorker0(const NDArray& host_array, const DRef& remote_array) { + this->AppendHostNDArray(host_array); + BcastSessionObj::Internal::BroadcastUnpacked(this, DiscoAction::kCopyToWorker0, + remote_array->reg_id); +} + +void BcastSessionObj::Shutdown() { + BcastSessionObj::Internal::BroadcastUnpacked(this, DiscoAction::kShutDown, 0); +} + +void BcastSessionObj::SyncWorker(int worker_id) { + BcastSessionObj::Internal::BroadcastUnpacked(this, DiscoAction::kSyncWorker, worker_id); + TVMArgs args = this->RecvReplyPacked(worker_id); + ICHECK_EQ(args.size(), 2); + DiscoAction action = static_cast(args[0].operator int()); + int ret_worker_id = args[1]; + ICHECK(action == DiscoAction::kSyncWorker); + ICHECK_EQ(ret_worker_id, worker_id); +} + +DRef BcastSessionObj::CallWithPacked(const TVMArgs& args) { + constexpr int offset = 3; + TVMValue* values = const_cast(args.values); + int* type_codes = const_cast(args.type_codes); + int num_args = args.num_args; + for (int i = offset; i < num_args; ++i) { + int type_code = type_codes[i]; + CHECK(type_code != kTVMDLTensorHandle) + << "ValueError: Cannot pass DLTensor to Session.CallPacked"; + CHECK(type_code != kTVMModuleHandle) << "ValueError: Cannot pass Module to Session.CallPacked "; + CHECK(type_code != kTVMPackedFuncHandle) + << "ValueError: Cannot pass PackedFunc to Session.CallPacked"; + CHECK(type_code != kTVMNDArrayHandle) + << "ValueError: Cannot pass NDArray to Session.CallPacked"; + CHECK(type_code != kTVMObjectRValueRefArg) + << "ValueError: Cannot pass RValue to Session.CallPacked"; + CHECK(type_code != kTVMObjectHandle || TVMArgValue(values[i], type_code).IsObjectRef()) + << "ValueError: Cannot pass Object to Session.CallPacked"; + } + int reg_id = AllocateReg(); + { + TVMArgsSetter setter(values, type_codes); + DRef func = args[2]; + setter(0, static_cast(DiscoAction::kCallPacked)); + setter(1, reg_id); + setter(2, func->reg_id); + } + this->BroadcastPacked(TVMArgs(values, type_codes, num_args)); + return BcastSessionObj::Internal::MakeDRef(reg_id, GetRef(this)); +} + +void BcastSessionObj::DeallocReg(int reg_id) { + BcastSessionObj::Internal::BroadcastUnpacked(this, DiscoAction::kKillReg, reg_id); + this->free_regs_.push_back(reg_id); +} + +int BcastSessionObj::AllocateReg() { + if (this->free_regs_.empty()) { + return this->reg_count_++; + } + int reg_id = this->free_regs_.back(); + this->free_regs_.pop_back(); + return reg_id; +} + +void BcastSessionObj::AppendHostNDArray(const NDArray& host_array) { + std::lock_guard lock(worker_zero_data_.queue_mutex_); + worker_zero_data_.host_arrays.push(host_array); +} + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/disco/bcast_session.h b/src/runtime/disco/bcast_session.h new file mode 100644 index 0000000000000..0221207b96f2c --- /dev/null +++ b/src/runtime/disco/bcast_session.h @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_RUNTIME_DISCO_BCAST_SESSION_H_ +#define TVM_RUNTIME_DISCO_BCAST_SESSION_H_ + +#include + +#include +#include + +#include "./worker.h" + +namespace tvm { +namespace runtime { + +/*! + * \brief A Disco interactive session. It allows users to interact with the Disco command queue with + * various PackedFunc calling convention. + */ +class BcastSessionObj : public SessionObj { + public: + virtual ~BcastSessionObj() = default; + + DRef GetGlobalFunc(const std::string& name) override; + void CopyFromWorker0(const NDArray& host_array, const DRef& remote_array) override; + void CopyToWorker0(const NDArray& host_array, const DRef& remote_array) override; + void SyncWorker(int worker_id) override; + void Shutdown() override; + TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) override = 0; + + protected: + /*! \brief Deallocate a register id, kill it on all workers, and append it to `free_regs_`. */ + void DeallocReg(int reg_id) override; + /*! \brief Call packed function on each worker using a packed sequence */ + DRef CallWithPacked(const TVMArgs& args) override; + /*! \brief Allocate a register id, either from `free_regs_` or by incrementing `reg_count_` */ + virtual int AllocateReg(); + /*! + * \brief Append an controler-side NDArray to a special queue used to communicate with + worker-0. + * \param host_array The array to be appended to worker-0 + */ + virtual void AppendHostNDArray(const NDArray& host_array); + /*! + * \brief Broadcast a command to all workers via TVM's PackedFunc calling convention. + * As part of the calling convention, The first argument in the packed sequence must be + * the action, and the second argument must be the register id. + * \param TVMArgs The input arguments in TVM's PackedFunc calling convention + */ + virtual void BroadcastPacked(const TVMArgs& args) = 0; + /*! + * \brief Receive a packed sequence from a worker. This function is usually called by the + * controler to communicate with worker-0, because the worker-0 is assumed to be always + collocated + * with the controler. Receiving from other workers may not be supported. + * \return The packed sequence received. + */ + virtual TVMArgs RecvReplyPacked(int worker_id) = 0; + + /*! \brief A side channel to communicate with worker-0 */ + WorkerZeroData worker_zero_data_; + /*! \brief Number of registers used, including those in `free_regs_` */ + int reg_count_ = 1; + /*! \brief The regsiter ids that have been deallocated */ + std::vector free_regs_; + + struct Internal; + friend struct Internal; +}; + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_DISCO_BCAST_SESSION_H_ diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc new file mode 100644 index 0000000000000..656b359839a02 --- /dev/null +++ b/src/runtime/disco/builtin.cc @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "./utils.h" +#include "./worker.h" + +namespace tvm { +namespace runtime { + +class DSOLibraryCache { + public: + Module Open(const std::string& library_path) { + std::lock_guard lock(mutex_); + Module& lib = cache_[library_path]; + if (!lib.defined()) { + lib = Module::LoadFromFile(library_path, ""); + } + return lib; + } + + std::unordered_map cache_; + std::mutex mutex_; +}; + +Module LoadVMModule(std::string path, Device device) { + static DSOLibraryCache cache; + Module dso_mod = cache.Open(path); + device = UseDefaultDeviceIfNone(device); + PackedFunc vm_load_executable = dso_mod.GetFunction("vm_load_executable"); + CHECK(vm_load_executable != nullptr) + << "ValueError: File `" << path + << "` is not built by RelaxVM, because `vm_load_executable` does not exist"; + Module mod = vm_load_executable(); + PackedFunc vm_initialization = mod.GetFunction("vm_initialization"); + CHECK(vm_initialization != nullptr) + << "ValueError: File `" << path + << "` is not built by RelaxVM, because `vm_initialization` does not exist"; + vm_initialization(static_cast(device.device_type), // + static_cast(device.device_id), // + static_cast(relax_vm::AllocatorType::kPooled), // + static_cast(kDLCPU), // + 0, // + static_cast(relax_vm::AllocatorType::kPooled)); + return mod; +} + +TVM_REGISTER_GLOBAL("runtime.disco.load_vm_module").set_body_typed(LoadVMModule); + +TVM_REGISTER_GLOBAL("runtime.disco.empty").set_body([](TVMArgs args, TVMRetValue* rv) -> void { + runtime::DataType dtype = args[args.num_args - 2]; + Device device = args[args.num_args - 1]; + int ndim = args.num_args - 2; + std::vector shape; + for (int i = 0; i < ndim; ++i) { + shape.push_back(args[i].operator int64_t()); + } + device = UseDefaultDeviceIfNone(device); + *rv = NDArray::Empty(ShapeTuple(shape), dtype, device); +}); + +TVM_REGISTER_GLOBAL("runtime.disco.allreduce") + .set_body_typed([](ObjectRef obj, TVMArgValue arg_op_type) -> TVMRetValue { + int op_type = -1; + if (arg_op_type.IsObjectRef()) { + ShapeTuple op = arg_op_type; + CHECK_EQ(op.size(), 1) << "ValueError: `op` must be an integer enum class"; + op_type = op[0]; + } else { + op_type = arg_op_type; + } + std::string ccl = DiscoWorker::ThreadLocal()->ccl; + std::string pf_name = "runtime.disco." + ccl + ".allreduce"; + const PackedFunc* pf = tvm::runtime::Registry::Get(pf_name); + CHECK(pf != nullptr) << "ValueError: Cannot find the allreduce function for " << ccl + << " via `" << pf_name << "`"; + return (*pf)(obj, op_type); + }); + +TVM_REGISTER_GLOBAL("runtime.disco.broadcast_from_worker0") + .set_body([](TVMArgs args, TVMRetValue* rv) -> void { + std::string ccl = DiscoWorker::ThreadLocal()->ccl; + std::string pf_name = "runtime.disco." + ccl + ".broadcast_from_worker0"; + const PackedFunc* pf = tvm::runtime::Registry::Get(pf_name); + CHECK(pf != nullptr) << "ValueError: Cannot find the broadcast function for " << ccl + << " via `" << pf_name << "`"; + pf->CallPacked(args, rv); + }); + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc new file mode 100644 index 0000000000000..c1bbfb72f855d --- /dev/null +++ b/src/runtime/disco/nccl/nccl.cc @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include + +#include +#include + +#include "../../cuda/cuda_common.h" +#include "./utils.h" + +namespace tvm { +namespace runtime { +namespace nccl { + +struct NCCLGlobalContext { + // TODO(@junrushao): support more flexible communicator pattern for generic SPMD usecases + std::vector device_ids; + std::vector communicators; + std::vector streams; + + ~NCCLGlobalContext() {} + + void Clear() { + for (ncclComm_t comm : this->communicators) { + NCCL_CALL(ncclCommDestroy(comm)); + } + device_ids.clear(); + communicators.clear(); + } + + static NCCLGlobalContext* Get() { + static NCCLGlobalContext ctx; + return &ctx; + } + + void Initialize(std::vector device_ids) { + DiscoWorker* worker = DiscoWorker::ThreadLocal(); + int num_workers = worker->num_workers; + CHECK_EQ(device_ids.size(), num_workers) + << "ValueError: There are " << num_workers << " worker(s), but " << device_ids.size() + << " device id(s) are provided."; + ncclUniqueId id; + NCCL_CALL(ncclGetUniqueId(&id)); + NCCL_CALL(ncclGroupStart()); + for (int worker_id = 0; worker_id < num_workers; ++worker_id) { + int device_id = device_ids[worker_id]; + ncclComm_t comm; + cudaStream_t stream; + CUDA_CALL(cudaSetDevice(device_id)); + CUDA_CALL(cudaStreamCreate(&stream)); + NCCL_CALL(ncclCommInitRank(&comm, num_workers, id, worker_id)); + this->streams.push_back(stream); + this->communicators.push_back(comm); + } + NCCL_CALL(ncclGroupEnd()); + this->device_ids = std::move(device_ids); + } + + static ncclComm_t ThreadLocalCommunicator() { + thread_local static ncclComm_t comm = + NCCLGlobalContext::Get()->communicators[DiscoWorker::ThreadLocal()->worker_id]; + return comm; + } + + static cudaStream_t ThreadLocalStream() { + thread_local static cudaStream_t stream = + NCCLGlobalContext::Get()->streams[DiscoWorker::ThreadLocal()->worker_id]; + return stream; + } +}; + +inline int64_t GetNumel(const ShapeTuple& shape) { + int64_t numel = 1; + for (int64_t d : shape) { + numel *= d; + } + return numel; +} + +NDArray AllReduce(NDArray send, int _reduce_kind) { + ShapeTuple shape = send.Shape(); + int64_t numel = GetNumel(shape); + NDArray recv = NDArray::Empty(shape, send->dtype, send->device); + ReduceKind reduce_kind = static_cast(_reduce_kind); + NCCL_CALL(ncclAllReduce(send->data, recv->data, numel, + /*datatype=*/AsNCCLDataType(DataType(send->dtype)), + /*op=*/AsNCCLRedOp(reduce_kind), + /*comm=*/NCCLGlobalContext::ThreadLocalCommunicator(), + /*stream=*/NCCLGlobalContext::ThreadLocalStream())); + return recv; +} + +void BroadcastFromZero(NDArray buffer) { + ShapeTuple shape = buffer.Shape(); + int64_t numel = GetNumel(shape); + NCCL_CALL(ncclBroadcast(buffer->data, buffer->data, numel, + /*datatype=*/AsNCCLDataType(DataType(buffer->dtype)), // + /*root=*/0, NCCLGlobalContext::ThreadLocalCommunicator(), + NCCLGlobalContext::ThreadLocalStream())); +} + +TVM_REGISTER_GLOBAL("runtime.disco.nccl.init").set_body([](TVMArgs args, TVMRetValue* rv) -> void { + // Parse the inputs into device ids + std::vector device_ids; + for (int i = 0; i < args.num_args; ++i) { + device_ids.push_back(args[i].operator int()); + } + // Set the `default_device` and `ccl` for the current worker + DiscoWorker* worker = DiscoWorker::ThreadLocal(); + worker->default_device = Device{DLDeviceType::kDLCUDA, device_ids[worker->worker_id]}; + worker->ccl = "nccl"; + // Setup global context only once + static std::once_flag flag; + std::call_once(flag, [&]() { NCCLGlobalContext::Get()->Initialize(device_ids); }); +}); + +TVM_REGISTER_GLOBAL("runtime.disco.nccl.allreduce").set_body_typed(AllReduce); +TVM_REGISTER_GLOBAL("runtime.disco.nccl.broadcast_from_worker0").set_body_typed(BroadcastFromZero); + +} // namespace nccl +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/disco/nccl/utils.h b/src/runtime/disco/nccl/utils.h new file mode 100644 index 0000000000000..4e5fb8cd74636 --- /dev/null +++ b/src/runtime/disco/nccl/utils.h @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_RUNTIME_DISCO_NCCL_UTILS_H_ +#define TVM_RUNTIME_DISCO_NCCL_UTILS_H_ + +#include +#include +#include + +#include "../utils.h" + +namespace tvm { +namespace runtime { +namespace nccl { + +#define NCCL_CALL(cmd) \ + do { \ + ncclResult_t r = cmd; \ + if (r != ncclSuccess) { \ + LOG(FATAL) << "NCCLErrror: " << ncclGetErrorString(r); \ + } \ + } while (0) + +inline ncclDataType_t AsNCCLDataType(runtime::DataType dtype) { + if (dtype == DataType::Int(8)) { + return ncclInt8; + } + if (dtype == DataType::UInt(8)) { + return ncclUint8; + } + if (dtype == DataType::Int(32)) { + return ncclInt32; + } + if (dtype == DataType::UInt(32)) { + return ncclUint32; + } + if (dtype == DataType::Int(64)) { + return ncclInt64; + } + if (dtype == DataType::UInt(64)) { + return ncclUint64; + } + if (dtype == DataType::Float(16)) { + return ncclFloat16; + } + if (dtype == DataType::Float(32)) { + return ncclFloat32; + } + if (dtype == DataType::Float(64)) { + return ncclFloat64; + } + if (dtype == DataType::BFloat(16)) { + return ncclBfloat16; + } + LOG(FATAL) << "ValueError: Unsupported data type " << dtype; +} + +inline ncclRedOp_t AsNCCLRedOp(ReduceKind kind) { + switch (kind) { + case ReduceKind::kSum: + return ncclSum; + case ReduceKind::kProd: + return ncclProd; + case ReduceKind::kMin: + return ncclMin; + case ReduceKind::kMax: + return ncclMax; + case ReduceKind::kAvg: + return ncclAvg; + } + LOG(FATAL) << "ValueError: Unknown ReduceKind: " << static_cast(kind); +} + +} // namespace nccl +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_DISCO_NCCL_UTILS_H_ diff --git a/src/runtime/disco/session.cc b/src/runtime/disco/session.cc new file mode 100644 index 0000000000000..20c6e096afd48 --- /dev/null +++ b/src/runtime/disco/session.cc @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include + +#include "./worker.h" + +namespace tvm { +namespace runtime { + +struct SessionObj::FFI { + static DRef CallWithPacked(Session sess, const TVMArgs& args) { + return sess->CallWithPacked(args); + } +}; + +TVM_REGISTER_OBJECT_TYPE(DRefObj); +TVM_REGISTER_OBJECT_TYPE(SessionObj); +TVM_REGISTER_GLOBAL("runtime.disco.SessionThreaded").set_body_typed(Session::ThreadedSession); +TVM_REGISTER_GLOBAL("runtime.disco.DRefSession").set_body_typed([](DRef obj) { + return obj->session; +}); +TVM_REGISTER_GLOBAL("runtime.disco.DRefDebugGetFromRemote") + .set_body_method(&DRefObj::DebugGetFromRemote); +TVM_REGISTER_GLOBAL("runtime.disco.SessionGetGlobalFunc") + .set_body_method(&SessionObj::GetGlobalFunc); +TVM_REGISTER_GLOBAL("runtime.disco.SessionCopyFromWorker0") + .set_body_method(&SessionObj::CopyFromWorker0); +TVM_REGISTER_GLOBAL("runtime.disco.SessionCopyToWorker0") + .set_body_method(&SessionObj::CopyToWorker0); +TVM_REGISTER_GLOBAL("runtime.disco.SessionSyncWorker") + .set_body_method(&SessionObj::SyncWorker); +TVM_REGISTER_GLOBAL("runtime.disco.SessionCallPacked").set_body([](TVMArgs args, TVMRetValue* rv) { + Session self = args[0]; + *rv = SessionObj::FFI::CallWithPacked( + self, TVMArgs(args.values + 1, args.type_codes + 1, args.num_args - 1)); +}); + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/disco/threaded_session.cc b/src/runtime/disco/threaded_session.cc new file mode 100644 index 0000000000000..6e256f2cf08a3 --- /dev/null +++ b/src/runtime/disco/threaded_session.cc @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "../../support/arena.h" +#include "../../support/ring_buffer.h" +#include "../minrpc/rpc_reference.h" +#include "./bcast_session.h" +#include "./worker.h" + +namespace tvm { +namespace runtime { + +class DiscoThreadedMessageQueue : public dmlc::Stream { + public: + void Send(const TVMArgs& args) { + RPCReference::ReturnPackedSeq(args.values, args.type_codes, args.num_args, this); + NotifyEnqueue(); + } + + TVMArgs Recv() { + WaitDequeue(); + this->dref_arena_.clear(); + uint64_t packet_nbytes = 0; + RPCCode code = RPCCode::kReturn; + this->Read(&packet_nbytes); + this->Read(&code); + TVMValue* values = nullptr; + int* type_codes = nullptr; + int num_args = 0; + RPCReference::RecvPackedSeq(&values, &type_codes, &num_args, this); + return TVMArgs(values, type_codes, num_args); + } + + protected: + void NotifyEnqueue() { + { + std::lock_guard lock{mutex_}; + ++msg_cnt_; + } + condition_.notify_one(); + } + + void WaitDequeue() { + std::unique_lock lock(mutex_); + condition_.wait(lock, [this] { return msg_cnt_.load() > 0; }); + --msg_cnt_; + } + + void MessageStart(uint64_t packet_nbytes) { + std::lock_guard lock(mutex_); + size_t n = ring_buffer_.bytes_available(); + n += packet_nbytes + sizeof(uint64_t); + this->ring_buffer_.Reserve(n); + } + + void MessageDone() {} + + void ThrowError(RPCServerStatus status) { + LOG(FATAL) << "InternalError: Unexpected error in RPC: " << RPCServerStatusToString(status); + } + + template + T* ArenaAlloc(int count) { + static_assert(std::is_pod::value, "need to be trival"); + return arena_.template allocate_(count); + } + + size_t Read(void* data, size_t size) final { + std::lock_guard lock(mutex_); + ring_buffer_.Read(data, size); + return size; + } + + void Write(const void* data, size_t size) final { + std::lock_guard lock(mutex_); + ring_buffer_.Write(data, size); + } + + uint64_t GetObjectBytes(Object* obj) { + if (obj->IsInstance()) { + return sizeof(int64_t); + } else { + LOG(FATAL) << "ValueError: Object type is not supported in Disco calling convention: " + << obj->GetTypeKey() << " (type_index = " << obj->type_index() << ")"; + } + } + + void WriteObject(Object* obj) { + if (obj->IsInstance()) { + int64_t reg_id = static_cast(obj)->reg_id; + this->Write(TypeIndex::kRuntimeDiscoDRef); + this->Write(reg_id); + } else { + LOG(FATAL) << "ValueError: Object type is not supported in Disco calling convention: " + << obj->GetTypeKey() << " (type_index = " << obj->type_index() << ")"; + } + } + + void ReadObject(int* tcode, TVMValue* value) { + uint32_t type_index; + this->Read(&type_index); + if (type_index == TypeIndex::kRuntimeDiscoDRef) { + ObjectPtr dref = make_object(); + this->Read(&dref->reg_id); + dref->session = Session{nullptr}; + *tcode = kTVMObjectHandle; + value->v_handle = dref.get(); + dref_arena_.push_back(std::move(dref)); + } else { + LOG(FATAL) << "ValueError: Object type is not supported in Disco calling convention: " + << Object::TypeIndex2Key(type_index) << " (type_index = " << type_index << ")"; + } + } + + using dmlc::Stream::Read; + using dmlc::Stream::ReadArray; + using dmlc::Stream::Write; + using dmlc::Stream::WriteArray; + friend struct RPCReference; + + std::mutex mutex_; + std::atomic msg_cnt_{0}; + std::condition_variable condition_; + + support::RingBuffer ring_buffer_; + support::Arena arena_; + std::vector> dref_arena_; +}; + +class DiscoThreadChannel final : public DiscoChannel { + public: + void Send(const TVMArgs& args) { controler_to_worker_.Send(args); } + TVMArgs Recv() { return controler_to_worker_.Recv(); } + void Reply(const TVMArgs& args) { worker_to_controler_.Send(args); } + TVMArgs RecvReply() { return worker_to_controler_.Recv(); } + + DiscoThreadedMessageQueue controler_to_worker_; + DiscoThreadedMessageQueue worker_to_controler_; +}; + +class ThreadedSessionObj final : public BcastSessionObj { + public: + explicit ThreadedSessionObj(int num_workers) { + for (int i = 0; i < num_workers; ++i) { + std::unique_ptr channel = std::make_unique(); + WorkerZeroData* data = (i == 0) ? &worker_zero_data_ : nullptr; + workers_.emplace_back(std::make_unique(i, num_workers, data, channel.get())); + channels_.emplace_back(std::move(channel)); + worker_threads_.emplace_back([worker = workers_.back().get()] { worker->MainLoop(); }); + } + } + + ~ThreadedSessionObj() { + this->Shutdown(); + for (std::thread& worker : this->worker_threads_) { + worker.join(); + } + } + + TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) { + this->SyncWorker(worker_id); + return this->workers_.at(worker_id)->register_file.at(reg_id); + } + + void BroadcastPacked(const TVMArgs& args) final { + for (const std::unique_ptr& channel : this->channels_) { + channel->Send(args); + } + } + + TVMArgs RecvReplyPacked(int worker_id) final { return channels_[worker_id]->RecvReply(); } + + static constexpr const char* _type_key = "runtime.disco.ThreadedSession"; + TVM_DECLARE_FINAL_OBJECT_INFO(ThreadedSessionObj, SessionObj); + + std::vector> channels_; + std::vector> workers_; + std::vector worker_threads_; +}; + +TVM_REGISTER_OBJECT_TYPE(ThreadedSessionObj); + +Session Session::ThreadedSession(int num_workers) { + ObjectPtr n = make_object(num_workers); + return Session(std::move(n)); +} + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/disco/utils.h b/src/runtime/disco/utils.h new file mode 100644 index 0000000000000..9303ed15cb8c8 --- /dev/null +++ b/src/runtime/disco/utils.h @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_RUNTIME_DISCO_UTILS_H_ +#define TVM_RUNTIME_DISCO_UTILS_H_ + +#include +#include + +#include + +#include "./worker.h" + +namespace tvm { +namespace runtime { + +inline Device UseDefaultDeviceIfNone(Device device) { + if (device.device_type == 0 && device.device_id == 0) { + return DiscoWorker::ThreadLocal()->default_device; + } + return device; +} + +/*! + * \brief Possible kinds of reduction operations. + */ +enum class ReduceKind : int32_t { + kSum = 0, + kProd = 1, + kMin = 2, + kMax = 3, + kAvg = 4, +}; + +/*! \brief Converts `ReduceKind` to string */ +inline std::string ReduceKind2String(ReduceKind kind) { + switch (kind) { + case ReduceKind::kSum: + return "kSum"; + case ReduceKind::kProd: + return "kProd"; + case ReduceKind::kMin: + return "kMin"; + case ReduceKind::kMax: + return "kMax"; + case ReduceKind::kAvg: + return "kAvg"; + } + LOG(FATAL) << "ValueError: Unknown ReduceKind: " << static_cast(kind); +} + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_DISCO_UTILS_H_ diff --git a/src/runtime/disco/worker.cc b/src/runtime/disco/worker.cc new file mode 100644 index 0000000000000..b951f009acefe --- /dev/null +++ b/src/runtime/disco/worker.cc @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./worker.h" + +#include + +#include + +namespace tvm { +namespace runtime { + +struct ThreadLocalDiscoWorker { + DiscoWorker* worker; + + static ThreadLocalDiscoWorker* Get() { + thread_local static ThreadLocalDiscoWorker worker; + return &worker; + } +}; + +DiscoWorker* DiscoWorker::ThreadLocal() { return ThreadLocalDiscoWorker::Get()->worker; } + +struct DiscoWorker::Impl { + static void MainLoop(DiscoWorker* self) { + ThreadLocalDiscoWorker::Get()->worker = self; + LOG(INFO) << "[Thread " << std::this_thread::get_id() << "] Worker #" << self->worker_id + << " Launched"; + while (true) { + TVMArgs args = self->channel->Recv(); + DiscoAction action = static_cast(args[0].operator int()); + int64_t reg_id = args[1]; + switch (action) { + case DiscoAction::kShutDown: { + Shutdown(self); + return; + } + case DiscoAction::kKillReg: { + GetReg(self, reg_id) = nullptr; + break; + } + case DiscoAction::kGetGlobalFunc: { + GetGlobalFunc(self, reg_id, args[2]); + break; + } + case DiscoAction::kCallPacked: { + int func_reg_id = args[2]; + PackedFunc func = GetReg(self, func_reg_id); + CallPacked(self, reg_id, func, + TVMArgs(args.values + 3, args.type_codes + 3, args.num_args - 3)); + break; + } + case DiscoAction::kCopyFromWorker0: { + CopyFromWorker0(self, reg_id); + break; + } + case DiscoAction::kCopyToWorker0: { + CopyToWorker0(self, reg_id); + break; + } + case DiscoAction::kSyncWorker: { + SyncWorker(self, reg_id); + break; + } + } + } + } + + static void Shutdown(DiscoWorker* self) {} + + static void GetGlobalFunc(DiscoWorker* self, int reg_id, const std::string& name) { + const PackedFunc* pf = runtime::Registry::Get(name); + CHECK(pf) << "ValueError: Cannot find global function: " << name; + if (reg_id != 0) { + GetReg(self, reg_id) = *pf; + } + } + + static NDArray GetNDArrayFromHost(DiscoWorker* self) { + std::lock_guard lock(self->worker_zero_data->queue_mutex_); + NDArray array = self->worker_zero_data->host_arrays.front(); + self->worker_zero_data->host_arrays.pop(); + return array; + } + + static void CopyFromWorker0(DiscoWorker* self, int reg_id) { + if (self->worker_zero_data != nullptr) { + NDArray tgt = GetNDArrayFromHost(self); + NDArray src = GetReg(self, reg_id); + tgt.CopyFrom(src); + } + } + + static void CopyToWorker0(DiscoWorker* self, int reg_id) { + if (self->worker_zero_data != nullptr) { + NDArray src = GetNDArrayFromHost(self); + NDArray tgt = GetReg(self, reg_id); + tgt.CopyFrom(src); + } + } + + static void SyncWorker(DiscoWorker* self, int worker_id) { + if (worker_id == self->worker_id) { + TVMValue values[2]; + int type_codes[2]; + PackArgs(values, type_codes, static_cast(DiscoAction::kSyncWorker), worker_id); + self->channel->Reply(TVMArgs(values, type_codes, 2)); + } + } + + static void CallPacked(DiscoWorker* self, int64_t ret_reg_id, PackedFunc func, + const TVMArgs& args) { + TVMValue* values = const_cast(args.values); + int* type_codes = const_cast(args.type_codes); + int num_args = args.num_args; + TVMArgsSetter setter(values, type_codes); + for (int i = 0; i < num_args; ++i) { + TVMArgValue val = TVMArgValue(values[i], type_codes[i]); + if (val.IsObjectRef()) { + DRef dref = val; + setter(i, GetReg(self, dref->reg_id)); + } + } + TVMRetValue rv; + func.CallPacked(TVMArgs(values, type_codes, num_args), &rv); + GetReg(self, ret_reg_id) = std::move(rv); + } + + static TVMRetValue& GetReg(DiscoWorker* self, int64_t reg_id) { + if (reg_id >= static_cast(self->register_file.size())) { + self->register_file.resize(reg_id + 1); + } + return self->register_file[reg_id]; + } +}; + +void DiscoWorker::MainLoop() { DiscoWorker::Impl::MainLoop(this); } + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/disco/worker.h b/src/runtime/disco/worker.h new file mode 100644 index 0000000000000..f10382b06877f --- /dev/null +++ b/src/runtime/disco/worker.h @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file worker.h + * \brief This file defines a worker in Disco. A worker can be launched in a separate thread or + * process as long as the channel supports bi-directional communication in-between the worker and + * the controler. + */ +#ifndef TVM_RUNTIME_DISCO_WORKER_H_ +#define TVM_RUNTIME_DISCO_WORKER_H_ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace runtime { + +/*! + * \brief A special communication channel between controler and worker-0, + * assuming they are always collocated in the same process. + */ +class WorkerZeroData { + public: + /*! + * \brief The host-side arrays to passed to worker-0 for special uses, for example, + * copy-to-worker0 and copy-from-worker0 + */ + std::queue host_arrays; + /*! \brief The mutex that guards `host_arrays` */ + std::mutex queue_mutex_; +}; + +/*! + * \brief A worker in Disco. It takes a channel to communication with the controler. + * The worker can be run in a separate thread or process as long as the channel supports + * bi-directional communication in-between. + */ +class DiscoWorker { + public: + /*! + * \brief Construct a worker. + * \param worker_id The id of the worker. + * \param worker_zero_data The data shared between worker-0 and the controler. It's a nullptr if + * the worker is not worker-0. + * \param channel The communication channel between the worker and the controler. + */ + explicit DiscoWorker(int worker_id, int num_workers, WorkerZeroData* worker_zero_data, + DiscoChannel* channel) + : worker_id(worker_id), + num_workers(num_workers), + default_device(Device{DLDeviceType::kDLCPU, 0}), + worker_zero_data(worker_zero_data), + channel(channel), + register_file{} {} + + /*! \brief Main loop of the worker */ + void MainLoop(); + /*! \brief Get the worker instance on the current thread */ + static DiscoWorker* ThreadLocal(); + + /*! \brief The id of the worker.*/ + int worker_id; + /*! \brief Total number of workers */ + int num_workers; + /*! \brief The default device to allocate data if not specified */ + Device default_device; + /*! \brief The name of the underlying collective communication library. */ + String ccl; + /*! + * \brief The data shared between worker-0 and the controler. It's a nullptr if + * the worker is not worker-0. + * \note This data structure is owned by the controler. + */ + WorkerZeroData* worker_zero_data; + /*! + * \brief The communication channel between the worker and the controler. + * \note This data structure is owned by the controler. + */ + DiscoChannel* channel; + /*! \brief The registers in the worker */ + std::vector register_file; + + struct Impl; + friend struct DiscoWorker::Impl; +}; + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_DISCO_WORKER_H_ diff --git a/tests/python/disco/test_nccl.py b/tests/python/disco/test_nccl.py new file mode 100644 index 0000000000000..9fff3c9e33bdd --- /dev/null +++ b/tests/python/disco/test_nccl.py @@ -0,0 +1,184 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +"""Tests for NCCL""" +import tempfile + +import numpy as np + +import tvm +from tvm import dlight as dl +from tvm import relax as rx +from tvm.runtime import disco as di +from tvm.runtime.relax_vm import VirtualMachine +from tvm.script import relax as R + + +def test_init(): + num_workers = 2 + devices = [1, 2] + + sess = di.ThreadedSession(num_workers=num_workers) + sess.init_ccl("nccl", *devices) + + +def test_allreduce(): + num_workers = 2 + devices = [1, 2] + sess = di.ThreadedSession(num_workers=num_workers) + sess.init_ccl("nccl", *devices) + d_array = sess.empty((3, 4), "float32") + + array_1 = np.arange(12, dtype="float32").reshape(3, 4) + array_2 = np.arange(start=1, stop=-11, step=-1, dtype="float32").reshape(3, 4) + + d_array.debug_get_from_remote(0).copyfrom(array_1) + d_array.debug_get_from_remote(1).copyfrom(array_2) + + for op, np_op in [ # pylint: disable=invalid-name + ("sum", np.add), + ("prod", np.multiply), + ("min", np.minimum), + ("max", np.maximum), + ("avg", lambda a, b: (a + b) * 0.5), + ]: + result = sess.allreduce(d_array, op=op) + sess.sync_worker(0) + result = result.debug_get_from_remote(0).numpy() + expected = np_op(array_1, array_2) + np.testing.assert_equal(result, expected) + + +def test_broadcast_from_zero(): + num_workers = 2 + devices = [1, 2] + array = np.arange(12, dtype="float32").reshape(3, 4) + + sess = di.ThreadedSession(num_workers=num_workers) + sess.init_ccl("nccl", *devices) + d_array = sess.empty((3, 4), "float32") + d_array.debug_get_from_remote(0).copyfrom(array) + + sess.broadcast_from_worker0(d_array) + result = d_array.debug_get_from_remote(1).numpy() + np.testing.assert_equal(result, array) + + +def test_mlp(): # pylint: disable=too-many-locals + num_workers = 2 + devices = [1, 2] + + # pylint: disable=invalid-name + @tvm.script.ir_module + class MLP: # pylint: disable=too-few-public-methods + @R.function + def main( + x: R.Tensor((128, 128), "float32"), + W1: R.Tensor((128, 128), "float32"), + W2: R.Tensor((128, 128), "float32"), + ) -> R.Tensor((128, 128), "float32"): + R.func_attr({"global_symbol": "main"}) + with R.dataflow(): + lv0: R.Tensor((128, 128), "float32") = R.matmul(x, W1) + lv1: R.Tensor((128, 128), "float32") = R.nn.gelu(lv0) + lv2: R.Tensor((128, 128), "float32") = R.matmul(lv1, W2) + R.output(lv2) + return lv2 + + @tvm.script.ir_module + class ShardedMLP: # pylint: disable=too-few-public-methods + @R.function + def main( + x: R.Tensor((128, 128), "float32"), + W1: R.Tensor((128, 64), "float32"), # shard along axis 1 + W2: R.Tensor((64, 128), "float32"), # shard along axis 0 + ) -> R.Tensor((128, 128), "float32"): + R.func_attr({"global_symbol": "main"}) + with R.dataflow(): + lv0: R.Tensor((128, 64), "float32") = R.matmul(x, W1) + lv1: R.Tensor((128, 64), "float32") = R.nn.gelu(lv0) + lv2: R.Tensor((128, 128), "float32") = R.matmul(lv1, W2) + lv3: R.Tensor((128, 128), "float32") = R.ccl.allreduce(lv2, "sum") + R.output(lv3) + return lv3 + + # pylint: enable=invalid-name + target = tvm.target.Target( + { + "kind": "cuda", + "max_shared_memory_per_block": 49152, + "max_threads_per_block": 1024, + "thread_warp_size": 32, + "registers_per_block": 65536, + "arch": "sm_80", + } + ) + + def relax_build(mod, target): + with target: + mod = rx.get_pipeline("zero")(mod) # pylint: disable=no-value-for-parameter + mod = dl.ApplyDefaultSchedule( # pylint: disable=not-callable + dl.gpu.Matmul(), + dl.gpu.GEMV(), + dl.gpu.Reduction(), + dl.gpu.GeneralReduction(), + dl.gpu.Fallback(), + )(mod) + return rx.build(mod, target="cuda") + + # pylint: disable=invalid-name + X = np.random.randn(128, 128).astype("float32") + W1 = np.random.randn(128, 128).astype("float32") + W2 = np.random.randn(128, 128).astype("float32") + Y_expected = VirtualMachine(relax_build(MLP, target), device=tvm.cuda(0))["main"]( + tvm.nd.array(X, device=tvm.cuda(0)), + tvm.nd.array(W1, device=tvm.cuda(0)), + tvm.nd.array(W2, device=tvm.cuda(0)), + ).numpy() + + with tempfile.TemporaryDirectory() as tmpdir: + path = tmpdir + "/test.so" + relax_build(ShardedMLP, target).export_library(path) + + sess = di.ThreadedSession(num_workers=num_workers) + sess.init_ccl("nccl", *devices) + mod = sess.load_vm_module(path) + + d_X = sess.empty((128, 128), "float32") + d_W1 = sess.empty((128, 64), "float32") + d_W2 = sess.empty((64, 128), "float32") + + d_X.debug_get_from_remote(0).copyfrom(X) + d_X.debug_get_from_remote(1).copyfrom(X) + d_W1.debug_get_from_remote(0).copyfrom(W1[:, :64]) + d_W1.debug_get_from_remote(1).copyfrom(W1[:, 64:]) + d_W2.debug_get_from_remote(0).copyfrom(W2[:64, :]) + d_W2.debug_get_from_remote(1).copyfrom(W2[64:, :]) + d_Y = mod["main"](d_X, d_W1, d_W2) + Y_result = tvm.nd.empty((128, 128), "float32", device=tvm.cuda(0)) + sess.copy_from_worker_0(Y_result, d_Y) + sess.sync_worker(0) + Y_result = Y_result.numpy() + # pylint: enable=invalid-name + np.testing.assert_allclose(Y_result, Y_expected, rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + test_init() + test_broadcast_from_zero() + test_allreduce() + test_mlp() diff --git a/tests/python/disco/test_session.py b/tests/python/disco/test_session.py new file mode 100644 index 0000000000000..ecff37afefc98 --- /dev/null +++ b/tests/python/disco/test_session.py @@ -0,0 +1,215 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Basic tests for a Disco session""" +# pylint: disable=missing-docstring +import tempfile + +import numpy as np + +import tvm +from tvm import relax as rx +from tvm._ffi import register_func +from tvm.runtime import disco as di +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tir as T + + +def _numpy_to_worker_0(sess: di.Session, np_array: np.array, device): + x_array = sess.empty(np_array.shape, "float32", device=device) + host_array = tvm.nd.array(np_array, device=device) + sess.copy_to_worker_0(host_array, x_array) + return x_array + + +def _numpy_from_worker_0(sess: di.Session, remote_array, shape, dtype): + host_array = tvm.nd.empty(shape, dtype, device=tvm.cpu()) + sess.copy_from_worker_0(host_array, remote_array) + sess.sync_worker(0) + return host_array.numpy() + + +def test_int(): + num_workers = 4 + + @register_func("tests.disco.add_one", override=True) + def add_one(x: int) -> int: # pylint: disable=invalid-name + return x + 1 + + sess = di.ThreadedSession(num_workers=num_workers) + func: di.DPackedFunc = sess.get_global_func("tests.disco.add_one") + result: di.DRef = func(1) + + for i in range(num_workers): + assert result.debug_get_from_remote(i) == 2 + + +def test_float(): + num_workers = 4 + + @register_func("tests.disco.add_one_float", override=True) + def add_one(x: float): # pylint: disable=invalid-name + return x + 0.5 + + sess = di.ThreadedSession(num_workers=num_workers) + func: di.DPackedFunc = sess.get_global_func("tests.disco.add_one_float") + result: di.DRef = func(1.5) + + for i in range(num_workers): + assert result.debug_get_from_remote(i) == 2.0 + + +def test_ndarray(): + num_workers = 4 + + @register_func("tests.disco.add_one_ndarray", override=True) + def add_one(x: tvm.runtime.NDArray) -> tvm.runtime.NDArray: # pylint: disable=invalid-name + return tvm.nd.array(x.numpy() + 1) + + device = tvm.cpu(0) + x_np = np.arange(6).astype("float32").reshape([2, 3]) + y_np = np.arange(6).astype("float32").reshape([2, 3]) + 1 + + sess = di.ThreadedSession(num_workers=num_workers) + x_disc = _numpy_to_worker_0(sess, x_np, device=device) + y_disc = sess.get_global_func("tests.disco.add_one_ndarray")(x_disc) + y_nd = _numpy_from_worker_0(sess, y_disc, shape=y_np.shape, dtype=y_np.dtype) + np.testing.assert_equal(y_nd, y_np) + + +def test_string(): + num_workers = 4 + + @register_func("tests.disco.str", override=True) + def my_str_func(x: str): # pylint: disable=invalid-name + return x + "_suffix" + + sess = di.ThreadedSession(num_workers=num_workers) + func: di.DPackedFunc = sess.get_global_func("tests.disco.str") + result: di.DRef = func("hello") + + for i in range(num_workers): + assert result.debug_get_from_remote(i) == "hello_suffix" + + +def test_vm_module(): + num_workers = 4 + + # pylint: disable=invalid-name + @I.ir_module + class TestMod: + @T.prim_func + def transpose(A: T.Buffer((8, 16), "float32"), B: T.Buffer((16, 8), "float32")): + for i, j in T.grid(16, 8): + with T.block("transpose"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vj, vi] + + @R.function + def main(A: R.Tensor((8, 16), dtype="float32")) -> R.Tensor((16, 8), dtype="float32"): + cls = TestMod + with R.dataflow(): + B = R.call_tir(cls.transpose, (A,), out_sinfo=R.Tensor((16, 8), dtype="float32")) + R.output(B) + return B + + # pylint: enable=invalid-name + with tempfile.TemporaryDirectory() as tmpdir: + path = tmpdir + "/test.so" + device = tvm.cpu() + x_np = np.arange(8 * 16).astype("float32").reshape([8, 16]) + y_np = x_np.transpose() + + rx.build(TestMod, target="llvm").export_library(path) + sess = di.ThreadedSession(num_workers=num_workers) + mod = sess.load_vm_module(path, device=device) + + x_disc = _numpy_to_worker_0(sess, x_np, device=device) + y_disc = mod["main"](x_disc) + y_nd = _numpy_from_worker_0(sess, y_disc, shape=y_np.shape, dtype=y_np.dtype) + np.testing.assert_equal(y_nd, y_np) + + +def test_vm_multi_func(): + num_workers = 4 + + # pylint: disable=invalid-name + @I.ir_module + class TestMod: + @T.prim_func + def t1(A: T.Buffer((8, 16), "float32"), B: T.Buffer((16, 8), "float32")): + for i, j in T.grid(16, 8): + with T.block("transpose"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vj, vi] + + @T.prim_func + def t2(A: T.Buffer((16, 8), "float32"), B: T.Buffer((8, 16), "float32")): + for i, j in T.grid(8, 16): + with T.block("transpose"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vj, vi] + + @R.function + def transpose_1( + A: R.Tensor((8, 16), dtype="float32") + ) -> R.Tensor((16, 8), dtype="float32"): + R.func_attr({"global_symbol": "main"}) + cls = TestMod + with R.dataflow(): + B = R.call_tir(cls.t1, (A,), out_sinfo=R.Tensor((16, 8), dtype="float32")) + R.output(B) + return B + + @R.function + def transpose_2( + A: R.Tensor((16, 8), dtype="float32") + ) -> R.Tensor((8, 16), dtype="float32"): + R.func_attr({"global_symbol": "main"}) + cls = TestMod + with R.dataflow(): + B = R.call_tir(cls.t2, (A,), out_sinfo=R.Tensor((8, 16), dtype="float32")) + R.output(B) + return B + + # pylint: enable=invalid-name + with tempfile.TemporaryDirectory() as tmpdir: + path = tmpdir + "/test.so" + device = tvm.cpu() + x_np = np.arange(8 * 16).astype("float32").reshape([8, 16]) + y_np = x_np.transpose() + + rx.build(TestMod, target="llvm").export_library(path) + sess = di.ThreadedSession(num_workers=num_workers) + mod = sess.load_vm_module(path, device=device) + + x_disc = _numpy_to_worker_0(sess, x_np, device=device) + y_disc = mod["transpose_1"](x_disc) + z_disc = mod["transpose_2"](y_disc) + y_nd = _numpy_from_worker_0(sess, y_disc, shape=y_np.shape, dtype=y_np.dtype) + z_nd = _numpy_from_worker_0(sess, z_disc, shape=x_np.shape, dtype=x_np.dtype) + np.testing.assert_equal(y_nd, y_np) + np.testing.assert_equal(z_nd, x_np) + + +if __name__ == "__main__": + test_int() + test_float() + test_string() + test_ndarray() + test_vm_module() + test_vm_multi_func() diff --git a/tests/python/relax/test_op_ccl.py b/tests/python/relax/test_op_ccl.py new file mode 100644 index 0000000000000..fd25b393cbe24 --- /dev/null +++ b/tests/python/relax/test_op_ccl.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3), "float32")) + assert relax.op.ccl.allreduce(x).op == Op.get("relax.ccl.allreduce") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_allreduce_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32", ndim=-1)) + x3 = relax.Var("x", R.Tensor((2, 3))) + x4 = relax.Var("x", R.Tensor()) + x5 = relax.Var("x", R.Tensor((3, 4))) + + _check_inference(bb, relax.op.ccl.allreduce(x0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference( + bb, relax.op.ccl.allreduce(x1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference(bb, relax.op.ccl.allreduce(x2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.ccl.allreduce(x3), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, relax.op.ccl.allreduce(x4), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.ccl.allreduce(x5), relax.TensorStructInfo((3, 4), dtype="")) + + +def test_allreduce_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + x1 = relax.Var("x", R.Tensor((4, n), "float32")) + + _check_inference(bb, relax.op.ccl.allreduce(x0), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, relax.op.ccl.allreduce(x1), relax.TensorStructInfo((4, n), "float32")) + + +def test_allreduce_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s1 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + _check_inference(bb, relax.op.ccl.allreduce(x0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.ccl.allreduce(x1), relax.TensorStructInfo(s1, "float32")) + + +def test_allreduce_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float64")) + x1 = relax.Var("x", R.Tensor((2, 3), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3), "int64")) + + _check_inference(bb, relax.op.ccl.allreduce(x0), relax.TensorStructInfo((2, 3), "float64")) + _check_inference(bb, relax.op.ccl.allreduce(x1), relax.TensorStructInfo((2, 3), "int8")) + _check_inference(bb, relax.op.ccl.allreduce(x2), relax.TensorStructInfo((2, 3), "int64")) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_legalize_ops_ccl.py b/tests/python/relax/test_transform_legalize_ops_ccl.py new file mode 100644 index 0000000000000..ef535bef533be --- /dev/null +++ b/tests/python/relax/test_transform_legalize_ops_ccl.py @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm.relax.transform import LegalizeOps +from tvm.script import ir as I +from tvm.script import relax as R + + +def test_allreduce(): + # fmt: off + @tvm.script.ir_module + class AllReduce: + @R.function + def main(x: R.Tensor((10, 10), "float32")) -> R.Tensor((10, 10), "float32"): + gv0: R.Tensor((10, 10), "float32") = R.ccl.allreduce(x, "sum") + gv1: R.Tensor((10, 10), "float32") = R.ccl.allreduce(x, "prod") + gv2: R.Tensor((10, 10), "float32") = R.ccl.allreduce(x, "min") + gv3: R.Tensor((10, 10), "float32") = R.ccl.allreduce(x, "max") + gv4: R.Tensor((10, 10), "float32") = R.ccl.allreduce(x, "avg") + return x + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="float32"): + gv0: R.Tensor((10, 10), dtype="float32") = R.call_pure_packed("runtime.disco.allreduce", x, R.shape([0]), sinfo_args=R.Tensor((10, 10), dtype="float32")) + gv1: R.Tensor((10, 10), dtype="float32") = R.call_pure_packed("runtime.disco.allreduce", x, R.shape([1]), sinfo_args=R.Tensor((10, 10), dtype="float32")) + gv2: R.Tensor((10, 10), dtype="float32") = R.call_pure_packed("runtime.disco.allreduce", x, R.shape([2]), sinfo_args=R.Tensor((10, 10), dtype="float32")) + gv3: R.Tensor((10, 10), dtype="float32") = R.call_pure_packed("runtime.disco.allreduce", x, R.shape([3]), sinfo_args=R.Tensor((10, 10), dtype="float32")) + gv4: R.Tensor((10, 10), dtype="float32") = R.call_pure_packed("runtime.disco.allreduce", x, R.shape([4]), sinfo_args=R.Tensor((10, 10), dtype="float32")) + return x + # fmt: on + + mod = LegalizeOps()(AllReduce) + tvm.ir.assert_structural_equal(mod, Expected) + + +if __name__ == "__main__": + tvm.testing.main()