diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 4f1036546c6..b69dd66132f 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -176,6 +176,12 @@ func (hsdb *HSDatabase) GetNodeByMachineKey( hsdb.mu.RLock() defer hsdb.mu.RUnlock() + return hsdb.getNodeByMachineKey(machineKey) +} + +func (hsdb *HSDatabase) getNodeByMachineKey( + machineKey key.MachinePublic, +) (*types.Node, error) { mach := types.Node{} if result := hsdb.db. Preload("AuthKey"). diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go index d866c7cc601..a3be9f223e5 100644 --- a/hscontrol/db/routes.go +++ b/hscontrol/db/routes.go @@ -167,14 +167,14 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error { // be enabled at the same time, as per // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 if !route.IsExitRoute() { - route.Enabled = false - route.IsPrimary = false - err = hsdb.db.Save(route).Error + err = hsdb.failoverRouteWithNotify(route) if err != nil { return err } - err = hsdb.failoverRouteWithNotify(route) + route.Enabled = false + route.IsPrimary = false + err = hsdb.db.Save(route).Error if err != nil { return err } @@ -229,14 +229,15 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error { // be enabled at the same time, as per // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 if !route.IsExitRoute() { - if err := hsdb.db.Unscoped().Delete(&route).Error; err != nil { - return err - } - err := hsdb.failoverRouteWithNotify(route) if err != nil { return nil } + + if err := hsdb.db.Unscoped().Delete(&route).Error; err != nil { + return err + } + } else { routes, err := hsdb.getNodeRoutes(&node) @@ -489,8 +490,12 @@ func (hsdb *HSDatabase) failoverRouteWithNotify(r *types.Route) error { var nodes types.Nodes + log.Trace(). + Str("hostname", r.Node.Hostname). + Msg("loading machines with new primary routes from db") + for _, key := range changedKeys { - node, err := hsdb.GetNodeByMachineKey(key) + node, err := hsdb.getNodeByMachineKey(key) if err != nil { return err } @@ -498,11 +503,19 @@ func (hsdb *HSDatabase) failoverRouteWithNotify(r *types.Route) error { nodes = append(nodes, node) } + log.Trace(). + Str("hostname", r.Node.Hostname). + Msg("notifying peers about primary route change") + hsdb.notifier.NotifyAll(types.StateUpdate{ Type: types.StatePeerChanged, Changed: nodes, }) + log.Trace(). + Str("hostname", r.Node.Hostname). + Msg("notified peers about primary route change") + return nil } @@ -571,6 +584,10 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro return nil, err } + log.Trace(). + Str("hostname", newPrimary.Node.Hostname). + Msg("removed primary from old route") + // Set primary for the new primary newPrimary.IsPrimary = true err = hsdb.db.Save(&newPrimary).Error @@ -580,6 +597,10 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro return nil, err } + log.Trace(). + Str("hostname", newPrimary.Node.Hostname). + Msg("set primary to new route") + rKey, err := r.Node.MachinePublicKey() if err != nil { return nil, err diff --git a/integration/route_test.go b/integration/route_test.go index 449108a7c99..ac41d615bad 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -384,12 +384,17 @@ func TestHASubnetRouterFailover(t *testing.T) { // Verify that the client has routes from the primary machine srs1, err := subRouter1.Status() + srs2, err := subRouter2.Status() clientStatus, err := client.Status() assertNoErr(t, err) srs1PeerStatus := clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus := clientStatus.Peer[srs2.Self.PublicKey] + assertNotNil(t, srs1PeerStatus.PrimaryRoutes) + assert.Nil(t, srs2PeerStatus.PrimaryRoutes) + assert.Contains( t, srs1PeerStatus.PrimaryRoutes.AsSlice(), @@ -431,13 +436,15 @@ func TestHASubnetRouterFailover(t *testing.T) { // TODO(kradalby): Check client status // Route is expected to be on SR2 - srs2, err := subRouter2.Status() + srs2, err = subRouter2.Status() clientStatus, err = client.Status() assertNoErr(t, err) - srs2PeerStatus := clientStatus.Peer[srs2.Self.PublicKey] + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + assert.Nil(t, srs1PeerStatus.PrimaryRoutes) assertNotNil(t, srs2PeerStatus.PrimaryRoutes) if srs2PeerStatus.PrimaryRoutes != nil { @@ -489,8 +496,10 @@ func TestHASubnetRouterFailover(t *testing.T) { clientStatus, err = client.Status() assertNoErr(t, err) + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + assert.Nil(t, srs1PeerStatus.PrimaryRoutes) assertNotNil(t, srs2PeerStatus.PrimaryRoutes) if srs2PeerStatus.PrimaryRoutes != nil { @@ -523,12 +532,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assertNoErr(t, err) assert.Len(t, routesAfter1Up, 2) - // Node 1 is not primary + // Node 1 is primary assert.Equal(t, true, routesAfter1Up[0].Advertised) assert.Equal(t, true, routesAfter1Up[0].Enabled) assert.Equal(t, true, routesAfter1Up[0].IsPrimary) - // Node 2 is primary + // Node 2 is not primary assert.Equal(t, true, routesAfter1Up[1].Advertised) assert.Equal(t, true, routesAfter1Up[1].Enabled) assert.Equal(t, false, routesAfter1Up[1].IsPrimary) @@ -538,8 +547,10 @@ func TestHASubnetRouterFailover(t *testing.T) { assertNoErr(t, err) srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] assert.NotNil(t, srs1PeerStatus.PrimaryRoutes) + assert.Nil(t, srs2PeerStatus.PrimaryRoutes) if srs1PeerStatus.PrimaryRoutes != nil { assert.Contains( @@ -586,8 +597,178 @@ func TestHASubnetRouterFailover(t *testing.T) { assertNoErr(t, err) srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] assert.NotNil(t, srs1PeerStatus.PrimaryRoutes) + assert.Nil(t, srs2PeerStatus.PrimaryRoutes) + + if srs1PeerStatus.PrimaryRoutes != nil { + assert.Contains( + t, + srs1PeerStatus.PrimaryRoutes.AsSlice(), + netip.MustParsePrefix(expectedRoutes[string(srs1.Self.ID)]), + ) + } + + // Disable the route of subnet router 1, making it failover to 2 + t.Logf("disabling route in subnet router 1 (%s)", subRouter1.Hostname()) + _, err = headscale.Execute( + []string{ + "headscale", + "routes", + "disable", + "--route", + fmt.Sprintf("%d", routesAfter2Up[0].Id), + }) + assertNoErr(t, err) + + time.Sleep(5 * time.Second) + + var routesAfterDisabling1 []*v1.Route + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "routes", + "list", + "--output", + "json", + }, + &routesAfterDisabling1, + ) + assertNoErr(t, err) + assert.Len(t, routesAfterDisabling1, 2) + + // Node 1 is not primary + assert.Equal(t, true, routesAfterDisabling1[0].Advertised) + assert.Equal(t, false, routesAfterDisabling1[0].Enabled) + assert.Equal(t, false, routesAfterDisabling1[0].IsPrimary) + + // Node 2 is primary + assert.Equal(t, true, routesAfterDisabling1[1].Advertised) + assert.Equal(t, true, routesAfterDisabling1[1].Enabled) + assert.Equal(t, true, routesAfterDisabling1[1].IsPrimary) + + // Verify that the route is announced from subnet router 1 + clientStatus, err = client.Status() + assertNoErr(t, err) + + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + + assert.Nil(t, srs1PeerStatus.PrimaryRoutes) + assert.NotNil(t, srs2PeerStatus.PrimaryRoutes) + + if srs2PeerStatus.PrimaryRoutes != nil { + assert.Contains( + t, + srs2PeerStatus.PrimaryRoutes.AsSlice(), + netip.MustParsePrefix(expectedRoutes[string(srs2.Self.ID)]), + ) + } + + // enable the route of subnet router 1, no change expected + t.Logf("enabling route in subnet router 1 (%s)", subRouter1.Hostname()) + _, err = headscale.Execute( + []string{ + "headscale", + "routes", + "enable", + "--route", + fmt.Sprintf("%d", routesAfter2Up[0].Id), + }) + assertNoErr(t, err) + + time.Sleep(5 * time.Second) + + var routesAfterEnabling1 []*v1.Route + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "routes", + "list", + "--output", + "json", + }, + &routesAfterEnabling1, + ) + assertNoErr(t, err) + assert.Len(t, routesAfterEnabling1, 2) + + // Node 1 is not primary + assert.Equal(t, true, routesAfterEnabling1[0].Advertised) + assert.Equal(t, true, routesAfterEnabling1[0].Enabled) + assert.Equal(t, false, routesAfterEnabling1[0].IsPrimary) + + // Node 2 is primary + assert.Equal(t, true, routesAfterEnabling1[1].Advertised) + assert.Equal(t, true, routesAfterEnabling1[1].Enabled) + assert.Equal(t, true, routesAfterEnabling1[1].IsPrimary) + + // Verify that the route is announced from subnet router 1 + clientStatus, err = client.Status() + assertNoErr(t, err) + + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + + assert.Nil(t, srs1PeerStatus.PrimaryRoutes) + assert.NotNil(t, srs2PeerStatus.PrimaryRoutes) + + if srs2PeerStatus.PrimaryRoutes != nil { + assert.Contains( + t, + srs2PeerStatus.PrimaryRoutes.AsSlice(), + netip.MustParsePrefix(expectedRoutes[string(srs2.Self.ID)]), + ) + } + + // delete the route of subnet router 2, failover to one expected + t.Logf("deleting route in subnet router 2 (%s)", subRouter2.Hostname()) + _, err = headscale.Execute( + []string{ + "headscale", + "routes", + "delete", + "--route", + fmt.Sprintf("%d", routesAfterEnabling1[1].Id), + }) + assertNoErr(t, err) + + time.Sleep(5 * time.Second) + + var routesAfterDeleting2 []*v1.Route + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "routes", + "list", + "--output", + "json", + }, + &routesAfterDeleting2, + ) + assertNoErr(t, err) + assert.Len(t, routesAfterDeleting2, 1) + + t.Logf("routes after deleting2 %#v", routesAfterDeleting2) + + // Node 1 is primary + assert.Equal(t, true, routesAfterDeleting2[0].Advertised) + assert.Equal(t, true, routesAfterDeleting2[0].Enabled) + assert.Equal(t, true, routesAfterDeleting2[0].IsPrimary) + + // Verify that the route is announced from subnet router 1 + clientStatus, err = client.Status() + assertNoErr(t, err) + + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + + assertNotNil(t, srs1PeerStatus.PrimaryRoutes) + assert.Nil(t, srs2PeerStatus.PrimaryRoutes) if srs1PeerStatus.PrimaryRoutes != nil { assert.Contains(