diff --git a/pkg/network/ebpf/c/protocols/classification/defs.h b/pkg/network/ebpf/c/protocols/classification/defs.h index 46dbfd8e7df51..235d638d2f5b5 100644 --- a/pkg/network/ebpf/c/protocols/classification/defs.h +++ b/pkg/network/ebpf/c/protocols/classification/defs.h @@ -76,7 +76,7 @@ typedef enum { // Each `protocol_t` entry is implicitly associated to a single // `protocol_layer_t` value (see notes above). // -//In order to determine which `protocol_layer_t` a `protocol_t` belongs to, +// In order to determine which `protocol_layer_t` a `protocol_t` belongs to, // users can call `get_protocol_layer` typedef enum { LAYER_UNKNOWN, @@ -103,7 +103,7 @@ typedef struct { // `protocol_stack_t` is embedded in the `conn_stats_t` type, which is used // across the whole NPM kernel code. If we added the 64-bit timestamp field // directly to `protocol_stack_t`, we would go from 4 bytes to 12 bytes, which -// bloats the eBPF stack size of some NPM probes. Using the wrapper type +// bloats the eBPF stack size of some NPM probes. Using the wrapper type // prevents that, because we pretty much only store the wrapper type in the // connection_protocol map, but elsewhere in the code we're still using // protocol_stack_t, so this is change is "transparent" to most of the code. @@ -123,6 +123,8 @@ typedef enum { CLASSIFICATION_GRPC_PROG, __PROG_ENCRYPTION, // Encryption classification programs go here + CLASSIFICATION_TLS_CLIENT_PROG, + CLASSIFICATION_TLS_SERVER_PROG, CLASSIFICATION_PROG_MAX, } classification_prog_t; diff --git a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h index e3f44bfa16bdc..865288731ac7f 100644 --- a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h +++ b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h @@ -160,15 +160,26 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint(struct protocol_t app_layer_proto = get_protocol_from_stack(protocol_stack, LAYER_APPLICATION); - if ((app_layer_proto == PROTOCOL_UNKNOWN || app_layer_proto == PROTOCOL_POSTGRES) && is_tls(buffer, usm_ctx->buffer.size, skb_info.data_end)) { + tls_record_header_t tls_hdr = {0}; + + if ((app_layer_proto == PROTOCOL_UNKNOWN || app_layer_proto == PROTOCOL_POSTGRES) && is_tls(skb, skb_info.data_off, skb_info.data_end, &tls_hdr)) { protocol_stack = get_or_create_protocol_stack(&usm_ctx->tuple); if (!protocol_stack) { return; } // TLS classification - update_protocol_information(usm_ctx, protocol_stack, PROTOCOL_TLS); - // The connection is TLS encrypted, thus we cannot classify the protocol - // using the socket filter and therefore we can bail out; + if (tls_hdr.content_type != TLS_HANDSHAKE) { + // We can't classify TLS encrypted traffic further, so return early + update_protocol_information(usm_ctx, protocol_stack, PROTOCOL_TLS); + return; + } + + // Parse TLS handshake payload + tls_info_t *tags = get_or_create_tls_enhanced_tags(&usm_ctx->tuple); + if (tags) { + // The packet is a TLS handshake, so trigger tail calls to extract metadata from the payload + goto next_program; + } return; } @@ -200,6 +211,58 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint(struct classification_next_program(skb, usm_ctx); } +__maybe_unused static __always_inline void protocol_classifier_entrypoint_tls_handshake_client(struct __sk_buff *skb) { + usm_context_t *usm_ctx = usm_context(skb); + if (!usm_ctx) { + return; + } + tls_info_t* tls_info = get_tls_enhanced_tags(&usm_ctx->tuple); + if (!tls_info) { + goto next_program; + } + __u32 offset = usm_ctx->skb_info.data_off + sizeof(tls_record_header_t); + __u32 data_end = usm_ctx->skb_info.data_end; + if (!is_tls_handshake_client_hello(skb, offset, usm_ctx->skb_info.data_end)) { + goto next_program; + } + if (!parse_client_hello(skb, offset, data_end, tls_info)) { + return; + } + +next_program: + classification_next_program(skb, usm_ctx); +} + +__maybe_unused static __always_inline void protocol_classifier_entrypoint_tls_handshake_server(struct __sk_buff *skb) { + usm_context_t *usm_ctx = usm_context(skb); + if (!usm_ctx) { + return; + } + tls_info_t* tls_info = get_tls_enhanced_tags(&usm_ctx->tuple); + if (!tls_info) { + goto next_program; + } + __u32 offset = usm_ctx->skb_info.data_off + sizeof(tls_record_header_t); + __u32 data_end = usm_ctx->skb_info.data_end; + if (!is_tls_handshake_server_hello(skb, offset, data_end)) { + goto next_program; + } + if (!parse_server_hello(skb, offset, data_end, tls_info)) { + return; + } + + protocol_stack_t *protocol_stack = get_protocol_stack_if_exists(&usm_ctx->tuple); + if (!protocol_stack) { + return; + } + update_protocol_information(usm_ctx, protocol_stack, PROTOCOL_TLS); + // We can't classify TLS encrypted traffic further, so return early + return; + +next_program: + classification_next_program(skb, usm_ctx); +} + __maybe_unused static __always_inline void protocol_classifier_entrypoint_queues(struct __sk_buff *skb) { usm_context_t *usm_ctx = usm_context(skb); if (!usm_ctx) { diff --git a/pkg/network/ebpf/c/protocols/classification/routing-helpers.h b/pkg/network/ebpf/c/protocols/classification/routing-helpers.h index 9e2ba628851c6..d7565977eb36d 100644 --- a/pkg/network/ebpf/c/protocols/classification/routing-helpers.h +++ b/pkg/network/ebpf/c/protocols/classification/routing-helpers.h @@ -18,29 +18,7 @@ static __always_inline bool has_available_program(classification_prog_t current_ return true; } -#pragma clang diagnostic push -// The following check is ignored because *currently* there are no API or -// Encryption classification programs registerd. -// Therefore the enum containing all BPF programs looks like the following: -// -// typedef enum { -// CLASSIFICATION_PROG_UNKNOWN = 0, -// __PROG_APPLICATION, -// APPLICATION_PROG_A -// APPLICATION_PROG_B -// APPLICATION_PROG_C -// ... -// __PROG_API, -// // No programs here -// __PROG_ENCRYPTION, -// // No programs here -// CLASSIFICATION_PROG_MAX, -// } classification_prog_t; -// -// Which means that the following conditionals will always evaluate to false: -// a) current_program > __PROG_API && current_program < __PROG_ENCRYPTION -// b) current_program > __PROG_ENCRYPTION && current_program < CLASSIFICATION_PROG_MAX -#pragma clang diagnostic ignored "-Wtautological-overlap-compare" +// get_current_program_layer returns the layer bit of the current program static __always_inline u16 get_current_program_layer(classification_prog_t current_program) { if (current_program > __PROG_APPLICATION && current_program < __PROG_API) { return LAYER_APPLICATION_BIT; @@ -56,20 +34,19 @@ static __always_inline u16 get_current_program_layer(classification_prog_t curre return 0; } -#pragma clang diagnostic pop static __always_inline classification_prog_t next_layer_entrypoint(usm_context_t *usm_ctx) { u16 to_skip = usm_ctx->routing_skip_layers; + if (!(to_skip&LAYER_ENCRYPTION_BIT)) { + return __PROG_ENCRYPTION+1; + } if (!(to_skip&LAYER_APPLICATION_BIT)) { return __PROG_APPLICATION+1; } if (!(to_skip&LAYER_API_BIT)) { return __PROG_API+1; } - if (!(to_skip&LAYER_ENCRYPTION_BIT)) { - return __PROG_ENCRYPTION+1; - } return CLASSIFICATION_PROG_UNKNOWN; } diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index 0f7f2cdf9ee56..813c548360a9a 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -1,9 +1,9 @@ #ifndef __TLS_H #define __TLS_H -#include "ktypes.h" -#include "bpf_builtins.h" +#include "tracer/tracer.h" +// TLS version constants (SSL versions are deprecated, included for completeness) #define SSL_VERSION20 0x0200 #define SSL_VERSION30 0x0300 #define TLS_VERSION10 0x0301 @@ -11,122 +11,612 @@ #define TLS_VERSION12 0x0303 #define TLS_VERSION13 0x0304 -#define TLS_HANDSHAKE 0x16 -#define TLS_APPLICATION_DATA 0x17 +// TLS Content Types (https://www.rfc-editor.org/rfc/rfc5246#page-19 6.2. Record Layer) +#define TLS_HANDSHAKE 0x16 +#define TLS_APPLICATION_DATA 0x17 +#define TLS_CHANGE_CIPHER_SPEC 0x14 +#define TLS_ALERT 0x15 -/* https://www.rfc-editor.org/rfc/rfc5246#page-19 6.2. Record Layer */ +// TLS Handshake Types +#define TLS_HANDSHAKE_CLIENT_HELLO 0x01 +#define TLS_HANDSHAKE_SERVER_HELLO 0x02 + +// Bitmask constants for offered versions +#define TLS_VERSION10_BIT 1 << 0 +#define TLS_VERSION11_BIT 1 << 1 +#define TLS_VERSION12_BIT 1 << 2 +#define TLS_VERSION13_BIT 1 << 3 +// Maximum number of extensions to parse when looking for SUPPORTED_VERSIONS_EXTENSION +#define MAX_EXTENSIONS 16 +// The supported_versions extension for TLS 1.3 is described in RFC 8446 Section 4.2.1 +#define SUPPORTED_VERSIONS_EXTENSION 0x002B + +// Maximum TLS record payload size (16 KB) #define TLS_MAX_PAYLOAD_LENGTH (1 << 14) -// TLS record layer header structure +// The following field lengths and message formats are defined by the TLS specifications +// For TLS 1.2 (and earlier) see: +// RFC 5246 - The Transport Layer Security (TLS) Protocol Version 1.2 +// https://tools.ietf.org/html/rfc5246 +// Particularly Section 7.4 details handshake messages and their fields, and Section 6.2.1 +// covers the TLS record layer. +// For TLS 1.3, see: +// RFC 8446 - The Transport Layer Security (TLS) Protocol Version 1.3 +// https://tools.ietf.org/html/rfc8446 +// Many handshake structures are similar, but some extensions (like supported_versions) are defined here +#define TLS_HANDSHAKE_LENGTH 3 // Handshake length is 3 bytes (RFC 5246 Section 7.4) +#define TLS_HELLO_MESSAGE_HEADER_SIZE 4 // handshake_type(1) + length(3) +#define RANDOM_LENGTH 32 // Random field length in Client/Server Hello (RFC 5246 Section 7.4.1.2) +#define PROTOCOL_VERSION_LENGTH 2 // Protocol version field is 2 bytes (RFC 5246 Section 6.2.1) +#define SESSION_ID_LENGTH 1 // Session ID length field is 1 byte (RFC 5246 Section 7.4.1.2) +#define CIPHER_SUITES_LENGTH 2 // Cipher Suites length field is 2 bytes (RFC 5246 Section 7.4.1.2) +#define COMPRESSION_METHODS_LENGTH 1 // Compression Methods length field is 1 byte (RFC 5246 Section 7.4.1.2) +#define EXTENSION_TYPE_LENGTH 2 // Extension Type field is 2 bytes (RFC 5246 Section 7.4.1.4) +#define EXTENSION_LENGTH_FIELD 2 // Extension Length field is 2 bytes (RFC 5246 Section 7.4.1.4) + +// For single-byte fields (list lengths, etc.) +#define SINGLE_BYTE_LENGTH 1 + +// Minimum extension header length = Extension Type (2 bytes) + Extension Length (2 bytes) = 4 bytes +#define MIN_EXTENSION_HEADER_LENGTH (EXTENSION_TYPE_LENGTH + EXTENSION_LENGTH_FIELD) + +// Maximum number of supported versions we unroll for (all TLS versions) +#define MAX_SUPPORTED_VERSIONS 4 + +// TLS record layer header structure (RFC 5246) typedef struct { __u8 content_type; __u16 version; __u16 length; } __attribute__((packed)) tls_record_header_t; -typedef struct { +// Checks if the TLS version is valid +static __always_inline bool is_valid_tls_version(__u16 version) { + switch (version) { + case SSL_VERSION20: + case SSL_VERSION30: + case TLS_VERSION10: + case TLS_VERSION11: + case TLS_VERSION12: + case TLS_VERSION13: + return true; + default: + return false; + } +} + +// set_tls_offered_version sets the bit corresponding to the offered version in the offered_versions field of tls_info +static __always_inline void set_tls_offered_version(tls_info_t *tls_info, __u16 version) { + switch (version) { + case TLS_VERSION10: + tls_info->offered_versions |= TLS_VERSION10_BIT; + break; + case TLS_VERSION11: + tls_info->offered_versions |= TLS_VERSION11_BIT; + break; + case TLS_VERSION12: + tls_info->offered_versions |= TLS_VERSION12_BIT; + break; + case TLS_VERSION13: + tls_info->offered_versions |= TLS_VERSION13_BIT; + break; + default: + break; + } +} + +// TLS Record Header (RFC 5246 Section 6.2.1) +// +// +---------+---------+---------+-----------+ +// | type(1) | version(2) | length(2) | +// +---------+---------+---------+-----------+ +// type: 1 byte (TLS_CONTENT_TYPE) +// version: 2 bytes (e.g., 0x03 0x03 for TLS 1.2) +// length: 2 bytes (total number of payload bytes following this header) + +// read_tls_record_header reads the TLS record header from the packet +// Reference: RFC 5246 Section 6.2.1 (Record Layer), https://tools.ietf.org/html/rfc5246#section-6.2.1 +// Validates the record header fields (content_type, version, length) and checks for correctness within packet bounds. +static __always_inline bool read_tls_record_header(struct __sk_buff *skb, __u32 header_offset, __u32 data_end, tls_record_header_t *tls_hdr) { + // Ensure there's enough space for TLS record header + if (header_offset + sizeof(tls_record_header_t) > data_end) { + return false; + } + + // Read TLS record header + if (bpf_skb_load_bytes(skb, header_offset, tls_hdr, sizeof(tls_record_header_t)) < 0) { + return false; + } + + // Convert fields to host byte order + tls_hdr->version = bpf_ntohs(tls_hdr->version); + tls_hdr->length = bpf_ntohs(tls_hdr->length); + + // Validate version and length + if (!is_valid_tls_version(tls_hdr->version)) { + return false; + } + if (tls_hdr->length > TLS_MAX_PAYLOAD_LENGTH) { + return false; + } + + // Ensure we don't read beyond the packet + return header_offset + sizeof(tls_record_header_t) + tls_hdr->length <= data_end; +} + +// TLS Handshake Message Header (RFC 5246 Section 7.4) +// +---------+---------+---------+---------+ +// | handshake_type(1) | length(3 bytes) | +// +---------+---------+---------+---------+ +// +// The handshake_type identifies the handshake message (e.g., ClientHello, ServerHello). +// length indicates the size of the handshake message that follows (not including these 4 bytes). + +// is_valid_tls_handshake checks if the TLS handshake message is valid +// The function expects the record to have already been validated. It further checks that the +// handshake_type and handshake_length are consistent. +static __always_inline bool is_valid_tls_handshake(struct __sk_buff *skb, __u32 header_offset, __u32 data_end, const tls_record_header_t *hdr) { + // At this point, we know from read_tls_record_header() that: + // - hdr->version is a valid TLS version + // - hdr->length fits entirely within the packet (header_offset + hdr->length <= data_end) + + __u32 handshake_offset = header_offset + sizeof(tls_record_header_t); + + // Ensure we don't read beyond the packet + if (handshake_offset + SINGLE_BYTE_LENGTH > data_end) { + return false; + } + // Read handshake_type (1 byte) __u8 handshake_type; - __u8 length[3]; - __u16 version; -} __attribute__((packed)) tls_hello_message_t; + if (bpf_skb_load_bytes(skb, handshake_offset, &handshake_type, SINGLE_BYTE_LENGTH) < 0) { + return false; + } -#define TLS_HANDSHAKE_CLIENT_HELLO 0x01 -#define TLS_HANDSHAKE_SERVER_HELLO 0x02 -// The size of the handshake type and the length. -#define TLS_HELLO_MESSAGE_HEADER_SIZE 4 + // Read handshake_length (3 bytes) + __u32 length_offset = handshake_offset + SINGLE_BYTE_LENGTH; + if (length_offset + TLS_HANDSHAKE_LENGTH > data_end) { + return false; + } + __u8 handshake_length_bytes[TLS_HANDSHAKE_LENGTH]; + if (bpf_skb_load_bytes(skb, length_offset, handshake_length_bytes, TLS_HANDSHAKE_LENGTH) < 0) { + return false; + } -// is_valid_tls_version checks if the given version is a valid TLS version as -// defined in the TLS specification. -static __always_inline bool is_valid_tls_version(__u16 version) { - switch (version) { - case SSL_VERSION20: - case SSL_VERSION30: - case TLS_VERSION10: - case TLS_VERSION11: - case TLS_VERSION12: - case TLS_VERSION13: + __u32 handshake_length = (handshake_length_bytes[0] << 16) | + (handshake_length_bytes[1] << 8) | + handshake_length_bytes[2]; + + // Verify that the handshake message length plus the 4-byte handshake header (1 byte type + 3 bytes length) + // matches the total length defined in the record header. + // If handshake_length + TLS_HELLO_MESSAGE_HEADER_SIZE != hdr->length, the handshake message structure is inconsistent. + if (handshake_length + TLS_HELLO_MESSAGE_HEADER_SIZE != hdr->length) { + return false; + } + + // Check that the handshake_type is one of the expected values (ClientHello or ServerHello). + // This ensures we are dealing with a known handshake message type. + if (handshake_type != TLS_HANDSHAKE_CLIENT_HELLO && handshake_type != TLS_HANDSHAKE_SERVER_HELLO) { + return false; + } + + // At this point, we've confirmed: + // - The handshake message fits within the record. + // - The handshake_type is a known TLS Hello message. + // - The handshake_length matches the record header's length. + return true; +} + +// is_tls checks if the packet is a TLS packet by reading and validating the TLS record header +// Reference: RFC 5246 Section 6.2.1 (Record Layer), https://tools.ietf.org/html/rfc5246#section-6.2.1 +// Validates that content_type matches known TLS types (Handshake, Application Data, etc.). +static __always_inline bool is_tls(struct __sk_buff *skb, __u32 header_offset, __u32 data_end, tls_record_header_t *tls_hdr) { + // Read and validate the TLS record header + if (!read_tls_record_header(skb, header_offset, data_end, tls_hdr)) { + return false; + } + + switch (tls_hdr->content_type) { + case TLS_HANDSHAKE: + return is_valid_tls_handshake(skb, header_offset, data_end, tls_hdr); + case TLS_APPLICATION_DATA: + case TLS_CHANGE_CIPHER_SPEC: + case TLS_ALERT: return true; + default: + return false; + } +} + +// parse_tls_handshake_header extracts handshake_length and protocol_version from a TLS handshake message +// References: +// - RFC 5246 Section 7.4 (Handshake Protocol Overview), https://tools.ietf.org/html/rfc5246#section-7.4 +// For ClientHello and ServerHello, this includes parsing the handshake type (skipped prior) and the 3-byte length field, followed by a 2-byte protocol version field. +static __always_inline bool parse_tls_handshake_header(struct __sk_buff *skb, __u32 *offset, __u32 data_end, __u32 *handshake_length, __u16 *protocol_version) { + *offset += SINGLE_BYTE_LENGTH; // Move past handshake type (1 byte) + + // Read handshake length (3 bytes) + if (*offset + TLS_HANDSHAKE_LENGTH > data_end) { + return false; + } + __u8 handshake_length_bytes[TLS_HANDSHAKE_LENGTH]; + if (bpf_skb_load_bytes(skb, *offset, handshake_length_bytes, TLS_HANDSHAKE_LENGTH) < 0) { + return false; } + *handshake_length = (handshake_length_bytes[0] << 16) | + (handshake_length_bytes[1] << 8) | + handshake_length_bytes[2]; + *offset += TLS_HANDSHAKE_LENGTH; - return false; + // Ensure we don't read beyond the packet + if (*offset + *handshake_length > data_end) { + return false; + } + + // Read protocol version (2 bytes) + if (*offset + PROTOCOL_VERSION_LENGTH > data_end) { + return false; + } + __u16 version; + if (bpf_skb_load_bytes(skb, *offset, &version, PROTOCOL_VERSION_LENGTH) < 0) { + return false; + } + *protocol_version = bpf_ntohs(version); + *offset += PROTOCOL_VERSION_LENGTH; + + return true; } -// is_valid_tls_app_data checks if the buffer is a valid TLS Application Data -// record header. The record header is considered valid if: -// - the TLS version field is a known SSL/TLS version -// - the payload length is below the maximum payload length defined in the -// standard. -// - the payload length + the size of the record header is less than the size -// of the skb -static __always_inline bool is_valid_tls_app_data(tls_record_header_t *hdr, __u32 buf_size, __u32 skb_len) { - return sizeof(*hdr) + hdr->length <= skb_len; +// skip_random_and_session_id Skips the Random (32 bytes) and the Session ID from the TLS Hello messages +// References: +// - RFC 5246 Section 7.4.1.2 (Client Hello and Server Hello): https://tools.ietf.org/html/rfc5246#section-7.4.1.2 +// ClientHello and ServerHello contain a "random" field (32 bytes) followed by a "session_id_length" (1 byte) +// and a session_id of that length. This helper increments the offset accordingly after reading and skipping these fields. +static __always_inline bool skip_random_and_session_id(struct __sk_buff *skb, __u32 *offset, __u32 data_end) { + // Skip Random (32 bytes) + *offset += RANDOM_LENGTH; + + // Read Session ID Length (1 byte) + if (*offset + SESSION_ID_LENGTH > data_end) { + return false; + } + __u8 session_id_length; + if (bpf_skb_load_bytes(skb, *offset, &session_id_length, SESSION_ID_LENGTH) < 0) { + return false; + } + *offset += SESSION_ID_LENGTH; + + // Skip Session ID + *offset += session_id_length; + + // Ensure we don't read beyond the packet + return *offset <= data_end; +} + +// parse_supported_versions_extension looks for the supported_versions extension in the ClientHello or ServerHello and populates tags +// References: +// - For TLS 1.3 supported_versions extension: RFC 8446 Section 4.2.1: https://tools.ietf.org/html/rfc8446#section-4.2.1 +// For ClientHello this extension contains a list of supported versions (2 bytes each) preceded by a 1-byte length. +// supported_versions extension structure: +// +-----+--------------------+ +// | len(1) | versions(2 * N) | +// +-----+--------------------+ +// For ServerHello (TLS 1.3), it contains a single selected_version (2 bytes). +// +---------------------+ +// | selected_version(2) | +// +---------------------+ +static __always_inline bool parse_supported_versions_extension(struct __sk_buff *skb, __u32 *offset, __u32 data_end, __u32 extensions_end, tls_info_t *tags, bool is_client_hello) { + if (is_client_hello) { + // Read supported version list length (1 byte) + if (*offset + SINGLE_BYTE_LENGTH > data_end || *offset + SINGLE_BYTE_LENGTH > extensions_end) { + return false; + } + __u8 sv_list_length; + if (bpf_skb_load_bytes(skb, *offset, &sv_list_length, SINGLE_BYTE_LENGTH) < 0) { + return false; + } + *offset += SINGLE_BYTE_LENGTH; + + if (*offset + sv_list_length > data_end || *offset + sv_list_length > extensions_end) { + return false; + } + + // Parse the list of supported versions (2 bytes each) + __u8 sv_offset = 0; + __u16 sv_version; + #pragma unroll(MAX_SUPPORTED_VERSIONS) + for (int idx = 0; idx < MAX_SUPPORTED_VERSIONS; idx++) { + if (sv_offset + 1 >= sv_list_length) { + break; + } + // Each supported version is 2 bytes + if (*offset + PROTOCOL_VERSION_LENGTH > data_end) { + return false; + } + + if (bpf_skb_load_bytes(skb, *offset, &sv_version, PROTOCOL_VERSION_LENGTH) < 0) { + return false; + } + sv_version = bpf_ntohs(sv_version); + *offset += PROTOCOL_VERSION_LENGTH; + + set_tls_offered_version(tags, sv_version); + sv_offset += PROTOCOL_VERSION_LENGTH; + } + } else { + // ServerHello + // The selected_version field is 2 bytes + if (*offset + PROTOCOL_VERSION_LENGTH > data_end) { + return false; + } + + // Read selected version (2 bytes) + __u16 selected_version; + if (bpf_skb_load_bytes(skb, *offset, &selected_version, PROTOCOL_VERSION_LENGTH) < 0) { + return false; + } + selected_version = bpf_ntohs(selected_version); + *offset += PROTOCOL_VERSION_LENGTH; + + tags->chosen_version = selected_version; + } + + return true; +} + +// parse_tls_extensions parses TLS extensions in both ClientHello and ServerHello +// References: +// - RFC 5246 Section 7.4.1.4 (Hello Extensions): https://tools.ietf.org/html/rfc5246#section-7.4.1.4 +// - For TLS 1.3 supported_versions extension: RFC 8446 Section 4.2.1: https://tools.ietf.org/html/rfc8446#section-4.2.1 +// This function iterates over extensions, reading the extension_type and extension_length, and if it encounters +// the supported_versions extension, it calls parse_supported_versions_extension to handle it. +// ASCII snippet for a single extension: +// +---------+---------+--------------------------------+ +// | ext_type(2) | ext_length(2) | ext_data(ext_length) | +// +---------+---------+--------------------------------+ +// For multiple extensions, they are just concatenated one after another. +static __always_inline bool parse_tls_extensions(struct __sk_buff *skb, __u32 *offset, __u32 data_end, __u32 extensions_end, tls_info_t *tags, bool is_client_hello) { + __u16 extension_type; + __u16 extension_length; + + #pragma unroll(MAX_EXTENSIONS) + for (int i = 0; i < MAX_EXTENSIONS; i++) { + if (*offset + MIN_EXTENSION_HEADER_LENGTH > extensions_end) { + break; + } + + // Read Extension Type (2 bytes) + if (bpf_skb_load_bytes(skb, *offset, &extension_type, EXTENSION_TYPE_LENGTH) < 0) { + return false; + } + extension_type = bpf_ntohs(extension_type); + *offset += EXTENSION_TYPE_LENGTH; + + // Read Extension Length (2 bytes) + if (bpf_skb_load_bytes(skb, *offset, &extension_length, EXTENSION_LENGTH_FIELD) < 0) { + return false; + } + extension_length = bpf_ntohs(extension_length); + *offset += EXTENSION_LENGTH_FIELD; + + if (*offset + extension_length > data_end || *offset + extension_length > extensions_end) { + return false; + } + + if (extension_type == SUPPORTED_VERSIONS_EXTENSION) { + if (!parse_supported_versions_extension(skb, offset, data_end, extensions_end, tags, is_client_hello)) { + return false; + } + } else { + // Skip other extensions + *offset += extension_length; + } + + if (*offset >= extensions_end) { + break; + } + } + + return true; } -// is_tls_handshake checks if the given TLS message header is a valid TLS -// handshake message. The message is considered valid if: -// - The type matches CLIENT_HELLO or SERVER_HELLO -// - The version is a known SSL/TLS version -static __always_inline bool is_tls_handshake(tls_record_header_t *hdr, const char *buf, __u32 buf_size) { - // Checking the buffer size contains at least the size of the tls record header and the tls hello message header. - if (sizeof(tls_record_header_t) + sizeof(tls_hello_message_t) > buf_size) { +// parse_client_hello parses the ClientHello message and populates tags +// Reference: RFC 5246 Section 7.4.1.2 (Client Hello), https://tools.ietf.org/html/rfc5246 +// Structure (simplified): +// handshake_type (1 byte), length (3 bytes), version (2 bytes), random(32 bytes), session_id_length(1 byte), session_id(variable), cipher_suites_length(2 bytes), cipher_suites(variable), compression_methods_length(1 byte), compression_methods(variable), extensions_length(2 bytes), extensions(variable) +// After the handshake header (handshake_type + length), the ClientHello fields are: +// +----------------------------+ +// | client_version (2) | +// +----------------------------+ +// | random (32) | +// +----------------------------+ +// | session_id_length (1) | +// | session_id (...) | +// +----------------------------+ +// | cipher_suites_length(2) | +// | cipher_suites(...) | +// +----------------------------+ +// | compression_methods_len(1) | +// | compression_methods(...) | +// +----------------------------+ +// | extensions_length (2) | +// | extensions(...) | +// +----------------------------+ +static __always_inline bool parse_client_hello(struct __sk_buff *skb, __u32 offset, __u32 data_end, tls_info_t *tags) { + __u32 handshake_length; + __u16 client_version; + + if (!parse_tls_handshake_header(skb, &offset, data_end, &handshake_length, &client_version)) { + return false; + } + + set_tls_offered_version(tags, client_version); + + // TLS 1.2 is the highest version we will see in the header. If the connection is actually a higher version (1.3), + // it must be extracted from the extensions. Lower versions (1.0, 1.1) will not have extensions. + if (client_version != TLS_VERSION12) { + return true; + } + + if (!skip_random_and_session_id(skb, &offset, data_end)) { + return false; + } + + // Read Cipher Suites Length (2 bytes) + if (offset + CIPHER_SUITES_LENGTH > data_end) { + return false; + } + __u16 cipher_suites_length; + if (bpf_skb_load_bytes(skb, offset, &cipher_suites_length, CIPHER_SUITES_LENGTH) < 0) { + return false; + } + cipher_suites_length = bpf_ntohs(cipher_suites_length); + offset += CIPHER_SUITES_LENGTH; + + // Skip Cipher Suites + offset += cipher_suites_length; + + // Read Compression Methods Length (1 byte) + if (offset + COMPRESSION_METHODS_LENGTH > data_end) { + return false; + } + __u8 compression_methods_length; + if (bpf_skb_load_bytes(skb, offset, &compression_methods_length, COMPRESSION_METHODS_LENGTH) < 0) { return false; } - // Checking the tls record header length is greater than the tls hello message header length. - if (hdr->length < sizeof(tls_hello_message_t)) { + offset += COMPRESSION_METHODS_LENGTH; + + // Skip Compression Methods + offset += compression_methods_length; + + // Check if extensions are present + if (offset + EXTENSION_LENGTH_FIELD > data_end) { return false; } - // Getting the tls hello message header. - tls_hello_message_t msg = *(tls_hello_message_t *)(buf + sizeof(tls_record_header_t)); - // If the message is not a CLIENT_HELLO or SERVER_HELLO, we don't attempt to classify. - if (msg.handshake_type != TLS_HANDSHAKE_CLIENT_HELLO && msg.handshake_type != TLS_HANDSHAKE_SERVER_HELLO) { + // Read Extensions Length (2 bytes) + __u16 extensions_length; + if (bpf_skb_load_bytes(skb, offset, &extensions_length, EXTENSION_LENGTH_FIELD) < 0) { return false; } - // Converting the fields to host byte order. - __u32 length = msg.length[0] << 16 | msg.length[1] << 8 | msg.length[2]; - // TLS handshake message length should be equal to the record header length minus the size of the hello message - // header. - if (length + TLS_HELLO_MESSAGE_HEADER_SIZE != hdr->length) { + extensions_length = bpf_ntohs(extensions_length); + offset += EXTENSION_LENGTH_FIELD; + + if (offset + extensions_length > data_end) { return false; } - msg.version = bpf_ntohs(msg.version); - return is_valid_tls_version(msg.version) && msg.version >= hdr->version; + __u32 extensions_end = offset + extensions_length; + + return parse_tls_extensions(skb, &offset, data_end, extensions_end, tags, true); } -// is_tls checks if the given buffer is a valid TLS record header. We are -// currently checking for two types of record headers: -// - TLS Handshake record headers -// - TLS Application Data record headers -static __always_inline bool is_tls(const char *buf, __u32 buf_size, __u32 skb_len) { - if (buf_size < sizeof(tls_record_header_t)) { +// parse_server_hello parses the ServerHello message and populates tags +// Reference: RFC 5246 Section 7.4.1.2 (Server Hello), https://tools.ietf.org/html/rfc5246 +// Structure (simplified): +// handshake_type(1), length(3), version(2), random(32), session_id_length(1), session_id(variable), cipher_suite(2), compression_method(1), extensions_length(2), extensions(variable) +// After the handshake header (handshake_type + length), the ServerHello fields are: +// +------------------------+ +// | server_version (2) | +// +------------------------+ +// | random (32) | +// +------------------------+ +// | session_id_length (1) | +// | session_id (...) | +// +------------------------+ +// | cipher_suite (2) | +// +------------------------+ +// | compression_method (1) | +// +------------------------+ +// | extensions_length(2) | +// | extensions(...) | +// +------------------------+ +static __always_inline bool parse_server_hello(struct __sk_buff *skb, __u32 offset, __u32 data_end, tls_info_t *tags) { + __u32 handshake_length; + __u16 server_version; + + if (!parse_tls_handshake_header(skb, &offset, data_end, &handshake_length, &server_version)) { return false; } - // Copying struct to the stack, to avoid modifying the original buffer that will be used for other classifiers. - tls_record_header_t tls_record_header = *(tls_record_header_t *)buf; - // Converting the fields to host byte order. - tls_record_header.version = bpf_ntohs(tls_record_header.version); - tls_record_header.length = bpf_ntohs(tls_record_header.length); + // Set the version here and try to get the "real" version from the extensions if possible + // Note: In TLS 1.3, the server_version field is set to 1.2 + // The actual version is embedded in the supported_versions extension + tags->chosen_version = server_version; - // Checking the version in the record header. - if (!is_valid_tls_version(tls_record_header.version)) { + if (!skip_random_and_session_id(skb, &offset, data_end)) { return false; } - // Checking the length in the record header is not greater than the maximum payload length. - if (tls_record_header.length > TLS_MAX_PAYLOAD_LENGTH) { + // Read Cipher Suite (2 bytes) + if (offset + CIPHER_SUITES_LENGTH > data_end) { return false; } - switch (tls_record_header.content_type) { - case TLS_HANDSHAKE: - return is_tls_handshake(&tls_record_header, buf, buf_size); - case TLS_APPLICATION_DATA: - return is_valid_tls_app_data(&tls_record_header, buf_size, skb_len); + __u16 cipher_suite; + if (bpf_skb_load_bytes(skb, offset, &cipher_suite, CIPHER_SUITES_LENGTH) < 0) { + return false; + } + cipher_suite = bpf_ntohs(cipher_suite); + offset += CIPHER_SUITES_LENGTH; + + // Skip Compression Method (1 byte) + offset += COMPRESSION_METHODS_LENGTH; + + tags->cipher_suite = cipher_suite; + + // TLS 1.2 is the highest version we will see in the header. If the connection is actually a higher version (1.3), + // it must be extracted from the extensions. Lower versions (1.0, 1.1) will not have extensions. + if (tags->chosen_version != TLS_VERSION12) { + return true; + } + + if (offset + EXTENSION_LENGTH_FIELD > data_end) { + return false; + } + + // Read Extensions Length (2 bytes) + __u16 extensions_length; + if (bpf_skb_load_bytes(skb, offset, &extensions_length, EXTENSION_LENGTH_FIELD) < 0) { + return false; + } + extensions_length = bpf_ntohs(extensions_length); + offset += EXTENSION_LENGTH_FIELD; + + __u32 handshake_end = offset + handshake_length; + if (offset + extensions_length > data_end || offset + extensions_length > handshake_end) { + return false; + } + + __u32 extensions_end = offset + extensions_length; + + return parse_tls_extensions(skb, &offset, data_end, extensions_end, tags, false); +} + +// is_tls_handshake_type checks if the handshake type at the given offset matches the expected type (e.g., ClientHello or ServerHello) +// References: +// - RFC 5246 Section 7.4 (Handshake Protocol Overview), https://tools.ietf.org/html/rfc5246#section-7.4 +// The handshake_type is a single byte enumerated value. +static __always_inline bool is_tls_handshake_type(struct __sk_buff *skb, __u32 offset, __u32 data_end, __u8 expected_handshake_type) { + // The handshake type is a single byte enumerated value + if (offset + SINGLE_BYTE_LENGTH > data_end) { + return false; + } + __u8 handshake_type; + if (bpf_skb_load_bytes(skb, offset, &handshake_type, SINGLE_BYTE_LENGTH) < 0) { + return false; } - return false; + return handshake_type == expected_handshake_type; +} + +// is_tls_handshake_client_hello checks if the packet is a TLS ClientHello message +static __always_inline bool is_tls_handshake_client_hello(struct __sk_buff *skb, __u32 offset, __u32 data_end) { + return is_tls_handshake_type(skb, offset, data_end, TLS_HANDSHAKE_CLIENT_HELLO); +} + +// is_tls_handshake_server_hello checks if the packet is a TLS ServerHello message +static __always_inline bool is_tls_handshake_server_hello(struct __sk_buff *skb, __u32 offset, __u32 data_end) { + return is_tls_handshake_type(skb, offset, data_end, TLS_HANDSHAKE_SERVER_HELLO); } -#endif +#endif // __TLS_H diff --git a/pkg/network/ebpf/c/tracer.c b/pkg/network/ebpf/c/tracer.c index 66cefbee1d08d..5ff4ffa95f18b 100644 --- a/pkg/network/ebpf/c/tracer.c +++ b/pkg/network/ebpf/c/tracer.c @@ -33,6 +33,18 @@ int socket__classifier_entry(struct __sk_buff *skb) { return 0; } +SEC("socket/classifier_tls_handshake_client") +int socket__classifier_tls_handshake_client(struct __sk_buff *skb) { + protocol_classifier_entrypoint_tls_handshake_client(skb); + return 0; +} + +SEC("socket/classifier_tls_handshake_server") +int socket__classifier_tls_handshake_server(struct __sk_buff *skb) { + protocol_classifier_entrypoint_tls_handshake_server(skb); + return 0; +} + SEC("socket/classifier_queues") int socket__classifier_queues(struct __sk_buff *skb) { protocol_classifier_entrypoint_queues(skb); diff --git a/pkg/network/ebpf/c/tracer/events.h b/pkg/network/ebpf/c/tracer/events.h index 1de94d755b1f3..85f8ed48e1bab 100644 --- a/pkg/network/ebpf/c/tracer/events.h +++ b/pkg/network/ebpf/c/tracer/events.h @@ -23,6 +23,7 @@ static __always_inline void clean_protocol_classification(conn_tuple_t *tup) { conn_tuple.netns = 0; normalize_tuple(&conn_tuple); delete_protocol_stack(&conn_tuple, NULL, FLAG_TCP_CLOSE_DELETION); + bpf_map_delete_elem(&tls_enhanced_tags, &conn_tuple); conn_tuple_t *skb_tup_ptr = bpf_map_lookup_elem(&conn_tuple_to_socket_skb_conn_tuple, &conn_tuple); if (skb_tup_ptr == NULL) { @@ -31,6 +32,7 @@ static __always_inline void clean_protocol_classification(conn_tuple_t *tup) { conn_tuple_t skb_tup = *skb_tup_ptr; delete_protocol_stack(&skb_tup, NULL, FLAG_TCP_CLOSE_DELETION); + bpf_map_delete_elem(&tls_enhanced_tags, &skb_tup); bpf_map_delete_elem(&conn_tuple_to_socket_skb_conn_tuple, &conn_tuple); } diff --git a/pkg/network/ebpf/c/tracer/maps.h b/pkg/network/ebpf/c/tracer/maps.h index e6123782f8ea5..0141a970d1e5e 100644 --- a/pkg/network/ebpf/c/tracer/maps.h +++ b/pkg/network/ebpf/c/tracer/maps.h @@ -132,4 +132,7 @@ BPF_HASH_MAP(tcp_close_args, __u64, conn_tuple_t, 1024) // by using tail call. BPF_PROG_ARRAY(tcp_close_progs, 1) +// Map to store extra information about TLS connections like version, cipher, etc. +BPF_HASH_MAP(tls_enhanced_tags, conn_tuple_t, tls_info_wrapper_t, 0) + #endif diff --git a/pkg/network/ebpf/c/tracer/stats.h b/pkg/network/ebpf/c/tracer/stats.h index 862d8e32f81f1..ca1d6d19e309e 100644 --- a/pkg/network/ebpf/c/tracer/stats.h +++ b/pkg/network/ebpf/c/tracer/stats.h @@ -23,6 +23,55 @@ static __always_inline __u64 offset_rtt(); static __always_inline __u64 offset_rtt_var(); #endif +static __always_inline tls_info_t* get_tls_enhanced_tags(conn_tuple_t* tuple) { + conn_tuple_t normalized_tup = *tuple; + normalize_tuple(&normalized_tup); + tls_info_wrapper_t *wrapper = bpf_map_lookup_elem(&tls_enhanced_tags, &normalized_tup); + if (!wrapper) { + return NULL; + } + wrapper->updated = bpf_ktime_get_ns(); + return &wrapper->info; +} + +static __always_inline tls_info_t* get_or_create_tls_enhanced_tags(conn_tuple_t *tuple) { + tls_info_t *tags = get_tls_enhanced_tags(tuple); + if (!tags) { + conn_tuple_t normalized_tup = *tuple; + normalize_tuple(&normalized_tup); + tls_info_wrapper_t empty_tags_wrapper = {}; + empty_tags_wrapper.updated = bpf_ktime_get_ns(); + + bpf_map_update_with_telemetry(tls_enhanced_tags, &normalized_tup, &empty_tags_wrapper, BPF_ANY); + tls_info_wrapper_t *wrapper_ptr = bpf_map_lookup_elem(&tls_enhanced_tags, &normalized_tup); + if (!wrapper_ptr) { + return NULL; + } + tags = &wrapper_ptr->info; + } + return tags; +} + +// merge_tls_info modifies `this` by merging it with `that` +static __always_inline void merge_tls_info(tls_info_t *this, tls_info_t *that) { + if (!this || !that) { + return; + } + + // Merge chosen_version if not already set + if (this->chosen_version == 0 && that->chosen_version != 0) { + this->chosen_version = that->chosen_version; + } + + // Merge cipher_suite if not already set + if (this->cipher_suite == 0 && that->cipher_suite != 0) { + this->cipher_suite = that->cipher_suite; + } + + // Merge offered_versions bitmask + this->offered_versions |= that->offered_versions; +} + static __always_inline conn_stats_ts_t *get_conn_stats(conn_tuple_t *t, struct sock *sk) { conn_stats_ts_t *cs = bpf_map_lookup_elem(&conn_stats, t); if (cs) { @@ -112,6 +161,9 @@ static __always_inline void update_protocol_classification_information(conn_tupl mark_protocol_direction(t, &conn_tuple_copy, protocol_stack); merge_protocol_stacks(&stats->protocol_stack, protocol_stack); + tls_info_t *tls_tags = get_tls_enhanced_tags(&conn_tuple_copy); + merge_tls_info(&stats->tls_tags, tls_tags); + conn_tuple_t *cached_skb_conn_tup_ptr = bpf_map_lookup_elem(&conn_tuple_to_socket_skb_conn_tuple, &conn_tuple_copy); if (!cached_skb_conn_tup_ptr) { return; @@ -124,6 +176,9 @@ static __always_inline void update_protocol_classification_information(conn_tupl set_protocol_flag(protocol_stack, FLAG_NPM_ENABLED); mark_protocol_direction(t, &conn_tuple_copy, protocol_stack); merge_protocol_stacks(&stats->protocol_stack, protocol_stack); + + tls_tags = get_tls_enhanced_tags(&conn_tuple_copy); + merge_tls_info(&stats->tls_tags, tls_tags); } static __always_inline void determine_connection_direction(conn_tuple_t *t, conn_stats_ts_t *conn_stats) { diff --git a/pkg/network/ebpf/c/tracer/tracer.h b/pkg/network/ebpf/c/tracer/tracer.h index d688359dc9d79..c23c98b189162 100644 --- a/pkg/network/ebpf/c/tracer/tracer.h +++ b/pkg/network/ebpf/c/tracer/tracer.h @@ -29,6 +29,17 @@ typedef enum { #define CONN_DIRECTION_MASK 0b11 +typedef struct { + __u16 chosen_version; + __u16 cipher_suite; + __u8 offered_versions; +} tls_info_t; + +typedef struct { + tls_info_t info; + __u64 updated; +} tls_info_wrapper_t; + typedef struct { __u64 sent_bytes; __u64 recv_bytes; @@ -54,6 +65,7 @@ typedef struct { protocol_stack_t protocol_stack; __u8 flags; __u8 direction; + tls_info_t tls_tags; } conn_stats_ts_t; // Connection flags diff --git a/pkg/network/ebpf/kprobe_types.go b/pkg/network/ebpf/kprobe_types.go index 90d7eb1f331ae..ab5f569a9c3ec 100644 --- a/pkg/network/ebpf/kprobe_types.go +++ b/pkg/network/ebpf/kprobe_types.go @@ -30,6 +30,8 @@ type UDPRecvSock C.udp_recv_sock_t type BindSyscallArgs C.bind_syscall_args_t type ProtocolStack C.protocol_stack_t type ProtocolStackWrapper C.protocol_stack_wrapper_t +type TLSTags C.tls_info_t +type TLSTagsWrapper C.tls_info_wrapper_t // udp_recv_sock_t have *sock and *msghdr struct members, we make them opaque here type _Ctype_struct_sock uint64 @@ -63,7 +65,9 @@ const SizeofConn = C.sizeof_conn_t type ClassificationProgram = uint32 const ( - ClassificationQueues ClassificationProgram = C.CLASSIFICATION_QUEUES_PROG - ClassificationDBs ClassificationProgram = C.CLASSIFICATION_DBS_PROG - ClassificationGRPC ClassificationProgram = C.CLASSIFICATION_GRPC_PROG + ClassificationTLSClient ClassificationProgram = C.CLASSIFICATION_TLS_CLIENT_PROG + ClassificationTLSServer ClassificationProgram = C.CLASSIFICATION_TLS_SERVER_PROG + ClassificationQueues ClassificationProgram = C.CLASSIFICATION_QUEUES_PROG + ClassificationDBs ClassificationProgram = C.CLASSIFICATION_DBS_PROG + ClassificationGRPC ClassificationProgram = C.CLASSIFICATION_GRPC_PROG ) diff --git a/pkg/network/ebpf/kprobe_types_linux.go b/pkg/network/ebpf/kprobe_types_linux.go index bf9bbf38210f3..2953275bd5bfb 100644 --- a/pkg/network/ebpf/kprobe_types_linux.go +++ b/pkg/network/ebpf/kprobe_types_linux.go @@ -32,7 +32,7 @@ type ConnStats struct { Protocol_stack ProtocolStack Flags uint8 Direction uint8 - Pad_cgo_0 [6]byte + Tls_tags TLSTags } type Conn struct { Tup ConnTuple @@ -103,6 +103,16 @@ type ProtocolStackWrapper struct { Stack ProtocolStack Updated uint64 } +type TLSTags struct { + Chosen_version uint16 + Cipher_suite uint16 + Offered_versions uint8 + Pad_cgo_0 [1]byte +} +type TLSTagsWrapper struct { + Info TLSTags + Updated uint64 +} type _Ctype_struct_sock uint64 type _Ctype_struct_msghdr uint64 @@ -135,7 +145,9 @@ const SizeofConn = 0x78 type ClassificationProgram = uint32 const ( - ClassificationQueues ClassificationProgram = 0x2 - ClassificationDBs ClassificationProgram = 0x3 - ClassificationGRPC ClassificationProgram = 0x5 + ClassificationTLSClient ClassificationProgram = 0x7 + ClassificationTLSServer ClassificationProgram = 0x8 + ClassificationQueues ClassificationProgram = 0x2 + ClassificationDBs ClassificationProgram = 0x3 + ClassificationGRPC ClassificationProgram = 0x5 ) diff --git a/pkg/network/ebpf/probes/probes.go b/pkg/network/ebpf/probes/probes.go index cd72ca3ba1fe9..1f983ed25b691 100644 --- a/pkg/network/ebpf/probes/probes.go +++ b/pkg/network/ebpf/probes/probes.go @@ -27,6 +27,10 @@ const ( // ProtocolClassifierEntrySocketFilter runs a classifier algorithm as a socket filter ProtocolClassifierEntrySocketFilter ProbeFuncName = "socket__classifier_entry" + // ProtocolClassifierTLSClientSocketFilter runs classification rules for the TLS client hello packet + ProtocolClassifierTLSClientSocketFilter ProbeFuncName = "socket__classifier_tls_handshake_client" + // ProtocolClassifierTLSServerSocketFilter runs classification rules for the TLS server hello packet + ProtocolClassifierTLSServerSocketFilter ProbeFuncName = "socket__classifier_tls_handshake_server" // ProtocolClassifierQueuesSocketFilter runs a classification rules for Queue protocols. ProtocolClassifierQueuesSocketFilter ProbeFuncName = "socket__classifier_queues" // ProtocolClassifierDBsSocketFilter runs a classification rules for DB protocols. @@ -232,6 +236,8 @@ const ( ConnectionProtocolMap BPFMapName = "connection_protocol" // ConnectionTupleToSocketSKBConnMap is the map storing the connection tuple to socket skb conn tuple ConnectionTupleToSocketSKBConnMap BPFMapName = "conn_tuple_to_socket_skb_conn_tuple" + // EnhancedTLSTagsMap is the map storing additional tags for TLS connections (version, cipher, etc.) + EnhancedTLSTagsMap BPFMapName = "tls_enhanced_tags" // ClassificationProgsMap is the map storing the programs to run on classification events ClassificationProgsMap BPFMapName = "classification_progs" // TCPCloseProgsMap is the map storing the programs to run on TCP close events diff --git a/pkg/network/encoding/encoding_test.go b/pkg/network/encoding/encoding_test.go index e0d0eacfde10c..bc75da14be473 100644 --- a/pkg/network/encoding/encoding_test.go +++ b/pkg/network/encoding/encoding_test.go @@ -29,6 +29,7 @@ import ( "github.com/DataDog/datadog-agent/pkg/network/protocols/http" "github.com/DataDog/datadog-agent/pkg/network/protocols/kafka" "github.com/DataDog/datadog-agent/pkg/network/protocols/telemetry" + "github.com/DataDog/datadog-agent/pkg/network/protocols/tls" "github.com/DataDog/datadog-agent/pkg/process/util" ) @@ -229,6 +230,7 @@ func TestSerialization(t *testing.T) { }, }, ProtocolStack: protocols.Stack{Application: protocols.HTTP}, + TLSTags: tls.Tags{ChosenVersion: 0, CipherSuite: 0, OfferedVersions: 0}, }, {ConnectionTuple: network.ConnectionTuple{ Source: util.AddressFromString("10.1.1.1"), @@ -241,6 +243,7 @@ func TestSerialization(t *testing.T) { Direction: network.LOCAL, StaticTags: tagOpenSSL | tagTLS, ProtocolStack: protocols.Stack{Application: protocols.HTTP2}, + TLSTags: tls.Tags{ChosenVersion: 0, CipherSuite: 0, OfferedVersions: 0}, DNSStats: map[dns.Hostname]map[dns.QueryType]dns.Stats{ dns.ToHostname("foo.com"): { dns.TypeA: { diff --git a/pkg/network/encoding/marshal/format.go b/pkg/network/encoding/marshal/format.go index 99cf5c4c75694..2e0d40955191a 100644 --- a/pkg/network/encoding/marshal/format.go +++ b/pkg/network/encoding/marshal/format.go @@ -120,9 +120,10 @@ func FormatConnection(builder *model.ConnectionBuilder, conn network.ConnectionS httpStaticTags, httpDynamicTags := httpEncoder.GetHTTPAggregationsAndTags(conn, builder) http2StaticTags, http2DynamicTags := http2Encoder.WriteHTTP2AggregationsAndTags(conn, builder) + tlsDynamicTags := conn.TLSTags.GetDynamicTags() staticTags := httpStaticTags | http2StaticTags - dynamicTags := mergeDynamicTags(httpDynamicTags, http2DynamicTags) + dynamicTags := mergeDynamicTags(httpDynamicTags, http2DynamicTags, tlsDynamicTags) staticTags |= kafkaEncoder.WriteKafkaAggregations(conn, builder) staticTags |= postgresEncoder.WritePostgresAggregations(conn, builder) diff --git a/pkg/network/event_common.go b/pkg/network/event_common.go index d8cf93f592535..83592686c96bf 100644 --- a/pkg/network/event_common.go +++ b/pkg/network/event_common.go @@ -23,6 +23,7 @@ import ( "github.com/DataDog/datadog-agent/pkg/network/protocols/kafka" "github.com/DataDog/datadog-agent/pkg/network/protocols/postgres" "github.com/DataDog/datadog-agent/pkg/network/protocols/redis" + "github.com/DataDog/datadog-agent/pkg/network/protocols/tls" "github.com/DataDog/datadog-agent/pkg/process/util" ) @@ -283,6 +284,7 @@ type ConnectionStats struct { RTTVar uint32 StaticTags uint64 ProtocolStack protocols.Stack + TLSTags tls.Tags // keep these fields last because they are 1 byte each and otherwise inflate the struct size due to alignment Direction ConnectionDirection diff --git a/pkg/network/protocols/tls/types.go b/pkg/network/protocols/tls/types.go new file mode 100644 index 0000000000000..c3014c3f65f3e --- /dev/null +++ b/pkg/network/protocols/tls/types.go @@ -0,0 +1,135 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2024-present Datadog, Inc. + +// Package tls contains definitions and methods related to tags parsed from the TLS handshake +package tls + +import ( + "crypto/tls" + "fmt" +) + +// Constants for tag keys +const ( + TagTLSVersion = "tls.version:" + TagTLSCipherSuiteID = "tls.cipher_suite_id:" + TagTLSClientVersion = "tls.client_version:" + version10 = "tls_1.0" + version11 = "tls_1.1" + version12 = "tls_1.2" + version13 = "tls_1.3" +) + +// Bitmask constants for Offered_versions matching kernelspace definitions +const ( + OfferedTLSVersion10 uint8 = 0x01 + OfferedTLSVersion11 uint8 = 0x02 + OfferedTLSVersion12 uint8 = 0x04 + OfferedTLSVersion13 uint8 = 0x08 +) + +// VersionTags maps TLS versions to tag names for server chosen version (exported for testing) +var VersionTags = map[uint16]string{ + tls.VersionTLS10: TagTLSVersion + version10, + tls.VersionTLS11: TagTLSVersion + version11, + tls.VersionTLS12: TagTLSVersion + version12, + tls.VersionTLS13: TagTLSVersion + version13, +} + +// ClientVersionTags maps TLS versions to tag names for client offered versions (exported for testing) +var ClientVersionTags = map[uint16]string{ + tls.VersionTLS10: TagTLSClientVersion + version10, + tls.VersionTLS11: TagTLSClientVersion + version11, + tls.VersionTLS12: TagTLSClientVersion + version12, + tls.VersionTLS13: TagTLSClientVersion + version13, +} + +// Mapping of offered version bitmasks to version constants +var offeredVersionBitmask = []struct { + bitMask uint8 + version uint16 +}{ + {OfferedTLSVersion10, tls.VersionTLS10}, + {OfferedTLSVersion11, tls.VersionTLS11}, + {OfferedTLSVersion12, tls.VersionTLS12}, + {OfferedTLSVersion13, tls.VersionTLS13}, +} + +// Tags holds the TLS tags. It is used to store the TLS version, cipher suite and offered versions. +// We can't use the struct from eBPF as the definition is shared with windows. +type Tags struct { + ChosenVersion uint16 + CipherSuite uint16 + OfferedVersions uint8 +} + +// MergeWith merges the tags from another Tags struct into this one +func (t *Tags) MergeWith(that Tags) { + if t.ChosenVersion == 0 { + t.ChosenVersion = that.ChosenVersion + } + if t.CipherSuite == 0 { + t.CipherSuite = that.CipherSuite + } + if t.OfferedVersions == 0 { + t.OfferedVersions = that.OfferedVersions + } + +} + +// IsEmpty returns true if all fields are zero +func (t *Tags) IsEmpty() bool { + if t == nil { + return true + } + return t.ChosenVersion == 0 && t.CipherSuite == 0 && t.OfferedVersions == 0 +} + +// String returns a string representation of the Tags struct +func (t *Tags) String() string { + return fmt.Sprintf("ChosenVersion: %d, CipherSuite: %d, OfferedVersions: %d", t.ChosenVersion, t.CipherSuite, t.OfferedVersions) +} + +// parseOfferedVersions parses the Offered_versions bitmask into a slice of version strings +func parseOfferedVersions(offeredVersions uint8) []string { + versions := make([]string, 0, len(offeredVersionBitmask)) + for _, ov := range offeredVersionBitmask { + if (offeredVersions & ov.bitMask) != 0 { + if name := ClientVersionTags[ov.version]; name != "" { + versions = append(versions, name) + } + } + } + return versions +} + +func hexCipherSuiteTag(cipherSuite uint16) string { + return fmt.Sprintf("%s0x%04X", TagTLSCipherSuiteID, cipherSuite) +} + +// GetDynamicTags generates dynamic tags based on TLS information +func (t *Tags) GetDynamicTags() map[string]struct{} { + if t.IsEmpty() { + return nil + } + tags := make(map[string]struct{}) + + // Server chosen version + if tag, ok := VersionTags[t.ChosenVersion]; ok { + tags[tag] = struct{}{} + } + + // Client offered versions + for _, versionName := range parseOfferedVersions(t.OfferedVersions) { + tags[versionName] = struct{}{} + } + + // Cipher suite ID as hex string + if t.CipherSuite != 0 { + tags[hexCipherSuiteTag(t.CipherSuite)] = struct{}{} + } + + return tags +} diff --git a/pkg/network/protocols/tls/types_test.go b/pkg/network/protocols/tls/types_test.go new file mode 100644 index 0000000000000..979cc2bbdba32 --- /dev/null +++ b/pkg/network/protocols/tls/types_test.go @@ -0,0 +1,128 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2024-present Datadog, Inc. + +package tls + +import ( + "crypto/tls" + "fmt" + "reflect" + "testing" +) + +func TestParseOfferedVersions(t *testing.T) { + tests := []struct { + offeredVersions uint8 + expected []string + }{ + {0x00, []string{}}, // No versions offered + {OfferedTLSVersion10, []string{"tls.client_version:tls_1.0"}}, + {OfferedTLSVersion11, []string{"tls.client_version:tls_1.1"}}, + {OfferedTLSVersion12, []string{"tls.client_version:tls_1.2"}}, + {OfferedTLSVersion13, []string{"tls.client_version:tls_1.3"}}, + {OfferedTLSVersion10 | OfferedTLSVersion12, []string{"tls.client_version:tls_1.0", "tls.client_version:tls_1.2"}}, + {OfferedTLSVersion11 | OfferedTLSVersion13, []string{"tls.client_version:tls_1.1", "tls.client_version:tls_1.3"}}, + {0xFF, []string{"tls.client_version:tls_1.0", "tls.client_version:tls_1.1", "tls.client_version:tls_1.2", "tls.client_version:tls_1.3"}}, // All bits set + {0x40, []string{}}, // Undefined bit set + {0x80, []string{}}, // Undefined bit set + } + + for _, test := range tests { + t.Run(fmt.Sprintf("OfferedVersions_0x%02X", test.offeredVersions), func(t *testing.T) { + result := parseOfferedVersions(test.offeredVersions) + if !reflect.DeepEqual(result, test.expected) { + t.Errorf("parseOfferedVersions(0x%02X) = %v; want %v", test.offeredVersions, result, test.expected) + } + }) + } +} + +func TestGetTLSDynamicTags(t *testing.T) { + tests := []struct { + name string + tlsTags *Tags + expected map[string]struct{} + }{ + { + name: "Nil_TLSTags", + tlsTags: nil, + expected: nil, + }, + { + name: "All_Fields_Populated", + tlsTags: &Tags{ + ChosenVersion: tls.VersionTLS12, + CipherSuite: 0x009C, + OfferedVersions: OfferedTLSVersion11 | OfferedTLSVersion12, + }, + expected: map[string]struct{}{ + "tls.version:tls_1.2": {}, + "tls.cipher_suite_id:0x009C": {}, + "tls.client_version:tls_1.1": {}, + "tls.client_version:tls_1.2": {}, + }, + }, + { + name: "Unknown_Chosen_Version", + tlsTags: &Tags{ + ChosenVersion: 0xFFFF, // Unknown version + CipherSuite: 0x00FF, + OfferedVersions: OfferedTLSVersion13, + }, + expected: map[string]struct{}{ + "tls.cipher_suite_id:0x00FF": {}, + "tls.client_version:tls_1.3": {}, + }, + }, + { + name: "No_Offered_Versions", + tlsTags: &Tags{ + ChosenVersion: tls.VersionTLS13, + CipherSuite: 0x1301, + OfferedVersions: 0x00, + }, + expected: map[string]struct{}{ + "tls.version:tls_1.3": {}, + "tls.cipher_suite_id:0x1301": {}, + }, + }, + { + name: "Zero_Cipher_Suite", + tlsTags: &Tags{ + ChosenVersion: tls.VersionTLS10, + OfferedVersions: OfferedTLSVersion10, + }, + expected: map[string]struct{}{ + "tls.version:tls_1.0": {}, + "tls.client_version:tls_1.0": {}, + }, + }, + { + name: "All_Bits_Set_In_Offered_Versions", + tlsTags: &Tags{ + ChosenVersion: tls.VersionTLS12, + CipherSuite: 0xC02F, + OfferedVersions: 0xFF, // All bits set + }, + expected: map[string]struct{}{ + "tls.version:tls_1.2": {}, + "tls.cipher_suite_id:0xC02F": {}, + "tls.client_version:tls_1.0": {}, + "tls.client_version:tls_1.1": {}, + "tls.client_version:tls_1.2": {}, + "tls.client_version:tls_1.3": {}, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result := test.tlsTags.GetDynamicTags() + if !reflect.DeepEqual(result, test.expected) { + t.Errorf("GetDynamicTags(%v) = %v; want %v", test.tlsTags, result, test.expected) + } + }) + } +} diff --git a/pkg/network/state.go b/pkg/network/state.go index 39094959989ac..48f07d83242d7 100644 --- a/pkg/network/state.go +++ b/pkg/network/state.go @@ -1420,6 +1420,7 @@ func (ac *aggregateConnection) merge(c *ConnectionStats) { } ac.ProtocolStack.MergeWith(c.ProtocolStack) + ac.TLSTags.MergeWith(c.TLSTags) if ac.DNSStats == nil { ac.DNSStats = c.DNSStats @@ -1483,6 +1484,7 @@ func (ns *networkState) mergeConnectionStats(a, b *ConnectionStats) (collision b } a.ProtocolStack.MergeWith(b.ProtocolStack) + a.TLSTags.MergeWith(b.TLSTags) return false } diff --git a/pkg/network/tracer/connection/ebpf_tracer.go b/pkg/network/tracer/connection/ebpf_tracer.go index d181471578d21..b4af3495f8e16 100644 --- a/pkg/network/tracer/connection/ebpf_tracer.go +++ b/pkg/network/tracer/connection/ebpf_tracer.go @@ -31,6 +31,7 @@ import ( netebpf "github.com/DataDog/datadog-agent/pkg/network/ebpf" "github.com/DataDog/datadog-agent/pkg/network/ebpf/probes" "github.com/DataDog/datadog-agent/pkg/network/protocols" + "github.com/DataDog/datadog-agent/pkg/network/protocols/tls" "github.com/DataDog/datadog-agent/pkg/network/tracer/connection/fentry" "github.com/DataDog/datadog-agent/pkg/network/tracer/connection/kprobe" "github.com/DataDog/datadog-agent/pkg/network/tracer/connection/util" @@ -45,6 +46,7 @@ const ( ) var tcpOngoingConnectMapTTL = 30 * time.Minute.Nanoseconds() +var tlsTagsMapTTL = 3 * time.Minute.Nanoseconds() var EbpfTracerTelemetry = struct { //nolint:revive // TODO connections telemetry.Gauge @@ -150,6 +152,8 @@ type ebpfTracer struct { // periodically clean the ongoing connection pid map ongoingConnectCleaner *ddebpf.MapCleaner[netebpf.SkpConn, netebpf.PidTs] + // periodically clean the enhanced TLS tags map + TLSTagsCleaner *ddebpf.MapCleaner[netebpf.ConnTuple, netebpf.TLSTagsWrapper] removeTuple *netebpf.ConnTuple @@ -187,6 +191,7 @@ func newEbpfTracer(config *config.Config, _ telemetryComponent.Component) (Trace probes.PortBindingsMap: {MaxEntries: config.MaxTrackedConnections, EditorFlag: manager.EditMaxEntries}, probes.UDPPortBindingsMap: {MaxEntries: config.MaxTrackedConnections, EditorFlag: manager.EditMaxEntries}, probes.ConnectionProtocolMap: {MaxEntries: config.MaxTrackedConnections, EditorFlag: manager.EditMaxEntries}, + probes.EnhancedTLSTagsMap: {MaxEntries: config.MaxTrackedConnections, EditorFlag: manager.EditMaxEntries}, probes.ConnectionTupleToSocketSKBConnMap: {MaxEntries: config.MaxTrackedConnections, EditorFlag: manager.EditMaxEntries}, probes.TCPOngoingConnectPid: {MaxEntries: config.MaxTrackedConnections, EditorFlag: manager.EditMaxEntries}, probes.TCPRecvMsgArgsMap: {MaxEntries: config.MaxTrackedConnections / 32, EditorFlag: manager.EditMaxEntries}, @@ -271,7 +276,7 @@ func newEbpfTracer(config *config.Config, _ telemetryComponent.Component) (Trace ch: newCookieHasher(), } - tr.setupMapCleaner(m) + tr.setupMapCleaners(m) tr.conns, err = maps.GetMap[netebpf.ConnTuple, netebpf.ConnStats](m, probes.ConnMap) if err != nil { @@ -350,6 +355,7 @@ func (t *ebpfTracer) Stop() { _ = t.m.Stop(manager.CleanAll) t.closeConsumer.Stop() t.ongoingConnectCleaner.Stop() + t.TLSTagsCleaner.Stop() if t.closeTracer != nil { t.closeTracer() } @@ -693,8 +699,14 @@ func (t *ebpfTracer) getTCPStats(stats *netebpf.TCPStats, tuple *netebpf.ConnTup return t.tcpStats.Lookup(tuple, stats) == nil } -// setupMapCleaner sets up a map cleaner for the tcp_ongoing_connect_pid map -func (t *ebpfTracer) setupMapCleaner(m *manager.Manager) { +// setupMapCleaners sets up the map cleaners for the eBPF maps +func (t *ebpfTracer) setupMapCleaners(m *manager.Manager) { + t.setupOngoingConnectMapCleaner(m) + t.setupTLSTagsMapCleaner(m) +} + +// setupOngoingConnectMapCleaner sets up a map cleaner for the tcp_ongoing_connect_pid map +func (t *ebpfTracer) setupOngoingConnectMapCleaner(m *manager.Manager) { tcpOngoingConnectPidMap, _, err := m.GetMap(probes.TCPOngoingConnectPid) if err != nil { log.Errorf("error getting %v map: %s", probes.TCPOngoingConnectPid, err) @@ -718,6 +730,28 @@ func (t *ebpfTracer) setupMapCleaner(m *manager.Manager) { t.ongoingConnectCleaner = tcpOngoingConnectPidCleaner } +// setupTLSTagsMapCleaner sets up a map cleaner for the tls_enhanced_tags map +func (t *ebpfTracer) setupTLSTagsMapCleaner(m *manager.Manager) { + TLSTagsMap, _, err := m.GetMap(probes.EnhancedTLSTagsMap) + if err != nil { + log.Errorf("error getting %v map: %s", probes.EnhancedTLSTagsMap, err) + return + } + + TLSTagsMapCleaner, err := ddebpf.NewMapCleaner[netebpf.ConnTuple, netebpf.TLSTagsWrapper](TLSTagsMap, 1024, probes.EnhancedTLSTagsMap, "npm_tracer") + if err != nil { + log.Errorf("error creating map cleaner: %s", err) + return + } + // slight jitter to avoid all maps being cleaned at the same time + TLSTagsMapCleaner.Clean(time.Second*70, nil, nil, func(now int64, _ netebpf.ConnTuple, val netebpf.TLSTagsWrapper) bool { + ts := int64(val.Updated) + return ts > 0 && now-ts > tlsTagsMapTTL + }) + + t.TLSTagsCleaner = TLSTagsMapCleaner +} + func populateConnStats(stats *network.ConnectionStats, t *netebpf.ConnTuple, s *netebpf.ConnStats, ch *cookieHasher) { *stats = network.ConnectionStats{ConnectionTuple: network.ConnectionTuple{ Pid: t.Pid, @@ -748,6 +782,12 @@ func populateConnStats(stats *network.ConnectionStats, t *netebpf.ConnTuple, s * Encryption: protocols.Encryption(s.Protocol_stack.Encryption), } + stats.TLSTags = tls.Tags{ + ChosenVersion: s.Tls_tags.Chosen_version, + CipherSuite: s.Tls_tags.Cipher_suite, + OfferedVersions: s.Tls_tags.Offered_versions, + } + if t.Type() == netebpf.TCP { stats.Type = network.TCP } else { diff --git a/pkg/network/tracer/connection/kprobe/config.go b/pkg/network/tracer/connection/kprobe/config.go index 880a2f0a5e838..b7c6bb2ff9986 100644 --- a/pkg/network/tracer/connection/kprobe/config.go +++ b/pkg/network/tracer/connection/kprobe/config.go @@ -58,6 +58,8 @@ func enabledProbes(c *config.Config, runtimeTracer, coreTracer bool) (map[probes if c.CollectTCPv4Conns || c.CollectTCPv6Conns { if ClassificationSupported(c) { enableProbe(enabled, probes.ProtocolClassifierEntrySocketFilter) + enableProbe(enabled, probes.ProtocolClassifierTLSClientSocketFilter) + enableProbe(enabled, probes.ProtocolClassifierTLSServerSocketFilter) enableProbe(enabled, probes.ProtocolClassifierQueuesSocketFilter) enableProbe(enabled, probes.ProtocolClassifierDBsSocketFilter) enableProbe(enabled, probes.ProtocolClassifierGRPCSocketFilter) diff --git a/pkg/network/tracer/connection/kprobe/manager.go b/pkg/network/tracer/connection/kprobe/manager.go index fb2ee4b7bd065..684a0de804638 100644 --- a/pkg/network/tracer/connection/kprobe/manager.go +++ b/pkg/network/tracer/connection/kprobe/manager.go @@ -19,6 +19,8 @@ import ( var mainProbes = []probes.ProbeFuncName{ probes.NetDevQueue, probes.ProtocolClassifierEntrySocketFilter, + probes.ProtocolClassifierTLSClientSocketFilter, + probes.ProtocolClassifierTLSServerSocketFilter, probes.ProtocolClassifierQueuesSocketFilter, probes.ProtocolClassifierDBsSocketFilter, probes.ProtocolClassifierGRPCSocketFilter, diff --git a/pkg/network/tracer/connection/kprobe/tracer.go b/pkg/network/tracer/connection/kprobe/tracer.go index 634642154e881..2dbfea6629c1d 100644 --- a/pkg/network/tracer/connection/kprobe/tracer.go +++ b/pkg/network/tracer/connection/kprobe/tracer.go @@ -45,6 +45,22 @@ var ( classificationMinimumKernel = kernel.VersionCode(4, 11, 0) protocolClassificationTailCalls = []manager.TailCallRoute{ + { + ProgArrayName: probes.ClassificationProgsMap, + Key: netebpf.ClassificationTLSClient, + ProbeIdentificationPair: manager.ProbeIdentificationPair{ + EBPFFuncName: probes.ProtocolClassifierTLSClientSocketFilter, + UID: probeUID, + }, + }, + { + ProgArrayName: probes.ClassificationProgsMap, + Key: netebpf.ClassificationTLSServer, + ProbeIdentificationPair: manager.ProbeIdentificationPair{ + EBPFFuncName: probes.ProtocolClassifierTLSServerSocketFilter, + UID: probeUID, + }, + }, { ProgArrayName: probes.ClassificationProgsMap, Key: netebpf.ClassificationQueues, @@ -90,8 +106,8 @@ var ( ) // ClassificationSupported returns true if the current kernel version supports the classification feature. -// The kernel has to be newer than 4.7.0 since we are using bpf_skb_load_bytes (4.5.0+) method to read from the socket -// filter, and a tracepoint (4.7.0+) +// The kernel has to be newer than 4.11.0 since we are using bpf_skb_load_bytes (4.5.0+) method which was added to +// socket filters in 4.11.0, and a tracepoint (4.7.0+) func ClassificationSupported(config *config.Config) bool { if !config.ProtocolClassificationEnabled { return false diff --git a/pkg/network/tracer/tracer_linux_test.go b/pkg/network/tracer/tracer_linux_test.go index 46f33ac7fa405..3825b9d1bef07 100644 --- a/pkg/network/tracer/tracer_linux_test.go +++ b/pkg/network/tracer/tracer_linux_test.go @@ -11,6 +11,7 @@ import ( "bufio" "bytes" "context" + "crypto/tls" "errors" "fmt" "io" @@ -50,8 +51,12 @@ import ( "github.com/DataDog/datadog-agent/pkg/network/config/sysctl" "github.com/DataDog/datadog-agent/pkg/network/events" netlinktestutil "github.com/DataDog/datadog-agent/pkg/network/netlink/testutil" + "github.com/DataDog/datadog-agent/pkg/network/protocols" + usmtestutil "github.com/DataDog/datadog-agent/pkg/network/protocols/http/testutil" + ddtls "github.com/DataDog/datadog-agent/pkg/network/protocols/tls" "github.com/DataDog/datadog-agent/pkg/network/testutil" "github.com/DataDog/datadog-agent/pkg/network/tracer/connection" + "github.com/DataDog/datadog-agent/pkg/network/tracer/connection/kprobe" "github.com/DataDog/datadog-agent/pkg/network/tracer/offsetguess" tracertestutil "github.com/DataDog/datadog-agent/pkg/network/tracer/testutil" "github.com/DataDog/datadog-agent/pkg/network/tracer/testutil/testdns" @@ -2590,6 +2595,196 @@ func setupDropTrafficRule(tb testing.TB) (ns string) { return } +func (s *TracerSuite) TestTLSClassification() { + t := s.T() + cfg := testConfig() + + if !kprobe.ClassificationSupported(cfg) { + t.Skip("protocol classification not supported") + } + + tr := setupTracer(t, cfg) + + type tlsTest struct { + name string + postTracerSetup func(t *testing.T) (port uint16, scenario uint16) + validation func(t *testing.T, tr *Tracer, port uint16, scenario uint16) + } + + tests := make([]tlsTest, 0) + for _, scenario := range []uint16{tls.VersionTLS10, tls.VersionTLS11, tls.VersionTLS12, tls.VersionTLS13} { + scenario := scenario + tests = append(tests, tlsTest{ + name: strings.Replace(tls.VersionName(scenario), " ", "-", 1), + postTracerSetup: func(t *testing.T) (uint16, uint16) { + srv := usmtestutil.NewTLSServerWithSpecificVersion("localhost:0", func(conn net.Conn) { + defer conn.Close() + _, err := io.Copy(conn, conn) + if err != nil { + fmt.Printf("Failed to echo data: %v\n", err) + return + } + }, scenario) + done := make(chan struct{}) + require.NoError(t, srv.Run(done)) + t.Cleanup(func() { close(done) }) + + // Retrieve the actual port assigned to the server + addr := srv.Address() + _, portStr, err := net.SplitHostPort(addr) + require.NoError(t, err) + portInt, err := strconv.Atoi(portStr) + require.NoError(t, err) + port := uint16(portInt) + + tlsConfig := &tls.Config{ + MinVersion: scenario, + MaxVersion: scenario, + InsecureSkipVerify: true, + SessionTicketsDisabled: true, + ClientSessionCache: nil, + } + conn, err := net.Dial("tcp", addr) + require.NoError(t, err) + defer conn.Close() + + tlsConn := tls.Client(conn, tlsConfig) + require.NoError(t, tlsConn.Handshake()) + + return port, scenario + }, + validation: func(t *testing.T, tr *Tracer, port uint16, scenario uint16) { + require.EventuallyWithT(t, func(ct *assert.CollectT) { + require.True(ct, validateTLSTags(ct, tr, port, scenario), "TLS tags not set") + }, 3*time.Second, 100*time.Millisecond, "couldn't find TLS connection matching: dst port %v", port) + }, + }) + } + tests = append(tests, tlsTest{ + name: "Invalid-TLS-Handshake", + postTracerSetup: func(t *testing.T) (uint16, uint16) { + // server that accepts connections but does not perform TLS handshake + listener, err := net.Listen("tcp", "localhost:0") + require.NoError(t, err) + t.Cleanup(func() { listener.Close() }) + + go func() { + for { + conn, err := listener.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer c.Close() + buf := make([]byte, 1024) + _, _ = c.Read(buf) + // Do nothing with the data + }(conn) + } + }() + + // Retrieve the actual port from the listener address + addr := listener.Addr().String() + _, portStr, err := net.SplitHostPort(addr) + require.NoError(t, err) + portInt, err := strconv.Atoi(portStr) + require.NoError(t, err) + port := uint16(portInt) + + // Client connects to the server + conn, err := net.Dial("tcp", addr) + require.NoError(t, err) + defer conn.Close() + + // Send invalid TLS handshake data + _, err = conn.Write([]byte("invalid TLS data")) + require.NoError(t, err) + + // Since this is invalid TLS, scenario can be set to something irrelevant, e.g., TLS.VersionTLS12 + // or just 0 since the validation doesn't rely on the scenario for this test. + return port, tls.VersionTLS12 + }, + validation: func(t *testing.T, tr *Tracer, port uint16, _ uint16) { + // Verify that no TLS tags are set for this connection + require.EventuallyWithT(t, func(ct *assert.CollectT) { + payload := getConnections(ct, tr) + for _, c := range payload.Conns { + if c.DPort == port && c.ProtocolStack.Contains(protocols.TLS) { + t.Log("Unexpected TLS protocol detected for invalid handshake") + require.Fail(ct, "unexpected TLS tags") + } + } + }, 3*time.Second, 100*time.Millisecond) + }, + }) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if ebpftest.GetBuildMode() == ebpftest.Fentry { + t.Skip("protocol classification not supported for fentry tracer") + } + t.Cleanup(func() { + tr.RemoveClient(clientID) + _ = tr.Pause() + }) + tr.RemoveClient(clientID) + require.NoError(t, tr.RegisterClient(clientID)) + require.NoError(t, tr.Resume(), "enable probes - before post tracer") + port, scenario := tt.postTracerSetup(t) + require.NoError(t, tr.Pause(), "disable probes - after post tracer") + tt.validation(t, tr, port, scenario) + }) + } +} + +func validateTLSTags(t *assert.CollectT, tr *Tracer, port uint16, scenario uint16) bool { + payload := getConnections(t, tr) + for _, c := range payload.Conns { + if c.DPort == port && c.ProtocolStack.Contains(protocols.TLS) && !c.TLSTags.IsEmpty() { + tlsTags := c.TLSTags.GetDynamicTags() + + // Check that the cipher suite ID tag is present + cipherSuiteTagFound := false + for key := range tlsTags { + if strings.HasPrefix(key, ddtls.TagTLSCipherSuiteID) { + cipherSuiteTagFound = true + break + } + } + if !cipherSuiteTagFound { + return false + } + + // Check that the negotiated version tag is present + negotiatedVersionTag := ddtls.VersionTags[scenario] + if _, ok := tlsTags[negotiatedVersionTag]; !ok { + return false + } + + // Check that the client offered version tag is present + clientVersionTag := ddtls.ClientVersionTags[scenario] + if _, ok := tlsTags[clientVersionTag]; !ok { + return false + } + + if scenario == tls.VersionTLS13 { + expectedClientVersions := []string{ + ddtls.ClientVersionTags[tls.VersionTLS12], + ddtls.ClientVersionTags[tls.VersionTLS13], + } + for _, tag := range expectedClientVersions { + if _, ok := tlsTags[tag]; !ok { + return false + } + } + } + + return true + } + } + return false +} + func skipOnEbpflessNotSupported(t *testing.T, cfg *config.Config) { if cfg.EnableEbpfless { t.Skip("not supported on ebpf-less") diff --git a/pkg/network/tracer/tracer_test.go b/pkg/network/tracer/tracer_test.go index d50f4f299690c..f3117827905fc 100644 --- a/pkg/network/tracer/tracer_test.go +++ b/pkg/network/tracer/tracer_test.go @@ -1073,7 +1073,6 @@ func (s *TracerSuite) TestDNSStats() { func (s *TracerSuite) TestTCPEstablished() { t := s.T() - // Ensure closed connections are flushed as soon as possible cfg := testConfig() tr := setupTracer(t, cfg) diff --git a/releasenotes/notes/add-tls-enhanced-tags-6ff09ae7fc0ff7a1.yaml b/releasenotes/notes/add-tls-enhanced-tags-6ff09ae7fc0ff7a1.yaml new file mode 100644 index 0000000000000..14ed9c4ee8c35 --- /dev/null +++ b/releasenotes/notes/add-tls-enhanced-tags-6ff09ae7fc0ff7a1.yaml @@ -0,0 +1,13 @@ +# Each section from every release note are combined when the +# CHANGELOG.rst is rendered. So the text needs to be worded so that +# it does not depend on any information only available in another +# section. This may mean repeating some details, but each section +# must be readable independently of the other. +# +# Each section note must be formatted as reStructuredText. +--- +features: + - | + The Agent will now tag TLS enhanced metrics like `tls_version` and `tls_cipher`. + This will allow you to filter and aggregate metrics based on the TLS version and cipher used in the connection. + The tags will be added in CNM and USM.