Skip to content

Commit

Permalink
fix route table migration wiping routes 0.22 -> 0.23
Browse files Browse the repository at this point in the history
Signed-off-by: Kristoffer Dalby <[email protected]>
  • Loading branch information
kradalby committed Aug 27, 2024
1 parent a68854a commit 7e9f321
Show file tree
Hide file tree
Showing 9 changed files with 204 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ jobs:

- name: Run tests
if: steps.changed-files.outputs.files == 'true'
run: nix develop --check
run: nix develop --command -- gotestsum
22 changes: 17 additions & 5 deletions hscontrol/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ func NewHeadscaleDatabase(
dbConn,
gormigrate.DefaultOptions,
[]*gormigrate.Migration{
// New migrations should be added as transactions at the end of this list.
// The initial commit here is quite messy, completely out of order and
// New migrations must be added as transactions at the end of this list.
// The initial migration here is quite messy, completely out of order and
// has no versioning and is the tech debt of not having versioned migrations
// prior to this point. This first migration is all DB changes to bring a DB
// up to 0.23.0.
Expand Down Expand Up @@ -123,9 +123,21 @@ func NewHeadscaleDatabase(
}
}

err = tx.AutoMigrate(&types.Route{})
if err != nil {
return err
// Only run automigrate Route table if it does not exist. It has only been
// changed ones, when machines where renamed to nodes, which is covered
// further up. This whole initial integration is a mess and if AutoMigrate
// is ran on a 0.22 to 0.23 update, it will wipe all the routes.
if tx.Migrator().HasTable(&types.Route{}) && tx.Migrator().HasTable(&types.Node{}) {
err := tx.Exec("delete from routes where node_id not in (select id from nodes)").Error
if err != nil {
return err
}
}
if !tx.Migrator().HasTable(&types.Route{}) {
err = tx.AutoMigrate(&types.Route{})
if err != nil {
return err
}
}

err = tx.AutoMigrate(&types.Node{})
Expand Down
168 changes: 168 additions & 0 deletions hscontrol/db/db_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
package db

import (
"fmt"
"io"
"net/netip"
"os"
"path/filepath"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/assert"
"gorm.io/gorm"
)

func TestMigrations(t *testing.T) {
ipp := func(p string) types.IPPrefix {
return types.IPPrefix(netip.MustParsePrefix(p))
}
r := func(id uint64, p string, a, e, i bool) types.Route {
return types.Route{
NodeID: id,
Prefix: ipp(p),
Advertised: a,
Enabled: e,
IsPrimary: i,
}
}
tests := []struct {
dbPath string
wantFunc func(*testing.T, *HSDatabase)
wantErr string
}{
{
dbPath: "testdata/0-22-3-to-0-23-0-routes-are-dropped-2063.sqlite",
wantFunc: func(t *testing.T, h *HSDatabase) {
routes, err := Read(h.DB, func(rx *gorm.DB) (types.Routes, error) {
return GetRoutes(rx)
})
assert.NoError(t, err)

assert.Len(t, routes, 10)
want := types.Routes{
r(1, "0.0.0.0/0", true, true, false),
r(1, "::/0", true, true, false),
r(1, "10.9.110.0/24", true, true, true),
r(26, "172.100.100.0/24", true, true, true),
r(26, "172.100.100.0/24", true, false, false),
r(31, "0.0.0.0/0", true, true, false),
r(31, "0.0.0.0/0", true, false, false),
r(31, "::/0", true, true, false),
r(31, "::/0", true, false, false),
r(32, "192.168.0.24/32", true, true, true),
}
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), cmp.Comparer(func(x, y types.IPPrefix) bool {
return x == y
})); diff != "" {
t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
}
},
},
{
dbPath: "testdata/0-22-3-to-0-23-0-routes-fail-foreign-key-2076.sqlite",
wantFunc: func(t *testing.T, h *HSDatabase) {
routes, err := Read(h.DB, func(rx *gorm.DB) (types.Routes, error) {
return GetRoutes(rx)
})
assert.NoError(t, err)

assert.Len(t, routes, 4)
want := types.Routes{
// These routes exists, but have no nodes associated with them
// when the migration starts.
// r(1, "0.0.0.0/0", true, true, false),
// r(1, "::/0", true, true, false),
// r(3, "0.0.0.0/0", true, true, false),
// r(3, "::/0", true, true, false),
// r(5, "0.0.0.0/0", true, true, false),
// r(5, "::/0", true, true, false),
// r(6, "0.0.0.0/0", true, true, false),
// r(6, "::/0", true, true, false),
// r(6, "10.0.0.0/8", true, false, false),
// r(7, "0.0.0.0/0", true, true, false),
// r(7, "::/0", true, true, false),
// r(7, "10.0.0.0/8", true, false, false),
// r(9, "0.0.0.0/0", true, true, false),
// r(9, "::/0", true, true, false),
// r(9, "10.0.0.0/8", true, true, false),
// r(11, "0.0.0.0/0", true, true, false),
// r(11, "::/0", true, true, false),
// r(11, "10.0.0.0/8", true, true, true),
// r(12, "0.0.0.0/0", true, true, false),
// r(12, "::/0", true, true, false),
// r(12, "10.0.0.0/8", true, false, false),
//
// These nodes exists, so routes should be kept.
r(13, "10.0.0.0/8", true, false, false),
r(13, "0.0.0.0/0", true, true, false),
r(13, "::/0", true, true, false),
r(13, "10.18.80.2/32", true, true, true),
}
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), cmp.Comparer(func(x, y types.IPPrefix) bool {
return x == y
})); diff != "" {
t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
}
},
},
}

for _, tt := range tests {
t.Run(tt.dbPath, func(t *testing.T) {
dbPath, err := testCopyOfDatabase(tt.dbPath)
if err != nil {
t.Fatalf("copying db for test: %s", err)
}

hsdb, err := NewHeadscaleDatabase(types.DatabaseConfig{
Type: "sqlite3",
Sqlite: types.SqliteConfig{
Path: dbPath,
},
}, "")
if err != nil && tt.wantErr != err.Error() {
t.Errorf("TestMigrations() unexpected error = %v, wantErr %v", err, tt.wantErr)
}

if tt.wantFunc != nil {
tt.wantFunc(t, hsdb)
}
})
}
}

func testCopyOfDatabase(src string) (string, error) {
sourceFileStat, err := os.Stat(src)
if err != nil {
return "", err
}

if !sourceFileStat.Mode().IsRegular() {
return "", fmt.Errorf("%s is not a regular file", src)
}

source, err := os.Open(src)
if err != nil {
return "", err
}
defer source.Close()

tmpDir, err := os.MkdirTemp("", "hsdb-test-*")
if err != nil {
return "", err
}

fn := filepath.Base(src)
dst := filepath.Join(tmpDir, fn)

destination, err := os.Create(dst)
if err != nil {
return "", err
}
defer destination.Close()
_, err = io.Copy(destination, source)
return dst, err
}
7 changes: 3 additions & 4 deletions hscontrol/db/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ import (
"fmt"
"net/netip"
"sort"
"sync"
"time"

"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/patrickmn/go-cache"
"github.com/puzpuzpuz/xsync/v3"
"github.com/rs/zerolog/log"
"github.com/sasha-s/go-deadlock"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
Expand Down Expand Up @@ -724,7 +724,7 @@ func ExpireExpiredNodes(tx *gorm.DB,
// It is used to delete ephemeral nodes that have disconnected and should be
// cleaned up.
type EphemeralGarbageCollector struct {
mu deadlock.Mutex
mu sync.Mutex

deleteFunc func(types.NodeID)
toBeDeleted map[types.NodeID]*time.Timer
Expand Down Expand Up @@ -752,10 +752,9 @@ func (e *EphemeralGarbageCollector) Close() {
// Schedule schedules a node for deletion after the expiry duration.
func (e *EphemeralGarbageCollector) Schedule(nodeID types.NodeID, expiry time.Duration) {
e.mu.Lock()
defer e.mu.Unlock()

timer := time.NewTimer(expiry)
e.toBeDeleted[nodeID] = timer
e.mu.Unlock()

go func() {
select {
Expand Down
14 changes: 8 additions & 6 deletions hscontrol/db/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -609,12 +609,14 @@ func TestEphemeralGarbageCollectorOrder(t *testing.T) {
})
go e.Start()

e.Schedule(1, 1*time.Second)
e.Schedule(2, 2*time.Second)
e.Schedule(3, 3*time.Second)
e.Schedule(4, 4*time.Second)
e.Cancel(2)
e.Cancel(4)
go e.Schedule(1, 1*time.Second)
go e.Schedule(2, 2*time.Second)
go e.Schedule(3, 3*time.Second)
go e.Schedule(4, 4*time.Second)

time.Sleep(time.Second)
go e.Cancel(2)
go e.Cancel(4)

time.Sleep(6 * time.Second)

Expand Down
Binary file not shown.
Binary file not shown.
6 changes: 5 additions & 1 deletion hscontrol/util/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import (
"net/netip"

"github.com/google/go-cmp/cmp"
"tailscale.com/types/ipproto"
"tailscale.com/types/key"
"tailscale.com/types/views"
)

var PrefixComparer = cmp.Comparer(func(x, y netip.Prefix) bool {
Expand All @@ -31,6 +33,8 @@ var DkeyComparer = cmp.Comparer(func(x, y key.DiscoPublic) bool {
return x.String() == y.String()
})

var ViewSliceIPProtoComparer = cmp.Comparer(func(a, b views.Slice[ipproto.Proto]) bool { return views.SliceEqual(a, b) })

var Comparers []cmp.Option = []cmp.Option{
IPComparer, PrefixComparer, AddrPortComparer, MkeyComparer, NkeyComparer, DkeyComparer,
IPComparer, PrefixComparer, AddrPortComparer, MkeyComparer, NkeyComparer, DkeyComparer, ViewSliceIPProtoComparer,
}
4 changes: 2 additions & 2 deletions integration/route_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1170,7 +1170,7 @@ func TestSubnetRouteACL(t *testing.T) {
},
}

if diff := cmp.Diff(wantClientFilter, clientNm.PacketFilter, util.PrefixComparer); diff != "" {
if diff := cmp.Diff(wantClientFilter, clientNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" {
t.Errorf("Client (%s) filter, unexpected result (-want +got):\n%s", client.Hostname(), diff)
}

Expand Down Expand Up @@ -1220,7 +1220,7 @@ func TestSubnetRouteACL(t *testing.T) {
},
}

if diff := cmp.Diff(wantSubnetFilter, subnetNm.PacketFilter, util.PrefixComparer); diff != "" {
if diff := cmp.Diff(wantSubnetFilter, subnetNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" {
t.Errorf("Subnet (%s) filter, unexpected result (-want +got):\n%s", subRouter1.Hostname(), diff)
}
}

0 comments on commit 7e9f321

Please sign in to comment.