Skip to content

Commit dbaa329

Browse files
support hive create function (#500)
* support hive create function: Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL#LanguageManualDDL-CreateFunction * Pylint * func: get_switch_by_create_query set to private method
1 parent ece5ace commit dbaa329

3 files changed

Lines changed: 39 additions & 1 deletion

File tree

sql_metadata/keywords_lists.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ class TokenType(str, Enum):
108108
"CREATETABLE": QueryType.CREATE,
109109
"ALTERTABLE": QueryType.ALTER,
110110
"DROPTABLE": QueryType.DROP,
111+
"CREATEFUNCTION": QueryType.CREATE,
111112
}
112113

113114
# all the keywords we care for - rest is ignored in assigning

sql_metadata/parser.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,9 @@ def query_type(self) -> str:
114114
)
115115
.position
116116
)
117-
if tokens[index].normalized in ["CREATE", "ALTER", "DROP"]:
117+
if tokens[index].normalized == "CREATE":
118+
switch = self._get_switch_by_create_query(tokens, index)
119+
elif tokens[index].normalized in ("ALTER", "DROP"):
118120
switch = tokens[index].normalized + tokens[index + 1].normalized
119121
else:
120122
switch = tokens[index].normalized
@@ -1079,3 +1081,19 @@ def _flatten_sqlparse(self):
10791081
yield tok
10801082
else:
10811083
yield token
1084+
1085+
@staticmethod
1086+
def _get_switch_by_create_query(tokens: List[SQLToken], index: int) -> str:
1087+
"""
1088+
Return the switch that creates query type.
1089+
"""
1090+
switch = tokens[index].normalized + tokens[index + 1].normalized
1091+
1092+
# Hive CREATE FUNCTION
1093+
if any(
1094+
index + i < len(tokens) and tokens[index + i].normalized == "FUNCTION"
1095+
for i in (1, 2)
1096+
):
1097+
switch = "CREATEFUNCTION"
1098+
1099+
return switch

test/test_query_type.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,22 @@ def test_multiple_redundant_parentheses_create():
9393
"""
9494
parser = Parser(query)
9595
assert parser.query_type == QueryType.CREATE
96+
97+
98+
def test_hive_create_function():
99+
query = """
100+
CREATE FUNCTION simple_udf AS 'com.example.hive.udf.SimpleUDF'
101+
USING JAR 'hdfs:///user/hive/udfs/simple-udf.jar'
102+
WITH SERDEPROPERTIES (
103+
"hive.udf.param1"="value1",
104+
"hive.udf.param2"="value2"
105+
);
106+
"""
107+
parser = Parser(query)
108+
assert parser.query_type == QueryType.CREATE
109+
110+
query = """
111+
CREATE TEMPORARY FUNCTION myudf AS 'com.udf.myudf';
112+
"""
113+
parser = Parser(query)
114+
assert parser.query_type == QueryType.CREATE

0 commit comments

Comments
 (0)