diff --git a/Sources/starkbank-ecdsa/Curve.swift b/Sources/starkbank-ecdsa/Curve.swift index 36a114f..d6eddfc 100644 --- a/Sources/starkbank-ecdsa/Curve.swift +++ b/Sources/starkbank-ecdsa/Curve.swift @@ -35,7 +35,16 @@ public class CurveFp { - Returns: boolean */ func contains(p: Point) -> Bool { - return (p.y.power(2) - (p.x.power(3) + self.A * p.x + self.B)) % self.P == 0 + if (p.x < 0 || p.x >= self.P) { + return false + } + if (p.y < 0 || p.y >= self.P) { + return false + } + if ((p.y.power(2) - (p.x.power(3) + self.A * p.x + self.B)) % self.P != 0) { + return false + } + return true } func length() -> Int { diff --git a/Sources/starkbank-ecdsa/Ecdsa.swift b/Sources/starkbank-ecdsa/Ecdsa.swift index 3776e79..f1f1bbc 100644 --- a/Sources/starkbank-ecdsa/Ecdsa.swift +++ b/Sources/starkbank-ecdsa/Ecdsa.swift @@ -39,8 +39,10 @@ public class Ecdsa { let inv = Math.inv(s, curve.N) let u1 = Math.multiply(curve.G, (numberMessage * inv).modulus(curve.N), curve.N, curve.A, curve.P) let u2 = Math.multiply(publicKey.point, (r * inv).modulus(curve.N), curve.N, curve.A, curve.P) - let add = Math.add(u1, u2, curve.A, curve.P) - let modX = add.x.modulus(curve.N) - return r == modX + let v = Math.add(u1, u2, curve.A, curve.P) + if (v.isAtInfinity()) { + return false + } + return v.x.modulus(curve.N) == r } } diff --git a/Sources/starkbank-ecdsa/Point.swift b/Sources/starkbank-ecdsa/Point.swift index f843e03..35507d3 100644 --- a/Sources/starkbank-ecdsa/Point.swift +++ b/Sources/starkbank-ecdsa/Point.swift @@ -20,4 +20,8 @@ public class Point { self.y = y self.z = z } + + func isAtInfinity() -> Bool { + return self.y == BigInt(0) + } } diff --git a/Sources/starkbank-ecdsa/PublicKey.swift b/Sources/starkbank-ecdsa/PublicKey.swift index 71dad47..055a1cb 100644 --- a/Sources/starkbank-ecdsa/PublicKey.swift +++ b/Sources/starkbank-ecdsa/PublicKey.swift @@ -76,13 +76,26 @@ public class PublicKey { let point = Point(BinaryAscii.intFromHex(xs), BinaryAscii.intFromHex(ys)) - if (validatePoint && !curve.contains(p: point)) { + let publicKey = PublicKey(point: point, curve: curve) + if (!validatePoint) { + return publicKey + } + if (point.isAtInfinity()) { + throw Error.infinityError("Public Key point is at infinity") + } + if (!curve.contains(p: point)) { throw Error.pointError("Point ({x},{y}) is not valid for curve {name}" .replacingOccurrences(of: "{x}", with: String(point.x)) - .replacingOccurrences(of:"{y}", with: String(point.y)) - .replacingOccurrences(of:"{name}", with: curve.name)) + .replacingOccurrences(of: "{y}", with: String(point.y)) + .replacingOccurrences(of: "{name}", with: curve.name)) + } + if (!Math.multiply(point, curve.N, curve.N, curve.A, curve.P).isAtInfinity()) { + throw Error.pointError("Point ({x},{y}) * {name}.N is not at infinity" + .replacingOccurrences(of: "{x}", with: String(point.x)) + .replacingOccurrences(of: "{y}", with: String(point.y)) + .replacingOccurrences(of: "{name}", with: curve.name)) } - return PublicKey(point: point, curve: curve) + return publicKey } } diff --git a/Sources/starkbank-ecdsa/Utils/Error.swift b/Sources/starkbank-ecdsa/Utils/Error.swift index 5021d76..78ebab5 100644 --- a/Sources/starkbank-ecdsa/Utils/Error.swift +++ b/Sources/starkbank-ecdsa/Utils/Error.swift @@ -15,4 +15,5 @@ enum Error: Swift.Error { case pointError(String) case generationError(String) case invalidPath(String) + case infinityError(String) } diff --git a/Tests/starkbank-ecdsaTests/EcdsaTests.swift b/Tests/starkbank-ecdsaTests/EcdsaTests.swift index 95e7e1f..e107939 100644 --- a/Tests/starkbank-ecdsaTests/EcdsaTests.swift +++ b/Tests/starkbank-ecdsaTests/EcdsaTests.swift @@ -33,4 +33,12 @@ class EcdsaTests: XCTestCase { XCTAssertFalse(Ecdsa.verify(message: message2, signature: signature, publicKey: publicKey)) } + + func testSignatureZero() throws { + let privateKey = try PrivateKey() + let publicKey = privateKey.publicKey() + let message = "This is a text message" + + XCTAssertFalse(Ecdsa.verify(message: message, signature: Signature(0, 0), publicKey: publicKey)) + } }