Skip to content

Commit

Permalink
[Arith] Add IntegerSetNode to represent Presburger Set
Browse files Browse the repository at this point in the history
  • Loading branch information
Min Chen committed Apr 22, 2023
1 parent f4b53fb commit 745ecc2
Show file tree
Hide file tree
Showing 10 changed files with 625 additions and 102 deletions.
4 changes: 4 additions & 0 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ set(USE_MICRO_STANDALONE_RUNTIME OFF)
# - /path/to/llvm-config: enable specific LLVM when multiple llvm-dev is available.
set(USE_LLVM OFF)

# Whether use MLIR to help analyze, requires USE_LLVM is enabled
# Possible values: ON/OFF
set(USE_MLIR OFF)

#---------------------------------------------
# Contrib libraries
#---------------------------------------------
Expand Down
3 changes: 3 additions & 0 deletions cmake/modules/LLVM.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ if(NOT ${USE_LLVM} MATCHES ${IS_FALSE_PATTERN})
message(STATUS "Set TVM_LLVM_VERSION=" ${TVM_LLVM_VERSION})
# Set flags that are only needed for LLVM target
add_definitions(-DTVM_LLVM_VERSION=${TVM_LLVM_VERSION})
if (${TVM_MLIR_VERSION})
add_definitions(-DTVM_MLIR_VERSION=${TVM_MLIR_VERSION})
endif()
tvm_file_glob(GLOB COMPILER_LLVM_SRCS src/target/llvm/*.cc)
list(APPEND TVM_LINKER_LIBS ${LLVM_LIBS})
list(APPEND COMPILER_SRCS ${COMPILER_LLVM_SRCS})
Expand Down
10 changes: 10 additions & 0 deletions cmake/utils/FindLLVM.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,16 @@ macro(find_llvm use_llvm)
string(REPLACE "$" ${__llvm_prefix} __lib_with_prefix "${__flag}")
list(APPEND LLVM_LIBS "${__lib_with_prefix}")
endforeach()
if (${USE_MLIR})
if (EXISTS "${__llvm_libdir}/libMLIRPresburger.a")
if (EXISTS "${__llvm_libdir}/libMLIRSupport.a")
message(STATUS "Found MLIR")
list(APPEND LLVM_LIBS "${__llvm_libdir}/libMLIRPresburger.a")
list(APPEND LLVM_LIBS "${__llvm_libdir}/libMLIRSupport.a")
set(TVM_MLIR_VERSION ${TVM_LLVM_VERSION})
endif()
endif()
endif()
separate_arguments(__llvm_system_libs)
foreach(__flag IN ITEMS ${__llvm_system_libs})
# If the library file ends in .lib try to
Expand Down
101 changes: 101 additions & 0 deletions include/tvm/node/reflection.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,107 @@ class AttrVisitor {
//! \endcond
};

// Attr getter.
class AttrGetter : public AttrVisitor {
public:
const String& skey;
runtime::TVMRetValue* ret;

AttrGetter(const String& skey, runtime::TVMRetValue* ret) : skey(skey), ret(ret) {}

bool found_ref_object{false};

void Visit(const char* key, double* value) final {
if (skey == key) *ret = value[0];
}
void Visit(const char* key, int64_t* value) final {
if (skey == key) *ret = value[0];
}
void Visit(const char* key, uint64_t* value) final {
ICHECK_LE(value[0], static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))
<< "cannot return too big constant";
if (skey == key) *ret = static_cast<int64_t>(value[0]);
}
void Visit(const char* key, int* value) final {
if (skey == key) *ret = static_cast<int64_t>(value[0]);
}
void Visit(const char* key, bool* value) final {
if (skey == key) *ret = static_cast<int64_t>(value[0]);
}
void Visit(const char* key, void** value) final {
if (skey == key) *ret = static_cast<void*>(value[0]);
}
void Visit(const char* key, DataType* value) final {
if (skey == key) *ret = value[0];
}
void Visit(const char* key, std::string* value) final {
if (skey == key) *ret = value[0];
}

void Visit(const char* key, runtime::NDArray* value) final {
if (skey == key) {
*ret = value[0];
found_ref_object = true;
}
}
void Visit(const char* key, runtime::ObjectRef* value) final {
if (skey == key) {
*ret = value[0];
found_ref_object = true;
}
}
};

class NodeAttrSetter : public AttrVisitor {
public:
std::string type_key;
std::unordered_map<std::string, runtime::TVMArgValue> attrs;

void Visit(const char* key, double* value) final { *value = GetAttr(key).operator double(); }
void Visit(const char* key, int64_t* value) final { *value = GetAttr(key).operator int64_t(); }
void Visit(const char* key, uint64_t* value) final { *value = GetAttr(key).operator uint64_t(); }
void Visit(const char* key, int* value) final { *value = GetAttr(key).operator int(); }
void Visit(const char* key, bool* value) final { *value = GetAttr(key).operator bool(); }
void Visit(const char* key, std::string* value) final {
*value = GetAttr(key).operator std::string();
}
void Visit(const char* key, void** value) final { *value = GetAttr(key).operator void*(); }
void Visit(const char* key, DataType* value) final { *value = GetAttr(key).operator DataType(); }
void Visit(const char* key, runtime::NDArray* value) final {
*value = GetAttr(key).operator runtime::NDArray();
}
void Visit(const char* key, ObjectRef* value) final {
*value = GetAttr(key).operator ObjectRef();
}

runtime::TVMArgValue GetAttr(const char* key) {
auto it = attrs.find(key);
if (it == attrs.end()) {
LOG(FATAL) << type_key << ": require field " << key;
}
runtime::TVMArgValue v = it->second;
attrs.erase(it);
return v;
}
};

// List names;
class AttrDir : public AttrVisitor {
public:
std::vector<std::string>* names;

void Visit(const char* key, double* value) final { names->push_back(key); }
void Visit(const char* key, int64_t* value) final { names->push_back(key); }
void Visit(const char* key, uint64_t* value) final { names->push_back(key); }
void Visit(const char* key, bool* value) final { names->push_back(key); }
void Visit(const char* key, int* value) final { names->push_back(key); }
void Visit(const char* key, void** value) final { names->push_back(key); }
void Visit(const char* key, DataType* value) final { names->push_back(key); }
void Visit(const char* key, std::string* value) final { names->push_back(key); }
void Visit(const char* key, runtime::NDArray* value) final { names->push_back(key); }
void Visit(const char* key, runtime::ObjectRef* value) final { names->push_back(key); }
};

/*!
* \brief Virtual function table to support IR/AST node reflection.
*
Expand Down
1 change: 1 addition & 0 deletions python/tvm/arith/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .int_set import (
IntSet,
IntervalSet,
IntegerSet,
estimate_region_lower_bound,
estimate_region_strict_bound,
estimate_region_upper_bound,
Expand Down
17 changes: 17 additions & 0 deletions python/tvm/arith/int_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,23 @@ def __init__(self, min_value, max_value):
self.__init_handle_by_constructor__(_ffi_api.IntervalSet, min_value, max_value)


@tvm._ffi.register_object("arith.IntegerSet")
class IntegerSet(IntSet):
"""Represent of Presburger Set
Parameters
----------
constraint : PrimExpr
The constraint expression.
domain_vars : List[PrimExpr]
The domain vars of Presburger Set.
"""

def __init__(self, constraint, domain_vars):
self.__init_handle_by_constructor__(_ffi_api.IntegerSet, constraint, domain_vars)


def estimate_region_lower_bound(region, var_dom, predicate):
"""Analyze the region with affine map, given the domain of variables and their predicate
Some subregion may be discarded during the lower-bound analysis.
Expand Down
Loading

0 comments on commit 745ecc2

Please sign in to comment.