Skip to content

Commit

Permalink
[Target] Allow spaces in target attributes (#8587)
Browse files Browse the repository at this point in the history
* [Target] Allow for spaces in target attributes.

Some target parameters, such as the device_name on vulkan, have spaces
in them.  This prevented round-trips between string and Target
objects, which can occur in some cases.

* [Vulkan] Fixed "device_name" property querying.

* [Target] Switched from escaped spaces to quoted spaces.

Instead of -attr=value\ with\ spaces, will instead be written as
-attr='value with spaces'.

Co-authored-by: Eric Lunderberg <[email protected]>
  • Loading branch information
Lunderberg and Lunderberg authored Aug 4, 2021
1 parent b9204cd commit d38bef5
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 26 deletions.
2 changes: 1 addition & 1 deletion src/runtime/vulkan/vulkan_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv)
break;
}
case kDeviceName:
*rv = prop.device_name;
*rv = String(prop.device_name);
break;

case kMaxClockRate:
Expand Down
106 changes: 81 additions & 25 deletions src/target/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

#include <algorithm>
#include <cctype>
#include <cstring>
#include <stack>

#include "../runtime/object_internal.h"
Expand Down Expand Up @@ -147,17 +148,83 @@ static int FindFirstSubstr(const std::string& str, const std::string& substr) {
}

static Optional<String> JoinString(const std::vector<String>& array, char separator) {
char escape = '\\';
char quote = '\'';

if (array.empty()) {
return NullOpt;
}

std::ostringstream os;
os << array[0];
for (size_t i = 1; i < array.size(); ++i) {
os << separator << array[i];

for (size_t i = 0; i < array.size(); ++i) {
if (i > 0) {
os << separator;
}

std::string str = array[i];

if ((str.find(separator) == std::string::npos) && (str.find(quote) == std::string::npos)) {
os << str;
} else {
os << quote;
for (char c : str) {
if (c == separator || c == quote) {
os << escape;
}
os << c;
}
os << quote;
}
}
return String(os.str());
}

static std::vector<std::string> SplitString(const std::string& str, char separator) {
char escape = '\\';
char quote = '\'';

std::vector<std::string> output;

const char* start = str.data();
const char* end = start + str.size();
const char* pos = start;

std::stringstream current_word;

auto finish_word = [&]() {
std::string word = current_word.str();
if (word.size()) {
output.push_back(word);
current_word.str("");
}
};

bool pos_quoted = false;

while (pos < end) {
if ((*pos == separator) && !pos_quoted) {
finish_word();
pos++;
} else if ((*pos == escape) && (pos + 1 < end) && (pos[1] == quote)) {
current_word << quote;
pos += 2;
} else if (*pos == quote) {
pos_quoted = !pos_quoted;
pos++;
} else {
current_word << *pos;
pos++;
}
}

ICHECK(!pos_quoted) << "Mismatched quotes '' in string";

finish_word();

return output;
}

static int ParseKVPair(const std::string& s, const std::string& s_next, std::string* key,
std::string* value) {
int pos;
Expand Down Expand Up @@ -207,9 +274,9 @@ const TargetKindNode::ValueTypeInfo& TargetInternal::FindTypeInfo(const TargetKi

ObjectRef TargetInternal::ParseType(const std::string& str,
const TargetKindNode::ValueTypeInfo& info) {
std::istringstream is(str);
if (info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
// Parsing integer
std::istringstream is(str);
int v;
if (!(is >> v)) {
std::string lower(str.size(), '\x0');
Expand All @@ -226,19 +293,18 @@ ObjectRef TargetInternal::ParseType(const std::string& str,
}
return Integer(v);
} else if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
// Parsing string
std::string v;
if (!(is >> v)) {
throw Error(": Cannot parse into type \"String\" from string: " + str);
}
return String(v);
// Parsing string, strip leading/trailing spaces
auto start = str.find_first_not_of(' ');
auto end = str.find_last_not_of(' ');
return String(str.substr(start, (end - start + 1)));

} else if (info.type_index == Target::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
// Parsing target
return Target(TargetInternal::FromString(str));
} else if (info.type_index == ArrayNode::_GetOrAllocRuntimeTypeIndex()) {
// Parsing array
std::vector<ObjectRef> result;
for (std::string substr; std::getline(is, substr, ',');) {
for (const std::string& substr : SplitString(str, ',')) {
try {
ObjectRef parsed = TargetInternal::ParseType(substr, *info.key);
result.push_back(parsed);
Expand Down Expand Up @@ -550,24 +616,14 @@ ObjectPtr<Object> TargetInternal::FromConfigString(const String& config_str) {
}

ObjectPtr<Object> TargetInternal::FromRawString(const String& target_str) {
ICHECK_GT(target_str.length(), 0) << "Cannot parse empty target string";
// Split the string by empty spaces
std::string name;
std::vector<std::string> options;
std::string str;
for (std::istringstream is(target_str); is >> str;) {
if (name.empty()) {
name = str;
} else {
options.push_back(str);
}
}
if (name.empty()) {
throw Error(": Cannot parse empty target string");
}
std::vector<std::string> options = SplitString(std::string(target_str), ' ');
std::string name = options[0];
// Create the target config
std::unordered_map<String, ObjectRef> config = {{"kind", String(name)}};
TargetKind kind = GetTargetKind(name);
for (size_t iter = 0, end = options.size(); iter < end;) {
for (size_t iter = 1, end = options.size(); iter < end;) {
std::string key, value;
try {
// Parse key-value pair
Expand Down
8 changes: 8 additions & 0 deletions tests/python/unittest/test_target_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ def test_target_string_parse():
assert tvm.target.arm_cpu().device_name == "arm_cpu"


def test_target_string_with_spaces():
target = tvm.target.Target(
"vulkan -device_name='Name of GPU with spaces' -device_type=discrete"
)
assert target.attrs["device_name"] == "Name of GPU with spaces"
assert target.attrs["device_type"] == "discrete"


def test_target_create():
targets = [cuda(), rocm(), mali(), intel_graphics(), arm_cpu("rk3399"), vta(), bifrost()]
for tgt in targets:
Expand Down

0 comments on commit d38bef5

Please sign in to comment.