Skip to content

Commit

Permalink
refactor(branch-protection): simplify functions by using generics
Browse files Browse the repository at this point in the history
Signed-off-by: Diogo Teles Sant'Anna <[email protected]>
  • Loading branch information
diogoteles08 committed Dec 8, 2023
1 parent 1dd2bbd commit 31b6e31
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 29 deletions.
47 changes: 25 additions & 22 deletions clients/githubrepo/branches.go
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ func copyNonAdminSettings(src interface{}, dst *clients.BranchProtectionRule) {

// Evaluate if we have data to infer that the project requires PRs to make changes. If we don't have data, we let
// the struct RequiredPullRequestReviews as nil
if valueOrDefault(v.RequiredApprovingReviewCount, 0) > 0 || valueOrDefault(v.RequiresCodeOwnerReviews, false) {
if valueOrZero(v.RequiredApprovingReviewCount) > 0 || valueOrZero(v.RequiresCodeOwnerReviews) {
if dst.RequiredPullRequestReviews == nil {
dst.RequiredPullRequestReviews = new(clients.PullRequestReviewRule)
}
Expand Down Expand Up @@ -505,29 +505,25 @@ nextRule:
return ret, nil
}

func initializedBoolRef(value bool) *bool {
return &value
}

func applyRepoRules(branchRef *clients.BranchRef, rules []*repoRuleSet) {
for _, r := range rules {
// Init values of base checkbox as if they're unchecked
translated := clients.BranchProtectionRule{
AllowDeletions: initializedBoolRef(true),
AllowForcePushes: initializedBoolRef(true),
RequireLinearHistory: initializedBoolRef(false),
AllowDeletions: asPtr(true),
AllowForcePushes: asPtr(true),
RequireLinearHistory: asPtr(false),
}

translated.EnforceAdmins = initializedBoolRef(len(r.BypassActors.Nodes) == 0)
translated.EnforceAdmins = asPtr(len(r.BypassActors.Nodes) == 0)

for _, rule := range r.Rules.Nodes {
switch rule.Type {
case ruleDeletion:
translated.AllowDeletions = initializedBoolRef(false)
translated.AllowDeletions = asPtr(false)
case ruleForcePush:
translated.AllowForcePushes = initializedBoolRef(false)
translated.AllowForcePushes = asPtr(false)
case ruleLinear:
translated.RequireLinearHistory = initializedBoolRef(true)
translated.RequireLinearHistory = asPtr(true)
case rulePullRequest:
translatePullRequestRepoRule(&translated, rule)
case ruleStatusCheck:
Expand All @@ -553,7 +549,7 @@ func translateRequiredStatusRepoRule(base *clients.BranchProtectionRule, rule *r
if len(statusParams.RequiredStatusChecks) == 0 {
return
}
base.CheckRules.RequiresStatusChecks = initializedBoolRef(true)
base.CheckRules.RequiresStatusChecks = asPtr(true)
base.CheckRules.UpToDateBeforeMerge = statusParams.StrictRequiredStatusChecksPolicy
for _, chk := range statusParams.RequiredStatusChecks {
if chk.Context == nil {
Expand Down Expand Up @@ -581,21 +577,21 @@ func mergeBranchProtectionRules(base, translated *clients.BranchProtectionRule)
// https://github.com/ossf/scorecard/issues/3480
base.EnforceAdmins = translated.EnforceAdmins
}
if base.RequireLastPushApproval == nil || valueOrDefault(translated.RequireLastPushApproval, false) {
if base.RequireLastPushApproval == nil || valueOrZero(translated.RequireLastPushApproval) {
base.RequireLastPushApproval = translated.RequireLastPushApproval
}
if base.RequireLinearHistory == nil || valueOrDefault(translated.RequireLinearHistory, false) {
if base.RequireLinearHistory == nil || valueOrZero(translated.RequireLinearHistory) {
base.RequireLinearHistory = translated.RequireLinearHistory
}
mergeCheckRules(&base.CheckRules, &translated.CheckRules)
mergePullRequestReviews(&base.RequiredPullRequestReviews, translated.RequiredPullRequestReviews)
}

func mergeCheckRules(base, translated *clients.StatusChecksRule) {
if base.UpToDateBeforeMerge == nil || valueOrDefault(translated.UpToDateBeforeMerge, false) {
if base.UpToDateBeforeMerge == nil || valueOrZero(translated.UpToDateBeforeMerge) {
base.UpToDateBeforeMerge = translated.UpToDateBeforeMerge
}
if base.RequiresStatusChecks == nil || valueOrDefault(translated.RequiresStatusChecks, false) {
if base.RequiresStatusChecks == nil || valueOrZero(translated.RequiresStatusChecks) {
base.RequiresStatusChecks = translated.RequiresStatusChecks
}
for _, context := range translated.Contexts {
Expand All @@ -618,20 +614,27 @@ func mergePullRequestReviews(base **clients.PullRequestReviewRule, translated *c
}

if (*base).RequiredApprovingReviewCount == nil ||
valueOrDefault((*base).RequiredApprovingReviewCount, 0) < valueOrDefault(translated.RequiredApprovingReviewCount, 0) {
valueOrZero((*base).RequiredApprovingReviewCount) < valueOrZero(translated.RequiredApprovingReviewCount) {
(*base).RequiredApprovingReviewCount = translated.RequiredApprovingReviewCount
}
if (*base).DismissStaleReviews == nil || valueOrDefault(translated.DismissStaleReviews, false) {
if (*base).DismissStaleReviews == nil || valueOrZero(translated.DismissStaleReviews) {
(*base).DismissStaleReviews = translated.DismissStaleReviews
}
if (*base).RequireCodeOwnerReviews == nil || valueOrDefault(translated.RequireCodeOwnerReviews, false) {
if (*base).RequireCodeOwnerReviews == nil || valueOrZero(translated.RequireCodeOwnerReviews) {
(*base).RequireCodeOwnerReviews = translated.RequireCodeOwnerReviews
}
}

func valueOrDefault[T any](ptr *T, defaultValue T) T {
// returns a pointer to the given value. Useful for constant values.
func asPtr[T any](value T) *T {
return &value
}

// returns the pointer's value if it exists, the type's zero-value otherwise.
func valueOrZero[T any](ptr *T) T {
if ptr == nil {
return defaultValue
var zero T
return zero
}
return *ptr
}
10 changes: 3 additions & 7 deletions clients/githubrepo/branches_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ func Test_applyRepoRules(t *testing.T) {
Type: rulePullRequest,
Parameters: repoRulesParameters{
PullRequestParameters: pullRequestRuleParameters{
RequireLastPushApproval: initializedBoolRef(true),
RequireLastPushApproval: asPtr(true),
RequiredApprovingReviewCount: &zeroVal,
},
},
Expand Down Expand Up @@ -411,7 +411,7 @@ func Test_applyRepoRules(t *testing.T) {
StrictRequiredStatusChecksPolicy: &trueVal,
RequiredStatusChecks: []statusCheck{
{
Context: stringPtr("foo"),
Context: asPtr("foo"),
},
},
},
Expand Down Expand Up @@ -448,7 +448,7 @@ func Test_applyRepoRules(t *testing.T) {
StrictRequiredStatusChecksPolicy: &trueVal,
RequiredStatusChecks: []statusCheck{
{
Context: stringPtr("foo"),
Context: asPtr("foo"),
},
},
},
Expand Down Expand Up @@ -618,7 +618,3 @@ func Test_translationFromGithubAPIBranchProtectionData(t *testing.T) {
})
}
}

func stringPtr(s string) *string {
return &s
}

0 comments on commit 31b6e31

Please sign in to comment.