From ca40da451f896acb404d56dc5cf7c424f88c1893 Mon Sep 17 00:00:00 2001
From: Mike Rouleau <44850854+mikerouleau@users.noreply.github.com>
Date: Sat, 9 Apr 2022 20:08:58 +0000
Subject: [PATCH] add LinearAlgebra.normalize fallback for scalars (#44835)

fix NaN == NaN error
---
 stdlib/LinearAlgebra/src/generic.jl  | 20 ++++++++++++++++----
 stdlib/LinearAlgebra/test/generic.jl |  6 ++++++
 2 files changed, 22 insertions(+), 4 deletions(-)

diff --git a/stdlib/LinearAlgebra/src/generic.jl b/stdlib/LinearAlgebra/src/generic.jl
index 2449a78fda317..c79849535ad0a 100644
--- a/stdlib/LinearAlgebra/src/generic.jl
+++ b/stdlib/LinearAlgebra/src/generic.jl
@@ -1790,11 +1790,12 @@ end
 end
 
 """
-    normalize(a::AbstractArray, p::Real=2)
+    normalize(a, p::Real=2)
 
-Normalize the array `a` so that its `p`-norm equals unity,
-i.e. `norm(a, p) == 1`.
-See also [`normalize!`](@ref) and [`norm`](@ref).
+Normalize `a` so that its `p`-norm equals unity,
+i.e. `norm(a, p) == 1`. For scalars, this is similar to sign(a),
+except normalize(0) = NaN.
+See also [`normalize!`](@ref), [`norm`](@ref), and [`sign`](@ref).
 
 # Examples
 ```jldoctest
@@ -1831,6 +1832,14 @@ julia> normalize(a)
  0.154303  0.308607  0.617213
  0.154303  0.308607  0.617213
 
+julia> normalize(3, 1)
+1.0
+
+julia> normalize(-8, 1)
+-1.0
+
+julia> normalize(0, 1)
+NaN
 ```
 """
 function normalize(a::AbstractArray, p::Real = 2)
@@ -1843,3 +1852,6 @@ function normalize(a::AbstractArray, p::Real = 2)
         return T[]
     end
 end
+
+normalize(x) = x / norm(x)
+normalize(x, p::Real) = x / norm(x, p)
diff --git a/stdlib/LinearAlgebra/test/generic.jl b/stdlib/LinearAlgebra/test/generic.jl
index 69f2fff00755f..8ca829184314b 100644
--- a/stdlib/LinearAlgebra/test/generic.jl
+++ b/stdlib/LinearAlgebra/test/generic.jl
@@ -375,6 +375,12 @@ end
     @test typeof(normalize([1 2 3; 4 5 6])) == Array{Float64,2}
 end
 
+@testset "normalize for scalars" begin
+    @test normalize(8.0) == 1.0
+    @test normalize(-3.0) == -1.0
+    @test isnan(normalize(0.0))
+end
+
 @testset "Issue #30466" begin
     @test norm([typemin(Int), typemin(Int)], Inf) == -float(typemin(Int))
     @test norm([typemin(Int), typemin(Int)], 1) == -2float(typemin(Int))