-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
50 lines (37 loc) · 1.79 KB
/
Copy pathutils.py
File metadata and controls
50 lines (37 loc) · 1.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import dataclasses
import os
from dataclasses import dataclass
from typing import List, Optional
from huggingface_hub import login
from transformers import HfArgumentParser
class StarChatArgumentParser(HfArgumentParser):
def parse_yaml_and_args(self, yaml_arg: str, other_args: Optional[List[str]] = None) -> List[dataclass]:
arg_list = self.parse_yaml_file(os.path.abspath(yaml_arg))
outputs = []
# strip other args list into dict of key-value pairs
other_args = {arg.split("=")[0].strip("-"): arg.split("=")[1] for arg in other_args}
used_args = {}
for data_yaml, data_class in zip(arg_list, self.dataclass_types):
keys = {f.name for f in dataclasses.fields(data_yaml) if f.init}
inputs = {k: v for k, v in vars(data_yaml).items() if k in keys}
for arg, val in other_args.items():
# add only if in keys
if arg in keys:
base_type = data_yaml.__dataclass_fields__[arg].type
inputs[arg] = val
# cast type for ints, floats, and bools (default to strings)
if base_type in [int, float, bool]:
inputs[arg] = base_type(val)
# add to used-args so we can check if double add
if arg not in used_args:
used_args[arg] = val
else:
raise ValueError(f"Duplicate argument provided: {arg}, may cause unexpected behavior")
obj = data_class(**inputs)
outputs.append(obj)
return outputs
def hf_login():
"""Login to HuggingFace Hub if HF_TOKEN is defined in the environment"""
hf_token = os.getenv("HF_TOKEN")
if hf_token is not None:
login(token=hf_token)