diff --git a/pkg/tfgen/installation_docs.go b/pkg/tfgen/installation_docs.go index 7216a8593c..d7116d02ee 100644 --- a/pkg/tfgen/installation_docs.go +++ b/pkg/tfgen/installation_docs.go @@ -18,6 +18,7 @@ import ( "github.com/pulumi/pulumi-terraform-bridge/v3/pkg/tfbridge" "github.com/pulumi/pulumi-terraform-bridge/v3/pkg/tfgen/parse" "github.com/pulumi/pulumi-terraform-bridge/v3/pkg/tfgen/parse/section" + "github.com/pulumi/pulumi/sdk/v3/go/common/util/contract" ) func plainDocsParser(docFile *DocFile, g *Generator) ([]byte, error) { @@ -261,22 +262,20 @@ var _ parser.ASTTransformer = sectionSkipper{} func (t sectionSkipper) Transform(node *ast.Document, reader text.Reader, pc parser.Context) { source := reader.Source() - // All headings are first children of the ast.Document node. We will walk over - // them and remove any that match the header content we do not want, along with - // their associated [section.Section]. - for currentChild := node.FirstChild(); currentChild != nil; { - if section, ok := currentChild.(*section.Section); ok { + err := ast.Walk(node, func(n ast.Node, entering bool) (ast.WalkStatus, error) { + if section, ok := n.(*section.Section); ok && !entering { headerText := section.FirstChild().(*ast.Heading).Text(source) if t.shouldSkipHeader(string(headerText)) { - currentChild = section.NextSibling() parent := section.Parent() + if parent == nil { + panic("PARENT IS NIL") + } parent.RemoveChild(parent, section) - continue } } - // Move to next node in base case. - currentChild = currentChild.NextSibling() - } + return ast.WalkContinue, nil + }) + contract.AssertNoErrorf(err, "impossible") } // SkipSectionByHeaderContent removes headers where shouldSkipHeader(header) returns true, diff --git a/pkg/tfgen/installation_docs_test.go b/pkg/tfgen/installation_docs_test.go index 8033ca9422..2ea8e0a90c 100644 --- a/pkg/tfgen/installation_docs_test.go +++ b/pkg/tfgen/installation_docs_test.go @@ -302,8 +302,8 @@ func TestSkipSectionHeaderByContent(t *testing.T) { tc := testCase{ name: "Skips Section With Unwanted Header", headerToSkip: "Debugging Provider Output Using Logs", - input: readfile(t, "test_data/skip-sections-by-header/input.md"), - expected: readfile(t, "test_data/skip-sections-by-header/actual.md"), + input: readTestFile(t, "skip-sections-by-header/input.md"), + expected: readTestFile(t, "skip-sections-by-header/actual.md"), } t.Run(tc.name, func(t *testing.T) { diff --git a/pkg/tfgen/parse/section/section.go b/pkg/tfgen/parse/section/section.go index 92384caeb8..7d85337d75 100644 --- a/pkg/tfgen/parse/section/section.go +++ b/pkg/tfgen/parse/section/section.go @@ -56,27 +56,37 @@ func (s *Section) Dump(source []byte, level int) { func (s *Section) Kind() ast.NodeKind { return Kind } func (s sectionParser) Transform(node *ast.Document, reader text.Reader, pc parser.Context) { - for node := node.FirstChild(); node != nil; node = node.NextSibling() { + s.transform(node, reader, pc, false) +} + +func (s sectionParser) transform(node ast.Node, reader text.Reader, pc parser.Context, skipFirst bool) { + parent := node + node = node.FirstChild() + if skipFirst { + node = node.NextSibling() + } + for node != nil { heading, ok := node.(*ast.Heading) if !ok { + node = node.NextSibling() continue } + node = heading.NextSibling() - parent := heading.Parent() section := &Section{} - node = section - c := heading.NextSibling() parent.ReplaceChild(parent, heading, section) section.AppendChild(section, heading) - for c != nil { - if child, ok := c.(*ast.Heading); ok && child.Level >= heading.Level { + for node != nil { + if child, ok := node.(*ast.Heading); ok && child.Level <= heading.Level { break } - child := c + child := node // We are going to add c to section - c = c.NextSibling() + node = node.NextSibling() section.AppendChild(section, child) } + s.transform(section, reader, pc, true) + } } diff --git a/pkg/tfgen/parse/section/section_test.go b/pkg/tfgen/parse/section/section_test.go index a8e8151776..3d783b5ce5 100644 --- a/pkg/tfgen/parse/section/section_test.go +++ b/pkg/tfgen/parse/section/section_test.go @@ -78,6 +78,31 @@ content (again) content (again) `), }, + { + input: ` + +Hi + +# 1 + +content + +## 2 + +nested content +`, + walk: func(src []byte, node ast.Node, entering bool) (ast.WalkStatus, error) { + s, ok := node.(*section.Section) + if !ok || !entering { + return ast.WalkContinue, nil + } + if string(s.FirstChild().(*ast.Heading).Text(src)) == "1" { + s.Parent().RemoveChild(s.Parent(), s) + } + return ast.WalkContinue, nil + }, + expected: autogold.Expect("Hi\n"), + }, { input: `# I am a provider