diff --git a/gtpv2/ie/ambr.go b/gtpv2/ie/ambr.go index c9a730a3..6911cadb 100644 --- a/gtpv2/ie/ambr.go +++ b/gtpv2/ie/ambr.go @@ -11,9 +11,6 @@ import ( // NewAggregateMaximumBitRate creates a new AggregateMaximumBitRate IE. func NewAggregateMaximumBitRate(up, down uint32) *IE { - // this is more efficient but removed for consistency with other structured IEs. - // return newUint64ValIE(AggregateMaximumBitRate, (uint64(up)<<32 | uint64(down))) - v := NewAggregateMaximumBitRateFields(up, down) b, err := v.Marshal() if err != nil { diff --git a/gtpv2/ie/apn-restriction.go b/gtpv2/ie/apn-restriction.go index f55e091c..98b21688 100644 --- a/gtpv2/ie/apn-restriction.go +++ b/gtpv2/ie/apn-restriction.go @@ -4,11 +4,9 @@ package ie -import "io" - // NewAPNRestriction creates a new APNRestriction IE. func NewAPNRestriction(restriction uint8) *IE { - return newUint8ValIE(APNRestriction, restriction) + return NewUint8IE(APNRestriction, restriction) } // APNRestriction returns APNRestriction in uint8 if the type of IE matches. @@ -17,11 +15,7 @@ func (i *IE) APNRestriction() (uint8, error) { return 0, &InvalidTypeError{Type: i.Type} } - if len(i.Payload) < 1 { - return 0, io.ErrUnexpectedEOF - } - - return i.Payload[0], nil + return i.ValueAsUint8() } // MustAPNRestriction returns APNRestriction in uint8, ignoring errors. @@ -30,3 +24,8 @@ func (i *IE) MustAPNRestriction() uint8 { v, _ := i.APNRestriction() return v } + +// RestrictionType returns RestrictionType in uint8 if the type of IE matches. +func (i *IE) RestrictionType() (uint8, error) { + return i.APNRestriction() +} diff --git a/gtpv2/ie/apn.go b/gtpv2/ie/apn.go index 520ed063..e49afbdb 100644 --- a/gtpv2/ie/apn.go +++ b/gtpv2/ie/apn.go @@ -4,22 +4,9 @@ package ie -import ( - "strings" -) - // NewAccessPointName creates a new AccessPointName IE. func NewAccessPointName(apn string) *IE { - i := New(AccessPointName, 0x00, make([]byte, len(apn)+1)) - var offset = 0 - for _, label := range strings.Split(apn, ".") { - l := len(label) - i.Payload[offset] = uint8(l) - copy(i.Payload[offset+1:], label) - offset += l + 1 - } - - return i + return NewFQDNIE(AccessPointName, apn) } // AccessPointName returns AccessPointName in string if the type of IE matches. @@ -27,25 +14,7 @@ func (i *IE) AccessPointName() (string, error) { if i.Type != AccessPointName { return "", &InvalidTypeError{Type: i.Type} } - - var ( - apn []string - offset int - ) - max := len(i.Payload) - for { - if offset >= max { - break - } - l := int(i.Payload[offset]) - if offset+l+1 > max { - break - } - apn = append(apn, string(i.Payload[offset+1:offset+l+1])) - offset += l + 1 - } - - return strings.Join(apn, "."), nil + return i.ValueAsFQDN() } // MustAccessPointName returns AccessPointName in string, ignoring errors. diff --git a/gtpv2/ie/arp.go b/gtpv2/ie/arp.go index a4c02380..0525692e 100644 --- a/gtpv2/ie/arp.go +++ b/gtpv2/ie/arp.go @@ -4,8 +4,6 @@ package ie -import "io" - // NewAllocationRetensionPriority creates a new AllocationRetensionPriority IE. func NewAllocationRetensionPriority(pci, pl, pvi uint8) *IE { i := New(AllocationRetensionPriority, 0x00, make([]byte, 1)) @@ -13,15 +11,12 @@ func NewAllocationRetensionPriority(pci, pl, pvi uint8) *IE { return i } +// AllocationRetensionPriority returns AllocationRetensionPriority in uint8 if the type of IE matches. func (i *IE) AllocationRetensionPriority() (uint8, error) { if i.Type != AllocationRetensionPriority { return 0, &InvalidTypeError{Type: i.Type} } - if len(i.Payload) < 1 { - return 0, io.ErrUnexpectedEOF - } - - return i.Payload[0], nil + return i.ValueAsUint8() } // HasPVI reports whether an IE has PVI bit. diff --git a/gtpv2/ie/bearer-context.go b/gtpv2/ie/bearer-context.go index be858b62..2834f69b 100644 --- a/gtpv2/ie/bearer-context.go +++ b/gtpv2/ie/bearer-context.go @@ -8,13 +8,7 @@ import "io" // NewBearerContext creates a new BearerContext IE. func NewBearerContext(ies ...*IE) *IE { - var omitted []*IE - for _, ie := range ies { - if ie != nil { - omitted = append(omitted, ie) - } - } - return newGroupedIE(BearerContext, omitted...) + return NewGroupedIE(BearerContext, ies...) } // NewBearerContextWithinCreateBearerRequest creates a new BearerContext used within CreateBearerRequest. diff --git a/gtpv2/ie/bearer-flags.go b/gtpv2/ie/bearer-flags.go index d5ea51ee..a8586807 100644 --- a/gtpv2/ie/bearer-flags.go +++ b/gtpv2/ie/bearer-flags.go @@ -4,11 +4,6 @@ package ie -import ( - "fmt" - "io" -) - // NewBearerFlags creates a new BearerFlags IE. func NewBearerFlags(asi, vInd, vb, ppc uint8) *IE { i := New(BearerFlags, 0x00, make([]byte, 1)) @@ -20,15 +15,11 @@ func NewBearerFlags(asi, vInd, vb, ppc uint8) *IE { func (i *IE) BearerFlags() (uint8, error) { switch i.Type { case BearerFlags: - if len(i.Payload) < 1 { - return 0, io.ErrUnexpectedEOF - } - - return i.Payload[0], nil + return i.ValueAsUint8() case BearerContext: ies, err := i.BearerContext() if err != nil { - return 0, fmt.Errorf("failed to retrieve BearerFlags: %w", err) + return 0, err } for _, child := range ies { @@ -92,54 +83,38 @@ func (i *IE) HasASI() bool { // ActivityStatusIndicator reports whether the bearer context is preserved in // the CN without corresponding Radio Access Bearer established. func (i *IE) ActivityStatusIndicator() bool { - if len(i.Payload) < 1 { - return false - } - switch i.Type { - case BearerFlags: - return i.Payload[0]&0x08 == 1 - default: + v, err := i.BearerFlags() + if err != nil { return false } + return v&0x08 == 1 } // VSRVCC reports whether this bearer is an IMS video bearer and is candidate // for PS-to-CS vSRVCC handover. func (i *IE) VSRVCC() bool { - if len(i.Payload) < 1 { - return false - } - switch i.Type { - case BearerFlags: - return i.Payload[0]&0x04 == 1 - default: + v, err := i.BearerFlags() + if err != nil { return false } + return v&0x04 == 1 } // VoiceBearer reports whether a voice bearer when doing PS-to-CS (v)SRVCC handover. func (i *IE) VoiceBearer() bool { - if len(i.Payload) < 1 { - return false - } - switch i.Type { - case BearerFlags: - return i.Payload[0]&0x02 == 1 - default: + v, err := i.BearerFlags() + if err != nil { return false } + return v&0x02 == 1 } // ProhibitPayloadCompression reports whether an SGSN should attempt to // compress the payload of user data when the users asks for it to be compressed. func (i *IE) ProhibitPayloadCompression() bool { - if len(i.Payload) < 1 { - return false - } - switch i.Type { - case BearerFlags: - return i.Payload[0]&0x01 == 1 - default: + v, err := i.BearerFlags() + if err != nil { return false } + return v&0x01 == 1 } diff --git a/gtpv2/ie/bearer-qos.go b/gtpv2/ie/bearer-qos.go index ddc1e105..52f57b8c 100644 --- a/gtpv2/ie/bearer-qos.go +++ b/gtpv2/ie/bearer-qos.go @@ -134,10 +134,7 @@ func (i *IE) QCILabel() (uint8, error) { } return i.Payload[1], nil case FlowQoS: - if len(i.Payload) < 1 { - return 0, io.ErrUnexpectedEOF - } - return i.Payload[0], nil + return i.ValueAsUint8() default: return 0, &InvalidTypeError{Type: i.Type} } diff --git a/gtpv2/ie/cause.go b/gtpv2/ie/cause.go index 65e15e98..a32fbbea 100644 --- a/gtpv2/ie/cause.go +++ b/gtpv2/ie/cause.go @@ -28,11 +28,7 @@ func NewCause(cause uint8, pce, bce, cs uint8, offendingIE *IE) *IE { func (i *IE) Cause() (uint8, error) { switch i.Type { case Cause: - if len(i.Payload) < 1 { - return 0, io.ErrUnexpectedEOF - } - - return i.Payload[0], nil + return i.ValueAsUint8() case BearerContext: ies, err := i.BearerContext() if err != nil { diff --git a/gtpv2/ie/charging-characteristics.go b/gtpv2/ie/charging-characteristics.go index 48a9f011..5eb5fbf0 100644 --- a/gtpv2/ie/charging-characteristics.go +++ b/gtpv2/ie/charging-characteristics.go @@ -4,14 +4,9 @@ package ie -import ( - "encoding/binary" - "io" -) - // NewChargingCharacteristics creates a new ChargingCharacteristics IE. func NewChargingCharacteristics(chr uint16) *IE { - return newUint16ValIE(ChargingCharacteristics, chr) + return NewUint16IE(ChargingCharacteristics, chr) } // ChargingCharacteristics returns the ChargingCharacteristics value in uint16 if the type of IE matches. @@ -19,11 +14,7 @@ func (i *IE) ChargingCharacteristics() (uint16, error) { if i.Type != ChargingCharacteristics { return 0, &InvalidTypeError{Type: i.Type} } - if len(i.Payload) < 2 { - return 0, io.ErrUnexpectedEOF - } - - return binary.BigEndian.Uint16(i.Payload), nil + return i.ValueAsUint16() } // MustChargingCharacteristics returns ChargingCharacteristics in uint16, ignoring errors. diff --git a/gtpv2/ie/charging-id.go b/gtpv2/ie/charging-id.go index a018c48c..2942f4aa 100644 --- a/gtpv2/ie/charging-id.go +++ b/gtpv2/ie/charging-id.go @@ -4,30 +4,20 @@ package ie -import ( - "encoding/binary" - "fmt" - "io" -) - // NewChargingID creates a new ChargingID IE. func NewChargingID(id uint32) *IE { - return newUint32ValIE(ChargingID, id) + return NewUint32IE(ChargingID, id) } // ChargingID returns the ChargingID value in uint32 if the type of IE matches. func (i *IE) ChargingID() (uint32, error) { switch i.Type { case ChargingID: - if len(i.Payload) < 4 { - return 0, io.ErrUnexpectedEOF - } - - return binary.BigEndian.Uint32(i.Payload[:4]), nil + return i.ValueAsUint32() case BearerContext: ies, err := i.BearerContext() if err != nil { - return 0, fmt.Errorf("failed to retrieve ChargingID: %w", err) + return 0, err } for _, child := range ies { diff --git a/gtpv2/ie/cmi.go b/gtpv2/ie/cmi.go index df3c3221..eb60c4dc 100644 --- a/gtpv2/ie/cmi.go +++ b/gtpv2/ie/cmi.go @@ -8,23 +8,19 @@ import "io" // NewCSGMembershipIndication creates a new CSGMembershipIndication IE. func NewCSGMembershipIndication(cmi uint8) *IE { - return newUint8ValIE(CSGMembershipIndication, cmi) + return NewUint8IE(CSGMembershipIndication, cmi) } // CMI returns CMI in uint8 if the type of IE matches. func (i *IE) CMI() (uint8, error) { - if len(i.Payload) < 1 { - return 0, io.ErrUnexpectedEOF - } - switch i.Type { case CSGMembershipIndication: - return i.Payload[0] & 0x01, nil + return i.ValueAsUint8() case UserCSGInformation: if len(i.Payload) < 8 { return 0, io.ErrUnexpectedEOF } - return i.Payload[7] & 0x01, nil + return i.Payload[7], nil default: return 0, &InvalidTypeError{Type: i.Type} } diff --git a/gtpv2/ie/csg-id.go b/gtpv2/ie/csg-id.go index 0e9fbf4f..b0e558a6 100644 --- a/gtpv2/ie/csg-id.go +++ b/gtpv2/ie/csg-id.go @@ -11,18 +11,14 @@ import ( // NewCSGID creates a new CSGID IE. func NewCSGID(id uint32) *IE { - return newUint32ValIE(CSGID, id&0x7ffffff) + return NewUint32IE(CSGID, id&0x7ffffff) } // CSGID returns CSGID in uint32 if the type of IE matches. func (i *IE) CSGID() (uint32, error) { - if len(i.Payload) < 4 { - return 0, io.ErrUnexpectedEOF - } - switch i.Type { case CSGID: - return binary.BigEndian.Uint32(i.Payload[0:4]) & 0x7ffffff, nil + return i.ValueAsUint32() case UserCSGInformation: if len(i.Payload) < 7 { return 0, io.ErrUnexpectedEOF diff --git a/gtpv2/ie/delay-value.go b/gtpv2/ie/delay-value.go index de0b531f..35b8a76f 100644 --- a/gtpv2/ie/delay-value.go +++ b/gtpv2/ie/delay-value.go @@ -11,10 +11,20 @@ import ( // NewDelayValue creates a new DelayValue IE. func NewDelayValue(delay time.Duration) *IE { - return newUint8ValIE(DelayValue, uint8(delay.Seconds()*1000/50)) + return NewUint8IE(DelayValue, uint8(delay.Seconds()*1000/50)) +} + +// NewDelayValueRaw creates a new DelayValue IE from a uint8 value. +// +// The value should be in multiples of 50ms or zero. +func NewDelayValueRaw(delay uint8) *IE { + return NewUint8IE(DelayValue, delay) } // DelayValue returns DelayValue in time.Duration if the type of IE matches. +// +// The returned value is in time.Duration. To get the value in multiples of 50ms, +// use ValueAsUint8 or access Payload field directly instead. func (i *IE) DelayValue() (time.Duration, error) { if i.Type != DelayValue { return 0, &InvalidTypeError{Type: i.Type} diff --git a/gtpv2/ie/detach-type.go b/gtpv2/ie/detach-type.go index e20f4ae7..ac83cda0 100644 --- a/gtpv2/ie/detach-type.go +++ b/gtpv2/ie/detach-type.go @@ -4,13 +4,9 @@ package ie -import ( - "io" -) - // NewDetachType creates a new DetachType IE. func NewDetachType(dtype uint8) *IE { - return newUint8ValIE(DetachType, dtype) + return NewUint8IE(DetachType, dtype) } // DetachType returns DetachType in uint8 if the type of IE matches. @@ -18,11 +14,7 @@ func (i *IE) DetachType() (uint8, error) { if i.Type != DetachType { return 0, &InvalidTypeError{Type: i.Type} } - if len(i.Payload) < 1 { - return 0, io.ErrUnexpectedEOF - } - - return i.Payload[0], nil + return i.ValueAsUint8() } // MustDetachType returns DetachType in uint8, ignoring errors. diff --git a/gtpv2/ie/ebi.go b/gtpv2/ie/ebi.go index 9fc35a78..4af88615 100644 --- a/gtpv2/ie/ebi.go +++ b/gtpv2/ie/ebi.go @@ -4,29 +4,20 @@ package ie -import ( - "fmt" - "io" -) - // NewEPSBearerID creates a new EPSBearerID IE. func NewEPSBearerID(ebi uint8) *IE { - return newUint8ValIE(EPSBearerID, ebi&0x0f) + return NewUint8IE(EPSBearerID, ebi&0x0f) } // EPSBearerID returns EPSBearerID if the type of IE matches. func (i *IE) EPSBearerID() (uint8, error) { switch i.Type { case EPSBearerID: - if len(i.Payload) < 1 { - return 0, io.ErrUnexpectedEOF - } - - return i.Payload[0], nil + return i.ValueAsUint8() case BearerContext: ies, err := i.BearerContext() if err != nil { - return 0, fmt.Errorf("failed to retrieve EPSBearerID: %w", err) + return 0, err } for _, child := range ies { diff --git a/gtpv2/ie/epc-timer.go b/gtpv2/ie/epc-timer.go index fe83079e..13897e8d 100644 --- a/gtpv2/ie/epc-timer.go +++ b/gtpv2/ie/epc-timer.go @@ -11,7 +11,7 @@ import ( "time" ) -// NewEPCTimer creates a new Timer IE. +// NewEPCTimer creates a new EPCTimer IE. func NewEPCTimer(duration time.Duration) *IE { // 8.87 EPC Timer // Timer unit @@ -49,10 +49,18 @@ func NewEPCTimer(duration time.Duration) *IE { value = 0 } - return newUint8ValIE(EPCTimer, unit+(value&0x1f)) + return NewUint8IE(EPCTimer, unit+(value&0x1f)) +} + +// NewEPCTimerRaw creates a new EPCTimer IE from a uint8 value. +func NewEPCTimerRaw(duration uint8) *IE { + return NewUint8IE(EPCTimer, duration) } // EPCTimer returns EPCTimer in time.Duration if the type of IE matches. +// +// The returned value is in time.Duration. To get the raw value as uint8, +// use ValueAsUint8 or access Payload field directly instead. func (i *IE) EPCTimer() (time.Duration, error) { return i.Timer() } diff --git a/gtpv2/ie/fqdn.go b/gtpv2/ie/fqdn.go index 95312b55..d20dcf00 100644 --- a/gtpv2/ie/fqdn.go +++ b/gtpv2/ie/fqdn.go @@ -4,22 +4,9 @@ package ie -import ( - "strings" -) - // NewFullyQualifiedDomainName creates a new FullyQualifiedDomainName IE. func NewFullyQualifiedDomainName(fqdn string) *IE { - i := New(FullyQualifiedDomainName, 0x00, make([]byte, len(fqdn)+1)) - var offset = 0 - for _, label := range strings.Split(fqdn, ".") { - l := len(label) - i.Payload[offset] = uint8(l) - copy(i.Payload[offset+1:], label) - offset += l + 1 - } - - return i + return NewFQDNIE(FullyQualifiedDomainName, fqdn) } // FullyQualifiedDomainName returns FullyQualifiedDomainName in string if the type of IE matches. @@ -27,25 +14,7 @@ func (i *IE) FullyQualifiedDomainName() (string, error) { if i.Type != FullyQualifiedDomainName { return "", &InvalidTypeError{Type: i.Type} } - - var ( - fqdn []string - offset int - ) - max := len(i.Payload) - for { - if offset >= max { - break - } - l := int(i.Payload[offset]) - if offset+l+1 > max { - break - } - fqdn = append(fqdn, string(i.Payload[offset+1:offset+l+1])) - offset += l + 1 - } - - return strings.Join(fqdn, "."), nil + return i.ValueAsFQDN() } // MustFullyQualifiedDomainName returns FullyQualifiedDomainName in string, ignoring errors. diff --git a/gtpv2/ie/hop-counter.go b/gtpv2/ie/hop-counter.go index eb572771..6b79cb61 100644 --- a/gtpv2/ie/hop-counter.go +++ b/gtpv2/ie/hop-counter.go @@ -4,11 +4,9 @@ package ie -import "io" - // NewHopCounter creates a new HopCounter IE. func NewHopCounter(hop uint8) *IE { - return newUint8ValIE(HopCounter, hop) + return NewUint8IE(HopCounter, hop) } // HopCounter returns HopCounter in uint8 if the type of IE matches. @@ -16,11 +14,7 @@ func (i *IE) HopCounter() (uint8, error) { if i.Type != HopCounter { return 0, &InvalidTypeError{Type: i.Type} } - if len(i.Payload) < 1 { - return 0, io.ErrUnexpectedEOF - } - - return i.Payload[0], nil + return i.ValueAsUint8() } // MustHopCounter returns HopCounter in uint8, ignoring errors. diff --git a/gtpv2/ie/ie.go b/gtpv2/ie/ie.go index fe2333a6..9e6bdd06 100644 --- a/gtpv2/ie/ie.go +++ b/gtpv2/ie/ie.go @@ -11,6 +11,8 @@ import ( "encoding/binary" "fmt" "io" + + "github.com/wmnsk/go-gtp/utils" ) // IE definitions. @@ -321,90 +323,6 @@ func (i *IE) String() string { ) } -var grouped = []uint8{ - BearerContext, - PDNConnection, - OverloadControlInformation, - LoadControlInformation, - RemoteUEContext, - SCEFPDNConnection, - V2XContext, - PC5QoSParameters, -} - -// IsGrouped reports whether an IE is grouped type or not. -func (i *IE) IsGrouped() bool { - for _, itype := range grouped { - if i.Type == itype { - return true - } - } - return false -} - -// Add adds variable number of IEs to a IE if the IE is grouped type and update length. -// Otherwise, this does nothing(no errors). -func (i *IE) Add(ies ...*IE) { - if !i.IsGrouped() { - return - } - - i.Payload = nil - i.ChildIEs = append(i.ChildIEs, ies...) - for _, ie := range i.ChildIEs { - serialized, err := ie.Marshal() - if err != nil { - continue - } - i.Payload = append(i.Payload, serialized...) - } - i.SetLength() -} - -// Remove removes an IE looked up by type and instance. -func (i *IE) Remove(typ, instance uint8) { - if !i.IsGrouped() { - return - } - - i.Payload = nil - newChildren := make([]*IE, len(i.ChildIEs)) - idx := 0 - for _, ie := range i.ChildIEs { - if ie.Type == typ && ie.Instance() == instance { - newChildren = newChildren[:len(newChildren)-1] - continue - } - newChildren[idx] = ie - idx++ - - serialized, err := ie.Marshal() - if err != nil { - continue - } - i.Payload = append(i.Payload, serialized...) - } - i.ChildIEs = newChildren - i.SetLength() -} - -// FindByType returns IE looked up by type and instance. -// -// The program may be slower when calling this method multiple times -// because this ranges over a ChildIEs each time it is called. -func (i *IE) FindByType(typ, instance uint8) (*IE, error) { - if !i.IsGrouped() { - return nil, ErrInvalidType - } - - for _, ie := range i.ChildIEs { - if ie.Type == typ && ie.Instance() == instance { - return ie, nil - } - } - return nil, ErrIENotFound -} - // ParseMultiIEs decodes multiple IEs at a time. // This is easy and useful but slower than decoding one by one. // When you don't know the number of IEs, this is the only way to decode them. @@ -426,46 +344,106 @@ func ParseMultiIEs(b []byte) ([]*IE, error) { return ies, nil } -func newUint8ValIE(t, v uint8) *IE { +// NewUint8IE creates a new IE with uint8 value. +func NewUint8IE(t, v uint8) *IE { return New(t, 0x00, []byte{v}) } -func newUint16ValIE(t uint8, v uint16) *IE { +// NewUint16IE creates a new IE with uint16 value. +func NewUint16IE(t uint8, v uint16) *IE { i := New(t, 0x00, make([]byte, 2)) binary.BigEndian.PutUint16(i.Payload, v) return i } -func newUint32ValIE(t uint8, v uint32) *IE { +// NewUint32IE creates a new IE with uint32 value. +func NewUint32IE(t uint8, v uint32) *IE { i := New(t, 0x00, make([]byte, 4)) binary.BigEndian.PutUint32(i.Payload, v) return i } -// unused for now. -// func newUint64ValIE(t uint8, v uint64) *IE { -// i := New(t, 0x00, make([]byte, 8)) -// binary.BigEndian.PutUint64(i.Payload, v) -// return i -// } +// NewUint64IE creates a new IE with uint64 value. +func NewUint64IE(t uint8, v uint64) *IE { + i := New(t, 0x00, make([]byte, 8)) + binary.BigEndian.PutUint64(i.Payload, v) + return i +} -func newStringIE(t uint8, v string) *IE { +// NewStringIE creates a new IE with string value. +func NewStringIE(t uint8, v string) *IE { return New(t, 0x00, []byte(v)) } -func newGroupedIE(itype uint8, ies ...*IE) *IE { - i := New(itype, 0x00, make([]byte, 0)) - i.ChildIEs = ies - for _, ie := range i.ChildIEs { - serialized, err := ie.Marshal() - if err != nil { - return nil - } - i.Payload = append(i.Payload, serialized...) +// NewFQDNIE creates a new IE with FQDN value. +func NewFQDNIE(t uint8, v string) *IE { + return New(t, 0x00, utils.EncodeFQDN(v)) +} + +// ValueAsUint8 returns the value of IE as uint8. +func (i *IE) ValueAsUint8() (uint8, error) { + if i.IsGrouped() { + return 0, &InvalidTypeError{Type: i.Type} + } + if len(i.Payload) < 1 { + return 0, io.ErrUnexpectedEOF } - i.SetLength() - return i + return i.Payload[0], nil +} + +// ValueAsUint16 returns the value of IE as uint16. +func (i *IE) ValueAsUint16() (uint16, error) { + if i.IsGrouped() { + return 0, &InvalidTypeError{Type: i.Type} + } + if len(i.Payload) < 2 { + return 0, io.ErrUnexpectedEOF + } + + return binary.BigEndian.Uint16(i.Payload[0:2]), nil +} + +// ValueAsUint32 returns the value of IE as uint32. +func (i *IE) ValueAsUint32() (uint32, error) { + if i.IsGrouped() { + return 0, &InvalidTypeError{Type: i.Type} + } + if len(i.Payload) < 4 { + return 0, io.ErrUnexpectedEOF + } + + return binary.BigEndian.Uint32(i.Payload[0:4]), nil +} + +// ValueAsUint64 returns the value of IE as uint64. +func (i *IE) ValueAsUint64() (uint64, error) { + if i.IsGrouped() { + return 0, &InvalidTypeError{Type: i.Type} + } + if len(i.Payload) < 8 { + return 0, io.ErrUnexpectedEOF + } + + return binary.BigEndian.Uint64(i.Payload[0:8]), nil +} + +// ValueAsString returns the value of IE as string. +func (i *IE) ValueAsString() (string, error) { + if i.IsGrouped() { + return "", &InvalidTypeError{Type: i.Type} + } + + return string(i.Payload), nil +} + +// ValueAsFQDN returns the value of IE as string, decoded as FQDN. +func (i *IE) ValueAsFQDN() (string, error) { + if i.IsGrouped() { + return "", &InvalidTypeError{Type: i.Type} + } + + return utils.DecodeFQDN(i.Payload), nil } var ieTypeNameMap = map[uint8]string{ diff --git a/gtpv2/ie/ie_grouped.go b/gtpv2/ie/ie_grouped.go new file mode 100644 index 00000000..0adae7e1 --- /dev/null +++ b/gtpv2/ie/ie_grouped.go @@ -0,0 +1,148 @@ +// Copyright 2019-2023 go-gtp authors. All rights reserved. +// Use of this source code is governed by a MIT-style license that can be +// found in the LICENSE file. + +package ie + +import "sync" + +// NewGroupedIE creates a new IE with the given IEs. +// +// The IEs with nil value will be ignored. +func NewGroupedIE(itype uint8, ies ...*IE) *IE { + i := New(itype, 0x00, make([]byte, 0)) + for _, ie := range ies { + if ie == nil { + continue + } + + serialized, err := ie.Marshal() + if err != nil { + return nil + } + + i.Payload = append(i.Payload, serialized...) + i.ChildIEs = append(i.ChildIEs, ie) + } + i.SetLength() + + return i +} + +// We're using map to avoid iterating over a list. +// The value `true` is not actually used. +// TODO: consider using a slice with utils in slices package introduced in Go 1.21. +var ( + mu sync.RWMutex + defaultGroupedIEMap = map[uint8]bool{ + BearerContext: true, + PDNConnection: true, + OverloadControlInformation: true, + LoadControlInformation: true, + RemoteUEContext: true, + SCEFPDNConnection: true, + V2XContext: true, + PC5QoSParameters: true, + } + isGroupedFun = func(t uint8) bool { + mu.RLock() + defer mu.RUnlock() + _, ok := defaultGroupedIEMap[t] + return ok + } +) + +// SetIsGroupedFun sets a function to check if an IE is of grouped type or not. +func SetIsGroupedFun(fun func(t uint8) bool) { + mu.Lock() + defer mu.Unlock() + isGroupedFun = fun +} + +// AddGroupedIEType adds IE type(s) to the defaultGroupedIEMap. +// This is useful when you want to add new IE types to the defaultGroupedIEMap. +// +// See also SetIsGroupedFun(). +func AddGroupedIEType(ts ...uint8) { + mu.Lock() + defer mu.Unlock() + for _, t := range ts { + defaultGroupedIEMap[t] = true + } +} + +// IsGrouped reports whether an IE is grouped type or not. +// +// By default, this package determines if an IE is grouped type or not by checking +// if the IE type is in the defaultGroupedIEMap. +// You can change this entire behavior by calling SetIsGroupedFun(), or you can add +// new IE types to the defaultGroupedIEMap by calling AddGroupedIEType(). +func (i *IE) IsGrouped() bool { + return isGroupedFun(i.Type) +} + +// Add adds variable number of IEs to a grouped IE and update length of it. +// This does nothing if the type of the IE is not grouped (no errors). +func (i *IE) Add(ies ...*IE) { + if !i.IsGrouped() { + return + } + + for _, ie := range ies { + if ie == nil { + continue + } + i.ChildIEs = append(i.ChildIEs, ie) + + serialized, err := ie.Marshal() + if err != nil { + continue + } + i.Payload = append(i.Payload, serialized...) + } + i.SetLength() +} + +// Remove removes an IE looked up by type and instance. +func (i *IE) Remove(typ, instance uint8) { + if !i.IsGrouped() { + return + } + + i.Payload = nil + newChildren := make([]*IE, len(i.ChildIEs)) + idx := 0 + for _, ie := range i.ChildIEs { + if ie.Type == typ && ie.Instance() == instance { + newChildren = newChildren[:len(newChildren)-1] + continue + } + newChildren[idx] = ie + idx++ + + serialized, err := ie.Marshal() + if err != nil { + continue + } + i.Payload = append(i.Payload, serialized...) + } + i.ChildIEs = newChildren + i.SetLength() +} + +// FindByType returns IE looked up by type and instance. +// +// The program may be slower when calling this method multiple times +// because this ranges over a ChildIEs each time it is called. +func (i *IE) FindByType(typ, instance uint8) (*IE, error) { + if !i.IsGrouped() { + return nil, ErrInvalidType + } + + for _, ie := range i.ChildIEs { + if ie.Type == typ && ie.Instance() == instance { + return ie, nil + } + } + return nil, ErrIENotFound +} diff --git a/gtpv2/ie/ie_test.go b/gtpv2/ie/ie_test.go index d67b2f5b..f2f40209 100644 --- a/gtpv2/ie/ie_test.go +++ b/gtpv2/ie/ie_test.go @@ -727,6 +727,7 @@ func TestIEAddRemove(t *testing.T) { ie.NewMSISDN("819012345678"), ) i.Add(ie.NewAccessPointName("foo.example")) + i.Add(nil) // ignored added := ie.NewBearerContext( ie.NewIMSI("123451234567890").WithInstance(1), diff --git a/gtpv2/ie/integer-number.go b/gtpv2/ie/integer-number.go index b85872ec..259e2c24 100644 --- a/gtpv2/ie/integer-number.go +++ b/gtpv2/ie/integer-number.go @@ -4,28 +4,16 @@ package ie -import ( - "encoding/binary" - "io" -) - // NewIntegerNumber creates a new IntegerNumber IE. -func NewIntegerNumber(port uint16) *IE { - return newUint16ValIE(IntegerNumber, port) +func NewIntegerNumber(num uint16) *IE { + return NewUint16IE(IntegerNumber, num) } // IntegerNumber returns IntegerNumber in uint16 if the type of IE matches. func (i *IE) IntegerNumber() (uint16, error) { - if len(i.Payload) < 1 { - return 0, io.ErrUnexpectedEOF - } - switch i.Type { case IntegerNumber: - if len(i.Payload) < 2 { - return 0, io.ErrUnexpectedEOF - } - return binary.BigEndian.Uint16(i.Payload[0:2]), nil + return i.ValueAsUint16() default: return 0, &InvalidTypeError{Type: i.Type} } diff --git a/gtpv2/ie/ldn.go b/gtpv2/ie/ldn.go index a532f1e9..dbaf08d0 100644 --- a/gtpv2/ie/ldn.go +++ b/gtpv2/ie/ldn.go @@ -4,11 +4,9 @@ package ie -import "io" - // NewLocalDistinguishedName creates a new LocalDistinguishedName IE. func NewLocalDistinguishedName(name string) *IE { - return newStringIE(LocalDistinguishedName, name) + return NewStringIE(LocalDistinguishedName, name) } // LocalDistinguishedName returns LocalDistinguishedName in string if the type of IE matches. @@ -16,11 +14,7 @@ func (i *IE) LocalDistinguishedName() (string, error) { if i.Type != LocalDistinguishedName { return "", &InvalidTypeError{Type: i.Type} } - if len(i.Payload) < 1 { - return "", io.ErrUnexpectedEOF - } - - return string(i.Payload), nil + return i.ValueAsString() } // MustLocalDistinguishedName returns LocalDistinguishedName in string, ignoring errors. diff --git a/gtpv2/ie/mbms-flags.go b/gtpv2/ie/mbms-flags.go index 3aaf5ed5..1af2f876 100644 --- a/gtpv2/ie/mbms-flags.go +++ b/gtpv2/ie/mbms-flags.go @@ -4,8 +4,6 @@ package ie -import "io" - // NewMBMSFlags creates a new MBMSFlags IE. func NewMBMSFlags(lmri, msri uint8) *IE { i := New(MBMSFlags, 0x00, make([]byte, 1)) @@ -18,11 +16,7 @@ func (i *IE) MBMSFlags() (uint8, error) { if i.Type != MBMSFlags { return 0, &InvalidTypeError{Type: i.Type} } - if len(i.Payload) < 1 { - return 0, io.ErrUnexpectedEOF - } - - return i.Payload[0], nil + return i.ValueAsUint8() } // MustMBMSFlags returns MBMSFlags in uint8, ignoring errors. @@ -55,27 +49,19 @@ func (i *IE) HasLMRI() bool { // LocalMBMSBearerContextRelease reports whether the MBMS Session Stop Request // message is used to release the MBMS Bearer Context locally in the MME/SGSN. func (i *IE) LocalMBMSBearerContextRelease() bool { - if len(i.Payload) < 1 { - return false - } - switch i.Type { - case MBMSFlags: - return i.Payload[0]&0x02 == 1 - default: + v, err := i.MBMSFlags() + if err != nil { return false } + return v&0x02 == 1 } // MBMSSessionReEstablishment reports whether the MBMS Session Start Request // message is used to re-establish an MBMS session. func (i *IE) MBMSSessionReEstablishment() bool { - if len(i.Payload) < 1 { - return false - } - switch i.Type { - case MBMSFlags: - return i.Payload[0]&0x01 == 1 - default: + v, err := i.MBMSFlags() + if err != nil { return false } + return v&0x01 == 1 } diff --git a/gtpv2/ie/node-features.go b/gtpv2/ie/node-features.go index eb2ceb32..9a8f3110 100644 --- a/gtpv2/ie/node-features.go +++ b/gtpv2/ie/node-features.go @@ -4,13 +4,9 @@ package ie -import ( - "io" -) - // NewNodeFeatures creates a new NodeFeatures IE. func NewNodeFeatures(flags uint8) *IE { - return newUint8ValIE(NodeFeatures, flags) + return NewUint8IE(NodeFeatures, flags) } // NodeFeatures returns NodeFeatures in uint8 if the type of IE matches. @@ -18,12 +14,7 @@ func (i *IE) NodeFeatures() (uint8, error) { if i.Type != NodeFeatures { return 0, &InvalidTypeError{Type: i.Type} } - - if len(i.Payload) < 1 { - return 0, io.ErrUnexpectedEOF - } - - return i.Payload[0], nil + return i.ValueAsUint8() } // MustNodeFeatures returns NodeFeatures in uint8 if the type of IE matches. diff --git a/gtpv2/ie/node-type.go b/gtpv2/ie/node-type.go index 1ce4c5ae..ab39a106 100644 --- a/gtpv2/ie/node-type.go +++ b/gtpv2/ie/node-type.go @@ -4,11 +4,9 @@ package ie -import "io" - // NewNodeType creates a new NodeType IE. func NewNodeType(nodeType uint8) *IE { - return newUint8ValIE(NodeType, nodeType) + return NewUint8IE(NodeType, nodeType) } // NodeType returns NodeType in uint8 if the type of IE matches. @@ -16,11 +14,7 @@ func (i *IE) NodeType() (uint8, error) { if i.Type != NodeType { return 0, &InvalidTypeError{Type: i.Type} } - if len(i.Payload) < 1 { - return 0, io.ErrUnexpectedEOF - } - - return i.Payload[0], nil + return i.ValueAsUint8() } // MustNodeType returns NodeType in uint8, ignoring errors. diff --git a/gtpv2/ie/p-tmsi.go b/gtpv2/ie/p-tmsi.go index 48bcb459..f23ab824 100644 --- a/gtpv2/ie/p-tmsi.go +++ b/gtpv2/ie/p-tmsi.go @@ -4,14 +4,9 @@ package ie -import ( - "encoding/binary" - "io" -) - // NewPacketTMSI creates a new PacketTMSI IE. func NewPacketTMSI(ptmsi uint32) *IE { - return newUint32ValIE(PacketTMSI, ptmsi) + return NewUint32IE(PacketTMSI, ptmsi) } // PacketTMSI returns PacketTMSI value in uint32 if type matches. @@ -19,11 +14,7 @@ func (i *IE) PacketTMSI() (uint32, error) { if i.Type != PacketTMSI { return 0, &InvalidTypeError{Type: i.Type} } - if len(i.Payload) < 4 { - return 0, io.ErrUnexpectedEOF - } - - return binary.BigEndian.Uint32(i.Payload), nil + return i.ValueAsUint32() } // MustPacketTMSI returns PacketTMSI in uint32, ignoring errors. diff --git a/gtpv2/ie/pdn-type.go b/gtpv2/ie/pdn-type.go index db9da962..06b7133b 100644 --- a/gtpv2/ie/pdn-type.go +++ b/gtpv2/ie/pdn-type.go @@ -4,22 +4,16 @@ package ie -import "io" - // NewPDNType creates a new PDNType IE. func NewPDNType(pdn uint8) *IE { - return newUint8ValIE(PDNType, pdn) + return NewUint8IE(PDNType, pdn) } // PDNType returns the PDNType value in uint8 if the type of IE matches. func (i *IE) PDNType() (uint8, error) { - if len(i.Payload) < 1 { - return 0, io.ErrUnexpectedEOF - } - switch i.Type { case PDNType, PDNAddressAllocation: - return i.Payload[0], nil + return i.ValueAsUint8() default: return 0, &InvalidTypeError{Type: i.Type} } diff --git a/gtpv2/ie/port-number.go b/gtpv2/ie/port-number.go index 0592e566..3fb9aab5 100644 --- a/gtpv2/ie/port-number.go +++ b/gtpv2/ie/port-number.go @@ -4,28 +4,16 @@ package ie -import ( - "encoding/binary" - "io" -) - // NewPortNumber creates a new PortNumber IE. func NewPortNumber(port uint16) *IE { - return newUint16ValIE(PortNumber, port) + return NewUint16IE(PortNumber, port) } // PortNumber returns PortNumber in uint16 if the type of IE matches. func (i *IE) PortNumber() (uint16, error) { - if len(i.Payload) < 1 { - return 0, io.ErrUnexpectedEOF - } - switch i.Type { case PortNumber: - if len(i.Payload) < 2 { - return 0, io.ErrUnexpectedEOF - } - return binary.BigEndian.Uint16(i.Payload[0:2]), nil + return i.ValueAsUint16() default: return 0, &InvalidTypeError{Type: i.Type} } diff --git a/gtpv2/ie/pti.go b/gtpv2/ie/pti.go index a3d2fd1d..e647c161 100644 --- a/gtpv2/ie/pti.go +++ b/gtpv2/ie/pti.go @@ -4,11 +4,9 @@ package ie -import "io" - // NewProcedureTransactionID creates a new ProcedureTransactionID IE. func NewProcedureTransactionID(pti uint8) *IE { - return newUint8ValIE(ProcedureTransactionID, pti) + return NewUint8IE(ProcedureTransactionID, pti) } // ProcedureTransactionID returns ProcedureTransactionID in uint8 if the type of IE matches. @@ -16,11 +14,7 @@ func (i *IE) ProcedureTransactionID() (uint8, error) { if i.Type != ProcedureTransactionID { return 0, &InvalidTypeError{Type: i.Type} } - if len(i.Payload) < 1 { - return 0, io.ErrUnexpectedEOF - } - - return i.Payload[0], nil + return i.ValueAsUint8() } // MustProcedureTransactionID returns ProcedureTransactionID in uint8, ignoring errors. diff --git a/gtpv2/ie/rat-type.go b/gtpv2/ie/rat-type.go index 0b70d767..237203ef 100644 --- a/gtpv2/ie/rat-type.go +++ b/gtpv2/ie/rat-type.go @@ -4,22 +4,16 @@ package ie -import "io" - // NewRATType creates a new RATType IE. func NewRATType(rat uint8) *IE { - return newUint8ValIE(RATType, rat) + return NewUint8IE(RATType, rat) } // RATType returns RATType in uint8 if the type of IE matches. func (i *IE) RATType() (uint8, error) { - if len(i.Payload) < 1 { - return 0, io.ErrUnexpectedEOF - } - switch i.Type { case RATType: - return i.Payload[0], nil + return i.ValueAsUint8() default: return 0, &InvalidTypeError{Type: i.Type} } diff --git a/gtpv2/ie/recovery.go b/gtpv2/ie/recovery.go index cf09e4a7..55f09ce6 100644 --- a/gtpv2/ie/recovery.go +++ b/gtpv2/ie/recovery.go @@ -4,11 +4,9 @@ package ie -import "io" - // NewRecovery creates a new Recovery IE. func NewRecovery(recovery uint8) *IE { - return newUint8ValIE(Recovery, recovery) + return NewUint8IE(Recovery, recovery) } // Recovery returns Recovery value if the type of IE matches. @@ -16,10 +14,7 @@ func (i *IE) Recovery() (uint8, error) { if i.Type != Recovery { return 0, &InvalidTypeError{Type: i.Type} } - if len(i.Payload) < 1 { - return 0, io.ErrUnexpectedEOF - } - return i.Payload[0], nil + return i.ValueAsUint8() } // MustRecovery returns Recovery in uint8, ignoring errors. diff --git a/gtpv2/ie/rfsp-index.go b/gtpv2/ie/rfsp-index.go index 389ab3e2..5fae6b18 100644 --- a/gtpv2/ie/rfsp-index.go +++ b/gtpv2/ie/rfsp-index.go @@ -4,11 +4,9 @@ package ie -import "io" - // NewRFSPIndex creates a new RFSPIndex IE. func NewRFSPIndex(idx uint8) *IE { - return newUint8ValIE(RFSPIndex, idx) + return NewUint8IE(RFSPIndex, idx) } // RFSPIndex returns RFSPIndex in uint8 if the type of IE matches. @@ -16,11 +14,7 @@ func (i *IE) RFSPIndex() (uint8, error) { if i.Type != RFSPIndex { return 0, &InvalidTypeError{Type: i.Type} } - if len(i.Payload) < 1 { - return 0, io.ErrUnexpectedEOF - } - - return i.Payload[0], nil + return i.ValueAsUint8() } // MustRFSPIndex returns RFSPIndex in uint8, ignoring errors. diff --git a/gtpv2/ie/selection-mode.go b/gtpv2/ie/selection-mode.go index 658c58c9..7e9f7864 100644 --- a/gtpv2/ie/selection-mode.go +++ b/gtpv2/ie/selection-mode.go @@ -4,14 +4,9 @@ package ie -import "io" - // NewSelectionMode creates a new SelectionMode IE. -// -// Note that exactly one of the parameters should be set to true. -// Otherwise, you'll get the unexpected result. func NewSelectionMode(mode uint8) *IE { - return newUint8ValIE(SelectionMode, mode) + return NewUint8IE(SelectionMode, mode) } // SelectionMode returns SelectionMode value if the type of IE matches. @@ -19,11 +14,7 @@ func (i *IE) SelectionMode() (uint8, error) { if i.Type != SelectionMode { return 0, &InvalidTypeError{Type: i.Type} } - if len(i.Payload) < 1 { - return 0, io.ErrUnexpectedEOF - } - - return i.Payload[0], nil + return i.ValueAsUint8() } // MustSelectionMode returns SelectionMode in uint8, ignoring errors. diff --git a/gtpv2/ie/service-indicator.go b/gtpv2/ie/service-indicator.go index 6d3e2fd8..71460b6c 100644 --- a/gtpv2/ie/service-indicator.go +++ b/gtpv2/ie/service-indicator.go @@ -4,11 +4,9 @@ package ie -import "io" - // NewServiceIndicator creates a new ServiceIndicator IE. func NewServiceIndicator(ind uint8) *IE { - return newUint8ValIE(ServiceIndicator, ind) + return NewUint8IE(ServiceIndicator, ind) } // ServiceIndicator returns ServiceIndicator in uint8 if the type of IE matches. @@ -16,11 +14,7 @@ func (i *IE) ServiceIndicator() (uint8, error) { if i.Type != ServiceIndicator { return 0, &InvalidTypeError{Type: i.Type} } - if len(i.Payload) < 1 { - return 0, io.ErrUnexpectedEOF - } - - return i.Payload[0], nil + return i.ValueAsUint8() } // MustServiceIndicator returns ServiceIndicator in uint8, ignoring errors. diff --git a/gtpv2/ie/tmsi.go b/gtpv2/ie/tmsi.go index 7bd7da1e..6b07c395 100644 --- a/gtpv2/ie/tmsi.go +++ b/gtpv2/ie/tmsi.go @@ -4,25 +4,16 @@ package ie -import ( - "encoding/binary" - "io" -) - // NewTMSI creates a new TMSI IE. func NewTMSI(tmsi uint32) *IE { - return newUint32ValIE(TMSI, tmsi) + return NewUint32IE(TMSI, tmsi) } // TMSI returns TMSI in uint32 if the type of IE matches. func (i *IE) TMSI() (uint32, error) { - if len(i.Payload) < 4 { - return 0, io.ErrUnexpectedEOF - } - switch i.Type { case TMSI: - return binary.BigEndian.Uint32(i.Payload), nil + return i.ValueAsUint32() default: return 0, &InvalidTypeError{Type: i.Type} } diff --git a/gtpv2/ie/uli-timestamp.go b/gtpv2/ie/uli-timestamp.go index 6e873f85..911de400 100644 --- a/gtpv2/ie/uli-timestamp.go +++ b/gtpv2/ie/uli-timestamp.go @@ -13,7 +13,7 @@ import ( // NewULITimestamp creates a new ULITimestamp IE. func NewULITimestamp(ts time.Time) *IE { u64sec := uint64(ts.Sub(time.Date(1900, time.January, 1, 0, 0, 0, 0, time.UTC))) / 1000000000 - return newUint32ValIE(ULITimestamp, uint32(u64sec)) + return NewUint32IE(ULITimestamp, uint32(u64sec)) } // Timestamp returns Timestamp in time.Time if the type of IE matches. diff --git a/utils/utils.go b/utils/utils.go index 1221d961..9f4fd390 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -8,6 +8,7 @@ package utils import ( "encoding/binary" "encoding/hex" + "strings" ) // StrToSwappedBytes returns swapped bits from a byte. @@ -142,3 +143,42 @@ func ParseECI(eci uint32) (enbID uint32, cellID uint8, err error) { enbID = binary.BigEndian.Uint32([]byte{0, buf[0], buf[1], buf[2]}) return } + +// EncodeFQDN encodes the given string as the Name Syntax defined +// in RFC 2181, RFC 1035 and RFC 1123. +func EncodeFQDN(fqdn string) []byte { + b := make([]byte, len(fqdn)+1) + + var offset = 0 + for _, label := range strings.Split(fqdn, ".") { + l := len(label) + b[offset] = uint8(l) + copy(b[offset+1:], label) + offset += l + 1 + } + + return b +} + +// DecodeFQDN decodes the given Name Syntax-encoded []byte as a string. +func DecodeFQDN(b []byte) string { + var ( + fqdn []string + offset int + ) + + max := len(b) + for { + if offset >= max { + break + } + l := int(b[offset]) + if offset+l+1 > max { + break + } + fqdn = append(fqdn, string(b[offset+1:offset+l+1])) + offset += l + 1 + } + + return strings.Join(fqdn, ".") +}