Skip to content

Commit 3bf2bcc

Browse files
fix: Code refactor with staged changes and auto-stage prompt.
1 parent 4ad9df9 commit 3bf2bcc

2 files changed

Lines changed: 80 additions & 51 deletions

File tree

client.go

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,6 @@ type ChatCompleteRequest struct {
2121
Messages []*Message `json:"messages"`
2222
}
2323

24-
type ChatCompleteChoice struct {
25-
Index int `json:"index"`
26-
Message struct {
27-
Role string `json:"role"`
28-
Content string `json:"content"`
29-
} `json:"message"`
30-
FinishReason string `json:"finish_reason"`
31-
}
32-
33-
type ChatCompleteResponse struct {
34-
Id string `json:"id"`
35-
Object string `json:"object"`
36-
Created int `json:"created"`
37-
Choices []ChatCompleteChoice `json:"choices"`
38-
}
39-
4024
type GptClient struct {
4125
apiKey string
4226
httpClient *http.Client
@@ -69,7 +53,7 @@ func (c *GptClient) ChatComplete(ctx context.Context, messages []*Message) (stri
6953
req.Header.Add("Content-Type", "application/json")
7054
req.Header.Add("Authorization", "Bearer "+c.apiKey)
7155

72-
res, err := http.DefaultClient.Do(req)
56+
res, err := c.httpClient.Do(req)
7357
if err != nil {
7458
return "", err
7559
}
@@ -91,21 +75,11 @@ func (c *GptClient) ChatComplete(ctx context.Context, messages []*Message) (stri
9175
return "", errors.New("failed to get response from OpenAI API: " + errorMessage)
9276
}
9377

94-
var response ChatCompleteResponse
95-
err = json.Unmarshal(body, &response)
96-
if err != nil {
97-
return "", err
98-
}
99-
100-
if len(response.Choices) == 0 {
101-
return "", errors.New("no choices returned from OpenAI API")
102-
}
103-
104-
firstChoice := response.Choices[0].Message.Content
105-
firstChoice = strings.TrimSpace(firstChoice)
106-
firstChoice = strings.Trim(firstChoice, `"`)
78+
answer := gjson.GetBytes(body, "choices.0.message.content").String()
79+
answer = strings.TrimSpace(answer)
80+
answer = strings.Trim(answer, `"`)
10781

108-
return firstChoice, nil
82+
return answer, nil
10983

11084
}
11185

main.go

Lines changed: 75 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ var messages = []*Message{
1919
}
2020

2121
func main() {
22-
2322
// prepare the arguments
2423
apiKey := os.Getenv("OPENAI_API_KEY")
2524
if apiKey == "" {
@@ -32,23 +31,34 @@ func main() {
3231
// prepare the diff
3332
diff, err := getDiff()
3433
if err != nil {
35-
explain, err := explainError(context.Background(), client, err)
36-
if err != nil {
37-
printError(err.Error())
34+
if explain, explainErr := explainError(context.Background(), client, err); explainErr == nil {
35+
printError(explain)
3836
os.Exit(1)
3937
}
40-
printError(explain)
38+
39+
printError(err.Error())
4140
os.Exit(1)
4241
}
4342

4443
if diff == "" {
45-
if isDirty() {
46-
fmt.Println("Please stage your changes and try again")
47-
} else {
44+
if !isDirty() {
4845
fmt.Println("Nothing to commit, working tree clean")
46+
os.Exit(0)
47+
}
48+
49+
if shouldAutoStage := askForAutoStage(client); !shouldAutoStage {
50+
os.Exit(0)
4951
}
5052

51-
os.Exit(0)
53+
if err := gitAdd(); err != nil {
54+
if explain, explainErr := explainError(context.Background(), client, err); explainErr == nil {
55+
printError(explain)
56+
os.Exit(1)
57+
}
58+
59+
printError(err.Error())
60+
os.Exit(1)
61+
}
5262
}
5363

5464
commitMessage := ""
@@ -64,12 +74,12 @@ func main() {
6474
printNormal("Assistant: " + generateLoadingMessage())
6575
commitMessage, err = client.ChatComplete(ctx, messages)
6676
if err != nil {
67-
explain, err := explainError(ctx, client, err)
68-
if err != nil {
69-
printError(err.Error())
77+
if explain, explainErr := explainError(context.Background(), client, err); explainErr == nil {
78+
printError(explain)
7079
os.Exit(1)
7180
}
72-
printError(explain)
81+
82+
printError(err.Error())
7383
os.Exit(1)
7484
}
7585

@@ -91,12 +101,12 @@ func main() {
91101
reader := bufio.NewReader(os.Stdin)
92102
userRequest, err = reader.ReadString('\n')
93103
if err != nil {
94-
explain, err := explainError(ctx, client, err)
95-
if err != nil {
96-
printError(err.Error())
104+
if explain, explainErr := explainError(context.Background(), client, err); explainErr == nil {
105+
printError(explain)
97106
os.Exit(1)
98107
}
99-
printError(explain)
108+
109+
printError(err.Error())
100110
os.Exit(1)
101111
}
102112

@@ -145,6 +155,37 @@ func joinPrefix(prefix string, message string) string {
145155
return prefix + ": " + message
146156
}
147157

158+
func gitAdd() error {
159+
workingDir, err := os.Getwd()
160+
if err != nil {
161+
return err
162+
}
163+
164+
return executils.Run("git",
165+
executils.WithDir(workingDir),
166+
executils.WithArgs("add", "."),
167+
)
168+
}
169+
170+
func askForAutoStage(apiClient *GptClient) bool {
171+
fmt.Println("Assistant: Your working tree is dirty, but the stage is empty, do you want me to stage the changes first?")
172+
fmt.Print("You: ")
173+
reader := bufio.NewReader(os.Stdin)
174+
userRequest, err := reader.ReadString('\n')
175+
if err != nil {
176+
printError("failed to read user input: " + err.Error())
177+
os.Exit(1)
178+
}
179+
180+
userRequest = strings.TrimSpace(userRequest)
181+
182+
if userRequest == "" {
183+
return askForAutoStage(apiClient)
184+
}
185+
186+
return IsAgree(apiClient, userRequest)
187+
}
188+
148189
func askForPrefix() string {
149190
prefix := ""
150191
var err error
@@ -169,7 +210,7 @@ func explainError(ctx context.Context, apiClient *GptClient, userError error) (s
169210
response, err := apiClient.ChatComplete(ctx, []*Message{
170211
{
171212
Role: "system",
172-
Content: "You are a developer, explain the error to user: `" + userError.Error() + "`, only response the message:",
213+
Content: "You are a developer, explain the error to user: `" + userError.Error() + "`.",
173214
},
174215
})
175216
if err != nil {
@@ -200,11 +241,13 @@ func getDiff() (string, error) {
200241
}
201242

202243
out := strings.Builder{}
203-
executils.Run("git",
244+
if err := executils.Run("git",
204245
executils.WithDir(workingDir),
205246
executils.WithArgs("diff", "--cached", "--unified=0"),
206247
executils.WithStdOut(&out),
207-
)
248+
); err != nil {
249+
return "", err
250+
}
208251

209252
return strings.TrimSpace(out.String()), nil
210253
}
@@ -235,6 +278,12 @@ var agreeWords = []string{
235278
"agree",
236279
}
237280

281+
var disagreeWords = []string{
282+
"no",
283+
"n",
284+
"disagree",
285+
}
286+
238287
// IsAgree returns true if the user agrees with the commit message
239288
func IsAgree(c *GptClient, userResponse string) bool {
240289
for _, word := range agreeWords {
@@ -243,6 +292,12 @@ func IsAgree(c *GptClient, userResponse string) bool {
243292
}
244293
}
245294

295+
for _, word := range disagreeWords {
296+
if strings.HasPrefix(strings.ToLower(userResponse), word) {
297+
return false
298+
}
299+
}
300+
246301
message := []*Message{
247302
{
248303
Role: "user",

0 commit comments

Comments
 (0)