diff --git a/src/mixtures/em.jl b/src/mixtures/em.jl index 5f6a7e30..14e158d8 100644 --- a/src/mixtures/em.jl +++ b/src/mixtures/em.jl @@ -52,13 +52,13 @@ function clustering(data, mix_num::Int64; maxiter=200)::Vector return [data] end data = Matrix(data) - R = kmeans(data, mix_num; maxiter=maxiter) + R = kmeans(data', mix_num; maxiter=maxiter) @assert nclusters(R) == mix_num a = assignments(R) clustered_data = Vector() for k in 1 : mix_num - push!(clustered_data, DataFrame(data[:, findall(x -> x == k, a)]')) + push!(clustered_data, DataFrame(data[findall(x -> x == k, a), :])) end return clustered_data