Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Target] Allow spaces in target attributes #8587

Merged
merged 3 commits into from
Aug 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/runtime/vulkan/vulkan_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv)
break;
}
case kDeviceName:
*rv = prop.device_name;
*rv = String(prop.device_name);
break;

case kMaxClockRate:
Expand Down
106 changes: 81 additions & 25 deletions src/target/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

#include <algorithm>
#include <cctype>
#include <cstring>
#include <stack>

#include "../runtime/object_internal.h"
Expand Down Expand Up @@ -146,17 +147,83 @@ static int FindFirstSubstr(const std::string& str, const std::string& substr) {
}

static Optional<String> JoinString(const std::vector<String>& array, char separator) {
char escape = '\\';
char quote = '\'';

if (array.empty()) {
return NullOpt;
}

std::ostringstream os;
os << array[0];
for (size_t i = 1; i < array.size(); ++i) {
os << separator << array[i];

for (size_t i = 0; i < array.size(); ++i) {
if (i > 0) {
os << separator;
}

std::string str = array[i];

if ((str.find(separator) == std::string::npos) && (str.find(quote) == std::string::npos)) {
os << str;
} else {
os << quote;
for (char c : str) {
if (c == separator || c == quote) {
os << escape;
}
os << c;
}
os << quote;
}
}
return String(os.str());
}

static std::vector<std::string> SplitString(const std::string& str, char separator) {
char escape = '\\';
char quote = '\'';

std::vector<std::string> output;

const char* start = str.data();
const char* end = start + str.size();
const char* pos = start;

std::stringstream current_word;

auto finish_word = [&]() {
std::string word = current_word.str();
if (word.size()) {
output.push_back(word);
current_word.str("");
}
};

bool pos_quoted = false;

while (pos < end) {
if ((*pos == separator) && !pos_quoted) {
finish_word();
pos++;
} else if ((*pos == escape) && (pos + 1 < end) && (pos[1] == quote)) {
current_word << quote;
pos += 2;
} else if (*pos == quote) {
pos_quoted = !pos_quoted;
pos++;
} else {
current_word << *pos;
pos++;
}
}

ICHECK(!pos_quoted) << "Mismatched quotes '' in string";

finish_word();

return output;
}

static int ParseKVPair(const std::string& s, const std::string& s_next, std::string* key,
std::string* value) {
int pos;
Expand Down Expand Up @@ -206,9 +273,9 @@ const TargetKindNode::ValueTypeInfo& TargetInternal::FindTypeInfo(const TargetKi

ObjectRef TargetInternal::ParseType(const std::string& str,
const TargetKindNode::ValueTypeInfo& info) {
std::istringstream is(str);
if (info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
// Parsing integer
std::istringstream is(str);
int v;
if (!(is >> v)) {
std::string lower(str.size(), '\x0');
Expand All @@ -225,19 +292,18 @@ ObjectRef TargetInternal::ParseType(const std::string& str,
}
return Integer(v);
} else if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
// Parsing string
std::string v;
if (!(is >> v)) {
throw Error(": Cannot parse into type \"String\" from string: " + str);
}
return String(v);
// Parsing string, strip leading/trailing spaces
auto start = str.find_first_not_of(' ');
auto end = str.find_last_not_of(' ');
return String(str.substr(start, (end - start + 1)));

} else if (info.type_index == Target::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
// Parsing target
return Target(TargetInternal::FromString(str));
} else if (info.type_index == ArrayNode::_GetOrAllocRuntimeTypeIndex()) {
// Parsing array
std::vector<ObjectRef> result;
for (std::string substr; std::getline(is, substr, ',');) {
for (const std::string& substr : SplitString(str, ',')) {
try {
ObjectRef parsed = TargetInternal::ParseType(substr, *info.key);
result.push_back(parsed);
Expand Down Expand Up @@ -549,24 +615,14 @@ ObjectPtr<Object> TargetInternal::FromConfigString(const String& config_str) {
}

ObjectPtr<Object> TargetInternal::FromRawString(const String& target_str) {
ICHECK_GT(target_str.length(), 0) << "Cannot parse empty target string";
// Split the string by empty spaces
std::string name;
std::vector<std::string> options;
std::string str;
for (std::istringstream is(target_str); is >> str;) {
if (name.empty()) {
name = str;
} else {
options.push_back(str);
}
}
if (name.empty()) {
throw Error(": Cannot parse empty target string");
}
std::vector<std::string> options = SplitString(std::string(target_str), ' ');
std::string name = options[0];
// Create the target config
std::unordered_map<String, ObjectRef> config = {{"kind", String(name)}};
TargetKind kind = GetTargetKind(name);
for (size_t iter = 0, end = options.size(); iter < end;) {
for (size_t iter = 1, end = options.size(); iter < end;) {
std::string key, value;
try {
// Parse key-value pair
Expand Down
8 changes: 8 additions & 0 deletions tests/python/unittest/test_target_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ def test_target_string_parse():
assert tvm.target.arm_cpu().device_name == "arm_cpu"


def test_target_string_with_spaces():
target = tvm.target.Target(
"vulkan -device_name='Name of GPU with spaces' -device_type=discrete"
)
assert target.attrs["device_name"] == "Name of GPU with spaces"
assert target.attrs["device_type"] == "discrete"


def test_target_create():
targets = [cuda(), rocm(), mali(), intel_graphics(), arm_cpu("rk3399"), vta(), bifrost()]
for tgt in targets:
Expand Down