Skip to content

Latest commit

 

History

History
154 lines (131 loc) · 6.03 KB

File metadata and controls

154 lines (131 loc) · 6.03 KB

Table cars_data has columns such as id, weight.\nid is the primary key.\nTable model_list has columns such as maker, model.

Table model_list has columns such as model, maker.\nTable cars_data has columns such as id, weight.\nid is the primary key.

python evaluation.py --gold /Users/jizha/code/python/spider/dataset/spider/dev_gold.sql --pred /Users/jizha/code/python/test-suite-sql-eval/二轮测试_gpt4_choice.json --etype all --db /Users/jizha/code/python/spider/dataset/spider/database --table tables.json

python generate_question.py \
--data_type spider \
--split test \
--tokenizer gpt-3.5-turbo \
--max_seq_len 4096 \
--selector_type EUCDISMASKPRESKLSIMTHR \
--pre_test_result /Users/jizha/code/python/test-suite-sql-eval/随机列测试/union_test_20231201_random_table.sql \
--prompt_repr SQL \
--k_shot 9 \
--example_type QA
import argparse
import os
import json

import openai
from tqdm import tqdm

from llm.chatgpt import init_chatgpt, ask_llm
from utils.enums import LLM
from torch.utils.data import DataLoader

from utils.post_process import process_duplication, get_sqls
import concurrent.futures

QUESTION_FILE = "questions.json"


def gen_predict_sql(index, token_cnt, args, batch):
    try:
        res = ask_llm(args.model, batch, args.temperature, args.n)
    except openai.error.InvalidRequestError:
        print(f"The {i}-th question has too much tokens! Return \"SELECT\" instead")
        res = ""
    # parse result
    token_cnt += res["total_tokens"]
    results = []
    if args.n == 1:
        for sql in res["response"]:
            # remove \n and extra spaces
            sql = " ".join(sql.replace("\n", " ").split())
            sql = process_duplication(sql)
            # python version should >= 3.8
            if sql.startswith("SELECT"):
                results.append(sql)
            elif sql.startswith(" "):
                results.append("SELECT" + sql)
            else:
                results.append("SELECT " + sql)
    else:
        cur_db_ids = db_ids[i * args.batch_size: i * args.batch_size + len(batch)]
        for sqls, db_id in zip(res["response"], cur_db_ids):
            processed_sqls = []
            for sql in sqls:
                sql = " ".join(sql.replace("\n", " ").split())
                sql = process_duplication(sql)
                if sql.startswith("SELECT"):
                    pass
                elif sql.startswith(" "):
                    sql = "SELECT" + sql
                else:
                    sql = "SELECT " + sql
                processed_sqls.append(sql)
            result = {
                'db_id': db_id,
                'p_sqls': processed_sqls
            }
            final_sqls = get_sqls([result], args.n, args.db_dir)
            results = final_sqls
    return index, results


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--question", type=str)
    parser.add_argument("--openai_api_key", type=str, default="eab38a33cc07467aae9b7d09783b75a8")
    parser.add_argument("--openai_group_id", type=str, default="luli.wjc")
    parser.add_argument("--openai_api_base", type=str,
                        default="https://codegencore.antgroup-inc.cn/api/chat/commonPower/v1")
    parser.add_argument("--model", type=str, choices=[LLM.TEXT_DAVINCI_003,
                                                      LLM.GPT_35_TURBO,
                                                      LLM.GPT_35_TURBO_0613,
                                                      LLM.TONG_YI_QIAN_WEN,
                                                      LLM.GPT_35_TURBO_16K,
                                                      LLM.GPT_4],
                        default=LLM.GPT_35_TURBO)
    parser.add_argument("--start_index", type=int, default=0)
    parser.add_argument("--end_index", type=int, default=1000000)
    parser.add_argument("--temperature", type=float, default=0)
    parser.add_argument("--mini_index_path", type=str, default="")
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--n", type=int, default=1, help="Size of self-consistent set")
    parser.add_argument("--db_dir", type=str, default="dataset/spider/database")
    args = parser.parse_args()

    # check args
    assert args.model in LLM.BATCH_FORWARD or \
           args.model not in LLM.BATCH_FORWARD and args.batch_size == 1, \
        f"{args.model} doesn't support batch_size > 1"

    questions_json = json.load(open(os.path.join(args.question, QUESTION_FILE), "r"))
    questions = [_["prompt"] for _ in questions_json["questions"]]
    db_ids = [_["db_id"] for _ in questions_json["questions"]]

    # init openai api
    init_chatgpt(args.openai_api_key, args.openai_group_id, args.openai_api_base, args.model)

    if args.start_index == 0:
        mode = "w"
    else:
        mode = "a"

    if args.mini_index_path:
        mini_index = json.load(open(args.mini_index_path, 'r'))
        questions = [questions[i] for i in mini_index]
        out_file = f"{args.question}/RESULTS_MODEL-{args.model}_MINI.txt"
    else:
        out_file = f"{args.question}/RESULTS_MODEL-{args.model}.txt"

    question_loader = DataLoader(questions, batch_size=args.batch_size, shuffle=False, drop_last=False)

    token_cnt = 0
    results = []
    with open(out_file, mode) as f:
        for i in tqdm(range(0, len(question_loader), 10)):
            up = i + 10
            if len(question_loader) < up:
                up = len(question_loader)
            result_temp = [""] * (up - i)
            future_list = []
            with concurrent.futures.ThreadPoolExecutor() as executor:
                question_batch = question_loader[i:up]
                for index, item in enumerate(question_batch):
                    future_list.append(executor.submit(gen_predict_sql, index, token_cnt, args, item))
                for future in concurrent.futures.as_completed(future_list):
                    index, p_sqls = future.result()
                    result_temp[index] = p_sqls
            for item in result_temp:
                f.write("".join(item))
                results.extend(item)