diff --git a/pkg/api/networkservice/ipcontext_helpers.go b/pkg/api/networkservice/ipcontext_helpers.go index a237395..8db408b 100644 --- a/pkg/api/networkservice/ipcontext_helpers.go +++ b/pkg/api/networkservice/ipcontext_helpers.go @@ -33,25 +33,13 @@ func (i *IPContext) GetDstIPNets() []*net.IPNet { } // GetDstRoutesWithExplicitNextHop - returns routes with the Route.NextHop explicitly set to the first IP address -// of the same address family (IPv4 or IPv6) from SrcIPAddrs if and only if Route.NextHop was initially nil +// of the same address family (IPv4 or IPv6) from SrcIPAddrs if and only if Route.NextHop was initially nil and +// the Route.NextHop is not in DstIPAddrs func (i *IPContext) GetDstRoutesWithExplicitNextHop() (routes []*Route) { srcIPs := i.GetSrcIPNets() - ipv6NextHop := getNextHop(filterIPNetByFamily(IpFamily_IPV6, srcIPs)) - ipv4NextHop := getNextHop(filterIPNetByFamily(IpFamily_IPV4, srcIPs)) - for _, route := range i.GetDstRoutes() { - if route.GetPrefixIPNet() != nil && route.GetNextHopIP() == nil { - if ipv4NextHop != nil && route.GetPrefixIPNet().IP.To4() != nil { - route = route.Clone() - route.NextHop = ipv4NextHop.String() - } - if ipv6NextHop != nil && route.GetPrefixIPNet().IP.To4() == nil { - route = route.Clone() - route.NextHop = ipv6NextHop.String() - } - } - routes = append(routes, route) - } - return routes + ipv6Nets := filterIPNetByFamily(IpFamily_IPV6, srcIPs) + ipv4Nets := filterIPNetByFamily(IpFamily_IPV4, srcIPs) + return getRoutesWithExplicitNextHop(i.GetDstRoutes(), ipv4Nets, ipv6Nets) } // GetSrcRoutesWithExplicitNextHop - returns routes with the Route.NextHop explicitly set to the first IP address @@ -59,22 +47,9 @@ func (i *IPContext) GetDstRoutesWithExplicitNextHop() (routes []*Route) { func (i *IPContext) GetSrcRoutesWithExplicitNextHop() (routes []*Route) { // Set nextHop for any Route that is missing them dstIPs := i.GetDstIPNets() - ipv6NextHop := getNextHop(filterIPNetByFamily(IpFamily_IPV6, dstIPs)) - ipv4NextHop := getNextHop(filterIPNetByFamily(IpFamily_IPV4, dstIPs)) - for _, route := range i.GetSrcRoutes() { - if route.GetPrefixIPNet() != nil && route.GetNextHopIP() == nil { - if ipv4NextHop != nil && route.GetPrefixIPNet().IP.To4() != nil { - route = route.Clone() - route.NextHop = ipv4NextHop.String() - } - if ipv6NextHop != nil && route.GetPrefixIPNet().IP.To4() == nil { - route = route.Clone() - route.NextHop = ipv6NextHop.String() - } - } - routes = append(routes, route) - } - return routes + ipv6Nets := filterIPNetByFamily(IpFamily_IPV6, dstIPs) + ipv4Nets := filterIPNetByFamily(IpFamily_IPV4, dstIPs) + return getRoutesWithExplicitNextHop(i.GetSrcRoutes(), ipv4Nets, ipv6Nets) } // GetSrcIPRoutes - returns routes for any SrcIPs that are not contained in the prefixes of at least one DstIP @@ -109,6 +84,25 @@ func (i *IPContext) GetDstIPRoutes() (routes []*Route) { return routes } +func getRoutesWithExplicitNextHop(inRoutes []*Route, toIPv4Nets, toIPv6Nets []*net.IPNet) (routes []*Route) { + ipv6NextHop := getNextHop(toIPv6Nets) + ipv4NextHop := getNextHop(toIPv4Nets) + for _, route := range inRoutes { + if route.GetPrefixIPNet() != nil && route.GetNextHopIP() == nil { + if ipv4NextHop != nil && route.GetPrefixIPNet().IP.To4() != nil && !contains(toIPv4Nets, route.GetPrefixIPNet().IP) { + route = route.Clone() + route.NextHop = ipv4NextHop.String() + } + if ipv6NextHop != nil && route.GetPrefixIPNet().IP.To4() == nil && !contains(toIPv6Nets, route.GetPrefixIPNet().IP) { + route = route.Clone() + route.NextHop = ipv6NextHop.String() + } + } + routes = append(routes, route) + } + return routes +} + func getNextHop(ipNets []*net.IPNet) net.IP { if len(ipNets) > 0 && ipNets[0] != nil { return ipNets[0].IP @@ -119,10 +113,10 @@ func getNextHop(ipNets []*net.IPNet) net.IP { func filterIPNetByFamily(family IpFamily_Family, ipNets []*net.IPNet) []*net.IPNet { var rv []*net.IPNet for _, ipNet := range ipNets { - if ipNet != nil && family == IpFamily_IPV4 && ipNet.IP.To4() == nil { + if ipNet != nil && family == IpFamily_IPV4 && ipNet.IP.To4() != nil { rv = append(rv, ipNet) } - if ipNet != nil && family == IpFamily_IPV6 && ipNet.IP.To4() != nil { + if ipNet != nil && family == IpFamily_IPV6 && ipNet.IP.To4() == nil { rv = append(rv, ipNet) } }