Skip to content

Commit

Permalink
Added support to map from rootGlobals to application globals
Browse files Browse the repository at this point in the history
  • Loading branch information
djfreels committed Oct 15, 2021
1 parent 3981366 commit eeab70f
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 33 deletions.
2 changes: 1 addition & 1 deletion manual_tests/testData/metalus-common/steps.json

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,7 @@ trait PipelineStepMapper {
def mapParameter(parameter: Parameter, pipelineContext: PipelineContext): Any = {
// Get the value/defaultValue for this parameter
val value = getParamValue(parameter)
val returnValue = if (value.isDefined) {
removeOptions(value) match {
val returnValue = value.map(removeOptions).flatMap {
case s: String =>
parameter.`type`.getOrElse("none").toLowerCase match {
case "script" =>
Expand All @@ -239,13 +238,11 @@ trait PipelineStepMapper {
case b: Boolean => Some(b)
case i: Int => Some(i)
case i: BigInt => Some(i.toInt)
case d: Double => Some(d)
case l: List[_] => handleListParameter(l, parameter, pipelineContext)
case m: Map[_, _] => handleMapParameter(m, parameter, pipelineContext)
case t => // Handle other types - This function may need to be reworked to support this so that it can be overridden
throw new RuntimeException(s"Unsupported value type ${t.getClass} for ${parameter.name.getOrElse("unknown")}!")
}
} else {
None
}

// use the first valid (non-empty) value found
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ object ApplicationUtils {
* @param pipelineListener An optional PipelineListener. This may be overridden by the application.
* @return An execution plan.
*/
//noinspection ScalaStyle
def createExecutionPlan(application: Application, globals: Option[Map[String, Any]], sparkConf: SparkConf,
pipelineListener: PipelineListener = PipelineListener(),
applicationTriggers: ApplicationTriggers = ApplicationTriggers(),
Expand All @@ -77,18 +78,23 @@ object ApplicationUtils {
logger.info(s"setting parquet dictionary enabled to ${applicationTriggers.parquetDictionaryEnabled.toString}")
sparkSession.sparkContext.hadoopConfiguration.set("parquet.enable.dictionary", applicationTriggers.parquetDictionaryEnabled.toString)
implicit val formats: Formats = getJson4sFormats(application.json4sSerializers)
val globalStepMapper = generateStepMapper(application.stepMapper, Some(PipelineStepMapper()),
applicationTriggers.validateArgumentTypes, credentialProvider)
val rootGlobals = globals.getOrElse(Map[String, Any]()) // Create the default globals
val defaultGlobals = generateGlobals(application.globals, rootGlobals, Some(rootGlobals))
val globalListener = generatePipelineListener(application.pipelineListener, Some(pipelineListener),
applicationTriggers.validateArgumentTypes, credentialProvider)
val globalSecurityManager = generateSecurityManager(application.securityManager, Some(PipelineSecurityManager()),
applicationTriggers.validateArgumentTypes, credentialProvider)
val globalStepMapper = generateStepMapper(application.stepMapper, Some(PipelineStepMapper()),
applicationTriggers.validateArgumentTypes, credentialProvider)
val globalPipelineParameters = generatePipelineParameters(application.pipelineParameters, Some(PipelineParameters()))
val pipelineManager = generatePipelineManager(application.pipelineManager,
Some(PipelineManager(application.pipelines.getOrElse(List[DefaultPipeline]()))),
applicationTriggers.validateArgumentTypes, credentialProvider).get
val initialContext = PipelineContext(Some(sparkConf), Some(sparkSession), Some(rootGlobals), globalSecurityManager.get,
globalPipelineParameters.get, application.stepPackages, globalStepMapper.get, globalListener,
Some(sparkSession.sparkContext.collectionAccumulator[PipelineStepMessage]("stepMessages")),
ExecutionAudit("root", AuditType.EXECUTION, Map[String, Any](), System.currentTimeMillis()), pipelineManager,
credentialProvider, Some(formats))
val defaultGlobals = generateGlobals(application.globals, rootGlobals , Some(rootGlobals), initialContext)
generateSparkListeners(application.sparkListeners,
applicationTriggers.validateArgumentTypes, credentialProvider).getOrElse(List()).foreach(sparkSession.sparkContext.addSparkListener)
addSparkListener(globalListener, sparkSession)
Expand All @@ -101,18 +107,17 @@ object ApplicationUtils {
}
generateSparkListeners(execution.sparkListeners,
applicationTriggers.validateArgumentTypes, credentialProvider).getOrElse(List()).foreach(sparkSession.sparkContext.addSparkListener)
val stepMapper = generateStepMapper(execution.stepMapper, globalStepMapper, applicationTriggers.validateArgumentTypes,
credentialProvider).get
// Extracting pipelines
val ctx = PipelineContext(Some(sparkConf),
Some(sparkSession),
generateGlobals(execution.globals, rootGlobals, defaultGlobals, execution.mergeGlobals.getOrElse(false)),
generateSecurityManager(execution.securityManager, globalSecurityManager,
val ctx = initialContext.copy(
globals = generateGlobals(execution.globals, rootGlobals, defaultGlobals, initialContext, execution.mergeGlobals.getOrElse(false)),
security = generateSecurityManager(execution.securityManager, globalSecurityManager,
applicationTriggers.validateArgumentTypes, credentialProvider).get,
generatePipelineParameters(execution.pipelineParameters, globalPipelineParameters).get, application.stepPackages,
generateStepMapper(execution.stepMapper, globalStepMapper, applicationTriggers.validateArgumentTypes,
credentialProvider).get, pipelineListener,
Some(sparkSession.sparkContext.collectionAccumulator[PipelineStepMessage]("stepMessages")),
ExecutionAudit("root", AuditType.EXECUTION, Map[String, Any](), System.currentTimeMillis()),
pipelineManager, credentialProvider, Some(formats))
parameters = generatePipelineParameters(execution.pipelineParameters, globalPipelineParameters).get,
parameterMapper = stepMapper,
pipelineListener = pipelineListener
)
PipelineExecution(execution.id.getOrElse(""),
generatePipelines(execution, application, pipelineManager), execution.initialPipelineId, ctx, execution.parents)
})
Expand All @@ -132,10 +137,12 @@ object ApplicationUtils {
execution: Execution,
pipelineExecution: PipelineExecution): PipelineExecution = {
implicit val formats: Formats = getJson4sFormats(application.json4sSerializers)
val defaultGlobals = generateGlobals(application.globals, rootGlobals.get, rootGlobals)
val initialContext = pipelineExecution.pipelineContext.copy(globals = rootGlobals)
val defaultGlobals = generateGlobals(application.globals, rootGlobals.get, rootGlobals, initialContext)
val globalPipelineParameters = generatePipelineParameters(application.pipelineParameters, Some(PipelineParameters()))
val ctx = pipelineExecution.pipelineContext
.copy(globals = generateGlobals(execution.globals, rootGlobals.get, defaultGlobals, execution.mergeGlobals.getOrElse(false)))
.copy(globals = generateGlobals(execution.globals, rootGlobals.get, defaultGlobals,
initialContext, execution.mergeGlobals.getOrElse(false)))
.copy(parameters = generatePipelineParameters(execution.pipelineParameters, globalPipelineParameters).get)
pipelineExecution.asInstanceOf[DefaultPipelineExecution].copy(pipelineContext = ctx)
}
Expand Down Expand Up @@ -266,18 +273,25 @@ object ApplicationUtils {
private def generateGlobals(globals: Option[Map[String, Any]],
rootGlobals: Map[String, Any],
defaultGlobals: Option[Map[String, Any]],
pipelineContext: PipelineContext,
merge: Boolean = false)(implicit formats: Formats): Option[Map[String, Any]] = {
if (globals.isEmpty) {
defaultGlobals
} else {
val baseGlobals = globals.get
val result = baseGlobals.foldLeft(rootGlobals)((rootMap, entry) => parseValue(rootMap, entry._1, entry._2))
Some(if (merge) {
globals.map { baseGlobals =>
val result = rootGlobals ++ baseGlobals.map{
case (key, m: Map[String, Any]) if m.contains("className") =>
key -> Parameter(Some("object"), Some(key), value = m.get("object"), className = m.get("className").map(_.toString))
case (key, l: List[Any]) => key -> Parameter(Some("list"), Some(key), value = Some(l))
case (key, value) => key -> Parameter(Some("text"), Some(key), value = Some(value))
}.map{
case ("GlobalLinks", p) => "GlobalLinks" -> p.value.get // skip global links
case (key, p) => key -> pipelineContext.parameterMapper.mapParameter(p, pipelineContext)
}
// val result = baseGlobals.foldLeft(rootGlobals)((rootMap, entry) => parseValue(rootMap, entry._1, entry._2))
if (merge) {
defaultGlobals.getOrElse(Map[String, Any]()) ++ result
} else {
result
})
}
}
}.orElse(defaultGlobals)
}

private def parseParameters(classInfo: ClassInfo, credentialProvider: Option[CredentialProvider])(implicit formats: Formats): Map[String, Any] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ class ApplicationTests extends FunSpec with BeforeAndAfterAll with Suite {
assert(globals.contains("rootLogLevel"))
assert(globals.contains("rootLogLevel"))
assert(globals.contains("number"))
assert(globals("number").asInstanceOf[BigInt] == 5)
assert(globals("number").asInstanceOf[Int] == 5)
assert(globals.contains("float"))
assert(globals("float").asInstanceOf[Double] == 1.5)
assert(globals.contains("string"))
Expand Down Expand Up @@ -452,7 +452,7 @@ class ApplicationTests extends FunSpec with BeforeAndAfterAll with Suite {
assert(globals.contains("rootLogLevel"))
assert(globals.contains("rootLogLevel"))
assert(globals.contains("number"))
assert(globals("number").asInstanceOf[BigInt] == 2)
assert(globals("number").asInstanceOf[Int] == 2)
assert(globals.contains("float"))
assert(globals("float").asInstanceOf[Double] == 3.5)
assert(globals.contains("string"))
Expand Down Expand Up @@ -492,7 +492,7 @@ class ApplicationTests extends FunSpec with BeforeAndAfterAll with Suite {
assert(globals1.contains("rootLogLevel"))
assert(globals1.contains("rootLogLevel"))
assert(globals1.contains("number"))
assert(globals1("number").asInstanceOf[BigInt] == 1)
assert(globals1("number").asInstanceOf[Int] == 1)
assert(globals1.contains("float"))
assert(globals1("float").asInstanceOf[Double] == 1.5)
assert(globals1.contains("string"))
Expand Down Expand Up @@ -565,7 +565,7 @@ class ApplicationTests extends FunSpec with BeforeAndAfterAll with Suite {
assert(globals.contains("rootLogLevel"))
assert(globals.contains("rootLogLevel"))
assert(globals.contains("number"))
assert(globals("number").asInstanceOf[BigInt] == 2)
assert(globals("number").asInstanceOf[Int] == 2)
assert(globals.contains("float"))
assert(globals("float").asInstanceOf[Double] == 3.5)
assert(globals.contains("string"))
Expand Down

0 comments on commit eeab70f

Please sign in to comment.