diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index bb88fc323e..d84e3953f5 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -12,6 +12,7 @@ import ( v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/policy/matcher" + "github.com/rs/zerolog/log" "go4.org/netipx" "google.golang.org/protobuf/types/known/timestamppb" "gorm.io/gorm" @@ -241,25 +242,25 @@ func (nodes Nodes) FilterByIP(ip netip.Addr) Nodes { // cannot be directly marshalled into database values are stored // correctly in the database. // This currently means storing the keys as strings. -func (n *Node) BeforeSave(tx *gorm.DB) (err error) { - n.MachineKeyDatabaseField = n.MachineKey.String() - n.NodeKeyDatabaseField = n.NodeKey.String() - n.DiscoKeyDatabaseField = n.DiscoKey.String() +func (node *Node) BeforeSave(tx *gorm.DB) error { + node.MachineKeyDatabaseField = node.MachineKey.String() + node.NodeKeyDatabaseField = node.NodeKey.String() + node.DiscoKeyDatabaseField = node.DiscoKey.String() var endpoints StringList - for _, addrPort := range n.Endpoints { + for _, addrPort := range node.Endpoints { endpoints = append(endpoints, addrPort.String()) } - n.EndpointsDatabaseField = endpoints + node.EndpointsDatabaseField = endpoints - hi, err := json.Marshal(n.Hostinfo) + hi, err := json.Marshal(node.Hostinfo) if err != nil { return fmt.Errorf("failed to marshal Hostinfo to store in db: %w", err) } - n.HostinfoDatabaseField = string(hi) + node.HostinfoDatabaseField = string(hi) - return + return nil } // AfterFind is a hook that ensures that Node objects fields that @@ -267,43 +268,45 @@ func (n *Node) BeforeSave(tx *gorm.DB) (err error) { // correctly. // This currently unmarshals all the keys, stored as strings, into // the proper types. -func (n *Node) AfterFind(tx *gorm.DB) (err error) { +func (node *Node) AfterFind(tx *gorm.DB) error { var machineKey key.MachinePublic - if err := machineKey.UnmarshalText([]byte(n.MachineKeyDatabaseField)); err != nil { + if err := machineKey.UnmarshalText([]byte(node.MachineKeyDatabaseField)); err != nil { return fmt.Errorf("failed to unmarshal machine key from db: %w", err) } - n.MachineKey = machineKey + node.MachineKey = machineKey var nodeKey key.NodePublic - if err := nodeKey.UnmarshalText([]byte(n.NodeKeyDatabaseField)); err != nil { + if err := nodeKey.UnmarshalText([]byte(node.NodeKeyDatabaseField)); err != nil { return fmt.Errorf("failed to unmarshal node key from db: %w", err) } - n.NodeKey = nodeKey + node.NodeKey = nodeKey var discoKey key.DiscoPublic - if err := discoKey.UnmarshalText([]byte(n.DiscoKeyDatabaseField)); err != nil { + if err := discoKey.UnmarshalText([]byte(node.DiscoKeyDatabaseField)); err != nil { return fmt.Errorf("failed to unmarshal disco key from db: %w", err) } - n.DiscoKey = discoKey + node.DiscoKey = discoKey - var endpoints []netip.AddrPort - for _, ep := range n.EndpointsDatabaseField { + endpoints := make([]netip.AddrPort, len(node.EndpointsDatabaseField)) + for idx, ep := range node.EndpointsDatabaseField { addrPort, err := netip.ParseAddrPort(ep) if err != nil { return fmt.Errorf("failed to parse endpoint from db: %w", err) } - endpoints = append(endpoints, addrPort) + endpoints[idx] = addrPort } - n.Endpoints = endpoints + node.Endpoints = endpoints var hi tailcfg.Hostinfo - if err := json.Unmarshal([]byte(n.HostinfoDatabaseField), &hi); err != nil { + if err := json.Unmarshal([]byte(node.HostinfoDatabaseField), &hi); err != nil { + log.Trace().Err(err).Msgf("Hostinfo content: %s", node.HostinfoDatabaseField) + return fmt.Errorf("failed to unmarshal Hostinfo from db: %w", err) } - n.Hostinfo = &hi + node.Hostinfo = &hi - return + return nil } func (node *Node) Proto() *v1.Node {