diff --git a/GNNGraphs/src/transform.jl b/GNNGraphs/src/transform.jl index 8df726752..8e8c98d13 100644 --- a/GNNGraphs/src/transform.jl +++ b/GNNGraphs/src/transform.jl @@ -151,12 +151,14 @@ end """ remove_edges(g::GNNGraph, edges_to_remove::AbstractVector{<:Integer}) + remove_edges(g::GNNGraph, p=0.5) -Remove specified edges from a GNNGraph. +Remove specified edges from a GNNGraph, either by specifying edge indices or by randomly removing edges with a given probability. # Arguments - `g`: The input graph from which edges will be removed. -- `edges_to_remove`: Vector of edge indices to be removed. +- `edges_to_remove`: Vector of edge indices to be removed. This argument is only required for the first method. +- `p`: Probability of removing each edge. This argument is only required for the second method and defaults to 0.5. # Returns A new GNNGraph with the specified edges removed. @@ -178,6 +180,14 @@ julia> g_new GNNGraph: num_nodes: 3 num_edges: 4 + +# Remove edges with a probability of 0.5 +julia> g_new = remove_edges(g, 0.5); + +julia> g_new +GNNGraph: + num_nodes: 3 + num_edges: 2 ``` """ function remove_edges(g::GNNGraph{<:COO_T}, edges_to_remove::AbstractVector{<:Integer}) @@ -200,6 +210,13 @@ function remove_edges(g::GNNGraph{<:COO_T}, edges_to_remove::AbstractVector{<:In g.ndata, edata, g.gdata) end + +function remove_edges(g::GNNGraph{<:COO_T}, p = 0.5) + num_edges = g.num_edges + edges_to_remove = filter(_ -> rand() < p, 1:num_edges) + return remove_edges(g, edges_to_remove) +end + """ remove_multi_edges(g::GNNGraph; aggr=+) diff --git a/GNNGraphs/test/transform.jl b/GNNGraphs/test/transform.jl index 993ac714a..05413fd4f 100644 --- a/GNNGraphs/test/transform.jl +++ b/GNNGraphs/test/transform.jl @@ -126,6 +126,13 @@ end @test new_t == [4] @test new_w == [0.3] @test new_edata == ['c'] + + # drop with probability + gnew = remove_edges(g, Float32(1.0)) + @test gnew.num_edges == 0 + + gnew = remove_edges(g, Float32(0.0)) + @test gnew.num_edges == g.num_edges end end