Skip to content

Commit

Permalink
feat: add default prompt support (#204)
Browse files Browse the repository at this point in the history
* feat: add default prompt support

Add a new option `-d` for adding default prompt to the user prompt.

Allow user to generate images without input (in such case they must toggle the `default` option)

Add a new short cut `default` for quick accessing default prompt without any user input.

* feat: use respond prompt when the backend is sd-webui

Use respond prompt to replace the original prompt as feedback, because the prompt may be affected by some plugins (e.g. dynamic-prompts)

* Feat: add `defaultPromptSw` config and remove the `default` shortcut

* Feat: separate the auth lv requirements for default and normal usage

* Refactor: remove `default` shortcut

* Chore: change some descriptions
  • Loading branch information
Ninzore authored Jul 25, 2023
1 parent 0e0c3a2 commit 5c8a76c
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 8 deletions.
22 changes: 21 additions & 1 deletion src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ export interface PromptConfig {
basePrompt?: Computed<string>
negativePrompt?: Computed<string>
forbidden?: Computed<string>
defaultPromptSw?: boolean
defaultPrompt?: Computed<string>
placement?: Computed<'before' | 'after'>
latinOnly?: Computed<boolean>
translator?: boolean
Expand All @@ -153,6 +155,8 @@ export const PromptConfig: Schema<PromptConfig> = Schema.object({
basePrompt: Schema.computed(Schema.string().role('textarea'), options).description('默认附加的标签。').default('masterpiece, best quality'),
negativePrompt: Schema.computed(Schema.string().role('textarea'), options).description('默认附加的反向标签。').default(ucPreset),
forbidden: Schema.computed(Schema.string().role('textarea'), options).description('违禁词列表。请求中的违禁词将会被自动删除。').default(''),
defaultPromptSw: Schema.boolean().description('是否启用默认标签。').default(false),
defaultPrompt: Schema.string().role('textarea', options).description('默认标签,可以在用户无输入prompt时调用。可选在sd-webui中安装dynamic prompt插件,配合使用以达到随机标签效果。').default(''),
placement: Schema.computed(Schema.union([
Schema.const('before').description('置于最前'),
Schema.const('after').description('置于最后'),
Expand Down Expand Up @@ -202,6 +206,8 @@ export interface Config extends PromptConfig, ParamConfig {
token?: string
email?: string
password?: string
authLv?: Computed<number>
authLvDefault?: Computed<number>
output?: Computed<'minimal' | 'default' | 'verbose'>
features?: FeatureConfig
endpoint?: string
Expand Down Expand Up @@ -273,6 +279,11 @@ export const Config = Schema.intersect([
}),
]),

Schema.object({
authLv: Schema.computed(Schema.natural(), options).description('使用画图全部功能所需要的权限等级。').default(0),
authLvDefault: Schema.computed(Schema.natural(), options).description('使用默认参数生成所需要的权限等级。').default(0),
}).description('权限设置'),

Schema.object({
features: Schema.object({}),
}).description('功能设置'),
Expand Down Expand Up @@ -372,7 +383,15 @@ export function parseForbidden(input: string) {

const backslash = /@@__BACKSLASH__@@/g

export function parseInput(session: Session, input: string, config: Config, override: boolean): string[] {
export function parseInput(session: Session, input: string, config: Config, override: boolean, addDefault: boolean): string[] {
if (!input) {
return [
null,
[session.resolve(config.basePrompt), session.resolve(config.defaultPrompt)].join(','),
session.resolve(config.negativePrompt)
]
}

input = input
.replace(/\\\\/g, backslash.source)
.replace(//g, ',')
Expand Down Expand Up @@ -446,6 +465,7 @@ export function parseInput(session: Session, input: string, config: Config, over
if (!override) {
appendToList(positive, session.resolve(config.basePrompt))
appendToList(negative, session.resolve(config.negativePrompt))
if (addDefault) appendToList(positive, session.resolve(config.defaultPrompt))
}

return [null, positive.join(', '), negative.join(', ')]
Expand Down
33 changes: 26 additions & 7 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,22 @@ export function apply(ctx: Context, config: Config) {
.option('iterations', '-i <iterations:posint>', { fallback: 1, hidden: () => config.maxIterations <= 1 })
.option('batch', '-b <batch:option>', { fallback: 1, hidden: () => config.maxIterations <= 1 })
.action(async ({ session, options }, input) => {
if (!input?.trim()) return session.execute('help novelai')
if (config.defaultPromptSw) {
if (session.user.authority < session.resolve(config.authLvDefault)) {
return session.text('internal.low-authority')
}
if (session.user.authority < session.resolve(config.authLv)) {
input = ''
options = options.resolution ? { resolution: options.resolution } : {}
}
}
else if (
!config.defaultPromptSw &&
session.user.authority < session.resolve(config.authLv)
) return session.text('internal.low-auth')

const haveInput = input?.trim() ? true : false
if (!haveInput && !config.defaultPromptSw) return session.execute('help novelai')

// Check if the user is allowed to use this command.
// This code is originally written in the `resolution` function,
Expand All @@ -126,7 +141,7 @@ export function apply(ctx: Context, config: Config) {
const allowImage = useFilter(config.features.image)(session)

let imgUrl: string, image: ImageData
if (!restricted(session)) {
if (!restricted(session) && haveInput) {
input = h('', h.transform(h.parse(input), {
image(attrs) {
if (!allowImage) throw new SessionError('commands.novelai.messages.invalid-content')
Expand All @@ -144,11 +159,11 @@ export function apply(ctx: Context, config: Config) {
return session.text('.expect-prompt')
}
} else {
input = h('', h.transform(h.parse(input), {
input = haveInput ? h('', h.transform(h.parse(input), {
image(attrs) {
throw new SessionError('commands.novelai.messages.invalid-content')
},
})).toString(true)
})).toString(true) : input
delete options.enhance
delete options.steps
delete options.noise
Expand All @@ -160,15 +175,17 @@ export function apply(ctx: Context, config: Config) {
return session.text('.expect-image')
}

if (config.translator && ctx.translator && !options.noTranslator) {
if (haveInput && config.translator && ctx.translator && !options.noTranslator) {
try {
input = await ctx.translator.translate({ input, target: 'en' })
} catch (err) {
logger.warn(err)
}
}

const [errPath, prompt, uc] = parseInput(session, input, config, options.override)
const [errPath, prompt, uc] = parseInput(
session, input, config, options.override, config.defaultPromptSw
)
if (errPath) return session.text(errPath)

let token: string
Expand Down Expand Up @@ -352,6 +369,7 @@ export function apply(ctx: Context, config: Config) {
}
}

let finalPrompt = prompt
const iterate = async () => {
const request = async () => {
const res = await ctx.http.axios(trimSlash(config.endpoint) + path, {
Expand All @@ -365,6 +383,7 @@ export function apply(ctx: Context, config: Config) {
})

if (config.type === 'sd-webui') {
finalPrompt = (JSON.parse((res.data as StableDiffusionWebUI.Response).info)).prompt
return forceDataPrefix((res.data as StableDiffusionWebUI.Response).images[0])
}
if (config.type === 'stable-horde') {
Expand Down Expand Up @@ -436,7 +455,7 @@ export function apply(ctx: Context, config: Config) {
}
}
result.children.push(h('message', attrs, lines.join('\n')))
result.children.push(h('message', attrs, `prompt = ${prompt}`))
result.children.push(h('message', attrs, `prompt = ${finalPrompt}`))
if (output === 'verbose') {
result.children.push(h('message', attrs, `undesired = ${uc}`))
}
Expand Down

0 comments on commit 5c8a76c

Please sign in to comment.