Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

未正确输出SQL #1

Open
ZhengyiWang opened this issue Apr 11, 2024 · 5 comments
Open

未正确输出SQL #1

ZhengyiWang opened this issue Apr 11, 2024 · 5 comments

Comments

@ZhengyiWang
Copy link

运行样例后,大语言模型直接输出4

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
model_path = "Chat2DB/Chat2DB-SQL-7B" # This can be replaced with your local model path
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", trust_remote_code=True, torch_dtype=torch.float16, use_cache=True)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, return_full_text=False, max_new_tokens=100)
prompt = "### Database Schema\n\n['CREATE TABLE \"stadium\" (\\n\"Stadium_ID\" int,\\n\"Location\" text,\\n\"Name\" text,\\n\"Capacity\" int,\\n\"Highest\" int,\\n\"Lowest\" int,\\n\"Average\" int,\\nPRIMARY KEY (\"Stadium_ID\")\\n);', 'CREATE TABLE \"singer\" (\\n\"Singer_ID\" int,\\n\"Name\" text,\\n\"Country\" text,\\n\"Song_Name\" text,\\n\"Song_release_year\" text,\\n\"Age\" int,\\n\"Is_male\" bool,\\nPRIMARY KEY (\"Singer_ID\")\\n);', 'CREATE TABLE \"concert\" (\\n\"concert_ID\" int,\\n\"concert_Name\" text,\\n\"Theme\" text,\\n\"Stadium_ID\" text,\\n\"Year\" text,\\nPRIMARY KEY (\"concert_ID\"),\\nFOREIGN KEY (\"Stadium_ID\") REFERENCES \"stadium\"(\"Stadium_ID\")\\n);', 'CREATE TABLE \"singer_in_concert\" (\\n\"concert_ID\" int,\\n\"Singer_ID\" text,\\nPRIMARY KEY (\"concert_ID\",\"Singer_ID\"),\\nFOREIGN KEY (

\"concert_ID\") REFERENCES \"concert\"(\"concert_ID\"),\\nFOREIGN KEY (\"Singer_ID\") REFERENCES \"singer\"(\"Singer_ID\")\\n);']\n\n\n### Task \n\nBased on the provided database schema information, How many singers do we have?[SQL]\n"
response = pipe(prompt)[0]["generated_text"]
print(response)

尝试其他prompt例子,大语言模型也未像chat2db ai一样直接输出sql

@lordk911
Copy link

同样的问题:

>>> print(response)


### Answer 
255

@baisui1981
Copy link

这个模型的出处是哪儿?

@loli0123456789
Copy link

官方怎么说?

@ls25145
Copy link

ls25145 commented Jul 3, 2024

  1. 需要安装依赖包
!pip install accelerate
# 这个包需要重启内核生效
import os
os._exit(00)
  1. 示例代码中多了一个空行,会导致语法错误。

  2. 调用 pipeline 的时候需要加个参数:
    pipe(prompt) 改为 pipe(prompt,pad_token_id=pipe.tokenizer.eos_token_id)
    否则会报错

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
  1. 用示例里面给的提示词是生成不了SQL的,里面的建表语句不知道为啥搞个方括号还有单引号括起来,也不像JSON语法,
    我把这个方括号还有单引号删了就能生成SQL了。

总体看来开发者不怎么用心,随便搞了个说明应付了事。

以下是修改后代码,供参考

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
model_path = "Chat2DB/Chat2DB-SQL-7B" # This can be replaced with your local model path
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", trust_remote_code=True, torch_dtype=torch.float16, use_cache=True)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, return_full_text=False, max_new_tokens=100)
prompt = """
### Database Schema

CREATE TABLE "stadium" (
"Stadium_ID" int,
"Location" text,
"Name" text,
"Capacity" int,
"Highest" int,
"Lowest" int,
"Average" int,
PRIMARY KEY ("Stadium_ID")
);

CREATE TABLE "singer" (
"Singer_ID" int,
"Name" text,
"Country" text,
"Song_Name" text,
"Song_release_year" text,
"Age" int,
"Is_male" bool,
PRIMARY KEY ("Singer_ID")
);

CREATE TABLE "concert" (
"concert_ID" int,
"concert_Name" text,
"Theme" text,
"Stadium_ID" text,
"Year" text,
PRIMARY KEY ("concert_ID"),
FOREIGN KEY ("Stadium_ID") REFERENCES "stadium"("Stadium_ID")
);

CREATE TABLE "singer_in_concert" (
"concert_ID" int,
"Singer_ID" text,
PRIMARY KEY ("concert_ID","Singer_ID"),
FOREIGN KEY ("concert_ID") REFERENCES "concert"("concert_ID"),
FOREIGN KEY ("Singer_ID") REFERENCES "singer"("Singer_ID")
);


### Task 

Based on the provided database schema information, How many singers do we have?[SQL]

"""
response = pipe(prompt,pad_token_id=pipe.tokenizer.eos_token_id)[0]["generated_text"]
print(response)

@Valdanitooooo
Copy link

prompt 模板和 code-llama-instruct 不一样吗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants