Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add complex dtype support for mean #850

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

kgryte
Copy link
Contributor

@kgryte kgryte commented Oct 31, 2024

This PR

  • resolves RFC: add support for complex input to mean #846.
  • adds support for complex data types when computing the arithmetic mean.
  • includes special cases which deviate from NumPy et al, but are consistent with real-valued floating-point special cases as outlined in the specification. Namely, if the arithmetic mean is computed as a sum followed by scalar division, then NaN components should only propagate relative to the respective component. However, in NumPy, given array([ 1.+0.j, 2.+0.j, nan+0.j]), NumPy will return nan+nanj when invoking np.mean.

@kgryte kgryte added API extension Adds new functions or objects to the API. topic: Complex Data Types Complex number data types. topic: Statistics Statistics. Needs Review Pull request which needs review. labels Oct 31, 2024
@kgryte kgryte added this to the v2024 milestone Oct 31, 2024
@kgryte kgryte changed the title feat: add complex dtype support mean feat: add complex dtype support for mean Oct 31, 2024
@rgommers
Copy link
Member

rgommers commented Nov 3, 2024

  • includes special cases which deviate from NumPy et al, but are consistent with real-valued floating-point special cases as outlined in the specification. Namely, if the arithmetic mean is computed as a sum followed by scalar division, then NaN components should only propagate relative to the respective component. However, in NumPy, given array([ 1.+0.j, 2.+0.j, nan+0.j]), NumPy will return nan+nanj when invoking np.mean.

PyTorch, JAX and CuPy all yield the same nan + nanj result as NumPy. I don't think it's very useful to specify a special case like this, since it'll just be ignored and in practice be wrong. So I'd at least use "should" rather than "must".

It's possible the behavior in NumPy may change (following CPython's change for division): numpy/numpy#26560.

@asmeurer
Copy link
Member

asmeurer commented Nov 6, 2024

I thought we had decided at some point that the distinction between things like 1. + nanj and nan + nanj is generally undefined in the standard (or am I just remembering a similar discussion about infinites?).

I've certainly taken that view in some places when implement things, taking something like "the result should be nan" to mean isnan(x) should give True, which is the case if either or both components are nan.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API extension Adds new functions or objects to the API. Needs Review Pull request which needs review. topic: Complex Data Types Complex number data types. topic: Statistics Statistics.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

RFC: add support for complex input to mean
3 participants