diff --git a/pkg/cmd/cmd.go b/pkg/cmd/cmd.go index e74b67c2..65b41cc6 100644 --- a/pkg/cmd/cmd.go +++ b/pkg/cmd/cmd.go @@ -13,6 +13,7 @@ import ( "github.com/brevdev/brev-cli/pkg/cmd/configureenvvars" "github.com/brevdev/brev-cli/pkg/cmd/connect" "github.com/brevdev/brev-cli/pkg/cmd/copy" + "github.com/brevdev/brev-cli/pkg/cmd/credit" "github.com/brevdev/brev-cli/pkg/cmd/delete" "github.com/brevdev/brev-cli/pkg/cmd/deregister" "github.com/brevdev/brev-cli/pkg/cmd/enablessh" @@ -288,6 +289,7 @@ func createCmdTree(cmd *cobra.Command, t *terminal.Terminal, loginCmdStore *stor cmd.AddCommand(org.NewCmdOrg(t, loginCmdStore, noLoginCmdStore)) cmd.AddCommand(invite.NewCmdInvite(t, loginCmdStore, noLoginCmdStore)) cmd.AddCommand(redeem.NewCmdRedeem(t, loginCmdStore, noLoginCmdStore)) + cmd.AddCommand(credit.NewCmdCredit(t, loginCmdStore, noLoginCmdStore)) cmd.AddCommand(portforward.NewCmdPortForwardSSH(loginCmdStore, t)) cmd.AddCommand(login.NewCmdLogin(t, noLoginCmdStore, loginAuth)) cmd.AddCommand(logout.NewCmdLogout(loginAuth, noLoginCmdStore)) diff --git a/pkg/cmd/credit/credit.go b/pkg/cmd/credit/credit.go new file mode 100644 index 00000000..7b94de30 --- /dev/null +++ b/pkg/cmd/credit/credit.go @@ -0,0 +1,108 @@ +package credit + +import ( + "fmt" + + "github.com/brevdev/brev-cli/pkg/cmd/cmderrors" + "github.com/brevdev/brev-cli/pkg/cmd/completions" + "github.com/brevdev/brev-cli/pkg/cmdcontext" + "github.com/brevdev/brev-cli/pkg/entity" + breverrors "github.com/brevdev/brev-cli/pkg/errors" + "github.com/brevdev/brev-cli/pkg/store" + "github.com/brevdev/brev-cli/pkg/terminal" + + "github.com/spf13/cobra" +) + +type CreditStore interface { + GetActiveOrganizationOrDefault() (*entity.Organization, error) + GetOrganizations(options *store.GetOrganizationsOptions) ([]entity.Organization, error) + GetBillingProfile(organizationID string) (*store.BillingProfile, error) + completions.CompletionStore +} + +func NewCmdCredit(t *terminal.Terminal, creditStore CreditStore, noCreditStore CreditStore) *cobra.Command { + var orgFlag string + + cmd := &cobra.Command{ + Annotations: map[string]string{"organization": ""}, + Use: "credit", + DisableFlagsInUseLine: true, + Short: "Show active organization's credit balance", + Long: "Print the credit balance for the active organization.", + Example: ` + brev credit + brev credit --org myorg + `, + PersistentPreRunE: func(cmd *cobra.Command, args []string) error { + err := cmdcontext.InvokeParentPersistentPreRun(cmd, args) + if err != nil { + return breverrors.WrapAndTrace(err) + } + + return nil + }, + Args: cmderrors.TransformToValidationError(cobra.NoArgs), + RunE: func(cmd *cobra.Command, args []string) error { + err := RunCredit(t, creditStore, orgFlag) + if err != nil { + return breverrors.WrapAndTrace(err) + } + return nil + }, + } + + cmd.Flags().StringVarP(&orgFlag, "org", "o", "", "organization (will override active org)") + err := cmd.RegisterFlagCompletionFunc("org", completions.GetOrgsNameCompletionHandler(noCreditStore, t)) + if err != nil { + breverrors.GetDefaultErrorReporter().ReportError(breverrors.WrapAndTrace(err)) + fmt.Print(breverrors.WrapAndTrace(err)) + } + + return cmd +} + +func RunCredit(t *terminal.Terminal, creditStore CreditStore, orgFlag string) error { + var org *entity.Organization + if orgFlag != "" { + orgs, err := creditStore.GetOrganizations(&store.GetOrganizationsOptions{Name: orgFlag}) + if err != nil { + return breverrors.WrapAndTrace(err) + } + if len(orgs) == 0 { + return fmt.Errorf("no org found with name %s", orgFlag) + } else if len(orgs) > 1 { + return fmt.Errorf("more than one org found with name %s", orgFlag) + } + + org = &orgs[0] + } else { + currOrg, err := creditStore.GetActiveOrganizationOrDefault() + if err != nil { + return breverrors.WrapAndTrace(err) + } + if currOrg == nil { + return fmt.Errorf("no orgs exist") + } + org = currOrg + } + + profile, err := creditStore.GetBillingProfile(org.ID) + if err != nil { + return breverrors.WrapAndTrace(err) + } + + if profile == nil || profile.CreditDetails == nil || profile.CreditDetails.RemainingCredits == nil { + return fmt.Errorf("failed to retrieve credit balance") + } + + remainingCents := *profile.CreditDetails.RemainingCredits + dollars := float64(remainingCents) / 100.0 + + t.Vprint(t.Green("✓ Retrieved credit balance\n")) + t.Vprintf(" Organization: %s\n", org.Name) + t.Vprintf(" ID: %s\n", org.ID) + t.Vprintf(" Credits: $%.2f\n", dollars) + + return nil +} diff --git a/pkg/cmd/credit/credit_test.go b/pkg/cmd/credit/credit_test.go new file mode 100644 index 00000000..20897c1d --- /dev/null +++ b/pkg/cmd/credit/credit_test.go @@ -0,0 +1,91 @@ +package credit + +import ( + "testing" + + "github.com/brevdev/brev-cli/pkg/entity" + "github.com/brevdev/brev-cli/pkg/store" + "github.com/brevdev/brev-cli/pkg/terminal" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type fakeCreditStore struct { + org *entity.Organization + profile *store.BillingProfile + orgs []entity.Organization +} + +func (f fakeCreditStore) GetActiveOrganizationOrDefault() (*entity.Organization, error) { + return f.org, nil +} + +func (f fakeCreditStore) GetOrganizations(options *store.GetOrganizationsOptions) ([]entity.Organization, error) { + return f.orgs, nil +} + +func (f fakeCreditStore) GetCurrentUser() (*entity.User, error) { + return nil, nil +} + +func (f fakeCreditStore) GetAuthTokens() (*entity.AuthTokens, error) { + return nil, nil +} + +func (f fakeCreditStore) GetWorkspaces(organizationID string, options *store.GetWorkspacesOptions) ([]entity.Workspace, error) { + return nil, nil +} + +func (f fakeCreditStore) GetBillingProfile(organizationID string) (*store.BillingProfile, error) { + return f.profile, nil +} + +func TestRunCreditUsesActiveOrgBalance(t *testing.T) { + org := &entity.Organization{ID: "org-1", Name: "test-org"} + remaining := int64(500) + profile := &store.BillingProfile{CreditDetails: &store.CreditDetails{RemainingCredits: &remaining}} + store := fakeCreditStore{org: org, profile: profile} + + err := RunCredit(terminal.New(), store, "") + require.NoError(t, err) + assert.Equal(t, org, store.org) + assert.NotNil(t, store.profile) + assert.NotNil(t, store.profile.CreditDetails) + assert.Equal(t, remaining, *store.profile.CreditDetails.RemainingCredits) +} + +func TestRunCreditUsesNamedOrg(t *testing.T) { + org := &entity.Organization{ID: "org-2", Name: "named-org"} + remaining := int64(1234) + profile := &store.BillingProfile{CreditDetails: &store.CreditDetails{RemainingCredits: &remaining}} + store := fakeCreditStore{org: org, profile: profile, orgs: []entity.Organization{*org}} + + err := RunCredit(terminal.New(), store, "named-org") + require.NoError(t, err) + assert.Equal(t, org, store.org) + assert.NotNil(t, store.profile) + assert.Equal(t, remaining, *store.profile.CreditDetails.RemainingCredits) +} + +func TestRunCreditFailsWhenCreditBalanceMissing(t *testing.T) { + org := &entity.Organization{ID: "org-1", Name: "test-org"} + + tests := []struct { + name string + profile *store.BillingProfile + }{ + {name: "billing profile missing", profile: nil}, + {name: "credit details missing", profile: &store.BillingProfile{CreditDetails: nil}}, + {name: "remaining credits missing", profile: &store.BillingProfile{CreditDetails: &store.CreditDetails{RemainingCredits: nil}}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := fakeCreditStore{org: org, profile: tt.profile} + + err := RunCredit(terminal.New(), store, "") + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to retrieve credit balance") + }) + } +} diff --git a/pkg/store/organization.go b/pkg/store/organization.go index f41c760c..98e13a9b 100644 --- a/pkg/store/organization.go +++ b/pkg/store/organization.go @@ -308,3 +308,34 @@ func (s AuthHTTPStore) RedeemCouponCode(organizationID string, code string) (*Re return &result, nil } + +type CreditDetails struct { + RemainingCredits *int64 `json:"remaining_credits"` + TopUpAmount *int64 `json:"top_up_amount"` + TopUpThreshold *int64 `json:"top_up_threshold"` +} + +type BillingProfile struct { + BillingProfileID string `json:"billing_profile_id"` + CreditDetails *CreditDetails `json:"credit_details"` +} + +// GetBillingProfile fetches the billing profile for the given organization. +// remaining_credits is returned in cents (server-side). Caller should divide by 100 for dollars. +func (s AuthHTTPStore) GetBillingProfile(organizationID string) (*BillingProfile, error) { + var result BillingProfile + path := orgPath + "/" + organizationID + "/billingprofile" + + res, err := s.authHTTPClient.restyClient.R(). + SetHeader("Content-Type", "application/json"). + SetResult(&result). + Get(path) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + if res.IsError() { + return nil, NewHTTPResponseError(res) + } + + return &result, nil +}