Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions github/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ func Provider() *schema.Provider {
"github_repository_milestone": resourceGithubRepositoryMilestone(),
"github_repository_project": resourceGithubRepositoryProject(),
"github_repository_pull_request": resourceGithubRepositoryPullRequest(),
"github_repository_pull_request_creation_policy": resourceGithubRepositoryPullRequestCreationPolicy(),
"github_repository_ruleset": resourceGithubRepositoryRuleset(),
"github_repository_topics": resourceGithubRepositoryTopics(),
"github_repository_webhook": resourceGithubRepositoryWebhook(),
Expand Down
120 changes: 120 additions & 0 deletions github/resource_github_repository_pull_request_creation_policy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package github

import (
"context"
"fmt"

"github.com/hashicorp/terraform-plugin-sdk/v2/diag"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation"
)

func resourceGithubRepositoryPullRequestCreationPolicy() *schema.Resource {
return &schema.Resource{
Description: "Manages the pull request creation policy for a repository.",
CreateContext: resourceGithubRepositoryPullRequestCreationPolicyCreate,
ReadContext: resourceGithubRepositoryPullRequestCreationPolicyRead,
UpdateContext: resourceGithubRepositoryPullRequestCreationPolicyUpdate,
DeleteContext: resourceGithubRepositoryPullRequestCreationPolicyDelete,
Importer: &schema.ResourceImporter{
StateContext: resourceGithubRepositoryPullRequestCreationPolicyImport,
},

Schema: map[string]*schema.Schema{
"repository": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
Description: "The name of the GitHub repository.",
},
"policy": {
Type: schema.TypeString,
Required: true,
Description: "Controls who can create pull requests for the repository. Can be `all` or `collaborators_only`.",
ValidateDiagFunc: validation.ToDiagFunc(validation.StringInSlice([]string{"all", "collaborators_only"}, false)),
},
},
}
}

func resourceGithubRepositoryPullRequestCreationPolicyCreate(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics {
repoName := d.Get("repository").(string)
policy := d.Get("policy").(string)

nodeID, err := getRepositoryID(repoName, meta)
if err != nil {
return diag.Errorf("error resolving repository node ID for %s: %s", repoName, err)
}

if err := updateRepositoryPullRequestCreationPolicy(ctx, nodeID, policy, meta); err != nil {
return diag.Errorf("error setting pull request creation policy for %s: %s", repoName, err)
}

d.SetId(repoName)
return resourceGithubRepositoryPullRequestCreationPolicyRead(ctx, d, meta)
}

func resourceGithubRepositoryPullRequestCreationPolicyRead(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics {
owner := meta.(*Owner).name
repoName := d.Id()

policy, err := getRepositoryPullRequestCreationPolicy(ctx, owner, repoName, meta)
if err != nil {
return diag.Errorf("error reading pull request creation policy for %s/%s: %s", owner, repoName, err)
}

if err := d.Set("policy", policy); err != nil {
return diag.FromErr(err)
}
if err := d.Set("repository", repoName); err != nil {
return diag.FromErr(err)
}

return nil
}

func resourceGithubRepositoryPullRequestCreationPolicyUpdate(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics {
repoName := d.Id()
policy := d.Get("policy").(string)

nodeID, err := getRepositoryID(repoName, meta)
if err != nil {
return diag.Errorf("error resolving repository node ID for %s: %s", repoName, err)
}

if err := updateRepositoryPullRequestCreationPolicy(ctx, nodeID, policy, meta); err != nil {
return diag.Errorf("error updating pull request creation policy for %s: %s", repoName, err)
}

return resourceGithubRepositoryPullRequestCreationPolicyRead(ctx, d, meta)
}

func resourceGithubRepositoryPullRequestCreationPolicyDelete(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics {
repoName := d.Id()

nodeID, err := getRepositoryID(repoName, meta)
if err != nil {
return diag.Errorf("error resolving repository node ID for %s: %s", repoName, err)
}

if err := updateRepositoryPullRequestCreationPolicy(ctx, nodeID, "all", meta); err != nil {
return diag.Errorf("error resetting pull request creation policy for %s: %s", repoName, err)
}

return nil
}

func resourceGithubRepositoryPullRequestCreationPolicyImport(ctx context.Context, d *schema.ResourceData, meta any) ([]*schema.ResourceData, error) {
repoName := d.Id()

if err := d.Set("repository", repoName); err != nil {
return nil, err
}

diags := resourceGithubRepositoryPullRequestCreationPolicyRead(ctx, d, meta)
if diags.HasError() {
return nil, fmt.Errorf("%s", diags[0].Summary)
}

return []*schema.ResourceData{d}, nil
}
100 changes: 100 additions & 0 deletions github/resource_github_repository_pull_request_creation_policy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package github

import (
"fmt"
"testing"

"github.com/hashicorp/terraform-plugin-testing/helper/acctest"
"github.com/hashicorp/terraform-plugin-testing/helper/resource"
)

func TestAccGithubRepositoryPullRequestCreationPolicy(t *testing.T) {
t.Run("sets policy without error", func(t *testing.T) {
randomID := acctest.RandStringFromCharSet(5, acctest.CharSetAlphaNum)
repoName := fmt.Sprintf("%srepo-pr-policy-%s", testResourcePrefix, randomID)
initial := `policy = "collaborators_only"`
updated := `policy = "all"`

config := fmt.Sprintf(`
resource "github_repository" "test" {
name = "%s"
visibility = "private"
auto_init = true
}

resource "github_repository_pull_request_creation_policy" "test" {
repository = github_repository.test.name
%%s
}
`, repoName)

checks := map[string]resource.TestCheckFunc{
"before": resource.ComposeTestCheckFunc(
resource.TestCheckResourceAttr(
"github_repository_pull_request_creation_policy.test", "policy",
"collaborators_only",
),
),
"after": resource.ComposeTestCheckFunc(
resource.TestCheckResourceAttr(
"github_repository_pull_request_creation_policy.test", "policy",
"all",
),
),
}

resource.Test(t, resource.TestCase{
PreCheck: func() { skipUnauthenticated(t) },
ProviderFactories: providerFactories,
Steps: []resource.TestStep{
{
Config: fmt.Sprintf(config, initial),
Check: checks["before"],
},
{
Config: fmt.Sprintf(config, updated),
Check: checks["after"],
},
},
})
})

t.Run("imports without error", func(t *testing.T) {
randomID := acctest.RandStringFromCharSet(5, acctest.CharSetAlphaNum)
repoName := fmt.Sprintf("%srepo-pr-policy-%s", testResourcePrefix, randomID)

config := fmt.Sprintf(`
resource "github_repository" "test" {
name = "%s"
visibility = "private"
auto_init = true
}

resource "github_repository_pull_request_creation_policy" "test" {
repository = github_repository.test.name
policy = "collaborators_only"
}
`, repoName)

check := resource.ComposeTestCheckFunc(
resource.TestCheckResourceAttrSet("github_repository_pull_request_creation_policy.test", "repository"),
resource.TestCheckResourceAttr("github_repository_pull_request_creation_policy.test", "policy", "collaborators_only"),
)

resource.Test(t, resource.TestCase{
PreCheck: func() { skipUnauthenticated(t) },
ProviderFactories: providerFactories,
Steps: []resource.TestStep{
{
Config: config,
Check: check,
},
{
ResourceName: "github_repository_pull_request_creation_policy.test",
ImportState: true,
ImportStateVerify: true,
},
},
})
})
}
87 changes: 87 additions & 0 deletions github/util_v4_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,30 @@ import (
"context"
"encoding/base64"
"errors"
"fmt"

"github.com/shurcooL/githubv4"
)

// PullRequestCreationPolicy mirrors the GitHub GraphQL enum type of the same
// name so we can query and mutate the field even when the vendored client
// model lags behind the live schema.
type PullRequestCreationPolicy string

const (
PullRequestCreationPolicyAll PullRequestCreationPolicy = "ALL"
PullRequestCreationPolicyCollaboratorsOnly PullRequestCreationPolicy = "COLLABORATORS_ONLY"
)

// UpdateRepositoryInput intentionally mirrors the GitHub GraphQL input type
// name so the graphql client emits the correct variable type in mutations.
// We only model the fields needed for pullRequestCreationPolicy updates.
type UpdateRepositoryInput struct {
RepositoryID githubv4.ID `json:"repositoryId"`
PullRequestCreationPolicy *PullRequestCreationPolicy `json:"pullRequestCreationPolicy,omitempty"`
ClientMutationID *githubv4.String `json:"clientMutationId,omitempty"`
}

func getRepositoryID(name string, meta any) (githubv4.ID, error) {
// Interpret `name` as a node ID
exists, nodeIDerr := repositoryNodeIDExists(name, meta)
Expand Down Expand Up @@ -65,6 +85,73 @@ func repositoryNodeIDExists(name string, meta any) (bool, error) {
return query.Node.ID.(string) == name, nil
}

func flattenPullRequestCreationPolicy(policy PullRequestCreationPolicy) (string, error) {
switch policy {
case PullRequestCreationPolicyAll:
return "all", nil
case PullRequestCreationPolicyCollaboratorsOnly:
return "collaborators_only", nil
case "":
return "", nil
default:
return "", fmt.Errorf("unsupported GraphQL pull request creation policy %q", policy)
}
}

func expandPullRequestCreationPolicy(policy string) (PullRequestCreationPolicy, error) {
switch policy {
case "all":
return PullRequestCreationPolicyAll, nil
case "collaborators_only":
return PullRequestCreationPolicyCollaboratorsOnly, nil
default:
return "", fmt.Errorf("unsupported Terraform pull request creation policy %q", policy)
}
}

func getRepositoryPullRequestCreationPolicy(ctx context.Context, owner, name string, meta any) (string, error) {
var query struct {
Repository struct {
PullRequestCreationPolicy PullRequestCreationPolicy
} `graphql:"repository(owner:$owner, name:$name)"`
}

variables := map[string]any{
"owner": githubv4.String(owner),
"name": githubv4.String(name),
}

client := meta.(*Owner).v4client
if err := client.Query(ctx, &query, variables); err != nil {
return "", err
}

return flattenPullRequestCreationPolicy(query.Repository.PullRequestCreationPolicy)
}

func updateRepositoryPullRequestCreationPolicy(ctx context.Context, repositoryID githubv4.ID, policy string, meta any) error {
expandedPolicy, err := expandPullRequestCreationPolicy(policy)
if err != nil {
return err
}

input := UpdateRepositoryInput{
RepositoryID: repositoryID,
PullRequestCreationPolicy: &expandedPolicy,
}

var mutation struct {
UpdateRepository struct {
Repository struct {
ID githubv4.ID
}
} `graphql:"updateRepository(input:$input)"`
}

client := meta.(*Owner).v4client
return client.Mutate(ctx, &mutation, input, nil)
}

// Maintain compatibility with deprecated Global ID format
// https://github.blog/2021-02-10-new-global-id-format-coming-to-graphql/
func repositoryLegacyNodeIDExists(name string, meta any) (bool, error) {
Expand Down
Loading