diff --git a/.changeset/perfect-beers-wave.md b/.changeset/perfect-beers-wave.md new file mode 100644 index 0000000000..af19e981fa --- /dev/null +++ b/.changeset/perfect-beers-wave.md @@ -0,0 +1,5 @@ +--- +"@effect/schema": patch +--- + +resolve parse error when using `pick` with union of class schemas, closes #3751 diff --git a/packages/schema/src/AST.ts b/packages/schema/src/AST.ts index 312fbe8f77..535daaf24c 100644 --- a/packages/schema/src/AST.ts +++ b/packages/schema/src/AST.ts @@ -2067,51 +2067,57 @@ export const getNumberIndexedAccess = (ast: AST): AST => { throw new Error(errors_.getASTUnsupportedSchema(ast)) } +const getTypeLiteralPropertySignature = (ast: TypeLiteral, name: PropertyKey): PropertySignature | undefined => { + // from property signatures... + const ops = Arr.findFirst(ast.propertySignatures, (ps) => ps.name === name) + if (Option.isSome(ops)) { + return ops.value + } + + // from index signatures... + if (Predicate.isString(name)) { + let out: PropertySignature | undefined = undefined + for (const is of ast.indexSignatures) { + const parameterBase = getParameterBase(is.parameter) + switch (parameterBase._tag) { + case "TemplateLiteral": { + const regex = getTemplateLiteralRegExp(parameterBase) + if (regex.test(name)) { + return new PropertySignature(name, is.type, false, true) + } + break + } + case "StringKeyword": { + if (out === undefined) { + out = new PropertySignature(name, is.type, false, true) + } + } + } + } + if (out) { + return out + } + } else if (Predicate.isSymbol(name)) { + for (const is of ast.indexSignatures) { + const parameterBase = getParameterBase(is.parameter) + if (isSymbolKeyword(parameterBase)) { + return new PropertySignature(name, is.type, false, true) + } + } + } +} + /** @internal */ export const getPropertyKeyIndexedAccess = (ast: AST, name: PropertyKey): PropertySignature => { + const annotation = getSurrogateAnnotation(ast) + if (Option.isSome(annotation)) { + return getPropertyKeyIndexedAccess(annotation.value, name) + } switch (ast._tag) { - case "Declaration": { - const annotation = getSurrogateAnnotation(ast) - if (Option.isSome(annotation)) { - return getPropertyKeyIndexedAccess(annotation.value, name) - } - break - } case "TypeLiteral": { - const ops = Arr.findFirst(ast.propertySignatures, (ps) => ps.name === name) - if (Option.isSome(ops)) { - return ops.value - } else { - if (Predicate.isString(name)) { - let out: PropertySignature | undefined = undefined - for (const is of ast.indexSignatures) { - const parameterBase = getParameterBase(is.parameter) - switch (parameterBase._tag) { - case "TemplateLiteral": { - const regex = getTemplateLiteralRegExp(parameterBase) - if (regex.test(name)) { - return new PropertySignature(name, is.type, false, true) - } - break - } - case "StringKeyword": { - if (out === undefined) { - out = new PropertySignature(name, is.type, false, true) - } - } - } - } - if (out) { - return out - } - } else if (Predicate.isSymbol(name)) { - for (const is of ast.indexSignatures) { - const parameterBase = getParameterBase(is.parameter) - if (isSymbolKeyword(parameterBase)) { - return new PropertySignature(name, is.type, false, true) - } - } - } + const ps = getTypeLiteralPropertySignature(ast, name) + if (ps) { + return ps } break } @@ -2124,6 +2130,8 @@ export const getPropertyKeyIndexedAccess = (ast: AST, name: PropertyKey): Proper ) case "Suspend": return getPropertyKeyIndexedAccess(ast.f(), name) + case "Refinement": + return getPropertyKeyIndexedAccess(ast.from, name) } return new PropertySignature(name, neverKeyword, false, true) } @@ -2202,44 +2210,68 @@ export const record = (key: AST, value: AST): { * @since 0.67.0 */ export const pick = (ast: AST, keys: ReadonlyArray): TypeLiteral | Transformation => { - if (isTransformation(ast)) { - switch (ast.transformation._tag) { - case "ComposeTransformation": - return new Transformation( - pick(ast.from, keys), - pick(ast.to, keys), - composeTransformation - ) - case "TypeLiteralTransformation": { - const ts: Array = [] - const fromKeys: Array = [] - for (const k of keys) { - const t = ast.transformation.propertySignatureTransformations.find((t) => t.to === k) - if (t) { - ts.push(t) - fromKeys.push(t.from) - } else { - fromKeys.push(k) + const annotation = getSurrogateAnnotation(ast) + if (Option.isSome(annotation)) { + return pick(annotation.value, keys) + } + switch (ast._tag) { + case "TypeLiteral": { + const pss: Array = [] + const names: Record = {} + for (const ps of ast.propertySignatures) { + names[ps.name] = null + if (keys.includes(ps.name)) { + pss.push(ps) + } + } + for (const key of keys) { + if (!(key in names)) { + const ps = getTypeLiteralPropertySignature(ast, key) + if (ps) { + pss.push(ps) } } - return Arr.isNonEmptyReadonlyArray(ts) ? - new Transformation( - pick(ast.from, fromKeys), - pick(ast.to, keys), - new TypeLiteralTransformation(ts) - ) : - pick(ast.from, fromKeys) } - case "FinalTransformation": { - const annotation = getSurrogateAnnotation(ast) - if (Option.isSome(annotation)) { - return pick(annotation.value, keys) + return new TypeLiteral(pss, []) + } + case "Union": + return new TypeLiteral(keys.map((name) => getPropertyKeyIndexedAccess(ast, name)), []) + case "Suspend": + return pick(ast.f(), keys) + case "Refinement": + return pick(ast.from, keys) + case "Transformation": { + switch (ast.transformation._tag) { + case "ComposeTransformation": + return new Transformation( + pick(ast.from, keys), + pick(ast.to, keys), + composeTransformation + ) + case "TypeLiteralTransformation": { + const ts: Array = [] + const fromKeys: Array = [] + for (const k of keys) { + const t = ast.transformation.propertySignatureTransformations.find((t) => t.to === k) + if (t) { + ts.push(t) + fromKeys.push(t.from) + } else { + fromKeys.push(k) + } + } + return Arr.isNonEmptyReadonlyArray(ts) ? + new Transformation( + pick(ast.from, fromKeys), + pick(ast.to, keys), + new TypeLiteralTransformation(ts) + ) : + pick(ast.from, fromKeys) } - throw new Error(errors_.getASTUnsupportedSchema(ast)) } } } - return new TypeLiteral(keys.map((key) => getPropertyKeyIndexedAccess(ast, key)), []) + throw new Error(errors_.getASTUnsupportedSchema(ast)) } /** diff --git a/packages/schema/test/AST/pick.test.ts b/packages/schema/test/AST/pick.test.ts index 806e09c1f2..7f5701ddd6 100644 --- a/packages/schema/test/AST/pick.test.ts +++ b/packages/schema/test/AST/pick.test.ts @@ -3,12 +3,36 @@ import * as S from "@effect/schema/Schema" import { describe, expect, it } from "vitest" describe("pick", () => { - it("TypeLiteral", async () => { + it("refinement", async () => { + const schema = S.Struct({ a: S.NumberFromString, b: S.Number }).pipe(S.filter(() => true)) + const ast = schema.pipe(S.pick("a")).ast + expect(ast).toStrictEqual(S.Struct({ a: S.NumberFromString }).ast) + }) + + it("struct", async () => { const schema = S.Struct({ a: S.NumberFromString, b: S.Number }) const ast = schema.pipe(S.pick("a")).ast expect(ast).toStrictEqual(S.Struct({ a: S.NumberFromString }).ast) }) + it("struct + record", async () => { + const schema = S.Struct( + { a: S.NumberFromString, b: S.Number }, + S.Record({ key: S.String, value: S.Union(S.String, S.Number) }) + ) + const ast = schema.pipe(S.pick("a", "c")).ast + expect(ast).toStrictEqual(S.Struct({ a: S.NumberFromString, c: S.Union(S.String, S.Number) }).ast) + }) + + it("union", async () => { + const A = S.Struct({ a: S.String }) + const B = S.Struct({ a: S.Number }) + const schema = S.Union(A, B) + const pick = schema.pipe(S.pick("a")) + const ast = pick.ast + expect(ast).toStrictEqual(S.Struct({ a: S.Union(S.String, S.Number) }).ast) + }) + describe("transformation", () => { it("ComposeTransformation", async () => { const schema = S.compose( @@ -41,11 +65,22 @@ describe("pick", () => { }) }) - it("with SurrogateAnnotation", async () => { - class A extends S.Class("A")({ a: S.NumberFromString, b: S.Number }) {} - const schema = A - const ast = schema.pipe(S.pick("a")).ast - expect(ast).toStrictEqual(S.Struct({ a: S.NumberFromString }).ast) + describe("SurrogateAnnotation", () => { + it("a single Class", async () => { + class A extends S.Class("A")({ a: S.NumberFromString, b: S.Number }) {} + const schema = A + const ast = schema.pipe(S.pick("a")).ast + expect(ast).toStrictEqual(S.Struct({ a: S.NumberFromString }).ast) + }) + + it("a union of Classes", async () => { + class A extends S.Class("A")({ a: S.Number }) {} + class B extends S.Class("B")({ a: S.String }) {} + const schema = S.Union(A, B) + const pick = schema.pipe(S.pick("a")) + const ast = pick.ast + expect(ast).toStrictEqual(S.Struct({ a: S.Union(S.Number, S.String) }).ast) + }) }) }) }) diff --git a/packages/schema/test/Schema/pick.test.ts b/packages/schema/test/Schema/pick.test.ts index 7b582020a9..02392f48c6 100644 --- a/packages/schema/test/Schema/pick.test.ts +++ b/packages/schema/test/Schema/pick.test.ts @@ -13,19 +13,19 @@ describe("pick", () => { await Util.expectDecodeUnknownFailure( schema, null, - "Expected { readonly Symbol(@effect/schema/test/a): string; readonly b: NumberFromString }, actual null" + "Expected { readonly b: NumberFromString; readonly Symbol(@effect/schema/test/a): string }, actual null" ) await Util.expectDecodeUnknownFailure( schema, { [a]: "a" }, - `{ readonly Symbol(@effect/schema/test/a): string; readonly b: NumberFromString } + `{ readonly b: NumberFromString; readonly Symbol(@effect/schema/test/a): string } └─ ["b"] └─ is missing` ) await Util.expectDecodeUnknownFailure( schema, - { b: 1 }, - `{ readonly Symbol(@effect/schema/test/a): string; readonly b: NumberFromString } + { b: "1" }, + `{ readonly b: NumberFromString; readonly Symbol(@effect/schema/test/a): string } └─ [Symbol(@effect/schema/test/a)] └─ is missing` )