Skip to content

Commit

Permalink
Merge pull request #57 from anh-tong/fix-ci-and-test
Browse files Browse the repository at this point in the history
Fix CI and log signature test in `word` mode
  • Loading branch information
anh-tong authored Mar 27, 2024
2 parents 539e160 + 28db27d commit 85b3a31
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 30 deletions.
5 changes: 2 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,9 @@ jobs:
python -m pip install --upgrade setuptools pip wheel
python -m pip install .[test]
- name: Install heavy deps for testing (signatory + torch)
- name: Install iisignature for testing
run: |
python -m pip install torch==1.9.0
python -m pip install signatory==1.2.6.1.9.0 --no-cache-dir --force-reinstall
python -m pip install iisignature
- name: Test package
run: >-
Expand Down
79 changes: 75 additions & 4 deletions tests/test_against_iisignature.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import pytest
from numpy.random import default_rng

from signax import logsignature, signature
from signax import logsignature, multi_signature_combine, signature
from signax.utils import compress, lyndon_words, unravel_signature

rng = default_rng()

Expand All @@ -32,14 +33,84 @@ def test_signature(depth, length, dim, stream):
assert jnp.allclose(jax_signature, iis_signature)


def test_multi_signature_combine():
# iisignature does not support multiple signature combination
# we only test for the case of combining two signatures
# note: this test is passed in signatory before migrating to iisignature
n_signatures = 2
dim = 5
signatures = [
rng.standard_normal((n_signatures, dim)),
rng.standard_normal((n_signatures, dim, dim)),
rng.standard_normal((n_signatures, dim, dim, dim)),
]

jax_signatures = [jnp.array(x) for x in signatures]

jax_output = multi_signature_combine(jax_signatures)
jax_output = jnp.concatenate([jnp.ravel(x) for x in jax_output])

iis_signatures = []
for i in range(n_signatures):
tensors = [np.asarray(x[i]) for x in signatures]
current = np.concatenate([t.flatten() for t in tensors])
current = current[None, :]
iis_signatures.append(current)

iis_output = iisignature.sigcombine(
iis_signatures[0],
iis_signatures[1],
dim,
len(signatures),
)
iis_output = jnp.array(iis_output)
assert jnp.allclose(jax_output, iis_output)


@pytest.mark.parametrize("stream", [True, False])
def test_signature_batch(stream):
depth = 3

# no remainder case
length = 1001
dim = 5
n_chunks = 10

path = rng.standard_normal((length, dim))
jax_signature = signature(
path, depth=depth, num_chunks=n_chunks, stream=stream, flatten=True
)

iis_signature = (
iisignature.sig(np.asarray(path), depth)
if not stream
else iisignature.sig(np.asarray(path), depth, 2)
)
iis_signature = jnp.asarray(iis_signature)

assert jnp.allclose(jax_signature, iis_signature)


@pytest.mark.parametrize(
("depth", "length", "dim", "stream"), [(1, 2, 2, False), (3, 3, 5, False)]
)
def test_logsignature(depth, length, dim, stream):
batch_size = 10
path = rng.standard_normal((batch_size, length, dim))
jax_logsignature = logsignature(path, depth=depth, stream=stream, flatten=True)
s = iisignature.prepare(dim, depth, "O")
iis_logsignature = iisignature.logsig(np.asarray(path), s)
iis_logsignature = jnp.asarray(iis_logsignature)

# get expanded version of log signature
s = iisignature.prepare(dim, depth, "x")
iis_logsignature = iisignature.logsig(np.asarray(path), s, "x")

def _compress(expanded_log_signature):
# convert expanded array as list of arrays
expanded_log_signature = unravel_signature(expanded_log_signature, dim, depth)
indices = lyndon_words(depth, dim)
compressed = compress(expanded_log_signature, indices)
compressed = jnp.concatenate(compressed)
return compressed

iis_logsignature = jax.vmap(_compress)(iis_logsignature)

assert jnp.allclose(jax_logsignature, iis_logsignature, atol=5e-1, rtol=5e-1)
49 changes: 26 additions & 23 deletions tests/test_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import numpy as np
from numpy.random import default_rng

from signax import signature, signature_to_logsignature
from signax.tensor_ops import (
addcmul,
mult,
mult_fused_restricted_exp,
otimes,
restricted_exp,
)
from signax.utils import compress, lyndon_words, unravel_signature

rng = default_rng()

Expand Down Expand Up @@ -110,26 +112,27 @@ def test_mult():
assert jnp.allclose(iis_signature, jax_output)


# def test_log():
# """Test log via signature_to_logsignature"""
# depth = 4
# length, dim = 3, 2
# path = rng.standard_normal((length, dim))
# jax_path = jnp.array(path)
# jax_signature = signature(jax_path, depth, flatten=False)
# jax_logsignature = signature_to_logsignature(jax_signature)
# jax_output = jnp.concatenate([jnp.ravel(x) for x in jax_logsignature])

# torch_signature = signatory.signature(
# torch.tensor(path)[None, ...],
# depth,
# )
# torch_logsignature = signatory.signature_to_logsignature(
# torch_signature,
# dim,
# depth,
# )

# torch_output = jnp.array(torch_logsignature.numpy())

# assert jnp.allclose(torch_output, jax_output)
def test_log():
"""Test log via signature_to_logsignature"""
depth = 4
length, dim = 3, 2
path = rng.standard_normal((length, dim))
jax_path = jnp.array(path)
jax_signature = signature(jax_path, depth, flatten=False)
jax_logsignature = signature_to_logsignature(jax_signature)
jax_output = jnp.concatenate([jnp.ravel(x) for x in jax_logsignature])

s = iisignature.prepare(dim, depth, "x")
iis_logsignature = iisignature.logsig(np.asarray(path), s, "x")

def _compress(expanded_log_signature):
# convert expanded array as list of arrays
expanded_log_signature = unravel_signature(expanded_log_signature, dim, depth)
indices = lyndon_words(depth, dim)
compressed = compress(expanded_log_signature, indices)
compressed = jnp.concatenate(compressed)
return compressed

iis_logsignature = _compress(iis_logsignature)

assert jnp.allclose(iis_logsignature, jax_output)

0 comments on commit 85b3a31

Please sign in to comment.