Skip to content

Commit

Permalink
refactor(logic): harmonize term to/from bytes functions
Browse files Browse the repository at this point in the history
  • Loading branch information
ccamel committed Jan 6, 2024
1 parent 0fe855c commit 1e262b0
Show file tree
Hide file tree
Showing 14 changed files with 637 additions and 355 deletions.
11 changes: 3 additions & 8 deletions x/logic/predicate/address.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func Bech32Address(vm *engine.VM, address, bech32 engine.Term, cont engine.Cont,
if err != nil {
return engine.Error(fmt.Errorf("bech32_address/2: failed to decode Bech32: %w", err))
}
pair := AtomPair.Apply(util.StringToTerm(h), BytesToList(a))
pair := AtomPair.Apply(util.StringToTerm(h), util.BytesToStringTermDefault(a))
return engine.Unify(vm, address, pair, cont, env)
default:
return engine.Error(fmt.Errorf("bech32_address/2: invalid Bech32 type: %T, should be Atom or Variable", b))
Expand All @@ -71,12 +71,7 @@ func addressPairToBech32(addressPair engine.Compound, env *engine.Env) (string,

switch a := env.Resolve(addressPair.Arg(1)).(type) {
case engine.Compound:
if a.Arity() != 2 || a.Functor().String() != "." {
return "", fmt.Errorf("address should be a List of bytes")
}

iter := engine.ListIterator{List: a, Env: env}
data, err := ListToBytes(iter, env)
data, err := util.StringTermToBytes(a, "", env)
if err != nil {
return "", fmt.Errorf("failed to convert term to bytes list: %w", err)
}
Expand All @@ -91,6 +86,6 @@ func addressPairToBech32(addressPair engine.Compound, env *engine.Env) (string,

return b, nil
default:
return "", fmt.Errorf("address should be a Pair with a List of bytes in arity 2, give %T", addressPair.Arg(1))
return "", fmt.Errorf("address should be a Pair with a List of bytes in arity 2, given: %T", addressPair.Arg(1))
}
}
6 changes: 3 additions & 3 deletions x/logic/predicate/address_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func TestBech32(t *testing.T) {
},
{
query: `bech32_address(-('okp4', ['8956',167,23,244,162,175,49,162,170,15,181,141,68,134,141,168,18,56,247,30]), Bech32).`,
wantError: fmt.Errorf("bech32_address/2: failed to convert term to bytes list: invalid term type in list at position 1: engine.Atom, only engine.Integer allowed"),
wantError: fmt.Errorf("bech32_address/2: failed to convert term to bytes list: invalid character_code '8956' value in list at position 1: should be a single character"),
wantSuccess: false,
},
{
Expand All @@ -99,12 +99,12 @@ func TestBech32(t *testing.T) {
},
{
query: `bech32_address(-('okp4', hey(2)), Bech32).`,
wantError: fmt.Errorf("bech32_address/2: address should be a List of bytes"),
wantError: fmt.Errorf("bech32_address/2: failed to convert term to bytes list: invalid compound term: expected a list of character_code or integer"),
wantSuccess: false,
},
{
query: `bech32_address(-('okp4', 'foo'), Bech32).`,
wantError: fmt.Errorf("bech32_address/2: address should be a Pair with a List of bytes in arity 2, give engine.Atom"),
wantError: fmt.Errorf("bech32_address/2: address should be a Pair with a List of bytes in arity 2, given: engine.Atom"),
wantSuccess: false,
},
{
Expand Down
21 changes: 21 additions & 0 deletions x/logic/predicate/atom.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,27 @@ var (

// AtomNull is the term null.
AtomNull = engine.NewAtom("null")

// AtomEncoding is the term used to indicate the encoding type option.
AtomEncoding = engine.NewAtom("encoding")

// AtomUtf8 is the term used to indicate the UTF-8 encoding type option.
AtomUtf8 = engine.NewAtom("utf8")

// AtomHex is the term used to indicate the hexadecimal encoding type option.
AtomHex = engine.NewAtom("hex")

// AtomOctet is the term used to indicate the byte encoding type option.
AtomOctet = engine.NewAtom("octet")

// AtomCharset is the term used to indicate the charset encoding type option.
AtomCharset = engine.NewAtom("charset")

// AtomPadding is the term used to indicate the padding encoding type option.
AtomPadding = engine.NewAtom("padding")

// AtomAs is the term used to indicate the as encoding type option.
AtomAs = engine.NewAtom("as")
)

// MakeNull returns the compound term @(null).
Expand Down
4 changes: 2 additions & 2 deletions x/logic/predicate/bank.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ func fetchBalances(
func(ctx context.Context) *engine.Promise {
return engine.Unify(
vm,
Tuple(engine.NewAtom(address), CoinsToTerm(coins)),
Tuple(account, balances),
util.Tuple(engine.NewAtom(address), CoinsToTerm(coins)),
util.Tuple(account, balances),
cont,
env,
)
Expand Down
95 changes: 21 additions & 74 deletions x/logic/predicate/crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package predicate

import (
"context"
"encoding/hex"
"fmt"
"slices"
"strings"
Expand Down Expand Up @@ -58,9 +57,9 @@ func CryptoDataHash(
algorithmOpt := engine.NewAtom("algorithm")

return engine.Delay(func(ctx context.Context) *engine.Promise {
algorithmAtom, err := getOptionAsAtomWithDefault(algorithmOpt, options, engine.NewAtom("sha256"), env, functor)
algorithmAtom, err := util.GetOptionAsAtomWithDefault(algorithmOpt, options, engine.NewAtom("sha256"), env)
if err != nil {
return engine.Error(err)
return engine.Error(fmt.Errorf("%s: %w", functor, err))
}
algorithm, err := util.ParseHashAlg(algorithmAtom.String())
if err != nil {
Expand All @@ -69,7 +68,7 @@ func CryptoDataHash(
algorithmAtom.String(),
util.HashAlgNames()))
}
decodedData, err := TermToBytes(data, options, AtomUtf8, env)
decodedData, err := termToBytes(data, options, AtomUtf8, env)
if err != nil {
return engine.Error(fmt.Errorf("%s: failed to decode data: %w", functor, err))
}
Expand All @@ -79,62 +78,7 @@ func CryptoDataHash(
return engine.Error(fmt.Errorf("%s: failed to hash data: %w", functor, err))
}

return engine.Unify(vm, hash, BytesToList(result), cont, env)
})
}

// HexBytes is a predicate that unifies hexadecimal encoded bytes to a list of bytes.
//
// The signature is as follows:
//
// hex_bytes(?Hex, ?Bytes) is det
//
// Where:
// - Hex is an Atom, string or list of characters in hexadecimal encoding.
// - Bytes is the list of numbers between 0 and 255 that represent the sequence of bytes.
//
// Examples:
//
// # Convert hexadecimal atom to list of bytes.
// - hex_bytes('2c26b46b68ffc68ff99b453c1d3041341342d706483bfa0f98a5e886266e7ae', Bytes).
func HexBytes(vm *engine.VM, hexa, bts engine.Term, cont engine.Cont, env *engine.Env) *engine.Promise {
return engine.Delay(func(ctx context.Context) *engine.Promise {
var result []byte

switch h := env.Resolve(hexa).(type) {
case engine.Variable:
case engine.Atom:
src := []byte(h.String())
result = make([]byte, hex.DecodedLen(len(src)))
_, err := hex.Decode(result, src)
if err != nil {
return engine.Error(fmt.Errorf("hex_bytes/2: failed decode hexadecimal %w", err))
}
default:
return engine.Error(fmt.Errorf("hex_bytes/2: invalid hex type: %T, should be Atom or Variable", h))
}

switch b := env.Resolve(bts).(type) {
case engine.Variable:
if result == nil {
return engine.Error(fmt.Errorf("hex_bytes/2: nil hexadecimal conversion in input"))
}
return engine.Unify(vm, bts, BytesToList(result), cont, env)
case engine.Compound:
if b.Arity() != 2 || b.Functor().String() != "." {
return engine.Error(fmt.Errorf("hex_bytes/2: bytes should be a List, give %T", b))
}
iter := engine.ListIterator{List: b, Env: env}

src, err := ListToBytes(iter, env)
if err != nil {
return engine.Error(fmt.Errorf("hex_bytes/2: failed convert list into bytes: %w", err))
}
dst := hex.EncodeToString(src)
return engine.Unify(vm, hexa, util.StringToTerm(dst), cont, env)
default:
return engine.Error(fmt.Errorf("hex_bytes/2: invalid hex type: %T, should be Variable or List", b))
}
return engine.Unify(vm, hash, util.BytesToStringTermDefault(result), cont, env)
})
}

Expand Down Expand Up @@ -227,7 +171,7 @@ func xVerify(functor string, key, data, sig, options engine.Term, defaultAlgo ut
if err != nil {
return engine.Error(fmt.Errorf("%s: %w", functor, err))
}
typeAtom, err := util.ResolveToAtom(env, typeTerm)
typeAtom, err := util.AssertAtom(env, typeTerm)
if err != nil {
return engine.Error(fmt.Errorf("%s: %w", functor, err))
}
Expand All @@ -239,17 +183,17 @@ func xVerify(functor string, key, data, sig, options engine.Term, defaultAlgo ut
strings.Join(util.Map(algos, func(a util.KeyAlg) string { return a.String() }), ", ")))
}

decodedKey, err := TermToBytes(key, AtomEncoding.Apply(AtomOctet), AtomHex, env)
decodedKey, err := termToBytes(key, AtomEncoding.Apply(AtomOctet), AtomHex, env)
if err != nil {
return engine.Error(fmt.Errorf("%s: failed to decode public key: %w", functor, err))
}

decodedData, err := TermToBytes(data, options, AtomHex, env)
decodedData, err := termToBytes(data, options, AtomHex, env)
if err != nil {
return engine.Error(fmt.Errorf("%s: failed to decode data: %w", functor, err))
}

decodedSignature, err := TermToBytes(sig, AtomEncoding.Apply(AtomOctet), AtomHex, env)
decodedSignature, err := termToBytes(sig, AtomEncoding.Apply(AtomOctet), AtomHex, env)
if err != nil {
return engine.Error(fmt.Errorf("%s: failed to decode signature: %w", functor, err))
}
Expand All @@ -267,19 +211,22 @@ func xVerify(functor string, key, data, sig, options engine.Term, defaultAlgo ut
})
}

// getOptionAsAtomWithDefault is a helper function that returns the value of the first option with the given name in the
// given options.
func getOptionAsAtomWithDefault(algorithmOpt engine.Atom, options engine.Term, defaultValue engine.Term, env *engine.Env,
functor string,
) (engine.Atom, error) {
term, err := util.GetOptionWithDefault(algorithmOpt, options, defaultValue, env)
func termToBytes(term, options, defaultEncoding engine.Term, env *engine.Env) ([]byte, error) {
encodingTerm, err := util.GetOptionWithDefault(AtomEncoding, options, defaultEncoding, env)
if err != nil {
return util.AtomEmpty, fmt.Errorf("%s: %w", functor, err)
return nil, err
}
atom, err := util.ResolveToAtom(env, term)
encodingAtom, err := util.AssertAtom(env, encodingTerm)
if err != nil {
return util.AtomEmpty, fmt.Errorf("%s: %w", functor, err)
return nil, err
}

return atom, nil
switch encodingAtom {
case AtomHex:
return util.TermHexToBytes(term, env)
case AtomOctet, AtomUtf8:
return util.StringTermToBytes(term, "", env)
default:
return nil, fmt.Errorf("invalid encoding: %s. Possible values: hex, octet", encodingAtom.String())
}
}
41 changes: 0 additions & 41 deletions x/logic/predicate/crypto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,47 +29,6 @@ func TestCryptoOperations(t *testing.T) {
wantError error
wantSuccess bool
}{
{
query: `hex_bytes(Hex,
[44,38,180,107,104,255,198,143,249,155,69,60,29,48,65,52,19,66,45,112,100,131,191,160,249,138,94,136,98,102,231,174]).`,
wantResult: []types.TermResults{{"Hex": "'2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae'"}},
wantSuccess: true,
},
{
query: `hex_bytes('2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae', Bytes).`,
wantResult: []types.TermResults{{
"Bytes": "[44,38,180,107,104,255,198,143,249,155,69,60,29,48,65,52,19,66,45,112,100,131,191,160,249,138,94,136,98,102,231,174]",
}},
wantSuccess: true,
},
{
query: `hex_bytes('2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae',
[44,38,180,107,104,255,198,143,249,155,69,60,29,48,65,52,19,66,45,112,100,131,191,160,249,138,94,136,98,102,231,174]).`,
wantResult: []types.TermResults{{}},
wantSuccess: true,
},
{
query: `hex_bytes('3c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae',
[44,38,180,107,104,255,198,143,249,155,69,60,29,48,65,52,19,66,45,112,100,131,191,160,249,138,94,136,98,102,231,174]).`,
wantSuccess: false,
},
{
query: `hex_bytes('fail',
[44,38,180,107,104,255,198,143,249,155,69,60,29,48,65,52,19,66,45,112,100,131,191,160,249,138,94,136,98,102,231,174]).`,
wantError: fmt.Errorf("hex_bytes/2: failed decode hexadecimal encoding/hex: invalid byte: U+0069 'i'"),
wantSuccess: false,
},
{
query: `hex_bytes('2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae',
[45,38,180,107,104,255,198,143,249,155,69,60,29,48,65,52,19,66,45,112,100,131,191,160,249,138,94,136,98,102,231,174]).`,
wantSuccess: false,
},
{
query: `hex_bytes('2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae',
[345,38,'hey',107,104,255,198,143,249,155,69,60,29,48,65,52,19,66,45,112,100,131,191,160,249,138,94,136,98,102,231,174]).`,
wantSuccess: false,
wantError: fmt.Errorf("hex_bytes/2: failed convert list into bytes: invalid integer value in list at position 1: 345 is out of byte range (0-255)"),
},
{
program: `test(Hex) :- crypto_data_hash('hello world', Hash, []), hex_bytes(Hex, Hash).`,
query: `test(Hex).`,
Expand Down
2 changes: 1 addition & 1 deletion x/logic/predicate/did.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func processSegment(segments engine.Compound, segmentNumber uint8, fn func(segme
if _, ok := term.(engine.Variable); ok {
return nil
}
segment, err := util.ResolveToAtom(env, segments.Arg(int(segmentNumber)))
segment, err := util.AssertAtom(env, segments.Arg(int(segmentNumber)))
if err != nil {
return fmt.Errorf("failed to resolve atom at segment %d: %w", segmentNumber, err)
}
Expand Down
61 changes: 61 additions & 0 deletions x/logic/predicate/encoding.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package predicate

import (
"context"
"encoding/hex"
"fmt"

"github.com/ichiban/prolog/engine"

"github.com/okp4/okp4d/x/logic/util"
)

// HexBytes is a predicate that unifies hexadecimal encoded bytes to a list of bytes.
//
// The signature is as follows:
//
// hex_bytes(?Hex, ?Bytes) is det
//
// Where:
// - Hex is an Atom, string or list of characters in hexadecimal encoding.
// - Bytes is the list of numbers between 0 and 255 that represent the sequence of bytes.
//
// Examples:
//
// # Convert hexadecimal atom to list of bytes.
// - hex_bytes('2c26b46b68ffc68ff99b453c1d3041341342d706483bfa0f98a5e886266e7ae', Bytes).
func HexBytes(vm *engine.VM, hexa, bts engine.Term, cont engine.Cont, env *engine.Env) *engine.Promise {
return engine.Delay(func(ctx context.Context) *engine.Promise {
var result []byte

switch h := env.Resolve(hexa).(type) {
case engine.Variable:
case engine.Atom:
src := []byte(h.String())
result = make([]byte, hex.DecodedLen(len(src)))
_, err := hex.Decode(result, src)
if err != nil {
return engine.Error(fmt.Errorf("hex_bytes/2: failed decode hexadecimal %w", err))
}
default:
return engine.Error(fmt.Errorf("hex_bytes/2: invalid hex type: %T, should be Atom or Variable", h))
}

switch b := env.Resolve(bts).(type) {
case engine.Variable:
if result == nil {
return engine.Error(fmt.Errorf("hex_bytes/2: nil hexadecimal conversion in input"))
}
return engine.Unify(vm, bts, util.BytesToStringTermDefault(result), cont, env)
case engine.Compound:
src, err := util.StringTermToBytes(b, "", env)
if err != nil {
return engine.Error(fmt.Errorf("hex_bytes/2: %w", err))
}
dst := hex.EncodeToString(src)
return engine.Unify(vm, hexa, util.StringToTerm(dst), cont, env)
default:
return engine.Error(fmt.Errorf("hex_bytes/2: invalid hex type: %T, should be Variable or List", b))
}
})
}
Loading

0 comments on commit 1e262b0

Please sign in to comment.