diff --git a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/utils/Engine.scala b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/utils/Engine.scala index 4e77482f85b..2deb125110b 100644 --- a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/utils/Engine.scala +++ b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/utils/Engine.scala @@ -531,13 +531,21 @@ object Engine { val master = conf.get("spark.master", null) if (master.toLowerCase.startsWith("local")) { // Spark local mode + // patternLocalN example: local[4] val patternLocalN = "local\\[(\\d+)\\]".r + // patternLocalNF example: local[4,2] + val patternLocalNF = "local\\[(\\d+),\\s*(\\d+)\\]".r + // patternLocalStar example: local[*] val patternLocalStar = "local\\[\\*\\]".r + // patternLocalStarF example: local[*,4] + val patternLocalStarF = "local\\[\\*,\\s*(\\d+)\\]".r master match { case patternLocalN(n) => Some(1, n.toInt) + case patternLocalNF(n, f) => Some(1, n.toInt) case patternLocalStar(_*) => Some(1, getNumMachineCores) + case patternLocalStarF(_*) => Some(1, getNumMachineCores) case _ => - Log4Error.invalidOperationError(false, s"Can't parser master $master") + Log4Error.invalidOperationError(false, s"Can't parse master $master") Some(1, 0) } } else if (master.toLowerCase.startsWith("spark")) { diff --git a/scala/dllib/src/test/scala/com/intel/analytics/bigdl/dllib/utils/EngineSpec.scala b/scala/dllib/src/test/scala/com/intel/analytics/bigdl/dllib/utils/EngineSpec.scala index 3e9a5b52051..2137eb7b53d 100644 --- a/scala/dllib/src/test/scala/com/intel/analytics/bigdl/dllib/utils/EngineSpec.scala +++ b/scala/dllib/src/test/scala/com/intel/analytics/bigdl/dllib/utils/EngineSpec.scala @@ -128,6 +128,17 @@ nExecutor should be(1) } + "sparkExecutorAndCore" should "parse local[*,4]" in { + val conf = Engine.createSparkConf().setAppName("EngineSpecTest").setMaster("local[*,4]") + val (nExecutor, _) = Engine.parseExecutorAndCore(conf).get + nExecutor should be(1) + } + + "sparkExecutorAndCore" should "parse local[4,2]" in { + val conf = Engine.createSparkConf().setAppName("EngineSpecTest").setMaster("local[4,2]") + Engine.parseExecutorAndCore(conf) should be(Some(1, 4)) + } + "readConf" should "be right" in { val conf = Engine.readConf val target = Map(