Skip to content

Commit 78ad55a

Browse files
committed
Enhance CLI with subcommands and provider logic
- Introduces `version` subcommand to display the app version - Prevents simultaneous use of subcommands and other flags - Implements `handle_subcommand` to manage subcommand execution - Adds type annotations for CLI arguments using `Annotated` - Modifies `get_api_key` to return `None` when no API key is found - Updates `get_ai_provider` to dynamically set the model for the ollama provider - Ensures a model name is set for remote providers via environment variable or CLI flag
1 parent 3b3b8e6 commit 78ad55a

3 files changed

Lines changed: 62 additions & 13 deletions

File tree

ai_commit/__init__.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import argparse
22
from dataclasses import dataclass
3+
from typing import Annotated
34

45

56
@dataclass
67
class CLIArgs:
7-
remote: bool
8-
debug: bool
9-
model: str | None = None
8+
remote: Annotated[bool, "Use remote model for commit generation"]
9+
debug: Annotated[bool, "Run the CLI in debug mode"]
10+
model: Annotated[str | None, "Model to use for commit generation"] = None
11+
command: Annotated[str | None, "Sub-command passed to the CLI"] = None
1012

1113

1214
parser = argparse.ArgumentParser(
@@ -34,5 +36,15 @@ class CLIArgs:
3436
"-r", "--remote", help=remote_model_help, default=False, action="store_true"
3537
)
3638

39+
40+
version_parser = parser.add_subparsers(title="version", dest="command")
41+
version_parser.add_parser("version", help="Show app version")
3742
raw_args = parser.parse_args()
38-
cli_args = CLIArgs(remote=raw_args.remote, debug=raw_args.debug, model=raw_args.model)
43+
44+
45+
cli_args = CLIArgs(
46+
remote=raw_args.remote,
47+
debug=raw_args.debug,
48+
model=raw_args.model,
49+
command=raw_args.command,
50+
)

ai_commit/app.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
import subprocess
55
import sys
66
import tempfile
7+
from importlib.metadata import PackageNotFoundError, version as package_version
78

89
from openai import OpenAI
910

10-
from ai_commit import cli_args as args
11+
from ai_commit import CLIArgs, cli_args as args
1112
from ai_commit.prompts import get_system_prompt
1213
from ai_commit.providers import get_ai_provider, provider_names
1314

@@ -19,7 +20,7 @@
1920
}
2021

2122

22-
def get_api_key() -> str:
23+
def get_api_key() -> str | None:
2324
"""Check if API keys are set in environment variables for remote models.
2425
If the model is run locally (via Ollama), then the 'ollama' key is returned back.
2526
@@ -108,7 +109,7 @@ def generate_commit_message(staged_changes: str, regenerate: bool = False) -> st
108109

109110
return commit_message
110111
except Exception as e:
111-
print(f"❌ Error generating commit message: {str(e)}")
112+
print(f"❌ Error generating commit message:\n{str(e)}")
112113
sys.exit(1)
113114

114115

@@ -163,7 +164,6 @@ def handle_editing(commit_message: str):
163164
def interaction_loop(staged_changes: str):
164165
commit_message = generate_commit_message(staged_changes)
165166
while True:
166-
167167
action = input(
168168
"\n\nProceed to commit? [y(yes) | n[no] | r(regenerate) | e(edit)] "
169169
)
@@ -197,8 +197,29 @@ def interaction_loop(staged_changes: str):
197197
break
198198

199199

200+
def get_version() -> str:
201+
try:
202+
version = package_version("ai-gen-commit")
203+
except PackageNotFoundError:
204+
version = "unknown"
205+
return version
206+
207+
208+
def handle_subcommand(cli_args: CLIArgs):
209+
if cli_args.command and (cli_args.model or cli_args.debug or cli_args.remote):
210+
print(
211+
f"❌ Error: cannot use subcommand {cli_args.command} and other flags together"
212+
)
213+
sys.exit(1)
214+
215+
if cli_args.command == "version":
216+
print(get_version())
217+
sys.exit(0)
218+
219+
200220
def run():
201221
try:
222+
handle_subcommand(args)
202223
run_command(commands["is_git_repo"])
203224
staged_changes = run_command(commands["get_stashed_changes"])
204225

ai_commit/providers.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99
@dataclass
1010
class Provider:
1111
name: str
12-
model: str
12+
model: str | None
1313
base_url: str
1414

1515

1616
providers_mapping = {
1717
"ollama": Provider(
1818
name="ollama",
19-
model=get_model(),
19+
model=None,
2020
base_url=parse_host(os.getenv("OLLAMA_HOST")) + "/v1",
2121
),
2222
"openai": Provider(name="openai", model="gpt-4o", base_url=""),
@@ -61,10 +61,12 @@ class Provider:
6161

6262
def get_ai_provider() -> Provider | None:
6363
if not args.remote:
64-
return providers_mapping["ollama"]
64+
provider = providers_mapping["ollama"]
65+
if provider.model is None:
66+
provider.model = get_model()
67+
return provider
6568

6669
ai_provider = os.environ.get("AI_COMMIT_PROVIDER")
67-
ai_provider = ai_provider.lower() if ai_provider else None
6870
if not ai_provider:
6971
print(
7072
f"""
@@ -84,4 +86,18 @@ def get_ai_provider() -> Provider | None:
8486
)
8587
sys.exit(1)
8688

87-
return providers_mapping.get(ai_provider)
89+
provider = providers_mapping.get(ai_provider.lower())
90+
if not provider.model:
91+
print(
92+
f"""
93+
❌ No model name set for the provider: {provider.name}
94+
95+
Set model name using:
96+
export AI_COMMIT_MODEL=<model-from-provider>
97+
OR
98+
aic -m <model-from-provider>
99+
"""
100+
)
101+
sys.exit(1)
102+
103+
return providers_mapping.get(ai_provider.lower())

0 commit comments

Comments
 (0)