Skip to content

Commit

Permalink
Add generation of wrappers for keeping record references (finos#82)
Browse files Browse the repository at this point in the history
* Add generation of wrappers for keeping record references

Now we don't remove parameters representing records from functions .
This fixes problems where the a column may be lead to ambiguities.

* Code review suggestions
  • Loading branch information
sfc-gh-lfallasavendano committed Jan 8, 2024
1 parent 3d62b91 commit 010a323
Show file tree
Hide file tree
Showing 10 changed files with 201 additions and 81 deletions.
1 change: 1 addition & 0 deletions src/Morphir/Scala/AST.elm
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ type Value
| CommentedValue Value String
| ForComp (List Generator) Value
| TypeAscripted Value Type
| New Path Name (List ArgValue)


{-| -}
Expand Down
3 changes: 3 additions & 0 deletions src/Morphir/Scala/PrettyPrinter.elm
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,9 @@ mapValue opt value =

Apply funValue argValues ->
mapValue opt funValue ++ argValueBlock opt argValues

New path name argValues ->
"new" ++ " " ++ (dotSep <| prefixKeywords (path ++ [ name ])) ++ argValueBlock opt argValues

UnOp op right ->
op ++ mapValue opt right
Expand Down
50 changes: 38 additions & 12 deletions src/Morphir/Snowpark/AccessElementMapping.elm
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ module Morphir.Snowpark.AccessElementMapping exposing (
{-| This module contains functions to generate code like `a.b` or `a`.
|-}


import Dict exposing (Dict)
import Morphir.IR.Name as Name
import Morphir.IR.Type as IrType
import Morphir.Scala.AST as Scala
Expand All @@ -27,32 +27,58 @@ import Morphir.IR.Value as Value
import Morphir.Snowpark.MappingContext exposing (isAnonymousRecordWithSimpleTypes)
import Morphir.Snowpark.Constants exposing (applySnowparkFunc)
import Morphir.Snowpark.MappingContext exposing (isUnionTypeWithParams)
import Morphir.IR.Value as Value
import String exposing (replace)


checkForDataFrameVariableReference : Value ta (IrType.Type ()) -> ValueMappingContext -> Maybe String
checkForDataFrameVariableReference value ctx =
case Value.valueAttribute value of
IrType.Reference _ typeName _ ->
Dict.get typeName ctx.dataFrameColumnsObjects
_ ->
Nothing

mapFieldAccess : va -> (Value ta (IrType.Type ())) -> Name.Name -> ValueMappingContext -> Scala.Value
mapFieldAccess _ value name ctx =
(let
simpleFieldName = name |> Name.toCamelCase
valueIsFunctionParameter =
valueIsFunctionParameter =
case value of
Value.Variable _ varName -> (varName, List.member varName ctx.parameters)
_ -> (Name.fromString "a",False)
Value.Variable _ varName ->
if List.member varName ctx.parameters then
Just <| Name.toCamelCase varName
else
Nothing
_ ->
Nothing
valueIsDataFrameColumnAccess =
case (value, checkForDataFrameVariableReference value ctx) of
(Value.Variable _ _, Just replacement) ->
Just replacement
_ ->
Nothing
in
case (isValueReferenceToSimpleTypesRecord value ctx.typesContextInfo, valueIsFunctionParameter, value) of
(_, (paramName, True), _) ->
Scala.Ref [paramName |> Name.toCamelCase] simpleFieldName
(Just (path, refererName), (_, False), _) ->
case (isValueReferenceToSimpleTypesRecord value ctx.typesContextInfo, valueIsFunctionParameter, valueIsDataFrameColumnAccess) of
(_,Just replacement, _ ) ->
Scala.Ref [replacement] simpleFieldName
(_,_, Just replacement) ->
Scala.Ref [replacement] simpleFieldName
(Just (path, refererName), Nothing, Nothing) ->
Scala.Ref (path ++ [refererName |> Name.toTitleCase]) simpleFieldName
_ ->
(if isAnonymousRecordWithSimpleTypes (value |> Value.valueAttribute) ctx.typesContextInfo then
applySnowparkFunc "col" [Scala.Literal (Scala.StringLit (Name.toCamelCase name)) ]
else
Scala.Literal (Scala.StringLit "Field access to not converted")))

mapVariableAccess : (IrType.Type a) -> Name.Name -> ValueMappingContext -> Scala.Value
mapVariableAccess _ name ctx =
case getReplacementForIdentifier name ctx of
Just replacement ->
mapVariableAccess : Name.Name -> (Value ta (IrType.Type ())) -> ValueMappingContext -> Scala.Value
mapVariableAccess name nameAccess ctx =
case (getReplacementForIdentifier name ctx, checkForDataFrameVariableReference nameAccess ctx) of
(Just replacement, _) ->
replacement
(_, Just replacementStr) ->
Scala.Variable replacementStr
_ ->
Scala.Variable (name |> Name.toCamelCase)

Expand Down
71 changes: 43 additions & 28 deletions src/Morphir/Snowpark/Backend.elm
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ import Morphir.Snowpark.PatternMatchMapping exposing (mapPatternMatch)
import Morphir.Snowpark.Constants as Constants
import Morphir.Scala.Common exposing (scalaKeywords)
import Morphir.Scala.Common exposing (javaObjectMethods)
import Morphir.Snowpark.MappingContext exposing (isCandidateForDataFrame)
import Morphir.IR.FQName as FQName

type alias Options =
{}
Expand Down Expand Up @@ -138,16 +140,13 @@ mapModuleDefinition currentPackagePath currentModulePath accessControlledModuleD
mapFunctionDefinition : Name.Name -> AccessControlled (Documented (Value.Definition () (Type ()))) -> Path -> MappingContextInfo () -> Scala.MemberDecl
mapFunctionDefinition functionName body currentPackagePath mappingCtx =
let
(parameters, localVariables) = processParameters body.value.value.inputTypes mappingCtx
parameters = processParameters body.value.value.inputTypes mappingCtx
parameterNames = body.value.value.inputTypes |> List.map (\(name, _, _) -> name)
valueMappingContext = { emptyValueMappingContext | typesContextInfo = mappingCtx, parameters = parameterNames, packagePath = currentPackagePath}
bodyCandidate = mapFunctionBody body.value.value valueMappingContext

resultingBody = case (localVariables, bodyCandidate) of
([], _) -> bodyCandidate
(declarations, Just bodyToUse) -> Just (Scala.Block declarations bodyToUse)
(_, _) -> Nothing

localDeclarations =
body.value.value.inputTypes
|> List.filterMap (checkForDataFrameColumndsDeclaration mappingCtx)
bodyCandidate = mapFunctionBody body.value.value (includeDataFrameInfo localDeclarations valueMappingContext)
returnTypeToGenerate = mapTypeReference body.value.value.outputType mappingCtx
in
Scala.FunctionDecl
Expand All @@ -158,27 +157,52 @@ mapFunctionDefinition functionName body currentPackagePath mappingCtx =
, returnType =
Just returnTypeToGenerate
, body =
resultingBody
case (localDeclarations |> List.map Tuple.first, bodyCandidate) of
([], Just _) -> bodyCandidate
(declarations, Just bodyToUse) -> Just (Scala.Block declarations bodyToUse)
(_, _) -> Nothing
}

generateLocalVariableForRowRecord : MappingContextInfo () -> ( Name.Name, Type (), Type a ) -> Scala.MemberDecl
generateLocalVariableForRowRecord ctx (name, _, tpe) =
includeDataFrameInfo : List (Scala.MemberDecl, (String, FQName.FQName)) -> ValueMappingContext -> ValueMappingContext
includeDataFrameInfo declInfos ctx =
let
newDataFrameInfo = declInfos
|> List.map (\(_, (varName, typeFullName) ) -> (typeFullName, varName))
|> Dict.fromList
in
{ ctx | dataFrameColumnsObjects = Dict.union ctx.dataFrameColumnsObjects newDataFrameInfo }

checkForDataFrameColumndsDeclaration : MappingContextInfo () -> ( Name.Name, va, Type a ) -> Maybe (Scala.MemberDecl, (String, FQName.FQName))
checkForDataFrameColumndsDeclaration ctx (name, _, tpe) =
let
varNewName = ((name |> Name.toCamelCase) ++ "Columns")
in
case tpe of
Type.Reference _ _ [(Type.Reference _ typeName _) as argType] ->
Just (generateLocalVariableForDataFrameColumns ctx (varNewName, name, argType), (varNewName, typeName))
_ ->
Nothing

generateLocalVariableForDataFrameColumns : MappingContextInfo () -> ( String, Name.Name, Type a ) -> Scala.MemberDecl
generateLocalVariableForDataFrameColumns ctx (name, originalName, tpe) =
let
nameInfo = isTypeReferenceToSimpleTypesRecord tpe ctx
nameInfo =
isTypeReferenceToSimpleTypesRecord tpe ctx
typeNameInfo = Maybe.map
(\(typePath, simpleTypeName) -> Just (Scala.TypeRef typePath (simpleTypeName |> Name.toTitleCase) ))
nameInfo
objectReference = Maybe.map
(\(typePath, simpleTypeName) -> Scala.Ref typePath (simpleTypeName |> Name.toTitleCase))
(\(typePath, simpleTypeName) ->
Scala.New typePath ((simpleTypeName |> Name.toTitleCase) ++ "Wrapper") [Scala.ArgValue Nothing (Scala.Variable (Name.toCamelCase originalName ))] )
nameInfo
in
Scala.ValueDecl {
modifiers = []
, pattern = (Scala.NamedMatch (name |> Name.toCamelCase))
, pattern = (Scala.NamedMatch name)
, valueType = Maybe.withDefault Nothing typeNameInfo
, value = Maybe.withDefault Scala.Unit objectReference
}

generateArgumentDeclarationForFunction : MappingContextInfo () -> ( Name.Name, Type (), Type () ) -> List Scala.ArgDecl
generateArgumentDeclarationForFunction ctx (name, _, tpe) =
[Scala.ArgDecl [] (mapTypeReference tpe ctx) (name |> generateParameterName) Nothing]
Expand All @@ -194,18 +218,9 @@ generateParameterName name =
scalaName


processParameters : List ( Name.Name, Type (), Type () ) -> MappingContextInfo () -> (List (List Scala.ArgDecl), List (Scala.MemberDecl) )
processParameters : List ( Name.Name, Type (), Type () ) -> MappingContextInfo () -> List (List Scala.ArgDecl)
processParameters inputTypes ctx =
let
recWithSimpleTypesPred : ( Name.Name, Type (), Type a ) -> Bool
recWithSimpleTypesPred tpe =
case tpe of
(_, _, Type.Reference _ fqName []) -> not(isRecordWithSimpleTypes fqName ctx)
_ -> True
(typesToProcess, typesWithRecords) = inputTypes |> List.partition recWithSimpleTypesPred
in
(typesToProcess |> List.map (generateArgumentDeclarationForFunction ctx),
typesWithRecords |> List.map (generateLocalVariableForRowRecord ctx))
inputTypes |> List.map (generateArgumentDeclarationForFunction ctx)


mapFunctionBody : Value.Definition ta (Type ()) -> ValueMappingContext -> Maybe Scala.Value
Expand All @@ -219,8 +234,8 @@ mapValue value ctx =
mapLiteral tpe literal
Field tpe val name ->
mapFieldAccess tpe val name ctx
Variable tpe name ->
mapVariableAccess tpe name ctx
Variable _ name as varAccess ->
mapVariableAccess name varAccess ctx
Constructor tpe name ->
mapConstructorAccess tpe name ctx
List _ values ->
Expand Down
24 changes: 12 additions & 12 deletions src/Morphir/Snowpark/MapFunctionsMapping.elm
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import Morphir.Snowpark.ReferenceUtils exposing (scalaPathToModule)
import Morphir.Visual.BoolOperatorTree exposing (functionName)
import Morphir.IR.FQName as FQName
import Morphir.Snowpark.MappingContext exposing (isTypeRefToRecordWithSimpleTypes)
import Morphir.Snowpark.TypeRefMapping exposing (generateRecordTypeWrapperExpression)

type alias MapValueType ta = ValueIR.Value ta (TypeIR.Type ()) -> ValueMappingContext -> Scala.Value

Expand Down Expand Up @@ -116,12 +117,8 @@ tryToConvertUserFunctionCall (func, args) mapValue ctx =
funcReference =
Scala.Ref (scalaPathToModule functionName)
(functionName |> FQName.getLocalName |> Name.toCamelCase)
-- this controversial code removes dataframe arguments
argFilteringCriteria =
(\arg -> not (isTypeRefToRecordWithSimpleTypes (ValueIR.valueAttribute arg) ctx.typesContextInfo))
argsToUse =
args
|> List.filter argFilteringCriteria
|> List.map (\arg -> mapValue arg ctx)
|> List.map (Scala.ArgValue Nothing)
in
Expand All @@ -138,7 +135,7 @@ tryToConvertUserFunctionCall (func, args) mapValue ctx =
argsToUse =
args
|> List.indexedMap (\i arg -> ("field" ++ (String.fromInt i), mapValue arg ctx))
|> List.concatMap (\(field, value) -> [Constants.applySnowparkFunc "lit" [Scala.Literal (Scala.StringLit field), value]])
|> List.concatMap (\(field, value) -> [Constants.applySnowparkFunc "lit" [Scala.Literal (Scala.StringLit field)], value])
tag = [ Constants.applySnowparkFunc "lit" [Scala.Literal (Scala.StringLit "__tag")],
Constants.applySnowparkFunc "lit" [ Scala.Literal (Scala.StringLit <| ( constructorName |> FQName.getLocalName |> Name.toTitleCase))]]
in Constants.applySnowparkFunc "object_construct" (tag ++ argsToUse)
Expand Down Expand Up @@ -203,13 +200,16 @@ generateForListFilter predicate sourceRelation ctx mapValue =
in
if isCandidateForDataFrame (valueAttribute sourceRelation) ctx.typesContextInfo then
case predicate of
ValueIR.Lambda _ _ binExpr ->
generateFilterCall <| mapValue binExpr ctx
ValueIR.Reference _ functionName ->
if isLocalFunctionName functionName ctx then
generateFilterCall <| Scala.Ref (scalaPathToModule functionName) (functionName |> FQName.getLocalName |> Name.toCamelCase)
else
Scala.Literal (Scala.StringLit ("Unsupported filter function scenario2" ))
ValueIR.Lambda _ _ bodyExpr ->
generateFilterCall <| mapValue bodyExpr ctx
ValueIR.Reference (TypeIR.Function _ fromType _) functionName ->
case (isLocalFunctionName functionName ctx, generateRecordTypeWrapperExpression fromType ctx) of
(True, Just typeRefExpr) ->
(generateFilterCall <|
Scala.Apply (Scala.Ref (scalaPathToModule functionName) (functionName |> FQName.getLocalName |> Name.toCamelCase))
[Scala.ArgValue Nothing typeRefExpr])
_ ->
Scala.Literal (Scala.StringLit ("Unsupported filter function scenario2" ))
_ ->
Scala.Literal (Scala.StringLit ("Unsupported filter function scenario" ))
else
Expand Down
15 changes: 13 additions & 2 deletions src/Morphir/Snowpark/MappingContext.elm
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ module Morphir.Snowpark.MappingContext exposing
, isDataFrameFriendlyType
, isLocalFunctionName
, isTypeRefToRecordWithSimpleTypes
, isAliasedBasicType )
, isAliasedBasicType
, getLocalVariableIfDataFrameReference )

{-| This module contains functions to collect information about type definitions in a distribution.
It classifies type definitions in the following kinds:
Expand Down Expand Up @@ -56,13 +57,15 @@ type alias ValueMappingContext =
, typesContextInfo : MappingContextInfo ()
, inlinedIds: Dict Name Scala.Value
, packagePath: Path.Path
, dataFrameColumnsObjects: Dict FQName String
}

emptyValueMappingContext : ValueMappingContext
emptyValueMappingContext = { parameters = []
, inlinedIds = Dict.empty
, typesContextInfo = emptyContext
, packagePath = Path.fromString "default"
, dataFrameColumnsObjects = Dict.empty
}

getReplacementForIdentifier : Name -> ValueMappingContext -> Maybe Scala.Value
Expand Down Expand Up @@ -122,7 +125,7 @@ isTypeAlias name ctx =
case Dict.get name ctx of
Just (TypeClassified (TypeAlias _)) -> True
_ -> False

isCandidateForDataFrame : (Type ()) -> MappingContextInfo () -> Bool
isCandidateForDataFrame typeRef ctx =
case typeRef of
Expand All @@ -137,6 +140,14 @@ isCandidateForDataFrame typeRef ctx =
|> List.all (\{tpe} -> isDataFrameFriendlyType tpe ctx )
_ -> False

getLocalVariableIfDataFrameReference : Type.Type () -> ValueMappingContext -> Maybe String
getLocalVariableIfDataFrameReference tpe ctx =
case tpe of
Type.Reference _ typeName _ ->
Dict.get typeName ctx.dataFrameColumnsObjects
_ ->
Nothing

isAnonymousRecordWithSimpleTypes : Type.Type () -> MappingContextInfo () -> Bool
isAnonymousRecordWithSimpleTypes tpe ctx =
case tpe of
Expand Down
54 changes: 50 additions & 4 deletions src/Morphir/Snowpark/RecordWrapperGenerator.elm
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,9 @@ processUnionTypeDeclaration name constructors noParams =
processRecordDeclaration : Name -> String -> (List (Field a)) -> Bool -> List (Scala.Documented (Scala.Annotated Scala.TypeDecl))
processRecordDeclaration name doc fields recordWithSimpleTypes =
if recordWithSimpleTypes then
[
traitForRecordWrapper name doc fields,
objectForRecordWrapper name fields
]
[ traitForRecordWrapper name doc fields
, objectForRecordWrapper name fields
, classForRecordWrapper name fields]
else
[]

Expand Down Expand Up @@ -106,6 +105,40 @@ generateTraitMember field =
}))


classForRecordWrapper : Name -> (List (Field a)) -> (Scala.Documented (Scala.Annotated Scala.TypeDecl))
classForRecordWrapper name fields =
let
traitName = (name |> toTitleCase)
nameToUse = (name |> toTitleCase) ++ "Wrapper"
members = fields |> List.map generateWrapperClassMember
dataFrameArgDecl =
{ modifiers = []
, tpe = typeRefForSnowparkType "DataFrame"
, name = "df"
, defaultValue = Nothing
}
in
( Scala.Documented Nothing
(Scala.Annotated []
(Scala.Class
{ modifiers =
[]
, name =
nameToUse
, typeArgs =
[]
, ctorArgs =
[ [dataFrameArgDecl] ]
, members =
members
, extends =
[ Scala.TypeRef [] traitName ]
, body =
[]
}
)
))

objectForRecordWrapper : Name -> (List (Field a)) -> (Scala.Documented (Scala.Annotated Scala.TypeDecl))
objectForRecordWrapper name fields =
let
Expand Down Expand Up @@ -144,6 +177,19 @@ generateObjectMember field =
}))


generateWrapperClassMember : (Field a) -> (Scala.Annotated Scala.MemberDecl)
generateWrapperClassMember field =
(Scala.Annotated
[]
(Scala.FunctionDecl
{
modifiers = []
, name = (field.name |> Name.toCamelCase)
, typeArgs = []
, args = []
, returnType = Just (typeRefForSnowparkType "Column")
, body = Just (Scala.Apply (Scala.Variable "df") [Scala.ArgValue Nothing (Scala.Literal (Scala.StringLit (field.name |> Name.toCamelCase)))])
}))

generateUnionTypeNameMember : String -> (Scala.Annotated Scala.MemberDecl)
generateUnionTypeNameMember optionName =
Expand Down
Loading

0 comments on commit 010a323

Please sign in to comment.