Skip to content

Commit

Permalink
Merge pull request #1021 from microsoft/failfast
Browse files Browse the repository at this point in the history
FailFast instead of throw from non-HRESULT returning CCW methods
  • Loading branch information
AArnott authored Aug 11, 2023
2 parents 204a18a + 4119ae5 commit 9904847
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 8 deletions.
2 changes: 2 additions & 0 deletions src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ internal static ForStatementSyntax ForStatement(VariableDeclarationSyntax? decla

internal static AttributeTargetSpecifierSyntax AttributeTargetSpecifier(SyntaxToken identifier) => SyntaxFactory.AttributeTargetSpecifier(identifier, TokenWithSpace(SyntaxKind.ColonToken));

internal static ThrowStatementSyntax ThrowStatement() => SyntaxFactory.ThrowStatement(default, Token(SyntaxKind.ThrowKeyword), null, Semicolon);

internal static ThrowStatementSyntax ThrowStatement(ExpressionSyntax expression) => SyntaxFactory.ThrowStatement(Token(TriviaList(), SyntaxKind.ThrowKeyword, TriviaList(Space)), expression, Semicolon);

internal static ThrowExpressionSyntax ThrowExpression(ExpressionSyntax expression) => SyntaxFactory.ThrowExpression(Token(TriviaList(), SyntaxKind.ThrowKeyword, TriviaList(Space)), expression);
Expand Down
31 changes: 23 additions & 8 deletions src/Microsoft.Windows.CsWin32/Generator.Com.cs
Original file line number Diff line number Diff line change
Expand Up @@ -425,20 +425,35 @@ void AddCcwThunk(params StatementSyntax[] thunkInvokeAndReturn)
//// hr.ThrowOnFailure();
: ExpressionStatement(InvocationExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, hrLocal, HRThrowOnFailureMethodName)));

//// catch (Exception ex) { return (HRESULT)ex.HResult; }
IdentifierNameSyntax exLocal = IdentifierName("ex");
CatchClauseSyntax catchClause = CatchClause(CatchDeclaration(IdentifierName(nameof(Exception)).WithTrailingTrivia(Space), exLocal.Identifier), null, Block().AddStatements(
ReturnStatement(CastExpression(HresultTypeSyntax, MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, exLocal, IdentifierName(nameof(Exception.HResult)))))));
BlockSyntax catchBlock = Block();
if (hrReturnType)
{
//// return (HRESULT)ex.HResult;
catchBlock = catchBlock.AddStatements(ReturnStatement(CastExpression(HresultTypeSyntax, MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, exLocal, IdentifierName(nameof(Exception.HResult))))));
}
else
{
//// Environment.FailFast("COM object threw an exception from a non-HRESULT returning method.", ex);
//// throw;
catchBlock = catchBlock.AddStatements(
ExpressionStatement(InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, ParseName("global::System.Environment"), IdentifierName(nameof(Environment.FailFast))),
ArgumentList().AddArguments(
Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal("COM object threw an exception from a non-HRESULT returning method."))),
Argument(exLocal)))),
ThrowStatement());
}

//// catch (Exception ex) {
CatchClauseSyntax catchClause = CatchClause(CatchDeclaration(IdentifierName(nameof(Exception)).WithTrailingTrivia(Space), exLocal.Identifier), null, catchBlock);

BlockSyntax tryBlock = Block().AddStatements(
hrDecl,
ifNullReturnStatement).AddStatements(thunkInvokeAndReturn);

BlockSyntax ccwBody = hrReturnType
//// try { ... } catch { ... }
? Block().AddStatements(TryStatement(tryBlock, new SyntaxList<CatchClauseSyntax>(catchClause), null))
//// { .... } // any exception is thrown back to native code.
: tryBlock;
//// try { ... } catch { ... }
BlockSyntax ccwBody = Block().AddStatements(TryStatement(tryBlock, new SyntaxList<CatchClauseSyntax>(catchClause), null));

//// [UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })]
//// private static HRESULT Clone(IEnumEventObject* @this, IEnumEventObject** ppInterface)
Expand Down
2 changes: 2 additions & 0 deletions src/Microsoft.Windows.CsWin32/Generator.WhitespaceRewriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ internal WhitespaceRewriter()

public override SyntaxNode? VisitTryStatement(TryStatementSyntax node) => base.VisitTryStatement(this.WithIndentingTrivia(node));

public override SyntaxNode? VisitThrowStatement(ThrowStatementSyntax node) => base.VisitThrowStatement(this.WithIndentingTrivia(node));

public override SyntaxNode? VisitCatchClause(CatchClauseSyntax node) => base.VisitCatchClause(this.WithIndentingTrivia(node));

public override SyntaxNode? VisitFinallyClause(FinallyClauseSyntax node) => base.VisitFinallyClause(this.WithIndentingTrivia(node));
Expand Down
16 changes: 16 additions & 0 deletions test/Microsoft.Windows.CsWin32.Tests/COMTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,22 @@ public void InterestingComInterfaces(
this.AssertNoDiagnostics();
}

[Fact]
public void EnvironmentFailFast()
{
this.compilation = this.starterCompilations["net7.0"];
this.generator = this.CreateGenerator(DefaultTestGeneratorOptions with { AllowMarshaling = false });

// Emit something into the Environment namespace, to invite collisions.
Assert.True(this.generator.TryGenerate("ENCLAVE_IDENTITY", CancellationToken.None));

// Emit the interface that can require Environment.FailFast.
Assert.True(this.generator.TryGenerate("ITypeInfo", CancellationToken.None));

this.CollectGeneratedCode(this.generator);
this.AssertNoDiagnostics();
}

[Fact]
public void ComOutPtrTypedAsOutObject()
{
Expand Down

0 comments on commit 9904847

Please sign in to comment.