Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize reflection of F# types, part 2 #9784

Merged
merged 2 commits into from
Jul 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 121 additions & 13 deletions src/fsharp/FSharp.Core/reflect.fs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ module internal Impl =
| null -> None
| prop -> Some(fun (obj: obj) -> prop.GetValue (obj, instancePropertyFlags ||| bindingFlags, null, null, null))

//-----------------------------------------------------------------
// EXPRESSION TREE COMPILATION

let compilePropGetterFunc (prop: PropertyInfo) =
let param = Expression.Parameter (typeof<obj>, "param")

Expand All @@ -77,6 +80,84 @@ module internal Impl =
param)
expr.Compile ()

let compileRecordOrUnionCaseReaderFunc (typ, props: PropertyInfo[]) =
let param = Expression.Parameter (typeof<obj>, "param")
let typedParam = Expression.Variable typ

let expr =
Expression.Lambda<Func<obj, obj[]>> (
Expression.Block (
[ typedParam ],
Expression.Assign (typedParam, Expression.Convert (param, typ)),
Expression.NewArrayInit (typeof<obj>, [
for prop in props ->
Expression.Convert (Expression.Property (typedParam, prop), typeof<obj>) :> Expression
])
),
param)
expr.Compile ()

let compileRecordConstructorFunc (ctorInfo: ConstructorInfo) =
let ctorParams = ctorInfo.GetParameters ()
let paramArray = Expression.Parameter (typeof<obj[]>, "paramArray")

let expr =
Expression.Lambda<Func<obj[], obj>> (
Expression.Convert (
Expression.New (
ctorInfo,
[
for paramIndex in 0 .. ctorParams.Length - 1 do
let p = ctorParams.[paramIndex]

Expression.Convert (
Expression.ArrayAccess (paramArray, Expression.Constant paramIndex),
p.ParameterType
) :> Expression
]
),
typeof<obj>),
paramArray
)
expr.Compile ()

let compileUnionCaseConstructorFunc (methodInfo: MethodInfo) =
let methodParams = methodInfo.GetParameters ()
let paramArray = Expression.Parameter (typeof<obj[]>, "param")

let expr =
Expression.Lambda<Func<obj[], obj>> (
Expression.Convert (
Expression.Call (
methodInfo,
[
for paramIndex in 0 .. methodParams.Length - 1 do
let p = methodParams.[paramIndex]

Expression.Convert (
Expression.ArrayAccess (paramArray, Expression.Constant paramIndex),
p.ParameterType
) :> Expression
]
),
typeof<obj>),
paramArray
)
expr.Compile ()

let compileUnionTagReaderFunc (info: Choice<MethodInfo, PropertyInfo>) =
let param = Expression.Parameter (typeof<obj>, "param")
let tag =
match info with
| Choice1Of2 info -> Expression.Call (info, Expression.Convert (param, info.DeclaringType)) :> Expression
| Choice2Of2 info -> Expression.Property (Expression.Convert (param, info.DeclaringType), info) :> _

let expr =
Expression.Lambda<Func<obj, int>> (
tag,
param)
expr.Compile ()

//-----------------------------------------------------------------
// ATTRIBUTE DECOMPILATION

Expand Down Expand Up @@ -275,6 +356,12 @@ module internal Impl =
let props = fieldsPropsOfUnionCase (typ, tag, bindingFlags)
(fun (obj: obj) -> props |> Array.map (fun prop -> prop.GetValue (obj, bindingFlags, null, null, null)))

let getUnionCaseRecordReaderCompiled (typ: Type, tag: int, bindingFlags) =
let props = fieldsPropsOfUnionCase (typ, tag, bindingFlags)
let caseTyp = getUnionCaseTyp (typ, tag, bindingFlags)
let caseTyp = if isNull caseTyp then typ else caseTyp
compileRecordOrUnionCaseReaderFunc(caseTyp, props).Invoke

let getUnionTagReader (typ: Type, bindingFlags) : (obj -> int) =
if isOptionType typ then
(fun (obj: obj) -> match obj with null -> 0 | _ -> 1)
Expand All @@ -286,9 +373,22 @@ module internal Impl =
match getInstancePropertyReader (typ, "Tag", bindingFlags) with
| Some reader -> (fun (obj: obj) -> reader obj :?> int)
| None ->
(fun (obj: obj) ->
let m2b = typ.GetMethod("GetTag", BindingFlags.Static ||| bindingFlags, null, [| typ |], null)
m2b.Invoke(null, [|obj|]) :?> int)
let m2b = typ.GetMethod("GetTag", BindingFlags.Static ||| bindingFlags, null, [| typ |], null)
(fun (obj: obj) -> m2b.Invoke(null, [|obj|]) :?> int)
Comment on lines -289 to +377
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's probably no need to look up the method on every invocation. I didn't know how to test this specifically though. When is it ever the case that a DU doesn't have a Tag property?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A DU always has a tag as far as I know - struct/ref type, single case/multi case all have tag properties.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, then I'm not sure if this branch executes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Paul had PR to do tagless DUs a few (lot of) years ago.


let getUnionTagReaderCompiled (typ: Type, bindingFlags) : (obj -> int) =
if isOptionType typ then
(fun (obj: obj) -> match obj with null -> 0 | _ -> 1)
else
let tagMap = getUnionTypeTagNameMap (typ, bindingFlags)
if tagMap.Length <= 1 then
(fun (_obj: obj) -> 0)
else
match getInstancePropertyInfo (typ, "Tag", bindingFlags) with
| null ->
let m2b = typ.GetMethod("GetTag", BindingFlags.Static ||| bindingFlags, null, [| typ |], null)
compileUnionTagReaderFunc(Choice1Of2 m2b).Invoke
| info -> compileUnionTagReaderFunc(Choice2Of2 info).Invoke

let getUnionTagMemberInfo (typ: Type, bindingFlags) =
match getInstancePropertyInfo (typ, "Tag", bindingFlags) with
Expand All @@ -314,6 +414,10 @@ module internal Impl =
(fun args ->
meth.Invoke(null, BindingFlags.Static ||| BindingFlags.InvokeMethod ||| bindingFlags, null, args, null))

let getUnionCaseConstructorCompiled (typ: Type, tag: int, bindingFlags) =
let meth = getUnionCaseConstructorMethod (typ, tag, bindingFlags)
compileUnionCaseConstructorFunc(meth).Invoke

let checkUnionType (unionType, bindingFlags) =
checkNonNull "unionType" unionType
if not (isUnionType (unionType, bindingFlags)) then
Expand Down Expand Up @@ -599,9 +703,9 @@ module internal Impl =
let props = fieldPropsOfRecordType(typ, bindingFlags)
(fun (obj: obj) -> props |> Array.map (fun prop -> prop.GetValue (obj, null)))

let getRecordReaderFromFuncs(typ: Type, bindingFlags) =
let props = fieldPropsOfRecordType(typ, bindingFlags) |> Array.map compilePropGetterFunc
(fun (obj: obj) -> props |> Array.map (fun prop -> prop.Invoke obj))
let getRecordReaderCompiled(typ: Type, bindingFlags) =
let props = fieldPropsOfRecordType(typ, bindingFlags)
compileRecordOrUnionCaseReaderFunc(typ, props).Invoke

let getRecordConstructorMethod(typ: Type, bindingFlags) =
let props = fieldPropsOfRecordType(typ, bindingFlags)
Expand All @@ -616,6 +720,10 @@ module internal Impl =
(fun (args: obj[]) ->
ctor.Invoke(BindingFlags.InvokeMethod ||| BindingFlags.Instance ||| bindingFlags, null, args, null))

let getRecordConstructorCompiled(typ: Type, bindingFlags) =
let ctor = getRecordConstructorMethod(typ, bindingFlags)
compileRecordConstructorFunc(ctor).Invoke

/// EXCEPTION DECOMPILATION
// Check the base type - if it is also an F# type then
// for the moment we know it is a Discriminated Union
Expand Down Expand Up @@ -817,19 +925,19 @@ type FSharpValue =
invalidArg "record" (SR.GetString (SR.objIsNotARecord))
getRecordReader (typ, bindingFlags) record

static member PreComputeRecordFieldReader(info: PropertyInfo) =
static member PreComputeRecordFieldReader(info: PropertyInfo): obj -> obj =
checkNonNull "info" info
(fun (obj: obj) -> info.GetValue (obj, null))
compilePropGetterFunc(info).Invoke

static member PreComputeRecordReader(recordType: Type, ?bindingFlags) : (obj -> obj[]) =
let bindingFlags = defaultArg bindingFlags BindingFlags.Public
checkRecordType ("recordType", recordType, bindingFlags)
getRecordReaderFromFuncs (recordType, bindingFlags)
getRecordReaderCompiled (recordType, bindingFlags)

static member PreComputeRecordConstructor(recordType: Type, ?bindingFlags) =
let bindingFlags = defaultArg bindingFlags BindingFlags.Public
checkRecordType ("recordType", recordType, bindingFlags)
getRecordConstructor (recordType, bindingFlags)
getRecordConstructorCompiled (recordType, bindingFlags)

static member PreComputeRecordConstructorInfo(recordType: Type, ?bindingFlags) =
let bindingFlags = defaultArg bindingFlags BindingFlags.Public
Expand Down Expand Up @@ -894,7 +1002,7 @@ type FSharpValue =
static member PreComputeUnionConstructor (unionCase: UnionCaseInfo, ?bindingFlags) =
let bindingFlags = defaultArg bindingFlags BindingFlags.Public
checkNonNull "unionCase" unionCase
getUnionCaseConstructor (unionCase.DeclaringType, unionCase.Tag, bindingFlags)
getUnionCaseConstructorCompiled (unionCase.DeclaringType, unionCase.Tag, bindingFlags)

static member PreComputeUnionConstructorInfo(unionCase: UnionCaseInfo, ?bindingFlags) =
let bindingFlags = defaultArg bindingFlags BindingFlags.Public
Expand Down Expand Up @@ -926,7 +1034,7 @@ type FSharpValue =
checkNonNull "unionType" unionType
let unionType = getTypeOfReprType (unionType, bindingFlags)
checkUnionType (unionType, bindingFlags)
getUnionTagReader (unionType, bindingFlags)
getUnionTagReaderCompiled (unionType, bindingFlags)

static member PreComputeUnionTagMemberInfo(unionType: Type, ?bindingFlags) =
let bindingFlags = defaultArg bindingFlags BindingFlags.Public
Expand All @@ -939,7 +1047,7 @@ type FSharpValue =
let bindingFlags = defaultArg bindingFlags BindingFlags.Public
checkNonNull "unionCase" unionCase
let typ = unionCase.DeclaringType
getUnionCaseRecordReader (typ, unionCase.Tag, bindingFlags)
getUnionCaseRecordReaderCompiled (typ, unionCase.Tag, bindingFlags)

static member GetExceptionFields (exn: obj, ?bindingFlags) =
let bindingFlags = defaultArg bindingFlags BindingFlags.Public
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ type FSharpValueTests() =
let discStructUnionCaseB = DiscStructUnionType.B(1)
let discStructUnionCaseC = DiscStructUnionType.C(1.0, "stringparam")

let optionSome = Some(3)
let optionNone: int option = None

let voptionSome = ValueSome("stringparam")
let voptionNone: string voption = ValueNone

let list1 = [ 1; 2 ]
let list2: int list = []

let fsharpDelegate1 = new FSharpDelegate(fun (x:int) -> "delegate1")
let fsharpDelegate2 = new FSharpDelegate(fun (x:int) -> "delegate2")
Expand Down Expand Up @@ -738,6 +746,24 @@ type FSharpValueTests() =
let (discUnionInfo, discvaluearray) = FSharpValue.GetUnionFields(discUnionRecCaseB, typeof<DiscUnionType<int>>)
let discUnionReader = FSharpValue.PreComputeUnionReader(discUnionInfo)
Assert.AreEqual(discUnionReader(box(discUnionRecCaseB)) , [| box 1; box(Some(discUnionCaseB)) |])

// Option
let (optionCaseInfo, _) = FSharpValue.GetUnionFields(optionSome, typeof<int option>)
let optionReader = FSharpValue.PreComputeUnionReader(optionCaseInfo)
Assert.AreEqual(optionReader(box(optionSome)), [| box 3 |])

let (optionCaseInfo, _) = FSharpValue.GetUnionFields(optionNone, typeof<int option>)
let optionReader = FSharpValue.PreComputeUnionReader(optionCaseInfo)
Assert.AreEqual(optionReader(box(optionNone)), [| |])

// List
let (listCaseInfo, _) = FSharpValue.GetUnionFields(list1, typeof<int list>)
let listReader = FSharpValue.PreComputeUnionReader(listCaseInfo)
Assert.AreEqual(listReader(box(list1)), [| box 1; box [ 2 ] |])

let (listCaseInfo, _) = FSharpValue.GetUnionFields(list2, typeof<int list>)
let listReader = FSharpValue.PreComputeUnionReader(listCaseInfo)
Assert.AreEqual(listReader(box(list2)), [| |])

[<Test>]
member __.PreComputeStructUnionReader() =
Expand All @@ -751,6 +777,15 @@ type FSharpValueTests() =
let (discUnionInfo, discvaluearray) = FSharpValue.GetUnionFields(discStructUnionCaseB, typeof<DiscStructUnionType<int>>)
let discUnionReader = FSharpValue.PreComputeUnionReader(discUnionInfo)
Assert.AreEqual(discUnionReader(box(discStructUnionCaseB)) , [| box 1|])

// Value Option
let (voptionCaseInfo, _) = FSharpValue.GetUnionFields(voptionSome, typeof<string voption>)
let voptionReader = FSharpValue.PreComputeUnionReader(voptionCaseInfo)
Assert.AreEqual(voptionReader(box(voptionSome)), [| box "stringparam" |])

let (voptionCaseInfo, _) = FSharpValue.GetUnionFields(voptionNone, typeof<string voption>)
let voptionReader = FSharpValue.PreComputeUnionReader(voptionCaseInfo)
Assert.AreEqual(voptionReader(box(voptionNone)), [| |])

[<Test>]
member __.PreComputeUnionTagMemberInfo() =
Expand Down Expand Up @@ -790,6 +825,16 @@ type FSharpValueTests() =
// DiscUnion
let discUnionTagReader = FSharpValue.PreComputeUnionTagReader(typeof<DiscUnionType<int>>)
Assert.AreEqual(discUnionTagReader(box(discUnionCaseB)), 1)

// Option
let optionTagReader = FSharpValue.PreComputeUnionTagReader(typeof<int option>)
Assert.AreEqual(optionTagReader(box(optionSome)), 1)
Assert.AreEqual(optionTagReader(box(optionNone)), 0)

// Value Option
let voptionTagReader = FSharpValue.PreComputeUnionTagReader(typeof<string voption>)
Assert.AreEqual(voptionTagReader(box(voptionSome)), 1)
Assert.AreEqual(voptionTagReader(box(voptionNone)), 0)

// null value
CheckThrowsArgumentException(fun () ->FSharpValue.PreComputeUnionTagReader(null)|> ignore)
Expand Down