Skip to content

Commit

Permalink
Merge pull request #496 from emmaai/master
Browse files Browse the repository at this point in the history
Support multithreading
  • Loading branch information
FrancescAlted authored Sep 13, 2024
2 parents 7484828 + c4f527d commit a99412e
Show file tree
Hide file tree
Showing 4 changed files with 308 additions and 5 deletions.
152 changes: 152 additions & 0 deletions bench/large_array_vs_numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
#################################################################################
# To mimic the scenario that computation is i/o bound and constrained by memory
#
# It's a much simplified version that the chunk is computed in a loop,
# and expression is evaluated in a sequence, which is not true in reality.
# Neverthless, numexpr outperforms numpy.
#################################################################################
"""
Benchmarking Expression 1:
NumPy time (threaded over 32 chunks with 2 threads): 4.612313 seconds
numexpr time (threaded with re_evaluate over 32 chunks with 2 threads): 0.951172 seconds
numexpr speedup: 4.85x
----------------------------------------
Benchmarking Expression 2:
NumPy time (threaded over 32 chunks with 2 threads): 23.862752 seconds
numexpr time (threaded with re_evaluate over 32 chunks with 2 threads): 2.182058 seconds
numexpr speedup: 10.94x
----------------------------------------
Benchmarking Expression 3:
NumPy time (threaded over 32 chunks with 2 threads): 20.594895 seconds
numexpr time (threaded with re_evaluate over 32 chunks with 2 threads): 2.927881 seconds
numexpr speedup: 7.03x
----------------------------------------
Benchmarking Expression 4:
NumPy time (threaded over 32 chunks with 2 threads): 12.834101 seconds
numexpr time (threaded with re_evaluate over 32 chunks with 2 threads): 5.392480 seconds
numexpr speedup: 2.38x
----------------------------------------
"""

import os

os.environ["NUMEXPR_NUM_THREADS"] = "16"
import numpy as np
import numexpr as ne
import timeit
import threading

array_size = 10**8
num_runs = 10
num_chunks = 32 # Number of chunks
num_threads = 2 # Number of threads constrained by how many chunks memory can hold

a = np.random.rand(array_size).reshape(10**4, -1)
b = np.random.rand(array_size).reshape(10**4, -1)
c = np.random.rand(array_size).reshape(10**4, -1)

chunk_size = array_size // num_chunks

expressions_numpy = [
lambda a, b, c: a + b * c,
lambda a, b, c: a**2 + b**2 - 2 * a * b * np.cos(c),
lambda a, b, c: np.sin(a) + np.log(b) * np.sqrt(c),
lambda a, b, c: np.exp(a) + np.tan(b) - np.sinh(c),
]

expressions_numexpr = [
"a + b * c",
"a**2 + b**2 - 2 * a * b * cos(c)",
"sin(a) + log(b) * sqrt(c)",
"exp(a) + tan(b) - sinh(c)",
]


def benchmark_numpy_chunk(func, a, b, c, results, indices):
for index in indices:
start = index * chunk_size
end = (index + 1) * chunk_size
time_taken = timeit.timeit(
lambda: func(a[start:end], b[start:end], c[start:end]), number=num_runs
)
results.append(time_taken)


def benchmark_numexpr_re_evaluate(expr, a, b, c, results, indices):
for index in indices:
start = index * chunk_size
end = (index + 1) * chunk_size
if index == 0:
# Evaluate the first chunk with evaluate
time_taken = timeit.timeit(
lambda: ne.evaluate(
expr,
local_dict={
"a": a[start:end],
"b": b[start:end],
"c": c[start:end],
},
),
number=num_runs,
)
else:
# Re-evaluate subsequent chunks with re_evaluate
time_taken = timeit.timeit(
lambda: ne.re_evaluate(
local_dict={"a": a[start:end], "b": b[start:end], "c": c[start:end]}
),
number=num_runs,
)
results.append(time_taken)


def run_benchmark_threaded():
chunk_indices = list(range(num_chunks))

for i in range(len(expressions_numpy)):
print(f"Benchmarking Expression {i+1}:")

results_numpy = []
results_numexpr = []

threads_numpy = []
for j in range(num_threads):
indices = chunk_indices[j::num_threads] # Distribute chunks across threads
thread = threading.Thread(
target=benchmark_numpy_chunk,
args=(expressions_numpy[i], a, b, c, results_numpy, indices),
)
threads_numpy.append(thread)
thread.start()

for thread in threads_numpy:
thread.join()

numpy_time = sum(results_numpy)
print(
f"NumPy time (threaded over {num_chunks} chunks with {num_threads} threads): {numpy_time:.6f} seconds"
)

threads_numexpr = []
for j in range(num_threads):
indices = chunk_indices[j::num_threads] # Distribute chunks across threads
thread = threading.Thread(
target=benchmark_numexpr_re_evaluate,
args=(expressions_numexpr[i], a, b, c, results_numexpr, indices),
)
threads_numexpr.append(thread)
thread.start()

for thread in threads_numexpr:
thread.join()

numexpr_time = sum(results_numexpr)
print(
f"numexpr time (threaded with re_evaluate over {num_chunks} chunks with {num_threads} threads): {numexpr_time:.6f} seconds"
)
print(f"numexpr speedup: {numpy_time / numexpr_time:.2f}x")
print("-" * 40)


if __name__ == "__main__":
run_benchmark_threaded()
8 changes: 3 additions & 5 deletions numexpr/necompiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

is_cpu_amd_intel = False # DEPRECATION WARNING: WILL BE REMOVED IN FUTURE RELEASE
from numexpr import interpreter, expressions, use_vml
from numexpr.utils import CacheDict
from numexpr.utils import CacheDict, ContextDict

# Declare a double type that does not exist in Python space
double = numpy.double
Expand Down Expand Up @@ -776,11 +776,9 @@ def getArguments(names, local_dict=None, global_dict=None, _frame_depth: int=2):
# Dictionaries for caching variable names and compiled expressions
_names_cache = CacheDict(256)
_numexpr_cache = CacheDict(256)
_numexpr_last = {}
_numexpr_last = ContextDict()
evaluate_lock = threading.Lock()

# MAYBE: decorate this function to add attributes instead of having the
# _numexpr_last dictionary?
def validate(ex: str,
local_dict: Optional[Dict] = None,
global_dict: Optional[Dict] = None,
Expand Down Expand Up @@ -887,7 +885,7 @@ def validate(ex: str,
compiled_ex = _numexpr_cache[numexpr_key] = NumExpr(ex, signature, sanitize=sanitize, **context)
kwargs = {'out': out, 'order': order, 'casting': casting,
'ex_uses_vml': ex_uses_vml}
_numexpr_last = dict(ex=compiled_ex, argnames=names, kwargs=kwargs)
_numexpr_last.set(ex=compiled_ex, argnames=names, kwargs=kwargs)
except Exception as e:
return e
return None
Expand Down
72 changes: 72 additions & 0 deletions numexpr/tests/test_numexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,6 +1201,7 @@ def run(self):
test.join()

def test_multithread(self):

import threading

# Running evaluate() from multiple threads shouldn't crash
Expand All @@ -1218,6 +1219,77 @@ def work(n):
for t in threads:
t.join()

def test_thread_safety(self):
"""
Expected output
When not safe (before the pr this test is commited)
AssertionError: Thread-0 failed: result does not match expected
When safe (after the pr this test is commited)
Should pass without failure
"""
import threading
import time

barrier = threading.Barrier(4)

# Function that each thread will run with different expressions
def thread_function(a_value, b_value, expression, expected_result, results, index):
validate(expression, local_dict={"a": a_value, "b": b_value})
# Wait for all threads to reach this point
# such that they all set _numexpr_last
barrier.wait()

# Simulate some work or a context switch delay
time.sleep(0.1)

result = re_evaluate(local_dict={"a": a_value, "b": b_value})
results[index] = np.array_equal(result, expected_result)

def test_thread_safety_with_numexpr():
num_threads = 4
array_size = 1000000

expressions = [
"a + b",
"a - b",
"a * b",
"a / b"
]

a_value = [np.full(array_size, i + 1) for i in range(num_threads)]
b_value = [np.full(array_size, (i + 1) * 2) for i in range(num_threads)]

expected_results = [
a_value[i] + b_value[i] if expr == "a + b" else
a_value[i] - b_value[i] if expr == "a - b" else
a_value[i] * b_value[i] if expr == "a * b" else
a_value[i] / b_value[i] if expr == "a / b" else None
for i, expr in enumerate(expressions)
]

results = [None] * num_threads
threads = []

# Create and start threads with different expressions
for i in range(num_threads):
thread = threading.Thread(
target=thread_function,
args=(a_value[i], b_value[i], expressions[i], expected_results[i], results, i)
)
threads.append(thread)
thread.start()

for thread in threads:
thread.join()

for i in range(num_threads):
if not results[i]:
self.fail(f"Thread-{i} failed: result does not match expected")

test_thread_safety_with_numexpr()


# The worker function for the subprocess (needs to be here because Windows
# has problems pickling nested functions with the multiprocess module :-/)
Expand Down
81 changes: 81 additions & 0 deletions numexpr/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import os
import subprocess
import contextvars

from numexpr.interpreter import _set_num_threads, _get_num_threads, MAX_THREADS
from numexpr import use_vml
Expand Down Expand Up @@ -226,3 +227,83 @@ def __setitem__(self, key, value):
super(CacheDict, self).__delitem__(k)
super(CacheDict, self).__setitem__(key, value)


class ContextDict:
"""
A context aware version dictionary
"""
def __init__(self):
self._context_data = contextvars.ContextVar('context_data', default={})

def set(self, key=None, value=None, **kwargs):
data = self._context_data.get().copy()

if key is not None:
data[key] = value

for k, v in kwargs.items():
data[k] = v

self._context_data.set(data)

def get(self, key, default=None):
data = self._context_data.get()
return data.get(key, default)

def delete(self, key):
data = self._context_data.get().copy()
if key in data:
del data[key]
self._context_data.set(data)

def clear(self):
self._context_data.set({})

def all(self):
return self._context_data.get()

def update(self, *args, **kwargs):
data = self._context_data.get().copy()

if args:
if len(args) > 1:
raise TypeError(f"update() takes at most 1 positional argument ({len(args)} given)")
other = args[0]
if isinstance(other, dict):
data.update(other)
else:
for k, v in other:
data[k] = v

data.update(kwargs)
self._context_data.set(data)

def keys(self):
return self._context_data.get().keys()

def values(self):
return self._context_data.get().values()

def items(self):
return self._context_data.get().items()

def __getitem__(self, key):
return self.get(key)

def __setitem__(self, key, value):
self.set(key, value)

def __delitem__(self, key):
self.delete(key)

def __contains__(self, key):
return key in self._context_data.get()

def __len__(self):
return len(self._context_data.get())

def __iter__(self):
return iter(self._context_data.get())

def __repr__(self):
return repr(self._context_data.get())

0 comments on commit a99412e

Please sign in to comment.