Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

problems with numba ufunc + distributed #3450

Open
rabernat opened this issue Feb 6, 2020 · 29 comments · Fixed by numba/numba#9495
Open

problems with numba ufunc + distributed #3450

rabernat opened this issue Feb 6, 2020 · 29 comments · Fixed by numba/numba#9495

Comments

@rabernat
Copy link

rabernat commented Feb 6, 2020

We have created a new software package called fastjmd95 that uses numba to accelerate computation of the ocean equation of state. Everything works find with dask and a local scheduler. Now I want to run this code on a distributed dask cluster. It isn't working, I think because the workers are not able to deserialize the numba functions properly.

Original Full Example

This example with real data can be run on any pangeo cluster

from intake import open_catalog
from fastjmd95 import rho

cat = open_catalog("https://raw.githubusercontent.com/pangeo-data/pangeo-datastore/master/intake-catalogs/ocean.yaml")
ds  = cat["SOSE"].to_dask()

rhonil = 1025
pa_to_dbar = 1.0/10000
p = ds.PHrefC * rhonil * pa_to_dbar
s = ds.SALT
t = ds.THETA
r = rho(s.data, t.data, 0)
# works fine with local scheduler
r_mean = r[:5].compute()

# now start distributed scheduler
from dask.distributed import Client
client = Client()
r_mean = r[:5].compute()
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-4-7316322484d4> in <module>
----> 1 r_mean = r[:5].compute()

/srv/conda/envs/notebook/lib/python3.7/site-packages/dask/base.py in compute(self, **kwargs)
    163         dask.base.compute
    164         """
--> 165         (result,) = compute(self, traverse=False, **kwargs)
    166         return result
    167 

/srv/conda/envs/notebook/lib/python3.7/site-packages/dask/base.py in compute(*args, **kwargs)
    434     keys = [x.__dask_keys__() for x in collections]
    435     postcomputes = [x.__dask_postcompute__() for x in collections]
--> 436     results = schedule(dsk, keys, **kwargs)
    437     return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
    438 

/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/client.py in get(self, dsk, keys, restrictions, loose_restrictions, resources, sync, asynchronous, direct, retries, priority, fifo_timeout, actors, **kwargs)
   2571                     should_rejoin = False
   2572             try:
-> 2573                 results = self.gather(packed, asynchronous=asynchronous, direct=direct)
   2574             finally:
   2575                 for f in futures.values():

/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/client.py in gather(self, futures, errors, direct, asynchronous)
   1871                 direct=direct,
   1872                 local_worker=local_worker,
-> 1873                 asynchronous=asynchronous,
   1874             )
   1875 

/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/client.py in sync(self, func, asynchronous, callback_timeout, *args, **kwargs)
    766         else:
    767             return sync(
--> 768                 self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
    769             )
    770 

/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/utils.py in sync(loop, func, callback_timeout, *args, **kwargs)
    332     if error[0]:
    333         typ, exc, tb = error[0]
--> 334         raise exc.with_traceback(tb)
    335     else:
    336         return result[0]

/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/utils.py in f()
    316             if callback_timeout is not None:
    317                 future = gen.with_timeout(timedelta(seconds=callback_timeout), future)
--> 318             result[0] = yield future
    319         except Exception as exc:
    320             error[0] = sys.exc_info()

/srv/conda/envs/notebook/lib/python3.7/site-packages/tornado/gen.py in run(self)
    733 
    734                     try:
--> 735                         value = future.result()
    736                     except Exception:
    737                         exc_info = sys.exc_info()

/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/client.py in _gather(self, futures, errors, direct, local_worker)
   1727                             exc = CancelledError(key)
   1728                         else:
-> 1729                             raise exception.with_traceback(traceback)
   1730                         raise exc
   1731                     if errors == "skip":

/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/protocol/pickle.py in loads()
     57 def loads(x):
     58     try:
---> 59         return pickle.loads(x)
     60     except Exception:
     61         logger.info("Failed to deserialize %s", x[:10000], exc_info=True)

/srv/conda/envs/notebook/lib/python3.7/site-packages/numpy/core/__init__.py in _ufunc_reconstruct()
    123     # scipy.special.expit for instance.
    124     mod = __import__(module, fromlist=[name])
--> 125     return getattr(mod, name)
    126 
    127 def _ufunc_reduce(func):

AttributeError: module '__main__' has no attribute 'rho'

Minimal Example

I believe this reproduces the core problem

import numpy as np
from numba import vectorize, float64, float32
import dask.array as dsa
from dask.distributed import Client
client = Client()

# define a numba ufunc
@vectorize([float64(float64), float32(float32)], nopython=True)
def test_numba(a):
    return a**2

# verify that the client can run it
def try_numba_on_client():
    data = np.arange(5, dtype='f4')
    return test_numba(data)
client.run(try_numba_on_client)
# works, output is:
# > {'tcp://127.0.0.1:37583': array([ 0.,  1.,  4.,  9., 16.]),
# > 'tcp://127.0.0.1:44855': array([ 0.,  1.,  4.,  9., 16.])}

# use in a computation
data_dask = dsa.arange(5, dtype='f4')
test_numba(data_dask).compute()

At this point I get a KilledWorker error. In the worker log, I can see the following error (sorry for the lack of formatting--that's how it comes out of the worker error logs)

distributed.worker - ERROR - module '__main__' has no attribute 'test_numba'
Traceback (most recent call last): File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/worker.py", line 905, in handle_scheduler comm, every_cycle=[self.ensure_communicating, self.ensure_computing] File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/core.py", line 456, in handle_stream msgs = await comm.read() File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/comm/tcp.py", line 222, in read frames, deserialize=self.deserialize, deserializers=deserializers File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/comm/utils.py", line 69, in from_frames res = _from_frames() File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/comm/utils.py", line 55, in _from_frames frames, deserialize=deserialize, deserializers=deserializers File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/protocol/core.py", line 124, in loads value = _deserialize(head, fs, deserializers=deserializers) File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/protocol/serialize.py", line 255, in deserialize deserializers=deserializers, File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/protocol/serialize.py", line 268, in deserialize return loads(header, frames) File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/protocol/serialize.py", line 62, in pickle_loads return pickle.loads(b"".join(frames)) File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/protocol/pickle.py", line 59, in loads return pickle.loads(x) File "/srv/conda/envs/notebook/lib/python3.7/site-packages/numpy/core/__init__.py", line 125, in _ufunc_reconstruct return getattr(mod, name)
AttributeError: module '__main__' has no attribute 'test_numba'

The basic error appears to be the same as in the full example.

This seems like a pretty straightforward use of numba + distributed, and I assumed this sort of usage was supported. Am I missing something obvious?

Installed versions

I'm on dask 2.9.0 and numba 0.48.0.

>>> client.get_versions(check=True)
{'scheduler': {'host': (('python', '3.7.6.final.0'),
   ('python-bits', 64),
   ('OS', 'Linux'),
   ('OS-release', '4.19.76+'),
   ('machine', 'x86_64'),
   ('processor', 'x86_64'),
   ('byteorder', 'little'),
   ('LC_ALL', 'en_US.UTF-8'),
   ('LANG', 'en_US.UTF-8'),
   ('LOCALE', 'en_US.UTF-8')),
  'packages': {'required': (('dask', '2.9.0'),
    ('distributed', '2.9.0'),
    ('msgpack', '0.6.2'),
    ('cloudpickle', '1.2.2'),
    ('tornado', '6.0.3'),
    ('toolz', '0.10.0')),
   'optional': (('numpy', '1.17.3'),
    ('pandas', '0.25.3'),
    ('bokeh', '1.4.0'),
    ('lz4', '2.2.1'),
    ('dask_ml', '1.1.1'),
    ('blosc', '1.8.1'))}},
 'workers': {'tcp://10.32.181.10:45663': {'host': (('python', '3.7.6.final.0'),
    ('python-bits', 64),
    ('OS', 'Linux'),
    ('OS-release', '4.19.76+'),
    ('machine', 'x86_64'),
    ('processor', 'x86_64'),
    ('byteorder', 'little'),
    ('LC_ALL', 'en_US.UTF-8'),
    ('LANG', 'en_US.UTF-8'),
    ('LOCALE', 'en_US.UTF-8')),
   'packages': {'required': (('dask', '2.9.0'),
     ('distributed', '2.9.0'),
     ('msgpack', '0.6.2'),
     ('cloudpickle', '1.2.2'),
     ('tornado', '6.0.3'),
     ('toolz', '0.10.0')),
    'optional': (('numpy', '1.17.3'),
     ('pandas', '0.25.3'),
     ('bokeh', '1.4.0'),
     ('lz4', '2.2.1'),
     ('dask_ml', '1.1.1'),
     ('blosc', '1.8.1'))}},
  'tcp://10.32.181.11:37259': {'host': (('python', '3.7.6.final.0'),
    ('python-bits', 64),
    ('OS', 'Linux'),
    ('OS-release', '4.19.76+'),
    ('machine', 'x86_64'),
    ('processor', 'x86_64'),
    ('byteorder', 'little'),
    ('LC_ALL', 'en_US.UTF-8'),
    ('LANG', 'en_US.UTF-8'),
    ('LOCALE', 'en_US.UTF-8')),
   'packages': {'required': (('dask', '2.9.0'),
     ('distributed', '2.9.0'),
     ('msgpack', '0.6.2'),
     ('cloudpickle', '1.2.2'),
     ('tornado', '6.0.3'),
     ('toolz', '0.10.0')),
    'optional': (('numpy', '1.17.3'),
     ('pandas', '0.25.3'),
     ('bokeh', '1.4.0'),
     ('lz4', '2.2.1'),
     ('dask_ml', '1.1.1'),
     ('blosc', '1.8.1'))}}},
 'client': {'host': [('python', '3.7.6.final.0'),
   ('python-bits', 64),
   ('OS', 'Linux'),
   ('OS-release', '4.19.76+'),
   ('machine', 'x86_64'),
   ('processor', 'x86_64'),
   ('byteorder', 'little'),
   ('LC_ALL', 'en_US.UTF-8'),
   ('LANG', 'en_US.UTF-8'),
   ('LOCALE', 'en_US.UTF-8')],
  'packages': {'required': [('dask', '2.9.0'),
    ('distributed', '2.9.0'),
    ('msgpack', '0.6.2'),
    ('cloudpickle', '1.2.2'),
    ('tornado', '6.0.3'),
    ('toolz', '0.10.0')],
   'optional': [('numpy', '1.17.3'),
    ('pandas', '0.25.3'),
    ('bokeh', '1.4.0'),
    ('lz4', '2.2.1'),
    ('dask_ml', '1.1.1'),
    ('blosc', '1.8.1')]}}}
@rabernat
Copy link
Author

rabernat commented Feb 6, 2020

And I just found numba/numba#4314, which seems to be the dual of this issue in the numba repo...

That issue suggests that the problem is just with dynamically defined functions (i.e. functions defined in the notebook interpreted), as in my minimal example. But I am still having the same problem with my full example, where the functions are defined in a module in an installed package.

@mrocklin
Copy link
Member

mrocklin commented Feb 6, 2020

cc @seibert from numba and @TomAugspurger for the Anaconda/Pangeo connection

@rabernat
Copy link
Author

rabernat commented Feb 7, 2020

Via gitter, @mrocklin pointed me to the official dask example with numba:
https://examples.dask.org/applications/stencils-with-numba.html

I have confirmed that the official example only works with a single multi-threaded worker. So the problem is related to multiprocessing and serialization.

More specifically, the cluster in that example is created like this:

from dask.distributed import Client, progress
client = Client(threads_per_worker=4,
                n_workers=1,
                processes=False,
                memory_limit='4GB')

If instead, I change to

client = Client(n_workers=4,
                memory_limit='4GB')

The workers die with the error message

File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/protocol/pickle.py", line 59, in loads return pickle.loads(x) File "/srv/conda/envs/notebook/lib/python3.7/site-packages/numpy/core/__init__.py", line 125, in _ufunc_reconstruct return getattr(mod, name) AttributeError: module '__main__' has no attribute 'smooth

My conclusion from this is that numba functions are generally incompatible with distributed because they don't serialize correctly. Can anyone from numba provide some confirmation on this?

@rabernat
Copy link
Author

rabernat commented Feb 7, 2020

Related issue in numba, closed a while back, suggesting this should be resolved: numba/numba#2943

@TomAugspurger
Copy link
Member

Looking into this now.

But I am still having the same problem with my full example, where the functions are defined in a module in an installed package.

@rabernat as a quick workaround, you might be able to fix this by importing the functions like

from mypackage import test_numba
test_numba(...)

rather than

import mypackage

mypackage.test_numba(...)

Will figure out a proper fix.

@TomAugspurger
Copy link
Member

TomAugspurger commented Feb 11, 2020

I have a slightly better understanding of the situation now. The call order is something like

numba_ufunc(dask_array) ->
  numba_ufunc.ufunc(dask_array) ->
    dask_array.__array_ufunc__(...)

The test_numba.ufunc is a NumPy ufunc that is (I think) dynamically generated by numba.

In [2]: b.test_numba
Out[2]: <numba._DUFunc 'test_numba'>

In [3]: b.test_numba.ufunc
Out[3]: <ufunc 'test_numba'>

In [4]: type(b.test_numba.ufunc)
Out[4]: numpy.ufunc

And that's what chokes up dask's serialization

In [9]: pickle.loads(pickle.dumps(b.test_numba))
Out[9]: <numba._DUFunc 'test_numba'>

In [10]: pickle.loads(pickle.dumps(b.test_numba.ufunc))
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-10-cff7e39b4aa8> in <module>
----> 1 pickle.loads(pickle.dumps(b.test_numba.ufunc))

~/Envs/dask-dev/lib/python3.7/site-packages/numpy/core/__init__.py in _ufunc_reconstruct(module, name)
    130     # scipy.special.expit for instance.
    131     mod = __import__(module, fromlist=[name])
--> 132     return getattr(mod, name)
    133
    134 def _ufunc_reduce(func):

AttributeError: module '__main__' has no attribute 'test_numba'
# b.py
from numba import vectorize, float64, float32


@vectorize([float64(float64), float32(float32)], nopython=True)
def test_numba(a):
    return a**2

Will start looking for solutions now.

@rabernat
Copy link
Author

Thanks so much for looking into this. Before you dig too deep, it's worth confirming whether the problem in my toy example is indeed the same one as in the full example. The key difference is that in the full example, the function is defined in a module in a package, which is installed with pip.

@TomAugspurger
Copy link
Member

TomAugspurger commented Feb 11, 2020

Thanks, my hope is that my file b.py sufficiently simulates a 3rd-party library function. I'm able to reproduce the same error with

import pickle
import fastjmd95

if __name__ == "__main__":
    pickle.loads(pickle.dumps(fastjmd95.rho.ufunc))

@TomAugspurger
Copy link
Member

TomAugspurger commented Feb 11, 2020

@sklam if we wanted to support this, it seems like we'd need to ensure that DUFunc.ufunc is pickleable. IIUC, that's created at https://github.com/numba/numba/blob/fd8c232bb37a1945e8dc8becef9fe05fdd78c4cf/numba/np/ufunc/_internal.c#L227-L229. Does that sound right?

Or perhaps there's a deeper issue with the generated. Do you know why pickle.loads is getting into numpy/core/__init__.py in _ufunc_reconstruct? Are we missing a namespace on the generated ufunc or something like that?

@TomAugspurger
Copy link
Member

FWIW, I can't reproduce the issue going through the ufunc tutorial in https://docs.scipy.org/doc/numpy/user/c-info.ufunc-tutorial.html.

@TomAugspurger
Copy link
Member

FYI @rabernat, as a workaround you / your users can use Array.map_blocks(numba_ufunc). This avoids putting the problematic numba_ufunc.ufunc in the Dask graph, so serialization isn't a problem.

(That of course sacrifices writing generic code that works on ndarrays, DataArrays, & dask ararys, so I'll continue to look into this).

@TomAugspurger
Copy link
Member

Right now, I think handling this is either on Numba or (unfortunately) Numba's users.

I have a hack at https://gist.github.com/TomAugspurger/38c68595a91387926907a2436305c8c2 that nobody should really use, but is possibly an option for libraries like fastjmd95 (though it will need way more vetting). That gist includes a description of what's going on and why it does it.

Ideally, Numba could make things better by implementing custom pickling handlers for the DUFunc.ufunc objects. I'm not sure if this is possible though, since these are real numpy.ufunc instances, which numpy controls. They aren't associated with a __module__, so the usual pickle mechanisms fail. And I don't know it'd be possible to make some kind of ufunc subclass, or attach otherwise override __reduce__, since this is all done in C against the CPython API. Hopefully @sklam has some thoughts on whether there are options in numba.

@mrocklin
Copy link
Member

mrocklin commented Feb 12, 2020 via email

@TomAugspurger
Copy link
Member

I think if Numba is unable to handle this, then it's worth proposing a change to NumPy to avoid the patching in https://gist.github.com/TomAugspurger/38c68595a91387926907a2436305c8c2.

@mrocklin
Copy link
Member

mrocklin commented Feb 12, 2020 via email

@seberg
Copy link

seberg commented Feb 12, 2020

Interesting issue, I have never thought about pickling of user ufuncs. It seems like NumPy would have to know the fully qualified module+name of the desired ufunc. Tom's solution seems reasonable to me, but not something that can be done in NumPy, since NumPy knows nothing about numba ufuncs.

It seems like something to keep in mind when we revise the UFunc API, which I assume we have to do sooner rather than later hopefully. I guess we could add API right now in principle, although the solution from within numba seems simpler? It would have to be new C-API.

I wonder if it breaks strange things like unpickling of np.core.umath._ones_like? Which maybe does not matter...

@TomAugspurger
Copy link
Member

Interesting issue, I have never thought about pickling of user ufuncs.

Just to clarify, this is only for user-defined ufuncs that don't have a __module__. Ones that do, like https://docs.scipy.org/doc/numpy/user/c-info.ufunc-tutorial.html, work just fine as is.

It seems like NumPy would have to know the fully qualified module+name of the desired ufunc

That's the main issue from NumPy's side. My _ufunc_reduce essentially hardcodes the 3rd party module name into the pickle stream. That's what keeps things relatively simple on the unpickling side (just a getattr(mod, name).ufunc) to get the DUFunc.ufunc). I haven't tested, but I suspect that my hack won't work if two 3rd-party modules wanted to serialize numba functions. I'm not sure yet what a solution that's fit for NumPy would look like yet.

@seberg
Copy link

seberg commented Feb 13, 2020

Ah OK, interesting, although I do not quite see how it actually knows the correct module :). Is there a way that the @vectorize could fix it up, or do we have to move that into the NumPy API in some future to make it work?

@TomAugspurger
Copy link
Member

Is there a way that the @vectorize could fix it up[...]?

I'm not sure. AFAICT, we need to somehow tell the DUFunc.ufunc what it's "module" is, so that it can customize its __reduce__ method. In pseudo-python

class DUFunc:
    def __init__(self, ...):
        self.ufunc = generate_ufunc(module=self.__class__.__module__)


def generate_ufunc(..., module):
    ufunc = ...
    ufunc.__reduce__ = custom_reduce(module)
    return ufunc

But the generate_ufunc is working with the CPython API in https://github.com/numba/numba/blob/master/numba/np/ufunc/_internal.c, and I have no idea if you can "subclass" numpy.ufunc to override the __reduce__.

@seberg
Copy link

seberg commented Feb 13, 2020

Right, with the wrapping logic, I have no idea if there is any chance of that working. ufuncs cannot be subclassed right now.

@TomAugspurger
Copy link
Member

Thanks. I'll attempt to summarize the current state of things.

It'd be nice to write numba_ufunc(dask_array), to retain symmetry with numba_ufunc(ndarray) / any object implementing __array_ufunc__. Currently that doesn't work well with distributed because numba_ufunc.__call__(x) does numba_ufunc.ufunc.__call__(x), so the dynamically generated numpy ufunc is passed to Dask, which stores it deep in the resulting task graph. This NumPy .ufunc doesn't have a module, so it can't be serialized by normal means.

As someone who doesn't know C and doesn't know how ufuncs are implemented, it'd be nice if NumPy used something like __reduce__ to pickle ufuncs (which would be helpful for NumPy too), and it'd be nice if numba could override this __reduce__ in their DUFunc.ufunc instance. I don't know how feasible that is.

@seberg
Copy link

seberg commented Feb 18, 2020

I guess its tricky, since the numpy ufunc does the __array_ufunc__ dispatching. Implementing reduce/reduce_ex is not an issue as such, it is just the same as in C, although we definitely would need to add C-API to do it (allowing subclassing could work, I am not sure if there would be any big implication, if we could e.g. limit the subclassing to the Python side).

UFuncs do not have reasonable state you can store as such, they are more like builtin classes in that regard which are also singleton with state set at import time. I think we would have to either provide only a way to load it as a fully qualified path (basically just allowing to set __module__ for pickle purposes. Or maybe allow to set a parent, so that we can unpickle parent first and then extract parent.ufunc. It might be nice if the ufunc doesn't really need to know about its parent though, since it creates reference cycles?

@TomAugspurger
Copy link
Member

This issue came up again on the pangeo call today.

@sklam or @stuartarchibald, assuming this is difficult / impossible to solve on the NumPy side, do you have any guesses on how Numba could avoid creating the dynamically generated ufunc at https://github.com/numba/numba/blob/fd8c232bb37a1945e8dc8becef9fe05fdd78c4cf/numba/np/ufunc/_internal.c#L227-L229? The summary of the issue is at #3450 (comment).

@TomAugspurger
Copy link
Member

@seberg I looked into this again today. I'm afraid that doing things properly by implementing some kind of overrideable ufunc.__reduce__ using the Python C API is a bit beyond me right now. What do you think about an explicit solution like the following:

  1. NumPy stores a global dict of {ufunc_instance: (module, name)}
  2. Dynamic ufuncs register themselves with NumPy.
diff --git a/numpy/core/__init__.py b/numpy/core/__init__.py
index c77885954..6f816747c 100644
--- a/numpy/core/__init__.py
+++ b/numpy/core/__init__.py
@@ -117,6 +117,20 @@ __all__ += einsumfunc.__all__
 #  Here are the loading and unloading functions
 # The name numpy.core._ufunc_reconstruct must be
 #   available for unpickling to work.
+
+
+_ufunc_modules = {}  # Dict[ufunc, Tuple[Callable, Tuple]]]
+
+def _register_ufunc(ufunc, reconstruct_function, reconstruct_args):
+    _ufunc_modules[ufunc] = (reconstruct_function, reconstruct_args)
+
+
+def _ufunc_reconstruct_registered(module, name):
+    import operator
+    mod = __import__(module)
+    return operator.attrgetter(name)(mod)
+
+
 def _ufunc_reconstruct(module, name):
     # The `fromlist` kwarg is required to ensure that `mod` points to the
     # inner-most module rather than the parent package when module name is
@@ -128,7 +142,11 @@ def _ufunc_reconstruct(module, name):
 def _ufunc_reduce(func):
     from pickle import whichmodule
     name = func.__name__
-    return _ufunc_reconstruct, (whichmodule(func, name), name)
+    if func in _ufunc_modules:
+        reconstruct_func, args = _ufunc_modules[func]
+        return reconstruct_func, args
+    else:
+        return _ufunc_reconstruct, (whichmodule(func, name), name)
 
 
 import copyreg

Then numba could use it like

import copyreg
import numpy as np
from numba import vectorize, float64, float32


@vectorize([float64(float64), float32(float32)], nopython=True)
def test_numba(a):
    return a ** 2

# This would be done as part of numba's DUFunc.__init__, not by the user
def _reconstruct_func(module, name):
    import importlib
    import operator
    module = importlib.import_module(module)
    return operator.attrgetter(name)(module)

np.core._register_ufunc(test_numba.ufunc, _reconstruct_func, (__name__, "test_numba.ufunc"))

Do you see any hope for something like that being merged into NumPy? Or would we need to wait for a proper solution doing things in C?

@seberg
Copy link

seberg commented Sep 10, 2020

@TomAugspurger, seems hackish, but maybe a band-aid is better than nothing.

However, I tried around a bit and I think we are missing that pickle got better, or how reliable it actually is? I.e. I think NumPy is over-engineered and that makes the solution harder than necessary. I tried modifying NumPy like this, but you can also do it manually:


if np.version < "1.20":  # use correct comparison:
    import copyreg

    def _ufunc_reduce(func):
        return func.__name__

    copyreg.pickle(ufunc, _ufunc_reduce, np.core._ufunc_reconstruct)

Now you need one more ingredient, and that is that test_numba.ufunc has to report its __name__ as test_numba.ufunc (a bit like a __qualname__. I tried this by hacking that the ufunc name is mutable. __qualname__ would maybe be better, and I guess we could add a __qualname__ to UFuncs, but if printing the extra .ufunc seems OK, this solution is possible right now maybe.

Now overriding the ufunc pickling outside of NumPy seems pretty extreme, but I am not actually sure its all that bad, I did not check, but I think the above replacement is effectively identical to what NumPy does, except that it supports attributes in a __qualname__ like fashion.

seberg added a commit to seberg/numpy that referenced this issue Sep 10, 2020
This also allows at least in principle numba dynamically
generated ufuncs to be pickled (with some hacking), see:

dask/distributed#3450

If the name of the ufunc is set to a qualname, using this method,
pickle should be able to unpickle the ufunc correctly.
We may want to allow setting the module and qualname explicitly
on the ufunc object to remove the need for the custom pickler
completely.
@TomAugspurger
Copy link
Member

I'm not sure I follow your example, sorry. Is there an else block missing?

But if I understand the spirit of the example, a (mutable?) ufunc.__qualname__ would perhaps work for the numba example.

@seberg
Copy link

seberg commented Sep 10, 2020

@TomAugspurger, sorry, I put the if, just because I assume we can do this change in NumPy master, but numba could use the hack to support older NumPy versions.

A (semi) mutable __qualname__ would definitely be cleaner. But if Numba is OK with test_numba.ufunc.__name__ = "test_numba.ufunc", then this will work right now without any modification, I think.

@TomAugspurger
Copy link
Member

Thanks for numpy/numpy#17289

So IIUC, we'd update somewhere around https://github.com/numpy/numpy/blob/7e9d603664edc756f555fecf8649bf888a46d47c/numpy/core/src/umath/ufunc_object.c#L6117-L6119 to make it setable? And when numba generates the ufunc dynamically, they'd set ufunc.__name__ to be __module__.<function_name>.ufunc? I might be able to manage that.

@seberg
Copy link

seberg commented Sep 10, 2020

I am thinking that in theory, numba could set a different name here: https://github.com/numba/numba/blob/fd8c232bb37a1945e8dc8becef9fe05fdd78c4cf/numba/np/ufunc/_internal.c#L213

at least for now. It is all a bit problematic I admit, if you want to not leak the name (Numba can probably hack this, but I am not quite sure, its all pretty brittle and since Numba copies some NumPy code, in the end numba probably just needs to always make sure to keep up with any NumPy change).

But yes, adding __module__ and __qualname__ as property attributes would seem like a good solution. I guess, I ideally those could be only be set at construction time and not mutated later on, so that makes me slightly hesitant as to what the API should look like. But maybe we can be practical about it, unless numba is just fine with modfying the __name__ for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants