Skip to content

Commit

Permalink
Visualizer Ver1.0 (#113)
Browse files Browse the repository at this point in the history
* gitignore

* init

* [Draft]Complete basic structure, still met with type conversion problem

* Update core.pyx

* Visualizer python interface

* support newly added optypes

* Restore cutlass submodule to commit cc3c29a

* Repair minor issues including replace magic type numbers, useless var in func and remove block_dim

---------

Co-authored-by: Mengdi Wu <[email protected]>
Co-authored-by: Jianan Ji <[email protected]>
  • Loading branch information
3 people authored Oct 26, 2024
1 parent dbbc4f2 commit 187f7a5
Show file tree
Hide file tree
Showing 15 changed files with 1,041 additions and 22 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,10 @@ dist/

# Cython
/python/mirage/_cython/core.cpp

# Mac OS .DS_Store
.DS_Store

# Visualizer results
*.png
*.dot
1 change: 1 addition & 0 deletions demo/reference_mugraphs/gated_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,4 @@ def torch_gated_mlp(X, W1, W2):
mean_syn = curr_time / 1000
#print(timings)
print(mean_syn)
graph.visualize("gated_mlp")
3 changes: 2 additions & 1 deletion demo/reference_mugraphs/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import argparse
import os
import torch
from mirage import visualizer

@torch.compile(backend="cudagraphs")
def torch_lora(X, W, A, B):
Expand Down Expand Up @@ -44,7 +45,7 @@ def optimize_lora(checkpoint):
curr_time = starter.elapsed_time(ender)
mean_syn = curr_time / 1000
print(mean_syn)
graph.visualize("lora")

if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand Down
1 change: 1 addition & 0 deletions include/mirage/kernel/customized.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class KNCustomizedOp : public mirage::kernel::KNOperator {

public:
mirage::threadblock::Graph bgraph;
void get_bgraph(mirage::threadblock::Graph** bgraph);
};

} // namespace kernel
Expand Down
3 changes: 3 additions & 0 deletions include/mirage/kernel/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class KNOperator {
KNOperator(Graph *graph,
mirage::type::KNOperatorType _type,
std::vector<DTensor> const &inputs);
int get_input_dtensors(DTensor** inputs);
int get_output_dtensors(DTensor** inputs);

virtual ~KNOperator();
virtual bool profile(ProfileResult &result) = 0;
virtual bool fingerprint(void) = 0;
Expand Down
5 changes: 5 additions & 0 deletions include/mirage/threadblock/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class TBOperator {
TBOperator(Graph *graph,
mirage::type::TBOperatorType,
std::vector<STensor> const &inputs);
int get_input_stensors(STensor** inputs);
int get_output_stensors(STensor** inputs);

virtual ~TBOperator();

virtual operator json() const = 0;
Expand All @@ -57,6 +60,7 @@ class TBInputOp : public TBOperator {
~TBInputOp();

operator json() const override;
size_t get_dtensor_guid();

public:
mirage::kernel::DTensor dtensor;
Expand All @@ -74,6 +78,7 @@ class TBOutputOp : public TBOperator {
~TBOutputOp();

operator json() const override;
size_t get_dtensor_guid();

public:
mirage::kernel::DTensor dtensor;
Expand Down
108 changes: 100 additions & 8 deletions python/mirage/_cython/CCore.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,70 @@ cdef extern from "mirage/type.h" namespace "mirage::type":
TB_EPILOGUE_ALLREDUCE = 3101,
TB_EPILOGUE_ALLTOALL = 3102,
TB_EPILOGUE_INVALID = 3199,
cdef enum KNOperatorType:
KN_UNKOWN = 1000,
KN_INPUT_OP = 1001,
KN_OUTPUT_OP = 1002,
KN_MATMUL_OP = 1003,
# ElementUnary
KN_EXP_OP = 1100,
KN_SQUARE_OP = 1101,
KN_SQRT_OP = 1102,
KN_SILU_OP = 1103,
# ElementBinary
KN_ADD_OP = 1200,
KN_MUL_OP = 1201,
KN_DIV_OP = 1202,
# Reduction & Normalization
KN_REDUCTION_0_OP = 1300,
KN_REDUCTION_1_OP = 1301,
KN_REDUCTION_2_OP = 1302,
KN_RMS_NORM_OP = 1350,
# Communication
KN_ALLREDUCE_OP = 1400,
KN_CUSTOMIZED_OP = 1999,
cdef enum TBOperatorType:
TB_UNKOWN = 2000,
TB_INPUT_OP = 2001,
TB_OUTPUT_OP = 2002,
TB_MATMUL_OP = 2003,
# ElementUnary
TB_EXP_OP = 2100,
TB_SQUARE_OP = 2101,
TB_SQRT_OP = 2102,
TB_SILU_OP = 2103,
TB_MUL_SCALAR_OP = 2104,
# ElementBinary
TB_ADD_OP = 2200,
TB_MUL_OP = 2201,
TB_DIV_OP = 2202,
# Reduction and Normalization
TB_REDUCTION_FIRST_OP_ID = 2300,
TB_REDUCTION_0_OP = 2301,
TB_REDUCTION_1_OP = 2302,
TB_REDUCTION_2_OP = 2303,
TB_REDUCTION_0_TO_DIMX_OP = 2304,
TB_REDUCTION_1_TO_DIMX_OP = 2305,
TB_REDUCTION_2_TO_DIMX_OP = 2306,
TB_REDUCTION_LAST_OP_ID = 2349,
TB_RMS_NORM_OP = 2350,
# Concat
TB_CONCAT_FIRST_OP_ID = 2400,
TB_CONCAT_0_OP = 2400,
TB_CONCAT_1_OP = 2401,
TB_CONCAT_2_OP = 2402,
TB_CONCAT_LAST_OP_ID = 2410,
TB_CONCAT_THEN_MATMUL_OP = 2411,
# Forloop Accum
# LD indicates last dimension
TB_FORLOOP_ACCUM_FIRST_OP = 2500,
TB_FORLOOP_ACCUM_NO_RED_OP = 2500,
TB_FORLOOP_ACCUM_RED_LD_SUM_OP = 2501,
TB_FORLOOP_ACCUM_RED_LD_MEAN_OP = 2502,
TB_FORLOOP_ACCUM_RED_LD_RMS_OP = 2503,
TB_FORLOOP_ACCUM_REDTOX_LD_SUM_OP = 2504,
TB_FORLOOP_ACCUM_LAST_OP = 2599,
TB_CUSTOMIZED_OP = 2999

cdef extern from "mirage/layout.h" namespace "mirage::layout":
# This must be consistent with mirage/layout.h
Expand All @@ -63,10 +121,10 @@ cdef extern from "mirage/layout.h" namespace "mirage::layout":
SmemColumnMajor = 201,
SmemUnknownLayout = 299

cdef extern from "mirage/kernel/graph.h" namespace "mirage::kernel":
cdef cppclass KNOperator:
pass
ctypedef struct CppDTensor "mirage::kernel::DTensor":
cdef cppclass CppTBGraph "mirage::threadblock::Graph"

cdef extern from "mirage/kernel/device_tensor.h" namespace "mirage::kernel":
cdef struct CppDTensor "mirage::kernel::DTensor":
DataType data_type
DmemLayout layout
int num_dims
Expand All @@ -75,7 +133,19 @@ cdef extern from "mirage/kernel/graph.h" namespace "mirage::kernel":
#KNOperator *owner_op
#void *data_ptr
int owner_ts_idx
pass

cdef extern from "mirage/kernel/graph.h" namespace "mirage::kernel":

cdef cppclass CppKNOperator "mirage::kernel::KNOperator":
KNOperatorType op_type
vector[CppDTensor] input_tensors
vector[CppDTensor] output_tensors
int get_input_dtensors(CppDTensor** cinputs)
int get_output_dtensors(CppDTensor** cinputs)

cdef cppclass CppKNCustomizedOp "mirage::kernel::KNCustomizedOp"(CppKNOperator):
CppTBGraph bgraph
void get_bgraph(CppTBGraph** bgraph)

cdef cppclass CppKNGraph "mirage::kernel::Graph":
CppKNGraph()
Expand All @@ -99,16 +169,33 @@ cdef extern from "mirage/kernel/graph.h" namespace "mirage::kernel":
int get_input_dtensor_layout(const CppDTensor *input, int *strides)
void generate_triton_program(const char *filepath)
void generate_cuda_program(const char *filepath)
vector[CppKNOperator*] operators

cdef extern from "mirage/threadblock/graph.h" namespace "mirage::threadblock":
cdef cppclass TBOperator:
pass
ctypedef struct CppSTensor "mirage::threadblock::STensor":
DataType data_type
SmemLayout layout
int num_dims
int dim[4]
int owner_ts_id
int owner_ts_idx
size_t guid

cdef cppclass CppTBOperator "mirage::threadblock::TBOperator":
TBOperatorType op_type
vector[CppSTensor] input_tensors
vector[CppSTensor] output_tensors
int get_input_stensors(CppSTensor** cinputs)
int get_output_stensors(CppSTensor** cinputs)

cdef cppclass CppTBInputOp "mirage::threadblock::TBInputOp"(CppTBOperator):
int forloop_dim
int3 input_map
size_t get_dtensor_guid()

cdef cppclass CppTBOutputOp "mirage::threadblock::TBOutputOp"(CppTBOperator):
int forloop_dim
int3 output_map
size_t get_dtensor_guid()

cdef cppclass CppTBGraph "mirage::threadblock::Graph":
CppTBGraph(dim3 grid_dim,
Expand Down Expand Up @@ -142,6 +229,11 @@ cdef extern from "mirage/threadblock/graph.h" namespace "mirage::threadblock":
int dim)
CppSTensor* forloop_accum(const CppSTensor *A,
TBOperatorType optype)
dim3 grid_dim
dim3 block_dim
int forloop_range
int reduction_dimx
vector[CppTBOperator*] operators

cdef extern from "mirage/search/search_c.h" namespace "mirage::search_c":
ctypedef struct MInt3:
Expand Down
Loading

0 comments on commit 187f7a5

Please sign in to comment.