Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

strange hanging behaviour with pmap #5065

Closed
Joshuaalbert opened this issue Dec 1, 2020 · 6 comments
Closed

strange hanging behaviour with pmap #5065

Joshuaalbert opened this issue Dec 1, 2020 · 6 comments

Comments

@Joshuaalbert
Copy link
Contributor

It seems like when pmap is called repeatedly it hangs at some point. This point seems to be deterministic.

context

I'm using pmap (not soft_pmap) to distribute tasks over devices. I've created a chunked_pmap to do this:

def chunked_pmap(f, *args, chunksize=None):
    """
    Calls pmap on chunks of moderate work to be distributed over devices.
    Automatically handle non-dividing chunksizes, by adding filler elements.
    
    Args:
        f: callable
        *args: arguments to map down first dimension
        chunksize: optional chunk size else num devices

    Returns: pytree mapped result.
    """
    if chunksize is None:
        chunksize = local_device_count()
    N = args[0].shape[0]
    remainder = N % chunksize
    # add some extra filler elements to let tree_multimap work nices at the end.
    if (remainder != 0) and (N > chunksize):
        args = [jnp.concatenate([arg, arg[:remainder]], axis=0) for arg in args]
        N = args[0].shape[0]
    results = []
    for start in range(0, N, chunksize):
        stop = min(start + chunksize, N)
        print("Starting slice {}:{}".format(start,stop))
        results.append(pmap(f)(*[arg[start:stop] for arg in args]))
    # merge chunked results
    result = tree_multimap(lambda *args: jnp.concatenate(args, axis=0), *results)
    # remove filler elements
    if remainder != 0:
        result = tree_map(lambda x: x[:-remainder], result)
    return result

problem specifics

I am using chunked_pmap to distribute moderate workloads to devices. I have a function unconstrained_solve and arguments for it that get distributed like this:

chunked_pmap(unconstrained_solve, random.split(random.PRNGKey(746583), T), Y_obs, chunksize=1)

When I run this, with chunksize=1 (so that it's just sequentially running problems) it hangs at chunk index 223 (out of many more).

diagnosis effort

Here I describe diagnosis angles and seek your input and help.

Is it a data/function problem or pmap problem?

Seems to be a pmap problem.
The problems being distributed by pmap run fine on their own.
I have run the unconstrained_solve on each slice of data and it computes nominally. i.e. this works:

_unconstrained_solve = jit(unconstrained_solve)
for key, _Y_obs in zip(random.split(random.PRNGKey(746583), T), Y_obs):
    result1 = _unconstrained_solve(key, _Y_obs)
    result2 = unconstrained_solve(key, _Y_obs)
    assert result1 == result2

works fine as expected.

Is it dependent on chunksize?

The point that it hangs is dependent on chunksize. I have 8 cores on my CPU and here's what I get when I try using different chunksize's
When chunksize=1 it hangs on slice 223:224.
When chunksize=2 it hangs on slice 226:228.
When chunksize=3 it hangs on slice 315:318.
When chunksize=4 it hangs on slice 232:236.
When chunksize=5 it hangs on slice 235:240.
When chunksize=6 it hangs on slice 408:414.
When chunksize=7 it hangs on slice 441:448.
When chunksize=8 it hangs on slice 472:480.

There's something special about those slices when it hangs.

I tried starting the distribution at later slices, but before the problematic slices. And I get these peculiar behaviours:

With chunksize=1 the problem was at slice 223:224 so I tried starting at slice 221:222.

chunked_pmap(unconstrained_solve, random.split(random.PRNGKey(746583), T)[221:], Y_obs[221:], chunksize=1)

This goes past 223 and hangs at 570:571. However, if I start at 222:223 (one later) then it makes it to 569:570 (one behind).

Other information

These numbers above are consistent. I've run the same code ~10 times and get the same hanging slices. So it seems to be deterministic (less likely to be a race condition).

The functions being distributed take approximately 0.1 seconds to run.

jaxjax==0.2.6
jaxlib==0.1.57
@Joshuaalbert
Copy link
Contributor Author

@mattjj any ideas what this could be, and if there's something I can do to force pmap to play nicely when getting rapidly called?

@Joshuaalbert
Copy link
Contributor Author

I guess this doesn't receive attention because it's too long and it's not clear where the problem stems from. I'm closing.

@hawkinsp
Copy link
Collaborator

hawkinsp commented Dec 9, 2020

We would like to look into it, but we're busy and anything you can do to minimize the problem would help us out!

@hawkinsp
Copy link
Collaborator

hawkinsp commented Dec 9, 2020

Rereading this, I don't think there's enough information to debug it. Can you give us something that we can run? I suspect you can strip out most of the math; it's probably only the communication structure that's relevant.

@Joshuaalbert
Copy link
Contributor Author

@hawkinsp Hi Peter, I wasn't able to isolate it for the reason of #5117. I tried to isolate a small part of code that replicated the problem. I would try to use disable_jit (so that I could put in print statements so that I could find where the hanging begins) and it strangely wouldn't hang anymore, nor would it with jit. That is the substance of #5117, which has a complete example showing the behaviour.

@ethanabrooks
Copy link

This is an issue for me as well. Appreciate any help resolving.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants