Skip to content

Commit

Permalink
Intgemm refactor (#762)
Browse files Browse the repository at this point in the history
* clean-up type hierarchy
* towards offline hardware-specific packing
* call conversion explicitly
* simplify conversion
* remove backend modifications for now
* more clean-up
* get rid of Transpose10
* update changelog
* address review comments
* clean-up intgemm_interface.h
* align parameters
* add correct dispatching
* add comments
* minor formatting
  • Loading branch information
emjotde authored Nov 14, 2020
1 parent 0efcdd7 commit a0c57e4
Show file tree
Hide file tree
Showing 21 changed files with 371 additions and 778 deletions.
8 changes: 4 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
## [Unreleased]

### Added
- Added `intgemm8(ssse3|avx|avx512)?`, `intgemm16(sse2|avx|avx512)?` types to marian-conv with uses intgemm backend. Types intgemm8 and intgemm16 are hardware-agnostic, the other ones hardware-specific.
- Shortlist is now always multiple-of-eight.
- Added intgemm 8/16bit integer binary architecture agnostic format.
- Add --train-embedder-rank for fine-tuning any encoder(-decoder) model for multi-lingual similarity via softmax-margin loss
- Add --logical-epoch that allows to redefine the displayed epoch counter as a multiple of n data epochs, updates or labels. Also allows to define width of fractional part with second argument.
- Add --metrics chrf for computing ChrF according to https://www.aclweb.org/anthology/W15-3049/ and SacreBLEU reference implementation
Expand All @@ -28,10 +31,6 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Training and scoring from STDIN
- Support for reading from TSV files from STDIN and other sources during training
and translation with options --tsv and --tsv-fields n.
- Shortlist is now always multiple-of-eight.
- Changed the `--optimize` switch to `--int16` and replaced the computational backend to intgemm.
- Added `--gemmm-precision` which specifies the numerical precision used for the GEMM computations. Valid values `float32` (default) `int16`, `int8` and `int8shift`. Also added aliases for the latter three. All integer based GEMM use intgemm as a computational backend.
- Added intgemm 8/16bit integer binary architecture agnostic format.
- Internal optional parameter in n-best list generation that skips empty hypotheses.
- Quantized training (fixed point or log-based quantization) with --quantize-bits N command

Expand All @@ -58,6 +57,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Improved handling for receiving SIGTERM during training. By default, SIGTERM triggers 'save (now) and exit'. Prior to this fix, batch pre-fetching did not check for this sigal, potentially delaying exit considerably. It now pays attention to that. Also, the default behaviour of save-and-exit can now be disabled on the command line with --sigterm exit-immediately.

### Changed
- Removed `--optimize` switch, instead we now determine compute type based on binary model.
- Updated SentencePiece repository to version 8336bbd0c1cfba02a879afe625bf1ddaf7cd93c5 from https://github.com/google/sentencepiece.
- Enabled compilation of SentencePiece by default since no dependency on protobuf anymore.
- Changed default value of --sentencepiece-max-lines from 10000000 to 2000000 since apparently the new version doesn't sample automatically anymore (Not quite clear how that affects quality of the vocabulary).
Expand Down
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ if(MSVC)
# Or maybe use these?
set(INTRINSICS "/arch:AVX2")
# set(INTRINSICS "/arch:AVX512")
# /bigobj is necessary for expression_operators.cpp. See https://stackoverflow.com/questions/15110580/penalty-of-the-msvs-compiler-flag-bigobj
# /bigobj is necessary for expression_operators.cpp. See https://stackoverflow.com/questions/15110580/penalty-of-the-msvs-compiler-flag-bigobj
set(CMAKE_CXX_FLAGS "/EHsc /DWIN32 /D_WINDOWS /DUNICODE /D_UNICODE /D_CRT_NONSTDC_NO_WARNINGS /D_CRT_SECURE_NO_WARNINGS /bigobj ${DISABLE_GLOBALLY}")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS} /MT /O2 ${INTRINSICS} /Zi /MP /GL /DNDEBUG")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS} /MTd /Od /Ob0 ${INTRINSICS} /RTC1 /Zi /D_DEBUG")
Expand Down
22 changes: 2 additions & 20 deletions src/command/marian_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,8 @@ int main(int argc, char** argv) {
auto exportAs = options->get<std::string>("export-as");
auto vocabPaths = options->get<std::vector<std::string>>("vocabs");// , std::vector<std::string>());

auto saveGemmTypeStr = options->get<std::string>("gemm-type", "float32");
Type saveGemmType;
if(saveGemmTypeStr == "float32") {
saveGemmType = Type::float32;
} else if(saveGemmTypeStr == "packed16") { // packed16 (fbgemm) only supports AVX2. AVX512 might be added later
saveGemmType = Type::packed16;
} else if(saveGemmTypeStr == "packed8avx2") { // packed8 for AVX2 (fbgemm)
saveGemmType = Type::packed8avx2;
} else if(saveGemmTypeStr == "packed8avx512") { // packed8 for AVX512 (fbgemm)
saveGemmType = Type::packed8avx512;
} else if(saveGemmTypeStr == "intgemm8") { // intgemm 8 bit format
saveGemmType = Type::intgemm8;
} else if(saveGemmTypeStr == "intgemm16") { // intgemm 16 bit format
saveGemmType = Type::intgemm16;
} else {
ABORT("Unknown gemm-type: {}", saveGemmTypeStr);
}
// We accept any type here and will later croak during packAndSave if the type cannot be used for conversion
Type saveGemmType = typeFromString(options->get<std::string>("gemm-type", "float32"));

LOG(info, "Outputting {}, precision: {}", modelTo, saveGemmType);

Expand All @@ -63,9 +48,6 @@ int main(int argc, char** argv) {

auto load = [&](Ptr<ExpressionGraph> graph) {
graph->setDevice(CPU0);
graph->getBackend()->setInt8(false); // Since win run graph->forward() we need to make sure it does not get converted to an intgemm format during it.
graph->getBackend()->setInt16(false); // We manually do the compression later.

graph->load(modelFrom);
graph->forward(); // run the initializers
};
Expand Down
12 changes: 0 additions & 12 deletions src/common/aliases.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,18 +225,6 @@ void ConfigParser::addAliases(cli::CLIWrapper& cli) {
config["valid-mini-batch"] = 8;
config["normalize"] = 1.0;
});
} else { // Only available during translation/scoring or server modes
cli.alias("int16", "true", [&](YAML::Node& config) {
config["gemm-precision"] = std::string("int16");
});

cli.alias("int8", "true", [&](YAML::Node& config) {
config["gemm-precision"] = std::string("int8");
});

cli.alias("int8shift", "true", [&](YAML::Node& config) {
config["gemm-precision"] = std::string("int8shift");
});
}
}

Expand Down
7 changes: 4 additions & 3 deletions src/common/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ void loadItems(const void* current, std::vector<io::Item>& items, bool mapped) {

for(int i = 0; i < numHeaders; ++i) {
if(items[i].mapped) { // memory-mapped, hence only set pointer
ABORT_IF(isIntgemm(items[i].type), "mmap format not supported for intgemm matrices");
// @TOOD: verify this actually works for the hardware-specific ones like intgemm8avx2
ABORT_IF(items[i].type == Type::intgemm8 || items[i].type == Type::intgemm16, "mmap format not supported for hardware non-specific intgemm matrices");
items[i].ptr = get<char>(current, headers[i].dataLength);
} else { // reading into item data
size_t len = headers[i].dataLength;
Expand All @@ -69,9 +70,9 @@ void loadItems(const void* current, std::vector<io::Item>& items, bool mapped) {
// Reordering depends on the architecture (SSE/AVX2/AVX512) so we read in the quantized matrices and
// then reorder them before adding them as a parameter in the graph.
if (matchType<intgemm8>(items[i].type)) {
cpu::integer::prepareAndTransposeB<Type::int8>(items[i], ptr);
cpu::integer::prepareAndTransposeB<Type::intgemm8>(items[i], ptr);
} else if (matchType<intgemm16>(items[i].type)) {
cpu::integer::prepareAndTransposeB<Type::int16>(items[i], ptr);
cpu::integer::prepareAndTransposeB<Type::intgemm16>(items[i], ptr);
} else {
std::copy(ptr, ptr + len, items[i].bytes.begin());
}
Expand Down
17 changes: 0 additions & 17 deletions src/common/config_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,6 @@ void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) {
addSuboptionsTSV(cli);
addSuboptionsDevices(cli);
addSuboptionsBatching(cli);
addSuboptionsIntgemm(cli);

cli.add<bool>("--fp16",
"Shortcut for mixed precision inference with float16, corresponds to: --precision float16");
Expand Down Expand Up @@ -735,10 +734,7 @@ void ConfigParser::addOptionsScoring(cli::CLIWrapper& cli) {
addSuboptionsTSV(cli);
addSuboptionsDevices(cli);
addSuboptionsBatching(cli);
addSuboptionsIntgemm(cli);

cli.add<bool>("--optimize",
"Optimize speed aggressively sacrificing memory or precision");
cli.add<bool>("--fp16",
"Shortcut for mixed precision inference with float16, corresponds to: --precision float16");
cli.add<std::vector<std::string>>("--precision",
Expand Down Expand Up @@ -775,7 +771,6 @@ void ConfigParser::addOptionsEmbedding(cli::CLIWrapper& cli) {
addSuboptionsTSV(cli);
addSuboptionsDevices(cli);
addSuboptionsBatching(cli);
addSuboptionsIntgemm(cli);

cli.add<bool>("--fp16",
"Shortcut for mixed precision inference with float16, corresponds to: --precision float16");
Expand Down Expand Up @@ -917,18 +912,6 @@ void ConfigParser::addSuboptionsULR(cli::CLIWrapper& cli) {
// clang-format on
}

void ConfigParser::addSuboptionsIntgemm(cli::CLIWrapper& cli) {
// clang-format off
cli.add<bool>("--int16",
"Optimize speed aggressively sacrificing memory or precision by using 16bit integer GEMM with intgemm instead of floats. Only available on CPU. Corresponds to --gemm-precision int16");
cli.add<bool>("--int8",
"Optimize speed even more aggressively sacrificing memory or precision by using 8bit integer GEMM with intgemm instead of floats. Only available on CPU. Corresponds to --gemm-precision int8");
cli.add<bool>("--int8shift",
"Use a faster, shifted integer 8bit GEMM implementation. Corresponds to --gemm-precision int8shift");
cli.add<std::string>("--gemm-precision",
"Use lower precision for the GEMM operations only. Supported values: float32, int16, int8, int8shift", "float32");
}

void ConfigParser::addSuboptionsQuantization(cli::CLIWrapper& cli) {
// clang-format off
// model quantization training
Expand Down
1 change: 0 additions & 1 deletion src/common/config_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ class ConfigParser {
void addSuboptionsInputLength(cli::CLIWrapper&);
void addSuboptionsTSV(cli::CLIWrapper&);
void addSuboptionsULR(cli::CLIWrapper&);
void addSuboptionsIntgemm(cli::CLIWrapper&);
void addSuboptionsQuantization(cli::CLIWrapper&);

// Extract paths to all config files found in the config object.
Expand Down
Loading

0 comments on commit a0c57e4

Please sign in to comment.